feat: Add PostgreSQL database infrastructure
- Add SQLAlchemy models for User, Translation, ApiKey, UsageLog, PaymentHistory - Add database connection management with PostgreSQL/SQLite support - Add repository layer for CRUD operations - Add Alembic migration setup with initial migration - Update auth_service to automatically use database when DATABASE_URL is set - Update docker-compose.yml with PostgreSQL service and Redis (non-optional) - Add database migration script (scripts/migrate_to_db.py) - Update .env.example with database configuration
This commit is contained in:
parent
c4d6cae735
commit
550f3516db
11
.env.example
11
.env.example
@ -72,9 +72,16 @@ MAX_REQUEST_SIZE_MB=100
|
|||||||
# Request timeout in seconds
|
# Request timeout in seconds
|
||||||
REQUEST_TIMEOUT_SECONDS=300
|
REQUEST_TIMEOUT_SECONDS=300
|
||||||
|
|
||||||
# ============== Database (Production) ==============
|
# ============== Database ==============
|
||||||
# PostgreSQL connection string (recommended for production)
|
# PostgreSQL connection string (recommended for production)
|
||||||
# DATABASE_URL=postgresql://user:password@localhost:5432/translate_db
|
# Format: postgresql://user:password@host:port/database
|
||||||
|
# DATABASE_URL=postgresql://translate_user:secure_password@localhost:5432/translate_db
|
||||||
|
|
||||||
|
# SQLite path (used when DATABASE_URL is not set - for development)
|
||||||
|
SQLITE_PATH=data/translate.db
|
||||||
|
|
||||||
|
# Enable SQL query logging (for debugging)
|
||||||
|
DATABASE_ECHO=false
|
||||||
|
|
||||||
# Redis for sessions and caching (recommended for production)
|
# Redis for sessions and caching (recommended for production)
|
||||||
# REDIS_URL=redis://localhost:6379/0
|
# REDIS_URL=redis://localhost:6379/0
|
||||||
|
|||||||
78
alembic.ini
Normal file
78
alembic.ini
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
# Alembic configuration file
|
||||||
|
|
||||||
|
[alembic]
|
||||||
|
# path to migration scripts
|
||||||
|
script_location = alembic
|
||||||
|
|
||||||
|
# template used to generate migration files
|
||||||
|
# file_template = %%(rev)s_%%(slug)s
|
||||||
|
|
||||||
|
# sys.path path, will be prepended to sys.path if present.
|
||||||
|
prepend_sys_path = .
|
||||||
|
|
||||||
|
# timezone to use when rendering the date within the migration file
|
||||||
|
# as well as the filename.
|
||||||
|
# timezone =
|
||||||
|
|
||||||
|
# max length of characters to apply to the "slug" field
|
||||||
|
# truncate_slug_length = 40
|
||||||
|
|
||||||
|
# set to 'true' to run the environment during the 'revision' command,
|
||||||
|
# regardless of autogenerate
|
||||||
|
# revision_environment = false
|
||||||
|
|
||||||
|
# set to 'true' to allow .pyc and .pyo files without a source .py file
|
||||||
|
# to be detected as revisions in the versions/ directory
|
||||||
|
# sourceless = false
|
||||||
|
|
||||||
|
# version location specification; This defaults to alembic/versions
|
||||||
|
# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions
|
||||||
|
|
||||||
|
# version path separator
|
||||||
|
# version_path_separator = :
|
||||||
|
|
||||||
|
# the output encoding used when revision files are written
|
||||||
|
# output_encoding = utf-8
|
||||||
|
|
||||||
|
sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||||
|
|
||||||
|
|
||||||
|
[post_write_hooks]
|
||||||
|
# hooks = black
|
||||||
|
# black.type = console_scripts
|
||||||
|
# black.entrypoint = black
|
||||||
|
# black.options = -q
|
||||||
|
|
||||||
|
[loggers]
|
||||||
|
keys = root,sqlalchemy,alembic
|
||||||
|
|
||||||
|
[handlers]
|
||||||
|
keys = console
|
||||||
|
|
||||||
|
[formatters]
|
||||||
|
keys = generic
|
||||||
|
|
||||||
|
[logger_root]
|
||||||
|
level = WARN
|
||||||
|
handlers = console
|
||||||
|
qualname =
|
||||||
|
|
||||||
|
[logger_sqlalchemy]
|
||||||
|
level = WARN
|
||||||
|
handlers =
|
||||||
|
qualname = sqlalchemy.engine
|
||||||
|
|
||||||
|
[logger_alembic]
|
||||||
|
level = INFO
|
||||||
|
handlers =
|
||||||
|
qualname = alembic
|
||||||
|
|
||||||
|
[handler_console]
|
||||||
|
class = StreamHandler
|
||||||
|
args = (sys.stderr,)
|
||||||
|
level = NOTSET
|
||||||
|
formatter = generic
|
||||||
|
|
||||||
|
[formatter_generic]
|
||||||
|
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||||
|
datefmt = %H:%M:%S
|
||||||
86
alembic/env.py
Normal file
86
alembic/env.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
"""
|
||||||
|
Alembic environment configuration
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from logging.config import fileConfig
|
||||||
|
|
||||||
|
from sqlalchemy import engine_from_config
|
||||||
|
from sqlalchemy import pool
|
||||||
|
|
||||||
|
from alembic import context
|
||||||
|
|
||||||
|
# Add parent directory to path for imports
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
# this is the Alembic Config object
|
||||||
|
config = context.config
|
||||||
|
|
||||||
|
# Interpret the config file for Python logging
|
||||||
|
if config.config_file_name is not None:
|
||||||
|
fileConfig(config.config_file_name)
|
||||||
|
|
||||||
|
# Import models for autogenerate support
|
||||||
|
from database.models import Base
|
||||||
|
target_metadata = Base.metadata
|
||||||
|
|
||||||
|
# Get database URL from environment
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
DATABASE_URL = os.getenv("DATABASE_URL", "")
|
||||||
|
if not DATABASE_URL:
|
||||||
|
SQLITE_PATH = os.getenv("SQLITE_PATH", "data/translate.db")
|
||||||
|
DATABASE_URL = f"sqlite:///./{SQLITE_PATH}"
|
||||||
|
|
||||||
|
# Override sqlalchemy.url with environment variable
|
||||||
|
config.set_main_option("sqlalchemy.url", DATABASE_URL)
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_offline() -> None:
|
||||||
|
"""Run migrations in 'offline' mode.
|
||||||
|
|
||||||
|
This configures the context with just a URL
|
||||||
|
and not an Engine, though an Engine is acceptable
|
||||||
|
here as well. By skipping the Engine creation
|
||||||
|
we don't even need a DBAPI to be available.
|
||||||
|
"""
|
||||||
|
url = config.get_main_option("sqlalchemy.url")
|
||||||
|
context.configure(
|
||||||
|
url=url,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
literal_binds=True,
|
||||||
|
dialect_opts={"paramstyle": "named"},
|
||||||
|
)
|
||||||
|
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_online() -> None:
|
||||||
|
"""Run migrations in 'online' mode.
|
||||||
|
|
||||||
|
In this scenario we need to create an Engine
|
||||||
|
and associate a connection with the context.
|
||||||
|
"""
|
||||||
|
connectable = engine_from_config(
|
||||||
|
config.get_section(config.config_ini_section, {}),
|
||||||
|
prefix="sqlalchemy.",
|
||||||
|
poolclass=pool.NullPool,
|
||||||
|
)
|
||||||
|
|
||||||
|
with connectable.connect() as connection:
|
||||||
|
context.configure(
|
||||||
|
connection=connection,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
compare_type=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
if context.is_offline_mode():
|
||||||
|
run_migrations_offline()
|
||||||
|
else:
|
||||||
|
run_migrations_online()
|
||||||
27
alembic/script.py.mako
Normal file
27
alembic/script.py.mako
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
"""
|
||||||
|
Alembic migration script template
|
||||||
|
${message}
|
||||||
|
|
||||||
|
Revision ID: ${up_revision}
|
||||||
|
Revises: ${down_revision | comma,n}
|
||||||
|
Create Date: ${create_date}
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
${imports if imports else ""}
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = ${repr(up_revision)}
|
||||||
|
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||||
|
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
${upgrades if upgrades else "pass"}
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
${downgrades if downgrades else "pass"}
|
||||||
126
alembic/versions/001_initial.py
Normal file
126
alembic/versions/001_initial.py
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
"""Initial database schema
|
||||||
|
|
||||||
|
Revision ID: 001_initial
|
||||||
|
Revises:
|
||||||
|
Create Date: 2024-12-31
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '001_initial'
|
||||||
|
down_revision: Union[str, None] = None
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Create users table
|
||||||
|
op.create_table(
|
||||||
|
'users',
|
||||||
|
sa.Column('id', sa.String(36), primary_key=True),
|
||||||
|
sa.Column('email', sa.String(255), unique=True, nullable=False),
|
||||||
|
sa.Column('name', sa.String(255), nullable=False),
|
||||||
|
sa.Column('password_hash', sa.String(255), nullable=False),
|
||||||
|
sa.Column('email_verified', sa.Boolean(), default=False),
|
||||||
|
sa.Column('is_active', sa.Boolean(), default=True),
|
||||||
|
sa.Column('avatar_url', sa.String(500), nullable=True),
|
||||||
|
sa.Column('plan', sa.String(20), default='free'),
|
||||||
|
sa.Column('subscription_status', sa.String(20), default='active'),
|
||||||
|
sa.Column('stripe_customer_id', sa.String(255), nullable=True),
|
||||||
|
sa.Column('stripe_subscription_id', sa.String(255), nullable=True),
|
||||||
|
sa.Column('docs_translated_this_month', sa.Integer(), default=0),
|
||||||
|
sa.Column('pages_translated_this_month', sa.Integer(), default=0),
|
||||||
|
sa.Column('api_calls_this_month', sa.Integer(), default=0),
|
||||||
|
sa.Column('extra_credits', sa.Integer(), default=0),
|
||||||
|
sa.Column('usage_reset_date', sa.DateTime()),
|
||||||
|
sa.Column('created_at', sa.DateTime()),
|
||||||
|
sa.Column('updated_at', sa.DateTime()),
|
||||||
|
sa.Column('last_login_at', sa.DateTime(), nullable=True),
|
||||||
|
)
|
||||||
|
op.create_index('ix_users_email', 'users', ['email'])
|
||||||
|
op.create_index('ix_users_email_active', 'users', ['email', 'is_active'])
|
||||||
|
op.create_index('ix_users_stripe_customer', 'users', ['stripe_customer_id'])
|
||||||
|
|
||||||
|
# Create translations table
|
||||||
|
op.create_table(
|
||||||
|
'translations',
|
||||||
|
sa.Column('id', sa.String(36), primary_key=True),
|
||||||
|
sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id', ondelete='CASCADE'), nullable=False),
|
||||||
|
sa.Column('original_filename', sa.String(255), nullable=False),
|
||||||
|
sa.Column('file_type', sa.String(10), nullable=False),
|
||||||
|
sa.Column('file_size_bytes', sa.BigInteger(), default=0),
|
||||||
|
sa.Column('page_count', sa.Integer(), default=0),
|
||||||
|
sa.Column('source_language', sa.String(10), default='auto'),
|
||||||
|
sa.Column('target_language', sa.String(10), nullable=False),
|
||||||
|
sa.Column('provider', sa.String(50), nullable=False),
|
||||||
|
sa.Column('status', sa.String(20), default='pending'),
|
||||||
|
sa.Column('error_message', sa.Text(), nullable=True),
|
||||||
|
sa.Column('processing_time_ms', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('characters_translated', sa.Integer(), default=0),
|
||||||
|
sa.Column('estimated_cost_usd', sa.Float(), default=0.0),
|
||||||
|
sa.Column('created_at', sa.DateTime()),
|
||||||
|
sa.Column('completed_at', sa.DateTime(), nullable=True),
|
||||||
|
)
|
||||||
|
op.create_index('ix_translations_user_date', 'translations', ['user_id', 'created_at'])
|
||||||
|
op.create_index('ix_translations_status', 'translations', ['status'])
|
||||||
|
|
||||||
|
# Create api_keys table
|
||||||
|
op.create_table(
|
||||||
|
'api_keys',
|
||||||
|
sa.Column('id', sa.String(36), primary_key=True),
|
||||||
|
sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id', ondelete='CASCADE'), nullable=False),
|
||||||
|
sa.Column('name', sa.String(100), nullable=False),
|
||||||
|
sa.Column('key_hash', sa.String(255), nullable=False),
|
||||||
|
sa.Column('key_prefix', sa.String(10), nullable=False),
|
||||||
|
sa.Column('is_active', sa.Boolean(), default=True),
|
||||||
|
sa.Column('scopes', sa.JSON(), default=list),
|
||||||
|
sa.Column('last_used_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('usage_count', sa.Integer(), default=0),
|
||||||
|
sa.Column('created_at', sa.DateTime()),
|
||||||
|
sa.Column('expires_at', sa.DateTime(), nullable=True),
|
||||||
|
)
|
||||||
|
op.create_index('ix_api_keys_prefix', 'api_keys', ['key_prefix'])
|
||||||
|
op.create_index('ix_api_keys_hash', 'api_keys', ['key_hash'])
|
||||||
|
|
||||||
|
# Create usage_logs table
|
||||||
|
op.create_table(
|
||||||
|
'usage_logs',
|
||||||
|
sa.Column('id', sa.String(36), primary_key=True),
|
||||||
|
sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id', ondelete='CASCADE'), nullable=False),
|
||||||
|
sa.Column('date', sa.DateTime(), nullable=False),
|
||||||
|
sa.Column('documents_count', sa.Integer(), default=0),
|
||||||
|
sa.Column('pages_count', sa.Integer(), default=0),
|
||||||
|
sa.Column('characters_count', sa.BigInteger(), default=0),
|
||||||
|
sa.Column('api_calls_count', sa.Integer(), default=0),
|
||||||
|
sa.Column('provider_breakdown', sa.JSON(), default=dict),
|
||||||
|
)
|
||||||
|
op.create_index('ix_usage_logs_user_date', 'usage_logs', ['user_id', 'date'], unique=True)
|
||||||
|
|
||||||
|
# Create payment_history table
|
||||||
|
op.create_table(
|
||||||
|
'payment_history',
|
||||||
|
sa.Column('id', sa.String(36), primary_key=True),
|
||||||
|
sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id', ondelete='CASCADE'), nullable=False),
|
||||||
|
sa.Column('stripe_payment_intent_id', sa.String(255), nullable=True),
|
||||||
|
sa.Column('stripe_invoice_id', sa.String(255), nullable=True),
|
||||||
|
sa.Column('amount_cents', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('currency', sa.String(3), default='usd'),
|
||||||
|
sa.Column('payment_type', sa.String(50), nullable=False),
|
||||||
|
sa.Column('status', sa.String(20), nullable=False),
|
||||||
|
sa.Column('description', sa.String(255), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime()),
|
||||||
|
)
|
||||||
|
op.create_index('ix_payment_history_user', 'payment_history', ['user_id', 'created_at'])
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_table('payment_history')
|
||||||
|
op.drop_table('usage_logs')
|
||||||
|
op.drop_table('api_keys')
|
||||||
|
op.drop_table('translations')
|
||||||
|
op.drop_table('users')
|
||||||
17
database/__init__.py
Normal file
17
database/__init__.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
"""
|
||||||
|
Database module for the Document Translation API
|
||||||
|
Provides PostgreSQL support with async SQLAlchemy
|
||||||
|
"""
|
||||||
|
from database.connection import get_db, engine, SessionLocal, init_db
|
||||||
|
from database.models import User, Subscription, Translation, ApiKey
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"get_db",
|
||||||
|
"engine",
|
||||||
|
"SessionLocal",
|
||||||
|
"init_db",
|
||||||
|
"User",
|
||||||
|
"Subscription",
|
||||||
|
"Translation",
|
||||||
|
"ApiKey"
|
||||||
|
]
|
||||||
139
database/connection.py
Normal file
139
database/connection.py
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
"""
|
||||||
|
Database connection and session management
|
||||||
|
Supports both PostgreSQL (production) and SQLite (development/testing)
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from typing import Generator, Optional
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
from sqlalchemy import create_engine, event
|
||||||
|
from sqlalchemy.orm import sessionmaker, Session
|
||||||
|
from sqlalchemy.pool import QueuePool, StaticPool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Database URL from environment
|
||||||
|
# PostgreSQL: postgresql://user:password@host:port/database
|
||||||
|
# SQLite: sqlite:///./data/translate.db
|
||||||
|
DATABASE_URL = os.getenv("DATABASE_URL", "")
|
||||||
|
|
||||||
|
# Determine if we're using SQLite or PostgreSQL
|
||||||
|
_is_sqlite = DATABASE_URL.startswith("sqlite") if DATABASE_URL else True
|
||||||
|
|
||||||
|
# Create engine based on database type
|
||||||
|
if DATABASE_URL and not _is_sqlite:
|
||||||
|
# PostgreSQL configuration
|
||||||
|
engine = create_engine(
|
||||||
|
DATABASE_URL,
|
||||||
|
poolclass=QueuePool,
|
||||||
|
pool_size=5,
|
||||||
|
max_overflow=10,
|
||||||
|
pool_timeout=30,
|
||||||
|
pool_recycle=1800, # Recycle connections after 30 minutes
|
||||||
|
pool_pre_ping=True, # Check connection health before use
|
||||||
|
echo=os.getenv("DATABASE_ECHO", "false").lower() == "true",
|
||||||
|
)
|
||||||
|
logger.info("✅ Database configured with PostgreSQL")
|
||||||
|
else:
|
||||||
|
# SQLite configuration (for development/testing or when no DATABASE_URL)
|
||||||
|
sqlite_path = os.getenv("SQLITE_PATH", "data/translate.db")
|
||||||
|
os.makedirs(os.path.dirname(sqlite_path), exist_ok=True)
|
||||||
|
|
||||||
|
sqlite_url = f"sqlite:///./{sqlite_path}"
|
||||||
|
engine = create_engine(
|
||||||
|
sqlite_url,
|
||||||
|
connect_args={"check_same_thread": False},
|
||||||
|
poolclass=StaticPool,
|
||||||
|
echo=os.getenv("DATABASE_ECHO", "false").lower() == "true",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Enable foreign keys for SQLite
|
||||||
|
@event.listens_for(engine, "connect")
|
||||||
|
def set_sqlite_pragma(dbapi_connection, connection_record):
|
||||||
|
cursor = dbapi_connection.cursor()
|
||||||
|
cursor.execute("PRAGMA foreign_keys=ON")
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
if not DATABASE_URL:
|
||||||
|
logger.warning("⚠️ DATABASE_URL not set, using SQLite for development")
|
||||||
|
else:
|
||||||
|
logger.info(f"✅ Database configured with SQLite: {sqlite_path}")
|
||||||
|
|
||||||
|
# Session factory
|
||||||
|
SessionLocal = sessionmaker(
|
||||||
|
autocommit=False,
|
||||||
|
autoflush=False,
|
||||||
|
bind=engine,
|
||||||
|
expire_on_commit=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_db() -> Generator[Session, None, None]:
|
||||||
|
"""
|
||||||
|
Dependency for FastAPI to get database session.
|
||||||
|
Usage: db: Session = Depends(get_db)
|
||||||
|
"""
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
yield db
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def get_db_session() -> Generator[Session, None, None]:
|
||||||
|
"""
|
||||||
|
Context manager for database session.
|
||||||
|
Usage: with get_db_session() as db: ...
|
||||||
|
"""
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
yield db
|
||||||
|
db.commit()
|
||||||
|
except Exception:
|
||||||
|
db.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
# Alias for backward compatibility
|
||||||
|
get_sync_session = get_db_session
|
||||||
|
|
||||||
|
|
||||||
|
def init_db():
|
||||||
|
"""
|
||||||
|
Initialize database tables.
|
||||||
|
Call this on application startup.
|
||||||
|
"""
|
||||||
|
from database.models import Base
|
||||||
|
Base.metadata.create_all(bind=engine)
|
||||||
|
logger.info("✅ Database tables initialized")
|
||||||
|
|
||||||
|
|
||||||
|
def check_db_connection() -> bool:
|
||||||
|
"""
|
||||||
|
Check if database connection is healthy.
|
||||||
|
Returns True if connection works, False otherwise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with engine.connect() as conn:
|
||||||
|
conn.execute("SELECT 1")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Database connection check failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# Connection pool stats (for monitoring)
|
||||||
|
def get_pool_stats() -> dict:
|
||||||
|
"""Get database connection pool statistics"""
|
||||||
|
if hasattr(engine.pool, 'status'):
|
||||||
|
return {
|
||||||
|
"pool_size": engine.pool.size(),
|
||||||
|
"checked_in": engine.pool.checkedin(),
|
||||||
|
"checked_out": engine.pool.checkedout(),
|
||||||
|
"overflow": engine.pool.overflow(),
|
||||||
|
}
|
||||||
|
return {"status": "pool stats not available"}
|
||||||
259
database/models.py
Normal file
259
database/models.py
Normal file
@ -0,0 +1,259 @@
|
|||||||
|
"""
|
||||||
|
SQLAlchemy models for the Document Translation API
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
from sqlalchemy import (
|
||||||
|
Column, String, Integer, Float, Boolean, DateTime, Text,
|
||||||
|
ForeignKey, Enum, Index, JSON, BigInteger
|
||||||
|
)
|
||||||
|
from sqlalchemy.orm import relationship, declarative_base
|
||||||
|
from sqlalchemy.dialects.postgresql import UUID as PG_UUID
|
||||||
|
import enum
|
||||||
|
|
||||||
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
|
def generate_uuid():
|
||||||
|
"""Generate a new UUID string"""
|
||||||
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
|
class PlanType(str, enum.Enum):
|
||||||
|
FREE = "free"
|
||||||
|
STARTER = "starter"
|
||||||
|
PRO = "pro"
|
||||||
|
BUSINESS = "business"
|
||||||
|
ENTERPRISE = "enterprise"
|
||||||
|
|
||||||
|
|
||||||
|
class SubscriptionStatus(str, enum.Enum):
|
||||||
|
ACTIVE = "active"
|
||||||
|
CANCELED = "canceled"
|
||||||
|
PAST_DUE = "past_due"
|
||||||
|
TRIALING = "trialing"
|
||||||
|
PAUSED = "paused"
|
||||||
|
|
||||||
|
|
||||||
|
class User(Base):
|
||||||
|
"""User model for authentication and billing"""
|
||||||
|
__tablename__ = "users"
|
||||||
|
|
||||||
|
id = Column(String(36), primary_key=True, default=generate_uuid)
|
||||||
|
email = Column(String(255), unique=True, nullable=False, index=True)
|
||||||
|
name = Column(String(255), nullable=False)
|
||||||
|
password_hash = Column(String(255), nullable=False)
|
||||||
|
|
||||||
|
# Account status
|
||||||
|
email_verified = Column(Boolean, default=False)
|
||||||
|
is_active = Column(Boolean, default=True)
|
||||||
|
avatar_url = Column(String(500), nullable=True)
|
||||||
|
|
||||||
|
# Subscription info
|
||||||
|
plan = Column(Enum(PlanType), default=PlanType.FREE)
|
||||||
|
subscription_status = Column(Enum(SubscriptionStatus), default=SubscriptionStatus.ACTIVE)
|
||||||
|
|
||||||
|
# Stripe integration
|
||||||
|
stripe_customer_id = Column(String(255), nullable=True, index=True)
|
||||||
|
stripe_subscription_id = Column(String(255), nullable=True)
|
||||||
|
|
||||||
|
# Usage tracking (reset monthly)
|
||||||
|
docs_translated_this_month = Column(Integer, default=0)
|
||||||
|
pages_translated_this_month = Column(Integer, default=0)
|
||||||
|
api_calls_this_month = Column(Integer, default=0)
|
||||||
|
extra_credits = Column(Integer, default=0) # Purchased credits
|
||||||
|
usage_reset_date = Column(DateTime, default=datetime.utcnow)
|
||||||
|
|
||||||
|
# Timestamps
|
||||||
|
created_at = Column(DateTime, default=datetime.utcnow)
|
||||||
|
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||||
|
last_login_at = Column(DateTime, nullable=True)
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
translations = relationship("Translation", back_populates="user", lazy="dynamic")
|
||||||
|
api_keys = relationship("ApiKey", back_populates="user", lazy="dynamic")
|
||||||
|
|
||||||
|
# Indexes
|
||||||
|
__table_args__ = (
|
||||||
|
Index('ix_users_email_active', 'email', 'is_active'),
|
||||||
|
Index('ix_users_stripe_customer', 'stripe_customer_id'),
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
"""Convert user to dictionary for API response"""
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"email": self.email,
|
||||||
|
"name": self.name,
|
||||||
|
"avatar_url": self.avatar_url,
|
||||||
|
"plan": self.plan.value if self.plan else "free",
|
||||||
|
"subscription_status": self.subscription_status.value if self.subscription_status else "active",
|
||||||
|
"docs_translated_this_month": self.docs_translated_this_month,
|
||||||
|
"pages_translated_this_month": self.pages_translated_this_month,
|
||||||
|
"api_calls_this_month": self.api_calls_this_month,
|
||||||
|
"extra_credits": self.extra_credits,
|
||||||
|
"email_verified": self.email_verified,
|
||||||
|
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Translation(Base):
|
||||||
|
"""Translation history for analytics and billing"""
|
||||||
|
__tablename__ = "translations"
|
||||||
|
|
||||||
|
id = Column(String(36), primary_key=True, default=generate_uuid)
|
||||||
|
user_id = Column(String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
|
||||||
|
|
||||||
|
# File info
|
||||||
|
original_filename = Column(String(255), nullable=False)
|
||||||
|
file_type = Column(String(10), nullable=False) # xlsx, docx, pptx
|
||||||
|
file_size_bytes = Column(BigInteger, default=0)
|
||||||
|
page_count = Column(Integer, default=0)
|
||||||
|
|
||||||
|
# Translation details
|
||||||
|
source_language = Column(String(10), default="auto")
|
||||||
|
target_language = Column(String(10), nullable=False)
|
||||||
|
provider = Column(String(50), nullable=False) # google, deepl, ollama, etc.
|
||||||
|
|
||||||
|
# Processing info
|
||||||
|
status = Column(String(20), default="pending") # pending, processing, completed, failed
|
||||||
|
error_message = Column(Text, nullable=True)
|
||||||
|
processing_time_ms = Column(Integer, nullable=True)
|
||||||
|
|
||||||
|
# Cost tracking (for paid providers)
|
||||||
|
characters_translated = Column(Integer, default=0)
|
||||||
|
estimated_cost_usd = Column(Float, default=0.0)
|
||||||
|
|
||||||
|
# Timestamps
|
||||||
|
created_at = Column(DateTime, default=datetime.utcnow)
|
||||||
|
completed_at = Column(DateTime, nullable=True)
|
||||||
|
|
||||||
|
# Relationship
|
||||||
|
user = relationship("User", back_populates="translations")
|
||||||
|
|
||||||
|
# Indexes
|
||||||
|
__table_args__ = (
|
||||||
|
Index('ix_translations_user_date', 'user_id', 'created_at'),
|
||||||
|
Index('ix_translations_status', 'status'),
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"original_filename": self.original_filename,
|
||||||
|
"file_type": self.file_type,
|
||||||
|
"file_size_bytes": self.file_size_bytes,
|
||||||
|
"page_count": self.page_count,
|
||||||
|
"source_language": self.source_language,
|
||||||
|
"target_language": self.target_language,
|
||||||
|
"provider": self.provider,
|
||||||
|
"status": self.status,
|
||||||
|
"processing_time_ms": self.processing_time_ms,
|
||||||
|
"characters_translated": self.characters_translated,
|
||||||
|
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||||
|
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ApiKey(Base):
|
||||||
|
"""API keys for programmatic access"""
|
||||||
|
__tablename__ = "api_keys"
|
||||||
|
|
||||||
|
id = Column(String(36), primary_key=True, default=generate_uuid)
|
||||||
|
user_id = Column(String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
|
||||||
|
|
||||||
|
# Key info
|
||||||
|
name = Column(String(100), nullable=False) # User-friendly name
|
||||||
|
key_hash = Column(String(255), nullable=False) # SHA256 of the key
|
||||||
|
key_prefix = Column(String(10), nullable=False) # First 8 chars for identification
|
||||||
|
|
||||||
|
# Permissions
|
||||||
|
is_active = Column(Boolean, default=True)
|
||||||
|
scopes = Column(JSON, default=list) # ["translate", "read", "write"]
|
||||||
|
|
||||||
|
# Usage tracking
|
||||||
|
last_used_at = Column(DateTime, nullable=True)
|
||||||
|
usage_count = Column(Integer, default=0)
|
||||||
|
|
||||||
|
# Timestamps
|
||||||
|
created_at = Column(DateTime, default=datetime.utcnow)
|
||||||
|
expires_at = Column(DateTime, nullable=True)
|
||||||
|
|
||||||
|
# Relationship
|
||||||
|
user = relationship("User", back_populates="api_keys")
|
||||||
|
|
||||||
|
# Indexes
|
||||||
|
__table_args__ = (
|
||||||
|
Index('ix_api_keys_prefix', 'key_prefix'),
|
||||||
|
Index('ix_api_keys_hash', 'key_hash'),
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"name": self.name,
|
||||||
|
"key_prefix": self.key_prefix,
|
||||||
|
"is_active": self.is_active,
|
||||||
|
"scopes": self.scopes,
|
||||||
|
"last_used_at": self.last_used_at.isoformat() if self.last_used_at else None,
|
||||||
|
"usage_count": self.usage_count,
|
||||||
|
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||||
|
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class UsageLog(Base):
|
||||||
|
"""Daily usage aggregation for billing and analytics"""
|
||||||
|
__tablename__ = "usage_logs"
|
||||||
|
|
||||||
|
id = Column(String(36), primary_key=True, default=generate_uuid)
|
||||||
|
user_id = Column(String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
|
||||||
|
|
||||||
|
# Date (for daily aggregation)
|
||||||
|
date = Column(DateTime, nullable=False, index=True)
|
||||||
|
|
||||||
|
# Aggregated counts
|
||||||
|
documents_count = Column(Integer, default=0)
|
||||||
|
pages_count = Column(Integer, default=0)
|
||||||
|
characters_count = Column(BigInteger, default=0)
|
||||||
|
api_calls_count = Column(Integer, default=0)
|
||||||
|
|
||||||
|
# By provider breakdown (JSON)
|
||||||
|
provider_breakdown = Column(JSON, default=dict)
|
||||||
|
|
||||||
|
# Indexes
|
||||||
|
__table_args__ = (
|
||||||
|
Index('ix_usage_logs_user_date', 'user_id', 'date', unique=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PaymentHistory(Base):
|
||||||
|
"""Payment and invoice history"""
|
||||||
|
__tablename__ = "payment_history"
|
||||||
|
|
||||||
|
id = Column(String(36), primary_key=True, default=generate_uuid)
|
||||||
|
user_id = Column(String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
|
||||||
|
|
||||||
|
# Stripe info
|
||||||
|
stripe_payment_intent_id = Column(String(255), nullable=True)
|
||||||
|
stripe_invoice_id = Column(String(255), nullable=True)
|
||||||
|
|
||||||
|
# Payment details
|
||||||
|
amount_cents = Column(Integer, nullable=False)
|
||||||
|
currency = Column(String(3), default="usd")
|
||||||
|
payment_type = Column(String(50), nullable=False) # subscription, credits, one_time
|
||||||
|
status = Column(String(20), nullable=False) # succeeded, failed, pending, refunded
|
||||||
|
|
||||||
|
# Description
|
||||||
|
description = Column(String(255), nullable=True)
|
||||||
|
|
||||||
|
# Timestamps
|
||||||
|
created_at = Column(DateTime, default=datetime.utcnow)
|
||||||
|
|
||||||
|
# Indexes
|
||||||
|
__table_args__ = (
|
||||||
|
Index('ix_payment_history_user', 'user_id', 'created_at'),
|
||||||
|
)
|
||||||
341
database/repositories.py
Normal file
341
database/repositories.py
Normal file
@ -0,0 +1,341 @@
|
|||||||
|
"""
|
||||||
|
Repository layer for database operations
|
||||||
|
Provides clean interface for CRUD operations
|
||||||
|
"""
|
||||||
|
import hashlib
|
||||||
|
import secrets
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from sqlalchemy import and_, func, or_
|
||||||
|
|
||||||
|
from database.models import (
|
||||||
|
User, Translation, ApiKey, UsageLog, PaymentHistory,
|
||||||
|
PlanType, SubscriptionStatus
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UserRepository:
|
||||||
|
"""Repository for User database operations"""
|
||||||
|
|
||||||
|
def __init__(self, db: Session):
|
||||||
|
self.db = db
|
||||||
|
|
||||||
|
def get_by_id(self, user_id: str) -> Optional[User]:
|
||||||
|
"""Get user by ID"""
|
||||||
|
return self.db.query(User).filter(User.id == user_id).first()
|
||||||
|
|
||||||
|
def get_by_email(self, email: str) -> Optional[User]:
|
||||||
|
"""Get user by email (case-insensitive)"""
|
||||||
|
return self.db.query(User).filter(
|
||||||
|
func.lower(User.email) == email.lower()
|
||||||
|
).first()
|
||||||
|
|
||||||
|
def get_by_stripe_customer(self, stripe_customer_id: str) -> Optional[User]:
|
||||||
|
"""Get user by Stripe customer ID"""
|
||||||
|
return self.db.query(User).filter(
|
||||||
|
User.stripe_customer_id == stripe_customer_id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
email: str,
|
||||||
|
name: str,
|
||||||
|
password_hash: str,
|
||||||
|
plan: PlanType = PlanType.FREE
|
||||||
|
) -> User:
|
||||||
|
"""Create a new user"""
|
||||||
|
user = User(
|
||||||
|
email=email.lower(),
|
||||||
|
name=name,
|
||||||
|
password_hash=password_hash,
|
||||||
|
plan=plan,
|
||||||
|
subscription_status=SubscriptionStatus.ACTIVE,
|
||||||
|
)
|
||||||
|
self.db.add(user)
|
||||||
|
self.db.commit()
|
||||||
|
self.db.refresh(user)
|
||||||
|
return user
|
||||||
|
|
||||||
|
def update(self, user_id: str, **kwargs) -> Optional[User]:
|
||||||
|
"""Update user fields"""
|
||||||
|
user = self.get_by_id(user_id)
|
||||||
|
if not user:
|
||||||
|
return None
|
||||||
|
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if hasattr(user, key):
|
||||||
|
setattr(user, key, value)
|
||||||
|
|
||||||
|
user.updated_at = datetime.utcnow()
|
||||||
|
self.db.commit()
|
||||||
|
self.db.refresh(user)
|
||||||
|
return user
|
||||||
|
|
||||||
|
def delete(self, user_id: str) -> bool:
|
||||||
|
"""Delete a user"""
|
||||||
|
user = self.get_by_id(user_id)
|
||||||
|
if not user:
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.db.delete(user)
|
||||||
|
self.db.commit()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def increment_usage(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
docs: int = 0,
|
||||||
|
pages: int = 0,
|
||||||
|
api_calls: int = 0
|
||||||
|
) -> Optional[User]:
|
||||||
|
"""Increment usage counters"""
|
||||||
|
user = self.get_by_id(user_id)
|
||||||
|
if not user:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check if usage needs to be reset (monthly)
|
||||||
|
if user.usage_reset_date:
|
||||||
|
now = datetime.utcnow()
|
||||||
|
if now.month != user.usage_reset_date.month or now.year != user.usage_reset_date.year:
|
||||||
|
user.docs_translated_this_month = 0
|
||||||
|
user.pages_translated_this_month = 0
|
||||||
|
user.api_calls_this_month = 0
|
||||||
|
user.usage_reset_date = now
|
||||||
|
|
||||||
|
user.docs_translated_this_month += docs
|
||||||
|
user.pages_translated_this_month += pages
|
||||||
|
user.api_calls_this_month += api_calls
|
||||||
|
|
||||||
|
self.db.commit()
|
||||||
|
self.db.refresh(user)
|
||||||
|
return user
|
||||||
|
|
||||||
|
def add_credits(self, user_id: str, credits: int) -> Optional[User]:
|
||||||
|
"""Add extra credits to user"""
|
||||||
|
user = self.get_by_id(user_id)
|
||||||
|
if not user:
|
||||||
|
return None
|
||||||
|
|
||||||
|
user.extra_credits += credits
|
||||||
|
self.db.commit()
|
||||||
|
self.db.refresh(user)
|
||||||
|
return user
|
||||||
|
|
||||||
|
def use_credits(self, user_id: str, credits: int) -> bool:
|
||||||
|
"""Use credits from user balance"""
|
||||||
|
user = self.get_by_id(user_id)
|
||||||
|
if not user or user.extra_credits < credits:
|
||||||
|
return False
|
||||||
|
|
||||||
|
user.extra_credits -= credits
|
||||||
|
self.db.commit()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_all_users(
|
||||||
|
self,
|
||||||
|
skip: int = 0,
|
||||||
|
limit: int = 100,
|
||||||
|
plan: Optional[PlanType] = None
|
||||||
|
) -> List[User]:
|
||||||
|
"""Get all users with pagination"""
|
||||||
|
query = self.db.query(User)
|
||||||
|
if plan:
|
||||||
|
query = query.filter(User.plan == plan)
|
||||||
|
return query.offset(skip).limit(limit).all()
|
||||||
|
|
||||||
|
def count_users(self, plan: Optional[PlanType] = None) -> int:
|
||||||
|
"""Count total users"""
|
||||||
|
query = self.db.query(func.count(User.id))
|
||||||
|
if plan:
|
||||||
|
query = query.filter(User.plan == plan)
|
||||||
|
return query.scalar()
|
||||||
|
|
||||||
|
|
||||||
|
class TranslationRepository:
|
||||||
|
"""Repository for Translation database operations"""
|
||||||
|
|
||||||
|
def __init__(self, db: Session):
|
||||||
|
self.db = db
|
||||||
|
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
original_filename: str,
|
||||||
|
file_type: str,
|
||||||
|
target_language: str,
|
||||||
|
provider: str,
|
||||||
|
source_language: str = "auto",
|
||||||
|
file_size_bytes: int = 0,
|
||||||
|
page_count: int = 0,
|
||||||
|
) -> Translation:
|
||||||
|
"""Create a new translation record"""
|
||||||
|
translation = Translation(
|
||||||
|
user_id=user_id,
|
||||||
|
original_filename=original_filename,
|
||||||
|
file_type=file_type,
|
||||||
|
file_size_bytes=file_size_bytes,
|
||||||
|
page_count=page_count,
|
||||||
|
source_language=source_language,
|
||||||
|
target_language=target_language,
|
||||||
|
provider=provider,
|
||||||
|
status="pending",
|
||||||
|
)
|
||||||
|
self.db.add(translation)
|
||||||
|
self.db.commit()
|
||||||
|
self.db.refresh(translation)
|
||||||
|
return translation
|
||||||
|
|
||||||
|
def update_status(
|
||||||
|
self,
|
||||||
|
translation_id: str,
|
||||||
|
status: str,
|
||||||
|
error_message: Optional[str] = None,
|
||||||
|
processing_time_ms: Optional[int] = None,
|
||||||
|
characters_translated: Optional[int] = None,
|
||||||
|
) -> Optional[Translation]:
|
||||||
|
"""Update translation status"""
|
||||||
|
translation = self.db.query(Translation).filter(
|
||||||
|
Translation.id == translation_id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not translation:
|
||||||
|
return None
|
||||||
|
|
||||||
|
translation.status = status
|
||||||
|
if error_message:
|
||||||
|
translation.error_message = error_message
|
||||||
|
if processing_time_ms:
|
||||||
|
translation.processing_time_ms = processing_time_ms
|
||||||
|
if characters_translated:
|
||||||
|
translation.characters_translated = characters_translated
|
||||||
|
if status == "completed":
|
||||||
|
translation.completed_at = datetime.utcnow()
|
||||||
|
|
||||||
|
self.db.commit()
|
||||||
|
self.db.refresh(translation)
|
||||||
|
return translation
|
||||||
|
|
||||||
|
def get_user_translations(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
skip: int = 0,
|
||||||
|
limit: int = 50,
|
||||||
|
status: Optional[str] = None,
|
||||||
|
) -> List[Translation]:
|
||||||
|
"""Get user's translation history"""
|
||||||
|
query = self.db.query(Translation).filter(Translation.user_id == user_id)
|
||||||
|
if status:
|
||||||
|
query = query.filter(Translation.status == status)
|
||||||
|
return query.order_by(Translation.created_at.desc()).offset(skip).limit(limit).all()
|
||||||
|
|
||||||
|
def get_user_stats(self, user_id: str, days: int = 30) -> Dict[str, Any]:
|
||||||
|
"""Get user's translation statistics"""
|
||||||
|
since = datetime.utcnow() - timedelta(days=days)
|
||||||
|
|
||||||
|
result = self.db.query(
|
||||||
|
func.count(Translation.id).label("total_translations"),
|
||||||
|
func.sum(Translation.page_count).label("total_pages"),
|
||||||
|
func.sum(Translation.characters_translated).label("total_characters"),
|
||||||
|
).filter(
|
||||||
|
and_(
|
||||||
|
Translation.user_id == user_id,
|
||||||
|
Translation.created_at >= since,
|
||||||
|
Translation.status == "completed",
|
||||||
|
)
|
||||||
|
).first()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_translations": result.total_translations or 0,
|
||||||
|
"total_pages": result.total_pages or 0,
|
||||||
|
"total_characters": result.total_characters or 0,
|
||||||
|
"period_days": days,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ApiKeyRepository:
|
||||||
|
"""Repository for API Key database operations"""
|
||||||
|
|
||||||
|
def __init__(self, db: Session):
|
||||||
|
self.db = db
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def hash_key(key: str) -> str:
|
||||||
|
"""Hash an API key"""
|
||||||
|
return hashlib.sha256(key.encode()).hexdigest()
|
||||||
|
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
name: str,
|
||||||
|
scopes: List[str] = None,
|
||||||
|
expires_in_days: Optional[int] = None,
|
||||||
|
) -> tuple[ApiKey, str]:
|
||||||
|
"""Create a new API key. Returns (ApiKey, raw_key)"""
|
||||||
|
# Generate a secure random key
|
||||||
|
raw_key = f"tr_{secrets.token_urlsafe(32)}"
|
||||||
|
key_hash = self.hash_key(raw_key)
|
||||||
|
key_prefix = raw_key[:10]
|
||||||
|
|
||||||
|
expires_at = None
|
||||||
|
if expires_in_days:
|
||||||
|
expires_at = datetime.utcnow() + timedelta(days=expires_in_days)
|
||||||
|
|
||||||
|
api_key = ApiKey(
|
||||||
|
user_id=user_id,
|
||||||
|
name=name,
|
||||||
|
key_hash=key_hash,
|
||||||
|
key_prefix=key_prefix,
|
||||||
|
scopes=scopes or ["translate"],
|
||||||
|
expires_at=expires_at,
|
||||||
|
)
|
||||||
|
self.db.add(api_key)
|
||||||
|
self.db.commit()
|
||||||
|
self.db.refresh(api_key)
|
||||||
|
|
||||||
|
return api_key, raw_key
|
||||||
|
|
||||||
|
def get_by_key(self, raw_key: str) -> Optional[ApiKey]:
|
||||||
|
"""Get API key by raw key value"""
|
||||||
|
key_hash = self.hash_key(raw_key)
|
||||||
|
api_key = self.db.query(ApiKey).filter(
|
||||||
|
and_(
|
||||||
|
ApiKey.key_hash == key_hash,
|
||||||
|
ApiKey.is_active == True,
|
||||||
|
)
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if api_key:
|
||||||
|
# Check expiration
|
||||||
|
if api_key.expires_at and api_key.expires_at < datetime.utcnow():
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Update last used
|
||||||
|
api_key.last_used_at = datetime.utcnow()
|
||||||
|
api_key.usage_count += 1
|
||||||
|
self.db.commit()
|
||||||
|
|
||||||
|
return api_key
|
||||||
|
|
||||||
|
def get_user_keys(self, user_id: str) -> List[ApiKey]:
|
||||||
|
"""Get all API keys for a user"""
|
||||||
|
return self.db.query(ApiKey).filter(
|
||||||
|
ApiKey.user_id == user_id
|
||||||
|
).order_by(ApiKey.created_at.desc()).all()
|
||||||
|
|
||||||
|
def revoke(self, key_id: str, user_id: str) -> bool:
|
||||||
|
"""Revoke an API key"""
|
||||||
|
api_key = self.db.query(ApiKey).filter(
|
||||||
|
and_(
|
||||||
|
ApiKey.id == key_id,
|
||||||
|
ApiKey.user_id == user_id,
|
||||||
|
)
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
return False
|
||||||
|
|
||||||
|
api_key.is_active = False
|
||||||
|
self.db.commit()
|
||||||
|
return True
|
||||||
@ -4,6 +4,35 @@
|
|||||||
version: '3.8'
|
version: '3.8'
|
||||||
|
|
||||||
services:
|
services:
|
||||||
|
# ===========================================
|
||||||
|
# PostgreSQL Database
|
||||||
|
# ===========================================
|
||||||
|
postgres:
|
||||||
|
image: postgres:16-alpine
|
||||||
|
container_name: translate-postgres
|
||||||
|
restart: unless-stopped
|
||||||
|
environment:
|
||||||
|
- POSTGRES_USER=${POSTGRES_USER:-translate}
|
||||||
|
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-translate_secret_123}
|
||||||
|
- POSTGRES_DB=${POSTGRES_DB:-translate_db}
|
||||||
|
- PGDATA=/var/lib/postgresql/data/pgdata
|
||||||
|
volumes:
|
||||||
|
- postgres_data:/var/lib/postgresql/data
|
||||||
|
networks:
|
||||||
|
- translate-network
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-translate} -d ${POSTGRES_DB:-translate_db}"]
|
||||||
|
interval: 10s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 5
|
||||||
|
start_period: 10s
|
||||||
|
deploy:
|
||||||
|
resources:
|
||||||
|
limits:
|
||||||
|
memory: 512M
|
||||||
|
reservations:
|
||||||
|
memory: 128M
|
||||||
|
|
||||||
# ===========================================
|
# ===========================================
|
||||||
# Backend API Service
|
# Backend API Service
|
||||||
# ===========================================
|
# ===========================================
|
||||||
@ -14,23 +43,45 @@ services:
|
|||||||
container_name: translate-backend
|
container_name: translate-backend
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
environment:
|
environment:
|
||||||
|
# Database
|
||||||
|
- DATABASE_URL=postgresql+asyncpg://${POSTGRES_USER:-translate}:${POSTGRES_PASSWORD:-translate_secret_123}@postgres:5432/${POSTGRES_DB:-translate_db}
|
||||||
|
# Redis
|
||||||
|
- REDIS_URL=redis://redis:6379/0
|
||||||
|
# Translation Services
|
||||||
- TRANSLATION_SERVICE=${TRANSLATION_SERVICE:-ollama}
|
- TRANSLATION_SERVICE=${TRANSLATION_SERVICE:-ollama}
|
||||||
- OLLAMA_BASE_URL=${OLLAMA_BASE_URL:-http://ollama:11434}
|
- OLLAMA_BASE_URL=${OLLAMA_BASE_URL:-http://ollama:11434}
|
||||||
- OLLAMA_MODEL=${OLLAMA_MODEL:-llama3}
|
- OLLAMA_MODEL=${OLLAMA_MODEL:-llama3}
|
||||||
- DEEPL_API_KEY=${DEEPL_API_KEY:-}
|
- DEEPL_API_KEY=${DEEPL_API_KEY:-}
|
||||||
- OPENAI_API_KEY=${OPENAI_API_KEY:-}
|
- OPENAI_API_KEY=${OPENAI_API_KEY:-}
|
||||||
|
- OPENROUTER_API_KEY=${OPENROUTER_API_KEY:-}
|
||||||
|
# File Limits
|
||||||
- MAX_FILE_SIZE_MB=${MAX_FILE_SIZE_MB:-50}
|
- MAX_FILE_SIZE_MB=${MAX_FILE_SIZE_MB:-50}
|
||||||
|
# Rate Limiting
|
||||||
- RATE_LIMIT_REQUESTS_PER_MINUTE=${RATE_LIMIT_REQUESTS_PER_MINUTE:-60}
|
- RATE_LIMIT_REQUESTS_PER_MINUTE=${RATE_LIMIT_REQUESTS_PER_MINUTE:-60}
|
||||||
- RATE_LIMIT_TRANSLATIONS_PER_MINUTE=${RATE_LIMIT_TRANSLATIONS_PER_MINUTE:-10}
|
- RATE_LIMIT_TRANSLATIONS_PER_MINUTE=${RATE_LIMIT_TRANSLATIONS_PER_MINUTE:-10}
|
||||||
- ADMIN_USERNAME=${ADMIN_USERNAME:-admin}
|
# Admin Auth (CHANGE IN PRODUCTION!)
|
||||||
- ADMIN_PASSWORD=${ADMIN_PASSWORD:-changeme123}
|
- ADMIN_USERNAME=${ADMIN_USERNAME}
|
||||||
- CORS_ORIGINS=${CORS_ORIGINS:-*}
|
- ADMIN_PASSWORD=${ADMIN_PASSWORD}
|
||||||
|
# Security
|
||||||
|
- JWT_SECRET=${JWT_SECRET}
|
||||||
|
- CORS_ORIGINS=${CORS_ORIGINS:-https://yourdomain.com}
|
||||||
|
# Stripe Payments
|
||||||
|
- STRIPE_SECRET_KEY=${STRIPE_SECRET_KEY:-}
|
||||||
|
- STRIPE_WEBHOOK_SECRET=${STRIPE_WEBHOOK_SECRET:-}
|
||||||
|
- STRIPE_STARTER_PRICE_ID=${STRIPE_STARTER_PRICE_ID:-}
|
||||||
|
- STRIPE_PRO_PRICE_ID=${STRIPE_PRO_PRICE_ID:-}
|
||||||
|
- STRIPE_BUSINESS_PRICE_ID=${STRIPE_BUSINESS_PRICE_ID:-}
|
||||||
volumes:
|
volumes:
|
||||||
- uploads_data:/app/uploads
|
- uploads_data:/app/uploads
|
||||||
- outputs_data:/app/outputs
|
- outputs_data:/app/outputs
|
||||||
- logs_data:/app/logs
|
- logs_data:/app/logs
|
||||||
networks:
|
networks:
|
||||||
- translate-network
|
- translate-network
|
||||||
|
depends_on:
|
||||||
|
postgres:
|
||||||
|
condition: service_healthy
|
||||||
|
redis:
|
||||||
|
condition: service_healthy
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
||||||
interval: 30s
|
interval: 30s
|
||||||
@ -117,7 +168,7 @@ services:
|
|||||||
- with-ollama
|
- with-ollama
|
||||||
|
|
||||||
# ===========================================
|
# ===========================================
|
||||||
# Redis (Optional - For caching & sessions)
|
# Redis (Caching & Sessions)
|
||||||
# ===========================================
|
# ===========================================
|
||||||
redis:
|
redis:
|
||||||
image: redis:7-alpine
|
image: redis:7-alpine
|
||||||
@ -130,11 +181,9 @@ services:
|
|||||||
- translate-network
|
- translate-network
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD", "redis-cli", "ping"]
|
test: ["CMD", "redis-cli", "ping"]
|
||||||
interval: 30s
|
interval: 10s
|
||||||
timeout: 10s
|
timeout: 5s
|
||||||
retries: 3
|
retries: 5
|
||||||
profiles:
|
|
||||||
- with-cache
|
|
||||||
|
|
||||||
# ===========================================
|
# ===========================================
|
||||||
# Prometheus (Optional - Monitoring)
|
# Prometheus (Optional - Monitoring)
|
||||||
@ -190,6 +239,8 @@ networks:
|
|||||||
# Volumes
|
# Volumes
|
||||||
# ===========================================
|
# ===========================================
|
||||||
volumes:
|
volumes:
|
||||||
|
postgres_data:
|
||||||
|
driver: local
|
||||||
uploads_data:
|
uploads_data:
|
||||||
driver: local
|
driver: local
|
||||||
outputs_data:
|
outputs_data:
|
||||||
|
|||||||
12
main.py
12
main.py
@ -204,6 +204,18 @@ async def lifespan(app: FastAPI):
|
|||||||
# Startup
|
# Startup
|
||||||
logger.info("Starting Document Translation API...")
|
logger.info("Starting Document Translation API...")
|
||||||
config.ensure_directories()
|
config.ensure_directories()
|
||||||
|
|
||||||
|
# Initialize database
|
||||||
|
try:
|
||||||
|
from database.connection import init_db, check_db_connection
|
||||||
|
init_db()
|
||||||
|
if check_db_connection():
|
||||||
|
logger.info("✅ Database connection verified")
|
||||||
|
else:
|
||||||
|
logger.warning("⚠️ Database connection check failed")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"⚠️ Database initialization skipped: {e}")
|
||||||
|
|
||||||
await cleanup_manager.start()
|
await cleanup_manager.start()
|
||||||
logger.info("API ready to accept requests")
|
logger.info("API ready to accept requests")
|
||||||
|
|
||||||
|
|||||||
@ -28,8 +28,7 @@ stripe==7.0.0
|
|||||||
# Session storage & caching (optional but recommended for production)
|
# Session storage & caching (optional but recommended for production)
|
||||||
redis==5.0.1
|
redis==5.0.1
|
||||||
|
|
||||||
# Database (optional but recommended for production)
|
# Database (recommended for production)
|
||||||
# sqlalchemy==2.0.25
|
sqlalchemy==2.0.25
|
||||||
# asyncpg==0.29.0 # PostgreSQL async driver
|
psycopg2-binary==2.9.9 # PostgreSQL driver
|
||||||
# alembic==1.13.1 # Database migrations
|
alembic==1.13.1 # Database migrations
|
||||||
|
|
||||||
|
|||||||
160
scripts/migrate_to_db.py
Normal file
160
scripts/migrate_to_db.py
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
"""
|
||||||
|
Migration script to move data from JSON files to database
|
||||||
|
Run this once to migrate existing users to the new database system
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
# Add parent directory to path
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
|
from database.connection import init_db, get_db_session
|
||||||
|
from database.repositories import UserRepository
|
||||||
|
from database.models import PlanType, SubscriptionStatus
|
||||||
|
|
||||||
|
|
||||||
|
def migrate_users_from_json():
|
||||||
|
"""Migrate users from JSON file to database"""
|
||||||
|
json_path = Path("data/users.json")
|
||||||
|
|
||||||
|
if not json_path.exists():
|
||||||
|
print("No users.json found, nothing to migrate")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Initialize database
|
||||||
|
print("Initializing database tables...")
|
||||||
|
init_db()
|
||||||
|
|
||||||
|
# Load JSON data
|
||||||
|
with open(json_path, 'r') as f:
|
||||||
|
users_data = json.load(f)
|
||||||
|
|
||||||
|
print(f"Found {len(users_data)} users to migrate")
|
||||||
|
|
||||||
|
migrated = 0
|
||||||
|
skipped = 0
|
||||||
|
errors = 0
|
||||||
|
|
||||||
|
with get_db_session() as db:
|
||||||
|
repo = UserRepository(db)
|
||||||
|
|
||||||
|
for user_id, user_data in users_data.items():
|
||||||
|
try:
|
||||||
|
# Check if user already exists
|
||||||
|
existing = repo.get_by_email(user_data.get('email', ''))
|
||||||
|
if existing:
|
||||||
|
print(f" Skipping {user_data.get('email')} - already exists")
|
||||||
|
skipped += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Map plan string to enum
|
||||||
|
plan_str = user_data.get('plan', 'free')
|
||||||
|
try:
|
||||||
|
plan = PlanType(plan_str)
|
||||||
|
except ValueError:
|
||||||
|
plan = PlanType.FREE
|
||||||
|
|
||||||
|
# Map subscription status
|
||||||
|
status_str = user_data.get('subscription_status', 'active')
|
||||||
|
try:
|
||||||
|
status = SubscriptionStatus(status_str)
|
||||||
|
except ValueError:
|
||||||
|
status = SubscriptionStatus.ACTIVE
|
||||||
|
|
||||||
|
# Create user with original ID
|
||||||
|
from database.models import User
|
||||||
|
user = User(
|
||||||
|
id=user_id,
|
||||||
|
email=user_data.get('email', '').lower(),
|
||||||
|
name=user_data.get('name', ''),
|
||||||
|
password_hash=user_data.get('password_hash', ''),
|
||||||
|
email_verified=user_data.get('email_verified', False),
|
||||||
|
avatar_url=user_data.get('avatar_url'),
|
||||||
|
plan=plan,
|
||||||
|
subscription_status=status,
|
||||||
|
stripe_customer_id=user_data.get('stripe_customer_id'),
|
||||||
|
stripe_subscription_id=user_data.get('stripe_subscription_id'),
|
||||||
|
docs_translated_this_month=user_data.get('docs_translated_this_month', 0),
|
||||||
|
pages_translated_this_month=user_data.get('pages_translated_this_month', 0),
|
||||||
|
api_calls_this_month=user_data.get('api_calls_this_month', 0),
|
||||||
|
extra_credits=user_data.get('extra_credits', 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse dates
|
||||||
|
if user_data.get('created_at'):
|
||||||
|
try:
|
||||||
|
user.created_at = datetime.fromisoformat(user_data['created_at'].replace('Z', '+00:00'))
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if user_data.get('updated_at'):
|
||||||
|
try:
|
||||||
|
user.updated_at = datetime.fromisoformat(user_data['updated_at'].replace('Z', '+00:00'))
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
db.add(user)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
print(f" Migrated: {user.email}")
|
||||||
|
migrated += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Error migrating {user_data.get('email', user_id)}: {e}")
|
||||||
|
errors += 1
|
||||||
|
db.rollback()
|
||||||
|
|
||||||
|
print(f"\nMigration complete:")
|
||||||
|
print(f" Migrated: {migrated}")
|
||||||
|
print(f" Skipped: {skipped}")
|
||||||
|
print(f" Errors: {errors}")
|
||||||
|
|
||||||
|
# Backup original file
|
||||||
|
if migrated > 0:
|
||||||
|
backup_path = json_path.with_suffix('.json.bak')
|
||||||
|
os.rename(json_path, backup_path)
|
||||||
|
print(f"\nOriginal file backed up to: {backup_path}")
|
||||||
|
|
||||||
|
return migrated
|
||||||
|
|
||||||
|
|
||||||
|
def verify_migration():
|
||||||
|
"""Verify the migration was successful"""
|
||||||
|
from database.connection import get_db_session
|
||||||
|
from database.repositories import UserRepository
|
||||||
|
|
||||||
|
with get_db_session() as db:
|
||||||
|
repo = UserRepository(db)
|
||||||
|
count = repo.count_users()
|
||||||
|
print(f"\nDatabase now contains {count} users")
|
||||||
|
|
||||||
|
# List first 5 users
|
||||||
|
users = repo.get_all_users(limit=5)
|
||||||
|
if users:
|
||||||
|
print("\nSample users:")
|
||||||
|
for user in users:
|
||||||
|
print(f" - {user.email} ({user.plan.value})")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("=" * 50)
|
||||||
|
print("JSON to Database Migration Script")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
# Check environment
|
||||||
|
db_url = os.getenv("DATABASE_URL", "")
|
||||||
|
if db_url:
|
||||||
|
print(f"Database: PostgreSQL")
|
||||||
|
else:
|
||||||
|
print(f"Database: SQLite (development)")
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Run migration
|
||||||
|
migrate_users_from_json()
|
||||||
|
|
||||||
|
# Verify
|
||||||
|
verify_migration()
|
||||||
@ -1,5 +1,9 @@
|
|||||||
"""
|
"""
|
||||||
Authentication service with JWT tokens and password hashing
|
Authentication service with JWT tokens and password hashing
|
||||||
|
|
||||||
|
This service provides user authentication with automatic backend selection:
|
||||||
|
- If DATABASE_URL is configured: Uses PostgreSQL database
|
||||||
|
- Otherwise: Falls back to JSON file storage (development mode)
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
import secrets
|
import secrets
|
||||||
@ -8,6 +12,9 @@ from datetime import datetime, timedelta
|
|||||||
from typing import Optional, Dict, Any
|
from typing import Optional, Dict, Any
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Try to import optional dependencies
|
# Try to import optional dependencies
|
||||||
try:
|
try:
|
||||||
@ -15,6 +22,7 @@ try:
|
|||||||
JWT_AVAILABLE = True
|
JWT_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
JWT_AVAILABLE = False
|
JWT_AVAILABLE = False
|
||||||
|
logger.warning("PyJWT not installed. Using fallback token encoding.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from passlib.context import CryptContext
|
from passlib.context import CryptContext
|
||||||
@ -22,17 +30,37 @@ try:
|
|||||||
PASSLIB_AVAILABLE = True
|
PASSLIB_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
PASSLIB_AVAILABLE = False
|
PASSLIB_AVAILABLE = False
|
||||||
|
logger.warning("passlib not installed. Using SHA256 fallback for password hashing.")
|
||||||
|
|
||||||
|
# Check if database is configured
|
||||||
|
DATABASE_URL = os.getenv("DATABASE_URL", "")
|
||||||
|
USE_DATABASE = bool(DATABASE_URL and DATABASE_URL.startswith("postgresql"))
|
||||||
|
|
||||||
|
if USE_DATABASE:
|
||||||
|
try:
|
||||||
|
from database.repositories import UserRepository
|
||||||
|
from database.connection import get_sync_session, init_db as _init_db
|
||||||
|
from database import models as db_models
|
||||||
|
DATABASE_AVAILABLE = True
|
||||||
|
logger.info("Database backend enabled for authentication")
|
||||||
|
except ImportError as e:
|
||||||
|
DATABASE_AVAILABLE = False
|
||||||
|
USE_DATABASE = False
|
||||||
|
logger.warning(f"Database modules not available: {e}. Using JSON storage.")
|
||||||
|
else:
|
||||||
|
DATABASE_AVAILABLE = False
|
||||||
|
logger.info("Using JSON file storage for authentication (DATABASE_URL not configured)")
|
||||||
|
|
||||||
from models.subscription import User, UserCreate, PlanType, SubscriptionStatus, PLANS
|
from models.subscription import User, UserCreate, PlanType, SubscriptionStatus, PLANS
|
||||||
|
|
||||||
|
|
||||||
# Configuration
|
# Configuration
|
||||||
SECRET_KEY = os.getenv("JWT_SECRET_KEY", secrets.token_urlsafe(32))
|
SECRET_KEY = os.getenv("JWT_SECRET", os.getenv("JWT_SECRET_KEY", secrets.token_urlsafe(32)))
|
||||||
ALGORITHM = "HS256"
|
ALGORITHM = "HS256"
|
||||||
ACCESS_TOKEN_EXPIRE_HOURS = 24
|
ACCESS_TOKEN_EXPIRE_HOURS = 24
|
||||||
REFRESH_TOKEN_EXPIRE_DAYS = 30
|
REFRESH_TOKEN_EXPIRE_DAYS = 30
|
||||||
|
|
||||||
# Simple file-based storage (replace with database in production)
|
# Simple file-based storage (used when database is not configured)
|
||||||
USERS_FILE = Path("data/users.json")
|
USERS_FILE = Path("data/users.json")
|
||||||
USERS_FILE.parent.mkdir(exist_ok=True)
|
USERS_FILE.parent.mkdir(exist_ok=True)
|
||||||
|
|
||||||
@ -117,7 +145,7 @@ def verify_token(token: str) -> Optional[Dict[str, Any]]:
|
|||||||
|
|
||||||
|
|
||||||
def load_users() -> Dict[str, Dict]:
|
def load_users() -> Dict[str, Dict]:
|
||||||
"""Load users from file storage"""
|
"""Load users from file storage (JSON backend only)"""
|
||||||
if USERS_FILE.exists():
|
if USERS_FILE.exists():
|
||||||
try:
|
try:
|
||||||
with open(USERS_FILE, 'r') as f:
|
with open(USERS_FILE, 'r') as f:
|
||||||
@ -128,54 +156,109 @@ def load_users() -> Dict[str, Dict]:
|
|||||||
|
|
||||||
|
|
||||||
def save_users(users: Dict[str, Dict]):
|
def save_users(users: Dict[str, Dict]):
|
||||||
"""Save users to file storage"""
|
"""Save users to file storage (JSON backend only)"""
|
||||||
with open(USERS_FILE, 'w') as f:
|
with open(USERS_FILE, 'w') as f:
|
||||||
json.dump(users, f, indent=2, default=str)
|
json.dump(users, f, indent=2, default=str)
|
||||||
|
|
||||||
|
|
||||||
|
def _db_user_to_model(db_user) -> User:
|
||||||
|
"""Convert database user model to Pydantic User model"""
|
||||||
|
return User(
|
||||||
|
id=str(db_user.id),
|
||||||
|
email=db_user.email,
|
||||||
|
name=db_user.name or "",
|
||||||
|
password_hash=db_user.password_hash,
|
||||||
|
avatar_url=db_user.avatar_url,
|
||||||
|
plan=PlanType(db_user.plan) if db_user.plan else PlanType.FREE,
|
||||||
|
subscription_status=SubscriptionStatus(db_user.subscription_status) if db_user.subscription_status else SubscriptionStatus.ACTIVE,
|
||||||
|
stripe_customer_id=db_user.stripe_customer_id,
|
||||||
|
stripe_subscription_id=db_user.stripe_subscription_id,
|
||||||
|
docs_translated_this_month=db_user.docs_translated_this_month or 0,
|
||||||
|
pages_translated_this_month=db_user.pages_translated_this_month or 0,
|
||||||
|
api_calls_this_month=db_user.api_calls_this_month or 0,
|
||||||
|
extra_credits=db_user.extra_credits or 0,
|
||||||
|
usage_reset_date=db_user.usage_reset_date or datetime.utcnow(),
|
||||||
|
default_source_lang=db_user.default_source_lang or "en",
|
||||||
|
default_target_lang=db_user.default_target_lang or "es",
|
||||||
|
default_provider=db_user.default_provider or "google",
|
||||||
|
created_at=db_user.created_at or datetime.utcnow(),
|
||||||
|
updated_at=db_user.updated_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_user_by_email(email: str) -> Optional[User]:
|
def get_user_by_email(email: str) -> Optional[User]:
|
||||||
"""Get a user by email"""
|
"""Get a user by email"""
|
||||||
users = load_users()
|
if USE_DATABASE and DATABASE_AVAILABLE:
|
||||||
for user_data in users.values():
|
with get_sync_session() as session:
|
||||||
if user_data.get("email", "").lower() == email.lower():
|
repo = UserRepository(session)
|
||||||
return User(**user_data)
|
db_user = repo.get_by_email(email)
|
||||||
return None
|
if db_user:
|
||||||
|
return _db_user_to_model(db_user)
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
users = load_users()
|
||||||
|
for user_data in users.values():
|
||||||
|
if user_data.get("email", "").lower() == email.lower():
|
||||||
|
return User(**user_data)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_user_by_id(user_id: str) -> Optional[User]:
|
def get_user_by_id(user_id: str) -> Optional[User]:
|
||||||
"""Get a user by ID"""
|
"""Get a user by ID"""
|
||||||
users = load_users()
|
if USE_DATABASE and DATABASE_AVAILABLE:
|
||||||
if user_id in users:
|
with get_sync_session() as session:
|
||||||
return User(**users[user_id])
|
repo = UserRepository(session)
|
||||||
return None
|
db_user = repo.get_by_id(user_id)
|
||||||
|
if db_user:
|
||||||
|
return _db_user_to_model(db_user)
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
users = load_users()
|
||||||
|
if user_id in users:
|
||||||
|
return User(**users[user_id])
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def create_user(user_create: UserCreate) -> User:
|
def create_user(user_create: UserCreate) -> User:
|
||||||
"""Create a new user"""
|
"""Create a new user"""
|
||||||
users = load_users()
|
|
||||||
|
|
||||||
# Check if email exists
|
# Check if email exists
|
||||||
if get_user_by_email(user_create.email):
|
if get_user_by_email(user_create.email):
|
||||||
raise ValueError("Email already registered")
|
raise ValueError("Email already registered")
|
||||||
|
|
||||||
# Generate user ID
|
if USE_DATABASE and DATABASE_AVAILABLE:
|
||||||
user_id = secrets.token_urlsafe(16)
|
with get_sync_session() as session:
|
||||||
|
repo = UserRepository(session)
|
||||||
# Create user
|
db_user = repo.create(
|
||||||
user = User(
|
email=user_create.email,
|
||||||
id=user_id,
|
name=user_create.name,
|
||||||
email=user_create.email,
|
password_hash=hash_password(user_create.password),
|
||||||
name=user_create.name,
|
plan=PlanType.FREE.value,
|
||||||
password_hash=hash_password(user_create.password),
|
subscription_status=SubscriptionStatus.ACTIVE.value
|
||||||
plan=PlanType.FREE,
|
)
|
||||||
subscription_status=SubscriptionStatus.ACTIVE,
|
session.commit()
|
||||||
)
|
session.refresh(db_user)
|
||||||
|
return _db_user_to_model(db_user)
|
||||||
# Save to storage
|
else:
|
||||||
users[user_id] = user.model_dump()
|
users = load_users()
|
||||||
save_users(users)
|
|
||||||
|
# Generate user ID
|
||||||
return user
|
user_id = secrets.token_urlsafe(16)
|
||||||
|
|
||||||
|
# Create user
|
||||||
|
user = User(
|
||||||
|
id=user_id,
|
||||||
|
email=user_create.email,
|
||||||
|
name=user_create.name,
|
||||||
|
password_hash=hash_password(user_create.password),
|
||||||
|
plan=PlanType.FREE,
|
||||||
|
subscription_status=SubscriptionStatus.ACTIVE,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save to storage
|
||||||
|
users[user_id] = user.model_dump()
|
||||||
|
save_users(users)
|
||||||
|
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
def authenticate_user(email: str, password: str) -> Optional[User]:
|
def authenticate_user(email: str, password: str) -> Optional[User]:
|
||||||
@ -190,15 +273,25 @@ def authenticate_user(email: str, password: str) -> Optional[User]:
|
|||||||
|
|
||||||
def update_user(user_id: str, updates: Dict[str, Any]) -> Optional[User]:
|
def update_user(user_id: str, updates: Dict[str, Any]) -> Optional[User]:
|
||||||
"""Update a user's data"""
|
"""Update a user's data"""
|
||||||
users = load_users()
|
if USE_DATABASE and DATABASE_AVAILABLE:
|
||||||
if user_id not in users:
|
with get_sync_session() as session:
|
||||||
return None
|
repo = UserRepository(session)
|
||||||
|
db_user = repo.update(user_id, updates)
|
||||||
users[user_id].update(updates)
|
if db_user:
|
||||||
users[user_id]["updated_at"] = datetime.utcnow().isoformat()
|
session.commit()
|
||||||
save_users(users)
|
session.refresh(db_user)
|
||||||
|
return _db_user_to_model(db_user)
|
||||||
return User(**users[user_id])
|
return None
|
||||||
|
else:
|
||||||
|
users = load_users()
|
||||||
|
if user_id not in users:
|
||||||
|
return None
|
||||||
|
|
||||||
|
users[user_id].update(updates)
|
||||||
|
users[user_id]["updated_at"] = datetime.utcnow().isoformat()
|
||||||
|
save_users(users)
|
||||||
|
|
||||||
|
return User(**users[user_id])
|
||||||
|
|
||||||
|
|
||||||
def check_usage_limits(user: User) -> Dict[str, Any]:
|
def check_usage_limits(user: User) -> Dict[str, Any]:
|
||||||
@ -212,7 +305,7 @@ def check_usage_limits(user: User) -> Dict[str, Any]:
|
|||||||
"docs_translated_this_month": 0,
|
"docs_translated_this_month": 0,
|
||||||
"pages_translated_this_month": 0,
|
"pages_translated_this_month": 0,
|
||||||
"api_calls_this_month": 0,
|
"api_calls_this_month": 0,
|
||||||
"usage_reset_date": now.isoformat()
|
"usage_reset_date": now.isoformat() if not USE_DATABASE else now
|
||||||
})
|
})
|
||||||
user.docs_translated_this_month = 0
|
user.docs_translated_this_month = 0
|
||||||
user.pages_translated_this_month = 0
|
user.pages_translated_this_month = 0
|
||||||
@ -248,8 +341,8 @@ def record_usage(user_id: str, pages_count: int, use_credits: bool = False) -> b
|
|||||||
if use_credits:
|
if use_credits:
|
||||||
updates["extra_credits"] = max(0, user.extra_credits - pages_count)
|
updates["extra_credits"] = max(0, user.extra_credits - pages_count)
|
||||||
|
|
||||||
update_user(user_id, updates)
|
result = update_user(user_id, updates)
|
||||||
return True
|
return result is not None
|
||||||
|
|
||||||
|
|
||||||
def add_credits(user_id: str, credits: int) -> bool:
|
def add_credits(user_id: str, credits: int) -> bool:
|
||||||
@ -258,5 +351,14 @@ def add_credits(user_id: str, credits: int) -> bool:
|
|||||||
if not user:
|
if not user:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
update_user(user_id, {"extra_credits": user.extra_credits + credits})
|
result = update_user(user_id, {"extra_credits": user.extra_credits + credits})
|
||||||
return True
|
return result is not None
|
||||||
|
|
||||||
|
|
||||||
|
def init_database():
|
||||||
|
"""Initialize the database (call on application startup)"""
|
||||||
|
if USE_DATABASE and DATABASE_AVAILABLE:
|
||||||
|
_init_db()
|
||||||
|
logger.info("Database initialized successfully")
|
||||||
|
else:
|
||||||
|
logger.info("Using JSON file storage")
|
||||||
|
|||||||
245
services/auth_service_db.py
Normal file
245
services/auth_service_db.py
Normal file
@ -0,0 +1,245 @@
|
|||||||
|
"""
|
||||||
|
Database-backed authentication service
|
||||||
|
Replaces JSON file storage with SQLAlchemy
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import secrets
|
||||||
|
import hashlib
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
import logging
|
||||||
|
|
||||||
|
# Try to import optional dependencies
|
||||||
|
try:
|
||||||
|
import jwt
|
||||||
|
JWT_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
JWT_AVAILABLE = False
|
||||||
|
import json
|
||||||
|
import base64
|
||||||
|
|
||||||
|
try:
|
||||||
|
from passlib.context import CryptContext
|
||||||
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||||
|
PASSLIB_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
PASSLIB_AVAILABLE = False
|
||||||
|
|
||||||
|
from database.connection import get_db_session
|
||||||
|
from database.repositories import UserRepository
|
||||||
|
from database.models import User, PlanType, SubscriptionStatus
|
||||||
|
from models.subscription import PLANS
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
SECRET_KEY = os.getenv("JWT_SECRET_KEY", secrets.token_urlsafe(32))
|
||||||
|
ALGORITHM = "HS256"
|
||||||
|
ACCESS_TOKEN_EXPIRE_HOURS = 24
|
||||||
|
REFRESH_TOKEN_EXPIRE_DAYS = 30
|
||||||
|
|
||||||
|
|
||||||
|
def hash_password(password: str) -> str:
|
||||||
|
"""Hash a password using bcrypt or fallback to SHA256"""
|
||||||
|
if PASSLIB_AVAILABLE:
|
||||||
|
return pwd_context.hash(password)
|
||||||
|
else:
|
||||||
|
salt = secrets.token_hex(16)
|
||||||
|
hashed = hashlib.sha256(f"{salt}{password}".encode()).hexdigest()
|
||||||
|
return f"sha256${salt}${hashed}"
|
||||||
|
|
||||||
|
|
||||||
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||||
|
"""Verify a password against its hash"""
|
||||||
|
if PASSLIB_AVAILABLE and not hashed_password.startswith("sha256$"):
|
||||||
|
return pwd_context.verify(plain_password, hashed_password)
|
||||||
|
else:
|
||||||
|
parts = hashed_password.split("$")
|
||||||
|
if len(parts) == 3 and parts[0] == "sha256":
|
||||||
|
salt = parts[1]
|
||||||
|
expected_hash = parts[2]
|
||||||
|
actual_hash = hashlib.sha256(f"{salt}{plain_password}".encode()).hexdigest()
|
||||||
|
return secrets.compare_digest(actual_hash, expected_hash)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def create_access_token(user_id: str, expires_delta: Optional[timedelta] = None) -> str:
|
||||||
|
"""Create a JWT access token"""
|
||||||
|
expire = datetime.utcnow() + (expires_delta or timedelta(hours=ACCESS_TOKEN_EXPIRE_HOURS))
|
||||||
|
|
||||||
|
if not JWT_AVAILABLE:
|
||||||
|
token_data = {"user_id": user_id, "exp": expire.isoformat(), "type": "access"}
|
||||||
|
return base64.urlsafe_b64encode(json.dumps(token_data).encode()).decode()
|
||||||
|
|
||||||
|
to_encode = {"sub": user_id, "exp": expire, "type": "access"}
|
||||||
|
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||||
|
|
||||||
|
|
||||||
|
def create_refresh_token(user_id: str) -> str:
|
||||||
|
"""Create a JWT refresh token"""
|
||||||
|
expire = datetime.utcnow() + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
|
||||||
|
|
||||||
|
if not JWT_AVAILABLE:
|
||||||
|
token_data = {"user_id": user_id, "exp": expire.isoformat(), "type": "refresh"}
|
||||||
|
return base64.urlsafe_b64encode(json.dumps(token_data).encode()).decode()
|
||||||
|
|
||||||
|
to_encode = {"sub": user_id, "exp": expire, "type": "refresh"}
|
||||||
|
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||||
|
|
||||||
|
|
||||||
|
def verify_token(token: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Verify a JWT token and return payload"""
|
||||||
|
if not JWT_AVAILABLE:
|
||||||
|
try:
|
||||||
|
data = json.loads(base64.urlsafe_b64decode(token.encode()).decode())
|
||||||
|
exp = datetime.fromisoformat(data["exp"])
|
||||||
|
if exp < datetime.utcnow():
|
||||||
|
return None
|
||||||
|
return {"sub": data["user_id"], "type": data.get("type", "access")}
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||||
|
return payload
|
||||||
|
except jwt.ExpiredSignatureError:
|
||||||
|
return None
|
||||||
|
except jwt.JWTError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def create_user(email: str, name: str, password: str) -> User:
|
||||||
|
"""Create a new user in the database"""
|
||||||
|
with get_db_session() as db:
|
||||||
|
repo = UserRepository(db)
|
||||||
|
|
||||||
|
# Check if email already exists
|
||||||
|
existing = repo.get_by_email(email)
|
||||||
|
if existing:
|
||||||
|
raise ValueError("Email already registered")
|
||||||
|
|
||||||
|
password_hash = hash_password(password)
|
||||||
|
user = repo.create(
|
||||||
|
email=email,
|
||||||
|
name=name,
|
||||||
|
password_hash=password_hash,
|
||||||
|
plan=PlanType.FREE,
|
||||||
|
)
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
def authenticate_user(email: str, password: str) -> Optional[User]:
|
||||||
|
"""Authenticate user and return user object if valid"""
|
||||||
|
with get_db_session() as db:
|
||||||
|
repo = UserRepository(db)
|
||||||
|
user = repo.get_by_email(email)
|
||||||
|
|
||||||
|
if not user:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not verify_password(password, user.password_hash):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Update last login
|
||||||
|
repo.update(user.id, last_login_at=datetime.utcnow())
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_by_id(user_id: str) -> Optional[User]:
|
||||||
|
"""Get user by ID from database"""
|
||||||
|
with get_db_session() as db:
|
||||||
|
repo = UserRepository(db)
|
||||||
|
return repo.get_by_id(user_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_by_email(email: str) -> Optional[User]:
|
||||||
|
"""Get user by email from database"""
|
||||||
|
with get_db_session() as db:
|
||||||
|
repo = UserRepository(db)
|
||||||
|
return repo.get_by_email(email)
|
||||||
|
|
||||||
|
|
||||||
|
def update_user(user_id: str, updates: Dict[str, Any]) -> Optional[User]:
|
||||||
|
"""Update user fields in database"""
|
||||||
|
with get_db_session() as db:
|
||||||
|
repo = UserRepository(db)
|
||||||
|
return repo.update(user_id, **updates)
|
||||||
|
|
||||||
|
|
||||||
|
def add_credits(user_id: str, credits: int) -> bool:
|
||||||
|
"""Add credits to user account"""
|
||||||
|
with get_db_session() as db:
|
||||||
|
repo = UserRepository(db)
|
||||||
|
result = repo.add_credits(user_id, credits)
|
||||||
|
return result is not None
|
||||||
|
|
||||||
|
|
||||||
|
def use_credits(user_id: str, credits: int) -> bool:
|
||||||
|
"""Use credits from user account"""
|
||||||
|
with get_db_session() as db:
|
||||||
|
repo = UserRepository(db)
|
||||||
|
return repo.use_credits(user_id, credits)
|
||||||
|
|
||||||
|
|
||||||
|
def increment_usage(user_id: str, docs: int = 0, pages: int = 0, api_calls: int = 0) -> bool:
|
||||||
|
"""Increment user usage counters"""
|
||||||
|
with get_db_session() as db:
|
||||||
|
repo = UserRepository(db)
|
||||||
|
result = repo.increment_usage(user_id, docs=docs, pages=pages, api_calls=api_calls)
|
||||||
|
return result is not None
|
||||||
|
|
||||||
|
|
||||||
|
def check_usage_limits(user_id: str) -> Dict[str, Any]:
|
||||||
|
"""Check if user is within their plan limits"""
|
||||||
|
with get_db_session() as db:
|
||||||
|
repo = UserRepository(db)
|
||||||
|
user = repo.get_by_id(user_id)
|
||||||
|
|
||||||
|
if not user:
|
||||||
|
return {"allowed": False, "reason": "User not found"}
|
||||||
|
|
||||||
|
plan_config = PLANS.get(user.plan, PLANS[PlanType.FREE])
|
||||||
|
|
||||||
|
# Check document limit
|
||||||
|
docs_limit = plan_config["docs_per_month"]
|
||||||
|
if docs_limit > 0 and user.docs_translated_this_month >= docs_limit:
|
||||||
|
# Check if user has extra credits
|
||||||
|
if user.extra_credits <= 0:
|
||||||
|
return {
|
||||||
|
"allowed": False,
|
||||||
|
"reason": "Monthly document limit reached",
|
||||||
|
"limit": docs_limit,
|
||||||
|
"used": user.docs_translated_this_month,
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"allowed": True,
|
||||||
|
"docs_remaining": max(0, docs_limit - user.docs_translated_this_month) if docs_limit > 0 else -1,
|
||||||
|
"extra_credits": user.extra_credits,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_usage_stats(user_id: str) -> Dict[str, Any]:
|
||||||
|
"""Get detailed usage statistics for a user"""
|
||||||
|
with get_db_session() as db:
|
||||||
|
repo = UserRepository(db)
|
||||||
|
user = repo.get_by_id(user_id)
|
||||||
|
|
||||||
|
if not user:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
plan_config = PLANS.get(user.plan, PLANS[PlanType.FREE])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"docs_used": user.docs_translated_this_month,
|
||||||
|
"docs_limit": plan_config["docs_per_month"],
|
||||||
|
"docs_remaining": max(0, plan_config["docs_per_month"] - user.docs_translated_this_month) if plan_config["docs_per_month"] > 0 else -1,
|
||||||
|
"pages_used": user.pages_translated_this_month,
|
||||||
|
"extra_credits": user.extra_credits,
|
||||||
|
"max_pages_per_doc": plan_config["max_pages_per_doc"],
|
||||||
|
"max_file_size_mb": plan_config["max_file_size_mb"],
|
||||||
|
"allowed_providers": plan_config["providers"],
|
||||||
|
"api_access": plan_config.get("api_access", False),
|
||||||
|
"api_calls_used": user.api_calls_this_month if plan_config.get("api_access") else 0,
|
||||||
|
"api_calls_limit": plan_config.get("api_calls_per_month", 0),
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user