Fix get_node error for PostgreSQL graph storage

This commit is contained in:
yangdx
2025-04-03 15:40:31 +08:00
parent b48d5e62e3
commit 3e3338a144
3 changed files with 8 additions and 4 deletions

View File

@@ -297,13 +297,13 @@ class BaseGraphStorage(StorageNameSpace, ABC):
@abstractmethod @abstractmethod
async def get_node(self, node_id: str) -> dict[str, str] | None: async def get_node(self, node_id: str) -> dict[str, str] | None:
"""Get an edge by its source and target node ids.""" """Get node by its label identifier, return only node properties"""
@abstractmethod @abstractmethod
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None: ) -> dict[str, str] | None:
"""Get all edges connected to a node.""" """Get edge properties between two nodes"""
@abstractmethod @abstractmethod
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:

View File

@@ -267,7 +267,7 @@ class Neo4JStorage(BaseGraphStorage):
raise raise
async def get_node(self, node_id: str) -> dict[str, str] | None: async def get_node(self, node_id: str) -> dict[str, str] | None:
"""Get node by its label identifier. """Get node by its label identifier, return only node properties
Args: Args:
node_id: The node label to look up node_id: The node label to look up

View File

@@ -1194,6 +1194,8 @@ class PGGraphStorage(BaseGraphStorage):
return single_result["edge_exists"] return single_result["edge_exists"]
async def get_node(self, node_id: str) -> dict[str, str] | None: async def get_node(self, node_id: str) -> dict[str, str] | None:
"""Get node by its label identifier, return only node properties"""
label = node_id.strip('"') label = node_id.strip('"')
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MATCH (n:base {entity_id: "%s"}) MATCH (n:base {entity_id: "%s"})
@@ -1202,7 +1204,7 @@ class PGGraphStorage(BaseGraphStorage):
record = await self._query(query) record = await self._query(query)
if record: if record:
node = record[0] node = record[0]
node_dict = node["n"] node_dict = node["n"]["properties"]
return node_dict return node_dict
return None return None
@@ -1235,6 +1237,8 @@ class PGGraphStorage(BaseGraphStorage):
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None: ) -> dict[str, str] | None:
"""Get edge properties between two nodes"""
src_label = source_node_id.strip('"') src_label = source_node_id.strip('"')
tgt_label = target_node_id.strip('"') tgt_label = target_node_id.strip('"')