Fix linting

This commit is contained in:
yangdx
2025-04-15 12:34:04 +08:00
parent 22b03ee1bb
commit 1de74c9228
7 changed files with 249 additions and 150 deletions

View File

@@ -388,7 +388,9 @@ class BaseGraphStorage(StorageNameSpace, ABC):
result[node_id] = degree
return result
async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> dict[tuple[str, str], int]:
async def edge_degrees_batch(
self, edge_pairs: list[tuple[str, str]]
) -> dict[tuple[str, str], int]:
"""Edge degrees as a batch using UNWIND also uses node_degrees_batch
Default implementation calculates edge degrees one by one.
@@ -401,7 +403,9 @@ class BaseGraphStorage(StorageNameSpace, ABC):
result[(src_id, tgt_id)] = degree
return result
async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]:
async def get_edges_batch(
self, pairs: list[dict[str, str]]
) -> dict[tuple[str, str], dict]:
"""Get edges as a batch using UNWIND
Default implementation fetches edges one by one.
@@ -417,7 +421,9 @@ class BaseGraphStorage(StorageNameSpace, ABC):
result[(src_id, tgt_id)] = edge
return result
async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]:
async def get_nodes_edges_batch(
self, node_ids: list[str]
) -> dict[str, list[tuple[str, str]]]:
"""Get nodes edges as a batch using UNWIND
Default implementation fetches node edges one by one.

View File

@@ -334,7 +334,9 @@ class Neo4JStorage(BaseGraphStorage):
node_dict = dict(node)
# Remove the 'base' label if present in a 'labels' property
if "labels" in node_dict:
node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"]
node_dict["labels"] = [
label for label in node_dict["labels"] if label != "base"
]
nodes[entity_id] = node_dict
await result.consume() # Make sure to consume the result fully
return nodes
@@ -437,7 +439,9 @@ class Neo4JStorage(BaseGraphStorage):
degrees = int(src_degree) + int(trg_degree)
return degrees
async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> dict[tuple[str, str], int]:
async def edge_degrees_batch(
self, edge_pairs: list[tuple[str, str]]
) -> dict[tuple[str, str], int]:
"""
Calculate the combined degree for each edge (sum of the source and target node degrees)
in batch using the already implemented node_degrees_batch.
@@ -547,7 +551,9 @@ class Neo4JStorage(BaseGraphStorage):
)
raise
async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]:
async def get_edges_batch(
self, pairs: list[dict[str, str]]
) -> dict[tuple[str, str], dict]:
"""
Retrieve edge properties for multiple (src, tgt) pairs in one query.
@@ -574,13 +580,23 @@ class Neo4JStorage(BaseGraphStorage):
if edges and len(edges) > 0:
edge_props = edges[0] # choose the first if multiple exist
# Ensure required keys exist with defaults
for key, default in {"weight": 0.0, "source_id": None, "description": None, "keywords": None}.items():
for key, default in {
"weight": 0.0,
"source_id": None,
"description": None,
"keywords": None,
}.items():
if key not in edge_props:
edge_props[key] = default
edges_dict[(src, tgt)] = edge_props
else:
# No edge found set default edge properties
edges_dict[(src, tgt)] = {"weight": 0.0, "source_id": None, "description": None, "keywords": None}
edges_dict[(src, tgt)] = {
"weight": 0.0,
"source_id": None,
"description": None,
"keywords": None,
}
await result.consume()
return edges_dict
@@ -644,7 +660,9 @@ class Neo4JStorage(BaseGraphStorage):
logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}")
raise
async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]:
async def get_nodes_edges_batch(
self, node_ids: list[str]
) -> dict[str, list[tuple[str, str]]]:
"""
Batch retrieve edges for multiple nodes in one query using UNWIND.
@@ -654,7 +672,9 @@ class Neo4JStorage(BaseGraphStorage):
Returns:
A dictionary mapping each node ID to its list of edge tuples (source, target).
"""
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = """
UNWIND $node_ids AS id
MATCH (n:base {entity_id: id})

