Fix LLM cache now work for nodes and edges merging

This commit is contained in:
yangdx
2025-04-10 03:57:36 +08:00
parent ad087073aa
commit 8d858da4d0
3 changed files with 135 additions and 66 deletions

View File

@@ -12,13 +12,17 @@ import re
from dataclasses import dataclass
from functools import wraps
from hashlib import md5
from typing import Any, Callable
from typing import Any, Callable, TYPE_CHECKING
import xml.etree.ElementTree as ET
import numpy as np
import tiktoken
from lightrag.prompt import PROMPTS
from dotenv import load_dotenv
# Use TYPE_CHECKING to avoid circular imports
if TYPE_CHECKING:
from lightrag.base import BaseKVStorage
# use the .env that is inside the current folder
# allows to use different .env file for each lightrag instance
# the OS environment variables take precedence over the .env file
@@ -908,6 +912,84 @@ def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any
return import_class
async def use_llm_func_with_cache(
input_text: str,
use_llm_func: callable,
llm_response_cache: 'BaseKVStorage | None' = None,
max_tokens: int = None,
history_messages: list[dict[str, str]] = None,
cache_type: str = "extract"
) -> str:
"""Call LLM function with cache support
If cache is available and enabled (determined by handle_cache based on mode),
retrieve result from cache; otherwise call LLM function and save result to cache.
Args:
input_text: Input text to send to LLM
use_llm_func: LLM function to call
llm_response_cache: Cache storage instance
max_tokens: Maximum tokens for generation
history_messages: History messages list
cache_type: Type of cache
Returns:
LLM response text
"""
if llm_response_cache:
if history_messages:
history = json.dumps(history_messages, ensure_ascii=False)
_prompt = history + "\n" + input_text
else:
_prompt = input_text
arg_hash = compute_args_hash(_prompt)
cached_return, _1, _2, _3 = await handle_cache(
llm_response_cache,
arg_hash,
_prompt,
"default",
cache_type=cache_type,
)
if cached_return:
logger.debug(f"Found cache for {arg_hash}")
statistic_data["llm_cache"] += 1
return cached_return
statistic_data["llm_call"] += 1
# Call LLM
kwargs = {}
if history_messages:
kwargs["history_messages"] = history_messages
if max_tokens is not None:
kwargs["max_tokens"] = max_tokens
res: str = await use_llm_func(input_text, **kwargs)
# Save to cache
logger.info(f"Saving LLM cache for {arg_hash}")
await save_to_cache(
llm_response_cache,
CacheData(
args_hash=arg_hash,
content=res,
prompt=_prompt,
cache_type=cache_type,
),
)
return res
# When cache is disabled, directly call LLM
kwargs = {}
if history_messages:
kwargs["history_messages"] = history_messages
if max_tokens is not None:
kwargs["max_tokens"] = max_tokens
logger.info(f"Call LLM function with query text lenght: {len(input_text)}")
return await use_llm_func(input_text, **kwargs)
def get_content_summary(content: str, max_length: int = 250) -> str:
"""Get summary of document content