Fixed some query parsing issues

This commit is contained in:
jofoks
2025-03-13 11:30:52 -07:00
parent 2ffd7f9111
commit edc95126de

View File

@@ -950,10 +950,7 @@ class PGGraphStorage(BaseGraphStorage):
vertices.get(edge["end_id"], {}),
)
else:
if v is None or (v.count("{") < 1 and v.count("[") < 1):
d[k] = v
else:
d[k] = json.loads(v) if isinstance(v, str) else v
d[k] = json.loads(v) if isinstance(v, str) and ("{" in v or "[" in v) else v
return d
@@ -1395,9 +1392,7 @@ class PGGraphStorage(BaseGraphStorage):
embed_func = self._node_embed_algorithms[algorithm]
return await embed_func()
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
"""
Retrieve a subgraph containing the specified node and its neighbors up to the specified depth.
@@ -1410,29 +1405,22 @@ class PGGraphStorage(BaseGraphStorage):
"""
MAX_GRAPH_NODES = 1000
# Build the query based on whether we want the full graph or a specific subgraph.
if node_label == "*":
query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity)
OPTIONAL MATCH (n)-[r]->(m:Entity)
RETURN n, r, m
LIMIT %d
$$) AS (n agtype, r agtype, m agtype)""" % (
self.graph_name,
MAX_GRAPH_NODES,
)
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
MATCH (n:Entity)
OPTIONAL MATCH (n)-[r]->(m:Entity)
RETURN n, r, m
LIMIT {MAX_GRAPH_NODES}
$$) AS (n agtype, r agtype, m agtype)"""
else:
encoded_node_label = self._encode_graph_label(node_label.strip('"'))
query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity {node_id: "%s"})
OPTIONAL MATCH p = (n)-[*..%d]-(m)
RETURN nodes(p) AS nodes, relationships(p) AS relationships
LIMIT %d
$$) AS (nodes agtype, relationships agtype)""" % (
self.graph_name,
encoded_node_label,
max_depth,
MAX_GRAPH_NODES,
)
encoded_label = self._encode_graph_label(node_label.strip('"'))
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
MATCH (n:Entity {{node_id: "{encoded_label}"}})
OPTIONAL MATCH p = (n)-[*..{max_depth}]-(m)
RETURN nodes(p) AS nodes, relationships(p) AS relationships
LIMIT {MAX_GRAPH_NODES}
$$) AS (nodes agtype, relationships agtype)"""
results = await self._query(query)
@@ -1440,61 +1428,48 @@ class PGGraphStorage(BaseGraphStorage):
edges = []
unique_edge_ids = set()
for result in results:
if node_label == "*":
if result["n"]:
node = result["n"]
node_id = self._decode_graph_label(node["node_id"])
if node_id not in nodes:
nodes[node_id] = node
def add_node(node_data: dict):
node_id = self._decode_graph_label(node_data["node_id"])
if node_id not in nodes:
nodes[node_id] = node_data
if result["m"]:
node = result["m"]
node_id = self._decode_graph_label(node["node_id"])
if node_id not in nodes:
nodes[node_id] = node
if result["r"]:
edge = result["r"]
src_id = self._decode_graph_label(edge["start_id"])
tgt_id = self._decode_graph_label(edge["end_id"])
edges.append((src_id, tgt_id))
else:
if result["nodes"]:
for node in result["nodes"]:
node_id = self._decode_graph_label(node["node_id"])
if node_id not in nodes:
nodes[node_id] = node
def add_edge(edge_data: list):
src_id = self._decode_graph_label(edge_data[0]["node_id"])
tgt_id = self._decode_graph_label(edge_data[2]["node_id"])
edge_key = f"{src_id},{tgt_id}"
if edge_key not in unique_edge_ids:
unique_edge_ids.add(edge_key)
edges.append((edge_key, src_id, tgt_id, {"source": edge_data[0], "target": edge_data[2]}))
if result["relationships"]:
for edge in result["relationships"]: # src --DIRECTED--> target
src_id = self._decode_graph_label(edge[0]["node_id"])
tgt_id = self._decode_graph_label(edge[2]["node_id"])
id = src_id + "," + tgt_id
if id in unique_edge_ids:
continue
else:
unique_edge_ids.add(id)
edges.append(
(id, src_id, tgt_id, {"source": edge[0], "target": edge[2]})
)
# Process the query results.
if node_label == "*":
for result in results:
if result.get("n"):
add_node(result["n"])
if result.get("m"):
add_node(result["m"])
if result.get("r"):
add_edge(result["r"])
else:
for result in results:
for node in result.get("nodes", []):
add_node(node)
for edge in result.get("relationships", []):
add_edge(edge)
# Construct and return the KnowledgeGraph.
kg = KnowledgeGraph(
nodes=[
KnowledgeGraphNode(
id=node_id, labels=[node_id], properties=nodes[node_id]
)
for node_id in nodes
KnowledgeGraphNode(id=node_id, labels=[node_id], properties=node_data)
for node_id, node_data in nodes.items()
],
edges=[
KnowledgeGraphEdge(
id=id, type="DIRECTED", source=src, target=tgt, properties=props
)
for id, src, tgt, props in edges
KnowledgeGraphEdge(id=edge_id, type="DIRECTED", source=src, target=tgt, properties=props)
for edge_id, src, tgt, props in edges
],
)
return kg
async def drop(self) -> None:
"""Drop the storage"""
drop_sql = SQL_TEMPLATES["drop_vdb_entity"]