get_node added and all to base.py and to neo4j_impl.py file
This commit is contained in:
@@ -1233,16 +1233,20 @@ async def _get_node_data(
|
||||
|
||||
if not len(results):
|
||||
return "", "", ""
|
||||
# get entity information
|
||||
node_datas, node_degrees = await asyncio.gather(
|
||||
asyncio.gather(
|
||||
*[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
|
||||
),
|
||||
asyncio.gather(
|
||||
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
|
||||
),
|
||||
|
||||
# Extract all entity IDs from your results list
|
||||
node_ids = [r["entity_name"] for r in results]
|
||||
|
||||
# Call the batch node retrieval and degree functions concurrently.
|
||||
nodes_dict, degrees_dict = await asyncio.gather(
|
||||
knowledge_graph_inst.get_nodes_batch(node_ids),
|
||||
knowledge_graph_inst.node_degrees_batch(node_ids)
|
||||
)
|
||||
|
||||
# Now, if you need the node data and degree in order:
|
||||
node_datas = [nodes_dict.get(nid) for nid in node_ids]
|
||||
node_degrees = [degrees_dict.get(nid, 0) for nid in node_ids]
|
||||
|
||||
if not all([n is not None for n in node_datas]):
|
||||
logger.warning("Some nodes are missing, maybe the storage is damaged")
|
||||
|
||||
@@ -1374,9 +1378,10 @@ async def _find_most_related_text_unit_from_entities(
|
||||
all_one_hop_nodes.update([e[1] for e in this_edges])
|
||||
|
||||
all_one_hop_nodes = list(all_one_hop_nodes)
|
||||
all_one_hop_nodes_data = await asyncio.gather(
|
||||
*[knowledge_graph_inst.get_node(e) for e in all_one_hop_nodes]
|
||||
)
|
||||
|
||||
# Batch retrieve one-hop node data using get_nodes_batch
|
||||
all_one_hop_nodes_data_dict = await knowledge_graph_inst.get_nodes_batch(all_one_hop_nodes)
|
||||
all_one_hop_nodes_data = [all_one_hop_nodes_data_dict.get(e) for e in all_one_hop_nodes]
|
||||
|
||||
# Add null check for node data
|
||||
all_one_hop_text_units_lookup = {
|
||||
@@ -1512,29 +1517,34 @@ async def _get_edge_data(
|
||||
if not len(results):
|
||||
return "", "", ""
|
||||
|
||||
edge_datas, edge_degree = await asyncio.gather(
|
||||
asyncio.gather(
|
||||
*[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results]
|
||||
),
|
||||
asyncio.gather(
|
||||
*[
|
||||
knowledge_graph_inst.edge_degree(r["src_id"], r["tgt_id"])
|
||||
for r in results
|
||||
]
|
||||
),
|
||||
# Prepare edge pairs in two forms:
|
||||
# For the batch edge properties function, use dicts.
|
||||
edge_pairs_dicts = [{"src": r["src_id"], "tgt": r["tgt_id"]} for r in results]
|
||||
# For edge degrees, use tuples.
|
||||
edge_pairs_tuples = [(r["src_id"], r["tgt_id"]) for r in results]
|
||||
|
||||
# Call the batched functions concurrently.
|
||||
edge_data_dict, edge_degrees_dict = await asyncio.gather(
|
||||
knowledge_graph_inst.get_edges_batch(edge_pairs_dicts),
|
||||
knowledge_graph_inst.get_edges_degree_batch(edge_pairs_tuples)
|
||||
)
|
||||
|
||||
edge_datas = [
|
||||
{
|
||||
"src_id": k["src_id"],
|
||||
"tgt_id": k["tgt_id"],
|
||||
"rank": d,
|
||||
"created_at": k.get("__created_at__", None),
|
||||
**v,
|
||||
}
|
||||
for k, v, d in zip(results, edge_datas, edge_degree)
|
||||
if v is not None
|
||||
]
|
||||
# Reconstruct edge_datas list in the same order as results.
|
||||
edge_datas = []
|
||||
for k in results:
|
||||
pair = (k["src_id"], k["tgt_id"])
|
||||
edge_props = edge_data_dict.get(pair)
|
||||
if edge_props is not None:
|
||||
# Use edge degree from the batch as rank.
|
||||
combined = {
|
||||
"src_id": k["src_id"],
|
||||
"tgt_id": k["tgt_id"],
|
||||
"rank": edge_degrees_dict.get(pair, k.get("rank", 0)),
|
||||
"created_at": k.get("__created_at__", None),
|
||||
**edge_props,
|
||||
}
|
||||
edge_datas.append(combined)
|
||||
|
||||
edge_datas = sorted(
|
||||
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
||||
)
|
||||
@@ -1640,24 +1650,23 @@ async def _find_most_related_entities_from_relationships(
|
||||
entity_names.append(e["tgt_id"])
|
||||
seen.add(e["tgt_id"])
|
||||
|
||||
node_datas, node_degrees = await asyncio.gather(
|
||||
asyncio.gather(
|
||||
*[
|
||||
knowledge_graph_inst.get_node(entity_name)
|
||||
for entity_name in entity_names
|
||||
]
|
||||
),
|
||||
asyncio.gather(
|
||||
*[
|
||||
knowledge_graph_inst.node_degree(entity_name)
|
||||
for entity_name in entity_names
|
||||
]
|
||||
),
|
||||
# Batch approach: Retrieve nodes and their degrees concurrently with one query each.
|
||||
nodes_dict, degrees_dict = await asyncio.gather(
|
||||
knowledge_graph_inst.get_nodes_batch(entity_names),
|
||||
knowledge_graph_inst.get_node_degrees_batch(entity_names)
|
||||
)
|
||||
node_datas = [
|
||||
{**n, "entity_name": k, "rank": d}
|
||||
for k, n, d in zip(entity_names, node_datas, node_degrees)
|
||||
]
|
||||
|
||||
# Rebuild the list in the same order as entity_names
|
||||
node_datas = []
|
||||
for entity_name in entity_names:
|
||||
node = nodes_dict.get(entity_name)
|
||||
degree = degrees_dict.get(entity_name, 0)
|
||||
if node is None:
|
||||
logger.warning(f"Node '{entity_name}' not found in batch retrieval.")
|
||||
continue
|
||||
# Combine the node data with the entity name and computed degree (as rank)
|
||||
combined = {**node, "entity_name": entity_name, "rank": degree}
|
||||
node_datas.append(combined)
|
||||
|
||||
len_node_datas = len(node_datas)
|
||||
node_datas = truncate_list_by_token_size(
|
||||
|
Reference in New Issue
Block a user