feat: Add PostgreSQL database infrastructure

- Add SQLAlchemy models for User, Translation, ApiKey, UsageLog, PaymentHistory
- Add database connection management with PostgreSQL/SQLite support
- Add repository layer for CRUD operations
- Add Alembic migration setup with initial migration
- Update auth_service to automatically use database when DATABASE_URL is set
- Update docker-compose.yml with PostgreSQL service and Redis (non-optional)
- Add database migration script (scripts/migrate_to_db.py)
- Update .env.example with database configuration
This commit is contained in:
2025-12-31 10:56:19 +01:00
parent c4d6cae735
commit 550f3516db
15 changed files with 1712 additions and 63 deletions

17
database/__init__.py Normal file
View File

@@ -0,0 +1,17 @@
"""
Database module for the Document Translation API
Provides PostgreSQL support with async SQLAlchemy
"""
from database.connection import get_db, engine, SessionLocal, init_db
from database.models import User, Subscription, Translation, ApiKey
__all__ = [
"get_db",
"engine",
"SessionLocal",
"init_db",
"User",
"Subscription",
"Translation",
"ApiKey"
]

139
database/connection.py Normal file
View File

@@ -0,0 +1,139 @@
"""
Database connection and session management
Supports both PostgreSQL (production) and SQLite (development/testing)
"""
import os
import logging
from typing import Generator, Optional
from contextlib import contextmanager
from sqlalchemy import create_engine, event
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.pool import QueuePool, StaticPool
logger = logging.getLogger(__name__)
# Database URL from environment
# PostgreSQL: postgresql://user:password@host:port/database
# SQLite: sqlite:///./data/translate.db
DATABASE_URL = os.getenv("DATABASE_URL", "")
# Determine if we're using SQLite or PostgreSQL
_is_sqlite = DATABASE_URL.startswith("sqlite") if DATABASE_URL else True
# Create engine based on database type
if DATABASE_URL and not _is_sqlite:
# PostgreSQL configuration
engine = create_engine(
DATABASE_URL,
poolclass=QueuePool,
pool_size=5,
max_overflow=10,
pool_timeout=30,
pool_recycle=1800, # Recycle connections after 30 minutes
pool_pre_ping=True, # Check connection health before use
echo=os.getenv("DATABASE_ECHO", "false").lower() == "true",
)
logger.info("✅ Database configured with PostgreSQL")
else:
# SQLite configuration (for development/testing or when no DATABASE_URL)
sqlite_path = os.getenv("SQLITE_PATH", "data/translate.db")
os.makedirs(os.path.dirname(sqlite_path), exist_ok=True)
sqlite_url = f"sqlite:///./{sqlite_path}"
engine = create_engine(
sqlite_url,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
echo=os.getenv("DATABASE_ECHO", "false").lower() == "true",
)
# Enable foreign keys for SQLite
@event.listens_for(engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
if not DATABASE_URL:
logger.warning("⚠️ DATABASE_URL not set, using SQLite for development")
else:
logger.info(f"✅ Database configured with SQLite: {sqlite_path}")
# Session factory
SessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=engine,
expire_on_commit=False,
)
def get_db() -> Generator[Session, None, None]:
"""
Dependency for FastAPI to get database session.
Usage: db: Session = Depends(get_db)
"""
db = SessionLocal()
try:
yield db
finally:
db.close()
@contextmanager
def get_db_session() -> Generator[Session, None, None]:
"""
Context manager for database session.
Usage: with get_db_session() as db: ...
"""
db = SessionLocal()
try:
yield db
db.commit()
except Exception:
db.rollback()
raise
finally:
db.close()
# Alias for backward compatibility
get_sync_session = get_db_session
def init_db():
"""
Initialize database tables.
Call this on application startup.
"""
from database.models import Base
Base.metadata.create_all(bind=engine)
logger.info("✅ Database tables initialized")
def check_db_connection() -> bool:
"""
Check if database connection is healthy.
Returns True if connection works, False otherwise.
"""
try:
with engine.connect() as conn:
conn.execute("SELECT 1")
return True
except Exception as e:
logger.error(f"Database connection check failed: {e}")
return False
# Connection pool stats (for monitoring)
def get_pool_stats() -> dict:
"""Get database connection pool statistics"""
if hasattr(engine.pool, 'status'):
return {
"pool_size": engine.pool.size(),
"checked_in": engine.pool.checkedin(),
"checked_out": engine.pool.checkedout(),
"overflow": engine.pool.overflow(),
}
return {"status": "pool stats not available"}

