From ab75027b2211dba66d8f1549a50a372168e1274d Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 13 May 2025 23:59:00 +0800 Subject: [PATCH 1/5] Remove deprecated demo code --- examples/lightrag_ollama_age_demo.py | 113 --------------- examples/lightrag_siliconcloud_demo.py | 103 ------------- .../lightrag_siliconcloud_track_token_demo.py | 110 -------------- examples/lightrag_tidb_demo.py | 116 --------------- examples/lightrag_tongyi_openai_demo.py | 136 ------------------ examples/lightrag_zhipu_demo.py | 80 ----------- examples/lightrag_zhipu_postgres_demo.py | 109 -------------- 7 files changed, 767 deletions(-) delete mode 100644 examples/lightrag_ollama_age_demo.py delete mode 100644 examples/lightrag_siliconcloud_demo.py delete mode 100644 examples/lightrag_siliconcloud_track_token_demo.py delete mode 100644 examples/lightrag_tidb_demo.py delete mode 100644 examples/lightrag_tongyi_openai_demo.py delete mode 100644 examples/lightrag_zhipu_demo.py delete mode 100644 examples/lightrag_zhipu_postgres_demo.py diff --git a/examples/lightrag_ollama_age_demo.py b/examples/lightrag_ollama_age_demo.py deleted file mode 100644 index 0e1b441e..00000000 --- a/examples/lightrag_ollama_age_demo.py +++ /dev/null @@ -1,113 +0,0 @@ -import asyncio -import nest_asyncio - -import inspect -import logging -import os - -from lightrag import LightRAG, QueryParam -from lightrag.llm.ollama import ollama_embed, ollama_model_complete -from lightrag.utils import EmbeddingFunc -from lightrag.kg.shared_storage import initialize_pipeline_status - -nest_asyncio.apply() - -WORKING_DIR = "./dickens_age" - -logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO) - -if not os.path.exists(WORKING_DIR): - os.mkdir(WORKING_DIR) - -# AGE -os.environ["AGE_POSTGRES_DB"] = "postgresDB" -os.environ["AGE_POSTGRES_USER"] = "postgresUser" -os.environ["AGE_POSTGRES_PASSWORD"] = "postgresPW" -os.environ["AGE_POSTGRES_HOST"] = "localhost" -os.environ["AGE_POSTGRES_PORT"] = "5455" -os.environ["AGE_GRAPH_NAME"] = "dickens" - - -async def initialize_rag(): - rag = LightRAG( - working_dir=WORKING_DIR, - llm_model_func=ollama_model_complete, - llm_model_name="llama3.1:8b", - llm_model_max_async=4, - llm_model_max_token_size=32768, - llm_model_kwargs={ - "host": "http://localhost:11434", - "options": {"num_ctx": 32768}, - }, - embedding_func=EmbeddingFunc( - embedding_dim=768, - max_token_size=8192, - func=lambda texts: ollama_embed( - texts, embed_model="nomic-embed-text", host="http://localhost:11434" - ), - ), - graph_storage="AGEStorage", - ) - - await rag.initialize_storages() - await initialize_pipeline_status() - - return rag - - -async def print_stream(stream): - async for chunk in stream: - print(chunk, end="", flush=True) - - -def main(): - # Initialize RAG instance - rag = asyncio.run(initialize_rag()) - - # Insert example text - with open("./book.txt", "r", encoding="utf-8") as f: - rag.insert(f.read()) - - # Test different query modes - print("\nNaive Search:") - print( - rag.query( - "What are the top themes in this story?", param=QueryParam(mode="naive") - ) - ) - - print("\nLocal Search:") - print( - rag.query( - "What are the top themes in this story?", param=QueryParam(mode="local") - ) - ) - - print("\nGlobal Search:") - print( - rag.query( - "What are the top themes in this story?", param=QueryParam(mode="global") - ) - ) - - print("\nHybrid Search:") - print( - rag.query( - "What are the top themes in this story?", param=QueryParam(mode="hybrid") - ) - ) - - # stream response - resp = rag.query( - "What are the top themes in this story?", - param=QueryParam(mode="hybrid", stream=True), - ) - - if inspect.isasyncgen(resp): - asyncio.run(print_stream(resp)) - else: - print(resp) - - -if __name__ == "__main__": - main() diff --git a/examples/lightrag_siliconcloud_demo.py b/examples/lightrag_siliconcloud_demo.py deleted file mode 100644 index 7a414aca..00000000 --- a/examples/lightrag_siliconcloud_demo.py +++ /dev/null @@ -1,103 +0,0 @@ -import os -import asyncio -from lightrag import LightRAG, QueryParam -from lightrag.llm.openai import openai_complete_if_cache -from lightrag.llm.siliconcloud import siliconcloud_embedding -from lightrag.utils import EmbeddingFunc -import numpy as np -from lightrag.kg.shared_storage import initialize_pipeline_status - -WORKING_DIR = "./dickens" - -if not os.path.exists(WORKING_DIR): - os.mkdir(WORKING_DIR) - - -async def llm_model_func( - prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs -) -> str: - return await openai_complete_if_cache( - "Qwen/Qwen2.5-7B-Instruct", - prompt, - system_prompt=system_prompt, - history_messages=history_messages, - api_key=os.getenv("SILICONFLOW_API_KEY"), - base_url="https://api.siliconflow.cn/v1/", - **kwargs, - ) - - -async def embedding_func(texts: list[str]) -> np.ndarray: - return await siliconcloud_embedding( - texts, - model="netease-youdao/bce-embedding-base_v1", - api_key=os.getenv("SILICONFLOW_API_KEY"), - max_token_size=512, - ) - - -# function test -async def test_funcs(): - result = await llm_model_func("How are you?") - print("llm_model_func: ", result) - - result = await embedding_func(["How are you?"]) - print("embedding_func: ", result) - - -asyncio.run(test_funcs()) - - -async def initialize_rag(): - rag = LightRAG( - working_dir=WORKING_DIR, - llm_model_func=llm_model_func, - embedding_func=EmbeddingFunc( - embedding_dim=768, max_token_size=512, func=embedding_func - ), - ) - - await rag.initialize_storages() - await initialize_pipeline_status() - - return rag - - -def main(): - # Initialize RAG instance - rag = asyncio.run(initialize_rag()) - - with open("./book.txt", "r", encoding="utf-8") as f: - rag.insert(f.read()) - - # Perform naive search - print( - rag.query( - "What are the top themes in this story?", param=QueryParam(mode="naive") - ) - ) - - # Perform local search - print( - rag.query( - "What are the top themes in this story?", param=QueryParam(mode="local") - ) - ) - - # Perform global search - print( - rag.query( - "What are the top themes in this story?", param=QueryParam(mode="global") - ) - ) - - # Perform hybrid search - print( - rag.query( - "What are the top themes in this story?", param=QueryParam(mode="hybrid") - ) - ) - - -if __name__ == "__main__": - main() diff --git a/examples/lightrag_siliconcloud_track_token_demo.py b/examples/lightrag_siliconcloud_track_token_demo.py deleted file mode 100644 index d82a30bc..00000000 --- a/examples/lightrag_siliconcloud_track_token_demo.py +++ /dev/null @@ -1,110 +0,0 @@ -import os -import asyncio -from lightrag import LightRAG, QueryParam -from lightrag.llm.openai import openai_complete_if_cache -from lightrag.llm.siliconcloud import siliconcloud_embedding -from lightrag.utils import EmbeddingFunc -from lightrag.utils import TokenTracker -import numpy as np -from lightrag.kg.shared_storage import initialize_pipeline_status -from dotenv import load_dotenv - -load_dotenv() - -token_tracker = TokenTracker() -WORKING_DIR = "./dickens" - -if not os.path.exists(WORKING_DIR): - os.mkdir(WORKING_DIR) - - -async def llm_model_func( - prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs -) -> str: - return await openai_complete_if_cache( - "Qwen/Qwen2.5-7B-Instruct", - prompt, - system_prompt=system_prompt, - history_messages=history_messages, - api_key=os.getenv("SILICONFLOW_API_KEY"), - base_url="https://api.siliconflow.cn/v1/", - token_tracker=token_tracker, - **kwargs, - ) - - -async def embedding_func(texts: list[str]) -> np.ndarray: - return await siliconcloud_embedding( - texts, - model="BAAI/bge-m3", - api_key=os.getenv("SILICONFLOW_API_KEY"), - max_token_size=512, - ) - - -# function test -async def test_funcs(): - # Context Manager Method - with token_tracker: - result = await llm_model_func("How are you?") - print("llm_model_func: ", result) - - -asyncio.run(test_funcs()) - - -async def initialize_rag(): - rag = LightRAG( - working_dir=WORKING_DIR, - llm_model_func=llm_model_func, - embedding_func=EmbeddingFunc( - embedding_dim=1024, max_token_size=512, func=embedding_func - ), - ) - - await rag.initialize_storages() - await initialize_pipeline_status() - - return rag - - -def main(): - # Initialize RAG instance - rag = asyncio.run(initialize_rag()) - - # Reset tracker before processing queries - token_tracker.reset() - - with open("./book.txt", "r", encoding="utf-8") as f: - rag.insert(f.read()) - - print( - rag.query( - "What are the top themes in this story?", param=QueryParam(mode="naive") - ) - ) - - print( - rag.query( - "What are the top themes in this story?", param=QueryParam(mode="local") - ) - ) - - print( - rag.query( - "What are the top themes in this story?", param=QueryParam(mode="global") - ) - ) - - print( - rag.query( - "What are the top themes in this story?", param=QueryParam(mode="hybrid") - ) - ) - - # Display final token usage after main query - print("Token usage:", token_tracker.get_usage()) - - -if __name__ == "__main__": - main() diff --git a/examples/lightrag_tidb_demo.py b/examples/lightrag_tidb_demo.py deleted file mode 100644 index 50eac2ca..00000000 --- a/examples/lightrag_tidb_demo.py +++ /dev/null @@ -1,116 +0,0 @@ -########################################### -# TiDB storage implementation is deprecated -########################################### - -import asyncio -import os - -import numpy as np - -from lightrag import LightRAG, QueryParam -from lightrag.llm import siliconcloud_embedding, openai_complete_if_cache -from lightrag.utils import EmbeddingFunc -from lightrag.kg.shared_storage import initialize_pipeline_status - -WORKING_DIR = "./dickens" - -# We use SiliconCloud API to call LLM on Oracle Cloud -# More docs here https://docs.siliconflow.cn/introduction -BASE_URL = "https://api.siliconflow.cn/v1/" -APIKEY = "" -CHATMODEL = "" -EMBEDMODEL = "" - -os.environ["TIDB_HOST"] = "" -os.environ["TIDB_PORT"] = "" -os.environ["TIDB_USER"] = "" -os.environ["TIDB_PASSWORD"] = "" -os.environ["TIDB_DATABASE"] = "lightrag" - -if not os.path.exists(WORKING_DIR): - os.mkdir(WORKING_DIR) - - -async def llm_model_func( - prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs -) -> str: - return await openai_complete_if_cache( - CHATMODEL, - prompt, - system_prompt=system_prompt, - history_messages=history_messages, - api_key=APIKEY, - base_url=BASE_URL, - **kwargs, - ) - - -async def embedding_func(texts: list[str]) -> np.ndarray: - return await siliconcloud_embedding( - texts, - # model=EMBEDMODEL, - api_key=APIKEY, - ) - - -async def get_embedding_dim(): - test_text = ["This is a test sentence."] - embedding = await embedding_func(test_text) - embedding_dim = embedding.shape[1] - return embedding_dim - - -async def initialize_rag(): - # Detect embedding dimension - embedding_dimension = await get_embedding_dim() - print(f"Detected embedding dimension: {embedding_dimension}") - - # Initialize LightRAG - # We use TiDB DB as the KV/vector - rag = LightRAG( - enable_llm_cache=False, - working_dir=WORKING_DIR, - chunk_token_size=512, - llm_model_func=llm_model_func, - embedding_func=EmbeddingFunc( - embedding_dim=embedding_dimension, - max_token_size=512, - func=embedding_func, - ), - kv_storage="TiDBKVStorage", - vector_storage="TiDBVectorDBStorage", - graph_storage="TiDBGraphStorage", - ) - - await rag.initialize_storages() - await initialize_pipeline_status() - - return rag - - -async def main(): - try: - # Initialize RAG instance - rag = await initialize_rag() - - with open("./book.txt", "r", encoding="utf-8") as f: - rag.insert(f.read()) - - # Perform search in different modes - modes = ["naive", "local", "global", "hybrid"] - for mode in modes: - print("=" * 20, mode, "=" * 20) - print( - await rag.aquery( - "What are the top themes in this story?", - param=QueryParam(mode=mode), - ) - ) - print("-" * 100, "\n") - - except Exception as e: - print(f"An error occurred: {e}") - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/examples/lightrag_tongyi_openai_demo.py b/examples/lightrag_tongyi_openai_demo.py deleted file mode 100644 index f44c287e..00000000 --- a/examples/lightrag_tongyi_openai_demo.py +++ /dev/null @@ -1,136 +0,0 @@ -import os -import asyncio -from lightrag import LightRAG, QueryParam -from lightrag.utils import EmbeddingFunc -import numpy as np -from dotenv import load_dotenv -import logging -from openai import OpenAI -from lightrag.kg.shared_storage import initialize_pipeline_status - -logging.basicConfig(level=logging.INFO) - -load_dotenv() - -LLM_MODEL = os.environ.get("LLM_MODEL", "qwen-turbo-latest") -LLM_BINDING_HOST = "https://dashscope.aliyuncs.com/compatible-mode/v1" -LLM_BINDING_API_KEY = os.getenv("LLM_BINDING_API_KEY") - -EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-v3") -EMBEDDING_BINDING_HOST = os.getenv("EMBEDDING_BINDING_HOST", LLM_BINDING_HOST) -EMBEDDING_BINDING_API_KEY = os.getenv("EMBEDDING_BINDING_API_KEY", LLM_BINDING_API_KEY) -EMBEDDING_DIM = int(os.environ.get("EMBEDDING_DIM", 1024)) -EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192)) -EMBEDDING_MAX_BATCH_SIZE = int(os.environ.get("EMBEDDING_MAX_BATCH_SIZE", 10)) - -print(f"LLM_MODEL: {LLM_MODEL}") -print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}") - -WORKING_DIR = "./dickens" - -if os.path.exists(WORKING_DIR): - import shutil - - shutil.rmtree(WORKING_DIR) - -os.mkdir(WORKING_DIR) - - -async def llm_model_func( - prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs -) -> str: - client = OpenAI( - api_key=LLM_BINDING_API_KEY, - base_url=LLM_BINDING_HOST, - ) - - messages = [] - if system_prompt: - messages.append({"role": "system", "content": system_prompt}) - if history_messages: - messages.extend(history_messages) - messages.append({"role": "user", "content": prompt}) - - chat_completion = client.chat.completions.create( - model=LLM_MODEL, - messages=messages, - temperature=kwargs.get("temperature", 0), - top_p=kwargs.get("top_p", 1), - n=kwargs.get("n", 1), - extra_body={"enable_thinking": False}, - ) - return chat_completion.choices[0].message.content - - -async def embedding_func(texts: list[str]) -> np.ndarray: - client = OpenAI( - api_key=EMBEDDING_BINDING_API_KEY, - base_url=EMBEDDING_BINDING_HOST, - ) - - print("##### embedding: texts: %d #####" % len(texts)) - max_batch_size = EMBEDDING_MAX_BATCH_SIZE - embeddings = [] - for i in range(0, len(texts), max_batch_size): - batch = texts[i : i + max_batch_size] - embedding = client.embeddings.create(model=EMBEDDING_MODEL, input=batch) - embeddings += [item.embedding for item in embedding.data] - - return np.array(embeddings) - - -async def test_funcs(): - result = await llm_model_func("How are you?") - print("Resposta do llm_model_func: ", result) - - result = await embedding_func(["How are you?"]) - print("Resultado do embedding_func: ", result.shape) - print("Dimensão da embedding: ", result.shape[1]) - - -asyncio.run(test_funcs()) - - -async def initialize_rag(): - rag = LightRAG( - working_dir=WORKING_DIR, - llm_model_func=llm_model_func, - embedding_func=EmbeddingFunc( - embedding_dim=EMBEDDING_DIM, - max_token_size=EMBEDDING_MAX_TOKEN_SIZE, - func=embedding_func, - ), - ) - - await rag.initialize_storages() - await initialize_pipeline_status() - - return rag - - -def main(): - rag = asyncio.run(initialize_rag()) - - with open("./book.txt", "r", encoding="utf-8") as f: - rag.insert(f.read()) - - query_text = "What are the main themes?" - - print("Result (Naive):") - print(rag.query(query_text, param=QueryParam(mode="naive"))) - - print("\nResult (Local):") - print(rag.query(query_text, param=QueryParam(mode="local"))) - - print("\nResult (Global):") - print(rag.query(query_text, param=QueryParam(mode="global"))) - - print("\nResult (Hybrid):") - print(rag.query(query_text, param=QueryParam(mode="hybrid"))) - - print("\nResult (mix):") - print(rag.query(query_text, param=QueryParam(mode="mix"))) - - -if __name__ == "__main__": - main() diff --git a/examples/lightrag_zhipu_demo.py b/examples/lightrag_zhipu_demo.py deleted file mode 100644 index fdc37c9c..00000000 --- a/examples/lightrag_zhipu_demo.py +++ /dev/null @@ -1,80 +0,0 @@ -import os -import logging -import asyncio - - -from lightrag import LightRAG, QueryParam -from lightrag.llm.zhipu import zhipu_complete, zhipu_embedding -from lightrag.utils import EmbeddingFunc -from lightrag.kg.shared_storage import initialize_pipeline_status - -WORKING_DIR = "./dickens" - -logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO) - -if not os.path.exists(WORKING_DIR): - os.mkdir(WORKING_DIR) - -api_key = os.environ.get("ZHIPUAI_API_KEY") -if api_key is None: - raise Exception("Please set ZHIPU_API_KEY in your environment") - - -async def initialize_rag(): - rag = LightRAG( - working_dir=WORKING_DIR, - llm_model_func=zhipu_complete, - llm_model_name="glm-4-flashx", # Using the most cost/performance balance model, but you can change it here. - llm_model_max_async=4, - llm_model_max_token_size=32768, - embedding_func=EmbeddingFunc( - embedding_dim=2048, # Zhipu embedding-3 dimension - max_token_size=8192, - func=lambda texts: zhipu_embedding(texts), - ), - ) - - await rag.initialize_storages() - await initialize_pipeline_status() - - return rag - - -def main(): - # Initialize RAG instance - rag = asyncio.run(initialize_rag()) - - with open("./book.txt", "r", encoding="utf-8") as f: - rag.insert(f.read()) - - # Perform naive search - print( - rag.query( - "What are the top themes in this story?", param=QueryParam(mode="naive") - ) - ) - - # Perform local search - print( - rag.query( - "What are the top themes in this story?", param=QueryParam(mode="local") - ) - ) - - # Perform global search - print( - rag.query( - "What are the top themes in this story?", param=QueryParam(mode="global") - ) - ) - - # Perform hybrid search - print( - rag.query( - "What are the top themes in this story?", param=QueryParam(mode="hybrid") - ) - ) - - -if __name__ == "__main__": - main() diff --git a/examples/lightrag_zhipu_postgres_demo.py b/examples/lightrag_zhipu_postgres_demo.py deleted file mode 100644 index e4a20f26..00000000 --- a/examples/lightrag_zhipu_postgres_demo.py +++ /dev/null @@ -1,109 +0,0 @@ -import asyncio -import logging -import os -import time -from dotenv import load_dotenv - -from lightrag import LightRAG, QueryParam -from lightrag.llm.zhipu import zhipu_complete -from lightrag.llm.ollama import ollama_embedding -from lightrag.utils import EmbeddingFunc -from lightrag.kg.shared_storage import initialize_pipeline_status - -load_dotenv() -ROOT_DIR = os.environ.get("ROOT_DIR") -WORKING_DIR = f"{ROOT_DIR}/dickens-pg" - -logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO) - -if not os.path.exists(WORKING_DIR): - os.mkdir(WORKING_DIR) - -# AGE -os.environ["AGE_GRAPH_NAME"] = "dickens" - -os.environ["POSTGRES_HOST"] = "localhost" -os.environ["POSTGRES_PORT"] = "15432" -os.environ["POSTGRES_USER"] = "rag" -os.environ["POSTGRES_PASSWORD"] = "rag" -os.environ["POSTGRES_DATABASE"] = "rag" - - -async def initialize_rag(): - rag = LightRAG( - working_dir=WORKING_DIR, - llm_model_func=zhipu_complete, - llm_model_name="glm-4-flashx", - llm_model_max_async=4, - llm_model_max_token_size=32768, - enable_llm_cache_for_entity_extract=True, - embedding_func=EmbeddingFunc( - embedding_dim=1024, - max_token_size=8192, - func=lambda texts: ollama_embedding( - texts, embed_model="bge-m3", host="http://localhost:11434" - ), - ), - kv_storage="PGKVStorage", - doc_status_storage="PGDocStatusStorage", - graph_storage="PGGraphStorage", - vector_storage="PGVectorStorage", - auto_manage_storages_states=False, - ) - - await rag.initialize_storages() - await initialize_pipeline_status() - - return rag - - -async def main(): - # Initialize RAG instance - rag = await initialize_rag() - - # add embedding_func for graph database, it's deleted in commit 5661d76860436f7bf5aef2e50d9ee4a59660146c - rag.chunk_entity_relation_graph.embedding_func = rag.embedding_func - - with open(f"{ROOT_DIR}/book.txt", "r", encoding="utf-8") as f: - await rag.ainsert(f.read()) - - print("==== Trying to test the rag queries ====") - print("**** Start Naive Query ****") - start_time = time.time() - # Perform naive search - print( - await rag.aquery( - "What are the top themes in this story?", param=QueryParam(mode="naive") - ) - ) - print(f"Naive Query Time: {time.time() - start_time} seconds") - # Perform local search - print("**** Start Local Query ****") - start_time = time.time() - print( - await rag.aquery( - "What are the top themes in this story?", param=QueryParam(mode="local") - ) - ) - print(f"Local Query Time: {time.time() - start_time} seconds") - # Perform global search - print("**** Start Global Query ****") - start_time = time.time() - print( - await rag.aquery( - "What are the top themes in this story?", param=QueryParam(mode="global") - ) - ) - print(f"Global Query Time: {time.time() - start_time}") - # Perform hybrid search - print("**** Start Hybrid Query ****") - print( - await rag.aquery( - "What are the top themes in this story?", param=QueryParam(mode="hybrid") - ) - ) - print(f"Hybrid Query Time: {time.time() - start_time} seconds") - - -if __name__ == "__main__": - asyncio.run(main()) From aa36894d6e926dfc4d0a95c90dfc4bcee8d9f39d Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 14 May 2025 00:36:38 +0800 Subject: [PATCH 2/5] Remove deprecated demo code --- examples/lightrag_ollama_gremlin_demo.py | 122 ----------------- ...lightrag_ollama_neo4j_milvus_mongo_demo.py | 104 --------------- ..._openai_compatible_demo_embedding_cache.py | 123 ------------------ 3 files changed, 349 deletions(-) delete mode 100644 examples/lightrag_ollama_gremlin_demo.py delete mode 100644 examples/lightrag_ollama_neo4j_milvus_mongo_demo.py delete mode 100644 examples/lightrag_openai_compatible_demo_embedding_cache.py diff --git a/examples/lightrag_ollama_gremlin_demo.py b/examples/lightrag_ollama_gremlin_demo.py deleted file mode 100644 index 7ae62810..00000000 --- a/examples/lightrag_ollama_gremlin_demo.py +++ /dev/null @@ -1,122 +0,0 @@ -############################################## -# Gremlin storage implementation is deprecated -############################################## - -import asyncio -import inspect -import os - -# Uncomment these lines below to filter out somewhat verbose INFO level -# logging prints (the default loglevel is INFO). -# This has to go before the lightrag imports to work, -# which triggers linting errors, so we keep it commented out: -# import logging -# logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.WARN) - -from lightrag import LightRAG, QueryParam -from lightrag.llm.ollama import ollama_embed, ollama_model_complete -from lightrag.utils import EmbeddingFunc -from lightrag.kg.shared_storage import initialize_pipeline_status - -WORKING_DIR = "./dickens_gremlin" - -if not os.path.exists(WORKING_DIR): - os.mkdir(WORKING_DIR) - -# Gremlin -os.environ["GREMLIN_HOST"] = "localhost" -os.environ["GREMLIN_PORT"] = "8182" -os.environ["GREMLIN_GRAPH"] = "dickens" - -# Creating a non-default source requires manual -# configuration and a restart on the server: use the dafault "g" -os.environ["GREMLIN_TRAVERSE_SOURCE"] = "g" - -# No authorization by default on docker tinkerpop/gremlin-server -os.environ["GREMLIN_USER"] = "" -os.environ["GREMLIN_PASSWORD"] = "" - - -async def initialize_rag(): - rag = LightRAG( - working_dir=WORKING_DIR, - llm_model_func=ollama_model_complete, - llm_model_name="llama3.1:8b", - llm_model_max_async=4, - llm_model_max_token_size=32768, - llm_model_kwargs={ - "host": "http://localhost:11434", - "options": {"num_ctx": 32768}, - }, - embedding_func=EmbeddingFunc( - embedding_dim=768, - max_token_size=8192, - func=lambda texts: ollama_embed( - texts, embed_model="nomic-embed-text", host="http://localhost:11434" - ), - ), - graph_storage="GremlinStorage", - ) - - await rag.initialize_storages() - await initialize_pipeline_status() - - return rag - - -async def print_stream(stream): - async for chunk in stream: - print(chunk, end="", flush=True) - - -def main(): - # Initialize RAG instance - rag = asyncio.run(initialize_rag()) - - # Insert example text - with open("./book.txt", "r", encoding="utf-8") as f: - rag.insert(f.read()) - - # Test different query modes - print("\nNaive Search:") - print( - rag.query( - "What are the top themes in this story?", param=QueryParam(mode="naive") - ) - ) - - print("\nLocal Search:") - print( - rag.query( - "What are the top themes in this story?", param=QueryParam(mode="local") - ) - ) - - print("\nGlobal Search:") - print( - rag.query( - "What are the top themes in this story?", param=QueryParam(mode="global") - ) - ) - - print("\nHybrid Search:") - print( - rag.query( - "What are the top themes in this story?", param=QueryParam(mode="hybrid") - ) - ) - - # stream response - resp = rag.query( - "What are the top themes in this story?", - param=QueryParam(mode="hybrid", stream=True), - ) - - if inspect.isasyncgen(resp): - asyncio.run(print_stream(resp)) - else: - print(resp) - - -if __name__ == "__main__": - main() diff --git a/examples/lightrag_ollama_neo4j_milvus_mongo_demo.py b/examples/lightrag_ollama_neo4j_milvus_mongo_demo.py deleted file mode 100644 index b6cc931c..00000000 --- a/examples/lightrag_ollama_neo4j_milvus_mongo_demo.py +++ /dev/null @@ -1,104 +0,0 @@ -import os -from lightrag import LightRAG, QueryParam -from lightrag.llm.ollama import ollama_model_complete, ollama_embed -from lightrag.utils import EmbeddingFunc -import asyncio -import nest_asyncio - -nest_asyncio.apply() -from lightrag.kg.shared_storage import initialize_pipeline_status - -# WorkingDir -ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) -WORKING_DIR = os.path.join(ROOT_DIR, "myKG") -if not os.path.exists(WORKING_DIR): - os.mkdir(WORKING_DIR) -print(f"WorkingDir: {WORKING_DIR}") - -# mongo -os.environ["MONGO_URI"] = "mongodb://root:root@localhost:27017/" -os.environ["MONGO_DATABASE"] = "LightRAG" - -# neo4j -BATCH_SIZE_NODES = 500 -BATCH_SIZE_EDGES = 100 -os.environ["NEO4J_URI"] = "bolt://localhost:7687" -os.environ["NEO4J_USERNAME"] = "neo4j" -os.environ["NEO4J_PASSWORD"] = "neo4j" - -# milvus -os.environ["MILVUS_URI"] = "http://localhost:19530" -os.environ["MILVUS_USER"] = "root" -os.environ["MILVUS_PASSWORD"] = "root" -os.environ["MILVUS_DB_NAME"] = "lightrag" - - -async def initialize_rag(): - rag = LightRAG( - working_dir=WORKING_DIR, - llm_model_func=ollama_model_complete, - llm_model_name="qwen2.5:14b", - llm_model_max_async=4, - llm_model_max_token_size=32768, - llm_model_kwargs={ - "host": "http://127.0.0.1:11434", - "options": {"num_ctx": 32768}, - }, - embedding_func=EmbeddingFunc( - embedding_dim=1024, - max_token_size=8192, - func=lambda texts: ollama_embed( - texts=texts, embed_model="bge-m3:latest", host="http://127.0.0.1:11434" - ), - ), - kv_storage="MongoKVStorage", - graph_storage="Neo4JStorage", - vector_storage="MilvusVectorDBStorage", - ) - - await rag.initialize_storages() - await initialize_pipeline_status() - - return rag - - -def main(): - # Initialize RAG instance - rag = asyncio.run(initialize_rag()) - - # Insert example text - with open("./book.txt", "r", encoding="utf-8") as f: - rag.insert(f.read()) - - # Test different query modes - print("\nNaive Search:") - print( - rag.query( - "What are the top themes in this story?", param=QueryParam(mode="naive") - ) - ) - - print("\nLocal Search:") - print( - rag.query( - "What are the top themes in this story?", param=QueryParam(mode="local") - ) - ) - - print("\nGlobal Search:") - print( - rag.query( - "What are the top themes in this story?", param=QueryParam(mode="global") - ) - ) - - print("\nHybrid Search:") - print( - rag.query( - "What are the top themes in this story?", param=QueryParam(mode="hybrid") - ) - ) - - -if __name__ == "__main__": - main() diff --git a/examples/lightrag_openai_compatible_demo_embedding_cache.py b/examples/lightrag_openai_compatible_demo_embedding_cache.py deleted file mode 100644 index 4638219f..00000000 --- a/examples/lightrag_openai_compatible_demo_embedding_cache.py +++ /dev/null @@ -1,123 +0,0 @@ -import os -import asyncio -from lightrag import LightRAG, QueryParam -from lightrag.llm.openai import openai_complete_if_cache, openai_embed -from lightrag.utils import EmbeddingFunc -import numpy as np -from lightrag.kg.shared_storage import initialize_pipeline_status - -WORKING_DIR = "./dickens" - -if not os.path.exists(WORKING_DIR): - os.mkdir(WORKING_DIR) - - -async def llm_model_func( - prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs -) -> str: - return await openai_complete_if_cache( - "solar-mini", - prompt, - system_prompt=system_prompt, - history_messages=history_messages, - api_key=os.getenv("UPSTAGE_API_KEY"), - base_url="https://api.upstage.ai/v1/solar", - **kwargs, - ) - - -async def embedding_func(texts: list[str]) -> np.ndarray: - return await openai_embed( - texts, - model="solar-embedding-1-large-query", - api_key=os.getenv("UPSTAGE_API_KEY"), - base_url="https://api.upstage.ai/v1/solar", - ) - - -async def get_embedding_dim(): - test_text = ["This is a test sentence."] - embedding = await embedding_func(test_text) - embedding_dim = embedding.shape[1] - return embedding_dim - - -# function test -async def test_funcs(): - result = await llm_model_func("How are you?") - print("llm_model_func: ", result) - - result = await embedding_func(["How are you?"]) - print("embedding_func: ", result) - - -# asyncio.run(test_funcs()) - - -async def initialize_rag(): - embedding_dimension = await get_embedding_dim() - print(f"Detected embedding dimension: {embedding_dimension}") - - rag = LightRAG( - working_dir=WORKING_DIR, - embedding_cache_config={ - "enabled": True, - "similarity_threshold": 0.90, - }, - llm_model_func=llm_model_func, - embedding_func=EmbeddingFunc( - embedding_dim=embedding_dimension, - max_token_size=8192, - func=embedding_func, - ), - ) - - await rag.initialize_storages() - await initialize_pipeline_status() - - return rag - - -async def main(): - try: - # Initialize RAG instance - rag = await initialize_rag() - - with open("./book.txt", "r", encoding="utf-8") as f: - await rag.ainsert(f.read()) - - # Perform naive search - print( - await rag.aquery( - "What are the top themes in this story?", param=QueryParam(mode="naive") - ) - ) - - # Perform local search - print( - await rag.aquery( - "What are the top themes in this story?", param=QueryParam(mode="local") - ) - ) - - # Perform global search - print( - await rag.aquery( - "What are the top themes in this story?", - param=QueryParam(mode="global"), - ) - ) - - # Perform hybrid search - print( - await rag.aquery( - "What are the top themes in this story?", - param=QueryParam(mode="hybrid"), - ) - ) - except Exception as e: - print(f"An error occurred: {e}") - - -if __name__ == "__main__": - asyncio.run(main()) From b836d02cacebd55b60bf9bbcdce09439fdb99c4e Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 14 May 2025 01:13:03 +0800 Subject: [PATCH 3/5] Optimize Ollama LLM driver --- README-zh.md | 2 +- README.md | 2 +- lightrag/llm/ollama.py | 112 +++++++++++++++++++++++++++-------------- 3 files changed, 75 insertions(+), 41 deletions(-) diff --git a/README-zh.md b/README-zh.md index 66690ee8..5300b2cf 100644 --- a/README-zh.md +++ b/README-zh.md @@ -415,7 +415,7 @@ rag = LightRAG( embedding_func=EmbeddingFunc( embedding_dim=768, max_token_size=8192, - func=lambda texts: ollama_embedding( + func=lambda texts: ollama_embed( texts, embed_model="nomic-embed-text" ) diff --git a/README.md b/README.md index 449880f2..12e18f0d 100644 --- a/README.md +++ b/README.md @@ -447,7 +447,7 @@ rag = LightRAG( embedding_func=EmbeddingFunc( embedding_dim=768, max_token_size=8192, - func=lambda texts: ollama_embedding( + func=lambda texts: ollama_embed( texts, embed_model="nomic-embed-text" ) diff --git a/lightrag/llm/ollama.py b/lightrag/llm/ollama.py index 21ae9a67..7668be44 100644 --- a/lightrag/llm/ollama.py +++ b/lightrag/llm/ollama.py @@ -31,6 +31,7 @@ from lightrag.api import __api_version__ import numpy as np from typing import Union +from lightrag.utils import logger @retry( @@ -52,7 +53,7 @@ async def _ollama_model_if_cache( kwargs.pop("max_tokens", None) # kwargs.pop("response_format", None) # allow json host = kwargs.pop("host", None) - timeout = kwargs.pop("timeout", None) + timeout = kwargs.pop("timeout", None) or 300 # Default timeout 300s kwargs.pop("hashing_kv", None) api_key = kwargs.pop("api_key", None) headers = { @@ -61,32 +62,59 @@ async def _ollama_model_if_cache( } if api_key: headers["Authorization"] = f"Bearer {api_key}" + ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers) - messages = [] - if system_prompt: - messages.append({"role": "system", "content": system_prompt}) - messages.extend(history_messages) - messages.append({"role": "user", "content": prompt}) + + try: + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.extend(history_messages) + messages.append({"role": "user", "content": prompt}) - response = await ollama_client.chat(model=model, messages=messages, **kwargs) - if stream: - """cannot cache stream response and process reasoning""" + response = await ollama_client.chat(model=model, messages=messages, **kwargs) + if stream: + """cannot cache stream response and process reasoning""" - async def inner(): - async for chunk in response: - yield chunk["message"]["content"] + async def inner(): + try: + async for chunk in response: + yield chunk["message"]["content"] + except Exception as e: + logger.error(f"Error in stream response: {str(e)}") + raise + finally: + try: + await ollama_client._client.aclose() + logger.debug("Successfully closed Ollama client for streaming") + except Exception as close_error: + logger.warning(f"Failed to close Ollama client: {close_error}") - return inner() - else: - model_response = response["message"]["content"] + return inner() + else: + model_response = response["message"]["content"] - """ - If the model also wraps its thoughts in a specific tag, - this information is not needed for the final - response and can simply be trimmed. - """ + """ + If the model also wraps its thoughts in a specific tag, + this information is not needed for the final + response and can simply be trimmed. + """ - return model_response + return model_response + except Exception as e: + try: + await ollama_client._client.aclose() + logger.debug("Successfully closed Ollama client after exception") + except Exception as close_error: + logger.warning(f"Failed to close Ollama client after exception: {close_error}") + raise e + finally: + if not stream: + try: + await ollama_client._client.aclose() + logger.debug("Successfully closed Ollama client for non-streaming response") + except Exception as close_error: + logger.warning(f"Failed to close Ollama client in finally block: {close_error}") async def ollama_model_complete( @@ -105,19 +133,6 @@ async def ollama_model_complete( ) -async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray: - """ - Deprecated in favor of `embed`. - """ - embed_text = [] - ollama_client = ollama.Client(**kwargs) - for text in texts: - data = ollama_client.embeddings(model=embed_model, prompt=text) - embed_text.append(data["embedding"]) - - return embed_text - - async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray: api_key = kwargs.pop("api_key", None) headers = { @@ -125,8 +140,27 @@ async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray: "User-Agent": f"LightRAG/{__api_version__}", } if api_key: - headers["Authorization"] = api_key - kwargs["headers"] = headers - ollama_client = ollama.Client(**kwargs) - data = ollama_client.embed(model=embed_model, input=texts) - return np.array(data["embeddings"]) + headers["Authorization"] = f"Bearer {api_key}" + + host = kwargs.pop("host", None) + timeout = kwargs.pop("timeout", None) or 90 # Default time out 90s + + ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers) + + try: + data = await ollama_client.embed(model=embed_model, input=texts) + return np.array(data["embeddings"]) + except Exception as e: + logger.error(f"Error in ollama_embed: {str(e)}") + try: + await ollama_client._client.aclose() + logger.debug("Successfully closed Ollama client after exception in embed") + except Exception as close_error: + logger.warning(f"Failed to close Ollama client after exception in embed: {close_error}") + raise e + finally: + try: + await ollama_client._client.aclose() + logger.debug("Successfully closed Ollama client after embed") + except Exception as close_error: + logger.warning(f"Failed to close Ollama client after embed: {close_error}") From 5c9fd9c4d2f2245f028ff2fb2499010bedcb98e7 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 14 May 2025 01:14:15 +0800 Subject: [PATCH 4/5] Update Ollama sample code --- examples/lightrag_ollama_demo.py | 219 +++++++++++++++++++++++-------- 1 file changed, 167 insertions(+), 52 deletions(-) diff --git a/examples/lightrag_ollama_demo.py b/examples/lightrag_ollama_demo.py index 437ace1b..1ce1f2da 100644 --- a/examples/lightrag_ollama_demo.py +++ b/examples/lightrag_ollama_demo.py @@ -1,19 +1,84 @@ import asyncio -import nest_asyncio - import os import inspect import logging +import logging.config from lightrag import LightRAG, QueryParam from lightrag.llm.ollama import ollama_model_complete, ollama_embed -from lightrag.utils import EmbeddingFunc +from lightrag.utils import EmbeddingFunc, logger, set_verbose_debug from lightrag.kg.shared_storage import initialize_pipeline_status -nest_asyncio.apply() +from dotenv import load_dotenv + +load_dotenv(dotenv_path=".env", override=False) WORKING_DIR = "./dickens" -logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO) + +def configure_logging(): + """Configure logging for the application""" + + # Reset any existing handlers to ensure clean configuration + for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error", "lightrag"]: + logger_instance = logging.getLogger(logger_name) + logger_instance.handlers = [] + logger_instance.filters = [] + + # Get log directory path from environment variable or use current directory + log_dir = os.getenv("LOG_DIR", os.getcwd()) + log_file_path = os.path.abspath( + os.path.join(log_dir, "lightrag_ollama_demo.log") + ) + + print(f"\nLightRAG compatible demo log file: {log_file_path}\n") + os.makedirs(os.path.dirname(log_file_path), exist_ok=True) + + # Get log file max size and backup count from environment variables + log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB + log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups + + logging.config.dictConfig( + { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "default": { + "format": "%(levelname)s: %(message)s", + }, + "detailed": { + "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s", + }, + }, + "handlers": { + "console": { + "formatter": "default", + "class": "logging.StreamHandler", + "stream": "ext://sys.stderr", + }, + "file": { + "formatter": "detailed", + "class": "logging.handlers.RotatingFileHandler", + "filename": log_file_path, + "maxBytes": log_max_bytes, + "backupCount": log_backup_count, + "encoding": "utf-8", + }, + }, + "loggers": { + "lightrag": { + "handlers": ["console", "file"], + "level": "INFO", + "propagate": False, + }, + }, + } + ) + + # Set the logger level to INFO + logger.setLevel(logging.INFO) + # Enable verbose debug if needed + set_verbose_debug(os.getenv("VERBOSE_DEBUG", "false").lower() == "true") + if not os.path.exists(WORKING_DIR): os.mkdir(WORKING_DIR) @@ -23,18 +88,20 @@ async def initialize_rag(): rag = LightRAG( working_dir=WORKING_DIR, llm_model_func=ollama_model_complete, - llm_model_name="gemma2:2b", - llm_model_max_async=4, - llm_model_max_token_size=32768, + llm_model_name=os.getenv("LLM_MODEL", "qwen2.5-coder:7b"), + llm_model_max_token_size=8192, llm_model_kwargs={ - "host": "http://localhost:11434", - "options": {"num_ctx": 32768}, + "host": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"), + "options": {"num_ctx": 8192}, + "timeout": int(os.getenv("TIMEOUT", "300")), }, embedding_func=EmbeddingFunc( - embedding_dim=768, - max_token_size=8192, + embedding_dim=int(os.getenv("EMBEDDING_DIM", "1024")), + max_token_size=int(os.getenv("MAX_EMBED_TOKENS", "8192")), func=lambda texts: ollama_embed( - texts, embed_model="nomic-embed-text", host="http://localhost:11434" + texts, + embed_model=os.getenv("EMBEDDING_MODEL", "bge-m3:latest"), + host=os.getenv("EMBEDDING_BINDING_HOST", "http://localhost:11434"), ), ), ) @@ -50,54 +117,102 @@ async def print_stream(stream): print(chunk, end="", flush=True) -def main(): - # Initialize RAG instance - rag = asyncio.run(initialize_rag()) +async def main(): + try: + # Clear old data files + files_to_delete = [ + "graph_chunk_entity_relation.graphml", + "kv_store_doc_status.json", + "kv_store_full_docs.json", + "kv_store_text_chunks.json", + "vdb_chunks.json", + "vdb_entities.json", + "vdb_relationships.json", + ] - # Insert example text - with open("./book.txt", "r", encoding="utf-8") as f: - rag.insert(f.read()) + for file in files_to_delete: + file_path = os.path.join(WORKING_DIR, file) + if os.path.exists(file_path): + os.remove(file_path) + print(f"Deleting old file:: {file_path}") - # Test different query modes - print("\nNaive Search:") - print( - rag.query( - "What are the top themes in this story?", param=QueryParam(mode="naive") + # Initialize RAG instance + rag = await initialize_rag() + + # Test embedding function + test_text = ["This is a test string for embedding."] + embedding = await rag.embedding_func(test_text) + embedding_dim = embedding.shape[1] + print("\n=======================") + print("Test embedding function") + print("========================") + print(f"Test dict: {test_text}") + print(f"Detected embedding dimension: {embedding_dim}\n\n") + + with open("./book.txt", "r", encoding="utf-8") as f: + await rag.ainsert(f.read()) + + # Perform naive search + print("\n=====================") + print("Query mode: naive") + print("=====================") + resp = await rag.aquery( + "What are the top themes in this story?", + param=QueryParam(mode="naive", stream=True), ) - ) + if inspect.isasyncgen(resp): + await print_stream(resp) + else: + print(resp) - print("\nLocal Search:") - print( - rag.query( - "What are the top themes in this story?", param=QueryParam(mode="local") + # Perform local search + print("\n=====================") + print("Query mode: local") + print("=====================") + resp = await rag.aquery( + "What are the top themes in this story?", + param=QueryParam(mode="local", stream=True), ) - ) + if inspect.isasyncgen(resp): + await print_stream(resp) + else: + print(resp) - print("\nGlobal Search:") - print( - rag.query( - "What are the top themes in this story?", param=QueryParam(mode="global") + # Perform global search + print("\n=====================") + print("Query mode: global") + print("=====================") + resp = await rag.aquery( + "What are the top themes in this story?", + param=QueryParam(mode="global", stream=True), ) - ) + if inspect.isasyncgen(resp): + await print_stream(resp) + else: + print(resp) - print("\nHybrid Search:") - print( - rag.query( - "What are the top themes in this story?", param=QueryParam(mode="hybrid") + # Perform hybrid search + print("\n=====================") + print("Query mode: hybrid") + print("=====================") + resp = await rag.aquery( + "What are the top themes in this story?", + param=QueryParam(mode="hybrid", stream=True), ) - ) - - # stream response - resp = rag.query( - "What are the top themes in this story?", - param=QueryParam(mode="hybrid", stream=True), - ) - - if inspect.isasyncgen(resp): - asyncio.run(print_stream(resp)) - else: - print(resp) + if inspect.isasyncgen(resp): + await print_stream(resp) + else: + print(resp) + except Exception as e: + print(f"An error occurred: {e}") + finally: + if rag: + await rag.llm_response_cache.index_done_callback() + await rag.finalize_storages() if __name__ == "__main__": - main() + # Configure logging before running the main function + configure_logging() + asyncio.run(main()) + print("\nDone!") From 0e26cbebd0c86642e9f87c76ec5ccb016ecf9327 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 14 May 2025 01:14:45 +0800 Subject: [PATCH 5/5] Fix linting --- examples/lightrag_ollama_demo.py | 5 ++--- lightrag/llm/ollama.py | 26 +++++++++++++++++--------- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/examples/lightrag_ollama_demo.py b/examples/lightrag_ollama_demo.py index 1ce1f2da..b012f685 100644 --- a/examples/lightrag_ollama_demo.py +++ b/examples/lightrag_ollama_demo.py @@ -26,9 +26,7 @@ def configure_logging(): # Get log directory path from environment variable or use current directory log_dir = os.getenv("LOG_DIR", os.getcwd()) - log_file_path = os.path.abspath( - os.path.join(log_dir, "lightrag_ollama_demo.log") - ) + log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag_ollama_demo.log")) print(f"\nLightRAG compatible demo log file: {log_file_path}\n") os.makedirs(os.path.dirname(log_file_path), exist_ok=True) @@ -211,6 +209,7 @@ async def main(): await rag.llm_response_cache.index_done_callback() await rag.finalize_storages() + if __name__ == "__main__": # Configure logging before running the main function configure_logging() diff --git a/lightrag/llm/ollama.py b/lightrag/llm/ollama.py index 7668be44..3cf10511 100644 --- a/lightrag/llm/ollama.py +++ b/lightrag/llm/ollama.py @@ -62,9 +62,9 @@ async def _ollama_model_if_cache( } if api_key: headers["Authorization"] = f"Bearer {api_key}" - + ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers) - + try: messages = [] if system_prompt: @@ -106,15 +106,21 @@ async def _ollama_model_if_cache( await ollama_client._client.aclose() logger.debug("Successfully closed Ollama client after exception") except Exception as close_error: - logger.warning(f"Failed to close Ollama client after exception: {close_error}") + logger.warning( + f"Failed to close Ollama client after exception: {close_error}" + ) raise e finally: if not stream: try: await ollama_client._client.aclose() - logger.debug("Successfully closed Ollama client for non-streaming response") + logger.debug( + "Successfully closed Ollama client for non-streaming response" + ) except Exception as close_error: - logger.warning(f"Failed to close Ollama client in finally block: {close_error}") + logger.warning( + f"Failed to close Ollama client in finally block: {close_error}" + ) async def ollama_model_complete( @@ -141,12 +147,12 @@ async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray: } if api_key: headers["Authorization"] = f"Bearer {api_key}" - + host = kwargs.pop("host", None) timeout = kwargs.pop("timeout", None) or 90 # Default time out 90s - + ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers) - + try: data = await ollama_client.embed(model=embed_model, input=texts) return np.array(data["embeddings"]) @@ -156,7 +162,9 @@ async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray: await ollama_client._client.aclose() logger.debug("Successfully closed Ollama client after exception in embed") except Exception as close_error: - logger.warning(f"Failed to close Ollama client after exception in embed: {close_error}") + logger.warning( + f"Failed to close Ollama client after exception in embed: {close_error}" + ) raise e finally: try: