Fix linting
This commit is contained in:
@@ -388,7 +388,9 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
|||||||
result[node_id] = degree
|
result[node_id] = degree
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> dict[tuple[str, str], int]:
|
async def edge_degrees_batch(
|
||||||
|
self, edge_pairs: list[tuple[str, str]]
|
||||||
|
) -> dict[tuple[str, str], int]:
|
||||||
"""Edge degrees as a batch using UNWIND also uses node_degrees_batch
|
"""Edge degrees as a batch using UNWIND also uses node_degrees_batch
|
||||||
|
|
||||||
Default implementation calculates edge degrees one by one.
|
Default implementation calculates edge degrees one by one.
|
||||||
@@ -401,7 +403,9 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
|||||||
result[(src_id, tgt_id)] = degree
|
result[(src_id, tgt_id)] = degree
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]:
|
async def get_edges_batch(
|
||||||
|
self, pairs: list[dict[str, str]]
|
||||||
|
) -> dict[tuple[str, str], dict]:
|
||||||
"""Get edges as a batch using UNWIND
|
"""Get edges as a batch using UNWIND
|
||||||
|
|
||||||
Default implementation fetches edges one by one.
|
Default implementation fetches edges one by one.
|
||||||
@@ -417,7 +421,9 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
|||||||
result[(src_id, tgt_id)] = edge
|
result[(src_id, tgt_id)] = edge
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]:
|
async def get_nodes_edges_batch(
|
||||||
|
self, node_ids: list[str]
|
||||||
|
) -> dict[str, list[tuple[str, str]]]:
|
||||||
"""Get nodes edges as a batch using UNWIND
|
"""Get nodes edges as a batch using UNWIND
|
||||||
|
|
||||||
Default implementation fetches node edges one by one.
|
Default implementation fetches node edges one by one.
|
||||||
|
@@ -334,7 +334,9 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
node_dict = dict(node)
|
node_dict = dict(node)
|
||||||
# Remove the 'base' label if present in a 'labels' property
|
# Remove the 'base' label if present in a 'labels' property
|
||||||
if "labels" in node_dict:
|
if "labels" in node_dict:
|
||||||
node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"]
|
node_dict["labels"] = [
|
||||||
|
label for label in node_dict["labels"] if label != "base"
|
||||||
|
]
|
||||||
nodes[entity_id] = node_dict
|
nodes[entity_id] = node_dict
|
||||||
await result.consume() # Make sure to consume the result fully
|
await result.consume() # Make sure to consume the result fully
|
||||||
return nodes
|
return nodes
|
||||||
@@ -437,7 +439,9 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
degrees = int(src_degree) + int(trg_degree)
|
degrees = int(src_degree) + int(trg_degree)
|
||||||
return degrees
|
return degrees
|
||||||
|
|
||||||
async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> dict[tuple[str, str], int]:
|
async def edge_degrees_batch(
|
||||||
|
self, edge_pairs: list[tuple[str, str]]
|
||||||
|
) -> dict[tuple[str, str], int]:
|
||||||
"""
|
"""
|
||||||
Calculate the combined degree for each edge (sum of the source and target node degrees)
|
Calculate the combined degree for each edge (sum of the source and target node degrees)
|
||||||
in batch using the already implemented node_degrees_batch.
|
in batch using the already implemented node_degrees_batch.
|
||||||
@@ -547,7 +551,9 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]:
|
async def get_edges_batch(
|
||||||
|
self, pairs: list[dict[str, str]]
|
||||||
|
) -> dict[tuple[str, str], dict]:
|
||||||
"""
|
"""
|
||||||
Retrieve edge properties for multiple (src, tgt) pairs in one query.
|
Retrieve edge properties for multiple (src, tgt) pairs in one query.
|
||||||
|
|
||||||
@@ -574,13 +580,23 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
if edges and len(edges) > 0:
|
if edges and len(edges) > 0:
|
||||||
edge_props = edges[0] # choose the first if multiple exist
|
edge_props = edges[0] # choose the first if multiple exist
|
||||||
# Ensure required keys exist with defaults
|
# Ensure required keys exist with defaults
|
||||||
for key, default in {"weight": 0.0, "source_id": None, "description": None, "keywords": None}.items():
|
for key, default in {
|
||||||
|
"weight": 0.0,
|
||||||
|
"source_id": None,
|
||||||
|
"description": None,
|
||||||
|
"keywords": None,
|
||||||
|
}.items():
|
||||||
if key not in edge_props:
|
if key not in edge_props:
|
||||||
edge_props[key] = default
|
edge_props[key] = default
|
||||||
edges_dict[(src, tgt)] = edge_props
|
edges_dict[(src, tgt)] = edge_props
|
||||||
else:
|
else:
|
||||||
# No edge found – set default edge properties
|
# No edge found – set default edge properties
|
||||||
edges_dict[(src, tgt)] = {"weight": 0.0, "source_id": None, "description": None, "keywords": None}
|
edges_dict[(src, tgt)] = {
|
||||||
|
"weight": 0.0,
|
||||||
|
"source_id": None,
|
||||||
|
"description": None,
|
||||||
|
"keywords": None,
|
||||||
|
}
|
||||||
await result.consume()
|
await result.consume()
|
||||||
return edges_dict
|
return edges_dict
|
||||||
|
|
||||||
@@ -644,7 +660,9 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}")
|
logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]:
|
async def get_nodes_edges_batch(
|
||||||
|
self, node_ids: list[str]
|
||||||
|
) -> dict[str, list[tuple[str, str]]]:
|
||||||
"""
|
"""
|
||||||
Batch retrieve edges for multiple nodes in one query using UNWIND.
|
Batch retrieve edges for multiple nodes in one query using UNWIND.
|
||||||
|
|
||||||
@@ -654,7 +672,9 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
Returns:
|
Returns:
|
||||||
A dictionary mapping each node ID to its list of edge tuples (source, target).
|
A dictionary mapping each node ID to its list of edge tuples (source, target).
|
||||||
"""
|
"""
|
||||||
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
|
async with self._driver.session(
|
||||||
|
database=self._DATABASE, default_access_mode="READ"
|
||||||
|
) as session:
|
||||||
query = """
|
query = """
|
||||||
UNWIND $node_ids AS id
|
UNWIND $node_ids AS id
|
||||||
MATCH (n:base {entity_id: id})
|
MATCH (n:base {entity_id: id})
|
||||||
|
@@ -1472,16 +1472,15 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
# Format node IDs for the query
|
# Format node IDs for the query
|
||||||
formatted_ids = ", ".join(['"' + node_id.replace('"', '') + '"' for node_id in node_ids])
|
formatted_ids = ", ".join(
|
||||||
|
['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
|
||||||
|
)
|
||||||
|
|
||||||
query = """SELECT * FROM cypher('%s', $$
|
query = """SELECT * FROM cypher('%s', $$
|
||||||
UNWIND [%s] AS node_id
|
UNWIND [%s] AS node_id
|
||||||
MATCH (n:base {entity_id: node_id})
|
MATCH (n:base {entity_id: node_id})
|
||||||
RETURN node_id, n
|
RETURN node_id, n
|
||||||
$$) AS (node_id text, n agtype)""" % (
|
$$) AS (node_id text, n agtype)""" % (self.graph_name, formatted_ids)
|
||||||
self.graph_name,
|
|
||||||
formatted_ids
|
|
||||||
)
|
|
||||||
|
|
||||||
results = await self._query(query)
|
results = await self._query(query)
|
||||||
|
|
||||||
@@ -1492,7 +1491,9 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
node_dict = result["n"]["properties"]
|
node_dict = result["n"]["properties"]
|
||||||
# Remove the 'base' label if present in a 'labels' property
|
# Remove the 'base' label if present in a 'labels' property
|
||||||
if "labels" in node_dict:
|
if "labels" in node_dict:
|
||||||
node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"]
|
node_dict["labels"] = [
|
||||||
|
label for label in node_dict["labels"] if label != "base"
|
||||||
|
]
|
||||||
nodes_dict[result["node_id"]] = node_dict
|
nodes_dict[result["node_id"]] = node_dict
|
||||||
|
|
||||||
return nodes_dict
|
return nodes_dict
|
||||||
@@ -1512,7 +1513,9 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
# Format node IDs for the query
|
# Format node IDs for the query
|
||||||
formatted_ids = ", ".join(['"' + node_id.replace('"', '') + '"' for node_id in node_ids])
|
formatted_ids = ", ".join(
|
||||||
|
['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
|
||||||
|
)
|
||||||
|
|
||||||
query = """SELECT * FROM cypher('%s', $$
|
query = """SELECT * FROM cypher('%s', $$
|
||||||
UNWIND [%s] AS node_id
|
UNWIND [%s] AS node_id
|
||||||
@@ -1521,7 +1524,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
RETURN node_id, count(r) AS degree
|
RETURN node_id, count(r) AS degree
|
||||||
$$) AS (node_id text, degree bigint)""" % (
|
$$) AS (node_id text, degree bigint)""" % (
|
||||||
self.graph_name,
|
self.graph_name,
|
||||||
formatted_ids
|
formatted_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
results = await self._query(query)
|
results = await self._query(query)
|
||||||
@@ -1539,7 +1542,9 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
return degrees_dict
|
return degrees_dict
|
||||||
|
|
||||||
async def edge_degrees_batch(self, edges: list[tuple[str, str]]) -> dict[tuple[str, str], int]:
|
async def edge_degrees_batch(
|
||||||
|
self, edges: list[tuple[str, str]]
|
||||||
|
) -> dict[tuple[str, str], int]:
|
||||||
"""
|
"""
|
||||||
Calculate the combined degree for each edge (sum of the source and target node degrees)
|
Calculate the combined degree for each edge (sum of the source and target node degrees)
|
||||||
in batch using the already implemented node_degrees_batch.
|
in batch using the already implemented node_degrees_batch.
|
||||||
@@ -1570,7 +1575,9 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
return edge_degrees_dict
|
return edge_degrees_dict
|
||||||
|
|
||||||
async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]:
|
async def get_edges_batch(
|
||||||
|
self, pairs: list[dict[str, str]]
|
||||||
|
) -> dict[tuple[str, str], dict]:
|
||||||
"""
|
"""
|
||||||
Retrieve edge properties for multiple (src, tgt) pairs in one query.
|
Retrieve edge properties for multiple (src, tgt) pairs in one query.
|
||||||
|
|
||||||
@@ -1587,8 +1594,8 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
src_nodes = []
|
src_nodes = []
|
||||||
tgt_nodes = []
|
tgt_nodes = []
|
||||||
for pair in pairs:
|
for pair in pairs:
|
||||||
src_nodes.append(pair["src"].replace('"', ''))
|
src_nodes.append(pair["src"].replace('"', ""))
|
||||||
tgt_nodes.append(pair["tgt"].replace('"', ''))
|
tgt_nodes.append(pair["tgt"].replace('"', ""))
|
||||||
|
|
||||||
# 构建查询,使用数组索引来匹配源节点和目标节点
|
# 构建查询,使用数组索引来匹配源节点和目标节点
|
||||||
src_array = ", ".join([f'"{src}"' for src in src_nodes])
|
src_array = ", ".join([f'"{src}"' for src in src_nodes])
|
||||||
@@ -1607,11 +1614,15 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
edges_dict = {}
|
edges_dict = {}
|
||||||
for result in results:
|
for result in results:
|
||||||
if result["source"] and result["target"] and result["edge_properties"]:
|
if result["source"] and result["target"] and result["edge_properties"]:
|
||||||
edges_dict[(result["source"], result["target"])] = result["edge_properties"]
|
edges_dict[(result["source"], result["target"])] = result[
|
||||||
|
"edge_properties"
|
||||||
|
]
|
||||||
|
|
||||||
return edges_dict
|
return edges_dict
|
||||||
|
|
||||||
async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]:
|
async def get_nodes_edges_batch(
|
||||||
|
self, node_ids: list[str]
|
||||||
|
) -> dict[str, list[tuple[str, str]]]:
|
||||||
"""
|
"""
|
||||||
Get all edges for multiple nodes in a single batch operation.
|
Get all edges for multiple nodes in a single batch operation.
|
||||||
|
|
||||||
@@ -1625,7 +1636,9 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
# Format node IDs for the query
|
# Format node IDs for the query
|
||||||
formatted_ids = ", ".join(['"' + node_id.replace('"', '') + '"' for node_id in node_ids])
|
formatted_ids = ", ".join(
|
||||||
|
['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
|
||||||
|
)
|
||||||
|
|
||||||
query = """SELECT * FROM cypher('%s', $$
|
query = """SELECT * FROM cypher('%s', $$
|
||||||
UNWIND [%s] AS node_id
|
UNWIND [%s] AS node_id
|
||||||
@@ -1634,7 +1647,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
RETURN node_id, connected.entity_id AS connected_id
|
RETURN node_id, connected.entity_id AS connected_id
|
||||||
$$) AS (node_id text, connected_id text)""" % (
|
$$) AS (node_id text, connected_id text)""" % (
|
||||||
self.graph_name,
|
self.graph_name,
|
||||||
formatted_ids
|
formatted_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
results = await self._query(query)
|
results = await self._query(query)
|
||||||
|
@@ -1330,7 +1330,7 @@ async def _get_node_data(
|
|||||||
# Call the batch node retrieval and degree functions concurrently.
|
# Call the batch node retrieval and degree functions concurrently.
|
||||||
nodes_dict, degrees_dict = await asyncio.gather(
|
nodes_dict, degrees_dict = await asyncio.gather(
|
||||||
knowledge_graph_inst.get_nodes_batch(node_ids),
|
knowledge_graph_inst.get_nodes_batch(node_ids),
|
||||||
knowledge_graph_inst.node_degrees_batch(node_ids)
|
knowledge_graph_inst.node_degrees_batch(node_ids),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Now, if you need the node data and degree in order:
|
# Now, if you need the node data and degree in order:
|
||||||
@@ -1474,8 +1474,12 @@ async def _find_most_related_text_unit_from_entities(
|
|||||||
all_one_hop_nodes = list(all_one_hop_nodes)
|
all_one_hop_nodes = list(all_one_hop_nodes)
|
||||||
|
|
||||||
# Batch retrieve one-hop node data using get_nodes_batch
|
# Batch retrieve one-hop node data using get_nodes_batch
|
||||||
all_one_hop_nodes_data_dict = await knowledge_graph_inst.get_nodes_batch(all_one_hop_nodes)
|
all_one_hop_nodes_data_dict = await knowledge_graph_inst.get_nodes_batch(
|
||||||
all_one_hop_nodes_data = [all_one_hop_nodes_data_dict.get(e) for e in all_one_hop_nodes]
|
all_one_hop_nodes
|
||||||
|
)
|
||||||
|
all_one_hop_nodes_data = [
|
||||||
|
all_one_hop_nodes_data_dict.get(e) for e in all_one_hop_nodes
|
||||||
|
]
|
||||||
|
|
||||||
# Add null check for node data
|
# Add null check for node data
|
||||||
all_one_hop_text_units_lookup = {
|
all_one_hop_text_units_lookup = {
|
||||||
@@ -1575,7 +1579,7 @@ async def _find_most_related_edges_from_entities(
|
|||||||
# Call the batched functions concurrently.
|
# Call the batched functions concurrently.
|
||||||
edge_data_dict, edge_degrees_dict = await asyncio.gather(
|
edge_data_dict, edge_degrees_dict = await asyncio.gather(
|
||||||
knowledge_graph_inst.get_edges_batch(edge_pairs_dicts),
|
knowledge_graph_inst.get_edges_batch(edge_pairs_dicts),
|
||||||
knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples)
|
knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reconstruct edge_datas list in the same order as the deduplicated results.
|
# Reconstruct edge_datas list in the same order as the deduplicated results.
|
||||||
@@ -1590,7 +1594,6 @@ async def _find_most_related_edges_from_entities(
|
|||||||
}
|
}
|
||||||
all_edges_data.append(combined)
|
all_edges_data.append(combined)
|
||||||
|
|
||||||
|
|
||||||
all_edges_data = sorted(
|
all_edges_data = sorted(
|
||||||
all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
||||||
)
|
)
|
||||||
@@ -1634,7 +1637,7 @@ async def _get_edge_data(
|
|||||||
# Call the batched functions concurrently.
|
# Call the batched functions concurrently.
|
||||||
edge_data_dict, edge_degrees_dict = await asyncio.gather(
|
edge_data_dict, edge_degrees_dict = await asyncio.gather(
|
||||||
knowledge_graph_inst.get_edges_batch(edge_pairs_dicts),
|
knowledge_graph_inst.get_edges_batch(edge_pairs_dicts),
|
||||||
knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples)
|
knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reconstruct edge_datas list in the same order as results.
|
# Reconstruct edge_datas list in the same order as results.
|
||||||
@@ -1761,7 +1764,7 @@ async def _find_most_related_entities_from_relationships(
|
|||||||
# Batch approach: Retrieve nodes and their degrees concurrently with one query each.
|
# Batch approach: Retrieve nodes and their degrees concurrently with one query each.
|
||||||
nodes_dict, degrees_dict = await asyncio.gather(
|
nodes_dict, degrees_dict = await asyncio.gather(
|
||||||
knowledge_graph_inst.get_nodes_batch(entity_names),
|
knowledge_graph_inst.get_nodes_batch(entity_names),
|
||||||
knowledge_graph_inst.node_degrees_batch(entity_names)
|
knowledge_graph_inst.node_degrees_batch(entity_names),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Rebuild the list in the same order as entity_names
|
# Rebuild the list in the same order as entity_names
|
||||||
|
@@ -510,35 +510,66 @@ async def test_graph_batch_operations(storage):
|
|||||||
assert node1_id in nodes_dict, f"{node1_id} 应在返回结果中"
|
assert node1_id in nodes_dict, f"{node1_id} 应在返回结果中"
|
||||||
assert node2_id in nodes_dict, f"{node2_id} 应在返回结果中"
|
assert node2_id in nodes_dict, f"{node2_id} 应在返回结果中"
|
||||||
assert node3_id in nodes_dict, f"{node3_id} 应在返回结果中"
|
assert node3_id in nodes_dict, f"{node3_id} 应在返回结果中"
|
||||||
assert nodes_dict[node1_id]["description"] == node1_data["description"], f"{node1_id} 描述不匹配"
|
assert (
|
||||||
assert nodes_dict[node2_id]["description"] == node2_data["description"], f"{node2_id} 描述不匹配"
|
nodes_dict[node1_id]["description"] == node1_data["description"]
|
||||||
assert nodes_dict[node3_id]["description"] == node3_data["description"], f"{node3_id} 描述不匹配"
|
), f"{node1_id} 描述不匹配"
|
||||||
|
assert (
|
||||||
|
nodes_dict[node2_id]["description"] == node2_data["description"]
|
||||||
|
), f"{node2_id} 描述不匹配"
|
||||||
|
assert (
|
||||||
|
nodes_dict[node3_id]["description"] == node3_data["description"]
|
||||||
|
), f"{node3_id} 描述不匹配"
|
||||||
|
|
||||||
# 3. 测试 node_degrees_batch - 批量获取多个节点的度数
|
# 3. 测试 node_degrees_batch - 批量获取多个节点的度数
|
||||||
print("== 测试 node_degrees_batch")
|
print("== 测试 node_degrees_batch")
|
||||||
node_degrees = await storage.node_degrees_batch(node_ids)
|
node_degrees = await storage.node_degrees_batch(node_ids)
|
||||||
print(f"批量获取节点度数结果: {node_degrees}")
|
print(f"批量获取节点度数结果: {node_degrees}")
|
||||||
assert len(node_degrees) == 3, f"应返回3个节点的度数,实际返回 {len(node_degrees)} 个"
|
assert (
|
||||||
|
len(node_degrees) == 3
|
||||||
|
), f"应返回3个节点的度数,实际返回 {len(node_degrees)} 个"
|
||||||
assert node1_id in node_degrees, f"{node1_id} 应在返回结果中"
|
assert node1_id in node_degrees, f"{node1_id} 应在返回结果中"
|
||||||
assert node2_id in node_degrees, f"{node2_id} 应在返回结果中"
|
assert node2_id in node_degrees, f"{node2_id} 应在返回结果中"
|
||||||
assert node3_id in node_degrees, f"{node3_id} 应在返回结果中"
|
assert node3_id in node_degrees, f"{node3_id} 应在返回结果中"
|
||||||
assert node_degrees[node1_id] == 3, f"{node1_id} 度数应为3,实际为 {node_degrees[node1_id]}"
|
assert (
|
||||||
assert node_degrees[node2_id] == 2, f"{node2_id} 度数应为2,实际为 {node_degrees[node2_id]}"
|
node_degrees[node1_id] == 3
|
||||||
assert node_degrees[node3_id] == 3, f"{node3_id} 度数应为3,实际为 {node_degrees[node3_id]}"
|
), f"{node1_id} 度数应为3,实际为 {node_degrees[node1_id]}"
|
||||||
|
assert (
|
||||||
|
node_degrees[node2_id] == 2
|
||||||
|
), f"{node2_id} 度数应为2,实际为 {node_degrees[node2_id]}"
|
||||||
|
assert (
|
||||||
|
node_degrees[node3_id] == 3
|
||||||
|
), f"{node3_id} 度数应为3,实际为 {node_degrees[node3_id]}"
|
||||||
|
|
||||||
# 4. 测试 edge_degrees_batch - 批量获取多个边的度数
|
# 4. 测试 edge_degrees_batch - 批量获取多个边的度数
|
||||||
print("== 测试 edge_degrees_batch")
|
print("== 测试 edge_degrees_batch")
|
||||||
edges = [(node1_id, node2_id), (node2_id, node3_id), (node3_id, node4_id)]
|
edges = [(node1_id, node2_id), (node2_id, node3_id), (node3_id, node4_id)]
|
||||||
edge_degrees = await storage.edge_degrees_batch(edges)
|
edge_degrees = await storage.edge_degrees_batch(edges)
|
||||||
print(f"批量获取边度数结果: {edge_degrees}")
|
print(f"批量获取边度数结果: {edge_degrees}")
|
||||||
assert len(edge_degrees) == 3, f"应返回3条边的度数,实际返回 {len(edge_degrees)} 条"
|
assert (
|
||||||
assert (node1_id, node2_id) in edge_degrees, f"边 {node1_id} -> {node2_id} 应在返回结果中"
|
len(edge_degrees) == 3
|
||||||
assert (node2_id, node3_id) in edge_degrees, f"边 {node2_id} -> {node3_id} 应在返回结果中"
|
), f"应返回3条边的度数,实际返回 {len(edge_degrees)} 条"
|
||||||
assert (node3_id, node4_id) in edge_degrees, f"边 {node3_id} -> {node4_id} 应在返回结果中"
|
assert (
|
||||||
|
node1_id,
|
||||||
|
node2_id,
|
||||||
|
) in edge_degrees, f"边 {node1_id} -> {node2_id} 应在返回结果中"
|
||||||
|
assert (
|
||||||
|
node2_id,
|
||||||
|
node3_id,
|
||||||
|
) in edge_degrees, f"边 {node2_id} -> {node3_id} 应在返回结果中"
|
||||||
|
assert (
|
||||||
|
node3_id,
|
||||||
|
node4_id,
|
||||||
|
) in edge_degrees, f"边 {node3_id} -> {node4_id} 应在返回结果中"
|
||||||
# 验证边的度数是否正确(源节点度数 + 目标节点度数)
|
# 验证边的度数是否正确(源节点度数 + 目标节点度数)
|
||||||
assert edge_degrees[(node1_id, node2_id)] == 5, f"边 {node1_id} -> {node2_id} 度数应为5,实际为 {edge_degrees[(node1_id, node2_id)]}"
|
assert (
|
||||||
assert edge_degrees[(node2_id, node3_id)] == 5, f"边 {node2_id} -> {node3_id} 度数应为5,实际为 {edge_degrees[(node2_id, node3_id)]}"
|
edge_degrees[(node1_id, node2_id)] == 5
|
||||||
assert edge_degrees[(node3_id, node4_id)] == 5, f"边 {node3_id} -> {node4_id} 度数应为5,实际为 {edge_degrees[(node3_id, node4_id)]}"
|
), f"边 {node1_id} -> {node2_id} 度数应为5,实际为 {edge_degrees[(node1_id, node2_id)]}"
|
||||||
|
assert (
|
||||||
|
edge_degrees[(node2_id, node3_id)] == 5
|
||||||
|
), f"边 {node2_id} -> {node3_id} 度数应为5,实际为 {edge_degrees[(node2_id, node3_id)]}"
|
||||||
|
assert (
|
||||||
|
edge_degrees[(node3_id, node4_id)] == 5
|
||||||
|
), f"边 {node3_id} -> {node4_id} 度数应为5,实际为 {edge_degrees[(node3_id, node4_id)]}"
|
||||||
|
|
||||||
# 5. 测试 get_edges_batch - 批量获取多个边的属性
|
# 5. 测试 get_edges_batch - 批量获取多个边的属性
|
||||||
print("== 测试 get_edges_batch")
|
print("== 测试 get_edges_batch")
|
||||||
@@ -547,28 +578,54 @@ async def test_graph_batch_operations(storage):
|
|||||||
edges_dict = await storage.get_edges_batch(edge_dicts)
|
edges_dict = await storage.get_edges_batch(edge_dicts)
|
||||||
print(f"批量获取边属性结果: {edges_dict.keys()}")
|
print(f"批量获取边属性结果: {edges_dict.keys()}")
|
||||||
assert len(edges_dict) == 3, f"应返回3条边的属性,实际返回 {len(edges_dict)} 条"
|
assert len(edges_dict) == 3, f"应返回3条边的属性,实际返回 {len(edges_dict)} 条"
|
||||||
assert (node1_id, node2_id) in edges_dict, f"边 {node1_id} -> {node2_id} 应在返回结果中"
|
assert (
|
||||||
assert (node2_id, node3_id) in edges_dict, f"边 {node2_id} -> {node3_id} 应在返回结果中"
|
node1_id,
|
||||||
assert (node3_id, node4_id) in edges_dict, f"边 {node3_id} -> {node4_id} 应在返回结果中"
|
node2_id,
|
||||||
assert edges_dict[(node1_id, node2_id)]["relationship"] == edge1_data["relationship"], f"边 {node1_id} -> {node2_id} 关系不匹配"
|
) in edges_dict, f"边 {node1_id} -> {node2_id} 应在返回结果中"
|
||||||
assert edges_dict[(node2_id, node3_id)]["relationship"] == edge2_data["relationship"], f"边 {node2_id} -> {node3_id} 关系不匹配"
|
assert (
|
||||||
assert edges_dict[(node3_id, node4_id)]["relationship"] == edge5_data["relationship"], f"边 {node3_id} -> {node4_id} 关系不匹配"
|
node2_id,
|
||||||
|
node3_id,
|
||||||
|
) in edges_dict, f"边 {node2_id} -> {node3_id} 应在返回结果中"
|
||||||
|
assert (
|
||||||
|
node3_id,
|
||||||
|
node4_id,
|
||||||
|
) in edges_dict, f"边 {node3_id} -> {node4_id} 应在返回结果中"
|
||||||
|
assert (
|
||||||
|
edges_dict[(node1_id, node2_id)]["relationship"]
|
||||||
|
== edge1_data["relationship"]
|
||||||
|
), f"边 {node1_id} -> {node2_id} 关系不匹配"
|
||||||
|
assert (
|
||||||
|
edges_dict[(node2_id, node3_id)]["relationship"]
|
||||||
|
== edge2_data["relationship"]
|
||||||
|
), f"边 {node2_id} -> {node3_id} 关系不匹配"
|
||||||
|
assert (
|
||||||
|
edges_dict[(node3_id, node4_id)]["relationship"]
|
||||||
|
== edge5_data["relationship"]
|
||||||
|
), f"边 {node3_id} -> {node4_id} 关系不匹配"
|
||||||
|
|
||||||
# 6. 测试 get_nodes_edges_batch - 批量获取多个节点的所有边
|
# 6. 测试 get_nodes_edges_batch - 批量获取多个节点的所有边
|
||||||
print("== 测试 get_nodes_edges_batch")
|
print("== 测试 get_nodes_edges_batch")
|
||||||
nodes_edges = await storage.get_nodes_edges_batch([node1_id, node3_id])
|
nodes_edges = await storage.get_nodes_edges_batch([node1_id, node3_id])
|
||||||
print(f"批量获取节点边结果: {nodes_edges.keys()}")
|
print(f"批量获取节点边结果: {nodes_edges.keys()}")
|
||||||
assert len(nodes_edges) == 2, f"应返回2个节点的边,实际返回 {len(nodes_edges)} 个"
|
assert (
|
||||||
|
len(nodes_edges) == 2
|
||||||
|
), f"应返回2个节点的边,实际返回 {len(nodes_edges)} 个"
|
||||||
assert node1_id in nodes_edges, f"{node1_id} 应在返回结果中"
|
assert node1_id in nodes_edges, f"{node1_id} 应在返回结果中"
|
||||||
assert node3_id in nodes_edges, f"{node3_id} 应在返回结果中"
|
assert node3_id in nodes_edges, f"{node3_id} 应在返回结果中"
|
||||||
assert len(nodes_edges[node1_id]) == 3, f"{node1_id} 应有3条边,实际有 {len(nodes_edges[node1_id])} 条"
|
assert (
|
||||||
assert len(nodes_edges[node3_id]) == 3, f"{node3_id} 应有3条边,实际有 {len(nodes_edges[node3_id])} 条"
|
len(nodes_edges[node1_id]) == 3
|
||||||
|
), f"{node1_id} 应有3条边,实际有 {len(nodes_edges[node1_id])} 条"
|
||||||
|
assert (
|
||||||
|
len(nodes_edges[node3_id]) == 3
|
||||||
|
), f"{node3_id} 应有3条边,实际有 {len(nodes_edges[node3_id])} 条"
|
||||||
|
|
||||||
# 7. 清理数据
|
# 7. 清理数据
|
||||||
print("== 测试 drop")
|
print("== 测试 drop")
|
||||||
result = await storage.drop()
|
result = await storage.drop()
|
||||||
print(f"清理结果: {result}")
|
print(f"清理结果: {result}")
|
||||||
assert result["status"] == "success", f"清理应成功,实际状态为 {result['status']}"
|
assert (
|
||||||
|
result["status"] == "success"
|
||||||
|
), f"清理应成功,实际状态为 {result['status']}"
|
||||||
|
|
||||||
print("\n批量操作测试完成")
|
print("\n批量操作测试完成")
|
||||||
return True
|
return True
|
||||||
|
Reference in New Issue
Block a user