@@ -86,29 +86,40 @@ class PostgreSQLDB:
|
||||
)
|
||||
raise
|
||||
|
||||
async def check_graph_requirement(self, graph_name: str):
|
||||
async with self.pool.acquire() as connection: # type: ignore
|
||||
try:
|
||||
await connection.execute(
|
||||
'SET search_path = ag_catalog, "$user", public'
|
||||
) # type: ignore
|
||||
await connection.execute(f"select create_graph('{graph_name}')") # type: ignore
|
||||
except (
|
||||
asyncpg.exceptions.InvalidSchemaNameError,
|
||||
asyncpg.exceptions.UniqueViolationError,
|
||||
):
|
||||
pass
|
||||
@staticmethod
|
||||
async def configure_age(connection: asyncpg.Connection, graph_name: str) -> None:
|
||||
"""Set the Apache AGE environment and creates a graph if it does not exist.
|
||||
|
||||
This method:
|
||||
- Sets the PostgreSQL `search_path` to include `ag_catalog`, ensuring that Apache AGE functions can be used without specifying the schema.
|
||||
- Attempts to create a new graph with the provided `graph_name` if it does not already exist.
|
||||
- Silently ignores errors related to the graph already existing.
|
||||
|
||||
"""
|
||||
try:
|
||||
await connection.execute( # type: ignore
|
||||
'SET search_path = ag_catalog, "$user", public'
|
||||
)
|
||||
await connection.execute( # type: ignore
|
||||
f"select create_graph('{graph_name}')"
|
||||
)
|
||||
except (
|
||||
asyncpg.exceptions.InvalidSchemaNameError,
|
||||
asyncpg.exceptions.UniqueViolationError,
|
||||
):
|
||||
pass
|
||||
|
||||
async def check_tables(self):
|
||||
for k, v in TABLES.items():
|
||||
try:
|
||||
await self.query(f"SELECT 1 FROM {k} LIMIT 1")
|
||||
except Exception as e:
|
||||
logger.error(f"PostgreSQL database error: {e}")
|
||||
except Exception:
|
||||
try:
|
||||
logger.info(f"PostgreSQL, Try Creating table {k} in database")
|
||||
await self.execute(v["ddl"])
|
||||
logger.info(f"PostgreSQL, Created table {k} in PostgreSQL database")
|
||||
logger.info(
|
||||
f"PostgreSQL, Creation success table {k} in PostgreSQL database"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"PostgreSQL, Failed to create table {k} in database, Please verify the connection with PostgreSQL database, Got: {e}"
|
||||
@@ -120,8 +131,15 @@ class PostgreSQLDB:
|
||||
sql: str,
|
||||
params: dict[str, Any] | None = None,
|
||||
multirows: bool = False,
|
||||
with_age: bool = False,
|
||||
graph_name: str | None = None,
|
||||
) -> dict[str, Any] | None | list[dict[str, Any]]:
|
||||
async with self.pool.acquire() as connection: # type: ignore
|
||||
if with_age and graph_name:
|
||||
await self.configure_age(connection, graph_name) # type: ignore
|
||||
elif with_age and not graph_name:
|
||||
raise ValueError("Graph name is required when with_age is True")
|
||||
|
||||
try:
|
||||
if params:
|
||||
rows = await connection.fetch(sql, *params.values())
|
||||
@@ -142,9 +160,7 @@ class PostgreSQLDB:
|
||||
data = None
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"PostgreSQL database,\nsql:{sql},\nparams:{params},\nerror:{e}"
|
||||
)
|
||||
logger.error(f"PostgreSQL database, error:{e}")
|
||||
raise
|
||||
|
||||
async def execute(
|
||||
@@ -152,6 +168,8 @@ class PostgreSQLDB:
|
||||
sql: str,
|
||||
data: dict[str, Any] | None = None,
|
||||
upsert: bool = False,
|
||||
with_age: bool = False,
|
||||
graph_name: str | None = None,
|
||||
):
|
||||
try:
|
||||
async with self.pool.acquire() as connection: # type: ignore
|
||||
@@ -656,9 +674,6 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
async def initialize(self):
|
||||
if self.db is None:
|
||||
self.db = await ClientManager.get_client()
|
||||
# `check_graph_requirement` is required to be executed after `get_client`
|
||||
# to ensure the graph is created before any query is executed.
|
||||
await self.db.check_graph_requirement(self.graph_name)
|
||||
|
||||
async def finalize(self):
|
||||
if self.db is not None:
|
||||
@@ -829,12 +844,17 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
data = await self.db.query(
|
||||
query,
|
||||
multirows=True,
|
||||
with_age=True,
|
||||
graph_name=self.graph_name,
|
||||
)
|
||||
else:
|
||||
data = await self.db.execute(
|
||||
query,
|
||||
upsert=upsert,
|
||||
with_age=True,
|
||||
graph_name=self.graph_name,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise PGGraphQueryException(
|
||||
{
|
||||
|
Reference in New Issue
Block a user