""" Authentication service with JWT tokens and password hashing This service provides user authentication with automatic backend selection: - If DATABASE_URL is configured: Uses PostgreSQL database - Otherwise: Falls back to JSON file storage (development mode) """ import os import secrets import hashlib import uuid import time from datetime import datetime, timedelta, timezone from typing import Optional, Dict, Any import json from pathlib import Path import logging logger = logging.getLogger(__name__) # Try to import optional dependencies try: import jwt JWT_AVAILABLE = True except ImportError: JWT_AVAILABLE = False logger.warning("PyJWT not installed. Using fallback token encoding.") try: from passlib.context import CryptContext pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") PASSLIB_AVAILABLE = True except ImportError: PASSLIB_AVAILABLE = False logger.warning("passlib not installed. Using SHA256 fallback for password hashing.") # Check if database is configured DATABASE_URL = os.getenv("DATABASE_URL", "") USE_DATABASE = bool(DATABASE_URL and DATABASE_URL.startswith("postgresql")) if USE_DATABASE: try: from database.repositories import UserRepository from database.connection import get_sync_session, init_db as _init_db from database import models as db_models DATABASE_AVAILABLE = True logger.info("Database backend enabled for authentication") except ImportError as e: DATABASE_AVAILABLE = False USE_DATABASE = False logger.warning(f"Database modules not available: {e}. Using JSON storage.") else: DATABASE_AVAILABLE = False logger.info( "Using JSON file storage for authentication (DATABASE_URL not configured)" ) from models.subscription import User, UserCreate, PlanType, SubscriptionStatus, PLANS # Configuration _jwt_secret = os.getenv("JWT_SECRET", os.getenv("JWT_SECRET_KEY")) if not _jwt_secret: _jwt_secret = secrets.token_urlsafe(32) logger.critical( "SECURITY: JWT_SECRET_KEY is not configured! Using an ephemeral random key. " "ALL JWT TOKENS WILL BE INVALIDATED ON EVERY RESTART. " "Set JWT_SECRET_KEY in your .env file immediately." ) SECRET_KEY = _jwt_secret ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 15 REFRESH_TOKEN_EXPIRE_DAYS = 7 # Simple file-based storage (used when database is not configured) USERS_FILE = Path("data/users.json") USERS_FILE.parent.mkdir(exist_ok=True) # Token blocklist: jti → expiry timestamp (Unix). # Uses Redis when available (persistent across restarts), falls back to in-memory. _revoked_jtis: dict[str, float] = {} _redis_blocklist_client = None def _get_blocklist_redis(): """Return Redis client for token blocklist, or None if unavailable.""" global _redis_blocklist_client if _redis_blocklist_client is not None: return _redis_blocklist_client if _redis_blocklist_client is not False else None redis_url = os.getenv("REDIS_URL", "") if not redis_url: _redis_blocklist_client = False return None try: import redis as redis_lib client = redis_lib.from_url(redis_url, decode_responses=True) client.ping() _redis_blocklist_client = client logger.info("Token blocklist using Redis (persistent across restarts)") return client except Exception as e: logger.warning(f"Redis unavailable for token blocklist, using in-memory: {e}") _redis_blocklist_client = False return None def revoke_token_jti(jti: str, expires_at: float) -> None: """Add a JTI to the blocklist (revoked until its expiry time).""" ttl = max(1, int(expires_at - time.time())) redis = _get_blocklist_redis() if redis: try: redis.setex(f"revoked_jti:{jti}", ttl, "1") return except Exception as e: logger.warning(f"Redis revoke failed, falling back to memory: {e}") _revoked_jtis[jti] = expires_at def is_token_revoked(jti: str) -> bool: """Return True if JTI is revoked. Lazy GC of expired in-memory entries.""" if not jti: return False redis = _get_blocklist_redis() if redis: try: return redis.exists(f"revoked_jti:{jti}") == 1 except Exception as e: logger.warning(f"Redis revoke check failed, falling back to memory: {e}") now = time.time() expired = [k for k, v in _revoked_jtis.items() if v < now] for k in expired: _revoked_jtis.pop(k, None) return jti in _revoked_jtis def hash_password(password: str) -> str: """Hash a password using bcrypt or fallback to SHA256""" if PASSLIB_AVAILABLE: return pwd_context.hash(password) else: # Fallback to SHA256 with salt salt = secrets.token_hex(16) hashed = hashlib.sha256(f"{salt}{password}".encode()).hexdigest() return f"sha256${salt}${hashed}" def verify_password(plain_password: str, hashed_password: str) -> bool: """Verify a password against its hash""" if PASSLIB_AVAILABLE and not hashed_password.startswith("sha256$"): return pwd_context.verify(plain_password, hashed_password) else: # Fallback SHA256 verification parts = hashed_password.split("$") if len(parts) == 3 and parts[0] == "sha256": salt = parts[1] expected_hash = parts[2] actual_hash = hashlib.sha256(f"{salt}{plain_password}".encode()).hexdigest() return secrets.compare_digest(actual_hash, expected_hash) return False def create_access_token( user_id: str, tier: str = "free", expires_delta: Optional[timedelta] = None ) -> str: """Create a JWT access token with tier claim for quick access""" if not JWT_AVAILABLE: token_data = { "user_id": user_id, "tier": tier, "exp": ( datetime.now(timezone.utc) + (expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)) ).isoformat(), } import base64 return base64.urlsafe_b64encode(json.dumps(token_data).encode()).decode() expire = datetime.now(timezone.utc) + ( expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) ) to_encode = { "sub": user_id, "tier": tier, "exp": expire, "type": "access", "jti": str(uuid.uuid4()), } return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) def create_refresh_token( user_id: str, expires_delta: Optional[timedelta] = None ) -> str: """Create a JWT refresh token (7 days by default)""" if not JWT_AVAILABLE: token_data = { "user_id": user_id, "exp": ( datetime.now(timezone.utc) + (expires_delta or timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)) ).isoformat(), } import base64 return base64.urlsafe_b64encode(json.dumps(token_data).encode()).decode() expire = datetime.now(timezone.utc) + ( expires_delta or timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) ) to_encode = { "sub": user_id, "exp": expire, "type": "refresh", "jti": str(uuid.uuid4()), } return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) def verify_token(token: str) -> Optional[Dict[str, Any]]: """Verify a JWT token and return payload""" if not JWT_AVAILABLE: try: import base64 data = json.loads(base64.urlsafe_b64decode(token.encode()).decode()) exp = datetime.fromisoformat(data["exp"]) if exp < datetime.now(timezone.utc): return None return {"sub": data["user_id"]} except Exception: return None try: payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) jti = payload.get("jti") if jti and is_token_revoked(jti): return None return payload except jwt.ExpiredSignatureError: return None except jwt.PyJWTError: return None def load_users() -> Dict[str, Dict]: """Load users from file storage (JSON backend only)""" if USERS_FILE.exists(): try: with open(USERS_FILE, "r") as f: return json.load(f) except Exception as e: logger.error(f"Failed to load users file: {e}") return {} return {} def save_users(users: Dict[str, Dict]): """Save users to file storage (JSON backend only)""" with open(USERS_FILE, "w") as f: json.dump(users, f, indent=2, default=str) def _db_user_to_model(db_user) -> User: """Convert database user model to Pydantic User model""" return User( id=str(db_user.id), email=db_user.email, name=db_user.name or "", password_hash=db_user.password_hash, avatar_url=db_user.avatar_url, plan=PlanType(db_user.plan) if db_user.plan else PlanType.FREE, subscription_status=SubscriptionStatus(db_user.subscription_status) if db_user.subscription_status else SubscriptionStatus.ACTIVE, stripe_customer_id=db_user.stripe_customer_id, stripe_subscription_id=db_user.stripe_subscription_id, docs_translated_this_month=db_user.docs_translated_this_month or 0, pages_translated_this_month=db_user.pages_translated_this_month or 0, api_calls_this_month=db_user.api_calls_this_month or 0, daily_translation_count=getattr(db_user, "daily_translation_count", 0) or 0, extra_credits=db_user.extra_credits or 0, usage_reset_date=db_user.usage_reset_date or datetime.now(timezone.utc), default_source_lang=getattr(db_user, "default_source_lang", None) or "en", default_target_lang=getattr(db_user, "default_target_lang", None) or "es", default_provider=getattr(db_user, "default_provider", None) or "google", created_at=db_user.created_at or datetime.now(timezone.utc), updated_at=db_user.updated_at, ) def get_user_by_email(email: str) -> Optional[User]: """Get a user by email""" if USE_DATABASE and DATABASE_AVAILABLE: from database.connection import get_sync_session from database.repositories import UserRepository with get_sync_session() as session: repo = UserRepository(session) db_user = repo.get_by_email(email) if db_user: return _db_user_to_model(db_user) return None else: users = load_users() for user_data in users.values(): if user_data.get("email", "").lower() == email.lower(): return User(**user_data) return None def get_user_by_id(user_id: str) -> Optional[User]: """Get a user by ID""" if USE_DATABASE and DATABASE_AVAILABLE: from database.connection import get_sync_session from database.repositories import UserRepository with get_sync_session() as session: repo = UserRepository(session) db_user = repo.get_by_id(user_id) if db_user: return _db_user_to_model(db_user) return None else: users = load_users() if user_id in users: return User(**users[user_id]) return None def create_user(user_create: UserCreate) -> User: """Create a new user""" # Check if email exists if get_user_by_email(user_create.email): raise ValueError("Email already registered") if USE_DATABASE and DATABASE_AVAILABLE: from database.connection import get_sync_session from database.repositories import UserRepository with get_sync_session() as session: repo = UserRepository(session) db_user = repo.create( email=user_create.email, name=user_create.name, hashed_password=hash_password(user_create.password), tier="free", ) return _db_user_to_model(db_user) else: users = load_users() # Generate user ID user_id = secrets.token_urlsafe(16) # Create user user = User( id=user_id, email=user_create.email, name=user_create.name, password_hash=hash_password(user_create.password), plan=PlanType.FREE, subscription_status=SubscriptionStatus.ACTIVE, ) # Save to storage users[user_id] = user.model_dump() save_users(users) return user def authenticate_user(email: str, password: str) -> Optional[User]: """Authenticate a user with email and password""" user = get_user_by_email(email) if not user: return None if not verify_password(password, user.password_hash): return None return user def update_user(user_id: str, updates: Dict[str, Any]) -> Optional[User]: """Update a user's data""" if USE_DATABASE and DATABASE_AVAILABLE: from database.connection import get_sync_session from database.repositories import UserRepository with get_sync_session() as session: repo = UserRepository(session) db_user = repo.update(user_id, **updates) if db_user: return _db_user_to_model(db_user) return None else: users = load_users() if user_id not in users: return None users[user_id].update(updates) users[user_id]["updated_at"] = datetime.now(timezone.utc).isoformat() save_users(users) return User(**users[user_id]) def check_usage_limits(user: User) -> Dict[str, Any]: """Check if user has exceeded their plan limits""" plan = PLANS[user.plan] # Reset usage if it's a new month now = datetime.now(timezone.utc) if ( user.usage_reset_date.month != now.month or user.usage_reset_date.year != now.year ): update_user( user.id, { "docs_translated_this_month": 0, "pages_translated_this_month": 0, "api_calls_this_month": 0, "usage_reset_date": now.isoformat() if not USE_DATABASE else now, }, ) user.docs_translated_this_month = 0 user.pages_translated_this_month = 0 user.api_calls_this_month = 0 docs_limit = plan["docs_per_month"] docs_remaining = ( max(0, docs_limit - user.docs_translated_this_month) if docs_limit > 0 else -1 ) return { "can_translate": docs_remaining != 0 or user.extra_credits > 0, "docs_used": user.docs_translated_this_month, "docs_limit": docs_limit, "docs_remaining": docs_remaining, "pages_used": user.pages_translated_this_month, "extra_credits": user.extra_credits, "max_pages_per_doc": plan["max_pages_per_doc"], "max_file_size_mb": plan["max_file_size_mb"], "allowed_providers": plan["providers"], } def record_usage(user_id: str, pages_count: int, use_credits: bool = False) -> bool: """Record document translation usage""" user = get_user_by_id(user_id) if not user: return False updates = { "docs_translated_this_month": user.docs_translated_this_month + 1, "pages_translated_this_month": user.pages_translated_this_month + pages_count, } if use_credits: updates["extra_credits"] = max(0, user.extra_credits - pages_count) result = update_user(user_id, updates) return result is not None def add_credits(user_id: str, credits: int) -> bool: """Add credits to a user's account""" user = get_user_by_id(user_id) if not user: return False result = update_user(user_id, {"extra_credits": user.extra_credits + credits}) return result is not None # Valid plan values for admin tier change (Story 1.7) VALID_PLAN_VALUES = {"free", "starter", "pro", "business", "enterprise"} def update_user_plan(user_id: str, plan: str) -> Optional[User]: """ Update a user's plan/tier (admin only). Keeps User.plan and User.tier in sync. tier is set to 'pro' for pro/business/enterprise, 'free' otherwise (DB constraint). """ plan_lower = (plan or "").strip().lower() if plan_lower not in VALID_PLAN_VALUES: return None plan_enum = PlanType(plan_lower) tier = ( "pro" if plan_enum in (PlanType.PRO, PlanType.BUSINESS, PlanType.ENTERPRISE) else "free" ) if USE_DATABASE and DATABASE_AVAILABLE: updates = {"plan": plan_enum, "tier": tier} else: updates = {"plan": plan_lower, "tier": tier} return update_user(user_id, updates) def get_user_by_api_key(api_key: str) -> Optional[User]: """ Get a user by API key. Verifies that: - The key exists in the database - The key is active (is_active=True) - The key hasn't expired (expires_at is None or in the future) Returns the user associated with the API key, or None if invalid/revoked. Raises: ValueError: With code "API_KEY_REVOKED" if key exists but is inactive """ if not api_key: return None # Only database backend supports API keys if USE_DATABASE and DATABASE_AVAILABLE: from database.connection import get_sync_session from database.models import ApiKey import hashlib # Hash the provided key to compare with stored hash key_hash = hashlib.sha256(api_key.encode()).hexdigest() with get_sync_session() as session: api_key_record = ( session.query(ApiKey).filter(ApiKey.key_hash == key_hash).first() ) if not api_key_record: return None # Check if key is active (Story 3.2 - Revocation check) if not api_key_record.is_active: raise ValueError("API_KEY_REVOKED") # Check expiration if set if api_key_record.expires_at: if api_key_record.expires_at < datetime.now(timezone.utc): raise ValueError("API_KEY_EXPIRED") # Update last_used_at and usage_count api_key_record.last_used_at = datetime.now(timezone.utc) api_key_record.usage_count = (api_key_record.usage_count or 0) + 1 session.commit() # Get the user user_id = api_key_record.user_id return get_user_by_id(str(user_id)) return None def init_database(): """Initialize the database (call on application startup)""" if USE_DATABASE and DATABASE_AVAILABLE: _init_db() logger.info("Database initialized successfully") else: logger.info("Using JSON file storage")