Merge pull request #867 from YanSte/postgres-update

improved code of postgress and execution
This commit is contained in:
Yannick Stephan
2025-02-19 13:55:42 +01:00
committed by GitHub
7 changed files with 104 additions and 164 deletions

View File

@@ -1,9 +1,11 @@
import networkx as nx
import pipmaster as pm
if not pm.is_installed("pyvis"):
pm.install("pyvis")
if not pm.is_installed("networkx"):
pm.install("networkx")
import networkx as nx
from pyvis.network import Network
import random

View File

@@ -797,8 +797,8 @@ class MongoGraphStorage(BaseGraphStorage):
@final
@dataclass
class MongoVectorDBStorage(BaseVectorStorage):
db: AsyncIOMotorDatabase = field(default=None)
_data: AsyncIOMotorCollection = field(default=None)
db: AsyncIOMotorDatabase | None = field(default=None)
_data: AsyncIOMotorCollection | None = field(default=None)
def __post_init__(self):
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})

View File

@@ -43,10 +43,6 @@ config.read("config.ini", "utf-8")
@final
@dataclass
class Neo4JStorage(BaseGraphStorage):
@staticmethod
def load_nx_graph(file_name):
print("no preloading of graph with neo4j in production")
def __init__(self, namespace, global_config, embedding_func):
super().__init__(
namespace=namespace,

View File

@@ -15,11 +15,10 @@ from lightrag.base import (
)
import pipmaster as pm
if not pm.is_installed("graspologic"):
pm.install("graspologic")
if not pm.is_installed("networkx"):
pm.install("networkx")
if not pm.is_installed("graspologic"):
pm.install("graspologic")
try:
from graspologic import embed

View File

@@ -178,11 +178,11 @@ class OracleDB:
class ClientManager:
_instances = {"db": None, "ref_count": 0}
_instances: dict[str, Any] = {"db": None, "ref_count": 0}
_lock = asyncio.Lock()
@staticmethod
def get_config():
def get_config() -> dict[str, Any]:
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
@@ -398,7 +398,7 @@ class OracleKVStorage(BaseKVStorage):
@final
@dataclass
class OracleVectorDBStorage(BaseVectorStorage):
db: OracleDB = field(default=None)
db: OracleDB | None = field(default=None)
def __post_init__(self):
config = self.global_config.get("vector_db_storage_cls_kwargs", {})

View File

@@ -1,10 +1,9 @@
import asyncio
import inspect
import json
import os
import time
from dataclasses import dataclass, field
from typing import Any, Dict, List, Union, final
from typing import Any, Union, final
import numpy as np
import configparser
@@ -41,6 +40,7 @@ if not pm.is_installed("asyncpg"):
try:
import asyncpg
from asyncpg import Pool
except ImportError as e:
raise ImportError(
@@ -49,8 +49,7 @@ except ImportError as e:
class PostgreSQLDB:
def __init__(self, config, **kwargs):
self.pool = None
def __init__(self, config: dict[str, Any], **kwargs: Any):
self.host = config.get("host", "localhost")
self.port = config.get("port", 5432)
self.user = config.get("user", "postgres")
@@ -59,7 +58,7 @@ class PostgreSQLDB:
self.workspace = config.get("workspace", "default")
self.max = 12
self.increment = 1
logger.info(f"Using the label {self.workspace} for PostgreSQL as identifier")
self.pool: Pool | None = None
if self.user is None or self.password is None or self.database is None:
raise ValueError(
@@ -68,7 +67,7 @@ class PostgreSQLDB:
async def initdb(self):
try:
self.pool = await asyncpg.create_pool(
self.pool = await asyncpg.create_pool( # type: ignore
user=self.user,
password=self.password,
database=self.database,
@@ -79,43 +78,51 @@ class PostgreSQLDB:
)
logger.info(
f"Connected to PostgreSQL database at {self.host}:{self.port}/{self.database}"
f"PostgreSQL, Connected to database at {self.host}:{self.port}/{self.database}"
)
except Exception as e:
logger.error(
f"Failed to connect to PostgreSQL database at {self.host}:{self.port}/{self.database}"
f"PostgreSQL, Failed to connect database at {self.host}:{self.port}/{self.database}, Got:{e}"
)
logger.error(f"PostgreSQL database error: {e}")
raise
async def check_graph_requirement(self, graph_name: str):
async with self.pool.acquire() as connection: # type: ignore
try:
await connection.execute(
'SET search_path = ag_catalog, "$user", public'
) # type: ignore
await connection.execute(f"select create_graph('{graph_name}')") # type: ignore
except (
asyncpg.exceptions.InvalidSchemaNameError,
asyncpg.exceptions.UniqueViolationError,
):
pass
async def check_tables(self):
for k, v in TABLES.items():
try:
await self.query(f"SELECT 1 FROM {k} LIMIT 1")
except Exception as e:
logger.error(f"Failed to check table {k} in PostgreSQL database")
logger.error(f"PostgreSQL database error: {e}")
try:
logger.info(f"PostgreSQL, Try Creating table {k} in database")
await self.execute(v["ddl"])
logger.info(f"Created table {k} in PostgreSQL database")
logger.info(f"PostgreSQL, Created table {k} in PostgreSQL database")
except Exception as e:
logger.error(f"Failed to create table {k} in PostgreSQL database")
logger.error(f"PostgreSQL database error: {e}")
logger.info("Finished checking all tables in PostgreSQL database")
logger.error(
f"PostgreSQL, Failed to create table {k} in database, Please verify the connection with PostgreSQL database, Got: {e}"
)
raise e
async def query(
self,
sql: str,
params: dict = None,
params: dict[str, Any] | None = None,
multirows: bool = False,
for_age: bool = False,
graph_name: str = None,
) -> Union[dict, None, list[dict]]:
async with self.pool.acquire() as connection:
) -> dict[str, Any] | None | list[dict[str, Any]]:
async with self.pool.acquire() as connection: # type: ignore
try:
if for_age:
await PostgreSQLDB._prerequisite(connection, graph_name)
if params:
rows = await connection.fetch(sql, *params.values())
else:
@@ -143,20 +150,15 @@ class PostgreSQLDB:
async def execute(
self,
sql: str,
data: Union[list, dict] = None,
for_age: bool = False,
graph_name: str = None,
data: dict[str, Any] | None = None,
upsert: bool = False,
):
try:
async with self.pool.acquire() as connection:
if for_age:
await PostgreSQLDB._prerequisite(connection, graph_name)
async with self.pool.acquire() as connection: # type: ignore
if data is None:
await connection.execute(sql)
await connection.execute(sql) # type: ignore
else:
await connection.execute(sql, *data.values())
await connection.execute(sql, *data.values()) # type: ignore
except (
asyncpg.exceptions.UniqueViolationError,
asyncpg.exceptions.DuplicateTableError,
@@ -169,24 +171,13 @@ class PostgreSQLDB:
logger.error(f"PostgreSQL database,\nsql:{sql},\ndata:{data},\nerror:{e}")
raise
@staticmethod
async def _prerequisite(conn: asyncpg.Connection, graph_name: str):
try:
await conn.execute('SET search_path = ag_catalog, "$user", public')
await conn.execute(f"""select create_graph('{graph_name}')""")
except (
asyncpg.exceptions.InvalidSchemaNameError,
asyncpg.exceptions.UniqueViolationError,
):
pass
class ClientManager:
_instances = {"db": None, "ref_count": 0}
_instances: dict[str, Any] = {"db": None, "ref_count": 0}
_lock = asyncio.Lock()
@staticmethod
def get_config():
def get_config() -> dict[str, Any]:
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
@@ -377,7 +368,7 @@ class PGKVStorage(BaseKVStorage):
@final
@dataclass
class PGVectorStorage(BaseVectorStorage):
db: PostgreSQLDB = field(default=None)
db: PostgreSQLDB | None = field(default=None)
def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"]
@@ -398,10 +389,10 @@ class PGVectorStorage(BaseVectorStorage):
await ClientManager.release_client(self.db)
self.db = None
def _upsert_chunks(self, item: dict):
def _upsert_chunks(self, item: dict[str, Any]) -> tuple[str, dict[str, Any]]:
try:
upsert_sql = SQL_TEMPLATES["upsert_chunk"]
data = {
data: dict[str, Any] = {
"workspace": self.db.workspace,
"id": item["__id__"],
"tokens": item["tokens"],
@@ -416,9 +407,9 @@ class PGVectorStorage(BaseVectorStorage):
return upsert_sql, data
def _upsert_entities(self, item: dict):
def _upsert_entities(self, item: dict[str, Any]) -> tuple[str, dict[str, Any]]:
upsert_sql = SQL_TEMPLATES["upsert_entity"]
data = {
data: dict[str, Any] = {
"workspace": self.db.workspace,
"id": item["__id__"],
"entity_name": item["entity_name"],
@@ -427,9 +418,9 @@ class PGVectorStorage(BaseVectorStorage):
}
return upsert_sql, data
def _upsert_relationships(self, item: dict):
def _upsert_relationships(self, item: dict[str, Any]) -> tuple[str, dict[str, Any]]:
upsert_sql = SQL_TEMPLATES["upsert_relationship"]
data = {
data: dict[str, Any] = {
"workspace": self.db.workspace,
"id": item["__id__"],
"source_id": item["src_id"],
@@ -558,16 +549,16 @@ class PGDocStatusStorage(DocStatusStorage):
)
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
"""Get doc_chunks data by id"""
raise NotImplementedError
async def get_status_counts(self) -> Dict[str, int]:
async def get_status_counts(self) -> dict[str, int]:
"""Get counts of documents in each status"""
sql = """SELECT status as "status", COUNT(1) as "count"
FROM LIGHTRAG_DOC_STATUS
where workspace=$1 GROUP BY STATUS
"""
result = await self.db.query(sql, {"workspace": self.db.workspace}, True)
# Result is like [{'status': 'PENDING', 'count': 1}, {'status': 'PROCESSING', 'count': 2}, ...]
counts = {}
for doc in result:
counts[doc["status"]] = doc["count"]
@@ -575,7 +566,7 @@ class PGDocStatusStorage(DocStatusStorage):
async def get_docs_by_status(
self, status: DocStatus
) -> Dict[str, DocProcessingStatus]:
) -> dict[str, DocProcessingStatus]:
"""all documents with a specific status"""
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2"
params = {"workspace": self.db.workspace, "status": status.value}
@@ -602,7 +593,7 @@ class PGDocStatusStorage(DocStatusStorage):
"""Update or insert document status
Args:
data: Dictionary of document IDs and their status data
data: dictionary of document IDs and their status data
"""
sql = """insert into LIGHTRAG_DOC_STATUS(workspace,id,content,content_summary,content_length,chunks_count,status)
values($1,$2,$3,$4,$5,$6,$7)
@@ -627,7 +618,6 @@ class PGDocStatusStorage(DocStatusStorage):
"status": v["status"],
},
)
return data
async def drop(self) -> None:
"""Drop the storage"""
@@ -638,7 +628,7 @@ class PGDocStatusStorage(DocStatusStorage):
class PGGraphQueryException(Exception):
"""Exception for the AGE queries."""
def __init__(self, exception: Union[str, Dict]) -> None:
def __init__(self, exception: Union[str, dict[str, Any]]) -> None:
if isinstance(exception, dict):
self.message = exception["message"] if "message" in exception else "unknown"
self.details = exception["details"] if "details" in exception else "unknown"
@@ -656,21 +646,19 @@ class PGGraphQueryException(Exception):
@final
@dataclass
class PGGraphStorage(BaseGraphStorage):
db: PostgreSQLDB = field(default=None)
@staticmethod
def load_nx_graph(file_name):
print("no preloading of graph with AGE in production")
def __post_init__(self):
self.graph_name = self.namespace or os.environ.get("AGE_GRAPH_NAME", "lightrag")
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed,
}
self.db: PostgreSQLDB | None = None
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
# `check_graph_requirement` is required to be executed after `get_client`
# to ensure the graph is created before any query is executed.
await self.db.check_graph_requirement(self.graph_name)
async def finalize(self):
if self.db is not None:
@@ -682,7 +670,7 @@ class PGGraphStorage(BaseGraphStorage):
pass
@staticmethod
def _record_to_dict(record: asyncpg.Record) -> Dict[str, Any]:
def _record_to_dict(record: asyncpg.Record) -> dict[str, Any]:
"""
Convert a record returned from an age query to a dictionary
@@ -690,7 +678,7 @@ class PGGraphStorage(BaseGraphStorage):
record (): a record from an age query result
Returns:
Dict[str, Any]: a dictionary representation of the record where
dict[str, Any]: a dictionary representation of the record where
the dictionary key is the field name and the value is the
value converted to a python type
"""
@@ -745,14 +733,14 @@ class PGGraphStorage(BaseGraphStorage):
@staticmethod
def _format_properties(
properties: Dict[str, Any], _id: Union[str, None] = None
properties: dict[str, Any], _id: Union[str, None] = None
) -> str:
"""
Convert a dictionary of properties to a string representation that
can be used in a cypher query insert/merge statement.
Args:
properties (Dict[str,str]): a dictionary containing node/edge properties
properties (dict[str,str]): a dictionary containing node/edge properties
_id (Union[str, None]): the id of the node or None if none exists
Returns:
@@ -820,8 +808,11 @@ class PGGraphStorage(BaseGraphStorage):
return field.replace("(", "_").replace(")", "")
async def _query(
self, query: str, readonly: bool = True, upsert: bool = False
) -> List[Dict[str, Any]]:
self,
query: str,
readonly: bool = True,
upsert: bool = False,
) -> list[dict[str, Any]]:
"""
Query the graph by taking a cypher query, converting it to an
age compatible query, executing it and converting the result
@@ -831,32 +822,24 @@ class PGGraphStorage(BaseGraphStorage):
params (dict): parameters for the query
Returns:
List[Dict[str, Any]]: a list of dictionaries containing the result set
list[dict[str, Any]]: a list of dictionaries containing the result set
"""
# convert cypher query to pgsql/age query
wrapped_query = query
# execute the query, rolling back on an error
try:
if readonly:
data = await self.db.query(
wrapped_query,
query,
multirows=True,
for_age=True,
graph_name=self.graph_name,
)
else:
data = await self.db.execute(
wrapped_query,
for_age=True,
graph_name=self.graph_name,
query,
upsert=upsert,
)
except Exception as e:
raise PGGraphQueryException(
{
"message": f"Error executing graph query: {query}",
"wrapped": wrapped_query,
"wrapped": query,
"detail": str(e),
}
) from e
@@ -865,12 +848,12 @@ class PGGraphStorage(BaseGraphStorage):
result = []
# decode records
else:
result = [PGGraphStorage._record_to_dict(d) for d in data]
result = [self._record_to_dict(d) for d in data]
return result
async def has_node(self, node_id: str) -> bool:
entity_name_label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
entity_name_label = self._encode_graph_label(node_id.strip('"'))
query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity {node_id: "%s"})
@@ -878,18 +861,12 @@ class PGGraphStorage(BaseGraphStorage):
$$) AS (node_exists bool)""" % (self.graph_name, entity_name_label)
single_result = (await self._query(query))[0]
logger.debug(
"{%s}:query:{%s}:result:{%s}",
inspect.currentframe().f_code.co_name,
query,
single_result["node_exists"],
)
return single_result["node_exists"]
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"'))
tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"'))
src_label = self._encode_graph_label(source_node_id.strip('"'))
tgt_label = self._encode_graph_label(target_node_id.strip('"'))
query = """SELECT * FROM cypher('%s', $$
MATCH (a:Entity {node_id: "%s"})-[r]-(b:Entity {node_id: "%s"})
@@ -901,16 +878,11 @@ class PGGraphStorage(BaseGraphStorage):
)
single_result = (await self._query(query))[0]
logger.debug(
"{%s}:query:{%s}:result:{%s}",
inspect.currentframe().f_code.co_name,
query,
single_result["edge_exists"],
)
return single_result["edge_exists"]
async def get_node(self, node_id: str) -> dict[str, str] | None:
label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
label = self._encode_graph_label(node_id.strip('"'))
query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity {node_id: "%s"})
RETURN n
@@ -919,17 +891,12 @@ class PGGraphStorage(BaseGraphStorage):
if record:
node = record[0]
node_dict = node["n"]
logger.debug(
"{%s}: query: {%s}, result: {%s}",
inspect.currentframe().f_code.co_name,
query,
node_dict,
)
return node_dict
return None
async def node_degree(self, node_id: str) -> int:
label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
label = self._encode_graph_label(node_id.strip('"'))
query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity {node_id: "%s"})-[]->(x)
@@ -938,12 +905,7 @@ class PGGraphStorage(BaseGraphStorage):
record = (await self._query(query))[0]
if record:
edge_count = int(record["total_edge_count"])
logger.debug(
"{%s}:query:{%s}:result:{%s}",
inspect.currentframe().f_code.co_name,
query,
edge_count,
)
return edge_count
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
@@ -955,18 +917,14 @@ class PGGraphStorage(BaseGraphStorage):
trg_degree = 0 if trg_degree is None else trg_degree
degrees = int(src_degree) + int(trg_degree)
logger.debug(
"{%s}:query:src_Degree+trg_degree:result:{%s}",
inspect.currentframe().f_code.co_name,
degrees,
)
return degrees
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None:
src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"'))
tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"'))
src_label = self._encode_graph_label(source_node_id.strip('"'))
tgt_label = self._encode_graph_label(target_node_id.strip('"'))
query = """SELECT * FROM cypher('%s', $$
MATCH (a:Entity {node_id: "%s"})-[r]->(b:Entity {node_id: "%s"})
@@ -980,20 +938,15 @@ class PGGraphStorage(BaseGraphStorage):
record = await self._query(query)
if record and record[0] and record[0]["edge_properties"]:
result = record[0]["edge_properties"]
logger.debug(
"{%s}:query:{%s}:result:{%s}",
inspect.currentframe().f_code.co_name,
query,
result,
)
return result
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
"""
Retrieves all edges (relationships) for a particular node identified by its label.
:return: List of dictionaries containing edge information
:return: list of dictionaries containing edge information
"""
label = PGGraphStorage._encode_graph_label(source_node_id.strip('"'))
label = self._encode_graph_label(source_node_id.strip('"'))
query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity {node_id: "%s"})
@@ -1024,8 +977,8 @@ class PGGraphStorage(BaseGraphStorage):
if source_label and target_label:
edges.append(
(
PGGraphStorage._decode_graph_label(source_label),
PGGraphStorage._decode_graph_label(target_label),
self._decode_graph_label(source_label),
self._decode_graph_label(target_label),
)
)
@@ -1037,7 +990,7 @@ class PGGraphStorage(BaseGraphStorage):
retry=retry_if_exception_type((PGGraphQueryException,)),
)
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
label = self._encode_graph_label(node_id.strip('"'))
properties = node_data
query = """SELECT * FROM cypher('%s', $$
@@ -1047,18 +1000,14 @@ class PGGraphStorage(BaseGraphStorage):
$$) AS (n agtype)""" % (
self.graph_name,
label,
PGGraphStorage._format_properties(properties),
self._format_properties(properties),
)
try:
await self._query(query, readonly=False, upsert=True)
logger.debug(
"Upserted node with label '{%s}' and properties: {%s}",
label,
properties,
)
except Exception as e:
logger.error("Error during upsert: {%s}", e)
logger.error("POSTGRES, Error during upsert: {%s}", e)
raise
@retry(
@@ -1075,10 +1024,10 @@ class PGGraphStorage(BaseGraphStorage):
Args:
source_node_id (str): Label of the source node (used as identifier)
target_node_id (str): Label of the target node (used as identifier)
edge_data (dict): Dictionary of properties to set on the edge
edge_data (dict): dictionary of properties to set on the edge
"""
src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"'))
tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"'))
src_label = self._encode_graph_label(source_node_id.strip('"'))
tgt_label = self._encode_graph_label(target_node_id.strip('"'))
edge_properties = edge_data
query = """SELECT * FROM cypher('%s', $$
@@ -1092,17 +1041,12 @@ class PGGraphStorage(BaseGraphStorage):
self.graph_name,
src_label,
tgt_label,
PGGraphStorage._format_properties(edge_properties),
self._format_properties(edge_properties),
)
# logger.info(f"-- inserting edge after formatted: {params}")
try:
await self._query(query, readonly=False, upsert=True)
logger.debug(
"Upserted edge from '{%s}' to '{%s}' with properties: {%s}",
src_label,
tgt_label,
edge_properties,
)
except Exception as e:
logger.error("Error during edge upsert: {%s}", e)
raise

View File

@@ -58,7 +58,6 @@ class TiDB:
logger.error(f"Failed to check table {k} in TiDB database")
logger.error(f"TiDB database error: {e}")
try:
# print(v["ddl"])
await self.execute(v["ddl"])
logger.info(f"Created table {k} in TiDB database")
except Exception as e:
@@ -106,11 +105,11 @@ class TiDB:
class ClientManager:
_instances = {"db": None, "ref_count": 0}
_instances: dict[str, Any] = {"db": None, "ref_count": 0}
_lock = asyncio.Lock()
@staticmethod
def get_config():
def get_config() -> dict[str, Any]:
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
@@ -278,7 +277,7 @@ class TiDBKVStorage(BaseKVStorage):
@final
@dataclass
class TiDBVectorDBStorage(BaseVectorStorage):
db: TiDB = field(default=None)
db: TiDB | None = field(default=None)
def __post_init__(self):
self._client_file_name = os.path.join(