office_translator/middleware/rate_limiting.py

329 lines
12 KiB
Python

"""
Rate Limiting Middleware for SaaS robustness
Protects against abuse and ensures fair usage
"""
import time
import asyncio
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Dict, Optional
from fastapi import Request, HTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
import logging
logger = logging.getLogger(__name__)
@dataclass
class RateLimitConfig:
"""Configuration for rate limiting"""
# Requests per window
requests_per_minute: int = 30
requests_per_hour: int = 200
requests_per_day: int = 1000
# Translation-specific limits
translations_per_minute: int = 10
translations_per_hour: int = 50
max_concurrent_translations: int = 5
# File size limits (MB)
max_file_size_mb: int = 50
max_total_size_per_hour_mb: int = 500
# Burst protection
burst_limit: int = 10 # Max requests in 1 second
# Whitelist IPs (no rate limiting)
whitelist_ips: list = field(default_factory=lambda: ["127.0.0.1", "::1"])
class TokenBucket:
"""Token bucket algorithm for rate limiting"""
def __init__(self, capacity: int, refill_rate: float):
self.capacity = capacity
self.refill_rate = refill_rate # tokens per second
self.tokens = capacity
self.last_refill = time.time()
self._lock = asyncio.Lock()
async def consume(self, tokens: int = 1) -> bool:
"""Try to consume tokens, return True if successful"""
async with self._lock:
self._refill()
if self.tokens >= tokens:
self.tokens -= tokens
return True
return False
def _refill(self):
"""Refill tokens based on time elapsed"""
now = time.time()
elapsed = now - self.last_refill
self.tokens = min(self.capacity, self.tokens + elapsed * self.refill_rate)
self.last_refill = now
class SlidingWindowCounter:
"""Sliding window counter for accurate rate limiting"""
def __init__(self, window_seconds: int, max_requests: int):
self.window_seconds = window_seconds
self.max_requests = max_requests
self.requests: list = []
self._lock = asyncio.Lock()
async def is_allowed(self) -> bool:
"""Check if a new request is allowed"""
async with self._lock:
now = time.time()
# Remove old requests outside the window
self.requests = [ts for ts in self.requests if now - ts < self.window_seconds]
if len(self.requests) < self.max_requests:
self.requests.append(now)
return True
return False
@property
def current_count(self) -> int:
"""Get current request count in window"""
now = time.time()
return len([ts for ts in self.requests if now - ts < self.window_seconds])
class ClientRateLimiter:
"""Per-client rate limiter with multiple windows"""
def __init__(self, config: RateLimitConfig):
self.config = config
self.minute_counter = SlidingWindowCounter(60, config.requests_per_minute)
self.hour_counter = SlidingWindowCounter(3600, config.requests_per_hour)
self.day_counter = SlidingWindowCounter(86400, config.requests_per_day)
self.burst_bucket = TokenBucket(config.burst_limit, config.burst_limit)
self.translation_minute = SlidingWindowCounter(60, config.translations_per_minute)
self.translation_hour = SlidingWindowCounter(3600, config.translations_per_hour)
self.concurrent_translations = 0
self.total_size_hour: list = [] # List of (timestamp, size_mb)
self._lock = asyncio.Lock()
async def check_request(self) -> tuple[bool, str]:
"""Check if request is allowed, return (allowed, reason)"""
# Check burst limit
if not await self.burst_bucket.consume():
return False, "Too many requests. Please slow down."
# Check minute limit
if not await self.minute_counter.is_allowed():
return False, f"Rate limit exceeded. Max {self.config.requests_per_minute} requests per minute."
# Check hour limit
if not await self.hour_counter.is_allowed():
return False, f"Hourly limit exceeded. Max {self.config.requests_per_hour} requests per hour."
# Check day limit
if not await self.day_counter.is_allowed():
return False, f"Daily limit exceeded. Max {self.config.requests_per_day} requests per day."
return True, ""
async def check_translation(self, file_size_mb: float = 0) -> tuple[bool, str]:
"""Check if translation request is allowed"""
async with self._lock:
# Check concurrent limit
if self.concurrent_translations >= self.config.max_concurrent_translations:
return False, f"Too many concurrent translations. Max {self.config.max_concurrent_translations} at a time."
# Check translation per minute
if not await self.translation_minute.is_allowed():
return False, f"Translation rate limit exceeded. Max {self.config.translations_per_minute} translations per minute."
# Check translation per hour
if not await self.translation_hour.is_allowed():
return False, f"Hourly translation limit exceeded. Max {self.config.translations_per_hour} translations per hour."
# Check total size per hour
async with self._lock:
now = time.time()
self.total_size_hour = [(ts, size) for ts, size in self.total_size_hour if now - ts < 3600]
total_size = sum(size for _, size in self.total_size_hour)
if total_size + file_size_mb > self.config.max_total_size_per_hour_mb:
return False, f"Hourly data limit exceeded. Max {self.config.max_total_size_per_hour_mb}MB per hour."
self.total_size_hour.append((now, file_size_mb))
return True, ""
async def start_translation(self):
"""Mark start of translation"""
async with self._lock:
self.concurrent_translations += 1
async def end_translation(self):
"""Mark end of translation"""
async with self._lock:
self.concurrent_translations = max(0, self.concurrent_translations - 1)
def get_stats(self) -> dict:
"""Get current rate limit stats"""
return {
"requests_minute": self.minute_counter.current_count,
"requests_hour": self.hour_counter.current_count,
"requests_day": self.day_counter.current_count,
"translations_minute": self.translation_minute.current_count,
"translations_hour": self.translation_hour.current_count,
"concurrent_translations": self.concurrent_translations,
}
class RateLimitManager:
"""Manages rate limiters for all clients"""
def __init__(self, config: Optional[RateLimitConfig] = None):
self.config = config or RateLimitConfig()
self.clients: Dict[str, ClientRateLimiter] = defaultdict(lambda: ClientRateLimiter(self.config))
self._cleanup_interval = 3600 # Cleanup old clients every hour
self._last_cleanup = time.time()
self._total_requests = 0
self._total_translations = 0
def get_client_id(self, request: Request) -> str:
"""Extract client identifier from request"""
# Try to get real IP from headers (for proxied requests)
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
# Fall back to direct client IP
if request.client:
return request.client.host
return "unknown"
def is_whitelisted(self, client_id: str) -> bool:
"""Check if client is whitelisted"""
return client_id in self.config.whitelist_ips
async def check_request(self, request: Request) -> tuple[bool, str, str]:
"""Check if request is allowed, return (allowed, reason, client_id)"""
client_id = self.get_client_id(request)
self._total_requests += 1
if self.is_whitelisted(client_id):
return True, "", client_id
client = self.clients[client_id]
allowed, reason = await client.check_request()
return allowed, reason, client_id
async def check_translation(self, request: Request, file_size_mb: float = 0) -> tuple[bool, str]:
"""Check if translation is allowed"""
client_id = self.get_client_id(request)
self._total_translations += 1
if self.is_whitelisted(client_id):
return True, ""
client = self.clients[client_id]
return await client.check_translation(file_size_mb)
async def check_translation_limit(self, client_id: str, file_size_mb: float = 0) -> bool:
"""Check if translation is allowed for a specific client ID"""
if self.is_whitelisted(client_id):
return True
client = self.clients[client_id]
allowed, _ = await client.check_translation(file_size_mb)
return allowed
def get_client_stats(self, request: Request) -> dict:
"""Get rate limit stats for a client"""
client_id = self.get_client_id(request)
client = self.clients[client_id]
return {
"client_id": client_id,
"is_whitelisted": self.is_whitelisted(client_id),
**client.get_stats()
}
async def get_client_status(self, client_id: str) -> dict:
"""Get current usage status for a client"""
if client_id not in self.clients:
return {"status": "no_activity", "requests": 0}
client = self.clients[client_id]
stats = client.get_stats()
return {
"requests_used_minute": stats["requests_minute"],
"requests_used_hour": stats["requests_hour"],
"translations_used_minute": stats["translations_minute"],
"translations_used_hour": stats["translations_hour"],
"concurrent_translations": stats["concurrent_translations"],
"is_whitelisted": self.is_whitelisted(client_id)
}
def get_stats(self) -> dict:
"""Get global rate limiting statistics"""
return {
"total_requests": self._total_requests,
"total_translations": self._total_translations,
"active_clients": len(self.clients),
"config": {
"requests_per_minute": self.config.requests_per_minute,
"requests_per_hour": self.config.requests_per_hour,
"translations_per_minute": self.config.translations_per_minute,
"translations_per_hour": self.config.translations_per_hour,
"max_concurrent_translations": self.config.max_concurrent_translations
}
}
class RateLimitMiddleware(BaseHTTPMiddleware):
"""FastAPI middleware for rate limiting"""
def __init__(self, app, rate_limit_manager: RateLimitManager):
super().__init__(app)
self.manager = rate_limit_manager
async def dispatch(self, request: Request, call_next):
# Skip rate limiting for health checks and static files
if request.url.path in ["/health", "/", "/docs", "/openapi.json", "/redoc"]:
return await call_next(request)
if request.url.path.startswith("/static"):
return await call_next(request)
# Check rate limit
allowed, reason, client_id = await self.manager.check_request(request)
if not allowed:
logger.warning(f"Rate limit exceeded for {client_id}: {reason}")
return JSONResponse(
status_code=429,
content={
"error": "rate_limit_exceeded",
"message": reason,
"retry_after": 60
},
headers={"Retry-After": "60"}
)
# Add client info to request state for use in endpoints
request.state.client_id = client_id
request.state.rate_limiter = self.manager.clients[client_id]
return await call_next(request)
# Global rate limit manager
rate_limit_manager = RateLimitManager()