Merge remote-tracking branch 'origin/main' and fix syntax
This commit is contained in:
@@ -9,12 +9,14 @@ import re
|
||||
from dataclasses import dataclass
|
||||
from functools import wraps
|
||||
from hashlib import md5
|
||||
from typing import Any, Union, List
|
||||
from typing import Any, Union, List, Optional
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
import numpy as np
|
||||
import tiktoken
|
||||
|
||||
from lightrag.prompt import PROMPTS
|
||||
|
||||
ENCODER = None
|
||||
|
||||
logger = logging.getLogger("lightrag")
|
||||
@@ -314,6 +316,9 @@ async def get_best_cached_response(
|
||||
current_embedding,
|
||||
similarity_threshold=0.95,
|
||||
mode="default",
|
||||
use_llm_check=False,
|
||||
llm_func=None,
|
||||
original_prompt=None,
|
||||
) -> Union[str, None]:
|
||||
# Get mode-specific cache
|
||||
mode_cache = await hashing_kv.get_by_id(mode)
|
||||
@@ -348,6 +353,37 @@ async def get_best_cached_response(
|
||||
best_cache_id = cache_id
|
||||
|
||||
if best_similarity > similarity_threshold:
|
||||
# If LLM check is enabled and all required parameters are provided
|
||||
if use_llm_check and llm_func and original_prompt and best_prompt:
|
||||
compare_prompt = PROMPTS["similarity_check"].format(
|
||||
original_prompt=original_prompt, cached_prompt=best_prompt
|
||||
)
|
||||
|
||||
try:
|
||||
llm_result = await llm_func(compare_prompt)
|
||||
llm_result = llm_result.strip()
|
||||
llm_similarity = float(llm_result)
|
||||
|
||||
# Replace vector similarity with LLM similarity score
|
||||
best_similarity = llm_similarity
|
||||
if best_similarity < similarity_threshold:
|
||||
log_data = {
|
||||
"event": "llm_check_cache_rejected",
|
||||
"original_question": original_prompt[:100] + "..."
|
||||
if len(original_prompt) > 100
|
||||
else original_prompt,
|
||||
"cached_question": best_prompt[:100] + "..."
|
||||
if len(best_prompt) > 100
|
||||
else best_prompt,
|
||||
"similarity_score": round(best_similarity, 4),
|
||||
"threshold": similarity_threshold,
|
||||
}
|
||||
logger.info(json.dumps(log_data, ensure_ascii=False))
|
||||
return None
|
||||
except Exception as e: # Catch all possible exceptions
|
||||
logger.warning(f"LLM similarity check failed: {e}")
|
||||
return None # Return None directly when LLM check fails
|
||||
|
||||
prompt_display = (
|
||||
best_prompt[:50] + "..." if len(best_prompt) > 50 else best_prompt
|
||||
)
|
||||
@@ -390,3 +426,84 @@ def dequantize_embedding(
|
||||
"""Restore quantized embedding"""
|
||||
scale = (max_val - min_val) / (2**bits - 1)
|
||||
return (quantized * scale + min_val).astype(np.float32)
|
||||
|
||||
|
||||
async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
|
||||
"""Generic cache handling function"""
|
||||
if hashing_kv is None:
|
||||
return None, None, None, None
|
||||
|
||||
# For naive mode, only use simple cache matching
|
||||
if mode == "naive":
|
||||
mode_cache = await hashing_kv.get_by_id(mode) or {}
|
||||
if args_hash in mode_cache:
|
||||
return mode_cache[args_hash]["return"], None, None, None
|
||||
return None, None, None, None
|
||||
|
||||
# Get embedding cache configuration
|
||||
embedding_cache_config = hashing_kv.global_config.get(
|
||||
"embedding_cache_config",
|
||||
{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False},
|
||||
)
|
||||
is_embedding_cache_enabled = embedding_cache_config["enabled"]
|
||||
use_llm_check = embedding_cache_config.get("use_llm_check", False)
|
||||
|
||||
quantized = min_val = max_val = None
|
||||
if is_embedding_cache_enabled:
|
||||
# Use embedding cache
|
||||
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
|
||||
llm_model_func = hashing_kv.global_config.get("llm_model_func")
|
||||
|
||||
current_embedding = await embedding_model_func([prompt])
|
||||
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
||||
best_cached_response = await get_best_cached_response(
|
||||
hashing_kv,
|
||||
current_embedding[0],
|
||||
similarity_threshold=embedding_cache_config["similarity_threshold"],
|
||||
mode=mode,
|
||||
use_llm_check=use_llm_check,
|
||||
llm_func=llm_model_func if use_llm_check else None,
|
||||
original_prompt=prompt if use_llm_check else None,
|
||||
)
|
||||
if best_cached_response is not None:
|
||||
return best_cached_response, None, None, None
|
||||
else:
|
||||
# Use regular cache
|
||||
mode_cache = await hashing_kv.get_by_id(mode) or {}
|
||||
if args_hash in mode_cache:
|
||||
return mode_cache[args_hash]["return"], None, None, None
|
||||
|
||||
return None, quantized, min_val, max_val
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheData:
|
||||
args_hash: str
|
||||
content: str
|
||||
prompt: str
|
||||
quantized: Optional[np.ndarray] = None
|
||||
min_val: Optional[float] = None
|
||||
max_val: Optional[float] = None
|
||||
mode: str = "default"
|
||||
|
||||
|
||||
async def save_to_cache(hashing_kv, cache_data: CacheData):
|
||||
if hashing_kv is None or hasattr(cache_data.content, "__aiter__"):
|
||||
return
|
||||
|
||||
mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
|
||||
|
||||
mode_cache[cache_data.args_hash] = {
|
||||
"return": cache_data.content,
|
||||
"embedding": cache_data.quantized.tobytes().hex()
|
||||
if cache_data.quantized is not None
|
||||
else None,
|
||||
"embedding_shape": cache_data.quantized.shape
|
||||
if cache_data.quantized is not None
|
||||
else None,
|
||||
"embedding_min": cache_data.min_val,
|
||||
"embedding_max": cache_data.max_val,
|
||||
"original_prompt": cache_data.prompt,
|
||||
}
|
||||
|
||||
await hashing_kv.upsert({cache_data.mode: mode_cache})
|
||||
|
Reference in New Issue
Block a user