fix: add node_id normalization to handle backslash escaping in PostgreSQL Cypher queries
This commit is contained in:
@@ -1037,6 +1037,23 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.graph_name = self.namespace or os.environ.get("AGE_GRAPH_NAME", "lightrag")
|
self.graph_name = self.namespace or os.environ.get("AGE_GRAPH_NAME", "lightrag")
|
||||||
self.db: PostgreSQLDB | None = None
|
self.db: PostgreSQLDB | None = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_node_id(node_id: str) -> str:
|
||||||
|
"""
|
||||||
|
Normalize node ID to ensure special characters are properly handled in Cypher queries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_id: The original node ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized node ID suitable for Cypher queries
|
||||||
|
"""
|
||||||
|
# Remove quotes
|
||||||
|
normalized_id = node_id.strip('"')
|
||||||
|
# Escape backslashes
|
||||||
|
normalized_id = normalized_id.replace("\\", "\\\\")
|
||||||
|
return normalized_id
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
if self.db is None:
|
if self.db is None:
|
||||||
@@ -1222,7 +1239,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
async def has_node(self, node_id: str) -> bool:
|
async def has_node(self, node_id: str) -> bool:
|
||||||
entity_name_label = node_id.strip('"')
|
entity_name_label = self._normalize_node_id(node_id)
|
||||||
|
|
||||||
query = """SELECT * FROM cypher('%s', $$
|
query = """SELECT * FROM cypher('%s', $$
|
||||||
MATCH (n:base {entity_id: "%s"})
|
MATCH (n:base {entity_id: "%s"})
|
||||||
@@ -1234,8 +1251,8 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
return single_result["node_exists"]
|
return single_result["node_exists"]
|
||||||
|
|
||||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||||
src_label = source_node_id.strip('"')
|
src_label = self._normalize_node_id(source_node_id)
|
||||||
tgt_label = target_node_id.strip('"')
|
tgt_label = self._normalize_node_id(target_node_id)
|
||||||
|
|
||||||
query = """SELECT * FROM cypher('%s', $$
|
query = """SELECT * FROM cypher('%s', $$
|
||||||
MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"})
|
MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"})
|
||||||
@@ -1253,7 +1270,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||||
"""Get node by its label identifier, return only node properties"""
|
"""Get node by its label identifier, return only node properties"""
|
||||||
|
|
||||||
label = node_id.strip('"')
|
label = self._normalize_node_id(node_id)
|
||||||
query = """SELECT * FROM cypher('%s', $$
|
query = """SELECT * FROM cypher('%s', $$
|
||||||
MATCH (n:base {entity_id: "%s"})
|
MATCH (n:base {entity_id: "%s"})
|
||||||
RETURN n
|
RETURN n
|
||||||
@@ -1267,7 +1284,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
async def node_degree(self, node_id: str) -> int:
|
async def node_degree(self, node_id: str) -> int:
|
||||||
label = node_id.strip('"')
|
label = self._normalize_node_id(node_id)
|
||||||
|
|
||||||
query = """SELECT * FROM cypher('%s', $$
|
query = """SELECT * FROM cypher('%s', $$
|
||||||
MATCH (n:base {entity_id: "%s"})-[r]-()
|
MATCH (n:base {entity_id: "%s"})-[r]-()
|
||||||
@@ -1295,8 +1312,8 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
) -> dict[str, str] | None:
|
) -> dict[str, str] | None:
|
||||||
"""Get edge properties between two nodes"""
|
"""Get edge properties between two nodes"""
|
||||||
|
|
||||||
src_label = source_node_id.strip('"')
|
src_label = self._normalize_node_id(source_node_id)
|
||||||
tgt_label = target_node_id.strip('"')
|
tgt_label = self._normalize_node_id(target_node_id)
|
||||||
|
|
||||||
query = """SELECT * FROM cypher('%s', $$
|
query = """SELECT * FROM cypher('%s', $$
|
||||||
MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"})
|
MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"})
|
||||||
@@ -1318,7 +1335,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
Retrieves all edges (relationships) for a particular node identified by its label.
|
Retrieves all edges (relationships) for a particular node identified by its label.
|
||||||
:return: list of dictionaries containing edge information
|
:return: list of dictionaries containing edge information
|
||||||
"""
|
"""
|
||||||
label = source_node_id.strip('"')
|
label = self._normalize_node_id(source_node_id)
|
||||||
|
|
||||||
query = """SELECT * FROM cypher('%s', $$
|
query = """SELECT * FROM cypher('%s', $$
|
||||||
MATCH (n:base {entity_id: "%s"})
|
MATCH (n:base {entity_id: "%s"})
|
||||||
@@ -1367,7 +1384,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
"PostgreSQL: node properties must contain an 'entity_id' field"
|
"PostgreSQL: node properties must contain an 'entity_id' field"
|
||||||
)
|
)
|
||||||
|
|
||||||
label = node_id.strip('"')
|
label = self._normalize_node_id(node_id)
|
||||||
properties = self._format_properties(node_data)
|
properties = self._format_properties(node_data)
|
||||||
|
|
||||||
query = """SELECT * FROM cypher('%s', $$
|
query = """SELECT * FROM cypher('%s', $$
|
||||||
@@ -1403,8 +1420,8 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
target_node_id (str): Label of the target node (used as identifier)
|
target_node_id (str): Label of the target node (used as identifier)
|
||||||
edge_data (dict): dictionary of properties to set on the edge
|
edge_data (dict): dictionary of properties to set on the edge
|
||||||
"""
|
"""
|
||||||
src_label = source_node_id.strip('"')
|
src_label = self._normalize_node_id(source_node_id)
|
||||||
tgt_label = target_node_id.strip('"')
|
tgt_label = self._normalize_node_id(target_node_id)
|
||||||
edge_properties = self._format_properties(edge_data)
|
edge_properties = self._format_properties(edge_data)
|
||||||
|
|
||||||
query = """SELECT * FROM cypher('%s', $$
|
query = """SELECT * FROM cypher('%s', $$
|
||||||
@@ -1437,7 +1454,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
Args:
|
Args:
|
||||||
node_id (str): The ID of the node to delete.
|
node_id (str): The ID of the node to delete.
|
||||||
"""
|
"""
|
||||||
label = node_id.strip('"')
|
label = self._normalize_node_id(node_id)
|
||||||
|
|
||||||
query = """SELECT * FROM cypher('%s', $$
|
query = """SELECT * FROM cypher('%s', $$
|
||||||
MATCH (n:base {entity_id: "%s"})
|
MATCH (n:base {entity_id: "%s"})
|
||||||
@@ -1457,7 +1474,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
Args:
|
Args:
|
||||||
node_ids (list[str]): A list of node IDs to remove.
|
node_ids (list[str]): A list of node IDs to remove.
|
||||||
"""
|
"""
|
||||||
node_ids = [node_id.strip('"') for node_id in node_ids]
|
node_ids = [self._normalize_node_id(node_id) for node_id in node_ids]
|
||||||
node_id_list = ", ".join([f'"{node_id}"' for node_id in node_ids])
|
node_id_list = ", ".join([f'"{node_id}"' for node_id in node_ids])
|
||||||
|
|
||||||
query = """SELECT * FROM cypher('%s', $$
|
query = """SELECT * FROM cypher('%s', $$
|
||||||
@@ -1480,8 +1497,8 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id).
|
edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id).
|
||||||
"""
|
"""
|
||||||
for source, target in edges:
|
for source, target in edges:
|
||||||
src_label = source.strip('"')
|
src_label = self._normalize_node_id(source)
|
||||||
tgt_label = target.strip('"')
|
tgt_label = self._normalize_node_id(target)
|
||||||
|
|
||||||
query = """SELECT * FROM cypher('%s', $$
|
query = """SELECT * FROM cypher('%s', $$
|
||||||
MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"})
|
MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"})
|
||||||
@@ -1510,7 +1527,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
# Format node IDs for the query
|
# Format node IDs for the query
|
||||||
formatted_ids = ", ".join(
|
formatted_ids = ", ".join(
|
||||||
['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
|
['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids]
|
||||||
)
|
)
|
||||||
|
|
||||||
query = """SELECT * FROM cypher('%s', $$
|
query = """SELECT * FROM cypher('%s', $$
|
||||||
@@ -1553,7 +1570,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
# Format node IDs for the query
|
# Format node IDs for the query
|
||||||
formatted_ids = ", ".join(
|
formatted_ids = ", ".join(
|
||||||
['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
|
['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids]
|
||||||
)
|
)
|
||||||
|
|
||||||
outgoing_query = """SELECT * FROM cypher('%s', $$
|
outgoing_query = """SELECT * FROM cypher('%s', $$
|
||||||
@@ -1650,8 +1667,8 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
src_nodes = []
|
src_nodes = []
|
||||||
tgt_nodes = []
|
tgt_nodes = []
|
||||||
for pair in pairs:
|
for pair in pairs:
|
||||||
src_nodes.append(pair["src"].replace('"', ""))
|
src_nodes.append(self._normalize_node_id(pair["src"]))
|
||||||
tgt_nodes.append(pair["tgt"].replace('"', ""))
|
tgt_nodes.append(self._normalize_node_id(pair["tgt"]))
|
||||||
|
|
||||||
src_array = ", ".join([f'"{src}"' for src in src_nodes])
|
src_array = ", ".join([f'"{src}"' for src in src_nodes])
|
||||||
tgt_array = ", ".join([f'"{tgt}"' for tgt in tgt_nodes])
|
tgt_array = ", ".join([f'"{tgt}"' for tgt in tgt_nodes])
|
||||||
@@ -1706,7 +1723,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
# Format node IDs for the query
|
# Format node IDs for the query
|
||||||
formatted_ids = ", ".join(
|
formatted_ids = ", ".join(
|
||||||
['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
|
['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids]
|
||||||
)
|
)
|
||||||
|
|
||||||
outgoing_query = """SELECT * FROM cypher('%s', $$
|
outgoing_query = """SELECT * FROM cypher('%s', $$
|
||||||
@@ -1794,7 +1811,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
RETURN count(distinct n) AS total_nodes
|
RETURN count(distinct n) AS total_nodes
|
||||||
$$) AS (total_nodes bigint)"""
|
$$) AS (total_nodes bigint)"""
|
||||||
else:
|
else:
|
||||||
strip_label = node_label.strip('"')
|
strip_label = self._normalize_node_id(node_label)
|
||||||
count_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
count_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||||
MATCH (n:base {{entity_id: "{strip_label}"}})
|
MATCH (n:base {{entity_id: "{strip_label}"}})
|
||||||
OPTIONAL MATCH p = (n)-[*..{max_depth}]-()
|
OPTIONAL MATCH p = (n)-[*..{max_depth}]-()
|
||||||
@@ -1814,7 +1831,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
LIMIT {max_nodes}
|
LIMIT {max_nodes}
|
||||||
$$) AS (n agtype, r agtype)"""
|
$$) AS (n agtype, r agtype)"""
|
||||||
else:
|
else:
|
||||||
strip_label = node_label.strip('"')
|
strip_label = self._normalize_node_id(node_label)
|
||||||
if total_nodes > 0:
|
if total_nodes > 0:
|
||||||
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||||
MATCH (node:base {{entity_id: "{strip_label}"}})
|
MATCH (node:base {{entity_id: "{strip_label}"}})
|
||||||
|
Reference in New Issue
Block a user