From bc42afe7b65f92a5d73eb01f5410bdde9385ddd0 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sun, 9 Mar 2025 22:15:26 +0800 Subject: [PATCH] Unify llm_response_cache and hashing_kv, prevent creating an independent hashing_kv. --- lightrag/api/lightrag_server.py | 6 +-- lightrag/api/utils_api.py | 6 +-- lightrag/lightrag.py | 90 ++++----------------------------- lightrag/operate.py | 2 +- lightrag/utils.py | 22 ++++---- 5 files changed, 30 insertions(+), 96 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index c42a816a..8871650a 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -323,7 +323,7 @@ def create_app(args): vector_db_storage_cls_kwargs={ "cosine_better_than_threshold": args.cosine_threshold }, - enable_llm_cache_for_entity_extract=args.enable_llm_cache, # Read from args + enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract, embedding_cache_config={ "enabled": True, "similarity_threshold": 0.95, @@ -352,7 +352,7 @@ def create_app(args): vector_db_storage_cls_kwargs={ "cosine_better_than_threshold": args.cosine_threshold }, - enable_llm_cache_for_entity_extract=args.enable_llm_cache, # Read from args + enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract, embedding_cache_config={ "enabled": True, "similarity_threshold": 0.95, @@ -416,7 +416,7 @@ def create_app(args): "doc_status_storage": args.doc_status_storage, "graph_storage": args.graph_storage, "vector_storage": args.vector_storage, - "enable_llm_cache": args.enable_llm_cache, + "enable_llm_cache_for_extract": args.enable_llm_cache_for_extract, }, "update_status": update_status, } diff --git a/lightrag/api/utils_api.py b/lightrag/api/utils_api.py index da443558..9a619f9e 100644 --- a/lightrag/api/utils_api.py +++ b/lightrag/api/utils_api.py @@ -361,7 +361,7 @@ def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace: args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int) # Inject LLM cache configuration - args.enable_llm_cache = get_env_value( + args.enable_llm_cache_for_extract = get_env_value( "ENABLE_LLM_CACHE_FOR_EXTRACT", False, bool @@ -460,8 +460,8 @@ def display_splash_screen(args: argparse.Namespace) -> None: ASCIIColors.yellow(f"{args.cosine_threshold}") ASCIIColors.white(" ├─ Top-K: ", end="") ASCIIColors.yellow(f"{args.top_k}") - ASCIIColors.white(" └─ LLM Cache Enabled: ", end="") - ASCIIColors.yellow(f"{args.enable_llm_cache}") + ASCIIColors.white(" └─ LLM Cache for Extraction Enabled: ", end="") + ASCIIColors.yellow(f"{args.enable_llm_cache_for_extract}") # System Configuration ASCIIColors.magenta("\n💾 Storage Configuration:") diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index b06520fc..a91aa6fa 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -354,6 +354,7 @@ class LightRAG: namespace=make_namespace( self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE ), + global_config=asdict(self), # Add global_config to ensure cache works properly embedding_func=self.embedding_func, ) @@ -404,18 +405,8 @@ class LightRAG: embedding_func=None, ) - if self.llm_response_cache and hasattr( - self.llm_response_cache, "global_config" - ): - hashing_kv = self.llm_response_cache - else: - hashing_kv = self.key_string_value_json_storage_cls( # type: ignore - namespace=make_namespace( - self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE - ), - global_config=asdict(self), - embedding_func=self.embedding_func, - ) + # Directly use llm_response_cache, don't create a new object + hashing_kv = self.llm_response_cache self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( partial( @@ -1260,16 +1251,7 @@ class LightRAG: self.text_chunks, param, asdict(self), - hashing_kv=self.llm_response_cache - if self.llm_response_cache - and hasattr(self.llm_response_cache, "global_config") - else self.key_string_value_json_storage_cls( - namespace=make_namespace( - self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE - ), - global_config=asdict(self), - embedding_func=self.embedding_func, - ), + hashing_kv=self.llm_response_cache, # Directly use llm_response_cache system_prompt=system_prompt, ) elif param.mode == "naive": @@ -1279,16 +1261,7 @@ class LightRAG: self.text_chunks, param, asdict(self), - hashing_kv=self.llm_response_cache - if self.llm_response_cache - and hasattr(self.llm_response_cache, "global_config") - else self.key_string_value_json_storage_cls( - namespace=make_namespace( - self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE - ), - global_config=asdict(self), - embedding_func=self.embedding_func, - ), + hashing_kv=self.llm_response_cache, # Directly use llm_response_cache system_prompt=system_prompt, ) elif param.mode == "mix": @@ -1301,16 +1274,7 @@ class LightRAG: self.text_chunks, param, asdict(self), - hashing_kv=self.llm_response_cache - if self.llm_response_cache - and hasattr(self.llm_response_cache, "global_config") - else self.key_string_value_json_storage_cls( - namespace=make_namespace( - self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE - ), - global_config=asdict(self), - embedding_func=self.embedding_func, - ), + hashing_kv=self.llm_response_cache, # Directly use llm_response_cache system_prompt=system_prompt, ) else: @@ -1344,14 +1308,7 @@ class LightRAG: text=query, param=param, global_config=asdict(self), - hashing_kv=self.llm_response_cache - or self.key_string_value_json_storage_cls( - namespace=make_namespace( - self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE - ), - global_config=asdict(self), - embedding_func=self.embedding_func, - ), + hashing_kv=self.llm_response_cache, # Directly use llm_response_cache ) param.hl_keywords = hl_keywords @@ -1375,16 +1332,7 @@ class LightRAG: self.text_chunks, param, asdict(self), - hashing_kv=self.llm_response_cache - if self.llm_response_cache - and hasattr(self.llm_response_cache, "global_config") - else self.key_string_value_json_storage_cls( - namespace=make_namespace( - self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE - ), - global_config=asdict(self), - embedding_func=self.embedding_func, - ), + hashing_kv=self.llm_response_cache, # Directly use llm_response_cache ) elif param.mode == "naive": response = await naive_query( @@ -1393,16 +1341,7 @@ class LightRAG: self.text_chunks, param, asdict(self), - hashing_kv=self.llm_response_cache - if self.llm_response_cache - and hasattr(self.llm_response_cache, "global_config") - else self.key_string_value_json_storage_cls( - namespace=make_namespace( - self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE - ), - global_config=asdict(self), - embedding_func=self.embedding_func, - ), + hashing_kv=self.llm_response_cache, # Directly use llm_response_cache ) elif param.mode == "mix": response = await mix_kg_vector_query( @@ -1414,16 +1353,7 @@ class LightRAG: self.text_chunks, param, asdict(self), - hashing_kv=self.llm_response_cache - if self.llm_response_cache - and hasattr(self.llm_response_cache, "global_config") - else self.key_string_value_json_storage_cls( - namespace=make_namespace( - self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE - ), - global_config=asdict(self), - embedding_func=self.embedding_func, - ), + hashing_kv=self.llm_response_cache, # Directly use llm_response_cache ) else: raise ValueError(f"Unknown mode {param.mode}") diff --git a/lightrag/operate.py b/lightrag/operate.py index d16e170c..9ba3b06d 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -410,7 +410,6 @@ async def extract_entities( _prompt, "default", cache_type="extract", - force_llm_cache=True, ) if cached_return: logger.debug(f"Found cache for {arg_hash}") @@ -432,6 +431,7 @@ async def extract_entities( cache_type="extract", ), ) + logger.info(f"Extract: saved cache for {arg_hash}") return res if history_messages: diff --git a/lightrag/utils.py b/lightrag/utils.py index 1b65097e..02c3236d 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -633,15 +633,15 @@ async def handle_cache( prompt, mode="default", cache_type=None, - force_llm_cache=False, ): """Generic cache handling function""" - if hashing_kv is None or not ( - force_llm_cache or hashing_kv.global_config.get("enable_llm_cache") - ): + if hashing_kv is None: return None, None, None, None - if mode != "default": + if mode != "default": # handle cache for all type of query + if not hashing_kv.global_config.get("enable_llm_cache"): + return None, None, None, None + # Get embedding cache configuration embedding_cache_config = hashing_kv.global_config.get( "embedding_cache_config", @@ -651,8 +651,7 @@ async def handle_cache( use_llm_check = embedding_cache_config.get("use_llm_check", False) quantized = min_val = max_val = None - if is_embedding_cache_enabled: - # Use embedding cache + if is_embedding_cache_enabled: # Use embedding simularity to match cache current_embedding = await hashing_kv.embedding_func([prompt]) llm_model_func = hashing_kv.global_config.get("llm_model_func") quantized, min_val, max_val = quantize_embedding(current_embedding[0]) @@ -674,8 +673,13 @@ async def handle_cache( logger.debug(f"Embedding cached missed(mode:{mode} type:{cache_type})") return None, quantized, min_val, max_val - # For default mode or is_embedding_cache_enabled is False, use regular cache - # default mode is for extract_entities or naive query + else: # handle cache for entity extraction + if not hashing_kv.global_config.get("enable_llm_cache_for_entity_extract"): + return None, None, None, None + + # Here is the conditions of code reaching this point: + # 1. All query mode: enable_llm_cache is True and embedding simularity is not enabled + # 2. Entity extract: enable_llm_cache_for_entity_extract is True if exists_func(hashing_kv, "get_by_mode_and_id"): mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {} else: