From 8562ecdebca67f689b682275f63e42e6634fa47f Mon Sep 17 00:00:00 2001 From: Larfii <834462287@qq.com> Date: Mon, 25 Nov 2024 15:04:38 +0800 Subject: [PATCH] Add a progress bar --- lightrag/lightrag.py | 5 +++- lightrag/operate.py | 59 +++++++++++++++++++++++++++++++------------- lightrag/storage.py | 14 ++++++++--- 3 files changed, 57 insertions(+), 21 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 7fafadcf..28e72102 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -1,5 +1,6 @@ import asyncio import os +from tqdm.asyncio import tqdm as tqdm_async from dataclasses import asdict, dataclass, field from datetime import datetime from functools import partial @@ -243,7 +244,9 @@ class LightRAG: logger.info(f"[New Docs] inserting {len(new_docs)} docs") inserting_chunks = {} - for doc_key, doc in new_docs.items(): + for doc_key, doc in tqdm_async( + new_docs.items(), desc="Chunking documents", unit="doc" + ): chunks = { compute_mdhash_id(dp["content"], prefix="chunk-"): { **dp, diff --git a/lightrag/operate.py b/lightrag/operate.py index cf236633..9e4b768a 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -1,6 +1,7 @@ import asyncio import json import re +from tqdm.asyncio import tqdm as tqdm_async from typing import Union from collections import Counter, defaultdict import warnings @@ -329,11 +330,15 @@ async def extract_entities( ) return dict(maybe_nodes), dict(maybe_edges) - # use_llm_func is wrapped in ascynio.Semaphore, limiting max_async callings - results = await asyncio.gather( - *[_process_single_content(c) for c in ordered_chunks] - ) - print() # clear the progress bar + results = [] + for result in tqdm_async( + asyncio.as_completed([_process_single_content(c) for c in ordered_chunks]), + total=len(ordered_chunks), + desc="Extracting entities from chunks", + unit="chunk", + ): + results.append(await result) + maybe_nodes = defaultdict(list) maybe_edges = defaultdict(list) for m_nodes, m_edges in results: @@ -341,18 +346,38 @@ async def extract_entities( maybe_nodes[k].extend(v) for k, v in m_edges.items(): maybe_edges[tuple(sorted(k))].extend(v) - all_entities_data = await asyncio.gather( - *[ - _merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config) - for k, v in maybe_nodes.items() - ] - ) - all_relationships_data = await asyncio.gather( - *[ - _merge_edges_then_upsert(k[0], k[1], v, knowledge_graph_inst, global_config) - for k, v in maybe_edges.items() - ] - ) + logger.info("Inserting entities into storage...") + all_entities_data = [] + for result in tqdm_async( + asyncio.as_completed( + [ + _merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config) + for k, v in maybe_nodes.items() + ] + ), + total=len(maybe_nodes), + desc="Inserting entities", + unit="entity", + ): + all_entities_data.append(await result) + + logger.info("Inserting relationships into storage...") + all_relationships_data = [] + for result in tqdm_async( + asyncio.as_completed( + [ + _merge_edges_then_upsert( + k[0], k[1], v, knowledge_graph_inst, global_config + ) + for k, v in maybe_edges.items() + ] + ), + total=len(maybe_edges), + desc="Inserting relationships", + unit="relationship", + ): + all_relationships_data.append(await result) + if not len(all_entities_data): logger.warning("Didn't extract any entities, maybe your LLM is not working") return None diff --git a/lightrag/storage.py b/lightrag/storage.py index 9a4c3d4c..007d6534 100644 --- a/lightrag/storage.py +++ b/lightrag/storage.py @@ -1,6 +1,7 @@ import asyncio import html import os +from tqdm.asyncio import tqdm as tqdm_async from dataclasses import dataclass from typing import Any, Union, cast import networkx as nx @@ -95,9 +96,16 @@ class NanoVectorDBStorage(BaseVectorStorage): contents[i : i + self._max_batch_size] for i in range(0, len(contents), self._max_batch_size) ] - embeddings_list = await asyncio.gather( - *[self.embedding_func(batch) for batch in batches] - ) + embedding_tasks = [self.embedding_func(batch) for batch in batches] + embeddings_list = [] + for f in tqdm_async( + asyncio.as_completed(embedding_tasks), + total=len(embedding_tasks), + desc="Generating embeddings", + unit="batch", + ): + embeddings = await f + embeddings_list.append(embeddings) embeddings = np.concatenate(embeddings_list) for i, d in enumerate(list_data): d["__vector__"] = embeddings[i]