From 021152f95c979cdb33d63b5568d35a8f0ec9a784 Mon Sep 17 00:00:00 2001 From: omdivyatej Date: Sun, 23 Mar 2025 19:58:58 +0530 Subject: [PATCH 1/5] first --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index eb2575e7..95caf1d2 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,5 @@

🚀 LightRAG: Simple and Fast Retrieval-Augmented Generation

+hi om
From f8ba98c1ff5f80bcaff2bb53488a0aee1d805c14 Mon Sep 17 00:00:00 2001 From: omdivyatej Date: Sun, 23 Mar 2025 20:00:25 +0530 Subject: [PATCH 2/5] first --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index 95caf1d2..eb2575e7 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,4 @@

🚀 LightRAG: Simple and Fast Retrieval-Augmented Generation

-hi om
From 3522da1b21f2d67740f02a96013a61310be2bff5 Mon Sep 17 00:00:00 2001 From: omdivyatej Date: Sun, 23 Mar 2025 21:33:49 +0530 Subject: [PATCH 3/5] specify LLM for query --- .../lightrag_multi_model_all_modes_demo.py | 93 +++++++++++++++++++ lightrag/base.py | 7 ++ lightrag/lightrag.py | 10 +- lightrag/operate.py | 10 +- 4 files changed, 112 insertions(+), 8 deletions(-) create mode 100644 examples/lightrag_multi_model_all_modes_demo.py diff --git a/examples/lightrag_multi_model_all_modes_demo.py b/examples/lightrag_multi_model_all_modes_demo.py new file mode 100644 index 00000000..04adf642 --- /dev/null +++ b/examples/lightrag_multi_model_all_modes_demo.py @@ -0,0 +1,93 @@ +import os +import asyncio +from lightrag import LightRAG, QueryParam +from lightrag.llm.openai import gpt_4o_mini_complete, gpt_4o_complete, openai_embed +from lightrag.kg.shared_storage import initialize_pipeline_status +from lightrag.utils import setup_logger + +setup_logger("lightrag", level="INFO") + +WORKING_DIR = "./all_modes_demo" + +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) + + +async def initialize_rag(): + # Initialize LightRAG with a base model (gpt-4o-mini) + rag = LightRAG( + working_dir=WORKING_DIR, + embedding_func=openai_embed, + llm_model_func=gpt_4o_mini_complete, # Default model for most queries + ) + + await rag.initialize_storages() + await initialize_pipeline_status() + + return rag + + +def main(): + # Initialize RAG instance + rag = asyncio.run(initialize_rag()) + + # Load the data + with open("./book.txt", "r", encoding="utf-8") as f: + rag.insert(f.read()) + + # Example query + query_text = "What are the main themes in this story?" + + # Demonstrate using default model (gpt-4o-mini) for all modes + print("\n===== Default Model (gpt-4o-mini) =====") + + for mode in ["local", "global", "hybrid", "naive", "mix"]: + print(f"\n--- {mode.upper()} mode with default model ---") + response = rag.query( + query_text, + param=QueryParam(mode=mode) + ) + print(response) + + # Demonstrate using custom model (gpt-4o) for all modes + print("\n===== Custom Model (gpt-4o) =====") + + for mode in ["local", "global", "hybrid", "naive", "mix"]: + print(f"\n--- {mode.upper()} mode with custom model ---") + response = rag.query( + query_text, + param=QueryParam( + mode=mode, + model_func=gpt_4o_complete # Override with more capable model + ) + ) + print(response) + + # Mixed approach - use different models for different modes + print("\n===== Strategic Model Selection =====") + + # Complex analytical question + complex_query = "How does the character development in the story reflect Victorian-era social values?" + + # Use default model for simpler modes + print("\n--- NAIVE mode with default model (suitable for simple retrieval) ---") + response1 = rag.query( + complex_query, + param=QueryParam(mode="naive") # Use default model for basic retrieval + ) + print(response1) + + # Use more capable model for complex modes + print("\n--- HYBRID mode with more capable model (for complex analysis) ---") + response2 = rag.query( + complex_query, + param=QueryParam( + mode="hybrid", + model_func=gpt_4o_complete # Use more capable model for complex analysis + ) + ) + print(response2) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/lightrag/base.py b/lightrag/base.py index f0376c01..faece842 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -10,6 +10,7 @@ from typing import ( Literal, TypedDict, TypeVar, + Callable, ) import numpy as np from .utils import EmbeddingFunc @@ -83,6 +84,12 @@ class QueryParam: ids: list[str] | None = None """List of ids to filter the results.""" + + model_func: Callable[..., object] | None = None + """Optional override for the LLM model function to use for this specific query. + If provided, this will be used instead of the global model function. + This allows using different models for different query modes. + """ @dataclass diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 49f3d955..442f00e3 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -1330,11 +1330,15 @@ class LightRAG: Args: query (str): The query to be executed. param (QueryParam): Configuration parameters for query execution. + If param.model_func is provided, it will be used instead of the global model. prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"]. Returns: str: The result of the query execution. """ + # If a custom model is provided in param, temporarily update global config + global_config = asdict(self) + if param.mode in ["local", "global", "hybrid"]: response = await kg_query( query.strip(), @@ -1343,7 +1347,7 @@ class LightRAG: self.relationships_vdb, self.text_chunks, param, - asdict(self), + global_config, hashing_kv=self.llm_response_cache, # Directly use llm_response_cache system_prompt=system_prompt, ) @@ -1353,7 +1357,7 @@ class LightRAG: self.chunks_vdb, self.text_chunks, param, - asdict(self), + global_config, hashing_kv=self.llm_response_cache, # Directly use llm_response_cache system_prompt=system_prompt, ) @@ -1366,7 +1370,7 @@ class LightRAG: self.chunks_vdb, self.text_chunks, param, - asdict(self), + global_config, hashing_kv=self.llm_response_cache, # Directly use llm_response_cache system_prompt=system_prompt, ) diff --git a/lightrag/operate.py b/lightrag/operate.py index 3291c49f..f7de6b5e 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -705,7 +705,7 @@ async def kg_query( system_prompt: str | None = None, ) -> str | AsyncIterator[str]: # Handle cache - use_model_func = global_config["llm_model_func"] + use_model_func = query_param.model_func if query_param.model_func else global_config["llm_model_func"] args_hash = compute_args_hash(query_param.mode, query, cache_type="query") cached_response, quantized, min_val, max_val = await handle_cache( hashing_kv, args_hash, query, query_param.mode, cache_type="query" @@ -866,7 +866,7 @@ async def extract_keywords_only( logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}") # 5. Call the LLM for keyword extraction - use_model_func = global_config["llm_model_func"] + use_model_func = param.model_func if param.model_func else global_config["llm_model_func"] result = await use_model_func(kw_prompt, keyword_extraction=True) # 6. Parse out JSON from the LLM response @@ -926,7 +926,7 @@ async def mix_kg_vector_query( 3. Combining both results for comprehensive answer generation """ # 1. Cache handling - use_model_func = global_config["llm_model_func"] + use_model_func = query_param.model_func if query_param.model_func else global_config["llm_model_func"] args_hash = compute_args_hash("mix", query, cache_type="query") cached_response, quantized, min_val, max_val = await handle_cache( hashing_kv, args_hash, query, "mix", cache_type="query" @@ -1731,7 +1731,7 @@ async def naive_query( system_prompt: str | None = None, ) -> str | AsyncIterator[str]: # Handle cache - use_model_func = global_config["llm_model_func"] + use_model_func = query_param.model_func if query_param.model_func else global_config["llm_model_func"] args_hash = compute_args_hash(query_param.mode, query, cache_type="query") cached_response, quantized, min_val, max_val = await handle_cache( hashing_kv, args_hash, query, query_param.mode, cache_type="query" @@ -1850,7 +1850,7 @@ async def kg_query_with_keywords( # --------------------------- # 1) Handle potential cache for query results # --------------------------- - use_model_func = global_config["llm_model_func"] + use_model_func = query_param.model_func if query_param.model_func else global_config["llm_model_func"] args_hash = compute_args_hash(query_param.mode, query, cache_type="query") cached_response, quantized, min_val, max_val = await handle_cache( hashing_kv, args_hash, query, query_param.mode, cache_type="query" From f87c235a4c00d13fac047a848cfb70e2346e1524 Mon Sep 17 00:00:00 2001 From: omdivyatej Date: Sun, 23 Mar 2025 21:42:56 +0530 Subject: [PATCH 4/5] less comments --- .../lightrag_multi_model_all_modes_demo.py | 111 +++++++++--------- 1 file changed, 54 insertions(+), 57 deletions(-) diff --git a/examples/lightrag_multi_model_all_modes_demo.py b/examples/lightrag_multi_model_all_modes_demo.py index 04adf642..c2f9c3d2 100644 --- a/examples/lightrag_multi_model_all_modes_demo.py +++ b/examples/lightrag_multi_model_all_modes_demo.py @@ -3,22 +3,17 @@ import asyncio from lightrag import LightRAG, QueryParam from lightrag.llm.openai import gpt_4o_mini_complete, gpt_4o_complete, openai_embed from lightrag.kg.shared_storage import initialize_pipeline_status -from lightrag.utils import setup_logger -setup_logger("lightrag", level="INFO") - -WORKING_DIR = "./all_modes_demo" +WORKING_DIR = "./lightrag_demo" if not os.path.exists(WORKING_DIR): os.mkdir(WORKING_DIR) - async def initialize_rag(): - # Initialize LightRAG with a base model (gpt-4o-mini) rag = LightRAG( working_dir=WORKING_DIR, embedding_func=openai_embed, - llm_model_func=gpt_4o_mini_complete, # Default model for most queries + llm_model_func=gpt_4o_mini_complete, # Default model for queries ) await rag.initialize_storages() @@ -26,7 +21,6 @@ async def initialize_rag(): return rag - def main(): # Initialize RAG instance rag = asyncio.run(initialize_rag()) @@ -34,60 +28,63 @@ def main(): # Load the data with open("./book.txt", "r", encoding="utf-8") as f: rag.insert(f.read()) - - # Example query - query_text = "What are the main themes in this story?" - - # Demonstrate using default model (gpt-4o-mini) for all modes - print("\n===== Default Model (gpt-4o-mini) =====") - - for mode in ["local", "global", "hybrid", "naive", "mix"]: - print(f"\n--- {mode.upper()} mode with default model ---") - response = rag.query( - query_text, - param=QueryParam(mode=mode) + + # Query with naive mode (default model) + print("--- NAIVE mode ---") + print( + rag.query( + "What are the main themes in this story?", + param=QueryParam(mode="naive") ) - print(response) - - # Demonstrate using custom model (gpt-4o) for all modes - print("\n===== Custom Model (gpt-4o) =====") - - for mode in ["local", "global", "hybrid", "naive", "mix"]: - print(f"\n--- {mode.upper()} mode with custom model ---") - response = rag.query( - query_text, + ) + + # Query with local mode (default model) + print("\n--- LOCAL mode ---") + print( + rag.query( + "What are the main themes in this story?", + param=QueryParam(mode="local") + ) + ) + + # Query with global mode (default model) + print("\n--- GLOBAL mode ---") + print( + rag.query( + "What are the main themes in this story?", + param=QueryParam(mode="global") + ) + ) + + # Query with hybrid mode (default model) + print("\n--- HYBRID mode ---") + print( + rag.query( + "What are the main themes in this story?", + param=QueryParam(mode="hybrid") + ) + ) + + # Query with mix mode (default model) + print("\n--- MIX mode ---") + print( + rag.query( + "What are the main themes in this story?", + param=QueryParam(mode="mix") + ) + ) + + # Query with a custom model (gpt-4o) for a more complex question + print("\n--- Using custom model for complex analysis ---") + print( + rag.query( + "How does the character development reflect Victorian-era attitudes?", param=QueryParam( - mode=mode, - model_func=gpt_4o_complete # Override with more capable model + mode="global", + model_func=gpt_4o_complete # Override default model with more capable one ) ) - print(response) - - # Mixed approach - use different models for different modes - print("\n===== Strategic Model Selection =====") - - # Complex analytical question - complex_query = "How does the character development in the story reflect Victorian-era social values?" - - # Use default model for simpler modes - print("\n--- NAIVE mode with default model (suitable for simple retrieval) ---") - response1 = rag.query( - complex_query, - param=QueryParam(mode="naive") # Use default model for basic retrieval ) - print(response1) - - # Use more capable model for complex modes - print("\n--- HYBRID mode with more capable model (for complex analysis) ---") - response2 = rag.query( - complex_query, - param=QueryParam( - mode="hybrid", - model_func=gpt_4o_complete # Use more capable model for complex analysis - ) - ) - print(response2) - if __name__ == "__main__": main() \ No newline at end of file From f049f2f5c46fa52399b817a70d3ef7c85d2ee8cc Mon Sep 17 00:00:00 2001 From: omdivyatej Date: Tue, 25 Mar 2025 15:20:09 +0530 Subject: [PATCH 5/5] linting errors --- .../lightrag_multi_model_all_modes_demo.py | 24 ++++++++-------- lightrag/base.py | 2 +- lightrag/lightrag.py | 2 +- lightrag/operate.py | 28 +++++++++++++++---- 4 files changed, 36 insertions(+), 20 deletions(-) diff --git a/examples/lightrag_multi_model_all_modes_demo.py b/examples/lightrag_multi_model_all_modes_demo.py index c2f9c3d2..16e18782 100644 --- a/examples/lightrag_multi_model_all_modes_demo.py +++ b/examples/lightrag_multi_model_all_modes_demo.py @@ -9,6 +9,7 @@ WORKING_DIR = "./lightrag_demo" if not os.path.exists(WORKING_DIR): os.mkdir(WORKING_DIR) + async def initialize_rag(): rag = LightRAG( working_dir=WORKING_DIR, @@ -21,6 +22,7 @@ async def initialize_rag(): return rag + def main(): # Initialize RAG instance rag = asyncio.run(initialize_rag()) @@ -33,8 +35,7 @@ def main(): print("--- NAIVE mode ---") print( rag.query( - "What are the main themes in this story?", - param=QueryParam(mode="naive") + "What are the main themes in this story?", param=QueryParam(mode="naive") ) ) @@ -42,8 +43,7 @@ def main(): print("\n--- LOCAL mode ---") print( rag.query( - "What are the main themes in this story?", - param=QueryParam(mode="local") + "What are the main themes in this story?", param=QueryParam(mode="local") ) ) @@ -51,8 +51,7 @@ def main(): print("\n--- GLOBAL mode ---") print( rag.query( - "What are the main themes in this story?", - param=QueryParam(mode="global") + "What are the main themes in this story?", param=QueryParam(mode="global") ) ) @@ -60,8 +59,7 @@ def main(): print("\n--- HYBRID mode ---") print( rag.query( - "What are the main themes in this story?", - param=QueryParam(mode="hybrid") + "What are the main themes in this story?", param=QueryParam(mode="hybrid") ) ) @@ -69,8 +67,7 @@ def main(): print("\n--- MIX mode ---") print( rag.query( - "What are the main themes in this story?", - param=QueryParam(mode="mix") + "What are the main themes in this story?", param=QueryParam(mode="mix") ) ) @@ -81,10 +78,11 @@ def main(): "How does the character development reflect Victorian-era attitudes?", param=QueryParam( mode="global", - model_func=gpt_4o_complete # Override default model with more capable one - ) + model_func=gpt_4o_complete, # Override default model with more capable one + ), ) ) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/lightrag/base.py b/lightrag/base.py index faece842..3db337e5 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -84,7 +84,7 @@ class QueryParam: ids: list[str] | None = None """List of ids to filter the results.""" - + model_func: Callable[..., object] | None = None """Optional override for the LLM model function to use for this specific query. If provided, this will be used instead of the global model function. diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 442f00e3..d404bffa 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -1338,7 +1338,7 @@ class LightRAG: """ # If a custom model is provided in param, temporarily update global config global_config = asdict(self) - + if param.mode in ["local", "global", "hybrid"]: response = await kg_query( query.strip(), diff --git a/lightrag/operate.py b/lightrag/operate.py index f7de6b5e..9f5eb92b 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -705,7 +705,11 @@ async def kg_query( system_prompt: str | None = None, ) -> str | AsyncIterator[str]: # Handle cache - use_model_func = query_param.model_func if query_param.model_func else global_config["llm_model_func"] + use_model_func = ( + query_param.model_func + if query_param.model_func + else global_config["llm_model_func"] + ) args_hash = compute_args_hash(query_param.mode, query, cache_type="query") cached_response, quantized, min_val, max_val = await handle_cache( hashing_kv, args_hash, query, query_param.mode, cache_type="query" @@ -866,7 +870,9 @@ async def extract_keywords_only( logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}") # 5. Call the LLM for keyword extraction - use_model_func = param.model_func if param.model_func else global_config["llm_model_func"] + use_model_func = ( + param.model_func if param.model_func else global_config["llm_model_func"] + ) result = await use_model_func(kw_prompt, keyword_extraction=True) # 6. Parse out JSON from the LLM response @@ -926,7 +932,11 @@ async def mix_kg_vector_query( 3. Combining both results for comprehensive answer generation """ # 1. Cache handling - use_model_func = query_param.model_func if query_param.model_func else global_config["llm_model_func"] + use_model_func = ( + query_param.model_func + if query_param.model_func + else global_config["llm_model_func"] + ) args_hash = compute_args_hash("mix", query, cache_type="query") cached_response, quantized, min_val, max_val = await handle_cache( hashing_kv, args_hash, query, "mix", cache_type="query" @@ -1731,7 +1741,11 @@ async def naive_query( system_prompt: str | None = None, ) -> str | AsyncIterator[str]: # Handle cache - use_model_func = query_param.model_func if query_param.model_func else global_config["llm_model_func"] + use_model_func = ( + query_param.model_func + if query_param.model_func + else global_config["llm_model_func"] + ) args_hash = compute_args_hash(query_param.mode, query, cache_type="query") cached_response, quantized, min_val, max_val = await handle_cache( hashing_kv, args_hash, query, query_param.mode, cache_type="query" @@ -1850,7 +1864,11 @@ async def kg_query_with_keywords( # --------------------------- # 1) Handle potential cache for query results # --------------------------- - use_model_func = query_param.model_func if query_param.model_func else global_config["llm_model_func"] + use_model_func = ( + query_param.model_func + if query_param.model_func + else global_config["llm_model_func"] + ) args_hash = compute_args_hash(query_param.mode, query, cache_type="query") cached_response, quantized, min_val, max_val = await handle_cache( hashing_kv, args_hash, query, query_param.mode, cache_type="query"