- Add SQLAlchemy models for User, Translation, ApiKey, UsageLog, PaymentHistory - Add database connection management with PostgreSQL/SQLite support - Add repository layer for CRUD operations - Add Alembic migration setup with initial migration - Update auth_service to automatically use database when DATABASE_URL is set - Update docker-compose.yml with PostgreSQL service and Redis (non-optional) - Add database migration script (scripts/migrate_to_db.py) - Update .env.example with database configuration
342 lines
10 KiB
Python
342 lines
10 KiB
Python
"""
|
|
Repository layer for database operations
|
|
Provides clean interface for CRUD operations
|
|
"""
|
|
import hashlib
|
|
import secrets
|
|
from datetime import datetime, timedelta
|
|
from typing import Optional, List, Dict, Any
|
|
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy import and_, func, or_
|
|
|
|
from database.models import (
|
|
User, Translation, ApiKey, UsageLog, PaymentHistory,
|
|
PlanType, SubscriptionStatus
|
|
)
|
|
|
|
|
|
class UserRepository:
|
|
"""Repository for User database operations"""
|
|
|
|
def __init__(self, db: Session):
|
|
self.db = db
|
|
|
|
def get_by_id(self, user_id: str) -> Optional[User]:
|
|
"""Get user by ID"""
|
|
return self.db.query(User).filter(User.id == user_id).first()
|
|
|
|
def get_by_email(self, email: str) -> Optional[User]:
|
|
"""Get user by email (case-insensitive)"""
|
|
return self.db.query(User).filter(
|
|
func.lower(User.email) == email.lower()
|
|
).first()
|
|
|
|
def get_by_stripe_customer(self, stripe_customer_id: str) -> Optional[User]:
|
|
"""Get user by Stripe customer ID"""
|
|
return self.db.query(User).filter(
|
|
User.stripe_customer_id == stripe_customer_id
|
|
).first()
|
|
|
|
def create(
|
|
self,
|
|
email: str,
|
|
name: str,
|
|
password_hash: str,
|
|
plan: PlanType = PlanType.FREE
|
|
) -> User:
|
|
"""Create a new user"""
|
|
user = User(
|
|
email=email.lower(),
|
|
name=name,
|
|
password_hash=password_hash,
|
|
plan=plan,
|
|
subscription_status=SubscriptionStatus.ACTIVE,
|
|
)
|
|
self.db.add(user)
|
|
self.db.commit()
|
|
self.db.refresh(user)
|
|
return user
|
|
|
|
def update(self, user_id: str, **kwargs) -> Optional[User]:
|
|
"""Update user fields"""
|
|
user = self.get_by_id(user_id)
|
|
if not user:
|
|
return None
|
|
|
|
for key, value in kwargs.items():
|
|
if hasattr(user, key):
|
|
setattr(user, key, value)
|
|
|
|
user.updated_at = datetime.utcnow()
|
|
self.db.commit()
|
|
self.db.refresh(user)
|
|
return user
|
|
|
|
def delete(self, user_id: str) -> bool:
|
|
"""Delete a user"""
|
|
user = self.get_by_id(user_id)
|
|
if not user:
|
|
return False
|
|
|
|
self.db.delete(user)
|
|
self.db.commit()
|
|
return True
|
|
|
|
def increment_usage(
|
|
self,
|
|
user_id: str,
|
|
docs: int = 0,
|
|
pages: int = 0,
|
|
api_calls: int = 0
|
|
) -> Optional[User]:
|
|
"""Increment usage counters"""
|
|
user = self.get_by_id(user_id)
|
|
if not user:
|
|
return None
|
|
|
|
# Check if usage needs to be reset (monthly)
|
|
if user.usage_reset_date:
|
|
now = datetime.utcnow()
|
|
if now.month != user.usage_reset_date.month or now.year != user.usage_reset_date.year:
|
|
user.docs_translated_this_month = 0
|
|
user.pages_translated_this_month = 0
|
|
user.api_calls_this_month = 0
|
|
user.usage_reset_date = now
|
|
|
|
user.docs_translated_this_month += docs
|
|
user.pages_translated_this_month += pages
|
|
user.api_calls_this_month += api_calls
|
|
|
|
self.db.commit()
|
|
self.db.refresh(user)
|
|
return user
|
|
|
|
def add_credits(self, user_id: str, credits: int) -> Optional[User]:
|
|
"""Add extra credits to user"""
|
|
user = self.get_by_id(user_id)
|
|
if not user:
|
|
return None
|
|
|
|
user.extra_credits += credits
|
|
self.db.commit()
|
|
self.db.refresh(user)
|
|
return user
|
|
|
|
def use_credits(self, user_id: str, credits: int) -> bool:
|
|
"""Use credits from user balance"""
|
|
user = self.get_by_id(user_id)
|
|
if not user or user.extra_credits < credits:
|
|
return False
|
|
|
|
user.extra_credits -= credits
|
|
self.db.commit()
|
|
return True
|
|
|
|
def get_all_users(
|
|
self,
|
|
skip: int = 0,
|
|
limit: int = 100,
|
|
plan: Optional[PlanType] = None
|
|
) -> List[User]:
|
|
"""Get all users with pagination"""
|
|
query = self.db.query(User)
|
|
if plan:
|
|
query = query.filter(User.plan == plan)
|
|
return query.offset(skip).limit(limit).all()
|
|
|
|
def count_users(self, plan: Optional[PlanType] = None) -> int:
|
|
"""Count total users"""
|
|
query = self.db.query(func.count(User.id))
|
|
if plan:
|
|
query = query.filter(User.plan == plan)
|
|
return query.scalar()
|
|
|
|
|
|
class TranslationRepository:
|
|
"""Repository for Translation database operations"""
|
|
|
|
def __init__(self, db: Session):
|
|
self.db = db
|
|
|
|
def create(
|
|
self,
|
|
user_id: str,
|
|
original_filename: str,
|
|
file_type: str,
|
|
target_language: str,
|
|
provider: str,
|
|
source_language: str = "auto",
|
|
file_size_bytes: int = 0,
|
|
page_count: int = 0,
|
|
) -> Translation:
|
|
"""Create a new translation record"""
|
|
translation = Translation(
|
|
user_id=user_id,
|
|
original_filename=original_filename,
|
|
file_type=file_type,
|
|
file_size_bytes=file_size_bytes,
|
|
page_count=page_count,
|
|
source_language=source_language,
|
|
target_language=target_language,
|
|
provider=provider,
|
|
status="pending",
|
|
)
|
|
self.db.add(translation)
|
|
self.db.commit()
|
|
self.db.refresh(translation)
|
|
return translation
|
|
|
|
def update_status(
|
|
self,
|
|
translation_id: str,
|
|
status: str,
|
|
error_message: Optional[str] = None,
|
|
processing_time_ms: Optional[int] = None,
|
|
characters_translated: Optional[int] = None,
|
|
) -> Optional[Translation]:
|
|
"""Update translation status"""
|
|
translation = self.db.query(Translation).filter(
|
|
Translation.id == translation_id
|
|
).first()
|
|
|
|
if not translation:
|
|
return None
|
|
|
|
translation.status = status
|
|
if error_message:
|
|
translation.error_message = error_message
|
|
if processing_time_ms:
|
|
translation.processing_time_ms = processing_time_ms
|
|
if characters_translated:
|
|
translation.characters_translated = characters_translated
|
|
if status == "completed":
|
|
translation.completed_at = datetime.utcnow()
|
|
|
|
self.db.commit()
|
|
self.db.refresh(translation)
|
|
return translation
|
|
|
|
def get_user_translations(
|
|
self,
|
|
user_id: str,
|
|
skip: int = 0,
|
|
limit: int = 50,
|
|
status: Optional[str] = None,
|
|
) -> List[Translation]:
|
|
"""Get user's translation history"""
|
|
query = self.db.query(Translation).filter(Translation.user_id == user_id)
|
|
if status:
|
|
query = query.filter(Translation.status == status)
|
|
return query.order_by(Translation.created_at.desc()).offset(skip).limit(limit).all()
|
|
|
|
def get_user_stats(self, user_id: str, days: int = 30) -> Dict[str, Any]:
|
|
"""Get user's translation statistics"""
|
|
since = datetime.utcnow() - timedelta(days=days)
|
|
|
|
result = self.db.query(
|
|
func.count(Translation.id).label("total_translations"),
|
|
func.sum(Translation.page_count).label("total_pages"),
|
|
func.sum(Translation.characters_translated).label("total_characters"),
|
|
).filter(
|
|
and_(
|
|
Translation.user_id == user_id,
|
|
Translation.created_at >= since,
|
|
Translation.status == "completed",
|
|
)
|
|
).first()
|
|
|
|
return {
|
|
"total_translations": result.total_translations or 0,
|
|
"total_pages": result.total_pages or 0,
|
|
"total_characters": result.total_characters or 0,
|
|
"period_days": days,
|
|
}
|
|
|
|
|
|
class ApiKeyRepository:
|
|
"""Repository for API Key database operations"""
|
|
|
|
def __init__(self, db: Session):
|
|
self.db = db
|
|
|
|
@staticmethod
|
|
def hash_key(key: str) -> str:
|
|
"""Hash an API key"""
|
|
return hashlib.sha256(key.encode()).hexdigest()
|
|
|
|
def create(
|
|
self,
|
|
user_id: str,
|
|
name: str,
|
|
scopes: List[str] = None,
|
|
expires_in_days: Optional[int] = None,
|
|
) -> tuple[ApiKey, str]:
|
|
"""Create a new API key. Returns (ApiKey, raw_key)"""
|
|
# Generate a secure random key
|
|
raw_key = f"tr_{secrets.token_urlsafe(32)}"
|
|
key_hash = self.hash_key(raw_key)
|
|
key_prefix = raw_key[:10]
|
|
|
|
expires_at = None
|
|
if expires_in_days:
|
|
expires_at = datetime.utcnow() + timedelta(days=expires_in_days)
|
|
|
|
api_key = ApiKey(
|
|
user_id=user_id,
|
|
name=name,
|
|
key_hash=key_hash,
|
|
key_prefix=key_prefix,
|
|
scopes=scopes or ["translate"],
|
|
expires_at=expires_at,
|
|
)
|
|
self.db.add(api_key)
|
|
self.db.commit()
|
|
self.db.refresh(api_key)
|
|
|
|
return api_key, raw_key
|
|
|
|
def get_by_key(self, raw_key: str) -> Optional[ApiKey]:
|
|
"""Get API key by raw key value"""
|
|
key_hash = self.hash_key(raw_key)
|
|
api_key = self.db.query(ApiKey).filter(
|
|
and_(
|
|
ApiKey.key_hash == key_hash,
|
|
ApiKey.is_active == True,
|
|
)
|
|
).first()
|
|
|
|
if api_key:
|
|
# Check expiration
|
|
if api_key.expires_at and api_key.expires_at < datetime.utcnow():
|
|
return None
|
|
|
|
# Update last used
|
|
api_key.last_used_at = datetime.utcnow()
|
|
api_key.usage_count += 1
|
|
self.db.commit()
|
|
|
|
return api_key
|
|
|
|
def get_user_keys(self, user_id: str) -> List[ApiKey]:
|
|
"""Get all API keys for a user"""
|
|
return self.db.query(ApiKey).filter(
|
|
ApiKey.user_id == user_id
|
|
).order_by(ApiKey.created_at.desc()).all()
|
|
|
|
def revoke(self, key_id: str, user_id: str) -> bool:
|
|
"""Revoke an API key"""
|
|
api_key = self.db.query(ApiKey).filter(
|
|
and_(
|
|
ApiKey.id == key_id,
|
|
ApiKey.user_id == user_id,
|
|
)
|
|
).first()
|
|
|
|
if not api_key:
|
|
return False
|
|
|
|
api_key.is_active = False
|
|
self.db.commit()
|
|
return True
|