import gradio as gr import base64 from io import BytesIO from PIL import Image import pandas as pd import traceback import threading import queue import time from rag_chatbot import MultimodalRAGChatbot from langchain.callbacks.base import BaseCallbackHandler # Handler personnalisé pour capturer les tokens en streaming class GradioStreamingHandler(BaseCallbackHandler): def __init__(self): self.tokens_queue = queue.Queue() self.full_text = "" def on_llm_new_token(self, token, **kwargs): self.tokens_queue.put(token) self.full_text += token # Fonction pour créer un objet Image à partir des données base64 def base64_to_image(base64_data): """Convertit une image base64 en objet Image pour l'affichage direct""" try: if not base64_data: return None image_bytes = base64.b64decode(base64_data) image = Image.open(BytesIO(image_bytes)) return image except Exception as e: print(f"Erreur lors de la conversion d'image: {e}") return None # Configuration pour initialiser le chatbot qdrant_url = "http://localhost:6333" qdrant_collection_name = "my_custom_collection" embedding_model = "mxbai-embed-large" ollama_url = "http://127.0.0.1:11434" default_model = "llama3.2" # Liste des modèles disponibles AVAILABLE_MODELS = ["llama3.1", "llama3.2","deepseek-r1:7b", "deepseek-r1:14b"] # Mapping des langues pour une meilleure compréhension par le LLM LANGUAGE_MAPPING = { "Français": "français", "English": "English", "Español": "español", "Deutsch": "Deutsch", "Italiano": "italiano", "中文": "Chinese", "日本語": "Japanese", "العربية": "Arabic" } # Initialiser le chatbot RAG avec le modèle par défaut rag_bot = MultimodalRAGChatbot( qdrant_url=qdrant_url, qdrant_collection_name=qdrant_collection_name, ollama_model=default_model, embedding_model=embedding_model, ollama_url=ollama_url ) print(f"Chatbot initialisé avec modèle: {default_model}") # Variables globales pour stocker les images et tableaux de la dernière requête current_images = [] current_tables = [] # Fonction pour changer de modèle def change_model(model_name): global rag_bot try: # Réinitialiser le chatbot avec le nouveau modèle rag_bot = MultimodalRAGChatbot( qdrant_url=qdrant_url, qdrant_collection_name=qdrant_collection_name, ollama_model=model_name, embedding_model=embedding_model, ollama_url=ollama_url ) print(f"Modèle changé pour: {model_name}") return f"✅ Modèle changé pour: {model_name}" except Exception as e: print(f"Erreur lors du changement de modèle: {e}") return f"❌ Erreur: {str(e)}" # Fonction pour changer de collection def change_collection(collection_name): global rag_bot, qdrant_collection_name try: # Mise à jour de la variable globale qdrant_collection_name = collection_name # Réinitialiser le chatbot avec la nouvelle collection rag_bot = MultimodalRAGChatbot( qdrant_url=qdrant_url, qdrant_collection_name=collection_name, ollama_model=rag_bot.llm.model, # Conserver le modèle actuel embedding_model=embedding_model, ollama_url=ollama_url ) print(f"Collection changée pour: {collection_name}") return f"✅ Collection changée pour: {collection_name}" except Exception as e: print(f"Erreur lors du changement de collection: {e}") return f"❌ Erreur: {str(e)}" # Fonction de traitement des requêtes avec support du streaming dans Gradio def process_query(message, history, streaming, show_sources, max_images, language): global current_images, current_tables if not message.strip(): return history, "", None, None current_images = [] current_tables = [] try: if streaming: # Version avec streaming dans Gradio history = history + [(message, "")] # 1. Récupérer les documents pertinents docs = rag_bot._retrieve_relevant_documents(message) # 2. Préparer le contexte et l'historique context = rag_bot._format_documents(docs) history_text = rag_bot._format_chat_history() # 3. Préparer le prompt from langchain.prompts import ChatPromptTemplate prompt_template = ChatPromptTemplate.from_template(""" Tu es un assistant documentaire spécialisé qui utilise toutes les informations disponibles dans le contexte fourni. TRÈS IMPORTANT: Tu dois répondre EXCLUSIVEMENT en {language}. Ne réponds JAMAIS dans une autre langue. Instructions spécifiques: 1. Pour chaque image mentionnée dans le contexte, inclue TOUJOURS dans ta réponse: - La légende/caption exacte de l'image - La source et le numéro de page - Une description brève de ce qu'elle montre 2. Pour chaque tableau mentionné dans le contexte, inclue TOUJOURS: - Le titre/caption exact du tableau - La source et le numéro de page - Ce que contient et signifie le tableau 3. Lorsque tu cites des équations mathématiques: - Utilise la syntaxe LaTeX exacte comme dans le document ($...$ ou $$...$$) - Reproduis-les fidèlement sans modification 4. IMPORTANT: Ne pas inventer d'informations - si une donnée n'est pas explicitement fournie dans le contexte, indique clairement que cette information n'est pas disponible dans les documents fournis. 5. Cite précisément les sources pour chaque élément d'information (format: [Source, Page]). 6. CRUCIAL: Ta réponse doit être UNIQUEMENT et INTÉGRALEMENT en {language}, quelle que soit la langue de la question. Historique de conversation: {chat_history} Contexte (à utiliser pour répondre): {context} Question: {question} Réponds de façon structurée et précise en intégrant activement les images, tableaux et équations disponibles dans le contexte. Ta réponse doit être exclusivement en {language}. """) # 4. Formater les messages pour le LLM messages = prompt_template.format_messages( chat_history=history_text, context=context, question=message, language=LANGUAGE_MAPPING.get(language, "français") # Use the mapped language value ) # 5. Créer un handler de streaming personnalisé from langchain_ollama import ChatOllama handler = GradioStreamingHandler() # 6. Créer un modèle LLM avec notre handler streaming_llm = ChatOllama( model=rag_bot.llm.model, base_url=rag_bot.llm.base_url, streaming=True, callbacks=[handler] ) # 7. Lancer la génération dans un thread pour ne pas bloquer l'UI def generate_response(): streaming_llm.invoke(messages) thread = threading.Thread(target=generate_response) thread.start() # 8. Récupérer les tokens et mettre à jour l'interface partial_response = "" # Attendre les tokens avec un timeout while thread.is_alive() or not handler.tokens_queue.empty(): try: token = handler.tokens_queue.get(timeout=0.05) partial_response += token history[-1] = (message, partial_response) yield history, "", None, None except queue.Empty: continue # 9. Thread terminé, mettre à jour l'historique de conversation du chatbot rag_bot.chat_history.append({"role": "user", "content": message}) rag_bot.chat_history.append({"role": "assistant", "content": partial_response}) # 10. Récupérer les sources, images, tableaux texts, images, tables = rag_bot._process_documents(docs) # Préparer les informations sur les sources source_info = "" if texts: source_info += f"📚 {len(texts)} textes • " if images: source_info += f"🖼️ {len(images)} images • " if tables: source_info += f"📊 {len(tables)} tableaux" if source_info: source_info = "Sources trouvées: " + source_info # 11. Traiter les images if show_sources and images: images = images[:max_images] for img in images: img_data = img.get("image_data") if img_data: image = base64_to_image(img_data) if image: current_images.append({ "image": image, "caption": img.get("caption", ""), "source": img.get("source", ""), "page": img.get("page", ""), "description": img.get("description", "") }) # 12. Traiter les tableaux if show_sources and tables: for table in tables: current_tables.append({ "data": rag_bot.format_table(table.get("table_data", "")), "caption": table.get("caption", ""), "source": table.get("source", ""), "page": table.get("page", ""), "description": table.get("description", "") }) # 13. Retourner les résultats finaux yield history, source_info, display_images(), display_tables() else: # Version sans streaming (code existant) result = rag_bot.chat(message, stream=False) history = history + [(message, result["response"])] # Préparer les informations sur les sources source_info = "" if "texts" in result: source_info += f"📚 {len(result['texts'])} textes • " if "images" in result: source_info += f"🖼️ {len(result['images'])} images • " if "tables" in result: source_info += f"📊 {len(result['tables'])} tableaux" if source_info: source_info = "Sources trouvées: " + source_info # Traiter les images et tableaux if show_sources and "images" in result and result["images"]: images = result["images"][:max_images] for img in images: img_data = img.get("image_data") if img_data: image = base64_to_image(img_data) if image: current_images.append({ "image": image, "caption": img.get("caption", ""), "source": img.get("source", ""), "page": img.get("page", ""), "description": img.get("description", "") }) if show_sources and "tables" in result and result["tables"]: tables = result["tables"] for table in tables: current_tables.append({ "data": rag_bot.format_table(table.get("table_data", "")), "caption": table.get("caption", ""), "source": table.get("source", ""), "page": table.get("page", ""), "description": table.get("description", "") }) return history, source_info, display_images(), display_tables() except Exception as e: error_msg = f"Une erreur est survenue: {str(e)}" traceback_text = traceback.format_exc() print(error_msg) print(traceback_text) history = history + [(message, error_msg)] return history, "Erreur lors du traitement de la requête", None, None # Fonctions pour afficher les images et tableaux def display_images(): if not current_images: return None gallery = [] for img_data in current_images: image = img_data["image"] if image: caption = f"{img_data['caption']} (Source: {img_data['source']}, Page: {img_data['page']})" gallery.append((image, caption)) return gallery if gallery else None def display_tables(): if not current_tables: return None html = "" for idx, table in enumerate(current_tables): # Convert raw table data to a proper HTML table table_data = table['data'] table_html = "" # Try to convert the table data to a formatted HTML table try: # If it's a string representation, convert to DataFrame and then to HTML if isinstance(table_data, str): # Try to parse as markdown table or CSV if '|' in table_data: # Clean up the table data - remove extra pipes and spaces rows = table_data.strip().split('\n') table_html = '
| {cell_content} | ' else: table_html += f'{cell_content} | ' table_html += '
|---|
{table_data}'
else:
# For any other format, just use a pre tag
table_html = f'{table_data}'
except Exception as e:
# Fallback if conversion fails
print(f"Error formatting table {idx}: {e}")
table_html = f'{table_data}'
# Create the table container with metadata
html += f"""
Source: {table['source']}, Page: {table['page']}
Description: {table['description']}
{table_html}