Files
Momento/memento-note/lib/byok.ts
Antigravity a623454347
Some checks failed
CI / Lint, Unit Tests & Build (push) Failing after 1m32s
CI / Deploy production (on server) (push) Has been skipped
perf: memo GridCard, fuse save fns, fix slash tab active color
2026-06-14 14:06:05 +00:00

335 lines
10 KiB
TypeScript

import { prisma } from '@/lib/prisma';
import { decryptApiKey, encryptApiKey, hashApiKey } from '@/lib/crypto';
import {
VALID_PROVIDERS,
type AiGatewayProvider,
} from '@/lib/ai/router';
import { getProviderConfigKeys } from '@/lib/ai/factory';
import { getUserInfo, type SubscriptionTier } from '@/lib/entitlements';
import { redis } from '@/lib/redis';
/** Thrown when user has active BYOK configured but it can't be loaded (decryption failure, etc.). */
export class ByokUnavailableError extends Error {
readonly code = 'BYOK_UNAVAILABLE';
constructor(msg = 'Votre clé API est configurée mais n\'a pas pu être chargée.') {
super(msg);
this.name = 'ByokUnavailableError';
}
}
const PRO_BYOK_PROVIDERS: readonly AiGatewayProvider[] = [
'openai',
'anthropic',
'deepseek',
'openrouter',
'minimax',
'zai',
'custom_openai', // Custom OpenAI-compatible API
'custom_anthropic', // Custom Anthropic-compatible API
];
const BUSINESS_BYOK_PROVIDERS: readonly AiGatewayProvider[] = [
...VALID_PROVIDERS,
].filter((p) => p !== 'ollama' && p !== 'lmstudio') as AiGatewayProvider[];
export function getAllowedByokProviders(
tier: SubscriptionTier,
): readonly AiGatewayProvider[] {
if (tier === 'BASIC') return [];
if (tier === 'PRO') return PRO_BYOK_PROVIDERS;
return BUSINESS_BYOK_PROVIDERS;
}
export function isByokProviderAllowed(
tier: SubscriptionTier,
provider: string,
): boolean {
return getAllowedByokProviders(tier).includes(provider as AiGatewayProvider);
}
export async function hasAnyActiveByok(userId: string): Promise<boolean> {
const count = await prisma.userAPIKey.count({
where: { userId, isActive: true },
});
return count > 0;
}
/**
* Get active BYOK key for a user and provider.
* Optionally pass tier to avoid an extra query. If not provided, tier will be fetched.
*/
export async function getActiveByokKey(userId: string, provider: string, tier?: SubscriptionTier) {
const key = await prisma.userAPIKey.findFirst({
where: { userId, provider, isActive: true },
});
// Safety check: if key exists but provider is no longer allowed for user's tier, deactivate it
if (key) {
const effectiveTier = tier ?? (await getUserInfo(userId)).tier;
if (!isByokProviderAllowed(effectiveTier, provider)) {
await prisma.userAPIKey.update({
where: { id: key.id },
data: { isActive: false },
});
console.warn(`[byok] Deactivated key for ${provider} (user ${userId}) - tier ${effectiveTier} does not allow this provider`);
return null;
}
}
return key;
}
/**
* Deactivate all API keys that are no longer allowed for the user's current tier.
* Call this when a user's subscription tier changes.
*/
export async function deactivateUnauthorizedKeys(userId: string): Promise<number> {
const { tier } = await getUserInfo(userId);
const allowedProviders = new Set(getAllowedByokProviders(tier));
// Find all active keys
const allKeys = await prisma.userAPIKey.findMany({
where: { userId, isActive: true },
select: { id: true, provider: true },
});
// Filter keys that are no longer allowed
const toDeactivate = allKeys.filter((k) => !allowedProviders.has(k.provider as AiGatewayProvider));
if (toDeactivate.length === 0) return 0;
// Batch deactivate
await prisma.userAPIKey.updateMany({
where: {
id: { in: toDeactivate.map((k) => k.id) },
},
data: { isActive: false },
});
console.log(`[byok] Deactivated ${toDeactivate.length} keys for user ${userId} after tier change to ${tier}`);
return toDeactivate.length;
}
export async function resolveByokApiKey(
userId: string,
providerType: string,
feature?: string,
): Promise<{ plaintext: string; provider: string; model: string | null; baseUrl: string | null } | null> {
const row = await getActiveByokKey(userId, providerType);
if (!row) return null;
try {
const plaintext = await decryptApiKey(row.encryptedKey);
prisma.userAPIKey.update({
where: { id: row.id },
data: { lastUsedAt: new Date(), lastUsedFor: feature || null },
}).catch((err) => { console.error('[byok] Failed to update lastUsedAt/lastUsedFor:', err); });
return { plaintext, provider: row.provider, model: row.model, baseUrl: row.baseUrl ?? null };
} catch (err) {
console.error('[byok] Failed to decrypt key for provider', providerType, err);
return null;
}
}
/**
* Returns any active BYOK key for the user.
* Prefers a key matching preferredProvider, falls back to any active key.
* This allows BYOK to work regardless of which provider the admin configured.
*/
export async function getAnyActiveByokForUser(
userId: string,
preferredProvider?: string,
feature?: string,
): Promise<{ plaintext: string; provider: string; model: string | null; baseUrl: string | null } | null> {
// 1. Try exact match first
if (preferredProvider) {
const exact = await resolveByokApiKey(userId, preferredProvider, feature);
if (exact) return exact;
}
// 2. Fall back to any active key
const anyRow = await prisma.userAPIKey.findFirst({
where: { userId, isActive: true },
orderBy: { lastUsedAt: 'desc' },
});
if (!anyRow) return null;
try {
const plaintext = await decryptApiKey(anyRow.encryptedKey);
prisma.userAPIKey.update({
where: { id: anyRow.id },
data: { lastUsedAt: new Date(), lastUsedFor: feature || null },
}).catch(() => {});
return { plaintext, provider: anyRow.provider, model: anyRow.model, baseUrl: anyRow.baseUrl ?? null };
} catch (err) {
console.error('[byok] Failed to decrypt any active key:', err);
return null;
}
}
export async function applyByokToConfig(
billingUserId: string,
providerType: string,
config: Record<string, string>,
feature?: string,
): Promise<{ config: Record<string, string>; usedByok: boolean; model: string | null }> {
const byok = await resolveByokApiKey(billingUserId, providerType, feature);
if (!byok) return { config, usedByok: false, model: null };
const { apiKeyConfigKey } = getProviderConfigKeys(providerType);
if (!apiKeyConfigKey) return { config, usedByok: false, model: null };
return {
config: { ...config, [apiKeyConfigKey]: byok.plaintext },
usedByok: true,
model: byok.model,
};
}
export async function upsertUserApiKey(params: {
userId: string;
provider: AiGatewayProvider;
plaintext: string;
alias?: string;
model?: string;
baseUrl?: string;
}) {
const encryptedKey = await encryptApiKey(params.plaintext);
const keyHash = hashApiKey(params.plaintext);
return prisma.userAPIKey.upsert({
where: {
userId_provider: {
userId: params.userId,
provider: params.provider,
},
},
create: {
userId: params.userId,
provider: params.provider,
alias: params.alias ?? '',
encryptedKey,
keyHash,
model: params.model ?? null,
baseUrl: params.baseUrl ?? null,
isActive: true,
},
update: {
alias: params.alias ?? '',
encryptedKey,
keyHash,
model: params.model ?? null,
baseUrl: params.baseUrl ?? null,
isActive: true,
},
});
}
/**
* Check if this API key hash is already used by another provider for this user.
* Returns the existing provider if found, null otherwise.
*/
export async function findDuplicateApiKeyHash(
userId: string,
keyHash: string,
excludeProvider?: string,
): Promise<string | null> {
const existing = await prisma.userAPIKey.findFirst({
where: {
userId,
keyHash,
...(excludeProvider ? { provider: { not: excludeProvider } } : {}),
},
select: { provider: true },
});
return existing?.provider ?? null;
}
export function toPublicApiKey(row: {
provider: string;
alias: string;
model: string | null;
baseUrl: string | null;
isActive: boolean;
lastUsedAt: Date | null;
createdAt: Date;
updatedAt: Date;
}) {
return {
provider: row.provider,
alias: row.alias,
model: row.model,
baseUrl: row.baseUrl,
isActive: row.isActive,
lastUsedAt: row.lastUsedAt,
createdAt: row.createdAt,
updatedAt: row.updatedAt,
};
}
/**
* Rate limit for API key creation: max 5 keys per hour per user.
* Uses atomic Lua script to prevent race conditions.
* Returns true if limit is not exceeded, false if rate limited.
*/
const RATE_LIMIT_LUA = `
local key = KEYS[1]
local limit = tonumber(ARGV[1])
local window = tonumber(ARGV[2])
local current = tonumber(redis.call('GET', key) or '0')
if current >= limit then
local ttl = redis.call('TTL', key)
return {-1, ttl}
end
local newCount = redis.call('INCR', key)
if newCount == 1 then
redis.call('EXPIRE', key, window)
end
return {newCount, limit - newCount}
`;
// Simple in-memory cache for rate limit results (30s TTL)
const rateLimitCache = new Map<string, { result: { allowed: boolean; remaining: number; resetAt: Date | null }; expiresAt: number }>();
const CACHE_TTL = 30_000; // 30 seconds
export async function checkApiKeyCreationRateLimit(userId: string): Promise<{ allowed: boolean; remaining: number; resetAt: Date | null }> {
const key = `byok:ratelimit:create:${userId}`;
const limit = 5; // max 5 creations per hour
const window = 60 * 60; // 1 hour in seconds
// Check cache first
const cached = rateLimitCache.get(key);
if (cached && Date.now() < cached.expiresAt) {
return cached.result;
}
try {
const result = await redis.eval(RATE_LIMIT_LUA, 1, key, String(limit), String(window)) as number[];
if (!Array.isArray(result)) {
// Fallback for non-array results (shouldn't happen with correct Lua)
return { allowed: true, remaining: limit, resetAt: null };
}
const [value, ttlOrRemaining] = result;
// Rate limited
if (value === -1) {
const ttl = ttlOrRemaining as number;
const resetAt = ttl > 0 ? new Date(Date.now() + ttl * 1000) : null;
const rateLimitResult = { allowed: false, remaining: 0, resetAt };
// Cache rate limited results for 60s
rateLimitCache.set(key, { result: rateLimitResult, expiresAt: Date.now() + 60_000 });
return rateLimitResult;
}
// Allowed
const remaining = ttlOrRemaining as number;
const rateLimitResult = { allowed: true, remaining, resetAt: null };
// Don't cache allowed results too aggressively (let users create keys)
return rateLimitResult;
} catch (err) {
console.error('[byok] Rate limit check failed, allowing request:', err);
return { allowed: true, remaining: limit, resetAt: null };
}
}