Inject oracle db to LightRag storage class when needed

This commit is contained in:
yangdx
2025-02-11 03:54:54 +08:00
parent 8cfca5a141
commit a4cf7e66d3
2 changed files with 54 additions and 1 deletions

View File

@@ -361,7 +361,8 @@ class OracleVectorDBStorage(BaseVectorStorage):
@dataclass @dataclass
class OracleGraphStorage(BaseGraphStorage): class OracleGraphStorage(BaseGraphStorage):
"""基于Oracle的图存储模块""" # should pass db object to self.db
db: OracleDB = None
def __post_init__(self): def __post_init__(self):
"""从graphml文件加载图""" """从graphml文件加载图"""

View File

@@ -1,5 +1,6 @@
import asyncio import asyncio
import os import os
import configparser
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
@@ -15,6 +16,7 @@ from .base import (
QueryParam, QueryParam,
StorageNameSpace, StorageNameSpace,
) )
from .kg.oracle_impl import OracleDB
from .namespace import NameSpace, make_namespace from .namespace import NameSpace, make_namespace
from .operate import ( from .operate import (
chunking_by_token_size, chunking_by_token_size,
@@ -35,6 +37,9 @@ from .utils import (
set_logger, set_logger,
) )
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
# Storage type and implementation compatibility validation table # Storage type and implementation compatibility validation table
STORAGE_IMPLEMENTATIONS = { STORAGE_IMPLEMENTATIONS = {
"KV_STORAGE": { "KV_STORAGE": {
@@ -389,6 +394,53 @@ class LightRAG:
self.graph_storage_cls, global_config=global_config self.graph_storage_cls, global_config=global_config
) )
# 检查是否使用了 Oracle 存储实现
if (
self.kv_storage == "OracleKVStorage"
or self.vector_storage == "OracleVectorDBStorage"
or self.graph_storage == "OracleGraphStorage"
):
# 从环境变量或配置文件获取参数
dbconfig = {
"user": os.environ.get(
"ORACLE_USER", config.get("oracle", "user", fallback=None)
),
"password": os.environ.get(
"ORACLE_PASSWORD",
config.get("oracle", "password", fallback=None),
),
"dsn": os.environ.get(
"ORACLE_DSN", config.get("oracle", "dsn", fallback=None)
),
"config_dir": os.environ.get(
"ORACLE_CONFIG_DIR",
config.get("oracle", "config_dir", fallback=None),
),
"wallet_location": os.environ.get(
"ORACLE_WALLET_LOCATION",
config.get("oracle", "wallet_location", fallback=None),
),
"wallet_password": os.environ.get(
"ORACLE_WALLET_PASSWORD",
config.get("oracle", "wallet_password", fallback=None),
),
"workspace": os.environ.get(
"ORACLE_WORKSPACE",
config.get("oracle", "workspace", fallback="default"),
),
}
# 初始化 OracleDB 对象
oracle_db = OracleDB(dbconfig)
# 只对 Oracle 实现的存储类注入 db 对象
if self.kv_storage == "OracleKVStorage":
self.key_string_value_json_storage_cls.db = oracle_db
if self.vector_storage == "OracleVectorDBStorage":
self.vector_db_storage_cls.db = oracle_db
if self.graph_storage == "OracleGraphStorage":
self.graph_storage_cls.db = oracle_db
self.json_doc_status_storage = self.key_string_value_json_storage_cls( self.json_doc_status_storage = self.key_string_value_json_storage_cls(
namespace=self.namespace_prefix + "json_doc_status_storage", namespace=self.namespace_prefix + "json_doc_status_storage",
embedding_func=None, embedding_func=None,