Add a progress bar

This commit is contained in:
Larfii
2024-11-25 15:04:38 +08:00
parent 8161bd19f9
commit ac1587ad2a
3 changed files with 57 additions and 21 deletions

View File

@@ -1,5 +1,6 @@
import asyncio import asyncio
import os import os
from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
@@ -243,7 +244,9 @@ class LightRAG:
logger.info(f"[New Docs] inserting {len(new_docs)} docs") logger.info(f"[New Docs] inserting {len(new_docs)} docs")
inserting_chunks = {} inserting_chunks = {}
for doc_key, doc in new_docs.items(): for doc_key, doc in tqdm_async(
new_docs.items(), desc="Chunking documents", unit="doc"
):
chunks = { chunks = {
compute_mdhash_id(dp["content"], prefix="chunk-"): { compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp, **dp,

View File

@@ -1,6 +1,7 @@
import asyncio import asyncio
import json import json
import re import re
from tqdm.asyncio import tqdm as tqdm_async
from typing import Union from typing import Union
from collections import Counter, defaultdict from collections import Counter, defaultdict
import warnings import warnings
@@ -329,11 +330,15 @@ async def extract_entities(
) )
return dict(maybe_nodes), dict(maybe_edges) return dict(maybe_nodes), dict(maybe_edges)
# use_llm_func is wrapped in ascynio.Semaphore, limiting max_async callings results = []
results = await asyncio.gather( for result in tqdm_async(
*[_process_single_content(c) for c in ordered_chunks] asyncio.as_completed([_process_single_content(c) for c in ordered_chunks]),
) total=len(ordered_chunks),
print() # clear the progress bar desc="Extracting entities from chunks",
unit="chunk",
):
results.append(await result)
maybe_nodes = defaultdict(list) maybe_nodes = defaultdict(list)
maybe_edges = defaultdict(list) maybe_edges = defaultdict(list)
for m_nodes, m_edges in results: for m_nodes, m_edges in results:
@@ -341,18 +346,38 @@ async def extract_entities(
maybe_nodes[k].extend(v) maybe_nodes[k].extend(v)
for k, v in m_edges.items(): for k, v in m_edges.items():
maybe_edges[tuple(sorted(k))].extend(v) maybe_edges[tuple(sorted(k))].extend(v)
all_entities_data = await asyncio.gather( logger.info("Inserting entities into storage...")
*[ all_entities_data = []
_merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config) for result in tqdm_async(
for k, v in maybe_nodes.items() asyncio.as_completed(
] [
) _merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config)
all_relationships_data = await asyncio.gather( for k, v in maybe_nodes.items()
*[ ]
_merge_edges_then_upsert(k[0], k[1], v, knowledge_graph_inst, global_config) ),
for k, v in maybe_edges.items() total=len(maybe_nodes),
] desc="Inserting entities",
) unit="entity",
):
all_entities_data.append(await result)
logger.info("Inserting relationships into storage...")
all_relationships_data = []
for result in tqdm_async(
asyncio.as_completed(
[
_merge_edges_then_upsert(
k[0], k[1], v, knowledge_graph_inst, global_config
)
for k, v in maybe_edges.items()
]
),
total=len(maybe_edges),
desc="Inserting relationships",
unit="relationship",
):
all_relationships_data.append(await result)
if not len(all_entities_data): if not len(all_entities_data):
logger.warning("Didn't extract any entities, maybe your LLM is not working") logger.warning("Didn't extract any entities, maybe your LLM is not working")
return None return None

View File

@@ -1,6 +1,7 @@
import asyncio import asyncio
import html import html
import os import os
from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Union, cast from typing import Any, Union, cast
import networkx as nx import networkx as nx
@@ -95,9 +96,16 @@ class NanoVectorDBStorage(BaseVectorStorage):
contents[i : i + self._max_batch_size] contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size) for i in range(0, len(contents), self._max_batch_size)
] ]
embeddings_list = await asyncio.gather( embedding_tasks = [self.embedding_func(batch) for batch in batches]
*[self.embedding_func(batch) for batch in batches] embeddings_list = []
) for f in tqdm_async(
asyncio.as_completed(embedding_tasks),
total=len(embedding_tasks),
desc="Generating embeddings",
unit="batch",
):
embeddings = await f
embeddings_list.append(embeddings)
embeddings = np.concatenate(embeddings_list) embeddings = np.concatenate(embeddings_list)
for i, d in enumerate(list_data): for i, d in enumerate(list_data):
d["__vector__"] = embeddings[i] d["__vector__"] = embeddings[i]