From f6129857a1ed23cd369f717e5909e7a7f500a5d1 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 24 Apr 2025 20:03:01 +0800 Subject: [PATCH] Improve quantize and dequantize handling of embedding --- lightrag/utils.py | 43 +++++++++++++++++++++++++++++++++---------- 1 file changed, 33 insertions(+), 10 deletions(-) diff --git a/lightrag/utils.py b/lightrag/utils.py index 2c18db69..e390d775 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -573,18 +573,32 @@ async def get_best_cached_response( if cache_type and cache_data.get("cache_type") != cache_type: continue + # Check if cache data is valid if cache_data["embedding"] is None: continue - - # Convert cached embedding list to ndarray - cached_quantized = np.frombuffer( - bytes.fromhex(cache_data["embedding"]), dtype=np.uint8 - ).reshape(cache_data["embedding_shape"]) - cached_embedding = dequantize_embedding( - cached_quantized, - cache_data["embedding_min"], - cache_data["embedding_max"], - ) + + try: + # Safely convert cached embedding + cached_quantized = np.frombuffer( + bytes.fromhex(cache_data["embedding"]), dtype=np.uint8 + ).reshape(cache_data["embedding_shape"]) + + # Ensure min_val and max_val are valid float values + embedding_min = cache_data.get("embedding_min") + embedding_max = cache_data.get("embedding_max") + + if embedding_min is None or embedding_max is None or embedding_min >= embedding_max: + logger.warning(f"Invalid embedding min/max values: min={embedding_min}, max={embedding_max}") + continue + + cached_embedding = dequantize_embedding( + cached_quantized, + embedding_min, + embedding_max, + ) + except Exception as e: + logger.warning(f"Error processing cached embedding: {str(e)}") + continue similarity = cosine_similarity(current_embedding, cached_embedding) if similarity > best_similarity: @@ -668,6 +682,11 @@ def quantize_embedding(embedding: np.ndarray | list[float], bits: int = 8) -> tu min_val = embedding.min() max_val = embedding.max() + if min_val == max_val: + # handle constant vector + quantized = np.zeros_like(embedding, dtype=np.uint8) + return quantized, min_val, max_val + # Quantize to 0-255 range scale = (2**bits - 1) / (max_val - min_val) quantized = np.round((embedding - min_val) * scale).astype(np.uint8) @@ -679,6 +698,10 @@ def dequantize_embedding( quantized: np.ndarray, min_val: float, max_val: float, bits=8 ) -> np.ndarray: """Restore quantized embedding""" + if min_val == max_val: + # handle constant vector + return np.full_like(quantized, min_val, dtype=np.float32) + scale = (max_val - min_val) / (2**bits - 1) return (quantized * scale + min_val).astype(np.float32)