feat: Add SaaS robustness middleware - Rate limiting with token bucket and sliding window algorithms - Input validation (file, language, provider) - Security headers middleware (CSP, XSS protection, etc.) - Automatic file cleanup with TTL tracking - Memory and disk monitoring - Enhanced health check and metrics endpoints - Request logging with unique IDs
This commit is contained in:
62
middleware/__init__.py
Normal file
62
middleware/__init__.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""
|
||||
Middleware package for SaaS robustness
|
||||
|
||||
This package provides:
|
||||
- Rate limiting: Protect against abuse and ensure fair usage
|
||||
- Validation: Validate all inputs before processing
|
||||
- Security: Security headers, request logging, error handling
|
||||
- Cleanup: Automatic file cleanup and resource management
|
||||
"""
|
||||
|
||||
from .rate_limiting import (
|
||||
RateLimitConfig,
|
||||
RateLimitManager,
|
||||
RateLimitMiddleware,
|
||||
ClientRateLimiter,
|
||||
)
|
||||
|
||||
from .validation import (
|
||||
ValidationError,
|
||||
ValidationResult,
|
||||
FileValidator,
|
||||
LanguageValidator,
|
||||
ProviderValidator,
|
||||
InputSanitizer,
|
||||
)
|
||||
|
||||
from .security import (
|
||||
SecurityHeadersMiddleware,
|
||||
RequestLoggingMiddleware,
|
||||
ErrorHandlingMiddleware,
|
||||
)
|
||||
|
||||
from .cleanup import (
|
||||
FileCleanupManager,
|
||||
MemoryMonitor,
|
||||
HealthChecker,
|
||||
create_cleanup_manager,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Rate limiting
|
||||
"RateLimitConfig",
|
||||
"RateLimitManager",
|
||||
"RateLimitMiddleware",
|
||||
"ClientRateLimiter",
|
||||
# Validation
|
||||
"ValidationError",
|
||||
"ValidationResult",
|
||||
"FileValidator",
|
||||
"LanguageValidator",
|
||||
"ProviderValidator",
|
||||
"InputSanitizer",
|
||||
# Security
|
||||
"SecurityHeadersMiddleware",
|
||||
"RequestLoggingMiddleware",
|
||||
"ErrorHandlingMiddleware",
|
||||
# Cleanup
|
||||
"FileCleanupManager",
|
||||
"MemoryMonitor",
|
||||
"HealthChecker",
|
||||
"create_cleanup_manager",
|
||||
]
|
||||
400
middleware/cleanup.py
Normal file
400
middleware/cleanup.py
Normal file
@@ -0,0 +1,400 @@
|
||||
"""
|
||||
Cleanup and Resource Management for SaaS robustness
|
||||
Automatic cleanup of temporary files and resources
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
import asyncio
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Set
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileCleanupManager:
|
||||
"""Manages automatic cleanup of temporary and output files"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
upload_dir: Path,
|
||||
output_dir: Path,
|
||||
temp_dir: Path,
|
||||
max_file_age_hours: int = 1,
|
||||
cleanup_interval_minutes: int = 10,
|
||||
max_total_size_gb: float = 10.0
|
||||
):
|
||||
self.upload_dir = Path(upload_dir)
|
||||
self.output_dir = Path(output_dir)
|
||||
self.temp_dir = Path(temp_dir)
|
||||
self.max_file_age_seconds = max_file_age_hours * 3600
|
||||
self.cleanup_interval = cleanup_interval_minutes * 60
|
||||
self.max_total_size_bytes = int(max_total_size_gb * 1024 * 1024 * 1024)
|
||||
|
||||
self._running = False
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._protected_files: Set[str] = set()
|
||||
self._tracked_files: dict = {} # filepath -> {created, ttl_minutes}
|
||||
self._lock = threading.Lock()
|
||||
self._stats = {
|
||||
"files_cleaned": 0,
|
||||
"bytes_freed": 0,
|
||||
"cleanup_runs": 0
|
||||
}
|
||||
|
||||
async def track_file(self, filepath: Path, ttl_minutes: int = 60):
|
||||
"""Track a file for automatic cleanup after TTL expires"""
|
||||
with self._lock:
|
||||
self._tracked_files[str(filepath)] = {
|
||||
"created": time.time(),
|
||||
"ttl_minutes": ttl_minutes,
|
||||
"expires_at": time.time() + (ttl_minutes * 60)
|
||||
}
|
||||
|
||||
def get_tracked_files(self) -> list:
|
||||
"""Get list of currently tracked files with their status"""
|
||||
now = time.time()
|
||||
result = []
|
||||
|
||||
with self._lock:
|
||||
for filepath, info in self._tracked_files.items():
|
||||
remaining = info["expires_at"] - now
|
||||
result.append({
|
||||
"path": filepath,
|
||||
"exists": Path(filepath).exists(),
|
||||
"expires_in_seconds": max(0, int(remaining)),
|
||||
"ttl_minutes": info["ttl_minutes"]
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
async def cleanup_expired(self) -> int:
|
||||
"""Cleanup expired tracked files"""
|
||||
now = time.time()
|
||||
cleaned = 0
|
||||
to_remove = []
|
||||
|
||||
with self._lock:
|
||||
for filepath, info in list(self._tracked_files.items()):
|
||||
if now > info["expires_at"]:
|
||||
to_remove.append(filepath)
|
||||
|
||||
for filepath in to_remove:
|
||||
try:
|
||||
path = Path(filepath)
|
||||
if path.exists() and not self.is_protected(path):
|
||||
size = path.stat().st_size
|
||||
path.unlink()
|
||||
cleaned += 1
|
||||
self._stats["files_cleaned"] += 1
|
||||
self._stats["bytes_freed"] += size
|
||||
logger.info(f"Cleaned expired file: {filepath}")
|
||||
|
||||
with self._lock:
|
||||
self._tracked_files.pop(filepath, None)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean expired file {filepath}: {e}")
|
||||
|
||||
return cleaned
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get cleanup statistics"""
|
||||
disk_usage = self.get_disk_usage()
|
||||
|
||||
with self._lock:
|
||||
tracked_count = len(self._tracked_files)
|
||||
|
||||
return {
|
||||
"files_cleaned_total": self._stats["files_cleaned"],
|
||||
"bytes_freed_total_mb": round(self._stats["bytes_freed"] / (1024 * 1024), 2),
|
||||
"cleanup_runs": self._stats["cleanup_runs"],
|
||||
"tracked_files": tracked_count,
|
||||
"disk_usage": disk_usage,
|
||||
"is_running": self._running
|
||||
}
|
||||
|
||||
def protect_file(self, filepath: Path):
|
||||
"""Mark a file as protected (being processed)"""
|
||||
with self._lock:
|
||||
self._protected_files.add(str(filepath))
|
||||
|
||||
def unprotect_file(self, filepath: Path):
|
||||
"""Remove protection from a file"""
|
||||
with self._lock:
|
||||
self._protected_files.discard(str(filepath))
|
||||
|
||||
def is_protected(self, filepath: Path) -> bool:
|
||||
"""Check if a file is protected"""
|
||||
with self._lock:
|
||||
return str(filepath) in self._protected_files
|
||||
|
||||
async def start(self):
|
||||
"""Start the cleanup background task"""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._task = asyncio.create_task(self._cleanup_loop())
|
||||
logger.info("File cleanup manager started")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the cleanup background task"""
|
||||
self._running = False
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("File cleanup manager stopped")
|
||||
|
||||
async def _cleanup_loop(self):
|
||||
"""Background loop for periodic cleanup"""
|
||||
while self._running:
|
||||
try:
|
||||
await self.cleanup()
|
||||
await self.cleanup_expired()
|
||||
self._stats["cleanup_runs"] += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Cleanup error: {e}")
|
||||
|
||||
await asyncio.sleep(self.cleanup_interval)
|
||||
|
||||
async def cleanup(self) -> dict:
|
||||
"""Perform cleanup of old files"""
|
||||
stats = {
|
||||
"files_deleted": 0,
|
||||
"bytes_freed": 0,
|
||||
"errors": []
|
||||
}
|
||||
|
||||
now = time.time()
|
||||
|
||||
# Cleanup each directory
|
||||
for directory in [self.upload_dir, self.output_dir, self.temp_dir]:
|
||||
if not directory.exists():
|
||||
continue
|
||||
|
||||
for filepath in directory.iterdir():
|
||||
if not filepath.is_file():
|
||||
continue
|
||||
|
||||
# Skip protected files
|
||||
if self.is_protected(filepath):
|
||||
continue
|
||||
|
||||
try:
|
||||
# Check file age
|
||||
file_age = now - filepath.stat().st_mtime
|
||||
|
||||
if file_age > self.max_file_age_seconds:
|
||||
file_size = filepath.stat().st_size
|
||||
filepath.unlink()
|
||||
stats["files_deleted"] += 1
|
||||
stats["bytes_freed"] += file_size
|
||||
logger.debug(f"Deleted old file: {filepath}")
|
||||
|
||||
except Exception as e:
|
||||
stats["errors"].append(str(e))
|
||||
logger.warning(f"Failed to delete {filepath}: {e}")
|
||||
|
||||
# Force cleanup if total size exceeds limit
|
||||
await self._enforce_size_limit(stats)
|
||||
|
||||
if stats["files_deleted"] > 0:
|
||||
mb_freed = stats["bytes_freed"] / (1024 * 1024)
|
||||
logger.info(f"Cleanup: deleted {stats['files_deleted']} files, freed {mb_freed:.2f}MB")
|
||||
|
||||
return stats
|
||||
|
||||
async def _enforce_size_limit(self, stats: dict):
|
||||
"""Delete oldest files if total size exceeds limit"""
|
||||
files_with_mtime = []
|
||||
total_size = 0
|
||||
|
||||
for directory in [self.upload_dir, self.output_dir, self.temp_dir]:
|
||||
if not directory.exists():
|
||||
continue
|
||||
|
||||
for filepath in directory.iterdir():
|
||||
if not filepath.is_file() or self.is_protected(filepath):
|
||||
continue
|
||||
|
||||
try:
|
||||
stat = filepath.stat()
|
||||
files_with_mtime.append((filepath, stat.st_mtime, stat.st_size))
|
||||
total_size += stat.st_size
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# If under limit, nothing to do
|
||||
if total_size <= self.max_total_size_bytes:
|
||||
return
|
||||
|
||||
# Sort by modification time (oldest first)
|
||||
files_with_mtime.sort(key=lambda x: x[1])
|
||||
|
||||
# Delete oldest files until under limit
|
||||
for filepath, _, size in files_with_mtime:
|
||||
if total_size <= self.max_total_size_bytes:
|
||||
break
|
||||
|
||||
try:
|
||||
filepath.unlink()
|
||||
total_size -= size
|
||||
stats["files_deleted"] += 1
|
||||
stats["bytes_freed"] += size
|
||||
logger.info(f"Deleted file to free space: {filepath}")
|
||||
except Exception as e:
|
||||
stats["errors"].append(str(e))
|
||||
|
||||
def get_disk_usage(self) -> dict:
|
||||
"""Get current disk usage statistics"""
|
||||
total_files = 0
|
||||
total_size = 0
|
||||
|
||||
for directory in [self.upload_dir, self.output_dir, self.temp_dir]:
|
||||
if not directory.exists():
|
||||
continue
|
||||
|
||||
for filepath in directory.iterdir():
|
||||
if filepath.is_file():
|
||||
total_files += 1
|
||||
try:
|
||||
total_size += filepath.stat().st_size
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {
|
||||
"total_files": total_files,
|
||||
"total_size_mb": round(total_size / (1024 * 1024), 2),
|
||||
"max_size_gb": self.max_total_size_bytes / (1024 * 1024 * 1024),
|
||||
"usage_percent": round((total_size / self.max_total_size_bytes) * 100, 1) if self.max_total_size_bytes > 0 else 0,
|
||||
"directories": {
|
||||
"uploads": str(self.upload_dir),
|
||||
"outputs": str(self.output_dir),
|
||||
"temp": str(self.temp_dir)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class MemoryMonitor:
|
||||
"""Monitors memory usage and triggers cleanup if needed"""
|
||||
|
||||
def __init__(self, max_memory_percent: float = 80.0):
|
||||
self.max_memory_percent = max_memory_percent
|
||||
self._high_memory_callbacks = []
|
||||
|
||||
def get_memory_usage(self) -> dict:
|
||||
"""Get current memory usage"""
|
||||
try:
|
||||
import psutil
|
||||
process = psutil.Process()
|
||||
memory_info = process.memory_info()
|
||||
system_memory = psutil.virtual_memory()
|
||||
|
||||
return {
|
||||
"process_rss_mb": round(memory_info.rss / (1024 * 1024), 2),
|
||||
"process_vms_mb": round(memory_info.vms / (1024 * 1024), 2),
|
||||
"system_total_gb": round(system_memory.total / (1024 * 1024 * 1024), 2),
|
||||
"system_available_gb": round(system_memory.available / (1024 * 1024 * 1024), 2),
|
||||
"system_percent": system_memory.percent
|
||||
}
|
||||
except ImportError:
|
||||
return {"error": "psutil not installed"}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
def check_memory(self) -> bool:
|
||||
"""Check if memory usage is within limits"""
|
||||
usage = self.get_memory_usage()
|
||||
if "error" in usage:
|
||||
return True # Can't check, assume OK
|
||||
|
||||
return usage.get("system_percent", 0) < self.max_memory_percent
|
||||
|
||||
def on_high_memory(self, callback):
|
||||
"""Register callback for high memory situations"""
|
||||
self._high_memory_callbacks.append(callback)
|
||||
|
||||
|
||||
class HealthChecker:
|
||||
"""Comprehensive health checking for the application"""
|
||||
|
||||
def __init__(self, cleanup_manager: FileCleanupManager, memory_monitor: MemoryMonitor):
|
||||
self.cleanup_manager = cleanup_manager
|
||||
self.memory_monitor = memory_monitor
|
||||
self.start_time = datetime.now()
|
||||
self._translation_count = 0
|
||||
self._error_count = 0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def record_translation(self, success: bool = True):
|
||||
"""Record a translation attempt"""
|
||||
with self._lock:
|
||||
self._translation_count += 1
|
||||
if not success:
|
||||
self._error_count += 1
|
||||
|
||||
async def check_health(self) -> dict:
|
||||
"""Get comprehensive health status (async version)"""
|
||||
return self.get_health()
|
||||
|
||||
def get_health(self) -> dict:
|
||||
"""Get comprehensive health status"""
|
||||
memory = self.memory_monitor.get_memory_usage()
|
||||
disk = self.cleanup_manager.get_disk_usage()
|
||||
|
||||
# Determine overall status
|
||||
status = "healthy"
|
||||
issues = []
|
||||
|
||||
if "error" not in memory:
|
||||
if memory.get("system_percent", 0) > 90:
|
||||
status = "degraded"
|
||||
issues.append("High memory usage")
|
||||
elif memory.get("system_percent", 0) > 80:
|
||||
issues.append("Memory usage elevated")
|
||||
|
||||
if disk.get("usage_percent", 0) > 90:
|
||||
status = "degraded"
|
||||
issues.append("High disk usage")
|
||||
elif disk.get("usage_percent", 0) > 80:
|
||||
issues.append("Disk usage elevated")
|
||||
|
||||
uptime = datetime.now() - self.start_time
|
||||
|
||||
return {
|
||||
"status": status,
|
||||
"issues": issues,
|
||||
"uptime_seconds": int(uptime.total_seconds()),
|
||||
"uptime_human": str(uptime).split('.')[0],
|
||||
"translations": {
|
||||
"total": self._translation_count,
|
||||
"errors": self._error_count,
|
||||
"success_rate": round(
|
||||
((self._translation_count - self._error_count) / self._translation_count * 100)
|
||||
if self._translation_count > 0 else 100, 1
|
||||
)
|
||||
},
|
||||
"memory": memory,
|
||||
"disk": disk,
|
||||
"cleanup_service": self.cleanup_manager.get_stats(),
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
|
||||
# Create default instances
|
||||
def create_cleanup_manager(config) -> FileCleanupManager:
|
||||
"""Create cleanup manager with config"""
|
||||
return FileCleanupManager(
|
||||
upload_dir=config.UPLOAD_DIR,
|
||||
output_dir=config.OUTPUT_DIR,
|
||||
temp_dir=config.TEMP_DIR,
|
||||
max_file_age_hours=getattr(config, 'MAX_FILE_AGE_HOURS', 1),
|
||||
cleanup_interval_minutes=getattr(config, 'CLEANUP_INTERVAL_MINUTES', 10),
|
||||
max_total_size_gb=getattr(config, 'MAX_TOTAL_SIZE_GB', 10.0)
|
||||
)
|
||||
328
middleware/rate_limiting.py
Normal file
328
middleware/rate_limiting.py
Normal file
@@ -0,0 +1,328 @@
|
||||
"""
|
||||
Rate Limiting Middleware for SaaS robustness
|
||||
Protects against abuse and ensures fair usage
|
||||
"""
|
||||
import time
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional
|
||||
from fastapi import Request, HTTPException
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitConfig:
|
||||
"""Configuration for rate limiting"""
|
||||
# Requests per window
|
||||
requests_per_minute: int = 30
|
||||
requests_per_hour: int = 200
|
||||
requests_per_day: int = 1000
|
||||
|
||||
# Translation-specific limits
|
||||
translations_per_minute: int = 10
|
||||
translations_per_hour: int = 50
|
||||
max_concurrent_translations: int = 5
|
||||
|
||||
# File size limits (MB)
|
||||
max_file_size_mb: int = 50
|
||||
max_total_size_per_hour_mb: int = 500
|
||||
|
||||
# Burst protection
|
||||
burst_limit: int = 10 # Max requests in 1 second
|
||||
|
||||
# Whitelist IPs (no rate limiting)
|
||||
whitelist_ips: list = field(default_factory=lambda: ["127.0.0.1", "::1"])
|
||||
|
||||
|
||||
class TokenBucket:
|
||||
"""Token bucket algorithm for rate limiting"""
|
||||
|
||||
def __init__(self, capacity: int, refill_rate: float):
|
||||
self.capacity = capacity
|
||||
self.refill_rate = refill_rate # tokens per second
|
||||
self.tokens = capacity
|
||||
self.last_refill = time.time()
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def consume(self, tokens: int = 1) -> bool:
|
||||
"""Try to consume tokens, return True if successful"""
|
||||
async with self._lock:
|
||||
self._refill()
|
||||
if self.tokens >= tokens:
|
||||
self.tokens -= tokens
|
||||
return True
|
||||
return False
|
||||
|
||||
def _refill(self):
|
||||
"""Refill tokens based on time elapsed"""
|
||||
now = time.time()
|
||||
elapsed = now - self.last_refill
|
||||
self.tokens = min(self.capacity, self.tokens + elapsed * self.refill_rate)
|
||||
self.last_refill = now
|
||||
|
||||
|
||||
class SlidingWindowCounter:
|
||||
"""Sliding window counter for accurate rate limiting"""
|
||||
|
||||
def __init__(self, window_seconds: int, max_requests: int):
|
||||
self.window_seconds = window_seconds
|
||||
self.max_requests = max_requests
|
||||
self.requests: list = []
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def is_allowed(self) -> bool:
|
||||
"""Check if a new request is allowed"""
|
||||
async with self._lock:
|
||||
now = time.time()
|
||||
# Remove old requests outside the window
|
||||
self.requests = [ts for ts in self.requests if now - ts < self.window_seconds]
|
||||
|
||||
if len(self.requests) < self.max_requests:
|
||||
self.requests.append(now)
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def current_count(self) -> int:
|
||||
"""Get current request count in window"""
|
||||
now = time.time()
|
||||
return len([ts for ts in self.requests if now - ts < self.window_seconds])
|
||||
|
||||
|
||||
class ClientRateLimiter:
|
||||
"""Per-client rate limiter with multiple windows"""
|
||||
|
||||
def __init__(self, config: RateLimitConfig):
|
||||
self.config = config
|
||||
self.minute_counter = SlidingWindowCounter(60, config.requests_per_minute)
|
||||
self.hour_counter = SlidingWindowCounter(3600, config.requests_per_hour)
|
||||
self.day_counter = SlidingWindowCounter(86400, config.requests_per_day)
|
||||
self.burst_bucket = TokenBucket(config.burst_limit, config.burst_limit)
|
||||
self.translation_minute = SlidingWindowCounter(60, config.translations_per_minute)
|
||||
self.translation_hour = SlidingWindowCounter(3600, config.translations_per_hour)
|
||||
self.concurrent_translations = 0
|
||||
self.total_size_hour: list = [] # List of (timestamp, size_mb)
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def check_request(self) -> tuple[bool, str]:
|
||||
"""Check if request is allowed, return (allowed, reason)"""
|
||||
# Check burst limit
|
||||
if not await self.burst_bucket.consume():
|
||||
return False, "Too many requests. Please slow down."
|
||||
|
||||
# Check minute limit
|
||||
if not await self.minute_counter.is_allowed():
|
||||
return False, f"Rate limit exceeded. Max {self.config.requests_per_minute} requests per minute."
|
||||
|
||||
# Check hour limit
|
||||
if not await self.hour_counter.is_allowed():
|
||||
return False, f"Hourly limit exceeded. Max {self.config.requests_per_hour} requests per hour."
|
||||
|
||||
# Check day limit
|
||||
if not await self.day_counter.is_allowed():
|
||||
return False, f"Daily limit exceeded. Max {self.config.requests_per_day} requests per day."
|
||||
|
||||
return True, ""
|
||||
|
||||
async def check_translation(self, file_size_mb: float = 0) -> tuple[bool, str]:
|
||||
"""Check if translation request is allowed"""
|
||||
async with self._lock:
|
||||
# Check concurrent limit
|
||||
if self.concurrent_translations >= self.config.max_concurrent_translations:
|
||||
return False, f"Too many concurrent translations. Max {self.config.max_concurrent_translations} at a time."
|
||||
|
||||
# Check translation per minute
|
||||
if not await self.translation_minute.is_allowed():
|
||||
return False, f"Translation rate limit exceeded. Max {self.config.translations_per_minute} translations per minute."
|
||||
|
||||
# Check translation per hour
|
||||
if not await self.translation_hour.is_allowed():
|
||||
return False, f"Hourly translation limit exceeded. Max {self.config.translations_per_hour} translations per hour."
|
||||
|
||||
# Check total size per hour
|
||||
async with self._lock:
|
||||
now = time.time()
|
||||
self.total_size_hour = [(ts, size) for ts, size in self.total_size_hour if now - ts < 3600]
|
||||
total_size = sum(size for _, size in self.total_size_hour)
|
||||
|
||||
if total_size + file_size_mb > self.config.max_total_size_per_hour_mb:
|
||||
return False, f"Hourly data limit exceeded. Max {self.config.max_total_size_per_hour_mb}MB per hour."
|
||||
|
||||
self.total_size_hour.append((now, file_size_mb))
|
||||
|
||||
return True, ""
|
||||
|
||||
async def start_translation(self):
|
||||
"""Mark start of translation"""
|
||||
async with self._lock:
|
||||
self.concurrent_translations += 1
|
||||
|
||||
async def end_translation(self):
|
||||
"""Mark end of translation"""
|
||||
async with self._lock:
|
||||
self.concurrent_translations = max(0, self.concurrent_translations - 1)
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get current rate limit stats"""
|
||||
return {
|
||||
"requests_minute": self.minute_counter.current_count,
|
||||
"requests_hour": self.hour_counter.current_count,
|
||||
"requests_day": self.day_counter.current_count,
|
||||
"translations_minute": self.translation_minute.current_count,
|
||||
"translations_hour": self.translation_hour.current_count,
|
||||
"concurrent_translations": self.concurrent_translations,
|
||||
}
|
||||
|
||||
|
||||
class RateLimitManager:
|
||||
"""Manages rate limiters for all clients"""
|
||||
|
||||
def __init__(self, config: Optional[RateLimitConfig] = None):
|
||||
self.config = config or RateLimitConfig()
|
||||
self.clients: Dict[str, ClientRateLimiter] = defaultdict(lambda: ClientRateLimiter(self.config))
|
||||
self._cleanup_interval = 3600 # Cleanup old clients every hour
|
||||
self._last_cleanup = time.time()
|
||||
self._total_requests = 0
|
||||
self._total_translations = 0
|
||||
|
||||
def get_client_id(self, request: Request) -> str:
|
||||
"""Extract client identifier from request"""
|
||||
# Try to get real IP from headers (for proxied requests)
|
||||
forwarded = request.headers.get("X-Forwarded-For")
|
||||
if forwarded:
|
||||
return forwarded.split(",")[0].strip()
|
||||
|
||||
real_ip = request.headers.get("X-Real-IP")
|
||||
if real_ip:
|
||||
return real_ip
|
||||
|
||||
# Fall back to direct client IP
|
||||
if request.client:
|
||||
return request.client.host
|
||||
|
||||
return "unknown"
|
||||
|
||||
def is_whitelisted(self, client_id: str) -> bool:
|
||||
"""Check if client is whitelisted"""
|
||||
return client_id in self.config.whitelist_ips
|
||||
|
||||
async def check_request(self, request: Request) -> tuple[bool, str, str]:
|
||||
"""Check if request is allowed, return (allowed, reason, client_id)"""
|
||||
client_id = self.get_client_id(request)
|
||||
self._total_requests += 1
|
||||
|
||||
if self.is_whitelisted(client_id):
|
||||
return True, "", client_id
|
||||
|
||||
client = self.clients[client_id]
|
||||
allowed, reason = await client.check_request()
|
||||
|
||||
return allowed, reason, client_id
|
||||
|
||||
async def check_translation(self, request: Request, file_size_mb: float = 0) -> tuple[bool, str]:
|
||||
"""Check if translation is allowed"""
|
||||
client_id = self.get_client_id(request)
|
||||
self._total_translations += 1
|
||||
|
||||
if self.is_whitelisted(client_id):
|
||||
return True, ""
|
||||
|
||||
client = self.clients[client_id]
|
||||
return await client.check_translation(file_size_mb)
|
||||
|
||||
async def check_translation_limit(self, client_id: str, file_size_mb: float = 0) -> bool:
|
||||
"""Check if translation is allowed for a specific client ID"""
|
||||
if self.is_whitelisted(client_id):
|
||||
return True
|
||||
|
||||
client = self.clients[client_id]
|
||||
allowed, _ = await client.check_translation(file_size_mb)
|
||||
return allowed
|
||||
|
||||
def get_client_stats(self, request: Request) -> dict:
|
||||
"""Get rate limit stats for a client"""
|
||||
client_id = self.get_client_id(request)
|
||||
client = self.clients[client_id]
|
||||
return {
|
||||
"client_id": client_id,
|
||||
"is_whitelisted": self.is_whitelisted(client_id),
|
||||
**client.get_stats()
|
||||
}
|
||||
|
||||
async def get_client_status(self, client_id: str) -> dict:
|
||||
"""Get current usage status for a client"""
|
||||
if client_id not in self.clients:
|
||||
return {"status": "no_activity", "requests": 0}
|
||||
|
||||
client = self.clients[client_id]
|
||||
stats = client.get_stats()
|
||||
|
||||
return {
|
||||
"requests_used_minute": stats["requests_minute"],
|
||||
"requests_used_hour": stats["requests_hour"],
|
||||
"translations_used_minute": stats["translations_minute"],
|
||||
"translations_used_hour": stats["translations_hour"],
|
||||
"concurrent_translations": stats["concurrent_translations"],
|
||||
"is_whitelisted": self.is_whitelisted(client_id)
|
||||
}
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get global rate limiting statistics"""
|
||||
return {
|
||||
"total_requests": self._total_requests,
|
||||
"total_translations": self._total_translations,
|
||||
"active_clients": len(self.clients),
|
||||
"config": {
|
||||
"requests_per_minute": self.config.requests_per_minute,
|
||||
"requests_per_hour": self.config.requests_per_hour,
|
||||
"translations_per_minute": self.config.translations_per_minute,
|
||||
"translations_per_hour": self.config.translations_per_hour,
|
||||
"max_concurrent_translations": self.config.max_concurrent_translations
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""FastAPI middleware for rate limiting"""
|
||||
|
||||
def __init__(self, app, rate_limit_manager: RateLimitManager):
|
||||
super().__init__(app)
|
||||
self.manager = rate_limit_manager
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# Skip rate limiting for health checks and static files
|
||||
if request.url.path in ["/health", "/", "/docs", "/openapi.json", "/redoc"]:
|
||||
return await call_next(request)
|
||||
|
||||
if request.url.path.startswith("/static"):
|
||||
return await call_next(request)
|
||||
|
||||
# Check rate limit
|
||||
allowed, reason, client_id = await self.manager.check_request(request)
|
||||
|
||||
if not allowed:
|
||||
logger.warning(f"Rate limit exceeded for {client_id}: {reason}")
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"error": "rate_limit_exceeded",
|
||||
"message": reason,
|
||||
"retry_after": 60
|
||||
},
|
||||
headers={"Retry-After": "60"}
|
||||
)
|
||||
|
||||
# Add client info to request state for use in endpoints
|
||||
request.state.client_id = client_id
|
||||
request.state.rate_limiter = self.manager.clients[client_id]
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
# Global rate limit manager
|
||||
rate_limit_manager = RateLimitManager()
|
||||
142
middleware/security.py
Normal file
142
middleware/security.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""
|
||||
Security Headers Middleware for SaaS robustness
|
||||
Adds security headers to all responses
|
||||
"""
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
"""Add security headers to all responses"""
|
||||
|
||||
def __init__(self, app, config: dict = None):
|
||||
super().__init__(app)
|
||||
self.config = config or {}
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
response = await call_next(request)
|
||||
|
||||
# Prevent clickjacking
|
||||
response.headers["X-Frame-Options"] = "DENY"
|
||||
|
||||
# Prevent MIME type sniffing
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
|
||||
# Enable XSS filter
|
||||
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||||
|
||||
# Referrer policy
|
||||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||
|
||||
# Permissions policy
|
||||
response.headers["Permissions-Policy"] = "geolocation=(), microphone=(), camera=()"
|
||||
|
||||
# Content Security Policy (adjust for your frontend)
|
||||
if not request.url.path.startswith("/docs") and not request.url.path.startswith("/redoc"):
|
||||
response.headers["Content-Security-Policy"] = (
|
||||
"default-src 'self'; "
|
||||
"script-src 'self' 'unsafe-inline' 'unsafe-eval' blob:; "
|
||||
"style-src 'self' 'unsafe-inline'; "
|
||||
"img-src 'self' data: blob:; "
|
||||
"font-src 'self' data:; "
|
||||
"connect-src 'self' http://localhost:* https://localhost:* ws://localhost:*; "
|
||||
"worker-src 'self' blob:; "
|
||||
)
|
||||
|
||||
# HSTS (only in production with HTTPS)
|
||||
if self.config.get("enable_hsts", False):
|
||||
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
||||
"""Log all requests for monitoring and debugging"""
|
||||
|
||||
def __init__(self, app, log_body: bool = False):
|
||||
super().__init__(app)
|
||||
self.log_body = log_body
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
import time
|
||||
import uuid
|
||||
|
||||
# Generate request ID
|
||||
request_id = str(uuid.uuid4())[:8]
|
||||
request.state.request_id = request_id
|
||||
|
||||
# Get client info
|
||||
client_ip = self._get_client_ip(request)
|
||||
|
||||
# Log request start
|
||||
start_time = time.time()
|
||||
logger.info(
|
||||
f"[{request_id}] {request.method} {request.url.path} "
|
||||
f"from {client_ip} - Started"
|
||||
)
|
||||
|
||||
try:
|
||||
response = await call_next(request)
|
||||
|
||||
# Log request completion
|
||||
duration = time.time() - start_time
|
||||
logger.info(
|
||||
f"[{request_id}] {request.method} {request.url.path} "
|
||||
f"- {response.status_code} in {duration:.3f}s"
|
||||
)
|
||||
|
||||
# Add request ID to response headers
|
||||
response.headers["X-Request-ID"] = request_id
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
duration = time.time() - start_time
|
||||
logger.error(
|
||||
f"[{request_id}] {request.method} {request.url.path} "
|
||||
f"- ERROR in {duration:.3f}s: {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""Get real client IP from headers or connection"""
|
||||
forwarded = request.headers.get("X-Forwarded-For")
|
||||
if forwarded:
|
||||
return forwarded.split(",")[0].strip()
|
||||
|
||||
real_ip = request.headers.get("X-Real-IP")
|
||||
if real_ip:
|
||||
return real_ip
|
||||
|
||||
if request.client:
|
||||
return request.client.host
|
||||
|
||||
return "unknown"
|
||||
|
||||
|
||||
class ErrorHandlingMiddleware(BaseHTTPMiddleware):
|
||||
"""Catch all unhandled exceptions and return proper error responses"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
try:
|
||||
return await call_next(request)
|
||||
|
||||
except Exception as e:
|
||||
request_id = getattr(request.state, 'request_id', 'unknown')
|
||||
logger.exception(f"[{request_id}] Unhandled exception: {str(e)}")
|
||||
|
||||
# Don't expose internal errors in production
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"error": "internal_server_error",
|
||||
"message": "An unexpected error occurred. Please try again later.",
|
||||
"request_id": request_id
|
||||
}
|
||||
)
|
||||
440
middleware/validation.py
Normal file
440
middleware/validation.py
Normal file
@@ -0,0 +1,440 @@
|
||||
"""
|
||||
Input Validation Module for SaaS robustness
|
||||
Validates all user inputs before processing
|
||||
"""
|
||||
import re
|
||||
import magic
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Set
|
||||
from fastapi import UploadFile, HTTPException
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ValidationError(Exception):
|
||||
"""Custom validation error with user-friendly messages"""
|
||||
def __init__(self, message: str, code: str = "validation_error", details: Optional[dict] = None):
|
||||
self.message = message
|
||||
self.code = code
|
||||
self.details = details or {}
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ValidationResult:
|
||||
"""Result of a validation check"""
|
||||
def __init__(self, is_valid: bool = True, errors: List[str] = None, warnings: List[str] = None, data: dict = None):
|
||||
self.is_valid = is_valid
|
||||
self.errors = errors or []
|
||||
self.warnings = warnings or []
|
||||
self.data = data or {}
|
||||
|
||||
|
||||
class FileValidator:
|
||||
"""Validates uploaded files for security and compatibility"""
|
||||
|
||||
# Allowed MIME types mapped to extensions
|
||||
ALLOWED_MIME_TYPES = {
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx",
|
||||
}
|
||||
|
||||
# Magic bytes for Office Open XML files (ZIP format)
|
||||
OFFICE_MAGIC_BYTES = b"PK\x03\x04"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_size_mb: int = 50,
|
||||
allowed_extensions: Set[str] = None,
|
||||
scan_content: bool = True
|
||||
):
|
||||
self.max_size_bytes = max_size_mb * 1024 * 1024
|
||||
self.max_size_mb = max_size_mb
|
||||
self.allowed_extensions = allowed_extensions or {".xlsx", ".docx", ".pptx"}
|
||||
self.scan_content = scan_content
|
||||
|
||||
async def validate_async(self, file: UploadFile) -> ValidationResult:
|
||||
"""
|
||||
Validate an uploaded file asynchronously
|
||||
Returns ValidationResult with is_valid, errors, warnings
|
||||
"""
|
||||
errors = []
|
||||
warnings = []
|
||||
data = {}
|
||||
|
||||
try:
|
||||
# Validate filename
|
||||
if not file.filename:
|
||||
errors.append("Filename is required")
|
||||
return ValidationResult(is_valid=False, errors=errors)
|
||||
|
||||
# Sanitize filename
|
||||
try:
|
||||
safe_filename = self._sanitize_filename(file.filename)
|
||||
data["safe_filename"] = safe_filename
|
||||
except ValidationError as e:
|
||||
errors.append(str(e.message))
|
||||
return ValidationResult(is_valid=False, errors=errors)
|
||||
|
||||
# Validate extension
|
||||
try:
|
||||
extension = self._validate_extension(safe_filename)
|
||||
data["extension"] = extension
|
||||
except ValidationError as e:
|
||||
errors.append(str(e.message))
|
||||
return ValidationResult(is_valid=False, errors=errors)
|
||||
|
||||
# Read file content for validation
|
||||
content = await file.read()
|
||||
await file.seek(0) # Reset for later processing
|
||||
|
||||
# Validate file size
|
||||
file_size = len(content)
|
||||
data["size_bytes"] = file_size
|
||||
data["size_mb"] = round(file_size / (1024*1024), 2)
|
||||
|
||||
if file_size > self.max_size_bytes:
|
||||
errors.append(f"File too large. Maximum size is {self.max_size_mb}MB, got {file_size / (1024*1024):.1f}MB")
|
||||
return ValidationResult(is_valid=False, errors=errors, data=data)
|
||||
|
||||
if file_size == 0:
|
||||
errors.append("File is empty")
|
||||
return ValidationResult(is_valid=False, errors=errors, data=data)
|
||||
|
||||
# Warn about large files
|
||||
if file_size > self.max_size_bytes * 0.8:
|
||||
warnings.append(f"File is {data['size_mb']}MB, approaching the {self.max_size_mb}MB limit")
|
||||
|
||||
# Validate magic bytes
|
||||
if self.scan_content:
|
||||
try:
|
||||
self._validate_magic_bytes(content, extension)
|
||||
except ValidationError as e:
|
||||
errors.append(str(e.message))
|
||||
return ValidationResult(is_valid=False, errors=errors, data=data)
|
||||
|
||||
# Validate MIME type
|
||||
try:
|
||||
mime_type = self._detect_mime_type(content)
|
||||
data["mime_type"] = mime_type
|
||||
self._validate_mime_type(mime_type, extension)
|
||||
except ValidationError as e:
|
||||
warnings.append(f"MIME type warning: {e.message}")
|
||||
except Exception:
|
||||
warnings.append("Could not verify MIME type")
|
||||
|
||||
data["original_filename"] = file.filename
|
||||
|
||||
return ValidationResult(is_valid=True, errors=errors, warnings=warnings, data=data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Validation error: {str(e)}")
|
||||
errors.append(f"Validation failed: {str(e)}")
|
||||
return ValidationResult(is_valid=False, errors=errors, warnings=warnings, data=data)
|
||||
|
||||
async def validate(self, file: UploadFile) -> dict:
|
||||
"""
|
||||
Validate an uploaded file
|
||||
Returns validation info dict or raises ValidationError
|
||||
"""
|
||||
# Validate filename
|
||||
if not file.filename:
|
||||
raise ValidationError(
|
||||
"Filename is required",
|
||||
code="missing_filename"
|
||||
)
|
||||
|
||||
# Sanitize filename
|
||||
safe_filename = self._sanitize_filename(file.filename)
|
||||
|
||||
# Validate extension
|
||||
extension = self._validate_extension(safe_filename)
|
||||
|
||||
# Read file content for validation
|
||||
content = await file.read()
|
||||
await file.seek(0) # Reset for later processing
|
||||
|
||||
# Validate file size
|
||||
file_size = len(content)
|
||||
if file_size > self.max_size_bytes:
|
||||
raise ValidationError(
|
||||
f"File too large. Maximum size is {self.max_size_mb}MB, got {file_size / (1024*1024):.1f}MB",
|
||||
code="file_too_large",
|
||||
details={"max_mb": self.max_size_mb, "actual_mb": round(file_size / (1024*1024), 2)}
|
||||
)
|
||||
|
||||
if file_size == 0:
|
||||
raise ValidationError(
|
||||
"File is empty",
|
||||
code="empty_file"
|
||||
)
|
||||
|
||||
# Validate magic bytes (file signature)
|
||||
if self.scan_content:
|
||||
self._validate_magic_bytes(content, extension)
|
||||
|
||||
# Validate MIME type
|
||||
mime_type = self._detect_mime_type(content)
|
||||
self._validate_mime_type(mime_type, extension)
|
||||
|
||||
return {
|
||||
"original_filename": file.filename,
|
||||
"safe_filename": safe_filename,
|
||||
"extension": extension,
|
||||
"size_bytes": file_size,
|
||||
"size_mb": round(file_size / (1024*1024), 2),
|
||||
"mime_type": mime_type
|
||||
}
|
||||
|
||||
def _sanitize_filename(self, filename: str) -> str:
|
||||
"""Sanitize filename to prevent path traversal and other attacks"""
|
||||
# Remove path components
|
||||
filename = Path(filename).name
|
||||
|
||||
# Remove null bytes and control characters
|
||||
filename = re.sub(r'[\x00-\x1f\x7f-\x9f]', '', filename)
|
||||
|
||||
# Remove potentially dangerous characters
|
||||
filename = re.sub(r'[<>:"/\\|?*]', '_', filename)
|
||||
|
||||
# Limit length
|
||||
if len(filename) > 255:
|
||||
name, ext = filename.rsplit('.', 1) if '.' in filename else (filename, '')
|
||||
filename = name[:250] + ('.' + ext if ext else '')
|
||||
|
||||
# Ensure not empty after sanitization
|
||||
if not filename or filename.strip() == '':
|
||||
raise ValidationError(
|
||||
"Invalid filename",
|
||||
code="invalid_filename"
|
||||
)
|
||||
|
||||
return filename
|
||||
|
||||
def _validate_extension(self, filename: str) -> str:
|
||||
"""Validate and return the file extension"""
|
||||
if '.' not in filename:
|
||||
raise ValidationError(
|
||||
f"File must have an extension. Supported: {', '.join(self.allowed_extensions)}",
|
||||
code="missing_extension",
|
||||
details={"allowed_extensions": list(self.allowed_extensions)}
|
||||
)
|
||||
|
||||
extension = '.' + filename.rsplit('.', 1)[1].lower()
|
||||
|
||||
if extension not in self.allowed_extensions:
|
||||
raise ValidationError(
|
||||
f"File type '{extension}' not supported. Supported types: {', '.join(self.allowed_extensions)}",
|
||||
code="unsupported_file_type",
|
||||
details={"extension": extension, "allowed_extensions": list(self.allowed_extensions)}
|
||||
)
|
||||
|
||||
return extension
|
||||
|
||||
def _validate_magic_bytes(self, content: bytes, extension: str):
|
||||
"""Validate file magic bytes match expected format"""
|
||||
# All supported formats are Office Open XML (ZIP-based)
|
||||
if not content.startswith(self.OFFICE_MAGIC_BYTES):
|
||||
raise ValidationError(
|
||||
"File content does not match expected format. The file may be corrupted or not a valid Office document.",
|
||||
code="invalid_file_content"
|
||||
)
|
||||
|
||||
def _detect_mime_type(self, content: bytes) -> str:
|
||||
"""Detect MIME type from file content"""
|
||||
try:
|
||||
mime = magic.Magic(mime=True)
|
||||
return mime.from_buffer(content)
|
||||
except Exception:
|
||||
# Fallback to basic detection
|
||||
if content.startswith(self.OFFICE_MAGIC_BYTES):
|
||||
return "application/zip"
|
||||
return "application/octet-stream"
|
||||
|
||||
def _validate_mime_type(self, mime_type: str, extension: str):
|
||||
"""Validate MIME type matches extension"""
|
||||
# Office Open XML files may be detected as ZIP
|
||||
allowed_mimes = list(self.ALLOWED_MIME_TYPES.keys()) + ["application/zip", "application/octet-stream"]
|
||||
|
||||
if mime_type not in allowed_mimes:
|
||||
raise ValidationError(
|
||||
f"Invalid file type detected. Expected Office document, got: {mime_type}",
|
||||
code="invalid_mime_type",
|
||||
details={"detected_mime": mime_type}
|
||||
)
|
||||
|
||||
|
||||
class LanguageValidator:
|
||||
"""Validates language codes"""
|
||||
|
||||
SUPPORTED_LANGUAGES = {
|
||||
# ISO 639-1 codes
|
||||
"af", "sq", "am", "ar", "hy", "az", "eu", "be", "bn", "bs",
|
||||
"bg", "ca", "ceb", "zh", "zh-CN", "zh-TW", "co", "hr", "cs",
|
||||
"da", "nl", "en", "eo", "et", "fi", "fr", "fy", "gl", "ka",
|
||||
"de", "el", "gu", "ht", "ha", "haw", "he", "hi", "hmn", "hu",
|
||||
"is", "ig", "id", "ga", "it", "ja", "jv", "kn", "kk", "km",
|
||||
"rw", "ko", "ku", "ky", "lo", "la", "lv", "lt", "lb", "mk",
|
||||
"mg", "ms", "ml", "mt", "mi", "mr", "mn", "my", "ne", "no",
|
||||
"ny", "or", "ps", "fa", "pl", "pt", "pa", "ro", "ru", "sm",
|
||||
"gd", "sr", "st", "sn", "sd", "si", "sk", "sl", "so", "es",
|
||||
"su", "sw", "sv", "tl", "tg", "ta", "tt", "te", "th", "tr",
|
||||
"tk", "uk", "ur", "ug", "uz", "vi", "cy", "xh", "yi", "yo",
|
||||
"zu", "auto"
|
||||
}
|
||||
|
||||
LANGUAGE_NAMES = {
|
||||
"en": "English", "es": "Spanish", "fr": "French", "de": "German",
|
||||
"it": "Italian", "pt": "Portuguese", "ru": "Russian", "zh": "Chinese",
|
||||
"zh-CN": "Chinese (Simplified)", "zh-TW": "Chinese (Traditional)",
|
||||
"ja": "Japanese", "ko": "Korean", "ar": "Arabic", "hi": "Hindi",
|
||||
"nl": "Dutch", "pl": "Polish", "tr": "Turkish", "sv": "Swedish",
|
||||
"da": "Danish", "no": "Norwegian", "fi": "Finnish", "cs": "Czech",
|
||||
"el": "Greek", "th": "Thai", "vi": "Vietnamese", "id": "Indonesian",
|
||||
"uk": "Ukrainian", "ro": "Romanian", "hu": "Hungarian", "auto": "Auto-detect"
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def validate(cls, language_code: str, field_name: str = "language") -> str:
|
||||
"""Validate and normalize language code"""
|
||||
if not language_code:
|
||||
raise ValidationError(
|
||||
f"{field_name} is required",
|
||||
code="missing_language"
|
||||
)
|
||||
|
||||
# Normalize
|
||||
normalized = language_code.strip().lower()
|
||||
|
||||
# Handle common variations
|
||||
if normalized in ["chinese", "cn"]:
|
||||
normalized = "zh-CN"
|
||||
elif normalized in ["chinese-traditional", "tw"]:
|
||||
normalized = "zh-TW"
|
||||
|
||||
if normalized not in cls.SUPPORTED_LANGUAGES:
|
||||
raise ValidationError(
|
||||
f"Unsupported language code: '{language_code}'. See /languages for supported codes.",
|
||||
code="unsupported_language",
|
||||
details={"language": language_code}
|
||||
)
|
||||
|
||||
return normalized
|
||||
|
||||
@classmethod
|
||||
def get_language_name(cls, code: str) -> str:
|
||||
"""Get human-readable language name"""
|
||||
return cls.LANGUAGE_NAMES.get(code, code.upper())
|
||||
|
||||
|
||||
class ProviderValidator:
|
||||
"""Validates translation provider configuration"""
|
||||
|
||||
SUPPORTED_PROVIDERS = {"google", "ollama", "deepl", "libre", "openai", "webllm"}
|
||||
|
||||
@classmethod
|
||||
def validate(cls, provider: str, **kwargs) -> dict:
|
||||
"""Validate provider and its required configuration"""
|
||||
if not provider:
|
||||
raise ValidationError(
|
||||
"Translation provider is required",
|
||||
code="missing_provider"
|
||||
)
|
||||
|
||||
normalized = provider.strip().lower()
|
||||
|
||||
if normalized not in cls.SUPPORTED_PROVIDERS:
|
||||
raise ValidationError(
|
||||
f"Unsupported provider: '{provider}'. Supported: {', '.join(cls.SUPPORTED_PROVIDERS)}",
|
||||
code="unsupported_provider",
|
||||
details={"provider": provider, "supported": list(cls.SUPPORTED_PROVIDERS)}
|
||||
)
|
||||
|
||||
# Provider-specific validation
|
||||
if normalized == "deepl":
|
||||
if not kwargs.get("deepl_api_key"):
|
||||
raise ValidationError(
|
||||
"DeepL API key is required when using DeepL provider",
|
||||
code="missing_deepl_key"
|
||||
)
|
||||
|
||||
elif normalized == "openai":
|
||||
if not kwargs.get("openai_api_key"):
|
||||
raise ValidationError(
|
||||
"OpenAI API key is required when using OpenAI provider",
|
||||
code="missing_openai_key"
|
||||
)
|
||||
|
||||
elif normalized == "ollama":
|
||||
# Ollama doesn't require API key but may need model
|
||||
model = kwargs.get("ollama_model", "")
|
||||
if not model:
|
||||
logger.warning("No Ollama model specified, will use default")
|
||||
|
||||
return {"provider": normalized, "validated": True}
|
||||
|
||||
|
||||
class InputSanitizer:
|
||||
"""Sanitizes user inputs to prevent injection attacks"""
|
||||
|
||||
@staticmethod
|
||||
def sanitize_text(text: str, max_length: int = 10000) -> str:
|
||||
"""Sanitize text input"""
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
# Remove null bytes
|
||||
text = text.replace('\x00', '')
|
||||
|
||||
# Limit length
|
||||
if len(text) > max_length:
|
||||
text = text[:max_length]
|
||||
|
||||
return text.strip()
|
||||
|
||||
@staticmethod
|
||||
def sanitize_language_code(code: str) -> str:
|
||||
"""Sanitize and normalize language code"""
|
||||
if not code:
|
||||
return "auto"
|
||||
|
||||
# Remove dangerous characters, keep only alphanumeric and hyphen
|
||||
code = re.sub(r'[^a-zA-Z0-9\-]', '', code.strip())
|
||||
|
||||
# Limit length
|
||||
if len(code) > 10:
|
||||
code = code[:10]
|
||||
|
||||
return code.lower() if code else "auto"
|
||||
|
||||
@staticmethod
|
||||
def sanitize_url(url: str) -> str:
|
||||
"""Sanitize URL input"""
|
||||
if not url:
|
||||
return ""
|
||||
|
||||
url = url.strip()
|
||||
|
||||
# Basic URL validation
|
||||
if not re.match(r'^https?://', url, re.IGNORECASE):
|
||||
raise ValidationError(
|
||||
"Invalid URL format. Must start with http:// or https://",
|
||||
code="invalid_url"
|
||||
)
|
||||
|
||||
# Remove trailing slashes
|
||||
url = url.rstrip('/')
|
||||
|
||||
return url
|
||||
|
||||
@staticmethod
|
||||
def sanitize_api_key(key: str) -> str:
|
||||
"""Sanitize API key (just trim, no logging)"""
|
||||
if not key:
|
||||
return ""
|
||||
return key.strip()
|
||||
|
||||
|
||||
# Default validators
|
||||
file_validator = FileValidator()
|
||||
Reference in New Issue
Block a user