cleaned code
This commit is contained in:
@@ -88,7 +88,16 @@ class PostgreSQLDB:
|
||||
|
||||
async def check_graph_requirement(self, graph_name: str):
|
||||
async with self.pool.acquire() as connection: # type: ignore
|
||||
await self._prerequisite(connection, graph_name) # type: ignore
|
||||
try:
|
||||
await connection.execute(
|
||||
'SET search_path = ag_catalog, "$user", public'
|
||||
) # type: ignore
|
||||
await connection.execute(f"select create_graph('{graph_name}')") # type: ignore
|
||||
except (
|
||||
asyncpg.exceptions.InvalidSchemaNameError,
|
||||
asyncpg.exceptions.UniqueViolationError,
|
||||
):
|
||||
pass
|
||||
|
||||
async def check_tables(self):
|
||||
for k, v in TABLES.items():
|
||||
@@ -106,8 +115,6 @@ class PostgreSQLDB:
|
||||
)
|
||||
raise e
|
||||
|
||||
logger.info("Finished checking all tables in PostgreSQL database")
|
||||
|
||||
async def query(
|
||||
self,
|
||||
sql: str,
|
||||
@@ -164,17 +171,6 @@ class PostgreSQLDB:
|
||||
logger.error(f"PostgreSQL database,\nsql:{sql},\ndata:{data},\nerror:{e}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
async def _prerequisite(conn: asyncpg.Connection, graph_name: str):
|
||||
try:
|
||||
await conn.execute('SET search_path = ag_catalog, "$user", public') # type: ignore
|
||||
await conn.execute(f"select create_graph('{graph_name}')") # type: ignore
|
||||
except (
|
||||
asyncpg.exceptions.InvalidSchemaNameError,
|
||||
asyncpg.exceptions.UniqueViolationError,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class ClientManager:
|
||||
_instances: dict[str, Any] = {"db": None, "ref_count": 0}
|
||||
@@ -554,24 +550,7 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
"""Get doc_chunks data by id"""
|
||||
sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
||||
ids=",".join([f"'{id}'" for id in ids])
|
||||
)
|
||||
params = {"workspace": self.db.workspace}
|
||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
||||
array_res = await self.db.query(sql, params, multirows=True)
|
||||
modes = set()
|
||||
dict_res: dict[str, dict] = {}
|
||||
for row in array_res:
|
||||
modes.add(row["mode"])
|
||||
for mode in modes:
|
||||
if mode not in dict_res:
|
||||
dict_res[mode] = {}
|
||||
for row in array_res:
|
||||
dict_res[row["mode"]][row["id"]] = row
|
||||
return [{k: v} for k, v in dict_res.items()]
|
||||
else:
|
||||
return await self.db.query(sql, params, multirows=True)
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_status_counts(self) -> dict[str, int]:
|
||||
"""Get counts of documents in each status"""
|
||||
@@ -677,6 +656,8 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
async def initialize(self):
|
||||
if self.db is None:
|
||||
self.db = await ClientManager.get_client()
|
||||
# `check_graph_requirement` is required to be executed after `get_client`
|
||||
# to ensure the graph is created before any query is executed.
|
||||
await self.db.check_graph_requirement(self.graph_name)
|
||||
|
||||
async def finalize(self):
|
||||
|
Reference in New Issue
Block a user