diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 6343b291..91aacddf 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -768,14 +768,17 @@ class LightRAG: 4. Update the document status """ # 1. Get all pending, failed, and abnormally terminated processing documents. - to_process_docs: dict[str, DocProcessingStatus] = {} + # Run the asynchronous status retrievals in parallel using asyncio.gather + processing_docs, failed_docs, pending_docs = await asyncio.gather( + self.doc_status.get_docs_by_status(DocStatus.PROCESSING), + self.doc_status.get_docs_by_status(DocStatus.FAILED), + self.doc_status.get_docs_by_status(DocStatus.PENDING), + ) - processing_docs = await self.doc_status.get_docs_by_status(DocStatus.PROCESSING) + to_process_docs: dict[str, DocProcessingStatus] = {} to_process_docs.update(processing_docs) - failed_docs = await self.doc_status.get_docs_by_status(DocStatus.FAILED) to_process_docs.update(failed_docs) - pendings_docs = await self.doc_status.get_docs_by_status(DocStatus.PENDING) - to_process_docs.update(pendings_docs) + to_process_docs.update(pending_docs) if not to_process_docs: logger.info("All documents have been processed or are duplicates") @@ -789,10 +792,11 @@ class LightRAG: ] logger.info(f"Number of batches to process: {len(docs_batches)}.") - - tasks: list[tuple[str, DocProcessingStatus, dict[str, Any], Any]] = [] # 3. iterate over batches for batch_idx, docs_batch in enumerate(docs_batches): + logger.info( + f"Start processing batch {batch_idx + 1} of {len(docs_batches)}." + ) # 4. iterate over batch for doc_id_processing_status in docs_batch: doc_id, status_doc = doc_id_processing_status @@ -826,91 +830,47 @@ class LightRAG: ) } - # Prepare async tasks with full context - tasks.extend( - [ - ( - doc_status_id, - status_doc, - chunks, - self.chunks_vdb.upsert(chunks), - ), - ( - doc_status_id, - status_doc, - chunks, - self._process_entity_relation_graph(chunks), - ), - ( - doc_status_id, - status_doc, - chunks, - self.full_docs.upsert( - {doc_id: {"content": status_doc.content}} - ), - ), - ( - doc_status_id, - status_doc, - chunks, - self.text_chunks.upsert(chunks), - ), - ] - ) - - # Execute tasks as they complete - for future in asyncio.as_completed([task for _, _, _, task in tasks]): - try: - # Wait for the completed task - await future - - # Retrieve the full context of the completed task - completed_doc_status_id, status_doc, chunks, _ = next( - (doc_id, s_doc, ch, task) - for doc_id, s_doc, ch, task in tasks - if task == future - ) - - # Update status to processed - await self.doc_status.upsert( - { - completed_doc_status_id: { - "status": DocStatus.PROCESSED, - "chunks_count": len(chunks), - "content": status_doc.content, - "content_summary": status_doc.content_summary, - "content_length": status_doc.content_length, - "created_at": status_doc.created_at, - "updated_at": datetime.now().isoformat(), + # Process document (text chunks and full docs) in parallel + tasks = [ + self.chunks_vdb.upsert(chunks), + self._process_entity_relation_graph(chunks), + self.full_docs.upsert({doc_id: {"content": status_doc.content}}), + self.text_chunks.upsert(chunks), + self.doc_status.upsert( + { + doc_status_id: { + "status": DocStatus.PROCESSED, + "chunks_count": len(chunks), + "content": status_doc.content, + "content_summary": status_doc.content_summary, + "content_length": status_doc.content_length, + "created_at": status_doc.created_at, + "updated_at": datetime.now().isoformat(), + } } - } - ) - logger.info(f"Completed doc_id: {completed_doc_status_id}") - except Exception as e: - # Retrieve the context of the failed task - failed_doc_status_id, status_doc, chunks, _ = next( - (doc_id, s_doc, ch, task) - for doc_id, s_doc, ch, task in tasks - if task == future - ) - logger.error( - f"Failed to process document {failed_doc_status_id}: {str(e)}" - ) + ), + ] + try: + await asyncio.gather(*tasks) + await self._insert_done() - await self.doc_status.upsert( - { - failed_doc_status_id: { - "status": DocStatus.FAILED, - "error": str(e), - "content": status_doc.content, - "content_summary": status_doc.content_summary, - "content_length": status_doc.content_length, - "created_at": status_doc.created_at, - "updated_at": datetime.now().isoformat(), + except Exception as e: + logger.error(f"Failed to process document {doc_id}: {str(e)}") + await self.doc_status.upsert( + { + doc_status_id: { + "status": DocStatus.FAILED, + "error": str(e), + "content": status_doc.content, + "content_summary": status_doc.content_summary, + "content_length": status_doc.content_length, + "created_at": status_doc.created_at, + "updated_at": datetime.now().isoformat(), + } } - } - ) - await self._insert_done() + ) + continue + logger.info(f"Completed batch {batch_idx + 1} of {len(docs_batches)}.") async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None: try: