diff --git a/lightrag/utils.py b/lightrag/utils.py index d4d42b40..edf96dcc 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -490,56 +490,50 @@ async def handle_cache( if hashing_kv is None or not hashing_kv.global_config.get("enable_llm_cache"): return None, None, None, None - # For default mode, only use simple cache matching - if mode == "default": - if exists_func(hashing_kv, "get_by_mode_and_id"): - mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {} - else: - mode_cache = await hashing_kv.get_by_id(mode) or {} - if args_hash in mode_cache: - return mode_cache[args_hash]["return"], None, None, None - return None, None, None, None - - # Get embedding cache configuration - embedding_cache_config = hashing_kv.global_config.get( - "embedding_cache_config", - {"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False}, - ) - is_embedding_cache_enabled = embedding_cache_config["enabled"] - use_llm_check = embedding_cache_config.get("use_llm_check", False) - - quantized = min_val = max_val = None - if is_embedding_cache_enabled: - # Use embedding cache - current_embedding = await hashing_kv.embedding_func([prompt]) - llm_model_func = ( - hashing_kv.llm_model_func if hasattr(hashing_kv, "llm_model_func") else None + if mode != "default": + # Get embedding cache configuration + embedding_cache_config = hashing_kv.global_config.get( + "embedding_cache_config", + {"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False}, ) - quantized, min_val, max_val = quantize_embedding(current_embedding[0]) - best_cached_response = await get_best_cached_response( - hashing_kv, - current_embedding[0], - similarity_threshold=embedding_cache_config["similarity_threshold"], - mode=mode, - use_llm_check=use_llm_check, - llm_func=llm - if (use_llm_check and llm is not None) - else (llm_model_func if use_llm_check else None), - original_prompt=prompt if use_llm_check else None, - cache_type=cache_type, - ) - if best_cached_response is not None: - return best_cached_response, None, None, None + is_embedding_cache_enabled = embedding_cache_config["enabled"] + use_llm_check = embedding_cache_config.get("use_llm_check", False) + + quantized = min_val = max_val = None + if is_embedding_cache_enabled: + # Use embedding cache + current_embedding = await hashing_kv.embedding_func([prompt]) + llm_model_func = ( + hashing_kv.llm_model_func if hasattr(hashing_kv, "llm_model_func") else None + ) + quantized, min_val, max_val = quantize_embedding(current_embedding[0]) + best_cached_response = await get_best_cached_response( + hashing_kv, + current_embedding[0], + similarity_threshold=embedding_cache_config["similarity_threshold"], + mode=mode, + use_llm_check=use_llm_check, + llm_func=llm + if (use_llm_check and llm is not None) + else (llm_model_func if use_llm_check else None), + original_prompt=prompt if use_llm_check else None, + cache_type=cache_type, + ) + if best_cached_response is not None: + return best_cached_response, None, None, None + else: + return None, quantized, min_val, max_val + + # For default mode(extract_entities or naive query) or is_embedding_cache_enabled is False + # Use regular cache + if exists_func(hashing_kv, "get_by_mode_and_id"): + mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {} else: - # Use regular cache - if exists_func(hashing_kv, "get_by_mode_and_id"): - mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {} - else: - mode_cache = await hashing_kv.get_by_id(mode) or {} - if args_hash in mode_cache: - return mode_cache[args_hash]["return"], None, None, None + mode_cache = await hashing_kv.get_by_id(mode) or {} + if args_hash in mode_cache: + return mode_cache[args_hash]["return"], None, None, None - return None, quantized, min_val, max_val + return None, None, None, None @dataclass