asyncio optimizations
This commit is contained in:
@@ -941,28 +941,35 @@ async def _build_query_context(
|
||||
query_param,
|
||||
)
|
||||
else: # hybrid mode
|
||||
ll_data, hl_data = await asyncio.gather(
|
||||
_get_node_data(
|
||||
ll_keywords,
|
||||
knowledge_graph_inst,
|
||||
entities_vdb,
|
||||
text_chunks_db,
|
||||
query_param,
|
||||
),
|
||||
_get_edge_data(
|
||||
hl_keywords,
|
||||
knowledge_graph_inst,
|
||||
relationships_vdb,
|
||||
text_chunks_db,
|
||||
query_param,
|
||||
),
|
||||
)
|
||||
|
||||
(
|
||||
ll_entities_context,
|
||||
ll_relations_context,
|
||||
ll_text_units_context,
|
||||
) = await _get_node_data(
|
||||
ll_keywords,
|
||||
knowledge_graph_inst,
|
||||
entities_vdb,
|
||||
text_chunks_db,
|
||||
query_param,
|
||||
)
|
||||
) = ll_data
|
||||
|
||||
(
|
||||
hl_entities_context,
|
||||
hl_relations_context,
|
||||
hl_text_units_context,
|
||||
) = await _get_edge_data(
|
||||
hl_keywords,
|
||||
knowledge_graph_inst,
|
||||
relationships_vdb,
|
||||
text_chunks_db,
|
||||
query_param,
|
||||
)
|
||||
) = hl_data
|
||||
|
||||
entities_context, relations_context, text_units_context = combine_contexts(
|
||||
[hl_entities_context, ll_entities_context],
|
||||
[hl_relations_context, ll_relations_context],
|
||||
@@ -996,28 +1003,31 @@ async def _get_node_data(
|
||||
if not len(results):
|
||||
return "", "", ""
|
||||
# get entity information
|
||||
node_datas = await asyncio.gather(
|
||||
*[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
|
||||
node_datas, node_degrees = await asyncio.gather(
|
||||
asyncio.gather(
|
||||
*[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
|
||||
),
|
||||
asyncio.gather(
|
||||
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
|
||||
),
|
||||
)
|
||||
|
||||
if not all([n is not None for n in node_datas]):
|
||||
logger.warning("Some nodes are missing, maybe the storage is damaged")
|
||||
|
||||
# get entity degree
|
||||
node_degrees = await asyncio.gather(
|
||||
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
|
||||
)
|
||||
node_datas = [
|
||||
{**n, "entity_name": k["entity_name"], "rank": d}
|
||||
for k, n, d in zip(results, node_datas, node_degrees)
|
||||
if n is not None
|
||||
] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
|
||||
# get entitytext chunk
|
||||
use_text_units = await _find_most_related_text_unit_from_entities(
|
||||
node_datas, query_param, text_chunks_db, knowledge_graph_inst
|
||||
)
|
||||
# get relate edges
|
||||
use_relations = await _find_most_related_edges_from_entities(
|
||||
node_datas, query_param, knowledge_graph_inst
|
||||
use_text_units, use_relations = await asyncio.gather(
|
||||
_find_most_related_text_unit_from_entities(
|
||||
node_datas, query_param, text_chunks_db, knowledge_graph_inst
|
||||
),
|
||||
_find_most_related_edges_from_entities(
|
||||
node_datas, query_param, knowledge_graph_inst
|
||||
),
|
||||
)
|
||||
logger.info(
|
||||
f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} text units"
|
||||
@@ -1107,22 +1117,30 @@ async def _find_most_related_text_unit_from_entities(
|
||||
}
|
||||
|
||||
all_text_units_lookup = {}
|
||||
tasks = []
|
||||
for index, (this_text_units, this_edges) in enumerate(zip(text_units, edges)):
|
||||
for c_id in this_text_units:
|
||||
if c_id not in all_text_units_lookup:
|
||||
all_text_units_lookup[c_id] = {
|
||||
"data": await text_chunks_db.get_by_id(c_id),
|
||||
"order": index,
|
||||
"relation_counts": 0,
|
||||
}
|
||||
tasks.append((c_id, index, this_edges))
|
||||
|
||||
if this_edges:
|
||||
for e in this_edges:
|
||||
if (
|
||||
e[1] in all_one_hop_text_units_lookup
|
||||
and c_id in all_one_hop_text_units_lookup[e[1]]
|
||||
):
|
||||
all_text_units_lookup[c_id]["relation_counts"] += 1
|
||||
results = await asyncio.gather(
|
||||
*[text_chunks_db.get_by_id(c_id) for c_id, _, _ in tasks]
|
||||
)
|
||||
|
||||
for (c_id, index, this_edges), data in zip(tasks, results):
|
||||
all_text_units_lookup[c_id] = {
|
||||
"data": data,
|
||||
"order": index,
|
||||
"relation_counts": 0,
|
||||
}
|
||||
|
||||
if this_edges:
|
||||
for e in this_edges:
|
||||
if (
|
||||
e[1] in all_one_hop_text_units_lookup
|
||||
and c_id in all_one_hop_text_units_lookup[e[1]]
|
||||
):
|
||||
all_text_units_lookup[c_id]["relation_counts"] += 1
|
||||
|
||||
# Filter out None values and ensure data has content
|
||||
all_text_units = [
|
||||
@@ -1167,11 +1185,11 @@ async def _find_most_related_edges_from_entities(
|
||||
seen.add(sorted_edge)
|
||||
all_edges.append(sorted_edge)
|
||||
|
||||
all_edges_pack = await asyncio.gather(
|
||||
*[knowledge_graph_inst.get_edge(e[0], e[1]) for e in all_edges]
|
||||
)
|
||||
all_edges_degree = await asyncio.gather(
|
||||
*[knowledge_graph_inst.edge_degree(e[0], e[1]) for e in all_edges]
|
||||
all_edges_pack, all_edges_degree = await asyncio.gather(
|
||||
asyncio.gather(*[knowledge_graph_inst.get_edge(e[0], e[1]) for e in all_edges]),
|
||||
asyncio.gather(
|
||||
*[knowledge_graph_inst.edge_degree(e[0], e[1]) for e in all_edges]
|
||||
),
|
||||
)
|
||||
all_edges_data = [
|
||||
{"src_tgt": k, "rank": d, **v}
|
||||
@@ -1201,15 +1219,21 @@ async def _get_edge_data(
|
||||
if not len(results):
|
||||
return "", "", ""
|
||||
|
||||
edge_datas = await asyncio.gather(
|
||||
*[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results]
|
||||
edge_datas, edge_degree = await asyncio.gather(
|
||||
asyncio.gather(
|
||||
*[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results]
|
||||
),
|
||||
asyncio.gather(
|
||||
*[
|
||||
knowledge_graph_inst.edge_degree(r["src_id"], r["tgt_id"])
|
||||
for r in results
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
if not all([n is not None for n in edge_datas]):
|
||||
logger.warning("Some edges are missing, maybe the storage is damaged")
|
||||
edge_degree = await asyncio.gather(
|
||||
*[knowledge_graph_inst.edge_degree(r["src_id"], r["tgt_id"]) for r in results]
|
||||
)
|
||||
|
||||
edge_datas = [
|
||||
{
|
||||
"src_id": k["src_id"],
|
||||
@@ -1230,11 +1254,13 @@ async def _get_edge_data(
|
||||
max_token_size=query_param.max_token_for_global_context,
|
||||
)
|
||||
|
||||
use_entities = await _find_most_related_entities_from_relationships(
|
||||
edge_datas, query_param, knowledge_graph_inst
|
||||
)
|
||||
use_text_units = await _find_related_text_unit_from_relationships(
|
||||
edge_datas, query_param, text_chunks_db, knowledge_graph_inst
|
||||
use_entities, use_text_units = await asyncio.gather(
|
||||
_find_most_related_entities_from_relationships(
|
||||
edge_datas, query_param, knowledge_graph_inst
|
||||
),
|
||||
_find_related_text_unit_from_relationships(
|
||||
edge_datas, query_param, text_chunks_db, knowledge_graph_inst
|
||||
),
|
||||
)
|
||||
logger.info(
|
||||
f"Global query uses {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} text units"
|
||||
@@ -1307,12 +1333,19 @@ async def _find_most_related_entities_from_relationships(
|
||||
entity_names.append(e["tgt_id"])
|
||||
seen.add(e["tgt_id"])
|
||||
|
||||
node_datas = await asyncio.gather(
|
||||
*[knowledge_graph_inst.get_node(entity_name) for entity_name in entity_names]
|
||||
)
|
||||
|
||||
node_degrees = await asyncio.gather(
|
||||
*[knowledge_graph_inst.node_degree(entity_name) for entity_name in entity_names]
|
||||
node_datas, node_degrees = await asyncio.gather(
|
||||
asyncio.gather(
|
||||
*[
|
||||
knowledge_graph_inst.get_node(entity_name)
|
||||
for entity_name in entity_names
|
||||
]
|
||||
),
|
||||
asyncio.gather(
|
||||
*[
|
||||
knowledge_graph_inst.node_degree(entity_name)
|
||||
for entity_name in entity_names
|
||||
]
|
||||
),
|
||||
)
|
||||
node_datas = [
|
||||
{**n, "entity_name": k, "rank": d}
|
||||
@@ -1340,16 +1373,22 @@ async def _find_related_text_unit_from_relationships(
|
||||
]
|
||||
all_text_units_lookup = {}
|
||||
|
||||
async def fetch_chunk_data(c_id, index):
|
||||
if c_id not in all_text_units_lookup:
|
||||
chunk_data = await text_chunks_db.get_by_id(c_id)
|
||||
# Only store valid data
|
||||
if chunk_data is not None and "content" in chunk_data:
|
||||
all_text_units_lookup[c_id] = {
|
||||
"data": chunk_data,
|
||||
"order": index,
|
||||
}
|
||||
|
||||
tasks = []
|
||||
for index, unit_list in enumerate(text_units):
|
||||
for c_id in unit_list:
|
||||
if c_id not in all_text_units_lookup:
|
||||
chunk_data = await text_chunks_db.get_by_id(c_id)
|
||||
# Only store valid data
|
||||
if chunk_data is not None and "content" in chunk_data:
|
||||
all_text_units_lookup[c_id] = {
|
||||
"data": chunk_data,
|
||||
"order": index,
|
||||
}
|
||||
tasks.append(fetch_chunk_data(c_id, index))
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
if not all_text_units_lookup:
|
||||
logger.warning("No valid text chunks found")
|
||||
|
Reference in New Issue
Block a user