improve conditional checks for db instance

This commit is contained in:
ArnoChen
2025-02-19 04:53:15 +08:00
parent ae7a850d4e
commit f50604b2d3
4 changed files with 70 additions and 42 deletions

View File

@@ -3,7 +3,7 @@ import asyncio
# import html
import os
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Union, final
import numpy as np
import configparser
@@ -242,8 +242,7 @@ class ClientManager:
@final
@dataclass
class OracleKVStorage(BaseKVStorage):
# db instance must be injected before use
# db: OracleDB
db: OracleDB = field(init=False)
meta_fields = None
def __post_init__(self):
@@ -251,11 +250,11 @@ class OracleKVStorage(BaseKVStorage):
self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
async def initialize(self):
if not hasattr(self, "db") or self.db is None:
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if hasattr(self, "db") and self.db is not None:
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
@@ -395,6 +394,8 @@ class OracleKVStorage(BaseKVStorage):
@final
@dataclass
class OracleVectorDBStorage(BaseVectorStorage):
db: OracleDB = field(init=False)
def __post_init__(self):
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = config.get("cosine_better_than_threshold")
@@ -405,11 +406,11 @@ class OracleVectorDBStorage(BaseVectorStorage):
self.cosine_better_than_threshold = cosine_threshold
async def initialize(self):
if not hasattr(self, "db") or self.db is None:
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if hasattr(self, "db") and self.db is not None:
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
@@ -449,15 +450,17 @@ class OracleVectorDBStorage(BaseVectorStorage):
@final
@dataclass
class OracleGraphStorage(BaseGraphStorage):
db: OracleDB = field(init=False)
def __post_init__(self):
self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
async def initialize(self):
if not hasattr(self, "db") or self.db is None:
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if hasattr(self, "db") and self.db is not None:
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None