Fix LLM cache now work for nodes and edges merging

This commit is contained in:
yangdx
2025-04-10 03:57:36 +08:00
parent ad087073aa
commit 8d858da4d0
3 changed files with 135 additions and 66 deletions

View File

@@ -1 +1 @@
__api_version__ = "0142" __api_version__ = "0143"

View File

@@ -24,8 +24,8 @@ from .utils import (
handle_cache, handle_cache,
save_to_cache, save_to_cache,
CacheData, CacheData,
statistic_data,
get_conversation_turns, get_conversation_turns,
use_llm_func_with_cache,
) )
from .base import ( from .base import (
BaseGraphStorage, BaseGraphStorage,
@@ -108,6 +108,7 @@ async def _handle_entity_relation_summary(
global_config: dict, global_config: dict,
pipeline_status: dict = None, pipeline_status: dict = None,
pipeline_status_lock=None, pipeline_status_lock=None,
llm_response_cache: BaseKVStorage | None = None,
) -> str: ) -> str:
"""Handle entity relation summary """Handle entity relation summary
For each entity or relation, input is the combined description of already existing description and new description. For each entity or relation, input is the combined description of already existing description and new description.
@@ -125,13 +126,6 @@ async def _handle_entity_relation_summary(
if len(tokens) < summary_max_tokens: # No need for summary if len(tokens) < summary_max_tokens: # No need for summary
return description return description
# Update pipeline status when LLM summary is needed
status_message = "Use LLM to re-summary description..."
logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = status_message
pipeline_status["history_messages"].append(status_message)
prompt_template = PROMPTS["summarize_entity_descriptions"] prompt_template = PROMPTS["summarize_entity_descriptions"]
use_description = decode_tokens_by_tiktoken( use_description = decode_tokens_by_tiktoken(
tokens[:llm_max_tokens], model_name=tiktoken_model_name tokens[:llm_max_tokens], model_name=tiktoken_model_name
@@ -143,7 +137,23 @@ async def _handle_entity_relation_summary(
) )
use_prompt = prompt_template.format(**context_base) use_prompt = prompt_template.format(**context_base)
logger.debug(f"Trigger summary: {entity_or_relation_name}") logger.debug(f"Trigger summary: {entity_or_relation_name}")
summary = await use_llm_func(use_prompt, max_tokens=summary_max_tokens)
# Update pipeline status when LLM summary is needed
status_message = "Use LLM to re-summary description..."
logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = status_message
pipeline_status["history_messages"].append(status_message)
# Use LLM function with cache
summary = await use_llm_func_with_cache(
use_prompt,
use_llm_func,
llm_response_cache=llm_response_cache,
max_tokens=summary_max_tokens,
cache_type="extract",
)
return summary return summary
@@ -224,6 +234,7 @@ async def _merge_nodes_then_upsert(
global_config: dict, global_config: dict,
pipeline_status: dict = None, pipeline_status: dict = None,
pipeline_status_lock=None, pipeline_status_lock=None,
llm_response_cache: BaseKVStorage | None = None,
): ):
"""Get existing nodes from knowledge graph use name,if exists, merge data, else create, then upsert.""" """Get existing nodes from knowledge graph use name,if exists, merge data, else create, then upsert."""
already_entity_types = [] already_entity_types = []
@@ -269,7 +280,12 @@ async def _merge_nodes_then_upsert(
logger.debug(f"file_path: {file_path}") logger.debug(f"file_path: {file_path}")
description = await _handle_entity_relation_summary( description = await _handle_entity_relation_summary(
entity_name, description, global_config, pipeline_status, pipeline_status_lock entity_name,
description,
global_config,
pipeline_status,
pipeline_status_lock,
llm_response_cache,
) )
node_data = dict( node_data = dict(
entity_id=entity_name, entity_id=entity_name,
@@ -294,6 +310,7 @@ async def _merge_edges_then_upsert(
global_config: dict, global_config: dict,
pipeline_status: dict = None, pipeline_status: dict = None,
pipeline_status_lock=None, pipeline_status_lock=None,
llm_response_cache: BaseKVStorage | None = None,
): ):
already_weights = [] already_weights = []
already_source_ids = [] already_source_ids = []
@@ -303,7 +320,7 @@ async def _merge_edges_then_upsert(
if await knowledge_graph_inst.has_edge(src_id, tgt_id): if await knowledge_graph_inst.has_edge(src_id, tgt_id):
# Update pipeline status when an edge that needs merging is found # Update pipeline status when an edge that needs merging is found
status_message = f"Merging edges: {src_id} - {tgt_id}" status_message = f"Merging edges: {src_id} - {tgt_id}"
logger.info(status_message) logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None: if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock: async with pipeline_status_lock:
@@ -393,6 +410,7 @@ async def _merge_edges_then_upsert(
global_config, global_config,
pipeline_status, pipeline_status,
pipeline_status_lock, pipeline_status_lock,
llm_response_cache,
) )
await knowledge_graph_inst.upsert_edge( await knowledge_graph_inst.upsert_edge(
src_id, src_id,
@@ -428,11 +446,9 @@ async def extract_entities(
pipeline_status_lock=None, pipeline_status_lock=None,
llm_response_cache: BaseKVStorage | None = None, llm_response_cache: BaseKVStorage | None = None,
) -> None: ) -> None:
use_llm_func: callable = global_config["llm_model_func"] use_llm_func: callable = global_config["llm_model_func"]
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
enable_llm_cache_for_entity_extract: bool = global_config[
"enable_llm_cache_for_entity_extract"
]
ordered_chunks = list(chunks.items()) ordered_chunks = list(chunks.items())
# add language and example number params to prompt # add language and example number params to prompt
@@ -483,51 +499,7 @@ async def extract_entities(
graph_db_lock = get_graph_db_lock(enable_logging=False) graph_db_lock = get_graph_db_lock(enable_logging=False)
async def _user_llm_func_with_cache( # Use the global use_llm_func_with_cache function from utils.py
input_text: str, history_messages: list[dict[str, str]] = None
) -> str:
if enable_llm_cache_for_entity_extract and llm_response_cache:
if history_messages:
history = json.dumps(history_messages, ensure_ascii=False)
_prompt = history + "\n" + input_text
else:
_prompt = input_text
# TODO add cache_type="extract"
arg_hash = compute_args_hash(_prompt)
cached_return, _1, _2, _3 = await handle_cache(
llm_response_cache,
arg_hash,
_prompt,
"default",
cache_type="extract",
)
if cached_return:
logger.debug(f"Found cache for {arg_hash}")
statistic_data["llm_cache"] += 1
return cached_return
statistic_data["llm_call"] += 1
if history_messages:
res: str = await use_llm_func(
input_text, history_messages=history_messages
)
else:
res: str = await use_llm_func(input_text)
await save_to_cache(
llm_response_cache,
CacheData(
args_hash=arg_hash,
content=res,
prompt=_prompt,
cache_type="extract",
),
)
return res
if history_messages:
return await use_llm_func(input_text, history_messages=history_messages)
else:
return await use_llm_func(input_text)
async def _process_extraction_result( async def _process_extraction_result(
result: str, chunk_key: str, file_path: str = "unknown_source" result: str, chunk_key: str, file_path: str = "unknown_source"
@@ -592,7 +564,12 @@ async def extract_entities(
**context_base, input_text="{input_text}" **context_base, input_text="{input_text}"
).format(**context_base, input_text=content) ).format(**context_base, input_text=content)
final_result = await _user_llm_func_with_cache(hint_prompt) final_result = await use_llm_func_with_cache(
hint_prompt,
use_llm_func,
llm_response_cache=llm_response_cache,
cache_type="extract",
)
history = pack_user_ass_to_openai_messages(hint_prompt, final_result) history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
# Process initial extraction with file path # Process initial extraction with file path
@@ -602,8 +579,12 @@ async def extract_entities(
# Process additional gleaning results # Process additional gleaning results
for now_glean_index in range(entity_extract_max_gleaning): for now_glean_index in range(entity_extract_max_gleaning):
glean_result = await _user_llm_func_with_cache( glean_result = await use_llm_func_with_cache(
continue_prompt, history_messages=history continue_prompt,
use_llm_func,
llm_response_cache=llm_response_cache,
history_messages=history,
cache_type="extract",
) )
history += pack_user_ass_to_openai_messages(continue_prompt, glean_result) history += pack_user_ass_to_openai_messages(continue_prompt, glean_result)
@@ -622,8 +603,12 @@ async def extract_entities(
if now_glean_index == entity_extract_max_gleaning - 1: if now_glean_index == entity_extract_max_gleaning - 1:
break break
if_loop_result: str = await _user_llm_func_with_cache( if_loop_result: str = await use_llm_func_with_cache(
if_loop_prompt, history_messages=history if_loop_prompt,
use_llm_func,
llm_response_cache=llm_response_cache,
history_messages=history,
cache_type="extract",
) )
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower() if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
if if_loop_result != "yes": if if_loop_result != "yes":
@@ -653,6 +638,7 @@ async def extract_entities(
global_config, global_config,
pipeline_status, pipeline_status,
pipeline_status_lock, pipeline_status_lock,
llm_response_cache,
) )
chunk_entities_data.append(entity_data) chunk_entities_data.append(entity_data)
@@ -668,6 +654,7 @@ async def extract_entities(
global_config, global_config,
pipeline_status, pipeline_status,
pipeline_status_lock, pipeline_status_lock,
llm_response_cache,
) )
chunk_relationships_data.append(edge_data) chunk_relationships_data.append(edge_data)

View File

@@ -12,13 +12,17 @@ import re
from dataclasses import dataclass from dataclasses import dataclass
from functools import wraps from functools import wraps
from hashlib import md5 from hashlib import md5
from typing import Any, Callable from typing import Any, Callable, TYPE_CHECKING
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
import numpy as np import numpy as np
import tiktoken import tiktoken
from lightrag.prompt import PROMPTS from lightrag.prompt import PROMPTS
from dotenv import load_dotenv from dotenv import load_dotenv
# Use TYPE_CHECKING to avoid circular imports
if TYPE_CHECKING:
from lightrag.base import BaseKVStorage
# use the .env that is inside the current folder # use the .env that is inside the current folder
# allows to use different .env file for each lightrag instance # allows to use different .env file for each lightrag instance
# the OS environment variables take precedence over the .env file # the OS environment variables take precedence over the .env file
@@ -908,6 +912,84 @@ def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any
return import_class return import_class
async def use_llm_func_with_cache(
input_text: str,
use_llm_func: callable,
llm_response_cache: 'BaseKVStorage | None' = None,
max_tokens: int = None,
history_messages: list[dict[str, str]] = None,
cache_type: str = "extract"
) -> str:
"""Call LLM function with cache support
If cache is available and enabled (determined by handle_cache based on mode),
retrieve result from cache; otherwise call LLM function and save result to cache.
Args:
input_text: Input text to send to LLM
use_llm_func: LLM function to call
llm_response_cache: Cache storage instance
max_tokens: Maximum tokens for generation
history_messages: History messages list
cache_type: Type of cache
Returns:
LLM response text
"""
if llm_response_cache:
if history_messages:
history = json.dumps(history_messages, ensure_ascii=False)
_prompt = history + "\n" + input_text
else:
_prompt = input_text
arg_hash = compute_args_hash(_prompt)
cached_return, _1, _2, _3 = await handle_cache(
llm_response_cache,
arg_hash,
_prompt,
"default",
cache_type=cache_type,
)
if cached_return:
logger.debug(f"Found cache for {arg_hash}")
statistic_data["llm_cache"] += 1
return cached_return
statistic_data["llm_call"] += 1
# Call LLM
kwargs = {}
if history_messages:
kwargs["history_messages"] = history_messages
if max_tokens is not None:
kwargs["max_tokens"] = max_tokens
res: str = await use_llm_func(input_text, **kwargs)
# Save to cache
logger.info(f"Saving LLM cache for {arg_hash}")
await save_to_cache(
llm_response_cache,
CacheData(
args_hash=arg_hash,
content=res,
prompt=_prompt,
cache_type=cache_type,
),
)
return res
# When cache is disabled, directly call LLM
kwargs = {}
if history_messages:
kwargs["history_messages"] = history_messages
if max_tokens is not None:
kwargs["max_tokens"] = max_tokens
logger.info(f"Call LLM function with query text lenght: {len(input_text)}")
return await use_llm_func(input_text, **kwargs)
def get_content_summary(content: str, max_length: int = 250) -> str: def get_content_summary(content: str, max_length: int = 250) -> str:
"""Get summary of document content """Get summary of document content