572 lines
18 KiB
Python
572 lines
18 KiB
Python
"""
|
|
Authentication service with JWT tokens and password hashing
|
|
|
|
This service provides user authentication with automatic backend selection:
|
|
- If DATABASE_URL is configured: Uses PostgreSQL database
|
|
- Otherwise: Falls back to JSON file storage (development mode)
|
|
"""
|
|
|
|
import os
|
|
import secrets
|
|
import hashlib
|
|
import uuid
|
|
import time
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Optional, Dict, Any
|
|
import json
|
|
from pathlib import Path
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Try to import optional dependencies
|
|
try:
|
|
import jwt
|
|
|
|
JWT_AVAILABLE = True
|
|
except ImportError:
|
|
JWT_AVAILABLE = False
|
|
logger.warning("PyJWT not installed. Using fallback token encoding.")
|
|
|
|
try:
|
|
from passlib.context import CryptContext
|
|
|
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
PASSLIB_AVAILABLE = True
|
|
except ImportError:
|
|
PASSLIB_AVAILABLE = False
|
|
logger.warning("passlib not installed. Using SHA256 fallback for password hashing.")
|
|
|
|
# Check if database is configured
|
|
DATABASE_URL = os.getenv("DATABASE_URL", "")
|
|
USE_DATABASE = bool(DATABASE_URL and DATABASE_URL.startswith("postgresql"))
|
|
|
|
if USE_DATABASE:
|
|
try:
|
|
from database.repositories import UserRepository
|
|
from database.connection import get_sync_session, init_db as _init_db
|
|
from database import models as db_models
|
|
|
|
DATABASE_AVAILABLE = True
|
|
logger.info("Database backend enabled for authentication")
|
|
except ImportError as e:
|
|
DATABASE_AVAILABLE = False
|
|
USE_DATABASE = False
|
|
logger.warning(f"Database modules not available: {e}. Using JSON storage.")
|
|
else:
|
|
DATABASE_AVAILABLE = False
|
|
logger.info(
|
|
"Using JSON file storage for authentication (DATABASE_URL not configured)"
|
|
)
|
|
|
|
from models.subscription import User, UserCreate, PlanType, SubscriptionStatus, PLANS
|
|
|
|
|
|
# Configuration
|
|
_jwt_secret = os.getenv("JWT_SECRET", os.getenv("JWT_SECRET_KEY"))
|
|
if not _jwt_secret:
|
|
_jwt_secret = secrets.token_urlsafe(32)
|
|
logger.critical(
|
|
"SECURITY: JWT_SECRET_KEY is not configured! Using an ephemeral random key. "
|
|
"ALL JWT TOKENS WILL BE INVALIDATED ON EVERY RESTART. "
|
|
"Set JWT_SECRET_KEY in your .env file immediately."
|
|
)
|
|
SECRET_KEY = _jwt_secret
|
|
|
|
ALGORITHM = "HS256"
|
|
ACCESS_TOKEN_EXPIRE_MINUTES = 15
|
|
REFRESH_TOKEN_EXPIRE_DAYS = 7
|
|
|
|
# Simple file-based storage (used when database is not configured)
|
|
USERS_FILE = Path("data/users.json")
|
|
USERS_FILE.parent.mkdir(exist_ok=True)
|
|
|
|
# Token blocklist: jti → expiry timestamp (Unix).
|
|
# Uses Redis when available (persistent across restarts), falls back to in-memory.
|
|
_revoked_jtis: dict[str, float] = {}
|
|
_redis_blocklist_client = None
|
|
|
|
|
|
def _get_blocklist_redis():
|
|
"""Return Redis client for token blocklist, or None if unavailable."""
|
|
global _redis_blocklist_client
|
|
if _redis_blocklist_client is not None:
|
|
return _redis_blocklist_client if _redis_blocklist_client is not False else None
|
|
redis_url = os.getenv("REDIS_URL", "")
|
|
if not redis_url:
|
|
_redis_blocklist_client = False
|
|
return None
|
|
try:
|
|
import redis as redis_lib
|
|
client = redis_lib.from_url(redis_url, decode_responses=True)
|
|
client.ping()
|
|
_redis_blocklist_client = client
|
|
logger.info("Token blocklist using Redis (persistent across restarts)")
|
|
return client
|
|
except Exception as e:
|
|
logger.warning(f"Redis unavailable for token blocklist, using in-memory: {e}")
|
|
_redis_blocklist_client = False
|
|
return None
|
|
|
|
|
|
def revoke_token_jti(jti: str, expires_at: float) -> None:
|
|
"""Add a JTI to the blocklist (revoked until its expiry time)."""
|
|
ttl = max(1, int(expires_at - time.time()))
|
|
redis = _get_blocklist_redis()
|
|
if redis:
|
|
try:
|
|
redis.setex(f"revoked_jti:{jti}", ttl, "1")
|
|
return
|
|
except Exception as e:
|
|
logger.warning(f"Redis revoke failed, falling back to memory: {e}")
|
|
_revoked_jtis[jti] = expires_at
|
|
|
|
|
|
def is_token_revoked(jti: str) -> bool:
|
|
"""Return True if JTI is revoked. Lazy GC of expired in-memory entries."""
|
|
if not jti:
|
|
return False
|
|
redis = _get_blocklist_redis()
|
|
if redis:
|
|
try:
|
|
return redis.exists(f"revoked_jti:{jti}") == 1
|
|
except Exception as e:
|
|
logger.warning(f"Redis revoke check failed, falling back to memory: {e}")
|
|
now = time.time()
|
|
expired = [k for k, v in _revoked_jtis.items() if v < now]
|
|
for k in expired:
|
|
_revoked_jtis.pop(k, None)
|
|
return jti in _revoked_jtis
|
|
|
|
|
|
def hash_password(password: str) -> str:
|
|
"""Hash a password using bcrypt or fallback to SHA256"""
|
|
if PASSLIB_AVAILABLE:
|
|
return pwd_context.hash(password)
|
|
else:
|
|
# Fallback to SHA256 with salt
|
|
salt = secrets.token_hex(16)
|
|
hashed = hashlib.sha256(f"{salt}{password}".encode()).hexdigest()
|
|
return f"sha256${salt}${hashed}"
|
|
|
|
|
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
|
"""Verify a password against its hash"""
|
|
if PASSLIB_AVAILABLE and not hashed_password.startswith("sha256$"):
|
|
return pwd_context.verify(plain_password, hashed_password)
|
|
else:
|
|
# Fallback SHA256 verification
|
|
parts = hashed_password.split("$")
|
|
if len(parts) == 3 and parts[0] == "sha256":
|
|
salt = parts[1]
|
|
expected_hash = parts[2]
|
|
actual_hash = hashlib.sha256(f"{salt}{plain_password}".encode()).hexdigest()
|
|
return secrets.compare_digest(actual_hash, expected_hash)
|
|
return False
|
|
|
|
|
|
def create_access_token(
|
|
user_id: str, tier: str = "free", expires_delta: Optional[timedelta] = None
|
|
) -> str:
|
|
"""Create a JWT access token with tier claim for quick access"""
|
|
if not JWT_AVAILABLE:
|
|
token_data = {
|
|
"user_id": user_id,
|
|
"tier": tier,
|
|
"exp": (
|
|
datetime.now(timezone.utc)
|
|
+ (expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES))
|
|
).isoformat(),
|
|
}
|
|
import base64
|
|
|
|
return base64.urlsafe_b64encode(json.dumps(token_data).encode()).decode()
|
|
|
|
expire = datetime.now(timezone.utc) + (
|
|
expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
)
|
|
to_encode = {
|
|
"sub": user_id,
|
|
"tier": tier,
|
|
"exp": expire,
|
|
"type": "access",
|
|
"jti": str(uuid.uuid4()),
|
|
}
|
|
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
|
|
|
|
|
def create_refresh_token(
|
|
user_id: str, expires_delta: Optional[timedelta] = None
|
|
) -> str:
|
|
"""Create a JWT refresh token (7 days by default)"""
|
|
if not JWT_AVAILABLE:
|
|
token_data = {
|
|
"user_id": user_id,
|
|
"exp": (
|
|
datetime.now(timezone.utc)
|
|
+ (expires_delta or timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS))
|
|
).isoformat(),
|
|
}
|
|
import base64
|
|
|
|
return base64.urlsafe_b64encode(json.dumps(token_data).encode()).decode()
|
|
|
|
expire = datetime.now(timezone.utc) + (
|
|
expires_delta or timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
|
|
)
|
|
to_encode = {
|
|
"sub": user_id,
|
|
"exp": expire,
|
|
"type": "refresh",
|
|
"jti": str(uuid.uuid4()),
|
|
}
|
|
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
|
|
|
|
|
def verify_token(token: str) -> Optional[Dict[str, Any]]:
|
|
"""Verify a JWT token and return payload"""
|
|
if not JWT_AVAILABLE:
|
|
try:
|
|
import base64
|
|
|
|
data = json.loads(base64.urlsafe_b64decode(token.encode()).decode())
|
|
exp = datetime.fromisoformat(data["exp"])
|
|
if exp < datetime.now(timezone.utc):
|
|
return None
|
|
return {"sub": data["user_id"]}
|
|
except Exception:
|
|
return None
|
|
|
|
try:
|
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
jti = payload.get("jti")
|
|
if jti and is_token_revoked(jti):
|
|
return None
|
|
return payload
|
|
except jwt.ExpiredSignatureError:
|
|
return None
|
|
except jwt.PyJWTError:
|
|
return None
|
|
|
|
|
|
def load_users() -> Dict[str, Dict]:
|
|
"""Load users from file storage (JSON backend only)"""
|
|
if USERS_FILE.exists():
|
|
try:
|
|
with open(USERS_FILE, "r") as f:
|
|
return json.load(f)
|
|
except Exception as e:
|
|
logger.error(f"Failed to load users file: {e}")
|
|
return {}
|
|
return {}
|
|
|
|
|
|
def save_users(users: Dict[str, Dict]):
|
|
"""Save users to file storage (JSON backend only)"""
|
|
with open(USERS_FILE, "w") as f:
|
|
json.dump(users, f, indent=2, default=str)
|
|
|
|
|
|
def _db_user_to_model(db_user) -> User:
|
|
"""Convert database user model to Pydantic User model"""
|
|
return User(
|
|
id=str(db_user.id),
|
|
email=db_user.email,
|
|
name=db_user.name or "",
|
|
password_hash=db_user.password_hash,
|
|
avatar_url=db_user.avatar_url,
|
|
plan=PlanType(db_user.plan) if db_user.plan else PlanType.FREE,
|
|
subscription_status=SubscriptionStatus(db_user.subscription_status)
|
|
if db_user.subscription_status
|
|
else SubscriptionStatus.ACTIVE,
|
|
stripe_customer_id=db_user.stripe_customer_id,
|
|
stripe_subscription_id=db_user.stripe_subscription_id,
|
|
docs_translated_this_month=db_user.docs_translated_this_month or 0,
|
|
pages_translated_this_month=db_user.pages_translated_this_month or 0,
|
|
api_calls_this_month=db_user.api_calls_this_month or 0,
|
|
daily_translation_count=getattr(db_user, "daily_translation_count", 0) or 0,
|
|
extra_credits=db_user.extra_credits or 0,
|
|
usage_reset_date=db_user.usage_reset_date or datetime.now(timezone.utc),
|
|
default_source_lang=getattr(db_user, "default_source_lang", None) or "en",
|
|
default_target_lang=getattr(db_user, "default_target_lang", None) or "es",
|
|
default_provider=getattr(db_user, "default_provider", None) or "google",
|
|
created_at=db_user.created_at or datetime.now(timezone.utc),
|
|
updated_at=db_user.updated_at,
|
|
)
|
|
|
|
|
|
def get_user_by_email(email: str) -> Optional[User]:
|
|
"""Get a user by email"""
|
|
if USE_DATABASE and DATABASE_AVAILABLE:
|
|
from database.connection import get_sync_session
|
|
from database.repositories import UserRepository
|
|
|
|
with get_sync_session() as session:
|
|
repo = UserRepository(session)
|
|
db_user = repo.get_by_email(email)
|
|
if db_user:
|
|
return _db_user_to_model(db_user)
|
|
return None
|
|
else:
|
|
users = load_users()
|
|
for user_data in users.values():
|
|
if user_data.get("email", "").lower() == email.lower():
|
|
return User(**user_data)
|
|
return None
|
|
|
|
|
|
def get_user_by_id(user_id: str) -> Optional[User]:
|
|
"""Get a user by ID"""
|
|
if USE_DATABASE and DATABASE_AVAILABLE:
|
|
from database.connection import get_sync_session
|
|
from database.repositories import UserRepository
|
|
|
|
with get_sync_session() as session:
|
|
repo = UserRepository(session)
|
|
db_user = repo.get_by_id(user_id)
|
|
if db_user:
|
|
return _db_user_to_model(db_user)
|
|
return None
|
|
else:
|
|
users = load_users()
|
|
if user_id in users:
|
|
return User(**users[user_id])
|
|
return None
|
|
|
|
|
|
def create_user(user_create: UserCreate) -> User:
|
|
"""Create a new user"""
|
|
# Check if email exists
|
|
if get_user_by_email(user_create.email):
|
|
raise ValueError("Email already registered")
|
|
|
|
if USE_DATABASE and DATABASE_AVAILABLE:
|
|
from database.connection import get_sync_session
|
|
from database.repositories import UserRepository
|
|
|
|
with get_sync_session() as session:
|
|
repo = UserRepository(session)
|
|
db_user = repo.create(
|
|
email=user_create.email,
|
|
name=user_create.name,
|
|
hashed_password=hash_password(user_create.password),
|
|
tier="free",
|
|
)
|
|
return _db_user_to_model(db_user)
|
|
else:
|
|
users = load_users()
|
|
|
|
# Generate user ID
|
|
user_id = secrets.token_urlsafe(16)
|
|
|
|
# Create user
|
|
user = User(
|
|
id=user_id,
|
|
email=user_create.email,
|
|
name=user_create.name,
|
|
password_hash=hash_password(user_create.password),
|
|
plan=PlanType.FREE,
|
|
subscription_status=SubscriptionStatus.ACTIVE,
|
|
)
|
|
|
|
# Save to storage
|
|
users[user_id] = user.model_dump()
|
|
save_users(users)
|
|
|
|
return user
|
|
|
|
|
|
def authenticate_user(email: str, password: str) -> Optional[User]:
|
|
"""Authenticate a user with email and password"""
|
|
user = get_user_by_email(email)
|
|
if not user:
|
|
return None
|
|
if not verify_password(password, user.password_hash):
|
|
return None
|
|
return user
|
|
|
|
|
|
def update_user(user_id: str, updates: Dict[str, Any]) -> Optional[User]:
|
|
"""Update a user's data"""
|
|
if USE_DATABASE and DATABASE_AVAILABLE:
|
|
from database.connection import get_sync_session
|
|
from database.repositories import UserRepository
|
|
|
|
with get_sync_session() as session:
|
|
repo = UserRepository(session)
|
|
db_user = repo.update(user_id, **updates)
|
|
if db_user:
|
|
return _db_user_to_model(db_user)
|
|
return None
|
|
else:
|
|
users = load_users()
|
|
if user_id not in users:
|
|
return None
|
|
|
|
users[user_id].update(updates)
|
|
users[user_id]["updated_at"] = datetime.now(timezone.utc).isoformat()
|
|
save_users(users)
|
|
|
|
return User(**users[user_id])
|
|
|
|
|
|
def check_usage_limits(user: User) -> Dict[str, Any]:
|
|
"""Check if user has exceeded their plan limits"""
|
|
plan = PLANS[user.plan]
|
|
|
|
# Reset usage if it's a new month
|
|
now = datetime.now(timezone.utc)
|
|
if (
|
|
user.usage_reset_date.month != now.month
|
|
or user.usage_reset_date.year != now.year
|
|
):
|
|
update_user(
|
|
user.id,
|
|
{
|
|
"docs_translated_this_month": 0,
|
|
"pages_translated_this_month": 0,
|
|
"api_calls_this_month": 0,
|
|
"usage_reset_date": now.isoformat() if not USE_DATABASE else now,
|
|
},
|
|
)
|
|
user.docs_translated_this_month = 0
|
|
user.pages_translated_this_month = 0
|
|
user.api_calls_this_month = 0
|
|
|
|
docs_limit = plan["docs_per_month"]
|
|
docs_remaining = (
|
|
max(0, docs_limit - user.docs_translated_this_month) if docs_limit > 0 else -1
|
|
)
|
|
|
|
return {
|
|
"can_translate": docs_remaining != 0 or user.extra_credits > 0,
|
|
"docs_used": user.docs_translated_this_month,
|
|
"docs_limit": docs_limit,
|
|
"docs_remaining": docs_remaining,
|
|
"pages_used": user.pages_translated_this_month,
|
|
"extra_credits": user.extra_credits,
|
|
"max_pages_per_doc": plan["max_pages_per_doc"],
|
|
"max_file_size_mb": plan["max_file_size_mb"],
|
|
"allowed_providers": plan["providers"],
|
|
}
|
|
|
|
|
|
def record_usage(user_id: str, pages_count: int, use_credits: bool = False) -> bool:
|
|
"""Record document translation usage"""
|
|
user = get_user_by_id(user_id)
|
|
if not user:
|
|
return False
|
|
|
|
updates = {
|
|
"docs_translated_this_month": user.docs_translated_this_month + 1,
|
|
"pages_translated_this_month": user.pages_translated_this_month + pages_count,
|
|
}
|
|
|
|
if use_credits:
|
|
updates["extra_credits"] = max(0, user.extra_credits - pages_count)
|
|
|
|
result = update_user(user_id, updates)
|
|
return result is not None
|
|
|
|
|
|
def add_credits(user_id: str, credits: int) -> bool:
|
|
"""Add credits to a user's account"""
|
|
user = get_user_by_id(user_id)
|
|
if not user:
|
|
return False
|
|
|
|
result = update_user(user_id, {"extra_credits": user.extra_credits + credits})
|
|
return result is not None
|
|
|
|
|
|
# Valid plan values for admin tier change (Story 1.7)
|
|
VALID_PLAN_VALUES = {"free", "starter", "pro", "business", "enterprise"}
|
|
|
|
|
|
def update_user_plan(user_id: str, plan: str) -> Optional[User]:
|
|
"""
|
|
Update a user's plan/tier (admin only). Keeps User.plan and User.tier in sync.
|
|
tier is set to 'pro' for pro/business/enterprise, 'free' otherwise (DB constraint).
|
|
"""
|
|
plan_lower = (plan or "").strip().lower()
|
|
if plan_lower not in VALID_PLAN_VALUES:
|
|
return None
|
|
|
|
plan_enum = PlanType(plan_lower)
|
|
tier = (
|
|
"pro"
|
|
if plan_enum in (PlanType.PRO, PlanType.BUSINESS, PlanType.ENTERPRISE)
|
|
else "free"
|
|
)
|
|
|
|
if USE_DATABASE and DATABASE_AVAILABLE:
|
|
updates = {"plan": plan_enum, "tier": tier}
|
|
else:
|
|
updates = {"plan": plan_lower, "tier": tier}
|
|
|
|
return update_user(user_id, updates)
|
|
|
|
|
|
def get_user_by_api_key(api_key: str) -> Optional[User]:
|
|
"""
|
|
Get a user by API key.
|
|
|
|
Verifies that:
|
|
- The key exists in the database
|
|
- The key is active (is_active=True)
|
|
- The key hasn't expired (expires_at is None or in the future)
|
|
|
|
Returns the user associated with the API key, or None if invalid/revoked.
|
|
|
|
Raises:
|
|
ValueError: With code "API_KEY_REVOKED" if key exists but is inactive
|
|
"""
|
|
if not api_key:
|
|
return None
|
|
|
|
# Only database backend supports API keys
|
|
if USE_DATABASE and DATABASE_AVAILABLE:
|
|
from database.connection import get_sync_session
|
|
from database.models import ApiKey
|
|
import hashlib
|
|
|
|
# Hash the provided key to compare with stored hash
|
|
key_hash = hashlib.sha256(api_key.encode()).hexdigest()
|
|
|
|
with get_sync_session() as session:
|
|
api_key_record = (
|
|
session.query(ApiKey).filter(ApiKey.key_hash == key_hash).first()
|
|
)
|
|
|
|
if not api_key_record:
|
|
return None
|
|
|
|
# Check if key is active (Story 3.2 - Revocation check)
|
|
if not api_key_record.is_active:
|
|
raise ValueError("API_KEY_REVOKED")
|
|
|
|
# Check expiration if set
|
|
if api_key_record.expires_at:
|
|
if api_key_record.expires_at < datetime.now(timezone.utc):
|
|
raise ValueError("API_KEY_EXPIRED")
|
|
|
|
# Update last_used_at and usage_count
|
|
api_key_record.last_used_at = datetime.now(timezone.utc)
|
|
api_key_record.usage_count = (api_key_record.usage_count or 0) + 1
|
|
session.commit()
|
|
|
|
# Get the user
|
|
user_id = api_key_record.user_id
|
|
return get_user_by_id(str(user_id))
|
|
|
|
return None
|
|
|
|
|
|
def init_database():
|
|
"""Initialize the database (call on application startup)"""
|
|
if USE_DATABASE and DATABASE_AVAILABLE:
|
|
_init_db()
|
|
logger.info("Database initialized successfully")
|
|
else:
|
|
logger.info("Using JSON file storage")
|