improve conditional checks for db instance

This commit is contained in:
ArnoChen
2025-02-19 04:53:15 +08:00
parent ae7a850d4e
commit f50604b2d3
4 changed files with 70 additions and 42 deletions

View File

@@ -1,5 +1,5 @@
import os import os
from dataclasses import dataclass from dataclasses import dataclass, field
import numpy as np import numpy as np
import configparser import configparser
from tqdm.asyncio import tqdm as tqdm_async from tqdm.asyncio import tqdm as tqdm_async
@@ -27,7 +27,11 @@ if not pm.is_installed("motor"):
pm.install("motor") pm.install("motor")
try: try:
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase from motor.motor_asyncio import (
AsyncIOMotorClient,
AsyncIOMotorDatabase,
AsyncIOMotorCollection,
)
from pymongo.operations import SearchIndexModel from pymongo.operations import SearchIndexModel
from pymongo.errors import PyMongoError from pymongo.errors import PyMongoError
except ImportError as e: except ImportError as e:
@@ -79,19 +83,23 @@ class ClientManager:
@final @final
@dataclass @dataclass
class MongoKVStorage(BaseKVStorage): class MongoKVStorage(BaseKVStorage):
db: AsyncIOMotorDatabase = field(init=False)
_data: AsyncIOMotorCollection = field(init=False)
def __post_init__(self): def __post_init__(self):
self._collection_name = self.namespace self._collection_name = self.namespace
async def initialize(self): async def initialize(self):
if not hasattr(self, "db") or self.db is None: if self.db is None:
self.db = await ClientManager.get_client() self.db = await ClientManager.get_client()
self._data = await get_or_create_collection(self.db, self._collection_name) self._data = await get_or_create_collection(self.db, self._collection_name)
logger.debug(f"Use MongoDB as KV {self._collection_name}") logger.debug(f"Use MongoDB as KV {self._collection_name}")
async def finalize(self): async def finalize(self):
if hasattr(self, "db") and self.db is not None: if self.db is not None:
await ClientManager.release_client(self.db) await ClientManager.release_client(self.db)
self.db = None self.db = None
self._data = None
async def get_by_id(self, id: str) -> dict[str, Any] | None: async def get_by_id(self, id: str) -> dict[str, Any] | None:
return await self._data.find_one({"_id": id}) return await self._data.find_one({"_id": id})
@@ -148,19 +156,23 @@ class MongoKVStorage(BaseKVStorage):
@final @final
@dataclass @dataclass
class MongoDocStatusStorage(DocStatusStorage): class MongoDocStatusStorage(DocStatusStorage):
db: AsyncIOMotorDatabase = field(init=False)
_data: AsyncIOMotorCollection = field(init=False)
def __post_init__(self): def __post_init__(self):
self._collection_name = self.namespace self._collection_name = self.namespace
async def initialize(self): async def initialize(self):
if not hasattr(self, "db") or self.db is None: if self.db is None:
self.db = await ClientManager.get_client() self.db = await ClientManager.get_client()
self._data = await get_or_create_collection(self.db, self._collection_name) self._data = await get_or_create_collection(self.db, self._collection_name)
logger.debug(f"Use MongoDB as DocStatus {self._collection_name}") logger.debug(f"Use MongoDB as DocStatus {self._collection_name}")
async def finalize(self): async def finalize(self):
if hasattr(self, "db") and self.db is not None: if self.db is not None:
await ClientManager.release_client(self.db) await ClientManager.release_client(self.db)
self.db = None self.db = None
self._data = None
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
return await self._data.find_one({"_id": id}) return await self._data.find_one({"_id": id})
@@ -221,9 +233,12 @@ class MongoDocStatusStorage(DocStatusStorage):
@dataclass @dataclass
class MongoGraphStorage(BaseGraphStorage): class MongoGraphStorage(BaseGraphStorage):
""" """
A concrete implementation using MongoDBs $graphLookup to demonstrate multi-hop queries. A concrete implementation using MongoDB's $graphLookup to demonstrate multi-hop queries.
""" """
db: AsyncIOMotorDatabase = field(init=False)
collection: AsyncIOMotorCollection = field(init=False)
def __init__(self, namespace, global_config, embedding_func): def __init__(self, namespace, global_config, embedding_func):
super().__init__( super().__init__(
namespace=namespace, namespace=namespace,
@@ -233,7 +248,7 @@ class MongoGraphStorage(BaseGraphStorage):
self._collection_name = self.namespace self._collection_name = self.namespace
async def initialize(self): async def initialize(self):
if not hasattr(self, "db") or self.db is None: if self.db is None:
self.db = await ClientManager.get_client() self.db = await ClientManager.get_client()
self.collection = await get_or_create_collection( self.collection = await get_or_create_collection(
self.db, self._collection_name self.db, self._collection_name
@@ -241,9 +256,10 @@ class MongoGraphStorage(BaseGraphStorage):
logger.debug(f"Use MongoDB as KG {self._collection_name}") logger.debug(f"Use MongoDB as KG {self._collection_name}")
async def finalize(self): async def finalize(self):
if hasattr(self, "db") and self.db is not None: if self.db is not None:
await ClientManager.release_client(self.db) await ClientManager.release_client(self.db)
self.db = None self.db = None
self.collection = None
# #
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
@@ -782,6 +798,9 @@ class MongoGraphStorage(BaseGraphStorage):
@final @final
@dataclass @dataclass
class MongoVectorDBStorage(BaseVectorStorage): class MongoVectorDBStorage(BaseVectorStorage):
db: AsyncIOMotorDatabase = field(init=False)
_data: AsyncIOMotorCollection = field(init=False)
def __post_init__(self): def __post_init__(self):
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = kwargs.get("cosine_better_than_threshold") cosine_threshold = kwargs.get("cosine_better_than_threshold")
@@ -794,7 +813,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
async def initialize(self): async def initialize(self):
if not hasattr(self, "db") or self.db is None: if self.db is None:
self.db = await ClientManager.get_client() self.db = await ClientManager.get_client()
self._data = await get_or_create_collection(self.db, self._collection_name) self._data = await get_or_create_collection(self.db, self._collection_name)
@@ -804,9 +823,10 @@ class MongoVectorDBStorage(BaseVectorStorage):
logger.debug(f"Use MongoDB as VDB {self._collection_name}") logger.debug(f"Use MongoDB as VDB {self._collection_name}")
async def finalize(self): async def finalize(self):
if hasattr(self, "db") and self.db is not None: if self.db is not None:
await ClientManager.release_client(self.db) await ClientManager.release_client(self.db)
self.db = None self.db = None
self._data = None
async def create_vector_index_if_not_exists(self): async def create_vector_index_if_not_exists(self):
"""Creates an Atlas Vector Search index.""" """Creates an Atlas Vector Search index."""

View File

@@ -3,7 +3,7 @@ import asyncio
# import html # import html
import os import os
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Any, Union, final from typing import Any, Union, final
import numpy as np import numpy as np
import configparser import configparser
@@ -242,8 +242,7 @@ class ClientManager:
@final @final
@dataclass @dataclass
class OracleKVStorage(BaseKVStorage): class OracleKVStorage(BaseKVStorage):
# db instance must be injected before use db: OracleDB = field(init=False)
# db: OracleDB
meta_fields = None meta_fields = None
def __post_init__(self): def __post_init__(self):
@@ -251,11 +250,11 @@ class OracleKVStorage(BaseKVStorage):
self._max_batch_size = self.global_config.get("embedding_batch_num", 10) self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
async def initialize(self): async def initialize(self):
if not hasattr(self, "db") or self.db is None: if self.db is None:
self.db = await ClientManager.get_client() self.db = await ClientManager.get_client()
async def finalize(self): async def finalize(self):
if hasattr(self, "db") and self.db is not None: if self.db is not None:
await ClientManager.release_client(self.db) await ClientManager.release_client(self.db)
self.db = None self.db = None
@@ -395,6 +394,8 @@ class OracleKVStorage(BaseKVStorage):
@final @final
@dataclass @dataclass
class OracleVectorDBStorage(BaseVectorStorage): class OracleVectorDBStorage(BaseVectorStorage):
db: OracleDB = field(init=False)
def __post_init__(self): def __post_init__(self):
config = self.global_config.get("vector_db_storage_cls_kwargs", {}) config = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = config.get("cosine_better_than_threshold") cosine_threshold = config.get("cosine_better_than_threshold")
@@ -405,11 +406,11 @@ class OracleVectorDBStorage(BaseVectorStorage):
self.cosine_better_than_threshold = cosine_threshold self.cosine_better_than_threshold = cosine_threshold
async def initialize(self): async def initialize(self):
if not hasattr(self, "db") or self.db is None: if self.db is None:
self.db = await ClientManager.get_client() self.db = await ClientManager.get_client()
async def finalize(self): async def finalize(self):
if hasattr(self, "db") and self.db is not None: if self.db is not None:
await ClientManager.release_client(self.db) await ClientManager.release_client(self.db)
self.db = None self.db = None
@@ -449,15 +450,17 @@ class OracleVectorDBStorage(BaseVectorStorage):
@final @final
@dataclass @dataclass
class OracleGraphStorage(BaseGraphStorage): class OracleGraphStorage(BaseGraphStorage):
db: OracleDB = field(init=False)
def __post_init__(self): def __post_init__(self):
self._max_batch_size = self.global_config.get("embedding_batch_num", 10) self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
async def initialize(self): async def initialize(self):
if not hasattr(self, "db") or self.db is None: if self.db is None:
self.db = await ClientManager.get_client() self.db = await ClientManager.get_client()
async def finalize(self): async def finalize(self):
if hasattr(self, "db") and self.db is not None: if self.db is not None:
await ClientManager.release_client(self.db) await ClientManager.release_client(self.db)
self.db = None self.db = None

View File

@@ -3,7 +3,7 @@ import inspect
import json import json
import os import os
import time import time
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Any, Dict, List, Union, final from typing import Any, Dict, List, Union, final
import numpy as np import numpy as np
import configparser import configparser
@@ -246,18 +246,17 @@ class ClientManager:
@final @final
@dataclass @dataclass
class PGKVStorage(BaseKVStorage): class PGKVStorage(BaseKVStorage):
# db instance must be injected before use db: PostgreSQLDB = field(init=False)
# db: PostgreSQLDB
def __post_init__(self): def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
async def initialize(self): async def initialize(self):
if not hasattr(self, "db") or self.db is None: if self.db is None:
self.db = await ClientManager.get_client() self.db = await ClientManager.get_client()
async def finalize(self): async def finalize(self):
if hasattr(self, "db") and self.db is not None: if self.db is not None:
await ClientManager.release_client(self.db) await ClientManager.release_client(self.db)
self.db = None self.db = None
@@ -379,6 +378,8 @@ class PGKVStorage(BaseKVStorage):
@final @final
@dataclass @dataclass
class PGVectorStorage(BaseVectorStorage): class PGVectorStorage(BaseVectorStorage):
db: PostgreSQLDB = field(init=False)
def __post_init__(self): def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
config = self.global_config.get("vector_db_storage_cls_kwargs", {}) config = self.global_config.get("vector_db_storage_cls_kwargs", {})
@@ -390,11 +391,11 @@ class PGVectorStorage(BaseVectorStorage):
self.cosine_better_than_threshold = cosine_threshold self.cosine_better_than_threshold = cosine_threshold
async def initialize(self): async def initialize(self):
if not hasattr(self, "db") or self.db is None: if self.db is None:
self.db = await ClientManager.get_client() self.db = await ClientManager.get_client()
async def finalize(self): async def finalize(self):
if hasattr(self, "db") and self.db is not None: if self.db is not None:
await ClientManager.release_client(self.db) await ClientManager.release_client(self.db)
self.db = None self.db = None
@@ -514,12 +515,14 @@ class PGVectorStorage(BaseVectorStorage):
@final @final
@dataclass @dataclass
class PGDocStatusStorage(DocStatusStorage): class PGDocStatusStorage(DocStatusStorage):
db: PostgreSQLDB = field(init=False)
async def initialize(self): async def initialize(self):
if not hasattr(self, "db") or self.db is None: if self.db is None:
self.db = await ClientManager.get_client() self.db = await ClientManager.get_client()
async def finalize(self): async def finalize(self):
if hasattr(self, "db") and self.db is not None: if self.db is not None:
await ClientManager.release_client(self.db) await ClientManager.release_client(self.db)
self.db = None self.db = None
@@ -662,6 +665,8 @@ class PGGraphQueryException(Exception):
@final @final
@dataclass @dataclass
class PGGraphStorage(BaseGraphStorage): class PGGraphStorage(BaseGraphStorage):
db: PostgreSQLDB = field(init=False)
@staticmethod @staticmethod
def load_nx_graph(file_name): def load_nx_graph(file_name):
print("no preloading of graph with AGE in production") print("no preloading of graph with AGE in production")
@@ -673,11 +678,11 @@ class PGGraphStorage(BaseGraphStorage):
} }
async def initialize(self): async def initialize(self):
if not hasattr(self, "db") or self.db is None: if self.db is None:
self.db = await ClientManager.get_client() self.db = await ClientManager.get_client()
async def finalize(self): async def finalize(self):
if hasattr(self, "db") and self.db is not None: if self.db is not None:
await ClientManager.release_client(self.db) await ClientManager.release_client(self.db)
self.db = None self.db = None

View File

@@ -1,6 +1,6 @@
import asyncio import asyncio
import os import os
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Any, Union, final from typing import Any, Union, final
import numpy as np import numpy as np
@@ -166,19 +166,18 @@ class ClientManager:
@final @final
@dataclass @dataclass
class TiDBKVStorage(BaseKVStorage): class TiDBKVStorage(BaseKVStorage):
# db instance must be injected before use db: TiDB = field(init=False)
# db: TiDB
def __post_init__(self): def __post_init__(self):
self._data = {} self._data = {}
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
async def initialize(self): async def initialize(self):
if not hasattr(self, "db") or self.db is None: if self.db is None:
self.db = await ClientManager.get_client() self.db = await ClientManager.get_client()
async def finalize(self): async def finalize(self):
if hasattr(self, "db") and self.db is not None: if self.db is not None:
await ClientManager.release_client(self.db) await ClientManager.release_client(self.db)
self.db = None self.db = None
@@ -280,6 +279,8 @@ class TiDBKVStorage(BaseKVStorage):
@final @final
@dataclass @dataclass
class TiDBVectorDBStorage(BaseVectorStorage): class TiDBVectorDBStorage(BaseVectorStorage):
db: TiDB = field(init=False)
def __post_init__(self): def __post_init__(self):
self._client_file_name = os.path.join( self._client_file_name = os.path.join(
self.global_config["working_dir"], f"vdb_{self.namespace}.json" self.global_config["working_dir"], f"vdb_{self.namespace}.json"
@@ -294,11 +295,11 @@ class TiDBVectorDBStorage(BaseVectorStorage):
self.cosine_better_than_threshold = cosine_threshold self.cosine_better_than_threshold = cosine_threshold
async def initialize(self): async def initialize(self):
if not hasattr(self, "db") or self.db is None: if self.db is None:
self.db = await ClientManager.get_client() self.db = await ClientManager.get_client()
async def finalize(self): async def finalize(self):
if hasattr(self, "db") and self.db is not None: if self.db is not None:
await ClientManager.release_client(self.db) await ClientManager.release_client(self.db)
self.db = None self.db = None
@@ -421,18 +422,17 @@ class TiDBVectorDBStorage(BaseVectorStorage):
@final @final
@dataclass @dataclass
class TiDBGraphStorage(BaseGraphStorage): class TiDBGraphStorage(BaseGraphStorage):
# db instance must be injected before use db: TiDB = field(init=False)
# db: TiDB
def __post_init__(self): def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
async def initialize(self): async def initialize(self):
if not hasattr(self, "db") or self.db is None: if self.db is None:
self.db = await ClientManager.get_client() self.db = await ClientManager.get_client()
async def finalize(self): async def finalize(self):
if hasattr(self, "db") and self.db is not None: if self.db is not None:
await ClientManager.release_client(self.db) await ClientManager.release_client(self.db)
self.db = None self.db = None