diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 3c88a05b..044bf4c1 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -88,7 +88,16 @@ class PostgreSQLDB: async def check_graph_requirement(self, graph_name: str): async with self.pool.acquire() as connection: # type: ignore - await self._prerequisite(connection, graph_name) # type: ignore + try: + await connection.execute( + 'SET search_path = ag_catalog, "$user", public' + ) # type: ignore + await connection.execute(f"select create_graph('{graph_name}')") # type: ignore + except ( + asyncpg.exceptions.InvalidSchemaNameError, + asyncpg.exceptions.UniqueViolationError, + ): + pass async def check_tables(self): for k, v in TABLES.items(): @@ -106,8 +115,6 @@ class PostgreSQLDB: ) raise e - logger.info("Finished checking all tables in PostgreSQL database") - async def query( self, sql: str, @@ -164,17 +171,6 @@ class PostgreSQLDB: logger.error(f"PostgreSQL database,\nsql:{sql},\ndata:{data},\nerror:{e}") raise - @staticmethod - async def _prerequisite(conn: asyncpg.Connection, graph_name: str): - try: - await conn.execute('SET search_path = ag_catalog, "$user", public') # type: ignore - await conn.execute(f"select create_graph('{graph_name}')") # type: ignore - except ( - asyncpg.exceptions.InvalidSchemaNameError, - asyncpg.exceptions.UniqueViolationError, - ): - pass - class ClientManager: _instances: dict[str, Any] = {"db": None, "ref_count": 0} @@ -554,24 +550,7 @@ class PGDocStatusStorage(DocStatusStorage): async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: """Get doc_chunks data by id""" - sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( - ids=",".join([f"'{id}'" for id in ids]) - ) - params = {"workspace": self.db.workspace} - if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): - array_res = await self.db.query(sql, params, multirows=True) - modes = set() - dict_res: dict[str, dict] = {} - for row in array_res: - modes.add(row["mode"]) - for mode in modes: - if mode not in dict_res: - dict_res[mode] = {} - for row in array_res: - dict_res[row["mode"]][row["id"]] = row - return [{k: v} for k, v in dict_res.items()] - else: - return await self.db.query(sql, params, multirows=True) + raise NotImplementedError async def get_status_counts(self) -> dict[str, int]: """Get counts of documents in each status""" @@ -677,6 +656,8 @@ class PGGraphStorage(BaseGraphStorage): async def initialize(self): if self.db is None: self.db = await ClientManager.get_client() + # `check_graph_requirement` is required to be executed after `get_client` + # to ensure the graph is created before any query is executed. await self.db.check_graph_requirement(self.graph_name) async def finalize(self):