asyncio optimizations

This commit is contained in:
Dmytro Til
2025-01-24 16:06:04 +01:00
parent 4e5ca51e38
commit f7b66d2c22

View File

@@ -941,28 +941,35 @@ async def _build_query_context(
query_param, query_param,
) )
else: # hybrid mode else: # hybrid mode
( ll_data, hl_data = await asyncio.gather(
ll_entities_context, _get_node_data(
ll_relations_context,
ll_text_units_context,
) = await _get_node_data(
ll_keywords, ll_keywords,
knowledge_graph_inst, knowledge_graph_inst,
entities_vdb, entities_vdb,
text_chunks_db, text_chunks_db,
query_param, query_param,
) ),
( _get_edge_data(
hl_entities_context,
hl_relations_context,
hl_text_units_context,
) = await _get_edge_data(
hl_keywords, hl_keywords,
knowledge_graph_inst, knowledge_graph_inst,
relationships_vdb, relationships_vdb,
text_chunks_db, text_chunks_db,
query_param, query_param,
),
) )
(
ll_entities_context,
ll_relations_context,
ll_text_units_context,
) = ll_data
(
hl_entities_context,
hl_relations_context,
hl_text_units_context,
) = hl_data
entities_context, relations_context, text_units_context = combine_contexts( entities_context, relations_context, text_units_context = combine_contexts(
[hl_entities_context, ll_entities_context], [hl_entities_context, ll_entities_context],
[hl_relations_context, ll_relations_context], [hl_relations_context, ll_relations_context],
@@ -996,28 +1003,31 @@ async def _get_node_data(
if not len(results): if not len(results):
return "", "", "" return "", "", ""
# get entity information # get entity information
node_datas = await asyncio.gather( node_datas, node_degrees = await asyncio.gather(
asyncio.gather(
*[knowledge_graph_inst.get_node(r["entity_name"]) for r in results] *[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
),
asyncio.gather(
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
),
) )
if not all([n is not None for n in node_datas]): if not all([n is not None for n in node_datas]):
logger.warning("Some nodes are missing, maybe the storage is damaged") logger.warning("Some nodes are missing, maybe the storage is damaged")
# get entity degree
node_degrees = await asyncio.gather(
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
)
node_datas = [ node_datas = [
{**n, "entity_name": k["entity_name"], "rank": d} {**n, "entity_name": k["entity_name"], "rank": d}
for k, n, d in zip(results, node_datas, node_degrees) for k, n, d in zip(results, node_datas, node_degrees)
if n is not None if n is not None
] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram. ] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
# get entitytext chunk # get entitytext chunk
use_text_units = await _find_most_related_text_unit_from_entities( use_text_units, use_relations = await asyncio.gather(
_find_most_related_text_unit_from_entities(
node_datas, query_param, text_chunks_db, knowledge_graph_inst node_datas, query_param, text_chunks_db, knowledge_graph_inst
) ),
# get relate edges _find_most_related_edges_from_entities(
use_relations = await _find_most_related_edges_from_entities(
node_datas, query_param, knowledge_graph_inst node_datas, query_param, knowledge_graph_inst
),
) )
logger.info( logger.info(
f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} text units" f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} text units"
@@ -1107,11 +1117,19 @@ async def _find_most_related_text_unit_from_entities(
} }
all_text_units_lookup = {} all_text_units_lookup = {}
tasks = []
for index, (this_text_units, this_edges) in enumerate(zip(text_units, edges)): for index, (this_text_units, this_edges) in enumerate(zip(text_units, edges)):
for c_id in this_text_units: for c_id in this_text_units:
if c_id not in all_text_units_lookup: if c_id not in all_text_units_lookup:
tasks.append((c_id, index, this_edges))
results = await asyncio.gather(
*[text_chunks_db.get_by_id(c_id) for c_id, _, _ in tasks]
)
for (c_id, index, this_edges), data in zip(tasks, results):
all_text_units_lookup[c_id] = { all_text_units_lookup[c_id] = {
"data": await text_chunks_db.get_by_id(c_id), "data": data,
"order": index, "order": index,
"relation_counts": 0, "relation_counts": 0,
} }
@@ -1167,11 +1185,11 @@ async def _find_most_related_edges_from_entities(
seen.add(sorted_edge) seen.add(sorted_edge)
all_edges.append(sorted_edge) all_edges.append(sorted_edge)
all_edges_pack = await asyncio.gather( all_edges_pack, all_edges_degree = await asyncio.gather(
*[knowledge_graph_inst.get_edge(e[0], e[1]) for e in all_edges] asyncio.gather(*[knowledge_graph_inst.get_edge(e[0], e[1]) for e in all_edges]),
) asyncio.gather(
all_edges_degree = await asyncio.gather(
*[knowledge_graph_inst.edge_degree(e[0], e[1]) for e in all_edges] *[knowledge_graph_inst.edge_degree(e[0], e[1]) for e in all_edges]
),
) )
all_edges_data = [ all_edges_data = [
{"src_tgt": k, "rank": d, **v} {"src_tgt": k, "rank": d, **v}
@@ -1201,15 +1219,21 @@ async def _get_edge_data(
if not len(results): if not len(results):
return "", "", "" return "", "", ""
edge_datas = await asyncio.gather( edge_datas, edge_degree = await asyncio.gather(
asyncio.gather(
*[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results] *[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results]
),
asyncio.gather(
*[
knowledge_graph_inst.edge_degree(r["src_id"], r["tgt_id"])
for r in results
]
),
) )
if not all([n is not None for n in edge_datas]): if not all([n is not None for n in edge_datas]):
logger.warning("Some edges are missing, maybe the storage is damaged") logger.warning("Some edges are missing, maybe the storage is damaged")
edge_degree = await asyncio.gather(
*[knowledge_graph_inst.edge_degree(r["src_id"], r["tgt_id"]) for r in results]
)
edge_datas = [ edge_datas = [
{ {
"src_id": k["src_id"], "src_id": k["src_id"],
@@ -1230,11 +1254,13 @@ async def _get_edge_data(
max_token_size=query_param.max_token_for_global_context, max_token_size=query_param.max_token_for_global_context,
) )
use_entities = await _find_most_related_entities_from_relationships( use_entities, use_text_units = await asyncio.gather(
_find_most_related_entities_from_relationships(
edge_datas, query_param, knowledge_graph_inst edge_datas, query_param, knowledge_graph_inst
) ),
use_text_units = await _find_related_text_unit_from_relationships( _find_related_text_unit_from_relationships(
edge_datas, query_param, text_chunks_db, knowledge_graph_inst edge_datas, query_param, text_chunks_db, knowledge_graph_inst
),
) )
logger.info( logger.info(
f"Global query uses {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} text units" f"Global query uses {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} text units"
@@ -1307,12 +1333,19 @@ async def _find_most_related_entities_from_relationships(
entity_names.append(e["tgt_id"]) entity_names.append(e["tgt_id"])
seen.add(e["tgt_id"]) seen.add(e["tgt_id"])
node_datas = await asyncio.gather( node_datas, node_degrees = await asyncio.gather(
*[knowledge_graph_inst.get_node(entity_name) for entity_name in entity_names] asyncio.gather(
) *[
knowledge_graph_inst.get_node(entity_name)
node_degrees = await asyncio.gather( for entity_name in entity_names
*[knowledge_graph_inst.node_degree(entity_name) for entity_name in entity_names] ]
),
asyncio.gather(
*[
knowledge_graph_inst.node_degree(entity_name)
for entity_name in entity_names
]
),
) )
node_datas = [ node_datas = [
{**n, "entity_name": k, "rank": d} {**n, "entity_name": k, "rank": d}
@@ -1340,8 +1373,7 @@ async def _find_related_text_unit_from_relationships(
] ]
all_text_units_lookup = {} all_text_units_lookup = {}
for index, unit_list in enumerate(text_units): async def fetch_chunk_data(c_id, index):
for c_id in unit_list:
if c_id not in all_text_units_lookup: if c_id not in all_text_units_lookup:
chunk_data = await text_chunks_db.get_by_id(c_id) chunk_data = await text_chunks_db.get_by_id(c_id)
# Only store valid data # Only store valid data
@@ -1351,6 +1383,13 @@ async def _find_related_text_unit_from_relationships(
"order": index, "order": index,
} }
tasks = []
for index, unit_list in enumerate(text_units):
for c_id in unit_list:
tasks.append(fetch_chunk_data(c_id, index))
await asyncio.gather(*tasks)
if not all_text_units_lookup: if not all_text_units_lookup:
logger.warning("No valid text chunks found") logger.warning("No valid text chunks found")
return [] return []