refactor: improve database initialization by centralizing db instance injection
- Move db configs to separate methods - Remove db field defaults in storage classes - Add _initialize_database_if_needed method - Inject db instances during initialization - Clean up storage implementation code
This commit is contained in:
@@ -177,7 +177,8 @@ class PostgreSQLDB:
|
||||
|
||||
@dataclass
|
||||
class PGKVStorage(BaseKVStorage):
|
||||
db: PostgreSQLDB = None
|
||||
# db instance must be injected before use
|
||||
# db: PostgreSQLDB
|
||||
|
||||
def __post_init__(self):
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
@@ -296,8 +297,9 @@ class PGKVStorage(BaseKVStorage):
|
||||
|
||||
@dataclass
|
||||
class PGVectorStorage(BaseVectorStorage):
|
||||
# db instance must be injected before use
|
||||
# db: PostgreSQLDB
|
||||
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
||||
db: PostgreSQLDB = None
|
||||
|
||||
def __post_init__(self):
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
@@ -418,10 +420,8 @@ class PGVectorStorage(BaseVectorStorage):
|
||||
class PGDocStatusStorage(DocStatusStorage):
|
||||
"""PostgreSQL implementation of document status storage"""
|
||||
|
||||
db: PostgreSQLDB = None
|
||||
|
||||
def __post_init__(self):
|
||||
pass
|
||||
# db instance must be injected before use
|
||||
db: PostgreSQLDB
|
||||
|
||||
async def filter_keys(self, data: set[str]) -> set[str]:
|
||||
"""Return keys that don't exist in storage"""
|
||||
@@ -577,19 +577,15 @@ class PGGraphQueryException(Exception):
|
||||
|
||||
@dataclass
|
||||
class PGGraphStorage(BaseGraphStorage):
|
||||
db: PostgreSQLDB = None
|
||||
# db instance must be injected before use
|
||||
# db: PostgreSQLDB
|
||||
|
||||
@staticmethod
|
||||
def load_nx_graph(file_name):
|
||||
print("no preloading of graph with AGE in production")
|
||||
|
||||
def __init__(self, namespace, global_config, embedding_func):
|
||||
super().__init__(
|
||||
namespace=namespace,
|
||||
global_config=global_config,
|
||||
embedding_func=embedding_func,
|
||||
)
|
||||
self.graph_name = os.environ["AGE_GRAPH_NAME"]
|
||||
def __post_init__(self):
|
||||
self.graph_name = os.environ.get("AGE_GRAPH_NAME", "lightrag")
|
||||
self._node_embed_algorithms = {
|
||||
"node2vec": self._node2vec_embed,
|
||||
}
|
||||
|
Reference in New Issue
Block a user