Fix linting

This commit is contained in:
yangdx
2025-04-15 12:34:04 +08:00
parent 22b03ee1bb
commit 1de74c9228
7 changed files with 249 additions and 150 deletions

View File

@@ -363,7 +363,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
"""Get nodes as a batch using UNWIND """Get nodes as a batch using UNWIND
Default implementation fetches nodes one by one. Default implementation fetches nodes one by one.
Override this method for better performance in storage backends Override this method for better performance in storage backends
that support batch operations. that support batch operations.
@@ -377,7 +377,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]: async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
"""Node degrees as a batch using UNWIND """Node degrees as a batch using UNWIND
Default implementation fetches node degrees one by one. Default implementation fetches node degrees one by one.
Override this method for better performance in storage backends Override this method for better performance in storage backends
that support batch operations. that support batch operations.
@@ -388,9 +388,11 @@ class BaseGraphStorage(StorageNameSpace, ABC):
result[node_id] = degree result[node_id] = degree
return result return result
async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> dict[tuple[str, str], int]: async def edge_degrees_batch(
self, edge_pairs: list[tuple[str, str]]
) -> dict[tuple[str, str], int]:
"""Edge degrees as a batch using UNWIND also uses node_degrees_batch """Edge degrees as a batch using UNWIND also uses node_degrees_batch
Default implementation calculates edge degrees one by one. Default implementation calculates edge degrees one by one.
Override this method for better performance in storage backends Override this method for better performance in storage backends
that support batch operations. that support batch operations.
@@ -401,9 +403,11 @@ class BaseGraphStorage(StorageNameSpace, ABC):
result[(src_id, tgt_id)] = degree result[(src_id, tgt_id)] = degree
return result return result
async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]: async def get_edges_batch(
self, pairs: list[dict[str, str]]
) -> dict[tuple[str, str], dict]:
"""Get edges as a batch using UNWIND """Get edges as a batch using UNWIND
Default implementation fetches edges one by one. Default implementation fetches edges one by one.
Override this method for better performance in storage backends Override this method for better performance in storage backends
that support batch operations. that support batch operations.
@@ -417,9 +421,11 @@ class BaseGraphStorage(StorageNameSpace, ABC):
result[(src_id, tgt_id)] = edge result[(src_id, tgt_id)] = edge
return result return result
async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]: async def get_nodes_edges_batch(
self, node_ids: list[str]
) -> dict[str, list[tuple[str, str]]]:
"""Get nodes edges as a batch using UNWIND """Get nodes edges as a batch using UNWIND
Default implementation fetches node edges one by one. Default implementation fetches node edges one by one.
Override this method for better performance in storage backends Override this method for better performance in storage backends
that support batch operations. that support batch operations.

View File

