- Add SQLAlchemy models for User, Translation, ApiKey, UsageLog, PaymentHistory - Add database connection management with PostgreSQL/SQLite support - Add repository layer for CRUD operations - Add Alembic migration setup with initial migration - Update auth_service to automatically use database when DATABASE_URL is set - Update docker-compose.yml with PostgreSQL service and Redis (non-optional) - Add database migration script (scripts/migrate_to_db.py) - Update .env.example with database configuration
365 lines
12 KiB
Python
365 lines
12 KiB
Python
"""
|
|
Authentication service with JWT tokens and password hashing
|
|
|
|
This service provides user authentication with automatic backend selection:
|
|
- If DATABASE_URL is configured: Uses PostgreSQL database
|
|
- Otherwise: Falls back to JSON file storage (development mode)
|
|
"""
|
|
import os
|
|
import secrets
|
|
import hashlib
|
|
from datetime import datetime, timedelta
|
|
from typing import Optional, Dict, Any
|
|
import json
|
|
from pathlib import Path
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Try to import optional dependencies
|
|
try:
|
|
import jwt
|
|
JWT_AVAILABLE = True
|
|
except ImportError:
|
|
JWT_AVAILABLE = False
|
|
logger.warning("PyJWT not installed. Using fallback token encoding.")
|
|
|
|
try:
|
|
from passlib.context import CryptContext
|
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
PASSLIB_AVAILABLE = True
|
|
except ImportError:
|
|
PASSLIB_AVAILABLE = False
|
|
logger.warning("passlib not installed. Using SHA256 fallback for password hashing.")
|
|
|
|
# Check if database is configured
|
|
DATABASE_URL = os.getenv("DATABASE_URL", "")
|
|
USE_DATABASE = bool(DATABASE_URL and DATABASE_URL.startswith("postgresql"))
|
|
|
|
if USE_DATABASE:
|
|
try:
|
|
from database.repositories import UserRepository
|
|
from database.connection import get_sync_session, init_db as _init_db
|
|
from database import models as db_models
|
|
DATABASE_AVAILABLE = True
|
|
logger.info("Database backend enabled for authentication")
|
|
except ImportError as e:
|
|
DATABASE_AVAILABLE = False
|
|
USE_DATABASE = False
|
|
logger.warning(f"Database modules not available: {e}. Using JSON storage.")
|
|
else:
|
|
DATABASE_AVAILABLE = False
|
|
logger.info("Using JSON file storage for authentication (DATABASE_URL not configured)")
|
|
|
|
from models.subscription import User, UserCreate, PlanType, SubscriptionStatus, PLANS
|
|
|
|
|
|
# Configuration
|
|
SECRET_KEY = os.getenv("JWT_SECRET", os.getenv("JWT_SECRET_KEY", secrets.token_urlsafe(32)))
|
|
ALGORITHM = "HS256"
|
|
ACCESS_TOKEN_EXPIRE_HOURS = 24
|
|
REFRESH_TOKEN_EXPIRE_DAYS = 30
|
|
|
|
# Simple file-based storage (used when database is not configured)
|
|
USERS_FILE = Path("data/users.json")
|
|
USERS_FILE.parent.mkdir(exist_ok=True)
|
|
|
|
|
|
def hash_password(password: str) -> str:
|
|
"""Hash a password using bcrypt or fallback to SHA256"""
|
|
if PASSLIB_AVAILABLE:
|
|
return pwd_context.hash(password)
|
|
else:
|
|
# Fallback to SHA256 with salt
|
|
salt = secrets.token_hex(16)
|
|
hashed = hashlib.sha256(f"{salt}{password}".encode()).hexdigest()
|
|
return f"sha256${salt}${hashed}"
|
|
|
|
|
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
|
"""Verify a password against its hash"""
|
|
if PASSLIB_AVAILABLE and not hashed_password.startswith("sha256$"):
|
|
return pwd_context.verify(plain_password, hashed_password)
|
|
else:
|
|
# Fallback SHA256 verification
|
|
parts = hashed_password.split("$")
|
|
if len(parts) == 3 and parts[0] == "sha256":
|
|
salt = parts[1]
|
|
expected_hash = parts[2]
|
|
actual_hash = hashlib.sha256(f"{salt}{plain_password}".encode()).hexdigest()
|
|
return secrets.compare_digest(actual_hash, expected_hash)
|
|
return False
|
|
|
|
|
|
def create_access_token(user_id: str, expires_delta: Optional[timedelta] = None) -> str:
|
|
"""Create a JWT access token"""
|
|
if not JWT_AVAILABLE:
|
|
# Fallback to simple token
|
|
token_data = {
|
|
"user_id": user_id,
|
|
"exp": (datetime.utcnow() + (expires_delta or timedelta(hours=ACCESS_TOKEN_EXPIRE_HOURS))).isoformat()
|
|
}
|
|
import base64
|
|
return base64.urlsafe_b64encode(json.dumps(token_data).encode()).decode()
|
|
|
|
expire = datetime.utcnow() + (expires_delta or timedelta(hours=ACCESS_TOKEN_EXPIRE_HOURS))
|
|
to_encode = {"sub": user_id, "exp": expire, "type": "access"}
|
|
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
|
|
|
|
|
def create_refresh_token(user_id: str) -> str:
|
|
"""Create a JWT refresh token"""
|
|
if not JWT_AVAILABLE:
|
|
token_data = {
|
|
"user_id": user_id,
|
|
"exp": (datetime.utcnow() + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)).isoformat()
|
|
}
|
|
import base64
|
|
return base64.urlsafe_b64encode(json.dumps(token_data).encode()).decode()
|
|
|
|
expire = datetime.utcnow() + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
|
|
to_encode = {"sub": user_id, "exp": expire, "type": "refresh"}
|
|
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
|
|
|
|
|
def verify_token(token: str) -> Optional[Dict[str, Any]]:
|
|
"""Verify a JWT token and return payload"""
|
|
if not JWT_AVAILABLE:
|
|
try:
|
|
import base64
|
|
data = json.loads(base64.urlsafe_b64decode(token.encode()).decode())
|
|
exp = datetime.fromisoformat(data["exp"])
|
|
if exp < datetime.utcnow():
|
|
return None
|
|
return {"sub": data["user_id"]}
|
|
except:
|
|
return None
|
|
|
|
try:
|
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
return payload
|
|
except jwt.ExpiredSignatureError:
|
|
return None
|
|
except jwt.JWTError:
|
|
return None
|
|
|
|
|
|
def load_users() -> Dict[str, Dict]:
|
|
"""Load users from file storage (JSON backend only)"""
|
|
if USERS_FILE.exists():
|
|
try:
|
|
with open(USERS_FILE, 'r') as f:
|
|
return json.load(f)
|
|
except:
|
|
return {}
|
|
return {}
|
|
|
|
|
|
def save_users(users: Dict[str, Dict]):
|
|
"""Save users to file storage (JSON backend only)"""
|
|
with open(USERS_FILE, 'w') as f:
|
|
json.dump(users, f, indent=2, default=str)
|
|
|
|
|
|
def _db_user_to_model(db_user) -> User:
|
|
"""Convert database user model to Pydantic User model"""
|
|
return User(
|
|
id=str(db_user.id),
|
|
email=db_user.email,
|
|
name=db_user.name or "",
|
|
password_hash=db_user.password_hash,
|
|
avatar_url=db_user.avatar_url,
|
|
plan=PlanType(db_user.plan) if db_user.plan else PlanType.FREE,
|
|
subscription_status=SubscriptionStatus(db_user.subscription_status) if db_user.subscription_status else SubscriptionStatus.ACTIVE,
|
|
stripe_customer_id=db_user.stripe_customer_id,
|
|
stripe_subscription_id=db_user.stripe_subscription_id,
|
|
docs_translated_this_month=db_user.docs_translated_this_month or 0,
|
|
pages_translated_this_month=db_user.pages_translated_this_month or 0,
|
|
api_calls_this_month=db_user.api_calls_this_month or 0,
|
|
extra_credits=db_user.extra_credits or 0,
|
|
usage_reset_date=db_user.usage_reset_date or datetime.utcnow(),
|
|
default_source_lang=db_user.default_source_lang or "en",
|
|
default_target_lang=db_user.default_target_lang or "es",
|
|
default_provider=db_user.default_provider or "google",
|
|
created_at=db_user.created_at or datetime.utcnow(),
|
|
updated_at=db_user.updated_at,
|
|
)
|
|
|
|
|
|
def get_user_by_email(email: str) -> Optional[User]:
|
|
"""Get a user by email"""
|
|
if USE_DATABASE and DATABASE_AVAILABLE:
|
|
with get_sync_session() as session:
|
|
repo = UserRepository(session)
|
|
db_user = repo.get_by_email(email)
|
|
if db_user:
|
|
return _db_user_to_model(db_user)
|
|
return None
|
|
else:
|
|
users = load_users()
|
|
for user_data in users.values():
|
|
if user_data.get("email", "").lower() == email.lower():
|
|
return User(**user_data)
|
|
return None
|
|
|
|
|
|
def get_user_by_id(user_id: str) -> Optional[User]:
|
|
"""Get a user by ID"""
|
|
if USE_DATABASE and DATABASE_AVAILABLE:
|
|
with get_sync_session() as session:
|
|
repo = UserRepository(session)
|
|
db_user = repo.get_by_id(user_id)
|
|
if db_user:
|
|
return _db_user_to_model(db_user)
|
|
return None
|
|
else:
|
|
users = load_users()
|
|
if user_id in users:
|
|
return User(**users[user_id])
|
|
return None
|
|
|
|
|
|
def create_user(user_create: UserCreate) -> User:
|
|
"""Create a new user"""
|
|
# Check if email exists
|
|
if get_user_by_email(user_create.email):
|
|
raise ValueError("Email already registered")
|
|
|
|
if USE_DATABASE and DATABASE_AVAILABLE:
|
|
with get_sync_session() as session:
|
|
repo = UserRepository(session)
|
|
db_user = repo.create(
|
|
email=user_create.email,
|
|
name=user_create.name,
|
|
password_hash=hash_password(user_create.password),
|
|
plan=PlanType.FREE.value,
|
|
subscription_status=SubscriptionStatus.ACTIVE.value
|
|
)
|
|
session.commit()
|
|
session.refresh(db_user)
|
|
return _db_user_to_model(db_user)
|
|
else:
|
|
users = load_users()
|
|
|
|
# Generate user ID
|
|
user_id = secrets.token_urlsafe(16)
|
|
|
|
# Create user
|
|
user = User(
|
|
id=user_id,
|
|
email=user_create.email,
|
|
name=user_create.name,
|
|
password_hash=hash_password(user_create.password),
|
|
plan=PlanType.FREE,
|
|
subscription_status=SubscriptionStatus.ACTIVE,
|
|
)
|
|
|
|
# Save to storage
|
|
users[user_id] = user.model_dump()
|
|
save_users(users)
|
|
|
|
return user
|
|
|
|
|
|
def authenticate_user(email: str, password: str) -> Optional[User]:
|
|
"""Authenticate a user with email and password"""
|
|
user = get_user_by_email(email)
|
|
if not user:
|
|
return None
|
|
if not verify_password(password, user.password_hash):
|
|
return None
|
|
return user
|
|
|
|
|
|
def update_user(user_id: str, updates: Dict[str, Any]) -> Optional[User]:
|
|
"""Update a user's data"""
|
|
if USE_DATABASE and DATABASE_AVAILABLE:
|
|
with get_sync_session() as session:
|
|
repo = UserRepository(session)
|
|
db_user = repo.update(user_id, updates)
|
|
if db_user:
|
|
session.commit()
|
|
session.refresh(db_user)
|
|
return _db_user_to_model(db_user)
|
|
return None
|
|
else:
|
|
users = load_users()
|
|
if user_id not in users:
|
|
return None
|
|
|
|
users[user_id].update(updates)
|
|
users[user_id]["updated_at"] = datetime.utcnow().isoformat()
|
|
save_users(users)
|
|
|
|
return User(**users[user_id])
|
|
|
|
|
|
def check_usage_limits(user: User) -> Dict[str, Any]:
|
|
"""Check if user has exceeded their plan limits"""
|
|
plan = PLANS[user.plan]
|
|
|
|
# Reset usage if it's a new month
|
|
now = datetime.utcnow()
|
|
if user.usage_reset_date.month != now.month or user.usage_reset_date.year != now.year:
|
|
update_user(user.id, {
|
|
"docs_translated_this_month": 0,
|
|
"pages_translated_this_month": 0,
|
|
"api_calls_this_month": 0,
|
|
"usage_reset_date": now.isoformat() if not USE_DATABASE else now
|
|
})
|
|
user.docs_translated_this_month = 0
|
|
user.pages_translated_this_month = 0
|
|
user.api_calls_this_month = 0
|
|
|
|
docs_limit = plan["docs_per_month"]
|
|
docs_remaining = max(0, docs_limit - user.docs_translated_this_month) if docs_limit > 0 else -1
|
|
|
|
return {
|
|
"can_translate": docs_remaining != 0 or user.extra_credits > 0,
|
|
"docs_used": user.docs_translated_this_month,
|
|
"docs_limit": docs_limit,
|
|
"docs_remaining": docs_remaining,
|
|
"pages_used": user.pages_translated_this_month,
|
|
"extra_credits": user.extra_credits,
|
|
"max_pages_per_doc": plan["max_pages_per_doc"],
|
|
"max_file_size_mb": plan["max_file_size_mb"],
|
|
"allowed_providers": plan["providers"],
|
|
}
|
|
|
|
|
|
def record_usage(user_id: str, pages_count: int, use_credits: bool = False) -> bool:
|
|
"""Record document translation usage"""
|
|
user = get_user_by_id(user_id)
|
|
if not user:
|
|
return False
|
|
|
|
updates = {
|
|
"docs_translated_this_month": user.docs_translated_this_month + 1,
|
|
"pages_translated_this_month": user.pages_translated_this_month + pages_count,
|
|
}
|
|
|
|
if use_credits:
|
|
updates["extra_credits"] = max(0, user.extra_credits - pages_count)
|
|
|
|
result = update_user(user_id, updates)
|
|
return result is not None
|
|
|
|
|
|
def add_credits(user_id: str, credits: int) -> bool:
|
|
"""Add credits to a user's account"""
|
|
user = get_user_by_id(user_id)
|
|
if not user:
|
|
return False
|
|
|
|
result = update_user(user_id, {"extra_credits": user.extra_credits + credits})
|
|
return result is not None
|
|
|
|
|
|
def init_database():
|
|
"""Initialize the database (call on application startup)"""
|
|
if USE_DATABASE and DATABASE_AVAILABLE:
|
|
_init_db()
|
|
logger.info("Database initialized successfully")
|
|
else:
|
|
logger.info("Using JSON file storage")
|