""" Security Headers Middleware for SaaS robustness Adds security headers to all responses """ from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import Response import logging logger = logging.getLogger(__name__) class SecurityHeadersMiddleware(BaseHTTPMiddleware): """Add security headers to all responses""" def __init__(self, app, config: dict = None): super().__init__(app) self.config = config or {} async def dispatch(self, request: Request, call_next) -> Response: response = await call_next(request) # Prevent clickjacking response.headers["X-Frame-Options"] = "DENY" # Prevent MIME type sniffing response.headers["X-Content-Type-Options"] = "nosniff" # Enable XSS filter response.headers["X-XSS-Protection"] = "1; mode=block" # Referrer policy response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" # Permissions policy response.headers["Permissions-Policy"] = "geolocation=(), microphone=(), camera=()" # Content Security Policy (adjust for your frontend) if not request.url.path.startswith("/docs") and not request.url.path.startswith("/redoc"): response.headers["Content-Security-Policy"] = ( "default-src 'self'; " "script-src 'self' 'unsafe-inline' 'unsafe-eval' blob:; " "style-src 'self' 'unsafe-inline'; " "img-src 'self' data: blob:; " "font-src 'self' data:; " "connect-src 'self' http://localhost:* https://localhost:* ws://localhost:*; " "worker-src 'self' blob:; " ) # HSTS (only in production with HTTPS) if self.config.get("enable_hsts", False): response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains" return response class RequestLoggingMiddleware(BaseHTTPMiddleware): """Log all requests for monitoring and debugging""" def __init__(self, app, log_body: bool = False): super().__init__(app) self.log_body = log_body async def dispatch(self, request: Request, call_next) -> Response: import time import uuid # Generate request ID request_id = str(uuid.uuid4())[:8] request.state.request_id = request_id # Get client info client_ip = self._get_client_ip(request) # Log request start start_time = time.time() logger.info( f"[{request_id}] {request.method} {request.url.path} " f"from {client_ip} - Started" ) try: response = await call_next(request) # Log request completion duration = time.time() - start_time logger.info( f"[{request_id}] {request.method} {request.url.path} " f"- {response.status_code} in {duration:.3f}s" ) # Add request ID to response headers response.headers["X-Request-ID"] = request_id return response except Exception as e: duration = time.time() - start_time logger.error( f"[{request_id}] {request.method} {request.url.path} " f"- ERROR in {duration:.3f}s: {str(e)}" ) raise def _get_client_ip(self, request: Request) -> str: """Get real client IP from headers or connection""" forwarded = request.headers.get("X-Forwarded-For") if forwarded: return forwarded.split(",")[0].strip() real_ip = request.headers.get("X-Real-IP") if real_ip: return real_ip if request.client: return request.client.host return "unknown" class ErrorHandlingMiddleware(BaseHTTPMiddleware): """Catch all unhandled exceptions and return proper error responses""" async def dispatch(self, request: Request, call_next) -> Response: from starlette.responses import JSONResponse try: return await call_next(request) except Exception as e: request_id = getattr(request.state, 'request_id', 'unknown') logger.exception(f"[{request_id}] Unhandled exception: {str(e)}") # Don't expose internal errors in production return JSONResponse( status_code=500, content={ "error": "internal_server_error", "message": "An unexpected error occurred. Please try again later.", "request_id": request_id } )