View File

@@ -1472,16 +1472,15 @@ class PGGraphStorage(BaseGraphStorage):
return {}
# Format node IDs for the query
formatted_ids = ", ".join(['"' + node_id.replace('"', '') + '"' for node_id in node_ids])
formatted_ids = ", ".join(
['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
)
query = """SELECT * FROM cypher('%s', $$
UNWIND [%s] AS node_id
MATCH (n:base {entity_id: node_id})
RETURN node_id, n
$$) AS (node_id text, n agtype)""" % (
self.graph_name,
formatted_ids
)
$$) AS (node_id text, n agtype)""" % (self.graph_name, formatted_ids)
results = await self._query(query)
@@ -1492,7 +1491,9 @@ class PGGraphStorage(BaseGraphStorage):
node_dict = result["n"]["properties"]
# Remove the 'base' label if present in a 'labels' property
if "labels" in node_dict:
node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"]
node_dict["labels"] = [
label for label in node_dict["labels"] if label != "base"
]
nodes_dict[result["node_id"]] = node_dict
return nodes_dict
@@ -1512,7 +1513,9 @@ class PGGraphStorage(BaseGraphStorage):
return {}
# Format node IDs for the query
formatted_ids = ", ".join(['"' + node_id.replace('"', '') + '"' for node_id in node_ids])
formatted_ids = ", ".join(
['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
)
query = """SELECT * FROM cypher('%s', $$
UNWIND [%s] AS node_id
@@ -1521,7 +1524,7 @@ class PGGraphStorage(BaseGraphStorage):
RETURN node_id, count(r) AS degree
$$) AS (node_id text, degree bigint)""" % (
self.graph_name,
formatted_ids
formatted_ids,
)
results = await self._query(query)
@@ -1539,7 +1542,9 @@ class PGGraphStorage(BaseGraphStorage):
return degrees_dict
async def edge_degrees_batch(self, edges: list[tuple[str, str]]) -> dict[tuple[str, str], int]:
async def edge_degrees_batch(
self, edges: list[tuple[str, str]]
) -> dict[tuple[str, str], int]:
"""
Calculate the combined degree for each edge (sum of the source and target node degrees)
in batch using the already implemented node_degrees_batch.
@@ -1570,7 +1575,9 @@ class PGGraphStorage(BaseGraphStorage):
return edge_degrees_dict
async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]:
async def get_edges_batch(
self, pairs: list[dict[str, str]]
) -> dict[tuple[str, str], dict]:
"""
Retrieve edge properties for multiple (src, tgt) pairs in one query.
@@ -1587,8 +1594,8 @@ class PGGraphStorage(BaseGraphStorage):
src_nodes = []
tgt_nodes = []
for pair in pairs:
src_nodes.append(pair["src"].replace('"', ''))
tgt_nodes.append(pair["tgt"].replace('"', ''))
src_nodes.append(pair["src"].replace('"', ""))
tgt_nodes.append(pair["tgt"].replace('"', ""))
# 构建查询,使用数组索引来匹配源节点和目标节点
src_array = ", ".join([f'"{src}"' for src in src_nodes])
@@ -1607,11 +1614,15 @@ class PGGraphStorage(BaseGraphStorage):
edges_dict = {}
for result in results:
if result["source"] and result["target"] and result["edge_properties"]:
edges_dict[(result["source"], result["target"])] = result["edge_properties"]
edges_dict[(result["source"], result["target"])] = result[
"edge_properties"
]
return edges_dict
async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]:
async def get_nodes_edges_batch(
self, node_ids: list[str]
) -> dict[str, list[tuple[str, str]]]:
"""
Get all edges for multiple nodes in a single batch operation.
@@ -1625,7 +1636,9 @@ class PGGraphStorage(BaseGraphStorage):
return {}
# Format node IDs for the query
formatted_ids = ", ".join(['"' + node_id.replace('"', '') + '"' for node_id in node_ids])
formatted_ids = ", ".join(
['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
)
query = """SELECT * FROM cypher('%s', $$
UNWIND [%s] AS node_id
@@ -1634,7 +1647,7 @@ class PGGraphStorage(BaseGraphStorage):
RETURN node_id, connected.entity_id AS connected_id
$$) AS (node_id text, connected_id text)""" % (
self.graph_name,
formatted_ids
formatted_ids,
)
results = await self._query(query)

View File

@@ -1330,7 +1330,7 @@ async def _get_node_data(
# Call the batch node retrieval and degree functions concurrently.
nodes_dict, degrees_dict = await asyncio.gather(
knowledge_graph_inst.get_nodes_batch(node_ids),
knowledge_graph_inst.node_degrees_batch(node_ids)
knowledge_graph_inst.node_degrees_batch(node_ids),
)
# Now, if you need the node data and degree in order:
@@ -1474,8 +1474,12 @@ async def _find_most_related_text_unit_from_entities(
all_one_hop_nodes = list(all_one_hop_nodes)
# Batch retrieve one-hop node data using get_nodes_batch
all_one_hop_nodes_data_dict = await knowledge_graph_inst.get_nodes_batch(all_one_hop_nodes)
all_one_hop_nodes_data = [all_one_hop_nodes_data_dict.get(e) for e in all_one_hop_nodes]
all_one_hop_nodes_data_dict = await knowledge_graph_inst.get_nodes_batch(
all_one_hop_nodes
)
all_one_hop_nodes_data = [
all_one_hop_nodes_data_dict.get(e) for e in all_one_hop_nodes
]
# Add null check for node data
all_one_hop_text_units_lookup = {
@@ -1575,7 +1579,7 @@ async def _find_most_related_edges_from_entities(
# Call the batched functions concurrently.
edge_data_dict, edge_degrees_dict = await asyncio.gather(
knowledge_graph_inst.get_edges_batch(edge_pairs_dicts),
knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples)
knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples),
)
# Reconstruct edge_datas list in the same order as the deduplicated results.
@@ -1590,7 +1594,6 @@ async def _find_most_related_edges_from_entities(
}
all_edges_data.append(combined)
all_edges_data = sorted(
all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
)
@@ -1634,7 +1637,7 @@ async def _get_edge_data(
# Call the batched functions concurrently.
edge_data_dict, edge_degrees_dict = await asyncio.gather(
knowledge_graph_inst.get_edges_batch(edge_pairs_dicts),
knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples)
knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples),
)
# Reconstruct edge_datas list in the same order as results.
@@ -1761,7 +1764,7 @@ async def _find_most_related_entities_from_relationships(
# Batch approach: Retrieve nodes and their degrees concurrently with one query each.
nodes_dict, degrees_dict = await asyncio.gather(
knowledge_graph_inst.get_nodes_batch(entity_names),
knowledge_graph_inst.node_degrees_batch(entity_names)
knowledge_graph_inst.node_degrees_batch(entity_names),
)
# Rebuild the list in the same order as entity_names

View File

@@ -510,35 +510,66 @@ async def test_graph_batch_operations(storage):
assert node1_id in nodes_dict, f"{node1_id} 应在返回结果中"
assert node2_id in nodes_dict, f"{node2_id} 应在返回结果中"
assert node3_id in nodes_dict, f"{node3_id} 应在返回结果中"
assert nodes_dict[node1_id]["description"] == node1_data["description"], f"{node1_id} 描述不匹配"
assert nodes_dict[node2_id]["description"] == node2_data["description"], f"{node2_id} 描述不匹配"
assert nodes_dict[node3_id]["description"] == node3_data["description"], f"{node3_id} 描述不匹配"
assert (
nodes_dict[node1_id]["description"] == node1_data["description"]
), f"{node1_id} 描述不匹配"
assert (
nodes_dict[node2_id]["description"] == node2_data["description"]
), f"{node2_id} 描述不匹配"
assert (
nodes_dict[node3_id]["description"] == node3_data["description"]
), f"{node3_id} 描述不匹配"
# 3. 测试 node_degrees_batch - 批量获取多个节点的度数
print("== 测试 node_degrees_batch")
node_degrees = await storage.node_degrees_batch(node_ids)
print(f"批量获取节点度数结果: {node_degrees}")
assert len(node_degrees) == 3, f"应返回3个节点的度数实际返回 {len(node_degrees)}"
assert (
len(node_degrees) == 3
), f"应返回3个节点的度数实际返回 {len(node_degrees)}"
assert node1_id in node_degrees, f"{node1_id} 应在返回结果中"
assert node2_id in node_degrees, f"{node2_id} 应在返回结果中"
assert node3_id in node_degrees, f"{node3_id} 应在返回结果中"
assert node_degrees[node1_id] == 3, f"{node1_id} 度数应为3实际为 {node_degrees[node1_id]}"
assert node_degrees[node2_id] == 2, f"{node2_id} 度数应为2实际为 {node_degrees[node2_id]}"
assert node_degrees[node3_id] == 3, f"{node3_id} 度数应为3实际为 {node_degrees[node3_id]}"
assert (
node_degrees[node1_id] == 3
), f"{node1_id} 度数应为3实际为 {node_degrees[node1_id]}"
assert (
node_degrees[node2_id] == 2
), f"{node2_id} 度数应为2实际为 {node_degrees[node2_id]}"
assert (
node_degrees[node3_id] == 3
), f"{node3_id} 度数应为3实际为 {node_degrees[node3_id]}"
# 4. 测试 edge_degrees_batch - 批量获取多个边的度数
print("== 测试 edge_degrees_batch")
edges = [(node1_id, node2_id), (node2_id, node3_id), (node3_id, node4_id)]
edge_degrees = await storage.edge_degrees_batch(edges)
print(f"批量获取边度数结果: {edge_degrees}")
assert len(edge_degrees) == 3, f"应返回3条边的度数实际返回 {len(edge_degrees)}"
assert (node1_id, node2_id) in edge_degrees, f"{node1_id} -> {node2_id} 应在返回结果中"
assert (node2_id, node3_id) in edge_degrees, f"{node2_id} -> {node3_id} 应在返回结果中"
assert (node3_id, node4_id) in edge_degrees, f"{node3_id} -> {node4_id} 应在返回结果中"
assert (
len(edge_degrees) == 3
), f"应返回3条边的度数实际返回 {len(edge_degrees)}"
assert (
node1_id,
node2_id,
) in edge_degrees, f"{node1_id} -> {node2_id} 应在返回结果中"
assert (
node2_id,
node3_id,
) in edge_degrees, f"{node2_id} -> {node3_id} 应在返回结果中"
assert (
node3_id,
node4_id,
) in edge_degrees, f"{node3_id} -> {node4_id} 应在返回结果中"
# 验证边的度数是否正确(源节点度数 + 目标节点度数)
assert edge_degrees[(node1_id, node2_id)] == 5, f"{node1_id} -> {node2_id} 度数应为5实际为 {edge_degrees[(node1_id, node2_id)]}"
assert edge_degrees[(node2_id, node3_id)] == 5, f"{node2_id} -> {node3_id} 度数应为5实际为 {edge_degrees[(node2_id, node3_id)]}"
assert edge_degrees[(node3_id, node4_id)] == 5, f"{node3_id} -> {node4_id} 度数应为5实际为 {edge_degrees[(node3_id, node4_id)]}"
assert (
edge_degrees[(node1_id, node2_id)] == 5
), f"{node1_id} -> {node2_id} 度数应为5实际为 {edge_degrees[(node1_id, node2_id)]}"
assert (
edge_degrees[(node2_id, node3_id)] == 5
), f"{node2_id} -> {node3_id} 度数应为5实际为 {edge_degrees[(node2_id, node3_id)]}"
assert (
edge_degrees[(node3_id, node4_id)] == 5
), f"{node3_id} -> {node4_id} 度数应为5实际为 {edge_degrees[(node3_id, node4_id)]}"
# 5. 测试 get_edges_batch - 批量获取多个边的属性
print("== 测试 get_edges_batch")
@@ -547,28 +578,54 @@ async def test_graph_batch_operations(storage):
edges_dict = await storage.get_edges_batch(edge_dicts)
print(f"批量获取边属性结果: {edges_dict.keys()}")
assert len(edges_dict) == 3, f"应返回3条边的属性实际返回 {len(edges_dict)}"
assert (node1_id, node2_id) in edges_dict, f"{node1_id} -> {node2_id} 应在返回结果中"
assert (node2_id, node3_id) in edges_dict, f"{node2_id} -> {node3_id} 应在返回结果中"
assert (node3_id, node4_id) in edges_dict, f"{node3_id} -> {node4_id} 应在返回结果中"
assert edges_dict[(node1_id, node2_id)]["relationship"] == edge1_data["relationship"], f"{node1_id} -> {node2_id} 关系不匹配"
assert edges_dict[(node2_id, node3_id)]["relationship"] == edge2_data["relationship"], f"{node2_id} -> {node3_id} 关系不匹配"
assert edges_dict[(node3_id, node4_id)]["relationship"] == edge5_data["relationship"], f"{node3_id} -> {node4_id} 关系不匹配"
assert (
node1_id,
node2_id,
) in edges_dict, f"{node1_id} -> {node2_id} 应在返回结果中"
assert (
node2_id,
node3_id,
) in edges_dict, f"{node2_id} -> {node3_id} 应在返回结果中"
assert (
node3_id,
node4_id,
) in edges_dict, f"{node3_id} -> {node4_id} 应在返回结果中"
assert (
edges_dict[(node1_id, node2_id)]["relationship"]
== edge1_data["relationship"]
), f"{node1_id} -> {node2_id} 关系不匹配"
assert (
edges_dict[(node2_id, node3_id)]["relationship"]
== edge2_data["relationship"]
), f"{node2_id} -> {node3_id} 关系不匹配"
assert (
edges_dict[(node3_id, node4_id)]["relationship"]
== edge5_data["relationship"]
), f"{node3_id} -> {node4_id} 关系不匹配"
# 6. 测试 get_nodes_edges_batch - 批量获取多个节点的所有边
print("== 测试 get_nodes_edges_batch")
nodes_edges = await storage.get_nodes_edges_batch([node1_id, node3_id])
print(f"批量获取节点边结果: {nodes_edges.keys()}")
assert len(nodes_edges) == 2, f"应返回2个节点的边实际返回 {len(nodes_edges)}"
assert (
len(nodes_edges) == 2
), f"应返回2个节点的边实际返回 {len(nodes_edges)}"
assert node1_id in nodes_edges, f"{node1_id} 应在返回结果中"
assert node3_id in nodes_edges, f"{node3_id} 应在返回结果中"
assert len(nodes_edges[node1_id]) == 3, f"{node1_id} 应有3条边实际有 {len(nodes_edges[node1_id])}"
assert len(nodes_edges[node3_id]) == 3, f"{node3_id} 应有3条边实际有 {len(nodes_edges[node3_id])}"
assert (
len(nodes_edges[node1_id]) == 3
), f"{node1_id} 应有3条边实际有 {len(nodes_edges[node1_id])}"
assert (
len(nodes_edges[node3_id]) == 3
), f"{node3_id} 应有3条边实际有 {len(nodes_edges[node3_id])}"
# 7. 清理数据
print("== 测试 drop")
result = await storage.drop()
print(f"清理结果: {result}")
assert result["status"] == "success", f"清理应成功,实际状态为 {result['status']}"
assert (
result["status"] == "success"
), f"清理应成功,实际状态为 {result['status']}"
print("\n批量操作测试完成")
return True