From 550f3516db0a28a352e85ff99fc6eb57da6c316f Mon Sep 17 00:00:00 2001 From: Sepehr Date: Wed, 31 Dec 2025 10:56:19 +0100 Subject: [PATCH] feat: Add PostgreSQL database infrastructure - Add SQLAlchemy models for User, Translation, ApiKey, UsageLog, PaymentHistory - Add database connection management with PostgreSQL/SQLite support - Add repository layer for CRUD operations - Add Alembic migration setup with initial migration - Update auth_service to automatically use database when DATABASE_URL is set - Update docker-compose.yml with PostgreSQL service and Redis (non-optional) - Add database migration script (scripts/migrate_to_db.py) - Update .env.example with database configuration --- .env.example | 11 +- alembic.ini | 78 ++++++++ alembic/env.py | 86 ++++++++ alembic/script.py.mako | 27 +++ alembic/versions/001_initial.py | 126 ++++++++++++ database/__init__.py | 17 ++ database/connection.py | 139 +++++++++++++ database/models.py | 259 ++++++++++++++++++++++++ database/repositories.py | 341 ++++++++++++++++++++++++++++++++ docker-compose.yml | 69 ++++++- main.py | 12 ++ requirements.txt | 9 +- scripts/migrate_to_db.py | 160 +++++++++++++++ services/auth_service.py | 196 +++++++++++++----- services/auth_service_db.py | 245 +++++++++++++++++++++++ 15 files changed, 1712 insertions(+), 63 deletions(-) create mode 100644 alembic.ini create mode 100644 alembic/env.py create mode 100644 alembic/script.py.mako create mode 100644 alembic/versions/001_initial.py create mode 100644 database/__init__.py create mode 100644 database/connection.py create mode 100644 database/models.py create mode 100644 database/repositories.py create mode 100644 scripts/migrate_to_db.py create mode 100644 services/auth_service_db.py diff --git a/.env.example b/.env.example index e5f6ad4..693fdff 100644 --- a/.env.example +++ b/.env.example @@ -72,9 +72,16 @@ MAX_REQUEST_SIZE_MB=100 # Request timeout in seconds REQUEST_TIMEOUT_SECONDS=300 -# ============== Database (Production) ============== +# ============== Database ============== # PostgreSQL connection string (recommended for production) -# DATABASE_URL=postgresql://user:password@localhost:5432/translate_db +# Format: postgresql://user:password@host:port/database +# DATABASE_URL=postgresql://translate_user:secure_password@localhost:5432/translate_db + +# SQLite path (used when DATABASE_URL is not set - for development) +SQLITE_PATH=data/translate.db + +# Enable SQL query logging (for debugging) +DATABASE_ECHO=false # Redis for sessions and caching (recommended for production) # REDIS_URL=redis://localhost:6379/0 diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 0000000..3fd5d0b --- /dev/null +++ b/alembic.ini @@ -0,0 +1,78 @@ +# Alembic configuration file + +[alembic] +# path to migration scripts +script_location = alembic + +# template used to generate migration files +# file_template = %%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# timezone = + +# max length of characters to apply to the "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during the 'revision' command, +# regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without a source .py file +# to be detected as revisions in the versions/ directory +# sourceless = false + +# version location specification; This defaults to alembic/versions +# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions + +# version path separator +# version_path_separator = : + +# the output encoding used when revision files are written +# output_encoding = utf-8 + +sqlalchemy.url = driver://user:pass@localhost/dbname + + +[post_write_hooks] +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -q + +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/alembic/env.py b/alembic/env.py new file mode 100644 index 0000000..ec9c559 --- /dev/null +++ b/alembic/env.py @@ -0,0 +1,86 @@ +""" +Alembic environment configuration +""" +import os +import sys +from logging.config import fileConfig + +from sqlalchemy import engine_from_config +from sqlalchemy import pool + +from alembic import context + +# Add parent directory to path for imports +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +# this is the Alembic Config object +config = context.config + +# Interpret the config file for Python logging +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# Import models for autogenerate support +from database.models import Base +target_metadata = Base.metadata + +# Get database URL from environment +from dotenv import load_dotenv +load_dotenv() + +DATABASE_URL = os.getenv("DATABASE_URL", "") +if not DATABASE_URL: + SQLITE_PATH = os.getenv("SQLITE_PATH", "data/translate.db") + DATABASE_URL = f"sqlite:///./{SQLITE_PATH}" + +# Override sqlalchemy.url with environment variable +config.set_main_option("sqlalchemy.url", DATABASE_URL) + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + """ + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure( + connection=connection, + target_metadata=target_metadata, + compare_type=True, + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/alembic/script.py.mako b/alembic/script.py.mako new file mode 100644 index 0000000..5d7f79f --- /dev/null +++ b/alembic/script.py.mako @@ -0,0 +1,27 @@ +""" +Alembic migration script template +${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/alembic/versions/001_initial.py b/alembic/versions/001_initial.py new file mode 100644 index 0000000..0a1f7ef --- /dev/null +++ b/alembic/versions/001_initial.py @@ -0,0 +1,126 @@ +"""Initial database schema + +Revision ID: 001_initial +Revises: +Create Date: 2024-12-31 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '001_initial' +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Create users table + op.create_table( + 'users', + sa.Column('id', sa.String(36), primary_key=True), + sa.Column('email', sa.String(255), unique=True, nullable=False), + sa.Column('name', sa.String(255), nullable=False), + sa.Column('password_hash', sa.String(255), nullable=False), + sa.Column('email_verified', sa.Boolean(), default=False), + sa.Column('is_active', sa.Boolean(), default=True), + sa.Column('avatar_url', sa.String(500), nullable=True), + sa.Column('plan', sa.String(20), default='free'), + sa.Column('subscription_status', sa.String(20), default='active'), + sa.Column('stripe_customer_id', sa.String(255), nullable=True), + sa.Column('stripe_subscription_id', sa.String(255), nullable=True), + sa.Column('docs_translated_this_month', sa.Integer(), default=0), + sa.Column('pages_translated_this_month', sa.Integer(), default=0), + sa.Column('api_calls_this_month', sa.Integer(), default=0), + sa.Column('extra_credits', sa.Integer(), default=0), + sa.Column('usage_reset_date', sa.DateTime()), + sa.Column('created_at', sa.DateTime()), + sa.Column('updated_at', sa.DateTime()), + sa.Column('last_login_at', sa.DateTime(), nullable=True), + ) + op.create_index('ix_users_email', 'users', ['email']) + op.create_index('ix_users_email_active', 'users', ['email', 'is_active']) + op.create_index('ix_users_stripe_customer', 'users', ['stripe_customer_id']) + + # Create translations table + op.create_table( + 'translations', + sa.Column('id', sa.String(36), primary_key=True), + sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id', ondelete='CASCADE'), nullable=False), + sa.Column('original_filename', sa.String(255), nullable=False), + sa.Column('file_type', sa.String(10), nullable=False), + sa.Column('file_size_bytes', sa.BigInteger(), default=0), + sa.Column('page_count', sa.Integer(), default=0), + sa.Column('source_language', sa.String(10), default='auto'), + sa.Column('target_language', sa.String(10), nullable=False), + sa.Column('provider', sa.String(50), nullable=False), + sa.Column('status', sa.String(20), default='pending'), + sa.Column('error_message', sa.Text(), nullable=True), + sa.Column('processing_time_ms', sa.Integer(), nullable=True), + sa.Column('characters_translated', sa.Integer(), default=0), + sa.Column('estimated_cost_usd', sa.Float(), default=0.0), + sa.Column('created_at', sa.DateTime()), + sa.Column('completed_at', sa.DateTime(), nullable=True), + ) + op.create_index('ix_translations_user_date', 'translations', ['user_id', 'created_at']) + op.create_index('ix_translations_status', 'translations', ['status']) + + # Create api_keys table + op.create_table( + 'api_keys', + sa.Column('id', sa.String(36), primary_key=True), + sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id', ondelete='CASCADE'), nullable=False), + sa.Column('name', sa.String(100), nullable=False), + sa.Column('key_hash', sa.String(255), nullable=False), + sa.Column('key_prefix', sa.String(10), nullable=False), + sa.Column('is_active', sa.Boolean(), default=True), + sa.Column('scopes', sa.JSON(), default=list), + sa.Column('last_used_at', sa.DateTime(), nullable=True), + sa.Column('usage_count', sa.Integer(), default=0), + sa.Column('created_at', sa.DateTime()), + sa.Column('expires_at', sa.DateTime(), nullable=True), + ) + op.create_index('ix_api_keys_prefix', 'api_keys', ['key_prefix']) + op.create_index('ix_api_keys_hash', 'api_keys', ['key_hash']) + + # Create usage_logs table + op.create_table( + 'usage_logs', + sa.Column('id', sa.String(36), primary_key=True), + sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id', ondelete='CASCADE'), nullable=False), + sa.Column('date', sa.DateTime(), nullable=False), + sa.Column('documents_count', sa.Integer(), default=0), + sa.Column('pages_count', sa.Integer(), default=0), + sa.Column('characters_count', sa.BigInteger(), default=0), + sa.Column('api_calls_count', sa.Integer(), default=0), + sa.Column('provider_breakdown', sa.JSON(), default=dict), + ) + op.create_index('ix_usage_logs_user_date', 'usage_logs', ['user_id', 'date'], unique=True) + + # Create payment_history table + op.create_table( + 'payment_history', + sa.Column('id', sa.String(36), primary_key=True), + sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id', ondelete='CASCADE'), nullable=False), + sa.Column('stripe_payment_intent_id', sa.String(255), nullable=True), + sa.Column('stripe_invoice_id', sa.String(255), nullable=True), + sa.Column('amount_cents', sa.Integer(), nullable=False), + sa.Column('currency', sa.String(3), default='usd'), + sa.Column('payment_type', sa.String(50), nullable=False), + sa.Column('status', sa.String(20), nullable=False), + sa.Column('description', sa.String(255), nullable=True), + sa.Column('created_at', sa.DateTime()), + ) + op.create_index('ix_payment_history_user', 'payment_history', ['user_id', 'created_at']) + + +def downgrade() -> None: + op.drop_table('payment_history') + op.drop_table('usage_logs') + op.drop_table('api_keys') + op.drop_table('translations') + op.drop_table('users') diff --git a/database/__init__.py b/database/__init__.py new file mode 100644 index 0000000..b45b2ab --- /dev/null +++ b/database/__init__.py @@ -0,0 +1,17 @@ +""" +Database module for the Document Translation API +Provides PostgreSQL support with async SQLAlchemy +""" +from database.connection import get_db, engine, SessionLocal, init_db +from database.models import User, Subscription, Translation, ApiKey + +__all__ = [ + "get_db", + "engine", + "SessionLocal", + "init_db", + "User", + "Subscription", + "Translation", + "ApiKey" +] diff --git a/database/connection.py b/database/connection.py new file mode 100644 index 0000000..5a1ab2f --- /dev/null +++ b/database/connection.py @@ -0,0 +1,139 @@ +""" +Database connection and session management +Supports both PostgreSQL (production) and SQLite (development/testing) +""" +import os +import logging +from typing import Generator, Optional +from contextlib import contextmanager + +from sqlalchemy import create_engine, event +from sqlalchemy.orm import sessionmaker, Session +from sqlalchemy.pool import QueuePool, StaticPool + +logger = logging.getLogger(__name__) + +# Database URL from environment +# PostgreSQL: postgresql://user:password@host:port/database +# SQLite: sqlite:///./data/translate.db +DATABASE_URL = os.getenv("DATABASE_URL", "") + +# Determine if we're using SQLite or PostgreSQL +_is_sqlite = DATABASE_URL.startswith("sqlite") if DATABASE_URL else True + +# Create engine based on database type +if DATABASE_URL and not _is_sqlite: + # PostgreSQL configuration + engine = create_engine( + DATABASE_URL, + poolclass=QueuePool, + pool_size=5, + max_overflow=10, + pool_timeout=30, + pool_recycle=1800, # Recycle connections after 30 minutes + pool_pre_ping=True, # Check connection health before use + echo=os.getenv("DATABASE_ECHO", "false").lower() == "true", + ) + logger.info("✅ Database configured with PostgreSQL") +else: + # SQLite configuration (for development/testing or when no DATABASE_URL) + sqlite_path = os.getenv("SQLITE_PATH", "data/translate.db") + os.makedirs(os.path.dirname(sqlite_path), exist_ok=True) + + sqlite_url = f"sqlite:///./{sqlite_path}" + engine = create_engine( + sqlite_url, + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + echo=os.getenv("DATABASE_ECHO", "false").lower() == "true", + ) + + # Enable foreign keys for SQLite + @event.listens_for(engine, "connect") + def set_sqlite_pragma(dbapi_connection, connection_record): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + + if not DATABASE_URL: + logger.warning("⚠️ DATABASE_URL not set, using SQLite for development") + else: + logger.info(f"✅ Database configured with SQLite: {sqlite_path}") + +# Session factory +SessionLocal = sessionmaker( + autocommit=False, + autoflush=False, + bind=engine, + expire_on_commit=False, +) + + +def get_db() -> Generator[Session, None, None]: + """ + Dependency for FastAPI to get database session. + Usage: db: Session = Depends(get_db) + """ + db = SessionLocal() + try: + yield db + finally: + db.close() + + +@contextmanager +def get_db_session() -> Generator[Session, None, None]: + """ + Context manager for database session. + Usage: with get_db_session() as db: ... + """ + db = SessionLocal() + try: + yield db + db.commit() + except Exception: + db.rollback() + raise + finally: + db.close() + + +# Alias for backward compatibility +get_sync_session = get_db_session + + +def init_db(): + """ + Initialize database tables. + Call this on application startup. + """ + from database.models import Base + Base.metadata.create_all(bind=engine) + logger.info("✅ Database tables initialized") + + +def check_db_connection() -> bool: + """ + Check if database connection is healthy. + Returns True if connection works, False otherwise. + """ + try: + with engine.connect() as conn: + conn.execute("SELECT 1") + return True + except Exception as e: + logger.error(f"Database connection check failed: {e}") + return False + + +# Connection pool stats (for monitoring) +def get_pool_stats() -> dict: + """Get database connection pool statistics""" + if hasattr(engine.pool, 'status'): + return { + "pool_size": engine.pool.size(), + "checked_in": engine.pool.checkedin(), + "checked_out": engine.pool.checkedout(), + "overflow": engine.pool.overflow(), + } + return {"status": "pool stats not available"} diff --git a/database/models.py b/database/models.py new file mode 100644 index 0000000..a5ad791 --- /dev/null +++ b/database/models.py @@ -0,0 +1,259 @@ +""" +SQLAlchemy models for the Document Translation API +""" +import os +import uuid +from datetime import datetime +from typing import Optional, List + +from sqlalchemy import ( + Column, String, Integer, Float, Boolean, DateTime, Text, + ForeignKey, Enum, Index, JSON, BigInteger +) +from sqlalchemy.orm import relationship, declarative_base +from sqlalchemy.dialects.postgresql import UUID as PG_UUID +import enum + +Base = declarative_base() + + +def generate_uuid(): + """Generate a new UUID string""" + return str(uuid.uuid4()) + + +class PlanType(str, enum.Enum): + FREE = "free" + STARTER = "starter" + PRO = "pro" + BUSINESS = "business" + ENTERPRISE = "enterprise" + + +class SubscriptionStatus(str, enum.Enum): + ACTIVE = "active" + CANCELED = "canceled" + PAST_DUE = "past_due" + TRIALING = "trialing" + PAUSED = "paused" + + +class User(Base): + """User model for authentication and billing""" + __tablename__ = "users" + + id = Column(String(36), primary_key=True, default=generate_uuid) + email = Column(String(255), unique=True, nullable=False, index=True) + name = Column(String(255), nullable=False) + password_hash = Column(String(255), nullable=False) + + # Account status + email_verified = Column(Boolean, default=False) + is_active = Column(Boolean, default=True) + avatar_url = Column(String(500), nullable=True) + + # Subscription info + plan = Column(Enum(PlanType), default=PlanType.FREE) + subscription_status = Column(Enum(SubscriptionStatus), default=SubscriptionStatus.ACTIVE) + + # Stripe integration + stripe_customer_id = Column(String(255), nullable=True, index=True) + stripe_subscription_id = Column(String(255), nullable=True) + + # Usage tracking (reset monthly) + docs_translated_this_month = Column(Integer, default=0) + pages_translated_this_month = Column(Integer, default=0) + api_calls_this_month = Column(Integer, default=0) + extra_credits = Column(Integer, default=0) # Purchased credits + usage_reset_date = Column(DateTime, default=datetime.utcnow) + + # Timestamps + created_at = Column(DateTime, default=datetime.utcnow) + updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + last_login_at = Column(DateTime, nullable=True) + + # Relationships + translations = relationship("Translation", back_populates="user", lazy="dynamic") + api_keys = relationship("ApiKey", back_populates="user", lazy="dynamic") + + # Indexes + __table_args__ = ( + Index('ix_users_email_active', 'email', 'is_active'), + Index('ix_users_stripe_customer', 'stripe_customer_id'), + ) + + def to_dict(self) -> dict: + """Convert user to dictionary for API response""" + return { + "id": self.id, + "email": self.email, + "name": self.name, + "avatar_url": self.avatar_url, + "plan": self.plan.value if self.plan else "free", + "subscription_status": self.subscription_status.value if self.subscription_status else "active", + "docs_translated_this_month": self.docs_translated_this_month, + "pages_translated_this_month": self.pages_translated_this_month, + "api_calls_this_month": self.api_calls_this_month, + "extra_credits": self.extra_credits, + "email_verified": self.email_verified, + "created_at": self.created_at.isoformat() if self.created_at else None, + } + + +class Translation(Base): + """Translation history for analytics and billing""" + __tablename__ = "translations" + + id = Column(String(36), primary_key=True, default=generate_uuid) + user_id = Column(String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False) + + # File info + original_filename = Column(String(255), nullable=False) + file_type = Column(String(10), nullable=False) # xlsx, docx, pptx + file_size_bytes = Column(BigInteger, default=0) + page_count = Column(Integer, default=0) + + # Translation details + source_language = Column(String(10), default="auto") + target_language = Column(String(10), nullable=False) + provider = Column(String(50), nullable=False) # google, deepl, ollama, etc. + + # Processing info + status = Column(String(20), default="pending") # pending, processing, completed, failed + error_message = Column(Text, nullable=True) + processing_time_ms = Column(Integer, nullable=True) + + # Cost tracking (for paid providers) + characters_translated = Column(Integer, default=0) + estimated_cost_usd = Column(Float, default=0.0) + + # Timestamps + created_at = Column(DateTime, default=datetime.utcnow) + completed_at = Column(DateTime, nullable=True) + + # Relationship + user = relationship("User", back_populates="translations") + + # Indexes + __table_args__ = ( + Index('ix_translations_user_date', 'user_id', 'created_at'), + Index('ix_translations_status', 'status'), + ) + + def to_dict(self) -> dict: + return { + "id": self.id, + "original_filename": self.original_filename, + "file_type": self.file_type, + "file_size_bytes": self.file_size_bytes, + "page_count": self.page_count, + "source_language": self.source_language, + "target_language": self.target_language, + "provider": self.provider, + "status": self.status, + "processing_time_ms": self.processing_time_ms, + "characters_translated": self.characters_translated, + "created_at": self.created_at.isoformat() if self.created_at else None, + "completed_at": self.completed_at.isoformat() if self.completed_at else None, + } + + +class ApiKey(Base): + """API keys for programmatic access""" + __tablename__ = "api_keys" + + id = Column(String(36), primary_key=True, default=generate_uuid) + user_id = Column(String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False) + + # Key info + name = Column(String(100), nullable=False) # User-friendly name + key_hash = Column(String(255), nullable=False) # SHA256 of the key + key_prefix = Column(String(10), nullable=False) # First 8 chars for identification + + # Permissions + is_active = Column(Boolean, default=True) + scopes = Column(JSON, default=list) # ["translate", "read", "write"] + + # Usage tracking + last_used_at = Column(DateTime, nullable=True) + usage_count = Column(Integer, default=0) + + # Timestamps + created_at = Column(DateTime, default=datetime.utcnow) + expires_at = Column(DateTime, nullable=True) + + # Relationship + user = relationship("User", back_populates="api_keys") + + # Indexes + __table_args__ = ( + Index('ix_api_keys_prefix', 'key_prefix'), + Index('ix_api_keys_hash', 'key_hash'), + ) + + def to_dict(self) -> dict: + return { + "id": self.id, + "name": self.name, + "key_prefix": self.key_prefix, + "is_active": self.is_active, + "scopes": self.scopes, + "last_used_at": self.last_used_at.isoformat() if self.last_used_at else None, + "usage_count": self.usage_count, + "created_at": self.created_at.isoformat() if self.created_at else None, + "expires_at": self.expires_at.isoformat() if self.expires_at else None, + } + + +class UsageLog(Base): + """Daily usage aggregation for billing and analytics""" + __tablename__ = "usage_logs" + + id = Column(String(36), primary_key=True, default=generate_uuid) + user_id = Column(String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False) + + # Date (for daily aggregation) + date = Column(DateTime, nullable=False, index=True) + + # Aggregated counts + documents_count = Column(Integer, default=0) + pages_count = Column(Integer, default=0) + characters_count = Column(BigInteger, default=0) + api_calls_count = Column(Integer, default=0) + + # By provider breakdown (JSON) + provider_breakdown = Column(JSON, default=dict) + + # Indexes + __table_args__ = ( + Index('ix_usage_logs_user_date', 'user_id', 'date', unique=True), + ) + + +class PaymentHistory(Base): + """Payment and invoice history""" + __tablename__ = "payment_history" + + id = Column(String(36), primary_key=True, default=generate_uuid) + user_id = Column(String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False) + + # Stripe info + stripe_payment_intent_id = Column(String(255), nullable=True) + stripe_invoice_id = Column(String(255), nullable=True) + + # Payment details + amount_cents = Column(Integer, nullable=False) + currency = Column(String(3), default="usd") + payment_type = Column(String(50), nullable=False) # subscription, credits, one_time + status = Column(String(20), nullable=False) # succeeded, failed, pending, refunded + + # Description + description = Column(String(255), nullable=True) + + # Timestamps + created_at = Column(DateTime, default=datetime.utcnow) + + # Indexes + __table_args__ = ( + Index('ix_payment_history_user', 'user_id', 'created_at'), + ) diff --git a/database/repositories.py b/database/repositories.py new file mode 100644 index 0000000..240aa84 --- /dev/null +++ b/database/repositories.py @@ -0,0 +1,341 @@ +""" +Repository layer for database operations +Provides clean interface for CRUD operations +""" +import hashlib +import secrets +from datetime import datetime, timedelta +from typing import Optional, List, Dict, Any + +from sqlalchemy.orm import Session +from sqlalchemy import and_, func, or_ + +from database.models import ( + User, Translation, ApiKey, UsageLog, PaymentHistory, + PlanType, SubscriptionStatus +) + + +class UserRepository: + """Repository for User database operations""" + + def __init__(self, db: Session): + self.db = db + + def get_by_id(self, user_id: str) -> Optional[User]: + """Get user by ID""" + return self.db.query(User).filter(User.id == user_id).first() + + def get_by_email(self, email: str) -> Optional[User]: + """Get user by email (case-insensitive)""" + return self.db.query(User).filter( + func.lower(User.email) == email.lower() + ).first() + + def get_by_stripe_customer(self, stripe_customer_id: str) -> Optional[User]: + """Get user by Stripe customer ID""" + return self.db.query(User).filter( + User.stripe_customer_id == stripe_customer_id + ).first() + + def create( + self, + email: str, + name: str, + password_hash: str, + plan: PlanType = PlanType.FREE + ) -> User: + """Create a new user""" + user = User( + email=email.lower(), + name=name, + password_hash=password_hash, + plan=plan, + subscription_status=SubscriptionStatus.ACTIVE, + ) + self.db.add(user) + self.db.commit() + self.db.refresh(user) + return user + + def update(self, user_id: str, **kwargs) -> Optional[User]: + """Update user fields""" + user = self.get_by_id(user_id) + if not user: + return None + + for key, value in kwargs.items(): + if hasattr(user, key): + setattr(user, key, value) + + user.updated_at = datetime.utcnow() + self.db.commit() + self.db.refresh(user) + return user + + def delete(self, user_id: str) -> bool: + """Delete a user""" + user = self.get_by_id(user_id) + if not user: + return False + + self.db.delete(user) + self.db.commit() + return True + + def increment_usage( + self, + user_id: str, + docs: int = 0, + pages: int = 0, + api_calls: int = 0 + ) -> Optional[User]: + """Increment usage counters""" + user = self.get_by_id(user_id) + if not user: + return None + + # Check if usage needs to be reset (monthly) + if user.usage_reset_date: + now = datetime.utcnow() + if now.month != user.usage_reset_date.month or now.year != user.usage_reset_date.year: + user.docs_translated_this_month = 0 + user.pages_translated_this_month = 0 + user.api_calls_this_month = 0 + user.usage_reset_date = now + + user.docs_translated_this_month += docs + user.pages_translated_this_month += pages + user.api_calls_this_month += api_calls + + self.db.commit() + self.db.refresh(user) + return user + + def add_credits(self, user_id: str, credits: int) -> Optional[User]: + """Add extra credits to user""" + user = self.get_by_id(user_id) + if not user: + return None + + user.extra_credits += credits + self.db.commit() + self.db.refresh(user) + return user + + def use_credits(self, user_id: str, credits: int) -> bool: + """Use credits from user balance""" + user = self.get_by_id(user_id) + if not user or user.extra_credits < credits: + return False + + user.extra_credits -= credits + self.db.commit() + return True + + def get_all_users( + self, + skip: int = 0, + limit: int = 100, + plan: Optional[PlanType] = None + ) -> List[User]: + """Get all users with pagination""" + query = self.db.query(User) + if plan: + query = query.filter(User.plan == plan) + return query.offset(skip).limit(limit).all() + + def count_users(self, plan: Optional[PlanType] = None) -> int: + """Count total users""" + query = self.db.query(func.count(User.id)) + if plan: + query = query.filter(User.plan == plan) + return query.scalar() + + +class TranslationRepository: + """Repository for Translation database operations""" + + def __init__(self, db: Session): + self.db = db + + def create( + self, + user_id: str, + original_filename: str, + file_type: str, + target_language: str, + provider: str, + source_language: str = "auto", + file_size_bytes: int = 0, + page_count: int = 0, + ) -> Translation: + """Create a new translation record""" + translation = Translation( + user_id=user_id, + original_filename=original_filename, + file_type=file_type, + file_size_bytes=file_size_bytes, + page_count=page_count, + source_language=source_language, + target_language=target_language, + provider=provider, + status="pending", + ) + self.db.add(translation) + self.db.commit() + self.db.refresh(translation) + return translation + + def update_status( + self, + translation_id: str, + status: str, + error_message: Optional[str] = None, + processing_time_ms: Optional[int] = None, + characters_translated: Optional[int] = None, + ) -> Optional[Translation]: + """Update translation status""" + translation = self.db.query(Translation).filter( + Translation.id == translation_id + ).first() + + if not translation: + return None + + translation.status = status + if error_message: + translation.error_message = error_message + if processing_time_ms: + translation.processing_time_ms = processing_time_ms + if characters_translated: + translation.characters_translated = characters_translated + if status == "completed": + translation.completed_at = datetime.utcnow() + + self.db.commit() + self.db.refresh(translation) + return translation + + def get_user_translations( + self, + user_id: str, + skip: int = 0, + limit: int = 50, + status: Optional[str] = None, + ) -> List[Translation]: + """Get user's translation history""" + query = self.db.query(Translation).filter(Translation.user_id == user_id) + if status: + query = query.filter(Translation.status == status) + return query.order_by(Translation.created_at.desc()).offset(skip).limit(limit).all() + + def get_user_stats(self, user_id: str, days: int = 30) -> Dict[str, Any]: + """Get user's translation statistics""" + since = datetime.utcnow() - timedelta(days=days) + + result = self.db.query( + func.count(Translation.id).label("total_translations"), + func.sum(Translation.page_count).label("total_pages"), + func.sum(Translation.characters_translated).label("total_characters"), + ).filter( + and_( + Translation.user_id == user_id, + Translation.created_at >= since, + Translation.status == "completed", + ) + ).first() + + return { + "total_translations": result.total_translations or 0, + "total_pages": result.total_pages or 0, + "total_characters": result.total_characters or 0, + "period_days": days, + } + + +class ApiKeyRepository: + """Repository for API Key database operations""" + + def __init__(self, db: Session): + self.db = db + + @staticmethod + def hash_key(key: str) -> str: + """Hash an API key""" + return hashlib.sha256(key.encode()).hexdigest() + + def create( + self, + user_id: str, + name: str, + scopes: List[str] = None, + expires_in_days: Optional[int] = None, + ) -> tuple[ApiKey, str]: + """Create a new API key. Returns (ApiKey, raw_key)""" + # Generate a secure random key + raw_key = f"tr_{secrets.token_urlsafe(32)}" + key_hash = self.hash_key(raw_key) + key_prefix = raw_key[:10] + + expires_at = None + if expires_in_days: + expires_at = datetime.utcnow() + timedelta(days=expires_in_days) + + api_key = ApiKey( + user_id=user_id, + name=name, + key_hash=key_hash, + key_prefix=key_prefix, + scopes=scopes or ["translate"], + expires_at=expires_at, + ) + self.db.add(api_key) + self.db.commit() + self.db.refresh(api_key) + + return api_key, raw_key + + def get_by_key(self, raw_key: str) -> Optional[ApiKey]: + """Get API key by raw key value""" + key_hash = self.hash_key(raw_key) + api_key = self.db.query(ApiKey).filter( + and_( + ApiKey.key_hash == key_hash, + ApiKey.is_active == True, + ) + ).first() + + if api_key: + # Check expiration + if api_key.expires_at and api_key.expires_at < datetime.utcnow(): + return None + + # Update last used + api_key.last_used_at = datetime.utcnow() + api_key.usage_count += 1 + self.db.commit() + + return api_key + + def get_user_keys(self, user_id: str) -> List[ApiKey]: + """Get all API keys for a user""" + return self.db.query(ApiKey).filter( + ApiKey.user_id == user_id + ).order_by(ApiKey.created_at.desc()).all() + + def revoke(self, key_id: str, user_id: str) -> bool: + """Revoke an API key""" + api_key = self.db.query(ApiKey).filter( + and_( + ApiKey.id == key_id, + ApiKey.user_id == user_id, + ) + ).first() + + if not api_key: + return False + + api_key.is_active = False + self.db.commit() + return True diff --git a/docker-compose.yml b/docker-compose.yml index e307cd7..f76f2a0 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -4,6 +4,35 @@ version: '3.8' services: + # =========================================== + # PostgreSQL Database + # =========================================== + postgres: + image: postgres:16-alpine + container_name: translate-postgres + restart: unless-stopped + environment: + - POSTGRES_USER=${POSTGRES_USER:-translate} + - POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-translate_secret_123} + - POSTGRES_DB=${POSTGRES_DB:-translate_db} + - PGDATA=/var/lib/postgresql/data/pgdata + volumes: + - postgres_data:/var/lib/postgresql/data + networks: + - translate-network + healthcheck: + test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-translate} -d ${POSTGRES_DB:-translate_db}"] + interval: 10s + timeout: 5s + retries: 5 + start_period: 10s + deploy: + resources: + limits: + memory: 512M + reservations: + memory: 128M + # =========================================== # Backend API Service # =========================================== @@ -14,23 +43,45 @@ services: container_name: translate-backend restart: unless-stopped environment: + # Database + - DATABASE_URL=postgresql+asyncpg://${POSTGRES_USER:-translate}:${POSTGRES_PASSWORD:-translate_secret_123}@postgres:5432/${POSTGRES_DB:-translate_db} + # Redis + - REDIS_URL=redis://redis:6379/0 + # Translation Services - TRANSLATION_SERVICE=${TRANSLATION_SERVICE:-ollama} - OLLAMA_BASE_URL=${OLLAMA_BASE_URL:-http://ollama:11434} - OLLAMA_MODEL=${OLLAMA_MODEL:-llama3} - DEEPL_API_KEY=${DEEPL_API_KEY:-} - OPENAI_API_KEY=${OPENAI_API_KEY:-} + - OPENROUTER_API_KEY=${OPENROUTER_API_KEY:-} + # File Limits - MAX_FILE_SIZE_MB=${MAX_FILE_SIZE_MB:-50} + # Rate Limiting - RATE_LIMIT_REQUESTS_PER_MINUTE=${RATE_LIMIT_REQUESTS_PER_MINUTE:-60} - RATE_LIMIT_TRANSLATIONS_PER_MINUTE=${RATE_LIMIT_TRANSLATIONS_PER_MINUTE:-10} - - ADMIN_USERNAME=${ADMIN_USERNAME:-admin} - - ADMIN_PASSWORD=${ADMIN_PASSWORD:-changeme123} - - CORS_ORIGINS=${CORS_ORIGINS:-*} + # Admin Auth (CHANGE IN PRODUCTION!) + - ADMIN_USERNAME=${ADMIN_USERNAME} + - ADMIN_PASSWORD=${ADMIN_PASSWORD} + # Security + - JWT_SECRET=${JWT_SECRET} + - CORS_ORIGINS=${CORS_ORIGINS:-https://yourdomain.com} + # Stripe Payments + - STRIPE_SECRET_KEY=${STRIPE_SECRET_KEY:-} + - STRIPE_WEBHOOK_SECRET=${STRIPE_WEBHOOK_SECRET:-} + - STRIPE_STARTER_PRICE_ID=${STRIPE_STARTER_PRICE_ID:-} + - STRIPE_PRO_PRICE_ID=${STRIPE_PRO_PRICE_ID:-} + - STRIPE_BUSINESS_PRICE_ID=${STRIPE_BUSINESS_PRICE_ID:-} volumes: - uploads_data:/app/uploads - outputs_data:/app/outputs - logs_data:/app/logs networks: - translate-network + depends_on: + postgres: + condition: service_healthy + redis: + condition: service_healthy healthcheck: test: ["CMD", "curl", "-f", "http://localhost:8000/health"] interval: 30s @@ -117,7 +168,7 @@ services: - with-ollama # =========================================== - # Redis (Optional - For caching & sessions) + # Redis (Caching & Sessions) # =========================================== redis: image: redis:7-alpine @@ -130,11 +181,9 @@ services: - translate-network healthcheck: test: ["CMD", "redis-cli", "ping"] - interval: 30s - timeout: 10s - retries: 3 - profiles: - - with-cache + interval: 10s + timeout: 5s + retries: 5 # =========================================== # Prometheus (Optional - Monitoring) @@ -190,6 +239,8 @@ networks: # Volumes # =========================================== volumes: + postgres_data: + driver: local uploads_data: driver: local outputs_data: diff --git a/main.py b/main.py index 7e1ed53..a4895c4 100644 --- a/main.py +++ b/main.py @@ -204,6 +204,18 @@ async def lifespan(app: FastAPI): # Startup logger.info("Starting Document Translation API...") config.ensure_directories() + + # Initialize database + try: + from database.connection import init_db, check_db_connection + init_db() + if check_db_connection(): + logger.info("✅ Database connection verified") + else: + logger.warning("⚠️ Database connection check failed") + except Exception as e: + logger.warning(f"⚠️ Database initialization skipped: {e}") + await cleanup_manager.start() logger.info("API ready to accept requests") diff --git a/requirements.txt b/requirements.txt index 6211f77..10eda51 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,8 +28,7 @@ stripe==7.0.0 # Session storage & caching (optional but recommended for production) redis==5.0.1 -# Database (optional but recommended for production) -# sqlalchemy==2.0.25 -# asyncpg==0.29.0 # PostgreSQL async driver -# alembic==1.13.1 # Database migrations - +# Database (recommended for production) +sqlalchemy==2.0.25 +psycopg2-binary==2.9.9 # PostgreSQL driver +alembic==1.13.1 # Database migrations diff --git a/scripts/migrate_to_db.py b/scripts/migrate_to_db.py new file mode 100644 index 0000000..a29d99f --- /dev/null +++ b/scripts/migrate_to_db.py @@ -0,0 +1,160 @@ +""" +Migration script to move data from JSON files to database +Run this once to migrate existing users to the new database system +""" +import json +import os +import sys +from pathlib import Path +from datetime import datetime + +# Add parent directory to path +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from database.connection import init_db, get_db_session +from database.repositories import UserRepository +from database.models import PlanType, SubscriptionStatus + + +def migrate_users_from_json(): + """Migrate users from JSON file to database""" + json_path = Path("data/users.json") + + if not json_path.exists(): + print("No users.json found, nothing to migrate") + return 0 + + # Initialize database + print("Initializing database tables...") + init_db() + + # Load JSON data + with open(json_path, 'r') as f: + users_data = json.load(f) + + print(f"Found {len(users_data)} users to migrate") + + migrated = 0 + skipped = 0 + errors = 0 + + with get_db_session() as db: + repo = UserRepository(db) + + for user_id, user_data in users_data.items(): + try: + # Check if user already exists + existing = repo.get_by_email(user_data.get('email', '')) + if existing: + print(f" Skipping {user_data.get('email')} - already exists") + skipped += 1 + continue + + # Map plan string to enum + plan_str = user_data.get('plan', 'free') + try: + plan = PlanType(plan_str) + except ValueError: + plan = PlanType.FREE + + # Map subscription status + status_str = user_data.get('subscription_status', 'active') + try: + status = SubscriptionStatus(status_str) + except ValueError: + status = SubscriptionStatus.ACTIVE + + # Create user with original ID + from database.models import User + user = User( + id=user_id, + email=user_data.get('email', '').lower(), + name=user_data.get('name', ''), + password_hash=user_data.get('password_hash', ''), + email_verified=user_data.get('email_verified', False), + avatar_url=user_data.get('avatar_url'), + plan=plan, + subscription_status=status, + stripe_customer_id=user_data.get('stripe_customer_id'), + stripe_subscription_id=user_data.get('stripe_subscription_id'), + docs_translated_this_month=user_data.get('docs_translated_this_month', 0), + pages_translated_this_month=user_data.get('pages_translated_this_month', 0), + api_calls_this_month=user_data.get('api_calls_this_month', 0), + extra_credits=user_data.get('extra_credits', 0), + ) + + # Parse dates + if user_data.get('created_at'): + try: + user.created_at = datetime.fromisoformat(user_data['created_at'].replace('Z', '+00:00')) + except: + pass + + if user_data.get('updated_at'): + try: + user.updated_at = datetime.fromisoformat(user_data['updated_at'].replace('Z', '+00:00')) + except: + pass + + db.add(user) + db.commit() + + print(f" Migrated: {user.email}") + migrated += 1 + + except Exception as e: + print(f" Error migrating {user_data.get('email', user_id)}: {e}") + errors += 1 + db.rollback() + + print(f"\nMigration complete:") + print(f" Migrated: {migrated}") + print(f" Skipped: {skipped}") + print(f" Errors: {errors}") + + # Backup original file + if migrated > 0: + backup_path = json_path.with_suffix('.json.bak') + os.rename(json_path, backup_path) + print(f"\nOriginal file backed up to: {backup_path}") + + return migrated + + +def verify_migration(): + """Verify the migration was successful""" + from database.connection import get_db_session + from database.repositories import UserRepository + + with get_db_session() as db: + repo = UserRepository(db) + count = repo.count_users() + print(f"\nDatabase now contains {count} users") + + # List first 5 users + users = repo.get_all_users(limit=5) + if users: + print("\nSample users:") + for user in users: + print(f" - {user.email} ({user.plan.value})") + + +if __name__ == "__main__": + print("=" * 50) + print("JSON to Database Migration Script") + print("=" * 50) + + # Check environment + db_url = os.getenv("DATABASE_URL", "") + if db_url: + print(f"Database: PostgreSQL") + else: + print(f"Database: SQLite (development)") + + print() + + # Run migration + migrate_users_from_json() + + # Verify + verify_migration() diff --git a/services/auth_service.py b/services/auth_service.py index 68a3ed2..26c5fa0 100644 --- a/services/auth_service.py +++ b/services/auth_service.py @@ -1,5 +1,9 @@ """ Authentication service with JWT tokens and password hashing + +This service provides user authentication with automatic backend selection: +- If DATABASE_URL is configured: Uses PostgreSQL database +- Otherwise: Falls back to JSON file storage (development mode) """ import os import secrets @@ -8,6 +12,9 @@ from datetime import datetime, timedelta from typing import Optional, Dict, Any import json from pathlib import Path +import logging + +logger = logging.getLogger(__name__) # Try to import optional dependencies try: @@ -15,6 +22,7 @@ try: JWT_AVAILABLE = True except ImportError: JWT_AVAILABLE = False + logger.warning("PyJWT not installed. Using fallback token encoding.") try: from passlib.context import CryptContext @@ -22,17 +30,37 @@ try: PASSLIB_AVAILABLE = True except ImportError: PASSLIB_AVAILABLE = False + logger.warning("passlib not installed. Using SHA256 fallback for password hashing.") + +# Check if database is configured +DATABASE_URL = os.getenv("DATABASE_URL", "") +USE_DATABASE = bool(DATABASE_URL and DATABASE_URL.startswith("postgresql")) + +if USE_DATABASE: + try: + from database.repositories import UserRepository + from database.connection import get_sync_session, init_db as _init_db + from database import models as db_models + DATABASE_AVAILABLE = True + logger.info("Database backend enabled for authentication") + except ImportError as e: + DATABASE_AVAILABLE = False + USE_DATABASE = False + logger.warning(f"Database modules not available: {e}. Using JSON storage.") +else: + DATABASE_AVAILABLE = False + logger.info("Using JSON file storage for authentication (DATABASE_URL not configured)") from models.subscription import User, UserCreate, PlanType, SubscriptionStatus, PLANS # Configuration -SECRET_KEY = os.getenv("JWT_SECRET_KEY", secrets.token_urlsafe(32)) +SECRET_KEY = os.getenv("JWT_SECRET", os.getenv("JWT_SECRET_KEY", secrets.token_urlsafe(32))) ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_HOURS = 24 REFRESH_TOKEN_EXPIRE_DAYS = 30 -# Simple file-based storage (replace with database in production) +# Simple file-based storage (used when database is not configured) USERS_FILE = Path("data/users.json") USERS_FILE.parent.mkdir(exist_ok=True) @@ -117,7 +145,7 @@ def verify_token(token: str) -> Optional[Dict[str, Any]]: def load_users() -> Dict[str, Dict]: - """Load users from file storage""" + """Load users from file storage (JSON backend only)""" if USERS_FILE.exists(): try: with open(USERS_FILE, 'r') as f: @@ -128,54 +156,109 @@ def load_users() -> Dict[str, Dict]: def save_users(users: Dict[str, Dict]): - """Save users to file storage""" + """Save users to file storage (JSON backend only)""" with open(USERS_FILE, 'w') as f: json.dump(users, f, indent=2, default=str) +def _db_user_to_model(db_user) -> User: + """Convert database user model to Pydantic User model""" + return User( + id=str(db_user.id), + email=db_user.email, + name=db_user.name or "", + password_hash=db_user.password_hash, + avatar_url=db_user.avatar_url, + plan=PlanType(db_user.plan) if db_user.plan else PlanType.FREE, + subscription_status=SubscriptionStatus(db_user.subscription_status) if db_user.subscription_status else SubscriptionStatus.ACTIVE, + stripe_customer_id=db_user.stripe_customer_id, + stripe_subscription_id=db_user.stripe_subscription_id, + docs_translated_this_month=db_user.docs_translated_this_month or 0, + pages_translated_this_month=db_user.pages_translated_this_month or 0, + api_calls_this_month=db_user.api_calls_this_month or 0, + extra_credits=db_user.extra_credits or 0, + usage_reset_date=db_user.usage_reset_date or datetime.utcnow(), + default_source_lang=db_user.default_source_lang or "en", + default_target_lang=db_user.default_target_lang or "es", + default_provider=db_user.default_provider or "google", + created_at=db_user.created_at or datetime.utcnow(), + updated_at=db_user.updated_at, + ) + + def get_user_by_email(email: str) -> Optional[User]: """Get a user by email""" - users = load_users() - for user_data in users.values(): - if user_data.get("email", "").lower() == email.lower(): - return User(**user_data) - return None + if USE_DATABASE and DATABASE_AVAILABLE: + with get_sync_session() as session: + repo = UserRepository(session) + db_user = repo.get_by_email(email) + if db_user: + return _db_user_to_model(db_user) + return None + else: + users = load_users() + for user_data in users.values(): + if user_data.get("email", "").lower() == email.lower(): + return User(**user_data) + return None def get_user_by_id(user_id: str) -> Optional[User]: """Get a user by ID""" - users = load_users() - if user_id in users: - return User(**users[user_id]) - return None + if USE_DATABASE and DATABASE_AVAILABLE: + with get_sync_session() as session: + repo = UserRepository(session) + db_user = repo.get_by_id(user_id) + if db_user: + return _db_user_to_model(db_user) + return None + else: + users = load_users() + if user_id in users: + return User(**users[user_id]) + return None def create_user(user_create: UserCreate) -> User: """Create a new user""" - users = load_users() - # Check if email exists if get_user_by_email(user_create.email): raise ValueError("Email already registered") - # Generate user ID - user_id = secrets.token_urlsafe(16) - - # Create user - user = User( - id=user_id, - email=user_create.email, - name=user_create.name, - password_hash=hash_password(user_create.password), - plan=PlanType.FREE, - subscription_status=SubscriptionStatus.ACTIVE, - ) - - # Save to storage - users[user_id] = user.model_dump() - save_users(users) - - return user + if USE_DATABASE and DATABASE_AVAILABLE: + with get_sync_session() as session: + repo = UserRepository(session) + db_user = repo.create( + email=user_create.email, + name=user_create.name, + password_hash=hash_password(user_create.password), + plan=PlanType.FREE.value, + subscription_status=SubscriptionStatus.ACTIVE.value + ) + session.commit() + session.refresh(db_user) + return _db_user_to_model(db_user) + else: + users = load_users() + + # Generate user ID + user_id = secrets.token_urlsafe(16) + + # Create user + user = User( + id=user_id, + email=user_create.email, + name=user_create.name, + password_hash=hash_password(user_create.password), + plan=PlanType.FREE, + subscription_status=SubscriptionStatus.ACTIVE, + ) + + # Save to storage + users[user_id] = user.model_dump() + save_users(users) + + return user def authenticate_user(email: str, password: str) -> Optional[User]: @@ -190,15 +273,25 @@ def authenticate_user(email: str, password: str) -> Optional[User]: def update_user(user_id: str, updates: Dict[str, Any]) -> Optional[User]: """Update a user's data""" - users = load_users() - if user_id not in users: - return None - - users[user_id].update(updates) - users[user_id]["updated_at"] = datetime.utcnow().isoformat() - save_users(users) - - return User(**users[user_id]) + if USE_DATABASE and DATABASE_AVAILABLE: + with get_sync_session() as session: + repo = UserRepository(session) + db_user = repo.update(user_id, updates) + if db_user: + session.commit() + session.refresh(db_user) + return _db_user_to_model(db_user) + return None + else: + users = load_users() + if user_id not in users: + return None + + users[user_id].update(updates) + users[user_id]["updated_at"] = datetime.utcnow().isoformat() + save_users(users) + + return User(**users[user_id]) def check_usage_limits(user: User) -> Dict[str, Any]: @@ -212,7 +305,7 @@ def check_usage_limits(user: User) -> Dict[str, Any]: "docs_translated_this_month": 0, "pages_translated_this_month": 0, "api_calls_this_month": 0, - "usage_reset_date": now.isoformat() + "usage_reset_date": now.isoformat() if not USE_DATABASE else now }) user.docs_translated_this_month = 0 user.pages_translated_this_month = 0 @@ -248,8 +341,8 @@ def record_usage(user_id: str, pages_count: int, use_credits: bool = False) -> b if use_credits: updates["extra_credits"] = max(0, user.extra_credits - pages_count) - update_user(user_id, updates) - return True + result = update_user(user_id, updates) + return result is not None def add_credits(user_id: str, credits: int) -> bool: @@ -258,5 +351,14 @@ def add_credits(user_id: str, credits: int) -> bool: if not user: return False - update_user(user_id, {"extra_credits": user.extra_credits + credits}) - return True + result = update_user(user_id, {"extra_credits": user.extra_credits + credits}) + return result is not None + + +def init_database(): + """Initialize the database (call on application startup)""" + if USE_DATABASE and DATABASE_AVAILABLE: + _init_db() + logger.info("Database initialized successfully") + else: + logger.info("Using JSON file storage") diff --git a/services/auth_service_db.py b/services/auth_service_db.py new file mode 100644 index 0000000..c678890 --- /dev/null +++ b/services/auth_service_db.py @@ -0,0 +1,245 @@ +""" +Database-backed authentication service +Replaces JSON file storage with SQLAlchemy +""" +import os +import secrets +import hashlib +from datetime import datetime, timedelta +from typing import Optional, Dict, Any +import logging + +# Try to import optional dependencies +try: + import jwt + JWT_AVAILABLE = True +except ImportError: + JWT_AVAILABLE = False + import json + import base64 + +try: + from passlib.context import CryptContext + pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + PASSLIB_AVAILABLE = True +except ImportError: + PASSLIB_AVAILABLE = False + +from database.connection import get_db_session +from database.repositories import UserRepository +from database.models import User, PlanType, SubscriptionStatus +from models.subscription import PLANS + +logger = logging.getLogger(__name__) + +# Configuration +SECRET_KEY = os.getenv("JWT_SECRET_KEY", secrets.token_urlsafe(32)) +ALGORITHM = "HS256" +ACCESS_TOKEN_EXPIRE_HOURS = 24 +REFRESH_TOKEN_EXPIRE_DAYS = 30 + + +def hash_password(password: str) -> str: + """Hash a password using bcrypt or fallback to SHA256""" + if PASSLIB_AVAILABLE: + return pwd_context.hash(password) + else: + salt = secrets.token_hex(16) + hashed = hashlib.sha256(f"{salt}{password}".encode()).hexdigest() + return f"sha256${salt}${hashed}" + + +def verify_password(plain_password: str, hashed_password: str) -> bool: + """Verify a password against its hash""" + if PASSLIB_AVAILABLE and not hashed_password.startswith("sha256$"): + return pwd_context.verify(plain_password, hashed_password) + else: + parts = hashed_password.split("$") + if len(parts) == 3 and parts[0] == "sha256": + salt = parts[1] + expected_hash = parts[2] + actual_hash = hashlib.sha256(f"{salt}{plain_password}".encode()).hexdigest() + return secrets.compare_digest(actual_hash, expected_hash) + return False + + +def create_access_token(user_id: str, expires_delta: Optional[timedelta] = None) -> str: + """Create a JWT access token""" + expire = datetime.utcnow() + (expires_delta or timedelta(hours=ACCESS_TOKEN_EXPIRE_HOURS)) + + if not JWT_AVAILABLE: + token_data = {"user_id": user_id, "exp": expire.isoformat(), "type": "access"} + return base64.urlsafe_b64encode(json.dumps(token_data).encode()).decode() + + to_encode = {"sub": user_id, "exp": expire, "type": "access"} + return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) + + +def create_refresh_token(user_id: str) -> str: + """Create a JWT refresh token""" + expire = datetime.utcnow() + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) + + if not JWT_AVAILABLE: + token_data = {"user_id": user_id, "exp": expire.isoformat(), "type": "refresh"} + return base64.urlsafe_b64encode(json.dumps(token_data).encode()).decode() + + to_encode = {"sub": user_id, "exp": expire, "type": "refresh"} + return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) + + +def verify_token(token: str) -> Optional[Dict[str, Any]]: + """Verify a JWT token and return payload""" + if not JWT_AVAILABLE: + try: + data = json.loads(base64.urlsafe_b64decode(token.encode()).decode()) + exp = datetime.fromisoformat(data["exp"]) + if exp < datetime.utcnow(): + return None + return {"sub": data["user_id"], "type": data.get("type", "access")} + except Exception: + return None + + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + return payload + except jwt.ExpiredSignatureError: + return None + except jwt.JWTError: + return None + + +def create_user(email: str, name: str, password: str) -> User: + """Create a new user in the database""" + with get_db_session() as db: + repo = UserRepository(db) + + # Check if email already exists + existing = repo.get_by_email(email) + if existing: + raise ValueError("Email already registered") + + password_hash = hash_password(password) + user = repo.create( + email=email, + name=name, + password_hash=password_hash, + plan=PlanType.FREE, + ) + return user + + +def authenticate_user(email: str, password: str) -> Optional[User]: + """Authenticate user and return user object if valid""" + with get_db_session() as db: + repo = UserRepository(db) + user = repo.get_by_email(email) + + if not user: + return None + + if not verify_password(password, user.password_hash): + return None + + # Update last login + repo.update(user.id, last_login_at=datetime.utcnow()) + return user + + +def get_user_by_id(user_id: str) -> Optional[User]: + """Get user by ID from database""" + with get_db_session() as db: + repo = UserRepository(db) + return repo.get_by_id(user_id) + + +def get_user_by_email(email: str) -> Optional[User]: + """Get user by email from database""" + with get_db_session() as db: + repo = UserRepository(db) + return repo.get_by_email(email) + + +def update_user(user_id: str, updates: Dict[str, Any]) -> Optional[User]: + """Update user fields in database""" + with get_db_session() as db: + repo = UserRepository(db) + return repo.update(user_id, **updates) + + +def add_credits(user_id: str, credits: int) -> bool: + """Add credits to user account""" + with get_db_session() as db: + repo = UserRepository(db) + result = repo.add_credits(user_id, credits) + return result is not None + + +def use_credits(user_id: str, credits: int) -> bool: + """Use credits from user account""" + with get_db_session() as db: + repo = UserRepository(db) + return repo.use_credits(user_id, credits) + + +def increment_usage(user_id: str, docs: int = 0, pages: int = 0, api_calls: int = 0) -> bool: + """Increment user usage counters""" + with get_db_session() as db: + repo = UserRepository(db) + result = repo.increment_usage(user_id, docs=docs, pages=pages, api_calls=api_calls) + return result is not None + + +def check_usage_limits(user_id: str) -> Dict[str, Any]: + """Check if user is within their plan limits""" + with get_db_session() as db: + repo = UserRepository(db) + user = repo.get_by_id(user_id) + + if not user: + return {"allowed": False, "reason": "User not found"} + + plan_config = PLANS.get(user.plan, PLANS[PlanType.FREE]) + + # Check document limit + docs_limit = plan_config["docs_per_month"] + if docs_limit > 0 and user.docs_translated_this_month >= docs_limit: + # Check if user has extra credits + if user.extra_credits <= 0: + return { + "allowed": False, + "reason": "Monthly document limit reached", + "limit": docs_limit, + "used": user.docs_translated_this_month, + } + + return { + "allowed": True, + "docs_remaining": max(0, docs_limit - user.docs_translated_this_month) if docs_limit > 0 else -1, + "extra_credits": user.extra_credits, + } + + +def get_user_usage_stats(user_id: str) -> Dict[str, Any]: + """Get detailed usage statistics for a user""" + with get_db_session() as db: + repo = UserRepository(db) + user = repo.get_by_id(user_id) + + if not user: + return {} + + plan_config = PLANS.get(user.plan, PLANS[PlanType.FREE]) + + return { + "docs_used": user.docs_translated_this_month, + "docs_limit": plan_config["docs_per_month"], + "docs_remaining": max(0, plan_config["docs_per_month"] - user.docs_translated_this_month) if plan_config["docs_per_month"] > 0 else -1, + "pages_used": user.pages_translated_this_month, + "extra_credits": user.extra_credits, + "max_pages_per_doc": plan_config["max_pages_per_doc"], + "max_file_size_mb": plan_config["max_file_size_mb"], + "allowed_providers": plan_config["providers"], + "api_access": plan_config.get("api_access", False), + "api_calls_used": user.api_calls_this_month if plan_config.get("api_access") else 0, + "api_calls_limit": plan_config.get("api_calls_per_month", 0), + }