Improve parallel handling logic between extraction and merge operation

This commit is contained in:
yangdx
2025-04-28 01:14:00 +08:00
parent 4acc2adc32
commit 18040aa95c
2 changed files with 199 additions and 120 deletions

View File

@@ -46,6 +46,7 @@ from .namespace import NameSpace, make_namespace
from .operate import ( from .operate import (
chunking_by_token_size, chunking_by_token_size,
extract_entities, extract_entities,
merge_nodes_and_edges,
kg_query, kg_query,
mix_kg_vector_query, mix_kg_vector_query,
naive_query, naive_query,
@@ -902,6 +903,7 @@ class LightRAG:
semaphore: asyncio.Semaphore, semaphore: asyncio.Semaphore,
) -> None: ) -> None:
"""Process single document""" """Process single document"""
file_extraction_stage_ok = False
async with semaphore: async with semaphore:
nonlocal processed_count nonlocal processed_count
current_file_number = 0 current_file_number = 0
@@ -919,7 +921,7 @@ class LightRAG:
) )
pipeline_status["cur_batch"] = processed_count pipeline_status["cur_batch"] = processed_count
log_message = f"Processing file ({current_file_number}/{total_files}): {file_path}" log_message = f"Processing file {current_file_number}/{total_files}: {file_path}"
logger.info(log_message) logger.info(log_message)
pipeline_status["history_messages"].append(log_message) pipeline_status["history_messages"].append(log_message)
log_message = f"Processing d-id: {doc_id}" log_message = f"Processing d-id: {doc_id}"
@@ -986,6 +988,61 @@ class LightRAG:
text_chunks_task, text_chunks_task,
] ]
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
file_extraction_stage_ok = True
except Exception as e:
# Log error and update pipeline status
error_msg = f"Failed to extrat document {doc_id}: {traceback.format_exc()}"
logger.error(error_msg)
async with pipeline_status_lock:
pipeline_status["latest_message"] = error_msg
pipeline_status["history_messages"].append(error_msg)
# Cancel other tasks as they are no longer meaningful
for task in [
chunks_vdb_task,
entity_relation_task,
full_docs_task,
text_chunks_task,
]:
if not task.done():
task.cancel()
# Update document status to failed
await self.doc_status.upsert(
{
doc_id: {
"status": DocStatus.FAILED,
"error": str(e),
"content": status_doc.content,
"content_summary": status_doc.content_summary,
"content_length": status_doc.content_length,
"created_at": status_doc.created_at,
"updated_at": datetime.now().isoformat(),
"file_path": file_path,
}
}
)
# Release semphore before entering to merge stage
if file_extraction_stage_ok:
try:
# Get chunk_results from entity_relation_task
chunk_results = await entity_relation_task
await merge_nodes_and_edges(
chunk_results=chunk_results, # result collected from entity_relation_task
knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb,
global_config=asdict(self),
pipeline_status=pipeline_status,
pipeline_status_lock=pipeline_status_lock,
llm_response_cache=self.llm_response_cache,
current_file_number=current_file_number,
total_files=total_files,
file_path=file_path,
)
await self.doc_status.upsert( await self.doc_status.upsert(
{ {
doc_id: { doc_id: {
@@ -1012,22 +1069,12 @@ class LightRAG:
except Exception as e: except Exception as e:
# Log error and update pipeline status # Log error and update pipeline status
error_msg = f"Failed to process document {doc_id}: {traceback.format_exc()}" error_msg = f"Merging stage failed in document {doc_id}: {traceback.format_exc()}"
logger.error(error_msg) logger.error(error_msg)
async with pipeline_status_lock: async with pipeline_status_lock:
pipeline_status["latest_message"] = error_msg pipeline_status["latest_message"] = error_msg
pipeline_status["history_messages"].append(error_msg) pipeline_status["history_messages"].append(error_msg)
# Cancel other tasks as they are no longer meaningful
for task in [
chunks_vdb_task,
entity_relation_task,
full_docs_task,
text_chunks_task,
]:
if not task.done():
task.cancel()
# Update document status to failed # Update document status to failed
await self.doc_status.upsert( await self.doc_status.upsert(
{ {
@@ -1101,9 +1148,9 @@ class LightRAG:
async def _process_entity_relation_graph( async def _process_entity_relation_graph(
self, chunk: dict[str, Any], pipeline_status=None, pipeline_status_lock=None self, chunk: dict[str, Any], pipeline_status=None, pipeline_status_lock=None
) -> None: ) -> list:
try: try:
await extract_entities( chunk_results = await extract_entities(
chunk, chunk,
knowledge_graph_inst=self.chunk_entity_relation_graph, knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb, entity_vdb=self.entities_vdb,
@@ -1113,6 +1160,7 @@ class LightRAG:
pipeline_status_lock=pipeline_status_lock, pipeline_status_lock=pipeline_status_lock,
llm_response_cache=self.llm_response_cache, llm_response_cache=self.llm_response_cache,
) )
return chunk_results
except Exception as e: except Exception as e:
error_msg = f"Failed to extract entities and relationships: {str(e)}" error_msg = f"Failed to extract entities and relationships: {str(e)}"
logger.error(error_msg) logger.error(error_msg)

View File

@@ -476,6 +476,139 @@ async def _merge_edges_then_upsert(
return edge_data return edge_data
async def merge_nodes_and_edges(
chunk_results: list,
knowledge_graph_inst: BaseGraphStorage,
entity_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
global_config: dict[str, str],
pipeline_status: dict = None,
pipeline_status_lock=None,
llm_response_cache: BaseKVStorage | None = None,
current_file_number: int = 0,
total_files: int = 0,
file_path: str = "unknown_source",
) -> None:
"""Merge nodes and edges from extraction results
Args:
chunk_results: List of tuples (maybe_nodes, maybe_edges) containing extracted entities and relationships
knowledge_graph_inst: Knowledge graph storage
entity_vdb: Entity vector database
relationships_vdb: Relationship vector database
global_config: Global configuration
pipeline_status: Pipeline status dictionary
pipeline_status_lock: Lock for pipeline status
llm_response_cache: LLM response cache
"""
# Get lock manager from shared storage
from .kg.shared_storage import get_graph_db_lock
graph_db_lock = get_graph_db_lock(enable_logging=False)
# Collect all nodes and edges from all chunks
all_nodes = defaultdict(list)
all_edges = defaultdict(list)
for maybe_nodes, maybe_edges in chunk_results:
# Collect nodes
for entity_name, entities in maybe_nodes.items():
all_nodes[entity_name].extend(entities)
# Collect edges with sorted keys for undirected graph
for edge_key, edges in maybe_edges.items():
sorted_edge_key = tuple(sorted(edge_key))
all_edges[sorted_edge_key].extend(edges)
# Centralized processing of all nodes and edges
entities_data = []
relationships_data = []
# Merge nodes and edges
# Use graph database lock to ensure atomic merges and updates
async with graph_db_lock:
async with pipeline_status_lock:
log_message = f"Merging nodes/edges {current_file_number}/{total_files}: {file_path}"
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
# Process and update all entities at once
for entity_name, entities in all_nodes.items():
entity_data = await _merge_nodes_then_upsert(
entity_name,
entities,
knowledge_graph_inst,
global_config,
pipeline_status,
pipeline_status_lock,
llm_response_cache,
)
entities_data.append(entity_data)
# Process and update all relationships at once
for edge_key, edges in all_edges.items():
edge_data = await _merge_edges_then_upsert(
edge_key[0],
edge_key[1],
edges,
knowledge_graph_inst,
global_config,
pipeline_status,
pipeline_status_lock,
llm_response_cache,
)
if edge_data is not None:
relationships_data.append(edge_data)
# Update total counts
total_entities_count = len(entities_data)
total_relations_count = len(relationships_data)
log_message = f"Updating {total_entities_count} entities {current_file_number}/{total_files}: {file_path}"
logger.info(log_message)
if pipeline_status is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
# Update vector databases with all collected data
if entity_vdb is not None and entities_data:
data_for_vdb = {
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
"entity_name": dp["entity_name"],
"entity_type": dp["entity_type"],
"content": f"{dp['entity_name']}\n{dp['description']}",
"source_id": dp["source_id"],
"file_path": dp.get("file_path", "unknown_source"),
}
for dp in entities_data
}
await entity_vdb.upsert(data_for_vdb)
log_message = (
f"Updating {total_relations_count} relations {current_file_number}/{total_files}: {file_path}"
)
logger.info(log_message)
if pipeline_status is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
if relationships_vdb is not None and relationships_data:
data_for_vdb = {
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
"src_id": dp["src_id"],
"tgt_id": dp["tgt_id"],
"keywords": dp["keywords"],
"content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}",
"source_id": dp["source_id"],
"file_path": dp.get("file_path", "unknown_source"),
}
for dp in relationships_data
}
await relationships_vdb.upsert(data_for_vdb)
async def extract_entities( async def extract_entities(
chunks: dict[str, TextChunkSchema], chunks: dict[str, TextChunkSchema],
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
@@ -485,7 +618,7 @@ async def extract_entities(
pipeline_status: dict = None, pipeline_status: dict = None,
pipeline_status_lock=None, pipeline_status_lock=None,
llm_response_cache: BaseKVStorage | None = None, llm_response_cache: BaseKVStorage | None = None,
) -> None: ) -> list:
use_llm_func: callable = global_config["llm_model_func"] use_llm_func: callable = global_config["llm_model_func"]
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
@@ -530,15 +663,6 @@ async def extract_entities(
processed_chunks = 0 processed_chunks = 0
total_chunks = len(ordered_chunks) total_chunks = len(ordered_chunks)
total_entities_count = 0
total_relations_count = 0
# Get lock manager from shared storage
from .kg.shared_storage import get_graph_db_lock
graph_db_lock = get_graph_db_lock(enable_logging=False)
# Use the global use_llm_func_with_cache function from utils.py
async def _process_extraction_result( async def _process_extraction_result(
result: str, chunk_key: str, file_path: str = "unknown_source" result: str, chunk_key: str, file_path: str = "unknown_source"
@@ -708,102 +832,9 @@ async def extract_entities(
# If all tasks completed successfully, collect results # If all tasks completed successfully, collect results
chunk_results = [task.result() for task in tasks] chunk_results = [task.result() for task in tasks]
# Collect all nodes and edges from all chunks # Return the chunk_results for later processing in merge_nodes_and_edges
all_nodes = defaultdict(list) return chunk_results
all_edges = defaultdict(list)
for maybe_nodes, maybe_edges in chunk_results:
# Collect nodes
for entity_name, entities in maybe_nodes.items():
all_nodes[entity_name].extend(entities)
# Collect edges with sorted keys for undirected graph
for edge_key, edges in maybe_edges.items():
sorted_edge_key = tuple(sorted(edge_key))
all_edges[sorted_edge_key].extend(edges)
# Centralized processing of all nodes and edges
entities_data = []
relationships_data = []
# Use graph database lock to ensure atomic merges and updates
async with graph_db_lock:
# Process and update all entities at once
for entity_name, entities in all_nodes.items():
entity_data = await _merge_nodes_then_upsert(
entity_name,
entities,
knowledge_graph_inst,
global_config,
pipeline_status,
pipeline_status_lock,
llm_response_cache,
)
entities_data.append(entity_data)
# Process and update all relationships at once
for edge_key, edges in all_edges.items():
edge_data = await _merge_edges_then_upsert(
edge_key[0],
edge_key[1],
edges,
knowledge_graph_inst,
global_config,
pipeline_status,
pipeline_status_lock,
llm_response_cache,
)
if edge_data is not None:
relationships_data.append(edge_data)
# Update total counts
total_entities_count = len(entities_data)
total_relations_count = len(relationships_data)
log_message = f"Updating vector storage: {total_entities_count} entities..."
logger.info(log_message)
if pipeline_status is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
# Update vector databases with all collected data
if entity_vdb is not None and entities_data:
data_for_vdb = {
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
"entity_name": dp["entity_name"],
"entity_type": dp["entity_type"],
"content": f"{dp['entity_name']}\n{dp['description']}",
"source_id": dp["source_id"],
"file_path": dp.get("file_path", "unknown_source"),
}
for dp in entities_data
}
await entity_vdb.upsert(data_for_vdb)
log_message = (
f"Updating vector storage: {total_relations_count} relationships..."
)
logger.info(log_message)
if pipeline_status is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
if relationships_vdb is not None and relationships_data:
data_for_vdb = {
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
"src_id": dp["src_id"],
"tgt_id": dp["tgt_id"],
"keywords": dp["keywords"],
"content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}",
"source_id": dp["source_id"],
"file_path": dp.get("file_path", "unknown_source"),
}
for dp in relationships_data
}
await relationships_vdb.upsert(data_for_vdb)
async def kg_query( async def kg_query(