Fix linting
This commit is contained in:
@@ -363,7 +363,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
||||
|
||||
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
|
||||
"""Get nodes as a batch using UNWIND
|
||||
|
||||
|
||||
Default implementation fetches nodes one by one.
|
||||
Override this method for better performance in storage backends
|
||||
that support batch operations.
|
||||
@@ -377,7 +377,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
||||
|
||||
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
|
||||
"""Node degrees as a batch using UNWIND
|
||||
|
||||
|
||||
Default implementation fetches node degrees one by one.
|
||||
Override this method for better performance in storage backends
|
||||
that support batch operations.
|
||||
@@ -388,9 +388,11 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
||||
result[node_id] = degree
|
||||
return result
|
||||
|
||||
async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> dict[tuple[str, str], int]:
|
||||
async def edge_degrees_batch(
|
||||
self, edge_pairs: list[tuple[str, str]]
|
||||
) -> dict[tuple[str, str], int]:
|
||||
"""Edge degrees as a batch using UNWIND also uses node_degrees_batch
|
||||
|
||||
|
||||
Default implementation calculates edge degrees one by one.
|
||||
Override this method for better performance in storage backends
|
||||
that support batch operations.
|
||||
@@ -401,9 +403,11 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
||||
result[(src_id, tgt_id)] = degree
|
||||
return result
|
||||
|
||||
async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]:
|
||||
async def get_edges_batch(
|
||||
self, pairs: list[dict[str, str]]
|
||||
) -> dict[tuple[str, str], dict]:
|
||||
"""Get edges as a batch using UNWIND
|
||||
|
||||
|
||||
Default implementation fetches edges one by one.
|
||||
Override this method for better performance in storage backends
|
||||
that support batch operations.
|
||||
@@ -417,9 +421,11 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
||||
result[(src_id, tgt_id)] = edge
|
||||
return result
|
||||
|
||||
async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]:
|
||||
async def get_nodes_edges_batch(
|
||||
self, node_ids: list[str]
|
||||
) -> dict[str, list[tuple[str, str]]]:
|
||||
"""Get nodes edges as a batch using UNWIND
|
||||
|
||||
|
||||
Default implementation fetches node edges one by one.
|
||||
Override this method for better performance in storage backends
|
||||
that support batch operations.
|
||||
|
@@ -311,10 +311,10 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
|
||||
"""
|
||||
Retrieve multiple nodes in one query using UNWIND.
|
||||
|
||||
|
||||
Args:
|
||||
node_ids: List of node entity IDs to fetch.
|
||||
|
||||
|
||||
Returns:
|
||||
A dictionary mapping each node_id to its node data (or None if not found).
|
||||
"""
|
||||
@@ -334,7 +334,9 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
node_dict = dict(node)
|
||||
# Remove the 'base' label if present in a 'labels' property
|
||||
if "labels" in node_dict:
|
||||
node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"]
|
||||
node_dict["labels"] = [
|
||||
label for label in node_dict["labels"] if label != "base"
|
||||
]
|
||||
nodes[entity_id] = node_dict
|
||||
await result.consume() # Make sure to consume the result fully
|
||||
return nodes
|
||||
@@ -385,12 +387,12 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
|
||||
"""
|
||||
Retrieve the degree for multiple nodes in a single query using UNWIND.
|
||||
|
||||
|
||||
Args:
|
||||
node_ids: List of node labels (entity_id values) to look up.
|
||||
|
||||
|
||||
Returns:
|
||||
A dictionary mapping each node_id to its degree (number of relationships).
|
||||
A dictionary mapping each node_id to its degree (number of relationships).
|
||||
If a node is not found, its degree will be set to 0.
|
||||
"""
|
||||
async with self._driver.session(
|
||||
@@ -407,13 +409,13 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
entity_id = record["entity_id"]
|
||||
degrees[entity_id] = record["degree"]
|
||||
await result.consume() # Ensure result is fully consumed
|
||||
|
||||
|
||||
# For any node_id that did not return a record, set degree to 0.
|
||||
for nid in node_ids:
|
||||
if nid not in degrees:
|
||||
logger.warning(f"No node found with label '{nid}'")
|
||||
degrees[nid] = 0
|
||||
|
||||
|
||||
logger.debug(f"Neo4j batch node degree query returned: {degrees}")
|
||||
return degrees
|
||||
|
||||
@@ -436,25 +438,27 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
|
||||
degrees = int(src_degree) + int(trg_degree)
|
||||
return degrees
|
||||
|
||||
async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> dict[tuple[str, str], int]:
|
||||
|
||||
async def edge_degrees_batch(
|
||||
self, edge_pairs: list[tuple[str, str]]
|
||||
) -> dict[tuple[str, str], int]:
|
||||
"""
|
||||
Calculate the combined degree for each edge (sum of the source and target node degrees)
|
||||
in batch using the already implemented node_degrees_batch.
|
||||
|
||||
|
||||
Args:
|
||||
edge_pairs: List of (src, tgt) tuples.
|
||||
|
||||
|
||||
Returns:
|
||||
A dictionary mapping each (src, tgt) tuple to the sum of their degrees.
|
||||
"""
|
||||
# Collect unique node IDs from all edge pairs.
|
||||
unique_node_ids = {src for src, _ in edge_pairs}
|
||||
unique_node_ids.update({tgt for _, tgt in edge_pairs})
|
||||
|
||||
|
||||
# Get degrees for all nodes in one go.
|
||||
degrees = await self.node_degrees_batch(list(unique_node_ids))
|
||||
|
||||
|
||||
# Sum up degrees for each edge pair.
|
||||
edge_degrees = {}
|
||||
for src, tgt in edge_pairs:
|
||||
@@ -547,13 +551,15 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]:
|
||||
async def get_edges_batch(
|
||||
self, pairs: list[dict[str, str]]
|
||||
) -> dict[tuple[str, str], dict]:
|
||||
"""
|
||||
Retrieve edge properties for multiple (src, tgt) pairs in one query.
|
||||
|
||||
|
||||
Args:
|
||||
pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...]
|
||||
|
||||
|
||||
Returns:
|
||||
A dictionary mapping (src, tgt) tuples to their edge properties.
|
||||
"""
|
||||
@@ -574,13 +580,23 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
if edges and len(edges) > 0:
|
||||
edge_props = edges[0] # choose the first if multiple exist
|
||||
# Ensure required keys exist with defaults
|
||||
for key, default in {"weight": 0.0, "source_id": None, "description": None, "keywords": None}.items():
|
||||
for key, default in {
|
||||
"weight": 0.0,
|
||||
"source_id": None,
|
||||
"description": None,
|
||||
"keywords": None,
|
||||
}.items():
|
||||
if key not in edge_props:
|
||||
edge_props[key] = default
|
||||
edges_dict[(src, tgt)] = edge_props
|
||||
else:
|
||||
# No edge found – set default edge properties
|
||||
edges_dict[(src, tgt)] = {"weight": 0.0, "source_id": None, "description": None, "keywords": None}
|
||||
edges_dict[(src, tgt)] = {
|
||||
"weight": 0.0,
|
||||
"source_id": None,
|
||||
"description": None,
|
||||
"keywords": None,
|
||||
}
|
||||
await result.consume()
|
||||
return edges_dict
|
||||
|
||||
@@ -644,17 +660,21 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]:
|
||||
async def get_nodes_edges_batch(
|
||||
self, node_ids: list[str]
|
||||
) -> dict[str, list[tuple[str, str]]]:
|
||||
"""
|
||||
Batch retrieve edges for multiple nodes in one query using UNWIND.
|
||||
|
||||
|
||||
Args:
|
||||
node_ids: List of node IDs (entity_id) for which to retrieve edges.
|
||||
|
||||
|
||||
Returns:
|
||||
A dictionary mapping each node ID to its list of edge tuples (source, target).
|
||||
"""
|
||||
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
query = """
|
||||
UNWIND $node_ids AS id
|
||||
MATCH (n:base {entity_id: id})
|
||||
|
@@ -1461,30 +1461,29 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
|
||||
"""
|
||||
Retrieve multiple nodes in one query using UNWIND.
|
||||
|
||||
|
||||
Args:
|
||||
node_ids: List of node entity IDs to fetch.
|
||||
|
||||
|
||||
Returns:
|
||||
A dictionary mapping each node_id to its node data (or None if not found).
|
||||
"""
|
||||
if not node_ids:
|
||||
return {}
|
||||
|
||||
|
||||
# Format node IDs for the query
|
||||
formatted_ids = ", ".join(['"' + node_id.replace('"', '') + '"' for node_id in node_ids])
|
||||
|
||||
formatted_ids = ", ".join(
|
||||
['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
|
||||
)
|
||||
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
UNWIND [%s] AS node_id
|
||||
MATCH (n:base {entity_id: node_id})
|
||||
RETURN node_id, n
|
||||
$$) AS (node_id text, n agtype)""" % (
|
||||
self.graph_name,
|
||||
formatted_ids
|
||||
)
|
||||
|
||||
$$) AS (node_id text, n agtype)""" % (self.graph_name, formatted_ids)
|
||||
|
||||
results = await self._query(query)
|
||||
|
||||
|
||||
# Build result dictionary
|
||||
nodes_dict = {}
|
||||
for result in results:
|
||||
@@ -1492,28 +1491,32 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
node_dict = result["n"]["properties"]
|
||||
# Remove the 'base' label if present in a 'labels' property
|
||||
if "labels" in node_dict:
|
||||
node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"]
|
||||
node_dict["labels"] = [
|
||||
label for label in node_dict["labels"] if label != "base"
|
||||
]
|
||||
nodes_dict[result["node_id"]] = node_dict
|
||||
|
||||
|
||||
return nodes_dict
|
||||
|
||||
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
|
||||
"""
|
||||
Retrieve the degree for multiple nodes in a single query using UNWIND.
|
||||
|
||||
|
||||
Args:
|
||||
node_ids: List of node labels (entity_id values) to look up.
|
||||
|
||||
|
||||
Returns:
|
||||
A dictionary mapping each node_id to its degree (number of relationships).
|
||||
A dictionary mapping each node_id to its degree (number of relationships).
|
||||
If a node is not found, its degree will be set to 0.
|
||||
"""
|
||||
if not node_ids:
|
||||
return {}
|
||||
|
||||
|
||||
# Format node IDs for the query
|
||||
formatted_ids = ", ".join(['"' + node_id.replace('"', '') + '"' for node_id in node_ids])
|
||||
|
||||
formatted_ids = ", ".join(
|
||||
['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
|
||||
)
|
||||
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
UNWIND [%s] AS node_id
|
||||
MATCH (n:base {entity_id: node_id})
|
||||
@@ -1521,112 +1524,122 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
RETURN node_id, count(r) AS degree
|
||||
$$) AS (node_id text, degree bigint)""" % (
|
||||
self.graph_name,
|
||||
formatted_ids
|
||||
formatted_ids,
|
||||
)
|
||||
|
||||
|
||||
results = await self._query(query)
|
||||
|
||||
|
||||
# Build result dictionary
|
||||
degrees_dict = {}
|
||||
for result in results:
|
||||
if result["node_id"] is not None:
|
||||
degrees_dict[result["node_id"]] = int(result["degree"])
|
||||
|
||||
|
||||
# Ensure all requested node_ids are in the result dictionary
|
||||
for node_id in node_ids:
|
||||
if node_id not in degrees_dict:
|
||||
degrees_dict[node_id] = 0
|
||||
|
||||
|
||||
return degrees_dict
|
||||
|
||||
async def edge_degrees_batch(self, edges: list[tuple[str, str]]) -> dict[tuple[str, str], int]:
|
||||
|
||||
async def edge_degrees_batch(
|
||||
self, edges: list[tuple[str, str]]
|
||||
) -> dict[tuple[str, str], int]:
|
||||
"""
|
||||
Calculate the combined degree for each edge (sum of the source and target node degrees)
|
||||
in batch using the already implemented node_degrees_batch.
|
||||
|
||||
|
||||
Args:
|
||||
edges: List of (source_node_id, target_node_id) tuples
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary mapping edge tuples to their combined degrees
|
||||
"""
|
||||
if not edges:
|
||||
return {}
|
||||
|
||||
|
||||
# Use node_degrees_batch to get all node degrees efficiently
|
||||
all_nodes = set()
|
||||
for src, tgt in edges:
|
||||
all_nodes.add(src)
|
||||
all_nodes.add(tgt)
|
||||
|
||||
|
||||
node_degrees = await self.node_degrees_batch(list(all_nodes))
|
||||
|
||||
|
||||
# Calculate edge degrees
|
||||
edge_degrees_dict = {}
|
||||
for src, tgt in edges:
|
||||
src_degree = node_degrees.get(src, 0)
|
||||
tgt_degree = node_degrees.get(tgt, 0)
|
||||
edge_degrees_dict[(src, tgt)] = src_degree + tgt_degree
|
||||
|
||||
|
||||
return edge_degrees_dict
|
||||
|
||||
async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]:
|
||||
|
||||
async def get_edges_batch(
|
||||
self, pairs: list[dict[str, str]]
|
||||
) -> dict[tuple[str, str], dict]:
|
||||
"""
|
||||
Retrieve edge properties for multiple (src, tgt) pairs in one query.
|
||||
|
||||
|
||||
Args:
|
||||
pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...]
|
||||
|
||||
|
||||
Returns:
|
||||
A dictionary mapping (src, tgt) tuples to their edge properties.
|
||||
"""
|
||||
if not pairs:
|
||||
return {}
|
||||
|
||||
|
||||
# 从字典列表中提取源节点和目标节点ID
|
||||
src_nodes = []
|
||||
tgt_nodes = []
|
||||
for pair in pairs:
|
||||
src_nodes.append(pair["src"].replace('"', ''))
|
||||
tgt_nodes.append(pair["tgt"].replace('"', ''))
|
||||
|
||||
src_nodes.append(pair["src"].replace('"', ""))
|
||||
tgt_nodes.append(pair["tgt"].replace('"', ""))
|
||||
|
||||
# 构建查询,使用数组索引来匹配源节点和目标节点
|
||||
src_array = ", ".join([f'"{src}"' for src in src_nodes])
|
||||
tgt_array = ", ".join([f'"{tgt}"' for tgt in tgt_nodes])
|
||||
|
||||
|
||||
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||
WITH [{src_array}] AS sources, [{tgt_array}] AS targets
|
||||
UNWIND range(0, size(sources)-1) AS i
|
||||
MATCH (a:base {{entity_id: sources[i]}})-[r:DIRECTED]-(b:base {{entity_id: targets[i]}})
|
||||
RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties
|
||||
$$) AS (source text, target text, edge_properties agtype)"""
|
||||
|
||||
|
||||
results = await self._query(query)
|
||||
|
||||
|
||||
# 构建结果字典
|
||||
edges_dict = {}
|
||||
for result in results:
|
||||
if result["source"] and result["target"] and result["edge_properties"]:
|
||||
edges_dict[(result["source"], result["target"])] = result["edge_properties"]
|
||||
|
||||
edges_dict[(result["source"], result["target"])] = result[
|
||||
"edge_properties"
|
||||
]
|
||||
|
||||
return edges_dict
|
||||
|
||||
async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]:
|
||||
|
||||
async def get_nodes_edges_batch(
|
||||
self, node_ids: list[str]
|
||||
) -> dict[str, list[tuple[str, str]]]:
|
||||
"""
|
||||
Get all edges for multiple nodes in a single batch operation.
|
||||
|
||||
|
||||
Args:
|
||||
node_ids: List of node IDs to get edges for
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary mapping node IDs to lists of (source, target) edge tuples
|
||||
"""
|
||||
if not node_ids:
|
||||
return {}
|
||||
|
||||
|
||||
# Format node IDs for the query
|
||||
formatted_ids = ", ".join(['"' + node_id.replace('"', '') + '"' for node_id in node_ids])
|
||||
|
||||
formatted_ids = ", ".join(
|
||||
['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
|
||||
)
|
||||
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
UNWIND [%s] AS node_id
|
||||
MATCH (n:base {entity_id: node_id})
|
||||
@@ -1634,11 +1647,11 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
RETURN node_id, connected.entity_id AS connected_id
|
||||
$$) AS (node_id text, connected_id text)""" % (
|
||||
self.graph_name,
|
||||
formatted_ids
|
||||
formatted_ids,
|
||||
)
|
||||
|
||||
|
||||
results = await self._query(query)
|
||||
|
||||
|
||||
# Build result dictionary
|
||||
nodes_edges_dict = {node_id: [] for node_id in node_ids}
|
||||
for result in results:
|
||||
@@ -1646,9 +1659,9 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
nodes_edges_dict[result["node_id"]].append(
|
||||
(result["node_id"], result["connected_id"])
|
||||
)
|
||||
|
||||
|
||||
return nodes_edges_dict
|
||||
|
||||
|
||||
async def get_all_labels(self) -> list[str]:
|
||||
"""
|
||||
Get all labels (node IDs) in the graph.
|
||||
|
@@ -1323,14 +1323,14 @@ async def _get_node_data(
|
||||
|
||||
if not len(results):
|
||||
return "", "", ""
|
||||
|
||||
|
||||
# Extract all entity IDs from your results list
|
||||
node_ids = [r["entity_name"] for r in results]
|
||||
|
||||
# Call the batch node retrieval and degree functions concurrently.
|
||||
nodes_dict, degrees_dict = await asyncio.gather(
|
||||
knowledge_graph_inst.get_nodes_batch(node_ids),
|
||||
knowledge_graph_inst.node_degrees_batch(node_ids)
|
||||
knowledge_graph_inst.get_nodes_batch(node_ids),
|
||||
knowledge_graph_inst.node_degrees_batch(node_ids),
|
||||
)
|
||||
|
||||
# Now, if you need the node data and degree in order:
|
||||
@@ -1459,7 +1459,7 @@ async def _find_most_related_text_unit_from_entities(
|
||||
for dp in node_datas
|
||||
if dp["source_id"] is not None
|
||||
]
|
||||
|
||||
|
||||
node_names = [dp["entity_name"] for dp in node_datas]
|
||||
batch_edges_dict = await knowledge_graph_inst.get_nodes_edges_batch(node_names)
|
||||
# Build the edges list in the same order as node_datas.
|
||||
@@ -1472,10 +1472,14 @@ async def _find_most_related_text_unit_from_entities(
|
||||
all_one_hop_nodes.update([e[1] for e in this_edges])
|
||||
|
||||
all_one_hop_nodes = list(all_one_hop_nodes)
|
||||
|
||||
|
||||
# Batch retrieve one-hop node data using get_nodes_batch
|
||||
all_one_hop_nodes_data_dict = await knowledge_graph_inst.get_nodes_batch(all_one_hop_nodes)
|
||||
all_one_hop_nodes_data = [all_one_hop_nodes_data_dict.get(e) for e in all_one_hop_nodes]
|
||||
all_one_hop_nodes_data_dict = await knowledge_graph_inst.get_nodes_batch(
|
||||
all_one_hop_nodes
|
||||
)
|
||||
all_one_hop_nodes_data = [
|
||||
all_one_hop_nodes_data_dict.get(e) for e in all_one_hop_nodes
|
||||
]
|
||||
|
||||
# Add null check for node data
|
||||
all_one_hop_text_units_lookup = {
|
||||
@@ -1571,13 +1575,13 @@ async def _find_most_related_edges_from_entities(
|
||||
edge_pairs_dicts = [{"src": e[0], "tgt": e[1]} for e in all_edges]
|
||||
# For edge degrees, use tuples.
|
||||
edge_pairs_tuples = list(all_edges) # all_edges is already a list of tuples
|
||||
|
||||
|
||||
# Call the batched functions concurrently.
|
||||
edge_data_dict, edge_degrees_dict = await asyncio.gather(
|
||||
knowledge_graph_inst.get_edges_batch(edge_pairs_dicts),
|
||||
knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples)
|
||||
knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples),
|
||||
)
|
||||
|
||||
|
||||
# Reconstruct edge_datas list in the same order as the deduplicated results.
|
||||
all_edges_data = []
|
||||
for pair in all_edges:
|
||||
@@ -1590,7 +1594,6 @@ async def _find_most_related_edges_from_entities(
|
||||
}
|
||||
all_edges_data.append(combined)
|
||||
|
||||
|
||||
all_edges_data = sorted(
|
||||
all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
||||
)
|
||||
@@ -1634,7 +1637,7 @@ async def _get_edge_data(
|
||||
# Call the batched functions concurrently.
|
||||
edge_data_dict, edge_degrees_dict = await asyncio.gather(
|
||||
knowledge_graph_inst.get_edges_batch(edge_pairs_dicts),
|
||||
knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples)
|
||||
knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples),
|
||||
)
|
||||
|
||||
# Reconstruct edge_datas list in the same order as results.
|
||||
@@ -1652,7 +1655,7 @@ async def _get_edge_data(
|
||||
**edge_props,
|
||||
}
|
||||
edge_datas.append(combined)
|
||||
|
||||
|
||||
edge_datas = sorted(
|
||||
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
||||
)
|
||||
@@ -1761,7 +1764,7 @@ async def _find_most_related_entities_from_relationships(
|
||||
# Batch approach: Retrieve nodes and their degrees concurrently with one query each.
|
||||
nodes_dict, degrees_dict = await asyncio.gather(
|
||||
knowledge_graph_inst.get_nodes_batch(entity_names),
|
||||
knowledge_graph_inst.node_degrees_batch(entity_names)
|
||||
knowledge_graph_inst.node_degrees_batch(entity_names),
|
||||
)
|
||||
|
||||
# Rebuild the list in the same order as entity_names
|
||||
|
@@ -136,7 +136,7 @@ interface GraphState {
|
||||
// Version counter to trigger data refresh
|
||||
graphDataVersion: number
|
||||
incrementGraphDataVersion: () => void
|
||||
|
||||
|
||||
// Methods for updating graph elements and UI state together
|
||||
updateNodeAndSelect: (nodeId: string, entityId: string, propertyName: string, newValue: string) => Promise<void>
|
||||
updateEdgeAndSelect: (edgeId: string, dynamicId: string, sourceId: string, targetId: string, propertyName: string, newValue: string) => Promise<void>
|
||||
@@ -252,40 +252,40 @@ const useGraphStoreBase = create<GraphState>()((set, get) => ({
|
||||
// Get current state
|
||||
const state = get()
|
||||
const { sigmaGraph, rawGraph } = state
|
||||
|
||||
|
||||
// Validate graph state
|
||||
if (!sigmaGraph || !rawGraph || !sigmaGraph.hasNode(nodeId)) {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
try {
|
||||
const nodeAttributes = sigmaGraph.getNodeAttributes(nodeId)
|
||||
|
||||
|
||||
console.log('updateNodeAndSelect', nodeId, entityId, propertyName, newValue)
|
||||
|
||||
|
||||
// For entity_id changes (node renaming) with NetworkX graph storage
|
||||
if ((nodeId === entityId) && (propertyName === 'entity_id')) {
|
||||
// Create new node with updated ID but same attributes
|
||||
sigmaGraph.addNode(newValue, { ...nodeAttributes, label: newValue })
|
||||
|
||||
|
||||
const edgesToUpdate: EdgeToUpdate[] = []
|
||||
|
||||
|
||||
// Process all edges connected to this node
|
||||
sigmaGraph.forEachEdge(nodeId, (edge, attributes, source, target) => {
|
||||
const otherNode = source === nodeId ? target : source
|
||||
const isOutgoing = source === nodeId
|
||||
|
||||
|
||||
// Get original edge dynamic ID for later reference
|
||||
const originalEdgeDynamicId = edge
|
||||
const edgeIndexInRawGraph = rawGraph.edgeDynamicIdMap[originalEdgeDynamicId]
|
||||
|
||||
|
||||
// Create new edge with updated node reference
|
||||
const newEdgeId = sigmaGraph.addEdge(
|
||||
isOutgoing ? newValue : otherNode,
|
||||
isOutgoing ? otherNode : newValue,
|
||||
attributes
|
||||
)
|
||||
|
||||
|
||||
// Track edges that need updating in the raw graph
|
||||
if (edgeIndexInRawGraph !== undefined) {
|
||||
edgesToUpdate.push({
|
||||
@@ -294,14 +294,14 @@ const useGraphStoreBase = create<GraphState>()((set, get) => ({
|
||||
edgeIndex: edgeIndexInRawGraph
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
// Remove the old edge
|
||||
sigmaGraph.dropEdge(edge)
|
||||
})
|
||||
|
||||
|
||||
// Remove the old node after all edges are processed
|
||||
sigmaGraph.dropNode(nodeId)
|
||||
|
||||
|
||||
// Update node reference in raw graph data
|
||||
const nodeIndex = rawGraph.nodeIdMap[nodeId]
|
||||
if (nodeIndex !== undefined) {
|
||||
@@ -311,7 +311,7 @@ const useGraphStoreBase = create<GraphState>()((set, get) => ({
|
||||
delete rawGraph.nodeIdMap[nodeId]
|
||||
rawGraph.nodeIdMap[newValue] = nodeIndex
|
||||
}
|
||||
|
||||
|
||||
// Update all edge references in raw graph data
|
||||
edgesToUpdate.forEach(({ originalDynamicId, newEdgeId, edgeIndex }) => {
|
||||
if (rawGraph.edges[edgeIndex]) {
|
||||
@@ -322,14 +322,14 @@ const useGraphStoreBase = create<GraphState>()((set, get) => ({
|
||||
if (rawGraph.edges[edgeIndex].target === nodeId) {
|
||||
rawGraph.edges[edgeIndex].target = newValue
|
||||
}
|
||||
|
||||
|
||||
// Update dynamic ID mappings
|
||||
rawGraph.edges[edgeIndex].dynamicId = newEdgeId
|
||||
delete rawGraph.edgeDynamicIdMap[originalDynamicId]
|
||||
rawGraph.edgeDynamicIdMap[newEdgeId] = edgeIndex
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
// Update selected node in store
|
||||
set({ selectedNode: newValue, moveToSelectedNode: true })
|
||||
} else {
|
||||
@@ -342,7 +342,7 @@ const useGraphStoreBase = create<GraphState>()((set, get) => ({
|
||||
sigmaGraph.setNodeAttribute(String(nodeId), 'label', newValue)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Trigger a re-render by incrementing the version counter
|
||||
set((state) => ({ graphDataVersion: state.graphDataVersion + 1 }))
|
||||
}
|
||||
@@ -351,17 +351,17 @@ const useGraphStoreBase = create<GraphState>()((set, get) => ({
|
||||
throw new Error('Failed to update node in graph')
|
||||
}
|
||||
},
|
||||
|
||||
|
||||
updateEdgeAndSelect: async (edgeId: string, dynamicId: string, sourceId: string, targetId: string, propertyName: string, newValue: string) => {
|
||||
// Get current state
|
||||
const state = get()
|
||||
const { sigmaGraph, rawGraph } = state
|
||||
|
||||
|
||||
// Validate graph state
|
||||
if (!sigmaGraph || !rawGraph) {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
try {
|
||||
const edgeIndex = rawGraph.edgeIdMap[String(edgeId)]
|
||||
if (edgeIndex !== undefined && rawGraph.edges[edgeIndex]) {
|
||||
@@ -370,10 +370,10 @@ const useGraphStoreBase = create<GraphState>()((set, get) => ({
|
||||
sigmaGraph.setEdgeAttribute(dynamicId, 'label', newValue)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Trigger a re-render by incrementing the version counter
|
||||
set((state) => ({ graphDataVersion: state.graphDataVersion + 1 }))
|
||||
|
||||
|
||||
// Update selected edge in store to ensure UI reflects changes
|
||||
set({ selectedEdge: dynamicId })
|
||||
} catch (error) {
|
||||
|
@@ -3,7 +3,7 @@ import { useGraphStore } from '@/stores/graph'
|
||||
/**
|
||||
* Update node in the graph visualization
|
||||
* This function is now a wrapper around the store's updateNodeAndSelect method
|
||||
*
|
||||
*
|
||||
* @param nodeId - ID of the node to update
|
||||
* @param entityId - ID of the entity
|
||||
* @param propertyName - Name of the property being updated
|
||||
|
@@ -510,35 +510,66 @@ async def test_graph_batch_operations(storage):
|
||||
assert node1_id in nodes_dict, f"{node1_id} 应在返回结果中"
|
||||
assert node2_id in nodes_dict, f"{node2_id} 应在返回结果中"
|
||||
assert node3_id in nodes_dict, f"{node3_id} 应在返回结果中"
|
||||
assert nodes_dict[node1_id]["description"] == node1_data["description"], f"{node1_id} 描述不匹配"
|
||||
assert nodes_dict[node2_id]["description"] == node2_data["description"], f"{node2_id} 描述不匹配"
|
||||
assert nodes_dict[node3_id]["description"] == node3_data["description"], f"{node3_id} 描述不匹配"
|
||||
assert (
|
||||
nodes_dict[node1_id]["description"] == node1_data["description"]
|
||||
), f"{node1_id} 描述不匹配"
|
||||
assert (
|
||||
nodes_dict[node2_id]["description"] == node2_data["description"]
|
||||
), f"{node2_id} 描述不匹配"
|
||||
assert (
|
||||
nodes_dict[node3_id]["description"] == node3_data["description"]
|
||||
), f"{node3_id} 描述不匹配"
|
||||
|
||||
# 3. 测试 node_degrees_batch - 批量获取多个节点的度数
|
||||
print("== 测试 node_degrees_batch")
|
||||
node_degrees = await storage.node_degrees_batch(node_ids)
|
||||
print(f"批量获取节点度数结果: {node_degrees}")
|
||||
assert len(node_degrees) == 3, f"应返回3个节点的度数,实际返回 {len(node_degrees)} 个"
|
||||
assert (
|
||||
len(node_degrees) == 3
|
||||
), f"应返回3个节点的度数,实际返回 {len(node_degrees)} 个"
|
||||
assert node1_id in node_degrees, f"{node1_id} 应在返回结果中"
|
||||
assert node2_id in node_degrees, f"{node2_id} 应在返回结果中"
|
||||
assert node3_id in node_degrees, f"{node3_id} 应在返回结果中"
|
||||
assert node_degrees[node1_id] == 3, f"{node1_id} 度数应为3,实际为 {node_degrees[node1_id]}"
|
||||
assert node_degrees[node2_id] == 2, f"{node2_id} 度数应为2,实际为 {node_degrees[node2_id]}"
|
||||
assert node_degrees[node3_id] == 3, f"{node3_id} 度数应为3,实际为 {node_degrees[node3_id]}"
|
||||
assert (
|
||||
node_degrees[node1_id] == 3
|
||||
), f"{node1_id} 度数应为3,实际为 {node_degrees[node1_id]}"
|
||||
assert (
|
||||
node_degrees[node2_id] == 2
|
||||
), f"{node2_id} 度数应为2,实际为 {node_degrees[node2_id]}"
|
||||
assert (
|
||||
node_degrees[node3_id] == 3
|
||||
), f"{node3_id} 度数应为3,实际为 {node_degrees[node3_id]}"
|
||||
|
||||
# 4. 测试 edge_degrees_batch - 批量获取多个边的度数
|
||||
print("== 测试 edge_degrees_batch")
|
||||
edges = [(node1_id, node2_id), (node2_id, node3_id), (node3_id, node4_id)]
|
||||
edge_degrees = await storage.edge_degrees_batch(edges)
|
||||
print(f"批量获取边度数结果: {edge_degrees}")
|
||||
assert len(edge_degrees) == 3, f"应返回3条边的度数,实际返回 {len(edge_degrees)} 条"
|
||||
assert (node1_id, node2_id) in edge_degrees, f"边 {node1_id} -> {node2_id} 应在返回结果中"
|
||||
assert (node2_id, node3_id) in edge_degrees, f"边 {node2_id} -> {node3_id} 应在返回结果中"
|
||||
assert (node3_id, node4_id) in edge_degrees, f"边 {node3_id} -> {node4_id} 应在返回结果中"
|
||||
assert (
|
||||
len(edge_degrees) == 3
|
||||
), f"应返回3条边的度数,实际返回 {len(edge_degrees)} 条"
|
||||
assert (
|
||||
node1_id,
|
||||
node2_id,
|
||||
) in edge_degrees, f"边 {node1_id} -> {node2_id} 应在返回结果中"
|
||||
assert (
|
||||
node2_id,
|
||||
node3_id,
|
||||
) in edge_degrees, f"边 {node2_id} -> {node3_id} 应在返回结果中"
|
||||
assert (
|
||||
node3_id,
|
||||
node4_id,
|
||||
) in edge_degrees, f"边 {node3_id} -> {node4_id} 应在返回结果中"
|
||||
# 验证边的度数是否正确(源节点度数 + 目标节点度数)
|
||||
assert edge_degrees[(node1_id, node2_id)] == 5, f"边 {node1_id} -> {node2_id} 度数应为5,实际为 {edge_degrees[(node1_id, node2_id)]}"
|
||||
assert edge_degrees[(node2_id, node3_id)] == 5, f"边 {node2_id} -> {node3_id} 度数应为5,实际为 {edge_degrees[(node2_id, node3_id)]}"
|
||||
assert edge_degrees[(node3_id, node4_id)] == 5, f"边 {node3_id} -> {node4_id} 度数应为5,实际为 {edge_degrees[(node3_id, node4_id)]}"
|
||||
assert (
|
||||
edge_degrees[(node1_id, node2_id)] == 5
|
||||
), f"边 {node1_id} -> {node2_id} 度数应为5,实际为 {edge_degrees[(node1_id, node2_id)]}"
|
||||
assert (
|
||||
edge_degrees[(node2_id, node3_id)] == 5
|
||||
), f"边 {node2_id} -> {node3_id} 度数应为5,实际为 {edge_degrees[(node2_id, node3_id)]}"
|
||||
assert (
|
||||
edge_degrees[(node3_id, node4_id)] == 5
|
||||
), f"边 {node3_id} -> {node4_id} 度数应为5,实际为 {edge_degrees[(node3_id, node4_id)]}"
|
||||
|
||||
# 5. 测试 get_edges_batch - 批量获取多个边的属性
|
||||
print("== 测试 get_edges_batch")
|
||||
@@ -547,28 +578,54 @@ async def test_graph_batch_operations(storage):
|
||||
edges_dict = await storage.get_edges_batch(edge_dicts)
|
||||
print(f"批量获取边属性结果: {edges_dict.keys()}")
|
||||
assert len(edges_dict) == 3, f"应返回3条边的属性,实际返回 {len(edges_dict)} 条"
|
||||
assert (node1_id, node2_id) in edges_dict, f"边 {node1_id} -> {node2_id} 应在返回结果中"
|
||||
assert (node2_id, node3_id) in edges_dict, f"边 {node2_id} -> {node3_id} 应在返回结果中"
|
||||
assert (node3_id, node4_id) in edges_dict, f"边 {node3_id} -> {node4_id} 应在返回结果中"
|
||||
assert edges_dict[(node1_id, node2_id)]["relationship"] == edge1_data["relationship"], f"边 {node1_id} -> {node2_id} 关系不匹配"
|
||||
assert edges_dict[(node2_id, node3_id)]["relationship"] == edge2_data["relationship"], f"边 {node2_id} -> {node3_id} 关系不匹配"
|
||||
assert edges_dict[(node3_id, node4_id)]["relationship"] == edge5_data["relationship"], f"边 {node3_id} -> {node4_id} 关系不匹配"
|
||||
assert (
|
||||
node1_id,
|
||||
node2_id,
|
||||
) in edges_dict, f"边 {node1_id} -> {node2_id} 应在返回结果中"
|
||||
assert (
|
||||
node2_id,
|
||||
node3_id,
|
||||
) in edges_dict, f"边 {node2_id} -> {node3_id} 应在返回结果中"
|
||||
assert (
|
||||
node3_id,
|
||||
node4_id,
|
||||
) in edges_dict, f"边 {node3_id} -> {node4_id} 应在返回结果中"
|
||||
assert (
|
||||
edges_dict[(node1_id, node2_id)]["relationship"]
|
||||
== edge1_data["relationship"]
|
||||
), f"边 {node1_id} -> {node2_id} 关系不匹配"
|
||||
assert (
|
||||
edges_dict[(node2_id, node3_id)]["relationship"]
|
||||
== edge2_data["relationship"]
|
||||
), f"边 {node2_id} -> {node3_id} 关系不匹配"
|
||||
assert (
|
||||
edges_dict[(node3_id, node4_id)]["relationship"]
|
||||
== edge5_data["relationship"]
|
||||
), f"边 {node3_id} -> {node4_id} 关系不匹配"
|
||||
|
||||
# 6. 测试 get_nodes_edges_batch - 批量获取多个节点的所有边
|
||||
print("== 测试 get_nodes_edges_batch")
|
||||
nodes_edges = await storage.get_nodes_edges_batch([node1_id, node3_id])
|
||||
print(f"批量获取节点边结果: {nodes_edges.keys()}")
|
||||
assert len(nodes_edges) == 2, f"应返回2个节点的边,实际返回 {len(nodes_edges)} 个"
|
||||
assert (
|
||||
len(nodes_edges) == 2
|
||||
), f"应返回2个节点的边,实际返回 {len(nodes_edges)} 个"
|
||||
assert node1_id in nodes_edges, f"{node1_id} 应在返回结果中"
|
||||
assert node3_id in nodes_edges, f"{node3_id} 应在返回结果中"
|
||||
assert len(nodes_edges[node1_id]) == 3, f"{node1_id} 应有3条边,实际有 {len(nodes_edges[node1_id])} 条"
|
||||
assert len(nodes_edges[node3_id]) == 3, f"{node3_id} 应有3条边,实际有 {len(nodes_edges[node3_id])} 条"
|
||||
assert (
|
||||
len(nodes_edges[node1_id]) == 3
|
||||
), f"{node1_id} 应有3条边,实际有 {len(nodes_edges[node1_id])} 条"
|
||||
assert (
|
||||
len(nodes_edges[node3_id]) == 3
|
||||
), f"{node3_id} 应有3条边,实际有 {len(nodes_edges[node3_id])} 条"
|
||||
|
||||
# 7. 清理数据
|
||||
print("== 测试 drop")
|
||||
result = await storage.drop()
|
||||
print(f"清理结果: {result}")
|
||||
assert result["status"] == "success", f"清理应成功,实际状态为 {result['status']}"
|
||||
assert (
|
||||
result["status"] == "success"
|
||||
), f"清理应成功,实际状态为 {result['status']}"
|
||||
|
||||
print("\n批量操作测试完成")
|
||||
return True
|
||||
@@ -630,7 +687,7 @@ async def main():
|
||||
if basic_result:
|
||||
ASCIIColors.cyan("\n=== 开始高级测试 ===")
|
||||
advanced_result = await test_graph_advanced(storage)
|
||||
|
||||
|
||||
if advanced_result:
|
||||
ASCIIColors.cyan("\n=== 开始批量操作测试 ===")
|
||||
await test_graph_batch_operations(storage)
|
||||
|
Reference in New Issue
Block a user