@@ -311,10 +311,10 @@ class Neo4JStorage(BaseGraphStorage):
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
""" """
Retrieve multiple nodes in one query using UNWIND. Retrieve multiple nodes in one query using UNWIND.
Args: Args:
node_ids: List of node entity IDs to fetch. node_ids: List of node entity IDs to fetch.
Returns: Returns:
A dictionary mapping each node_id to its node data (or None if not found). A dictionary mapping each node_id to its node data (or None if not found).
""" """
@@ -334,7 +334,9 @@ class Neo4JStorage(BaseGraphStorage):
node_dict = dict(node) node_dict = dict(node)
# Remove the 'base' label if present in a 'labels' property # Remove the 'base' label if present in a 'labels' property
if "labels" in node_dict: if "labels" in node_dict:
node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"] node_dict["labels"] = [
label for label in node_dict["labels"] if label != "base"
]
nodes[entity_id] = node_dict nodes[entity_id] = node_dict
await result.consume() # Make sure to consume the result fully await result.consume() # Make sure to consume the result fully
return nodes return nodes
@@ -385,12 +387,12 @@ class Neo4JStorage(BaseGraphStorage):
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]: async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
""" """
Retrieve the degree for multiple nodes in a single query using UNWIND. Retrieve the degree for multiple nodes in a single query using UNWIND.
Args: Args:
node_ids: List of node labels (entity_id values) to look up. node_ids: List of node labels (entity_id values) to look up.
Returns: Returns:
A dictionary mapping each node_id to its degree (number of relationships). A dictionary mapping each node_id to its degree (number of relationships).
If a node is not found, its degree will be set to 0. If a node is not found, its degree will be set to 0.
""" """
async with self._driver.session( async with self._driver.session(
@@ -407,13 +409,13 @@ class Neo4JStorage(BaseGraphStorage):
entity_id = record["entity_id"] entity_id = record["entity_id"]
degrees[entity_id] = record["degree"] degrees[entity_id] = record["degree"]
await result.consume() # Ensure result is fully consumed await result.consume() # Ensure result is fully consumed
# For any node_id that did not return a record, set degree to 0. # For any node_id that did not return a record, set degree to 0.
for nid in node_ids: for nid in node_ids:
if nid not in degrees: if nid not in degrees:
logger.warning(f"No node found with label '{nid}'") logger.warning(f"No node found with label '{nid}'")
degrees[nid] = 0 degrees[nid] = 0
logger.debug(f"Neo4j batch node degree query returned: {degrees}") logger.debug(f"Neo4j batch node degree query returned: {degrees}")
return degrees return degrees
@@ -436,25 +438,27 @@ class Neo4JStorage(BaseGraphStorage):
degrees = int(src_degree) + int(trg_degree) degrees = int(src_degree) + int(trg_degree)
return degrees return degrees
async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> dict[tuple[str, str], int]: async def edge_degrees_batch(
self, edge_pairs: list[tuple[str, str]]
) -> dict[tuple[str, str], int]:
""" """
Calculate the combined degree for each edge (sum of the source and target node degrees) Calculate the combined degree for each edge (sum of the source and target node degrees)
in batch using the already implemented node_degrees_batch. in batch using the already implemented node_degrees_batch.
Args: Args:
edge_pairs: List of (src, tgt) tuples. edge_pairs: List of (src, tgt) tuples.
Returns: Returns:
A dictionary mapping each (src, tgt) tuple to the sum of their degrees. A dictionary mapping each (src, tgt) tuple to the sum of their degrees.
""" """
# Collect unique node IDs from all edge pairs. # Collect unique node IDs from all edge pairs.
unique_node_ids = {src for src, _ in edge_pairs} unique_node_ids = {src for src, _ in edge_pairs}
unique_node_ids.update({tgt for _, tgt in edge_pairs}) unique_node_ids.update({tgt for _, tgt in edge_pairs})
# Get degrees for all nodes in one go. # Get degrees for all nodes in one go.
degrees = await self.node_degrees_batch(list(unique_node_ids)) degrees = await self.node_degrees_batch(list(unique_node_ids))
# Sum up degrees for each edge pair. # Sum up degrees for each edge pair.
edge_degrees = {} edge_degrees = {}
for src, tgt in edge_pairs: for src, tgt in edge_pairs:
@@ -547,13 +551,15 @@ class Neo4JStorage(BaseGraphStorage):
) )
raise raise
async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]: async def get_edges_batch(
self, pairs: list[dict[str, str]]
) -> dict[tuple[str, str], dict]:
""" """
Retrieve edge properties for multiple (src, tgt) pairs in one query. Retrieve edge properties for multiple (src, tgt) pairs in one query.
Args: Args:
pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...] pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...]
Returns: Returns:
A dictionary mapping (src, tgt) tuples to their edge properties. A dictionary mapping (src, tgt) tuples to their edge properties.
""" """
@@ -574,13 +580,23 @@ class Neo4JStorage(BaseGraphStorage):
if edges and len(edges) > 0: if edges and len(edges) > 0:
edge_props = edges[0] # choose the first if multiple exist edge_props = edges[0] # choose the first if multiple exist
# Ensure required keys exist with defaults # Ensure required keys exist with defaults
for key, default in {"weight": 0.0, "source_id": None, "description": None, "keywords": None}.items(): for key, default in {
"weight": 0.0,
"source_id": None,
"description": None,
"keywords": None,
}.items():
if key not in edge_props: if key not in edge_props:
edge_props[key] = default edge_props[key] = default
edges_dict[(src, tgt)] = edge_props edges_dict[(src, tgt)] = edge_props
else: else:
# No edge found set default edge properties # No edge found set default edge properties
edges_dict[(src, tgt)] = {"weight": 0.0, "source_id": None, "description": None, "keywords": None} edges_dict[(src, tgt)] = {
"weight": 0.0,
"source_id": None,
"description": None,
"keywords": None,
}
await result.consume() await result.consume()
return edges_dict return edges_dict
@@ -644,17 +660,21 @@ class Neo4JStorage(BaseGraphStorage):
logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}") logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}")
raise raise
async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]: async def get_nodes_edges_batch(
self, node_ids: list[str]
) -> dict[str, list[tuple[str, str]]]:
""" """
Batch retrieve edges for multiple nodes in one query using UNWIND. Batch retrieve edges for multiple nodes in one query using UNWIND.
Args: Args:
node_ids: List of node IDs (entity_id) for which to retrieve edges. node_ids: List of node IDs (entity_id) for which to retrieve edges.
Returns: Returns:
A dictionary mapping each node ID to its list of edge tuples (source, target). A dictionary mapping each node ID to its list of edge tuples (source, target).
""" """
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = """ query = """
UNWIND $node_ids AS id UNWIND $node_ids AS id
MATCH (n:base {entity_id: id}) MATCH (n:base {entity_id: id})

View File

