Unify llm_response_cache and hashing_kv, prevent creating an independent hashing_kv.
This commit is contained in:
@@ -323,7 +323,7 @@ def create_app(args):
|
|||||||
vector_db_storage_cls_kwargs={
|
vector_db_storage_cls_kwargs={
|
||||||
"cosine_better_than_threshold": args.cosine_threshold
|
"cosine_better_than_threshold": args.cosine_threshold
|
||||||
},
|
},
|
||||||
enable_llm_cache_for_entity_extract=args.enable_llm_cache, # Read from args
|
enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
|
||||||
embedding_cache_config={
|
embedding_cache_config={
|
||||||
"enabled": True,
|
"enabled": True,
|
||||||
"similarity_threshold": 0.95,
|
"similarity_threshold": 0.95,
|
||||||
@@ -352,7 +352,7 @@ def create_app(args):
|
|||||||
vector_db_storage_cls_kwargs={
|
vector_db_storage_cls_kwargs={
|
||||||
"cosine_better_than_threshold": args.cosine_threshold
|
"cosine_better_than_threshold": args.cosine_threshold
|
||||||
},
|
},
|
||||||
enable_llm_cache_for_entity_extract=args.enable_llm_cache, # Read from args
|
enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
|
||||||
embedding_cache_config={
|
embedding_cache_config={
|
||||||
"enabled": True,
|
"enabled": True,
|
||||||
"similarity_threshold": 0.95,
|
"similarity_threshold": 0.95,
|
||||||
@@ -416,7 +416,7 @@ def create_app(args):
|
|||||||
"doc_status_storage": args.doc_status_storage,
|
"doc_status_storage": args.doc_status_storage,
|
||||||
"graph_storage": args.graph_storage,
|
"graph_storage": args.graph_storage,
|
||||||
"vector_storage": args.vector_storage,
|
"vector_storage": args.vector_storage,
|
||||||
"enable_llm_cache": args.enable_llm_cache,
|
"enable_llm_cache_for_extract": args.enable_llm_cache_for_extract,
|
||||||
},
|
},
|
||||||
"update_status": update_status,
|
"update_status": update_status,
|
||||||
}
|
}
|
||||||
|
@@ -361,7 +361,7 @@ def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace:
|
|||||||
args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int)
|
args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int)
|
||||||
|
|
||||||
# Inject LLM cache configuration
|
# Inject LLM cache configuration
|
||||||
args.enable_llm_cache = get_env_value(
|
args.enable_llm_cache_for_extract = get_env_value(
|
||||||
"ENABLE_LLM_CACHE_FOR_EXTRACT",
|
"ENABLE_LLM_CACHE_FOR_EXTRACT",
|
||||||
False,
|
False,
|
||||||
bool
|
bool
|
||||||
@@ -460,8 +460,8 @@ def display_splash_screen(args: argparse.Namespace) -> None:
|
|||||||
ASCIIColors.yellow(f"{args.cosine_threshold}")
|
ASCIIColors.yellow(f"{args.cosine_threshold}")
|
||||||
ASCIIColors.white(" ├─ Top-K: ", end="")
|
ASCIIColors.white(" ├─ Top-K: ", end="")
|
||||||
ASCIIColors.yellow(f"{args.top_k}")
|
ASCIIColors.yellow(f"{args.top_k}")
|
||||||
ASCIIColors.white(" └─ LLM Cache Enabled: ", end="")
|
ASCIIColors.white(" └─ LLM Cache for Extraction Enabled: ", end="")
|
||||||
ASCIIColors.yellow(f"{args.enable_llm_cache}")
|
ASCIIColors.yellow(f"{args.enable_llm_cache_for_extract}")
|
||||||
|
|
||||||
# System Configuration
|
# System Configuration
|
||||||
ASCIIColors.magenta("\n💾 Storage Configuration:")
|
ASCIIColors.magenta("\n💾 Storage Configuration:")
|
||||||
|
@@ -354,6 +354,7 @@ class LightRAG:
|
|||||||
namespace=make_namespace(
|
namespace=make_namespace(
|
||||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
||||||
),
|
),
|
||||||
|
global_config=asdict(self), # Add global_config to ensure cache works properly
|
||||||
embedding_func=self.embedding_func,
|
embedding_func=self.embedding_func,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -404,18 +405,8 @@ class LightRAG:
|
|||||||
embedding_func=None,
|
embedding_func=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.llm_response_cache and hasattr(
|
# Directly use llm_response_cache, don't create a new object
|
||||||
self.llm_response_cache, "global_config"
|
hashing_kv = self.llm_response_cache
|
||||||
):
|
|
||||||
hashing_kv = self.llm_response_cache
|
|
||||||
else:
|
|
||||||
hashing_kv = self.key_string_value_json_storage_cls( # type: ignore
|
|
||||||
namespace=make_namespace(
|
|
||||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
|
||||||
),
|
|
||||||
global_config=asdict(self),
|
|
||||||
embedding_func=self.embedding_func,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
||||||
partial(
|
partial(
|
||||||
@@ -1260,16 +1251,7 @@ class LightRAG:
|
|||||||
self.text_chunks,
|
self.text_chunks,
|
||||||
param,
|
param,
|
||||||
asdict(self),
|
asdict(self),
|
||||||
hashing_kv=self.llm_response_cache
|
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
|
||||||
if self.llm_response_cache
|
|
||||||
and hasattr(self.llm_response_cache, "global_config")
|
|
||||||
else self.key_string_value_json_storage_cls(
|
|
||||||
namespace=make_namespace(
|
|
||||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
|
||||||
),
|
|
||||||
global_config=asdict(self),
|
|
||||||
embedding_func=self.embedding_func,
|
|
||||||
),
|
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
)
|
)
|
||||||
elif param.mode == "naive":
|
elif param.mode == "naive":
|
||||||
@@ -1279,16 +1261,7 @@ class LightRAG:
|
|||||||
self.text_chunks,
|
self.text_chunks,
|
||||||
param,
|
param,
|
||||||
asdict(self),
|
asdict(self),
|
||||||
hashing_kv=self.llm_response_cache
|
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
|
||||||
if self.llm_response_cache
|
|
||||||
and hasattr(self.llm_response_cache, "global_config")
|
|
||||||
else self.key_string_value_json_storage_cls(
|
|
||||||
namespace=make_namespace(
|
|
||||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
|
||||||
),
|
|
||||||
global_config=asdict(self),
|
|
||||||
embedding_func=self.embedding_func,
|
|
||||||
),
|
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
)
|
)
|
||||||
elif param.mode == "mix":
|
elif param.mode == "mix":
|
||||||
@@ -1301,16 +1274,7 @@ class LightRAG:
|
|||||||
self.text_chunks,
|
self.text_chunks,
|
||||||
param,
|
param,
|
||||||
asdict(self),
|
asdict(self),
|
||||||
hashing_kv=self.llm_response_cache
|
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
|
||||||
if self.llm_response_cache
|
|
||||||
and hasattr(self.llm_response_cache, "global_config")
|
|
||||||
else self.key_string_value_json_storage_cls(
|
|
||||||
namespace=make_namespace(
|
|
||||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
|
||||||
),
|
|
||||||
global_config=asdict(self),
|
|
||||||
embedding_func=self.embedding_func,
|
|
||||||
),
|
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -1344,14 +1308,7 @@ class LightRAG:
|
|||||||
text=query,
|
text=query,
|
||||||
param=param,
|
param=param,
|
||||||
global_config=asdict(self),
|
global_config=asdict(self),
|
||||||
hashing_kv=self.llm_response_cache
|
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
|
||||||
or self.key_string_value_json_storage_cls(
|
|
||||||
namespace=make_namespace(
|
|
||||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
|
||||||
),
|
|
||||||
global_config=asdict(self),
|
|
||||||
embedding_func=self.embedding_func,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
param.hl_keywords = hl_keywords
|
param.hl_keywords = hl_keywords
|
||||||
@@ -1375,16 +1332,7 @@ class LightRAG:
|
|||||||
self.text_chunks,
|
self.text_chunks,
|
||||||
param,
|
param,
|
||||||
asdict(self),
|
asdict(self),
|
||||||
hashing_kv=self.llm_response_cache
|
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
|
||||||
if self.llm_response_cache
|
|
||||||
and hasattr(self.llm_response_cache, "global_config")
|
|
||||||
else self.key_string_value_json_storage_cls(
|
|
||||||
namespace=make_namespace(
|
|
||||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
|
||||||
),
|
|
||||||
global_config=asdict(self),
|
|
||||||
embedding_func=self.embedding_func,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
elif param.mode == "naive":
|
elif param.mode == "naive":
|
||||||
response = await naive_query(
|
response = await naive_query(
|
||||||
@@ -1393,16 +1341,7 @@ class LightRAG:
|
|||||||
self.text_chunks,
|
self.text_chunks,
|
||||||
param,
|
param,
|
||||||
asdict(self),
|
asdict(self),
|
||||||
hashing_kv=self.llm_response_cache
|
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
|
||||||
if self.llm_response_cache
|
|
||||||
and hasattr(self.llm_response_cache, "global_config")
|
|
||||||
else self.key_string_value_json_storage_cls(
|
|
||||||
namespace=make_namespace(
|
|
||||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
|
||||||
),
|
|
||||||
global_config=asdict(self),
|
|
||||||
embedding_func=self.embedding_func,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
elif param.mode == "mix":
|
elif param.mode == "mix":
|
||||||
response = await mix_kg_vector_query(
|
response = await mix_kg_vector_query(
|
||||||
@@ -1414,16 +1353,7 @@ class LightRAG:
|
|||||||
self.text_chunks,
|
self.text_chunks,
|
||||||
param,
|
param,
|
||||||
asdict(self),
|
asdict(self),
|
||||||
hashing_kv=self.llm_response_cache
|
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
|
||||||
if self.llm_response_cache
|
|
||||||
and hasattr(self.llm_response_cache, "global_config")
|
|
||||||
else self.key_string_value_json_storage_cls(
|
|
||||||
namespace=make_namespace(
|
|
||||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
|
||||||
),
|
|
||||||
global_config=asdict(self),
|
|
||||||
embedding_func=self.embedding_func,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown mode {param.mode}")
|
raise ValueError(f"Unknown mode {param.mode}")
|
||||||
|
@@ -410,7 +410,6 @@ async def extract_entities(
|
|||||||
_prompt,
|
_prompt,
|
||||||
"default",
|
"default",
|
||||||
cache_type="extract",
|
cache_type="extract",
|
||||||
force_llm_cache=True,
|
|
||||||
)
|
)
|
||||||
if cached_return:
|
if cached_return:
|
||||||
logger.debug(f"Found cache for {arg_hash}")
|
logger.debug(f"Found cache for {arg_hash}")
|
||||||
@@ -432,6 +431,7 @@ async def extract_entities(
|
|||||||
cache_type="extract",
|
cache_type="extract",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
logger.info(f"Extract: saved cache for {arg_hash}")
|
||||||
return res
|
return res
|
||||||
|
|
||||||
if history_messages:
|
if history_messages:
|
||||||
|
@@ -633,15 +633,15 @@ async def handle_cache(
|
|||||||
prompt,
|
prompt,
|
||||||
mode="default",
|
mode="default",
|
||||||
cache_type=None,
|
cache_type=None,
|
||||||
force_llm_cache=False,
|
|
||||||
):
|
):
|
||||||
"""Generic cache handling function"""
|
"""Generic cache handling function"""
|
||||||
if hashing_kv is None or not (
|
if hashing_kv is None:
|
||||||
force_llm_cache or hashing_kv.global_config.get("enable_llm_cache")
|
|
||||||
):
|
|
||||||
return None, None, None, None
|
return None, None, None, None
|
||||||
|
|
||||||
if mode != "default":
|
if mode != "default": # handle cache for all type of query
|
||||||
|
if not hashing_kv.global_config.get("enable_llm_cache"):
|
||||||
|
return None, None, None, None
|
||||||
|
|
||||||
# Get embedding cache configuration
|
# Get embedding cache configuration
|
||||||
embedding_cache_config = hashing_kv.global_config.get(
|
embedding_cache_config = hashing_kv.global_config.get(
|
||||||
"embedding_cache_config",
|
"embedding_cache_config",
|
||||||
@@ -651,8 +651,7 @@ async def handle_cache(
|
|||||||
use_llm_check = embedding_cache_config.get("use_llm_check", False)
|
use_llm_check = embedding_cache_config.get("use_llm_check", False)
|
||||||
|
|
||||||
quantized = min_val = max_val = None
|
quantized = min_val = max_val = None
|
||||||
if is_embedding_cache_enabled:
|
if is_embedding_cache_enabled: # Use embedding simularity to match cache
|
||||||
# Use embedding cache
|
|
||||||
current_embedding = await hashing_kv.embedding_func([prompt])
|
current_embedding = await hashing_kv.embedding_func([prompt])
|
||||||
llm_model_func = hashing_kv.global_config.get("llm_model_func")
|
llm_model_func = hashing_kv.global_config.get("llm_model_func")
|
||||||
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
||||||
@@ -674,8 +673,13 @@ async def handle_cache(
|
|||||||
logger.debug(f"Embedding cached missed(mode:{mode} type:{cache_type})")
|
logger.debug(f"Embedding cached missed(mode:{mode} type:{cache_type})")
|
||||||
return None, quantized, min_val, max_val
|
return None, quantized, min_val, max_val
|
||||||
|
|
||||||
# For default mode or is_embedding_cache_enabled is False, use regular cache
|
else: # handle cache for entity extraction
|
||||||
# default mode is for extract_entities or naive query
|
if not hashing_kv.global_config.get("enable_llm_cache_for_entity_extract"):
|
||||||
|
return None, None, None, None
|
||||||
|
|
||||||
|
# Here is the conditions of code reaching this point:
|
||||||
|
# 1. All query mode: enable_llm_cache is True and embedding simularity is not enabled
|
||||||
|
# 2. Entity extract: enable_llm_cache_for_entity_extract is True
|
||||||
if exists_func(hashing_kv, "get_by_mode_and_id"):
|
if exists_func(hashing_kv, "get_by_mode_and_id"):
|
||||||
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
|
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
|
||||||
else:
|
else:
|
||||||
|
Reference in New Issue
Block a user