cleaned code
This commit is contained in:
@@ -88,7 +88,16 @@ class PostgreSQLDB:
|
|||||||
|
|
||||||
async def check_graph_requirement(self, graph_name: str):
|
async def check_graph_requirement(self, graph_name: str):
|
||||||
async with self.pool.acquire() as connection: # type: ignore
|
async with self.pool.acquire() as connection: # type: ignore
|
||||||
await self._prerequisite(connection, graph_name) # type: ignore
|
try:
|
||||||
|
await connection.execute(
|
||||||
|
'SET search_path = ag_catalog, "$user", public'
|
||||||
|
) # type: ignore
|
||||||
|
await connection.execute(f"select create_graph('{graph_name}')") # type: ignore
|
||||||
|
except (
|
||||||
|
asyncpg.exceptions.InvalidSchemaNameError,
|
||||||
|
asyncpg.exceptions.UniqueViolationError,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
async def check_tables(self):
|
async def check_tables(self):
|
||||||
for k, v in TABLES.items():
|
for k, v in TABLES.items():
|
||||||
@@ -106,8 +115,6 @@ class PostgreSQLDB:
|
|||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
logger.info("Finished checking all tables in PostgreSQL database")
|
|
||||||
|
|
||||||
async def query(
|
async def query(
|
||||||
self,
|
self,
|
||||||
sql: str,
|
sql: str,
|
||||||
@@ -164,17 +171,6 @@ class PostgreSQLDB:
|
|||||||
logger.error(f"PostgreSQL database,\nsql:{sql},\ndata:{data},\nerror:{e}")
|
logger.error(f"PostgreSQL database,\nsql:{sql},\ndata:{data},\nerror:{e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def _prerequisite(conn: asyncpg.Connection, graph_name: str):
|
|
||||||
try:
|
|
||||||
await conn.execute('SET search_path = ag_catalog, "$user", public') # type: ignore
|
|
||||||
await conn.execute(f"select create_graph('{graph_name}')") # type: ignore
|
|
||||||
except (
|
|
||||||
asyncpg.exceptions.InvalidSchemaNameError,
|
|
||||||
asyncpg.exceptions.UniqueViolationError,
|
|
||||||
):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ClientManager:
|
class ClientManager:
|
||||||
_instances: dict[str, Any] = {"db": None, "ref_count": 0}
|
_instances: dict[str, Any] = {"db": None, "ref_count": 0}
|
||||||
@@ -554,24 +550,7 @@ class PGDocStatusStorage(DocStatusStorage):
|
|||||||
|
|
||||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||||
"""Get doc_chunks data by id"""
|
"""Get doc_chunks data by id"""
|
||||||
sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
raise NotImplementedError
|
||||||
ids=",".join([f"'{id}'" for id in ids])
|
|
||||||
)
|
|
||||||
params = {"workspace": self.db.workspace}
|
|
||||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
|
||||||
array_res = await self.db.query(sql, params, multirows=True)
|
|
||||||
modes = set()
|
|
||||||
dict_res: dict[str, dict] = {}
|
|
||||||
for row in array_res:
|
|
||||||
modes.add(row["mode"])
|
|
||||||
for mode in modes:
|
|
||||||
if mode not in dict_res:
|
|
||||||
dict_res[mode] = {}
|
|
||||||
for row in array_res:
|
|
||||||
dict_res[row["mode"]][row["id"]] = row
|
|
||||||
return [{k: v} for k, v in dict_res.items()]
|
|
||||||
else:
|
|
||||||
return await self.db.query(sql, params, multirows=True)
|
|
||||||
|
|
||||||
async def get_status_counts(self) -> dict[str, int]:
|
async def get_status_counts(self) -> dict[str, int]:
|
||||||
"""Get counts of documents in each status"""
|
"""Get counts of documents in each status"""
|
||||||
@@ -677,6 +656,8 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
if self.db is None:
|
if self.db is None:
|
||||||
self.db = await ClientManager.get_client()
|
self.db = await ClientManager.get_client()
|
||||||
|
# `check_graph_requirement` is required to be executed after `get_client`
|
||||||
|
# to ensure the graph is created before any query is executed.
|
||||||
await self.db.check_graph_requirement(self.graph_name)
|
await self.db.check_graph_requirement(self.graph_name)
|
||||||
|
|
||||||
async def finalize(self):
|
async def finalize(self):
|
||||||
|
Reference in New Issue
Block a user