improve conditional checks for db instance

This commit is contained in:
ArnoChen
2025-02-19 04:53:15 +08:00
parent ae7a850d4e
commit f50604b2d3
4 changed files with 70 additions and 42 deletions

View File

@@ -1,5 +1,5 @@
import os
from dataclasses import dataclass
from dataclasses import dataclass, field
import numpy as np
import configparser
from tqdm.asyncio import tqdm as tqdm_async
@@ -27,7 +27,11 @@ if not pm.is_installed("motor"):
pm.install("motor")
try:
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
from motor.motor_asyncio import (
AsyncIOMotorClient,
AsyncIOMotorDatabase,
AsyncIOMotorCollection,
)
from pymongo.operations import SearchIndexModel
from pymongo.errors import PyMongoError
except ImportError as e:
@@ -79,19 +83,23 @@ class ClientManager:
@final
@dataclass
class MongoKVStorage(BaseKVStorage):
db: AsyncIOMotorDatabase = field(init=False)
_data: AsyncIOMotorCollection = field(init=False)
def __post_init__(self):
self._collection_name = self.namespace
async def initialize(self):
if not hasattr(self, "db") or self.db is None:
if self.db is None:
self.db = await ClientManager.get_client()
self._data = await get_or_create_collection(self.db, self._collection_name)
logger.debug(f"Use MongoDB as KV {self._collection_name}")
async def finalize(self):
if hasattr(self, "db") and self.db is not None:
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
self._data = None
async def get_by_id(self, id: str) -> dict[str, Any] | None:
return await self._data.find_one({"_id": id})
@@ -148,19 +156,23 @@ class MongoKVStorage(BaseKVStorage):
@final
@dataclass
class MongoDocStatusStorage(DocStatusStorage):
db: AsyncIOMotorDatabase = field(init=False)
_data: AsyncIOMotorCollection = field(init=False)
def __post_init__(self):
self._collection_name = self.namespace
async def initialize(self):
if not hasattr(self, "db") or self.db is None:
if self.db is None:
self.db = await ClientManager.get_client()
self._data = await get_or_create_collection(self.db, self._collection_name)
logger.debug(f"Use MongoDB as DocStatus {self._collection_name}")
async def finalize(self):
if hasattr(self, "db") and self.db is not None:
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
self._data = None
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
return await self._data.find_one({"_id": id})
@@ -221,9 +233,12 @@ class MongoDocStatusStorage(DocStatusStorage):
@dataclass
class MongoGraphStorage(BaseGraphStorage):
"""
A concrete implementation using MongoDBs $graphLookup to demonstrate multi-hop queries.
A concrete implementation using MongoDB's $graphLookup to demonstrate multi-hop queries.
"""
db: AsyncIOMotorDatabase = field(init=False)
collection: AsyncIOMotorCollection = field(init=False)
def __init__(self, namespace, global_config, embedding_func):
super().__init__(
namespace=namespace,
@@ -233,7 +248,7 @@ class MongoGraphStorage(BaseGraphStorage):
self._collection_name = self.namespace
async def initialize(self):
if not hasattr(self, "db") or self.db is None:
if self.db is None:
self.db = await ClientManager.get_client()
self.collection = await get_or_create_collection(
self.db, self._collection_name
@@ -241,9 +256,10 @@ class MongoGraphStorage(BaseGraphStorage):
logger.debug(f"Use MongoDB as KG {self._collection_name}")
async def finalize(self):
if hasattr(self, "db") and self.db is not None:
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
self.collection = None
#
# -------------------------------------------------------------------------
@@ -782,6 +798,9 @@ class MongoGraphStorage(BaseGraphStorage):
@final
@dataclass
class MongoVectorDBStorage(BaseVectorStorage):
db: AsyncIOMotorDatabase = field(init=False)
_data: AsyncIOMotorCollection = field(init=False)
def __post_init__(self):
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = kwargs.get("cosine_better_than_threshold")
@@ -794,7 +813,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
self._max_batch_size = self.global_config["embedding_batch_num"]
async def initialize(self):
if not hasattr(self, "db") or self.db is None:
if self.db is None:
self.db = await ClientManager.get_client()
self._data = await get_or_create_collection(self.db, self._collection_name)
@@ -804,9 +823,10 @@ class MongoVectorDBStorage(BaseVectorStorage):
logger.debug(f"Use MongoDB as VDB {self._collection_name}")
async def finalize(self):
if hasattr(self, "db") and self.db is not None:
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
self._data = None
async def create_vector_index_if_not_exists(self):
"""Creates an Atlas Vector Search index."""

View File

@@ -3,7 +3,7 @@ import asyncio
# import html
import os
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Union, final
import numpy as np
import configparser
@@ -242,8 +242,7 @@ class ClientManager:
@final
@dataclass
class OracleKVStorage(BaseKVStorage):
# db instance must be injected before use
# db: OracleDB
db: OracleDB = field(init=False)
meta_fields = None
def __post_init__(self):
@@ -251,11 +250,11 @@ class OracleKVStorage(BaseKVStorage):
self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
async def initialize(self):
if not hasattr(self, "db") or self.db is None:
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if hasattr(self, "db") and self.db is not None:
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
@@ -395,6 +394,8 @@ class OracleKVStorage(BaseKVStorage):
@final
@dataclass
class OracleVectorDBStorage(BaseVectorStorage):
db: OracleDB = field(init=False)
def __post_init__(self):
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = config.get("cosine_better_than_threshold")
@@ -405,11 +406,11 @@ class OracleVectorDBStorage(BaseVectorStorage):
self.cosine_better_than_threshold = cosine_threshold
async def initialize(self):
if not hasattr(self, "db") or self.db is None:
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if hasattr(self, "db") and self.db is not None:
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
@@ -449,15 +450,17 @@ class OracleVectorDBStorage(BaseVectorStorage):
@final
@dataclass
class OracleGraphStorage(BaseGraphStorage):
db: OracleDB = field(init=False)
def __post_init__(self):
self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
async def initialize(self):
if not hasattr(self, "db") or self.db is None:
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if hasattr(self, "db") and self.db is not None:
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None

View File

@@ -3,7 +3,7 @@ import inspect
import json
import os
import time
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Dict, List, Union, final
import numpy as np
import configparser
@@ -246,18 +246,17 @@ class ClientManager:
@final
@dataclass
class PGKVStorage(BaseKVStorage):
# db instance must be injected before use
# db: PostgreSQLDB
db: PostgreSQLDB = field(init=False)
def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"]
async def initialize(self):
if not hasattr(self, "db") or self.db is None:
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if hasattr(self, "db") and self.db is not None:
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
@@ -379,6 +378,8 @@ class PGKVStorage(BaseKVStorage):
@final
@dataclass
class PGVectorStorage(BaseVectorStorage):
db: PostgreSQLDB = field(init=False)
def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"]
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
@@ -390,11 +391,11 @@ class PGVectorStorage(BaseVectorStorage):
self.cosine_better_than_threshold = cosine_threshold
async def initialize(self):
if not hasattr(self, "db") or self.db is None:
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if hasattr(self, "db") and self.db is not None:
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
@@ -514,12 +515,14 @@ class PGVectorStorage(BaseVectorStorage):
@final
@dataclass
class PGDocStatusStorage(DocStatusStorage):
db: PostgreSQLDB = field(init=False)
async def initialize(self):
if not hasattr(self, "db") or self.db is None:
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if hasattr(self, "db") and self.db is not None:
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
@@ -662,6 +665,8 @@ class PGGraphQueryException(Exception):
@final
@dataclass
class PGGraphStorage(BaseGraphStorage):
db: PostgreSQLDB = field(init=False)
@staticmethod
def load_nx_graph(file_name):
print("no preloading of graph with AGE in production")
@@ -673,11 +678,11 @@ class PGGraphStorage(BaseGraphStorage):
}
async def initialize(self):
if not hasattr(self, "db") or self.db is None:
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if hasattr(self, "db") and self.db is not None:
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None

View File

@@ -1,6 +1,6 @@
import asyncio
import os
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Union, final
import numpy as np
@@ -166,19 +166,18 @@ class ClientManager:
@final
@dataclass
class TiDBKVStorage(BaseKVStorage):
# db instance must be injected before use
# db: TiDB
db: TiDB = field(init=False)
def __post_init__(self):
self._data = {}
self._max_batch_size = self.global_config["embedding_batch_num"]
async def initialize(self):
if not hasattr(self, "db") or self.db is None:
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if hasattr(self, "db") and self.db is not None:
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
@@ -280,6 +279,8 @@ class TiDBKVStorage(BaseKVStorage):
@final
@dataclass
class TiDBVectorDBStorage(BaseVectorStorage):
db: TiDB = field(init=False)
def __post_init__(self):
self._client_file_name = os.path.join(
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
@@ -294,11 +295,11 @@ class TiDBVectorDBStorage(BaseVectorStorage):
self.cosine_better_than_threshold = cosine_threshold
async def initialize(self):
if not hasattr(self, "db") or self.db is None:
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if hasattr(self, "db") and self.db is not None:
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
@@ -421,18 +422,17 @@ class TiDBVectorDBStorage(BaseVectorStorage):
@final
@dataclass
class TiDBGraphStorage(BaseGraphStorage):
# db instance must be injected before use
# db: TiDB
db: TiDB = field(init=False)
def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"]
async def initialize(self):
if not hasattr(self, "db") or self.db is None:
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if hasattr(self, "db") and self.db is not None:
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None