96 lines
2.6 KiB
Python
96 lines
2.6 KiB
Python
"""
|
|
API dependencies.
|
|
|
|
This module provides common dependencies for API endpoints.
|
|
"""
|
|
|
|
from typing import Optional
|
|
from fastapi import Depends, HTTPException, status, Header
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.database import get_db
|
|
from app.services.apiKeyService import ApiKeyService
|
|
|
|
|
|
async def get_api_key(
|
|
x_api_key: Optional[str] = Header(None, description="API Key for authentication"),
|
|
db: Session = Depends(get_db)
|
|
) -> int:
|
|
"""
|
|
Dependency to validate API key from header.
|
|
|
|
Args:
|
|
x_api_key: API key from X-API-Key header
|
|
db: Database session
|
|
|
|
Returns:
|
|
User ID associated with valid API key
|
|
|
|
Raises:
|
|
401: If API key is missing or invalid
|
|
|
|
Example:
|
|
from fastapi import Depends, APIRouter
|
|
|
|
@router.get("/protected")
|
|
async def protected_endpoint(user_id: int = Depends(get_api_key)):
|
|
return {"user_id": user_id}
|
|
"""
|
|
if not x_api_key:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail={
|
|
"code": "MISSING_API_KEY",
|
|
"message": "X-API-Key header is required for this endpoint"
|
|
}
|
|
)
|
|
|
|
service = ApiKeyService(db)
|
|
api_key_record = service.validate_api_key(x_api_key)
|
|
|
|
if not api_key_record:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail={
|
|
"code": "INVALID_API_KEY",
|
|
"message": "Invalid API key"
|
|
}
|
|
)
|
|
|
|
return api_key_record.user_id
|
|
|
|
|
|
async def get_api_key_optional(
|
|
x_api_key: Optional[str] = Header(None, description="Optional API Key for authentication"),
|
|
db: Session = Depends(get_db)
|
|
) -> Optional[int]:
|
|
"""
|
|
Optional dependency to validate API key from header.
|
|
|
|
Args:
|
|
x_api_key: API key from X-API-Key header (optional)
|
|
db: Database session
|
|
|
|
Returns:
|
|
User ID if API key is valid, None otherwise
|
|
|
|
Example:
|
|
from fastapi import Depends, APIRouter
|
|
|
|
@router.get("/semi-public")
|
|
async def semi_public_endpoint(user_id: Optional[int] = Depends(get_api_key_optional)):
|
|
if user_id:
|
|
return {"message": "Authenticated", "user_id": user_id}
|
|
return {"message": "Unauthenticated"}
|
|
"""
|
|
if not x_api_key:
|
|
return None
|
|
|
|
service = ApiKeyService(db)
|
|
api_key_record = service.validate_api_key(x_api_key)
|
|
|
|
if not api_key_record:
|
|
return None
|
|
|
|
return api_key_record.user_id
|