""" Rate Limiting Middleware for SaaS robustness Protects against abuse and ensures fair usage. When REDIS_URL is set, uses Redis for sliding-window counters (shared across instances). Otherwise falls back to in-memory per-process limits. """ import time import asyncio from collections import defaultdict from dataclasses import dataclass, field from typing import Dict, Optional from fastapi import Request, HTTPException from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import JSONResponse import logging logger = logging.getLogger(__name__) @dataclass class RateLimitConfig: """Configuration for rate limiting""" # Requests per window requests_per_minute: int = 120 requests_per_hour: int = 1000 requests_per_day: int = 5000 # Translation-specific limits translations_per_minute: int = 10 translations_per_hour: int = 50 max_concurrent_translations: int = 5 # File size limits (MB) max_file_size_mb: int = 50 max_total_size_per_hour_mb: int = 500 # Burst protection burst_limit: int = 30 # Max requests in 1 second # Whitelist IPs (no rate limiting) whitelist_ips: list = field(default_factory=lambda: ["127.0.0.1", "::1"]) KEY_PREFIX_IP = "rate_limit:ip" KEY_PREFIX_SIZE_HOUR = "rate_limit:size_hour" async def _redis_sliding_window_is_allowed(redis_client, client_id: str, window_key: str, window_seconds: int, max_requests: int) -> bool: """Redis sliding window. Key: rate_limit:ip:{client_id}:{window_key}. Returns True if under limit.""" now = time.time() key = f"{KEY_PREFIX_IP}:{client_id}:{window_key}" pipe = redis_client.pipeline() pipe.zremrangebyscore(key, "-inf", now - window_seconds) pipe.zcard(key) results = await pipe.execute() count_after_cleanup = results[1] if count_after_cleanup >= max_requests: return False pipe2 = redis_client.pipeline() pipe2.zadd(key, {str(now): now}) pipe2.expire(key, window_seconds + 60) await pipe2.execute() return True async def _check_request_redis(redis_client, client_id: str, config: RateLimitConfig) -> tuple[bool, str]: """Check request limits using Redis. Returns (allowed, reason).""" if not await _redis_sliding_window_is_allowed(redis_client, client_id, "minute", 60, config.requests_per_minute): return False, f"Rate limit exceeded. Max {config.requests_per_minute} requests per minute." if not await _redis_sliding_window_is_allowed(redis_client, client_id, "hour", 3600, config.requests_per_hour): return False, f"Hourly limit exceeded. Max {config.requests_per_hour} requests per hour." if not await _redis_sliding_window_is_allowed(redis_client, client_id, "day", 86400, config.requests_per_day): return False, f"Daily limit exceeded. Max {config.requests_per_day} requests per day." return True, "" async def _check_translation_redis(redis_client, client_id: str, config: RateLimitConfig, file_size_mb: float = 0) -> tuple[bool, str]: """Check translation limits using Redis. Returns (allowed, reason).""" if not await _redis_sliding_window_is_allowed(redis_client, client_id, "trans_minute", 60, config.translations_per_minute): return False, f"Translation rate limit exceeded. Max {config.translations_per_minute} translations per minute." if not await _redis_sliding_window_is_allowed(redis_client, client_id, "trans_hour", 3600, config.translations_per_hour): return False, f"Hourly translation limit exceeded. Max {config.translations_per_hour} translations per hour." # Hourly total size (MB) per client — same semantics as in-memory total_size_hour now = time.time() hour_ts = int(now // 3600) size_key = f"{KEY_PREFIX_SIZE_HOUR}:{client_id}:{hour_ts}" try: cur_raw = await redis_client.get(size_key) cur = float(cur_raw or 0) if cur + file_size_mb > config.max_total_size_per_hour_mb: return False, f"Hourly data limit exceeded. Max {config.max_total_size_per_hour_mb}MB per hour." await redis_client.set(size_key, cur + file_size_mb, ex=7200) except Exception as e: logger.warning("Redis size-hour check failed: %s; allowing request", e) return True, "" class TokenBucket: """Token bucket algorithm for rate limiting""" def __init__(self, capacity: int, refill_rate: float): self.capacity = capacity self.refill_rate = refill_rate # tokens per second self.tokens = capacity self.last_refill = time.time() self._lock = asyncio.Lock() async def consume(self, tokens: int = 1) -> bool: """Try to consume tokens, return True if successful""" async with self._lock: self._refill() if self.tokens >= tokens: self.tokens -= tokens return True return False def _refill(self): """Refill tokens based on time elapsed""" now = time.time() elapsed = now - self.last_refill self.tokens = min(self.capacity, self.tokens + elapsed * self.refill_rate) self.last_refill = now class SlidingWindowCounter: """Sliding window counter for accurate rate limiting""" def __init__(self, window_seconds: int, max_requests: int): self.window_seconds = window_seconds self.max_requests = max_requests self.requests: list = [] self._lock = asyncio.Lock() async def is_allowed(self) -> bool: """Check if a new request is allowed""" async with self._lock: now = time.time() # Remove old requests outside the window self.requests = [ts for ts in self.requests if now - ts < self.window_seconds] if len(self.requests) < self.max_requests: self.requests.append(now) return True return False @property def current_count(self) -> int: """Get current request count in window""" now = time.time() return len([ts for ts in self.requests if now - ts < self.window_seconds]) class ClientRateLimiter: """Per-client rate limiter with multiple windows""" def __init__(self, config: RateLimitConfig): self.config = config self.minute_counter = SlidingWindowCounter(60, config.requests_per_minute) self.hour_counter = SlidingWindowCounter(3600, config.requests_per_hour) self.day_counter = SlidingWindowCounter(86400, config.requests_per_day) self.burst_bucket = TokenBucket(config.burst_limit, config.burst_limit) self.translation_minute = SlidingWindowCounter(60, config.translations_per_minute) self.translation_hour = SlidingWindowCounter(3600, config.translations_per_hour) self.concurrent_translations = 0 self.total_size_hour: list = [] # List of (timestamp, size_mb) self._lock = asyncio.Lock() async def check_request(self) -> tuple[bool, str]: """Check if request is allowed, return (allowed, reason)""" # Check burst limit if not await self.burst_bucket.consume(): return False, "Too many requests. Please slow down." # Check minute limit if not await self.minute_counter.is_allowed(): return False, f"Rate limit exceeded. Max {self.config.requests_per_minute} requests per minute." # Check hour limit if not await self.hour_counter.is_allowed(): return False, f"Hourly limit exceeded. Max {self.config.requests_per_hour} requests per hour." # Check day limit if not await self.day_counter.is_allowed(): return False, f"Daily limit exceeded. Max {self.config.requests_per_day} requests per day." return True, "" async def check_translation(self, file_size_mb: float = 0) -> tuple[bool, str]: """Check if translation request is allowed""" async with self._lock: # Check concurrent limit if self.concurrent_translations >= self.config.max_concurrent_translations: return False, f"Too many concurrent translations. Max {self.config.max_concurrent_translations} at a time." # Check translation per minute if not await self.translation_minute.is_allowed(): return False, f"Translation rate limit exceeded. Max {self.config.translations_per_minute} translations per minute." # Check translation per hour if not await self.translation_hour.is_allowed(): return False, f"Hourly translation limit exceeded. Max {self.config.translations_per_hour} translations per hour." # Check total size per hour async with self._lock: now = time.time() self.total_size_hour = [(ts, size) for ts, size in self.total_size_hour if now - ts < 3600] total_size = sum(size for _, size in self.total_size_hour) if total_size + file_size_mb > self.config.max_total_size_per_hour_mb: return False, f"Hourly data limit exceeded. Max {self.config.max_total_size_per_hour_mb}MB per hour." self.total_size_hour.append((now, file_size_mb)) return True, "" async def start_translation(self): """Mark start of translation""" async with self._lock: self.concurrent_translations += 1 async def end_translation(self): """Mark end of translation""" async with self._lock: self.concurrent_translations = max(0, self.concurrent_translations - 1) def get_stats(self) -> dict: """Get current rate limit stats""" return { "requests_minute": self.minute_counter.current_count, "requests_hour": self.hour_counter.current_count, "requests_day": self.day_counter.current_count, "translations_minute": self.translation_minute.current_count, "translations_hour": self.translation_hour.current_count, "concurrent_translations": self.concurrent_translations, } class RateLimitManager: """Manages rate limiters for all clients""" def __init__(self, config: Optional[RateLimitConfig] = None): self.config = config or RateLimitConfig() self.clients: Dict[str, ClientRateLimiter] = defaultdict(lambda: ClientRateLimiter(self.config)) self._cleanup_interval = 3600 # Cleanup old clients every hour self._last_cleanup = time.time() self._total_requests = 0 self._total_translations = 0 def get_client_id(self, request: Request) -> str: """Extract client identifier from request""" # Try to get real IP from headers (for proxied requests) forwarded = request.headers.get("X-Forwarded-For") if forwarded: return forwarded.split(",")[0].strip() real_ip = request.headers.get("X-Real-IP") if real_ip: return real_ip # Fall back to direct client IP if request.client: return request.client.host return "unknown" def is_whitelisted(self, client_id: str) -> bool: """Check if client is whitelisted""" return client_id in self.config.whitelist_ips async def check_request(self, request: Request) -> tuple[bool, str, str]: """Check if request is allowed, return (allowed, reason, client_id)""" client_id = self.get_client_id(request) # Prefer user ID for authenticated users (avoids shared limits behind proxy) # Try to extract from already-set state (auth middleware ran first) user_id = None if hasattr(request, "state"): user_id = getattr(request.state, "client_id", None) if not user_id: # Try to get user from auth header for better per-user limiting auth_header = request.headers.get("Authorization", "") if auth_header.startswith("Bearer "): # Use a hash of the token as user identifier (no decoding needed) import hashlib token = auth_header[7:] user_id = f"tok:{hashlib.sha256(token.encode()).hexdigest()[:16]}" if user_id: client_id = user_id self._total_requests += 1 if self.is_whitelisted(client_id): return True, "", client_id try: from core.redis import get_async_redis redis_client = get_async_redis() if redis_client: allowed, reason = await _check_request_redis(redis_client, client_id, self.config) return allowed, reason, client_id except Exception as e: logger.warning("Redis rate limit check failed, using in-memory: %s", e) client = self.clients[client_id] allowed, reason = await client.check_request() return allowed, reason, client_id async def check_translation(self, request: Request, file_size_mb: float = 0) -> tuple[bool, str]: """Check if translation is allowed""" client_id = self.get_client_id(request) self._total_translations += 1 if self.is_whitelisted(client_id): return True, "" try: from core.redis import get_async_redis redis_client = get_async_redis() if redis_client: allowed, reason = await _check_translation_redis(redis_client, client_id, self.config, file_size_mb) return allowed, reason except Exception as e: logger.warning("Redis translation limit check failed, using in-memory: %s", e) client = self.clients[client_id] return await client.check_translation(file_size_mb) async def check_translation_limit(self, client_id: str, file_size_mb: float = 0) -> bool: """Check if translation is allowed for a specific client ID""" if self.is_whitelisted(client_id): return True try: from core.redis import get_async_redis redis_client = get_async_redis() if redis_client: allowed, _ = await _check_translation_redis(redis_client, client_id, self.config, file_size_mb) return allowed except Exception as e: logger.warning("Redis translation limit check failed, using in-memory: %s", e) client = self.clients[client_id] allowed, _ = await client.check_translation(file_size_mb) return allowed def get_client_stats(self, request: Request) -> dict: """Get rate limit stats for a client""" client_id = self.get_client_id(request) client = self.clients[client_id] return { "client_id": client_id, "is_whitelisted": self.is_whitelisted(client_id), **client.get_stats() } async def get_client_status(self, client_id: str) -> dict: """Get current usage status for a client""" if client_id not in self.clients: return {"status": "no_activity", "requests": 0} client = self.clients[client_id] stats = client.get_stats() return { "requests_used_minute": stats["requests_minute"], "requests_used_hour": stats["requests_hour"], "translations_used_minute": stats["translations_minute"], "translations_used_hour": stats["translations_hour"], "concurrent_translations": stats["concurrent_translations"], "is_whitelisted": self.is_whitelisted(client_id) } def get_stats(self) -> dict: """Get global rate limiting statistics""" return { "total_requests": self._total_requests, "total_translations": self._total_translations, "active_clients": len(self.clients), "config": { "requests_per_minute": self.config.requests_per_minute, "requests_per_hour": self.config.requests_per_hour, "translations_per_minute": self.config.translations_per_minute, "translations_per_hour": self.config.translations_per_hour, "max_concurrent_translations": self.config.max_concurrent_translations } } class RateLimitMiddleware(BaseHTTPMiddleware): """FastAPI middleware for rate limiting""" def __init__(self, app, rate_limit_manager: RateLimitManager): super().__init__(app) self.manager = rate_limit_manager async def dispatch(self, request: Request, call_next): # CORS preflight must not be rate-limited (no ACAO header → browser blocks the real request) if request.method == "OPTIONS": return await call_next(request) # Skip rate limiting for health checks and static files if request.url.path in ["/health", "/", "/docs", "/openapi.json", "/redoc"]: return await call_next(request) if request.url.path.startswith("/static"): return await call_next(request) # Skip rate limiting for lightweight GET endpoints (read-only, cacheable) # These are config/fetch endpoints that don't consume resources if request.method == "GET": skip_paths = ( "/api/v1/languages", "/api/v1/providers", "/api/v1/auth/me", "/api/v1/auth/usage", "/api/v1/translations/", # status polling (uses job_id suffix) ) if any(request.url.path.startswith(p) for p in skip_paths): return await call_next(request) # Check rate limit allowed, reason, client_id = await self.manager.check_request(request) if not allowed: logger.warning(f"Rate limit exceeded for {client_id}: {reason}") return JSONResponse( status_code=429, content={ "error": "rate_limit_exceeded", "message": reason, "retry_after": 60 }, headers={"Retry-After": "60"} ) # Add client info to request state for use in endpoints request.state.client_id = client_id request.state.rate_limiter = self.manager.clients[client_id] return await call_next(request) # Global rate limit manager rate_limit_manager = RateLimitManager()