Fix linting

This commit is contained in:
yangdx
2025-04-15 12:34:04 +08:00
parent 22b03ee1bb
commit 1de74c9228
7 changed files with 249 additions and 150 deletions

View File

@@ -363,7 +363,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
"""Get nodes as a batch using UNWIND
Default implementation fetches nodes one by one.
Override this method for better performance in storage backends
that support batch operations.
@@ -377,7 +377,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
"""Node degrees as a batch using UNWIND
Default implementation fetches node degrees one by one.
Override this method for better performance in storage backends
that support batch operations.
@@ -388,9 +388,11 @@ class BaseGraphStorage(StorageNameSpace, ABC):
result[node_id] = degree
return result
async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> dict[tuple[str, str], int]:
async def edge_degrees_batch(
self, edge_pairs: list[tuple[str, str]]
) -> dict[tuple[str, str], int]:
"""Edge degrees as a batch using UNWIND also uses node_degrees_batch
Default implementation calculates edge degrees one by one.
Override this method for better performance in storage backends
that support batch operations.
@@ -401,9 +403,11 @@ class BaseGraphStorage(StorageNameSpace, ABC):
result[(src_id, tgt_id)] = degree
return result
async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]:
async def get_edges_batch(
self, pairs: list[dict[str, str]]
) -> dict[tuple[str, str], dict]:
"""Get edges as a batch using UNWIND
Default implementation fetches edges one by one.
Override this method for better performance in storage backends
that support batch operations.
@@ -417,9 +421,11 @@ class BaseGraphStorage(StorageNameSpace, ABC):
result[(src_id, tgt_id)] = edge
return result
async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]:
async def get_nodes_edges_batch(
self, node_ids: list[str]
) -> dict[str, list[tuple[str, str]]]:
"""Get nodes edges as a batch using UNWIND
Default implementation fetches node edges one by one.
Override this method for better performance in storage backends
that support batch operations.

View File

@@ -311,10 +311,10 @@ class Neo4JStorage(BaseGraphStorage):
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
"""
Retrieve multiple nodes in one query using UNWIND.
Args:
node_ids: List of node entity IDs to fetch.
Returns:
A dictionary mapping each node_id to its node data (or None if not found).
"""
@@ -334,7 +334,9 @@ class Neo4JStorage(BaseGraphStorage):
node_dict = dict(node)
# Remove the 'base' label if present in a 'labels' property
if "labels" in node_dict:
node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"]
node_dict["labels"] = [
label for label in node_dict["labels"] if label != "base"
]
nodes[entity_id] = node_dict
await result.consume() # Make sure to consume the result fully
return nodes
@@ -385,12 +387,12 @@ class Neo4JStorage(BaseGraphStorage):
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
"""
Retrieve the degree for multiple nodes in a single query using UNWIND.
Args:
node_ids: List of node labels (entity_id values) to look up.
Returns:
A dictionary mapping each node_id to its degree (number of relationships).
A dictionary mapping each node_id to its degree (number of relationships).
If a node is not found, its degree will be set to 0.
"""
async with self._driver.session(
@@ -407,13 +409,13 @@ class Neo4JStorage(BaseGraphStorage):
entity_id = record["entity_id"]
degrees[entity_id] = record["degree"]
await result.consume() # Ensure result is fully consumed
# For any node_id that did not return a record, set degree to 0.
for nid in node_ids:
if nid not in degrees:
logger.warning(f"No node found with label '{nid}'")
degrees[nid] = 0
logger.debug(f"Neo4j batch node degree query returned: {degrees}")
return degrees
@@ -436,25 +438,27 @@ class Neo4JStorage(BaseGraphStorage):
degrees = int(src_degree) + int(trg_degree)
return degrees
async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> dict[tuple[str, str], int]:
async def edge_degrees_batch(
self, edge_pairs: list[tuple[str, str]]
) -> dict[tuple[str, str], int]:
"""
Calculate the combined degree for each edge (sum of the source and target node degrees)
in batch using the already implemented node_degrees_batch.
Args:
edge_pairs: List of (src, tgt) tuples.
Returns:
A dictionary mapping each (src, tgt) tuple to the sum of their degrees.
"""
# Collect unique node IDs from all edge pairs.
unique_node_ids = {src for src, _ in edge_pairs}
unique_node_ids.update({tgt for _, tgt in edge_pairs})
# Get degrees for all nodes in one go.
degrees = await self.node_degrees_batch(list(unique_node_ids))
# Sum up degrees for each edge pair.
edge_degrees = {}
for src, tgt in edge_pairs:
@@ -547,13 +551,15 @@ class Neo4JStorage(BaseGraphStorage):
)
raise
async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]:
async def get_edges_batch(
self, pairs: list[dict[str, str]]
) -> dict[tuple[str, str], dict]:
"""
Retrieve edge properties for multiple (src, tgt) pairs in one query.
Args:
pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...]
Returns:
A dictionary mapping (src, tgt) tuples to their edge properties.
"""
@@ -574,13 +580,23 @@ class Neo4JStorage(BaseGraphStorage):
if edges and len(edges) > 0:
edge_props = edges[0] # choose the first if multiple exist
# Ensure required keys exist with defaults
for key, default in {"weight": 0.0, "source_id": None, "description": None, "keywords": None}.items():
for key, default in {
"weight": 0.0,
"source_id": None,
"description": None,
"keywords": None,
}.items():
if key not in edge_props:
edge_props[key] = default
edges_dict[(src, tgt)] = edge_props
else:
# No edge found set default edge properties
edges_dict[(src, tgt)] = {"weight": 0.0, "source_id": None, "description": None, "keywords": None}
edges_dict[(src, tgt)] = {
"weight": 0.0,
"source_id": None,
"description": None,
"keywords": None,
}
await result.consume()
return edges_dict
@@ -644,17 +660,21 @@ class Neo4JStorage(BaseGraphStorage):
logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}")
raise
async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]:
async def get_nodes_edges_batch(
self, node_ids: list[str]
) -> dict[str, list[tuple[str, str]]]:
"""
Batch retrieve edges for multiple nodes in one query using UNWIND.
Args:
node_ids: List of node IDs (entity_id) for which to retrieve edges.
Returns:
A dictionary mapping each node ID to its list of edge tuples (source, target).
"""
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = """
UNWIND $node_ids AS id
MATCH (n:base {entity_id: id})

