Files
Momento/memento-note/lib/ai/providers/custom-openai.ts
Antigravity e09ea3a145
All checks were successful
Deploy to Production / Build and Deploy (push) Successful in 5s
fix: switch embedding dimension from 1536 to 2560 for qwen-embedding-4b
2026-05-12 09:07:55 +00:00

215 lines
7.5 KiB
TypeScript

import { createOpenAI } from '@ai-sdk/openai';
import { generateObject, generateText as aiGenerateText, stepCountIs } from 'ai';
import { z } from 'zod';
import { AIProvider, TagSuggestion, TitleSuggestion, ToolUseOptions, ToolCallResult } from '../types';
export class CustomOpenAIProvider implements AIProvider {
private model: any;
private apiKey: string;
private baseUrl: string;
private embeddingModelName: string;
constructor(
apiKey: string,
baseUrl: string,
modelName: string = 'gpt-4o-mini',
embeddingModelName: string = 'text-embedding-3-small'
) {
this.apiKey = apiKey;
this.baseUrl = baseUrl.endsWith('/') ? baseUrl.slice(0, -1) : baseUrl;
this.embeddingModelName = embeddingModelName;
// Create OpenAI-compatible client with custom base URL
// Use .chat() to force /chat/completions endpoint (avoids Responses API)
const customClient = createOpenAI({
baseURL: baseUrl,
apiKey: apiKey,
fetch: async (url, options) => {
const headers = new Headers(options?.headers);
headers.set('HTTP-Referer', 'https://localhost:3000');
headers.set('X-Title', 'Memento AI');
// Disable DeepSeek extended thinking for reliable tool/function calling
if (options?.body) {
try {
const body = JSON.parse(options.body as string)
if (
typeof body.model === 'string' &&
(body.model.includes('deepseek') || body.model.includes('thinking') || body.model.includes('reasoner'))
) {
body.thinking = { type: 'disabled' }
}
return fetch(url, { ...options, headers, body: JSON.stringify(body) })
} catch { /* ignore parse errors */ }
}
return fetch(url, { ...options, headers });
}
});
this.model = customClient.chat(modelName);
}
private async fetchWithTimeout(url: string, options: RequestInit, timeoutMs: number = 60_000): Promise<Response> {
const controller = new AbortController()
const timer = setTimeout(() => controller.abort(), timeoutMs)
try {
return await fetch(url, { ...options, signal: controller.signal })
} finally {
clearTimeout(timer)
}
}
async generateTags(content: string): Promise<TagSuggestion[]> {
try {
const { object } = await generateObject({
model: this.model,
schema: z.object({
tags: z.array(z.object({
tag: z.string().describe('Short tag name in lowercase'),
confidence: z.number().min(0).max(1).describe('Confidence level between 0 and 1')
}))
}),
prompt: `Analyze the following note and suggest 1 to 5 relevant tags.
Note content: "${content}"`,
});
return object.tags;
} catch (e) {
console.error('Error generating tags (Custom OpenAI):', e);
return [];
}
}
async getEmbeddings(text: string): Promise<number[]> {
try {
const response = await this.fetchWithTimeout(`${this.baseUrl}/embeddings`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Authorization': `Bearer ${this.apiKey}`,
'HTTP-Referer': 'https://localhost:3000',
'X-Title': 'Memento AI',
},
body: JSON.stringify({
model: this.embeddingModelName,
input: text,
}),
});
if (!response.ok) {
const errText = await response.text();
throw new Error(`${this.baseUrl}/embeddings error ${response.status}: ${errText}`);
}
const data = await response.json();
// Standard OpenAI-compatible response: { data: [{ embedding: number[] }] }
if (data.data && Array.isArray(data.data) && data.data[0]?.embedding) {
return data.data[0].embedding;
}
// Fallback: some providers return { embedding: number[] }
if (data.embedding && Array.isArray(data.embedding)) {
return data.embedding;
}
throw new Error(`Unexpected embeddings response shape: ${JSON.stringify(data)}`);
} catch (e) {
console.error('Error generating embeddings (CustomOpenAI):', e);
throw e;
}
}
async generateTitles(prompt: string): Promise<TitleSuggestion[]> {
try {
// Use generateText instead of generateObject — DeepSeek doesn't support
// response_format: json_schema via the OpenAI compat layer
const { text } = await aiGenerateText({
model: this.model,
prompt: prompt,
})
// Parse the JSON array from the text response — strip markdown code fences if present
const parsed = JSON.parse(text.replace(/^```json\n?/,'').replace(/\n?```$/,'').trim())
const titles = Array.isArray(parsed) ? parsed : (parsed.titles || parsed.suggestions || [])
return titles.map((t: any) => ({
title: typeof t === 'string' ? t : t.title || t.name || '',
confidence: typeof t === 'number' ? t : (t.confidence || t.score || 0.5),
}))
} catch (e) {
console.error('Error generating titles (Custom OpenAI):', e)
return []
}
}
async generateText(prompt: string): Promise<string> {
try {
const { text } = await aiGenerateText({
model: this.model,
prompt: prompt,
});
return text.trim();
} catch (e) {
console.error('Error generating text (Custom OpenAI):', e);
throw e;
}
}
async chat(messages: any[], systemPrompt?: string): Promise<any> {
try {
const { text } = await aiGenerateText({
model: this.model,
system: systemPrompt,
messages: messages,
});
return { text: text.trim() };
} catch (e) {
console.error('Error in chat (Custom OpenAI):', e);
throw e;
}
}
async generateWithTools(options: ToolUseOptions): Promise<ToolCallResult> {
const { tools, maxSteps = 10, systemPrompt, messages, prompt } = options
const buildOpts = (steps: number): Record<string, any> => {
const opts: Record<string, any> = { model: this.model, tools, stopWhen: stepCountIs(steps) }
if (systemPrompt) opts.system = systemPrompt
if (messages) opts.messages = messages
else if (prompt) opts.prompt = prompt
return opts
}
const toResult = (r: any): ToolCallResult => ({
toolCalls: r.toolCalls?.map((tc: any) => ({ toolName: tc.toolName, input: tc.input })) || [],
toolResults: r.toolResults?.map((tr: any) => ({ toolName: tr.toolName, input: tr.input, output: tr.output })) || [],
text: r.text,
steps: r.steps?.map((step: any) => ({
text: step.text,
toolCalls: step.toolCalls?.map((tc: any) => ({ toolName: tc.toolName, input: tc.input })) || [],
toolResults: step.toolResults?.map((tr: any) => ({ toolName: tr.toolName, input: tr.input, output: tr.output })) || [],
})) || [],
})
try {
const result = await aiGenerateText(buildOpts(maxSteps) as any)
return toResult(result)
} catch (err: any) {
// DeepSeek reasoning/thinking models require reasoning_content to be passed back
// between multi-step calls, which the AI SDK doesn't handle via the OpenAI-compat layer.
// Retry with a single step so the model calls the tool directly.
const msg: string = err?.message || String(err)
if (msg.includes('reasoning_content') || msg.includes('thinking mode')) {
console.warn('[CustomOpenAI] Reasoning model detected — retrying with maxSteps=1')
const result = await aiGenerateText(buildOpts(1) as any)
return toResult(result)
}
throw err
}
}
getModel() {
return this.model;
}
}