fix: add node_id normalization to handle backslash escaping in PostgreSQL Cypher queries
This commit is contained in:
@@ -1037,6 +1037,23 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
def __post_init__(self):
|
||||
self.graph_name = self.namespace or os.environ.get("AGE_GRAPH_NAME", "lightrag")
|
||||
self.db: PostgreSQLDB | None = None
|
||||
|
||||
@staticmethod
|
||||
def _normalize_node_id(node_id: str) -> str:
|
||||
"""
|
||||
Normalize node ID to ensure special characters are properly handled in Cypher queries.
|
||||
|
||||
Args:
|
||||
node_id: The original node ID
|
||||
|
||||
Returns:
|
||||
Normalized node ID suitable for Cypher queries
|
||||
"""
|
||||
# Remove quotes
|
||||
normalized_id = node_id.strip('"')
|
||||
# Escape backslashes
|
||||
normalized_id = normalized_id.replace("\\", "\\\\")
|
||||
return normalized_id
|
||||
|
||||
async def initialize(self):
|
||||
if self.db is None:
|
||||
@@ -1222,7 +1239,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
return result
|
||||
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
entity_name_label = node_id.strip('"')
|
||||
entity_name_label = self._normalize_node_id(node_id)
|
||||
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (n:base {entity_id: "%s"})
|
||||
@@ -1234,8 +1251,8 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
return single_result["node_exists"]
|
||||
|
||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||
src_label = source_node_id.strip('"')
|
||||
tgt_label = target_node_id.strip('"')
|
||||
src_label = self._normalize_node_id(source_node_id)
|
||||
tgt_label = self._normalize_node_id(target_node_id)
|
||||
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"})
|
||||
@@ -1253,7 +1270,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||
"""Get node by its label identifier, return only node properties"""
|
||||
|
||||
label = node_id.strip('"')
|
||||
label = self._normalize_node_id(node_id)
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (n:base {entity_id: "%s"})
|
||||
RETURN n
|
||||
@@ -1267,7 +1284,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
return None
|
||||
|
||||
async def node_degree(self, node_id: str) -> int:
|
||||
label = node_id.strip('"')
|
||||
label = self._normalize_node_id(node_id)
|
||||
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (n:base {entity_id: "%s"})-[r]-()
|
||||
@@ -1295,8 +1312,8 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
) -> dict[str, str] | None:
|
||||
"""Get edge properties between two nodes"""
|
||||
|
||||
src_label = source_node_id.strip('"')
|
||||
tgt_label = target_node_id.strip('"')
|
||||
src_label = self._normalize_node_id(source_node_id)
|
||||
tgt_label = self._normalize_node_id(target_node_id)
|
||||
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"})
|
||||
@@ -1318,7 +1335,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
Retrieves all edges (relationships) for a particular node identified by its label.
|
||||
:return: list of dictionaries containing edge information
|
||||
"""
|
||||
label = source_node_id.strip('"')
|
||||
label = self._normalize_node_id(source_node_id)
|
||||
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (n:base {entity_id: "%s"})
|
||||
@@ -1367,7 +1384,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
"PostgreSQL: node properties must contain an 'entity_id' field"
|
||||
)
|
||||
|
||||
label = node_id.strip('"')
|
||||
label = self._normalize_node_id(node_id)
|
||||
properties = self._format_properties(node_data)
|
||||
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
@@ -1403,8 +1420,8 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
target_node_id (str): Label of the target node (used as identifier)
|
||||
edge_data (dict): dictionary of properties to set on the edge
|
||||
"""
|
||||
src_label = source_node_id.strip('"')
|
||||
tgt_label = target_node_id.strip('"')
|
||||
src_label = self._normalize_node_id(source_node_id)
|
||||
tgt_label = self._normalize_node_id(target_node_id)
|
||||
edge_properties = self._format_properties(edge_data)
|
||||
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
@@ -1437,7 +1454,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
Args:
|
||||
node_id (str): The ID of the node to delete.
|
||||
"""
|
||||
label = node_id.strip('"')
|
||||
label = self._normalize_node_id(node_id)
|
||||
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (n:base {entity_id: "%s"})
|
||||
@@ -1457,7 +1474,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
Args:
|
||||
node_ids (list[str]): A list of node IDs to remove.
|
||||
"""
|
||||
node_ids = [node_id.strip('"') for node_id in node_ids]
|
||||
node_ids = [self._normalize_node_id(node_id) for node_id in node_ids]
|
||||
node_id_list = ", ".join([f'"{node_id}"' for node_id in node_ids])
|
||||
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
@@ -1480,8 +1497,8 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id).
|
||||
"""
|
||||
for source, target in edges:
|
||||
src_label = source.strip('"')
|
||||
tgt_label = target.strip('"')
|
||||
src_label = self._normalize_node_id(source)
|
||||
tgt_label = self._normalize_node_id(target)
|
||||
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"})
|
||||
@@ -1510,7 +1527,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
|
||||
# Format node IDs for the query
|
||||
formatted_ids = ", ".join(
|
||||
['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
|
||||
['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids]
|
||||
)
|
||||
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
@@ -1553,7 +1570,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
|
||||
# Format node IDs for the query
|
||||
formatted_ids = ", ".join(
|
||||
['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
|
||||
['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids]
|
||||
)
|
||||
|
||||
outgoing_query = """SELECT * FROM cypher('%s', $$
|
||||
@@ -1650,8 +1667,8 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
src_nodes = []
|
||||
tgt_nodes = []
|
||||
for pair in pairs:
|
||||
src_nodes.append(pair["src"].replace('"', ""))
|
||||
tgt_nodes.append(pair["tgt"].replace('"', ""))
|
||||
src_nodes.append(self._normalize_node_id(pair["src"]))
|
||||
tgt_nodes.append(self._normalize_node_id(pair["tgt"]))
|
||||
|
||||
src_array = ", ".join([f'"{src}"' for src in src_nodes])
|
||||
tgt_array = ", ".join([f'"{tgt}"' for tgt in tgt_nodes])
|
||||
@@ -1706,7 +1723,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
|
||||
# Format node IDs for the query
|
||||
formatted_ids = ", ".join(
|
||||
['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
|
||||
['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids]
|
||||
)
|
||||
|
||||
outgoing_query = """SELECT * FROM cypher('%s', $$
|
||||
@@ -1794,7 +1811,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
RETURN count(distinct n) AS total_nodes
|
||||
$$) AS (total_nodes bigint)"""
|
||||
else:
|
||||
strip_label = node_label.strip('"')
|
||||
strip_label = self._normalize_node_id(node_label)
|
||||
count_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||
MATCH (n:base {{entity_id: "{strip_label}"}})
|
||||
OPTIONAL MATCH p = (n)-[*..{max_depth}]-()
|
||||
@@ -1814,7 +1831,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
LIMIT {max_nodes}
|
||||
$$) AS (n agtype, r agtype)"""
|
||||
else:
|
||||
strip_label = node_label.strip('"')
|
||||
strip_label = self._normalize_node_id(node_label)
|
||||
if total_nodes > 0:
|
||||
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||
MATCH (node:base {{entity_id: "{strip_label}"}})
|
||||
|
Reference in New Issue
Block a user