Merge pull request #846 from ArnoChenFx/db-connection-and-storage-lifecycle

Refactor Database Connection Management and Improve Storage Lifecycle Handling
This commit is contained in:
Yannick Stephan
2025-02-18 22:39:31 +01:00
committed by GitHub
11 changed files with 540 additions and 416 deletions

View File

@@ -1,5 +1,5 @@
import os
from dataclasses import dataclass
from dataclasses import dataclass, field
import numpy as np
import configparser
import asyncio
@@ -26,8 +26,11 @@ if not pm.is_installed("motor"):
pm.install("motor")
try:
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo import MongoClient
from motor.motor_asyncio import (
AsyncIOMotorClient,
AsyncIOMotorDatabase,
AsyncIOMotorCollection,
)
from pymongo.operations import SearchIndexModel
from pymongo.errors import PyMongoError
except ImportError as e:
@@ -39,31 +42,63 @@ config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
class ClientManager:
_instances = {"db": None, "ref_count": 0}
_lock = asyncio.Lock()
@classmethod
async def get_client(cls) -> AsyncIOMotorDatabase:
async with cls._lock:
if cls._instances["db"] is None:
uri = os.environ.get(
"MONGO_URI",
config.get(
"mongodb",
"uri",
fallback="mongodb://root:root@localhost:27017/",
),
)
database_name = os.environ.get(
"MONGO_DATABASE",
config.get("mongodb", "database", fallback="LightRAG"),
)
client = AsyncIOMotorClient(uri)
db = client.get_database(database_name)
cls._instances["db"] = db
cls._instances["ref_count"] = 0
cls._instances["ref_count"] += 1
return cls._instances["db"]
@classmethod
async def release_client(cls, db: AsyncIOMotorDatabase):
async with cls._lock:
if db is not None:
if db is cls._instances["db"]:
cls._instances["ref_count"] -= 1
if cls._instances["ref_count"] == 0:
cls._instances["db"] = None
@final
@dataclass
class MongoKVStorage(BaseKVStorage):
def __post_init__(self):
uri = os.environ.get(
"MONGO_URI",
config.get(
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
),
)
client = AsyncIOMotorClient(uri)
database = client.get_database(
os.environ.get(
"MONGO_DATABASE",
config.get("mongodb", "database", fallback="LightRAG"),
)
)
db: AsyncIOMotorDatabase = field(default=None)
_data: AsyncIOMotorCollection = field(default=None)
def __post_init__(self):
self._collection_name = self.namespace
self._data = database.get_collection(self._collection_name)
logger.debug(f"Use MongoDB as KV {self._collection_name}")
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
self._data = await get_or_create_collection(self.db, self._collection_name)
logger.debug(f"Use MongoDB as KV {self._collection_name}")
# Ensure collection exists
create_collection_if_not_exists(uri, database.name, self._collection_name)
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
self._data = None
async def get_by_id(self, id: str) -> dict[str, Any] | None:
return await self._data.find_one({"_id": id})
@@ -120,28 +155,23 @@ class MongoKVStorage(BaseKVStorage):
@final
@dataclass
class MongoDocStatusStorage(DocStatusStorage):
db: AsyncIOMotorDatabase = field(default=None)
_data: AsyncIOMotorCollection = field(default=None)
def __post_init__(self):
uri = os.environ.get(
"MONGO_URI",
config.get(
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
),
)
client = AsyncIOMotorClient(uri)
database = client.get_database(
os.environ.get(
"MONGO_DATABASE",
config.get("mongodb", "database", fallback="LightRAG"),
)
)
self._collection_name = self.namespace
self._data = database.get_collection(self._collection_name)
logger.debug(f"Use MongoDB as doc status {self._collection_name}")
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
self._data = await get_or_create_collection(self.db, self._collection_name)
logger.debug(f"Use MongoDB as DocStatus {self._collection_name}")
# Ensure collection exists
create_collection_if_not_exists(uri, database.name, self._collection_name)
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
self._data = None
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
return await self._data.find_one({"_id": id})
@@ -202,36 +232,33 @@ class MongoDocStatusStorage(DocStatusStorage):
@dataclass
class MongoGraphStorage(BaseGraphStorage):
"""
A concrete implementation using MongoDBs $graphLookup to demonstrate multi-hop queries.
A concrete implementation using MongoDB's $graphLookup to demonstrate multi-hop queries.
"""
db: AsyncIOMotorDatabase = field(default=None)
collection: AsyncIOMotorCollection = field(default=None)
def __init__(self, namespace, global_config, embedding_func):
super().__init__(
namespace=namespace,
global_config=global_config,
embedding_func=embedding_func,
)
uri = os.environ.get(
"MONGO_URI",
config.get(
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
),
)
client = AsyncIOMotorClient(uri)
database = client.get_database(
os.environ.get(
"MONGO_DATABASE",
config.get("mongodb", "database", fallback="LightRAG"),
)
)
self._collection_name = self.namespace
self.collection = database.get_collection(self._collection_name)
logger.debug(f"Use MongoDB as KG {self._collection_name}")
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
self.collection = await get_or_create_collection(
self.db, self._collection_name
)
logger.debug(f"Use MongoDB as KG {self._collection_name}")
# Ensure collection exists
create_collection_if_not_exists(uri, database.name, self._collection_name)
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
self.collection = None
#
# -------------------------------------------------------------------------
@@ -770,6 +797,9 @@ class MongoGraphStorage(BaseGraphStorage):
@final
@dataclass
class MongoVectorDBStorage(BaseVectorStorage):
db: AsyncIOMotorDatabase = field(default=None)
_data: AsyncIOMotorCollection = field(default=None)
def __post_init__(self):
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = kwargs.get("cosine_better_than_threshold")
@@ -778,41 +808,36 @@ class MongoVectorDBStorage(BaseVectorStorage):
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
)
self.cosine_better_than_threshold = cosine_threshold
uri = os.environ.get(
"MONGO_URI",
config.get(
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
),
)
client = AsyncIOMotorClient(uri)
database = client.get_database(
os.environ.get(
"MONGO_DATABASE",
config.get("mongodb", "database", fallback="LightRAG"),
)
)
self._collection_name = self.namespace
self._data = database.get_collection(self._collection_name)
self._max_batch_size = self.global_config["embedding_batch_num"]
logger.debug(f"Use MongoDB as VDB {self._collection_name}")
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
self._data = await get_or_create_collection(self.db, self._collection_name)
# Ensure collection exists
create_collection_if_not_exists(uri, database.name, self._collection_name)
# Ensure vector index exists
await self.create_vector_index_if_not_exists()
# Ensure vector index exists
self.create_vector_index(uri, database.name, self._collection_name)
logger.debug(f"Use MongoDB as VDB {self._collection_name}")
def create_vector_index(self, uri: str, database_name: str, collection_name: str):
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
self._data = None
async def create_vector_index_if_not_exists(self):
"""Creates an Atlas Vector Search index."""
client = MongoClient(uri)
collection = client.get_database(database_name).get_collection(
self._collection_name
)
try:
index_name = "vector_knn_index"
indexes = await self._data.list_search_indexes().to_list(length=None)
for index in indexes:
if index["name"] == index_name:
logger.debug("vector index already exist")
return
search_index_model = SearchIndexModel(
definition={
"fields": [
@@ -824,11 +849,11 @@ class MongoVectorDBStorage(BaseVectorStorage):
}
]
},
name="vector_knn_index",
name=index_name,
type="vectorSearch",
)
collection.create_search_index(search_index_model)
await self._data.create_search_index(search_index_model)
logger.info("Vector index created successfully.")
except PyMongoError as _:
@@ -913,15 +938,13 @@ class MongoVectorDBStorage(BaseVectorStorage):
raise NotImplementedError
def create_collection_if_not_exists(uri: str, database_name: str, collection_name: str):
"""Check if the collection exists. if not, create it."""
client = MongoClient(uri)
database = client.get_database(database_name)
collection_names = database.list_collection_names()
async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str):
collection_names = await db.list_collection_names()
if collection_name not in collection_names:
database.create_collection(collection_name)
collection = await db.create_collection(collection_name)
logger.info(f"Created collection: {collection_name}")
return collection
else:
logger.debug(f"Collection '{collection_name}' already exists.")
return db.get_collection(collection_name)

