diff --git a/lightrag/base.py b/lightrag/base.py index 969a5a8c..b1f63fa5 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -363,7 +363,7 @@ class BaseGraphStorage(StorageNameSpace, ABC): async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: """Get nodes as a batch using UNWIND - + Default implementation fetches nodes one by one. Override this method for better performance in storage backends that support batch operations. @@ -377,7 +377,7 @@ class BaseGraphStorage(StorageNameSpace, ABC): async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]: """Node degrees as a batch using UNWIND - + Default implementation fetches node degrees one by one. Override this method for better performance in storage backends that support batch operations. @@ -388,9 +388,11 @@ class BaseGraphStorage(StorageNameSpace, ABC): result[node_id] = degree return result - async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> dict[tuple[str, str], int]: + async def edge_degrees_batch( + self, edge_pairs: list[tuple[str, str]] + ) -> dict[tuple[str, str], int]: """Edge degrees as a batch using UNWIND also uses node_degrees_batch - + Default implementation calculates edge degrees one by one. Override this method for better performance in storage backends that support batch operations. @@ -401,9 +403,11 @@ class BaseGraphStorage(StorageNameSpace, ABC): result[(src_id, tgt_id)] = degree return result - async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]: + async def get_edges_batch( + self, pairs: list[dict[str, str]] + ) -> dict[tuple[str, str], dict]: """Get edges as a batch using UNWIND - + Default implementation fetches edges one by one. Override this method for better performance in storage backends that support batch operations. @@ -417,9 +421,11 @@ class BaseGraphStorage(StorageNameSpace, ABC): result[(src_id, tgt_id)] = edge return result - async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]: + async def get_nodes_edges_batch( + self, node_ids: list[str] + ) -> dict[str, list[tuple[str, str]]]: """Get nodes edges as a batch using UNWIND - + Default implementation fetches node edges one by one. Override this method for better performance in storage backends that support batch operations. diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 44b3e20e..1b712462 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -311,10 +311,10 @@ class Neo4JStorage(BaseGraphStorage): async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: """ Retrieve multiple nodes in one query using UNWIND. - + Args: node_ids: List of node entity IDs to fetch. - + Returns: A dictionary mapping each node_id to its node data (or None if not found). """ @@ -334,7 +334,9 @@ class Neo4JStorage(BaseGraphStorage): node_dict = dict(node) # Remove the 'base' label if present in a 'labels' property if "labels" in node_dict: - node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"] + node_dict["labels"] = [ + label for label in node_dict["labels"] if label != "base" + ] nodes[entity_id] = node_dict await result.consume() # Make sure to consume the result fully return nodes @@ -385,12 +387,12 @@ class Neo4JStorage(BaseGraphStorage): async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]: """ Retrieve the degree for multiple nodes in a single query using UNWIND. - + Args: node_ids: List of node labels (entity_id values) to look up. - + Returns: - A dictionary mapping each node_id to its degree (number of relationships). + A dictionary mapping each node_id to its degree (number of relationships). If a node is not found, its degree will be set to 0. """ async with self._driver.session( @@ -407,13 +409,13 @@ class Neo4JStorage(BaseGraphStorage): entity_id = record["entity_id"] degrees[entity_id] = record["degree"] await result.consume() # Ensure result is fully consumed - + # For any node_id that did not return a record, set degree to 0. for nid in node_ids: if nid not in degrees: logger.warning(f"No node found with label '{nid}'") degrees[nid] = 0 - + logger.debug(f"Neo4j batch node degree query returned: {degrees}") return degrees @@ -436,25 +438,27 @@ class Neo4JStorage(BaseGraphStorage): degrees = int(src_degree) + int(trg_degree) return degrees - - async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> dict[tuple[str, str], int]: + + async def edge_degrees_batch( + self, edge_pairs: list[tuple[str, str]] + ) -> dict[tuple[str, str], int]: """ Calculate the combined degree for each edge (sum of the source and target node degrees) in batch using the already implemented node_degrees_batch. - + Args: edge_pairs: List of (src, tgt) tuples. - + Returns: A dictionary mapping each (src, tgt) tuple to the sum of their degrees. """ # Collect unique node IDs from all edge pairs. unique_node_ids = {src for src, _ in edge_pairs} unique_node_ids.update({tgt for _, tgt in edge_pairs}) - + # Get degrees for all nodes in one go. degrees = await self.node_degrees_batch(list(unique_node_ids)) - + # Sum up degrees for each edge pair. edge_degrees = {} for src, tgt in edge_pairs: @@ -547,13 +551,15 @@ class Neo4JStorage(BaseGraphStorage): ) raise - async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]: + async def get_edges_batch( + self, pairs: list[dict[str, str]] + ) -> dict[tuple[str, str], dict]: """ Retrieve edge properties for multiple (src, tgt) pairs in one query. - + Args: pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...] - + Returns: A dictionary mapping (src, tgt) tuples to their edge properties. """ @@ -574,13 +580,23 @@ class Neo4JStorage(BaseGraphStorage): if edges and len(edges) > 0: edge_props = edges[0] # choose the first if multiple exist # Ensure required keys exist with defaults - for key, default in {"weight": 0.0, "source_id": None, "description": None, "keywords": None}.items(): + for key, default in { + "weight": 0.0, + "source_id": None, + "description": None, + "keywords": None, + }.items(): if key not in edge_props: edge_props[key] = default edges_dict[(src, tgt)] = edge_props else: # No edge found – set default edge properties - edges_dict[(src, tgt)] = {"weight": 0.0, "source_id": None, "description": None, "keywords": None} + edges_dict[(src, tgt)] = { + "weight": 0.0, + "source_id": None, + "description": None, + "keywords": None, + } await result.consume() return edges_dict @@ -644,17 +660,21 @@ class Neo4JStorage(BaseGraphStorage): logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}") raise - async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]: + async def get_nodes_edges_batch( + self, node_ids: list[str] + ) -> dict[str, list[tuple[str, str]]]: """ Batch retrieve edges for multiple nodes in one query using UNWIND. - + Args: node_ids: List of node IDs (entity_id) for which to retrieve edges. - + Returns: A dictionary mapping each node ID to its list of edge tuples (source, target). """ - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: query = """ UNWIND $node_ids AS id MATCH (n:base {entity_id: id}) diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index f567055c..872bc6bb 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -1461,30 +1461,29 @@ class PGGraphStorage(BaseGraphStorage): async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: """ Retrieve multiple nodes in one query using UNWIND. - + Args: node_ids: List of node entity IDs to fetch. - + Returns: A dictionary mapping each node_id to its node data (or None if not found). """ if not node_ids: return {} - + # Format node IDs for the query - formatted_ids = ", ".join(['"' + node_id.replace('"', '') + '"' for node_id in node_ids]) - + formatted_ids = ", ".join( + ['"' + node_id.replace('"', "") + '"' for node_id in node_ids] + ) + query = """SELECT * FROM cypher('%s', $$ UNWIND [%s] AS node_id MATCH (n:base {entity_id: node_id}) RETURN node_id, n - $$) AS (node_id text, n agtype)""" % ( - self.graph_name, - formatted_ids - ) - + $$) AS (node_id text, n agtype)""" % (self.graph_name, formatted_ids) + results = await self._query(query) - + # Build result dictionary nodes_dict = {} for result in results: @@ -1492,28 +1491,32 @@ class PGGraphStorage(BaseGraphStorage): node_dict = result["n"]["properties"] # Remove the 'base' label if present in a 'labels' property if "labels" in node_dict: - node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"] + node_dict["labels"] = [ + label for label in node_dict["labels"] if label != "base" + ] nodes_dict[result["node_id"]] = node_dict - + return nodes_dict async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]: """ Retrieve the degree for multiple nodes in a single query using UNWIND. - + Args: node_ids: List of node labels (entity_id values) to look up. - + Returns: - A dictionary mapping each node_id to its degree (number of relationships). + A dictionary mapping each node_id to its degree (number of relationships). If a node is not found, its degree will be set to 0. """ if not node_ids: return {} - + # Format node IDs for the query - formatted_ids = ", ".join(['"' + node_id.replace('"', '') + '"' for node_id in node_ids]) - + formatted_ids = ", ".join( + ['"' + node_id.replace('"', "") + '"' for node_id in node_ids] + ) + query = """SELECT * FROM cypher('%s', $$ UNWIND [%s] AS node_id MATCH (n:base {entity_id: node_id}) @@ -1521,112 +1524,122 @@ class PGGraphStorage(BaseGraphStorage): RETURN node_id, count(r) AS degree $$) AS (node_id text, degree bigint)""" % ( self.graph_name, - formatted_ids + formatted_ids, ) - + results = await self._query(query) - + # Build result dictionary degrees_dict = {} for result in results: if result["node_id"] is not None: degrees_dict[result["node_id"]] = int(result["degree"]) - + # Ensure all requested node_ids are in the result dictionary for node_id in node_ids: if node_id not in degrees_dict: degrees_dict[node_id] = 0 - + return degrees_dict - - async def edge_degrees_batch(self, edges: list[tuple[str, str]]) -> dict[tuple[str, str], int]: + + async def edge_degrees_batch( + self, edges: list[tuple[str, str]] + ) -> dict[tuple[str, str], int]: """ Calculate the combined degree for each edge (sum of the source and target node degrees) in batch using the already implemented node_degrees_batch. - + Args: edges: List of (source_node_id, target_node_id) tuples - + Returns: Dictionary mapping edge tuples to their combined degrees """ if not edges: return {} - + # Use node_degrees_batch to get all node degrees efficiently all_nodes = set() for src, tgt in edges: all_nodes.add(src) all_nodes.add(tgt) - + node_degrees = await self.node_degrees_batch(list(all_nodes)) - + # Calculate edge degrees edge_degrees_dict = {} for src, tgt in edges: src_degree = node_degrees.get(src, 0) tgt_degree = node_degrees.get(tgt, 0) edge_degrees_dict[(src, tgt)] = src_degree + tgt_degree - + return edge_degrees_dict - - async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]: + + async def get_edges_batch( + self, pairs: list[dict[str, str]] + ) -> dict[tuple[str, str], dict]: """ Retrieve edge properties for multiple (src, tgt) pairs in one query. - + Args: pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...] - + Returns: A dictionary mapping (src, tgt) tuples to their edge properties. """ if not pairs: return {} - + # 从字典列表中提取源节点和目标节点ID src_nodes = [] tgt_nodes = [] for pair in pairs: - src_nodes.append(pair["src"].replace('"', '')) - tgt_nodes.append(pair["tgt"].replace('"', '')) - + src_nodes.append(pair["src"].replace('"', "")) + tgt_nodes.append(pair["tgt"].replace('"', "")) + # 构建查询,使用数组索引来匹配源节点和目标节点 src_array = ", ".join([f'"{src}"' for src in src_nodes]) tgt_array = ", ".join([f'"{tgt}"' for tgt in tgt_nodes]) - + query = f"""SELECT * FROM cypher('{self.graph_name}', $$ WITH [{src_array}] AS sources, [{tgt_array}] AS targets UNWIND range(0, size(sources)-1) AS i MATCH (a:base {{entity_id: sources[i]}})-[r:DIRECTED]-(b:base {{entity_id: targets[i]}}) RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties $$) AS (source text, target text, edge_properties agtype)""" - + results = await self._query(query) - + # 构建结果字典 edges_dict = {} for result in results: if result["source"] and result["target"] and result["edge_properties"]: - edges_dict[(result["source"], result["target"])] = result["edge_properties"] - + edges_dict[(result["source"], result["target"])] = result[ + "edge_properties" + ] + return edges_dict - - async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]: + + async def get_nodes_edges_batch( + self, node_ids: list[str] + ) -> dict[str, list[tuple[str, str]]]: """ Get all edges for multiple nodes in a single batch operation. - + Args: node_ids: List of node IDs to get edges for - + Returns: Dictionary mapping node IDs to lists of (source, target) edge tuples """ if not node_ids: return {} - + # Format node IDs for the query - formatted_ids = ", ".join(['"' + node_id.replace('"', '') + '"' for node_id in node_ids]) - + formatted_ids = ", ".join( + ['"' + node_id.replace('"', "") + '"' for node_id in node_ids] + ) + query = """SELECT * FROM cypher('%s', $$ UNWIND [%s] AS node_id MATCH (n:base {entity_id: node_id}) @@ -1634,11 +1647,11 @@ class PGGraphStorage(BaseGraphStorage): RETURN node_id, connected.entity_id AS connected_id $$) AS (node_id text, connected_id text)""" % ( self.graph_name, - formatted_ids + formatted_ids, ) - + results = await self._query(query) - + # Build result dictionary nodes_edges_dict = {node_id: [] for node_id in node_ids} for result in results: @@ -1646,9 +1659,9 @@ class PGGraphStorage(BaseGraphStorage): nodes_edges_dict[result["node_id"]].append( (result["node_id"], result["connected_id"]) ) - + return nodes_edges_dict - + async def get_all_labels(self) -> list[str]: """ Get all labels (node IDs) in the graph. diff --git a/lightrag/operate.py b/lightrag/operate.py index c167b001..fca5ff32 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -1323,14 +1323,14 @@ async def _get_node_data( if not len(results): return "", "", "" - + # Extract all entity IDs from your results list node_ids = [r["entity_name"] for r in results] # Call the batch node retrieval and degree functions concurrently. nodes_dict, degrees_dict = await asyncio.gather( - knowledge_graph_inst.get_nodes_batch(node_ids), - knowledge_graph_inst.node_degrees_batch(node_ids) + knowledge_graph_inst.get_nodes_batch(node_ids), + knowledge_graph_inst.node_degrees_batch(node_ids), ) # Now, if you need the node data and degree in order: @@ -1459,7 +1459,7 @@ async def _find_most_related_text_unit_from_entities( for dp in node_datas if dp["source_id"] is not None ] - + node_names = [dp["entity_name"] for dp in node_datas] batch_edges_dict = await knowledge_graph_inst.get_nodes_edges_batch(node_names) # Build the edges list in the same order as node_datas. @@ -1472,10 +1472,14 @@ async def _find_most_related_text_unit_from_entities( all_one_hop_nodes.update([e[1] for e in this_edges]) all_one_hop_nodes = list(all_one_hop_nodes) - + # Batch retrieve one-hop node data using get_nodes_batch - all_one_hop_nodes_data_dict = await knowledge_graph_inst.get_nodes_batch(all_one_hop_nodes) - all_one_hop_nodes_data = [all_one_hop_nodes_data_dict.get(e) for e in all_one_hop_nodes] + all_one_hop_nodes_data_dict = await knowledge_graph_inst.get_nodes_batch( + all_one_hop_nodes + ) + all_one_hop_nodes_data = [ + all_one_hop_nodes_data_dict.get(e) for e in all_one_hop_nodes + ] # Add null check for node data all_one_hop_text_units_lookup = { @@ -1571,13 +1575,13 @@ async def _find_most_related_edges_from_entities( edge_pairs_dicts = [{"src": e[0], "tgt": e[1]} for e in all_edges] # For edge degrees, use tuples. edge_pairs_tuples = list(all_edges) # all_edges is already a list of tuples - + # Call the batched functions concurrently. edge_data_dict, edge_degrees_dict = await asyncio.gather( knowledge_graph_inst.get_edges_batch(edge_pairs_dicts), - knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples) + knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples), ) - + # Reconstruct edge_datas list in the same order as the deduplicated results. all_edges_data = [] for pair in all_edges: @@ -1590,7 +1594,6 @@ async def _find_most_related_edges_from_entities( } all_edges_data.append(combined) - all_edges_data = sorted( all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True ) @@ -1634,7 +1637,7 @@ async def _get_edge_data( # Call the batched functions concurrently. edge_data_dict, edge_degrees_dict = await asyncio.gather( knowledge_graph_inst.get_edges_batch(edge_pairs_dicts), - knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples) + knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples), ) # Reconstruct edge_datas list in the same order as results. @@ -1652,7 +1655,7 @@ async def _get_edge_data( **edge_props, } edge_datas.append(combined) - + edge_datas = sorted( edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True ) @@ -1761,7 +1764,7 @@ async def _find_most_related_entities_from_relationships( # Batch approach: Retrieve nodes and their degrees concurrently with one query each. nodes_dict, degrees_dict = await asyncio.gather( knowledge_graph_inst.get_nodes_batch(entity_names), - knowledge_graph_inst.node_degrees_batch(entity_names) + knowledge_graph_inst.node_degrees_batch(entity_names), ) # Rebuild the list in the same order as entity_names diff --git a/lightrag_webui/src/stores/graph.ts b/lightrag_webui/src/stores/graph.ts index c4c6a285..fb035cb9 100644 --- a/lightrag_webui/src/stores/graph.ts +++ b/lightrag_webui/src/stores/graph.ts @@ -136,7 +136,7 @@ interface GraphState { // Version counter to trigger data refresh graphDataVersion: number incrementGraphDataVersion: () => void - + // Methods for updating graph elements and UI state together updateNodeAndSelect: (nodeId: string, entityId: string, propertyName: string, newValue: string) => Promise updateEdgeAndSelect: (edgeId: string, dynamicId: string, sourceId: string, targetId: string, propertyName: string, newValue: string) => Promise @@ -252,40 +252,40 @@ const useGraphStoreBase = create()((set, get) => ({ // Get current state const state = get() const { sigmaGraph, rawGraph } = state - + // Validate graph state if (!sigmaGraph || !rawGraph || !sigmaGraph.hasNode(nodeId)) { return } - + try { const nodeAttributes = sigmaGraph.getNodeAttributes(nodeId) - + console.log('updateNodeAndSelect', nodeId, entityId, propertyName, newValue) - + // For entity_id changes (node renaming) with NetworkX graph storage if ((nodeId === entityId) && (propertyName === 'entity_id')) { // Create new node with updated ID but same attributes sigmaGraph.addNode(newValue, { ...nodeAttributes, label: newValue }) - + const edgesToUpdate: EdgeToUpdate[] = [] - + // Process all edges connected to this node sigmaGraph.forEachEdge(nodeId, (edge, attributes, source, target) => { const otherNode = source === nodeId ? target : source const isOutgoing = source === nodeId - + // Get original edge dynamic ID for later reference const originalEdgeDynamicId = edge const edgeIndexInRawGraph = rawGraph.edgeDynamicIdMap[originalEdgeDynamicId] - + // Create new edge with updated node reference const newEdgeId = sigmaGraph.addEdge( isOutgoing ? newValue : otherNode, isOutgoing ? otherNode : newValue, attributes ) - + // Track edges that need updating in the raw graph if (edgeIndexInRawGraph !== undefined) { edgesToUpdate.push({ @@ -294,14 +294,14 @@ const useGraphStoreBase = create()((set, get) => ({ edgeIndex: edgeIndexInRawGraph }) } - + // Remove the old edge sigmaGraph.dropEdge(edge) }) - + // Remove the old node after all edges are processed sigmaGraph.dropNode(nodeId) - + // Update node reference in raw graph data const nodeIndex = rawGraph.nodeIdMap[nodeId] if (nodeIndex !== undefined) { @@ -311,7 +311,7 @@ const useGraphStoreBase = create()((set, get) => ({ delete rawGraph.nodeIdMap[nodeId] rawGraph.nodeIdMap[newValue] = nodeIndex } - + // Update all edge references in raw graph data edgesToUpdate.forEach(({ originalDynamicId, newEdgeId, edgeIndex }) => { if (rawGraph.edges[edgeIndex]) { @@ -322,14 +322,14 @@ const useGraphStoreBase = create()((set, get) => ({ if (rawGraph.edges[edgeIndex].target === nodeId) { rawGraph.edges[edgeIndex].target = newValue } - + // Update dynamic ID mappings rawGraph.edges[edgeIndex].dynamicId = newEdgeId delete rawGraph.edgeDynamicIdMap[originalDynamicId] rawGraph.edgeDynamicIdMap[newEdgeId] = edgeIndex } }) - + // Update selected node in store set({ selectedNode: newValue, moveToSelectedNode: true }) } else { @@ -342,7 +342,7 @@ const useGraphStoreBase = create()((set, get) => ({ sigmaGraph.setNodeAttribute(String(nodeId), 'label', newValue) } } - + // Trigger a re-render by incrementing the version counter set((state) => ({ graphDataVersion: state.graphDataVersion + 1 })) } @@ -351,17 +351,17 @@ const useGraphStoreBase = create()((set, get) => ({ throw new Error('Failed to update node in graph') } }, - + updateEdgeAndSelect: async (edgeId: string, dynamicId: string, sourceId: string, targetId: string, propertyName: string, newValue: string) => { // Get current state const state = get() const { sigmaGraph, rawGraph } = state - + // Validate graph state if (!sigmaGraph || !rawGraph) { return } - + try { const edgeIndex = rawGraph.edgeIdMap[String(edgeId)] if (edgeIndex !== undefined && rawGraph.edges[edgeIndex]) { @@ -370,10 +370,10 @@ const useGraphStoreBase = create()((set, get) => ({ sigmaGraph.setEdgeAttribute(dynamicId, 'label', newValue) } } - + // Trigger a re-render by incrementing the version counter set((state) => ({ graphDataVersion: state.graphDataVersion + 1 })) - + // Update selected edge in store to ensure UI reflects changes set({ selectedEdge: dynamicId }) } catch (error) { diff --git a/lightrag_webui/src/utils/graphOperations.ts b/lightrag_webui/src/utils/graphOperations.ts index 5bc9c5ba..9a506b51 100644 --- a/lightrag_webui/src/utils/graphOperations.ts +++ b/lightrag_webui/src/utils/graphOperations.ts @@ -3,7 +3,7 @@ import { useGraphStore } from '@/stores/graph' /** * Update node in the graph visualization * This function is now a wrapper around the store's updateNodeAndSelect method - * + * * @param nodeId - ID of the node to update * @param entityId - ID of the entity * @param propertyName - Name of the property being updated diff --git a/tests/test_graph_storage.py b/tests/test_graph_storage.py index eb59594b..344ba118 100644 --- a/tests/test_graph_storage.py +++ b/tests/test_graph_storage.py @@ -510,35 +510,66 @@ async def test_graph_batch_operations(storage): assert node1_id in nodes_dict, f"{node1_id} 应在返回结果中" assert node2_id in nodes_dict, f"{node2_id} 应在返回结果中" assert node3_id in nodes_dict, f"{node3_id} 应在返回结果中" - assert nodes_dict[node1_id]["description"] == node1_data["description"], f"{node1_id} 描述不匹配" - assert nodes_dict[node2_id]["description"] == node2_data["description"], f"{node2_id} 描述不匹配" - assert nodes_dict[node3_id]["description"] == node3_data["description"], f"{node3_id} 描述不匹配" + assert ( + nodes_dict[node1_id]["description"] == node1_data["description"] + ), f"{node1_id} 描述不匹配" + assert ( + nodes_dict[node2_id]["description"] == node2_data["description"] + ), f"{node2_id} 描述不匹配" + assert ( + nodes_dict[node3_id]["description"] == node3_data["description"] + ), f"{node3_id} 描述不匹配" # 3. 测试 node_degrees_batch - 批量获取多个节点的度数 print("== 测试 node_degrees_batch") node_degrees = await storage.node_degrees_batch(node_ids) print(f"批量获取节点度数结果: {node_degrees}") - assert len(node_degrees) == 3, f"应返回3个节点的度数,实际返回 {len(node_degrees)} 个" + assert ( + len(node_degrees) == 3 + ), f"应返回3个节点的度数,实际返回 {len(node_degrees)} 个" assert node1_id in node_degrees, f"{node1_id} 应在返回结果中" assert node2_id in node_degrees, f"{node2_id} 应在返回结果中" assert node3_id in node_degrees, f"{node3_id} 应在返回结果中" - assert node_degrees[node1_id] == 3, f"{node1_id} 度数应为3,实际为 {node_degrees[node1_id]}" - assert node_degrees[node2_id] == 2, f"{node2_id} 度数应为2,实际为 {node_degrees[node2_id]}" - assert node_degrees[node3_id] == 3, f"{node3_id} 度数应为3,实际为 {node_degrees[node3_id]}" + assert ( + node_degrees[node1_id] == 3 + ), f"{node1_id} 度数应为3,实际为 {node_degrees[node1_id]}" + assert ( + node_degrees[node2_id] == 2 + ), f"{node2_id} 度数应为2,实际为 {node_degrees[node2_id]}" + assert ( + node_degrees[node3_id] == 3 + ), f"{node3_id} 度数应为3,实际为 {node_degrees[node3_id]}" # 4. 测试 edge_degrees_batch - 批量获取多个边的度数 print("== 测试 edge_degrees_batch") edges = [(node1_id, node2_id), (node2_id, node3_id), (node3_id, node4_id)] edge_degrees = await storage.edge_degrees_batch(edges) print(f"批量获取边度数结果: {edge_degrees}") - assert len(edge_degrees) == 3, f"应返回3条边的度数,实际返回 {len(edge_degrees)} 条" - assert (node1_id, node2_id) in edge_degrees, f"边 {node1_id} -> {node2_id} 应在返回结果中" - assert (node2_id, node3_id) in edge_degrees, f"边 {node2_id} -> {node3_id} 应在返回结果中" - assert (node3_id, node4_id) in edge_degrees, f"边 {node3_id} -> {node4_id} 应在返回结果中" + assert ( + len(edge_degrees) == 3 + ), f"应返回3条边的度数,实际返回 {len(edge_degrees)} 条" + assert ( + node1_id, + node2_id, + ) in edge_degrees, f"边 {node1_id} -> {node2_id} 应在返回结果中" + assert ( + node2_id, + node3_id, + ) in edge_degrees, f"边 {node2_id} -> {node3_id} 应在返回结果中" + assert ( + node3_id, + node4_id, + ) in edge_degrees, f"边 {node3_id} -> {node4_id} 应在返回结果中" # 验证边的度数是否正确(源节点度数 + 目标节点度数) - assert edge_degrees[(node1_id, node2_id)] == 5, f"边 {node1_id} -> {node2_id} 度数应为5,实际为 {edge_degrees[(node1_id, node2_id)]}" - assert edge_degrees[(node2_id, node3_id)] == 5, f"边 {node2_id} -> {node3_id} 度数应为5,实际为 {edge_degrees[(node2_id, node3_id)]}" - assert edge_degrees[(node3_id, node4_id)] == 5, f"边 {node3_id} -> {node4_id} 度数应为5,实际为 {edge_degrees[(node3_id, node4_id)]}" + assert ( + edge_degrees[(node1_id, node2_id)] == 5 + ), f"边 {node1_id} -> {node2_id} 度数应为5,实际为 {edge_degrees[(node1_id, node2_id)]}" + assert ( + edge_degrees[(node2_id, node3_id)] == 5 + ), f"边 {node2_id} -> {node3_id} 度数应为5,实际为 {edge_degrees[(node2_id, node3_id)]}" + assert ( + edge_degrees[(node3_id, node4_id)] == 5 + ), f"边 {node3_id} -> {node4_id} 度数应为5,实际为 {edge_degrees[(node3_id, node4_id)]}" # 5. 测试 get_edges_batch - 批量获取多个边的属性 print("== 测试 get_edges_batch") @@ -547,28 +578,54 @@ async def test_graph_batch_operations(storage): edges_dict = await storage.get_edges_batch(edge_dicts) print(f"批量获取边属性结果: {edges_dict.keys()}") assert len(edges_dict) == 3, f"应返回3条边的属性,实际返回 {len(edges_dict)} 条" - assert (node1_id, node2_id) in edges_dict, f"边 {node1_id} -> {node2_id} 应在返回结果中" - assert (node2_id, node3_id) in edges_dict, f"边 {node2_id} -> {node3_id} 应在返回结果中" - assert (node3_id, node4_id) in edges_dict, f"边 {node3_id} -> {node4_id} 应在返回结果中" - assert edges_dict[(node1_id, node2_id)]["relationship"] == edge1_data["relationship"], f"边 {node1_id} -> {node2_id} 关系不匹配" - assert edges_dict[(node2_id, node3_id)]["relationship"] == edge2_data["relationship"], f"边 {node2_id} -> {node3_id} 关系不匹配" - assert edges_dict[(node3_id, node4_id)]["relationship"] == edge5_data["relationship"], f"边 {node3_id} -> {node4_id} 关系不匹配" + assert ( + node1_id, + node2_id, + ) in edges_dict, f"边 {node1_id} -> {node2_id} 应在返回结果中" + assert ( + node2_id, + node3_id, + ) in edges_dict, f"边 {node2_id} -> {node3_id} 应在返回结果中" + assert ( + node3_id, + node4_id, + ) in edges_dict, f"边 {node3_id} -> {node4_id} 应在返回结果中" + assert ( + edges_dict[(node1_id, node2_id)]["relationship"] + == edge1_data["relationship"] + ), f"边 {node1_id} -> {node2_id} 关系不匹配" + assert ( + edges_dict[(node2_id, node3_id)]["relationship"] + == edge2_data["relationship"] + ), f"边 {node2_id} -> {node3_id} 关系不匹配" + assert ( + edges_dict[(node3_id, node4_id)]["relationship"] + == edge5_data["relationship"] + ), f"边 {node3_id} -> {node4_id} 关系不匹配" # 6. 测试 get_nodes_edges_batch - 批量获取多个节点的所有边 print("== 测试 get_nodes_edges_batch") nodes_edges = await storage.get_nodes_edges_batch([node1_id, node3_id]) print(f"批量获取节点边结果: {nodes_edges.keys()}") - assert len(nodes_edges) == 2, f"应返回2个节点的边,实际返回 {len(nodes_edges)} 个" + assert ( + len(nodes_edges) == 2 + ), f"应返回2个节点的边,实际返回 {len(nodes_edges)} 个" assert node1_id in nodes_edges, f"{node1_id} 应在返回结果中" assert node3_id in nodes_edges, f"{node3_id} 应在返回结果中" - assert len(nodes_edges[node1_id]) == 3, f"{node1_id} 应有3条边,实际有 {len(nodes_edges[node1_id])} 条" - assert len(nodes_edges[node3_id]) == 3, f"{node3_id} 应有3条边,实际有 {len(nodes_edges[node3_id])} 条" + assert ( + len(nodes_edges[node1_id]) == 3 + ), f"{node1_id} 应有3条边,实际有 {len(nodes_edges[node1_id])} 条" + assert ( + len(nodes_edges[node3_id]) == 3 + ), f"{node3_id} 应有3条边,实际有 {len(nodes_edges[node3_id])} 条" # 7. 清理数据 print("== 测试 drop") result = await storage.drop() print(f"清理结果: {result}") - assert result["status"] == "success", f"清理应成功,实际状态为 {result['status']}" + assert ( + result["status"] == "success" + ), f"清理应成功,实际状态为 {result['status']}" print("\n批量操作测试完成") return True @@ -630,7 +687,7 @@ async def main(): if basic_result: ASCIIColors.cyan("\n=== 开始高级测试 ===") advanced_result = await test_graph_advanced(storage) - + if advanced_result: ASCIIColors.cyan("\n=== 开始批量操作测试 ===") await test_graph_batch_operations(storage)