Refactor: Entity and edge merging in extract_entities

- Improves efficiency by merging identical
entities and edges in a single operation
- Esures proper handling of undirected graph edges
- Change merge stage from chunk leve to document level
This commit is contained in:
yangdx
2025-04-10 14:19:06 +08:00
parent 28d462e46f
commit 35431644ad

View File

@@ -139,7 +139,7 @@ async def _handle_entity_relation_summary(
logger.debug(f"Trigger summary: {entity_or_relation_name}") logger.debug(f"Trigger summary: {entity_or_relation_name}")
# Update pipeline status when LLM summary is needed # Update pipeline status when LLM summary is needed
status_message = "Use LLM to re-summary description..." status_message = " == Use LLM == to re-summary description..."
logger.info(status_message) logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None: if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock: async with pipeline_status_lock:
@@ -244,14 +244,6 @@ async def _merge_nodes_then_upsert(
already_node = await knowledge_graph_inst.get_node(entity_name) already_node = await knowledge_graph_inst.get_node(entity_name)
if already_node is not None: if already_node is not None:
# Update pipeline status when a node that needs merging is found
status_message = f"Merging entity: {entity_name}"
logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = status_message
pipeline_status["history_messages"].append(status_message)
already_entity_types.append(already_node["entity_type"]) already_entity_types.append(already_node["entity_type"])
already_source_ids.extend( already_source_ids.extend(
split_string_by_multi_markers(already_node["source_id"], [GRAPH_FIELD_SEP]) split_string_by_multi_markers(already_node["source_id"], [GRAPH_FIELD_SEP])
@@ -278,7 +270,15 @@ async def _merge_nodes_then_upsert(
set([dp["file_path"] for dp in nodes_data] + already_file_paths) set([dp["file_path"] for dp in nodes_data] + already_file_paths)
) )
logger.debug(f"file_path: {file_path}") if len(nodes_data) > 1 or len(already_entity_types) > 0:
# Update pipeline status when a node that needs merging
status_message = f"Merging entity: {entity_name} | {len(nodes_data)}+{len(already_entity_types)}"
logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = status_message
pipeline_status["history_messages"].append(status_message)
description = await _handle_entity_relation_summary( description = await _handle_entity_relation_summary(
entity_name, entity_name,
description, description,
@@ -287,6 +287,7 @@ async def _merge_nodes_then_upsert(
pipeline_status_lock, pipeline_status_lock,
llm_response_cache, llm_response_cache,
) )
node_data = dict( node_data = dict(
entity_id=entity_name, entity_id=entity_name,
entity_type=entity_type, entity_type=entity_type,
@@ -319,14 +320,6 @@ async def _merge_edges_then_upsert(
already_file_paths = [] already_file_paths = []
if await knowledge_graph_inst.has_edge(src_id, tgt_id): if await knowledge_graph_inst.has_edge(src_id, tgt_id):
# Update pipeline status when an edge that needs merging is found
status_message = f"Merging edge::: {src_id} - {tgt_id}"
logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = status_message
pipeline_status["history_messages"].append(status_message)
already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id) already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
# Handle the case where get_edge returns None or missing fields # Handle the case where get_edge returns None or missing fields
if already_edge: if already_edge:
@@ -404,6 +397,16 @@ async def _merge_edges_then_upsert(
"file_path": file_path, "file_path": file_path,
}, },
) )
if len(edges_data) > 1 or len(already_weights) > 0:
# Update pipeline status when a edge that needs merging
status_message = f"Merging edge::: {src_id} - {tgt_id} | {len(edges_data)}+{len(already_weights)}"
logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = status_message
pipeline_status["history_messages"].append(status_message)
description = await _handle_entity_relation_summary( description = await _handle_entity_relation_summary(
f"({src_id}, {tgt_id})", f"({src_id}, {tgt_id})",
description, description,
@@ -412,6 +415,7 @@ async def _merge_edges_then_upsert(
pipeline_status_lock, pipeline_status_lock,
llm_response_cache, llm_response_cache,
) )
await knowledge_graph_inst.upsert_edge( await knowledge_graph_inst.upsert_edge(
src_id, src_id,
tgt_id, tgt_id,
@@ -550,8 +554,10 @@ async def extract_entities(
Args: Args:
chunk_key_dp (tuple[str, TextChunkSchema]): chunk_key_dp (tuple[str, TextChunkSchema]):
("chunk-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int}) ("chunk-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int})
Returns:
tuple: (maybe_nodes, maybe_edges) containing extracted entities and relationships
""" """
nonlocal processed_chunks, total_entities_count, total_relations_count nonlocal processed_chunks
chunk_key = chunk_key_dp[0] chunk_key = chunk_key_dp[0]
chunk_dp = chunk_key_dp[1] chunk_dp = chunk_key_dp[1]
content = chunk_dp["content"] content = chunk_dp["content"]
@@ -623,13 +629,35 @@ async def extract_entities(
pipeline_status["latest_message"] = log_message pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message) pipeline_status["history_messages"].append(log_message)
# Use graph database lock to ensure atomic merges and updates # Return the extracted nodes and edges for centralized processing
chunk_entities_data = [] return maybe_nodes, maybe_edges
chunk_relationships_data = []
async with graph_db_lock: # Handle all chunks in parallel and collect results
# Process and update entities tasks = [_process_single_content(c) for c in ordered_chunks]
chunk_results = await asyncio.gather(*tasks)
# Collect all nodes and edges from all chunks
all_nodes = defaultdict(list)
all_edges = defaultdict(list)
for maybe_nodes, maybe_edges in chunk_results:
# Collect nodes
for entity_name, entities in maybe_nodes.items(): for entity_name, entities in maybe_nodes.items():
all_nodes[entity_name].extend(entities)
# Collect edges with sorted keys for undirected graph
for edge_key, edges in maybe_edges.items():
sorted_edge_key = tuple(sorted(edge_key))
all_edges[sorted_edge_key].extend(edges)
# Centralized processing of all nodes and edges
entities_data = []
relationships_data = []
# Use graph database lock to ensure atomic merges and updates
async with graph_db_lock:
# Process and update all entities at once
for entity_name, entities in all_nodes.items():
entity_data = await _merge_nodes_then_upsert( entity_data = await _merge_nodes_then_upsert(
entity_name, entity_name,
entities, entities,
@@ -639,15 +667,13 @@ async def extract_entities(
pipeline_status_lock, pipeline_status_lock,
llm_response_cache, llm_response_cache,
) )
chunk_entities_data.append(entity_data) entities_data.append(entity_data)
# Process and update relationships # Process and update all relationships at once
for edge_key, edges in maybe_edges.items(): for edge_key, edges in all_edges.items():
# Ensure edge direction consistency
sorted_edge_key = tuple(sorted(edge_key))
edge_data = await _merge_edges_then_upsert( edge_data = await _merge_edges_then_upsert(
sorted_edge_key[0], edge_key[0],
sorted_edge_key[1], edge_key[1],
edges, edges,
knowledge_graph_inst, knowledge_graph_inst,
global_config, global_config,
@@ -655,10 +681,10 @@ async def extract_entities(
pipeline_status_lock, pipeline_status_lock,
llm_response_cache, llm_response_cache,
) )
chunk_relationships_data.append(edge_data) relationships_data.append(edge_data)
# Update vector database (within the same lock to ensure atomicity) # Update vector databases with all collected data
if entity_vdb is not None and chunk_entities_data: if entity_vdb is not None and entities_data:
data_for_vdb = { data_for_vdb = {
compute_mdhash_id(dp["entity_name"], prefix="ent-"): { compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
"entity_name": dp["entity_name"], "entity_name": dp["entity_name"],
@@ -667,11 +693,11 @@ async def extract_entities(
"source_id": dp["source_id"], "source_id": dp["source_id"],
"file_path": dp.get("file_path", "unknown_source"), "file_path": dp.get("file_path", "unknown_source"),
} }
for dp in chunk_entities_data for dp in entities_data
} }
await entity_vdb.upsert(data_for_vdb) await entity_vdb.upsert(data_for_vdb)
if relationships_vdb is not None and chunk_relationships_data: if relationships_vdb is not None and relationships_data:
data_for_vdb = { data_for_vdb = {
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): { compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
"src_id": dp["src_id"], "src_id": dp["src_id"],
@@ -681,17 +707,13 @@ async def extract_entities(
"source_id": dp["source_id"], "source_id": dp["source_id"],
"file_path": dp.get("file_path", "unknown_source"), "file_path": dp.get("file_path", "unknown_source"),
} }
for dp in chunk_relationships_data for dp in relationships_data
} }
await relationships_vdb.upsert(data_for_vdb) await relationships_vdb.upsert(data_for_vdb)
# Update counters # Update total counts
total_entities_count += len(chunk_entities_data) total_entities_count = len(entities_data)
total_relations_count += len(chunk_relationships_data) total_relations_count = len(relationships_data)
# Handle all chunks in parallel
tasks = [_process_single_content(c) for c in ordered_chunks]
await asyncio.gather(*tasks)
log_message = f"Extracted {total_entities_count} entities + {total_relations_count} relationships (total)" log_message = f"Extracted {total_entities_count} entities + {total_relations_count} relationships (total)"
logger.info(log_message) logger.info(log_message)