View File

@@ -2,11 +2,11 @@ import array
import asyncio
# import html
# import os
from dataclasses import dataclass
import os
from dataclasses import dataclass, field
from typing import Any, Union, final
import numpy as np
import configparser
from lightrag.types import KnowledgeGraph
@@ -177,17 +177,91 @@ class OracleDB:
raise
class ClientManager:
_instances = {"db": None, "ref_count": 0}
_lock = asyncio.Lock()
@staticmethod
def get_config():
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
return {
"user": os.environ.get(
"ORACLE_USER",
config.get("oracle", "user", fallback=None),
),
"password": os.environ.get(
"ORACLE_PASSWORD",
config.get("oracle", "password", fallback=None),
),
"dsn": os.environ.get(
"ORACLE_DSN",
config.get("oracle", "dsn", fallback=None),
),
"config_dir": os.environ.get(
"ORACLE_CONFIG_DIR",
config.get("oracle", "config_dir", fallback=None),
),
"wallet_location": os.environ.get(
"ORACLE_WALLET_LOCATION",
config.get("oracle", "wallet_location", fallback=None),
),
"wallet_password": os.environ.get(
"ORACLE_WALLET_PASSWORD",
config.get("oracle", "wallet_password", fallback=None),
),
"workspace": os.environ.get(
"ORACLE_WORKSPACE",
config.get("oracle", "workspace", fallback="default"),
),
}
@classmethod
async def get_client(cls) -> OracleDB:
async with cls._lock:
if cls._instances["db"] is None:
config = ClientManager.get_config()
db = OracleDB(config)
await db.check_tables()
cls._instances["db"] = db
cls._instances["ref_count"] = 0
cls._instances["ref_count"] += 1
return cls._instances["db"]
@classmethod
async def release_client(cls, db: OracleDB):
async with cls._lock:
if db is not None:
if db is cls._instances["db"]:
cls._instances["ref_count"] -= 1
if cls._instances["ref_count"] == 0:
await db.pool.close()
logger.info("Closed OracleDB database connection pool")
cls._instances["db"] = None
else:
await db.pool.close()
@final
@dataclass
class OracleKVStorage(BaseKVStorage):
# db instance must be injected before use
# db: OracleDB
db: OracleDB = field(default=None)
meta_fields = None
def __post_init__(self):
self._data = {}
self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
################ QUERY METHODS ################
async def get_by_id(self, id: str) -> dict[str, Any] | None:
@@ -324,6 +398,8 @@ class OracleKVStorage(BaseKVStorage):
@final
@dataclass
class OracleVectorDBStorage(BaseVectorStorage):
db: OracleDB = field(default=None)
def __post_init__(self):
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = config.get("cosine_better_than_threshold")
@@ -333,6 +409,15 @@ class OracleVectorDBStorage(BaseVectorStorage):
)
self.cosine_better_than_threshold = cosine_threshold
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
#################### query method ###############
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
embeddings = await self.embedding_func([query])
@@ -369,9 +454,20 @@ class OracleVectorDBStorage(BaseVectorStorage):
@final
@dataclass
class OracleGraphStorage(BaseGraphStorage):
db: OracleDB = field(default=None)
def __post_init__(self):
self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
#################### insert method ################
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:

