feat: revue de code, doc CODE_REVIEW, forfaits 2026, traduction LLM, providers avec modèle
Made-with: Cursor
This commit is contained in:
@@ -1,17 +1,46 @@
|
||||
"""
|
||||
Database module for the Document Translation API
|
||||
Provides PostgreSQL support with async SQLAlchemy
|
||||
Provides PostgreSQL/SQLite support with async SQLAlchemy 2.0
|
||||
"""
|
||||
from database.connection import get_db, engine, SessionLocal, init_db
|
||||
from database.models import User, Subscription, Translation, ApiKey
|
||||
|
||||
from database.connection import (
|
||||
get_db,
|
||||
get_db_session,
|
||||
get_async_session,
|
||||
engine,
|
||||
AsyncSessionLocal,
|
||||
init_db,
|
||||
get_engine,
|
||||
)
|
||||
from database.models import (
|
||||
Base,
|
||||
User,
|
||||
Translation,
|
||||
ApiKey,
|
||||
UsageLog,
|
||||
PaymentHistory,
|
||||
PlanType,
|
||||
SubscriptionStatus,
|
||||
Glossary,
|
||||
GlossaryTerm,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"get_db",
|
||||
"engine",
|
||||
"SessionLocal",
|
||||
"get_db_session",
|
||||
"get_async_session",
|
||||
"engine",
|
||||
"AsyncSessionLocal",
|
||||
"init_db",
|
||||
"get_engine",
|
||||
"Base",
|
||||
"User",
|
||||
"Subscription",
|
||||
"Translation",
|
||||
"ApiKey"
|
||||
"ApiKey",
|
||||
"UsageLog",
|
||||
"PaymentHistory",
|
||||
"PlanType",
|
||||
"SubscriptionStatus",
|
||||
"Glossary",
|
||||
"GlossaryTerm",
|
||||
]
|
||||
|
||||
@@ -1,135 +1,174 @@
|
||||
"""
|
||||
Database connection and session management
|
||||
Supports both PostgreSQL (production) and SQLite (development/testing)
|
||||
Async SQLAlchemy 2.0 implementation
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import Generator, Optional
|
||||
from typing import AsyncGenerator, Optional
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from sqlalchemy import text, create_engine
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
create_async_engine,
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
AsyncEngine,
|
||||
)
|
||||
from sqlalchemy.pool import QueuePool, StaticPool
|
||||
from contextlib import contextmanager
|
||||
|
||||
from sqlalchemy import create_engine, event
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
from sqlalchemy.pool import QueuePool, StaticPool
|
||||
from database.utils import convert_to_async_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Database URL from environment
|
||||
# PostgreSQL: postgresql://user:password@host:port/database
|
||||
# SQLite: sqlite:///./data/translate.db
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "")
|
||||
|
||||
# Determine if we're using SQLite or PostgreSQL
|
||||
_is_sqlite = DATABASE_URL.startswith("sqlite") if DATABASE_URL else True
|
||||
_is_postgres = DATABASE_URL.startswith("postgres") if DATABASE_URL else False
|
||||
|
||||
# Create engine based on database type
|
||||
if DATABASE_URL and not _is_sqlite:
|
||||
# PostgreSQL configuration
|
||||
engine = create_engine(
|
||||
DATABASE_URL,
|
||||
|
||||
if DATABASE_URL and _is_postgres:
|
||||
async_database_url = convert_to_async_url(DATABASE_URL)
|
||||
engine: AsyncEngine = create_async_engine(
|
||||
async_database_url,
|
||||
poolclass=QueuePool,
|
||||
pool_size=5,
|
||||
max_overflow=10,
|
||||
pool_timeout=30,
|
||||
pool_recycle=1800, # Recycle connections after 30 minutes
|
||||
pool_pre_ping=True, # Check connection health before use
|
||||
pool_recycle=1800,
|
||||
pool_pre_ping=True,
|
||||
echo=os.getenv("DATABASE_ECHO", "false").lower() == "true",
|
||||
)
|
||||
logger.info("✅ Database configured with PostgreSQL")
|
||||
logger.info("✅ Database configured with PostgreSQL (async)")
|
||||
else:
|
||||
# SQLite configuration (for development/testing or when no DATABASE_URL)
|
||||
sqlite_path = os.getenv("SQLITE_PATH", "data/translate.db")
|
||||
os.makedirs(os.path.dirname(sqlite_path), exist_ok=True)
|
||||
|
||||
sqlite_url = f"sqlite:///./{sqlite_path}"
|
||||
engine = create_engine(
|
||||
sqlite_url,
|
||||
os.makedirs(
|
||||
os.path.dirname(sqlite_path) if os.path.dirname(sqlite_path) else ".",
|
||||
exist_ok=True,
|
||||
)
|
||||
|
||||
async_database_url = f"sqlite+aiosqlite:///./{sqlite_path}"
|
||||
engine: AsyncEngine = create_async_engine(
|
||||
async_database_url,
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
echo=os.getenv("DATABASE_ECHO", "false").lower() == "true",
|
||||
)
|
||||
|
||||
# Enable foreign keys for SQLite
|
||||
@event.listens_for(engine, "connect")
|
||||
def set_sqlite_pragma(dbapi_connection, connection_record):
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute("PRAGMA foreign_keys=ON")
|
||||
cursor.close()
|
||||
|
||||
if not DATABASE_URL:
|
||||
logger.warning("⚠️ DATABASE_URL not set, using SQLite for development")
|
||||
else:
|
||||
logger.info(f"✅ Database configured with SQLite: {sqlite_path}")
|
||||
|
||||
# Session factory
|
||||
SessionLocal = sessionmaker(
|
||||
if not DATABASE_URL:
|
||||
logger.warning("⚠️ DATABASE_URL not set, using SQLite for development (async)")
|
||||
else:
|
||||
logger.info(f"✅ Database configured with SQLite: {sqlite_path} (async)")
|
||||
|
||||
# Sync engine and session for repositories (auth, translation log).
|
||||
# Kept for backward compatibility until all callers use async; see story 1-1.
|
||||
# Prefer get_db() / AsyncSessionLocal for new code.
|
||||
if DATABASE_URL and _is_postgres:
|
||||
sync_engine = create_engine(
|
||||
DATABASE_URL,
|
||||
poolclass=QueuePool,
|
||||
pool_size=5,
|
||||
max_overflow=10,
|
||||
pool_pre_ping=True,
|
||||
echo=os.getenv("DATABASE_ECHO", "false").lower() == "true",
|
||||
)
|
||||
else:
|
||||
_sqlite_path = os.getenv("SQLITE_PATH", "data/translate.db")
|
||||
_sync_sqlite_url = f"sqlite:///./{_sqlite_path}"
|
||||
sync_engine = create_engine(
|
||||
_sync_sqlite_url,
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
|
||||
SyncSessionLocal = sessionmaker(bind=sync_engine, autocommit=False, autoflush=False, expire_on_commit=False)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_sync_session():
|
||||
"""Sync session context manager for use with sync repositories (auth_service, translation log)."""
|
||||
session = SyncSessionLocal()
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
bind=engine,
|
||||
class_=AsyncSession,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
bind=engine,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
|
||||
def get_db() -> Generator[Session, None, None]:
|
||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
Dependency for FastAPI to get database session.
|
||||
Usage: db: Session = Depends(get_db)
|
||||
Async dependency for FastAPI to get database session.
|
||||
Usage: db: AsyncSession = Depends(get_db)
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db_session() -> Generator[Session, None, None]:
|
||||
@asynccontextmanager
|
||||
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
Context manager for database session.
|
||||
Usage: with get_db_session() as db: ...
|
||||
Async context manager for database session.
|
||||
Usage: async with get_db_session() as db: ...
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
db.commit()
|
||||
except Exception:
|
||||
db.rollback()
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
|
||||
# Alias for backward compatibility
|
||||
get_sync_session = get_db_session
|
||||
get_async_session = get_db_session
|
||||
|
||||
|
||||
def init_db():
|
||||
async def init_db():
|
||||
"""
|
||||
Initialize database tables.
|
||||
Initialize database tables asynchronously.
|
||||
Call this on application startup.
|
||||
"""
|
||||
from database.models import Base
|
||||
Base.metadata.create_all(bind=engine)
|
||||
logger.info("✅ Database tables initialized")
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
logger.info("✅ Database tables initialized (async)")
|
||||
|
||||
|
||||
def check_db_connection() -> bool:
|
||||
async def check_db_connection() -> bool:
|
||||
"""
|
||||
Check if database connection is healthy.
|
||||
Returns True if connection works, False otherwise.
|
||||
"""
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
conn.execute("SELECT 1")
|
||||
async with engine.connect() as conn:
|
||||
await conn.execute(text("SELECT 1"))
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Database connection check failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# Connection pool stats (for monitoring)
|
||||
def get_pool_stats() -> dict:
|
||||
"""Get database connection pool statistics"""
|
||||
if hasattr(engine.pool, 'status'):
|
||||
if hasattr(engine.pool, "status"):
|
||||
return {
|
||||
"pool_size": engine.pool.size(),
|
||||
"checked_in": engine.pool.checkedin(),
|
||||
@@ -137,3 +176,8 @@ def get_pool_stats() -> dict:
|
||||
"overflow": engine.pool.overflow(),
|
||||
}
|
||||
return {"status": "pool stats not available"}
|
||||
|
||||
|
||||
def get_engine() -> AsyncEngine:
|
||||
"""Get the async engine instance"""
|
||||
return engine
|
||||
|
||||
@@ -1,19 +1,34 @@
|
||||
"""
|
||||
SQLAlchemy models for the Document Translation API
|
||||
"""
|
||||
import os
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
from datetime import datetime, timezone
|
||||
import warnings
|
||||
|
||||
from sqlalchemy import (
|
||||
Column, String, Integer, Float, Boolean, DateTime, Text,
|
||||
ForeignKey, Enum, Index, JSON, BigInteger
|
||||
Column,
|
||||
String,
|
||||
Integer,
|
||||
Float,
|
||||
Boolean,
|
||||
DateTime,
|
||||
Text,
|
||||
ForeignKey,
|
||||
Enum,
|
||||
Index,
|
||||
JSON,
|
||||
BigInteger,
|
||||
CheckConstraint,
|
||||
)
|
||||
from sqlalchemy.orm import relationship, declarative_base
|
||||
from sqlalchemy.dialects.postgresql import UUID as PG_UUID
|
||||
import enum
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
@@ -22,6 +37,11 @@ def generate_uuid():
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
def generate_uuid_value():
|
||||
"""Generate a new UUID value for PostgreSQL UUID column"""
|
||||
return uuid.uuid4()
|
||||
|
||||
|
||||
class PlanType(str, enum.Enum):
|
||||
FREE = "free"
|
||||
STARTER = "starter"
|
||||
@@ -40,57 +60,78 @@ class SubscriptionStatus(str, enum.Enum):
|
||||
|
||||
class User(Base):
|
||||
"""User model for authentication and billing"""
|
||||
|
||||
__tablename__ = "users"
|
||||
|
||||
|
||||
id = Column(String(36), primary_key=True, default=generate_uuid)
|
||||
email = Column(String(255), unique=True, nullable=False, index=True)
|
||||
name = Column(String(255), nullable=False)
|
||||
password_hash = Column(String(255), nullable=False)
|
||||
|
||||
# Account status
|
||||
hashed_password = Column(String(255), nullable=False)
|
||||
|
||||
tier = Column(String(10), default="free", nullable=False)
|
||||
daily_translation_count = Column(Integer, default=0, nullable=False)
|
||||
|
||||
email_verified = Column(Boolean, default=False)
|
||||
is_active = Column(Boolean, default=True)
|
||||
avatar_url = Column(String(500), nullable=True)
|
||||
|
||||
# Subscription info
|
||||
|
||||
plan = Column(Enum(PlanType), default=PlanType.FREE)
|
||||
subscription_status = Column(Enum(SubscriptionStatus), default=SubscriptionStatus.ACTIVE)
|
||||
|
||||
# Stripe integration
|
||||
subscription_status = Column(
|
||||
Enum(SubscriptionStatus), default=SubscriptionStatus.ACTIVE
|
||||
)
|
||||
|
||||
stripe_customer_id = Column(String(255), nullable=True, index=True)
|
||||
stripe_subscription_id = Column(String(255), nullable=True)
|
||||
|
||||
# Usage tracking (reset monthly)
|
||||
|
||||
docs_translated_this_month = Column(Integer, default=0)
|
||||
pages_translated_this_month = Column(Integer, default=0)
|
||||
api_calls_this_month = Column(Integer, default=0)
|
||||
extra_credits = Column(Integer, default=0) # Purchased credits
|
||||
usage_reset_date = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
extra_credits = Column(Integer, default=0)
|
||||
usage_reset_date = Column(DateTime, default=_utcnow)
|
||||
|
||||
created_at = Column(DateTime, default=_utcnow)
|
||||
updated_at = Column(DateTime, default=_utcnow, onupdate=_utcnow)
|
||||
last_login_at = Column(DateTime, nullable=True)
|
||||
|
||||
# Relationships
|
||||
translations = relationship("Translation", back_populates="user", lazy="dynamic")
|
||||
api_keys = relationship("ApiKey", back_populates="user", lazy="dynamic")
|
||||
|
||||
# Indexes
|
||||
|
||||
translations = relationship("Translation", back_populates="user", lazy="select")
|
||||
api_keys = relationship("ApiKey", back_populates="user", lazy="select")
|
||||
|
||||
__table_args__ = (
|
||||
Index('ix_users_email_active', 'email', 'is_active'),
|
||||
Index('ix_users_stripe_customer', 'stripe_customer_id'),
|
||||
CheckConstraint("tier IN ('free', 'pro')", name="ck_users_tier"),
|
||||
Index("ix_users_email_active", "email", "is_active"),
|
||||
Index("ix_users_stripe_customer", "stripe_customer_id"),
|
||||
)
|
||||
|
||||
|
||||
@property
|
||||
def password_hash(self) -> str:
|
||||
warnings.warn(
|
||||
"password_hash is deprecated, use hashed_password instead",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self.hashed_password
|
||||
|
||||
@password_hash.setter
|
||||
def password_hash(self, value: str) -> None:
|
||||
warnings.warn(
|
||||
"password_hash is deprecated, use hashed_password instead",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self.hashed_password = value
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert user to dictionary for API response"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"email": self.email,
|
||||
"name": self.name,
|
||||
"avatar_url": self.avatar_url,
|
||||
"tier": self.tier,
|
||||
"plan": self.plan.value if self.plan else "free",
|
||||
"subscription_status": self.subscription_status.value if self.subscription_status else "active",
|
||||
"subscription_status": self.subscription_status.value
|
||||
if self.subscription_status
|
||||
else "active",
|
||||
"daily_translation_count": self.daily_translation_count,
|
||||
"docs_translated_this_month": self.docs_translated_this_month,
|
||||
"pages_translated_this_month": self.pages_translated_this_month,
|
||||
"api_calls_this_month": self.api_calls_this_month,
|
||||
@@ -102,44 +143,49 @@ class User(Base):
|
||||
|
||||
class Translation(Base):
|
||||
"""Translation history for analytics and billing"""
|
||||
|
||||
__tablename__ = "translations"
|
||||
|
||||
|
||||
id = Column(String(36), primary_key=True, default=generate_uuid)
|
||||
user_id = Column(String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
|
||||
|
||||
user_id = Column(
|
||||
String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
|
||||
# File info
|
||||
original_filename = Column(String(255), nullable=False)
|
||||
file_type = Column(String(10), nullable=False) # xlsx, docx, pptx
|
||||
file_type = Column(String(20), nullable=False) # xlsx, docx, pptx
|
||||
file_size_bytes = Column(BigInteger, default=0)
|
||||
page_count = Column(Integer, default=0)
|
||||
|
||||
|
||||
# Translation details
|
||||
source_language = Column(String(10), default="auto")
|
||||
target_language = Column(String(10), nullable=False)
|
||||
provider = Column(String(50), nullable=False) # google, deepl, ollama, etc.
|
||||
|
||||
|
||||
# Processing info
|
||||
status = Column(String(20), default="pending") # pending, processing, completed, failed
|
||||
status = Column(
|
||||
String(20), default="pending"
|
||||
) # pending, processing, completed, failed
|
||||
error_message = Column(Text, nullable=True)
|
||||
processing_time_ms = Column(Integer, nullable=True)
|
||||
|
||||
|
||||
# Cost tracking (for paid providers)
|
||||
characters_translated = Column(Integer, default=0)
|
||||
estimated_cost_usd = Column(Float, default=0.0)
|
||||
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
created_at = Column(DateTime, default=_utcnow)
|
||||
completed_at = Column(DateTime, nullable=True)
|
||||
|
||||
|
||||
# Relationship
|
||||
user = relationship("User", back_populates="translations")
|
||||
|
||||
|
||||
# Indexes
|
||||
__table_args__ = (
|
||||
Index('ix_translations_user_date', 'user_id', 'created_at'),
|
||||
Index('ix_translations_status', 'status'),
|
||||
Index("ix_translations_user_date", "user_id", "created_at"),
|
||||
Index("ix_translations_status", "status"),
|
||||
)
|
||||
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"id": self.id,
|
||||
@@ -154,43 +200,49 @@ class Translation(Base):
|
||||
"processing_time_ms": self.processing_time_ms,
|
||||
"characters_translated": self.characters_translated,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
|
||||
"completed_at": self.completed_at.isoformat()
|
||||
if self.completed_at
|
||||
else None,
|
||||
}
|
||||
|
||||
|
||||
class ApiKey(Base):
|
||||
"""API keys for programmatic access"""
|
||||
|
||||
__tablename__ = "api_keys"
|
||||
|
||||
|
||||
id = Column(String(36), primary_key=True, default=generate_uuid)
|
||||
user_id = Column(String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
|
||||
|
||||
user_id = Column(
|
||||
String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
|
||||
# Key info
|
||||
name = Column(String(100), nullable=False) # User-friendly name
|
||||
key_hash = Column(String(255), nullable=False) # SHA256 of the key
|
||||
key_prefix = Column(String(10), nullable=False) # First 8 chars for identification
|
||||
|
||||
|
||||
# Permissions
|
||||
is_active = Column(Boolean, default=True)
|
||||
scopes = Column(JSON, default=list) # ["translate", "read", "write"]
|
||||
|
||||
|
||||
# Usage tracking
|
||||
last_used_at = Column(DateTime, nullable=True)
|
||||
usage_count = Column(Integer, default=0)
|
||||
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
created_at = Column(DateTime, default=_utcnow)
|
||||
expires_at = Column(DateTime, nullable=True)
|
||||
|
||||
revoked_at = Column(DateTime, nullable=True) # Set when is_active=False
|
||||
|
||||
# Relationship
|
||||
user = relationship("User", back_populates="api_keys")
|
||||
|
||||
|
||||
# Indexes
|
||||
__table_args__ = (
|
||||
Index('ix_api_keys_prefix', 'key_prefix'),
|
||||
Index('ix_api_keys_hash', 'key_hash'),
|
||||
Index("ix_api_keys_prefix", "key_prefix"),
|
||||
Index("ix_api_keys_hash", "key_hash"),
|
||||
)
|
||||
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"id": self.id,
|
||||
@@ -198,7 +250,9 @@ class ApiKey(Base):
|
||||
"key_prefix": self.key_prefix,
|
||||
"is_active": self.is_active,
|
||||
"scopes": self.scopes,
|
||||
"last_used_at": self.last_used_at.isoformat() if self.last_used_at else None,
|
||||
"last_used_at": self.last_used_at.isoformat()
|
||||
if self.last_used_at
|
||||
else None,
|
||||
"usage_count": self.usage_count,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
|
||||
@@ -207,53 +261,149 @@ class ApiKey(Base):
|
||||
|
||||
class UsageLog(Base):
|
||||
"""Daily usage aggregation for billing and analytics"""
|
||||
|
||||
__tablename__ = "usage_logs"
|
||||
|
||||
|
||||
id = Column(String(36), primary_key=True, default=generate_uuid)
|
||||
user_id = Column(String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
|
||||
|
||||
user_id = Column(
|
||||
String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
|
||||
# Date (for daily aggregation)
|
||||
date = Column(DateTime, nullable=False, index=True)
|
||||
|
||||
|
||||
# Aggregated counts
|
||||
documents_count = Column(Integer, default=0)
|
||||
pages_count = Column(Integer, default=0)
|
||||
characters_count = Column(BigInteger, default=0)
|
||||
api_calls_count = Column(Integer, default=0)
|
||||
|
||||
|
||||
# By provider breakdown (JSON)
|
||||
provider_breakdown = Column(JSON, default=dict)
|
||||
|
||||
|
||||
# Indexes
|
||||
__table_args__ = (
|
||||
Index('ix_usage_logs_user_date', 'user_id', 'date', unique=True),
|
||||
)
|
||||
__table_args__ = (Index("ix_usage_logs_user_date", "user_id", "date", unique=True),)
|
||||
|
||||
|
||||
class PaymentHistory(Base):
|
||||
"""Payment and invoice history"""
|
||||
|
||||
__tablename__ = "payment_history"
|
||||
|
||||
|
||||
id = Column(String(36), primary_key=True, default=generate_uuid)
|
||||
user_id = Column(String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
|
||||
|
||||
user_id = Column(
|
||||
String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
|
||||
# Stripe info
|
||||
stripe_payment_intent_id = Column(String(255), nullable=True)
|
||||
stripe_invoice_id = Column(String(255), nullable=True)
|
||||
|
||||
|
||||
# Payment details
|
||||
amount_cents = Column(Integer, nullable=False)
|
||||
currency = Column(String(3), default="usd")
|
||||
payment_type = Column(String(50), nullable=False) # subscription, credits, one_time
|
||||
status = Column(String(20), nullable=False) # succeeded, failed, pending, refunded
|
||||
|
||||
|
||||
# Description
|
||||
description = Column(String(255), nullable=True)
|
||||
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
created_at = Column(DateTime, default=_utcnow)
|
||||
|
||||
# Indexes
|
||||
__table_args__ = (
|
||||
Index('ix_payment_history_user', 'user_id', 'created_at'),
|
||||
__table_args__ = (Index("ix_payment_history_user", "user_id", "created_at"),)
|
||||
|
||||
|
||||
class Glossary(Base):
|
||||
"""User's glossary containing source->target term pairs.
|
||||
Story 3.9: Glossaires - Endpoint CRUD
|
||||
"""
|
||||
|
||||
__tablename__ = "glossaries"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=generate_uuid)
|
||||
user_id = Column(
|
||||
String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
name = Column(String(255), nullable=False)
|
||||
created_at = Column(DateTime, default=_utcnow)
|
||||
updated_at = Column(DateTime, default=_utcnow, onupdate=_utcnow)
|
||||
|
||||
# Relationship
|
||||
terms = relationship(
|
||||
"GlossaryTerm", back_populates="glossary", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# Indexes
|
||||
__table_args__ = (Index("ix_glossaries_user_id", "user_id"),)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"id": self.id,
|
||||
"user_id": self.user_id,
|
||||
"name": self.name,
|
||||
"terms": [term.to_dict() for term in self.terms] if self.terms else [],
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
class GlossaryTerm(Base):
|
||||
"""Single term pair in a glossary.
|
||||
Story 3.9: Glossaires - Endpoint CRUD
|
||||
"""
|
||||
|
||||
__tablename__ = "glossary_terms"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=generate_uuid)
|
||||
glossary_id = Column(
|
||||
String(36), ForeignKey("glossaries.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
source = Column(String(500), nullable=False)
|
||||
target = Column(String(500), nullable=False)
|
||||
created_at = Column(DateTime, default=_utcnow)
|
||||
|
||||
# Relationship
|
||||
glossary = relationship("Glossary", back_populates="terms")
|
||||
|
||||
# Indexes
|
||||
__table_args__ = (Index("ix_glossary_terms_glossary_id", "glossary_id"),)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"id": self.id,
|
||||
"source": self.source,
|
||||
"target": self.target,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
class CustomPrompt(Base):
|
||||
"""User's custom prompts for LLM translation context.
|
||||
Story 3.11: Custom Prompts - Endpoint CRUD
|
||||
"""
|
||||
|
||||
__tablename__ = "custom_prompts"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=generate_uuid)
|
||||
user_id = Column(
|
||||
String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
name = Column(String(255), nullable=False)
|
||||
content = Column(Text, nullable=False)
|
||||
created_at = Column(DateTime, default=_utcnow)
|
||||
updated_at = Column(DateTime, default=_utcnow, onupdate=_utcnow)
|
||||
|
||||
# Indexes
|
||||
__table_args__ = (Index("ix_custom_prompts_user_id", "user_id"),)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"id": self.id,
|
||||
"user_id": self.user_id,
|
||||
"name": self.name,
|
||||
"content": self.content,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
}
|
||||
|
||||
@@ -2,54 +2,64 @@
|
||||
Repository layer for database operations
|
||||
Provides clean interface for CRUD operations
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import secrets
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_, func, or_
|
||||
|
||||
from database.models import (
|
||||
User, Translation, ApiKey, UsageLog, PaymentHistory,
|
||||
PlanType, SubscriptionStatus
|
||||
User,
|
||||
Translation,
|
||||
ApiKey,
|
||||
UsageLog,
|
||||
PaymentHistory,
|
||||
PlanType,
|
||||
SubscriptionStatus,
|
||||
)
|
||||
|
||||
|
||||
class UserRepository:
|
||||
"""Repository for User database operations"""
|
||||
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
|
||||
def get_by_id(self, user_id: str) -> Optional[User]:
|
||||
"""Get user by ID"""
|
||||
return self.db.query(User).filter(User.id == user_id).first()
|
||||
|
||||
|
||||
def get_by_email(self, email: str) -> Optional[User]:
|
||||
"""Get user by email (case-insensitive)"""
|
||||
return self.db.query(User).filter(
|
||||
func.lower(User.email) == email.lower()
|
||||
).first()
|
||||
|
||||
return (
|
||||
self.db.query(User).filter(func.lower(User.email) == email.lower()).first()
|
||||
)
|
||||
|
||||
def get_by_stripe_customer(self, stripe_customer_id: str) -> Optional[User]:
|
||||
"""Get user by Stripe customer ID"""
|
||||
return self.db.query(User).filter(
|
||||
User.stripe_customer_id == stripe_customer_id
|
||||
).first()
|
||||
|
||||
return (
|
||||
self.db.query(User)
|
||||
.filter(User.stripe_customer_id == stripe_customer_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
def create(
|
||||
self,
|
||||
email: str,
|
||||
name: str,
|
||||
password_hash: str,
|
||||
plan: PlanType = PlanType.FREE
|
||||
self,
|
||||
email: str,
|
||||
name: str,
|
||||
hashed_password: str,
|
||||
tier: str = "free",
|
||||
) -> User:
|
||||
"""Create a new user"""
|
||||
"""Create a new user. Uses hashed_password and tier (story 1-1 refactor)."""
|
||||
plan = PlanType.PRO if tier == "pro" else PlanType.FREE
|
||||
user = User(
|
||||
email=email.lower(),
|
||||
name=name,
|
||||
password_hash=password_hash,
|
||||
hashed_password=hashed_password,
|
||||
tier=tier,
|
||||
plan=plan,
|
||||
subscription_status=SubscriptionStatus.ACTIVE,
|
||||
)
|
||||
@@ -57,94 +67,90 @@ class UserRepository:
|
||||
self.db.commit()
|
||||
self.db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
def update(self, user_id: str, **kwargs) -> Optional[User]:
|
||||
"""Update user fields"""
|
||||
user = self.get_by_id(user_id)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(user, key):
|
||||
setattr(user, key, value)
|
||||
|
||||
user.updated_at = datetime.utcnow()
|
||||
|
||||
user.updated_at = datetime.now(timezone.utc)
|
||||
self.db.commit()
|
||||
self.db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
def delete(self, user_id: str) -> bool:
|
||||
"""Delete a user"""
|
||||
user = self.get_by_id(user_id)
|
||||
if not user:
|
||||
return False
|
||||
|
||||
|
||||
self.db.delete(user)
|
||||
self.db.commit()
|
||||
return True
|
||||
|
||||
|
||||
def increment_usage(
|
||||
self,
|
||||
user_id: str,
|
||||
docs: int = 0,
|
||||
pages: int = 0,
|
||||
api_calls: int = 0
|
||||
self, user_id: str, docs: int = 0, pages: int = 0, api_calls: int = 0
|
||||
) -> Optional[User]:
|
||||
"""Increment usage counters"""
|
||||
user = self.get_by_id(user_id)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
|
||||
# Check if usage needs to be reset (monthly)
|
||||
if user.usage_reset_date:
|
||||
now = datetime.utcnow()
|
||||
if now.month != user.usage_reset_date.month or now.year != user.usage_reset_date.year:
|
||||
now = datetime.now(timezone.utc)
|
||||
if (
|
||||
now.month != user.usage_reset_date.month
|
||||
or now.year != user.usage_reset_date.year
|
||||
):
|
||||
user.docs_translated_this_month = 0
|
||||
user.pages_translated_this_month = 0
|
||||
user.api_calls_this_month = 0
|
||||
user.usage_reset_date = now
|
||||
|
||||
|
||||
user.docs_translated_this_month += docs
|
||||
user.pages_translated_this_month += pages
|
||||
user.api_calls_this_month += api_calls
|
||||
|
||||
|
||||
self.db.commit()
|
||||
self.db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
def add_credits(self, user_id: str, credits: int) -> Optional[User]:
|
||||
"""Add extra credits to user"""
|
||||
user = self.get_by_id(user_id)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
|
||||
user.extra_credits += credits
|
||||
self.db.commit()
|
||||
self.db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
def use_credits(self, user_id: str, credits: int) -> bool:
|
||||
"""Use credits from user balance"""
|
||||
user = self.get_by_id(user_id)
|
||||
if not user or user.extra_credits < credits:
|
||||
return False
|
||||
|
||||
|
||||
user.extra_credits -= credits
|
||||
self.db.commit()
|
||||
return True
|
||||
|
||||
|
||||
def get_all_users(
|
||||
self,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
plan: Optional[PlanType] = None
|
||||
self, skip: int = 0, limit: int = 100, plan: Optional[PlanType] = None
|
||||
) -> List[User]:
|
||||
"""Get all users with pagination"""
|
||||
query = self.db.query(User)
|
||||
if plan:
|
||||
query = query.filter(User.plan == plan)
|
||||
return query.offset(skip).limit(limit).all()
|
||||
|
||||
|
||||
def count_users(self, plan: Optional[PlanType] = None) -> int:
|
||||
"""Count total users"""
|
||||
query = self.db.query(func.count(User.id))
|
||||
@@ -155,10 +161,10 @@ class UserRepository:
|
||||
|
||||
class TranslationRepository:
|
||||
"""Repository for Translation database operations"""
|
||||
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
|
||||
def create(
|
||||
self,
|
||||
user_id: str,
|
||||
@@ -186,7 +192,36 @@ class TranslationRepository:
|
||||
self.db.commit()
|
||||
self.db.refresh(translation)
|
||||
return translation
|
||||
|
||||
|
||||
def create_completed(
|
||||
self,
|
||||
user_id: str,
|
||||
original_filename: str,
|
||||
file_type: str,
|
||||
target_language: str,
|
||||
provider: str,
|
||||
source_language: str = "auto",
|
||||
file_size_bytes: int = 0,
|
||||
page_count: int = 0,
|
||||
) -> Translation:
|
||||
"""Create a translation record directly in completed status (Story 1.8 - billing log)."""
|
||||
translation = Translation(
|
||||
user_id=user_id,
|
||||
original_filename=original_filename,
|
||||
file_type=file_type,
|
||||
file_size_bytes=file_size_bytes,
|
||||
page_count=page_count,
|
||||
source_language=source_language,
|
||||
target_language=target_language,
|
||||
provider=provider,
|
||||
status="completed",
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
)
|
||||
self.db.add(translation)
|
||||
self.db.commit()
|
||||
self.db.refresh(translation)
|
||||
return translation
|
||||
|
||||
def update_status(
|
||||
self,
|
||||
translation_id: str,
|
||||
@@ -196,13 +231,13 @@ class TranslationRepository:
|
||||
characters_translated: Optional[int] = None,
|
||||
) -> Optional[Translation]:
|
||||
"""Update translation status"""
|
||||
translation = self.db.query(Translation).filter(
|
||||
Translation.id == translation_id
|
||||
).first()
|
||||
|
||||
translation = (
|
||||
self.db.query(Translation).filter(Translation.id == translation_id).first()
|
||||
)
|
||||
|
||||
if not translation:
|
||||
return None
|
||||
|
||||
|
||||
translation.status = status
|
||||
if error_message:
|
||||
translation.error_message = error_message
|
||||
@@ -211,12 +246,12 @@ class TranslationRepository:
|
||||
if characters_translated:
|
||||
translation.characters_translated = characters_translated
|
||||
if status == "completed":
|
||||
translation.completed_at = datetime.utcnow()
|
||||
|
||||
translation.completed_at = datetime.now(timezone.utc)
|
||||
|
||||
self.db.commit()
|
||||
self.db.refresh(translation)
|
||||
return translation
|
||||
|
||||
|
||||
def get_user_translations(
|
||||
self,
|
||||
user_id: str,
|
||||
@@ -228,24 +263,33 @@ class TranslationRepository:
|
||||
query = self.db.query(Translation).filter(Translation.user_id == user_id)
|
||||
if status:
|
||||
query = query.filter(Translation.status == status)
|
||||
return query.order_by(Translation.created_at.desc()).offset(skip).limit(limit).all()
|
||||
|
||||
return (
|
||||
query.order_by(Translation.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
def get_user_stats(self, user_id: str, days: int = 30) -> Dict[str, Any]:
|
||||
"""Get user's translation statistics"""
|
||||
since = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
result = self.db.query(
|
||||
func.count(Translation.id).label("total_translations"),
|
||||
func.sum(Translation.page_count).label("total_pages"),
|
||||
func.sum(Translation.characters_translated).label("total_characters"),
|
||||
).filter(
|
||||
and_(
|
||||
Translation.user_id == user_id,
|
||||
Translation.created_at >= since,
|
||||
Translation.status == "completed",
|
||||
since = datetime.now(timezone.utc) - timedelta(days=days)
|
||||
|
||||
result = (
|
||||
self.db.query(
|
||||
func.count(Translation.id).label("total_translations"),
|
||||
func.sum(Translation.page_count).label("total_pages"),
|
||||
func.sum(Translation.characters_translated).label("total_characters"),
|
||||
)
|
||||
).first()
|
||||
|
||||
.filter(
|
||||
and_(
|
||||
Translation.user_id == user_id,
|
||||
Translation.created_at >= since,
|
||||
Translation.status == "completed",
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
return {
|
||||
"total_translations": result.total_translations or 0,
|
||||
"total_pages": result.total_pages or 0,
|
||||
@@ -256,15 +300,15 @@ class TranslationRepository:
|
||||
|
||||
class ApiKeyRepository:
|
||||
"""Repository for API Key database operations"""
|
||||
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
|
||||
@staticmethod
|
||||
def hash_key(key: str) -> str:
|
||||
"""Hash an API key"""
|
||||
return hashlib.sha256(key.encode()).hexdigest()
|
||||
|
||||
|
||||
def create(
|
||||
self,
|
||||
user_id: str,
|
||||
@@ -277,11 +321,11 @@ class ApiKeyRepository:
|
||||
raw_key = f"tr_{secrets.token_urlsafe(32)}"
|
||||
key_hash = self.hash_key(raw_key)
|
||||
key_prefix = raw_key[:10]
|
||||
|
||||
|
||||
expires_at = None
|
||||
if expires_in_days:
|
||||
expires_at = datetime.utcnow() + timedelta(days=expires_in_days)
|
||||
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(days=expires_in_days)
|
||||
|
||||
api_key = ApiKey(
|
||||
user_id=user_id,
|
||||
name=name,
|
||||
@@ -293,49 +337,60 @@ class ApiKeyRepository:
|
||||
self.db.add(api_key)
|
||||
self.db.commit()
|
||||
self.db.refresh(api_key)
|
||||
|
||||
|
||||
return api_key, raw_key
|
||||
|
||||
|
||||
def get_by_key(self, raw_key: str) -> Optional[ApiKey]:
|
||||
"""Get API key by raw key value"""
|
||||
key_hash = self.hash_key(raw_key)
|
||||
api_key = self.db.query(ApiKey).filter(
|
||||
and_(
|
||||
ApiKey.key_hash == key_hash,
|
||||
ApiKey.is_active == True,
|
||||
api_key = (
|
||||
self.db.query(ApiKey)
|
||||
.filter(
|
||||
and_(
|
||||
ApiKey.key_hash == key_hash,
|
||||
ApiKey.is_active == True,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
|
||||
.first()
|
||||
)
|
||||
|
||||
if api_key:
|
||||
# Check expiration
|
||||
if api_key.expires_at and api_key.expires_at < datetime.utcnow():
|
||||
if api_key.expires_at and api_key.expires_at < datetime.now(timezone.utc):
|
||||
return None
|
||||
|
||||
|
||||
# Update last used
|
||||
api_key.last_used_at = datetime.utcnow()
|
||||
api_key.last_used_at = datetime.now(timezone.utc)
|
||||
api_key.usage_count += 1
|
||||
self.db.commit()
|
||||
|
||||
|
||||
return api_key
|
||||
|
||||
|
||||
def get_user_keys(self, user_id: str) -> List[ApiKey]:
|
||||
"""Get all API keys for a user"""
|
||||
return self.db.query(ApiKey).filter(
|
||||
ApiKey.user_id == user_id
|
||||
).order_by(ApiKey.created_at.desc()).all()
|
||||
|
||||
return (
|
||||
self.db.query(ApiKey)
|
||||
.filter(ApiKey.user_id == user_id)
|
||||
.order_by(ApiKey.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
def revoke(self, key_id: str, user_id: str) -> bool:
|
||||
"""Revoke an API key"""
|
||||
api_key = self.db.query(ApiKey).filter(
|
||||
and_(
|
||||
ApiKey.id == key_id,
|
||||
ApiKey.user_id == user_id,
|
||||
api_key = (
|
||||
self.db.query(ApiKey)
|
||||
.filter(
|
||||
and_(
|
||||
ApiKey.id == key_id,
|
||||
ApiKey.user_id == user_id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
|
||||
.first()
|
||||
)
|
||||
|
||||
if not api_key:
|
||||
return False
|
||||
|
||||
|
||||
api_key.is_active = False
|
||||
self.db.commit()
|
||||
return True
|
||||
|
||||
14
database/utils.py
Normal file
14
database/utils.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
Shared database utilities
|
||||
"""
|
||||
|
||||
|
||||
def convert_to_async_url(url: str) -> str:
|
||||
"""Convert a sync database URL to its async driver equivalent."""
|
||||
if url.startswith("postgresql://"):
|
||||
return url.replace("postgresql://", "postgresql+asyncpg://", 1)
|
||||
elif url.startswith("postgres://"):
|
||||
return url.replace("postgres://", "postgresql+asyncpg://", 1)
|
||||
elif url.startswith("sqlite:///"):
|
||||
return url.replace("sqlite:///", "sqlite+aiosqlite:///", 1)
|
||||
return url
|
||||
Reference in New Issue
Block a user