Merge pull request #1087 from JoramMillenaar/fix--postgres-impl

Fixed Postgres query parsing issues
This commit is contained in:
zrguo
2025-03-17 15:57:41 +08:00
committed by GitHub

View File

@@ -962,14 +962,7 @@ class PGGraphStorage(BaseGraphStorage):
vertices.get(edge["end_id"], {}), vertices.get(edge["end_id"], {}),
) )
else: else:
if v is None: d[k] = json.loads(v) if isinstance(v, str) and ("{" in v or "[" in v) else v
d[k] = v
elif isinstance(v, str) and (v.count("{") < 1 and v.count("[") < 1):
d[k] = v
elif isinstance(v, str):
d[k] = json.loads(v)
else:
d[k] = v
return d return d
@@ -1411,9 +1404,7 @@ class PGGraphStorage(BaseGraphStorage):
embed_func = self._node_embed_algorithms[algorithm] embed_func = self._node_embed_algorithms[algorithm]
return await embed_func() return await embed_func()
async def get_knowledge_graph( async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
""" """
Retrieve a subgraph containing the specified node and its neighbors up to the specified depth. Retrieve a subgraph containing the specified node and its neighbors up to the specified depth.
@@ -1426,29 +1417,22 @@ class PGGraphStorage(BaseGraphStorage):
""" """
MAX_GRAPH_NODES = 1000 MAX_GRAPH_NODES = 1000
# Build the query based on whether we want the full graph or a specific subgraph.
if node_label == "*": if node_label == "*":
query = """SELECT * FROM cypher('%s', $$ query = f"""SELECT * FROM cypher('{self.graph_name}', $$
MATCH (n:Entity) MATCH (n:Entity)
OPTIONAL MATCH (n)-[r]->(m:Entity) OPTIONAL MATCH (n)-[r]->(m:Entity)
RETURN n, r, m RETURN n, r, m
LIMIT %d LIMIT {MAX_GRAPH_NODES}
$$) AS (n agtype, r agtype, m agtype)""" % ( $$) AS (n agtype, r agtype, m agtype)"""
self.graph_name,
MAX_GRAPH_NODES,
)
else: else:
encoded_node_label = self._encode_graph_label(node_label.strip('"')) encoded_label = self._encode_graph_label(node_label.strip('"'))
query = """SELECT * FROM cypher('%s', $$ query = f"""SELECT * FROM cypher('{self.graph_name}', $$
MATCH (n:Entity {node_id: "%s"}) MATCH (n:Entity {{node_id: "{encoded_label}"}})
OPTIONAL MATCH p = (n)-[*..%d]-(m) OPTIONAL MATCH p = (n)-[*..{max_depth}]-(m)
RETURN nodes(p) AS nodes, relationships(p) AS relationships RETURN nodes(p) AS nodes, relationships(p) AS relationships
LIMIT %d LIMIT {MAX_GRAPH_NODES}
$$) AS (nodes agtype, relationships agtype)""" % ( $$) AS (nodes agtype, relationships agtype)"""
self.graph_name,
encoded_node_label,
max_depth,
MAX_GRAPH_NODES,
)
results = await self._query(query) results = await self._query(query)
@@ -1456,61 +1440,48 @@ class PGGraphStorage(BaseGraphStorage):
edges = [] edges = []
unique_edge_ids = set() unique_edge_ids = set()
for result in results: def add_node(node_data: dict):
if node_label == "*": node_id = self._decode_graph_label(node_data["node_id"])
if result["n"]: if node_id not in nodes:
node = result["n"] nodes[node_id] = node_data
node_id = self._decode_graph_label(node["node_id"])
if node_id not in nodes:
nodes[node_id] = node
if result["m"]: def add_edge(edge_data: list):
node = result["m"] src_id = self._decode_graph_label(edge_data[0]["node_id"])
node_id = self._decode_graph_label(node["node_id"]) tgt_id = self._decode_graph_label(edge_data[2]["node_id"])
if node_id not in nodes: edge_key = f"{src_id},{tgt_id}"
nodes[node_id] = node if edge_key not in unique_edge_ids:
if result["r"]: unique_edge_ids.add(edge_key)
edge = result["r"] edges.append((edge_key, src_id, tgt_id, {"source": edge_data[0], "target": edge_data[2]}))
src_id = self._decode_graph_label(edge["start_id"])
tgt_id = self._decode_graph_label(edge["end_id"])
edges.append((src_id, tgt_id))
else:
if result["nodes"]:
for node in result["nodes"]:
node_id = self._decode_graph_label(node["node_id"])
if node_id not in nodes:
nodes[node_id] = node
if result["relationships"]: # Process the query results.
for edge in result["relationships"]: # src --DIRECTED--> target if node_label == "*":
src_id = self._decode_graph_label(edge[0]["node_id"]) for result in results:
tgt_id = self._decode_graph_label(edge[2]["node_id"]) if result.get("n"):
id = src_id + "," + tgt_id add_node(result["n"])
if id in unique_edge_ids: if result.get("m"):
continue add_node(result["m"])
else: if result.get("r"):
unique_edge_ids.add(id) add_edge(result["r"])
edges.append( else:
(id, src_id, tgt_id, {"source": edge[0], "target": edge[2]}) for result in results:
) for node in result.get("nodes", []):
add_node(node)
for edge in result.get("relationships", []):
add_edge(edge)
# Construct and return the KnowledgeGraph.
kg = KnowledgeGraph( kg = KnowledgeGraph(
nodes=[ nodes=[
KnowledgeGraphNode( KnowledgeGraphNode(id=node_id, labels=[node_id], properties=node_data)
id=node_id, labels=[node_id], properties=nodes[node_id] for node_id, node_data in nodes.items()
)
for node_id in nodes
], ],
edges=[ edges=[
KnowledgeGraphEdge( KnowledgeGraphEdge(id=edge_id, type="DIRECTED", source=src, target=tgt, properties=props)
id=id, type="DIRECTED", source=src, target=tgt, properties=props for edge_id, src, tgt, props in edges
)
for id, src, tgt, props in edges
], ],
) )
return kg return kg
async def drop(self) -> None: async def drop(self) -> None:
"""Drop the storage""" """Drop the storage"""
drop_sql = SQL_TEMPLATES["drop_vdb_entity"] drop_sql = SQL_TEMPLATES["drop_vdb_entity"]