441 lines
16 KiB
Python
441 lines
16 KiB
Python
"""
|
|
Input Validation Module for SaaS robustness
|
|
Validates all user inputs before processing
|
|
"""
|
|
import re
|
|
import magic
|
|
from pathlib import Path
|
|
from typing import Optional, List, Set
|
|
from fastapi import UploadFile, HTTPException
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ValidationError(Exception):
|
|
"""Custom validation error with user-friendly messages"""
|
|
def __init__(self, message: str, code: str = "validation_error", details: Optional[dict] = None):
|
|
self.message = message
|
|
self.code = code
|
|
self.details = details or {}
|
|
super().__init__(message)
|
|
|
|
|
|
class ValidationResult:
|
|
"""Result of a validation check"""
|
|
def __init__(self, is_valid: bool = True, errors: List[str] = None, warnings: List[str] = None, data: dict = None):
|
|
self.is_valid = is_valid
|
|
self.errors = errors or []
|
|
self.warnings = warnings or []
|
|
self.data = data or {}
|
|
|
|
|
|
class FileValidator:
|
|
"""Validates uploaded files for security and compatibility"""
|
|
|
|
# Allowed MIME types mapped to extensions
|
|
ALLOWED_MIME_TYPES = {
|
|
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
|
|
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
|
|
"application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx",
|
|
}
|
|
|
|
# Magic bytes for Office Open XML files (ZIP format)
|
|
OFFICE_MAGIC_BYTES = b"PK\x03\x04"
|
|
|
|
def __init__(
|
|
self,
|
|
max_size_mb: int = 50,
|
|
allowed_extensions: Set[str] = None,
|
|
scan_content: bool = True
|
|
):
|
|
self.max_size_bytes = max_size_mb * 1024 * 1024
|
|
self.max_size_mb = max_size_mb
|
|
self.allowed_extensions = allowed_extensions or {".xlsx", ".docx", ".pptx"}
|
|
self.scan_content = scan_content
|
|
|
|
async def validate_async(self, file: UploadFile) -> ValidationResult:
|
|
"""
|
|
Validate an uploaded file asynchronously
|
|
Returns ValidationResult with is_valid, errors, warnings
|
|
"""
|
|
errors = []
|
|
warnings = []
|
|
data = {}
|
|
|
|
try:
|
|
# Validate filename
|
|
if not file.filename:
|
|
errors.append("Filename is required")
|
|
return ValidationResult(is_valid=False, errors=errors)
|
|
|
|
# Sanitize filename
|
|
try:
|
|
safe_filename = self._sanitize_filename(file.filename)
|
|
data["safe_filename"] = safe_filename
|
|
except ValidationError as e:
|
|
errors.append(str(e.message))
|
|
return ValidationResult(is_valid=False, errors=errors)
|
|
|
|
# Validate extension
|
|
try:
|
|
extension = self._validate_extension(safe_filename)
|
|
data["extension"] = extension
|
|
except ValidationError as e:
|
|
errors.append(str(e.message))
|
|
return ValidationResult(is_valid=False, errors=errors)
|
|
|
|
# Read file content for validation
|
|
content = await file.read()
|
|
await file.seek(0) # Reset for later processing
|
|
|
|
# Validate file size
|
|
file_size = len(content)
|
|
data["size_bytes"] = file_size
|
|
data["size_mb"] = round(file_size / (1024*1024), 2)
|
|
|
|
if file_size > self.max_size_bytes:
|
|
errors.append(f"File too large. Maximum size is {self.max_size_mb}MB, got {file_size / (1024*1024):.1f}MB")
|
|
return ValidationResult(is_valid=False, errors=errors, data=data)
|
|
|
|
if file_size == 0:
|
|
errors.append("File is empty")
|
|
return ValidationResult(is_valid=False, errors=errors, data=data)
|
|
|
|
# Warn about large files
|
|
if file_size > self.max_size_bytes * 0.8:
|
|
warnings.append(f"File is {data['size_mb']}MB, approaching the {self.max_size_mb}MB limit")
|
|
|
|
# Validate magic bytes
|
|
if self.scan_content:
|
|
try:
|
|
self._validate_magic_bytes(content, extension)
|
|
except ValidationError as e:
|
|
errors.append(str(e.message))
|
|
return ValidationResult(is_valid=False, errors=errors, data=data)
|
|
|
|
# Validate MIME type
|
|
try:
|
|
mime_type = self._detect_mime_type(content)
|
|
data["mime_type"] = mime_type
|
|
self._validate_mime_type(mime_type, extension)
|
|
except ValidationError as e:
|
|
warnings.append(f"MIME type warning: {e.message}")
|
|
except Exception:
|
|
warnings.append("Could not verify MIME type")
|
|
|
|
data["original_filename"] = file.filename
|
|
|
|
return ValidationResult(is_valid=True, errors=errors, warnings=warnings, data=data)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Validation error: {str(e)}")
|
|
errors.append(f"Validation failed: {str(e)}")
|
|
return ValidationResult(is_valid=False, errors=errors, warnings=warnings, data=data)
|
|
|
|
async def validate(self, file: UploadFile) -> dict:
|
|
"""
|
|
Validate an uploaded file
|
|
Returns validation info dict or raises ValidationError
|
|
"""
|
|
# Validate filename
|
|
if not file.filename:
|
|
raise ValidationError(
|
|
"Filename is required",
|
|
code="missing_filename"
|
|
)
|
|
|
|
# Sanitize filename
|
|
safe_filename = self._sanitize_filename(file.filename)
|
|
|
|
# Validate extension
|
|
extension = self._validate_extension(safe_filename)
|
|
|
|
# Read file content for validation
|
|
content = await file.read()
|
|
await file.seek(0) # Reset for later processing
|
|
|
|
# Validate file size
|
|
file_size = len(content)
|
|
if file_size > self.max_size_bytes:
|
|
raise ValidationError(
|
|
f"File too large. Maximum size is {self.max_size_mb}MB, got {file_size / (1024*1024):.1f}MB",
|
|
code="file_too_large",
|
|
details={"max_mb": self.max_size_mb, "actual_mb": round(file_size / (1024*1024), 2)}
|
|
)
|
|
|
|
if file_size == 0:
|
|
raise ValidationError(
|
|
"File is empty",
|
|
code="empty_file"
|
|
)
|
|
|
|
# Validate magic bytes (file signature)
|
|
if self.scan_content:
|
|
self._validate_magic_bytes(content, extension)
|
|
|
|
# Validate MIME type
|
|
mime_type = self._detect_mime_type(content)
|
|
self._validate_mime_type(mime_type, extension)
|
|
|
|
return {
|
|
"original_filename": file.filename,
|
|
"safe_filename": safe_filename,
|
|
"extension": extension,
|
|
"size_bytes": file_size,
|
|
"size_mb": round(file_size / (1024*1024), 2),
|
|
"mime_type": mime_type
|
|
}
|
|
|
|
def _sanitize_filename(self, filename: str) -> str:
|
|
"""Sanitize filename to prevent path traversal and other attacks"""
|
|
# Remove path components
|
|
filename = Path(filename).name
|
|
|
|
# Remove null bytes and control characters
|
|
filename = re.sub(r'[\x00-\x1f\x7f-\x9f]', '', filename)
|
|
|
|
# Remove potentially dangerous characters
|
|
filename = re.sub(r'[<>:"/\\|?*]', '_', filename)
|
|
|
|
# Limit length
|
|
if len(filename) > 255:
|
|
name, ext = filename.rsplit('.', 1) if '.' in filename else (filename, '')
|
|
filename = name[:250] + ('.' + ext if ext else '')
|
|
|
|
# Ensure not empty after sanitization
|
|
if not filename or filename.strip() == '':
|
|
raise ValidationError(
|
|
"Invalid filename",
|
|
code="invalid_filename"
|
|
)
|
|
|
|
return filename
|
|
|
|
def _validate_extension(self, filename: str) -> str:
|
|
"""Validate and return the file extension"""
|
|
if '.' not in filename:
|
|
raise ValidationError(
|
|
f"File must have an extension. Supported: {', '.join(self.allowed_extensions)}",
|
|
code="missing_extension",
|
|
details={"allowed_extensions": list(self.allowed_extensions)}
|
|
)
|
|
|
|
extension = '.' + filename.rsplit('.', 1)[1].lower()
|
|
|
|
if extension not in self.allowed_extensions:
|
|
raise ValidationError(
|
|
f"File type '{extension}' not supported. Supported types: {', '.join(self.allowed_extensions)}",
|
|
code="unsupported_file_type",
|
|
details={"extension": extension, "allowed_extensions": list(self.allowed_extensions)}
|
|
)
|
|
|
|
return extension
|
|
|
|
def _validate_magic_bytes(self, content: bytes, extension: str):
|
|
"""Validate file magic bytes match expected format"""
|
|
# All supported formats are Office Open XML (ZIP-based)
|
|
if not content.startswith(self.OFFICE_MAGIC_BYTES):
|
|
raise ValidationError(
|
|
"File content does not match expected format. The file may be corrupted or not a valid Office document.",
|
|
code="invalid_file_content"
|
|
)
|
|
|
|
def _detect_mime_type(self, content: bytes) -> str:
|
|
"""Detect MIME type from file content"""
|
|
try:
|
|
mime = magic.Magic(mime=True)
|
|
return mime.from_buffer(content)
|
|
except Exception:
|
|
# Fallback to basic detection
|
|
if content.startswith(self.OFFICE_MAGIC_BYTES):
|
|
return "application/zip"
|
|
return "application/octet-stream"
|
|
|
|
def _validate_mime_type(self, mime_type: str, extension: str):
|
|
"""Validate MIME type matches extension"""
|
|
# Office Open XML files may be detected as ZIP
|
|
allowed_mimes = list(self.ALLOWED_MIME_TYPES.keys()) + ["application/zip", "application/octet-stream"]
|
|
|
|
if mime_type not in allowed_mimes:
|
|
raise ValidationError(
|
|
f"Invalid file type detected. Expected Office document, got: {mime_type}",
|
|
code="invalid_mime_type",
|
|
details={"detected_mime": mime_type}
|
|
)
|
|
|
|
|
|
class LanguageValidator:
|
|
"""Validates language codes"""
|
|
|
|
SUPPORTED_LANGUAGES = {
|
|
# ISO 639-1 codes
|
|
"af", "sq", "am", "ar", "hy", "az", "eu", "be", "bn", "bs",
|
|
"bg", "ca", "ceb", "zh", "zh-CN", "zh-TW", "co", "hr", "cs",
|
|
"da", "nl", "en", "eo", "et", "fi", "fr", "fy", "gl", "ka",
|
|
"de", "el", "gu", "ht", "ha", "haw", "he", "hi", "hmn", "hu",
|
|
"is", "ig", "id", "ga", "it", "ja", "jv", "kn", "kk", "km",
|
|
"rw", "ko", "ku", "ky", "lo", "la", "lv", "lt", "lb", "mk",
|
|
"mg", "ms", "ml", "mt", "mi", "mr", "mn", "my", "ne", "no",
|
|
"ny", "or", "ps", "fa", "pl", "pt", "pa", "ro", "ru", "sm",
|
|
"gd", "sr", "st", "sn", "sd", "si", "sk", "sl", "so", "es",
|
|
"su", "sw", "sv", "tl", "tg", "ta", "tt", "te", "th", "tr",
|
|
"tk", "uk", "ur", "ug", "uz", "vi", "cy", "xh", "yi", "yo",
|
|
"zu", "auto"
|
|
}
|
|
|
|
LANGUAGE_NAMES = {
|
|
"en": "English", "es": "Spanish", "fr": "French", "de": "German",
|
|
"it": "Italian", "pt": "Portuguese", "ru": "Russian", "zh": "Chinese",
|
|
"zh-CN": "Chinese (Simplified)", "zh-TW": "Chinese (Traditional)",
|
|
"ja": "Japanese", "ko": "Korean", "ar": "Arabic", "hi": "Hindi",
|
|
"nl": "Dutch", "pl": "Polish", "tr": "Turkish", "sv": "Swedish",
|
|
"da": "Danish", "no": "Norwegian", "fi": "Finnish", "cs": "Czech",
|
|
"el": "Greek", "th": "Thai", "vi": "Vietnamese", "id": "Indonesian",
|
|
"uk": "Ukrainian", "ro": "Romanian", "hu": "Hungarian", "auto": "Auto-detect"
|
|
}
|
|
|
|
@classmethod
|
|
def validate(cls, language_code: str, field_name: str = "language") -> str:
|
|
"""Validate and normalize language code"""
|
|
if not language_code:
|
|
raise ValidationError(
|
|
f"{field_name} is required",
|
|
code="missing_language"
|
|
)
|
|
|
|
# Normalize
|
|
normalized = language_code.strip().lower()
|
|
|
|
# Handle common variations
|
|
if normalized in ["chinese", "cn"]:
|
|
normalized = "zh-CN"
|
|
elif normalized in ["chinese-traditional", "tw"]:
|
|
normalized = "zh-TW"
|
|
|
|
if normalized not in cls.SUPPORTED_LANGUAGES:
|
|
raise ValidationError(
|
|
f"Unsupported language code: '{language_code}'. See /languages for supported codes.",
|
|
code="unsupported_language",
|
|
details={"language": language_code}
|
|
)
|
|
|
|
return normalized
|
|
|
|
@classmethod
|
|
def get_language_name(cls, code: str) -> str:
|
|
"""Get human-readable language name"""
|
|
return cls.LANGUAGE_NAMES.get(code, code.upper())
|
|
|
|
|
|
class ProviderValidator:
|
|
"""Validates translation provider configuration"""
|
|
|
|
SUPPORTED_PROVIDERS = {"google", "ollama", "deepl", "libre", "openai", "webllm", "openrouter"}
|
|
|
|
@classmethod
|
|
def validate(cls, provider: str, **kwargs) -> dict:
|
|
"""Validate provider and its required configuration"""
|
|
if not provider:
|
|
raise ValidationError(
|
|
"Translation provider is required",
|
|
code="missing_provider"
|
|
)
|
|
|
|
normalized = provider.strip().lower()
|
|
|
|
if normalized not in cls.SUPPORTED_PROVIDERS:
|
|
raise ValidationError(
|
|
f"Unsupported provider: '{provider}'. Supported: {', '.join(cls.SUPPORTED_PROVIDERS)}",
|
|
code="unsupported_provider",
|
|
details={"provider": provider, "supported": list(cls.SUPPORTED_PROVIDERS)}
|
|
)
|
|
|
|
# Provider-specific validation
|
|
if normalized == "deepl":
|
|
if not kwargs.get("deepl_api_key"):
|
|
raise ValidationError(
|
|
"DeepL API key is required when using DeepL provider",
|
|
code="missing_deepl_key"
|
|
)
|
|
|
|
elif normalized == "openai":
|
|
if not kwargs.get("openai_api_key"):
|
|
raise ValidationError(
|
|
"OpenAI API key is required when using OpenAI provider",
|
|
code="missing_openai_key"
|
|
)
|
|
|
|
elif normalized == "ollama":
|
|
# Ollama doesn't require API key but may need model
|
|
model = kwargs.get("ollama_model", "")
|
|
if not model:
|
|
logger.warning("No Ollama model specified, will use default")
|
|
|
|
return {"provider": normalized, "validated": True}
|
|
|
|
|
|
class InputSanitizer:
|
|
"""Sanitizes user inputs to prevent injection attacks"""
|
|
|
|
@staticmethod
|
|
def sanitize_text(text: str, max_length: int = 10000) -> str:
|
|
"""Sanitize text input"""
|
|
if not text:
|
|
return ""
|
|
|
|
# Remove null bytes
|
|
text = text.replace('\x00', '')
|
|
|
|
# Limit length
|
|
if len(text) > max_length:
|
|
text = text[:max_length]
|
|
|
|
return text.strip()
|
|
|
|
@staticmethod
|
|
def sanitize_language_code(code: str) -> str:
|
|
"""Sanitize and normalize language code"""
|
|
if not code:
|
|
return "auto"
|
|
|
|
# Remove dangerous characters, keep only alphanumeric and hyphen
|
|
code = re.sub(r'[^a-zA-Z0-9\-]', '', code.strip())
|
|
|
|
# Limit length
|
|
if len(code) > 10:
|
|
code = code[:10]
|
|
|
|
return code.lower() if code else "auto"
|
|
|
|
@staticmethod
|
|
def sanitize_url(url: str) -> str:
|
|
"""Sanitize URL input"""
|
|
if not url:
|
|
return ""
|
|
|
|
url = url.strip()
|
|
|
|
# Basic URL validation
|
|
if not re.match(r'^https?://', url, re.IGNORECASE):
|
|
raise ValidationError(
|
|
"Invalid URL format. Must start with http:// or https://",
|
|
code="invalid_url"
|
|
)
|
|
|
|
# Remove trailing slashes
|
|
url = url.rstrip('/')
|
|
|
|
return url
|
|
|
|
@staticmethod
|
|
def sanitize_api_key(key: str) -> str:
|
|
"""Sanitize API key (just trim, no logging)"""
|
|
if not key:
|
|
return ""
|
|
return key.strip()
|
|
|
|
|
|
# Default validators
|
|
file_validator = FileValidator()
|