feat: Add SaaS robustness middleware - Rate limiting with token bucket and sliding window algorithms - Input validation (file, language, provider) - Security headers middleware (CSP, XSS protection, etc.) - Automatic file cleanup with TTL tracking - Memory and disk monitoring - Enhanced health check and metrics endpoints - Request logging with unique IDs

This commit is contained in:
2025-11-30 19:25:09 +01:00
parent 8c7716bf4d
commit 500502440c
9 changed files with 1681 additions and 26 deletions

62
middleware/__init__.py Normal file
View File

@@ -0,0 +1,62 @@
"""
Middleware package for SaaS robustness
This package provides:
- Rate limiting: Protect against abuse and ensure fair usage
- Validation: Validate all inputs before processing
- Security: Security headers, request logging, error handling
- Cleanup: Automatic file cleanup and resource management
"""
from .rate_limiting import (
RateLimitConfig,
RateLimitManager,
RateLimitMiddleware,
ClientRateLimiter,
)
from .validation import (
ValidationError,
ValidationResult,
FileValidator,
LanguageValidator,
ProviderValidator,
InputSanitizer,
)
from .security import (
SecurityHeadersMiddleware,
RequestLoggingMiddleware,
ErrorHandlingMiddleware,
)
from .cleanup import (
FileCleanupManager,
MemoryMonitor,
HealthChecker,
create_cleanup_manager,
)
__all__ = [
# Rate limiting
"RateLimitConfig",
"RateLimitManager",
"RateLimitMiddleware",
"ClientRateLimiter",
# Validation
"ValidationError",
"ValidationResult",
"FileValidator",
"LanguageValidator",
"ProviderValidator",
"InputSanitizer",
# Security
"SecurityHeadersMiddleware",
"RequestLoggingMiddleware",
"ErrorHandlingMiddleware",
# Cleanup
"FileCleanupManager",
"MemoryMonitor",
"HealthChecker",
"create_cleanup_manager",
]

400
middleware/cleanup.py Normal file
View File

