""" Rate Limiting Middleware for SaaS robustness Protects against abuse and ensures fair usage """ import time import asyncio from collections import defaultdict from dataclasses import dataclass, field from typing import Dict, Optional from fastapi import Request, HTTPException from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import JSONResponse import logging logger = logging.getLogger(__name__) @dataclass class RateLimitConfig: """Configuration for rate limiting""" # Requests per window requests_per_minute: int = 30 requests_per_hour: int = 200 requests_per_day: int = 1000 # Translation-specific limits translations_per_minute: int = 10 translations_per_hour: int = 50 max_concurrent_translations: int = 5 # File size limits (MB) max_file_size_mb: int = 50 max_total_size_per_hour_mb: int = 500 # Burst protection burst_limit: int = 10 # Max requests in 1 second # Whitelist IPs (no rate limiting) whitelist_ips: list = field(default_factory=lambda: ["127.0.0.1", "::1"]) class TokenBucket: """Token bucket algorithm for rate limiting""" def __init__(self, capacity: int, refill_rate: float): self.capacity = capacity self.refill_rate = refill_rate # tokens per second self.tokens = capacity self.last_refill = time.time() self._lock = asyncio.Lock() async def consume(self, tokens: int = 1) -> bool: """Try to consume tokens, return True if successful""" async with self._lock: self._refill() if self.tokens >= tokens: self.tokens -= tokens return True return False def _refill(self): """Refill tokens based on time elapsed""" now = time.time() elapsed = now - self.last_refill self.tokens = min(self.capacity, self.tokens + elapsed * self.refill_rate) self.last_refill = now class SlidingWindowCounter: """Sliding window counter for accurate rate limiting""" def __init__(self, window_seconds: int, max_requests: int): self.window_seconds = window_seconds self.max_requests = max_requests self.requests: list = [] self._lock = asyncio.Lock() async def is_allowed(self) -> bool: """Check if a new request is allowed""" async with self._lock: now = time.time() # Remove old requests outside the window self.requests = [ts for ts in self.requests if now - ts < self.window_seconds] if len(self.requests) < self.max_requests: self.requests.append(now) return True return False @property def current_count(self) -> int: """Get current request count in window""" now = time.time() return len([ts for ts in self.requests if now - ts < self.window_seconds]) class ClientRateLimiter: """Per-client rate limiter with multiple windows""" def __init__(self, config: RateLimitConfig): self.config = config self.minute_counter = SlidingWindowCounter(60, config.requests_per_minute) self.hour_counter = SlidingWindowCounter(3600, config.requests_per_hour) self.day_counter = SlidingWindowCounter(86400, config.requests_per_day) self.burst_bucket = TokenBucket(config.burst_limit, config.burst_limit) self.translation_minute = SlidingWindowCounter(60, config.translations_per_minute) self.translation_hour = SlidingWindowCounter(3600, config.translations_per_hour) self.concurrent_translations = 0 self.total_size_hour: list = [] # List of (timestamp, size_mb) self._lock = asyncio.Lock() async def check_request(self) -> tuple[bool, str]: """Check if request is allowed, return (allowed, reason)""" # Check burst limit if not await self.burst_bucket.consume(): return False, "Too many requests. Please slow down." # Check minute limit if not await self.minute_counter.is_allowed(): return False, f"Rate limit exceeded. Max {self.config.requests_per_minute} requests per minute." # Check hour limit if not await self.hour_counter.is_allowed(): return False, f"Hourly limit exceeded. Max {self.config.requests_per_hour} requests per hour." # Check day limit if not await self.day_counter.is_allowed(): return False, f"Daily limit exceeded. Max {self.config.requests_per_day} requests per day." return True, "" async def check_translation(self, file_size_mb: float = 0) -> tuple[bool, str]: """Check if translation request is allowed""" async with self._lock: # Check concurrent limit if self.concurrent_translations >= self.config.max_concurrent_translations: return False, f"Too many concurrent translations. Max {self.config.max_concurrent_translations} at a time." # Check translation per minute if not await self.translation_minute.is_allowed(): return False, f"Translation rate limit exceeded. Max {self.config.translations_per_minute} translations per minute." # Check translation per hour if not await self.translation_hour.is_allowed(): return False, f"Hourly translation limit exceeded. Max {self.config.translations_per_hour} translations per hour." # Check total size per hour async with self._lock: now = time.time() self.total_size_hour = [(ts, size) for ts, size in self.total_size_hour if now - ts < 3600] total_size = sum(size for _, size in self.total_size_hour) if total_size + file_size_mb > self.config.max_total_size_per_hour_mb: return False, f"Hourly data limit exceeded. Max {self.config.max_total_size_per_hour_mb}MB per hour." self.total_size_hour.append((now, file_size_mb)) return True, "" async def start_translation(self): """Mark start of translation""" async with self._lock: self.concurrent_translations += 1 async def end_translation(self): """Mark end of translation""" async with self._lock: self.concurrent_translations = max(0, self.concurrent_translations - 1) def get_stats(self) -> dict: """Get current rate limit stats""" return { "requests_minute": self.minute_counter.current_count, "requests_hour": self.hour_counter.current_count, "requests_day": self.day_counter.current_count, "translations_minute": self.translation_minute.current_count, "translations_hour": self.translation_hour.current_count, "concurrent_translations": self.concurrent_translations, } class RateLimitManager: """Manages rate limiters for all clients""" def __init__(self, config: Optional[RateLimitConfig] = None): self.config = config or RateLimitConfig() self.clients: Dict[str, ClientRateLimiter] = defaultdict(lambda: ClientRateLimiter(self.config)) self._cleanup_interval = 3600 # Cleanup old clients every hour self._last_cleanup = time.time() self._total_requests = 0 self._total_translations = 0 def get_client_id(self, request: Request) -> str: """Extract client identifier from request""" # Try to get real IP from headers (for proxied requests) forwarded = request.headers.get("X-Forwarded-For") if forwarded: return forwarded.split(",")[0].strip() real_ip = request.headers.get("X-Real-IP") if real_ip: return real_ip # Fall back to direct client IP if request.client: return request.client.host return "unknown" def is_whitelisted(self, client_id: str) -> bool: """Check if client is whitelisted""" return client_id in self.config.whitelist_ips async def check_request(self, request: Request) -> tuple[bool, str, str]: """Check if request is allowed, return (allowed, reason, client_id)""" client_id = self.get_client_id(request) self._total_requests += 1 if self.is_whitelisted(client_id): return True, "", client_id client = self.clients[client_id] allowed, reason = await client.check_request() return allowed, reason, client_id async def check_translation(self, request: Request, file_size_mb: float = 0) -> tuple[bool, str]: """Check if translation is allowed""" client_id = self.get_client_id(request) self._total_translations += 1 if self.is_whitelisted(client_id): return True, "" client = self.clients[client_id] return await client.check_translation(file_size_mb) async def check_translation_limit(self, client_id: str, file_size_mb: float = 0) -> bool: """Check if translation is allowed for a specific client ID""" if self.is_whitelisted(client_id): return True client = self.clients[client_id] allowed, _ = await client.check_translation(file_size_mb) return allowed def get_client_stats(self, request: Request) -> dict: """Get rate limit stats for a client""" client_id = self.get_client_id(request) client = self.clients[client_id] return { "client_id": client_id, "is_whitelisted": self.is_whitelisted(client_id), **client.get_stats() } async def get_client_status(self, client_id: str) -> dict: """Get current usage status for a client""" if client_id not in self.clients: return {"status": "no_activity", "requests": 0} client = self.clients[client_id] stats = client.get_stats() return { "requests_used_minute": stats["requests_minute"], "requests_used_hour": stats["requests_hour"], "translations_used_minute": stats["translations_minute"], "translations_used_hour": stats["translations_hour"], "concurrent_translations": stats["concurrent_translations"], "is_whitelisted": self.is_whitelisted(client_id) } def get_stats(self) -> dict: """Get global rate limiting statistics""" return { "total_requests": self._total_requests, "total_translations": self._total_translations, "active_clients": len(self.clients), "config": { "requests_per_minute": self.config.requests_per_minute, "requests_per_hour": self.config.requests_per_hour, "translations_per_minute": self.config.translations_per_minute, "translations_per_hour": self.config.translations_per_hour, "max_concurrent_translations": self.config.max_concurrent_translations } } class RateLimitMiddleware(BaseHTTPMiddleware): """FastAPI middleware for rate limiting""" def __init__(self, app, rate_limit_manager: RateLimitManager): super().__init__(app) self.manager = rate_limit_manager async def dispatch(self, request: Request, call_next): # Skip rate limiting for health checks and static files if request.url.path in ["/health", "/", "/docs", "/openapi.json", "/redoc"]: return await call_next(request) if request.url.path.startswith("/static"): return await call_next(request) # Check rate limit allowed, reason, client_id = await self.manager.check_request(request) if not allowed: logger.warning(f"Rate limit exceeded for {client_id}: {reason}") return JSONResponse( status_code=429, content={ "error": "rate_limit_exceeded", "message": reason, "retry_after": 60 }, headers={"Retry-After": "60"} ) # Add client info to request state for use in endpoints request.state.client_id = client_id request.state.rate_limiter = self.manager.clients[client_id] return await call_next(request) # Global rate limit manager rate_limit_manager = RateLimitManager()