diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index eec042a2..66449567 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -1037,6 +1037,23 @@ class PGGraphStorage(BaseGraphStorage): def __post_init__(self): self.graph_name = self.namespace or os.environ.get("AGE_GRAPH_NAME", "lightrag") self.db: PostgreSQLDB | None = None + + @staticmethod + def _normalize_node_id(node_id: str) -> str: + """ + Normalize node ID to ensure special characters are properly handled in Cypher queries. + + Args: + node_id: The original node ID + + Returns: + Normalized node ID suitable for Cypher queries + """ + # Remove quotes + normalized_id = node_id.strip('"') + # Escape backslashes + normalized_id = normalized_id.replace("\\", "\\\\") + return normalized_id async def initialize(self): if self.db is None: @@ -1222,7 +1239,7 @@ class PGGraphStorage(BaseGraphStorage): return result async def has_node(self, node_id: str) -> bool: - entity_name_label = node_id.strip('"') + entity_name_label = self._normalize_node_id(node_id) query = """SELECT * FROM cypher('%s', $$ MATCH (n:base {entity_id: "%s"}) @@ -1234,8 +1251,8 @@ class PGGraphStorage(BaseGraphStorage): return single_result["node_exists"] async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: - src_label = source_node_id.strip('"') - tgt_label = target_node_id.strip('"') + src_label = self._normalize_node_id(source_node_id) + tgt_label = self._normalize_node_id(target_node_id) query = """SELECT * FROM cypher('%s', $$ MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"}) @@ -1253,7 +1270,7 @@ class PGGraphStorage(BaseGraphStorage): async def get_node(self, node_id: str) -> dict[str, str] | None: """Get node by its label identifier, return only node properties""" - label = node_id.strip('"') + label = self._normalize_node_id(node_id) query = """SELECT * FROM cypher('%s', $$ MATCH (n:base {entity_id: "%s"}) RETURN n @@ -1267,7 +1284,7 @@ class PGGraphStorage(BaseGraphStorage): return None async def node_degree(self, node_id: str) -> int: - label = node_id.strip('"') + label = self._normalize_node_id(node_id) query = """SELECT * FROM cypher('%s', $$ MATCH (n:base {entity_id: "%s"})-[r]-() @@ -1295,8 +1312,8 @@ class PGGraphStorage(BaseGraphStorage): ) -> dict[str, str] | None: """Get edge properties between two nodes""" - src_label = source_node_id.strip('"') - tgt_label = target_node_id.strip('"') + src_label = self._normalize_node_id(source_node_id) + tgt_label = self._normalize_node_id(target_node_id) query = """SELECT * FROM cypher('%s', $$ MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"}) @@ -1318,7 +1335,7 @@ class PGGraphStorage(BaseGraphStorage): Retrieves all edges (relationships) for a particular node identified by its label. :return: list of dictionaries containing edge information """ - label = source_node_id.strip('"') + label = self._normalize_node_id(source_node_id) query = """SELECT * FROM cypher('%s', $$ MATCH (n:base {entity_id: "%s"}) @@ -1367,7 +1384,7 @@ class PGGraphStorage(BaseGraphStorage): "PostgreSQL: node properties must contain an 'entity_id' field" ) - label = node_id.strip('"') + label = self._normalize_node_id(node_id) properties = self._format_properties(node_data) query = """SELECT * FROM cypher('%s', $$ @@ -1403,8 +1420,8 @@ class PGGraphStorage(BaseGraphStorage): target_node_id (str): Label of the target node (used as identifier) edge_data (dict): dictionary of properties to set on the edge """ - src_label = source_node_id.strip('"') - tgt_label = target_node_id.strip('"') + src_label = self._normalize_node_id(source_node_id) + tgt_label = self._normalize_node_id(target_node_id) edge_properties = self._format_properties(edge_data) query = """SELECT * FROM cypher('%s', $$ @@ -1437,7 +1454,7 @@ class PGGraphStorage(BaseGraphStorage): Args: node_id (str): The ID of the node to delete. """ - label = node_id.strip('"') + label = self._normalize_node_id(node_id) query = """SELECT * FROM cypher('%s', $$ MATCH (n:base {entity_id: "%s"}) @@ -1457,7 +1474,7 @@ class PGGraphStorage(BaseGraphStorage): Args: node_ids (list[str]): A list of node IDs to remove. """ - node_ids = [node_id.strip('"') for node_id in node_ids] + node_ids = [self._normalize_node_id(node_id) for node_id in node_ids] node_id_list = ", ".join([f'"{node_id}"' for node_id in node_ids]) query = """SELECT * FROM cypher('%s', $$ @@ -1480,8 +1497,8 @@ class PGGraphStorage(BaseGraphStorage): edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id). """ for source, target in edges: - src_label = source.strip('"') - tgt_label = target.strip('"') + src_label = self._normalize_node_id(source) + tgt_label = self._normalize_node_id(target) query = """SELECT * FROM cypher('%s', $$ MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"}) @@ -1510,7 +1527,7 @@ class PGGraphStorage(BaseGraphStorage): # Format node IDs for the query formatted_ids = ", ".join( - ['"' + node_id.replace('"', "") + '"' for node_id in node_ids] + ['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids] ) query = """SELECT * FROM cypher('%s', $$ @@ -1553,7 +1570,7 @@ class PGGraphStorage(BaseGraphStorage): # Format node IDs for the query formatted_ids = ", ".join( - ['"' + node_id.replace('"', "") + '"' for node_id in node_ids] + ['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids] ) outgoing_query = """SELECT * FROM cypher('%s', $$ @@ -1650,8 +1667,8 @@ class PGGraphStorage(BaseGraphStorage): src_nodes = [] tgt_nodes = [] for pair in pairs: - src_nodes.append(pair["src"].replace('"', "")) - tgt_nodes.append(pair["tgt"].replace('"', "")) + src_nodes.append(self._normalize_node_id(pair["src"])) + tgt_nodes.append(self._normalize_node_id(pair["tgt"])) src_array = ", ".join([f'"{src}"' for src in src_nodes]) tgt_array = ", ".join([f'"{tgt}"' for tgt in tgt_nodes]) @@ -1706,7 +1723,7 @@ class PGGraphStorage(BaseGraphStorage): # Format node IDs for the query formatted_ids = ", ".join( - ['"' + node_id.replace('"', "") + '"' for node_id in node_ids] + ['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids] ) outgoing_query = """SELECT * FROM cypher('%s', $$ @@ -1794,7 +1811,7 @@ class PGGraphStorage(BaseGraphStorage): RETURN count(distinct n) AS total_nodes $$) AS (total_nodes bigint)""" else: - strip_label = node_label.strip('"') + strip_label = self._normalize_node_id(node_label) count_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ MATCH (n:base {{entity_id: "{strip_label}"}}) OPTIONAL MATCH p = (n)-[*..{max_depth}]-() @@ -1814,7 +1831,7 @@ class PGGraphStorage(BaseGraphStorage): LIMIT {max_nodes} $$) AS (n agtype, r agtype)""" else: - strip_label = node_label.strip('"') + strip_label = self._normalize_node_id(node_label) if total_nodes > 0: query = f"""SELECT * FROM cypher('{self.graph_name}', $$ MATCH (node:base {{entity_id: "{strip_label}"}})