- Add SQLAlchemy models for User, Translation, ApiKey, UsageLog, PaymentHistory - Add database connection management with PostgreSQL/SQLite support - Add repository layer for CRUD operations - Add Alembic migration setup with initial migration - Update auth_service to automatically use database when DATABASE_URL is set - Update docker-compose.yml with PostgreSQL service and Redis (non-optional) - Add database migration script (scripts/migrate_to_db.py) - Update .env.example with database configuration
246 lines
8.3 KiB
Python
246 lines
8.3 KiB
Python
"""
|
|
Database-backed authentication service
|
|
Replaces JSON file storage with SQLAlchemy
|
|
"""
|
|
import os
|
|
import secrets
|
|
import hashlib
|
|
from datetime import datetime, timedelta
|
|
from typing import Optional, Dict, Any
|
|
import logging
|
|
|
|
# Try to import optional dependencies
|
|
try:
|
|
import jwt
|
|
JWT_AVAILABLE = True
|
|
except ImportError:
|
|
JWT_AVAILABLE = False
|
|
import json
|
|
import base64
|
|
|
|
try:
|
|
from passlib.context import CryptContext
|
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
PASSLIB_AVAILABLE = True
|
|
except ImportError:
|
|
PASSLIB_AVAILABLE = False
|
|
|
|
from database.connection import get_db_session
|
|
from database.repositories import UserRepository
|
|
from database.models import User, PlanType, SubscriptionStatus
|
|
from models.subscription import PLANS
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Configuration
|
|
SECRET_KEY = os.getenv("JWT_SECRET_KEY", secrets.token_urlsafe(32))
|
|
ALGORITHM = "HS256"
|
|
ACCESS_TOKEN_EXPIRE_HOURS = 24
|
|
REFRESH_TOKEN_EXPIRE_DAYS = 30
|
|
|
|
|
|
def hash_password(password: str) -> str:
|
|
"""Hash a password using bcrypt or fallback to SHA256"""
|
|
if PASSLIB_AVAILABLE:
|
|
return pwd_context.hash(password)
|
|
else:
|
|
salt = secrets.token_hex(16)
|
|
hashed = hashlib.sha256(f"{salt}{password}".encode()).hexdigest()
|
|
return f"sha256${salt}${hashed}"
|
|
|
|
|
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
|
"""Verify a password against its hash"""
|
|
if PASSLIB_AVAILABLE and not hashed_password.startswith("sha256$"):
|
|
return pwd_context.verify(plain_password, hashed_password)
|
|
else:
|
|
parts = hashed_password.split("$")
|
|
if len(parts) == 3 and parts[0] == "sha256":
|
|
salt = parts[1]
|
|
expected_hash = parts[2]
|
|
actual_hash = hashlib.sha256(f"{salt}{plain_password}".encode()).hexdigest()
|
|
return secrets.compare_digest(actual_hash, expected_hash)
|
|
return False
|
|
|
|
|
|
def create_access_token(user_id: str, expires_delta: Optional[timedelta] = None) -> str:
|
|
"""Create a JWT access token"""
|
|
expire = datetime.utcnow() + (expires_delta or timedelta(hours=ACCESS_TOKEN_EXPIRE_HOURS))
|
|
|
|
if not JWT_AVAILABLE:
|
|
token_data = {"user_id": user_id, "exp": expire.isoformat(), "type": "access"}
|
|
return base64.urlsafe_b64encode(json.dumps(token_data).encode()).decode()
|
|
|
|
to_encode = {"sub": user_id, "exp": expire, "type": "access"}
|
|
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
|
|
|
|
|
def create_refresh_token(user_id: str) -> str:
|
|
"""Create a JWT refresh token"""
|
|
expire = datetime.utcnow() + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
|
|
|
|
if not JWT_AVAILABLE:
|
|
token_data = {"user_id": user_id, "exp": expire.isoformat(), "type": "refresh"}
|
|
return base64.urlsafe_b64encode(json.dumps(token_data).encode()).decode()
|
|
|
|
to_encode = {"sub": user_id, "exp": expire, "type": "refresh"}
|
|
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
|
|
|
|
|
def verify_token(token: str) -> Optional[Dict[str, Any]]:
|
|
"""Verify a JWT token and return payload"""
|
|
if not JWT_AVAILABLE:
|
|
try:
|
|
data = json.loads(base64.urlsafe_b64decode(token.encode()).decode())
|
|
exp = datetime.fromisoformat(data["exp"])
|
|
if exp < datetime.utcnow():
|
|
return None
|
|
return {"sub": data["user_id"], "type": data.get("type", "access")}
|
|
except Exception:
|
|
return None
|
|
|
|
try:
|
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
return payload
|
|
except jwt.ExpiredSignatureError:
|
|
return None
|
|
except jwt.JWTError:
|
|
return None
|
|
|
|
|
|
def create_user(email: str, name: str, password: str) -> User:
|
|
"""Create a new user in the database"""
|
|
with get_db_session() as db:
|
|
repo = UserRepository(db)
|
|
|
|
# Check if email already exists
|
|
existing = repo.get_by_email(email)
|
|
if existing:
|
|
raise ValueError("Email already registered")
|
|
|
|
password_hash = hash_password(password)
|
|
user = repo.create(
|
|
email=email,
|
|
name=name,
|
|
password_hash=password_hash,
|
|
plan=PlanType.FREE,
|
|
)
|
|
return user
|
|
|
|
|
|
def authenticate_user(email: str, password: str) -> Optional[User]:
|
|
"""Authenticate user and return user object if valid"""
|
|
with get_db_session() as db:
|
|
repo = UserRepository(db)
|
|
user = repo.get_by_email(email)
|
|
|
|
if not user:
|
|
return None
|
|
|
|
if not verify_password(password, user.password_hash):
|
|
return None
|
|
|
|
# Update last login
|
|
repo.update(user.id, last_login_at=datetime.utcnow())
|
|
return user
|
|
|
|
|
|
def get_user_by_id(user_id: str) -> Optional[User]:
|
|
"""Get user by ID from database"""
|
|
with get_db_session() as db:
|
|
repo = UserRepository(db)
|
|
return repo.get_by_id(user_id)
|
|
|
|
|
|
def get_user_by_email(email: str) -> Optional[User]:
|
|
"""Get user by email from database"""
|
|
with get_db_session() as db:
|
|
repo = UserRepository(db)
|
|
return repo.get_by_email(email)
|
|
|
|
|
|
def update_user(user_id: str, updates: Dict[str, Any]) -> Optional[User]:
|
|
"""Update user fields in database"""
|
|
with get_db_session() as db:
|
|
repo = UserRepository(db)
|
|
return repo.update(user_id, **updates)
|
|
|
|
|
|
def add_credits(user_id: str, credits: int) -> bool:
|
|
"""Add credits to user account"""
|
|
with get_db_session() as db:
|
|
repo = UserRepository(db)
|
|
result = repo.add_credits(user_id, credits)
|
|
return result is not None
|
|
|
|
|
|
def use_credits(user_id: str, credits: int) -> bool:
|
|
"""Use credits from user account"""
|
|
with get_db_session() as db:
|
|
repo = UserRepository(db)
|
|
return repo.use_credits(user_id, credits)
|
|
|
|
|
|
def increment_usage(user_id: str, docs: int = 0, pages: int = 0, api_calls: int = 0) -> bool:
|
|
"""Increment user usage counters"""
|
|
with get_db_session() as db:
|
|
repo = UserRepository(db)
|
|
result = repo.increment_usage(user_id, docs=docs, pages=pages, api_calls=api_calls)
|
|
return result is not None
|
|
|
|
|
|
def check_usage_limits(user_id: str) -> Dict[str, Any]:
|
|
"""Check if user is within their plan limits"""
|
|
with get_db_session() as db:
|
|
repo = UserRepository(db)
|
|
user = repo.get_by_id(user_id)
|
|
|
|
if not user:
|
|
return {"allowed": False, "reason": "User not found"}
|
|
|
|
plan_config = PLANS.get(user.plan, PLANS[PlanType.FREE])
|
|
|
|
# Check document limit
|
|
docs_limit = plan_config["docs_per_month"]
|
|
if docs_limit > 0 and user.docs_translated_this_month >= docs_limit:
|
|
# Check if user has extra credits
|
|
if user.extra_credits <= 0:
|
|
return {
|
|
"allowed": False,
|
|
"reason": "Monthly document limit reached",
|
|
"limit": docs_limit,
|
|
"used": user.docs_translated_this_month,
|
|
}
|
|
|
|
return {
|
|
"allowed": True,
|
|
"docs_remaining": max(0, docs_limit - user.docs_translated_this_month) if docs_limit > 0 else -1,
|
|
"extra_credits": user.extra_credits,
|
|
}
|
|
|
|
|
|
def get_user_usage_stats(user_id: str) -> Dict[str, Any]:
|
|
"""Get detailed usage statistics for a user"""
|
|
with get_db_session() as db:
|
|
repo = UserRepository(db)
|
|
user = repo.get_by_id(user_id)
|
|
|
|
if not user:
|
|
return {}
|
|
|
|
plan_config = PLANS.get(user.plan, PLANS[PlanType.FREE])
|
|
|
|
return {
|
|
"docs_used": user.docs_translated_this_month,
|
|
"docs_limit": plan_config["docs_per_month"],
|
|
"docs_remaining": max(0, plan_config["docs_per_month"] - user.docs_translated_this_month) if plan_config["docs_per_month"] > 0 else -1,
|
|
"pages_used": user.pages_translated_this_month,
|
|
"extra_credits": user.extra_credits,
|
|
"max_pages_per_doc": plan_config["max_pages_per_doc"],
|
|
"max_file_size_mb": plan_config["max_file_size_mb"],
|
|
"allowed_providers": plan_config["providers"],
|
|
"api_access": plan_config.get("api_access", False),
|
|
"api_calls_used": user.api_calls_this_month if plan_config.get("api_access") else 0,
|
|
"api_calls_limit": plan_config.get("api_calls_per_month", 0),
|
|
}
|