feat: Add SaaS robustness middleware - Rate limiting with token bucket and sliding window algorithms - Input validation (file, language, provider) - Security headers middleware (CSP, XSS protection, etc.) - Automatic file cleanup with TTL tracking - Memory and disk monitoring - Enhanced health check and metrics endpoints - Request logging with unique IDs

This commit is contained in:
Sepehr 2025-11-30 19:25:09 +01:00
parent 8c7716bf4d
commit 500502440c
9 changed files with 1681 additions and 26 deletions

View File

@ -1,5 +1,11 @@
# Translation Service Configuration # Document Translation API - Environment Configuration
TRANSLATION_SERVICE=google # Options: google, deepl, libre, ollama # Copy this file to .env and configure your settings
# ============== Translation Services ==============
# Default provider: google, ollama, deepl, libre, openai
TRANSLATION_SERVICE=google
# DeepL API Key (required for DeepL provider)
DEEPL_API_KEY=your_deepl_api_key_here DEEPL_API_KEY=your_deepl_api_key_here
# Ollama Configuration (for LLM-based translation) # Ollama Configuration (for LLM-based translation)
@ -7,7 +13,58 @@ OLLAMA_BASE_URL=http://localhost:11434
OLLAMA_MODEL=llama3 OLLAMA_MODEL=llama3
OLLAMA_VISION_MODEL=llava OLLAMA_VISION_MODEL=llava
# API Configuration # ============== File Limits ==============
# Maximum file size in MB
MAX_FILE_SIZE_MB=50 MAX_FILE_SIZE_MB=50
UPLOAD_DIR=./uploads
OUTPUT_DIR=./outputs # ============== Rate Limiting (SaaS) ==============
# Enable/disable rate limiting
RATE_LIMIT_ENABLED=true
# Request limits
RATE_LIMIT_PER_MINUTE=30
RATE_LIMIT_PER_HOUR=200
# Translation-specific limits
TRANSLATIONS_PER_MINUTE=10
TRANSLATIONS_PER_HOUR=50
MAX_CONCURRENT_TRANSLATIONS=5
# ============== Cleanup Service ==============
# Enable automatic file cleanup
CLEANUP_ENABLED=true
# Cleanup interval in minutes
CLEANUP_INTERVAL_MINUTES=15
# File time-to-live in minutes
FILE_TTL_MINUTES=60
INPUT_FILE_TTL_MINUTES=30
OUTPUT_FILE_TTL_MINUTES=120
# Disk space warning thresholds (GB)
DISK_WARNING_THRESHOLD_GB=5.0
DISK_CRITICAL_THRESHOLD_GB=1.0
# ============== Security ==============
# Enable HSTS (only for HTTPS deployments)
ENABLE_HSTS=false
# CORS allowed origins (comma-separated)
CORS_ORIGINS=*
# Maximum request size in MB
MAX_REQUEST_SIZE_MB=100
# Request timeout in seconds
REQUEST_TIMEOUT_SECONDS=300
# ============== Monitoring ==============
# Log level: DEBUG, INFO, WARNING, ERROR
LOG_LEVEL=INFO
# Enable request logging
ENABLE_REQUEST_LOGGING=true
# Memory usage threshold (percentage)
MAX_MEMORY_PERCENT=80

View File

@ -1,5 +1,6 @@
""" """
Configuration module for the Document Translation API Configuration module for the Document Translation API
SaaS-ready with comprehensive settings for production deployment
""" """
import os import os
from pathlib import Path from pathlib import Path
@ -8,7 +9,7 @@ from dotenv import load_dotenv
load_dotenv() load_dotenv()
class Config: class Config:
# Translation Service # ============== Translation Service ==============
TRANSLATION_SERVICE = os.getenv("TRANSLATION_SERVICE", "google") TRANSLATION_SERVICE = os.getenv("TRANSLATION_SERVICE", "google")
DEEPL_API_KEY = os.getenv("DEEPL_API_KEY", "") DEEPL_API_KEY = os.getenv("DEEPL_API_KEY", "")
@ -17,20 +18,51 @@ class Config:
OLLAMA_MODEL = os.getenv("OLLAMA_MODEL", "llama3") OLLAMA_MODEL = os.getenv("OLLAMA_MODEL", "llama3")
OLLAMA_VISION_MODEL = os.getenv("OLLAMA_VISION_MODEL", "llava") OLLAMA_VISION_MODEL = os.getenv("OLLAMA_VISION_MODEL", "llava")
# File Upload Configuration # ============== File Upload Configuration ==============
MAX_FILE_SIZE_MB = int(os.getenv("MAX_FILE_SIZE_MB", "50")) MAX_FILE_SIZE_MB = int(os.getenv("MAX_FILE_SIZE_MB", "50"))
MAX_FILE_SIZE_BYTES = MAX_FILE_SIZE_MB * 1024 * 1024 MAX_FILE_SIZE_BYTES = MAX_FILE_SIZE_MB * 1024 * 1024
# Directories # Directories
BASE_DIR = Path(__file__).parent.parent BASE_DIR = Path(__file__).parent
UPLOAD_DIR = BASE_DIR / "uploads" UPLOAD_DIR = BASE_DIR / "uploads"
OUTPUT_DIR = BASE_DIR / "outputs" OUTPUT_DIR = BASE_DIR / "outputs"
TEMP_DIR = BASE_DIR / "temp" TEMP_DIR = BASE_DIR / "temp"
LOGS_DIR = BASE_DIR / "logs"
# Supported file types # Supported file types
SUPPORTED_EXTENSIONS = {".xlsx", ".docx", ".pptx"} SUPPORTED_EXTENSIONS = {".xlsx", ".docx", ".pptx"}
# API Configuration # ============== Rate Limiting (SaaS) ==============
RATE_LIMIT_ENABLED = os.getenv("RATE_LIMIT_ENABLED", "true").lower() == "true"
RATE_LIMIT_PER_MINUTE = int(os.getenv("RATE_LIMIT_PER_MINUTE", "30"))
RATE_LIMIT_PER_HOUR = int(os.getenv("RATE_LIMIT_PER_HOUR", "200"))
TRANSLATIONS_PER_MINUTE = int(os.getenv("TRANSLATIONS_PER_MINUTE", "10"))
TRANSLATIONS_PER_HOUR = int(os.getenv("TRANSLATIONS_PER_HOUR", "50"))
MAX_CONCURRENT_TRANSLATIONS = int(os.getenv("MAX_CONCURRENT_TRANSLATIONS", "5"))
# ============== Cleanup Service ==============
CLEANUP_ENABLED = os.getenv("CLEANUP_ENABLED", "true").lower() == "true"
CLEANUP_INTERVAL_MINUTES = int(os.getenv("CLEANUP_INTERVAL_MINUTES", "15"))
FILE_TTL_MINUTES = int(os.getenv("FILE_TTL_MINUTES", "60"))
INPUT_FILE_TTL_MINUTES = int(os.getenv("INPUT_FILE_TTL_MINUTES", "30"))
OUTPUT_FILE_TTL_MINUTES = int(os.getenv("OUTPUT_FILE_TTL_MINUTES", "120"))
# Disk space thresholds
DISK_WARNING_THRESHOLD_GB = float(os.getenv("DISK_WARNING_THRESHOLD_GB", "5.0"))
DISK_CRITICAL_THRESHOLD_GB = float(os.getenv("DISK_CRITICAL_THRESHOLD_GB", "1.0"))
# ============== Security ==============
ENABLE_HSTS = os.getenv("ENABLE_HSTS", "false").lower() == "true"
CORS_ORIGINS = os.getenv("CORS_ORIGINS", "*").split(",")
MAX_REQUEST_SIZE_MB = int(os.getenv("MAX_REQUEST_SIZE_MB", "100"))
REQUEST_TIMEOUT_SECONDS = int(os.getenv("REQUEST_TIMEOUT_SECONDS", "300"))
# ============== Monitoring ==============
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")
ENABLE_REQUEST_LOGGING = os.getenv("ENABLE_REQUEST_LOGGING", "true").lower() == "true"
MAX_MEMORY_PERCENT = float(os.getenv("MAX_MEMORY_PERCENT", "80"))
# ============== API Configuration ==============
API_TITLE = "Document Translation API" API_TITLE = "Document Translation API"
API_VERSION = "1.0.0" API_VERSION = "1.0.0"
API_DESCRIPTION = """ API_DESCRIPTION = """
@ -40,6 +72,12 @@ class Config:
- Excel (.xlsx) - Preserves cell formatting, formulas, merged cells, images - Excel (.xlsx) - Preserves cell formatting, formulas, merged cells, images
- Word (.docx) - Preserves styles, tables, images, headers/footers - Word (.docx) - Preserves styles, tables, images, headers/footers
- PowerPoint (.pptx) - Preserves layouts, animations, embedded media - PowerPoint (.pptx) - Preserves layouts, animations, embedded media
SaaS Features:
- Rate limiting per client IP
- Automatic file cleanup
- Health monitoring
- Request logging
""" """
@classmethod @classmethod
@ -48,5 +86,7 @@ class Config:
cls.UPLOAD_DIR.mkdir(exist_ok=True, parents=True) cls.UPLOAD_DIR.mkdir(exist_ok=True, parents=True)
cls.OUTPUT_DIR.mkdir(exist_ok=True, parents=True) cls.OUTPUT_DIR.mkdir(exist_ok=True, parents=True)
cls.TEMP_DIR.mkdir(exist_ok=True, parents=True) cls.TEMP_DIR.mkdir(exist_ok=True, parents=True)
cls.LOGS_DIR.mkdir(exist_ok=True, parents=True)
config = Config() config = Config()

