Merge pull request #867 from YanSte/postgres-update

improved code of postgress and execution
This commit is contained in:
Yannick Stephan
2025-02-19 13:55:42 +01:00
committed by GitHub
7 changed files with 104 additions and 164 deletions

View File

@@ -1,9 +1,11 @@
import networkx as nx
import pipmaster as pm import pipmaster as pm
if not pm.is_installed("pyvis"): if not pm.is_installed("pyvis"):
pm.install("pyvis") pm.install("pyvis")
if not pm.is_installed("networkx"):
pm.install("networkx")
import networkx as nx
from pyvis.network import Network from pyvis.network import Network
import random import random

View File

@@ -797,8 +797,8 @@ class MongoGraphStorage(BaseGraphStorage):
@final @final
@dataclass @dataclass
class MongoVectorDBStorage(BaseVectorStorage): class MongoVectorDBStorage(BaseVectorStorage):
db: AsyncIOMotorDatabase = field(default=None) db: AsyncIOMotorDatabase | None = field(default=None)
_data: AsyncIOMotorCollection = field(default=None) _data: AsyncIOMotorCollection | None = field(default=None)
def __post_init__(self): def __post_init__(self):
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})

View File

@@ -43,10 +43,6 @@ config.read("config.ini", "utf-8")
@final @final
@dataclass @dataclass
class Neo4JStorage(BaseGraphStorage): class Neo4JStorage(BaseGraphStorage):
@staticmethod
def load_nx_graph(file_name):
print("no preloading of graph with neo4j in production")
def __init__(self, namespace, global_config, embedding_func): def __init__(self, namespace, global_config, embedding_func):
super().__init__( super().__init__(
namespace=namespace, namespace=namespace,

View File

@@ -15,11 +15,10 @@ from lightrag.base import (
) )
import pipmaster as pm import pipmaster as pm
if not pm.is_installed("graspologic"):
pm.install("graspologic")
if not pm.is_installed("networkx"): if not pm.is_installed("networkx"):
pm.install("networkx") pm.install("networkx")
if not pm.is_installed("graspologic"):
pm.install("graspologic")
try: try:
from graspologic import embed from graspologic import embed

View File

@@ -178,11 +178,11 @@ class OracleDB:
class ClientManager: class ClientManager:
_instances = {"db": None, "ref_count": 0} _instances: dict[str, Any] = {"db": None, "ref_count": 0}
_lock = asyncio.Lock() _lock = asyncio.Lock()
@staticmethod @staticmethod
def get_config(): def get_config() -> dict[str, Any]:
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read("config.ini", "utf-8") config.read("config.ini", "utf-8")
@@ -398,7 +398,7 @@ class OracleKVStorage(BaseKVStorage):
@final @final
@dataclass @dataclass
class OracleVectorDBStorage(BaseVectorStorage): class OracleVectorDBStorage(BaseVectorStorage):
db: OracleDB = field(default=None) db: OracleDB | None = field(default=None)
def __post_init__(self): def __post_init__(self):
config = self.global_config.get("vector_db_storage_cls_kwargs", {}) config = self.global_config.get("vector_db_storage_cls_kwargs", {})

View File

@@ -1,10 +1,9 @@
import asyncio import asyncio
import inspect
import json import json
import os import os
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict, List, Union, final from typing import Any, Union, final
import numpy as np import numpy as np
import configparser import configparser
@@ -41,6 +40,7 @@ if not pm.is_installed("asyncpg"):
try: try:
import asyncpg import asyncpg
from asyncpg import Pool
except ImportError as e: except ImportError as e:
raise ImportError( raise ImportError(
@@ -49,8 +49,7 @@ except ImportError as e:
class PostgreSQLDB: class PostgreSQLDB:
def __init__(self, config, **kwargs): def __init__(self, config: dict[str, Any], **kwargs: Any):
self.pool = None
self.host = config.get("host", "localhost") self.host = config.get("host", "localhost")
self.port = config.get("port", 5432) self.port = config.get("port", 5432)
self.user = config.get("user", "postgres") self.user = config.get("user", "postgres")
@@ -59,7 +58,7 @@ class PostgreSQLDB:
self.workspace = config.get("workspace", "default") self.workspace = config.get("workspace", "default")
self.max = 12 self.max = 12
self.increment = 1 self.increment = 1
logger.info(f"Using the label {self.workspace} for PostgreSQL as identifier") self.pool: Pool | None = None
if self.user is None or self.password is None or self.database is None: if self.user is None or self.password is None or self.database is None:
raise ValueError( raise ValueError(
@@ -68,7 +67,7 @@ class PostgreSQLDB:
async def initdb(self): async def initdb(self):
try: try:
self.pool = await asyncpg.create_pool( self.pool = await asyncpg.create_pool( # type: ignore
user=self.user, user=self.user,
password=self.password, password=self.password,
database=self.database, database=self.database,
@@ -79,43 +78,51 @@ class PostgreSQLDB:
) )
logger.info( logger.info(
f"Connected to PostgreSQL database at {self.host}:{self.port}/{self.database}" f"PostgreSQL, Connected to database at {self.host}:{self.port}/{self.database}"
) )
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Failed to connect to PostgreSQL database at {self.host}:{self.port}/{self.database}" f"PostgreSQL, Failed to connect database at {self.host}:{self.port}/{self.database}, Got:{e}"
) )
logger.error(f"PostgreSQL database error: {e}")
raise raise
async def check_graph_requirement(self, graph_name: str):
async with self.pool.acquire() as connection: # type: ignore
try:
await connection.execute(
'SET search_path = ag_catalog, "$user", public'
) # type: ignore
await connection.execute(f"select create_graph('{graph_name}')") # type: ignore
except (
asyncpg.exceptions.InvalidSchemaNameError,
asyncpg.exceptions.UniqueViolationError,
):
pass
async def check_tables(self): async def check_tables(self):
for k, v in TABLES.items(): for k, v in TABLES.items():
try: try:
await self.query(f"SELECT 1 FROM {k} LIMIT 1") await self.query(f"SELECT 1 FROM {k} LIMIT 1")
except Exception as e: except Exception as e:
logger.error(f"Failed to check table {k} in PostgreSQL database")
logger.error(f"PostgreSQL database error: {e}") logger.error(f"PostgreSQL database error: {e}")
try: try:
logger.info(f"PostgreSQL, Try Creating table {k} in database")
await self.execute(v["ddl"]) await self.execute(v["ddl"])
logger.info(f"Created table {k} in PostgreSQL database") logger.info(f"PostgreSQL, Created table {k} in PostgreSQL database")
except Exception as e: except Exception as e:
logger.error(f"Failed to create table {k} in PostgreSQL database") logger.error(
logger.error(f"PostgreSQL database error: {e}") f"PostgreSQL, Failed to create table {k} in database, Please verify the connection with PostgreSQL database, Got: {e}"
)
logger.info("Finished checking all tables in PostgreSQL database") raise e
async def query( async def query(
self, self,
sql: str, sql: str,
params: dict = None, params: dict[str, Any] | None = None,
multirows: bool = False, multirows: bool = False,
for_age: bool = False, ) -> dict[str, Any] | None | list[dict[str, Any]]:
graph_name: str = None, async with self.pool.acquire() as connection: # type: ignore
) -> Union[dict, None, list[dict]]:
async with self.pool.acquire() as connection:
try: try:
if for_age:
await PostgreSQLDB._prerequisite(connection, graph_name)
if params: if params:
rows = await connection.fetch(sql, *params.values()) rows = await connection.fetch(sql, *params.values())
else: else:
@@ -143,20 +150,15 @@ class PostgreSQLDB:
async def execute( async def execute(
self, self,
sql: str, sql: str,
data: Union[list, dict] = None, data: dict[str, Any] | None = None,
for_age: bool = False,
graph_name: str = None,
upsert: bool = False, upsert: bool = False,
): ):
try: try:
async with self.pool.acquire() as connection: async with self.pool.acquire() as connection: # type: ignore
if for_age:
await PostgreSQLDB._prerequisite(connection, graph_name)
if data is None: if data is None:
await connection.execute(sql) await connection.execute(sql) # type: ignore
else: else:
await connection.execute(sql, *data.values()) await connection.execute(sql, *data.values()) # type: ignore
except ( except (
asyncpg.exceptions.UniqueViolationError, asyncpg.exceptions.UniqueViolationError,
asyncpg.exceptions.DuplicateTableError, asyncpg.exceptions.DuplicateTableError,
@@ -169,24 +171,13 @@ class PostgreSQLDB:
logger.error(f"PostgreSQL database,\nsql:{sql},\ndata:{data},\nerror:{e}") logger.error(f"PostgreSQL database,\nsql:{sql},\ndata:{data},\nerror:{e}")
raise raise
@staticmethod
async def _prerequisite(conn: asyncpg.Connection, graph_name: str):
try:
await conn.execute('SET search_path = ag_catalog, "$user", public')
await conn.execute(f"""select create_graph('{graph_name}')""")
except (
asyncpg.exceptions.InvalidSchemaNameError,
asyncpg.exceptions.UniqueViolationError,
):
pass
class ClientManager: class ClientManager:
_instances = {"db": None, "ref_count": 0} _instances: dict[str, Any] = {"db": None, "ref_count": 0}
_lock = asyncio.Lock() _lock = asyncio.Lock()
@staticmethod @staticmethod
def get_config(): def get_config() -> dict[str, Any]:
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read("config.ini", "utf-8") config.read("config.ini", "utf-8")
@@ -377,7 +368,7 @@ class PGKVStorage(BaseKVStorage):
@final @final
@dataclass @dataclass
class PGVectorStorage(BaseVectorStorage): class PGVectorStorage(BaseVectorStorage):
db: PostgreSQLDB = field(default=None) db: PostgreSQLDB | None = field(default=None)
def __post_init__(self): def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
@@ -398,10 +389,10 @@ class PGVectorStorage(BaseVectorStorage):
await ClientManager.release_client(self.db) await ClientManager.release_client(self.db)
self.db = None self.db = None
def _upsert_chunks(self, item: dict): def _upsert_chunks(self, item: dict[str, Any]) -> tuple[str, dict[str, Any]]:
try: try:
upsert_sql = SQL_TEMPLATES["upsert_chunk"] upsert_sql = SQL_TEMPLATES["upsert_chunk"]
data = { data: dict[str, Any] = {
"workspace": self.db.workspace, "workspace": self.db.workspace,
"id": item["__id__"], "id": item["__id__"],
"tokens": item["tokens"], "tokens": item["tokens"],
@@ -416,9 +407,9 @@ class PGVectorStorage(BaseVectorStorage):
return upsert_sql, data return upsert_sql, data
def _upsert_entities(self, item: dict): def _upsert_entities(self, item: dict[str, Any]) -> tuple[str, dict[str, Any]]:
upsert_sql = SQL_TEMPLATES["upsert_entity"] upsert_sql = SQL_TEMPLATES["upsert_entity"]
data = { data: dict[str, Any] = {
"workspace": self.db.workspace, "workspace": self.db.workspace,
"id": item["__id__"], "id": item["__id__"],
"entity_name": item["entity_name"], "entity_name": item["entity_name"],
@@ -427,9 +418,9 @@ class PGVectorStorage(BaseVectorStorage):
} }
return upsert_sql, data return upsert_sql, data
def _upsert_relationships(self, item: dict): def _upsert_relationships(self, item: dict[str, Any]) -> tuple[str, dict[str, Any]]:
upsert_sql = SQL_TEMPLATES["upsert_relationship"] upsert_sql = SQL_TEMPLATES["upsert_relationship"]
data = { data: dict[str, Any] = {
"workspace": self.db.workspace, "workspace": self.db.workspace,
"id": item["__id__"], "id": item["__id__"],
"source_id": item["src_id"], "source_id": item["src_id"],
@@ -558,16 +549,16 @@ class PGDocStatusStorage(DocStatusStorage):
) )
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
"""Get doc_chunks data by id"""
raise NotImplementedError raise NotImplementedError
async def get_status_counts(self) -> Dict[str, int]: async def get_status_counts(self) -> dict[str, int]:
"""Get counts of documents in each status""" """Get counts of documents in each status"""
sql = """SELECT status as "status", COUNT(1) as "count" sql = """SELECT status as "status", COUNT(1) as "count"
FROM LIGHTRAG_DOC_STATUS FROM LIGHTRAG_DOC_STATUS
where workspace=$1 GROUP BY STATUS where workspace=$1 GROUP BY STATUS
""" """
result = await self.db.query(sql, {"workspace": self.db.workspace}, True) result = await self.db.query(sql, {"workspace": self.db.workspace}, True)
# Result is like [{'status': 'PENDING', 'count': 1}, {'status': 'PROCESSING', 'count': 2}, ...]
counts = {} counts = {}
for doc in result: for doc in result:
counts[doc["status"]] = doc["count"] counts[doc["status"]] = doc["count"]
@@ -575,7 +566,7 @@ class PGDocStatusStorage(DocStatusStorage):
async def get_docs_by_status( async def get_docs_by_status(
self, status: DocStatus self, status: DocStatus
) -> Dict[str, DocProcessingStatus]: ) -> dict[str, DocProcessingStatus]:
"""all documents with a specific status""" """all documents with a specific status"""
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2" sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2"
params = {"workspace": self.db.workspace, "status": status.value} params = {"workspace": self.db.workspace, "status": status.value}
@@ -602,7 +593,7 @@ class PGDocStatusStorage(DocStatusStorage):
"""Update or insert document status """Update or insert document status
Args: Args:
data: Dictionary of document IDs and their status data data: dictionary of document IDs and their status data
""" """
sql = """insert into LIGHTRAG_DOC_STATUS(workspace,id,content,content_summary,content_length,chunks_count,status) sql = """insert into LIGHTRAG_DOC_STATUS(workspace,id,content,content_summary,content_length,chunks_count,status)
values($1,$2,$3,$4,$5,$6,$7) values($1,$2,$3,$4,$5,$6,$7)
@@ -627,7 +618,6 @@ class PGDocStatusStorage(DocStatusStorage):
"status": v["status"], "status": v["status"],
}, },
) )
return data
async def drop(self) -> None: async def drop(self) -> None:
"""Drop the storage""" """Drop the storage"""
@@ -638,7 +628,7 @@ class PGDocStatusStorage(DocStatusStorage):
class PGGraphQueryException(Exception): class PGGraphQueryException(Exception):
"""Exception for the AGE queries.""" """Exception for the AGE queries."""
def __init__(self, exception: Union[str, Dict]) -> None: def __init__(self, exception: Union[str, dict[str, Any]]) -> None:
if isinstance(exception, dict): if isinstance(exception, dict):
self.message = exception["message"] if "message" in exception else "unknown" self.message = exception["message"] if "message" in exception else "unknown"
self.details = exception["details"] if "details" in exception else "unknown" self.details = exception["details"] if "details" in exception else "unknown"
@@ -656,21 +646,19 @@ class PGGraphQueryException(Exception):
@final @final
@dataclass @dataclass
class PGGraphStorage(BaseGraphStorage): class PGGraphStorage(BaseGraphStorage):
db: PostgreSQLDB = field(default=None)
@staticmethod
def load_nx_graph(file_name):
print("no preloading of graph with AGE in production")
def __post_init__(self): def __post_init__(self):
self.graph_name = self.namespace or os.environ.get("AGE_GRAPH_NAME", "lightrag") self.graph_name = self.namespace or os.environ.get("AGE_GRAPH_NAME", "lightrag")
self._node_embed_algorithms = { self._node_embed_algorithms = {
"node2vec": self._node2vec_embed, "node2vec": self._node2vec_embed,
} }
self.db: PostgreSQLDB | None = None
async def initialize(self): async def initialize(self):
if self.db is None: if self.db is None:
self.db = await ClientManager.get_client() self.db = await ClientManager.get_client()
# `check_graph_requirement` is required to be executed after `get_client`
# to ensure the graph is created before any query is executed.
await self.db.check_graph_requirement(self.graph_name)
async def finalize(self): async def finalize(self):
if self.db is not None: if self.db is not None:
@@ -682,7 +670,7 @@ class PGGraphStorage(BaseGraphStorage):
pass pass
@staticmethod @staticmethod
def _record_to_dict(record: asyncpg.Record) -> Dict[str, Any]: def _record_to_dict(record: asyncpg.Record) -> dict[str, Any]:
""" """
Convert a record returned from an age query to a dictionary Convert a record returned from an age query to a dictionary
@@ -690,7 +678,7 @@ class PGGraphStorage(BaseGraphStorage):
record (): a record from an age query result record (): a record from an age query result
Returns: Returns:
Dict[str, Any]: a dictionary representation of the record where dict[str, Any]: a dictionary representation of the record where
the dictionary key is the field name and the value is the the dictionary key is the field name and the value is the
value converted to a python type value converted to a python type
""" """
@@ -745,14 +733,14 @@ class PGGraphStorage(BaseGraphStorage):
@staticmethod @staticmethod
def _format_properties( def _format_properties(
properties: Dict[str, Any], _id: Union[str, None] = None properties: dict[str, Any], _id: Union[str, None] = None
) -> str: ) -> str:
""" """
Convert a dictionary of properties to a string representation that Convert a dictionary of properties to a string representation that
can be used in a cypher query insert/merge statement. can be used in a cypher query insert/merge statement.
Args: Args:
properties (Dict[str,str]): a dictionary containing node/edge properties properties (dict[str,str]): a dictionary containing node/edge properties
_id (Union[str, None]): the id of the node or None if none exists _id (Union[str, None]): the id of the node or None if none exists
Returns: Returns:
@@ -820,8 +808,11 @@ class PGGraphStorage(BaseGraphStorage):
return field.replace("(", "_").replace(")", "") return field.replace("(", "_").replace(")", "")
async def _query( async def _query(
self, query: str, readonly: bool = True, upsert: bool = False self,
) -> List[Dict[str, Any]]: query: str,
readonly: bool = True,
upsert: bool = False,
) -> list[dict[str, Any]]:
""" """
Query the graph by taking a cypher query, converting it to an Query the graph by taking a cypher query, converting it to an
age compatible query, executing it and converting the result age compatible query, executing it and converting the result
@@ -831,32 +822,24 @@ class PGGraphStorage(BaseGraphStorage):
params (dict): parameters for the query params (dict): parameters for the query
Returns: Returns:
List[Dict[str, Any]]: a list of dictionaries containing the result set list[dict[str, Any]]: a list of dictionaries containing the result set
""" """
# convert cypher query to pgsql/age query
wrapped_query = query
# execute the query, rolling back on an error
try: try:
if readonly: if readonly:
data = await self.db.query( data = await self.db.query(
wrapped_query, query,
multirows=True, multirows=True,
for_age=True,
graph_name=self.graph_name,
) )
else: else:
data = await self.db.execute( data = await self.db.execute(
wrapped_query, query,
for_age=True,
graph_name=self.graph_name,
upsert=upsert, upsert=upsert,
) )
except Exception as e: except Exception as e:
raise PGGraphQueryException( raise PGGraphQueryException(
{ {
"message": f"Error executing graph query: {query}", "message": f"Error executing graph query: {query}",
"wrapped": wrapped_query, "wrapped": query,
"detail": str(e), "detail": str(e),
} }
) from e ) from e
@@ -865,12 +848,12 @@ class PGGraphStorage(BaseGraphStorage):
result = [] result = []
# decode records # decode records
else: else:
result = [PGGraphStorage._record_to_dict(d) for d in data] result = [self._record_to_dict(d) for d in data]
return result return result
async def has_node(self, node_id: str) -> bool: async def has_node(self, node_id: str) -> bool:
entity_name_label = PGGraphStorage._encode_graph_label(node_id.strip('"')) entity_name_label = self._encode_graph_label(node_id.strip('"'))
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity {node_id: "%s"}) MATCH (n:Entity {node_id: "%s"})
@@ -878,18 +861,12 @@ class PGGraphStorage(BaseGraphStorage):
$$) AS (node_exists bool)""" % (self.graph_name, entity_name_label) $$) AS (node_exists bool)""" % (self.graph_name, entity_name_label)
single_result = (await self._query(query))[0] single_result = (await self._query(query))[0]
logger.debug(
"{%s}:query:{%s}:result:{%s}",
inspect.currentframe().f_code.co_name,
query,
single_result["node_exists"],
)
return single_result["node_exists"] return single_result["node_exists"]
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"')) src_label = self._encode_graph_label(source_node_id.strip('"'))
tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"')) tgt_label = self._encode_graph_label(target_node_id.strip('"'))
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MATCH (a:Entity {node_id: "%s"})-[r]-(b:Entity {node_id: "%s"}) MATCH (a:Entity {node_id: "%s"})-[r]-(b:Entity {node_id: "%s"})
@@ -901,16 +878,11 @@ class PGGraphStorage(BaseGraphStorage):
) )
single_result = (await self._query(query))[0] single_result = (await self._query(query))[0]
logger.debug(
"{%s}:query:{%s}:result:{%s}",
inspect.currentframe().f_code.co_name,
query,
single_result["edge_exists"],
)
return single_result["edge_exists"] return single_result["edge_exists"]
async def get_node(self, node_id: str) -> dict[str, str] | None: async def get_node(self, node_id: str) -> dict[str, str] | None:
label = PGGraphStorage._encode_graph_label(node_id.strip('"')) label = self._encode_graph_label(node_id.strip('"'))
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity {node_id: "%s"}) MATCH (n:Entity {node_id: "%s"})
RETURN n RETURN n
@@ -919,17 +891,12 @@ class PGGraphStorage(BaseGraphStorage):
if record: if record:
node = record[0] node = record[0]
node_dict = node["n"] node_dict = node["n"]
logger.debug(
"{%s}: query: {%s}, result: {%s}",
inspect.currentframe().f_code.co_name,
query,
node_dict,
)
return node_dict return node_dict
return None return None
async def node_degree(self, node_id: str) -> int: async def node_degree(self, node_id: str) -> int:
label = PGGraphStorage._encode_graph_label(node_id.strip('"')) label = self._encode_graph_label(node_id.strip('"'))
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity {node_id: "%s"})-[]->(x) MATCH (n:Entity {node_id: "%s"})-[]->(x)
@@ -938,12 +905,7 @@ class PGGraphStorage(BaseGraphStorage):
record = (await self._query(query))[0] record = (await self._query(query))[0]
if record: if record:
edge_count = int(record["total_edge_count"]) edge_count = int(record["total_edge_count"])
logger.debug(
"{%s}:query:{%s}:result:{%s}",
inspect.currentframe().f_code.co_name,
query,
edge_count,
)
return edge_count return edge_count
async def edge_degree(self, src_id: str, tgt_id: str) -> int: async def edge_degree(self, src_id: str, tgt_id: str) -> int:
@@ -955,18 +917,14 @@ class PGGraphStorage(BaseGraphStorage):
trg_degree = 0 if trg_degree is None else trg_degree trg_degree = 0 if trg_degree is None else trg_degree
degrees = int(src_degree) + int(trg_degree) degrees = int(src_degree) + int(trg_degree)
logger.debug(
"{%s}:query:src_Degree+trg_degree:result:{%s}",
inspect.currentframe().f_code.co_name,
degrees,
)
return degrees return degrees
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None: ) -> dict[str, str] | None:
src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"')) src_label = self._encode_graph_label(source_node_id.strip('"'))
tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"')) tgt_label = self._encode_graph_label(target_node_id.strip('"'))
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MATCH (a:Entity {node_id: "%s"})-[r]->(b:Entity {node_id: "%s"}) MATCH (a:Entity {node_id: "%s"})-[r]->(b:Entity {node_id: "%s"})
@@ -980,20 +938,15 @@ class PGGraphStorage(BaseGraphStorage):
record = await self._query(query) record = await self._query(query)
if record and record[0] and record[0]["edge_properties"]: if record and record[0] and record[0]["edge_properties"]:
result = record[0]["edge_properties"] result = record[0]["edge_properties"]
logger.debug(
"{%s}:query:{%s}:result:{%s}",
inspect.currentframe().f_code.co_name,
query,
result,
)
return result return result
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
""" """
Retrieves all edges (relationships) for a particular node identified by its label. Retrieves all edges (relationships) for a particular node identified by its label.
:return: List of dictionaries containing edge information :return: list of dictionaries containing edge information
""" """
label = PGGraphStorage._encode_graph_label(source_node_id.strip('"')) label = self._encode_graph_label(source_node_id.strip('"'))
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity {node_id: "%s"}) MATCH (n:Entity {node_id: "%s"})
@@ -1024,8 +977,8 @@ class PGGraphStorage(BaseGraphStorage):
if source_label and target_label: if source_label and target_label:
edges.append( edges.append(
( (
PGGraphStorage._decode_graph_label(source_label), self._decode_graph_label(source_label),
PGGraphStorage._decode_graph_label(target_label), self._decode_graph_label(target_label),
) )
) )
@@ -1037,7 +990,7 @@ class PGGraphStorage(BaseGraphStorage):
retry=retry_if_exception_type((PGGraphQueryException,)), retry=retry_if_exception_type((PGGraphQueryException,)),
) )
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
label = PGGraphStorage._encode_graph_label(node_id.strip('"')) label = self._encode_graph_label(node_id.strip('"'))
properties = node_data properties = node_data
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
@@ -1047,18 +1000,14 @@ class PGGraphStorage(BaseGraphStorage):
$$) AS (n agtype)""" % ( $$) AS (n agtype)""" % (
self.graph_name, self.graph_name,
label, label,
PGGraphStorage._format_properties(properties), self._format_properties(properties),
) )
try: try:
await self._query(query, readonly=False, upsert=True) await self._query(query, readonly=False, upsert=True)
logger.debug(
"Upserted node with label '{%s}' and properties: {%s}",
label,
properties,
)
except Exception as e: except Exception as e:
logger.error("Error during upsert: {%s}", e) logger.error("POSTGRES, Error during upsert: {%s}", e)
raise raise
@retry( @retry(
@@ -1075,10 +1024,10 @@ class PGGraphStorage(BaseGraphStorage):
Args: Args:
source_node_id (str): Label of the source node (used as identifier) source_node_id (str): Label of the source node (used as identifier)
target_node_id (str): Label of the target node (used as identifier) target_node_id (str): Label of the target node (used as identifier)
edge_data (dict): Dictionary of properties to set on the edge edge_data (dict): dictionary of properties to set on the edge
""" """
src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"')) src_label = self._encode_graph_label(source_node_id.strip('"'))
tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"')) tgt_label = self._encode_graph_label(target_node_id.strip('"'))
edge_properties = edge_data edge_properties = edge_data
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
@@ -1092,17 +1041,12 @@ class PGGraphStorage(BaseGraphStorage):
self.graph_name, self.graph_name,
src_label, src_label,
tgt_label, tgt_label,
PGGraphStorage._format_properties(edge_properties), self._format_properties(edge_properties),
) )
# logger.info(f"-- inserting edge after formatted: {params}")
try: try:
await self._query(query, readonly=False, upsert=True) await self._query(query, readonly=False, upsert=True)
logger.debug(
"Upserted edge from '{%s}' to '{%s}' with properties: {%s}",
src_label,
tgt_label,
edge_properties,
)
except Exception as e: except Exception as e:
logger.error("Error during edge upsert: {%s}", e) logger.error("Error during edge upsert: {%s}", e)
raise raise

