Improve parallel handling logic between extraction and merge operation
This commit is contained in:
@@ -46,6 +46,7 @@ from .namespace import NameSpace, make_namespace
|
|||||||
from .operate import (
|
from .operate import (
|
||||||
chunking_by_token_size,
|
chunking_by_token_size,
|
||||||
extract_entities,
|
extract_entities,
|
||||||
|
merge_nodes_and_edges,
|
||||||
kg_query,
|
kg_query,
|
||||||
mix_kg_vector_query,
|
mix_kg_vector_query,
|
||||||
naive_query,
|
naive_query,
|
||||||
@@ -902,6 +903,7 @@ class LightRAG:
|
|||||||
semaphore: asyncio.Semaphore,
|
semaphore: asyncio.Semaphore,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Process single document"""
|
"""Process single document"""
|
||||||
|
file_extraction_stage_ok = False
|
||||||
async with semaphore:
|
async with semaphore:
|
||||||
nonlocal processed_count
|
nonlocal processed_count
|
||||||
current_file_number = 0
|
current_file_number = 0
|
||||||
@@ -919,7 +921,7 @@ class LightRAG:
|
|||||||
)
|
)
|
||||||
pipeline_status["cur_batch"] = processed_count
|
pipeline_status["cur_batch"] = processed_count
|
||||||
|
|
||||||
log_message = f"Processing file ({current_file_number}/{total_files}): {file_path}"
|
log_message = f"Processing file {current_file_number}/{total_files}: {file_path}"
|
||||||
logger.info(log_message)
|
logger.info(log_message)
|
||||||
pipeline_status["history_messages"].append(log_message)
|
pipeline_status["history_messages"].append(log_message)
|
||||||
log_message = f"Processing d-id: {doc_id}"
|
log_message = f"Processing d-id: {doc_id}"
|
||||||
@@ -986,6 +988,61 @@ class LightRAG:
|
|||||||
text_chunks_task,
|
text_chunks_task,
|
||||||
]
|
]
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
file_extraction_stage_ok = True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Log error and update pipeline status
|
||||||
|
error_msg = f"Failed to extrat document {doc_id}: {traceback.format_exc()}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
async with pipeline_status_lock:
|
||||||
|
pipeline_status["latest_message"] = error_msg
|
||||||
|
pipeline_status["history_messages"].append(error_msg)
|
||||||
|
|
||||||
|
# Cancel other tasks as they are no longer meaningful
|
||||||
|
for task in [
|
||||||
|
chunks_vdb_task,
|
||||||
|
entity_relation_task,
|
||||||
|
full_docs_task,
|
||||||
|
text_chunks_task,
|
||||||
|
]:
|
||||||
|
if not task.done():
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
# Update document status to failed
|
||||||
|
await self.doc_status.upsert(
|
||||||
|
{
|
||||||
|
doc_id: {
|
||||||
|
"status": DocStatus.FAILED,
|
||||||
|
"error": str(e),
|
||||||
|
"content": status_doc.content,
|
||||||
|
"content_summary": status_doc.content_summary,
|
||||||
|
"content_length": status_doc.content_length,
|
||||||
|
"created_at": status_doc.created_at,
|
||||||
|
"updated_at": datetime.now().isoformat(),
|
||||||
|
"file_path": file_path,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Release semphore before entering to merge stage
|
||||||
|
if file_extraction_stage_ok:
|
||||||
|
try:
|
||||||
|
# Get chunk_results from entity_relation_task
|
||||||
|
chunk_results = await entity_relation_task
|
||||||
|
await merge_nodes_and_edges(
|
||||||
|
chunk_results=chunk_results, # result collected from entity_relation_task
|
||||||
|
knowledge_graph_inst=self.chunk_entity_relation_graph,
|
||||||
|
entity_vdb=self.entities_vdb,
|
||||||
|
relationships_vdb=self.relationships_vdb,
|
||||||
|
global_config=asdict(self),
|
||||||
|
pipeline_status=pipeline_status,
|
||||||
|
pipeline_status_lock=pipeline_status_lock,
|
||||||
|
llm_response_cache=self.llm_response_cache,
|
||||||
|
current_file_number=current_file_number,
|
||||||
|
total_files=total_files,
|
||||||
|
file_path=file_path,
|
||||||
|
)
|
||||||
|
|
||||||
await self.doc_status.upsert(
|
await self.doc_status.upsert(
|
||||||
{
|
{
|
||||||
doc_id: {
|
doc_id: {
|
||||||
@@ -1012,22 +1069,12 @@ class LightRAG:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Log error and update pipeline status
|
# Log error and update pipeline status
|
||||||
error_msg = f"Failed to process document {doc_id}: {traceback.format_exc()}"
|
error_msg = f"Merging stage failed in document {doc_id}: {traceback.format_exc()}"
|
||||||
|
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
async with pipeline_status_lock:
|
async with pipeline_status_lock:
|
||||||
pipeline_status["latest_message"] = error_msg
|
pipeline_status["latest_message"] = error_msg
|
||||||
pipeline_status["history_messages"].append(error_msg)
|
pipeline_status["history_messages"].append(error_msg)
|
||||||
|
|
||||||
# Cancel other tasks as they are no longer meaningful
|
|
||||||
for task in [
|
|
||||||
chunks_vdb_task,
|
|
||||||
entity_relation_task,
|
|
||||||
full_docs_task,
|
|
||||||
text_chunks_task,
|
|
||||||
]:
|
|
||||||
if not task.done():
|
|
||||||
task.cancel()
|
|
||||||
# Update document status to failed
|
# Update document status to failed
|
||||||
await self.doc_status.upsert(
|
await self.doc_status.upsert(
|
||||||
{
|
{
|
||||||
@@ -1101,9 +1148,9 @@ class LightRAG:
|
|||||||
|
|
||||||
async def _process_entity_relation_graph(
|
async def _process_entity_relation_graph(
|
||||||
self, chunk: dict[str, Any], pipeline_status=None, pipeline_status_lock=None
|
self, chunk: dict[str, Any], pipeline_status=None, pipeline_status_lock=None
|
||||||
) -> None:
|
) -> list:
|
||||||
try:
|
try:
|
||||||
await extract_entities(
|
chunk_results = await extract_entities(
|
||||||
chunk,
|
chunk,
|
||||||
knowledge_graph_inst=self.chunk_entity_relation_graph,
|
knowledge_graph_inst=self.chunk_entity_relation_graph,
|
||||||
entity_vdb=self.entities_vdb,
|
entity_vdb=self.entities_vdb,
|
||||||
@@ -1113,6 +1160,7 @@ class LightRAG:
|
|||||||
pipeline_status_lock=pipeline_status_lock,
|
pipeline_status_lock=pipeline_status_lock,
|
||||||
llm_response_cache=self.llm_response_cache,
|
llm_response_cache=self.llm_response_cache,
|
||||||
)
|
)
|
||||||
|
return chunk_results
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Failed to extract entities and relationships: {str(e)}"
|
error_msg = f"Failed to extract entities and relationships: {str(e)}"
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
|
@@ -476,6 +476,139 @@ async def _merge_edges_then_upsert(
|
|||||||
return edge_data
|
return edge_data
|
||||||
|
|
||||||
|
|
||||||
|
async def merge_nodes_and_edges(
|
||||||
|
chunk_results: list,
|
||||||
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
|
entity_vdb: BaseVectorStorage,
|
||||||
|
relationships_vdb: BaseVectorStorage,
|
||||||
|
global_config: dict[str, str],
|
||||||
|
pipeline_status: dict = None,
|
||||||
|
pipeline_status_lock=None,
|
||||||
|
llm_response_cache: BaseKVStorage | None = None,
|
||||||
|
current_file_number: int = 0,
|
||||||
|
total_files: int = 0,
|
||||||
|
file_path: str = "unknown_source",
|
||||||
|
) -> None:
|
||||||
|
"""Merge nodes and edges from extraction results
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk_results: List of tuples (maybe_nodes, maybe_edges) containing extracted entities and relationships
|
||||||
|
knowledge_graph_inst: Knowledge graph storage
|
||||||
|
entity_vdb: Entity vector database
|
||||||
|
relationships_vdb: Relationship vector database
|
||||||
|
global_config: Global configuration
|
||||||
|
pipeline_status: Pipeline status dictionary
|
||||||
|
pipeline_status_lock: Lock for pipeline status
|
||||||
|
llm_response_cache: LLM response cache
|
||||||
|
"""
|
||||||
|
# Get lock manager from shared storage
|
||||||
|
from .kg.shared_storage import get_graph_db_lock
|
||||||
|
graph_db_lock = get_graph_db_lock(enable_logging=False)
|
||||||
|
|
||||||
|
# Collect all nodes and edges from all chunks
|
||||||
|
all_nodes = defaultdict(list)
|
||||||
|
all_edges = defaultdict(list)
|
||||||
|
|
||||||
|
for maybe_nodes, maybe_edges in chunk_results:
|
||||||
|
# Collect nodes
|
||||||
|
for entity_name, entities in maybe_nodes.items():
|
||||||
|
all_nodes[entity_name].extend(entities)
|
||||||
|
|
||||||
|
# Collect edges with sorted keys for undirected graph
|
||||||
|
for edge_key, edges in maybe_edges.items():
|
||||||
|
sorted_edge_key = tuple(sorted(edge_key))
|
||||||
|
all_edges[sorted_edge_key].extend(edges)
|
||||||
|
|
||||||
|
# Centralized processing of all nodes and edges
|
||||||
|
entities_data = []
|
||||||
|
relationships_data = []
|
||||||
|
|
||||||
|
# Merge nodes and edges
|
||||||
|
# Use graph database lock to ensure atomic merges and updates
|
||||||
|
async with graph_db_lock:
|
||||||
|
async with pipeline_status_lock:
|
||||||
|
log_message = f"Merging nodes/edges {current_file_number}/{total_files}: {file_path}"
|
||||||
|
logger.info(log_message)
|
||||||
|
pipeline_status["latest_message"] = log_message
|
||||||
|
pipeline_status["history_messages"].append(log_message)
|
||||||
|
|
||||||
|
# Process and update all entities at once
|
||||||
|
for entity_name, entities in all_nodes.items():
|
||||||
|
entity_data = await _merge_nodes_then_upsert(
|
||||||
|
entity_name,
|
||||||
|
entities,
|
||||||
|
knowledge_graph_inst,
|
||||||
|
global_config,
|
||||||
|
pipeline_status,
|
||||||
|
pipeline_status_lock,
|
||||||
|
llm_response_cache,
|
||||||
|
)
|
||||||
|
entities_data.append(entity_data)
|
||||||
|
|
||||||
|
# Process and update all relationships at once
|
||||||
|
for edge_key, edges in all_edges.items():
|
||||||
|
edge_data = await _merge_edges_then_upsert(
|
||||||
|
edge_key[0],
|
||||||
|
edge_key[1],
|
||||||
|
edges,
|
||||||
|
knowledge_graph_inst,
|
||||||
|
global_config,
|
||||||
|
pipeline_status,
|
||||||
|
pipeline_status_lock,
|
||||||
|
llm_response_cache,
|
||||||
|
)
|
||||||
|
if edge_data is not None:
|
||||||
|
relationships_data.append(edge_data)
|
||||||
|
|
||||||
|
# Update total counts
|
||||||
|
total_entities_count = len(entities_data)
|
||||||
|
total_relations_count = len(relationships_data)
|
||||||
|
|
||||||
|
log_message = f"Updating {total_entities_count} entities {current_file_number}/{total_files}: {file_path}"
|
||||||
|
logger.info(log_message)
|
||||||
|
if pipeline_status is not None:
|
||||||
|
async with pipeline_status_lock:
|
||||||
|
pipeline_status["latest_message"] = log_message
|
||||||
|
pipeline_status["history_messages"].append(log_message)
|
||||||
|
|
||||||
|
# Update vector databases with all collected data
|
||||||
|
if entity_vdb is not None and entities_data:
|
||||||
|
data_for_vdb = {
|
||||||
|
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
||||||
|
"entity_name": dp["entity_name"],
|
||||||
|
"entity_type": dp["entity_type"],
|
||||||
|
"content": f"{dp['entity_name']}\n{dp['description']}",
|
||||||
|
"source_id": dp["source_id"],
|
||||||
|
"file_path": dp.get("file_path", "unknown_source"),
|
||||||
|
}
|
||||||
|
for dp in entities_data
|
||||||
|
}
|
||||||
|
await entity_vdb.upsert(data_for_vdb)
|
||||||
|
|
||||||
|
log_message = (
|
||||||
|
f"Updating {total_relations_count} relations {current_file_number}/{total_files}: {file_path}"
|
||||||
|
)
|
||||||
|
logger.info(log_message)
|
||||||
|
if pipeline_status is not None:
|
||||||
|
async with pipeline_status_lock:
|
||||||
|
pipeline_status["latest_message"] = log_message
|
||||||
|
pipeline_status["history_messages"].append(log_message)
|
||||||
|
|
||||||
|
if relationships_vdb is not None and relationships_data:
|
||||||
|
data_for_vdb = {
|
||||||
|
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
||||||
|
"src_id": dp["src_id"],
|
||||||
|
"tgt_id": dp["tgt_id"],
|
||||||
|
"keywords": dp["keywords"],
|
||||||
|
"content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}",
|
||||||
|
"source_id": dp["source_id"],
|
||||||
|
"file_path": dp.get("file_path", "unknown_source"),
|
||||||
|
}
|
||||||
|
for dp in relationships_data
|
||||||
|
}
|
||||||
|
await relationships_vdb.upsert(data_for_vdb)
|
||||||
|
|
||||||
|
|
||||||
async def extract_entities(
|
async def extract_entities(
|
||||||
chunks: dict[str, TextChunkSchema],
|
chunks: dict[str, TextChunkSchema],
|
||||||
knowledge_graph_inst: BaseGraphStorage,
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
@@ -485,7 +618,7 @@ async def extract_entities(
|
|||||||
pipeline_status: dict = None,
|
pipeline_status: dict = None,
|
||||||
pipeline_status_lock=None,
|
pipeline_status_lock=None,
|
||||||
llm_response_cache: BaseKVStorage | None = None,
|
llm_response_cache: BaseKVStorage | None = None,
|
||||||
) -> None:
|
) -> list:
|
||||||
use_llm_func: callable = global_config["llm_model_func"]
|
use_llm_func: callable = global_config["llm_model_func"]
|
||||||
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
||||||
|
|
||||||
@@ -530,15 +663,6 @@ async def extract_entities(
|
|||||||
|
|
||||||
processed_chunks = 0
|
processed_chunks = 0
|
||||||
total_chunks = len(ordered_chunks)
|
total_chunks = len(ordered_chunks)
|
||||||
total_entities_count = 0
|
|
||||||
total_relations_count = 0
|
|
||||||
|
|
||||||
# Get lock manager from shared storage
|
|
||||||
from .kg.shared_storage import get_graph_db_lock
|
|
||||||
|
|
||||||
graph_db_lock = get_graph_db_lock(enable_logging=False)
|
|
||||||
|
|
||||||
# Use the global use_llm_func_with_cache function from utils.py
|
|
||||||
|
|
||||||
async def _process_extraction_result(
|
async def _process_extraction_result(
|
||||||
result: str, chunk_key: str, file_path: str = "unknown_source"
|
result: str, chunk_key: str, file_path: str = "unknown_source"
|
||||||
@@ -708,102 +832,9 @@ async def extract_entities(
|
|||||||
|
|
||||||
# If all tasks completed successfully, collect results
|
# If all tasks completed successfully, collect results
|
||||||
chunk_results = [task.result() for task in tasks]
|
chunk_results = [task.result() for task in tasks]
|
||||||
|
|
||||||
# Collect all nodes and edges from all chunks
|
# Return the chunk_results for later processing in merge_nodes_and_edges
|
||||||
all_nodes = defaultdict(list)
|
return chunk_results
|
||||||
all_edges = defaultdict(list)
|
|
||||||
|
|
||||||
for maybe_nodes, maybe_edges in chunk_results:
|
|
||||||
# Collect nodes
|
|
||||||
for entity_name, entities in maybe_nodes.items():
|
|
||||||
all_nodes[entity_name].extend(entities)
|
|
||||||
|
|
||||||
# Collect edges with sorted keys for undirected graph
|
|
||||||
for edge_key, edges in maybe_edges.items():
|
|
||||||
sorted_edge_key = tuple(sorted(edge_key))
|
|
||||||
all_edges[sorted_edge_key].extend(edges)
|
|
||||||
|
|
||||||
# Centralized processing of all nodes and edges
|
|
||||||
entities_data = []
|
|
||||||
relationships_data = []
|
|
||||||
|
|
||||||
# Use graph database lock to ensure atomic merges and updates
|
|
||||||
async with graph_db_lock:
|
|
||||||
# Process and update all entities at once
|
|
||||||
for entity_name, entities in all_nodes.items():
|
|
||||||
entity_data = await _merge_nodes_then_upsert(
|
|
||||||
entity_name,
|
|
||||||
entities,
|
|
||||||
knowledge_graph_inst,
|
|
||||||
global_config,
|
|
||||||
pipeline_status,
|
|
||||||
pipeline_status_lock,
|
|
||||||
llm_response_cache,
|
|
||||||
)
|
|
||||||
entities_data.append(entity_data)
|
|
||||||
|
|
||||||
# Process and update all relationships at once
|
|
||||||
for edge_key, edges in all_edges.items():
|
|
||||||
edge_data = await _merge_edges_then_upsert(
|
|
||||||
edge_key[0],
|
|
||||||
edge_key[1],
|
|
||||||
edges,
|
|
||||||
knowledge_graph_inst,
|
|
||||||
global_config,
|
|
||||||
pipeline_status,
|
|
||||||
pipeline_status_lock,
|
|
||||||
llm_response_cache,
|
|
||||||
)
|
|
||||||
if edge_data is not None:
|
|
||||||
relationships_data.append(edge_data)
|
|
||||||
|
|
||||||
# Update total counts
|
|
||||||
total_entities_count = len(entities_data)
|
|
||||||
total_relations_count = len(relationships_data)
|
|
||||||
|
|
||||||
log_message = f"Updating vector storage: {total_entities_count} entities..."
|
|
||||||
logger.info(log_message)
|
|
||||||
if pipeline_status is not None:
|
|
||||||
async with pipeline_status_lock:
|
|
||||||
pipeline_status["latest_message"] = log_message
|
|
||||||
pipeline_status["history_messages"].append(log_message)
|
|
||||||
|
|
||||||
# Update vector databases with all collected data
|
|
||||||
if entity_vdb is not None and entities_data:
|
|
||||||
data_for_vdb = {
|
|
||||||
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
|
||||||
"entity_name": dp["entity_name"],
|
|
||||||
"entity_type": dp["entity_type"],
|
|
||||||
"content": f"{dp['entity_name']}\n{dp['description']}",
|
|
||||||
"source_id": dp["source_id"],
|
|
||||||
"file_path": dp.get("file_path", "unknown_source"),
|
|
||||||
}
|
|
||||||
for dp in entities_data
|
|
||||||
}
|
|
||||||
await entity_vdb.upsert(data_for_vdb)
|
|
||||||
|
|
||||||
log_message = (
|
|
||||||
f"Updating vector storage: {total_relations_count} relationships..."
|
|
||||||
)
|
|
||||||
logger.info(log_message)
|
|
||||||
if pipeline_status is not None:
|
|
||||||
async with pipeline_status_lock:
|
|
||||||
pipeline_status["latest_message"] = log_message
|
|
||||||
pipeline_status["history_messages"].append(log_message)
|
|
||||||
|
|
||||||
if relationships_vdb is not None and relationships_data:
|
|
||||||
data_for_vdb = {
|
|
||||||
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
|
||||||
"src_id": dp["src_id"],
|
|
||||||
"tgt_id": dp["tgt_id"],
|
|
||||||
"keywords": dp["keywords"],
|
|
||||||
"content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}",
|
|
||||||
"source_id": dp["source_id"],
|
|
||||||
"file_path": dp.get("file_path", "unknown_source"),
|
|
||||||
}
|
|
||||||
for dp in relationships_data
|
|
||||||
}
|
|
||||||
await relationships_vdb.upsert(data_for_vdb)
|
|
||||||
|
|
||||||
|
|
||||||
async def kg_query(
|
async def kg_query(
|
||||||
|
Reference in New Issue
Block a user