Merge pull request #642 from dimatill/main

asyncio optimizations
This commit is contained in:
zrguo
2025-01-25 01:52:24 +08:00
committed by GitHub

View File

@@ -990,28 +990,35 @@ async def _build_query_context(
query_param,
)
else: # hybrid mode
(
ll_entities_context,
ll_relations_context,
ll_text_units_context,
) = await _get_node_data(
ll_data, hl_data = await asyncio.gather(
_get_node_data(
ll_keywords,
knowledge_graph_inst,
entities_vdb,
text_chunks_db,
query_param,
)
(
hl_entities_context,
hl_relations_context,
hl_text_units_context,
) = await _get_edge_data(
),
_get_edge_data(
hl_keywords,
knowledge_graph_inst,
relationships_vdb,
text_chunks_db,
query_param,
),
)
(
ll_entities_context,
ll_relations_context,
ll_text_units_context,
) = ll_data
(
hl_entities_context,
hl_relations_context,
hl_text_units_context,
) = hl_data
entities_context, relations_context, text_units_context = combine_contexts(
[hl_entities_context, ll_entities_context],
[hl_relations_context, ll_relations_context],
@@ -1045,28 +1052,31 @@ async def _get_node_data(
if not len(results):
return "", "", ""
# get entity information
node_datas = await asyncio.gather(
node_datas, node_degrees = await asyncio.gather(
asyncio.gather(
*[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
),
asyncio.gather(
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
),
)
if not all([n is not None for n in node_datas]):
logger.warning("Some nodes are missing, maybe the storage is damaged")
# get entity degree
node_degrees = await asyncio.gather(
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
)
node_datas = [
{**n, "entity_name": k["entity_name"], "rank": d}
for k, n, d in zip(results, node_datas, node_degrees)
if n is not None
] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
# get entitytext chunk
use_text_units = await _find_most_related_text_unit_from_entities(
use_text_units, use_relations = await asyncio.gather(
_find_most_related_text_unit_from_entities(
node_datas, query_param, text_chunks_db, knowledge_graph_inst
)
# get relate edges
use_relations = await _find_most_related_edges_from_entities(
),
_find_most_related_edges_from_entities(
node_datas, query_param, knowledge_graph_inst
),
)
logger.info(
f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} text units"
@@ -1156,11 +1166,19 @@ async def _find_most_related_text_unit_from_entities(
}
all_text_units_lookup = {}
tasks = []
for index, (this_text_units, this_edges) in enumerate(zip(text_units, edges)):
for c_id in this_text_units:
if c_id not in all_text_units_lookup:
tasks.append((c_id, index, this_edges))
results = await asyncio.gather(
*[text_chunks_db.get_by_id(c_id) for c_id, _, _ in tasks]
)
for (c_id, index, this_edges), data in zip(tasks, results):
all_text_units_lookup[c_id] = {
"data": await text_chunks_db.get_by_id(c_id),
"data": data,
"order": index,
"relation_counts": 0,
}
@@ -1216,11 +1234,11 @@ async def _find_most_related_edges_from_entities(
seen.add(sorted_edge)
all_edges.append(sorted_edge)
all_edges_pack = await asyncio.gather(
*[knowledge_graph_inst.get_edge(e[0], e[1]) for e in all_edges]
)
all_edges_degree = await asyncio.gather(
all_edges_pack, all_edges_degree = await asyncio.gather(
asyncio.gather(*[knowledge_graph_inst.get_edge(e[0], e[1]) for e in all_edges]),
asyncio.gather(
*[knowledge_graph_inst.edge_degree(e[0], e[1]) for e in all_edges]
),
)
all_edges_data = [
{"src_tgt": k, "rank": d, **v}
@@ -1250,15 +1268,21 @@ async def _get_edge_data(
if not len(results):
return "", "", ""
edge_datas = await asyncio.gather(
edge_datas, edge_degree = await asyncio.gather(
asyncio.gather(
*[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results]
),
asyncio.gather(
*[
knowledge_graph_inst.edge_degree(r["src_id"], r["tgt_id"])
for r in results
]
),
)
if not all([n is not None for n in edge_datas]):
logger.warning("Some edges are missing, maybe the storage is damaged")
edge_degree = await asyncio.gather(
*[knowledge_graph_inst.edge_degree(r["src_id"], r["tgt_id"]) for r in results]
)
edge_datas = [
{
"src_id": k["src_id"],
@@ -1279,11 +1303,13 @@ async def _get_edge_data(
max_token_size=query_param.max_token_for_global_context,
)
use_entities = await _find_most_related_entities_from_relationships(
use_entities, use_text_units = await asyncio.gather(
_find_most_related_entities_from_relationships(
edge_datas, query_param, knowledge_graph_inst
)
use_text_units = await _find_related_text_unit_from_relationships(
),
_find_related_text_unit_from_relationships(
edge_datas, query_param, text_chunks_db, knowledge_graph_inst
),
)
logger.info(
f"Global query uses {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} text units"
@@ -1356,12 +1382,19 @@ async def _find_most_related_entities_from_relationships(
entity_names.append(e["tgt_id"])
seen.add(e["tgt_id"])
node_datas = await asyncio.gather(
*[knowledge_graph_inst.get_node(entity_name) for entity_name in entity_names]
)
node_degrees = await asyncio.gather(
*[knowledge_graph_inst.node_degree(entity_name) for entity_name in entity_names]
node_datas, node_degrees = await asyncio.gather(
asyncio.gather(
*[
knowledge_graph_inst.get_node(entity_name)
for entity_name in entity_names
]
),
asyncio.gather(
*[
knowledge_graph_inst.node_degree(entity_name)
for entity_name in entity_names
]
),
)
node_datas = [
{**n, "entity_name": k, "rank": d}
@@ -1389,8 +1422,7 @@ async def _find_related_text_unit_from_relationships(
]
all_text_units_lookup = {}
for index, unit_list in enumerate(text_units):
for c_id in unit_list:
async def fetch_chunk_data(c_id, index):
if c_id not in all_text_units_lookup:
chunk_data = await text_chunks_db.get_by_id(c_id)
# Only store valid data
@@ -1400,6 +1432,13 @@ async def _find_related_text_unit_from_relationships(
"order": index,
}
tasks = []
for index, unit_list in enumerate(text_units):
for c_id in unit_list:
tasks.append(fetch_chunk_data(c_id, index))
await asyncio.gather(*tasks)
if not all_text_units_lookup:
logger.warning("No valid text chunks found")
return []