feat: Add PostgreSQL database infrastructure
- Add SQLAlchemy models for User, Translation, ApiKey, UsageLog, PaymentHistory - Add database connection management with PostgreSQL/SQLite support - Add repository layer for CRUD operations - Add Alembic migration setup with initial migration - Update auth_service to automatically use database when DATABASE_URL is set - Update docker-compose.yml with PostgreSQL service and Redis (non-optional) - Add database migration script (scripts/migrate_to_db.py) - Update .env.example with database configuration
This commit is contained in:
17
database/__init__.py
Normal file
17
database/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
Database module for the Document Translation API
|
||||
Provides PostgreSQL support with async SQLAlchemy
|
||||
"""
|
||||
from database.connection import get_db, engine, SessionLocal, init_db
|
||||
from database.models import User, Subscription, Translation, ApiKey
|
||||
|
||||
__all__ = [
|
||||
"get_db",
|
||||
"engine",
|
||||
"SessionLocal",
|
||||
"init_db",
|
||||
"User",
|
||||
"Subscription",
|
||||
"Translation",
|
||||
"ApiKey"
|
||||
]
|
||||
139
database/connection.py
Normal file
139
database/connection.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""
|
||||
Database connection and session management
|
||||
Supports both PostgreSQL (production) and SQLite (development/testing)
|
||||
"""
|
||||
import os
|
||||
import logging
|
||||
from typing import Generator, Optional
|
||||
from contextlib import contextmanager
|
||||
|
||||
from sqlalchemy import create_engine, event
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
from sqlalchemy.pool import QueuePool, StaticPool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Database URL from environment
|
||||
# PostgreSQL: postgresql://user:password@host:port/database
|
||||
# SQLite: sqlite:///./data/translate.db
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "")
|
||||
|
||||
# Determine if we're using SQLite or PostgreSQL
|
||||
_is_sqlite = DATABASE_URL.startswith("sqlite") if DATABASE_URL else True
|
||||
|
||||
# Create engine based on database type
|
||||
if DATABASE_URL and not _is_sqlite:
|
||||
# PostgreSQL configuration
|
||||
engine = create_engine(
|
||||
DATABASE_URL,
|
||||
poolclass=QueuePool,
|
||||
pool_size=5,
|
||||
max_overflow=10,
|
||||
pool_timeout=30,
|
||||
pool_recycle=1800, # Recycle connections after 30 minutes
|
||||
pool_pre_ping=True, # Check connection health before use
|
||||
echo=os.getenv("DATABASE_ECHO", "false").lower() == "true",
|
||||
)
|
||||
logger.info("✅ Database configured with PostgreSQL")
|
||||
else:
|
||||
# SQLite configuration (for development/testing or when no DATABASE_URL)
|
||||
sqlite_path = os.getenv("SQLITE_PATH", "data/translate.db")
|
||||
os.makedirs(os.path.dirname(sqlite_path), exist_ok=True)
|
||||
|
||||
sqlite_url = f"sqlite:///./{sqlite_path}"
|
||||
engine = create_engine(
|
||||
sqlite_url,
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
echo=os.getenv("DATABASE_ECHO", "false").lower() == "true",
|
||||
)
|
||||
|
||||
# Enable foreign keys for SQLite
|
||||
@event.listens_for(engine, "connect")
|
||||
def set_sqlite_pragma(dbapi_connection, connection_record):
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute("PRAGMA foreign_keys=ON")
|
||||
cursor.close()
|
||||
|
||||
if not DATABASE_URL:
|
||||
logger.warning("⚠️ DATABASE_URL not set, using SQLite for development")
|
||||
else:
|
||||
logger.info(f"✅ Database configured with SQLite: {sqlite_path}")
|
||||
|
||||
# Session factory
|
||||
SessionLocal = sessionmaker(
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
bind=engine,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
|
||||
def get_db() -> Generator[Session, None, None]:
|
||||
"""
|
||||
Dependency for FastAPI to get database session.
|
||||
Usage: db: Session = Depends(get_db)
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db_session() -> Generator[Session, None, None]:
|
||||
"""
|
||||
Context manager for database session.
|
||||
Usage: with get_db_session() as db: ...
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
db.commit()
|
||||
except Exception:
|
||||
db.rollback()
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
# Alias for backward compatibility
|
||||
get_sync_session = get_db_session
|
||||
|
||||
|
||||
def init_db():
|
||||
"""
|
||||
Initialize database tables.
|
||||
Call this on application startup.
|
||||
"""
|
||||
from database.models import Base
|
||||
Base.metadata.create_all(bind=engine)
|
||||
logger.info("✅ Database tables initialized")
|
||||
|
||||
|
||||
def check_db_connection() -> bool:
|
||||
"""
|
||||
Check if database connection is healthy.
|
||||
Returns True if connection works, False otherwise.
|
||||
"""
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
conn.execute("SELECT 1")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Database connection check failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# Connection pool stats (for monitoring)
|
||||
def get_pool_stats() -> dict:
|
||||
"""Get database connection pool statistics"""
|
||||
if hasattr(engine.pool, 'status'):
|
||||
return {
|
||||
"pool_size": engine.pool.size(),
|
||||
"checked_in": engine.pool.checkedin(),
|
||||
"checked_out": engine.pool.checkedout(),
|
||||
"overflow": engine.pool.overflow(),
|
||||
}
|
||||
return {"status": "pool stats not available"}
|
||||
259
database/models.py
Normal file
259
database/models.py
Normal file
@@ -0,0 +1,259 @@
|
||||
"""
|
||||
SQLAlchemy models for the Document Translation API
|
||||
"""
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
|
||||
from sqlalchemy import (
|
||||
Column, String, Integer, Float, Boolean, DateTime, Text,
|
||||
ForeignKey, Enum, Index, JSON, BigInteger
|
||||
)
|
||||
from sqlalchemy.orm import relationship, declarative_base
|
||||
from sqlalchemy.dialects.postgresql import UUID as PG_UUID
|
||||
import enum
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
def generate_uuid():
|
||||
"""Generate a new UUID string"""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
class PlanType(str, enum.Enum):
|
||||
FREE = "free"
|
||||
STARTER = "starter"
|
||||
PRO = "pro"
|
||||
BUSINESS = "business"
|
||||
ENTERPRISE = "enterprise"
|
||||
|
||||
|
||||
class SubscriptionStatus(str, enum.Enum):
|
||||
ACTIVE = "active"
|
||||
CANCELED = "canceled"
|
||||
PAST_DUE = "past_due"
|
||||
TRIALING = "trialing"
|
||||
PAUSED = "paused"
|
||||
|
||||
|
||||
class User(Base):
|
||||
"""User model for authentication and billing"""
|
||||
__tablename__ = "users"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=generate_uuid)
|
||||
email = Column(String(255), unique=True, nullable=False, index=True)
|
||||
name = Column(String(255), nullable=False)
|
||||
password_hash = Column(String(255), nullable=False)
|
||||
|
||||
# Account status
|
||||
email_verified = Column(Boolean, default=False)
|
||||
is_active = Column(Boolean, default=True)
|
||||
avatar_url = Column(String(500), nullable=True)
|
||||
|
||||
# Subscription info
|
||||
plan = Column(Enum(PlanType), default=PlanType.FREE)
|
||||
subscription_status = Column(Enum(SubscriptionStatus), default=SubscriptionStatus.ACTIVE)
|
||||
|
||||
# Stripe integration
|
||||
stripe_customer_id = Column(String(255), nullable=True, index=True)
|
||||
stripe_subscription_id = Column(String(255), nullable=True)
|
||||
|
||||
# Usage tracking (reset monthly)
|
||||
docs_translated_this_month = Column(Integer, default=0)
|
||||
pages_translated_this_month = Column(Integer, default=0)
|
||||
api_calls_this_month = Column(Integer, default=0)
|
||||
extra_credits = Column(Integer, default=0) # Purchased credits
|
||||
usage_reset_date = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
last_login_at = Column(DateTime, nullable=True)
|
||||
|
||||
# Relationships
|
||||
translations = relationship("Translation", back_populates="user", lazy="dynamic")
|
||||
api_keys = relationship("ApiKey", back_populates="user", lazy="dynamic")
|
||||
|
||||
# Indexes
|
||||
__table_args__ = (
|
||||
Index('ix_users_email_active', 'email', 'is_active'),
|
||||
Index('ix_users_stripe_customer', 'stripe_customer_id'),
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert user to dictionary for API response"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"email": self.email,
|
||||
"name": self.name,
|
||||
"avatar_url": self.avatar_url,
|
||||
"plan": self.plan.value if self.plan else "free",
|
||||
"subscription_status": self.subscription_status.value if self.subscription_status else "active",
|
||||
"docs_translated_this_month": self.docs_translated_this_month,
|
||||
"pages_translated_this_month": self.pages_translated_this_month,
|
||||
"api_calls_this_month": self.api_calls_this_month,
|
||||
"extra_credits": self.extra_credits,
|
||||
"email_verified": self.email_verified,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
class Translation(Base):
|
||||
"""Translation history for analytics and billing"""
|
||||
__tablename__ = "translations"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=generate_uuid)
|
||||
user_id = Column(String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
|
||||
|
||||
# File info
|
||||
original_filename = Column(String(255), nullable=False)
|
||||
file_type = Column(String(10), nullable=False) # xlsx, docx, pptx
|
||||
file_size_bytes = Column(BigInteger, default=0)
|
||||
page_count = Column(Integer, default=0)
|
||||
|
||||
# Translation details
|
||||
source_language = Column(String(10), default="auto")
|
||||
target_language = Column(String(10), nullable=False)
|
||||
provider = Column(String(50), nullable=False) # google, deepl, ollama, etc.
|
||||
|
||||
# Processing info
|
||||
status = Column(String(20), default="pending") # pending, processing, completed, failed
|
||||
error_message = Column(Text, nullable=True)
|
||||
processing_time_ms = Column(Integer, nullable=True)
|
||||
|
||||
# Cost tracking (for paid providers)
|
||||
characters_translated = Column(Integer, default=0)
|
||||
estimated_cost_usd = Column(Float, default=0.0)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
completed_at = Column(DateTime, nullable=True)
|
||||
|
||||
# Relationship
|
||||
user = relationship("User", back_populates="translations")
|
||||
|
||||
# Indexes
|
||||
__table_args__ = (
|
||||
Index('ix_translations_user_date', 'user_id', 'created_at'),
|
||||
Index('ix_translations_status', 'status'),
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"id": self.id,
|
||||
"original_filename": self.original_filename,
|
||||
"file_type": self.file_type,
|
||||
"file_size_bytes": self.file_size_bytes,
|
||||
"page_count": self.page_count,
|
||||
"source_language": self.source_language,
|
||||
"target_language": self.target_language,
|
||||
"provider": self.provider,
|
||||
"status": self.status,
|
||||
"processing_time_ms": self.processing_time_ms,
|
||||
"characters_translated": self.characters_translated,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
|
||||
}
|
||||
|
||||
|
||||
class ApiKey(Base):
|
||||
"""API keys for programmatic access"""
|
||||
__tablename__ = "api_keys"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=generate_uuid)
|
||||
user_id = Column(String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
|
||||
|
||||
# Key info
|
||||
name = Column(String(100), nullable=False) # User-friendly name
|
||||
key_hash = Column(String(255), nullable=False) # SHA256 of the key
|
||||
key_prefix = Column(String(10), nullable=False) # First 8 chars for identification
|
||||
|
||||
# Permissions
|
||||
is_active = Column(Boolean, default=True)
|
||||
scopes = Column(JSON, default=list) # ["translate", "read", "write"]
|
||||
|
||||
# Usage tracking
|
||||
last_used_at = Column(DateTime, nullable=True)
|
||||
usage_count = Column(Integer, default=0)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
expires_at = Column(DateTime, nullable=True)
|
||||
|
||||
# Relationship
|
||||
user = relationship("User", back_populates="api_keys")
|
||||
|
||||
# Indexes
|
||||
__table_args__ = (
|
||||
Index('ix_api_keys_prefix', 'key_prefix'),
|
||||
Index('ix_api_keys_hash', 'key_hash'),
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"key_prefix": self.key_prefix,
|
||||
"is_active": self.is_active,
|
||||
"scopes": self.scopes,
|
||||
"last_used_at": self.last_used_at.isoformat() if self.last_used_at else None,
|
||||
"usage_count": self.usage_count,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
|
||||
}
|
||||
|
||||
|
||||
class UsageLog(Base):
|
||||
"""Daily usage aggregation for billing and analytics"""
|
||||
__tablename__ = "usage_logs"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=generate_uuid)
|
||||
user_id = Column(String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
|
||||
|
||||
# Date (for daily aggregation)
|
||||
date = Column(DateTime, nullable=False, index=True)
|
||||
|
||||
# Aggregated counts
|
||||
documents_count = Column(Integer, default=0)
|
||||
pages_count = Column(Integer, default=0)
|
||||
characters_count = Column(BigInteger, default=0)
|
||||
api_calls_count = Column(Integer, default=0)
|
||||
|
||||
# By provider breakdown (JSON)
|
||||
provider_breakdown = Column(JSON, default=dict)
|
||||
|
||||
# Indexes
|
||||
__table_args__ = (
|
||||
Index('ix_usage_logs_user_date', 'user_id', 'date', unique=True),
|
||||
)
|
||||
|
||||
|
||||
class PaymentHistory(Base):
|
||||
"""Payment and invoice history"""
|
||||
__tablename__ = "payment_history"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=generate_uuid)
|
||||
user_id = Column(String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
|
||||
|
||||
# Stripe info
|
||||
stripe_payment_intent_id = Column(String(255), nullable=True)
|
||||
stripe_invoice_id = Column(String(255), nullable=True)
|
||||
|
||||
# Payment details
|
||||
amount_cents = Column(Integer, nullable=False)
|
||||
currency = Column(String(3), default="usd")
|
||||
payment_type = Column(String(50), nullable=False) # subscription, credits, one_time
|
||||
status = Column(String(20), nullable=False) # succeeded, failed, pending, refunded
|
||||
|
||||
# Description
|
||||
description = Column(String(255), nullable=True)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
# Indexes
|
||||
__table_args__ = (
|
||||
Index('ix_payment_history_user', 'user_id', 'created_at'),
|
||||
)
|
||||
341
database/repositories.py
Normal file
341
database/repositories.py
Normal file
@@ -0,0 +1,341 @@
|
||||
"""
|
||||
Repository layer for database operations
|
||||
Provides clean interface for CRUD operations
|
||||
"""
|
||||
import hashlib
|
||||
import secrets
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_, func, or_
|
||||
|
||||
from database.models import (
|
||||
User, Translation, ApiKey, UsageLog, PaymentHistory,
|
||||
PlanType, SubscriptionStatus
|
||||
)
|
||||
|
||||
|
||||
class UserRepository:
|
||||
"""Repository for User database operations"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def get_by_id(self, user_id: str) -> Optional[User]:
|
||||
"""Get user by ID"""
|
||||
return self.db.query(User).filter(User.id == user_id).first()
|
||||
|
||||
def get_by_email(self, email: str) -> Optional[User]:
|
||||
"""Get user by email (case-insensitive)"""
|
||||
return self.db.query(User).filter(
|
||||
func.lower(User.email) == email.lower()
|
||||
).first()
|
||||
|
||||
def get_by_stripe_customer(self, stripe_customer_id: str) -> Optional[User]:
|
||||
"""Get user by Stripe customer ID"""
|
||||
return self.db.query(User).filter(
|
||||
User.stripe_customer_id == stripe_customer_id
|
||||
).first()
|
||||
|
||||
def create(
|
||||
self,
|
||||
email: str,
|
||||
name: str,
|
||||
password_hash: str,
|
||||
plan: PlanType = PlanType.FREE
|
||||
) -> User:
|
||||
"""Create a new user"""
|
||||
user = User(
|
||||
email=email.lower(),
|
||||
name=name,
|
||||
password_hash=password_hash,
|
||||
plan=plan,
|
||||
subscription_status=SubscriptionStatus.ACTIVE,
|
||||
)
|
||||
self.db.add(user)
|
||||
self.db.commit()
|
||||
self.db.refresh(user)
|
||||
return user
|
||||
|
||||
def update(self, user_id: str, **kwargs) -> Optional[User]:
|
||||
"""Update user fields"""
|
||||
user = self.get_by_id(user_id)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(user, key):
|
||||
setattr(user, key, value)
|
||||
|
||||
user.updated_at = datetime.utcnow()
|
||||
self.db.commit()
|
||||
self.db.refresh(user)
|
||||
return user
|
||||
|
||||
def delete(self, user_id: str) -> bool:
|
||||
"""Delete a user"""
|
||||
user = self.get_by_id(user_id)
|
||||
if not user:
|
||||
return False
|
||||
|
||||
self.db.delete(user)
|
||||
self.db.commit()
|
||||
return True
|
||||
|
||||
def increment_usage(
|
||||
self,
|
||||
user_id: str,
|
||||
docs: int = 0,
|
||||
pages: int = 0,
|
||||
api_calls: int = 0
|
||||
) -> Optional[User]:
|
||||
"""Increment usage counters"""
|
||||
user = self.get_by_id(user_id)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
# Check if usage needs to be reset (monthly)
|
||||
if user.usage_reset_date:
|
||||
now = datetime.utcnow()
|
||||
if now.month != user.usage_reset_date.month or now.year != user.usage_reset_date.year:
|
||||
user.docs_translated_this_month = 0
|
||||
user.pages_translated_this_month = 0
|
||||
user.api_calls_this_month = 0
|
||||
user.usage_reset_date = now
|
||||
|
||||
user.docs_translated_this_month += docs
|
||||
user.pages_translated_this_month += pages
|
||||
user.api_calls_this_month += api_calls
|
||||
|
||||
self.db.commit()
|
||||
self.db.refresh(user)
|
||||
return user
|
||||
|
||||
def add_credits(self, user_id: str, credits: int) -> Optional[User]:
|
||||
"""Add extra credits to user"""
|
||||
user = self.get_by_id(user_id)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
user.extra_credits += credits
|
||||
self.db.commit()
|
||||
self.db.refresh(user)
|
||||
return user
|
||||
|
||||
def use_credits(self, user_id: str, credits: int) -> bool:
|
||||
"""Use credits from user balance"""
|
||||
user = self.get_by_id(user_id)
|
||||
if not user or user.extra_credits < credits:
|
||||
return False
|
||||
|
||||
user.extra_credits -= credits
|
||||
self.db.commit()
|
||||
return True
|
||||
|
||||
def get_all_users(
|
||||
self,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
plan: Optional[PlanType] = None
|
||||
) -> List[User]:
|
||||
"""Get all users with pagination"""
|
||||
query = self.db.query(User)
|
||||
if plan:
|
||||
query = query.filter(User.plan == plan)
|
||||
return query.offset(skip).limit(limit).all()
|
||||
|
||||
def count_users(self, plan: Optional[PlanType] = None) -> int:
|
||||
"""Count total users"""
|
||||
query = self.db.query(func.count(User.id))
|
||||
if plan:
|
||||
query = query.filter(User.plan == plan)
|
||||
return query.scalar()
|
||||
|
||||
|
||||
class TranslationRepository:
|
||||
"""Repository for Translation database operations"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def create(
|
||||
self,
|
||||
user_id: str,
|
||||
original_filename: str,
|
||||
file_type: str,
|
||||
target_language: str,
|
||||
provider: str,
|
||||
source_language: str = "auto",
|
||||
file_size_bytes: int = 0,
|
||||
page_count: int = 0,
|
||||
) -> Translation:
|
||||
"""Create a new translation record"""
|
||||
translation = Translation(
|
||||
user_id=user_id,
|
||||
original_filename=original_filename,
|
||||
file_type=file_type,
|
||||
file_size_bytes=file_size_bytes,
|
||||
page_count=page_count,
|
||||
source_language=source_language,
|
||||
target_language=target_language,
|
||||
provider=provider,
|
||||
status="pending",
|
||||
)
|
||||
self.db.add(translation)
|
||||
self.db.commit()
|
||||
self.db.refresh(translation)
|
||||
return translation
|
||||
|
||||
def update_status(
|
||||
self,
|
||||
translation_id: str,
|
||||
status: str,
|
||||
error_message: Optional[str] = None,
|
||||
processing_time_ms: Optional[int] = None,
|
||||
characters_translated: Optional[int] = None,
|
||||
) -> Optional[Translation]:
|
||||
"""Update translation status"""
|
||||
translation = self.db.query(Translation).filter(
|
||||
Translation.id == translation_id
|
||||
).first()
|
||||
|
||||
if not translation:
|
||||
return None
|
||||
|
||||
translation.status = status
|
||||
if error_message:
|
||||
translation.error_message = error_message
|
||||
if processing_time_ms:
|
||||
translation.processing_time_ms = processing_time_ms
|
||||
if characters_translated:
|
||||
translation.characters_translated = characters_translated
|
||||
if status == "completed":
|
||||
translation.completed_at = datetime.utcnow()
|
||||
|
||||
self.db.commit()
|
||||
self.db.refresh(translation)
|
||||
return translation
|
||||
|
||||
def get_user_translations(
|
||||
self,
|
||||
user_id: str,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
status: Optional[str] = None,
|
||||
) -> List[Translation]:
|
||||
"""Get user's translation history"""
|
||||
query = self.db.query(Translation).filter(Translation.user_id == user_id)
|
||||
if status:
|
||||
query = query.filter(Translation.status == status)
|
||||
return query.order_by(Translation.created_at.desc()).offset(skip).limit(limit).all()
|
||||
|
||||
def get_user_stats(self, user_id: str, days: int = 30) -> Dict[str, Any]:
|
||||
"""Get user's translation statistics"""
|
||||
since = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
result = self.db.query(
|
||||
func.count(Translation.id).label("total_translations"),
|
||||
func.sum(Translation.page_count).label("total_pages"),
|
||||
func.sum(Translation.characters_translated).label("total_characters"),
|
||||
).filter(
|
||||
and_(
|
||||
Translation.user_id == user_id,
|
||||
Translation.created_at >= since,
|
||||
Translation.status == "completed",
|
||||
)
|
||||
).first()
|
||||
|
||||
return {
|
||||
"total_translations": result.total_translations or 0,
|
||||
"total_pages": result.total_pages or 0,
|
||||
"total_characters": result.total_characters or 0,
|
||||
"period_days": days,
|
||||
}
|
||||
|
||||
|
||||
class ApiKeyRepository:
|
||||
"""Repository for API Key database operations"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
@staticmethod
|
||||
def hash_key(key: str) -> str:
|
||||
"""Hash an API key"""
|
||||
return hashlib.sha256(key.encode()).hexdigest()
|
||||
|
||||
def create(
|
||||
self,
|
||||
user_id: str,
|
||||
name: str,
|
||||
scopes: List[str] = None,
|
||||
expires_in_days: Optional[int] = None,
|
||||
) -> tuple[ApiKey, str]:
|
||||
"""Create a new API key. Returns (ApiKey, raw_key)"""
|
||||
# Generate a secure random key
|
||||
raw_key = f"tr_{secrets.token_urlsafe(32)}"
|
||||
key_hash = self.hash_key(raw_key)
|
||||
key_prefix = raw_key[:10]
|
||||
|
||||
expires_at = None
|
||||
if expires_in_days:
|
||||
expires_at = datetime.utcnow() + timedelta(days=expires_in_days)
|
||||
|
||||
api_key = ApiKey(
|
||||
user_id=user_id,
|
||||
name=name,
|
||||
key_hash=key_hash,
|
||||
key_prefix=key_prefix,
|
||||
scopes=scopes or ["translate"],
|
||||
expires_at=expires_at,
|
||||
)
|
||||
self.db.add(api_key)
|
||||
self.db.commit()
|
||||
self.db.refresh(api_key)
|
||||
|
||||
return api_key, raw_key
|
||||
|
||||
def get_by_key(self, raw_key: str) -> Optional[ApiKey]:
|
||||
"""Get API key by raw key value"""
|
||||
key_hash = self.hash_key(raw_key)
|
||||
api_key = self.db.query(ApiKey).filter(
|
||||
and_(
|
||||
ApiKey.key_hash == key_hash,
|
||||
ApiKey.is_active == True,
|
||||
)
|
||||
).first()
|
||||
|
||||
if api_key:
|
||||
# Check expiration
|
||||
if api_key.expires_at and api_key.expires_at < datetime.utcnow():
|
||||
return None
|
||||
|
||||
# Update last used
|
||||
api_key.last_used_at = datetime.utcnow()
|
||||
api_key.usage_count += 1
|
||||
self.db.commit()
|
||||
|
||||
return api_key
|
||||
|
||||
def get_user_keys(self, user_id: str) -> List[ApiKey]:
|
||||
"""Get all API keys for a user"""
|
||||
return self.db.query(ApiKey).filter(
|
||||
ApiKey.user_id == user_id
|
||||
).order_by(ApiKey.created_at.desc()).all()
|
||||
|
||||
def revoke(self, key_id: str, user_id: str) -> bool:
|
||||
"""Revoke an API key"""
|
||||
api_key = self.db.query(ApiKey).filter(
|
||||
and_(
|
||||
ApiKey.id == key_id,
|
||||
ApiKey.user_id == user_id,
|
||||
)
|
||||
).first()
|
||||
|
||||
if not api_key:
|
||||
return False
|
||||
|
||||
api_key.is_active = False
|
||||
self.db.commit()
|
||||
return True
|
||||
Reference in New Issue
Block a user