Fix LLM cache now work for nodes and edges merging

This commit is contained in:
yangdx
2025-04-10 03:57:36 +08:00
parent ad087073aa
commit 8d858da4d0
3 changed files with 135 additions and 66 deletions

View File

@@ -1 +1 @@
__api_version__ = "0142"
__api_version__ = "0143"

View File

@@ -24,8 +24,8 @@ from .utils import (
handle_cache,
save_to_cache,
CacheData,
statistic_data,
get_conversation_turns,
use_llm_func_with_cache,
)
from .base import (
BaseGraphStorage,
@@ -108,6 +108,7 @@ async def _handle_entity_relation_summary(
global_config: dict,
pipeline_status: dict = None,
pipeline_status_lock=None,
llm_response_cache: BaseKVStorage | None = None,
) -> str:
"""Handle entity relation summary
For each entity or relation, input is the combined description of already existing description and new description.
@@ -125,13 +126,6 @@ async def _handle_entity_relation_summary(
if len(tokens) < summary_max_tokens: # No need for summary
return description
# Update pipeline status when LLM summary is needed
status_message = "Use LLM to re-summary description..."
logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = status_message
pipeline_status["history_messages"].append(status_message)
prompt_template = PROMPTS["summarize_entity_descriptions"]
use_description = decode_tokens_by_tiktoken(
tokens[:llm_max_tokens], model_name=tiktoken_model_name
@@ -143,7 +137,23 @@ async def _handle_entity_relation_summary(
)
use_prompt = prompt_template.format(**context_base)
logger.debug(f"Trigger summary: {entity_or_relation_name}")
summary = await use_llm_func(use_prompt, max_tokens=summary_max_tokens)
# Update pipeline status when LLM summary is needed
status_message = "Use LLM to re-summary description..."
logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = status_message
pipeline_status["history_messages"].append(status_message)
# Use LLM function with cache
summary = await use_llm_func_with_cache(
use_prompt,
use_llm_func,
llm_response_cache=llm_response_cache,
max_tokens=summary_max_tokens,
cache_type="extract",
)
return summary
@@ -224,6 +234,7 @@ async def _merge_nodes_then_upsert(
global_config: dict,
pipeline_status: dict = None,
pipeline_status_lock=None,
llm_response_cache: BaseKVStorage | None = None,
):
"""Get existing nodes from knowledge graph use name,if exists, merge data, else create, then upsert."""
already_entity_types = []
@@ -269,7 +280,12 @@ async def _merge_nodes_then_upsert(
logger.debug(f"file_path: {file_path}")
description = await _handle_entity_relation_summary(
entity_name, description, global_config, pipeline_status, pipeline_status_lock
entity_name,
description,
global_config,
pipeline_status,
pipeline_status_lock,
llm_response_cache,
)
node_data = dict(
entity_id=entity_name,
@@ -294,6 +310,7 @@ async def _merge_edges_then_upsert(
global_config: dict,
pipeline_status: dict = None,
pipeline_status_lock=None,
llm_response_cache: BaseKVStorage | None = None,
):
already_weights = []
already_source_ids = []
@@ -393,6 +410,7 @@ async def _merge_edges_then_upsert(
global_config,
pipeline_status,
pipeline_status_lock,
llm_response_cache,
)
await knowledge_graph_inst.upsert_edge(
src_id,
@@ -428,11 +446,9 @@ async def extract_entities(
pipeline_status_lock=None,
llm_response_cache: BaseKVStorage | None = None,
) -> None:
use_llm_func: callable = global_config["llm_model_func"]
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
enable_llm_cache_for_entity_extract: bool = global_config[
"enable_llm_cache_for_entity_extract"
]
ordered_chunks = list(chunks.items())
# add language and example number params to prompt
@@ -483,51 +499,7 @@ async def extract_entities(
graph_db_lock = get_graph_db_lock(enable_logging=False)
async def _user_llm_func_with_cache(
input_text: str, history_messages: list[dict[str, str]] = None
) -> str:
if enable_llm_cache_for_entity_extract and llm_response_cache:
if history_messages:
history = json.dumps(history_messages, ensure_ascii=False)
_prompt = history + "\n" + input_text
else:
_prompt = input_text
# TODO add cache_type="extract"
arg_hash = compute_args_hash(_prompt)
cached_return, _1, _2, _3 = await handle_cache(
llm_response_cache,
arg_hash,
_prompt,
"default",
cache_type="extract",
)
if cached_return:
logger.debug(f"Found cache for {arg_hash}")
statistic_data["llm_cache"] += 1
return cached_return
statistic_data["llm_call"] += 1
if history_messages:
res: str = await use_llm_func(
input_text, history_messages=history_messages
)
else:
res: str = await use_llm_func(input_text)
await save_to_cache(
llm_response_cache,
CacheData(
args_hash=arg_hash,
content=res,
prompt=_prompt,
cache_type="extract",
),
)
return res
if history_messages:
return await use_llm_func(input_text, history_messages=history_messages)
else:
return await use_llm_func(input_text)
# Use the global use_llm_func_with_cache function from utils.py
async def _process_extraction_result(
result: str, chunk_key: str, file_path: str = "unknown_source"
@@ -592,7 +564,12 @@ async def extract_entities(
**context_base, input_text="{input_text}"
).format(**context_base, input_text=content)
final_result = await _user_llm_func_with_cache(hint_prompt)
final_result = await use_llm_func_with_cache(
hint_prompt,
use_llm_func,
llm_response_cache=llm_response_cache,
cache_type="extract",
)
history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
# Process initial extraction with file path
@@ -602,8 +579,12 @@ async def extract_entities(
# Process additional gleaning results
for now_glean_index in range(entity_extract_max_gleaning):
glean_result = await _user_llm_func_with_cache(
continue_prompt, history_messages=history
glean_result = await use_llm_func_with_cache(
continue_prompt,
use_llm_func,
llm_response_cache=llm_response_cache,
history_messages=history,
cache_type="extract",
)
history += pack_user_ass_to_openai_messages(continue_prompt, glean_result)
@@ -622,8 +603,12 @@ async def extract_entities(
if now_glean_index == entity_extract_max_gleaning - 1:
break
if_loop_result: str = await _user_llm_func_with_cache(
if_loop_prompt, history_messages=history
if_loop_result: str = await use_llm_func_with_cache(
if_loop_prompt,
use_llm_func,
llm_response_cache=llm_response_cache,
history_messages=history,
cache_type="extract",
)
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
if if_loop_result != "yes":
@@ -653,6 +638,7 @@ async def extract_entities(
global_config,
pipeline_status,
pipeline_status_lock,
llm_response_cache,
)
chunk_entities_data.append(entity_data)
@@ -668,6 +654,7 @@ async def extract_entities(
global_config,
pipeline_status,
pipeline_status_lock,
llm_response_cache,
)
chunk_relationships_data.append(edge_data)

View File

@@ -12,13 +12,17 @@ import re
from dataclasses import dataclass
from functools import wraps
from hashlib import md5
from typing import Any, Callable
from typing import Any, Callable, TYPE_CHECKING
import xml.etree.ElementTree as ET
import numpy as np
import tiktoken
from lightrag.prompt import PROMPTS
from dotenv import load_dotenv
# Use TYPE_CHECKING to avoid circular imports
if TYPE_CHECKING:
from lightrag.base import BaseKVStorage
# use the .env that is inside the current folder
# allows to use different .env file for each lightrag instance
# the OS environment variables take precedence over the .env file
@@ -908,6 +912,84 @@ def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any
return import_class
async def use_llm_func_with_cache(
input_text: str,
use_llm_func: callable,
llm_response_cache: 'BaseKVStorage | None' = None,
max_tokens: int = None,
history_messages: list[dict[str, str]] = None,
cache_type: str = "extract"
) -> str:
"""Call LLM function with cache support
If cache is available and enabled (determined by handle_cache based on mode),
retrieve result from cache; otherwise call LLM function and save result to cache.
Args:
input_text: Input text to send to LLM
use_llm_func: LLM function to call
llm_response_cache: Cache storage instance
max_tokens: Maximum tokens for generation
history_messages: History messages list
cache_type: Type of cache
Returns:
LLM response text
"""
if llm_response_cache:
if history_messages:
history = json.dumps(history_messages, ensure_ascii=False)
_prompt = history + "\n" + input_text
else:
_prompt = input_text
arg_hash = compute_args_hash(_prompt)
cached_return, _1, _2, _3 = await handle_cache(
llm_response_cache,
arg_hash,
_prompt,
"default",
cache_type=cache_type,
)
if cached_return:
logger.debug(f"Found cache for {arg_hash}")
statistic_data["llm_cache"] += 1
return cached_return
statistic_data["llm_call"] += 1
# Call LLM
kwargs = {}
if history_messages:
kwargs["history_messages"] = history_messages
if max_tokens is not None:
kwargs["max_tokens"] = max_tokens
res: str = await use_llm_func(input_text, **kwargs)
# Save to cache
logger.info(f"Saving LLM cache for {arg_hash}")
await save_to_cache(
llm_response_cache,
CacheData(
args_hash=arg_hash,
content=res,
prompt=_prompt,
cache_type=cache_type,
),
)
return res
# When cache is disabled, directly call LLM
kwargs = {}
if history_messages:
kwargs["history_messages"] = history_messages
if max_tokens is not None:
kwargs["max_tokens"] = max_tokens
logger.info(f"Call LLM function with query text lenght: {len(input_text)}")
return await use_llm_func(input_text, **kwargs)
def get_content_summary(content: str, max_length: int = 250) -> str:
"""Get summary of document content