improve conditional checks for db instance

This commit is contained in:
ArnoChen
2025-02-19 04:53:15 +08:00
parent ae7a850d4e
commit f50604b2d3
4 changed files with 70 additions and 42 deletions

View File

@@ -1,5 +1,5 @@
import os
from dataclasses import dataclass
from dataclasses import dataclass, field
import numpy as np
import configparser
from tqdm.asyncio import tqdm as tqdm_async
@@ -27,7 +27,11 @@ if not pm.is_installed("motor"):
pm.install("motor")
try:
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
from motor.motor_asyncio import (
AsyncIOMotorClient,
AsyncIOMotorDatabase,
AsyncIOMotorCollection,
)
from pymongo.operations import SearchIndexModel
from pymongo.errors import PyMongoError
except ImportError as e:
@@ -79,19 +83,23 @@ class ClientManager:
@final
@dataclass
class MongoKVStorage(BaseKVStorage):
db: AsyncIOMotorDatabase = field(init=False)
_data: AsyncIOMotorCollection = field(init=False)
def __post_init__(self):
self._collection_name = self.namespace
async def initialize(self):
if not hasattr(self, "db") or self.db is None:
if self.db is None:
self.db = await ClientManager.get_client()
self._data = await get_or_create_collection(self.db, self._collection_name)
logger.debug(f"Use MongoDB as KV {self._collection_name}")
async def finalize(self):
if hasattr(self, "db") and self.db is not None:
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
self._data = None
async def get_by_id(self, id: str) -> dict[str, Any] | None:
return await self._data.find_one({"_id": id})
@@ -148,19 +156,23 @@ class MongoKVStorage(BaseKVStorage):
@final
@dataclass
class MongoDocStatusStorage(DocStatusStorage):
db: AsyncIOMotorDatabase = field(init=False)
_data: AsyncIOMotorCollection = field(init=False)
def __post_init__(self):
self._collection_name = self.namespace
async def initialize(self):
if not hasattr(self, "db") or self.db is None:
if self.db is None:
self.db = await ClientManager.get_client()
self._data = await get_or_create_collection(self.db, self._collection_name)
logger.debug(f"Use MongoDB as DocStatus {self._collection_name}")
async def finalize(self):
if hasattr(self, "db") and self.db is not None:
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
self._data = None
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
return await self._data.find_one({"_id": id})
@@ -221,9 +233,12 @@ class MongoDocStatusStorage(DocStatusStorage):
@dataclass
class MongoGraphStorage(BaseGraphStorage):
"""
A concrete implementation using MongoDBs $graphLookup to demonstrate multi-hop queries.
A concrete implementation using MongoDB's $graphLookup to demonstrate multi-hop queries.
"""
db: AsyncIOMotorDatabase = field(init=False)
collection: AsyncIOMotorCollection = field(init=False)
def __init__(self, namespace, global_config, embedding_func):
super().__init__(
namespace=namespace,
@@ -233,7 +248,7 @@ class MongoGraphStorage(BaseGraphStorage):
self._collection_name = self.namespace
async def initialize(self):
if not hasattr(self, "db") or self.db is None:
if self.db is None:
self.db = await ClientManager.get_client()
self.collection = await get_or_create_collection(
self.db, self._collection_name
@@ -241,9 +256,10 @@ class MongoGraphStorage(BaseGraphStorage):
logger.debug(f"Use MongoDB as KG {self._collection_name}")
async def finalize(self):
if hasattr(self, "db") and self.db is not None:
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
self.collection = None
#
# -------------------------------------------------------------------------
@@ -782,6 +798,9 @@ class MongoGraphStorage(BaseGraphStorage):
@final
@dataclass
class MongoVectorDBStorage(BaseVectorStorage):
db: AsyncIOMotorDatabase = field(init=False)
_data: AsyncIOMotorCollection = field(init=False)
def __post_init__(self):
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = kwargs.get("cosine_better_than_threshold")
@@ -794,7 +813,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
self._max_batch_size = self.global_config["embedding_batch_num"]
async def initialize(self):
if not hasattr(self, "db") or self.db is None:
if self.db is None:
self.db = await ClientManager.get_client()
self._data = await get_or_create_collection(self.db, self._collection_name)
@@ -804,9 +823,10 @@ class MongoVectorDBStorage(BaseVectorStorage):
logger.debug(f"Use MongoDB as VDB {self._collection_name}")
async def finalize(self):
if hasattr(self, "db") and self.db is not None:
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
self._data = None
async def create_vector_index_if_not_exists(self):
"""Creates an Atlas Vector Search index."""