Inject oracle db to LightRag storage class when needed

This commit is contained in:
yangdx
2025-02-11 03:54:54 +08:00
parent 8cfca5a141
commit a4cf7e66d3
2 changed files with 54 additions and 1 deletions

View File

@@ -1,5 +1,6 @@
import asyncio
import os
import configparser
from dataclasses import asdict, dataclass, field
from datetime import datetime
from functools import partial
@@ -15,6 +16,7 @@ from .base import (
QueryParam,
StorageNameSpace,
)
from .kg.oracle_impl import OracleDB
from .namespace import NameSpace, make_namespace
from .operate import (
chunking_by_token_size,
@@ -35,6 +37,9 @@ from .utils import (
set_logger,
)
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
# Storage type and implementation compatibility validation table
STORAGE_IMPLEMENTATIONS = {
"KV_STORAGE": {
@@ -389,6 +394,53 @@ class LightRAG:
self.graph_storage_cls, global_config=global_config
)
# 检查是否使用了 Oracle 存储实现
if (
self.kv_storage == "OracleKVStorage"
or self.vector_storage == "OracleVectorDBStorage"
or self.graph_storage == "OracleGraphStorage"
):
# 从环境变量或配置文件获取参数
dbconfig = {
"user": os.environ.get(
"ORACLE_USER", config.get("oracle", "user", fallback=None)
),
"password": os.environ.get(
"ORACLE_PASSWORD",
config.get("oracle", "password", fallback=None),
),
"dsn": os.environ.get(
"ORACLE_DSN", config.get("oracle", "dsn", fallback=None)
),
"config_dir": os.environ.get(
"ORACLE_CONFIG_DIR",
config.get("oracle", "config_dir", fallback=None),
),
"wallet_location": os.environ.get(
"ORACLE_WALLET_LOCATION",
config.get("oracle", "wallet_location", fallback=None),
),
"wallet_password": os.environ.get(
"ORACLE_WALLET_PASSWORD",
config.get("oracle", "wallet_password", fallback=None),
),
"workspace": os.environ.get(
"ORACLE_WORKSPACE",
config.get("oracle", "workspace", fallback="default"),
),
}
# 初始化 OracleDB 对象
oracle_db = OracleDB(dbconfig)
# 只对 Oracle 实现的存储类注入 db 对象
if self.kv_storage == "OracleKVStorage":
self.key_string_value_json_storage_cls.db = oracle_db
if self.vector_storage == "OracleVectorDBStorage":
self.vector_db_storage_cls.db = oracle_db
if self.graph_storage == "OracleGraphStorage":
self.graph_storage_cls.db = oracle_db
self.json_doc_status_storage = self.key_string_value_json_storage_cls(
namespace=self.namespace_prefix + "json_doc_status_storage",
embedding_func=None,