Merge pull request #464 from billvsme/fix/asyncio.as_completed
Maybe very important!!! Fix embedding error
This commit is contained in:
@@ -56,16 +56,16 @@ class MilvusVectorDBStorge(BaseVectorStorage):
|
||||
contents[i : i + self._max_batch_size]
|
||||
for i in range(0, len(contents), self._max_batch_size)
|
||||
]
|
||||
embedding_tasks = [self.embedding_func(batch) for batch in batches]
|
||||
embeddings_list = []
|
||||
for f in tqdm_async(
|
||||
asyncio.as_completed(embedding_tasks),
|
||||
total=len(embedding_tasks),
|
||||
desc="Generating embeddings",
|
||||
unit="batch",
|
||||
):
|
||||
embeddings = await f
|
||||
embeddings_list.append(embeddings)
|
||||
|
||||
async def wrapped_task(batch):
|
||||
result = await self.embedding_func(batch)
|
||||
pbar.update(1)
|
||||
return result
|
||||
|
||||
embedding_tasks = [wrapped_task(batch) for batch in batches]
|
||||
pbar = tqdm_async(total=len(embedding_tasks), desc="Generating embeddings", unit="batch")
|
||||
embeddings_list = await asyncio.gather(*embedding_tasks)
|
||||
|
||||
embeddings = np.concatenate(embeddings_list)
|
||||
for i, d in enumerate(list_data):
|
||||
d["vector"] = embeddings[i]
|
||||
|
@@ -96,16 +96,16 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
contents[i : i + self._max_batch_size]
|
||||
for i in range(0, len(contents), self._max_batch_size)
|
||||
]
|
||||
embedding_tasks = [self.embedding_func(batch) for batch in batches]
|
||||
embeddings_list = []
|
||||
for f in tqdm_async(
|
||||
asyncio.as_completed(embedding_tasks),
|
||||
total=len(embedding_tasks),
|
||||
desc="Generating embeddings",
|
||||
unit="batch",
|
||||
):
|
||||
embeddings = await f
|
||||
embeddings_list.append(embeddings)
|
||||
|
||||
async def wrapped_task(batch):
|
||||
result = await self.embedding_func(batch)
|
||||
pbar.update(1)
|
||||
return result
|
||||
|
||||
embedding_tasks = [wrapped_task(batch) for batch in batches]
|
||||
pbar = tqdm_async(total=len(embedding_tasks), desc="Generating embeddings", unit="batch")
|
||||
embeddings_list = await asyncio.gather(*embedding_tasks)
|
||||
|
||||
embeddings = np.concatenate(embeddings_list)
|
||||
if len(embeddings) == len(list_data):
|
||||
for i, d in enumerate(list_data):
|
||||
|
Reference in New Issue
Block a user