Merge remote-tracking branch 'origin/main' and fix syntax

This commit is contained in:
partoneplay
2024-12-09 12:36:55 +08:00
8 changed files with 361 additions and 262 deletions

View File

@@ -9,12 +9,14 @@ import re
from dataclasses import dataclass
from functools import wraps
from hashlib import md5
from typing import Any, Union, List
from typing import Any, Union, List, Optional
import xml.etree.ElementTree as ET
import numpy as np
import tiktoken
from lightrag.prompt import PROMPTS
ENCODER = None
logger = logging.getLogger("lightrag")
@@ -314,6 +316,9 @@ async def get_best_cached_response(
current_embedding,
similarity_threshold=0.95,
mode="default",
use_llm_check=False,
llm_func=None,
original_prompt=None,
) -> Union[str, None]:
# Get mode-specific cache
mode_cache = await hashing_kv.get_by_id(mode)
@@ -348,6 +353,37 @@ async def get_best_cached_response(
best_cache_id = cache_id
if best_similarity > similarity_threshold:
# If LLM check is enabled and all required parameters are provided
if use_llm_check and llm_func and original_prompt and best_prompt:
compare_prompt = PROMPTS["similarity_check"].format(
original_prompt=original_prompt, cached_prompt=best_prompt
)
try:
llm_result = await llm_func(compare_prompt)
llm_result = llm_result.strip()
llm_similarity = float(llm_result)
# Replace vector similarity with LLM similarity score
best_similarity = llm_similarity
if best_similarity < similarity_threshold:
log_data = {
"event": "llm_check_cache_rejected",
"original_question": original_prompt[:100] + "..."
if len(original_prompt) > 100
else original_prompt,
"cached_question": best_prompt[:100] + "..."
if len(best_prompt) > 100
else best_prompt,
"similarity_score": round(best_similarity, 4),
"threshold": similarity_threshold,
}
logger.info(json.dumps(log_data, ensure_ascii=False))
return None
except Exception as e: # Catch all possible exceptions
logger.warning(f"LLM similarity check failed: {e}")
return None # Return None directly when LLM check fails
prompt_display = (
best_prompt[:50] + "..." if len(best_prompt) > 50 else best_prompt
)
@@ -390,3 +426,84 @@ def dequantize_embedding(
"""Restore quantized embedding"""
scale = (max_val - min_val) / (2**bits - 1)
return (quantized * scale + min_val).astype(np.float32)
async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
"""Generic cache handling function"""
if hashing_kv is None:
return None, None, None, None
# For naive mode, only use simple cache matching
if mode == "naive":
mode_cache = await hashing_kv.get_by_id(mode) or {}
if args_hash in mode_cache:
return mode_cache[args_hash]["return"], None, None, None
return None, None, None, None
# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config",
{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False},
)
is_embedding_cache_enabled = embedding_cache_config["enabled"]
use_llm_check = embedding_cache_config.get("use_llm_check", False)
quantized = min_val = max_val = None
if is_embedding_cache_enabled:
# Use embedding cache
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
llm_model_func = hashing_kv.global_config.get("llm_model_func")
current_embedding = await embedding_model_func([prompt])
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
best_cached_response = await get_best_cached_response(
hashing_kv,
current_embedding[0],
similarity_threshold=embedding_cache_config["similarity_threshold"],
mode=mode,
use_llm_check=use_llm_check,
llm_func=llm_model_func if use_llm_check else None,
original_prompt=prompt if use_llm_check else None,
)
if best_cached_response is not None:
return best_cached_response, None, None, None
else:
# Use regular cache
mode_cache = await hashing_kv.get_by_id(mode) or {}
if args_hash in mode_cache:
return mode_cache[args_hash]["return"], None, None, None
return None, quantized, min_val, max_val
@dataclass
class CacheData:
args_hash: str
content: str
prompt: str
quantized: Optional[np.ndarray] = None
min_val: Optional[float] = None
max_val: Optional[float] = None
mode: str = "default"
async def save_to_cache(hashing_kv, cache_data: CacheData):
if hashing_kv is None or hasattr(cache_data.content, "__aiter__"):
return
mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
mode_cache[cache_data.args_hash] = {
"return": cache_data.content,
"embedding": cache_data.quantized.tobytes().hex()
if cache_data.quantized is not None
else None,
"embedding_shape": cache_data.quantized.shape
if cache_data.quantized is not None
else None,
"embedding_min": cache_data.min_val,
"embedding_max": cache_data.max_val,
"original_prompt": cache_data.prompt,
}
await hashing_kv.upsert({cache_data.mode: mode_cache})