refactor: improve database initialization by centralizing db instance injection
- Move db configs to separate methods - Remove db field defaults in storage classes - Add _initialize_database_if_needed method - Inject db instances during initialization - Clean up storage implementation code
This commit is contained in:
@@ -172,8 +172,8 @@ class OracleDB:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class OracleKVStorage(BaseKVStorage):
|
class OracleKVStorage(BaseKVStorage):
|
||||||
# should pass db object to self.db
|
# db instance must be injected before use
|
||||||
db: OracleDB = None
|
# db: OracleDB
|
||||||
meta_fields = None
|
meta_fields = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@@ -318,8 +318,8 @@ class OracleKVStorage(BaseKVStorage):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class OracleVectorDBStorage(BaseVectorStorage):
|
class OracleVectorDBStorage(BaseVectorStorage):
|
||||||
# should pass db object to self.db
|
# db instance must be injected before use
|
||||||
db: OracleDB = None
|
# db: OracleDB
|
||||||
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@@ -361,8 +361,8 @@ class OracleVectorDBStorage(BaseVectorStorage):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class OracleGraphStorage(BaseGraphStorage):
|
class OracleGraphStorage(BaseGraphStorage):
|
||||||
# should pass db object to self.db
|
# db instance must be injected before use
|
||||||
db: OracleDB = None
|
# db: OracleDB
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""从graphml文件加载图"""
|
"""从graphml文件加载图"""
|
||||||
|
@@ -177,7 +177,8 @@ class PostgreSQLDB:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PGKVStorage(BaseKVStorage):
|
class PGKVStorage(BaseKVStorage):
|
||||||
db: PostgreSQLDB = None
|
# db instance must be injected before use
|
||||||
|
# db: PostgreSQLDB
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||||
@@ -296,8 +297,9 @@ class PGKVStorage(BaseKVStorage):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PGVectorStorage(BaseVectorStorage):
|
class PGVectorStorage(BaseVectorStorage):
|
||||||
|
# db instance must be injected before use
|
||||||
|
# db: PostgreSQLDB
|
||||||
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
||||||
db: PostgreSQLDB = None
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||||
@@ -418,10 +420,8 @@ class PGVectorStorage(BaseVectorStorage):
|
|||||||
class PGDocStatusStorage(DocStatusStorage):
|
class PGDocStatusStorage(DocStatusStorage):
|
||||||
"""PostgreSQL implementation of document status storage"""
|
"""PostgreSQL implementation of document status storage"""
|
||||||
|
|
||||||
db: PostgreSQLDB = None
|
# db instance must be injected before use
|
||||||
|
db: PostgreSQLDB
|
||||||
def __post_init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def filter_keys(self, data: set[str]) -> set[str]:
|
async def filter_keys(self, data: set[str]) -> set[str]:
|
||||||
"""Return keys that don't exist in storage"""
|
"""Return keys that don't exist in storage"""
|
||||||
@@ -577,19 +577,15 @@ class PGGraphQueryException(Exception):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PGGraphStorage(BaseGraphStorage):
|
class PGGraphStorage(BaseGraphStorage):
|
||||||
db: PostgreSQLDB = None
|
# db instance must be injected before use
|
||||||
|
# db: PostgreSQLDB
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_nx_graph(file_name):
|
def load_nx_graph(file_name):
|
||||||
print("no preloading of graph with AGE in production")
|
print("no preloading of graph with AGE in production")
|
||||||
|
|
||||||
def __init__(self, namespace, global_config, embedding_func):
|
def __post_init__(self):
|
||||||
super().__init__(
|
self.graph_name = os.environ.get("AGE_GRAPH_NAME", "lightrag")
|
||||||
namespace=namespace,
|
|
||||||
global_config=global_config,
|
|
||||||
embedding_func=embedding_func,
|
|
||||||
)
|
|
||||||
self.graph_name = os.environ["AGE_GRAPH_NAME"]
|
|
||||||
self._node_embed_algorithms = {
|
self._node_embed_algorithms = {
|
||||||
"node2vec": self._node2vec_embed,
|
"node2vec": self._node2vec_embed,
|
||||||
}
|
}
|
||||||
|
@@ -101,8 +101,8 @@ class TiDB:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TiDBKVStorage(BaseKVStorage):
|
class TiDBKVStorage(BaseKVStorage):
|
||||||
# should pass db object to self.db
|
# db instance must be injected before use
|
||||||
db: TiDB = None
|
# db: TiDB
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._data = {}
|
self._data = {}
|
||||||
@@ -210,8 +210,8 @@ class TiDBKVStorage(BaseKVStorage):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TiDBVectorDBStorage(BaseVectorStorage):
|
class TiDBVectorDBStorage(BaseVectorStorage):
|
||||||
# should pass db object to self.db
|
# db instance must be injected before use
|
||||||
db: TiDB = None
|
# db: TiDB
|
||||||
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@@ -333,8 +333,8 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TiDBGraphStorage(BaseGraphStorage):
|
class TiDBGraphStorage(BaseGraphStorage):
|
||||||
# should pass db object to self.db
|
# db instance must be injected before use
|
||||||
db: TiDB = None
|
# db: TiDB
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||||
|
@@ -243,6 +243,9 @@ class LightRAG:
|
|||||||
graph_storage: str = field(default="NetworkXStorage")
|
graph_storage: str = field(default="NetworkXStorage")
|
||||||
"""Storage backend for knowledge graphs."""
|
"""Storage backend for knowledge graphs."""
|
||||||
|
|
||||||
|
doc_status_storage: str = field(default="JsonDocStatusStorage")
|
||||||
|
"""Storage type for tracking document processing statuses."""
|
||||||
|
|
||||||
# Logging
|
# Logging
|
||||||
current_log_level = logger.level
|
current_log_level = logger.level
|
||||||
log_level: int = field(default=current_log_level)
|
log_level: int = field(default=current_log_level)
|
||||||
@@ -339,9 +342,6 @@ class LightRAG:
|
|||||||
convert_response_to_json
|
convert_response_to_json
|
||||||
)
|
)
|
||||||
|
|
||||||
doc_status_storage: str = field(default="JsonDocStatusStorage")
|
|
||||||
"""Storage type for tracking document processing statuses."""
|
|
||||||
|
|
||||||
# Custom Chunking Function
|
# Custom Chunking Function
|
||||||
chunking_func: Callable[
|
chunking_func: Callable[
|
||||||
[
|
[
|
||||||
@@ -355,6 +355,91 @@ class LightRAG:
|
|||||||
list[dict[str, Any]],
|
list[dict[str, Any]],
|
||||||
] = chunking_by_token_size
|
] = chunking_by_token_size
|
||||||
|
|
||||||
|
def _get_postgres_config(self):
|
||||||
|
return {
|
||||||
|
"host": os.environ.get(
|
||||||
|
"POSTGRES_HOST",
|
||||||
|
config.get("postgres", "host", fallback="localhost"),
|
||||||
|
),
|
||||||
|
"port": os.environ.get(
|
||||||
|
"POSTGRES_PORT", config.get("postgres", "port", fallback=5432)
|
||||||
|
),
|
||||||
|
"user": os.environ.get(
|
||||||
|
"POSTGRES_USER", config.get("postgres", "user", fallback=None)
|
||||||
|
),
|
||||||
|
"password": os.environ.get(
|
||||||
|
"POSTGRES_PASSWORD",
|
||||||
|
config.get("postgres", "password", fallback=None),
|
||||||
|
),
|
||||||
|
"database": os.environ.get(
|
||||||
|
"POSTGRES_DATABASE",
|
||||||
|
config.get("postgres", "database", fallback=None),
|
||||||
|
),
|
||||||
|
"workspace": os.environ.get(
|
||||||
|
"POSTGRES_WORKSPACE",
|
||||||
|
config.get("postgres", "workspace", fallback="default"),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_oracle_config(self):
|
||||||
|
return {
|
||||||
|
"user": os.environ.get(
|
||||||
|
"ORACLE_USER",
|
||||||
|
config.get("oracle", "user", fallback=None),
|
||||||
|
),
|
||||||
|
"password": os.environ.get(
|
||||||
|
"ORACLE_PASSWORD",
|
||||||
|
config.get("oracle", "password", fallback=None),
|
||||||
|
),
|
||||||
|
"dsn": os.environ.get(
|
||||||
|
"ORACLE_DSN",
|
||||||
|
config.get("oracle", "dsn", fallback=None),
|
||||||
|
),
|
||||||
|
"config_dir": os.environ.get(
|
||||||
|
"ORACLE_CONFIG_DIR",
|
||||||
|
config.get("oracle", "config_dir", fallback=None),
|
||||||
|
),
|
||||||
|
"wallet_location": os.environ.get(
|
||||||
|
"ORACLE_WALLET_LOCATION",
|
||||||
|
config.get("oracle", "wallet_location", fallback=None),
|
||||||
|
),
|
||||||
|
"wallet_password": os.environ.get(
|
||||||
|
"ORACLE_WALLET_PASSWORD",
|
||||||
|
config.get("oracle", "wallet_password", fallback=None),
|
||||||
|
),
|
||||||
|
"workspace": os.environ.get(
|
||||||
|
"ORACLE_WORKSPACE",
|
||||||
|
config.get("oracle", "workspace", fallback="default"),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_tidb_config(self):
|
||||||
|
return {
|
||||||
|
"host": os.environ.get(
|
||||||
|
"TIDB_HOST",
|
||||||
|
config.get("tidb", "host", fallback="localhost"),
|
||||||
|
),
|
||||||
|
"port": os.environ.get(
|
||||||
|
"TIDB_PORT", config.get("tidb", "port", fallback=4000)
|
||||||
|
),
|
||||||
|
"user": os.environ.get(
|
||||||
|
"TIDB_USER",
|
||||||
|
config.get("tidb", "user", fallback=None),
|
||||||
|
),
|
||||||
|
"password": os.environ.get(
|
||||||
|
"TIDB_PASSWORD",
|
||||||
|
config.get("tidb", "password", fallback=None),
|
||||||
|
),
|
||||||
|
"database": os.environ.get(
|
||||||
|
"TIDB_DATABASE",
|
||||||
|
config.get("tidb", "database", fallback=None),
|
||||||
|
),
|
||||||
|
"workspace": os.environ.get(
|
||||||
|
"TIDB_WORKSPACE",
|
||||||
|
config.get("tidb", "workspace", fallback="default"),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
def verify_storage_implementation(
|
def verify_storage_implementation(
|
||||||
self, storage_type: str, storage_name: str
|
self, storage_type: str, storage_name: str
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -456,167 +541,6 @@ class LightRAG:
|
|||||||
# Initialize document status storage
|
# Initialize document status storage
|
||||||
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
|
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
|
||||||
|
|
||||||
# Check if Oracle storage implementation is used
|
|
||||||
if (
|
|
||||||
self.kv_storage == "OracleKVStorage"
|
|
||||||
or self.vector_storage == "OracleVectorDBStorage"
|
|
||||||
or self.graph_storage == "OracleGraphStorage"
|
|
||||||
):
|
|
||||||
# Get parameters from environment variables or config file
|
|
||||||
dbconfig = {
|
|
||||||
"user": os.environ.get(
|
|
||||||
"ORACLE_USER",
|
|
||||||
config.get("oracle", "user", fallback=None),
|
|
||||||
),
|
|
||||||
"password": os.environ.get(
|
|
||||||
"ORACLE_PASSWORD",
|
|
||||||
config.get("oracle", "password", fallback=None),
|
|
||||||
),
|
|
||||||
"dsn": os.environ.get(
|
|
||||||
"ORACLE_DSN",
|
|
||||||
config.get("oracle", "dsn", fallback=None),
|
|
||||||
),
|
|
||||||
"config_dir": os.environ.get(
|
|
||||||
"ORACLE_CONFIG_DIR",
|
|
||||||
config.get("oracle", "config_dir", fallback=None),
|
|
||||||
),
|
|
||||||
"wallet_location": os.environ.get(
|
|
||||||
"ORACLE_WALLET_LOCATION",
|
|
||||||
config.get("oracle", "wallet_location", fallback=None),
|
|
||||||
),
|
|
||||||
"wallet_password": os.environ.get(
|
|
||||||
"ORACLE_WALLET_PASSWORD",
|
|
||||||
config.get("oracle", "wallet_password", fallback=None),
|
|
||||||
),
|
|
||||||
"workspace": os.environ.get(
|
|
||||||
"ORACLE_WORKSPACE",
|
|
||||||
config.get("oracle", "workspace", fallback="default"),
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Initialize OracleDB object
|
|
||||||
from .kg.oracle_impl import OracleDB
|
|
||||||
|
|
||||||
oracle_db = OracleDB(dbconfig)
|
|
||||||
# Check if DB tables exist, if not, tables will be created
|
|
||||||
loop = always_get_an_event_loop()
|
|
||||||
loop.run_until_complete(oracle_db.check_tables())
|
|
||||||
|
|
||||||
# Only inject db object for Oracle storage implementations
|
|
||||||
if self.kv_storage == "OracleKVStorage":
|
|
||||||
self.key_string_value_json_storage_cls.db = oracle_db
|
|
||||||
if self.vector_storage == "OracleVectorDBStorage":
|
|
||||||
self.vector_db_storage_cls.db = oracle_db
|
|
||||||
if self.graph_storage == "OracleGraphStorage":
|
|
||||||
self.graph_storage_cls.db = oracle_db
|
|
||||||
|
|
||||||
# Check if TiDB storage implementation is used
|
|
||||||
if (
|
|
||||||
self.kv_storage == "TiDBKVStorage"
|
|
||||||
or self.vector_storage == "TiDBVectorDBStorage"
|
|
||||||
or self.graph_storage == "TiDBGraphStorage"
|
|
||||||
):
|
|
||||||
# Get parameters from environment variables or config file
|
|
||||||
dbconfig = {
|
|
||||||
"host": os.environ.get(
|
|
||||||
"TIDB_HOST",
|
|
||||||
config.get("tidb", "host", fallback="localhost"),
|
|
||||||
),
|
|
||||||
"port": os.environ.get(
|
|
||||||
"TIDB_PORT", config.get("tidb", "port", fallback=4000)
|
|
||||||
),
|
|
||||||
"user": os.environ.get(
|
|
||||||
"TIDB_USER",
|
|
||||||
config.get("tidb", "user", fallback=None),
|
|
||||||
),
|
|
||||||
"password": os.environ.get(
|
|
||||||
"TIDB_PASSWORD",
|
|
||||||
config.get("tidb", "password", fallback=None),
|
|
||||||
),
|
|
||||||
"database": os.environ.get(
|
|
||||||
"TIDB_DATABASE",
|
|
||||||
config.get("tidb", "database", fallback=None),
|
|
||||||
),
|
|
||||||
"workspace": os.environ.get(
|
|
||||||
"TIDB_WORKSPACE",
|
|
||||||
config.get("tidb", "workspace", fallback="default"),
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Initialize TiDB object
|
|
||||||
from .kg.tidb_impl import TiDB
|
|
||||||
|
|
||||||
tidb_db = TiDB(dbconfig)
|
|
||||||
# Check if DB tables exist, if not, tables will be created
|
|
||||||
loop = always_get_an_event_loop()
|
|
||||||
loop.run_until_complete(tidb_db.check_tables())
|
|
||||||
|
|
||||||
# Only inject db object for TiDB storage implementations
|
|
||||||
if self.kv_storage == "TiDBKVStorage":
|
|
||||||
self.key_string_value_json_storage_cls.db = tidb_db
|
|
||||||
if self.vector_storage == "TiDBVectorDBStorage":
|
|
||||||
self.vector_db_storage_cls.db = tidb_db
|
|
||||||
if self.graph_storage == "TiDBGraphStorage":
|
|
||||||
self.graph_storage_cls.db = tidb_db
|
|
||||||
|
|
||||||
# Check if PostgreSQL storage implementation is used
|
|
||||||
if (
|
|
||||||
self.kv_storage == "PGKVStorage"
|
|
||||||
or self.vector_storage == "PGVectorStorage"
|
|
||||||
or self.graph_storage == "PGGraphStorage"
|
|
||||||
or self.doc_status_storage == "PGDocStatusStorage"
|
|
||||||
):
|
|
||||||
# Read configuration file
|
|
||||||
config_parser = configparser.ConfigParser()
|
|
||||||
if os.path.exists("config.ini"):
|
|
||||||
config_parser.read("config.ini")
|
|
||||||
|
|
||||||
# Get parameters from environment variables or config file
|
|
||||||
dbconfig = {
|
|
||||||
"host": os.environ.get(
|
|
||||||
"POSTGRES_HOST",
|
|
||||||
config.get("postgres", "host", fallback="localhost"),
|
|
||||||
),
|
|
||||||
"port": os.environ.get(
|
|
||||||
"POSTGRES_PORT", config.get("postgres", "port", fallback=5432)
|
|
||||||
),
|
|
||||||
"user": os.environ.get(
|
|
||||||
"POSTGRES_USER", config.get("postgres", "user", fallback=None)
|
|
||||||
),
|
|
||||||
"password": os.environ.get(
|
|
||||||
"POSTGRES_PASSWORD",
|
|
||||||
config.get("postgres", "password", fallback=None),
|
|
||||||
),
|
|
||||||
"database": os.environ.get(
|
|
||||||
"POSTGRES_DATABASE",
|
|
||||||
config.get("postgres", "database", fallback=None),
|
|
||||||
),
|
|
||||||
"workspace": os.environ.get(
|
|
||||||
"POSTGRES_WORKSPACE",
|
|
||||||
config.get("postgres", "workspace", fallback="default"),
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Initialize PostgreSQLDB object
|
|
||||||
from .kg.postgres_impl import PostgreSQLDB
|
|
||||||
|
|
||||||
postgres_db = PostgreSQLDB(dbconfig)
|
|
||||||
# Initialize and check tables
|
|
||||||
loop = always_get_an_event_loop()
|
|
||||||
loop.run_until_complete(postgres_db.initdb())
|
|
||||||
# Check if DB tables exist, if not, tables will be created
|
|
||||||
loop.run_until_complete(postgres_db.check_tables())
|
|
||||||
|
|
||||||
# Only inject db object for PostgreSQL storage implementations
|
|
||||||
if self.kv_storage == "PGKVStorage":
|
|
||||||
self.key_string_value_json_storage_cls.db = postgres_db
|
|
||||||
if self.vector_storage == "PGVectorStorage":
|
|
||||||
self.vector_db_storage_cls.db = postgres_db
|
|
||||||
if self.graph_storage == "PGGraphStorage":
|
|
||||||
self.graph_storage_cls.db = postgres_db
|
|
||||||
if self.doc_status_storage == "PGDocStatusStorage":
|
|
||||||
self.doc_status_storage_cls.db = postgres_db
|
|
||||||
|
|
||||||
self.llm_response_cache = self.key_string_value_json_storage_cls(
|
self.llm_response_cache = self.key_string_value_json_storage_cls(
|
||||||
namespace=make_namespace(
|
namespace=make_namespace(
|
||||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
||||||
@@ -664,6 +588,13 @@ class LightRAG:
|
|||||||
embedding_func=self.embedding_func,
|
embedding_func=self.embedding_func,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Initialize document status storage
|
||||||
|
self.doc_status: DocStatusStorage = self.doc_status_storage_cls(
|
||||||
|
namespace=make_namespace(self.namespace_prefix, NameSpace.DOC_STATUS),
|
||||||
|
global_config=global_config,
|
||||||
|
embedding_func=None,
|
||||||
|
)
|
||||||
|
|
||||||
# What's for, Is this nessisary ?
|
# What's for, Is this nessisary ?
|
||||||
if self.llm_response_cache and hasattr(
|
if self.llm_response_cache and hasattr(
|
||||||
self.llm_response_cache, "global_config"
|
self.llm_response_cache, "global_config"
|
||||||
@@ -677,16 +608,21 @@ class LightRAG:
|
|||||||
embedding_func=self.embedding_func,
|
embedding_func=self.embedding_func,
|
||||||
)
|
)
|
||||||
|
|
||||||
# self.json_doc_status_storage = self.key_string_value_json_storage_cls(
|
|
||||||
# namespace=self.namespace_prefix + "json_doc_status_storage",
|
|
||||||
# embedding_func=None,
|
|
||||||
# )
|
|
||||||
|
|
||||||
self.doc_status: DocStatusStorage = self.doc_status_storage_cls(
|
# Collect all storage instances
|
||||||
namespace=make_namespace(self.namespace_prefix, NameSpace.DOC_STATUS),
|
storage_instances = [
|
||||||
global_config=global_config,
|
self.full_docs,
|
||||||
embedding_func=None,
|
self.text_chunks,
|
||||||
)
|
self.chunk_entity_relation_graph,
|
||||||
|
self.entities_vdb,
|
||||||
|
self.relationships_vdb,
|
||||||
|
self.chunks_vdb,
|
||||||
|
self.doc_status,
|
||||||
|
]
|
||||||
|
|
||||||
|
# Initialize database connections if needed
|
||||||
|
loop = always_get_an_event_loop()
|
||||||
|
loop.run_until_complete(self._initialize_database_if_needed(storage_instances))
|
||||||
|
|
||||||
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
||||||
partial(
|
partial(
|
||||||
@@ -710,8 +646,81 @@ class LightRAG:
|
|||||||
storage_class = lazy_external_import(import_path, storage_name)
|
storage_class = lazy_external_import(import_path, storage_name)
|
||||||
return storage_class
|
return storage_class
|
||||||
|
|
||||||
|
async def _initialize_database_if_needed(self, storage_instances: list):
|
||||||
|
"""Intialize database connection and inject it to storage implementation if needed"""
|
||||||
|
from .kg.postgres_impl import PostgreSQLDB
|
||||||
|
from .kg.oracle_impl import OracleDB
|
||||||
|
from .kg.tidb_impl import TiDB
|
||||||
|
from .kg.postgres_impl import (
|
||||||
|
PGKVStorage,
|
||||||
|
PGVectorStorage,
|
||||||
|
PGGraphStorage,
|
||||||
|
PGDocStatusStorage,
|
||||||
|
)
|
||||||
|
from .kg.oracle_impl import (
|
||||||
|
OracleKVStorage,
|
||||||
|
OracleVectorDBStorage,
|
||||||
|
OracleGraphStorage,
|
||||||
|
)
|
||||||
|
from .kg.tidb_impl import (
|
||||||
|
TiDBKVStorage,
|
||||||
|
TiDBVectorDBStorage,
|
||||||
|
TiDBGraphStorage)
|
||||||
|
|
||||||
|
# Checking if PostgreSQL is needed
|
||||||
|
if any(
|
||||||
|
isinstance(
|
||||||
|
storage,
|
||||||
|
(PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage),
|
||||||
|
)
|
||||||
|
for storage in storage_instances
|
||||||
|
):
|
||||||
|
postgres_db = PostgreSQLDB(self._get_postgres_config())
|
||||||
|
await postgres_db.initdb()
|
||||||
|
await postgres_db.check_tables()
|
||||||
|
for storage in storage_instances:
|
||||||
|
if isinstance(
|
||||||
|
storage,
|
||||||
|
(PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage),
|
||||||
|
):
|
||||||
|
storage.db = postgres_db
|
||||||
|
logger.info(f"Injected postgres_db to {storage.__class__.__name__}")
|
||||||
|
|
||||||
|
# Checking if Oracle is needed
|
||||||
|
if any(
|
||||||
|
isinstance(
|
||||||
|
storage, (OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage)
|
||||||
|
)
|
||||||
|
for storage in storage_instances
|
||||||
|
):
|
||||||
|
oracle_db = OracleDB(self._get_oracle_config())
|
||||||
|
await oracle_db.check_tables()
|
||||||
|
for storage in storage_instances:
|
||||||
|
if isinstance(
|
||||||
|
storage,
|
||||||
|
(OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage),
|
||||||
|
):
|
||||||
|
storage.db = oracle_db
|
||||||
|
logger.info(f"Injected oracle_db to {storage.__class__.__name__}")
|
||||||
|
|
||||||
|
# Checking if TiDB is needed
|
||||||
|
if any(
|
||||||
|
isinstance(storage, (TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage))
|
||||||
|
for storage in storage_instances
|
||||||
|
):
|
||||||
|
tidb_db = TiDB(self._get_tidb_config())
|
||||||
|
await tidb_db.check_tables()
|
||||||
|
# 注入db实例
|
||||||
|
for storage in storage_instances:
|
||||||
|
if isinstance(
|
||||||
|
storage, (TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage)
|
||||||
|
):
|
||||||
|
storage.db = tidb_db
|
||||||
|
logger.info(f"Injected tidb_db to {storage.__class__.__name__}")
|
||||||
|
|
||||||
def set_storage_client(self, db_client):
|
def set_storage_client(self, db_client):
|
||||||
# Now only tested on Oracle Database
|
# Inject db to storage implementation (only tested on Oracle Database
|
||||||
|
# Deprecated, seting correct value to *_storage creating LightRAG insteaded
|
||||||
for storage in [
|
for storage in [
|
||||||
self.vector_db_storage_cls,
|
self.vector_db_storage_cls,
|
||||||
self.graph_storage_cls,
|
self.graph_storage_cls,
|
||||||
|
Reference in New Issue
Block a user