fix(billing): unify quota counters, fix Stripe webhooks, tier/plan sync
All checks were successful
Deploy to Production / Build and Deploy (push) Successful in 3m16s

This commit is contained in:
2026-06-14 17:39:34 +02:00
parent fa637abff0
commit 45e44dd7b2
9 changed files with 546 additions and 256 deletions

View File

@@ -93,34 +93,45 @@ class UserRepository:
self.db.commit()
return True
def increment_usage(
self, user_id: str, docs: int = 0, pages: int = 0, api_calls: int = 0
) -> Optional[User]:
"""Increment usage counters"""
def reset_usage_if_needed(self, user_id: str) -> Optional[User]:
"""Reset monthly counters if usage_reset_date is in a previous month."""
user = self.get_by_id(user_id)
if not user:
return None
# Check if usage needs to be reset (monthly)
if user.usage_reset_date:
now = datetime.now(timezone.utc)
if (
now.month != user.usage_reset_date.month
or now.year != user.usage_reset_date.year
):
user.docs_translated_this_month = 0
user.pages_translated_this_month = 0
user.api_calls_this_month = 0
user.usage_reset_date = now
user.docs_translated_this_month += docs
user.pages_translated_this_month += pages
user.api_calls_this_month += api_calls
self.db.commit()
self.db.refresh(user)
now = datetime.now(timezone.utc)
reset_date = user.usage_reset_date
if reset_date is None or reset_date.month != now.month or reset_date.year != now.year:
user.docs_translated_this_month = 0
user.pages_translated_this_month = 0
user.api_calls_this_month = 0
user.usage_reset_date = now
user.updated_at = now
self.db.commit()
self.db.refresh(user)
return user
def increment_usage(
self, user_id: str, docs: int = 0, pages: int = 0, api_calls: int = 0
) -> Optional[User]:
"""Increment usage counters atomically via SQL UPDATE."""
from sqlalchemy import update
now = datetime.now(timezone.utc)
self.db.execute(
update(User)
.where(User.id == user_id)
.values(
docs_translated_this_month=User.docs_translated_this_month + docs,
pages_translated_this_month=User.pages_translated_this_month + pages,
api_calls_this_month=User.api_calls_this_month + api_calls,
updated_at=now,
)
.execution_options(synchronize_session=False)
)
self.db.commit()
return self.get_by_id(user_id)
def add_credits(self, user_id: str, credits: int) -> Optional[User]:
"""Add extra credits to user"""
user = self.get_by_id(user_id)
@@ -133,14 +144,21 @@ class UserRepository:
return user
def use_credits(self, user_id: str, credits: int) -> bool:
"""Use credits from user balance"""
user = self.get_by_id(user_id)
if not user or user.extra_credits < credits:
return False
"""Use credits from user balance atomically (prevents overdraft)."""
from sqlalchemy import update
user.extra_credits -= credits
now = datetime.now(timezone.utc)
result = self.db.execute(
update(User)
.where(User.id == user_id, User.extra_credits >= credits)
.values(
extra_credits=User.extra_credits - credits,
updated_at=now,
)
.execution_options(synchronize_session=False)
)
self.db.commit()
return True
return result.rowcount > 0
def get_all_users(
self, skip: int = 0, limit: int = 100, plan: Optional[PlanType] = None

View File

@@ -204,6 +204,7 @@ class User(BaseModel):
# Subscription info
plan: PlanType = PlanType.FREE
tier: str = "free" # binary-ish tier for JWT/auth: free for free/starter, pro for paid
subscription_status: SubscriptionStatus = SubscriptionStatus.ACTIVE
stripe_customer_id: Optional[str] = None
stripe_subscription_id: Optional[str] = None

View File

@@ -104,13 +104,14 @@ async def require_user(credentials: HTTPAuthorizationCredentials = Depends(secur
def user_to_response(user) -> UserResponse:
"""Convert User to UserResponse with plan limits"""
plan_limits = PLANS[user.plan]
tier = getattr(user, "tier", None) or user.plan
return UserResponse(
id=user.id,
email=user.email,
name=user.name,
avatar_url=user.avatar_url,
plan=user.plan,
tier=user.plan,
tier=tier,
subscription_status=user.subscription_status,
subscription_ends_at=getattr(user, 'subscription_ends_at', None),
cancel_at_period_end=getattr(user, 'cancel_at_period_end', False),
@@ -540,7 +541,7 @@ async def login_v1(request: Request):
)
access_token = create_access_token(
user.id, tier=user.plan.value, expires_delta=timedelta(minutes=15)
user.id, tier=getattr(user, "tier", user.plan.value), expires_delta=timedelta(minutes=15)
)
refresh_token = create_refresh_token(user.id, expires_delta=timedelta(days=7))
@@ -626,7 +627,7 @@ async def google_auth_v1(body: GoogleAuthRequest):
content={"error": "USER_CREATE_FAILED", "message": str(exc)},
)
access_token = create_access_token(user.id)
access_token = create_access_token(user.id, tier=getattr(user, "tier", "free"))
refresh_token = create_refresh_token(user.id)
return JSONResponse(
@@ -724,7 +725,7 @@ async def refresh_v1(request: Request):
)
access_token = create_access_token(
user.id, tier=user.plan.value, expires_delta=timedelta(minutes=15)
user.id, tier=getattr(user, "tier", user.plan.value), expires_delta=timedelta(minutes=15)
)
new_refresh_token = create_refresh_token(user.id, expires_delta=timedelta(days=7))

View File

@@ -44,8 +44,13 @@ from typing_extensions import Annotated
from config import config
from translators import ExcelTranslator, WordTranslator, PowerPointTranslator
from models.subscription import PlanType
from middleware.tier_quota import tier_quota_service
from services.auth_service import record_usage, check_usage_limits
from services.auth_service import (
record_usage,
check_usage_limits,
reserve_translation_quota,
release_translation_quota,
)
from middleware.tier_quota import _seconds_until_next_month, _next_month_utc
from middleware.validation import FileValidator, ValidationError, LanguageValidator, webhook_validator
from middleware.api_key_auth import get_authenticated_user, get_user_from_api_key
from utils import file_handler
@@ -548,6 +553,7 @@ async def translate_document_v1(
"""
request_id = getattr(request.state, "request_id", str(uuid.uuid4())[:8])
quota_reserved = False
try:
if not file and not file_url:
raise TranslateEndpointError(
@@ -615,40 +621,42 @@ async def translate_document_v1(
)
if current_user:
quota = await tier_quota_service.check_quota(user_id, tier)
if not quota.allowed:
retry_after = tier_quota_service.seconds_until_reset()
raise HTTPException(
status_code=429,
detail={
"error": "QUOTA_EXCEEDED",
"message": f"Monthly limit reached ({quota.current_usage}/{quota.limit} documents). Upgrade your plan for more.",
"details": {
"current_usage": quota.current_usage,
"limit": quota.limit,
"tier": tier,
"reset_at": quota.reset_at_utc.isoformat(),
},
},
headers={"Retry-After": str(retry_after)},
)
# Strict database plan limit check (Starter, Pro, Business, Enterprise)
usage = check_usage_limits(current_user)
if not usage["can_translate"]:
retry_after = _seconds_until_next_month()
raise HTTPException(
status_code=429,
detail={
"error": "QUOTA_EXCEEDED",
"message": f"Monthly limit reached ({usage['docs_used']}/{usage['docs_limit']} documents). Upgrade your plan for more.",
"details": {
"current_usage": usage["docs_used"],
"limit": usage["docs_limit"],
"current_usage": usage['docs_used'],
"limit": usage['docs_limit'],
"tier": tier,
"reset_at": _next_month_utc().isoformat(),
},
},
headers={"Retry-After": str(retry_after)},
)
rate_limit_remaining = quota.remaining
# Atomically reserve one document slot now so concurrent requests cannot
# overshoot the monthly quota while background jobs are still running.
reserved = await asyncio.to_thread(reserve_translation_quota, user_id)
if not reserved:
retry_after = _seconds_until_next_month()
raise HTTPException(
status_code=429,
detail={
"error": "QUOTA_EXCEEDED",
"message": "Monthly limit reached. Upgrade your plan for more.",
"details": {
"tier": tier,
"reset_at": _next_month_utc().isoformat(),
},
},
headers={"Retry-After": str(retry_after)},
)
quota_reserved = True
rate_limit_remaining = usage["docs_remaining"]
else:
rate_limit_remaining = -1
@@ -839,6 +847,8 @@ async def translate_document_v1(
)
except TranslateEndpointError as e:
if quota_reserved and user_id:
await asyncio.to_thread(release_translation_quota, user_id)
status_code = 400
if e.code == TranslateEndpointError.FILE_TOO_LARGE:
status_code = 413
@@ -852,8 +862,12 @@ async def translate_document_v1(
content=e.to_dict(),
)
except HTTPException:
if quota_reserved and user_id:
await asyncio.to_thread(release_translation_quota, user_id)
raise
except Exception as e:
if quota_reserved and user_id:
await asyncio.to_thread(release_translation_quota, user_id)
logger.error(f"[{request_id}] Unexpected error: {e}")
return JSONResponse(
status_code=400,
@@ -973,6 +987,7 @@ async def _run_translation_job(
return
tracker = ProgressTracker(job_id, _translation_jobs)
usage_recorded = False
try:
job["status"] = "processing"
@@ -1337,14 +1352,14 @@ async def _run_translation_job(
else:
cost_factor = 5
for _ in range(cost_factor):
await tier_quota_service.increment_on_success(user_id)
# Persist monthly usage counters in PostgreSQL (docs + pages)
pages = await asyncio.to_thread(
_estimate_pages, input_path, file_extension
)
await asyncio.to_thread(record_usage, user_id, pages, False, cost_factor)
await asyncio.to_thread(
record_usage, user_id, pages, cost_factor, reserved_docs=1
)
usage_recorded = True
logger.info(f"Job {job_id}: usage recorded — {pages} page(s) with cost factor {cost_factor}")
# Apply watermark for Free-tier users
@@ -1364,6 +1379,12 @@ async def _run_translation_job(
record_translation(provider=provider, file_type=file_extension or "unknown", duration=duration, status="success")
logger.info(f"Job {job_id}: Completed successfully")
except asyncio.CancelledError:
# Background task cancelled (e.g. TestClient teardown or server shutdown).
# The document slot was already reserved at request time; keep it consumed
# so quota enforcement remains deterministic.
logger.warning(f"Job {job_id}: translation task cancelled, keeping reserved quota")
raise
except Exception as e:
# Check if this is our structured TranslationProviderError
if type(e).__name__ == "TranslationProviderError":
@@ -1376,6 +1397,13 @@ async def _run_translation_job(
# Record translation failure metric
record_translation(provider=provider, file_type=file_extension or "unknown", duration=0, status="error")
if user_id and not usage_recorded:
try:
await asyncio.to_thread(release_translation_quota, user_id)
logger.info(f"Job {job_id}: released reserved quota after failure")
except Exception as release_err:
logger.exception(f"Job {job_id}: failed to release reserved quota: {release_err}")
finally:
if webhook_url:
try:

View File

@@ -278,6 +278,7 @@ def _db_user_to_model(db_user) -> User:
password_hash=db_user.password_hash,
avatar_url=db_user.avatar_url,
plan=PlanType(db_user.plan) if db_user.plan else PlanType.FREE,
tier=db_user.tier or "free",
subscription_status=SubscriptionStatus(db_user.subscription_status)
if db_user.subscription_status
else SubscriptionStatus.ACTIVE,
@@ -460,36 +461,48 @@ def update_user(user_id: str, updates: Dict[str, Any]) -> Optional[User]:
return User(**users[user_id])
def _reset_usage_if_needed(user_id: str) -> Optional[User]:
"""Reset monthly counters if usage_reset_date is in a previous month."""
if USE_DATABASE and DATABASE_AVAILABLE:
from database.connection import get_sync_session
from database.repositories import UserRepository
with get_sync_session() as session:
repo = UserRepository(session)
repo.reset_usage_if_needed(user_id)
else:
user = get_user_by_id(user_id)
if not user:
return None
now = datetime.now(timezone.utc)
reset_date = user.usage_reset_date
if reset_date.month != now.month or reset_date.year != now.year:
update_user(
user_id,
{
"docs_translated_this_month": 0,
"pages_translated_this_month": 0,
"api_calls_this_month": 0,
"usage_reset_date": now.isoformat(),
},
)
return get_user_by_id(user_id)
def check_usage_limits(user: User) -> Dict[str, Any]:
"""Check if user has exceeded their plan limits"""
# Ensure counters are reset if we've entered a new month.
refreshed = _reset_usage_if_needed(user.id)
if refreshed:
user = refreshed
plan = PLANS[user.plan]
# Reset usage if it's a new month
now = datetime.now(timezone.utc)
if (
user.usage_reset_date.month != now.month
or user.usage_reset_date.year != now.year
):
update_user(
user.id,
{
"docs_translated_this_month": 0,
"pages_translated_this_month": 0,
"api_calls_this_month": 0,
"usage_reset_date": now.isoformat() if not USE_DATABASE else now,
},
)
user.docs_translated_this_month = 0
user.pages_translated_this_month = 0
user.api_calls_this_month = 0
docs_limit = plan["docs_per_month"]
docs_remaining = (
max(0, docs_limit - user.docs_translated_this_month) if docs_limit > 0 else -1
)
unlimited = docs_limit == -1
docs_remaining = -1 if unlimited else max(0, docs_limit - user.docs_translated_this_month)
return {
"can_translate": docs_remaining != 0 or user.extra_credits > 0,
"can_translate": unlimited or docs_remaining != 0 or user.extra_credits > 0,
"docs_used": user.docs_translated_this_month,
"docs_limit": docs_limit,
"docs_remaining": docs_remaining,
@@ -502,23 +515,190 @@ def check_usage_limits(user: User) -> Dict[str, Any]:
def record_usage(
user_id: str, pages_count: int, use_credits: bool = False, cost_factor: int = 1
user_id: str, pages_count: int, cost_factor: int = 1, reserved_docs: int = 0
) -> bool:
"""Record document translation usage with optional cost factor depending on AI model"""
"""Record document translation usage with optional cost factor depending on AI model.
`reserved_docs` is the number of document slots already reserved at request time
(e.g. by ``reserve_translation_quota``). Those slots are not counted again; only
the remaining ``max(0, cost_factor - reserved_docs)`` docs are added here.
Automatically consumes extra credits first; falls back to monthly quota.
Returns True if usage was recorded successfully, False otherwise.
"""
user = _reset_usage_if_needed(user_id)
if not user:
return False
total_cost = pages_count * cost_factor
docs_to_record = max(0, cost_factor - reserved_docs)
plan = PLANS[user.plan]
docs_limit = plan["docs_per_month"]
unlimited = docs_limit == -1
if USE_DATABASE and DATABASE_AVAILABLE:
from database.connection import get_sync_session
from database.repositories import UserRepository
with get_sync_session() as session:
repo = UserRepository(session)
if unlimited:
# Paid unlimited plans: track usage for analytics/downgrade safety.
repo.increment_usage(user_id, docs=docs_to_record, pages=total_cost)
return True
# Prefer credits first, then quota.
if user.extra_credits > 0:
credits_to_use = min(user.extra_credits, total_cost)
if repo.use_credits(user_id, credits_to_use):
remaining_cost = total_cost - credits_to_use
if remaining_cost > 0:
repo.increment_usage(
user_id, docs=docs_to_record, pages=remaining_cost
)
return True
return False
if user.docs_translated_this_month + docs_to_record <= docs_limit:
repo.increment_usage(user_id, docs=docs_to_record, pages=total_cost)
return True
return False
else:
# JSON fallback (non-atomic, dev only)
if unlimited:
return update_user(
user_id,
{
"docs_translated_this_month": user.docs_translated_this_month
+ docs_to_record,
"pages_translated_this_month": user.pages_translated_this_month
+ total_cost,
},
) is not None
if user.extra_credits > 0:
credits_to_use = min(user.extra_credits, total_cost)
new_credits = user.extra_credits - credits_to_use
remaining_cost = total_cost - credits_to_use
updates = {"extra_credits": new_credits}
if remaining_cost > 0:
if user.docs_translated_this_month + docs_to_record > docs_limit:
return False
updates["docs_translated_this_month"] = (
user.docs_translated_this_month + docs_to_record
)
updates["pages_translated_this_month"] = (
user.pages_translated_this_month + remaining_cost
)
return update_user(user_id, updates) is not None
if user.docs_translated_this_month + docs_to_record <= docs_limit:
return update_user(
user_id,
{
"docs_translated_this_month": user.docs_translated_this_month
+ docs_to_record,
"pages_translated_this_month": user.pages_translated_this_month
+ total_cost,
},
) is not None
return False
def reserve_translation_quota(user_id: str) -> bool:
"""Atomically reserve one document slot at request time.
Returns True if the reservation succeeded. This prevents race conditions where
multiple concurrent requests could exceed the monthly quota.
"""
user = _reset_usage_if_needed(user_id)
if not user:
return False
plan = PLANS[user.plan]
docs_limit = plan["docs_per_month"]
unlimited = docs_limit == -1
if unlimited:
if USE_DATABASE and DATABASE_AVAILABLE:
from database.connection import get_sync_session
from database.repositories import UserRepository
with get_sync_session() as session:
repo = UserRepository(session)
repo.increment_usage(user_id, docs=1, pages=0)
return True
return update_user(
user_id,
{"docs_translated_this_month": user.docs_translated_this_month + 1},
) is not None
# Limited plan: prefer credits first, then quota.
if user.extra_credits > 0:
if USE_DATABASE and DATABASE_AVAILABLE:
from database.connection import get_sync_session
from database.repositories import UserRepository
with get_sync_session() as session:
repo = UserRepository(session)
return repo.use_credits(user_id, 1)
return update_user(
user_id, {"extra_credits": max(0, user.extra_credits - 1)}
) is not None
if user.docs_translated_this_month + 1 <= docs_limit:
if USE_DATABASE and DATABASE_AVAILABLE:
from database.connection import get_sync_session
from database.repositories import UserRepository
with get_sync_session() as session:
repo = UserRepository(session)
repo.increment_usage(user_id, docs=1, pages=0)
return True
return update_user(
user_id,
{"docs_translated_this_month": user.docs_translated_this_month + 1},
) is not None
return False
def release_translation_quota(user_id: str) -> bool:
"""Release a previously reserved document slot (e.g. on translation failure)."""
user = get_user_by_id(user_id)
if not user:
return False
updates = {
"docs_translated_this_month": user.docs_translated_this_month + cost_factor,
"pages_translated_this_month": user.pages_translated_this_month + (pages_count * cost_factor),
}
if USE_DATABASE and DATABASE_AVAILABLE:
from database.connection import get_sync_session
from database.repositories import UserRepository
from sqlalchemy import update
if use_credits:
updates["extra_credits"] = max(0, user.extra_credits - (pages_count * cost_factor))
with get_sync_session() as session:
now = datetime.now(timezone.utc)
session.execute(
update(db_models.User)
.where(
db_models.User.id == user_id,
db_models.User.docs_translated_this_month > 0,
)
.values(
docs_translated_this_month=db_models.User.docs_translated_this_month
- 1,
updated_at=now,
)
.execution_options(synchronize_session=False)
)
session.commit()
return True
result = update_user(user_id, updates)
return result is not None
if user.docs_translated_this_month > 0:
return update_user(
user_id,
{"docs_translated_this_month": user.docs_translated_this_month - 1},
) is not None
return True
def add_credits(user_id: str, credits: int) -> bool:

View File

@@ -4,7 +4,7 @@ Stripe payment integration for subscriptions and credits
import os
import logging
from typing import Optional, Dict, Any
from datetime import datetime
from datetime import datetime, timezone
logger = logging.getLogger(__name__)
@@ -250,19 +250,23 @@ async def handle_webhook(payload: bytes, sig_header: str) -> Dict[str, Any]:
if event["type"] == "checkout.session.completed":
session = event["data"]["object"]
await handle_checkout_completed(session)
elif event["type"] == "customer.subscription.updated":
subscription = event["data"]["object"]
await handle_subscription_updated(subscription)
elif event["type"] == "customer.subscription.deleted":
subscription = event["data"]["object"]
await handle_subscription_deleted(subscription)
elif event["type"] == "invoice.payment_failed":
invoice = event["data"]["object"]
await handle_payment_failed(invoice)
elif event["type"] == "invoice.paid":
invoice = event["data"]["object"]
await handle_invoice_paid(invoice)
return {"status": "success"}
@@ -276,7 +280,9 @@ async def handle_checkout_completed(session: Dict):
session_id = session.get("id")
# Check for duplicate session processing using PaymentHistory
# Check for duplicate session processing using PaymentHistory.
# Use the Stripe payment_intent (one-time) or subscription id (recurring) as the
# idempotency key, NOT the checkout session id — Stripe can redeliver the event.
db_available = False
try:
from database.connection import get_sync_session
@@ -285,14 +291,22 @@ async def handle_checkout_completed(session: Dict):
except ImportError:
pass
payment_intent_id = session.get("payment_intent")
subscription_id = session.get("subscription")
if db_available and session_id:
try:
with get_sync_session() as db_session:
existing = db_session.query(DBPaymentHistory).filter(
DBPaymentHistory.stripe_payment_intent_id == session_id
).first()
from sqlalchemy import or_
filters = [DBPaymentHistory.stripe_payment_intent_id == session_id]
if payment_intent_id:
filters.append(DBPaymentHistory.stripe_payment_intent_id == payment_intent_id)
if subscription_id:
filters.append(DBPaymentHistory.stripe_invoice_id == subscription_id)
existing = db_session.query(DBPaymentHistory).filter(or_(*filters)).first()
if existing:
logger.info("Checkout session %s already processed. Skipping.", session_id)
logger.info("Checkout session %s already processed (pi=%s sub=%s). Skipping.",
session_id, payment_intent_id, subscription_id)
return
except Exception as e:
logger.error("Error checking PaymentHistory duplication: %s", e)
@@ -312,7 +326,7 @@ async def handle_checkout_completed(session: Dict):
with get_sync_session() as db_session:
payment = DBPaymentHistory(
user_id=user_id,
stripe_payment_intent_id=session_id,
stripe_payment_intent_id=payment_intent_id or session_id,
stripe_invoice_id=session.get("invoice") or session.get("subscription"),
amount_cents=session.get("amount_total") or 0,
currency=session.get("currency") or "usd",
@@ -377,7 +391,6 @@ async def handle_checkout_completed(session: Dict):
subscription_id = subscription_raw.get("id")
period_end = subscription_raw.get("current_period_end")
if period_end:
from datetime import timezone
subscription_ends_at = datetime.fromtimestamp(period_end, tz=timezone.utc)
# Derive tier from plan (DB constraint: only 'free' or 'pro')
@@ -408,7 +421,7 @@ async def handle_checkout_completed(session: Dict):
with get_sync_session() as db_session:
payment = DBPaymentHistory(
user_id=user_id,
stripe_payment_intent_id=session_id,
stripe_payment_intent_id=payment_intent_id or session_id,
stripe_invoice_id=subscription_id or session.get("invoice"),
amount_cents=session.get("amount_total") or 0,
currency=session.get("currency") or "usd",
@@ -481,15 +494,13 @@ async def handle_subscription_updated(subscription: Dict):
period_end = subscription.get("current_period_end")
ends_str = ""
if period_end:
from datetime import timezone
ends_str = datetime.fromtimestamp(period_end, tz=timezone.utc).strftime("%d/%m/%Y")
period_end = subscription.get("current_period_end")
update_user(user_id, {
"subscription_status": status.value,
"cancel_at_period_end": stripe_cancel_at_period_end,
"subscription_ends_at": datetime.fromtimestamp(
subscription.get("current_period_end", 0)
).isoformat() if subscription.get("current_period_end") else None
"subscription_ends_at": datetime.fromtimestamp(period_end, tz=timezone.utc) if period_end else None,
})
# Send cancellation email if they just selected to cancel
@@ -531,7 +542,7 @@ async def handle_subscription_deleted(subscription: Dict):
if not user:
return
had_active_sub = user.plan != PlanType.FREE.value or user.tier != "free"
had_active_sub = user.plan != PlanType.FREE or user.tier != "free"
update_user(user_id, {
"plan": PlanType.FREE.value,
@@ -621,6 +632,50 @@ async def handle_payment_failed(invoice: Dict):
logger.error("handle_payment_failed DB error: %s", exc)
async def handle_invoice_paid(invoice: Dict):
"""Extend subscription_ends_at when a recurring invoice is paid."""
customer_id = invoice.get("customer")
if not customer_id:
return
subscription_id = invoice.get("subscription")
period_end = invoice.get("period_end") or invoice.get("lines", {}).get("data", [{}])[0].get("period", {}).get("end")
try:
from database.connection import get_sync_session
from database.models import User as DBUser
with get_sync_session() as session:
db_user = (
session.query(DBUser)
.filter(DBUser.stripe_customer_id == customer_id)
.first()
)
if not db_user:
return
if subscription_id and db_user.stripe_subscription_id != subscription_id:
# The paid invoice belongs to a different subscription; do not update.
logger.warning(
"Invoice paid for customer %s but subscription id mismatch (expected %s, got %s)",
customer_id, db_user.stripe_subscription_id, subscription_id,
)
return
if period_end:
new_end = datetime.fromtimestamp(period_end, tz=timezone.utc)
if db_user.subscription_ends_at is None or new_end > db_user.subscription_ends_at:
db_user.subscription_ends_at = new_end
db_user.updated_at = datetime.now(timezone.utc)
session.commit()
logger.info(
"Extended subscription_ends_at for user %s to %s",
db_user.id, new_end.isoformat(),
)
except Exception as exc:
logger.error("handle_invoice_paid error: %s", exc)
async def cancel_subscription(user_id: str) -> Dict[str, Any]:
"""Cancel a user's subscription at period end."""
if not is_stripe_configured():
@@ -631,6 +686,10 @@ async def cancel_subscription(user_id: str) -> Dict[str, Any]:
return {"error": "No active subscription found"}
try:
subscription = stripe.Subscription.retrieve(user.stripe_subscription_id)
if subscription.customer != user.stripe_customer_id:
return {"error": "Subscription does not belong to current user"}
subscription = stripe.Subscription.modify(
user.stripe_subscription_id,
cancel_at_period_end=True,
@@ -638,13 +697,13 @@ async def cancel_subscription(user_id: str) -> Dict[str, Any]:
cancel_at = None
if subscription.cancel_at:
cancel_at = datetime.fromtimestamp(subscription.cancel_at).isoformat()
cancel_at = datetime.fromtimestamp(subscription.cancel_at, tz=timezone.utc)
subscription_ends_at = None
ends_str = ""
if subscription.current_period_end:
subscription_ends_at = datetime.fromtimestamp(subscription.current_period_end).isoformat()
ends_str = datetime.fromtimestamp(subscription.current_period_end).strftime("%d/%m/%Y")
subscription_ends_at = datetime.fromtimestamp(subscription.current_period_end, tz=timezone.utc)
ends_str = datetime.fromtimestamp(subscription.current_period_end, tz=timezone.utc).strftime("%d/%m/%Y")
is_new_cancel = not user.cancel_at_period_end

View File

@@ -110,26 +110,22 @@ def client(users_file: Path, monkeypatch):
monkeypatch.setattr(RateLimitManager, "check_request", _check_request_allow)
monkeypatch.setattr(RateLimitManager, "check_translation", _check_translation_allow)
from middleware.tier_quota import TierQuotaService
def _check_usage_limits_allow(user):
return {
"can_translate": True,
"docs_used": 0,
"docs_limit": 5,
"docs_remaining": 5,
"pages_used": 0,
"extra_credits": 0,
"max_pages_per_doc": 50,
"max_file_size_mb": 10,
"allowed_providers": ["google", "deepl"],
}
async def _check_quota_allow(self, user_id, tier):
from middleware.tier_quota import QuotaResult
from datetime import datetime, timezone, timedelta
now = datetime.now(timezone.utc)
tomorrow = now.date() + timedelta(days=1)
reset_at = datetime(
tomorrow.year, tomorrow.month, tomorrow.day, tzinfo=timezone.utc
)
return QuotaResult(
allowed=True, remaining=5, reset_at_utc=reset_at, current_usage=0, limit=5
)
async def _increment_noop(self, user_id):
pass
monkeypatch.setattr(TierQuotaService, "check_quota", _check_quota_allow)
monkeypatch.setattr(TierQuotaService, "increment_on_success", _increment_noop)
monkeypatch.setattr(
"routes.translate_routes.check_usage_limits", _check_usage_limits_allow
)
from main import app

View File

@@ -573,16 +573,18 @@ class TestURLIngestionIntegration:
monkeypatch.setattr("routes.translate_routes.download_from_url", mock_download)
monkeypatch.setattr(
"middleware.tier_quota.TierQuotaService.check_quota",
AsyncMock(
return_value=MagicMock(
allowed=True,
remaining=5,
reset_at_utc=None,
current_usage=0,
limit=5,
)
),
"routes.translate_routes.check_usage_limits",
lambda user: {
"can_translate": True,
"docs_used": 0,
"docs_limit": 5,
"docs_remaining": 5,
"pages_used": 0,
"extra_credits": 0,
"max_pages_per_doc": 50,
"max_file_size_mb": 10,
"allowed_providers": ["google", "deepl"],
},
)
response = pro_client.post(
@@ -598,16 +600,18 @@ class TestURLIngestionIntegration:
def test_free_user_rejected(self, free_client, monkeypatch):
"""AC7: Free user receives PRO_FEATURE_REQUIRED (403)"""
monkeypatch.setattr(
"middleware.tier_quota.TierQuotaService.check_quota",
AsyncMock(
return_value=MagicMock(
allowed=True,
remaining=5,
reset_at_utc=None,
current_usage=0,
limit=5,
)
),
"routes.translate_routes.check_usage_limits",
lambda user: {
"can_translate": True,
"docs_used": 0,
"docs_limit": 5,
"docs_remaining": 5,
"pages_used": 0,
"extra_credits": 0,
"max_pages_per_doc": 50,
"max_file_size_mb": 10,
"allowed_providers": ["google", "deepl"],
},
)
response = free_client.post(
@@ -632,16 +636,18 @@ class TestURLIngestionIntegration:
monkeypatch.setattr("routes.translate_routes.download_from_url", mock_download)
monkeypatch.setattr(
"middleware.tier_quota.TierQuotaService.check_quota",
AsyncMock(
return_value=MagicMock(
allowed=True,
remaining=5,
reset_at_utc=None,
current_usage=0,
limit=5,
)
),
"routes.translate_routes.check_usage_limits",
lambda user: {
"can_translate": True,
"docs_used": 0,
"docs_limit": 5,
"docs_remaining": 5,
"pages_used": 0,
"extra_credits": 0,
"max_pages_per_doc": 50,
"max_file_size_mb": 10,
"allowed_providers": ["google", "deepl"],
},
)
response = pro_client.post(
@@ -668,16 +674,18 @@ class TestURLIngestionIntegration:
monkeypatch.setattr("routes.translate_routes.download_from_url", mock_download)
monkeypatch.setattr(
"middleware.tier_quota.TierQuotaService.check_quota",
AsyncMock(
return_value=MagicMock(
allowed=True,
remaining=5,
reset_at_utc=None,
current_usage=0,
limit=5,
)
),
"routes.translate_routes.check_usage_limits",
lambda user: {
"can_translate": True,
"docs_used": 0,
"docs_limit": 5,
"docs_remaining": 5,
"pages_used": 0,
"extra_credits": 0,
"max_pages_per_doc": 50,
"max_file_size_mb": 10,
"allowed_providers": ["google", "deepl"],
},
)
response = pro_client.post(
@@ -704,16 +712,18 @@ class TestURLIngestionIntegration:
monkeypatch.setattr("routes.translate_routes.download_from_url", mock_download)
monkeypatch.setattr(
"middleware.tier_quota.TierQuotaService.check_quota",
AsyncMock(
return_value=MagicMock(
allowed=True,
remaining=5,
reset_at_utc=None,
current_usage=0,
limit=5,
)
),
"routes.translate_routes.check_usage_limits",
lambda user: {
"can_translate": True,
"docs_used": 0,
"docs_limit": 5,
"docs_remaining": 5,
"pages_used": 0,
"extra_credits": 0,
"max_pages_per_doc": 50,
"max_file_size_mb": 10,
"allowed_providers": ["google", "deepl"],
},
)
response = pro_client.post(
@@ -738,16 +748,18 @@ class TestURLIngestionIntegration:
monkeypatch.setattr("routes.translate_routes.download_from_url", mock_download)
monkeypatch.setattr(
"middleware.tier_quota.TierQuotaService.check_quota",
AsyncMock(
return_value=MagicMock(
allowed=True,
remaining=5,
reset_at_utc=None,
current_usage=0,
limit=5,
)
),
"routes.translate_routes.check_usage_limits",
lambda user: {
"can_translate": True,
"docs_used": 0,
"docs_limit": 5,
"docs_remaining": 5,
"pages_used": 0,
"extra_credits": 0,
"max_pages_per_doc": 50,
"max_file_size_mb": 10,
"allowed_providers": ["google", "deepl"],
},
)
response = pro_client.post(
@@ -772,16 +784,18 @@ class TestURLIngestionIntegration:
monkeypatch.setattr("routes.translate_routes.download_from_url", mock_download)
monkeypatch.setattr(
"middleware.tier_quota.TierQuotaService.check_quota",
AsyncMock(
return_value=MagicMock(
allowed=True,
remaining=5,
reset_at_utc=None,
current_usage=0,
limit=5,
)
),
"routes.translate_routes.check_usage_limits",
lambda user: {
"can_translate": True,
"docs_used": 0,
"docs_limit": 5,
"docs_remaining": 5,
"pages_used": 0,
"extra_credits": 0,
"max_pages_per_doc": 50,
"max_file_size_mb": 10,
"allowed_providers": ["google", "deepl"],
},
)
response = pro_client.post(

View File

@@ -115,26 +115,22 @@ def client(users_file: Path, monkeypatch):
monkeypatch.setattr(RateLimitManager, "check_request", _check_request_allow)
monkeypatch.setattr(RateLimitManager, "check_translation", _check_translation_allow)
from middleware.tier_quota import TierQuotaService
def _check_usage_limits_allow(user):
return {
"can_translate": True,
"docs_used": 0,
"docs_limit": 5,
"docs_remaining": 5,
"pages_used": 0,
"extra_credits": 0,
"max_pages_per_doc": 50,
"max_file_size_mb": 10,
"allowed_providers": ["google", "deepl"],
}
async def _check_quota_allow(self, user_id, tier):
from middleware.tier_quota import QuotaResult
from datetime import datetime, timezone, timedelta
now = datetime.now(timezone.utc)
tomorrow = now.date() + timedelta(days=1)
reset_at = datetime(
tomorrow.year, tomorrow.month, tomorrow.day, tzinfo=timezone.utc
)
return QuotaResult(
allowed=True, remaining=5, reset_at_utc=reset_at, current_usage=0, limit=5
)
async def _increment_noop(self, user_id):
pass
monkeypatch.setattr(TierQuotaService, "check_quota", _check_quota_allow)
monkeypatch.setattr(TierQuotaService, "increment_on_success", _increment_noop)
monkeypatch.setattr(
"routes.translate_routes.check_usage_limits", _check_usage_limits_allow
)
from main import app
@@ -477,24 +473,22 @@ class TestQuotaExceeded:
def test_returns_429_when_quota_exceeded(self, client, monkeypatch):
"""Returns 429 with QUOTA_EXCEEDED when quota exceeded"""
from middleware.tier_quota import TierQuotaService, QuotaResult
from datetime import datetime, timezone, timedelta
def _check_usage_limits_denied(user):
return {
"can_translate": False,
"docs_used": 5,
"docs_limit": 5,
"docs_remaining": 0,
"pages_used": 0,
"extra_credits": 0,
"max_pages_per_doc": 50,
"max_file_size_mb": 10,
"allowed_providers": ["google", "deepl"],
}
async def _check_quota_denied(self, user_id, tier):
now = datetime.now(timezone.utc)
tomorrow = now.date() + timedelta(days=1)
reset_at = datetime(
tomorrow.year, tomorrow.month, tomorrow.day, tzinfo=timezone.utc
)
return QuotaResult(
allowed=False,
remaining=0,
reset_at_utc=reset_at,
current_usage=5,
limit=5,
)
monkeypatch.setattr(TierQuotaService, "check_quota", _check_quota_denied)
monkeypatch.setattr(
"routes.translate_routes.check_usage_limits", _check_usage_limits_denied
)
# Register and login
client.post(REGISTER_URL, json=VALID_USER)
@@ -530,24 +524,23 @@ class TestQuotaExceeded:
def test_includes_retry_after_header(self, client, monkeypatch):
"""Includes Retry-After header on 429"""
from middleware.tier_quota import TierQuotaService, QuotaResult
from datetime import datetime, timezone, timedelta
async def _check_quota_denied(self, user_id, tier):
now = datetime.now(timezone.utc)
tomorrow = now.date() + timedelta(days=1)
reset_at = datetime(
tomorrow.year, tomorrow.month, tomorrow.day, tzinfo=timezone.utc
)
return QuotaResult(
allowed=False,
remaining=0,
reset_at_utc=reset_at,
current_usage=5,
limit=5,
)
def _check_usage_limits_denied(user):
return {
"can_translate": False,
"docs_used": 5,
"docs_limit": 5,
"docs_remaining": 0,
"pages_used": 0,
"extra_credits": 0,
"max_pages_per_doc": 50,
"max_file_size_mb": 10,
"allowed_providers": ["google", "deepl"],
}
monkeypatch.setattr(TierQuotaService, "check_quota", _check_quota_denied)
monkeypatch.setattr(
"routes.translate_routes.check_usage_limits", _check_usage_limits_denied
)
client.post(REGISTER_URL, json=VALID_USER)
response = client.post(