335 lines
10 KiB
TypeScript
335 lines
10 KiB
TypeScript
import { prisma } from '@/lib/prisma';
|
|
import { decryptApiKey, encryptApiKey, hashApiKey } from '@/lib/crypto';
|
|
import {
|
|
VALID_PROVIDERS,
|
|
type AiGatewayProvider,
|
|
} from '@/lib/ai/router';
|
|
import { getProviderConfigKeys } from '@/lib/ai/factory';
|
|
import { getUserInfo, type SubscriptionTier } from '@/lib/entitlements';
|
|
import { redis } from '@/lib/redis';
|
|
|
|
/** Thrown when user has active BYOK configured but it can't be loaded (decryption failure, etc.). */
|
|
export class ByokUnavailableError extends Error {
|
|
readonly code = 'BYOK_UNAVAILABLE';
|
|
constructor(msg = 'Votre clé API est configurée mais n\'a pas pu être chargée.') {
|
|
super(msg);
|
|
this.name = 'ByokUnavailableError';
|
|
}
|
|
}
|
|
|
|
const PRO_BYOK_PROVIDERS: readonly AiGatewayProvider[] = [
|
|
'openai',
|
|
'anthropic',
|
|
'deepseek',
|
|
'openrouter',
|
|
'minimax',
|
|
'zai',
|
|
'custom_openai', // Custom OpenAI-compatible API
|
|
'custom_anthropic', // Custom Anthropic-compatible API
|
|
];
|
|
|
|
const BUSINESS_BYOK_PROVIDERS: readonly AiGatewayProvider[] = [
|
|
...VALID_PROVIDERS,
|
|
].filter((p) => p !== 'ollama' && p !== 'lmstudio') as AiGatewayProvider[];
|
|
|
|
export function getAllowedByokProviders(
|
|
tier: SubscriptionTier,
|
|
): readonly AiGatewayProvider[] {
|
|
if (tier === 'BASIC') return [];
|
|
if (tier === 'PRO') return PRO_BYOK_PROVIDERS;
|
|
return BUSINESS_BYOK_PROVIDERS;
|
|
}
|
|
|
|
export function isByokProviderAllowed(
|
|
tier: SubscriptionTier,
|
|
provider: string,
|
|
): boolean {
|
|
return getAllowedByokProviders(tier).includes(provider as AiGatewayProvider);
|
|
}
|
|
|
|
export async function hasAnyActiveByok(userId: string): Promise<boolean> {
|
|
const count = await prisma.userAPIKey.count({
|
|
where: { userId, isActive: true },
|
|
});
|
|
return count > 0;
|
|
}
|
|
|
|
/**
|
|
* Get active BYOK key for a user and provider.
|
|
* Optionally pass tier to avoid an extra query. If not provided, tier will be fetched.
|
|
*/
|
|
export async function getActiveByokKey(userId: string, provider: string, tier?: SubscriptionTier) {
|
|
const key = await prisma.userAPIKey.findFirst({
|
|
where: { userId, provider, isActive: true },
|
|
});
|
|
|
|
// Safety check: if key exists but provider is no longer allowed for user's tier, deactivate it
|
|
if (key) {
|
|
const effectiveTier = tier ?? (await getUserInfo(userId)).tier;
|
|
if (!isByokProviderAllowed(effectiveTier, provider)) {
|
|
await prisma.userAPIKey.update({
|
|
where: { id: key.id },
|
|
data: { isActive: false },
|
|
});
|
|
console.warn(`[byok] Deactivated key for ${provider} (user ${userId}) - tier ${effectiveTier} does not allow this provider`);
|
|
return null;
|
|
}
|
|
}
|
|
|
|
return key;
|
|
}
|
|
|
|
/**
|
|
* Deactivate all API keys that are no longer allowed for the user's current tier.
|
|
* Call this when a user's subscription tier changes.
|
|
*/
|
|
export async function deactivateUnauthorizedKeys(userId: string): Promise<number> {
|
|
const { tier } = await getUserInfo(userId);
|
|
const allowedProviders = new Set(getAllowedByokProviders(tier));
|
|
|
|
// Find all active keys
|
|
const allKeys = await prisma.userAPIKey.findMany({
|
|
where: { userId, isActive: true },
|
|
select: { id: true, provider: true },
|
|
});
|
|
|
|
// Filter keys that are no longer allowed
|
|
const toDeactivate = allKeys.filter((k) => !allowedProviders.has(k.provider as AiGatewayProvider));
|
|
|
|
if (toDeactivate.length === 0) return 0;
|
|
|
|
// Batch deactivate
|
|
await prisma.userAPIKey.updateMany({
|
|
where: {
|
|
id: { in: toDeactivate.map((k) => k.id) },
|
|
},
|
|
data: { isActive: false },
|
|
});
|
|
|
|
console.log(`[byok] Deactivated ${toDeactivate.length} keys for user ${userId} after tier change to ${tier}`);
|
|
return toDeactivate.length;
|
|
}
|
|
|
|
export async function resolveByokApiKey(
|
|
userId: string,
|
|
providerType: string,
|
|
feature?: string,
|
|
): Promise<{ plaintext: string; provider: string; model: string | null; baseUrl: string | null } | null> {
|
|
const row = await getActiveByokKey(userId, providerType);
|
|
if (!row) return null;
|
|
try {
|
|
const plaintext = await decryptApiKey(row.encryptedKey);
|
|
prisma.userAPIKey.update({
|
|
where: { id: row.id },
|
|
data: { lastUsedAt: new Date(), lastUsedFor: feature || null },
|
|
}).catch((err) => { console.error('[byok] Failed to update lastUsedAt/lastUsedFor:', err); });
|
|
return { plaintext, provider: row.provider, model: row.model, baseUrl: row.baseUrl ?? null };
|
|
} catch (err) {
|
|
console.error('[byok] Failed to decrypt key for provider', providerType, err);
|
|
return null;
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Returns any active BYOK key for the user.
|
|
* Prefers a key matching preferredProvider, falls back to any active key.
|
|
* This allows BYOK to work regardless of which provider the admin configured.
|
|
*/
|
|
export async function getAnyActiveByokForUser(
|
|
userId: string,
|
|
preferredProvider?: string,
|
|
feature?: string,
|
|
): Promise<{ plaintext: string; provider: string; model: string | null; baseUrl: string | null } | null> {
|
|
// 1. Try exact match first
|
|
if (preferredProvider) {
|
|
const exact = await resolveByokApiKey(userId, preferredProvider, feature);
|
|
if (exact) return exact;
|
|
}
|
|
|
|
// 2. Fall back to any active key
|
|
const anyRow = await prisma.userAPIKey.findFirst({
|
|
where: { userId, isActive: true },
|
|
orderBy: { lastUsedAt: 'desc' },
|
|
});
|
|
if (!anyRow) return null;
|
|
|
|
try {
|
|
const plaintext = await decryptApiKey(anyRow.encryptedKey);
|
|
prisma.userAPIKey.update({
|
|
where: { id: anyRow.id },
|
|
data: { lastUsedAt: new Date(), lastUsedFor: feature || null },
|
|
}).catch(() => {});
|
|
return { plaintext, provider: anyRow.provider, model: anyRow.model, baseUrl: anyRow.baseUrl ?? null };
|
|
} catch (err) {
|
|
console.error('[byok] Failed to decrypt any active key:', err);
|
|
return null;
|
|
}
|
|
}
|
|
|
|
export async function applyByokToConfig(
|
|
billingUserId: string,
|
|
providerType: string,
|
|
config: Record<string, string>,
|
|
feature?: string,
|
|
): Promise<{ config: Record<string, string>; usedByok: boolean; model: string | null }> {
|
|
const byok = await resolveByokApiKey(billingUserId, providerType, feature);
|
|
if (!byok) return { config, usedByok: false, model: null };
|
|
|
|
const { apiKeyConfigKey } = getProviderConfigKeys(providerType);
|
|
if (!apiKeyConfigKey) return { config, usedByok: false, model: null };
|
|
|
|
return {
|
|
config: { ...config, [apiKeyConfigKey]: byok.plaintext },
|
|
usedByok: true,
|
|
model: byok.model,
|
|
};
|
|
}
|
|
|
|
export async function upsertUserApiKey(params: {
|
|
userId: string;
|
|
provider: AiGatewayProvider;
|
|
plaintext: string;
|
|
alias?: string;
|
|
model?: string;
|
|
baseUrl?: string;
|
|
}) {
|
|
const encryptedKey = await encryptApiKey(params.plaintext);
|
|
const keyHash = hashApiKey(params.plaintext);
|
|
|
|
return prisma.userAPIKey.upsert({
|
|
where: {
|
|
userId_provider: {
|
|
userId: params.userId,
|
|
provider: params.provider,
|
|
},
|
|
},
|
|
create: {
|
|
userId: params.userId,
|
|
provider: params.provider,
|
|
alias: params.alias ?? '',
|
|
encryptedKey,
|
|
keyHash,
|
|
model: params.model ?? null,
|
|
baseUrl: params.baseUrl ?? null,
|
|
isActive: true,
|
|
},
|
|
update: {
|
|
alias: params.alias ?? '',
|
|
encryptedKey,
|
|
keyHash,
|
|
model: params.model ?? null,
|
|
baseUrl: params.baseUrl ?? null,
|
|
isActive: true,
|
|
},
|
|
});
|
|
}
|
|
|
|
/**
|
|
* Check if this API key hash is already used by another provider for this user.
|
|
* Returns the existing provider if found, null otherwise.
|
|
*/
|
|
export async function findDuplicateApiKeyHash(
|
|
userId: string,
|
|
keyHash: string,
|
|
excludeProvider?: string,
|
|
): Promise<string | null> {
|
|
const existing = await prisma.userAPIKey.findFirst({
|
|
where: {
|
|
userId,
|
|
keyHash,
|
|
...(excludeProvider ? { provider: { not: excludeProvider } } : {}),
|
|
},
|
|
select: { provider: true },
|
|
});
|
|
return existing?.provider ?? null;
|
|
}
|
|
|
|
export function toPublicApiKey(row: {
|
|
provider: string;
|
|
alias: string;
|
|
model: string | null;
|
|
baseUrl: string | null;
|
|
isActive: boolean;
|
|
lastUsedAt: Date | null;
|
|
createdAt: Date;
|
|
updatedAt: Date;
|
|
}) {
|
|
return {
|
|
provider: row.provider,
|
|
alias: row.alias,
|
|
model: row.model,
|
|
baseUrl: row.baseUrl,
|
|
isActive: row.isActive,
|
|
lastUsedAt: row.lastUsedAt,
|
|
createdAt: row.createdAt,
|
|
updatedAt: row.updatedAt,
|
|
};
|
|
}
|
|
|
|
/**
|
|
* Rate limit for API key creation: max 5 keys per hour per user.
|
|
* Uses atomic Lua script to prevent race conditions.
|
|
* Returns true if limit is not exceeded, false if rate limited.
|
|
*/
|
|
const RATE_LIMIT_LUA = `
|
|
local key = KEYS[1]
|
|
local limit = tonumber(ARGV[1])
|
|
local window = tonumber(ARGV[2])
|
|
local current = tonumber(redis.call('GET', key) or '0')
|
|
if current >= limit then
|
|
local ttl = redis.call('TTL', key)
|
|
return {-1, ttl}
|
|
end
|
|
local newCount = redis.call('INCR', key)
|
|
if newCount == 1 then
|
|
redis.call('EXPIRE', key, window)
|
|
end
|
|
return {newCount, limit - newCount}
|
|
`;
|
|
|
|
// Simple in-memory cache for rate limit results (30s TTL)
|
|
const rateLimitCache = new Map<string, { result: { allowed: boolean; remaining: number; resetAt: Date | null }; expiresAt: number }>();
|
|
const CACHE_TTL = 30_000; // 30 seconds
|
|
|
|
export async function checkApiKeyCreationRateLimit(userId: string): Promise<{ allowed: boolean; remaining: number; resetAt: Date | null }> {
|
|
const key = `byok:ratelimit:create:${userId}`;
|
|
const limit = 5; // max 5 creations per hour
|
|
const window = 60 * 60; // 1 hour in seconds
|
|
|
|
// Check cache first
|
|
const cached = rateLimitCache.get(key);
|
|
if (cached && Date.now() < cached.expiresAt) {
|
|
return cached.result;
|
|
}
|
|
|
|
try {
|
|
const result = await redis.eval(RATE_LIMIT_LUA, 1, key, String(limit), String(window)) as number[];
|
|
|
|
if (!Array.isArray(result)) {
|
|
// Fallback for non-array results (shouldn't happen with correct Lua)
|
|
return { allowed: true, remaining: limit, resetAt: null };
|
|
}
|
|
|
|
const [value, ttlOrRemaining] = result;
|
|
|
|
// Rate limited
|
|
if (value === -1) {
|
|
const ttl = ttlOrRemaining as number;
|
|
const resetAt = ttl > 0 ? new Date(Date.now() + ttl * 1000) : null;
|
|
const rateLimitResult = { allowed: false, remaining: 0, resetAt };
|
|
// Cache rate limited results for 60s
|
|
rateLimitCache.set(key, { result: rateLimitResult, expiresAt: Date.now() + 60_000 });
|
|
return rateLimitResult;
|
|
}
|
|
|
|
// Allowed
|
|
const remaining = ttlOrRemaining as number;
|
|
const rateLimitResult = { allowed: true, remaining, resetAt: null };
|
|
// Don't cache allowed results too aggressively (let users create keys)
|
|
return rateLimitResult;
|
|
} catch (err) {
|
|
console.error('[byok] Rate limit check failed, allowing request:', err);
|
|
return { allowed: true, remaining: limit, resetAt: null };
|
|
}
|
|
}
|