Chat_bot_Rag/gradio_chatbot.py

722 lines
27 KiB
Python

import gradio as gr
import base64
from io import BytesIO
from PIL import Image
import pandas as pd
import traceback
import threading
import queue
import time
from rag_chatbot import MultimodalRAGChatbot
from langchain.callbacks.base import BaseCallbackHandler
# Handler personnalisé pour capturer les tokens en streaming
class GradioStreamingHandler(BaseCallbackHandler):
def __init__(self):
self.tokens_queue = queue.Queue()
self.full_text = ""
def on_llm_new_token(self, token, **kwargs):
self.tokens_queue.put(token)
self.full_text += token
# Fonction pour créer un objet Image à partir des données base64
def base64_to_image(base64_data):
"""Convertit une image base64 en objet Image pour l'affichage direct"""
try:
if not base64_data:
return None
image_bytes = base64.b64decode(base64_data)
image = Image.open(BytesIO(image_bytes))
return image
except Exception as e:
print(f"Erreur lors de la conversion d'image: {e}")
return None
# Configuration pour initialiser le chatbot
qdrant_url = "http://localhost:6333"
qdrant_collection_name = "my_custom_collection"
embedding_model = "mxbai-embed-large"
ollama_url = "http://127.0.0.1:11434"
default_model = "llama3.2"
# Liste des modèles disponibles
AVAILABLE_MODELS = ["llama3.1", "llama3.2","deepseek-r1:7b", "deepseek-r1:14b"]
# Mapping des langues pour une meilleure compréhension par le LLM
LANGUAGE_MAPPING = {
"Français": "français",
"English": "English",
"Español": "español",
"Deutsch": "Deutsch",
"Italiano": "italiano",
"中文": "Chinese",
"日本語": "Japanese",
"العربية": "Arabic"
}
# Initialiser le chatbot RAG avec le modèle par défaut
rag_bot = MultimodalRAGChatbot(
qdrant_url=qdrant_url,
qdrant_collection_name=qdrant_collection_name,
ollama_model=default_model,
embedding_model=embedding_model,
ollama_url=ollama_url
)
print(f"Chatbot initialisé avec modèle: {default_model}")
# Variables globales pour stocker les images et tableaux de la dernière requête
current_images = []
current_tables = []
# Fonction pour changer de modèle
def change_model(model_name):
global rag_bot
try:
# Réinitialiser le chatbot avec le nouveau modèle
rag_bot = MultimodalRAGChatbot(
qdrant_url=qdrant_url,
qdrant_collection_name=qdrant_collection_name,
ollama_model=model_name,
embedding_model=embedding_model,
ollama_url=ollama_url
)
print(f"Modèle changé pour: {model_name}")
return f"✅ Modèle changé pour: {model_name}"
except Exception as e:
print(f"Erreur lors du changement de modèle: {e}")
return f"❌ Erreur: {str(e)}"
# Fonction pour changer de collection
def change_collection(collection_name):
global rag_bot, qdrant_collection_name
try:
# Mise à jour de la variable globale
qdrant_collection_name = collection_name
# Réinitialiser le chatbot avec la nouvelle collection
rag_bot = MultimodalRAGChatbot(
qdrant_url=qdrant_url,
qdrant_collection_name=collection_name,
ollama_model=rag_bot.llm.model, # Conserver le modèle actuel
embedding_model=embedding_model,
ollama_url=ollama_url
)
print(f"Collection changée pour: {collection_name}")
return f"✅ Collection changée pour: {collection_name}"
except Exception as e:
print(f"Erreur lors du changement de collection: {e}")
return f"❌ Erreur: {str(e)}"
# Fonction de traitement des requêtes avec support du streaming dans Gradio
def process_query(message, history, streaming, show_sources, max_images, language):
global current_images, current_tables
if not message.strip():
return history, "", None, None
current_images = []
current_tables = []
try:
if streaming:
# Version avec streaming dans Gradio
history = history + [(message, "")]
# 1. Récupérer les documents pertinents
docs = rag_bot._retrieve_relevant_documents(message)
# 2. Préparer le contexte et l'historique
context = rag_bot._format_documents(docs)
history_text = rag_bot._format_chat_history()
# 3. Préparer le prompt
from langchain.prompts import ChatPromptTemplate
prompt_template = ChatPromptTemplate.from_template("""
Tu es un assistant documentaire spécialisé qui utilise toutes les informations disponibles dans le contexte fourni.
TRÈS IMPORTANT: Tu dois répondre EXCLUSIVEMENT en {language}. Ne réponds JAMAIS dans une autre langue.
Instructions spécifiques:
1. Pour chaque image mentionnée dans le contexte, inclue TOUJOURS dans ta réponse:
- La légende/caption exacte de l'image
- La source et le numéro de page
- Une description brève de ce qu'elle montre
2. Pour chaque tableau mentionné dans le contexte, inclue TOUJOURS:
- Le titre/caption exact du tableau
- La source et le numéro de page
- Ce que contient et signifie le tableau
3. Lorsque tu cites des équations mathématiques:
- Utilise la syntaxe LaTeX exacte comme dans le document ($...$ ou $$...$$)
- Reproduis-les fidèlement sans modification
4. IMPORTANT: Ne pas inventer d'informations - si une donnée n'est pas explicitement fournie dans le contexte,
indique clairement que cette information n'est pas disponible dans les documents fournis.
5. Cite précisément les sources pour chaque élément d'information (format: [Source, Page]).
6. CRUCIAL: Ta réponse doit être UNIQUEMENT et INTÉGRALEMENT en {language}, quelle que soit la langue de la question.
Historique de conversation:
{chat_history}
Contexte (à utiliser pour répondre):
{context}
Question: {question}
Réponds de façon structurée et précise en intégrant activement les images, tableaux et équations disponibles dans le contexte.
Ta réponse doit être exclusivement en {language}.
""")
# 4. Formater les messages pour le LLM
messages = prompt_template.format_messages(
chat_history=history_text,
context=context,
question=message,
language=LANGUAGE_MAPPING.get(language, "français") # Use the mapped language value
)
# 5. Créer un handler de streaming personnalisé
from langchain_ollama import ChatOllama
handler = GradioStreamingHandler()
# 6. Créer un modèle LLM avec notre handler
streaming_llm = ChatOllama(
model=rag_bot.llm.model,
base_url=rag_bot.llm.base_url,
streaming=True,
callbacks=[handler]
)
# 7. Lancer la génération dans un thread pour ne pas bloquer l'UI
def generate_response():
streaming_llm.invoke(messages)
thread = threading.Thread(target=generate_response)
thread.start()
# 8. Récupérer les tokens et mettre à jour l'interface
partial_response = ""
# Attendre les tokens avec un timeout
while thread.is_alive() or not handler.tokens_queue.empty():
try:
token = handler.tokens_queue.get(timeout=0.05)
partial_response += token
history[-1] = (message, partial_response)
yield history, "", None, None
except queue.Empty:
continue
# 9. Thread terminé, mettre à jour l'historique de conversation du chatbot
rag_bot.chat_history.append({"role": "user", "content": message})
rag_bot.chat_history.append({"role": "assistant", "content": partial_response})
# 10. Récupérer les sources, images, tableaux
texts, images, tables = rag_bot._process_documents(docs)
# Préparer les informations sur les sources
source_info = ""
if texts:
source_info += f"📚 {len(texts)} textes • "
if images:
source_info += f"🖼️ {len(images)} images • "
if tables:
source_info += f"📊 {len(tables)} tableaux"
if source_info:
source_info = "Sources trouvées: " + source_info
# 11. Traiter les images
if show_sources and images:
images = images[:max_images]
for img in images:
img_data = img.get("image_data")
if img_data:
image = base64_to_image(img_data)
if image:
current_images.append({
"image": image,
"caption": img.get("caption", ""),
"source": img.get("source", ""),
"page": img.get("page", ""),
"description": img.get("description", "")
})
# 12. Traiter les tableaux
if show_sources and tables:
for table in tables:
current_tables.append({
"data": rag_bot.format_table(table.get("table_data", "")),
"caption": table.get("caption", ""),
"source": table.get("source", ""),
"page": table.get("page", ""),
"description": table.get("description", "")
})
# 13. Retourner les résultats finaux
yield history, source_info, display_images(), display_tables()
else:
# Version sans streaming (code existant)
result = rag_bot.chat(message, stream=False)
history = history + [(message, result["response"])]
# Préparer les informations sur les sources
source_info = ""
if "texts" in result:
source_info += f"📚 {len(result['texts'])} textes • "
if "images" in result:
source_info += f"🖼️ {len(result['images'])} images • "
if "tables" in result:
source_info += f"📊 {len(result['tables'])} tableaux"
if source_info:
source_info = "Sources trouvées: " + source_info
# Traiter les images et tableaux
if show_sources and "images" in result and result["images"]:
images = result["images"][:max_images]
for img in images:
img_data = img.get("image_data")
if img_data:
image = base64_to_image(img_data)
if image:
current_images.append({
"image": image,
"caption": img.get("caption", ""),
"source": img.get("source", ""),
"page": img.get("page", ""),
"description": img.get("description", "")
})
if show_sources and "tables" in result and result["tables"]:
tables = result["tables"]
for table in tables:
current_tables.append({
"data": rag_bot.format_table(table.get("table_data", "")),
"caption": table.get("caption", ""),
"source": table.get("source", ""),
"page": table.get("page", ""),
"description": table.get("description", "")
})
return history, source_info, display_images(), display_tables()
except Exception as e:
error_msg = f"Une erreur est survenue: {str(e)}"
traceback_text = traceback.format_exc()
print(error_msg)
print(traceback_text)
history = history + [(message, error_msg)]
return history, "Erreur lors du traitement de la requête", None, None
# Fonctions pour afficher les images et tableaux
def display_images():
if not current_images:
return None
gallery = []
for img_data in current_images:
image = img_data["image"]
if image:
caption = f"{img_data['caption']} (Source: {img_data['source']}, Page: {img_data['page']})"
gallery.append((image, caption))
return gallery if gallery else None
def display_tables():
if not current_tables:
return None
html = ""
for idx, table in enumerate(current_tables):
# Convert raw table data to a proper HTML table
table_data = table['data']
table_html = ""
# Try to convert the table data to a formatted HTML table
try:
# If it's a string representation, convert to DataFrame and then to HTML
if isinstance(table_data, str):
# Try to parse as markdown table or CSV
if '|' in table_data:
# Clean up the table data - remove extra pipes and spaces
rows = table_data.strip().split('\n')
table_html = '<div class="table-container"><table>'
for i, row in enumerate(rows):
# Skip separator rows (---|---) in markdown tables
if i == 1 and all(c in ':-|' for c in row):
continue
# Process each cell
cells = row.split('|')
# Remove empty cells from start/end (caused by leading/trailing |)
if cells and cells[0].strip() == '':
cells = cells[1:]
if cells and cells[-1].strip() == '':
cells = cells[:-1]
# Create table row
if cells:
is_header = (i == 0)
table_html += '<tr>'
for cell in cells:
cell_content = cell.strip()
if is_header:
table_html += f'<th>{cell_content}</th>'
else:
table_html += f'<td>{cell_content}</td>'
table_html += '</tr>'
table_html += '</table></div>'
else:
# If not pipe-separated, wrap in pre for code formatting
table_html = f'<pre>{table_data}</pre>'
else:
# For any other format, just use a pre tag
table_html = f'<pre>{table_data}</pre>'
except Exception as e:
# Fallback if conversion fails
print(f"Error formatting table {idx}: {e}")
table_html = f'<pre>{table_data}</pre>'
# Create the table container with metadata
html += f"""
<div style="margin-bottom: 20px; border: 1px solid #ddd; padding: 15px; border-radius: 8px;">
<h3>{table['caption']}</h3>
<p style="color:#666; font-size:0.9em;">Source: {table['source']}, Page: {table['page']}</p>
<p><strong>Description:</strong> {table['description']}</p>
{table_html}
</div>
"""
return html if html else None
# Fonction pour réinitialiser l'historique
def reset_conversation():
global current_images, current_tables
current_images = []
current_tables = []
rag_bot.clear_history()
return [], "", None, None
# Interface Gradio
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
gr.Markdown("# 📚 Assistant documentaire intelligent")
with gr.Row():
with gr.Column(scale=2):
chat_interface = gr.Chatbot(
height=600,
show_label=False,
layout="bubble",
elem_id="chatbot"
)
with gr.Row():
msg = gr.Textbox(
show_label=False,
placeholder="Posez votre question...",
container=False,
scale=4
)
submit_btn = gr.Button("Envoyer", variant="primary", scale=1)
clear_btn = gr.Button("Effacer la conversation")
source_info = gr.Markdown("", elem_id="sources_info")
with gr.Column(scale=1):
with gr.Accordion("Options", open=True):
# Sélecteur de modèle
model_selector = gr.Dropdown(
choices=AVAILABLE_MODELS,
value=default_model,
label="Modèle Ollama",
info="Choisir le modèle de language à utiliser"
)
model_status = gr.Markdown(f"Modèle actuel: **{default_model}**")
# Sélecteur de langue
language_selector = gr.Dropdown(
choices=["Français", "English", "Español", "Deutsch", "Italiano", "中文", "日本語", "العربية"],
value="Français",
label="Langue des réponses",
info="Choisir la langue dans laquelle l'assistant répondra"
)
# Sélecteur de collection Qdrant
collection_name_input = gr.Textbox(
value=qdrant_collection_name,
label="Collection Qdrant",
info="Nom de la collection de documents à utiliser"
)
collection_status = gr.Markdown(f"Collection actuelle: **{qdrant_collection_name}**")
# Apply collection button
apply_collection_btn = gr.Button("Appliquer la collection")
streaming = gr.Checkbox(
label="Mode streaming",
value=True,
info="Voir les réponses s'afficher progressivement"
)
show_sources = gr.Checkbox(label="Afficher les sources", value=True)
max_images = gr.Slider(
minimum=1,
maximum=10,
value=3,
step=1,
label="Nombre max d'images"
)
gr.Markdown("---")
gr.Markdown("### 🖼️ Images pertinentes")
image_gallery = gr.Gallery(
label="Images pertinentes",
show_label=False,
columns=2,
height=300,
object_fit="contain"
)
gr.Markdown("### 📊 Tableaux")
tables_display = gr.HTML()
# Connecter le changement de modèle
model_selector.change(
fn=change_model,
inputs=model_selector,
outputs=model_status
)
# Connecter le changement de collection
apply_collection_btn.click(
fn=change_collection,
inputs=collection_name_input,
outputs=collection_status
)
# Configuration des actions
msg.submit(
process_query,
inputs=[msg, chat_interface, streaming, show_sources, max_images, language_selector],
outputs=[chat_interface, source_info, image_gallery, tables_display]
).then(lambda: "", outputs=msg)
submit_btn.click(
process_query,
inputs=[msg, chat_interface, streaming, show_sources, max_images, language_selector],
outputs=[chat_interface, source_info, image_gallery, tables_display]
).then(lambda: "", outputs=msg)
clear_btn.click(
reset_conversation,
outputs=[chat_interface, source_info, image_gallery, tables_display]
)
# Support amélioré pour les équations mathématiques avec KaTeX
gr.Markdown("""
<style>
.gradio-container {max-width: 1200px !important}
#chatbot {height: 600px; overflow-y: auto;}
#sources_info {margin-top: 10px; color: #666;}
/* Improved styles for equations */
.katex { font-size: 1.1em !important; }
.math-inline { background: #f8f9fa; padding: 2px 5px; border-radius: 4px; }
.math-display { background: #f8f9f9; margin: 10px 0; padding: 10px; border-radius: 5px; overflow-x: auto; text-align: center; }
/* Table styles */
table {
border-collapse: collapse;
width: 100%;
margin: 15px 0;
font-size: 0.9em;
}
table, th, td {
border: 1px solid #ddd;
}
th, td {
padding: 8px 12px;
text-align: left;
}
th {
background-color: #f2f2f2;
}
tr:nth-child(even) {
background-color: #f9f9f9;
}
.table-container {
overflow-x: auto;
margin-top: 10px;
}
</style>
<!-- Loading KaTeX -->
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.16.8/dist/katex.min.css">
<script src="https://cdn.jsdelivr.net/npm/katex@0.16.8/dist/katex.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/katex@0.16.8/dist/contrib/auto-render.min.js"></script>
<script>
// Function to process math equations with KaTeX
function renderMathInElement(element) {
if (!window.renderMathInElement) return;
try {
window.renderMathInElement(element, {
delimiters: [
{left: '$$', right: '$$', display: true},
{left: '$', right: '$', display: false},
{left: '\\(', right: '\\)', display: false},
{left: '\\[', right: '\\]', display: true}
],
throwOnError: false,
trust: true,
strict: false,
macros: {
"\\R": "\\mathbb{R}",
"\\N": "\\mathbb{N}"
}
});
} catch (e) {
console.error("KaTeX rendering error:", e);
}
}
// Function to fix and prepare text for LaTeX rendering
function prepareTextForLatex(text) {
if (!text) return text;
// Don't modify code blocks
if (text.indexOf('<pre>') !== -1) {
const parts = text.split(/<pre>|<\/pre>/);
for (let i = 0; i < parts.length; i++) {
// Only process odd-indexed parts (non-code)
if (i % 2 === 0) {
parts[i] = prepareLatexInText(parts[i]);
}
}
return parts.join('');
}
return prepareLatexInText(text);
}
// Helper to process LaTeX in regular text
function prepareLatexInText(text) {
// Make sure dollar signs used for math have proper spacing
// First, protect existing well-formed math expressions
text = text.replace(/(\$\$[^\$]+\$\$)/g, '<protect>$1</protect>'); // protect display math
text = text.replace(/(\$[^\$\n]+\$)/g, '<protect>$1</protect>'); // protect inline math
// Fix common LaTeX formatting issues outside protected regions
text = text.replace(/([^<]protect[^>]*)(\$)([^\s])/g, '$1$2 $3'); // Add space after $ if needed
text = text.replace(/([^\s])(\$)([^<]protect[^>]*)/g, '$1 $2$3'); // Add space before $ if needed
// Handle subscripts: transform x_1 into x_{1} for better LaTeX compatibility
text = text.replace(/([a-zA-Z])_([0-9a-zA-Z])/g, '$1_{$2}');
// Restore protected content
text = text.replace(/<protect>(.*?)<\/protect>/g, '$1');
return text;
}
// Enhanced message processor for KaTeX rendering
function processMessage(message) {
if (!message) return;
try {
// Get direct textual content when possible
const elements = message.querySelectorAll('p, li, h1, h2, h3, h4, h5, span');
elements.forEach(el => {
const originalText = el.innerHTML;
const preparedText = prepareTextForLatex(originalText);
// Only update if changes were made
if (preparedText !== originalText) {
el.innerHTML = preparedText;
}
// Render equations in this element
renderMathInElement(el);
});
// Also try to render on the entire message as fallback
renderMathInElement(message);
} catch (e) {
console.error("Error processing message for LaTeX:", e);
}
}
// Function to monitor for new messages
function setupMathObserver() {
const chatElement = document.getElementById('chatbot');
if (!chatElement) {
setTimeout(setupMathObserver, 500);
return;
}
// Process any existing messages
chatElement.querySelectorAll('.message').forEach(processMessage);
// Set up observer for new content
const observer = new MutationObserver((mutations) => {
for (const mutation of mutations) {
if (mutation.addedNodes.length > 0 || mutation.type === 'characterData') {
chatElement.querySelectorAll('.message').forEach(processMessage);
break;
}
}
});
observer.observe(chatElement, {
childList: true,
subtree: true,
characterData: true
});
console.log("LaTeX rendering observer set up successfully");
}
// Initialize once the document is fully loaded
function initializeRendering() {
if (window.renderMathInElement) {
setupMathObserver();
} else {
// If KaTeX isn't loaded yet, wait for it
const katexScript = document.querySelector('script[src*="auto-render.min.js"]');
if (katexScript) {
katexScript.onload = setupMathObserver;
} else {
// Last resort: try again later
setTimeout(initializeRendering, 500);
}
}
}
// Set up multiple trigger points to ensure it loads
document.addEventListener('DOMContentLoaded', function() {
setTimeout(initializeRendering, 800);
});
window.addEventListener('load', function() {
setTimeout(initializeRendering, 1200);
});
</script>
""")
if __name__ == "__main__":
demo.queue()
demo.launch(share=False, inbrowser=True)