Merge pull request #1087 from JoramMillenaar/fix--postgres-impl
Fixed Postgres query parsing issues
This commit is contained in:
@@ -962,14 +962,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
vertices.get(edge["end_id"], {}),
|
||||
)
|
||||
else:
|
||||
if v is None:
|
||||
d[k] = v
|
||||
elif isinstance(v, str) and (v.count("{") < 1 and v.count("[") < 1):
|
||||
d[k] = v
|
||||
elif isinstance(v, str):
|
||||
d[k] = json.loads(v)
|
||||
else:
|
||||
d[k] = v
|
||||
d[k] = json.loads(v) if isinstance(v, str) and ("{" in v or "[" in v) else v
|
||||
|
||||
return d
|
||||
|
||||
@@ -1411,9 +1404,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
embed_func = self._node_embed_algorithms[algorithm]
|
||||
return await embed_func()
|
||||
|
||||
async def get_knowledge_graph(
|
||||
self, node_label: str, max_depth: int = 5
|
||||
) -> KnowledgeGraph:
|
||||
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
|
||||
"""
|
||||
Retrieve a subgraph containing the specified node and its neighbors up to the specified depth.
|
||||
|
||||
@@ -1426,29 +1417,22 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
"""
|
||||
MAX_GRAPH_NODES = 1000
|
||||
|
||||
# Build the query based on whether we want the full graph or a specific subgraph.
|
||||
if node_label == "*":
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||
MATCH (n:Entity)
|
||||
OPTIONAL MATCH (n)-[r]->(m:Entity)
|
||||
RETURN n, r, m
|
||||
LIMIT %d
|
||||
$$) AS (n agtype, r agtype, m agtype)""" % (
|
||||
self.graph_name,
|
||||
MAX_GRAPH_NODES,
|
||||
)
|
||||
LIMIT {MAX_GRAPH_NODES}
|
||||
$$) AS (n agtype, r agtype, m agtype)"""
|
||||
else:
|
||||
encoded_node_label = self._encode_graph_label(node_label.strip('"'))
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (n:Entity {node_id: "%s"})
|
||||
OPTIONAL MATCH p = (n)-[*..%d]-(m)
|
||||
encoded_label = self._encode_graph_label(node_label.strip('"'))
|
||||
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||
MATCH (n:Entity {{node_id: "{encoded_label}"}})
|
||||
OPTIONAL MATCH p = (n)-[*..{max_depth}]-(m)
|
||||
RETURN nodes(p) AS nodes, relationships(p) AS relationships
|
||||
LIMIT %d
|
||||
$$) AS (nodes agtype, relationships agtype)""" % (
|
||||
self.graph_name,
|
||||
encoded_node_label,
|
||||
max_depth,
|
||||
MAX_GRAPH_NODES,
|
||||
)
|
||||
LIMIT {MAX_GRAPH_NODES}
|
||||
$$) AS (nodes agtype, relationships agtype)"""
|
||||
|
||||
results = await self._query(query)
|
||||
|
||||
@@ -1456,61 +1440,48 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
edges = []
|
||||
unique_edge_ids = set()
|
||||
|
||||
for result in results:
|
||||
def add_node(node_data: dict):
|
||||
node_id = self._decode_graph_label(node_data["node_id"])
|
||||
if node_id not in nodes:
|
||||
nodes[node_id] = node_data
|
||||
|
||||
def add_edge(edge_data: list):
|
||||
src_id = self._decode_graph_label(edge_data[0]["node_id"])
|
||||
tgt_id = self._decode_graph_label(edge_data[2]["node_id"])
|
||||
edge_key = f"{src_id},{tgt_id}"
|
||||
if edge_key not in unique_edge_ids:
|
||||
unique_edge_ids.add(edge_key)
|
||||
edges.append((edge_key, src_id, tgt_id, {"source": edge_data[0], "target": edge_data[2]}))
|
||||
|
||||
# Process the query results.
|
||||
if node_label == "*":
|
||||
if result["n"]:
|
||||
node = result["n"]
|
||||
node_id = self._decode_graph_label(node["node_id"])
|
||||
if node_id not in nodes:
|
||||
nodes[node_id] = node
|
||||
|
||||
if result["m"]:
|
||||
node = result["m"]
|
||||
node_id = self._decode_graph_label(node["node_id"])
|
||||
if node_id not in nodes:
|
||||
nodes[node_id] = node
|
||||
if result["r"]:
|
||||
edge = result["r"]
|
||||
src_id = self._decode_graph_label(edge["start_id"])
|
||||
tgt_id = self._decode_graph_label(edge["end_id"])
|
||||
edges.append((src_id, tgt_id))
|
||||
for result in results:
|
||||
if result.get("n"):
|
||||
add_node(result["n"])
|
||||
if result.get("m"):
|
||||
add_node(result["m"])
|
||||
if result.get("r"):
|
||||
add_edge(result["r"])
|
||||
else:
|
||||
if result["nodes"]:
|
||||
for node in result["nodes"]:
|
||||
node_id = self._decode_graph_label(node["node_id"])
|
||||
if node_id not in nodes:
|
||||
nodes[node_id] = node
|
||||
|
||||
if result["relationships"]:
|
||||
for edge in result["relationships"]: # src --DIRECTED--> target
|
||||
src_id = self._decode_graph_label(edge[0]["node_id"])
|
||||
tgt_id = self._decode_graph_label(edge[2]["node_id"])
|
||||
id = src_id + "," + tgt_id
|
||||
if id in unique_edge_ids:
|
||||
continue
|
||||
else:
|
||||
unique_edge_ids.add(id)
|
||||
edges.append(
|
||||
(id, src_id, tgt_id, {"source": edge[0], "target": edge[2]})
|
||||
)
|
||||
for result in results:
|
||||
for node in result.get("nodes", []):
|
||||
add_node(node)
|
||||
for edge in result.get("relationships", []):
|
||||
add_edge(edge)
|
||||
|
||||
# Construct and return the KnowledgeGraph.
|
||||
kg = KnowledgeGraph(
|
||||
nodes=[
|
||||
KnowledgeGraphNode(
|
||||
id=node_id, labels=[node_id], properties=nodes[node_id]
|
||||
)
|
||||
for node_id in nodes
|
||||
KnowledgeGraphNode(id=node_id, labels=[node_id], properties=node_data)
|
||||
for node_id, node_data in nodes.items()
|
||||
],
|
||||
edges=[
|
||||
KnowledgeGraphEdge(
|
||||
id=id, type="DIRECTED", source=src, target=tgt, properties=props
|
||||
)
|
||||
for id, src, tgt, props in edges
|
||||
KnowledgeGraphEdge(id=edge_id, type="DIRECTED", source=src, target=tgt, properties=props)
|
||||
for edge_id, src, tgt, props in edges
|
||||
],
|
||||
)
|
||||
|
||||
return kg
|
||||
|
||||
async def drop(self) -> None:
|
||||
"""Drop the storage"""
|
||||
drop_sql = SQL_TEMPLATES["drop_vdb_entity"]
|
||||
|
Reference in New Issue
Block a user