feat: revue de code, doc CODE_REVIEW, forfaits 2026, traduction LLM, providers avec modèle

Made-with: Cursor
This commit is contained in:
Sepehr Ramezani
2026-03-07 11:42:58 +01:00
parent 3d37ce4582
commit 473b3e26c7
181 changed files with 30617 additions and 7170 deletions

View File

@@ -1,17 +1,46 @@
"""
Database module for the Document Translation API
Provides PostgreSQL support with async SQLAlchemy
Provides PostgreSQL/SQLite support with async SQLAlchemy 2.0
"""
from database.connection import get_db, engine, SessionLocal, init_db
from database.models import User, Subscription, Translation, ApiKey
from database.connection import (
get_db,
get_db_session,
get_async_session,
engine,
AsyncSessionLocal,
init_db,
get_engine,
)
from database.models import (
Base,
User,
Translation,
ApiKey,
UsageLog,
PaymentHistory,
PlanType,
SubscriptionStatus,
Glossary,
GlossaryTerm,
)
__all__ = [
"get_db",
"engine",
"SessionLocal",
"get_db_session",
"get_async_session",
"engine",
"AsyncSessionLocal",
"init_db",
"get_engine",
"Base",
"User",
"Subscription",
"Translation",
"ApiKey"
"ApiKey",
"UsageLog",
"PaymentHistory",
"PlanType",
"SubscriptionStatus",
"Glossary",
"GlossaryTerm",
]

View File

