Merge pull request #846 from ArnoChenFx/db-connection-and-storage-lifecycle

Refactor Database Connection Management and Improve Storage Lifecycle Handling
This commit is contained in:
Yannick Stephan
2025-02-18 22:39:31 +01:00
committed by GitHub
11 changed files with 540 additions and 416 deletions

View File

@@ -3,10 +3,10 @@ import inspect
import json
import os
import time
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Dict, List, Union, final
import numpy as np
import configparser
from lightrag.types import KnowledgeGraph
@@ -181,15 +181,84 @@ class PostgreSQLDB:
pass
class ClientManager:
_instances = {"db": None, "ref_count": 0}
_lock = asyncio.Lock()
@staticmethod
def get_config():
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
return {
"host": os.environ.get(
"POSTGRES_HOST",
config.get("postgres", "host", fallback="localhost"),
),
"port": os.environ.get(
"POSTGRES_PORT", config.get("postgres", "port", fallback=5432)
),
"user": os.environ.get(
"POSTGRES_USER", config.get("postgres", "user", fallback=None)
),
"password": os.environ.get(
"POSTGRES_PASSWORD",
config.get("postgres", "password", fallback=None),
),
"database": os.environ.get(
"POSTGRES_DATABASE",
config.get("postgres", "database", fallback=None),
),
"workspace": os.environ.get(
"POSTGRES_WORKSPACE",
config.get("postgres", "workspace", fallback="default"),
),
}
@classmethod
async def get_client(cls) -> PostgreSQLDB:
async with cls._lock:
if cls._instances["db"] is None:
config = ClientManager.get_config()
db = PostgreSQLDB(config)
await db.initdb()
await db.check_tables()
cls._instances["db"] = db
cls._instances["ref_count"] = 0
cls._instances["ref_count"] += 1
return cls._instances["db"]
@classmethod
async def release_client(cls, db: PostgreSQLDB):
async with cls._lock:
if db is not None:
if db is cls._instances["db"]:
cls._instances["ref_count"] -= 1
if cls._instances["ref_count"] == 0:
await db.pool.close()
logger.info("Closed PostgreSQL database connection pool")
cls._instances["db"] = None
else:
await db.pool.close()
@final
@dataclass
class PGKVStorage(BaseKVStorage):
# db instance must be injected before use
# db: PostgreSQLDB
db: PostgreSQLDB = field(default=None)
def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"]
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
################ QUERY METHODS ################
async def get_by_id(self, id: str) -> dict[str, Any] | None:
@@ -308,6 +377,8 @@ class PGKVStorage(BaseKVStorage):
@final
@dataclass
class PGVectorStorage(BaseVectorStorage):
db: PostgreSQLDB = field(default=None)
def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"]
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
@@ -318,6 +389,15 @@ class PGVectorStorage(BaseVectorStorage):
)
self.cosine_better_than_threshold = cosine_threshold
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
def _upsert_chunks(self, item: dict):
try:
upsert_sql = SQL_TEMPLATES["upsert_chunk"]
@@ -426,6 +506,17 @@ class PGVectorStorage(BaseVectorStorage):
@final
@dataclass
class PGDocStatusStorage(DocStatusStorage):
db: PostgreSQLDB = field(default=None)
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
async def filter_keys(self, keys: set[str]) -> set[str]:
"""Filter out duplicated content"""
sql = SQL_TEMPLATES["filter_keys"].format(
@@ -565,6 +656,8 @@ class PGGraphQueryException(Exception):
@final
@dataclass
class PGGraphStorage(BaseGraphStorage):
db: PostgreSQLDB = field(default=None)
@staticmethod
def load_nx_graph(file_name):
print("no preloading of graph with AGE in production")
@@ -575,6 +668,15 @@ class PGGraphStorage(BaseGraphStorage):
"node2vec": self._node2vec_embed,
}
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
async def index_done_callback(self) -> None:
# PG handles persistence automatically
pass