216
main.py
View File

@ -1,24 +1,55 @@
""" """
Document Translation API Document Translation API
FastAPI application for translating complex documents while preserving formatting FastAPI application for translating complex documents while preserving formatting
SaaS-ready with rate limiting, validation, and robust error handling
""" """
from fastapi import FastAPI, UploadFile, File, Form, HTTPException from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request, Depends
from fastapi.responses import FileResponse, JSONResponse from fastapi.responses import FileResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from contextlib import asynccontextmanager
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import asyncio import asyncio
import logging import logging
import os
from config import config from config import config
from translators import excel_translator, word_translator, pptx_translator from translators import excel_translator, word_translator, pptx_translator
from utils import file_handler, handle_translation_error, DocumentProcessingError from utils import file_handler, handle_translation_error, DocumentProcessingError
# Configure logging # Import SaaS middleware
logging.basicConfig(level=logging.INFO) from middleware.rate_limiting import RateLimitMiddleware, RateLimitManager, RateLimitConfig
from middleware.security import SecurityHeadersMiddleware, RequestLoggingMiddleware, ErrorHandlingMiddleware
from middleware.cleanup import FileCleanupManager, MemoryMonitor, HealthChecker, create_cleanup_manager
from middleware.validation import FileValidator, LanguageValidator, ProviderValidator, InputSanitizer, ValidationError
# Configure structured logging
logging.basicConfig(
level=getattr(logging, os.getenv("LOG_LEVEL", "INFO")),
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Initialize SaaS components
rate_limit_config = RateLimitConfig(
requests_per_minute=int(os.getenv("RATE_LIMIT_PER_MINUTE", "30")),
requests_per_hour=int(os.getenv("RATE_LIMIT_PER_HOUR", "200")),
translations_per_minute=int(os.getenv("TRANSLATIONS_PER_MINUTE", "10")),
translations_per_hour=int(os.getenv("TRANSLATIONS_PER_HOUR", "50")),
max_concurrent_translations=int(os.getenv("MAX_CONCURRENT_TRANSLATIONS", "5")),
)
rate_limit_manager = RateLimitManager(rate_limit_config)
cleanup_manager = create_cleanup_manager(config)
memory_monitor = MemoryMonitor(max_memory_percent=float(os.getenv("MAX_MEMORY_PERCENT", "80")))
health_checker = HealthChecker(cleanup_manager, memory_monitor)
file_validator = FileValidator(
max_size_mb=config.MAX_FILE_SIZE_MB,
allowed_extensions=config.SUPPORTED_EXTENSIONS
)
def build_full_prompt(system_prompt: str, glossary: str) -> str: def build_full_prompt(system_prompt: str, glossary: str) -> str:
"""Combine system prompt and glossary into a single prompt for LLM translation.""" """Combine system prompt and glossary into a single prompt for LLM translation."""
@ -40,23 +71,47 @@ Always use the translations from this glossary when you encounter these terms.""
return "\n\n".join(parts) if parts else "" return "\n\n".join(parts) if parts else ""
# Ensure necessary directories exist # Lifespan context manager for startup/shutdown
config.ensure_directories() @asynccontextmanager
async def lifespan(app: FastAPI):
"""Handle startup and shutdown events"""
# Startup
logger.info("Starting Document Translation API...")
config.ensure_directories()
await cleanup_manager.start()
logger.info("API ready to accept requests")
yield
# Shutdown
logger.info("Shutting down...")
await cleanup_manager.stop()
logger.info("Cleanup completed")
# Create FastAPI app
# Create FastAPI app with lifespan
app = FastAPI( app = FastAPI(
title=config.API_TITLE, title=config.API_TITLE,
version=config.API_VERSION, version=config.API_VERSION,
description=config.API_DESCRIPTION description=config.API_DESCRIPTION,
lifespan=lifespan
) )
# Add CORS middleware # Add middleware (order matters - first added is outermost)
app.add_middleware(ErrorHandlingMiddleware)
app.add_middleware(RequestLoggingMiddleware, log_body=False)
app.add_middleware(SecurityHeadersMiddleware, config={"enable_hsts": os.getenv("ENABLE_HSTS", "false").lower() == "true"})
app.add_middleware(RateLimitMiddleware, rate_limit_manager=rate_limit_manager)
# CORS - configure for production
allowed_origins = os.getenv("CORS_ORIGINS", "*").split(",")
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=["*"], # Configure appropriately for production allow_origins=allowed_origins,
allow_credentials=True, allow_credentials=True,
allow_methods=["*"], allow_methods=["GET", "POST", "DELETE", "OPTIONS"],
allow_headers=["*"], allow_headers=["*"],
expose_headers=["X-Request-ID", "X-Original-Filename", "X-File-Size-MB", "X-Target-Language"]
) )
# Mount static files # Mount static files
@ -65,6 +120,20 @@ if static_dir.exists():
app.mount("/static", StaticFiles(directory=str(static_dir)), name="static") app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")
# Custom exception handler for ValidationError
@app.exception_handler(ValidationError)
async def validation_error_handler(request: Request, exc: ValidationError):
"""Handle validation errors with user-friendly messages"""
return JSONResponse(
status_code=400,
content={
"error": exc.code,
"message": exc.message,
"details": exc.details
}
)
@app.get("/") @app.get("/")
async def root(): async def root():
"""Root endpoint with API information""" """Root endpoint with API information"""
@ -83,11 +152,24 @@ async def root():
@app.get("/health") @app.get("/health")
async def health_check(): async def health_check():
"""Health check endpoint""" """Health check endpoint with detailed system status"""
return { health_status = await health_checker.check_health()
"status": "healthy", status_code = 200 if health_status.get("status") == "healthy" else 503
"translation_service": config.TRANSLATION_SERVICE
} return JSONResponse(
status_code=status_code,
content={
"status": health_status.get("status", "unknown"),
"translation_service": config.TRANSLATION_SERVICE,
"memory": health_status.get("memory", {}),
"disk": health_status.get("disk", {}),
"cleanup_service": health_status.get("cleanup_service", {}),
"rate_limits": {
"requests_per_minute": rate_limit_config.requests_per_minute,
"translations_per_minute": rate_limit_config.translations_per_minute,
}
}
)
@app.get("/languages") @app.get("/languages")
@ -128,6 +210,7 @@ async def get_supported_languages():
@app.post("/translate") @app.post("/translate")
async def translate_document( async def translate_document(
request: Request,
file: UploadFile = File(..., description="Document file to translate (.xlsx, .docx, or .pptx)"), file: UploadFile = File(..., description="Document file to translate (.xlsx, .docx, or .pptx)"),
target_language: str = Form(..., description="Target language code (e.g., 'es', 'fr', 'de')"), target_language: str = Form(..., description="Target language code (e.g., 'es', 'fr', 'de')"),
source_language: str = Form(default="auto", description="Source language code (default: auto-detect)"), source_language: str = Form(default="auto", description="Source language code (default: auto-detect)"),
@ -160,11 +243,38 @@ async def translate_document(
""" """
input_path = None input_path = None
output_path = None output_path = None
request_id = getattr(request.state, 'request_id', 'unknown')
try: try:
# Validate inputs
sanitized_language = InputSanitizer.sanitize_language_code(target_language)
LanguageValidator.validate(sanitized_language)
ProviderValidator.validate(provider)
# Validate file before processing
validation_result = await file_validator.validate_async(file)
if not validation_result.is_valid:
raise ValidationError(
message=f"File validation failed: {'; '.join(validation_result.errors)}",
code="INVALID_FILE",
details={"errors": validation_result.errors, "warnings": validation_result.warnings}
)
# Log any warnings
if validation_result.warnings:
logger.warning(f"[{request_id}] File validation warnings: {validation_result.warnings}")
# Check rate limit for translations
client_ip = request.client.host if request.client else "unknown"
if not await rate_limit_manager.check_translation_limit(client_ip):
raise HTTPException(
status_code=429,
detail="Translation rate limit exceeded. Please try again later."
)
# Validate file extension # Validate file extension
file_extension = file_handler.validate_file_extension(file.filename) file_extension = file_handler.validate_file_extension(file.filename)
logger.info(f"Processing {file_extension} file: {file.filename}") logger.info(f"[{request_id}] Processing {file_extension} file: {file.filename}")
# Validate file size # Validate file size
file_handler.validate_file_size(file) file_handler.validate_file_size(file)
@ -178,7 +288,11 @@ async def translate_document(
output_path = config.OUTPUT_DIR / output_filename output_path = config.OUTPUT_DIR / output_filename
await file_handler.save_upload_file(file, input_path) await file_handler.save_upload_file(file, input_path)
logger.info(f"Saved input file to: {input_path}") logger.info(f"[{request_id}] Saved input file to: {input_path}")
# Track file for cleanup
await cleanup_manager.track_file(input_path, ttl_minutes=30)
await cleanup_manager.track_file(output_path, ttl_minutes=60)
# Configure translation provider # Configure translation provider
from services.translation_service import GoogleTranslationProvider, DeepLTranslationProvider, LibreTranslationProvider, OllamaTranslationProvider, OpenAITranslationProvider, translation_service from services.translation_service import GoogleTranslationProvider, DeepLTranslationProvider, LibreTranslationProvider, OllamaTranslationProvider, OpenAITranslationProvider, translation_service
@ -657,6 +771,74 @@ async def reconstruct_document(
raise HTTPException(status_code=500, detail=f"Failed to reconstruct document: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to reconstruct document: {str(e)}")
# ============== SaaS Management Endpoints ==============
@app.get("/metrics")
async def get_metrics():
"""Get system metrics and statistics for monitoring"""
health_status = await health_checker.check_health()
cleanup_stats = cleanup_manager.get_stats()
rate_limit_stats = rate_limit_manager.get_stats()
return {
"system": {
"memory": health_status.get("memory", {}),
"disk": health_status.get("disk", {}),
"status": health_status.get("status", "unknown")
},
"cleanup": cleanup_stats,
"rate_limits": rate_limit_stats,
"config": {
"max_file_size_mb": config.MAX_FILE_SIZE_MB,
"supported_extensions": list(config.SUPPORTED_EXTENSIONS),
"translation_service": config.TRANSLATION_SERVICE
}
}
@app.get("/rate-limit/status")
async def get_rate_limit_status(request: Request):
"""Get current rate limit status for the requesting client"""
client_ip = request.client.host if request.client else "unknown"
status = await rate_limit_manager.get_client_status(client_ip)
return {
"client_ip": client_ip,
"limits": {
"requests_per_minute": rate_limit_config.requests_per_minute,
"requests_per_hour": rate_limit_config.requests_per_hour,
"translations_per_minute": rate_limit_config.translations_per_minute,
"translations_per_hour": rate_limit_config.translations_per_hour
},
"current_usage": status
}
@app.post("/admin/cleanup/trigger")
async def trigger_cleanup():
"""Trigger manual cleanup of expired files"""
try:
cleaned = await cleanup_manager.cleanup_expired()
return {
"status": "success",
"files_cleaned": cleaned,
"message": f"Cleaned up {cleaned} expired files"
}
except Exception as e:
logger.error(f"Manual cleanup failed: {str(e)}")
raise HTTPException(status_code=500, detail=f"Cleanup failed: {str(e)}")
@app.get("/admin/files/tracked")
async def get_tracked_files():
"""Get list of currently tracked files"""
tracked = cleanup_manager.get_tracked_files()
return {
"count": len(tracked),
"files": tracked
}
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True) uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)

62
middleware/__init__.py Normal file
View File

@ -0,0 +1,62 @@
"""
Middleware package for SaaS robustness
This package provides:
- Rate limiting: Protect against abuse and ensure fair usage
- Validation: Validate all inputs before processing
- Security: Security headers, request logging, error handling
- Cleanup: Automatic file cleanup and resource management
"""
from .rate_limiting import (
RateLimitConfig,
RateLimitManager,
RateLimitMiddleware,
ClientRateLimiter,
)
from .validation import (
ValidationError,
ValidationResult,
FileValidator,
LanguageValidator,
ProviderValidator,
InputSanitizer,
)
from .security import (
SecurityHeadersMiddleware,
RequestLoggingMiddleware,
ErrorHandlingMiddleware,
)
from .cleanup import (
FileCleanupManager,
MemoryMonitor,
HealthChecker,
create_cleanup_manager,
)
__all__ = [
# Rate limiting
"RateLimitConfig",
"RateLimitManager",
"RateLimitMiddleware",
"ClientRateLimiter",
# Validation
"ValidationError",
"ValidationResult",
"FileValidator",
"LanguageValidator",
"ProviderValidator",
"InputSanitizer",
# Security
"SecurityHeadersMiddleware",
"RequestLoggingMiddleware",
"ErrorHandlingMiddleware",
# Cleanup
"FileCleanupManager",
"MemoryMonitor",
"HealthChecker",
"create_cleanup_manager",
]

400
middleware/cleanup.py Normal file
View File

@ -0,0 +1,400 @@
"""
Cleanup and Resource Management for SaaS robustness
Automatic cleanup of temporary files and resources
"""
import os
import time
import asyncio
import threading
from pathlib import Path
from datetime import datetime, timedelta
from typing import Optional, Set
import logging
logger = logging.getLogger(__name__)
class FileCleanupManager:
"""Manages automatic cleanup of temporary and output files"""
def __init__(
self,
upload_dir: Path,
output_dir: Path,
temp_dir: Path,
max_file_age_hours: int = 1,
cleanup_interval_minutes: int = 10,
max_total_size_gb: float = 10.0
):
self.upload_dir = Path(upload_dir)
self.output_dir = Path(output_dir)
self.temp_dir = Path(temp_dir)
self.max_file_age_seconds = max_file_age_hours * 3600
self.cleanup_interval = cleanup_interval_minutes * 60
self.max_total_size_bytes = int(max_total_size_gb * 1024 * 1024 * 1024)
self._running = False
self._task: Optional[asyncio.Task] = None
self._protected_files: Set[str] = set()
self._tracked_files: dict = {} # filepath -> {created, ttl_minutes}
self._lock = threading.Lock()
self._stats = {
"files_cleaned": 0,
"bytes_freed": 0,
"cleanup_runs": 0
}
async def track_file(self, filepath: Path, ttl_minutes: int = 60):
"""Track a file for automatic cleanup after TTL expires"""
with self._lock:
self._tracked_files[str(filepath)] = {
"created": time.time(),
"ttl_minutes": ttl_minutes,
"expires_at": time.time() + (ttl_minutes * 60)
}
def get_tracked_files(self) -> list:
"""Get list of currently tracked files with their status"""
now = time.time()
result = []
with self._lock:
for filepath, info in self._tracked_files.items():
remaining = info["expires_at"] - now
result.append({
"path": filepath,
"exists": Path(filepath).exists(),
"expires_in_seconds": max(0, int(remaining)),
"ttl_minutes": info["ttl_minutes"]
})
return result
async def cleanup_expired(self) -> int:
"""Cleanup expired tracked files"""
now = time.time()
cleaned = 0
to_remove = []
with self._lock:
for filepath, info in list(self._tracked_files.items()):
if now > info["expires_at"]:
to_remove.append(filepath)
for filepath in to_remove:
try:
path = Path(filepath)
if path.exists() and not self.is_protected(path):
size = path.stat().st_size
path.unlink()
cleaned += 1
self._stats["files_cleaned"] += 1
self._stats["bytes_freed"] += size
logger.info(f"Cleaned expired file: {filepath}")
with self._lock:
self._tracked_files.pop(filepath, None)
except Exception as e:
logger.warning(f"Failed to clean expired file {filepath}: {e}")
return cleaned
def get_stats(self) -> dict:
"""Get cleanup statistics"""
disk_usage = self.get_disk_usage()
with self._lock:
tracked_count = len(self._tracked_files)
return {
"files_cleaned_total": self._stats["files_cleaned"],
"bytes_freed_total_mb": round(self._stats["bytes_freed"] / (1024 * 1024), 2),
"cleanup_runs": self._stats["cleanup_runs"],
"tracked_files": tracked_count,
"disk_usage": disk_usage,
"is_running": self._running
}
def protect_file(self, filepath: Path):
"""Mark a file as protected (being processed)"""
with self._lock:
self._protected_files.add(str(filepath))
def unprotect_file(self, filepath: Path):
"""Remove protection from a file"""
with self._lock:
self._protected_files.discard(str(filepath))
def is_protected(self, filepath: Path) -> bool:
"""Check if a file is protected"""
with self._lock:
return str(filepath) in self._protected_files
async def start(self):
"""Start the cleanup background task"""
if self._running:
return
self._running = True
self._task = asyncio.create_task(self._cleanup_loop())
logger.info("File cleanup manager started")
async def stop(self):
"""Stop the cleanup background task"""
self._running = False
if self._task:
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
logger.info("File cleanup manager stopped")
async def _cleanup_loop(self):
"""Background loop for periodic cleanup"""
while self._running:
try:
await self.cleanup()
await self.cleanup_expired()
self._stats["cleanup_runs"] += 1
except Exception as e:
logger.error(f"Cleanup error: {e}")
await asyncio.sleep(self.cleanup_interval)
async def cleanup(self) -> dict:
"""Perform cleanup of old files"""
stats = {
"files_deleted": 0,
"bytes_freed": 0,
"errors": []
}
now = time.time()
# Cleanup each directory
for directory in [self.upload_dir, self.output_dir, self.temp_dir]:
if not directory.exists():
continue
for filepath in directory.iterdir():
if not filepath.is_file():
continue
# Skip protected files
if self.is_protected(filepath):
continue
try:
# Check file age
file_age = now - filepath.stat().st_mtime
if file_age > self.max_file_age_seconds:
file_size = filepath.stat().st_size
filepath.unlink()
stats["files_deleted"] += 1
stats["bytes_freed"] += file_size
logger.debug(f"Deleted old file: {filepath}")
except Exception as e:
stats["errors"].append(str(e))
logger.warning(f"Failed to delete {filepath}: {e}")
# Force cleanup if total size exceeds limit
await self._enforce_size_limit(stats)
if stats["files_deleted"] > 0:
mb_freed = stats["bytes_freed"] / (1024 * 1024)
logger.info(f"Cleanup: deleted {stats['files_deleted']} files, freed {mb_freed:.2f}MB")
return stats
async def _enforce_size_limit(self, stats: dict):
"""Delete oldest files if total size exceeds limit"""
files_with_mtime = []
total_size = 0
for directory in [self.upload_dir, self.output_dir, self.temp_dir]:
if not directory.exists():
continue
for filepath in directory.iterdir():
if not filepath.is_file() or self.is_protected(filepath):
continue
try:
stat = filepath.stat()
files_with_mtime.append((filepath, stat.st_mtime, stat.st_size))
total_size += stat.st_size
except Exception:
pass
# If under limit, nothing to do
if total_size <= self.max_total_size_bytes:
return
# Sort by modification time (oldest first)
files_with_mtime.sort(key=lambda x: x[1])
# Delete oldest files until under limit
for filepath, _, size in files_with_mtime:
if total_size <= self.max_total_size_bytes:
break
try:
filepath.unlink()
total_size -= size
stats["files_deleted"] += 1
stats["bytes_freed"] += size
logger.info(f"Deleted file to free space: {filepath}")
except Exception as e:
stats["errors"].append(str(e))
def get_disk_usage(self) -> dict:
"""Get current disk usage statistics"""
total_files = 0
total_size = 0
for directory in [self.upload_dir, self.output_dir, self.temp_dir]:
if not directory.exists():
continue
for filepath in directory.iterdir():
if filepath.is_file():
total_files += 1
try:
total_size += filepath.stat().st_size
except Exception:
pass
return {
"total_files": total_files,
"total_size_mb": round(total_size / (1024 * 1024), 2),
"max_size_gb": self.max_total_size_bytes / (1024 * 1024 * 1024),
"usage_percent": round((total_size / self.max_total_size_bytes) * 100, 1) if self.max_total_size_bytes > 0 else 0,
"directories": {
"uploads": str(self.upload_dir),
"outputs": str(self.output_dir),
"temp": str(self.temp_dir)
}
}
class MemoryMonitor:
"""Monitors memory usage and triggers cleanup if needed"""
def __init__(self, max_memory_percent: float = 80.0):
self.max_memory_percent = max_memory_percent
self._high_memory_callbacks = []
def get_memory_usage(self) -> dict:
"""Get current memory usage"""
try:
import psutil
process = psutil.Process()
memory_info = process.memory_info()
system_memory = psutil.virtual_memory()
return {
"process_rss_mb": round(memory_info.rss / (1024 * 1024), 2),
"process_vms_mb": round(memory_info.vms / (1024 * 1024), 2),
"system_total_gb": round(system_memory.total / (1024 * 1024 * 1024), 2),
"system_available_gb": round(system_memory.available / (1024 * 1024 * 1024), 2),
"system_percent": system_memory.percent
}
except ImportError:
return {"error": "psutil not installed"}
except Exception as e:
return {"error": str(e)}
def check_memory(self) -> bool:
"""Check if memory usage is within limits"""
usage = self.get_memory_usage()
if "error" in usage:
return True # Can't check, assume OK
return usage.get("system_percent", 0) < self.max_memory_percent
def on_high_memory(self, callback):
"""Register callback for high memory situations"""
self._high_memory_callbacks.append(callback)
class HealthChecker:
"""Comprehensive health checking for the application"""
def __init__(self, cleanup_manager: FileCleanupManager, memory_monitor: MemoryMonitor):
self.cleanup_manager = cleanup_manager
self.memory_monitor = memory_monitor
self.start_time = datetime.now()
self._translation_count = 0
self._error_count = 0
self._lock = threading.Lock()
def record_translation(self, success: bool = True):
"""Record a translation attempt"""
with self._lock:
self._translation_count += 1
if not success:
self._error_count += 1
async def check_health(self) -> dict:
"""Get comprehensive health status (async version)"""
return self.get_health()
def get_health(self) -> dict:
"""Get comprehensive health status"""
memory = self.memory_monitor.get_memory_usage()
disk = self.cleanup_manager.get_disk_usage()
# Determine overall status
status = "healthy"
issues = []
if "error" not in memory:
if memory.get("system_percent", 0) > 90:
status = "degraded"
issues.append("High memory usage")
elif memory.get("system_percent", 0) > 80:
issues.append("Memory usage elevated")
if disk.get("usage_percent", 0) > 90:
status = "degraded"
issues.append("High disk usage")
elif disk.get("usage_percent", 0) > 80:
issues.append("Disk usage elevated")
uptime = datetime.now() - self.start_time
return {
"status": status,
"issues": issues,
"uptime_seconds": int(uptime.total_seconds()),
"uptime_human": str(uptime).split('.')[0],
"translations": {
"total": self._translation_count,
"errors": self._error_count,
"success_rate": round(
((self._translation_count - self._error_count) / self._translation_count * 100)
if self._translation_count > 0 else 100, 1
)
},
"memory": memory,
"disk": disk,
"cleanup_service": self.cleanup_manager.get_stats(),
"timestamp": datetime.now().isoformat()
}
# Create default instances
def create_cleanup_manager(config) -> FileCleanupManager:
"""Create cleanup manager with config"""
return FileCleanupManager(
upload_dir=config.UPLOAD_DIR,
output_dir=config.OUTPUT_DIR,
temp_dir=config.TEMP_DIR,
max_file_age_hours=getattr(config, 'MAX_FILE_AGE_HOURS', 1),
cleanup_interval_minutes=getattr(config, 'CLEANUP_INTERVAL_MINUTES', 10),
max_total_size_gb=getattr(config, 'MAX_TOTAL_SIZE_GB', 10.0)
)

328
middleware/rate_limiting.py Normal file
View File

@ -0,0 +1,328 @@
"""
Rate Limiting Middleware for SaaS robustness
Protects against abuse and ensures fair usage
"""
import time
import asyncio
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Dict, Optional
from fastapi import Request, HTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
import logging
logger = logging.getLogger(__name__)
@dataclass
class RateLimitConfig:
"""Configuration for rate limiting"""
# Requests per window
requests_per_minute: int = 30
requests_per_hour: int = 200
requests_per_day: int = 1000
# Translation-specific limits
translations_per_minute: int = 10
translations_per_hour: int = 50
max_concurrent_translations: int = 5
# File size limits (MB)
max_file_size_mb: int = 50
max_total_size_per_hour_mb: int = 500
# Burst protection
burst_limit: int = 10 # Max requests in 1 second
# Whitelist IPs (no rate limiting)
whitelist_ips: list = field(default_factory=lambda: ["127.0.0.1", "::1"])
class TokenBucket:
"""Token bucket algorithm for rate limiting"""
def __init__(self, capacity: int, refill_rate: float):
self.capacity = capacity
self.refill_rate = refill_rate # tokens per second
self.tokens = capacity
self.last_refill = time.time()
self._lock = asyncio.Lock()
async def consume(self, tokens: int = 1) -> bool:
"""Try to consume tokens, return True if successful"""
async with self._lock:
self._refill()
if self.tokens >= tokens:
self.tokens -= tokens
return True
return False
def _refill(self):
"""Refill tokens based on time elapsed"""
now = time.time()
elapsed = now - self.last_refill
self.tokens = min(self.capacity, self.tokens + elapsed * self.refill_rate)
self.last_refill = now
class SlidingWindowCounter:
"""Sliding window counter for accurate rate limiting"""
def __init__(self, window_seconds: int, max_requests: int):
self.window_seconds = window_seconds
self.max_requests = max_requests
self.requests: list = []
self._lock = asyncio.Lock()
async def is_allowed(self) -> bool:
"""Check if a new request is allowed"""
async with self._lock:
now = time.time()
# Remove old requests outside the window
self.requests = [ts for ts in self.requests if now - ts < self.window_seconds]
if len(self.requests) < self.max_requests:
self.requests.append(now)
return True
return False
@property
def current_count(self) -> int:
"""Get current request count in window"""
now = time.time()
return len([ts for ts in self.requests if now - ts < self.window_seconds])
class ClientRateLimiter:
"""Per-client rate limiter with multiple windows"""
def __init__(self, config: RateLimitConfig):
self.config = config
self.minute_counter = SlidingWindowCounter(60, config.requests_per_minute)
self.hour_counter = SlidingWindowCounter(3600, config.requests_per_hour)
self.day_counter = SlidingWindowCounter(86400, config.requests_per_day)
self.burst_bucket = TokenBucket(config.burst_limit, config.burst_limit)
self.translation_minute = SlidingWindowCounter(60, config.translations_per_minute)
self.translation_hour = SlidingWindowCounter(3600, config.translations_per_hour)
self.concurrent_translations = 0
self.total_size_hour: list = [] # List of (timestamp, size_mb)
self._lock = asyncio.Lock()
async def check_request(self) -> tuple[bool, str]:
"""Check if request is allowed, return (allowed, reason)"""
# Check burst limit
if not await self.burst_bucket.consume():
return False, "Too many requests. Please slow down."
# Check minute limit
if not await self.minute_counter.is_allowed():
return False, f"Rate limit exceeded. Max {self.config.requests_per_minute} requests per minute."
# Check hour limit
if not await self.hour_counter.is_allowed():
return False, f"Hourly limit exceeded. Max {self.config.requests_per_hour} requests per hour."
# Check day limit
if not await self.day_counter.is_allowed():
return False, f"Daily limit exceeded. Max {self.config.requests_per_day} requests per day."
return True, ""
async def check_translation(self, file_size_mb: float = 0) -> tuple[bool, str]:
"""Check if translation request is allowed"""
async with self._lock:
# Check concurrent limit
if self.concurrent_translations >= self.config.max_concurrent_translations:
return False, f"Too many concurrent translations. Max {self.config.max_concurrent_translations} at a time."
# Check translation per minute
if not await self.translation_minute.is_allowed():
return False, f"Translation rate limit exceeded. Max {self.config.translations_per_minute} translations per minute."
# Check translation per hour
if not await self.translation_hour.is_allowed():
return False, f"Hourly translation limit exceeded. Max {self.config.translations_per_hour} translations per hour."
# Check total size per hour
async with self._lock:
now = time.time()
self.total_size_hour = [(ts, size) for ts, size in self.total_size_hour if now - ts < 3600]
total_size = sum(size for _, size in self.total_size_hour)
if total_size + file_size_mb > self.config.max_total_size_per_hour_mb:
return False, f"Hourly data limit exceeded. Max {self.config.max_total_size_per_hour_mb}MB per hour."
self.total_size_hour.append((now, file_size_mb))
return True, ""
async def start_translation(self):
"""Mark start of translation"""
async with self._lock:
self.concurrent_translations += 1
async def end_translation(self):
"""Mark end of translation"""
async with self._lock:
self.concurrent_translations = max(0, self.concurrent_translations - 1)
def get_stats(self) -> dict:
"""Get current rate limit stats"""
return {
"requests_minute": self.minute_counter.current_count,
"requests_hour": self.hour_counter.current_count,
"requests_day": self.day_counter.current_count,
"translations_minute": self.translation_minute.current_count,
"translations_hour": self.translation_hour.current_count,
"concurrent_translations": self.concurrent_translations,
}
class RateLimitManager:
"""Manages rate limiters for all clients"""
def __init__(self, config: Optional[RateLimitConfig] = None):
self.config = config or RateLimitConfig()
self.clients: Dict[str, ClientRateLimiter] = defaultdict(lambda: ClientRateLimiter(self.config))
self._cleanup_interval = 3600 # Cleanup old clients every hour
self._last_cleanup = time.time()
self._total_requests = 0
self._total_translations = 0
def get_client_id(self, request: Request) -> str:
"""Extract client identifier from request"""
# Try to get real IP from headers (for proxied requests)
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
# Fall back to direct client IP
if request.client:
return request.client.host
return "unknown"
def is_whitelisted(self, client_id: str) -> bool:
"""Check if client is whitelisted"""
return client_id in self.config.whitelist_ips
async def check_request(self, request: Request) -> tuple[bool, str, str]:
"""Check if request is allowed, return (allowed, reason, client_id)"""
client_id = self.get_client_id(request)
self._total_requests += 1
if self.is_whitelisted(client_id):
return True, "", client_id
client = self.clients[client_id]
allowed, reason = await client.check_request()
return allowed, reason, client_id
async def check_translation(self, request: Request, file_size_mb: float = 0) -> tuple[bool, str]:
"""Check if translation is allowed"""
client_id = self.get_client_id(request)
self._total_translations += 1
if self.is_whitelisted(client_id):
return True, ""
client = self.clients[client_id]
return await client.check_translation(file_size_mb)
async def check_translation_limit(self, client_id: str, file_size_mb: float = 0) -> bool:
"""Check if translation is allowed for a specific client ID"""
if self.is_whitelisted(client_id):
return True
client = self.clients[client_id]
allowed, _ = await client.check_translation(file_size_mb)
return allowed
def get_client_stats(self, request: Request) -> dict:
"""Get rate limit stats for a client"""
client_id = self.get_client_id(request)
client = self.clients[client_id]
return {
"client_id": client_id,
"is_whitelisted": self.is_whitelisted(client_id),
**client.get_stats()
}
async def get_client_status(self, client_id: str) -> dict:
"""Get current usage status for a client"""
if client_id not in self.clients:
return {"status": "no_activity", "requests": 0}
client = self.clients[client_id]
stats = client.get_stats()
return {
"requests_used_minute": stats["requests_minute"],
"requests_used_hour": stats["requests_hour"],
"translations_used_minute": stats["translations_minute"],
"translations_used_hour": stats["translations_hour"],
"concurrent_translations": stats["concurrent_translations"],
"is_whitelisted": self.is_whitelisted(client_id)
}
def get_stats(self) -> dict:
"""Get global rate limiting statistics"""
return {
"total_requests": self._total_requests,
"total_translations": self._total_translations,
"active_clients": len(self.clients),
"config": {
"requests_per_minute": self.config.requests_per_minute,
"requests_per_hour": self.config.requests_per_hour,
"translations_per_minute": self.config.translations_per_minute,
"translations_per_hour": self.config.translations_per_hour,
"max_concurrent_translations": self.config.max_concurrent_translations
}
}
class RateLimitMiddleware(BaseHTTPMiddleware):
"""FastAPI middleware for rate limiting"""
def __init__(self, app, rate_limit_manager: RateLimitManager):
super().__init__(app)
self.manager = rate_limit_manager
async def dispatch(self, request: Request, call_next):
# Skip rate limiting for health checks and static files
if request.url.path in ["/health", "/", "/docs", "/openapi.json", "/redoc"]:
return await call_next(request)
if request.url.path.startswith("/static"):
return await call_next(request)
# Check rate limit
allowed, reason, client_id = await self.manager.check_request(request)
if not allowed:
logger.warning(f"Rate limit exceeded for {client_id}: {reason}")
return JSONResponse(
status_code=429,
content={
"error": "rate_limit_exceeded",
"message": reason,
"retry_after": 60
},
headers={"Retry-After": "60"}
)
# Add client info to request state for use in endpoints
request.state.client_id = client_id
request.state.rate_limiter = self.manager.clients[client_id]
return await call_next(request)
# Global rate limit manager
rate_limit_manager = RateLimitManager()

142
middleware/security.py Normal file
View File

@ -0,0 +1,142 @@
"""
Security Headers Middleware for SaaS robustness
Adds security headers to all responses
"""
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
import logging
logger = logging.getLogger(__name__)
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""Add security headers to all responses"""
def __init__(self, app, config: dict = None):
super().__init__(app)
self.config = config or {}
async def dispatch(self, request: Request, call_next) -> Response:
response = await call_next(request)
# Prevent clickjacking
response.headers["X-Frame-Options"] = "DENY"
# Prevent MIME type sniffing
response.headers["X-Content-Type-Options"] = "nosniff"
# Enable XSS filter
response.headers["X-XSS-Protection"] = "1; mode=block"
# Referrer policy
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
# Permissions policy
response.headers["Permissions-Policy"] = "geolocation=(), microphone=(), camera=()"
# Content Security Policy (adjust for your frontend)
if not request.url.path.startswith("/docs") and not request.url.path.startswith("/redoc"):
response.headers["Content-Security-Policy"] = (
"default-src 'self'; "
"script-src 'self' 'unsafe-inline' 'unsafe-eval' blob:; "
"style-src 'self' 'unsafe-inline'; "
"img-src 'self' data: blob:; "
"font-src 'self' data:; "
"connect-src 'self' http://localhost:* https://localhost:* ws://localhost:*; "
"worker-src 'self' blob:; "
)
# HSTS (only in production with HTTPS)
if self.config.get("enable_hsts", False):
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
return response
class RequestLoggingMiddleware(BaseHTTPMiddleware):
"""Log all requests for monitoring and debugging"""
def __init__(self, app, log_body: bool = False):
super().__init__(app)
self.log_body = log_body
async def dispatch(self, request: Request, call_next) -> Response:
import time
import uuid
# Generate request ID
request_id = str(uuid.uuid4())[:8]
request.state.request_id = request_id
# Get client info
client_ip = self._get_client_ip(request)
# Log request start
start_time = time.time()
logger.info(
f"[{request_id}] {request.method} {request.url.path} "
f"from {client_ip} - Started"
)
try:
response = await call_next(request)
# Log request completion
duration = time.time() - start_time
logger.info(
f"[{request_id}] {request.method} {request.url.path} "
f"- {response.status_code} in {duration:.3f}s"
)
# Add request ID to response headers
response.headers["X-Request-ID"] = request_id
return response
except Exception as e:
duration = time.time() - start_time
logger.error(
f"[{request_id}] {request.method} {request.url.path} "
f"- ERROR in {duration:.3f}s: {str(e)}"
)
raise
def _get_client_ip(self, request: Request) -> str:
"""Get real client IP from headers or connection"""
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
if request.client:
return request.client.host
return "unknown"
class ErrorHandlingMiddleware(BaseHTTPMiddleware):
"""Catch all unhandled exceptions and return proper error responses"""
async def dispatch(self, request: Request, call_next) -> Response:
from starlette.responses import JSONResponse
try:
return await call_next(request)
except Exception as e:
request_id = getattr(request.state, 'request_id', 'unknown')
logger.exception(f"[{request_id}] Unhandled exception: {str(e)}")
# Don't expose internal errors in production
return JSONResponse(
status_code=500,
content={
"error": "internal_server_error",
"message": "An unexpected error occurred. Please try again later.",
"request_id": request_id
}
)

440
middleware/validation.py Normal file
View File

@ -0,0 +1,440 @@
"""
Input Validation Module for SaaS robustness
Validates all user inputs before processing
"""
import re
import magic
from pathlib import Path
from typing import Optional, List, Set
from fastapi import UploadFile, HTTPException
import logging
logger = logging.getLogger(__name__)
class ValidationError(Exception):
"""Custom validation error with user-friendly messages"""
def __init__(self, message: str, code: str = "validation_error", details: Optional[dict] = None):
self.message = message
self.code = code
self.details = details or {}
super().__init__(message)
class ValidationResult:
"""Result of a validation check"""
def __init__(self, is_valid: bool = True, errors: List[str] = None, warnings: List[str] = None, data: dict = None):
self.is_valid = is_valid
self.errors = errors or []
self.warnings = warnings or []
self.data = data or {}
class FileValidator:
"""Validates uploaded files for security and compatibility"""
# Allowed MIME types mapped to extensions
ALLOWED_MIME_TYPES = {
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
"application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx",
}
# Magic bytes for Office Open XML files (ZIP format)
OFFICE_MAGIC_BYTES = b"PK\x03\x04"
def __init__(
self,
max_size_mb: int = 50,
allowed_extensions: Set[str] = None,
scan_content: bool = True
):
self.max_size_bytes = max_size_mb * 1024 * 1024
self.max_size_mb = max_size_mb
self.allowed_extensions = allowed_extensions or {".xlsx", ".docx", ".pptx"}
self.scan_content = scan_content
async def validate_async(self, file: UploadFile) -> ValidationResult:
"""
Validate an uploaded file asynchronously
Returns ValidationResult with is_valid, errors, warnings
"""
errors = []
warnings = []
data = {}
try:
# Validate filename
if not file.filename:
errors.append("Filename is required")
return ValidationResult(is_valid=False, errors=errors)
# Sanitize filename
try:
safe_filename = self._sanitize_filename(file.filename)
data["safe_filename"] = safe_filename
except ValidationError as e:
errors.append(str(e.message))
return ValidationResult(is_valid=False, errors=errors)
# Validate extension
try:
extension = self._validate_extension(safe_filename)
data["extension"] = extension
except ValidationError as e:
errors.append(str(e.message))
return ValidationResult(is_valid=False, errors=errors)
# Read file content for validation
content = await file.read()
await file.seek(0) # Reset for later processing
# Validate file size
file_size = len(content)
data["size_bytes"] = file_size
data["size_mb"] = round(file_size / (1024*1024), 2)
if file_size > self.max_size_bytes:
errors.append(f"File too large. Maximum size is {self.max_size_mb}MB, got {file_size / (1024*1024):.1f}MB")
return ValidationResult(is_valid=False, errors=errors, data=data)
if file_size == 0:
errors.append("File is empty")
return ValidationResult(is_valid=False, errors=errors, data=data)
# Warn about large files
if file_size > self.max_size_bytes * 0.8:
warnings.append(f"File is {data['size_mb']}MB, approaching the {self.max_size_mb}MB limit")
# Validate magic bytes
if self.scan_content:
try:
self._validate_magic_bytes(content, extension)
except ValidationError as e:
errors.append(str(e.message))
return ValidationResult(is_valid=False, errors=errors, data=data)
# Validate MIME type
try:
mime_type = self._detect_mime_type(content)
data["mime_type"] = mime_type
self._validate_mime_type(mime_type, extension)
except ValidationError as e:
warnings.append(f"MIME type warning: {e.message}")
except Exception:
warnings.append("Could not verify MIME type")
data["original_filename"] = file.filename
return ValidationResult(is_valid=True, errors=errors, warnings=warnings, data=data)
except Exception as e:
logger.error(f"Validation error: {str(e)}")
errors.append(f"Validation failed: {str(e)}")
return ValidationResult(is_valid=False, errors=errors, warnings=warnings, data=data)
async def validate(self, file: UploadFile) -> dict:
"""
Validate an uploaded file
Returns validation info dict or raises ValidationError
"""
# Validate filename
if not file.filename:
raise ValidationError(
"Filename is required",
code="missing_filename"
)
# Sanitize filename
safe_filename = self._sanitize_filename(file.filename)
# Validate extension
extension = self._validate_extension(safe_filename)
# Read file content for validation
content = await file.read()
await file.seek(0) # Reset for later processing
# Validate file size
file_size = len(content)
if file_size > self.max_size_bytes:
raise ValidationError(
f"File too large. Maximum size is {self.max_size_mb}MB, got {file_size / (1024*1024):.1f}MB",
code="file_too_large",
details={"max_mb": self.max_size_mb, "actual_mb": round(file_size / (1024*1024), 2)}
)
if file_size == 0:
raise ValidationError(
"File is empty",
code="empty_file"
)
# Validate magic bytes (file signature)
if self.scan_content:
self._validate_magic_bytes(content, extension)
# Validate MIME type
mime_type = self._detect_mime_type(content)
self._validate_mime_type(mime_type, extension)
return {
"original_filename": file.filename,
"safe_filename": safe_filename,
"extension": extension,
"size_bytes": file_size,
"size_mb": round(file_size / (1024*1024), 2),
"mime_type": mime_type
}
def _sanitize_filename(self, filename: str) -> str:
"""Sanitize filename to prevent path traversal and other attacks"""
# Remove path components
filename = Path(filename).name
# Remove null bytes and control characters
filename = re.sub(r'[\x00-\x1f\x7f-\x9f]', '', filename)
# Remove potentially dangerous characters
filename = re.sub(r'[<>:"/\\|?*]', '_', filename)
# Limit length
if len(filename) > 255:
name, ext = filename.rsplit('.', 1) if '.' in filename else (filename, '')
filename = name[:250] + ('.' + ext if ext else '')
# Ensure not empty after sanitization
if not filename or filename.strip() == '':
raise ValidationError(
"Invalid filename",
code="invalid_filename"
)
return filename
def _validate_extension(self, filename: str) -> str:
"""Validate and return the file extension"""
if '.' not in filename:
raise ValidationError(
f"File must have an extension. Supported: {', '.join(self.allowed_extensions)}",
code="missing_extension",
details={"allowed_extensions": list(self.allowed_extensions)}
)
extension = '.' + filename.rsplit('.', 1)[1].lower()
if extension not in self.allowed_extensions:
raise ValidationError(
f"File type '{extension}' not supported. Supported types: {', '.join(self.allowed_extensions)}",
code="unsupported_file_type",
details={"extension": extension, "allowed_extensions": list(self.allowed_extensions)}
)
return extension
def _validate_magic_bytes(self, content: bytes, extension: str):
"""Validate file magic bytes match expected format"""
# All supported formats are Office Open XML (ZIP-based)
if not content.startswith(self.OFFICE_MAGIC_BYTES):
raise ValidationError(
"File content does not match expected format. The file may be corrupted or not a valid Office document.",
code="invalid_file_content"
)
def _detect_mime_type(self, content: bytes) -> str:
"""Detect MIME type from file content"""
try:
mime = magic.Magic(mime=True)
return mime.from_buffer(content)
except Exception:
# Fallback to basic detection
if content.startswith(self.OFFICE_MAGIC_BYTES):
return "application/zip"
return "application/octet-stream"
def _validate_mime_type(self, mime_type: str, extension: str):
"""Validate MIME type matches extension"""
# Office Open XML files may be detected as ZIP
allowed_mimes = list(self.ALLOWED_MIME_TYPES.keys()) + ["application/zip", "application/octet-stream"]
if mime_type not in allowed_mimes:
raise ValidationError(
f"Invalid file type detected. Expected Office document, got: {mime_type}",
code="invalid_mime_type",
details={"detected_mime": mime_type}
)
class LanguageValidator:
"""Validates language codes"""
SUPPORTED_LANGUAGES = {
# ISO 639-1 codes
"af", "sq", "am", "ar", "hy", "az", "eu", "be", "bn", "bs",
"bg", "ca", "ceb", "zh", "zh-CN", "zh-TW", "co", "hr", "cs",
"da", "nl", "en", "eo", "et", "fi", "fr", "fy", "gl", "ka",
"de", "el", "gu", "ht", "ha", "haw", "he", "hi", "hmn", "hu",
"is", "ig", "id", "ga", "it", "ja", "jv", "kn", "kk", "km",
"rw", "ko", "ku", "ky", "lo", "la", "lv", "lt", "lb", "mk",
"mg", "ms", "ml", "mt", "mi", "mr", "mn", "my", "ne", "no",
"ny", "or", "ps", "fa", "pl", "pt", "pa", "ro", "ru", "sm",
"gd", "sr", "st", "sn", "sd", "si", "sk", "sl", "so", "es",
"su", "sw", "sv", "tl", "tg", "ta", "tt", "te", "th", "tr",
"tk", "uk", "ur", "ug", "uz", "vi", "cy", "xh", "yi", "yo",
"zu", "auto"
}
LANGUAGE_NAMES = {
"en": "English", "es": "Spanish", "fr": "French", "de": "German",
"it": "Italian", "pt": "Portuguese", "ru": "Russian", "zh": "Chinese",
"zh-CN": "Chinese (Simplified)", "zh-TW": "Chinese (Traditional)",
"ja": "Japanese", "ko": "Korean", "ar": "Arabic", "hi": "Hindi",
"nl": "Dutch", "pl": "Polish", "tr": "Turkish", "sv": "Swedish",
"da": "Danish", "no": "Norwegian", "fi": "Finnish", "cs": "Czech",
"el": "Greek", "th": "Thai", "vi": "Vietnamese", "id": "Indonesian",
"uk": "Ukrainian", "ro": "Romanian", "hu": "Hungarian", "auto": "Auto-detect"
}
@classmethod
def validate(cls, language_code: str, field_name: str = "language") -> str:
"""Validate and normalize language code"""
if not language_code:
raise ValidationError(
f"{field_name} is required",
code="missing_language"
)
# Normalize
normalized = language_code.strip().lower()
# Handle common variations
if normalized in ["chinese", "cn"]:
normalized = "zh-CN"
elif normalized in ["chinese-traditional", "tw"]:
normalized = "zh-TW"
if normalized not in cls.SUPPORTED_LANGUAGES:
raise ValidationError(
f"Unsupported language code: '{language_code}'. See /languages for supported codes.",
code="unsupported_language",
details={"language": language_code}
)
return normalized
@classmethod
def get_language_name(cls, code: str) -> str:
"""Get human-readable language name"""
return cls.LANGUAGE_NAMES.get(code, code.upper())
class ProviderValidator:
"""Validates translation provider configuration"""
SUPPORTED_PROVIDERS = {"google", "ollama", "deepl", "libre", "openai", "webllm"}
@classmethod
def validate(cls, provider: str, **kwargs) -> dict:
"""Validate provider and its required configuration"""
if not provider:
raise ValidationError(
"Translation provider is required",
code="missing_provider"
)
normalized = provider.strip().lower()
if normalized not in cls.SUPPORTED_PROVIDERS:
raise ValidationError(
f"Unsupported provider: '{provider}'. Supported: {', '.join(cls.SUPPORTED_PROVIDERS)}",
code="unsupported_provider",
details={"provider": provider, "supported": list(cls.SUPPORTED_PROVIDERS)}
)
# Provider-specific validation
if normalized == "deepl":
if not kwargs.get("deepl_api_key"):
raise ValidationError(
"DeepL API key is required when using DeepL provider",
code="missing_deepl_key"
)
elif normalized == "openai":
if not kwargs.get("openai_api_key"):
raise ValidationError(
"OpenAI API key is required when using OpenAI provider",
code="missing_openai_key"
)
elif normalized == "ollama":
# Ollama doesn't require API key but may need model
model = kwargs.get("ollama_model", "")
if not model:
logger.warning("No Ollama model specified, will use default")
return {"provider": normalized, "validated": True}
class InputSanitizer:
"""Sanitizes user inputs to prevent injection attacks"""
@staticmethod
def sanitize_text(text: str, max_length: int = 10000) -> str:
"""Sanitize text input"""
if not text:
return ""
# Remove null bytes
text = text.replace('\x00', '')
# Limit length
if len(text) > max_length:
text = text[:max_length]
return text.strip()
@staticmethod
def sanitize_language_code(code: str) -> str:
"""Sanitize and normalize language code"""
if not code:
return "auto"
# Remove dangerous characters, keep only alphanumeric and hyphen
code = re.sub(r'[^a-zA-Z0-9\-]', '', code.strip())
# Limit length
if len(code) > 10:
code = code[:10]
return code.lower() if code else "auto"
@staticmethod
def sanitize_url(url: str) -> str:
"""Sanitize URL input"""
if not url:
return ""
url = url.strip()
# Basic URL validation
if not re.match(r'^https?://', url, re.IGNORECASE):
raise ValidationError(
"Invalid URL format. Must start with http:// or https://",
code="invalid_url"
)
# Remove trailing slashes
url = url.rstrip('/')
return url
@staticmethod
def sanitize_api_key(key: str) -> str:
"""Sanitize API key (just trim, no logging)"""
if not key:
return ""
return key.strip()
# Default validators
file_validator = FileValidator()

View File

@ -14,3 +14,7 @@ pandas==2.1.4
requests==2.31.0 requests==2.31.0
ipykernel==6.27.1 ipykernel==6.27.1
openai>=1.0.0 openai>=1.0.0
# SaaS robustness dependencies
psutil==5.9.8
python-magic-bin==0.4.14 # For Windows, use python-magic on Linux