Merge pull request #464 from billvsme/fix/asyncio.as_completed

Maybe very important!!! Fix embedding error
This commit is contained in:
zrguo
2024-12-13 17:17:37 +08:00
committed by GitHub
2 changed files with 20 additions and 20 deletions

View File

@@ -56,16 +56,16 @@ class MilvusVectorDBStorge(BaseVectorStorage):
contents[i : i + self._max_batch_size] contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size) for i in range(0, len(contents), self._max_batch_size)
] ]
embedding_tasks = [self.embedding_func(batch) for batch in batches]
embeddings_list = [] async def wrapped_task(batch):
for f in tqdm_async( result = await self.embedding_func(batch)
asyncio.as_completed(embedding_tasks), pbar.update(1)
total=len(embedding_tasks), return result
desc="Generating embeddings",
unit="batch", embedding_tasks = [wrapped_task(batch) for batch in batches]
): pbar = tqdm_async(total=len(embedding_tasks), desc="Generating embeddings", unit="batch")
embeddings = await f embeddings_list = await asyncio.gather(*embedding_tasks)
embeddings_list.append(embeddings)
embeddings = np.concatenate(embeddings_list) embeddings = np.concatenate(embeddings_list)
for i, d in enumerate(list_data): for i, d in enumerate(list_data):
d["vector"] = embeddings[i] d["vector"] = embeddings[i]

View File

@@ -96,16 +96,16 @@ class NanoVectorDBStorage(BaseVectorStorage):
contents[i : i + self._max_batch_size] contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size) for i in range(0, len(contents), self._max_batch_size)
] ]
embedding_tasks = [self.embedding_func(batch) for batch in batches]
embeddings_list = [] async def wrapped_task(batch):
for f in tqdm_async( result = await self.embedding_func(batch)
asyncio.as_completed(embedding_tasks), pbar.update(1)
total=len(embedding_tasks), return result
desc="Generating embeddings",
unit="batch", embedding_tasks = [wrapped_task(batch) for batch in batches]
): pbar = tqdm_async(total=len(embedding_tasks), desc="Generating embeddings", unit="batch")
embeddings = await f embeddings_list = await asyncio.gather(*embedding_tasks)
embeddings_list.append(embeddings)
embeddings = np.concatenate(embeddings_list) embeddings = np.concatenate(embeddings_list)
if len(embeddings) == len(list_data): if len(embeddings) == len(list_data):
for i, d in enumerate(list_data): for i, d in enumerate(list_data):