fix(billing): unify quota counters, fix Stripe webhooks, tier/plan sync
All checks were successful
Deploy to Production / Build and Deploy (push) Successful in 3m16s
All checks were successful
Deploy to Production / Build and Deploy (push) Successful in 3m16s
This commit is contained in:
@@ -93,34 +93,45 @@ class UserRepository:
|
||||
self.db.commit()
|
||||
return True
|
||||
|
||||
def increment_usage(
|
||||
self, user_id: str, docs: int = 0, pages: int = 0, api_calls: int = 0
|
||||
) -> Optional[User]:
|
||||
"""Increment usage counters"""
|
||||
def reset_usage_if_needed(self, user_id: str) -> Optional[User]:
|
||||
"""Reset monthly counters if usage_reset_date is in a previous month."""
|
||||
user = self.get_by_id(user_id)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
# Check if usage needs to be reset (monthly)
|
||||
if user.usage_reset_date:
|
||||
now = datetime.now(timezone.utc)
|
||||
if (
|
||||
now.month != user.usage_reset_date.month
|
||||
or now.year != user.usage_reset_date.year
|
||||
):
|
||||
user.docs_translated_this_month = 0
|
||||
user.pages_translated_this_month = 0
|
||||
user.api_calls_this_month = 0
|
||||
user.usage_reset_date = now
|
||||
|
||||
user.docs_translated_this_month += docs
|
||||
user.pages_translated_this_month += pages
|
||||
user.api_calls_this_month += api_calls
|
||||
|
||||
self.db.commit()
|
||||
self.db.refresh(user)
|
||||
now = datetime.now(timezone.utc)
|
||||
reset_date = user.usage_reset_date
|
||||
if reset_date is None or reset_date.month != now.month or reset_date.year != now.year:
|
||||
user.docs_translated_this_month = 0
|
||||
user.pages_translated_this_month = 0
|
||||
user.api_calls_this_month = 0
|
||||
user.usage_reset_date = now
|
||||
user.updated_at = now
|
||||
self.db.commit()
|
||||
self.db.refresh(user)
|
||||
return user
|
||||
|
||||
def increment_usage(
|
||||
self, user_id: str, docs: int = 0, pages: int = 0, api_calls: int = 0
|
||||
) -> Optional[User]:
|
||||
"""Increment usage counters atomically via SQL UPDATE."""
|
||||
from sqlalchemy import update
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
self.db.execute(
|
||||
update(User)
|
||||
.where(User.id == user_id)
|
||||
.values(
|
||||
docs_translated_this_month=User.docs_translated_this_month + docs,
|
||||
pages_translated_this_month=User.pages_translated_this_month + pages,
|
||||
api_calls_this_month=User.api_calls_this_month + api_calls,
|
||||
updated_at=now,
|
||||
)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
self.db.commit()
|
||||
return self.get_by_id(user_id)
|
||||
|
||||
def add_credits(self, user_id: str, credits: int) -> Optional[User]:
|
||||
"""Add extra credits to user"""
|
||||
user = self.get_by_id(user_id)
|
||||
@@ -133,14 +144,21 @@ class UserRepository:
|
||||
return user
|
||||
|
||||
def use_credits(self, user_id: str, credits: int) -> bool:
|
||||
"""Use credits from user balance"""
|
||||
user = self.get_by_id(user_id)
|
||||
if not user or user.extra_credits < credits:
|
||||
return False
|
||||
"""Use credits from user balance atomically (prevents overdraft)."""
|
||||
from sqlalchemy import update
|
||||
|
||||
user.extra_credits -= credits
|
||||
now = datetime.now(timezone.utc)
|
||||
result = self.db.execute(
|
||||
update(User)
|
||||
.where(User.id == user_id, User.extra_credits >= credits)
|
||||
.values(
|
||||
extra_credits=User.extra_credits - credits,
|
||||
updated_at=now,
|
||||
)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
self.db.commit()
|
||||
return True
|
||||
return result.rowcount > 0
|
||||
|
||||
def get_all_users(
|
||||
self, skip: int = 0, limit: int = 100, plan: Optional[PlanType] = None
|
||||
|
||||
@@ -204,6 +204,7 @@ class User(BaseModel):
|
||||
|
||||
# Subscription info
|
||||
plan: PlanType = PlanType.FREE
|
||||
tier: str = "free" # binary-ish tier for JWT/auth: free for free/starter, pro for paid
|
||||
subscription_status: SubscriptionStatus = SubscriptionStatus.ACTIVE
|
||||
stripe_customer_id: Optional[str] = None
|
||||
stripe_subscription_id: Optional[str] = None
|
||||
|
||||
@@ -104,13 +104,14 @@ async def require_user(credentials: HTTPAuthorizationCredentials = Depends(secur
|
||||
def user_to_response(user) -> UserResponse:
|
||||
"""Convert User to UserResponse with plan limits"""
|
||||
plan_limits = PLANS[user.plan]
|
||||
tier = getattr(user, "tier", None) or user.plan
|
||||
return UserResponse(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
name=user.name,
|
||||
avatar_url=user.avatar_url,
|
||||
plan=user.plan,
|
||||
tier=user.plan,
|
||||
tier=tier,
|
||||
subscription_status=user.subscription_status,
|
||||
subscription_ends_at=getattr(user, 'subscription_ends_at', None),
|
||||
cancel_at_period_end=getattr(user, 'cancel_at_period_end', False),
|
||||
@@ -540,7 +541,7 @@ async def login_v1(request: Request):
|
||||
)
|
||||
|
||||
access_token = create_access_token(
|
||||
user.id, tier=user.plan.value, expires_delta=timedelta(minutes=15)
|
||||
user.id, tier=getattr(user, "tier", user.plan.value), expires_delta=timedelta(minutes=15)
|
||||
)
|
||||
refresh_token = create_refresh_token(user.id, expires_delta=timedelta(days=7))
|
||||
|
||||
@@ -626,7 +627,7 @@ async def google_auth_v1(body: GoogleAuthRequest):
|
||||
content={"error": "USER_CREATE_FAILED", "message": str(exc)},
|
||||
)
|
||||
|
||||
access_token = create_access_token(user.id)
|
||||
access_token = create_access_token(user.id, tier=getattr(user, "tier", "free"))
|
||||
refresh_token = create_refresh_token(user.id)
|
||||
|
||||
return JSONResponse(
|
||||
@@ -724,7 +725,7 @@ async def refresh_v1(request: Request):
|
||||
)
|
||||
|
||||
access_token = create_access_token(
|
||||
user.id, tier=user.plan.value, expires_delta=timedelta(minutes=15)
|
||||
user.id, tier=getattr(user, "tier", user.plan.value), expires_delta=timedelta(minutes=15)
|
||||
)
|
||||
new_refresh_token = create_refresh_token(user.id, expires_delta=timedelta(days=7))
|
||||
|
||||
|
||||
@@ -44,8 +44,13 @@ from typing_extensions import Annotated
|
||||
from config import config
|
||||
from translators import ExcelTranslator, WordTranslator, PowerPointTranslator
|
||||
from models.subscription import PlanType
|
||||
from middleware.tier_quota import tier_quota_service
|
||||
from services.auth_service import record_usage, check_usage_limits
|
||||
from services.auth_service import (
|
||||
record_usage,
|
||||
check_usage_limits,
|
||||
reserve_translation_quota,
|
||||
release_translation_quota,
|
||||
)
|
||||
from middleware.tier_quota import _seconds_until_next_month, _next_month_utc
|
||||
from middleware.validation import FileValidator, ValidationError, LanguageValidator, webhook_validator
|
||||
from middleware.api_key_auth import get_authenticated_user, get_user_from_api_key
|
||||
from utils import file_handler
|
||||
@@ -548,6 +553,7 @@ async def translate_document_v1(
|
||||
"""
|
||||
request_id = getattr(request.state, "request_id", str(uuid.uuid4())[:8])
|
||||
|
||||
quota_reserved = False
|
||||
try:
|
||||
if not file and not file_url:
|
||||
raise TranslateEndpointError(
|
||||
@@ -615,40 +621,42 @@ async def translate_document_v1(
|
||||
)
|
||||
|
||||
if current_user:
|
||||
quota = await tier_quota_service.check_quota(user_id, tier)
|
||||
if not quota.allowed:
|
||||
retry_after = tier_quota_service.seconds_until_reset()
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
"error": "QUOTA_EXCEEDED",
|
||||
"message": f"Monthly limit reached ({quota.current_usage}/{quota.limit} documents). Upgrade your plan for more.",
|
||||
"details": {
|
||||
"current_usage": quota.current_usage,
|
||||
"limit": quota.limit,
|
||||
"tier": tier,
|
||||
"reset_at": quota.reset_at_utc.isoformat(),
|
||||
},
|
||||
},
|
||||
headers={"Retry-After": str(retry_after)},
|
||||
)
|
||||
|
||||
# Strict database plan limit check (Starter, Pro, Business, Enterprise)
|
||||
usage = check_usage_limits(current_user)
|
||||
if not usage["can_translate"]:
|
||||
retry_after = _seconds_until_next_month()
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
"error": "QUOTA_EXCEEDED",
|
||||
"message": f"Monthly limit reached ({usage['docs_used']}/{usage['docs_limit']} documents). Upgrade your plan for more.",
|
||||
"details": {
|
||||
"current_usage": usage["docs_used"],
|
||||
"limit": usage["docs_limit"],
|
||||
"current_usage": usage['docs_used'],
|
||||
"limit": usage['docs_limit'],
|
||||
"tier": tier,
|
||||
"reset_at": _next_month_utc().isoformat(),
|
||||
},
|
||||
},
|
||||
headers={"Retry-After": str(retry_after)},
|
||||
)
|
||||
rate_limit_remaining = quota.remaining
|
||||
# Atomically reserve one document slot now so concurrent requests cannot
|
||||
# overshoot the monthly quota while background jobs are still running.
|
||||
reserved = await asyncio.to_thread(reserve_translation_quota, user_id)
|
||||
if not reserved:
|
||||
retry_after = _seconds_until_next_month()
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
"error": "QUOTA_EXCEEDED",
|
||||
"message": "Monthly limit reached. Upgrade your plan for more.",
|
||||
"details": {
|
||||
"tier": tier,
|
||||
"reset_at": _next_month_utc().isoformat(),
|
||||
},
|
||||
},
|
||||
headers={"Retry-After": str(retry_after)},
|
||||
)
|
||||
quota_reserved = True
|
||||
rate_limit_remaining = usage["docs_remaining"]
|
||||
else:
|
||||
rate_limit_remaining = -1
|
||||
|
||||
@@ -839,6 +847,8 @@ async def translate_document_v1(
|
||||
)
|
||||
|
||||
except TranslateEndpointError as e:
|
||||
if quota_reserved and user_id:
|
||||
await asyncio.to_thread(release_translation_quota, user_id)
|
||||
status_code = 400
|
||||
if e.code == TranslateEndpointError.FILE_TOO_LARGE:
|
||||
status_code = 413
|
||||
@@ -852,8 +862,12 @@ async def translate_document_v1(
|
||||
content=e.to_dict(),
|
||||
)
|
||||
except HTTPException:
|
||||
if quota_reserved and user_id:
|
||||
await asyncio.to_thread(release_translation_quota, user_id)
|
||||
raise
|
||||
except Exception as e:
|
||||
if quota_reserved and user_id:
|
||||
await asyncio.to_thread(release_translation_quota, user_id)
|
||||
logger.error(f"[{request_id}] Unexpected error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
@@ -973,6 +987,7 @@ async def _run_translation_job(
|
||||
return
|
||||
|
||||
tracker = ProgressTracker(job_id, _translation_jobs)
|
||||
usage_recorded = False
|
||||
|
||||
try:
|
||||
job["status"] = "processing"
|
||||
@@ -1337,14 +1352,14 @@ async def _run_translation_job(
|
||||
else:
|
||||
cost_factor = 5
|
||||
|
||||
for _ in range(cost_factor):
|
||||
await tier_quota_service.increment_on_success(user_id)
|
||||
|
||||
# Persist monthly usage counters in PostgreSQL (docs + pages)
|
||||
pages = await asyncio.to_thread(
|
||||
_estimate_pages, input_path, file_extension
|
||||
)
|
||||
await asyncio.to_thread(record_usage, user_id, pages, False, cost_factor)
|
||||
await asyncio.to_thread(
|
||||
record_usage, user_id, pages, cost_factor, reserved_docs=1
|
||||
)
|
||||
usage_recorded = True
|
||||
logger.info(f"Job {job_id}: usage recorded — {pages} page(s) with cost factor {cost_factor}")
|
||||
|
||||
# Apply watermark for Free-tier users
|
||||
@@ -1364,6 +1379,12 @@ async def _run_translation_job(
|
||||
record_translation(provider=provider, file_type=file_extension or "unknown", duration=duration, status="success")
|
||||
logger.info(f"Job {job_id}: Completed successfully")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# Background task cancelled (e.g. TestClient teardown or server shutdown).
|
||||
# The document slot was already reserved at request time; keep it consumed
|
||||
# so quota enforcement remains deterministic.
|
||||
logger.warning(f"Job {job_id}: translation task cancelled, keeping reserved quota")
|
||||
raise
|
||||
except Exception as e:
|
||||
# Check if this is our structured TranslationProviderError
|
||||
if type(e).__name__ == "TranslationProviderError":
|
||||
@@ -1376,6 +1397,13 @@ async def _run_translation_job(
|
||||
# Record translation failure metric
|
||||
record_translation(provider=provider, file_type=file_extension or "unknown", duration=0, status="error")
|
||||
|
||||
if user_id and not usage_recorded:
|
||||
try:
|
||||
await asyncio.to_thread(release_translation_quota, user_id)
|
||||
logger.info(f"Job {job_id}: released reserved quota after failure")
|
||||
except Exception as release_err:
|
||||
logger.exception(f"Job {job_id}: failed to release reserved quota: {release_err}")
|
||||
|
||||
finally:
|
||||
if webhook_url:
|
||||
try:
|
||||
|
||||
@@ -278,6 +278,7 @@ def _db_user_to_model(db_user) -> User:
|
||||
password_hash=db_user.password_hash,
|
||||
avatar_url=db_user.avatar_url,
|
||||
plan=PlanType(db_user.plan) if db_user.plan else PlanType.FREE,
|
||||
tier=db_user.tier or "free",
|
||||
subscription_status=SubscriptionStatus(db_user.subscription_status)
|
||||
if db_user.subscription_status
|
||||
else SubscriptionStatus.ACTIVE,
|
||||
@@ -460,36 +461,48 @@ def update_user(user_id: str, updates: Dict[str, Any]) -> Optional[User]:
|
||||
return User(**users[user_id])
|
||||
|
||||
|
||||
def _reset_usage_if_needed(user_id: str) -> Optional[User]:
|
||||
"""Reset monthly counters if usage_reset_date is in a previous month."""
|
||||
if USE_DATABASE and DATABASE_AVAILABLE:
|
||||
from database.connection import get_sync_session
|
||||
from database.repositories import UserRepository
|
||||
|
||||
with get_sync_session() as session:
|
||||
repo = UserRepository(session)
|
||||
repo.reset_usage_if_needed(user_id)
|
||||
else:
|
||||
user = get_user_by_id(user_id)
|
||||
if not user:
|
||||
return None
|
||||
now = datetime.now(timezone.utc)
|
||||
reset_date = user.usage_reset_date
|
||||
if reset_date.month != now.month or reset_date.year != now.year:
|
||||
update_user(
|
||||
user_id,
|
||||
{
|
||||
"docs_translated_this_month": 0,
|
||||
"pages_translated_this_month": 0,
|
||||
"api_calls_this_month": 0,
|
||||
"usage_reset_date": now.isoformat(),
|
||||
},
|
||||
)
|
||||
return get_user_by_id(user_id)
|
||||
|
||||
|
||||
def check_usage_limits(user: User) -> Dict[str, Any]:
|
||||
"""Check if user has exceeded their plan limits"""
|
||||
# Ensure counters are reset if we've entered a new month.
|
||||
refreshed = _reset_usage_if_needed(user.id)
|
||||
if refreshed:
|
||||
user = refreshed
|
||||
|
||||
plan = PLANS[user.plan]
|
||||
|
||||
# Reset usage if it's a new month
|
||||
now = datetime.now(timezone.utc)
|
||||
if (
|
||||
user.usage_reset_date.month != now.month
|
||||
or user.usage_reset_date.year != now.year
|
||||
):
|
||||
update_user(
|
||||
user.id,
|
||||
{
|
||||
"docs_translated_this_month": 0,
|
||||
"pages_translated_this_month": 0,
|
||||
"api_calls_this_month": 0,
|
||||
"usage_reset_date": now.isoformat() if not USE_DATABASE else now,
|
||||
},
|
||||
)
|
||||
user.docs_translated_this_month = 0
|
||||
user.pages_translated_this_month = 0
|
||||
user.api_calls_this_month = 0
|
||||
|
||||
docs_limit = plan["docs_per_month"]
|
||||
docs_remaining = (
|
||||
max(0, docs_limit - user.docs_translated_this_month) if docs_limit > 0 else -1
|
||||
)
|
||||
unlimited = docs_limit == -1
|
||||
docs_remaining = -1 if unlimited else max(0, docs_limit - user.docs_translated_this_month)
|
||||
|
||||
return {
|
||||
"can_translate": docs_remaining != 0 or user.extra_credits > 0,
|
||||
"can_translate": unlimited or docs_remaining != 0 or user.extra_credits > 0,
|
||||
"docs_used": user.docs_translated_this_month,
|
||||
"docs_limit": docs_limit,
|
||||
"docs_remaining": docs_remaining,
|
||||
@@ -502,23 +515,190 @@ def check_usage_limits(user: User) -> Dict[str, Any]:
|
||||
|
||||
|
||||
def record_usage(
|
||||
user_id: str, pages_count: int, use_credits: bool = False, cost_factor: int = 1
|
||||
user_id: str, pages_count: int, cost_factor: int = 1, reserved_docs: int = 0
|
||||
) -> bool:
|
||||
"""Record document translation usage with optional cost factor depending on AI model"""
|
||||
"""Record document translation usage with optional cost factor depending on AI model.
|
||||
|
||||
`reserved_docs` is the number of document slots already reserved at request time
|
||||
(e.g. by ``reserve_translation_quota``). Those slots are not counted again; only
|
||||
the remaining ``max(0, cost_factor - reserved_docs)`` docs are added here.
|
||||
|
||||
Automatically consumes extra credits first; falls back to monthly quota.
|
||||
Returns True if usage was recorded successfully, False otherwise.
|
||||
"""
|
||||
user = _reset_usage_if_needed(user_id)
|
||||
if not user:
|
||||
return False
|
||||
|
||||
total_cost = pages_count * cost_factor
|
||||
docs_to_record = max(0, cost_factor - reserved_docs)
|
||||
plan = PLANS[user.plan]
|
||||
docs_limit = plan["docs_per_month"]
|
||||
unlimited = docs_limit == -1
|
||||
|
||||
if USE_DATABASE and DATABASE_AVAILABLE:
|
||||
from database.connection import get_sync_session
|
||||
from database.repositories import UserRepository
|
||||
|
||||
with get_sync_session() as session:
|
||||
repo = UserRepository(session)
|
||||
|
||||
if unlimited:
|
||||
# Paid unlimited plans: track usage for analytics/downgrade safety.
|
||||
repo.increment_usage(user_id, docs=docs_to_record, pages=total_cost)
|
||||
return True
|
||||
|
||||
# Prefer credits first, then quota.
|
||||
if user.extra_credits > 0:
|
||||
credits_to_use = min(user.extra_credits, total_cost)
|
||||
if repo.use_credits(user_id, credits_to_use):
|
||||
remaining_cost = total_cost - credits_to_use
|
||||
if remaining_cost > 0:
|
||||
repo.increment_usage(
|
||||
user_id, docs=docs_to_record, pages=remaining_cost
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
if user.docs_translated_this_month + docs_to_record <= docs_limit:
|
||||
repo.increment_usage(user_id, docs=docs_to_record, pages=total_cost)
|
||||
return True
|
||||
return False
|
||||
else:
|
||||
# JSON fallback (non-atomic, dev only)
|
||||
if unlimited:
|
||||
return update_user(
|
||||
user_id,
|
||||
{
|
||||
"docs_translated_this_month": user.docs_translated_this_month
|
||||
+ docs_to_record,
|
||||
"pages_translated_this_month": user.pages_translated_this_month
|
||||
+ total_cost,
|
||||
},
|
||||
) is not None
|
||||
|
||||
if user.extra_credits > 0:
|
||||
credits_to_use = min(user.extra_credits, total_cost)
|
||||
new_credits = user.extra_credits - credits_to_use
|
||||
remaining_cost = total_cost - credits_to_use
|
||||
updates = {"extra_credits": new_credits}
|
||||
if remaining_cost > 0:
|
||||
if user.docs_translated_this_month + docs_to_record > docs_limit:
|
||||
return False
|
||||
updates["docs_translated_this_month"] = (
|
||||
user.docs_translated_this_month + docs_to_record
|
||||
)
|
||||
updates["pages_translated_this_month"] = (
|
||||
user.pages_translated_this_month + remaining_cost
|
||||
)
|
||||
return update_user(user_id, updates) is not None
|
||||
|
||||
if user.docs_translated_this_month + docs_to_record <= docs_limit:
|
||||
return update_user(
|
||||
user_id,
|
||||
{
|
||||
"docs_translated_this_month": user.docs_translated_this_month
|
||||
+ docs_to_record,
|
||||
"pages_translated_this_month": user.pages_translated_this_month
|
||||
+ total_cost,
|
||||
},
|
||||
) is not None
|
||||
return False
|
||||
|
||||
|
||||
def reserve_translation_quota(user_id: str) -> bool:
|
||||
"""Atomically reserve one document slot at request time.
|
||||
|
||||
Returns True if the reservation succeeded. This prevents race conditions where
|
||||
multiple concurrent requests could exceed the monthly quota.
|
||||
"""
|
||||
user = _reset_usage_if_needed(user_id)
|
||||
if not user:
|
||||
return False
|
||||
|
||||
plan = PLANS[user.plan]
|
||||
docs_limit = plan["docs_per_month"]
|
||||
unlimited = docs_limit == -1
|
||||
|
||||
if unlimited:
|
||||
if USE_DATABASE and DATABASE_AVAILABLE:
|
||||
from database.connection import get_sync_session
|
||||
from database.repositories import UserRepository
|
||||
|
||||
with get_sync_session() as session:
|
||||
repo = UserRepository(session)
|
||||
repo.increment_usage(user_id, docs=1, pages=0)
|
||||
return True
|
||||
return update_user(
|
||||
user_id,
|
||||
{"docs_translated_this_month": user.docs_translated_this_month + 1},
|
||||
) is not None
|
||||
|
||||
# Limited plan: prefer credits first, then quota.
|
||||
if user.extra_credits > 0:
|
||||
if USE_DATABASE and DATABASE_AVAILABLE:
|
||||
from database.connection import get_sync_session
|
||||
from database.repositories import UserRepository
|
||||
|
||||
with get_sync_session() as session:
|
||||
repo = UserRepository(session)
|
||||
return repo.use_credits(user_id, 1)
|
||||
return update_user(
|
||||
user_id, {"extra_credits": max(0, user.extra_credits - 1)}
|
||||
) is not None
|
||||
|
||||
if user.docs_translated_this_month + 1 <= docs_limit:
|
||||
if USE_DATABASE and DATABASE_AVAILABLE:
|
||||
from database.connection import get_sync_session
|
||||
from database.repositories import UserRepository
|
||||
|
||||
with get_sync_session() as session:
|
||||
repo = UserRepository(session)
|
||||
repo.increment_usage(user_id, docs=1, pages=0)
|
||||
return True
|
||||
return update_user(
|
||||
user_id,
|
||||
{"docs_translated_this_month": user.docs_translated_this_month + 1},
|
||||
) is not None
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def release_translation_quota(user_id: str) -> bool:
|
||||
"""Release a previously reserved document slot (e.g. on translation failure)."""
|
||||
user = get_user_by_id(user_id)
|
||||
if not user:
|
||||
return False
|
||||
|
||||
updates = {
|
||||
"docs_translated_this_month": user.docs_translated_this_month + cost_factor,
|
||||
"pages_translated_this_month": user.pages_translated_this_month + (pages_count * cost_factor),
|
||||
}
|
||||
if USE_DATABASE and DATABASE_AVAILABLE:
|
||||
from database.connection import get_sync_session
|
||||
from database.repositories import UserRepository
|
||||
from sqlalchemy import update
|
||||
|
||||
if use_credits:
|
||||
updates["extra_credits"] = max(0, user.extra_credits - (pages_count * cost_factor))
|
||||
with get_sync_session() as session:
|
||||
now = datetime.now(timezone.utc)
|
||||
session.execute(
|
||||
update(db_models.User)
|
||||
.where(
|
||||
db_models.User.id == user_id,
|
||||
db_models.User.docs_translated_this_month > 0,
|
||||
)
|
||||
.values(
|
||||
docs_translated_this_month=db_models.User.docs_translated_this_month
|
||||
- 1,
|
||||
updated_at=now,
|
||||
)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
result = update_user(user_id, updates)
|
||||
return result is not None
|
||||
if user.docs_translated_this_month > 0:
|
||||
return update_user(
|
||||
user_id,
|
||||
{"docs_translated_this_month": user.docs_translated_this_month - 1},
|
||||
) is not None
|
||||
return True
|
||||
|
||||
|
||||
def add_credits(user_id: str, credits: int) -> bool:
|
||||
|
||||
@@ -4,7 +4,7 @@ Stripe payment integration for subscriptions and credits
|
||||
import os
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -263,6 +263,10 @@ async def handle_webhook(payload: bytes, sig_header: str) -> Dict[str, Any]:
|
||||
invoice = event["data"]["object"]
|
||||
await handle_payment_failed(invoice)
|
||||
|
||||
elif event["type"] == "invoice.paid":
|
||||
invoice = event["data"]["object"]
|
||||
await handle_invoice_paid(invoice)
|
||||
|
||||
return {"status": "success"}
|
||||
|
||||
|
||||
@@ -276,7 +280,9 @@ async def handle_checkout_completed(session: Dict):
|
||||
|
||||
session_id = session.get("id")
|
||||
|
||||
# Check for duplicate session processing using PaymentHistory
|
||||
# Check for duplicate session processing using PaymentHistory.
|
||||
# Use the Stripe payment_intent (one-time) or subscription id (recurring) as the
|
||||
# idempotency key, NOT the checkout session id — Stripe can redeliver the event.
|
||||
db_available = False
|
||||
try:
|
||||
from database.connection import get_sync_session
|
||||
@@ -285,14 +291,22 @@ async def handle_checkout_completed(session: Dict):
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
payment_intent_id = session.get("payment_intent")
|
||||
subscription_id = session.get("subscription")
|
||||
if db_available and session_id:
|
||||
try:
|
||||
with get_sync_session() as db_session:
|
||||
existing = db_session.query(DBPaymentHistory).filter(
|
||||
DBPaymentHistory.stripe_payment_intent_id == session_id
|
||||
).first()
|
||||
from sqlalchemy import or_
|
||||
|
||||
filters = [DBPaymentHistory.stripe_payment_intent_id == session_id]
|
||||
if payment_intent_id:
|
||||
filters.append(DBPaymentHistory.stripe_payment_intent_id == payment_intent_id)
|
||||
if subscription_id:
|
||||
filters.append(DBPaymentHistory.stripe_invoice_id == subscription_id)
|
||||
existing = db_session.query(DBPaymentHistory).filter(or_(*filters)).first()
|
||||
if existing:
|
||||
logger.info("Checkout session %s already processed. Skipping.", session_id)
|
||||
logger.info("Checkout session %s already processed (pi=%s sub=%s). Skipping.",
|
||||
session_id, payment_intent_id, subscription_id)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error("Error checking PaymentHistory duplication: %s", e)
|
||||
@@ -312,7 +326,7 @@ async def handle_checkout_completed(session: Dict):
|
||||
with get_sync_session() as db_session:
|
||||
payment = DBPaymentHistory(
|
||||
user_id=user_id,
|
||||
stripe_payment_intent_id=session_id,
|
||||
stripe_payment_intent_id=payment_intent_id or session_id,
|
||||
stripe_invoice_id=session.get("invoice") or session.get("subscription"),
|
||||
amount_cents=session.get("amount_total") or 0,
|
||||
currency=session.get("currency") or "usd",
|
||||
@@ -377,7 +391,6 @@ async def handle_checkout_completed(session: Dict):
|
||||
subscription_id = subscription_raw.get("id")
|
||||
period_end = subscription_raw.get("current_period_end")
|
||||
if period_end:
|
||||
from datetime import timezone
|
||||
subscription_ends_at = datetime.fromtimestamp(period_end, tz=timezone.utc)
|
||||
|
||||
# Derive tier from plan (DB constraint: only 'free' or 'pro')
|
||||
@@ -408,7 +421,7 @@ async def handle_checkout_completed(session: Dict):
|
||||
with get_sync_session() as db_session:
|
||||
payment = DBPaymentHistory(
|
||||
user_id=user_id,
|
||||
stripe_payment_intent_id=session_id,
|
||||
stripe_payment_intent_id=payment_intent_id or session_id,
|
||||
stripe_invoice_id=subscription_id or session.get("invoice"),
|
||||
amount_cents=session.get("amount_total") or 0,
|
||||
currency=session.get("currency") or "usd",
|
||||
@@ -481,15 +494,13 @@ async def handle_subscription_updated(subscription: Dict):
|
||||
period_end = subscription.get("current_period_end")
|
||||
ends_str = ""
|
||||
if period_end:
|
||||
from datetime import timezone
|
||||
ends_str = datetime.fromtimestamp(period_end, tz=timezone.utc).strftime("%d/%m/%Y")
|
||||
|
||||
period_end = subscription.get("current_period_end")
|
||||
update_user(user_id, {
|
||||
"subscription_status": status.value,
|
||||
"cancel_at_period_end": stripe_cancel_at_period_end,
|
||||
"subscription_ends_at": datetime.fromtimestamp(
|
||||
subscription.get("current_period_end", 0)
|
||||
).isoformat() if subscription.get("current_period_end") else None
|
||||
"subscription_ends_at": datetime.fromtimestamp(period_end, tz=timezone.utc) if period_end else None,
|
||||
})
|
||||
|
||||
# Send cancellation email if they just selected to cancel
|
||||
@@ -531,7 +542,7 @@ async def handle_subscription_deleted(subscription: Dict):
|
||||
if not user:
|
||||
return
|
||||
|
||||
had_active_sub = user.plan != PlanType.FREE.value or user.tier != "free"
|
||||
had_active_sub = user.plan != PlanType.FREE or user.tier != "free"
|
||||
|
||||
update_user(user_id, {
|
||||
"plan": PlanType.FREE.value,
|
||||
@@ -621,6 +632,50 @@ async def handle_payment_failed(invoice: Dict):
|
||||
logger.error("handle_payment_failed DB error: %s", exc)
|
||||
|
||||
|
||||
async def handle_invoice_paid(invoice: Dict):
|
||||
"""Extend subscription_ends_at when a recurring invoice is paid."""
|
||||
customer_id = invoice.get("customer")
|
||||
if not customer_id:
|
||||
return
|
||||
|
||||
subscription_id = invoice.get("subscription")
|
||||
period_end = invoice.get("period_end") or invoice.get("lines", {}).get("data", [{}])[0].get("period", {}).get("end")
|
||||
|
||||
try:
|
||||
from database.connection import get_sync_session
|
||||
from database.models import User as DBUser
|
||||
|
||||
with get_sync_session() as session:
|
||||
db_user = (
|
||||
session.query(DBUser)
|
||||
.filter(DBUser.stripe_customer_id == customer_id)
|
||||
.first()
|
||||
)
|
||||
if not db_user:
|
||||
return
|
||||
|
||||
if subscription_id and db_user.stripe_subscription_id != subscription_id:
|
||||
# The paid invoice belongs to a different subscription; do not update.
|
||||
logger.warning(
|
||||
"Invoice paid for customer %s but subscription id mismatch (expected %s, got %s)",
|
||||
customer_id, db_user.stripe_subscription_id, subscription_id,
|
||||
)
|
||||
return
|
||||
|
||||
if period_end:
|
||||
new_end = datetime.fromtimestamp(period_end, tz=timezone.utc)
|
||||
if db_user.subscription_ends_at is None or new_end > db_user.subscription_ends_at:
|
||||
db_user.subscription_ends_at = new_end
|
||||
db_user.updated_at = datetime.now(timezone.utc)
|
||||
session.commit()
|
||||
logger.info(
|
||||
"Extended subscription_ends_at for user %s to %s",
|
||||
db_user.id, new_end.isoformat(),
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("handle_invoice_paid error: %s", exc)
|
||||
|
||||
|
||||
async def cancel_subscription(user_id: str) -> Dict[str, Any]:
|
||||
"""Cancel a user's subscription at period end."""
|
||||
if not is_stripe_configured():
|
||||
@@ -631,6 +686,10 @@ async def cancel_subscription(user_id: str) -> Dict[str, Any]:
|
||||
return {"error": "No active subscription found"}
|
||||
|
||||
try:
|
||||
subscription = stripe.Subscription.retrieve(user.stripe_subscription_id)
|
||||
if subscription.customer != user.stripe_customer_id:
|
||||
return {"error": "Subscription does not belong to current user"}
|
||||
|
||||
subscription = stripe.Subscription.modify(
|
||||
user.stripe_subscription_id,
|
||||
cancel_at_period_end=True,
|
||||
@@ -638,13 +697,13 @@ async def cancel_subscription(user_id: str) -> Dict[str, Any]:
|
||||
|
||||
cancel_at = None
|
||||
if subscription.cancel_at:
|
||||
cancel_at = datetime.fromtimestamp(subscription.cancel_at).isoformat()
|
||||
cancel_at = datetime.fromtimestamp(subscription.cancel_at, tz=timezone.utc)
|
||||
|
||||
subscription_ends_at = None
|
||||
ends_str = ""
|
||||
if subscription.current_period_end:
|
||||
subscription_ends_at = datetime.fromtimestamp(subscription.current_period_end).isoformat()
|
||||
ends_str = datetime.fromtimestamp(subscription.current_period_end).strftime("%d/%m/%Y")
|
||||
subscription_ends_at = datetime.fromtimestamp(subscription.current_period_end, tz=timezone.utc)
|
||||
ends_str = datetime.fromtimestamp(subscription.current_period_end, tz=timezone.utc).strftime("%d/%m/%Y")
|
||||
|
||||
is_new_cancel = not user.cancel_at_period_end
|
||||
|
||||
|
||||
@@ -110,26 +110,22 @@ def client(users_file: Path, monkeypatch):
|
||||
monkeypatch.setattr(RateLimitManager, "check_request", _check_request_allow)
|
||||
monkeypatch.setattr(RateLimitManager, "check_translation", _check_translation_allow)
|
||||
|
||||
from middleware.tier_quota import TierQuotaService
|
||||
def _check_usage_limits_allow(user):
|
||||
return {
|
||||
"can_translate": True,
|
||||
"docs_used": 0,
|
||||
"docs_limit": 5,
|
||||
"docs_remaining": 5,
|
||||
"pages_used": 0,
|
||||
"extra_credits": 0,
|
||||
"max_pages_per_doc": 50,
|
||||
"max_file_size_mb": 10,
|
||||
"allowed_providers": ["google", "deepl"],
|
||||
}
|
||||
|
||||
async def _check_quota_allow(self, user_id, tier):
|
||||
from middleware.tier_quota import QuotaResult
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
tomorrow = now.date() + timedelta(days=1)
|
||||
reset_at = datetime(
|
||||
tomorrow.year, tomorrow.month, tomorrow.day, tzinfo=timezone.utc
|
||||
)
|
||||
return QuotaResult(
|
||||
allowed=True, remaining=5, reset_at_utc=reset_at, current_usage=0, limit=5
|
||||
)
|
||||
|
||||
async def _increment_noop(self, user_id):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(TierQuotaService, "check_quota", _check_quota_allow)
|
||||
monkeypatch.setattr(TierQuotaService, "increment_on_success", _increment_noop)
|
||||
monkeypatch.setattr(
|
||||
"routes.translate_routes.check_usage_limits", _check_usage_limits_allow
|
||||
)
|
||||
|
||||
from main import app
|
||||
|
||||
|
||||
@@ -573,16 +573,18 @@ class TestURLIngestionIntegration:
|
||||
monkeypatch.setattr("routes.translate_routes.download_from_url", mock_download)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"middleware.tier_quota.TierQuotaService.check_quota",
|
||||
AsyncMock(
|
||||
return_value=MagicMock(
|
||||
allowed=True,
|
||||
remaining=5,
|
||||
reset_at_utc=None,
|
||||
current_usage=0,
|
||||
limit=5,
|
||||
)
|
||||
),
|
||||
"routes.translate_routes.check_usage_limits",
|
||||
lambda user: {
|
||||
"can_translate": True,
|
||||
"docs_used": 0,
|
||||
"docs_limit": 5,
|
||||
"docs_remaining": 5,
|
||||
"pages_used": 0,
|
||||
"extra_credits": 0,
|
||||
"max_pages_per_doc": 50,
|
||||
"max_file_size_mb": 10,
|
||||
"allowed_providers": ["google", "deepl"],
|
||||
},
|
||||
)
|
||||
|
||||
response = pro_client.post(
|
||||
@@ -598,16 +600,18 @@ class TestURLIngestionIntegration:
|
||||
def test_free_user_rejected(self, free_client, monkeypatch):
|
||||
"""AC7: Free user receives PRO_FEATURE_REQUIRED (403)"""
|
||||
monkeypatch.setattr(
|
||||
"middleware.tier_quota.TierQuotaService.check_quota",
|
||||
AsyncMock(
|
||||
return_value=MagicMock(
|
||||
allowed=True,
|
||||
remaining=5,
|
||||
reset_at_utc=None,
|
||||
current_usage=0,
|
||||
limit=5,
|
||||
)
|
||||
),
|
||||
"routes.translate_routes.check_usage_limits",
|
||||
lambda user: {
|
||||
"can_translate": True,
|
||||
"docs_used": 0,
|
||||
"docs_limit": 5,
|
||||
"docs_remaining": 5,
|
||||
"pages_used": 0,
|
||||
"extra_credits": 0,
|
||||
"max_pages_per_doc": 50,
|
||||
"max_file_size_mb": 10,
|
||||
"allowed_providers": ["google", "deepl"],
|
||||
},
|
||||
)
|
||||
|
||||
response = free_client.post(
|
||||
@@ -632,16 +636,18 @@ class TestURLIngestionIntegration:
|
||||
monkeypatch.setattr("routes.translate_routes.download_from_url", mock_download)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"middleware.tier_quota.TierQuotaService.check_quota",
|
||||
AsyncMock(
|
||||
return_value=MagicMock(
|
||||
allowed=True,
|
||||
remaining=5,
|
||||
reset_at_utc=None,
|
||||
current_usage=0,
|
||||
limit=5,
|
||||
)
|
||||
),
|
||||
"routes.translate_routes.check_usage_limits",
|
||||
lambda user: {
|
||||
"can_translate": True,
|
||||
"docs_used": 0,
|
||||
"docs_limit": 5,
|
||||
"docs_remaining": 5,
|
||||
"pages_used": 0,
|
||||
"extra_credits": 0,
|
||||
"max_pages_per_doc": 50,
|
||||
"max_file_size_mb": 10,
|
||||
"allowed_providers": ["google", "deepl"],
|
||||
},
|
||||
)
|
||||
|
||||
response = pro_client.post(
|
||||
@@ -668,16 +674,18 @@ class TestURLIngestionIntegration:
|
||||
monkeypatch.setattr("routes.translate_routes.download_from_url", mock_download)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"middleware.tier_quota.TierQuotaService.check_quota",
|
||||
AsyncMock(
|
||||
return_value=MagicMock(
|
||||
allowed=True,
|
||||
remaining=5,
|
||||
reset_at_utc=None,
|
||||
current_usage=0,
|
||||
limit=5,
|
||||
)
|
||||
),
|
||||
"routes.translate_routes.check_usage_limits",
|
||||
lambda user: {
|
||||
"can_translate": True,
|
||||
"docs_used": 0,
|
||||
"docs_limit": 5,
|
||||
"docs_remaining": 5,
|
||||
"pages_used": 0,
|
||||
"extra_credits": 0,
|
||||
"max_pages_per_doc": 50,
|
||||
"max_file_size_mb": 10,
|
||||
"allowed_providers": ["google", "deepl"],
|
||||
},
|
||||
)
|
||||
|
||||
response = pro_client.post(
|
||||
@@ -704,16 +712,18 @@ class TestURLIngestionIntegration:
|
||||
monkeypatch.setattr("routes.translate_routes.download_from_url", mock_download)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"middleware.tier_quota.TierQuotaService.check_quota",
|
||||
AsyncMock(
|
||||
return_value=MagicMock(
|
||||
allowed=True,
|
||||
remaining=5,
|
||||
reset_at_utc=None,
|
||||
current_usage=0,
|
||||
limit=5,
|
||||
)
|
||||
),
|
||||
"routes.translate_routes.check_usage_limits",
|
||||
lambda user: {
|
||||
"can_translate": True,
|
||||
"docs_used": 0,
|
||||
"docs_limit": 5,
|
||||
"docs_remaining": 5,
|
||||
"pages_used": 0,
|
||||
"extra_credits": 0,
|
||||
"max_pages_per_doc": 50,
|
||||
"max_file_size_mb": 10,
|
||||
"allowed_providers": ["google", "deepl"],
|
||||
},
|
||||
)
|
||||
|
||||
response = pro_client.post(
|
||||
@@ -738,16 +748,18 @@ class TestURLIngestionIntegration:
|
||||
monkeypatch.setattr("routes.translate_routes.download_from_url", mock_download)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"middleware.tier_quota.TierQuotaService.check_quota",
|
||||
AsyncMock(
|
||||
return_value=MagicMock(
|
||||
allowed=True,
|
||||
remaining=5,
|
||||
reset_at_utc=None,
|
||||
current_usage=0,
|
||||
limit=5,
|
||||
)
|
||||
),
|
||||
"routes.translate_routes.check_usage_limits",
|
||||
lambda user: {
|
||||
"can_translate": True,
|
||||
"docs_used": 0,
|
||||
"docs_limit": 5,
|
||||
"docs_remaining": 5,
|
||||
"pages_used": 0,
|
||||
"extra_credits": 0,
|
||||
"max_pages_per_doc": 50,
|
||||
"max_file_size_mb": 10,
|
||||
"allowed_providers": ["google", "deepl"],
|
||||
},
|
||||
)
|
||||
|
||||
response = pro_client.post(
|
||||
@@ -772,16 +784,18 @@ class TestURLIngestionIntegration:
|
||||
monkeypatch.setattr("routes.translate_routes.download_from_url", mock_download)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"middleware.tier_quota.TierQuotaService.check_quota",
|
||||
AsyncMock(
|
||||
return_value=MagicMock(
|
||||
allowed=True,
|
||||
remaining=5,
|
||||
reset_at_utc=None,
|
||||
current_usage=0,
|
||||
limit=5,
|
||||
)
|
||||
),
|
||||
"routes.translate_routes.check_usage_limits",
|
||||
lambda user: {
|
||||
"can_translate": True,
|
||||
"docs_used": 0,
|
||||
"docs_limit": 5,
|
||||
"docs_remaining": 5,
|
||||
"pages_used": 0,
|
||||
"extra_credits": 0,
|
||||
"max_pages_per_doc": 50,
|
||||
"max_file_size_mb": 10,
|
||||
"allowed_providers": ["google", "deepl"],
|
||||
},
|
||||
)
|
||||
|
||||
response = pro_client.post(
|
||||
|
||||
@@ -115,26 +115,22 @@ def client(users_file: Path, monkeypatch):
|
||||
monkeypatch.setattr(RateLimitManager, "check_request", _check_request_allow)
|
||||
monkeypatch.setattr(RateLimitManager, "check_translation", _check_translation_allow)
|
||||
|
||||
from middleware.tier_quota import TierQuotaService
|
||||
def _check_usage_limits_allow(user):
|
||||
return {
|
||||
"can_translate": True,
|
||||
"docs_used": 0,
|
||||
"docs_limit": 5,
|
||||
"docs_remaining": 5,
|
||||
"pages_used": 0,
|
||||
"extra_credits": 0,
|
||||
"max_pages_per_doc": 50,
|
||||
"max_file_size_mb": 10,
|
||||
"allowed_providers": ["google", "deepl"],
|
||||
}
|
||||
|
||||
async def _check_quota_allow(self, user_id, tier):
|
||||
from middleware.tier_quota import QuotaResult
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
tomorrow = now.date() + timedelta(days=1)
|
||||
reset_at = datetime(
|
||||
tomorrow.year, tomorrow.month, tomorrow.day, tzinfo=timezone.utc
|
||||
)
|
||||
return QuotaResult(
|
||||
allowed=True, remaining=5, reset_at_utc=reset_at, current_usage=0, limit=5
|
||||
)
|
||||
|
||||
async def _increment_noop(self, user_id):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(TierQuotaService, "check_quota", _check_quota_allow)
|
||||
monkeypatch.setattr(TierQuotaService, "increment_on_success", _increment_noop)
|
||||
monkeypatch.setattr(
|
||||
"routes.translate_routes.check_usage_limits", _check_usage_limits_allow
|
||||
)
|
||||
|
||||
from main import app
|
||||
|
||||
@@ -477,24 +473,22 @@ class TestQuotaExceeded:
|
||||
|
||||
def test_returns_429_when_quota_exceeded(self, client, monkeypatch):
|
||||
"""Returns 429 with QUOTA_EXCEEDED when quota exceeded"""
|
||||
from middleware.tier_quota import TierQuotaService, QuotaResult
|
||||
from datetime import datetime, timezone, timedelta
|
||||
def _check_usage_limits_denied(user):
|
||||
return {
|
||||
"can_translate": False,
|
||||
"docs_used": 5,
|
||||
"docs_limit": 5,
|
||||
"docs_remaining": 0,
|
||||
"pages_used": 0,
|
||||
"extra_credits": 0,
|
||||
"max_pages_per_doc": 50,
|
||||
"max_file_size_mb": 10,
|
||||
"allowed_providers": ["google", "deepl"],
|
||||
}
|
||||
|
||||
async def _check_quota_denied(self, user_id, tier):
|
||||
now = datetime.now(timezone.utc)
|
||||
tomorrow = now.date() + timedelta(days=1)
|
||||
reset_at = datetime(
|
||||
tomorrow.year, tomorrow.month, tomorrow.day, tzinfo=timezone.utc
|
||||
)
|
||||
return QuotaResult(
|
||||
allowed=False,
|
||||
remaining=0,
|
||||
reset_at_utc=reset_at,
|
||||
current_usage=5,
|
||||
limit=5,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(TierQuotaService, "check_quota", _check_quota_denied)
|
||||
monkeypatch.setattr(
|
||||
"routes.translate_routes.check_usage_limits", _check_usage_limits_denied
|
||||
)
|
||||
|
||||
# Register and login
|
||||
client.post(REGISTER_URL, json=VALID_USER)
|
||||
@@ -530,24 +524,23 @@ class TestQuotaExceeded:
|
||||
|
||||
def test_includes_retry_after_header(self, client, monkeypatch):
|
||||
"""Includes Retry-After header on 429"""
|
||||
from middleware.tier_quota import TierQuotaService, QuotaResult
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
async def _check_quota_denied(self, user_id, tier):
|
||||
now = datetime.now(timezone.utc)
|
||||
tomorrow = now.date() + timedelta(days=1)
|
||||
reset_at = datetime(
|
||||
tomorrow.year, tomorrow.month, tomorrow.day, tzinfo=timezone.utc
|
||||
)
|
||||
return QuotaResult(
|
||||
allowed=False,
|
||||
remaining=0,
|
||||
reset_at_utc=reset_at,
|
||||
current_usage=5,
|
||||
limit=5,
|
||||
)
|
||||
def _check_usage_limits_denied(user):
|
||||
return {
|
||||
"can_translate": False,
|
||||
"docs_used": 5,
|
||||
"docs_limit": 5,
|
||||
"docs_remaining": 0,
|
||||
"pages_used": 0,
|
||||
"extra_credits": 0,
|
||||
"max_pages_per_doc": 50,
|
||||
"max_file_size_mb": 10,
|
||||
"allowed_providers": ["google", "deepl"],
|
||||
}
|
||||
|
||||
monkeypatch.setattr(TierQuotaService, "check_quota", _check_quota_denied)
|
||||
monkeypatch.setattr(
|
||||
"routes.translate_routes.check_usage_limits", _check_usage_limits_denied
|
||||
)
|
||||
|
||||
client.post(REGISTER_URL, json=VALID_USER)
|
||||
response = client.post(
|
||||
|
||||
Reference in New Issue
Block a user