259
database/models.py Normal file
View File

@@ -0,0 +1,259 @@
"""
SQLAlchemy models for the Document Translation API
"""
import os
import uuid
from datetime import datetime
from typing import Optional, List
from sqlalchemy import (
Column, String, Integer, Float, Boolean, DateTime, Text,
ForeignKey, Enum, Index, JSON, BigInteger
)
from sqlalchemy.orm import relationship, declarative_base
from sqlalchemy.dialects.postgresql import UUID as PG_UUID
import enum
Base = declarative_base()
def generate_uuid():
"""Generate a new UUID string"""
return str(uuid.uuid4())
class PlanType(str, enum.Enum):
FREE = "free"
STARTER = "starter"
PRO = "pro"
BUSINESS = "business"
ENTERPRISE = "enterprise"
class SubscriptionStatus(str, enum.Enum):
ACTIVE = "active"
CANCELED = "canceled"
PAST_DUE = "past_due"
TRIALING = "trialing"
PAUSED = "paused"
class User(Base):
"""User model for authentication and billing"""
__tablename__ = "users"
id = Column(String(36), primary_key=True, default=generate_uuid)
email = Column(String(255), unique=True, nullable=False, index=True)
name = Column(String(255), nullable=False)
password_hash = Column(String(255), nullable=False)
# Account status
email_verified = Column(Boolean, default=False)
is_active = Column(Boolean, default=True)
avatar_url = Column(String(500), nullable=True)
# Subscription info
plan = Column(Enum(PlanType), default=PlanType.FREE)
subscription_status = Column(Enum(SubscriptionStatus), default=SubscriptionStatus.ACTIVE)
# Stripe integration
stripe_customer_id = Column(String(255), nullable=True, index=True)
stripe_subscription_id = Column(String(255), nullable=True)
# Usage tracking (reset monthly)
docs_translated_this_month = Column(Integer, default=0)
pages_translated_this_month = Column(Integer, default=0)
api_calls_this_month = Column(Integer, default=0)
extra_credits = Column(Integer, default=0) # Purchased credits
usage_reset_date = Column(DateTime, default=datetime.utcnow)
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
last_login_at = Column(DateTime, nullable=True)
# Relationships
translations = relationship("Translation", back_populates="user", lazy="dynamic")
api_keys = relationship("ApiKey", back_populates="user", lazy="dynamic")
# Indexes
__table_args__ = (
Index('ix_users_email_active', 'email', 'is_active'),
Index('ix_users_stripe_customer', 'stripe_customer_id'),
)
def to_dict(self) -> dict:
"""Convert user to dictionary for API response"""
return {
"id": self.id,
"email": self.email,
"name": self.name,
"avatar_url": self.avatar_url,
"plan": self.plan.value if self.plan else "free",
"subscription_status": self.subscription_status.value if self.subscription_status else "active",
"docs_translated_this_month": self.docs_translated_this_month,
"pages_translated_this_month": self.pages_translated_this_month,
"api_calls_this_month": self.api_calls_this_month,
"extra_credits": self.extra_credits,
"email_verified": self.email_verified,
"created_at": self.created_at.isoformat() if self.created_at else None,
}
class Translation(Base):
"""Translation history for analytics and billing"""
__tablename__ = "translations"
id = Column(String(36), primary_key=True, default=generate_uuid)
user_id = Column(String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
# File info
original_filename = Column(String(255), nullable=False)
file_type = Column(String(10), nullable=False) # xlsx, docx, pptx
file_size_bytes = Column(BigInteger, default=0)
page_count = Column(Integer, default=0)
# Translation details
source_language = Column(String(10), default="auto")
target_language = Column(String(10), nullable=False)
provider = Column(String(50), nullable=False) # google, deepl, ollama, etc.
# Processing info
status = Column(String(20), default="pending") # pending, processing, completed, failed
error_message = Column(Text, nullable=True)
processing_time_ms = Column(Integer, nullable=True)
# Cost tracking (for paid providers)
characters_translated = Column(Integer, default=0)
estimated_cost_usd = Column(Float, default=0.0)
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow)
completed_at = Column(DateTime, nullable=True)
# Relationship
user = relationship("User", back_populates="translations")
# Indexes
__table_args__ = (
Index('ix_translations_user_date', 'user_id', 'created_at'),
Index('ix_translations_status', 'status'),
)
def to_dict(self) -> dict:
return {
"id": self.id,
"original_filename": self.original_filename,
"file_type": self.file_type,
"file_size_bytes": self.file_size_bytes,
"page_count": self.page_count,
"source_language": self.source_language,
"target_language": self.target_language,
"provider": self.provider,
"status": self.status,
"processing_time_ms": self.processing_time_ms,
"characters_translated": self.characters_translated,
"created_at": self.created_at.isoformat() if self.created_at else None,
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
}
class ApiKey(Base):
"""API keys for programmatic access"""
__tablename__ = "api_keys"
id = Column(String(36), primary_key=True, default=generate_uuid)
user_id = Column(String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
# Key info
name = Column(String(100), nullable=False) # User-friendly name
key_hash = Column(String(255), nullable=False) # SHA256 of the key
key_prefix = Column(String(10), nullable=False) # First 8 chars for identification
# Permissions
is_active = Column(Boolean, default=True)
scopes = Column(JSON, default=list) # ["translate", "read", "write"]
# Usage tracking
last_used_at = Column(DateTime, nullable=True)
usage_count = Column(Integer, default=0)
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow)
expires_at = Column(DateTime, nullable=True)
# Relationship
user = relationship("User", back_populates="api_keys")
# Indexes
__table_args__ = (
Index('ix_api_keys_prefix', 'key_prefix'),
Index('ix_api_keys_hash', 'key_hash'),
)
def to_dict(self) -> dict:
return {
"id": self.id,
"name": self.name,
"key_prefix": self.key_prefix,
"is_active": self.is_active,
"scopes": self.scopes,
"last_used_at": self.last_used_at.isoformat() if self.last_used_at else None,
"usage_count": self.usage_count,
"created_at": self.created_at.isoformat() if self.created_at else None,
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
}
class UsageLog(Base):
"""Daily usage aggregation for billing and analytics"""
__tablename__ = "usage_logs"
id = Column(String(36), primary_key=True, default=generate_uuid)
user_id = Column(String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
# Date (for daily aggregation)
date = Column(DateTime, nullable=False, index=True)
# Aggregated counts
documents_count = Column(Integer, default=0)
pages_count = Column(Integer, default=0)
characters_count = Column(BigInteger, default=0)
api_calls_count = Column(Integer, default=0)
# By provider breakdown (JSON)
provider_breakdown = Column(JSON, default=dict)
# Indexes
__table_args__ = (
Index('ix_usage_logs_user_date', 'user_id', 'date', unique=True),
)
class PaymentHistory(Base):
"""Payment and invoice history"""
__tablename__ = "payment_history"
id = Column(String(36), primary_key=True, default=generate_uuid)
user_id = Column(String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
# Stripe info
stripe_payment_intent_id = Column(String(255), nullable=True)
stripe_invoice_id = Column(String(255), nullable=True)
# Payment details
amount_cents = Column(Integer, nullable=False)
currency = Column(String(3), default="usd")
payment_type = Column(String(50), nullable=False) # subscription, credits, one_time
status = Column(String(20), nullable=False) # succeeded, failed, pending, refunded
# Description
description = Column(String(255), nullable=True)
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow)
# Indexes
__table_args__ = (
Index('ix_payment_history_user', 'user_id', 'created_at'),
)

341
database/repositories.py Normal file
View File

@@ -0,0 +1,341 @@
"""
Repository layer for database operations
Provides clean interface for CRUD operations
"""
import hashlib
import secrets
from datetime import datetime, timedelta
from typing import Optional, List, Dict, Any
from sqlalchemy.orm import Session
from sqlalchemy import and_, func, or_
from database.models import (
User, Translation, ApiKey, UsageLog, PaymentHistory,
PlanType, SubscriptionStatus
)
class UserRepository:
"""Repository for User database operations"""
def __init__(self, db: Session):
self.db = db
def get_by_id(self, user_id: str) -> Optional[User]:
"""Get user by ID"""
return self.db.query(User).filter(User.id == user_id).first()
def get_by_email(self, email: str) -> Optional[User]:
"""Get user by email (case-insensitive)"""
return self.db.query(User).filter(
func.lower(User.email) == email.lower()
).first()
def get_by_stripe_customer(self, stripe_customer_id: str) -> Optional[User]:
"""Get user by Stripe customer ID"""
return self.db.query(User).filter(
User.stripe_customer_id == stripe_customer_id
).first()
def create(
self,
email: str,
name: str,
password_hash: str,
plan: PlanType = PlanType.FREE
) -> User:
"""Create a new user"""
user = User(
email=email.lower(),
name=name,
password_hash=password_hash,
plan=plan,
subscription_status=SubscriptionStatus.ACTIVE,
)
self.db.add(user)
self.db.commit()
self.db.refresh(user)
return user
def update(self, user_id: str, **kwargs) -> Optional[User]:
"""Update user fields"""
user = self.get_by_id(user_id)
if not user:
return None
for key, value in kwargs.items():
if hasattr(user, key):
setattr(user, key, value)
user.updated_at = datetime.utcnow()
self.db.commit()
self.db.refresh(user)
return user
def delete(self, user_id: str) -> bool:
"""Delete a user"""
user = self.get_by_id(user_id)
if not user:
return False
self.db.delete(user)
self.db.commit()
return True
def increment_usage(
self,
user_id: str,
docs: int = 0,
pages: int = 0,
api_calls: int = 0
) -> Optional[User]:
"""Increment usage counters"""
user = self.get_by_id(user_id)
if not user:
return None
# Check if usage needs to be reset (monthly)
if user.usage_reset_date:
now = datetime.utcnow()
if now.month != user.usage_reset_date.month or now.year != user.usage_reset_date.year:
user.docs_translated_this_month = 0
user.pages_translated_this_month = 0
user.api_calls_this_month = 0
user.usage_reset_date = now
user.docs_translated_this_month += docs
user.pages_translated_this_month += pages
user.api_calls_this_month += api_calls
self.db.commit()
self.db.refresh(user)
return user
def add_credits(self, user_id: str, credits: int) -> Optional[User]:
"""Add extra credits to user"""
user = self.get_by_id(user_id)
if not user:
return None
user.extra_credits += credits
self.db.commit()
self.db.refresh(user)
return user
def use_credits(self, user_id: str, credits: int) -> bool:
"""Use credits from user balance"""
user = self.get_by_id(user_id)
if not user or user.extra_credits < credits:
return False
user.extra_credits -= credits
self.db.commit()
return True
def get_all_users(
self,
skip: int = 0,
limit: int = 100,
plan: Optional[PlanType] = None
) -> List[User]:
"""Get all users with pagination"""
query = self.db.query(User)
if plan:
query = query.filter(User.plan == plan)
return query.offset(skip).limit(limit).all()
def count_users(self, plan: Optional[PlanType] = None) -> int:
"""Count total users"""
query = self.db.query(func.count(User.id))
if plan:
query = query.filter(User.plan == plan)
return query.scalar()
class TranslationRepository:
"""Repository for Translation database operations"""
def __init__(self, db: Session):
self.db = db
def create(
self,
user_id: str,
original_filename: str,
file_type: str,
target_language: str,
provider: str,
source_language: str = "auto",
file_size_bytes: int = 0,
page_count: int = 0,
) -> Translation:
"""Create a new translation record"""
translation = Translation(
user_id=user_id,
original_filename=original_filename,
file_type=file_type,
file_size_bytes=file_size_bytes,
page_count=page_count,
source_language=source_language,
target_language=target_language,
provider=provider,
status="pending",
)
self.db.add(translation)
self.db.commit()
self.db.refresh(translation)
return translation
def update_status(
self,
translation_id: str,
status: str,
error_message: Optional[str] = None,
processing_time_ms: Optional[int] = None,
characters_translated: Optional[int] = None,
) -> Optional[Translation]:
"""Update translation status"""
translation = self.db.query(Translation).filter(
Translation.id == translation_id
).first()
if not translation:
return None
translation.status = status
if error_message:
translation.error_message = error_message
if processing_time_ms:
translation.processing_time_ms = processing_time_ms
if characters_translated:
translation.characters_translated = characters_translated
if status == "completed":
translation.completed_at = datetime.utcnow()
self.db.commit()
self.db.refresh(translation)
return translation
def get_user_translations(
self,
user_id: str,
skip: int = 0,
limit: int = 50,
status: Optional[str] = None,
) -> List[Translation]:
"""Get user's translation history"""
query = self.db.query(Translation).filter(Translation.user_id == user_id)
if status:
query = query.filter(Translation.status == status)
return query.order_by(Translation.created_at.desc()).offset(skip).limit(limit).all()
def get_user_stats(self, user_id: str, days: int = 30) -> Dict[str, Any]:
"""Get user's translation statistics"""
since = datetime.utcnow() - timedelta(days=days)
result = self.db.query(
func.count(Translation.id).label("total_translations"),
func.sum(Translation.page_count).label("total_pages"),
func.sum(Translation.characters_translated).label("total_characters"),
).filter(
and_(
Translation.user_id == user_id,
Translation.created_at >= since,
Translation.status == "completed",
)
).first()
return {
"total_translations": result.total_translations or 0,
"total_pages": result.total_pages or 0,
"total_characters": result.total_characters or 0,
"period_days": days,
}
class ApiKeyRepository:
"""Repository for API Key database operations"""
def __init__(self, db: Session):
self.db = db
@staticmethod
def hash_key(key: str) -> str:
"""Hash an API key"""
return hashlib.sha256(key.encode()).hexdigest()
def create(
self,
user_id: str,
name: str,
scopes: List[str] = None,
expires_in_days: Optional[int] = None,
) -> tuple[ApiKey, str]:
"""Create a new API key. Returns (ApiKey, raw_key)"""
# Generate a secure random key
raw_key = f"tr_{secrets.token_urlsafe(32)}"
key_hash = self.hash_key(raw_key)
key_prefix = raw_key[:10]
expires_at = None
if expires_in_days:
expires_at = datetime.utcnow() + timedelta(days=expires_in_days)
api_key = ApiKey(
user_id=user_id,
name=name,
key_hash=key_hash,
key_prefix=key_prefix,
scopes=scopes or ["translate"],
expires_at=expires_at,
)
self.db.add(api_key)
self.db.commit()
self.db.refresh(api_key)
return api_key, raw_key
def get_by_key(self, raw_key: str) -> Optional[ApiKey]:
"""Get API key by raw key value"""
key_hash = self.hash_key(raw_key)
api_key = self.db.query(ApiKey).filter(
and_(
ApiKey.key_hash == key_hash,
ApiKey.is_active == True,
)
).first()
if api_key:
# Check expiration
if api_key.expires_at and api_key.expires_at < datetime.utcnow():
return None
# Update last used
api_key.last_used_at = datetime.utcnow()
api_key.usage_count += 1
self.db.commit()
return api_key
def get_user_keys(self, user_id: str) -> List[ApiKey]:
"""Get all API keys for a user"""
return self.db.query(ApiKey).filter(
ApiKey.user_id == user_id
).order_by(ApiKey.created_at.desc()).all()
def revoke(self, key_id: str, user_id: str) -> bool:
"""Revoke an API key"""
api_key = self.db.query(ApiKey).filter(
and_(
ApiKey.id == key_id,
ApiKey.user_id == user_id,
)
).first()
if not api_key:
return False
api_key.is_active = False
self.db.commit()
return True