Files
office_translator/middleware/security.py
Sepehr Ramezani 26bd096a06 feat: production deployment - full update with providers, admin, glossaries, pricing, tests
Major changes across backend, frontend, infrastructure:
- Provider system with model selection (Google, DeepL, OpenAI, Ollama, Google Cloud)
- Admin panel: user management, pricing, settings
- Glossary system with CSV import/export
- Subscription and tier quota management
- Security hardening (rate limiting, API key auth, path traversal fixes)
- Docker compose for dev, prod, and IONOS deployment
- Alembic migrations for new tables
- Frontend: dashboard, pricing page, landing page, i18n (en/fr)
- Test suite and verification scripts

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-04-25 15:01:47 +02:00

141 lines
4.6 KiB
Python

"""
Security & request logging middlewares.
RequestLoggingMiddleware is responsible for:
- Assigning a request_id to each request
- Binding request_id and user_id into structlog context so all logs
include these fields in JSON output (Story 6.4).
"""
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
from core.logging import get_logger, bind_request_context, clear_request_context
logger = get_logger(__name__)
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""Add security headers to all responses"""
def __init__(self, app, config: dict = None):
super().__init__(app)
self.config = config or {}
async def dispatch(self, request: Request, call_next) -> Response:
response = await call_next(request)
# Prevent clickjacking
response.headers["X-Frame-Options"] = "DENY"
# Prevent MIME type sniffing
response.headers["X-Content-Type-Options"] = "nosniff"
# Enable XSS filter
response.headers["X-XSS-Protection"] = "1; mode=block"
# Referrer policy
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
# Permissions policy
response.headers["Permissions-Policy"] = "geolocation=(), microphone=(), camera=()"
# Content Security Policy (adjust for your frontend)
if not request.url.path.startswith("/docs") and not request.url.path.startswith("/redoc"):
response.headers["Content-Security-Policy"] = (
"default-src 'self'; "
"script-src 'self' 'unsafe-inline' 'unsafe-eval' blob:; "
"style-src 'self' 'unsafe-inline'; "
"img-src 'self' data: blob:; "
"font-src 'self' data:; "
"connect-src 'self' http://localhost:* https://localhost:* ws://localhost:*; "
"worker-src 'self' blob:; "
)
# HSTS (only in production with HTTPS)
if self.config.get("enable_hsts", False):
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
return response
class RequestLoggingMiddleware(BaseHTTPMiddleware):
"""Log all requests for monitoring and debugging with structured context."""
def __init__(self, app, log_body: bool = False):
super().__init__(app)
self.log_body = log_body
async def dispatch(self, request: Request, call_next) -> Response:
import time
import uuid
# Generate request ID
request_id = str(uuid.uuid4())[:8]
request.state.request_id = request_id
# Attempt to get user id if auth has already populated it
user_id = getattr(getattr(request.state, "user", None), "id", None)
# Bind context so all downstream logs include request_id/user_id
bind_request_context(request_id=request_id, user_id=str(user_id) if user_id else None)
# Get client info
client_ip = self._get_client_ip(request)
# Log request start
start_time = time.time()
logger.info(
"request_started",
method=request.method,
path=request.url.path,
client_ip=client_ip,
)
try:
response = await call_next(request)
# Log request completion
duration = time.time() - start_time
logger.info(
"request_completed",
method=request.method,
path=request.url.path,
status_code=response.status_code,
duration_ms=round(duration * 1000, 3),
)
# Add request ID to response headers
response.headers["X-Request-ID"] = request_id
return response
except Exception as e:
duration = time.time() - start_time
logger.exception(
"request_error",
method=request.method,
path=request.url.path,
duration_ms=round(duration * 1000, 3),
error=str(e),
)
raise
finally:
# Clear contextvars so they don't leak across requests
clear_request_context()
def _get_client_ip(self, request: Request) -> str:
"""Get real client IP from headers or connection"""
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
if request.client:
return request.client.host
return "unknown"