From c01693402173d2ffdb1c6140c9f4aa06a815760b Mon Sep 17 00:00:00 2001 From: Samuel Chan Date: Sun, 12 Jan 2025 21:38:39 +0800 Subject: [PATCH] Revise the AGE implementation on get_node_edges, to align with Neo4j behavior. --- lightrag/kg/postgres_impl.py | 14 +++++++++++--- lightrag/kg/postgres_impl_test.py | 8 ++++++-- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index ccbff679..b93a345b 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -141,13 +141,16 @@ class PostgreSQLDB: await connection.execute(sql) else: await connection.execute(sql, *data.values()) - except asyncpg.exceptions.UniqueViolationError as e: + except ( + asyncpg.exceptions.UniqueViolationError, + asyncpg.exceptions.DuplicateTableError, + ) as e: if upsert: print("Key value duplicate, but upsert succeeded.") else: logger.error(f"Upsert error: {e}") except Exception as e: - logger.error(f"PostgreSQL database error: {e}") + logger.error(f"PostgreSQL database error: {e.__class__} - {e}") print(sql) print(data) raise @@ -885,7 +888,12 @@ class PGGraphStorage(BaseGraphStorage): ) if source_label and target_label: - edges.append((source_label, target_label)) + edges.append( + ( + PGGraphStorage._decode_graph_label(source_label), + PGGraphStorage._decode_graph_label(target_label), + ) + ) return edges diff --git a/lightrag/kg/postgres_impl_test.py b/lightrag/kg/postgres_impl_test.py index dc046311..274f03de 100644 --- a/lightrag/kg/postgres_impl_test.py +++ b/lightrag/kg/postgres_impl_test.py @@ -61,7 +61,7 @@ db = PostgreSQLDB( "port": 15432, "user": "rag", "password": "rag", - "database": "rag", + "database": "r1", } ) @@ -74,8 +74,12 @@ async def query_with_age(): embedding_func=None, ) graph.db = db - res = await graph.get_node('"CHRISTMAS-TIME"') + res = await graph.get_node('"A CHRISTMAS CAROL"') print("Node is: ", res) + res = await graph.get_edge('"A CHRISTMAS CAROL"', "PROJECT GUTENBERG") + print("Edge is: ", res) + res = await graph.get_node_edges('"SCROOGE"') + print("Node Edges are: ", res) async def create_edge_with_age():