Inject TiDB同LightRAG storage when needed
This commit is contained in:
@@ -102,6 +102,8 @@ class TiDB:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class TiDBKVStorage(BaseKVStorage):
|
class TiDBKVStorage(BaseKVStorage):
|
||||||
# should pass db object to self.db
|
# should pass db object to self.db
|
||||||
|
db: TiDB = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._data = {}
|
self._data = {}
|
||||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||||
@@ -208,6 +210,8 @@ class TiDBKVStorage(BaseKVStorage):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TiDBVectorDBStorage(BaseVectorStorage):
|
class TiDBVectorDBStorage(BaseVectorStorage):
|
||||||
|
# should pass db object to self.db
|
||||||
|
db: TiDB = None
|
||||||
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@@ -329,6 +333,9 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TiDBGraphStorage(BaseGraphStorage):
|
class TiDBGraphStorage(BaseGraphStorage):
|
||||||
|
# should pass db object to self.db
|
||||||
|
db: TiDB = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||||
|
|
||||||
|
@@ -16,8 +16,6 @@ from .base import (
|
|||||||
QueryParam,
|
QueryParam,
|
||||||
StorageNameSpace,
|
StorageNameSpace,
|
||||||
)
|
)
|
||||||
from .kg.oracle_impl import OracleDB
|
|
||||||
from .kg.postgres_impl import PostgreSQLDB
|
|
||||||
from .namespace import NameSpace, make_namespace
|
from .namespace import NameSpace, make_namespace
|
||||||
from .operate import (
|
from .operate import (
|
||||||
chunking_by_token_size,
|
chunking_by_token_size,
|
||||||
@@ -446,6 +444,7 @@ class LightRAG:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# 初始化 OracleDB 对象
|
# 初始化 OracleDB 对象
|
||||||
|
from .kg.oracle_impl import OracleDB
|
||||||
oracle_db = OracleDB(dbconfig)
|
oracle_db = OracleDB(dbconfig)
|
||||||
# Check if DB tables exist, if not, tables will be created
|
# Check if DB tables exist, if not, tables will be created
|
||||||
loop = always_get_an_event_loop()
|
loop = always_get_an_event_loop()
|
||||||
@@ -459,6 +458,55 @@ class LightRAG:
|
|||||||
if self.graph_storage == "OracleGraphStorage":
|
if self.graph_storage == "OracleGraphStorage":
|
||||||
self.graph_storage_cls.db = oracle_db
|
self.graph_storage_cls.db = oracle_db
|
||||||
|
|
||||||
|
# 检查是否使用了 TiDB 存储实现
|
||||||
|
if (
|
||||||
|
self.kv_storage == "TiDBKVStorage"
|
||||||
|
or self.vector_storage == "TiDBVectorDBStorage"
|
||||||
|
or self.graph_storage == "TiDBGraphStorage"
|
||||||
|
):
|
||||||
|
# 从环境变量或配置文件获取参数
|
||||||
|
dbconfig = {
|
||||||
|
"host": os.environ.get(
|
||||||
|
"TIDB_HOST",
|
||||||
|
config.get("tidb", "host", fallback="localhost"),
|
||||||
|
),
|
||||||
|
"port": os.environ.get(
|
||||||
|
"TIDB_PORT",
|
||||||
|
config.get("tidb", "port", fallback=4000)
|
||||||
|
),
|
||||||
|
"user": os.environ.get(
|
||||||
|
"TIDB_USER",
|
||||||
|
config.get("tidb", "user", fallback=None),
|
||||||
|
),
|
||||||
|
"password": os.environ.get(
|
||||||
|
"TIDB_PASSWORD",
|
||||||
|
config.get("tidb", "password", fallback=None),
|
||||||
|
),
|
||||||
|
"database": os.environ.get(
|
||||||
|
"TIDB_DATABASE",
|
||||||
|
config.get("tidb", "database", fallback=None),
|
||||||
|
),
|
||||||
|
"workspace": os.environ.get(
|
||||||
|
"TIDB_WORKSPACE",
|
||||||
|
config.get("tidb", "workspace", fallback="default"),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
# 初始化 TiDB 对象
|
||||||
|
from .kg.tidb_impl import TiDB
|
||||||
|
tidb_db = TiDB(dbconfig)
|
||||||
|
# Check if DB tables exist, if not, tables will be created
|
||||||
|
loop = always_get_an_event_loop()
|
||||||
|
loop.run_until_complete(tidb_db.check_tables())
|
||||||
|
|
||||||
|
# 只对 TiDB 实现的存储类注入 db 对象
|
||||||
|
if self.kv_storage == "TiDBKVStorage":
|
||||||
|
self.key_string_value_json_storage_cls.db = tidb_db
|
||||||
|
if self.vector_storage == "TiDBVectorDBStorage":
|
||||||
|
self.vector_db_storage_cls.db = tidb_db
|
||||||
|
if self.graph_storage == "TiDBGraphStorage":
|
||||||
|
self.graph_storage_cls.db = tidb_db
|
||||||
|
|
||||||
# 检查是否使用了 PostgreSQL 存储实现
|
# 检查是否使用了 PostgreSQL 存储实现
|
||||||
if (
|
if (
|
||||||
self.kv_storage == "PGKVStorage"
|
self.kv_storage == "PGKVStorage"
|
||||||
@@ -498,6 +546,7 @@ class LightRAG:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# 初始化 PostgreSQLDB 对象
|
# 初始化 PostgreSQLDB 对象
|
||||||
|
from .kg.postgres_impl import PostgreSQLDB
|
||||||
postgres_db = PostgreSQLDB(dbconfig)
|
postgres_db = PostgreSQLDB(dbconfig)
|
||||||
# Initialize and check tables
|
# Initialize and check tables
|
||||||
loop = always_get_an_event_loop()
|
loop = always_get_an_event_loop()
|
||||||
|
Reference in New Issue
Block a user