diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index f37b4e09..f3a3ac9a 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -4,7 +4,7 @@ from tqdm.asyncio import tqdm as tqdm_async from dataclasses import asdict, dataclass, field from datetime import datetime from functools import partial -from typing import Any, Callable, Optional, Type, Union, cast +from typing import Any, Callable, Coroutine, Optional, Type, Union, cast import traceback from .operate import ( chunking_by_token_size, @@ -561,72 +561,96 @@ class LightRAG: ] for i, el in enumerate(batch_docs_list): items = ((k, v) for d in el for k, v in d.items()) + + tasks: dict[str, list[Coroutine[Any, Any, None]]] = {} + + doc_status: dict[str, Any] = { + "status": DocStatus.PROCESSING, + "updated_at": datetime.now().isoformat(), + } + for doc_id, doc in tqdm_async( items, desc=f"Level 1 - Spliting doc in batch {i // len(batch_docs_list) + 1}", ): - doc_status: dict[str, Any] = { - "content_summary": doc["content_summary"], - "content_length": doc["content_length"], - "status": DocStatus.PROCESSING, - "created_at": doc["created_at"], - "updated_at": datetime.now().isoformat(), + doc_status.update( + { + "content_summary": doc["content_summary"], + "content_length": doc["content_length"], + "created_at": doc["created_at"], + } + ) + + await self.doc_status.upsert({doc_id: doc_status}) + + # Generate chunks from document + chunks: dict[str, Any] = { + compute_mdhash_id(dp["content"], prefix="chunk-"): { + **dp, + "full_doc_id": doc_id, + } + for dp in self.chunking_func( + doc["content"], + split_by_character, + split_by_character_only, + self.chunk_overlap_token_size, + self.chunk_token_size, + self.tiktoken_model_name, + ) } try: - await self.doc_status.upsert({doc_id: doc_status}) - - # Generate chunks from document - chunks: dict[str, Any] = { - compute_mdhash_id(dp["content"], prefix="chunk-"): { - **dp, - "full_doc_id": doc_id, - } - for dp in self.chunking_func( - doc["content"], - split_by_character, - split_by_character_only, - self.chunk_overlap_token_size, - self.chunk_token_size, - self.tiktoken_model_name, - ) - } - await self.chunks_vdb.upsert(chunks) - - # Update status with chunks information - await self._process_entity_relation_graph(chunks) - - if doc_id not in full_docs_new_docs_ids: - await self.full_docs.upsert( - {doc_id: {"content": doc["content"]}} - ) - - if doc_id not in text_chunks_new_docs_ids: - await self.text_chunks.upsert(chunks) - - doc_status.update( - { - "status": DocStatus.PROCESSED, - "chunks_count": len(chunks), - "updated_at": datetime.now().isoformat(), - } - ) - await self.doc_status.upsert({doc_id: doc_status}) - await self._insert_done() - + # If fails it's failed on full doc and text chunks upset + if doc["status"] != DocStatus.FAILED: + # Ensure chunk insertion and graph processing happen sequentially + await self._process_entity_relation_graph(chunks) + await self.chunks_vdb.upsert(chunks) except Exception as e: - # Update status with failed information doc_status.update( { - "status": DocStatus.FAILED, + "status": DocStatus.PENDING, "error": str(e), "updated_at": datetime.now().isoformat(), } ) await self.doc_status.upsert({doc_id: doc_status}) - logger.error( - f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}" + + if doc_id not in full_docs_new_docs_ids: + tasks[doc_id].append( + self.full_docs.upsert({doc_id: {"content": doc["content"]}}) ) - continue + + if doc_id not in text_chunks_new_docs_ids: + tasks[doc_id].append(self.text_chunks.upsert(chunks)) + + for doc_id, task in tasks.items(): + try: + await asyncio.gather(*task) + + # Update document status + doc_status.update( + { + "status": DocStatus.PROCESSED, + "chunks_count": len(chunks), + "updated_at": datetime.now().isoformat(), + } + ) + await self.doc_status.upsert({doc_id: doc_status}) + await self._insert_done() + + except Exception as e: + # Update status with failed information + doc_status.update( + { + "status": DocStatus.FAILED, + "error": str(e), + "updated_at": datetime.now().isoformat(), + } + ) + await self.doc_status.upsert({doc_id: doc_status}) + logger.error( + f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}" + ) + continue async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None: try: