fix: add node_id normalization to handle backslash escaping in PostgreSQL Cypher queries

This commit is contained in:
yangdx
2025-04-17 02:31:56 +08:00
parent d4c4a40c53
commit 70d1eab9e7

View File

@@ -1038,6 +1038,23 @@ class PGGraphStorage(BaseGraphStorage):
self.graph_name = self.namespace or os.environ.get("AGE_GRAPH_NAME", "lightrag")
self.db: PostgreSQLDB | None = None
@staticmethod
def _normalize_node_id(node_id: str) -> str:
"""
Normalize node ID to ensure special characters are properly handled in Cypher queries.
Args:
node_id: The original node ID
Returns:
Normalized node ID suitable for Cypher queries
"""
# Remove quotes
normalized_id = node_id.strip('"')
# Escape backslashes
normalized_id = normalized_id.replace("\\", "\\\\")
return normalized_id
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
@@ -1222,7 +1239,7 @@ class PGGraphStorage(BaseGraphStorage):
return result
async def has_node(self, node_id: str) -> bool:
entity_name_label = node_id.strip('"')
entity_name_label = self._normalize_node_id(node_id)
query = """SELECT * FROM cypher('%s', $$
MATCH (n:base {entity_id: "%s"})
@@ -1234,8 +1251,8 @@ class PGGraphStorage(BaseGraphStorage):
return single_result["node_exists"]
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
src_label = source_node_id.strip('"')
tgt_label = target_node_id.strip('"')
src_label = self._normalize_node_id(source_node_id)
tgt_label = self._normalize_node_id(target_node_id)
query = """SELECT * FROM cypher('%s', $$
MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"})
@@ -1253,7 +1270,7 @@ class PGGraphStorage(BaseGraphStorage):
async def get_node(self, node_id: str) -> dict[str, str] | None:
"""Get node by its label identifier, return only node properties"""
label = node_id.strip('"')
label = self._normalize_node_id(node_id)
query = """SELECT * FROM cypher('%s', $$
MATCH (n:base {entity_id: "%s"})
RETURN n
@@ -1267,7 +1284,7 @@ class PGGraphStorage(BaseGraphStorage):
return None
async def node_degree(self, node_id: str) -> int:
label = node_id.strip('"')
label = self._normalize_node_id(node_id)
query = """SELECT * FROM cypher('%s', $$
MATCH (n:base {entity_id: "%s"})-[r]-()
@@ -1295,8 +1312,8 @@ class PGGraphStorage(BaseGraphStorage):
) -> dict[str, str] | None:
"""Get edge properties between two nodes"""
src_label = source_node_id.strip('"')
tgt_label = target_node_id.strip('"')
src_label = self._normalize_node_id(source_node_id)
tgt_label = self._normalize_node_id(target_node_id)
query = """SELECT * FROM cypher('%s', $$
MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"})
@@ -1318,7 +1335,7 @@ class PGGraphStorage(BaseGraphStorage):
Retrieves all edges (relationships) for a particular node identified by its label.
:return: list of dictionaries containing edge information
"""
label = source_node_id.strip('"')
label = self._normalize_node_id(source_node_id)
query = """SELECT * FROM cypher('%s', $$
MATCH (n:base {entity_id: "%s"})
@@ -1367,7 +1384,7 @@ class PGGraphStorage(BaseGraphStorage):
"PostgreSQL: node properties must contain an 'entity_id' field"
)
label = node_id.strip('"')
label = self._normalize_node_id(node_id)
properties = self._format_properties(node_data)
query = """SELECT * FROM cypher('%s', $$
@@ -1403,8 +1420,8 @@ class PGGraphStorage(BaseGraphStorage):
target_node_id (str): Label of the target node (used as identifier)
edge_data (dict): dictionary of properties to set on the edge
"""
src_label = source_node_id.strip('"')
tgt_label = target_node_id.strip('"')
src_label = self._normalize_node_id(source_node_id)
tgt_label = self._normalize_node_id(target_node_id)
edge_properties = self._format_properties(edge_data)
query = """SELECT * FROM cypher('%s', $$
@@ -1437,7 +1454,7 @@ class PGGraphStorage(BaseGraphStorage):
Args:
node_id (str): The ID of the node to delete.
"""
label = node_id.strip('"')
label = self._normalize_node_id(node_id)
query = """SELECT * FROM cypher('%s', $$
MATCH (n:base {entity_id: "%s"})
@@ -1457,7 +1474,7 @@ class PGGraphStorage(BaseGraphStorage):
Args:
node_ids (list[str]): A list of node IDs to remove.
"""
node_ids = [node_id.strip('"') for node_id in node_ids]
node_ids = [self._normalize_node_id(node_id) for node_id in node_ids]
node_id_list = ", ".join([f'"{node_id}"' for node_id in node_ids])
query = """SELECT * FROM cypher('%s', $$
@@ -1480,8 +1497,8 @@ class PGGraphStorage(BaseGraphStorage):
edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id).
"""
for source, target in edges:
src_label = source.strip('"')
tgt_label = target.strip('"')
src_label = self._normalize_node_id(source)
tgt_label = self._normalize_node_id(target)
query = """SELECT * FROM cypher('%s', $$
MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"})
@@ -1510,7 +1527,7 @@ class PGGraphStorage(BaseGraphStorage):
# Format node IDs for the query
formatted_ids = ", ".join(
['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids]
)
query = """SELECT * FROM cypher('%s', $$
@@ -1553,7 +1570,7 @@ class PGGraphStorage(BaseGraphStorage):
# Format node IDs for the query
formatted_ids = ", ".join(
['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids]
)
outgoing_query = """SELECT * FROM cypher('%s', $$
@@ -1650,8 +1667,8 @@ class PGGraphStorage(BaseGraphStorage):
src_nodes = []
tgt_nodes = []
for pair in pairs:
src_nodes.append(pair["src"].replace('"', ""))
tgt_nodes.append(pair["tgt"].replace('"', ""))
src_nodes.append(self._normalize_node_id(pair["src"]))
tgt_nodes.append(self._normalize_node_id(pair["tgt"]))
src_array = ", ".join([f'"{src}"' for src in src_nodes])
tgt_array = ", ".join([f'"{tgt}"' for tgt in tgt_nodes])
@@ -1706,7 +1723,7 @@ class PGGraphStorage(BaseGraphStorage):
# Format node IDs for the query
formatted_ids = ", ".join(
['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids]
)
outgoing_query = """SELECT * FROM cypher('%s', $$
@@ -1794,7 +1811,7 @@ class PGGraphStorage(BaseGraphStorage):
RETURN count(distinct n) AS total_nodes
$$) AS (total_nodes bigint)"""
else:
strip_label = node_label.strip('"')
strip_label = self._normalize_node_id(node_label)
count_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
MATCH (n:base {{entity_id: "{strip_label}"}})
OPTIONAL MATCH p = (n)-[*..{max_depth}]-()
@@ -1814,7 +1831,7 @@ class PGGraphStorage(BaseGraphStorage):
LIMIT {max_nodes}
$$) AS (n agtype, r agtype)"""
else:
strip_label = node_label.strip('"')
strip_label = self._normalize_node_id(node_label)
if total_nodes > 0:
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
MATCH (node:base {{entity_id: "{strip_label}"}})