feat(cache): 增加 LLM 相似性检查功能并优化缓存机制
- 在 embedding 缓存配置中添加 use_llm_check 参数 - 实现 LLM 相似性检查逻辑,作为缓存命中的二次验证- 优化 naive 模式的缓存处理流程 - 调整缓存数据结构,移除不必要的 model 字段
This commit is contained in:
@@ -15,6 +15,8 @@ import xml.etree.ElementTree as ET
|
||||
import numpy as np
|
||||
import tiktoken
|
||||
|
||||
from lightrag.prompt import PROMPTS
|
||||
|
||||
ENCODER = None
|
||||
|
||||
logger = logging.getLogger("lightrag")
|
||||
@@ -314,6 +316,9 @@ async def get_best_cached_response(
|
||||
current_embedding,
|
||||
similarity_threshold=0.95,
|
||||
mode="default",
|
||||
use_llm_check=False,
|
||||
llm_func=None,
|
||||
original_prompt=None,
|
||||
) -> Union[str, None]:
|
||||
# Get mode-specific cache
|
||||
mode_cache = await hashing_kv.get_by_id(mode)
|
||||
@@ -348,6 +353,37 @@ async def get_best_cached_response(
|
||||
best_cache_id = cache_id
|
||||
|
||||
if best_similarity > similarity_threshold:
|
||||
# If LLM check is enabled and all required parameters are provided
|
||||
if use_llm_check and llm_func and original_prompt and best_prompt:
|
||||
compare_prompt = PROMPTS["similarity_check"].format(
|
||||
original_prompt=original_prompt, cached_prompt=best_prompt
|
||||
)
|
||||
|
||||
try:
|
||||
llm_result = await llm_func(compare_prompt)
|
||||
llm_result = llm_result.strip()
|
||||
llm_similarity = float(llm_result)
|
||||
|
||||
# Replace vector similarity with LLM similarity score
|
||||
best_similarity = llm_similarity
|
||||
if best_similarity < similarity_threshold:
|
||||
log_data = {
|
||||
"event": "llm_check_cache_rejected",
|
||||
"original_question": original_prompt[:100] + "..."
|
||||
if len(original_prompt) > 100
|
||||
else original_prompt,
|
||||
"cached_question": best_prompt[:100] + "..."
|
||||
if len(best_prompt) > 100
|
||||
else best_prompt,
|
||||
"similarity_score": round(best_similarity, 4),
|
||||
"threshold": similarity_threshold,
|
||||
}
|
||||
logger.info(json.dumps(log_data, ensure_ascii=False))
|
||||
return None
|
||||
except Exception as e: # Catch all possible exceptions
|
||||
logger.warning(f"LLM similarity check failed: {e}")
|
||||
return None # Return None directly when LLM check fails
|
||||
|
||||
prompt_display = (
|
||||
best_prompt[:50] + "..." if len(best_prompt) > 50 else best_prompt
|
||||
)
|
||||
@@ -391,21 +427,33 @@ def dequantize_embedding(
|
||||
scale = (max_val - min_val) / (2**bits - 1)
|
||||
return (quantized * scale + min_val).astype(np.float32)
|
||||
|
||||
|
||||
async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
|
||||
"""Generic cache handling function"""
|
||||
if hashing_kv is None:
|
||||
return None, None, None, None
|
||||
|
||||
# For naive mode, only use simple cache matching
|
||||
if mode == "naive":
|
||||
mode_cache = await hashing_kv.get_by_id(mode) or {}
|
||||
if args_hash in mode_cache:
|
||||
return mode_cache[args_hash]["return"], None, None, None
|
||||
return None, None, None, None
|
||||
|
||||
# Get embedding cache configuration
|
||||
embedding_cache_config = hashing_kv.global_config.get(
|
||||
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
||||
"embedding_cache_config",
|
||||
{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False},
|
||||
)
|
||||
is_embedding_cache_enabled = embedding_cache_config["enabled"]
|
||||
use_llm_check = embedding_cache_config.get("use_llm_check", False)
|
||||
|
||||
quantized = min_val = max_val = None
|
||||
if is_embedding_cache_enabled:
|
||||
# Use embedding cache
|
||||
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
|
||||
llm_model_func = hashing_kv.global_config.get("llm_model_func")
|
||||
|
||||
current_embedding = await embedding_model_func([prompt])
|
||||
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
||||
best_cached_response = await get_best_cached_response(
|
||||
@@ -413,6 +461,9 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
|
||||
current_embedding[0],
|
||||
similarity_threshold=embedding_cache_config["similarity_threshold"],
|
||||
mode=mode,
|
||||
use_llm_check=use_llm_check,
|
||||
llm_func=llm_model_func if use_llm_check else None,
|
||||
original_prompt=prompt if use_llm_check else None,
|
||||
)
|
||||
if best_cached_response is not None:
|
||||
return best_cached_response, None, None, None
|
||||
@@ -429,7 +480,6 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
|
||||
class CacheData:
|
||||
args_hash: str
|
||||
content: str
|
||||
model: str
|
||||
prompt: str
|
||||
quantized: Optional[np.ndarray] = None
|
||||
min_val: Optional[float] = None
|
||||
@@ -445,7 +495,6 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
|
||||
|
||||
mode_cache[cache_data.args_hash] = {
|
||||
"return": cache_data.content,
|
||||
"model": cache_data.model,
|
||||
"embedding": cache_data.quantized.tobytes().hex()
|
||||
if cache_data.quantized is not None
|
||||
else None,
|
||||
|
Reference in New Issue
Block a user