From 18040aa95c535076ca3f4fc4a047ac1ac54715aa Mon Sep 17 00:00:00 2001 From: yangdx Date: Mon, 28 Apr 2025 01:14:00 +0800 Subject: [PATCH] Improve parallel handling logic between extraction and merge operation --- lightrag/lightrag.py | 76 +++++++++++--- lightrag/operate.py | 243 ++++++++++++++++++++++++------------------- 2 files changed, 199 insertions(+), 120 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index e6507c37..b8314870 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -46,6 +46,7 @@ from .namespace import NameSpace, make_namespace from .operate import ( chunking_by_token_size, extract_entities, + merge_nodes_and_edges, kg_query, mix_kg_vector_query, naive_query, @@ -902,6 +903,7 @@ class LightRAG: semaphore: asyncio.Semaphore, ) -> None: """Process single document""" + file_extraction_stage_ok = False async with semaphore: nonlocal processed_count current_file_number = 0 @@ -919,7 +921,7 @@ class LightRAG: ) pipeline_status["cur_batch"] = processed_count - log_message = f"Processing file ({current_file_number}/{total_files}): {file_path}" + log_message = f"Processing file {current_file_number}/{total_files}: {file_path}" logger.info(log_message) pipeline_status["history_messages"].append(log_message) log_message = f"Processing d-id: {doc_id}" @@ -986,6 +988,61 @@ class LightRAG: text_chunks_task, ] await asyncio.gather(*tasks) + file_extraction_stage_ok = True + + except Exception as e: + # Log error and update pipeline status + error_msg = f"Failed to extrat document {doc_id}: {traceback.format_exc()}" + logger.error(error_msg) + async with pipeline_status_lock: + pipeline_status["latest_message"] = error_msg + pipeline_status["history_messages"].append(error_msg) + + # Cancel other tasks as they are no longer meaningful + for task in [ + chunks_vdb_task, + entity_relation_task, + full_docs_task, + text_chunks_task, + ]: + if not task.done(): + task.cancel() + + # Update document status to failed + await self.doc_status.upsert( + { + doc_id: { + "status": DocStatus.FAILED, + "error": str(e), + "content": status_doc.content, + "content_summary": status_doc.content_summary, + "content_length": status_doc.content_length, + "created_at": status_doc.created_at, + "updated_at": datetime.now().isoformat(), + "file_path": file_path, + } + } + ) + + # Release semphore before entering to merge stage + if file_extraction_stage_ok: + try: + # Get chunk_results from entity_relation_task + chunk_results = await entity_relation_task + await merge_nodes_and_edges( + chunk_results=chunk_results, # result collected from entity_relation_task + knowledge_graph_inst=self.chunk_entity_relation_graph, + entity_vdb=self.entities_vdb, + relationships_vdb=self.relationships_vdb, + global_config=asdict(self), + pipeline_status=pipeline_status, + pipeline_status_lock=pipeline_status_lock, + llm_response_cache=self.llm_response_cache, + current_file_number=current_file_number, + total_files=total_files, + file_path=file_path, + ) + await self.doc_status.upsert( { doc_id: { @@ -1012,22 +1069,12 @@ class LightRAG: except Exception as e: # Log error and update pipeline status - error_msg = f"Failed to process document {doc_id}: {traceback.format_exc()}" - + error_msg = f"Merging stage failed in document {doc_id}: {traceback.format_exc()}" logger.error(error_msg) async with pipeline_status_lock: pipeline_status["latest_message"] = error_msg pipeline_status["history_messages"].append(error_msg) - # Cancel other tasks as they are no longer meaningful - for task in [ - chunks_vdb_task, - entity_relation_task, - full_docs_task, - text_chunks_task, - ]: - if not task.done(): - task.cancel() # Update document status to failed await self.doc_status.upsert( { @@ -1101,9 +1148,9 @@ class LightRAG: async def _process_entity_relation_graph( self, chunk: dict[str, Any], pipeline_status=None, pipeline_status_lock=None - ) -> None: + ) -> list: try: - await extract_entities( + chunk_results = await extract_entities( chunk, knowledge_graph_inst=self.chunk_entity_relation_graph, entity_vdb=self.entities_vdb, @@ -1113,6 +1160,7 @@ class LightRAG: pipeline_status_lock=pipeline_status_lock, llm_response_cache=self.llm_response_cache, ) + return chunk_results except Exception as e: error_msg = f"Failed to extract entities and relationships: {str(e)}" logger.error(error_msg) diff --git a/lightrag/operate.py b/lightrag/operate.py index 882e3ac2..83ef36ad 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -476,6 +476,139 @@ async def _merge_edges_then_upsert( return edge_data +async def merge_nodes_and_edges( + chunk_results: list, + knowledge_graph_inst: BaseGraphStorage, + entity_vdb: BaseVectorStorage, + relationships_vdb: BaseVectorStorage, + global_config: dict[str, str], + pipeline_status: dict = None, + pipeline_status_lock=None, + llm_response_cache: BaseKVStorage | None = None, + current_file_number: int = 0, + total_files: int = 0, + file_path: str = "unknown_source", +) -> None: + """Merge nodes and edges from extraction results + + Args: + chunk_results: List of tuples (maybe_nodes, maybe_edges) containing extracted entities and relationships + knowledge_graph_inst: Knowledge graph storage + entity_vdb: Entity vector database + relationships_vdb: Relationship vector database + global_config: Global configuration + pipeline_status: Pipeline status dictionary + pipeline_status_lock: Lock for pipeline status + llm_response_cache: LLM response cache + """ + # Get lock manager from shared storage + from .kg.shared_storage import get_graph_db_lock + graph_db_lock = get_graph_db_lock(enable_logging=False) + + # Collect all nodes and edges from all chunks + all_nodes = defaultdict(list) + all_edges = defaultdict(list) + + for maybe_nodes, maybe_edges in chunk_results: + # Collect nodes + for entity_name, entities in maybe_nodes.items(): + all_nodes[entity_name].extend(entities) + + # Collect edges with sorted keys for undirected graph + for edge_key, edges in maybe_edges.items(): + sorted_edge_key = tuple(sorted(edge_key)) + all_edges[sorted_edge_key].extend(edges) + + # Centralized processing of all nodes and edges + entities_data = [] + relationships_data = [] + + # Merge nodes and edges + # Use graph database lock to ensure atomic merges and updates + async with graph_db_lock: + async with pipeline_status_lock: + log_message = f"Merging nodes/edges {current_file_number}/{total_files}: {file_path}" + logger.info(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) + + # Process and update all entities at once + for entity_name, entities in all_nodes.items(): + entity_data = await _merge_nodes_then_upsert( + entity_name, + entities, + knowledge_graph_inst, + global_config, + pipeline_status, + pipeline_status_lock, + llm_response_cache, + ) + entities_data.append(entity_data) + + # Process and update all relationships at once + for edge_key, edges in all_edges.items(): + edge_data = await _merge_edges_then_upsert( + edge_key[0], + edge_key[1], + edges, + knowledge_graph_inst, + global_config, + pipeline_status, + pipeline_status_lock, + llm_response_cache, + ) + if edge_data is not None: + relationships_data.append(edge_data) + + # Update total counts + total_entities_count = len(entities_data) + total_relations_count = len(relationships_data) + + log_message = f"Updating {total_entities_count} entities {current_file_number}/{total_files}: {file_path}" + logger.info(log_message) + if pipeline_status is not None: + async with pipeline_status_lock: + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) + + # Update vector databases with all collected data + if entity_vdb is not None and entities_data: + data_for_vdb = { + compute_mdhash_id(dp["entity_name"], prefix="ent-"): { + "entity_name": dp["entity_name"], + "entity_type": dp["entity_type"], + "content": f"{dp['entity_name']}\n{dp['description']}", + "source_id": dp["source_id"], + "file_path": dp.get("file_path", "unknown_source"), + } + for dp in entities_data + } + await entity_vdb.upsert(data_for_vdb) + + log_message = ( + f"Updating {total_relations_count} relations {current_file_number}/{total_files}: {file_path}" + ) + logger.info(log_message) + if pipeline_status is not None: + async with pipeline_status_lock: + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) + + if relationships_vdb is not None and relationships_data: + data_for_vdb = { + compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): { + "src_id": dp["src_id"], + "tgt_id": dp["tgt_id"], + "keywords": dp["keywords"], + "content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}", + "source_id": dp["source_id"], + "file_path": dp.get("file_path", "unknown_source"), + } + for dp in relationships_data + } + await relationships_vdb.upsert(data_for_vdb) + + async def extract_entities( chunks: dict[str, TextChunkSchema], knowledge_graph_inst: BaseGraphStorage, @@ -485,7 +618,7 @@ async def extract_entities( pipeline_status: dict = None, pipeline_status_lock=None, llm_response_cache: BaseKVStorage | None = None, -) -> None: +) -> list: use_llm_func: callable = global_config["llm_model_func"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] @@ -530,15 +663,6 @@ async def extract_entities( processed_chunks = 0 total_chunks = len(ordered_chunks) - total_entities_count = 0 - total_relations_count = 0 - - # Get lock manager from shared storage - from .kg.shared_storage import get_graph_db_lock - - graph_db_lock = get_graph_db_lock(enable_logging=False) - - # Use the global use_llm_func_with_cache function from utils.py async def _process_extraction_result( result: str, chunk_key: str, file_path: str = "unknown_source" @@ -708,102 +832,9 @@ async def extract_entities( # If all tasks completed successfully, collect results chunk_results = [task.result() for task in tasks] - - # Collect all nodes and edges from all chunks - all_nodes = defaultdict(list) - all_edges = defaultdict(list) - - for maybe_nodes, maybe_edges in chunk_results: - # Collect nodes - for entity_name, entities in maybe_nodes.items(): - all_nodes[entity_name].extend(entities) - - # Collect edges with sorted keys for undirected graph - for edge_key, edges in maybe_edges.items(): - sorted_edge_key = tuple(sorted(edge_key)) - all_edges[sorted_edge_key].extend(edges) - - # Centralized processing of all nodes and edges - entities_data = [] - relationships_data = [] - - # Use graph database lock to ensure atomic merges and updates - async with graph_db_lock: - # Process and update all entities at once - for entity_name, entities in all_nodes.items(): - entity_data = await _merge_nodes_then_upsert( - entity_name, - entities, - knowledge_graph_inst, - global_config, - pipeline_status, - pipeline_status_lock, - llm_response_cache, - ) - entities_data.append(entity_data) - - # Process and update all relationships at once - for edge_key, edges in all_edges.items(): - edge_data = await _merge_edges_then_upsert( - edge_key[0], - edge_key[1], - edges, - knowledge_graph_inst, - global_config, - pipeline_status, - pipeline_status_lock, - llm_response_cache, - ) - if edge_data is not None: - relationships_data.append(edge_data) - - # Update total counts - total_entities_count = len(entities_data) - total_relations_count = len(relationships_data) - - log_message = f"Updating vector storage: {total_entities_count} entities..." - logger.info(log_message) - if pipeline_status is not None: - async with pipeline_status_lock: - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) - - # Update vector databases with all collected data - if entity_vdb is not None and entities_data: - data_for_vdb = { - compute_mdhash_id(dp["entity_name"], prefix="ent-"): { - "entity_name": dp["entity_name"], - "entity_type": dp["entity_type"], - "content": f"{dp['entity_name']}\n{dp['description']}", - "source_id": dp["source_id"], - "file_path": dp.get("file_path", "unknown_source"), - } - for dp in entities_data - } - await entity_vdb.upsert(data_for_vdb) - - log_message = ( - f"Updating vector storage: {total_relations_count} relationships..." - ) - logger.info(log_message) - if pipeline_status is not None: - async with pipeline_status_lock: - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) - - if relationships_vdb is not None and relationships_data: - data_for_vdb = { - compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): { - "src_id": dp["src_id"], - "tgt_id": dp["tgt_id"], - "keywords": dp["keywords"], - "content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}", - "source_id": dp["source_id"], - "file_path": dp.get("file_path", "unknown_source"), - } - for dp in relationships_data - } - await relationships_vdb.upsert(data_for_vdb) + + # Return the chunk_results for later processing in merge_nodes_and_edges + return chunk_results async def kg_query(