feat: Add PostgreSQL database infrastructure

- Add SQLAlchemy models for User, Translation, ApiKey, UsageLog, PaymentHistory
- Add database connection management with PostgreSQL/SQLite support
- Add repository layer for CRUD operations
- Add Alembic migration setup with initial migration
- Update auth_service to automatically use database when DATABASE_URL is set
- Update docker-compose.yml with PostgreSQL service and Redis (non-optional)
- Add database migration script (scripts/migrate_to_db.py)
- Update .env.example with database configuration
This commit is contained in:
Sepehr 2025-12-31 10:56:19 +01:00
parent c4d6cae735
commit 550f3516db
15 changed files with 1712 additions and 63 deletions

View File

@ -72,9 +72,16 @@ MAX_REQUEST_SIZE_MB=100
# Request timeout in seconds
REQUEST_TIMEOUT_SECONDS=300
# ============== Database (Production) ==============
# ============== Database ==============
# PostgreSQL connection string (recommended for production)
# DATABASE_URL=postgresql://user:password@localhost:5432/translate_db
# Format: postgresql://user:password@host:port/database
# DATABASE_URL=postgresql://translate_user:secure_password@localhost:5432/translate_db
# SQLite path (used when DATABASE_URL is not set - for development)
SQLITE_PATH=data/translate.db
# Enable SQL query logging (for debugging)
DATABASE_ECHO=false
# Redis for sessions and caching (recommended for production)
# REDIS_URL=redis://localhost:6379/0

78
alembic.ini Normal file
View File

@ -0,0 +1,78 @@
# Alembic configuration file
[alembic]
# path to migration scripts
script_location = alembic
# template used to generate migration files
# file_template = %%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# timezone =
# max length of characters to apply to the "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during the 'revision' command,
# regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without a source .py file
# to be detected as revisions in the versions/ directory
# sourceless = false
# version location specification; This defaults to alembic/versions
# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions
# version path separator
# version_path_separator = :
# the output encoding used when revision files are written
# output_encoding = utf-8
sqlalchemy.url = driver://user:pass@localhost/dbname
[post_write_hooks]
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -q
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

86
alembic/env.py Normal file
View File

@ -0,0 +1,86 @@
"""
Alembic environment configuration
"""
import os
import sys
from logging.config import fileConfig
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
# Add parent directory to path for imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# this is the Alembic Config object
config = context.config
# Interpret the config file for Python logging
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# Import models for autogenerate support
from database.models import Base
target_metadata = Base.metadata
# Get database URL from environment
from dotenv import load_dotenv
load_dotenv()
DATABASE_URL = os.getenv("DATABASE_URL", "")
if not DATABASE_URL:
SQLITE_PATH = os.getenv("SQLITE_PATH", "data/translate.db")
DATABASE_URL = f"sqlite:///./{SQLITE_PATH}"
# Override sqlalchemy.url with environment variable
config.set_main_option("sqlalchemy.url", DATABASE_URL)
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(
connection=connection,
target_metadata=target_metadata,
compare_type=True,
)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

27
alembic/script.py.mako Normal file
View File

@ -0,0 +1,27 @@
"""
Alembic migration script template
${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}

View File

@ -0,0 +1,126 @@
"""Initial database schema
Revision ID: 001_initial
Revises:
Create Date: 2024-12-31
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '001_initial'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Create users table
op.create_table(
'users',
sa.Column('id', sa.String(36), primary_key=True),
sa.Column('email', sa.String(255), unique=True, nullable=False),
sa.Column('name', sa.String(255), nullable=False),
sa.Column('password_hash', sa.String(255), nullable=False),
sa.Column('email_verified', sa.Boolean(), default=False),
sa.Column('is_active', sa.Boolean(), default=True),
sa.Column('avatar_url', sa.String(500), nullable=True),
sa.Column('plan', sa.String(20), default='free'),
sa.Column('subscription_status', sa.String(20), default='active'),
sa.Column('stripe_customer_id', sa.String(255), nullable=True),
sa.Column('stripe_subscription_id', sa.String(255), nullable=True),
sa.Column('docs_translated_this_month', sa.Integer(), default=0),
sa.Column('pages_translated_this_month', sa.Integer(), default=0),
sa.Column('api_calls_this_month', sa.Integer(), default=0),
sa.Column('extra_credits', sa.Integer(), default=0),
sa.Column('usage_reset_date', sa.DateTime()),
sa.Column('created_at', sa.DateTime()),
sa.Column('updated_at', sa.DateTime()),
sa.Column('last_login_at', sa.DateTime(), nullable=True),
)
op.create_index('ix_users_email', 'users', ['email'])
op.create_index('ix_users_email_active', 'users', ['email', 'is_active'])
op.create_index('ix_users_stripe_customer', 'users', ['stripe_customer_id'])
# Create translations table
op.create_table(
'translations',
sa.Column('id', sa.String(36), primary_key=True),
sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id', ondelete='CASCADE'), nullable=False),
sa.Column('original_filename', sa.String(255), nullable=False),
sa.Column('file_type', sa.String(10), nullable=False),
sa.Column('file_size_bytes', sa.BigInteger(), default=0),
sa.Column('page_count', sa.Integer(), default=0),
sa.Column('source_language', sa.String(10), default='auto'),
sa.Column('target_language', sa.String(10), nullable=False),
sa.Column('provider', sa.String(50), nullable=False),
sa.Column('status', sa.String(20), default='pending'),
sa.Column('error_message', sa.Text(), nullable=True),
sa.Column('processing_time_ms', sa.Integer(), nullable=True),
sa.Column('characters_translated', sa.Integer(), default=0),
sa.Column('estimated_cost_usd', sa.Float(), default=0.0),
sa.Column('created_at', sa.DateTime()),
sa.Column('completed_at', sa.DateTime(), nullable=True),
)
op.create_index('ix_translations_user_date', 'translations', ['user_id', 'created_at'])
op.create_index('ix_translations_status', 'translations', ['status'])
# Create api_keys table
op.create_table(
'api_keys',
sa.Column('id', sa.String(36), primary_key=True),
sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id', ondelete='CASCADE'), nullable=False),
sa.Column('name', sa.String(100), nullable=False),
sa.Column('key_hash', sa.String(255), nullable=False),
sa.Column('key_prefix', sa.String(10), nullable=False),
sa.Column('is_active', sa.Boolean(), default=True),
sa.Column('scopes', sa.JSON(), default=list),
sa.Column('last_used_at', sa.DateTime(), nullable=True),
sa.Column('usage_count', sa.Integer(), default=0),
sa.Column('created_at', sa.DateTime()),
sa.Column('expires_at', sa.DateTime(), nullable=True),
)
op.create_index('ix_api_keys_prefix', 'api_keys', ['key_prefix'])
op.create_index('ix_api_keys_hash', 'api_keys', ['key_hash'])
# Create usage_logs table
op.create_table(
'usage_logs',
sa.Column('id', sa.String(36), primary_key=True),
sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id', ondelete='CASCADE'), nullable=False),
sa.Column('date', sa.DateTime(), nullable=False),
sa.Column('documents_count', sa.Integer(), default=0),
sa.Column('pages_count', sa.Integer(), default=0),
sa.Column('characters_count', sa.BigInteger(), default=0),
sa.Column('api_calls_count', sa.Integer(), default=0),
sa.Column('provider_breakdown', sa.JSON(), default=dict),
)
op.create_index('ix_usage_logs_user_date', 'usage_logs', ['user_id', 'date'], unique=True)
# Create payment_history table
op.create_table(
'payment_history',
sa.Column('id', sa.String(36), primary_key=True),
sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id', ondelete='CASCADE'), nullable=False),
sa.Column('stripe_payment_intent_id', sa.String(255), nullable=True),
sa.Column('stripe_invoice_id', sa.String(255), nullable=True),
sa.Column('amount_cents', sa.Integer(), nullable=False),
sa.Column('currency', sa.String(3), default='usd'),
sa.Column('payment_type', sa.String(50), nullable=False),
sa.Column('status', sa.String(20), nullable=False),
sa.Column('description', sa.String(255), nullable=True),
sa.Column('created_at', sa.DateTime()),
)
op.create_index('ix_payment_history_user', 'payment_history', ['user_id', 'created_at'])
def downgrade() -> None:
op.drop_table('payment_history')
op.drop_table('usage_logs')
op.drop_table('api_keys')
op.drop_table('translations')
op.drop_table('users')

