From 48377d91ef2284ed1329f76fad2d2785202329ce Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Wed, 19 Feb 2025 14:26:46 +0100 Subject: [PATCH] back age --- lightrag/kg/postgres_impl.py | 62 ++++++++++++++++++++++++------------ 1 file changed, 41 insertions(+), 21 deletions(-) diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 044bf4c1..8f1128a5 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -86,29 +86,40 @@ class PostgreSQLDB: ) raise - async def check_graph_requirement(self, graph_name: str): - async with self.pool.acquire() as connection: # type: ignore - try: - await connection.execute( - 'SET search_path = ag_catalog, "$user", public' - ) # type: ignore - await connection.execute(f"select create_graph('{graph_name}')") # type: ignore - except ( - asyncpg.exceptions.InvalidSchemaNameError, - asyncpg.exceptions.UniqueViolationError, - ): - pass + @staticmethod + async def configure_age(connection: asyncpg.Connection, graph_name: str) -> None: + """Set the Apache AGE environment and creates a graph if it does not exist. + + This method: + - Sets the PostgreSQL `search_path` to include `ag_catalog`, ensuring that Apache AGE functions can be used without specifying the schema. + - Attempts to create a new graph with the provided `graph_name` if it does not already exist. + - Silently ignores errors related to the graph already existing. + + """ + try: + await connection.execute( # type: ignore + 'SET search_path = ag_catalog, "$user", public' + ) + await connection.execute( # type: ignore + f"select create_graph('{graph_name}')" + ) + except ( + asyncpg.exceptions.InvalidSchemaNameError, + asyncpg.exceptions.UniqueViolationError, + ): + pass async def check_tables(self): for k, v in TABLES.items(): try: await self.query(f"SELECT 1 FROM {k} LIMIT 1") - except Exception as e: - logger.error(f"PostgreSQL database error: {e}") + except Exception: try: logger.info(f"PostgreSQL, Try Creating table {k} in database") await self.execute(v["ddl"]) - logger.info(f"PostgreSQL, Created table {k} in PostgreSQL database") + logger.info( + f"PostgreSQL, Creation success table {k} in PostgreSQL database" + ) except Exception as e: logger.error( f"PostgreSQL, Failed to create table {k} in database, Please verify the connection with PostgreSQL database, Got: {e}" @@ -120,8 +131,15 @@ class PostgreSQLDB: sql: str, params: dict[str, Any] | None = None, multirows: bool = False, + with_age: bool = False, + graph_name: str | None = None, ) -> dict[str, Any] | None | list[dict[str, Any]]: async with self.pool.acquire() as connection: # type: ignore + if with_age and graph_name: + await self.configure_age(connection, graph_name) # type: ignore + elif with_age and not graph_name: + raise ValueError("Graph name is required when with_age is True") + try: if params: rows = await connection.fetch(sql, *params.values()) @@ -142,9 +160,7 @@ class PostgreSQLDB: data = None return data except Exception as e: - logger.error( - f"PostgreSQL database,\nsql:{sql},\nparams:{params},\nerror:{e}" - ) + logger.error(f"PostgreSQL database, error:{e}") raise async def execute( @@ -152,6 +168,8 @@ class PostgreSQLDB: sql: str, data: dict[str, Any] | None = None, upsert: bool = False, + with_age: bool = False, + graph_name: str | None = None, ): try: async with self.pool.acquire() as connection: # type: ignore @@ -656,9 +674,6 @@ class PGGraphStorage(BaseGraphStorage): async def initialize(self): if self.db is None: self.db = await ClientManager.get_client() - # `check_graph_requirement` is required to be executed after `get_client` - # to ensure the graph is created before any query is executed. - await self.db.check_graph_requirement(self.graph_name) async def finalize(self): if self.db is not None: @@ -829,12 +844,17 @@ class PGGraphStorage(BaseGraphStorage): data = await self.db.query( query, multirows=True, + with_age=True, + graph_name=self.graph_name, ) else: data = await self.db.execute( query, upsert=upsert, + with_age=True, + graph_name=self.graph_name, ) + except Exception as e: raise PGGraphQueryException( {