From 98d005dc1c5bb8a34897c9fd77c2d3f1b2089322 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Wed, 19 Feb 2025 23:26:21 +0100 Subject: [PATCH] updated paralle --- lightrag/lightrag.py | 123 +++++++++++++++++++++++++++++-------------- lightrag/operate.py | 10 +--- 2 files changed, 85 insertions(+), 48 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index e46d548c..6343b291 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -790,6 +790,7 @@ class LightRAG: logger.info(f"Number of batches to process: {len(docs_batches)}.") + tasks: list[tuple[str, DocProcessingStatus, dict[str, Any], Any]] = [] # 3. iterate over batches for batch_idx, docs_batch in enumerate(docs_batches): # 4. iterate over batch @@ -825,47 +826,91 @@ class LightRAG: ) } - # Process document (text chunks and full docs) in parallel - tasks = [ - self.chunks_vdb.upsert(chunks), - self._process_entity_relation_graph(chunks), - self.full_docs.upsert({doc_id: {"content": status_doc.content}}), - self.text_chunks.upsert(chunks), - ] - try: - await asyncio.gather(*tasks) - await self.doc_status.upsert( - { - doc_status_id: { - "status": DocStatus.PROCESSED, - "chunks_count": len(chunks), - "content": status_doc.content, - "content_summary": status_doc.content_summary, - "content_length": status_doc.content_length, - "created_at": status_doc.created_at, - "updated_at": datetime.now().isoformat(), - } - } - ) - await self._insert_done() + # Prepare async tasks with full context + tasks.extend( + [ + ( + doc_status_id, + status_doc, + chunks, + self.chunks_vdb.upsert(chunks), + ), + ( + doc_status_id, + status_doc, + chunks, + self._process_entity_relation_graph(chunks), + ), + ( + doc_status_id, + status_doc, + chunks, + self.full_docs.upsert( + {doc_id: {"content": status_doc.content}} + ), + ), + ( + doc_status_id, + status_doc, + chunks, + self.text_chunks.upsert(chunks), + ), + ] + ) - except Exception as e: - logger.error(f"Failed to process document {doc_id}: {str(e)}") - await self.doc_status.upsert( - { - doc_status_id: { - "status": DocStatus.FAILED, - "error": str(e), - "content": status_doc.content, - "content_summary": status_doc.content_summary, - "content_length": status_doc.content_length, - "created_at": status_doc.created_at, - "updated_at": datetime.now().isoformat(), - } + # Execute tasks as they complete + for future in asyncio.as_completed([task for _, _, _, task in tasks]): + try: + # Wait for the completed task + await future + + # Retrieve the full context of the completed task + completed_doc_status_id, status_doc, chunks, _ = next( + (doc_id, s_doc, ch, task) + for doc_id, s_doc, ch, task in tasks + if task == future + ) + + # Update status to processed + await self.doc_status.upsert( + { + completed_doc_status_id: { + "status": DocStatus.PROCESSED, + "chunks_count": len(chunks), + "content": status_doc.content, + "content_summary": status_doc.content_summary, + "content_length": status_doc.content_length, + "created_at": status_doc.created_at, + "updated_at": datetime.now().isoformat(), } - ) - continue - logger.info(f"Completed batch {batch_idx + 1} of {len(docs_batches)}.") + } + ) + logger.info(f"Completed doc_id: {completed_doc_status_id}") + except Exception as e: + # Retrieve the context of the failed task + failed_doc_status_id, status_doc, chunks, _ = next( + (doc_id, s_doc, ch, task) + for doc_id, s_doc, ch, task in tasks + if task == future + ) + logger.error( + f"Failed to process document {failed_doc_status_id}: {str(e)}" + ) + + await self.doc_status.upsert( + { + failed_doc_status_id: { + "status": DocStatus.FAILED, + "error": str(e), + "content": status_doc.content, + "content_summary": status_doc.content_summary, + "content_length": status_doc.content_length, + "created_at": status_doc.created_at, + "updated_at": datetime.now().isoformat(), + } + } + ) + await self._insert_done() async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None: try: diff --git a/lightrag/operate.py b/lightrag/operate.py index 9552f2ed..27950b7d 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -1326,15 +1326,12 @@ async def _get_edge_data( ), ) - if not all([n is not None for n in edge_datas]): - logger.warning("Some edges are missing, maybe the storage is damaged") - edge_datas = [ { "src_id": k["src_id"], "tgt_id": k["tgt_id"], "rank": d, - "created_at": k.get("__created_at__", None), # 从 KV 存储中获取时间元数据 + "created_at": k.get("__created_at__", None), **v, } for k, v, d in zip(results, edge_datas, edge_degree) @@ -1343,16 +1340,11 @@ async def _get_edge_data( edge_datas = sorted( edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True ) - len_edge_datas = len(edge_datas) edge_datas = truncate_list_by_token_size( edge_datas, key=lambda x: x["description"], max_token_size=query_param.max_token_for_global_context, ) - logger.debug( - f"Truncate relations from {len_edge_datas} to {len(edge_datas)} (max tokens:{query_param.max_token_for_global_context})" - ) - use_entities, use_text_units = await asyncio.gather( _find_most_related_entities_from_relationships( edge_datas, query_param, knowledge_graph_inst