Unify llm_response_cache and hashing_kv, prevent creating an independent hashing_kv.

This commit is contained in:
yangdx
2025-03-09 22:15:26 +08:00
parent e47883d872
commit bc42afe7b6
5 changed files with 30 additions and 96 deletions

View File

@@ -323,7 +323,7 @@ def create_app(args):
vector_db_storage_cls_kwargs={
"cosine_better_than_threshold": args.cosine_threshold
},
enable_llm_cache_for_entity_extract=args.enable_llm_cache, # Read from args
enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
embedding_cache_config={
"enabled": True,
"similarity_threshold": 0.95,
@@ -352,7 +352,7 @@ def create_app(args):
vector_db_storage_cls_kwargs={
"cosine_better_than_threshold": args.cosine_threshold
},
enable_llm_cache_for_entity_extract=args.enable_llm_cache, # Read from args
enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
embedding_cache_config={
"enabled": True,
"similarity_threshold": 0.95,
@@ -416,7 +416,7 @@ def create_app(args):
"doc_status_storage": args.doc_status_storage,
"graph_storage": args.graph_storage,
"vector_storage": args.vector_storage,
"enable_llm_cache": args.enable_llm_cache,
"enable_llm_cache_for_extract": args.enable_llm_cache_for_extract,
},
"update_status": update_status,
}

View File

@@ -361,7 +361,7 @@ def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace:
args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int)
# Inject LLM cache configuration
args.enable_llm_cache = get_env_value(
args.enable_llm_cache_for_extract = get_env_value(
"ENABLE_LLM_CACHE_FOR_EXTRACT",
False,
bool
@@ -460,8 +460,8 @@ def display_splash_screen(args: argparse.Namespace) -> None:
ASCIIColors.yellow(f"{args.cosine_threshold}")
ASCIIColors.white(" ├─ Top-K: ", end="")
ASCIIColors.yellow(f"{args.top_k}")
ASCIIColors.white(" └─ LLM Cache Enabled: ", end="")
ASCIIColors.yellow(f"{args.enable_llm_cache}")
ASCIIColors.white(" └─ LLM Cache for Extraction Enabled: ", end="")
ASCIIColors.yellow(f"{args.enable_llm_cache_for_extract}")
# System Configuration
ASCIIColors.magenta("\n💾 Storage Configuration:")

View File

@@ -354,6 +354,7 @@ class LightRAG:
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self), # Add global_config to ensure cache works properly
embedding_func=self.embedding_func,
)
@@ -404,18 +405,8 @@ class LightRAG:
embedding_func=None,
)
if self.llm_response_cache and hasattr(
self.llm_response_cache, "global_config"
):
# Directly use llm_response_cache, don't create a new object
hashing_kv = self.llm_response_cache
else:
hashing_kv = self.key_string_value_json_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
embedding_func=self.embedding_func,
)
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
partial(
@@ -1260,16 +1251,7 @@ class LightRAG:
self.text_chunks,
param,
asdict(self),
hashing_kv=self.llm_response_cache
if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
embedding_func=self.embedding_func,
),
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
system_prompt=system_prompt,
)
elif param.mode == "naive":
@@ -1279,16 +1261,7 @@ class LightRAG:
self.text_chunks,
param,
asdict(self),
hashing_kv=self.llm_response_cache
if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
embedding_func=self.embedding_func,
),
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
system_prompt=system_prompt,
)
elif param.mode == "mix":
@@ -1301,16 +1274,7 @@ class LightRAG:
self.text_chunks,
param,
asdict(self),
hashing_kv=self.llm_response_cache
if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
embedding_func=self.embedding_func,
),
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
system_prompt=system_prompt,
)
else:
@@ -1344,14 +1308,7 @@ class LightRAG:
text=query,
param=param,
global_config=asdict(self),
hashing_kv=self.llm_response_cache
or self.key_string_value_json_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
embedding_func=self.embedding_func,
),
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
)
param.hl_keywords = hl_keywords
@@ -1375,16 +1332,7 @@ class LightRAG:
self.text_chunks,
param,
asdict(self),
hashing_kv=self.llm_response_cache
if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
embedding_func=self.embedding_func,
),
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
)
elif param.mode == "naive":
response = await naive_query(
@@ -1393,16 +1341,7 @@ class LightRAG:
self.text_chunks,
param,
asdict(self),
hashing_kv=self.llm_response_cache
if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
embedding_func=self.embedding_func,
),
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
)
elif param.mode == "mix":
response = await mix_kg_vector_query(
@@ -1414,16 +1353,7 @@ class LightRAG:
self.text_chunks,
param,
asdict(self),
hashing_kv=self.llm_response_cache
if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
embedding_func=self.embedding_func,
),
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
)
else:
raise ValueError(f"Unknown mode {param.mode}")

View File

@@ -410,7 +410,6 @@ async def extract_entities(
_prompt,
"default",
cache_type="extract",
force_llm_cache=True,
)
if cached_return:
logger.debug(f"Found cache for {arg_hash}")
@@ -432,6 +431,7 @@ async def extract_entities(
cache_type="extract",
),
)
logger.info(f"Extract: saved cache for {arg_hash}")
return res
if history_messages:

View File

@@ -633,15 +633,15 @@ async def handle_cache(
prompt,
mode="default",
cache_type=None,
force_llm_cache=False,
):
"""Generic cache handling function"""
if hashing_kv is None or not (
force_llm_cache or hashing_kv.global_config.get("enable_llm_cache")
):
if hashing_kv is None:
return None, None, None, None
if mode != "default": # handle cache for all type of query
if not hashing_kv.global_config.get("enable_llm_cache"):
return None, None, None, None
if mode != "default":
# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config",
@@ -651,8 +651,7 @@ async def handle_cache(
use_llm_check = embedding_cache_config.get("use_llm_check", False)
quantized = min_val = max_val = None
if is_embedding_cache_enabled:
# Use embedding cache
if is_embedding_cache_enabled: # Use embedding simularity to match cache
current_embedding = await hashing_kv.embedding_func([prompt])
llm_model_func = hashing_kv.global_config.get("llm_model_func")
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
@@ -674,8 +673,13 @@ async def handle_cache(
logger.debug(f"Embedding cached missed(mode:{mode} type:{cache_type})")
return None, quantized, min_val, max_val
# For default mode or is_embedding_cache_enabled is False, use regular cache
# default mode is for extract_entities or naive query
else: # handle cache for entity extraction
if not hashing_kv.global_config.get("enable_llm_cache_for_entity_extract"):
return None, None, None, None
# Here is the conditions of code reaching this point:
# 1. All query mode: enable_llm_cache is True and embedding simularity is not enabled
# 2. Entity extract: enable_llm_cache_for_entity_extract is True
if exists_func(hashing_kv, "get_by_mode_and_id"):
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
else: