""" Repository layer for database operations Provides clean interface for CRUD operations """ import hashlib import secrets from datetime import datetime, timedelta from typing import Optional, List, Dict, Any from sqlalchemy.orm import Session from sqlalchemy import and_, func, or_ from database.models import ( User, Translation, ApiKey, UsageLog, PaymentHistory, PlanType, SubscriptionStatus ) class UserRepository: """Repository for User database operations""" def __init__(self, db: Session): self.db = db def get_by_id(self, user_id: str) -> Optional[User]: """Get user by ID""" return self.db.query(User).filter(User.id == user_id).first() def get_by_email(self, email: str) -> Optional[User]: """Get user by email (case-insensitive)""" return self.db.query(User).filter( func.lower(User.email) == email.lower() ).first() def get_by_stripe_customer(self, stripe_customer_id: str) -> Optional[User]: """Get user by Stripe customer ID""" return self.db.query(User).filter( User.stripe_customer_id == stripe_customer_id ).first() def create( self, email: str, name: str, password_hash: str, plan: PlanType = PlanType.FREE ) -> User: """Create a new user""" user = User( email=email.lower(), name=name, password_hash=password_hash, plan=plan, subscription_status=SubscriptionStatus.ACTIVE, ) self.db.add(user) self.db.commit() self.db.refresh(user) return user def update(self, user_id: str, **kwargs) -> Optional[User]: """Update user fields""" user = self.get_by_id(user_id) if not user: return None for key, value in kwargs.items(): if hasattr(user, key): setattr(user, key, value) user.updated_at = datetime.utcnow() self.db.commit() self.db.refresh(user) return user def delete(self, user_id: str) -> bool: """Delete a user""" user = self.get_by_id(user_id) if not user: return False self.db.delete(user) self.db.commit() return True def increment_usage( self, user_id: str, docs: int = 0, pages: int = 0, api_calls: int = 0 ) -> Optional[User]: """Increment usage counters""" user = self.get_by_id(user_id) if not user: return None # Check if usage needs to be reset (monthly) if user.usage_reset_date: now = datetime.utcnow() if now.month != user.usage_reset_date.month or now.year != user.usage_reset_date.year: user.docs_translated_this_month = 0 user.pages_translated_this_month = 0 user.api_calls_this_month = 0 user.usage_reset_date = now user.docs_translated_this_month += docs user.pages_translated_this_month += pages user.api_calls_this_month += api_calls self.db.commit() self.db.refresh(user) return user def add_credits(self, user_id: str, credits: int) -> Optional[User]: """Add extra credits to user""" user = self.get_by_id(user_id) if not user: return None user.extra_credits += credits self.db.commit() self.db.refresh(user) return user def use_credits(self, user_id: str, credits: int) -> bool: """Use credits from user balance""" user = self.get_by_id(user_id) if not user or user.extra_credits < credits: return False user.extra_credits -= credits self.db.commit() return True def get_all_users( self, skip: int = 0, limit: int = 100, plan: Optional[PlanType] = None ) -> List[User]: """Get all users with pagination""" query = self.db.query(User) if plan: query = query.filter(User.plan == plan) return query.offset(skip).limit(limit).all() def count_users(self, plan: Optional[PlanType] = None) -> int: """Count total users""" query = self.db.query(func.count(User.id)) if plan: query = query.filter(User.plan == plan) return query.scalar() class TranslationRepository: """Repository for Translation database operations""" def __init__(self, db: Session): self.db = db def create( self, user_id: str, original_filename: str, file_type: str, target_language: str, provider: str, source_language: str = "auto", file_size_bytes: int = 0, page_count: int = 0, ) -> Translation: """Create a new translation record""" translation = Translation( user_id=user_id, original_filename=original_filename, file_type=file_type, file_size_bytes=file_size_bytes, page_count=page_count, source_language=source_language, target_language=target_language, provider=provider, status="pending", ) self.db.add(translation) self.db.commit() self.db.refresh(translation) return translation def update_status( self, translation_id: str, status: str, error_message: Optional[str] = None, processing_time_ms: Optional[int] = None, characters_translated: Optional[int] = None, ) -> Optional[Translation]: """Update translation status""" translation = self.db.query(Translation).filter( Translation.id == translation_id ).first() if not translation: return None translation.status = status if error_message: translation.error_message = error_message if processing_time_ms: translation.processing_time_ms = processing_time_ms if characters_translated: translation.characters_translated = characters_translated if status == "completed": translation.completed_at = datetime.utcnow() self.db.commit() self.db.refresh(translation) return translation def get_user_translations( self, user_id: str, skip: int = 0, limit: int = 50, status: Optional[str] = None, ) -> List[Translation]: """Get user's translation history""" query = self.db.query(Translation).filter(Translation.user_id == user_id) if status: query = query.filter(Translation.status == status) return query.order_by(Translation.created_at.desc()).offset(skip).limit(limit).all() def get_user_stats(self, user_id: str, days: int = 30) -> Dict[str, Any]: """Get user's translation statistics""" since = datetime.utcnow() - timedelta(days=days) result = self.db.query( func.count(Translation.id).label("total_translations"), func.sum(Translation.page_count).label("total_pages"), func.sum(Translation.characters_translated).label("total_characters"), ).filter( and_( Translation.user_id == user_id, Translation.created_at >= since, Translation.status == "completed", ) ).first() return { "total_translations": result.total_translations or 0, "total_pages": result.total_pages or 0, "total_characters": result.total_characters or 0, "period_days": days, } class ApiKeyRepository: """Repository for API Key database operations""" def __init__(self, db: Session): self.db = db @staticmethod def hash_key(key: str) -> str: """Hash an API key""" return hashlib.sha256(key.encode()).hexdigest() def create( self, user_id: str, name: str, scopes: List[str] = None, expires_in_days: Optional[int] = None, ) -> tuple[ApiKey, str]: """Create a new API key. Returns (ApiKey, raw_key)""" # Generate a secure random key raw_key = f"tr_{secrets.token_urlsafe(32)}" key_hash = self.hash_key(raw_key) key_prefix = raw_key[:10] expires_at = None if expires_in_days: expires_at = datetime.utcnow() + timedelta(days=expires_in_days) api_key = ApiKey( user_id=user_id, name=name, key_hash=key_hash, key_prefix=key_prefix, scopes=scopes or ["translate"], expires_at=expires_at, ) self.db.add(api_key) self.db.commit() self.db.refresh(api_key) return api_key, raw_key def get_by_key(self, raw_key: str) -> Optional[ApiKey]: """Get API key by raw key value""" key_hash = self.hash_key(raw_key) api_key = self.db.query(ApiKey).filter( and_( ApiKey.key_hash == key_hash, ApiKey.is_active == True, ) ).first() if api_key: # Check expiration if api_key.expires_at and api_key.expires_at < datetime.utcnow(): return None # Update last used api_key.last_used_at = datetime.utcnow() api_key.usage_count += 1 self.db.commit() return api_key def get_user_keys(self, user_id: str) -> List[ApiKey]: """Get all API keys for a user""" return self.db.query(ApiKey).filter( ApiKey.user_id == user_id ).order_by(ApiKey.created_at.desc()).all() def revoke(self, key_id: str, user_id: str) -> bool: """Revoke an API key""" api_key = self.db.query(ApiKey).filter( and_( ApiKey.id == key_id, ApiKey.user_id == user_id, ) ).first() if not api_key: return False api_key.is_active = False self.db.commit() return True