@@ -990,28 +990,35 @@ async def _build_query_context(
|
|||||||
query_param,
|
query_param,
|
||||||
)
|
)
|
||||||
else: # hybrid mode
|
else: # hybrid mode
|
||||||
|
ll_data, hl_data = await asyncio.gather(
|
||||||
|
_get_node_data(
|
||||||
|
ll_keywords,
|
||||||
|
knowledge_graph_inst,
|
||||||
|
entities_vdb,
|
||||||
|
text_chunks_db,
|
||||||
|
query_param,
|
||||||
|
),
|
||||||
|
_get_edge_data(
|
||||||
|
hl_keywords,
|
||||||
|
knowledge_graph_inst,
|
||||||
|
relationships_vdb,
|
||||||
|
text_chunks_db,
|
||||||
|
query_param,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
(
|
(
|
||||||
ll_entities_context,
|
ll_entities_context,
|
||||||
ll_relations_context,
|
ll_relations_context,
|
||||||
ll_text_units_context,
|
ll_text_units_context,
|
||||||
) = await _get_node_data(
|
) = ll_data
|
||||||
ll_keywords,
|
|
||||||
knowledge_graph_inst,
|
|
||||||
entities_vdb,
|
|
||||||
text_chunks_db,
|
|
||||||
query_param,
|
|
||||||
)
|
|
||||||
(
|
(
|
||||||
hl_entities_context,
|
hl_entities_context,
|
||||||
hl_relations_context,
|
hl_relations_context,
|
||||||
hl_text_units_context,
|
hl_text_units_context,
|
||||||
) = await _get_edge_data(
|
) = hl_data
|
||||||
hl_keywords,
|
|
||||||
knowledge_graph_inst,
|
|
||||||
relationships_vdb,
|
|
||||||
text_chunks_db,
|
|
||||||
query_param,
|
|
||||||
)
|
|
||||||
entities_context, relations_context, text_units_context = combine_contexts(
|
entities_context, relations_context, text_units_context = combine_contexts(
|
||||||
[hl_entities_context, ll_entities_context],
|
[hl_entities_context, ll_entities_context],
|
||||||
[hl_relations_context, ll_relations_context],
|
[hl_relations_context, ll_relations_context],
|
||||||
@@ -1045,28 +1052,31 @@ async def _get_node_data(
|
|||||||
if not len(results):
|
if not len(results):
|
||||||
return "", "", ""
|
return "", "", ""
|
||||||
# get entity information
|
# get entity information
|
||||||
node_datas = await asyncio.gather(
|
node_datas, node_degrees = await asyncio.gather(
|
||||||
*[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
|
asyncio.gather(
|
||||||
|
*[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
|
||||||
|
),
|
||||||
|
asyncio.gather(
|
||||||
|
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if not all([n is not None for n in node_datas]):
|
if not all([n is not None for n in node_datas]):
|
||||||
logger.warning("Some nodes are missing, maybe the storage is damaged")
|
logger.warning("Some nodes are missing, maybe the storage is damaged")
|
||||||
|
|
||||||
# get entity degree
|
|
||||||
node_degrees = await asyncio.gather(
|
|
||||||
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
|
|
||||||
)
|
|
||||||
node_datas = [
|
node_datas = [
|
||||||
{**n, "entity_name": k["entity_name"], "rank": d}
|
{**n, "entity_name": k["entity_name"], "rank": d}
|
||||||
for k, n, d in zip(results, node_datas, node_degrees)
|
for k, n, d in zip(results, node_datas, node_degrees)
|
||||||
if n is not None
|
if n is not None
|
||||||
] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
|
] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
|
||||||
# get entitytext chunk
|
# get entitytext chunk
|
||||||
use_text_units = await _find_most_related_text_unit_from_entities(
|
use_text_units, use_relations = await asyncio.gather(
|
||||||
node_datas, query_param, text_chunks_db, knowledge_graph_inst
|
_find_most_related_text_unit_from_entities(
|
||||||
)
|
node_datas, query_param, text_chunks_db, knowledge_graph_inst
|
||||||
# get relate edges
|
),
|
||||||
use_relations = await _find_most_related_edges_from_entities(
|
_find_most_related_edges_from_entities(
|
||||||
node_datas, query_param, knowledge_graph_inst
|
node_datas, query_param, knowledge_graph_inst
|
||||||
|
),
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} text units"
|
f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} text units"
|
||||||
@@ -1156,22 +1166,30 @@ async def _find_most_related_text_unit_from_entities(
|
|||||||
}
|
}
|
||||||
|
|
||||||
all_text_units_lookup = {}
|
all_text_units_lookup = {}
|
||||||
|
tasks = []
|
||||||
for index, (this_text_units, this_edges) in enumerate(zip(text_units, edges)):
|
for index, (this_text_units, this_edges) in enumerate(zip(text_units, edges)):
|
||||||
for c_id in this_text_units:
|
for c_id in this_text_units:
|
||||||
if c_id not in all_text_units_lookup:
|
if c_id not in all_text_units_lookup:
|
||||||
all_text_units_lookup[c_id] = {
|
tasks.append((c_id, index, this_edges))
|
||||||
"data": await text_chunks_db.get_by_id(c_id),
|
|
||||||
"order": index,
|
|
||||||
"relation_counts": 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
if this_edges:
|
results = await asyncio.gather(
|
||||||
for e in this_edges:
|
*[text_chunks_db.get_by_id(c_id) for c_id, _, _ in tasks]
|
||||||
if (
|
)
|
||||||
e[1] in all_one_hop_text_units_lookup
|
|
||||||
and c_id in all_one_hop_text_units_lookup[e[1]]
|
for (c_id, index, this_edges), data in zip(tasks, results):
|
||||||
):
|
all_text_units_lookup[c_id] = {
|
||||||
all_text_units_lookup[c_id]["relation_counts"] += 1
|
"data": data,
|
||||||
|
"order": index,
|
||||||
|
"relation_counts": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
if this_edges:
|
||||||
|
for e in this_edges:
|
||||||
|
if (
|
||||||
|
e[1] in all_one_hop_text_units_lookup
|
||||||
|
and c_id in all_one_hop_text_units_lookup[e[1]]
|
||||||
|
):
|
||||||
|
all_text_units_lookup[c_id]["relation_counts"] += 1
|
||||||
|
|
||||||
# Filter out None values and ensure data has content
|
# Filter out None values and ensure data has content
|
||||||
all_text_units = [
|
all_text_units = [
|
||||||
@@ -1216,11 +1234,11 @@ async def _find_most_related_edges_from_entities(
|
|||||||
seen.add(sorted_edge)
|
seen.add(sorted_edge)
|
||||||
all_edges.append(sorted_edge)
|
all_edges.append(sorted_edge)
|
||||||
|
|
||||||
all_edges_pack = await asyncio.gather(
|
all_edges_pack, all_edges_degree = await asyncio.gather(
|
||||||
*[knowledge_graph_inst.get_edge(e[0], e[1]) for e in all_edges]
|
asyncio.gather(*[knowledge_graph_inst.get_edge(e[0], e[1]) for e in all_edges]),
|
||||||
)
|
asyncio.gather(
|
||||||
all_edges_degree = await asyncio.gather(
|
*[knowledge_graph_inst.edge_degree(e[0], e[1]) for e in all_edges]
|
||||||
*[knowledge_graph_inst.edge_degree(e[0], e[1]) for e in all_edges]
|
),
|
||||||
)
|
)
|
||||||
all_edges_data = [
|
all_edges_data = [
|
||||||
{"src_tgt": k, "rank": d, **v}
|
{"src_tgt": k, "rank": d, **v}
|
||||||
@@ -1250,15 +1268,21 @@ async def _get_edge_data(
|
|||||||
if not len(results):
|
if not len(results):
|
||||||
return "", "", ""
|
return "", "", ""
|
||||||
|
|
||||||
edge_datas = await asyncio.gather(
|
edge_datas, edge_degree = await asyncio.gather(
|
||||||
*[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results]
|
asyncio.gather(
|
||||||
|
*[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results]
|
||||||
|
),
|
||||||
|
asyncio.gather(
|
||||||
|
*[
|
||||||
|
knowledge_graph_inst.edge_degree(r["src_id"], r["tgt_id"])
|
||||||
|
for r in results
|
||||||
|
]
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if not all([n is not None for n in edge_datas]):
|
if not all([n is not None for n in edge_datas]):
|
||||||
logger.warning("Some edges are missing, maybe the storage is damaged")
|
logger.warning("Some edges are missing, maybe the storage is damaged")
|
||||||
edge_degree = await asyncio.gather(
|
|
||||||
*[knowledge_graph_inst.edge_degree(r["src_id"], r["tgt_id"]) for r in results]
|
|
||||||
)
|
|
||||||
edge_datas = [
|
edge_datas = [
|
||||||
{
|
{
|
||||||
"src_id": k["src_id"],
|
"src_id": k["src_id"],
|
||||||
@@ -1279,11 +1303,13 @@ async def _get_edge_data(
|
|||||||
max_token_size=query_param.max_token_for_global_context,
|
max_token_size=query_param.max_token_for_global_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
use_entities = await _find_most_related_entities_from_relationships(
|
use_entities, use_text_units = await asyncio.gather(
|
||||||
edge_datas, query_param, knowledge_graph_inst
|
_find_most_related_entities_from_relationships(
|
||||||
)
|
edge_datas, query_param, knowledge_graph_inst
|
||||||
use_text_units = await _find_related_text_unit_from_relationships(
|
),
|
||||||
edge_datas, query_param, text_chunks_db, knowledge_graph_inst
|
_find_related_text_unit_from_relationships(
|
||||||
|
edge_datas, query_param, text_chunks_db, knowledge_graph_inst
|
||||||
|
),
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Global query uses {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} text units"
|
f"Global query uses {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} text units"
|
||||||
@@ -1356,12 +1382,19 @@ async def _find_most_related_entities_from_relationships(
|
|||||||
entity_names.append(e["tgt_id"])
|
entity_names.append(e["tgt_id"])
|
||||||
seen.add(e["tgt_id"])
|
seen.add(e["tgt_id"])
|
||||||
|
|
||||||
node_datas = await asyncio.gather(
|
node_datas, node_degrees = await asyncio.gather(
|
||||||
*[knowledge_graph_inst.get_node(entity_name) for entity_name in entity_names]
|
asyncio.gather(
|
||||||
)
|
*[
|
||||||
|
knowledge_graph_inst.get_node(entity_name)
|
||||||
node_degrees = await asyncio.gather(
|
for entity_name in entity_names
|
||||||
*[knowledge_graph_inst.node_degree(entity_name) for entity_name in entity_names]
|
]
|
||||||
|
),
|
||||||
|
asyncio.gather(
|
||||||
|
*[
|
||||||
|
knowledge_graph_inst.node_degree(entity_name)
|
||||||
|
for entity_name in entity_names
|
||||||
|
]
|
||||||
|
),
|
||||||
)
|
)
|
||||||
node_datas = [
|
node_datas = [
|
||||||
{**n, "entity_name": k, "rank": d}
|
{**n, "entity_name": k, "rank": d}
|
||||||
@@ -1389,16 +1422,22 @@ async def _find_related_text_unit_from_relationships(
|
|||||||
]
|
]
|
||||||
all_text_units_lookup = {}
|
all_text_units_lookup = {}
|
||||||
|
|
||||||
|
async def fetch_chunk_data(c_id, index):
|
||||||
|
if c_id not in all_text_units_lookup:
|
||||||
|
chunk_data = await text_chunks_db.get_by_id(c_id)
|
||||||
|
# Only store valid data
|
||||||
|
if chunk_data is not None and "content" in chunk_data:
|
||||||
|
all_text_units_lookup[c_id] = {
|
||||||
|
"data": chunk_data,
|
||||||
|
"order": index,
|
||||||
|
}
|
||||||
|
|
||||||
|
tasks = []
|
||||||
for index, unit_list in enumerate(text_units):
|
for index, unit_list in enumerate(text_units):
|
||||||
for c_id in unit_list:
|
for c_id in unit_list:
|
||||||
if c_id not in all_text_units_lookup:
|
tasks.append(fetch_chunk_data(c_id, index))
|
||||||
chunk_data = await text_chunks_db.get_by_id(c_id)
|
|
||||||
# Only store valid data
|
await asyncio.gather(*tasks)
|
||||||
if chunk_data is not None and "content" in chunk_data:
|
|
||||||
all_text_units_lookup[c_id] = {
|
|
||||||
"data": chunk_data,
|
|
||||||
"order": index,
|
|
||||||
}
|
|
||||||
|
|
||||||
if not all_text_units_lookup:
|
if not all_text_units_lookup:
|
||||||
logger.warning("No valid text chunks found")
|
logger.warning("No valid text chunks found")
|
||||||
|
Reference in New Issue
Block a user