View File

@@ -58,7 +58,6 @@ class TiDB:
logger.error(f"Failed to check table {k} in TiDB database") logger.error(f"Failed to check table {k} in TiDB database")
logger.error(f"TiDB database error: {e}") logger.error(f"TiDB database error: {e}")
try: try:
# print(v["ddl"])
await self.execute(v["ddl"]) await self.execute(v["ddl"])
logger.info(f"Created table {k} in TiDB database") logger.info(f"Created table {k} in TiDB database")
except Exception as e: except Exception as e:
@@ -106,11 +105,11 @@ class TiDB:
class ClientManager: class ClientManager:
_instances = {"db": None, "ref_count": 0} _instances: dict[str, Any] = {"db": None, "ref_count": 0}
_lock = asyncio.Lock() _lock = asyncio.Lock()
@staticmethod @staticmethod
def get_config(): def get_config() -> dict[str, Any]:
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read("config.ini", "utf-8") config.read("config.ini", "utf-8")
@@ -278,7 +277,7 @@ class TiDBKVStorage(BaseKVStorage):
@final @final
@dataclass @dataclass
class TiDBVectorDBStorage(BaseVectorStorage): class TiDBVectorDBStorage(BaseVectorStorage):
db: TiDB = field(default=None) db: TiDB | None = field(default=None)
def __post_init__(self): def __post_init__(self):
self._client_file_name = os.path.join( self._client_file_name = os.path.join(