Refactor: Entity and edge merging in extract_entities
- Improves efficiency by merging identical entities and edges in a single operation - Esures proper handling of undirected graph edges - Change merge stage from chunk leve to document level
This commit is contained in:
@@ -139,7 +139,7 @@ async def _handle_entity_relation_summary(
|
|||||||
logger.debug(f"Trigger summary: {entity_or_relation_name}")
|
logger.debug(f"Trigger summary: {entity_or_relation_name}")
|
||||||
|
|
||||||
# Update pipeline status when LLM summary is needed
|
# Update pipeline status when LLM summary is needed
|
||||||
status_message = "Use LLM to re-summary description..."
|
status_message = " == Use LLM == to re-summary description..."
|
||||||
logger.info(status_message)
|
logger.info(status_message)
|
||||||
if pipeline_status is not None and pipeline_status_lock is not None:
|
if pipeline_status is not None and pipeline_status_lock is not None:
|
||||||
async with pipeline_status_lock:
|
async with pipeline_status_lock:
|
||||||
@@ -244,14 +244,6 @@ async def _merge_nodes_then_upsert(
|
|||||||
|
|
||||||
already_node = await knowledge_graph_inst.get_node(entity_name)
|
already_node = await knowledge_graph_inst.get_node(entity_name)
|
||||||
if already_node is not None:
|
if already_node is not None:
|
||||||
# Update pipeline status when a node that needs merging is found
|
|
||||||
status_message = f"Merging entity: {entity_name}"
|
|
||||||
logger.info(status_message)
|
|
||||||
if pipeline_status is not None and pipeline_status_lock is not None:
|
|
||||||
async with pipeline_status_lock:
|
|
||||||
pipeline_status["latest_message"] = status_message
|
|
||||||
pipeline_status["history_messages"].append(status_message)
|
|
||||||
|
|
||||||
already_entity_types.append(already_node["entity_type"])
|
already_entity_types.append(already_node["entity_type"])
|
||||||
already_source_ids.extend(
|
already_source_ids.extend(
|
||||||
split_string_by_multi_markers(already_node["source_id"], [GRAPH_FIELD_SEP])
|
split_string_by_multi_markers(already_node["source_id"], [GRAPH_FIELD_SEP])
|
||||||
@@ -278,15 +270,24 @@ async def _merge_nodes_then_upsert(
|
|||||||
set([dp["file_path"] for dp in nodes_data] + already_file_paths)
|
set([dp["file_path"] for dp in nodes_data] + already_file_paths)
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"file_path: {file_path}")
|
if len(nodes_data) > 1 or len(already_entity_types) > 0:
|
||||||
description = await _handle_entity_relation_summary(
|
# Update pipeline status when a node that needs merging
|
||||||
entity_name,
|
status_message = f"Merging entity: {entity_name} | {len(nodes_data)}+{len(already_entity_types)}"
|
||||||
description,
|
logger.info(status_message)
|
||||||
global_config,
|
if pipeline_status is not None and pipeline_status_lock is not None:
|
||||||
pipeline_status,
|
async with pipeline_status_lock:
|
||||||
pipeline_status_lock,
|
pipeline_status["latest_message"] = status_message
|
||||||
llm_response_cache,
|
pipeline_status["history_messages"].append(status_message)
|
||||||
)
|
|
||||||
|
description = await _handle_entity_relation_summary(
|
||||||
|
entity_name,
|
||||||
|
description,
|
||||||
|
global_config,
|
||||||
|
pipeline_status,
|
||||||
|
pipeline_status_lock,
|
||||||
|
llm_response_cache,
|
||||||
|
)
|
||||||
|
|
||||||
node_data = dict(
|
node_data = dict(
|
||||||
entity_id=entity_name,
|
entity_id=entity_name,
|
||||||
entity_type=entity_type,
|
entity_type=entity_type,
|
||||||
@@ -319,14 +320,6 @@ async def _merge_edges_then_upsert(
|
|||||||
already_file_paths = []
|
already_file_paths = []
|
||||||
|
|
||||||
if await knowledge_graph_inst.has_edge(src_id, tgt_id):
|
if await knowledge_graph_inst.has_edge(src_id, tgt_id):
|
||||||
# Update pipeline status when an edge that needs merging is found
|
|
||||||
status_message = f"Merging edge::: {src_id} - {tgt_id}"
|
|
||||||
logger.info(status_message)
|
|
||||||
if pipeline_status is not None and pipeline_status_lock is not None:
|
|
||||||
async with pipeline_status_lock:
|
|
||||||
pipeline_status["latest_message"] = status_message
|
|
||||||
pipeline_status["history_messages"].append(status_message)
|
|
||||||
|
|
||||||
already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
|
already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
|
||||||
# Handle the case where get_edge returns None or missing fields
|
# Handle the case where get_edge returns None or missing fields
|
||||||
if already_edge:
|
if already_edge:
|
||||||
@@ -404,14 +397,25 @@ async def _merge_edges_then_upsert(
|
|||||||
"file_path": file_path,
|
"file_path": file_path,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
description = await _handle_entity_relation_summary(
|
|
||||||
f"({src_id}, {tgt_id})",
|
if len(edges_data) > 1 or len(already_weights) > 0:
|
||||||
description,
|
# Update pipeline status when a edge that needs merging
|
||||||
global_config,
|
status_message = f"Merging edge::: {src_id} - {tgt_id} | {len(edges_data)}+{len(already_weights)}"
|
||||||
pipeline_status,
|
logger.info(status_message)
|
||||||
pipeline_status_lock,
|
if pipeline_status is not None and pipeline_status_lock is not None:
|
||||||
llm_response_cache,
|
async with pipeline_status_lock:
|
||||||
)
|
pipeline_status["latest_message"] = status_message
|
||||||
|
pipeline_status["history_messages"].append(status_message)
|
||||||
|
|
||||||
|
description = await _handle_entity_relation_summary(
|
||||||
|
f"({src_id}, {tgt_id})",
|
||||||
|
description,
|
||||||
|
global_config,
|
||||||
|
pipeline_status,
|
||||||
|
pipeline_status_lock,
|
||||||
|
llm_response_cache,
|
||||||
|
)
|
||||||
|
|
||||||
await knowledge_graph_inst.upsert_edge(
|
await knowledge_graph_inst.upsert_edge(
|
||||||
src_id,
|
src_id,
|
||||||
tgt_id,
|
tgt_id,
|
||||||
@@ -550,8 +554,10 @@ async def extract_entities(
|
|||||||
Args:
|
Args:
|
||||||
chunk_key_dp (tuple[str, TextChunkSchema]):
|
chunk_key_dp (tuple[str, TextChunkSchema]):
|
||||||
("chunk-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int})
|
("chunk-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int})
|
||||||
|
Returns:
|
||||||
|
tuple: (maybe_nodes, maybe_edges) containing extracted entities and relationships
|
||||||
"""
|
"""
|
||||||
nonlocal processed_chunks, total_entities_count, total_relations_count
|
nonlocal processed_chunks
|
||||||
chunk_key = chunk_key_dp[0]
|
chunk_key = chunk_key_dp[0]
|
||||||
chunk_dp = chunk_key_dp[1]
|
chunk_dp = chunk_key_dp[1]
|
||||||
content = chunk_dp["content"]
|
content = chunk_dp["content"]
|
||||||
@@ -623,75 +629,91 @@ async def extract_entities(
|
|||||||
pipeline_status["latest_message"] = log_message
|
pipeline_status["latest_message"] = log_message
|
||||||
pipeline_status["history_messages"].append(log_message)
|
pipeline_status["history_messages"].append(log_message)
|
||||||
|
|
||||||
# Use graph database lock to ensure atomic merges and updates
|
# Return the extracted nodes and edges for centralized processing
|
||||||
chunk_entities_data = []
|
return maybe_nodes, maybe_edges
|
||||||
chunk_relationships_data = []
|
|
||||||
|
|
||||||
async with graph_db_lock:
|
# Handle all chunks in parallel and collect results
|
||||||
# Process and update entities
|
|
||||||
for entity_name, entities in maybe_nodes.items():
|
|
||||||
entity_data = await _merge_nodes_then_upsert(
|
|
||||||
entity_name,
|
|
||||||
entities,
|
|
||||||
knowledge_graph_inst,
|
|
||||||
global_config,
|
|
||||||
pipeline_status,
|
|
||||||
pipeline_status_lock,
|
|
||||||
llm_response_cache,
|
|
||||||
)
|
|
||||||
chunk_entities_data.append(entity_data)
|
|
||||||
|
|
||||||
# Process and update relationships
|
|
||||||
for edge_key, edges in maybe_edges.items():
|
|
||||||
# Ensure edge direction consistency
|
|
||||||
sorted_edge_key = tuple(sorted(edge_key))
|
|
||||||
edge_data = await _merge_edges_then_upsert(
|
|
||||||
sorted_edge_key[0],
|
|
||||||
sorted_edge_key[1],
|
|
||||||
edges,
|
|
||||||
knowledge_graph_inst,
|
|
||||||
global_config,
|
|
||||||
pipeline_status,
|
|
||||||
pipeline_status_lock,
|
|
||||||
llm_response_cache,
|
|
||||||
)
|
|
||||||
chunk_relationships_data.append(edge_data)
|
|
||||||
|
|
||||||
# Update vector database (within the same lock to ensure atomicity)
|
|
||||||
if entity_vdb is not None and chunk_entities_data:
|
|
||||||
data_for_vdb = {
|
|
||||||
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
|
||||||
"entity_name": dp["entity_name"],
|
|
||||||
"entity_type": dp["entity_type"],
|
|
||||||
"content": f"{dp['entity_name']}\n{dp['description']}",
|
|
||||||
"source_id": dp["source_id"],
|
|
||||||
"file_path": dp.get("file_path", "unknown_source"),
|
|
||||||
}
|
|
||||||
for dp in chunk_entities_data
|
|
||||||
}
|
|
||||||
await entity_vdb.upsert(data_for_vdb)
|
|
||||||
|
|
||||||
if relationships_vdb is not None and chunk_relationships_data:
|
|
||||||
data_for_vdb = {
|
|
||||||
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
|
||||||
"src_id": dp["src_id"],
|
|
||||||
"tgt_id": dp["tgt_id"],
|
|
||||||
"keywords": dp["keywords"],
|
|
||||||
"content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}",
|
|
||||||
"source_id": dp["source_id"],
|
|
||||||
"file_path": dp.get("file_path", "unknown_source"),
|
|
||||||
}
|
|
||||||
for dp in chunk_relationships_data
|
|
||||||
}
|
|
||||||
await relationships_vdb.upsert(data_for_vdb)
|
|
||||||
|
|
||||||
# Update counters
|
|
||||||
total_entities_count += len(chunk_entities_data)
|
|
||||||
total_relations_count += len(chunk_relationships_data)
|
|
||||||
|
|
||||||
# Handle all chunks in parallel
|
|
||||||
tasks = [_process_single_content(c) for c in ordered_chunks]
|
tasks = [_process_single_content(c) for c in ordered_chunks]
|
||||||
await asyncio.gather(*tasks)
|
chunk_results = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
# Collect all nodes and edges from all chunks
|
||||||
|
all_nodes = defaultdict(list)
|
||||||
|
all_edges = defaultdict(list)
|
||||||
|
|
||||||
|
for maybe_nodes, maybe_edges in chunk_results:
|
||||||
|
# Collect nodes
|
||||||
|
for entity_name, entities in maybe_nodes.items():
|
||||||
|
all_nodes[entity_name].extend(entities)
|
||||||
|
|
||||||
|
# Collect edges with sorted keys for undirected graph
|
||||||
|
for edge_key, edges in maybe_edges.items():
|
||||||
|
sorted_edge_key = tuple(sorted(edge_key))
|
||||||
|
all_edges[sorted_edge_key].extend(edges)
|
||||||
|
|
||||||
|
# Centralized processing of all nodes and edges
|
||||||
|
entities_data = []
|
||||||
|
relationships_data = []
|
||||||
|
|
||||||
|
# Use graph database lock to ensure atomic merges and updates
|
||||||
|
async with graph_db_lock:
|
||||||
|
# Process and update all entities at once
|
||||||
|
for entity_name, entities in all_nodes.items():
|
||||||
|
entity_data = await _merge_nodes_then_upsert(
|
||||||
|
entity_name,
|
||||||
|
entities,
|
||||||
|
knowledge_graph_inst,
|
||||||
|
global_config,
|
||||||
|
pipeline_status,
|
||||||
|
pipeline_status_lock,
|
||||||
|
llm_response_cache,
|
||||||
|
)
|
||||||
|
entities_data.append(entity_data)
|
||||||
|
|
||||||
|
# Process and update all relationships at once
|
||||||
|
for edge_key, edges in all_edges.items():
|
||||||
|
edge_data = await _merge_edges_then_upsert(
|
||||||
|
edge_key[0],
|
||||||
|
edge_key[1],
|
||||||
|
edges,
|
||||||
|
knowledge_graph_inst,
|
||||||
|
global_config,
|
||||||
|
pipeline_status,
|
||||||
|
pipeline_status_lock,
|
||||||
|
llm_response_cache,
|
||||||
|
)
|
||||||
|
relationships_data.append(edge_data)
|
||||||
|
|
||||||
|
# Update vector databases with all collected data
|
||||||
|
if entity_vdb is not None and entities_data:
|
||||||
|
data_for_vdb = {
|
||||||
|
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
||||||
|
"entity_name": dp["entity_name"],
|
||||||
|
"entity_type": dp["entity_type"],
|
||||||
|
"content": f"{dp['entity_name']}\n{dp['description']}",
|
||||||
|
"source_id": dp["source_id"],
|
||||||
|
"file_path": dp.get("file_path", "unknown_source"),
|
||||||
|
}
|
||||||
|
for dp in entities_data
|
||||||
|
}
|
||||||
|
await entity_vdb.upsert(data_for_vdb)
|
||||||
|
|
||||||
|
if relationships_vdb is not None and relationships_data:
|
||||||
|
data_for_vdb = {
|
||||||
|
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
||||||
|
"src_id": dp["src_id"],
|
||||||
|
"tgt_id": dp["tgt_id"],
|
||||||
|
"keywords": dp["keywords"],
|
||||||
|
"content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}",
|
||||||
|
"source_id": dp["source_id"],
|
||||||
|
"file_path": dp.get("file_path", "unknown_source"),
|
||||||
|
}
|
||||||
|
for dp in relationships_data
|
||||||
|
}
|
||||||
|
await relationships_vdb.upsert(data_for_vdb)
|
||||||
|
|
||||||
|
# Update total counts
|
||||||
|
total_entities_count = len(entities_data)
|
||||||
|
total_relations_count = len(relationships_data)
|
||||||
|
|
||||||
log_message = f"Extracted {total_entities_count} entities + {total_relations_count} relationships (total)"
|
log_message = f"Extracted {total_entities_count} entities + {total_relations_count} relationships (total)"
|
||||||
logger.info(log_message)
|
logger.info(log_message)
|
||||||
|
Reference in New Issue
Block a user