@@ -1461,30 +1461,29 @@ class PGGraphStorage(BaseGraphStorage):
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
""" """
Retrieve multiple nodes in one query using UNWIND. Retrieve multiple nodes in one query using UNWIND.
Args: Args:
node_ids: List of node entity IDs to fetch. node_ids: List of node entity IDs to fetch.
Returns: Returns:
A dictionary mapping each node_id to its node data (or None if not found). A dictionary mapping each node_id to its node data (or None if not found).
""" """
if not node_ids: if not node_ids:
return {} return {}
# Format node IDs for the query # Format node IDs for the query
formatted_ids = ", ".join(['"' + node_id.replace('"', '') + '"' for node_id in node_ids]) formatted_ids = ", ".join(
['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
)
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
UNWIND [%s] AS node_id UNWIND [%s] AS node_id
MATCH (n:base {entity_id: node_id}) MATCH (n:base {entity_id: node_id})
RETURN node_id, n RETURN node_id, n
$$) AS (node_id text, n agtype)""" % ( $$) AS (node_id text, n agtype)""" % (self.graph_name, formatted_ids)
self.graph_name,
formatted_ids
)
results = await self._query(query) results = await self._query(query)
# Build result dictionary # Build result dictionary
nodes_dict = {} nodes_dict = {}
for result in results: for result in results:
@@ -1492,28 +1491,32 @@ class PGGraphStorage(BaseGraphStorage):
node_dict = result["n"]["properties"] node_dict = result["n"]["properties"]
# Remove the 'base' label if present in a 'labels' property # Remove the 'base' label if present in a 'labels' property
if "labels" in node_dict: if "labels" in node_dict:
node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"] node_dict["labels"] = [
label for label in node_dict["labels"] if label != "base"
]
nodes_dict[result["node_id"]] = node_dict nodes_dict[result["node_id"]] = node_dict
return nodes_dict return nodes_dict
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]: async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
""" """
Retrieve the degree for multiple nodes in a single query using UNWIND. Retrieve the degree for multiple nodes in a single query using UNWIND.
Args: Args:
node_ids: List of node labels (entity_id values) to look up. node_ids: List of node labels (entity_id values) to look up.
Returns: Returns:
A dictionary mapping each node_id to its degree (number of relationships). A dictionary mapping each node_id to its degree (number of relationships).
If a node is not found, its degree will be set to 0. If a node is not found, its degree will be set to 0.
""" """
if not node_ids: if not node_ids:
return {} return {}
# Format node IDs for the query # Format node IDs for the query
formatted_ids = ", ".join(['"' + node_id.replace('"', '') + '"' for node_id in node_ids]) formatted_ids = ", ".join(
['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
)
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
UNWIND [%s] AS node_id UNWIND [%s] AS node_id
MATCH (n:base {entity_id: node_id}) MATCH (n:base {entity_id: node_id})
@@ -1521,112 +1524,122 @@ class PGGraphStorage(BaseGraphStorage):
RETURN node_id, count(r) AS degree RETURN node_id, count(r) AS degree
$$) AS (node_id text, degree bigint)""" % ( $$) AS (node_id text, degree bigint)""" % (
self.graph_name, self.graph_name,
formatted_ids formatted_ids,
) )
results = await self._query(query) results = await self._query(query)
# Build result dictionary # Build result dictionary
degrees_dict = {} degrees_dict = {}
for result in results: for result in results:
if result["node_id"] is not None: if result["node_id"] is not None:
degrees_dict[result["node_id"]] = int(result["degree"]) degrees_dict[result["node_id"]] = int(result["degree"])
# Ensure all requested node_ids are in the result dictionary # Ensure all requested node_ids are in the result dictionary
for node_id in node_ids: for node_id in node_ids:
if node_id not in degrees_dict: if node_id not in degrees_dict:
degrees_dict[node_id] = 0 degrees_dict[node_id] = 0
return degrees_dict return degrees_dict
async def edge_degrees_batch(self, edges: list[tuple[str, str]]) -> dict[tuple[str, str], int]: async def edge_degrees_batch(
self, edges: list[tuple[str, str]]
) -> dict[tuple[str, str], int]:
""" """
Calculate the combined degree for each edge (sum of the source and target node degrees) Calculate the combined degree for each edge (sum of the source and target node degrees)
in batch using the already implemented node_degrees_batch. in batch using the already implemented node_degrees_batch.
Args: Args:
edges: List of (source_node_id, target_node_id) tuples edges: List of (source_node_id, target_node_id) tuples
Returns: Returns:
Dictionary mapping edge tuples to their combined degrees Dictionary mapping edge tuples to their combined degrees
""" """
if not edges: if not edges:
return {} return {}
# Use node_degrees_batch to get all node degrees efficiently # Use node_degrees_batch to get all node degrees efficiently
all_nodes = set() all_nodes = set()
for src, tgt in edges: for src, tgt in edges:
all_nodes.add(src) all_nodes.add(src)
all_nodes.add(tgt) all_nodes.add(tgt)
node_degrees = await self.node_degrees_batch(list(all_nodes)) node_degrees = await self.node_degrees_batch(list(all_nodes))
# Calculate edge degrees # Calculate edge degrees
edge_degrees_dict = {} edge_degrees_dict = {}
for src, tgt in edges: for src, tgt in edges:
src_degree = node_degrees.get(src, 0) src_degree = node_degrees.get(src, 0)
tgt_degree = node_degrees.get(tgt, 0) tgt_degree = node_degrees.get(tgt, 0)
edge_degrees_dict[(src, tgt)] = src_degree + tgt_degree edge_degrees_dict[(src, tgt)] = src_degree + tgt_degree
return edge_degrees_dict return edge_degrees_dict
async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]: async def get_edges_batch(
self, pairs: list[dict[str, str]]
) -> dict[tuple[str, str], dict]:
""" """
Retrieve edge properties for multiple (src, tgt) pairs in one query. Retrieve edge properties for multiple (src, tgt) pairs in one query.
Args: Args:
pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...] pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...]
Returns: Returns:
A dictionary mapping (src, tgt) tuples to their edge properties. A dictionary mapping (src, tgt) tuples to their edge properties.
""" """
if not pairs: if not pairs:
return {} return {}
# 从字典列表中提取源节点和目标节点ID # 从字典列表中提取源节点和目标节点ID
src_nodes = [] src_nodes = []
tgt_nodes = [] tgt_nodes = []
for pair in pairs: for pair in pairs:
src_nodes.append(pair["src"].replace('"', '')) src_nodes.append(pair["src"].replace('"', ""))
tgt_nodes.append(pair["tgt"].replace('"', '')) tgt_nodes.append(pair["tgt"].replace('"', ""))
# 构建查询,使用数组索引来匹配源节点和目标节点 # 构建查询,使用数组索引来匹配源节点和目标节点
src_array = ", ".join([f'"{src}"' for src in src_nodes]) src_array = ", ".join([f'"{src}"' for src in src_nodes])
tgt_array = ", ".join([f'"{tgt}"' for tgt in tgt_nodes]) tgt_array = ", ".join([f'"{tgt}"' for tgt in tgt_nodes])
query = f"""SELECT * FROM cypher('{self.graph_name}', $$ query = f"""SELECT * FROM cypher('{self.graph_name}', $$
WITH [{src_array}] AS sources, [{tgt_array}] AS targets WITH [{src_array}] AS sources, [{tgt_array}] AS targets
UNWIND range(0, size(sources)-1) AS i UNWIND range(0, size(sources)-1) AS i
MATCH (a:base {{entity_id: sources[i]}})-[r:DIRECTED]-(b:base {{entity_id: targets[i]}}) MATCH (a:base {{entity_id: sources[i]}})-[r:DIRECTED]-(b:base {{entity_id: targets[i]}})
RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties
$$) AS (source text, target text, edge_properties agtype)""" $$) AS (source text, target text, edge_properties agtype)"""
results = await self._query(query) results = await self._query(query)
# 构建结果字典 # 构建结果字典
edges_dict = {} edges_dict = {}
for result in results: for result in results:
if result["source"] and result["target"] and result["edge_properties"]: if result["source"] and result["target"] and result["edge_properties"]:
edges_dict[(result["source"], result["target"])] = result["edge_properties"] edges_dict[(result["source"], result["target"])] = result[
"edge_properties"
]
return edges_dict return edges_dict
async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]: async def get_nodes_edges_batch(
self, node_ids: list[str]
) -> dict[str, list[tuple[str, str]]]:
""" """
Get all edges for multiple nodes in a single batch operation. Get all edges for multiple nodes in a single batch operation.
Args: Args:
node_ids: List of node IDs to get edges for node_ids: List of node IDs to get edges for
Returns: Returns:
Dictionary mapping node IDs to lists of (source, target) edge tuples Dictionary mapping node IDs to lists of (source, target) edge tuples
""" """
if not node_ids: if not node_ids:
return {} return {}
# Format node IDs for the query # Format node IDs for the query
formatted_ids = ", ".join(['"' + node_id.replace('"', '') + '"' for node_id in node_ids]) formatted_ids = ", ".join(
['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
)
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
UNWIND [%s] AS node_id UNWIND [%s] AS node_id
MATCH (n:base {entity_id: node_id}) MATCH (n:base {entity_id: node_id})
@@ -1634,11 +1647,11 @@ class PGGraphStorage(BaseGraphStorage):
RETURN node_id, connected.entity_id AS connected_id RETURN node_id, connected.entity_id AS connected_id
$$) AS (node_id text, connected_id text)""" % ( $$) AS (node_id text, connected_id text)""" % (
self.graph_name, self.graph_name,
formatted_ids formatted_ids,
) )
results = await self._query(query) results = await self._query(query)
# Build result dictionary # Build result dictionary
nodes_edges_dict = {node_id: [] for node_id in node_ids} nodes_edges_dict = {node_id: [] for node_id in node_ids}
for result in results: for result in results:
@@ -1646,9 +1659,9 @@ class PGGraphStorage(BaseGraphStorage):
nodes_edges_dict[result["node_id"]].append( nodes_edges_dict[result["node_id"]].append(
(result["node_id"], result["connected_id"]) (result["node_id"], result["connected_id"])
) )
return nodes_edges_dict return nodes_edges_dict
async def get_all_labels(self) -> list[str]: async def get_all_labels(self) -> list[str]:
""" """
Get all labels (node IDs) in the graph. Get all labels (node IDs) in the graph.

