get_node added and all to base.py and to neo4j_impl.py file

This commit is contained in:
frederikhendrix
2025-04-07 19:09:31 +02:00
parent d11f346cd5
commit 182aee2e14
3 changed files with 234 additions and 49 deletions

View File

@@ -309,6 +309,26 @@ class BaseGraphStorage(StorageNameSpace, ABC):
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
"""Upsert a node into the graph."""
@abstractmethod
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
"""Get nodes as a batch using UNWIND"""
@abstractmethod
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
"""Node degrees as a batch using UNWIND"""
@abstractmethod
async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> dict[tuple[str, str], int]:
"""Edge degrees as a batch using UNWIND also uses node_degrees_batch"""
@abstractmethod
async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]:
"""Get edges as a batch using UNWIND"""
@abstractmethod
async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]:
""""Get nodes edges as a batch using UNWIND"""
@abstractmethod
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
"""Upsert an edge into the graph."""

View File

@@ -314,6 +314,37 @@ class Neo4JStorage(BaseGraphStorage):
logger.error(f"Error getting node for {node_id}: {str(e)}")
raise
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
"""
Retrieve multiple nodes in one query using UNWIND.
Args:
node_ids: List of node entity IDs to fetch.
Returns:
A dictionary mapping each node_id to its node data (or None if not found).
"""
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = """
UNWIND $node_ids AS id
MATCH (n:base {entity_id: id})
RETURN n.entity_id AS entity_id, n
"""
result = await session.run(query, node_ids=node_ids)
nodes = {}
async for record in result:
entity_id = record["entity_id"]
node = record["n"]
node_dict = dict(node)
# Remove the 'base' label if present in a 'labels' property
if "labels" in node_dict:
node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"]
nodes[entity_id] = node_dict
await result.consume() # Make sure to consume the result fully
return nodes
async def node_degree(self, node_id: str) -> int:
"""Get the degree (number of relationships) of a node with the given label.
If multiple nodes have the same label, returns the degree of the first node.
@@ -357,6 +388,41 @@ class Neo4JStorage(BaseGraphStorage):
logger.error(f"Error getting node degree for {node_id}: {str(e)}")
raise
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
"""
Retrieve the degree for multiple nodes in a single query using UNWIND.
Args:
node_ids: List of node labels (entity_id values) to look up.
Returns:
A dictionary mapping each node_id to its degree (number of relationships).
If a node is not found, its degree will be set to 0.
"""
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = """
UNWIND $node_ids AS id
MATCH (n:base {entity_id: id})
RETURN n.entity_id AS entity_id, count { (n)--() } AS degree;
"""
result = await session.run(query, node_ids=node_ids)
degrees = {}
async for record in result:
entity_id = record["entity_id"]
degrees[entity_id] = record["degree"]
await result.consume() # Ensure result is fully consumed
# For any node_id that did not return a record, set degree to 0.
for nid in node_ids:
if nid not in degrees:
logger.warning(f"No node found with label '{nid}'")
degrees[nid] = 0
logger.debug(f"Neo4j batch node degree query returned: {degrees}")
return degrees
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
"""Get the total degree (sum of relationships) of two nodes.
@@ -376,6 +442,30 @@ class Neo4JStorage(BaseGraphStorage):
degrees = int(src_degree) + int(trg_degree)
return degrees
async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> dict[tuple[str, str], int]:
"""
Calculate the combined degree for each edge (sum of the source and target node degrees)
in batch using the already implemented node_degrees_batch.
Args:
edge_pairs: List of (src, tgt) tuples.
Returns:
A dictionary mapping each (src, tgt) tuple to the sum of their degrees.
"""
# Collect unique node IDs from all edge pairs.
unique_node_ids = {src for src, _ in edge_pairs}
unique_node_ids.update({tgt for _, tgt in edge_pairs})
# Get degrees for all nodes in one go.
degrees = await self.node_degrees_batch(list(unique_node_ids))
# Sum up degrees for each edge pair.
edge_degrees = {}
for src, tgt in edge_pairs:
edge_degrees[(src, tgt)] = degrees.get(src, 0) + degrees.get(tgt, 0)
return edge_degrees
async def get_edge(
self, source_node_id: str, target_node_id: str
@@ -463,6 +553,43 @@ class Neo4JStorage(BaseGraphStorage):
)
raise
async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]:
"""
Retrieve edge properties for multiple (src, tgt) pairs in one query.
Args:
pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...]
Returns:
A dictionary mapping (src, tgt) tuples to their edge properties.
"""
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = """
UNWIND $pairs AS pair
MATCH (start:base {entity_id: pair.src})-[r:DIRECTED]-(end:base {entity_id: pair.tgt})
RETURN pair.src AS src_id, pair.tgt AS tgt_id, collect(properties(r)) AS edges
"""
result = await session.run(query, pairs=pairs)
edges_dict = {}
async for record in result:
src = record["src_id"]
tgt = record["tgt_id"]
edges = record["edges"]
if edges and len(edges) > 0:
edge_props = edges[0] # choose the first if multiple exist
# Ensure required keys exist with defaults
for key, default in {"weight": 0.0, "source_id": None, "description": None, "keywords": None}.items():
if key not in edge_props:
edge_props[key] = default
edges_dict[(src, tgt)] = edge_props
else:
# No edge found set default edge properties
edges_dict[(src, tgt)] = {"weight": 0.0, "source_id": None, "description": None, "keywords": None}
await result.consume()
return edges_dict
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
"""Retrieves all edges (relationships) for a particular node identified by its label.
@@ -523,6 +650,35 @@ class Neo4JStorage(BaseGraphStorage):
logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}")
raise
async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]:
"""
Batch retrieve edges for multiple nodes in one query using UNWIND.
Args:
node_ids: List of node IDs (entity_id) for which to retrieve edges.
Returns:
A dictionary mapping each node ID to its list of edge tuples (source, target).
"""
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
query = """
UNWIND $node_ids AS id
MATCH (n:base {entity_id: id})
OPTIONAL MATCH (n)-[r]-(connected:base)
RETURN id AS queried_id, n.entity_id AS source_entity_id, connected.entity_id AS target_entity_id
"""
result = await session.run(query, node_ids=node_ids)
# Initialize the dictionary with empty lists for each node ID
edges_dict = {node_id: [] for node_id in node_ids}
async for record in result:
queried_id = record["queried_id"]
source_label = record["source_entity_id"]
target_label = record["target_entity_id"]
if source_label and target_label:
edges_dict[queried_id].append((source_label, target_label))
await result.consume() # Ensure results are fully consumed
return edges_dict
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),

View File

@@ -1233,16 +1233,20 @@ async def _get_node_data(
if not len(results):
return "", "", ""
# get entity information
node_datas, node_degrees = await asyncio.gather(
asyncio.gather(
*[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
),
asyncio.gather(
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
),
# Extract all entity IDs from your results list
node_ids = [r["entity_name"] for r in results]
# Call the batch node retrieval and degree functions concurrently.
nodes_dict, degrees_dict = await asyncio.gather(
knowledge_graph_inst.get_nodes_batch(node_ids),
knowledge_graph_inst.node_degrees_batch(node_ids)
)
# Now, if you need the node data and degree in order:
node_datas = [nodes_dict.get(nid) for nid in node_ids]
node_degrees = [degrees_dict.get(nid, 0) for nid in node_ids]
if not all([n is not None for n in node_datas]):
logger.warning("Some nodes are missing, maybe the storage is damaged")
@@ -1374,9 +1378,10 @@ async def _find_most_related_text_unit_from_entities(
all_one_hop_nodes.update([e[1] for e in this_edges])
all_one_hop_nodes = list(all_one_hop_nodes)
all_one_hop_nodes_data = await asyncio.gather(
*[knowledge_graph_inst.get_node(e) for e in all_one_hop_nodes]
)
# Batch retrieve one-hop node data using get_nodes_batch
all_one_hop_nodes_data_dict = await knowledge_graph_inst.get_nodes_batch(all_one_hop_nodes)
all_one_hop_nodes_data = [all_one_hop_nodes_data_dict.get(e) for e in all_one_hop_nodes]
# Add null check for node data
all_one_hop_text_units_lookup = {
@@ -1512,29 +1517,34 @@ async def _get_edge_data(
if not len(results):
return "", "", ""
edge_datas, edge_degree = await asyncio.gather(
asyncio.gather(
*[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results]
),
asyncio.gather(
*[
knowledge_graph_inst.edge_degree(r["src_id"], r["tgt_id"])
for r in results
]
),
# Prepare edge pairs in two forms:
# For the batch edge properties function, use dicts.
edge_pairs_dicts = [{"src": r["src_id"], "tgt": r["tgt_id"]} for r in results]
# For edge degrees, use tuples.
edge_pairs_tuples = [(r["src_id"], r["tgt_id"]) for r in results]
# Call the batched functions concurrently.
edge_data_dict, edge_degrees_dict = await asyncio.gather(
knowledge_graph_inst.get_edges_batch(edge_pairs_dicts),
knowledge_graph_inst.get_edges_degree_batch(edge_pairs_tuples)
)
edge_datas = [
{
"src_id": k["src_id"],
"tgt_id": k["tgt_id"],
"rank": d,
"created_at": k.get("__created_at__", None),
**v,
}
for k, v, d in zip(results, edge_datas, edge_degree)
if v is not None
]
# Reconstruct edge_datas list in the same order as results.
edge_datas = []
for k in results:
pair = (k["src_id"], k["tgt_id"])
edge_props = edge_data_dict.get(pair)
if edge_props is not None:
# Use edge degree from the batch as rank.
combined = {
"src_id": k["src_id"],
"tgt_id": k["tgt_id"],
"rank": edge_degrees_dict.get(pair, k.get("rank", 0)),
"created_at": k.get("__created_at__", None),
**edge_props,
}
edge_datas.append(combined)
edge_datas = sorted(
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
)
@@ -1640,24 +1650,23 @@ async def _find_most_related_entities_from_relationships(
entity_names.append(e["tgt_id"])
seen.add(e["tgt_id"])
node_datas, node_degrees = await asyncio.gather(
asyncio.gather(
*[
knowledge_graph_inst.get_node(entity_name)
for entity_name in entity_names
]
),
asyncio.gather(
*[
knowledge_graph_inst.node_degree(entity_name)
for entity_name in entity_names
]
),
# Batch approach: Retrieve nodes and their degrees concurrently with one query each.
nodes_dict, degrees_dict = await asyncio.gather(
knowledge_graph_inst.get_nodes_batch(entity_names),
knowledge_graph_inst.get_node_degrees_batch(entity_names)
)
node_datas = [
{**n, "entity_name": k, "rank": d}
for k, n, d in zip(entity_names, node_datas, node_degrees)
]
# Rebuild the list in the same order as entity_names
node_datas = []
for entity_name in entity_names:
node = nodes_dict.get(entity_name)
degree = degrees_dict.get(entity_name, 0)
if node is None:
logger.warning(f"Node '{entity_name}' not found in batch retrieval.")
continue
# Combine the node data with the entity name and computed degree (as rank)
combined = {**node, "entity_name": entity_name, "rank": degree}
node_datas.append(combined)
len_node_datas = len(node_datas)
node_datas = truncate_list_by_token_size(