get_node added and all to base.py and to neo4j_impl.py file
This commit is contained in:
@@ -314,6 +314,37 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
logger.error(f"Error getting node for {node_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
|
||||
"""
|
||||
Retrieve multiple nodes in one query using UNWIND.
|
||||
|
||||
Args:
|
||||
node_ids: List of node entity IDs to fetch.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping each node_id to its node data (or None if not found).
|
||||
"""
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
query = """
|
||||
UNWIND $node_ids AS id
|
||||
MATCH (n:base {entity_id: id})
|
||||
RETURN n.entity_id AS entity_id, n
|
||||
"""
|
||||
result = await session.run(query, node_ids=node_ids)
|
||||
nodes = {}
|
||||
async for record in result:
|
||||
entity_id = record["entity_id"]
|
||||
node = record["n"]
|
||||
node_dict = dict(node)
|
||||
# Remove the 'base' label if present in a 'labels' property
|
||||
if "labels" in node_dict:
|
||||
node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"]
|
||||
nodes[entity_id] = node_dict
|
||||
await result.consume() # Make sure to consume the result fully
|
||||
return nodes
|
||||
|
||||
async def node_degree(self, node_id: str) -> int:
|
||||
"""Get the degree (number of relationships) of a node with the given label.
|
||||
If multiple nodes have the same label, returns the degree of the first node.
|
||||
@@ -357,6 +388,41 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
logger.error(f"Error getting node degree for {node_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
|
||||
"""
|
||||
Retrieve the degree for multiple nodes in a single query using UNWIND.
|
||||
|
||||
Args:
|
||||
node_ids: List of node labels (entity_id values) to look up.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping each node_id to its degree (number of relationships).
|
||||
If a node is not found, its degree will be set to 0.
|
||||
"""
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
query = """
|
||||
UNWIND $node_ids AS id
|
||||
MATCH (n:base {entity_id: id})
|
||||
RETURN n.entity_id AS entity_id, count { (n)--() } AS degree;
|
||||
"""
|
||||
result = await session.run(query, node_ids=node_ids)
|
||||
degrees = {}
|
||||
async for record in result:
|
||||
entity_id = record["entity_id"]
|
||||
degrees[entity_id] = record["degree"]
|
||||
await result.consume() # Ensure result is fully consumed
|
||||
|
||||
# For any node_id that did not return a record, set degree to 0.
|
||||
for nid in node_ids:
|
||||
if nid not in degrees:
|
||||
logger.warning(f"No node found with label '{nid}'")
|
||||
degrees[nid] = 0
|
||||
|
||||
logger.debug(f"Neo4j batch node degree query returned: {degrees}")
|
||||
return degrees
|
||||
|
||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||
"""Get the total degree (sum of relationships) of two nodes.
|
||||
|
||||
@@ -376,6 +442,30 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
|
||||
degrees = int(src_degree) + int(trg_degree)
|
||||
return degrees
|
||||
|
||||
async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> dict[tuple[str, str], int]:
|
||||
"""
|
||||
Calculate the combined degree for each edge (sum of the source and target node degrees)
|
||||
in batch using the already implemented node_degrees_batch.
|
||||
|
||||
Args:
|
||||
edge_pairs: List of (src, tgt) tuples.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping each (src, tgt) tuple to the sum of their degrees.
|
||||
"""
|
||||
# Collect unique node IDs from all edge pairs.
|
||||
unique_node_ids = {src for src, _ in edge_pairs}
|
||||
unique_node_ids.update({tgt for _, tgt in edge_pairs})
|
||||
|
||||
# Get degrees for all nodes in one go.
|
||||
degrees = await self.node_degrees_batch(list(unique_node_ids))
|
||||
|
||||
# Sum up degrees for each edge pair.
|
||||
edge_degrees = {}
|
||||
for src, tgt in edge_pairs:
|
||||
edge_degrees[(src, tgt)] = degrees.get(src, 0) + degrees.get(tgt, 0)
|
||||
return edge_degrees
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
@@ -463,6 +553,43 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]:
|
||||
"""
|
||||
Retrieve edge properties for multiple (src, tgt) pairs in one query.
|
||||
|
||||
Args:
|
||||
pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...]
|
||||
|
||||
Returns:
|
||||
A dictionary mapping (src, tgt) tuples to their edge properties.
|
||||
"""
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
query = """
|
||||
UNWIND $pairs AS pair
|
||||
MATCH (start:base {entity_id: pair.src})-[r:DIRECTED]-(end:base {entity_id: pair.tgt})
|
||||
RETURN pair.src AS src_id, pair.tgt AS tgt_id, collect(properties(r)) AS edges
|
||||
"""
|
||||
result = await session.run(query, pairs=pairs)
|
||||
edges_dict = {}
|
||||
async for record in result:
|
||||
src = record["src_id"]
|
||||
tgt = record["tgt_id"]
|
||||
edges = record["edges"]
|
||||
if edges and len(edges) > 0:
|
||||
edge_props = edges[0] # choose the first if multiple exist
|
||||
# Ensure required keys exist with defaults
|
||||
for key, default in {"weight": 0.0, "source_id": None, "description": None, "keywords": None}.items():
|
||||
if key not in edge_props:
|
||||
edge_props[key] = default
|
||||
edges_dict[(src, tgt)] = edge_props
|
||||
else:
|
||||
# No edge found – set default edge properties
|
||||
edges_dict[(src, tgt)] = {"weight": 0.0, "source_id": None, "description": None, "keywords": None}
|
||||
await result.consume()
|
||||
return edges_dict
|
||||
|
||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||
"""Retrieves all edges (relationships) for a particular node identified by its label.
|
||||
|
||||
@@ -523,6 +650,35 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]:
|
||||
"""
|
||||
Batch retrieve edges for multiple nodes in one query using UNWIND.
|
||||
|
||||
Args:
|
||||
node_ids: List of node IDs (entity_id) for which to retrieve edges.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping each node ID to its list of edge tuples (source, target).
|
||||
"""
|
||||
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
|
||||
query = """
|
||||
UNWIND $node_ids AS id
|
||||
MATCH (n:base {entity_id: id})
|
||||
OPTIONAL MATCH (n)-[r]-(connected:base)
|
||||
RETURN id AS queried_id, n.entity_id AS source_entity_id, connected.entity_id AS target_entity_id
|
||||
"""
|
||||
result = await session.run(query, node_ids=node_ids)
|
||||
# Initialize the dictionary with empty lists for each node ID
|
||||
edges_dict = {node_id: [] for node_id in node_ids}
|
||||
async for record in result:
|
||||
queried_id = record["queried_id"]
|
||||
source_label = record["source_entity_id"]
|
||||
target_label = record["target_entity_id"]
|
||||
if source_label and target_label:
|
||||
edges_dict[queried_id].append((source_label, target_label))
|
||||
await result.consume() # Ensure results are fully consumed
|
||||
return edges_dict
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
|
Reference in New Issue
Block a user