Fix get_node error for PostgreSQL graph storage
This commit is contained in:
@@ -297,13 +297,13 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||||
"""Get an edge by its source and target node ids."""
|
"""Get node by its label identifier, return only node properties"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_edge(
|
async def get_edge(
|
||||||
self, source_node_id: str, target_node_id: str
|
self, source_node_id: str, target_node_id: str
|
||||||
) -> dict[str, str] | None:
|
) -> dict[str, str] | None:
|
||||||
"""Get all edges connected to a node."""
|
"""Get edge properties between two nodes"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||||
|
@@ -267,7 +267,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||||
"""Get node by its label identifier.
|
"""Get node by its label identifier, return only node properties
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node_id: The node label to look up
|
node_id: The node label to look up
|
||||||
|
@@ -1194,6 +1194,8 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
return single_result["edge_exists"]
|
return single_result["edge_exists"]
|
||||||
|
|
||||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||||
|
"""Get node by its label identifier, return only node properties"""
|
||||||
|
|
||||||
label = node_id.strip('"')
|
label = node_id.strip('"')
|
||||||
query = """SELECT * FROM cypher('%s', $$
|
query = """SELECT * FROM cypher('%s', $$
|
||||||
MATCH (n:base {entity_id: "%s"})
|
MATCH (n:base {entity_id: "%s"})
|
||||||
@@ -1202,7 +1204,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
record = await self._query(query)
|
record = await self._query(query)
|
||||||
if record:
|
if record:
|
||||||
node = record[0]
|
node = record[0]
|
||||||
node_dict = node["n"]
|
node_dict = node["n"]["properties"]
|
||||||
|
|
||||||
return node_dict
|
return node_dict
|
||||||
return None
|
return None
|
||||||
@@ -1235,6 +1237,8 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
async def get_edge(
|
async def get_edge(
|
||||||
self, source_node_id: str, target_node_id: str
|
self, source_node_id: str, target_node_id: str
|
||||||
) -> dict[str, str] | None:
|
) -> dict[str, str] | None:
|
||||||
|
"""Get edge properties between two nodes"""
|
||||||
|
|
||||||
src_label = source_node_id.strip('"')
|
src_label = source_node_id.strip('"')
|
||||||
tgt_label = target_node_id.strip('"')
|
tgt_label = target_node_id.strip('"')
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user