From 554d290993794908e94a18ddda1b2a93d151a35d Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 2 Apr 2025 14:03:56 +0800 Subject: [PATCH] Changed node label from 'Entity' to 'base' and fix edge deletion error in PostgreSQL AGE graph --- lightrag/kg/postgres_impl.py | 69 +++++++++++++++++------------------- 1 file changed, 33 insertions(+), 36 deletions(-) diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index a23aa8da..837d48ed 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -1258,7 +1258,7 @@ class PGGraphStorage(BaseGraphStorage): entity_name_label = self._encode_graph_label(node_id.strip('"')) query = """SELECT * FROM cypher('%s', $$ - MATCH (n:Entity {node_id: "%s"}) + MATCH (n:base {node_id: "%s"}) RETURN count(n) > 0 AS node_exists $$) AS (node_exists bool)""" % (self.graph_name, entity_name_label) @@ -1271,7 +1271,7 @@ class PGGraphStorage(BaseGraphStorage): tgt_label = self._encode_graph_label(target_node_id.strip('"')) query = """SELECT * FROM cypher('%s', $$ - MATCH (a:Entity {node_id: "%s"})-[r]-(b:Entity {node_id: "%s"}) + MATCH (a:base {node_id: "%s"})-[r]-(b:base {node_id: "%s"}) RETURN COUNT(r) > 0 AS edge_exists $$) AS (edge_exists bool)""" % ( self.graph_name, @@ -1286,7 +1286,7 @@ class PGGraphStorage(BaseGraphStorage): async def get_node(self, node_id: str) -> dict[str, str] | None: label = self._encode_graph_label(node_id.strip('"')) query = """SELECT * FROM cypher('%s', $$ - MATCH (n:Entity {node_id: "%s"}) + MATCH (n:base {node_id: "%s"}) RETURN n $$) AS (n agtype)""" % (self.graph_name, label) record = await self._query(query) @@ -1301,7 +1301,7 @@ class PGGraphStorage(BaseGraphStorage): label = self._encode_graph_label(node_id.strip('"')) query = """SELECT * FROM cypher('%s', $$ - MATCH (n:Entity {node_id: "%s"})-[]->(x) + MATCH (n:base {node_id: "%s"})-[]->(x) RETURN count(x) AS total_edge_count $$) AS (total_edge_count integer)""" % (self.graph_name, label) record = (await self._query(query))[0] @@ -1329,7 +1329,7 @@ class PGGraphStorage(BaseGraphStorage): tgt_label = self._encode_graph_label(target_node_id.strip('"')) query = """SELECT * FROM cypher('%s', $$ - MATCH (a:Entity {node_id: "%s"})-[r]->(b:Entity {node_id: "%s"}) + MATCH (a:base {node_id: "%s"})-[r]->(b:base {node_id: "%s"}) RETURN properties(r) as edge_properties LIMIT 1 $$) AS (edge_properties agtype)""" % ( @@ -1351,8 +1351,8 @@ class PGGraphStorage(BaseGraphStorage): label = self._encode_graph_label(source_node_id.strip('"')) query = """SELECT * FROM cypher('%s', $$ - MATCH (n:Entity {node_id: "%s"}) - OPTIONAL MATCH (n)-[]-(connected) + MATCH (n:base {node_id: "%s"}) + OPTIONAL MATCH (n)-[]-(connected:base) RETURN n, connected $$) AS (n agtype, connected agtype)""" % ( self.graph_name, @@ -1396,7 +1396,7 @@ class PGGraphStorage(BaseGraphStorage): properties = node_data query = """SELECT * FROM cypher('%s', $$ - MERGE (n:Entity {node_id: "%s"}) + MERGE (n:base {node_id: "%s"}) SET n += %s RETURN n $$) AS (n agtype)""" % ( @@ -1433,9 +1433,9 @@ class PGGraphStorage(BaseGraphStorage): edge_properties = edge_data query = """SELECT * FROM cypher('%s', $$ - MATCH (source:Entity {node_id: "%s"}) + MATCH (source:base {node_id: "%s"}) WITH source - MATCH (target:Entity {node_id: "%s"}) + MATCH (target:base {node_id: "%s"}) MERGE (source)-[r:DIRECTED]->(target) SET r += %s RETURN r @@ -1466,7 +1466,7 @@ class PGGraphStorage(BaseGraphStorage): label = self._encode_graph_label(node_id.strip('"')) query = """SELECT * FROM cypher('%s', $$ - MATCH (n:Entity {node_id: "%s"}) + MATCH (n:base {entity_id: "%s"}) DETACH DELETE n $$) AS (n agtype)""" % (self.graph_name, label) @@ -1489,8 +1489,8 @@ class PGGraphStorage(BaseGraphStorage): node_id_list = ", ".join([f'"{node_id}"' for node_id in encoded_node_ids]) query = """SELECT * FROM cypher('%s', $$ - MATCH (n:Entity) - WHERE n.node_id IN [%s] + MATCH (n:base) + WHERE n.nentity_id IN [%s] DETACH DELETE n $$) AS (n agtype)""" % (self.graph_name, node_id_list) @@ -1507,26 +1507,21 @@ class PGGraphStorage(BaseGraphStorage): Args: edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id). """ - encoded_edges = [ - ( - self._encode_graph_label(src.strip('"')), - self._encode_graph_label(tgt.strip('"')), - ) - for src, tgt in edges - ] - edge_list = ", ".join([f'["{src}", "{tgt}"]' for src, tgt in encoded_edges]) + for source, target in edges: + src_label = self._encode_graph_label(source.strip('"')) + tgt_label = self._encode_graph_label(target.strip('"')) - query = """SELECT * FROM cypher('%s', $$ - MATCH (a:Entity)-[r]->(b:Entity) - WHERE [a.node_id, b.node_id] IN [%s] - DELETE r - $$) AS (r agtype)""" % (self.graph_name, edge_list) + query = """SELECT * FROM cypher('%s', $$ + MATCH (a:base {node_id: "%s"})-[r]->(b:base {node_id: "%s"}) + DELETE r + $$) AS (r agtype)""" % (self.graph_name, src_label, tgt_label) - try: - await self._query(query, readonly=False) - except Exception as e: - logger.error("Error during edge removal: {%s}", e) - raise + try: + await self._query(query, readonly=False) + logger.debug(f"Deleted edge from '{source}' to '{target}'") + except Exception as e: + logger.error(f"Error during edge deletion: {str(e)}") + raise async def get_all_labels(self) -> list[str]: """ @@ -1537,8 +1532,10 @@ class PGGraphStorage(BaseGraphStorage): """ query = ( """SELECT * FROM cypher('%s', $$ - MATCH (n:Entity) - RETURN DISTINCT n.node_id AS label + MATCH (n:base) + WHERE n.entity_id IS NOT NULL + RETURN DISTINCT n.entity_id AS label + ORDER BY label $$) AS (label text)""" % self.graph_name ) @@ -1584,15 +1581,15 @@ class PGGraphStorage(BaseGraphStorage): # Build the query based on whether we want the full graph or a specific subgraph. if node_label == "*": query = f"""SELECT * FROM cypher('{self.graph_name}', $$ - MATCH (n:Entity) - OPTIONAL MATCH (n)-[r]->(m:Entity) + MATCH (n:base) + OPTIONAL MATCH (n)-[r]->(m:base) RETURN n, r, m LIMIT {MAX_GRAPH_NODES} $$) AS (n agtype, r agtype, m agtype)""" else: encoded_label = self._encode_graph_label(node_label.strip('"')) query = f"""SELECT * FROM cypher('{self.graph_name}', $$ - MATCH (n:Entity {{node_id: "{encoded_label}"}}) + MATCH (n:base {{entity_id: "{encoded_label}"}}) OPTIONAL MATCH p = (n)-[*..{max_depth}]-(m) RETURN nodes(p) AS nodes, relationships(p) AS relationships LIMIT {MAX_GRAPH_NODES}