improve conditional checks for db instance
This commit is contained in:
@@ -3,7 +3,7 @@ import inspect
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Union, final
|
||||
import numpy as np
|
||||
import configparser
|
||||
@@ -246,18 +246,17 @@ class ClientManager:
|
||||
@final
|
||||
@dataclass
|
||||
class PGKVStorage(BaseKVStorage):
|
||||
# db instance must be injected before use
|
||||
# db: PostgreSQLDB
|
||||
db: PostgreSQLDB = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
|
||||
async def initialize(self):
|
||||
if not hasattr(self, "db") or self.db is None:
|
||||
if self.db is None:
|
||||
self.db = await ClientManager.get_client()
|
||||
|
||||
async def finalize(self):
|
||||
if hasattr(self, "db") and self.db is not None:
|
||||
if self.db is not None:
|
||||
await ClientManager.release_client(self.db)
|
||||
self.db = None
|
||||
|
||||
@@ -379,6 +378,8 @@ class PGKVStorage(BaseKVStorage):
|
||||
@final
|
||||
@dataclass
|
||||
class PGVectorStorage(BaseVectorStorage):
|
||||
db: PostgreSQLDB = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
@@ -390,11 +391,11 @@ class PGVectorStorage(BaseVectorStorage):
|
||||
self.cosine_better_than_threshold = cosine_threshold
|
||||
|
||||
async def initialize(self):
|
||||
if not hasattr(self, "db") or self.db is None:
|
||||
if self.db is None:
|
||||
self.db = await ClientManager.get_client()
|
||||
|
||||
async def finalize(self):
|
||||
if hasattr(self, "db") and self.db is not None:
|
||||
if self.db is not None:
|
||||
await ClientManager.release_client(self.db)
|
||||
self.db = None
|
||||
|
||||
@@ -514,12 +515,14 @@ class PGVectorStorage(BaseVectorStorage):
|
||||
@final
|
||||
@dataclass
|
||||
class PGDocStatusStorage(DocStatusStorage):
|
||||
db: PostgreSQLDB = field(init=False)
|
||||
|
||||
async def initialize(self):
|
||||
if not hasattr(self, "db") or self.db is None:
|
||||
if self.db is None:
|
||||
self.db = await ClientManager.get_client()
|
||||
|
||||
async def finalize(self):
|
||||
if hasattr(self, "db") and self.db is not None:
|
||||
if self.db is not None:
|
||||
await ClientManager.release_client(self.db)
|
||||
self.db = None
|
||||
|
||||
@@ -662,6 +665,8 @@ class PGGraphQueryException(Exception):
|
||||
@final
|
||||
@dataclass
|
||||
class PGGraphStorage(BaseGraphStorage):
|
||||
db: PostgreSQLDB = field(init=False)
|
||||
|
||||
@staticmethod
|
||||
def load_nx_graph(file_name):
|
||||
print("no preloading of graph with AGE in production")
|
||||
@@ -673,11 +678,11 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
}
|
||||
|
||||
async def initialize(self):
|
||||
if not hasattr(self, "db") or self.db is None:
|
||||
if self.db is None:
|
||||
self.db = await ClientManager.get_client()
|
||||
|
||||
async def finalize(self):
|
||||
if hasattr(self, "db") and self.db is not None:
|
||||
if self.db is not None:
|
||||
await ClientManager.release_client(self.db)
|
||||
self.db = None
|
||||
|
||||
|
Reference in New Issue
Block a user