Merge pull request #1328 from danielaskdd/main

Fix LLM cache now work for nodes and edges merging
This commit is contained in:
Daniel.y
2025-04-10 04:24:34 +08:00
committed by GitHub
8 changed files with 176 additions and 67 deletions

View File

@@ -1 +1 @@
__api_version__ = "0142" __api_version__ = "0143"

View File

@@ -116,7 +116,7 @@ class JsonDocStatusStorage(DocStatusStorage):
""" """
if not data: if not data:
return return
logger.info(f"Inserting {len(data)} records to {self.namespace}") logger.debug(f"Inserting {len(data)} records to {self.namespace}")
async with self._storage_lock: async with self._storage_lock:
self._data.update(data) self._data.update(data)
await set_all_update_flags(self.namespace) await set_all_update_flags(self.namespace)

View File

@@ -121,7 +121,7 @@ class JsonKVStorage(BaseKVStorage):
""" """
if not data: if not data:
return return
logger.info(f"Inserting {len(data)} records to {self.namespace}") logger.debug(f"Inserting {len(data)} records to {self.namespace}")
async with self._storage_lock: async with self._storage_lock:
self._data.update(data) self._data.update(data)
await set_all_update_flags(self.namespace) await set_all_update_flags(self.namespace)

View File

@@ -85,7 +85,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
KG-storage-log should be used to avoid data corruption KG-storage-log should be used to avoid data corruption
""" """
logger.info(f"Inserting {len(data)} to {self.namespace}") logger.debug(f"Inserting {len(data)} to {self.namespace}")
if not data: if not data:
return return

View File

@@ -392,7 +392,7 @@ class NetworkXStorage(BaseGraphStorage):
# Check if storage was updated by another process # Check if storage was updated by another process
if self.storage_updated.value: if self.storage_updated.value:
# Storage was updated by another process, reload data instead of saving # Storage was updated by another process, reload data instead of saving
logger.warning( logger.info(
f"Graph for {self.namespace} was updated by another process, reloading..." f"Graph for {self.namespace} was updated by another process, reloading..."
) )
self._graph = ( self._graph = (

View File

@@ -361,7 +361,7 @@ class PGKVStorage(BaseKVStorage):
################ INSERT METHODS ################ ################ INSERT METHODS ################
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}") logger.debug(f"Inserting {len(data)} to {self.namespace}")
if not data: if not data:
return return
@@ -560,7 +560,7 @@ class PGVectorStorage(BaseVectorStorage):
return upsert_sql, data return upsert_sql, data
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}") logger.debug(f"Inserting {len(data)} to {self.namespace}")
if not data: if not data:
return return
@@ -949,7 +949,7 @@ class PGDocStatusStorage(DocStatusStorage):
Args: Args:
data: dictionary of document IDs and their status data data: dictionary of document IDs and their status data
""" """
logger.info(f"Inserting {len(data)} to {self.namespace}") logger.debug(f"Inserting {len(data)} to {self.namespace}")
if not data: if not data:
return return

View File

