This commit is contained in:
Yannick Stephan
2025-02-19 14:26:46 +01:00
parent fc151e5866
commit 48377d91ef

View File

@@ -86,29 +86,40 @@ class PostgreSQLDB:
) )
raise raise
async def check_graph_requirement(self, graph_name: str): @staticmethod
async with self.pool.acquire() as connection: # type: ignore async def configure_age(connection: asyncpg.Connection, graph_name: str) -> None:
try: """Set the Apache AGE environment and creates a graph if it does not exist.
await connection.execute(
'SET search_path = ag_catalog, "$user", public' This method:
) # type: ignore - Sets the PostgreSQL `search_path` to include `ag_catalog`, ensuring that Apache AGE functions can be used without specifying the schema.
await connection.execute(f"select create_graph('{graph_name}')") # type: ignore - Attempts to create a new graph with the provided `graph_name` if it does not already exist.
except ( - Silently ignores errors related to the graph already existing.
asyncpg.exceptions.InvalidSchemaNameError,
asyncpg.exceptions.UniqueViolationError, """
): try:
pass await connection.execute( # type: ignore
'SET search_path = ag_catalog, "$user", public'
)
await connection.execute( # type: ignore
f"select create_graph('{graph_name}')"
)
except (
asyncpg.exceptions.InvalidSchemaNameError,
asyncpg.exceptions.UniqueViolationError,
):
pass
async def check_tables(self): async def check_tables(self):
for k, v in TABLES.items(): for k, v in TABLES.items():
try: try:
await self.query(f"SELECT 1 FROM {k} LIMIT 1") await self.query(f"SELECT 1 FROM {k} LIMIT 1")
except Exception as e: except Exception:
logger.error(f"PostgreSQL database error: {e}")
try: try:
logger.info(f"PostgreSQL, Try Creating table {k} in database") logger.info(f"PostgreSQL, Try Creating table {k} in database")
await self.execute(v["ddl"]) await self.execute(v["ddl"])
logger.info(f"PostgreSQL, Created table {k} in PostgreSQL database") logger.info(
f"PostgreSQL, Creation success table {k} in PostgreSQL database"
)
except Exception as e: except Exception as e:
logger.error( logger.error(
f"PostgreSQL, Failed to create table {k} in database, Please verify the connection with PostgreSQL database, Got: {e}" f"PostgreSQL, Failed to create table {k} in database, Please verify the connection with PostgreSQL database, Got: {e}"
@@ -120,8 +131,15 @@ class PostgreSQLDB:
sql: str, sql: str,
params: dict[str, Any] | None = None, params: dict[str, Any] | None = None,
multirows: bool = False, multirows: bool = False,
with_age: bool = False,
graph_name: str | None = None,
) -> dict[str, Any] | None | list[dict[str, Any]]: ) -> dict[str, Any] | None | list[dict[str, Any]]:
async with self.pool.acquire() as connection: # type: ignore async with self.pool.acquire() as connection: # type: ignore
if with_age and graph_name:
await self.configure_age(connection, graph_name) # type: ignore
elif with_age and not graph_name:
raise ValueError("Graph name is required when with_age is True")
try: try:
if params: if params:
rows = await connection.fetch(sql, *params.values()) rows = await connection.fetch(sql, *params.values())
@@ -142,9 +160,7 @@ class PostgreSQLDB:
data = None data = None
return data return data
except Exception as e: except Exception as e:
logger.error( logger.error(f"PostgreSQL database, error:{e}")
f"PostgreSQL database,\nsql:{sql},\nparams:{params},\nerror:{e}"
)
raise raise
async def execute( async def execute(
@@ -152,6 +168,8 @@ class PostgreSQLDB:
sql: str, sql: str,
data: dict[str, Any] | None = None, data: dict[str, Any] | None = None,
upsert: bool = False, upsert: bool = False,
with_age: bool = False,
graph_name: str | None = None,
): ):
try: try:
async with self.pool.acquire() as connection: # type: ignore async with self.pool.acquire() as connection: # type: ignore
@@ -656,9 +674,6 @@ class PGGraphStorage(BaseGraphStorage):
async def initialize(self): async def initialize(self):
if self.db is None: if self.db is None:
self.db = await ClientManager.get_client() self.db = await ClientManager.get_client()
# `check_graph_requirement` is required to be executed after `get_client`
# to ensure the graph is created before any query is executed.
await self.db.check_graph_requirement(self.graph_name)
async def finalize(self): async def finalize(self):
if self.db is not None: if self.db is not None:
@@ -829,12 +844,17 @@ class PGGraphStorage(BaseGraphStorage):
data = await self.db.query( data = await self.db.query(
query, query,
multirows=True, multirows=True,
with_age=True,
graph_name=self.graph_name,
) )
else: else:
data = await self.db.execute( data = await self.db.execute(
query, query,
upsert=upsert, upsert=upsert,
with_age=True,
graph_name=self.graph_name,
) )
except Exception as e: except Exception as e:
raise PGGraphQueryException( raise PGGraphQueryException(
{ {