329 lines
12 KiB
Python
329 lines
12 KiB
Python
"""
|
|
Rate Limiting Middleware for SaaS robustness
|
|
Protects against abuse and ensures fair usage
|
|
"""
|
|
import time
|
|
import asyncio
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass, field
|
|
from typing import Dict, Optional
|
|
from fastapi import Request, HTTPException
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.responses import JSONResponse
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class RateLimitConfig:
|
|
"""Configuration for rate limiting"""
|
|
# Requests per window
|
|
requests_per_minute: int = 30
|
|
requests_per_hour: int = 200
|
|
requests_per_day: int = 1000
|
|
|
|
# Translation-specific limits
|
|
translations_per_minute: int = 10
|
|
translations_per_hour: int = 50
|
|
max_concurrent_translations: int = 5
|
|
|
|
# File size limits (MB)
|
|
max_file_size_mb: int = 50
|
|
max_total_size_per_hour_mb: int = 500
|
|
|
|
# Burst protection
|
|
burst_limit: int = 10 # Max requests in 1 second
|
|
|
|
# Whitelist IPs (no rate limiting)
|
|
whitelist_ips: list = field(default_factory=lambda: ["127.0.0.1", "::1"])
|
|
|
|
|
|
class TokenBucket:
|
|
"""Token bucket algorithm for rate limiting"""
|
|
|
|
def __init__(self, capacity: int, refill_rate: float):
|
|
self.capacity = capacity
|
|
self.refill_rate = refill_rate # tokens per second
|
|
self.tokens = capacity
|
|
self.last_refill = time.time()
|
|
self._lock = asyncio.Lock()
|
|
|
|
async def consume(self, tokens: int = 1) -> bool:
|
|
"""Try to consume tokens, return True if successful"""
|
|
async with self._lock:
|
|
self._refill()
|
|
if self.tokens >= tokens:
|
|
self.tokens -= tokens
|
|
return True
|
|
return False
|
|
|
|
def _refill(self):
|
|
"""Refill tokens based on time elapsed"""
|
|
now = time.time()
|
|
elapsed = now - self.last_refill
|
|
self.tokens = min(self.capacity, self.tokens + elapsed * self.refill_rate)
|
|
self.last_refill = now
|
|
|
|
|
|
class SlidingWindowCounter:
|
|
"""Sliding window counter for accurate rate limiting"""
|
|
|
|
def __init__(self, window_seconds: int, max_requests: int):
|
|
self.window_seconds = window_seconds
|
|
self.max_requests = max_requests
|
|
self.requests: list = []
|
|
self._lock = asyncio.Lock()
|
|
|
|
async def is_allowed(self) -> bool:
|
|
"""Check if a new request is allowed"""
|
|
async with self._lock:
|
|
now = time.time()
|
|
# Remove old requests outside the window
|
|
self.requests = [ts for ts in self.requests if now - ts < self.window_seconds]
|
|
|
|
if len(self.requests) < self.max_requests:
|
|
self.requests.append(now)
|
|
return True
|
|
return False
|
|
|
|
@property
|
|
def current_count(self) -> int:
|
|
"""Get current request count in window"""
|
|
now = time.time()
|
|
return len([ts for ts in self.requests if now - ts < self.window_seconds])
|
|
|
|
|
|
class ClientRateLimiter:
|
|
"""Per-client rate limiter with multiple windows"""
|
|
|
|
def __init__(self, config: RateLimitConfig):
|
|
self.config = config
|
|
self.minute_counter = SlidingWindowCounter(60, config.requests_per_minute)
|
|
self.hour_counter = SlidingWindowCounter(3600, config.requests_per_hour)
|
|
self.day_counter = SlidingWindowCounter(86400, config.requests_per_day)
|
|
self.burst_bucket = TokenBucket(config.burst_limit, config.burst_limit)
|
|
self.translation_minute = SlidingWindowCounter(60, config.translations_per_minute)
|
|
self.translation_hour = SlidingWindowCounter(3600, config.translations_per_hour)
|
|
self.concurrent_translations = 0
|
|
self.total_size_hour: list = [] # List of (timestamp, size_mb)
|
|
self._lock = asyncio.Lock()
|
|
|
|
async def check_request(self) -> tuple[bool, str]:
|
|
"""Check if request is allowed, return (allowed, reason)"""
|
|
# Check burst limit
|
|
if not await self.burst_bucket.consume():
|
|
return False, "Too many requests. Please slow down."
|
|
|
|
# Check minute limit
|
|
if not await self.minute_counter.is_allowed():
|
|
return False, f"Rate limit exceeded. Max {self.config.requests_per_minute} requests per minute."
|
|
|
|
# Check hour limit
|
|
if not await self.hour_counter.is_allowed():
|
|
return False, f"Hourly limit exceeded. Max {self.config.requests_per_hour} requests per hour."
|
|
|
|
# Check day limit
|
|
if not await self.day_counter.is_allowed():
|
|
return False, f"Daily limit exceeded. Max {self.config.requests_per_day} requests per day."
|
|
|
|
return True, ""
|
|
|
|
async def check_translation(self, file_size_mb: float = 0) -> tuple[bool, str]:
|
|
"""Check if translation request is allowed"""
|
|
async with self._lock:
|
|
# Check concurrent limit
|
|
if self.concurrent_translations >= self.config.max_concurrent_translations:
|
|
return False, f"Too many concurrent translations. Max {self.config.max_concurrent_translations} at a time."
|
|
|
|
# Check translation per minute
|
|
if not await self.translation_minute.is_allowed():
|
|
return False, f"Translation rate limit exceeded. Max {self.config.translations_per_minute} translations per minute."
|
|
|
|
# Check translation per hour
|
|
if not await self.translation_hour.is_allowed():
|
|
return False, f"Hourly translation limit exceeded. Max {self.config.translations_per_hour} translations per hour."
|
|
|
|
# Check total size per hour
|
|
async with self._lock:
|
|
now = time.time()
|
|
self.total_size_hour = [(ts, size) for ts, size in self.total_size_hour if now - ts < 3600]
|
|
total_size = sum(size for _, size in self.total_size_hour)
|
|
|
|
if total_size + file_size_mb > self.config.max_total_size_per_hour_mb:
|
|
return False, f"Hourly data limit exceeded. Max {self.config.max_total_size_per_hour_mb}MB per hour."
|
|
|
|
self.total_size_hour.append((now, file_size_mb))
|
|
|
|
return True, ""
|
|
|
|
async def start_translation(self):
|
|
"""Mark start of translation"""
|
|
async with self._lock:
|
|
self.concurrent_translations += 1
|
|
|
|
async def end_translation(self):
|
|
"""Mark end of translation"""
|
|
async with self._lock:
|
|
self.concurrent_translations = max(0, self.concurrent_translations - 1)
|
|
|
|
def get_stats(self) -> dict:
|
|
"""Get current rate limit stats"""
|
|
return {
|
|
"requests_minute": self.minute_counter.current_count,
|
|
"requests_hour": self.hour_counter.current_count,
|
|
"requests_day": self.day_counter.current_count,
|
|
"translations_minute": self.translation_minute.current_count,
|
|
"translations_hour": self.translation_hour.current_count,
|
|
"concurrent_translations": self.concurrent_translations,
|
|
}
|
|
|
|
|
|
class RateLimitManager:
|
|
"""Manages rate limiters for all clients"""
|
|
|
|
def __init__(self, config: Optional[RateLimitConfig] = None):
|
|
self.config = config or RateLimitConfig()
|
|
self.clients: Dict[str, ClientRateLimiter] = defaultdict(lambda: ClientRateLimiter(self.config))
|
|
self._cleanup_interval = 3600 # Cleanup old clients every hour
|
|
self._last_cleanup = time.time()
|
|
self._total_requests = 0
|
|
self._total_translations = 0
|
|
|
|
def get_client_id(self, request: Request) -> str:
|
|
"""Extract client identifier from request"""
|
|
# Try to get real IP from headers (for proxied requests)
|
|
forwarded = request.headers.get("X-Forwarded-For")
|
|
if forwarded:
|
|
return forwarded.split(",")[0].strip()
|
|
|
|
real_ip = request.headers.get("X-Real-IP")
|
|
if real_ip:
|
|
return real_ip
|
|
|
|
# Fall back to direct client IP
|
|
if request.client:
|
|
return request.client.host
|
|
|
|
return "unknown"
|
|
|
|
def is_whitelisted(self, client_id: str) -> bool:
|
|
"""Check if client is whitelisted"""
|
|
return client_id in self.config.whitelist_ips
|
|
|
|
async def check_request(self, request: Request) -> tuple[bool, str, str]:
|
|
"""Check if request is allowed, return (allowed, reason, client_id)"""
|
|
client_id = self.get_client_id(request)
|
|
self._total_requests += 1
|
|
|
|
if self.is_whitelisted(client_id):
|
|
return True, "", client_id
|
|
|
|
client = self.clients[client_id]
|
|
allowed, reason = await client.check_request()
|
|
|
|
return allowed, reason, client_id
|
|
|
|
async def check_translation(self, request: Request, file_size_mb: float = 0) -> tuple[bool, str]:
|
|
"""Check if translation is allowed"""
|
|
client_id = self.get_client_id(request)
|
|
self._total_translations += 1
|
|
|
|
if self.is_whitelisted(client_id):
|
|
return True, ""
|
|
|
|
client = self.clients[client_id]
|
|
return await client.check_translation(file_size_mb)
|
|
|
|
async def check_translation_limit(self, client_id: str, file_size_mb: float = 0) -> bool:
|
|
"""Check if translation is allowed for a specific client ID"""
|
|
if self.is_whitelisted(client_id):
|
|
return True
|
|
|
|
client = self.clients[client_id]
|
|
allowed, _ = await client.check_translation(file_size_mb)
|
|
return allowed
|
|
|
|
def get_client_stats(self, request: Request) -> dict:
|
|
"""Get rate limit stats for a client"""
|
|
client_id = self.get_client_id(request)
|
|
client = self.clients[client_id]
|
|
return {
|
|
"client_id": client_id,
|
|
"is_whitelisted": self.is_whitelisted(client_id),
|
|
**client.get_stats()
|
|
}
|
|
|
|
async def get_client_status(self, client_id: str) -> dict:
|
|
"""Get current usage status for a client"""
|
|
if client_id not in self.clients:
|
|
return {"status": "no_activity", "requests": 0}
|
|
|
|
client = self.clients[client_id]
|
|
stats = client.get_stats()
|
|
|
|
return {
|
|
"requests_used_minute": stats["requests_minute"],
|
|
"requests_used_hour": stats["requests_hour"],
|
|
"translations_used_minute": stats["translations_minute"],
|
|
"translations_used_hour": stats["translations_hour"],
|
|
"concurrent_translations": stats["concurrent_translations"],
|
|
"is_whitelisted": self.is_whitelisted(client_id)
|
|
}
|
|
|
|
def get_stats(self) -> dict:
|
|
"""Get global rate limiting statistics"""
|
|
return {
|
|
"total_requests": self._total_requests,
|
|
"total_translations": self._total_translations,
|
|
"active_clients": len(self.clients),
|
|
"config": {
|
|
"requests_per_minute": self.config.requests_per_minute,
|
|
"requests_per_hour": self.config.requests_per_hour,
|
|
"translations_per_minute": self.config.translations_per_minute,
|
|
"translations_per_hour": self.config.translations_per_hour,
|
|
"max_concurrent_translations": self.config.max_concurrent_translations
|
|
}
|
|
}
|
|
|
|
|
|
class RateLimitMiddleware(BaseHTTPMiddleware):
|
|
"""FastAPI middleware for rate limiting"""
|
|
|
|
def __init__(self, app, rate_limit_manager: RateLimitManager):
|
|
super().__init__(app)
|
|
self.manager = rate_limit_manager
|
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
# Skip rate limiting for health checks and static files
|
|
if request.url.path in ["/health", "/", "/docs", "/openapi.json", "/redoc"]:
|
|
return await call_next(request)
|
|
|
|
if request.url.path.startswith("/static"):
|
|
return await call_next(request)
|
|
|
|
# Check rate limit
|
|
allowed, reason, client_id = await self.manager.check_request(request)
|
|
|
|
if not allowed:
|
|
logger.warning(f"Rate limit exceeded for {client_id}: {reason}")
|
|
return JSONResponse(
|
|
status_code=429,
|
|
content={
|
|
"error": "rate_limit_exceeded",
|
|
"message": reason,
|
|
"retry_after": 60
|
|
},
|
|
headers={"Retry-After": "60"}
|
|
)
|
|
|
|
# Add client info to request state for use in endpoints
|
|
request.state.client_id = client_id
|
|
request.state.rate_limiter = self.manager.clients[client_id]
|
|
|
|
return await call_next(request)
|
|
|
|
|
|
# Global rate limit manager
|
|
rate_limit_manager = RateLimitManager()
|