feat: revue de code, doc CODE_REVIEW, forfaits 2026, traduction LLM, providers avec modèle

Made-with: Cursor
This commit is contained in:
Sepehr Ramezani
2026-03-07 11:42:58 +01:00
parent 3d37ce4582
commit 473b3e26c7
181 changed files with 30617 additions and 7170 deletions

View File

@@ -27,6 +27,9 @@ from .validation import (
from .security import (
SecurityHeadersMiddleware,
RequestLoggingMiddleware,
)
from .error_handler import (
ErrorHandlingMiddleware,
)
@@ -37,7 +40,25 @@ from .cleanup import (
create_cleanup_manager,
)
from .api_key_auth import (
APIKeyError,
get_user_from_api_key,
get_authenticated_user,
get_authenticated_user_optional,
get_current_user_optional,
require_authenticated_user,
require_api_key,
)
__all__ = [
# API Key Authentication
"APIKeyError",
"get_user_from_api_key",
"get_authenticated_user",
"get_authenticated_user_optional",
"get_current_user_optional",
"require_authenticated_user",
"require_api_key",
# Rate limiting
"RateLimitConfig",
"RateLimitManager",

222
middleware/api_key_auth.py Normal file
View File

@@ -0,0 +1,222 @@
"""
API Key Authentication Middleware
Provides reusable dependencies for API key authentication across all endpoints.
Story 3.4: Authentification API via X-API-Key
"""
from typing import Optional, Any, Union
from fastapi import Header, Depends
from fastapi.responses import JSONResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
security = HTTPBearer(auto_error=False)
class APIKeyError(Exception):
"""Exception for API key authentication errors with structured error codes."""
INVALID_API_KEY = "INVALID_API_KEY"
API_KEY_REVOKED = "API_KEY_REVOKED"
API_KEY_EXPIRED = "API_KEY_EXPIRED"
MISSING_API_KEY = "MISSING_API_KEY"
UNAUTHORIZED = "UNAUTHORIZED"
ERROR_MESSAGES = {
INVALID_API_KEY: "Clé API invalide ou non reconnue.",
API_KEY_REVOKED: "Cette clé API a été révoquée.",
API_KEY_EXPIRED: "Cette clé API a expiré.",
MISSING_API_KEY: "Clé API requise pour cet endpoint.",
UNAUTHORIZED: "Authentification requise. Utilisez X-API-Key ou Authorization: Bearer.",
}
def __init__(self, code: str, message: Optional[str] = None):
self.code = code
self.message = message or self.ERROR_MESSAGES.get(code, "Erreur d'authentification")
super().__init__(self.message)
def to_response(self, status_code: int = 401) -> JSONResponse:
"""Convert to JSONResponse for FastAPI."""
return JSONResponse(
status_code=status_code,
content={
"error": self.code,
"message": self.message,
},
)
def _raise_api_key_error(code: str, message: Optional[str] = None) -> None:
"""Raise an APIKeyError and convert it to JSONResponse for FastAPI."""
raise APIKeyError(code, message)
async def get_user_from_api_key(
x_api_key: Optional[str] = Header(None, alias="X-API-Key"),
) -> Optional[Any]:
"""
Get user from X-API-Key header if provided.
Returns:
User object if valid API key provided
None if no API key provided (caller should try other auth methods)
Raises:
APIKeyError: With structured error code if API key is invalid/revoked/expired
"""
if not x_api_key:
return None
try:
from services.auth_service import get_user_by_api_key
user = get_user_by_api_key(x_api_key)
return user
except ValueError as e:
# Handle revoked/expired API keys with specific error codes
error_code = str(e)
if error_code == "API_KEY_REVOKED":
raise APIKeyError("API_KEY_REVOKED", "Cette clé API a été révoquée.")
elif error_code == "API_KEY_EXPIRED":
raise APIKeyError("API_KEY_EXPIRED", "Cette clé API a expiré.")
else:
# Unknown error - treat as invalid
raise APIKeyError("INVALID_API_KEY", "Clé API invalide ou non reconnue.")
except Exception:
# Unexpected error - treat as invalid
raise APIKeyError("INVALID_API_KEY", "Clé API invalide ou non reconnue.")
async def get_current_user_optional(
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
) -> Optional[Any]:
"""Get current user if authenticated via JWT, None otherwise."""
if not credentials:
return None
try:
from routes.auth_routes import get_current_user
user = await get_current_user(credentials)
return user
except Exception:
return None
async def get_authenticated_user_optional(
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
x_api_key: Optional[str] = Header(None, alias="X-API-Key"),
) -> Optional[Any]:
"""
Get authenticated user from API key or JWT (optional - returns None if not authenticated).
Priority:
1. X-API-Key header (automation users)
2. JWT Bearer token (web users)
3. None (unauthenticated)
Returns:
User object if authenticated, None otherwise (never raises for auth failures)
"""
# Try API key first (priority for automation)
if x_api_key:
try:
user = await get_user_from_api_key(x_api_key)
if user:
return user
except APIKeyError:
# Invalid API key, fall through to JWT
pass
# Fall back to JWT
if credentials:
try:
from routes.auth_routes import get_current_user
user = await get_current_user(credentials)
return user
except Exception:
pass
return None
async def get_authenticated_user(
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
x_api_key: Optional[str] = Header(None, alias="X-API-Key"),
) -> Optional[Any]:
"""
Get authenticated user from API key or JWT.
Priority:
1. X-API-Key header (automation users)
2. JWT Bearer token (web users)
3. None (unauthenticated)
Returns:
User object if authenticated
None if not authenticated
Raises:
APIKeyError: If API key is provided but invalid/revoked/expired
"""
# Try API key first (priority for automation)
if x_api_key:
# get_user_from_api_key will raise APIKeyError for invalid keys
user = await get_user_from_api_key(x_api_key)
if user:
return user
# Should not reach here - get_user_from_api_key returns None only if no key provided
raise APIKeyError("INVALID_API_KEY", "Clé API invalide ou non reconnue.")
# Fall back to JWT
if credentials:
try:
from routes.auth_routes import get_current_user
user = await get_current_user(credentials)
return user
except Exception:
pass
return None
async def require_authenticated_user(
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
x_api_key: Optional[str] = Header(None, alias="X-API-Key"),
) -> Any:
"""
Require authentication (API key or JWT).
Raises:
APIKeyError: 401 if not authenticated
Returns:
User object (guaranteed to be authenticated)
"""
user = await get_authenticated_user(credentials, x_api_key)
if not user:
raise APIKeyError("MISSING_API_KEY", "Authentification requise. Utilisez X-API-Key ou Authorization: Bearer.")
return user
async def require_api_key(
x_api_key: str = Header(..., alias="X-API-Key"),
) -> Any:
"""
Require API key authentication (no JWT fallback).
Use this for endpoints that MUST use API key (e.g., certain automation endpoints).
Raises:
APIKeyError: 401 if API key is missing, invalid, revoked, or expired
Returns:
User object (guaranteed to be authenticated via API key)
"""
return await get_user_from_api_key(x_api_key)

View File

@@ -2,6 +2,7 @@
Cleanup and Resource Management for SaaS robustness
Automatic cleanup of temporary files and resources
"""
import os
import time
import asyncio
@@ -10,77 +11,84 @@ from pathlib import Path
from datetime import datetime, timedelta
from typing import Optional, Set
import logging
import json
from services.storage_tracker import _get_async_redis, KEY_PREFIX
logger = logging.getLogger(__name__)
try:
import structlog
logger = structlog.get_logger(__name__)
_HAS_STRUCTLOG = True
except ImportError:
logger = logging.getLogger(__name__)
_HAS_STRUCTLOG = False
class FileCleanupManager:
"""Manages automatic cleanup of temporary and output files"""
def __init__(
self,
upload_dir: Path,
output_dir: Path,
temp_dir: Path,
max_file_age_hours: int = 1,
cleanup_interval_minutes: int = 10,
max_total_size_gb: float = 10.0
max_file_age_minutes: int = 60,
cleanup_interval_minutes: int = 5,
max_total_size_gb: float = 10.0,
):
self.upload_dir = Path(upload_dir)
self.output_dir = Path(output_dir)
self.temp_dir = Path(temp_dir)
self.max_file_age_seconds = max_file_age_hours * 3600
self.max_file_age_seconds = max_file_age_minutes * 60
self.cleanup_interval = cleanup_interval_minutes * 60
self.max_total_size_bytes = int(max_total_size_gb * 1024 * 1024 * 1024)
self._running = False
self._task: Optional[asyncio.Task] = None
self._protected_files: Set[str] = set()
self._tracked_files: dict = {} # filepath -> {created, ttl_minutes}
self._lock = threading.Lock()
self._stats = {
"files_cleaned": 0,
"bytes_freed": 0,
"cleanup_runs": 0
}
self._stats = {"files_cleaned": 0, "bytes_freed": 0, "cleanup_runs": 0}
async def track_file(self, filepath: Path, ttl_minutes: int = 60):
"""Track a file for automatic cleanup after TTL expires"""
with self._lock:
self._tracked_files[str(filepath)] = {
"created": time.time(),
"ttl_minutes": ttl_minutes,
"expires_at": time.time() + (ttl_minutes * 60)
"expires_at": time.time() + (ttl_minutes * 60),
}
def get_tracked_files(self) -> list:
"""Get list of currently tracked files with their status"""
now = time.time()
result = []
with self._lock:
for filepath, info in self._tracked_files.items():
remaining = info["expires_at"] - now
result.append({
"path": filepath,
"exists": Path(filepath).exists(),
"expires_in_seconds": max(0, int(remaining)),
"ttl_minutes": info["ttl_minutes"]
})
result.append(
{
"path": filepath,
"exists": Path(filepath).exists(),
"expires_in_seconds": max(0, int(remaining)),
"ttl_minutes": info["ttl_minutes"],
}
)
return result
async def cleanup_expired(self) -> int:
"""Cleanup expired tracked files"""
now = time.time()
cleaned = 0
to_remove = []
with self._lock:
for filepath, info in list(self._tracked_files.items()):
if now > info["expires_at"]:
to_remove.append(filepath)
for filepath in to_remove:
try:
path = Path(filepath)
@@ -91,55 +99,57 @@ class FileCleanupManager:
self._stats["files_cleaned"] += 1
self._stats["bytes_freed"] += size
logger.info(f"Cleaned expired file: {filepath}")
with self._lock:
self._tracked_files.pop(filepath, None)
except Exception as e:
logger.warning(f"Failed to clean expired file {filepath}: {e}")
return cleaned
def get_stats(self) -> dict:
"""Get cleanup statistics"""
disk_usage = self.get_disk_usage()
with self._lock:
tracked_count = len(self._tracked_files)
return {
"files_cleaned_total": self._stats["files_cleaned"],
"bytes_freed_total_mb": round(self._stats["bytes_freed"] / (1024 * 1024), 2),
"bytes_freed_total_mb": round(
self._stats["bytes_freed"] / (1024 * 1024), 2
),
"cleanup_runs": self._stats["cleanup_runs"],
"tracked_files": tracked_count,
"disk_usage": disk_usage,
"is_running": self._running
"is_running": self._running,
}
def protect_file(self, filepath: Path):
"""Mark a file as protected (being processed)"""
with self._lock:
self._protected_files.add(str(filepath))
def unprotect_file(self, filepath: Path):
"""Remove protection from a file"""
with self._lock:
self._protected_files.discard(str(filepath))
def is_protected(self, filepath: Path) -> bool:
"""Check if a file is protected"""
with self._lock:
return str(filepath) in self._protected_files
async def start(self):
"""Start the cleanup background task"""
if self._running:
return
self._running = True
self._task = asyncio.create_task(self._cleanup_loop())
logger.info("File cleanup manager started")
async def stop(self):
"""Stop the cleanup background task"""
self._running = False
@@ -150,7 +160,7 @@ class FileCleanupManager:
except asyncio.CancelledError:
pass
logger.info("File cleanup manager stopped")
async def _cleanup_loop(self):
"""Background loop for periodic cleanup"""
while self._running:
@@ -160,88 +170,146 @@ class FileCleanupManager:
self._stats["cleanup_runs"] += 1
except Exception as e:
logger.error(f"Cleanup error: {e}")
await asyncio.sleep(self.cleanup_interval)
async def cleanup(self) -> dict:
"""Perform cleanup of old files"""
"""Perform cleanup of old files and orphans"""
stats = {
"files_deleted": 0,
"bytes_freed": 0,
"errors": []
"orphaned_deleted": 0,
"errors": [],
}
now = time.time()
# Cleanup each directory
# Get tracked paths from Redis to identify orphans
tracked_paths = set()
redis_client = _get_async_redis()
redis_available = redis_client is not None
if redis_client:
try:
keys = await redis_client.keys(f"{KEY_PREFIX}:*")
for key in keys:
data = await redis_client.get(key)
if data:
metadata = json.loads(data)
if "file_path" in metadata:
# Normalize path to absolute string for comparison
path_str = str(Path(metadata["file_path"]).absolute())
tracked_paths.add(path_str)
except Exception as e:
logger.warning(f"Failed to fetch tracked paths from Redis: {e}")
redis_available = False
else:
logger.warning(
"Redis unavailable - orphan detection disabled, using age-based cleanup only"
)
# Cleanup each directory (collect files first to avoid race condition)
for directory in [self.upload_dir, self.output_dir, self.temp_dir]:
if not directory.exists():
continue
for filepath in directory.iterdir():
try:
files_to_check = list(directory.iterdir())
except OSError as e:
logger.warning(f"Failed to list directory {directory}: {e}")
continue
for filepath in files_to_check:
if not filepath.is_file():
continue
# Skip protected files
if self.is_protected(filepath):
continue
try:
# Check if it's an orphan (only if Redis is available)
abs_path = str(filepath.absolute())
is_orphan = redis_available and abs_path not in tracked_paths
# Check file age
file_age = now - filepath.stat().st_mtime
if file_age > self.max_file_age_seconds:
should_delete = False
reason = ""
if is_orphan:
should_delete = True
reason = "orphan"
elif file_age > self.max_file_age_seconds:
should_delete = True
reason = "expired"
if should_delete:
file_size = filepath.stat().st_size
filepath.unlink()
stats["files_deleted"] += 1
stats["bytes_freed"] += file_size
logger.debug(f"Deleted old file: {filepath}")
if reason == "orphan":
stats["orphaned_deleted"] += 1
logger.info(f"Deleted {reason} file: {filepath}")
except Exception as e:
stats["errors"].append(str(e))
logger.warning(f"Failed to delete {filepath}: {e}")
# Force cleanup if total size exceeds limit
await self._enforce_size_limit(stats)
if stats["files_deleted"] > 0:
mb_freed = stats["bytes_freed"] / (1024 * 1024)
logger.info(f"Cleanup: deleted {stats['files_deleted']} files, freed {mb_freed:.2f}MB")
mb_freed = stats["bytes_freed"] / (1024 * 1024)
cleanup_timestamp = datetime.now().isoformat()
# Structured logging (AC: #5)
log_data = {
"files_deleted": stats["files_deleted"],
"bytes_freed_mb": round(mb_freed, 2),
"orphaned_deleted": stats["orphaned_deleted"],
"cleanup_run_timestamp": cleanup_timestamp,
}
if _HAS_STRUCTLOG:
logger.info("cleanup_completed", **log_data)
else:
logger.info(f"Cleanup completed: {log_data}")
return stats
async def _enforce_size_limit(self, stats: dict):
"""Delete oldest files if total size exceeds limit"""
files_with_mtime = []
total_size = 0
for directory in [self.upload_dir, self.output_dir, self.temp_dir]:
if not directory.exists():
continue
for filepath in directory.iterdir():
if not filepath.is_file() or self.is_protected(filepath):
continue
try:
stat = filepath.stat()
files_with_mtime.append((filepath, stat.st_mtime, stat.st_size))
total_size += stat.st_size
except Exception:
pass
# If under limit, nothing to do
if total_size <= self.max_total_size_bytes:
return
# Sort by modification time (oldest first)
files_with_mtime.sort(key=lambda x: x[1])
# Delete oldest files until under limit
for filepath, _, size in files_with_mtime:
if total_size <= self.max_total_size_bytes:
break
try:
filepath.unlink()
total_size -= size
@@ -250,16 +318,16 @@ class FileCleanupManager:
logger.info(f"Deleted file to free space: {filepath}")
except Exception as e:
stats["errors"].append(str(e))
def get_disk_usage(self) -> dict:
"""Get current disk usage statistics"""
total_files = 0
total_size = 0
for directory in [self.upload_dir, self.output_dir, self.temp_dir]:
if not directory.exists():
continue
for filepath in directory.iterdir():
if filepath.is_file():
total_files += 1
@@ -267,55 +335,60 @@ class FileCleanupManager:
total_size += filepath.stat().st_size
except Exception:
pass
return {
"total_files": total_files,
"total_size_mb": round(total_size / (1024 * 1024), 2),
"max_size_gb": self.max_total_size_bytes / (1024 * 1024 * 1024),
"usage_percent": round((total_size / self.max_total_size_bytes) * 100, 1) if self.max_total_size_bytes > 0 else 0,
"usage_percent": round((total_size / self.max_total_size_bytes) * 100, 1)
if self.max_total_size_bytes > 0
else 0,
"directories": {
"uploads": str(self.upload_dir),
"outputs": str(self.output_dir),
"temp": str(self.temp_dir)
}
"temp": str(self.temp_dir),
},
}
class MemoryMonitor:
"""Monitors memory usage and triggers cleanup if needed"""
def __init__(self, max_memory_percent: float = 80.0):
self.max_memory_percent = max_memory_percent
self._high_memory_callbacks = []
def get_memory_usage(self) -> dict:
"""Get current memory usage"""
try:
import psutil
process = psutil.Process()
memory_info = process.memory_info()
system_memory = psutil.virtual_memory()
return {
"process_rss_mb": round(memory_info.rss / (1024 * 1024), 2),
"process_vms_mb": round(memory_info.vms / (1024 * 1024), 2),
"system_total_gb": round(system_memory.total / (1024 * 1024 * 1024), 2),
"system_available_gb": round(system_memory.available / (1024 * 1024 * 1024), 2),
"system_percent": system_memory.percent
"system_available_gb": round(
system_memory.available / (1024 * 1024 * 1024), 2
),
"system_percent": system_memory.percent,
}
except ImportError:
return {"error": "psutil not installed"}
except Exception as e:
return {"error": str(e)}
def check_memory(self) -> bool:
"""Check if memory usage is within limits"""
usage = self.get_memory_usage()
if "error" in usage:
return True # Can't check, assume OK
return usage.get("system_percent", 0) < self.max_memory_percent
def on_high_memory(self, callback):
"""Register callback for high memory situations"""
self._high_memory_callbacks.append(callback)
@@ -323,67 +396,75 @@ class MemoryMonitor:
class HealthChecker:
"""Comprehensive health checking for the application"""
def __init__(self, cleanup_manager: FileCleanupManager, memory_monitor: MemoryMonitor):
def __init__(
self, cleanup_manager: FileCleanupManager, memory_monitor: MemoryMonitor
):
self.cleanup_manager = cleanup_manager
self.memory_monitor = memory_monitor
self.start_time = datetime.now()
self._translation_count = 0
self._error_count = 0
self._lock = threading.Lock()
def record_translation(self, success: bool = True):
"""Record a translation attempt"""
with self._lock:
self._translation_count += 1
if not success:
self._error_count += 1
async def check_health(self) -> dict:
"""Get comprehensive health status (async version)"""
return self.get_health()
def get_health(self) -> dict:
"""Get comprehensive health status"""
memory = self.memory_monitor.get_memory_usage()
disk = self.cleanup_manager.get_disk_usage()
# Determine overall status
status = "healthy"
issues = []
if "error" not in memory:
if memory.get("system_percent", 0) > 90:
status = "degraded"
issues.append("High memory usage")
elif memory.get("system_percent", 0) > 80:
issues.append("Memory usage elevated")
if disk.get("usage_percent", 0) > 90:
status = "degraded"
issues.append("High disk usage")
elif disk.get("usage_percent", 0) > 80:
issues.append("Disk usage elevated")
uptime = datetime.now() - self.start_time
return {
"status": status,
"issues": issues,
"uptime_seconds": int(uptime.total_seconds()),
"uptime_human": str(uptime).split('.')[0],
"uptime_human": str(uptime).split(".")[0],
"translations": {
"total": self._translation_count,
"errors": self._error_count,
"success_rate": round(
((self._translation_count - self._error_count) / self._translation_count * 100)
if self._translation_count > 0 else 100, 1
)
(
(self._translation_count - self._error_count)
/ self._translation_count
* 100
)
if self._translation_count > 0
else 100,
1,
),
},
"memory": memory,
"disk": disk,
"cleanup_service": self.cleanup_manager.get_stats(),
"timestamp": datetime.now().isoformat()
"timestamp": datetime.now().isoformat(),
}
@@ -394,7 +475,7 @@ def create_cleanup_manager(config) -> FileCleanupManager:
upload_dir=config.UPLOAD_DIR,
output_dir=config.OUTPUT_DIR,
temp_dir=config.TEMP_DIR,
max_file_age_hours=getattr(config, 'MAX_FILE_AGE_HOURS', 1),
cleanup_interval_minutes=getattr(config, 'CLEANUP_INTERVAL_MINUTES', 10),
max_total_size_gb=getattr(config, 'MAX_TOTAL_SIZE_GB', 10.0)
max_file_age_minutes=getattr(config, "FILE_TTL_MINUTES", 60),
cleanup_interval_minutes=getattr(config, "CLEANUP_INTERVAL_MINUTES", 5),
max_total_size_gb=getattr(config, "MAX_TOTAL_SIZE_GB", 10.0),
)

107
middleware/error_handler.py Normal file
View File

@@ -0,0 +1,107 @@
"""
Global Error Handling Middleware
Catches all unhandled exceptions and standardizes API error responses.
"""
import logging
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response, JSONResponse
from fastapi import HTTPException
from starlette.exceptions import HTTPException as StarletteHTTPException
# Import APIKeyError for handling
from middleware.api_key_auth import APIKeyError
try:
import structlog
logger = structlog.get_logger(__name__)
except ImportError:
logger = logging.getLogger(__name__)
def format_error_response(
status_code: int,
message: str,
error_code: str = None,
details: dict = None,
request_id: str = "unknown",
headers: dict = None,
) -> JSONResponse:
"""
Standardizes the error response format.
Format: {error: "CODE", message: "...", details: {...}}
"""
if not error_code:
error_code = _map_http_status_to_code(status_code)
content = {"error": error_code, "message": message, "details": details or {}}
# Always include request_id in details if not present
if "request_id" not in content["details"]:
content["details"]["request_id"] = request_id
return JSONResponse(status_code=status_code, content=content, headers=headers)
def _map_http_status_to_code(status_code: int) -> str:
"""Map HTTP status codes to architectural error codes."""
mapping = {
400: "INVALID_FORMAT",
401: "UNAUTHORIZED",
403: "FORBIDDEN",
404: "NOT_FOUND",
405: "METHOD_NOT_ALLOWED",
413: "FILE_TOO_LARGE",
422: "VALIDATION_ERROR",
429: "QUOTA_EXCEEDED",
502: "PROVIDER_ERROR",
503: "SERVICE_UNAVAILABLE",
}
return mapping.get(status_code, "INTERNAL_ERROR")
class ErrorHandlingMiddleware(BaseHTTPMiddleware):
"""
Catch all unhandled exceptions (Exception) that bubble up to the top.
Note: HTTPException is often caught by FastAPI handlers before reaching here.
"""
async def dispatch(self, request: Request, call_next) -> Response:
try:
return await call_next(request)
except APIKeyError as e:
# Handle APIKeyError with structured response using to_response()
request_id = getattr(request.state, "request_id", "unknown")
logger.info(f"[{request_id}] API Key authentication error: {e.code}")
return e.to_response()
except Exception as e:
request_id = getattr(request.state, "request_id", "unknown")
# If it's already an HTTPException, we might want to handle it specifically if it leaked through
if isinstance(e, (HTTPException, StarletteHTTPException)):
detail = e.detail if hasattr(e, "detail") and e.detail else {}
if isinstance(detail, dict):
return format_error_response(
status_code=e.status_code,
message=detail.get("message", "Une erreur s'est produite."),
error_code=detail.get("error"),
request_id=request_id,
)
return format_error_response(
status_code=e.status_code,
message=str(detail) if detail else "Une erreur s'est produite.",
request_id=request_id,
)
# Log the full stack trace for internal debugging
logger.exception(f"[{request_id}] Unhandled internal exception: {str(e)}")
# Return generic error in French to user (AC4, AC5)
return format_error_response(
status_code=500,
message="Une erreur inattendue s'est produite. Veuillez réessayer plus tard.",
error_code="INTERNAL_ERROR",
request_id=request_id,
)

View File

@@ -116,27 +116,3 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
return request.client.host
return "unknown"
class ErrorHandlingMiddleware(BaseHTTPMiddleware):
"""Catch all unhandled exceptions and return proper error responses"""
async def dispatch(self, request: Request, call_next) -> Response:
from starlette.responses import JSONResponse
try:
return await call_next(request)
except Exception as e:
request_id = getattr(request.state, 'request_id', 'unknown')
logger.exception(f"[{request_id}] Unhandled exception: {str(e)}")
# Don't expose internal errors in production
return JSONResponse(
status_code=500,
content={
"error": "internal_server_error",
"message": "An unexpected error occurred. Please try again later.",
"request_id": request_id
}
)

180
middleware/tier_quota.py Normal file
View File

@@ -0,0 +1,180 @@
"""
Tier-based daily translation quota (Story 1.6).
Uses Redis sliding-window daily counter per user; fallback in-memory when Redis unavailable.
Coexists with IP-based rate limiting in rate_limiting.py.
Source of truth: Redis (key per user per UTC date) is the authority for quota enforcement.
User.daily_translation_count in DB is kept in sync on each successful translation for
reporting/analytics; reset at midnight UTC is automatic in Redis (new key per day). DB
reset can be done by a scheduled job at midnight UTC if needed.
"""
from __future__ import annotations
import os
import logging
from dataclasses import dataclass
from datetime import datetime, timezone, timedelta
from typing import Optional
logger = logging.getLogger(__name__)
# Free tier: 5 translations per day (UTC). Pro (and equivalent) tiers: no daily cap.
FREE_TIER_DAILY_LIMIT = 5
KEY_PREFIX = "rate_limit:daily"
def _utc_date_str(dt: Optional[datetime] = None) -> str:
"""Current date in UTC as YYYY-MM-DD."""
t = dt or datetime.now(timezone.utc)
return t.strftime("%Y-%m-%d")
def _next_midnight_utc(dt: Optional[datetime] = None) -> datetime:
"""Next midnight UTC after the given time (or now)."""
now = dt or datetime.now(timezone.utc)
tomorrow = (now.date() + timedelta(days=1))
return datetime(tomorrow.year, tomorrow.month, tomorrow.day, tzinfo=timezone.utc)
def _seconds_until_midnight_utc(dt: Optional[datetime] = None) -> int:
"""Seconds until next midnight UTC."""
now = dt or datetime.now(timezone.utc)
return max(0, int((_next_midnight_utc(now) - now).total_seconds()))
@dataclass
class QuotaResult:
"""Result of a quota check."""
allowed: bool
remaining: int # -1 for pro (unlimited)
reset_at_utc: datetime
current_usage: int = 0
limit: int = FREE_TIER_DAILY_LIMIT
# ---------------------------------------------------------------------------
# Redis backend
# ---------------------------------------------------------------------------
_async_redis = None
def _get_async_redis():
"""Return async Redis client or None. Uses REDIS_URL from env. Single shared client."""
global _async_redis
if _async_redis is not None:
return _async_redis if _async_redis is not False else None
url = os.getenv("REDIS_URL", "").strip()
if not url:
_async_redis = False
return None
try:
import redis.asyncio as redis
_async_redis = redis.Redis.from_url(url, decode_responses=True)
logger.info("Tier quota: using Redis for daily quota")
return _async_redis
except Exception as e:
logger.warning("Tier quota: Redis unavailable (%s), using in-memory fallback", e)
_async_redis = False
return None
# ---------------------------------------------------------------------------
# In-memory fallback (per process; not shared across workers). Documented as fallback.
# ---------------------------------------------------------------------------
_memory_usage: dict[tuple[str, str], int] = {} # (user_id, date_utc_str) -> count
def _memory_get(user_id: str, date_str: str) -> int:
return _memory_usage.get((user_id, date_str), 0)
def _memory_incr(user_id: str, date_str: str) -> int:
key = (user_id, date_str)
_memory_usage[key] = _memory_usage.get(key, 0) + 1
return _memory_usage[key]
# ---------------------------------------------------------------------------
# TierQuotaService
# ---------------------------------------------------------------------------
class TierQuotaService:
"""
Daily translation quota per user by tier.
Redis key pattern: rate_limit:daily:{user_id}:{YYYY-MM-DD}, TTL 25h.
If Redis is unavailable, uses in-memory dict (documented fallback).
"""
def __init__(self):
self._redis = None # Lazy init on first use
def _redis_client(self):
if self._redis is None:
self._redis = _get_async_redis()
return self._redis
def _date_str(self, dt: Optional[datetime] = None) -> str:
return _utc_date_str(dt)
async def check_quota(self, user_id: str, tier: str) -> QuotaResult:
"""
Check if user has quota for one more translation today (UTC).
tier "free" -> limit 5/day; "pro" (or equivalent) -> unlimited.
"""
reset_at = _next_midnight_utc()
tier_lower = (tier or "free").lower()
if tier_lower in ("pro", "business", "enterprise", "starter"):
return QuotaResult(
allowed=True,
remaining=-1,
reset_at_utc=reset_at,
current_usage=0,
limit=0,
)
# Free tier
date_str = self._date_str()
redis_client = self._redis_client()
if redis_client:
try:
key = f"{KEY_PREFIX}:{user_id}:{date_str}"
count = await redis_client.get(key)
count = int(count or 0)
except Exception as e:
logger.warning("Tier quota Redis get failed: %s, using in-memory", e)
count = _memory_get(user_id, date_str)
else:
count = _memory_get(user_id, date_str)
remaining = max(0, FREE_TIER_DAILY_LIMIT - count)
return QuotaResult(
allowed=count < FREE_TIER_DAILY_LIMIT,
remaining=remaining,
reset_at_utc=reset_at,
current_usage=count,
limit=FREE_TIER_DAILY_LIMIT,
)
async def increment_on_success(self, user_id: str) -> None:
"""Increment daily translation count for user (call after successful translation)."""
date_str = self._date_str()
redis_client = self._redis_client()
if redis_client:
try:
key = f"{KEY_PREFIX}:{user_id}:{date_str}"
pipe = redis_client.pipeline()
pipe.incr(key)
pipe.expire(key, 25 * 3600) # 25h so key expires after midnight UTC
await pipe.execute()
return
except Exception as e:
logger.warning("Tier quota Redis increment failed: %s, using in-memory", e)
_memory_incr(user_id, date_str)
def seconds_until_reset(self) -> int:
"""Seconds until next midnight UTC (for Retry-After header)."""
return _seconds_until_midnight_utc()
# Singleton for app use
tier_quota_service = TierQuotaService()

View File

@@ -2,10 +2,14 @@
Input Validation Module for SaaS robustness
Validates all user inputs before processing
"""
import re
import magic
import ipaddress
import socket
from pathlib import Path
from typing import Optional, List, Set
from urllib.parse import urlparse
from typing import Optional, List, Set, Tuple
from fastapi import UploadFile, HTTPException
import logging
@@ -14,7 +18,13 @@ logger = logging.getLogger(__name__)
class ValidationError(Exception):
"""Custom validation error with user-friendly messages"""
def __init__(self, message: str, code: str = "validation_error", details: Optional[dict] = None):
def __init__(
self,
message: str,
code: str = "validation_error",
details: Optional[dict] = None,
):
self.message = message
self.code = code
self.details = details or {}
@@ -23,37 +33,46 @@ class ValidationError(Exception):
class ValidationResult:
"""Result of a validation check"""
def __init__(self, is_valid: bool = True, errors: List[str] = None, warnings: List[str] = None, data: dict = None):
def __init__(
self,
is_valid: bool = True,
errors: Optional[List[str]] = None,
warnings: Optional[List[str]] = None,
data: Optional[dict] = None,
error_code: Optional[str] = None,
):
self.is_valid = is_valid
self.errors = errors or []
self.warnings = warnings or []
self.data = data or {}
self.error_code = error_code
class FileValidator:
"""Validates uploaded files for security and compatibility"""
# Allowed MIME types mapped to extensions
ALLOWED_MIME_TYPES = {
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
"application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx",
}
# Magic bytes for Office Open XML files (ZIP format)
OFFICE_MAGIC_BYTES = b"PK\x03\x04"
def __init__(
self,
max_size_mb: int = 50,
allowed_extensions: Set[str] = None,
scan_content: bool = True
allowed_extensions: Optional[Set[str]] = None,
scan_content: bool = True,
):
self.max_size_bytes = max_size_mb * 1024 * 1024
self.max_size_mb = max_size_mb
self.allowed_extensions = allowed_extensions or {".xlsx", ".docx", ".pptx"}
self.scan_content = scan_content
async def validate_async(self, file: UploadFile) -> ValidationResult:
"""
Validate an uploaded file asynchronously
@@ -62,77 +81,105 @@ class FileValidator:
errors = []
warnings = []
data = {}
try:
# Validate filename
if not file.filename:
errors.append("Filename is required")
return ValidationResult(is_valid=False, errors=errors)
errors.append("Le nom de fichier est requis")
return ValidationResult(
is_valid=False, errors=errors, error_code="missing_filename"
)
# Sanitize filename
try:
safe_filename = self._sanitize_filename(file.filename)
data["safe_filename"] = safe_filename
except ValidationError as e:
errors.append(str(e.message))
return ValidationResult(is_valid=False, errors=errors)
return ValidationResult(
is_valid=False, errors=errors, error_code=e.code
)
# Validate extension
try:
extension = self._validate_extension(safe_filename)
data["extension"] = extension
except ValidationError as e:
errors.append(str(e.message))
return ValidationResult(is_valid=False, errors=errors)
return ValidationResult(
is_valid=False, errors=errors, error_code=e.code
)
# Read file content for validation
content = await file.read()
await file.seek(0) # Reset for later processing
# Validate file size
file_size = len(content)
data["size_bytes"] = file_size
data["size_mb"] = round(file_size / (1024*1024), 2)
data["size_mb"] = round(file_size / (1024 * 1024), 2)
if file_size > self.max_size_bytes:
errors.append(f"File too large. Maximum size is {self.max_size_mb}MB, got {file_size / (1024*1024):.1f}MB")
return ValidationResult(is_valid=False, errors=errors, data=data)
errors.append(
f"Fichier trop volumineux. La taille maximale est de {self.max_size_mb}Mo, "
f"vous avez envoye {file_size / (1024 * 1024):.1f}Mo"
)
return ValidationResult(
is_valid=False,
errors=errors,
data=data,
error_code="file_too_large",
)
if file_size == 0:
errors.append("File is empty")
return ValidationResult(is_valid=False, errors=errors, data=data)
errors.append("Le fichier est vide")
return ValidationResult(
is_valid=False, errors=errors, data=data, error_code="empty_file"
)
# Warn about large files
if file_size > self.max_size_bytes * 0.8:
warnings.append(f"File is {data['size_mb']}MB, approaching the {self.max_size_mb}MB limit")
warnings.append(
f"Le fichier fait {data['size_mb']}Mo, approchant la limite de {self.max_size_mb}Mo"
)
# Validate magic bytes
if self.scan_content:
try:
self._validate_magic_bytes(content, extension)
except ValidationError as e:
errors.append(str(e.message))
return ValidationResult(is_valid=False, errors=errors, data=data)
return ValidationResult(
is_valid=False, errors=errors, data=data, error_code=e.code
)
# Validate MIME type
try:
mime_type = self._detect_mime_type(content)
data["mime_type"] = mime_type
self._validate_mime_type(mime_type, extension)
except ValidationError as e:
warnings.append(f"MIME type warning: {e.message}")
warnings.append(f"Avertissement MIME: {e.message}")
except Exception:
warnings.append("Could not verify MIME type")
warnings.append("Impossible de verifier le type MIME")
data["original_filename"] = file.filename
return ValidationResult(is_valid=True, errors=errors, warnings=warnings, data=data)
return ValidationResult(
is_valid=True, errors=errors, warnings=warnings, data=data
)
except Exception as e:
logger.error(f"Validation error: {str(e)}")
errors.append(f"Validation failed: {str(e)}")
return ValidationResult(is_valid=False, errors=errors, warnings=warnings, data=data)
errors.append(f"Erreur de validation: {str(e)}")
return ValidationResult(
is_valid=False,
errors=errors,
warnings=warnings,
data=data,
error_code="validation_error",
)
async def validate(self, file: UploadFile) -> dict:
"""
Validate an uploaded file
@@ -141,106 +188,107 @@ class FileValidator:
# Validate filename
if not file.filename:
raise ValidationError(
"Filename is required",
code="missing_filename"
"Le nom de fichier est requis", code="missing_filename"
)
# Sanitize filename
safe_filename = self._sanitize_filename(file.filename)
# Validate extension
extension = self._validate_extension(safe_filename)
# Read file content for validation
content = await file.read()
await file.seek(0) # Reset for later processing
# Validate file size
file_size = len(content)
if file_size > self.max_size_bytes:
raise ValidationError(
f"File too large. Maximum size is {self.max_size_mb}MB, got {file_size / (1024*1024):.1f}MB",
f"Fichier trop volumineux. La taille maximale est de {self.max_size_mb}Mo, "
f"vous avez envoye {file_size / (1024 * 1024):.1f}Mo",
code="file_too_large",
details={"max_mb": self.max_size_mb, "actual_mb": round(file_size / (1024*1024), 2)}
details={
"max_mb": self.max_size_mb,
"actual_mb": round(file_size / (1024 * 1024), 2),
},
)
if file_size == 0:
raise ValidationError(
"File is empty",
code="empty_file"
)
raise ValidationError("Le fichier est vide", code="empty_file")
# Validate magic bytes (file signature)
if self.scan_content:
self._validate_magic_bytes(content, extension)
# Validate MIME type
mime_type = self._detect_mime_type(content)
self._validate_mime_type(mime_type, extension)
return {
"original_filename": file.filename,
"safe_filename": safe_filename,
"extension": extension,
"size_bytes": file_size,
"size_mb": round(file_size / (1024*1024), 2),
"mime_type": mime_type
"size_mb": round(file_size / (1024 * 1024), 2),
"mime_type": mime_type,
}
def _sanitize_filename(self, filename: str) -> str:
"""Sanitize filename to prevent path traversal and other attacks"""
# Remove path components
filename = Path(filename).name
# Remove null bytes and control characters
filename = re.sub(r'[\x00-\x1f\x7f-\x9f]', '', filename)
filename = re.sub(r"[\x00-\x1f\x7f-\x9f]", "", filename)
# Remove potentially dangerous characters
filename = re.sub(r'[<>:"/\\|?*]', '_', filename)
filename = re.sub(r'[<>:"/\\|?*]', "_", filename)
# Limit length
if len(filename) > 255:
name, ext = filename.rsplit('.', 1) if '.' in filename else (filename, '')
filename = name[:250] + ('.' + ext if ext else '')
name, ext = filename.rsplit(".", 1) if "." in filename else (filename, "")
filename = name[:250] + ("." + ext if ext else "")
# Ensure not empty after sanitization
if not filename or filename.strip() == '':
raise ValidationError(
"Invalid filename",
code="invalid_filename"
)
if not filename or filename.strip() == "":
raise ValidationError("Nom de fichier invalide", code="invalid_filename")
return filename
def _validate_extension(self, filename: str) -> str:
"""Validate and return the file extension"""
if '.' not in filename:
if "." not in filename:
raise ValidationError(
f"File must have an extension. Supported: {', '.join(self.allowed_extensions)}",
f"Le fichier doit avoir une extension. Formats supportes : {', '.join(self.allowed_extensions)}",
code="missing_extension",
details={"allowed_extensions": list(self.allowed_extensions)}
details={"allowed_extensions": list(self.allowed_extensions)},
)
extension = '.' + filename.rsplit('.', 1)[1].lower()
extension = "." + filename.rsplit(".", 1)[1].lower()
if extension not in self.allowed_extensions:
raise ValidationError(
f"File type '{extension}' not supported. Supported types: {', '.join(self.allowed_extensions)}",
f"Format de fichier '{extension}' non supporte. Formats acceptes : {', '.join(self.allowed_extensions)}",
code="unsupported_file_type",
details={"extension": extension, "allowed_extensions": list(self.allowed_extensions)}
details={
"extension": extension,
"allowed_extensions": list(self.allowed_extensions),
},
)
return extension
def _validate_magic_bytes(self, content: bytes, extension: str):
"""Validate file magic bytes match expected format"""
# All supported formats are Office Open XML (ZIP-based)
if not content.startswith(self.OFFICE_MAGIC_BYTES):
raise ValidationError(
"File content does not match expected format. The file may be corrupted or not a valid Office document.",
code="invalid_file_content"
"Le contenu du fichier ne correspond pas au format attendu. "
"Le fichier est peut-etre corrompu ou n'est pas un document Office valide.",
code="invalid_file_content",
)
def _detect_mime_type(self, content: bytes) -> str:
"""Detect MIME type from file content"""
try:
@@ -251,77 +299,198 @@ class FileValidator:
if content.startswith(self.OFFICE_MAGIC_BYTES):
return "application/zip"
return "application/octet-stream"
def _validate_mime_type(self, mime_type: str, extension: str):
"""Validate MIME type matches extension"""
# Office Open XML files may be detected as ZIP
allowed_mimes = list(self.ALLOWED_MIME_TYPES.keys()) + ["application/zip", "application/octet-stream"]
allowed_mimes = list(self.ALLOWED_MIME_TYPES.keys()) + [
"application/zip",
"application/octet-stream",
]
if mime_type not in allowed_mimes:
raise ValidationError(
f"Invalid file type detected. Expected Office document, got: {mime_type}",
f"Type de fichier invalide detecte. Document Office attendu, recu : {mime_type}",
code="invalid_mime_type",
details={"detected_mime": mime_type}
details={"detected_mime": mime_type},
)
class LanguageValidator:
"""Validates language codes"""
SUPPORTED_LANGUAGES = {
# ISO 639-1 codes
"af", "sq", "am", "ar", "hy", "az", "eu", "be", "bn", "bs",
"bg", "ca", "ceb", "zh", "zh-CN", "zh-TW", "co", "hr", "cs",
"da", "nl", "en", "eo", "et", "fi", "fr", "fy", "gl", "ka",
"de", "el", "gu", "ht", "ha", "haw", "he", "hi", "hmn", "hu",
"is", "ig", "id", "ga", "it", "ja", "jv", "kn", "kk", "km",
"rw", "ko", "ku", "ky", "lo", "la", "lv", "lt", "lb", "mk",
"mg", "ms", "ml", "mt", "mi", "mr", "mn", "my", "ne", "no",
"ny", "or", "ps", "fa", "pl", "pt", "pa", "ro", "ru", "sm",
"gd", "sr", "st", "sn", "sd", "si", "sk", "sl", "so", "es",
"su", "sw", "sv", "tl", "tg", "ta", "tt", "te", "th", "tr",
"tk", "uk", "ur", "ug", "uz", "vi", "cy", "xh", "yi", "yo",
"zu", "auto"
"af",
"sq",
"am",
"ar",
"hy",
"az",
"eu",
"be",
"bn",
"bs",
"bg",
"ca",
"ceb",
"zh",
"zh-CN",
"zh-TW",
"co",
"hr",
"cs",
"da",
"nl",
"en",
"eo",
"et",
"fi",
"fr",
"fy",
"gl",
"ka",
"de",
"el",
"gu",
"ht",
"ha",
"haw",
"he",
"hi",
"hmn",
"hu",
"is",
"ig",
"id",
"ga",
"it",
"ja",
"jv",
"kn",
"kk",
"km",
"rw",
"ko",
"ku",
"ky",
"lo",
"la",
"lv",
"lt",
"lb",
"mk",
"mg",
"ms",
"ml",
"mt",
"mi",
"mr",
"mn",
"my",
"ne",
"no",
"ny",
"or",
"ps",
"fa",
"pl",
"pt",
"pa",
"ro",
"ru",
"sm",
"gd",
"sr",
"st",
"sn",
"sd",
"si",
"sk",
"sl",
"so",
"es",
"su",
"sw",
"sv",
"tl",
"tg",
"ta",
"tt",
"te",
"th",
"tr",
"tk",
"uk",
"ur",
"ug",
"uz",
"vi",
"cy",
"xh",
"yi",
"yo",
"zu",
"auto",
}
LANGUAGE_NAMES = {
"en": "English", "es": "Spanish", "fr": "French", "de": "German",
"it": "Italian", "pt": "Portuguese", "ru": "Russian", "zh": "Chinese",
"zh-CN": "Chinese (Simplified)", "zh-TW": "Chinese (Traditional)",
"ja": "Japanese", "ko": "Korean", "ar": "Arabic", "hi": "Hindi",
"nl": "Dutch", "pl": "Polish", "tr": "Turkish", "sv": "Swedish",
"da": "Danish", "no": "Norwegian", "fi": "Finnish", "cs": "Czech",
"el": "Greek", "th": "Thai", "vi": "Vietnamese", "id": "Indonesian",
"uk": "Ukrainian", "ro": "Romanian", "hu": "Hungarian", "auto": "Auto-detect"
"en": "English",
"es": "Spanish",
"fr": "French",
"de": "German",
"it": "Italian",
"pt": "Portuguese",
"ru": "Russian",
"zh": "Chinese",
"zh-CN": "Chinese (Simplified)",
"zh-TW": "Chinese (Traditional)",
"ja": "Japanese",
"ko": "Korean",
"ar": "Arabic",
"hi": "Hindi",
"nl": "Dutch",
"pl": "Polish",
"tr": "Turkish",
"sv": "Swedish",
"da": "Danish",
"no": "Norwegian",
"fi": "Finnish",
"cs": "Czech",
"el": "Greek",
"th": "Thai",
"vi": "Vietnamese",
"id": "Indonesian",
"uk": "Ukrainian",
"ro": "Romanian",
"hu": "Hungarian",
"auto": "Auto-detect",
}
@classmethod
def validate(cls, language_code: str, field_name: str = "language") -> str:
"""Validate and normalize language code"""
if not language_code:
raise ValidationError(
f"{field_name} is required",
code="missing_language"
)
raise ValidationError(f"{field_name} est requis", code="missing_language")
# Normalize
normalized = language_code.strip().lower()
# Handle common variations
if normalized in ["chinese", "cn"]:
normalized = "zh-CN"
elif normalized in ["chinese-traditional", "tw"]:
normalized = "zh-TW"
if normalized not in cls.SUPPORTED_LANGUAGES:
raise ValidationError(
f"Unsupported language code: '{language_code}'. See /languages for supported codes.",
f"Code langue non supporte: '{language_code}'. Consultez /languages pour les codes supportes.",
code="unsupported_language",
details={"language": language_code}
details={"language": language_code},
)
return normalized
@classmethod
def get_language_name(cls, code: str) -> str:
"""Get human-readable language name"""
@@ -330,104 +499,116 @@ class LanguageValidator:
class ProviderValidator:
"""Validates translation provider configuration"""
SUPPORTED_PROVIDERS = {"google", "ollama", "deepl", "libre", "openai", "webllm", "openrouter"}
SUPPORTED_PROVIDERS = {
"google",
"ollama",
"deepl",
"libre",
"openai",
"webllm",
"openrouter",
"classic",
"llm",
}
@classmethod
def validate(cls, provider: str, **kwargs) -> dict:
"""Validate provider and its required configuration"""
if not provider:
raise ValidationError(
"Translation provider is required",
code="missing_provider"
"Le fournisseur de traduction est requis", code="missing_provider"
)
normalized = provider.strip().lower()
if normalized not in cls.SUPPORTED_PROVIDERS:
raise ValidationError(
f"Unsupported provider: '{provider}'. Supported: {', '.join(cls.SUPPORTED_PROVIDERS)}",
f"Fournisseur non supporte: '{provider}'. Supportes: {', '.join(cls.SUPPORTED_PROVIDERS)}",
code="unsupported_provider",
details={"provider": provider, "supported": list(cls.SUPPORTED_PROVIDERS)}
details={
"provider": provider,
"supported": list(cls.SUPPORTED_PROVIDERS),
},
)
# Provider-specific validation
if normalized == "deepl":
if not kwargs.get("deepl_api_key"):
raise ValidationError(
"DeepL API key is required when using DeepL provider",
code="missing_deepl_key"
"La cle API DeepL est requise pour utiliser le fournisseur DeepL",
code="missing_deepl_key",
)
elif normalized == "openai":
if not kwargs.get("openai_api_key"):
raise ValidationError(
"OpenAI API key is required when using OpenAI provider",
code="missing_openai_key"
"La cle API OpenAI est requise pour utiliser le fournisseur OpenAI",
code="missing_openai_key",
)
elif normalized == "ollama":
# Ollama doesn't require API key but may need model
model = kwargs.get("ollama_model", "")
if not model:
logger.warning("No Ollama model specified, will use default")
return {"provider": normalized, "validated": True}
class InputSanitizer:
"""Sanitizes user inputs to prevent injection attacks"""
@staticmethod
def sanitize_text(text: str, max_length: int = 10000) -> str:
"""Sanitize text input"""
if not text:
return ""
# Remove null bytes
text = text.replace('\x00', '')
text = text.replace("\x00", "")
# Limit length
if len(text) > max_length:
text = text[:max_length]
return text.strip()
@staticmethod
def sanitize_language_code(code: str) -> str:
"""Sanitize and normalize language code"""
if not code:
return "auto"
# Remove dangerous characters, keep only alphanumeric and hyphen
code = re.sub(r'[^a-zA-Z0-9\-]', '', code.strip())
code = re.sub(r"[^a-zA-Z0-9\-]", "", code.strip())
# Limit length
if len(code) > 10:
code = code[:10]
return code.lower() if code else "auto"
@staticmethod
def sanitize_url(url: str) -> str:
"""Sanitize URL input"""
if not url:
return ""
url = url.strip()
# Basic URL validation
if not re.match(r'^https?://', url, re.IGNORECASE):
if not re.match(r"^https?://", url, re.IGNORECASE):
raise ValidationError(
"Invalid URL format. Must start with http:// or https://",
code="invalid_url"
"Format d'URL invalide. Doit commencer par http:// ou https://",
code="invalid_url",
)
# Remove trailing slashes
url = url.rstrip('/')
url = url.rstrip("/")
return url
@staticmethod
def sanitize_api_key(key: str) -> str:
"""Sanitize API key (just trim, no logging)"""
@@ -436,5 +617,117 @@ class InputSanitizer:
return key.strip()
class WebhookURLValidator:
"""
Validator for webhook URLs with security checks.
Prevents SSRF attacks by blocking private IPs and localhost.
Story 3.7: Webhook - Spécification URL
"""
# Allowed URL schemes
ALLOWED_SCHEMES = ("http", "https")
# Blocked hostnames
BLOCKED_HOSTNAMES = {"localhost", "127.0.0.1", "::1", "0.0.0.0"}
def __init__(
self,
allowed_schemes: Tuple[str, ...] = ALLOWED_SCHEMES,
block_private_ips: bool = True
):
self.allowed_schemes = allowed_schemes
self.block_private_ips = block_private_ips
def validate(self, url: Optional[str]) -> Tuple[bool, Optional[str], Optional[dict]]:
"""
Validate webhook URL format and security.
Args:
url: The webhook URL to validate (can be None or empty for optional parameter)
Returns:
Tuple of (is_valid, error_message, details)
"""
# Empty or None URLs are valid (optional parameter)
if not url:
return True, None, None
try:
parsed = urlparse(url)
# Check scheme
if parsed.scheme.lower() not in self.allowed_schemes:
return False, (
f"L'URL doit utiliser {' ou '.join(self.allowed_schemes)}"
), {
"field": "webhook_url",
"allowed_schemes": list(self.allowed_schemes),
"detected_scheme": parsed.scheme or "none"
}
# Check for credentials in URL
if parsed.username or parsed.password:
return False, (
"L'URL ne doit pas contenir d'identifiants (credentials)"
), {"field": "webhook_url", "reason": "credentials_in_url"}
# Check hostname
hostname = parsed.hostname
if not hostname:
return False, (
"URL invalide: nom d'hôte manquant"
), {"field": "webhook_url", "reason": "missing_hostname"}
# Block localhost and common local addresses
if hostname.lower() in self.BLOCKED_HOSTNAMES:
return False, (
"Les URLs localhost ne sont pas autorisées"
), {"field": "webhook_url", "reason": "localhost_blocked"}
# Check for private IPs (SSRF protection)
if self.block_private_ips:
try:
# Try to parse as IP directly
try:
ip = ipaddress.ip_address(hostname)
if self._is_blocked_ip(ip):
return False, (
"Les adresses IP privées ne sont pas autorisées"
), {"field": "webhook_url", "reason": "private_ip_blocked"}
except ValueError:
# Not an IP, try DNS resolution
ip_str = socket.gethostbyname(hostname)
ip = ipaddress.ip_address(ip_str)
if self._is_blocked_ip(ip):
return False, (
"Les adresses IP privées ne sont pas autorisées"
), {"field": "webhook_url", "reason": "private_ip_blocked"}
except socket.gaierror:
# DNS resolution failed - let it through
# Will fail at webhook send time
pass
except Exception:
pass
return True, None, None
except Exception as e:
return False, (
f"Format d'URL invalide: {str(e)}"
), {"field": "webhook_url", "error": str(e)}
def _is_blocked_ip(self, ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
"""Check if IP is private, loopback, or link-local."""
return (
ip.is_private or
ip.is_loopback or
ip.is_link_local or
ip.is_reserved or
ip.is_multicast
)
# Default validators
file_validator = FileValidator()
webhook_validator = WebhookURLValidator()