Improve quantize and dequantize handling of embedding

This commit is contained in:
yangdx
2025-04-24 20:03:01 +08:00
parent f595834c60
commit f6129857a1

View File

@@ -573,18 +573,32 @@ async def get_best_cached_response(
if cache_type and cache_data.get("cache_type") != cache_type:
continue
# Check if cache data is valid
if cache_data["embedding"] is None:
continue
# Convert cached embedding list to ndarray
cached_quantized = np.frombuffer(
bytes.fromhex(cache_data["embedding"]), dtype=np.uint8
).reshape(cache_data["embedding_shape"])
cached_embedding = dequantize_embedding(
cached_quantized,
cache_data["embedding_min"],
cache_data["embedding_max"],
)
try:
# Safely convert cached embedding
cached_quantized = np.frombuffer(
bytes.fromhex(cache_data["embedding"]), dtype=np.uint8
).reshape(cache_data["embedding_shape"])
# Ensure min_val and max_val are valid float values
embedding_min = cache_data.get("embedding_min")
embedding_max = cache_data.get("embedding_max")
if embedding_min is None or embedding_max is None or embedding_min >= embedding_max:
logger.warning(f"Invalid embedding min/max values: min={embedding_min}, max={embedding_max}")
continue
cached_embedding = dequantize_embedding(
cached_quantized,
embedding_min,
embedding_max,
)
except Exception as e:
logger.warning(f"Error processing cached embedding: {str(e)}")
continue
similarity = cosine_similarity(current_embedding, cached_embedding)
if similarity > best_similarity:
@@ -668,6 +682,11 @@ def quantize_embedding(embedding: np.ndarray | list[float], bits: int = 8) -> tu
min_val = embedding.min()
max_val = embedding.max()
if min_val == max_val:
# handle constant vector
quantized = np.zeros_like(embedding, dtype=np.uint8)
return quantized, min_val, max_val
# Quantize to 0-255 range
scale = (2**bits - 1) / (max_val - min_val)
quantized = np.round((embedding - min_val) * scale).astype(np.uint8)
@@ -679,6 +698,10 @@ def dequantize_embedding(
quantized: np.ndarray, min_val: float, max_val: float, bits=8
) -> np.ndarray:
"""Restore quantized embedding"""
if min_val == max_val:
# handle constant vector
return np.full_like(quantized, min_val, dtype=np.float32)
scale = (max_val - min_val) / (2**bits - 1)
return (quantized * scale + min_val).astype(np.float32)