feat: Add PostgreSQL database infrastructure
- Add SQLAlchemy models for User, Translation, ApiKey, UsageLog, PaymentHistory - Add database connection management with PostgreSQL/SQLite support - Add repository layer for CRUD operations - Add Alembic migration setup with initial migration - Update auth_service to automatically use database when DATABASE_URL is set - Update docker-compose.yml with PostgreSQL service and Redis (non-optional) - Add database migration script (scripts/migrate_to_db.py) - Update .env.example with database configuration
This commit is contained in:
parent
c4d6cae735
commit
550f3516db
11
.env.example
11
.env.example
@ -72,9 +72,16 @@ MAX_REQUEST_SIZE_MB=100
|
||||
# Request timeout in seconds
|
||||
REQUEST_TIMEOUT_SECONDS=300
|
||||
|
||||
# ============== Database (Production) ==============
|
||||
# ============== Database ==============
|
||||
# PostgreSQL connection string (recommended for production)
|
||||
# DATABASE_URL=postgresql://user:password@localhost:5432/translate_db
|
||||
# Format: postgresql://user:password@host:port/database
|
||||
# DATABASE_URL=postgresql://translate_user:secure_password@localhost:5432/translate_db
|
||||
|
||||
# SQLite path (used when DATABASE_URL is not set - for development)
|
||||
SQLITE_PATH=data/translate.db
|
||||
|
||||
# Enable SQL query logging (for debugging)
|
||||
DATABASE_ECHO=false
|
||||
|
||||
# Redis for sessions and caching (recommended for production)
|
||||
# REDIS_URL=redis://localhost:6379/0
|
||||
|
||||
78
alembic.ini
Normal file
78
alembic.ini
Normal file
@ -0,0 +1,78 @@
|
||||
# Alembic configuration file
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts
|
||||
script_location = alembic
|
||||
|
||||
# template used to generate migration files
|
||||
# file_template = %%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
prepend_sys_path = .
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during the 'revision' command,
|
||||
# regardless of autogenerate
|
||||
# revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without a source .py file
|
||||
# to be detected as revisions in the versions/ directory
|
||||
# sourceless = false
|
||||
|
||||
# version location specification; This defaults to alembic/versions
|
||||
# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions
|
||||
|
||||
# version path separator
|
||||
# version_path_separator = :
|
||||
|
||||
# the output encoding used when revision files are written
|
||||
# output_encoding = utf-8
|
||||
|
||||
sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||
|
||||
|
||||
[post_write_hooks]
|
||||
# hooks = black
|
||||
# black.type = console_scripts
|
||||
# black.entrypoint = black
|
||||
# black.options = -q
|
||||
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
86
alembic/env.py
Normal file
86
alembic/env.py
Normal file
@ -0,0 +1,86 @@
|
||||
"""
|
||||
Alembic environment configuration
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
from logging.config import fileConfig
|
||||
|
||||
from sqlalchemy import engine_from_config
|
||||
from sqlalchemy import pool
|
||||
|
||||
from alembic import context
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
# this is the Alembic Config object
|
||||
config = context.config
|
||||
|
||||
# Interpret the config file for Python logging
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# Import models for autogenerate support
|
||||
from database.models import Base
|
||||
target_metadata = Base.metadata
|
||||
|
||||
# Get database URL from environment
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "")
|
||||
if not DATABASE_URL:
|
||||
SQLITE_PATH = os.getenv("SQLITE_PATH", "data/translate.db")
|
||||
DATABASE_URL = f"sqlite:///./{SQLITE_PATH}"
|
||||
|
||||
# Override sqlalchemy.url with environment variable
|
||||
config.set_main_option("sqlalchemy.url", DATABASE_URL)
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
|
||||
This configures the context with just a URL
|
||||
and not an Engine, though an Engine is acceptable
|
||||
here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
"""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode.
|
||||
|
||||
In this scenario we need to create an Engine
|
||||
and associate a connection with the context.
|
||||
"""
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
compare_type=True,
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
27
alembic/script.py.mako
Normal file
27
alembic/script.py.mako
Normal file
@ -0,0 +1,27 @@
|
||||
"""
|
||||
Alembic migration script template
|
||||
${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
126
alembic/versions/001_initial.py
Normal file
126
alembic/versions/001_initial.py
Normal file
@ -0,0 +1,126 @@
|
||||
"""Initial database schema
|
||||
|
||||
Revision ID: 001_initial
|
||||
Revises:
|
||||
Create Date: 2024-12-31
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '001_initial'
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create users table
|
||||
op.create_table(
|
||||
'users',
|
||||
sa.Column('id', sa.String(36), primary_key=True),
|
||||
sa.Column('email', sa.String(255), unique=True, nullable=False),
|
||||
sa.Column('name', sa.String(255), nullable=False),
|
||||
sa.Column('password_hash', sa.String(255), nullable=False),
|
||||
sa.Column('email_verified', sa.Boolean(), default=False),
|
||||
sa.Column('is_active', sa.Boolean(), default=True),
|
||||
sa.Column('avatar_url', sa.String(500), nullable=True),
|
||||
sa.Column('plan', sa.String(20), default='free'),
|
||||
sa.Column('subscription_status', sa.String(20), default='active'),
|
||||
sa.Column('stripe_customer_id', sa.String(255), nullable=True),
|
||||
sa.Column('stripe_subscription_id', sa.String(255), nullable=True),
|
||||
sa.Column('docs_translated_this_month', sa.Integer(), default=0),
|
||||
sa.Column('pages_translated_this_month', sa.Integer(), default=0),
|
||||
sa.Column('api_calls_this_month', sa.Integer(), default=0),
|
||||
sa.Column('extra_credits', sa.Integer(), default=0),
|
||||
sa.Column('usage_reset_date', sa.DateTime()),
|
||||
sa.Column('created_at', sa.DateTime()),
|
||||
sa.Column('updated_at', sa.DateTime()),
|
||||
sa.Column('last_login_at', sa.DateTime(), nullable=True),
|
||||
)
|
||||
op.create_index('ix_users_email', 'users', ['email'])
|
||||
op.create_index('ix_users_email_active', 'users', ['email', 'is_active'])
|
||||
op.create_index('ix_users_stripe_customer', 'users', ['stripe_customer_id'])
|
||||
|
||||
# Create translations table
|
||||
op.create_table(
|
||||
'translations',
|
||||
sa.Column('id', sa.String(36), primary_key=True),
|
||||
sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id', ondelete='CASCADE'), nullable=False),
|
||||
sa.Column('original_filename', sa.String(255), nullable=False),
|
||||
sa.Column('file_type', sa.String(10), nullable=False),
|
||||
sa.Column('file_size_bytes', sa.BigInteger(), default=0),
|
||||
sa.Column('page_count', sa.Integer(), default=0),
|
||||
sa.Column('source_language', sa.String(10), default='auto'),
|
||||
sa.Column('target_language', sa.String(10), nullable=False),
|
||||
sa.Column('provider', sa.String(50), nullable=False),
|
||||
sa.Column('status', sa.String(20), default='pending'),
|
||||
sa.Column('error_message', sa.Text(), nullable=True),
|
||||
sa.Column('processing_time_ms', sa.Integer(), nullable=True),
|
||||
sa.Column('characters_translated', sa.Integer(), default=0),
|
||||
sa.Column('estimated_cost_usd', sa.Float(), default=0.0),
|
||||
sa.Column('created_at', sa.DateTime()),
|
||||
sa.Column('completed_at', sa.DateTime(), nullable=True),
|
||||
)
|
||||
op.create_index('ix_translations_user_date', 'translations', ['user_id', 'created_at'])
|
||||
op.create_index('ix_translations_status', 'translations', ['status'])
|
||||
|
||||
# Create api_keys table
|
||||
op.create_table(
|
||||
'api_keys',
|
||||
sa.Column('id', sa.String(36), primary_key=True),
|
||||
sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id', ondelete='CASCADE'), nullable=False),
|
||||
sa.Column('name', sa.String(100), nullable=False),
|
||||
sa.Column('key_hash', sa.String(255), nullable=False),
|
||||
sa.Column('key_prefix', sa.String(10), nullable=False),
|
||||
sa.Column('is_active', sa.Boolean(), default=True),
|
||||
sa.Column('scopes', sa.JSON(), default=list),
|
||||
sa.Column('last_used_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('usage_count', sa.Integer(), default=0),
|
||||
sa.Column('created_at', sa.DateTime()),
|
||||
sa.Column('expires_at', sa.DateTime(), nullable=True),
|
||||
)
|
||||
op.create_index('ix_api_keys_prefix', 'api_keys', ['key_prefix'])
|
||||
op.create_index('ix_api_keys_hash', 'api_keys', ['key_hash'])
|
||||
|
||||
# Create usage_logs table
|
||||
op.create_table(
|
||||
'usage_logs',
|
||||
sa.Column('id', sa.String(36), primary_key=True),
|
||||
sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id', ondelete='CASCADE'), nullable=False),
|
||||
sa.Column('date', sa.DateTime(), nullable=False),
|
||||
sa.Column('documents_count', sa.Integer(), default=0),
|
||||
sa.Column('pages_count', sa.Integer(), default=0),
|
||||
sa.Column('characters_count', sa.BigInteger(), default=0),
|
||||
sa.Column('api_calls_count', sa.Integer(), default=0),
|
||||
sa.Column('provider_breakdown', sa.JSON(), default=dict),
|
||||
)
|
||||
op.create_index('ix_usage_logs_user_date', 'usage_logs', ['user_id', 'date'], unique=True)
|
||||
|
||||
# Create payment_history table
|
||||
op.create_table(
|
||||
'payment_history',
|
||||
sa.Column('id', sa.String(36), primary_key=True),
|
||||
sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id', ondelete='CASCADE'), nullable=False),
|
||||
sa.Column('stripe_payment_intent_id', sa.String(255), nullable=True),
|
||||
sa.Column('stripe_invoice_id', sa.String(255), nullable=True),
|
||||
sa.Column('amount_cents', sa.Integer(), nullable=False),
|
||||
sa.Column('currency', sa.String(3), default='usd'),
|
||||
sa.Column('payment_type', sa.String(50), nullable=False),
|
||||
sa.Column('status', sa.String(20), nullable=False),
|
||||
sa.Column('description', sa.String(255), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime()),
|
||||
)
|
||||
op.create_index('ix_payment_history_user', 'payment_history', ['user_id', 'created_at'])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table('payment_history')
|
||||
op.drop_table('usage_logs')
|
||||
op.drop_table('api_keys')
|
||||
op.drop_table('translations')
|
||||
op.drop_table('users')
|
||||
17
database/__init__.py
Normal file
17
database/__init__.py
Normal file
@ -0,0 +1,17 @@
|
||||
"""
|
||||
Database module for the Document Translation API
|
||||
Provides PostgreSQL support with async SQLAlchemy
|
||||
"""
|
||||
from database.connection import get_db, engine, SessionLocal, init_db
|
||||
from database.models import User, Subscription, Translation, ApiKey
|
||||
|
||||
__all__ = [
|
||||
"get_db",
|
||||
"engine",
|
||||
"SessionLocal",
|
||||
"init_db",
|
||||
"User",
|
||||
"Subscription",
|
||||
"Translation",
|
||||
"ApiKey"
|
||||
]
|
||||
139
database/connection.py
Normal file
139
database/connection.py
Normal file
@ -0,0 +1,139 @@
|
||||
"""
|
||||
Database connection and session management
|
||||
Supports both PostgreSQL (production) and SQLite (development/testing)
|
||||
"""
|
||||
import os
|
||||
import logging
|
||||
from typing import Generator, Optional
|
||||
from contextlib import contextmanager
|
||||
|
||||
from sqlalchemy import create_engine, event
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
from sqlalchemy.pool import QueuePool, StaticPool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Database URL from environment
|
||||
# PostgreSQL: postgresql://user:password@host:port/database
|
||||
# SQLite: sqlite:///./data/translate.db
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "")
|
||||
|
||||
# Determine if we're using SQLite or PostgreSQL
|
||||
_is_sqlite = DATABASE_URL.startswith("sqlite") if DATABASE_URL else True
|
||||
|
||||
# Create engine based on database type
|
||||
if DATABASE_URL and not _is_sqlite:
|
||||
# PostgreSQL configuration
|
||||
engine = create_engine(
|
||||
DATABASE_URL,
|
||||
poolclass=QueuePool,
|
||||
pool_size=5,
|
||||
max_overflow=10,
|
||||
pool_timeout=30,
|
||||
pool_recycle=1800, # Recycle connections after 30 minutes
|
||||
pool_pre_ping=True, # Check connection health before use
|
||||
echo=os.getenv("DATABASE_ECHO", "false").lower() == "true",
|
||||
)
|
||||
logger.info("✅ Database configured with PostgreSQL")
|
||||
else:
|
||||
# SQLite configuration (for development/testing or when no DATABASE_URL)
|
||||
sqlite_path = os.getenv("SQLITE_PATH", "data/translate.db")
|
||||
os.makedirs(os.path.dirname(sqlite_path), exist_ok=True)
|
||||
|
||||
sqlite_url = f"sqlite:///./{sqlite_path}"
|
||||
engine = create_engine(
|
||||
sqlite_url,
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
echo=os.getenv("DATABASE_ECHO", "false").lower() == "true",
|
||||
)
|
||||
|
||||
# Enable foreign keys for SQLite
|
||||
@event.listens_for(engine, "connect")
|
||||
def set_sqlite_pragma(dbapi_connection, connection_record):
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute("PRAGMA foreign_keys=ON")
|
||||
cursor.close()
|
||||
|
||||
if not DATABASE_URL:
|
||||
logger.warning("⚠️ DATABASE_URL not set, using SQLite for development")
|
||||
else:
|
||||
logger.info(f"✅ Database configured with SQLite: {sqlite_path}")
|
||||
|
||||
# Session factory
|
||||
SessionLocal = sessionmaker(
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
bind=engine,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
|
||||
def get_db() -> Generator[Session, None, None]:
|
||||
"""
|
||||
Dependency for FastAPI to get database session.
|
||||
Usage: db: Session = Depends(get_db)
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db_session() -> Generator[Session, None, None]:
|
||||
"""
|
||||
Context manager for database session.
|
||||
Usage: with get_db_session() as db: ...
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
db.commit()
|
||||
except Exception:
|
||||
db.rollback()
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
# Alias for backward compatibility
|
||||
get_sync_session = get_db_session
|
||||
|
||||
|
||||
def init_db():
|
||||
"""
|
||||
Initialize database tables.
|
||||
Call this on application startup.
|
||||
"""
|
||||
from database.models import Base
|
||||
Base.metadata.create_all(bind=engine)
|
||||
logger.info("✅ Database tables initialized")
|
||||
|
||||
|
||||
def check_db_connection() -> bool:
|
||||
"""
|
||||
Check if database connection is healthy.
|
||||
Returns True if connection works, False otherwise.
|
||||
"""
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
conn.execute("SELECT 1")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Database connection check failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# Connection pool stats (for monitoring)
|
||||
def get_pool_stats() -> dict:
|
||||
"""Get database connection pool statistics"""
|
||||
if hasattr(engine.pool, 'status'):
|
||||
return {
|
||||
"pool_size": engine.pool.size(),
|
||||
"checked_in": engine.pool.checkedin(),
|
||||
"checked_out": engine.pool.checkedout(),
|
||||
"overflow": engine.pool.overflow(),
|
||||
}
|
||||
return {"status": "pool stats not available"}
|
||||
259
database/models.py
Normal file
259
database/models.py
Normal file
@ -0,0 +1,259 @@
|
||||
"""
|
||||
SQLAlchemy models for the Document Translation API
|
||||
"""
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
|
||||
from sqlalchemy import (
|
||||
Column, String, Integer, Float, Boolean, DateTime, Text,
|
||||
ForeignKey, Enum, Index, JSON, BigInteger
|
||||
)
|
||||
from sqlalchemy.orm import relationship, declarative_base
|
||||
from sqlalchemy.dialects.postgresql import UUID as PG_UUID
|
||||
import enum
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
def generate_uuid():
|
||||
"""Generate a new UUID string"""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
class PlanType(str, enum.Enum):
|
||||
FREE = "free"
|
||||
STARTER = "starter"
|
||||
PRO = "pro"
|
||||
BUSINESS = "business"
|
||||
ENTERPRISE = "enterprise"
|
||||
|
||||
|
||||
class SubscriptionStatus(str, enum.Enum):
|
||||
ACTIVE = "active"
|
||||
CANCELED = "canceled"
|
||||
PAST_DUE = "past_due"
|
||||
TRIALING = "trialing"
|
||||
PAUSED = "paused"
|
||||
|
||||
|
||||
class User(Base):
|
||||
"""User model for authentication and billing"""
|
||||
__tablename__ = "users"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=generate_uuid)
|
||||
email = Column(String(255), unique=True, nullable=False, index=True)
|
||||
name = Column(String(255), nullable=False)
|
||||
password_hash = Column(String(255), nullable=False)
|
||||
|
||||
# Account status
|
||||
email_verified = Column(Boolean, default=False)
|
||||
is_active = Column(Boolean, default=True)
|
||||
avatar_url = Column(String(500), nullable=True)
|
||||
|
||||
# Subscription info
|
||||
plan = Column(Enum(PlanType), default=PlanType.FREE)
|
||||
subscription_status = Column(Enum(SubscriptionStatus), default=SubscriptionStatus.ACTIVE)
|
||||
|
||||
# Stripe integration
|
||||
stripe_customer_id = Column(String(255), nullable=True, index=True)
|
||||
stripe_subscription_id = Column(String(255), nullable=True)
|
||||
|
||||
# Usage tracking (reset monthly)
|
||||
docs_translated_this_month = Column(Integer, default=0)
|
||||
pages_translated_this_month = Column(Integer, default=0)
|
||||
api_calls_this_month = Column(Integer, default=0)
|
||||
extra_credits = Column(Integer, default=0) # Purchased credits
|
||||
usage_reset_date = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
last_login_at = Column(DateTime, nullable=True)
|
||||
|
||||
# Relationships
|
||||
translations = relationship("Translation", back_populates="user", lazy="dynamic")
|
||||
api_keys = relationship("ApiKey", back_populates="user", lazy="dynamic")
|
||||
|
||||
# Indexes
|
||||
__table_args__ = (
|
||||
Index('ix_users_email_active', 'email', 'is_active'),
|
||||
Index('ix_users_stripe_customer', 'stripe_customer_id'),
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert user to dictionary for API response"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"email": self.email,
|
||||
"name": self.name,
|
||||
"avatar_url": self.avatar_url,
|
||||
"plan": self.plan.value if self.plan else "free",
|
||||
"subscription_status": self.subscription_status.value if self.subscription_status else "active",
|
||||
"docs_translated_this_month": self.docs_translated_this_month,
|
||||
"pages_translated_this_month": self.pages_translated_this_month,
|
||||
"api_calls_this_month": self.api_calls_this_month,
|
||||
"extra_credits": self.extra_credits,
|
||||
"email_verified": self.email_verified,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
class Translation(Base):
|
||||
"""Translation history for analytics and billing"""
|
||||
__tablename__ = "translations"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=generate_uuid)
|
||||
user_id = Column(String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
|
||||
|
||||
# File info
|
||||
original_filename = Column(String(255), nullable=False)
|
||||
file_type = Column(String(10), nullable=False) # xlsx, docx, pptx
|
||||
file_size_bytes = Column(BigInteger, default=0)
|
||||
page_count = Column(Integer, default=0)
|
||||
|
||||
# Translation details
|
||||
source_language = Column(String(10), default="auto")
|
||||
target_language = Column(String(10), nullable=False)
|
||||
provider = Column(String(50), nullable=False) # google, deepl, ollama, etc.
|
||||
|
||||
# Processing info
|
||||
status = Column(String(20), default="pending") # pending, processing, completed, failed
|
||||
error_message = Column(Text, nullable=True)
|
||||
processing_time_ms = Column(Integer, nullable=True)
|
||||
|
||||
# Cost tracking (for paid providers)
|
||||
characters_translated = Column(Integer, default=0)
|
||||
estimated_cost_usd = Column(Float, default=0.0)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
completed_at = Column(DateTime, nullable=True)
|
||||
|
||||
# Relationship
|
||||
user = relationship("User", back_populates="translations")
|
||||
|
||||
# Indexes
|
||||
__table_args__ = (
|
||||
Index('ix_translations_user_date', 'user_id', 'created_at'),
|
||||
Index('ix_translations_status', 'status'),
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"id": self.id,
|
||||
"original_filename": self.original_filename,
|
||||
"file_type": self.file_type,
|
||||
"file_size_bytes": self.file_size_bytes,
|
||||
"page_count": self.page_count,
|
||||
"source_language": self.source_language,
|
||||
"target_language": self.target_language,
|
||||
"provider": self.provider,
|
||||
"status": self.status,
|
||||
"processing_time_ms": self.processing_time_ms,
|
||||
"characters_translated": self.characters_translated,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
|
||||
}
|
||||
|
||||
|
||||
class ApiKey(Base):
|
||||
"""API keys for programmatic access"""
|
||||
__tablename__ = "api_keys"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=generate_uuid)
|
||||
user_id = Column(String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
|
||||
|
||||
# Key info
|
||||
name = Column(String(100), nullable=False) # User-friendly name
|
||||
key_hash = Column(String(255), nullable=False) # SHA256 of the key
|
||||
key_prefix = Column(String(10), nullable=False) # First 8 chars for identification
|
||||
|
||||
# Permissions
|
||||
is_active = Column(Boolean, default=True)
|
||||
scopes = Column(JSON, default=list) # ["translate", "read", "write"]
|
||||
|
||||
# Usage tracking
|
||||
last_used_at = Column(DateTime, nullable=True)
|
||||
usage_count = Column(Integer, default=0)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
expires_at = Column(DateTime, nullable=True)
|
||||
|
||||
# Relationship
|
||||
user = relationship("User", back_populates="api_keys")
|
||||
|
||||
# Indexes
|
||||
__table_args__ = (
|
||||
Index('ix_api_keys_prefix', 'key_prefix'),
|
||||
Index('ix_api_keys_hash', 'key_hash'),
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"key_prefix": self.key_prefix,
|
||||
"is_active": self.is_active,
|
||||
"scopes": self.scopes,
|
||||
"last_used_at": self.last_used_at.isoformat() if self.last_used_at else None,
|
||||
"usage_count": self.usage_count,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
|
||||
}
|
||||
|
||||
|
||||
class UsageLog(Base):
|
||||
"""Daily usage aggregation for billing and analytics"""
|
||||
__tablename__ = "usage_logs"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=generate_uuid)
|
||||
user_id = Column(String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
|
||||
|
||||
# Date (for daily aggregation)
|
||||
date = Column(DateTime, nullable=False, index=True)
|
||||
|
||||
# Aggregated counts
|
||||
documents_count = Column(Integer, default=0)
|
||||
pages_count = Column(Integer, default=0)
|
||||
characters_count = Column(BigInteger, default=0)
|
||||
api_calls_count = Column(Integer, default=0)
|
||||
|
||||
# By provider breakdown (JSON)
|
||||
provider_breakdown = Column(JSON, default=dict)
|
||||
|
||||
# Indexes
|
||||
__table_args__ = (
|
||||
Index('ix_usage_logs_user_date', 'user_id', 'date', unique=True),
|
||||
)
|
||||
|
||||
|
||||
class PaymentHistory(Base):
|
||||
"""Payment and invoice history"""
|
||||
__tablename__ = "payment_history"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=generate_uuid)
|
||||
user_id = Column(String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
|
||||
|
||||
# Stripe info
|
||||
stripe_payment_intent_id = Column(String(255), nullable=True)
|
||||
stripe_invoice_id = Column(String(255), nullable=True)
|
||||
|
||||
# Payment details
|
||||
amount_cents = Column(Integer, nullable=False)
|
||||
currency = Column(String(3), default="usd")
|
||||
payment_type = Column(String(50), nullable=False) # subscription, credits, one_time
|
||||
status = Column(String(20), nullable=False) # succeeded, failed, pending, refunded
|
||||
|
||||
# Description
|
||||
description = Column(String(255), nullable=True)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
# Indexes
|
||||
__table_args__ = (
|
||||
Index('ix_payment_history_user', 'user_id', 'created_at'),
|
||||
)
|
||||
341
database/repositories.py
Normal file
341
database/repositories.py
Normal file
@ -0,0 +1,341 @@
|
||||
"""
|
||||
Repository layer for database operations
|
||||
Provides clean interface for CRUD operations
|
||||
"""
|
||||
import hashlib
|
||||
import secrets
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_, func, or_
|
||||
|
||||
from database.models import (
|
||||
User, Translation, ApiKey, UsageLog, PaymentHistory,
|
||||
PlanType, SubscriptionStatus
|
||||
)
|
||||
|
||||
|
||||
class UserRepository:
|
||||
"""Repository for User database operations"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def get_by_id(self, user_id: str) -> Optional[User]:
|
||||
"""Get user by ID"""
|
||||
return self.db.query(User).filter(User.id == user_id).first()
|
||||
|
||||
def get_by_email(self, email: str) -> Optional[User]:
|
||||
"""Get user by email (case-insensitive)"""
|
||||
return self.db.query(User).filter(
|
||||
func.lower(User.email) == email.lower()
|
||||
).first()
|
||||
|
||||
def get_by_stripe_customer(self, stripe_customer_id: str) -> Optional[User]:
|
||||
"""Get user by Stripe customer ID"""
|
||||
return self.db.query(User).filter(
|
||||
User.stripe_customer_id == stripe_customer_id
|
||||
).first()
|
||||
|
||||
def create(
|
||||
self,
|
||||
email: str,
|
||||
name: str,
|
||||
password_hash: str,
|
||||
plan: PlanType = PlanType.FREE
|
||||
) -> User:
|
||||
"""Create a new user"""
|
||||
user = User(
|
||||
email=email.lower(),
|
||||
name=name,
|
||||
password_hash=password_hash,
|
||||
plan=plan,
|
||||
subscription_status=SubscriptionStatus.ACTIVE,
|
||||
)
|
||||
self.db.add(user)
|
||||
self.db.commit()
|
||||
self.db.refresh(user)
|
||||
return user
|
||||
|
||||
def update(self, user_id: str, **kwargs) -> Optional[User]:
|
||||
"""Update user fields"""
|
||||
user = self.get_by_id(user_id)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(user, key):
|
||||
setattr(user, key, value)
|
||||
|
||||
user.updated_at = datetime.utcnow()
|
||||
self.db.commit()
|
||||
self.db.refresh(user)
|
||||
return user
|
||||
|
||||
def delete(self, user_id: str) -> bool:
|
||||
"""Delete a user"""
|
||||
user = self.get_by_id(user_id)
|
||||
if not user:
|
||||
return False
|
||||
|
||||
self.db.delete(user)
|
||||
self.db.commit()
|
||||
return True
|
||||
|
||||
def increment_usage(
|
||||
self,
|
||||
user_id: str,
|
||||
docs: int = 0,
|
||||
pages: int = 0,
|
||||
api_calls: int = 0
|
||||
) -> Optional[User]:
|
||||
"""Increment usage counters"""
|
||||
user = self.get_by_id(user_id)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
# Check if usage needs to be reset (monthly)
|
||||
if user.usage_reset_date:
|
||||
now = datetime.utcnow()
|
||||
if now.month != user.usage_reset_date.month or now.year != user.usage_reset_date.year:
|
||||
user.docs_translated_this_month = 0
|
||||
user.pages_translated_this_month = 0
|
||||
user.api_calls_this_month = 0
|
||||
user.usage_reset_date = now
|
||||
|
||||
user.docs_translated_this_month += docs
|
||||
user.pages_translated_this_month += pages
|
||||
user.api_calls_this_month += api_calls
|
||||
|
||||
self.db.commit()
|
||||
self.db.refresh(user)
|
||||
return user
|
||||
|
||||
def add_credits(self, user_id: str, credits: int) -> Optional[User]:
|
||||
"""Add extra credits to user"""
|
||||
user = self.get_by_id(user_id)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
user.extra_credits += credits
|
||||
self.db.commit()
|
||||
self.db.refresh(user)
|
||||
return user
|
||||
|
||||
def use_credits(self, user_id: str, credits: int) -> bool:
|
||||
"""Use credits from user balance"""
|
||||
user = self.get_by_id(user_id)
|
||||
if not user or user.extra_credits < credits:
|
||||
return False
|
||||
|
||||
user.extra_credits -= credits
|
||||
self.db.commit()
|
||||
return True
|
||||
|
||||
def get_all_users(
|
||||
self,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
plan: Optional[PlanType] = None
|
||||
) -> List[User]:
|
||||
"""Get all users with pagination"""
|
||||
query = self.db.query(User)
|
||||
if plan:
|
||||
query = query.filter(User.plan == plan)
|
||||
return query.offset(skip).limit(limit).all()
|
||||
|
||||
def count_users(self, plan: Optional[PlanType] = None) -> int:
|
||||
"""Count total users"""
|
||||
query = self.db.query(func.count(User.id))
|
||||
if plan:
|
||||
query = query.filter(User.plan == plan)
|
||||
return query.scalar()
|
||||
|
||||
|
||||
class TranslationRepository:
|
||||
"""Repository for Translation database operations"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def create(
|
||||
self,
|
||||
user_id: str,
|
||||
original_filename: str,
|
||||
file_type: str,
|
||||
target_language: str,
|
||||
provider: str,
|
||||
source_language: str = "auto",
|
||||
file_size_bytes: int = 0,
|
||||
page_count: int = 0,
|
||||
) -> Translation:
|
||||
"""Create a new translation record"""
|
||||
translation = Translation(
|
||||
user_id=user_id,
|
||||
original_filename=original_filename,
|
||||
file_type=file_type,
|
||||
file_size_bytes=file_size_bytes,
|
||||
page_count=page_count,
|
||||
source_language=source_language,
|
||||
target_language=target_language,
|
||||
provider=provider,
|
||||
status="pending",
|
||||
)
|
||||
self.db.add(translation)
|
||||
self.db.commit()
|
||||
self.db.refresh(translation)
|
||||
return translation
|
||||
|
||||
def update_status(
|
||||
self,
|
||||
translation_id: str,
|
||||
status: str,
|
||||
error_message: Optional[str] = None,
|
||||
processing_time_ms: Optional[int] = None,
|
||||
characters_translated: Optional[int] = None,
|
||||
) -> Optional[Translation]:
|
||||
"""Update translation status"""
|
||||
translation = self.db.query(Translation).filter(
|
||||
Translation.id == translation_id
|
||||
).first()
|
||||
|
||||
if not translation:
|
||||
return None
|
||||
|
||||
translation.status = status
|
||||
if error_message:
|
||||
translation.error_message = error_message
|
||||
if processing_time_ms:
|
||||
translation.processing_time_ms = processing_time_ms
|
||||
if characters_translated:
|
||||
translation.characters_translated = characters_translated
|
||||
if status == "completed":
|
||||
translation.completed_at = datetime.utcnow()
|
||||
|
||||
self.db.commit()
|
||||
self.db.refresh(translation)
|
||||
return translation
|
||||
|
||||
def get_user_translations(
|
||||
self,
|
||||
user_id: str,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
status: Optional[str] = None,
|
||||
) -> List[Translation]:
|
||||
"""Get user's translation history"""
|
||||
query = self.db.query(Translation).filter(Translation.user_id == user_id)
|
||||
if status:
|
||||
query = query.filter(Translation.status == status)
|
||||
return query.order_by(Translation.created_at.desc()).offset(skip).limit(limit).all()
|
||||
|
||||
def get_user_stats(self, user_id: str, days: int = 30) -> Dict[str, Any]:
|
||||
"""Get user's translation statistics"""
|
||||
since = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
result = self.db.query(
|
||||
func.count(Translation.id).label("total_translations"),
|
||||
func.sum(Translation.page_count).label("total_pages"),
|
||||
func.sum(Translation.characters_translated).label("total_characters"),
|
||||
).filter(
|
||||
and_(
|
||||
Translation.user_id == user_id,
|
||||
Translation.created_at >= since,
|
||||
Translation.status == "completed",
|
||||
)
|
||||
).first()
|
||||
|
||||
return {
|
||||
"total_translations": result.total_translations or 0,
|
||||
"total_pages": result.total_pages or 0,
|
||||
"total_characters": result.total_characters or 0,
|
||||
"period_days": days,
|
||||
}
|
||||
|
||||
|
||||
class ApiKeyRepository:
|
||||
"""Repository for API Key database operations"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
@staticmethod
|
||||
def hash_key(key: str) -> str:
|
||||
"""Hash an API key"""
|
||||
return hashlib.sha256(key.encode()).hexdigest()
|
||||
|
||||
def create(
|
||||
self,
|
||||
user_id: str,
|
||||
name: str,
|
||||
scopes: List[str] = None,
|
||||
expires_in_days: Optional[int] = None,
|
||||
) -> tuple[ApiKey, str]:
|
||||
"""Create a new API key. Returns (ApiKey, raw_key)"""
|
||||
# Generate a secure random key
|
||||
raw_key = f"tr_{secrets.token_urlsafe(32)}"
|
||||
key_hash = self.hash_key(raw_key)
|
||||
key_prefix = raw_key[:10]
|
||||
|
||||
expires_at = None
|
||||
if expires_in_days:
|
||||
expires_at = datetime.utcnow() + timedelta(days=expires_in_days)
|
||||
|
||||
api_key = ApiKey(
|
||||
user_id=user_id,
|
||||
name=name,
|
||||
key_hash=key_hash,
|
||||
key_prefix=key_prefix,
|
||||
scopes=scopes or ["translate"],
|
||||
expires_at=expires_at,
|
||||
)
|
||||
self.db.add(api_key)
|
||||
self.db.commit()
|
||||
self.db.refresh(api_key)
|
||||
|
||||
return api_key, raw_key
|
||||
|
||||
def get_by_key(self, raw_key: str) -> Optional[ApiKey]:
|
||||
"""Get API key by raw key value"""
|
||||
key_hash = self.hash_key(raw_key)
|
||||
api_key = self.db.query(ApiKey).filter(
|
||||
and_(
|
||||
ApiKey.key_hash == key_hash,
|
||||
ApiKey.is_active == True,
|
||||
)
|
||||
).first()
|
||||
|
||||
if api_key:
|
||||
# Check expiration
|
||||
if api_key.expires_at and api_key.expires_at < datetime.utcnow():
|
||||
return None
|
||||
|
||||
# Update last used
|
||||
api_key.last_used_at = datetime.utcnow()
|
||||
api_key.usage_count += 1
|
||||
self.db.commit()
|
||||
|
||||
return api_key
|
||||
|
||||
def get_user_keys(self, user_id: str) -> List[ApiKey]:
|
||||
"""Get all API keys for a user"""
|
||||
return self.db.query(ApiKey).filter(
|
||||
ApiKey.user_id == user_id
|
||||
).order_by(ApiKey.created_at.desc()).all()
|
||||
|
||||
def revoke(self, key_id: str, user_id: str) -> bool:
|
||||
"""Revoke an API key"""
|
||||
api_key = self.db.query(ApiKey).filter(
|
||||
and_(
|
||||
ApiKey.id == key_id,
|
||||
ApiKey.user_id == user_id,
|
||||
)
|
||||
).first()
|
||||
|
||||
if not api_key:
|
||||
return False
|
||||
|
||||
api_key.is_active = False
|
||||
self.db.commit()
|
||||
return True
|
||||
@ -4,6 +4,35 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
# ===========================================
|
||||
# PostgreSQL Database
|
||||
# ===========================================
|
||||
postgres:
|
||||
image: postgres:16-alpine
|
||||
container_name: translate-postgres
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
- POSTGRES_USER=${POSTGRES_USER:-translate}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-translate_secret_123}
|
||||
- POSTGRES_DB=${POSTGRES_DB:-translate_db}
|
||||
- PGDATA=/var/lib/postgresql/data/pgdata
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
networks:
|
||||
- translate-network
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-translate} -d ${POSTGRES_DB:-translate_db}"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
start_period: 10s
|
||||
deploy:
|
||||
resources:
|
||||
limits:
|
||||
memory: 512M
|
||||
reservations:
|
||||
memory: 128M
|
||||
|
||||
# ===========================================
|
||||
# Backend API Service
|
||||
# ===========================================
|
||||
@ -14,23 +43,45 @@ services:
|
||||
container_name: translate-backend
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
# Database
|
||||
- DATABASE_URL=postgresql+asyncpg://${POSTGRES_USER:-translate}:${POSTGRES_PASSWORD:-translate_secret_123}@postgres:5432/${POSTGRES_DB:-translate_db}
|
||||
# Redis
|
||||
- REDIS_URL=redis://redis:6379/0
|
||||
# Translation Services
|
||||
- TRANSLATION_SERVICE=${TRANSLATION_SERVICE:-ollama}
|
||||
- OLLAMA_BASE_URL=${OLLAMA_BASE_URL:-http://ollama:11434}
|
||||
- OLLAMA_MODEL=${OLLAMA_MODEL:-llama3}
|
||||
- DEEPL_API_KEY=${DEEPL_API_KEY:-}
|
||||
- OPENAI_API_KEY=${OPENAI_API_KEY:-}
|
||||
- OPENROUTER_API_KEY=${OPENROUTER_API_KEY:-}
|
||||
# File Limits
|
||||
- MAX_FILE_SIZE_MB=${MAX_FILE_SIZE_MB:-50}
|
||||
# Rate Limiting
|
||||
- RATE_LIMIT_REQUESTS_PER_MINUTE=${RATE_LIMIT_REQUESTS_PER_MINUTE:-60}
|
||||
- RATE_LIMIT_TRANSLATIONS_PER_MINUTE=${RATE_LIMIT_TRANSLATIONS_PER_MINUTE:-10}
|
||||
- ADMIN_USERNAME=${ADMIN_USERNAME:-admin}
|
||||
- ADMIN_PASSWORD=${ADMIN_PASSWORD:-changeme123}
|
||||
- CORS_ORIGINS=${CORS_ORIGINS:-*}
|
||||
# Admin Auth (CHANGE IN PRODUCTION!)
|
||||
- ADMIN_USERNAME=${ADMIN_USERNAME}
|
||||
- ADMIN_PASSWORD=${ADMIN_PASSWORD}
|
||||
# Security
|
||||
- JWT_SECRET=${JWT_SECRET}
|
||||
- CORS_ORIGINS=${CORS_ORIGINS:-https://yourdomain.com}
|
||||
# Stripe Payments
|
||||
- STRIPE_SECRET_KEY=${STRIPE_SECRET_KEY:-}
|
||||
- STRIPE_WEBHOOK_SECRET=${STRIPE_WEBHOOK_SECRET:-}
|
||||
- STRIPE_STARTER_PRICE_ID=${STRIPE_STARTER_PRICE_ID:-}
|
||||
- STRIPE_PRO_PRICE_ID=${STRIPE_PRO_PRICE_ID:-}
|
||||
- STRIPE_BUSINESS_PRICE_ID=${STRIPE_BUSINESS_PRICE_ID:-}
|
||||
volumes:
|
||||
- uploads_data:/app/uploads
|
||||
- outputs_data:/app/outputs
|
||||
- logs_data:/app/logs
|
||||
networks:
|
||||
- translate-network
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
||||
interval: 30s
|
||||
@ -117,7 +168,7 @@ services:
|
||||
- with-ollama
|
||||
|
||||
# ===========================================
|
||||
# Redis (Optional - For caching & sessions)
|
||||
# Redis (Caching & Sessions)
|
||||
# ===========================================
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
@ -130,11 +181,9 @@ services:
|
||||
- translate-network
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "ping"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
profiles:
|
||||
- with-cache
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
|
||||
# ===========================================
|
||||
# Prometheus (Optional - Monitoring)
|
||||
@ -190,6 +239,8 @@ networks:
|
||||
# Volumes
|
||||
# ===========================================
|
||||
volumes:
|
||||
postgres_data:
|
||||
driver: local
|
||||
uploads_data:
|
||||
driver: local
|
||||
outputs_data:
|
||||
|
||||
12
main.py
12
main.py
@ -204,6 +204,18 @@ async def lifespan(app: FastAPI):
|
||||
# Startup
|
||||
logger.info("Starting Document Translation API...")
|
||||
config.ensure_directories()
|
||||
|
||||
# Initialize database
|
||||
try:
|
||||
from database.connection import init_db, check_db_connection
|
||||
init_db()
|
||||
if check_db_connection():
|
||||
logger.info("✅ Database connection verified")
|
||||
else:
|
||||
logger.warning("⚠️ Database connection check failed")
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ Database initialization skipped: {e}")
|
||||
|
||||
await cleanup_manager.start()
|
||||
logger.info("API ready to accept requests")
|
||||
|
||||
|
||||
@ -28,8 +28,7 @@ stripe==7.0.0
|
||||
# Session storage & caching (optional but recommended for production)
|
||||
redis==5.0.1
|
||||
|
||||
# Database (optional but recommended for production)
|
||||
# sqlalchemy==2.0.25
|
||||
# asyncpg==0.29.0 # PostgreSQL async driver
|
||||
# alembic==1.13.1 # Database migrations
|
||||
|
||||
# Database (recommended for production)
|
||||
sqlalchemy==2.0.25
|
||||
psycopg2-binary==2.9.9 # PostgreSQL driver
|
||||
alembic==1.13.1 # Database migrations
|
||||
|
||||
160
scripts/migrate_to_db.py
Normal file
160
scripts/migrate_to_db.py
Normal file
@ -0,0 +1,160 @@
|
||||
"""
|
||||
Migration script to move data from JSON files to database
|
||||
Run this once to migrate existing users to the new database system
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from database.connection import init_db, get_db_session
|
||||
from database.repositories import UserRepository
|
||||
from database.models import PlanType, SubscriptionStatus
|
||||
|
||||
|
||||
def migrate_users_from_json():
|
||||
"""Migrate users from JSON file to database"""
|
||||
json_path = Path("data/users.json")
|
||||
|
||||
if not json_path.exists():
|
||||
print("No users.json found, nothing to migrate")
|
||||
return 0
|
||||
|
||||
# Initialize database
|
||||
print("Initializing database tables...")
|
||||
init_db()
|
||||
|
||||
# Load JSON data
|
||||
with open(json_path, 'r') as f:
|
||||
users_data = json.load(f)
|
||||
|
||||
print(f"Found {len(users_data)} users to migrate")
|
||||
|
||||
migrated = 0
|
||||
skipped = 0
|
||||
errors = 0
|
||||
|
||||
with get_db_session() as db:
|
||||
repo = UserRepository(db)
|
||||
|
||||
for user_id, user_data in users_data.items():
|
||||
try:
|
||||
# Check if user already exists
|
||||
existing = repo.get_by_email(user_data.get('email', ''))
|
||||
if existing:
|
||||
print(f" Skipping {user_data.get('email')} - already exists")
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
# Map plan string to enum
|
||||
plan_str = user_data.get('plan', 'free')
|
||||
try:
|
||||
plan = PlanType(plan_str)
|
||||
except ValueError:
|
||||
plan = PlanType.FREE
|
||||
|
||||
# Map subscription status
|
||||
status_str = user_data.get('subscription_status', 'active')
|
||||
try:
|
||||
status = SubscriptionStatus(status_str)
|
||||
except ValueError:
|
||||
status = SubscriptionStatus.ACTIVE
|
||||
|
||||
# Create user with original ID
|
||||
from database.models import User
|
||||
user = User(
|
||||
id=user_id,
|
||||
email=user_data.get('email', '').lower(),
|
||||
name=user_data.get('name', ''),
|
||||
password_hash=user_data.get('password_hash', ''),
|
||||
email_verified=user_data.get('email_verified', False),
|
||||
avatar_url=user_data.get('avatar_url'),
|
||||
plan=plan,
|
||||
subscription_status=status,
|
||||
stripe_customer_id=user_data.get('stripe_customer_id'),
|
||||
stripe_subscription_id=user_data.get('stripe_subscription_id'),
|
||||
docs_translated_this_month=user_data.get('docs_translated_this_month', 0),
|
||||
pages_translated_this_month=user_data.get('pages_translated_this_month', 0),
|
||||
api_calls_this_month=user_data.get('api_calls_this_month', 0),
|
||||
extra_credits=user_data.get('extra_credits', 0),
|
||||
)
|
||||
|
||||
# Parse dates
|
||||
if user_data.get('created_at'):
|
||||
try:
|
||||
user.created_at = datetime.fromisoformat(user_data['created_at'].replace('Z', '+00:00'))
|
||||
except:
|
||||
pass
|
||||
|
||||
if user_data.get('updated_at'):
|
||||
try:
|
||||
user.updated_at = datetime.fromisoformat(user_data['updated_at'].replace('Z', '+00:00'))
|
||||
except:
|
||||
pass
|
||||
|
||||
db.add(user)
|
||||
db.commit()
|
||||
|
||||
print(f" Migrated: {user.email}")
|
||||
migrated += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f" Error migrating {user_data.get('email', user_id)}: {e}")
|
||||
errors += 1
|
||||
db.rollback()
|
||||
|
||||
print(f"\nMigration complete:")
|
||||
print(f" Migrated: {migrated}")
|
||||
print(f" Skipped: {skipped}")
|
||||
print(f" Errors: {errors}")
|
||||
|
||||
# Backup original file
|
||||
if migrated > 0:
|
||||
backup_path = json_path.with_suffix('.json.bak')
|
||||
os.rename(json_path, backup_path)
|
||||
print(f"\nOriginal file backed up to: {backup_path}")
|
||||
|
||||
return migrated
|
||||
|
||||
|
||||
def verify_migration():
|
||||
"""Verify the migration was successful"""
|
||||
from database.connection import get_db_session
|
||||
from database.repositories import UserRepository
|
||||
|
||||
with get_db_session() as db:
|
||||
repo = UserRepository(db)
|
||||
count = repo.count_users()
|
||||
print(f"\nDatabase now contains {count} users")
|
||||
|
||||
# List first 5 users
|
||||
users = repo.get_all_users(limit=5)
|
||||
if users:
|
||||
print("\nSample users:")
|
||||
for user in users:
|
||||
print(f" - {user.email} ({user.plan.value})")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 50)
|
||||
print("JSON to Database Migration Script")
|
||||
print("=" * 50)
|
||||
|
||||
# Check environment
|
||||
db_url = os.getenv("DATABASE_URL", "")
|
||||
if db_url:
|
||||
print(f"Database: PostgreSQL")
|
||||
else:
|
||||
print(f"Database: SQLite (development)")
|
||||
|
||||
print()
|
||||
|
||||
# Run migration
|
||||
migrate_users_from_json()
|
||||
|
||||
# Verify
|
||||
verify_migration()
|
||||
@ -1,5 +1,9 @@
|
||||
"""
|
||||
Authentication service with JWT tokens and password hashing
|
||||
|
||||
This service provides user authentication with automatic backend selection:
|
||||
- If DATABASE_URL is configured: Uses PostgreSQL database
|
||||
- Otherwise: Falls back to JSON file storage (development mode)
|
||||
"""
|
||||
import os
|
||||
import secrets
|
||||
@ -8,6 +12,9 @@ from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any
|
||||
import json
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Try to import optional dependencies
|
||||
try:
|
||||
@ -15,6 +22,7 @@ try:
|
||||
JWT_AVAILABLE = True
|
||||
except ImportError:
|
||||
JWT_AVAILABLE = False
|
||||
logger.warning("PyJWT not installed. Using fallback token encoding.")
|
||||
|
||||
try:
|
||||
from passlib.context import CryptContext
|
||||
@ -22,17 +30,37 @@ try:
|
||||
PASSLIB_AVAILABLE = True
|
||||
except ImportError:
|
||||
PASSLIB_AVAILABLE = False
|
||||
logger.warning("passlib not installed. Using SHA256 fallback for password hashing.")
|
||||
|
||||
# Check if database is configured
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "")
|
||||
USE_DATABASE = bool(DATABASE_URL and DATABASE_URL.startswith("postgresql"))
|
||||
|
||||
if USE_DATABASE:
|
||||
try:
|
||||
from database.repositories import UserRepository
|
||||
from database.connection import get_sync_session, init_db as _init_db
|
||||
from database import models as db_models
|
||||
DATABASE_AVAILABLE = True
|
||||
logger.info("Database backend enabled for authentication")
|
||||
except ImportError as e:
|
||||
DATABASE_AVAILABLE = False
|
||||
USE_DATABASE = False
|
||||
logger.warning(f"Database modules not available: {e}. Using JSON storage.")
|
||||
else:
|
||||
DATABASE_AVAILABLE = False
|
||||
logger.info("Using JSON file storage for authentication (DATABASE_URL not configured)")
|
||||
|
||||
from models.subscription import User, UserCreate, PlanType, SubscriptionStatus, PLANS
|
||||
|
||||
|
||||
# Configuration
|
||||
SECRET_KEY = os.getenv("JWT_SECRET_KEY", secrets.token_urlsafe(32))
|
||||
SECRET_KEY = os.getenv("JWT_SECRET", os.getenv("JWT_SECRET_KEY", secrets.token_urlsafe(32)))
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_HOURS = 24
|
||||
REFRESH_TOKEN_EXPIRE_DAYS = 30
|
||||
|
||||
# Simple file-based storage (replace with database in production)
|
||||
# Simple file-based storage (used when database is not configured)
|
||||
USERS_FILE = Path("data/users.json")
|
||||
USERS_FILE.parent.mkdir(exist_ok=True)
|
||||
|
||||
@ -117,7 +145,7 @@ def verify_token(token: str) -> Optional[Dict[str, Any]]:
|
||||
|
||||
|
||||
def load_users() -> Dict[str, Dict]:
|
||||
"""Load users from file storage"""
|
||||
"""Load users from file storage (JSON backend only)"""
|
||||
if USERS_FILE.exists():
|
||||
try:
|
||||
with open(USERS_FILE, 'r') as f:
|
||||
@ -128,13 +156,46 @@ def load_users() -> Dict[str, Dict]:
|
||||
|
||||
|
||||
def save_users(users: Dict[str, Dict]):
|
||||
"""Save users to file storage"""
|
||||
"""Save users to file storage (JSON backend only)"""
|
||||
with open(USERS_FILE, 'w') as f:
|
||||
json.dump(users, f, indent=2, default=str)
|
||||
|
||||
|
||||
def _db_user_to_model(db_user) -> User:
|
||||
"""Convert database user model to Pydantic User model"""
|
||||
return User(
|
||||
id=str(db_user.id),
|
||||
email=db_user.email,
|
||||
name=db_user.name or "",
|
||||
password_hash=db_user.password_hash,
|
||||
avatar_url=db_user.avatar_url,
|
||||
plan=PlanType(db_user.plan) if db_user.plan else PlanType.FREE,
|
||||
subscription_status=SubscriptionStatus(db_user.subscription_status) if db_user.subscription_status else SubscriptionStatus.ACTIVE,
|
||||
stripe_customer_id=db_user.stripe_customer_id,
|
||||
stripe_subscription_id=db_user.stripe_subscription_id,
|
||||
docs_translated_this_month=db_user.docs_translated_this_month or 0,
|
||||
pages_translated_this_month=db_user.pages_translated_this_month or 0,
|
||||
api_calls_this_month=db_user.api_calls_this_month or 0,
|
||||
extra_credits=db_user.extra_credits or 0,
|
||||
usage_reset_date=db_user.usage_reset_date or datetime.utcnow(),
|
||||
default_source_lang=db_user.default_source_lang or "en",
|
||||
default_target_lang=db_user.default_target_lang or "es",
|
||||
default_provider=db_user.default_provider or "google",
|
||||
created_at=db_user.created_at or datetime.utcnow(),
|
||||
updated_at=db_user.updated_at,
|
||||
)
|
||||
|
||||
|
||||
def get_user_by_email(email: str) -> Optional[User]:
|
||||
"""Get a user by email"""
|
||||
if USE_DATABASE and DATABASE_AVAILABLE:
|
||||
with get_sync_session() as session:
|
||||
repo = UserRepository(session)
|
||||
db_user = repo.get_by_email(email)
|
||||
if db_user:
|
||||
return _db_user_to_model(db_user)
|
||||
return None
|
||||
else:
|
||||
users = load_users()
|
||||
for user_data in users.values():
|
||||
if user_data.get("email", "").lower() == email.lower():
|
||||
@ -144,6 +205,14 @@ def get_user_by_email(email: str) -> Optional[User]:
|
||||
|
||||
def get_user_by_id(user_id: str) -> Optional[User]:
|
||||
"""Get a user by ID"""
|
||||
if USE_DATABASE and DATABASE_AVAILABLE:
|
||||
with get_sync_session() as session:
|
||||
repo = UserRepository(session)
|
||||
db_user = repo.get_by_id(user_id)
|
||||
if db_user:
|
||||
return _db_user_to_model(db_user)
|
||||
return None
|
||||
else:
|
||||
users = load_users()
|
||||
if user_id in users:
|
||||
return User(**users[user_id])
|
||||
@ -152,12 +221,26 @@ def get_user_by_id(user_id: str) -> Optional[User]:
|
||||
|
||||
def create_user(user_create: UserCreate) -> User:
|
||||
"""Create a new user"""
|
||||
users = load_users()
|
||||
|
||||
# Check if email exists
|
||||
if get_user_by_email(user_create.email):
|
||||
raise ValueError("Email already registered")
|
||||
|
||||
if USE_DATABASE and DATABASE_AVAILABLE:
|
||||
with get_sync_session() as session:
|
||||
repo = UserRepository(session)
|
||||
db_user = repo.create(
|
||||
email=user_create.email,
|
||||
name=user_create.name,
|
||||
password_hash=hash_password(user_create.password),
|
||||
plan=PlanType.FREE.value,
|
||||
subscription_status=SubscriptionStatus.ACTIVE.value
|
||||
)
|
||||
session.commit()
|
||||
session.refresh(db_user)
|
||||
return _db_user_to_model(db_user)
|
||||
else:
|
||||
users = load_users()
|
||||
|
||||
# Generate user ID
|
||||
user_id = secrets.token_urlsafe(16)
|
||||
|
||||
@ -190,6 +273,16 @@ def authenticate_user(email: str, password: str) -> Optional[User]:
|
||||
|
||||
def update_user(user_id: str, updates: Dict[str, Any]) -> Optional[User]:
|
||||
"""Update a user's data"""
|
||||
if USE_DATABASE and DATABASE_AVAILABLE:
|
||||
with get_sync_session() as session:
|
||||
repo = UserRepository(session)
|
||||
db_user = repo.update(user_id, updates)
|
||||
if db_user:
|
||||
session.commit()
|
||||
session.refresh(db_user)
|
||||
return _db_user_to_model(db_user)
|
||||
return None
|
||||
else:
|
||||
users = load_users()
|
||||
if user_id not in users:
|
||||
return None
|
||||
@ -212,7 +305,7 @@ def check_usage_limits(user: User) -> Dict[str, Any]:
|
||||
"docs_translated_this_month": 0,
|
||||
"pages_translated_this_month": 0,
|
||||
"api_calls_this_month": 0,
|
||||
"usage_reset_date": now.isoformat()
|
||||
"usage_reset_date": now.isoformat() if not USE_DATABASE else now
|
||||
})
|
||||
user.docs_translated_this_month = 0
|
||||
user.pages_translated_this_month = 0
|
||||
@ -248,8 +341,8 @@ def record_usage(user_id: str, pages_count: int, use_credits: bool = False) -> b
|
||||
if use_credits:
|
||||
updates["extra_credits"] = max(0, user.extra_credits - pages_count)
|
||||
|
||||
update_user(user_id, updates)
|
||||
return True
|
||||
result = update_user(user_id, updates)
|
||||
return result is not None
|
||||
|
||||
|
||||
def add_credits(user_id: str, credits: int) -> bool:
|
||||
@ -258,5 +351,14 @@ def add_credits(user_id: str, credits: int) -> bool:
|
||||
if not user:
|
||||
return False
|
||||
|
||||
update_user(user_id, {"extra_credits": user.extra_credits + credits})
|
||||
return True
|
||||
result = update_user(user_id, {"extra_credits": user.extra_credits + credits})
|
||||
return result is not None
|
||||
|
||||
|
||||
def init_database():
|
||||
"""Initialize the database (call on application startup)"""
|
||||
if USE_DATABASE and DATABASE_AVAILABLE:
|
||||
_init_db()
|
||||
logger.info("Database initialized successfully")
|
||||
else:
|
||||
logger.info("Using JSON file storage")
|
||||
|
||||
245
services/auth_service_db.py
Normal file
245
services/auth_service_db.py
Normal file
@ -0,0 +1,245 @@
|
||||
"""
|
||||
Database-backed authentication service
|
||||
Replaces JSON file storage with SQLAlchemy
|
||||
"""
|
||||
import os
|
||||
import secrets
|
||||
import hashlib
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any
|
||||
import logging
|
||||
|
||||
# Try to import optional dependencies
|
||||
try:
|
||||
import jwt
|
||||
JWT_AVAILABLE = True
|
||||
except ImportError:
|
||||
JWT_AVAILABLE = False
|
||||
import json
|
||||
import base64
|
||||
|
||||
try:
|
||||
from passlib.context import CryptContext
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
PASSLIB_AVAILABLE = True
|
||||
except ImportError:
|
||||
PASSLIB_AVAILABLE = False
|
||||
|
||||
from database.connection import get_db_session
|
||||
from database.repositories import UserRepository
|
||||
from database.models import User, PlanType, SubscriptionStatus
|
||||
from models.subscription import PLANS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Configuration
|
||||
SECRET_KEY = os.getenv("JWT_SECRET_KEY", secrets.token_urlsafe(32))
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_HOURS = 24
|
||||
REFRESH_TOKEN_EXPIRE_DAYS = 30
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
"""Hash a password using bcrypt or fallback to SHA256"""
|
||||
if PASSLIB_AVAILABLE:
|
||||
return pwd_context.hash(password)
|
||||
else:
|
||||
salt = secrets.token_hex(16)
|
||||
hashed = hashlib.sha256(f"{salt}{password}".encode()).hexdigest()
|
||||
return f"sha256${salt}${hashed}"
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a password against its hash"""
|
||||
if PASSLIB_AVAILABLE and not hashed_password.startswith("sha256$"):
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
else:
|
||||
parts = hashed_password.split("$")
|
||||
if len(parts) == 3 and parts[0] == "sha256":
|
||||
salt = parts[1]
|
||||
expected_hash = parts[2]
|
||||
actual_hash = hashlib.sha256(f"{salt}{plain_password}".encode()).hexdigest()
|
||||
return secrets.compare_digest(actual_hash, expected_hash)
|
||||
return False
|
||||
|
||||
|
||||
def create_access_token(user_id: str, expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""Create a JWT access token"""
|
||||
expire = datetime.utcnow() + (expires_delta or timedelta(hours=ACCESS_TOKEN_EXPIRE_HOURS))
|
||||
|
||||
if not JWT_AVAILABLE:
|
||||
token_data = {"user_id": user_id, "exp": expire.isoformat(), "type": "access"}
|
||||
return base64.urlsafe_b64encode(json.dumps(token_data).encode()).decode()
|
||||
|
||||
to_encode = {"sub": user_id, "exp": expire, "type": "access"}
|
||||
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
|
||||
def create_refresh_token(user_id: str) -> str:
|
||||
"""Create a JWT refresh token"""
|
||||
expire = datetime.utcnow() + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
|
||||
if not JWT_AVAILABLE:
|
||||
token_data = {"user_id": user_id, "exp": expire.isoformat(), "type": "refresh"}
|
||||
return base64.urlsafe_b64encode(json.dumps(token_data).encode()).decode()
|
||||
|
||||
to_encode = {"sub": user_id, "exp": expire, "type": "refresh"}
|
||||
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
|
||||
def verify_token(token: str) -> Optional[Dict[str, Any]]:
|
||||
"""Verify a JWT token and return payload"""
|
||||
if not JWT_AVAILABLE:
|
||||
try:
|
||||
data = json.loads(base64.urlsafe_b64decode(token.encode()).decode())
|
||||
exp = datetime.fromisoformat(data["exp"])
|
||||
if exp < datetime.utcnow():
|
||||
return None
|
||||
return {"sub": data["user_id"], "type": data.get("type", "access")}
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
return payload
|
||||
except jwt.ExpiredSignatureError:
|
||||
return None
|
||||
except jwt.JWTError:
|
||||
return None
|
||||
|
||||
|
||||
def create_user(email: str, name: str, password: str) -> User:
|
||||
"""Create a new user in the database"""
|
||||
with get_db_session() as db:
|
||||
repo = UserRepository(db)
|
||||
|
||||
# Check if email already exists
|
||||
existing = repo.get_by_email(email)
|
||||
if existing:
|
||||
raise ValueError("Email already registered")
|
||||
|
||||
password_hash = hash_password(password)
|
||||
user = repo.create(
|
||||
email=email,
|
||||
name=name,
|
||||
password_hash=password_hash,
|
||||
plan=PlanType.FREE,
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
def authenticate_user(email: str, password: str) -> Optional[User]:
|
||||
"""Authenticate user and return user object if valid"""
|
||||
with get_db_session() as db:
|
||||
repo = UserRepository(db)
|
||||
user = repo.get_by_email(email)
|
||||
|
||||
if not user:
|
||||
return None
|
||||
|
||||
if not verify_password(password, user.password_hash):
|
||||
return None
|
||||
|
||||
# Update last login
|
||||
repo.update(user.id, last_login_at=datetime.utcnow())
|
||||
return user
|
||||
|
||||
|
||||
def get_user_by_id(user_id: str) -> Optional[User]:
|
||||
"""Get user by ID from database"""
|
||||
with get_db_session() as db:
|
||||
repo = UserRepository(db)
|
||||
return repo.get_by_id(user_id)
|
||||
|
||||
|
||||
def get_user_by_email(email: str) -> Optional[User]:
|
||||
"""Get user by email from database"""
|
||||
with get_db_session() as db:
|
||||
repo = UserRepository(db)
|
||||
return repo.get_by_email(email)
|
||||
|
||||
|
||||
def update_user(user_id: str, updates: Dict[str, Any]) -> Optional[User]:
|
||||
"""Update user fields in database"""
|
||||
with get_db_session() as db:
|
||||
repo = UserRepository(db)
|
||||
return repo.update(user_id, **updates)
|
||||
|
||||
|
||||
def add_credits(user_id: str, credits: int) -> bool:
|
||||
"""Add credits to user account"""
|
||||
with get_db_session() as db:
|
||||
repo = UserRepository(db)
|
||||
result = repo.add_credits(user_id, credits)
|
||||
return result is not None
|
||||
|
||||
|
||||
def use_credits(user_id: str, credits: int) -> bool:
|
||||
"""Use credits from user account"""
|
||||
with get_db_session() as db:
|
||||
repo = UserRepository(db)
|
||||
return repo.use_credits(user_id, credits)
|
||||
|
||||
|
||||
def increment_usage(user_id: str, docs: int = 0, pages: int = 0, api_calls: int = 0) -> bool:
|
||||
"""Increment user usage counters"""
|
||||
with get_db_session() as db:
|
||||
repo = UserRepository(db)
|
||||
result = repo.increment_usage(user_id, docs=docs, pages=pages, api_calls=api_calls)
|
||||
return result is not None
|
||||
|
||||
|
||||
def check_usage_limits(user_id: str) -> Dict[str, Any]:
|
||||
"""Check if user is within their plan limits"""
|
||||
with get_db_session() as db:
|
||||
repo = UserRepository(db)
|
||||
user = repo.get_by_id(user_id)
|
||||
|
||||
if not user:
|
||||
return {"allowed": False, "reason": "User not found"}
|
||||
|
||||
plan_config = PLANS.get(user.plan, PLANS[PlanType.FREE])
|
||||
|
||||
# Check document limit
|
||||
docs_limit = plan_config["docs_per_month"]
|
||||
if docs_limit > 0 and user.docs_translated_this_month >= docs_limit:
|
||||
# Check if user has extra credits
|
||||
if user.extra_credits <= 0:
|
||||
return {
|
||||
"allowed": False,
|
||||
"reason": "Monthly document limit reached",
|
||||
"limit": docs_limit,
|
||||
"used": user.docs_translated_this_month,
|
||||
}
|
||||
|
||||
return {
|
||||
"allowed": True,
|
||||
"docs_remaining": max(0, docs_limit - user.docs_translated_this_month) if docs_limit > 0 else -1,
|
||||
"extra_credits": user.extra_credits,
|
||||
}
|
||||
|
||||
|
||||
def get_user_usage_stats(user_id: str) -> Dict[str, Any]:
|
||||
"""Get detailed usage statistics for a user"""
|
||||
with get_db_session() as db:
|
||||
repo = UserRepository(db)
|
||||
user = repo.get_by_id(user_id)
|
||||
|
||||
if not user:
|
||||
return {}
|
||||
|
||||
plan_config = PLANS.get(user.plan, PLANS[PlanType.FREE])
|
||||
|
||||
return {
|
||||
"docs_used": user.docs_translated_this_month,
|
||||
"docs_limit": plan_config["docs_per_month"],
|
||||
"docs_remaining": max(0, plan_config["docs_per_month"] - user.docs_translated_this_month) if plan_config["docs_per_month"] > 0 else -1,
|
||||
"pages_used": user.pages_translated_this_month,
|
||||
"extra_credits": user.extra_credits,
|
||||
"max_pages_per_doc": plan_config["max_pages_per_doc"],
|
||||
"max_file_size_mb": plan_config["max_file_size_mb"],
|
||||
"allowed_providers": plan_config["providers"],
|
||||
"api_access": plan_config.get("api_access", False),
|
||||
"api_calls_used": user.api_calls_this_month if plan_config.get("api_access") else 0,
|
||||
"api_calls_limit": plan_config.get("api_calls_per_month", 0),
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user