chartbastan/backend/app/middleware/rate_limiter.py
2026-02-01 09:31:38 +01:00

155 lines
5.3 KiB
Python

"""
Rate limiting middleware for API.
This module provides rate limiting functionality for API endpoints.
"""
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Dict, Optional
from fastapi import HTTPException, status, Request
from starlette.middleware.base import BaseHTTPMiddleware
class InMemoryRateLimiter:
"""
In-memory rate limiter using sliding window algorithm.
For production, consider using Redis or a similar distributed cache.
"""
def __init__(self):
# Stores: {key: [(timestamp, count), ...]}
self.requests: Dict[str, list[tuple[datetime, int]]] = defaultdict(list)
# Stores last cleanup time
self.last_cleanup: Dict[str, datetime] = {}
def _cleanup_old_requests(self, key: str, window_seconds: int = 60):
"""Remove requests older than the time window."""
if key not in self.requests:
return
cutoff_time = datetime.utcnow() - timedelta(seconds=window_seconds)
self.requests[key] = [
(ts, count) for ts, count in self.requests[key]
if ts > cutoff_time
]
def _get_request_count(self, key: str, window_seconds: int = 60) -> int:
"""Count requests in the time window."""
self._cleanup_old_requests(key, window_seconds)
return sum(count for _, count in self.requests[key])
def is_allowed(self, key: str, limit: int, window_seconds: int = 60) -> bool:
"""
Check if request is allowed based on rate limit.
Args:
key: Unique identifier (e.g., user_id or IP address)
limit: Maximum requests per window
window_seconds: Time window in seconds (default: 60)
Returns:
True if allowed, False otherwise
"""
current_count = self._get_request_count(key, window_seconds)
return current_count < limit
def record_request(self, key: str):
"""Record a request for a key."""
self.requests[key].append((datetime.utcnow(), 1))
def get_remaining(self, key: str, limit: int, window_seconds: int = 60) -> int:
"""
Get remaining requests for a key.
Args:
key: Unique identifier
limit: Maximum requests per window
window_seconds: Time window in seconds (default: 60)
Returns:
Number of remaining requests
"""
current_count = self._get_request_count(key, window_seconds)
return max(0, limit - current_count)
# Global rate limiter instance
rate_limiter = InMemoryRateLimiter()
class RateLimitMiddleware(BaseHTTPMiddleware):
"""
Middleware to enforce rate limits on API requests.
This middleware applies rate limiting based on:
- If API key is present: use user_id from API key
- If no API key: use IP address (basic protection)
"""
def __init__(self, app, public_limit: int = 10, authenticated_limit: int = 100):
super().__init__(app)
self.public_limit = public_limit # 10 req/min for unauthenticated
self.authenticated_limit = authenticated_limit # 100 req/min for authenticated
async def dispatch(self, request: Request, call_next):
"""
Process request with rate limiting.
Args:
request: Incoming request
call_next: Next middleware/handler
Returns:
Response with rate limit headers
"""
# Get client IP for fallback
client_ip = request.client.host if request.client else "unknown"
# Get API key if present
api_key = request.headers.get("X-API-Key")
# Determine rate limit key
if api_key:
# For authenticated requests, we'd validate the key here
# For now, use API key hash as identifier
import hashlib
key_id = hashlib.sha256(api_key.encode()).hexdigest()[:16]
limit = self.authenticated_limit
else:
# Use IP address for unauthenticated requests
key_id = client_ip
limit = self.public_limit
# Check if request is allowed
if not rate_limiter.is_allowed(key_id, limit):
remaining = rate_limiter.get_remaining(key_id, limit)
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail={
"code": "RATE_LIMIT_EXCEEDED",
"message": f"Rate limit exceeded. Maximum {limit} requests per minute.",
"details": {
"limit": limit,
"remaining": remaining,
"retry_after": 60
}
}
)
# Record the request
rate_limiter.record_request(key_id)
# Process request
response = await call_next(request)
# Add rate limit headers to response
remaining = rate_limiter.get_remaining(key_id, limit)
response.headers["X-RateLimit-Limit"] = str(limit)
response.headers["X-RateLimit-Remaining"] = str(remaining)
response.headers["X-RateLimit-Reset"] = str(60) # Reset after 60 seconds
return response