improve conditional checks for db instance
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import configparser
|
import configparser
|
||||||
from tqdm.asyncio import tqdm as tqdm_async
|
from tqdm.asyncio import tqdm as tqdm_async
|
||||||
@@ -27,7 +27,11 @@ if not pm.is_installed("motor"):
|
|||||||
pm.install("motor")
|
pm.install("motor")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
|
from motor.motor_asyncio import (
|
||||||
|
AsyncIOMotorClient,
|
||||||
|
AsyncIOMotorDatabase,
|
||||||
|
AsyncIOMotorCollection,
|
||||||
|
)
|
||||||
from pymongo.operations import SearchIndexModel
|
from pymongo.operations import SearchIndexModel
|
||||||
from pymongo.errors import PyMongoError
|
from pymongo.errors import PyMongoError
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
@@ -79,19 +83,23 @@ class ClientManager:
|
|||||||
@final
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class MongoKVStorage(BaseKVStorage):
|
class MongoKVStorage(BaseKVStorage):
|
||||||
|
db: AsyncIOMotorDatabase = field(init=False)
|
||||||
|
_data: AsyncIOMotorCollection = field(init=False)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._collection_name = self.namespace
|
self._collection_name = self.namespace
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
if not hasattr(self, "db") or self.db is None:
|
if self.db is None:
|
||||||
self.db = await ClientManager.get_client()
|
self.db = await ClientManager.get_client()
|
||||||
self._data = await get_or_create_collection(self.db, self._collection_name)
|
self._data = await get_or_create_collection(self.db, self._collection_name)
|
||||||
logger.debug(f"Use MongoDB as KV {self._collection_name}")
|
logger.debug(f"Use MongoDB as KV {self._collection_name}")
|
||||||
|
|
||||||
async def finalize(self):
|
async def finalize(self):
|
||||||
if hasattr(self, "db") and self.db is not None:
|
if self.db is not None:
|
||||||
await ClientManager.release_client(self.db)
|
await ClientManager.release_client(self.db)
|
||||||
self.db = None
|
self.db = None
|
||||||
|
self._data = None
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||||
return await self._data.find_one({"_id": id})
|
return await self._data.find_one({"_id": id})
|
||||||
@@ -148,19 +156,23 @@ class MongoKVStorage(BaseKVStorage):
|
|||||||
@final
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class MongoDocStatusStorage(DocStatusStorage):
|
class MongoDocStatusStorage(DocStatusStorage):
|
||||||
|
db: AsyncIOMotorDatabase = field(init=False)
|
||||||
|
_data: AsyncIOMotorCollection = field(init=False)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._collection_name = self.namespace
|
self._collection_name = self.namespace
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
if not hasattr(self, "db") or self.db is None:
|
if self.db is None:
|
||||||
self.db = await ClientManager.get_client()
|
self.db = await ClientManager.get_client()
|
||||||
self._data = await get_or_create_collection(self.db, self._collection_name)
|
self._data = await get_or_create_collection(self.db, self._collection_name)
|
||||||
logger.debug(f"Use MongoDB as DocStatus {self._collection_name}")
|
logger.debug(f"Use MongoDB as DocStatus {self._collection_name}")
|
||||||
|
|
||||||
async def finalize(self):
|
async def finalize(self):
|
||||||
if hasattr(self, "db") and self.db is not None:
|
if self.db is not None:
|
||||||
await ClientManager.release_client(self.db)
|
await ClientManager.release_client(self.db)
|
||||||
self.db = None
|
self.db = None
|
||||||
|
self._data = None
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||||
return await self._data.find_one({"_id": id})
|
return await self._data.find_one({"_id": id})
|
||||||
@@ -221,9 +233,12 @@ class MongoDocStatusStorage(DocStatusStorage):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class MongoGraphStorage(BaseGraphStorage):
|
class MongoGraphStorage(BaseGraphStorage):
|
||||||
"""
|
"""
|
||||||
A concrete implementation using MongoDB’s $graphLookup to demonstrate multi-hop queries.
|
A concrete implementation using MongoDB's $graphLookup to demonstrate multi-hop queries.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
db: AsyncIOMotorDatabase = field(init=False)
|
||||||
|
collection: AsyncIOMotorCollection = field(init=False)
|
||||||
|
|
||||||
def __init__(self, namespace, global_config, embedding_func):
|
def __init__(self, namespace, global_config, embedding_func):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
namespace=namespace,
|
namespace=namespace,
|
||||||
@@ -233,7 +248,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||||||
self._collection_name = self.namespace
|
self._collection_name = self.namespace
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
if not hasattr(self, "db") or self.db is None:
|
if self.db is None:
|
||||||
self.db = await ClientManager.get_client()
|
self.db = await ClientManager.get_client()
|
||||||
self.collection = await get_or_create_collection(
|
self.collection = await get_or_create_collection(
|
||||||
self.db, self._collection_name
|
self.db, self._collection_name
|
||||||
@@ -241,9 +256,10 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||||||
logger.debug(f"Use MongoDB as KG {self._collection_name}")
|
logger.debug(f"Use MongoDB as KG {self._collection_name}")
|
||||||
|
|
||||||
async def finalize(self):
|
async def finalize(self):
|
||||||
if hasattr(self, "db") and self.db is not None:
|
if self.db is not None:
|
||||||
await ClientManager.release_client(self.db)
|
await ClientManager.release_client(self.db)
|
||||||
self.db = None
|
self.db = None
|
||||||
|
self.collection = None
|
||||||
|
|
||||||
#
|
#
|
||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
@@ -782,6 +798,9 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||||||
@final
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class MongoVectorDBStorage(BaseVectorStorage):
|
class MongoVectorDBStorage(BaseVectorStorage):
|
||||||
|
db: AsyncIOMotorDatabase = field(init=False)
|
||||||
|
_data: AsyncIOMotorCollection = field(init=False)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
cosine_threshold = kwargs.get("cosine_better_than_threshold")
|
cosine_threshold = kwargs.get("cosine_better_than_threshold")
|
||||||
@@ -794,7 +813,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
|||||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
if not hasattr(self, "db") or self.db is None:
|
if self.db is None:
|
||||||
self.db = await ClientManager.get_client()
|
self.db = await ClientManager.get_client()
|
||||||
self._data = await get_or_create_collection(self.db, self._collection_name)
|
self._data = await get_or_create_collection(self.db, self._collection_name)
|
||||||
|
|
||||||
@@ -804,9 +823,10 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
|||||||
logger.debug(f"Use MongoDB as VDB {self._collection_name}")
|
logger.debug(f"Use MongoDB as VDB {self._collection_name}")
|
||||||
|
|
||||||
async def finalize(self):
|
async def finalize(self):
|
||||||
if hasattr(self, "db") and self.db is not None:
|
if self.db is not None:
|
||||||
await ClientManager.release_client(self.db)
|
await ClientManager.release_client(self.db)
|
||||||
self.db = None
|
self.db = None
|
||||||
|
self._data = None
|
||||||
|
|
||||||
async def create_vector_index_if_not_exists(self):
|
async def create_vector_index_if_not_exists(self):
|
||||||
"""Creates an Atlas Vector Search index."""
|
"""Creates an Atlas Vector Search index."""
|
||||||
|
@@ -3,7 +3,7 @@ import asyncio
|
|||||||
|
|
||||||
# import html
|
# import html
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Union, final
|
from typing import Any, Union, final
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import configparser
|
import configparser
|
||||||
@@ -242,8 +242,7 @@ class ClientManager:
|
|||||||
@final
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class OracleKVStorage(BaseKVStorage):
|
class OracleKVStorage(BaseKVStorage):
|
||||||
# db instance must be injected before use
|
db: OracleDB = field(init=False)
|
||||||
# db: OracleDB
|
|
||||||
meta_fields = None
|
meta_fields = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@@ -251,11 +250,11 @@ class OracleKVStorage(BaseKVStorage):
|
|||||||
self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
|
self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
if not hasattr(self, "db") or self.db is None:
|
if self.db is None:
|
||||||
self.db = await ClientManager.get_client()
|
self.db = await ClientManager.get_client()
|
||||||
|
|
||||||
async def finalize(self):
|
async def finalize(self):
|
||||||
if hasattr(self, "db") and self.db is not None:
|
if self.db is not None:
|
||||||
await ClientManager.release_client(self.db)
|
await ClientManager.release_client(self.db)
|
||||||
self.db = None
|
self.db = None
|
||||||
|
|
||||||
@@ -395,6 +394,8 @@ class OracleKVStorage(BaseKVStorage):
|
|||||||
@final
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class OracleVectorDBStorage(BaseVectorStorage):
|
class OracleVectorDBStorage(BaseVectorStorage):
|
||||||
|
db: OracleDB = field(init=False)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
cosine_threshold = config.get("cosine_better_than_threshold")
|
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||||
@@ -405,11 +406,11 @@ class OracleVectorDBStorage(BaseVectorStorage):
|
|||||||
self.cosine_better_than_threshold = cosine_threshold
|
self.cosine_better_than_threshold = cosine_threshold
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
if not hasattr(self, "db") or self.db is None:
|
if self.db is None:
|
||||||
self.db = await ClientManager.get_client()
|
self.db = await ClientManager.get_client()
|
||||||
|
|
||||||
async def finalize(self):
|
async def finalize(self):
|
||||||
if hasattr(self, "db") and self.db is not None:
|
if self.db is not None:
|
||||||
await ClientManager.release_client(self.db)
|
await ClientManager.release_client(self.db)
|
||||||
self.db = None
|
self.db = None
|
||||||
|
|
||||||
@@ -449,15 +450,17 @@ class OracleVectorDBStorage(BaseVectorStorage):
|
|||||||
@final
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class OracleGraphStorage(BaseGraphStorage):
|
class OracleGraphStorage(BaseGraphStorage):
|
||||||
|
db: OracleDB = field(init=False)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
|
self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
if not hasattr(self, "db") or self.db is None:
|
if self.db is None:
|
||||||
self.db = await ClientManager.get_client()
|
self.db = await ClientManager.get_client()
|
||||||
|
|
||||||
async def finalize(self):
|
async def finalize(self):
|
||||||
if hasattr(self, "db") and self.db is not None:
|
if self.db is not None:
|
||||||
await ClientManager.release_client(self.db)
|
await ClientManager.release_client(self.db)
|
||||||
self.db = None
|
self.db = None
|
||||||
|
|
||||||
|
@@ -3,7 +3,7 @@ import inspect
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Dict, List, Union, final
|
from typing import Any, Dict, List, Union, final
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import configparser
|
import configparser
|
||||||
@@ -246,18 +246,17 @@ class ClientManager:
|
|||||||
@final
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class PGKVStorage(BaseKVStorage):
|
class PGKVStorage(BaseKVStorage):
|
||||||
# db instance must be injected before use
|
db: PostgreSQLDB = field(init=False)
|
||||||
# db: PostgreSQLDB
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
if not hasattr(self, "db") or self.db is None:
|
if self.db is None:
|
||||||
self.db = await ClientManager.get_client()
|
self.db = await ClientManager.get_client()
|
||||||
|
|
||||||
async def finalize(self):
|
async def finalize(self):
|
||||||
if hasattr(self, "db") and self.db is not None:
|
if self.db is not None:
|
||||||
await ClientManager.release_client(self.db)
|
await ClientManager.release_client(self.db)
|
||||||
self.db = None
|
self.db = None
|
||||||
|
|
||||||
@@ -379,6 +378,8 @@ class PGKVStorage(BaseKVStorage):
|
|||||||
@final
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class PGVectorStorage(BaseVectorStorage):
|
class PGVectorStorage(BaseVectorStorage):
|
||||||
|
db: PostgreSQLDB = field(init=False)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
@@ -390,11 +391,11 @@ class PGVectorStorage(BaseVectorStorage):
|
|||||||
self.cosine_better_than_threshold = cosine_threshold
|
self.cosine_better_than_threshold = cosine_threshold
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
if not hasattr(self, "db") or self.db is None:
|
if self.db is None:
|
||||||
self.db = await ClientManager.get_client()
|
self.db = await ClientManager.get_client()
|
||||||
|
|
||||||
async def finalize(self):
|
async def finalize(self):
|
||||||
if hasattr(self, "db") and self.db is not None:
|
if self.db is not None:
|
||||||
await ClientManager.release_client(self.db)
|
await ClientManager.release_client(self.db)
|
||||||
self.db = None
|
self.db = None
|
||||||
|
|
||||||
@@ -514,12 +515,14 @@ class PGVectorStorage(BaseVectorStorage):
|
|||||||
@final
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class PGDocStatusStorage(DocStatusStorage):
|
class PGDocStatusStorage(DocStatusStorage):
|
||||||
|
db: PostgreSQLDB = field(init=False)
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
if not hasattr(self, "db") or self.db is None:
|
if self.db is None:
|
||||||
self.db = await ClientManager.get_client()
|
self.db = await ClientManager.get_client()
|
||||||
|
|
||||||
async def finalize(self):
|
async def finalize(self):
|
||||||
if hasattr(self, "db") and self.db is not None:
|
if self.db is not None:
|
||||||
await ClientManager.release_client(self.db)
|
await ClientManager.release_client(self.db)
|
||||||
self.db = None
|
self.db = None
|
||||||
|
|
||||||
@@ -662,6 +665,8 @@ class PGGraphQueryException(Exception):
|
|||||||
@final
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class PGGraphStorage(BaseGraphStorage):
|
class PGGraphStorage(BaseGraphStorage):
|
||||||
|
db: PostgreSQLDB = field(init=False)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_nx_graph(file_name):
|
def load_nx_graph(file_name):
|
||||||
print("no preloading of graph with AGE in production")
|
print("no preloading of graph with AGE in production")
|
||||||
@@ -673,11 +678,11 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
}
|
}
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
if not hasattr(self, "db") or self.db is None:
|
if self.db is None:
|
||||||
self.db = await ClientManager.get_client()
|
self.db = await ClientManager.get_client()
|
||||||
|
|
||||||
async def finalize(self):
|
async def finalize(self):
|
||||||
if hasattr(self, "db") and self.db is not None:
|
if self.db is not None:
|
||||||
await ClientManager.release_client(self.db)
|
await ClientManager.release_client(self.db)
|
||||||
self.db = None
|
self.db = None
|
||||||
|
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Union, final
|
from typing import Any, Union, final
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -166,19 +166,18 @@ class ClientManager:
|
|||||||
@final
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class TiDBKVStorage(BaseKVStorage):
|
class TiDBKVStorage(BaseKVStorage):
|
||||||
# db instance must be injected before use
|
db: TiDB = field(init=False)
|
||||||
# db: TiDB
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._data = {}
|
self._data = {}
|
||||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
if not hasattr(self, "db") or self.db is None:
|
if self.db is None:
|
||||||
self.db = await ClientManager.get_client()
|
self.db = await ClientManager.get_client()
|
||||||
|
|
||||||
async def finalize(self):
|
async def finalize(self):
|
||||||
if hasattr(self, "db") and self.db is not None:
|
if self.db is not None:
|
||||||
await ClientManager.release_client(self.db)
|
await ClientManager.release_client(self.db)
|
||||||
self.db = None
|
self.db = None
|
||||||
|
|
||||||
@@ -280,6 +279,8 @@ class TiDBKVStorage(BaseKVStorage):
|
|||||||
@final
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class TiDBVectorDBStorage(BaseVectorStorage):
|
class TiDBVectorDBStorage(BaseVectorStorage):
|
||||||
|
db: TiDB = field(init=False)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._client_file_name = os.path.join(
|
self._client_file_name = os.path.join(
|
||||||
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
||||||
@@ -294,11 +295,11 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
|||||||
self.cosine_better_than_threshold = cosine_threshold
|
self.cosine_better_than_threshold = cosine_threshold
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
if not hasattr(self, "db") or self.db is None:
|
if self.db is None:
|
||||||
self.db = await ClientManager.get_client()
|
self.db = await ClientManager.get_client()
|
||||||
|
|
||||||
async def finalize(self):
|
async def finalize(self):
|
||||||
if hasattr(self, "db") and self.db is not None:
|
if self.db is not None:
|
||||||
await ClientManager.release_client(self.db)
|
await ClientManager.release_client(self.db)
|
||||||
self.db = None
|
self.db = None
|
||||||
|
|
||||||
@@ -421,18 +422,17 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
|||||||
@final
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class TiDBGraphStorage(BaseGraphStorage):
|
class TiDBGraphStorage(BaseGraphStorage):
|
||||||
# db instance must be injected before use
|
db: TiDB = field(init=False)
|
||||||
# db: TiDB
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
if not hasattr(self, "db") or self.db is None:
|
if self.db is None:
|
||||||
self.db = await ClientManager.get_client()
|
self.db = await ClientManager.get_client()
|
||||||
|
|
||||||
async def finalize(self):
|
async def finalize(self):
|
||||||
if hasattr(self, "db") and self.db is not None:
|
if self.db is not None:
|
||||||
await ClientManager.release_client(self.db)
|
await ClientManager.release_client(self.db)
|
||||||
self.db = None
|
self.db = None
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user