@@ -0,0 +1,400 @@
"""
Cleanup and Resource Management for SaaS robustness
Automatic cleanup of temporary files and resources
"""
import os
import time
import asyncio
import threading
from pathlib import Path
from datetime import datetime, timedelta
from typing import Optional, Set
import logging
logger = logging.getLogger(__name__)
class FileCleanupManager:
"""Manages automatic cleanup of temporary and output files"""
def __init__(
self,
upload_dir: Path,
output_dir: Path,
temp_dir: Path,
max_file_age_hours: int = 1,
cleanup_interval_minutes: int = 10,
max_total_size_gb: float = 10.0
):
self.upload_dir = Path(upload_dir)
self.output_dir = Path(output_dir)
self.temp_dir = Path(temp_dir)
self.max_file_age_seconds = max_file_age_hours * 3600
self.cleanup_interval = cleanup_interval_minutes * 60
self.max_total_size_bytes = int(max_total_size_gb * 1024 * 1024 * 1024)
self._running = False
self._task: Optional[asyncio.Task] = None
self._protected_files: Set[str] = set()
self._tracked_files: dict = {} # filepath -> {created, ttl_minutes}
self._lock = threading.Lock()
self._stats = {
"files_cleaned": 0,
"bytes_freed": 0,
"cleanup_runs": 0
}
async def track_file(self, filepath: Path, ttl_minutes: int = 60):
"""Track a file for automatic cleanup after TTL expires"""
with self._lock:
self._tracked_files[str(filepath)] = {
"created": time.time(),
"ttl_minutes": ttl_minutes,
"expires_at": time.time() + (ttl_minutes * 60)
}
def get_tracked_files(self) -> list:
"""Get list of currently tracked files with their status"""
now = time.time()
result = []
with self._lock:
for filepath, info in self._tracked_files.items():
remaining = info["expires_at"] - now
result.append({
"path": filepath,
"exists": Path(filepath).exists(),
"expires_in_seconds": max(0, int(remaining)),
"ttl_minutes": info["ttl_minutes"]
})
return result
async def cleanup_expired(self) -> int:
"""Cleanup expired tracked files"""
now = time.time()
cleaned = 0
to_remove = []
with self._lock:
for filepath, info in list(self._tracked_files.items()):
if now > info["expires_at"]:
to_remove.append(filepath)
for filepath in to_remove:
try:
path = Path(filepath)
if path.exists() and not self.is_protected(path):
size = path.stat().st_size
path.unlink()
cleaned += 1
self._stats["files_cleaned"] += 1
self._stats["bytes_freed"] += size
logger.info(f"Cleaned expired file: {filepath}")
with self._lock:
self._tracked_files.pop(filepath, None)
except Exception as e:
logger.warning(f"Failed to clean expired file {filepath}: {e}")
return cleaned
def get_stats(self) -> dict:
"""Get cleanup statistics"""
disk_usage = self.get_disk_usage()
with self._lock:
tracked_count = len(self._tracked_files)
return {
"files_cleaned_total": self._stats["files_cleaned"],
"bytes_freed_total_mb": round(self._stats["bytes_freed"] / (1024 * 1024), 2),
"cleanup_runs": self._stats["cleanup_runs"],
"tracked_files": tracked_count,
"disk_usage": disk_usage,
"is_running": self._running
}
def protect_file(self, filepath: Path):
"""Mark a file as protected (being processed)"""
with self._lock:
self._protected_files.add(str(filepath))
def unprotect_file(self, filepath: Path):
"""Remove protection from a file"""
with self._lock:
self._protected_files.discard(str(filepath))
def is_protected(self, filepath: Path) -> bool:
"""Check if a file is protected"""
with self._lock:
return str(filepath) in self._protected_files
async def start(self):
"""Start the cleanup background task"""
if self._running:
return
self._running = True
self._task = asyncio.create_task(self._cleanup_loop())
logger.info("File cleanup manager started")
async def stop(self):
"""Stop the cleanup background task"""
self._running = False
if self._task:
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
logger.info("File cleanup manager stopped")
async def _cleanup_loop(self):
"""Background loop for periodic cleanup"""
while self._running:
try:
await self.cleanup()
await self.cleanup_expired()
self._stats["cleanup_runs"] += 1
except Exception as e:
logger.error(f"Cleanup error: {e}")
await asyncio.sleep(self.cleanup_interval)
async def cleanup(self) -> dict:
"""Perform cleanup of old files"""
stats = {
"files_deleted": 0,
"bytes_freed": 0,
"errors": []
}
now = time.time()
# Cleanup each directory
for directory in [self.upload_dir, self.output_dir, self.temp_dir]:
if not directory.exists():
continue
for filepath in directory.iterdir():
if not filepath.is_file():
continue
# Skip protected files
if self.is_protected(filepath):
continue
try:
# Check file age
file_age = now - filepath.stat().st_mtime
if file_age > self.max_file_age_seconds:
file_size = filepath.stat().st_size
filepath.unlink()
stats["files_deleted"] += 1
stats["bytes_freed"] += file_size
logger.debug(f"Deleted old file: {filepath}")
except Exception as e:
stats["errors"].append(str(e))
logger.warning(f"Failed to delete {filepath}: {e}")
# Force cleanup if total size exceeds limit
await self._enforce_size_limit(stats)
if stats["files_deleted"] > 0:
mb_freed = stats["bytes_freed"] / (1024 * 1024)
logger.info(f"Cleanup: deleted {stats['files_deleted']} files, freed {mb_freed:.2f}MB")
return stats
async def _enforce_size_limit(self, stats: dict):
"""Delete oldest files if total size exceeds limit"""
files_with_mtime = []
total_size = 0
for directory in [self.upload_dir, self.output_dir, self.temp_dir]:
if not directory.exists():
continue
for filepath in directory.iterdir():
if not filepath.is_file() or self.is_protected(filepath):
continue
try:
stat = filepath.stat()
files_with_mtime.append((filepath, stat.st_mtime, stat.st_size))
total_size += stat.st_size
except Exception:
pass
# If under limit, nothing to do
if total_size <= self.max_total_size_bytes:
return
# Sort by modification time (oldest first)
files_with_mtime.sort(key=lambda x: x[1])
# Delete oldest files until under limit
for filepath, _, size in files_with_mtime:
if total_size <= self.max_total_size_bytes:
break
try:
filepath.unlink()
total_size -= size
stats["files_deleted"] += 1
stats["bytes_freed"] += size
logger.info(f"Deleted file to free space: {filepath}")
except Exception as e:
stats["errors"].append(str(e))
def get_disk_usage(self) -> dict:
"""Get current disk usage statistics"""
total_files = 0
total_size = 0
for directory in [self.upload_dir, self.output_dir, self.temp_dir]:
if not directory.exists():
continue
for filepath in directory.iterdir():
if filepath.is_file():
total_files += 1
try:
total_size += filepath.stat().st_size
except Exception:
pass
return {
"total_files": total_files,
"total_size_mb": round(total_size / (1024 * 1024), 2),
"max_size_gb": self.max_total_size_bytes / (1024 * 1024 * 1024),
"usage_percent": round((total_size / self.max_total_size_bytes) * 100, 1) if self.max_total_size_bytes > 0 else 0,
"directories": {
"uploads": str(self.upload_dir),
"outputs": str(self.output_dir),
"temp": str(self.temp_dir)
}
}
class MemoryMonitor:
"""Monitors memory usage and triggers cleanup if needed"""
def __init__(self, max_memory_percent: float = 80.0):
self.max_memory_percent = max_memory_percent
self._high_memory_callbacks = []
def get_memory_usage(self) -> dict:
"""Get current memory usage"""
try:
import psutil
process = psutil.Process()
memory_info = process.memory_info()
system_memory = psutil.virtual_memory()
return {
"process_rss_mb": round(memory_info.rss / (1024 * 1024), 2),
"process_vms_mb": round(memory_info.vms / (1024 * 1024), 2),
"system_total_gb": round(system_memory.total / (1024 * 1024 * 1024), 2),
"system_available_gb": round(system_memory.available / (1024 * 1024 * 1024), 2),
"system_percent": system_memory.percent
}
except ImportError:
return {"error": "psutil not installed"}
except Exception as e:
return {"error": str(e)}
def check_memory(self) -> bool:
"""Check if memory usage is within limits"""
usage = self.get_memory_usage()
if "error" in usage:
return True # Can't check, assume OK
return usage.get("system_percent", 0) < self.max_memory_percent
def on_high_memory(self, callback):
"""Register callback for high memory situations"""
self._high_memory_callbacks.append(callback)
class HealthChecker:
"""Comprehensive health checking for the application"""
def __init__(self, cleanup_manager: FileCleanupManager, memory_monitor: MemoryMonitor):
self.cleanup_manager = cleanup_manager
self.memory_monitor = memory_monitor
self.start_time = datetime.now()
self._translation_count = 0
self._error_count = 0
self._lock = threading.Lock()
def record_translation(self, success: bool = True):
"""Record a translation attempt"""
with self._lock:
self._translation_count += 1
if not success:
self._error_count += 1
async def check_health(self) -> dict:
"""Get comprehensive health status (async version)"""
return self.get_health()
def get_health(self) -> dict:
"""Get comprehensive health status"""
memory = self.memory_monitor.get_memory_usage()
disk = self.cleanup_manager.get_disk_usage()
# Determine overall status
status = "healthy"
issues = []
if "error" not in memory:
if memory.get("system_percent", 0) > 90:
status = "degraded"
issues.append("High memory usage")
elif memory.get("system_percent", 0) > 80:
issues.append("Memory usage elevated")
if disk.get("usage_percent", 0) > 90:
status = "degraded"
issues.append("High disk usage")
elif disk.get("usage_percent", 0) > 80:
issues.append("Disk usage elevated")
uptime = datetime.now() - self.start_time
return {
"status": status,
"issues": issues,
"uptime_seconds": int(uptime.total_seconds()),
"uptime_human": str(uptime).split('.')[0],
"translations": {
"total": self._translation_count,
"errors": self._error_count,
"success_rate": round(
((self._translation_count - self._error_count) / self._translation_count * 100)
if self._translation_count > 0 else 100, 1
)
},
"memory": memory,
"disk": disk,
"cleanup_service": self.cleanup_manager.get_stats(),
"timestamp": datetime.now().isoformat()
}
# Create default instances
def create_cleanup_manager(config) -> FileCleanupManager:
"""Create cleanup manager with config"""
return FileCleanupManager(
upload_dir=config.UPLOAD_DIR,
output_dir=config.OUTPUT_DIR,
temp_dir=config.TEMP_DIR,
max_file_age_hours=getattr(config, 'MAX_FILE_AGE_HOURS', 1),
cleanup_interval_minutes=getattr(config, 'CLEANUP_INTERVAL_MINUTES', 10),
max_total_size_gb=getattr(config, 'MAX_TOTAL_SIZE_GB', 10.0)
)

