diff --git a/examples/lightrag_multi_model_all_modes_demo.py b/examples/lightrag_multi_model_all_modes_demo.py index c2f9c3d2..16e18782 100644 --- a/examples/lightrag_multi_model_all_modes_demo.py +++ b/examples/lightrag_multi_model_all_modes_demo.py @@ -9,6 +9,7 @@ WORKING_DIR = "./lightrag_demo" if not os.path.exists(WORKING_DIR): os.mkdir(WORKING_DIR) + async def initialize_rag(): rag = LightRAG( working_dir=WORKING_DIR, @@ -21,6 +22,7 @@ async def initialize_rag(): return rag + def main(): # Initialize RAG instance rag = asyncio.run(initialize_rag()) @@ -33,8 +35,7 @@ def main(): print("--- NAIVE mode ---") print( rag.query( - "What are the main themes in this story?", - param=QueryParam(mode="naive") + "What are the main themes in this story?", param=QueryParam(mode="naive") ) ) @@ -42,8 +43,7 @@ def main(): print("\n--- LOCAL mode ---") print( rag.query( - "What are the main themes in this story?", - param=QueryParam(mode="local") + "What are the main themes in this story?", param=QueryParam(mode="local") ) ) @@ -51,8 +51,7 @@ def main(): print("\n--- GLOBAL mode ---") print( rag.query( - "What are the main themes in this story?", - param=QueryParam(mode="global") + "What are the main themes in this story?", param=QueryParam(mode="global") ) ) @@ -60,8 +59,7 @@ def main(): print("\n--- HYBRID mode ---") print( rag.query( - "What are the main themes in this story?", - param=QueryParam(mode="hybrid") + "What are the main themes in this story?", param=QueryParam(mode="hybrid") ) ) @@ -69,8 +67,7 @@ def main(): print("\n--- MIX mode ---") print( rag.query( - "What are the main themes in this story?", - param=QueryParam(mode="mix") + "What are the main themes in this story?", param=QueryParam(mode="mix") ) ) @@ -81,10 +78,11 @@ def main(): "How does the character development reflect Victorian-era attitudes?", param=QueryParam( mode="global", - model_func=gpt_4o_complete # Override default model with more capable one - ) + model_func=gpt_4o_complete, # Override default model with more capable one + ), ) ) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/lightrag/base.py b/lightrag/base.py index faece842..3db337e5 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -84,7 +84,7 @@ class QueryParam: ids: list[str] | None = None """List of ids to filter the results.""" - + model_func: Callable[..., object] | None = None """Optional override for the LLM model function to use for this specific query. If provided, this will be used instead of the global model function. diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 442f00e3..d404bffa 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -1338,7 +1338,7 @@ class LightRAG: """ # If a custom model is provided in param, temporarily update global config global_config = asdict(self) - + if param.mode in ["local", "global", "hybrid"]: response = await kg_query( query.strip(), diff --git a/lightrag/operate.py b/lightrag/operate.py index f7de6b5e..9f5eb92b 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -705,7 +705,11 @@ async def kg_query( system_prompt: str | None = None, ) -> str | AsyncIterator[str]: # Handle cache - use_model_func = query_param.model_func if query_param.model_func else global_config["llm_model_func"] + use_model_func = ( + query_param.model_func + if query_param.model_func + else global_config["llm_model_func"] + ) args_hash = compute_args_hash(query_param.mode, query, cache_type="query") cached_response, quantized, min_val, max_val = await handle_cache( hashing_kv, args_hash, query, query_param.mode, cache_type="query" @@ -866,7 +870,9 @@ async def extract_keywords_only( logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}") # 5. Call the LLM for keyword extraction - use_model_func = param.model_func if param.model_func else global_config["llm_model_func"] + use_model_func = ( + param.model_func if param.model_func else global_config["llm_model_func"] + ) result = await use_model_func(kw_prompt, keyword_extraction=True) # 6. Parse out JSON from the LLM response @@ -926,7 +932,11 @@ async def mix_kg_vector_query( 3. Combining both results for comprehensive answer generation """ # 1. Cache handling - use_model_func = query_param.model_func if query_param.model_func else global_config["llm_model_func"] + use_model_func = ( + query_param.model_func + if query_param.model_func + else global_config["llm_model_func"] + ) args_hash = compute_args_hash("mix", query, cache_type="query") cached_response, quantized, min_val, max_val = await handle_cache( hashing_kv, args_hash, query, "mix", cache_type="query" @@ -1731,7 +1741,11 @@ async def naive_query( system_prompt: str | None = None, ) -> str | AsyncIterator[str]: # Handle cache - use_model_func = query_param.model_func if query_param.model_func else global_config["llm_model_func"] + use_model_func = ( + query_param.model_func + if query_param.model_func + else global_config["llm_model_func"] + ) args_hash = compute_args_hash(query_param.mode, query, cache_type="query") cached_response, quantized, min_val, max_val = await handle_cache( hashing_kv, args_hash, query, query_param.mode, cache_type="query" @@ -1850,7 +1864,11 @@ async def kg_query_with_keywords( # --------------------------- # 1) Handle potential cache for query results # --------------------------- - use_model_func = query_param.model_func if query_param.model_func else global_config["llm_model_func"] + use_model_func = ( + query_param.model_func + if query_param.model_func + else global_config["llm_model_func"] + ) args_hash = compute_args_hash(query_param.mode, query, cache_type="query") cached_response, quantized, min_val, max_val = await handle_cache( hashing_kv, args_hash, query, query_param.mode, cache_type="query"