Inject TiDB同LightRAG storage when needed
This commit is contained in:
@@ -102,6 +102,8 @@ class TiDB:
|
||||
@dataclass
|
||||
class TiDBKVStorage(BaseKVStorage):
|
||||
# should pass db object to self.db
|
||||
db: TiDB = None
|
||||
|
||||
def __post_init__(self):
|
||||
self._data = {}
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
@@ -208,6 +210,8 @@ class TiDBKVStorage(BaseKVStorage):
|
||||
|
||||
@dataclass
|
||||
class TiDBVectorDBStorage(BaseVectorStorage):
|
||||
# should pass db object to self.db
|
||||
db: TiDB = None
|
||||
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -329,6 +333,9 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
||||
|
||||
@dataclass
|
||||
class TiDBGraphStorage(BaseGraphStorage):
|
||||
# should pass db object to self.db
|
||||
db: TiDB = None
|
||||
|
||||
def __post_init__(self):
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
|
||||
|
@@ -16,8 +16,6 @@ from .base import (
|
||||
QueryParam,
|
||||
StorageNameSpace,
|
||||
)
|
||||
from .kg.oracle_impl import OracleDB
|
||||
from .kg.postgres_impl import PostgreSQLDB
|
||||
from .namespace import NameSpace, make_namespace
|
||||
from .operate import (
|
||||
chunking_by_token_size,
|
||||
@@ -446,6 +444,7 @@ class LightRAG:
|
||||
}
|
||||
|
||||
# 初始化 OracleDB 对象
|
||||
from .kg.oracle_impl import OracleDB
|
||||
oracle_db = OracleDB(dbconfig)
|
||||
# Check if DB tables exist, if not, tables will be created
|
||||
loop = always_get_an_event_loop()
|
||||
@@ -459,6 +458,55 @@ class LightRAG:
|
||||
if self.graph_storage == "OracleGraphStorage":
|
||||
self.graph_storage_cls.db = oracle_db
|
||||
|
||||
# 检查是否使用了 TiDB 存储实现
|
||||
if (
|
||||
self.kv_storage == "TiDBKVStorage"
|
||||
or self.vector_storage == "TiDBVectorDBStorage"
|
||||
or self.graph_storage == "TiDBGraphStorage"
|
||||
):
|
||||
# 从环境变量或配置文件获取参数
|
||||
dbconfig = {
|
||||
"host": os.environ.get(
|
||||
"TIDB_HOST",
|
||||
config.get("tidb", "host", fallback="localhost"),
|
||||
),
|
||||
"port": os.environ.get(
|
||||
"TIDB_PORT",
|
||||
config.get("tidb", "port", fallback=4000)
|
||||
),
|
||||
"user": os.environ.get(
|
||||
"TIDB_USER",
|
||||
config.get("tidb", "user", fallback=None),
|
||||
),
|
||||
"password": os.environ.get(
|
||||
"TIDB_PASSWORD",
|
||||
config.get("tidb", "password", fallback=None),
|
||||
),
|
||||
"database": os.environ.get(
|
||||
"TIDB_DATABASE",
|
||||
config.get("tidb", "database", fallback=None),
|
||||
),
|
||||
"workspace": os.environ.get(
|
||||
"TIDB_WORKSPACE",
|
||||
config.get("tidb", "workspace", fallback="default"),
|
||||
),
|
||||
}
|
||||
|
||||
# 初始化 TiDB 对象
|
||||
from .kg.tidb_impl import TiDB
|
||||
tidb_db = TiDB(dbconfig)
|
||||
# Check if DB tables exist, if not, tables will be created
|
||||
loop = always_get_an_event_loop()
|
||||
loop.run_until_complete(tidb_db.check_tables())
|
||||
|
||||
# 只对 TiDB 实现的存储类注入 db 对象
|
||||
if self.kv_storage == "TiDBKVStorage":
|
||||
self.key_string_value_json_storage_cls.db = tidb_db
|
||||
if self.vector_storage == "TiDBVectorDBStorage":
|
||||
self.vector_db_storage_cls.db = tidb_db
|
||||
if self.graph_storage == "TiDBGraphStorage":
|
||||
self.graph_storage_cls.db = tidb_db
|
||||
|
||||
# 检查是否使用了 PostgreSQL 存储实现
|
||||
if (
|
||||
self.kv_storage == "PGKVStorage"
|
||||
@@ -498,6 +546,7 @@ class LightRAG:
|
||||
}
|
||||
|
||||
# 初始化 PostgreSQLDB 对象
|
||||
from .kg.postgres_impl import PostgreSQLDB
|
||||
postgres_db = PostgreSQLDB(dbconfig)
|
||||
# Initialize and check tables
|
||||
loop = always_get_an_event_loop()
|
||||
|
Reference in New Issue
Block a user