Files
office_translator/services/auth_service.py
2026-03-07 11:42:58 +01:00

572 lines
18 KiB
Python

"""
Authentication service with JWT tokens and password hashing
This service provides user authentication with automatic backend selection:
- If DATABASE_URL is configured: Uses PostgreSQL database
- Otherwise: Falls back to JSON file storage (development mode)
"""
import os
import secrets
import hashlib
import uuid
import time
from datetime import datetime, timedelta, timezone
from typing import Optional, Dict, Any
import json
from pathlib import Path
import logging
logger = logging.getLogger(__name__)
# Try to import optional dependencies
try:
import jwt
JWT_AVAILABLE = True
except ImportError:
JWT_AVAILABLE = False
logger.warning("PyJWT not installed. Using fallback token encoding.")
try:
from passlib.context import CryptContext
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
PASSLIB_AVAILABLE = True
except ImportError:
PASSLIB_AVAILABLE = False
logger.warning("passlib not installed. Using SHA256 fallback for password hashing.")
# Check if database is configured
DATABASE_URL = os.getenv("DATABASE_URL", "")
USE_DATABASE = bool(DATABASE_URL and DATABASE_URL.startswith("postgresql"))
if USE_DATABASE:
try:
from database.repositories import UserRepository
from database.connection import get_sync_session, init_db as _init_db
from database import models as db_models
DATABASE_AVAILABLE = True
logger.info("Database backend enabled for authentication")
except ImportError as e:
DATABASE_AVAILABLE = False
USE_DATABASE = False
logger.warning(f"Database modules not available: {e}. Using JSON storage.")
else:
DATABASE_AVAILABLE = False
logger.info(
"Using JSON file storage for authentication (DATABASE_URL not configured)"
)
from models.subscription import User, UserCreate, PlanType, SubscriptionStatus, PLANS
# Configuration
_jwt_secret = os.getenv("JWT_SECRET", os.getenv("JWT_SECRET_KEY"))
if not _jwt_secret:
_jwt_secret = secrets.token_urlsafe(32)
logger.critical(
"SECURITY: JWT_SECRET_KEY is not configured! Using an ephemeral random key. "
"ALL JWT TOKENS WILL BE INVALIDATED ON EVERY RESTART. "
"Set JWT_SECRET_KEY in your .env file immediately."
)
SECRET_KEY = _jwt_secret
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 15
REFRESH_TOKEN_EXPIRE_DAYS = 7
# Simple file-based storage (used when database is not configured)
USERS_FILE = Path("data/users.json")
USERS_FILE.parent.mkdir(exist_ok=True)
# Token blocklist: jti → expiry timestamp (Unix).
# Uses Redis when available (persistent across restarts), falls back to in-memory.
_revoked_jtis: dict[str, float] = {}
_redis_blocklist_client = None
def _get_blocklist_redis():
"""Return Redis client for token blocklist, or None if unavailable."""
global _redis_blocklist_client
if _redis_blocklist_client is not None:
return _redis_blocklist_client if _redis_blocklist_client is not False else None
redis_url = os.getenv("REDIS_URL", "")
if not redis_url:
_redis_blocklist_client = False
return None
try:
import redis as redis_lib
client = redis_lib.from_url(redis_url, decode_responses=True)
client.ping()
_redis_blocklist_client = client
logger.info("Token blocklist using Redis (persistent across restarts)")
return client
except Exception as e:
logger.warning(f"Redis unavailable for token blocklist, using in-memory: {e}")
_redis_blocklist_client = False
return None
def revoke_token_jti(jti: str, expires_at: float) -> None:
"""Add a JTI to the blocklist (revoked until its expiry time)."""
ttl = max(1, int(expires_at - time.time()))
redis = _get_blocklist_redis()
if redis:
try:
redis.setex(f"revoked_jti:{jti}", ttl, "1")
return
except Exception as e:
logger.warning(f"Redis revoke failed, falling back to memory: {e}")
_revoked_jtis[jti] = expires_at
def is_token_revoked(jti: str) -> bool:
"""Return True if JTI is revoked. Lazy GC of expired in-memory entries."""
if not jti:
return False
redis = _get_blocklist_redis()
if redis:
try:
return redis.exists(f"revoked_jti:{jti}") == 1
except Exception as e:
logger.warning(f"Redis revoke check failed, falling back to memory: {e}")
now = time.time()
expired = [k for k, v in _revoked_jtis.items() if v < now]
for k in expired:
_revoked_jtis.pop(k, None)
return jti in _revoked_jtis
def hash_password(password: str) -> str:
"""Hash a password using bcrypt or fallback to SHA256"""
if PASSLIB_AVAILABLE:
return pwd_context.hash(password)
else:
# Fallback to SHA256 with salt
salt = secrets.token_hex(16)
hashed = hashlib.sha256(f"{salt}{password}".encode()).hexdigest()
return f"sha256${salt}${hashed}"
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against its hash"""
if PASSLIB_AVAILABLE and not hashed_password.startswith("sha256$"):
return pwd_context.verify(plain_password, hashed_password)
else:
# Fallback SHA256 verification
parts = hashed_password.split("$")
if len(parts) == 3 and parts[0] == "sha256":
salt = parts[1]
expected_hash = parts[2]
actual_hash = hashlib.sha256(f"{salt}{plain_password}".encode()).hexdigest()
return secrets.compare_digest(actual_hash, expected_hash)
return False
def create_access_token(
user_id: str, tier: str = "free", expires_delta: Optional[timedelta] = None
) -> str:
"""Create a JWT access token with tier claim for quick access"""
if not JWT_AVAILABLE:
token_data = {
"user_id": user_id,
"tier": tier,
"exp": (
datetime.now(timezone.utc)
+ (expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES))
).isoformat(),
}
import base64
return base64.urlsafe_b64encode(json.dumps(token_data).encode()).decode()
expire = datetime.now(timezone.utc) + (
expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
)
to_encode = {
"sub": user_id,
"tier": tier,
"exp": expire,
"type": "access",
"jti": str(uuid.uuid4()),
}
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
def create_refresh_token(
user_id: str, expires_delta: Optional[timedelta] = None
) -> str:
"""Create a JWT refresh token (7 days by default)"""
if not JWT_AVAILABLE:
token_data = {
"user_id": user_id,
"exp": (
datetime.now(timezone.utc)
+ (expires_delta or timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS))
).isoformat(),
}
import base64
return base64.urlsafe_b64encode(json.dumps(token_data).encode()).decode()
expire = datetime.now(timezone.utc) + (
expires_delta or timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
)
to_encode = {
"sub": user_id,
"exp": expire,
"type": "refresh",
"jti": str(uuid.uuid4()),
}
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
def verify_token(token: str) -> Optional[Dict[str, Any]]:
"""Verify a JWT token and return payload"""
if not JWT_AVAILABLE:
try:
import base64
data = json.loads(base64.urlsafe_b64decode(token.encode()).decode())
exp = datetime.fromisoformat(data["exp"])
if exp < datetime.now(timezone.utc):
return None
return {"sub": data["user_id"]}
except Exception:
return None
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
jti = payload.get("jti")
if jti and is_token_revoked(jti):
return None
return payload
except jwt.ExpiredSignatureError:
return None
except jwt.PyJWTError:
return None
def load_users() -> Dict[str, Dict]:
"""Load users from file storage (JSON backend only)"""
if USERS_FILE.exists():
try:
with open(USERS_FILE, "r") as f:
return json.load(f)
except Exception as e:
logger.error(f"Failed to load users file: {e}")
return {}
return {}
def save_users(users: Dict[str, Dict]):
"""Save users to file storage (JSON backend only)"""
with open(USERS_FILE, "w") as f:
json.dump(users, f, indent=2, default=str)
def _db_user_to_model(db_user) -> User:
"""Convert database user model to Pydantic User model"""
return User(
id=str(db_user.id),
email=db_user.email,
name=db_user.name or "",
password_hash=db_user.password_hash,
avatar_url=db_user.avatar_url,
plan=PlanType(db_user.plan) if db_user.plan else PlanType.FREE,
subscription_status=SubscriptionStatus(db_user.subscription_status)
if db_user.subscription_status
else SubscriptionStatus.ACTIVE,
stripe_customer_id=db_user.stripe_customer_id,
stripe_subscription_id=db_user.stripe_subscription_id,
docs_translated_this_month=db_user.docs_translated_this_month or 0,
pages_translated_this_month=db_user.pages_translated_this_month or 0,
api_calls_this_month=db_user.api_calls_this_month or 0,
daily_translation_count=getattr(db_user, "daily_translation_count", 0) or 0,
extra_credits=db_user.extra_credits or 0,
usage_reset_date=db_user.usage_reset_date or datetime.now(timezone.utc),
default_source_lang=getattr(db_user, "default_source_lang", None) or "en",
default_target_lang=getattr(db_user, "default_target_lang", None) or "es",
default_provider=getattr(db_user, "default_provider", None) or "google",
created_at=db_user.created_at or datetime.now(timezone.utc),
updated_at=db_user.updated_at,
)
def get_user_by_email(email: str) -> Optional[User]:
"""Get a user by email"""
if USE_DATABASE and DATABASE_AVAILABLE:
from database.connection import get_sync_session
from database.repositories import UserRepository
with get_sync_session() as session:
repo = UserRepository(session)
db_user = repo.get_by_email(email)
if db_user:
return _db_user_to_model(db_user)
return None
else:
users = load_users()
for user_data in users.values():
if user_data.get("email", "").lower() == email.lower():
return User(**user_data)
return None
def get_user_by_id(user_id: str) -> Optional[User]:
"""Get a user by ID"""
if USE_DATABASE and DATABASE_AVAILABLE:
from database.connection import get_sync_session
from database.repositories import UserRepository
with get_sync_session() as session:
repo = UserRepository(session)
db_user = repo.get_by_id(user_id)
if db_user:
return _db_user_to_model(db_user)
return None
else:
users = load_users()
if user_id in users:
return User(**users[user_id])
return None
def create_user(user_create: UserCreate) -> User:
"""Create a new user"""
# Check if email exists
if get_user_by_email(user_create.email):
raise ValueError("Email already registered")
if USE_DATABASE and DATABASE_AVAILABLE:
from database.connection import get_sync_session
from database.repositories import UserRepository
with get_sync_session() as session:
repo = UserRepository(session)
db_user = repo.create(
email=user_create.email,
name=user_create.name,
hashed_password=hash_password(user_create.password),
tier="free",
)
return _db_user_to_model(db_user)
else:
users = load_users()
# Generate user ID
user_id = secrets.token_urlsafe(16)
# Create user
user = User(
id=user_id,
email=user_create.email,
name=user_create.name,
password_hash=hash_password(user_create.password),
plan=PlanType.FREE,
subscription_status=SubscriptionStatus.ACTIVE,
)
# Save to storage
users[user_id] = user.model_dump()
save_users(users)
return user
def authenticate_user(email: str, password: str) -> Optional[User]:
"""Authenticate a user with email and password"""
user = get_user_by_email(email)
if not user:
return None
if not verify_password(password, user.password_hash):
return None
return user
def update_user(user_id: str, updates: Dict[str, Any]) -> Optional[User]:
"""Update a user's data"""
if USE_DATABASE and DATABASE_AVAILABLE:
from database.connection import get_sync_session
from database.repositories import UserRepository
with get_sync_session() as session:
repo = UserRepository(session)
db_user = repo.update(user_id, **updates)
if db_user:
return _db_user_to_model(db_user)
return None
else:
users = load_users()
if user_id not in users:
return None
users[user_id].update(updates)
users[user_id]["updated_at"] = datetime.now(timezone.utc).isoformat()
save_users(users)
return User(**users[user_id])
def check_usage_limits(user: User) -> Dict[str, Any]:
"""Check if user has exceeded their plan limits"""
plan = PLANS[user.plan]
# Reset usage if it's a new month
now = datetime.now(timezone.utc)
if (
user.usage_reset_date.month != now.month
or user.usage_reset_date.year != now.year
):
update_user(
user.id,
{
"docs_translated_this_month": 0,
"pages_translated_this_month": 0,
"api_calls_this_month": 0,
"usage_reset_date": now.isoformat() if not USE_DATABASE else now,
},
)
user.docs_translated_this_month = 0
user.pages_translated_this_month = 0
user.api_calls_this_month = 0
docs_limit = plan["docs_per_month"]
docs_remaining = (
max(0, docs_limit - user.docs_translated_this_month) if docs_limit > 0 else -1
)
return {
"can_translate": docs_remaining != 0 or user.extra_credits > 0,
"docs_used": user.docs_translated_this_month,
"docs_limit": docs_limit,
"docs_remaining": docs_remaining,
"pages_used": user.pages_translated_this_month,
"extra_credits": user.extra_credits,
"max_pages_per_doc": plan["max_pages_per_doc"],
"max_file_size_mb": plan["max_file_size_mb"],
"allowed_providers": plan["providers"],
}
def record_usage(user_id: str, pages_count: int, use_credits: bool = False) -> bool:
"""Record document translation usage"""
user = get_user_by_id(user_id)
if not user:
return False
updates = {
"docs_translated_this_month": user.docs_translated_this_month + 1,
"pages_translated_this_month": user.pages_translated_this_month + pages_count,
}
if use_credits:
updates["extra_credits"] = max(0, user.extra_credits - pages_count)
result = update_user(user_id, updates)
return result is not None
def add_credits(user_id: str, credits: int) -> bool:
"""Add credits to a user's account"""
user = get_user_by_id(user_id)
if not user:
return False
result = update_user(user_id, {"extra_credits": user.extra_credits + credits})
return result is not None
# Valid plan values for admin tier change (Story 1.7)
VALID_PLAN_VALUES = {"free", "starter", "pro", "business", "enterprise"}
def update_user_plan(user_id: str, plan: str) -> Optional[User]:
"""
Update a user's plan/tier (admin only). Keeps User.plan and User.tier in sync.
tier is set to 'pro' for pro/business/enterprise, 'free' otherwise (DB constraint).
"""
plan_lower = (plan or "").strip().lower()
if plan_lower not in VALID_PLAN_VALUES:
return None
plan_enum = PlanType(plan_lower)
tier = (
"pro"
if plan_enum in (PlanType.PRO, PlanType.BUSINESS, PlanType.ENTERPRISE)
else "free"
)
if USE_DATABASE and DATABASE_AVAILABLE:
updates = {"plan": plan_enum, "tier": tier}
else:
updates = {"plan": plan_lower, "tier": tier}
return update_user(user_id, updates)
def get_user_by_api_key(api_key: str) -> Optional[User]:
"""
Get a user by API key.
Verifies that:
- The key exists in the database
- The key is active (is_active=True)
- The key hasn't expired (expires_at is None or in the future)
Returns the user associated with the API key, or None if invalid/revoked.
Raises:
ValueError: With code "API_KEY_REVOKED" if key exists but is inactive
"""
if not api_key:
return None
# Only database backend supports API keys
if USE_DATABASE and DATABASE_AVAILABLE:
from database.connection import get_sync_session
from database.models import ApiKey
import hashlib
# Hash the provided key to compare with stored hash
key_hash = hashlib.sha256(api_key.encode()).hexdigest()
with get_sync_session() as session:
api_key_record = (
session.query(ApiKey).filter(ApiKey.key_hash == key_hash).first()
)
if not api_key_record:
return None
# Check if key is active (Story 3.2 - Revocation check)
if not api_key_record.is_active:
raise ValueError("API_KEY_REVOKED")
# Check expiration if set
if api_key_record.expires_at:
if api_key_record.expires_at < datetime.now(timezone.utc):
raise ValueError("API_KEY_EXPIRED")
# Update last_used_at and usage_count
api_key_record.last_used_at = datetime.now(timezone.utc)
api_key_record.usage_count = (api_key_record.usage_count or 0) + 1
session.commit()
# Get the user
user_id = api_key_record.user_id
return get_user_by_id(str(user_id))
return None
def init_database():
"""Initialize the database (call on application startup)"""
if USE_DATABASE and DATABASE_AVAILABLE:
_init_db()
logger.info("Database initialized successfully")
else:
logger.info("Using JSON file storage")