View File

@@ -1323,14 +1323,14 @@ async def _get_node_data(
if not len(results): if not len(results):
return "", "", "" return "", "", ""
# Extract all entity IDs from your results list # Extract all entity IDs from your results list
node_ids = [r["entity_name"] for r in results] node_ids = [r["entity_name"] for r in results]
# Call the batch node retrieval and degree functions concurrently. # Call the batch node retrieval and degree functions concurrently.
nodes_dict, degrees_dict = await asyncio.gather( nodes_dict, degrees_dict = await asyncio.gather(
knowledge_graph_inst.get_nodes_batch(node_ids), knowledge_graph_inst.get_nodes_batch(node_ids),
knowledge_graph_inst.node_degrees_batch(node_ids) knowledge_graph_inst.node_degrees_batch(node_ids),
) )
# Now, if you need the node data and degree in order: # Now, if you need the node data and degree in order:
@@ -1459,7 +1459,7 @@ async def _find_most_related_text_unit_from_entities(
for dp in node_datas for dp in node_datas
if dp["source_id"] is not None if dp["source_id"] is not None
] ]
node_names = [dp["entity_name"] for dp in node_datas] node_names = [dp["entity_name"] for dp in node_datas]
batch_edges_dict = await knowledge_graph_inst.get_nodes_edges_batch(node_names) batch_edges_dict = await knowledge_graph_inst.get_nodes_edges_batch(node_names)
# Build the edges list in the same order as node_datas. # Build the edges list in the same order as node_datas.
@@ -1472,10 +1472,14 @@ async def _find_most_related_text_unit_from_entities(
all_one_hop_nodes.update([e[1] for e in this_edges]) all_one_hop_nodes.update([e[1] for e in this_edges])
all_one_hop_nodes = list(all_one_hop_nodes) all_one_hop_nodes = list(all_one_hop_nodes)
# Batch retrieve one-hop node data using get_nodes_batch # Batch retrieve one-hop node data using get_nodes_batch
all_one_hop_nodes_data_dict = await knowledge_graph_inst.get_nodes_batch(all_one_hop_nodes) all_one_hop_nodes_data_dict = await knowledge_graph_inst.get_nodes_batch(
all_one_hop_nodes_data = [all_one_hop_nodes_data_dict.get(e) for e in all_one_hop_nodes] all_one_hop_nodes
)
all_one_hop_nodes_data = [
all_one_hop_nodes_data_dict.get(e) for e in all_one_hop_nodes
]
# Add null check for node data # Add null check for node data
all_one_hop_text_units_lookup = { all_one_hop_text_units_lookup = {
@@ -1571,13 +1575,13 @@ async def _find_most_related_edges_from_entities(
edge_pairs_dicts = [{"src": e[0], "tgt": e[1]} for e in all_edges] edge_pairs_dicts = [{"src": e[0], "tgt": e[1]} for e in all_edges]
# For edge degrees, use tuples. # For edge degrees, use tuples.
edge_pairs_tuples = list(all_edges) # all_edges is already a list of tuples edge_pairs_tuples = list(all_edges) # all_edges is already a list of tuples
# Call the batched functions concurrently. # Call the batched functions concurrently.
edge_data_dict, edge_degrees_dict = await asyncio.gather( edge_data_dict, edge_degrees_dict = await asyncio.gather(
knowledge_graph_inst.get_edges_batch(edge_pairs_dicts), knowledge_graph_inst.get_edges_batch(edge_pairs_dicts),
knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples) knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples),
) )
# Reconstruct edge_datas list in the same order as the deduplicated results. # Reconstruct edge_datas list in the same order as the deduplicated results.
all_edges_data = [] all_edges_data = []
for pair in all_edges: for pair in all_edges:
@@ -1590,7 +1594,6 @@ async def _find_most_related_edges_from_entities(
} }
all_edges_data.append(combined) all_edges_data.append(combined)
all_edges_data = sorted( all_edges_data = sorted(
all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
) )
@@ -1634,7 +1637,7 @@ async def _get_edge_data(
# Call the batched functions concurrently. # Call the batched functions concurrently.
edge_data_dict, edge_degrees_dict = await asyncio.gather( edge_data_dict, edge_degrees_dict = await asyncio.gather(
knowledge_graph_inst.get_edges_batch(edge_pairs_dicts), knowledge_graph_inst.get_edges_batch(edge_pairs_dicts),
knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples) knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples),
) )
# Reconstruct edge_datas list in the same order as results. # Reconstruct edge_datas list in the same order as results.
@@ -1652,7 +1655,7 @@ async def _get_edge_data(
**edge_props, **edge_props,
} }
edge_datas.append(combined) edge_datas.append(combined)
edge_datas = sorted( edge_datas = sorted(
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
) )
@@ -1761,7 +1764,7 @@ async def _find_most_related_entities_from_relationships(
# Batch approach: Retrieve nodes and their degrees concurrently with one query each. # Batch approach: Retrieve nodes and their degrees concurrently with one query each.
nodes_dict, degrees_dict = await asyncio.gather( nodes_dict, degrees_dict = await asyncio.gather(
knowledge_graph_inst.get_nodes_batch(entity_names), knowledge_graph_inst.get_nodes_batch(entity_names),
knowledge_graph_inst.node_degrees_batch(entity_names) knowledge_graph_inst.node_degrees_batch(entity_names),
) )
# Rebuild the list in the same order as entity_names # Rebuild the list in the same order as entity_names

View File

@@ -136,7 +136,7 @@ interface GraphState {
// Version counter to trigger data refresh // Version counter to trigger data refresh
graphDataVersion: number graphDataVersion: number
incrementGraphDataVersion: () => void incrementGraphDataVersion: () => void
// Methods for updating graph elements and UI state together // Methods for updating graph elements and UI state together
updateNodeAndSelect: (nodeId: string, entityId: string, propertyName: string, newValue: string) => Promise<void> updateNodeAndSelect: (nodeId: string, entityId: string, propertyName: string, newValue: string) => Promise<void>
updateEdgeAndSelect: (edgeId: string, dynamicId: string, sourceId: string, targetId: string, propertyName: string, newValue: string) => Promise<void> updateEdgeAndSelect: (edgeId: string, dynamicId: string, sourceId: string, targetId: string, propertyName: string, newValue: string) => Promise<void>
@@ -252,40 +252,40 @@ const useGraphStoreBase = create<GraphState>()((set, get) => ({
// Get current state // Get current state
const state = get() const state = get()
const { sigmaGraph, rawGraph } = state const { sigmaGraph, rawGraph } = state
// Validate graph state // Validate graph state
if (!sigmaGraph || !rawGraph || !sigmaGraph.hasNode(nodeId)) { if (!sigmaGraph || !rawGraph || !sigmaGraph.hasNode(nodeId)) {
return return
} }
try { try {
const nodeAttributes = sigmaGraph.getNodeAttributes(nodeId) const nodeAttributes = sigmaGraph.getNodeAttributes(nodeId)
console.log('updateNodeAndSelect', nodeId, entityId, propertyName, newValue) console.log('updateNodeAndSelect', nodeId, entityId, propertyName, newValue)
// For entity_id changes (node renaming) with NetworkX graph storage // For entity_id changes (node renaming) with NetworkX graph storage
if ((nodeId === entityId) && (propertyName === 'entity_id')) { if ((nodeId === entityId) && (propertyName === 'entity_id')) {
// Create new node with updated ID but same attributes // Create new node with updated ID but same attributes
sigmaGraph.addNode(newValue, { ...nodeAttributes, label: newValue }) sigmaGraph.addNode(newValue, { ...nodeAttributes, label: newValue })
const edgesToUpdate: EdgeToUpdate[] = [] const edgesToUpdate: EdgeToUpdate[] = []
// Process all edges connected to this node // Process all edges connected to this node
sigmaGraph.forEachEdge(nodeId, (edge, attributes, source, target) => { sigmaGraph.forEachEdge(nodeId, (edge, attributes, source, target) => {
const otherNode = source === nodeId ? target : source const otherNode = source === nodeId ? target : source
const isOutgoing = source === nodeId const isOutgoing = source === nodeId
// Get original edge dynamic ID for later reference // Get original edge dynamic ID for later reference
const originalEdgeDynamicId = edge const originalEdgeDynamicId = edge
const edgeIndexInRawGraph = rawGraph.edgeDynamicIdMap[originalEdgeDynamicId] const edgeIndexInRawGraph = rawGraph.edgeDynamicIdMap[originalEdgeDynamicId]
// Create new edge with updated node reference // Create new edge with updated node reference
const newEdgeId = sigmaGraph.addEdge( const newEdgeId = sigmaGraph.addEdge(
isOutgoing ? newValue : otherNode, isOutgoing ? newValue : otherNode,
isOutgoing ? otherNode : newValue, isOutgoing ? otherNode : newValue,
attributes attributes
) )
// Track edges that need updating in the raw graph // Track edges that need updating in the raw graph
if (edgeIndexInRawGraph !== undefined) { if (edgeIndexInRawGraph !== undefined) {
edgesToUpdate.push({ edgesToUpdate.push({
@@ -294,14 +294,14 @@ const useGraphStoreBase = create<GraphState>()((set, get) => ({
edgeIndex: edgeIndexInRawGraph edgeIndex: edgeIndexInRawGraph
}) })
} }
// Remove the old edge // Remove the old edge
sigmaGraph.dropEdge(edge) sigmaGraph.dropEdge(edge)
}) })
// Remove the old node after all edges are processed // Remove the old node after all edges are processed
sigmaGraph.dropNode(nodeId) sigmaGraph.dropNode(nodeId)
// Update node reference in raw graph data // Update node reference in raw graph data
const nodeIndex = rawGraph.nodeIdMap[nodeId] const nodeIndex = rawGraph.nodeIdMap[nodeId]
if (nodeIndex !== undefined) { if (nodeIndex !== undefined) {
@@ -311,7 +311,7 @@ const useGraphStoreBase = create<GraphState>()((set, get) => ({
delete rawGraph.nodeIdMap[nodeId] delete rawGraph.nodeIdMap[nodeId]
rawGraph.nodeIdMap[newValue] = nodeIndex rawGraph.nodeIdMap[newValue] = nodeIndex
} }
// Update all edge references in raw graph data // Update all edge references in raw graph data
edgesToUpdate.forEach(({ originalDynamicId, newEdgeId, edgeIndex }) => { edgesToUpdate.forEach(({ originalDynamicId, newEdgeId, edgeIndex }) => {
if (rawGraph.edges[edgeIndex]) { if (rawGraph.edges[edgeIndex]) {
@@ -322,14 +322,14 @@ const useGraphStoreBase = create<GraphState>()((set, get) => ({
if (rawGraph.edges[edgeIndex].target === nodeId) { if (rawGraph.edges[edgeIndex].target === nodeId) {
rawGraph.edges[edgeIndex].target = newValue rawGraph.edges[edgeIndex].target = newValue
} }
// Update dynamic ID mappings // Update dynamic ID mappings
rawGraph.edges[edgeIndex].dynamicId = newEdgeId rawGraph.edges[edgeIndex].dynamicId = newEdgeId
delete rawGraph.edgeDynamicIdMap[originalDynamicId] delete rawGraph.edgeDynamicIdMap[originalDynamicId]
rawGraph.edgeDynamicIdMap[newEdgeId] = edgeIndex rawGraph.edgeDynamicIdMap[newEdgeId] = edgeIndex
} }
}) })
// Update selected node in store // Update selected node in store
set({ selectedNode: newValue, moveToSelectedNode: true }) set({ selectedNode: newValue, moveToSelectedNode: true })
} else { } else {
@@ -342,7 +342,7 @@ const useGraphStoreBase = create<GraphState>()((set, get) => ({
sigmaGraph.setNodeAttribute(String(nodeId), 'label', newValue) sigmaGraph.setNodeAttribute(String(nodeId), 'label', newValue)
} }
} }
// Trigger a re-render by incrementing the version counter // Trigger a re-render by incrementing the version counter
set((state) => ({ graphDataVersion: state.graphDataVersion + 1 })) set((state) => ({ graphDataVersion: state.graphDataVersion + 1 }))
} }
@@ -351,17 +351,17 @@ const useGraphStoreBase = create<GraphState>()((set, get) => ({
throw new Error('Failed to update node in graph') throw new Error('Failed to update node in graph')
} }
}, },
updateEdgeAndSelect: async (edgeId: string, dynamicId: string, sourceId: string, targetId: string, propertyName: string, newValue: string) => { updateEdgeAndSelect: async (edgeId: string, dynamicId: string, sourceId: string, targetId: string, propertyName: string, newValue: string) => {
// Get current state // Get current state
const state = get() const state = get()
const { sigmaGraph, rawGraph } = state const { sigmaGraph, rawGraph } = state
// Validate graph state // Validate graph state
if (!sigmaGraph || !rawGraph) { if (!sigmaGraph || !rawGraph) {
return return
} }
try { try {
const edgeIndex = rawGraph.edgeIdMap[String(edgeId)] const edgeIndex = rawGraph.edgeIdMap[String(edgeId)]
if (edgeIndex !== undefined && rawGraph.edges[edgeIndex]) { if (edgeIndex !== undefined && rawGraph.edges[edgeIndex]) {
@@ -370,10 +370,10 @@ const useGraphStoreBase = create<GraphState>()((set, get) => ({
sigmaGraph.setEdgeAttribute(dynamicId, 'label', newValue) sigmaGraph.setEdgeAttribute(dynamicId, 'label', newValue)
} }
} }
// Trigger a re-render by incrementing the version counter // Trigger a re-render by incrementing the version counter
set((state) => ({ graphDataVersion: state.graphDataVersion + 1 })) set((state) => ({ graphDataVersion: state.graphDataVersion + 1 }))
// Update selected edge in store to ensure UI reflects changes // Update selected edge in store to ensure UI reflects changes
set({ selectedEdge: dynamicId }) set({ selectedEdge: dynamicId })
} catch (error) { } catch (error) {

View File

@@ -3,7 +3,7 @@ import { useGraphStore } from '@/stores/graph'
/** /**
* Update node in the graph visualization * Update node in the graph visualization
* This function is now a wrapper around the store's updateNodeAndSelect method * This function is now a wrapper around the store's updateNodeAndSelect method
* *
* @param nodeId - ID of the node to update * @param nodeId - ID of the node to update
* @param entityId - ID of the entity * @param entityId - ID of the entity
* @param propertyName - Name of the property being updated * @param propertyName - Name of the property being updated

View File

@@ -510,35 +510,66 @@ async def test_graph_batch_operations(storage):
assert node1_id in nodes_dict, f"{node1_id} 应在返回结果中" assert node1_id in nodes_dict, f"{node1_id} 应在返回结果中"
assert node2_id in nodes_dict, f"{node2_id} 应在返回结果中" assert node2_id in nodes_dict, f"{node2_id} 应在返回结果中"
assert node3_id in nodes_dict, f"{node3_id} 应在返回结果中" assert node3_id in nodes_dict, f"{node3_id} 应在返回结果中"
assert nodes_dict[node1_id]["description"] == node1_data["description"], f"{node1_id} 描述不匹配" assert (
assert nodes_dict[node2_id]["description"] == node2_data["description"], f"{node2_id} 描述不匹配" nodes_dict[node1_id]["description"] == node1_data["description"]
assert nodes_dict[node3_id]["description"] == node3_data["description"], f"{node3_id} 描述不匹配" ), f"{node1_id} 描述不匹配"
assert (
nodes_dict[node2_id]["description"] == node2_data["description"]
), f"{node2_id} 描述不匹配"
assert (
nodes_dict[node3_id]["description"] == node3_data["description"]
), f"{node3_id} 描述不匹配"
# 3. 测试 node_degrees_batch - 批量获取多个节点的度数 # 3. 测试 node_degrees_batch - 批量获取多个节点的度数
print("== 测试 node_degrees_batch") print("== 测试 node_degrees_batch")
node_degrees = await storage.node_degrees_batch(node_ids) node_degrees = await storage.node_degrees_batch(node_ids)
print(f"批量获取节点度数结果: {node_degrees}") print(f"批量获取节点度数结果: {node_degrees}")
assert len(node_degrees) == 3, f"应返回3个节点的度数实际返回 {len(node_degrees)}" assert (
len(node_degrees) == 3
), f"应返回3个节点的度数实际返回 {len(node_degrees)}"
assert node1_id in node_degrees, f"{node1_id} 应在返回结果中" assert node1_id in node_degrees, f"{node1_id} 应在返回结果中"
assert node2_id in node_degrees, f"{node2_id} 应在返回结果中" assert node2_id in node_degrees, f"{node2_id} 应在返回结果中"
assert node3_id in node_degrees, f"{node3_id} 应在返回结果中" assert node3_id in node_degrees, f"{node3_id} 应在返回结果中"
assert node_degrees[node1_id] == 3, f"{node1_id} 度数应为3实际为 {node_degrees[node1_id]}" assert (
assert node_degrees[node2_id] == 2, f"{node2_id} 度数应为2实际为 {node_degrees[node2_id]}" node_degrees[node1_id] == 3
assert node_degrees[node3_id] == 3, f"{node3_id} 度数应为3实际为 {node_degrees[node3_id]}" ), f"{node1_id} 度数应为3实际为 {node_degrees[node1_id]}"
assert (
node_degrees[node2_id] == 2
), f"{node2_id} 度数应为2实际为 {node_degrees[node2_id]}"
assert (
node_degrees[node3_id] == 3
), f"{node3_id} 度数应为3实际为 {node_degrees[node3_id]}"
# 4. 测试 edge_degrees_batch - 批量获取多个边的度数 # 4. 测试 edge_degrees_batch - 批量获取多个边的度数
print("== 测试 edge_degrees_batch") print("== 测试 edge_degrees_batch")
edges = [(node1_id, node2_id), (node2_id, node3_id), (node3_id, node4_id)] edges = [(node1_id, node2_id), (node2_id, node3_id), (node3_id, node4_id)]
edge_degrees = await storage.edge_degrees_batch(edges) edge_degrees = await storage.edge_degrees_batch(edges)
print(f"批量获取边度数结果: {edge_degrees}") print(f"批量获取边度数结果: {edge_degrees}")
assert len(edge_degrees) == 3, f"应返回3条边的度数实际返回 {len(edge_degrees)}" assert (
assert (node1_id, node2_id) in edge_degrees, f"{node1_id} -> {node2_id} 应在返回结果中" len(edge_degrees) == 3
assert (node2_id, node3_id) in edge_degrees, f"{node2_id} -> {node3_id} 应在返回结果中" ), f"应返回3条边的度数实际返回 {len(edge_degrees)}"
assert (node3_id, node4_id) in edge_degrees, f"{node3_id} -> {node4_id} 应在返回结果中" assert (
node1_id,
node2_id,
) in edge_degrees, f"{node1_id} -> {node2_id} 应在返回结果中"
assert (
node2_id,
node3_id,
) in edge_degrees, f"{node2_id} -> {node3_id} 应在返回结果中"
assert (
node3_id,
node4_id,
) in edge_degrees, f"{node3_id} -> {node4_id} 应在返回结果中"
# 验证边的度数是否正确(源节点度数 + 目标节点度数) # 验证边的度数是否正确(源节点度数 + 目标节点度数)
assert edge_degrees[(node1_id, node2_id)] == 5, f"{node1_id} -> {node2_id} 度数应为5实际为 {edge_degrees[(node1_id, node2_id)]}" assert (
assert edge_degrees[(node2_id, node3_id)] == 5, f"{node2_id} -> {node3_id} 度数应为5实际为 {edge_degrees[(node2_id, node3_id)]}" edge_degrees[(node1_id, node2_id)] == 5
assert edge_degrees[(node3_id, node4_id)] == 5, f"{node3_id} -> {node4_id} 度数应为5实际为 {edge_degrees[(node3_id, node4_id)]}" ), f"{node1_id} -> {node2_id} 度数应为5实际为 {edge_degrees[(node1_id, node2_id)]}"
assert (
edge_degrees[(node2_id, node3_id)] == 5
), f"{node2_id} -> {node3_id} 度数应为5实际为 {edge_degrees[(node2_id, node3_id)]}"
assert (
edge_degrees[(node3_id, node4_id)] == 5
), f"{node3_id} -> {node4_id} 度数应为5实际为 {edge_degrees[(node3_id, node4_id)]}"
# 5. 测试 get_edges_batch - 批量获取多个边的属性 # 5. 测试 get_edges_batch - 批量获取多个边的属性
print("== 测试 get_edges_batch") print("== 测试 get_edges_batch")
@@ -547,28 +578,54 @@ async def test_graph_batch_operations(storage):
edges_dict = await storage.get_edges_batch(edge_dicts) edges_dict = await storage.get_edges_batch(edge_dicts)
print(f"批量获取边属性结果: {edges_dict.keys()}") print(f"批量获取边属性结果: {edges_dict.keys()}")
assert len(edges_dict) == 3, f"应返回3条边的属性实际返回 {len(edges_dict)}" assert len(edges_dict) == 3, f"应返回3条边的属性实际返回 {len(edges_dict)}"
assert (node1_id, node2_id) in edges_dict, f"{node1_id} -> {node2_id} 应在返回结果中" assert (
assert (node2_id, node3_id) in edges_dict, f"{node2_id} -> {node3_id} 应在返回结果中" node1_id,
assert (node3_id, node4_id) in edges_dict, f"{node3_id} -> {node4_id} 应在返回结果中" node2_id,
assert edges_dict[(node1_id, node2_id)]["relationship"] == edge1_data["relationship"], f"{node1_id} -> {node2_id} 关系不匹配" ) in edges_dict, f"{node1_id} -> {node2_id} 应在返回结果中"
assert edges_dict[(node2_id, node3_id)]["relationship"] == edge2_data["relationship"], f"{node2_id} -> {node3_id} 关系不匹配" assert (
assert edges_dict[(node3_id, node4_id)]["relationship"] == edge5_data["relationship"], f"{node3_id} -> {node4_id} 关系不匹配" node2_id,
node3_id,
) in edges_dict, f"{node2_id} -> {node3_id} 应在返回结果中"
assert (
node3_id,
node4_id,
) in edges_dict, f"{node3_id} -> {node4_id} 应在返回结果中"
assert (
edges_dict[(node1_id, node2_id)]["relationship"]
== edge1_data["relationship"]
), f"{node1_id} -> {node2_id} 关系不匹配"
assert (
edges_dict[(node2_id, node3_id)]["relationship"]
== edge2_data["relationship"]
), f"{node2_id} -> {node3_id} 关系不匹配"
assert (
edges_dict[(node3_id, node4_id)]["relationship"]
== edge5_data["relationship"]
), f"{node3_id} -> {node4_id} 关系不匹配"
# 6. 测试 get_nodes_edges_batch - 批量获取多个节点的所有边 # 6. 测试 get_nodes_edges_batch - 批量获取多个节点的所有边
print("== 测试 get_nodes_edges_batch") print("== 测试 get_nodes_edges_batch")
nodes_edges = await storage.get_nodes_edges_batch([node1_id, node3_id]) nodes_edges = await storage.get_nodes_edges_batch([node1_id, node3_id])
print(f"批量获取节点边结果: {nodes_edges.keys()}") print(f"批量获取节点边结果: {nodes_edges.keys()}")
assert len(nodes_edges) == 2, f"应返回2个节点的边实际返回 {len(nodes_edges)}" assert (
len(nodes_edges) == 2
), f"应返回2个节点的边实际返回 {len(nodes_edges)}"
assert node1_id in nodes_edges, f"{node1_id} 应在返回结果中" assert node1_id in nodes_edges, f"{node1_id} 应在返回结果中"
assert node3_id in nodes_edges, f"{node3_id} 应在返回结果中" assert node3_id in nodes_edges, f"{node3_id} 应在返回结果中"
assert len(nodes_edges[node1_id]) == 3, f"{node1_id} 应有3条边实际有 {len(nodes_edges[node1_id])}" assert (
assert len(nodes_edges[node3_id]) == 3, f"{node3_id} 应有3条边实际有 {len(nodes_edges[node3_id])}" len(nodes_edges[node1_id]) == 3
), f"{node1_id} 应有3条边实际有 {len(nodes_edges[node1_id])}"
assert (
len(nodes_edges[node3_id]) == 3
), f"{node3_id} 应有3条边实际有 {len(nodes_edges[node3_id])}"
# 7. 清理数据 # 7. 清理数据
print("== 测试 drop") print("== 测试 drop")
result = await storage.drop() result = await storage.drop()
print(f"清理结果: {result}") print(f"清理结果: {result}")
assert result["status"] == "success", f"清理应成功,实际状态为 {result['status']}" assert (
result["status"] == "success"
), f"清理应成功,实际状态为 {result['status']}"
print("\n批量操作测试完成") print("\n批量操作测试完成")
return True return True
@@ -630,7 +687,7 @@ async def main():
if basic_result: if basic_result:
ASCIIColors.cyan("\n=== 开始高级测试 ===") ASCIIColors.cyan("\n=== 开始高级测试 ===")
advanced_result = await test_graph_advanced(storage) advanced_result = await test_graph_advanced(storage)
if advanced_result: if advanced_result:
ASCIIColors.cyan("\n=== 开始批量操作测试 ===") ASCIIColors.cyan("\n=== 开始批量操作测试 ===")
await test_graph_batch_operations(storage) await test_graph_batch_operations(storage)