328
middleware/rate_limiting.py Normal file
View File

@@ -0,0 +1,328 @@
"""
Rate Limiting Middleware for SaaS robustness
Protects against abuse and ensures fair usage
"""
import time
import asyncio
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Dict, Optional
from fastapi import Request, HTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
import logging
logger = logging.getLogger(__name__)
@dataclass
class RateLimitConfig:
"""Configuration for rate limiting"""
# Requests per window
requests_per_minute: int = 30
requests_per_hour: int = 200
requests_per_day: int = 1000
# Translation-specific limits
translations_per_minute: int = 10
translations_per_hour: int = 50
max_concurrent_translations: int = 5
# File size limits (MB)
max_file_size_mb: int = 50
max_total_size_per_hour_mb: int = 500
# Burst protection
burst_limit: int = 10 # Max requests in 1 second
# Whitelist IPs (no rate limiting)
whitelist_ips: list = field(default_factory=lambda: ["127.0.0.1", "::1"])
class TokenBucket:
"""Token bucket algorithm for rate limiting"""
def __init__(self, capacity: int, refill_rate: float):
self.capacity = capacity
self.refill_rate = refill_rate # tokens per second
self.tokens = capacity
self.last_refill = time.time()
self._lock = asyncio.Lock()
async def consume(self, tokens: int = 1) -> bool:
"""Try to consume tokens, return True if successful"""
async with self._lock:
self._refill()
if self.tokens >= tokens:
self.tokens -= tokens
return True
return False
def _refill(self):
"""Refill tokens based on time elapsed"""
now = time.time()
elapsed = now - self.last_refill
self.tokens = min(self.capacity, self.tokens + elapsed * self.refill_rate)
self.last_refill = now
class SlidingWindowCounter:
"""Sliding window counter for accurate rate limiting"""
def __init__(self, window_seconds: int, max_requests: int):
self.window_seconds = window_seconds
self.max_requests = max_requests
self.requests: list = []
self._lock = asyncio.Lock()
async def is_allowed(self) -> bool:
"""Check if a new request is allowed"""
async with self._lock:
now = time.time()
# Remove old requests outside the window
self.requests = [ts for ts in self.requests if now - ts < self.window_seconds]
if len(self.requests) < self.max_requests:
self.requests.append(now)
return True
return False
@property
def current_count(self) -> int:
"""Get current request count in window"""
now = time.time()
return len([ts for ts in self.requests if now - ts < self.window_seconds])
class ClientRateLimiter:
"""Per-client rate limiter with multiple windows"""
def __init__(self, config: RateLimitConfig):
self.config = config
self.minute_counter = SlidingWindowCounter(60, config.requests_per_minute)
self.hour_counter = SlidingWindowCounter(3600, config.requests_per_hour)
self.day_counter = SlidingWindowCounter(86400, config.requests_per_day)
self.burst_bucket = TokenBucket(config.burst_limit, config.burst_limit)
self.translation_minute = SlidingWindowCounter(60, config.translations_per_minute)
self.translation_hour = SlidingWindowCounter(3600, config.translations_per_hour)
self.concurrent_translations = 0
self.total_size_hour: list = [] # List of (timestamp, size_mb)
self._lock = asyncio.Lock()
async def check_request(self) -> tuple[bool, str]:
"""Check if request is allowed, return (allowed, reason)"""
# Check burst limit
if not await self.burst_bucket.consume():
return False, "Too many requests. Please slow down."
# Check minute limit
if not await self.minute_counter.is_allowed():
return False, f"Rate limit exceeded. Max {self.config.requests_per_minute} requests per minute."
# Check hour limit
if not await self.hour_counter.is_allowed():
return False, f"Hourly limit exceeded. Max {self.config.requests_per_hour} requests per hour."
# Check day limit
if not await self.day_counter.is_allowed():
return False, f"Daily limit exceeded. Max {self.config.requests_per_day} requests per day."
return True, ""
async def check_translation(self, file_size_mb: float = 0) -> tuple[bool, str]:
"""Check if translation request is allowed"""
async with self._lock:
# Check concurrent limit
if self.concurrent_translations >= self.config.max_concurrent_translations:
return False, f"Too many concurrent translations. Max {self.config.max_concurrent_translations} at a time."
# Check translation per minute
if not await self.translation_minute.is_allowed():
return False, f"Translation rate limit exceeded. Max {self.config.translations_per_minute} translations per minute."
# Check translation per hour
if not await self.translation_hour.is_allowed():
return False, f"Hourly translation limit exceeded. Max {self.config.translations_per_hour} translations per hour."
# Check total size per hour
async with self._lock:
now = time.time()
self.total_size_hour = [(ts, size) for ts, size in self.total_size_hour if now - ts < 3600]
total_size = sum(size for _, size in self.total_size_hour)
if total_size + file_size_mb > self.config.max_total_size_per_hour_mb:
return False, f"Hourly data limit exceeded. Max {self.config.max_total_size_per_hour_mb}MB per hour."
self.total_size_hour.append((now, file_size_mb))
return True, ""
async def start_translation(self):
"""Mark start of translation"""
async with self._lock:
self.concurrent_translations += 1
async def end_translation(self):
"""Mark end of translation"""
async with self._lock:
self.concurrent_translations = max(0, self.concurrent_translations - 1)
def get_stats(self) -> dict:
"""Get current rate limit stats"""
return {
"requests_minute": self.minute_counter.current_count,
"requests_hour": self.hour_counter.current_count,
"requests_day": self.day_counter.current_count,
"translations_minute": self.translation_minute.current_count,
"translations_hour": self.translation_hour.current_count,
"concurrent_translations": self.concurrent_translations,
}
class RateLimitManager:
"""Manages rate limiters for all clients"""
def __init__(self, config: Optional[RateLimitConfig] = None):
self.config = config or RateLimitConfig()
self.clients: Dict[str, ClientRateLimiter] = defaultdict(lambda: ClientRateLimiter(self.config))
self._cleanup_interval = 3600 # Cleanup old clients every hour
self._last_cleanup = time.time()
self._total_requests = 0
self._total_translations = 0
def get_client_id(self, request: Request) -> str:
"""Extract client identifier from request"""
# Try to get real IP from headers (for proxied requests)
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
# Fall back to direct client IP
if request.client:
return request.client.host
return "unknown"
def is_whitelisted(self, client_id: str) -> bool:
"""Check if client is whitelisted"""
return client_id in self.config.whitelist_ips
async def check_request(self, request: Request) -> tuple[bool, str, str]:
"""Check if request is allowed, return (allowed, reason, client_id)"""
client_id = self.get_client_id(request)
self._total_requests += 1
if self.is_whitelisted(client_id):
return True, "", client_id
client = self.clients[client_id]
allowed, reason = await client.check_request()
return allowed, reason, client_id
async def check_translation(self, request: Request, file_size_mb: float = 0) -> tuple[bool, str]:
"""Check if translation is allowed"""
client_id = self.get_client_id(request)
self._total_translations += 1
if self.is_whitelisted(client_id):
return True, ""
client = self.clients[client_id]
return await client.check_translation(file_size_mb)
async def check_translation_limit(self, client_id: str, file_size_mb: float = 0) -> bool:
"""Check if translation is allowed for a specific client ID"""
if self.is_whitelisted(client_id):
return True
client = self.clients[client_id]
allowed, _ = await client.check_translation(file_size_mb)
return allowed
def get_client_stats(self, request: Request) -> dict:
"""Get rate limit stats for a client"""
client_id = self.get_client_id(request)
client = self.clients[client_id]
return {
"client_id": client_id,
"is_whitelisted": self.is_whitelisted(client_id),
**client.get_stats()
}
async def get_client_status(self, client_id: str) -> dict:
"""Get current usage status for a client"""
if client_id not in self.clients:
return {"status": "no_activity", "requests": 0}
client = self.clients[client_id]
stats = client.get_stats()
return {
"requests_used_minute": stats["requests_minute"],
"requests_used_hour": stats["requests_hour"],
"translations_used_minute": stats["translations_minute"],
"translations_used_hour": stats["translations_hour"],
"concurrent_translations": stats["concurrent_translations"],
"is_whitelisted": self.is_whitelisted(client_id)
}
def get_stats(self) -> dict:
"""Get global rate limiting statistics"""
return {
"total_requests": self._total_requests,
"total_translations": self._total_translations,
"active_clients": len(self.clients),
"config": {
"requests_per_minute": self.config.requests_per_minute,
"requests_per_hour": self.config.requests_per_hour,
"translations_per_minute": self.config.translations_per_minute,
"translations_per_hour": self.config.translations_per_hour,
"max_concurrent_translations": self.config.max_concurrent_translations
}
}
class RateLimitMiddleware(BaseHTTPMiddleware):
"""FastAPI middleware for rate limiting"""
def __init__(self, app, rate_limit_manager: RateLimitManager):
super().__init__(app)
self.manager = rate_limit_manager
async def dispatch(self, request: Request, call_next):
# Skip rate limiting for health checks and static files
if request.url.path in ["/health", "/", "/docs", "/openapi.json", "/redoc"]:
return await call_next(request)
if request.url.path.startswith("/static"):
return await call_next(request)
# Check rate limit
allowed, reason, client_id = await self.manager.check_request(request)
if not allowed:
logger.warning(f"Rate limit exceeded for {client_id}: {reason}")
return JSONResponse(
status_code=429,
content={
"error": "rate_limit_exceeded",
"message": reason,
"retry_after": 60
},
headers={"Retry-After": "60"}
)
# Add client info to request state for use in endpoints
request.state.client_id = client_id
request.state.rate_limiter = self.manager.clients[client_id]
return await call_next(request)
# Global rate limit manager
rate_limit_manager = RateLimitManager()

