feat: Add PostgreSQL database infrastructure
- Add SQLAlchemy models for User, Translation, ApiKey, UsageLog, PaymentHistory - Add database connection management with PostgreSQL/SQLite support - Add repository layer for CRUD operations - Add Alembic migration setup with initial migration - Update auth_service to automatically use database when DATABASE_URL is set - Update docker-compose.yml with PostgreSQL service and Redis (non-optional) - Add database migration script (scripts/migrate_to_db.py) - Update .env.example with database configuration
This commit is contained in:
@@ -1,5 +1,9 @@
|
||||
"""
|
||||
Authentication service with JWT tokens and password hashing
|
||||
|
||||
This service provides user authentication with automatic backend selection:
|
||||
- If DATABASE_URL is configured: Uses PostgreSQL database
|
||||
- Otherwise: Falls back to JSON file storage (development mode)
|
||||
"""
|
||||
import os
|
||||
import secrets
|
||||
@@ -8,6 +12,9 @@ from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any
|
||||
import json
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Try to import optional dependencies
|
||||
try:
|
||||
@@ -15,6 +22,7 @@ try:
|
||||
JWT_AVAILABLE = True
|
||||
except ImportError:
|
||||
JWT_AVAILABLE = False
|
||||
logger.warning("PyJWT not installed. Using fallback token encoding.")
|
||||
|
||||
try:
|
||||
from passlib.context import CryptContext
|
||||
@@ -22,17 +30,37 @@ try:
|
||||
PASSLIB_AVAILABLE = True
|
||||
except ImportError:
|
||||
PASSLIB_AVAILABLE = False
|
||||
logger.warning("passlib not installed. Using SHA256 fallback for password hashing.")
|
||||
|
||||
# Check if database is configured
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "")
|
||||
USE_DATABASE = bool(DATABASE_URL and DATABASE_URL.startswith("postgresql"))
|
||||
|
||||
if USE_DATABASE:
|
||||
try:
|
||||
from database.repositories import UserRepository
|
||||
from database.connection import get_sync_session, init_db as _init_db
|
||||
from database import models as db_models
|
||||
DATABASE_AVAILABLE = True
|
||||
logger.info("Database backend enabled for authentication")
|
||||
except ImportError as e:
|
||||
DATABASE_AVAILABLE = False
|
||||
USE_DATABASE = False
|
||||
logger.warning(f"Database modules not available: {e}. Using JSON storage.")
|
||||
else:
|
||||
DATABASE_AVAILABLE = False
|
||||
logger.info("Using JSON file storage for authentication (DATABASE_URL not configured)")
|
||||
|
||||
from models.subscription import User, UserCreate, PlanType, SubscriptionStatus, PLANS
|
||||
|
||||
|
||||
# Configuration
|
||||
SECRET_KEY = os.getenv("JWT_SECRET_KEY", secrets.token_urlsafe(32))
|
||||
SECRET_KEY = os.getenv("JWT_SECRET", os.getenv("JWT_SECRET_KEY", secrets.token_urlsafe(32)))
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_HOURS = 24
|
||||
REFRESH_TOKEN_EXPIRE_DAYS = 30
|
||||
|
||||
# Simple file-based storage (replace with database in production)
|
||||
# Simple file-based storage (used when database is not configured)
|
||||
USERS_FILE = Path("data/users.json")
|
||||
USERS_FILE.parent.mkdir(exist_ok=True)
|
||||
|
||||
@@ -117,7 +145,7 @@ def verify_token(token: str) -> Optional[Dict[str, Any]]:
|
||||
|
||||
|
||||
def load_users() -> Dict[str, Dict]:
|
||||
"""Load users from file storage"""
|
||||
"""Load users from file storage (JSON backend only)"""
|
||||
if USERS_FILE.exists():
|
||||
try:
|
||||
with open(USERS_FILE, 'r') as f:
|
||||
@@ -128,54 +156,109 @@ def load_users() -> Dict[str, Dict]:
|
||||
|
||||
|
||||
def save_users(users: Dict[str, Dict]):
|
||||
"""Save users to file storage"""
|
||||
"""Save users to file storage (JSON backend only)"""
|
||||
with open(USERS_FILE, 'w') as f:
|
||||
json.dump(users, f, indent=2, default=str)
|
||||
|
||||
|
||||
def _db_user_to_model(db_user) -> User:
|
||||
"""Convert database user model to Pydantic User model"""
|
||||
return User(
|
||||
id=str(db_user.id),
|
||||
email=db_user.email,
|
||||
name=db_user.name or "",
|
||||
password_hash=db_user.password_hash,
|
||||
avatar_url=db_user.avatar_url,
|
||||
plan=PlanType(db_user.plan) if db_user.plan else PlanType.FREE,
|
||||
subscription_status=SubscriptionStatus(db_user.subscription_status) if db_user.subscription_status else SubscriptionStatus.ACTIVE,
|
||||
stripe_customer_id=db_user.stripe_customer_id,
|
||||
stripe_subscription_id=db_user.stripe_subscription_id,
|
||||
docs_translated_this_month=db_user.docs_translated_this_month or 0,
|
||||
pages_translated_this_month=db_user.pages_translated_this_month or 0,
|
||||
api_calls_this_month=db_user.api_calls_this_month or 0,
|
||||
extra_credits=db_user.extra_credits or 0,
|
||||
usage_reset_date=db_user.usage_reset_date or datetime.utcnow(),
|
||||
default_source_lang=db_user.default_source_lang or "en",
|
||||
default_target_lang=db_user.default_target_lang or "es",
|
||||
default_provider=db_user.default_provider or "google",
|
||||
created_at=db_user.created_at or datetime.utcnow(),
|
||||
updated_at=db_user.updated_at,
|
||||
)
|
||||
|
||||
|
||||
def get_user_by_email(email: str) -> Optional[User]:
|
||||
"""Get a user by email"""
|
||||
users = load_users()
|
||||
for user_data in users.values():
|
||||
if user_data.get("email", "").lower() == email.lower():
|
||||
return User(**user_data)
|
||||
return None
|
||||
if USE_DATABASE and DATABASE_AVAILABLE:
|
||||
with get_sync_session() as session:
|
||||
repo = UserRepository(session)
|
||||
db_user = repo.get_by_email(email)
|
||||
if db_user:
|
||||
return _db_user_to_model(db_user)
|
||||
return None
|
||||
else:
|
||||
users = load_users()
|
||||
for user_data in users.values():
|
||||
if user_data.get("email", "").lower() == email.lower():
|
||||
return User(**user_data)
|
||||
return None
|
||||
|
||||
|
||||
def get_user_by_id(user_id: str) -> Optional[User]:
|
||||
"""Get a user by ID"""
|
||||
users = load_users()
|
||||
if user_id in users:
|
||||
return User(**users[user_id])
|
||||
return None
|
||||
if USE_DATABASE and DATABASE_AVAILABLE:
|
||||
with get_sync_session() as session:
|
||||
repo = UserRepository(session)
|
||||
db_user = repo.get_by_id(user_id)
|
||||
if db_user:
|
||||
return _db_user_to_model(db_user)
|
||||
return None
|
||||
else:
|
||||
users = load_users()
|
||||
if user_id in users:
|
||||
return User(**users[user_id])
|
||||
return None
|
||||
|
||||
|
||||
def create_user(user_create: UserCreate) -> User:
|
||||
"""Create a new user"""
|
||||
users = load_users()
|
||||
|
||||
# Check if email exists
|
||||
if get_user_by_email(user_create.email):
|
||||
raise ValueError("Email already registered")
|
||||
|
||||
# Generate user ID
|
||||
user_id = secrets.token_urlsafe(16)
|
||||
|
||||
# Create user
|
||||
user = User(
|
||||
id=user_id,
|
||||
email=user_create.email,
|
||||
name=user_create.name,
|
||||
password_hash=hash_password(user_create.password),
|
||||
plan=PlanType.FREE,
|
||||
subscription_status=SubscriptionStatus.ACTIVE,
|
||||
)
|
||||
|
||||
# Save to storage
|
||||
users[user_id] = user.model_dump()
|
||||
save_users(users)
|
||||
|
||||
return user
|
||||
if USE_DATABASE and DATABASE_AVAILABLE:
|
||||
with get_sync_session() as session:
|
||||
repo = UserRepository(session)
|
||||
db_user = repo.create(
|
||||
email=user_create.email,
|
||||
name=user_create.name,
|
||||
password_hash=hash_password(user_create.password),
|
||||
plan=PlanType.FREE.value,
|
||||
subscription_status=SubscriptionStatus.ACTIVE.value
|
||||
)
|
||||
session.commit()
|
||||
session.refresh(db_user)
|
||||
return _db_user_to_model(db_user)
|
||||
else:
|
||||
users = load_users()
|
||||
|
||||
# Generate user ID
|
||||
user_id = secrets.token_urlsafe(16)
|
||||
|
||||
# Create user
|
||||
user = User(
|
||||
id=user_id,
|
||||
email=user_create.email,
|
||||
name=user_create.name,
|
||||
password_hash=hash_password(user_create.password),
|
||||
plan=PlanType.FREE,
|
||||
subscription_status=SubscriptionStatus.ACTIVE,
|
||||
)
|
||||
|
||||
# Save to storage
|
||||
users[user_id] = user.model_dump()
|
||||
save_users(users)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
def authenticate_user(email: str, password: str) -> Optional[User]:
|
||||
@@ -190,15 +273,25 @@ def authenticate_user(email: str, password: str) -> Optional[User]:
|
||||
|
||||
def update_user(user_id: str, updates: Dict[str, Any]) -> Optional[User]:
|
||||
"""Update a user's data"""
|
||||
users = load_users()
|
||||
if user_id not in users:
|
||||
return None
|
||||
|
||||
users[user_id].update(updates)
|
||||
users[user_id]["updated_at"] = datetime.utcnow().isoformat()
|
||||
save_users(users)
|
||||
|
||||
return User(**users[user_id])
|
||||
if USE_DATABASE and DATABASE_AVAILABLE:
|
||||
with get_sync_session() as session:
|
||||
repo = UserRepository(session)
|
||||
db_user = repo.update(user_id, updates)
|
||||
if db_user:
|
||||
session.commit()
|
||||
session.refresh(db_user)
|
||||
return _db_user_to_model(db_user)
|
||||
return None
|
||||
else:
|
||||
users = load_users()
|
||||
if user_id not in users:
|
||||
return None
|
||||
|
||||
users[user_id].update(updates)
|
||||
users[user_id]["updated_at"] = datetime.utcnow().isoformat()
|
||||
save_users(users)
|
||||
|
||||
return User(**users[user_id])
|
||||
|
||||
|
||||
def check_usage_limits(user: User) -> Dict[str, Any]:
|
||||
@@ -212,7 +305,7 @@ def check_usage_limits(user: User) -> Dict[str, Any]:
|
||||
"docs_translated_this_month": 0,
|
||||
"pages_translated_this_month": 0,
|
||||
"api_calls_this_month": 0,
|
||||
"usage_reset_date": now.isoformat()
|
||||
"usage_reset_date": now.isoformat() if not USE_DATABASE else now
|
||||
})
|
||||
user.docs_translated_this_month = 0
|
||||
user.pages_translated_this_month = 0
|
||||
@@ -248,8 +341,8 @@ def record_usage(user_id: str, pages_count: int, use_credits: bool = False) -> b
|
||||
if use_credits:
|
||||
updates["extra_credits"] = max(0, user.extra_credits - pages_count)
|
||||
|
||||
update_user(user_id, updates)
|
||||
return True
|
||||
result = update_user(user_id, updates)
|
||||
return result is not None
|
||||
|
||||
|
||||
def add_credits(user_id: str, credits: int) -> bool:
|
||||
@@ -258,5 +351,14 @@ def add_credits(user_id: str, credits: int) -> bool:
|
||||
if not user:
|
||||
return False
|
||||
|
||||
update_user(user_id, {"extra_credits": user.extra_credits + credits})
|
||||
return True
|
||||
result = update_user(user_id, {"extra_credits": user.extra_credits + credits})
|
||||
return result is not None
|
||||
|
||||
|
||||
def init_database():
|
||||
"""Initialize the database (call on application startup)"""
|
||||
if USE_DATABASE and DATABASE_AVAILABLE:
|
||||
_init_db()
|
||||
logger.info("Database initialized successfully")
|
||||
else:
|
||||
logger.info("Using JSON file storage")
|
||||
|
||||
245
services/auth_service_db.py
Normal file
245
services/auth_service_db.py
Normal file
@@ -0,0 +1,245 @@
|
||||
"""
|
||||
Database-backed authentication service
|
||||
Replaces JSON file storage with SQLAlchemy
|
||||
"""
|
||||
import os
|
||||
import secrets
|
||||
import hashlib
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any
|
||||
import logging
|
||||
|
||||
# Try to import optional dependencies
|
||||
try:
|
||||
import jwt
|
||||
JWT_AVAILABLE = True
|
||||
except ImportError:
|
||||
JWT_AVAILABLE = False
|
||||
import json
|
||||
import base64
|
||||
|
||||
try:
|
||||
from passlib.context import CryptContext
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
PASSLIB_AVAILABLE = True
|
||||
except ImportError:
|
||||
PASSLIB_AVAILABLE = False
|
||||
|
||||
from database.connection import get_db_session
|
||||
from database.repositories import UserRepository
|
||||
from database.models import User, PlanType, SubscriptionStatus
|
||||
from models.subscription import PLANS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Configuration
|
||||
SECRET_KEY = os.getenv("JWT_SECRET_KEY", secrets.token_urlsafe(32))
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_HOURS = 24
|
||||
REFRESH_TOKEN_EXPIRE_DAYS = 30
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
"""Hash a password using bcrypt or fallback to SHA256"""
|
||||
if PASSLIB_AVAILABLE:
|
||||
return pwd_context.hash(password)
|
||||
else:
|
||||
salt = secrets.token_hex(16)
|
||||
hashed = hashlib.sha256(f"{salt}{password}".encode()).hexdigest()
|
||||
return f"sha256${salt}${hashed}"
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a password against its hash"""
|
||||
if PASSLIB_AVAILABLE and not hashed_password.startswith("sha256$"):
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
else:
|
||||
parts = hashed_password.split("$")
|
||||
if len(parts) == 3 and parts[0] == "sha256":
|
||||
salt = parts[1]
|
||||
expected_hash = parts[2]
|
||||
actual_hash = hashlib.sha256(f"{salt}{plain_password}".encode()).hexdigest()
|
||||
return secrets.compare_digest(actual_hash, expected_hash)
|
||||
return False
|
||||
|
||||
|
||||
def create_access_token(user_id: str, expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""Create a JWT access token"""
|
||||
expire = datetime.utcnow() + (expires_delta or timedelta(hours=ACCESS_TOKEN_EXPIRE_HOURS))
|
||||
|
||||
if not JWT_AVAILABLE:
|
||||
token_data = {"user_id": user_id, "exp": expire.isoformat(), "type": "access"}
|
||||
return base64.urlsafe_b64encode(json.dumps(token_data).encode()).decode()
|
||||
|
||||
to_encode = {"sub": user_id, "exp": expire, "type": "access"}
|
||||
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
|
||||
def create_refresh_token(user_id: str) -> str:
|
||||
"""Create a JWT refresh token"""
|
||||
expire = datetime.utcnow() + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
|
||||
if not JWT_AVAILABLE:
|
||||
token_data = {"user_id": user_id, "exp": expire.isoformat(), "type": "refresh"}
|
||||
return base64.urlsafe_b64encode(json.dumps(token_data).encode()).decode()
|
||||
|
||||
to_encode = {"sub": user_id, "exp": expire, "type": "refresh"}
|
||||
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
|
||||
def verify_token(token: str) -> Optional[Dict[str, Any]]:
|
||||
"""Verify a JWT token and return payload"""
|
||||
if not JWT_AVAILABLE:
|
||||
try:
|
||||
data = json.loads(base64.urlsafe_b64decode(token.encode()).decode())
|
||||
exp = datetime.fromisoformat(data["exp"])
|
||||
if exp < datetime.utcnow():
|
||||
return None
|
||||
return {"sub": data["user_id"], "type": data.get("type", "access")}
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
return payload
|
||||
except jwt.ExpiredSignatureError:
|
||||
return None
|
||||
except jwt.JWTError:
|
||||
return None
|
||||
|
||||
|
||||
def create_user(email: str, name: str, password: str) -> User:
|
||||
"""Create a new user in the database"""
|
||||
with get_db_session() as db:
|
||||
repo = UserRepository(db)
|
||||
|
||||
# Check if email already exists
|
||||
existing = repo.get_by_email(email)
|
||||
if existing:
|
||||
raise ValueError("Email already registered")
|
||||
|
||||
password_hash = hash_password(password)
|
||||
user = repo.create(
|
||||
email=email,
|
||||
name=name,
|
||||
password_hash=password_hash,
|
||||
plan=PlanType.FREE,
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
def authenticate_user(email: str, password: str) -> Optional[User]:
|
||||
"""Authenticate user and return user object if valid"""
|
||||
with get_db_session() as db:
|
||||
repo = UserRepository(db)
|
||||
user = repo.get_by_email(email)
|
||||
|
||||
if not user:
|
||||
return None
|
||||
|
||||
if not verify_password(password, user.password_hash):
|
||||
return None
|
||||
|
||||
# Update last login
|
||||
repo.update(user.id, last_login_at=datetime.utcnow())
|
||||
return user
|
||||
|
||||
|
||||
def get_user_by_id(user_id: str) -> Optional[User]:
|
||||
"""Get user by ID from database"""
|
||||
with get_db_session() as db:
|
||||
repo = UserRepository(db)
|
||||
return repo.get_by_id(user_id)
|
||||
|
||||
|
||||
def get_user_by_email(email: str) -> Optional[User]:
|
||||
"""Get user by email from database"""
|
||||
with get_db_session() as db:
|
||||
repo = UserRepository(db)
|
||||
return repo.get_by_email(email)
|
||||
|
||||
|
||||
def update_user(user_id: str, updates: Dict[str, Any]) -> Optional[User]:
|
||||
"""Update user fields in database"""
|
||||
with get_db_session() as db:
|
||||
repo = UserRepository(db)
|
||||
return repo.update(user_id, **updates)
|
||||
|
||||
|
||||
def add_credits(user_id: str, credits: int) -> bool:
|
||||
"""Add credits to user account"""
|
||||
with get_db_session() as db:
|
||||
repo = UserRepository(db)
|
||||
result = repo.add_credits(user_id, credits)
|
||||
return result is not None
|
||||
|
||||
|
||||
def use_credits(user_id: str, credits: int) -> bool:
|
||||
"""Use credits from user account"""
|
||||
with get_db_session() as db:
|
||||
repo = UserRepository(db)
|
||||
return repo.use_credits(user_id, credits)
|
||||
|
||||
|
||||
def increment_usage(user_id: str, docs: int = 0, pages: int = 0, api_calls: int = 0) -> bool:
|
||||
"""Increment user usage counters"""
|
||||
with get_db_session() as db:
|
||||
repo = UserRepository(db)
|
||||
result = repo.increment_usage(user_id, docs=docs, pages=pages, api_calls=api_calls)
|
||||
return result is not None
|
||||
|
||||
|
||||
def check_usage_limits(user_id: str) -> Dict[str, Any]:
|
||||
"""Check if user is within their plan limits"""
|
||||
with get_db_session() as db:
|
||||
repo = UserRepository(db)
|
||||
user = repo.get_by_id(user_id)
|
||||
|
||||
if not user:
|
||||
return {"allowed": False, "reason": "User not found"}
|
||||
|
||||
plan_config = PLANS.get(user.plan, PLANS[PlanType.FREE])
|
||||
|
||||
# Check document limit
|
||||
docs_limit = plan_config["docs_per_month"]
|
||||
if docs_limit > 0 and user.docs_translated_this_month >= docs_limit:
|
||||
# Check if user has extra credits
|
||||
if user.extra_credits <= 0:
|
||||
return {
|
||||
"allowed": False,
|
||||
"reason": "Monthly document limit reached",
|
||||
"limit": docs_limit,
|
||||
"used": user.docs_translated_this_month,
|
||||
}
|
||||
|
||||
return {
|
||||
"allowed": True,
|
||||
"docs_remaining": max(0, docs_limit - user.docs_translated_this_month) if docs_limit > 0 else -1,
|
||||
"extra_credits": user.extra_credits,
|
||||
}
|
||||
|
||||
|
||||
def get_user_usage_stats(user_id: str) -> Dict[str, Any]:
|
||||
"""Get detailed usage statistics for a user"""
|
||||
with get_db_session() as db:
|
||||
repo = UserRepository(db)
|
||||
user = repo.get_by_id(user_id)
|
||||
|
||||
if not user:
|
||||
return {}
|
||||
|
||||
plan_config = PLANS.get(user.plan, PLANS[PlanType.FREE])
|
||||
|
||||
return {
|
||||
"docs_used": user.docs_translated_this_month,
|
||||
"docs_limit": plan_config["docs_per_month"],
|
||||
"docs_remaining": max(0, plan_config["docs_per_month"] - user.docs_translated_this_month) if plan_config["docs_per_month"] > 0 else -1,
|
||||
"pages_used": user.pages_translated_this_month,
|
||||
"extra_credits": user.extra_credits,
|
||||
"max_pages_per_doc": plan_config["max_pages_per_doc"],
|
||||
"max_file_size_mb": plan_config["max_file_size_mb"],
|
||||
"allowed_providers": plan_config["providers"],
|
||||
"api_access": plan_config.get("api_access", False),
|
||||
"api_calls_used": user.api_calls_this_month if plan_config.get("api_access") else 0,
|
||||
"api_calls_limit": plan_config.get("api_calls_per_month", 0),
|
||||
}
|
||||
Reference in New Issue
Block a user