feat: Add SaaS robustness middleware - Rate limiting with token bucket and sliding window algorithms - Input validation (file, language, provider) - Security headers middleware (CSP, XSS protection, etc.) - Automatic file cleanup with TTL tracking - Memory and disk monitoring - Enhanced health check and metrics endpoints - Request logging with unique IDs
This commit is contained in:
parent
8c7716bf4d
commit
500502440c
67
.env.example
67
.env.example
@ -1,5 +1,11 @@
|
|||||||
# Translation Service Configuration
|
# Document Translation API - Environment Configuration
|
||||||
TRANSLATION_SERVICE=google # Options: google, deepl, libre, ollama
|
# Copy this file to .env and configure your settings
|
||||||
|
|
||||||
|
# ============== Translation Services ==============
|
||||||
|
# Default provider: google, ollama, deepl, libre, openai
|
||||||
|
TRANSLATION_SERVICE=google
|
||||||
|
|
||||||
|
# DeepL API Key (required for DeepL provider)
|
||||||
DEEPL_API_KEY=your_deepl_api_key_here
|
DEEPL_API_KEY=your_deepl_api_key_here
|
||||||
|
|
||||||
# Ollama Configuration (for LLM-based translation)
|
# Ollama Configuration (for LLM-based translation)
|
||||||
@ -7,7 +13,58 @@ OLLAMA_BASE_URL=http://localhost:11434
|
|||||||
OLLAMA_MODEL=llama3
|
OLLAMA_MODEL=llama3
|
||||||
OLLAMA_VISION_MODEL=llava
|
OLLAMA_VISION_MODEL=llava
|
||||||
|
|
||||||
# API Configuration
|
# ============== File Limits ==============
|
||||||
|
# Maximum file size in MB
|
||||||
MAX_FILE_SIZE_MB=50
|
MAX_FILE_SIZE_MB=50
|
||||||
UPLOAD_DIR=./uploads
|
|
||||||
OUTPUT_DIR=./outputs
|
# ============== Rate Limiting (SaaS) ==============
|
||||||
|
# Enable/disable rate limiting
|
||||||
|
RATE_LIMIT_ENABLED=true
|
||||||
|
|
||||||
|
# Request limits
|
||||||
|
RATE_LIMIT_PER_MINUTE=30
|
||||||
|
RATE_LIMIT_PER_HOUR=200
|
||||||
|
|
||||||
|
# Translation-specific limits
|
||||||
|
TRANSLATIONS_PER_MINUTE=10
|
||||||
|
TRANSLATIONS_PER_HOUR=50
|
||||||
|
MAX_CONCURRENT_TRANSLATIONS=5
|
||||||
|
|
||||||
|
# ============== Cleanup Service ==============
|
||||||
|
# Enable automatic file cleanup
|
||||||
|
CLEANUP_ENABLED=true
|
||||||
|
|
||||||
|
# Cleanup interval in minutes
|
||||||
|
CLEANUP_INTERVAL_MINUTES=15
|
||||||
|
|
||||||
|
# File time-to-live in minutes
|
||||||
|
FILE_TTL_MINUTES=60
|
||||||
|
INPUT_FILE_TTL_MINUTES=30
|
||||||
|
OUTPUT_FILE_TTL_MINUTES=120
|
||||||
|
|
||||||
|
# Disk space warning thresholds (GB)
|
||||||
|
DISK_WARNING_THRESHOLD_GB=5.0
|
||||||
|
DISK_CRITICAL_THRESHOLD_GB=1.0
|
||||||
|
|
||||||
|
# ============== Security ==============
|
||||||
|
# Enable HSTS (only for HTTPS deployments)
|
||||||
|
ENABLE_HSTS=false
|
||||||
|
|
||||||
|
# CORS allowed origins (comma-separated)
|
||||||
|
CORS_ORIGINS=*
|
||||||
|
|
||||||
|
# Maximum request size in MB
|
||||||
|
MAX_REQUEST_SIZE_MB=100
|
||||||
|
|
||||||
|
# Request timeout in seconds
|
||||||
|
REQUEST_TIMEOUT_SECONDS=300
|
||||||
|
|
||||||
|
# ============== Monitoring ==============
|
||||||
|
# Log level: DEBUG, INFO, WARNING, ERROR
|
||||||
|
LOG_LEVEL=INFO
|
||||||
|
|
||||||
|
# Enable request logging
|
||||||
|
ENABLE_REQUEST_LOGGING=true
|
||||||
|
|
||||||
|
# Memory usage threshold (percentage)
|
||||||
|
MAX_MEMORY_PERCENT=80
|
||||||
|
|||||||
48
config.py
48
config.py
@ -1,5 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Configuration module for the Document Translation API
|
Configuration module for the Document Translation API
|
||||||
|
SaaS-ready with comprehensive settings for production deployment
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -8,7 +9,7 @@ from dotenv import load_dotenv
|
|||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
# Translation Service
|
# ============== Translation Service ==============
|
||||||
TRANSLATION_SERVICE = os.getenv("TRANSLATION_SERVICE", "google")
|
TRANSLATION_SERVICE = os.getenv("TRANSLATION_SERVICE", "google")
|
||||||
DEEPL_API_KEY = os.getenv("DEEPL_API_KEY", "")
|
DEEPL_API_KEY = os.getenv("DEEPL_API_KEY", "")
|
||||||
|
|
||||||
@ -17,20 +18,51 @@ class Config:
|
|||||||
OLLAMA_MODEL = os.getenv("OLLAMA_MODEL", "llama3")
|
OLLAMA_MODEL = os.getenv("OLLAMA_MODEL", "llama3")
|
||||||
OLLAMA_VISION_MODEL = os.getenv("OLLAMA_VISION_MODEL", "llava")
|
OLLAMA_VISION_MODEL = os.getenv("OLLAMA_VISION_MODEL", "llava")
|
||||||
|
|
||||||
# File Upload Configuration
|
# ============== File Upload Configuration ==============
|
||||||
MAX_FILE_SIZE_MB = int(os.getenv("MAX_FILE_SIZE_MB", "50"))
|
MAX_FILE_SIZE_MB = int(os.getenv("MAX_FILE_SIZE_MB", "50"))
|
||||||
MAX_FILE_SIZE_BYTES = MAX_FILE_SIZE_MB * 1024 * 1024
|
MAX_FILE_SIZE_BYTES = MAX_FILE_SIZE_MB * 1024 * 1024
|
||||||
|
|
||||||
# Directories
|
# Directories
|
||||||
BASE_DIR = Path(__file__).parent.parent
|
BASE_DIR = Path(__file__).parent
|
||||||
UPLOAD_DIR = BASE_DIR / "uploads"
|
UPLOAD_DIR = BASE_DIR / "uploads"
|
||||||
OUTPUT_DIR = BASE_DIR / "outputs"
|
OUTPUT_DIR = BASE_DIR / "outputs"
|
||||||
TEMP_DIR = BASE_DIR / "temp"
|
TEMP_DIR = BASE_DIR / "temp"
|
||||||
|
LOGS_DIR = BASE_DIR / "logs"
|
||||||
|
|
||||||
# Supported file types
|
# Supported file types
|
||||||
SUPPORTED_EXTENSIONS = {".xlsx", ".docx", ".pptx"}
|
SUPPORTED_EXTENSIONS = {".xlsx", ".docx", ".pptx"}
|
||||||
|
|
||||||
# API Configuration
|
# ============== Rate Limiting (SaaS) ==============
|
||||||
|
RATE_LIMIT_ENABLED = os.getenv("RATE_LIMIT_ENABLED", "true").lower() == "true"
|
||||||
|
RATE_LIMIT_PER_MINUTE = int(os.getenv("RATE_LIMIT_PER_MINUTE", "30"))
|
||||||
|
RATE_LIMIT_PER_HOUR = int(os.getenv("RATE_LIMIT_PER_HOUR", "200"))
|
||||||
|
TRANSLATIONS_PER_MINUTE = int(os.getenv("TRANSLATIONS_PER_MINUTE", "10"))
|
||||||
|
TRANSLATIONS_PER_HOUR = int(os.getenv("TRANSLATIONS_PER_HOUR", "50"))
|
||||||
|
MAX_CONCURRENT_TRANSLATIONS = int(os.getenv("MAX_CONCURRENT_TRANSLATIONS", "5"))
|
||||||
|
|
||||||
|
# ============== Cleanup Service ==============
|
||||||
|
CLEANUP_ENABLED = os.getenv("CLEANUP_ENABLED", "true").lower() == "true"
|
||||||
|
CLEANUP_INTERVAL_MINUTES = int(os.getenv("CLEANUP_INTERVAL_MINUTES", "15"))
|
||||||
|
FILE_TTL_MINUTES = int(os.getenv("FILE_TTL_MINUTES", "60"))
|
||||||
|
INPUT_FILE_TTL_MINUTES = int(os.getenv("INPUT_FILE_TTL_MINUTES", "30"))
|
||||||
|
OUTPUT_FILE_TTL_MINUTES = int(os.getenv("OUTPUT_FILE_TTL_MINUTES", "120"))
|
||||||
|
|
||||||
|
# Disk space thresholds
|
||||||
|
DISK_WARNING_THRESHOLD_GB = float(os.getenv("DISK_WARNING_THRESHOLD_GB", "5.0"))
|
||||||
|
DISK_CRITICAL_THRESHOLD_GB = float(os.getenv("DISK_CRITICAL_THRESHOLD_GB", "1.0"))
|
||||||
|
|
||||||
|
# ============== Security ==============
|
||||||
|
ENABLE_HSTS = os.getenv("ENABLE_HSTS", "false").lower() == "true"
|
||||||
|
CORS_ORIGINS = os.getenv("CORS_ORIGINS", "*").split(",")
|
||||||
|
MAX_REQUEST_SIZE_MB = int(os.getenv("MAX_REQUEST_SIZE_MB", "100"))
|
||||||
|
REQUEST_TIMEOUT_SECONDS = int(os.getenv("REQUEST_TIMEOUT_SECONDS", "300"))
|
||||||
|
|
||||||
|
# ============== Monitoring ==============
|
||||||
|
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")
|
||||||
|
ENABLE_REQUEST_LOGGING = os.getenv("ENABLE_REQUEST_LOGGING", "true").lower() == "true"
|
||||||
|
MAX_MEMORY_PERCENT = float(os.getenv("MAX_MEMORY_PERCENT", "80"))
|
||||||
|
|
||||||
|
# ============== API Configuration ==============
|
||||||
API_TITLE = "Document Translation API"
|
API_TITLE = "Document Translation API"
|
||||||
API_VERSION = "1.0.0"
|
API_VERSION = "1.0.0"
|
||||||
API_DESCRIPTION = """
|
API_DESCRIPTION = """
|
||||||
@ -40,6 +72,12 @@ class Config:
|
|||||||
- Excel (.xlsx) - Preserves cell formatting, formulas, merged cells, images
|
- Excel (.xlsx) - Preserves cell formatting, formulas, merged cells, images
|
||||||
- Word (.docx) - Preserves styles, tables, images, headers/footers
|
- Word (.docx) - Preserves styles, tables, images, headers/footers
|
||||||
- PowerPoint (.pptx) - Preserves layouts, animations, embedded media
|
- PowerPoint (.pptx) - Preserves layouts, animations, embedded media
|
||||||
|
|
||||||
|
SaaS Features:
|
||||||
|
- Rate limiting per client IP
|
||||||
|
- Automatic file cleanup
|
||||||
|
- Health monitoring
|
||||||
|
- Request logging
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -48,5 +86,7 @@ class Config:
|
|||||||
cls.UPLOAD_DIR.mkdir(exist_ok=True, parents=True)
|
cls.UPLOAD_DIR.mkdir(exist_ok=True, parents=True)
|
||||||
cls.OUTPUT_DIR.mkdir(exist_ok=True, parents=True)
|
cls.OUTPUT_DIR.mkdir(exist_ok=True, parents=True)
|
||||||
cls.TEMP_DIR.mkdir(exist_ok=True, parents=True)
|
cls.TEMP_DIR.mkdir(exist_ok=True, parents=True)
|
||||||
|
cls.LOGS_DIR.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
|
|
||||||
config = Config()
|
config = Config()
|
||||||
|
|||||||
216
main.py
216
main.py
@ -1,24 +1,55 @@
|
|||||||
"""
|
"""
|
||||||
Document Translation API
|
Document Translation API
|
||||||
FastAPI application for translating complex documents while preserving formatting
|
FastAPI application for translating complex documents while preserving formatting
|
||||||
|
SaaS-ready with rate limiting, validation, and robust error handling
|
||||||
"""
|
"""
|
||||||
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request, Depends
|
||||||
from fastapi.responses import FileResponse, JSONResponse
|
from fastapi.responses import FileResponse, JSONResponse
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
from config import config
|
from config import config
|
||||||
from translators import excel_translator, word_translator, pptx_translator
|
from translators import excel_translator, word_translator, pptx_translator
|
||||||
from utils import file_handler, handle_translation_error, DocumentProcessingError
|
from utils import file_handler, handle_translation_error, DocumentProcessingError
|
||||||
|
|
||||||
# Configure logging
|
# Import SaaS middleware
|
||||||
logging.basicConfig(level=logging.INFO)
|
from middleware.rate_limiting import RateLimitMiddleware, RateLimitManager, RateLimitConfig
|
||||||
|
from middleware.security import SecurityHeadersMiddleware, RequestLoggingMiddleware, ErrorHandlingMiddleware
|
||||||
|
from middleware.cleanup import FileCleanupManager, MemoryMonitor, HealthChecker, create_cleanup_manager
|
||||||
|
from middleware.validation import FileValidator, LanguageValidator, ProviderValidator, InputSanitizer, ValidationError
|
||||||
|
|
||||||
|
# Configure structured logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=getattr(logging, os.getenv("LOG_LEVEL", "INFO")),
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||||
|
)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Initialize SaaS components
|
||||||
|
rate_limit_config = RateLimitConfig(
|
||||||
|
requests_per_minute=int(os.getenv("RATE_LIMIT_PER_MINUTE", "30")),
|
||||||
|
requests_per_hour=int(os.getenv("RATE_LIMIT_PER_HOUR", "200")),
|
||||||
|
translations_per_minute=int(os.getenv("TRANSLATIONS_PER_MINUTE", "10")),
|
||||||
|
translations_per_hour=int(os.getenv("TRANSLATIONS_PER_HOUR", "50")),
|
||||||
|
max_concurrent_translations=int(os.getenv("MAX_CONCURRENT_TRANSLATIONS", "5")),
|
||||||
|
)
|
||||||
|
rate_limit_manager = RateLimitManager(rate_limit_config)
|
||||||
|
|
||||||
|
cleanup_manager = create_cleanup_manager(config)
|
||||||
|
memory_monitor = MemoryMonitor(max_memory_percent=float(os.getenv("MAX_MEMORY_PERCENT", "80")))
|
||||||
|
health_checker = HealthChecker(cleanup_manager, memory_monitor)
|
||||||
|
|
||||||
|
file_validator = FileValidator(
|
||||||
|
max_size_mb=config.MAX_FILE_SIZE_MB,
|
||||||
|
allowed_extensions=config.SUPPORTED_EXTENSIONS
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_full_prompt(system_prompt: str, glossary: str) -> str:
|
def build_full_prompt(system_prompt: str, glossary: str) -> str:
|
||||||
"""Combine system prompt and glossary into a single prompt for LLM translation."""
|
"""Combine system prompt and glossary into a single prompt for LLM translation."""
|
||||||
@ -40,23 +71,47 @@ Always use the translations from this glossary when you encounter these terms.""
|
|||||||
return "\n\n".join(parts) if parts else ""
|
return "\n\n".join(parts) if parts else ""
|
||||||
|
|
||||||
|
|
||||||
# Ensure necessary directories exist
|
# Lifespan context manager for startup/shutdown
|
||||||
config.ensure_directories()
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""Handle startup and shutdown events"""
|
||||||
|
# Startup
|
||||||
|
logger.info("Starting Document Translation API...")
|
||||||
|
config.ensure_directories()
|
||||||
|
await cleanup_manager.start()
|
||||||
|
logger.info("API ready to accept requests")
|
||||||
|
|
||||||
# Create FastAPI app
|
yield
|
||||||
|
|
||||||
|
# Shutdown
|
||||||
|
logger.info("Shutting down...")
|
||||||
|
await cleanup_manager.stop()
|
||||||
|
logger.info("Cleanup completed")
|
||||||
|
|
||||||
|
|
||||||
|
# Create FastAPI app with lifespan
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title=config.API_TITLE,
|
title=config.API_TITLE,
|
||||||
version=config.API_VERSION,
|
version=config.API_VERSION,
|
||||||
description=config.API_DESCRIPTION
|
description=config.API_DESCRIPTION,
|
||||||
|
lifespan=lifespan
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add CORS middleware
|
# Add middleware (order matters - first added is outermost)
|
||||||
|
app.add_middleware(ErrorHandlingMiddleware)
|
||||||
|
app.add_middleware(RequestLoggingMiddleware, log_body=False)
|
||||||
|
app.add_middleware(SecurityHeadersMiddleware, config={"enable_hsts": os.getenv("ENABLE_HSTS", "false").lower() == "true"})
|
||||||
|
app.add_middleware(RateLimitMiddleware, rate_limit_manager=rate_limit_manager)
|
||||||
|
|
||||||
|
# CORS - configure for production
|
||||||
|
allowed_origins = os.getenv("CORS_ORIGINS", "*").split(",")
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=["*"], # Configure appropriately for production
|
allow_origins=allowed_origins,
|
||||||
allow_credentials=True,
|
allow_credentials=True,
|
||||||
allow_methods=["*"],
|
allow_methods=["GET", "POST", "DELETE", "OPTIONS"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
|
expose_headers=["X-Request-ID", "X-Original-Filename", "X-File-Size-MB", "X-Target-Language"]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mount static files
|
# Mount static files
|
||||||
@ -65,6 +120,20 @@ if static_dir.exists():
|
|||||||
app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")
|
app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")
|
||||||
|
|
||||||
|
|
||||||
|
# Custom exception handler for ValidationError
|
||||||
|
@app.exception_handler(ValidationError)
|
||||||
|
async def validation_error_handler(request: Request, exc: ValidationError):
|
||||||
|
"""Handle validation errors with user-friendly messages"""
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=400,
|
||||||
|
content={
|
||||||
|
"error": exc.code,
|
||||||
|
"message": exc.message,
|
||||||
|
"details": exc.details
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
async def root():
|
async def root():
|
||||||
"""Root endpoint with API information"""
|
"""Root endpoint with API information"""
|
||||||
@ -83,11 +152,24 @@ async def root():
|
|||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
async def health_check():
|
async def health_check():
|
||||||
"""Health check endpoint"""
|
"""Health check endpoint with detailed system status"""
|
||||||
return {
|
health_status = await health_checker.check_health()
|
||||||
"status": "healthy",
|
status_code = 200 if health_status.get("status") == "healthy" else 503
|
||||||
"translation_service": config.TRANSLATION_SERVICE
|
|
||||||
}
|
return JSONResponse(
|
||||||
|
status_code=status_code,
|
||||||
|
content={
|
||||||
|
"status": health_status.get("status", "unknown"),
|
||||||
|
"translation_service": config.TRANSLATION_SERVICE,
|
||||||
|
"memory": health_status.get("memory", {}),
|
||||||
|
"disk": health_status.get("disk", {}),
|
||||||
|
"cleanup_service": health_status.get("cleanup_service", {}),
|
||||||
|
"rate_limits": {
|
||||||
|
"requests_per_minute": rate_limit_config.requests_per_minute,
|
||||||
|
"translations_per_minute": rate_limit_config.translations_per_minute,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/languages")
|
@app.get("/languages")
|
||||||
@ -128,6 +210,7 @@ async def get_supported_languages():
|
|||||||
|
|
||||||
@app.post("/translate")
|
@app.post("/translate")
|
||||||
async def translate_document(
|
async def translate_document(
|
||||||
|
request: Request,
|
||||||
file: UploadFile = File(..., description="Document file to translate (.xlsx, .docx, or .pptx)"),
|
file: UploadFile = File(..., description="Document file to translate (.xlsx, .docx, or .pptx)"),
|
||||||
target_language: str = Form(..., description="Target language code (e.g., 'es', 'fr', 'de')"),
|
target_language: str = Form(..., description="Target language code (e.g., 'es', 'fr', 'de')"),
|
||||||
source_language: str = Form(default="auto", description="Source language code (default: auto-detect)"),
|
source_language: str = Form(default="auto", description="Source language code (default: auto-detect)"),
|
||||||
@ -160,11 +243,38 @@ async def translate_document(
|
|||||||
"""
|
"""
|
||||||
input_path = None
|
input_path = None
|
||||||
output_path = None
|
output_path = None
|
||||||
|
request_id = getattr(request.state, 'request_id', 'unknown')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Validate inputs
|
||||||
|
sanitized_language = InputSanitizer.sanitize_language_code(target_language)
|
||||||
|
LanguageValidator.validate(sanitized_language)
|
||||||
|
ProviderValidator.validate(provider)
|
||||||
|
|
||||||
|
# Validate file before processing
|
||||||
|
validation_result = await file_validator.validate_async(file)
|
||||||
|
if not validation_result.is_valid:
|
||||||
|
raise ValidationError(
|
||||||
|
message=f"File validation failed: {'; '.join(validation_result.errors)}",
|
||||||
|
code="INVALID_FILE",
|
||||||
|
details={"errors": validation_result.errors, "warnings": validation_result.warnings}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log any warnings
|
||||||
|
if validation_result.warnings:
|
||||||
|
logger.warning(f"[{request_id}] File validation warnings: {validation_result.warnings}")
|
||||||
|
|
||||||
|
# Check rate limit for translations
|
||||||
|
client_ip = request.client.host if request.client else "unknown"
|
||||||
|
if not await rate_limit_manager.check_translation_limit(client_ip):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429,
|
||||||
|
detail="Translation rate limit exceeded. Please try again later."
|
||||||
|
)
|
||||||
|
|
||||||
# Validate file extension
|
# Validate file extension
|
||||||
file_extension = file_handler.validate_file_extension(file.filename)
|
file_extension = file_handler.validate_file_extension(file.filename)
|
||||||
logger.info(f"Processing {file_extension} file: {file.filename}")
|
logger.info(f"[{request_id}] Processing {file_extension} file: {file.filename}")
|
||||||
|
|
||||||
# Validate file size
|
# Validate file size
|
||||||
file_handler.validate_file_size(file)
|
file_handler.validate_file_size(file)
|
||||||
@ -178,7 +288,11 @@ async def translate_document(
|
|||||||
output_path = config.OUTPUT_DIR / output_filename
|
output_path = config.OUTPUT_DIR / output_filename
|
||||||
|
|
||||||
await file_handler.save_upload_file(file, input_path)
|
await file_handler.save_upload_file(file, input_path)
|
||||||
logger.info(f"Saved input file to: {input_path}")
|
logger.info(f"[{request_id}] Saved input file to: {input_path}")
|
||||||
|
|
||||||
|
# Track file for cleanup
|
||||||
|
await cleanup_manager.track_file(input_path, ttl_minutes=30)
|
||||||
|
await cleanup_manager.track_file(output_path, ttl_minutes=60)
|
||||||
|
|
||||||
# Configure translation provider
|
# Configure translation provider
|
||||||
from services.translation_service import GoogleTranslationProvider, DeepLTranslationProvider, LibreTranslationProvider, OllamaTranslationProvider, OpenAITranslationProvider, translation_service
|
from services.translation_service import GoogleTranslationProvider, DeepLTranslationProvider, LibreTranslationProvider, OllamaTranslationProvider, OpenAITranslationProvider, translation_service
|
||||||
@ -657,6 +771,74 @@ async def reconstruct_document(
|
|||||||
raise HTTPException(status_code=500, detail=f"Failed to reconstruct document: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"Failed to reconstruct document: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
# ============== SaaS Management Endpoints ==============
|
||||||
|
|
||||||
|
@app.get("/metrics")
|
||||||
|
async def get_metrics():
|
||||||
|
"""Get system metrics and statistics for monitoring"""
|
||||||
|
health_status = await health_checker.check_health()
|
||||||
|
cleanup_stats = cleanup_manager.get_stats()
|
||||||
|
rate_limit_stats = rate_limit_manager.get_stats()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"system": {
|
||||||
|
"memory": health_status.get("memory", {}),
|
||||||
|
"disk": health_status.get("disk", {}),
|
||||||
|
"status": health_status.get("status", "unknown")
|
||||||
|
},
|
||||||
|
"cleanup": cleanup_stats,
|
||||||
|
"rate_limits": rate_limit_stats,
|
||||||
|
"config": {
|
||||||
|
"max_file_size_mb": config.MAX_FILE_SIZE_MB,
|
||||||
|
"supported_extensions": list(config.SUPPORTED_EXTENSIONS),
|
||||||
|
"translation_service": config.TRANSLATION_SERVICE
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/rate-limit/status")
|
||||||
|
async def get_rate_limit_status(request: Request):
|
||||||
|
"""Get current rate limit status for the requesting client"""
|
||||||
|
client_ip = request.client.host if request.client else "unknown"
|
||||||
|
status = await rate_limit_manager.get_client_status(client_ip)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"client_ip": client_ip,
|
||||||
|
"limits": {
|
||||||
|
"requests_per_minute": rate_limit_config.requests_per_minute,
|
||||||
|
"requests_per_hour": rate_limit_config.requests_per_hour,
|
||||||
|
"translations_per_minute": rate_limit_config.translations_per_minute,
|
||||||
|
"translations_per_hour": rate_limit_config.translations_per_hour
|
||||||
|
},
|
||||||
|
"current_usage": status
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/admin/cleanup/trigger")
|
||||||
|
async def trigger_cleanup():
|
||||||
|
"""Trigger manual cleanup of expired files"""
|
||||||
|
try:
|
||||||
|
cleaned = await cleanup_manager.cleanup_expired()
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"files_cleaned": cleaned,
|
||||||
|
"message": f"Cleaned up {cleaned} expired files"
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Manual cleanup failed: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"Cleanup failed: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/admin/files/tracked")
|
||||||
|
async def get_tracked_files():
|
||||||
|
"""Get list of currently tracked files"""
|
||||||
|
tracked = cleanup_manager.get_tracked_files()
|
||||||
|
return {
|
||||||
|
"count": len(tracked),
|
||||||
|
"files": tracked
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
|
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
|
||||||
|
|||||||
62
middleware/__init__.py
Normal file
62
middleware/__init__.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
"""
|
||||||
|
Middleware package for SaaS robustness
|
||||||
|
|
||||||
|
This package provides:
|
||||||
|
- Rate limiting: Protect against abuse and ensure fair usage
|
||||||
|
- Validation: Validate all inputs before processing
|
||||||
|
- Security: Security headers, request logging, error handling
|
||||||
|
- Cleanup: Automatic file cleanup and resource management
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .rate_limiting import (
|
||||||
|
RateLimitConfig,
|
||||||
|
RateLimitManager,
|
||||||
|
RateLimitMiddleware,
|
||||||
|
ClientRateLimiter,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .validation import (
|
||||||
|
ValidationError,
|
||||||
|
ValidationResult,
|
||||||
|
FileValidator,
|
||||||
|
LanguageValidator,
|
||||||
|
ProviderValidator,
|
||||||
|
InputSanitizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .security import (
|
||||||
|
SecurityHeadersMiddleware,
|
||||||
|
RequestLoggingMiddleware,
|
||||||
|
ErrorHandlingMiddleware,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .cleanup import (
|
||||||
|
FileCleanupManager,
|
||||||
|
MemoryMonitor,
|
||||||
|
HealthChecker,
|
||||||
|
create_cleanup_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Rate limiting
|
||||||
|
"RateLimitConfig",
|
||||||
|
"RateLimitManager",
|
||||||
|
"RateLimitMiddleware",
|
||||||
|
"ClientRateLimiter",
|
||||||
|
# Validation
|
||||||
|
"ValidationError",
|
||||||
|
"ValidationResult",
|
||||||
|
"FileValidator",
|
||||||
|
"LanguageValidator",
|
||||||
|
"ProviderValidator",
|
||||||
|
"InputSanitizer",
|
||||||
|
# Security
|
||||||
|
"SecurityHeadersMiddleware",
|
||||||
|
"RequestLoggingMiddleware",
|
||||||
|
"ErrorHandlingMiddleware",
|
||||||
|
# Cleanup
|
||||||
|
"FileCleanupManager",
|
||||||
|
"MemoryMonitor",
|
||||||
|
"HealthChecker",
|
||||||
|
"create_cleanup_manager",
|
||||||
|
]
|
||||||
400
middleware/cleanup.py
Normal file
400
middleware/cleanup.py
Normal file
@ -0,0 +1,400 @@
|
|||||||
|
"""
|
||||||
|
Cleanup and Resource Management for SaaS robustness
|
||||||
|
Automatic cleanup of temporary files and resources
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import asyncio
|
||||||
|
import threading
|
||||||
|
from pathlib import Path
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Optional, Set
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FileCleanupManager:
|
||||||
|
"""Manages automatic cleanup of temporary and output files"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
upload_dir: Path,
|
||||||
|
output_dir: Path,
|
||||||
|
temp_dir: Path,
|
||||||
|
max_file_age_hours: int = 1,
|
||||||
|
cleanup_interval_minutes: int = 10,
|
||||||
|
max_total_size_gb: float = 10.0
|
||||||
|
):
|
||||||
|
self.upload_dir = Path(upload_dir)
|
||||||
|
self.output_dir = Path(output_dir)
|
||||||
|
self.temp_dir = Path(temp_dir)
|
||||||
|
self.max_file_age_seconds = max_file_age_hours * 3600
|
||||||
|
self.cleanup_interval = cleanup_interval_minutes * 60
|
||||||
|
self.max_total_size_bytes = int(max_total_size_gb * 1024 * 1024 * 1024)
|
||||||
|
|
||||||
|
self._running = False
|
||||||
|
self._task: Optional[asyncio.Task] = None
|
||||||
|
self._protected_files: Set[str] = set()
|
||||||
|
self._tracked_files: dict = {} # filepath -> {created, ttl_minutes}
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._stats = {
|
||||||
|
"files_cleaned": 0,
|
||||||
|
"bytes_freed": 0,
|
||||||
|
"cleanup_runs": 0
|
||||||
|
}
|
||||||
|
|
||||||
|
async def track_file(self, filepath: Path, ttl_minutes: int = 60):
|
||||||
|
"""Track a file for automatic cleanup after TTL expires"""
|
||||||
|
with self._lock:
|
||||||
|
self._tracked_files[str(filepath)] = {
|
||||||
|
"created": time.time(),
|
||||||
|
"ttl_minutes": ttl_minutes,
|
||||||
|
"expires_at": time.time() + (ttl_minutes * 60)
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_tracked_files(self) -> list:
|
||||||
|
"""Get list of currently tracked files with their status"""
|
||||||
|
now = time.time()
|
||||||
|
result = []
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
for filepath, info in self._tracked_files.items():
|
||||||
|
remaining = info["expires_at"] - now
|
||||||
|
result.append({
|
||||||
|
"path": filepath,
|
||||||
|
"exists": Path(filepath).exists(),
|
||||||
|
"expires_in_seconds": max(0, int(remaining)),
|
||||||
|
"ttl_minutes": info["ttl_minutes"]
|
||||||
|
})
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def cleanup_expired(self) -> int:
|
||||||
|
"""Cleanup expired tracked files"""
|
||||||
|
now = time.time()
|
||||||
|
cleaned = 0
|
||||||
|
to_remove = []
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
for filepath, info in list(self._tracked_files.items()):
|
||||||
|
if now > info["expires_at"]:
|
||||||
|
to_remove.append(filepath)
|
||||||
|
|
||||||
|
for filepath in to_remove:
|
||||||
|
try:
|
||||||
|
path = Path(filepath)
|
||||||
|
if path.exists() and not self.is_protected(path):
|
||||||
|
size = path.stat().st_size
|
||||||
|
path.unlink()
|
||||||
|
cleaned += 1
|
||||||
|
self._stats["files_cleaned"] += 1
|
||||||
|
self._stats["bytes_freed"] += size
|
||||||
|
logger.info(f"Cleaned expired file: {filepath}")
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
self._tracked_files.pop(filepath, None)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to clean expired file {filepath}: {e}")
|
||||||
|
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
def get_stats(self) -> dict:
|
||||||
|
"""Get cleanup statistics"""
|
||||||
|
disk_usage = self.get_disk_usage()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
tracked_count = len(self._tracked_files)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"files_cleaned_total": self._stats["files_cleaned"],
|
||||||
|
"bytes_freed_total_mb": round(self._stats["bytes_freed"] / (1024 * 1024), 2),
|
||||||
|
"cleanup_runs": self._stats["cleanup_runs"],
|
||||||
|
"tracked_files": tracked_count,
|
||||||
|
"disk_usage": disk_usage,
|
||||||
|
"is_running": self._running
|
||||||
|
}
|
||||||
|
|
||||||
|
def protect_file(self, filepath: Path):
|
||||||
|
"""Mark a file as protected (being processed)"""
|
||||||
|
with self._lock:
|
||||||
|
self._protected_files.add(str(filepath))
|
||||||
|
|
||||||
|
def unprotect_file(self, filepath: Path):
|
||||||
|
"""Remove protection from a file"""
|
||||||
|
with self._lock:
|
||||||
|
self._protected_files.discard(str(filepath))
|
||||||
|
|
||||||
|
def is_protected(self, filepath: Path) -> bool:
|
||||||
|
"""Check if a file is protected"""
|
||||||
|
with self._lock:
|
||||||
|
return str(filepath) in self._protected_files
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
"""Start the cleanup background task"""
|
||||||
|
if self._running:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._running = True
|
||||||
|
self._task = asyncio.create_task(self._cleanup_loop())
|
||||||
|
logger.info("File cleanup manager started")
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
"""Stop the cleanup background task"""
|
||||||
|
self._running = False
|
||||||
|
if self._task:
|
||||||
|
self._task.cancel()
|
||||||
|
try:
|
||||||
|
await self._task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
logger.info("File cleanup manager stopped")
|
||||||
|
|
||||||
|
async def _cleanup_loop(self):
|
||||||
|
"""Background loop for periodic cleanup"""
|
||||||
|
while self._running:
|
||||||
|
try:
|
||||||
|
await self.cleanup()
|
||||||
|
await self.cleanup_expired()
|
||||||
|
self._stats["cleanup_runs"] += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Cleanup error: {e}")
|
||||||
|
|
||||||
|
await asyncio.sleep(self.cleanup_interval)
|
||||||
|
|
||||||
|
async def cleanup(self) -> dict:
|
||||||
|
"""Perform cleanup of old files"""
|
||||||
|
stats = {
|
||||||
|
"files_deleted": 0,
|
||||||
|
"bytes_freed": 0,
|
||||||
|
"errors": []
|
||||||
|
}
|
||||||
|
|
||||||
|
now = time.time()
|
||||||
|
|
||||||
|
# Cleanup each directory
|
||||||
|
for directory in [self.upload_dir, self.output_dir, self.temp_dir]:
|
||||||
|
if not directory.exists():
|
||||||
|
continue
|
||||||
|
|
||||||
|
for filepath in directory.iterdir():
|
||||||
|
if not filepath.is_file():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Skip protected files
|
||||||
|
if self.is_protected(filepath):
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check file age
|
||||||
|
file_age = now - filepath.stat().st_mtime
|
||||||
|
|
||||||
|
if file_age > self.max_file_age_seconds:
|
||||||
|
file_size = filepath.stat().st_size
|
||||||
|
filepath.unlink()
|
||||||
|
stats["files_deleted"] += 1
|
||||||
|
stats["bytes_freed"] += file_size
|
||||||
|
logger.debug(f"Deleted old file: {filepath}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
stats["errors"].append(str(e))
|
||||||
|
logger.warning(f"Failed to delete {filepath}: {e}")
|
||||||
|
|
||||||
|
# Force cleanup if total size exceeds limit
|
||||||
|
await self._enforce_size_limit(stats)
|
||||||
|
|
||||||
|
if stats["files_deleted"] > 0:
|
||||||
|
mb_freed = stats["bytes_freed"] / (1024 * 1024)
|
||||||
|
logger.info(f"Cleanup: deleted {stats['files_deleted']} files, freed {mb_freed:.2f}MB")
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
async def _enforce_size_limit(self, stats: dict):
|
||||||
|
"""Delete oldest files if total size exceeds limit"""
|
||||||
|
files_with_mtime = []
|
||||||
|
total_size = 0
|
||||||
|
|
||||||
|
for directory in [self.upload_dir, self.output_dir, self.temp_dir]:
|
||||||
|
if not directory.exists():
|
||||||
|
continue
|
||||||
|
|
||||||
|
for filepath in directory.iterdir():
|
||||||
|
if not filepath.is_file() or self.is_protected(filepath):
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
stat = filepath.stat()
|
||||||
|
files_with_mtime.append((filepath, stat.st_mtime, stat.st_size))
|
||||||
|
total_size += stat.st_size
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# If under limit, nothing to do
|
||||||
|
if total_size <= self.max_total_size_bytes:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Sort by modification time (oldest first)
|
||||||
|
files_with_mtime.sort(key=lambda x: x[1])
|
||||||
|
|
||||||
|
# Delete oldest files until under limit
|
||||||
|
for filepath, _, size in files_with_mtime:
|
||||||
|
if total_size <= self.max_total_size_bytes:
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
filepath.unlink()
|
||||||
|
total_size -= size
|
||||||
|
stats["files_deleted"] += 1
|
||||||
|
stats["bytes_freed"] += size
|
||||||
|
logger.info(f"Deleted file to free space: {filepath}")
|
||||||
|
except Exception as e:
|
||||||
|
stats["errors"].append(str(e))
|
||||||
|
|
||||||
|
def get_disk_usage(self) -> dict:
|
||||||
|
"""Get current disk usage statistics"""
|
||||||
|
total_files = 0
|
||||||
|
total_size = 0
|
||||||
|
|
||||||
|
for directory in [self.upload_dir, self.output_dir, self.temp_dir]:
|
||||||
|
if not directory.exists():
|
||||||
|
continue
|
||||||
|
|
||||||
|
for filepath in directory.iterdir():
|
||||||
|
if filepath.is_file():
|
||||||
|
total_files += 1
|
||||||
|
try:
|
||||||
|
total_size += filepath.stat().st_size
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_files": total_files,
|
||||||
|
"total_size_mb": round(total_size / (1024 * 1024), 2),
|
||||||
|
"max_size_gb": self.max_total_size_bytes / (1024 * 1024 * 1024),
|
||||||
|
"usage_percent": round((total_size / self.max_total_size_bytes) * 100, 1) if self.max_total_size_bytes > 0 else 0,
|
||||||
|
"directories": {
|
||||||
|
"uploads": str(self.upload_dir),
|
||||||
|
"outputs": str(self.output_dir),
|
||||||
|
"temp": str(self.temp_dir)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryMonitor:
|
||||||
|
"""Monitors memory usage and triggers cleanup if needed"""
|
||||||
|
|
||||||
|
def __init__(self, max_memory_percent: float = 80.0):
|
||||||
|
self.max_memory_percent = max_memory_percent
|
||||||
|
self._high_memory_callbacks = []
|
||||||
|
|
||||||
|
def get_memory_usage(self) -> dict:
|
||||||
|
"""Get current memory usage"""
|
||||||
|
try:
|
||||||
|
import psutil
|
||||||
|
process = psutil.Process()
|
||||||
|
memory_info = process.memory_info()
|
||||||
|
system_memory = psutil.virtual_memory()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"process_rss_mb": round(memory_info.rss / (1024 * 1024), 2),
|
||||||
|
"process_vms_mb": round(memory_info.vms / (1024 * 1024), 2),
|
||||||
|
"system_total_gb": round(system_memory.total / (1024 * 1024 * 1024), 2),
|
||||||
|
"system_available_gb": round(system_memory.available / (1024 * 1024 * 1024), 2),
|
||||||
|
"system_percent": system_memory.percent
|
||||||
|
}
|
||||||
|
except ImportError:
|
||||||
|
return {"error": "psutil not installed"}
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": str(e)}
|
||||||
|
|
||||||
|
def check_memory(self) -> bool:
|
||||||
|
"""Check if memory usage is within limits"""
|
||||||
|
usage = self.get_memory_usage()
|
||||||
|
if "error" in usage:
|
||||||
|
return True # Can't check, assume OK
|
||||||
|
|
||||||
|
return usage.get("system_percent", 0) < self.max_memory_percent
|
||||||
|
|
||||||
|
def on_high_memory(self, callback):
|
||||||
|
"""Register callback for high memory situations"""
|
||||||
|
self._high_memory_callbacks.append(callback)
|
||||||
|
|
||||||
|
|
||||||
|
class HealthChecker:
|
||||||
|
"""Comprehensive health checking for the application"""
|
||||||
|
|
||||||
|
def __init__(self, cleanup_manager: FileCleanupManager, memory_monitor: MemoryMonitor):
|
||||||
|
self.cleanup_manager = cleanup_manager
|
||||||
|
self.memory_monitor = memory_monitor
|
||||||
|
self.start_time = datetime.now()
|
||||||
|
self._translation_count = 0
|
||||||
|
self._error_count = 0
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
def record_translation(self, success: bool = True):
|
||||||
|
"""Record a translation attempt"""
|
||||||
|
with self._lock:
|
||||||
|
self._translation_count += 1
|
||||||
|
if not success:
|
||||||
|
self._error_count += 1
|
||||||
|
|
||||||
|
async def check_health(self) -> dict:
|
||||||
|
"""Get comprehensive health status (async version)"""
|
||||||
|
return self.get_health()
|
||||||
|
|
||||||
|
def get_health(self) -> dict:
|
||||||
|
"""Get comprehensive health status"""
|
||||||
|
memory = self.memory_monitor.get_memory_usage()
|
||||||
|
disk = self.cleanup_manager.get_disk_usage()
|
||||||
|
|
||||||
|
# Determine overall status
|
||||||
|
status = "healthy"
|
||||||
|
issues = []
|
||||||
|
|
||||||
|
if "error" not in memory:
|
||||||
|
if memory.get("system_percent", 0) > 90:
|
||||||
|
status = "degraded"
|
||||||
|
issues.append("High memory usage")
|
||||||
|
elif memory.get("system_percent", 0) > 80:
|
||||||
|
issues.append("Memory usage elevated")
|
||||||
|
|
||||||
|
if disk.get("usage_percent", 0) > 90:
|
||||||
|
status = "degraded"
|
||||||
|
issues.append("High disk usage")
|
||||||
|
elif disk.get("usage_percent", 0) > 80:
|
||||||
|
issues.append("Disk usage elevated")
|
||||||
|
|
||||||
|
uptime = datetime.now() - self.start_time
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": status,
|
||||||
|
"issues": issues,
|
||||||
|
"uptime_seconds": int(uptime.total_seconds()),
|
||||||
|
"uptime_human": str(uptime).split('.')[0],
|
||||||
|
"translations": {
|
||||||
|
"total": self._translation_count,
|
||||||
|
"errors": self._error_count,
|
||||||
|
"success_rate": round(
|
||||||
|
((self._translation_count - self._error_count) / self._translation_count * 100)
|
||||||
|
if self._translation_count > 0 else 100, 1
|
||||||
|
)
|
||||||
|
},
|
||||||
|
"memory": memory,
|
||||||
|
"disk": disk,
|
||||||
|
"cleanup_service": self.cleanup_manager.get_stats(),
|
||||||
|
"timestamp": datetime.now().isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Create default instances
|
||||||
|
def create_cleanup_manager(config) -> FileCleanupManager:
|
||||||
|
"""Create cleanup manager with config"""
|
||||||
|
return FileCleanupManager(
|
||||||
|
upload_dir=config.UPLOAD_DIR,
|
||||||
|
output_dir=config.OUTPUT_DIR,
|
||||||
|
temp_dir=config.TEMP_DIR,
|
||||||
|
max_file_age_hours=getattr(config, 'MAX_FILE_AGE_HOURS', 1),
|
||||||
|
cleanup_interval_minutes=getattr(config, 'CLEANUP_INTERVAL_MINUTES', 10),
|
||||||
|
max_total_size_gb=getattr(config, 'MAX_TOTAL_SIZE_GB', 10.0)
|
||||||
|
)
|
||||||
328
middleware/rate_limiting.py
Normal file
328
middleware/rate_limiting.py
Normal file
@ -0,0 +1,328 @@
|
|||||||
|
"""
|
||||||
|
Rate Limiting Middleware for SaaS robustness
|
||||||
|
Protects against abuse and ensures fair usage
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
import asyncio
|
||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict, Optional
|
||||||
|
from fastapi import Request, HTTPException
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.responses import JSONResponse
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RateLimitConfig:
|
||||||
|
"""Configuration for rate limiting"""
|
||||||
|
# Requests per window
|
||||||
|
requests_per_minute: int = 30
|
||||||
|
requests_per_hour: int = 200
|
||||||
|
requests_per_day: int = 1000
|
||||||
|
|
||||||
|
# Translation-specific limits
|
||||||
|
translations_per_minute: int = 10
|
||||||
|
translations_per_hour: int = 50
|
||||||
|
max_concurrent_translations: int = 5
|
||||||
|
|
||||||
|
# File size limits (MB)
|
||||||
|
max_file_size_mb: int = 50
|
||||||
|
max_total_size_per_hour_mb: int = 500
|
||||||
|
|
||||||
|
# Burst protection
|
||||||
|
burst_limit: int = 10 # Max requests in 1 second
|
||||||
|
|
||||||
|
# Whitelist IPs (no rate limiting)
|
||||||
|
whitelist_ips: list = field(default_factory=lambda: ["127.0.0.1", "::1"])
|
||||||
|
|
||||||
|
|
||||||
|
class TokenBucket:
|
||||||
|
"""Token bucket algorithm for rate limiting"""
|
||||||
|
|
||||||
|
def __init__(self, capacity: int, refill_rate: float):
|
||||||
|
self.capacity = capacity
|
||||||
|
self.refill_rate = refill_rate # tokens per second
|
||||||
|
self.tokens = capacity
|
||||||
|
self.last_refill = time.time()
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def consume(self, tokens: int = 1) -> bool:
|
||||||
|
"""Try to consume tokens, return True if successful"""
|
||||||
|
async with self._lock:
|
||||||
|
self._refill()
|
||||||
|
if self.tokens >= tokens:
|
||||||
|
self.tokens -= tokens
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _refill(self):
|
||||||
|
"""Refill tokens based on time elapsed"""
|
||||||
|
now = time.time()
|
||||||
|
elapsed = now - self.last_refill
|
||||||
|
self.tokens = min(self.capacity, self.tokens + elapsed * self.refill_rate)
|
||||||
|
self.last_refill = now
|
||||||
|
|
||||||
|
|
||||||
|
class SlidingWindowCounter:
|
||||||
|
"""Sliding window counter for accurate rate limiting"""
|
||||||
|
|
||||||
|
def __init__(self, window_seconds: int, max_requests: int):
|
||||||
|
self.window_seconds = window_seconds
|
||||||
|
self.max_requests = max_requests
|
||||||
|
self.requests: list = []
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def is_allowed(self) -> bool:
|
||||||
|
"""Check if a new request is allowed"""
|
||||||
|
async with self._lock:
|
||||||
|
now = time.time()
|
||||||
|
# Remove old requests outside the window
|
||||||
|
self.requests = [ts for ts in self.requests if now - ts < self.window_seconds]
|
||||||
|
|
||||||
|
if len(self.requests) < self.max_requests:
|
||||||
|
self.requests.append(now)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_count(self) -> int:
|
||||||
|
"""Get current request count in window"""
|
||||||
|
now = time.time()
|
||||||
|
return len([ts for ts in self.requests if now - ts < self.window_seconds])
|
||||||
|
|
||||||
|
|
||||||
|
class ClientRateLimiter:
|
||||||
|
"""Per-client rate limiter with multiple windows"""
|
||||||
|
|
||||||
|
def __init__(self, config: RateLimitConfig):
|
||||||
|
self.config = config
|
||||||
|
self.minute_counter = SlidingWindowCounter(60, config.requests_per_minute)
|
||||||
|
self.hour_counter = SlidingWindowCounter(3600, config.requests_per_hour)
|
||||||
|
self.day_counter = SlidingWindowCounter(86400, config.requests_per_day)
|
||||||
|
self.burst_bucket = TokenBucket(config.burst_limit, config.burst_limit)
|
||||||
|
self.translation_minute = SlidingWindowCounter(60, config.translations_per_minute)
|
||||||
|
self.translation_hour = SlidingWindowCounter(3600, config.translations_per_hour)
|
||||||
|
self.concurrent_translations = 0
|
||||||
|
self.total_size_hour: list = [] # List of (timestamp, size_mb)
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def check_request(self) -> tuple[bool, str]:
|
||||||
|
"""Check if request is allowed, return (allowed, reason)"""
|
||||||
|
# Check burst limit
|
||||||
|
if not await self.burst_bucket.consume():
|
||||||
|
return False, "Too many requests. Please slow down."
|
||||||
|
|
||||||
|
# Check minute limit
|
||||||
|
if not await self.minute_counter.is_allowed():
|
||||||
|
return False, f"Rate limit exceeded. Max {self.config.requests_per_minute} requests per minute."
|
||||||
|
|
||||||
|
# Check hour limit
|
||||||
|
if not await self.hour_counter.is_allowed():
|
||||||
|
return False, f"Hourly limit exceeded. Max {self.config.requests_per_hour} requests per hour."
|
||||||
|
|
||||||
|
# Check day limit
|
||||||
|
if not await self.day_counter.is_allowed():
|
||||||
|
return False, f"Daily limit exceeded. Max {self.config.requests_per_day} requests per day."
|
||||||
|
|
||||||
|
return True, ""
|
||||||
|
|
||||||
|
async def check_translation(self, file_size_mb: float = 0) -> tuple[bool, str]:
|
||||||
|
"""Check if translation request is allowed"""
|
||||||
|
async with self._lock:
|
||||||
|
# Check concurrent limit
|
||||||
|
if self.concurrent_translations >= self.config.max_concurrent_translations:
|
||||||
|
return False, f"Too many concurrent translations. Max {self.config.max_concurrent_translations} at a time."
|
||||||
|
|
||||||
|
# Check translation per minute
|
||||||
|
if not await self.translation_minute.is_allowed():
|
||||||
|
return False, f"Translation rate limit exceeded. Max {self.config.translations_per_minute} translations per minute."
|
||||||
|
|
||||||
|
# Check translation per hour
|
||||||
|
if not await self.translation_hour.is_allowed():
|
||||||
|
return False, f"Hourly translation limit exceeded. Max {self.config.translations_per_hour} translations per hour."
|
||||||
|
|
||||||
|
# Check total size per hour
|
||||||
|
async with self._lock:
|
||||||
|
now = time.time()
|
||||||
|
self.total_size_hour = [(ts, size) for ts, size in self.total_size_hour if now - ts < 3600]
|
||||||
|
total_size = sum(size for _, size in self.total_size_hour)
|
||||||
|
|
||||||
|
if total_size + file_size_mb > self.config.max_total_size_per_hour_mb:
|
||||||
|
return False, f"Hourly data limit exceeded. Max {self.config.max_total_size_per_hour_mb}MB per hour."
|
||||||
|
|
||||||
|
self.total_size_hour.append((now, file_size_mb))
|
||||||
|
|
||||||
|
return True, ""
|
||||||
|
|
||||||
|
async def start_translation(self):
|
||||||
|
"""Mark start of translation"""
|
||||||
|
async with self._lock:
|
||||||
|
self.concurrent_translations += 1
|
||||||
|
|
||||||
|
async def end_translation(self):
|
||||||
|
"""Mark end of translation"""
|
||||||
|
async with self._lock:
|
||||||
|
self.concurrent_translations = max(0, self.concurrent_translations - 1)
|
||||||
|
|
||||||
|
def get_stats(self) -> dict:
|
||||||
|
"""Get current rate limit stats"""
|
||||||
|
return {
|
||||||
|
"requests_minute": self.minute_counter.current_count,
|
||||||
|
"requests_hour": self.hour_counter.current_count,
|
||||||
|
"requests_day": self.day_counter.current_count,
|
||||||
|
"translations_minute": self.translation_minute.current_count,
|
||||||
|
"translations_hour": self.translation_hour.current_count,
|
||||||
|
"concurrent_translations": self.concurrent_translations,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimitManager:
|
||||||
|
"""Manages rate limiters for all clients"""
|
||||||
|
|
||||||
|
def __init__(self, config: Optional[RateLimitConfig] = None):
|
||||||
|
self.config = config or RateLimitConfig()
|
||||||
|
self.clients: Dict[str, ClientRateLimiter] = defaultdict(lambda: ClientRateLimiter(self.config))
|
||||||
|
self._cleanup_interval = 3600 # Cleanup old clients every hour
|
||||||
|
self._last_cleanup = time.time()
|
||||||
|
self._total_requests = 0
|
||||||
|
self._total_translations = 0
|
||||||
|
|
||||||
|
def get_client_id(self, request: Request) -> str:
|
||||||
|
"""Extract client identifier from request"""
|
||||||
|
# Try to get real IP from headers (for proxied requests)
|
||||||
|
forwarded = request.headers.get("X-Forwarded-For")
|
||||||
|
if forwarded:
|
||||||
|
return forwarded.split(",")[0].strip()
|
||||||
|
|
||||||
|
real_ip = request.headers.get("X-Real-IP")
|
||||||
|
if real_ip:
|
||||||
|
return real_ip
|
||||||
|
|
||||||
|
# Fall back to direct client IP
|
||||||
|
if request.client:
|
||||||
|
return request.client.host
|
||||||
|
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
def is_whitelisted(self, client_id: str) -> bool:
|
||||||
|
"""Check if client is whitelisted"""
|
||||||
|
return client_id in self.config.whitelist_ips
|
||||||
|
|
||||||
|
async def check_request(self, request: Request) -> tuple[bool, str, str]:
|
||||||
|
"""Check if request is allowed, return (allowed, reason, client_id)"""
|
||||||
|
client_id = self.get_client_id(request)
|
||||||
|
self._total_requests += 1
|
||||||
|
|
||||||
|
if self.is_whitelisted(client_id):
|
||||||
|
return True, "", client_id
|
||||||
|
|
||||||
|
client = self.clients[client_id]
|
||||||
|
allowed, reason = await client.check_request()
|
||||||
|
|
||||||
|
return allowed, reason, client_id
|
||||||
|
|
||||||
|
async def check_translation(self, request: Request, file_size_mb: float = 0) -> tuple[bool, str]:
|
||||||
|
"""Check if translation is allowed"""
|
||||||
|
client_id = self.get_client_id(request)
|
||||||
|
self._total_translations += 1
|
||||||
|
|
||||||
|
if self.is_whitelisted(client_id):
|
||||||
|
return True, ""
|
||||||
|
|
||||||
|
client = self.clients[client_id]
|
||||||
|
return await client.check_translation(file_size_mb)
|
||||||
|
|
||||||
|
async def check_translation_limit(self, client_id: str, file_size_mb: float = 0) -> bool:
|
||||||
|
"""Check if translation is allowed for a specific client ID"""
|
||||||
|
if self.is_whitelisted(client_id):
|
||||||
|
return True
|
||||||
|
|
||||||
|
client = self.clients[client_id]
|
||||||
|
allowed, _ = await client.check_translation(file_size_mb)
|
||||||
|
return allowed
|
||||||
|
|
||||||
|
def get_client_stats(self, request: Request) -> dict:
|
||||||
|
"""Get rate limit stats for a client"""
|
||||||
|
client_id = self.get_client_id(request)
|
||||||
|
client = self.clients[client_id]
|
||||||
|
return {
|
||||||
|
"client_id": client_id,
|
||||||
|
"is_whitelisted": self.is_whitelisted(client_id),
|
||||||
|
**client.get_stats()
|
||||||
|
}
|
||||||
|
|
||||||
|
async def get_client_status(self, client_id: str) -> dict:
|
||||||
|
"""Get current usage status for a client"""
|
||||||
|
if client_id not in self.clients:
|
||||||
|
return {"status": "no_activity", "requests": 0}
|
||||||
|
|
||||||
|
client = self.clients[client_id]
|
||||||
|
stats = client.get_stats()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"requests_used_minute": stats["requests_minute"],
|
||||||
|
"requests_used_hour": stats["requests_hour"],
|
||||||
|
"translations_used_minute": stats["translations_minute"],
|
||||||
|
"translations_used_hour": stats["translations_hour"],
|
||||||
|
"concurrent_translations": stats["concurrent_translations"],
|
||||||
|
"is_whitelisted": self.is_whitelisted(client_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_stats(self) -> dict:
|
||||||
|
"""Get global rate limiting statistics"""
|
||||||
|
return {
|
||||||
|
"total_requests": self._total_requests,
|
||||||
|
"total_translations": self._total_translations,
|
||||||
|
"active_clients": len(self.clients),
|
||||||
|
"config": {
|
||||||
|
"requests_per_minute": self.config.requests_per_minute,
|
||||||
|
"requests_per_hour": self.config.requests_per_hour,
|
||||||
|
"translations_per_minute": self.config.translations_per_minute,
|
||||||
|
"translations_per_hour": self.config.translations_per_hour,
|
||||||
|
"max_concurrent_translations": self.config.max_concurrent_translations
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""FastAPI middleware for rate limiting"""
|
||||||
|
|
||||||
|
def __init__(self, app, rate_limit_manager: RateLimitManager):
|
||||||
|
super().__init__(app)
|
||||||
|
self.manager = rate_limit_manager
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next):
|
||||||
|
# Skip rate limiting for health checks and static files
|
||||||
|
if request.url.path in ["/health", "/", "/docs", "/openapi.json", "/redoc"]:
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
if request.url.path.startswith("/static"):
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
# Check rate limit
|
||||||
|
allowed, reason, client_id = await self.manager.check_request(request)
|
||||||
|
|
||||||
|
if not allowed:
|
||||||
|
logger.warning(f"Rate limit exceeded for {client_id}: {reason}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=429,
|
||||||
|
content={
|
||||||
|
"error": "rate_limit_exceeded",
|
||||||
|
"message": reason,
|
||||||
|
"retry_after": 60
|
||||||
|
},
|
||||||
|
headers={"Retry-After": "60"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add client info to request state for use in endpoints
|
||||||
|
request.state.client_id = client_id
|
||||||
|
request.state.rate_limiter = self.manager.clients[client_id]
|
||||||
|
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
|
||||||
|
# Global rate limit manager
|
||||||
|
rate_limit_manager = RateLimitManager()
|
||||||
142
middleware/security.py
Normal file
142
middleware/security.py
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
"""
|
||||||
|
Security Headers Middleware for SaaS robustness
|
||||||
|
Adds security headers to all responses
|
||||||
|
"""
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import Response
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""Add security headers to all responses"""
|
||||||
|
|
||||||
|
def __init__(self, app, config: dict = None):
|
||||||
|
super().__init__(app)
|
||||||
|
self.config = config or {}
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next) -> Response:
|
||||||
|
response = await call_next(request)
|
||||||
|
|
||||||
|
# Prevent clickjacking
|
||||||
|
response.headers["X-Frame-Options"] = "DENY"
|
||||||
|
|
||||||
|
# Prevent MIME type sniffing
|
||||||
|
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||||
|
|
||||||
|
# Enable XSS filter
|
||||||
|
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||||||
|
|
||||||
|
# Referrer policy
|
||||||
|
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||||
|
|
||||||
|
# Permissions policy
|
||||||
|
response.headers["Permissions-Policy"] = "geolocation=(), microphone=(), camera=()"
|
||||||
|
|
||||||
|
# Content Security Policy (adjust for your frontend)
|
||||||
|
if not request.url.path.startswith("/docs") and not request.url.path.startswith("/redoc"):
|
||||||
|
response.headers["Content-Security-Policy"] = (
|
||||||
|
"default-src 'self'; "
|
||||||
|
"script-src 'self' 'unsafe-inline' 'unsafe-eval' blob:; "
|
||||||
|
"style-src 'self' 'unsafe-inline'; "
|
||||||
|
"img-src 'self' data: blob:; "
|
||||||
|
"font-src 'self' data:; "
|
||||||
|
"connect-src 'self' http://localhost:* https://localhost:* ws://localhost:*; "
|
||||||
|
"worker-src 'self' blob:; "
|
||||||
|
)
|
||||||
|
|
||||||
|
# HSTS (only in production with HTTPS)
|
||||||
|
if self.config.get("enable_hsts", False):
|
||||||
|
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""Log all requests for monitoring and debugging"""
|
||||||
|
|
||||||
|
def __init__(self, app, log_body: bool = False):
|
||||||
|
super().__init__(app)
|
||||||
|
self.log_body = log_body
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next) -> Response:
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
# Generate request ID
|
||||||
|
request_id = str(uuid.uuid4())[:8]
|
||||||
|
request.state.request_id = request_id
|
||||||
|
|
||||||
|
# Get client info
|
||||||
|
client_ip = self._get_client_ip(request)
|
||||||
|
|
||||||
|
# Log request start
|
||||||
|
start_time = time.time()
|
||||||
|
logger.info(
|
||||||
|
f"[{request_id}] {request.method} {request.url.path} "
|
||||||
|
f"from {client_ip} - Started"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await call_next(request)
|
||||||
|
|
||||||
|
# Log request completion
|
||||||
|
duration = time.time() - start_time
|
||||||
|
logger.info(
|
||||||
|
f"[{request_id}] {request.method} {request.url.path} "
|
||||||
|
f"- {response.status_code} in {duration:.3f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add request ID to response headers
|
||||||
|
response.headers["X-Request-ID"] = request_id
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
duration = time.time() - start_time
|
||||||
|
logger.error(
|
||||||
|
f"[{request_id}] {request.method} {request.url.path} "
|
||||||
|
f"- ERROR in {duration:.3f}s: {str(e)}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _get_client_ip(self, request: Request) -> str:
|
||||||
|
"""Get real client IP from headers or connection"""
|
||||||
|
forwarded = request.headers.get("X-Forwarded-For")
|
||||||
|
if forwarded:
|
||||||
|
return forwarded.split(",")[0].strip()
|
||||||
|
|
||||||
|
real_ip = request.headers.get("X-Real-IP")
|
||||||
|
if real_ip:
|
||||||
|
return real_ip
|
||||||
|
|
||||||
|
if request.client:
|
||||||
|
return request.client.host
|
||||||
|
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorHandlingMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""Catch all unhandled exceptions and return proper error responses"""
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next) -> Response:
|
||||||
|
from starlette.responses import JSONResponse
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
request_id = getattr(request.state, 'request_id', 'unknown')
|
||||||
|
logger.exception(f"[{request_id}] Unhandled exception: {str(e)}")
|
||||||
|
|
||||||
|
# Don't expose internal errors in production
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content={
|
||||||
|
"error": "internal_server_error",
|
||||||
|
"message": "An unexpected error occurred. Please try again later.",
|
||||||
|
"request_id": request_id
|
||||||
|
}
|
||||||
|
)
|
||||||
440
middleware/validation.py
Normal file
440
middleware/validation.py
Normal file
@ -0,0 +1,440 @@
|
|||||||
|
"""
|
||||||
|
Input Validation Module for SaaS robustness
|
||||||
|
Validates all user inputs before processing
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
import magic
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, List, Set
|
||||||
|
from fastapi import UploadFile, HTTPException
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationError(Exception):
|
||||||
|
"""Custom validation error with user-friendly messages"""
|
||||||
|
def __init__(self, message: str, code: str = "validation_error", details: Optional[dict] = None):
|
||||||
|
self.message = message
|
||||||
|
self.code = code
|
||||||
|
self.details = details or {}
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationResult:
|
||||||
|
"""Result of a validation check"""
|
||||||
|
def __init__(self, is_valid: bool = True, errors: List[str] = None, warnings: List[str] = None, data: dict = None):
|
||||||
|
self.is_valid = is_valid
|
||||||
|
self.errors = errors or []
|
||||||
|
self.warnings = warnings or []
|
||||||
|
self.data = data or {}
|
||||||
|
|
||||||
|
|
||||||
|
class FileValidator:
|
||||||
|
"""Validates uploaded files for security and compatibility"""
|
||||||
|
|
||||||
|
# Allowed MIME types mapped to extensions
|
||||||
|
ALLOWED_MIME_TYPES = {
|
||||||
|
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
|
||||||
|
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
|
||||||
|
"application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Magic bytes for Office Open XML files (ZIP format)
|
||||||
|
OFFICE_MAGIC_BYTES = b"PK\x03\x04"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_size_mb: int = 50,
|
||||||
|
allowed_extensions: Set[str] = None,
|
||||||
|
scan_content: bool = True
|
||||||
|
):
|
||||||
|
self.max_size_bytes = max_size_mb * 1024 * 1024
|
||||||
|
self.max_size_mb = max_size_mb
|
||||||
|
self.allowed_extensions = allowed_extensions or {".xlsx", ".docx", ".pptx"}
|
||||||
|
self.scan_content = scan_content
|
||||||
|
|
||||||
|
async def validate_async(self, file: UploadFile) -> ValidationResult:
|
||||||
|
"""
|
||||||
|
Validate an uploaded file asynchronously
|
||||||
|
Returns ValidationResult with is_valid, errors, warnings
|
||||||
|
"""
|
||||||
|
errors = []
|
||||||
|
warnings = []
|
||||||
|
data = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Validate filename
|
||||||
|
if not file.filename:
|
||||||
|
errors.append("Filename is required")
|
||||||
|
return ValidationResult(is_valid=False, errors=errors)
|
||||||
|
|
||||||
|
# Sanitize filename
|
||||||
|
try:
|
||||||
|
safe_filename = self._sanitize_filename(file.filename)
|
||||||
|
data["safe_filename"] = safe_filename
|
||||||
|
except ValidationError as e:
|
||||||
|
errors.append(str(e.message))
|
||||||
|
return ValidationResult(is_valid=False, errors=errors)
|
||||||
|
|
||||||
|
# Validate extension
|
||||||
|
try:
|
||||||
|
extension = self._validate_extension(safe_filename)
|
||||||
|
data["extension"] = extension
|
||||||
|
except ValidationError as e:
|
||||||
|
errors.append(str(e.message))
|
||||||
|
return ValidationResult(is_valid=False, errors=errors)
|
||||||
|
|
||||||
|
# Read file content for validation
|
||||||
|
content = await file.read()
|
||||||
|
await file.seek(0) # Reset for later processing
|
||||||
|
|
||||||
|
# Validate file size
|
||||||
|
file_size = len(content)
|
||||||
|
data["size_bytes"] = file_size
|
||||||
|
data["size_mb"] = round(file_size / (1024*1024), 2)
|
||||||
|
|
||||||
|
if file_size > self.max_size_bytes:
|
||||||
|
errors.append(f"File too large. Maximum size is {self.max_size_mb}MB, got {file_size / (1024*1024):.1f}MB")
|
||||||
|
return ValidationResult(is_valid=False, errors=errors, data=data)
|
||||||
|
|
||||||
|
if file_size == 0:
|
||||||
|
errors.append("File is empty")
|
||||||
|
return ValidationResult(is_valid=False, errors=errors, data=data)
|
||||||
|
|
||||||
|
# Warn about large files
|
||||||
|
if file_size > self.max_size_bytes * 0.8:
|
||||||
|
warnings.append(f"File is {data['size_mb']}MB, approaching the {self.max_size_mb}MB limit")
|
||||||
|
|
||||||
|
# Validate magic bytes
|
||||||
|
if self.scan_content:
|
||||||
|
try:
|
||||||
|
self._validate_magic_bytes(content, extension)
|
||||||
|
except ValidationError as e:
|
||||||
|
errors.append(str(e.message))
|
||||||
|
return ValidationResult(is_valid=False, errors=errors, data=data)
|
||||||
|
|
||||||
|
# Validate MIME type
|
||||||
|
try:
|
||||||
|
mime_type = self._detect_mime_type(content)
|
||||||
|
data["mime_type"] = mime_type
|
||||||
|
self._validate_mime_type(mime_type, extension)
|
||||||
|
except ValidationError as e:
|
||||||
|
warnings.append(f"MIME type warning: {e.message}")
|
||||||
|
except Exception:
|
||||||
|
warnings.append("Could not verify MIME type")
|
||||||
|
|
||||||
|
data["original_filename"] = file.filename
|
||||||
|
|
||||||
|
return ValidationResult(is_valid=True, errors=errors, warnings=warnings, data=data)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Validation error: {str(e)}")
|
||||||
|
errors.append(f"Validation failed: {str(e)}")
|
||||||
|
return ValidationResult(is_valid=False, errors=errors, warnings=warnings, data=data)
|
||||||
|
|
||||||
|
async def validate(self, file: UploadFile) -> dict:
|
||||||
|
"""
|
||||||
|
Validate an uploaded file
|
||||||
|
Returns validation info dict or raises ValidationError
|
||||||
|
"""
|
||||||
|
# Validate filename
|
||||||
|
if not file.filename:
|
||||||
|
raise ValidationError(
|
||||||
|
"Filename is required",
|
||||||
|
code="missing_filename"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sanitize filename
|
||||||
|
safe_filename = self._sanitize_filename(file.filename)
|
||||||
|
|
||||||
|
# Validate extension
|
||||||
|
extension = self._validate_extension(safe_filename)
|
||||||
|
|
||||||
|
# Read file content for validation
|
||||||
|
content = await file.read()
|
||||||
|
await file.seek(0) # Reset for later processing
|
||||||
|
|
||||||
|
# Validate file size
|
||||||
|
file_size = len(content)
|
||||||
|
if file_size > self.max_size_bytes:
|
||||||
|
raise ValidationError(
|
||||||
|
f"File too large. Maximum size is {self.max_size_mb}MB, got {file_size / (1024*1024):.1f}MB",
|
||||||
|
code="file_too_large",
|
||||||
|
details={"max_mb": self.max_size_mb, "actual_mb": round(file_size / (1024*1024), 2)}
|
||||||
|
)
|
||||||
|
|
||||||
|
if file_size == 0:
|
||||||
|
raise ValidationError(
|
||||||
|
"File is empty",
|
||||||
|
code="empty_file"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate magic bytes (file signature)
|
||||||
|
if self.scan_content:
|
||||||
|
self._validate_magic_bytes(content, extension)
|
||||||
|
|
||||||
|
# Validate MIME type
|
||||||
|
mime_type = self._detect_mime_type(content)
|
||||||
|
self._validate_mime_type(mime_type, extension)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"original_filename": file.filename,
|
||||||
|
"safe_filename": safe_filename,
|
||||||
|
"extension": extension,
|
||||||
|
"size_bytes": file_size,
|
||||||
|
"size_mb": round(file_size / (1024*1024), 2),
|
||||||
|
"mime_type": mime_type
|
||||||
|
}
|
||||||
|
|
||||||
|
def _sanitize_filename(self, filename: str) -> str:
|
||||||
|
"""Sanitize filename to prevent path traversal and other attacks"""
|
||||||
|
# Remove path components
|
||||||
|
filename = Path(filename).name
|
||||||
|
|
||||||
|
# Remove null bytes and control characters
|
||||||
|
filename = re.sub(r'[\x00-\x1f\x7f-\x9f]', '', filename)
|
||||||
|
|
||||||
|
# Remove potentially dangerous characters
|
||||||
|
filename = re.sub(r'[<>:"/\\|?*]', '_', filename)
|
||||||
|
|
||||||
|
# Limit length
|
||||||
|
if len(filename) > 255:
|
||||||
|
name, ext = filename.rsplit('.', 1) if '.' in filename else (filename, '')
|
||||||
|
filename = name[:250] + ('.' + ext if ext else '')
|
||||||
|
|
||||||
|
# Ensure not empty after sanitization
|
||||||
|
if not filename or filename.strip() == '':
|
||||||
|
raise ValidationError(
|
||||||
|
"Invalid filename",
|
||||||
|
code="invalid_filename"
|
||||||
|
)
|
||||||
|
|
||||||
|
return filename
|
||||||
|
|
||||||
|
def _validate_extension(self, filename: str) -> str:
|
||||||
|
"""Validate and return the file extension"""
|
||||||
|
if '.' not in filename:
|
||||||
|
raise ValidationError(
|
||||||
|
f"File must have an extension. Supported: {', '.join(self.allowed_extensions)}",
|
||||||
|
code="missing_extension",
|
||||||
|
details={"allowed_extensions": list(self.allowed_extensions)}
|
||||||
|
)
|
||||||
|
|
||||||
|
extension = '.' + filename.rsplit('.', 1)[1].lower()
|
||||||
|
|
||||||
|
if extension not in self.allowed_extensions:
|
||||||
|
raise ValidationError(
|
||||||
|
f"File type '{extension}' not supported. Supported types: {', '.join(self.allowed_extensions)}",
|
||||||
|
code="unsupported_file_type",
|
||||||
|
details={"extension": extension, "allowed_extensions": list(self.allowed_extensions)}
|
||||||
|
)
|
||||||
|
|
||||||
|
return extension
|
||||||
|
|
||||||
|
def _validate_magic_bytes(self, content: bytes, extension: str):
|
||||||
|
"""Validate file magic bytes match expected format"""
|
||||||
|
# All supported formats are Office Open XML (ZIP-based)
|
||||||
|
if not content.startswith(self.OFFICE_MAGIC_BYTES):
|
||||||
|
raise ValidationError(
|
||||||
|
"File content does not match expected format. The file may be corrupted or not a valid Office document.",
|
||||||
|
code="invalid_file_content"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _detect_mime_type(self, content: bytes) -> str:
|
||||||
|
"""Detect MIME type from file content"""
|
||||||
|
try:
|
||||||
|
mime = magic.Magic(mime=True)
|
||||||
|
return mime.from_buffer(content)
|
||||||
|
except Exception:
|
||||||
|
# Fallback to basic detection
|
||||||
|
if content.startswith(self.OFFICE_MAGIC_BYTES):
|
||||||
|
return "application/zip"
|
||||||
|
return "application/octet-stream"
|
||||||
|
|
||||||
|
def _validate_mime_type(self, mime_type: str, extension: str):
|
||||||
|
"""Validate MIME type matches extension"""
|
||||||
|
# Office Open XML files may be detected as ZIP
|
||||||
|
allowed_mimes = list(self.ALLOWED_MIME_TYPES.keys()) + ["application/zip", "application/octet-stream"]
|
||||||
|
|
||||||
|
if mime_type not in allowed_mimes:
|
||||||
|
raise ValidationError(
|
||||||
|
f"Invalid file type detected. Expected Office document, got: {mime_type}",
|
||||||
|
code="invalid_mime_type",
|
||||||
|
details={"detected_mime": mime_type}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LanguageValidator:
|
||||||
|
"""Validates language codes"""
|
||||||
|
|
||||||
|
SUPPORTED_LANGUAGES = {
|
||||||
|
# ISO 639-1 codes
|
||||||
|
"af", "sq", "am", "ar", "hy", "az", "eu", "be", "bn", "bs",
|
||||||
|
"bg", "ca", "ceb", "zh", "zh-CN", "zh-TW", "co", "hr", "cs",
|
||||||
|
"da", "nl", "en", "eo", "et", "fi", "fr", "fy", "gl", "ka",
|
||||||
|
"de", "el", "gu", "ht", "ha", "haw", "he", "hi", "hmn", "hu",
|
||||||
|
"is", "ig", "id", "ga", "it", "ja", "jv", "kn", "kk", "km",
|
||||||
|
"rw", "ko", "ku", "ky", "lo", "la", "lv", "lt", "lb", "mk",
|
||||||
|
"mg", "ms", "ml", "mt", "mi", "mr", "mn", "my", "ne", "no",
|
||||||
|
"ny", "or", "ps", "fa", "pl", "pt", "pa", "ro", "ru", "sm",
|
||||||
|
"gd", "sr", "st", "sn", "sd", "si", "sk", "sl", "so", "es",
|
||||||
|
"su", "sw", "sv", "tl", "tg", "ta", "tt", "te", "th", "tr",
|
||||||
|
"tk", "uk", "ur", "ug", "uz", "vi", "cy", "xh", "yi", "yo",
|
||||||
|
"zu", "auto"
|
||||||
|
}
|
||||||
|
|
||||||
|
LANGUAGE_NAMES = {
|
||||||
|
"en": "English", "es": "Spanish", "fr": "French", "de": "German",
|
||||||
|
"it": "Italian", "pt": "Portuguese", "ru": "Russian", "zh": "Chinese",
|
||||||
|
"zh-CN": "Chinese (Simplified)", "zh-TW": "Chinese (Traditional)",
|
||||||
|
"ja": "Japanese", "ko": "Korean", "ar": "Arabic", "hi": "Hindi",
|
||||||
|
"nl": "Dutch", "pl": "Polish", "tr": "Turkish", "sv": "Swedish",
|
||||||
|
"da": "Danish", "no": "Norwegian", "fi": "Finnish", "cs": "Czech",
|
||||||
|
"el": "Greek", "th": "Thai", "vi": "Vietnamese", "id": "Indonesian",
|
||||||
|
"uk": "Ukrainian", "ro": "Romanian", "hu": "Hungarian", "auto": "Auto-detect"
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate(cls, language_code: str, field_name: str = "language") -> str:
|
||||||
|
"""Validate and normalize language code"""
|
||||||
|
if not language_code:
|
||||||
|
raise ValidationError(
|
||||||
|
f"{field_name} is required",
|
||||||
|
code="missing_language"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Normalize
|
||||||
|
normalized = language_code.strip().lower()
|
||||||
|
|
||||||
|
# Handle common variations
|
||||||
|
if normalized in ["chinese", "cn"]:
|
||||||
|
normalized = "zh-CN"
|
||||||
|
elif normalized in ["chinese-traditional", "tw"]:
|
||||||
|
normalized = "zh-TW"
|
||||||
|
|
||||||
|
if normalized not in cls.SUPPORTED_LANGUAGES:
|
||||||
|
raise ValidationError(
|
||||||
|
f"Unsupported language code: '{language_code}'. See /languages for supported codes.",
|
||||||
|
code="unsupported_language",
|
||||||
|
details={"language": language_code}
|
||||||
|
)
|
||||||
|
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_language_name(cls, code: str) -> str:
|
||||||
|
"""Get human-readable language name"""
|
||||||
|
return cls.LANGUAGE_NAMES.get(code, code.upper())
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderValidator:
|
||||||
|
"""Validates translation provider configuration"""
|
||||||
|
|
||||||
|
SUPPORTED_PROVIDERS = {"google", "ollama", "deepl", "libre", "openai", "webllm"}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate(cls, provider: str, **kwargs) -> dict:
|
||||||
|
"""Validate provider and its required configuration"""
|
||||||
|
if not provider:
|
||||||
|
raise ValidationError(
|
||||||
|
"Translation provider is required",
|
||||||
|
code="missing_provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
normalized = provider.strip().lower()
|
||||||
|
|
||||||
|
if normalized not in cls.SUPPORTED_PROVIDERS:
|
||||||
|
raise ValidationError(
|
||||||
|
f"Unsupported provider: '{provider}'. Supported: {', '.join(cls.SUPPORTED_PROVIDERS)}",
|
||||||
|
code="unsupported_provider",
|
||||||
|
details={"provider": provider, "supported": list(cls.SUPPORTED_PROVIDERS)}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Provider-specific validation
|
||||||
|
if normalized == "deepl":
|
||||||
|
if not kwargs.get("deepl_api_key"):
|
||||||
|
raise ValidationError(
|
||||||
|
"DeepL API key is required when using DeepL provider",
|
||||||
|
code="missing_deepl_key"
|
||||||
|
)
|
||||||
|
|
||||||
|
elif normalized == "openai":
|
||||||
|
if not kwargs.get("openai_api_key"):
|
||||||
|
raise ValidationError(
|
||||||
|
"OpenAI API key is required when using OpenAI provider",
|
||||||
|
code="missing_openai_key"
|
||||||
|
)
|
||||||
|
|
||||||
|
elif normalized == "ollama":
|
||||||
|
# Ollama doesn't require API key but may need model
|
||||||
|
model = kwargs.get("ollama_model", "")
|
||||||
|
if not model:
|
||||||
|
logger.warning("No Ollama model specified, will use default")
|
||||||
|
|
||||||
|
return {"provider": normalized, "validated": True}
|
||||||
|
|
||||||
|
|
||||||
|
class InputSanitizer:
|
||||||
|
"""Sanitizes user inputs to prevent injection attacks"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def sanitize_text(text: str, max_length: int = 10000) -> str:
|
||||||
|
"""Sanitize text input"""
|
||||||
|
if not text:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Remove null bytes
|
||||||
|
text = text.replace('\x00', '')
|
||||||
|
|
||||||
|
# Limit length
|
||||||
|
if len(text) > max_length:
|
||||||
|
text = text[:max_length]
|
||||||
|
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def sanitize_language_code(code: str) -> str:
|
||||||
|
"""Sanitize and normalize language code"""
|
||||||
|
if not code:
|
||||||
|
return "auto"
|
||||||
|
|
||||||
|
# Remove dangerous characters, keep only alphanumeric and hyphen
|
||||||
|
code = re.sub(r'[^a-zA-Z0-9\-]', '', code.strip())
|
||||||
|
|
||||||
|
# Limit length
|
||||||
|
if len(code) > 10:
|
||||||
|
code = code[:10]
|
||||||
|
|
||||||
|
return code.lower() if code else "auto"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def sanitize_url(url: str) -> str:
|
||||||
|
"""Sanitize URL input"""
|
||||||
|
if not url:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
url = url.strip()
|
||||||
|
|
||||||
|
# Basic URL validation
|
||||||
|
if not re.match(r'^https?://', url, re.IGNORECASE):
|
||||||
|
raise ValidationError(
|
||||||
|
"Invalid URL format. Must start with http:// or https://",
|
||||||
|
code="invalid_url"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove trailing slashes
|
||||||
|
url = url.rstrip('/')
|
||||||
|
|
||||||
|
return url
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def sanitize_api_key(key: str) -> str:
|
||||||
|
"""Sanitize API key (just trim, no logging)"""
|
||||||
|
if not key:
|
||||||
|
return ""
|
||||||
|
return key.strip()
|
||||||
|
|
||||||
|
|
||||||
|
# Default validators
|
||||||
|
file_validator = FileValidator()
|
||||||
@ -14,3 +14,7 @@ pandas==2.1.4
|
|||||||
requests==2.31.0
|
requests==2.31.0
|
||||||
ipykernel==6.27.1
|
ipykernel==6.27.1
|
||||||
openai>=1.0.0
|
openai>=1.0.0
|
||||||
|
|
||||||
|
# SaaS robustness dependencies
|
||||||
|
psutil==5.9.8
|
||||||
|
python-magic-bin==0.4.14 # For Windows, use python-magic on Linux
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user