diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index c72498c5..fe046eb4 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -56,16 +56,16 @@ class MilvusVectorDBStorge(BaseVectorStorage): contents[i : i + self._max_batch_size] for i in range(0, len(contents), self._max_batch_size) ] - embedding_tasks = [self.embedding_func(batch) for batch in batches] - embeddings_list = [] - for f in tqdm_async( - await asyncio.gather(*embedding_tasks), - total=len(embedding_tasks), - desc="Generating embeddings", - unit="batch", - ): - embeddings = await f - embeddings_list.append(embeddings) + + async def wrapped_task(batch): + result = await self.embedding_func(batch) + pbar.update(1) + return result + + embedding_tasks = [wrapped_task(batch) for batch in batches] + pbar = tqdm_async(total=len(embedding_tasks), desc="Generating embeddings", unit="batch") + embeddings_list = await asyncio.gather(*embedding_tasks) + embeddings = np.concatenate(embeddings_list) for i, d in enumerate(list_data): d["vector"] = embeddings[i] diff --git a/lightrag/storage.py b/lightrag/storage.py index 534c6e2e..037a9c2f 100644 --- a/lightrag/storage.py +++ b/lightrag/storage.py @@ -96,16 +96,16 @@ class NanoVectorDBStorage(BaseVectorStorage): contents[i : i + self._max_batch_size] for i in range(0, len(contents), self._max_batch_size) ] - embedding_tasks = [self.embedding_func(batch) for batch in batches] - embeddings_list = [] - for f in tqdm_async( - await asyncio.gather(*embedding_tasks), - total=len(embedding_tasks), - desc="Generating embeddings", - unit="batch", - ): - embeddings = await f - embeddings_list.append(embeddings) + + async def wrapped_task(batch): + result = await self.embedding_func(batch) + pbar.update(1) + return result + + embedding_tasks = [wrapped_task(batch) for batch in batches] + pbar = tqdm_async(total=len(embedding_tasks), desc="Generating embeddings", unit="batch") + embeddings_list = await asyncio.gather(*embedding_tasks) + embeddings = np.concatenate(embeddings_list) if len(embeddings) == len(list_data): for i, d in enumerate(list_data):