Chat_bot_Rag/rag_chatbot.py

286 lines
10 KiB
Python

from typing import Dict, List, Any, Optional
import base64
from io import BytesIO
import pandas as pd
from PIL import Image
# Remplacer les imports dépréciés par les nouveaux packages
from langchain_qdrant import QdrantVectorStore
from langchain_ollama import OllamaEmbeddings, ChatOllama
from langchain.prompts import ChatPromptTemplate
from langchain.schema import Document
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from qdrant_client import QdrantClient
class MultimodalRAGChatbot:
"""
Chatbot RAG multimodal qui utilise Qdrant pour stocker les documents
"""
def __init__(
self,
qdrant_url: str = "http://localhost:6333",
qdrant_collection_name: str = "my_documents",
ollama_model: str = "llama3.1",
embedding_model: str = "mxbai-embed-large",
ollama_url: str = "http://localhost:11434" # Ajout de ce paramètre
):
"""
Initialise le chatbot RAG avec Qdrant
"""
# Initialiser le modèle d'embedding
self.embeddings = OllamaEmbeddings(
model=embedding_model,
base_url=ollama_url # Utilisation de l'URL d'Ollama
)
# Créer le client Qdrant
self.client = QdrantClient(url=qdrant_url)
# Se connecter à la collection existante
self.vector_store = QdrantVectorStore(
client=self.client,
collection_name=qdrant_collection_name,
embedding=self.embeddings
)
# Initialiser le retriever
self.retriever = self.vector_store.as_retriever(
search_type="similarity",
search_kwargs={"k": 5}
)
# Initialiser les modèles LLM
self.llm = ChatOllama(
model=ollama_model,
base_url=ollama_url # Utilisation de l'URL d'Ollama
)
self.streaming_llm = ChatOllama(
model=ollama_model,
base_url=ollama_url, # Utilisation de l'URL d'Ollama
streaming=True,
callbacks=[StreamingStdOutCallbackHandler()]
)
# Historique des conversations
self.chat_history = []
print(f"Chatbot initialisé avec modèle: {ollama_model}")
print(f"Utilisant embeddings: {embedding_model}")
print(f"Connecté à Qdrant: {qdrant_url}, collection: {qdrant_collection_name}")
print(f"Ollama URL: {ollama_url}")
def chat(self, query: str, stream: bool = False):
"""
Traite une question de l'utilisateur et retourne une réponse
"""
# 1. Récupérer les documents pertinents
docs = self._retrieve_relevant_documents(query)
# 2. Préparer le contexte à partir des documents
context = self._format_documents(docs)
# 3. Préparer l'historique des conversations
history_text = self._format_chat_history()
# 4. Créer le prompt
prompt_template = ChatPromptTemplate.from_template("""
Tu es un assistant intelligent qui répond aux questions en utilisant uniquement
les informations fournies dans le contexte. Si tu ne trouves pas l'information
dans le contexte, dis simplement que tu ne sais pas. Lorsque tu mentionnes une
image ou un tableau, décris brièvement son contenu en te basant sur les
descriptions fournies.
Historique de conversation:
{chat_history}
Contexte:
{context}
Question de l'utilisateur: {question}
Réponds de façon concise et précise en citant les sources pertinentes.
""")
# 5. Générer la réponse
llm = self.streaming_llm if stream else self.llm
if stream:
print("\nRéponse:")
# Formater les messages pour le LLM
messages = prompt_template.format_messages(
chat_history=history_text,
context=context,
question=query
)
# Appeler le LLM
response = llm.invoke(messages)
answer = response.content
# 6. Mettre à jour l'historique des conversations
self.chat_history.append({"role": "user", "content": query})
self.chat_history.append({"role": "assistant", "content": answer})
# 7. Traiter les documents pour la sortie
texts, images, tables = self._process_documents(docs)
# 8. Préparer la réponse
result = {
"response": answer,
"texts": texts,
"images": images,
"tables": tables
}
return result
def _retrieve_relevant_documents(self, query: str, k: int = 5) -> List[Document]:
"""
Récupère les documents pertinents de la base Qdrant
"""
return self.vector_store.similarity_search(query, k=k)
def _format_documents(self, docs: List[Document]) -> str:
"""
Formate les documents pour le contexte
"""
formatted_docs = []
for i, doc in enumerate(docs):
metadata = doc.metadata
# Déterminer le type de document et le formater en conséquence
if "image_base64" in metadata:
# Image
formatted_docs.append(
f"[IMAGE {i+1}]\n"
f"Source: {metadata.get('source', 'Inconnue')}\n"
f"Page: {metadata.get('page_number', '')}\n"
f"Caption: {metadata.get('caption', '')}\n"
f"Description: {doc.page_content}\n"
)
elif "table_content" in metadata:
# Tableau
formatted_docs.append(
f"[TABLEAU {i+1}]\n"
f"Source: {metadata.get('source', 'Inconnue')}\n"
f"Page: {metadata.get('page_number', '')}\n"
f"Caption: {metadata.get('caption', '')}\n"
f"Description: {doc.page_content}\n"
)
else:
# Texte
formatted_docs.append(
f"[TEXTE {i+1}]\n"
f"Source: {metadata.get('source', 'Inconnue')}\n"
f"Page: {metadata.get('page_number', '')}\n"
f"{doc.page_content}\n"
)
return "\n".join(formatted_docs)
def _format_chat_history(self) -> str:
"""
Formate l'historique des conversations
"""
if not self.chat_history:
return "Pas d'historique de conversation."
formatted_history = []
for message in self.chat_history:
role = "Utilisateur" if message["role"] == "user" else "Assistant"
formatted_history.append(f"{role}: {message['content']}")
return "\n".join(formatted_history)
def _process_documents(self, docs: List[Document]):
"""
Traite les documents pour séparer textes, images et tableaux
"""
texts = []
images = []
tables = []
for doc in docs:
metadata = doc.metadata
# Déterminer le type de document
if "image_base64" in metadata:
# C'est une image
images.append({
"image_data": metadata.get("image_base64", ""),
"description": doc.page_content,
"caption": metadata.get("caption", ""),
"source": metadata.get("source", ""),
"page": metadata.get("page_number", "")
})
elif "table_content" in metadata:
# C'est un tableau
tables.append({
"table_data": metadata.get("table_content", ""),
"description": doc.page_content,
"caption": metadata.get("caption", ""),
"source": metadata.get("source", ""),
"page": metadata.get("page_number", "")
})
else:
# C'est du texte
texts.append({
"content": doc.page_content,
"source": metadata.get("source", ""),
"page": metadata.get("page_number", "")
})
return texts, images, tables
def clear_history(self):
"""
Efface l'historique de conversation
"""
self.chat_history = []
def display_image(self, image_data: str, caption: str = ""):
"""
Affiche une image à partir de sa représentation base64
"""
try:
# Décodage de l'image base64
image_bytes = base64.b64decode(image_data)
image = Image.open(BytesIO(image_bytes))
# Affichage selon l'environnement
try:
from IPython.display import display
print(f"Caption: {caption}")
display(image)
except ImportError:
image.show()
return True
except Exception as e:
print(f"Erreur lors de l'affichage de l'image: {e}")
return False
def format_table(self, table_data: str) -> str:
"""
Formate les données d'un tableau pour l'affichage
"""
try:
# Si format markdown
if isinstance(table_data, str) and table_data.strip().startswith("|"):
return table_data
# Essayer de parser comme JSON
import json
try:
data = json.loads(table_data)
df = pd.DataFrame(data)
return df.to_string(index=False)
except:
# Si échec, retourner les données brutes
return str(table_data)
except Exception as e:
return f"Erreur lors du formatage du tableau: {e}\n{table_data}"