View File

@@ -3,10 +3,10 @@ import inspect
import json
import os
import time
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Dict, List, Union, final
import numpy as np
import configparser
from lightrag.types import KnowledgeGraph
@@ -181,15 +181,84 @@ class PostgreSQLDB:
pass
class ClientManager:
_instances = {"db": None, "ref_count": 0}
_lock = asyncio.Lock()
@staticmethod
def get_config():
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
return {
"host": os.environ.get(
"POSTGRES_HOST",
config.get("postgres", "host", fallback="localhost"),
),
"port": os.environ.get(
"POSTGRES_PORT", config.get("postgres", "port", fallback=5432)
),
"user": os.environ.get(
"POSTGRES_USER", config.get("postgres", "user", fallback=None)
),
"password": os.environ.get(
"POSTGRES_PASSWORD",
config.get("postgres", "password", fallback=None),
),
"database": os.environ.get(
"POSTGRES_DATABASE",
config.get("postgres", "database", fallback=None),
),
"workspace": os.environ.get(
"POSTGRES_WORKSPACE",
config.get("postgres", "workspace", fallback="default"),
),
}
@classmethod
async def get_client(cls) -> PostgreSQLDB:
async with cls._lock:
if cls._instances["db"] is None:
config = ClientManager.get_config()
db = PostgreSQLDB(config)
await db.initdb()
await db.check_tables()
cls._instances["db"] = db
cls._instances["ref_count"] = 0
cls._instances["ref_count"] += 1
return cls._instances["db"]
@classmethod
async def release_client(cls, db: PostgreSQLDB):
async with cls._lock:
if db is not None:
if db is cls._instances["db"]:
cls._instances["ref_count"] -= 1
if cls._instances["ref_count"] == 0:
await db.pool.close()
logger.info("Closed PostgreSQL database connection pool")
cls._instances["db"] = None
else:
await db.pool.close()
@final
@dataclass
class PGKVStorage(BaseKVStorage):
# db instance must be injected before use
# db: PostgreSQLDB
db: PostgreSQLDB = field(default=None)
def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"]
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
################ QUERY METHODS ################
async def get_by_id(self, id: str) -> dict[str, Any] | None:
@@ -308,6 +377,8 @@ class PGKVStorage(BaseKVStorage):
@final
@dataclass
class PGVectorStorage(BaseVectorStorage):
db: PostgreSQLDB = field(default=None)
def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"]
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
@@ -318,6 +389,15 @@ class PGVectorStorage(BaseVectorStorage):
)
self.cosine_better_than_threshold = cosine_threshold
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
def _upsert_chunks(self, item: dict):
try:
upsert_sql = SQL_TEMPLATES["upsert_chunk"]
@@ -426,6 +506,17 @@ class PGVectorStorage(BaseVectorStorage):
@final
@dataclass
class PGDocStatusStorage(DocStatusStorage):
db: PostgreSQLDB = field(default=None)
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
async def filter_keys(self, keys: set[str]) -> set[str]:
"""Filter out duplicated content"""
sql = SQL_TEMPLATES["filter_keys"].format(
@@ -565,6 +656,8 @@ class PGGraphQueryException(Exception):
@final
@dataclass
class PGGraphStorage(BaseGraphStorage):
db: PostgreSQLDB = field(default=None)
@staticmethod
def load_nx_graph(file_name):
print("no preloading of graph with AGE in production")
@@ -575,6 +668,15 @@ class PGGraphStorage(BaseGraphStorage):
"node2vec": self._node2vec_embed,
}
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
async def index_done_callback(self) -> None:
# PG handles persistence automatically
pass

