From 3e3338a1440253de0a54e43133e406db30bb9e20 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 3 Apr 2025 15:40:31 +0800 Subject: [PATCH] Fix get_node error for PostgreSQL graph storage --- lightrag/base.py | 4 ++-- lightrag/kg/neo4j_impl.py | 2 +- lightrag/kg/postgres_impl.py | 6 +++++- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/lightrag/base.py b/lightrag/base.py index 0d387cc3..3640fa1f 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -297,13 +297,13 @@ class BaseGraphStorage(StorageNameSpace, ABC): @abstractmethod async def get_node(self, node_id: str) -> dict[str, str] | None: - """Get an edge by its source and target node ids.""" + """Get node by its label identifier, return only node properties""" @abstractmethod async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: - """Get all edges connected to a node.""" + """Get edge properties between two nodes""" @abstractmethod async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 33d986a9..72978b60 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -267,7 +267,7 @@ class Neo4JStorage(BaseGraphStorage): raise async def get_node(self, node_id: str) -> dict[str, str] | None: - """Get node by its label identifier. + """Get node by its label identifier, return only node properties Args: node_id: The node label to look up diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 4b116d96..0038f99f 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -1194,6 +1194,8 @@ class PGGraphStorage(BaseGraphStorage): return single_result["edge_exists"] async def get_node(self, node_id: str) -> dict[str, str] | None: + """Get node by its label identifier, return only node properties""" + label = node_id.strip('"') query = """SELECT * FROM cypher('%s', $$ MATCH (n:base {entity_id: "%s"}) @@ -1202,7 +1204,7 @@ class PGGraphStorage(BaseGraphStorage): record = await self._query(query) if record: node = record[0] - node_dict = node["n"] + node_dict = node["n"]["properties"] return node_dict return None @@ -1235,6 +1237,8 @@ class PGGraphStorage(BaseGraphStorage): async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: + """Get edge properties between two nodes""" + src_label = source_node_id.strip('"') tgt_label = target_node_id.strip('"')