Unify llm_response_cache and hashing_kv, prevent creating an independent hashing_kv.

This commit is contained in:
yangdx
2025-03-09 22:15:26 +08:00
parent e47883d872
commit bc42afe7b6
5 changed files with 30 additions and 96 deletions

View File

@@ -323,7 +323,7 @@ def create_app(args):
vector_db_storage_cls_kwargs={ vector_db_storage_cls_kwargs={
"cosine_better_than_threshold": args.cosine_threshold "cosine_better_than_threshold": args.cosine_threshold
}, },
enable_llm_cache_for_entity_extract=args.enable_llm_cache, # Read from args enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
embedding_cache_config={ embedding_cache_config={
"enabled": True, "enabled": True,
"similarity_threshold": 0.95, "similarity_threshold": 0.95,
@@ -352,7 +352,7 @@ def create_app(args):
vector_db_storage_cls_kwargs={ vector_db_storage_cls_kwargs={
"cosine_better_than_threshold": args.cosine_threshold "cosine_better_than_threshold": args.cosine_threshold
}, },
enable_llm_cache_for_entity_extract=args.enable_llm_cache, # Read from args enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
embedding_cache_config={ embedding_cache_config={
"enabled": True, "enabled": True,
"similarity_threshold": 0.95, "similarity_threshold": 0.95,
@@ -416,7 +416,7 @@ def create_app(args):
"doc_status_storage": args.doc_status_storage, "doc_status_storage": args.doc_status_storage,
"graph_storage": args.graph_storage, "graph_storage": args.graph_storage,
"vector_storage": args.vector_storage, "vector_storage": args.vector_storage,
"enable_llm_cache": args.enable_llm_cache, "enable_llm_cache_for_extract": args.enable_llm_cache_for_extract,
}, },
"update_status": update_status, "update_status": update_status,
} }

View File

@@ -361,7 +361,7 @@ def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace:
args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int) args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int)
# Inject LLM cache configuration # Inject LLM cache configuration
args.enable_llm_cache = get_env_value( args.enable_llm_cache_for_extract = get_env_value(
"ENABLE_LLM_CACHE_FOR_EXTRACT", "ENABLE_LLM_CACHE_FOR_EXTRACT",
False, False,
bool bool
@@ -460,8 +460,8 @@ def display_splash_screen(args: argparse.Namespace) -> None:
ASCIIColors.yellow(f"{args.cosine_threshold}") ASCIIColors.yellow(f"{args.cosine_threshold}")
ASCIIColors.white(" ├─ Top-K: ", end="") ASCIIColors.white(" ├─ Top-K: ", end="")
ASCIIColors.yellow(f"{args.top_k}") ASCIIColors.yellow(f"{args.top_k}")
ASCIIColors.white(" └─ LLM Cache Enabled: ", end="") ASCIIColors.white(" └─ LLM Cache for Extraction Enabled: ", end="")
ASCIIColors.yellow(f"{args.enable_llm_cache}") ASCIIColors.yellow(f"{args.enable_llm_cache_for_extract}")
# System Configuration # System Configuration
ASCIIColors.magenta("\n💾 Storage Configuration:") ASCIIColors.magenta("\n💾 Storage Configuration:")

View File

@@ -354,6 +354,7 @@ class LightRAG:
namespace=make_namespace( namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
), ),
global_config=asdict(self), # Add global_config to ensure cache works properly
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
) )
@@ -404,18 +405,8 @@ class LightRAG:
embedding_func=None, embedding_func=None,
) )
if self.llm_response_cache and hasattr( # Directly use llm_response_cache, don't create a new object
self.llm_response_cache, "global_config" hashing_kv = self.llm_response_cache
):
hashing_kv = self.llm_response_cache
else:
hashing_kv = self.key_string_value_json_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
embedding_func=self.embedding_func,
)
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
partial( partial(
@@ -1260,16 +1251,7 @@ class LightRAG:
self.text_chunks, self.text_chunks,
param, param,
asdict(self), asdict(self),
hashing_kv=self.llm_response_cache hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
embedding_func=self.embedding_func,
),
system_prompt=system_prompt, system_prompt=system_prompt,
) )
elif param.mode == "naive": elif param.mode == "naive":
@@ -1279,16 +1261,7 @@ class LightRAG:
self.text_chunks, self.text_chunks,
param, param,
asdict(self), asdict(self),
hashing_kv=self.llm_response_cache hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
embedding_func=self.embedding_func,
),
system_prompt=system_prompt, system_prompt=system_prompt,
) )
elif param.mode == "mix": elif param.mode == "mix":
@@ -1301,16 +1274,7 @@ class LightRAG:
self.text_chunks, self.text_chunks,
param, param,
asdict(self), asdict(self),
hashing_kv=self.llm_response_cache hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
embedding_func=self.embedding_func,
),
system_prompt=system_prompt, system_prompt=system_prompt,
) )
else: else:
@@ -1344,14 +1308,7 @@ class LightRAG:
text=query, text=query,
param=param, param=param,
global_config=asdict(self), global_config=asdict(self),
hashing_kv=self.llm_response_cache hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
or self.key_string_value_json_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
embedding_func=self.embedding_func,
),
) )
param.hl_keywords = hl_keywords param.hl_keywords = hl_keywords
@@ -1375,16 +1332,7 @@ class LightRAG:
self.text_chunks, self.text_chunks,
param, param,
asdict(self), asdict(self),
hashing_kv=self.llm_response_cache hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
embedding_func=self.embedding_func,
),
) )
elif param.mode == "naive": elif param.mode == "naive":
response = await naive_query( response = await naive_query(
@@ -1393,16 +1341,7 @@ class LightRAG:
self.text_chunks, self.text_chunks,
param, param,
asdict(self), asdict(self),
hashing_kv=self.llm_response_cache hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
embedding_func=self.embedding_func,
),
) )
elif param.mode == "mix": elif param.mode == "mix":
response = await mix_kg_vector_query( response = await mix_kg_vector_query(
@@ -1414,16 +1353,7 @@ class LightRAG:
self.text_chunks, self.text_chunks,
param, param,
asdict(self), asdict(self),
hashing_kv=self.llm_response_cache hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
embedding_func=self.embedding_func,
),
) )
else: else:
raise ValueError(f"Unknown mode {param.mode}") raise ValueError(f"Unknown mode {param.mode}")

View File

@@ -410,7 +410,6 @@ async def extract_entities(
_prompt, _prompt,
"default", "default",
cache_type="extract", cache_type="extract",
force_llm_cache=True,
) )
if cached_return: if cached_return:
logger.debug(f"Found cache for {arg_hash}") logger.debug(f"Found cache for {arg_hash}")
@@ -432,6 +431,7 @@ async def extract_entities(
cache_type="extract", cache_type="extract",
), ),
) )
logger.info(f"Extract: saved cache for {arg_hash}")
return res return res
if history_messages: if history_messages:

View File

@@ -633,15 +633,15 @@ async def handle_cache(
prompt, prompt,
mode="default", mode="default",
cache_type=None, cache_type=None,
force_llm_cache=False,
): ):
"""Generic cache handling function""" """Generic cache handling function"""
if hashing_kv is None or not ( if hashing_kv is None:
force_llm_cache or hashing_kv.global_config.get("enable_llm_cache")
):
return None, None, None, None return None, None, None, None
if mode != "default": if mode != "default": # handle cache for all type of query
if not hashing_kv.global_config.get("enable_llm_cache"):
return None, None, None, None
# Get embedding cache configuration # Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get( embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", "embedding_cache_config",
@@ -651,8 +651,7 @@ async def handle_cache(
use_llm_check = embedding_cache_config.get("use_llm_check", False) use_llm_check = embedding_cache_config.get("use_llm_check", False)
quantized = min_val = max_val = None quantized = min_val = max_val = None
if is_embedding_cache_enabled: if is_embedding_cache_enabled: # Use embedding simularity to match cache
# Use embedding cache
current_embedding = await hashing_kv.embedding_func([prompt]) current_embedding = await hashing_kv.embedding_func([prompt])
llm_model_func = hashing_kv.global_config.get("llm_model_func") llm_model_func = hashing_kv.global_config.get("llm_model_func")
quantized, min_val, max_val = quantize_embedding(current_embedding[0]) quantized, min_val, max_val = quantize_embedding(current_embedding[0])
@@ -674,8 +673,13 @@ async def handle_cache(
logger.debug(f"Embedding cached missed(mode:{mode} type:{cache_type})") logger.debug(f"Embedding cached missed(mode:{mode} type:{cache_type})")
return None, quantized, min_val, max_val return None, quantized, min_val, max_val
# For default mode or is_embedding_cache_enabled is False, use regular cache else: # handle cache for entity extraction
# default mode is for extract_entities or naive query if not hashing_kv.global_config.get("enable_llm_cache_for_entity_extract"):
return None, None, None, None
# Here is the conditions of code reaching this point:
# 1. All query mode: enable_llm_cache is True and embedding simularity is not enabled
# 2. Entity extract: enable_llm_cache_for_entity_extract is True
if exists_func(hashing_kv, "get_by_mode_and_id"): if exists_func(hashing_kv, "get_by_mode_and_id"):
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {} mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
else: else: