From 500502440c6e30a1064bc12b00c69344831f4d82 Mon Sep 17 00:00:00 2001 From: Sepehr Date: Sun, 30 Nov 2025 19:25:09 +0100 Subject: [PATCH] feat: Add SaaS robustness middleware - Rate limiting with token bucket and sliding window algorithms - Input validation (file, language, provider) - Security headers middleware (CSP, XSS protection, etc.) - Automatic file cleanup with TTL tracking - Memory and disk monitoring - Enhanced health check and metrics endpoints - Request logging with unique IDs --- .env.example | 67 +++++- config.py | 48 +++- main.py | 216 ++++++++++++++++-- middleware/__init__.py | 62 +++++ middleware/cleanup.py | 400 ++++++++++++++++++++++++++++++++ middleware/rate_limiting.py | 328 +++++++++++++++++++++++++++ middleware/security.py | 142 ++++++++++++ middleware/validation.py | 440 ++++++++++++++++++++++++++++++++++++ requirements.txt | 4 + 9 files changed, 1681 insertions(+), 26 deletions(-) create mode 100644 middleware/__init__.py create mode 100644 middleware/cleanup.py create mode 100644 middleware/rate_limiting.py create mode 100644 middleware/security.py create mode 100644 middleware/validation.py diff --git a/.env.example b/.env.example index cd0b3d9..dd5f17f 100644 --- a/.env.example +++ b/.env.example @@ -1,5 +1,11 @@ -# Translation Service Configuration -TRANSLATION_SERVICE=google # Options: google, deepl, libre, ollama +# Document Translation API - Environment Configuration +# Copy this file to .env and configure your settings + +# ============== Translation Services ============== +# Default provider: google, ollama, deepl, libre, openai +TRANSLATION_SERVICE=google + +# DeepL API Key (required for DeepL provider) DEEPL_API_KEY=your_deepl_api_key_here # Ollama Configuration (for LLM-based translation) @@ -7,7 +13,58 @@ OLLAMA_BASE_URL=http://localhost:11434 OLLAMA_MODEL=llama3 OLLAMA_VISION_MODEL=llava -# API Configuration +# ============== File Limits ============== +# Maximum file size in MB MAX_FILE_SIZE_MB=50 -UPLOAD_DIR=./uploads -OUTPUT_DIR=./outputs + +# ============== Rate Limiting (SaaS) ============== +# Enable/disable rate limiting +RATE_LIMIT_ENABLED=true + +# Request limits +RATE_LIMIT_PER_MINUTE=30 +RATE_LIMIT_PER_HOUR=200 + +# Translation-specific limits +TRANSLATIONS_PER_MINUTE=10 +TRANSLATIONS_PER_HOUR=50 +MAX_CONCURRENT_TRANSLATIONS=5 + +# ============== Cleanup Service ============== +# Enable automatic file cleanup +CLEANUP_ENABLED=true + +# Cleanup interval in minutes +CLEANUP_INTERVAL_MINUTES=15 + +# File time-to-live in minutes +FILE_TTL_MINUTES=60 +INPUT_FILE_TTL_MINUTES=30 +OUTPUT_FILE_TTL_MINUTES=120 + +# Disk space warning thresholds (GB) +DISK_WARNING_THRESHOLD_GB=5.0 +DISK_CRITICAL_THRESHOLD_GB=1.0 + +# ============== Security ============== +# Enable HSTS (only for HTTPS deployments) +ENABLE_HSTS=false + +# CORS allowed origins (comma-separated) +CORS_ORIGINS=* + +# Maximum request size in MB +MAX_REQUEST_SIZE_MB=100 + +# Request timeout in seconds +REQUEST_TIMEOUT_SECONDS=300 + +# ============== Monitoring ============== +# Log level: DEBUG, INFO, WARNING, ERROR +LOG_LEVEL=INFO + +# Enable request logging +ENABLE_REQUEST_LOGGING=true + +# Memory usage threshold (percentage) +MAX_MEMORY_PERCENT=80 diff --git a/config.py b/config.py index 8137950..d8363c9 100644 --- a/config.py +++ b/config.py @@ -1,5 +1,6 @@ """ Configuration module for the Document Translation API +SaaS-ready with comprehensive settings for production deployment """ import os from pathlib import Path @@ -8,7 +9,7 @@ from dotenv import load_dotenv load_dotenv() class Config: - # Translation Service + # ============== Translation Service ============== TRANSLATION_SERVICE = os.getenv("TRANSLATION_SERVICE", "google") DEEPL_API_KEY = os.getenv("DEEPL_API_KEY", "") @@ -17,20 +18,51 @@ class Config: OLLAMA_MODEL = os.getenv("OLLAMA_MODEL", "llama3") OLLAMA_VISION_MODEL = os.getenv("OLLAMA_VISION_MODEL", "llava") - # File Upload Configuration + # ============== File Upload Configuration ============== MAX_FILE_SIZE_MB = int(os.getenv("MAX_FILE_SIZE_MB", "50")) MAX_FILE_SIZE_BYTES = MAX_FILE_SIZE_MB * 1024 * 1024 # Directories - BASE_DIR = Path(__file__).parent.parent + BASE_DIR = Path(__file__).parent UPLOAD_DIR = BASE_DIR / "uploads" OUTPUT_DIR = BASE_DIR / "outputs" TEMP_DIR = BASE_DIR / "temp" + LOGS_DIR = BASE_DIR / "logs" # Supported file types SUPPORTED_EXTENSIONS = {".xlsx", ".docx", ".pptx"} - # API Configuration + # ============== Rate Limiting (SaaS) ============== + RATE_LIMIT_ENABLED = os.getenv("RATE_LIMIT_ENABLED", "true").lower() == "true" + RATE_LIMIT_PER_MINUTE = int(os.getenv("RATE_LIMIT_PER_MINUTE", "30")) + RATE_LIMIT_PER_HOUR = int(os.getenv("RATE_LIMIT_PER_HOUR", "200")) + TRANSLATIONS_PER_MINUTE = int(os.getenv("TRANSLATIONS_PER_MINUTE", "10")) + TRANSLATIONS_PER_HOUR = int(os.getenv("TRANSLATIONS_PER_HOUR", "50")) + MAX_CONCURRENT_TRANSLATIONS = int(os.getenv("MAX_CONCURRENT_TRANSLATIONS", "5")) + + # ============== Cleanup Service ============== + CLEANUP_ENABLED = os.getenv("CLEANUP_ENABLED", "true").lower() == "true" + CLEANUP_INTERVAL_MINUTES = int(os.getenv("CLEANUP_INTERVAL_MINUTES", "15")) + FILE_TTL_MINUTES = int(os.getenv("FILE_TTL_MINUTES", "60")) + INPUT_FILE_TTL_MINUTES = int(os.getenv("INPUT_FILE_TTL_MINUTES", "30")) + OUTPUT_FILE_TTL_MINUTES = int(os.getenv("OUTPUT_FILE_TTL_MINUTES", "120")) + + # Disk space thresholds + DISK_WARNING_THRESHOLD_GB = float(os.getenv("DISK_WARNING_THRESHOLD_GB", "5.0")) + DISK_CRITICAL_THRESHOLD_GB = float(os.getenv("DISK_CRITICAL_THRESHOLD_GB", "1.0")) + + # ============== Security ============== + ENABLE_HSTS = os.getenv("ENABLE_HSTS", "false").lower() == "true" + CORS_ORIGINS = os.getenv("CORS_ORIGINS", "*").split(",") + MAX_REQUEST_SIZE_MB = int(os.getenv("MAX_REQUEST_SIZE_MB", "100")) + REQUEST_TIMEOUT_SECONDS = int(os.getenv("REQUEST_TIMEOUT_SECONDS", "300")) + + # ============== Monitoring ============== + LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO") + ENABLE_REQUEST_LOGGING = os.getenv("ENABLE_REQUEST_LOGGING", "true").lower() == "true" + MAX_MEMORY_PERCENT = float(os.getenv("MAX_MEMORY_PERCENT", "80")) + + # ============== API Configuration ============== API_TITLE = "Document Translation API" API_VERSION = "1.0.0" API_DESCRIPTION = """ @@ -40,6 +72,12 @@ class Config: - Excel (.xlsx) - Preserves cell formatting, formulas, merged cells, images - Word (.docx) - Preserves styles, tables, images, headers/footers - PowerPoint (.pptx) - Preserves layouts, animations, embedded media + + SaaS Features: + - Rate limiting per client IP + - Automatic file cleanup + - Health monitoring + - Request logging """ @classmethod @@ -48,5 +86,7 @@ class Config: cls.UPLOAD_DIR.mkdir(exist_ok=True, parents=True) cls.OUTPUT_DIR.mkdir(exist_ok=True, parents=True) cls.TEMP_DIR.mkdir(exist_ok=True, parents=True) + cls.LOGS_DIR.mkdir(exist_ok=True, parents=True) + config = Config() diff --git a/main.py b/main.py index 7d18675..1c52b75 100644 --- a/main.py +++ b/main.py @@ -1,24 +1,55 @@ """ Document Translation API FastAPI application for translating complex documents while preserving formatting +SaaS-ready with rate limiting, validation, and robust error handling """ -from fastapi import FastAPI, UploadFile, File, Form, HTTPException +from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request, Depends from fastapi.responses import FileResponse, JSONResponse from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles +from contextlib import asynccontextmanager from pathlib import Path from typing import Optional import asyncio import logging +import os from config import config from translators import excel_translator, word_translator, pptx_translator from utils import file_handler, handle_translation_error, DocumentProcessingError -# Configure logging -logging.basicConfig(level=logging.INFO) +# Import SaaS middleware +from middleware.rate_limiting import RateLimitMiddleware, RateLimitManager, RateLimitConfig +from middleware.security import SecurityHeadersMiddleware, RequestLoggingMiddleware, ErrorHandlingMiddleware +from middleware.cleanup import FileCleanupManager, MemoryMonitor, HealthChecker, create_cleanup_manager +from middleware.validation import FileValidator, LanguageValidator, ProviderValidator, InputSanitizer, ValidationError + +# Configure structured logging +logging.basicConfig( + level=getattr(logging, os.getenv("LOG_LEVEL", "INFO")), + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) logger = logging.getLogger(__name__) +# Initialize SaaS components +rate_limit_config = RateLimitConfig( + requests_per_minute=int(os.getenv("RATE_LIMIT_PER_MINUTE", "30")), + requests_per_hour=int(os.getenv("RATE_LIMIT_PER_HOUR", "200")), + translations_per_minute=int(os.getenv("TRANSLATIONS_PER_MINUTE", "10")), + translations_per_hour=int(os.getenv("TRANSLATIONS_PER_HOUR", "50")), + max_concurrent_translations=int(os.getenv("MAX_CONCURRENT_TRANSLATIONS", "5")), +) +rate_limit_manager = RateLimitManager(rate_limit_config) + +cleanup_manager = create_cleanup_manager(config) +memory_monitor = MemoryMonitor(max_memory_percent=float(os.getenv("MAX_MEMORY_PERCENT", "80"))) +health_checker = HealthChecker(cleanup_manager, memory_monitor) + +file_validator = FileValidator( + max_size_mb=config.MAX_FILE_SIZE_MB, + allowed_extensions=config.SUPPORTED_EXTENSIONS +) + def build_full_prompt(system_prompt: str, glossary: str) -> str: """Combine system prompt and glossary into a single prompt for LLM translation.""" @@ -40,23 +71,47 @@ Always use the translations from this glossary when you encounter these terms."" return "\n\n".join(parts) if parts else "" -# Ensure necessary directories exist -config.ensure_directories() +# Lifespan context manager for startup/shutdown +@asynccontextmanager +async def lifespan(app: FastAPI): + """Handle startup and shutdown events""" + # Startup + logger.info("Starting Document Translation API...") + config.ensure_directories() + await cleanup_manager.start() + logger.info("API ready to accept requests") + + yield + + # Shutdown + logger.info("Shutting down...") + await cleanup_manager.stop() + logger.info("Cleanup completed") -# Create FastAPI app + +# Create FastAPI app with lifespan app = FastAPI( title=config.API_TITLE, version=config.API_VERSION, - description=config.API_DESCRIPTION + description=config.API_DESCRIPTION, + lifespan=lifespan ) -# Add CORS middleware +# Add middleware (order matters - first added is outermost) +app.add_middleware(ErrorHandlingMiddleware) +app.add_middleware(RequestLoggingMiddleware, log_body=False) +app.add_middleware(SecurityHeadersMiddleware, config={"enable_hsts": os.getenv("ENABLE_HSTS", "false").lower() == "true"}) +app.add_middleware(RateLimitMiddleware, rate_limit_manager=rate_limit_manager) + +# CORS - configure for production +allowed_origins = os.getenv("CORS_ORIGINS", "*").split(",") app.add_middleware( CORSMiddleware, - allow_origins=["*"], # Configure appropriately for production + allow_origins=allowed_origins, allow_credentials=True, - allow_methods=["*"], + allow_methods=["GET", "POST", "DELETE", "OPTIONS"], allow_headers=["*"], + expose_headers=["X-Request-ID", "X-Original-Filename", "X-File-Size-MB", "X-Target-Language"] ) # Mount static files @@ -65,6 +120,20 @@ if static_dir.exists(): app.mount("/static", StaticFiles(directory=str(static_dir)), name="static") +# Custom exception handler for ValidationError +@app.exception_handler(ValidationError) +async def validation_error_handler(request: Request, exc: ValidationError): + """Handle validation errors with user-friendly messages""" + return JSONResponse( + status_code=400, + content={ + "error": exc.code, + "message": exc.message, + "details": exc.details + } + ) + + @app.get("/") async def root(): """Root endpoint with API information""" @@ -83,11 +152,24 @@ async def root(): @app.get("/health") async def health_check(): - """Health check endpoint""" - return { - "status": "healthy", - "translation_service": config.TRANSLATION_SERVICE - } + """Health check endpoint with detailed system status""" + health_status = await health_checker.check_health() + status_code = 200 if health_status.get("status") == "healthy" else 503 + + return JSONResponse( + status_code=status_code, + content={ + "status": health_status.get("status", "unknown"), + "translation_service": config.TRANSLATION_SERVICE, + "memory": health_status.get("memory", {}), + "disk": health_status.get("disk", {}), + "cleanup_service": health_status.get("cleanup_service", {}), + "rate_limits": { + "requests_per_minute": rate_limit_config.requests_per_minute, + "translations_per_minute": rate_limit_config.translations_per_minute, + } + } + ) @app.get("/languages") @@ -128,6 +210,7 @@ async def get_supported_languages(): @app.post("/translate") async def translate_document( + request: Request, file: UploadFile = File(..., description="Document file to translate (.xlsx, .docx, or .pptx)"), target_language: str = Form(..., description="Target language code (e.g., 'es', 'fr', 'de')"), source_language: str = Form(default="auto", description="Source language code (default: auto-detect)"), @@ -160,11 +243,38 @@ async def translate_document( """ input_path = None output_path = None + request_id = getattr(request.state, 'request_id', 'unknown') try: + # Validate inputs + sanitized_language = InputSanitizer.sanitize_language_code(target_language) + LanguageValidator.validate(sanitized_language) + ProviderValidator.validate(provider) + + # Validate file before processing + validation_result = await file_validator.validate_async(file) + if not validation_result.is_valid: + raise ValidationError( + message=f"File validation failed: {'; '.join(validation_result.errors)}", + code="INVALID_FILE", + details={"errors": validation_result.errors, "warnings": validation_result.warnings} + ) + + # Log any warnings + if validation_result.warnings: + logger.warning(f"[{request_id}] File validation warnings: {validation_result.warnings}") + + # Check rate limit for translations + client_ip = request.client.host if request.client else "unknown" + if not await rate_limit_manager.check_translation_limit(client_ip): + raise HTTPException( + status_code=429, + detail="Translation rate limit exceeded. Please try again later." + ) + # Validate file extension file_extension = file_handler.validate_file_extension(file.filename) - logger.info(f"Processing {file_extension} file: {file.filename}") + logger.info(f"[{request_id}] Processing {file_extension} file: {file.filename}") # Validate file size file_handler.validate_file_size(file) @@ -178,7 +288,11 @@ async def translate_document( output_path = config.OUTPUT_DIR / output_filename await file_handler.save_upload_file(file, input_path) - logger.info(f"Saved input file to: {input_path}") + logger.info(f"[{request_id}] Saved input file to: {input_path}") + + # Track file for cleanup + await cleanup_manager.track_file(input_path, ttl_minutes=30) + await cleanup_manager.track_file(output_path, ttl_minutes=60) # Configure translation provider from services.translation_service import GoogleTranslationProvider, DeepLTranslationProvider, LibreTranslationProvider, OllamaTranslationProvider, OpenAITranslationProvider, translation_service @@ -657,6 +771,74 @@ async def reconstruct_document( raise HTTPException(status_code=500, detail=f"Failed to reconstruct document: {str(e)}") +# ============== SaaS Management Endpoints ============== + +@app.get("/metrics") +async def get_metrics(): + """Get system metrics and statistics for monitoring""" + health_status = await health_checker.check_health() + cleanup_stats = cleanup_manager.get_stats() + rate_limit_stats = rate_limit_manager.get_stats() + + return { + "system": { + "memory": health_status.get("memory", {}), + "disk": health_status.get("disk", {}), + "status": health_status.get("status", "unknown") + }, + "cleanup": cleanup_stats, + "rate_limits": rate_limit_stats, + "config": { + "max_file_size_mb": config.MAX_FILE_SIZE_MB, + "supported_extensions": list(config.SUPPORTED_EXTENSIONS), + "translation_service": config.TRANSLATION_SERVICE + } + } + + +@app.get("/rate-limit/status") +async def get_rate_limit_status(request: Request): + """Get current rate limit status for the requesting client""" + client_ip = request.client.host if request.client else "unknown" + status = await rate_limit_manager.get_client_status(client_ip) + + return { + "client_ip": client_ip, + "limits": { + "requests_per_minute": rate_limit_config.requests_per_minute, + "requests_per_hour": rate_limit_config.requests_per_hour, + "translations_per_minute": rate_limit_config.translations_per_minute, + "translations_per_hour": rate_limit_config.translations_per_hour + }, + "current_usage": status + } + + +@app.post("/admin/cleanup/trigger") +async def trigger_cleanup(): + """Trigger manual cleanup of expired files""" + try: + cleaned = await cleanup_manager.cleanup_expired() + return { + "status": "success", + "files_cleaned": cleaned, + "message": f"Cleaned up {cleaned} expired files" + } + except Exception as e: + logger.error(f"Manual cleanup failed: {str(e)}") + raise HTTPException(status_code=500, detail=f"Cleanup failed: {str(e)}") + + +@app.get("/admin/files/tracked") +async def get_tracked_files(): + """Get list of currently tracked files""" + tracked = cleanup_manager.get_tracked_files() + return { + "count": len(tracked), + "files": tracked + } + + if __name__ == "__main__": import uvicorn uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True) diff --git a/middleware/__init__.py b/middleware/__init__.py new file mode 100644 index 0000000..2d3c558 --- /dev/null +++ b/middleware/__init__.py @@ -0,0 +1,62 @@ +""" +Middleware package for SaaS robustness + +This package provides: +- Rate limiting: Protect against abuse and ensure fair usage +- Validation: Validate all inputs before processing +- Security: Security headers, request logging, error handling +- Cleanup: Automatic file cleanup and resource management +""" + +from .rate_limiting import ( + RateLimitConfig, + RateLimitManager, + RateLimitMiddleware, + ClientRateLimiter, +) + +from .validation import ( + ValidationError, + ValidationResult, + FileValidator, + LanguageValidator, + ProviderValidator, + InputSanitizer, +) + +from .security import ( + SecurityHeadersMiddleware, + RequestLoggingMiddleware, + ErrorHandlingMiddleware, +) + +from .cleanup import ( + FileCleanupManager, + MemoryMonitor, + HealthChecker, + create_cleanup_manager, +) + +__all__ = [ + # Rate limiting + "RateLimitConfig", + "RateLimitManager", + "RateLimitMiddleware", + "ClientRateLimiter", + # Validation + "ValidationError", + "ValidationResult", + "FileValidator", + "LanguageValidator", + "ProviderValidator", + "InputSanitizer", + # Security + "SecurityHeadersMiddleware", + "RequestLoggingMiddleware", + "ErrorHandlingMiddleware", + # Cleanup + "FileCleanupManager", + "MemoryMonitor", + "HealthChecker", + "create_cleanup_manager", +] diff --git a/middleware/cleanup.py b/middleware/cleanup.py new file mode 100644 index 0000000..ad32e91 --- /dev/null +++ b/middleware/cleanup.py @@ -0,0 +1,400 @@ +""" +Cleanup and Resource Management for SaaS robustness +Automatic cleanup of temporary files and resources +""" +import os +import time +import asyncio +import threading +from pathlib import Path +from datetime import datetime, timedelta +from typing import Optional, Set +import logging + +logger = logging.getLogger(__name__) + + +class FileCleanupManager: + """Manages automatic cleanup of temporary and output files""" + + def __init__( + self, + upload_dir: Path, + output_dir: Path, + temp_dir: Path, + max_file_age_hours: int = 1, + cleanup_interval_minutes: int = 10, + max_total_size_gb: float = 10.0 + ): + self.upload_dir = Path(upload_dir) + self.output_dir = Path(output_dir) + self.temp_dir = Path(temp_dir) + self.max_file_age_seconds = max_file_age_hours * 3600 + self.cleanup_interval = cleanup_interval_minutes * 60 + self.max_total_size_bytes = int(max_total_size_gb * 1024 * 1024 * 1024) + + self._running = False + self._task: Optional[asyncio.Task] = None + self._protected_files: Set[str] = set() + self._tracked_files: dict = {} # filepath -> {created, ttl_minutes} + self._lock = threading.Lock() + self._stats = { + "files_cleaned": 0, + "bytes_freed": 0, + "cleanup_runs": 0 + } + + async def track_file(self, filepath: Path, ttl_minutes: int = 60): + """Track a file for automatic cleanup after TTL expires""" + with self._lock: + self._tracked_files[str(filepath)] = { + "created": time.time(), + "ttl_minutes": ttl_minutes, + "expires_at": time.time() + (ttl_minutes * 60) + } + + def get_tracked_files(self) -> list: + """Get list of currently tracked files with their status""" + now = time.time() + result = [] + + with self._lock: + for filepath, info in self._tracked_files.items(): + remaining = info["expires_at"] - now + result.append({ + "path": filepath, + "exists": Path(filepath).exists(), + "expires_in_seconds": max(0, int(remaining)), + "ttl_minutes": info["ttl_minutes"] + }) + + return result + + async def cleanup_expired(self) -> int: + """Cleanup expired tracked files""" + now = time.time() + cleaned = 0 + to_remove = [] + + with self._lock: + for filepath, info in list(self._tracked_files.items()): + if now > info["expires_at"]: + to_remove.append(filepath) + + for filepath in to_remove: + try: + path = Path(filepath) + if path.exists() and not self.is_protected(path): + size = path.stat().st_size + path.unlink() + cleaned += 1 + self._stats["files_cleaned"] += 1 + self._stats["bytes_freed"] += size + logger.info(f"Cleaned expired file: {filepath}") + + with self._lock: + self._tracked_files.pop(filepath, None) + + except Exception as e: + logger.warning(f"Failed to clean expired file {filepath}: {e}") + + return cleaned + + def get_stats(self) -> dict: + """Get cleanup statistics""" + disk_usage = self.get_disk_usage() + + with self._lock: + tracked_count = len(self._tracked_files) + + return { + "files_cleaned_total": self._stats["files_cleaned"], + "bytes_freed_total_mb": round(self._stats["bytes_freed"] / (1024 * 1024), 2), + "cleanup_runs": self._stats["cleanup_runs"], + "tracked_files": tracked_count, + "disk_usage": disk_usage, + "is_running": self._running + } + + def protect_file(self, filepath: Path): + """Mark a file as protected (being processed)""" + with self._lock: + self._protected_files.add(str(filepath)) + + def unprotect_file(self, filepath: Path): + """Remove protection from a file""" + with self._lock: + self._protected_files.discard(str(filepath)) + + def is_protected(self, filepath: Path) -> bool: + """Check if a file is protected""" + with self._lock: + return str(filepath) in self._protected_files + + async def start(self): + """Start the cleanup background task""" + if self._running: + return + + self._running = True + self._task = asyncio.create_task(self._cleanup_loop()) + logger.info("File cleanup manager started") + + async def stop(self): + """Stop the cleanup background task""" + self._running = False + if self._task: + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + logger.info("File cleanup manager stopped") + + async def _cleanup_loop(self): + """Background loop for periodic cleanup""" + while self._running: + try: + await self.cleanup() + await self.cleanup_expired() + self._stats["cleanup_runs"] += 1 + except Exception as e: + logger.error(f"Cleanup error: {e}") + + await asyncio.sleep(self.cleanup_interval) + + async def cleanup(self) -> dict: + """Perform cleanup of old files""" + stats = { + "files_deleted": 0, + "bytes_freed": 0, + "errors": [] + } + + now = time.time() + + # Cleanup each directory + for directory in [self.upload_dir, self.output_dir, self.temp_dir]: + if not directory.exists(): + continue + + for filepath in directory.iterdir(): + if not filepath.is_file(): + continue + + # Skip protected files + if self.is_protected(filepath): + continue + + try: + # Check file age + file_age = now - filepath.stat().st_mtime + + if file_age > self.max_file_age_seconds: + file_size = filepath.stat().st_size + filepath.unlink() + stats["files_deleted"] += 1 + stats["bytes_freed"] += file_size + logger.debug(f"Deleted old file: {filepath}") + + except Exception as e: + stats["errors"].append(str(e)) + logger.warning(f"Failed to delete {filepath}: {e}") + + # Force cleanup if total size exceeds limit + await self._enforce_size_limit(stats) + + if stats["files_deleted"] > 0: + mb_freed = stats["bytes_freed"] / (1024 * 1024) + logger.info(f"Cleanup: deleted {stats['files_deleted']} files, freed {mb_freed:.2f}MB") + + return stats + + async def _enforce_size_limit(self, stats: dict): + """Delete oldest files if total size exceeds limit""" + files_with_mtime = [] + total_size = 0 + + for directory in [self.upload_dir, self.output_dir, self.temp_dir]: + if not directory.exists(): + continue + + for filepath in directory.iterdir(): + if not filepath.is_file() or self.is_protected(filepath): + continue + + try: + stat = filepath.stat() + files_with_mtime.append((filepath, stat.st_mtime, stat.st_size)) + total_size += stat.st_size + except Exception: + pass + + # If under limit, nothing to do + if total_size <= self.max_total_size_bytes: + return + + # Sort by modification time (oldest first) + files_with_mtime.sort(key=lambda x: x[1]) + + # Delete oldest files until under limit + for filepath, _, size in files_with_mtime: + if total_size <= self.max_total_size_bytes: + break + + try: + filepath.unlink() + total_size -= size + stats["files_deleted"] += 1 + stats["bytes_freed"] += size + logger.info(f"Deleted file to free space: {filepath}") + except Exception as e: + stats["errors"].append(str(e)) + + def get_disk_usage(self) -> dict: + """Get current disk usage statistics""" + total_files = 0 + total_size = 0 + + for directory in [self.upload_dir, self.output_dir, self.temp_dir]: + if not directory.exists(): + continue + + for filepath in directory.iterdir(): + if filepath.is_file(): + total_files += 1 + try: + total_size += filepath.stat().st_size + except Exception: + pass + + return { + "total_files": total_files, + "total_size_mb": round(total_size / (1024 * 1024), 2), + "max_size_gb": self.max_total_size_bytes / (1024 * 1024 * 1024), + "usage_percent": round((total_size / self.max_total_size_bytes) * 100, 1) if self.max_total_size_bytes > 0 else 0, + "directories": { + "uploads": str(self.upload_dir), + "outputs": str(self.output_dir), + "temp": str(self.temp_dir) + } + } + + +class MemoryMonitor: + """Monitors memory usage and triggers cleanup if needed""" + + def __init__(self, max_memory_percent: float = 80.0): + self.max_memory_percent = max_memory_percent + self._high_memory_callbacks = [] + + def get_memory_usage(self) -> dict: + """Get current memory usage""" + try: + import psutil + process = psutil.Process() + memory_info = process.memory_info() + system_memory = psutil.virtual_memory() + + return { + "process_rss_mb": round(memory_info.rss / (1024 * 1024), 2), + "process_vms_mb": round(memory_info.vms / (1024 * 1024), 2), + "system_total_gb": round(system_memory.total / (1024 * 1024 * 1024), 2), + "system_available_gb": round(system_memory.available / (1024 * 1024 * 1024), 2), + "system_percent": system_memory.percent + } + except ImportError: + return {"error": "psutil not installed"} + except Exception as e: + return {"error": str(e)} + + def check_memory(self) -> bool: + """Check if memory usage is within limits""" + usage = self.get_memory_usage() + if "error" in usage: + return True # Can't check, assume OK + + return usage.get("system_percent", 0) < self.max_memory_percent + + def on_high_memory(self, callback): + """Register callback for high memory situations""" + self._high_memory_callbacks.append(callback) + + +class HealthChecker: + """Comprehensive health checking for the application""" + + def __init__(self, cleanup_manager: FileCleanupManager, memory_monitor: MemoryMonitor): + self.cleanup_manager = cleanup_manager + self.memory_monitor = memory_monitor + self.start_time = datetime.now() + self._translation_count = 0 + self._error_count = 0 + self._lock = threading.Lock() + + def record_translation(self, success: bool = True): + """Record a translation attempt""" + with self._lock: + self._translation_count += 1 + if not success: + self._error_count += 1 + + async def check_health(self) -> dict: + """Get comprehensive health status (async version)""" + return self.get_health() + + def get_health(self) -> dict: + """Get comprehensive health status""" + memory = self.memory_monitor.get_memory_usage() + disk = self.cleanup_manager.get_disk_usage() + + # Determine overall status + status = "healthy" + issues = [] + + if "error" not in memory: + if memory.get("system_percent", 0) > 90: + status = "degraded" + issues.append("High memory usage") + elif memory.get("system_percent", 0) > 80: + issues.append("Memory usage elevated") + + if disk.get("usage_percent", 0) > 90: + status = "degraded" + issues.append("High disk usage") + elif disk.get("usage_percent", 0) > 80: + issues.append("Disk usage elevated") + + uptime = datetime.now() - self.start_time + + return { + "status": status, + "issues": issues, + "uptime_seconds": int(uptime.total_seconds()), + "uptime_human": str(uptime).split('.')[0], + "translations": { + "total": self._translation_count, + "errors": self._error_count, + "success_rate": round( + ((self._translation_count - self._error_count) / self._translation_count * 100) + if self._translation_count > 0 else 100, 1 + ) + }, + "memory": memory, + "disk": disk, + "cleanup_service": self.cleanup_manager.get_stats(), + "timestamp": datetime.now().isoformat() + } + + +# Create default instances +def create_cleanup_manager(config) -> FileCleanupManager: + """Create cleanup manager with config""" + return FileCleanupManager( + upload_dir=config.UPLOAD_DIR, + output_dir=config.OUTPUT_DIR, + temp_dir=config.TEMP_DIR, + max_file_age_hours=getattr(config, 'MAX_FILE_AGE_HOURS', 1), + cleanup_interval_minutes=getattr(config, 'CLEANUP_INTERVAL_MINUTES', 10), + max_total_size_gb=getattr(config, 'MAX_TOTAL_SIZE_GB', 10.0) + ) diff --git a/middleware/rate_limiting.py b/middleware/rate_limiting.py new file mode 100644 index 0000000..0a4a948 --- /dev/null +++ b/middleware/rate_limiting.py @@ -0,0 +1,328 @@ +""" +Rate Limiting Middleware for SaaS robustness +Protects against abuse and ensures fair usage +""" +import time +import asyncio +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Dict, Optional +from fastapi import Request, HTTPException +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.responses import JSONResponse +import logging + +logger = logging.getLogger(__name__) + + +@dataclass +class RateLimitConfig: + """Configuration for rate limiting""" + # Requests per window + requests_per_minute: int = 30 + requests_per_hour: int = 200 + requests_per_day: int = 1000 + + # Translation-specific limits + translations_per_minute: int = 10 + translations_per_hour: int = 50 + max_concurrent_translations: int = 5 + + # File size limits (MB) + max_file_size_mb: int = 50 + max_total_size_per_hour_mb: int = 500 + + # Burst protection + burst_limit: int = 10 # Max requests in 1 second + + # Whitelist IPs (no rate limiting) + whitelist_ips: list = field(default_factory=lambda: ["127.0.0.1", "::1"]) + + +class TokenBucket: + """Token bucket algorithm for rate limiting""" + + def __init__(self, capacity: int, refill_rate: float): + self.capacity = capacity + self.refill_rate = refill_rate # tokens per second + self.tokens = capacity + self.last_refill = time.time() + self._lock = asyncio.Lock() + + async def consume(self, tokens: int = 1) -> bool: + """Try to consume tokens, return True if successful""" + async with self._lock: + self._refill() + if self.tokens >= tokens: + self.tokens -= tokens + return True + return False + + def _refill(self): + """Refill tokens based on time elapsed""" + now = time.time() + elapsed = now - self.last_refill + self.tokens = min(self.capacity, self.tokens + elapsed * self.refill_rate) + self.last_refill = now + + +class SlidingWindowCounter: + """Sliding window counter for accurate rate limiting""" + + def __init__(self, window_seconds: int, max_requests: int): + self.window_seconds = window_seconds + self.max_requests = max_requests + self.requests: list = [] + self._lock = asyncio.Lock() + + async def is_allowed(self) -> bool: + """Check if a new request is allowed""" + async with self._lock: + now = time.time() + # Remove old requests outside the window + self.requests = [ts for ts in self.requests if now - ts < self.window_seconds] + + if len(self.requests) < self.max_requests: + self.requests.append(now) + return True + return False + + @property + def current_count(self) -> int: + """Get current request count in window""" + now = time.time() + return len([ts for ts in self.requests if now - ts < self.window_seconds]) + + +class ClientRateLimiter: + """Per-client rate limiter with multiple windows""" + + def __init__(self, config: RateLimitConfig): + self.config = config + self.minute_counter = SlidingWindowCounter(60, config.requests_per_minute) + self.hour_counter = SlidingWindowCounter(3600, config.requests_per_hour) + self.day_counter = SlidingWindowCounter(86400, config.requests_per_day) + self.burst_bucket = TokenBucket(config.burst_limit, config.burst_limit) + self.translation_minute = SlidingWindowCounter(60, config.translations_per_minute) + self.translation_hour = SlidingWindowCounter(3600, config.translations_per_hour) + self.concurrent_translations = 0 + self.total_size_hour: list = [] # List of (timestamp, size_mb) + self._lock = asyncio.Lock() + + async def check_request(self) -> tuple[bool, str]: + """Check if request is allowed, return (allowed, reason)""" + # Check burst limit + if not await self.burst_bucket.consume(): + return False, "Too many requests. Please slow down." + + # Check minute limit + if not await self.minute_counter.is_allowed(): + return False, f"Rate limit exceeded. Max {self.config.requests_per_minute} requests per minute." + + # Check hour limit + if not await self.hour_counter.is_allowed(): + return False, f"Hourly limit exceeded. Max {self.config.requests_per_hour} requests per hour." + + # Check day limit + if not await self.day_counter.is_allowed(): + return False, f"Daily limit exceeded. Max {self.config.requests_per_day} requests per day." + + return True, "" + + async def check_translation(self, file_size_mb: float = 0) -> tuple[bool, str]: + """Check if translation request is allowed""" + async with self._lock: + # Check concurrent limit + if self.concurrent_translations >= self.config.max_concurrent_translations: + return False, f"Too many concurrent translations. Max {self.config.max_concurrent_translations} at a time." + + # Check translation per minute + if not await self.translation_minute.is_allowed(): + return False, f"Translation rate limit exceeded. Max {self.config.translations_per_minute} translations per minute." + + # Check translation per hour + if not await self.translation_hour.is_allowed(): + return False, f"Hourly translation limit exceeded. Max {self.config.translations_per_hour} translations per hour." + + # Check total size per hour + async with self._lock: + now = time.time() + self.total_size_hour = [(ts, size) for ts, size in self.total_size_hour if now - ts < 3600] + total_size = sum(size for _, size in self.total_size_hour) + + if total_size + file_size_mb > self.config.max_total_size_per_hour_mb: + return False, f"Hourly data limit exceeded. Max {self.config.max_total_size_per_hour_mb}MB per hour." + + self.total_size_hour.append((now, file_size_mb)) + + return True, "" + + async def start_translation(self): + """Mark start of translation""" + async with self._lock: + self.concurrent_translations += 1 + + async def end_translation(self): + """Mark end of translation""" + async with self._lock: + self.concurrent_translations = max(0, self.concurrent_translations - 1) + + def get_stats(self) -> dict: + """Get current rate limit stats""" + return { + "requests_minute": self.minute_counter.current_count, + "requests_hour": self.hour_counter.current_count, + "requests_day": self.day_counter.current_count, + "translations_minute": self.translation_minute.current_count, + "translations_hour": self.translation_hour.current_count, + "concurrent_translations": self.concurrent_translations, + } + + +class RateLimitManager: + """Manages rate limiters for all clients""" + + def __init__(self, config: Optional[RateLimitConfig] = None): + self.config = config or RateLimitConfig() + self.clients: Dict[str, ClientRateLimiter] = defaultdict(lambda: ClientRateLimiter(self.config)) + self._cleanup_interval = 3600 # Cleanup old clients every hour + self._last_cleanup = time.time() + self._total_requests = 0 + self._total_translations = 0 + + def get_client_id(self, request: Request) -> str: + """Extract client identifier from request""" + # Try to get real IP from headers (for proxied requests) + forwarded = request.headers.get("X-Forwarded-For") + if forwarded: + return forwarded.split(",")[0].strip() + + real_ip = request.headers.get("X-Real-IP") + if real_ip: + return real_ip + + # Fall back to direct client IP + if request.client: + return request.client.host + + return "unknown" + + def is_whitelisted(self, client_id: str) -> bool: + """Check if client is whitelisted""" + return client_id in self.config.whitelist_ips + + async def check_request(self, request: Request) -> tuple[bool, str, str]: + """Check if request is allowed, return (allowed, reason, client_id)""" + client_id = self.get_client_id(request) + self._total_requests += 1 + + if self.is_whitelisted(client_id): + return True, "", client_id + + client = self.clients[client_id] + allowed, reason = await client.check_request() + + return allowed, reason, client_id + + async def check_translation(self, request: Request, file_size_mb: float = 0) -> tuple[bool, str]: + """Check if translation is allowed""" + client_id = self.get_client_id(request) + self._total_translations += 1 + + if self.is_whitelisted(client_id): + return True, "" + + client = self.clients[client_id] + return await client.check_translation(file_size_mb) + + async def check_translation_limit(self, client_id: str, file_size_mb: float = 0) -> bool: + """Check if translation is allowed for a specific client ID""" + if self.is_whitelisted(client_id): + return True + + client = self.clients[client_id] + allowed, _ = await client.check_translation(file_size_mb) + return allowed + + def get_client_stats(self, request: Request) -> dict: + """Get rate limit stats for a client""" + client_id = self.get_client_id(request) + client = self.clients[client_id] + return { + "client_id": client_id, + "is_whitelisted": self.is_whitelisted(client_id), + **client.get_stats() + } + + async def get_client_status(self, client_id: str) -> dict: + """Get current usage status for a client""" + if client_id not in self.clients: + return {"status": "no_activity", "requests": 0} + + client = self.clients[client_id] + stats = client.get_stats() + + return { + "requests_used_minute": stats["requests_minute"], + "requests_used_hour": stats["requests_hour"], + "translations_used_minute": stats["translations_minute"], + "translations_used_hour": stats["translations_hour"], + "concurrent_translations": stats["concurrent_translations"], + "is_whitelisted": self.is_whitelisted(client_id) + } + + def get_stats(self) -> dict: + """Get global rate limiting statistics""" + return { + "total_requests": self._total_requests, + "total_translations": self._total_translations, + "active_clients": len(self.clients), + "config": { + "requests_per_minute": self.config.requests_per_minute, + "requests_per_hour": self.config.requests_per_hour, + "translations_per_minute": self.config.translations_per_minute, + "translations_per_hour": self.config.translations_per_hour, + "max_concurrent_translations": self.config.max_concurrent_translations + } + } + + +class RateLimitMiddleware(BaseHTTPMiddleware): + """FastAPI middleware for rate limiting""" + + def __init__(self, app, rate_limit_manager: RateLimitManager): + super().__init__(app) + self.manager = rate_limit_manager + + async def dispatch(self, request: Request, call_next): + # Skip rate limiting for health checks and static files + if request.url.path in ["/health", "/", "/docs", "/openapi.json", "/redoc"]: + return await call_next(request) + + if request.url.path.startswith("/static"): + return await call_next(request) + + # Check rate limit + allowed, reason, client_id = await self.manager.check_request(request) + + if not allowed: + logger.warning(f"Rate limit exceeded for {client_id}: {reason}") + return JSONResponse( + status_code=429, + content={ + "error": "rate_limit_exceeded", + "message": reason, + "retry_after": 60 + }, + headers={"Retry-After": "60"} + ) + + # Add client info to request state for use in endpoints + request.state.client_id = client_id + request.state.rate_limiter = self.manager.clients[client_id] + + return await call_next(request) + + +# Global rate limit manager +rate_limit_manager = RateLimitManager() diff --git a/middleware/security.py b/middleware/security.py new file mode 100644 index 0000000..271e778 --- /dev/null +++ b/middleware/security.py @@ -0,0 +1,142 @@ +""" +Security Headers Middleware for SaaS robustness +Adds security headers to all responses +""" +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import Response +import logging + +logger = logging.getLogger(__name__) + + +class SecurityHeadersMiddleware(BaseHTTPMiddleware): + """Add security headers to all responses""" + + def __init__(self, app, config: dict = None): + super().__init__(app) + self.config = config or {} + + async def dispatch(self, request: Request, call_next) -> Response: + response = await call_next(request) + + # Prevent clickjacking + response.headers["X-Frame-Options"] = "DENY" + + # Prevent MIME type sniffing + response.headers["X-Content-Type-Options"] = "nosniff" + + # Enable XSS filter + response.headers["X-XSS-Protection"] = "1; mode=block" + + # Referrer policy + response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" + + # Permissions policy + response.headers["Permissions-Policy"] = "geolocation=(), microphone=(), camera=()" + + # Content Security Policy (adjust for your frontend) + if not request.url.path.startswith("/docs") and not request.url.path.startswith("/redoc"): + response.headers["Content-Security-Policy"] = ( + "default-src 'self'; " + "script-src 'self' 'unsafe-inline' 'unsafe-eval' blob:; " + "style-src 'self' 'unsafe-inline'; " + "img-src 'self' data: blob:; " + "font-src 'self' data:; " + "connect-src 'self' http://localhost:* https://localhost:* ws://localhost:*; " + "worker-src 'self' blob:; " + ) + + # HSTS (only in production with HTTPS) + if self.config.get("enable_hsts", False): + response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains" + + return response + + +class RequestLoggingMiddleware(BaseHTTPMiddleware): + """Log all requests for monitoring and debugging""" + + def __init__(self, app, log_body: bool = False): + super().__init__(app) + self.log_body = log_body + + async def dispatch(self, request: Request, call_next) -> Response: + import time + import uuid + + # Generate request ID + request_id = str(uuid.uuid4())[:8] + request.state.request_id = request_id + + # Get client info + client_ip = self._get_client_ip(request) + + # Log request start + start_time = time.time() + logger.info( + f"[{request_id}] {request.method} {request.url.path} " + f"from {client_ip} - Started" + ) + + try: + response = await call_next(request) + + # Log request completion + duration = time.time() - start_time + logger.info( + f"[{request_id}] {request.method} {request.url.path} " + f"- {response.status_code} in {duration:.3f}s" + ) + + # Add request ID to response headers + response.headers["X-Request-ID"] = request_id + + return response + + except Exception as e: + duration = time.time() - start_time + logger.error( + f"[{request_id}] {request.method} {request.url.path} " + f"- ERROR in {duration:.3f}s: {str(e)}" + ) + raise + + def _get_client_ip(self, request: Request) -> str: + """Get real client IP from headers or connection""" + forwarded = request.headers.get("X-Forwarded-For") + if forwarded: + return forwarded.split(",")[0].strip() + + real_ip = request.headers.get("X-Real-IP") + if real_ip: + return real_ip + + if request.client: + return request.client.host + + return "unknown" + + +class ErrorHandlingMiddleware(BaseHTTPMiddleware): + """Catch all unhandled exceptions and return proper error responses""" + + async def dispatch(self, request: Request, call_next) -> Response: + from starlette.responses import JSONResponse + + try: + return await call_next(request) + + except Exception as e: + request_id = getattr(request.state, 'request_id', 'unknown') + logger.exception(f"[{request_id}] Unhandled exception: {str(e)}") + + # Don't expose internal errors in production + return JSONResponse( + status_code=500, + content={ + "error": "internal_server_error", + "message": "An unexpected error occurred. Please try again later.", + "request_id": request_id + } + ) diff --git a/middleware/validation.py b/middleware/validation.py new file mode 100644 index 0000000..f9fd5a2 --- /dev/null +++ b/middleware/validation.py @@ -0,0 +1,440 @@ +""" +Input Validation Module for SaaS robustness +Validates all user inputs before processing +""" +import re +import magic +from pathlib import Path +from typing import Optional, List, Set +from fastapi import UploadFile, HTTPException +import logging + +logger = logging.getLogger(__name__) + + +class ValidationError(Exception): + """Custom validation error with user-friendly messages""" + def __init__(self, message: str, code: str = "validation_error", details: Optional[dict] = None): + self.message = message + self.code = code + self.details = details or {} + super().__init__(message) + + +class ValidationResult: + """Result of a validation check""" + def __init__(self, is_valid: bool = True, errors: List[str] = None, warnings: List[str] = None, data: dict = None): + self.is_valid = is_valid + self.errors = errors or [] + self.warnings = warnings or [] + self.data = data or {} + + +class FileValidator: + """Validates uploaded files for security and compatibility""" + + # Allowed MIME types mapped to extensions + ALLOWED_MIME_TYPES = { + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx", + "application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx", + } + + # Magic bytes for Office Open XML files (ZIP format) + OFFICE_MAGIC_BYTES = b"PK\x03\x04" + + def __init__( + self, + max_size_mb: int = 50, + allowed_extensions: Set[str] = None, + scan_content: bool = True + ): + self.max_size_bytes = max_size_mb * 1024 * 1024 + self.max_size_mb = max_size_mb + self.allowed_extensions = allowed_extensions or {".xlsx", ".docx", ".pptx"} + self.scan_content = scan_content + + async def validate_async(self, file: UploadFile) -> ValidationResult: + """ + Validate an uploaded file asynchronously + Returns ValidationResult with is_valid, errors, warnings + """ + errors = [] + warnings = [] + data = {} + + try: + # Validate filename + if not file.filename: + errors.append("Filename is required") + return ValidationResult(is_valid=False, errors=errors) + + # Sanitize filename + try: + safe_filename = self._sanitize_filename(file.filename) + data["safe_filename"] = safe_filename + except ValidationError as e: + errors.append(str(e.message)) + return ValidationResult(is_valid=False, errors=errors) + + # Validate extension + try: + extension = self._validate_extension(safe_filename) + data["extension"] = extension + except ValidationError as e: + errors.append(str(e.message)) + return ValidationResult(is_valid=False, errors=errors) + + # Read file content for validation + content = await file.read() + await file.seek(0) # Reset for later processing + + # Validate file size + file_size = len(content) + data["size_bytes"] = file_size + data["size_mb"] = round(file_size / (1024*1024), 2) + + if file_size > self.max_size_bytes: + errors.append(f"File too large. Maximum size is {self.max_size_mb}MB, got {file_size / (1024*1024):.1f}MB") + return ValidationResult(is_valid=False, errors=errors, data=data) + + if file_size == 0: + errors.append("File is empty") + return ValidationResult(is_valid=False, errors=errors, data=data) + + # Warn about large files + if file_size > self.max_size_bytes * 0.8: + warnings.append(f"File is {data['size_mb']}MB, approaching the {self.max_size_mb}MB limit") + + # Validate magic bytes + if self.scan_content: + try: + self._validate_magic_bytes(content, extension) + except ValidationError as e: + errors.append(str(e.message)) + return ValidationResult(is_valid=False, errors=errors, data=data) + + # Validate MIME type + try: + mime_type = self._detect_mime_type(content) + data["mime_type"] = mime_type + self._validate_mime_type(mime_type, extension) + except ValidationError as e: + warnings.append(f"MIME type warning: {e.message}") + except Exception: + warnings.append("Could not verify MIME type") + + data["original_filename"] = file.filename + + return ValidationResult(is_valid=True, errors=errors, warnings=warnings, data=data) + + except Exception as e: + logger.error(f"Validation error: {str(e)}") + errors.append(f"Validation failed: {str(e)}") + return ValidationResult(is_valid=False, errors=errors, warnings=warnings, data=data) + + async def validate(self, file: UploadFile) -> dict: + """ + Validate an uploaded file + Returns validation info dict or raises ValidationError + """ + # Validate filename + if not file.filename: + raise ValidationError( + "Filename is required", + code="missing_filename" + ) + + # Sanitize filename + safe_filename = self._sanitize_filename(file.filename) + + # Validate extension + extension = self._validate_extension(safe_filename) + + # Read file content for validation + content = await file.read() + await file.seek(0) # Reset for later processing + + # Validate file size + file_size = len(content) + if file_size > self.max_size_bytes: + raise ValidationError( + f"File too large. Maximum size is {self.max_size_mb}MB, got {file_size / (1024*1024):.1f}MB", + code="file_too_large", + details={"max_mb": self.max_size_mb, "actual_mb": round(file_size / (1024*1024), 2)} + ) + + if file_size == 0: + raise ValidationError( + "File is empty", + code="empty_file" + ) + + # Validate magic bytes (file signature) + if self.scan_content: + self._validate_magic_bytes(content, extension) + + # Validate MIME type + mime_type = self._detect_mime_type(content) + self._validate_mime_type(mime_type, extension) + + return { + "original_filename": file.filename, + "safe_filename": safe_filename, + "extension": extension, + "size_bytes": file_size, + "size_mb": round(file_size / (1024*1024), 2), + "mime_type": mime_type + } + + def _sanitize_filename(self, filename: str) -> str: + """Sanitize filename to prevent path traversal and other attacks""" + # Remove path components + filename = Path(filename).name + + # Remove null bytes and control characters + filename = re.sub(r'[\x00-\x1f\x7f-\x9f]', '', filename) + + # Remove potentially dangerous characters + filename = re.sub(r'[<>:"/\\|?*]', '_', filename) + + # Limit length + if len(filename) > 255: + name, ext = filename.rsplit('.', 1) if '.' in filename else (filename, '') + filename = name[:250] + ('.' + ext if ext else '') + + # Ensure not empty after sanitization + if not filename or filename.strip() == '': + raise ValidationError( + "Invalid filename", + code="invalid_filename" + ) + + return filename + + def _validate_extension(self, filename: str) -> str: + """Validate and return the file extension""" + if '.' not in filename: + raise ValidationError( + f"File must have an extension. Supported: {', '.join(self.allowed_extensions)}", + code="missing_extension", + details={"allowed_extensions": list(self.allowed_extensions)} + ) + + extension = '.' + filename.rsplit('.', 1)[1].lower() + + if extension not in self.allowed_extensions: + raise ValidationError( + f"File type '{extension}' not supported. Supported types: {', '.join(self.allowed_extensions)}", + code="unsupported_file_type", + details={"extension": extension, "allowed_extensions": list(self.allowed_extensions)} + ) + + return extension + + def _validate_magic_bytes(self, content: bytes, extension: str): + """Validate file magic bytes match expected format""" + # All supported formats are Office Open XML (ZIP-based) + if not content.startswith(self.OFFICE_MAGIC_BYTES): + raise ValidationError( + "File content does not match expected format. The file may be corrupted or not a valid Office document.", + code="invalid_file_content" + ) + + def _detect_mime_type(self, content: bytes) -> str: + """Detect MIME type from file content""" + try: + mime = magic.Magic(mime=True) + return mime.from_buffer(content) + except Exception: + # Fallback to basic detection + if content.startswith(self.OFFICE_MAGIC_BYTES): + return "application/zip" + return "application/octet-stream" + + def _validate_mime_type(self, mime_type: str, extension: str): + """Validate MIME type matches extension""" + # Office Open XML files may be detected as ZIP + allowed_mimes = list(self.ALLOWED_MIME_TYPES.keys()) + ["application/zip", "application/octet-stream"] + + if mime_type not in allowed_mimes: + raise ValidationError( + f"Invalid file type detected. Expected Office document, got: {mime_type}", + code="invalid_mime_type", + details={"detected_mime": mime_type} + ) + + +class LanguageValidator: + """Validates language codes""" + + SUPPORTED_LANGUAGES = { + # ISO 639-1 codes + "af", "sq", "am", "ar", "hy", "az", "eu", "be", "bn", "bs", + "bg", "ca", "ceb", "zh", "zh-CN", "zh-TW", "co", "hr", "cs", + "da", "nl", "en", "eo", "et", "fi", "fr", "fy", "gl", "ka", + "de", "el", "gu", "ht", "ha", "haw", "he", "hi", "hmn", "hu", + "is", "ig", "id", "ga", "it", "ja", "jv", "kn", "kk", "km", + "rw", "ko", "ku", "ky", "lo", "la", "lv", "lt", "lb", "mk", + "mg", "ms", "ml", "mt", "mi", "mr", "mn", "my", "ne", "no", + "ny", "or", "ps", "fa", "pl", "pt", "pa", "ro", "ru", "sm", + "gd", "sr", "st", "sn", "sd", "si", "sk", "sl", "so", "es", + "su", "sw", "sv", "tl", "tg", "ta", "tt", "te", "th", "tr", + "tk", "uk", "ur", "ug", "uz", "vi", "cy", "xh", "yi", "yo", + "zu", "auto" + } + + LANGUAGE_NAMES = { + "en": "English", "es": "Spanish", "fr": "French", "de": "German", + "it": "Italian", "pt": "Portuguese", "ru": "Russian", "zh": "Chinese", + "zh-CN": "Chinese (Simplified)", "zh-TW": "Chinese (Traditional)", + "ja": "Japanese", "ko": "Korean", "ar": "Arabic", "hi": "Hindi", + "nl": "Dutch", "pl": "Polish", "tr": "Turkish", "sv": "Swedish", + "da": "Danish", "no": "Norwegian", "fi": "Finnish", "cs": "Czech", + "el": "Greek", "th": "Thai", "vi": "Vietnamese", "id": "Indonesian", + "uk": "Ukrainian", "ro": "Romanian", "hu": "Hungarian", "auto": "Auto-detect" + } + + @classmethod + def validate(cls, language_code: str, field_name: str = "language") -> str: + """Validate and normalize language code""" + if not language_code: + raise ValidationError( + f"{field_name} is required", + code="missing_language" + ) + + # Normalize + normalized = language_code.strip().lower() + + # Handle common variations + if normalized in ["chinese", "cn"]: + normalized = "zh-CN" + elif normalized in ["chinese-traditional", "tw"]: + normalized = "zh-TW" + + if normalized not in cls.SUPPORTED_LANGUAGES: + raise ValidationError( + f"Unsupported language code: '{language_code}'. See /languages for supported codes.", + code="unsupported_language", + details={"language": language_code} + ) + + return normalized + + @classmethod + def get_language_name(cls, code: str) -> str: + """Get human-readable language name""" + return cls.LANGUAGE_NAMES.get(code, code.upper()) + + +class ProviderValidator: + """Validates translation provider configuration""" + + SUPPORTED_PROVIDERS = {"google", "ollama", "deepl", "libre", "openai", "webllm"} + + @classmethod + def validate(cls, provider: str, **kwargs) -> dict: + """Validate provider and its required configuration""" + if not provider: + raise ValidationError( + "Translation provider is required", + code="missing_provider" + ) + + normalized = provider.strip().lower() + + if normalized not in cls.SUPPORTED_PROVIDERS: + raise ValidationError( + f"Unsupported provider: '{provider}'. Supported: {', '.join(cls.SUPPORTED_PROVIDERS)}", + code="unsupported_provider", + details={"provider": provider, "supported": list(cls.SUPPORTED_PROVIDERS)} + ) + + # Provider-specific validation + if normalized == "deepl": + if not kwargs.get("deepl_api_key"): + raise ValidationError( + "DeepL API key is required when using DeepL provider", + code="missing_deepl_key" + ) + + elif normalized == "openai": + if not kwargs.get("openai_api_key"): + raise ValidationError( + "OpenAI API key is required when using OpenAI provider", + code="missing_openai_key" + ) + + elif normalized == "ollama": + # Ollama doesn't require API key but may need model + model = kwargs.get("ollama_model", "") + if not model: + logger.warning("No Ollama model specified, will use default") + + return {"provider": normalized, "validated": True} + + +class InputSanitizer: + """Sanitizes user inputs to prevent injection attacks""" + + @staticmethod + def sanitize_text(text: str, max_length: int = 10000) -> str: + """Sanitize text input""" + if not text: + return "" + + # Remove null bytes + text = text.replace('\x00', '') + + # Limit length + if len(text) > max_length: + text = text[:max_length] + + return text.strip() + + @staticmethod + def sanitize_language_code(code: str) -> str: + """Sanitize and normalize language code""" + if not code: + return "auto" + + # Remove dangerous characters, keep only alphanumeric and hyphen + code = re.sub(r'[^a-zA-Z0-9\-]', '', code.strip()) + + # Limit length + if len(code) > 10: + code = code[:10] + + return code.lower() if code else "auto" + + @staticmethod + def sanitize_url(url: str) -> str: + """Sanitize URL input""" + if not url: + return "" + + url = url.strip() + + # Basic URL validation + if not re.match(r'^https?://', url, re.IGNORECASE): + raise ValidationError( + "Invalid URL format. Must start with http:// or https://", + code="invalid_url" + ) + + # Remove trailing slashes + url = url.rstrip('/') + + return url + + @staticmethod + def sanitize_api_key(key: str) -> str: + """Sanitize API key (just trim, no logging)""" + if not key: + return "" + return key.strip() + + +# Default validators +file_validator = FileValidator() diff --git a/requirements.txt b/requirements.txt index e0a6b5a..d65dd6b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,3 +14,7 @@ pandas==2.1.4 requests==2.31.0 ipykernel==6.27.1 openai>=1.0.0 + +# SaaS robustness dependencies +psutil==5.9.8 +python-magic-bin==0.4.14 # For Windows, use python-magic on Linux