Refactoring PostgreSQL AGE graph db implementation

This commit is contained in:
yangdx
2025-04-03 15:16:48 +08:00
parent 8878c0e998
commit 33f5629d8a

View File

@@ -1064,31 +1064,11 @@ class PGGraphStorage(BaseGraphStorage):
if v.startswith("[") and v.endswith("]"): if v.startswith("[") and v.endswith("]"):
if "::vertex" in v: if "::vertex" in v:
v = v.replace("::vertex", "") v = v.replace("::vertex", "")
vertexes = json.loads(v) d[k] = json.loads(v)
dl = []
for vertex in vertexes:
prop = vertex.get("properties")
if not prop:
prop = {}
prop["label"] = PGGraphStorage._decode_graph_label(
prop["node_id"]
)
dl.append(prop)
d[k] = dl
elif "::edge" in v: elif "::edge" in v:
v = v.replace("::edge", "") v = v.replace("::edge", "")
edges = json.loads(v) d[k] = json.loads(v)
dl = []
for edge in edges:
dl.append(
(
vertices[edge["start_id"]],
edge["label"],
vertices[edge["end_id"]],
)
)
d[k] = dl
else: else:
print("WARNING: unsupported type") print("WARNING: unsupported type")
continue continue
@@ -1097,26 +1077,9 @@ class PGGraphStorage(BaseGraphStorage):
dtype = v.split("::")[-1] dtype = v.split("::")[-1]
v = v.split("::")[0] v = v.split("::")[0]
if dtype == "vertex": if dtype == "vertex":
vertex = json.loads(v) d[k] = json.loads(v)
field = vertex.get("properties")
if not field:
field = {}
field["label"] = PGGraphStorage._decode_graph_label(
field["node_id"]
)
d[k] = field
# convert edge from id-label->id by replacing id with node information
# we only do this if the vertex was also returned in the query
# this is an attempt to be consistent with neo4j implementation
elif dtype == "edge": elif dtype == "edge":
edge = json.loads(v) d[k] = json.loads(v)
d[k] = (
vertices.get(edge["start_id"], {}),
edge[
"label"
], # we don't use decode_graph_label(), since edge label is always "DIRECTED"
vertices.get(edge["end_id"], {}),
)
else: else:
d[k] = ( d[k] = (
json.loads(v) json.loads(v)
@@ -1152,56 +1115,6 @@ class PGGraphStorage(BaseGraphStorage):
) )
return "{" + ", ".join(props) + "}" return "{" + ", ".join(props) + "}"
@staticmethod
def _encode_graph_label(label: str) -> str:
"""
Since AGE supports only alphanumerical labels, we will encode generic label as HEX string
Args:
label (str): the original label
Returns:
str: the encoded label
"""
return "x" + label.encode().hex()
@staticmethod
def _decode_graph_label(encoded_label: str) -> str:
"""
Since AGE supports only alphanumerical labels, we will encode generic label as HEX string
Args:
encoded_label (str): the encoded label
Returns:
str: the decoded label
"""
return bytes.fromhex(encoded_label.removeprefix("x")).decode()
@staticmethod
def _get_col_name(field: str, idx: int) -> str:
"""
Convert a cypher return field to a pgsql select field
If possible keep the cypher column name, but create a generic name if necessary
Args:
field (str): a return field from a cypher query to be formatted for pgsql
idx (int): the position of the field in the return statement
Returns:
str: the field to be used in the pgsql select statement
"""
# remove white space
field = field.strip()
# if an alias is provided for the field, use it
if " as " in field:
return field.split(" as ")[-1].strip()
# if the return value is an unnamed primitive, give it a generic name
if field.isnumeric() or field in ("true", "false", "null"):
return f"column_{idx}"
# otherwise return the value stripping out some common special chars
return field.replace("(", "_").replace(")", "")
async def _query( async def _query(
self, self,
query: str, query: str,
@@ -1252,10 +1165,10 @@ class PGGraphStorage(BaseGraphStorage):
return result return result
async def has_node(self, node_id: str) -> bool: async def has_node(self, node_id: str) -> bool:
entity_name_label = self._encode_graph_label(node_id.strip('"')) entity_name_label = node_id.strip('"')
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MATCH (n:base {node_id: "%s"}) MATCH (n:base {entity_id: "%s"})
RETURN count(n) > 0 AS node_exists RETURN count(n) > 0 AS node_exists
$$) AS (node_exists bool)""" % (self.graph_name, entity_name_label) $$) AS (node_exists bool)""" % (self.graph_name, entity_name_label)
@@ -1264,11 +1177,11 @@ class PGGraphStorage(BaseGraphStorage):
return single_result["node_exists"] return single_result["node_exists"]
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
src_label = self._encode_graph_label(source_node_id.strip('"')) src_label = source_node_id.strip('"')
tgt_label = self._encode_graph_label(target_node_id.strip('"')) tgt_label = target_node_id.strip('"')
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MATCH (a:base {node_id: "%s"})-[r]-(b:base {node_id: "%s"}) MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"})
RETURN COUNT(r) > 0 AS edge_exists RETURN COUNT(r) > 0 AS edge_exists
$$) AS (edge_exists bool)""" % ( $$) AS (edge_exists bool)""" % (
self.graph_name, self.graph_name,
@@ -1281,13 +1194,14 @@ class PGGraphStorage(BaseGraphStorage):
return single_result["edge_exists"] return single_result["edge_exists"]
async def get_node(self, node_id: str) -> dict[str, str] | None: async def get_node(self, node_id: str) -> dict[str, str] | None:
label = self._encode_graph_label(node_id.strip('"')) label = node_id.strip('"')
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MATCH (n:base {node_id: "%s"}) MATCH (n:base {entity_id: "%s"})
RETURN n RETURN n
$$) AS (n agtype)""" % (self.graph_name, label) $$) AS (n agtype)""" % (self.graph_name, label)
record = await self._query(query) record = await self._query(query)
if record: if record:
print(f"Record: {record}")
node = record[0] node = record[0]
node_dict = node["n"] node_dict = node["n"]
@@ -1295,10 +1209,10 @@ class PGGraphStorage(BaseGraphStorage):
return None return None
async def node_degree(self, node_id: str) -> int: async def node_degree(self, node_id: str) -> int:
label = self._encode_graph_label(node_id.strip('"')) label = node_id.strip('"')
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MATCH (n:base {node_id: "%s"})-[]->(x) MATCH (n:base {entity_id: "%s"})-[]->(x)
RETURN count(x) AS total_edge_count RETURN count(x) AS total_edge_count
$$) AS (total_edge_count integer)""" % (self.graph_name, label) $$) AS (total_edge_count integer)""" % (self.graph_name, label)
record = (await self._query(query))[0] record = (await self._query(query))[0]
@@ -1322,11 +1236,11 @@ class PGGraphStorage(BaseGraphStorage):
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None: ) -> dict[str, str] | None:
src_label = self._encode_graph_label(source_node_id.strip('"')) src_label = source_node_id.strip('"')
tgt_label = self._encode_graph_label(target_node_id.strip('"')) tgt_label = target_node_id.strip('"')
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MATCH (a:base {node_id: "%s"})-[r]->(b:base {node_id: "%s"}) MATCH (a:base {entity_id: "%s"})-[r]->(b:base {entity_id: "%s"})
RETURN properties(r) as edge_properties RETURN properties(r) as edge_properties
LIMIT 1 LIMIT 1
$$) AS (edge_properties agtype)""" % ( $$) AS (edge_properties agtype)""" % (
@@ -1336,6 +1250,7 @@ class PGGraphStorage(BaseGraphStorage):
) )
record = await self._query(query) record = await self._query(query)
if record and record[0] and record[0]["edge_properties"]: if record and record[0] and record[0]["edge_properties"]:
print(f"Record: {record}")
result = record[0]["edge_properties"] result = record[0]["edge_properties"]
return result return result
@@ -1345,10 +1260,10 @@ class PGGraphStorage(BaseGraphStorage):
Retrieves all edges (relationships) for a particular node identified by its label. Retrieves all edges (relationships) for a particular node identified by its label.
:return: list of dictionaries containing edge information :return: list of dictionaries containing edge information
""" """
label = self._encode_graph_label(source_node_id.strip('"')) label = source_node_id.strip('"')
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MATCH (n:base {node_id: "%s"}) MATCH (n:base {entity_id: "%s"})
OPTIONAL MATCH (n)-[]-(connected:base) OPTIONAL MATCH (n)-[]-(connected:base)
RETURN n, connected RETURN n, connected
$$) AS (n agtype, connected agtype)""" % ( $$) AS (n agtype, connected agtype)""" % (
@@ -1362,24 +1277,17 @@ class PGGraphStorage(BaseGraphStorage):
source_node = record["n"] if record["n"] else None source_node = record["n"] if record["n"] else None
connected_node = record["connected"] if record["connected"] else None connected_node = record["connected"] if record["connected"] else None
source_label = ( if (
source_node["node_id"] source_node
if source_node and source_node["node_id"] and connected_node
else None and "properties" in source_node
) and "properties" in connected_node
target_label = ( ):
connected_node["node_id"] source_label = source_node["properties"].get("entity_id")
if connected_node and connected_node["node_id"] target_label = connected_node["properties"].get("entity_id")
else None
)
if source_label and target_label: if source_label and target_label:
edges.append( edges.append((source_label, target_label))
(
self._decode_graph_label(source_label),
self._decode_graph_label(target_label),
)
)
return edges return edges
@@ -1389,17 +1297,17 @@ class PGGraphStorage(BaseGraphStorage):
retry=retry_if_exception_type((PGGraphQueryException,)), retry=retry_if_exception_type((PGGraphQueryException,)),
) )
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
label = self._encode_graph_label(node_id.strip('"')) label = node_id.strip('"')
properties = node_data properties = self._format_properties(node_data)
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MERGE (n:base {node_id: "%s"}) MERGE (n:base {entity_id: "%s"})
SET n += %s SET n += %s
RETURN n RETURN n
$$) AS (n agtype)""" % ( $$) AS (n agtype)""" % (
self.graph_name, self.graph_name,
label, label,
self._format_properties(properties), properties,
) )
try: try:
@@ -1425,14 +1333,14 @@ class PGGraphStorage(BaseGraphStorage):
target_node_id (str): Label of the target node (used as identifier) target_node_id (str): Label of the target node (used as identifier)
edge_data (dict): dictionary of properties to set on the edge edge_data (dict): dictionary of properties to set on the edge
""" """
src_label = self._encode_graph_label(source_node_id.strip('"')) src_label = source_node_id.strip('"')
tgt_label = self._encode_graph_label(target_node_id.strip('"')) tgt_label = target_node_id.strip('"')
edge_properties = edge_data edge_properties = self._format_properties(edge_data)
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MATCH (source:base {node_id: "%s"}) MATCH (source:base {entity_id: "%s"})
WITH source WITH source
MATCH (target:base {node_id: "%s"}) MATCH (target:base {entity_id: "%s"})
MERGE (source)-[r:DIRECTED]->(target) MERGE (source)-[r:DIRECTED]->(target)
SET r += %s SET r += %s
RETURN r RETURN r
@@ -1440,7 +1348,7 @@ class PGGraphStorage(BaseGraphStorage):
self.graph_name, self.graph_name,
src_label, src_label,
tgt_label, tgt_label,
self._format_properties(edge_properties), edge_properties,
) )
try: try:
@@ -1460,7 +1368,7 @@ class PGGraphStorage(BaseGraphStorage):
Args: Args:
node_id (str): The ID of the node to delete. node_id (str): The ID of the node to delete.
""" """
label = self._encode_graph_label(node_id.strip('"')) label = node_id.strip('"')
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MATCH (n:base {entity_id: "%s"}) MATCH (n:base {entity_id: "%s"})
@@ -1480,14 +1388,12 @@ class PGGraphStorage(BaseGraphStorage):
Args: Args:
node_ids (list[str]): A list of node IDs to remove. node_ids (list[str]): A list of node IDs to remove.
""" """
encoded_node_ids = [ node_ids = [node_id.strip('"') for node_id in node_ids]
self._encode_graph_label(node_id.strip('"')) for node_id in node_ids node_id_list = ", ".join([f'"{node_id}"' for node_id in node_ids])
]
node_id_list = ", ".join([f'"{node_id}"' for node_id in encoded_node_ids])
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MATCH (n:base) MATCH (n:base)
WHERE n.nentity_id IN [%s] WHERE n.entity_id IN [%s]
DETACH DELETE n DETACH DELETE n
$$) AS (n agtype)""" % (self.graph_name, node_id_list) $$) AS (n agtype)""" % (self.graph_name, node_id_list)
@@ -1505,11 +1411,11 @@ class PGGraphStorage(BaseGraphStorage):
edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id). edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id).
""" """
for source, target in edges: for source, target in edges:
src_label = self._encode_graph_label(source.strip('"')) src_label = source.strip('"')
tgt_label = self._encode_graph_label(target.strip('"')) tgt_label = target.strip('"')
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MATCH (a:base {node_id: "%s"})-[r]->(b:base {node_id: "%s"}) MATCH (a:base {entity_id: "%s"})-[r]->(b:base {entity_id: "%s"})
DELETE r DELETE r
$$) AS (r agtype)""" % (self.graph_name, src_label, tgt_label) $$) AS (r agtype)""" % (self.graph_name, src_label, tgt_label)
@@ -1560,95 +1466,98 @@ class PGGraphStorage(BaseGraphStorage):
return await embed_func() return await embed_func()
async def get_knowledge_graph( async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5 self,
node_label: str,
max_depth: int = 3,
max_nodes: int = MAX_GRAPH_NODES,
) -> KnowledgeGraph: ) -> KnowledgeGraph:
""" """
Retrieve a subgraph containing the specified node and its neighbors up to the specified depth. Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
Args: Args:
node_label (str): The label of the node to start from. If "*", the entire graph is returned. node_label: Label of the starting node, * means all nodes
max_depth (int): The maximum depth to traverse from the starting node. max_depth: Maximum depth of the subgraph, Defaults to 3
max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
Returns: Returns:
KnowledgeGraph: The retrieved subgraph. KnowledgeGraph object containing nodes and edges, with an is_truncated flag
indicating whether the graph was truncated due to max_nodes limit
""" """
MAX_GRAPH_NODES = 1000
# Build the query based on whether we want the full graph or a specific subgraph. # Build the query based on whether we want the full graph or a specific subgraph.
if node_label == "*": if node_label == "*":
query = f"""SELECT * FROM cypher('{self.graph_name}', $$ query = f"""SELECT * FROM cypher('{self.graph_name}', $$
MATCH (n:base) MATCH (n:base)
OPTIONAL MATCH (n)-[r]->(m:base) OPTIONAL MATCH (n)-[r]->(target:base)
RETURN n, r, m RETURN collect(distinct n) AS n, collect(distinct r) AS r
LIMIT {MAX_GRAPH_NODES} LIMIT {MAX_GRAPH_NODES}
$$) AS (n agtype, r agtype, m agtype)""" $$) AS (n agtype, r agtype)"""
else: else:
encoded_label = self._encode_graph_label(node_label.strip('"')) strip_label = node_label.strip('"')
query = f"""SELECT * FROM cypher('{self.graph_name}', $$ query = f"""SELECT * FROM cypher('{self.graph_name}', $$
MATCH (n:base {{entity_id: "{encoded_label}"}}) MATCH (n:base {{entity_id: "{strip_label}"}})
OPTIONAL MATCH p = (n)-[*..{max_depth}]-(m) OPTIONAL MATCH p = (n)-[*..{max_depth}]-(m)
RETURN nodes(p) AS nodes, relationships(p) AS relationships RETURN nodes(p) AS n, relationships(p) AS r
LIMIT {MAX_GRAPH_NODES} LIMIT {max_nodes}
$$) AS (nodes agtype, relationships agtype)""" $$) AS (n agtype, r agtype)"""
results = await self._query(query) results = await self._query(query)
nodes = {} # Process the query results with deduplication by node and edge IDs
edges = [] nodes_dict = {}
unique_edge_ids = set() edges_dict = {}
for result in results:
def add_node(node_data: dict): # Handle single node cases
node_id = self._decode_graph_label(node_data["node_id"]) if result.get("n") and isinstance(result["n"], dict):
if node_id not in nodes: node_id = str(result["n"]["id"])
nodes[node_id] = node_data if node_id not in nodes_dict:
nodes_dict[node_id] = KnowledgeGraphNode(
def add_edge(edge_data: list): id=node_id,
src_id = self._decode_graph_label(edge_data[0]["node_id"]) labels=[result["n"]["properties"]["entity_id"]],
tgt_id = self._decode_graph_label(edge_data[2]["node_id"]) properties=result["n"]["properties"],
edge_key = f"{src_id},{tgt_id}"
if edge_key not in unique_edge_ids:
unique_edge_ids.add(edge_key)
edges.append(
(
edge_key,
src_id,
tgt_id,
{"source": edge_data[0], "target": edge_data[2]},
) )
) # Handle node list cases
elif result.get("n") and isinstance(result["n"], list):
for node in result["n"]:
if isinstance(node, dict) and "id" in node:
node_id = str(node["id"])
if node_id not in nodes_dict and "properties" in node:
nodes_dict[node_id] = KnowledgeGraphNode(
id=node_id,
labels=[node["properties"]["entity_id"]],
properties=node["properties"],
)
# Process the query results. # Handle single edge cases
if node_label == "*": if result.get("r") and isinstance(result["r"], dict):
for result in results: edge_id = str(result["r"]["id"])
if result.get("n"): if edge_id not in edges_dict:
add_node(result["n"]) edges_dict[edge_id] = KnowledgeGraphEdge(
if result.get("m"): id=edge_id,
add_node(result["m"]) type="DIRECTED",
if result.get("r"): source=str(result["r"]["start_id"]),
add_edge(result["r"]) target=str(result["r"]["end_id"]),
else: properties=result["r"]["properties"],
for result in results: )
for node in result.get("nodes", []): # Handle edge list cases
add_node(node) elif result.get("r") and isinstance(result["r"], list):
for edge in result.get("relationships", []): for edge in result["r"]:
add_edge(edge) if isinstance(edge, dict) and "id" in edge:
edge_id = str(edge["id"])
if edge_id not in edges_dict:
edges_dict[edge_id] = KnowledgeGraphEdge(
id=edge_id,
type="DIRECTED",
source=str(edge["start_id"]),
target=str(edge["end_id"]),
properties=edge["properties"],
)
# Construct and return the KnowledgeGraph. # Construct and return the KnowledgeGraph with deduplicated nodes and edges
kg = KnowledgeGraph( kg = KnowledgeGraph(
nodes=[ nodes=list(nodes_dict.values()),
KnowledgeGraphNode(id=node_id, labels=[node_id], properties=node_data) edges=list(edges_dict.values()),
for node_id, node_data in nodes.items() is_truncated=False,
],
edges=[
KnowledgeGraphEdge(
id=edge_id,
type="DIRECTED",
source=src,
target=tgt,
properties=props,
)
for edge_id, src, tgt, props in edges
],
) )
return kg return kg