fix: add node_id normalization to handle backslash escaping in PostgreSQL Cypher queries

This commit is contained in:
yangdx
2025-04-17 02:31:56 +08:00
parent d4c4a40c53
commit 70d1eab9e7

View File

@@ -1037,6 +1037,23 @@ class PGGraphStorage(BaseGraphStorage):
def __post_init__(self): def __post_init__(self):
self.graph_name = self.namespace or os.environ.get("AGE_GRAPH_NAME", "lightrag") self.graph_name = self.namespace or os.environ.get("AGE_GRAPH_NAME", "lightrag")
self.db: PostgreSQLDB | None = None self.db: PostgreSQLDB | None = None
@staticmethod
def _normalize_node_id(node_id: str) -> str:
"""
Normalize node ID to ensure special characters are properly handled in Cypher queries.
Args:
node_id: The original node ID
Returns:
Normalized node ID suitable for Cypher queries
"""
# Remove quotes
normalized_id = node_id.strip('"')
# Escape backslashes
normalized_id = normalized_id.replace("\\", "\\\\")
return normalized_id
async def initialize(self): async def initialize(self):
if self.db is None: if self.db is None:
@@ -1222,7 +1239,7 @@ class PGGraphStorage(BaseGraphStorage):
return result return result
async def has_node(self, node_id: str) -> bool: async def has_node(self, node_id: str) -> bool:
entity_name_label = node_id.strip('"') entity_name_label = self._normalize_node_id(node_id)
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MATCH (n:base {entity_id: "%s"}) MATCH (n:base {entity_id: "%s"})
@@ -1234,8 +1251,8 @@ class PGGraphStorage(BaseGraphStorage):
return single_result["node_exists"] return single_result["node_exists"]
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
src_label = source_node_id.strip('"') src_label = self._normalize_node_id(source_node_id)
tgt_label = target_node_id.strip('"') tgt_label = self._normalize_node_id(target_node_id)
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"}) MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"})
@@ -1253,7 +1270,7 @@ class PGGraphStorage(BaseGraphStorage):
async def get_node(self, node_id: str) -> dict[str, str] | None: async def get_node(self, node_id: str) -> dict[str, str] | None:
"""Get node by its label identifier, return only node properties""" """Get node by its label identifier, return only node properties"""
label = node_id.strip('"') label = self._normalize_node_id(node_id)
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MATCH (n:base {entity_id: "%s"}) MATCH (n:base {entity_id: "%s"})
RETURN n RETURN n
@@ -1267,7 +1284,7 @@ class PGGraphStorage(BaseGraphStorage):
return None return None
async def node_degree(self, node_id: str) -> int: async def node_degree(self, node_id: str) -> int:
label = node_id.strip('"') label = self._normalize_node_id(node_id)
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MATCH (n:base {entity_id: "%s"})-[r]-() MATCH (n:base {entity_id: "%s"})-[r]-()
@@ -1295,8 +1312,8 @@ class PGGraphStorage(BaseGraphStorage):
) -> dict[str, str] | None: ) -> dict[str, str] | None:
"""Get edge properties between two nodes""" """Get edge properties between two nodes"""
src_label = source_node_id.strip('"') src_label = self._normalize_node_id(source_node_id)
tgt_label = target_node_id.strip('"') tgt_label = self._normalize_node_id(target_node_id)
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"}) MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"})
@@ -1318,7 +1335,7 @@ class PGGraphStorage(BaseGraphStorage):
Retrieves all edges (relationships) for a particular node identified by its label. Retrieves all edges (relationships) for a particular node identified by its label.
:return: list of dictionaries containing edge information :return: list of dictionaries containing edge information
""" """
label = source_node_id.strip('"') label = self._normalize_node_id(source_node_id)
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MATCH (n:base {entity_id: "%s"}) MATCH (n:base {entity_id: "%s"})
@@ -1367,7 +1384,7 @@ class PGGraphStorage(BaseGraphStorage):
"PostgreSQL: node properties must contain an 'entity_id' field" "PostgreSQL: node properties must contain an 'entity_id' field"
) )
label = node_id.strip('"') label = self._normalize_node_id(node_id)
properties = self._format_properties(node_data) properties = self._format_properties(node_data)
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
@@ -1403,8 +1420,8 @@ class PGGraphStorage(BaseGraphStorage):
target_node_id (str): Label of the target node (used as identifier) target_node_id (str): Label of the target node (used as identifier)
edge_data (dict): dictionary of properties to set on the edge edge_data (dict): dictionary of properties to set on the edge
""" """
src_label = source_node_id.strip('"') src_label = self._normalize_node_id(source_node_id)
tgt_label = target_node_id.strip('"') tgt_label = self._normalize_node_id(target_node_id)
edge_properties = self._format_properties(edge_data) edge_properties = self._format_properties(edge_data)
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
@@ -1437,7 +1454,7 @@ class PGGraphStorage(BaseGraphStorage):
Args: Args:
node_id (str): The ID of the node to delete. node_id (str): The ID of the node to delete.
""" """
label = node_id.strip('"') label = self._normalize_node_id(node_id)
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MATCH (n:base {entity_id: "%s"}) MATCH (n:base {entity_id: "%s"})
@@ -1457,7 +1474,7 @@ class PGGraphStorage(BaseGraphStorage):
Args: Args:
node_ids (list[str]): A list of node IDs to remove. node_ids (list[str]): A list of node IDs to remove.
""" """
node_ids = [node_id.strip('"') for node_id in node_ids] node_ids = [self._normalize_node_id(node_id) for node_id in node_ids]
node_id_list = ", ".join([f'"{node_id}"' for node_id in node_ids]) node_id_list = ", ".join([f'"{node_id}"' for node_id in node_ids])
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
@@ -1480,8 +1497,8 @@ class PGGraphStorage(BaseGraphStorage):
edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id). edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id).
""" """
for source, target in edges: for source, target in edges:
src_label = source.strip('"') src_label = self._normalize_node_id(source)
tgt_label = target.strip('"') tgt_label = self._normalize_node_id(target)
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"}) MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"})
@@ -1510,7 +1527,7 @@ class PGGraphStorage(BaseGraphStorage):
# Format node IDs for the query # Format node IDs for the query
formatted_ids = ", ".join( formatted_ids = ", ".join(
['"' + node_id.replace('"', "") + '"' for node_id in node_ids] ['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids]
) )
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
@@ -1553,7 +1570,7 @@ class PGGraphStorage(BaseGraphStorage):
# Format node IDs for the query # Format node IDs for the query
formatted_ids = ", ".join( formatted_ids = ", ".join(
['"' + node_id.replace('"', "") + '"' for node_id in node_ids] ['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids]
) )
outgoing_query = """SELECT * FROM cypher('%s', $$ outgoing_query = """SELECT * FROM cypher('%s', $$
@@ -1650,8 +1667,8 @@ class PGGraphStorage(BaseGraphStorage):
src_nodes = [] src_nodes = []
tgt_nodes = [] tgt_nodes = []
for pair in pairs: for pair in pairs:
src_nodes.append(pair["src"].replace('"', "")) src_nodes.append(self._normalize_node_id(pair["src"]))
tgt_nodes.append(pair["tgt"].replace('"', "")) tgt_nodes.append(self._normalize_node_id(pair["tgt"]))
src_array = ", ".join([f'"{src}"' for src in src_nodes]) src_array = ", ".join([f'"{src}"' for src in src_nodes])
tgt_array = ", ".join([f'"{tgt}"' for tgt in tgt_nodes]) tgt_array = ", ".join([f'"{tgt}"' for tgt in tgt_nodes])
@@ -1706,7 +1723,7 @@ class PGGraphStorage(BaseGraphStorage):
# Format node IDs for the query # Format node IDs for the query
formatted_ids = ", ".join( formatted_ids = ", ".join(
['"' + node_id.replace('"', "") + '"' for node_id in node_ids] ['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids]
) )
outgoing_query = """SELECT * FROM cypher('%s', $$ outgoing_query = """SELECT * FROM cypher('%s', $$
@@ -1794,7 +1811,7 @@ class PGGraphStorage(BaseGraphStorage):
RETURN count(distinct n) AS total_nodes RETURN count(distinct n) AS total_nodes
$$) AS (total_nodes bigint)""" $$) AS (total_nodes bigint)"""
else: else:
strip_label = node_label.strip('"') strip_label = self._normalize_node_id(node_label)
count_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ count_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
MATCH (n:base {{entity_id: "{strip_label}"}}) MATCH (n:base {{entity_id: "{strip_label}"}})
OPTIONAL MATCH p = (n)-[*..{max_depth}]-() OPTIONAL MATCH p = (n)-[*..{max_depth}]-()
@@ -1814,7 +1831,7 @@ class PGGraphStorage(BaseGraphStorage):
LIMIT {max_nodes} LIMIT {max_nodes}
$$) AS (n agtype, r agtype)""" $$) AS (n agtype, r agtype)"""
else: else:
strip_label = node_label.strip('"') strip_label = self._normalize_node_id(node_label)
if total_nodes > 0: if total_nodes > 0:
query = f"""SELECT * FROM cypher('{self.graph_name}', $$ query = f"""SELECT * FROM cypher('{self.graph_name}', $$
MATCH (node:base {{entity_id: "{strip_label}"}}) MATCH (node:base {{entity_id: "{strip_label}"}})