Initial commit
This commit is contained in:
3
backend/app/middleware/__init__.py
Normal file
3
backend/app/middleware/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Middleware package for the application.
|
||||
"""
|
||||
154
backend/app/middleware/rate_limiter.py
Normal file
154
backend/app/middleware/rate_limiter.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""
|
||||
Rate limiting middleware for API.
|
||||
|
||||
This module provides rate limiting functionality for API endpoints.
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Optional
|
||||
|
||||
from fastapi import HTTPException, status, Request
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
|
||||
class InMemoryRateLimiter:
|
||||
"""
|
||||
In-memory rate limiter using sliding window algorithm.
|
||||
|
||||
For production, consider using Redis or a similar distributed cache.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Stores: {key: [(timestamp, count), ...]}
|
||||
self.requests: Dict[str, list[tuple[datetime, int]]] = defaultdict(list)
|
||||
# Stores last cleanup time
|
||||
self.last_cleanup: Dict[str, datetime] = {}
|
||||
|
||||
def _cleanup_old_requests(self, key: str, window_seconds: int = 60):
|
||||
"""Remove requests older than the time window."""
|
||||
if key not in self.requests:
|
||||
return
|
||||
|
||||
cutoff_time = datetime.utcnow() - timedelta(seconds=window_seconds)
|
||||
self.requests[key] = [
|
||||
(ts, count) for ts, count in self.requests[key]
|
||||
if ts > cutoff_time
|
||||
]
|
||||
|
||||
def _get_request_count(self, key: str, window_seconds: int = 60) -> int:
|
||||
"""Count requests in the time window."""
|
||||
self._cleanup_old_requests(key, window_seconds)
|
||||
return sum(count for _, count in self.requests[key])
|
||||
|
||||
def is_allowed(self, key: str, limit: int, window_seconds: int = 60) -> bool:
|
||||
"""
|
||||
Check if request is allowed based on rate limit.
|
||||
|
||||
Args:
|
||||
key: Unique identifier (e.g., user_id or IP address)
|
||||
limit: Maximum requests per window
|
||||
window_seconds: Time window in seconds (default: 60)
|
||||
|
||||
Returns:
|
||||
True if allowed, False otherwise
|
||||
"""
|
||||
current_count = self._get_request_count(key, window_seconds)
|
||||
return current_count < limit
|
||||
|
||||
def record_request(self, key: str):
|
||||
"""Record a request for a key."""
|
||||
self.requests[key].append((datetime.utcnow(), 1))
|
||||
|
||||
def get_remaining(self, key: str, limit: int, window_seconds: int = 60) -> int:
|
||||
"""
|
||||
Get remaining requests for a key.
|
||||
|
||||
Args:
|
||||
key: Unique identifier
|
||||
limit: Maximum requests per window
|
||||
window_seconds: Time window in seconds (default: 60)
|
||||
|
||||
Returns:
|
||||
Number of remaining requests
|
||||
"""
|
||||
current_count = self._get_request_count(key, window_seconds)
|
||||
return max(0, limit - current_count)
|
||||
|
||||
|
||||
# Global rate limiter instance
|
||||
rate_limiter = InMemoryRateLimiter()
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware to enforce rate limits on API requests.
|
||||
|
||||
This middleware applies rate limiting based on:
|
||||
- If API key is present: use user_id from API key
|
||||
- If no API key: use IP address (basic protection)
|
||||
"""
|
||||
|
||||
def __init__(self, app, public_limit: int = 10, authenticated_limit: int = 100):
|
||||
super().__init__(app)
|
||||
self.public_limit = public_limit # 10 req/min for unauthenticated
|
||||
self.authenticated_limit = authenticated_limit # 100 req/min for authenticated
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
"""
|
||||
Process request with rate limiting.
|
||||
|
||||
Args:
|
||||
request: Incoming request
|
||||
call_next: Next middleware/handler
|
||||
|
||||
Returns:
|
||||
Response with rate limit headers
|
||||
"""
|
||||
# Get client IP for fallback
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
|
||||
# Get API key if present
|
||||
api_key = request.headers.get("X-API-Key")
|
||||
|
||||
# Determine rate limit key
|
||||
if api_key:
|
||||
# For authenticated requests, we'd validate the key here
|
||||
# For now, use API key hash as identifier
|
||||
import hashlib
|
||||
key_id = hashlib.sha256(api_key.encode()).hexdigest()[:16]
|
||||
limit = self.authenticated_limit
|
||||
else:
|
||||
# Use IP address for unauthenticated requests
|
||||
key_id = client_ip
|
||||
limit = self.public_limit
|
||||
|
||||
# Check if request is allowed
|
||||
if not rate_limiter.is_allowed(key_id, limit):
|
||||
remaining = rate_limiter.get_remaining(key_id, limit)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail={
|
||||
"code": "RATE_LIMIT_EXCEEDED",
|
||||
"message": f"Rate limit exceeded. Maximum {limit} requests per minute.",
|
||||
"details": {
|
||||
"limit": limit,
|
||||
"remaining": remaining,
|
||||
"retry_after": 60
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Record the request
|
||||
rate_limiter.record_request(key_id)
|
||||
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Add rate limit headers to response
|
||||
remaining = rate_limiter.get_remaining(key_id, limit)
|
||||
response.headers["X-RateLimit-Limit"] = str(limit)
|
||||
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
||||
response.headers["X-RateLimit-Reset"] = str(60) # Reset after 60 seconds
|
||||
|
||||
return response
|
||||
Reference in New Issue
Block a user