Merge branch 'main' into fix-entity-name-string

This commit is contained in:
zrguo
2024-12-09 17:30:40 +08:00
committed by GitHub
8 changed files with 406 additions and 289 deletions

View File

@@ -596,11 +596,7 @@ if __name__ == "__main__":
| **enable\_llm\_cache** | `bool` | If `TRUE`, stores LLM results in cache; repeated prompts return cached responses | `TRUE` | | **enable\_llm\_cache** | `bool` | If `TRUE`, stores LLM results in cache; repeated prompts return cached responses | `TRUE` |
| **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese"}`: sets example limit and output language | `example_number: all examples, language: English` | | **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese"}`: sets example limit and output language | `example_number: all examples, language: English` |
| **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` | | **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` |
| **embedding\_cache\_config** | `dict` | Configuration for question-answer caching. Contains two parameters: | **embedding\_cache\_config** | `dict` | Configuration for question-answer caching. Contains three parameters:<br>- `enabled`: Boolean value to enable/disable cache lookup functionality. When enabled, the system will check cached responses before generating new answers.<br>- `similarity_threshold`: Float value (0-1), similarity threshold. When a new question's similarity with a cached question exceeds this threshold, the cached answer will be returned directly without calling the LLM.<br>- `use_llm_check`: Boolean value to enable/disable LLM similarity verification. When enabled, LLM will be used as a secondary check to verify the similarity between questions before returning cached answers. | Default: `{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False}` |
- `enabled`: Boolean value to enable/disable caching functionality. When enabled, questions and answers will be cached.
- `similarity_threshold`: Float value (0-1), similarity threshold. When a new question's similarity with a cached question exceeds this threshold, the cached answer will be returned directly without calling the LLM.
Default: `{"enabled": False, "similarity_threshold": 0.95}` | `{"enabled": False, "similarity_threshold": 0.95}` |
## API Server Implementation ## API Server Implementation

View File

@@ -11,9 +11,17 @@ net = Network(height="100vh", notebook=True)
# Convert NetworkX graph to Pyvis network # Convert NetworkX graph to Pyvis network
net.from_nx(G) net.from_nx(G)
# Add colors to nodes
# Add colors and title to nodes
for node in net.nodes: for node in net.nodes:
node["color"] = "#{:06x}".format(random.randint(0, 0xFFFFFF)) node["color"] = "#{:06x}".format(random.randint(0, 0xFFFFFF))
if "description" in node:
node["title"] = node["description"]
# Add title to edges
for edge in net.edges:
if "description" in edge:
edge["title"] = edge["description"]
# Save and display the network # Save and display the network
net.show("knowledge_graph.html") net.show("knowledge_graph.html")

View File

@@ -0,0 +1,114 @@
import numpy as np
from lightrag import LightRAG, QueryParam
from lightrag.utils import EmbeddingFunc
from lightrag.llm import jina_embedding, openai_complete_if_cache
import os
import asyncio
async def embedding_func(texts: list[str]) -> np.ndarray:
return await jina_embedding(texts, api_key="YourJinaAPIKey")
WORKING_DIR = "./dickens"
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
return await openai_complete_if_cache(
"solar-mini",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
api_key=os.getenv("UPSTAGE_API_KEY"),
base_url="https://api.upstage.ai/v1/solar",
**kwargs,
)
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=1024, max_token_size=8192, func=embedding_func
),
)
async def lightraginsert(file_path, semaphore):
async with semaphore:
try:
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
except UnicodeDecodeError:
# If UTF-8 decoding fails, try other encodings
with open(file_path, "r", encoding="gbk") as f:
content = f.read()
await rag.ainsert(content)
async def process_files(directory, concurrency_limit):
semaphore = asyncio.Semaphore(concurrency_limit)
tasks = []
for root, dirs, files in os.walk(directory):
for f in files:
file_path = os.path.join(root, f)
if f.startswith("."):
continue
tasks.append(lightraginsert(file_path, semaphore))
await asyncio.gather(*tasks)
async def main():
try:
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=1024,
max_token_size=8192,
func=embedding_func,
),
)
asyncio.run(process_files(WORKING_DIR, concurrency_limit=4))
# Perform naive search
print(
await rag.aquery(
"What are the top themes in this story?", param=QueryParam(mode="naive")
)
)
# Perform local search
print(
await rag.aquery(
"What are the top themes in this story?", param=QueryParam(mode="local")
)
)
# Perform global search
print(
await rag.aquery(
"What are the top themes in this story?",
param=QueryParam(mode="global"),
)
)
# Perform hybrid search
print(
await rag.aquery(
"What are the top themes in this story?",
param=QueryParam(mode="hybrid"),
)
)
except Exception as e:
print(f"An error occurred: {e}")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -87,7 +87,11 @@ class LightRAG:
) )
# Default not to use embedding cache # Default not to use embedding cache
embedding_cache_config: dict = field( embedding_cache_config: dict = field(
default_factory=lambda: {"enabled": False, "similarity_threshold": 0.95} default_factory=lambda: {
"enabled": False,
"similarity_threshold": 0.95,
"use_llm_check": False,
}
) )
kv_storage: str = field(default="JsonKVStorage") kv_storage: str = field(default="JsonKVStorage")
vector_storage: str = field(default="NanoVectorDBStorage") vector_storage: str = field(default="NanoVectorDBStorage")
@@ -174,7 +178,6 @@ class LightRAG:
if self.enable_llm_cache if self.enable_llm_cache
else None else None
) )
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
self.embedding_func self.embedding_func
) )
@@ -481,6 +484,7 @@ class LightRAG:
self.text_chunks, self.text_chunks,
param, param,
asdict(self), asdict(self),
hashing_kv=self.llm_response_cache,
) )
elif param.mode == "naive": elif param.mode == "naive":
response = await naive_query( response = await naive_query(
@@ -489,6 +493,7 @@ class LightRAG:
self.text_chunks, self.text_chunks,
param, param,
asdict(self), asdict(self),
hashing_kv=self.llm_response_cache,
) )
else: else:
raise ValueError(f"Unknown mode {param.mode}") raise ValueError(f"Unknown mode {param.mode}")

View File

@@ -4,8 +4,7 @@ import json
import os import os
import struct import struct
from functools import lru_cache from functools import lru_cache
from typing import List, Dict, Callable, Any, Union, Optional from typing import List, Dict, Callable, Any, Union
from dataclasses import dataclass
import aioboto3 import aioboto3
import aiohttp import aiohttp
import numpy as np import numpy as np
@@ -27,13 +26,9 @@ from tenacity import (
) )
from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import AutoTokenizer, AutoModelForCausalLM
from .base import BaseKVStorage
from .utils import ( from .utils import (
compute_args_hash,
wrap_embedding_func_with_attrs, wrap_embedding_func_with_attrs,
locate_json_string_body_from_string, locate_json_string_body_from_string,
quantize_embedding,
get_best_cached_response,
) )
import sys import sys
@@ -66,23 +61,13 @@ async def openai_complete_if_cache(
openai_async_client = ( openai_async_client = (
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url) AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
) )
kwargs.pop("hashing_kv", None)
messages = [] messages = []
if system_prompt: if system_prompt:
messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages) messages.extend(history_messages)
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
# Handle cache
mode = kwargs.pop("mode", "default")
args_hash = compute_args_hash(model, messages)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, prompt, mode
)
if cached_response is not None:
return cached_response
if "response_format" in kwargs: if "response_format" in kwargs:
response = await openai_async_client.beta.chat.completions.parse( response = await openai_async_client.beta.chat.completions.parse(
model=model, messages=messages, **kwargs model=model, messages=messages, **kwargs
@@ -95,21 +80,6 @@ async def openai_complete_if_cache(
if r"\u" in content: if r"\u" in content:
content = content.encode("utf-8").decode("unicode_escape") content = content.encode("utf-8").decode("unicode_escape")
# Save to cache
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=content,
model=model,
prompt=prompt,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=mode,
),
)
return content return content
@@ -140,10 +110,7 @@ async def azure_openai_complete_if_cache(
api_key=os.getenv("AZURE_OPENAI_API_KEY"), api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION"), api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
) )
kwargs.pop("hashing_kv", None)
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
mode = kwargs.pop("mode", "default")
messages = [] messages = []
if system_prompt: if system_prompt:
messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "system", "content": system_prompt})
@@ -151,34 +118,11 @@ async def azure_openai_complete_if_cache(
if prompt is not None: if prompt is not None:
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
# Handle cache
args_hash = compute_args_hash(model, messages)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, prompt, mode
)
if cached_response is not None:
return cached_response
response = await openai_async_client.chat.completions.create( response = await openai_async_client.chat.completions.create(
model=model, messages=messages, **kwargs model=model, messages=messages, **kwargs
) )
content = response.choices[0].message.content content = response.choices[0].message.content
# Save to cache
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=content,
model=model,
prompt=prompt,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=mode,
),
)
return content return content
@@ -210,7 +154,7 @@ async def bedrock_complete_if_cache(
os.environ["AWS_SESSION_TOKEN"] = os.environ.get( os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
"AWS_SESSION_TOKEN", aws_session_token "AWS_SESSION_TOKEN", aws_session_token
) )
kwargs.pop("hashing_kv", None)
# Fix message history format # Fix message history format
messages = [] messages = []
for history_message in history_messages: for history_message in history_messages:
@@ -220,15 +164,6 @@ async def bedrock_complete_if_cache(
# Add user prompt # Add user prompt
messages.append({"role": "user", "content": [{"text": prompt}]}) messages.append({"role": "user", "content": [{"text": prompt}]})
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
# Handle cache
mode = kwargs.pop("mode", "default")
args_hash = compute_args_hash(model, messages)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, prompt, mode
)
if cached_response is not None:
return cached_response
# Initialize Converse API arguments # Initialize Converse API arguments
args = {"modelId": model, "messages": messages} args = {"modelId": model, "messages": messages}
@@ -251,15 +186,6 @@ async def bedrock_complete_if_cache(
args["inferenceConfig"][inference_params_map.get(param, param)] = ( args["inferenceConfig"][inference_params_map.get(param, param)] = (
kwargs.pop(param) kwargs.pop(param)
) )
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
# Handle cache
mode = kwargs.pop("mode", "default")
args_hash = compute_args_hash(model, messages)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, prompt, mode
)
if cached_response is not None:
return cached_response
# Call model via Converse API # Call model via Converse API
session = aioboto3.Session() session = aioboto3.Session()
@@ -269,21 +195,6 @@ async def bedrock_complete_if_cache(
except Exception as e: except Exception as e:
raise BedrockError(e) raise BedrockError(e)
# Save to cache
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=response["output"]["message"]["content"][0]["text"],
model=model,
prompt=prompt,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=mode,
),
)
return response["output"]["message"]["content"][0]["text"] return response["output"]["message"]["content"][0]["text"]
@@ -315,22 +226,12 @@ async def hf_model_if_cache(
) -> str: ) -> str:
model_name = model model_name = model
hf_model, hf_tokenizer = initialize_hf_model(model_name) hf_model, hf_tokenizer = initialize_hf_model(model_name)
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages = [] messages = []
if system_prompt: if system_prompt:
messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages) messages.extend(history_messages)
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
kwargs.pop("hashing_kv", None)
# Handle cache
mode = kwargs.pop("mode", "default")
args_hash = compute_args_hash(model, messages)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, prompt, mode
)
if cached_response is not None:
return cached_response
input_prompt = "" input_prompt = ""
try: try:
input_prompt = hf_tokenizer.apply_chat_template( input_prompt = hf_tokenizer.apply_chat_template(
@@ -375,21 +276,6 @@ async def hf_model_if_cache(
output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
) )
# Save to cache
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=response_text,
model=model,
prompt=prompt,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=mode,
),
)
return response_text return response_text
@@ -410,25 +296,14 @@ async def ollama_model_if_cache(
# kwargs.pop("response_format", None) # allow json # kwargs.pop("response_format", None) # allow json
host = kwargs.pop("host", None) host = kwargs.pop("host", None)
timeout = kwargs.pop("timeout", None) timeout = kwargs.pop("timeout", None)
kwargs.pop("hashing_kv", None)
ollama_client = ollama.AsyncClient(host=host, timeout=timeout) ollama_client = ollama.AsyncClient(host=host, timeout=timeout)
messages = [] messages = []
if system_prompt: if system_prompt:
messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "system", "content": system_prompt})
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages.extend(history_messages) messages.extend(history_messages)
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
# Handle cache
mode = kwargs.pop("mode", "default")
args_hash = compute_args_hash(model, messages)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, prompt, mode
)
if cached_response is not None:
return cached_response
response = await ollama_client.chat(model=model, messages=messages, **kwargs) response = await ollama_client.chat(model=model, messages=messages, **kwargs)
if stream: if stream:
""" cannot cache stream response """ """ cannot cache stream response """
@@ -439,40 +314,7 @@ async def ollama_model_if_cache(
return inner() return inner()
else: else:
result = response["message"]["content"] return response["message"]["content"]
# Save to cache
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=result,
model=model,
prompt=prompt,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=mode,
),
)
return result
result = response["message"]["content"]
# Save to cache
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=result,
model=model,
prompt=prompt,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=mode,
),
)
return result
@lru_cache(maxsize=1) @lru_cache(maxsize=1)
@@ -547,7 +389,7 @@ async def lmdeploy_model_if_cache(
from lmdeploy import version_info, GenerationConfig from lmdeploy import version_info, GenerationConfig
except Exception: except Exception:
raise ImportError("Please install lmdeploy before intialize lmdeploy backend.") raise ImportError("Please install lmdeploy before intialize lmdeploy backend.")
kwargs.pop("hashing_kv", None)
kwargs.pop("response_format", None) kwargs.pop("response_format", None)
max_new_tokens = kwargs.pop("max_tokens", 512) max_new_tokens = kwargs.pop("max_tokens", 512)
tp = kwargs.pop("tp", 1) tp = kwargs.pop("tp", 1)
@@ -579,19 +421,9 @@ async def lmdeploy_model_if_cache(
if system_prompt: if system_prompt:
messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "system", "content": system_prompt})
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages.extend(history_messages) messages.extend(history_messages)
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
# Handle cache
mode = kwargs.pop("mode", "default")
args_hash = compute_args_hash(model, messages)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, prompt, mode
)
if cached_response is not None:
return cached_response
gen_config = GenerationConfig( gen_config = GenerationConfig(
skip_special_tokens=skip_special_tokens, skip_special_tokens=skip_special_tokens,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
@@ -607,22 +439,6 @@ async def lmdeploy_model_if_cache(
session_id=1, session_id=1,
): ):
response += res.response response += res.response
# Save to cache
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=response,
model=model,
prompt=prompt,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=mode,
),
)
return response return response
@@ -767,6 +583,39 @@ async def openai_embedding(
return np.array([dp.embedding for dp in response.data]) return np.array([dp.embedding for dp in response.data])
async def fetch_data(url, headers, data):
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=data) as response:
response_json = await response.json()
data_list = response_json.get("data", [])
return data_list
async def jina_embedding(
texts: list[str],
dimensions: int = 1024,
late_chunking: bool = False,
base_url: str = None,
api_key: str = None,
) -> np.ndarray:
if api_key:
os.environ["JINA_API_KEY"] = api_key
url = "https://api.jina.ai/v1/embeddings" if not base_url else base_url
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ['JINA_API_KEY']}",
}
data = {
"model": "jina-embeddings-v3",
"normalized": True,
"embedding_type": "float",
"dimensions": f"{dimensions}",
"late_chunking": late_chunking,
"input": texts,
}
data_list = await fetch_data(url, headers, data)
return np.array([dp["embedding"] for dp in data_list])
@wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=512) @wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=512)
@retry( @retry(
stop=stop_after_attempt(3), stop=stop_after_attempt(3),
@@ -1052,75 +901,6 @@ class MultiModel:
return await next_model.gen_func(**args) return await next_model.gen_func(**args)
async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
"""Generic cache handling function"""
if hashing_kv is None:
return None, None, None, None
# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
)
is_embedding_cache_enabled = embedding_cache_config["enabled"]
quantized = min_val = max_val = None
if is_embedding_cache_enabled:
# Use embedding cache
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
current_embedding = await embedding_model_func([prompt])
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
best_cached_response = await get_best_cached_response(
hashing_kv,
current_embedding[0],
similarity_threshold=embedding_cache_config["similarity_threshold"],
mode=mode,
)
if best_cached_response is not None:
return best_cached_response, None, None, None
else:
# Use regular cache
mode_cache = await hashing_kv.get_by_id(mode) or {}
if args_hash in mode_cache:
return mode_cache[args_hash]["return"], None, None, None
return None, quantized, min_val, max_val
@dataclass
class CacheData:
args_hash: str
content: str
model: str
prompt: str
quantized: Optional[np.ndarray] = None
min_val: Optional[float] = None
max_val: Optional[float] = None
mode: str = "default"
async def save_to_cache(hashing_kv, cache_data: CacheData):
if hashing_kv is None:
return
mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
mode_cache[cache_data.args_hash] = {
"return": cache_data.content,
"model": cache_data.model,
"embedding": cache_data.quantized.tobytes().hex()
if cache_data.quantized is not None
else None,
"embedding_shape": cache_data.quantized.shape
if cache_data.quantized is not None
else None,
"embedding_min": cache_data.min_val,
"embedding_max": cache_data.max_val,
"original_prompt": cache_data.prompt,
}
await hashing_kv.upsert({cache_data.mode: mode_cache})
if __name__ == "__main__": if __name__ == "__main__":
import asyncio import asyncio

View File

@@ -17,6 +17,10 @@ from .utils import (
split_string_by_multi_markers, split_string_by_multi_markers,
truncate_list_by_token_size, truncate_list_by_token_size,
process_combine_contexts, process_combine_contexts,
compute_args_hash,
handle_cache,
save_to_cache,
CacheData,
) )
from .base import ( from .base import (
BaseGraphStorage, BaseGraphStorage,
@@ -452,8 +456,17 @@ async def kg_query(
text_chunks_db: BaseKVStorage[TextChunkSchema], text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam, query_param: QueryParam,
global_config: dict, global_config: dict,
hashing_kv: BaseKVStorage = None,
) -> str: ) -> str:
context = None # Handle cache
use_model_func = global_config["llm_model_func"]
args_hash = compute_args_hash(query_param.mode, query)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, query_param.mode
)
if cached_response is not None:
return cached_response
example_number = global_config["addon_params"].get("example_number", None) example_number = global_config["addon_params"].get("example_number", None)
if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]): if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]):
examples = "\n".join( examples = "\n".join(
@@ -471,12 +484,9 @@ async def kg_query(
return PROMPTS["fail_response"] return PROMPTS["fail_response"]
# LLM generate keywords # LLM generate keywords
use_model_func = global_config["llm_model_func"]
kw_prompt_temp = PROMPTS["keywords_extraction"] kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language) kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language)
result = await use_model_func( result = await use_model_func(kw_prompt, keyword_extraction=True)
kw_prompt, keyword_extraction=True, mode=query_param.mode
)
logger.info("kw_prompt result:") logger.info("kw_prompt result:")
print(result) print(result)
try: try:
@@ -537,7 +547,6 @@ async def kg_query(
query, query,
system_prompt=sys_prompt, system_prompt=sys_prompt,
stream=query_param.stream, stream=query_param.stream,
mode=query_param.mode,
) )
if isinstance(response, str) and len(response) > len(sys_prompt): if isinstance(response, str) and len(response) > len(sys_prompt):
response = ( response = (
@@ -550,6 +559,19 @@ async def kg_query(
.strip() .strip()
) )
# Save to cache
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=response,
prompt=query,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=query_param.mode,
),
)
return response return response
@@ -967,23 +989,37 @@ async def _find_related_text_unit_from_relationships(
for index, unit_list in enumerate(text_units): for index, unit_list in enumerate(text_units):
for c_id in unit_list: for c_id in unit_list:
if c_id not in all_text_units_lookup: if c_id not in all_text_units_lookup:
all_text_units_lookup[c_id] = { chunk_data = await text_chunks_db.get_by_id(c_id)
"data": await text_chunks_db.get_by_id(c_id), # Only store valid data
"order": index, if chunk_data is not None and "content" in chunk_data:
} all_text_units_lookup[c_id] = {
"data": chunk_data,
"order": index,
}
if any([v is None for v in all_text_units_lookup.values()]): if not all_text_units_lookup:
logger.warning("Text chunks are missing, maybe the storage is damaged") logger.warning("No valid text chunks found")
all_text_units = [ return []
{"id": k, **v} for k, v in all_text_units_lookup.items() if v is not None
] all_text_units = [{"id": k, **v} for k, v in all_text_units_lookup.items()]
all_text_units = sorted(all_text_units, key=lambda x: x["order"]) all_text_units = sorted(all_text_units, key=lambda x: x["order"])
all_text_units = truncate_list_by_token_size(
all_text_units, # Ensure all text chunks have content
valid_text_units = [
t for t in all_text_units if t["data"] is not None and "content" in t["data"]
]
if not valid_text_units:
logger.warning("No valid text chunks after filtering")
return []
truncated_text_units = truncate_list_by_token_size(
valid_text_units,
key=lambda x: x["data"]["content"], key=lambda x: x["data"]["content"],
max_token_size=query_param.max_token_for_text_unit, max_token_size=query_param.max_token_for_text_unit,
) )
all_text_units: list[TextChunkSchema] = [t["data"] for t in all_text_units]
all_text_units: list[TextChunkSchema] = [t["data"] for t in truncated_text_units]
return all_text_units return all_text_units
@@ -1013,29 +1049,57 @@ async def naive_query(
text_chunks_db: BaseKVStorage[TextChunkSchema], text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam, query_param: QueryParam,
global_config: dict, global_config: dict,
hashing_kv: BaseKVStorage = None,
): ):
# Handle cache
use_model_func = global_config["llm_model_func"] use_model_func = global_config["llm_model_func"]
args_hash = compute_args_hash(query_param.mode, query)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, query_param.mode
)
if cached_response is not None:
return cached_response
results = await chunks_vdb.query(query, top_k=query_param.top_k) results = await chunks_vdb.query(query, top_k=query_param.top_k)
if not len(results): if not len(results):
return PROMPTS["fail_response"] return PROMPTS["fail_response"]
chunks_ids = [r["id"] for r in results] chunks_ids = [r["id"] for r in results]
chunks = await text_chunks_db.get_by_ids(chunks_ids) chunks = await text_chunks_db.get_by_ids(chunks_ids)
# Filter out invalid chunks
valid_chunks = [
chunk for chunk in chunks if chunk is not None and "content" in chunk
]
if not valid_chunks:
logger.warning("No valid chunks found after filtering")
return PROMPTS["fail_response"]
maybe_trun_chunks = truncate_list_by_token_size( maybe_trun_chunks = truncate_list_by_token_size(
chunks, valid_chunks,
key=lambda x: x["content"], key=lambda x: x["content"],
max_token_size=query_param.max_token_for_text_unit, max_token_size=query_param.max_token_for_text_unit,
) )
if not maybe_trun_chunks:
logger.warning("No chunks left after truncation")
return PROMPTS["fail_response"]
logger.info(f"Truncate {len(chunks)} to {len(maybe_trun_chunks)} chunks") logger.info(f"Truncate {len(chunks)} to {len(maybe_trun_chunks)} chunks")
section = "\n--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks]) section = "\n--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks])
if query_param.only_need_context: if query_param.only_need_context:
return section return section
sys_prompt_temp = PROMPTS["naive_rag_response"] sys_prompt_temp = PROMPTS["naive_rag_response"]
sys_prompt = sys_prompt_temp.format( sys_prompt = sys_prompt_temp.format(
content_data=section, response_type=query_param.response_type content_data=section, response_type=query_param.response_type
) )
if query_param.only_need_prompt: if query_param.only_need_prompt:
return sys_prompt return sys_prompt
response = await use_model_func( response = await use_model_func(
query, query,
system_prompt=sys_prompt, system_prompt=sys_prompt,
@@ -1054,4 +1118,18 @@ async def naive_query(
.strip() .strip()
) )
# Save to cache
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=response,
prompt=query,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=query_param.mode,
),
)
return response return response

View File

@@ -261,3 +261,22 @@ Do not include information where the supporting evidence for it is not provided.
Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown. Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown.
""" """
PROMPTS[
"similarity_check"
] = """Please analyze the similarity between these two questions:
Question 1: {original_prompt}
Question 2: {cached_prompt}
Please evaluate:
1. Whether these two questions are semantically similar
2. Whether the answer to Question 2 can be used to answer Question 1
Please provide a similarity score between 0 and 1, where:
0: Completely unrelated or answer cannot be reused
1: Identical and answer can be directly reused
0.5: Partially related and answer needs modification to be used
Return only a number between 0-1, without any additional content.
"""

View File

@@ -9,12 +9,14 @@ import re
from dataclasses import dataclass from dataclasses import dataclass
from functools import wraps from functools import wraps
from hashlib import md5 from hashlib import md5
from typing import Any, Union, List from typing import Any, Union, List, Optional
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
import numpy as np import numpy as np
import tiktoken import tiktoken
from lightrag.prompt import PROMPTS
ENCODER = None ENCODER = None
logger = logging.getLogger("lightrag") logger = logging.getLogger("lightrag")
@@ -314,6 +316,9 @@ async def get_best_cached_response(
current_embedding, current_embedding,
similarity_threshold=0.95, similarity_threshold=0.95,
mode="default", mode="default",
use_llm_check=False,
llm_func=None,
original_prompt=None,
) -> Union[str, None]: ) -> Union[str, None]:
# Get mode-specific cache # Get mode-specific cache
mode_cache = await hashing_kv.get_by_id(mode) mode_cache = await hashing_kv.get_by_id(mode)
@@ -348,6 +353,37 @@ async def get_best_cached_response(
best_cache_id = cache_id best_cache_id = cache_id
if best_similarity > similarity_threshold: if best_similarity > similarity_threshold:
# If LLM check is enabled and all required parameters are provided
if use_llm_check and llm_func and original_prompt and best_prompt:
compare_prompt = PROMPTS["similarity_check"].format(
original_prompt=original_prompt, cached_prompt=best_prompt
)
try:
llm_result = await llm_func(compare_prompt)
llm_result = llm_result.strip()
llm_similarity = float(llm_result)
# Replace vector similarity with LLM similarity score
best_similarity = llm_similarity
if best_similarity < similarity_threshold:
log_data = {
"event": "llm_check_cache_rejected",
"original_question": original_prompt[:100] + "..."
if len(original_prompt) > 100
else original_prompt,
"cached_question": best_prompt[:100] + "..."
if len(best_prompt) > 100
else best_prompt,
"similarity_score": round(best_similarity, 4),
"threshold": similarity_threshold,
}
logger.info(json.dumps(log_data, ensure_ascii=False))
return None
except Exception as e: # Catch all possible exceptions
logger.warning(f"LLM similarity check failed: {e}")
return None # Return None directly when LLM check fails
prompt_display = ( prompt_display = (
best_prompt[:50] + "..." if len(best_prompt) > 50 else best_prompt best_prompt[:50] + "..." if len(best_prompt) > 50 else best_prompt
) )
@@ -390,3 +426,84 @@ def dequantize_embedding(
"""Restore quantized embedding""" """Restore quantized embedding"""
scale = (max_val - min_val) / (2**bits - 1) scale = (max_val - min_val) / (2**bits - 1)
return (quantized * scale + min_val).astype(np.float32) return (quantized * scale + min_val).astype(np.float32)
async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
"""Generic cache handling function"""
if hashing_kv is None:
return None, None, None, None
# For naive mode, only use simple cache matching
if mode == "naive":
mode_cache = await hashing_kv.get_by_id(mode) or {}
if args_hash in mode_cache:
return mode_cache[args_hash]["return"], None, None, None
return None, None, None, None
# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config",
{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False},
)
is_embedding_cache_enabled = embedding_cache_config["enabled"]
use_llm_check = embedding_cache_config.get("use_llm_check", False)
quantized = min_val = max_val = None
if is_embedding_cache_enabled:
# Use embedding cache
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
llm_model_func = hashing_kv.global_config.get("llm_model_func")
current_embedding = await embedding_model_func([prompt])
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
best_cached_response = await get_best_cached_response(
hashing_kv,
current_embedding[0],
similarity_threshold=embedding_cache_config["similarity_threshold"],
mode=mode,
use_llm_check=use_llm_check,
llm_func=llm_model_func if use_llm_check else None,
original_prompt=prompt if use_llm_check else None,
)
if best_cached_response is not None:
return best_cached_response, None, None, None
else:
# Use regular cache
mode_cache = await hashing_kv.get_by_id(mode) or {}
if args_hash in mode_cache:
return mode_cache[args_hash]["return"], None, None, None
return None, quantized, min_val, max_val
@dataclass
class CacheData:
args_hash: str
content: str
prompt: str
quantized: Optional[np.ndarray] = None
min_val: Optional[float] = None
max_val: Optional[float] = None
mode: str = "default"
async def save_to_cache(hashing_kv, cache_data: CacheData):
if hashing_kv is None:
return
mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
mode_cache[cache_data.args_hash] = {
"return": cache_data.content,
"embedding": cache_data.quantized.tobytes().hex()
if cache_data.quantized is not None
else None,
"embedding_shape": cache_data.quantized.shape
if cache_data.quantized is not None
else None,
"embedding_min": cache_data.min_val,
"embedding_max": cache_data.max_val,
"original_prompt": cache_data.prompt,
}
await hashing_kv.upsert({cache_data.mode: mode_cache})