diff --git a/lightrag/utils.py b/lightrag/utils.py index 4c8d7996..0fcb437f 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -9,7 +9,7 @@ import re from dataclasses import dataclass from functools import wraps from hashlib import md5 -from typing import Any, Union, List +from typing import Any, Union, List, Optional import xml.etree.ElementTree as ET import numpy as np @@ -390,3 +390,71 @@ def dequantize_embedding( """Restore quantized embedding""" scale = (max_val - min_val) / (2**bits - 1) return (quantized * scale + min_val).astype(np.float32) + +async def handle_cache(hashing_kv, args_hash, prompt, mode="default"): + """Generic cache handling function""" + if hashing_kv is None: + return None, None, None, None + + # Get embedding cache configuration + embedding_cache_config = hashing_kv.global_config.get( + "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} + ) + is_embedding_cache_enabled = embedding_cache_config["enabled"] + + quantized = min_val = max_val = None + if is_embedding_cache_enabled: + # Use embedding cache + embedding_model_func = hashing_kv.global_config["embedding_func"]["func"] + current_embedding = await embedding_model_func([prompt]) + quantized, min_val, max_val = quantize_embedding(current_embedding[0]) + best_cached_response = await get_best_cached_response( + hashing_kv, + current_embedding[0], + similarity_threshold=embedding_cache_config["similarity_threshold"], + mode=mode, + ) + if best_cached_response is not None: + return best_cached_response, None, None, None + else: + # Use regular cache + mode_cache = await hashing_kv.get_by_id(mode) or {} + if args_hash in mode_cache: + return mode_cache[args_hash]["return"], None, None, None + + return None, quantized, min_val, max_val + + +@dataclass +class CacheData: + args_hash: str + content: str + model: str + prompt: str + quantized: Optional[np.ndarray] = None + min_val: Optional[float] = None + max_val: Optional[float] = None + mode: str = "default" + + +async def save_to_cache(hashing_kv, cache_data: CacheData): + if hashing_kv is None: + return + + mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {} + + mode_cache[cache_data.args_hash] = { + "return": cache_data.content, + "model": cache_data.model, + "embedding": cache_data.quantized.tobytes().hex() + if cache_data.quantized is not None + else None, + "embedding_shape": cache_data.quantized.shape + if cache_data.quantized is not None + else None, + "embedding_min": cache_data.min_val, + "embedding_max": cache_data.max_val, + "original_prompt": cache_data.prompt, + } + + await hashing_kv.upsert({cache_data.mode: mode_cache})