286 lines
10 KiB
Python
286 lines
10 KiB
Python
from typing import Dict, List, Any, Optional
|
|
import base64
|
|
from io import BytesIO
|
|
import pandas as pd
|
|
from PIL import Image
|
|
|
|
# Remplacer les imports dépréciés par les nouveaux packages
|
|
from langchain_qdrant import QdrantVectorStore
|
|
from langchain_ollama import OllamaEmbeddings, ChatOllama
|
|
from langchain.prompts import ChatPromptTemplate
|
|
from langchain.schema import Document
|
|
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
|
from qdrant_client import QdrantClient
|
|
|
|
class MultimodalRAGChatbot:
|
|
"""
|
|
Chatbot RAG multimodal qui utilise Qdrant pour stocker les documents
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
qdrant_url: str = "http://localhost:6333",
|
|
qdrant_collection_name: str = "my_documents",
|
|
ollama_model: str = "llama3.1",
|
|
embedding_model: str = "mxbai-embed-large",
|
|
ollama_url: str = "http://localhost:11434" # Ajout de ce paramètre
|
|
):
|
|
"""
|
|
Initialise le chatbot RAG avec Qdrant
|
|
"""
|
|
# Initialiser le modèle d'embedding
|
|
self.embeddings = OllamaEmbeddings(
|
|
model=embedding_model,
|
|
base_url=ollama_url # Utilisation de l'URL d'Ollama
|
|
)
|
|
|
|
# Créer le client Qdrant
|
|
self.client = QdrantClient(url=qdrant_url)
|
|
|
|
# Se connecter à la collection existante
|
|
self.vector_store = QdrantVectorStore(
|
|
client=self.client,
|
|
collection_name=qdrant_collection_name,
|
|
embedding=self.embeddings
|
|
)
|
|
|
|
# Initialiser le retriever
|
|
self.retriever = self.vector_store.as_retriever(
|
|
search_type="similarity",
|
|
search_kwargs={"k": 5}
|
|
)
|
|
|
|
# Initialiser les modèles LLM
|
|
self.llm = ChatOllama(
|
|
model=ollama_model,
|
|
base_url=ollama_url # Utilisation de l'URL d'Ollama
|
|
)
|
|
self.streaming_llm = ChatOllama(
|
|
model=ollama_model,
|
|
base_url=ollama_url, # Utilisation de l'URL d'Ollama
|
|
streaming=True,
|
|
callbacks=[StreamingStdOutCallbackHandler()]
|
|
)
|
|
|
|
# Historique des conversations
|
|
self.chat_history = []
|
|
|
|
print(f"Chatbot initialisé avec modèle: {ollama_model}")
|
|
print(f"Utilisant embeddings: {embedding_model}")
|
|
print(f"Connecté à Qdrant: {qdrant_url}, collection: {qdrant_collection_name}")
|
|
print(f"Ollama URL: {ollama_url}")
|
|
|
|
def chat(self, query: str, stream: bool = False):
|
|
"""
|
|
Traite une question de l'utilisateur et retourne une réponse
|
|
"""
|
|
# 1. Récupérer les documents pertinents
|
|
docs = self._retrieve_relevant_documents(query)
|
|
|
|
# 2. Préparer le contexte à partir des documents
|
|
context = self._format_documents(docs)
|
|
|
|
# 3. Préparer l'historique des conversations
|
|
history_text = self._format_chat_history()
|
|
|
|
# 4. Créer le prompt
|
|
prompt_template = ChatPromptTemplate.from_template("""
|
|
Tu es un assistant intelligent qui répond aux questions en utilisant uniquement
|
|
les informations fournies dans le contexte. Si tu ne trouves pas l'information
|
|
dans le contexte, dis simplement que tu ne sais pas. Lorsque tu mentionnes une
|
|
image ou un tableau, décris brièvement son contenu en te basant sur les
|
|
descriptions fournies.
|
|
|
|
Historique de conversation:
|
|
{chat_history}
|
|
|
|
Contexte:
|
|
{context}
|
|
|
|
Question de l'utilisateur: {question}
|
|
|
|
Réponds de façon concise et précise en citant les sources pertinentes.
|
|
""")
|
|
|
|
# 5. Générer la réponse
|
|
llm = self.streaming_llm if stream else self.llm
|
|
|
|
if stream:
|
|
print("\nRéponse:")
|
|
|
|
# Formater les messages pour le LLM
|
|
messages = prompt_template.format_messages(
|
|
chat_history=history_text,
|
|
context=context,
|
|
question=query
|
|
)
|
|
|
|
# Appeler le LLM
|
|
response = llm.invoke(messages)
|
|
answer = response.content
|
|
|
|
# 6. Mettre à jour l'historique des conversations
|
|
self.chat_history.append({"role": "user", "content": query})
|
|
self.chat_history.append({"role": "assistant", "content": answer})
|
|
|
|
# 7. Traiter les documents pour la sortie
|
|
texts, images, tables = self._process_documents(docs)
|
|
|
|
# 8. Préparer la réponse
|
|
result = {
|
|
"response": answer,
|
|
"texts": texts,
|
|
"images": images,
|
|
"tables": tables
|
|
}
|
|
|
|
return result
|
|
|
|
def _retrieve_relevant_documents(self, query: str, k: int = 5) -> List[Document]:
|
|
"""
|
|
Récupère les documents pertinents de la base Qdrant
|
|
"""
|
|
return self.vector_store.similarity_search(query, k=k)
|
|
|
|
def _format_documents(self, docs: List[Document]) -> str:
|
|
"""
|
|
Formate les documents pour le contexte
|
|
"""
|
|
formatted_docs = []
|
|
|
|
for i, doc in enumerate(docs):
|
|
metadata = doc.metadata
|
|
|
|
# Déterminer le type de document et le formater en conséquence
|
|
if "image_base64" in metadata:
|
|
# Image
|
|
formatted_docs.append(
|
|
f"[IMAGE {i+1}]\n"
|
|
f"Source: {metadata.get('source', 'Inconnue')}\n"
|
|
f"Page: {metadata.get('page_number', '')}\n"
|
|
f"Caption: {metadata.get('caption', '')}\n"
|
|
f"Description: {doc.page_content}\n"
|
|
)
|
|
elif "table_content" in metadata:
|
|
# Tableau
|
|
formatted_docs.append(
|
|
f"[TABLEAU {i+1}]\n"
|
|
f"Source: {metadata.get('source', 'Inconnue')}\n"
|
|
f"Page: {metadata.get('page_number', '')}\n"
|
|
f"Caption: {metadata.get('caption', '')}\n"
|
|
f"Description: {doc.page_content}\n"
|
|
)
|
|
else:
|
|
# Texte
|
|
formatted_docs.append(
|
|
f"[TEXTE {i+1}]\n"
|
|
f"Source: {metadata.get('source', 'Inconnue')}\n"
|
|
f"Page: {metadata.get('page_number', '')}\n"
|
|
f"{doc.page_content}\n"
|
|
)
|
|
|
|
return "\n".join(formatted_docs)
|
|
|
|
def _format_chat_history(self) -> str:
|
|
"""
|
|
Formate l'historique des conversations
|
|
"""
|
|
if not self.chat_history:
|
|
return "Pas d'historique de conversation."
|
|
|
|
formatted_history = []
|
|
|
|
for message in self.chat_history:
|
|
role = "Utilisateur" if message["role"] == "user" else "Assistant"
|
|
formatted_history.append(f"{role}: {message['content']}")
|
|
|
|
return "\n".join(formatted_history)
|
|
|
|
def _process_documents(self, docs: List[Document]):
|
|
"""
|
|
Traite les documents pour séparer textes, images et tableaux
|
|
"""
|
|
texts = []
|
|
images = []
|
|
tables = []
|
|
|
|
for doc in docs:
|
|
metadata = doc.metadata
|
|
|
|
# Déterminer le type de document
|
|
if "image_base64" in metadata:
|
|
# C'est une image
|
|
images.append({
|
|
"image_data": metadata.get("image_base64", ""),
|
|
"description": doc.page_content,
|
|
"caption": metadata.get("caption", ""),
|
|
"source": metadata.get("source", ""),
|
|
"page": metadata.get("page_number", "")
|
|
})
|
|
elif "table_content" in metadata:
|
|
# C'est un tableau
|
|
tables.append({
|
|
"table_data": metadata.get("table_content", ""),
|
|
"description": doc.page_content,
|
|
"caption": metadata.get("caption", ""),
|
|
"source": metadata.get("source", ""),
|
|
"page": metadata.get("page_number", "")
|
|
})
|
|
else:
|
|
# C'est du texte
|
|
texts.append({
|
|
"content": doc.page_content,
|
|
"source": metadata.get("source", ""),
|
|
"page": metadata.get("page_number", "")
|
|
})
|
|
|
|
return texts, images, tables
|
|
|
|
def clear_history(self):
|
|
"""
|
|
Efface l'historique de conversation
|
|
"""
|
|
self.chat_history = []
|
|
|
|
def display_image(self, image_data: str, caption: str = ""):
|
|
"""
|
|
Affiche une image à partir de sa représentation base64
|
|
"""
|
|
try:
|
|
# Décodage de l'image base64
|
|
image_bytes = base64.b64decode(image_data)
|
|
image = Image.open(BytesIO(image_bytes))
|
|
|
|
# Affichage selon l'environnement
|
|
try:
|
|
from IPython.display import display
|
|
print(f"Caption: {caption}")
|
|
display(image)
|
|
except ImportError:
|
|
image.show()
|
|
|
|
return True
|
|
except Exception as e:
|
|
print(f"Erreur lors de l'affichage de l'image: {e}")
|
|
return False
|
|
|
|
def format_table(self, table_data: str) -> str:
|
|
"""
|
|
Formate les données d'un tableau pour l'affichage
|
|
"""
|
|
try:
|
|
# Si format markdown
|
|
if isinstance(table_data, str) and table_data.strip().startswith("|"):
|
|
return table_data
|
|
|
|
# Essayer de parser comme JSON
|
|
import json
|
|
try:
|
|
data = json.loads(table_data)
|
|
df = pd.DataFrame(data)
|
|
return df.to_string(index=False)
|
|
except:
|
|
# Si échec, retourner les données brutes
|
|
return str(table_data)
|
|
except Exception as e:
|
|
return f"Erreur lors du formatage du tableau: {e}\n{table_data}" |