""" Security & request logging middlewares. RequestLoggingMiddleware is responsible for: - Assigning a request_id to each request - Binding request_id and user_id into structlog context so all logs include these fields in JSON output (Story 6.4). """ from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import Response from core.logging import get_logger, bind_request_context, clear_request_context logger = get_logger(__name__) class SecurityHeadersMiddleware(BaseHTTPMiddleware): """Add security headers to all responses""" def __init__(self, app, config: dict = None): super().__init__(app) self.config = config or {} async def dispatch(self, request: Request, call_next) -> Response: response = await call_next(request) # Prevent clickjacking response.headers["X-Frame-Options"] = "DENY" # Prevent MIME type sniffing response.headers["X-Content-Type-Options"] = "nosniff" # Enable XSS filter response.headers["X-XSS-Protection"] = "1; mode=block" # Referrer policy response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" # Permissions policy response.headers["Permissions-Policy"] = "geolocation=(), microphone=(), camera=()" # Content Security Policy (adjust for your frontend) if not request.url.path.startswith("/docs") and not request.url.path.startswith("/redoc"): response.headers["Content-Security-Policy"] = ( "default-src 'self'; " "script-src 'self' 'unsafe-inline' 'unsafe-eval' blob:; " "style-src 'self' 'unsafe-inline'; " "img-src 'self' data: blob:; " "font-src 'self' data:; " "connect-src 'self' http://localhost:* https://localhost:* ws://localhost:*; " "worker-src 'self' blob:; " ) # HSTS (only in production with HTTPS) if self.config.get("enable_hsts", False): response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains" return response class RequestLoggingMiddleware(BaseHTTPMiddleware): """Log all requests for monitoring and debugging with structured context.""" def __init__(self, app, log_body: bool = False): super().__init__(app) self.log_body = log_body async def dispatch(self, request: Request, call_next) -> Response: import time import uuid # Generate request ID request_id = str(uuid.uuid4())[:8] request.state.request_id = request_id # Attempt to get user id if auth has already populated it user_id = getattr(getattr(request.state, "user", None), "id", None) # Bind context so all downstream logs include request_id/user_id bind_request_context(request_id=request_id, user_id=str(user_id) if user_id else None) # Get client info client_ip = self._get_client_ip(request) # Log request start start_time = time.time() logger.info( "request_started", method=request.method, path=request.url.path, client_ip=client_ip, ) try: response = await call_next(request) # Log request completion duration = time.time() - start_time logger.info( "request_completed", method=request.method, path=request.url.path, status_code=response.status_code, duration_ms=round(duration * 1000, 3), ) # Add request ID to response headers response.headers["X-Request-ID"] = request_id return response except Exception as e: duration = time.time() - start_time logger.exception( "request_error", method=request.method, path=request.url.path, duration_ms=round(duration * 1000, 3), error=str(e), ) raise finally: # Clear contextvars so they don't leak across requests clear_request_context() def _get_client_ip(self, request: Request) -> str: """Get real client IP from headers or connection""" forwarded = request.headers.get("X-Forwarded-For") if forwarded: return forwarded.split(",")[0].strip() real_ip = request.headers.get("X-Real-IP") if real_ip: return real_ip if request.client: return request.client.host return "unknown"