From e619b09c8adb99c9636577dc3d8e559913d5c6e4 Mon Sep 17 00:00:00 2001 From: magicyuan876 <317617749@qq.com> Date: Fri, 6 Dec 2024 14:29:16 +0800 Subject: [PATCH] =?UTF-8?q?=E9=87=8D=E6=9E=84=E7=BC=93=E5=AD=98=E5=A4=84?= =?UTF-8?q?=E7=90=86=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 提取通用缓存处理逻辑到新函数 handle_cache 和 save_to_cache - 使用 CacheData 类统一缓存数据结构 - 优化嵌入式缓存和常规缓存的处理流程 - 添加模式参数以支持不同查询模式的缓存策略 - 重构 get_best_cached_response 函数,提高缓存查询效率 --- lightrag/llm.py | 496 ++++++++++++++++++++------------------------ lightrag/operate.py | 6 +- lightrag/utils.py | 84 ++++---- 3 files changed, 277 insertions(+), 309 deletions(-) diff --git a/lightrag/llm.py b/lightrag/llm.py index fef8c9a3..89d74a5b 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -4,7 +4,8 @@ import json import os import struct from functools import lru_cache -from typing import List, Dict, Callable, Any +from typing import List, Dict, Callable, Any, Optional +from dataclasses import dataclass import aioboto3 import aiohttp @@ -59,39 +60,21 @@ async def openai_complete_if_cache( openai_async_client = ( AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url) ) - hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) + messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.extend(history_messages) messages.append({"role": "user", "content": prompt}) - if hashing_kv is not None: - # Calculate args_hash only when using cache - args_hash = compute_args_hash(model, messages) - - # Get embedding cache configuration - embedding_cache_config = hashing_kv.global_config.get( - "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} - ) - is_embedding_cache_enabled = embedding_cache_config["enabled"] - if is_embedding_cache_enabled: - # Use embedding cache - embedding_model_func = hashing_kv.global_config["embedding_func"]["func"] - current_embedding = await embedding_model_func([prompt]) - quantized, min_val, max_val = quantize_embedding(current_embedding[0]) - best_cached_response = await get_best_cached_response( - hashing_kv, - current_embedding[0], - similarity_threshold=embedding_cache_config["similarity_threshold"], - ) - if best_cached_response is not None: - return best_cached_response - else: - # Use regular cache - if_cache_return = await hashing_kv.get_by_id(args_hash) - if if_cache_return is not None: - return if_cache_return["return"] + # Handle cache + mode = kwargs.pop("mode", "default") + args_hash = compute_args_hash(model, messages) + cached_response, quantized, min_val, max_val = await handle_cache( + kwargs.get("hashing_kv"), args_hash, prompt, mode + ) + if cached_response is not None: + return cached_response if "response_format" in kwargs: response = await openai_async_client.beta.chat.completions.parse( @@ -105,24 +88,21 @@ async def openai_complete_if_cache( if r"\u" in content: content = content.encode("utf-8").decode("unicode_escape") - if hashing_kv is not None: - await hashing_kv.upsert( - { - args_hash: { - "return": content, - "model": model, - "embedding": quantized.tobytes().hex() - if is_embedding_cache_enabled - else None, - "embedding_shape": quantized.shape - if is_embedding_cache_enabled - else None, - "embedding_min": min_val if is_embedding_cache_enabled else None, - "embedding_max": max_val if is_embedding_cache_enabled else None, - "original_prompt": prompt, - } - } - ) + # Save to cache + await save_to_cache( + kwargs.get("hashing_kv"), + CacheData( + args_hash=args_hash, + content=content, + model=model, + prompt=prompt, + quantized=quantized, + min_val=min_val, + max_val=max_val, + mode=mode, + ), + ) + return content @@ -155,6 +135,8 @@ async def azure_openai_complete_if_cache( ) hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) + mode = kwargs.pop("mode", "default") + messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) @@ -162,56 +144,35 @@ async def azure_openai_complete_if_cache( if prompt is not None: messages.append({"role": "user", "content": prompt}) - if hashing_kv is not None: - # Calculate args_hash only when using cache - args_hash = compute_args_hash(model, messages) - - # Get embedding cache configuration - embedding_cache_config = hashing_kv.global_config.get( - "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} - ) - is_embedding_cache_enabled = embedding_cache_config["enabled"] - if is_embedding_cache_enabled: - # Use embedding cache - embedding_model_func = hashing_kv.global_config["embedding_func"]["func"] - current_embedding = await embedding_model_func([prompt]) - quantized, min_val, max_val = quantize_embedding(current_embedding[0]) - best_cached_response = await get_best_cached_response( - hashing_kv, - current_embedding[0], - similarity_threshold=embedding_cache_config["similarity_threshold"], - ) - if best_cached_response is not None: - return best_cached_response - else: - # Use regular cache - if_cache_return = await hashing_kv.get_by_id(args_hash) - if if_cache_return is not None: - return if_cache_return["return"] + # Handle cache + args_hash = compute_args_hash(model, messages) + cached_response, quantized, min_val, max_val = await handle_cache( + hashing_kv, args_hash, prompt, mode + ) + if cached_response is not None: + return cached_response response = await openai_async_client.chat.completions.create( model=model, messages=messages, **kwargs ) + content = response.choices[0].message.content - if hashing_kv is not None: - await hashing_kv.upsert( - { - args_hash: { - "return": response.choices[0].message.content, - "model": model, - "embedding": quantized.tobytes().hex() - if is_embedding_cache_enabled - else None, - "embedding_shape": quantized.shape - if is_embedding_cache_enabled - else None, - "embedding_min": min_val if is_embedding_cache_enabled else None, - "embedding_max": max_val if is_embedding_cache_enabled else None, - "original_prompt": prompt, - } - } - ) - return response.choices[0].message.content + # Save to cache + await save_to_cache( + hashing_kv, + CacheData( + args_hash=args_hash, + content=content, + model=model, + prompt=prompt, + quantized=quantized, + min_val=min_val, + max_val=max_val, + mode=mode, + ), + ) + + return content class BedrockError(Exception): @@ -253,6 +214,15 @@ async def bedrock_complete_if_cache( # Add user prompt messages.append({"role": "user", "content": [{"text": prompt}]}) + # Handle cache + mode = kwargs.pop("mode", "default") + args_hash = compute_args_hash(model, messages) + cached_response, quantized, min_val, max_val = await handle_cache( + kwargs.get("hashing_kv"), args_hash, prompt, mode + ) + if cached_response is not None: + return cached_response + # Initialize Converse API arguments args = {"modelId": model, "messages": messages} @@ -275,33 +245,14 @@ async def bedrock_complete_if_cache( kwargs.pop(param) ) - hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) - if hashing_kv is not None: - # Calculate args_hash only when using cache - args_hash = compute_args_hash(model, messages) - - # Get embedding cache configuration - embedding_cache_config = hashing_kv.global_config.get( - "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} - ) - is_embedding_cache_enabled = embedding_cache_config["enabled"] - if is_embedding_cache_enabled: - # Use embedding cache - embedding_model_func = hashing_kv.global_config["embedding_func"]["func"] - current_embedding = await embedding_model_func([prompt]) - quantized, min_val, max_val = quantize_embedding(current_embedding[0]) - best_cached_response = await get_best_cached_response( - hashing_kv, - current_embedding[0], - similarity_threshold=embedding_cache_config["similarity_threshold"], - ) - if best_cached_response is not None: - return best_cached_response - else: - # Use regular cache - if_cache_return = await hashing_kv.get_by_id(args_hash) - if if_cache_return is not None: - return if_cache_return["return"] + # Handle cache + mode = kwargs.pop("mode", "default") + args_hash = compute_args_hash(model, messages) + cached_response, quantized, min_val, max_val = await handle_cache( + kwargs.get("hashing_kv"), args_hash, prompt, mode + ) + if cached_response is not None: + return cached_response # Call model via Converse API session = aioboto3.Session() @@ -311,30 +262,22 @@ async def bedrock_complete_if_cache( except Exception as e: raise BedrockError(e) - if hashing_kv is not None: - await hashing_kv.upsert( - { - args_hash: { - "return": response["output"]["message"]["content"][0]["text"], - "model": model, - "embedding": quantized.tobytes().hex() - if is_embedding_cache_enabled - else None, - "embedding_shape": quantized.shape - if is_embedding_cache_enabled - else None, - "embedding_min": min_val - if is_embedding_cache_enabled - else None, - "embedding_max": max_val - if is_embedding_cache_enabled - else None, - "original_prompt": prompt, - } - } - ) + # Save to cache + await save_to_cache( + kwargs.get("hashing_kv"), + CacheData( + args_hash=args_hash, + content=response["output"]["message"]["content"][0]["text"], + model=model, + prompt=prompt, + quantized=quantized, + min_val=min_val, + max_val=max_val, + mode=mode, + ), + ) - return response["output"]["message"]["content"][0]["text"] + return response["output"]["message"]["content"][0]["text"] @lru_cache(maxsize=1) @@ -372,32 +315,14 @@ async def hf_model_if_cache( messages.extend(history_messages) messages.append({"role": "user", "content": prompt}) - if hashing_kv is not None: - # Calculate args_hash only when using cache - args_hash = compute_args_hash(model, messages) - - # Get embedding cache configuration - embedding_cache_config = hashing_kv.global_config.get( - "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} - ) - is_embedding_cache_enabled = embedding_cache_config["enabled"] - if is_embedding_cache_enabled: - # Use embedding cache - embedding_model_func = hashing_kv.global_config["embedding_func"]["func"] - current_embedding = await embedding_model_func([prompt]) - quantized, min_val, max_val = quantize_embedding(current_embedding[0]) - best_cached_response = await get_best_cached_response( - hashing_kv, - current_embedding[0], - similarity_threshold=embedding_cache_config["similarity_threshold"], - ) - if best_cached_response is not None: - return best_cached_response - else: - # Use regular cache - if_cache_return = await hashing_kv.get_by_id(args_hash) - if if_cache_return is not None: - return if_cache_return["return"] + # Handle cache + mode = kwargs.pop("mode", "default") + args_hash = compute_args_hash(model, messages) + cached_response, quantized, min_val, max_val = await handle_cache( + hashing_kv, args_hash, prompt, mode + ) + if cached_response is not None: + return cached_response input_prompt = "" try: @@ -442,24 +367,22 @@ async def hf_model_if_cache( response_text = hf_tokenizer.decode( output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True ) - if hashing_kv is not None: - await hashing_kv.upsert( - { - args_hash: { - "return": response_text, - "model": model, - "embedding": quantized.tobytes().hex() - if is_embedding_cache_enabled - else None, - "embedding_shape": quantized.shape - if is_embedding_cache_enabled - else None, - "embedding_min": min_val if is_embedding_cache_enabled else None, - "embedding_max": max_val if is_embedding_cache_enabled else None, - "original_prompt": prompt, - } - } - ) + + # Save to cache + await save_to_cache( + hashing_kv, + CacheData( + args_hash=args_hash, + content=response_text, + model=model, + prompt=prompt, + quantized=quantized, + min_val=min_val, + max_val=max_val, + mode=mode, + ), + ) + return response_text @@ -489,55 +412,34 @@ async def ollama_model_if_cache( messages.extend(history_messages) messages.append({"role": "user", "content": prompt}) - if hashing_kv is not None: - # Calculate args_hash only when using cache - args_hash = compute_args_hash(model, messages) - - # Get embedding cache configuration - embedding_cache_config = hashing_kv.global_config.get( - "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} - ) - is_embedding_cache_enabled = embedding_cache_config["enabled"] - if is_embedding_cache_enabled: - # Use embedding cache - embedding_model_func = hashing_kv.global_config["embedding_func"]["func"] - current_embedding = await embedding_model_func([prompt]) - quantized, min_val, max_val = quantize_embedding(current_embedding[0]) - best_cached_response = await get_best_cached_response( - hashing_kv, - current_embedding[0], - similarity_threshold=embedding_cache_config["similarity_threshold"], - ) - if best_cached_response is not None: - return best_cached_response - else: - # Use regular cache - if_cache_return = await hashing_kv.get_by_id(args_hash) - if if_cache_return is not None: - return if_cache_return["return"] + # Handle cache + mode = kwargs.pop("mode", "default") + args_hash = compute_args_hash(model, messages) + cached_response, quantized, min_val, max_val = await handle_cache( + hashing_kv, args_hash, prompt, mode + ) + if cached_response is not None: + return cached_response response = await ollama_client.chat(model=model, messages=messages, **kwargs) result = response["message"]["content"] - if hashing_kv is not None: - await hashing_kv.upsert( - { - args_hash: { - "return": result, - "model": model, - "embedding": quantized.tobytes().hex() - if is_embedding_cache_enabled - else None, - "embedding_shape": quantized.shape - if is_embedding_cache_enabled - else None, - "embedding_min": min_val if is_embedding_cache_enabled else None, - "embedding_max": max_val if is_embedding_cache_enabled else None, - "original_prompt": prompt, - } - } - ) + # Save to cache + await save_to_cache( + hashing_kv, + CacheData( + args_hash=args_hash, + content=result, + model=model, + prompt=prompt, + quantized=quantized, + min_val=min_val, + max_val=max_val, + mode=mode, + ), + ) + return result @@ -649,32 +551,14 @@ async def lmdeploy_model_if_cache( messages.extend(history_messages) messages.append({"role": "user", "content": prompt}) - if hashing_kv is not None: - # Calculate args_hash only when using cache - args_hash = compute_args_hash(model, messages) - - # Get embedding cache configuration - embedding_cache_config = hashing_kv.global_config.get( - "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} - ) - is_embedding_cache_enabled = embedding_cache_config["enabled"] - if is_embedding_cache_enabled: - # Use embedding cache - embedding_model_func = hashing_kv.global_config["embedding_func"]["func"] - current_embedding = await embedding_model_func([prompt]) - quantized, min_val, max_val = quantize_embedding(current_embedding[0]) - best_cached_response = await get_best_cached_response( - hashing_kv, - current_embedding[0], - similarity_threshold=embedding_cache_config["similarity_threshold"], - ) - if best_cached_response is not None: - return best_cached_response - else: - # Use regular cache - if_cache_return = await hashing_kv.get_by_id(args_hash) - if if_cache_return is not None: - return if_cache_return["return"] + # Handle cache + mode = kwargs.pop("mode", "default") + args_hash = compute_args_hash(model, messages) + cached_response, quantized, min_val, max_val = await handle_cache( + hashing_kv, args_hash, prompt, mode + ) + if cached_response is not None: + return cached_response gen_config = GenerationConfig( skip_special_tokens=skip_special_tokens, @@ -692,24 +576,21 @@ async def lmdeploy_model_if_cache( ): response += res.response - if hashing_kv is not None: - await hashing_kv.upsert( - { - args_hash: { - "return": response, - "model": model, - "embedding": quantized.tobytes().hex() - if is_embedding_cache_enabled - else None, - "embedding_shape": quantized.shape - if is_embedding_cache_enabled - else None, - "embedding_min": min_val if is_embedding_cache_enabled else None, - "embedding_max": max_val if is_embedding_cache_enabled else None, - "original_prompt": prompt, - } - } - ) + # Save to cache + await save_to_cache( + hashing_kv, + CacheData( + args_hash=args_hash, + content=response, + model=model, + prompt=prompt, + quantized=quantized, + min_val=min_val, + max_val=max_val, + mode=mode, + ), + ) + return response @@ -1139,6 +1020,75 @@ class MultiModel: return await next_model.gen_func(**args) +async def handle_cache(hashing_kv, args_hash, prompt, mode="default"): + """Generic cache handling function""" + if hashing_kv is None: + return None, None, None, None + + # Get embedding cache configuration + embedding_cache_config = hashing_kv.global_config.get( + "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} + ) + is_embedding_cache_enabled = embedding_cache_config["enabled"] + + quantized = min_val = max_val = None + if is_embedding_cache_enabled: + # Use embedding cache + embedding_model_func = hashing_kv.global_config["embedding_func"]["func"] + current_embedding = await embedding_model_func([prompt]) + quantized, min_val, max_val = quantize_embedding(current_embedding[0]) + best_cached_response = await get_best_cached_response( + hashing_kv, + current_embedding[0], + similarity_threshold=embedding_cache_config["similarity_threshold"], + mode=mode, + ) + if best_cached_response is not None: + return best_cached_response, None, None, None + else: + # Use regular cache + mode_cache = await hashing_kv.get_by_id(mode) or {} + if args_hash in mode_cache: + return mode_cache[args_hash]["return"], None, None, None + + return None, quantized, min_val, max_val + + +@dataclass +class CacheData: + args_hash: str + content: str + model: str + prompt: str + quantized: Optional[np.ndarray] = None + min_val: Optional[float] = None + max_val: Optional[float] = None + mode: str = "default" + + +async def save_to_cache(hashing_kv, cache_data: CacheData): + if hashing_kv is None: + return + + mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {} + + mode_cache[cache_data.args_hash] = { + "return": cache_data.content, + "model": cache_data.model, + "embedding": cache_data.quantized.tobytes().hex() + if cache_data.quantized is not None + else None, + "embedding_shape": cache_data.quantized.shape + if cache_data.quantized is not None + else None, + "embedding_min": cache_data.min_val, + "embedding_max": cache_data.max_val, + "original_prompt": cache_data.prompt, + } + + await hashing_kv.upsert({cache_data.mode: mode_cache}) + + if __name__ == "__main__": import asyncio diff --git a/lightrag/operate.py b/lightrag/operate.py index a846cfc5..5b911d34 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -474,7 +474,9 @@ async def kg_query( use_model_func = global_config["llm_model_func"] kw_prompt_temp = PROMPTS["keywords_extraction"] kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language) - result = await use_model_func(kw_prompt, keyword_extraction=True) + result = await use_model_func( + kw_prompt, keyword_extraction=True, mode=query_param.mode + ) logger.info("kw_prompt result:") print(result) try: @@ -534,6 +536,7 @@ async def kg_query( response = await use_model_func( query, system_prompt=sys_prompt, + mode=query_param.mode, ) if len(response) > len(sys_prompt): response = ( @@ -1035,6 +1038,7 @@ async def naive_query( response = await use_model_func( query, system_prompt=sys_prompt, + mode=query_param.mode, ) if len(response) > len(sys_prompt): diff --git a/lightrag/utils.py b/lightrag/utils.py index d080ee03..70ec4341 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -310,43 +310,57 @@ def process_combine_contexts(hl, ll): async def get_best_cached_response( - hashing_kv, current_embedding, similarity_threshold=0.95 -): - """Get the cached response with the highest similarity""" - try: - # Get all keys - all_keys = await hashing_kv.all_keys() - max_similarity = 0 - best_cached_response = None - - # Get cached data one by one - for key in all_keys: - cache_data = await hashing_kv.get_by_id(key) - if cache_data is None or "embedding" not in cache_data: - continue - - # Convert cached embedding list to ndarray - cached_quantized = np.frombuffer( - bytes.fromhex(cache_data["embedding"]), dtype=np.uint8 - ).reshape(cache_data["embedding_shape"]) - cached_embedding = dequantize_embedding( - cached_quantized, - cache_data["embedding_min"], - cache_data["embedding_max"], - ) - - similarity = cosine_similarity(current_embedding, cached_embedding) - if similarity > max_similarity: - max_similarity = similarity - best_cached_response = cache_data["return"] - - if max_similarity > similarity_threshold: - return best_cached_response + hashing_kv, + current_embedding, + similarity_threshold=0.95, + mode="default", +) -> Union[str, None]: + # Get mode-specific cache + mode_cache = await hashing_kv.get_by_id(mode) + if not mode_cache: return None - except Exception as e: - logger.warning(f"Error in get_best_cached_response: {e}") - return None + best_similarity = -1 + best_response = None + best_prompt = None + best_cache_id = None + + # Only iterate through cache entries for this mode + for cache_id, cache_data in mode_cache.items(): + if cache_data["embedding"] is None: + continue + + # Convert cached embedding list to ndarray + cached_quantized = np.frombuffer( + bytes.fromhex(cache_data["embedding"]), dtype=np.uint8 + ).reshape(cache_data["embedding_shape"]) + cached_embedding = dequantize_embedding( + cached_quantized, + cache_data["embedding_min"], + cache_data["embedding_max"], + ) + + similarity = cosine_similarity(current_embedding, cached_embedding) + if similarity > best_similarity: + best_similarity = similarity + best_response = cache_data["return"] + best_prompt = cache_data["original_prompt"] + best_cache_id = cache_id + + if best_similarity > similarity_threshold: + prompt_display = ( + best_prompt[:50] + "..." if len(best_prompt) > 50 else best_prompt + ) + log_data = { + "event": "cache_hit", + "mode": mode, + "similarity": round(best_similarity, 4), + "cache_id": best_cache_id, + "original_prompt": prompt_display, + } + logger.info(json.dumps(log_data)) + return best_response + return None def cosine_similarity(v1, v2):