fix(billing): unify quota counters, fix Stripe webhooks, tier/plan sync
All checks were successful
Deploy to Production / Build and Deploy (push) Successful in 3m16s

This commit is contained in:
2026-06-14 17:39:34 +02:00
parent fa637abff0
commit 45e44dd7b2
9 changed files with 546 additions and 256 deletions

View File

@@ -93,34 +93,45 @@ class UserRepository:
self.db.commit() self.db.commit()
return True return True
def increment_usage( def reset_usage_if_needed(self, user_id: str) -> Optional[User]:
self, user_id: str, docs: int = 0, pages: int = 0, api_calls: int = 0 """Reset monthly counters if usage_reset_date is in a previous month."""
) -> Optional[User]:
"""Increment usage counters"""
user = self.get_by_id(user_id) user = self.get_by_id(user_id)
if not user: if not user:
return None return None
# Check if usage needs to be reset (monthly)
if user.usage_reset_date:
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
if ( reset_date = user.usage_reset_date
now.month != user.usage_reset_date.month if reset_date is None or reset_date.month != now.month or reset_date.year != now.year:
or now.year != user.usage_reset_date.year
):
user.docs_translated_this_month = 0 user.docs_translated_this_month = 0
user.pages_translated_this_month = 0 user.pages_translated_this_month = 0
user.api_calls_this_month = 0 user.api_calls_this_month = 0
user.usage_reset_date = now user.usage_reset_date = now
user.updated_at = now
user.docs_translated_this_month += docs
user.pages_translated_this_month += pages
user.api_calls_this_month += api_calls
self.db.commit() self.db.commit()
self.db.refresh(user) self.db.refresh(user)
return user return user
def increment_usage(
self, user_id: str, docs: int = 0, pages: int = 0, api_calls: int = 0
) -> Optional[User]:
"""Increment usage counters atomically via SQL UPDATE."""
from sqlalchemy import update
now = datetime.now(timezone.utc)
self.db.execute(
update(User)
.where(User.id == user_id)
.values(
docs_translated_this_month=User.docs_translated_this_month + docs,
pages_translated_this_month=User.pages_translated_this_month + pages,
api_calls_this_month=User.api_calls_this_month + api_calls,
updated_at=now,
)
.execution_options(synchronize_session=False)
)
self.db.commit()
return self.get_by_id(user_id)
def add_credits(self, user_id: str, credits: int) -> Optional[User]: def add_credits(self, user_id: str, credits: int) -> Optional[User]:
"""Add extra credits to user""" """Add extra credits to user"""
user = self.get_by_id(user_id) user = self.get_by_id(user_id)
@@ -133,14 +144,21 @@ class UserRepository:
return user return user
def use_credits(self, user_id: str, credits: int) -> bool: def use_credits(self, user_id: str, credits: int) -> bool:
"""Use credits from user balance""" """Use credits from user balance atomically (prevents overdraft)."""
user = self.get_by_id(user_id) from sqlalchemy import update
if not user or user.extra_credits < credits:
return False
user.extra_credits -= credits now = datetime.now(timezone.utc)
result = self.db.execute(
update(User)
.where(User.id == user_id, User.extra_credits >= credits)
.values(
extra_credits=User.extra_credits - credits,
updated_at=now,
)
.execution_options(synchronize_session=False)
)
self.db.commit() self.db.commit()
return True return result.rowcount > 0
def get_all_users( def get_all_users(
self, skip: int = 0, limit: int = 100, plan: Optional[PlanType] = None self, skip: int = 0, limit: int = 100, plan: Optional[PlanType] = None

View File

@@ -204,6 +204,7 @@ class User(BaseModel):
# Subscription info # Subscription info
plan: PlanType = PlanType.FREE plan: PlanType = PlanType.FREE
tier: str = "free" # binary-ish tier for JWT/auth: free for free/starter, pro for paid
subscription_status: SubscriptionStatus = SubscriptionStatus.ACTIVE subscription_status: SubscriptionStatus = SubscriptionStatus.ACTIVE
stripe_customer_id: Optional[str] = None stripe_customer_id: Optional[str] = None
stripe_subscription_id: Optional[str] = None stripe_subscription_id: Optional[str] = None

View File

@@ -104,13 +104,14 @@ async def require_user(credentials: HTTPAuthorizationCredentials = Depends(secur
def user_to_response(user) -> UserResponse: def user_to_response(user) -> UserResponse:
"""Convert User to UserResponse with plan limits""" """Convert User to UserResponse with plan limits"""
plan_limits = PLANS[user.plan] plan_limits = PLANS[user.plan]
tier = getattr(user, "tier", None) or user.plan
return UserResponse( return UserResponse(
id=user.id, id=user.id,
email=user.email, email=user.email,
name=user.name, name=user.name,
avatar_url=user.avatar_url, avatar_url=user.avatar_url,
plan=user.plan, plan=user.plan,
tier=user.plan, tier=tier,
subscription_status=user.subscription_status, subscription_status=user.subscription_status,
subscription_ends_at=getattr(user, 'subscription_ends_at', None), subscription_ends_at=getattr(user, 'subscription_ends_at', None),
cancel_at_period_end=getattr(user, 'cancel_at_period_end', False), cancel_at_period_end=getattr(user, 'cancel_at_period_end', False),
@@ -540,7 +541,7 @@ async def login_v1(request: Request):
) )
access_token = create_access_token( access_token = create_access_token(
user.id, tier=user.plan.value, expires_delta=timedelta(minutes=15) user.id, tier=getattr(user, "tier", user.plan.value), expires_delta=timedelta(minutes=15)
) )
refresh_token = create_refresh_token(user.id, expires_delta=timedelta(days=7)) refresh_token = create_refresh_token(user.id, expires_delta=timedelta(days=7))
@@ -626,7 +627,7 @@ async def google_auth_v1(body: GoogleAuthRequest):
content={"error": "USER_CREATE_FAILED", "message": str(exc)}, content={"error": "USER_CREATE_FAILED", "message": str(exc)},
) )
access_token = create_access_token(user.id) access_token = create_access_token(user.id, tier=getattr(user, "tier", "free"))
refresh_token = create_refresh_token(user.id) refresh_token = create_refresh_token(user.id)
return JSONResponse( return JSONResponse(
@@ -724,7 +725,7 @@ async def refresh_v1(request: Request):
) )
access_token = create_access_token( access_token = create_access_token(
user.id, tier=user.plan.value, expires_delta=timedelta(minutes=15) user.id, tier=getattr(user, "tier", user.plan.value), expires_delta=timedelta(minutes=15)
) )
new_refresh_token = create_refresh_token(user.id, expires_delta=timedelta(days=7)) new_refresh_token = create_refresh_token(user.id, expires_delta=timedelta(days=7))

View File

@@ -44,8 +44,13 @@ from typing_extensions import Annotated
from config import config from config import config
from translators import ExcelTranslator, WordTranslator, PowerPointTranslator from translators import ExcelTranslator, WordTranslator, PowerPointTranslator
from models.subscription import PlanType from models.subscription import PlanType
from middleware.tier_quota import tier_quota_service from services.auth_service import (
from services.auth_service import record_usage, check_usage_limits record_usage,
check_usage_limits,
reserve_translation_quota,
release_translation_quota,
)
from middleware.tier_quota import _seconds_until_next_month, _next_month_utc
from middleware.validation import FileValidator, ValidationError, LanguageValidator, webhook_validator from middleware.validation import FileValidator, ValidationError, LanguageValidator, webhook_validator
from middleware.api_key_auth import get_authenticated_user, get_user_from_api_key from middleware.api_key_auth import get_authenticated_user, get_user_from_api_key
from utils import file_handler from utils import file_handler
@@ -548,6 +553,7 @@ async def translate_document_v1(
""" """
request_id = getattr(request.state, "request_id", str(uuid.uuid4())[:8]) request_id = getattr(request.state, "request_id", str(uuid.uuid4())[:8])
quota_reserved = False
try: try:
if not file and not file_url: if not file and not file_url:
raise TranslateEndpointError( raise TranslateEndpointError(
@@ -615,40 +621,42 @@ async def translate_document_v1(
) )
if current_user: if current_user:
quota = await tier_quota_service.check_quota(user_id, tier)
if not quota.allowed:
retry_after = tier_quota_service.seconds_until_reset()
raise HTTPException(
status_code=429,
detail={
"error": "QUOTA_EXCEEDED",
"message": f"Monthly limit reached ({quota.current_usage}/{quota.limit} documents). Upgrade your plan for more.",
"details": {
"current_usage": quota.current_usage,
"limit": quota.limit,
"tier": tier,
"reset_at": quota.reset_at_utc.isoformat(),
},
},
headers={"Retry-After": str(retry_after)},
)
# Strict database plan limit check (Starter, Pro, Business, Enterprise)
usage = check_usage_limits(current_user) usage = check_usage_limits(current_user)
if not usage["can_translate"]: if not usage["can_translate"]:
retry_after = _seconds_until_next_month()
raise HTTPException( raise HTTPException(
status_code=429, status_code=429,
detail={ detail={
"error": "QUOTA_EXCEEDED", "error": "QUOTA_EXCEEDED",
"message": f"Monthly limit reached ({usage['docs_used']}/{usage['docs_limit']} documents). Upgrade your plan for more.", "message": f"Monthly limit reached ({usage['docs_used']}/{usage['docs_limit']} documents). Upgrade your plan for more.",
"details": { "details": {
"current_usage": usage["docs_used"], "current_usage": usage['docs_used'],
"limit": usage["docs_limit"], "limit": usage['docs_limit'],
"tier": tier, "tier": tier,
"reset_at": _next_month_utc().isoformat(),
}, },
}, },
headers={"Retry-After": str(retry_after)},
) )
rate_limit_remaining = quota.remaining # Atomically reserve one document slot now so concurrent requests cannot
# overshoot the monthly quota while background jobs are still running.
reserved = await asyncio.to_thread(reserve_translation_quota, user_id)
if not reserved:
retry_after = _seconds_until_next_month()
raise HTTPException(
status_code=429,
detail={
"error": "QUOTA_EXCEEDED",
"message": "Monthly limit reached. Upgrade your plan for more.",
"details": {
"tier": tier,
"reset_at": _next_month_utc().isoformat(),
},
},
headers={"Retry-After": str(retry_after)},
)
quota_reserved = True
rate_limit_remaining = usage["docs_remaining"]
else: else:
rate_limit_remaining = -1 rate_limit_remaining = -1
@@ -839,6 +847,8 @@ async def translate_document_v1(
) )
except TranslateEndpointError as e: except TranslateEndpointError as e:
if quota_reserved and user_id:
await asyncio.to_thread(release_translation_quota, user_id)
status_code = 400 status_code = 400
if e.code == TranslateEndpointError.FILE_TOO_LARGE: if e.code == TranslateEndpointError.FILE_TOO_LARGE:
status_code = 413 status_code = 413
@@ -852,8 +862,12 @@ async def translate_document_v1(
content=e.to_dict(), content=e.to_dict(),
) )
except HTTPException: except HTTPException:
if quota_reserved and user_id:
await asyncio.to_thread(release_translation_quota, user_id)
raise raise
except Exception as e: except Exception as e:
if quota_reserved and user_id:
await asyncio.to_thread(release_translation_quota, user_id)
logger.error(f"[{request_id}] Unexpected error: {e}") logger.error(f"[{request_id}] Unexpected error: {e}")
return JSONResponse( return JSONResponse(
status_code=400, status_code=400,
@@ -973,6 +987,7 @@ async def _run_translation_job(
return return
tracker = ProgressTracker(job_id, _translation_jobs) tracker = ProgressTracker(job_id, _translation_jobs)
usage_recorded = False
try: try:
job["status"] = "processing" job["status"] = "processing"
@@ -1337,14 +1352,14 @@ async def _run_translation_job(
else: else:
cost_factor = 5 cost_factor = 5
for _ in range(cost_factor):
await tier_quota_service.increment_on_success(user_id)
# Persist monthly usage counters in PostgreSQL (docs + pages) # Persist monthly usage counters in PostgreSQL (docs + pages)
pages = await asyncio.to_thread( pages = await asyncio.to_thread(
_estimate_pages, input_path, file_extension _estimate_pages, input_path, file_extension
) )
await asyncio.to_thread(record_usage, user_id, pages, False, cost_factor) await asyncio.to_thread(
record_usage, user_id, pages, cost_factor, reserved_docs=1
)
usage_recorded = True
logger.info(f"Job {job_id}: usage recorded — {pages} page(s) with cost factor {cost_factor}") logger.info(f"Job {job_id}: usage recorded — {pages} page(s) with cost factor {cost_factor}")
# Apply watermark for Free-tier users # Apply watermark for Free-tier users
@@ -1364,6 +1379,12 @@ async def _run_translation_job(
record_translation(provider=provider, file_type=file_extension or "unknown", duration=duration, status="success") record_translation(provider=provider, file_type=file_extension or "unknown", duration=duration, status="success")
logger.info(f"Job {job_id}: Completed successfully") logger.info(f"Job {job_id}: Completed successfully")
except asyncio.CancelledError:
# Background task cancelled (e.g. TestClient teardown or server shutdown).
# The document slot was already reserved at request time; keep it consumed
# so quota enforcement remains deterministic.
logger.warning(f"Job {job_id}: translation task cancelled, keeping reserved quota")
raise
except Exception as e: except Exception as e:
# Check if this is our structured TranslationProviderError # Check if this is our structured TranslationProviderError
if type(e).__name__ == "TranslationProviderError": if type(e).__name__ == "TranslationProviderError":
@@ -1376,6 +1397,13 @@ async def _run_translation_job(
# Record translation failure metric # Record translation failure metric
record_translation(provider=provider, file_type=file_extension or "unknown", duration=0, status="error") record_translation(provider=provider, file_type=file_extension or "unknown", duration=0, status="error")
if user_id and not usage_recorded:
try:
await asyncio.to_thread(release_translation_quota, user_id)
logger.info(f"Job {job_id}: released reserved quota after failure")
except Exception as release_err:
logger.exception(f"Job {job_id}: failed to release reserved quota: {release_err}")
finally: finally:
if webhook_url: if webhook_url:
try: try:

View File

@@ -278,6 +278,7 @@ def _db_user_to_model(db_user) -> User:
password_hash=db_user.password_hash, password_hash=db_user.password_hash,
avatar_url=db_user.avatar_url, avatar_url=db_user.avatar_url,
plan=PlanType(db_user.plan) if db_user.plan else PlanType.FREE, plan=PlanType(db_user.plan) if db_user.plan else PlanType.FREE,
tier=db_user.tier or "free",
subscription_status=SubscriptionStatus(db_user.subscription_status) subscription_status=SubscriptionStatus(db_user.subscription_status)
if db_user.subscription_status if db_user.subscription_status
else SubscriptionStatus.ACTIVE, else SubscriptionStatus.ACTIVE,
@@ -460,36 +461,48 @@ def update_user(user_id: str, updates: Dict[str, Any]) -> Optional[User]:
return User(**users[user_id]) return User(**users[user_id])
def check_usage_limits(user: User) -> Dict[str, Any]: def _reset_usage_if_needed(user_id: str) -> Optional[User]:
"""Check if user has exceeded their plan limits""" """Reset monthly counters if usage_reset_date is in a previous month."""
plan = PLANS[user.plan] if USE_DATABASE and DATABASE_AVAILABLE:
from database.connection import get_sync_session
from database.repositories import UserRepository
# Reset usage if it's a new month with get_sync_session() as session:
repo = UserRepository(session)
repo.reset_usage_if_needed(user_id)
else:
user = get_user_by_id(user_id)
if not user:
return None
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
if ( reset_date = user.usage_reset_date
user.usage_reset_date.month != now.month if reset_date.month != now.month or reset_date.year != now.year:
or user.usage_reset_date.year != now.year
):
update_user( update_user(
user.id, user_id,
{ {
"docs_translated_this_month": 0, "docs_translated_this_month": 0,
"pages_translated_this_month": 0, "pages_translated_this_month": 0,
"api_calls_this_month": 0, "api_calls_this_month": 0,
"usage_reset_date": now.isoformat() if not USE_DATABASE else now, "usage_reset_date": now.isoformat(),
}, },
) )
user.docs_translated_this_month = 0 return get_user_by_id(user_id)
user.pages_translated_this_month = 0
user.api_calls_this_month = 0
def check_usage_limits(user: User) -> Dict[str, Any]:
"""Check if user has exceeded their plan limits"""
# Ensure counters are reset if we've entered a new month.
refreshed = _reset_usage_if_needed(user.id)
if refreshed:
user = refreshed
plan = PLANS[user.plan]
docs_limit = plan["docs_per_month"] docs_limit = plan["docs_per_month"]
docs_remaining = ( unlimited = docs_limit == -1
max(0, docs_limit - user.docs_translated_this_month) if docs_limit > 0 else -1 docs_remaining = -1 if unlimited else max(0, docs_limit - user.docs_translated_this_month)
)
return { return {
"can_translate": docs_remaining != 0 or user.extra_credits > 0, "can_translate": unlimited or docs_remaining != 0 or user.extra_credits > 0,
"docs_used": user.docs_translated_this_month, "docs_used": user.docs_translated_this_month,
"docs_limit": docs_limit, "docs_limit": docs_limit,
"docs_remaining": docs_remaining, "docs_remaining": docs_remaining,
@@ -502,23 +515,190 @@ def check_usage_limits(user: User) -> Dict[str, Any]:
def record_usage( def record_usage(
user_id: str, pages_count: int, use_credits: bool = False, cost_factor: int = 1 user_id: str, pages_count: int, cost_factor: int = 1, reserved_docs: int = 0
) -> bool: ) -> bool:
"""Record document translation usage with optional cost factor depending on AI model""" """Record document translation usage with optional cost factor depending on AI model.
`reserved_docs` is the number of document slots already reserved at request time
(e.g. by ``reserve_translation_quota``). Those slots are not counted again; only
the remaining ``max(0, cost_factor - reserved_docs)`` docs are added here.
Automatically consumes extra credits first; falls back to monthly quota.
Returns True if usage was recorded successfully, False otherwise.
"""
user = _reset_usage_if_needed(user_id)
if not user:
return False
total_cost = pages_count * cost_factor
docs_to_record = max(0, cost_factor - reserved_docs)
plan = PLANS[user.plan]
docs_limit = plan["docs_per_month"]
unlimited = docs_limit == -1
if USE_DATABASE and DATABASE_AVAILABLE:
from database.connection import get_sync_session
from database.repositories import UserRepository
with get_sync_session() as session:
repo = UserRepository(session)
if unlimited:
# Paid unlimited plans: track usage for analytics/downgrade safety.
repo.increment_usage(user_id, docs=docs_to_record, pages=total_cost)
return True
# Prefer credits first, then quota.
if user.extra_credits > 0:
credits_to_use = min(user.extra_credits, total_cost)
if repo.use_credits(user_id, credits_to_use):
remaining_cost = total_cost - credits_to_use
if remaining_cost > 0:
repo.increment_usage(
user_id, docs=docs_to_record, pages=remaining_cost
)
return True
return False
if user.docs_translated_this_month + docs_to_record <= docs_limit:
repo.increment_usage(user_id, docs=docs_to_record, pages=total_cost)
return True
return False
else:
# JSON fallback (non-atomic, dev only)
if unlimited:
return update_user(
user_id,
{
"docs_translated_this_month": user.docs_translated_this_month
+ docs_to_record,
"pages_translated_this_month": user.pages_translated_this_month
+ total_cost,
},
) is not None
if user.extra_credits > 0:
credits_to_use = min(user.extra_credits, total_cost)
new_credits = user.extra_credits - credits_to_use
remaining_cost = total_cost - credits_to_use
updates = {"extra_credits": new_credits}
if remaining_cost > 0:
if user.docs_translated_this_month + docs_to_record > docs_limit:
return False
updates["docs_translated_this_month"] = (
user.docs_translated_this_month + docs_to_record
)
updates["pages_translated_this_month"] = (
user.pages_translated_this_month + remaining_cost
)
return update_user(user_id, updates) is not None
if user.docs_translated_this_month + docs_to_record <= docs_limit:
return update_user(
user_id,
{
"docs_translated_this_month": user.docs_translated_this_month
+ docs_to_record,
"pages_translated_this_month": user.pages_translated_this_month
+ total_cost,
},
) is not None
return False
def reserve_translation_quota(user_id: str) -> bool:
"""Atomically reserve one document slot at request time.
Returns True if the reservation succeeded. This prevents race conditions where
multiple concurrent requests could exceed the monthly quota.
"""
user = _reset_usage_if_needed(user_id)
if not user:
return False
plan = PLANS[user.plan]
docs_limit = plan["docs_per_month"]
unlimited = docs_limit == -1
if unlimited:
if USE_DATABASE and DATABASE_AVAILABLE:
from database.connection import get_sync_session
from database.repositories import UserRepository
with get_sync_session() as session:
repo = UserRepository(session)
repo.increment_usage(user_id, docs=1, pages=0)
return True
return update_user(
user_id,
{"docs_translated_this_month": user.docs_translated_this_month + 1},
) is not None
# Limited plan: prefer credits first, then quota.
if user.extra_credits > 0:
if USE_DATABASE and DATABASE_AVAILABLE:
from database.connection import get_sync_session
from database.repositories import UserRepository
with get_sync_session() as session:
repo = UserRepository(session)
return repo.use_credits(user_id, 1)
return update_user(
user_id, {"extra_credits": max(0, user.extra_credits - 1)}
) is not None
if user.docs_translated_this_month + 1 <= docs_limit:
if USE_DATABASE and DATABASE_AVAILABLE:
from database.connection import get_sync_session
from database.repositories import UserRepository
with get_sync_session() as session:
repo = UserRepository(session)
repo.increment_usage(user_id, docs=1, pages=0)
return True
return update_user(
user_id,
{"docs_translated_this_month": user.docs_translated_this_month + 1},
) is not None
return False
def release_translation_quota(user_id: str) -> bool:
"""Release a previously reserved document slot (e.g. on translation failure)."""
user = get_user_by_id(user_id) user = get_user_by_id(user_id)
if not user: if not user:
return False return False
updates = { if USE_DATABASE and DATABASE_AVAILABLE:
"docs_translated_this_month": user.docs_translated_this_month + cost_factor, from database.connection import get_sync_session
"pages_translated_this_month": user.pages_translated_this_month + (pages_count * cost_factor), from database.repositories import UserRepository
} from sqlalchemy import update
if use_credits: with get_sync_session() as session:
updates["extra_credits"] = max(0, user.extra_credits - (pages_count * cost_factor)) now = datetime.now(timezone.utc)
session.execute(
update(db_models.User)
.where(
db_models.User.id == user_id,
db_models.User.docs_translated_this_month > 0,
)
.values(
docs_translated_this_month=db_models.User.docs_translated_this_month
- 1,
updated_at=now,
)
.execution_options(synchronize_session=False)
)
session.commit()
return True
result = update_user(user_id, updates) if user.docs_translated_this_month > 0:
return result is not None return update_user(
user_id,
{"docs_translated_this_month": user.docs_translated_this_month - 1},
) is not None
return True
def add_credits(user_id: str, credits: int) -> bool: def add_credits(user_id: str, credits: int) -> bool:

View File

@@ -4,7 +4,7 @@ Stripe payment integration for subscriptions and credits
import os import os
import logging import logging
from typing import Optional, Dict, Any from typing import Optional, Dict, Any
from datetime import datetime from datetime import datetime, timezone
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -263,6 +263,10 @@ async def handle_webhook(payload: bytes, sig_header: str) -> Dict[str, Any]:
invoice = event["data"]["object"] invoice = event["data"]["object"]
await handle_payment_failed(invoice) await handle_payment_failed(invoice)
elif event["type"] == "invoice.paid":
invoice = event["data"]["object"]
await handle_invoice_paid(invoice)
return {"status": "success"} return {"status": "success"}
@@ -276,7 +280,9 @@ async def handle_checkout_completed(session: Dict):
session_id = session.get("id") session_id = session.get("id")
# Check for duplicate session processing using PaymentHistory # Check for duplicate session processing using PaymentHistory.
# Use the Stripe payment_intent (one-time) or subscription id (recurring) as the
# idempotency key, NOT the checkout session id — Stripe can redeliver the event.
db_available = False db_available = False
try: try:
from database.connection import get_sync_session from database.connection import get_sync_session
@@ -285,14 +291,22 @@ async def handle_checkout_completed(session: Dict):
except ImportError: except ImportError:
pass pass
payment_intent_id = session.get("payment_intent")
subscription_id = session.get("subscription")
if db_available and session_id: if db_available and session_id:
try: try:
with get_sync_session() as db_session: with get_sync_session() as db_session:
existing = db_session.query(DBPaymentHistory).filter( from sqlalchemy import or_
DBPaymentHistory.stripe_payment_intent_id == session_id
).first() filters = [DBPaymentHistory.stripe_payment_intent_id == session_id]
if payment_intent_id:
filters.append(DBPaymentHistory.stripe_payment_intent_id == payment_intent_id)
if subscription_id:
filters.append(DBPaymentHistory.stripe_invoice_id == subscription_id)
existing = db_session.query(DBPaymentHistory).filter(or_(*filters)).first()
if existing: if existing:
logger.info("Checkout session %s already processed. Skipping.", session_id) logger.info("Checkout session %s already processed (pi=%s sub=%s). Skipping.",
session_id, payment_intent_id, subscription_id)
return return
except Exception as e: except Exception as e:
logger.error("Error checking PaymentHistory duplication: %s", e) logger.error("Error checking PaymentHistory duplication: %s", e)
@@ -312,7 +326,7 @@ async def handle_checkout_completed(session: Dict):
with get_sync_session() as db_session: with get_sync_session() as db_session:
payment = DBPaymentHistory( payment = DBPaymentHistory(
user_id=user_id, user_id=user_id,
stripe_payment_intent_id=session_id, stripe_payment_intent_id=payment_intent_id or session_id,
stripe_invoice_id=session.get("invoice") or session.get("subscription"), stripe_invoice_id=session.get("invoice") or session.get("subscription"),
amount_cents=session.get("amount_total") or 0, amount_cents=session.get("amount_total") or 0,
currency=session.get("currency") or "usd", currency=session.get("currency") or "usd",
@@ -377,7 +391,6 @@ async def handle_checkout_completed(session: Dict):
subscription_id = subscription_raw.get("id") subscription_id = subscription_raw.get("id")
period_end = subscription_raw.get("current_period_end") period_end = subscription_raw.get("current_period_end")
if period_end: if period_end:
from datetime import timezone
subscription_ends_at = datetime.fromtimestamp(period_end, tz=timezone.utc) subscription_ends_at = datetime.fromtimestamp(period_end, tz=timezone.utc)
# Derive tier from plan (DB constraint: only 'free' or 'pro') # Derive tier from plan (DB constraint: only 'free' or 'pro')
@@ -408,7 +421,7 @@ async def handle_checkout_completed(session: Dict):
with get_sync_session() as db_session: with get_sync_session() as db_session:
payment = DBPaymentHistory( payment = DBPaymentHistory(
user_id=user_id, user_id=user_id,
stripe_payment_intent_id=session_id, stripe_payment_intent_id=payment_intent_id or session_id,
stripe_invoice_id=subscription_id or session.get("invoice"), stripe_invoice_id=subscription_id or session.get("invoice"),
amount_cents=session.get("amount_total") or 0, amount_cents=session.get("amount_total") or 0,
currency=session.get("currency") or "usd", currency=session.get("currency") or "usd",
@@ -481,15 +494,13 @@ async def handle_subscription_updated(subscription: Dict):
period_end = subscription.get("current_period_end") period_end = subscription.get("current_period_end")
ends_str = "" ends_str = ""
if period_end: if period_end:
from datetime import timezone
ends_str = datetime.fromtimestamp(period_end, tz=timezone.utc).strftime("%d/%m/%Y") ends_str = datetime.fromtimestamp(period_end, tz=timezone.utc).strftime("%d/%m/%Y")
period_end = subscription.get("current_period_end")
update_user(user_id, { update_user(user_id, {
"subscription_status": status.value, "subscription_status": status.value,
"cancel_at_period_end": stripe_cancel_at_period_end, "cancel_at_period_end": stripe_cancel_at_period_end,
"subscription_ends_at": datetime.fromtimestamp( "subscription_ends_at": datetime.fromtimestamp(period_end, tz=timezone.utc) if period_end else None,
subscription.get("current_period_end", 0)
).isoformat() if subscription.get("current_period_end") else None
}) })
# Send cancellation email if they just selected to cancel # Send cancellation email if they just selected to cancel
@@ -531,7 +542,7 @@ async def handle_subscription_deleted(subscription: Dict):
if not user: if not user:
return return
had_active_sub = user.plan != PlanType.FREE.value or user.tier != "free" had_active_sub = user.plan != PlanType.FREE or user.tier != "free"
update_user(user_id, { update_user(user_id, {
"plan": PlanType.FREE.value, "plan": PlanType.FREE.value,
@@ -621,6 +632,50 @@ async def handle_payment_failed(invoice: Dict):
logger.error("handle_payment_failed DB error: %s", exc) logger.error("handle_payment_failed DB error: %s", exc)
async def handle_invoice_paid(invoice: Dict):
"""Extend subscription_ends_at when a recurring invoice is paid."""
customer_id = invoice.get("customer")
if not customer_id:
return
subscription_id = invoice.get("subscription")
period_end = invoice.get("period_end") or invoice.get("lines", {}).get("data", [{}])[0].get("period", {}).get("end")
try:
from database.connection import get_sync_session
from database.models import User as DBUser
with get_sync_session() as session:
db_user = (
session.query(DBUser)
.filter(DBUser.stripe_customer_id == customer_id)
.first()
)
if not db_user:
return
if subscription_id and db_user.stripe_subscription_id != subscription_id:
# The paid invoice belongs to a different subscription; do not update.
logger.warning(
"Invoice paid for customer %s but subscription id mismatch (expected %s, got %s)",
customer_id, db_user.stripe_subscription_id, subscription_id,
)
return
if period_end:
new_end = datetime.fromtimestamp(period_end, tz=timezone.utc)
if db_user.subscription_ends_at is None or new_end > db_user.subscription_ends_at:
db_user.subscription_ends_at = new_end
db_user.updated_at = datetime.now(timezone.utc)
session.commit()
logger.info(
"Extended subscription_ends_at for user %s to %s",
db_user.id, new_end.isoformat(),
)
except Exception as exc:
logger.error("handle_invoice_paid error: %s", exc)
async def cancel_subscription(user_id: str) -> Dict[str, Any]: async def cancel_subscription(user_id: str) -> Dict[str, Any]:
"""Cancel a user's subscription at period end.""" """Cancel a user's subscription at period end."""
if not is_stripe_configured(): if not is_stripe_configured():
@@ -631,6 +686,10 @@ async def cancel_subscription(user_id: str) -> Dict[str, Any]:
return {"error": "No active subscription found"} return {"error": "No active subscription found"}
try: try:
subscription = stripe.Subscription.retrieve(user.stripe_subscription_id)
if subscription.customer != user.stripe_customer_id:
return {"error": "Subscription does not belong to current user"}
subscription = stripe.Subscription.modify( subscription = stripe.Subscription.modify(
user.stripe_subscription_id, user.stripe_subscription_id,
cancel_at_period_end=True, cancel_at_period_end=True,
@@ -638,13 +697,13 @@ async def cancel_subscription(user_id: str) -> Dict[str, Any]:
cancel_at = None cancel_at = None
if subscription.cancel_at: if subscription.cancel_at:
cancel_at = datetime.fromtimestamp(subscription.cancel_at).isoformat() cancel_at = datetime.fromtimestamp(subscription.cancel_at, tz=timezone.utc)
subscription_ends_at = None subscription_ends_at = None
ends_str = "" ends_str = ""
if subscription.current_period_end: if subscription.current_period_end:
subscription_ends_at = datetime.fromtimestamp(subscription.current_period_end).isoformat() subscription_ends_at = datetime.fromtimestamp(subscription.current_period_end, tz=timezone.utc)
ends_str = datetime.fromtimestamp(subscription.current_period_end).strftime("%d/%m/%Y") ends_str = datetime.fromtimestamp(subscription.current_period_end, tz=timezone.utc).strftime("%d/%m/%Y")
is_new_cancel = not user.cancel_at_period_end is_new_cancel = not user.cancel_at_period_end

View File

@@ -110,26 +110,22 @@ def client(users_file: Path, monkeypatch):
monkeypatch.setattr(RateLimitManager, "check_request", _check_request_allow) monkeypatch.setattr(RateLimitManager, "check_request", _check_request_allow)
monkeypatch.setattr(RateLimitManager, "check_translation", _check_translation_allow) monkeypatch.setattr(RateLimitManager, "check_translation", _check_translation_allow)
from middleware.tier_quota import TierQuotaService def _check_usage_limits_allow(user):
return {
"can_translate": True,
"docs_used": 0,
"docs_limit": 5,
"docs_remaining": 5,
"pages_used": 0,
"extra_credits": 0,
"max_pages_per_doc": 50,
"max_file_size_mb": 10,
"allowed_providers": ["google", "deepl"],
}
async def _check_quota_allow(self, user_id, tier): monkeypatch.setattr(
from middleware.tier_quota import QuotaResult "routes.translate_routes.check_usage_limits", _check_usage_limits_allow
from datetime import datetime, timezone, timedelta
now = datetime.now(timezone.utc)
tomorrow = now.date() + timedelta(days=1)
reset_at = datetime(
tomorrow.year, tomorrow.month, tomorrow.day, tzinfo=timezone.utc
) )
return QuotaResult(
allowed=True, remaining=5, reset_at_utc=reset_at, current_usage=0, limit=5
)
async def _increment_noop(self, user_id):
pass
monkeypatch.setattr(TierQuotaService, "check_quota", _check_quota_allow)
monkeypatch.setattr(TierQuotaService, "increment_on_success", _increment_noop)
from main import app from main import app

View File

@@ -573,16 +573,18 @@ class TestURLIngestionIntegration:
monkeypatch.setattr("routes.translate_routes.download_from_url", mock_download) monkeypatch.setattr("routes.translate_routes.download_from_url", mock_download)
monkeypatch.setattr( monkeypatch.setattr(
"middleware.tier_quota.TierQuotaService.check_quota", "routes.translate_routes.check_usage_limits",
AsyncMock( lambda user: {
return_value=MagicMock( "can_translate": True,
allowed=True, "docs_used": 0,
remaining=5, "docs_limit": 5,
reset_at_utc=None, "docs_remaining": 5,
current_usage=0, "pages_used": 0,
limit=5, "extra_credits": 0,
) "max_pages_per_doc": 50,
), "max_file_size_mb": 10,
"allowed_providers": ["google", "deepl"],
},
) )
response = pro_client.post( response = pro_client.post(
@@ -598,16 +600,18 @@ class TestURLIngestionIntegration:
def test_free_user_rejected(self, free_client, monkeypatch): def test_free_user_rejected(self, free_client, monkeypatch):
"""AC7: Free user receives PRO_FEATURE_REQUIRED (403)""" """AC7: Free user receives PRO_FEATURE_REQUIRED (403)"""
monkeypatch.setattr( monkeypatch.setattr(
"middleware.tier_quota.TierQuotaService.check_quota", "routes.translate_routes.check_usage_limits",
AsyncMock( lambda user: {
return_value=MagicMock( "can_translate": True,
allowed=True, "docs_used": 0,
remaining=5, "docs_limit": 5,
reset_at_utc=None, "docs_remaining": 5,
current_usage=0, "pages_used": 0,
limit=5, "extra_credits": 0,
) "max_pages_per_doc": 50,
), "max_file_size_mb": 10,
"allowed_providers": ["google", "deepl"],
},
) )
response = free_client.post( response = free_client.post(
@@ -632,16 +636,18 @@ class TestURLIngestionIntegration:
monkeypatch.setattr("routes.translate_routes.download_from_url", mock_download) monkeypatch.setattr("routes.translate_routes.download_from_url", mock_download)
monkeypatch.setattr( monkeypatch.setattr(
"middleware.tier_quota.TierQuotaService.check_quota", "routes.translate_routes.check_usage_limits",
AsyncMock( lambda user: {
return_value=MagicMock( "can_translate": True,
allowed=True, "docs_used": 0,
remaining=5, "docs_limit": 5,
reset_at_utc=None, "docs_remaining": 5,
current_usage=0, "pages_used": 0,
limit=5, "extra_credits": 0,
) "max_pages_per_doc": 50,
), "max_file_size_mb": 10,
"allowed_providers": ["google", "deepl"],
},
) )
response = pro_client.post( response = pro_client.post(
@@ -668,16 +674,18 @@ class TestURLIngestionIntegration:
monkeypatch.setattr("routes.translate_routes.download_from_url", mock_download) monkeypatch.setattr("routes.translate_routes.download_from_url", mock_download)
monkeypatch.setattr( monkeypatch.setattr(
"middleware.tier_quota.TierQuotaService.check_quota", "routes.translate_routes.check_usage_limits",
AsyncMock( lambda user: {
return_value=MagicMock( "can_translate": True,
allowed=True, "docs_used": 0,
remaining=5, "docs_limit": 5,
reset_at_utc=None, "docs_remaining": 5,
current_usage=0, "pages_used": 0,
limit=5, "extra_credits": 0,
) "max_pages_per_doc": 50,
), "max_file_size_mb": 10,
"allowed_providers": ["google", "deepl"],
},
) )
response = pro_client.post( response = pro_client.post(
@@ -704,16 +712,18 @@ class TestURLIngestionIntegration:
monkeypatch.setattr("routes.translate_routes.download_from_url", mock_download) monkeypatch.setattr("routes.translate_routes.download_from_url", mock_download)
monkeypatch.setattr( monkeypatch.setattr(
"middleware.tier_quota.TierQuotaService.check_quota", "routes.translate_routes.check_usage_limits",
AsyncMock( lambda user: {
return_value=MagicMock( "can_translate": True,
allowed=True, "docs_used": 0,
remaining=5, "docs_limit": 5,
reset_at_utc=None, "docs_remaining": 5,
current_usage=0, "pages_used": 0,
limit=5, "extra_credits": 0,
) "max_pages_per_doc": 50,
), "max_file_size_mb": 10,
"allowed_providers": ["google", "deepl"],
},
) )
response = pro_client.post( response = pro_client.post(
@@ -738,16 +748,18 @@ class TestURLIngestionIntegration:
monkeypatch.setattr("routes.translate_routes.download_from_url", mock_download) monkeypatch.setattr("routes.translate_routes.download_from_url", mock_download)
monkeypatch.setattr( monkeypatch.setattr(
"middleware.tier_quota.TierQuotaService.check_quota", "routes.translate_routes.check_usage_limits",
AsyncMock( lambda user: {
return_value=MagicMock( "can_translate": True,
allowed=True, "docs_used": 0,
remaining=5, "docs_limit": 5,
reset_at_utc=None, "docs_remaining": 5,
current_usage=0, "pages_used": 0,
limit=5, "extra_credits": 0,
) "max_pages_per_doc": 50,
), "max_file_size_mb": 10,
"allowed_providers": ["google", "deepl"],
},
) )
response = pro_client.post( response = pro_client.post(
@@ -772,16 +784,18 @@ class TestURLIngestionIntegration:
monkeypatch.setattr("routes.translate_routes.download_from_url", mock_download) monkeypatch.setattr("routes.translate_routes.download_from_url", mock_download)
monkeypatch.setattr( monkeypatch.setattr(
"middleware.tier_quota.TierQuotaService.check_quota", "routes.translate_routes.check_usage_limits",
AsyncMock( lambda user: {
return_value=MagicMock( "can_translate": True,
allowed=True, "docs_used": 0,
remaining=5, "docs_limit": 5,
reset_at_utc=None, "docs_remaining": 5,
current_usage=0, "pages_used": 0,
limit=5, "extra_credits": 0,
) "max_pages_per_doc": 50,
), "max_file_size_mb": 10,
"allowed_providers": ["google", "deepl"],
},
) )
response = pro_client.post( response = pro_client.post(

View File

@@ -115,26 +115,22 @@ def client(users_file: Path, monkeypatch):
monkeypatch.setattr(RateLimitManager, "check_request", _check_request_allow) monkeypatch.setattr(RateLimitManager, "check_request", _check_request_allow)
monkeypatch.setattr(RateLimitManager, "check_translation", _check_translation_allow) monkeypatch.setattr(RateLimitManager, "check_translation", _check_translation_allow)
from middleware.tier_quota import TierQuotaService def _check_usage_limits_allow(user):
return {
"can_translate": True,
"docs_used": 0,
"docs_limit": 5,
"docs_remaining": 5,
"pages_used": 0,
"extra_credits": 0,
"max_pages_per_doc": 50,
"max_file_size_mb": 10,
"allowed_providers": ["google", "deepl"],
}
async def _check_quota_allow(self, user_id, tier): monkeypatch.setattr(
from middleware.tier_quota import QuotaResult "routes.translate_routes.check_usage_limits", _check_usage_limits_allow
from datetime import datetime, timezone, timedelta
now = datetime.now(timezone.utc)
tomorrow = now.date() + timedelta(days=1)
reset_at = datetime(
tomorrow.year, tomorrow.month, tomorrow.day, tzinfo=timezone.utc
) )
return QuotaResult(
allowed=True, remaining=5, reset_at_utc=reset_at, current_usage=0, limit=5
)
async def _increment_noop(self, user_id):
pass
monkeypatch.setattr(TierQuotaService, "check_quota", _check_quota_allow)
monkeypatch.setattr(TierQuotaService, "increment_on_success", _increment_noop)
from main import app from main import app
@@ -477,24 +473,22 @@ class TestQuotaExceeded:
def test_returns_429_when_quota_exceeded(self, client, monkeypatch): def test_returns_429_when_quota_exceeded(self, client, monkeypatch):
"""Returns 429 with QUOTA_EXCEEDED when quota exceeded""" """Returns 429 with QUOTA_EXCEEDED when quota exceeded"""
from middleware.tier_quota import TierQuotaService, QuotaResult def _check_usage_limits_denied(user):
from datetime import datetime, timezone, timedelta return {
"can_translate": False,
"docs_used": 5,
"docs_limit": 5,
"docs_remaining": 0,
"pages_used": 0,
"extra_credits": 0,
"max_pages_per_doc": 50,
"max_file_size_mb": 10,
"allowed_providers": ["google", "deepl"],
}
async def _check_quota_denied(self, user_id, tier): monkeypatch.setattr(
now = datetime.now(timezone.utc) "routes.translate_routes.check_usage_limits", _check_usage_limits_denied
tomorrow = now.date() + timedelta(days=1)
reset_at = datetime(
tomorrow.year, tomorrow.month, tomorrow.day, tzinfo=timezone.utc
) )
return QuotaResult(
allowed=False,
remaining=0,
reset_at_utc=reset_at,
current_usage=5,
limit=5,
)
monkeypatch.setattr(TierQuotaService, "check_quota", _check_quota_denied)
# Register and login # Register and login
client.post(REGISTER_URL, json=VALID_USER) client.post(REGISTER_URL, json=VALID_USER)
@@ -530,24 +524,23 @@ class TestQuotaExceeded:
def test_includes_retry_after_header(self, client, monkeypatch): def test_includes_retry_after_header(self, client, monkeypatch):
"""Includes Retry-After header on 429""" """Includes Retry-After header on 429"""
from middleware.tier_quota import TierQuotaService, QuotaResult
from datetime import datetime, timezone, timedelta
async def _check_quota_denied(self, user_id, tier): def _check_usage_limits_denied(user):
now = datetime.now(timezone.utc) return {
tomorrow = now.date() + timedelta(days=1) "can_translate": False,
reset_at = datetime( "docs_used": 5,
tomorrow.year, tomorrow.month, tomorrow.day, tzinfo=timezone.utc "docs_limit": 5,
) "docs_remaining": 0,
return QuotaResult( "pages_used": 0,
allowed=False, "extra_credits": 0,
remaining=0, "max_pages_per_doc": 50,
reset_at_utc=reset_at, "max_file_size_mb": 10,
current_usage=5, "allowed_providers": ["google", "deepl"],
limit=5, }
)
monkeypatch.setattr(TierQuotaService, "check_quota", _check_quota_denied) monkeypatch.setattr(
"routes.translate_routes.check_usage_limits", _check_usage_limits_denied
)
client.post(REGISTER_URL, json=VALID_USER) client.post(REGISTER_URL, json=VALID_USER)
response = client.post( response = client.post(