From f2a1897b614eeeec2eb1ac244c667c56e9fc7be2 Mon Sep 17 00:00:00 2001 From: yuanxiaobin Date: Fri, 6 Dec 2024 10:21:53 +0800 Subject: [PATCH 1/4] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=20args=5Fhash=E5=9C=A8?= =?UTF-8?q?=E4=BD=BF=E7=94=A8=E5=B8=B8=E8=A7=84=E7=BC=93=E5=AD=98=E6=97=B6?= =?UTF-8?q?=E5=80=99=E6=89=8D=E8=AE=A1=E7=AE=97=E5=AF=BC=E8=87=B4embedding?= =?UTF-8?q?=E7=BC=93=E5=AD=98=E6=97=B6=E6=B2=A1=E6=9C=89=E8=AE=A1=E7=AE=97?= =?UTF-8?q?=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 6 +++++- lightrag/llm.py | 48 ++++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 47 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 00612859..a2cbb217 100644 --- a/README.md +++ b/README.md @@ -596,7 +596,11 @@ if __name__ == "__main__": | **enable\_llm\_cache** | `bool` | If `TRUE`, stores LLM results in cache; repeated prompts return cached responses | `TRUE` | | **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese"}`: sets example limit and output language | `example_number: all examples, language: English` | | **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` | -| **embedding\_cache\_config** | `dict` | Configuration for embedding cache. Includes `enabled` (bool) to toggle cache and `similarity_threshold` (float) for cache retrieval | `{"enabled": False, "similarity_threshold": 0.95}` | +| **embedding\_cache\_config** | `dict` | Configuration for question-answer caching. Contains two parameters: +- `enabled`: Boolean value to enable/disable caching functionality. When enabled, questions and answers will be cached. +- `similarity_threshold`: Float value (0-1), similarity threshold. When a new question's similarity with a cached question exceeds this threshold, the cached answer will be returned directly without calling the LLM. + +Default: `{"enabled": False, "similarity_threshold": 0.95}` | `{"enabled": False, "similarity_threshold": 0.95}` | ## API Server Implementation diff --git a/lightrag/llm.py b/lightrag/llm.py index 33fdd182..fdfb70a8 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -66,7 +66,11 @@ async def openai_complete_if_cache( messages.append({"role": "system", "content": system_prompt}) messages.extend(history_messages) messages.append({"role": "user", "content": prompt}) + if hashing_kv is not None: + # Calculate args_hash only when using cache + args_hash = compute_args_hash(model, messages) + # Get embedding cache configuration embedding_cache_config = hashing_kv.global_config.get( "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} @@ -86,7 +90,6 @@ async def openai_complete_if_cache( return best_cached_response else: # Use regular cache - args_hash = compute_args_hash(model, messages) if_cache_return = await hashing_kv.get_by_id(args_hash) if if_cache_return is not None: return if_cache_return["return"] @@ -159,7 +162,12 @@ async def azure_openai_complete_if_cache( messages.extend(history_messages) if prompt is not None: messages.append({"role": "user", "content": prompt}) + + hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) if hashing_kv is not None: + # Calculate args_hash only when using cache + args_hash = compute_args_hash(model, messages) + # Get embedding cache configuration embedding_cache_config = hashing_kv.global_config.get( "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} @@ -178,7 +186,7 @@ async def azure_openai_complete_if_cache( if best_cached_response is not None: return best_cached_response else: - args_hash = compute_args_hash(model, messages) + # Use regular cache if_cache_return = await hashing_kv.get_by_id(args_hash) if if_cache_return is not None: return if_cache_return["return"] @@ -271,6 +279,9 @@ async def bedrock_complete_if_cache( hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) if hashing_kv is not None: + # Calculate args_hash only when using cache + args_hash = compute_args_hash(model, messages) + # Get embedding cache configuration embedding_cache_config = hashing_kv.global_config.get( "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} @@ -290,7 +301,6 @@ async def bedrock_complete_if_cache( return best_cached_response else: # Use regular cache - args_hash = compute_args_hash(model, messages) if_cache_return = await hashing_kv.get_by_id(args_hash) if if_cache_return is not None: return if_cache_return["return"] @@ -343,6 +353,11 @@ def initialize_hf_model(model_name): return hf_model, hf_tokenizer +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), +) async def hf_model_if_cache( model, prompt, @@ -359,7 +374,11 @@ async def hf_model_if_cache( messages.extend(history_messages) messages.append({"role": "user", "content": prompt}) + hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) if hashing_kv is not None: + # Calculate args_hash only when using cache + args_hash = compute_args_hash(model, messages) + # Get embedding cache configuration embedding_cache_config = hashing_kv.global_config.get( "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} @@ -379,7 +398,6 @@ async def hf_model_if_cache( return best_cached_response else: # Use regular cache - args_hash = compute_args_hash(model, messages) if_cache_return = await hashing_kv.get_by_id(args_hash) if if_cache_return is not None: return if_cache_return["return"] @@ -448,6 +466,11 @@ async def hf_model_if_cache( return response_text +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), +) async def ollama_model_if_cache( model, prompt, @@ -468,7 +491,12 @@ async def ollama_model_if_cache( hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) messages.extend(history_messages) messages.append({"role": "user", "content": prompt}) + + hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) if hashing_kv is not None: + # Calculate args_hash only when using cache + args_hash = compute_args_hash(model, messages) + # Get embedding cache configuration embedding_cache_config = hashing_kv.global_config.get( "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} @@ -488,7 +516,6 @@ async def ollama_model_if_cache( return best_cached_response else: # Use regular cache - args_hash = compute_args_hash(model, messages) if_cache_return = await hashing_kv.get_by_id(args_hash) if if_cache_return is not None: return if_cache_return["return"] @@ -542,6 +569,11 @@ def initialize_lmdeploy_pipeline( return lmdeploy_pipe +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), +) async def lmdeploy_model_if_cache( model, prompt, @@ -620,7 +652,12 @@ async def lmdeploy_model_if_cache( hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) messages.extend(history_messages) messages.append({"role": "user", "content": prompt}) + + hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) if hashing_kv is not None: + # Calculate args_hash only when using cache + args_hash = compute_args_hash(model, messages) + # Get embedding cache configuration embedding_cache_config = hashing_kv.global_config.get( "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} @@ -640,7 +677,6 @@ async def lmdeploy_model_if_cache( return best_cached_response else: # Use regular cache - args_hash = compute_args_hash(model, messages) if_cache_return = await hashing_kv.get_by_id(args_hash) if if_cache_return is not None: return if_cache_return["return"] From 8a69604966bedc5a1d5e45cfa97b565b5f171443 Mon Sep 17 00:00:00 2001 From: yuanxiaobin Date: Fri, 6 Dec 2024 10:28:35 +0800 Subject: [PATCH 2/4] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=20args=5Fhash=E5=9C=A8?= =?UTF-8?q?=E4=BD=BF=E7=94=A8=E5=B8=B8=E8=A7=84=E7=BC=93=E5=AD=98=E6=97=B6?= =?UTF-8?q?=E5=80=99=E6=89=8D=E8=AE=A1=E7=AE=97=E5=AF=BC=E8=87=B4embedding?= =?UTF-8?q?=E7=BC=93=E5=AD=98=E6=97=B6=E6=B2=A1=E6=9C=89=E8=AE=A1=E7=AE=97?= =?UTF-8?q?=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lightrag/llm.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/lightrag/llm.py b/lightrag/llm.py index fdfb70a8..97c903d2 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -1,12 +1,16 @@ -import os +import base64 import copy -from functools import lru_cache import json +import os +import struct +from functools import lru_cache +from typing import List, Dict, Callable, Any + import aioboto3 import aiohttp import numpy as np import ollama - +import torch from openai import ( AsyncOpenAI, APIConnectionError, @@ -14,10 +18,7 @@ from openai import ( Timeout, AsyncAzureOpenAI, ) - -import base64 -import struct - +from pydantic import BaseModel, Field from tenacity import ( retry, stop_after_attempt, @@ -25,9 +26,7 @@ from tenacity import ( retry_if_exception_type, ) from transformers import AutoTokenizer, AutoModelForCausalLM -import torch -from pydantic import BaseModel, Field -from typing import List, Dict, Callable, Any + from .base import BaseKVStorage from .utils import ( compute_args_hash, @@ -70,7 +69,7 @@ async def openai_complete_if_cache( if hashing_kv is not None: # Calculate args_hash only when using cache args_hash = compute_args_hash(model, messages) - + # Get embedding cache configuration embedding_cache_config = hashing_kv.global_config.get( "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} @@ -167,7 +166,7 @@ async def azure_openai_complete_if_cache( if hashing_kv is not None: # Calculate args_hash only when using cache args_hash = compute_args_hash(model, messages) - + # Get embedding cache configuration embedding_cache_config = hashing_kv.global_config.get( "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} @@ -281,7 +280,7 @@ async def bedrock_complete_if_cache( if hashing_kv is not None: # Calculate args_hash only when using cache args_hash = compute_args_hash(model, messages) - + # Get embedding cache configuration embedding_cache_config = hashing_kv.global_config.get( "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} @@ -378,7 +377,7 @@ async def hf_model_if_cache( if hashing_kv is not None: # Calculate args_hash only when using cache args_hash = compute_args_hash(model, messages) - + # Get embedding cache configuration embedding_cache_config = hashing_kv.global_config.get( "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} @@ -496,7 +495,7 @@ async def ollama_model_if_cache( if hashing_kv is not None: # Calculate args_hash only when using cache args_hash = compute_args_hash(model, messages) - + # Get embedding cache configuration embedding_cache_config = hashing_kv.global_config.get( "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} @@ -657,7 +656,7 @@ async def lmdeploy_model_if_cache( if hashing_kv is not None: # Calculate args_hash only when using cache args_hash = compute_args_hash(model, messages) - + # Get embedding cache configuration embedding_cache_config = hashing_kv.global_config.get( "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} @@ -867,7 +866,8 @@ async def openai_embedding( ) async def nvidia_openai_embedding( texts: list[str], - model: str = "nvidia/llama-3.2-nv-embedqa-1b-v1", # refer to https://build.nvidia.com/nim?filters=usecase%3Ausecase_text_to_embedding + model: str = "nvidia/llama-3.2-nv-embedqa-1b-v1", + # refer to https://build.nvidia.com/nim?filters=usecase%3Ausecase_text_to_embedding base_url: str = "https://integrate.api.nvidia.com/v1", api_key: str = None, input_type: str = "passage", # query for retrieval, passage for embedding From 7c4bbe2474a251f1e17e77f54436637303b80311 Mon Sep 17 00:00:00 2001 From: yuanxiaobin Date: Fri, 6 Dec 2024 10:40:48 +0800 Subject: [PATCH 3/4] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=20args=5Fhash=E5=9C=A8?= =?UTF-8?q?=E4=BD=BF=E7=94=A8=E5=B8=B8=E8=A7=84=E7=BC=93=E5=AD=98=E6=97=B6?= =?UTF-8?q?=E5=80=99=E6=89=8D=E8=AE=A1=E7=AE=97=E5=AF=BC=E8=87=B4embedding?= =?UTF-8?q?=E7=BC=93=E5=AD=98=E6=97=B6=E6=B2=A1=E6=9C=89=E8=AE=A1=E7=AE=97?= =?UTF-8?q?=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lightrag/llm.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/lightrag/llm.py b/lightrag/llm.py index 97c903d2..fef8c9a3 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -162,7 +162,6 @@ async def azure_openai_complete_if_cache( if prompt is not None: messages.append({"role": "user", "content": prompt}) - hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) if hashing_kv is not None: # Calculate args_hash only when using cache args_hash = compute_args_hash(model, messages) @@ -373,7 +372,6 @@ async def hf_model_if_cache( messages.extend(history_messages) messages.append({"role": "user", "content": prompt}) - hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) if hashing_kv is not None: # Calculate args_hash only when using cache args_hash = compute_args_hash(model, messages) @@ -491,7 +489,6 @@ async def ollama_model_if_cache( messages.extend(history_messages) messages.append({"role": "user", "content": prompt}) - hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) if hashing_kv is not None: # Calculate args_hash only when using cache args_hash = compute_args_hash(model, messages) @@ -652,7 +649,6 @@ async def lmdeploy_model_if_cache( messages.extend(history_messages) messages.append({"role": "user", "content": prompt}) - hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) if hashing_kv is not None: # Calculate args_hash only when using cache args_hash = compute_args_hash(model, messages) From 0614a936082f2805958394a5530628e30f84cc9c Mon Sep 17 00:00:00 2001 From: Suroy <77138019+zsuroy@users.noreply.github.com> Date: Fri, 6 Dec 2024 11:06:20 +0800 Subject: [PATCH 4/4] Update oracle_impl.py Fixed typing error in python3.9 --- lightrag/kg/oracle_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 8ed73772..34745312 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -143,7 +143,7 @@ class OracleDB: data = None return data - async def execute(self, sql: str, data: list | dict = None): + async def execute(self, sql: str, data: Union[list, dict] = None): # logger.info("go into OracleDB execute method") try: async with self.pool.acquire() as connection: