143 lines
4.8 KiB
Python
143 lines
4.8 KiB
Python
"""
|
|
Security Headers Middleware for SaaS robustness
|
|
Adds security headers to all responses
|
|
"""
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.requests import Request
|
|
from starlette.responses import Response
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
|
"""Add security headers to all responses"""
|
|
|
|
def __init__(self, app, config: dict = None):
|
|
super().__init__(app)
|
|
self.config = config or {}
|
|
|
|
async def dispatch(self, request: Request, call_next) -> Response:
|
|
response = await call_next(request)
|
|
|
|
# Prevent clickjacking
|
|
response.headers["X-Frame-Options"] = "DENY"
|
|
|
|
# Prevent MIME type sniffing
|
|
response.headers["X-Content-Type-Options"] = "nosniff"
|
|
|
|
# Enable XSS filter
|
|
response.headers["X-XSS-Protection"] = "1; mode=block"
|
|
|
|
# Referrer policy
|
|
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
|
|
|
# Permissions policy
|
|
response.headers["Permissions-Policy"] = "geolocation=(), microphone=(), camera=()"
|
|
|
|
# Content Security Policy (adjust for your frontend)
|
|
if not request.url.path.startswith("/docs") and not request.url.path.startswith("/redoc"):
|
|
response.headers["Content-Security-Policy"] = (
|
|
"default-src 'self'; "
|
|
"script-src 'self' 'unsafe-inline' 'unsafe-eval' blob:; "
|
|
"style-src 'self' 'unsafe-inline'; "
|
|
"img-src 'self' data: blob:; "
|
|
"font-src 'self' data:; "
|
|
"connect-src 'self' http://localhost:* https://localhost:* ws://localhost:*; "
|
|
"worker-src 'self' blob:; "
|
|
)
|
|
|
|
# HSTS (only in production with HTTPS)
|
|
if self.config.get("enable_hsts", False):
|
|
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
|
|
|
return response
|
|
|
|
|
|
class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
|
"""Log all requests for monitoring and debugging"""
|
|
|
|
def __init__(self, app, log_body: bool = False):
|
|
super().__init__(app)
|
|
self.log_body = log_body
|
|
|
|
async def dispatch(self, request: Request, call_next) -> Response:
|
|
import time
|
|
import uuid
|
|
|
|
# Generate request ID
|
|
request_id = str(uuid.uuid4())[:8]
|
|
request.state.request_id = request_id
|
|
|
|
# Get client info
|
|
client_ip = self._get_client_ip(request)
|
|
|
|
# Log request start
|
|
start_time = time.time()
|
|
logger.info(
|
|
f"[{request_id}] {request.method} {request.url.path} "
|
|
f"from {client_ip} - Started"
|
|
)
|
|
|
|
try:
|
|
response = await call_next(request)
|
|
|
|
# Log request completion
|
|
duration = time.time() - start_time
|
|
logger.info(
|
|
f"[{request_id}] {request.method} {request.url.path} "
|
|
f"- {response.status_code} in {duration:.3f}s"
|
|
)
|
|
|
|
# Add request ID to response headers
|
|
response.headers["X-Request-ID"] = request_id
|
|
|
|
return response
|
|
|
|
except Exception as e:
|
|
duration = time.time() - start_time
|
|
logger.error(
|
|
f"[{request_id}] {request.method} {request.url.path} "
|
|
f"- ERROR in {duration:.3f}s: {str(e)}"
|
|
)
|
|
raise
|
|
|
|
def _get_client_ip(self, request: Request) -> str:
|
|
"""Get real client IP from headers or connection"""
|
|
forwarded = request.headers.get("X-Forwarded-For")
|
|
if forwarded:
|
|
return forwarded.split(",")[0].strip()
|
|
|
|
real_ip = request.headers.get("X-Real-IP")
|
|
if real_ip:
|
|
return real_ip
|
|
|
|
if request.client:
|
|
return request.client.host
|
|
|
|
return "unknown"
|
|
|
|
|
|
class ErrorHandlingMiddleware(BaseHTTPMiddleware):
|
|
"""Catch all unhandled exceptions and return proper error responses"""
|
|
|
|
async def dispatch(self, request: Request, call_next) -> Response:
|
|
from starlette.responses import JSONResponse
|
|
|
|
try:
|
|
return await call_next(request)
|
|
|
|
except Exception as e:
|
|
request_id = getattr(request.state, 'request_id', 'unknown')
|
|
logger.exception(f"[{request_id}] Unhandled exception: {str(e)}")
|
|
|
|
# Don't expose internal errors in production
|
|
return JSONResponse(
|
|
status_code=500,
|
|
content={
|
|
"error": "internal_server_error",
|
|
"message": "An unexpected error occurred. Please try again later.",
|
|
"request_id": request_id
|
|
}
|
|
)
|