""" Rate limiting middleware for API. This module provides rate limiting functionality for API endpoints. """ from collections import defaultdict from datetime import datetime, timedelta from typing import Dict, Optional from fastapi import HTTPException, status, Request from starlette.middleware.base import BaseHTTPMiddleware class InMemoryRateLimiter: """ In-memory rate limiter using sliding window algorithm. For production, consider using Redis or a similar distributed cache. """ def __init__(self): # Stores: {key: [(timestamp, count), ...]} self.requests: Dict[str, list[tuple[datetime, int]]] = defaultdict(list) # Stores last cleanup time self.last_cleanup: Dict[str, datetime] = {} def _cleanup_old_requests(self, key: str, window_seconds: int = 60): """Remove requests older than the time window.""" if key not in self.requests: return cutoff_time = datetime.utcnow() - timedelta(seconds=window_seconds) self.requests[key] = [ (ts, count) for ts, count in self.requests[key] if ts > cutoff_time ] def _get_request_count(self, key: str, window_seconds: int = 60) -> int: """Count requests in the time window.""" self._cleanup_old_requests(key, window_seconds) return sum(count for _, count in self.requests[key]) def is_allowed(self, key: str, limit: int, window_seconds: int = 60) -> bool: """ Check if request is allowed based on rate limit. Args: key: Unique identifier (e.g., user_id or IP address) limit: Maximum requests per window window_seconds: Time window in seconds (default: 60) Returns: True if allowed, False otherwise """ current_count = self._get_request_count(key, window_seconds) return current_count < limit def record_request(self, key: str): """Record a request for a key.""" self.requests[key].append((datetime.utcnow(), 1)) def get_remaining(self, key: str, limit: int, window_seconds: int = 60) -> int: """ Get remaining requests for a key. Args: key: Unique identifier limit: Maximum requests per window window_seconds: Time window in seconds (default: 60) Returns: Number of remaining requests """ current_count = self._get_request_count(key, window_seconds) return max(0, limit - current_count) # Global rate limiter instance rate_limiter = InMemoryRateLimiter() class RateLimitMiddleware(BaseHTTPMiddleware): """ Middleware to enforce rate limits on API requests. This middleware applies rate limiting based on: - If API key is present: use user_id from API key - If no API key: use IP address (basic protection) """ def __init__(self, app, public_limit: int = 10, authenticated_limit: int = 100): super().__init__(app) self.public_limit = public_limit # 10 req/min for unauthenticated self.authenticated_limit = authenticated_limit # 100 req/min for authenticated async def dispatch(self, request: Request, call_next): """ Process request with rate limiting. Args: request: Incoming request call_next: Next middleware/handler Returns: Response with rate limit headers """ # Get client IP for fallback client_ip = request.client.host if request.client else "unknown" # Get API key if present api_key = request.headers.get("X-API-Key") # Determine rate limit key if api_key: # For authenticated requests, we'd validate the key here # For now, use API key hash as identifier import hashlib key_id = hashlib.sha256(api_key.encode()).hexdigest()[:16] limit = self.authenticated_limit else: # Use IP address for unauthenticated requests key_id = client_ip limit = self.public_limit # Check if request is allowed if not rate_limiter.is_allowed(key_id, limit): remaining = rate_limiter.get_remaining(key_id, limit) raise HTTPException( status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail={ "code": "RATE_LIMIT_EXCEEDED", "message": f"Rate limit exceeded. Maximum {limit} requests per minute.", "details": { "limit": limit, "remaining": remaining, "retry_after": 60 } } ) # Record the request rate_limiter.record_request(key_id) # Process request response = await call_next(request) # Add rate limit headers to response remaining = rate_limiter.get_remaining(key_id, limit) response.headers["X-RateLimit-Limit"] = str(limit) response.headers["X-RateLimit-Remaining"] = str(remaining) response.headers["X-RateLimit-Reset"] = str(60) # Reset after 60 seconds return response