缓存计算函数迁移到工具类
This commit is contained in:
@@ -9,7 +9,7 @@ import re
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from hashlib import md5
|
from hashlib import md5
|
||||||
from typing import Any, Union, List
|
from typing import Any, Union, List, Optional
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -390,3 +390,71 @@ def dequantize_embedding(
|
|||||||
"""Restore quantized embedding"""
|
"""Restore quantized embedding"""
|
||||||
scale = (max_val - min_val) / (2**bits - 1)
|
scale = (max_val - min_val) / (2**bits - 1)
|
||||||
return (quantized * scale + min_val).astype(np.float32)
|
return (quantized * scale + min_val).astype(np.float32)
|
||||||
|
|
||||||
|
async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
|
||||||
|
"""Generic cache handling function"""
|
||||||
|
if hashing_kv is None:
|
||||||
|
return None, None, None, None
|
||||||
|
|
||||||
|
# Get embedding cache configuration
|
||||||
|
embedding_cache_config = hashing_kv.global_config.get(
|
||||||
|
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
||||||
|
)
|
||||||
|
is_embedding_cache_enabled = embedding_cache_config["enabled"]
|
||||||
|
|
||||||
|
quantized = min_val = max_val = None
|
||||||
|
if is_embedding_cache_enabled:
|
||||||
|
# Use embedding cache
|
||||||
|
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
|
||||||
|
current_embedding = await embedding_model_func([prompt])
|
||||||
|
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
||||||
|
best_cached_response = await get_best_cached_response(
|
||||||
|
hashing_kv,
|
||||||
|
current_embedding[0],
|
||||||
|
similarity_threshold=embedding_cache_config["similarity_threshold"],
|
||||||
|
mode=mode,
|
||||||
|
)
|
||||||
|
if best_cached_response is not None:
|
||||||
|
return best_cached_response, None, None, None
|
||||||
|
else:
|
||||||
|
# Use regular cache
|
||||||
|
mode_cache = await hashing_kv.get_by_id(mode) or {}
|
||||||
|
if args_hash in mode_cache:
|
||||||
|
return mode_cache[args_hash]["return"], None, None, None
|
||||||
|
|
||||||
|
return None, quantized, min_val, max_val
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CacheData:
|
||||||
|
args_hash: str
|
||||||
|
content: str
|
||||||
|
model: str
|
||||||
|
prompt: str
|
||||||
|
quantized: Optional[np.ndarray] = None
|
||||||
|
min_val: Optional[float] = None
|
||||||
|
max_val: Optional[float] = None
|
||||||
|
mode: str = "default"
|
||||||
|
|
||||||
|
|
||||||
|
async def save_to_cache(hashing_kv, cache_data: CacheData):
|
||||||
|
if hashing_kv is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
|
||||||
|
|
||||||
|
mode_cache[cache_data.args_hash] = {
|
||||||
|
"return": cache_data.content,
|
||||||
|
"model": cache_data.model,
|
||||||
|
"embedding": cache_data.quantized.tobytes().hex()
|
||||||
|
if cache_data.quantized is not None
|
||||||
|
else None,
|
||||||
|
"embedding_shape": cache_data.quantized.shape
|
||||||
|
if cache_data.quantized is not None
|
||||||
|
else None,
|
||||||
|
"embedding_min": cache_data.min_val,
|
||||||
|
"embedding_max": cache_data.max_val,
|
||||||
|
"original_prompt": cache_data.prompt,
|
||||||
|
}
|
||||||
|
|
||||||
|
await hashing_kv.upsert({cache_data.mode: mode_cache})
|
||||||
|
Reference in New Issue
Block a user