diff --git a/README.md b/README.md index a2cbb217..a1454792 100644 --- a/README.md +++ b/README.md @@ -596,11 +596,7 @@ if __name__ == "__main__": | **enable\_llm\_cache** | `bool` | If `TRUE`, stores LLM results in cache; repeated prompts return cached responses | `TRUE` | | **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese"}`: sets example limit and output language | `example_number: all examples, language: English` | | **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` | -| **embedding\_cache\_config** | `dict` | Configuration for question-answer caching. Contains two parameters: -- `enabled`: Boolean value to enable/disable caching functionality. When enabled, questions and answers will be cached. -- `similarity_threshold`: Float value (0-1), similarity threshold. When a new question's similarity with a cached question exceeds this threshold, the cached answer will be returned directly without calling the LLM. - -Default: `{"enabled": False, "similarity_threshold": 0.95}` | `{"enabled": False, "similarity_threshold": 0.95}` | +| **embedding\_cache\_config** | `dict` | Configuration for question-answer caching. Contains three parameters:
- `enabled`: Boolean value to enable/disable cache lookup functionality. When enabled, the system will check cached responses before generating new answers.
- `similarity_threshold`: Float value (0-1), similarity threshold. When a new question's similarity with a cached question exceeds this threshold, the cached answer will be returned directly without calling the LLM.
- `use_llm_check`: Boolean value to enable/disable LLM similarity verification. When enabled, LLM will be used as a secondary check to verify the similarity between questions before returning cached answers. | Default: `{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False}` | ## API Server Implementation diff --git a/examples/graph_visual_with_html.py b/examples/graph_visual_with_html.py index 11279b3a..56642185 100644 --- a/examples/graph_visual_with_html.py +++ b/examples/graph_visual_with_html.py @@ -11,9 +11,17 @@ net = Network(height="100vh", notebook=True) # Convert NetworkX graph to Pyvis network net.from_nx(G) -# Add colors to nodes + +# Add colors and title to nodes for node in net.nodes: node["color"] = "#{:06x}".format(random.randint(0, 0xFFFFFF)) + if "description" in node: + node["title"] = node["description"] + +# Add title to edges +for edge in net.edges: + if "description" in edge: + edge["title"] = edge["description"] # Save and display the network net.show("knowledge_graph.html") diff --git a/examples/lightrag_jinaai_demo.py b/examples/lightrag_jinaai_demo.py new file mode 100644 index 00000000..4daead75 --- /dev/null +++ b/examples/lightrag_jinaai_demo.py @@ -0,0 +1,114 @@ +import numpy as np +from lightrag import LightRAG, QueryParam +from lightrag.utils import EmbeddingFunc +from lightrag.llm import jina_embedding, openai_complete_if_cache +import os +import asyncio + + +async def embedding_func(texts: list[str]) -> np.ndarray: + return await jina_embedding(texts, api_key="YourJinaAPIKey") + + +WORKING_DIR = "./dickens" + +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) + + +async def llm_model_func( + prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + return await openai_complete_if_cache( + "solar-mini", + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + api_key=os.getenv("UPSTAGE_API_KEY"), + base_url="https://api.upstage.ai/v1/solar", + **kwargs, + ) + + +rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=1024, max_token_size=8192, func=embedding_func + ), +) + + +async def lightraginsert(file_path, semaphore): + async with semaphore: + try: + with open(file_path, "r", encoding="utf-8") as f: + content = f.read() + except UnicodeDecodeError: + # If UTF-8 decoding fails, try other encodings + with open(file_path, "r", encoding="gbk") as f: + content = f.read() + await rag.ainsert(content) + + +async def process_files(directory, concurrency_limit): + semaphore = asyncio.Semaphore(concurrency_limit) + tasks = [] + for root, dirs, files in os.walk(directory): + for f in files: + file_path = os.path.join(root, f) + if f.startswith("."): + continue + tasks.append(lightraginsert(file_path, semaphore)) + await asyncio.gather(*tasks) + + +async def main(): + try: + rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=1024, + max_token_size=8192, + func=embedding_func, + ), + ) + + asyncio.run(process_files(WORKING_DIR, concurrency_limit=4)) + + # Perform naive search + print( + await rag.aquery( + "What are the top themes in this story?", param=QueryParam(mode="naive") + ) + ) + + # Perform local search + print( + await rag.aquery( + "What are the top themes in this story?", param=QueryParam(mode="local") + ) + ) + + # Perform global search + print( + await rag.aquery( + "What are the top themes in this story?", + param=QueryParam(mode="global"), + ) + ) + + # Perform hybrid search + print( + await rag.aquery( + "What are the top themes in this story?", + param=QueryParam(mode="hybrid"), + ) + ) + except Exception as e: + print(f"An error occurred: {e}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 0a44187e..0eb1b27e 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -87,7 +87,11 @@ class LightRAG: ) # Default not to use embedding cache embedding_cache_config: dict = field( - default_factory=lambda: {"enabled": False, "similarity_threshold": 0.95} + default_factory=lambda: { + "enabled": False, + "similarity_threshold": 0.95, + "use_llm_check": False, + } ) kv_storage: str = field(default="JsonKVStorage") vector_storage: str = field(default="NanoVectorDBStorage") @@ -174,7 +178,6 @@ class LightRAG: if self.enable_llm_cache else None ) - self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( self.embedding_func ) @@ -481,6 +484,7 @@ class LightRAG: self.text_chunks, param, asdict(self), + hashing_kv=self.llm_response_cache, ) elif param.mode == "naive": response = await naive_query( @@ -489,6 +493,7 @@ class LightRAG: self.text_chunks, param, asdict(self), + hashing_kv=self.llm_response_cache, ) else: raise ValueError(f"Unknown mode {param.mode}") diff --git a/lightrag/llm.py b/lightrag/llm.py index 63913c90..53626b76 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -4,8 +4,7 @@ import json import os import struct from functools import lru_cache -from typing import List, Dict, Callable, Any, Union, Optional -from dataclasses import dataclass +from typing import List, Dict, Callable, Any, Union import aioboto3 import aiohttp import numpy as np @@ -27,13 +26,9 @@ from tenacity import ( ) from transformers import AutoTokenizer, AutoModelForCausalLM -from .base import BaseKVStorage from .utils import ( - compute_args_hash, wrap_embedding_func_with_attrs, locate_json_string_body_from_string, - quantize_embedding, - get_best_cached_response, ) import sys @@ -66,23 +61,13 @@ async def openai_complete_if_cache( openai_async_client = ( AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url) ) - + kwargs.pop("hashing_kv", None) messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.extend(history_messages) messages.append({"role": "user", "content": prompt}) - hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) - # Handle cache - mode = kwargs.pop("mode", "default") - args_hash = compute_args_hash(model, messages) - cached_response, quantized, min_val, max_val = await handle_cache( - hashing_kv, args_hash, prompt, mode - ) - if cached_response is not None: - return cached_response - if "response_format" in kwargs: response = await openai_async_client.beta.chat.completions.parse( model=model, messages=messages, **kwargs @@ -95,21 +80,6 @@ async def openai_complete_if_cache( if r"\u" in content: content = content.encode("utf-8").decode("unicode_escape") - # Save to cache - await save_to_cache( - hashing_kv, - CacheData( - args_hash=args_hash, - content=content, - model=model, - prompt=prompt, - quantized=quantized, - min_val=min_val, - max_val=max_val, - mode=mode, - ), - ) - return content @@ -140,10 +110,7 @@ async def azure_openai_complete_if_cache( api_key=os.getenv("AZURE_OPENAI_API_KEY"), api_version=os.getenv("AZURE_OPENAI_API_VERSION"), ) - - hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) - mode = kwargs.pop("mode", "default") - + kwargs.pop("hashing_kv", None) messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) @@ -151,34 +118,11 @@ async def azure_openai_complete_if_cache( if prompt is not None: messages.append({"role": "user", "content": prompt}) - # Handle cache - args_hash = compute_args_hash(model, messages) - cached_response, quantized, min_val, max_val = await handle_cache( - hashing_kv, args_hash, prompt, mode - ) - if cached_response is not None: - return cached_response - response = await openai_async_client.chat.completions.create( model=model, messages=messages, **kwargs ) content = response.choices[0].message.content - # Save to cache - await save_to_cache( - hashing_kv, - CacheData( - args_hash=args_hash, - content=content, - model=model, - prompt=prompt, - quantized=quantized, - min_val=min_val, - max_val=max_val, - mode=mode, - ), - ) - return content @@ -210,7 +154,7 @@ async def bedrock_complete_if_cache( os.environ["AWS_SESSION_TOKEN"] = os.environ.get( "AWS_SESSION_TOKEN", aws_session_token ) - + kwargs.pop("hashing_kv", None) # Fix message history format messages = [] for history_message in history_messages: @@ -220,15 +164,6 @@ async def bedrock_complete_if_cache( # Add user prompt messages.append({"role": "user", "content": [{"text": prompt}]}) - hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) - # Handle cache - mode = kwargs.pop("mode", "default") - args_hash = compute_args_hash(model, messages) - cached_response, quantized, min_val, max_val = await handle_cache( - hashing_kv, args_hash, prompt, mode - ) - if cached_response is not None: - return cached_response # Initialize Converse API arguments args = {"modelId": model, "messages": messages} @@ -251,15 +186,6 @@ async def bedrock_complete_if_cache( args["inferenceConfig"][inference_params_map.get(param, param)] = ( kwargs.pop(param) ) - hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) - # Handle cache - mode = kwargs.pop("mode", "default") - args_hash = compute_args_hash(model, messages) - cached_response, quantized, min_val, max_val = await handle_cache( - hashing_kv, args_hash, prompt, mode - ) - if cached_response is not None: - return cached_response # Call model via Converse API session = aioboto3.Session() @@ -269,21 +195,6 @@ async def bedrock_complete_if_cache( except Exception as e: raise BedrockError(e) - # Save to cache - await save_to_cache( - hashing_kv, - CacheData( - args_hash=args_hash, - content=response["output"]["message"]["content"][0]["text"], - model=model, - prompt=prompt, - quantized=quantized, - min_val=min_val, - max_val=max_val, - mode=mode, - ), - ) - return response["output"]["message"]["content"][0]["text"] @@ -315,22 +226,12 @@ async def hf_model_if_cache( ) -> str: model_name = model hf_model, hf_tokenizer = initialize_hf_model(model_name) - hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.extend(history_messages) messages.append({"role": "user", "content": prompt}) - - # Handle cache - mode = kwargs.pop("mode", "default") - args_hash = compute_args_hash(model, messages) - cached_response, quantized, min_val, max_val = await handle_cache( - hashing_kv, args_hash, prompt, mode - ) - if cached_response is not None: - return cached_response - + kwargs.pop("hashing_kv", None) input_prompt = "" try: input_prompt = hf_tokenizer.apply_chat_template( @@ -375,21 +276,6 @@ async def hf_model_if_cache( output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True ) - # Save to cache - await save_to_cache( - hashing_kv, - CacheData( - args_hash=args_hash, - content=response_text, - model=model, - prompt=prompt, - quantized=quantized, - min_val=min_val, - max_val=max_val, - mode=mode, - ), - ) - return response_text @@ -410,25 +296,14 @@ async def ollama_model_if_cache( # kwargs.pop("response_format", None) # allow json host = kwargs.pop("host", None) timeout = kwargs.pop("timeout", None) - + kwargs.pop("hashing_kv", None) ollama_client = ollama.AsyncClient(host=host, timeout=timeout) messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) - - hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) messages.extend(history_messages) messages.append({"role": "user", "content": prompt}) - # Handle cache - mode = kwargs.pop("mode", "default") - args_hash = compute_args_hash(model, messages) - cached_response, quantized, min_val, max_val = await handle_cache( - hashing_kv, args_hash, prompt, mode - ) - if cached_response is not None: - return cached_response - response = await ollama_client.chat(model=model, messages=messages, **kwargs) if stream: """ cannot cache stream response """ @@ -439,40 +314,7 @@ async def ollama_model_if_cache( return inner() else: - result = response["message"]["content"] - # Save to cache - await save_to_cache( - hashing_kv, - CacheData( - args_hash=args_hash, - content=result, - model=model, - prompt=prompt, - quantized=quantized, - min_val=min_val, - max_val=max_val, - mode=mode, - ), - ) - return result - result = response["message"]["content"] - - # Save to cache - await save_to_cache( - hashing_kv, - CacheData( - args_hash=args_hash, - content=result, - model=model, - prompt=prompt, - quantized=quantized, - min_val=min_val, - max_val=max_val, - mode=mode, - ), - ) - - return result + return response["message"]["content"] @lru_cache(maxsize=1) @@ -547,7 +389,7 @@ async def lmdeploy_model_if_cache( from lmdeploy import version_info, GenerationConfig except Exception: raise ImportError("Please install lmdeploy before intialize lmdeploy backend.") - + kwargs.pop("hashing_kv", None) kwargs.pop("response_format", None) max_new_tokens = kwargs.pop("max_tokens", 512) tp = kwargs.pop("tp", 1) @@ -579,19 +421,9 @@ async def lmdeploy_model_if_cache( if system_prompt: messages.append({"role": "system", "content": system_prompt}) - hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) messages.extend(history_messages) messages.append({"role": "user", "content": prompt}) - # Handle cache - mode = kwargs.pop("mode", "default") - args_hash = compute_args_hash(model, messages) - cached_response, quantized, min_val, max_val = await handle_cache( - hashing_kv, args_hash, prompt, mode - ) - if cached_response is not None: - return cached_response - gen_config = GenerationConfig( skip_special_tokens=skip_special_tokens, max_new_tokens=max_new_tokens, @@ -607,22 +439,6 @@ async def lmdeploy_model_if_cache( session_id=1, ): response += res.response - - # Save to cache - await save_to_cache( - hashing_kv, - CacheData( - args_hash=args_hash, - content=response, - model=model, - prompt=prompt, - quantized=quantized, - min_val=min_val, - max_val=max_val, - mode=mode, - ), - ) - return response @@ -767,6 +583,39 @@ async def openai_embedding( return np.array([dp.embedding for dp in response.data]) +async def fetch_data(url, headers, data): + async with aiohttp.ClientSession() as session: + async with session.post(url, headers=headers, json=data) as response: + response_json = await response.json() + data_list = response_json.get("data", []) + return data_list + + +async def jina_embedding( + texts: list[str], + dimensions: int = 1024, + late_chunking: bool = False, + base_url: str = None, + api_key: str = None, +) -> np.ndarray: + if api_key: + os.environ["JINA_API_KEY"] = api_key + url = "https://api.jina.ai/v1/embeddings" if not base_url else base_url + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {os.environ['JINA_API_KEY']}", + } + data = { + "model": "jina-embeddings-v3", + "normalized": True, + "embedding_type": "float", + "dimensions": f"{dimensions}", + "late_chunking": late_chunking, + "input": texts, + } + data_list = await fetch_data(url, headers, data) + return np.array([dp["embedding"] for dp in data_list]) + @wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=512) @retry( stop=stop_after_attempt(3), @@ -1052,75 +901,6 @@ class MultiModel: return await next_model.gen_func(**args) -async def handle_cache(hashing_kv, args_hash, prompt, mode="default"): - """Generic cache handling function""" - if hashing_kv is None: - return None, None, None, None - - # Get embedding cache configuration - embedding_cache_config = hashing_kv.global_config.get( - "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} - ) - is_embedding_cache_enabled = embedding_cache_config["enabled"] - - quantized = min_val = max_val = None - if is_embedding_cache_enabled: - # Use embedding cache - embedding_model_func = hashing_kv.global_config["embedding_func"]["func"] - current_embedding = await embedding_model_func([prompt]) - quantized, min_val, max_val = quantize_embedding(current_embedding[0]) - best_cached_response = await get_best_cached_response( - hashing_kv, - current_embedding[0], - similarity_threshold=embedding_cache_config["similarity_threshold"], - mode=mode, - ) - if best_cached_response is not None: - return best_cached_response, None, None, None - else: - # Use regular cache - mode_cache = await hashing_kv.get_by_id(mode) or {} - if args_hash in mode_cache: - return mode_cache[args_hash]["return"], None, None, None - - return None, quantized, min_val, max_val - - -@dataclass -class CacheData: - args_hash: str - content: str - model: str - prompt: str - quantized: Optional[np.ndarray] = None - min_val: Optional[float] = None - max_val: Optional[float] = None - mode: str = "default" - - -async def save_to_cache(hashing_kv, cache_data: CacheData): - if hashing_kv is None: - return - - mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {} - - mode_cache[cache_data.args_hash] = { - "return": cache_data.content, - "model": cache_data.model, - "embedding": cache_data.quantized.tobytes().hex() - if cache_data.quantized is not None - else None, - "embedding_shape": cache_data.quantized.shape - if cache_data.quantized is not None - else None, - "embedding_min": cache_data.min_val, - "embedding_max": cache_data.max_val, - "original_prompt": cache_data.prompt, - } - - await hashing_kv.upsert({cache_data.mode: mode_cache}) - - if __name__ == "__main__": import asyncio diff --git a/lightrag/operate.py b/lightrag/operate.py index 61c8058c..72734867 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -17,6 +17,10 @@ from .utils import ( split_string_by_multi_markers, truncate_list_by_token_size, process_combine_contexts, + compute_args_hash, + handle_cache, + save_to_cache, + CacheData, ) from .base import ( BaseGraphStorage, @@ -452,8 +456,17 @@ async def kg_query( text_chunks_db: BaseKVStorage[TextChunkSchema], query_param: QueryParam, global_config: dict, + hashing_kv: BaseKVStorage = None, ) -> str: - context = None + # Handle cache + use_model_func = global_config["llm_model_func"] + args_hash = compute_args_hash(query_param.mode, query) + cached_response, quantized, min_val, max_val = await handle_cache( + hashing_kv, args_hash, query, query_param.mode + ) + if cached_response is not None: + return cached_response + example_number = global_config["addon_params"].get("example_number", None) if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]): examples = "\n".join( @@ -471,12 +484,9 @@ async def kg_query( return PROMPTS["fail_response"] # LLM generate keywords - use_model_func = global_config["llm_model_func"] kw_prompt_temp = PROMPTS["keywords_extraction"] kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language) - result = await use_model_func( - kw_prompt, keyword_extraction=True, mode=query_param.mode - ) + result = await use_model_func(kw_prompt, keyword_extraction=True) logger.info("kw_prompt result:") print(result) try: @@ -537,7 +547,6 @@ async def kg_query( query, system_prompt=sys_prompt, stream=query_param.stream, - mode=query_param.mode, ) if isinstance(response, str) and len(response) > len(sys_prompt): response = ( @@ -550,6 +559,19 @@ async def kg_query( .strip() ) + # Save to cache + await save_to_cache( + hashing_kv, + CacheData( + args_hash=args_hash, + content=response, + prompt=query, + quantized=quantized, + min_val=min_val, + max_val=max_val, + mode=query_param.mode, + ), + ) return response @@ -967,23 +989,37 @@ async def _find_related_text_unit_from_relationships( for index, unit_list in enumerate(text_units): for c_id in unit_list: if c_id not in all_text_units_lookup: - all_text_units_lookup[c_id] = { - "data": await text_chunks_db.get_by_id(c_id), - "order": index, - } + chunk_data = await text_chunks_db.get_by_id(c_id) + # Only store valid data + if chunk_data is not None and "content" in chunk_data: + all_text_units_lookup[c_id] = { + "data": chunk_data, + "order": index, + } - if any([v is None for v in all_text_units_lookup.values()]): - logger.warning("Text chunks are missing, maybe the storage is damaged") - all_text_units = [ - {"id": k, **v} for k, v in all_text_units_lookup.items() if v is not None - ] + if not all_text_units_lookup: + logger.warning("No valid text chunks found") + return [] + + all_text_units = [{"id": k, **v} for k, v in all_text_units_lookup.items()] all_text_units = sorted(all_text_units, key=lambda x: x["order"]) - all_text_units = truncate_list_by_token_size( - all_text_units, + + # Ensure all text chunks have content + valid_text_units = [ + t for t in all_text_units if t["data"] is not None and "content" in t["data"] + ] + + if not valid_text_units: + logger.warning("No valid text chunks after filtering") + return [] + + truncated_text_units = truncate_list_by_token_size( + valid_text_units, key=lambda x: x["data"]["content"], max_token_size=query_param.max_token_for_text_unit, ) - all_text_units: list[TextChunkSchema] = [t["data"] for t in all_text_units] + + all_text_units: list[TextChunkSchema] = [t["data"] for t in truncated_text_units] return all_text_units @@ -1013,29 +1049,57 @@ async def naive_query( text_chunks_db: BaseKVStorage[TextChunkSchema], query_param: QueryParam, global_config: dict, + hashing_kv: BaseKVStorage = None, ): + # Handle cache use_model_func = global_config["llm_model_func"] + args_hash = compute_args_hash(query_param.mode, query) + cached_response, quantized, min_val, max_val = await handle_cache( + hashing_kv, args_hash, query, query_param.mode + ) + if cached_response is not None: + return cached_response + results = await chunks_vdb.query(query, top_k=query_param.top_k) if not len(results): return PROMPTS["fail_response"] + chunks_ids = [r["id"] for r in results] chunks = await text_chunks_db.get_by_ids(chunks_ids) + # Filter out invalid chunks + valid_chunks = [ + chunk for chunk in chunks if chunk is not None and "content" in chunk + ] + + if not valid_chunks: + logger.warning("No valid chunks found after filtering") + return PROMPTS["fail_response"] + maybe_trun_chunks = truncate_list_by_token_size( - chunks, + valid_chunks, key=lambda x: x["content"], max_token_size=query_param.max_token_for_text_unit, ) + + if not maybe_trun_chunks: + logger.warning("No chunks left after truncation") + return PROMPTS["fail_response"] + logger.info(f"Truncate {len(chunks)} to {len(maybe_trun_chunks)} chunks") section = "\n--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks]) + if query_param.only_need_context: return section + sys_prompt_temp = PROMPTS["naive_rag_response"] sys_prompt = sys_prompt_temp.format( content_data=section, response_type=query_param.response_type ) + if query_param.only_need_prompt: return sys_prompt + response = await use_model_func( query, system_prompt=sys_prompt, @@ -1054,4 +1118,18 @@ async def naive_query( .strip() ) + # Save to cache + await save_to_cache( + hashing_kv, + CacheData( + args_hash=args_hash, + content=response, + prompt=query, + quantized=quantized, + min_val=min_val, + max_val=max_val, + mode=query_param.mode, + ), + ) + return response diff --git a/lightrag/prompt.py b/lightrag/prompt.py index d758397b..863d38dc 100644 --- a/lightrag/prompt.py +++ b/lightrag/prompt.py @@ -261,3 +261,22 @@ Do not include information where the supporting evidence for it is not provided. Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown. """ + +PROMPTS[ + "similarity_check" +] = """Please analyze the similarity between these two questions: + +Question 1: {original_prompt} +Question 2: {cached_prompt} + +Please evaluate: +1. Whether these two questions are semantically similar +2. Whether the answer to Question 2 can be used to answer Question 1 + +Please provide a similarity score between 0 and 1, where: +0: Completely unrelated or answer cannot be reused +1: Identical and answer can be directly reused +0.5: Partially related and answer needs modification to be used + +Return only a number between 0-1, without any additional content. +""" diff --git a/lightrag/utils.py b/lightrag/utils.py index 4c8d7996..49a7b498 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -9,12 +9,14 @@ import re from dataclasses import dataclass from functools import wraps from hashlib import md5 -from typing import Any, Union, List +from typing import Any, Union, List, Optional import xml.etree.ElementTree as ET import numpy as np import tiktoken +from lightrag.prompt import PROMPTS + ENCODER = None logger = logging.getLogger("lightrag") @@ -314,6 +316,9 @@ async def get_best_cached_response( current_embedding, similarity_threshold=0.95, mode="default", + use_llm_check=False, + llm_func=None, + original_prompt=None, ) -> Union[str, None]: # Get mode-specific cache mode_cache = await hashing_kv.get_by_id(mode) @@ -348,6 +353,37 @@ async def get_best_cached_response( best_cache_id = cache_id if best_similarity > similarity_threshold: + # If LLM check is enabled and all required parameters are provided + if use_llm_check and llm_func and original_prompt and best_prompt: + compare_prompt = PROMPTS["similarity_check"].format( + original_prompt=original_prompt, cached_prompt=best_prompt + ) + + try: + llm_result = await llm_func(compare_prompt) + llm_result = llm_result.strip() + llm_similarity = float(llm_result) + + # Replace vector similarity with LLM similarity score + best_similarity = llm_similarity + if best_similarity < similarity_threshold: + log_data = { + "event": "llm_check_cache_rejected", + "original_question": original_prompt[:100] + "..." + if len(original_prompt) > 100 + else original_prompt, + "cached_question": best_prompt[:100] + "..." + if len(best_prompt) > 100 + else best_prompt, + "similarity_score": round(best_similarity, 4), + "threshold": similarity_threshold, + } + logger.info(json.dumps(log_data, ensure_ascii=False)) + return None + except Exception as e: # Catch all possible exceptions + logger.warning(f"LLM similarity check failed: {e}") + return None # Return None directly when LLM check fails + prompt_display = ( best_prompt[:50] + "..." if len(best_prompt) > 50 else best_prompt ) @@ -390,3 +426,84 @@ def dequantize_embedding( """Restore quantized embedding""" scale = (max_val - min_val) / (2**bits - 1) return (quantized * scale + min_val).astype(np.float32) + + +async def handle_cache(hashing_kv, args_hash, prompt, mode="default"): + """Generic cache handling function""" + if hashing_kv is None: + return None, None, None, None + + # For naive mode, only use simple cache matching + if mode == "naive": + mode_cache = await hashing_kv.get_by_id(mode) or {} + if args_hash in mode_cache: + return mode_cache[args_hash]["return"], None, None, None + return None, None, None, None + + # Get embedding cache configuration + embedding_cache_config = hashing_kv.global_config.get( + "embedding_cache_config", + {"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False}, + ) + is_embedding_cache_enabled = embedding_cache_config["enabled"] + use_llm_check = embedding_cache_config.get("use_llm_check", False) + + quantized = min_val = max_val = None + if is_embedding_cache_enabled: + # Use embedding cache + embedding_model_func = hashing_kv.global_config["embedding_func"]["func"] + llm_model_func = hashing_kv.global_config.get("llm_model_func") + + current_embedding = await embedding_model_func([prompt]) + quantized, min_val, max_val = quantize_embedding(current_embedding[0]) + best_cached_response = await get_best_cached_response( + hashing_kv, + current_embedding[0], + similarity_threshold=embedding_cache_config["similarity_threshold"], + mode=mode, + use_llm_check=use_llm_check, + llm_func=llm_model_func if use_llm_check else None, + original_prompt=prompt if use_llm_check else None, + ) + if best_cached_response is not None: + return best_cached_response, None, None, None + else: + # Use regular cache + mode_cache = await hashing_kv.get_by_id(mode) or {} + if args_hash in mode_cache: + return mode_cache[args_hash]["return"], None, None, None + + return None, quantized, min_val, max_val + + +@dataclass +class CacheData: + args_hash: str + content: str + prompt: str + quantized: Optional[np.ndarray] = None + min_val: Optional[float] = None + max_val: Optional[float] = None + mode: str = "default" + + +async def save_to_cache(hashing_kv, cache_data: CacheData): + if hashing_kv is None: + return + + mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {} + + mode_cache[cache_data.args_hash] = { + "return": cache_data.content, + "embedding": cache_data.quantized.tobytes().hex() + if cache_data.quantized is not None + else None, + "embedding_shape": cache_data.quantized.shape + if cache_data.quantized is not None + else None, + "embedding_min": cache_data.min_val, + "embedding_max": cache_data.max_val, + "original_prompt": cache_data.prompt, + } + + await hashing_kv.upsert({cache_data.mode: mode_cache})