Changed node label from 'Entity' to 'base' and fix edge deletion error in PostgreSQL AGE graph

This commit is contained in:
yangdx
2025-04-02 14:03:56 +08:00
parent fc3208cf5b
commit 554d290993

View File

@@ -1258,7 +1258,7 @@ class PGGraphStorage(BaseGraphStorage):
entity_name_label = self._encode_graph_label(node_id.strip('"')) entity_name_label = self._encode_graph_label(node_id.strip('"'))
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity {node_id: "%s"}) MATCH (n:base {node_id: "%s"})
RETURN count(n) > 0 AS node_exists RETURN count(n) > 0 AS node_exists
$$) AS (node_exists bool)""" % (self.graph_name, entity_name_label) $$) AS (node_exists bool)""" % (self.graph_name, entity_name_label)
@@ -1271,7 +1271,7 @@ class PGGraphStorage(BaseGraphStorage):
tgt_label = self._encode_graph_label(target_node_id.strip('"')) tgt_label = self._encode_graph_label(target_node_id.strip('"'))
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MATCH (a:Entity {node_id: "%s"})-[r]-(b:Entity {node_id: "%s"}) MATCH (a:base {node_id: "%s"})-[r]-(b:base {node_id: "%s"})
RETURN COUNT(r) > 0 AS edge_exists RETURN COUNT(r) > 0 AS edge_exists
$$) AS (edge_exists bool)""" % ( $$) AS (edge_exists bool)""" % (
self.graph_name, self.graph_name,
@@ -1286,7 +1286,7 @@ class PGGraphStorage(BaseGraphStorage):
async def get_node(self, node_id: str) -> dict[str, str] | None: async def get_node(self, node_id: str) -> dict[str, str] | None:
label = self._encode_graph_label(node_id.strip('"')) label = self._encode_graph_label(node_id.strip('"'))
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity {node_id: "%s"}) MATCH (n:base {node_id: "%s"})
RETURN n RETURN n
$$) AS (n agtype)""" % (self.graph_name, label) $$) AS (n agtype)""" % (self.graph_name, label)
record = await self._query(query) record = await self._query(query)
@@ -1301,7 +1301,7 @@ class PGGraphStorage(BaseGraphStorage):
label = self._encode_graph_label(node_id.strip('"')) label = self._encode_graph_label(node_id.strip('"'))
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity {node_id: "%s"})-[]->(x) MATCH (n:base {node_id: "%s"})-[]->(x)
RETURN count(x) AS total_edge_count RETURN count(x) AS total_edge_count
$$) AS (total_edge_count integer)""" % (self.graph_name, label) $$) AS (total_edge_count integer)""" % (self.graph_name, label)
record = (await self._query(query))[0] record = (await self._query(query))[0]
@@ -1329,7 +1329,7 @@ class PGGraphStorage(BaseGraphStorage):
tgt_label = self._encode_graph_label(target_node_id.strip('"')) tgt_label = self._encode_graph_label(target_node_id.strip('"'))
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MATCH (a:Entity {node_id: "%s"})-[r]->(b:Entity {node_id: "%s"}) MATCH (a:base {node_id: "%s"})-[r]->(b:base {node_id: "%s"})
RETURN properties(r) as edge_properties RETURN properties(r) as edge_properties
LIMIT 1 LIMIT 1
$$) AS (edge_properties agtype)""" % ( $$) AS (edge_properties agtype)""" % (
@@ -1351,8 +1351,8 @@ class PGGraphStorage(BaseGraphStorage):
label = self._encode_graph_label(source_node_id.strip('"')) label = self._encode_graph_label(source_node_id.strip('"'))
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity {node_id: "%s"}) MATCH (n:base {node_id: "%s"})
OPTIONAL MATCH (n)-[]-(connected) OPTIONAL MATCH (n)-[]-(connected:base)
RETURN n, connected RETURN n, connected
$$) AS (n agtype, connected agtype)""" % ( $$) AS (n agtype, connected agtype)""" % (
self.graph_name, self.graph_name,
@@ -1396,7 +1396,7 @@ class PGGraphStorage(BaseGraphStorage):
properties = node_data properties = node_data
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MERGE (n:Entity {node_id: "%s"}) MERGE (n:base {node_id: "%s"})
SET n += %s SET n += %s
RETURN n RETURN n
$$) AS (n agtype)""" % ( $$) AS (n agtype)""" % (
@@ -1433,9 +1433,9 @@ class PGGraphStorage(BaseGraphStorage):
edge_properties = edge_data edge_properties = edge_data
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MATCH (source:Entity {node_id: "%s"}) MATCH (source:base {node_id: "%s"})
WITH source WITH source
MATCH (target:Entity {node_id: "%s"}) MATCH (target:base {node_id: "%s"})
MERGE (source)-[r:DIRECTED]->(target) MERGE (source)-[r:DIRECTED]->(target)
SET r += %s SET r += %s
RETURN r RETURN r
@@ -1466,7 +1466,7 @@ class PGGraphStorage(BaseGraphStorage):
label = self._encode_graph_label(node_id.strip('"')) label = self._encode_graph_label(node_id.strip('"'))
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity {node_id: "%s"}) MATCH (n:base {entity_id: "%s"})
DETACH DELETE n DETACH DELETE n
$$) AS (n agtype)""" % (self.graph_name, label) $$) AS (n agtype)""" % (self.graph_name, label)
@@ -1489,8 +1489,8 @@ class PGGraphStorage(BaseGraphStorage):
node_id_list = ", ".join([f'"{node_id}"' for node_id in encoded_node_ids]) node_id_list = ", ".join([f'"{node_id}"' for node_id in encoded_node_ids])
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity) MATCH (n:base)
WHERE n.node_id IN [%s] WHERE n.nentity_id IN [%s]
DETACH DELETE n DETACH DELETE n
$$) AS (n agtype)""" % (self.graph_name, node_id_list) $$) AS (n agtype)""" % (self.graph_name, node_id_list)
@@ -1507,26 +1507,21 @@ class PGGraphStorage(BaseGraphStorage):
Args: Args:
edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id). edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id).
""" """
encoded_edges = [ for source, target in edges:
( src_label = self._encode_graph_label(source.strip('"'))
self._encode_graph_label(src.strip('"')), tgt_label = self._encode_graph_label(target.strip('"'))
self._encode_graph_label(tgt.strip('"')),
)
for src, tgt in edges
]
edge_list = ", ".join([f'["{src}", "{tgt}"]' for src, tgt in encoded_edges])
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MATCH (a:Entity)-[r]->(b:Entity) MATCH (a:base {node_id: "%s"})-[r]->(b:base {node_id: "%s"})
WHERE [a.node_id, b.node_id] IN [%s] DELETE r
DELETE r $$) AS (r agtype)""" % (self.graph_name, src_label, tgt_label)
$$) AS (r agtype)""" % (self.graph_name, edge_list)
try: try:
await self._query(query, readonly=False) await self._query(query, readonly=False)
except Exception as e: logger.debug(f"Deleted edge from '{source}' to '{target}'")
logger.error("Error during edge removal: {%s}", e) except Exception as e:
raise logger.error(f"Error during edge deletion: {str(e)}")
raise
async def get_all_labels(self) -> list[str]: async def get_all_labels(self) -> list[str]:
""" """
@@ -1537,8 +1532,10 @@ class PGGraphStorage(BaseGraphStorage):
""" """
query = ( query = (
"""SELECT * FROM cypher('%s', $$ """SELECT * FROM cypher('%s', $$
MATCH (n:Entity) MATCH (n:base)
RETURN DISTINCT n.node_id AS label WHERE n.entity_id IS NOT NULL
RETURN DISTINCT n.entity_id AS label
ORDER BY label
$$) AS (label text)""" $$) AS (label text)"""
% self.graph_name % self.graph_name
) )
@@ -1584,15 +1581,15 @@ class PGGraphStorage(BaseGraphStorage):
# Build the query based on whether we want the full graph or a specific subgraph. # Build the query based on whether we want the full graph or a specific subgraph.
if node_label == "*": if node_label == "*":
query = f"""SELECT * FROM cypher('{self.graph_name}', $$ query = f"""SELECT * FROM cypher('{self.graph_name}', $$
MATCH (n:Entity) MATCH (n:base)
OPTIONAL MATCH (n)-[r]->(m:Entity) OPTIONAL MATCH (n)-[r]->(m:base)
RETURN n, r, m RETURN n, r, m
LIMIT {MAX_GRAPH_NODES} LIMIT {MAX_GRAPH_NODES}
$$) AS (n agtype, r agtype, m agtype)""" $$) AS (n agtype, r agtype, m agtype)"""
else: else:
encoded_label = self._encode_graph_label(node_label.strip('"')) encoded_label = self._encode_graph_label(node_label.strip('"'))
query = f"""SELECT * FROM cypher('{self.graph_name}', $$ query = f"""SELECT * FROM cypher('{self.graph_name}', $$
MATCH (n:Entity {{node_id: "{encoded_label}"}}) MATCH (n:base {{entity_id: "{encoded_label}"}})
OPTIONAL MATCH p = (n)-[*..{max_depth}]-(m) OPTIONAL MATCH p = (n)-[*..{max_depth}]-(m)
RETURN nodes(p) AS nodes, relationships(p) AS relationships RETURN nodes(p) AS nodes, relationships(p) AS relationships
LIMIT {MAX_GRAPH_NODES} LIMIT {MAX_GRAPH_NODES}