Add a progress bar
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
|
from tqdm.asyncio import tqdm as tqdm_async
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@@ -243,7 +244,9 @@ class LightRAG:
|
|||||||
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
|
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
|
||||||
|
|
||||||
inserting_chunks = {}
|
inserting_chunks = {}
|
||||||
for doc_key, doc in new_docs.items():
|
for doc_key, doc in tqdm_async(
|
||||||
|
new_docs.items(), desc="Chunking documents", unit="doc"
|
||||||
|
):
|
||||||
chunks = {
|
chunks = {
|
||||||
compute_mdhash_id(dp["content"], prefix="chunk-"): {
|
compute_mdhash_id(dp["content"], prefix="chunk-"): {
|
||||||
**dp,
|
**dp,
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
from tqdm.asyncio import tqdm as tqdm_async
|
||||||
from typing import Union
|
from typing import Union
|
||||||
from collections import Counter, defaultdict
|
from collections import Counter, defaultdict
|
||||||
import warnings
|
import warnings
|
||||||
@@ -329,11 +330,15 @@ async def extract_entities(
|
|||||||
)
|
)
|
||||||
return dict(maybe_nodes), dict(maybe_edges)
|
return dict(maybe_nodes), dict(maybe_edges)
|
||||||
|
|
||||||
# use_llm_func is wrapped in ascynio.Semaphore, limiting max_async callings
|
results = []
|
||||||
results = await asyncio.gather(
|
for result in tqdm_async(
|
||||||
*[_process_single_content(c) for c in ordered_chunks]
|
asyncio.as_completed([_process_single_content(c) for c in ordered_chunks]),
|
||||||
)
|
total=len(ordered_chunks),
|
||||||
print() # clear the progress bar
|
desc="Extracting entities from chunks",
|
||||||
|
unit="chunk",
|
||||||
|
):
|
||||||
|
results.append(await result)
|
||||||
|
|
||||||
maybe_nodes = defaultdict(list)
|
maybe_nodes = defaultdict(list)
|
||||||
maybe_edges = defaultdict(list)
|
maybe_edges = defaultdict(list)
|
||||||
for m_nodes, m_edges in results:
|
for m_nodes, m_edges in results:
|
||||||
@@ -341,18 +346,38 @@ async def extract_entities(
|
|||||||
maybe_nodes[k].extend(v)
|
maybe_nodes[k].extend(v)
|
||||||
for k, v in m_edges.items():
|
for k, v in m_edges.items():
|
||||||
maybe_edges[tuple(sorted(k))].extend(v)
|
maybe_edges[tuple(sorted(k))].extend(v)
|
||||||
all_entities_data = await asyncio.gather(
|
logger.info("Inserting entities into storage...")
|
||||||
*[
|
all_entities_data = []
|
||||||
_merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config)
|
for result in tqdm_async(
|
||||||
for k, v in maybe_nodes.items()
|
asyncio.as_completed(
|
||||||
]
|
[
|
||||||
)
|
_merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config)
|
||||||
all_relationships_data = await asyncio.gather(
|
for k, v in maybe_nodes.items()
|
||||||
*[
|
]
|
||||||
_merge_edges_then_upsert(k[0], k[1], v, knowledge_graph_inst, global_config)
|
),
|
||||||
for k, v in maybe_edges.items()
|
total=len(maybe_nodes),
|
||||||
]
|
desc="Inserting entities",
|
||||||
)
|
unit="entity",
|
||||||
|
):
|
||||||
|
all_entities_data.append(await result)
|
||||||
|
|
||||||
|
logger.info("Inserting relationships into storage...")
|
||||||
|
all_relationships_data = []
|
||||||
|
for result in tqdm_async(
|
||||||
|
asyncio.as_completed(
|
||||||
|
[
|
||||||
|
_merge_edges_then_upsert(
|
||||||
|
k[0], k[1], v, knowledge_graph_inst, global_config
|
||||||
|
)
|
||||||
|
for k, v in maybe_edges.items()
|
||||||
|
]
|
||||||
|
),
|
||||||
|
total=len(maybe_edges),
|
||||||
|
desc="Inserting relationships",
|
||||||
|
unit="relationship",
|
||||||
|
):
|
||||||
|
all_relationships_data.append(await result)
|
||||||
|
|
||||||
if not len(all_entities_data):
|
if not len(all_entities_data):
|
||||||
logger.warning("Didn't extract any entities, maybe your LLM is not working")
|
logger.warning("Didn't extract any entities, maybe your LLM is not working")
|
||||||
return None
|
return None
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import html
|
import html
|
||||||
import os
|
import os
|
||||||
|
from tqdm.asyncio import tqdm as tqdm_async
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Union, cast
|
from typing import Any, Union, cast
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
@@ -95,9 +96,16 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|||||||
contents[i : i + self._max_batch_size]
|
contents[i : i + self._max_batch_size]
|
||||||
for i in range(0, len(contents), self._max_batch_size)
|
for i in range(0, len(contents), self._max_batch_size)
|
||||||
]
|
]
|
||||||
embeddings_list = await asyncio.gather(
|
embedding_tasks = [self.embedding_func(batch) for batch in batches]
|
||||||
*[self.embedding_func(batch) for batch in batches]
|
embeddings_list = []
|
||||||
)
|
for f in tqdm_async(
|
||||||
|
asyncio.as_completed(embedding_tasks),
|
||||||
|
total=len(embedding_tasks),
|
||||||
|
desc="Generating embeddings",
|
||||||
|
unit="batch",
|
||||||
|
):
|
||||||
|
embeddings = await f
|
||||||
|
embeddings_list.append(embeddings)
|
||||||
embeddings = np.concatenate(embeddings_list)
|
embeddings = np.concatenate(embeddings_list)
|
||||||
for i, d in enumerate(list_data):
|
for i, d in enumerate(list_data):
|
||||||
d["__vector__"] = embeddings[i]
|
d["__vector__"] = embeddings[i]
|
||||||
|
Reference in New Issue
Block a user