fix(billing): unify quota counters, fix Stripe webhooks, tier/plan sync
All checks were successful
Deploy to Production / Build and Deploy (push) Successful in 3m16s
All checks were successful
Deploy to Production / Build and Deploy (push) Successful in 3m16s
This commit is contained in:
@@ -93,34 +93,45 @@ class UserRepository:
|
|||||||
self.db.commit()
|
self.db.commit()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def increment_usage(
|
def reset_usage_if_needed(self, user_id: str) -> Optional[User]:
|
||||||
self, user_id: str, docs: int = 0, pages: int = 0, api_calls: int = 0
|
"""Reset monthly counters if usage_reset_date is in a previous month."""
|
||||||
) -> Optional[User]:
|
|
||||||
"""Increment usage counters"""
|
|
||||||
user = self.get_by_id(user_id)
|
user = self.get_by_id(user_id)
|
||||||
if not user:
|
if not user:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Check if usage needs to be reset (monthly)
|
now = datetime.now(timezone.utc)
|
||||||
if user.usage_reset_date:
|
reset_date = user.usage_reset_date
|
||||||
now = datetime.now(timezone.utc)
|
if reset_date is None or reset_date.month != now.month or reset_date.year != now.year:
|
||||||
if (
|
user.docs_translated_this_month = 0
|
||||||
now.month != user.usage_reset_date.month
|
user.pages_translated_this_month = 0
|
||||||
or now.year != user.usage_reset_date.year
|
user.api_calls_this_month = 0
|
||||||
):
|
user.usage_reset_date = now
|
||||||
user.docs_translated_this_month = 0
|
user.updated_at = now
|
||||||
user.pages_translated_this_month = 0
|
self.db.commit()
|
||||||
user.api_calls_this_month = 0
|
self.db.refresh(user)
|
||||||
user.usage_reset_date = now
|
|
||||||
|
|
||||||
user.docs_translated_this_month += docs
|
|
||||||
user.pages_translated_this_month += pages
|
|
||||||
user.api_calls_this_month += api_calls
|
|
||||||
|
|
||||||
self.db.commit()
|
|
||||||
self.db.refresh(user)
|
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
def increment_usage(
|
||||||
|
self, user_id: str, docs: int = 0, pages: int = 0, api_calls: int = 0
|
||||||
|
) -> Optional[User]:
|
||||||
|
"""Increment usage counters atomically via SQL UPDATE."""
|
||||||
|
from sqlalchemy import update
|
||||||
|
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
self.db.execute(
|
||||||
|
update(User)
|
||||||
|
.where(User.id == user_id)
|
||||||
|
.values(
|
||||||
|
docs_translated_this_month=User.docs_translated_this_month + docs,
|
||||||
|
pages_translated_this_month=User.pages_translated_this_month + pages,
|
||||||
|
api_calls_this_month=User.api_calls_this_month + api_calls,
|
||||||
|
updated_at=now,
|
||||||
|
)
|
||||||
|
.execution_options(synchronize_session=False)
|
||||||
|
)
|
||||||
|
self.db.commit()
|
||||||
|
return self.get_by_id(user_id)
|
||||||
|
|
||||||
def add_credits(self, user_id: str, credits: int) -> Optional[User]:
|
def add_credits(self, user_id: str, credits: int) -> Optional[User]:
|
||||||
"""Add extra credits to user"""
|
"""Add extra credits to user"""
|
||||||
user = self.get_by_id(user_id)
|
user = self.get_by_id(user_id)
|
||||||
@@ -133,14 +144,21 @@ class UserRepository:
|
|||||||
return user
|
return user
|
||||||
|
|
||||||
def use_credits(self, user_id: str, credits: int) -> bool:
|
def use_credits(self, user_id: str, credits: int) -> bool:
|
||||||
"""Use credits from user balance"""
|
"""Use credits from user balance atomically (prevents overdraft)."""
|
||||||
user = self.get_by_id(user_id)
|
from sqlalchemy import update
|
||||||
if not user or user.extra_credits < credits:
|
|
||||||
return False
|
|
||||||
|
|
||||||
user.extra_credits -= credits
|
now = datetime.now(timezone.utc)
|
||||||
|
result = self.db.execute(
|
||||||
|
update(User)
|
||||||
|
.where(User.id == user_id, User.extra_credits >= credits)
|
||||||
|
.values(
|
||||||
|
extra_credits=User.extra_credits - credits,
|
||||||
|
updated_at=now,
|
||||||
|
)
|
||||||
|
.execution_options(synchronize_session=False)
|
||||||
|
)
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
return True
|
return result.rowcount > 0
|
||||||
|
|
||||||
def get_all_users(
|
def get_all_users(
|
||||||
self, skip: int = 0, limit: int = 100, plan: Optional[PlanType] = None
|
self, skip: int = 0, limit: int = 100, plan: Optional[PlanType] = None
|
||||||
|
|||||||
@@ -204,6 +204,7 @@ class User(BaseModel):
|
|||||||
|
|
||||||
# Subscription info
|
# Subscription info
|
||||||
plan: PlanType = PlanType.FREE
|
plan: PlanType = PlanType.FREE
|
||||||
|
tier: str = "free" # binary-ish tier for JWT/auth: free for free/starter, pro for paid
|
||||||
subscription_status: SubscriptionStatus = SubscriptionStatus.ACTIVE
|
subscription_status: SubscriptionStatus = SubscriptionStatus.ACTIVE
|
||||||
stripe_customer_id: Optional[str] = None
|
stripe_customer_id: Optional[str] = None
|
||||||
stripe_subscription_id: Optional[str] = None
|
stripe_subscription_id: Optional[str] = None
|
||||||
|
|||||||
@@ -104,13 +104,14 @@ async def require_user(credentials: HTTPAuthorizationCredentials = Depends(secur
|
|||||||
def user_to_response(user) -> UserResponse:
|
def user_to_response(user) -> UserResponse:
|
||||||
"""Convert User to UserResponse with plan limits"""
|
"""Convert User to UserResponse with plan limits"""
|
||||||
plan_limits = PLANS[user.plan]
|
plan_limits = PLANS[user.plan]
|
||||||
|
tier = getattr(user, "tier", None) or user.plan
|
||||||
return UserResponse(
|
return UserResponse(
|
||||||
id=user.id,
|
id=user.id,
|
||||||
email=user.email,
|
email=user.email,
|
||||||
name=user.name,
|
name=user.name,
|
||||||
avatar_url=user.avatar_url,
|
avatar_url=user.avatar_url,
|
||||||
plan=user.plan,
|
plan=user.plan,
|
||||||
tier=user.plan,
|
tier=tier,
|
||||||
subscription_status=user.subscription_status,
|
subscription_status=user.subscription_status,
|
||||||
subscription_ends_at=getattr(user, 'subscription_ends_at', None),
|
subscription_ends_at=getattr(user, 'subscription_ends_at', None),
|
||||||
cancel_at_period_end=getattr(user, 'cancel_at_period_end', False),
|
cancel_at_period_end=getattr(user, 'cancel_at_period_end', False),
|
||||||
@@ -540,7 +541,7 @@ async def login_v1(request: Request):
|
|||||||
)
|
)
|
||||||
|
|
||||||
access_token = create_access_token(
|
access_token = create_access_token(
|
||||||
user.id, tier=user.plan.value, expires_delta=timedelta(minutes=15)
|
user.id, tier=getattr(user, "tier", user.plan.value), expires_delta=timedelta(minutes=15)
|
||||||
)
|
)
|
||||||
refresh_token = create_refresh_token(user.id, expires_delta=timedelta(days=7))
|
refresh_token = create_refresh_token(user.id, expires_delta=timedelta(days=7))
|
||||||
|
|
||||||
@@ -626,7 +627,7 @@ async def google_auth_v1(body: GoogleAuthRequest):
|
|||||||
content={"error": "USER_CREATE_FAILED", "message": str(exc)},
|
content={"error": "USER_CREATE_FAILED", "message": str(exc)},
|
||||||
)
|
)
|
||||||
|
|
||||||
access_token = create_access_token(user.id)
|
access_token = create_access_token(user.id, tier=getattr(user, "tier", "free"))
|
||||||
refresh_token = create_refresh_token(user.id)
|
refresh_token = create_refresh_token(user.id)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
@@ -724,7 +725,7 @@ async def refresh_v1(request: Request):
|
|||||||
)
|
)
|
||||||
|
|
||||||
access_token = create_access_token(
|
access_token = create_access_token(
|
||||||
user.id, tier=user.plan.value, expires_delta=timedelta(minutes=15)
|
user.id, tier=getattr(user, "tier", user.plan.value), expires_delta=timedelta(minutes=15)
|
||||||
)
|
)
|
||||||
new_refresh_token = create_refresh_token(user.id, expires_delta=timedelta(days=7))
|
new_refresh_token = create_refresh_token(user.id, expires_delta=timedelta(days=7))
|
||||||
|
|
||||||
|
|||||||
@@ -44,8 +44,13 @@ from typing_extensions import Annotated
|
|||||||
from config import config
|
from config import config
|
||||||
from translators import ExcelTranslator, WordTranslator, PowerPointTranslator
|
from translators import ExcelTranslator, WordTranslator, PowerPointTranslator
|
||||||
from models.subscription import PlanType
|
from models.subscription import PlanType
|
||||||
from middleware.tier_quota import tier_quota_service
|
from services.auth_service import (
|
||||||
from services.auth_service import record_usage, check_usage_limits
|
record_usage,
|
||||||
|
check_usage_limits,
|
||||||
|
reserve_translation_quota,
|
||||||
|
release_translation_quota,
|
||||||
|
)
|
||||||
|
from middleware.tier_quota import _seconds_until_next_month, _next_month_utc
|
||||||
from middleware.validation import FileValidator, ValidationError, LanguageValidator, webhook_validator
|
from middleware.validation import FileValidator, ValidationError, LanguageValidator, webhook_validator
|
||||||
from middleware.api_key_auth import get_authenticated_user, get_user_from_api_key
|
from middleware.api_key_auth import get_authenticated_user, get_user_from_api_key
|
||||||
from utils import file_handler
|
from utils import file_handler
|
||||||
@@ -548,6 +553,7 @@ async def translate_document_v1(
|
|||||||
"""
|
"""
|
||||||
request_id = getattr(request.state, "request_id", str(uuid.uuid4())[:8])
|
request_id = getattr(request.state, "request_id", str(uuid.uuid4())[:8])
|
||||||
|
|
||||||
|
quota_reserved = False
|
||||||
try:
|
try:
|
||||||
if not file and not file_url:
|
if not file and not file_url:
|
||||||
raise TranslateEndpointError(
|
raise TranslateEndpointError(
|
||||||
@@ -615,40 +621,42 @@ async def translate_document_v1(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if current_user:
|
if current_user:
|
||||||
quota = await tier_quota_service.check_quota(user_id, tier)
|
|
||||||
if not quota.allowed:
|
|
||||||
retry_after = tier_quota_service.seconds_until_reset()
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=429,
|
|
||||||
detail={
|
|
||||||
"error": "QUOTA_EXCEEDED",
|
|
||||||
"message": f"Monthly limit reached ({quota.current_usage}/{quota.limit} documents). Upgrade your plan for more.",
|
|
||||||
"details": {
|
|
||||||
"current_usage": quota.current_usage,
|
|
||||||
"limit": quota.limit,
|
|
||||||
"tier": tier,
|
|
||||||
"reset_at": quota.reset_at_utc.isoformat(),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
headers={"Retry-After": str(retry_after)},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Strict database plan limit check (Starter, Pro, Business, Enterprise)
|
|
||||||
usage = check_usage_limits(current_user)
|
usage = check_usage_limits(current_user)
|
||||||
if not usage["can_translate"]:
|
if not usage["can_translate"]:
|
||||||
|
retry_after = _seconds_until_next_month()
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=429,
|
status_code=429,
|
||||||
detail={
|
detail={
|
||||||
"error": "QUOTA_EXCEEDED",
|
"error": "QUOTA_EXCEEDED",
|
||||||
"message": f"Monthly limit reached ({usage['docs_used']}/{usage['docs_limit']} documents). Upgrade your plan for more.",
|
"message": f"Monthly limit reached ({usage['docs_used']}/{usage['docs_limit']} documents). Upgrade your plan for more.",
|
||||||
"details": {
|
"details": {
|
||||||
"current_usage": usage["docs_used"],
|
"current_usage": usage['docs_used'],
|
||||||
"limit": usage["docs_limit"],
|
"limit": usage['docs_limit'],
|
||||||
"tier": tier,
|
"tier": tier,
|
||||||
|
"reset_at": _next_month_utc().isoformat(),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
headers={"Retry-After": str(retry_after)},
|
||||||
)
|
)
|
||||||
rate_limit_remaining = quota.remaining
|
# Atomically reserve one document slot now so concurrent requests cannot
|
||||||
|
# overshoot the monthly quota while background jobs are still running.
|
||||||
|
reserved = await asyncio.to_thread(reserve_translation_quota, user_id)
|
||||||
|
if not reserved:
|
||||||
|
retry_after = _seconds_until_next_month()
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429,
|
||||||
|
detail={
|
||||||
|
"error": "QUOTA_EXCEEDED",
|
||||||
|
"message": "Monthly limit reached. Upgrade your plan for more.",
|
||||||
|
"details": {
|
||||||
|
"tier": tier,
|
||||||
|
"reset_at": _next_month_utc().isoformat(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
headers={"Retry-After": str(retry_after)},
|
||||||
|
)
|
||||||
|
quota_reserved = True
|
||||||
|
rate_limit_remaining = usage["docs_remaining"]
|
||||||
else:
|
else:
|
||||||
rate_limit_remaining = -1
|
rate_limit_remaining = -1
|
||||||
|
|
||||||
@@ -839,6 +847,8 @@ async def translate_document_v1(
|
|||||||
)
|
)
|
||||||
|
|
||||||
except TranslateEndpointError as e:
|
except TranslateEndpointError as e:
|
||||||
|
if quota_reserved and user_id:
|
||||||
|
await asyncio.to_thread(release_translation_quota, user_id)
|
||||||
status_code = 400
|
status_code = 400
|
||||||
if e.code == TranslateEndpointError.FILE_TOO_LARGE:
|
if e.code == TranslateEndpointError.FILE_TOO_LARGE:
|
||||||
status_code = 413
|
status_code = 413
|
||||||
@@ -852,8 +862,12 @@ async def translate_document_v1(
|
|||||||
content=e.to_dict(),
|
content=e.to_dict(),
|
||||||
)
|
)
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
|
if quota_reserved and user_id:
|
||||||
|
await asyncio.to_thread(release_translation_quota, user_id)
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
if quota_reserved and user_id:
|
||||||
|
await asyncio.to_thread(release_translation_quota, user_id)
|
||||||
logger.error(f"[{request_id}] Unexpected error: {e}")
|
logger.error(f"[{request_id}] Unexpected error: {e}")
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
@@ -973,6 +987,7 @@ async def _run_translation_job(
|
|||||||
return
|
return
|
||||||
|
|
||||||
tracker = ProgressTracker(job_id, _translation_jobs)
|
tracker = ProgressTracker(job_id, _translation_jobs)
|
||||||
|
usage_recorded = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
job["status"] = "processing"
|
job["status"] = "processing"
|
||||||
@@ -1337,14 +1352,14 @@ async def _run_translation_job(
|
|||||||
else:
|
else:
|
||||||
cost_factor = 5
|
cost_factor = 5
|
||||||
|
|
||||||
for _ in range(cost_factor):
|
|
||||||
await tier_quota_service.increment_on_success(user_id)
|
|
||||||
|
|
||||||
# Persist monthly usage counters in PostgreSQL (docs + pages)
|
# Persist monthly usage counters in PostgreSQL (docs + pages)
|
||||||
pages = await asyncio.to_thread(
|
pages = await asyncio.to_thread(
|
||||||
_estimate_pages, input_path, file_extension
|
_estimate_pages, input_path, file_extension
|
||||||
)
|
)
|
||||||
await asyncio.to_thread(record_usage, user_id, pages, False, cost_factor)
|
await asyncio.to_thread(
|
||||||
|
record_usage, user_id, pages, cost_factor, reserved_docs=1
|
||||||
|
)
|
||||||
|
usage_recorded = True
|
||||||
logger.info(f"Job {job_id}: usage recorded — {pages} page(s) with cost factor {cost_factor}")
|
logger.info(f"Job {job_id}: usage recorded — {pages} page(s) with cost factor {cost_factor}")
|
||||||
|
|
||||||
# Apply watermark for Free-tier users
|
# Apply watermark for Free-tier users
|
||||||
@@ -1364,6 +1379,12 @@ async def _run_translation_job(
|
|||||||
record_translation(provider=provider, file_type=file_extension or "unknown", duration=duration, status="success")
|
record_translation(provider=provider, file_type=file_extension or "unknown", duration=duration, status="success")
|
||||||
logger.info(f"Job {job_id}: Completed successfully")
|
logger.info(f"Job {job_id}: Completed successfully")
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# Background task cancelled (e.g. TestClient teardown or server shutdown).
|
||||||
|
# The document slot was already reserved at request time; keep it consumed
|
||||||
|
# so quota enforcement remains deterministic.
|
||||||
|
logger.warning(f"Job {job_id}: translation task cancelled, keeping reserved quota")
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Check if this is our structured TranslationProviderError
|
# Check if this is our structured TranslationProviderError
|
||||||
if type(e).__name__ == "TranslationProviderError":
|
if type(e).__name__ == "TranslationProviderError":
|
||||||
@@ -1376,6 +1397,13 @@ async def _run_translation_job(
|
|||||||
# Record translation failure metric
|
# Record translation failure metric
|
||||||
record_translation(provider=provider, file_type=file_extension or "unknown", duration=0, status="error")
|
record_translation(provider=provider, file_type=file_extension or "unknown", duration=0, status="error")
|
||||||
|
|
||||||
|
if user_id and not usage_recorded:
|
||||||
|
try:
|
||||||
|
await asyncio.to_thread(release_translation_quota, user_id)
|
||||||
|
logger.info(f"Job {job_id}: released reserved quota after failure")
|
||||||
|
except Exception as release_err:
|
||||||
|
logger.exception(f"Job {job_id}: failed to release reserved quota: {release_err}")
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
if webhook_url:
|
if webhook_url:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -278,6 +278,7 @@ def _db_user_to_model(db_user) -> User:
|
|||||||
password_hash=db_user.password_hash,
|
password_hash=db_user.password_hash,
|
||||||
avatar_url=db_user.avatar_url,
|
avatar_url=db_user.avatar_url,
|
||||||
plan=PlanType(db_user.plan) if db_user.plan else PlanType.FREE,
|
plan=PlanType(db_user.plan) if db_user.plan else PlanType.FREE,
|
||||||
|
tier=db_user.tier or "free",
|
||||||
subscription_status=SubscriptionStatus(db_user.subscription_status)
|
subscription_status=SubscriptionStatus(db_user.subscription_status)
|
||||||
if db_user.subscription_status
|
if db_user.subscription_status
|
||||||
else SubscriptionStatus.ACTIVE,
|
else SubscriptionStatus.ACTIVE,
|
||||||
@@ -460,36 +461,48 @@ def update_user(user_id: str, updates: Dict[str, Any]) -> Optional[User]:
|
|||||||
return User(**users[user_id])
|
return User(**users[user_id])
|
||||||
|
|
||||||
|
|
||||||
|
def _reset_usage_if_needed(user_id: str) -> Optional[User]:
|
||||||
|
"""Reset monthly counters if usage_reset_date is in a previous month."""
|
||||||
|
if USE_DATABASE and DATABASE_AVAILABLE:
|
||||||
|
from database.connection import get_sync_session
|
||||||
|
from database.repositories import UserRepository
|
||||||
|
|
||||||
|
with get_sync_session() as session:
|
||||||
|
repo = UserRepository(session)
|
||||||
|
repo.reset_usage_if_needed(user_id)
|
||||||
|
else:
|
||||||
|
user = get_user_by_id(user_id)
|
||||||
|
if not user:
|
||||||
|
return None
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
reset_date = user.usage_reset_date
|
||||||
|
if reset_date.month != now.month or reset_date.year != now.year:
|
||||||
|
update_user(
|
||||||
|
user_id,
|
||||||
|
{
|
||||||
|
"docs_translated_this_month": 0,
|
||||||
|
"pages_translated_this_month": 0,
|
||||||
|
"api_calls_this_month": 0,
|
||||||
|
"usage_reset_date": now.isoformat(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return get_user_by_id(user_id)
|
||||||
|
|
||||||
|
|
||||||
def check_usage_limits(user: User) -> Dict[str, Any]:
|
def check_usage_limits(user: User) -> Dict[str, Any]:
|
||||||
"""Check if user has exceeded their plan limits"""
|
"""Check if user has exceeded their plan limits"""
|
||||||
|
# Ensure counters are reset if we've entered a new month.
|
||||||
|
refreshed = _reset_usage_if_needed(user.id)
|
||||||
|
if refreshed:
|
||||||
|
user = refreshed
|
||||||
|
|
||||||
plan = PLANS[user.plan]
|
plan = PLANS[user.plan]
|
||||||
|
|
||||||
# Reset usage if it's a new month
|
|
||||||
now = datetime.now(timezone.utc)
|
|
||||||
if (
|
|
||||||
user.usage_reset_date.month != now.month
|
|
||||||
or user.usage_reset_date.year != now.year
|
|
||||||
):
|
|
||||||
update_user(
|
|
||||||
user.id,
|
|
||||||
{
|
|
||||||
"docs_translated_this_month": 0,
|
|
||||||
"pages_translated_this_month": 0,
|
|
||||||
"api_calls_this_month": 0,
|
|
||||||
"usage_reset_date": now.isoformat() if not USE_DATABASE else now,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
user.docs_translated_this_month = 0
|
|
||||||
user.pages_translated_this_month = 0
|
|
||||||
user.api_calls_this_month = 0
|
|
||||||
|
|
||||||
docs_limit = plan["docs_per_month"]
|
docs_limit = plan["docs_per_month"]
|
||||||
docs_remaining = (
|
unlimited = docs_limit == -1
|
||||||
max(0, docs_limit - user.docs_translated_this_month) if docs_limit > 0 else -1
|
docs_remaining = -1 if unlimited else max(0, docs_limit - user.docs_translated_this_month)
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"can_translate": docs_remaining != 0 or user.extra_credits > 0,
|
"can_translate": unlimited or docs_remaining != 0 or user.extra_credits > 0,
|
||||||
"docs_used": user.docs_translated_this_month,
|
"docs_used": user.docs_translated_this_month,
|
||||||
"docs_limit": docs_limit,
|
"docs_limit": docs_limit,
|
||||||
"docs_remaining": docs_remaining,
|
"docs_remaining": docs_remaining,
|
||||||
@@ -502,23 +515,190 @@ def check_usage_limits(user: User) -> Dict[str, Any]:
|
|||||||
|
|
||||||
|
|
||||||
def record_usage(
|
def record_usage(
|
||||||
user_id: str, pages_count: int, use_credits: bool = False, cost_factor: int = 1
|
user_id: str, pages_count: int, cost_factor: int = 1, reserved_docs: int = 0
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Record document translation usage with optional cost factor depending on AI model"""
|
"""Record document translation usage with optional cost factor depending on AI model.
|
||||||
|
|
||||||
|
`reserved_docs` is the number of document slots already reserved at request time
|
||||||
|
(e.g. by ``reserve_translation_quota``). Those slots are not counted again; only
|
||||||
|
the remaining ``max(0, cost_factor - reserved_docs)`` docs are added here.
|
||||||
|
|
||||||
|
Automatically consumes extra credits first; falls back to monthly quota.
|
||||||
|
Returns True if usage was recorded successfully, False otherwise.
|
||||||
|
"""
|
||||||
|
user = _reset_usage_if_needed(user_id)
|
||||||
|
if not user:
|
||||||
|
return False
|
||||||
|
|
||||||
|
total_cost = pages_count * cost_factor
|
||||||
|
docs_to_record = max(0, cost_factor - reserved_docs)
|
||||||
|
plan = PLANS[user.plan]
|
||||||
|
docs_limit = plan["docs_per_month"]
|
||||||
|
unlimited = docs_limit == -1
|
||||||
|
|
||||||
|
if USE_DATABASE and DATABASE_AVAILABLE:
|
||||||
|
from database.connection import get_sync_session
|
||||||
|
from database.repositories import UserRepository
|
||||||
|
|
||||||
|
with get_sync_session() as session:
|
||||||
|
repo = UserRepository(session)
|
||||||
|
|
||||||
|
if unlimited:
|
||||||
|
# Paid unlimited plans: track usage for analytics/downgrade safety.
|
||||||
|
repo.increment_usage(user_id, docs=docs_to_record, pages=total_cost)
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Prefer credits first, then quota.
|
||||||
|
if user.extra_credits > 0:
|
||||||
|
credits_to_use = min(user.extra_credits, total_cost)
|
||||||
|
if repo.use_credits(user_id, credits_to_use):
|
||||||
|
remaining_cost = total_cost - credits_to_use
|
||||||
|
if remaining_cost > 0:
|
||||||
|
repo.increment_usage(
|
||||||
|
user_id, docs=docs_to_record, pages=remaining_cost
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
if user.docs_translated_this_month + docs_to_record <= docs_limit:
|
||||||
|
repo.increment_usage(user_id, docs=docs_to_record, pages=total_cost)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
# JSON fallback (non-atomic, dev only)
|
||||||
|
if unlimited:
|
||||||
|
return update_user(
|
||||||
|
user_id,
|
||||||
|
{
|
||||||
|
"docs_translated_this_month": user.docs_translated_this_month
|
||||||
|
+ docs_to_record,
|
||||||
|
"pages_translated_this_month": user.pages_translated_this_month
|
||||||
|
+ total_cost,
|
||||||
|
},
|
||||||
|
) is not None
|
||||||
|
|
||||||
|
if user.extra_credits > 0:
|
||||||
|
credits_to_use = min(user.extra_credits, total_cost)
|
||||||
|
new_credits = user.extra_credits - credits_to_use
|
||||||
|
remaining_cost = total_cost - credits_to_use
|
||||||
|
updates = {"extra_credits": new_credits}
|
||||||
|
if remaining_cost > 0:
|
||||||
|
if user.docs_translated_this_month + docs_to_record > docs_limit:
|
||||||
|
return False
|
||||||
|
updates["docs_translated_this_month"] = (
|
||||||
|
user.docs_translated_this_month + docs_to_record
|
||||||
|
)
|
||||||
|
updates["pages_translated_this_month"] = (
|
||||||
|
user.pages_translated_this_month + remaining_cost
|
||||||
|
)
|
||||||
|
return update_user(user_id, updates) is not None
|
||||||
|
|
||||||
|
if user.docs_translated_this_month + docs_to_record <= docs_limit:
|
||||||
|
return update_user(
|
||||||
|
user_id,
|
||||||
|
{
|
||||||
|
"docs_translated_this_month": user.docs_translated_this_month
|
||||||
|
+ docs_to_record,
|
||||||
|
"pages_translated_this_month": user.pages_translated_this_month
|
||||||
|
+ total_cost,
|
||||||
|
},
|
||||||
|
) is not None
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def reserve_translation_quota(user_id: str) -> bool:
|
||||||
|
"""Atomically reserve one document slot at request time.
|
||||||
|
|
||||||
|
Returns True if the reservation succeeded. This prevents race conditions where
|
||||||
|
multiple concurrent requests could exceed the monthly quota.
|
||||||
|
"""
|
||||||
|
user = _reset_usage_if_needed(user_id)
|
||||||
|
if not user:
|
||||||
|
return False
|
||||||
|
|
||||||
|
plan = PLANS[user.plan]
|
||||||
|
docs_limit = plan["docs_per_month"]
|
||||||
|
unlimited = docs_limit == -1
|
||||||
|
|
||||||
|
if unlimited:
|
||||||
|
if USE_DATABASE and DATABASE_AVAILABLE:
|
||||||
|
from database.connection import get_sync_session
|
||||||
|
from database.repositories import UserRepository
|
||||||
|
|
||||||
|
with get_sync_session() as session:
|
||||||
|
repo = UserRepository(session)
|
||||||
|
repo.increment_usage(user_id, docs=1, pages=0)
|
||||||
|
return True
|
||||||
|
return update_user(
|
||||||
|
user_id,
|
||||||
|
{"docs_translated_this_month": user.docs_translated_this_month + 1},
|
||||||
|
) is not None
|
||||||
|
|
||||||
|
# Limited plan: prefer credits first, then quota.
|
||||||
|
if user.extra_credits > 0:
|
||||||
|
if USE_DATABASE and DATABASE_AVAILABLE:
|
||||||
|
from database.connection import get_sync_session
|
||||||
|
from database.repositories import UserRepository
|
||||||
|
|
||||||
|
with get_sync_session() as session:
|
||||||
|
repo = UserRepository(session)
|
||||||
|
return repo.use_credits(user_id, 1)
|
||||||
|
return update_user(
|
||||||
|
user_id, {"extra_credits": max(0, user.extra_credits - 1)}
|
||||||
|
) is not None
|
||||||
|
|
||||||
|
if user.docs_translated_this_month + 1 <= docs_limit:
|
||||||
|
if USE_DATABASE and DATABASE_AVAILABLE:
|
||||||
|
from database.connection import get_sync_session
|
||||||
|
from database.repositories import UserRepository
|
||||||
|
|
||||||
|
with get_sync_session() as session:
|
||||||
|
repo = UserRepository(session)
|
||||||
|
repo.increment_usage(user_id, docs=1, pages=0)
|
||||||
|
return True
|
||||||
|
return update_user(
|
||||||
|
user_id,
|
||||||
|
{"docs_translated_this_month": user.docs_translated_this_month + 1},
|
||||||
|
) is not None
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def release_translation_quota(user_id: str) -> bool:
|
||||||
|
"""Release a previously reserved document slot (e.g. on translation failure)."""
|
||||||
user = get_user_by_id(user_id)
|
user = get_user_by_id(user_id)
|
||||||
if not user:
|
if not user:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
updates = {
|
if USE_DATABASE and DATABASE_AVAILABLE:
|
||||||
"docs_translated_this_month": user.docs_translated_this_month + cost_factor,
|
from database.connection import get_sync_session
|
||||||
"pages_translated_this_month": user.pages_translated_this_month + (pages_count * cost_factor),
|
from database.repositories import UserRepository
|
||||||
}
|
from sqlalchemy import update
|
||||||
|
|
||||||
if use_credits:
|
with get_sync_session() as session:
|
||||||
updates["extra_credits"] = max(0, user.extra_credits - (pages_count * cost_factor))
|
now = datetime.now(timezone.utc)
|
||||||
|
session.execute(
|
||||||
|
update(db_models.User)
|
||||||
|
.where(
|
||||||
|
db_models.User.id == user_id,
|
||||||
|
db_models.User.docs_translated_this_month > 0,
|
||||||
|
)
|
||||||
|
.values(
|
||||||
|
docs_translated_this_month=db_models.User.docs_translated_this_month
|
||||||
|
- 1,
|
||||||
|
updated_at=now,
|
||||||
|
)
|
||||||
|
.execution_options(synchronize_session=False)
|
||||||
|
)
|
||||||
|
session.commit()
|
||||||
|
return True
|
||||||
|
|
||||||
result = update_user(user_id, updates)
|
if user.docs_translated_this_month > 0:
|
||||||
return result is not None
|
return update_user(
|
||||||
|
user_id,
|
||||||
|
{"docs_translated_this_month": user.docs_translated_this_month - 1},
|
||||||
|
) is not None
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def add_credits(user_id: str, credits: int) -> bool:
|
def add_credits(user_id: str, credits: int) -> bool:
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ Stripe payment integration for subscriptions and credits
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, Dict, Any
|
from typing import Optional, Dict, Any
|
||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -250,19 +250,23 @@ async def handle_webhook(payload: bytes, sig_header: str) -> Dict[str, Any]:
|
|||||||
if event["type"] == "checkout.session.completed":
|
if event["type"] == "checkout.session.completed":
|
||||||
session = event["data"]["object"]
|
session = event["data"]["object"]
|
||||||
await handle_checkout_completed(session)
|
await handle_checkout_completed(session)
|
||||||
|
|
||||||
elif event["type"] == "customer.subscription.updated":
|
elif event["type"] == "customer.subscription.updated":
|
||||||
subscription = event["data"]["object"]
|
subscription = event["data"]["object"]
|
||||||
await handle_subscription_updated(subscription)
|
await handle_subscription_updated(subscription)
|
||||||
|
|
||||||
elif event["type"] == "customer.subscription.deleted":
|
elif event["type"] == "customer.subscription.deleted":
|
||||||
subscription = event["data"]["object"]
|
subscription = event["data"]["object"]
|
||||||
await handle_subscription_deleted(subscription)
|
await handle_subscription_deleted(subscription)
|
||||||
|
|
||||||
elif event["type"] == "invoice.payment_failed":
|
elif event["type"] == "invoice.payment_failed":
|
||||||
invoice = event["data"]["object"]
|
invoice = event["data"]["object"]
|
||||||
await handle_payment_failed(invoice)
|
await handle_payment_failed(invoice)
|
||||||
|
|
||||||
|
elif event["type"] == "invoice.paid":
|
||||||
|
invoice = event["data"]["object"]
|
||||||
|
await handle_invoice_paid(invoice)
|
||||||
|
|
||||||
return {"status": "success"}
|
return {"status": "success"}
|
||||||
|
|
||||||
|
|
||||||
@@ -276,7 +280,9 @@ async def handle_checkout_completed(session: Dict):
|
|||||||
|
|
||||||
session_id = session.get("id")
|
session_id = session.get("id")
|
||||||
|
|
||||||
# Check for duplicate session processing using PaymentHistory
|
# Check for duplicate session processing using PaymentHistory.
|
||||||
|
# Use the Stripe payment_intent (one-time) or subscription id (recurring) as the
|
||||||
|
# idempotency key, NOT the checkout session id — Stripe can redeliver the event.
|
||||||
db_available = False
|
db_available = False
|
||||||
try:
|
try:
|
||||||
from database.connection import get_sync_session
|
from database.connection import get_sync_session
|
||||||
@@ -285,14 +291,22 @@ async def handle_checkout_completed(session: Dict):
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
payment_intent_id = session.get("payment_intent")
|
||||||
|
subscription_id = session.get("subscription")
|
||||||
if db_available and session_id:
|
if db_available and session_id:
|
||||||
try:
|
try:
|
||||||
with get_sync_session() as db_session:
|
with get_sync_session() as db_session:
|
||||||
existing = db_session.query(DBPaymentHistory).filter(
|
from sqlalchemy import or_
|
||||||
DBPaymentHistory.stripe_payment_intent_id == session_id
|
|
||||||
).first()
|
filters = [DBPaymentHistory.stripe_payment_intent_id == session_id]
|
||||||
|
if payment_intent_id:
|
||||||
|
filters.append(DBPaymentHistory.stripe_payment_intent_id == payment_intent_id)
|
||||||
|
if subscription_id:
|
||||||
|
filters.append(DBPaymentHistory.stripe_invoice_id == subscription_id)
|
||||||
|
existing = db_session.query(DBPaymentHistory).filter(or_(*filters)).first()
|
||||||
if existing:
|
if existing:
|
||||||
logger.info("Checkout session %s already processed. Skipping.", session_id)
|
logger.info("Checkout session %s already processed (pi=%s sub=%s). Skipping.",
|
||||||
|
session_id, payment_intent_id, subscription_id)
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error checking PaymentHistory duplication: %s", e)
|
logger.error("Error checking PaymentHistory duplication: %s", e)
|
||||||
@@ -312,7 +326,7 @@ async def handle_checkout_completed(session: Dict):
|
|||||||
with get_sync_session() as db_session:
|
with get_sync_session() as db_session:
|
||||||
payment = DBPaymentHistory(
|
payment = DBPaymentHistory(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
stripe_payment_intent_id=session_id,
|
stripe_payment_intent_id=payment_intent_id or session_id,
|
||||||
stripe_invoice_id=session.get("invoice") or session.get("subscription"),
|
stripe_invoice_id=session.get("invoice") or session.get("subscription"),
|
||||||
amount_cents=session.get("amount_total") or 0,
|
amount_cents=session.get("amount_total") or 0,
|
||||||
currency=session.get("currency") or "usd",
|
currency=session.get("currency") or "usd",
|
||||||
@@ -377,7 +391,6 @@ async def handle_checkout_completed(session: Dict):
|
|||||||
subscription_id = subscription_raw.get("id")
|
subscription_id = subscription_raw.get("id")
|
||||||
period_end = subscription_raw.get("current_period_end")
|
period_end = subscription_raw.get("current_period_end")
|
||||||
if period_end:
|
if period_end:
|
||||||
from datetime import timezone
|
|
||||||
subscription_ends_at = datetime.fromtimestamp(period_end, tz=timezone.utc)
|
subscription_ends_at = datetime.fromtimestamp(period_end, tz=timezone.utc)
|
||||||
|
|
||||||
# Derive tier from plan (DB constraint: only 'free' or 'pro')
|
# Derive tier from plan (DB constraint: only 'free' or 'pro')
|
||||||
@@ -408,7 +421,7 @@ async def handle_checkout_completed(session: Dict):
|
|||||||
with get_sync_session() as db_session:
|
with get_sync_session() as db_session:
|
||||||
payment = DBPaymentHistory(
|
payment = DBPaymentHistory(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
stripe_payment_intent_id=session_id,
|
stripe_payment_intent_id=payment_intent_id or session_id,
|
||||||
stripe_invoice_id=subscription_id or session.get("invoice"),
|
stripe_invoice_id=subscription_id or session.get("invoice"),
|
||||||
amount_cents=session.get("amount_total") or 0,
|
amount_cents=session.get("amount_total") or 0,
|
||||||
currency=session.get("currency") or "usd",
|
currency=session.get("currency") or "usd",
|
||||||
@@ -481,15 +494,13 @@ async def handle_subscription_updated(subscription: Dict):
|
|||||||
period_end = subscription.get("current_period_end")
|
period_end = subscription.get("current_period_end")
|
||||||
ends_str = ""
|
ends_str = ""
|
||||||
if period_end:
|
if period_end:
|
||||||
from datetime import timezone
|
|
||||||
ends_str = datetime.fromtimestamp(period_end, tz=timezone.utc).strftime("%d/%m/%Y")
|
ends_str = datetime.fromtimestamp(period_end, tz=timezone.utc).strftime("%d/%m/%Y")
|
||||||
|
|
||||||
|
period_end = subscription.get("current_period_end")
|
||||||
update_user(user_id, {
|
update_user(user_id, {
|
||||||
"subscription_status": status.value,
|
"subscription_status": status.value,
|
||||||
"cancel_at_period_end": stripe_cancel_at_period_end,
|
"cancel_at_period_end": stripe_cancel_at_period_end,
|
||||||
"subscription_ends_at": datetime.fromtimestamp(
|
"subscription_ends_at": datetime.fromtimestamp(period_end, tz=timezone.utc) if period_end else None,
|
||||||
subscription.get("current_period_end", 0)
|
|
||||||
).isoformat() if subscription.get("current_period_end") else None
|
|
||||||
})
|
})
|
||||||
|
|
||||||
# Send cancellation email if they just selected to cancel
|
# Send cancellation email if they just selected to cancel
|
||||||
@@ -531,7 +542,7 @@ async def handle_subscription_deleted(subscription: Dict):
|
|||||||
if not user:
|
if not user:
|
||||||
return
|
return
|
||||||
|
|
||||||
had_active_sub = user.plan != PlanType.FREE.value or user.tier != "free"
|
had_active_sub = user.plan != PlanType.FREE or user.tier != "free"
|
||||||
|
|
||||||
update_user(user_id, {
|
update_user(user_id, {
|
||||||
"plan": PlanType.FREE.value,
|
"plan": PlanType.FREE.value,
|
||||||
@@ -621,6 +632,50 @@ async def handle_payment_failed(invoice: Dict):
|
|||||||
logger.error("handle_payment_failed DB error: %s", exc)
|
logger.error("handle_payment_failed DB error: %s", exc)
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_invoice_paid(invoice: Dict):
|
||||||
|
"""Extend subscription_ends_at when a recurring invoice is paid."""
|
||||||
|
customer_id = invoice.get("customer")
|
||||||
|
if not customer_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
subscription_id = invoice.get("subscription")
|
||||||
|
period_end = invoice.get("period_end") or invoice.get("lines", {}).get("data", [{}])[0].get("period", {}).get("end")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from database.connection import get_sync_session
|
||||||
|
from database.models import User as DBUser
|
||||||
|
|
||||||
|
with get_sync_session() as session:
|
||||||
|
db_user = (
|
||||||
|
session.query(DBUser)
|
||||||
|
.filter(DBUser.stripe_customer_id == customer_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if not db_user:
|
||||||
|
return
|
||||||
|
|
||||||
|
if subscription_id and db_user.stripe_subscription_id != subscription_id:
|
||||||
|
# The paid invoice belongs to a different subscription; do not update.
|
||||||
|
logger.warning(
|
||||||
|
"Invoice paid for customer %s but subscription id mismatch (expected %s, got %s)",
|
||||||
|
customer_id, db_user.stripe_subscription_id, subscription_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if period_end:
|
||||||
|
new_end = datetime.fromtimestamp(period_end, tz=timezone.utc)
|
||||||
|
if db_user.subscription_ends_at is None or new_end > db_user.subscription_ends_at:
|
||||||
|
db_user.subscription_ends_at = new_end
|
||||||
|
db_user.updated_at = datetime.now(timezone.utc)
|
||||||
|
session.commit()
|
||||||
|
logger.info(
|
||||||
|
"Extended subscription_ends_at for user %s to %s",
|
||||||
|
db_user.id, new_end.isoformat(),
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("handle_invoice_paid error: %s", exc)
|
||||||
|
|
||||||
|
|
||||||
async def cancel_subscription(user_id: str) -> Dict[str, Any]:
|
async def cancel_subscription(user_id: str) -> Dict[str, Any]:
|
||||||
"""Cancel a user's subscription at period end."""
|
"""Cancel a user's subscription at period end."""
|
||||||
if not is_stripe_configured():
|
if not is_stripe_configured():
|
||||||
@@ -631,6 +686,10 @@ async def cancel_subscription(user_id: str) -> Dict[str, Any]:
|
|||||||
return {"error": "No active subscription found"}
|
return {"error": "No active subscription found"}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
subscription = stripe.Subscription.retrieve(user.stripe_subscription_id)
|
||||||
|
if subscription.customer != user.stripe_customer_id:
|
||||||
|
return {"error": "Subscription does not belong to current user"}
|
||||||
|
|
||||||
subscription = stripe.Subscription.modify(
|
subscription = stripe.Subscription.modify(
|
||||||
user.stripe_subscription_id,
|
user.stripe_subscription_id,
|
||||||
cancel_at_period_end=True,
|
cancel_at_period_end=True,
|
||||||
@@ -638,13 +697,13 @@ async def cancel_subscription(user_id: str) -> Dict[str, Any]:
|
|||||||
|
|
||||||
cancel_at = None
|
cancel_at = None
|
||||||
if subscription.cancel_at:
|
if subscription.cancel_at:
|
||||||
cancel_at = datetime.fromtimestamp(subscription.cancel_at).isoformat()
|
cancel_at = datetime.fromtimestamp(subscription.cancel_at, tz=timezone.utc)
|
||||||
|
|
||||||
subscription_ends_at = None
|
subscription_ends_at = None
|
||||||
ends_str = ""
|
ends_str = ""
|
||||||
if subscription.current_period_end:
|
if subscription.current_period_end:
|
||||||
subscription_ends_at = datetime.fromtimestamp(subscription.current_period_end).isoformat()
|
subscription_ends_at = datetime.fromtimestamp(subscription.current_period_end, tz=timezone.utc)
|
||||||
ends_str = datetime.fromtimestamp(subscription.current_period_end).strftime("%d/%m/%Y")
|
ends_str = datetime.fromtimestamp(subscription.current_period_end, tz=timezone.utc).strftime("%d/%m/%Y")
|
||||||
|
|
||||||
is_new_cancel = not user.cancel_at_period_end
|
is_new_cancel = not user.cancel_at_period_end
|
||||||
|
|
||||||
|
|||||||
@@ -110,26 +110,22 @@ def client(users_file: Path, monkeypatch):
|
|||||||
monkeypatch.setattr(RateLimitManager, "check_request", _check_request_allow)
|
monkeypatch.setattr(RateLimitManager, "check_request", _check_request_allow)
|
||||||
monkeypatch.setattr(RateLimitManager, "check_translation", _check_translation_allow)
|
monkeypatch.setattr(RateLimitManager, "check_translation", _check_translation_allow)
|
||||||
|
|
||||||
from middleware.tier_quota import TierQuotaService
|
def _check_usage_limits_allow(user):
|
||||||
|
return {
|
||||||
|
"can_translate": True,
|
||||||
|
"docs_used": 0,
|
||||||
|
"docs_limit": 5,
|
||||||
|
"docs_remaining": 5,
|
||||||
|
"pages_used": 0,
|
||||||
|
"extra_credits": 0,
|
||||||
|
"max_pages_per_doc": 50,
|
||||||
|
"max_file_size_mb": 10,
|
||||||
|
"allowed_providers": ["google", "deepl"],
|
||||||
|
}
|
||||||
|
|
||||||
async def _check_quota_allow(self, user_id, tier):
|
monkeypatch.setattr(
|
||||||
from middleware.tier_quota import QuotaResult
|
"routes.translate_routes.check_usage_limits", _check_usage_limits_allow
|
||||||
from datetime import datetime, timezone, timedelta
|
)
|
||||||
|
|
||||||
now = datetime.now(timezone.utc)
|
|
||||||
tomorrow = now.date() + timedelta(days=1)
|
|
||||||
reset_at = datetime(
|
|
||||||
tomorrow.year, tomorrow.month, tomorrow.day, tzinfo=timezone.utc
|
|
||||||
)
|
|
||||||
return QuotaResult(
|
|
||||||
allowed=True, remaining=5, reset_at_utc=reset_at, current_usage=0, limit=5
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _increment_noop(self, user_id):
|
|
||||||
pass
|
|
||||||
|
|
||||||
monkeypatch.setattr(TierQuotaService, "check_quota", _check_quota_allow)
|
|
||||||
monkeypatch.setattr(TierQuotaService, "increment_on_success", _increment_noop)
|
|
||||||
|
|
||||||
from main import app
|
from main import app
|
||||||
|
|
||||||
|
|||||||
@@ -573,16 +573,18 @@ class TestURLIngestionIntegration:
|
|||||||
monkeypatch.setattr("routes.translate_routes.download_from_url", mock_download)
|
monkeypatch.setattr("routes.translate_routes.download_from_url", mock_download)
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"middleware.tier_quota.TierQuotaService.check_quota",
|
"routes.translate_routes.check_usage_limits",
|
||||||
AsyncMock(
|
lambda user: {
|
||||||
return_value=MagicMock(
|
"can_translate": True,
|
||||||
allowed=True,
|
"docs_used": 0,
|
||||||
remaining=5,
|
"docs_limit": 5,
|
||||||
reset_at_utc=None,
|
"docs_remaining": 5,
|
||||||
current_usage=0,
|
"pages_used": 0,
|
||||||
limit=5,
|
"extra_credits": 0,
|
||||||
)
|
"max_pages_per_doc": 50,
|
||||||
),
|
"max_file_size_mb": 10,
|
||||||
|
"allowed_providers": ["google", "deepl"],
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
response = pro_client.post(
|
response = pro_client.post(
|
||||||
@@ -598,16 +600,18 @@ class TestURLIngestionIntegration:
|
|||||||
def test_free_user_rejected(self, free_client, monkeypatch):
|
def test_free_user_rejected(self, free_client, monkeypatch):
|
||||||
"""AC7: Free user receives PRO_FEATURE_REQUIRED (403)"""
|
"""AC7: Free user receives PRO_FEATURE_REQUIRED (403)"""
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"middleware.tier_quota.TierQuotaService.check_quota",
|
"routes.translate_routes.check_usage_limits",
|
||||||
AsyncMock(
|
lambda user: {
|
||||||
return_value=MagicMock(
|
"can_translate": True,
|
||||||
allowed=True,
|
"docs_used": 0,
|
||||||
remaining=5,
|
"docs_limit": 5,
|
||||||
reset_at_utc=None,
|
"docs_remaining": 5,
|
||||||
current_usage=0,
|
"pages_used": 0,
|
||||||
limit=5,
|
"extra_credits": 0,
|
||||||
)
|
"max_pages_per_doc": 50,
|
||||||
),
|
"max_file_size_mb": 10,
|
||||||
|
"allowed_providers": ["google", "deepl"],
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
response = free_client.post(
|
response = free_client.post(
|
||||||
@@ -632,16 +636,18 @@ class TestURLIngestionIntegration:
|
|||||||
monkeypatch.setattr("routes.translate_routes.download_from_url", mock_download)
|
monkeypatch.setattr("routes.translate_routes.download_from_url", mock_download)
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"middleware.tier_quota.TierQuotaService.check_quota",
|
"routes.translate_routes.check_usage_limits",
|
||||||
AsyncMock(
|
lambda user: {
|
||||||
return_value=MagicMock(
|
"can_translate": True,
|
||||||
allowed=True,
|
"docs_used": 0,
|
||||||
remaining=5,
|
"docs_limit": 5,
|
||||||
reset_at_utc=None,
|
"docs_remaining": 5,
|
||||||
current_usage=0,
|
"pages_used": 0,
|
||||||
limit=5,
|
"extra_credits": 0,
|
||||||
)
|
"max_pages_per_doc": 50,
|
||||||
),
|
"max_file_size_mb": 10,
|
||||||
|
"allowed_providers": ["google", "deepl"],
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
response = pro_client.post(
|
response = pro_client.post(
|
||||||
@@ -668,16 +674,18 @@ class TestURLIngestionIntegration:
|
|||||||
monkeypatch.setattr("routes.translate_routes.download_from_url", mock_download)
|
monkeypatch.setattr("routes.translate_routes.download_from_url", mock_download)
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"middleware.tier_quota.TierQuotaService.check_quota",
|
"routes.translate_routes.check_usage_limits",
|
||||||
AsyncMock(
|
lambda user: {
|
||||||
return_value=MagicMock(
|
"can_translate": True,
|
||||||
allowed=True,
|
"docs_used": 0,
|
||||||
remaining=5,
|
"docs_limit": 5,
|
||||||
reset_at_utc=None,
|
"docs_remaining": 5,
|
||||||
current_usage=0,
|
"pages_used": 0,
|
||||||
limit=5,
|
"extra_credits": 0,
|
||||||
)
|
"max_pages_per_doc": 50,
|
||||||
),
|
"max_file_size_mb": 10,
|
||||||
|
"allowed_providers": ["google", "deepl"],
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
response = pro_client.post(
|
response = pro_client.post(
|
||||||
@@ -704,16 +712,18 @@ class TestURLIngestionIntegration:
|
|||||||
monkeypatch.setattr("routes.translate_routes.download_from_url", mock_download)
|
monkeypatch.setattr("routes.translate_routes.download_from_url", mock_download)
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"middleware.tier_quota.TierQuotaService.check_quota",
|
"routes.translate_routes.check_usage_limits",
|
||||||
AsyncMock(
|
lambda user: {
|
||||||
return_value=MagicMock(
|
"can_translate": True,
|
||||||
allowed=True,
|
"docs_used": 0,
|
||||||
remaining=5,
|
"docs_limit": 5,
|
||||||
reset_at_utc=None,
|
"docs_remaining": 5,
|
||||||
current_usage=0,
|
"pages_used": 0,
|
||||||
limit=5,
|
"extra_credits": 0,
|
||||||
)
|
"max_pages_per_doc": 50,
|
||||||
),
|
"max_file_size_mb": 10,
|
||||||
|
"allowed_providers": ["google", "deepl"],
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
response = pro_client.post(
|
response = pro_client.post(
|
||||||
@@ -738,16 +748,18 @@ class TestURLIngestionIntegration:
|
|||||||
monkeypatch.setattr("routes.translate_routes.download_from_url", mock_download)
|
monkeypatch.setattr("routes.translate_routes.download_from_url", mock_download)
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"middleware.tier_quota.TierQuotaService.check_quota",
|
"routes.translate_routes.check_usage_limits",
|
||||||
AsyncMock(
|
lambda user: {
|
||||||
return_value=MagicMock(
|
"can_translate": True,
|
||||||
allowed=True,
|
"docs_used": 0,
|
||||||
remaining=5,
|
"docs_limit": 5,
|
||||||
reset_at_utc=None,
|
"docs_remaining": 5,
|
||||||
current_usage=0,
|
"pages_used": 0,
|
||||||
limit=5,
|
"extra_credits": 0,
|
||||||
)
|
"max_pages_per_doc": 50,
|
||||||
),
|
"max_file_size_mb": 10,
|
||||||
|
"allowed_providers": ["google", "deepl"],
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
response = pro_client.post(
|
response = pro_client.post(
|
||||||
@@ -772,16 +784,18 @@ class TestURLIngestionIntegration:
|
|||||||
monkeypatch.setattr("routes.translate_routes.download_from_url", mock_download)
|
monkeypatch.setattr("routes.translate_routes.download_from_url", mock_download)
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"middleware.tier_quota.TierQuotaService.check_quota",
|
"routes.translate_routes.check_usage_limits",
|
||||||
AsyncMock(
|
lambda user: {
|
||||||
return_value=MagicMock(
|
"can_translate": True,
|
||||||
allowed=True,
|
"docs_used": 0,
|
||||||
remaining=5,
|
"docs_limit": 5,
|
||||||
reset_at_utc=None,
|
"docs_remaining": 5,
|
||||||
current_usage=0,
|
"pages_used": 0,
|
||||||
limit=5,
|
"extra_credits": 0,
|
||||||
)
|
"max_pages_per_doc": 50,
|
||||||
),
|
"max_file_size_mb": 10,
|
||||||
|
"allowed_providers": ["google", "deepl"],
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
response = pro_client.post(
|
response = pro_client.post(
|
||||||
|
|||||||
@@ -115,26 +115,22 @@ def client(users_file: Path, monkeypatch):
|
|||||||
monkeypatch.setattr(RateLimitManager, "check_request", _check_request_allow)
|
monkeypatch.setattr(RateLimitManager, "check_request", _check_request_allow)
|
||||||
monkeypatch.setattr(RateLimitManager, "check_translation", _check_translation_allow)
|
monkeypatch.setattr(RateLimitManager, "check_translation", _check_translation_allow)
|
||||||
|
|
||||||
from middleware.tier_quota import TierQuotaService
|
def _check_usage_limits_allow(user):
|
||||||
|
return {
|
||||||
|
"can_translate": True,
|
||||||
|
"docs_used": 0,
|
||||||
|
"docs_limit": 5,
|
||||||
|
"docs_remaining": 5,
|
||||||
|
"pages_used": 0,
|
||||||
|
"extra_credits": 0,
|
||||||
|
"max_pages_per_doc": 50,
|
||||||
|
"max_file_size_mb": 10,
|
||||||
|
"allowed_providers": ["google", "deepl"],
|
||||||
|
}
|
||||||
|
|
||||||
async def _check_quota_allow(self, user_id, tier):
|
monkeypatch.setattr(
|
||||||
from middleware.tier_quota import QuotaResult
|
"routes.translate_routes.check_usage_limits", _check_usage_limits_allow
|
||||||
from datetime import datetime, timezone, timedelta
|
)
|
||||||
|
|
||||||
now = datetime.now(timezone.utc)
|
|
||||||
tomorrow = now.date() + timedelta(days=1)
|
|
||||||
reset_at = datetime(
|
|
||||||
tomorrow.year, tomorrow.month, tomorrow.day, tzinfo=timezone.utc
|
|
||||||
)
|
|
||||||
return QuotaResult(
|
|
||||||
allowed=True, remaining=5, reset_at_utc=reset_at, current_usage=0, limit=5
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _increment_noop(self, user_id):
|
|
||||||
pass
|
|
||||||
|
|
||||||
monkeypatch.setattr(TierQuotaService, "check_quota", _check_quota_allow)
|
|
||||||
monkeypatch.setattr(TierQuotaService, "increment_on_success", _increment_noop)
|
|
||||||
|
|
||||||
from main import app
|
from main import app
|
||||||
|
|
||||||
@@ -477,24 +473,22 @@ class TestQuotaExceeded:
|
|||||||
|
|
||||||
def test_returns_429_when_quota_exceeded(self, client, monkeypatch):
|
def test_returns_429_when_quota_exceeded(self, client, monkeypatch):
|
||||||
"""Returns 429 with QUOTA_EXCEEDED when quota exceeded"""
|
"""Returns 429 with QUOTA_EXCEEDED when quota exceeded"""
|
||||||
from middleware.tier_quota import TierQuotaService, QuotaResult
|
def _check_usage_limits_denied(user):
|
||||||
from datetime import datetime, timezone, timedelta
|
return {
|
||||||
|
"can_translate": False,
|
||||||
|
"docs_used": 5,
|
||||||
|
"docs_limit": 5,
|
||||||
|
"docs_remaining": 0,
|
||||||
|
"pages_used": 0,
|
||||||
|
"extra_credits": 0,
|
||||||
|
"max_pages_per_doc": 50,
|
||||||
|
"max_file_size_mb": 10,
|
||||||
|
"allowed_providers": ["google", "deepl"],
|
||||||
|
}
|
||||||
|
|
||||||
async def _check_quota_denied(self, user_id, tier):
|
monkeypatch.setattr(
|
||||||
now = datetime.now(timezone.utc)
|
"routes.translate_routes.check_usage_limits", _check_usage_limits_denied
|
||||||
tomorrow = now.date() + timedelta(days=1)
|
)
|
||||||
reset_at = datetime(
|
|
||||||
tomorrow.year, tomorrow.month, tomorrow.day, tzinfo=timezone.utc
|
|
||||||
)
|
|
||||||
return QuotaResult(
|
|
||||||
allowed=False,
|
|
||||||
remaining=0,
|
|
||||||
reset_at_utc=reset_at,
|
|
||||||
current_usage=5,
|
|
||||||
limit=5,
|
|
||||||
)
|
|
||||||
|
|
||||||
monkeypatch.setattr(TierQuotaService, "check_quota", _check_quota_denied)
|
|
||||||
|
|
||||||
# Register and login
|
# Register and login
|
||||||
client.post(REGISTER_URL, json=VALID_USER)
|
client.post(REGISTER_URL, json=VALID_USER)
|
||||||
@@ -530,24 +524,23 @@ class TestQuotaExceeded:
|
|||||||
|
|
||||||
def test_includes_retry_after_header(self, client, monkeypatch):
|
def test_includes_retry_after_header(self, client, monkeypatch):
|
||||||
"""Includes Retry-After header on 429"""
|
"""Includes Retry-After header on 429"""
|
||||||
from middleware.tier_quota import TierQuotaService, QuotaResult
|
|
||||||
from datetime import datetime, timezone, timedelta
|
|
||||||
|
|
||||||
async def _check_quota_denied(self, user_id, tier):
|
def _check_usage_limits_denied(user):
|
||||||
now = datetime.now(timezone.utc)
|
return {
|
||||||
tomorrow = now.date() + timedelta(days=1)
|
"can_translate": False,
|
||||||
reset_at = datetime(
|
"docs_used": 5,
|
||||||
tomorrow.year, tomorrow.month, tomorrow.day, tzinfo=timezone.utc
|
"docs_limit": 5,
|
||||||
)
|
"docs_remaining": 0,
|
||||||
return QuotaResult(
|
"pages_used": 0,
|
||||||
allowed=False,
|
"extra_credits": 0,
|
||||||
remaining=0,
|
"max_pages_per_doc": 50,
|
||||||
reset_at_utc=reset_at,
|
"max_file_size_mb": 10,
|
||||||
current_usage=5,
|
"allowed_providers": ["google", "deepl"],
|
||||||
limit=5,
|
}
|
||||||
)
|
|
||||||
|
|
||||||
monkeypatch.setattr(TierQuotaService, "check_quota", _check_quota_denied)
|
monkeypatch.setattr(
|
||||||
|
"routes.translate_routes.check_usage_limits", _check_usage_limits_denied
|
||||||
|
)
|
||||||
|
|
||||||
client.post(REGISTER_URL, json=VALID_USER)
|
client.post(REGISTER_URL, json=VALID_USER)
|
||||||
response = client.post(
|
response = client.post(
|
||||||
|
|||||||
Reference in New Issue
Block a user