improved paralle

This commit is contained in:
Yannick Stephan
2025-02-09 20:41:18 +01:00
parent 5e3100221c
commit 2b99637584

View File

@@ -524,15 +524,25 @@ class LightRAG:
3. Process each chunk for entity and relation extraction 3. Process each chunk for entity and relation extraction
4. Update the document status 4. Update the document status
""" """
async def insert_full_doc(doc_id: str, content: str):
# Check if document is already processed
doc = await self.full_docs.get_by_id(doc_id)
if not doc:
await self.full_docs.upsert({doc_id: {"content": content}})
async def insert_doc_status(doc_id: str, chunks: dict[str, Any]):
# Check if chunks are already processed
doc = await self.text_chunks.get_by_id(doc_id)
if not doc:
await self.text_chunks.upsert(chunks)
# 1. get all pending and failed documents # 1. get all pending and failed documents
to_process_docs: dict[str, DocProcessingStatus] = {} to_process_docs: dict[str, DocProcessingStatus] = {}
# Fetch failed documents # Fetch failed documents
failed_docs = await self.doc_status.get_failed_docs() to_process_docs.update(await self.doc_status.get_failed_docs())
to_process_docs.update(failed_docs) to_process_docs.update(await self.doc_status.get_pending_docs())
pending_docs = await self.doc_status.get_pending_docs()
to_process_docs.update(pending_docs)
if not to_process_docs: if not to_process_docs:
logger.info("All documents have been processed or are duplicates") logger.info("All documents have been processed or are duplicates")
@@ -545,11 +555,10 @@ class LightRAG:
for i in range(0, len(to_process_docs), batch_size) for i in range(0, len(to_process_docs), batch_size)
] ]
# 3. iterate over batches
tasks: dict[str, list[Coroutine[Any, Any, None]]] = {}
logger.info(f"Number of batches to process: {len(docs_batches)}.") logger.info(f"Number of batches to process: {len(docs_batches)}.")
# 3. iterate over batches
tasks: dict[str, list[Coroutine[Any, Any, None]]] = {}
for batch_idx, docs_batch in enumerate(docs_batches): for batch_idx, docs_batch in enumerate(docs_batches):
# 4. iterate over batch # 4. iterate over batch
for doc_id_processing_status in docs_batch: for doc_id_processing_status in docs_batch:
@@ -586,47 +595,35 @@ class LightRAG:
await self.chunks_vdb.upsert(chunks) await self.chunks_vdb.upsert(chunks)
await self._process_entity_relation_graph(chunks) await self._process_entity_relation_graph(chunks)
tasks[doc_id] = []
# Check if document already processed the doc
if await self.full_docs.get_by_id(doc_id) is None:
tasks[doc_id].append(
self.full_docs.upsert({doc_id: {"content": status_doc.content}})
)
# Check if chunks already processed the doc
if await self.text_chunks.get_by_id(doc_id) is None:
tasks[doc_id].append(self.text_chunks.upsert(chunks))
# Process document (text chunks and full docs) in parallel # Process document (text chunks and full docs) in parallel
for task_doc_id, task in tasks.items(): tasks = []
try: tasks.append(insert_full_doc(doc_id, status_doc.content))
await asyncio.gather(*task) tasks.append(insert_doc_status(doc_id, chunks))
await self.doc_status.upsert( try:
{ await asyncio.gather(*tasks)
task_doc_id: { await self.doc_status.upsert(
"status": DocStatus.PROCESSED, {
"chunks_count": len(chunks), doc_id: {
"updated_at": datetime.now().isoformat(), "status": DocStatus.PROCESSED,
} "chunks_count": len(chunks),
"updated_at": datetime.now().isoformat(),
} }
) }
await self._insert_done() )
await self._insert_done()
except Exception as e: except Exception as e:
logger.error( logger.error(f"Failed to process document {doc_id}: {str(e)}")
f"Failed to process document {task_doc_id}: {str(e)}" await self.doc_status.upsert(
) {
await self.doc_status.upsert( doc_id: {
{ "status": DocStatus.FAILED,
task_doc_id: { "error": str(e),
"status": DocStatus.FAILED, "updated_at": datetime.now().isoformat(),
"error": str(e),
"updated_at": datetime.now().isoformat(),
}
} }
) }
continue )
continue
logger.info(f"Completed batch {batch_idx + 1} of {len(docs_batches)}.") logger.info(f"Completed batch {batch_idx + 1} of {len(docs_batches)}.")
async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None: async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None:
@@ -640,8 +637,9 @@ class LightRAG:
global_config=asdict(self), global_config=asdict(self),
) )
if new_kg is None: if new_kg is None:
logger.info("No entities or relationships extracted!") logger.info("No new entities or relationships extracted.")
else: else:
logger.info("New entities or relationships extracted.")
self.chunk_entity_relation_graph = new_kg self.chunk_entity_relation_graph = new_kg
except Exception as e: except Exception as e: