- Restructured docker-compose for Nginx Proxy Manager (no custom nginx) - Added domain wordly.art configuration - Added Prometheus + Grafana monitoring stack with pre-configured dashboards - Added PostgreSQL backup script to NAS (daily/weekly/monthly rotation) - Added alert rules for backend, system, and Docker metrics - Updated deployment guide for NPM + IONOS DNS homelab setup - Added marketing plan document - PDF translator and watermark support - Enhanced middleware, routes, and translator modules Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
444 lines
18 KiB
Python
444 lines
18 KiB
Python
"""
|
|
Rate Limiting Middleware for SaaS robustness
|
|
Protects against abuse and ensures fair usage.
|
|
When REDIS_URL is set, uses Redis for sliding-window counters (shared across instances).
|
|
Otherwise falls back to in-memory per-process limits.
|
|
"""
|
|
import time
|
|
import asyncio
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass, field
|
|
from typing import Dict, Optional
|
|
from fastapi import Request, HTTPException
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.responses import JSONResponse
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class RateLimitConfig:
|
|
"""Configuration for rate limiting"""
|
|
# Requests per window
|
|
requests_per_minute: int = 120
|
|
requests_per_hour: int = 1000
|
|
requests_per_day: int = 5000
|
|
|
|
# Translation-specific limits
|
|
translations_per_minute: int = 10
|
|
translations_per_hour: int = 50
|
|
max_concurrent_translations: int = 5
|
|
|
|
# File size limits (MB)
|
|
max_file_size_mb: int = 50
|
|
max_total_size_per_hour_mb: int = 500
|
|
|
|
# Burst protection
|
|
burst_limit: int = 30 # Max requests in 1 second
|
|
|
|
# Whitelist IPs (no rate limiting)
|
|
whitelist_ips: list = field(default_factory=lambda: ["127.0.0.1", "::1"])
|
|
|
|
|
|
KEY_PREFIX_IP = "rate_limit:ip"
|
|
KEY_PREFIX_SIZE_HOUR = "rate_limit:size_hour"
|
|
|
|
|
|
async def _redis_sliding_window_is_allowed(redis_client, client_id: str, window_key: str, window_seconds: int, max_requests: int) -> bool:
|
|
"""Redis sliding window. Key: rate_limit:ip:{client_id}:{window_key}. Returns True if under limit."""
|
|
now = time.time()
|
|
key = f"{KEY_PREFIX_IP}:{client_id}:{window_key}"
|
|
pipe = redis_client.pipeline()
|
|
pipe.zremrangebyscore(key, "-inf", now - window_seconds)
|
|
pipe.zcard(key)
|
|
results = await pipe.execute()
|
|
count_after_cleanup = results[1]
|
|
if count_after_cleanup >= max_requests:
|
|
return False
|
|
pipe2 = redis_client.pipeline()
|
|
pipe2.zadd(key, {str(now): now})
|
|
pipe2.expire(key, window_seconds + 60)
|
|
await pipe2.execute()
|
|
return True
|
|
|
|
|
|
async def _check_request_redis(redis_client, client_id: str, config: RateLimitConfig) -> tuple[bool, str]:
|
|
"""Check request limits using Redis. Returns (allowed, reason)."""
|
|
if not await _redis_sliding_window_is_allowed(redis_client, client_id, "minute", 60, config.requests_per_minute):
|
|
return False, f"Rate limit exceeded. Max {config.requests_per_minute} requests per minute."
|
|
if not await _redis_sliding_window_is_allowed(redis_client, client_id, "hour", 3600, config.requests_per_hour):
|
|
return False, f"Hourly limit exceeded. Max {config.requests_per_hour} requests per hour."
|
|
if not await _redis_sliding_window_is_allowed(redis_client, client_id, "day", 86400, config.requests_per_day):
|
|
return False, f"Daily limit exceeded. Max {config.requests_per_day} requests per day."
|
|
return True, ""
|
|
|
|
|
|
async def _check_translation_redis(redis_client, client_id: str, config: RateLimitConfig, file_size_mb: float = 0) -> tuple[bool, str]:
|
|
"""Check translation limits using Redis. Returns (allowed, reason)."""
|
|
if not await _redis_sliding_window_is_allowed(redis_client, client_id, "trans_minute", 60, config.translations_per_minute):
|
|
return False, f"Translation rate limit exceeded. Max {config.translations_per_minute} translations per minute."
|
|
if not await _redis_sliding_window_is_allowed(redis_client, client_id, "trans_hour", 3600, config.translations_per_hour):
|
|
return False, f"Hourly translation limit exceeded. Max {config.translations_per_hour} translations per hour."
|
|
# Hourly total size (MB) per client — same semantics as in-memory total_size_hour
|
|
now = time.time()
|
|
hour_ts = int(now // 3600)
|
|
size_key = f"{KEY_PREFIX_SIZE_HOUR}:{client_id}:{hour_ts}"
|
|
try:
|
|
cur_raw = await redis_client.get(size_key)
|
|
cur = float(cur_raw or 0)
|
|
if cur + file_size_mb > config.max_total_size_per_hour_mb:
|
|
return False, f"Hourly data limit exceeded. Max {config.max_total_size_per_hour_mb}MB per hour."
|
|
await redis_client.set(size_key, cur + file_size_mb, ex=7200)
|
|
except Exception as e:
|
|
logger.warning("Redis size-hour check failed: %s; allowing request", e)
|
|
return True, ""
|
|
|
|
|
|
class TokenBucket:
|
|
"""Token bucket algorithm for rate limiting"""
|
|
|
|
def __init__(self, capacity: int, refill_rate: float):
|
|
self.capacity = capacity
|
|
self.refill_rate = refill_rate # tokens per second
|
|
self.tokens = capacity
|
|
self.last_refill = time.time()
|
|
self._lock = asyncio.Lock()
|
|
|
|
async def consume(self, tokens: int = 1) -> bool:
|
|
"""Try to consume tokens, return True if successful"""
|
|
async with self._lock:
|
|
self._refill()
|
|
if self.tokens >= tokens:
|
|
self.tokens -= tokens
|
|
return True
|
|
return False
|
|
|
|
def _refill(self):
|
|
"""Refill tokens based on time elapsed"""
|
|
now = time.time()
|
|
elapsed = now - self.last_refill
|
|
self.tokens = min(self.capacity, self.tokens + elapsed * self.refill_rate)
|
|
self.last_refill = now
|
|
|
|
|
|
class SlidingWindowCounter:
|
|
"""Sliding window counter for accurate rate limiting"""
|
|
|
|
def __init__(self, window_seconds: int, max_requests: int):
|
|
self.window_seconds = window_seconds
|
|
self.max_requests = max_requests
|
|
self.requests: list = []
|
|
self._lock = asyncio.Lock()
|
|
|
|
async def is_allowed(self) -> bool:
|
|
"""Check if a new request is allowed"""
|
|
async with self._lock:
|
|
now = time.time()
|
|
# Remove old requests outside the window
|
|
self.requests = [ts for ts in self.requests if now - ts < self.window_seconds]
|
|
|
|
if len(self.requests) < self.max_requests:
|
|
self.requests.append(now)
|
|
return True
|
|
return False
|
|
|
|
@property
|
|
def current_count(self) -> int:
|
|
"""Get current request count in window"""
|
|
now = time.time()
|
|
return len([ts for ts in self.requests if now - ts < self.window_seconds])
|
|
|
|
|
|
class ClientRateLimiter:
|
|
"""Per-client rate limiter with multiple windows"""
|
|
|
|
def __init__(self, config: RateLimitConfig):
|
|
self.config = config
|
|
self.minute_counter = SlidingWindowCounter(60, config.requests_per_minute)
|
|
self.hour_counter = SlidingWindowCounter(3600, config.requests_per_hour)
|
|
self.day_counter = SlidingWindowCounter(86400, config.requests_per_day)
|
|
self.burst_bucket = TokenBucket(config.burst_limit, config.burst_limit)
|
|
self.translation_minute = SlidingWindowCounter(60, config.translations_per_minute)
|
|
self.translation_hour = SlidingWindowCounter(3600, config.translations_per_hour)
|
|
self.concurrent_translations = 0
|
|
self.total_size_hour: list = [] # List of (timestamp, size_mb)
|
|
self._lock = asyncio.Lock()
|
|
|
|
async def check_request(self) -> tuple[bool, str]:
|
|
"""Check if request is allowed, return (allowed, reason)"""
|
|
# Check burst limit
|
|
if not await self.burst_bucket.consume():
|
|
return False, "Too many requests. Please slow down."
|
|
|
|
# Check minute limit
|
|
if not await self.minute_counter.is_allowed():
|
|
return False, f"Rate limit exceeded. Max {self.config.requests_per_minute} requests per minute."
|
|
|
|
# Check hour limit
|
|
if not await self.hour_counter.is_allowed():
|
|
return False, f"Hourly limit exceeded. Max {self.config.requests_per_hour} requests per hour."
|
|
|
|
# Check day limit
|
|
if not await self.day_counter.is_allowed():
|
|
return False, f"Daily limit exceeded. Max {self.config.requests_per_day} requests per day."
|
|
|
|
return True, ""
|
|
|
|
async def check_translation(self, file_size_mb: float = 0) -> tuple[bool, str]:
|
|
"""Check if translation request is allowed"""
|
|
async with self._lock:
|
|
# Check concurrent limit
|
|
if self.concurrent_translations >= self.config.max_concurrent_translations:
|
|
return False, f"Too many concurrent translations. Max {self.config.max_concurrent_translations} at a time."
|
|
|
|
# Check translation per minute
|
|
if not await self.translation_minute.is_allowed():
|
|
return False, f"Translation rate limit exceeded. Max {self.config.translations_per_minute} translations per minute."
|
|
|
|
# Check translation per hour
|
|
if not await self.translation_hour.is_allowed():
|
|
return False, f"Hourly translation limit exceeded. Max {self.config.translations_per_hour} translations per hour."
|
|
|
|
# Check total size per hour
|
|
async with self._lock:
|
|
now = time.time()
|
|
self.total_size_hour = [(ts, size) for ts, size in self.total_size_hour if now - ts < 3600]
|
|
total_size = sum(size for _, size in self.total_size_hour)
|
|
|
|
if total_size + file_size_mb > self.config.max_total_size_per_hour_mb:
|
|
return False, f"Hourly data limit exceeded. Max {self.config.max_total_size_per_hour_mb}MB per hour."
|
|
|
|
self.total_size_hour.append((now, file_size_mb))
|
|
|
|
return True, ""
|
|
|
|
async def start_translation(self):
|
|
"""Mark start of translation"""
|
|
async with self._lock:
|
|
self.concurrent_translations += 1
|
|
|
|
async def end_translation(self):
|
|
"""Mark end of translation"""
|
|
async with self._lock:
|
|
self.concurrent_translations = max(0, self.concurrent_translations - 1)
|
|
|
|
def get_stats(self) -> dict:
|
|
"""Get current rate limit stats"""
|
|
return {
|
|
"requests_minute": self.minute_counter.current_count,
|
|
"requests_hour": self.hour_counter.current_count,
|
|
"requests_day": self.day_counter.current_count,
|
|
"translations_minute": self.translation_minute.current_count,
|
|
"translations_hour": self.translation_hour.current_count,
|
|
"concurrent_translations": self.concurrent_translations,
|
|
}
|
|
|
|
|
|
class RateLimitManager:
|
|
"""Manages rate limiters for all clients"""
|
|
|
|
def __init__(self, config: Optional[RateLimitConfig] = None):
|
|
self.config = config or RateLimitConfig()
|
|
self.clients: Dict[str, ClientRateLimiter] = defaultdict(lambda: ClientRateLimiter(self.config))
|
|
self._cleanup_interval = 3600 # Cleanup old clients every hour
|
|
self._last_cleanup = time.time()
|
|
self._total_requests = 0
|
|
self._total_translations = 0
|
|
|
|
def get_client_id(self, request: Request) -> str:
|
|
"""Extract client identifier from request"""
|
|
# Try to get real IP from headers (for proxied requests)
|
|
forwarded = request.headers.get("X-Forwarded-For")
|
|
if forwarded:
|
|
return forwarded.split(",")[0].strip()
|
|
|
|
real_ip = request.headers.get("X-Real-IP")
|
|
if real_ip:
|
|
return real_ip
|
|
|
|
# Fall back to direct client IP
|
|
if request.client:
|
|
return request.client.host
|
|
|
|
return "unknown"
|
|
|
|
def is_whitelisted(self, client_id: str) -> bool:
|
|
"""Check if client is whitelisted"""
|
|
return client_id in self.config.whitelist_ips
|
|
|
|
async def check_request(self, request: Request) -> tuple[bool, str, str]:
|
|
"""Check if request is allowed, return (allowed, reason, client_id)"""
|
|
client_id = self.get_client_id(request)
|
|
|
|
# Prefer user ID for authenticated users (avoids shared limits behind proxy)
|
|
# Try to extract from already-set state (auth middleware ran first)
|
|
user_id = None
|
|
if hasattr(request, "state"):
|
|
user_id = getattr(request.state, "client_id", None)
|
|
if not user_id:
|
|
# Try to get user from auth header for better per-user limiting
|
|
auth_header = request.headers.get("Authorization", "")
|
|
if auth_header.startswith("Bearer "):
|
|
# Use a hash of the token as user identifier (no decoding needed)
|
|
import hashlib
|
|
token = auth_header[7:]
|
|
user_id = f"tok:{hashlib.sha256(token.encode()).hexdigest()[:16]}"
|
|
if user_id:
|
|
client_id = user_id
|
|
|
|
self._total_requests += 1
|
|
|
|
if self.is_whitelisted(client_id):
|
|
return True, "", client_id
|
|
|
|
try:
|
|
from core.redis import get_async_redis
|
|
redis_client = get_async_redis()
|
|
if redis_client:
|
|
allowed, reason = await _check_request_redis(redis_client, client_id, self.config)
|
|
return allowed, reason, client_id
|
|
except Exception as e:
|
|
logger.warning("Redis rate limit check failed, using in-memory: %s", e)
|
|
|
|
client = self.clients[client_id]
|
|
allowed, reason = await client.check_request()
|
|
return allowed, reason, client_id
|
|
|
|
async def check_translation(self, request: Request, file_size_mb: float = 0) -> tuple[bool, str]:
|
|
"""Check if translation is allowed"""
|
|
client_id = self.get_client_id(request)
|
|
self._total_translations += 1
|
|
|
|
if self.is_whitelisted(client_id):
|
|
return True, ""
|
|
|
|
try:
|
|
from core.redis import get_async_redis
|
|
redis_client = get_async_redis()
|
|
if redis_client:
|
|
allowed, reason = await _check_translation_redis(redis_client, client_id, self.config, file_size_mb)
|
|
return allowed, reason
|
|
except Exception as e:
|
|
logger.warning("Redis translation limit check failed, using in-memory: %s", e)
|
|
|
|
client = self.clients[client_id]
|
|
return await client.check_translation(file_size_mb)
|
|
|
|
async def check_translation_limit(self, client_id: str, file_size_mb: float = 0) -> bool:
|
|
"""Check if translation is allowed for a specific client ID"""
|
|
if self.is_whitelisted(client_id):
|
|
return True
|
|
|
|
try:
|
|
from core.redis import get_async_redis
|
|
redis_client = get_async_redis()
|
|
if redis_client:
|
|
allowed, _ = await _check_translation_redis(redis_client, client_id, self.config, file_size_mb)
|
|
return allowed
|
|
except Exception as e:
|
|
logger.warning("Redis translation limit check failed, using in-memory: %s", e)
|
|
|
|
client = self.clients[client_id]
|
|
allowed, _ = await client.check_translation(file_size_mb)
|
|
return allowed
|
|
|
|
def get_client_stats(self, request: Request) -> dict:
|
|
"""Get rate limit stats for a client"""
|
|
client_id = self.get_client_id(request)
|
|
client = self.clients[client_id]
|
|
return {
|
|
"client_id": client_id,
|
|
"is_whitelisted": self.is_whitelisted(client_id),
|
|
**client.get_stats()
|
|
}
|
|
|
|
async def get_client_status(self, client_id: str) -> dict:
|
|
"""Get current usage status for a client"""
|
|
if client_id not in self.clients:
|
|
return {"status": "no_activity", "requests": 0}
|
|
|
|
client = self.clients[client_id]
|
|
stats = client.get_stats()
|
|
|
|
return {
|
|
"requests_used_minute": stats["requests_minute"],
|
|
"requests_used_hour": stats["requests_hour"],
|
|
"translations_used_minute": stats["translations_minute"],
|
|
"translations_used_hour": stats["translations_hour"],
|
|
"concurrent_translations": stats["concurrent_translations"],
|
|
"is_whitelisted": self.is_whitelisted(client_id)
|
|
}
|
|
|
|
def get_stats(self) -> dict:
|
|
"""Get global rate limiting statistics"""
|
|
return {
|
|
"total_requests": self._total_requests,
|
|
"total_translations": self._total_translations,
|
|
"active_clients": len(self.clients),
|
|
"config": {
|
|
"requests_per_minute": self.config.requests_per_minute,
|
|
"requests_per_hour": self.config.requests_per_hour,
|
|
"translations_per_minute": self.config.translations_per_minute,
|
|
"translations_per_hour": self.config.translations_per_hour,
|
|
"max_concurrent_translations": self.config.max_concurrent_translations
|
|
}
|
|
}
|
|
|
|
|
|
class RateLimitMiddleware(BaseHTTPMiddleware):
|
|
"""FastAPI middleware for rate limiting"""
|
|
|
|
def __init__(self, app, rate_limit_manager: RateLimitManager):
|
|
super().__init__(app)
|
|
self.manager = rate_limit_manager
|
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
# CORS preflight must not be rate-limited (no ACAO header → browser blocks the real request)
|
|
if request.method == "OPTIONS":
|
|
return await call_next(request)
|
|
# Skip rate limiting for health checks and static files
|
|
if request.url.path in ["/health", "/", "/docs", "/openapi.json", "/redoc"]:
|
|
return await call_next(request)
|
|
|
|
if request.url.path.startswith("/static"):
|
|
return await call_next(request)
|
|
|
|
# Skip rate limiting for lightweight GET endpoints (read-only, cacheable)
|
|
# These are config/fetch endpoints that don't consume resources
|
|
if request.method == "GET":
|
|
skip_paths = (
|
|
"/api/v1/languages",
|
|
"/api/v1/providers",
|
|
"/api/v1/auth/me",
|
|
"/api/v1/auth/usage",
|
|
"/api/v1/translations/", # status polling (uses job_id suffix)
|
|
)
|
|
if any(request.url.path.startswith(p) for p in skip_paths):
|
|
return await call_next(request)
|
|
|
|
# Check rate limit
|
|
allowed, reason, client_id = await self.manager.check_request(request)
|
|
|
|
if not allowed:
|
|
logger.warning(f"Rate limit exceeded for {client_id}: {reason}")
|
|
return JSONResponse(
|
|
status_code=429,
|
|
content={
|
|
"error": "rate_limit_exceeded",
|
|
"message": reason,
|
|
"retry_after": 60
|
|
},
|
|
headers={"Retry-After": "60"}
|
|
)
|
|
|
|
# Add client info to request state for use in endpoints
|
|
request.state.client_id = client_id
|
|
request.state.rate_limiter = self.manager.clients[client_id]
|
|
|
|
return await call_next(request)
|
|
|
|
|
|
# Global rate limit manager
|
|
rate_limit_manager = RateLimitManager()
|