Some enhancements:
- Enable the llm_cache storage to support get_by_mode_and_id, to improve the performance for using real KV server - Provide an option for the developers to cache the LLM response when extracting entities for a document. Solving the paint point that sometimes the process failed, the processed chunks we need to call LLM again, money and time wasted. With the new option (by default not enabled) enabling, we can cache that result, can significantly save the time and money for beginners.
This commit is contained in:
@@ -253,9 +253,13 @@ async def extract_entities(
|
||||
entity_vdb: BaseVectorStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
global_config: dict,
|
||||
llm_response_cache: BaseKVStorage = None,
|
||||
) -> Union[BaseGraphStorage, None]:
|
||||
use_llm_func: callable = global_config["llm_model_func"]
|
||||
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
||||
enable_llm_cache_for_entity_extract: bool = global_config[
|
||||
"enable_llm_cache_for_entity_extract"
|
||||
]
|
||||
|
||||
ordered_chunks = list(chunks.items())
|
||||
# add language and example number params to prompt
|
||||
@@ -300,6 +304,52 @@ async def extract_entities(
|
||||
already_entities = 0
|
||||
already_relations = 0
|
||||
|
||||
async def _user_llm_func_with_cache(
|
||||
input_text: str, history_messages: list[dict[str, str]] = None
|
||||
) -> str:
|
||||
if enable_llm_cache_for_entity_extract and llm_response_cache:
|
||||
need_to_restore = False
|
||||
if (
|
||||
global_config["embedding_cache_config"]
|
||||
and global_config["embedding_cache_config"]["enabled"]
|
||||
):
|
||||
new_config = global_config.copy()
|
||||
new_config["embedding_cache_config"] = None
|
||||
new_config["enable_llm_cache"] = True
|
||||
llm_response_cache.global_config = new_config
|
||||
need_to_restore = True
|
||||
if history_messages:
|
||||
history = json.dumps(history_messages)
|
||||
_prompt = history + "\n" + input_text
|
||||
else:
|
||||
_prompt = input_text
|
||||
|
||||
arg_hash = compute_args_hash(_prompt)
|
||||
cached_return, _1, _2, _3 = await handle_cache(
|
||||
llm_response_cache, arg_hash, _prompt, "default"
|
||||
)
|
||||
if need_to_restore:
|
||||
llm_response_cache.global_config = global_config
|
||||
if cached_return:
|
||||
return cached_return
|
||||
|
||||
if history_messages:
|
||||
res: str = await use_llm_func(
|
||||
input_text, history_messages=history_messages
|
||||
)
|
||||
else:
|
||||
res: str = await use_llm_func(input_text)
|
||||
await save_to_cache(
|
||||
llm_response_cache,
|
||||
CacheData(args_hash=arg_hash, content=res, prompt=_prompt),
|
||||
)
|
||||
return res
|
||||
|
||||
if history_messages:
|
||||
return await use_llm_func(input_text, history_messages=history_messages)
|
||||
else:
|
||||
return await use_llm_func(input_text)
|
||||
|
||||
async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
|
||||
nonlocal already_processed, already_entities, already_relations
|
||||
chunk_key = chunk_key_dp[0]
|
||||
@@ -310,17 +360,19 @@ async def extract_entities(
|
||||
**context_base, input_text="{input_text}"
|
||||
).format(**context_base, input_text=content)
|
||||
|
||||
final_result = await use_llm_func(hint_prompt)
|
||||
final_result = await _user_llm_func_with_cache(hint_prompt)
|
||||
history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
|
||||
for now_glean_index in range(entity_extract_max_gleaning):
|
||||
glean_result = await use_llm_func(continue_prompt, history_messages=history)
|
||||
glean_result = await _user_llm_func_with_cache(
|
||||
continue_prompt, history_messages=history
|
||||
)
|
||||
|
||||
history += pack_user_ass_to_openai_messages(continue_prompt, glean_result)
|
||||
final_result += glean_result
|
||||
if now_glean_index == entity_extract_max_gleaning - 1:
|
||||
break
|
||||
|
||||
if_loop_result: str = await use_llm_func(
|
||||
if_loop_result: str = await _user_llm_func_with_cache(
|
||||
if_loop_prompt, history_messages=history
|
||||
)
|
||||
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
|
||||
|
Reference in New Issue
Block a user