重构缓存处理逻辑

- 提取通用缓存处理逻辑到新函数 handle_cache 和 save_to_cache
- 使用 CacheData 类统一缓存数据结构
- 优化嵌入式缓存和常规缓存的处理流程
- 添加模式参数以支持不同查询模式的缓存策略
- 重构 get_best_cached_response 函数,提高缓存查询效率
This commit is contained in:
magicyuan876
2024-12-06 14:29:16 +08:00
parent 5dfb74ef2d
commit e619b09c8a
3 changed files with 277 additions and 309 deletions

View File

@@ -310,43 +310,57 @@ def process_combine_contexts(hl, ll):
async def get_best_cached_response(
hashing_kv, current_embedding, similarity_threshold=0.95
):
"""Get the cached response with the highest similarity"""
try:
# Get all keys
all_keys = await hashing_kv.all_keys()
max_similarity = 0
best_cached_response = None
# Get cached data one by one
for key in all_keys:
cache_data = await hashing_kv.get_by_id(key)
if cache_data is None or "embedding" not in cache_data:
continue
# Convert cached embedding list to ndarray
cached_quantized = np.frombuffer(
bytes.fromhex(cache_data["embedding"]), dtype=np.uint8
).reshape(cache_data["embedding_shape"])
cached_embedding = dequantize_embedding(
cached_quantized,
cache_data["embedding_min"],
cache_data["embedding_max"],
)
similarity = cosine_similarity(current_embedding, cached_embedding)
if similarity > max_similarity:
max_similarity = similarity
best_cached_response = cache_data["return"]
if max_similarity > similarity_threshold:
return best_cached_response
hashing_kv,
current_embedding,
similarity_threshold=0.95,
mode="default",
) -> Union[str, None]:
# Get mode-specific cache
mode_cache = await hashing_kv.get_by_id(mode)
if not mode_cache:
return None
except Exception as e:
logger.warning(f"Error in get_best_cached_response: {e}")
return None
best_similarity = -1
best_response = None
best_prompt = None
best_cache_id = None
# Only iterate through cache entries for this mode
for cache_id, cache_data in mode_cache.items():
if cache_data["embedding"] is None:
continue
# Convert cached embedding list to ndarray
cached_quantized = np.frombuffer(
bytes.fromhex(cache_data["embedding"]), dtype=np.uint8
).reshape(cache_data["embedding_shape"])
cached_embedding = dequantize_embedding(
cached_quantized,
cache_data["embedding_min"],
cache_data["embedding_max"],
)
similarity = cosine_similarity(current_embedding, cached_embedding)
if similarity > best_similarity:
best_similarity = similarity
best_response = cache_data["return"]
best_prompt = cache_data["original_prompt"]
best_cache_id = cache_id
if best_similarity > similarity_threshold:
prompt_display = (
best_prompt[:50] + "..." if len(best_prompt) > 50 else best_prompt
)
log_data = {
"event": "cache_hit",
"mode": mode,
"similarity": round(best_similarity, 4),
"cache_id": best_cache_id,
"original_prompt": prompt_display,
}
logger.info(json.dumps(log_data))
return best_response
return None
def cosine_similarity(v1, v2):