Inject TiDB同LightRAG storage when needed

This commit is contained in:
yangdx
2025-02-11 04:27:45 +08:00
parent 5408e7ea02
commit c5c606f491
2 changed files with 58 additions and 2 deletions

View File

@@ -16,8 +16,6 @@ from .base import (
QueryParam,
StorageNameSpace,
)
from .kg.oracle_impl import OracleDB
from .kg.postgres_impl import PostgreSQLDB
from .namespace import NameSpace, make_namespace
from .operate import (
chunking_by_token_size,
@@ -446,6 +444,7 @@ class LightRAG:
}
# 初始化 OracleDB 对象
from .kg.oracle_impl import OracleDB
oracle_db = OracleDB(dbconfig)
# Check if DB tables exist, if not, tables will be created
loop = always_get_an_event_loop()
@@ -459,6 +458,55 @@ class LightRAG:
if self.graph_storage == "OracleGraphStorage":
self.graph_storage_cls.db = oracle_db
# 检查是否使用了 TiDB 存储实现
if (
self.kv_storage == "TiDBKVStorage"
or self.vector_storage == "TiDBVectorDBStorage"
or self.graph_storage == "TiDBGraphStorage"
):
# 从环境变量或配置文件获取参数
dbconfig = {
"host": os.environ.get(
"TIDB_HOST",
config.get("tidb", "host", fallback="localhost"),
),
"port": os.environ.get(
"TIDB_PORT",
config.get("tidb", "port", fallback=4000)
),
"user": os.environ.get(
"TIDB_USER",
config.get("tidb", "user", fallback=None),
),
"password": os.environ.get(
"TIDB_PASSWORD",
config.get("tidb", "password", fallback=None),
),
"database": os.environ.get(
"TIDB_DATABASE",
config.get("tidb", "database", fallback=None),
),
"workspace": os.environ.get(
"TIDB_WORKSPACE",
config.get("tidb", "workspace", fallback="default"),
),
}
# 初始化 TiDB 对象
from .kg.tidb_impl import TiDB
tidb_db = TiDB(dbconfig)
# Check if DB tables exist, if not, tables will be created
loop = always_get_an_event_loop()
loop.run_until_complete(tidb_db.check_tables())
# 只对 TiDB 实现的存储类注入 db 对象
if self.kv_storage == "TiDBKVStorage":
self.key_string_value_json_storage_cls.db = tidb_db
if self.vector_storage == "TiDBVectorDBStorage":
self.vector_db_storage_cls.db = tidb_db
if self.graph_storage == "TiDBGraphStorage":
self.graph_storage_cls.db = tidb_db
# 检查是否使用了 PostgreSQL 存储实现
if (
self.kv_storage == "PGKVStorage"
@@ -498,6 +546,7 @@ class LightRAG:
}
# 初始化 PostgreSQLDB 对象
from .kg.postgres_impl import PostgreSQLDB
postgres_db = PostgreSQLDB(dbconfig)
# Initialize and check tables
loop = always_get_an_event_loop()