Merge pull request #590 from jin38324/main

Enhance Robustness of insert Method with Pipeline Processing and Caching Mechanisms
This commit is contained in:
zrguo
2025-01-16 14:20:08 +08:00
committed by GitHub
11 changed files with 510 additions and 191 deletions

View File

@@ -20,6 +20,7 @@ from .utils import (
handle_cache,
save_to_cache,
CacheData,
statistic_data,
)
from .base import (
BaseGraphStorage,
@@ -96,6 +97,10 @@ async def _handle_entity_relation_summary(
description: str,
global_config: dict,
) -> str:
"""Handle entity relation summary
For each entity or relation, input is the combined description of already existing description and new description.
If too long, use LLM to summarize.
"""
use_llm_func: callable = global_config["llm_model_func"]
llm_max_tokens = global_config["llm_model_max_token_size"]
tiktoken_model_name = global_config["tiktoken_model_name"]
@@ -176,6 +181,7 @@ async def _merge_nodes_then_upsert(
knowledge_graph_inst: BaseGraphStorage,
global_config: dict,
):
"""Get existing nodes from knowledge graph use name,if exists, merge data, else create, then upsert."""
already_entity_types = []
already_source_ids = []
already_description = []
@@ -356,7 +362,7 @@ async def extract_entities(
llm_response_cache.global_config = new_config
need_to_restore = True
if history_messages:
history = json.dumps(history_messages)
history = json.dumps(history_messages, ensure_ascii=False)
_prompt = history + "\n" + input_text
else:
_prompt = input_text
@@ -368,8 +374,10 @@ async def extract_entities(
if need_to_restore:
llm_response_cache.global_config = global_config
if cached_return:
logger.debug(f"Found cache for {arg_hash}")
statistic_data["llm_cache"] += 1
return cached_return
statistic_data["llm_call"] += 1
if history_messages:
res: str = await use_llm_func(
input_text, history_messages=history_messages
@@ -388,6 +396,11 @@ async def extract_entities(
return await use_llm_func(input_text)
async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
""" "Prpocess a single chunk
Args:
chunk_key_dp (tuple[str, TextChunkSchema]):
("chunck-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int})
"""
nonlocal already_processed, already_entities, already_relations
chunk_key = chunk_key_dp[0]
chunk_dp = chunk_key_dp[1]
@@ -451,10 +464,8 @@ async def extract_entities(
now_ticks = PROMPTS["process_tickers"][
already_processed % len(PROMPTS["process_tickers"])
]
print(
logger.debug(
f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",
end="",
flush=True,
)
return dict(maybe_nodes), dict(maybe_edges)
@@ -462,8 +473,10 @@ async def extract_entities(
for result in tqdm_async(
asyncio.as_completed([_process_single_content(c) for c in ordered_chunks]),
total=len(ordered_chunks),
desc="Extracting entities from chunks",
desc="Level 2 - Extracting entities and relationships",
unit="chunk",
position=1,
leave=False,
):
results.append(await result)
@@ -474,7 +487,7 @@ async def extract_entities(
maybe_nodes[k].extend(v)
for k, v in m_edges.items():
maybe_edges[tuple(sorted(k))].extend(v)
logger.info("Inserting entities into storage...")
logger.debug("Inserting entities into storage...")
all_entities_data = []
for result in tqdm_async(
asyncio.as_completed(
@@ -484,12 +497,14 @@ async def extract_entities(
]
),
total=len(maybe_nodes),
desc="Inserting entities",
desc="Level 3 - Inserting entities",
unit="entity",
position=2,
leave=False,
):
all_entities_data.append(await result)
logger.info("Inserting relationships into storage...")
logger.debug("Inserting relationships into storage...")
all_relationships_data = []
for result in tqdm_async(
asyncio.as_completed(
@@ -501,8 +516,10 @@ async def extract_entities(
]
),
total=len(maybe_edges),
desc="Inserting relationships",
desc="Level 3 - Inserting relationships",
unit="relationship",
position=3,
leave=False,
):
all_relationships_data.append(await result)