diff --git a/database/repositories.py b/database/repositories.py index 51c2b21..4e69725 100644 --- a/database/repositories.py +++ b/database/repositories.py @@ -93,34 +93,45 @@ class UserRepository: self.db.commit() return True - def increment_usage( - self, user_id: str, docs: int = 0, pages: int = 0, api_calls: int = 0 - ) -> Optional[User]: - """Increment usage counters""" + def reset_usage_if_needed(self, user_id: str) -> Optional[User]: + """Reset monthly counters if usage_reset_date is in a previous month.""" user = self.get_by_id(user_id) if not user: return None - # Check if usage needs to be reset (monthly) - if user.usage_reset_date: - now = datetime.now(timezone.utc) - if ( - now.month != user.usage_reset_date.month - or now.year != user.usage_reset_date.year - ): - user.docs_translated_this_month = 0 - user.pages_translated_this_month = 0 - user.api_calls_this_month = 0 - user.usage_reset_date = now - - user.docs_translated_this_month += docs - user.pages_translated_this_month += pages - user.api_calls_this_month += api_calls - - self.db.commit() - self.db.refresh(user) + now = datetime.now(timezone.utc) + reset_date = user.usage_reset_date + if reset_date is None or reset_date.month != now.month or reset_date.year != now.year: + user.docs_translated_this_month = 0 + user.pages_translated_this_month = 0 + user.api_calls_this_month = 0 + user.usage_reset_date = now + user.updated_at = now + self.db.commit() + self.db.refresh(user) return user + def increment_usage( + self, user_id: str, docs: int = 0, pages: int = 0, api_calls: int = 0 + ) -> Optional[User]: + """Increment usage counters atomically via SQL UPDATE.""" + from sqlalchemy import update + + now = datetime.now(timezone.utc) + self.db.execute( + update(User) + .where(User.id == user_id) + .values( + docs_translated_this_month=User.docs_translated_this_month + docs, + pages_translated_this_month=User.pages_translated_this_month + pages, + api_calls_this_month=User.api_calls_this_month + api_calls, + updated_at=now, + ) + .execution_options(synchronize_session=False) + ) + self.db.commit() + return self.get_by_id(user_id) + def add_credits(self, user_id: str, credits: int) -> Optional[User]: """Add extra credits to user""" user = self.get_by_id(user_id) @@ -133,14 +144,21 @@ class UserRepository: return user def use_credits(self, user_id: str, credits: int) -> bool: - """Use credits from user balance""" - user = self.get_by_id(user_id) - if not user or user.extra_credits < credits: - return False + """Use credits from user balance atomically (prevents overdraft).""" + from sqlalchemy import update - user.extra_credits -= credits + now = datetime.now(timezone.utc) + result = self.db.execute( + update(User) + .where(User.id == user_id, User.extra_credits >= credits) + .values( + extra_credits=User.extra_credits - credits, + updated_at=now, + ) + .execution_options(synchronize_session=False) + ) self.db.commit() - return True + return result.rowcount > 0 def get_all_users( self, skip: int = 0, limit: int = 100, plan: Optional[PlanType] = None diff --git a/models/subscription.py b/models/subscription.py index 3fb7c7f..81d6c6e 100644 --- a/models/subscription.py +++ b/models/subscription.py @@ -204,6 +204,7 @@ class User(BaseModel): # Subscription info plan: PlanType = PlanType.FREE + tier: str = "free" # binary-ish tier for JWT/auth: free for free/starter, pro for paid subscription_status: SubscriptionStatus = SubscriptionStatus.ACTIVE stripe_customer_id: Optional[str] = None stripe_subscription_id: Optional[str] = None diff --git a/routes/auth_routes.py b/routes/auth_routes.py index 3f308a7..8cc7980 100644 --- a/routes/auth_routes.py +++ b/routes/auth_routes.py @@ -104,13 +104,14 @@ async def require_user(credentials: HTTPAuthorizationCredentials = Depends(secur def user_to_response(user) -> UserResponse: """Convert User to UserResponse with plan limits""" plan_limits = PLANS[user.plan] + tier = getattr(user, "tier", None) or user.plan return UserResponse( id=user.id, email=user.email, name=user.name, avatar_url=user.avatar_url, plan=user.plan, - tier=user.plan, + tier=tier, subscription_status=user.subscription_status, subscription_ends_at=getattr(user, 'subscription_ends_at', None), cancel_at_period_end=getattr(user, 'cancel_at_period_end', False), @@ -540,7 +541,7 @@ async def login_v1(request: Request): ) access_token = create_access_token( - user.id, tier=user.plan.value, expires_delta=timedelta(minutes=15) + user.id, tier=getattr(user, "tier", user.plan.value), expires_delta=timedelta(minutes=15) ) refresh_token = create_refresh_token(user.id, expires_delta=timedelta(days=7)) @@ -626,7 +627,7 @@ async def google_auth_v1(body: GoogleAuthRequest): content={"error": "USER_CREATE_FAILED", "message": str(exc)}, ) - access_token = create_access_token(user.id) + access_token = create_access_token(user.id, tier=getattr(user, "tier", "free")) refresh_token = create_refresh_token(user.id) return JSONResponse( @@ -724,7 +725,7 @@ async def refresh_v1(request: Request): ) access_token = create_access_token( - user.id, tier=user.plan.value, expires_delta=timedelta(minutes=15) + user.id, tier=getattr(user, "tier", user.plan.value), expires_delta=timedelta(minutes=15) ) new_refresh_token = create_refresh_token(user.id, expires_delta=timedelta(days=7)) diff --git a/routes/translate_routes.py b/routes/translate_routes.py index 10e4ee9..f128389 100644 --- a/routes/translate_routes.py +++ b/routes/translate_routes.py @@ -44,8 +44,13 @@ from typing_extensions import Annotated from config import config from translators import ExcelTranslator, WordTranslator, PowerPointTranslator from models.subscription import PlanType -from middleware.tier_quota import tier_quota_service -from services.auth_service import record_usage, check_usage_limits +from services.auth_service import ( + record_usage, + check_usage_limits, + reserve_translation_quota, + release_translation_quota, +) +from middleware.tier_quota import _seconds_until_next_month, _next_month_utc from middleware.validation import FileValidator, ValidationError, LanguageValidator, webhook_validator from middleware.api_key_auth import get_authenticated_user, get_user_from_api_key from utils import file_handler @@ -548,6 +553,7 @@ async def translate_document_v1( """ request_id = getattr(request.state, "request_id", str(uuid.uuid4())[:8]) + quota_reserved = False try: if not file and not file_url: raise TranslateEndpointError( @@ -615,40 +621,42 @@ async def translate_document_v1( ) if current_user: - quota = await tier_quota_service.check_quota(user_id, tier) - if not quota.allowed: - retry_after = tier_quota_service.seconds_until_reset() - raise HTTPException( - status_code=429, - detail={ - "error": "QUOTA_EXCEEDED", - "message": f"Monthly limit reached ({quota.current_usage}/{quota.limit} documents). Upgrade your plan for more.", - "details": { - "current_usage": quota.current_usage, - "limit": quota.limit, - "tier": tier, - "reset_at": quota.reset_at_utc.isoformat(), - }, - }, - headers={"Retry-After": str(retry_after)}, - ) - - # Strict database plan limit check (Starter, Pro, Business, Enterprise) usage = check_usage_limits(current_user) if not usage["can_translate"]: + retry_after = _seconds_until_next_month() raise HTTPException( status_code=429, detail={ "error": "QUOTA_EXCEEDED", "message": f"Monthly limit reached ({usage['docs_used']}/{usage['docs_limit']} documents). Upgrade your plan for more.", "details": { - "current_usage": usage["docs_used"], - "limit": usage["docs_limit"], + "current_usage": usage['docs_used'], + "limit": usage['docs_limit'], "tier": tier, + "reset_at": _next_month_utc().isoformat(), }, }, + headers={"Retry-After": str(retry_after)}, ) - rate_limit_remaining = quota.remaining + # Atomically reserve one document slot now so concurrent requests cannot + # overshoot the monthly quota while background jobs are still running. + reserved = await asyncio.to_thread(reserve_translation_quota, user_id) + if not reserved: + retry_after = _seconds_until_next_month() + raise HTTPException( + status_code=429, + detail={ + "error": "QUOTA_EXCEEDED", + "message": "Monthly limit reached. Upgrade your plan for more.", + "details": { + "tier": tier, + "reset_at": _next_month_utc().isoformat(), + }, + }, + headers={"Retry-After": str(retry_after)}, + ) + quota_reserved = True + rate_limit_remaining = usage["docs_remaining"] else: rate_limit_remaining = -1 @@ -839,6 +847,8 @@ async def translate_document_v1( ) except TranslateEndpointError as e: + if quota_reserved and user_id: + await asyncio.to_thread(release_translation_quota, user_id) status_code = 400 if e.code == TranslateEndpointError.FILE_TOO_LARGE: status_code = 413 @@ -852,8 +862,12 @@ async def translate_document_v1( content=e.to_dict(), ) except HTTPException: + if quota_reserved and user_id: + await asyncio.to_thread(release_translation_quota, user_id) raise except Exception as e: + if quota_reserved and user_id: + await asyncio.to_thread(release_translation_quota, user_id) logger.error(f"[{request_id}] Unexpected error: {e}") return JSONResponse( status_code=400, @@ -973,6 +987,7 @@ async def _run_translation_job( return tracker = ProgressTracker(job_id, _translation_jobs) + usage_recorded = False try: job["status"] = "processing" @@ -1337,14 +1352,14 @@ async def _run_translation_job( else: cost_factor = 5 - for _ in range(cost_factor): - await tier_quota_service.increment_on_success(user_id) - # Persist monthly usage counters in PostgreSQL (docs + pages) pages = await asyncio.to_thread( _estimate_pages, input_path, file_extension ) - await asyncio.to_thread(record_usage, user_id, pages, False, cost_factor) + await asyncio.to_thread( + record_usage, user_id, pages, cost_factor, reserved_docs=1 + ) + usage_recorded = True logger.info(f"Job {job_id}: usage recorded — {pages} page(s) with cost factor {cost_factor}") # Apply watermark for Free-tier users @@ -1364,6 +1379,12 @@ async def _run_translation_job( record_translation(provider=provider, file_type=file_extension or "unknown", duration=duration, status="success") logger.info(f"Job {job_id}: Completed successfully") + except asyncio.CancelledError: + # Background task cancelled (e.g. TestClient teardown or server shutdown). + # The document slot was already reserved at request time; keep it consumed + # so quota enforcement remains deterministic. + logger.warning(f"Job {job_id}: translation task cancelled, keeping reserved quota") + raise except Exception as e: # Check if this is our structured TranslationProviderError if type(e).__name__ == "TranslationProviderError": @@ -1376,6 +1397,13 @@ async def _run_translation_job( # Record translation failure metric record_translation(provider=provider, file_type=file_extension or "unknown", duration=0, status="error") + if user_id and not usage_recorded: + try: + await asyncio.to_thread(release_translation_quota, user_id) + logger.info(f"Job {job_id}: released reserved quota after failure") + except Exception as release_err: + logger.exception(f"Job {job_id}: failed to release reserved quota: {release_err}") + finally: if webhook_url: try: diff --git a/services/auth_service.py b/services/auth_service.py index f7c22bc..5c33429 100644 --- a/services/auth_service.py +++ b/services/auth_service.py @@ -278,6 +278,7 @@ def _db_user_to_model(db_user) -> User: password_hash=db_user.password_hash, avatar_url=db_user.avatar_url, plan=PlanType(db_user.plan) if db_user.plan else PlanType.FREE, + tier=db_user.tier or "free", subscription_status=SubscriptionStatus(db_user.subscription_status) if db_user.subscription_status else SubscriptionStatus.ACTIVE, @@ -460,36 +461,48 @@ def update_user(user_id: str, updates: Dict[str, Any]) -> Optional[User]: return User(**users[user_id]) +def _reset_usage_if_needed(user_id: str) -> Optional[User]: + """Reset monthly counters if usage_reset_date is in a previous month.""" + if USE_DATABASE and DATABASE_AVAILABLE: + from database.connection import get_sync_session + from database.repositories import UserRepository + + with get_sync_session() as session: + repo = UserRepository(session) + repo.reset_usage_if_needed(user_id) + else: + user = get_user_by_id(user_id) + if not user: + return None + now = datetime.now(timezone.utc) + reset_date = user.usage_reset_date + if reset_date.month != now.month or reset_date.year != now.year: + update_user( + user_id, + { + "docs_translated_this_month": 0, + "pages_translated_this_month": 0, + "api_calls_this_month": 0, + "usage_reset_date": now.isoformat(), + }, + ) + return get_user_by_id(user_id) + + def check_usage_limits(user: User) -> Dict[str, Any]: """Check if user has exceeded their plan limits""" + # Ensure counters are reset if we've entered a new month. + refreshed = _reset_usage_if_needed(user.id) + if refreshed: + user = refreshed + plan = PLANS[user.plan] - - # Reset usage if it's a new month - now = datetime.now(timezone.utc) - if ( - user.usage_reset_date.month != now.month - or user.usage_reset_date.year != now.year - ): - update_user( - user.id, - { - "docs_translated_this_month": 0, - "pages_translated_this_month": 0, - "api_calls_this_month": 0, - "usage_reset_date": now.isoformat() if not USE_DATABASE else now, - }, - ) - user.docs_translated_this_month = 0 - user.pages_translated_this_month = 0 - user.api_calls_this_month = 0 - docs_limit = plan["docs_per_month"] - docs_remaining = ( - max(0, docs_limit - user.docs_translated_this_month) if docs_limit > 0 else -1 - ) + unlimited = docs_limit == -1 + docs_remaining = -1 if unlimited else max(0, docs_limit - user.docs_translated_this_month) return { - "can_translate": docs_remaining != 0 or user.extra_credits > 0, + "can_translate": unlimited or docs_remaining != 0 or user.extra_credits > 0, "docs_used": user.docs_translated_this_month, "docs_limit": docs_limit, "docs_remaining": docs_remaining, @@ -502,23 +515,190 @@ def check_usage_limits(user: User) -> Dict[str, Any]: def record_usage( - user_id: str, pages_count: int, use_credits: bool = False, cost_factor: int = 1 + user_id: str, pages_count: int, cost_factor: int = 1, reserved_docs: int = 0 ) -> bool: - """Record document translation usage with optional cost factor depending on AI model""" + """Record document translation usage with optional cost factor depending on AI model. + + `reserved_docs` is the number of document slots already reserved at request time + (e.g. by ``reserve_translation_quota``). Those slots are not counted again; only + the remaining ``max(0, cost_factor - reserved_docs)`` docs are added here. + + Automatically consumes extra credits first; falls back to monthly quota. + Returns True if usage was recorded successfully, False otherwise. + """ + user = _reset_usage_if_needed(user_id) + if not user: + return False + + total_cost = pages_count * cost_factor + docs_to_record = max(0, cost_factor - reserved_docs) + plan = PLANS[user.plan] + docs_limit = plan["docs_per_month"] + unlimited = docs_limit == -1 + + if USE_DATABASE and DATABASE_AVAILABLE: + from database.connection import get_sync_session + from database.repositories import UserRepository + + with get_sync_session() as session: + repo = UserRepository(session) + + if unlimited: + # Paid unlimited plans: track usage for analytics/downgrade safety. + repo.increment_usage(user_id, docs=docs_to_record, pages=total_cost) + return True + + # Prefer credits first, then quota. + if user.extra_credits > 0: + credits_to_use = min(user.extra_credits, total_cost) + if repo.use_credits(user_id, credits_to_use): + remaining_cost = total_cost - credits_to_use + if remaining_cost > 0: + repo.increment_usage( + user_id, docs=docs_to_record, pages=remaining_cost + ) + return True + return False + + if user.docs_translated_this_month + docs_to_record <= docs_limit: + repo.increment_usage(user_id, docs=docs_to_record, pages=total_cost) + return True + return False + else: + # JSON fallback (non-atomic, dev only) + if unlimited: + return update_user( + user_id, + { + "docs_translated_this_month": user.docs_translated_this_month + + docs_to_record, + "pages_translated_this_month": user.pages_translated_this_month + + total_cost, + }, + ) is not None + + if user.extra_credits > 0: + credits_to_use = min(user.extra_credits, total_cost) + new_credits = user.extra_credits - credits_to_use + remaining_cost = total_cost - credits_to_use + updates = {"extra_credits": new_credits} + if remaining_cost > 0: + if user.docs_translated_this_month + docs_to_record > docs_limit: + return False + updates["docs_translated_this_month"] = ( + user.docs_translated_this_month + docs_to_record + ) + updates["pages_translated_this_month"] = ( + user.pages_translated_this_month + remaining_cost + ) + return update_user(user_id, updates) is not None + + if user.docs_translated_this_month + docs_to_record <= docs_limit: + return update_user( + user_id, + { + "docs_translated_this_month": user.docs_translated_this_month + + docs_to_record, + "pages_translated_this_month": user.pages_translated_this_month + + total_cost, + }, + ) is not None + return False + + +def reserve_translation_quota(user_id: str) -> bool: + """Atomically reserve one document slot at request time. + + Returns True if the reservation succeeded. This prevents race conditions where + multiple concurrent requests could exceed the monthly quota. + """ + user = _reset_usage_if_needed(user_id) + if not user: + return False + + plan = PLANS[user.plan] + docs_limit = plan["docs_per_month"] + unlimited = docs_limit == -1 + + if unlimited: + if USE_DATABASE and DATABASE_AVAILABLE: + from database.connection import get_sync_session + from database.repositories import UserRepository + + with get_sync_session() as session: + repo = UserRepository(session) + repo.increment_usage(user_id, docs=1, pages=0) + return True + return update_user( + user_id, + {"docs_translated_this_month": user.docs_translated_this_month + 1}, + ) is not None + + # Limited plan: prefer credits first, then quota. + if user.extra_credits > 0: + if USE_DATABASE and DATABASE_AVAILABLE: + from database.connection import get_sync_session + from database.repositories import UserRepository + + with get_sync_session() as session: + repo = UserRepository(session) + return repo.use_credits(user_id, 1) + return update_user( + user_id, {"extra_credits": max(0, user.extra_credits - 1)} + ) is not None + + if user.docs_translated_this_month + 1 <= docs_limit: + if USE_DATABASE and DATABASE_AVAILABLE: + from database.connection import get_sync_session + from database.repositories import UserRepository + + with get_sync_session() as session: + repo = UserRepository(session) + repo.increment_usage(user_id, docs=1, pages=0) + return True + return update_user( + user_id, + {"docs_translated_this_month": user.docs_translated_this_month + 1}, + ) is not None + + return False + + +def release_translation_quota(user_id: str) -> bool: + """Release a previously reserved document slot (e.g. on translation failure).""" user = get_user_by_id(user_id) if not user: return False - updates = { - "docs_translated_this_month": user.docs_translated_this_month + cost_factor, - "pages_translated_this_month": user.pages_translated_this_month + (pages_count * cost_factor), - } + if USE_DATABASE and DATABASE_AVAILABLE: + from database.connection import get_sync_session + from database.repositories import UserRepository + from sqlalchemy import update - if use_credits: - updates["extra_credits"] = max(0, user.extra_credits - (pages_count * cost_factor)) + with get_sync_session() as session: + now = datetime.now(timezone.utc) + session.execute( + update(db_models.User) + .where( + db_models.User.id == user_id, + db_models.User.docs_translated_this_month > 0, + ) + .values( + docs_translated_this_month=db_models.User.docs_translated_this_month + - 1, + updated_at=now, + ) + .execution_options(synchronize_session=False) + ) + session.commit() + return True - result = update_user(user_id, updates) - return result is not None + if user.docs_translated_this_month > 0: + return update_user( + user_id, + {"docs_translated_this_month": user.docs_translated_this_month - 1}, + ) is not None + return True def add_credits(user_id: str, credits: int) -> bool: diff --git a/services/payment_service.py b/services/payment_service.py index 234cdb7..474f985 100644 --- a/services/payment_service.py +++ b/services/payment_service.py @@ -4,7 +4,7 @@ Stripe payment integration for subscriptions and credits import os import logging from typing import Optional, Dict, Any -from datetime import datetime +from datetime import datetime, timezone logger = logging.getLogger(__name__) @@ -250,19 +250,23 @@ async def handle_webhook(payload: bytes, sig_header: str) -> Dict[str, Any]: if event["type"] == "checkout.session.completed": session = event["data"]["object"] await handle_checkout_completed(session) - + elif event["type"] == "customer.subscription.updated": subscription = event["data"]["object"] await handle_subscription_updated(subscription) - + elif event["type"] == "customer.subscription.deleted": subscription = event["data"]["object"] await handle_subscription_deleted(subscription) - + elif event["type"] == "invoice.payment_failed": invoice = event["data"]["object"] await handle_payment_failed(invoice) - + + elif event["type"] == "invoice.paid": + invoice = event["data"]["object"] + await handle_invoice_paid(invoice) + return {"status": "success"} @@ -276,7 +280,9 @@ async def handle_checkout_completed(session: Dict): session_id = session.get("id") - # Check for duplicate session processing using PaymentHistory + # Check for duplicate session processing using PaymentHistory. + # Use the Stripe payment_intent (one-time) or subscription id (recurring) as the + # idempotency key, NOT the checkout session id — Stripe can redeliver the event. db_available = False try: from database.connection import get_sync_session @@ -285,14 +291,22 @@ async def handle_checkout_completed(session: Dict): except ImportError: pass + payment_intent_id = session.get("payment_intent") + subscription_id = session.get("subscription") if db_available and session_id: try: with get_sync_session() as db_session: - existing = db_session.query(DBPaymentHistory).filter( - DBPaymentHistory.stripe_payment_intent_id == session_id - ).first() + from sqlalchemy import or_ + + filters = [DBPaymentHistory.stripe_payment_intent_id == session_id] + if payment_intent_id: + filters.append(DBPaymentHistory.stripe_payment_intent_id == payment_intent_id) + if subscription_id: + filters.append(DBPaymentHistory.stripe_invoice_id == subscription_id) + existing = db_session.query(DBPaymentHistory).filter(or_(*filters)).first() if existing: - logger.info("Checkout session %s already processed. Skipping.", session_id) + logger.info("Checkout session %s already processed (pi=%s sub=%s). Skipping.", + session_id, payment_intent_id, subscription_id) return except Exception as e: logger.error("Error checking PaymentHistory duplication: %s", e) @@ -312,7 +326,7 @@ async def handle_checkout_completed(session: Dict): with get_sync_session() as db_session: payment = DBPaymentHistory( user_id=user_id, - stripe_payment_intent_id=session_id, + stripe_payment_intent_id=payment_intent_id or session_id, stripe_invoice_id=session.get("invoice") or session.get("subscription"), amount_cents=session.get("amount_total") or 0, currency=session.get("currency") or "usd", @@ -377,7 +391,6 @@ async def handle_checkout_completed(session: Dict): subscription_id = subscription_raw.get("id") period_end = subscription_raw.get("current_period_end") if period_end: - from datetime import timezone subscription_ends_at = datetime.fromtimestamp(period_end, tz=timezone.utc) # Derive tier from plan (DB constraint: only 'free' or 'pro') @@ -408,7 +421,7 @@ async def handle_checkout_completed(session: Dict): with get_sync_session() as db_session: payment = DBPaymentHistory( user_id=user_id, - stripe_payment_intent_id=session_id, + stripe_payment_intent_id=payment_intent_id or session_id, stripe_invoice_id=subscription_id or session.get("invoice"), amount_cents=session.get("amount_total") or 0, currency=session.get("currency") or "usd", @@ -481,15 +494,13 @@ async def handle_subscription_updated(subscription: Dict): period_end = subscription.get("current_period_end") ends_str = "" if period_end: - from datetime import timezone ends_str = datetime.fromtimestamp(period_end, tz=timezone.utc).strftime("%d/%m/%Y") + period_end = subscription.get("current_period_end") update_user(user_id, { "subscription_status": status.value, "cancel_at_period_end": stripe_cancel_at_period_end, - "subscription_ends_at": datetime.fromtimestamp( - subscription.get("current_period_end", 0) - ).isoformat() if subscription.get("current_period_end") else None + "subscription_ends_at": datetime.fromtimestamp(period_end, tz=timezone.utc) if period_end else None, }) # Send cancellation email if they just selected to cancel @@ -531,7 +542,7 @@ async def handle_subscription_deleted(subscription: Dict): if not user: return - had_active_sub = user.plan != PlanType.FREE.value or user.tier != "free" + had_active_sub = user.plan != PlanType.FREE or user.tier != "free" update_user(user_id, { "plan": PlanType.FREE.value, @@ -621,6 +632,50 @@ async def handle_payment_failed(invoice: Dict): logger.error("handle_payment_failed DB error: %s", exc) +async def handle_invoice_paid(invoice: Dict): + """Extend subscription_ends_at when a recurring invoice is paid.""" + customer_id = invoice.get("customer") + if not customer_id: + return + + subscription_id = invoice.get("subscription") + period_end = invoice.get("period_end") or invoice.get("lines", {}).get("data", [{}])[0].get("period", {}).get("end") + + try: + from database.connection import get_sync_session + from database.models import User as DBUser + + with get_sync_session() as session: + db_user = ( + session.query(DBUser) + .filter(DBUser.stripe_customer_id == customer_id) + .first() + ) + if not db_user: + return + + if subscription_id and db_user.stripe_subscription_id != subscription_id: + # The paid invoice belongs to a different subscription; do not update. + logger.warning( + "Invoice paid for customer %s but subscription id mismatch (expected %s, got %s)", + customer_id, db_user.stripe_subscription_id, subscription_id, + ) + return + + if period_end: + new_end = datetime.fromtimestamp(period_end, tz=timezone.utc) + if db_user.subscription_ends_at is None or new_end > db_user.subscription_ends_at: + db_user.subscription_ends_at = new_end + db_user.updated_at = datetime.now(timezone.utc) + session.commit() + logger.info( + "Extended subscription_ends_at for user %s to %s", + db_user.id, new_end.isoformat(), + ) + except Exception as exc: + logger.error("handle_invoice_paid error: %s", exc) + + async def cancel_subscription(user_id: str) -> Dict[str, Any]: """Cancel a user's subscription at period end.""" if not is_stripe_configured(): @@ -631,6 +686,10 @@ async def cancel_subscription(user_id: str) -> Dict[str, Any]: return {"error": "No active subscription found"} try: + subscription = stripe.Subscription.retrieve(user.stripe_subscription_id) + if subscription.customer != user.stripe_customer_id: + return {"error": "Subscription does not belong to current user"} + subscription = stripe.Subscription.modify( user.stripe_subscription_id, cancel_at_period_end=True, @@ -638,13 +697,13 @@ async def cancel_subscription(user_id: str) -> Dict[str, Any]: cancel_at = None if subscription.cancel_at: - cancel_at = datetime.fromtimestamp(subscription.cancel_at).isoformat() + cancel_at = datetime.fromtimestamp(subscription.cancel_at, tz=timezone.utc) subscription_ends_at = None ends_str = "" if subscription.current_period_end: - subscription_ends_at = datetime.fromtimestamp(subscription.current_period_end).isoformat() - ends_str = datetime.fromtimestamp(subscription.current_period_end).strftime("%d/%m/%Y") + subscription_ends_at = datetime.fromtimestamp(subscription.current_period_end, tz=timezone.utc) + ends_str = datetime.fromtimestamp(subscription.current_period_end, tz=timezone.utc).strftime("%d/%m/%Y") is_new_cancel = not user.cancel_at_period_end diff --git a/tests/test_download_endpoint.py b/tests/test_download_endpoint.py index 9f3329e..7eb386b 100644 --- a/tests/test_download_endpoint.py +++ b/tests/test_download_endpoint.py @@ -110,26 +110,22 @@ def client(users_file: Path, monkeypatch): monkeypatch.setattr(RateLimitManager, "check_request", _check_request_allow) monkeypatch.setattr(RateLimitManager, "check_translation", _check_translation_allow) - from middleware.tier_quota import TierQuotaService + def _check_usage_limits_allow(user): + return { + "can_translate": True, + "docs_used": 0, + "docs_limit": 5, + "docs_remaining": 5, + "pages_used": 0, + "extra_credits": 0, + "max_pages_per_doc": 50, + "max_file_size_mb": 10, + "allowed_providers": ["google", "deepl"], + } - async def _check_quota_allow(self, user_id, tier): - from middleware.tier_quota import QuotaResult - from datetime import datetime, timezone, timedelta - - now = datetime.now(timezone.utc) - tomorrow = now.date() + timedelta(days=1) - reset_at = datetime( - tomorrow.year, tomorrow.month, tomorrow.day, tzinfo=timezone.utc - ) - return QuotaResult( - allowed=True, remaining=5, reset_at_utc=reset_at, current_usage=0, limit=5 - ) - - async def _increment_noop(self, user_id): - pass - - monkeypatch.setattr(TierQuotaService, "check_quota", _check_quota_allow) - monkeypatch.setattr(TierQuotaService, "increment_on_success", _increment_noop) + monkeypatch.setattr( + "routes.translate_routes.check_usage_limits", _check_usage_limits_allow + ) from main import app diff --git a/tests/test_story_2_16_url_ingestion.py b/tests/test_story_2_16_url_ingestion.py index 47e1adc..f6f1cca 100644 --- a/tests/test_story_2_16_url_ingestion.py +++ b/tests/test_story_2_16_url_ingestion.py @@ -573,16 +573,18 @@ class TestURLIngestionIntegration: monkeypatch.setattr("routes.translate_routes.download_from_url", mock_download) monkeypatch.setattr( - "middleware.tier_quota.TierQuotaService.check_quota", - AsyncMock( - return_value=MagicMock( - allowed=True, - remaining=5, - reset_at_utc=None, - current_usage=0, - limit=5, - ) - ), + "routes.translate_routes.check_usage_limits", + lambda user: { + "can_translate": True, + "docs_used": 0, + "docs_limit": 5, + "docs_remaining": 5, + "pages_used": 0, + "extra_credits": 0, + "max_pages_per_doc": 50, + "max_file_size_mb": 10, + "allowed_providers": ["google", "deepl"], + }, ) response = pro_client.post( @@ -598,16 +600,18 @@ class TestURLIngestionIntegration: def test_free_user_rejected(self, free_client, monkeypatch): """AC7: Free user receives PRO_FEATURE_REQUIRED (403)""" monkeypatch.setattr( - "middleware.tier_quota.TierQuotaService.check_quota", - AsyncMock( - return_value=MagicMock( - allowed=True, - remaining=5, - reset_at_utc=None, - current_usage=0, - limit=5, - ) - ), + "routes.translate_routes.check_usage_limits", + lambda user: { + "can_translate": True, + "docs_used": 0, + "docs_limit": 5, + "docs_remaining": 5, + "pages_used": 0, + "extra_credits": 0, + "max_pages_per_doc": 50, + "max_file_size_mb": 10, + "allowed_providers": ["google", "deepl"], + }, ) response = free_client.post( @@ -632,16 +636,18 @@ class TestURLIngestionIntegration: monkeypatch.setattr("routes.translate_routes.download_from_url", mock_download) monkeypatch.setattr( - "middleware.tier_quota.TierQuotaService.check_quota", - AsyncMock( - return_value=MagicMock( - allowed=True, - remaining=5, - reset_at_utc=None, - current_usage=0, - limit=5, - ) - ), + "routes.translate_routes.check_usage_limits", + lambda user: { + "can_translate": True, + "docs_used": 0, + "docs_limit": 5, + "docs_remaining": 5, + "pages_used": 0, + "extra_credits": 0, + "max_pages_per_doc": 50, + "max_file_size_mb": 10, + "allowed_providers": ["google", "deepl"], + }, ) response = pro_client.post( @@ -668,16 +674,18 @@ class TestURLIngestionIntegration: monkeypatch.setattr("routes.translate_routes.download_from_url", mock_download) monkeypatch.setattr( - "middleware.tier_quota.TierQuotaService.check_quota", - AsyncMock( - return_value=MagicMock( - allowed=True, - remaining=5, - reset_at_utc=None, - current_usage=0, - limit=5, - ) - ), + "routes.translate_routes.check_usage_limits", + lambda user: { + "can_translate": True, + "docs_used": 0, + "docs_limit": 5, + "docs_remaining": 5, + "pages_used": 0, + "extra_credits": 0, + "max_pages_per_doc": 50, + "max_file_size_mb": 10, + "allowed_providers": ["google", "deepl"], + }, ) response = pro_client.post( @@ -704,16 +712,18 @@ class TestURLIngestionIntegration: monkeypatch.setattr("routes.translate_routes.download_from_url", mock_download) monkeypatch.setattr( - "middleware.tier_quota.TierQuotaService.check_quota", - AsyncMock( - return_value=MagicMock( - allowed=True, - remaining=5, - reset_at_utc=None, - current_usage=0, - limit=5, - ) - ), + "routes.translate_routes.check_usage_limits", + lambda user: { + "can_translate": True, + "docs_used": 0, + "docs_limit": 5, + "docs_remaining": 5, + "pages_used": 0, + "extra_credits": 0, + "max_pages_per_doc": 50, + "max_file_size_mb": 10, + "allowed_providers": ["google", "deepl"], + }, ) response = pro_client.post( @@ -738,16 +748,18 @@ class TestURLIngestionIntegration: monkeypatch.setattr("routes.translate_routes.download_from_url", mock_download) monkeypatch.setattr( - "middleware.tier_quota.TierQuotaService.check_quota", - AsyncMock( - return_value=MagicMock( - allowed=True, - remaining=5, - reset_at_utc=None, - current_usage=0, - limit=5, - ) - ), + "routes.translate_routes.check_usage_limits", + lambda user: { + "can_translate": True, + "docs_used": 0, + "docs_limit": 5, + "docs_remaining": 5, + "pages_used": 0, + "extra_credits": 0, + "max_pages_per_doc": 50, + "max_file_size_mb": 10, + "allowed_providers": ["google", "deepl"], + }, ) response = pro_client.post( @@ -772,16 +784,18 @@ class TestURLIngestionIntegration: monkeypatch.setattr("routes.translate_routes.download_from_url", mock_download) monkeypatch.setattr( - "middleware.tier_quota.TierQuotaService.check_quota", - AsyncMock( - return_value=MagicMock( - allowed=True, - remaining=5, - reset_at_utc=None, - current_usage=0, - limit=5, - ) - ), + "routes.translate_routes.check_usage_limits", + lambda user: { + "can_translate": True, + "docs_used": 0, + "docs_limit": 5, + "docs_remaining": 5, + "pages_used": 0, + "extra_credits": 0, + "max_pages_per_doc": 50, + "max_file_size_mb": 10, + "allowed_providers": ["google", "deepl"], + }, ) response = pro_client.post( diff --git a/tests/test_translate_endpoint.py b/tests/test_translate_endpoint.py index f388c0b..af7909c 100644 --- a/tests/test_translate_endpoint.py +++ b/tests/test_translate_endpoint.py @@ -115,26 +115,22 @@ def client(users_file: Path, monkeypatch): monkeypatch.setattr(RateLimitManager, "check_request", _check_request_allow) monkeypatch.setattr(RateLimitManager, "check_translation", _check_translation_allow) - from middleware.tier_quota import TierQuotaService + def _check_usage_limits_allow(user): + return { + "can_translate": True, + "docs_used": 0, + "docs_limit": 5, + "docs_remaining": 5, + "pages_used": 0, + "extra_credits": 0, + "max_pages_per_doc": 50, + "max_file_size_mb": 10, + "allowed_providers": ["google", "deepl"], + } - async def _check_quota_allow(self, user_id, tier): - from middleware.tier_quota import QuotaResult - from datetime import datetime, timezone, timedelta - - now = datetime.now(timezone.utc) - tomorrow = now.date() + timedelta(days=1) - reset_at = datetime( - tomorrow.year, tomorrow.month, tomorrow.day, tzinfo=timezone.utc - ) - return QuotaResult( - allowed=True, remaining=5, reset_at_utc=reset_at, current_usage=0, limit=5 - ) - - async def _increment_noop(self, user_id): - pass - - monkeypatch.setattr(TierQuotaService, "check_quota", _check_quota_allow) - monkeypatch.setattr(TierQuotaService, "increment_on_success", _increment_noop) + monkeypatch.setattr( + "routes.translate_routes.check_usage_limits", _check_usage_limits_allow + ) from main import app @@ -477,24 +473,22 @@ class TestQuotaExceeded: def test_returns_429_when_quota_exceeded(self, client, monkeypatch): """Returns 429 with QUOTA_EXCEEDED when quota exceeded""" - from middleware.tier_quota import TierQuotaService, QuotaResult - from datetime import datetime, timezone, timedelta + def _check_usage_limits_denied(user): + return { + "can_translate": False, + "docs_used": 5, + "docs_limit": 5, + "docs_remaining": 0, + "pages_used": 0, + "extra_credits": 0, + "max_pages_per_doc": 50, + "max_file_size_mb": 10, + "allowed_providers": ["google", "deepl"], + } - async def _check_quota_denied(self, user_id, tier): - now = datetime.now(timezone.utc) - tomorrow = now.date() + timedelta(days=1) - reset_at = datetime( - tomorrow.year, tomorrow.month, tomorrow.day, tzinfo=timezone.utc - ) - return QuotaResult( - allowed=False, - remaining=0, - reset_at_utc=reset_at, - current_usage=5, - limit=5, - ) - - monkeypatch.setattr(TierQuotaService, "check_quota", _check_quota_denied) + monkeypatch.setattr( + "routes.translate_routes.check_usage_limits", _check_usage_limits_denied + ) # Register and login client.post(REGISTER_URL, json=VALID_USER) @@ -530,24 +524,23 @@ class TestQuotaExceeded: def test_includes_retry_after_header(self, client, monkeypatch): """Includes Retry-After header on 429""" - from middleware.tier_quota import TierQuotaService, QuotaResult - from datetime import datetime, timezone, timedelta - async def _check_quota_denied(self, user_id, tier): - now = datetime.now(timezone.utc) - tomorrow = now.date() + timedelta(days=1) - reset_at = datetime( - tomorrow.year, tomorrow.month, tomorrow.day, tzinfo=timezone.utc - ) - return QuotaResult( - allowed=False, - remaining=0, - reset_at_utc=reset_at, - current_usage=5, - limit=5, - ) + def _check_usage_limits_denied(user): + return { + "can_translate": False, + "docs_used": 5, + "docs_limit": 5, + "docs_remaining": 0, + "pages_used": 0, + "extra_credits": 0, + "max_pages_per_doc": 50, + "max_file_size_mb": 10, + "allowed_providers": ["google", "deepl"], + } - monkeypatch.setattr(TierQuotaService, "check_quota", _check_quota_denied) + monkeypatch.setattr( + "routes.translate_routes.check_usage_limits", _check_usage_limits_denied + ) client.post(REGISTER_URL, json=VALID_USER) response = client.post(