17
database/__init__.py Normal file
View File

@ -0,0 +1,17 @@
"""
Database module for the Document Translation API
Provides PostgreSQL support with async SQLAlchemy
"""
from database.connection import get_db, engine, SessionLocal, init_db
from database.models import User, Subscription, Translation, ApiKey
__all__ = [
"get_db",
"engine",
"SessionLocal",
"init_db",
"User",
"Subscription",
"Translation",
"ApiKey"
]

139
database/connection.py Normal file
View File

@ -0,0 +1,139 @@
"""
Database connection and session management
Supports both PostgreSQL (production) and SQLite (development/testing)
"""
import os
import logging
from typing import Generator, Optional
from contextlib import contextmanager
from sqlalchemy import create_engine, event
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.pool import QueuePool, StaticPool
logger = logging.getLogger(__name__)
# Database URL from environment
# PostgreSQL: postgresql://user:password@host:port/database
# SQLite: sqlite:///./data/translate.db
DATABASE_URL = os.getenv("DATABASE_URL", "")
# Determine if we're using SQLite or PostgreSQL
_is_sqlite = DATABASE_URL.startswith("sqlite") if DATABASE_URL else True
# Create engine based on database type
if DATABASE_URL and not _is_sqlite:
# PostgreSQL configuration
engine = create_engine(
DATABASE_URL,
poolclass=QueuePool,
pool_size=5,
max_overflow=10,
pool_timeout=30,
pool_recycle=1800, # Recycle connections after 30 minutes
pool_pre_ping=True, # Check connection health before use
echo=os.getenv("DATABASE_ECHO", "false").lower() == "true",
)
logger.info("✅ Database configured with PostgreSQL")
else:
# SQLite configuration (for development/testing or when no DATABASE_URL)
sqlite_path = os.getenv("SQLITE_PATH", "data/translate.db")
os.makedirs(os.path.dirname(sqlite_path), exist_ok=True)
sqlite_url = f"sqlite:///./{sqlite_path}"
engine = create_engine(
sqlite_url,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
echo=os.getenv("DATABASE_ECHO", "false").lower() == "true",
)
# Enable foreign keys for SQLite
@event.listens_for(engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
if not DATABASE_URL:
logger.warning("⚠️ DATABASE_URL not set, using SQLite for development")
else:
logger.info(f"✅ Database configured with SQLite: {sqlite_path}")
# Session factory
SessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=engine,
expire_on_commit=False,
)
def get_db() -> Generator[Session, None, None]:
"""
Dependency for FastAPI to get database session.
Usage: db: Session = Depends(get_db)
"""
db = SessionLocal()
try:
yield db
finally:
db.close()
@contextmanager
def get_db_session() -> Generator[Session, None, None]:
"""
Context manager for database session.
Usage: with get_db_session() as db: ...
"""
db = SessionLocal()
try:
yield db
db.commit()
except Exception:
db.rollback()
raise
finally:
db.close()
# Alias for backward compatibility
get_sync_session = get_db_session
def init_db():
"""
Initialize database tables.
Call this on application startup.
"""
from database.models import Base
Base.metadata.create_all(bind=engine)
logger.info("✅ Database tables initialized")
def check_db_connection() -> bool:
"""
Check if database connection is healthy.
Returns True if connection works, False otherwise.
"""
try:
with engine.connect() as conn:
conn.execute("SELECT 1")
return True
except Exception as e:
logger.error(f"Database connection check failed: {e}")
return False
# Connection pool stats (for monitoring)
def get_pool_stats() -> dict:
"""Get database connection pool statistics"""
if hasattr(engine.pool, 'status'):
return {
"pool_size": engine.pool.size(),
"checked_in": engine.pool.checkedin(),
"checked_out": engine.pool.checkedout(),
"overflow": engine.pool.overflow(),
}
return {"status": "pool stats not available"}

259
database/models.py Normal file
View File

