|
|
@@ -1,10 +1,9 @@
|
|
|
|
import asyncio
|
|
|
|
import asyncio
|
|
|
|
import inspect
|
|
|
|
|
|
|
|
import json
|
|
|
|
import json
|
|
|
|
import os
|
|
|
|
import os
|
|
|
|
import time
|
|
|
|
import time
|
|
|
|
from dataclasses import dataclass, field
|
|
|
|
from dataclasses import dataclass, field
|
|
|
|
from typing import Any, Dict, List, Union, final
|
|
|
|
from typing import Any, Union, final
|
|
|
|
import numpy as np
|
|
|
|
import numpy as np
|
|
|
|
import configparser
|
|
|
|
import configparser
|
|
|
|
|
|
|
|
|
|
|
@@ -41,6 +40,7 @@ if not pm.is_installed("asyncpg"):
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
import asyncpg
|
|
|
|
import asyncpg
|
|
|
|
|
|
|
|
from asyncpg import Pool
|
|
|
|
|
|
|
|
|
|
|
|
except ImportError as e:
|
|
|
|
except ImportError as e:
|
|
|
|
raise ImportError(
|
|
|
|
raise ImportError(
|
|
|
@@ -49,8 +49,7 @@ except ImportError as e:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PostgreSQLDB:
|
|
|
|
class PostgreSQLDB:
|
|
|
|
def __init__(self, config, **kwargs):
|
|
|
|
def __init__(self, config: dict[str, Any], **kwargs: Any):
|
|
|
|
self.pool = None
|
|
|
|
|
|
|
|
self.host = config.get("host", "localhost")
|
|
|
|
self.host = config.get("host", "localhost")
|
|
|
|
self.port = config.get("port", 5432)
|
|
|
|
self.port = config.get("port", 5432)
|
|
|
|
self.user = config.get("user", "postgres")
|
|
|
|
self.user = config.get("user", "postgres")
|
|
|
@@ -59,7 +58,7 @@ class PostgreSQLDB:
|
|
|
|
self.workspace = config.get("workspace", "default")
|
|
|
|
self.workspace = config.get("workspace", "default")
|
|
|
|
self.max = 12
|
|
|
|
self.max = 12
|
|
|
|
self.increment = 1
|
|
|
|
self.increment = 1
|
|
|
|
logger.info(f"Using the label {self.workspace} for PostgreSQL as identifier")
|
|
|
|
self.pool: Pool | None = None
|
|
|
|
|
|
|
|
|
|
|
|
if self.user is None or self.password is None or self.database is None:
|
|
|
|
if self.user is None or self.password is None or self.database is None:
|
|
|
|
raise ValueError(
|
|
|
|
raise ValueError(
|
|
|
@@ -68,7 +67,7 @@ class PostgreSQLDB:
|
|
|
|
|
|
|
|
|
|
|
|
async def initdb(self):
|
|
|
|
async def initdb(self):
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
self.pool = await asyncpg.create_pool(
|
|
|
|
self.pool = await asyncpg.create_pool( # type: ignore
|
|
|
|
user=self.user,
|
|
|
|
user=self.user,
|
|
|
|
password=self.password,
|
|
|
|
password=self.password,
|
|
|
|
database=self.database,
|
|
|
|
database=self.database,
|
|
|
@@ -79,43 +78,51 @@ class PostgreSQLDB:
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
logger.info(
|
|
|
|
f"Connected to PostgreSQL database at {self.host}:{self.port}/{self.database}"
|
|
|
|
f"PostgreSQL, Connected to database at {self.host}:{self.port}/{self.database}"
|
|
|
|
)
|
|
|
|
)
|
|
|
|
except Exception as e:
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(
|
|
|
|
logger.error(
|
|
|
|
f"Failed to connect to PostgreSQL database at {self.host}:{self.port}/{self.database}"
|
|
|
|
f"PostgreSQL, Failed to connect database at {self.host}:{self.port}/{self.database}, Got:{e}"
|
|
|
|
)
|
|
|
|
)
|
|
|
|
logger.error(f"PostgreSQL database error: {e}")
|
|
|
|
|
|
|
|
raise
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def check_graph_requirement(self, graph_name: str):
|
|
|
|
|
|
|
|
async with self.pool.acquire() as connection: # type: ignore
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
await connection.execute(
|
|
|
|
|
|
|
|
'SET search_path = ag_catalog, "$user", public'
|
|
|
|
|
|
|
|
) # type: ignore
|
|
|
|
|
|
|
|
await connection.execute(f"select create_graph('{graph_name}')") # type: ignore
|
|
|
|
|
|
|
|
except (
|
|
|
|
|
|
|
|
asyncpg.exceptions.InvalidSchemaNameError,
|
|
|
|
|
|
|
|
asyncpg.exceptions.UniqueViolationError,
|
|
|
|
|
|
|
|
):
|
|
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
async def check_tables(self):
|
|
|
|
async def check_tables(self):
|
|
|
|
for k, v in TABLES.items():
|
|
|
|
for k, v in TABLES.items():
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
await self.query(f"SELECT 1 FROM {k} LIMIT 1")
|
|
|
|
await self.query(f"SELECT 1 FROM {k} LIMIT 1")
|
|
|
|
except Exception as e:
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"Failed to check table {k} in PostgreSQL database")
|
|
|
|
|
|
|
|
logger.error(f"PostgreSQL database error: {e}")
|
|
|
|
logger.error(f"PostgreSQL database error: {e}")
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
|
|
|
|
logger.info(f"PostgreSQL, Try Creating table {k} in database")
|
|
|
|
await self.execute(v["ddl"])
|
|
|
|
await self.execute(v["ddl"])
|
|
|
|
logger.info(f"Created table {k} in PostgreSQL database")
|
|
|
|
logger.info(f"PostgreSQL, Created table {k} in PostgreSQL database")
|
|
|
|
except Exception as e:
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"Failed to create table {k} in PostgreSQL database")
|
|
|
|
logger.error(
|
|
|
|
logger.error(f"PostgreSQL database error: {e}")
|
|
|
|
f"PostgreSQL, Failed to create table {k} in database, Please verify the connection with PostgreSQL database, Got: {e}"
|
|
|
|
|
|
|
|
)
|
|
|
|
logger.info("Finished checking all tables in PostgreSQL database")
|
|
|
|
raise e
|
|
|
|
|
|
|
|
|
|
|
|
async def query(
|
|
|
|
async def query(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|
sql: str,
|
|
|
|
sql: str,
|
|
|
|
params: dict = None,
|
|
|
|
params: dict[str, Any] | None = None,
|
|
|
|
multirows: bool = False,
|
|
|
|
multirows: bool = False,
|
|
|
|
for_age: bool = False,
|
|
|
|
) -> dict[str, Any] | None | list[dict[str, Any]]:
|
|
|
|
graph_name: str = None,
|
|
|
|
async with self.pool.acquire() as connection: # type: ignore
|
|
|
|
) -> Union[dict, None, list[dict]]:
|
|
|
|
|
|
|
|
async with self.pool.acquire() as connection:
|
|
|
|
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
if for_age:
|
|
|
|
|
|
|
|
await PostgreSQLDB._prerequisite(connection, graph_name)
|
|
|
|
|
|
|
|
if params:
|
|
|
|
if params:
|
|
|
|
rows = await connection.fetch(sql, *params.values())
|
|
|
|
rows = await connection.fetch(sql, *params.values())
|
|
|
|
else:
|
|
|
|
else:
|
|
|
@@ -143,20 +150,15 @@ class PostgreSQLDB:
|
|
|
|
async def execute(
|
|
|
|
async def execute(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|
sql: str,
|
|
|
|
sql: str,
|
|
|
|
data: Union[list, dict] = None,
|
|
|
|
data: dict[str, Any] | None = None,
|
|
|
|
for_age: bool = False,
|
|
|
|
|
|
|
|
graph_name: str = None,
|
|
|
|
|
|
|
|
upsert: bool = False,
|
|
|
|
upsert: bool = False,
|
|
|
|
):
|
|
|
|
):
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
async with self.pool.acquire() as connection:
|
|
|
|
async with self.pool.acquire() as connection: # type: ignore
|
|
|
|
if for_age:
|
|
|
|
|
|
|
|
await PostgreSQLDB._prerequisite(connection, graph_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if data is None:
|
|
|
|
if data is None:
|
|
|
|
await connection.execute(sql)
|
|
|
|
await connection.execute(sql) # type: ignore
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
await connection.execute(sql, *data.values())
|
|
|
|
await connection.execute(sql, *data.values()) # type: ignore
|
|
|
|
except (
|
|
|
|
except (
|
|
|
|
asyncpg.exceptions.UniqueViolationError,
|
|
|
|
asyncpg.exceptions.UniqueViolationError,
|
|
|
|
asyncpg.exceptions.DuplicateTableError,
|
|
|
|
asyncpg.exceptions.DuplicateTableError,
|
|
|
@@ -169,24 +171,13 @@ class PostgreSQLDB:
|
|
|
|
logger.error(f"PostgreSQL database,\nsql:{sql},\ndata:{data},\nerror:{e}")
|
|
|
|
logger.error(f"PostgreSQL database,\nsql:{sql},\ndata:{data},\nerror:{e}")
|
|
|
|
raise
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
|
|
async def _prerequisite(conn: asyncpg.Connection, graph_name: str):
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
await conn.execute('SET search_path = ag_catalog, "$user", public')
|
|
|
|
|
|
|
|
await conn.execute(f"""select create_graph('{graph_name}')""")
|
|
|
|
|
|
|
|
except (
|
|
|
|
|
|
|
|
asyncpg.exceptions.InvalidSchemaNameError,
|
|
|
|
|
|
|
|
asyncpg.exceptions.UniqueViolationError,
|
|
|
|
|
|
|
|
):
|
|
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ClientManager:
|
|
|
|
class ClientManager:
|
|
|
|
_instances = {"db": None, "ref_count": 0}
|
|
|
|
_instances: dict[str, Any] = {"db": None, "ref_count": 0}
|
|
|
|
_lock = asyncio.Lock()
|
|
|
|
_lock = asyncio.Lock()
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
def get_config():
|
|
|
|
def get_config() -> dict[str, Any]:
|
|
|
|
config = configparser.ConfigParser()
|
|
|
|
config = configparser.ConfigParser()
|
|
|
|
config.read("config.ini", "utf-8")
|
|
|
|
config.read("config.ini", "utf-8")
|
|
|
|
|
|
|
|
|
|
|
@@ -377,7 +368,7 @@ class PGKVStorage(BaseKVStorage):
|
|
|
|
@final
|
|
|
|
@final
|
|
|
|
@dataclass
|
|
|
|
@dataclass
|
|
|
|
class PGVectorStorage(BaseVectorStorage):
|
|
|
|
class PGVectorStorage(BaseVectorStorage):
|
|
|
|
db: PostgreSQLDB = field(default=None)
|
|
|
|
db: PostgreSQLDB | None = field(default=None)
|
|
|
|
|
|
|
|
|
|
|
|
def __post_init__(self):
|
|
|
|
def __post_init__(self):
|
|
|
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
|
|
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
|
|
@@ -398,10 +389,10 @@ class PGVectorStorage(BaseVectorStorage):
|
|
|
|
await ClientManager.release_client(self.db)
|
|
|
|
await ClientManager.release_client(self.db)
|
|
|
|
self.db = None
|
|
|
|
self.db = None
|
|
|
|
|
|
|
|
|
|
|
|
def _upsert_chunks(self, item: dict):
|
|
|
|
def _upsert_chunks(self, item: dict[str, Any]) -> tuple[str, dict[str, Any]]:
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
upsert_sql = SQL_TEMPLATES["upsert_chunk"]
|
|
|
|
upsert_sql = SQL_TEMPLATES["upsert_chunk"]
|
|
|
|
data = {
|
|
|
|
data: dict[str, Any] = {
|
|
|
|
"workspace": self.db.workspace,
|
|
|
|
"workspace": self.db.workspace,
|
|
|
|
"id": item["__id__"],
|
|
|
|
"id": item["__id__"],
|
|
|
|
"tokens": item["tokens"],
|
|
|
|
"tokens": item["tokens"],
|
|
|
@@ -416,9 +407,9 @@ class PGVectorStorage(BaseVectorStorage):
|
|
|
|
|
|
|
|
|
|
|
|
return upsert_sql, data
|
|
|
|
return upsert_sql, data
|
|
|
|
|
|
|
|
|
|
|
|
def _upsert_entities(self, item: dict):
|
|
|
|
def _upsert_entities(self, item: dict[str, Any]) -> tuple[str, dict[str, Any]]:
|
|
|
|
upsert_sql = SQL_TEMPLATES["upsert_entity"]
|
|
|
|
upsert_sql = SQL_TEMPLATES["upsert_entity"]
|
|
|
|
data = {
|
|
|
|
data: dict[str, Any] = {
|
|
|
|
"workspace": self.db.workspace,
|
|
|
|
"workspace": self.db.workspace,
|
|
|
|
"id": item["__id__"],
|
|
|
|
"id": item["__id__"],
|
|
|
|
"entity_name": item["entity_name"],
|
|
|
|
"entity_name": item["entity_name"],
|
|
|
@@ -427,9 +418,9 @@ class PGVectorStorage(BaseVectorStorage):
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return upsert_sql, data
|
|
|
|
return upsert_sql, data
|
|
|
|
|
|
|
|
|
|
|
|
def _upsert_relationships(self, item: dict):
|
|
|
|
def _upsert_relationships(self, item: dict[str, Any]) -> tuple[str, dict[str, Any]]:
|
|
|
|
upsert_sql = SQL_TEMPLATES["upsert_relationship"]
|
|
|
|
upsert_sql = SQL_TEMPLATES["upsert_relationship"]
|
|
|
|
data = {
|
|
|
|
data: dict[str, Any] = {
|
|
|
|
"workspace": self.db.workspace,
|
|
|
|
"workspace": self.db.workspace,
|
|
|
|
"id": item["__id__"],
|
|
|
|
"id": item["__id__"],
|
|
|
|
"source_id": item["src_id"],
|
|
|
|
"source_id": item["src_id"],
|
|
|
@@ -558,16 +549,16 @@ class PGDocStatusStorage(DocStatusStorage):
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
|
|
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
|
|
|
|
|
|
|
"""Get doc_chunks data by id"""
|
|
|
|
raise NotImplementedError
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
async def get_status_counts(self) -> Dict[str, int]:
|
|
|
|
async def get_status_counts(self) -> dict[str, int]:
|
|
|
|
"""Get counts of documents in each status"""
|
|
|
|
"""Get counts of documents in each status"""
|
|
|
|
sql = """SELECT status as "status", COUNT(1) as "count"
|
|
|
|
sql = """SELECT status as "status", COUNT(1) as "count"
|
|
|
|
FROM LIGHTRAG_DOC_STATUS
|
|
|
|
FROM LIGHTRAG_DOC_STATUS
|
|
|
|
where workspace=$1 GROUP BY STATUS
|
|
|
|
where workspace=$1 GROUP BY STATUS
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
result = await self.db.query(sql, {"workspace": self.db.workspace}, True)
|
|
|
|
result = await self.db.query(sql, {"workspace": self.db.workspace}, True)
|
|
|
|
# Result is like [{'status': 'PENDING', 'count': 1}, {'status': 'PROCESSING', 'count': 2}, ...]
|
|
|
|
|
|
|
|
counts = {}
|
|
|
|
counts = {}
|
|
|
|
for doc in result:
|
|
|
|
for doc in result:
|
|
|
|
counts[doc["status"]] = doc["count"]
|
|
|
|
counts[doc["status"]] = doc["count"]
|
|
|
@@ -575,7 +566,7 @@ class PGDocStatusStorage(DocStatusStorage):
|
|
|
|
|
|
|
|
|
|
|
|
async def get_docs_by_status(
|
|
|
|
async def get_docs_by_status(
|
|
|
|
self, status: DocStatus
|
|
|
|
self, status: DocStatus
|
|
|
|
) -> Dict[str, DocProcessingStatus]:
|
|
|
|
) -> dict[str, DocProcessingStatus]:
|
|
|
|
"""all documents with a specific status"""
|
|
|
|
"""all documents with a specific status"""
|
|
|
|
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2"
|
|
|
|
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2"
|
|
|
|
params = {"workspace": self.db.workspace, "status": status.value}
|
|
|
|
params = {"workspace": self.db.workspace, "status": status.value}
|
|
|
@@ -602,7 +593,7 @@ class PGDocStatusStorage(DocStatusStorage):
|
|
|
|
"""Update or insert document status
|
|
|
|
"""Update or insert document status
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|
data: Dictionary of document IDs and their status data
|
|
|
|
data: dictionary of document IDs and their status data
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
sql = """insert into LIGHTRAG_DOC_STATUS(workspace,id,content,content_summary,content_length,chunks_count,status)
|
|
|
|
sql = """insert into LIGHTRAG_DOC_STATUS(workspace,id,content,content_summary,content_length,chunks_count,status)
|
|
|
|
values($1,$2,$3,$4,$5,$6,$7)
|
|
|
|
values($1,$2,$3,$4,$5,$6,$7)
|
|
|
@@ -627,7 +618,6 @@ class PGDocStatusStorage(DocStatusStorage):
|
|
|
|
"status": v["status"],
|
|
|
|
"status": v["status"],
|
|
|
|
},
|
|
|
|
},
|
|
|
|
)
|
|
|
|
)
|
|
|
|
return data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def drop(self) -> None:
|
|
|
|
async def drop(self) -> None:
|
|
|
|
"""Drop the storage"""
|
|
|
|
"""Drop the storage"""
|
|
|
@@ -638,7 +628,7 @@ class PGDocStatusStorage(DocStatusStorage):
|
|
|
|
class PGGraphQueryException(Exception):
|
|
|
|
class PGGraphQueryException(Exception):
|
|
|
|
"""Exception for the AGE queries."""
|
|
|
|
"""Exception for the AGE queries."""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, exception: Union[str, Dict]) -> None:
|
|
|
|
def __init__(self, exception: Union[str, dict[str, Any]]) -> None:
|
|
|
|
if isinstance(exception, dict):
|
|
|
|
if isinstance(exception, dict):
|
|
|
|
self.message = exception["message"] if "message" in exception else "unknown"
|
|
|
|
self.message = exception["message"] if "message" in exception else "unknown"
|
|
|
|
self.details = exception["details"] if "details" in exception else "unknown"
|
|
|
|
self.details = exception["details"] if "details" in exception else "unknown"
|
|
|
@@ -656,21 +646,19 @@ class PGGraphQueryException(Exception):
|
|
|
|
@final
|
|
|
|
@final
|
|
|
|
@dataclass
|
|
|
|
@dataclass
|
|
|
|
class PGGraphStorage(BaseGraphStorage):
|
|
|
|
class PGGraphStorage(BaseGraphStorage):
|
|
|
|
db: PostgreSQLDB = field(default=None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
|
|
def load_nx_graph(file_name):
|
|
|
|
|
|
|
|
print("no preloading of graph with AGE in production")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __post_init__(self):
|
|
|
|
def __post_init__(self):
|
|
|
|
self.graph_name = self.namespace or os.environ.get("AGE_GRAPH_NAME", "lightrag")
|
|
|
|
self.graph_name = self.namespace or os.environ.get("AGE_GRAPH_NAME", "lightrag")
|
|
|
|
self._node_embed_algorithms = {
|
|
|
|
self._node_embed_algorithms = {
|
|
|
|
"node2vec": self._node2vec_embed,
|
|
|
|
"node2vec": self._node2vec_embed,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
self.db: PostgreSQLDB | None = None
|
|
|
|
|
|
|
|
|
|
|
|
async def initialize(self):
|
|
|
|
async def initialize(self):
|
|
|
|
if self.db is None:
|
|
|
|
if self.db is None:
|
|
|
|
self.db = await ClientManager.get_client()
|
|
|
|
self.db = await ClientManager.get_client()
|
|
|
|
|
|
|
|
# `check_graph_requirement` is required to be executed after `get_client`
|
|
|
|
|
|
|
|
# to ensure the graph is created before any query is executed.
|
|
|
|
|
|
|
|
await self.db.check_graph_requirement(self.graph_name)
|
|
|
|
|
|
|
|
|
|
|
|
async def finalize(self):
|
|
|
|
async def finalize(self):
|
|
|
|
if self.db is not None:
|
|
|
|
if self.db is not None:
|
|
|
@@ -682,7 +670,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|
|
|
pass
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
def _record_to_dict(record: asyncpg.Record) -> Dict[str, Any]:
|
|
|
|
def _record_to_dict(record: asyncpg.Record) -> dict[str, Any]:
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Convert a record returned from an age query to a dictionary
|
|
|
|
Convert a record returned from an age query to a dictionary
|
|
|
|
|
|
|
|
|
|
|
@@ -690,7 +678,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|
|
|
record (): a record from an age query result
|
|
|
|
record (): a record from an age query result
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Returns:
|
|
|
|
Dict[str, Any]: a dictionary representation of the record where
|
|
|
|
dict[str, Any]: a dictionary representation of the record where
|
|
|
|
the dictionary key is the field name and the value is the
|
|
|
|
the dictionary key is the field name and the value is the
|
|
|
|
value converted to a python type
|
|
|
|
value converted to a python type
|
|
|
|
"""
|
|
|
|
"""
|
|
|
@@ -745,14 +733,14 @@ class PGGraphStorage(BaseGraphStorage):
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
def _format_properties(
|
|
|
|
def _format_properties(
|
|
|
|
properties: Dict[str, Any], _id: Union[str, None] = None
|
|
|
|
properties: dict[str, Any], _id: Union[str, None] = None
|
|
|
|
) -> str:
|
|
|
|
) -> str:
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Convert a dictionary of properties to a string representation that
|
|
|
|
Convert a dictionary of properties to a string representation that
|
|
|
|
can be used in a cypher query insert/merge statement.
|
|
|
|
can be used in a cypher query insert/merge statement.
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|
properties (Dict[str,str]): a dictionary containing node/edge properties
|
|
|
|
properties (dict[str,str]): a dictionary containing node/edge properties
|
|
|
|
_id (Union[str, None]): the id of the node or None if none exists
|
|
|
|
_id (Union[str, None]): the id of the node or None if none exists
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Returns:
|
|
|
@@ -820,8 +808,11 @@ class PGGraphStorage(BaseGraphStorage):
|
|
|
|
return field.replace("(", "_").replace(")", "")
|
|
|
|
return field.replace("(", "_").replace(")", "")
|
|
|
|
|
|
|
|
|
|
|
|
async def _query(
|
|
|
|
async def _query(
|
|
|
|
self, query: str, readonly: bool = True, upsert: bool = False
|
|
|
|
self,
|
|
|
|
) -> List[Dict[str, Any]]:
|
|
|
|
query: str,
|
|
|
|
|
|
|
|
readonly: bool = True,
|
|
|
|
|
|
|
|
upsert: bool = False,
|
|
|
|
|
|
|
|
) -> list[dict[str, Any]]:
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Query the graph by taking a cypher query, converting it to an
|
|
|
|
Query the graph by taking a cypher query, converting it to an
|
|
|
|
age compatible query, executing it and converting the result
|
|
|
|
age compatible query, executing it and converting the result
|
|
|
@@ -831,32 +822,24 @@ class PGGraphStorage(BaseGraphStorage):
|
|
|
|
params (dict): parameters for the query
|
|
|
|
params (dict): parameters for the query
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Returns:
|
|
|
|
List[Dict[str, Any]]: a list of dictionaries containing the result set
|
|
|
|
list[dict[str, Any]]: a list of dictionaries containing the result set
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
# convert cypher query to pgsql/age query
|
|
|
|
|
|
|
|
wrapped_query = query
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# execute the query, rolling back on an error
|
|
|
|
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
if readonly:
|
|
|
|
if readonly:
|
|
|
|
data = await self.db.query(
|
|
|
|
data = await self.db.query(
|
|
|
|
wrapped_query,
|
|
|
|
query,
|
|
|
|
multirows=True,
|
|
|
|
multirows=True,
|
|
|
|
for_age=True,
|
|
|
|
|
|
|
|
graph_name=self.graph_name,
|
|
|
|
|
|
|
|
)
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
data = await self.db.execute(
|
|
|
|
data = await self.db.execute(
|
|
|
|
wrapped_query,
|
|
|
|
query,
|
|
|
|
for_age=True,
|
|
|
|
|
|
|
|
graph_name=self.graph_name,
|
|
|
|
|
|
|
|
upsert=upsert,
|
|
|
|
upsert=upsert,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
except Exception as e:
|
|
|
|
except Exception as e:
|
|
|
|
raise PGGraphQueryException(
|
|
|
|
raise PGGraphQueryException(
|
|
|
|
{
|
|
|
|
{
|
|
|
|
"message": f"Error executing graph query: {query}",
|
|
|
|
"message": f"Error executing graph query: {query}",
|
|
|
|
"wrapped": wrapped_query,
|
|
|
|
"wrapped": query,
|
|
|
|
"detail": str(e),
|
|
|
|
"detail": str(e),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
) from e
|
|
|
|
) from e
|
|
|
@@ -865,12 +848,12 @@ class PGGraphStorage(BaseGraphStorage):
|
|
|
|
result = []
|
|
|
|
result = []
|
|
|
|
# decode records
|
|
|
|
# decode records
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
result = [PGGraphStorage._record_to_dict(d) for d in data]
|
|
|
|
result = [self._record_to_dict(d) for d in data]
|
|
|
|
|
|
|
|
|
|
|
|
return result
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
async def has_node(self, node_id: str) -> bool:
|
|
|
|
async def has_node(self, node_id: str) -> bool:
|
|
|
|
entity_name_label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
|
|
|
|
entity_name_label = self._encode_graph_label(node_id.strip('"'))
|
|
|
|
|
|
|
|
|
|
|
|
query = """SELECT * FROM cypher('%s', $$
|
|
|
|
query = """SELECT * FROM cypher('%s', $$
|
|
|
|
MATCH (n:Entity {node_id: "%s"})
|
|
|
|
MATCH (n:Entity {node_id: "%s"})
|
|
|
@@ -878,18 +861,12 @@ class PGGraphStorage(BaseGraphStorage):
|
|
|
|
$$) AS (node_exists bool)""" % (self.graph_name, entity_name_label)
|
|
|
|
$$) AS (node_exists bool)""" % (self.graph_name, entity_name_label)
|
|
|
|
|
|
|
|
|
|
|
|
single_result = (await self._query(query))[0]
|
|
|
|
single_result = (await self._query(query))[0]
|
|
|
|
logger.debug(
|
|
|
|
|
|
|
|
"{%s}:query:{%s}:result:{%s}",
|
|
|
|
|
|
|
|
inspect.currentframe().f_code.co_name,
|
|
|
|
|
|
|
|
query,
|
|
|
|
|
|
|
|
single_result["node_exists"],
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return single_result["node_exists"]
|
|
|
|
return single_result["node_exists"]
|
|
|
|
|
|
|
|
|
|
|
|
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
|
|
|
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
|
|
|
src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"'))
|
|
|
|
src_label = self._encode_graph_label(source_node_id.strip('"'))
|
|
|
|
tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"'))
|
|
|
|
tgt_label = self._encode_graph_label(target_node_id.strip('"'))
|
|
|
|
|
|
|
|
|
|
|
|
query = """SELECT * FROM cypher('%s', $$
|
|
|
|
query = """SELECT * FROM cypher('%s', $$
|
|
|
|
MATCH (a:Entity {node_id: "%s"})-[r]-(b:Entity {node_id: "%s"})
|
|
|
|
MATCH (a:Entity {node_id: "%s"})-[r]-(b:Entity {node_id: "%s"})
|
|
|
@@ -901,16 +878,11 @@ class PGGraphStorage(BaseGraphStorage):
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
single_result = (await self._query(query))[0]
|
|
|
|
single_result = (await self._query(query))[0]
|
|
|
|
logger.debug(
|
|
|
|
|
|
|
|
"{%s}:query:{%s}:result:{%s}",
|
|
|
|
|
|
|
|
inspect.currentframe().f_code.co_name,
|
|
|
|
|
|
|
|
query,
|
|
|
|
|
|
|
|
single_result["edge_exists"],
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
return single_result["edge_exists"]
|
|
|
|
return single_result["edge_exists"]
|
|
|
|
|
|
|
|
|
|
|
|
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
|
|
|
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
|
|
|
label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
|
|
|
|
label = self._encode_graph_label(node_id.strip('"'))
|
|
|
|
query = """SELECT * FROM cypher('%s', $$
|
|
|
|
query = """SELECT * FROM cypher('%s', $$
|
|
|
|
MATCH (n:Entity {node_id: "%s"})
|
|
|
|
MATCH (n:Entity {node_id: "%s"})
|
|
|
|
RETURN n
|
|
|
|
RETURN n
|
|
|
@@ -919,17 +891,12 @@ class PGGraphStorage(BaseGraphStorage):
|
|
|
|
if record:
|
|
|
|
if record:
|
|
|
|
node = record[0]
|
|
|
|
node = record[0]
|
|
|
|
node_dict = node["n"]
|
|
|
|
node_dict = node["n"]
|
|
|
|
logger.debug(
|
|
|
|
|
|
|
|
"{%s}: query: {%s}, result: {%s}",
|
|
|
|
|
|
|
|
inspect.currentframe().f_code.co_name,
|
|
|
|
|
|
|
|
query,
|
|
|
|
|
|
|
|
node_dict,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
return node_dict
|
|
|
|
return node_dict
|
|
|
|
return None
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
async def node_degree(self, node_id: str) -> int:
|
|
|
|
async def node_degree(self, node_id: str) -> int:
|
|
|
|
label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
|
|
|
|
label = self._encode_graph_label(node_id.strip('"'))
|
|
|
|
|
|
|
|
|
|
|
|
query = """SELECT * FROM cypher('%s', $$
|
|
|
|
query = """SELECT * FROM cypher('%s', $$
|
|
|
|
MATCH (n:Entity {node_id: "%s"})-[]->(x)
|
|
|
|
MATCH (n:Entity {node_id: "%s"})-[]->(x)
|
|
|
@@ -938,12 +905,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|
|
|
record = (await self._query(query))[0]
|
|
|
|
record = (await self._query(query))[0]
|
|
|
|
if record:
|
|
|
|
if record:
|
|
|
|
edge_count = int(record["total_edge_count"])
|
|
|
|
edge_count = int(record["total_edge_count"])
|
|
|
|
logger.debug(
|
|
|
|
|
|
|
|
"{%s}:query:{%s}:result:{%s}",
|
|
|
|
|
|
|
|
inspect.currentframe().f_code.co_name,
|
|
|
|
|
|
|
|
query,
|
|
|
|
|
|
|
|
edge_count,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
return edge_count
|
|
|
|
return edge_count
|
|
|
|
|
|
|
|
|
|
|
|
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
|
|
|
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
|
|
@@ -955,18 +917,14 @@ class PGGraphStorage(BaseGraphStorage):
|
|
|
|
trg_degree = 0 if trg_degree is None else trg_degree
|
|
|
|
trg_degree = 0 if trg_degree is None else trg_degree
|
|
|
|
|
|
|
|
|
|
|
|
degrees = int(src_degree) + int(trg_degree)
|
|
|
|
degrees = int(src_degree) + int(trg_degree)
|
|
|
|
logger.debug(
|
|
|
|
|
|
|
|
"{%s}:query:src_Degree+trg_degree:result:{%s}",
|
|
|
|
|
|
|
|
inspect.currentframe().f_code.co_name,
|
|
|
|
|
|
|
|
degrees,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
return degrees
|
|
|
|
return degrees
|
|
|
|
|
|
|
|
|
|
|
|
async def get_edge(
|
|
|
|
async def get_edge(
|
|
|
|
self, source_node_id: str, target_node_id: str
|
|
|
|
self, source_node_id: str, target_node_id: str
|
|
|
|
) -> dict[str, str] | None:
|
|
|
|
) -> dict[str, str] | None:
|
|
|
|
src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"'))
|
|
|
|
src_label = self._encode_graph_label(source_node_id.strip('"'))
|
|
|
|
tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"'))
|
|
|
|
tgt_label = self._encode_graph_label(target_node_id.strip('"'))
|
|
|
|
|
|
|
|
|
|
|
|
query = """SELECT * FROM cypher('%s', $$
|
|
|
|
query = """SELECT * FROM cypher('%s', $$
|
|
|
|
MATCH (a:Entity {node_id: "%s"})-[r]->(b:Entity {node_id: "%s"})
|
|
|
|
MATCH (a:Entity {node_id: "%s"})-[r]->(b:Entity {node_id: "%s"})
|
|
|
@@ -980,20 +938,15 @@ class PGGraphStorage(BaseGraphStorage):
|
|
|
|
record = await self._query(query)
|
|
|
|
record = await self._query(query)
|
|
|
|
if record and record[0] and record[0]["edge_properties"]:
|
|
|
|
if record and record[0] and record[0]["edge_properties"]:
|
|
|
|
result = record[0]["edge_properties"]
|
|
|
|
result = record[0]["edge_properties"]
|
|
|
|
logger.debug(
|
|
|
|
|
|
|
|
"{%s}:query:{%s}:result:{%s}",
|
|
|
|
|
|
|
|
inspect.currentframe().f_code.co_name,
|
|
|
|
|
|
|
|
query,
|
|
|
|
|
|
|
|
result,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
return result
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
|
|
|
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Retrieves all edges (relationships) for a particular node identified by its label.
|
|
|
|
Retrieves all edges (relationships) for a particular node identified by its label.
|
|
|
|
:return: List of dictionaries containing edge information
|
|
|
|
:return: list of dictionaries containing edge information
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
label = PGGraphStorage._encode_graph_label(source_node_id.strip('"'))
|
|
|
|
label = self._encode_graph_label(source_node_id.strip('"'))
|
|
|
|
|
|
|
|
|
|
|
|
query = """SELECT * FROM cypher('%s', $$
|
|
|
|
query = """SELECT * FROM cypher('%s', $$
|
|
|
|
MATCH (n:Entity {node_id: "%s"})
|
|
|
|
MATCH (n:Entity {node_id: "%s"})
|
|
|
@@ -1024,8 +977,8 @@ class PGGraphStorage(BaseGraphStorage):
|
|
|
|
if source_label and target_label:
|
|
|
|
if source_label and target_label:
|
|
|
|
edges.append(
|
|
|
|
edges.append(
|
|
|
|
(
|
|
|
|
(
|
|
|
|
PGGraphStorage._decode_graph_label(source_label),
|
|
|
|
self._decode_graph_label(source_label),
|
|
|
|
PGGraphStorage._decode_graph_label(target_label),
|
|
|
|
self._decode_graph_label(target_label),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
@@ -1037,7 +990,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|
|
|
retry=retry_if_exception_type((PGGraphQueryException,)),
|
|
|
|
retry=retry_if_exception_type((PGGraphQueryException,)),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
|
|
|
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
|
|
|
label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
|
|
|
|
label = self._encode_graph_label(node_id.strip('"'))
|
|
|
|
properties = node_data
|
|
|
|
properties = node_data
|
|
|
|
|
|
|
|
|
|
|
|
query = """SELECT * FROM cypher('%s', $$
|
|
|
|
query = """SELECT * FROM cypher('%s', $$
|
|
|
@@ -1047,18 +1000,14 @@ class PGGraphStorage(BaseGraphStorage):
|
|
|
|
$$) AS (n agtype)""" % (
|
|
|
|
$$) AS (n agtype)""" % (
|
|
|
|
self.graph_name,
|
|
|
|
self.graph_name,
|
|
|
|
label,
|
|
|
|
label,
|
|
|
|
PGGraphStorage._format_properties(properties),
|
|
|
|
self._format_properties(properties),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
await self._query(query, readonly=False, upsert=True)
|
|
|
|
await self._query(query, readonly=False, upsert=True)
|
|
|
|
logger.debug(
|
|
|
|
|
|
|
|
"Upserted node with label '{%s}' and properties: {%s}",
|
|
|
|
|
|
|
|
label,
|
|
|
|
|
|
|
|
properties,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
except Exception as e:
|
|
|
|
logger.error("Error during upsert: {%s}", e)
|
|
|
|
logger.error("POSTGRES, Error during upsert: {%s}", e)
|
|
|
|
raise
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
@retry(
|
|
|
|
@retry(
|
|
|
@@ -1075,10 +1024,10 @@ class PGGraphStorage(BaseGraphStorage):
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|
source_node_id (str): Label of the source node (used as identifier)
|
|
|
|
source_node_id (str): Label of the source node (used as identifier)
|
|
|
|
target_node_id (str): Label of the target node (used as identifier)
|
|
|
|
target_node_id (str): Label of the target node (used as identifier)
|
|
|
|
edge_data (dict): Dictionary of properties to set on the edge
|
|
|
|
edge_data (dict): dictionary of properties to set on the edge
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"'))
|
|
|
|
src_label = self._encode_graph_label(source_node_id.strip('"'))
|
|
|
|
tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"'))
|
|
|
|
tgt_label = self._encode_graph_label(target_node_id.strip('"'))
|
|
|
|
edge_properties = edge_data
|
|
|
|
edge_properties = edge_data
|
|
|
|
|
|
|
|
|
|
|
|
query = """SELECT * FROM cypher('%s', $$
|
|
|
|
query = """SELECT * FROM cypher('%s', $$
|
|
|
@@ -1092,17 +1041,12 @@ class PGGraphStorage(BaseGraphStorage):
|
|
|
|
self.graph_name,
|
|
|
|
self.graph_name,
|
|
|
|
src_label,
|
|
|
|
src_label,
|
|
|
|
tgt_label,
|
|
|
|
tgt_label,
|
|
|
|
PGGraphStorage._format_properties(edge_properties),
|
|
|
|
self._format_properties(edge_properties),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
# logger.info(f"-- inserting edge after formatted: {params}")
|
|
|
|
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
await self._query(query, readonly=False, upsert=True)
|
|
|
|
await self._query(query, readonly=False, upsert=True)
|
|
|
|
logger.debug(
|
|
|
|
|
|
|
|
"Upserted edge from '{%s}' to '{%s}' with properties: {%s}",
|
|
|
|
|
|
|
|
src_label,
|
|
|
|
|
|
|
|
tgt_label,
|
|
|
|
|
|
|
|
edge_properties,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
except Exception as e:
|
|
|
|
logger.error("Error during edge upsert: {%s}", e)
|
|
|
|
logger.error("Error during edge upsert: {%s}", e)
|
|
|
|
raise
|
|
|
|
raise
|
|
|
|