back age
This commit is contained in:
@@ -86,29 +86,40 @@ class PostgreSQLDB:
|
|||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def check_graph_requirement(self, graph_name: str):
|
@staticmethod
|
||||||
async with self.pool.acquire() as connection: # type: ignore
|
async def configure_age(connection: asyncpg.Connection, graph_name: str) -> None:
|
||||||
try:
|
"""Set the Apache AGE environment and creates a graph if it does not exist.
|
||||||
await connection.execute(
|
|
||||||
'SET search_path = ag_catalog, "$user", public'
|
This method:
|
||||||
) # type: ignore
|
- Sets the PostgreSQL `search_path` to include `ag_catalog`, ensuring that Apache AGE functions can be used without specifying the schema.
|
||||||
await connection.execute(f"select create_graph('{graph_name}')") # type: ignore
|
- Attempts to create a new graph with the provided `graph_name` if it does not already exist.
|
||||||
except (
|
- Silently ignores errors related to the graph already existing.
|
||||||
asyncpg.exceptions.InvalidSchemaNameError,
|
|
||||||
asyncpg.exceptions.UniqueViolationError,
|
"""
|
||||||
):
|
try:
|
||||||
pass
|
await connection.execute( # type: ignore
|
||||||
|
'SET search_path = ag_catalog, "$user", public'
|
||||||
|
)
|
||||||
|
await connection.execute( # type: ignore
|
||||||
|
f"select create_graph('{graph_name}')"
|
||||||
|
)
|
||||||
|
except (
|
||||||
|
asyncpg.exceptions.InvalidSchemaNameError,
|
||||||
|
asyncpg.exceptions.UniqueViolationError,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
async def check_tables(self):
|
async def check_tables(self):
|
||||||
for k, v in TABLES.items():
|
for k, v in TABLES.items():
|
||||||
try:
|
try:
|
||||||
await self.query(f"SELECT 1 FROM {k} LIMIT 1")
|
await self.query(f"SELECT 1 FROM {k} LIMIT 1")
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error(f"PostgreSQL database error: {e}")
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"PostgreSQL, Try Creating table {k} in database")
|
logger.info(f"PostgreSQL, Try Creating table {k} in database")
|
||||||
await self.execute(v["ddl"])
|
await self.execute(v["ddl"])
|
||||||
logger.info(f"PostgreSQL, Created table {k} in PostgreSQL database")
|
logger.info(
|
||||||
|
f"PostgreSQL, Creation success table {k} in PostgreSQL database"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"PostgreSQL, Failed to create table {k} in database, Please verify the connection with PostgreSQL database, Got: {e}"
|
f"PostgreSQL, Failed to create table {k} in database, Please verify the connection with PostgreSQL database, Got: {e}"
|
||||||
@@ -120,8 +131,15 @@ class PostgreSQLDB:
|
|||||||
sql: str,
|
sql: str,
|
||||||
params: dict[str, Any] | None = None,
|
params: dict[str, Any] | None = None,
|
||||||
multirows: bool = False,
|
multirows: bool = False,
|
||||||
|
with_age: bool = False,
|
||||||
|
graph_name: str | None = None,
|
||||||
) -> dict[str, Any] | None | list[dict[str, Any]]:
|
) -> dict[str, Any] | None | list[dict[str, Any]]:
|
||||||
async with self.pool.acquire() as connection: # type: ignore
|
async with self.pool.acquire() as connection: # type: ignore
|
||||||
|
if with_age and graph_name:
|
||||||
|
await self.configure_age(connection, graph_name) # type: ignore
|
||||||
|
elif with_age and not graph_name:
|
||||||
|
raise ValueError("Graph name is required when with_age is True")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if params:
|
if params:
|
||||||
rows = await connection.fetch(sql, *params.values())
|
rows = await connection.fetch(sql, *params.values())
|
||||||
@@ -142,9 +160,7 @@ class PostgreSQLDB:
|
|||||||
data = None
|
data = None
|
||||||
return data
|
return data
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(f"PostgreSQL database, error:{e}")
|
||||||
f"PostgreSQL database,\nsql:{sql},\nparams:{params},\nerror:{e}"
|
|
||||||
)
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def execute(
|
async def execute(
|
||||||
@@ -152,6 +168,8 @@ class PostgreSQLDB:
|
|||||||
sql: str,
|
sql: str,
|
||||||
data: dict[str, Any] | None = None,
|
data: dict[str, Any] | None = None,
|
||||||
upsert: bool = False,
|
upsert: bool = False,
|
||||||
|
with_age: bool = False,
|
||||||
|
graph_name: str | None = None,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
async with self.pool.acquire() as connection: # type: ignore
|
async with self.pool.acquire() as connection: # type: ignore
|
||||||
@@ -656,9 +674,6 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
if self.db is None:
|
if self.db is None:
|
||||||
self.db = await ClientManager.get_client()
|
self.db = await ClientManager.get_client()
|
||||||
# `check_graph_requirement` is required to be executed after `get_client`
|
|
||||||
# to ensure the graph is created before any query is executed.
|
|
||||||
await self.db.check_graph_requirement(self.graph_name)
|
|
||||||
|
|
||||||
async def finalize(self):
|
async def finalize(self):
|
||||||
if self.db is not None:
|
if self.db is not None:
|
||||||
@@ -829,12 +844,17 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
data = await self.db.query(
|
data = await self.db.query(
|
||||||
query,
|
query,
|
||||||
multirows=True,
|
multirows=True,
|
||||||
|
with_age=True,
|
||||||
|
graph_name=self.graph_name,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
data = await self.db.execute(
|
data = await self.db.execute(
|
||||||
query,
|
query,
|
||||||
upsert=upsert,
|
upsert=upsert,
|
||||||
|
with_age=True,
|
||||||
|
graph_name=self.graph_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise PGGraphQueryException(
|
raise PGGraphQueryException(
|
||||||
{
|
{
|
||||||
|
Reference in New Issue
Block a user