From 2a4ff7c0d087e7ec3da7de5b34f63b4765b78c2c Mon Sep 17 00:00:00 2001 From: yangdx Date: Mon, 17 Mar 2025 04:00:38 +0800 Subject: [PATCH] Fix pipeline bactch process problem - Process batch one by one - Process documents in parallel within each batch --- lightrag/lightrag.py | 263 ++++++++++++++++++++++--------------------- 1 file changed, 136 insertions(+), 127 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 5a5461e0..76aa1dc8 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -769,7 +769,7 @@ class LightRAG: async with pipeline_status_lock: # Ensure only one worker is processing documents if not pipeline_status.get("busy", False): - # 先检查是否有需要处理的文档 + processing_docs, failed_docs, pending_docs = await asyncio.gather( self.doc_status.get_docs_by_status(DocStatus.PROCESSING), self.doc_status.get_docs_by_status(DocStatus.FAILED), @@ -781,12 +781,10 @@ class LightRAG: to_process_docs.update(failed_docs) to_process_docs.update(pending_docs) - # 如果没有需要处理的文档,直接返回,保留 pipeline_status 中的内容不变 if not to_process_docs: logger.info("No documents to process") return - # 有文档需要处理,更新 pipeline_status pipeline_status.update( { "busy": True, @@ -825,7 +823,7 @@ class LightRAG: for i in range(0, len(to_process_docs), self.max_parallel_insert) ] - log_message = f"Number of batches to process: {len(docs_batches)}." + log_message = f"Processing {len(to_process_docs)} document(s) in {len(docs_batches)} batches" logger.info(log_message) # Update pipeline status with current batch information @@ -834,140 +832,151 @@ class LightRAG: pipeline_status["latest_message"] = log_message pipeline_status["history_messages"].append(log_message) - batches: list[Any] = [] - # 3. iterate over batches - for batch_idx, docs_batch in enumerate(docs_batches): - # Update current batch in pipeline status (directly, as it's atomic) - pipeline_status["cur_batch"] += 1 - - async def batch( - batch_idx: int, - docs_batch: list[tuple[str, DocProcessingStatus]], - size_batch: int, - ) -> None: - log_message = ( - f"Start processing batch {batch_idx + 1} of {size_batch}." - ) - logger.info(log_message) - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) - # 4. iterate over batch - for doc_id_processing_status in docs_batch: - doc_id, status_doc = doc_id_processing_status - # Generate chunks from document - chunks: dict[str, Any] = { - compute_mdhash_id(dp["content"], prefix="chunk-"): { - **dp, - "full_doc_id": doc_id, - } - for dp in self.chunking_func( - status_doc.content, - split_by_character, - split_by_character_only, - self.chunk_overlap_token_size, - self.chunk_token_size, - self.tiktoken_model_name, - ) + async def process_document( + doc_id: str, + status_doc: DocProcessingStatus, + split_by_character: str | None, + split_by_character_only: bool, + pipeline_status: dict, + pipeline_status_lock: asyncio.Lock + ) -> None: + """Process single document""" + try: + # Generate chunks from document + chunks: dict[str, Any] = { + compute_mdhash_id(dp["content"], prefix="chunk-"): { + **dp, + "full_doc_id": doc_id, } - # Process document (text chunks and full docs) in parallel - # Create tasks with references for potential cancellation - doc_status_task = asyncio.create_task( - self.doc_status.upsert( - { - doc_id: { - "status": DocStatus.PROCESSING, - "updated_at": datetime.now().isoformat(), - "content": status_doc.content, - "content_summary": status_doc.content_summary, - "content_length": status_doc.content_length, - "created_at": status_doc.created_at, - } + for dp in self.chunking_func( + status_doc.content, + split_by_character, + split_by_character_only, + self.chunk_overlap_token_size, + self.chunk_token_size, + self.tiktoken_model_name, + ) + } + # Process document (text chunks and full docs) in parallel + # Create tasks with references for potential cancellation + doc_status_task = asyncio.create_task( + self.doc_status.upsert( + { + doc_id: { + "status": DocStatus.PROCESSING, + "updated_at": datetime.now().isoformat(), + "content": status_doc.content, + "content_summary": status_doc.content_summary, + "content_length": status_doc.content_length, + "created_at": status_doc.created_at, } - ) + } ) - chunks_vdb_task = asyncio.create_task( - self.chunks_vdb.upsert(chunks) + ) + chunks_vdb_task = asyncio.create_task( + self.chunks_vdb.upsert(chunks) + ) + entity_relation_task = asyncio.create_task( + self._process_entity_relation_graph( + chunks, pipeline_status, pipeline_status_lock ) - entity_relation_task = asyncio.create_task( - self._process_entity_relation_graph( - chunks, pipeline_status, pipeline_status_lock - ) + ) + full_docs_task = asyncio.create_task( + self.full_docs.upsert( + {doc_id: {"content": status_doc.content}} ) - full_docs_task = asyncio.create_task( - self.full_docs.upsert( - {doc_id: {"content": status_doc.content}} - ) - ) - text_chunks_task = asyncio.create_task( - self.text_chunks.upsert(chunks) - ) - tasks = [ - doc_status_task, + ) + text_chunks_task = asyncio.create_task( + self.text_chunks.upsert(chunks) + ) + tasks = [ + doc_status_task, + chunks_vdb_task, + entity_relation_task, + full_docs_task, + text_chunks_task, + ] + await asyncio.gather(*tasks) + await self.doc_status.upsert( + { + doc_id: { + "status": DocStatus.PROCESSED, + "chunks_count": len(chunks), + "content": status_doc.content, + "content_summary": status_doc.content_summary, + "content_length": status_doc.content_length, + "created_at": status_doc.created_at, + "updated_at": datetime.now().isoformat(), + } + } + ) + except Exception as e: + # Log error and update pipeline status + error_msg = ( + f"Failed to process document {doc_id}: {str(e)}" + ) + logger.error(error_msg) + async with pipeline_status_lock: + pipeline_status["latest_message"] = error_msg + pipeline_status["history_messages"].append(error_msg) + + # Cancel other tasks as they are no longer meaningful + for task in [ chunks_vdb_task, entity_relation_task, full_docs_task, text_chunks_task, - ] - try: - await asyncio.gather(*tasks) - await self.doc_status.upsert( - { - doc_id: { - "status": DocStatus.PROCESSED, - "chunks_count": len(chunks), - "content": status_doc.content, - "content_summary": status_doc.content_summary, - "content_length": status_doc.content_length, - "created_at": status_doc.created_at, - "updated_at": datetime.now().isoformat(), - } - } - ) - except Exception as e: - # Log error and update pipeline status - error_msg = ( - f"Failed to process document {doc_id}: {str(e)}" - ) - logger.error(error_msg) - pipeline_status["latest_message"] = error_msg - pipeline_status["history_messages"].append(error_msg) - - # Cancel other tasks as they are no longer meaningful - for task in [ - chunks_vdb_task, - entity_relation_task, - full_docs_task, - text_chunks_task, - ]: - if not task.done(): - task.cancel() - - # Update document status to failed - await self.doc_status.upsert( - { - doc_id: { - "status": DocStatus.FAILED, - "error": str(e), - "content": status_doc.content, - "content_summary": status_doc.content_summary, - "content_length": status_doc.content_length, - "created_at": status_doc.created_at, - "updated_at": datetime.now().isoformat(), - } - } - ) - continue - log_message = ( - f"Completed batch {batch_idx + 1} of {len(docs_batches)}." + ]: + if not task.done(): + task.cancel() + # Update document status to failed + await self.doc_status.upsert( + { + doc_id: { + "status": DocStatus.FAILED, + "error": str(e), + "content": status_doc.content, + "content_summary": status_doc.content_summary, + "content_length": status_doc.content_length, + "created_at": status_doc.created_at, + "updated_at": datetime.now().isoformat(), + } + } ) - logger.info(log_message) - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) - batches.append(batch(batch_idx, docs_batch, len(docs_batches))) + # 3. iterate over batches + total_batches = len(docs_batches) + for batch_idx, docs_batch in enumerate(docs_batches): - await asyncio.gather(*batches) - await self._insert_done() + current_batch = batch_idx + 1 + log_message = f"Start processing batch {current_batch} of {total_batches}." + logger.info(log_message) + pipeline_status["cur_batch"] = current_batch + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) + + doc_tasks = [] + for doc_id, status_doc in docs_batch: + doc_tasks.append( + process_document( + doc_id, + status_doc, + split_by_character, + split_by_character_only, + pipeline_status, + pipeline_status_lock + ) + ) + + # Process documents in one batch parallelly + await asyncio.gather(*doc_tasks) + await self._insert_done() + + log_message = f"Completed batch {current_batch} of {total_batches}." + logger.info(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) + # Check if there's a pending request to process more documents (with lock) has_pending_request = False @@ -1042,7 +1051,7 @@ class LightRAG: ] await asyncio.gather(*tasks) - log_message = "All Insert done" + log_message = "All data persist to disk" logger.info(log_message) if pipeline_status is not None and pipeline_status_lock is not None: