""" Input Validation Module for SaaS robustness Validates all user inputs before processing """ import re import magic from pathlib import Path from typing import Optional, List, Set from fastapi import UploadFile, HTTPException import logging logger = logging.getLogger(__name__) class ValidationError(Exception): """Custom validation error with user-friendly messages""" def __init__(self, message: str, code: str = "validation_error", details: Optional[dict] = None): self.message = message self.code = code self.details = details or {} super().__init__(message) class ValidationResult: """Result of a validation check""" def __init__(self, is_valid: bool = True, errors: List[str] = None, warnings: List[str] = None, data: dict = None): self.is_valid = is_valid self.errors = errors or [] self.warnings = warnings or [] self.data = data or {} class FileValidator: """Validates uploaded files for security and compatibility""" # Allowed MIME types mapped to extensions ALLOWED_MIME_TYPES = { "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx", "application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx", "application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx", } # Magic bytes for Office Open XML files (ZIP format) OFFICE_MAGIC_BYTES = b"PK\x03\x04" def __init__( self, max_size_mb: int = 50, allowed_extensions: Set[str] = None, scan_content: bool = True ): self.max_size_bytes = max_size_mb * 1024 * 1024 self.max_size_mb = max_size_mb self.allowed_extensions = allowed_extensions or {".xlsx", ".docx", ".pptx"} self.scan_content = scan_content async def validate_async(self, file: UploadFile) -> ValidationResult: """ Validate an uploaded file asynchronously Returns ValidationResult with is_valid, errors, warnings """ errors = [] warnings = [] data = {} try: # Validate filename if not file.filename: errors.append("Filename is required") return ValidationResult(is_valid=False, errors=errors) # Sanitize filename try: safe_filename = self._sanitize_filename(file.filename) data["safe_filename"] = safe_filename except ValidationError as e: errors.append(str(e.message)) return ValidationResult(is_valid=False, errors=errors) # Validate extension try: extension = self._validate_extension(safe_filename) data["extension"] = extension except ValidationError as e: errors.append(str(e.message)) return ValidationResult(is_valid=False, errors=errors) # Read file content for validation content = await file.read() await file.seek(0) # Reset for later processing # Validate file size file_size = len(content) data["size_bytes"] = file_size data["size_mb"] = round(file_size / (1024*1024), 2) if file_size > self.max_size_bytes: errors.append(f"File too large. Maximum size is {self.max_size_mb}MB, got {file_size / (1024*1024):.1f}MB") return ValidationResult(is_valid=False, errors=errors, data=data) if file_size == 0: errors.append("File is empty") return ValidationResult(is_valid=False, errors=errors, data=data) # Warn about large files if file_size > self.max_size_bytes * 0.8: warnings.append(f"File is {data['size_mb']}MB, approaching the {self.max_size_mb}MB limit") # Validate magic bytes if self.scan_content: try: self._validate_magic_bytes(content, extension) except ValidationError as e: errors.append(str(e.message)) return ValidationResult(is_valid=False, errors=errors, data=data) # Validate MIME type try: mime_type = self._detect_mime_type(content) data["mime_type"] = mime_type self._validate_mime_type(mime_type, extension) except ValidationError as e: warnings.append(f"MIME type warning: {e.message}") except Exception: warnings.append("Could not verify MIME type") data["original_filename"] = file.filename return ValidationResult(is_valid=True, errors=errors, warnings=warnings, data=data) except Exception as e: logger.error(f"Validation error: {str(e)}") errors.append(f"Validation failed: {str(e)}") return ValidationResult(is_valid=False, errors=errors, warnings=warnings, data=data) async def validate(self, file: UploadFile) -> dict: """ Validate an uploaded file Returns validation info dict or raises ValidationError """ # Validate filename if not file.filename: raise ValidationError( "Filename is required", code="missing_filename" ) # Sanitize filename safe_filename = self._sanitize_filename(file.filename) # Validate extension extension = self._validate_extension(safe_filename) # Read file content for validation content = await file.read() await file.seek(0) # Reset for later processing # Validate file size file_size = len(content) if file_size > self.max_size_bytes: raise ValidationError( f"File too large. Maximum size is {self.max_size_mb}MB, got {file_size / (1024*1024):.1f}MB", code="file_too_large", details={"max_mb": self.max_size_mb, "actual_mb": round(file_size / (1024*1024), 2)} ) if file_size == 0: raise ValidationError( "File is empty", code="empty_file" ) # Validate magic bytes (file signature) if self.scan_content: self._validate_magic_bytes(content, extension) # Validate MIME type mime_type = self._detect_mime_type(content) self._validate_mime_type(mime_type, extension) return { "original_filename": file.filename, "safe_filename": safe_filename, "extension": extension, "size_bytes": file_size, "size_mb": round(file_size / (1024*1024), 2), "mime_type": mime_type } def _sanitize_filename(self, filename: str) -> str: """Sanitize filename to prevent path traversal and other attacks""" # Remove path components filename = Path(filename).name # Remove null bytes and control characters filename = re.sub(r'[\x00-\x1f\x7f-\x9f]', '', filename) # Remove potentially dangerous characters filename = re.sub(r'[<>:"/\\|?*]', '_', filename) # Limit length if len(filename) > 255: name, ext = filename.rsplit('.', 1) if '.' in filename else (filename, '') filename = name[:250] + ('.' + ext if ext else '') # Ensure not empty after sanitization if not filename or filename.strip() == '': raise ValidationError( "Invalid filename", code="invalid_filename" ) return filename def _validate_extension(self, filename: str) -> str: """Validate and return the file extension""" if '.' not in filename: raise ValidationError( f"File must have an extension. Supported: {', '.join(self.allowed_extensions)}", code="missing_extension", details={"allowed_extensions": list(self.allowed_extensions)} ) extension = '.' + filename.rsplit('.', 1)[1].lower() if extension not in self.allowed_extensions: raise ValidationError( f"File type '{extension}' not supported. Supported types: {', '.join(self.allowed_extensions)}", code="unsupported_file_type", details={"extension": extension, "allowed_extensions": list(self.allowed_extensions)} ) return extension def _validate_magic_bytes(self, content: bytes, extension: str): """Validate file magic bytes match expected format""" # All supported formats are Office Open XML (ZIP-based) if not content.startswith(self.OFFICE_MAGIC_BYTES): raise ValidationError( "File content does not match expected format. The file may be corrupted or not a valid Office document.", code="invalid_file_content" ) def _detect_mime_type(self, content: bytes) -> str: """Detect MIME type from file content""" try: mime = magic.Magic(mime=True) return mime.from_buffer(content) except Exception: # Fallback to basic detection if content.startswith(self.OFFICE_MAGIC_BYTES): return "application/zip" return "application/octet-stream" def _validate_mime_type(self, mime_type: str, extension: str): """Validate MIME type matches extension""" # Office Open XML files may be detected as ZIP allowed_mimes = list(self.ALLOWED_MIME_TYPES.keys()) + ["application/zip", "application/octet-stream"] if mime_type not in allowed_mimes: raise ValidationError( f"Invalid file type detected. Expected Office document, got: {mime_type}", code="invalid_mime_type", details={"detected_mime": mime_type} ) class LanguageValidator: """Validates language codes""" SUPPORTED_LANGUAGES = { # ISO 639-1 codes "af", "sq", "am", "ar", "hy", "az", "eu", "be", "bn", "bs", "bg", "ca", "ceb", "zh", "zh-CN", "zh-TW", "co", "hr", "cs", "da", "nl", "en", "eo", "et", "fi", "fr", "fy", "gl", "ka", "de", "el", "gu", "ht", "ha", "haw", "he", "hi", "hmn", "hu", "is", "ig", "id", "ga", "it", "ja", "jv", "kn", "kk", "km", "rw", "ko", "ku", "ky", "lo", "la", "lv", "lt", "lb", "mk", "mg", "ms", "ml", "mt", "mi", "mr", "mn", "my", "ne", "no", "ny", "or", "ps", "fa", "pl", "pt", "pa", "ro", "ru", "sm", "gd", "sr", "st", "sn", "sd", "si", "sk", "sl", "so", "es", "su", "sw", "sv", "tl", "tg", "ta", "tt", "te", "th", "tr", "tk", "uk", "ur", "ug", "uz", "vi", "cy", "xh", "yi", "yo", "zu", "auto" } LANGUAGE_NAMES = { "en": "English", "es": "Spanish", "fr": "French", "de": "German", "it": "Italian", "pt": "Portuguese", "ru": "Russian", "zh": "Chinese", "zh-CN": "Chinese (Simplified)", "zh-TW": "Chinese (Traditional)", "ja": "Japanese", "ko": "Korean", "ar": "Arabic", "hi": "Hindi", "nl": "Dutch", "pl": "Polish", "tr": "Turkish", "sv": "Swedish", "da": "Danish", "no": "Norwegian", "fi": "Finnish", "cs": "Czech", "el": "Greek", "th": "Thai", "vi": "Vietnamese", "id": "Indonesian", "uk": "Ukrainian", "ro": "Romanian", "hu": "Hungarian", "auto": "Auto-detect" } @classmethod def validate(cls, language_code: str, field_name: str = "language") -> str: """Validate and normalize language code""" if not language_code: raise ValidationError( f"{field_name} is required", code="missing_language" ) # Normalize normalized = language_code.strip().lower() # Handle common variations if normalized in ["chinese", "cn"]: normalized = "zh-CN" elif normalized in ["chinese-traditional", "tw"]: normalized = "zh-TW" if normalized not in cls.SUPPORTED_LANGUAGES: raise ValidationError( f"Unsupported language code: '{language_code}'. See /languages for supported codes.", code="unsupported_language", details={"language": language_code} ) return normalized @classmethod def get_language_name(cls, code: str) -> str: """Get human-readable language name""" return cls.LANGUAGE_NAMES.get(code, code.upper()) class ProviderValidator: """Validates translation provider configuration""" SUPPORTED_PROVIDERS = {"google", "ollama", "deepl", "libre", "openai", "webllm", "openrouter"} @classmethod def validate(cls, provider: str, **kwargs) -> dict: """Validate provider and its required configuration""" if not provider: raise ValidationError( "Translation provider is required", code="missing_provider" ) normalized = provider.strip().lower() if normalized not in cls.SUPPORTED_PROVIDERS: raise ValidationError( f"Unsupported provider: '{provider}'. Supported: {', '.join(cls.SUPPORTED_PROVIDERS)}", code="unsupported_provider", details={"provider": provider, "supported": list(cls.SUPPORTED_PROVIDERS)} ) # Provider-specific validation if normalized == "deepl": if not kwargs.get("deepl_api_key"): raise ValidationError( "DeepL API key is required when using DeepL provider", code="missing_deepl_key" ) elif normalized == "openai": if not kwargs.get("openai_api_key"): raise ValidationError( "OpenAI API key is required when using OpenAI provider", code="missing_openai_key" ) elif normalized == "ollama": # Ollama doesn't require API key but may need model model = kwargs.get("ollama_model", "") if not model: logger.warning("No Ollama model specified, will use default") return {"provider": normalized, "validated": True} class InputSanitizer: """Sanitizes user inputs to prevent injection attacks""" @staticmethod def sanitize_text(text: str, max_length: int = 10000) -> str: """Sanitize text input""" if not text: return "" # Remove null bytes text = text.replace('\x00', '') # Limit length if len(text) > max_length: text = text[:max_length] return text.strip() @staticmethod def sanitize_language_code(code: str) -> str: """Sanitize and normalize language code""" if not code: return "auto" # Remove dangerous characters, keep only alphanumeric and hyphen code = re.sub(r'[^a-zA-Z0-9\-]', '', code.strip()) # Limit length if len(code) > 10: code = code[:10] return code.lower() if code else "auto" @staticmethod def sanitize_url(url: str) -> str: """Sanitize URL input""" if not url: return "" url = url.strip() # Basic URL validation if not re.match(r'^https?://', url, re.IGNORECASE): raise ValidationError( "Invalid URL format. Must start with http:// or https://", code="invalid_url" ) # Remove trailing slashes url = url.rstrip('/') return url @staticmethod def sanitize_api_key(key: str) -> str: """Sanitize API key (just trim, no logging)""" if not key: return "" return key.strip() # Default validators file_validator = FileValidator()