Improve quantize and dequantize handling of embedding
This commit is contained in:
@@ -573,18 +573,32 @@ async def get_best_cached_response(
|
||||
if cache_type and cache_data.get("cache_type") != cache_type:
|
||||
continue
|
||||
|
||||
# Check if cache data is valid
|
||||
if cache_data["embedding"] is None:
|
||||
continue
|
||||
|
||||
# Convert cached embedding list to ndarray
|
||||
cached_quantized = np.frombuffer(
|
||||
bytes.fromhex(cache_data["embedding"]), dtype=np.uint8
|
||||
).reshape(cache_data["embedding_shape"])
|
||||
cached_embedding = dequantize_embedding(
|
||||
cached_quantized,
|
||||
cache_data["embedding_min"],
|
||||
cache_data["embedding_max"],
|
||||
)
|
||||
|
||||
try:
|
||||
# Safely convert cached embedding
|
||||
cached_quantized = np.frombuffer(
|
||||
bytes.fromhex(cache_data["embedding"]), dtype=np.uint8
|
||||
).reshape(cache_data["embedding_shape"])
|
||||
|
||||
# Ensure min_val and max_val are valid float values
|
||||
embedding_min = cache_data.get("embedding_min")
|
||||
embedding_max = cache_data.get("embedding_max")
|
||||
|
||||
if embedding_min is None or embedding_max is None or embedding_min >= embedding_max:
|
||||
logger.warning(f"Invalid embedding min/max values: min={embedding_min}, max={embedding_max}")
|
||||
continue
|
||||
|
||||
cached_embedding = dequantize_embedding(
|
||||
cached_quantized,
|
||||
embedding_min,
|
||||
embedding_max,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing cached embedding: {str(e)}")
|
||||
continue
|
||||
|
||||
similarity = cosine_similarity(current_embedding, cached_embedding)
|
||||
if similarity > best_similarity:
|
||||
@@ -668,6 +682,11 @@ def quantize_embedding(embedding: np.ndarray | list[float], bits: int = 8) -> tu
|
||||
min_val = embedding.min()
|
||||
max_val = embedding.max()
|
||||
|
||||
if min_val == max_val:
|
||||
# handle constant vector
|
||||
quantized = np.zeros_like(embedding, dtype=np.uint8)
|
||||
return quantized, min_val, max_val
|
||||
|
||||
# Quantize to 0-255 range
|
||||
scale = (2**bits - 1) / (max_val - min_val)
|
||||
quantized = np.round((embedding - min_val) * scale).astype(np.uint8)
|
||||
@@ -679,6 +698,10 @@ def dequantize_embedding(
|
||||
quantized: np.ndarray, min_val: float, max_val: float, bits=8
|
||||
) -> np.ndarray:
|
||||
"""Restore quantized embedding"""
|
||||
if min_val == max_val:
|
||||
# handle constant vector
|
||||
return np.full_like(quantized, min_val, dtype=np.float32)
|
||||
|
||||
scale = (max_val - min_val) / (2**bits - 1)
|
||||
return (quantized * scale + min_val).astype(np.float32)
|
||||
|
||||
|
Reference in New Issue
Block a user