diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index ca6bcfb2..f34fe4b1 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -361,7 +361,8 @@ class OracleVectorDBStorage(BaseVectorStorage): @dataclass class OracleGraphStorage(BaseGraphStorage): - """基于Oracle的图存储模块""" + # should pass db object to self.db + db: OracleDB = None def __post_init__(self): """从graphml文件加载图""" diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index de0e4f59..6b9161be 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -1,5 +1,6 @@ import asyncio import os +import configparser from dataclasses import asdict, dataclass, field from datetime import datetime from functools import partial @@ -15,6 +16,7 @@ from .base import ( QueryParam, StorageNameSpace, ) +from .kg.oracle_impl import OracleDB from .namespace import NameSpace, make_namespace from .operate import ( chunking_by_token_size, @@ -35,6 +37,9 @@ from .utils import ( set_logger, ) +config = configparser.ConfigParser() +config.read("config.ini", "utf-8") + # Storage type and implementation compatibility validation table STORAGE_IMPLEMENTATIONS = { "KV_STORAGE": { @@ -389,6 +394,53 @@ class LightRAG: self.graph_storage_cls, global_config=global_config ) + # 检查是否使用了 Oracle 存储实现 + if ( + self.kv_storage == "OracleKVStorage" + or self.vector_storage == "OracleVectorDBStorage" + or self.graph_storage == "OracleGraphStorage" + ): + # 从环境变量或配置文件获取参数 + dbconfig = { + "user": os.environ.get( + "ORACLE_USER", config.get("oracle", "user", fallback=None) + ), + "password": os.environ.get( + "ORACLE_PASSWORD", + config.get("oracle", "password", fallback=None), + ), + "dsn": os.environ.get( + "ORACLE_DSN", config.get("oracle", "dsn", fallback=None) + ), + "config_dir": os.environ.get( + "ORACLE_CONFIG_DIR", + config.get("oracle", "config_dir", fallback=None), + ), + "wallet_location": os.environ.get( + "ORACLE_WALLET_LOCATION", + config.get("oracle", "wallet_location", fallback=None), + ), + "wallet_password": os.environ.get( + "ORACLE_WALLET_PASSWORD", + config.get("oracle", "wallet_password", fallback=None), + ), + "workspace": os.environ.get( + "ORACLE_WORKSPACE", + config.get("oracle", "workspace", fallback="default"), + ), + } + + # 初始化 OracleDB 对象 + oracle_db = OracleDB(dbconfig) + + # 只对 Oracle 实现的存储类注入 db 对象 + if self.kv_storage == "OracleKVStorage": + self.key_string_value_json_storage_cls.db = oracle_db + if self.vector_storage == "OracleVectorDBStorage": + self.vector_db_storage_cls.db = oracle_db + if self.graph_storage == "OracleGraphStorage": + self.graph_storage_cls.db = oracle_db + self.json_doc_status_storage = self.key_string_value_json_storage_cls( namespace=self.namespace_prefix + "json_doc_status_storage", embedding_func=None,