import { createOpenAI } from '@ai-sdk/openai'; import { generateText as aiGenerateText, stepCountIs } from 'ai'; import { AIProvider, TagSuggestion, TitleSuggestion, ToolUseOptions, ToolCallResult } from '../types'; export class OllamaProvider implements AIProvider { private baseUrl: string; private modelName: string; private embeddingModelName: string; private model: any; constructor(baseUrl: string, modelName: string = 'llama3', embeddingModelName?: string) { if (!baseUrl) { throw new Error('baseUrl is required for OllamaProvider') } // Ensure baseUrl ends with /api for Ollama API this.baseUrl = baseUrl.endsWith('/api') ? baseUrl : `${baseUrl}/api`; this.modelName = modelName; this.embeddingModelName = embeddingModelName || modelName; // Create OpenAI-compatible model for streaming support // Ollama exposes /v1/chat/completions which is compatible with the OpenAI SDK const cleanUrl = this.baseUrl.replace(/\/api$/, ''); const ollamaClient = createOpenAI({ baseURL: `${cleanUrl}/v1`, apiKey: 'ollama', }); this.model = ollamaClient.chat(modelName); } async generateTags(content: string, language: string = "en"): Promise { try { const promptText = language === 'fa' ? `متن زیر را تحلیل کن و مفاهیم کلیدی را به عنوان برچسب استخراج کن (حداکثر ۱-۳ کلمه). قوانین: - کلمات ربط را حذف کن. - عبارات ترکیبی را حفظ کن. - حداکثر ۵ برچسب. پاسخ فقط به صورت لیست JSON با فرمت [{"tag": "string", "confidence": number}] متن: "${content}"` : language === 'fr' ? `Analyse la note suivante et extrais les concepts clés sous forme de tags courts (1-3 mots max). Règles: - Pas de mots de liaison. - Garde les expressions composées ensemble. - Normalise en minuscules sauf noms propres. - Maximum 5 tags. Réponds UNIQUEMENT sous forme de liste JSON d'objets : [{"tag": "string", "confidence": number}]. Contenu de la note: "${content}"` : `Analyze the following note and extract key concepts as short tags (1-3 words max). Rules: - No stop words. - Keep compound expressions together. - Lowercase unless proper noun. - Max 5 tags. Respond ONLY as a JSON list of objects: [{"tag": "string", "confidence": number}]. Note content: "${content}"`; const response = await fetch(`${this.baseUrl}/generate`, { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ model: this.modelName, prompt: promptText, stream: false, }), }); if (!response.ok) throw new Error(`Ollama error: ${response.statusText}`); const data = await response.json(); const text = data.response; const jsonMatch = text.match(/\[\s*\{[\s\S]*\}\s*\]/); if (jsonMatch) { return JSON.parse(jsonMatch[0]); } // Support pour le format { "tags": [...] } const objectMatch = text.match(/\{\s*"tags"\s*:\s*(\[[\s\S]*\])\s*\}/); if (objectMatch && objectMatch[1]) { return JSON.parse(objectMatch[1]); } return []; } catch (e) { console.error('Erreur API directe Ollama:', e); return []; } } async getEmbeddings(text: string): Promise { try { const response = await fetch(`${this.baseUrl}/embeddings`, { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ model: this.embeddingModelName, prompt: text, }), }); if (!response.ok) throw new Error(`Ollama error: ${response.statusText}`); const data = await response.json(); return data.embedding; } catch (e) { console.error('Erreur embeddings directs Ollama:', e); return []; } } async generateTitles(prompt: string): Promise { try { const response = await fetch(`${this.baseUrl}/generate`, { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ model: this.modelName, prompt: `${prompt}\n\nRéponds UNIQUEMENT sous forme de tableau JSON : [{"title": "string", "confidence": number}]`, stream: false, }), }); if (!response.ok) throw new Error(`Ollama error: ${response.statusText}`); const data = await response.json(); const text = data.response; // Extraire le JSON de la réponse const jsonMatch = text.match(/\[\s*\{[\s\S]*\}\s*\]/); if (jsonMatch) { return JSON.parse(jsonMatch[0]); } return []; } catch (e) { console.error('Erreur génération titres Ollama:', e); return []; } } async generateText(prompt: string): Promise { try { const response = await fetch(`${this.baseUrl}/generate`, { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ model: this.modelName, prompt: prompt, stream: false, }), }); if (!response.ok) throw new Error(`Ollama error: ${response.statusText}`); const data = await response.json(); return data.response.trim(); } catch (e) { console.error('Erreur génération texte Ollama:', e); throw e; } } async chat(messages: any[], systemPrompt?: string): Promise { try { const ollamaMessages = messages.map(m => ({ role: m.role, content: m.content })); if (systemPrompt) { ollamaMessages.unshift({ role: 'system', content: systemPrompt }); } const response = await fetch(`${this.baseUrl}/chat`, { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ model: this.modelName, messages: ollamaMessages, stream: false, }), }); if (!response.ok) throw new Error(`Ollama error: ${response.statusText}`); const data = await response.json(); return { text: data.message?.content?.trim() || '' }; } catch (e) { console.error('Erreur chat Ollama:', e); throw e; } } getModel() { return this.model; } async generateWithTools(options: ToolUseOptions): Promise { const { tools, maxSteps = 10, systemPrompt, messages, prompt } = options const opts: Record = { model: this.model, tools, stopWhen: stepCountIs(maxSteps), } if (systemPrompt) opts.system = systemPrompt if (messages) opts.messages = messages else if (prompt) opts.prompt = prompt const result = await aiGenerateText(opts as any) return { toolCalls: result.toolCalls?.map((tc: any) => ({ toolName: tc.toolName, input: tc.input })) || [], toolResults: result.toolResults?.map((tr: any) => ({ toolName: tr.toolName, input: tr.input, output: tr.output })) || [], text: result.text, steps: result.steps?.map((step: any) => ({ text: step.text, toolCalls: step.toolCalls?.map((tc: any) => ({ toolName: tc.toolName, input: tc.input })) || [], toolResults: step.toolResults?.map((tr: any) => ({ toolName: tr.toolName, input: tr.input, output: tr.output })) || [] })) || [] } } }