397 lines
12 KiB
Python
397 lines
12 KiB
Python
"""
|
|
Repository layer for database operations
|
|
Provides clean interface for CRUD operations
|
|
"""
|
|
|
|
import hashlib
|
|
import secrets
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Optional, List, Dict, Any
|
|
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy import and_, func, or_
|
|
|
|
from database.models import (
|
|
User,
|
|
Translation,
|
|
ApiKey,
|
|
UsageLog,
|
|
PaymentHistory,
|
|
PlanType,
|
|
SubscriptionStatus,
|
|
)
|
|
|
|
|
|
class UserRepository:
|
|
"""Repository for User database operations"""
|
|
|
|
def __init__(self, db: Session):
|
|
self.db = db
|
|
|
|
def get_by_id(self, user_id: str) -> Optional[User]:
|
|
"""Get user by ID"""
|
|
return self.db.query(User).filter(User.id == user_id).first()
|
|
|
|
def get_by_email(self, email: str) -> Optional[User]:
|
|
"""Get user by email (case-insensitive)"""
|
|
return (
|
|
self.db.query(User).filter(func.lower(User.email) == email.lower()).first()
|
|
)
|
|
|
|
def get_by_stripe_customer(self, stripe_customer_id: str) -> Optional[User]:
|
|
"""Get user by Stripe customer ID"""
|
|
return (
|
|
self.db.query(User)
|
|
.filter(User.stripe_customer_id == stripe_customer_id)
|
|
.first()
|
|
)
|
|
|
|
def create(
|
|
self,
|
|
email: str,
|
|
name: str,
|
|
hashed_password: str,
|
|
tier: str = "free",
|
|
) -> User:
|
|
"""Create a new user. Uses hashed_password and tier (story 1-1 refactor)."""
|
|
plan = PlanType.PRO if tier == "pro" else PlanType.FREE
|
|
user = User(
|
|
email=email.lower(),
|
|
name=name,
|
|
hashed_password=hashed_password,
|
|
tier=tier,
|
|
plan=plan,
|
|
subscription_status=SubscriptionStatus.ACTIVE,
|
|
)
|
|
self.db.add(user)
|
|
self.db.commit()
|
|
self.db.refresh(user)
|
|
return user
|
|
|
|
def update(self, user_id: str, **kwargs) -> Optional[User]:
|
|
"""Update user fields"""
|
|
user = self.get_by_id(user_id)
|
|
if not user:
|
|
return None
|
|
|
|
for key, value in kwargs.items():
|
|
if hasattr(user, key):
|
|
setattr(user, key, value)
|
|
|
|
user.updated_at = datetime.now(timezone.utc)
|
|
self.db.commit()
|
|
self.db.refresh(user)
|
|
return user
|
|
|
|
def delete(self, user_id: str) -> bool:
|
|
"""Delete a user"""
|
|
user = self.get_by_id(user_id)
|
|
if not user:
|
|
return False
|
|
|
|
self.db.delete(user)
|
|
self.db.commit()
|
|
return True
|
|
|
|
def increment_usage(
|
|
self, user_id: str, docs: int = 0, pages: int = 0, api_calls: int = 0
|
|
) -> Optional[User]:
|
|
"""Increment usage counters"""
|
|
user = self.get_by_id(user_id)
|
|
if not user:
|
|
return None
|
|
|
|
# Check if usage needs to be reset (monthly)
|
|
if user.usage_reset_date:
|
|
now = datetime.now(timezone.utc)
|
|
if (
|
|
now.month != user.usage_reset_date.month
|
|
or now.year != user.usage_reset_date.year
|
|
):
|
|
user.docs_translated_this_month = 0
|
|
user.pages_translated_this_month = 0
|
|
user.api_calls_this_month = 0
|
|
user.usage_reset_date = now
|
|
|
|
user.docs_translated_this_month += docs
|
|
user.pages_translated_this_month += pages
|
|
user.api_calls_this_month += api_calls
|
|
|
|
self.db.commit()
|
|
self.db.refresh(user)
|
|
return user
|
|
|
|
def add_credits(self, user_id: str, credits: int) -> Optional[User]:
|
|
"""Add extra credits to user"""
|
|
user = self.get_by_id(user_id)
|
|
if not user:
|
|
return None
|
|
|
|
user.extra_credits += credits
|
|
self.db.commit()
|
|
self.db.refresh(user)
|
|
return user
|
|
|
|
def use_credits(self, user_id: str, credits: int) -> bool:
|
|
"""Use credits from user balance"""
|
|
user = self.get_by_id(user_id)
|
|
if not user or user.extra_credits < credits:
|
|
return False
|
|
|
|
user.extra_credits -= credits
|
|
self.db.commit()
|
|
return True
|
|
|
|
def get_all_users(
|
|
self, skip: int = 0, limit: int = 100, plan: Optional[PlanType] = None
|
|
) -> List[User]:
|
|
"""Get all users with pagination"""
|
|
query = self.db.query(User)
|
|
if plan:
|
|
query = query.filter(User.plan == plan)
|
|
return query.offset(skip).limit(limit).all()
|
|
|
|
def count_users(self, plan: Optional[PlanType] = None) -> int:
|
|
"""Count total users"""
|
|
query = self.db.query(func.count(User.id))
|
|
if plan:
|
|
query = query.filter(User.plan == plan)
|
|
return query.scalar()
|
|
|
|
|
|
class TranslationRepository:
|
|
"""Repository for Translation database operations"""
|
|
|
|
def __init__(self, db: Session):
|
|
self.db = db
|
|
|
|
def create(
|
|
self,
|
|
user_id: str,
|
|
original_filename: str,
|
|
file_type: str,
|
|
target_language: str,
|
|
provider: str,
|
|
source_language: str = "auto",
|
|
file_size_bytes: int = 0,
|
|
page_count: int = 0,
|
|
) -> Translation:
|
|
"""Create a new translation record"""
|
|
translation = Translation(
|
|
user_id=user_id,
|
|
original_filename=original_filename,
|
|
file_type=file_type,
|
|
file_size_bytes=file_size_bytes,
|
|
page_count=page_count,
|
|
source_language=source_language,
|
|
target_language=target_language,
|
|
provider=provider,
|
|
status="pending",
|
|
)
|
|
self.db.add(translation)
|
|
self.db.commit()
|
|
self.db.refresh(translation)
|
|
return translation
|
|
|
|
def create_completed(
|
|
self,
|
|
user_id: str,
|
|
original_filename: str,
|
|
file_type: str,
|
|
target_language: str,
|
|
provider: str,
|
|
source_language: str = "auto",
|
|
file_size_bytes: int = 0,
|
|
page_count: int = 0,
|
|
) -> Translation:
|
|
"""Create a translation record directly in completed status (Story 1.8 - billing log)."""
|
|
translation = Translation(
|
|
user_id=user_id,
|
|
original_filename=original_filename,
|
|
file_type=file_type,
|
|
file_size_bytes=file_size_bytes,
|
|
page_count=page_count,
|
|
source_language=source_language,
|
|
target_language=target_language,
|
|
provider=provider,
|
|
status="completed",
|
|
completed_at=datetime.now(timezone.utc),
|
|
)
|
|
self.db.add(translation)
|
|
self.db.commit()
|
|
self.db.refresh(translation)
|
|
return translation
|
|
|
|
def update_status(
|
|
self,
|
|
translation_id: str,
|
|
status: str,
|
|
error_message: Optional[str] = None,
|
|
processing_time_ms: Optional[int] = None,
|
|
characters_translated: Optional[int] = None,
|
|
) -> Optional[Translation]:
|
|
"""Update translation status"""
|
|
translation = (
|
|
self.db.query(Translation).filter(Translation.id == translation_id).first()
|
|
)
|
|
|
|
if not translation:
|
|
return None
|
|
|
|
translation.status = status
|
|
if error_message:
|
|
translation.error_message = error_message
|
|
if processing_time_ms:
|
|
translation.processing_time_ms = processing_time_ms
|
|
if characters_translated:
|
|
translation.characters_translated = characters_translated
|
|
if status == "completed":
|
|
translation.completed_at = datetime.now(timezone.utc)
|
|
|
|
self.db.commit()
|
|
self.db.refresh(translation)
|
|
return translation
|
|
|
|
def get_user_translations(
|
|
self,
|
|
user_id: str,
|
|
skip: int = 0,
|
|
limit: int = 50,
|
|
status: Optional[str] = None,
|
|
) -> List[Translation]:
|
|
"""Get user's translation history"""
|
|
query = self.db.query(Translation).filter(Translation.user_id == user_id)
|
|
if status:
|
|
query = query.filter(Translation.status == status)
|
|
return (
|
|
query.order_by(Translation.created_at.desc())
|
|
.offset(skip)
|
|
.limit(limit)
|
|
.all()
|
|
)
|
|
|
|
def get_user_stats(self, user_id: str, days: int = 30) -> Dict[str, Any]:
|
|
"""Get user's translation statistics"""
|
|
since = datetime.now(timezone.utc) - timedelta(days=days)
|
|
|
|
result = (
|
|
self.db.query(
|
|
func.count(Translation.id).label("total_translations"),
|
|
func.sum(Translation.page_count).label("total_pages"),
|
|
func.sum(Translation.characters_translated).label("total_characters"),
|
|
)
|
|
.filter(
|
|
and_(
|
|
Translation.user_id == user_id,
|
|
Translation.created_at >= since,
|
|
Translation.status == "completed",
|
|
)
|
|
)
|
|
.first()
|
|
)
|
|
|
|
return {
|
|
"total_translations": result.total_translations or 0,
|
|
"total_pages": result.total_pages or 0,
|
|
"total_characters": result.total_characters or 0,
|
|
"period_days": days,
|
|
}
|
|
|
|
|
|
class ApiKeyRepository:
|
|
"""Repository for API Key database operations"""
|
|
|
|
def __init__(self, db: Session):
|
|
self.db = db
|
|
|
|
@staticmethod
|
|
def hash_key(key: str) -> str:
|
|
"""Hash an API key"""
|
|
return hashlib.sha256(key.encode()).hexdigest()
|
|
|
|
def create(
|
|
self,
|
|
user_id: str,
|
|
name: str,
|
|
scopes: List[str] = None,
|
|
expires_in_days: Optional[int] = None,
|
|
) -> tuple[ApiKey, str]:
|
|
"""Create a new API key. Returns (ApiKey, raw_key)"""
|
|
# Generate a secure random key
|
|
raw_key = f"tr_{secrets.token_urlsafe(32)}"
|
|
key_hash = self.hash_key(raw_key)
|
|
key_prefix = raw_key[:10]
|
|
|
|
expires_at = None
|
|
if expires_in_days:
|
|
expires_at = datetime.now(timezone.utc) + timedelta(days=expires_in_days)
|
|
|
|
api_key = ApiKey(
|
|
user_id=user_id,
|
|
name=name,
|
|
key_hash=key_hash,
|
|
key_prefix=key_prefix,
|
|
scopes=scopes or ["translate"],
|
|
expires_at=expires_at,
|
|
)
|
|
self.db.add(api_key)
|
|
self.db.commit()
|
|
self.db.refresh(api_key)
|
|
|
|
return api_key, raw_key
|
|
|
|
def get_by_key(self, raw_key: str) -> Optional[ApiKey]:
|
|
"""Get API key by raw key value"""
|
|
key_hash = self.hash_key(raw_key)
|
|
api_key = (
|
|
self.db.query(ApiKey)
|
|
.filter(
|
|
and_(
|
|
ApiKey.key_hash == key_hash,
|
|
ApiKey.is_active == True,
|
|
)
|
|
)
|
|
.first()
|
|
)
|
|
|
|
if api_key:
|
|
# Check expiration
|
|
if api_key.expires_at and api_key.expires_at < datetime.now(timezone.utc):
|
|
return None
|
|
|
|
# Update last used
|
|
api_key.last_used_at = datetime.now(timezone.utc)
|
|
api_key.usage_count += 1
|
|
self.db.commit()
|
|
|
|
return api_key
|
|
|
|
def get_user_keys(self, user_id: str) -> List[ApiKey]:
|
|
"""Get all API keys for a user"""
|
|
return (
|
|
self.db.query(ApiKey)
|
|
.filter(ApiKey.user_id == user_id)
|
|
.order_by(ApiKey.created_at.desc())
|
|
.all()
|
|
)
|
|
|
|
def revoke(self, key_id: str, user_id: str) -> bool:
|
|
"""Revoke an API key"""
|
|
api_key = (
|
|
self.db.query(ApiKey)
|
|
.filter(
|
|
and_(
|
|
ApiKey.id == key_id,
|
|
ApiKey.user_id == user_id,
|
|
)
|
|
)
|
|
.first()
|
|
)
|
|
|
|
if not api_key:
|
|
return False
|
|
|
|
api_key.is_active = False
|
|
self.db.commit()
|
|
return True
|