Files
office_translator/database/repositories.py
sepehr 45e44dd7b2
All checks were successful
Deploy to Production / Build and Deploy (push) Successful in 3m16s
fix(billing): unify quota counters, fix Stripe webhooks, tier/plan sync
2026-06-14 17:39:34 +02:00

415 lines
12 KiB
Python

"""
Repository layer for database operations
Provides clean interface for CRUD operations
"""
import hashlib
import secrets
from datetime import datetime, timedelta, timezone
from typing import Optional, List, Dict, Any
from sqlalchemy.orm import Session
from sqlalchemy import and_, func, or_
from database.models import (
User,
Translation,
ApiKey,
UsageLog,
PaymentHistory,
PlanType,
SubscriptionStatus,
)
class UserRepository:
"""Repository for User database operations"""
def __init__(self, db: Session):
self.db = db
def get_by_id(self, user_id: str) -> Optional[User]:
"""Get user by ID"""
return self.db.query(User).filter(User.id == user_id).first()
def get_by_email(self, email: str) -> Optional[User]:
"""Get user by email (case-insensitive)"""
return (
self.db.query(User).filter(func.lower(User.email) == email.lower()).first()
)
def get_by_stripe_customer(self, stripe_customer_id: str) -> Optional[User]:
"""Get user by Stripe customer ID"""
return (
self.db.query(User)
.filter(User.stripe_customer_id == stripe_customer_id)
.first()
)
def create(
self,
email: str,
name: str,
hashed_password: str,
tier: str = "free",
) -> User:
"""Create a new user. Uses hashed_password and tier (story 1-1 refactor)."""
plan = PlanType.PRO if tier == "pro" else PlanType.FREE
user = User(
email=email.lower(),
name=name,
hashed_password=hashed_password,
tier=tier,
plan=plan,
subscription_status=SubscriptionStatus.ACTIVE,
)
self.db.add(user)
self.db.commit()
self.db.refresh(user)
return user
def update(self, user_id: str, **kwargs) -> Optional[User]:
"""Update user fields"""
user = self.get_by_id(user_id)
if not user:
return None
for key, value in kwargs.items():
if hasattr(user, key):
setattr(user, key, value)
user.updated_at = datetime.now(timezone.utc)
self.db.commit()
self.db.refresh(user)
return user
def delete(self, user_id: str) -> bool:
"""Delete a user"""
user = self.get_by_id(user_id)
if not user:
return False
self.db.delete(user)
self.db.commit()
return True
def reset_usage_if_needed(self, user_id: str) -> Optional[User]:
"""Reset monthly counters if usage_reset_date is in a previous month."""
user = self.get_by_id(user_id)
if not user:
return None
now = datetime.now(timezone.utc)
reset_date = user.usage_reset_date
if reset_date is None or reset_date.month != now.month or reset_date.year != now.year:
user.docs_translated_this_month = 0
user.pages_translated_this_month = 0
user.api_calls_this_month = 0
user.usage_reset_date = now
user.updated_at = now
self.db.commit()
self.db.refresh(user)
return user
def increment_usage(
self, user_id: str, docs: int = 0, pages: int = 0, api_calls: int = 0
) -> Optional[User]:
"""Increment usage counters atomically via SQL UPDATE."""
from sqlalchemy import update
now = datetime.now(timezone.utc)
self.db.execute(
update(User)
.where(User.id == user_id)
.values(
docs_translated_this_month=User.docs_translated_this_month + docs,
pages_translated_this_month=User.pages_translated_this_month + pages,
api_calls_this_month=User.api_calls_this_month + api_calls,
updated_at=now,
)
.execution_options(synchronize_session=False)
)
self.db.commit()
return self.get_by_id(user_id)
def add_credits(self, user_id: str, credits: int) -> Optional[User]:
"""Add extra credits to user"""
user = self.get_by_id(user_id)
if not user:
return None
user.extra_credits += credits
self.db.commit()
self.db.refresh(user)
return user
def use_credits(self, user_id: str, credits: int) -> bool:
"""Use credits from user balance atomically (prevents overdraft)."""
from sqlalchemy import update
now = datetime.now(timezone.utc)
result = self.db.execute(
update(User)
.where(User.id == user_id, User.extra_credits >= credits)
.values(
extra_credits=User.extra_credits - credits,
updated_at=now,
)
.execution_options(synchronize_session=False)
)
self.db.commit()
return result.rowcount > 0
def get_all_users(
self, skip: int = 0, limit: int = 100, plan: Optional[PlanType] = None
) -> List[User]:
"""Get all users with pagination"""
query = self.db.query(User)
if plan:
query = query.filter(User.plan == plan)
return query.offset(skip).limit(limit).all()
def count_users(self, plan: Optional[PlanType] = None) -> int:
"""Count total users"""
query = self.db.query(func.count(User.id))
if plan:
query = query.filter(User.plan == plan)
return query.scalar()
class TranslationRepository:
"""Repository for Translation database operations"""
def __init__(self, db: Session):
self.db = db
def create(
self,
user_id: str,
original_filename: str,
file_type: str,
target_language: str,
provider: str,
source_language: str = "auto",
file_size_bytes: int = 0,
page_count: int = 0,
) -> Translation:
"""Create a new translation record"""
translation = Translation(
user_id=user_id,
original_filename=original_filename,
file_type=file_type,
file_size_bytes=file_size_bytes,
page_count=page_count,
source_language=source_language,
target_language=target_language,
provider=provider,
status="pending",
)
self.db.add(translation)
self.db.commit()
self.db.refresh(translation)
return translation
def create_completed(
self,
user_id: str,
original_filename: str,
file_type: str,
target_language: str,
provider: str,
source_language: str = "auto",
file_size_bytes: int = 0,
page_count: int = 0,
) -> Translation:
"""Create a translation record directly in completed status (Story 1.8 - billing log)."""
translation = Translation(
user_id=user_id,
original_filename=original_filename,
file_type=file_type,
file_size_bytes=file_size_bytes,
page_count=page_count,
source_language=source_language,
target_language=target_language,
provider=provider,
status="completed",
completed_at=datetime.now(timezone.utc),
)
self.db.add(translation)
self.db.commit()
self.db.refresh(translation)
return translation
def update_status(
self,
translation_id: str,
status: str,
error_message: Optional[str] = None,
processing_time_ms: Optional[int] = None,
characters_translated: Optional[int] = None,
) -> Optional[Translation]:
"""Update translation status"""
translation = (
self.db.query(Translation).filter(Translation.id == translation_id).first()
)
if not translation:
return None
translation.status = status
if error_message:
translation.error_message = error_message
if processing_time_ms:
translation.processing_time_ms = processing_time_ms
if characters_translated:
translation.characters_translated = characters_translated
if status == "completed":
translation.completed_at = datetime.now(timezone.utc)
self.db.commit()
self.db.refresh(translation)
return translation
def get_user_translations(
self,
user_id: str,
skip: int = 0,
limit: int = 50,
status: Optional[str] = None,
) -> List[Translation]:
"""Get user's translation history"""
query = self.db.query(Translation).filter(Translation.user_id == user_id)
if status:
query = query.filter(Translation.status == status)
return (
query.order_by(Translation.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
def get_user_stats(self, user_id: str, days: int = 30) -> Dict[str, Any]:
"""Get user's translation statistics"""
since = datetime.now(timezone.utc) - timedelta(days=days)
result = (
self.db.query(
func.count(Translation.id).label("total_translations"),
func.sum(Translation.page_count).label("total_pages"),
func.sum(Translation.characters_translated).label("total_characters"),
)
.filter(
and_(
Translation.user_id == user_id,
Translation.created_at >= since,
Translation.status == "completed",
)
)
.first()
)
return {
"total_translations": result.total_translations or 0,
"total_pages": result.total_pages or 0,
"total_characters": result.total_characters or 0,
"period_days": days,
}
class ApiKeyRepository:
"""Repository for API Key database operations"""
def __init__(self, db: Session):
self.db = db
@staticmethod
def hash_key(key: str) -> str:
"""Hash an API key"""
return hashlib.sha256(key.encode()).hexdigest()
def create(
self,
user_id: str,
name: str,
scopes: List[str] = None,
expires_in_days: Optional[int] = None,
) -> tuple[ApiKey, str]:
"""Create a new API key. Returns (ApiKey, raw_key)"""
# Generate a secure random key
raw_key = f"tr_{secrets.token_urlsafe(32)}"
key_hash = self.hash_key(raw_key)
key_prefix = raw_key[:10]
expires_at = None
if expires_in_days:
expires_at = datetime.now(timezone.utc) + timedelta(days=expires_in_days)
api_key = ApiKey(
user_id=user_id,
name=name,
key_hash=key_hash,
key_prefix=key_prefix,
scopes=scopes or ["translate"],
expires_at=expires_at,
)
self.db.add(api_key)
self.db.commit()
self.db.refresh(api_key)
return api_key, raw_key
def get_by_key(self, raw_key: str) -> Optional[ApiKey]:
"""Get API key by raw key value"""
key_hash = self.hash_key(raw_key)
api_key = (
self.db.query(ApiKey)
.filter(
and_(
ApiKey.key_hash == key_hash,
ApiKey.is_active == True,
)
)
.first()
)
if api_key:
# Check expiration
if api_key.expires_at and api_key.expires_at < datetime.now(timezone.utc):
return None
# Update last used
api_key.last_used_at = datetime.now(timezone.utc)
api_key.usage_count += 1
self.db.commit()
return api_key
def get_user_keys(self, user_id: str) -> List[ApiKey]:
"""Get all API keys for a user"""
return (
self.db.query(ApiKey)
.filter(ApiKey.user_id == user_id)
.order_by(ApiKey.created_at.desc())
.all()
)
def revoke(self, key_id: str, user_id: str) -> bool:
"""Revoke an API key"""
api_key = (
self.db.query(ApiKey)
.filter(
and_(
ApiKey.id == key_id,
ApiKey.user_id == user_id,
)
)
.first()
)
if not api_key:
return False
api_key.is_active = False
self.db.commit()
return True