Refactor cache handling logic for better readability, keep function unchanged.

This commit is contained in:
yangdx
2025-02-02 00:10:21 +08:00
parent c9481c81b9
commit 6c7d7c25d3

View File

@@ -490,56 +490,50 @@ async def handle_cache(
if hashing_kv is None or not hashing_kv.global_config.get("enable_llm_cache"): if hashing_kv is None or not hashing_kv.global_config.get("enable_llm_cache"):
return None, None, None, None return None, None, None, None
# For default mode, only use simple cache matching if mode != "default":
if mode == "default": # Get embedding cache configuration
if exists_func(hashing_kv, "get_by_mode_and_id"): embedding_cache_config = hashing_kv.global_config.get(
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {} "embedding_cache_config",
else: {"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False},
mode_cache = await hashing_kv.get_by_id(mode) or {}
if args_hash in mode_cache:
return mode_cache[args_hash]["return"], None, None, None
return None, None, None, None
# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config",
{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False},
)
is_embedding_cache_enabled = embedding_cache_config["enabled"]
use_llm_check = embedding_cache_config.get("use_llm_check", False)
quantized = min_val = max_val = None
if is_embedding_cache_enabled:
# Use embedding cache
current_embedding = await hashing_kv.embedding_func([prompt])
llm_model_func = (
hashing_kv.llm_model_func if hasattr(hashing_kv, "llm_model_func") else None
) )
quantized, min_val, max_val = quantize_embedding(current_embedding[0]) is_embedding_cache_enabled = embedding_cache_config["enabled"]
best_cached_response = await get_best_cached_response( use_llm_check = embedding_cache_config.get("use_llm_check", False)
hashing_kv,
current_embedding[0], quantized = min_val = max_val = None
similarity_threshold=embedding_cache_config["similarity_threshold"], if is_embedding_cache_enabled:
mode=mode, # Use embedding cache
use_llm_check=use_llm_check, current_embedding = await hashing_kv.embedding_func([prompt])
llm_func=llm llm_model_func = (
if (use_llm_check and llm is not None) hashing_kv.llm_model_func if hasattr(hashing_kv, "llm_model_func") else None
else (llm_model_func if use_llm_check else None), )
original_prompt=prompt if use_llm_check else None, quantized, min_val, max_val = quantize_embedding(current_embedding[0])
cache_type=cache_type, best_cached_response = await get_best_cached_response(
) hashing_kv,
if best_cached_response is not None: current_embedding[0],
return best_cached_response, None, None, None similarity_threshold=embedding_cache_config["similarity_threshold"],
mode=mode,
use_llm_check=use_llm_check,
llm_func=llm
if (use_llm_check and llm is not None)
else (llm_model_func if use_llm_check else None),
original_prompt=prompt if use_llm_check else None,
cache_type=cache_type,
)
if best_cached_response is not None:
return best_cached_response, None, None, None
else:
return None, quantized, min_val, max_val
# For default mode(extract_entities or naive query) or is_embedding_cache_enabled is False
# Use regular cache
if exists_func(hashing_kv, "get_by_mode_and_id"):
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
else: else:
# Use regular cache mode_cache = await hashing_kv.get_by_id(mode) or {}
if exists_func(hashing_kv, "get_by_mode_and_id"): if args_hash in mode_cache:
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {} return mode_cache[args_hash]["return"], None, None, None
else:
mode_cache = await hashing_kv.get_by_id(mode) or {}
if args_hash in mode_cache:
return mode_cache[args_hash]["return"], None, None, None
return None, quantized, min_val, max_val return None, None, None, None
@dataclass @dataclass