@ -0,0 +1,259 @@
"""
SQLAlchemy models for the Document Translation API
"""
import os
import uuid
from datetime import datetime
from typing import Optional, List
from sqlalchemy import (
Column, String, Integer, Float, Boolean, DateTime, Text,
ForeignKey, Enum, Index, JSON, BigInteger
)
from sqlalchemy.orm import relationship, declarative_base
from sqlalchemy.dialects.postgresql import UUID as PG_UUID
import enum
Base = declarative_base()
def generate_uuid():
"""Generate a new UUID string"""
return str(uuid.uuid4())
class PlanType(str, enum.Enum):
FREE = "free"
STARTER = "starter"
PRO = "pro"
BUSINESS = "business"
ENTERPRISE = "enterprise"
class SubscriptionStatus(str, enum.Enum):
ACTIVE = "active"
CANCELED = "canceled"
PAST_DUE = "past_due"
TRIALING = "trialing"
PAUSED = "paused"
class User(Base):
"""User model for authentication and billing"""
__tablename__ = "users"
id = Column(String(36), primary_key=True, default=generate_uuid)
email = Column(String(255), unique=True, nullable=False, index=True)
name = Column(String(255), nullable=False)
password_hash = Column(String(255), nullable=False)
# Account status
email_verified = Column(Boolean, default=False)
is_active = Column(Boolean, default=True)
avatar_url = Column(String(500), nullable=True)
# Subscription info
plan = Column(Enum(PlanType), default=PlanType.FREE)
subscription_status = Column(Enum(SubscriptionStatus), default=SubscriptionStatus.ACTIVE)
# Stripe integration
stripe_customer_id = Column(String(255), nullable=True, index=True)
stripe_subscription_id = Column(String(255), nullable=True)
# Usage tracking (reset monthly)
docs_translated_this_month = Column(Integer, default=0)
pages_translated_this_month = Column(Integer, default=0)
api_calls_this_month = Column(Integer, default=0)
extra_credits = Column(Integer, default=0) # Purchased credits
usage_reset_date = Column(DateTime, default=datetime.utcnow)
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
last_login_at = Column(DateTime, nullable=True)
# Relationships
translations = relationship("Translation", back_populates="user", lazy="dynamic")
api_keys = relationship("ApiKey", back_populates="user", lazy="dynamic")
# Indexes
__table_args__ = (
Index('ix_users_email_active', 'email', 'is_active'),
Index('ix_users_stripe_customer', 'stripe_customer_id'),
)
def to_dict(self) -> dict:
"""Convert user to dictionary for API response"""
return {
"id": self.id,
"email": self.email,
"name": self.name,
"avatar_url": self.avatar_url,
"plan": self.plan.value if self.plan else "free",
"subscription_status": self.subscription_status.value if self.subscription_status else "active",
"docs_translated_this_month": self.docs_translated_this_month,
"pages_translated_this_month": self.pages_translated_this_month,
"api_calls_this_month": self.api_calls_this_month,
"extra_credits": self.extra_credits,
"email_verified": self.email_verified,
"created_at": self.created_at.isoformat() if self.created_at else None,
}
class Translation(Base):
"""Translation history for analytics and billing"""
__tablename__ = "translations"
id = Column(String(36), primary_key=True, default=generate_uuid)
user_id = Column(String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
# File info
original_filename = Column(String(255), nullable=False)
file_type = Column(String(10), nullable=False) # xlsx, docx, pptx
file_size_bytes = Column(BigInteger, default=0)
page_count = Column(Integer, default=0)
# Translation details
source_language = Column(String(10), default="auto")
target_language = Column(String(10), nullable=False)
provider = Column(String(50), nullable=False) # google, deepl, ollama, etc.
# Processing info
status = Column(String(20), default="pending") # pending, processing, completed, failed
error_message = Column(Text, nullable=True)
processing_time_ms = Column(Integer, nullable=True)
# Cost tracking (for paid providers)
characters_translated = Column(Integer, default=0)
estimated_cost_usd = Column(Float, default=0.0)
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow)
completed_at = Column(DateTime, nullable=True)
# Relationship
user = relationship("User", back_populates="translations")
# Indexes
__table_args__ = (
Index('ix_translations_user_date', 'user_id', 'created_at'),
Index('ix_translations_status', 'status'),
)
def to_dict(self) -> dict:
return {
"id": self.id,
"original_filename": self.original_filename,
"file_type": self.file_type,
"file_size_bytes": self.file_size_bytes,
"page_count": self.page_count,
"source_language": self.source_language,
"target_language": self.target_language,
"provider": self.provider,
"status": self.status,
"processing_time_ms": self.processing_time_ms,
"characters_translated": self.characters_translated,
"created_at": self.created_at.isoformat() if self.created_at else None,
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
}
class ApiKey(Base):
"""API keys for programmatic access"""
__tablename__ = "api_keys"
id = Column(String(36), primary_key=True, default=generate_uuid)
user_id = Column(String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
# Key info
name = Column(String(100), nullable=False) # User-friendly name
key_hash = Column(String(255), nullable=False) # SHA256 of the key
key_prefix = Column(String(10), nullable=False) # First 8 chars for identification
# Permissions
is_active = Column(Boolean, default=True)
scopes = Column(JSON, default=list) # ["translate", "read", "write"]
# Usage tracking
last_used_at = Column(DateTime, nullable=True)
usage_count = Column(Integer, default=0)
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow)
expires_at = Column(DateTime, nullable=True)
# Relationship
user = relationship("User", back_populates="api_keys")
# Indexes
__table_args__ = (
Index('ix_api_keys_prefix', 'key_prefix'),
Index('ix_api_keys_hash', 'key_hash'),
)
def to_dict(self) -> dict:
return {
"id": self.id,
"name": self.name,
"key_prefix": self.key_prefix,
"is_active": self.is_active,
"scopes": self.scopes,
"last_used_at": self.last_used_at.isoformat() if self.last_used_at else None,
"usage_count": self.usage_count,
"created_at": self.created_at.isoformat() if self.created_at else None,
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
}
class UsageLog(Base):
"""Daily usage aggregation for billing and analytics"""
__tablename__ = "usage_logs"
id = Column(String(36), primary_key=True, default=generate_uuid)
user_id = Column(String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
# Date (for daily aggregation)
date = Column(DateTime, nullable=False, index=True)
# Aggregated counts
documents_count = Column(Integer, default=0)
pages_count = Column(Integer, default=0)
characters_count = Column(BigInteger, default=0)
api_calls_count = Column(Integer, default=0)
# By provider breakdown (JSON)
provider_breakdown = Column(JSON, default=dict)
# Indexes
__table_args__ = (
Index('ix_usage_logs_user_date', 'user_id', 'date', unique=True),
)
class PaymentHistory(Base):
"""Payment and invoice history"""
__tablename__ = "payment_history"
id = Column(String(36), primary_key=True, default=generate_uuid)
user_id = Column(String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
# Stripe info
stripe_payment_intent_id = Column(String(255), nullable=True)
stripe_invoice_id = Column(String(255), nullable=True)
# Payment details
amount_cents = Column(Integer, nullable=False)
currency = Column(String(3), default="usd")
payment_type = Column(String(50), nullable=False) # subscription, credits, one_time
status = Column(String(20), nullable=False) # succeeded, failed, pending, refunded
# Description
description = Column(String(255), nullable=True)
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow)
# Indexes
__table_args__ = (
Index('ix_payment_history_user', 'user_id', 'created_at'),
)

341
database/repositories.py Normal file
View File

@ -0,0 +1,341 @@
"""
Repository layer for database operations
Provides clean interface for CRUD operations
"""
import hashlib
import secrets
from datetime import datetime, timedelta
from typing import Optional, List, Dict, Any
from sqlalchemy.orm import Session
from sqlalchemy import and_, func, or_
from database.models import (
User, Translation, ApiKey, UsageLog, PaymentHistory,
PlanType, SubscriptionStatus
)
class UserRepository:
"""Repository for User database operations"""
def __init__(self, db: Session):
self.db = db
def get_by_id(self, user_id: str) -> Optional[User]:
"""Get user by ID"""
return self.db.query(User).filter(User.id == user_id).first()
def get_by_email(self, email: str) -> Optional[User]:
"""Get user by email (case-insensitive)"""
return self.db.query(User).filter(
func.lower(User.email) == email.lower()
).first()
def get_by_stripe_customer(self, stripe_customer_id: str) -> Optional[User]:
"""Get user by Stripe customer ID"""
return self.db.query(User).filter(
User.stripe_customer_id == stripe_customer_id
).first()
def create(
self,
email: str,
name: str,
password_hash: str,
plan: PlanType = PlanType.FREE
) -> User:
"""Create a new user"""
user = User(
email=email.lower(),
name=name,
password_hash=password_hash,
plan=plan,
subscription_status=SubscriptionStatus.ACTIVE,
)
self.db.add(user)
self.db.commit()
self.db.refresh(user)
return user
def update(self, user_id: str, **kwargs) -> Optional[User]:
"""Update user fields"""
user = self.get_by_id(user_id)
if not user:
return None
for key, value in kwargs.items():
if hasattr(user, key):
setattr(user, key, value)
user.updated_at = datetime.utcnow()
self.db.commit()
self.db.refresh(user)
return user
def delete(self, user_id: str) -> bool:
"""Delete a user"""
user = self.get_by_id(user_id)
if not user:
return False
self.db.delete(user)
self.db.commit()
return True
def increment_usage(
self,
user_id: str,
docs: int = 0,
pages: int = 0,
api_calls: int = 0
) -> Optional[User]:
"""Increment usage counters"""
user = self.get_by_id(user_id)
if not user:
return None
# Check if usage needs to be reset (monthly)
if user.usage_reset_date:
now = datetime.utcnow()
if now.month != user.usage_reset_date.month or now.year != user.usage_reset_date.year:
user.docs_translated_this_month = 0
user.pages_translated_this_month = 0
user.api_calls_this_month = 0
user.usage_reset_date = now
user.docs_translated_this_month += docs
user.pages_translated_this_month += pages
user.api_calls_this_month += api_calls
self.db.commit()
self.db.refresh(user)
return user
def add_credits(self, user_id: str, credits: int) -> Optional[User]:
"""Add extra credits to user"""
user = self.get_by_id(user_id)
if not user:
return None
user.extra_credits += credits
self.db.commit()
self.db.refresh(user)
return user
def use_credits(self, user_id: str, credits: int) -> bool:
"""Use credits from user balance"""
user = self.get_by_id(user_id)
if not user or user.extra_credits < credits:
return False
user.extra_credits -= credits
self.db.commit()
return True
def get_all_users(
self,
skip: int = 0,
limit: int = 100,
plan: Optional[PlanType] = None
) -> List[User]:
"""Get all users with pagination"""
query = self.db.query(User)
if plan:
query = query.filter(User.plan == plan)
return query.offset(skip).limit(limit).all()
def count_users(self, plan: Optional[PlanType] = None) -> int:
"""Count total users"""
query = self.db.query(func.count(User.id))
if plan:
query = query.filter(User.plan == plan)
return query.scalar()
class TranslationRepository:
"""Repository for Translation database operations"""
def __init__(self, db: Session):
self.db = db
def create(
self,
user_id: str,
original_filename: str,
file_type: str,
target_language: str,
provider: str,
source_language: str = "auto",
file_size_bytes: int = 0,
page_count: int = 0,
) -> Translation:
"""Create a new translation record"""
translation = Translation(
user_id=user_id,
original_filename=original_filename,
file_type=file_type,
file_size_bytes=file_size_bytes,
page_count=page_count,
source_language=source_language,
target_language=target_language,
provider=provider,
status="pending",
)
self.db.add(translation)
self.db.commit()
self.db.refresh(translation)
return translation
def update_status(
self,
translation_id: str,
status: str,
error_message: Optional[str] = None,
processing_time_ms: Optional[int] = None,
characters_translated: Optional[int] = None,
) -> Optional[Translation]:
"""Update translation status"""
translation = self.db.query(Translation).filter(
Translation.id == translation_id
).first()
if not translation:
return None
translation.status = status
if error_message:
translation.error_message = error_message
if processing_time_ms:
translation.processing_time_ms = processing_time_ms
if characters_translated:
translation.characters_translated = characters_translated
if status == "completed":
translation.completed_at = datetime.utcnow()
self.db.commit()
self.db.refresh(translation)
return translation
def get_user_translations(
self,
user_id: str,
skip: int = 0,
limit: int = 50,
status: Optional[str] = None,
) -> List[Translation]:
"""Get user's translation history"""
query = self.db.query(Translation).filter(Translation.user_id == user_id)
if status:
query = query.filter(Translation.status == status)
return query.order_by(Translation.created_at.desc()).offset(skip).limit(limit).all()
def get_user_stats(self, user_id: str, days: int = 30) -> Dict[str, Any]:
"""Get user's translation statistics"""
since = datetime.utcnow() - timedelta(days=days)
result = self.db.query(
func.count(Translation.id).label("total_translations"),
func.sum(Translation.page_count).label("total_pages"),
func.sum(Translation.characters_translated).label("total_characters"),
).filter(
and_(
Translation.user_id == user_id,
Translation.created_at >= since,
Translation.status == "completed",
)
).first()
return {
"total_translations": result.total_translations or 0,
"total_pages": result.total_pages or 0,
"total_characters": result.total_characters or 0,
"period_days": days,
}
class ApiKeyRepository:
"""Repository for API Key database operations"""
def __init__(self, db: Session):
self.db = db
@staticmethod
def hash_key(key: str) -> str:
"""Hash an API key"""
return hashlib.sha256(key.encode()).hexdigest()
def create(
self,
user_id: str,
name: str,
scopes: List[str] = None,
expires_in_days: Optional[int] = None,
) -> tuple[ApiKey, str]:
"""Create a new API key. Returns (ApiKey, raw_key)"""
# Generate a secure random key
raw_key = f"tr_{secrets.token_urlsafe(32)}"
key_hash = self.hash_key(raw_key)
key_prefix = raw_key[:10]
expires_at = None
if expires_in_days:
expires_at = datetime.utcnow() + timedelta(days=expires_in_days)
api_key = ApiKey(
user_id=user_id,
name=name,
key_hash=key_hash,
key_prefix=key_prefix,
scopes=scopes or ["translate"],
expires_at=expires_at,
)
self.db.add(api_key)
self.db.commit()
self.db.refresh(api_key)
return api_key, raw_key
def get_by_key(self, raw_key: str) -> Optional[ApiKey]:
"""Get API key by raw key value"""
key_hash = self.hash_key(raw_key)
api_key = self.db.query(ApiKey).filter(
and_(
ApiKey.key_hash == key_hash,
ApiKey.is_active == True,
)
).first()
if api_key:
# Check expiration
if api_key.expires_at and api_key.expires_at < datetime.utcnow():
return None
# Update last used
api_key.last_used_at = datetime.utcnow()
api_key.usage_count += 1
self.db.commit()
return api_key
def get_user_keys(self, user_id: str) -> List[ApiKey]:
"""Get all API keys for a user"""
return self.db.query(ApiKey).filter(
ApiKey.user_id == user_id
).order_by(ApiKey.created_at.desc()).all()
def revoke(self, key_id: str, user_id: str) -> bool:
"""Revoke an API key"""
api_key = self.db.query(ApiKey).filter(
and_(
ApiKey.id == key_id,
ApiKey.user_id == user_id,
)
).first()
if not api_key:
return False
api_key.is_active = False
self.db.commit()
return True

View File

@ -4,6 +4,35 @@
version: '3.8'
services:
# ===========================================
# PostgreSQL Database
# ===========================================
postgres:
image: postgres:16-alpine
container_name: translate-postgres
restart: unless-stopped
environment:
- POSTGRES_USER=${POSTGRES_USER:-translate}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-translate_secret_123}
- POSTGRES_DB=${POSTGRES_DB:-translate_db}
- PGDATA=/var/lib/postgresql/data/pgdata
volumes:
- postgres_data:/var/lib/postgresql/data
networks:
- translate-network
healthcheck:
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-translate} -d ${POSTGRES_DB:-translate_db}"]
interval: 10s
timeout: 5s
retries: 5
start_period: 10s
deploy:
resources:
limits:
memory: 512M
reservations:
memory: 128M
# ===========================================
# Backend API Service
# ===========================================
@ -14,23 +43,45 @@ services:
container_name: translate-backend
restart: unless-stopped
environment:
# Database
- DATABASE_URL=postgresql+asyncpg://${POSTGRES_USER:-translate}:${POSTGRES_PASSWORD:-translate_secret_123}@postgres:5432/${POSTGRES_DB:-translate_db}
# Redis
- REDIS_URL=redis://redis:6379/0
# Translation Services
- TRANSLATION_SERVICE=${TRANSLATION_SERVICE:-ollama}
- OLLAMA_BASE_URL=${OLLAMA_BASE_URL:-http://ollama:11434}
- OLLAMA_MODEL=${OLLAMA_MODEL:-llama3}
- DEEPL_API_KEY=${DEEPL_API_KEY:-}
- OPENAI_API_KEY=${OPENAI_API_KEY:-}
- OPENROUTER_API_KEY=${OPENROUTER_API_KEY:-}
# File Limits
- MAX_FILE_SIZE_MB=${MAX_FILE_SIZE_MB:-50}
# Rate Limiting
- RATE_LIMIT_REQUESTS_PER_MINUTE=${RATE_LIMIT_REQUESTS_PER_MINUTE:-60}
- RATE_LIMIT_TRANSLATIONS_PER_MINUTE=${RATE_LIMIT_TRANSLATIONS_PER_MINUTE:-10}
- ADMIN_USERNAME=${ADMIN_USERNAME:-admin}
- ADMIN_PASSWORD=${ADMIN_PASSWORD:-changeme123}
- CORS_ORIGINS=${CORS_ORIGINS:-*}
# Admin Auth (CHANGE IN PRODUCTION!)
- ADMIN_USERNAME=${ADMIN_USERNAME}
- ADMIN_PASSWORD=${ADMIN_PASSWORD}
# Security
- JWT_SECRET=${JWT_SECRET}
- CORS_ORIGINS=${CORS_ORIGINS:-https://yourdomain.com}
# Stripe Payments
- STRIPE_SECRET_KEY=${STRIPE_SECRET_KEY:-}
- STRIPE_WEBHOOK_SECRET=${STRIPE_WEBHOOK_SECRET:-}
- STRIPE_STARTER_PRICE_ID=${STRIPE_STARTER_PRICE_ID:-}
- STRIPE_PRO_PRICE_ID=${STRIPE_PRO_PRICE_ID:-}
- STRIPE_BUSINESS_PRICE_ID=${STRIPE_BUSINESS_PRICE_ID:-}
volumes:
- uploads_data:/app/uploads
- outputs_data:/app/outputs
- logs_data:/app/logs
networks:
- translate-network
depends_on:
postgres:
condition: service_healthy
redis:
condition: service_healthy
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
interval: 30s
@ -117,7 +168,7 @@ services:
- with-ollama
# ===========================================
# Redis (Optional - For caching & sessions)
# Redis (Caching & Sessions)
# ===========================================
redis:
image: redis:7-alpine
@ -130,11 +181,9 @@ services:
- translate-network
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 30s
timeout: 10s
retries: 3
profiles:
- with-cache
interval: 10s
timeout: 5s
retries: 5
# ===========================================
# Prometheus (Optional - Monitoring)
@ -190,6 +239,8 @@ networks:
# Volumes
# ===========================================
volumes:
postgres_data:
driver: local
uploads_data:
driver: local
outputs_data:

12
main.py
View File

@ -204,6 +204,18 @@ async def lifespan(app: FastAPI):
# Startup
logger.info("Starting Document Translation API...")
config.ensure_directories()
# Initialize database
try:
from database.connection import init_db, check_db_connection
init_db()
if check_db_connection():
logger.info("✅ Database connection verified")
else:
logger.warning("⚠️ Database connection check failed")
except Exception as e:
logger.warning(f"⚠️ Database initialization skipped: {e}")
await cleanup_manager.start()
logger.info("API ready to accept requests")

View File

@ -28,8 +28,7 @@ stripe==7.0.0
# Session storage & caching (optional but recommended for production)
redis==5.0.1
# Database (optional but recommended for production)
# sqlalchemy==2.0.25
# asyncpg==0.29.0 # PostgreSQL async driver
# alembic==1.13.1 # Database migrations
# Database (recommended for production)
sqlalchemy==2.0.25
psycopg2-binary==2.9.9 # PostgreSQL driver
alembic==1.13.1 # Database migrations

160
scripts/migrate_to_db.py Normal file
View File

@ -0,0 +1,160 @@
"""
Migration script to move data from JSON files to database
Run this once to migrate existing users to the new database system
"""
import json
import os
import sys
from pathlib import Path
from datetime import datetime
# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent.parent))
from database.connection import init_db, get_db_session
from database.repositories import UserRepository
from database.models import PlanType, SubscriptionStatus
def migrate_users_from_json():
"""Migrate users from JSON file to database"""
json_path = Path("data/users.json")
if not json_path.exists():
print("No users.json found, nothing to migrate")
return 0
# Initialize database
print("Initializing database tables...")
init_db()
# Load JSON data
with open(json_path, 'r') as f:
users_data = json.load(f)
print(f"Found {len(users_data)} users to migrate")
migrated = 0
skipped = 0
errors = 0
with get_db_session() as db:
repo = UserRepository(db)
for user_id, user_data in users_data.items():
try:
# Check if user already exists
existing = repo.get_by_email(user_data.get('email', ''))
if existing:
print(f" Skipping {user_data.get('email')} - already exists")
skipped += 1
continue
# Map plan string to enum
plan_str = user_data.get('plan', 'free')
try:
plan = PlanType(plan_str)
except ValueError:
plan = PlanType.FREE
# Map subscription status
status_str = user_data.get('subscription_status', 'active')
try:
status = SubscriptionStatus(status_str)
except ValueError:
status = SubscriptionStatus.ACTIVE
# Create user with original ID
from database.models import User
user = User(
id=user_id,
email=user_data.get('email', '').lower(),
name=user_data.get('name', ''),
password_hash=user_data.get('password_hash', ''),
email_verified=user_data.get('email_verified', False),
avatar_url=user_data.get('avatar_url'),
plan=plan,
subscription_status=status,
stripe_customer_id=user_data.get('stripe_customer_id'),
stripe_subscription_id=user_data.get('stripe_subscription_id'),
docs_translated_this_month=user_data.get('docs_translated_this_month', 0),
pages_translated_this_month=user_data.get('pages_translated_this_month', 0),
api_calls_this_month=user_data.get('api_calls_this_month', 0),
extra_credits=user_data.get('extra_credits', 0),
)
# Parse dates
if user_data.get('created_at'):
try:
user.created_at = datetime.fromisoformat(user_data['created_at'].replace('Z', '+00:00'))
except:
pass
if user_data.get('updated_at'):
try:
user.updated_at = datetime.fromisoformat(user_data['updated_at'].replace('Z', '+00:00'))
except:
pass
db.add(user)
db.commit()
print(f" Migrated: {user.email}")
migrated += 1
except Exception as e:
print(f" Error migrating {user_data.get('email', user_id)}: {e}")
errors += 1
db.rollback()
print(f"\nMigration complete:")
print(f" Migrated: {migrated}")
print(f" Skipped: {skipped}")
print(f" Errors: {errors}")
# Backup original file
if migrated > 0:
backup_path = json_path.with_suffix('.json.bak')
os.rename(json_path, backup_path)
print(f"\nOriginal file backed up to: {backup_path}")
return migrated
def verify_migration():
"""Verify the migration was successful"""
from database.connection import get_db_session
from database.repositories import UserRepository
with get_db_session() as db:
repo = UserRepository(db)
count = repo.count_users()
print(f"\nDatabase now contains {count} users")
# List first 5 users
users = repo.get_all_users(limit=5)
if users:
print("\nSample users:")
for user in users:
print(f" - {user.email} ({user.plan.value})")
if __name__ == "__main__":
print("=" * 50)
print("JSON to Database Migration Script")
print("=" * 50)
# Check environment
db_url = os.getenv("DATABASE_URL", "")
if db_url:
print(f"Database: PostgreSQL")
else:
print(f"Database: SQLite (development)")
print()
# Run migration
migrate_users_from_json()
# Verify
verify_migration()

View File

@ -1,5 +1,9 @@
"""
Authentication service with JWT tokens and password hashing
This service provides user authentication with automatic backend selection:
- If DATABASE_URL is configured: Uses PostgreSQL database
- Otherwise: Falls back to JSON file storage (development mode)
"""
import os
import secrets
@ -8,6 +12,9 @@ from datetime import datetime, timedelta
from typing import Optional, Dict, Any
import json
from pathlib import Path
import logging
logger = logging.getLogger(__name__)
# Try to import optional dependencies
try:
@ -15,6 +22,7 @@ try:
JWT_AVAILABLE = True
except ImportError:
JWT_AVAILABLE = False
logger.warning("PyJWT not installed. Using fallback token encoding.")
try:
from passlib.context import CryptContext
@ -22,17 +30,37 @@ try:
PASSLIB_AVAILABLE = True
except ImportError:
PASSLIB_AVAILABLE = False
logger.warning("passlib not installed. Using SHA256 fallback for password hashing.")
# Check if database is configured
DATABASE_URL = os.getenv("DATABASE_URL", "")
USE_DATABASE = bool(DATABASE_URL and DATABASE_URL.startswith("postgresql"))
if USE_DATABASE:
try:
from database.repositories import UserRepository
from database.connection import get_sync_session, init_db as _init_db
from database import models as db_models
DATABASE_AVAILABLE = True
logger.info("Database backend enabled for authentication")
except ImportError as e:
DATABASE_AVAILABLE = False
USE_DATABASE = False
logger.warning(f"Database modules not available: {e}. Using JSON storage.")
else:
DATABASE_AVAILABLE = False
logger.info("Using JSON file storage for authentication (DATABASE_URL not configured)")
from models.subscription import User, UserCreate, PlanType, SubscriptionStatus, PLANS
# Configuration
SECRET_KEY = os.getenv("JWT_SECRET_KEY", secrets.token_urlsafe(32))
SECRET_KEY = os.getenv("JWT_SECRET", os.getenv("JWT_SECRET_KEY", secrets.token_urlsafe(32)))
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_HOURS = 24
REFRESH_TOKEN_EXPIRE_DAYS = 30
# Simple file-based storage (replace with database in production)
# Simple file-based storage (used when database is not configured)
USERS_FILE = Path("data/users.json")
USERS_FILE.parent.mkdir(exist_ok=True)
@ -117,7 +145,7 @@ def verify_token(token: str) -> Optional[Dict[str, Any]]:
def load_users() -> Dict[str, Dict]:
"""Load users from file storage"""
"""Load users from file storage (JSON backend only)"""
if USERS_FILE.exists():
try:
with open(USERS_FILE, 'r') as f:
@ -128,13 +156,46 @@ def load_users() -> Dict[str, Dict]:
def save_users(users: Dict[str, Dict]):
"""Save users to file storage"""
"""Save users to file storage (JSON backend only)"""
with open(USERS_FILE, 'w') as f:
json.dump(users, f, indent=2, default=str)
def _db_user_to_model(db_user) -> User:
"""Convert database user model to Pydantic User model"""
return User(
id=str(db_user.id),
email=db_user.email,
name=db_user.name or "",
password_hash=db_user.password_hash,
avatar_url=db_user.avatar_url,
plan=PlanType(db_user.plan) if db_user.plan else PlanType.FREE,
subscription_status=SubscriptionStatus(db_user.subscription_status) if db_user.subscription_status else SubscriptionStatus.ACTIVE,
stripe_customer_id=db_user.stripe_customer_id,
stripe_subscription_id=db_user.stripe_subscription_id,
docs_translated_this_month=db_user.docs_translated_this_month or 0,
pages_translated_this_month=db_user.pages_translated_this_month or 0,
api_calls_this_month=db_user.api_calls_this_month or 0,
extra_credits=db_user.extra_credits or 0,
usage_reset_date=db_user.usage_reset_date or datetime.utcnow(),
default_source_lang=db_user.default_source_lang or "en",
default_target_lang=db_user.default_target_lang or "es",
default_provider=db_user.default_provider or "google",
created_at=db_user.created_at or datetime.utcnow(),
updated_at=db_user.updated_at,
)
def get_user_by_email(email: str) -> Optional[User]:
"""Get a user by email"""
if USE_DATABASE and DATABASE_AVAILABLE:
with get_sync_session() as session:
repo = UserRepository(session)
db_user = repo.get_by_email(email)
if db_user:
return _db_user_to_model(db_user)
return None
else:
users = load_users()
for user_data in users.values():
if user_data.get("email", "").lower() == email.lower():
@ -144,6 +205,14 @@ def get_user_by_email(email: str) -> Optional[User]:
def get_user_by_id(user_id: str) -> Optional[User]:
"""Get a user by ID"""
if USE_DATABASE and DATABASE_AVAILABLE:
with get_sync_session() as session:
repo = UserRepository(session)
db_user = repo.get_by_id(user_id)
if db_user:
return _db_user_to_model(db_user)
return None
else:
users = load_users()
if user_id in users:
return User(**users[user_id])
@ -152,12 +221,26 @@ def get_user_by_id(user_id: str) -> Optional[User]:
def create_user(user_create: UserCreate) -> User:
"""Create a new user"""
users = load_users()
# Check if email exists
if get_user_by_email(user_create.email):
raise ValueError("Email already registered")
if USE_DATABASE and DATABASE_AVAILABLE:
with get_sync_session() as session:
repo = UserRepository(session)
db_user = repo.create(
email=user_create.email,
name=user_create.name,
password_hash=hash_password(user_create.password),
plan=PlanType.FREE.value,
subscription_status=SubscriptionStatus.ACTIVE.value
)
session.commit()
session.refresh(db_user)
return _db_user_to_model(db_user)
else:
users = load_users()
# Generate user ID
user_id = secrets.token_urlsafe(16)
@ -190,6 +273,16 @@ def authenticate_user(email: str, password: str) -> Optional[User]:
def update_user(user_id: str, updates: Dict[str, Any]) -> Optional[User]:
"""Update a user's data"""
if USE_DATABASE and DATABASE_AVAILABLE:
with get_sync_session() as session:
repo = UserRepository(session)
db_user = repo.update(user_id, updates)
if db_user:
session.commit()
session.refresh(db_user)
return _db_user_to_model(db_user)
return None
else:
users = load_users()
if user_id not in users:
return None
@ -212,7 +305,7 @@ def check_usage_limits(user: User) -> Dict[str, Any]:
"docs_translated_this_month": 0,
"pages_translated_this_month": 0,
"api_calls_this_month": 0,
"usage_reset_date": now.isoformat()
"usage_reset_date": now.isoformat() if not USE_DATABASE else now
})
user.docs_translated_this_month = 0
user.pages_translated_this_month = 0
@ -248,8 +341,8 @@ def record_usage(user_id: str, pages_count: int, use_credits: bool = False) -> b
if use_credits:
updates["extra_credits"] = max(0, user.extra_credits - pages_count)
update_user(user_id, updates)
return True
result = update_user(user_id, updates)
return result is not None
def add_credits(user_id: str, credits: int) -> bool:
@ -258,5 +351,14 @@ def add_credits(user_id: str, credits: int) -> bool:
if not user:
return False
update_user(user_id, {"extra_credits": user.extra_credits + credits})
return True
result = update_user(user_id, {"extra_credits": user.extra_credits + credits})
return result is not None
def init_database():
"""Initialize the database (call on application startup)"""
if USE_DATABASE and DATABASE_AVAILABLE:
_init_db()
logger.info("Database initialized successfully")
else:
logger.info("Using JSON file storage")

245
services/auth_service_db.py Normal file
View File

@ -0,0 +1,245 @@
"""
Database-backed authentication service
Replaces JSON file storage with SQLAlchemy
"""
import os
import secrets
import hashlib
from datetime import datetime, timedelta
from typing import Optional, Dict, Any
import logging
# Try to import optional dependencies
try:
import jwt
JWT_AVAILABLE = True
except ImportError:
JWT_AVAILABLE = False
import json
import base64
try:
from passlib.context import CryptContext
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
PASSLIB_AVAILABLE = True
except ImportError:
PASSLIB_AVAILABLE = False
from database.connection import get_db_session
from database.repositories import UserRepository
from database.models import User, PlanType, SubscriptionStatus
from models.subscription import PLANS
logger = logging.getLogger(__name__)
# Configuration
SECRET_KEY = os.getenv("JWT_SECRET_KEY", secrets.token_urlsafe(32))
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_HOURS = 24
REFRESH_TOKEN_EXPIRE_DAYS = 30
def hash_password(password: str) -> str:
"""Hash a password using bcrypt or fallback to SHA256"""
if PASSLIB_AVAILABLE:
return pwd_context.hash(password)
else:
salt = secrets.token_hex(16)
hashed = hashlib.sha256(f"{salt}{password}".encode()).hexdigest()
return f"sha256${salt}${hashed}"
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against its hash"""
if PASSLIB_AVAILABLE and not hashed_password.startswith("sha256$"):
return pwd_context.verify(plain_password, hashed_password)
else:
parts = hashed_password.split("$")
if len(parts) == 3 and parts[0] == "sha256":
salt = parts[1]
expected_hash = parts[2]
actual_hash = hashlib.sha256(f"{salt}{plain_password}".encode()).hexdigest()
return secrets.compare_digest(actual_hash, expected_hash)
return False
def create_access_token(user_id: str, expires_delta: Optional[timedelta] = None) -> str:
"""Create a JWT access token"""
expire = datetime.utcnow() + (expires_delta or timedelta(hours=ACCESS_TOKEN_EXPIRE_HOURS))
if not JWT_AVAILABLE:
token_data = {"user_id": user_id, "exp": expire.isoformat(), "type": "access"}
return base64.urlsafe_b64encode(json.dumps(token_data).encode()).decode()
to_encode = {"sub": user_id, "exp": expire, "type": "access"}
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
def create_refresh_token(user_id: str) -> str:
"""Create a JWT refresh token"""
expire = datetime.utcnow() + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
if not JWT_AVAILABLE:
token_data = {"user_id": user_id, "exp": expire.isoformat(), "type": "refresh"}
return base64.urlsafe_b64encode(json.dumps(token_data).encode()).decode()
to_encode = {"sub": user_id, "exp": expire, "type": "refresh"}
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
def verify_token(token: str) -> Optional[Dict[str, Any]]:
"""Verify a JWT token and return payload"""
if not JWT_AVAILABLE:
try:
data = json.loads(base64.urlsafe_b64decode(token.encode()).decode())
exp = datetime.fromisoformat(data["exp"])
if exp < datetime.utcnow():
return None
return {"sub": data["user_id"], "type": data.get("type", "access")}
except Exception:
return None
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
return payload
except jwt.ExpiredSignatureError:
return None
except jwt.JWTError:
return None
def create_user(email: str, name: str, password: str) -> User:
"""Create a new user in the database"""
with get_db_session() as db:
repo = UserRepository(db)
# Check if email already exists
existing = repo.get_by_email(email)
if existing:
raise ValueError("Email already registered")
password_hash = hash_password(password)
user = repo.create(
email=email,
name=name,
password_hash=password_hash,
plan=PlanType.FREE,
)
return user
def authenticate_user(email: str, password: str) -> Optional[User]:
"""Authenticate user and return user object if valid"""
with get_db_session() as db:
repo = UserRepository(db)
user = repo.get_by_email(email)
if not user:
return None
if not verify_password(password, user.password_hash):
return None
# Update last login
repo.update(user.id, last_login_at=datetime.utcnow())
return user
def get_user_by_id(user_id: str) -> Optional[User]:
"""Get user by ID from database"""
with get_db_session() as db:
repo = UserRepository(db)
return repo.get_by_id(user_id)
def get_user_by_email(email: str) -> Optional[User]:
"""Get user by email from database"""
with get_db_session() as db:
repo = UserRepository(db)
return repo.get_by_email(email)
def update_user(user_id: str, updates: Dict[str, Any]) -> Optional[User]:
"""Update user fields in database"""
with get_db_session() as db:
repo = UserRepository(db)
return repo.update(user_id, **updates)
def add_credits(user_id: str, credits: int) -> bool:
"""Add credits to user account"""
with get_db_session() as db:
repo = UserRepository(db)
result = repo.add_credits(user_id, credits)
return result is not None
def use_credits(user_id: str, credits: int) -> bool:
"""Use credits from user account"""
with get_db_session() as db:
repo = UserRepository(db)
return repo.use_credits(user_id, credits)
def increment_usage(user_id: str, docs: int = 0, pages: int = 0, api_calls: int = 0) -> bool:
"""Increment user usage counters"""
with get_db_session() as db:
repo = UserRepository(db)
result = repo.increment_usage(user_id, docs=docs, pages=pages, api_calls=api_calls)
return result is not None
def check_usage_limits(user_id: str) -> Dict[str, Any]:
"""Check if user is within their plan limits"""
with get_db_session() as db:
repo = UserRepository(db)
user = repo.get_by_id(user_id)
if not user:
return {"allowed": False, "reason": "User not found"}
plan_config = PLANS.get(user.plan, PLANS[PlanType.FREE])
# Check document limit
docs_limit = plan_config["docs_per_month"]
if docs_limit > 0 and user.docs_translated_this_month >= docs_limit:
# Check if user has extra credits
if user.extra_credits <= 0:
return {
"allowed": False,
"reason": "Monthly document limit reached",
"limit": docs_limit,
"used": user.docs_translated_this_month,
}
return {
"allowed": True,
"docs_remaining": max(0, docs_limit - user.docs_translated_this_month) if docs_limit > 0 else -1,
"extra_credits": user.extra_credits,
}
def get_user_usage_stats(user_id: str) -> Dict[str, Any]:
"""Get detailed usage statistics for a user"""
with get_db_session() as db:
repo = UserRepository(db)
user = repo.get_by_id(user_id)
if not user:
return {}
plan_config = PLANS.get(user.plan, PLANS[PlanType.FREE])
return {
"docs_used": user.docs_translated_this_month,
"docs_limit": plan_config["docs_per_month"],
"docs_remaining": max(0, plan_config["docs_per_month"] - user.docs_translated_this_month) if plan_config["docs_per_month"] > 0 else -1,
"pages_used": user.pages_translated_this_month,
"extra_credits": user.extra_credits,
"max_pages_per_doc": plan_config["max_pages_per_doc"],
"max_file_size_mb": plan_config["max_file_size_mb"],
"allowed_providers": plan_config["providers"],
"api_access": plan_config.get("api_access", False),
"api_calls_used": user.api_calls_this_month if plan_config.get("api_access") else 0,
"api_calls_limit": plan_config.get("api_calls_per_month", 0),
}