@@ -1,135 +1,174 @@
"""
Database connection and session management
Supports both PostgreSQL (production) and SQLite (development/testing)
Async SQLAlchemy 2.0 implementation
"""
import os
import logging
from typing import Generator, Optional
from typing import AsyncGenerator, Optional
from contextlib import asynccontextmanager
from sqlalchemy import text, create_engine
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.ext.asyncio import (
create_async_engine,
AsyncSession,
async_sessionmaker,
AsyncEngine,
)
from sqlalchemy.pool import QueuePool, StaticPool
from contextlib import contextmanager
from sqlalchemy import create_engine, event
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.pool import QueuePool, StaticPool
from database.utils import convert_to_async_url
logger = logging.getLogger(__name__)
# Database URL from environment
# PostgreSQL: postgresql://user:password@host:port/database
# SQLite: sqlite:///./data/translate.db
DATABASE_URL = os.getenv("DATABASE_URL", "")
# Determine if we're using SQLite or PostgreSQL
_is_sqlite = DATABASE_URL.startswith("sqlite") if DATABASE_URL else True
_is_postgres = DATABASE_URL.startswith("postgres") if DATABASE_URL else False
# Create engine based on database type
if DATABASE_URL and not _is_sqlite:
# PostgreSQL configuration
engine = create_engine(
DATABASE_URL,
if DATABASE_URL and _is_postgres:
async_database_url = convert_to_async_url(DATABASE_URL)
engine: AsyncEngine = create_async_engine(
async_database_url,
poolclass=QueuePool,
pool_size=5,
max_overflow=10,
pool_timeout=30,
pool_recycle=1800, # Recycle connections after 30 minutes
pool_pre_ping=True, # Check connection health before use
pool_recycle=1800,
pool_pre_ping=True,
echo=os.getenv("DATABASE_ECHO", "false").lower() == "true",
)
logger.info("✅ Database configured with PostgreSQL")
logger.info("✅ Database configured with PostgreSQL (async)")
else:
# SQLite configuration (for development/testing or when no DATABASE_URL)
sqlite_path = os.getenv("SQLITE_PATH", "data/translate.db")
os.makedirs(os.path.dirname(sqlite_path), exist_ok=True)
sqlite_url = f"sqlite:///./{sqlite_path}"
engine = create_engine(
sqlite_url,
os.makedirs(
os.path.dirname(sqlite_path) if os.path.dirname(sqlite_path) else ".",
exist_ok=True,
)
async_database_url = f"sqlite+aiosqlite:///./{sqlite_path}"
engine: AsyncEngine = create_async_engine(
async_database_url,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
echo=os.getenv("DATABASE_ECHO", "false").lower() == "true",
)
# Enable foreign keys for SQLite
@event.listens_for(engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
if not DATABASE_URL:
logger.warning("⚠️ DATABASE_URL not set, using SQLite for development")
else:
logger.info(f"✅ Database configured with SQLite: {sqlite_path}")
# Session factory
SessionLocal = sessionmaker(
if not DATABASE_URL:
logger.warning("⚠️ DATABASE_URL not set, using SQLite for development (async)")
else:
logger.info(f"✅ Database configured with SQLite: {sqlite_path} (async)")
# Sync engine and session for repositories (auth, translation log).
# Kept for backward compatibility until all callers use async; see story 1-1.
# Prefer get_db() / AsyncSessionLocal for new code.
if DATABASE_URL and _is_postgres:
sync_engine = create_engine(
DATABASE_URL,
poolclass=QueuePool,
pool_size=5,
max_overflow=10,
pool_pre_ping=True,
echo=os.getenv("DATABASE_ECHO", "false").lower() == "true",
)
else:
_sqlite_path = os.getenv("SQLITE_PATH", "data/translate.db")
_sync_sqlite_url = f"sqlite:///./{_sqlite_path}"
sync_engine = create_engine(
_sync_sqlite_url,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
SyncSessionLocal = sessionmaker(bind=sync_engine, autocommit=False, autoflush=False, expire_on_commit=False)
@contextmanager
def get_sync_session():
"""Sync session context manager for use with sync repositories (auth_service, translation log)."""
session = SyncSessionLocal()
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
AsyncSessionLocal = async_sessionmaker(
bind=engine,
class_=AsyncSession,
autocommit=False,
autoflush=False,
bind=engine,
expire_on_commit=False,
)
def get_db() -> Generator[Session, None, None]:
async def get_db() -> AsyncGenerator[AsyncSession, None]:
"""
Dependency for FastAPI to get database session.
Usage: db: Session = Depends(get_db)
Async dependency for FastAPI to get database session.
Usage: db: AsyncSession = Depends(get_db)
"""
db = SessionLocal()
try:
yield db
finally:
db.close()
async with AsyncSessionLocal() as session:
try:
yield session
finally:
await session.close()
@contextmanager
def get_db_session() -> Generator[Session, None, None]:
@asynccontextmanager
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
"""
Context manager for database session.
Usage: with get_db_session() as db: ...
Async context manager for database session.
Usage: async with get_db_session() as db: ...
"""
db = SessionLocal()
try:
yield db
db.commit()
except Exception:
db.rollback()
raise
finally:
db.close()
async with AsyncSessionLocal() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
# Alias for backward compatibility
get_sync_session = get_db_session
get_async_session = get_db_session
def init_db():
async def init_db():
"""
Initialize database tables.
Initialize database tables asynchronously.
Call this on application startup.
"""
from database.models import Base
Base.metadata.create_all(bind=engine)
logger.info("✅ Database tables initialized")
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
logger.info("✅ Database tables initialized (async)")
def check_db_connection() -> bool:
async def check_db_connection() -> bool:
"""
Check if database connection is healthy.
Returns True if connection works, False otherwise.
"""
try:
with engine.connect() as conn:
conn.execute("SELECT 1")
async with engine.connect() as conn:
await conn.execute(text("SELECT 1"))
return True
except Exception as e:
logger.error(f"Database connection check failed: {e}")
return False
# Connection pool stats (for monitoring)
def get_pool_stats() -> dict:
"""Get database connection pool statistics"""
if hasattr(engine.pool, 'status'):
if hasattr(engine.pool, "status"):
return {
"pool_size": engine.pool.size(),
"checked_in": engine.pool.checkedin(),
@@ -137,3 +176,8 @@ def get_pool_stats() -> dict:
"overflow": engine.pool.overflow(),
}
return {"status": "pool stats not available"}
def get_engine() -> AsyncEngine:
"""Get the async engine instance"""
return engine

View File

@@ -1,19 +1,34 @@
"""
SQLAlchemy models for the Document Translation API
"""
import os
import uuid
from datetime import datetime
from typing import Optional, List
from datetime import datetime, timezone
import warnings
from sqlalchemy import (
Column, String, Integer, Float, Boolean, DateTime, Text,
ForeignKey, Enum, Index, JSON, BigInteger
Column,
String,
Integer,
Float,
Boolean,
DateTime,
Text,
ForeignKey,
Enum,
Index,
JSON,
BigInteger,
CheckConstraint,
)
from sqlalchemy.orm import relationship, declarative_base
from sqlalchemy.dialects.postgresql import UUID as PG_UUID
import enum
def _utcnow() -> datetime:
return datetime.now(timezone.utc)
Base = declarative_base()
@@ -22,6 +37,11 @@ def generate_uuid():
return str(uuid.uuid4())
def generate_uuid_value():
"""Generate a new UUID value for PostgreSQL UUID column"""
return uuid.uuid4()
class PlanType(str, enum.Enum):
FREE = "free"
STARTER = "starter"
@@ -40,57 +60,78 @@ class SubscriptionStatus(str, enum.Enum):
class User(Base):
"""User model for authentication and billing"""
__tablename__ = "users"
id = Column(String(36), primary_key=True, default=generate_uuid)
email = Column(String(255), unique=True, nullable=False, index=True)
name = Column(String(255), nullable=False)
password_hash = Column(String(255), nullable=False)
# Account status
hashed_password = Column(String(255), nullable=False)
tier = Column(String(10), default="free", nullable=False)
daily_translation_count = Column(Integer, default=0, nullable=False)
email_verified = Column(Boolean, default=False)
is_active = Column(Boolean, default=True)
avatar_url = Column(String(500), nullable=True)
# Subscription info
plan = Column(Enum(PlanType), default=PlanType.FREE)
subscription_status = Column(Enum(SubscriptionStatus), default=SubscriptionStatus.ACTIVE)
# Stripe integration
subscription_status = Column(
Enum(SubscriptionStatus), default=SubscriptionStatus.ACTIVE
)
stripe_customer_id = Column(String(255), nullable=True, index=True)
stripe_subscription_id = Column(String(255), nullable=True)
# Usage tracking (reset monthly)
docs_translated_this_month = Column(Integer, default=0)
pages_translated_this_month = Column(Integer, default=0)
api_calls_this_month = Column(Integer, default=0)
extra_credits = Column(Integer, default=0) # Purchased credits
usage_reset_date = Column(DateTime, default=datetime.utcnow)
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
extra_credits = Column(Integer, default=0)
usage_reset_date = Column(DateTime, default=_utcnow)
created_at = Column(DateTime, default=_utcnow)
updated_at = Column(DateTime, default=_utcnow, onupdate=_utcnow)
last_login_at = Column(DateTime, nullable=True)
# Relationships
translations = relationship("Translation", back_populates="user", lazy="dynamic")
api_keys = relationship("ApiKey", back_populates="user", lazy="dynamic")
# Indexes
translations = relationship("Translation", back_populates="user", lazy="select")
api_keys = relationship("ApiKey", back_populates="user", lazy="select")
__table_args__ = (
Index('ix_users_email_active', 'email', 'is_active'),
Index('ix_users_stripe_customer', 'stripe_customer_id'),
CheckConstraint("tier IN ('free', 'pro')", name="ck_users_tier"),
Index("ix_users_email_active", "email", "is_active"),
Index("ix_users_stripe_customer", "stripe_customer_id"),
)
@property
def password_hash(self) -> str:
warnings.warn(
"password_hash is deprecated, use hashed_password instead",
DeprecationWarning,
stacklevel=2,
)
return self.hashed_password
@password_hash.setter
def password_hash(self, value: str) -> None:
warnings.warn(
"password_hash is deprecated, use hashed_password instead",
DeprecationWarning,
stacklevel=2,
)
self.hashed_password = value
def to_dict(self) -> dict:
"""Convert user to dictionary for API response"""
return {
"id": self.id,
"email": self.email,
"name": self.name,
"avatar_url": self.avatar_url,
"tier": self.tier,
"plan": self.plan.value if self.plan else "free",
"subscription_status": self.subscription_status.value if self.subscription_status else "active",
"subscription_status": self.subscription_status.value
if self.subscription_status
else "active",
"daily_translation_count": self.daily_translation_count,
"docs_translated_this_month": self.docs_translated_this_month,
"pages_translated_this_month": self.pages_translated_this_month,
"api_calls_this_month": self.api_calls_this_month,
@@ -102,44 +143,49 @@ class User(Base):
class Translation(Base):
"""Translation history for analytics and billing"""
__tablename__ = "translations"
id = Column(String(36), primary_key=True, default=generate_uuid)
user_id = Column(String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
user_id = Column(
String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False
)
# File info
original_filename = Column(String(255), nullable=False)
file_type = Column(String(10), nullable=False) # xlsx, docx, pptx
file_type = Column(String(20), nullable=False) # xlsx, docx, pptx
file_size_bytes = Column(BigInteger, default=0)
page_count = Column(Integer, default=0)
# Translation details
source_language = Column(String(10), default="auto")
target_language = Column(String(10), nullable=False)
provider = Column(String(50), nullable=False) # google, deepl, ollama, etc.
# Processing info
status = Column(String(20), default="pending") # pending, processing, completed, failed
status = Column(
String(20), default="pending"
) # pending, processing, completed, failed
error_message = Column(Text, nullable=True)
processing_time_ms = Column(Integer, nullable=True)
# Cost tracking (for paid providers)
characters_translated = Column(Integer, default=0)
estimated_cost_usd = Column(Float, default=0.0)
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow)
created_at = Column(DateTime, default=_utcnow)
completed_at = Column(DateTime, nullable=True)
# Relationship
user = relationship("User", back_populates="translations")
# Indexes
__table_args__ = (
Index('ix_translations_user_date', 'user_id', 'created_at'),
Index('ix_translations_status', 'status'),
Index("ix_translations_user_date", "user_id", "created_at"),
Index("ix_translations_status", "status"),
)
def to_dict(self) -> dict:
return {
"id": self.id,
@@ -154,43 +200,49 @@ class Translation(Base):
"processing_time_ms": self.processing_time_ms,
"characters_translated": self.characters_translated,
"created_at": self.created_at.isoformat() if self.created_at else None,
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
"completed_at": self.completed_at.isoformat()
if self.completed_at
else None,
}
class ApiKey(Base):
"""API keys for programmatic access"""
__tablename__ = "api_keys"
id = Column(String(36), primary_key=True, default=generate_uuid)
user_id = Column(String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
user_id = Column(
String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False
)
# Key info
name = Column(String(100), nullable=False) # User-friendly name
key_hash = Column(String(255), nullable=False) # SHA256 of the key
key_prefix = Column(String(10), nullable=False) # First 8 chars for identification
# Permissions
is_active = Column(Boolean, default=True)
scopes = Column(JSON, default=list) # ["translate", "read", "write"]
# Usage tracking
last_used_at = Column(DateTime, nullable=True)
usage_count = Column(Integer, default=0)
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow)
created_at = Column(DateTime, default=_utcnow)
expires_at = Column(DateTime, nullable=True)
revoked_at = Column(DateTime, nullable=True) # Set when is_active=False
# Relationship
user = relationship("User", back_populates="api_keys")
# Indexes
__table_args__ = (
Index('ix_api_keys_prefix', 'key_prefix'),
Index('ix_api_keys_hash', 'key_hash'),
Index("ix_api_keys_prefix", "key_prefix"),
Index("ix_api_keys_hash", "key_hash"),
)
def to_dict(self) -> dict:
return {
"id": self.id,
@@ -198,7 +250,9 @@ class ApiKey(Base):
"key_prefix": self.key_prefix,
"is_active": self.is_active,
"scopes": self.scopes,
"last_used_at": self.last_used_at.isoformat() if self.last_used_at else None,
"last_used_at": self.last_used_at.isoformat()
if self.last_used_at
else None,
"usage_count": self.usage_count,
"created_at": self.created_at.isoformat() if self.created_at else None,
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
@@ -207,53 +261,149 @@ class ApiKey(Base):
class UsageLog(Base):
"""Daily usage aggregation for billing and analytics"""
__tablename__ = "usage_logs"
id = Column(String(36), primary_key=True, default=generate_uuid)
user_id = Column(String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
user_id = Column(
String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False
)
# Date (for daily aggregation)
date = Column(DateTime, nullable=False, index=True)
# Aggregated counts
documents_count = Column(Integer, default=0)
pages_count = Column(Integer, default=0)
characters_count = Column(BigInteger, default=0)
api_calls_count = Column(Integer, default=0)
# By provider breakdown (JSON)
provider_breakdown = Column(JSON, default=dict)
# Indexes
__table_args__ = (
Index('ix_usage_logs_user_date', 'user_id', 'date', unique=True),
)
__table_args__ = (Index("ix_usage_logs_user_date", "user_id", "date", unique=True),)
class PaymentHistory(Base):
"""Payment and invoice history"""
__tablename__ = "payment_history"
id = Column(String(36), primary_key=True, default=generate_uuid)
user_id = Column(String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
user_id = Column(
String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False
)
# Stripe info
stripe_payment_intent_id = Column(String(255), nullable=True)
stripe_invoice_id = Column(String(255), nullable=True)
# Payment details
amount_cents = Column(Integer, nullable=False)
currency = Column(String(3), default="usd")
payment_type = Column(String(50), nullable=False) # subscription, credits, one_time
status = Column(String(20), nullable=False) # succeeded, failed, pending, refunded
# Description
description = Column(String(255), nullable=True)
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow)
created_at = Column(DateTime, default=_utcnow)
# Indexes
__table_args__ = (
Index('ix_payment_history_user', 'user_id', 'created_at'),
__table_args__ = (Index("ix_payment_history_user", "user_id", "created_at"),)
class Glossary(Base):
"""User's glossary containing source->target term pairs.
Story 3.9: Glossaires - Endpoint CRUD
"""
__tablename__ = "glossaries"
id = Column(String(36), primary_key=True, default=generate_uuid)
user_id = Column(
String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False
)
name = Column(String(255), nullable=False)
created_at = Column(DateTime, default=_utcnow)
updated_at = Column(DateTime, default=_utcnow, onupdate=_utcnow)
# Relationship
terms = relationship(
"GlossaryTerm", back_populates="glossary", cascade="all, delete-orphan"
)
# Indexes
__table_args__ = (Index("ix_glossaries_user_id", "user_id"),)
def to_dict(self) -> dict:
return {
"id": self.id,
"user_id": self.user_id,
"name": self.name,
"terms": [term.to_dict() for term in self.terms] if self.terms else [],
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
}
class GlossaryTerm(Base):
"""Single term pair in a glossary.
Story 3.9: Glossaires - Endpoint CRUD
"""
__tablename__ = "glossary_terms"
id = Column(String(36), primary_key=True, default=generate_uuid)
glossary_id = Column(
String(36), ForeignKey("glossaries.id", ondelete="CASCADE"), nullable=False
)
source = Column(String(500), nullable=False)
target = Column(String(500), nullable=False)
created_at = Column(DateTime, default=_utcnow)
# Relationship
glossary = relationship("Glossary", back_populates="terms")
# Indexes
__table_args__ = (Index("ix_glossary_terms_glossary_id", "glossary_id"),)
def to_dict(self) -> dict:
return {
"id": self.id,
"source": self.source,
"target": self.target,
"created_at": self.created_at.isoformat() if self.created_at else None,
}
class CustomPrompt(Base):
"""User's custom prompts for LLM translation context.
Story 3.11: Custom Prompts - Endpoint CRUD
"""
__tablename__ = "custom_prompts"
id = Column(String(36), primary_key=True, default=generate_uuid)
user_id = Column(
String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False
)
name = Column(String(255), nullable=False)
content = Column(Text, nullable=False)
created_at = Column(DateTime, default=_utcnow)
updated_at = Column(DateTime, default=_utcnow, onupdate=_utcnow)
# Indexes
__table_args__ = (Index("ix_custom_prompts_user_id", "user_id"),)
def to_dict(self) -> dict:
return {
"id": self.id,
"user_id": self.user_id,
"name": self.name,
"content": self.content,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
}

View File

@@ -2,54 +2,64 @@
Repository layer for database operations
Provides clean interface for CRUD operations
"""
import hashlib
import secrets
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from typing import Optional, List, Dict, Any
from sqlalchemy.orm import Session
from sqlalchemy import and_, func, or_
from database.models import (
User, Translation, ApiKey, UsageLog, PaymentHistory,
PlanType, SubscriptionStatus
User,
Translation,
ApiKey,
UsageLog,
PaymentHistory,
PlanType,
SubscriptionStatus,
)
class UserRepository:
"""Repository for User database operations"""
def __init__(self, db: Session):
self.db = db
def get_by_id(self, user_id: str) -> Optional[User]:
"""Get user by ID"""
return self.db.query(User).filter(User.id == user_id).first()
def get_by_email(self, email: str) -> Optional[User]:
"""Get user by email (case-insensitive)"""
return self.db.query(User).filter(
func.lower(User.email) == email.lower()
).first()
return (
self.db.query(User).filter(func.lower(User.email) == email.lower()).first()
)
def get_by_stripe_customer(self, stripe_customer_id: str) -> Optional[User]:
"""Get user by Stripe customer ID"""
return self.db.query(User).filter(
User.stripe_customer_id == stripe_customer_id
).first()
return (
self.db.query(User)
.filter(User.stripe_customer_id == stripe_customer_id)
.first()
)
def create(
self,
email: str,
name: str,
password_hash: str,
plan: PlanType = PlanType.FREE
self,
email: str,
name: str,
hashed_password: str,
tier: str = "free",
) -> User:
"""Create a new user"""
"""Create a new user. Uses hashed_password and tier (story 1-1 refactor)."""
plan = PlanType.PRO if tier == "pro" else PlanType.FREE
user = User(
email=email.lower(),
name=name,
password_hash=password_hash,
hashed_password=hashed_password,
tier=tier,
plan=plan,
subscription_status=SubscriptionStatus.ACTIVE,
)
@@ -57,94 +67,90 @@ class UserRepository:
self.db.commit()
self.db.refresh(user)
return user
def update(self, user_id: str, **kwargs) -> Optional[User]:
"""Update user fields"""
user = self.get_by_id(user_id)
if not user:
return None
for key, value in kwargs.items():
if hasattr(user, key):
setattr(user, key, value)
user.updated_at = datetime.utcnow()
user.updated_at = datetime.now(timezone.utc)
self.db.commit()
self.db.refresh(user)
return user
def delete(self, user_id: str) -> bool:
"""Delete a user"""
user = self.get_by_id(user_id)
if not user:
return False
self.db.delete(user)
self.db.commit()
return True
def increment_usage(
self,
user_id: str,
docs: int = 0,
pages: int = 0,
api_calls: int = 0
self, user_id: str, docs: int = 0, pages: int = 0, api_calls: int = 0
) -> Optional[User]:
"""Increment usage counters"""
user = self.get_by_id(user_id)
if not user:
return None
# Check if usage needs to be reset (monthly)
if user.usage_reset_date:
now = datetime.utcnow()
if now.month != user.usage_reset_date.month or now.year != user.usage_reset_date.year:
now = datetime.now(timezone.utc)
if (
now.month != user.usage_reset_date.month
or now.year != user.usage_reset_date.year
):
user.docs_translated_this_month = 0
user.pages_translated_this_month = 0
user.api_calls_this_month = 0
user.usage_reset_date = now
user.docs_translated_this_month += docs
user.pages_translated_this_month += pages
user.api_calls_this_month += api_calls
self.db.commit()
self.db.refresh(user)
return user
def add_credits(self, user_id: str, credits: int) -> Optional[User]:
"""Add extra credits to user"""
user = self.get_by_id(user_id)
if not user:
return None
user.extra_credits += credits
self.db.commit()
self.db.refresh(user)
return user
def use_credits(self, user_id: str, credits: int) -> bool:
"""Use credits from user balance"""
user = self.get_by_id(user_id)
if not user or user.extra_credits < credits:
return False
user.extra_credits -= credits
self.db.commit()
return True
def get_all_users(
self,
skip: int = 0,
limit: int = 100,
plan: Optional[PlanType] = None
self, skip: int = 0, limit: int = 100, plan: Optional[PlanType] = None
) -> List[User]:
"""Get all users with pagination"""
query = self.db.query(User)
if plan:
query = query.filter(User.plan == plan)
return query.offset(skip).limit(limit).all()
def count_users(self, plan: Optional[PlanType] = None) -> int:
"""Count total users"""
query = self.db.query(func.count(User.id))
@@ -155,10 +161,10 @@ class UserRepository:
class TranslationRepository:
"""Repository for Translation database operations"""
def __init__(self, db: Session):
self.db = db
def create(
self,
user_id: str,
@@ -186,7 +192,36 @@ class TranslationRepository:
self.db.commit()
self.db.refresh(translation)
return translation
def create_completed(
self,
user_id: str,
original_filename: str,
file_type: str,
target_language: str,
provider: str,
source_language: str = "auto",
file_size_bytes: int = 0,
page_count: int = 0,
) -> Translation:
"""Create a translation record directly in completed status (Story 1.8 - billing log)."""
translation = Translation(
user_id=user_id,
original_filename=original_filename,
file_type=file_type,
file_size_bytes=file_size_bytes,
page_count=page_count,
source_language=source_language,
target_language=target_language,
provider=provider,
status="completed",
completed_at=datetime.now(timezone.utc),
)
self.db.add(translation)
self.db.commit()
self.db.refresh(translation)
return translation
def update_status(
self,
translation_id: str,
@@ -196,13 +231,13 @@ class TranslationRepository:
characters_translated: Optional[int] = None,
) -> Optional[Translation]:
"""Update translation status"""
translation = self.db.query(Translation).filter(
Translation.id == translation_id
).first()
translation = (
self.db.query(Translation).filter(Translation.id == translation_id).first()
)
if not translation:
return None
translation.status = status
if error_message:
translation.error_message = error_message
@@ -211,12 +246,12 @@ class TranslationRepository:
if characters_translated:
translation.characters_translated = characters_translated
if status == "completed":
translation.completed_at = datetime.utcnow()
translation.completed_at = datetime.now(timezone.utc)
self.db.commit()
self.db.refresh(translation)
return translation
def get_user_translations(
self,
user_id: str,
@@ -228,24 +263,33 @@ class TranslationRepository:
query = self.db.query(Translation).filter(Translation.user_id == user_id)
if status:
query = query.filter(Translation.status == status)
return query.order_by(Translation.created_at.desc()).offset(skip).limit(limit).all()
return (
query.order_by(Translation.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
def get_user_stats(self, user_id: str, days: int = 30) -> Dict[str, Any]:
"""Get user's translation statistics"""
since = datetime.utcnow() - timedelta(days=days)
result = self.db.query(
func.count(Translation.id).label("total_translations"),
func.sum(Translation.page_count).label("total_pages"),
func.sum(Translation.characters_translated).label("total_characters"),
).filter(
and_(
Translation.user_id == user_id,
Translation.created_at >= since,
Translation.status == "completed",
since = datetime.now(timezone.utc) - timedelta(days=days)
result = (
self.db.query(
func.count(Translation.id).label("total_translations"),
func.sum(Translation.page_count).label("total_pages"),
func.sum(Translation.characters_translated).label("total_characters"),
)
).first()
.filter(
and_(
Translation.user_id == user_id,
Translation.created_at >= since,
Translation.status == "completed",
)
)
.first()
)
return {
"total_translations": result.total_translations or 0,
"total_pages": result.total_pages or 0,
@@ -256,15 +300,15 @@ class TranslationRepository:
class ApiKeyRepository:
"""Repository for API Key database operations"""
def __init__(self, db: Session):
self.db = db
@staticmethod
def hash_key(key: str) -> str:
"""Hash an API key"""
return hashlib.sha256(key.encode()).hexdigest()
def create(
self,
user_id: str,
@@ -277,11 +321,11 @@ class ApiKeyRepository:
raw_key = f"tr_{secrets.token_urlsafe(32)}"
key_hash = self.hash_key(raw_key)
key_prefix = raw_key[:10]
expires_at = None
if expires_in_days:
expires_at = datetime.utcnow() + timedelta(days=expires_in_days)
expires_at = datetime.now(timezone.utc) + timedelta(days=expires_in_days)
api_key = ApiKey(
user_id=user_id,
name=name,
@@ -293,49 +337,60 @@ class ApiKeyRepository:
self.db.add(api_key)
self.db.commit()
self.db.refresh(api_key)
return api_key, raw_key
def get_by_key(self, raw_key: str) -> Optional[ApiKey]:
"""Get API key by raw key value"""
key_hash = self.hash_key(raw_key)
api_key = self.db.query(ApiKey).filter(
and_(
ApiKey.key_hash == key_hash,
ApiKey.is_active == True,
api_key = (
self.db.query(ApiKey)
.filter(
and_(
ApiKey.key_hash == key_hash,
ApiKey.is_active == True,
)
)
).first()
.first()
)
if api_key:
# Check expiration
if api_key.expires_at and api_key.expires_at < datetime.utcnow():
if api_key.expires_at and api_key.expires_at < datetime.now(timezone.utc):
return None
# Update last used
api_key.last_used_at = datetime.utcnow()
api_key.last_used_at = datetime.now(timezone.utc)
api_key.usage_count += 1
self.db.commit()
return api_key
def get_user_keys(self, user_id: str) -> List[ApiKey]:
"""Get all API keys for a user"""
return self.db.query(ApiKey).filter(
ApiKey.user_id == user_id
).order_by(ApiKey.created_at.desc()).all()
return (
self.db.query(ApiKey)
.filter(ApiKey.user_id == user_id)
.order_by(ApiKey.created_at.desc())
.all()
)
def revoke(self, key_id: str, user_id: str) -> bool:
"""Revoke an API key"""
api_key = self.db.query(ApiKey).filter(
and_(
ApiKey.id == key_id,
ApiKey.user_id == user_id,
api_key = (
self.db.query(ApiKey)
.filter(
and_(
ApiKey.id == key_id,
ApiKey.user_id == user_id,
)
)
).first()
.first()
)
if not api_key:
return False
api_key.is_active = False
self.db.commit()
return True

14
database/utils.py Normal file
View File

@@ -0,0 +1,14 @@
"""
Shared database utilities
"""
def convert_to_async_url(url: str) -> str:
"""Convert a sync database URL to its async driver equivalent."""
if url.startswith("postgresql://"):
return url.replace("postgresql://", "postgresql+asyncpg://", 1)
elif url.startswith("postgres://"):
return url.replace("postgres://", "postgresql+asyncpg://", 1)
elif url.startswith("sqlite:///"):
return url.replace("sqlite:///", "sqlite+aiosqlite:///", 1)
return url