feat: revue de code, doc CODE_REVIEW, forfaits 2026, traduction LLM, providers avec modèle
Made-with: Cursor
This commit is contained in:
@@ -27,6 +27,9 @@ from .validation import (
|
||||
from .security import (
|
||||
SecurityHeadersMiddleware,
|
||||
RequestLoggingMiddleware,
|
||||
)
|
||||
|
||||
from .error_handler import (
|
||||
ErrorHandlingMiddleware,
|
||||
)
|
||||
|
||||
@@ -37,7 +40,25 @@ from .cleanup import (
|
||||
create_cleanup_manager,
|
||||
)
|
||||
|
||||
from .api_key_auth import (
|
||||
APIKeyError,
|
||||
get_user_from_api_key,
|
||||
get_authenticated_user,
|
||||
get_authenticated_user_optional,
|
||||
get_current_user_optional,
|
||||
require_authenticated_user,
|
||||
require_api_key,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# API Key Authentication
|
||||
"APIKeyError",
|
||||
"get_user_from_api_key",
|
||||
"get_authenticated_user",
|
||||
"get_authenticated_user_optional",
|
||||
"get_current_user_optional",
|
||||
"require_authenticated_user",
|
||||
"require_api_key",
|
||||
# Rate limiting
|
||||
"RateLimitConfig",
|
||||
"RateLimitManager",
|
||||
|
||||
222
middleware/api_key_auth.py
Normal file
222
middleware/api_key_auth.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""
|
||||
API Key Authentication Middleware
|
||||
|
||||
Provides reusable dependencies for API key authentication across all endpoints.
|
||||
Story 3.4: Authentification API via X-API-Key
|
||||
"""
|
||||
|
||||
from typing import Optional, Any, Union
|
||||
from fastapi import Header, Depends
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
class APIKeyError(Exception):
|
||||
"""Exception for API key authentication errors with structured error codes."""
|
||||
|
||||
INVALID_API_KEY = "INVALID_API_KEY"
|
||||
API_KEY_REVOKED = "API_KEY_REVOKED"
|
||||
API_KEY_EXPIRED = "API_KEY_EXPIRED"
|
||||
MISSING_API_KEY = "MISSING_API_KEY"
|
||||
UNAUTHORIZED = "UNAUTHORIZED"
|
||||
|
||||
ERROR_MESSAGES = {
|
||||
INVALID_API_KEY: "Clé API invalide ou non reconnue.",
|
||||
API_KEY_REVOKED: "Cette clé API a été révoquée.",
|
||||
API_KEY_EXPIRED: "Cette clé API a expiré.",
|
||||
MISSING_API_KEY: "Clé API requise pour cet endpoint.",
|
||||
UNAUTHORIZED: "Authentification requise. Utilisez X-API-Key ou Authorization: Bearer.",
|
||||
}
|
||||
|
||||
def __init__(self, code: str, message: Optional[str] = None):
|
||||
self.code = code
|
||||
self.message = message or self.ERROR_MESSAGES.get(code, "Erreur d'authentification")
|
||||
super().__init__(self.message)
|
||||
|
||||
def to_response(self, status_code: int = 401) -> JSONResponse:
|
||||
"""Convert to JSONResponse for FastAPI."""
|
||||
return JSONResponse(
|
||||
status_code=status_code,
|
||||
content={
|
||||
"error": self.code,
|
||||
"message": self.message,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _raise_api_key_error(code: str, message: Optional[str] = None) -> None:
|
||||
"""Raise an APIKeyError and convert it to JSONResponse for FastAPI."""
|
||||
raise APIKeyError(code, message)
|
||||
|
||||
|
||||
async def get_user_from_api_key(
|
||||
x_api_key: Optional[str] = Header(None, alias="X-API-Key"),
|
||||
) -> Optional[Any]:
|
||||
"""
|
||||
Get user from X-API-Key header if provided.
|
||||
|
||||
Returns:
|
||||
User object if valid API key provided
|
||||
None if no API key provided (caller should try other auth methods)
|
||||
|
||||
Raises:
|
||||
APIKeyError: With structured error code if API key is invalid/revoked/expired
|
||||
"""
|
||||
if not x_api_key:
|
||||
return None
|
||||
|
||||
try:
|
||||
from services.auth_service import get_user_by_api_key
|
||||
|
||||
user = get_user_by_api_key(x_api_key)
|
||||
return user
|
||||
|
||||
except ValueError as e:
|
||||
# Handle revoked/expired API keys with specific error codes
|
||||
error_code = str(e)
|
||||
|
||||
if error_code == "API_KEY_REVOKED":
|
||||
raise APIKeyError("API_KEY_REVOKED", "Cette clé API a été révoquée.")
|
||||
elif error_code == "API_KEY_EXPIRED":
|
||||
raise APIKeyError("API_KEY_EXPIRED", "Cette clé API a expiré.")
|
||||
else:
|
||||
# Unknown error - treat as invalid
|
||||
raise APIKeyError("INVALID_API_KEY", "Clé API invalide ou non reconnue.")
|
||||
|
||||
except Exception:
|
||||
# Unexpected error - treat as invalid
|
||||
raise APIKeyError("INVALID_API_KEY", "Clé API invalide ou non reconnue.")
|
||||
|
||||
|
||||
async def get_current_user_optional(
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
|
||||
) -> Optional[Any]:
|
||||
"""Get current user if authenticated via JWT, None otherwise."""
|
||||
if not credentials:
|
||||
return None
|
||||
try:
|
||||
from routes.auth_routes import get_current_user
|
||||
|
||||
user = await get_current_user(credentials)
|
||||
return user
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
async def get_authenticated_user_optional(
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
|
||||
x_api_key: Optional[str] = Header(None, alias="X-API-Key"),
|
||||
) -> Optional[Any]:
|
||||
"""
|
||||
Get authenticated user from API key or JWT (optional - returns None if not authenticated).
|
||||
|
||||
Priority:
|
||||
1. X-API-Key header (automation users)
|
||||
2. JWT Bearer token (web users)
|
||||
3. None (unauthenticated)
|
||||
|
||||
Returns:
|
||||
User object if authenticated, None otherwise (never raises for auth failures)
|
||||
"""
|
||||
# Try API key first (priority for automation)
|
||||
if x_api_key:
|
||||
try:
|
||||
user = await get_user_from_api_key(x_api_key)
|
||||
if user:
|
||||
return user
|
||||
except APIKeyError:
|
||||
# Invalid API key, fall through to JWT
|
||||
pass
|
||||
|
||||
# Fall back to JWT
|
||||
if credentials:
|
||||
try:
|
||||
from routes.auth_routes import get_current_user
|
||||
|
||||
user = await get_current_user(credentials)
|
||||
return user
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def get_authenticated_user(
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
|
||||
x_api_key: Optional[str] = Header(None, alias="X-API-Key"),
|
||||
) -> Optional[Any]:
|
||||
"""
|
||||
Get authenticated user from API key or JWT.
|
||||
|
||||
Priority:
|
||||
1. X-API-Key header (automation users)
|
||||
2. JWT Bearer token (web users)
|
||||
3. None (unauthenticated)
|
||||
|
||||
Returns:
|
||||
User object if authenticated
|
||||
None if not authenticated
|
||||
|
||||
Raises:
|
||||
APIKeyError: If API key is provided but invalid/revoked/expired
|
||||
"""
|
||||
# Try API key first (priority for automation)
|
||||
if x_api_key:
|
||||
# get_user_from_api_key will raise APIKeyError for invalid keys
|
||||
user = await get_user_from_api_key(x_api_key)
|
||||
if user:
|
||||
return user
|
||||
# Should not reach here - get_user_from_api_key returns None only if no key provided
|
||||
raise APIKeyError("INVALID_API_KEY", "Clé API invalide ou non reconnue.")
|
||||
|
||||
# Fall back to JWT
|
||||
if credentials:
|
||||
try:
|
||||
from routes.auth_routes import get_current_user
|
||||
|
||||
user = await get_current_user(credentials)
|
||||
return user
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def require_authenticated_user(
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
|
||||
x_api_key: Optional[str] = Header(None, alias="X-API-Key"),
|
||||
) -> Any:
|
||||
"""
|
||||
Require authentication (API key or JWT).
|
||||
|
||||
Raises:
|
||||
APIKeyError: 401 if not authenticated
|
||||
|
||||
Returns:
|
||||
User object (guaranteed to be authenticated)
|
||||
"""
|
||||
user = await get_authenticated_user(credentials, x_api_key)
|
||||
|
||||
if not user:
|
||||
raise APIKeyError("MISSING_API_KEY", "Authentification requise. Utilisez X-API-Key ou Authorization: Bearer.")
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def require_api_key(
|
||||
x_api_key: str = Header(..., alias="X-API-Key"),
|
||||
) -> Any:
|
||||
"""
|
||||
Require API key authentication (no JWT fallback).
|
||||
|
||||
Use this for endpoints that MUST use API key (e.g., certain automation endpoints).
|
||||
|
||||
Raises:
|
||||
APIKeyError: 401 if API key is missing, invalid, revoked, or expired
|
||||
|
||||
Returns:
|
||||
User object (guaranteed to be authenticated via API key)
|
||||
"""
|
||||
return await get_user_from_api_key(x_api_key)
|
||||
@@ -2,6 +2,7 @@
|
||||
Cleanup and Resource Management for SaaS robustness
|
||||
Automatic cleanup of temporary files and resources
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import asyncio
|
||||
@@ -10,77 +11,84 @@ from pathlib import Path
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Set
|
||||
import logging
|
||||
import json
|
||||
from services.storage_tracker import _get_async_redis, KEY_PREFIX
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
try:
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
_HAS_STRUCTLOG = True
|
||||
except ImportError:
|
||||
logger = logging.getLogger(__name__)
|
||||
_HAS_STRUCTLOG = False
|
||||
|
||||
|
||||
class FileCleanupManager:
|
||||
"""Manages automatic cleanup of temporary and output files"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
upload_dir: Path,
|
||||
output_dir: Path,
|
||||
temp_dir: Path,
|
||||
max_file_age_hours: int = 1,
|
||||
cleanup_interval_minutes: int = 10,
|
||||
max_total_size_gb: float = 10.0
|
||||
max_file_age_minutes: int = 60,
|
||||
cleanup_interval_minutes: int = 5,
|
||||
max_total_size_gb: float = 10.0,
|
||||
):
|
||||
self.upload_dir = Path(upload_dir)
|
||||
self.output_dir = Path(output_dir)
|
||||
self.temp_dir = Path(temp_dir)
|
||||
self.max_file_age_seconds = max_file_age_hours * 3600
|
||||
self.max_file_age_seconds = max_file_age_minutes * 60
|
||||
self.cleanup_interval = cleanup_interval_minutes * 60
|
||||
self.max_total_size_bytes = int(max_total_size_gb * 1024 * 1024 * 1024)
|
||||
|
||||
|
||||
self._running = False
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._protected_files: Set[str] = set()
|
||||
self._tracked_files: dict = {} # filepath -> {created, ttl_minutes}
|
||||
self._lock = threading.Lock()
|
||||
self._stats = {
|
||||
"files_cleaned": 0,
|
||||
"bytes_freed": 0,
|
||||
"cleanup_runs": 0
|
||||
}
|
||||
|
||||
self._stats = {"files_cleaned": 0, "bytes_freed": 0, "cleanup_runs": 0}
|
||||
|
||||
async def track_file(self, filepath: Path, ttl_minutes: int = 60):
|
||||
"""Track a file for automatic cleanup after TTL expires"""
|
||||
with self._lock:
|
||||
self._tracked_files[str(filepath)] = {
|
||||
"created": time.time(),
|
||||
"ttl_minutes": ttl_minutes,
|
||||
"expires_at": time.time() + (ttl_minutes * 60)
|
||||
"expires_at": time.time() + (ttl_minutes * 60),
|
||||
}
|
||||
|
||||
|
||||
def get_tracked_files(self) -> list:
|
||||
"""Get list of currently tracked files with their status"""
|
||||
now = time.time()
|
||||
result = []
|
||||
|
||||
|
||||
with self._lock:
|
||||
for filepath, info in self._tracked_files.items():
|
||||
remaining = info["expires_at"] - now
|
||||
result.append({
|
||||
"path": filepath,
|
||||
"exists": Path(filepath).exists(),
|
||||
"expires_in_seconds": max(0, int(remaining)),
|
||||
"ttl_minutes": info["ttl_minutes"]
|
||||
})
|
||||
|
||||
result.append(
|
||||
{
|
||||
"path": filepath,
|
||||
"exists": Path(filepath).exists(),
|
||||
"expires_in_seconds": max(0, int(remaining)),
|
||||
"ttl_minutes": info["ttl_minutes"],
|
||||
}
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def cleanup_expired(self) -> int:
|
||||
"""Cleanup expired tracked files"""
|
||||
now = time.time()
|
||||
cleaned = 0
|
||||
to_remove = []
|
||||
|
||||
|
||||
with self._lock:
|
||||
for filepath, info in list(self._tracked_files.items()):
|
||||
if now > info["expires_at"]:
|
||||
to_remove.append(filepath)
|
||||
|
||||
|
||||
for filepath in to_remove:
|
||||
try:
|
||||
path = Path(filepath)
|
||||
@@ -91,55 +99,57 @@ class FileCleanupManager:
|
||||
self._stats["files_cleaned"] += 1
|
||||
self._stats["bytes_freed"] += size
|
||||
logger.info(f"Cleaned expired file: {filepath}")
|
||||
|
||||
|
||||
with self._lock:
|
||||
self._tracked_files.pop(filepath, None)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean expired file {filepath}: {e}")
|
||||
|
||||
|
||||
return cleaned
|
||||
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get cleanup statistics"""
|
||||
disk_usage = self.get_disk_usage()
|
||||
|
||||
|
||||
with self._lock:
|
||||
tracked_count = len(self._tracked_files)
|
||||
|
||||
|
||||
return {
|
||||
"files_cleaned_total": self._stats["files_cleaned"],
|
||||
"bytes_freed_total_mb": round(self._stats["bytes_freed"] / (1024 * 1024), 2),
|
||||
"bytes_freed_total_mb": round(
|
||||
self._stats["bytes_freed"] / (1024 * 1024), 2
|
||||
),
|
||||
"cleanup_runs": self._stats["cleanup_runs"],
|
||||
"tracked_files": tracked_count,
|
||||
"disk_usage": disk_usage,
|
||||
"is_running": self._running
|
||||
"is_running": self._running,
|
||||
}
|
||||
|
||||
|
||||
def protect_file(self, filepath: Path):
|
||||
"""Mark a file as protected (being processed)"""
|
||||
with self._lock:
|
||||
self._protected_files.add(str(filepath))
|
||||
|
||||
|
||||
def unprotect_file(self, filepath: Path):
|
||||
"""Remove protection from a file"""
|
||||
with self._lock:
|
||||
self._protected_files.discard(str(filepath))
|
||||
|
||||
|
||||
def is_protected(self, filepath: Path) -> bool:
|
||||
"""Check if a file is protected"""
|
||||
with self._lock:
|
||||
return str(filepath) in self._protected_files
|
||||
|
||||
|
||||
async def start(self):
|
||||
"""Start the cleanup background task"""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
|
||||
self._running = True
|
||||
self._task = asyncio.create_task(self._cleanup_loop())
|
||||
logger.info("File cleanup manager started")
|
||||
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the cleanup background task"""
|
||||
self._running = False
|
||||
@@ -150,7 +160,7 @@ class FileCleanupManager:
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("File cleanup manager stopped")
|
||||
|
||||
|
||||
async def _cleanup_loop(self):
|
||||
"""Background loop for periodic cleanup"""
|
||||
while self._running:
|
||||
@@ -160,88 +170,146 @@ class FileCleanupManager:
|
||||
self._stats["cleanup_runs"] += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Cleanup error: {e}")
|
||||
|
||||
|
||||
await asyncio.sleep(self.cleanup_interval)
|
||||
|
||||
|
||||
async def cleanup(self) -> dict:
|
||||
"""Perform cleanup of old files"""
|
||||
"""Perform cleanup of old files and orphans"""
|
||||
stats = {
|
||||
"files_deleted": 0,
|
||||
"bytes_freed": 0,
|
||||
"errors": []
|
||||
"orphaned_deleted": 0,
|
||||
"errors": [],
|
||||
}
|
||||
|
||||
|
||||
now = time.time()
|
||||
|
||||
# Cleanup each directory
|
||||
|
||||
# Get tracked paths from Redis to identify orphans
|
||||
tracked_paths = set()
|
||||
redis_client = _get_async_redis()
|
||||
redis_available = redis_client is not None
|
||||
if redis_client:
|
||||
try:
|
||||
keys = await redis_client.keys(f"{KEY_PREFIX}:*")
|
||||
for key in keys:
|
||||
data = await redis_client.get(key)
|
||||
if data:
|
||||
metadata = json.loads(data)
|
||||
if "file_path" in metadata:
|
||||
# Normalize path to absolute string for comparison
|
||||
path_str = str(Path(metadata["file_path"]).absolute())
|
||||
tracked_paths.add(path_str)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch tracked paths from Redis: {e}")
|
||||
redis_available = False
|
||||
else:
|
||||
logger.warning(
|
||||
"Redis unavailable - orphan detection disabled, using age-based cleanup only"
|
||||
)
|
||||
|
||||
# Cleanup each directory (collect files first to avoid race condition)
|
||||
for directory in [self.upload_dir, self.output_dir, self.temp_dir]:
|
||||
if not directory.exists():
|
||||
continue
|
||||
|
||||
for filepath in directory.iterdir():
|
||||
|
||||
try:
|
||||
files_to_check = list(directory.iterdir())
|
||||
except OSError as e:
|
||||
logger.warning(f"Failed to list directory {directory}: {e}")
|
||||
continue
|
||||
|
||||
for filepath in files_to_check:
|
||||
if not filepath.is_file():
|
||||
continue
|
||||
|
||||
|
||||
# Skip protected files
|
||||
if self.is_protected(filepath):
|
||||
continue
|
||||
|
||||
|
||||
try:
|
||||
# Check if it's an orphan (only if Redis is available)
|
||||
abs_path = str(filepath.absolute())
|
||||
is_orphan = redis_available and abs_path not in tracked_paths
|
||||
|
||||
# Check file age
|
||||
file_age = now - filepath.stat().st_mtime
|
||||
|
||||
if file_age > self.max_file_age_seconds:
|
||||
|
||||
should_delete = False
|
||||
reason = ""
|
||||
|
||||
if is_orphan:
|
||||
should_delete = True
|
||||
reason = "orphan"
|
||||
elif file_age > self.max_file_age_seconds:
|
||||
should_delete = True
|
||||
reason = "expired"
|
||||
|
||||
if should_delete:
|
||||
file_size = filepath.stat().st_size
|
||||
filepath.unlink()
|
||||
stats["files_deleted"] += 1
|
||||
stats["bytes_freed"] += file_size
|
||||
logger.debug(f"Deleted old file: {filepath}")
|
||||
|
||||
if reason == "orphan":
|
||||
stats["orphaned_deleted"] += 1
|
||||
logger.info(f"Deleted {reason} file: {filepath}")
|
||||
|
||||
except Exception as e:
|
||||
stats["errors"].append(str(e))
|
||||
logger.warning(f"Failed to delete {filepath}: {e}")
|
||||
|
||||
|
||||
# Force cleanup if total size exceeds limit
|
||||
await self._enforce_size_limit(stats)
|
||||
|
||||
if stats["files_deleted"] > 0:
|
||||
mb_freed = stats["bytes_freed"] / (1024 * 1024)
|
||||
logger.info(f"Cleanup: deleted {stats['files_deleted']} files, freed {mb_freed:.2f}MB")
|
||||
|
||||
|
||||
mb_freed = stats["bytes_freed"] / (1024 * 1024)
|
||||
cleanup_timestamp = datetime.now().isoformat()
|
||||
|
||||
# Structured logging (AC: #5)
|
||||
log_data = {
|
||||
"files_deleted": stats["files_deleted"],
|
||||
"bytes_freed_mb": round(mb_freed, 2),
|
||||
"orphaned_deleted": stats["orphaned_deleted"],
|
||||
"cleanup_run_timestamp": cleanup_timestamp,
|
||||
}
|
||||
|
||||
if _HAS_STRUCTLOG:
|
||||
logger.info("cleanup_completed", **log_data)
|
||||
else:
|
||||
logger.info(f"Cleanup completed: {log_data}")
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
async def _enforce_size_limit(self, stats: dict):
|
||||
"""Delete oldest files if total size exceeds limit"""
|
||||
files_with_mtime = []
|
||||
total_size = 0
|
||||
|
||||
|
||||
for directory in [self.upload_dir, self.output_dir, self.temp_dir]:
|
||||
if not directory.exists():
|
||||
continue
|
||||
|
||||
|
||||
for filepath in directory.iterdir():
|
||||
if not filepath.is_file() or self.is_protected(filepath):
|
||||
continue
|
||||
|
||||
|
||||
try:
|
||||
stat = filepath.stat()
|
||||
files_with_mtime.append((filepath, stat.st_mtime, stat.st_size))
|
||||
total_size += stat.st_size
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# If under limit, nothing to do
|
||||
if total_size <= self.max_total_size_bytes:
|
||||
return
|
||||
|
||||
|
||||
# Sort by modification time (oldest first)
|
||||
files_with_mtime.sort(key=lambda x: x[1])
|
||||
|
||||
|
||||
# Delete oldest files until under limit
|
||||
for filepath, _, size in files_with_mtime:
|
||||
if total_size <= self.max_total_size_bytes:
|
||||
break
|
||||
|
||||
|
||||
try:
|
||||
filepath.unlink()
|
||||
total_size -= size
|
||||
@@ -250,16 +318,16 @@ class FileCleanupManager:
|
||||
logger.info(f"Deleted file to free space: {filepath}")
|
||||
except Exception as e:
|
||||
stats["errors"].append(str(e))
|
||||
|
||||
|
||||
def get_disk_usage(self) -> dict:
|
||||
"""Get current disk usage statistics"""
|
||||
total_files = 0
|
||||
total_size = 0
|
||||
|
||||
|
||||
for directory in [self.upload_dir, self.output_dir, self.temp_dir]:
|
||||
if not directory.exists():
|
||||
continue
|
||||
|
||||
|
||||
for filepath in directory.iterdir():
|
||||
if filepath.is_file():
|
||||
total_files += 1
|
||||
@@ -267,55 +335,60 @@ class FileCleanupManager:
|
||||
total_size += filepath.stat().st_size
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
return {
|
||||
"total_files": total_files,
|
||||
"total_size_mb": round(total_size / (1024 * 1024), 2),
|
||||
"max_size_gb": self.max_total_size_bytes / (1024 * 1024 * 1024),
|
||||
"usage_percent": round((total_size / self.max_total_size_bytes) * 100, 1) if self.max_total_size_bytes > 0 else 0,
|
||||
"usage_percent": round((total_size / self.max_total_size_bytes) * 100, 1)
|
||||
if self.max_total_size_bytes > 0
|
||||
else 0,
|
||||
"directories": {
|
||||
"uploads": str(self.upload_dir),
|
||||
"outputs": str(self.output_dir),
|
||||
"temp": str(self.temp_dir)
|
||||
}
|
||||
"temp": str(self.temp_dir),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class MemoryMonitor:
|
||||
"""Monitors memory usage and triggers cleanup if needed"""
|
||||
|
||||
|
||||
def __init__(self, max_memory_percent: float = 80.0):
|
||||
self.max_memory_percent = max_memory_percent
|
||||
self._high_memory_callbacks = []
|
||||
|
||||
|
||||
def get_memory_usage(self) -> dict:
|
||||
"""Get current memory usage"""
|
||||
try:
|
||||
import psutil
|
||||
|
||||
process = psutil.Process()
|
||||
memory_info = process.memory_info()
|
||||
system_memory = psutil.virtual_memory()
|
||||
|
||||
|
||||
return {
|
||||
"process_rss_mb": round(memory_info.rss / (1024 * 1024), 2),
|
||||
"process_vms_mb": round(memory_info.vms / (1024 * 1024), 2),
|
||||
"system_total_gb": round(system_memory.total / (1024 * 1024 * 1024), 2),
|
||||
"system_available_gb": round(system_memory.available / (1024 * 1024 * 1024), 2),
|
||||
"system_percent": system_memory.percent
|
||||
"system_available_gb": round(
|
||||
system_memory.available / (1024 * 1024 * 1024), 2
|
||||
),
|
||||
"system_percent": system_memory.percent,
|
||||
}
|
||||
except ImportError:
|
||||
return {"error": "psutil not installed"}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
def check_memory(self) -> bool:
|
||||
"""Check if memory usage is within limits"""
|
||||
usage = self.get_memory_usage()
|
||||
if "error" in usage:
|
||||
return True # Can't check, assume OK
|
||||
|
||||
|
||||
return usage.get("system_percent", 0) < self.max_memory_percent
|
||||
|
||||
|
||||
def on_high_memory(self, callback):
|
||||
"""Register callback for high memory situations"""
|
||||
self._high_memory_callbacks.append(callback)
|
||||
@@ -323,67 +396,75 @@ class MemoryMonitor:
|
||||
|
||||
class HealthChecker:
|
||||
"""Comprehensive health checking for the application"""
|
||||
|
||||
def __init__(self, cleanup_manager: FileCleanupManager, memory_monitor: MemoryMonitor):
|
||||
|
||||
def __init__(
|
||||
self, cleanup_manager: FileCleanupManager, memory_monitor: MemoryMonitor
|
||||
):
|
||||
self.cleanup_manager = cleanup_manager
|
||||
self.memory_monitor = memory_monitor
|
||||
self.start_time = datetime.now()
|
||||
self._translation_count = 0
|
||||
self._error_count = 0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
|
||||
def record_translation(self, success: bool = True):
|
||||
"""Record a translation attempt"""
|
||||
with self._lock:
|
||||
self._translation_count += 1
|
||||
if not success:
|
||||
self._error_count += 1
|
||||
|
||||
|
||||
async def check_health(self) -> dict:
|
||||
"""Get comprehensive health status (async version)"""
|
||||
return self.get_health()
|
||||
|
||||
|
||||
def get_health(self) -> dict:
|
||||
"""Get comprehensive health status"""
|
||||
memory = self.memory_monitor.get_memory_usage()
|
||||
disk = self.cleanup_manager.get_disk_usage()
|
||||
|
||||
|
||||
# Determine overall status
|
||||
status = "healthy"
|
||||
issues = []
|
||||
|
||||
|
||||
if "error" not in memory:
|
||||
if memory.get("system_percent", 0) > 90:
|
||||
status = "degraded"
|
||||
issues.append("High memory usage")
|
||||
elif memory.get("system_percent", 0) > 80:
|
||||
issues.append("Memory usage elevated")
|
||||
|
||||
|
||||
if disk.get("usage_percent", 0) > 90:
|
||||
status = "degraded"
|
||||
issues.append("High disk usage")
|
||||
elif disk.get("usage_percent", 0) > 80:
|
||||
issues.append("Disk usage elevated")
|
||||
|
||||
|
||||
uptime = datetime.now() - self.start_time
|
||||
|
||||
|
||||
return {
|
||||
"status": status,
|
||||
"issues": issues,
|
||||
"uptime_seconds": int(uptime.total_seconds()),
|
||||
"uptime_human": str(uptime).split('.')[0],
|
||||
"uptime_human": str(uptime).split(".")[0],
|
||||
"translations": {
|
||||
"total": self._translation_count,
|
||||
"errors": self._error_count,
|
||||
"success_rate": round(
|
||||
((self._translation_count - self._error_count) / self._translation_count * 100)
|
||||
if self._translation_count > 0 else 100, 1
|
||||
)
|
||||
(
|
||||
(self._translation_count - self._error_count)
|
||||
/ self._translation_count
|
||||
* 100
|
||||
)
|
||||
if self._translation_count > 0
|
||||
else 100,
|
||||
1,
|
||||
),
|
||||
},
|
||||
"memory": memory,
|
||||
"disk": disk,
|
||||
"cleanup_service": self.cleanup_manager.get_stats(),
|
||||
"timestamp": datetime.now().isoformat()
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
|
||||
@@ -394,7 +475,7 @@ def create_cleanup_manager(config) -> FileCleanupManager:
|
||||
upload_dir=config.UPLOAD_DIR,
|
||||
output_dir=config.OUTPUT_DIR,
|
||||
temp_dir=config.TEMP_DIR,
|
||||
max_file_age_hours=getattr(config, 'MAX_FILE_AGE_HOURS', 1),
|
||||
cleanup_interval_minutes=getattr(config, 'CLEANUP_INTERVAL_MINUTES', 10),
|
||||
max_total_size_gb=getattr(config, 'MAX_TOTAL_SIZE_GB', 10.0)
|
||||
max_file_age_minutes=getattr(config, "FILE_TTL_MINUTES", 60),
|
||||
cleanup_interval_minutes=getattr(config, "CLEANUP_INTERVAL_MINUTES", 5),
|
||||
max_total_size_gb=getattr(config, "MAX_TOTAL_SIZE_GB", 10.0),
|
||||
)
|
||||
|
||||
107
middleware/error_handler.py
Normal file
107
middleware/error_handler.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""
|
||||
Global Error Handling Middleware
|
||||
Catches all unhandled exceptions and standardizes API error responses.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response, JSONResponse
|
||||
from fastapi import HTTPException
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
|
||||
# Import APIKeyError for handling
|
||||
from middleware.api_key_auth import APIKeyError
|
||||
|
||||
try:
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
except ImportError:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def format_error_response(
|
||||
status_code: int,
|
||||
message: str,
|
||||
error_code: str = None,
|
||||
details: dict = None,
|
||||
request_id: str = "unknown",
|
||||
headers: dict = None,
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
Standardizes the error response format.
|
||||
Format: {error: "CODE", message: "...", details: {...}}
|
||||
"""
|
||||
if not error_code:
|
||||
error_code = _map_http_status_to_code(status_code)
|
||||
|
||||
content = {"error": error_code, "message": message, "details": details or {}}
|
||||
|
||||
# Always include request_id in details if not present
|
||||
if "request_id" not in content["details"]:
|
||||
content["details"]["request_id"] = request_id
|
||||
|
||||
return JSONResponse(status_code=status_code, content=content, headers=headers)
|
||||
|
||||
|
||||
def _map_http_status_to_code(status_code: int) -> str:
|
||||
"""Map HTTP status codes to architectural error codes."""
|
||||
mapping = {
|
||||
400: "INVALID_FORMAT",
|
||||
401: "UNAUTHORIZED",
|
||||
403: "FORBIDDEN",
|
||||
404: "NOT_FOUND",
|
||||
405: "METHOD_NOT_ALLOWED",
|
||||
413: "FILE_TOO_LARGE",
|
||||
422: "VALIDATION_ERROR",
|
||||
429: "QUOTA_EXCEEDED",
|
||||
502: "PROVIDER_ERROR",
|
||||
503: "SERVICE_UNAVAILABLE",
|
||||
}
|
||||
return mapping.get(status_code, "INTERNAL_ERROR")
|
||||
|
||||
|
||||
class ErrorHandlingMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Catch all unhandled exceptions (Exception) that bubble up to the top.
|
||||
Note: HTTPException is often caught by FastAPI handlers before reaching here.
|
||||
"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
try:
|
||||
return await call_next(request)
|
||||
except APIKeyError as e:
|
||||
# Handle APIKeyError with structured response using to_response()
|
||||
request_id = getattr(request.state, "request_id", "unknown")
|
||||
logger.info(f"[{request_id}] API Key authentication error: {e.code}")
|
||||
return e.to_response()
|
||||
except Exception as e:
|
||||
request_id = getattr(request.state, "request_id", "unknown")
|
||||
|
||||
# If it's already an HTTPException, we might want to handle it specifically if it leaked through
|
||||
if isinstance(e, (HTTPException, StarletteHTTPException)):
|
||||
detail = e.detail if hasattr(e, "detail") and e.detail else {}
|
||||
if isinstance(detail, dict):
|
||||
return format_error_response(
|
||||
status_code=e.status_code,
|
||||
message=detail.get("message", "Une erreur s'est produite."),
|
||||
error_code=detail.get("error"),
|
||||
request_id=request_id,
|
||||
)
|
||||
return format_error_response(
|
||||
status_code=e.status_code,
|
||||
message=str(detail) if detail else "Une erreur s'est produite.",
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
# Log the full stack trace for internal debugging
|
||||
logger.exception(f"[{request_id}] Unhandled internal exception: {str(e)}")
|
||||
|
||||
# Return generic error in French to user (AC4, AC5)
|
||||
return format_error_response(
|
||||
status_code=500,
|
||||
message="Une erreur inattendue s'est produite. Veuillez réessayer plus tard.",
|
||||
error_code="INTERNAL_ERROR",
|
||||
request_id=request_id,
|
||||
)
|
||||
@@ -116,27 +116,3 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
||||
return request.client.host
|
||||
|
||||
return "unknown"
|
||||
|
||||
|
||||
class ErrorHandlingMiddleware(BaseHTTPMiddleware):
|
||||
"""Catch all unhandled exceptions and return proper error responses"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
try:
|
||||
return await call_next(request)
|
||||
|
||||
except Exception as e:
|
||||
request_id = getattr(request.state, 'request_id', 'unknown')
|
||||
logger.exception(f"[{request_id}] Unhandled exception: {str(e)}")
|
||||
|
||||
# Don't expose internal errors in production
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"error": "internal_server_error",
|
||||
"message": "An unexpected error occurred. Please try again later.",
|
||||
"request_id": request_id
|
||||
}
|
||||
)
|
||||
|
||||
180
middleware/tier_quota.py
Normal file
180
middleware/tier_quota.py
Normal file
@@ -0,0 +1,180 @@
|
||||
"""
|
||||
Tier-based daily translation quota (Story 1.6).
|
||||
Uses Redis sliding-window daily counter per user; fallback in-memory when Redis unavailable.
|
||||
Coexists with IP-based rate limiting in rate_limiting.py.
|
||||
|
||||
Source of truth: Redis (key per user per UTC date) is the authority for quota enforcement.
|
||||
User.daily_translation_count in DB is kept in sync on each successful translation for
|
||||
reporting/analytics; reset at midnight UTC is automatic in Redis (new key per day). DB
|
||||
reset can be done by a scheduled job at midnight UTC if needed.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Free tier: 5 translations per day (UTC). Pro (and equivalent) tiers: no daily cap.
|
||||
FREE_TIER_DAILY_LIMIT = 5
|
||||
KEY_PREFIX = "rate_limit:daily"
|
||||
|
||||
|
||||
def _utc_date_str(dt: Optional[datetime] = None) -> str:
|
||||
"""Current date in UTC as YYYY-MM-DD."""
|
||||
t = dt or datetime.now(timezone.utc)
|
||||
return t.strftime("%Y-%m-%d")
|
||||
|
||||
|
||||
def _next_midnight_utc(dt: Optional[datetime] = None) -> datetime:
|
||||
"""Next midnight UTC after the given time (or now)."""
|
||||
now = dt or datetime.now(timezone.utc)
|
||||
tomorrow = (now.date() + timedelta(days=1))
|
||||
return datetime(tomorrow.year, tomorrow.month, tomorrow.day, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
def _seconds_until_midnight_utc(dt: Optional[datetime] = None) -> int:
|
||||
"""Seconds until next midnight UTC."""
|
||||
now = dt or datetime.now(timezone.utc)
|
||||
return max(0, int((_next_midnight_utc(now) - now).total_seconds()))
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuotaResult:
|
||||
"""Result of a quota check."""
|
||||
allowed: bool
|
||||
remaining: int # -1 for pro (unlimited)
|
||||
reset_at_utc: datetime
|
||||
current_usage: int = 0
|
||||
limit: int = FREE_TIER_DAILY_LIMIT
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Redis backend
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_async_redis = None
|
||||
|
||||
|
||||
def _get_async_redis():
|
||||
"""Return async Redis client or None. Uses REDIS_URL from env. Single shared client."""
|
||||
global _async_redis
|
||||
if _async_redis is not None:
|
||||
return _async_redis if _async_redis is not False else None
|
||||
url = os.getenv("REDIS_URL", "").strip()
|
||||
if not url:
|
||||
_async_redis = False
|
||||
return None
|
||||
try:
|
||||
import redis.asyncio as redis
|
||||
_async_redis = redis.Redis.from_url(url, decode_responses=True)
|
||||
logger.info("Tier quota: using Redis for daily quota")
|
||||
return _async_redis
|
||||
except Exception as e:
|
||||
logger.warning("Tier quota: Redis unavailable (%s), using in-memory fallback", e)
|
||||
_async_redis = False
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# In-memory fallback (per process; not shared across workers). Documented as fallback.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_memory_usage: dict[tuple[str, str], int] = {} # (user_id, date_utc_str) -> count
|
||||
|
||||
|
||||
def _memory_get(user_id: str, date_str: str) -> int:
|
||||
return _memory_usage.get((user_id, date_str), 0)
|
||||
|
||||
|
||||
def _memory_incr(user_id: str, date_str: str) -> int:
|
||||
key = (user_id, date_str)
|
||||
_memory_usage[key] = _memory_usage.get(key, 0) + 1
|
||||
return _memory_usage[key]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TierQuotaService
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TierQuotaService:
|
||||
"""
|
||||
Daily translation quota per user by tier.
|
||||
Redis key pattern: rate_limit:daily:{user_id}:{YYYY-MM-DD}, TTL 25h.
|
||||
If Redis is unavailable, uses in-memory dict (documented fallback).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._redis = None # Lazy init on first use
|
||||
|
||||
def _redis_client(self):
|
||||
if self._redis is None:
|
||||
self._redis = _get_async_redis()
|
||||
return self._redis
|
||||
|
||||
def _date_str(self, dt: Optional[datetime] = None) -> str:
|
||||
return _utc_date_str(dt)
|
||||
|
||||
async def check_quota(self, user_id: str, tier: str) -> QuotaResult:
|
||||
"""
|
||||
Check if user has quota for one more translation today (UTC).
|
||||
tier "free" -> limit 5/day; "pro" (or equivalent) -> unlimited.
|
||||
"""
|
||||
reset_at = _next_midnight_utc()
|
||||
tier_lower = (tier or "free").lower()
|
||||
if tier_lower in ("pro", "business", "enterprise", "starter"):
|
||||
return QuotaResult(
|
||||
allowed=True,
|
||||
remaining=-1,
|
||||
reset_at_utc=reset_at,
|
||||
current_usage=0,
|
||||
limit=0,
|
||||
)
|
||||
# Free tier
|
||||
date_str = self._date_str()
|
||||
redis_client = self._redis_client()
|
||||
if redis_client:
|
||||
try:
|
||||
key = f"{KEY_PREFIX}:{user_id}:{date_str}"
|
||||
count = await redis_client.get(key)
|
||||
count = int(count or 0)
|
||||
except Exception as e:
|
||||
logger.warning("Tier quota Redis get failed: %s, using in-memory", e)
|
||||
count = _memory_get(user_id, date_str)
|
||||
else:
|
||||
count = _memory_get(user_id, date_str)
|
||||
remaining = max(0, FREE_TIER_DAILY_LIMIT - count)
|
||||
return QuotaResult(
|
||||
allowed=count < FREE_TIER_DAILY_LIMIT,
|
||||
remaining=remaining,
|
||||
reset_at_utc=reset_at,
|
||||
current_usage=count,
|
||||
limit=FREE_TIER_DAILY_LIMIT,
|
||||
)
|
||||
|
||||
async def increment_on_success(self, user_id: str) -> None:
|
||||
"""Increment daily translation count for user (call after successful translation)."""
|
||||
date_str = self._date_str()
|
||||
redis_client = self._redis_client()
|
||||
if redis_client:
|
||||
try:
|
||||
key = f"{KEY_PREFIX}:{user_id}:{date_str}"
|
||||
pipe = redis_client.pipeline()
|
||||
pipe.incr(key)
|
||||
pipe.expire(key, 25 * 3600) # 25h so key expires after midnight UTC
|
||||
await pipe.execute()
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning("Tier quota Redis increment failed: %s, using in-memory", e)
|
||||
_memory_incr(user_id, date_str)
|
||||
|
||||
def seconds_until_reset(self) -> int:
|
||||
"""Seconds until next midnight UTC (for Retry-After header)."""
|
||||
return _seconds_until_midnight_utc()
|
||||
|
||||
|
||||
# Singleton for app use
|
||||
tier_quota_service = TierQuotaService()
|
||||
@@ -2,10 +2,14 @@
|
||||
Input Validation Module for SaaS robustness
|
||||
Validates all user inputs before processing
|
||||
"""
|
||||
|
||||
import re
|
||||
import magic
|
||||
import ipaddress
|
||||
import socket
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Set
|
||||
from urllib.parse import urlparse
|
||||
from typing import Optional, List, Set, Tuple
|
||||
from fastapi import UploadFile, HTTPException
|
||||
import logging
|
||||
|
||||
@@ -14,7 +18,13 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class ValidationError(Exception):
|
||||
"""Custom validation error with user-friendly messages"""
|
||||
def __init__(self, message: str, code: str = "validation_error", details: Optional[dict] = None):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
code: str = "validation_error",
|
||||
details: Optional[dict] = None,
|
||||
):
|
||||
self.message = message
|
||||
self.code = code
|
||||
self.details = details or {}
|
||||
@@ -23,37 +33,46 @@ class ValidationError(Exception):
|
||||
|
||||
class ValidationResult:
|
||||
"""Result of a validation check"""
|
||||
def __init__(self, is_valid: bool = True, errors: List[str] = None, warnings: List[str] = None, data: dict = None):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
is_valid: bool = True,
|
||||
errors: Optional[List[str]] = None,
|
||||
warnings: Optional[List[str]] = None,
|
||||
data: Optional[dict] = None,
|
||||
error_code: Optional[str] = None,
|
||||
):
|
||||
self.is_valid = is_valid
|
||||
self.errors = errors or []
|
||||
self.warnings = warnings or []
|
||||
self.data = data or {}
|
||||
self.error_code = error_code
|
||||
|
||||
|
||||
class FileValidator:
|
||||
"""Validates uploaded files for security and compatibility"""
|
||||
|
||||
|
||||
# Allowed MIME types mapped to extensions
|
||||
ALLOWED_MIME_TYPES = {
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx",
|
||||
}
|
||||
|
||||
|
||||
# Magic bytes for Office Open XML files (ZIP format)
|
||||
OFFICE_MAGIC_BYTES = b"PK\x03\x04"
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_size_mb: int = 50,
|
||||
allowed_extensions: Set[str] = None,
|
||||
scan_content: bool = True
|
||||
allowed_extensions: Optional[Set[str]] = None,
|
||||
scan_content: bool = True,
|
||||
):
|
||||
self.max_size_bytes = max_size_mb * 1024 * 1024
|
||||
self.max_size_mb = max_size_mb
|
||||
self.allowed_extensions = allowed_extensions or {".xlsx", ".docx", ".pptx"}
|
||||
self.scan_content = scan_content
|
||||
|
||||
|
||||
async def validate_async(self, file: UploadFile) -> ValidationResult:
|
||||
"""
|
||||
Validate an uploaded file asynchronously
|
||||
@@ -62,77 +81,105 @@ class FileValidator:
|
||||
errors = []
|
||||
warnings = []
|
||||
data = {}
|
||||
|
||||
|
||||
try:
|
||||
# Validate filename
|
||||
if not file.filename:
|
||||
errors.append("Filename is required")
|
||||
return ValidationResult(is_valid=False, errors=errors)
|
||||
|
||||
errors.append("Le nom de fichier est requis")
|
||||
return ValidationResult(
|
||||
is_valid=False, errors=errors, error_code="missing_filename"
|
||||
)
|
||||
|
||||
# Sanitize filename
|
||||
try:
|
||||
safe_filename = self._sanitize_filename(file.filename)
|
||||
data["safe_filename"] = safe_filename
|
||||
except ValidationError as e:
|
||||
errors.append(str(e.message))
|
||||
return ValidationResult(is_valid=False, errors=errors)
|
||||
|
||||
return ValidationResult(
|
||||
is_valid=False, errors=errors, error_code=e.code
|
||||
)
|
||||
|
||||
# Validate extension
|
||||
try:
|
||||
extension = self._validate_extension(safe_filename)
|
||||
data["extension"] = extension
|
||||
except ValidationError as e:
|
||||
errors.append(str(e.message))
|
||||
return ValidationResult(is_valid=False, errors=errors)
|
||||
|
||||
return ValidationResult(
|
||||
is_valid=False, errors=errors, error_code=e.code
|
||||
)
|
||||
|
||||
# Read file content for validation
|
||||
content = await file.read()
|
||||
await file.seek(0) # Reset for later processing
|
||||
|
||||
|
||||
# Validate file size
|
||||
file_size = len(content)
|
||||
data["size_bytes"] = file_size
|
||||
data["size_mb"] = round(file_size / (1024*1024), 2)
|
||||
|
||||
data["size_mb"] = round(file_size / (1024 * 1024), 2)
|
||||
|
||||
if file_size > self.max_size_bytes:
|
||||
errors.append(f"File too large. Maximum size is {self.max_size_mb}MB, got {file_size / (1024*1024):.1f}MB")
|
||||
return ValidationResult(is_valid=False, errors=errors, data=data)
|
||||
|
||||
errors.append(
|
||||
f"Fichier trop volumineux. La taille maximale est de {self.max_size_mb}Mo, "
|
||||
f"vous avez envoye {file_size / (1024 * 1024):.1f}Mo"
|
||||
)
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
errors=errors,
|
||||
data=data,
|
||||
error_code="file_too_large",
|
||||
)
|
||||
|
||||
if file_size == 0:
|
||||
errors.append("File is empty")
|
||||
return ValidationResult(is_valid=False, errors=errors, data=data)
|
||||
|
||||
errors.append("Le fichier est vide")
|
||||
return ValidationResult(
|
||||
is_valid=False, errors=errors, data=data, error_code="empty_file"
|
||||
)
|
||||
|
||||
# Warn about large files
|
||||
if file_size > self.max_size_bytes * 0.8:
|
||||
warnings.append(f"File is {data['size_mb']}MB, approaching the {self.max_size_mb}MB limit")
|
||||
|
||||
warnings.append(
|
||||
f"Le fichier fait {data['size_mb']}Mo, approchant la limite de {self.max_size_mb}Mo"
|
||||
)
|
||||
|
||||
# Validate magic bytes
|
||||
if self.scan_content:
|
||||
try:
|
||||
self._validate_magic_bytes(content, extension)
|
||||
except ValidationError as e:
|
||||
errors.append(str(e.message))
|
||||
return ValidationResult(is_valid=False, errors=errors, data=data)
|
||||
|
||||
return ValidationResult(
|
||||
is_valid=False, errors=errors, data=data, error_code=e.code
|
||||
)
|
||||
|
||||
# Validate MIME type
|
||||
try:
|
||||
mime_type = self._detect_mime_type(content)
|
||||
data["mime_type"] = mime_type
|
||||
self._validate_mime_type(mime_type, extension)
|
||||
except ValidationError as e:
|
||||
warnings.append(f"MIME type warning: {e.message}")
|
||||
warnings.append(f"Avertissement MIME: {e.message}")
|
||||
except Exception:
|
||||
warnings.append("Could not verify MIME type")
|
||||
|
||||
warnings.append("Impossible de verifier le type MIME")
|
||||
|
||||
data["original_filename"] = file.filename
|
||||
|
||||
return ValidationResult(is_valid=True, errors=errors, warnings=warnings, data=data)
|
||||
|
||||
|
||||
return ValidationResult(
|
||||
is_valid=True, errors=errors, warnings=warnings, data=data
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Validation error: {str(e)}")
|
||||
errors.append(f"Validation failed: {str(e)}")
|
||||
return ValidationResult(is_valid=False, errors=errors, warnings=warnings, data=data)
|
||||
|
||||
errors.append(f"Erreur de validation: {str(e)}")
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
errors=errors,
|
||||
warnings=warnings,
|
||||
data=data,
|
||||
error_code="validation_error",
|
||||
)
|
||||
|
||||
async def validate(self, file: UploadFile) -> dict:
|
||||
"""
|
||||
Validate an uploaded file
|
||||
@@ -141,106 +188,107 @@ class FileValidator:
|
||||
# Validate filename
|
||||
if not file.filename:
|
||||
raise ValidationError(
|
||||
"Filename is required",
|
||||
code="missing_filename"
|
||||
"Le nom de fichier est requis", code="missing_filename"
|
||||
)
|
||||
|
||||
|
||||
# Sanitize filename
|
||||
safe_filename = self._sanitize_filename(file.filename)
|
||||
|
||||
|
||||
# Validate extension
|
||||
extension = self._validate_extension(safe_filename)
|
||||
|
||||
|
||||
# Read file content for validation
|
||||
content = await file.read()
|
||||
await file.seek(0) # Reset for later processing
|
||||
|
||||
|
||||
# Validate file size
|
||||
file_size = len(content)
|
||||
if file_size > self.max_size_bytes:
|
||||
raise ValidationError(
|
||||
f"File too large. Maximum size is {self.max_size_mb}MB, got {file_size / (1024*1024):.1f}MB",
|
||||
f"Fichier trop volumineux. La taille maximale est de {self.max_size_mb}Mo, "
|
||||
f"vous avez envoye {file_size / (1024 * 1024):.1f}Mo",
|
||||
code="file_too_large",
|
||||
details={"max_mb": self.max_size_mb, "actual_mb": round(file_size / (1024*1024), 2)}
|
||||
details={
|
||||
"max_mb": self.max_size_mb,
|
||||
"actual_mb": round(file_size / (1024 * 1024), 2),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
if file_size == 0:
|
||||
raise ValidationError(
|
||||
"File is empty",
|
||||
code="empty_file"
|
||||
)
|
||||
|
||||
raise ValidationError("Le fichier est vide", code="empty_file")
|
||||
|
||||
# Validate magic bytes (file signature)
|
||||
if self.scan_content:
|
||||
self._validate_magic_bytes(content, extension)
|
||||
|
||||
|
||||
# Validate MIME type
|
||||
mime_type = self._detect_mime_type(content)
|
||||
self._validate_mime_type(mime_type, extension)
|
||||
|
||||
|
||||
return {
|
||||
"original_filename": file.filename,
|
||||
"safe_filename": safe_filename,
|
||||
"extension": extension,
|
||||
"size_bytes": file_size,
|
||||
"size_mb": round(file_size / (1024*1024), 2),
|
||||
"mime_type": mime_type
|
||||
"size_mb": round(file_size / (1024 * 1024), 2),
|
||||
"mime_type": mime_type,
|
||||
}
|
||||
|
||||
|
||||
def _sanitize_filename(self, filename: str) -> str:
|
||||
"""Sanitize filename to prevent path traversal and other attacks"""
|
||||
# Remove path components
|
||||
filename = Path(filename).name
|
||||
|
||||
|
||||
# Remove null bytes and control characters
|
||||
filename = re.sub(r'[\x00-\x1f\x7f-\x9f]', '', filename)
|
||||
|
||||
filename = re.sub(r"[\x00-\x1f\x7f-\x9f]", "", filename)
|
||||
|
||||
# Remove potentially dangerous characters
|
||||
filename = re.sub(r'[<>:"/\\|?*]', '_', filename)
|
||||
|
||||
filename = re.sub(r'[<>:"/\\|?*]', "_", filename)
|
||||
|
||||
# Limit length
|
||||
if len(filename) > 255:
|
||||
name, ext = filename.rsplit('.', 1) if '.' in filename else (filename, '')
|
||||
filename = name[:250] + ('.' + ext if ext else '')
|
||||
|
||||
name, ext = filename.rsplit(".", 1) if "." in filename else (filename, "")
|
||||
filename = name[:250] + ("." + ext if ext else "")
|
||||
|
||||
# Ensure not empty after sanitization
|
||||
if not filename or filename.strip() == '':
|
||||
raise ValidationError(
|
||||
"Invalid filename",
|
||||
code="invalid_filename"
|
||||
)
|
||||
|
||||
if not filename or filename.strip() == "":
|
||||
raise ValidationError("Nom de fichier invalide", code="invalid_filename")
|
||||
|
||||
return filename
|
||||
|
||||
|
||||
def _validate_extension(self, filename: str) -> str:
|
||||
"""Validate and return the file extension"""
|
||||
if '.' not in filename:
|
||||
if "." not in filename:
|
||||
raise ValidationError(
|
||||
f"File must have an extension. Supported: {', '.join(self.allowed_extensions)}",
|
||||
f"Le fichier doit avoir une extension. Formats supportes : {', '.join(self.allowed_extensions)}",
|
||||
code="missing_extension",
|
||||
details={"allowed_extensions": list(self.allowed_extensions)}
|
||||
details={"allowed_extensions": list(self.allowed_extensions)},
|
||||
)
|
||||
|
||||
extension = '.' + filename.rsplit('.', 1)[1].lower()
|
||||
|
||||
|
||||
extension = "." + filename.rsplit(".", 1)[1].lower()
|
||||
|
||||
if extension not in self.allowed_extensions:
|
||||
raise ValidationError(
|
||||
f"File type '{extension}' not supported. Supported types: {', '.join(self.allowed_extensions)}",
|
||||
f"Format de fichier '{extension}' non supporte. Formats acceptes : {', '.join(self.allowed_extensions)}",
|
||||
code="unsupported_file_type",
|
||||
details={"extension": extension, "allowed_extensions": list(self.allowed_extensions)}
|
||||
details={
|
||||
"extension": extension,
|
||||
"allowed_extensions": list(self.allowed_extensions),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
return extension
|
||||
|
||||
|
||||
def _validate_magic_bytes(self, content: bytes, extension: str):
|
||||
"""Validate file magic bytes match expected format"""
|
||||
# All supported formats are Office Open XML (ZIP-based)
|
||||
if not content.startswith(self.OFFICE_MAGIC_BYTES):
|
||||
raise ValidationError(
|
||||
"File content does not match expected format. The file may be corrupted or not a valid Office document.",
|
||||
code="invalid_file_content"
|
||||
"Le contenu du fichier ne correspond pas au format attendu. "
|
||||
"Le fichier est peut-etre corrompu ou n'est pas un document Office valide.",
|
||||
code="invalid_file_content",
|
||||
)
|
||||
|
||||
|
||||
def _detect_mime_type(self, content: bytes) -> str:
|
||||
"""Detect MIME type from file content"""
|
||||
try:
|
||||
@@ -251,77 +299,198 @@ class FileValidator:
|
||||
if content.startswith(self.OFFICE_MAGIC_BYTES):
|
||||
return "application/zip"
|
||||
return "application/octet-stream"
|
||||
|
||||
|
||||
def _validate_mime_type(self, mime_type: str, extension: str):
|
||||
"""Validate MIME type matches extension"""
|
||||
# Office Open XML files may be detected as ZIP
|
||||
allowed_mimes = list(self.ALLOWED_MIME_TYPES.keys()) + ["application/zip", "application/octet-stream"]
|
||||
|
||||
allowed_mimes = list(self.ALLOWED_MIME_TYPES.keys()) + [
|
||||
"application/zip",
|
||||
"application/octet-stream",
|
||||
]
|
||||
|
||||
if mime_type not in allowed_mimes:
|
||||
raise ValidationError(
|
||||
f"Invalid file type detected. Expected Office document, got: {mime_type}",
|
||||
f"Type de fichier invalide detecte. Document Office attendu, recu : {mime_type}",
|
||||
code="invalid_mime_type",
|
||||
details={"detected_mime": mime_type}
|
||||
details={"detected_mime": mime_type},
|
||||
)
|
||||
|
||||
|
||||
class LanguageValidator:
|
||||
"""Validates language codes"""
|
||||
|
||||
|
||||
SUPPORTED_LANGUAGES = {
|
||||
# ISO 639-1 codes
|
||||
"af", "sq", "am", "ar", "hy", "az", "eu", "be", "bn", "bs",
|
||||
"bg", "ca", "ceb", "zh", "zh-CN", "zh-TW", "co", "hr", "cs",
|
||||
"da", "nl", "en", "eo", "et", "fi", "fr", "fy", "gl", "ka",
|
||||
"de", "el", "gu", "ht", "ha", "haw", "he", "hi", "hmn", "hu",
|
||||
"is", "ig", "id", "ga", "it", "ja", "jv", "kn", "kk", "km",
|
||||
"rw", "ko", "ku", "ky", "lo", "la", "lv", "lt", "lb", "mk",
|
||||
"mg", "ms", "ml", "mt", "mi", "mr", "mn", "my", "ne", "no",
|
||||
"ny", "or", "ps", "fa", "pl", "pt", "pa", "ro", "ru", "sm",
|
||||
"gd", "sr", "st", "sn", "sd", "si", "sk", "sl", "so", "es",
|
||||
"su", "sw", "sv", "tl", "tg", "ta", "tt", "te", "th", "tr",
|
||||
"tk", "uk", "ur", "ug", "uz", "vi", "cy", "xh", "yi", "yo",
|
||||
"zu", "auto"
|
||||
"af",
|
||||
"sq",
|
||||
"am",
|
||||
"ar",
|
||||
"hy",
|
||||
"az",
|
||||
"eu",
|
||||
"be",
|
||||
"bn",
|
||||
"bs",
|
||||
"bg",
|
||||
"ca",
|
||||
"ceb",
|
||||
"zh",
|
||||
"zh-CN",
|
||||
"zh-TW",
|
||||
"co",
|
||||
"hr",
|
||||
"cs",
|
||||
"da",
|
||||
"nl",
|
||||
"en",
|
||||
"eo",
|
||||
"et",
|
||||
"fi",
|
||||
"fr",
|
||||
"fy",
|
||||
"gl",
|
||||
"ka",
|
||||
"de",
|
||||
"el",
|
||||
"gu",
|
||||
"ht",
|
||||
"ha",
|
||||
"haw",
|
||||
"he",
|
||||
"hi",
|
||||
"hmn",
|
||||
"hu",
|
||||
"is",
|
||||
"ig",
|
||||
"id",
|
||||
"ga",
|
||||
"it",
|
||||
"ja",
|
||||
"jv",
|
||||
"kn",
|
||||
"kk",
|
||||
"km",
|
||||
"rw",
|
||||
"ko",
|
||||
"ku",
|
||||
"ky",
|
||||
"lo",
|
||||
"la",
|
||||
"lv",
|
||||
"lt",
|
||||
"lb",
|
||||
"mk",
|
||||
"mg",
|
||||
"ms",
|
||||
"ml",
|
||||
"mt",
|
||||
"mi",
|
||||
"mr",
|
||||
"mn",
|
||||
"my",
|
||||
"ne",
|
||||
"no",
|
||||
"ny",
|
||||
"or",
|
||||
"ps",
|
||||
"fa",
|
||||
"pl",
|
||||
"pt",
|
||||
"pa",
|
||||
"ro",
|
||||
"ru",
|
||||
"sm",
|
||||
"gd",
|
||||
"sr",
|
||||
"st",
|
||||
"sn",
|
||||
"sd",
|
||||
"si",
|
||||
"sk",
|
||||
"sl",
|
||||
"so",
|
||||
"es",
|
||||
"su",
|
||||
"sw",
|
||||
"sv",
|
||||
"tl",
|
||||
"tg",
|
||||
"ta",
|
||||
"tt",
|
||||
"te",
|
||||
"th",
|
||||
"tr",
|
||||
"tk",
|
||||
"uk",
|
||||
"ur",
|
||||
"ug",
|
||||
"uz",
|
||||
"vi",
|
||||
"cy",
|
||||
"xh",
|
||||
"yi",
|
||||
"yo",
|
||||
"zu",
|
||||
"auto",
|
||||
}
|
||||
|
||||
|
||||
LANGUAGE_NAMES = {
|
||||
"en": "English", "es": "Spanish", "fr": "French", "de": "German",
|
||||
"it": "Italian", "pt": "Portuguese", "ru": "Russian", "zh": "Chinese",
|
||||
"zh-CN": "Chinese (Simplified)", "zh-TW": "Chinese (Traditional)",
|
||||
"ja": "Japanese", "ko": "Korean", "ar": "Arabic", "hi": "Hindi",
|
||||
"nl": "Dutch", "pl": "Polish", "tr": "Turkish", "sv": "Swedish",
|
||||
"da": "Danish", "no": "Norwegian", "fi": "Finnish", "cs": "Czech",
|
||||
"el": "Greek", "th": "Thai", "vi": "Vietnamese", "id": "Indonesian",
|
||||
"uk": "Ukrainian", "ro": "Romanian", "hu": "Hungarian", "auto": "Auto-detect"
|
||||
"en": "English",
|
||||
"es": "Spanish",
|
||||
"fr": "French",
|
||||
"de": "German",
|
||||
"it": "Italian",
|
||||
"pt": "Portuguese",
|
||||
"ru": "Russian",
|
||||
"zh": "Chinese",
|
||||
"zh-CN": "Chinese (Simplified)",
|
||||
"zh-TW": "Chinese (Traditional)",
|
||||
"ja": "Japanese",
|
||||
"ko": "Korean",
|
||||
"ar": "Arabic",
|
||||
"hi": "Hindi",
|
||||
"nl": "Dutch",
|
||||
"pl": "Polish",
|
||||
"tr": "Turkish",
|
||||
"sv": "Swedish",
|
||||
"da": "Danish",
|
||||
"no": "Norwegian",
|
||||
"fi": "Finnish",
|
||||
"cs": "Czech",
|
||||
"el": "Greek",
|
||||
"th": "Thai",
|
||||
"vi": "Vietnamese",
|
||||
"id": "Indonesian",
|
||||
"uk": "Ukrainian",
|
||||
"ro": "Romanian",
|
||||
"hu": "Hungarian",
|
||||
"auto": "Auto-detect",
|
||||
}
|
||||
|
||||
|
||||
@classmethod
|
||||
def validate(cls, language_code: str, field_name: str = "language") -> str:
|
||||
"""Validate and normalize language code"""
|
||||
if not language_code:
|
||||
raise ValidationError(
|
||||
f"{field_name} is required",
|
||||
code="missing_language"
|
||||
)
|
||||
|
||||
raise ValidationError(f"{field_name} est requis", code="missing_language")
|
||||
|
||||
# Normalize
|
||||
normalized = language_code.strip().lower()
|
||||
|
||||
|
||||
# Handle common variations
|
||||
if normalized in ["chinese", "cn"]:
|
||||
normalized = "zh-CN"
|
||||
elif normalized in ["chinese-traditional", "tw"]:
|
||||
normalized = "zh-TW"
|
||||
|
||||
|
||||
if normalized not in cls.SUPPORTED_LANGUAGES:
|
||||
raise ValidationError(
|
||||
f"Unsupported language code: '{language_code}'. See /languages for supported codes.",
|
||||
f"Code langue non supporte: '{language_code}'. Consultez /languages pour les codes supportes.",
|
||||
code="unsupported_language",
|
||||
details={"language": language_code}
|
||||
details={"language": language_code},
|
||||
)
|
||||
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_language_name(cls, code: str) -> str:
|
||||
"""Get human-readable language name"""
|
||||
@@ -330,104 +499,116 @@ class LanguageValidator:
|
||||
|
||||
class ProviderValidator:
|
||||
"""Validates translation provider configuration"""
|
||||
|
||||
SUPPORTED_PROVIDERS = {"google", "ollama", "deepl", "libre", "openai", "webllm", "openrouter"}
|
||||
|
||||
|
||||
SUPPORTED_PROVIDERS = {
|
||||
"google",
|
||||
"ollama",
|
||||
"deepl",
|
||||
"libre",
|
||||
"openai",
|
||||
"webllm",
|
||||
"openrouter",
|
||||
"classic",
|
||||
"llm",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def validate(cls, provider: str, **kwargs) -> dict:
|
||||
"""Validate provider and its required configuration"""
|
||||
if not provider:
|
||||
raise ValidationError(
|
||||
"Translation provider is required",
|
||||
code="missing_provider"
|
||||
"Le fournisseur de traduction est requis", code="missing_provider"
|
||||
)
|
||||
|
||||
|
||||
normalized = provider.strip().lower()
|
||||
|
||||
|
||||
if normalized not in cls.SUPPORTED_PROVIDERS:
|
||||
raise ValidationError(
|
||||
f"Unsupported provider: '{provider}'. Supported: {', '.join(cls.SUPPORTED_PROVIDERS)}",
|
||||
f"Fournisseur non supporte: '{provider}'. Supportes: {', '.join(cls.SUPPORTED_PROVIDERS)}",
|
||||
code="unsupported_provider",
|
||||
details={"provider": provider, "supported": list(cls.SUPPORTED_PROVIDERS)}
|
||||
details={
|
||||
"provider": provider,
|
||||
"supported": list(cls.SUPPORTED_PROVIDERS),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# Provider-specific validation
|
||||
if normalized == "deepl":
|
||||
if not kwargs.get("deepl_api_key"):
|
||||
raise ValidationError(
|
||||
"DeepL API key is required when using DeepL provider",
|
||||
code="missing_deepl_key"
|
||||
"La cle API DeepL est requise pour utiliser le fournisseur DeepL",
|
||||
code="missing_deepl_key",
|
||||
)
|
||||
|
||||
|
||||
elif normalized == "openai":
|
||||
if not kwargs.get("openai_api_key"):
|
||||
raise ValidationError(
|
||||
"OpenAI API key is required when using OpenAI provider",
|
||||
code="missing_openai_key"
|
||||
"La cle API OpenAI est requise pour utiliser le fournisseur OpenAI",
|
||||
code="missing_openai_key",
|
||||
)
|
||||
|
||||
|
||||
elif normalized == "ollama":
|
||||
# Ollama doesn't require API key but may need model
|
||||
model = kwargs.get("ollama_model", "")
|
||||
if not model:
|
||||
logger.warning("No Ollama model specified, will use default")
|
||||
|
||||
|
||||
return {"provider": normalized, "validated": True}
|
||||
|
||||
|
||||
class InputSanitizer:
|
||||
"""Sanitizes user inputs to prevent injection attacks"""
|
||||
|
||||
|
||||
@staticmethod
|
||||
def sanitize_text(text: str, max_length: int = 10000) -> str:
|
||||
"""Sanitize text input"""
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
|
||||
# Remove null bytes
|
||||
text = text.replace('\x00', '')
|
||||
|
||||
text = text.replace("\x00", "")
|
||||
|
||||
# Limit length
|
||||
if len(text) > max_length:
|
||||
text = text[:max_length]
|
||||
|
||||
|
||||
return text.strip()
|
||||
|
||||
|
||||
@staticmethod
|
||||
def sanitize_language_code(code: str) -> str:
|
||||
"""Sanitize and normalize language code"""
|
||||
if not code:
|
||||
return "auto"
|
||||
|
||||
|
||||
# Remove dangerous characters, keep only alphanumeric and hyphen
|
||||
code = re.sub(r'[^a-zA-Z0-9\-]', '', code.strip())
|
||||
|
||||
code = re.sub(r"[^a-zA-Z0-9\-]", "", code.strip())
|
||||
|
||||
# Limit length
|
||||
if len(code) > 10:
|
||||
code = code[:10]
|
||||
|
||||
|
||||
return code.lower() if code else "auto"
|
||||
|
||||
|
||||
@staticmethod
|
||||
def sanitize_url(url: str) -> str:
|
||||
"""Sanitize URL input"""
|
||||
if not url:
|
||||
return ""
|
||||
|
||||
|
||||
url = url.strip()
|
||||
|
||||
|
||||
# Basic URL validation
|
||||
if not re.match(r'^https?://', url, re.IGNORECASE):
|
||||
if not re.match(r"^https?://", url, re.IGNORECASE):
|
||||
raise ValidationError(
|
||||
"Invalid URL format. Must start with http:// or https://",
|
||||
code="invalid_url"
|
||||
"Format d'URL invalide. Doit commencer par http:// ou https://",
|
||||
code="invalid_url",
|
||||
)
|
||||
|
||||
|
||||
# Remove trailing slashes
|
||||
url = url.rstrip('/')
|
||||
|
||||
url = url.rstrip("/")
|
||||
|
||||
return url
|
||||
|
||||
|
||||
@staticmethod
|
||||
def sanitize_api_key(key: str) -> str:
|
||||
"""Sanitize API key (just trim, no logging)"""
|
||||
@@ -436,5 +617,117 @@ class InputSanitizer:
|
||||
return key.strip()
|
||||
|
||||
|
||||
class WebhookURLValidator:
|
||||
"""
|
||||
Validator for webhook URLs with security checks.
|
||||
|
||||
Prevents SSRF attacks by blocking private IPs and localhost.
|
||||
Story 3.7: Webhook - Spécification URL
|
||||
"""
|
||||
|
||||
# Allowed URL schemes
|
||||
ALLOWED_SCHEMES = ("http", "https")
|
||||
|
||||
# Blocked hostnames
|
||||
BLOCKED_HOSTNAMES = {"localhost", "127.0.0.1", "::1", "0.0.0.0"}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
allowed_schemes: Tuple[str, ...] = ALLOWED_SCHEMES,
|
||||
block_private_ips: bool = True
|
||||
):
|
||||
self.allowed_schemes = allowed_schemes
|
||||
self.block_private_ips = block_private_ips
|
||||
|
||||
def validate(self, url: Optional[str]) -> Tuple[bool, Optional[str], Optional[dict]]:
|
||||
"""
|
||||
Validate webhook URL format and security.
|
||||
|
||||
Args:
|
||||
url: The webhook URL to validate (can be None or empty for optional parameter)
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message, details)
|
||||
"""
|
||||
# Empty or None URLs are valid (optional parameter)
|
||||
if not url:
|
||||
return True, None, None
|
||||
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
|
||||
# Check scheme
|
||||
if parsed.scheme.lower() not in self.allowed_schemes:
|
||||
return False, (
|
||||
f"L'URL doit utiliser {' ou '.join(self.allowed_schemes)}"
|
||||
), {
|
||||
"field": "webhook_url",
|
||||
"allowed_schemes": list(self.allowed_schemes),
|
||||
"detected_scheme": parsed.scheme or "none"
|
||||
}
|
||||
|
||||
# Check for credentials in URL
|
||||
if parsed.username or parsed.password:
|
||||
return False, (
|
||||
"L'URL ne doit pas contenir d'identifiants (credentials)"
|
||||
), {"field": "webhook_url", "reason": "credentials_in_url"}
|
||||
|
||||
# Check hostname
|
||||
hostname = parsed.hostname
|
||||
if not hostname:
|
||||
return False, (
|
||||
"URL invalide: nom d'hôte manquant"
|
||||
), {"field": "webhook_url", "reason": "missing_hostname"}
|
||||
|
||||
# Block localhost and common local addresses
|
||||
if hostname.lower() in self.BLOCKED_HOSTNAMES:
|
||||
return False, (
|
||||
"Les URLs localhost ne sont pas autorisées"
|
||||
), {"field": "webhook_url", "reason": "localhost_blocked"}
|
||||
|
||||
# Check for private IPs (SSRF protection)
|
||||
if self.block_private_ips:
|
||||
try:
|
||||
# Try to parse as IP directly
|
||||
try:
|
||||
ip = ipaddress.ip_address(hostname)
|
||||
if self._is_blocked_ip(ip):
|
||||
return False, (
|
||||
"Les adresses IP privées ne sont pas autorisées"
|
||||
), {"field": "webhook_url", "reason": "private_ip_blocked"}
|
||||
except ValueError:
|
||||
# Not an IP, try DNS resolution
|
||||
ip_str = socket.gethostbyname(hostname)
|
||||
ip = ipaddress.ip_address(ip_str)
|
||||
if self._is_blocked_ip(ip):
|
||||
return False, (
|
||||
"Les adresses IP privées ne sont pas autorisées"
|
||||
), {"field": "webhook_url", "reason": "private_ip_blocked"}
|
||||
except socket.gaierror:
|
||||
# DNS resolution failed - let it through
|
||||
# Will fail at webhook send time
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return True, None, None
|
||||
|
||||
except Exception as e:
|
||||
return False, (
|
||||
f"Format d'URL invalide: {str(e)}"
|
||||
), {"field": "webhook_url", "error": str(e)}
|
||||
|
||||
def _is_blocked_ip(self, ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
|
||||
"""Check if IP is private, loopback, or link-local."""
|
||||
return (
|
||||
ip.is_private or
|
||||
ip.is_loopback or
|
||||
ip.is_link_local or
|
||||
ip.is_reserved or
|
||||
ip.is_multicast
|
||||
)
|
||||
|
||||
|
||||
# Default validators
|
||||
file_validator = FileValidator()
|
||||
webhook_validator = WebhookURLValidator()
|
||||
|
||||
Reference in New Issue
Block a user