View File

@@ -1,6 +1,6 @@
import asyncio
import os
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Union, final
import numpy as np
@@ -13,6 +13,7 @@ from ..namespace import NameSpace, is_namespace
from ..utils import logger
import pipmaster as pm
import configparser
if not pm.is_installed("pymysql"):
pm.install("pymysql")
@@ -104,16 +105,81 @@ class TiDB:
raise
class ClientManager:
_instances = {"db": None, "ref_count": 0}
_lock = asyncio.Lock()
@staticmethod
def get_config():
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
return {
"host": os.environ.get(
"TIDB_HOST",
config.get("tidb", "host", fallback="localhost"),
),
"port": os.environ.get(
"TIDB_PORT", config.get("tidb", "port", fallback=4000)
),
"user": os.environ.get(
"TIDB_USER",
config.get("tidb", "user", fallback=None),
),
"password": os.environ.get(
"TIDB_PASSWORD",
config.get("tidb", "password", fallback=None),
),
"database": os.environ.get(
"TIDB_DATABASE",
config.get("tidb", "database", fallback=None),
),
"workspace": os.environ.get(
"TIDB_WORKSPACE",
config.get("tidb", "workspace", fallback="default"),
),
}
@classmethod
async def get_client(cls) -> TiDB:
async with cls._lock:
if cls._instances["db"] is None:
config = ClientManager.get_config()
db = TiDB(config)
await db.check_tables()
cls._instances["db"] = db
cls._instances["ref_count"] = 0
cls._instances["ref_count"] += 1
return cls._instances["db"]
@classmethod
async def release_client(cls, db: TiDB):
async with cls._lock:
if db is not None:
if db is cls._instances["db"]:
cls._instances["ref_count"] -= 1
if cls._instances["ref_count"] == 0:
cls._instances["db"] = None
@final
@dataclass
class TiDBKVStorage(BaseKVStorage):
# db instance must be injected before use
# db: TiDB
db: TiDB = field(default=None)
def __post_init__(self):
self._data = {}
self._max_batch_size = self.global_config["embedding_batch_num"]
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
################ QUERY METHODS ################
async def get_by_id(self, id: str) -> dict[str, Any] | None:
@@ -184,7 +250,7 @@ class TiDBKVStorage(BaseKVStorage):
"tokens": item["tokens"],
"chunk_order_index": item["chunk_order_index"],
"full_doc_id": item["full_doc_id"],
"content_vector": f'{item["__vector__"].tolist()}',
"content_vector": f"{item['__vector__'].tolist()}",
"workspace": self.db.workspace,
}
)
@@ -212,6 +278,8 @@ class TiDBKVStorage(BaseKVStorage):
@final
@dataclass
class TiDBVectorDBStorage(BaseVectorStorage):
db: TiDB = field(default=None)
def __post_init__(self):
self._client_file_name = os.path.join(
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
@@ -225,6 +293,15 @@ class TiDBVectorDBStorage(BaseVectorStorage):
)
self.cosine_better_than_threshold = cosine_threshold
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
"""Search from tidb vector"""
embeddings = await self.embedding_func([query])
@@ -282,7 +359,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
"id": item["id"],
"name": item["entity_name"],
"content": item["content"],
"content_vector": f'{item["content_vector"].tolist()}',
"content_vector": f"{item['content_vector'].tolist()}",
"workspace": self.db.workspace,
}
# update entity_id if node inserted by graph_storage_instance before
@@ -304,7 +381,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
"source_name": item["src_id"],
"target_name": item["tgt_id"],
"content": item["content"],
"content_vector": f'{item["content_vector"].tolist()}',
"content_vector": f"{item['content_vector'].tolist()}",
"workspace": self.db.workspace,
}
# update relation_id if node inserted by graph_storage_instance before
@@ -337,12 +414,20 @@ class TiDBVectorDBStorage(BaseVectorStorage):
@final
@dataclass
class TiDBGraphStorage(BaseGraphStorage):
# db instance must be injected before use
# db: TiDB
db: TiDB = field(default=None)
def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"]
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
#################### upsert method ################
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
entity_name = node_id