From ef69009c159fdcabdda023fc0dc5082f704ab5b3 Mon Sep 17 00:00:00 2001 From: yangdx Date: Mon, 28 Apr 2025 19:36:21 +0800 Subject: [PATCH] Increase the priority of queries related to LLM requests --- lightrag/lightrag.py | 3 +++ lightrag/operate.py | 58 ++++++++++++++++++++++++++------------------ 2 files changed, 38 insertions(+), 23 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index a1e1f051..cd472922 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -1445,6 +1445,9 @@ class LightRAG: elif param.mode == "bypass": # Bypass mode: directly use LLM without knowledge retrieval use_llm_func = param.model_func or global_config["llm_model_func"] + # Apply higher priority (8) to entity/relation summary tasks + use_llm_func = partial(use_llm_func, _priority=8) + param.stream = True if param.stream is None else param.stream response = await use_llm_func( query.strip(), diff --git a/lightrag/operate.py b/lightrag/operate.py index 13c40cf4..2e583525 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -853,12 +853,14 @@ async def kg_query( hashing_kv: BaseKVStorage | None = None, system_prompt: str | None = None, ) -> str | AsyncIterator[str]: + if query_param.model_func: + use_model_func = query_param.model_func + else: + use_model_func = global_config["llm_model_func"] + # Apply higher priority (5) to query relation LLM function + use_model_func = partial(use_model_func, _priority=5) + # Handle cache - use_model_func = ( - query_param.model_func - if query_param.model_func - else global_config["llm_model_func"] - ) args_hash = compute_args_hash(query_param.mode, query, cache_type="query") cached_response, quantized, min_val, max_val = await handle_cache( hashing_kv, args_hash, query, query_param.mode, cache_type="query" @@ -1054,9 +1056,13 @@ async def extract_keywords_only( logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}") # 5. Call the LLM for keyword extraction - use_model_func = ( - param.model_func if param.model_func else global_config["llm_model_func"] - ) + if param.model_func: + use_model_func = param.model_func + else: + use_model_func = global_config["llm_model_func"] + # Apply higher priority (5) to query relation LLM function + use_model_func = partial(use_model_func, _priority=5) + result = await use_model_func(kw_prompt, keyword_extraction=True) # 6. Parse out JSON from the LLM response @@ -1119,12 +1125,15 @@ async def mix_kg_vector_query( """ # get tokenizer tokenizer: Tokenizer = global_config["tokenizer"] + + if query_param.model_func: + use_model_func = query_param.model_func + else: + use_model_func = global_config["llm_model_func"] + # Apply higher priority (5) to query relation LLM function + use_model_func = partial(use_model_func, _priority=5) + # 1. Cache handling - use_model_func = ( - query_param.model_func - if query_param.model_func - else global_config["llm_model_func"] - ) args_hash = compute_args_hash("mix", query, cache_type="query") cached_response, quantized, min_val, max_val = await handle_cache( hashing_kv, args_hash, query, "mix", cache_type="query" @@ -2010,12 +2019,14 @@ async def naive_query( hashing_kv: BaseKVStorage | None = None, system_prompt: str | None = None, ) -> str | AsyncIterator[str]: + if query_param.model_func: + use_model_func = query_param.model_func + else: + use_model_func = global_config["llm_model_func"] + # Apply higher priority (5) to query relation LLM function + use_model_func = partial(use_model_func, _priority=5) + # Handle cache - use_model_func = ( - query_param.model_func - if query_param.model_func - else global_config["llm_model_func"] - ) args_hash = compute_args_hash(query_param.mode, query, cache_type="query") cached_response, quantized, min_val, max_val = await handle_cache( hashing_kv, args_hash, query, query_param.mode, cache_type="query" @@ -2142,15 +2153,16 @@ async def kg_query_with_keywords( It expects hl_keywords and ll_keywords to be set in query_param, or defaults to empty. Then it uses those to build context and produce a final LLM response. """ + if query_param.model_func: + use_model_func = query_param.model_func + else: + use_model_func = global_config["llm_model_func"] + # Apply higher priority (5) to query relation LLM function + use_model_func = partial(use_model_func, _priority=5) # --------------------------- # 1) Handle potential cache for query results # --------------------------- - use_model_func = ( - query_param.model_func - if query_param.model_func - else global_config["llm_model_func"] - ) args_hash = compute_args_hash(query_param.mode, query, cache_type="query") cached_response, quantized, min_val, max_val = await handle_cache( hashing_kv, args_hash, query, query_param.mode, cache_type="query"