Files
Keep/keep-notes/lib/ai/providers/ollama.ts

223 lines
7.3 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import { createOpenAI } from '@ai-sdk/openai';
import { generateText as aiGenerateText, stepCountIs } from 'ai';
import { AIProvider, TagSuggestion, TitleSuggestion, ToolUseOptions, ToolCallResult } from '../types';
export class OllamaProvider implements AIProvider {
private baseUrl: string;
private modelName: string;
private embeddingModelName: string;
private model: any;
constructor(baseUrl: string, modelName: string = 'llama3', embeddingModelName?: string) {
if (!baseUrl) {
throw new Error('baseUrl is required for OllamaProvider')
}
// Ensure baseUrl ends with /api for Ollama API
this.baseUrl = baseUrl.endsWith('/api') ? baseUrl : `${baseUrl}/api`;
this.modelName = modelName;
this.embeddingModelName = embeddingModelName || modelName;
// Create OpenAI-compatible model for streaming support
// Ollama exposes /v1/chat/completions which is compatible with the OpenAI SDK
const cleanUrl = this.baseUrl.replace(/\/api$/, '');
const ollamaClient = createOpenAI({
baseURL: `${cleanUrl}/v1`,
apiKey: 'ollama',
});
this.model = ollamaClient.chat(modelName);
}
async generateTags(content: string, language: string = "en"): Promise<TagSuggestion[]> {
try {
const promptText = language === 'fa'
? `متن زیر را تحلیل کن و مفاهیم کلیدی را به عنوان برچسب استخراج کن (حداکثر ۱-۳ کلمه).
قوانین:
- کلمات ربط را حذف کن.
- عبارات ترکیبی را حفظ کن.
- حداکثر ۵ برچسب.
پاسخ فقط به صورت لیست JSON با فرمت [{"tag": "string", "confidence": number}]
متن: "${content}"`
: language === 'fr'
? `Analyse la note suivante et extrais les concepts clés sous forme de tags courts (1-3 mots max).
Règles:
- Pas de mots de liaison.
- Garde les expressions composées ensemble.
- Normalise en minuscules sauf noms propres.
- Maximum 5 tags.
Réponds UNIQUEMENT sous forme de liste JSON d'objets : [{"tag": "string", "confidence": number}].
Contenu de la note: "${content}"`
: `Analyze the following note and extract key concepts as short tags (1-3 words max).
Rules:
- No stop words.
- Keep compound expressions together.
- Lowercase unless proper noun.
- Max 5 tags.
Respond ONLY as a JSON list of objects: [{"tag": "string", "confidence": number}].
Note content: "${content}"`;
const response = await fetch(`${this.baseUrl}/generate`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
model: this.modelName,
prompt: promptText,
stream: false,
}),
});
if (!response.ok) throw new Error(`Ollama error: ${response.statusText}`);
const data = await response.json();
const text = data.response;
const jsonMatch = text.match(/\[\s*\{[\s\S]*\}\s*\]/);
if (jsonMatch) {
return JSON.parse(jsonMatch[0]);
}
// Support pour le format { "tags": [...] }
const objectMatch = text.match(/\{\s*"tags"\s*:\s*(\[[\s\S]*\])\s*\}/);
if (objectMatch && objectMatch[1]) {
return JSON.parse(objectMatch[1]);
}
return [];
} catch (e) {
console.error('Erreur API directe Ollama:', e);
return [];
}
}
async getEmbeddings(text: string): Promise<number[]> {
try {
const response = await fetch(`${this.baseUrl}/embeddings`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
model: this.embeddingModelName,
prompt: text,
}),
});
if (!response.ok) throw new Error(`Ollama error: ${response.statusText}`);
const data = await response.json();
return data.embedding;
} catch (e) {
console.error('Erreur embeddings directs Ollama:', e);
return [];
}
}
async generateTitles(prompt: string): Promise<TitleSuggestion[]> {
try {
const response = await fetch(`${this.baseUrl}/generate`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
model: this.modelName,
prompt: `${prompt}\n\nRéponds UNIQUEMENT sous forme de tableau JSON : [{"title": "string", "confidence": number}]`,
stream: false,
}),
});
if (!response.ok) throw new Error(`Ollama error: ${response.statusText}`);
const data = await response.json();
const text = data.response;
// Extraire le JSON de la réponse
const jsonMatch = text.match(/\[\s*\{[\s\S]*\}\s*\]/);
if (jsonMatch) {
return JSON.parse(jsonMatch[0]);
}
return [];
} catch (e) {
console.error('Erreur génération titres Ollama:', e);
return [];
}
}
async generateText(prompt: string): Promise<string> {
try {
const response = await fetch(`${this.baseUrl}/generate`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
model: this.modelName,
prompt: prompt,
stream: false,
}),
});
if (!response.ok) throw new Error(`Ollama error: ${response.statusText}`);
const data = await response.json();
return data.response.trim();
} catch (e) {
console.error('Erreur génération texte Ollama:', e);
throw e;
}
}
async chat(messages: any[], systemPrompt?: string): Promise<any> {
try {
const ollamaMessages = messages.map(m => ({
role: m.role,
content: m.content
}));
if (systemPrompt) {
ollamaMessages.unshift({ role: 'system', content: systemPrompt });
}
const response = await fetch(`${this.baseUrl}/chat`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
model: this.modelName,
messages: ollamaMessages,
stream: false,
}),
});
if (!response.ok) throw new Error(`Ollama error: ${response.statusText}`);
const data = await response.json();
return { text: data.message?.content?.trim() || '' };
} catch (e) {
console.error('Erreur chat Ollama:', e);
throw e;
}
}
getModel() {
return this.model;
}
async generateWithTools(options: ToolUseOptions): Promise<ToolCallResult> {
const { tools, maxSteps = 10, systemPrompt, messages, prompt } = options
const opts: Record<string, any> = {
model: this.model,
tools,
stopWhen: stepCountIs(maxSteps),
}
if (systemPrompt) opts.system = systemPrompt
if (messages) opts.messages = messages
else if (prompt) opts.prompt = prompt
const result = await aiGenerateText(opts as any)
return {
toolCalls: result.toolCalls?.map((tc: any) => ({ toolName: tc.toolName, input: tc.input })) || [],
toolResults: result.toolResults?.map((tr: any) => ({ toolName: tr.toolName, input: tr.input, output: tr.output })) || [],
text: result.text,
steps: result.steps?.map((step: any) => ({
text: step.text,
toolCalls: step.toolCalls?.map((tc: any) => ({ toolName: tc.toolName, input: tc.input })) || [],
toolResults: step.toolResults?.map((tr: any) => ({ toolName: tr.toolName, input: tr.input, output: tr.output })) || []
})) || []
}
}
}