Merge pull request #869 from YanSte/back-age

Fallback Age
This commit is contained in:
Yannick Stephan
2025-02-19 14:27:41 +01:00
committed by GitHub

View File

@@ -86,29 +86,40 @@ class PostgreSQLDB:
)
raise
async def check_graph_requirement(self, graph_name: str):
async with self.pool.acquire() as connection: # type: ignore
try:
await connection.execute(
'SET search_path = ag_catalog, "$user", public'
) # type: ignore
await connection.execute(f"select create_graph('{graph_name}')") # type: ignore
except (
asyncpg.exceptions.InvalidSchemaNameError,
asyncpg.exceptions.UniqueViolationError,
):
pass
@staticmethod
async def configure_age(connection: asyncpg.Connection, graph_name: str) -> None:
"""Set the Apache AGE environment and creates a graph if it does not exist.
This method:
- Sets the PostgreSQL `search_path` to include `ag_catalog`, ensuring that Apache AGE functions can be used without specifying the schema.
- Attempts to create a new graph with the provided `graph_name` if it does not already exist.
- Silently ignores errors related to the graph already existing.
"""
try:
await connection.execute( # type: ignore
'SET search_path = ag_catalog, "$user", public'
)
await connection.execute( # type: ignore
f"select create_graph('{graph_name}')"
)
except (
asyncpg.exceptions.InvalidSchemaNameError,
asyncpg.exceptions.UniqueViolationError,
):
pass
async def check_tables(self):
for k, v in TABLES.items():
try:
await self.query(f"SELECT 1 FROM {k} LIMIT 1")
except Exception as e:
logger.error(f"PostgreSQL database error: {e}")
except Exception:
try:
logger.info(f"PostgreSQL, Try Creating table {k} in database")
await self.execute(v["ddl"])
logger.info(f"PostgreSQL, Created table {k} in PostgreSQL database")
logger.info(
f"PostgreSQL, Creation success table {k} in PostgreSQL database"
)
except Exception as e:
logger.error(
f"PostgreSQL, Failed to create table {k} in database, Please verify the connection with PostgreSQL database, Got: {e}"
@@ -120,8 +131,15 @@ class PostgreSQLDB:
sql: str,
params: dict[str, Any] | None = None,
multirows: bool = False,
with_age: bool = False,
graph_name: str | None = None,
) -> dict[str, Any] | None | list[dict[str, Any]]:
async with self.pool.acquire() as connection: # type: ignore
if with_age and graph_name:
await self.configure_age(connection, graph_name) # type: ignore
elif with_age and not graph_name:
raise ValueError("Graph name is required when with_age is True")
try:
if params:
rows = await connection.fetch(sql, *params.values())
@@ -142,9 +160,7 @@ class PostgreSQLDB:
data = None
return data
except Exception as e:
logger.error(
f"PostgreSQL database,\nsql:{sql},\nparams:{params},\nerror:{e}"
)
logger.error(f"PostgreSQL database, error:{e}")
raise
async def execute(
@@ -152,6 +168,8 @@ class PostgreSQLDB:
sql: str,
data: dict[str, Any] | None = None,
upsert: bool = False,
with_age: bool = False,
graph_name: str | None = None,
):
try:
async with self.pool.acquire() as connection: # type: ignore
@@ -656,9 +674,6 @@ class PGGraphStorage(BaseGraphStorage):
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
# `check_graph_requirement` is required to be executed after `get_client`
# to ensure the graph is created before any query is executed.
await self.db.check_graph_requirement(self.graph_name)
async def finalize(self):
if self.db is not None:
@@ -829,12 +844,17 @@ class PGGraphStorage(BaseGraphStorage):
data = await self.db.query(
query,
multirows=True,
with_age=True,
graph_name=self.graph_name,
)
else:
data = await self.db.execute(
query,
upsert=upsert,
with_age=True,
graph_name=self.graph_name,
)
except Exception as e:
raise PGGraphQueryException(
{