Merge pull request #1383 from HKUDS/graph-storage-batch-query

Refactor: Optimizing graph storage performance
This commit is contained in:
Daniel.y
2025-04-16 15:28:38 +08:00
committed by GitHub
7 changed files with 1318 additions and 84 deletions

View File

@@ -11,3 +11,26 @@ services:
env_file: env_file:
- .env - .env
restart: unless-stopped restart: unless-stopped
neo4j:
image: neo4j:5.26.4-community
container_name: lightrag-server_neo4j-community
restart: always
ports:
- "7474:7474"
- "7687:7687"
environment:
- NEO4J_AUTH=${NEO4J_USERNAME}/${NEO4J_PASSWORD}
- NEO4J_apoc_export_file_enabled=true
- NEO4J_server_bolt_listen__address=0.0.0.0:7687
- NEO4J_server_bolt_advertised__address=neo4j:7687
volumes:
- ./neo4j/plugins:/var/lib/neo4j/plugins
- lightrag_neo4j_import:/var/lib/neo4j/import
- lightrag_neo4j_data:/data
- lightrag_neo4j_backups:/backups
volumes:
lightrag_neo4j_import:
lightrag_neo4j_data:
lightrag_neo4j_backups:

View File

@@ -1 +1 @@
__api_version__ = "0151" __api_version__ = "0152"

View File

@@ -361,6 +361,81 @@ class BaseGraphStorage(StorageNameSpace, ABC):
or None if the node doesn't exist or None if the node doesn't exist
""" """
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
"""Get nodes as a batch using UNWIND
Default implementation fetches nodes one by one.
Override this method for better performance in storage backends
that support batch operations.
"""
result = {}
for node_id in node_ids:
node = await self.get_node(node_id)
if node is not None:
result[node_id] = node
return result
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
"""Node degrees as a batch using UNWIND
Default implementation fetches node degrees one by one.
Override this method for better performance in storage backends
that support batch operations.
"""
result = {}
for node_id in node_ids:
degree = await self.node_degree(node_id)
result[node_id] = degree
return result
async def edge_degrees_batch(
self, edge_pairs: list[tuple[str, str]]
) -> dict[tuple[str, str], int]:
"""Edge degrees as a batch using UNWIND also uses node_degrees_batch
Default implementation calculates edge degrees one by one.
Override this method for better performance in storage backends
that support batch operations.
"""
result = {}
for src_id, tgt_id in edge_pairs:
degree = await self.edge_degree(src_id, tgt_id)
result[(src_id, tgt_id)] = degree
return result
async def get_edges_batch(
self, pairs: list[dict[str, str]]
) -> dict[tuple[str, str], dict]:
"""Get edges as a batch using UNWIND
Default implementation fetches edges one by one.
Override this method for better performance in storage backends
that support batch operations.
"""
result = {}
for pair in pairs:
src_id = pair["src"]
tgt_id = pair["tgt"]
edge = await self.get_edge(src_id, tgt_id)
if edge is not None:
result[(src_id, tgt_id)] = edge
return result
async def get_nodes_edges_batch(
self, node_ids: list[str]
) -> dict[str, list[tuple[str, str]]]:
"""Get nodes edges as a batch using UNWIND
Default implementation fetches node edges one by one.
Override this method for better performance in storage backends
that support batch operations.
"""
result = {}
for node_id in node_ids:
edges = await self.get_node_edges(node_id)
result[node_id] = edges if edges is not None else []
return result
@abstractmethod @abstractmethod
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
"""Insert a new node or update an existing node in the graph. """Insert a new node or update an existing node in the graph.

View File

