get_node added and all to base.py and to neo4j_impl.py file
This commit is contained in:
@@ -309,6 +309,26 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||
"""Upsert a node into the graph."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
|
||||
"""Get nodes as a batch using UNWIND"""
|
||||
|
||||
@abstractmethod
|
||||
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
|
||||
"""Node degrees as a batch using UNWIND"""
|
||||
|
||||
@abstractmethod
|
||||
async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> dict[tuple[str, str], int]:
|
||||
"""Edge degrees as a batch using UNWIND also uses node_degrees_batch"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]:
|
||||
"""Get edges as a batch using UNWIND"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]:
|
||||
""""Get nodes edges as a batch using UNWIND"""
|
||||
|
||||
@abstractmethod
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||
"""Upsert an edge into the graph."""
|
||||
|
@@ -314,6 +314,37 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
logger.error(f"Error getting node for {node_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
|
||||
"""
|
||||
Retrieve multiple nodes in one query using UNWIND.
|
||||
|
||||
Args:
|
||||
node_ids: List of node entity IDs to fetch.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping each node_id to its node data (or None if not found).
|
||||
"""
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
query = """
|
||||
UNWIND $node_ids AS id
|
||||
MATCH (n:base {entity_id: id})
|
||||
RETURN n.entity_id AS entity_id, n
|
||||
"""
|
||||
result = await session.run(query, node_ids=node_ids)
|
||||
nodes = {}
|
||||
async for record in result:
|
||||
entity_id = record["entity_id"]
|
||||
node = record["n"]
|
||||
node_dict = dict(node)
|
||||
# Remove the 'base' label if present in a 'labels' property
|
||||
if "labels" in node_dict:
|
||||
node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"]
|
||||
nodes[entity_id] = node_dict
|
||||
await result.consume() # Make sure to consume the result fully
|
||||
return nodes
|
||||
|
||||
async def node_degree(self, node_id: str) -> int:
|
||||
"""Get the degree (number of relationships) of a node with the given label.
|
||||
If multiple nodes have the same label, returns the degree of the first node.
|
||||
@@ -357,6 +388,41 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
logger.error(f"Error getting node degree for {node_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
|
||||
"""
|
||||
Retrieve the degree for multiple nodes in a single query using UNWIND.
|
||||
|
||||
Args:
|
||||
node_ids: List of node labels (entity_id values) to look up.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping each node_id to its degree (number of relationships).
|
||||
If a node is not found, its degree will be set to 0.
|
||||
"""
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
query = """
|
||||
UNWIND $node_ids AS id
|
||||
MATCH (n:base {entity_id: id})
|
||||
RETURN n.entity_id AS entity_id, count { (n)--() } AS degree;
|
||||
"""
|
||||
result = await session.run(query, node_ids=node_ids)
|
||||
degrees = {}
|
||||
async for record in result:
|
||||
entity_id = record["entity_id"]
|
||||
degrees[entity_id] = record["degree"]
|
||||
await result.consume() # Ensure result is fully consumed
|
||||
|
||||
# For any node_id that did not return a record, set degree to 0.
|
||||
for nid in node_ids:
|
||||
if nid not in degrees:
|
||||
logger.warning(f"No node found with label '{nid}'")
|
||||
degrees[nid] = 0
|
||||
|
||||
logger.debug(f"Neo4j batch node degree query returned: {degrees}")
|
||||
return degrees
|
||||
|
||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||
"""Get the total degree (sum of relationships) of two nodes.
|
||||
|
||||
@@ -376,6 +442,30 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
|
||||
degrees = int(src_degree) + int(trg_degree)
|
||||
return degrees
|
||||
|
||||
async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> dict[tuple[str, str], int]:
|
||||
"""
|
||||
Calculate the combined degree for each edge (sum of the source and target node degrees)
|
||||
in batch using the already implemented node_degrees_batch.
|
||||
|
||||
Args:
|
||||
edge_pairs: List of (src, tgt) tuples.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping each (src, tgt) tuple to the sum of their degrees.
|
||||
"""
|
||||
# Collect unique node IDs from all edge pairs.
|
||||
unique_node_ids = {src for src, _ in edge_pairs}
|
||||
unique_node_ids.update({tgt for _, tgt in edge_pairs})
|
||||
|
||||
# Get degrees for all nodes in one go.
|
||||
degrees = await self.node_degrees_batch(list(unique_node_ids))
|
||||
|
||||
# Sum up degrees for each edge pair.
|
||||
edge_degrees = {}
|
||||
for src, tgt in edge_pairs:
|
||||
edge_degrees[(src, tgt)] = degrees.get(src, 0) + degrees.get(tgt, 0)
|
||||
return edge_degrees
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
@@ -463,6 +553,43 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]:
|
||||
"""
|
||||
Retrieve edge properties for multiple (src, tgt) pairs in one query.
|
||||
|
||||
Args:
|
||||
pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...]
|
||||
|
||||
Returns:
|
||||
A dictionary mapping (src, tgt) tuples to their edge properties.
|
||||
"""
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
query = """
|
||||
UNWIND $pairs AS pair
|
||||
MATCH (start:base {entity_id: pair.src})-[r:DIRECTED]-(end:base {entity_id: pair.tgt})
|
||||
RETURN pair.src AS src_id, pair.tgt AS tgt_id, collect(properties(r)) AS edges
|
||||
"""
|
||||
result = await session.run(query, pairs=pairs)
|
||||
edges_dict = {}
|
||||
async for record in result:
|
||||
src = record["src_id"]
|
||||
tgt = record["tgt_id"]
|
||||
edges = record["edges"]
|
||||
if edges and len(edges) > 0:
|
||||
edge_props = edges[0] # choose the first if multiple exist
|
||||
# Ensure required keys exist with defaults
|
||||
for key, default in {"weight": 0.0, "source_id": None, "description": None, "keywords": None}.items():
|
||||
if key not in edge_props:
|
||||
edge_props[key] = default
|
||||
edges_dict[(src, tgt)] = edge_props
|
||||
else:
|
||||
# No edge found – set default edge properties
|
||||
edges_dict[(src, tgt)] = {"weight": 0.0, "source_id": None, "description": None, "keywords": None}
|
||||
await result.consume()
|
||||
return edges_dict
|
||||
|
||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||
"""Retrieves all edges (relationships) for a particular node identified by its label.
|
||||
|
||||
@@ -523,6 +650,35 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]:
|
||||
"""
|
||||
Batch retrieve edges for multiple nodes in one query using UNWIND.
|
||||
|
||||
Args:
|
||||
node_ids: List of node IDs (entity_id) for which to retrieve edges.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping each node ID to its list of edge tuples (source, target).
|
||||
"""
|
||||
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
|
||||
query = """
|
||||
UNWIND $node_ids AS id
|
||||
MATCH (n:base {entity_id: id})
|
||||
OPTIONAL MATCH (n)-[r]-(connected:base)
|
||||
RETURN id AS queried_id, n.entity_id AS source_entity_id, connected.entity_id AS target_entity_id
|
||||
"""
|
||||
result = await session.run(query, node_ids=node_ids)
|
||||
# Initialize the dictionary with empty lists for each node ID
|
||||
edges_dict = {node_id: [] for node_id in node_ids}
|
||||
async for record in result:
|
||||
queried_id = record["queried_id"]
|
||||
source_label = record["source_entity_id"]
|
||||
target_label = record["target_entity_id"]
|
||||
if source_label and target_label:
|
||||
edges_dict[queried_id].append((source_label, target_label))
|
||||
await result.consume() # Ensure results are fully consumed
|
||||
return edges_dict
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
|
@@ -1233,16 +1233,20 @@ async def _get_node_data(
|
||||
|
||||
if not len(results):
|
||||
return "", "", ""
|
||||
# get entity information
|
||||
node_datas, node_degrees = await asyncio.gather(
|
||||
asyncio.gather(
|
||||
*[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
|
||||
),
|
||||
asyncio.gather(
|
||||
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
|
||||
),
|
||||
|
||||
# Extract all entity IDs from your results list
|
||||
node_ids = [r["entity_name"] for r in results]
|
||||
|
||||
# Call the batch node retrieval and degree functions concurrently.
|
||||
nodes_dict, degrees_dict = await asyncio.gather(
|
||||
knowledge_graph_inst.get_nodes_batch(node_ids),
|
||||
knowledge_graph_inst.node_degrees_batch(node_ids)
|
||||
)
|
||||
|
||||
# Now, if you need the node data and degree in order:
|
||||
node_datas = [nodes_dict.get(nid) for nid in node_ids]
|
||||
node_degrees = [degrees_dict.get(nid, 0) for nid in node_ids]
|
||||
|
||||
if not all([n is not None for n in node_datas]):
|
||||
logger.warning("Some nodes are missing, maybe the storage is damaged")
|
||||
|
||||
@@ -1374,9 +1378,10 @@ async def _find_most_related_text_unit_from_entities(
|
||||
all_one_hop_nodes.update([e[1] for e in this_edges])
|
||||
|
||||
all_one_hop_nodes = list(all_one_hop_nodes)
|
||||
all_one_hop_nodes_data = await asyncio.gather(
|
||||
*[knowledge_graph_inst.get_node(e) for e in all_one_hop_nodes]
|
||||
)
|
||||
|
||||
# Batch retrieve one-hop node data using get_nodes_batch
|
||||
all_one_hop_nodes_data_dict = await knowledge_graph_inst.get_nodes_batch(all_one_hop_nodes)
|
||||
all_one_hop_nodes_data = [all_one_hop_nodes_data_dict.get(e) for e in all_one_hop_nodes]
|
||||
|
||||
# Add null check for node data
|
||||
all_one_hop_text_units_lookup = {
|
||||
@@ -1512,29 +1517,34 @@ async def _get_edge_data(
|
||||
if not len(results):
|
||||
return "", "", ""
|
||||
|
||||
edge_datas, edge_degree = await asyncio.gather(
|
||||
asyncio.gather(
|
||||
*[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results]
|
||||
),
|
||||
asyncio.gather(
|
||||
*[
|
||||
knowledge_graph_inst.edge_degree(r["src_id"], r["tgt_id"])
|
||||
for r in results
|
||||
]
|
||||
),
|
||||
# Prepare edge pairs in two forms:
|
||||
# For the batch edge properties function, use dicts.
|
||||
edge_pairs_dicts = [{"src": r["src_id"], "tgt": r["tgt_id"]} for r in results]
|
||||
# For edge degrees, use tuples.
|
||||
edge_pairs_tuples = [(r["src_id"], r["tgt_id"]) for r in results]
|
||||
|
||||
# Call the batched functions concurrently.
|
||||
edge_data_dict, edge_degrees_dict = await asyncio.gather(
|
||||
knowledge_graph_inst.get_edges_batch(edge_pairs_dicts),
|
||||
knowledge_graph_inst.get_edges_degree_batch(edge_pairs_tuples)
|
||||
)
|
||||
|
||||
edge_datas = [
|
||||
{
|
||||
"src_id": k["src_id"],
|
||||
"tgt_id": k["tgt_id"],
|
||||
"rank": d,
|
||||
"created_at": k.get("__created_at__", None),
|
||||
**v,
|
||||
}
|
||||
for k, v, d in zip(results, edge_datas, edge_degree)
|
||||
if v is not None
|
||||
]
|
||||
# Reconstruct edge_datas list in the same order as results.
|
||||
edge_datas = []
|
||||
for k in results:
|
||||
pair = (k["src_id"], k["tgt_id"])
|
||||
edge_props = edge_data_dict.get(pair)
|
||||
if edge_props is not None:
|
||||
# Use edge degree from the batch as rank.
|
||||
combined = {
|
||||
"src_id": k["src_id"],
|
||||
"tgt_id": k["tgt_id"],
|
||||
"rank": edge_degrees_dict.get(pair, k.get("rank", 0)),
|
||||
"created_at": k.get("__created_at__", None),
|
||||
**edge_props,
|
||||
}
|
||||
edge_datas.append(combined)
|
||||
|
||||
edge_datas = sorted(
|
||||
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
||||
)
|
||||
@@ -1640,24 +1650,23 @@ async def _find_most_related_entities_from_relationships(
|
||||
entity_names.append(e["tgt_id"])
|
||||
seen.add(e["tgt_id"])
|
||||
|
||||
node_datas, node_degrees = await asyncio.gather(
|
||||
asyncio.gather(
|
||||
*[
|
||||
knowledge_graph_inst.get_node(entity_name)
|
||||
for entity_name in entity_names
|
||||
]
|
||||
),
|
||||
asyncio.gather(
|
||||
*[
|
||||
knowledge_graph_inst.node_degree(entity_name)
|
||||
for entity_name in entity_names
|
||||
]
|
||||
),
|
||||
# Batch approach: Retrieve nodes and their degrees concurrently with one query each.
|
||||
nodes_dict, degrees_dict = await asyncio.gather(
|
||||
knowledge_graph_inst.get_nodes_batch(entity_names),
|
||||
knowledge_graph_inst.get_node_degrees_batch(entity_names)
|
||||
)
|
||||
node_datas = [
|
||||
{**n, "entity_name": k, "rank": d}
|
||||
for k, n, d in zip(entity_names, node_datas, node_degrees)
|
||||
]
|
||||
|
||||
# Rebuild the list in the same order as entity_names
|
||||
node_datas = []
|
||||
for entity_name in entity_names:
|
||||
node = nodes_dict.get(entity_name)
|
||||
degree = degrees_dict.get(entity_name, 0)
|
||||
if node is None:
|
||||
logger.warning(f"Node '{entity_name}' not found in batch retrieval.")
|
||||
continue
|
||||
# Combine the node data with the entity name and computed degree (as rank)
|
||||
combined = {**node, "entity_name": entity_name, "rank": degree}
|
||||
node_datas.append(combined)
|
||||
|
||||
len_node_datas = len(node_datas)
|
||||
node_datas = truncate_list_by_token_size(
|
||||
|
Reference in New Issue
Block a user