Changed node label from 'Entity' to 'base' and fix edge deletion error in PostgreSQL AGE graph
This commit is contained in:
@@ -1258,7 +1258,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
entity_name_label = self._encode_graph_label(node_id.strip('"'))
|
||||
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (n:Entity {node_id: "%s"})
|
||||
MATCH (n:base {node_id: "%s"})
|
||||
RETURN count(n) > 0 AS node_exists
|
||||
$$) AS (node_exists bool)""" % (self.graph_name, entity_name_label)
|
||||
|
||||
@@ -1271,7 +1271,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
tgt_label = self._encode_graph_label(target_node_id.strip('"'))
|
||||
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (a:Entity {node_id: "%s"})-[r]-(b:Entity {node_id: "%s"})
|
||||
MATCH (a:base {node_id: "%s"})-[r]-(b:base {node_id: "%s"})
|
||||
RETURN COUNT(r) > 0 AS edge_exists
|
||||
$$) AS (edge_exists bool)""" % (
|
||||
self.graph_name,
|
||||
@@ -1286,7 +1286,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||
label = self._encode_graph_label(node_id.strip('"'))
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (n:Entity {node_id: "%s"})
|
||||
MATCH (n:base {node_id: "%s"})
|
||||
RETURN n
|
||||
$$) AS (n agtype)""" % (self.graph_name, label)
|
||||
record = await self._query(query)
|
||||
@@ -1301,7 +1301,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
label = self._encode_graph_label(node_id.strip('"'))
|
||||
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (n:Entity {node_id: "%s"})-[]->(x)
|
||||
MATCH (n:base {node_id: "%s"})-[]->(x)
|
||||
RETURN count(x) AS total_edge_count
|
||||
$$) AS (total_edge_count integer)""" % (self.graph_name, label)
|
||||
record = (await self._query(query))[0]
|
||||
@@ -1329,7 +1329,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
tgt_label = self._encode_graph_label(target_node_id.strip('"'))
|
||||
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (a:Entity {node_id: "%s"})-[r]->(b:Entity {node_id: "%s"})
|
||||
MATCH (a:base {node_id: "%s"})-[r]->(b:base {node_id: "%s"})
|
||||
RETURN properties(r) as edge_properties
|
||||
LIMIT 1
|
||||
$$) AS (edge_properties agtype)""" % (
|
||||
@@ -1351,8 +1351,8 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
label = self._encode_graph_label(source_node_id.strip('"'))
|
||||
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (n:Entity {node_id: "%s"})
|
||||
OPTIONAL MATCH (n)-[]-(connected)
|
||||
MATCH (n:base {node_id: "%s"})
|
||||
OPTIONAL MATCH (n)-[]-(connected:base)
|
||||
RETURN n, connected
|
||||
$$) AS (n agtype, connected agtype)""" % (
|
||||
self.graph_name,
|
||||
@@ -1396,7 +1396,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
properties = node_data
|
||||
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MERGE (n:Entity {node_id: "%s"})
|
||||
MERGE (n:base {node_id: "%s"})
|
||||
SET n += %s
|
||||
RETURN n
|
||||
$$) AS (n agtype)""" % (
|
||||
@@ -1433,9 +1433,9 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
edge_properties = edge_data
|
||||
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (source:Entity {node_id: "%s"})
|
||||
MATCH (source:base {node_id: "%s"})
|
||||
WITH source
|
||||
MATCH (target:Entity {node_id: "%s"})
|
||||
MATCH (target:base {node_id: "%s"})
|
||||
MERGE (source)-[r:DIRECTED]->(target)
|
||||
SET r += %s
|
||||
RETURN r
|
||||
@@ -1466,7 +1466,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
label = self._encode_graph_label(node_id.strip('"'))
|
||||
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (n:Entity {node_id: "%s"})
|
||||
MATCH (n:base {entity_id: "%s"})
|
||||
DETACH DELETE n
|
||||
$$) AS (n agtype)""" % (self.graph_name, label)
|
||||
|
||||
@@ -1489,8 +1489,8 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
node_id_list = ", ".join([f'"{node_id}"' for node_id in encoded_node_ids])
|
||||
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (n:Entity)
|
||||
WHERE n.node_id IN [%s]
|
||||
MATCH (n:base)
|
||||
WHERE n.nentity_id IN [%s]
|
||||
DETACH DELETE n
|
||||
$$) AS (n agtype)""" % (self.graph_name, node_id_list)
|
||||
|
||||
@@ -1507,26 +1507,21 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
Args:
|
||||
edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id).
|
||||
"""
|
||||
encoded_edges = [
|
||||
(
|
||||
self._encode_graph_label(src.strip('"')),
|
||||
self._encode_graph_label(tgt.strip('"')),
|
||||
)
|
||||
for src, tgt in edges
|
||||
]
|
||||
edge_list = ", ".join([f'["{src}", "{tgt}"]' for src, tgt in encoded_edges])
|
||||
for source, target in edges:
|
||||
src_label = self._encode_graph_label(source.strip('"'))
|
||||
tgt_label = self._encode_graph_label(target.strip('"'))
|
||||
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (a:Entity)-[r]->(b:Entity)
|
||||
WHERE [a.node_id, b.node_id] IN [%s]
|
||||
DELETE r
|
||||
$$) AS (r agtype)""" % (self.graph_name, edge_list)
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (a:base {node_id: "%s"})-[r]->(b:base {node_id: "%s"})
|
||||
DELETE r
|
||||
$$) AS (r agtype)""" % (self.graph_name, src_label, tgt_label)
|
||||
|
||||
try:
|
||||
await self._query(query, readonly=False)
|
||||
except Exception as e:
|
||||
logger.error("Error during edge removal: {%s}", e)
|
||||
raise
|
||||
try:
|
||||
await self._query(query, readonly=False)
|
||||
logger.debug(f"Deleted edge from '{source}' to '{target}'")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during edge deletion: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_all_labels(self) -> list[str]:
|
||||
"""
|
||||
@@ -1537,8 +1532,10 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
"""
|
||||
query = (
|
||||
"""SELECT * FROM cypher('%s', $$
|
||||
MATCH (n:Entity)
|
||||
RETURN DISTINCT n.node_id AS label
|
||||
MATCH (n:base)
|
||||
WHERE n.entity_id IS NOT NULL
|
||||
RETURN DISTINCT n.entity_id AS label
|
||||
ORDER BY label
|
||||
$$) AS (label text)"""
|
||||
% self.graph_name
|
||||
)
|
||||
@@ -1584,15 +1581,15 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
# Build the query based on whether we want the full graph or a specific subgraph.
|
||||
if node_label == "*":
|
||||
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||
MATCH (n:Entity)
|
||||
OPTIONAL MATCH (n)-[r]->(m:Entity)
|
||||
MATCH (n:base)
|
||||
OPTIONAL MATCH (n)-[r]->(m:base)
|
||||
RETURN n, r, m
|
||||
LIMIT {MAX_GRAPH_NODES}
|
||||
$$) AS (n agtype, r agtype, m agtype)"""
|
||||
else:
|
||||
encoded_label = self._encode_graph_label(node_label.strip('"'))
|
||||
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||
MATCH (n:Entity {{node_id: "{encoded_label}"}})
|
||||
MATCH (n:base {{entity_id: "{encoded_label}"}})
|
||||
OPTIONAL MATCH p = (n)-[*..{max_depth}]-(m)
|
||||
RETURN nodes(p) AS nodes, relationships(p) AS relationships
|
||||
LIMIT {MAX_GRAPH_NODES}
|
||||
|
Reference in New Issue
Block a user