""" Database-backed authentication service Replaces JSON file storage with SQLAlchemy """ import os import secrets import hashlib from datetime import datetime, timedelta from typing import Optional, Dict, Any import logging # Try to import optional dependencies try: import jwt JWT_AVAILABLE = True except ImportError: JWT_AVAILABLE = False import json import base64 try: from passlib.context import CryptContext pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") PASSLIB_AVAILABLE = True except ImportError: PASSLIB_AVAILABLE = False from database.connection import get_db_session from database.repositories import UserRepository from database.models import User, PlanType, SubscriptionStatus from models.subscription import PLANS logger = logging.getLogger(__name__) # Configuration SECRET_KEY = os.getenv("JWT_SECRET_KEY", secrets.token_urlsafe(32)) ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_HOURS = 24 REFRESH_TOKEN_EXPIRE_DAYS = 30 def hash_password(password: str) -> str: """Hash a password using bcrypt or fallback to SHA256""" if PASSLIB_AVAILABLE: return pwd_context.hash(password) else: salt = secrets.token_hex(16) hashed = hashlib.sha256(f"{salt}{password}".encode()).hexdigest() return f"sha256${salt}${hashed}" def verify_password(plain_password: str, hashed_password: str) -> bool: """Verify a password against its hash""" if PASSLIB_AVAILABLE and not hashed_password.startswith("sha256$"): return pwd_context.verify(plain_password, hashed_password) else: parts = hashed_password.split("$") if len(parts) == 3 and parts[0] == "sha256": salt = parts[1] expected_hash = parts[2] actual_hash = hashlib.sha256(f"{salt}{plain_password}".encode()).hexdigest() return secrets.compare_digest(actual_hash, expected_hash) return False def create_access_token(user_id: str, expires_delta: Optional[timedelta] = None) -> str: """Create a JWT access token""" expire = datetime.utcnow() + (expires_delta or timedelta(hours=ACCESS_TOKEN_EXPIRE_HOURS)) if not JWT_AVAILABLE: token_data = {"user_id": user_id, "exp": expire.isoformat(), "type": "access"} return base64.urlsafe_b64encode(json.dumps(token_data).encode()).decode() to_encode = {"sub": user_id, "exp": expire, "type": "access"} return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) def create_refresh_token(user_id: str) -> str: """Create a JWT refresh token""" expire = datetime.utcnow() + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) if not JWT_AVAILABLE: token_data = {"user_id": user_id, "exp": expire.isoformat(), "type": "refresh"} return base64.urlsafe_b64encode(json.dumps(token_data).encode()).decode() to_encode = {"sub": user_id, "exp": expire, "type": "refresh"} return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) def verify_token(token: str) -> Optional[Dict[str, Any]]: """Verify a JWT token and return payload""" if not JWT_AVAILABLE: try: data = json.loads(base64.urlsafe_b64decode(token.encode()).decode()) exp = datetime.fromisoformat(data["exp"]) if exp < datetime.utcnow(): return None return {"sub": data["user_id"], "type": data.get("type", "access")} except Exception: return None try: payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) return payload except jwt.ExpiredSignatureError: return None except jwt.JWTError: return None def create_user(email: str, name: str, password: str) -> User: """Create a new user in the database""" with get_db_session() as db: repo = UserRepository(db) # Check if email already exists existing = repo.get_by_email(email) if existing: raise ValueError("Email already registered") password_hash = hash_password(password) user = repo.create( email=email, name=name, password_hash=password_hash, plan=PlanType.FREE, ) return user def authenticate_user(email: str, password: str) -> Optional[User]: """Authenticate user and return user object if valid""" with get_db_session() as db: repo = UserRepository(db) user = repo.get_by_email(email) if not user: return None if not verify_password(password, user.password_hash): return None # Update last login repo.update(user.id, last_login_at=datetime.utcnow()) return user def get_user_by_id(user_id: str) -> Optional[User]: """Get user by ID from database""" with get_db_session() as db: repo = UserRepository(db) return repo.get_by_id(user_id) def get_user_by_email(email: str) -> Optional[User]: """Get user by email from database""" with get_db_session() as db: repo = UserRepository(db) return repo.get_by_email(email) def update_user(user_id: str, updates: Dict[str, Any]) -> Optional[User]: """Update user fields in database""" with get_db_session() as db: repo = UserRepository(db) return repo.update(user_id, **updates) def add_credits(user_id: str, credits: int) -> bool: """Add credits to user account""" with get_db_session() as db: repo = UserRepository(db) result = repo.add_credits(user_id, credits) return result is not None def use_credits(user_id: str, credits: int) -> bool: """Use credits from user account""" with get_db_session() as db: repo = UserRepository(db) return repo.use_credits(user_id, credits) def increment_usage(user_id: str, docs: int = 0, pages: int = 0, api_calls: int = 0) -> bool: """Increment user usage counters""" with get_db_session() as db: repo = UserRepository(db) result = repo.increment_usage(user_id, docs=docs, pages=pages, api_calls=api_calls) return result is not None def check_usage_limits(user_id: str) -> Dict[str, Any]: """Check if user is within their plan limits""" with get_db_session() as db: repo = UserRepository(db) user = repo.get_by_id(user_id) if not user: return {"allowed": False, "reason": "User not found"} plan_config = PLANS.get(user.plan, PLANS[PlanType.FREE]) # Check document limit docs_limit = plan_config["docs_per_month"] if docs_limit > 0 and user.docs_translated_this_month >= docs_limit: # Check if user has extra credits if user.extra_credits <= 0: return { "allowed": False, "reason": "Monthly document limit reached", "limit": docs_limit, "used": user.docs_translated_this_month, } return { "allowed": True, "docs_remaining": max(0, docs_limit - user.docs_translated_this_month) if docs_limit > 0 else -1, "extra_credits": user.extra_credits, } def get_user_usage_stats(user_id: str) -> Dict[str, Any]: """Get detailed usage statistics for a user""" with get_db_session() as db: repo = UserRepository(db) user = repo.get_by_id(user_id) if not user: return {} plan_config = PLANS.get(user.plan, PLANS[PlanType.FREE]) return { "docs_used": user.docs_translated_this_month, "docs_limit": plan_config["docs_per_month"], "docs_remaining": max(0, plan_config["docs_per_month"] - user.docs_translated_this_month) if plan_config["docs_per_month"] > 0 else -1, "pages_used": user.pages_translated_this_month, "extra_credits": user.extra_credits, "max_pages_per_doc": plan_config["max_pages_per_doc"], "max_file_size_mb": plan_config["max_file_size_mb"], "allowed_providers": plan_config["providers"], "api_access": plan_config.get("api_access", False), "api_calls_used": user.api_calls_this_month if plan_config.get("api_access") else 0, "api_calls_limit": plan_config.get("api_calls_per_month", 0), }