added final, required methods and cleaned import

This commit is contained in:
Yannick Stephan
2025-02-16 14:38:09 +01:00
parent 7848a38a45
commit 3fef8201c6
16 changed files with 209 additions and 316 deletions

View File

@@ -4,17 +4,11 @@ import asyncio
# import html
# import os
from dataclasses import dataclass
from typing import Any, Union
from typing import Any, Union, final
import numpy as np
import pipmaster as pm
if not pm.is_installed("oracledb"):
pm.install("oracledb")
from lightrag.types import KnowledgeGraph
import oracledb
from ..base import (
BaseGraphStorage,
@@ -24,6 +18,14 @@ from ..base import (
from ..namespace import NameSpace, is_namespace
from ..utils import logger
try:
import oracledb
except ImportError as e:
raise ImportError(
"oracledb library is not installed. Please install it to proceed."
) from e
class OracleDB:
def __init__(self, config, **kwargs):
@@ -170,6 +172,7 @@ class OracleDB:
raise
@final
@dataclass
class OracleKVStorage(BaseKVStorage):
# db instance must be injected before use
@@ -319,12 +322,9 @@ class OracleKVStorage(BaseKVStorage):
raise NotImplementedError
@final
@dataclass
class OracleVectorDBStorage(BaseVectorStorage):
# db instance must be injected before use
# db: OracleDB
cosine_better_than_threshold: float = None
def __post_init__(self):
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = config.get("cosine_better_than_threshold")
@@ -337,7 +337,7 @@ class OracleVectorDBStorage(BaseVectorStorage):
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
pass
async def index_done_callback(self):
async def index_done_callback(self) -> None:
pass
#################### query method ###############
@@ -370,13 +370,10 @@ class OracleVectorDBStorage(BaseVectorStorage):
raise NotImplementedError
@final
@dataclass
class OracleGraphStorage(BaseGraphStorage):
# db instance must be injected before use
# db: OracleDB
def __post_init__(self):
"""从graphml文件加载图"""
self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
#################### insert method ################
@@ -474,10 +471,7 @@ class OracleGraphStorage(BaseGraphStorage):
return embeddings, nodes_ids
async def index_done_callback(self) -> None:
"""写入graphhml图文件"""
logger.info(
"Node and edge data had been saved into oracle db already, so nothing to do here!"
)
pass
#################### query method #################
async def has_node(self, node_id: str) -> bool: