Merge pull request #846 from ArnoChenFx/db-connection-and-storage-lifecycle
Refactor Database Connection Management and Improve Storage Lifecycle Handling
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
import numpy as np
|
||||
import configparser
|
||||
import asyncio
|
||||
@@ -26,8 +26,11 @@ if not pm.is_installed("motor"):
|
||||
pm.install("motor")
|
||||
|
||||
try:
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from pymongo import MongoClient
|
||||
from motor.motor_asyncio import (
|
||||
AsyncIOMotorClient,
|
||||
AsyncIOMotorDatabase,
|
||||
AsyncIOMotorCollection,
|
||||
)
|
||||
from pymongo.operations import SearchIndexModel
|
||||
from pymongo.errors import PyMongoError
|
||||
except ImportError as e:
|
||||
@@ -39,31 +42,63 @@ config = configparser.ConfigParser()
|
||||
config.read("config.ini", "utf-8")
|
||||
|
||||
|
||||
class ClientManager:
|
||||
_instances = {"db": None, "ref_count": 0}
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
@classmethod
|
||||
async def get_client(cls) -> AsyncIOMotorDatabase:
|
||||
async with cls._lock:
|
||||
if cls._instances["db"] is None:
|
||||
uri = os.environ.get(
|
||||
"MONGO_URI",
|
||||
config.get(
|
||||
"mongodb",
|
||||
"uri",
|
||||
fallback="mongodb://root:root@localhost:27017/",
|
||||
),
|
||||
)
|
||||
database_name = os.environ.get(
|
||||
"MONGO_DATABASE",
|
||||
config.get("mongodb", "database", fallback="LightRAG"),
|
||||
)
|
||||
client = AsyncIOMotorClient(uri)
|
||||
db = client.get_database(database_name)
|
||||
cls._instances["db"] = db
|
||||
cls._instances["ref_count"] = 0
|
||||
cls._instances["ref_count"] += 1
|
||||
return cls._instances["db"]
|
||||
|
||||
@classmethod
|
||||
async def release_client(cls, db: AsyncIOMotorDatabase):
|
||||
async with cls._lock:
|
||||
if db is not None:
|
||||
if db is cls._instances["db"]:
|
||||
cls._instances["ref_count"] -= 1
|
||||
if cls._instances["ref_count"] == 0:
|
||||
cls._instances["db"] = None
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class MongoKVStorage(BaseKVStorage):
|
||||
def __post_init__(self):
|
||||
uri = os.environ.get(
|
||||
"MONGO_URI",
|
||||
config.get(
|
||||
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
|
||||
),
|
||||
)
|
||||
client = AsyncIOMotorClient(uri)
|
||||
database = client.get_database(
|
||||
os.environ.get(
|
||||
"MONGO_DATABASE",
|
||||
config.get("mongodb", "database", fallback="LightRAG"),
|
||||
)
|
||||
)
|
||||
db: AsyncIOMotorDatabase = field(default=None)
|
||||
_data: AsyncIOMotorCollection = field(default=None)
|
||||
|
||||
def __post_init__(self):
|
||||
self._collection_name = self.namespace
|
||||
|
||||
self._data = database.get_collection(self._collection_name)
|
||||
logger.debug(f"Use MongoDB as KV {self._collection_name}")
|
||||
async def initialize(self):
|
||||
if self.db is None:
|
||||
self.db = await ClientManager.get_client()
|
||||
self._data = await get_or_create_collection(self.db, self._collection_name)
|
||||
logger.debug(f"Use MongoDB as KV {self._collection_name}")
|
||||
|
||||
# Ensure collection exists
|
||||
create_collection_if_not_exists(uri, database.name, self._collection_name)
|
||||
async def finalize(self):
|
||||
if self.db is not None:
|
||||
await ClientManager.release_client(self.db)
|
||||
self.db = None
|
||||
self._data = None
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
return await self._data.find_one({"_id": id})
|
||||
@@ -120,28 +155,23 @@ class MongoKVStorage(BaseKVStorage):
|
||||
@final
|
||||
@dataclass
|
||||
class MongoDocStatusStorage(DocStatusStorage):
|
||||
db: AsyncIOMotorDatabase = field(default=None)
|
||||
_data: AsyncIOMotorCollection = field(default=None)
|
||||
|
||||
def __post_init__(self):
|
||||
uri = os.environ.get(
|
||||
"MONGO_URI",
|
||||
config.get(
|
||||
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
|
||||
),
|
||||
)
|
||||
client = AsyncIOMotorClient(uri)
|
||||
database = client.get_database(
|
||||
os.environ.get(
|
||||
"MONGO_DATABASE",
|
||||
config.get("mongodb", "database", fallback="LightRAG"),
|
||||
)
|
||||
)
|
||||
|
||||
self._collection_name = self.namespace
|
||||
self._data = database.get_collection(self._collection_name)
|
||||
|
||||
logger.debug(f"Use MongoDB as doc status {self._collection_name}")
|
||||
async def initialize(self):
|
||||
if self.db is None:
|
||||
self.db = await ClientManager.get_client()
|
||||
self._data = await get_or_create_collection(self.db, self._collection_name)
|
||||
logger.debug(f"Use MongoDB as DocStatus {self._collection_name}")
|
||||
|
||||
# Ensure collection exists
|
||||
create_collection_if_not_exists(uri, database.name, self._collection_name)
|
||||
async def finalize(self):
|
||||
if self.db is not None:
|
||||
await ClientManager.release_client(self.db)
|
||||
self.db = None
|
||||
self._data = None
|
||||
|
||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||
return await self._data.find_one({"_id": id})
|
||||
@@ -202,36 +232,33 @@ class MongoDocStatusStorage(DocStatusStorage):
|
||||
@dataclass
|
||||
class MongoGraphStorage(BaseGraphStorage):
|
||||
"""
|
||||
A concrete implementation using MongoDB’s $graphLookup to demonstrate multi-hop queries.
|
||||
A concrete implementation using MongoDB's $graphLookup to demonstrate multi-hop queries.
|
||||
"""
|
||||
|
||||
db: AsyncIOMotorDatabase = field(default=None)
|
||||
collection: AsyncIOMotorCollection = field(default=None)
|
||||
|
||||
def __init__(self, namespace, global_config, embedding_func):
|
||||
super().__init__(
|
||||
namespace=namespace,
|
||||
global_config=global_config,
|
||||
embedding_func=embedding_func,
|
||||
)
|
||||
uri = os.environ.get(
|
||||
"MONGO_URI",
|
||||
config.get(
|
||||
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
|
||||
),
|
||||
)
|
||||
client = AsyncIOMotorClient(uri)
|
||||
database = client.get_database(
|
||||
os.environ.get(
|
||||
"MONGO_DATABASE",
|
||||
config.get("mongodb", "database", fallback="LightRAG"),
|
||||
)
|
||||
)
|
||||
|
||||
self._collection_name = self.namespace
|
||||
self.collection = database.get_collection(self._collection_name)
|
||||
|
||||
logger.debug(f"Use MongoDB as KG {self._collection_name}")
|
||||
async def initialize(self):
|
||||
if self.db is None:
|
||||
self.db = await ClientManager.get_client()
|
||||
self.collection = await get_or_create_collection(
|
||||
self.db, self._collection_name
|
||||
)
|
||||
logger.debug(f"Use MongoDB as KG {self._collection_name}")
|
||||
|
||||
# Ensure collection exists
|
||||
create_collection_if_not_exists(uri, database.name, self._collection_name)
|
||||
async def finalize(self):
|
||||
if self.db is not None:
|
||||
await ClientManager.release_client(self.db)
|
||||
self.db = None
|
||||
self.collection = None
|
||||
|
||||
#
|
||||
# -------------------------------------------------------------------------
|
||||
@@ -770,6 +797,9 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||
@final
|
||||
@dataclass
|
||||
class MongoVectorDBStorage(BaseVectorStorage):
|
||||
db: AsyncIOMotorDatabase = field(default=None)
|
||||
_data: AsyncIOMotorCollection = field(default=None)
|
||||
|
||||
def __post_init__(self):
|
||||
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
cosine_threshold = kwargs.get("cosine_better_than_threshold")
|
||||
@@ -778,41 +808,36 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
||||
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
||||
)
|
||||
self.cosine_better_than_threshold = cosine_threshold
|
||||
|
||||
uri = os.environ.get(
|
||||
"MONGO_URI",
|
||||
config.get(
|
||||
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
|
||||
),
|
||||
)
|
||||
client = AsyncIOMotorClient(uri)
|
||||
database = client.get_database(
|
||||
os.environ.get(
|
||||
"MONGO_DATABASE",
|
||||
config.get("mongodb", "database", fallback="LightRAG"),
|
||||
)
|
||||
)
|
||||
|
||||
self._collection_name = self.namespace
|
||||
self._data = database.get_collection(self._collection_name)
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
|
||||
logger.debug(f"Use MongoDB as VDB {self._collection_name}")
|
||||
async def initialize(self):
|
||||
if self.db is None:
|
||||
self.db = await ClientManager.get_client()
|
||||
self._data = await get_or_create_collection(self.db, self._collection_name)
|
||||
|
||||
# Ensure collection exists
|
||||
create_collection_if_not_exists(uri, database.name, self._collection_name)
|
||||
# Ensure vector index exists
|
||||
await self.create_vector_index_if_not_exists()
|
||||
|
||||
# Ensure vector index exists
|
||||
self.create_vector_index(uri, database.name, self._collection_name)
|
||||
logger.debug(f"Use MongoDB as VDB {self._collection_name}")
|
||||
|
||||
def create_vector_index(self, uri: str, database_name: str, collection_name: str):
|
||||
async def finalize(self):
|
||||
if self.db is not None:
|
||||
await ClientManager.release_client(self.db)
|
||||
self.db = None
|
||||
self._data = None
|
||||
|
||||
async def create_vector_index_if_not_exists(self):
|
||||
"""Creates an Atlas Vector Search index."""
|
||||
client = MongoClient(uri)
|
||||
collection = client.get_database(database_name).get_collection(
|
||||
self._collection_name
|
||||
)
|
||||
|
||||
try:
|
||||
index_name = "vector_knn_index"
|
||||
|
||||
indexes = await self._data.list_search_indexes().to_list(length=None)
|
||||
for index in indexes:
|
||||
if index["name"] == index_name:
|
||||
logger.debug("vector index already exist")
|
||||
return
|
||||
|
||||
search_index_model = SearchIndexModel(
|
||||
definition={
|
||||
"fields": [
|
||||
@@ -824,11 +849,11 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
||||
}
|
||||
]
|
||||
},
|
||||
name="vector_knn_index",
|
||||
name=index_name,
|
||||
type="vectorSearch",
|
||||
)
|
||||
|
||||
collection.create_search_index(search_index_model)
|
||||
await self._data.create_search_index(search_index_model)
|
||||
logger.info("Vector index created successfully.")
|
||||
|
||||
except PyMongoError as _:
|
||||
@@ -913,15 +938,13 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def create_collection_if_not_exists(uri: str, database_name: str, collection_name: str):
|
||||
"""Check if the collection exists. if not, create it."""
|
||||
client = MongoClient(uri)
|
||||
database = client.get_database(database_name)
|
||||
|
||||
collection_names = database.list_collection_names()
|
||||
async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str):
|
||||
collection_names = await db.list_collection_names()
|
||||
|
||||
if collection_name not in collection_names:
|
||||
database.create_collection(collection_name)
|
||||
collection = await db.create_collection(collection_name)
|
||||
logger.info(f"Created collection: {collection_name}")
|
||||
return collection
|
||||
else:
|
||||
logger.debug(f"Collection '{collection_name}' already exists.")
|
||||
return db.get_collection(collection_name)
|
||||
|
Reference in New Issue
Block a user