Merge pull request #1478 from danielaskdd/improve-extract-merge-parallel
Improve parallel processing logic for document extraction and merging
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
|
||||
|
||||
__version__ = "1.3.4"
|
||||
__version__ = "1.3.5"
|
||||
__author__ = "Zirui Guo"
|
||||
__url__ = "https://github.com/HKUDS/LightRAG"
|
||||
|
@@ -46,6 +46,7 @@ from .namespace import NameSpace, make_namespace
|
||||
from .operate import (
|
||||
chunking_by_token_size,
|
||||
extract_entities,
|
||||
merge_nodes_and_edges,
|
||||
kg_query,
|
||||
mix_kg_vector_query,
|
||||
naive_query,
|
||||
@@ -902,6 +903,7 @@ class LightRAG:
|
||||
semaphore: asyncio.Semaphore,
|
||||
) -> None:
|
||||
"""Process single document"""
|
||||
file_extraction_stage_ok = False
|
||||
async with semaphore:
|
||||
nonlocal processed_count
|
||||
current_file_number = 0
|
||||
@@ -919,7 +921,7 @@ class LightRAG:
|
||||
)
|
||||
pipeline_status["cur_batch"] = processed_count
|
||||
|
||||
log_message = f"Processing file ({current_file_number}/{total_files}): {file_path}"
|
||||
log_message = f"Processing file {current_file_number}/{total_files}: {file_path}"
|
||||
logger.info(log_message)
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
log_message = f"Processing d-id: {doc_id}"
|
||||
@@ -986,6 +988,61 @@ class LightRAG:
|
||||
text_chunks_task,
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
file_extraction_stage_ok = True
|
||||
|
||||
except Exception as e:
|
||||
# Log error and update pipeline status
|
||||
error_msg = f"Failed to extrat document {doc_id}: {traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
async with pipeline_status_lock:
|
||||
pipeline_status["latest_message"] = error_msg
|
||||
pipeline_status["history_messages"].append(error_msg)
|
||||
|
||||
# Cancel other tasks as they are no longer meaningful
|
||||
for task in [
|
||||
chunks_vdb_task,
|
||||
entity_relation_task,
|
||||
full_docs_task,
|
||||
text_chunks_task,
|
||||
]:
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
|
||||
# Update document status to failed
|
||||
await self.doc_status.upsert(
|
||||
{
|
||||
doc_id: {
|
||||
"status": DocStatus.FAILED,
|
||||
"error": str(e),
|
||||
"content": status_doc.content,
|
||||
"content_summary": status_doc.content_summary,
|
||||
"content_length": status_doc.content_length,
|
||||
"created_at": status_doc.created_at,
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
"file_path": file_path,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Release semphore before entering to merge stage
|
||||
if file_extraction_stage_ok:
|
||||
try:
|
||||
# Get chunk_results from entity_relation_task
|
||||
chunk_results = await entity_relation_task
|
||||
await merge_nodes_and_edges(
|
||||
chunk_results=chunk_results, # result collected from entity_relation_task
|
||||
knowledge_graph_inst=self.chunk_entity_relation_graph,
|
||||
entity_vdb=self.entities_vdb,
|
||||
relationships_vdb=self.relationships_vdb,
|
||||
global_config=asdict(self),
|
||||
pipeline_status=pipeline_status,
|
||||
pipeline_status_lock=pipeline_status_lock,
|
||||
llm_response_cache=self.llm_response_cache,
|
||||
current_file_number=current_file_number,
|
||||
total_files=total_files,
|
||||
file_path=file_path,
|
||||
)
|
||||
|
||||
await self.doc_status.upsert(
|
||||
{
|
||||
doc_id: {
|
||||
@@ -1012,22 +1069,12 @@ class LightRAG:
|
||||
|
||||
except Exception as e:
|
||||
# Log error and update pipeline status
|
||||
error_msg = f"Failed to process document {doc_id}: {traceback.format_exc()}"
|
||||
|
||||
error_msg = f"Merging stage failed in document {doc_id}: {traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
async with pipeline_status_lock:
|
||||
pipeline_status["latest_message"] = error_msg
|
||||
pipeline_status["history_messages"].append(error_msg)
|
||||
|
||||
# Cancel other tasks as they are no longer meaningful
|
||||
for task in [
|
||||
chunks_vdb_task,
|
||||
entity_relation_task,
|
||||
full_docs_task,
|
||||
text_chunks_task,
|
||||
]:
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
# Update document status to failed
|
||||
await self.doc_status.upsert(
|
||||
{
|
||||
@@ -1101,9 +1148,9 @@ class LightRAG:
|
||||
|
||||
async def _process_entity_relation_graph(
|
||||
self, chunk: dict[str, Any], pipeline_status=None, pipeline_status_lock=None
|
||||
) -> None:
|
||||
) -> list:
|
||||
try:
|
||||
await extract_entities(
|
||||
chunk_results = await extract_entities(
|
||||
chunk,
|
||||
knowledge_graph_inst=self.chunk_entity_relation_graph,
|
||||
entity_vdb=self.entities_vdb,
|
||||
@@ -1113,6 +1160,7 @@ class LightRAG:
|
||||
pipeline_status_lock=pipeline_status_lock,
|
||||
llm_response_cache=self.llm_response_cache,
|
||||
)
|
||||
return chunk_results
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to extract entities and relationships: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
|
@@ -476,6 +476,139 @@ async def _merge_edges_then_upsert(
|
||||
return edge_data
|
||||
|
||||
|
||||
async def merge_nodes_and_edges(
|
||||
chunk_results: list,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
entity_vdb: BaseVectorStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
global_config: dict[str, str],
|
||||
pipeline_status: dict = None,
|
||||
pipeline_status_lock=None,
|
||||
llm_response_cache: BaseKVStorage | None = None,
|
||||
current_file_number: int = 0,
|
||||
total_files: int = 0,
|
||||
file_path: str = "unknown_source",
|
||||
) -> None:
|
||||
"""Merge nodes and edges from extraction results
|
||||
|
||||
Args:
|
||||
chunk_results: List of tuples (maybe_nodes, maybe_edges) containing extracted entities and relationships
|
||||
knowledge_graph_inst: Knowledge graph storage
|
||||
entity_vdb: Entity vector database
|
||||
relationships_vdb: Relationship vector database
|
||||
global_config: Global configuration
|
||||
pipeline_status: Pipeline status dictionary
|
||||
pipeline_status_lock: Lock for pipeline status
|
||||
llm_response_cache: LLM response cache
|
||||
"""
|
||||
# Get lock manager from shared storage
|
||||
from .kg.shared_storage import get_graph_db_lock
|
||||
graph_db_lock = get_graph_db_lock(enable_logging=False)
|
||||
|
||||
# Collect all nodes and edges from all chunks
|
||||
all_nodes = defaultdict(list)
|
||||
all_edges = defaultdict(list)
|
||||
|
||||
for maybe_nodes, maybe_edges in chunk_results:
|
||||
# Collect nodes
|
||||
for entity_name, entities in maybe_nodes.items():
|
||||
all_nodes[entity_name].extend(entities)
|
||||
|
||||
# Collect edges with sorted keys for undirected graph
|
||||
for edge_key, edges in maybe_edges.items():
|
||||
sorted_edge_key = tuple(sorted(edge_key))
|
||||
all_edges[sorted_edge_key].extend(edges)
|
||||
|
||||
# Centralized processing of all nodes and edges
|
||||
entities_data = []
|
||||
relationships_data = []
|
||||
|
||||
# Merge nodes and edges
|
||||
# Use graph database lock to ensure atomic merges and updates
|
||||
async with graph_db_lock:
|
||||
async with pipeline_status_lock:
|
||||
log_message = f"Merging nodes/edges {current_file_number}/{total_files}: {file_path}"
|
||||
logger.info(log_message)
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
|
||||
# Process and update all entities at once
|
||||
for entity_name, entities in all_nodes.items():
|
||||
entity_data = await _merge_nodes_then_upsert(
|
||||
entity_name,
|
||||
entities,
|
||||
knowledge_graph_inst,
|
||||
global_config,
|
||||
pipeline_status,
|
||||
pipeline_status_lock,
|
||||
llm_response_cache,
|
||||
)
|
||||
entities_data.append(entity_data)
|
||||
|
||||
# Process and update all relationships at once
|
||||
for edge_key, edges in all_edges.items():
|
||||
edge_data = await _merge_edges_then_upsert(
|
||||
edge_key[0],
|
||||
edge_key[1],
|
||||
edges,
|
||||
knowledge_graph_inst,
|
||||
global_config,
|
||||
pipeline_status,
|
||||
pipeline_status_lock,
|
||||
llm_response_cache,
|
||||
)
|
||||
if edge_data is not None:
|
||||
relationships_data.append(edge_data)
|
||||
|
||||
# Update total counts
|
||||
total_entities_count = len(entities_data)
|
||||
total_relations_count = len(relationships_data)
|
||||
|
||||
log_message = f"Updating {total_entities_count} entities {current_file_number}/{total_files}: {file_path}"
|
||||
logger.info(log_message)
|
||||
if pipeline_status is not None:
|
||||
async with pipeline_status_lock:
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
|
||||
# Update vector databases with all collected data
|
||||
if entity_vdb is not None and entities_data:
|
||||
data_for_vdb = {
|
||||
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
||||
"entity_name": dp["entity_name"],
|
||||
"entity_type": dp["entity_type"],
|
||||
"content": f"{dp['entity_name']}\n{dp['description']}",
|
||||
"source_id": dp["source_id"],
|
||||
"file_path": dp.get("file_path", "unknown_source"),
|
||||
}
|
||||
for dp in entities_data
|
||||
}
|
||||
await entity_vdb.upsert(data_for_vdb)
|
||||
|
||||
log_message = (
|
||||
f"Updating {total_relations_count} relations {current_file_number}/{total_files}: {file_path}"
|
||||
)
|
||||
logger.info(log_message)
|
||||
if pipeline_status is not None:
|
||||
async with pipeline_status_lock:
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
|
||||
if relationships_vdb is not None and relationships_data:
|
||||
data_for_vdb = {
|
||||
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
||||
"src_id": dp["src_id"],
|
||||
"tgt_id": dp["tgt_id"],
|
||||
"keywords": dp["keywords"],
|
||||
"content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}",
|
||||
"source_id": dp["source_id"],
|
||||
"file_path": dp.get("file_path", "unknown_source"),
|
||||
}
|
||||
for dp in relationships_data
|
||||
}
|
||||
await relationships_vdb.upsert(data_for_vdb)
|
||||
|
||||
|
||||
async def extract_entities(
|
||||
chunks: dict[str, TextChunkSchema],
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
@@ -485,7 +618,7 @@ async def extract_entities(
|
||||
pipeline_status: dict = None,
|
||||
pipeline_status_lock=None,
|
||||
llm_response_cache: BaseKVStorage | None = None,
|
||||
) -> None:
|
||||
) -> list:
|
||||
use_llm_func: callable = global_config["llm_model_func"]
|
||||
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
||||
|
||||
@@ -530,15 +663,6 @@ async def extract_entities(
|
||||
|
||||
processed_chunks = 0
|
||||
total_chunks = len(ordered_chunks)
|
||||
total_entities_count = 0
|
||||
total_relations_count = 0
|
||||
|
||||
# Get lock manager from shared storage
|
||||
from .kg.shared_storage import get_graph_db_lock
|
||||
|
||||
graph_db_lock = get_graph_db_lock(enable_logging=False)
|
||||
|
||||
# Use the global use_llm_func_with_cache function from utils.py
|
||||
|
||||
async def _process_extraction_result(
|
||||
result: str, chunk_key: str, file_path: str = "unknown_source"
|
||||
@@ -709,101 +833,8 @@ async def extract_entities(
|
||||
# If all tasks completed successfully, collect results
|
||||
chunk_results = [task.result() for task in tasks]
|
||||
|
||||
# Collect all nodes and edges from all chunks
|
||||
all_nodes = defaultdict(list)
|
||||
all_edges = defaultdict(list)
|
||||
|
||||
for maybe_nodes, maybe_edges in chunk_results:
|
||||
# Collect nodes
|
||||
for entity_name, entities in maybe_nodes.items():
|
||||
all_nodes[entity_name].extend(entities)
|
||||
|
||||
# Collect edges with sorted keys for undirected graph
|
||||
for edge_key, edges in maybe_edges.items():
|
||||
sorted_edge_key = tuple(sorted(edge_key))
|
||||
all_edges[sorted_edge_key].extend(edges)
|
||||
|
||||
# Centralized processing of all nodes and edges
|
||||
entities_data = []
|
||||
relationships_data = []
|
||||
|
||||
# Use graph database lock to ensure atomic merges and updates
|
||||
async with graph_db_lock:
|
||||
# Process and update all entities at once
|
||||
for entity_name, entities in all_nodes.items():
|
||||
entity_data = await _merge_nodes_then_upsert(
|
||||
entity_name,
|
||||
entities,
|
||||
knowledge_graph_inst,
|
||||
global_config,
|
||||
pipeline_status,
|
||||
pipeline_status_lock,
|
||||
llm_response_cache,
|
||||
)
|
||||
entities_data.append(entity_data)
|
||||
|
||||
# Process and update all relationships at once
|
||||
for edge_key, edges in all_edges.items():
|
||||
edge_data = await _merge_edges_then_upsert(
|
||||
edge_key[0],
|
||||
edge_key[1],
|
||||
edges,
|
||||
knowledge_graph_inst,
|
||||
global_config,
|
||||
pipeline_status,
|
||||
pipeline_status_lock,
|
||||
llm_response_cache,
|
||||
)
|
||||
if edge_data is not None:
|
||||
relationships_data.append(edge_data)
|
||||
|
||||
# Update total counts
|
||||
total_entities_count = len(entities_data)
|
||||
total_relations_count = len(relationships_data)
|
||||
|
||||
log_message = f"Updating vector storage: {total_entities_count} entities..."
|
||||
logger.info(log_message)
|
||||
if pipeline_status is not None:
|
||||
async with pipeline_status_lock:
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
|
||||
# Update vector databases with all collected data
|
||||
if entity_vdb is not None and entities_data:
|
||||
data_for_vdb = {
|
||||
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
||||
"entity_name": dp["entity_name"],
|
||||
"entity_type": dp["entity_type"],
|
||||
"content": f"{dp['entity_name']}\n{dp['description']}",
|
||||
"source_id": dp["source_id"],
|
||||
"file_path": dp.get("file_path", "unknown_source"),
|
||||
}
|
||||
for dp in entities_data
|
||||
}
|
||||
await entity_vdb.upsert(data_for_vdb)
|
||||
|
||||
log_message = (
|
||||
f"Updating vector storage: {total_relations_count} relationships..."
|
||||
)
|
||||
logger.info(log_message)
|
||||
if pipeline_status is not None:
|
||||
async with pipeline_status_lock:
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
|
||||
if relationships_vdb is not None and relationships_data:
|
||||
data_for_vdb = {
|
||||
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
||||
"src_id": dp["src_id"],
|
||||
"tgt_id": dp["tgt_id"],
|
||||
"keywords": dp["keywords"],
|
||||
"content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}",
|
||||
"source_id": dp["source_id"],
|
||||
"file_path": dp.get("file_path", "unknown_source"),
|
||||
}
|
||||
for dp in relationships_data
|
||||
}
|
||||
await relationships_vdb.upsert(data_for_vdb)
|
||||
# Return the chunk_results for later processing in merge_nodes_and_edges
|
||||
return chunk_results
|
||||
|
||||
|
||||
async def kg_query(
|
||||
|
Reference in New Issue
Block a user