cleaned code

This commit is contained in:
Yannick Stephan
2025-02-19 13:50:38 +01:00
parent 495b0ddbe0
commit fc151e5866

View File

@@ -88,7 +88,16 @@ class PostgreSQLDB:
async def check_graph_requirement(self, graph_name: str):
async with self.pool.acquire() as connection: # type: ignore
await self._prerequisite(connection, graph_name) # type: ignore
try:
await connection.execute(
'SET search_path = ag_catalog, "$user", public'
) # type: ignore
await connection.execute(f"select create_graph('{graph_name}')") # type: ignore
except (
asyncpg.exceptions.InvalidSchemaNameError,
asyncpg.exceptions.UniqueViolationError,
):
pass
async def check_tables(self):
for k, v in TABLES.items():
@@ -106,8 +115,6 @@ class PostgreSQLDB:
)
raise e
logger.info("Finished checking all tables in PostgreSQL database")
async def query(
self,
sql: str,
@@ -164,17 +171,6 @@ class PostgreSQLDB:
logger.error(f"PostgreSQL database,\nsql:{sql},\ndata:{data},\nerror:{e}")
raise
@staticmethod
async def _prerequisite(conn: asyncpg.Connection, graph_name: str):
try:
await conn.execute('SET search_path = ag_catalog, "$user", public') # type: ignore
await conn.execute(f"select create_graph('{graph_name}')") # type: ignore
except (
asyncpg.exceptions.InvalidSchemaNameError,
asyncpg.exceptions.UniqueViolationError,
):
pass
class ClientManager:
_instances: dict[str, Any] = {"db": None, "ref_count": 0}
@@ -554,24 +550,7 @@ class PGDocStatusStorage(DocStatusStorage):
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
"""Get doc_chunks data by id"""
sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
ids=",".join([f"'{id}'" for id in ids])
)
params = {"workspace": self.db.workspace}
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
array_res = await self.db.query(sql, params, multirows=True)
modes = set()
dict_res: dict[str, dict] = {}
for row in array_res:
modes.add(row["mode"])
for mode in modes:
if mode not in dict_res:
dict_res[mode] = {}
for row in array_res:
dict_res[row["mode"]][row["id"]] = row
return [{k: v} for k, v in dict_res.items()]
else:
return await self.db.query(sql, params, multirows=True)
raise NotImplementedError
async def get_status_counts(self) -> dict[str, int]:
"""Get counts of documents in each status"""
@@ -677,6 +656,8 @@ class PGGraphStorage(BaseGraphStorage):
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
# `check_graph_requirement` is required to be executed after `get_client`
# to ensure the graph is created before any query is executed.
await self.db.check_graph_requirement(self.graph_name)
async def finalize(self):