@@ -308,6 +308,39 @@ class Neo4JStorage(BaseGraphStorage):
logger.error(f"Error getting node for {node_id}: {str(e)}") logger.error(f"Error getting node for {node_id}: {str(e)}")
raise raise
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
"""
Retrieve multiple nodes in one query using UNWIND.
Args:
node_ids: List of node entity IDs to fetch.
Returns:
A dictionary mapping each node_id to its node data (or None if not found).
"""
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = """
UNWIND $node_ids AS id
MATCH (n:base {entity_id: id})
RETURN n.entity_id AS entity_id, n
"""
result = await session.run(query, node_ids=node_ids)
nodes = {}
async for record in result:
entity_id = record["entity_id"]
node = record["n"]
node_dict = dict(node)
# Remove the 'base' label if present in a 'labels' property
if "labels" in node_dict:
node_dict["labels"] = [
label for label in node_dict["labels"] if label != "base"
]
nodes[entity_id] = node_dict
await result.consume() # Make sure to consume the result fully
return nodes
async def node_degree(self, node_id: str) -> int: async def node_degree(self, node_id: str) -> int:
"""Get the degree (number of relationships) of a node with the given label. """Get the degree (number of relationships) of a node with the given label.
If multiple nodes have the same label, returns the degree of the first node. If multiple nodes have the same label, returns the degree of the first node.
@@ -351,6 +384,41 @@ class Neo4JStorage(BaseGraphStorage):
logger.error(f"Error getting node degree for {node_id}: {str(e)}") logger.error(f"Error getting node degree for {node_id}: {str(e)}")
raise raise
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
"""
Retrieve the degree for multiple nodes in a single query using UNWIND.
Args:
node_ids: List of node labels (entity_id values) to look up.
Returns:
A dictionary mapping each node_id to its degree (number of relationships).
If a node is not found, its degree will be set to 0.
"""
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = """
UNWIND $node_ids AS id
MATCH (n:base {entity_id: id})
RETURN n.entity_id AS entity_id, count { (n)--() } AS degree;
"""
result = await session.run(query, node_ids=node_ids)
degrees = {}
async for record in result:
entity_id = record["entity_id"]
degrees[entity_id] = record["degree"]
await result.consume() # Ensure result is fully consumed
# For any node_id that did not return a record, set degree to 0.
for nid in node_ids:
if nid not in degrees:
logger.warning(f"No node found with label '{nid}'")
degrees[nid] = 0
logger.debug(f"Neo4j batch node degree query returned: {degrees}")
return degrees
async def edge_degree(self, src_id: str, tgt_id: str) -> int: async def edge_degree(self, src_id: str, tgt_id: str) -> int:
"""Get the total degree (sum of relationships) of two nodes. """Get the total degree (sum of relationships) of two nodes.
@@ -371,6 +439,32 @@ class Neo4JStorage(BaseGraphStorage):
degrees = int(src_degree) + int(trg_degree) degrees = int(src_degree) + int(trg_degree)
return degrees return degrees
async def edge_degrees_batch(
self, edge_pairs: list[tuple[str, str]]
) -> dict[tuple[str, str], int]:
"""
Calculate the combined degree for each edge (sum of the source and target node degrees)
in batch using the already implemented node_degrees_batch.
Args:
edge_pairs: List of (src, tgt) tuples.
Returns:
A dictionary mapping each (src, tgt) tuple to the sum of their degrees.
"""
# Collect unique node IDs from all edge pairs.
unique_node_ids = {src for src, _ in edge_pairs}
unique_node_ids.update({tgt for _, tgt in edge_pairs})
# Get degrees for all nodes in one go.
degrees = await self.node_degrees_batch(list(unique_node_ids))
# Sum up degrees for each edge pair.
edge_degrees = {}
for src, tgt in edge_pairs:
edge_degrees[(src, tgt)] = degrees.get(src, 0) + degrees.get(tgt, 0)
return edge_degrees
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None: ) -> dict[str, str] | None:
@@ -457,6 +551,55 @@ class Neo4JStorage(BaseGraphStorage):
) )
raise raise
async def get_edges_batch(
self, pairs: list[dict[str, str]]
) -> dict[tuple[str, str], dict]:
"""
Retrieve edge properties for multiple (src, tgt) pairs in one query.
Args:
pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...]
Returns:
A dictionary mapping (src, tgt) tuples to their edge properties.
"""
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = """
UNWIND $pairs AS pair
MATCH (start:base {entity_id: pair.src})-[r:DIRECTED]-(end:base {entity_id: pair.tgt})
RETURN pair.src AS src_id, pair.tgt AS tgt_id, collect(properties(r)) AS edges
"""
result = await session.run(query, pairs=pairs)
edges_dict = {}
async for record in result:
src = record["src_id"]
tgt = record["tgt_id"]
edges = record["edges"]
if edges and len(edges) > 0:
edge_props = edges[0] # choose the first if multiple exist
# Ensure required keys exist with defaults
for key, default in {
"weight": 0.0,
"source_id": None,
"description": None,
"keywords": None,
}.items():
if key not in edge_props:
edge_props[key] = default
edges_dict[(src, tgt)] = edge_props
else:
# No edge found set default edge properties
edges_dict[(src, tgt)] = {
"weight": 0.0,
"source_id": None,
"description": None,
"keywords": None,
}
await result.consume()
return edges_dict
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
"""Retrieves all edges (relationships) for a particular node identified by its label. """Retrieves all edges (relationships) for a particular node identified by its label.
@@ -517,6 +660,64 @@ class Neo4JStorage(BaseGraphStorage):
logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}") logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}")
raise raise
async def get_nodes_edges_batch(
self, node_ids: list[str]
) -> dict[str, list[tuple[str, str]]]:
"""
Batch retrieve edges for multiple nodes in one query using UNWIND.
For each node, returns both outgoing and incoming edges to properly represent
the undirected graph nature.
Args:
node_ids: List of node IDs (entity_id) for which to retrieve edges.
Returns:
A dictionary mapping each node ID to its list of edge tuples (source, target).
For each node, the list includes both:
- Outgoing edges: (queried_node, connected_node)
- Incoming edges: (connected_node, queried_node)
"""
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
# Query to get both outgoing and incoming edges
query = """
UNWIND $node_ids AS id
MATCH (n:base {entity_id: id})
OPTIONAL MATCH (n)-[r]-(connected:base)
RETURN id AS queried_id, n.entity_id AS node_entity_id,
connected.entity_id AS connected_entity_id,
startNode(r).entity_id AS start_entity_id
"""
result = await session.run(query, node_ids=node_ids)
# Initialize the dictionary with empty lists for each node ID
edges_dict = {node_id: [] for node_id in node_ids}
# Process results to include both outgoing and incoming edges
async for record in result:
queried_id = record["queried_id"]
node_entity_id = record["node_entity_id"]
connected_entity_id = record["connected_entity_id"]
start_entity_id = record["start_entity_id"]
# Skip if either node is None
if not node_entity_id or not connected_entity_id:
continue
# Determine the actual direction of the edge
# If the start node is the queried node, it's an outgoing edge
# Otherwise, it's an incoming edge
if start_entity_id == node_entity_id:
# Outgoing edge: (queried_node -> connected_node)
edges_dict[queried_id].append((node_entity_id, connected_entity_id))
else:
# Incoming edge: (connected_node -> queried_node)
edges_dict[queried_id].append((connected_entity_id, node_entity_id))
await result.consume() # Ensure results are fully consumed
return edges_dict
@retry( @retry(
stop=stop_after_attempt(3), stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10), wait=wait_exponential(multiplier=1, min=4, max=10),

View File

@@ -1027,6 +1027,28 @@ class PGGraphStorage(BaseGraphStorage):
if self.db is None: if self.db is None:
self.db = await ClientManager.get_client() self.db = await ClientManager.get_client()
node1_id = "dummy_entity"
node1_data = {
"entity_id": node1_id,
"description": "dummy description",
"keywords": "dummy,keywords",
"entity_type": "dummy_type",
}
await self.upsert_node(node1_id, node1_data)
await self.delete_node(node1_id)
query = (
"""CREATE INDEX entity_id_gin_idxSELECT ON %s."base" USING gin (properties);"""
% (self.graph_name)
)
await self.db.execute(
query,
upsert=True,
with_age=True,
graph_name=self.graph_name,
)
async def finalize(self): async def finalize(self):
if self.db is not None: if self.db is not None:
await ClientManager.release_client(self.db) await ClientManager.release_client(self.db)
@@ -1233,8 +1255,8 @@ class PGGraphStorage(BaseGraphStorage):
label = node_id.strip('"') label = node_id.strip('"')
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MATCH (n:base {entity_id: "%s"})-[]-(x) MATCH (n:base {entity_id: "%s"})-[r]-()
RETURN count(x) AS total_edge_count RETURN count(r) AS total_edge_count
$$) AS (total_edge_count integer)""" % (self.graph_name, label) $$) AS (total_edge_count integer)""" % (self.graph_name, label)
record = (await self._query(query))[0] record = (await self._query(query))[0]
if record: if record:
@@ -1262,7 +1284,7 @@ class PGGraphStorage(BaseGraphStorage):
tgt_label = target_node_id.strip('"') tgt_label = target_node_id.strip('"')
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MATCH (a:base {entity_id: "%s"})-[r]->(b:base {entity_id: "%s"}) MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"})
RETURN properties(r) as edge_properties RETURN properties(r) as edge_properties
LIMIT 1 LIMIT 1
$$) AS (edge_properties agtype)""" % ( $$) AS (edge_properties agtype)""" % (
@@ -1285,7 +1307,7 @@ class PGGraphStorage(BaseGraphStorage):
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MATCH (n:base {entity_id: "%s"}) MATCH (n:base {entity_id: "%s"})
OPTIONAL MATCH (n)-[]-(connected:base) OPTIONAL MATCH (n)-[]-(connected)
RETURN n, connected RETURN n, connected
$$) AS (n agtype, connected agtype)""" % ( $$) AS (n agtype, connected agtype)""" % (
self.graph_name, self.graph_name,
@@ -1374,7 +1396,7 @@ class PGGraphStorage(BaseGraphStorage):
MATCH (source:base {entity_id: "%s"}) MATCH (source:base {entity_id: "%s"})
WITH source WITH source
MATCH (target:base {entity_id: "%s"}) MATCH (target:base {entity_id: "%s"})
MERGE (source)-[r:DIRECTED]->(target) MERGE (source)-[r:DIRECTED]-(target)
SET r += %s SET r += %s
RETURN r RETURN r
$$) AS (r agtype)""" % ( $$) AS (r agtype)""" % (
@@ -1447,7 +1469,7 @@ class PGGraphStorage(BaseGraphStorage):
tgt_label = target.strip('"') tgt_label = target.strip('"')
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MATCH (a:base {entity_id: "%s"})-[r]->(b:base {entity_id: "%s"}) MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"})
DELETE r DELETE r
$$) AS (r agtype)""" % (self.graph_name, src_label, tgt_label) $$) AS (r agtype)""" % (self.graph_name, src_label, tgt_label)
@@ -1458,6 +1480,259 @@ class PGGraphStorage(BaseGraphStorage):
logger.error(f"Error during edge deletion: {str(e)}") logger.error(f"Error during edge deletion: {str(e)}")
raise raise
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
"""
Retrieve multiple nodes in one query using UNWIND.
Args:
node_ids: List of node entity IDs to fetch.
Returns:
A dictionary mapping each node_id to its node data (or None if not found).
"""
if not node_ids:
return {}
# Format node IDs for the query
formatted_ids = ", ".join(
['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
)
query = """SELECT * FROM cypher('%s', $$
UNWIND [%s] AS node_id
MATCH (n:base {entity_id: node_id})
RETURN node_id, n
$$) AS (node_id text, n agtype)""" % (self.graph_name, formatted_ids)
results = await self._query(query)
# Build result dictionary
nodes_dict = {}
for result in results:
if result["node_id"] and result["n"]:
node_dict = result["n"]["properties"]
# Remove the 'base' label if present in a 'labels' property
if "labels" in node_dict:
node_dict["labels"] = [
label for label in node_dict["labels"] if label != "base"
]
nodes_dict[result["node_id"]] = node_dict
return nodes_dict
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
"""
Retrieve the degree for multiple nodes in a single query using UNWIND.
Calculates the total degree by counting distinct relationships.
Uses separate queries for outgoing and incoming edges.
Args:
node_ids: List of node labels (entity_id values) to look up.
Returns:
A dictionary mapping each node_id to its degree (total number of relationships).
If a node is not found, its degree will be set to 0.
"""
if not node_ids:
return {}
# Format node IDs for the query
formatted_ids = ", ".join(
['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
)
outgoing_query = """SELECT * FROM cypher('%s', $$
UNWIND [%s] AS node_id
MATCH (n:base {entity_id: node_id})
OPTIONAL MATCH (n)-[r]->(a)
RETURN node_id, count(a) AS out_degree
$$) AS (node_id text, out_degree bigint)""" % (
self.graph_name,
formatted_ids,
)
incoming_query = """SELECT * FROM cypher('%s', $$
UNWIND [%s] AS node_id
MATCH (n:base {entity_id: node_id})
OPTIONAL MATCH (n)<-[r]-(b)
RETURN node_id, count(b) AS in_degree
$$) AS (node_id text, in_degree bigint)""" % (
self.graph_name,
formatted_ids,
)
outgoing_results = await self._query(outgoing_query)
incoming_results = await self._query(incoming_query)
out_degrees = {}
in_degrees = {}
for result in outgoing_results:
if result["node_id"] is not None:
out_degrees[result["node_id"]] = int(result["out_degree"])
for result in incoming_results:
if result["node_id"] is not None:
in_degrees[result["node_id"]] = int(result["in_degree"])
degrees_dict = {}
for node_id in node_ids:
out_degree = out_degrees.get(node_id, 0)
in_degree = in_degrees.get(node_id, 0)
degrees_dict[node_id] = out_degree + in_degree
return degrees_dict
async def edge_degrees_batch(
self, edges: list[tuple[str, str]]
) -> dict[tuple[str, str], int]:
"""
Calculate the combined degree for each edge (sum of the source and target node degrees)
in batch using the already implemented node_degrees_batch.
Args:
edges: List of (source_node_id, target_node_id) tuples
Returns:
Dictionary mapping edge tuples to their combined degrees
"""
if not edges:
return {}
# Use node_degrees_batch to get all node degrees efficiently
all_nodes = set()
for src, tgt in edges:
all_nodes.add(src)
all_nodes.add(tgt)
node_degrees = await self.node_degrees_batch(list(all_nodes))
# Calculate edge degrees
edge_degrees_dict = {}
for src, tgt in edges:
src_degree = node_degrees.get(src, 0)
tgt_degree = node_degrees.get(tgt, 0)
edge_degrees_dict[(src, tgt)] = src_degree + tgt_degree
return edge_degrees_dict
async def get_edges_batch(
self, pairs: list[dict[str, str]]
) -> dict[tuple[str, str], dict]:
"""
Retrieve edge properties for multiple (src, tgt) pairs in one query.
Get forward and backward edges seperately and merge them before return
Args:
pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...]
Returns:
A dictionary mapping (src, tgt) tuples to their edge properties.
"""
if not pairs:
return {}
src_nodes = []
tgt_nodes = []
for pair in pairs:
src_nodes.append(pair["src"].replace('"', ""))
tgt_nodes.append(pair["tgt"].replace('"', ""))
src_array = ", ".join([f'"{src}"' for src in src_nodes])
tgt_array = ", ".join([f'"{tgt}"' for tgt in tgt_nodes])
forward_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
WITH [{src_array}] AS sources, [{tgt_array}] AS targets
UNWIND range(0, size(sources)-1) AS i
MATCH (a:base {{entity_id: sources[i]}})-[r:DIRECTED]->(b:base {{entity_id: targets[i]}})
RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties
$$) AS (source text, target text, edge_properties agtype)"""
backward_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
WITH [{src_array}] AS sources, [{tgt_array}] AS targets
UNWIND range(0, size(sources)-1) AS i
MATCH (a:base {{entity_id: sources[i]}})<-[r:DIRECTED]-(b:base {{entity_id: targets[i]}})
RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties
$$) AS (source text, target text, edge_properties agtype)"""
forward_results = await self._query(forward_query)
backward_results = await self._query(backward_query)
edges_dict = {}
for result in forward_results:
if result["source"] and result["target"] and result["edge_properties"]:
edges_dict[(result["source"], result["target"])] = result[
"edge_properties"
]
for result in backward_results:
if result["source"] and result["target"] and result["edge_properties"]:
edges_dict[(result["source"], result["target"])] = result[
"edge_properties"
]
return edges_dict
async def get_nodes_edges_batch(
self, node_ids: list[str]
) -> dict[str, list[tuple[str, str]]]:
"""
Get all edges (both outgoing and incoming) for multiple nodes in a single batch operation.
Args:
node_ids: List of node IDs to get edges for
Returns:
Dictionary mapping node IDs to lists of (source, target) edge tuples
"""
if not node_ids:
return {}
# Format node IDs for the query
formatted_ids = ", ".join(
['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
)
outgoing_query = """SELECT * FROM cypher('%s', $$
UNWIND [%s] AS node_id
MATCH (n:base {entity_id: node_id})
OPTIONAL MATCH (n:base)-[]->(connected:base)
RETURN node_id, connected.entity_id AS connected_id
$$) AS (node_id text, connected_id text)""" % (
self.graph_name,
formatted_ids,
)
incoming_query = """SELECT * FROM cypher('%s', $$
UNWIND [%s] AS node_id
MATCH (n:base {entity_id: node_id})
OPTIONAL MATCH (n:base)<-[]-(connected:base)
RETURN node_id, connected.entity_id AS connected_id
$$) AS (node_id text, connected_id text)""" % (
self.graph_name,
formatted_ids,
)
outgoing_results = await self._query(outgoing_query)
incoming_results = await self._query(incoming_query)
nodes_edges_dict = {node_id: [] for node_id in node_ids}
for result in outgoing_results:
if result["node_id"] and result["connected_id"]:
nodes_edges_dict[result["node_id"]].append(
(result["node_id"], result["connected_id"])
)
for result in incoming_results:
if result["node_id"] and result["connected_id"]:
nodes_edges_dict[result["node_id"]].append(
(result["connected_id"], result["node_id"])
)
return nodes_edges_dict
async def get_all_labels(self) -> list[str]: async def get_all_labels(self) -> list[str]:
""" """
Get all labels (node IDs) in the graph. Get all labels (node IDs) in the graph.
@@ -1507,8 +1782,8 @@ class PGGraphStorage(BaseGraphStorage):
strip_label = node_label.strip('"') strip_label = node_label.strip('"')
count_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ count_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
MATCH (n:base {{entity_id: "{strip_label}"}}) MATCH (n:base {{entity_id: "{strip_label}"}})
OPTIONAL MATCH p = (n)-[*..{max_depth}]-(m) OPTIONAL MATCH p = (n)-[*..{max_depth}]-()
RETURN count(distinct m) AS total_nodes RETURN count(nodes(p)) AS total_nodes
$$) AS (total_nodes bigint)""" $$) AS (total_nodes bigint)"""
count_result = await self._query(count_query) count_result = await self._query(count_query)
@@ -1518,19 +1793,25 @@ class PGGraphStorage(BaseGraphStorage):
# Now get the actual data with limit # Now get the actual data with limit
if node_label == "*": if node_label == "*":
query = f"""SELECT * FROM cypher('{self.graph_name}', $$ query = f"""SELECT * FROM cypher('{self.graph_name}', $$
MATCH (n:base) MATCH (node:base)
OPTIONAL MATCH (n)-[r]->(target:base) OPTIONAL MATCH (node)-[r]->()
RETURN collect(distinct n) AS n, collect(distinct r) AS r RETURN collect(distinct node) AS n, collect(distinct r) AS r
LIMIT {max_nodes} LIMIT {max_nodes}
$$) AS (n agtype, r agtype)""" $$) AS (n agtype, r agtype)"""
else: else:
strip_label = node_label.strip('"') strip_label = node_label.strip('"')
query = f"""SELECT * FROM cypher('{self.graph_name}', $$ if total_nodes > 0:
MATCH (n:base {{entity_id: "{strip_label}"}}) query = f"""SELECT * FROM cypher('{self.graph_name}', $$
OPTIONAL MATCH p = (n)-[*..{max_depth}]-(m) MATCH (node:base {{entity_id: "{strip_label}"}})
RETURN nodes(p) AS n, relationships(p) AS r OPTIONAL MATCH p = (node)-[*..{max_depth}]-()
LIMIT {max_nodes} RETURN nodes(p) AS n, relationships(p) AS r
$$) AS (n agtype, r agtype)""" LIMIT {max_nodes}
$$) AS (n agtype, r agtype)"""
else:
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
MATCH (node:base {{entity_id: "{strip_label}"}})
RETURN node AS n
$$) AS (n agtype)"""
results = await self._query(query) results = await self._query(query)

View File

@@ -1328,16 +1328,20 @@ async def _get_node_data(
if not len(results): if not len(results):
return "", "", "" return "", "", ""
# get entity information
node_datas, node_degrees = await asyncio.gather( # Extract all entity IDs from your results list
asyncio.gather( node_ids = [r["entity_name"] for r in results]
*[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
), # Call the batch node retrieval and degree functions concurrently.
asyncio.gather( nodes_dict, degrees_dict = await asyncio.gather(
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results] knowledge_graph_inst.get_nodes_batch(node_ids),
), knowledge_graph_inst.node_degrees_batch(node_ids),
) )
# Now, if you need the node data and degree in order:
node_datas = [nodes_dict.get(nid) for nid in node_ids]
node_degrees = [degrees_dict.get(nid, 0) for nid in node_ids]
if not all([n is not None for n in node_datas]): if not all([n is not None for n in node_datas]):
logger.warning("Some nodes are missing, maybe the storage is damaged") logger.warning("Some nodes are missing, maybe the storage is damaged")
@@ -1460,9 +1464,12 @@ async def _find_most_related_text_unit_from_entities(
for dp in node_datas for dp in node_datas
if dp["source_id"] is not None if dp["source_id"] is not None
] ]
edges = await asyncio.gather(
*[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas] node_names = [dp["entity_name"] for dp in node_datas]
) batch_edges_dict = await knowledge_graph_inst.get_nodes_edges_batch(node_names)
# Build the edges list in the same order as node_datas.
edges = [batch_edges_dict.get(name, []) for name in node_names]
all_one_hop_nodes = set() all_one_hop_nodes = set()
for this_edges in edges: for this_edges in edges:
if not this_edges: if not this_edges:
@@ -1470,9 +1477,14 @@ async def _find_most_related_text_unit_from_entities(
all_one_hop_nodes.update([e[1] for e in this_edges]) all_one_hop_nodes.update([e[1] for e in this_edges])
all_one_hop_nodes = list(all_one_hop_nodes) all_one_hop_nodes = list(all_one_hop_nodes)
all_one_hop_nodes_data = await asyncio.gather(
*[knowledge_graph_inst.get_node(e) for e in all_one_hop_nodes] # Batch retrieve one-hop node data using get_nodes_batch
all_one_hop_nodes_data_dict = await knowledge_graph_inst.get_nodes_batch(
all_one_hop_nodes
) )
all_one_hop_nodes_data = [
all_one_hop_nodes_data_dict.get(e) for e in all_one_hop_nodes
]
# Add null check for node data # Add null check for node data
all_one_hop_text_units_lookup = { all_one_hop_text_units_lookup = {
@@ -1563,17 +1575,30 @@ async def _find_most_related_edges_from_entities(
seen.add(sorted_edge) seen.add(sorted_edge)
all_edges.append(sorted_edge) all_edges.append(sorted_edge)
all_edges_pack, all_edges_degree = await asyncio.gather( # Prepare edge pairs in two forms:
asyncio.gather(*[knowledge_graph_inst.get_edge(e[0], e[1]) for e in all_edges]), # For the batch edge properties function, use dicts.
asyncio.gather( edge_pairs_dicts = [{"src": e[0], "tgt": e[1]} for e in all_edges]
*[knowledge_graph_inst.edge_degree(e[0], e[1]) for e in all_edges] # For edge degrees, use tuples.
), edge_pairs_tuples = list(all_edges) # all_edges is already a list of tuples
# Call the batched functions concurrently.
edge_data_dict, edge_degrees_dict = await asyncio.gather(
knowledge_graph_inst.get_edges_batch(edge_pairs_dicts),
knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples),
) )
all_edges_data = [
{"src_tgt": k, "rank": d, **v} # Reconstruct edge_datas list in the same order as the deduplicated results.
for k, v, d in zip(all_edges, all_edges_pack, all_edges_degree) all_edges_data = []
if v is not None for pair in all_edges:
] edge_props = edge_data_dict.get(pair)
if edge_props is not None:
combined = {
"src_tgt": pair,
"rank": edge_degrees_dict.get(pair, 0),
**edge_props,
}
all_edges_data.append(combined)
all_edges_data = sorted( all_edges_data = sorted(
all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
) )
@@ -1608,29 +1633,34 @@ async def _get_edge_data(
if not len(results): if not len(results):
return "", "", "" return "", "", ""
edge_datas, edge_degree = await asyncio.gather( # Prepare edge pairs in two forms:
asyncio.gather( # For the batch edge properties function, use dicts.
*[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results] edge_pairs_dicts = [{"src": r["src_id"], "tgt": r["tgt_id"]} for r in results]
), # For edge degrees, use tuples.
asyncio.gather( edge_pairs_tuples = [(r["src_id"], r["tgt_id"]) for r in results]
*[
knowledge_graph_inst.edge_degree(r["src_id"], r["tgt_id"]) # Call the batched functions concurrently.
for r in results edge_data_dict, edge_degrees_dict = await asyncio.gather(
] knowledge_graph_inst.get_edges_batch(edge_pairs_dicts),
), knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples),
) )
edge_datas = [ # Reconstruct edge_datas list in the same order as results.
{ edge_datas = []
"src_id": k["src_id"], for k in results:
"tgt_id": k["tgt_id"], pair = (k["src_id"], k["tgt_id"])
"rank": d, edge_props = edge_data_dict.get(pair)
"created_at": k.get("__created_at__", None), if edge_props is not None:
**v, # Use edge degree from the batch as rank.
} combined = {
for k, v, d in zip(results, edge_datas, edge_degree) "src_id": k["src_id"],
if v is not None "tgt_id": k["tgt_id"],
] "rank": edge_degrees_dict.get(pair, k.get("rank", 0)),
"created_at": k.get("__created_at__", None),
**edge_props,
}
edge_datas.append(combined)
edge_datas = sorted( edge_datas = sorted(
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
) )
@@ -1736,25 +1766,23 @@ async def _find_most_related_entities_from_relationships(
entity_names.append(e["tgt_id"]) entity_names.append(e["tgt_id"])
seen.add(e["tgt_id"]) seen.add(e["tgt_id"])
node_datas, node_degrees = await asyncio.gather( # Batch approach: Retrieve nodes and their degrees concurrently with one query each.
asyncio.gather( nodes_dict, degrees_dict = await asyncio.gather(
*[ knowledge_graph_inst.get_nodes_batch(entity_names),
knowledge_graph_inst.get_node(entity_name) knowledge_graph_inst.node_degrees_batch(entity_names),
for entity_name in entity_names
]
),
asyncio.gather(
*[
knowledge_graph_inst.node_degree(entity_name)
for entity_name in entity_names
]
),
) )
node_datas = [
{**n, "entity_name": k, "rank": d} # Rebuild the list in the same order as entity_names
for k, n, d in zip(entity_names, node_datas, node_degrees) node_datas = []
if n is not None for entity_name in entity_names:
] node = nodes_dict.get(entity_name)
degree = degrees_dict.get(entity_name, 0)
if node is None:
logger.warning(f"Node '{entity_name}' not found in batch retrieval.")
continue
# Combine the node data with the entity name and computed degree (as rank)
combined = {**node, "entity_name": entity_name, "rank": degree}
node_datas.append(combined)
len_node_datas = len(node_datas) len_node_datas = len(node_datas)
node_datas = truncate_list_by_token_size( node_datas = truncate_list_by_token_size(

View File

@@ -210,6 +210,25 @@ async def test_graph_basic(storage):
print(f"读取边属性失败: {node1_id} -> {node2_id}") print(f"读取边属性失败: {node1_id} -> {node2_id}")
assert False, f"未能读取边属性: {node1_id} -> {node2_id}" assert False, f"未能读取边属性: {node1_id} -> {node2_id}"
# 5.1 验证无向图特性 - 读取反向边属性
print(f"读取反向边属性: {node2_id} -> {node1_id}")
reverse_edge_props = await storage.get_edge(node2_id, node1_id)
if reverse_edge_props:
print(f"成功读取反向边属性: {node2_id} -> {node1_id}")
print(f"反向边关系: {reverse_edge_props.get('relationship', '无关系')}")
print(f"反向边描述: {reverse_edge_props.get('description', '无描述')}")
print(f"反向边权重: {reverse_edge_props.get('weight', '无权重')}")
# 验证正向和反向边属性是否相同
assert (
edge_props == reverse_edge_props
), "正向和反向边属性不一致,无向图特性验证失败"
print("无向图特性验证成功:正向和反向边属性一致")
else:
print(f"读取反向边属性失败: {node2_id} -> {node1_id}")
assert (
False
), f"未能读取反向边属性: {node2_id} -> {node1_id},无向图特性验证失败"
print("基本测试完成,数据已保留在数据库中") print("基本测试完成,数据已保留在数据库中")
return True return True
@@ -294,13 +313,31 @@ async def test_graph_advanced(storage):
print(f"节点 {node1_id} 的度数: {node1_degree}") print(f"节点 {node1_id} 的度数: {node1_degree}")
assert node1_degree == 1, f"节点 {node1_id} 的度数应为1实际为 {node1_degree}" assert node1_degree == 1, f"节点 {node1_id} 的度数应为1实际为 {node1_degree}"
# 2.1 测试所有节点的度数
print("== 测试所有节点的度数")
node2_degree = await storage.node_degree(node2_id)
node3_degree = await storage.node_degree(node3_id)
print(f"节点 {node2_id} 的度数: {node2_degree}")
print(f"节点 {node3_id} 的度数: {node3_degree}")
assert node2_degree == 2, f"节点 {node2_id} 的度数应为2实际为 {node2_degree}"
assert node3_degree == 1, f"节点 {node3_id} 的度数应为1实际为 {node3_degree}"
# 3. 测试 edge_degree - 获取边的度数 # 3. 测试 edge_degree - 获取边的度数
print(f"== 测试 edge_degree: {node1_id} -> {node2_id}") print(f"== 测试 edge_degree: {node1_id} -> {node2_id}")
edge_degree = await storage.edge_degree(node1_id, node2_id) edge_degree = await storage.edge_degree(node1_id, node2_id)
print(f"{node1_id} -> {node2_id} 的度数: {edge_degree}") print(f"{node1_id} -> {node2_id} 的度数: {edge_degree}")
assert ( assert (
edge_degree == 3 edge_degree == 3
), f"{node1_id} -> {node2_id} 的度数应为2,实际为 {edge_degree}" ), f"{node1_id} -> {node2_id} 的度数应为3,实际为 {edge_degree}"
# 3.1 测试反向边的度数 - 验证无向图特性
print(f"== 测试反向边的度数: {node2_id} -> {node1_id}")
reverse_edge_degree = await storage.edge_degree(node2_id, node1_id)
print(f"反向边 {node2_id} -> {node1_id} 的度数: {reverse_edge_degree}")
assert (
edge_degree == reverse_edge_degree
), "正向边和反向边的度数不一致,无向图特性验证失败"
print("无向图特性验证成功:正向边和反向边的度数一致")
# 4. 测试 get_node_edges - 获取节点的所有边 # 4. 测试 get_node_edges - 获取节点的所有边
print(f"== 测试 get_node_edges: {node2_id}") print(f"== 测试 get_node_edges: {node2_id}")
@@ -310,6 +347,31 @@ async def test_graph_advanced(storage):
len(node2_edges) == 2 len(node2_edges) == 2
), f"节点 {node2_id} 应有2条边实际有 {len(node2_edges)}" ), f"节点 {node2_id} 应有2条边实际有 {len(node2_edges)}"
# 4.1 验证节点边的无向图特性
print("== 验证节点边的无向图特性")
# 检查是否包含与node1和node3的连接关系无论方向
has_connection_with_node1 = False
has_connection_with_node3 = False
for edge in node2_edges:
# 检查是否有与node1的连接无论方向
if (edge[0] == node1_id and edge[1] == node2_id) or (
edge[0] == node2_id and edge[1] == node1_id
):
has_connection_with_node1 = True
# 检查是否有与node3的连接无论方向
if (edge[0] == node2_id and edge[1] == node3_id) or (
edge[0] == node3_id and edge[1] == node2_id
):
has_connection_with_node3 = True
assert (
has_connection_with_node1
), f"节点 {node2_id} 的边列表中应包含与 {node1_id} 的连接"
assert (
has_connection_with_node3
), f"节点 {node2_id} 的边列表中应包含与 {node3_id} 的连接"
print(f"无向图特性验证成功:节点 {node2_id} 的边列表包含所有相关的边")
# 5. 测试 get_all_labels - 获取所有标签 # 5. 测试 get_all_labels - 获取所有标签
print("== 测试 get_all_labels") print("== 测试 get_all_labels")
all_labels = await storage.get_all_labels() all_labels = await storage.get_all_labels()
@@ -346,6 +408,15 @@ async def test_graph_advanced(storage):
print(f"删除后查询边属性 {node2_id} -> {node3_id}: {edge_props}") print(f"删除后查询边属性 {node2_id} -> {node3_id}: {edge_props}")
assert edge_props is None, f"{node2_id} -> {node3_id} 应已被删除" assert edge_props is None, f"{node2_id} -> {node3_id} 应已被删除"
# 8.1 验证删除边的无向图特性
print(f"== 验证删除边的无向图特性: {node3_id} -> {node2_id}")
reverse_edge_props = await storage.get_edge(node3_id, node2_id)
print(f"删除后查询反向边属性 {node3_id} -> {node2_id}: {reverse_edge_props}")
assert (
reverse_edge_props is None
), f"反向边 {node3_id} -> {node2_id} 也应被删除,无向图特性验证失败"
print("无向图特性验证成功:删除一个方向的边后,反向边也被删除")
# 9. 测试 remove_nodes - 批量删除节点 # 9. 测试 remove_nodes - 批量删除节点
print(f"== 测试 remove_nodes: [{node2_id}, {node3_id}]") print(f"== 测试 remove_nodes: [{node2_id}, {node3_id}]")
await storage.remove_nodes([node2_id, node3_id]) await storage.remove_nodes([node2_id, node3_id])
@@ -377,6 +448,547 @@ async def test_graph_advanced(storage):
return False return False
async def test_graph_batch_operations(storage):
"""
测试图数据库的批量操作:
1. 使用 get_nodes_batch 批量获取多个节点的属性
2. 使用 node_degrees_batch 批量获取多个节点的度数
3. 使用 edge_degrees_batch 批量获取多个边的度数
4. 使用 get_edges_batch 批量获取多个边的属性
5. 使用 get_nodes_edges_batch 批量获取多个节点的所有边
"""
try:
# 清理之前的测试数据
print("清理之前的测试数据...\n")
await storage.drop()
# 1. 插入测试数据
# 插入节点1: 人工智能
node1_id = "人工智能"
node1_data = {
"entity_id": node1_id,
"description": "人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。",
"keywords": "AI,机器学习,深度学习",
"entity_type": "技术领域",
}
print(f"插入节点1: {node1_id}")
await storage.upsert_node(node1_id, node1_data)
# 插入节点2: 机器学习
node2_id = "机器学习"
node2_data = {
"entity_id": node2_id,
"description": "机器学习是人工智能的一个分支,它使用统计学方法让计算机系统在不被明确编程的情况下也能够学习。",
"keywords": "监督学习,无监督学习,强化学习",
"entity_type": "技术领域",
}
print(f"插入节点2: {node2_id}")
await storage.upsert_node(node2_id, node2_data)
# 插入节点3: 深度学习
node3_id = "深度学习"
node3_data = {
"entity_id": node3_id,
"description": "深度学习是机器学习的一个分支,它使用多层神经网络来模拟人脑的学习过程。",
"keywords": "神经网络,CNN,RNN",
"entity_type": "技术领域",
}
print(f"插入节点3: {node3_id}")
await storage.upsert_node(node3_id, node3_data)
# 插入节点4: 自然语言处理
node4_id = "自然语言处理"
node4_data = {
"entity_id": node4_id,
"description": "自然语言处理是人工智能的一个分支,专注于使计算机理解和处理人类语言。",
"keywords": "NLP,文本分析,语言模型",
"entity_type": "技术领域",
}
print(f"插入节点4: {node4_id}")
await storage.upsert_node(node4_id, node4_data)
# 插入节点5: 计算机视觉
node5_id = "计算机视觉"
node5_data = {
"entity_id": node5_id,
"description": "计算机视觉是人工智能的一个分支,专注于使计算机能够从图像或视频中获取信息。",
"keywords": "CV,图像识别,目标检测",
"entity_type": "技术领域",
}
print(f"插入节点5: {node5_id}")
await storage.upsert_node(node5_id, node5_data)
# 插入边1: 人工智能 -> 机器学习
edge1_data = {
"relationship": "包含",
"weight": 1.0,
"description": "人工智能领域包含机器学习这个子领域",
}
print(f"插入边1: {node1_id} -> {node2_id}")
await storage.upsert_edge(node1_id, node2_id, edge1_data)
# 插入边2: 机器学习 -> 深度学习
edge2_data = {
"relationship": "包含",
"weight": 1.0,
"description": "机器学习领域包含深度学习这个子领域",
}
print(f"插入边2: {node2_id} -> {node3_id}")
await storage.upsert_edge(node2_id, node3_id, edge2_data)
# 插入边3: 人工智能 -> 自然语言处理
edge3_data = {
"relationship": "包含",
"weight": 1.0,
"description": "人工智能领域包含自然语言处理这个子领域",
}
print(f"插入边3: {node1_id} -> {node4_id}")
await storage.upsert_edge(node1_id, node4_id, edge3_data)
# 插入边4: 人工智能 -> 计算机视觉
edge4_data = {
"relationship": "包含",
"weight": 1.0,
"description": "人工智能领域包含计算机视觉这个子领域",
}
print(f"插入边4: {node1_id} -> {node5_id}")
await storage.upsert_edge(node1_id, node5_id, edge4_data)
# 插入边5: 深度学习 -> 自然语言处理
edge5_data = {
"relationship": "应用于",
"weight": 0.8,
"description": "深度学习技术应用于自然语言处理领域",
}
print(f"插入边5: {node3_id} -> {node4_id}")
await storage.upsert_edge(node3_id, node4_id, edge5_data)
# 插入边6: 深度学习 -> 计算机视觉
edge6_data = {
"relationship": "应用于",
"weight": 0.8,
"description": "深度学习技术应用于计算机视觉领域",
}
print(f"插入边6: {node3_id} -> {node5_id}")
await storage.upsert_edge(node3_id, node5_id, edge6_data)
# 2. 测试 get_nodes_batch - 批量获取多个节点的属性
print("== 测试 get_nodes_batch")
node_ids = [node1_id, node2_id, node3_id]
nodes_dict = await storage.get_nodes_batch(node_ids)
print(f"批量获取节点属性结果: {nodes_dict.keys()}")
assert len(nodes_dict) == 3, f"应返回3个节点实际返回 {len(nodes_dict)}"
assert node1_id in nodes_dict, f"{node1_id} 应在返回结果中"
assert node2_id in nodes_dict, f"{node2_id} 应在返回结果中"
assert node3_id in nodes_dict, f"{node3_id} 应在返回结果中"
assert (
nodes_dict[node1_id]["description"] == node1_data["description"]
), f"{node1_id} 描述不匹配"
assert (
nodes_dict[node2_id]["description"] == node2_data["description"]
), f"{node2_id} 描述不匹配"
assert (
nodes_dict[node3_id]["description"] == node3_data["description"]
), f"{node3_id} 描述不匹配"
# 3. 测试 node_degrees_batch - 批量获取多个节点的度数
print("== 测试 node_degrees_batch")
node_degrees = await storage.node_degrees_batch(node_ids)
print(f"批量获取节点度数结果: {node_degrees}")
assert (
len(node_degrees) == 3
), f"应返回3个节点的度数实际返回 {len(node_degrees)}"
assert node1_id in node_degrees, f"{node1_id} 应在返回结果中"
assert node2_id in node_degrees, f"{node2_id} 应在返回结果中"
assert node3_id in node_degrees, f"{node3_id} 应在返回结果中"
assert (
node_degrees[node1_id] == 3
), f"{node1_id} 度数应为3实际为 {node_degrees[node1_id]}"
assert (
node_degrees[node2_id] == 2
), f"{node2_id} 度数应为2实际为 {node_degrees[node2_id]}"
assert (
node_degrees[node3_id] == 3
), f"{node3_id} 度数应为3实际为 {node_degrees[node3_id]}"
# 4. 测试 edge_degrees_batch - 批量获取多个边的度数
print("== 测试 edge_degrees_batch")
edges = [(node1_id, node2_id), (node2_id, node3_id), (node3_id, node4_id)]
edge_degrees = await storage.edge_degrees_batch(edges)
print(f"批量获取边度数结果: {edge_degrees}")
assert (
len(edge_degrees) == 3
), f"应返回3条边的度数实际返回 {len(edge_degrees)}"
assert (
node1_id,
node2_id,
) in edge_degrees, f"{node1_id} -> {node2_id} 应在返回结果中"
assert (
node2_id,
node3_id,
) in edge_degrees, f"{node2_id} -> {node3_id} 应在返回结果中"
assert (
node3_id,
node4_id,
) in edge_degrees, f"{node3_id} -> {node4_id} 应在返回结果中"
# 验证边的度数是否正确(源节点度数 + 目标节点度数)
assert (
edge_degrees[(node1_id, node2_id)] == 5
), f"{node1_id} -> {node2_id} 度数应为5实际为 {edge_degrees[(node1_id, node2_id)]}"
assert (
edge_degrees[(node2_id, node3_id)] == 5
), f"{node2_id} -> {node3_id} 度数应为5实际为 {edge_degrees[(node2_id, node3_id)]}"
assert (
edge_degrees[(node3_id, node4_id)] == 5
), f"{node3_id} -> {node4_id} 度数应为5实际为 {edge_degrees[(node3_id, node4_id)]}"
# 5. 测试 get_edges_batch - 批量获取多个边的属性
print("== 测试 get_edges_batch")
# 将元组列表转换为Neo4j风格的字典列表
edge_dicts = [{"src": src, "tgt": tgt} for src, tgt in edges]
edges_dict = await storage.get_edges_batch(edge_dicts)
print(f"批量获取边属性结果: {edges_dict.keys()}")
assert len(edges_dict) == 3, f"应返回3条边的属性实际返回 {len(edges_dict)}"
assert (
node1_id,
node2_id,
) in edges_dict, f"{node1_id} -> {node2_id} 应在返回结果中"
assert (
node2_id,
node3_id,
) in edges_dict, f"{node2_id} -> {node3_id} 应在返回结果中"
assert (
node3_id,
node4_id,
) in edges_dict, f"{node3_id} -> {node4_id} 应在返回结果中"
assert (
edges_dict[(node1_id, node2_id)]["relationship"]
== edge1_data["relationship"]
), f"{node1_id} -> {node2_id} 关系不匹配"
assert (
edges_dict[(node2_id, node3_id)]["relationship"]
== edge2_data["relationship"]
), f"{node2_id} -> {node3_id} 关系不匹配"
assert (
edges_dict[(node3_id, node4_id)]["relationship"]
== edge5_data["relationship"]
), f"{node3_id} -> {node4_id} 关系不匹配"
# 5.1 测试反向边的批量获取 - 验证无向图特性
print("== 测试反向边的批量获取")
# 创建反向边的字典列表
reverse_edge_dicts = [{"src": tgt, "tgt": src} for src, tgt in edges]
reverse_edges_dict = await storage.get_edges_batch(reverse_edge_dicts)
print(f"批量获取反向边属性结果: {reverse_edges_dict.keys()}")
assert (
len(reverse_edges_dict) == 3
), f"应返回3条反向边的属性实际返回 {len(reverse_edges_dict)}"
# 验证正向和反向边的属性是否一致
for (src, tgt), props in edges_dict.items():
assert (
tgt,
src,
) in reverse_edges_dict, f"反向边 {tgt} -> {src} 应在返回结果中"
assert (
props == reverse_edges_dict[(tgt, src)]
), f"{src} -> {tgt} 和反向边 {tgt} -> {src} 的属性不一致"
print("无向图特性验证成功:批量获取的正向和反向边属性一致")
# 6. 测试 get_nodes_edges_batch - 批量获取多个节点的所有边
print("== 测试 get_nodes_edges_batch")
nodes_edges = await storage.get_nodes_edges_batch([node1_id, node3_id])
print(f"批量获取节点边结果: {nodes_edges.keys()}")
assert (
len(nodes_edges) == 2
), f"应返回2个节点的边实际返回 {len(nodes_edges)}"
assert node1_id in nodes_edges, f"{node1_id} 应在返回结果中"
assert node3_id in nodes_edges, f"{node3_id} 应在返回结果中"
assert (
len(nodes_edges[node1_id]) == 3
), f"{node1_id} 应有3条边实际有 {len(nodes_edges[node1_id])}"
assert (
len(nodes_edges[node3_id]) == 3
), f"{node3_id} 应有3条边实际有 {len(nodes_edges[node3_id])}"
# 6.1 验证批量获取节点边的无向图特性
print("== 验证批量获取节点边的无向图特性")
# 检查节点1的边是否包含所有相关的边无论方向
node1_outgoing_edges = [
(src, tgt) for src, tgt in nodes_edges[node1_id] if src == node1_id
]
node1_incoming_edges = [
(src, tgt) for src, tgt in nodes_edges[node1_id] if tgt == node1_id
]
print(f"节点 {node1_id} 的出边: {node1_outgoing_edges}")
print(f"节点 {node1_id} 的入边: {node1_incoming_edges}")
# 检查是否包含到机器学习、自然语言处理和计算机视觉的边
has_edge_to_node2 = any(tgt == node2_id for _, tgt in node1_outgoing_edges)
has_edge_to_node4 = any(tgt == node4_id for _, tgt in node1_outgoing_edges)
has_edge_to_node5 = any(tgt == node5_id for _, tgt in node1_outgoing_edges)
assert has_edge_to_node2, f"节点 {node1_id} 的边列表中应包含到 {node2_id} 的边"
assert has_edge_to_node4, f"节点 {node1_id} 的边列表中应包含到 {node4_id} 的边"
assert has_edge_to_node5, f"节点 {node1_id} 的边列表中应包含到 {node5_id} 的边"
# 检查节点3的边是否包含所有相关的边无论方向
node3_outgoing_edges = [
(src, tgt) for src, tgt in nodes_edges[node3_id] if src == node3_id
]
node3_incoming_edges = [
(src, tgt) for src, tgt in nodes_edges[node3_id] if tgt == node3_id
]
print(f"节点 {node3_id} 的出边: {node3_outgoing_edges}")
print(f"节点 {node3_id} 的入边: {node3_incoming_edges}")
# 检查是否包含与机器学习、自然语言处理和计算机视觉的连接(忽略方向)
has_connection_with_node2 = any(
(src == node2_id and tgt == node3_id)
or (src == node3_id and tgt == node2_id)
for src, tgt in nodes_edges[node3_id]
)
has_connection_with_node4 = any(
(src == node3_id and tgt == node4_id)
or (src == node4_id and tgt == node3_id)
for src, tgt in nodes_edges[node3_id]
)
has_connection_with_node5 = any(
(src == node3_id and tgt == node5_id)
or (src == node5_id and tgt == node3_id)
for src, tgt in nodes_edges[node3_id]
)
assert (
has_connection_with_node2
), f"节点 {node3_id} 的边列表中应包含与 {node2_id} 的连接"
assert (
has_connection_with_node4
), f"节点 {node3_id} 的边列表中应包含与 {node4_id} 的连接"
assert (
has_connection_with_node5
), f"节点 {node3_id} 的边列表中应包含与 {node5_id} 的连接"
print("无向图特性验证成功:批量获取的节点边包含所有相关的边(无论方向)")
# 7. 清理数据
print("== 测试 drop")
result = await storage.drop()
print(f"清理结果: {result}")
assert (
result["status"] == "success"
), f"清理应成功,实际状态为 {result['status']}"
print("\n批量操作测试完成")
return True
except Exception as e:
ASCIIColors.red(f"测试过程中发生错误: {str(e)}")
return False
async def test_graph_undirected_property(storage):
"""
专门测试图存储的无向图特性:
1. 验证插入一个方向的边后,反向查询是否能获得相同的结果
2. 验证边的属性在正向和反向查询中是否一致
3. 验证删除一个方向的边后,另一个方向的边是否也被删除
4. 验证批量操作中的无向图特性
"""
try:
# 清理之前的测试数据
print("清理之前的测试数据...\n")
await storage.drop()
# 1. 插入测试数据
# 插入节点1: 计算机科学
node1_id = "计算机科学"
node1_data = {
"entity_id": node1_id,
"description": "计算机科学是研究计算机及其应用的科学。",
"keywords": "计算机,科学,技术",
"entity_type": "学科",
}
print(f"插入节点1: {node1_id}")
await storage.upsert_node(node1_id, node1_data)
# 插入节点2: 数据结构
node2_id = "数据结构"
node2_data = {
"entity_id": node2_id,
"description": "数据结构是计算机科学中的一个基础概念,用于组织和存储数据。",
"keywords": "数据,结构,组织",
"entity_type": "概念",
}
print(f"插入节点2: {node2_id}")
await storage.upsert_node(node2_id, node2_data)
# 插入节点3: 算法
node3_id = "算法"
node3_data = {
"entity_id": node3_id,
"description": "算法是解决问题的步骤和方法。",
"keywords": "算法,步骤,方法",
"entity_type": "概念",
}
print(f"插入节点3: {node3_id}")
await storage.upsert_node(node3_id, node3_data)
# 2. 测试插入边后的无向图特性
print("\n== 测试插入边后的无向图特性")
# 插入边1: 计算机科学 -> 数据结构
edge1_data = {
"relationship": "包含",
"weight": 1.0,
"description": "计算机科学包含数据结构这个概念",
}
print(f"插入边1: {node1_id} -> {node2_id}")
await storage.upsert_edge(node1_id, node2_id, edge1_data)
# 验证正向查询
forward_edge = await storage.get_edge(node1_id, node2_id)
print(f"正向边属性: {forward_edge}")
assert forward_edge is not None, f"未能读取正向边属性: {node1_id} -> {node2_id}"
# 验证反向查询
reverse_edge = await storage.get_edge(node2_id, node1_id)
print(f"反向边属性: {reverse_edge}")
assert reverse_edge is not None, f"未能读取反向边属性: {node2_id} -> {node1_id}"
# 验证正向和反向边属性是否一致
assert (
forward_edge == reverse_edge
), "正向和反向边属性不一致,无向图特性验证失败"
print("无向图特性验证成功:正向和反向边属性一致")
# 3. 测试边的度数的无向图特性
print("\n== 测试边的度数的无向图特性")
# 插入边2: 计算机科学 -> 算法
edge2_data = {
"relationship": "包含",
"weight": 1.0,
"description": "计算机科学包含算法这个概念",
}
print(f"插入边2: {node1_id} -> {node3_id}")
await storage.upsert_edge(node1_id, node3_id, edge2_data)
# 验证正向和反向边的度数
forward_degree = await storage.edge_degree(node1_id, node2_id)
reverse_degree = await storage.edge_degree(node2_id, node1_id)
print(f"正向边 {node1_id} -> {node2_id} 的度数: {forward_degree}")
print(f"反向边 {node2_id} -> {node1_id} 的度数: {reverse_degree}")
assert (
forward_degree == reverse_degree
), "正向和反向边的度数不一致,无向图特性验证失败"
print("无向图特性验证成功:正向和反向边的度数一致")
# 4. 测试删除边的无向图特性
print("\n== 测试删除边的无向图特性")
# 删除正向边
print(f"删除边: {node1_id} -> {node2_id}")
await storage.remove_edges([(node1_id, node2_id)])
# 验证正向边是否被删除
forward_edge = await storage.get_edge(node1_id, node2_id)
print(f"删除后查询正向边属性 {node1_id} -> {node2_id}: {forward_edge}")
assert forward_edge is None, f"{node1_id} -> {node2_id} 应已被删除"
# 验证反向边是否也被删除
reverse_edge = await storage.get_edge(node2_id, node1_id)
print(f"删除后查询反向边属性 {node2_id} -> {node1_id}: {reverse_edge}")
assert (
reverse_edge is None
), f"反向边 {node2_id} -> {node1_id} 也应被删除,无向图特性验证失败"
print("无向图特性验证成功:删除一个方向的边后,反向边也被删除")
# 5. 测试批量操作中的无向图特性
print("\n== 测试批量操作中的无向图特性")
# 重新插入边
await storage.upsert_edge(node1_id, node2_id, edge1_data)
# 批量获取边属性
edge_dicts = [
{"src": node1_id, "tgt": node2_id},
{"src": node1_id, "tgt": node3_id},
]
reverse_edge_dicts = [
{"src": node2_id, "tgt": node1_id},
{"src": node3_id, "tgt": node1_id},
]
edges_dict = await storage.get_edges_batch(edge_dicts)
reverse_edges_dict = await storage.get_edges_batch(reverse_edge_dicts)
print(f"批量获取正向边属性结果: {edges_dict.keys()}")
print(f"批量获取反向边属性结果: {reverse_edges_dict.keys()}")
# 验证正向和反向边的属性是否一致
for (src, tgt), props in edges_dict.items():
assert (
tgt,
src,
) in reverse_edges_dict, f"反向边 {tgt} -> {src} 应在返回结果中"
assert (
props == reverse_edges_dict[(tgt, src)]
), f"{src} -> {tgt} 和反向边 {tgt} -> {src} 的属性不一致"
print("无向图特性验证成功:批量获取的正向和反向边属性一致")
# 6. 测试批量获取节点边的无向图特性
print("\n== 测试批量获取节点边的无向图特性")
nodes_edges = await storage.get_nodes_edges_batch([node1_id, node2_id])
print(f"批量获取节点边结果: {nodes_edges.keys()}")
# 检查节点1的边是否包含所有相关的边无论方向
node1_edges = nodes_edges[node1_id]
node2_edges = nodes_edges[node2_id]
# 检查节点1是否有到节点2和节点3的边
has_edge_to_node2 = any(
(src == node1_id and tgt == node2_id) for src, tgt in node1_edges
)
has_edge_to_node3 = any(
(src == node1_id and tgt == node3_id) for src, tgt in node1_edges
)
assert has_edge_to_node2, f"节点 {node1_id} 的边列表中应包含到 {node2_id} 的边"
assert has_edge_to_node3, f"节点 {node1_id} 的边列表中应包含到 {node3_id} 的边"
# 检查节点2是否有到节点1的边
has_edge_to_node1 = any(
(src == node2_id and tgt == node1_id)
or (src == node1_id and tgt == node2_id)
for src, tgt in node2_edges
)
assert (
has_edge_to_node1
), f"节点 {node2_id} 的边列表中应包含与 {node1_id} 的连接"
print("无向图特性验证成功:批量获取的节点边包含所有相关的边(无论方向)")
# 7. 清理数据
print("== 测试 drop")
result = await storage.drop()
print(f"清理结果: {result}")
assert (
result["status"] == "success"
), f"清理应成功,实际状态为 {result['status']}"
print("\n无向图特性测试完成")
return True
except Exception as e:
ASCIIColors.red(f"测试过程中发生错误: {str(e)}")
return False
async def main(): async def main():
"""主函数""" """主函数"""
# 显示程序标题 # 显示程序标题
@@ -411,21 +1023,35 @@ async def main():
ASCIIColors.yellow("\n请选择测试类型:") ASCIIColors.yellow("\n请选择测试类型:")
ASCIIColors.white("1. 基本测试 (节点和边的插入、读取)") ASCIIColors.white("1. 基本测试 (节点和边的插入、读取)")
ASCIIColors.white("2. 高级测试 (度数、标签、知识图谱、删除操作等)") ASCIIColors.white("2. 高级测试 (度数、标签、知识图谱、删除操作等)")
ASCIIColors.white("3. 全部测试") ASCIIColors.white("3. 批量操作测试 (批量获取节点、边属性和度数等)")
ASCIIColors.white("4. 无向图特性测试 (验证存储的无向图特性)")
ASCIIColors.white("5. 全部测试")
choice = input("\n请输入选项 (1/2/3): ") choice = input("\n请输入选项 (1/2/3/4/5): ")
if choice == "1": if choice == "1":
await test_graph_basic(storage) await test_graph_basic(storage)
elif choice == "2": elif choice == "2":
await test_graph_advanced(storage) await test_graph_advanced(storage)
elif choice == "3": elif choice == "3":
await test_graph_batch_operations(storage)
elif choice == "4":
await test_graph_undirected_property(storage)
elif choice == "5":
ASCIIColors.cyan("\n=== 开始基本测试 ===") ASCIIColors.cyan("\n=== 开始基本测试 ===")
basic_result = await test_graph_basic(storage) basic_result = await test_graph_basic(storage)
if basic_result: if basic_result:
ASCIIColors.cyan("\n=== 开始高级测试 ===") ASCIIColors.cyan("\n=== 开始高级测试 ===")
await test_graph_advanced(storage) advanced_result = await test_graph_advanced(storage)
if advanced_result:
ASCIIColors.cyan("\n=== 开始批量操作测试 ===")
batch_result = await test_graph_batch_operations(storage)
if batch_result:
ASCIIColors.cyan("\n=== 开始无向图特性测试 ===")
await test_graph_undirected_property(storage)
else: else:
ASCIIColors.red("无效的选项") ASCIIColors.red("无效的选项")