added final, required methods and cleaned import
This commit is contained in:
@@ -4,17 +4,11 @@ import asyncio
|
||||
# import html
|
||||
# import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Union
|
||||
from typing import Any, Union, final
|
||||
|
||||
import numpy as np
|
||||
import pipmaster as pm
|
||||
|
||||
if not pm.is_installed("oracledb"):
|
||||
pm.install("oracledb")
|
||||
|
||||
|
||||
from lightrag.types import KnowledgeGraph
|
||||
import oracledb
|
||||
|
||||
from ..base import (
|
||||
BaseGraphStorage,
|
||||
@@ -24,6 +18,14 @@ from ..base import (
|
||||
from ..namespace import NameSpace, is_namespace
|
||||
from ..utils import logger
|
||||
|
||||
try:
|
||||
import oracledb
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"oracledb library is not installed. Please install it to proceed."
|
||||
) from e
|
||||
|
||||
|
||||
class OracleDB:
|
||||
def __init__(self, config, **kwargs):
|
||||
@@ -170,6 +172,7 @@ class OracleDB:
|
||||
raise
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class OracleKVStorage(BaseKVStorage):
|
||||
# db instance must be injected before use
|
||||
@@ -319,12 +322,9 @@ class OracleKVStorage(BaseKVStorage):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class OracleVectorDBStorage(BaseVectorStorage):
|
||||
# db instance must be injected before use
|
||||
# db: OracleDB
|
||||
cosine_better_than_threshold: float = None
|
||||
|
||||
def __post_init__(self):
|
||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||
@@ -337,7 +337,7 @@ class OracleVectorDBStorage(BaseVectorStorage):
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
pass
|
||||
|
||||
async def index_done_callback(self):
|
||||
async def index_done_callback(self) -> None:
|
||||
pass
|
||||
|
||||
#################### query method ###############
|
||||
@@ -370,13 +370,10 @@ class OracleVectorDBStorage(BaseVectorStorage):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class OracleGraphStorage(BaseGraphStorage):
|
||||
# db instance must be injected before use
|
||||
# db: OracleDB
|
||||
|
||||
def __post_init__(self):
|
||||
"""从graphml文件加载图"""
|
||||
self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
|
||||
|
||||
#################### insert method ################
|
||||
@@ -474,10 +471,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
return embeddings, nodes_ids
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
"""写入graphhml图文件"""
|
||||
logger.info(
|
||||
"Node and edge data had been saved into oracle db already, so nothing to do here!"
|
||||
)
|
||||
pass
|
||||
|
||||
#################### query method #################
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
|
Reference in New Issue
Block a user