142
middleware/security.py Normal file
View File

@@ -0,0 +1,142 @@
"""
Security Headers Middleware for SaaS robustness
Adds security headers to all responses
"""
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
import logging
logger = logging.getLogger(__name__)
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""Add security headers to all responses"""
def __init__(self, app, config: dict = None):
super().__init__(app)
self.config = config or {}
async def dispatch(self, request: Request, call_next) -> Response:
response = await call_next(request)
# Prevent clickjacking
response.headers["X-Frame-Options"] = "DENY"
# Prevent MIME type sniffing
response.headers["X-Content-Type-Options"] = "nosniff"
# Enable XSS filter
response.headers["X-XSS-Protection"] = "1; mode=block"
# Referrer policy
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
# Permissions policy
response.headers["Permissions-Policy"] = "geolocation=(), microphone=(), camera=()"
# Content Security Policy (adjust for your frontend)
if not request.url.path.startswith("/docs") and not request.url.path.startswith("/redoc"):
response.headers["Content-Security-Policy"] = (
"default-src 'self'; "
"script-src 'self' 'unsafe-inline' 'unsafe-eval' blob:; "
"style-src 'self' 'unsafe-inline'; "
"img-src 'self' data: blob:; "
"font-src 'self' data:; "
"connect-src 'self' http://localhost:* https://localhost:* ws://localhost:*; "
"worker-src 'self' blob:; "
)
# HSTS (only in production with HTTPS)
if self.config.get("enable_hsts", False):
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
return response
class RequestLoggingMiddleware(BaseHTTPMiddleware):
"""Log all requests for monitoring and debugging"""
def __init__(self, app, log_body: bool = False):
super().__init__(app)
self.log_body = log_body
async def dispatch(self, request: Request, call_next) -> Response:
import time
import uuid
# Generate request ID
request_id = str(uuid.uuid4())[:8]
request.state.request_id = request_id
# Get client info
client_ip = self._get_client_ip(request)
# Log request start
start_time = time.time()
logger.info(
f"[{request_id}] {request.method} {request.url.path} "
f"from {client_ip} - Started"
)
try:
response = await call_next(request)
# Log request completion
duration = time.time() - start_time
logger.info(
f"[{request_id}] {request.method} {request.url.path} "
f"- {response.status_code} in {duration:.3f}s"
)
# Add request ID to response headers
response.headers["X-Request-ID"] = request_id
return response
except Exception as e:
duration = time.time() - start_time
logger.error(
f"[{request_id}] {request.method} {request.url.path} "
f"- ERROR in {duration:.3f}s: {str(e)}"
)
raise
def _get_client_ip(self, request: Request) -> str:
"""Get real client IP from headers or connection"""
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
if request.client:
return request.client.host
return "unknown"
class ErrorHandlingMiddleware(BaseHTTPMiddleware):
"""Catch all unhandled exceptions and return proper error responses"""
async def dispatch(self, request: Request, call_next) -> Response:
from starlette.responses import JSONResponse
try:
return await call_next(request)
except Exception as e:
request_id = getattr(request.state, 'request_id', 'unknown')
logger.exception(f"[{request_id}] Unhandled exception: {str(e)}")
# Don't expose internal errors in production
return JSONResponse(
status_code=500,
content={
"error": "internal_server_error",
"message": "An unexpected error occurred. Please try again later.",
"request_id": request_id
}
)

440
middleware/validation.py Normal file
View File

@@ -0,0 +1,440 @@
"""
Input Validation Module for SaaS robustness
Validates all user inputs before processing
"""
import re
import magic
from pathlib import Path
from typing import Optional, List, Set
from fastapi import UploadFile, HTTPException
import logging
logger = logging.getLogger(__name__)
class ValidationError(Exception):
"""Custom validation error with user-friendly messages"""
def __init__(self, message: str, code: str = "validation_error", details: Optional[dict] = None):
self.message = message
self.code = code
self.details = details or {}
super().__init__(message)
class ValidationResult:
"""Result of a validation check"""
def __init__(self, is_valid: bool = True, errors: List[str] = None, warnings: List[str] = None, data: dict = None):
self.is_valid = is_valid
self.errors = errors or []
self.warnings = warnings or []
self.data = data or {}
class FileValidator:
"""Validates uploaded files for security and compatibility"""
# Allowed MIME types mapped to extensions
ALLOWED_MIME_TYPES = {
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
"application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx",
}
# Magic bytes for Office Open XML files (ZIP format)
OFFICE_MAGIC_BYTES = b"PK\x03\x04"
def __init__(
self,
max_size_mb: int = 50,
allowed_extensions: Set[str] = None,
scan_content: bool = True
):
self.max_size_bytes = max_size_mb * 1024 * 1024
self.max_size_mb = max_size_mb
self.allowed_extensions = allowed_extensions or {".xlsx", ".docx", ".pptx"}
self.scan_content = scan_content
async def validate_async(self, file: UploadFile) -> ValidationResult:
"""
Validate an uploaded file asynchronously
Returns ValidationResult with is_valid, errors, warnings
"""
errors = []
warnings = []
data = {}
try:
# Validate filename
if not file.filename:
errors.append("Filename is required")
return ValidationResult(is_valid=False, errors=errors)
# Sanitize filename
try:
safe_filename = self._sanitize_filename(file.filename)
data["safe_filename"] = safe_filename
except ValidationError as e:
errors.append(str(e.message))
return ValidationResult(is_valid=False, errors=errors)
# Validate extension
try:
extension = self._validate_extension(safe_filename)
data["extension"] = extension
except ValidationError as e:
errors.append(str(e.message))
return ValidationResult(is_valid=False, errors=errors)
# Read file content for validation
content = await file.read()
await file.seek(0) # Reset for later processing
# Validate file size
file_size = len(content)
data["size_bytes"] = file_size
data["size_mb"] = round(file_size / (1024*1024), 2)
if file_size > self.max_size_bytes:
errors.append(f"File too large. Maximum size is {self.max_size_mb}MB, got {file_size / (1024*1024):.1f}MB")
return ValidationResult(is_valid=False, errors=errors, data=data)
if file_size == 0:
errors.append("File is empty")
return ValidationResult(is_valid=False, errors=errors, data=data)
# Warn about large files
if file_size > self.max_size_bytes * 0.8:
warnings.append(f"File is {data['size_mb']}MB, approaching the {self.max_size_mb}MB limit")
# Validate magic bytes
if self.scan_content:
try:
self._validate_magic_bytes(content, extension)
except ValidationError as e:
errors.append(str(e.message))
return ValidationResult(is_valid=False, errors=errors, data=data)
# Validate MIME type
try:
mime_type = self._detect_mime_type(content)
data["mime_type"] = mime_type
self._validate_mime_type(mime_type, extension)
except ValidationError as e:
warnings.append(f"MIME type warning: {e.message}")
except Exception:
warnings.append("Could not verify MIME type")
data["original_filename"] = file.filename
return ValidationResult(is_valid=True, errors=errors, warnings=warnings, data=data)
except Exception as e:
logger.error(f"Validation error: {str(e)}")
errors.append(f"Validation failed: {str(e)}")
return ValidationResult(is_valid=False, errors=errors, warnings=warnings, data=data)
async def validate(self, file: UploadFile) -> dict:
"""
Validate an uploaded file
Returns validation info dict or raises ValidationError
"""
# Validate filename
if not file.filename:
raise ValidationError(
"Filename is required",
code="missing_filename"
)
# Sanitize filename
safe_filename = self._sanitize_filename(file.filename)
# Validate extension
extension = self._validate_extension(safe_filename)
# Read file content for validation
content = await file.read()
await file.seek(0) # Reset for later processing
# Validate file size
file_size = len(content)
if file_size > self.max_size_bytes:
raise ValidationError(
f"File too large. Maximum size is {self.max_size_mb}MB, got {file_size / (1024*1024):.1f}MB",
code="file_too_large",
details={"max_mb": self.max_size_mb, "actual_mb": round(file_size / (1024*1024), 2)}
)
if file_size == 0:
raise ValidationError(
"File is empty",
code="empty_file"
)
# Validate magic bytes (file signature)
if self.scan_content:
self._validate_magic_bytes(content, extension)
# Validate MIME type
mime_type = self._detect_mime_type(content)
self._validate_mime_type(mime_type, extension)
return {
"original_filename": file.filename,
"safe_filename": safe_filename,
"extension": extension,
"size_bytes": file_size,
"size_mb": round(file_size / (1024*1024), 2),
"mime_type": mime_type
}
def _sanitize_filename(self, filename: str) -> str:
"""Sanitize filename to prevent path traversal and other attacks"""
# Remove path components
filename = Path(filename).name
# Remove null bytes and control characters
filename = re.sub(r'[\x00-\x1f\x7f-\x9f]', '', filename)
# Remove potentially dangerous characters
filename = re.sub(r'[<>:"/\\|?*]', '_', filename)
# Limit length
if len(filename) > 255:
name, ext = filename.rsplit('.', 1) if '.' in filename else (filename, '')
filename = name[:250] + ('.' + ext if ext else '')
# Ensure not empty after sanitization
if not filename or filename.strip() == '':
raise ValidationError(
"Invalid filename",
code="invalid_filename"
)
return filename
def _validate_extension(self, filename: str) -> str:
"""Validate and return the file extension"""
if '.' not in filename:
raise ValidationError(
f"File must have an extension. Supported: {', '.join(self.allowed_extensions)}",
code="missing_extension",
details={"allowed_extensions": list(self.allowed_extensions)}
)
extension = '.' + filename.rsplit('.', 1)[1].lower()
if extension not in self.allowed_extensions:
raise ValidationError(
f"File type '{extension}' not supported. Supported types: {', '.join(self.allowed_extensions)}",
code="unsupported_file_type",
details={"extension": extension, "allowed_extensions": list(self.allowed_extensions)}
)
return extension
def _validate_magic_bytes(self, content: bytes, extension: str):
"""Validate file magic bytes match expected format"""
# All supported formats are Office Open XML (ZIP-based)
if not content.startswith(self.OFFICE_MAGIC_BYTES):
raise ValidationError(
"File content does not match expected format. The file may be corrupted or not a valid Office document.",
code="invalid_file_content"
)
def _detect_mime_type(self, content: bytes) -> str:
"""Detect MIME type from file content"""
try:
mime = magic.Magic(mime=True)
return mime.from_buffer(content)
except Exception:
# Fallback to basic detection
if content.startswith(self.OFFICE_MAGIC_BYTES):
return "application/zip"
return "application/octet-stream"
def _validate_mime_type(self, mime_type: str, extension: str):
"""Validate MIME type matches extension"""
# Office Open XML files may be detected as ZIP
allowed_mimes = list(self.ALLOWED_MIME_TYPES.keys()) + ["application/zip", "application/octet-stream"]
if mime_type not in allowed_mimes:
raise ValidationError(
f"Invalid file type detected. Expected Office document, got: {mime_type}",
code="invalid_mime_type",
details={"detected_mime": mime_type}
)
class LanguageValidator:
"""Validates language codes"""
SUPPORTED_LANGUAGES = {
# ISO 639-1 codes
"af", "sq", "am", "ar", "hy", "az", "eu", "be", "bn", "bs",
"bg", "ca", "ceb", "zh", "zh-CN", "zh-TW", "co", "hr", "cs",
"da", "nl", "en", "eo", "et", "fi", "fr", "fy", "gl", "ka",
"de", "el", "gu", "ht", "ha", "haw", "he", "hi", "hmn", "hu",
"is", "ig", "id", "ga", "it", "ja", "jv", "kn", "kk", "km",
"rw", "ko", "ku", "ky", "lo", "la", "lv", "lt", "lb", "mk",
"mg", "ms", "ml", "mt", "mi", "mr", "mn", "my", "ne", "no",
"ny", "or", "ps", "fa", "pl", "pt", "pa", "ro", "ru", "sm",
"gd", "sr", "st", "sn", "sd", "si", "sk", "sl", "so", "es",
"su", "sw", "sv", "tl", "tg", "ta", "tt", "te", "th", "tr",
"tk", "uk", "ur", "ug", "uz", "vi", "cy", "xh", "yi", "yo",
"zu", "auto"
}
LANGUAGE_NAMES = {
"en": "English", "es": "Spanish", "fr": "French", "de": "German",
"it": "Italian", "pt": "Portuguese", "ru": "Russian", "zh": "Chinese",
"zh-CN": "Chinese (Simplified)", "zh-TW": "Chinese (Traditional)",
"ja": "Japanese", "ko": "Korean", "ar": "Arabic", "hi": "Hindi",
"nl": "Dutch", "pl": "Polish", "tr": "Turkish", "sv": "Swedish",
"da": "Danish", "no": "Norwegian", "fi": "Finnish", "cs": "Czech",
"el": "Greek", "th": "Thai", "vi": "Vietnamese", "id": "Indonesian",
"uk": "Ukrainian", "ro": "Romanian", "hu": "Hungarian", "auto": "Auto-detect"
}
@classmethod
def validate(cls, language_code: str, field_name: str = "language") -> str:
"""Validate and normalize language code"""
if not language_code:
raise ValidationError(
f"{field_name} is required",
code="missing_language"
)
# Normalize
normalized = language_code.strip().lower()
# Handle common variations
if normalized in ["chinese", "cn"]:
normalized = "zh-CN"
elif normalized in ["chinese-traditional", "tw"]:
normalized = "zh-TW"
if normalized not in cls.SUPPORTED_LANGUAGES:
raise ValidationError(
f"Unsupported language code: '{language_code}'. See /languages for supported codes.",
code="unsupported_language",
details={"language": language_code}
)
return normalized
@classmethod
def get_language_name(cls, code: str) -> str:
"""Get human-readable language name"""
return cls.LANGUAGE_NAMES.get(code, code.upper())
class ProviderValidator:
"""Validates translation provider configuration"""
SUPPORTED_PROVIDERS = {"google", "ollama", "deepl", "libre", "openai", "webllm"}
@classmethod
def validate(cls, provider: str, **kwargs) -> dict:
"""Validate provider and its required configuration"""
if not provider:
raise ValidationError(
"Translation provider is required",
code="missing_provider"
)
normalized = provider.strip().lower()
if normalized not in cls.SUPPORTED_PROVIDERS:
raise ValidationError(
f"Unsupported provider: '{provider}'. Supported: {', '.join(cls.SUPPORTED_PROVIDERS)}",
code="unsupported_provider",
details={"provider": provider, "supported": list(cls.SUPPORTED_PROVIDERS)}
)
# Provider-specific validation
if normalized == "deepl":
if not kwargs.get("deepl_api_key"):
raise ValidationError(
"DeepL API key is required when using DeepL provider",
code="missing_deepl_key"
)
elif normalized == "openai":
if not kwargs.get("openai_api_key"):
raise ValidationError(
"OpenAI API key is required when using OpenAI provider",
code="missing_openai_key"
)
elif normalized == "ollama":
# Ollama doesn't require API key but may need model
model = kwargs.get("ollama_model", "")
if not model:
logger.warning("No Ollama model specified, will use default")
return {"provider": normalized, "validated": True}
class InputSanitizer:
"""Sanitizes user inputs to prevent injection attacks"""
@staticmethod
def sanitize_text(text: str, max_length: int = 10000) -> str:
"""Sanitize text input"""
if not text:
return ""
# Remove null bytes
text = text.replace('\x00', '')
# Limit length
if len(text) > max_length:
text = text[:max_length]
return text.strip()
@staticmethod
def sanitize_language_code(code: str) -> str:
"""Sanitize and normalize language code"""
if not code:
return "auto"
# Remove dangerous characters, keep only alphanumeric and hyphen
code = re.sub(r'[^a-zA-Z0-9\-]', '', code.strip())
# Limit length
if len(code) > 10:
code = code[:10]
return code.lower() if code else "auto"
@staticmethod
def sanitize_url(url: str) -> str:
"""Sanitize URL input"""
if not url:
return ""
url = url.strip()
# Basic URL validation
if not re.match(r'^https?://', url, re.IGNORECASE):
raise ValidationError(
"Invalid URL format. Must start with http:// or https://",
code="invalid_url"
)
# Remove trailing slashes
url = url.rstrip('/')
return url
@staticmethod
def sanitize_api_key(key: str) -> str:
"""Sanitize API key (just trim, no logging)"""
if not key:
return ""
return key.strip()
# Default validators
file_validator = FileValidator()