Refactoring PostgreSQL AGE graph db implementation
This commit is contained in:
@@ -1064,31 +1064,11 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
if v.startswith("[") and v.endswith("]"):
|
if v.startswith("[") and v.endswith("]"):
|
||||||
if "::vertex" in v:
|
if "::vertex" in v:
|
||||||
v = v.replace("::vertex", "")
|
v = v.replace("::vertex", "")
|
||||||
vertexes = json.loads(v)
|
d[k] = json.loads(v)
|
||||||
dl = []
|
|
||||||
for vertex in vertexes:
|
|
||||||
prop = vertex.get("properties")
|
|
||||||
if not prop:
|
|
||||||
prop = {}
|
|
||||||
prop["label"] = PGGraphStorage._decode_graph_label(
|
|
||||||
prop["node_id"]
|
|
||||||
)
|
|
||||||
dl.append(prop)
|
|
||||||
d[k] = dl
|
|
||||||
|
|
||||||
elif "::edge" in v:
|
elif "::edge" in v:
|
||||||
v = v.replace("::edge", "")
|
v = v.replace("::edge", "")
|
||||||
edges = json.loads(v)
|
d[k] = json.loads(v)
|
||||||
dl = []
|
|
||||||
for edge in edges:
|
|
||||||
dl.append(
|
|
||||||
(
|
|
||||||
vertices[edge["start_id"]],
|
|
||||||
edge["label"],
|
|
||||||
vertices[edge["end_id"]],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
d[k] = dl
|
|
||||||
else:
|
else:
|
||||||
print("WARNING: unsupported type")
|
print("WARNING: unsupported type")
|
||||||
continue
|
continue
|
||||||
@@ -1097,26 +1077,9 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
dtype = v.split("::")[-1]
|
dtype = v.split("::")[-1]
|
||||||
v = v.split("::")[0]
|
v = v.split("::")[0]
|
||||||
if dtype == "vertex":
|
if dtype == "vertex":
|
||||||
vertex = json.loads(v)
|
d[k] = json.loads(v)
|
||||||
field = vertex.get("properties")
|
|
||||||
if not field:
|
|
||||||
field = {}
|
|
||||||
field["label"] = PGGraphStorage._decode_graph_label(
|
|
||||||
field["node_id"]
|
|
||||||
)
|
|
||||||
d[k] = field
|
|
||||||
# convert edge from id-label->id by replacing id with node information
|
|
||||||
# we only do this if the vertex was also returned in the query
|
|
||||||
# this is an attempt to be consistent with neo4j implementation
|
|
||||||
elif dtype == "edge":
|
elif dtype == "edge":
|
||||||
edge = json.loads(v)
|
d[k] = json.loads(v)
|
||||||
d[k] = (
|
|
||||||
vertices.get(edge["start_id"], {}),
|
|
||||||
edge[
|
|
||||||
"label"
|
|
||||||
], # we don't use decode_graph_label(), since edge label is always "DIRECTED"
|
|
||||||
vertices.get(edge["end_id"], {}),
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
d[k] = (
|
d[k] = (
|
||||||
json.loads(v)
|
json.loads(v)
|
||||||
@@ -1152,56 +1115,6 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
)
|
)
|
||||||
return "{" + ", ".join(props) + "}"
|
return "{" + ", ".join(props) + "}"
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _encode_graph_label(label: str) -> str:
|
|
||||||
"""
|
|
||||||
Since AGE supports only alphanumerical labels, we will encode generic label as HEX string
|
|
||||||
|
|
||||||
Args:
|
|
||||||
label (str): the original label
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: the encoded label
|
|
||||||
"""
|
|
||||||
return "x" + label.encode().hex()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _decode_graph_label(encoded_label: str) -> str:
|
|
||||||
"""
|
|
||||||
Since AGE supports only alphanumerical labels, we will encode generic label as HEX string
|
|
||||||
|
|
||||||
Args:
|
|
||||||
encoded_label (str): the encoded label
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: the decoded label
|
|
||||||
"""
|
|
||||||
return bytes.fromhex(encoded_label.removeprefix("x")).decode()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_col_name(field: str, idx: int) -> str:
|
|
||||||
"""
|
|
||||||
Convert a cypher return field to a pgsql select field
|
|
||||||
If possible keep the cypher column name, but create a generic name if necessary
|
|
||||||
|
|
||||||
Args:
|
|
||||||
field (str): a return field from a cypher query to be formatted for pgsql
|
|
||||||
idx (int): the position of the field in the return statement
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: the field to be used in the pgsql select statement
|
|
||||||
"""
|
|
||||||
# remove white space
|
|
||||||
field = field.strip()
|
|
||||||
# if an alias is provided for the field, use it
|
|
||||||
if " as " in field:
|
|
||||||
return field.split(" as ")[-1].strip()
|
|
||||||
# if the return value is an unnamed primitive, give it a generic name
|
|
||||||
if field.isnumeric() or field in ("true", "false", "null"):
|
|
||||||
return f"column_{idx}"
|
|
||||||
# otherwise return the value stripping out some common special chars
|
|
||||||
return field.replace("(", "_").replace(")", "")
|
|
||||||
|
|
||||||
async def _query(
|
async def _query(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
@@ -1252,10 +1165,10 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
async def has_node(self, node_id: str) -> bool:
|
async def has_node(self, node_id: str) -> bool:
|
||||||
entity_name_label = self._encode_graph_label(node_id.strip('"'))
|
entity_name_label = node_id.strip('"')
|
||||||
|
|
||||||
query = """SELECT * FROM cypher('%s', $$
|
query = """SELECT * FROM cypher('%s', $$
|
||||||
MATCH (n:base {node_id: "%s"})
|
MATCH (n:base {entity_id: "%s"})
|
||||||
RETURN count(n) > 0 AS node_exists
|
RETURN count(n) > 0 AS node_exists
|
||||||
$$) AS (node_exists bool)""" % (self.graph_name, entity_name_label)
|
$$) AS (node_exists bool)""" % (self.graph_name, entity_name_label)
|
||||||
|
|
||||||
@@ -1264,11 +1177,11 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
return single_result["node_exists"]
|
return single_result["node_exists"]
|
||||||
|
|
||||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||||
src_label = self._encode_graph_label(source_node_id.strip('"'))
|
src_label = source_node_id.strip('"')
|
||||||
tgt_label = self._encode_graph_label(target_node_id.strip('"'))
|
tgt_label = target_node_id.strip('"')
|
||||||
|
|
||||||
query = """SELECT * FROM cypher('%s', $$
|
query = """SELECT * FROM cypher('%s', $$
|
||||||
MATCH (a:base {node_id: "%s"})-[r]-(b:base {node_id: "%s"})
|
MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"})
|
||||||
RETURN COUNT(r) > 0 AS edge_exists
|
RETURN COUNT(r) > 0 AS edge_exists
|
||||||
$$) AS (edge_exists bool)""" % (
|
$$) AS (edge_exists bool)""" % (
|
||||||
self.graph_name,
|
self.graph_name,
|
||||||
@@ -1281,13 +1194,14 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
return single_result["edge_exists"]
|
return single_result["edge_exists"]
|
||||||
|
|
||||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||||
label = self._encode_graph_label(node_id.strip('"'))
|
label = node_id.strip('"')
|
||||||
query = """SELECT * FROM cypher('%s', $$
|
query = """SELECT * FROM cypher('%s', $$
|
||||||
MATCH (n:base {node_id: "%s"})
|
MATCH (n:base {entity_id: "%s"})
|
||||||
RETURN n
|
RETURN n
|
||||||
$$) AS (n agtype)""" % (self.graph_name, label)
|
$$) AS (n agtype)""" % (self.graph_name, label)
|
||||||
record = await self._query(query)
|
record = await self._query(query)
|
||||||
if record:
|
if record:
|
||||||
|
print(f"Record: {record}")
|
||||||
node = record[0]
|
node = record[0]
|
||||||
node_dict = node["n"]
|
node_dict = node["n"]
|
||||||
|
|
||||||
@@ -1295,10 +1209,10 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
async def node_degree(self, node_id: str) -> int:
|
async def node_degree(self, node_id: str) -> int:
|
||||||
label = self._encode_graph_label(node_id.strip('"'))
|
label = node_id.strip('"')
|
||||||
|
|
||||||
query = """SELECT * FROM cypher('%s', $$
|
query = """SELECT * FROM cypher('%s', $$
|
||||||
MATCH (n:base {node_id: "%s"})-[]->(x)
|
MATCH (n:base {entity_id: "%s"})-[]->(x)
|
||||||
RETURN count(x) AS total_edge_count
|
RETURN count(x) AS total_edge_count
|
||||||
$$) AS (total_edge_count integer)""" % (self.graph_name, label)
|
$$) AS (total_edge_count integer)""" % (self.graph_name, label)
|
||||||
record = (await self._query(query))[0]
|
record = (await self._query(query))[0]
|
||||||
@@ -1322,11 +1236,11 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
async def get_edge(
|
async def get_edge(
|
||||||
self, source_node_id: str, target_node_id: str
|
self, source_node_id: str, target_node_id: str
|
||||||
) -> dict[str, str] | None:
|
) -> dict[str, str] | None:
|
||||||
src_label = self._encode_graph_label(source_node_id.strip('"'))
|
src_label = source_node_id.strip('"')
|
||||||
tgt_label = self._encode_graph_label(target_node_id.strip('"'))
|
tgt_label = target_node_id.strip('"')
|
||||||
|
|
||||||
query = """SELECT * FROM cypher('%s', $$
|
query = """SELECT * FROM cypher('%s', $$
|
||||||
MATCH (a:base {node_id: "%s"})-[r]->(b:base {node_id: "%s"})
|
MATCH (a:base {entity_id: "%s"})-[r]->(b:base {entity_id: "%s"})
|
||||||
RETURN properties(r) as edge_properties
|
RETURN properties(r) as edge_properties
|
||||||
LIMIT 1
|
LIMIT 1
|
||||||
$$) AS (edge_properties agtype)""" % (
|
$$) AS (edge_properties agtype)""" % (
|
||||||
@@ -1336,6 +1250,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
)
|
)
|
||||||
record = await self._query(query)
|
record = await self._query(query)
|
||||||
if record and record[0] and record[0]["edge_properties"]:
|
if record and record[0] and record[0]["edge_properties"]:
|
||||||
|
print(f"Record: {record}")
|
||||||
result = record[0]["edge_properties"]
|
result = record[0]["edge_properties"]
|
||||||
|
|
||||||
return result
|
return result
|
||||||
@@ -1345,10 +1260,10 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
Retrieves all edges (relationships) for a particular node identified by its label.
|
Retrieves all edges (relationships) for a particular node identified by its label.
|
||||||
:return: list of dictionaries containing edge information
|
:return: list of dictionaries containing edge information
|
||||||
"""
|
"""
|
||||||
label = self._encode_graph_label(source_node_id.strip('"'))
|
label = source_node_id.strip('"')
|
||||||
|
|
||||||
query = """SELECT * FROM cypher('%s', $$
|
query = """SELECT * FROM cypher('%s', $$
|
||||||
MATCH (n:base {node_id: "%s"})
|
MATCH (n:base {entity_id: "%s"})
|
||||||
OPTIONAL MATCH (n)-[]-(connected:base)
|
OPTIONAL MATCH (n)-[]-(connected:base)
|
||||||
RETURN n, connected
|
RETURN n, connected
|
||||||
$$) AS (n agtype, connected agtype)""" % (
|
$$) AS (n agtype, connected agtype)""" % (
|
||||||
@@ -1362,24 +1277,17 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
source_node = record["n"] if record["n"] else None
|
source_node = record["n"] if record["n"] else None
|
||||||
connected_node = record["connected"] if record["connected"] else None
|
connected_node = record["connected"] if record["connected"] else None
|
||||||
|
|
||||||
source_label = (
|
if (
|
||||||
source_node["node_id"]
|
source_node
|
||||||
if source_node and source_node["node_id"]
|
and connected_node
|
||||||
else None
|
and "properties" in source_node
|
||||||
)
|
and "properties" in connected_node
|
||||||
target_label = (
|
):
|
||||||
connected_node["node_id"]
|
source_label = source_node["properties"].get("entity_id")
|
||||||
if connected_node and connected_node["node_id"]
|
target_label = connected_node["properties"].get("entity_id")
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
if source_label and target_label:
|
if source_label and target_label:
|
||||||
edges.append(
|
edges.append((source_label, target_label))
|
||||||
(
|
|
||||||
self._decode_graph_label(source_label),
|
|
||||||
self._decode_graph_label(target_label),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return edges
|
return edges
|
||||||
|
|
||||||
@@ -1389,17 +1297,17 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
retry=retry_if_exception_type((PGGraphQueryException,)),
|
retry=retry_if_exception_type((PGGraphQueryException,)),
|
||||||
)
|
)
|
||||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||||
label = self._encode_graph_label(node_id.strip('"'))
|
label = node_id.strip('"')
|
||||||
properties = node_data
|
properties = self._format_properties(node_data)
|
||||||
|
|
||||||
query = """SELECT * FROM cypher('%s', $$
|
query = """SELECT * FROM cypher('%s', $$
|
||||||
MERGE (n:base {node_id: "%s"})
|
MERGE (n:base {entity_id: "%s"})
|
||||||
SET n += %s
|
SET n += %s
|
||||||
RETURN n
|
RETURN n
|
||||||
$$) AS (n agtype)""" % (
|
$$) AS (n agtype)""" % (
|
||||||
self.graph_name,
|
self.graph_name,
|
||||||
label,
|
label,
|
||||||
self._format_properties(properties),
|
properties,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -1425,14 +1333,14 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
target_node_id (str): Label of the target node (used as identifier)
|
target_node_id (str): Label of the target node (used as identifier)
|
||||||
edge_data (dict): dictionary of properties to set on the edge
|
edge_data (dict): dictionary of properties to set on the edge
|
||||||
"""
|
"""
|
||||||
src_label = self._encode_graph_label(source_node_id.strip('"'))
|
src_label = source_node_id.strip('"')
|
||||||
tgt_label = self._encode_graph_label(target_node_id.strip('"'))
|
tgt_label = target_node_id.strip('"')
|
||||||
edge_properties = edge_data
|
edge_properties = self._format_properties(edge_data)
|
||||||
|
|
||||||
query = """SELECT * FROM cypher('%s', $$
|
query = """SELECT * FROM cypher('%s', $$
|
||||||
MATCH (source:base {node_id: "%s"})
|
MATCH (source:base {entity_id: "%s"})
|
||||||
WITH source
|
WITH source
|
||||||
MATCH (target:base {node_id: "%s"})
|
MATCH (target:base {entity_id: "%s"})
|
||||||
MERGE (source)-[r:DIRECTED]->(target)
|
MERGE (source)-[r:DIRECTED]->(target)
|
||||||
SET r += %s
|
SET r += %s
|
||||||
RETURN r
|
RETURN r
|
||||||
@@ -1440,7 +1348,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
self.graph_name,
|
self.graph_name,
|
||||||
src_label,
|
src_label,
|
||||||
tgt_label,
|
tgt_label,
|
||||||
self._format_properties(edge_properties),
|
edge_properties,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -1460,7 +1368,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
Args:
|
Args:
|
||||||
node_id (str): The ID of the node to delete.
|
node_id (str): The ID of the node to delete.
|
||||||
"""
|
"""
|
||||||
label = self._encode_graph_label(node_id.strip('"'))
|
label = node_id.strip('"')
|
||||||
|
|
||||||
query = """SELECT * FROM cypher('%s', $$
|
query = """SELECT * FROM cypher('%s', $$
|
||||||
MATCH (n:base {entity_id: "%s"})
|
MATCH (n:base {entity_id: "%s"})
|
||||||
@@ -1480,14 +1388,12 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
Args:
|
Args:
|
||||||
node_ids (list[str]): A list of node IDs to remove.
|
node_ids (list[str]): A list of node IDs to remove.
|
||||||
"""
|
"""
|
||||||
encoded_node_ids = [
|
node_ids = [node_id.strip('"') for node_id in node_ids]
|
||||||
self._encode_graph_label(node_id.strip('"')) for node_id in node_ids
|
node_id_list = ", ".join([f'"{node_id}"' for node_id in node_ids])
|
||||||
]
|
|
||||||
node_id_list = ", ".join([f'"{node_id}"' for node_id in encoded_node_ids])
|
|
||||||
|
|
||||||
query = """SELECT * FROM cypher('%s', $$
|
query = """SELECT * FROM cypher('%s', $$
|
||||||
MATCH (n:base)
|
MATCH (n:base)
|
||||||
WHERE n.nentity_id IN [%s]
|
WHERE n.entity_id IN [%s]
|
||||||
DETACH DELETE n
|
DETACH DELETE n
|
||||||
$$) AS (n agtype)""" % (self.graph_name, node_id_list)
|
$$) AS (n agtype)""" % (self.graph_name, node_id_list)
|
||||||
|
|
||||||
@@ -1505,11 +1411,11 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id).
|
edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id).
|
||||||
"""
|
"""
|
||||||
for source, target in edges:
|
for source, target in edges:
|
||||||
src_label = self._encode_graph_label(source.strip('"'))
|
src_label = source.strip('"')
|
||||||
tgt_label = self._encode_graph_label(target.strip('"'))
|
tgt_label = target.strip('"')
|
||||||
|
|
||||||
query = """SELECT * FROM cypher('%s', $$
|
query = """SELECT * FROM cypher('%s', $$
|
||||||
MATCH (a:base {node_id: "%s"})-[r]->(b:base {node_id: "%s"})
|
MATCH (a:base {entity_id: "%s"})-[r]->(b:base {entity_id: "%s"})
|
||||||
DELETE r
|
DELETE r
|
||||||
$$) AS (r agtype)""" % (self.graph_name, src_label, tgt_label)
|
$$) AS (r agtype)""" % (self.graph_name, src_label, tgt_label)
|
||||||
|
|
||||||
@@ -1560,95 +1466,98 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
return await embed_func()
|
return await embed_func()
|
||||||
|
|
||||||
async def get_knowledge_graph(
|
async def get_knowledge_graph(
|
||||||
self, node_label: str, max_depth: int = 5
|
self,
|
||||||
|
node_label: str,
|
||||||
|
max_depth: int = 3,
|
||||||
|
max_nodes: int = MAX_GRAPH_NODES,
|
||||||
) -> KnowledgeGraph:
|
) -> KnowledgeGraph:
|
||||||
"""
|
"""
|
||||||
Retrieve a subgraph containing the specified node and its neighbors up to the specified depth.
|
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node_label (str): The label of the node to start from. If "*", the entire graph is returned.
|
node_label: Label of the starting node, * means all nodes
|
||||||
max_depth (int): The maximum depth to traverse from the starting node.
|
max_depth: Maximum depth of the subgraph, Defaults to 3
|
||||||
|
max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
KnowledgeGraph: The retrieved subgraph.
|
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
|
||||||
|
indicating whether the graph was truncated due to max_nodes limit
|
||||||
"""
|
"""
|
||||||
MAX_GRAPH_NODES = 1000
|
|
||||||
|
|
||||||
# Build the query based on whether we want the full graph or a specific subgraph.
|
# Build the query based on whether we want the full graph or a specific subgraph.
|
||||||
if node_label == "*":
|
if node_label == "*":
|
||||||
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||||
MATCH (n:base)
|
MATCH (n:base)
|
||||||
OPTIONAL MATCH (n)-[r]->(m:base)
|
OPTIONAL MATCH (n)-[r]->(target:base)
|
||||||
RETURN n, r, m
|
RETURN collect(distinct n) AS n, collect(distinct r) AS r
|
||||||
LIMIT {MAX_GRAPH_NODES}
|
LIMIT {MAX_GRAPH_NODES}
|
||||||
$$) AS (n agtype, r agtype, m agtype)"""
|
$$) AS (n agtype, r agtype)"""
|
||||||
else:
|
else:
|
||||||
encoded_label = self._encode_graph_label(node_label.strip('"'))
|
strip_label = node_label.strip('"')
|
||||||
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||||
MATCH (n:base {{entity_id: "{encoded_label}"}})
|
MATCH (n:base {{entity_id: "{strip_label}"}})
|
||||||
OPTIONAL MATCH p = (n)-[*..{max_depth}]-(m)
|
OPTIONAL MATCH p = (n)-[*..{max_depth}]-(m)
|
||||||
RETURN nodes(p) AS nodes, relationships(p) AS relationships
|
RETURN nodes(p) AS n, relationships(p) AS r
|
||||||
LIMIT {MAX_GRAPH_NODES}
|
LIMIT {max_nodes}
|
||||||
$$) AS (nodes agtype, relationships agtype)"""
|
$$) AS (n agtype, r agtype)"""
|
||||||
|
|
||||||
results = await self._query(query)
|
results = await self._query(query)
|
||||||
|
|
||||||
nodes = {}
|
# Process the query results with deduplication by node and edge IDs
|
||||||
edges = []
|
nodes_dict = {}
|
||||||
unique_edge_ids = set()
|
edges_dict = {}
|
||||||
|
for result in results:
|
||||||
def add_node(node_data: dict):
|
# Handle single node cases
|
||||||
node_id = self._decode_graph_label(node_data["node_id"])
|
if result.get("n") and isinstance(result["n"], dict):
|
||||||
if node_id not in nodes:
|
node_id = str(result["n"]["id"])
|
||||||
nodes[node_id] = node_data
|
if node_id not in nodes_dict:
|
||||||
|
nodes_dict[node_id] = KnowledgeGraphNode(
|
||||||
def add_edge(edge_data: list):
|
id=node_id,
|
||||||
src_id = self._decode_graph_label(edge_data[0]["node_id"])
|
labels=[result["n"]["properties"]["entity_id"]],
|
||||||
tgt_id = self._decode_graph_label(edge_data[2]["node_id"])
|
properties=result["n"]["properties"],
|
||||||
edge_key = f"{src_id},{tgt_id}"
|
|
||||||
if edge_key not in unique_edge_ids:
|
|
||||||
unique_edge_ids.add(edge_key)
|
|
||||||
edges.append(
|
|
||||||
(
|
|
||||||
edge_key,
|
|
||||||
src_id,
|
|
||||||
tgt_id,
|
|
||||||
{"source": edge_data[0], "target": edge_data[2]},
|
|
||||||
)
|
)
|
||||||
|
# Handle node list cases
|
||||||
|
elif result.get("n") and isinstance(result["n"], list):
|
||||||
|
for node in result["n"]:
|
||||||
|
if isinstance(node, dict) and "id" in node:
|
||||||
|
node_id = str(node["id"])
|
||||||
|
if node_id not in nodes_dict and "properties" in node:
|
||||||
|
nodes_dict[node_id] = KnowledgeGraphNode(
|
||||||
|
id=node_id,
|
||||||
|
labels=[node["properties"]["entity_id"]],
|
||||||
|
properties=node["properties"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Process the query results.
|
# Handle single edge cases
|
||||||
if node_label == "*":
|
if result.get("r") and isinstance(result["r"], dict):
|
||||||
for result in results:
|
edge_id = str(result["r"]["id"])
|
||||||
if result.get("n"):
|
if edge_id not in edges_dict:
|
||||||
add_node(result["n"])
|
edges_dict[edge_id] = KnowledgeGraphEdge(
|
||||||
if result.get("m"):
|
|
||||||
add_node(result["m"])
|
|
||||||
if result.get("r"):
|
|
||||||
add_edge(result["r"])
|
|
||||||
else:
|
|
||||||
for result in results:
|
|
||||||
for node in result.get("nodes", []):
|
|
||||||
add_node(node)
|
|
||||||
for edge in result.get("relationships", []):
|
|
||||||
add_edge(edge)
|
|
||||||
|
|
||||||
# Construct and return the KnowledgeGraph.
|
|
||||||
kg = KnowledgeGraph(
|
|
||||||
nodes=[
|
|
||||||
KnowledgeGraphNode(id=node_id, labels=[node_id], properties=node_data)
|
|
||||||
for node_id, node_data in nodes.items()
|
|
||||||
],
|
|
||||||
edges=[
|
|
||||||
KnowledgeGraphEdge(
|
|
||||||
id=edge_id,
|
id=edge_id,
|
||||||
type="DIRECTED",
|
type="DIRECTED",
|
||||||
source=src,
|
source=str(result["r"]["start_id"]),
|
||||||
target=tgt,
|
target=str(result["r"]["end_id"]),
|
||||||
properties=props,
|
properties=result["r"]["properties"],
|
||||||
)
|
)
|
||||||
for edge_id, src, tgt, props in edges
|
# Handle edge list cases
|
||||||
],
|
elif result.get("r") and isinstance(result["r"], list):
|
||||||
|
for edge in result["r"]:
|
||||||
|
if isinstance(edge, dict) and "id" in edge:
|
||||||
|
edge_id = str(edge["id"])
|
||||||
|
if edge_id not in edges_dict:
|
||||||
|
edges_dict[edge_id] = KnowledgeGraphEdge(
|
||||||
|
id=edge_id,
|
||||||
|
type="DIRECTED",
|
||||||
|
source=str(edge["start_id"]),
|
||||||
|
target=str(edge["end_id"]),
|
||||||
|
properties=edge["properties"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Construct and return the KnowledgeGraph with deduplicated nodes and edges
|
||||||
|
kg = KnowledgeGraph(
|
||||||
|
nodes=list(nodes_dict.values()),
|
||||||
|
edges=list(edges_dict.values()),
|
||||||
|
is_truncated=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
return kg
|
return kg
|
||||||
|
Reference in New Issue
Block a user