diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index d9eeb2dd..7c75e2d8 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -102,6 +102,8 @@ class TiDB: @dataclass class TiDBKVStorage(BaseKVStorage): # should pass db object to self.db + db: TiDB = None + def __post_init__(self): self._data = {} self._max_batch_size = self.global_config["embedding_batch_num"] @@ -208,6 +210,8 @@ class TiDBKVStorage(BaseKVStorage): @dataclass class TiDBVectorDBStorage(BaseVectorStorage): + # should pass db object to self.db + db: TiDB = None cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) def __post_init__(self): @@ -329,6 +333,9 @@ class TiDBVectorDBStorage(BaseVectorStorage): @dataclass class TiDBGraphStorage(BaseGraphStorage): + # should pass db object to self.db + db: TiDB = None + def __post_init__(self): self._max_batch_size = self.global_config["embedding_batch_num"] diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 5518023d..8ef65951 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -16,8 +16,6 @@ from .base import ( QueryParam, StorageNameSpace, ) -from .kg.oracle_impl import OracleDB -from .kg.postgres_impl import PostgreSQLDB from .namespace import NameSpace, make_namespace from .operate import ( chunking_by_token_size, @@ -446,6 +444,7 @@ class LightRAG: } # 初始化 OracleDB 对象 + from .kg.oracle_impl import OracleDB oracle_db = OracleDB(dbconfig) # Check if DB tables exist, if not, tables will be created loop = always_get_an_event_loop() @@ -459,6 +458,55 @@ class LightRAG: if self.graph_storage == "OracleGraphStorage": self.graph_storage_cls.db = oracle_db + # 检查是否使用了 TiDB 存储实现 + if ( + self.kv_storage == "TiDBKVStorage" + or self.vector_storage == "TiDBVectorDBStorage" + or self.graph_storage == "TiDBGraphStorage" + ): + # 从环境变量或配置文件获取参数 + dbconfig = { + "host": os.environ.get( + "TIDB_HOST", + config.get("tidb", "host", fallback="localhost"), + ), + "port": os.environ.get( + "TIDB_PORT", + config.get("tidb", "port", fallback=4000) + ), + "user": os.environ.get( + "TIDB_USER", + config.get("tidb", "user", fallback=None), + ), + "password": os.environ.get( + "TIDB_PASSWORD", + config.get("tidb", "password", fallback=None), + ), + "database": os.environ.get( + "TIDB_DATABASE", + config.get("tidb", "database", fallback=None), + ), + "workspace": os.environ.get( + "TIDB_WORKSPACE", + config.get("tidb", "workspace", fallback="default"), + ), + } + + # 初始化 TiDB 对象 + from .kg.tidb_impl import TiDB + tidb_db = TiDB(dbconfig) + # Check if DB tables exist, if not, tables will be created + loop = always_get_an_event_loop() + loop.run_until_complete(tidb_db.check_tables()) + + # 只对 TiDB 实现的存储类注入 db 对象 + if self.kv_storage == "TiDBKVStorage": + self.key_string_value_json_storage_cls.db = tidb_db + if self.vector_storage == "TiDBVectorDBStorage": + self.vector_db_storage_cls.db = tidb_db + if self.graph_storage == "TiDBGraphStorage": + self.graph_storage_cls.db = tidb_db + # 检查是否使用了 PostgreSQL 存储实现 if ( self.kv_storage == "PGKVStorage" @@ -498,6 +546,7 @@ class LightRAG: } # 初始化 PostgreSQLDB 对象 + from .kg.postgres_impl import PostgreSQLDB postgres_db = PostgreSQLDB(dbconfig) # Initialize and check tables loop = always_get_an_event_loop()