@@ -24,8 +24,8 @@ from .utils import (
handle_cache, handle_cache,
save_to_cache, save_to_cache,
CacheData, CacheData,
statistic_data,
get_conversation_turns, get_conversation_turns,
use_llm_func_with_cache,
) )
from .base import ( from .base import (
BaseGraphStorage, BaseGraphStorage,
@@ -106,6 +106,9 @@ async def _handle_entity_relation_summary(
entity_or_relation_name: str, entity_or_relation_name: str,
description: str, description: str,
global_config: dict, global_config: dict,
pipeline_status: dict = None,
pipeline_status_lock=None,
llm_response_cache: BaseKVStorage | None = None,
) -> str: ) -> str:
"""Handle entity relation summary """Handle entity relation summary
For each entity or relation, input is the combined description of already existing description and new description. For each entity or relation, input is the combined description of already existing description and new description.
@@ -122,6 +125,7 @@ async def _handle_entity_relation_summary(
tokens = encode_string_by_tiktoken(description, model_name=tiktoken_model_name) tokens = encode_string_by_tiktoken(description, model_name=tiktoken_model_name)
if len(tokens) < summary_max_tokens: # No need for summary if len(tokens) < summary_max_tokens: # No need for summary
return description return description
prompt_template = PROMPTS["summarize_entity_descriptions"] prompt_template = PROMPTS["summarize_entity_descriptions"]
use_description = decode_tokens_by_tiktoken( use_description = decode_tokens_by_tiktoken(
tokens[:llm_max_tokens], model_name=tiktoken_model_name tokens[:llm_max_tokens], model_name=tiktoken_model_name
@@ -133,7 +137,23 @@ async def _handle_entity_relation_summary(
) )
use_prompt = prompt_template.format(**context_base) use_prompt = prompt_template.format(**context_base)
logger.debug(f"Trigger summary: {entity_or_relation_name}") logger.debug(f"Trigger summary: {entity_or_relation_name}")
summary = await use_llm_func(use_prompt, max_tokens=summary_max_tokens)
# Update pipeline status when LLM summary is needed
status_message = "Use LLM to re-summary description..."
logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = status_message
pipeline_status["history_messages"].append(status_message)
# Use LLM function with cache
summary = await use_llm_func_with_cache(
use_prompt,
use_llm_func,
llm_response_cache=llm_response_cache,
max_tokens=summary_max_tokens,
cache_type="extract",
)
return summary return summary
@@ -212,6 +232,9 @@ async def _merge_nodes_then_upsert(
nodes_data: list[dict], nodes_data: list[dict],
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
global_config: dict, global_config: dict,
pipeline_status: dict = None,
pipeline_status_lock=None,
llm_response_cache: BaseKVStorage | None = None,
): ):
"""Get existing nodes from knowledge graph use name,if exists, merge data, else create, then upsert.""" """Get existing nodes from knowledge graph use name,if exists, merge data, else create, then upsert."""
already_entity_types = [] already_entity_types = []
@@ -221,6 +244,14 @@ async def _merge_nodes_then_upsert(
already_node = await knowledge_graph_inst.get_node(entity_name) already_node = await knowledge_graph_inst.get_node(entity_name)
if already_node is not None: if already_node is not None:
# Update pipeline status when a node that needs merging is found
status_message = f"Merging entity: {entity_name}"
logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = status_message
pipeline_status["history_messages"].append(status_message)
already_entity_types.append(already_node["entity_type"]) already_entity_types.append(already_node["entity_type"])
already_source_ids.extend( already_source_ids.extend(
split_string_by_multi_markers(already_node["source_id"], [GRAPH_FIELD_SEP]) split_string_by_multi_markers(already_node["source_id"], [GRAPH_FIELD_SEP])
@@ -249,7 +280,12 @@ async def _merge_nodes_then_upsert(
logger.debug(f"file_path: {file_path}") logger.debug(f"file_path: {file_path}")
description = await _handle_entity_relation_summary( description = await _handle_entity_relation_summary(
entity_name, description, global_config entity_name,
description,
global_config,
pipeline_status,
pipeline_status_lock,
llm_response_cache,
) )
node_data = dict( node_data = dict(
entity_id=entity_name, entity_id=entity_name,
@@ -272,6 +308,9 @@ async def _merge_edges_then_upsert(
edges_data: list[dict], edges_data: list[dict],
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
global_config: dict, global_config: dict,
pipeline_status: dict = None,
pipeline_status_lock=None,
llm_response_cache: BaseKVStorage | None = None,
): ):
already_weights = [] already_weights = []
already_source_ids = [] already_source_ids = []
@@ -280,6 +319,14 @@ async def _merge_edges_then_upsert(
already_file_paths = [] already_file_paths = []
if await knowledge_graph_inst.has_edge(src_id, tgt_id): if await knowledge_graph_inst.has_edge(src_id, tgt_id):
# Update pipeline status when an edge that needs merging is found
status_message = f"Merging edge::: {src_id} - {tgt_id}"
logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = status_message
pipeline_status["history_messages"].append(status_message)
already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id) already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
# Handle the case where get_edge returns None or missing fields # Handle the case where get_edge returns None or missing fields
if already_edge: if already_edge:
@@ -358,7 +405,12 @@ async def _merge_edges_then_upsert(
}, },
) )
description = await _handle_entity_relation_summary( description = await _handle_entity_relation_summary(
f"({src_id}, {tgt_id})", description, global_config f"({src_id}, {tgt_id})",
description,
global_config,
pipeline_status,
pipeline_status_lock,
llm_response_cache,
) )
await knowledge_graph_inst.upsert_edge( await knowledge_graph_inst.upsert_edge(
src_id, src_id,
@@ -396,9 +448,6 @@ async def extract_entities(
) -> None: ) -> None:
use_llm_func: callable = global_config["llm_model_func"] use_llm_func: callable = global_config["llm_model_func"]
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
enable_llm_cache_for_entity_extract: bool = global_config[
"enable_llm_cache_for_entity_extract"
]
ordered_chunks = list(chunks.items()) ordered_chunks = list(chunks.items())
# add language and example number params to prompt # add language and example number params to prompt
@@ -449,51 +498,7 @@ async def extract_entities(
graph_db_lock = get_graph_db_lock(enable_logging=False) graph_db_lock = get_graph_db_lock(enable_logging=False)
async def _user_llm_func_with_cache( # Use the global use_llm_func_with_cache function from utils.py
input_text: str, history_messages: list[dict[str, str]] = None
) -> str:
if enable_llm_cache_for_entity_extract and llm_response_cache:
if history_messages:
history = json.dumps(history_messages, ensure_ascii=False)
_prompt = history + "\n" + input_text
else:
_prompt = input_text
# TODO add cache_type="extract"
arg_hash = compute_args_hash(_prompt)
cached_return, _1, _2, _3 = await handle_cache(
llm_response_cache,
arg_hash,
_prompt,
"default",
cache_type="extract",
)
if cached_return:
logger.debug(f"Found cache for {arg_hash}")
statistic_data["llm_cache"] += 1
return cached_return
statistic_data["llm_call"] += 1
if history_messages:
res: str = await use_llm_func(
input_text, history_messages=history_messages
)
else:
res: str = await use_llm_func(input_text)
await save_to_cache(
llm_response_cache,
CacheData(
args_hash=arg_hash,
content=res,
prompt=_prompt,
cache_type="extract",
),
)
return res
if history_messages:
return await use_llm_func(input_text, history_messages=history_messages)
else:
return await use_llm_func(input_text)
async def _process_extraction_result( async def _process_extraction_result(
result: str, chunk_key: str, file_path: str = "unknown_source" result: str, chunk_key: str, file_path: str = "unknown_source"
@@ -558,7 +563,12 @@ async def extract_entities(
**context_base, input_text="{input_text}" **context_base, input_text="{input_text}"
).format(**context_base, input_text=content) ).format(**context_base, input_text=content)
final_result = await _user_llm_func_with_cache(hint_prompt) final_result = await use_llm_func_with_cache(
hint_prompt,
use_llm_func,
llm_response_cache=llm_response_cache,
cache_type="extract",
)
history = pack_user_ass_to_openai_messages(hint_prompt, final_result) history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
# Process initial extraction with file path # Process initial extraction with file path
@@ -568,8 +578,12 @@ async def extract_entities(
# Process additional gleaning results # Process additional gleaning results
for now_glean_index in range(entity_extract_max_gleaning): for now_glean_index in range(entity_extract_max_gleaning):
glean_result = await _user_llm_func_with_cache( glean_result = await use_llm_func_with_cache(
continue_prompt, history_messages=history continue_prompt,
use_llm_func,
llm_response_cache=llm_response_cache,
history_messages=history,
cache_type="extract",
) )
history += pack_user_ass_to_openai_messages(continue_prompt, glean_result) history += pack_user_ass_to_openai_messages(continue_prompt, glean_result)
@@ -588,8 +602,12 @@ async def extract_entities(
if now_glean_index == entity_extract_max_gleaning - 1: if now_glean_index == entity_extract_max_gleaning - 1:
break break
if_loop_result: str = await _user_llm_func_with_cache( if_loop_result: str = await use_llm_func_with_cache(
if_loop_prompt, history_messages=history if_loop_prompt,
use_llm_func,
llm_response_cache=llm_response_cache,
history_messages=history,
cache_type="extract",
) )
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower() if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
if if_loop_result != "yes": if if_loop_result != "yes":
@@ -613,7 +631,13 @@ async def extract_entities(
# Process and update entities # Process and update entities
for entity_name, entities in maybe_nodes.items(): for entity_name, entities in maybe_nodes.items():
entity_data = await _merge_nodes_then_upsert( entity_data = await _merge_nodes_then_upsert(
entity_name, entities, knowledge_graph_inst, global_config entity_name,
entities,
knowledge_graph_inst,
global_config,
pipeline_status,
pipeline_status_lock,
llm_response_cache,
) )
chunk_entities_data.append(entity_data) chunk_entities_data.append(entity_data)
@@ -627,6 +651,9 @@ async def extract_entities(
edges, edges,
knowledge_graph_inst, knowledge_graph_inst,
global_config, global_config,
pipeline_status,
pipeline_status_lock,
llm_response_cache,
) )
chunk_relationships_data.append(edge_data) chunk_relationships_data.append(edge_data)

View File

@@ -12,13 +12,17 @@ import re
from dataclasses import dataclass from dataclasses import dataclass
from functools import wraps from functools import wraps
from hashlib import md5 from hashlib import md5
from typing import Any, Callable from typing import Any, Callable, TYPE_CHECKING
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
import numpy as np import numpy as np
import tiktoken import tiktoken
from lightrag.prompt import PROMPTS from lightrag.prompt import PROMPTS
from dotenv import load_dotenv from dotenv import load_dotenv
# Use TYPE_CHECKING to avoid circular imports
if TYPE_CHECKING:
from lightrag.base import BaseKVStorage
# use the .env that is inside the current folder # use the .env that is inside the current folder
# allows to use different .env file for each lightrag instance # allows to use different .env file for each lightrag instance
# the OS environment variables take precedence over the .env file # the OS environment variables take precedence over the .env file
@@ -908,6 +912,84 @@ def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any
return import_class return import_class
async def use_llm_func_with_cache(
input_text: str,
use_llm_func: callable,
llm_response_cache: "BaseKVStorage | None" = None,
max_tokens: int = None,
history_messages: list[dict[str, str]] = None,
cache_type: str = "extract",
) -> str:
"""Call LLM function with cache support
If cache is available and enabled (determined by handle_cache based on mode),
retrieve result from cache; otherwise call LLM function and save result to cache.
Args:
input_text: Input text to send to LLM
use_llm_func: LLM function to call
llm_response_cache: Cache storage instance
max_tokens: Maximum tokens for generation
history_messages: History messages list
cache_type: Type of cache
Returns:
LLM response text
"""
if llm_response_cache:
if history_messages:
history = json.dumps(history_messages, ensure_ascii=False)
_prompt = history + "\n" + input_text
else:
_prompt = input_text
arg_hash = compute_args_hash(_prompt)
cached_return, _1, _2, _3 = await handle_cache(
llm_response_cache,
arg_hash,
_prompt,
"default",
cache_type=cache_type,
)
if cached_return:
logger.debug(f"Found cache for {arg_hash}")
statistic_data["llm_cache"] += 1
return cached_return
statistic_data["llm_call"] += 1
# Call LLM
kwargs = {}
if history_messages:
kwargs["history_messages"] = history_messages
if max_tokens is not None:
kwargs["max_tokens"] = max_tokens
res: str = await use_llm_func(input_text, **kwargs)
# Save to cache
logger.info(f"Saving LLM cache for {arg_hash}")
await save_to_cache(
llm_response_cache,
CacheData(
args_hash=arg_hash,
content=res,
prompt=_prompt,
cache_type=cache_type,
),
)
return res
# When cache is disabled, directly call LLM
kwargs = {}
if history_messages:
kwargs["history_messages"] = history_messages
if max_tokens is not None:
kwargs["max_tokens"] = max_tokens
logger.info(f"Call LLM function with query text lenght: {len(input_text)}")
return await use_llm_func(input_text, **kwargs)
def get_content_summary(content: str, max_length: int = 250) -> str: def get_content_summary(content: str, max_length: int = 250) -> str:
"""Get summary of document content """Get summary of document content