View File

@@ -1461,30 +1461,29 @@ class PGGraphStorage(BaseGraphStorage):
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
"""
Retrieve multiple nodes in one query using UNWIND.
Args:
node_ids: List of node entity IDs to fetch.
Returns:
A dictionary mapping each node_id to its node data (or None if not found).
"""
if not node_ids:
return {}
# Format node IDs for the query
formatted_ids = ", ".join(['"' + node_id.replace('"', '') + '"' for node_id in node_ids])
formatted_ids = ", ".join(
['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
)
query = """SELECT * FROM cypher('%s', $$
UNWIND [%s] AS node_id
MATCH (n:base {entity_id: node_id})
RETURN node_id, n
$$) AS (node_id text, n agtype)""" % (
self.graph_name,
formatted_ids
)
$$) AS (node_id text, n agtype)""" % (self.graph_name, formatted_ids)
results = await self._query(query)
# Build result dictionary
nodes_dict = {}
for result in results:
@@ -1492,28 +1491,32 @@ class PGGraphStorage(BaseGraphStorage):
node_dict = result["n"]["properties"]
# Remove the 'base' label if present in a 'labels' property
if "labels" in node_dict:
node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"]
node_dict["labels"] = [
label for label in node_dict["labels"] if label != "base"
]
nodes_dict[result["node_id"]] = node_dict
return nodes_dict
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
"""
Retrieve the degree for multiple nodes in a single query using UNWIND.
Args:
node_ids: List of node labels (entity_id values) to look up.
Returns:
A dictionary mapping each node_id to its degree (number of relationships).
A dictionary mapping each node_id to its degree (number of relationships).
If a node is not found, its degree will be set to 0.
"""
if not node_ids:
return {}
# Format node IDs for the query
formatted_ids = ", ".join(['"' + node_id.replace('"', '') + '"' for node_id in node_ids])
formatted_ids = ", ".join(
['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
)
query = """SELECT * FROM cypher('%s', $$
UNWIND [%s] AS node_id
MATCH (n:base {entity_id: node_id})
@@ -1521,112 +1524,122 @@ class PGGraphStorage(BaseGraphStorage):
RETURN node_id, count(r) AS degree
$$) AS (node_id text, degree bigint)""" % (
self.graph_name,
formatted_ids
formatted_ids,
)
results = await self._query(query)
# Build result dictionary
degrees_dict = {}
for result in results:
if result["node_id"] is not None:
degrees_dict[result["node_id"]] = int(result["degree"])
# Ensure all requested node_ids are in the result dictionary
for node_id in node_ids:
if node_id not in degrees_dict:
degrees_dict[node_id] = 0
return degrees_dict
async def edge_degrees_batch(self, edges: list[tuple[str, str]]) -> dict[tuple[str, str], int]:
async def edge_degrees_batch(
self, edges: list[tuple[str, str]]
) -> dict[tuple[str, str], int]:
"""
Calculate the combined degree for each edge (sum of the source and target node degrees)
in batch using the already implemented node_degrees_batch.
Args:
edges: List of (source_node_id, target_node_id) tuples
Returns:
Dictionary mapping edge tuples to their combined degrees
"""
if not edges:
return {}
# Use node_degrees_batch to get all node degrees efficiently
all_nodes = set()
for src, tgt in edges:
all_nodes.add(src)
all_nodes.add(tgt)
node_degrees = await self.node_degrees_batch(list(all_nodes))
# Calculate edge degrees
edge_degrees_dict = {}
for src, tgt in edges:
src_degree = node_degrees.get(src, 0)
tgt_degree = node_degrees.get(tgt, 0)
edge_degrees_dict[(src, tgt)] = src_degree + tgt_degree
return edge_degrees_dict
async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]:
async def get_edges_batch(
self, pairs: list[dict[str, str]]
) -> dict[tuple[str, str], dict]:
"""
Retrieve edge properties for multiple (src, tgt) pairs in one query.
Args:
pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...]
Returns:
A dictionary mapping (src, tgt) tuples to their edge properties.
"""
if not pairs:
return {}
# 从字典列表中提取源节点和目标节点ID
src_nodes = []
tgt_nodes = []
for pair in pairs:
src_nodes.append(pair["src"].replace('"', ''))
tgt_nodes.append(pair["tgt"].replace('"', ''))
src_nodes.append(pair["src"].replace('"', ""))
tgt_nodes.append(pair["tgt"].replace('"', ""))
# 构建查询,使用数组索引来匹配源节点和目标节点
src_array = ", ".join([f'"{src}"' for src in src_nodes])
tgt_array = ", ".join([f'"{tgt}"' for tgt in tgt_nodes])
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
WITH [{src_array}] AS sources, [{tgt_array}] AS targets
UNWIND range(0, size(sources)-1) AS i
MATCH (a:base {{entity_id: sources[i]}})-[r:DIRECTED]-(b:base {{entity_id: targets[i]}})
RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties
$$) AS (source text, target text, edge_properties agtype)"""
results = await self._query(query)
# 构建结果字典
edges_dict = {}
for result in results:
if result["source"] and result["target"] and result["edge_properties"]:
edges_dict[(result["source"], result["target"])] = result["edge_properties"]
edges_dict[(result["source"], result["target"])] = result[
"edge_properties"
]
return edges_dict
async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]:
async def get_nodes_edges_batch(
self, node_ids: list[str]
) -> dict[str, list[tuple[str, str]]]:
"""
Get all edges for multiple nodes in a single batch operation.
Args:
node_ids: List of node IDs to get edges for
Returns:
Dictionary mapping node IDs to lists of (source, target) edge tuples
"""
if not node_ids:
return {}
# Format node IDs for the query
formatted_ids = ", ".join(['"' + node_id.replace('"', '') + '"' for node_id in node_ids])
formatted_ids = ", ".join(
['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
)
query = """SELECT * FROM cypher('%s', $$
UNWIND [%s] AS node_id
MATCH (n:base {entity_id: node_id})
@@ -1634,11 +1647,11 @@ class PGGraphStorage(BaseGraphStorage):
RETURN node_id, connected.entity_id AS connected_id
$$) AS (node_id text, connected_id text)""" % (
self.graph_name,
formatted_ids
formatted_ids,
)
results = await self._query(query)
# Build result dictionary
nodes_edges_dict = {node_id: [] for node_id in node_ids}
for result in results:
@@ -1646,9 +1659,9 @@ class PGGraphStorage(BaseGraphStorage):
nodes_edges_dict[result["node_id"]].append(
(result["node_id"], result["connected_id"])
)
return nodes_edges_dict
async def get_all_labels(self) -> list[str]:
"""
Get all labels (node IDs) in the graph.

View File

@@ -1323,14 +1323,14 @@ async def _get_node_data(
if not len(results):
return "", "", ""
# Extract all entity IDs from your results list
node_ids = [r["entity_name"] for r in results]
# Call the batch node retrieval and degree functions concurrently.
nodes_dict, degrees_dict = await asyncio.gather(
knowledge_graph_inst.get_nodes_batch(node_ids),
knowledge_graph_inst.node_degrees_batch(node_ids)
knowledge_graph_inst.get_nodes_batch(node_ids),
knowledge_graph_inst.node_degrees_batch(node_ids),
)
# Now, if you need the node data and degree in order:
@@ -1459,7 +1459,7 @@ async def _find_most_related_text_unit_from_entities(
for dp in node_datas
if dp["source_id"] is not None
]
node_names = [dp["entity_name"] for dp in node_datas]
batch_edges_dict = await knowledge_graph_inst.get_nodes_edges_batch(node_names)
# Build the edges list in the same order as node_datas.
@@ -1472,10 +1472,14 @@ async def _find_most_related_text_unit_from_entities(
all_one_hop_nodes.update([e[1] for e in this_edges])
all_one_hop_nodes = list(all_one_hop_nodes)
# Batch retrieve one-hop node data using get_nodes_batch
all_one_hop_nodes_data_dict = await knowledge_graph_inst.get_nodes_batch(all_one_hop_nodes)
all_one_hop_nodes_data = [all_one_hop_nodes_data_dict.get(e) for e in all_one_hop_nodes]
all_one_hop_nodes_data_dict = await knowledge_graph_inst.get_nodes_batch(
all_one_hop_nodes
)
all_one_hop_nodes_data = [
all_one_hop_nodes_data_dict.get(e) for e in all_one_hop_nodes
]
# Add null check for node data
all_one_hop_text_units_lookup = {
@@ -1571,13 +1575,13 @@ async def _find_most_related_edges_from_entities(
edge_pairs_dicts = [{"src": e[0], "tgt": e[1]} for e in all_edges]
# For edge degrees, use tuples.
edge_pairs_tuples = list(all_edges) # all_edges is already a list of tuples
# Call the batched functions concurrently.
edge_data_dict, edge_degrees_dict = await asyncio.gather(
knowledge_graph_inst.get_edges_batch(edge_pairs_dicts),
knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples)
knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples),
)
# Reconstruct edge_datas list in the same order as the deduplicated results.
all_edges_data = []
for pair in all_edges:
@@ -1590,7 +1594,6 @@ async def _find_most_related_edges_from_entities(
}
all_edges_data.append(combined)
all_edges_data = sorted(
all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
)
@@ -1634,7 +1637,7 @@ async def _get_edge_data(
# Call the batched functions concurrently.
edge_data_dict, edge_degrees_dict = await asyncio.gather(
knowledge_graph_inst.get_edges_batch(edge_pairs_dicts),
knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples)
knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples),
)
# Reconstruct edge_datas list in the same order as results.
@@ -1652,7 +1655,7 @@ async def _get_edge_data(
**edge_props,
}
edge_datas.append(combined)
edge_datas = sorted(
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
)
@@ -1761,7 +1764,7 @@ async def _find_most_related_entities_from_relationships(
# Batch approach: Retrieve nodes and their degrees concurrently with one query each.
nodes_dict, degrees_dict = await asyncio.gather(
knowledge_graph_inst.get_nodes_batch(entity_names),
knowledge_graph_inst.node_degrees_batch(entity_names)
knowledge_graph_inst.node_degrees_batch(entity_names),
)
# Rebuild the list in the same order as entity_names