feat: Add SaaS robustness middleware - Rate limiting with token bucket and sliding window algorithms - Input validation (file, language, provider) - Security headers middleware (CSP, XSS protection, etc.) - Automatic file cleanup with TTL tracking - Memory and disk monitoring - Enhanced health check and metrics endpoints - Request logging with unique IDs
This commit is contained in:
parent
8c7716bf4d
commit
500502440c
67
.env.example
67
.env.example
@ -1,5 +1,11 @@
|
||||
# Translation Service Configuration
|
||||
TRANSLATION_SERVICE=google # Options: google, deepl, libre, ollama
|
||||
# Document Translation API - Environment Configuration
|
||||
# Copy this file to .env and configure your settings
|
||||
|
||||
# ============== Translation Services ==============
|
||||
# Default provider: google, ollama, deepl, libre, openai
|
||||
TRANSLATION_SERVICE=google
|
||||
|
||||
# DeepL API Key (required for DeepL provider)
|
||||
DEEPL_API_KEY=your_deepl_api_key_here
|
||||
|
||||
# Ollama Configuration (for LLM-based translation)
|
||||
@ -7,7 +13,58 @@ OLLAMA_BASE_URL=http://localhost:11434
|
||||
OLLAMA_MODEL=llama3
|
||||
OLLAMA_VISION_MODEL=llava
|
||||
|
||||
# API Configuration
|
||||
# ============== File Limits ==============
|
||||
# Maximum file size in MB
|
||||
MAX_FILE_SIZE_MB=50
|
||||
UPLOAD_DIR=./uploads
|
||||
OUTPUT_DIR=./outputs
|
||||
|
||||
# ============== Rate Limiting (SaaS) ==============
|
||||
# Enable/disable rate limiting
|
||||
RATE_LIMIT_ENABLED=true
|
||||
|
||||
# Request limits
|
||||
RATE_LIMIT_PER_MINUTE=30
|
||||
RATE_LIMIT_PER_HOUR=200
|
||||
|
||||
# Translation-specific limits
|
||||
TRANSLATIONS_PER_MINUTE=10
|
||||
TRANSLATIONS_PER_HOUR=50
|
||||
MAX_CONCURRENT_TRANSLATIONS=5
|
||||
|
||||
# ============== Cleanup Service ==============
|
||||
# Enable automatic file cleanup
|
||||
CLEANUP_ENABLED=true
|
||||
|
||||
# Cleanup interval in minutes
|
||||
CLEANUP_INTERVAL_MINUTES=15
|
||||
|
||||
# File time-to-live in minutes
|
||||
FILE_TTL_MINUTES=60
|
||||
INPUT_FILE_TTL_MINUTES=30
|
||||
OUTPUT_FILE_TTL_MINUTES=120
|
||||
|
||||
# Disk space warning thresholds (GB)
|
||||
DISK_WARNING_THRESHOLD_GB=5.0
|
||||
DISK_CRITICAL_THRESHOLD_GB=1.0
|
||||
|
||||
# ============== Security ==============
|
||||
# Enable HSTS (only for HTTPS deployments)
|
||||
ENABLE_HSTS=false
|
||||
|
||||
# CORS allowed origins (comma-separated)
|
||||
CORS_ORIGINS=*
|
||||
|
||||
# Maximum request size in MB
|
||||
MAX_REQUEST_SIZE_MB=100
|
||||
|
||||
# Request timeout in seconds
|
||||
REQUEST_TIMEOUT_SECONDS=300
|
||||
|
||||
# ============== Monitoring ==============
|
||||
# Log level: DEBUG, INFO, WARNING, ERROR
|
||||
LOG_LEVEL=INFO
|
||||
|
||||
# Enable request logging
|
||||
ENABLE_REQUEST_LOGGING=true
|
||||
|
||||
# Memory usage threshold (percentage)
|
||||
MAX_MEMORY_PERCENT=80
|
||||
|
||||
48
config.py
48
config.py
@ -1,5 +1,6 @@
|
||||
"""
|
||||
Configuration module for the Document Translation API
|
||||
SaaS-ready with comprehensive settings for production deployment
|
||||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
@ -8,7 +9,7 @@ from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
class Config:
|
||||
# Translation Service
|
||||
# ============== Translation Service ==============
|
||||
TRANSLATION_SERVICE = os.getenv("TRANSLATION_SERVICE", "google")
|
||||
DEEPL_API_KEY = os.getenv("DEEPL_API_KEY", "")
|
||||
|
||||
@ -17,20 +18,51 @@ class Config:
|
||||
OLLAMA_MODEL = os.getenv("OLLAMA_MODEL", "llama3")
|
||||
OLLAMA_VISION_MODEL = os.getenv("OLLAMA_VISION_MODEL", "llava")
|
||||
|
||||
# File Upload Configuration
|
||||
# ============== File Upload Configuration ==============
|
||||
MAX_FILE_SIZE_MB = int(os.getenv("MAX_FILE_SIZE_MB", "50"))
|
||||
MAX_FILE_SIZE_BYTES = MAX_FILE_SIZE_MB * 1024 * 1024
|
||||
|
||||
# Directories
|
||||
BASE_DIR = Path(__file__).parent.parent
|
||||
BASE_DIR = Path(__file__).parent
|
||||
UPLOAD_DIR = BASE_DIR / "uploads"
|
||||
OUTPUT_DIR = BASE_DIR / "outputs"
|
||||
TEMP_DIR = BASE_DIR / "temp"
|
||||
LOGS_DIR = BASE_DIR / "logs"
|
||||
|
||||
# Supported file types
|
||||
SUPPORTED_EXTENSIONS = {".xlsx", ".docx", ".pptx"}
|
||||
|
||||
# API Configuration
|
||||
# ============== Rate Limiting (SaaS) ==============
|
||||
RATE_LIMIT_ENABLED = os.getenv("RATE_LIMIT_ENABLED", "true").lower() == "true"
|
||||
RATE_LIMIT_PER_MINUTE = int(os.getenv("RATE_LIMIT_PER_MINUTE", "30"))
|
||||
RATE_LIMIT_PER_HOUR = int(os.getenv("RATE_LIMIT_PER_HOUR", "200"))
|
||||
TRANSLATIONS_PER_MINUTE = int(os.getenv("TRANSLATIONS_PER_MINUTE", "10"))
|
||||
TRANSLATIONS_PER_HOUR = int(os.getenv("TRANSLATIONS_PER_HOUR", "50"))
|
||||
MAX_CONCURRENT_TRANSLATIONS = int(os.getenv("MAX_CONCURRENT_TRANSLATIONS", "5"))
|
||||
|
||||
# ============== Cleanup Service ==============
|
||||
CLEANUP_ENABLED = os.getenv("CLEANUP_ENABLED", "true").lower() == "true"
|
||||
CLEANUP_INTERVAL_MINUTES = int(os.getenv("CLEANUP_INTERVAL_MINUTES", "15"))
|
||||
FILE_TTL_MINUTES = int(os.getenv("FILE_TTL_MINUTES", "60"))
|
||||
INPUT_FILE_TTL_MINUTES = int(os.getenv("INPUT_FILE_TTL_MINUTES", "30"))
|
||||
OUTPUT_FILE_TTL_MINUTES = int(os.getenv("OUTPUT_FILE_TTL_MINUTES", "120"))
|
||||
|
||||
# Disk space thresholds
|
||||
DISK_WARNING_THRESHOLD_GB = float(os.getenv("DISK_WARNING_THRESHOLD_GB", "5.0"))
|
||||
DISK_CRITICAL_THRESHOLD_GB = float(os.getenv("DISK_CRITICAL_THRESHOLD_GB", "1.0"))
|
||||
|
||||
# ============== Security ==============
|
||||
ENABLE_HSTS = os.getenv("ENABLE_HSTS", "false").lower() == "true"
|
||||
CORS_ORIGINS = os.getenv("CORS_ORIGINS", "*").split(",")
|
||||
MAX_REQUEST_SIZE_MB = int(os.getenv("MAX_REQUEST_SIZE_MB", "100"))
|
||||
REQUEST_TIMEOUT_SECONDS = int(os.getenv("REQUEST_TIMEOUT_SECONDS", "300"))
|
||||
|
||||
# ============== Monitoring ==============
|
||||
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")
|
||||
ENABLE_REQUEST_LOGGING = os.getenv("ENABLE_REQUEST_LOGGING", "true").lower() == "true"
|
||||
MAX_MEMORY_PERCENT = float(os.getenv("MAX_MEMORY_PERCENT", "80"))
|
||||
|
||||
# ============== API Configuration ==============
|
||||
API_TITLE = "Document Translation API"
|
||||
API_VERSION = "1.0.0"
|
||||
API_DESCRIPTION = """
|
||||
@ -40,6 +72,12 @@ class Config:
|
||||
- Excel (.xlsx) - Preserves cell formatting, formulas, merged cells, images
|
||||
- Word (.docx) - Preserves styles, tables, images, headers/footers
|
||||
- PowerPoint (.pptx) - Preserves layouts, animations, embedded media
|
||||
|
||||
SaaS Features:
|
||||
- Rate limiting per client IP
|
||||
- Automatic file cleanup
|
||||
- Health monitoring
|
||||
- Request logging
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@ -48,5 +86,7 @@ class Config:
|
||||
cls.UPLOAD_DIR.mkdir(exist_ok=True, parents=True)
|
||||
cls.OUTPUT_DIR.mkdir(exist_ok=True, parents=True)
|
||||
cls.TEMP_DIR.mkdir(exist_ok=True, parents=True)
|
||||
cls.LOGS_DIR.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
|
||||
config = Config()
|
||||
|
||||
212
main.py
212
main.py
@ -1,24 +1,55 @@
|
||||
"""
|
||||
Document Translation API
|
||||
FastAPI application for translating complex documents while preserving formatting
|
||||
SaaS-ready with rate limiting, validation, and robust error handling
|
||||
"""
|
||||
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
||||
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request, Depends
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
|
||||
from config import config
|
||||
from translators import excel_translator, word_translator, pptx_translator
|
||||
from utils import file_handler, handle_translation_error, DocumentProcessingError
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
# Import SaaS middleware
|
||||
from middleware.rate_limiting import RateLimitMiddleware, RateLimitManager, RateLimitConfig
|
||||
from middleware.security import SecurityHeadersMiddleware, RequestLoggingMiddleware, ErrorHandlingMiddleware
|
||||
from middleware.cleanup import FileCleanupManager, MemoryMonitor, HealthChecker, create_cleanup_manager
|
||||
from middleware.validation import FileValidator, LanguageValidator, ProviderValidator, InputSanitizer, ValidationError
|
||||
|
||||
# Configure structured logging
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, os.getenv("LOG_LEVEL", "INFO")),
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize SaaS components
|
||||
rate_limit_config = RateLimitConfig(
|
||||
requests_per_minute=int(os.getenv("RATE_LIMIT_PER_MINUTE", "30")),
|
||||
requests_per_hour=int(os.getenv("RATE_LIMIT_PER_HOUR", "200")),
|
||||
translations_per_minute=int(os.getenv("TRANSLATIONS_PER_MINUTE", "10")),
|
||||
translations_per_hour=int(os.getenv("TRANSLATIONS_PER_HOUR", "50")),
|
||||
max_concurrent_translations=int(os.getenv("MAX_CONCURRENT_TRANSLATIONS", "5")),
|
||||
)
|
||||
rate_limit_manager = RateLimitManager(rate_limit_config)
|
||||
|
||||
cleanup_manager = create_cleanup_manager(config)
|
||||
memory_monitor = MemoryMonitor(max_memory_percent=float(os.getenv("MAX_MEMORY_PERCENT", "80")))
|
||||
health_checker = HealthChecker(cleanup_manager, memory_monitor)
|
||||
|
||||
file_validator = FileValidator(
|
||||
max_size_mb=config.MAX_FILE_SIZE_MB,
|
||||
allowed_extensions=config.SUPPORTED_EXTENSIONS
|
||||
)
|
||||
|
||||
|
||||
def build_full_prompt(system_prompt: str, glossary: str) -> str:
|
||||
"""Combine system prompt and glossary into a single prompt for LLM translation."""
|
||||
@ -40,23 +71,47 @@ Always use the translations from this glossary when you encounter these terms.""
|
||||
return "\n\n".join(parts) if parts else ""
|
||||
|
||||
|
||||
# Ensure necessary directories exist
|
||||
# Lifespan context manager for startup/shutdown
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Handle startup and shutdown events"""
|
||||
# Startup
|
||||
logger.info("Starting Document Translation API...")
|
||||
config.ensure_directories()
|
||||
await cleanup_manager.start()
|
||||
logger.info("API ready to accept requests")
|
||||
|
||||
# Create FastAPI app
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
logger.info("Shutting down...")
|
||||
await cleanup_manager.stop()
|
||||
logger.info("Cleanup completed")
|
||||
|
||||
|
||||
# Create FastAPI app with lifespan
|
||||
app = FastAPI(
|
||||
title=config.API_TITLE,
|
||||
version=config.API_VERSION,
|
||||
description=config.API_DESCRIPTION
|
||||
description=config.API_DESCRIPTION,
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# Add CORS middleware
|
||||
# Add middleware (order matters - first added is outermost)
|
||||
app.add_middleware(ErrorHandlingMiddleware)
|
||||
app.add_middleware(RequestLoggingMiddleware, log_body=False)
|
||||
app.add_middleware(SecurityHeadersMiddleware, config={"enable_hsts": os.getenv("ENABLE_HSTS", "false").lower() == "true"})
|
||||
app.add_middleware(RateLimitMiddleware, rate_limit_manager=rate_limit_manager)
|
||||
|
||||
# CORS - configure for production
|
||||
allowed_origins = os.getenv("CORS_ORIGINS", "*").split(",")
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # Configure appropriately for production
|
||||
allow_origins=allowed_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_methods=["GET", "POST", "DELETE", "OPTIONS"],
|
||||
allow_headers=["*"],
|
||||
expose_headers=["X-Request-ID", "X-Original-Filename", "X-File-Size-MB", "X-Target-Language"]
|
||||
)
|
||||
|
||||
# Mount static files
|
||||
@ -65,6 +120,20 @@ if static_dir.exists():
|
||||
app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")
|
||||
|
||||
|
||||
# Custom exception handler for ValidationError
|
||||
@app.exception_handler(ValidationError)
|
||||
async def validation_error_handler(request: Request, exc: ValidationError):
|
||||
"""Handle validation errors with user-friendly messages"""
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"error": exc.code,
|
||||
"message": exc.message,
|
||||
"details": exc.details
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint with API information"""
|
||||
@ -83,11 +152,24 @@ async def root():
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"translation_service": config.TRANSLATION_SERVICE
|
||||
"""Health check endpoint with detailed system status"""
|
||||
health_status = await health_checker.check_health()
|
||||
status_code = 200 if health_status.get("status") == "healthy" else 503
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status_code,
|
||||
content={
|
||||
"status": health_status.get("status", "unknown"),
|
||||
"translation_service": config.TRANSLATION_SERVICE,
|
||||
"memory": health_status.get("memory", {}),
|
||||
"disk": health_status.get("disk", {}),
|
||||
"cleanup_service": health_status.get("cleanup_service", {}),
|
||||
"rate_limits": {
|
||||
"requests_per_minute": rate_limit_config.requests_per_minute,
|
||||
"translations_per_minute": rate_limit_config.translations_per_minute,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@app.get("/languages")
|
||||
@ -128,6 +210,7 @@ async def get_supported_languages():
|
||||
|
||||
@app.post("/translate")
|
||||
async def translate_document(
|
||||
request: Request,
|
||||
file: UploadFile = File(..., description="Document file to translate (.xlsx, .docx, or .pptx)"),
|
||||
target_language: str = Form(..., description="Target language code (e.g., 'es', 'fr', 'de')"),
|
||||
source_language: str = Form(default="auto", description="Source language code (default: auto-detect)"),
|
||||
@ -160,11 +243,38 @@ async def translate_document(
|
||||
"""
|
||||
input_path = None
|
||||
output_path = None
|
||||
request_id = getattr(request.state, 'request_id', 'unknown')
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
sanitized_language = InputSanitizer.sanitize_language_code(target_language)
|
||||
LanguageValidator.validate(sanitized_language)
|
||||
ProviderValidator.validate(provider)
|
||||
|
||||
# Validate file before processing
|
||||
validation_result = await file_validator.validate_async(file)
|
||||
if not validation_result.is_valid:
|
||||
raise ValidationError(
|
||||
message=f"File validation failed: {'; '.join(validation_result.errors)}",
|
||||
code="INVALID_FILE",
|
||||
details={"errors": validation_result.errors, "warnings": validation_result.warnings}
|
||||
)
|
||||
|
||||
# Log any warnings
|
||||
if validation_result.warnings:
|
||||
logger.warning(f"[{request_id}] File validation warnings: {validation_result.warnings}")
|
||||
|
||||
# Check rate limit for translations
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
if not await rate_limit_manager.check_translation_limit(client_ip):
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="Translation rate limit exceeded. Please try again later."
|
||||
)
|
||||
|
||||
# Validate file extension
|
||||
file_extension = file_handler.validate_file_extension(file.filename)
|
||||
logger.info(f"Processing {file_extension} file: {file.filename}")
|
||||
logger.info(f"[{request_id}] Processing {file_extension} file: {file.filename}")
|
||||
|
||||
# Validate file size
|
||||
file_handler.validate_file_size(file)
|
||||
@ -178,7 +288,11 @@ async def translate_document(
|
||||
output_path = config.OUTPUT_DIR / output_filename
|
||||
|
||||
await file_handler.save_upload_file(file, input_path)
|
||||
logger.info(f"Saved input file to: {input_path}")
|
||||
logger.info(f"[{request_id}] Saved input file to: {input_path}")
|
||||
|
||||
# Track file for cleanup
|
||||
await cleanup_manager.track_file(input_path, ttl_minutes=30)
|
||||
await cleanup_manager.track_file(output_path, ttl_minutes=60)
|
||||
|
||||
# Configure translation provider
|
||||
from services.translation_service import GoogleTranslationProvider, DeepLTranslationProvider, LibreTranslationProvider, OllamaTranslationProvider, OpenAITranslationProvider, translation_service
|
||||
@ -657,6 +771,74 @@ async def reconstruct_document(
|
||||
raise HTTPException(status_code=500, detail=f"Failed to reconstruct document: {str(e)}")
|
||||
|
||||
|
||||
# ============== SaaS Management Endpoints ==============
|
||||
|
||||
@app.get("/metrics")
|
||||
async def get_metrics():
|
||||
"""Get system metrics and statistics for monitoring"""
|
||||
health_status = await health_checker.check_health()
|
||||
cleanup_stats = cleanup_manager.get_stats()
|
||||
rate_limit_stats = rate_limit_manager.get_stats()
|
||||
|
||||
return {
|
||||
"system": {
|
||||
"memory": health_status.get("memory", {}),
|
||||
"disk": health_status.get("disk", {}),
|
||||
"status": health_status.get("status", "unknown")
|
||||
},
|
||||
"cleanup": cleanup_stats,
|
||||
"rate_limits": rate_limit_stats,
|
||||
"config": {
|
||||
"max_file_size_mb": config.MAX_FILE_SIZE_MB,
|
||||
"supported_extensions": list(config.SUPPORTED_EXTENSIONS),
|
||||
"translation_service": config.TRANSLATION_SERVICE
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@app.get("/rate-limit/status")
|
||||
async def get_rate_limit_status(request: Request):
|
||||
"""Get current rate limit status for the requesting client"""
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
status = await rate_limit_manager.get_client_status(client_ip)
|
||||
|
||||
return {
|
||||
"client_ip": client_ip,
|
||||
"limits": {
|
||||
"requests_per_minute": rate_limit_config.requests_per_minute,
|
||||
"requests_per_hour": rate_limit_config.requests_per_hour,
|
||||
"translations_per_minute": rate_limit_config.translations_per_minute,
|
||||
"translations_per_hour": rate_limit_config.translations_per_hour
|
||||
},
|
||||
"current_usage": status
|
||||
}
|
||||
|
||||
|
||||
@app.post("/admin/cleanup/trigger")
|
||||
async def trigger_cleanup():
|
||||
"""Trigger manual cleanup of expired files"""
|
||||
try:
|
||||
cleaned = await cleanup_manager.cleanup_expired()
|
||||
return {
|
||||
"status": "success",
|
||||
"files_cleaned": cleaned,
|
||||
"message": f"Cleaned up {cleaned} expired files"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Manual cleanup failed: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Cleanup failed: {str(e)}")
|
||||
|
||||
|
||||
@app.get("/admin/files/tracked")
|
||||
async def get_tracked_files():
|
||||
"""Get list of currently tracked files"""
|
||||
tracked = cleanup_manager.get_tracked_files()
|
||||
return {
|
||||
"count": len(tracked),
|
||||
"files": tracked
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
|
||||
|
||||
62
middleware/__init__.py
Normal file
62
middleware/__init__.py
Normal file
@ -0,0 +1,62 @@
|
||||
"""
|
||||
Middleware package for SaaS robustness
|
||||
|
||||
This package provides:
|
||||
- Rate limiting: Protect against abuse and ensure fair usage
|
||||
- Validation: Validate all inputs before processing
|
||||
- Security: Security headers, request logging, error handling
|
||||
- Cleanup: Automatic file cleanup and resource management
|
||||
"""
|
||||
|
||||
from .rate_limiting import (
|
||||
RateLimitConfig,
|
||||
RateLimitManager,
|
||||
RateLimitMiddleware,
|
||||
ClientRateLimiter,
|
||||
)
|
||||
|
||||
from .validation import (
|
||||
ValidationError,
|
||||
ValidationResult,
|
||||
FileValidator,
|
||||
LanguageValidator,
|
||||
ProviderValidator,
|
||||
InputSanitizer,
|
||||
)
|
||||
|
||||
from .security import (
|
||||
SecurityHeadersMiddleware,
|
||||
RequestLoggingMiddleware,
|
||||
ErrorHandlingMiddleware,
|
||||
)
|
||||
|
||||
from .cleanup import (
|
||||
FileCleanupManager,
|
||||
MemoryMonitor,
|
||||
HealthChecker,
|
||||
create_cleanup_manager,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Rate limiting
|
||||
"RateLimitConfig",
|
||||
"RateLimitManager",
|
||||
"RateLimitMiddleware",
|
||||
"ClientRateLimiter",
|
||||
# Validation
|
||||
"ValidationError",
|
||||
"ValidationResult",
|
||||
"FileValidator",
|
||||
"LanguageValidator",
|
||||
"ProviderValidator",
|
||||
"InputSanitizer",
|
||||
# Security
|
||||
"SecurityHeadersMiddleware",
|
||||
"RequestLoggingMiddleware",
|
||||
"ErrorHandlingMiddleware",
|
||||
# Cleanup
|
||||
"FileCleanupManager",
|
||||
"MemoryMonitor",
|
||||
"HealthChecker",
|
||||
"create_cleanup_manager",
|
||||
]
|
||||
400
middleware/cleanup.py
Normal file
400
middleware/cleanup.py
Normal file
@ -0,0 +1,400 @@
|
||||
"""
|
||||
Cleanup and Resource Management for SaaS robustness
|
||||
Automatic cleanup of temporary files and resources
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
import asyncio
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Set
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileCleanupManager:
|
||||
"""Manages automatic cleanup of temporary and output files"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
upload_dir: Path,
|
||||
output_dir: Path,
|
||||
temp_dir: Path,
|
||||
max_file_age_hours: int = 1,
|
||||
cleanup_interval_minutes: int = 10,
|
||||
max_total_size_gb: float = 10.0
|
||||
):
|
||||
self.upload_dir = Path(upload_dir)
|
||||
self.output_dir = Path(output_dir)
|
||||
self.temp_dir = Path(temp_dir)
|
||||
self.max_file_age_seconds = max_file_age_hours * 3600
|
||||
self.cleanup_interval = cleanup_interval_minutes * 60
|
||||
self.max_total_size_bytes = int(max_total_size_gb * 1024 * 1024 * 1024)
|
||||
|
||||
self._running = False
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._protected_files: Set[str] = set()
|
||||
self._tracked_files: dict = {} # filepath -> {created, ttl_minutes}
|
||||
self._lock = threading.Lock()
|
||||
self._stats = {
|
||||
"files_cleaned": 0,
|
||||
"bytes_freed": 0,
|
||||
"cleanup_runs": 0
|
||||
}
|
||||
|
||||
async def track_file(self, filepath: Path, ttl_minutes: int = 60):
|
||||
"""Track a file for automatic cleanup after TTL expires"""
|
||||
with self._lock:
|
||||
self._tracked_files[str(filepath)] = {
|
||||
"created": time.time(),
|
||||
"ttl_minutes": ttl_minutes,
|
||||
"expires_at": time.time() + (ttl_minutes * 60)
|
||||
}
|
||||
|
||||
def get_tracked_files(self) -> list:
|
||||
"""Get list of currently tracked files with their status"""
|
||||
now = time.time()
|
||||
result = []
|
||||
|
||||
with self._lock:
|
||||
for filepath, info in self._tracked_files.items():
|
||||
remaining = info["expires_at"] - now
|
||||
result.append({
|
||||
"path": filepath,
|
||||
"exists": Path(filepath).exists(),
|
||||
"expires_in_seconds": max(0, int(remaining)),
|
||||
"ttl_minutes": info["ttl_minutes"]
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
async def cleanup_expired(self) -> int:
|
||||
"""Cleanup expired tracked files"""
|
||||
now = time.time()
|
||||
cleaned = 0
|
||||
to_remove = []
|
||||
|
||||
with self._lock:
|
||||
for filepath, info in list(self._tracked_files.items()):
|
||||
if now > info["expires_at"]:
|
||||
to_remove.append(filepath)
|
||||
|
||||
for filepath in to_remove:
|
||||
try:
|
||||
path = Path(filepath)
|
||||
if path.exists() and not self.is_protected(path):
|
||||
size = path.stat().st_size
|
||||
path.unlink()
|
||||
cleaned += 1
|
||||
self._stats["files_cleaned"] += 1
|
||||
self._stats["bytes_freed"] += size
|
||||
logger.info(f"Cleaned expired file: {filepath}")
|
||||
|
||||
with self._lock:
|
||||
self._tracked_files.pop(filepath, None)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean expired file {filepath}: {e}")
|
||||
|
||||
return cleaned
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get cleanup statistics"""
|
||||
disk_usage = self.get_disk_usage()
|
||||
|
||||
with self._lock:
|
||||
tracked_count = len(self._tracked_files)
|
||||
|
||||
return {
|
||||
"files_cleaned_total": self._stats["files_cleaned"],
|
||||
"bytes_freed_total_mb": round(self._stats["bytes_freed"] / (1024 * 1024), 2),
|
||||
"cleanup_runs": self._stats["cleanup_runs"],
|
||||
"tracked_files": tracked_count,
|
||||
"disk_usage": disk_usage,
|
||||
"is_running": self._running
|
||||
}
|
||||
|
||||
def protect_file(self, filepath: Path):
|
||||
"""Mark a file as protected (being processed)"""
|
||||
with self._lock:
|
||||
self._protected_files.add(str(filepath))
|
||||
|
||||
def unprotect_file(self, filepath: Path):
|
||||
"""Remove protection from a file"""
|
||||
with self._lock:
|
||||
self._protected_files.discard(str(filepath))
|
||||
|
||||
def is_protected(self, filepath: Path) -> bool:
|
||||
"""Check if a file is protected"""
|
||||
with self._lock:
|
||||
return str(filepath) in self._protected_files
|
||||
|
||||
async def start(self):
|
||||
"""Start the cleanup background task"""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._task = asyncio.create_task(self._cleanup_loop())
|
||||
logger.info("File cleanup manager started")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the cleanup background task"""
|
||||
self._running = False
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("File cleanup manager stopped")
|
||||
|
||||
async def _cleanup_loop(self):
|
||||
"""Background loop for periodic cleanup"""
|
||||
while self._running:
|
||||
try:
|
||||
await self.cleanup()
|
||||
await self.cleanup_expired()
|
||||
self._stats["cleanup_runs"] += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Cleanup error: {e}")
|
||||
|
||||
await asyncio.sleep(self.cleanup_interval)
|
||||
|
||||
async def cleanup(self) -> dict:
|
||||
"""Perform cleanup of old files"""
|
||||
stats = {
|
||||
"files_deleted": 0,
|
||||
"bytes_freed": 0,
|
||||
"errors": []
|
||||
}
|
||||
|
||||
now = time.time()
|
||||
|
||||
# Cleanup each directory
|
||||
for directory in [self.upload_dir, self.output_dir, self.temp_dir]:
|
||||
if not directory.exists():
|
||||
continue
|
||||
|
||||
for filepath in directory.iterdir():
|
||||
if not filepath.is_file():
|
||||
continue
|
||||
|
||||
# Skip protected files
|
||||
if self.is_protected(filepath):
|
||||
continue
|
||||
|
||||
try:
|
||||
# Check file age
|
||||
file_age = now - filepath.stat().st_mtime
|
||||
|
||||
if file_age > self.max_file_age_seconds:
|
||||
file_size = filepath.stat().st_size
|
||||
filepath.unlink()
|
||||
stats["files_deleted"] += 1
|
||||
stats["bytes_freed"] += file_size
|
||||
logger.debug(f"Deleted old file: {filepath}")
|
||||
|
||||
except Exception as e:
|
||||
stats["errors"].append(str(e))
|
||||
logger.warning(f"Failed to delete {filepath}: {e}")
|
||||
|
||||
# Force cleanup if total size exceeds limit
|
||||
await self._enforce_size_limit(stats)
|
||||
|
||||
if stats["files_deleted"] > 0:
|
||||
mb_freed = stats["bytes_freed"] / (1024 * 1024)
|
||||
logger.info(f"Cleanup: deleted {stats['files_deleted']} files, freed {mb_freed:.2f}MB")
|
||||
|
||||
return stats
|
||||
|
||||
async def _enforce_size_limit(self, stats: dict):
|
||||
"""Delete oldest files if total size exceeds limit"""
|
||||
files_with_mtime = []
|
||||
total_size = 0
|
||||
|
||||
for directory in [self.upload_dir, self.output_dir, self.temp_dir]:
|
||||
if not directory.exists():
|
||||
continue
|
||||
|
||||
for filepath in directory.iterdir():
|
||||
if not filepath.is_file() or self.is_protected(filepath):
|
||||
continue
|
||||
|
||||
try:
|
||||
stat = filepath.stat()
|
||||
files_with_mtime.append((filepath, stat.st_mtime, stat.st_size))
|
||||
total_size += stat.st_size
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# If under limit, nothing to do
|
||||
if total_size <= self.max_total_size_bytes:
|
||||
return
|
||||
|
||||
# Sort by modification time (oldest first)
|
||||
files_with_mtime.sort(key=lambda x: x[1])
|
||||
|
||||
# Delete oldest files until under limit
|
||||
for filepath, _, size in files_with_mtime:
|
||||
if total_size <= self.max_total_size_bytes:
|
||||
break
|
||||
|
||||
try:
|
||||
filepath.unlink()
|
||||
total_size -= size
|
||||
stats["files_deleted"] += 1
|
||||
stats["bytes_freed"] += size
|
||||
logger.info(f"Deleted file to free space: {filepath}")
|
||||
except Exception as e:
|
||||
stats["errors"].append(str(e))
|
||||
|
||||
def get_disk_usage(self) -> dict:
|
||||
"""Get current disk usage statistics"""
|
||||
total_files = 0
|
||||
total_size = 0
|
||||
|
||||
for directory in [self.upload_dir, self.output_dir, self.temp_dir]:
|
||||
if not directory.exists():
|
||||
continue
|
||||
|
||||
for filepath in directory.iterdir():
|
||||
if filepath.is_file():
|
||||
total_files += 1
|
||||
try:
|
||||
total_size += filepath.stat().st_size
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {
|
||||
"total_files": total_files,
|
||||
"total_size_mb": round(total_size / (1024 * 1024), 2),
|
||||
"max_size_gb": self.max_total_size_bytes / (1024 * 1024 * 1024),
|
||||
"usage_percent": round((total_size / self.max_total_size_bytes) * 100, 1) if self.max_total_size_bytes > 0 else 0,
|
||||
"directories": {
|
||||
"uploads": str(self.upload_dir),
|
||||
"outputs": str(self.output_dir),
|
||||
"temp": str(self.temp_dir)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class MemoryMonitor:
|
||||
"""Monitors memory usage and triggers cleanup if needed"""
|
||||
|
||||
def __init__(self, max_memory_percent: float = 80.0):
|
||||
self.max_memory_percent = max_memory_percent
|
||||
self._high_memory_callbacks = []
|
||||
|
||||
def get_memory_usage(self) -> dict:
|
||||
"""Get current memory usage"""
|
||||
try:
|
||||
import psutil
|
||||
process = psutil.Process()
|
||||
memory_info = process.memory_info()
|
||||
system_memory = psutil.virtual_memory()
|
||||
|
||||
return {
|
||||
"process_rss_mb": round(memory_info.rss / (1024 * 1024), 2),
|
||||
"process_vms_mb": round(memory_info.vms / (1024 * 1024), 2),
|
||||
"system_total_gb": round(system_memory.total / (1024 * 1024 * 1024), 2),
|
||||
"system_available_gb": round(system_memory.available / (1024 * 1024 * 1024), 2),
|
||||
"system_percent": system_memory.percent
|
||||
}
|
||||
except ImportError:
|
||||
return {"error": "psutil not installed"}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
def check_memory(self) -> bool:
|
||||
"""Check if memory usage is within limits"""
|
||||
usage = self.get_memory_usage()
|
||||
if "error" in usage:
|
||||
return True # Can't check, assume OK
|
||||
|
||||
return usage.get("system_percent", 0) < self.max_memory_percent
|
||||
|
||||
def on_high_memory(self, callback):
|
||||
"""Register callback for high memory situations"""
|
||||
self._high_memory_callbacks.append(callback)
|
||||
|
||||
|
||||
class HealthChecker:
|
||||
"""Comprehensive health checking for the application"""
|
||||
|
||||
def __init__(self, cleanup_manager: FileCleanupManager, memory_monitor: MemoryMonitor):
|
||||
self.cleanup_manager = cleanup_manager
|
||||
self.memory_monitor = memory_monitor
|
||||
self.start_time = datetime.now()
|
||||
self._translation_count = 0
|
||||
self._error_count = 0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def record_translation(self, success: bool = True):
|
||||
"""Record a translation attempt"""
|
||||
with self._lock:
|
||||
self._translation_count += 1
|
||||
if not success:
|
||||
self._error_count += 1
|
||||
|
||||
async def check_health(self) -> dict:
|
||||
"""Get comprehensive health status (async version)"""
|
||||
return self.get_health()
|
||||
|
||||
def get_health(self) -> dict:
|
||||
"""Get comprehensive health status"""
|
||||
memory = self.memory_monitor.get_memory_usage()
|
||||
disk = self.cleanup_manager.get_disk_usage()
|
||||
|
||||
# Determine overall status
|
||||
status = "healthy"
|
||||
issues = []
|
||||
|
||||
if "error" not in memory:
|
||||
if memory.get("system_percent", 0) > 90:
|
||||
status = "degraded"
|
||||
issues.append("High memory usage")
|
||||
elif memory.get("system_percent", 0) > 80:
|
||||
issues.append("Memory usage elevated")
|
||||
|
||||
if disk.get("usage_percent", 0) > 90:
|
||||
status = "degraded"
|
||||
issues.append("High disk usage")
|
||||
elif disk.get("usage_percent", 0) > 80:
|
||||
issues.append("Disk usage elevated")
|
||||
|
||||
uptime = datetime.now() - self.start_time
|
||||
|
||||
return {
|
||||
"status": status,
|
||||
"issues": issues,
|
||||
"uptime_seconds": int(uptime.total_seconds()),
|
||||
"uptime_human": str(uptime).split('.')[0],
|
||||
"translations": {
|
||||
"total": self._translation_count,
|
||||
"errors": self._error_count,
|
||||
"success_rate": round(
|
||||
((self._translation_count - self._error_count) / self._translation_count * 100)
|
||||
if self._translation_count > 0 else 100, 1
|
||||
)
|
||||
},
|
||||
"memory": memory,
|
||||
"disk": disk,
|
||||
"cleanup_service": self.cleanup_manager.get_stats(),
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
|
||||
# Create default instances
|
||||
def create_cleanup_manager(config) -> FileCleanupManager:
|
||||
"""Create cleanup manager with config"""
|
||||
return FileCleanupManager(
|
||||
upload_dir=config.UPLOAD_DIR,
|
||||
output_dir=config.OUTPUT_DIR,
|
||||
temp_dir=config.TEMP_DIR,
|
||||
max_file_age_hours=getattr(config, 'MAX_FILE_AGE_HOURS', 1),
|
||||
cleanup_interval_minutes=getattr(config, 'CLEANUP_INTERVAL_MINUTES', 10),
|
||||
max_total_size_gb=getattr(config, 'MAX_TOTAL_SIZE_GB', 10.0)
|
||||
)
|
||||
328
middleware/rate_limiting.py
Normal file
328
middleware/rate_limiting.py
Normal file
@ -0,0 +1,328 @@
|
||||
"""
|
||||
Rate Limiting Middleware for SaaS robustness
|
||||
Protects against abuse and ensures fair usage
|
||||
"""
|
||||
import time
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional
|
||||
from fastapi import Request, HTTPException
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitConfig:
|
||||
"""Configuration for rate limiting"""
|
||||
# Requests per window
|
||||
requests_per_minute: int = 30
|
||||
requests_per_hour: int = 200
|
||||
requests_per_day: int = 1000
|
||||
|
||||
# Translation-specific limits
|
||||
translations_per_minute: int = 10
|
||||
translations_per_hour: int = 50
|
||||
max_concurrent_translations: int = 5
|
||||
|
||||
# File size limits (MB)
|
||||
max_file_size_mb: int = 50
|
||||
max_total_size_per_hour_mb: int = 500
|
||||
|
||||
# Burst protection
|
||||
burst_limit: int = 10 # Max requests in 1 second
|
||||
|
||||
# Whitelist IPs (no rate limiting)
|
||||
whitelist_ips: list = field(default_factory=lambda: ["127.0.0.1", "::1"])
|
||||
|
||||
|
||||
class TokenBucket:
|
||||
"""Token bucket algorithm for rate limiting"""
|
||||
|
||||
def __init__(self, capacity: int, refill_rate: float):
|
||||
self.capacity = capacity
|
||||
self.refill_rate = refill_rate # tokens per second
|
||||
self.tokens = capacity
|
||||
self.last_refill = time.time()
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def consume(self, tokens: int = 1) -> bool:
|
||||
"""Try to consume tokens, return True if successful"""
|
||||
async with self._lock:
|
||||
self._refill()
|
||||
if self.tokens >= tokens:
|
||||
self.tokens -= tokens
|
||||
return True
|
||||
return False
|
||||
|
||||
def _refill(self):
|
||||
"""Refill tokens based on time elapsed"""
|
||||
now = time.time()
|
||||
elapsed = now - self.last_refill
|
||||
self.tokens = min(self.capacity, self.tokens + elapsed * self.refill_rate)
|
||||
self.last_refill = now
|
||||
|
||||
|
||||
class SlidingWindowCounter:
|
||||
"""Sliding window counter for accurate rate limiting"""
|
||||
|
||||
def __init__(self, window_seconds: int, max_requests: int):
|
||||
self.window_seconds = window_seconds
|
||||
self.max_requests = max_requests
|
||||
self.requests: list = []
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def is_allowed(self) -> bool:
|
||||
"""Check if a new request is allowed"""
|
||||
async with self._lock:
|
||||
now = time.time()
|
||||
# Remove old requests outside the window
|
||||
self.requests = [ts for ts in self.requests if now - ts < self.window_seconds]
|
||||
|
||||
if len(self.requests) < self.max_requests:
|
||||
self.requests.append(now)
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def current_count(self) -> int:
|
||||
"""Get current request count in window"""
|
||||
now = time.time()
|
||||
return len([ts for ts in self.requests if now - ts < self.window_seconds])
|
||||
|
||||
|
||||
class ClientRateLimiter:
|
||||
"""Per-client rate limiter with multiple windows"""
|
||||
|
||||
def __init__(self, config: RateLimitConfig):
|
||||
self.config = config
|
||||
self.minute_counter = SlidingWindowCounter(60, config.requests_per_minute)
|
||||
self.hour_counter = SlidingWindowCounter(3600, config.requests_per_hour)
|
||||
self.day_counter = SlidingWindowCounter(86400, config.requests_per_day)
|
||||
self.burst_bucket = TokenBucket(config.burst_limit, config.burst_limit)
|
||||
self.translation_minute = SlidingWindowCounter(60, config.translations_per_minute)
|
||||
self.translation_hour = SlidingWindowCounter(3600, config.translations_per_hour)
|
||||
self.concurrent_translations = 0
|
||||
self.total_size_hour: list = [] # List of (timestamp, size_mb)
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def check_request(self) -> tuple[bool, str]:
|
||||
"""Check if request is allowed, return (allowed, reason)"""
|
||||
# Check burst limit
|
||||
if not await self.burst_bucket.consume():
|
||||
return False, "Too many requests. Please slow down."
|
||||
|
||||
# Check minute limit
|
||||
if not await self.minute_counter.is_allowed():
|
||||
return False, f"Rate limit exceeded. Max {self.config.requests_per_minute} requests per minute."
|
||||
|
||||
# Check hour limit
|
||||
if not await self.hour_counter.is_allowed():
|
||||
return False, f"Hourly limit exceeded. Max {self.config.requests_per_hour} requests per hour."
|
||||
|
||||
# Check day limit
|
||||
if not await self.day_counter.is_allowed():
|
||||
return False, f"Daily limit exceeded. Max {self.config.requests_per_day} requests per day."
|
||||
|
||||
return True, ""
|
||||
|
||||
async def check_translation(self, file_size_mb: float = 0) -> tuple[bool, str]:
|
||||
"""Check if translation request is allowed"""
|
||||
async with self._lock:
|
||||
# Check concurrent limit
|
||||
if self.concurrent_translations >= self.config.max_concurrent_translations:
|
||||
return False, f"Too many concurrent translations. Max {self.config.max_concurrent_translations} at a time."
|
||||
|
||||
# Check translation per minute
|
||||
if not await self.translation_minute.is_allowed():
|
||||
return False, f"Translation rate limit exceeded. Max {self.config.translations_per_minute} translations per minute."
|
||||
|
||||
# Check translation per hour
|
||||
if not await self.translation_hour.is_allowed():
|
||||
return False, f"Hourly translation limit exceeded. Max {self.config.translations_per_hour} translations per hour."
|
||||
|
||||
# Check total size per hour
|
||||
async with self._lock:
|
||||
now = time.time()
|
||||
self.total_size_hour = [(ts, size) for ts, size in self.total_size_hour if now - ts < 3600]
|
||||
total_size = sum(size for _, size in self.total_size_hour)
|
||||
|
||||
if total_size + file_size_mb > self.config.max_total_size_per_hour_mb:
|
||||
return False, f"Hourly data limit exceeded. Max {self.config.max_total_size_per_hour_mb}MB per hour."
|
||||
|
||||
self.total_size_hour.append((now, file_size_mb))
|
||||
|
||||
return True, ""
|
||||
|
||||
async def start_translation(self):
|
||||
"""Mark start of translation"""
|
||||
async with self._lock:
|
||||
self.concurrent_translations += 1
|
||||
|
||||
async def end_translation(self):
|
||||
"""Mark end of translation"""
|
||||
async with self._lock:
|
||||
self.concurrent_translations = max(0, self.concurrent_translations - 1)
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get current rate limit stats"""
|
||||
return {
|
||||
"requests_minute": self.minute_counter.current_count,
|
||||
"requests_hour": self.hour_counter.current_count,
|
||||
"requests_day": self.day_counter.current_count,
|
||||
"translations_minute": self.translation_minute.current_count,
|
||||
"translations_hour": self.translation_hour.current_count,
|
||||
"concurrent_translations": self.concurrent_translations,
|
||||
}
|
||||
|
||||
|
||||
class RateLimitManager:
|
||||
"""Manages rate limiters for all clients"""
|
||||
|
||||
def __init__(self, config: Optional[RateLimitConfig] = None):
|
||||
self.config = config or RateLimitConfig()
|
||||
self.clients: Dict[str, ClientRateLimiter] = defaultdict(lambda: ClientRateLimiter(self.config))
|
||||
self._cleanup_interval = 3600 # Cleanup old clients every hour
|
||||
self._last_cleanup = time.time()
|
||||
self._total_requests = 0
|
||||
self._total_translations = 0
|
||||
|
||||
def get_client_id(self, request: Request) -> str:
|
||||
"""Extract client identifier from request"""
|
||||
# Try to get real IP from headers (for proxied requests)
|
||||
forwarded = request.headers.get("X-Forwarded-For")
|
||||
if forwarded:
|
||||
return forwarded.split(",")[0].strip()
|
||||
|
||||
real_ip = request.headers.get("X-Real-IP")
|
||||
if real_ip:
|
||||
return real_ip
|
||||
|
||||
# Fall back to direct client IP
|
||||
if request.client:
|
||||
return request.client.host
|
||||
|
||||
return "unknown"
|
||||
|
||||
def is_whitelisted(self, client_id: str) -> bool:
|
||||
"""Check if client is whitelisted"""
|
||||
return client_id in self.config.whitelist_ips
|
||||
|
||||
async def check_request(self, request: Request) -> tuple[bool, str, str]:
|
||||
"""Check if request is allowed, return (allowed, reason, client_id)"""
|
||||
client_id = self.get_client_id(request)
|
||||
self._total_requests += 1
|
||||
|
||||
if self.is_whitelisted(client_id):
|
||||
return True, "", client_id
|
||||
|
||||
client = self.clients[client_id]
|
||||
allowed, reason = await client.check_request()
|
||||
|
||||
return allowed, reason, client_id
|
||||
|
||||
async def check_translation(self, request: Request, file_size_mb: float = 0) -> tuple[bool, str]:
|
||||
"""Check if translation is allowed"""
|
||||
client_id = self.get_client_id(request)
|
||||
self._total_translations += 1
|
||||
|
||||
if self.is_whitelisted(client_id):
|
||||
return True, ""
|
||||
|
||||
client = self.clients[client_id]
|
||||
return await client.check_translation(file_size_mb)
|
||||
|
||||
async def check_translation_limit(self, client_id: str, file_size_mb: float = 0) -> bool:
|
||||
"""Check if translation is allowed for a specific client ID"""
|
||||
if self.is_whitelisted(client_id):
|
||||
return True
|
||||
|
||||
client = self.clients[client_id]
|
||||
allowed, _ = await client.check_translation(file_size_mb)
|
||||
return allowed
|
||||
|
||||
def get_client_stats(self, request: Request) -> dict:
|
||||
"""Get rate limit stats for a client"""
|
||||
client_id = self.get_client_id(request)
|
||||
client = self.clients[client_id]
|
||||
return {
|
||||
"client_id": client_id,
|
||||
"is_whitelisted": self.is_whitelisted(client_id),
|
||||
**client.get_stats()
|
||||
}
|
||||
|
||||
async def get_client_status(self, client_id: str) -> dict:
|
||||
"""Get current usage status for a client"""
|
||||
if client_id not in self.clients:
|
||||
return {"status": "no_activity", "requests": 0}
|
||||
|
||||
client = self.clients[client_id]
|
||||
stats = client.get_stats()
|
||||
|
||||
return {
|
||||
"requests_used_minute": stats["requests_minute"],
|
||||
"requests_used_hour": stats["requests_hour"],
|
||||
"translations_used_minute": stats["translations_minute"],
|
||||
"translations_used_hour": stats["translations_hour"],
|
||||
"concurrent_translations": stats["concurrent_translations"],
|
||||
"is_whitelisted": self.is_whitelisted(client_id)
|
||||
}
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get global rate limiting statistics"""
|
||||
return {
|
||||
"total_requests": self._total_requests,
|
||||
"total_translations": self._total_translations,
|
||||
"active_clients": len(self.clients),
|
||||
"config": {
|
||||
"requests_per_minute": self.config.requests_per_minute,
|
||||
"requests_per_hour": self.config.requests_per_hour,
|
||||
"translations_per_minute": self.config.translations_per_minute,
|
||||
"translations_per_hour": self.config.translations_per_hour,
|
||||
"max_concurrent_translations": self.config.max_concurrent_translations
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""FastAPI middleware for rate limiting"""
|
||||
|
||||
def __init__(self, app, rate_limit_manager: RateLimitManager):
|
||||
super().__init__(app)
|
||||
self.manager = rate_limit_manager
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# Skip rate limiting for health checks and static files
|
||||
if request.url.path in ["/health", "/", "/docs", "/openapi.json", "/redoc"]:
|
||||
return await call_next(request)
|
||||
|
||||
if request.url.path.startswith("/static"):
|
||||
return await call_next(request)
|
||||
|
||||
# Check rate limit
|
||||
allowed, reason, client_id = await self.manager.check_request(request)
|
||||
|
||||
if not allowed:
|
||||
logger.warning(f"Rate limit exceeded for {client_id}: {reason}")
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"error": "rate_limit_exceeded",
|
||||
"message": reason,
|
||||
"retry_after": 60
|
||||
},
|
||||
headers={"Retry-After": "60"}
|
||||
)
|
||||
|
||||
# Add client info to request state for use in endpoints
|
||||
request.state.client_id = client_id
|
||||
request.state.rate_limiter = self.manager.clients[client_id]
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
# Global rate limit manager
|
||||
rate_limit_manager = RateLimitManager()
|
||||
142
middleware/security.py
Normal file
142
middleware/security.py
Normal file
@ -0,0 +1,142 @@
|
||||
"""
|
||||
Security Headers Middleware for SaaS robustness
|
||||
Adds security headers to all responses
|
||||
"""
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
"""Add security headers to all responses"""
|
||||
|
||||
def __init__(self, app, config: dict = None):
|
||||
super().__init__(app)
|
||||
self.config = config or {}
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
response = await call_next(request)
|
||||
|
||||
# Prevent clickjacking
|
||||
response.headers["X-Frame-Options"] = "DENY"
|
||||
|
||||
# Prevent MIME type sniffing
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
|
||||
# Enable XSS filter
|
||||
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||||
|
||||
# Referrer policy
|
||||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||
|
||||
# Permissions policy
|
||||
response.headers["Permissions-Policy"] = "geolocation=(), microphone=(), camera=()"
|
||||
|
||||
# Content Security Policy (adjust for your frontend)
|
||||
if not request.url.path.startswith("/docs") and not request.url.path.startswith("/redoc"):
|
||||
response.headers["Content-Security-Policy"] = (
|
||||
"default-src 'self'; "
|
||||
"script-src 'self' 'unsafe-inline' 'unsafe-eval' blob:; "
|
||||
"style-src 'self' 'unsafe-inline'; "
|
||||
"img-src 'self' data: blob:; "
|
||||
"font-src 'self' data:; "
|
||||
"connect-src 'self' http://localhost:* https://localhost:* ws://localhost:*; "
|
||||
"worker-src 'self' blob:; "
|
||||
)
|
||||
|
||||
# HSTS (only in production with HTTPS)
|
||||
if self.config.get("enable_hsts", False):
|
||||
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
||||
"""Log all requests for monitoring and debugging"""
|
||||
|
||||
def __init__(self, app, log_body: bool = False):
|
||||
super().__init__(app)
|
||||
self.log_body = log_body
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
import time
|
||||
import uuid
|
||||
|
||||
# Generate request ID
|
||||
request_id = str(uuid.uuid4())[:8]
|
||||
request.state.request_id = request_id
|
||||
|
||||
# Get client info
|
||||
client_ip = self._get_client_ip(request)
|
||||
|
||||
# Log request start
|
||||
start_time = time.time()
|
||||
logger.info(
|
||||
f"[{request_id}] {request.method} {request.url.path} "
|
||||
f"from {client_ip} - Started"
|
||||
)
|
||||
|
||||
try:
|
||||
response = await call_next(request)
|
||||
|
||||
# Log request completion
|
||||
duration = time.time() - start_time
|
||||
logger.info(
|
||||
f"[{request_id}] {request.method} {request.url.path} "
|
||||
f"- {response.status_code} in {duration:.3f}s"
|
||||
)
|
||||
|
||||
# Add request ID to response headers
|
||||
response.headers["X-Request-ID"] = request_id
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
duration = time.time() - start_time
|
||||
logger.error(
|
||||
f"[{request_id}] {request.method} {request.url.path} "
|
||||
f"- ERROR in {duration:.3f}s: {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""Get real client IP from headers or connection"""
|
||||
forwarded = request.headers.get("X-Forwarded-For")
|
||||
if forwarded:
|
||||
return forwarded.split(",")[0].strip()
|
||||
|
||||
real_ip = request.headers.get("X-Real-IP")
|
||||
if real_ip:
|
||||
return real_ip
|
||||
|
||||
if request.client:
|
||||
return request.client.host
|
||||
|
||||
return "unknown"
|
||||
|
||||
|
||||
class ErrorHandlingMiddleware(BaseHTTPMiddleware):
|
||||
"""Catch all unhandled exceptions and return proper error responses"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
try:
|
||||
return await call_next(request)
|
||||
|
||||
except Exception as e:
|
||||
request_id = getattr(request.state, 'request_id', 'unknown')
|
||||
logger.exception(f"[{request_id}] Unhandled exception: {str(e)}")
|
||||
|
||||
# Don't expose internal errors in production
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"error": "internal_server_error",
|
||||
"message": "An unexpected error occurred. Please try again later.",
|
||||
"request_id": request_id
|
||||
}
|
||||
)
|
||||
440
middleware/validation.py
Normal file
440
middleware/validation.py
Normal file
@ -0,0 +1,440 @@
|
||||
"""
|
||||
Input Validation Module for SaaS robustness
|
||||
Validates all user inputs before processing
|
||||
"""
|
||||
import re
|
||||
import magic
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Set
|
||||
from fastapi import UploadFile, HTTPException
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ValidationError(Exception):
|
||||
"""Custom validation error with user-friendly messages"""
|
||||
def __init__(self, message: str, code: str = "validation_error", details: Optional[dict] = None):
|
||||
self.message = message
|
||||
self.code = code
|
||||
self.details = details or {}
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ValidationResult:
|
||||
"""Result of a validation check"""
|
||||
def __init__(self, is_valid: bool = True, errors: List[str] = None, warnings: List[str] = None, data: dict = None):
|
||||
self.is_valid = is_valid
|
||||
self.errors = errors or []
|
||||
self.warnings = warnings or []
|
||||
self.data = data or {}
|
||||
|
||||
|
||||
class FileValidator:
|
||||
"""Validates uploaded files for security and compatibility"""
|
||||
|
||||
# Allowed MIME types mapped to extensions
|
||||
ALLOWED_MIME_TYPES = {
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx",
|
||||
}
|
||||
|
||||
# Magic bytes for Office Open XML files (ZIP format)
|
||||
OFFICE_MAGIC_BYTES = b"PK\x03\x04"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_size_mb: int = 50,
|
||||
allowed_extensions: Set[str] = None,
|
||||
scan_content: bool = True
|
||||
):
|
||||
self.max_size_bytes = max_size_mb * 1024 * 1024
|
||||
self.max_size_mb = max_size_mb
|
||||
self.allowed_extensions = allowed_extensions or {".xlsx", ".docx", ".pptx"}
|
||||
self.scan_content = scan_content
|
||||
|
||||
async def validate_async(self, file: UploadFile) -> ValidationResult:
|
||||
"""
|
||||
Validate an uploaded file asynchronously
|
||||
Returns ValidationResult with is_valid, errors, warnings
|
||||
"""
|
||||
errors = []
|
||||
warnings = []
|
||||
data = {}
|
||||
|
||||
try:
|
||||
# Validate filename
|
||||
if not file.filename:
|
||||
errors.append("Filename is required")
|
||||
return ValidationResult(is_valid=False, errors=errors)
|
||||
|
||||
# Sanitize filename
|
||||
try:
|
||||
safe_filename = self._sanitize_filename(file.filename)
|
||||
data["safe_filename"] = safe_filename
|
||||
except ValidationError as e:
|
||||
errors.append(str(e.message))
|
||||
return ValidationResult(is_valid=False, errors=errors)
|
||||
|
||||
# Validate extension
|
||||
try:
|
||||
extension = self._validate_extension(safe_filename)
|
||||
data["extension"] = extension
|
||||
except ValidationError as e:
|
||||
errors.append(str(e.message))
|
||||
return ValidationResult(is_valid=False, errors=errors)
|
||||
|
||||
# Read file content for validation
|
||||
content = await file.read()
|
||||
await file.seek(0) # Reset for later processing
|
||||
|
||||
# Validate file size
|
||||
file_size = len(content)
|
||||
data["size_bytes"] = file_size
|
||||
data["size_mb"] = round(file_size / (1024*1024), 2)
|
||||
|
||||
if file_size > self.max_size_bytes:
|
||||
errors.append(f"File too large. Maximum size is {self.max_size_mb}MB, got {file_size / (1024*1024):.1f}MB")
|
||||
return ValidationResult(is_valid=False, errors=errors, data=data)
|
||||
|
||||
if file_size == 0:
|
||||
errors.append("File is empty")
|
||||
return ValidationResult(is_valid=False, errors=errors, data=data)
|
||||
|
||||
# Warn about large files
|
||||
if file_size > self.max_size_bytes * 0.8:
|
||||
warnings.append(f"File is {data['size_mb']}MB, approaching the {self.max_size_mb}MB limit")
|
||||
|
||||
# Validate magic bytes
|
||||
if self.scan_content:
|
||||
try:
|
||||
self._validate_magic_bytes(content, extension)
|
||||
except ValidationError as e:
|
||||
errors.append(str(e.message))
|
||||
return ValidationResult(is_valid=False, errors=errors, data=data)
|
||||
|
||||
# Validate MIME type
|
||||
try:
|
||||
mime_type = self._detect_mime_type(content)
|
||||
data["mime_type"] = mime_type
|
||||
self._validate_mime_type(mime_type, extension)
|
||||
except ValidationError as e:
|
||||
warnings.append(f"MIME type warning: {e.message}")
|
||||
except Exception:
|
||||
warnings.append("Could not verify MIME type")
|
||||
|
||||
data["original_filename"] = file.filename
|
||||
|
||||
return ValidationResult(is_valid=True, errors=errors, warnings=warnings, data=data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Validation error: {str(e)}")
|
||||
errors.append(f"Validation failed: {str(e)}")
|
||||
return ValidationResult(is_valid=False, errors=errors, warnings=warnings, data=data)
|
||||
|
||||
async def validate(self, file: UploadFile) -> dict:
|
||||
"""
|
||||
Validate an uploaded file
|
||||
Returns validation info dict or raises ValidationError
|
||||
"""
|
||||
# Validate filename
|
||||
if not file.filename:
|
||||
raise ValidationError(
|
||||
"Filename is required",
|
||||
code="missing_filename"
|
||||
)
|
||||
|
||||
# Sanitize filename
|
||||
safe_filename = self._sanitize_filename(file.filename)
|
||||
|
||||
# Validate extension
|
||||
extension = self._validate_extension(safe_filename)
|
||||
|
||||
# Read file content for validation
|
||||
content = await file.read()
|
||||
await file.seek(0) # Reset for later processing
|
||||
|
||||
# Validate file size
|
||||
file_size = len(content)
|
||||
if file_size > self.max_size_bytes:
|
||||
raise ValidationError(
|
||||
f"File too large. Maximum size is {self.max_size_mb}MB, got {file_size / (1024*1024):.1f}MB",
|
||||
code="file_too_large",
|
||||
details={"max_mb": self.max_size_mb, "actual_mb": round(file_size / (1024*1024), 2)}
|
||||
)
|
||||
|
||||
if file_size == 0:
|
||||
raise ValidationError(
|
||||
"File is empty",
|
||||
code="empty_file"
|
||||
)
|
||||
|
||||
# Validate magic bytes (file signature)
|
||||
if self.scan_content:
|
||||
self._validate_magic_bytes(content, extension)
|
||||
|
||||
# Validate MIME type
|
||||
mime_type = self._detect_mime_type(content)
|
||||
self._validate_mime_type(mime_type, extension)
|
||||
|
||||
return {
|
||||
"original_filename": file.filename,
|
||||
"safe_filename": safe_filename,
|
||||
"extension": extension,
|
||||
"size_bytes": file_size,
|
||||
"size_mb": round(file_size / (1024*1024), 2),
|
||||
"mime_type": mime_type
|
||||
}
|
||||
|
||||
def _sanitize_filename(self, filename: str) -> str:
|
||||
"""Sanitize filename to prevent path traversal and other attacks"""
|
||||
# Remove path components
|
||||
filename = Path(filename).name
|
||||
|
||||
# Remove null bytes and control characters
|
||||
filename = re.sub(r'[\x00-\x1f\x7f-\x9f]', '', filename)
|
||||
|
||||
# Remove potentially dangerous characters
|
||||
filename = re.sub(r'[<>:"/\\|?*]', '_', filename)
|
||||
|
||||
# Limit length
|
||||
if len(filename) > 255:
|
||||
name, ext = filename.rsplit('.', 1) if '.' in filename else (filename, '')
|
||||
filename = name[:250] + ('.' + ext if ext else '')
|
||||
|
||||
# Ensure not empty after sanitization
|
||||
if not filename or filename.strip() == '':
|
||||
raise ValidationError(
|
||||
"Invalid filename",
|
||||
code="invalid_filename"
|
||||
)
|
||||
|
||||
return filename
|
||||
|
||||
def _validate_extension(self, filename: str) -> str:
|
||||
"""Validate and return the file extension"""
|
||||
if '.' not in filename:
|
||||
raise ValidationError(
|
||||
f"File must have an extension. Supported: {', '.join(self.allowed_extensions)}",
|
||||
code="missing_extension",
|
||||
details={"allowed_extensions": list(self.allowed_extensions)}
|
||||
)
|
||||
|
||||
extension = '.' + filename.rsplit('.', 1)[1].lower()
|
||||
|
||||
if extension not in self.allowed_extensions:
|
||||
raise ValidationError(
|
||||
f"File type '{extension}' not supported. Supported types: {', '.join(self.allowed_extensions)}",
|
||||
code="unsupported_file_type",
|
||||
details={"extension": extension, "allowed_extensions": list(self.allowed_extensions)}
|
||||
)
|
||||
|
||||
return extension
|
||||
|
||||
def _validate_magic_bytes(self, content: bytes, extension: str):
|
||||
"""Validate file magic bytes match expected format"""
|
||||
# All supported formats are Office Open XML (ZIP-based)
|
||||
if not content.startswith(self.OFFICE_MAGIC_BYTES):
|
||||
raise ValidationError(
|
||||
"File content does not match expected format. The file may be corrupted or not a valid Office document.",
|
||||
code="invalid_file_content"
|
||||
)
|
||||
|
||||
def _detect_mime_type(self, content: bytes) -> str:
|
||||
"""Detect MIME type from file content"""
|
||||
try:
|
||||
mime = magic.Magic(mime=True)
|
||||
return mime.from_buffer(content)
|
||||
except Exception:
|
||||
# Fallback to basic detection
|
||||
if content.startswith(self.OFFICE_MAGIC_BYTES):
|
||||
return "application/zip"
|
||||
return "application/octet-stream"
|
||||
|
||||
def _validate_mime_type(self, mime_type: str, extension: str):
|
||||
"""Validate MIME type matches extension"""
|
||||
# Office Open XML files may be detected as ZIP
|
||||
allowed_mimes = list(self.ALLOWED_MIME_TYPES.keys()) + ["application/zip", "application/octet-stream"]
|
||||
|
||||
if mime_type not in allowed_mimes:
|
||||
raise ValidationError(
|
||||
f"Invalid file type detected. Expected Office document, got: {mime_type}",
|
||||
code="invalid_mime_type",
|
||||
details={"detected_mime": mime_type}
|
||||
)
|
||||
|
||||
|
||||
class LanguageValidator:
|
||||
"""Validates language codes"""
|
||||
|
||||
SUPPORTED_LANGUAGES = {
|
||||
# ISO 639-1 codes
|
||||
"af", "sq", "am", "ar", "hy", "az", "eu", "be", "bn", "bs",
|
||||
"bg", "ca", "ceb", "zh", "zh-CN", "zh-TW", "co", "hr", "cs",
|
||||
"da", "nl", "en", "eo", "et", "fi", "fr", "fy", "gl", "ka",
|
||||
"de", "el", "gu", "ht", "ha", "haw", "he", "hi", "hmn", "hu",
|
||||
"is", "ig", "id", "ga", "it", "ja", "jv", "kn", "kk", "km",
|
||||
"rw", "ko", "ku", "ky", "lo", "la", "lv", "lt", "lb", "mk",
|
||||
"mg", "ms", "ml", "mt", "mi", "mr", "mn", "my", "ne", "no",
|
||||
"ny", "or", "ps", "fa", "pl", "pt", "pa", "ro", "ru", "sm",
|
||||
"gd", "sr", "st", "sn", "sd", "si", "sk", "sl", "so", "es",
|
||||
"su", "sw", "sv", "tl", "tg", "ta", "tt", "te", "th", "tr",
|
||||
"tk", "uk", "ur", "ug", "uz", "vi", "cy", "xh", "yi", "yo",
|
||||
"zu", "auto"
|
||||
}
|
||||
|
||||
LANGUAGE_NAMES = {
|
||||
"en": "English", "es": "Spanish", "fr": "French", "de": "German",
|
||||
"it": "Italian", "pt": "Portuguese", "ru": "Russian", "zh": "Chinese",
|
||||
"zh-CN": "Chinese (Simplified)", "zh-TW": "Chinese (Traditional)",
|
||||
"ja": "Japanese", "ko": "Korean", "ar": "Arabic", "hi": "Hindi",
|
||||
"nl": "Dutch", "pl": "Polish", "tr": "Turkish", "sv": "Swedish",
|
||||
"da": "Danish", "no": "Norwegian", "fi": "Finnish", "cs": "Czech",
|
||||
"el": "Greek", "th": "Thai", "vi": "Vietnamese", "id": "Indonesian",
|
||||
"uk": "Ukrainian", "ro": "Romanian", "hu": "Hungarian", "auto": "Auto-detect"
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def validate(cls, language_code: str, field_name: str = "language") -> str:
|
||||
"""Validate and normalize language code"""
|
||||
if not language_code:
|
||||
raise ValidationError(
|
||||
f"{field_name} is required",
|
||||
code="missing_language"
|
||||
)
|
||||
|
||||
# Normalize
|
||||
normalized = language_code.strip().lower()
|
||||
|
||||
# Handle common variations
|
||||
if normalized in ["chinese", "cn"]:
|
||||
normalized = "zh-CN"
|
||||
elif normalized in ["chinese-traditional", "tw"]:
|
||||
normalized = "zh-TW"
|
||||
|
||||
if normalized not in cls.SUPPORTED_LANGUAGES:
|
||||
raise ValidationError(
|
||||
f"Unsupported language code: '{language_code}'. See /languages for supported codes.",
|
||||
code="unsupported_language",
|
||||
details={"language": language_code}
|
||||
)
|
||||
|
||||
return normalized
|
||||
|
||||
@classmethod
|
||||
def get_language_name(cls, code: str) -> str:
|
||||
"""Get human-readable language name"""
|
||||
return cls.LANGUAGE_NAMES.get(code, code.upper())
|
||||
|
||||
|
||||
class ProviderValidator:
|
||||
"""Validates translation provider configuration"""
|
||||
|
||||
SUPPORTED_PROVIDERS = {"google", "ollama", "deepl", "libre", "openai", "webllm"}
|
||||
|
||||
@classmethod
|
||||
def validate(cls, provider: str, **kwargs) -> dict:
|
||||
"""Validate provider and its required configuration"""
|
||||
if not provider:
|
||||
raise ValidationError(
|
||||
"Translation provider is required",
|
||||
code="missing_provider"
|
||||
)
|
||||
|
||||
normalized = provider.strip().lower()
|
||||
|
||||
if normalized not in cls.SUPPORTED_PROVIDERS:
|
||||
raise ValidationError(
|
||||
f"Unsupported provider: '{provider}'. Supported: {', '.join(cls.SUPPORTED_PROVIDERS)}",
|
||||
code="unsupported_provider",
|
||||
details={"provider": provider, "supported": list(cls.SUPPORTED_PROVIDERS)}
|
||||
)
|
||||
|
||||
# Provider-specific validation
|
||||
if normalized == "deepl":
|
||||
if not kwargs.get("deepl_api_key"):
|
||||
raise ValidationError(
|
||||
"DeepL API key is required when using DeepL provider",
|
||||
code="missing_deepl_key"
|
||||
)
|
||||
|
||||
elif normalized == "openai":
|
||||
if not kwargs.get("openai_api_key"):
|
||||
raise ValidationError(
|
||||
"OpenAI API key is required when using OpenAI provider",
|
||||
code="missing_openai_key"
|
||||
)
|
||||
|
||||
elif normalized == "ollama":
|
||||
# Ollama doesn't require API key but may need model
|
||||
model = kwargs.get("ollama_model", "")
|
||||
if not model:
|
||||
logger.warning("No Ollama model specified, will use default")
|
||||
|
||||
return {"provider": normalized, "validated": True}
|
||||
|
||||
|
||||
class InputSanitizer:
|
||||
"""Sanitizes user inputs to prevent injection attacks"""
|
||||
|
||||
@staticmethod
|
||||
def sanitize_text(text: str, max_length: int = 10000) -> str:
|
||||
"""Sanitize text input"""
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
# Remove null bytes
|
||||
text = text.replace('\x00', '')
|
||||
|
||||
# Limit length
|
||||
if len(text) > max_length:
|
||||
text = text[:max_length]
|
||||
|
||||
return text.strip()
|
||||
|
||||
@staticmethod
|
||||
def sanitize_language_code(code: str) -> str:
|
||||
"""Sanitize and normalize language code"""
|
||||
if not code:
|
||||
return "auto"
|
||||
|
||||
# Remove dangerous characters, keep only alphanumeric and hyphen
|
||||
code = re.sub(r'[^a-zA-Z0-9\-]', '', code.strip())
|
||||
|
||||
# Limit length
|
||||
if len(code) > 10:
|
||||
code = code[:10]
|
||||
|
||||
return code.lower() if code else "auto"
|
||||
|
||||
@staticmethod
|
||||
def sanitize_url(url: str) -> str:
|
||||
"""Sanitize URL input"""
|
||||
if not url:
|
||||
return ""
|
||||
|
||||
url = url.strip()
|
||||
|
||||
# Basic URL validation
|
||||
if not re.match(r'^https?://', url, re.IGNORECASE):
|
||||
raise ValidationError(
|
||||
"Invalid URL format. Must start with http:// or https://",
|
||||
code="invalid_url"
|
||||
)
|
||||
|
||||
# Remove trailing slashes
|
||||
url = url.rstrip('/')
|
||||
|
||||
return url
|
||||
|
||||
@staticmethod
|
||||
def sanitize_api_key(key: str) -> str:
|
||||
"""Sanitize API key (just trim, no logging)"""
|
||||
if not key:
|
||||
return ""
|
||||
return key.strip()
|
||||
|
||||
|
||||
# Default validators
|
||||
file_validator = FileValidator()
|
||||
@ -14,3 +14,7 @@ pandas==2.1.4
|
||||
requests==2.31.0
|
||||
ipykernel==6.27.1
|
||||
openai>=1.0.0
|
||||
|
||||
# SaaS robustness dependencies
|
||||
psutil==5.9.8
|
||||
python-magic-bin==0.4.14 # For Windows, use python-magic on Linux
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user