From 5c9fd9c4d2f2245f028ff2fb2499010bedcb98e7 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 14 May 2025 01:14:15 +0800 Subject: [PATCH] Update Ollama sample code --- examples/lightrag_ollama_demo.py | 219 +++++++++++++++++++++++-------- 1 file changed, 167 insertions(+), 52 deletions(-) diff --git a/examples/lightrag_ollama_demo.py b/examples/lightrag_ollama_demo.py index 437ace1b..1ce1f2da 100644 --- a/examples/lightrag_ollama_demo.py +++ b/examples/lightrag_ollama_demo.py @@ -1,19 +1,84 @@ import asyncio -import nest_asyncio - import os import inspect import logging +import logging.config from lightrag import LightRAG, QueryParam from lightrag.llm.ollama import ollama_model_complete, ollama_embed -from lightrag.utils import EmbeddingFunc +from lightrag.utils import EmbeddingFunc, logger, set_verbose_debug from lightrag.kg.shared_storage import initialize_pipeline_status -nest_asyncio.apply() +from dotenv import load_dotenv + +load_dotenv(dotenv_path=".env", override=False) WORKING_DIR = "./dickens" -logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO) + +def configure_logging(): + """Configure logging for the application""" + + # Reset any existing handlers to ensure clean configuration + for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error", "lightrag"]: + logger_instance = logging.getLogger(logger_name) + logger_instance.handlers = [] + logger_instance.filters = [] + + # Get log directory path from environment variable or use current directory + log_dir = os.getenv("LOG_DIR", os.getcwd()) + log_file_path = os.path.abspath( + os.path.join(log_dir, "lightrag_ollama_demo.log") + ) + + print(f"\nLightRAG compatible demo log file: {log_file_path}\n") + os.makedirs(os.path.dirname(log_file_path), exist_ok=True) + + # Get log file max size and backup count from environment variables + log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB + log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups + + logging.config.dictConfig( + { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "default": { + "format": "%(levelname)s: %(message)s", + }, + "detailed": { + "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s", + }, + }, + "handlers": { + "console": { + "formatter": "default", + "class": "logging.StreamHandler", + "stream": "ext://sys.stderr", + }, + "file": { + "formatter": "detailed", + "class": "logging.handlers.RotatingFileHandler", + "filename": log_file_path, + "maxBytes": log_max_bytes, + "backupCount": log_backup_count, + "encoding": "utf-8", + }, + }, + "loggers": { + "lightrag": { + "handlers": ["console", "file"], + "level": "INFO", + "propagate": False, + }, + }, + } + ) + + # Set the logger level to INFO + logger.setLevel(logging.INFO) + # Enable verbose debug if needed + set_verbose_debug(os.getenv("VERBOSE_DEBUG", "false").lower() == "true") + if not os.path.exists(WORKING_DIR): os.mkdir(WORKING_DIR) @@ -23,18 +88,20 @@ async def initialize_rag(): rag = LightRAG( working_dir=WORKING_DIR, llm_model_func=ollama_model_complete, - llm_model_name="gemma2:2b", - llm_model_max_async=4, - llm_model_max_token_size=32768, + llm_model_name=os.getenv("LLM_MODEL", "qwen2.5-coder:7b"), + llm_model_max_token_size=8192, llm_model_kwargs={ - "host": "http://localhost:11434", - "options": {"num_ctx": 32768}, + "host": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"), + "options": {"num_ctx": 8192}, + "timeout": int(os.getenv("TIMEOUT", "300")), }, embedding_func=EmbeddingFunc( - embedding_dim=768, - max_token_size=8192, + embedding_dim=int(os.getenv("EMBEDDING_DIM", "1024")), + max_token_size=int(os.getenv("MAX_EMBED_TOKENS", "8192")), func=lambda texts: ollama_embed( - texts, embed_model="nomic-embed-text", host="http://localhost:11434" + texts, + embed_model=os.getenv("EMBEDDING_MODEL", "bge-m3:latest"), + host=os.getenv("EMBEDDING_BINDING_HOST", "http://localhost:11434"), ), ), ) @@ -50,54 +117,102 @@ async def print_stream(stream): print(chunk, end="", flush=True) -def main(): - # Initialize RAG instance - rag = asyncio.run(initialize_rag()) +async def main(): + try: + # Clear old data files + files_to_delete = [ + "graph_chunk_entity_relation.graphml", + "kv_store_doc_status.json", + "kv_store_full_docs.json", + "kv_store_text_chunks.json", + "vdb_chunks.json", + "vdb_entities.json", + "vdb_relationships.json", + ] - # Insert example text - with open("./book.txt", "r", encoding="utf-8") as f: - rag.insert(f.read()) + for file in files_to_delete: + file_path = os.path.join(WORKING_DIR, file) + if os.path.exists(file_path): + os.remove(file_path) + print(f"Deleting old file:: {file_path}") - # Test different query modes - print("\nNaive Search:") - print( - rag.query( - "What are the top themes in this story?", param=QueryParam(mode="naive") + # Initialize RAG instance + rag = await initialize_rag() + + # Test embedding function + test_text = ["This is a test string for embedding."] + embedding = await rag.embedding_func(test_text) + embedding_dim = embedding.shape[1] + print("\n=======================") + print("Test embedding function") + print("========================") + print(f"Test dict: {test_text}") + print(f"Detected embedding dimension: {embedding_dim}\n\n") + + with open("./book.txt", "r", encoding="utf-8") as f: + await rag.ainsert(f.read()) + + # Perform naive search + print("\n=====================") + print("Query mode: naive") + print("=====================") + resp = await rag.aquery( + "What are the top themes in this story?", + param=QueryParam(mode="naive", stream=True), ) - ) + if inspect.isasyncgen(resp): + await print_stream(resp) + else: + print(resp) - print("\nLocal Search:") - print( - rag.query( - "What are the top themes in this story?", param=QueryParam(mode="local") + # Perform local search + print("\n=====================") + print("Query mode: local") + print("=====================") + resp = await rag.aquery( + "What are the top themes in this story?", + param=QueryParam(mode="local", stream=True), ) - ) + if inspect.isasyncgen(resp): + await print_stream(resp) + else: + print(resp) - print("\nGlobal Search:") - print( - rag.query( - "What are the top themes in this story?", param=QueryParam(mode="global") + # Perform global search + print("\n=====================") + print("Query mode: global") + print("=====================") + resp = await rag.aquery( + "What are the top themes in this story?", + param=QueryParam(mode="global", stream=True), ) - ) + if inspect.isasyncgen(resp): + await print_stream(resp) + else: + print(resp) - print("\nHybrid Search:") - print( - rag.query( - "What are the top themes in this story?", param=QueryParam(mode="hybrid") + # Perform hybrid search + print("\n=====================") + print("Query mode: hybrid") + print("=====================") + resp = await rag.aquery( + "What are the top themes in this story?", + param=QueryParam(mode="hybrid", stream=True), ) - ) - - # stream response - resp = rag.query( - "What are the top themes in this story?", - param=QueryParam(mode="hybrid", stream=True), - ) - - if inspect.isasyncgen(resp): - asyncio.run(print_stream(resp)) - else: - print(resp) + if inspect.isasyncgen(resp): + await print_stream(resp) + else: + print(resp) + except Exception as e: + print(f"An error occurred: {e}") + finally: + if rag: + await rag.llm_response_cache.index_done_callback() + await rag.finalize_storages() if __name__ == "__main__": - main() + # Configure logging before running the main function + configure_logging() + asyncio.run(main()) + print("\nDone!")