refactor: improve database initialization by centralizing db instance injection

- Move db configs to separate methods
- Remove db field defaults in storage classes
- Add _initialize_database_if_needed method
- Inject db instances during initialization
- Clean up storage implementation code
This commit is contained in:
yangdx
2025-02-12 22:25:34 +08:00
parent cf61bed62c
commit 7b79427097
4 changed files with 205 additions and 200 deletions

View File

@@ -243,6 +243,9 @@ class LightRAG:
graph_storage: str = field(default="NetworkXStorage")
"""Storage backend for knowledge graphs."""
doc_status_storage: str = field(default="JsonDocStatusStorage")
"""Storage type for tracking document processing statuses."""
# Logging
current_log_level = logger.level
log_level: int = field(default=current_log_level)
@@ -339,9 +342,6 @@ class LightRAG:
convert_response_to_json
)
doc_status_storage: str = field(default="JsonDocStatusStorage")
"""Storage type for tracking document processing statuses."""
# Custom Chunking Function
chunking_func: Callable[
[
@@ -355,6 +355,91 @@ class LightRAG:
list[dict[str, Any]],
] = chunking_by_token_size
def _get_postgres_config(self):
return {
"host": os.environ.get(
"POSTGRES_HOST",
config.get("postgres", "host", fallback="localhost"),
),
"port": os.environ.get(
"POSTGRES_PORT", config.get("postgres", "port", fallback=5432)
),
"user": os.environ.get(
"POSTGRES_USER", config.get("postgres", "user", fallback=None)
),
"password": os.environ.get(
"POSTGRES_PASSWORD",
config.get("postgres", "password", fallback=None),
),
"database": os.environ.get(
"POSTGRES_DATABASE",
config.get("postgres", "database", fallback=None),
),
"workspace": os.environ.get(
"POSTGRES_WORKSPACE",
config.get("postgres", "workspace", fallback="default"),
),
}
def _get_oracle_config(self):
return {
"user": os.environ.get(
"ORACLE_USER",
config.get("oracle", "user", fallback=None),
),
"password": os.environ.get(
"ORACLE_PASSWORD",
config.get("oracle", "password", fallback=None),
),
"dsn": os.environ.get(
"ORACLE_DSN",
config.get("oracle", "dsn", fallback=None),
),
"config_dir": os.environ.get(
"ORACLE_CONFIG_DIR",
config.get("oracle", "config_dir", fallback=None),
),
"wallet_location": os.environ.get(
"ORACLE_WALLET_LOCATION",
config.get("oracle", "wallet_location", fallback=None),
),
"wallet_password": os.environ.get(
"ORACLE_WALLET_PASSWORD",
config.get("oracle", "wallet_password", fallback=None),
),
"workspace": os.environ.get(
"ORACLE_WORKSPACE",
config.get("oracle", "workspace", fallback="default"),
),
}
def _get_tidb_config(self):
return {
"host": os.environ.get(
"TIDB_HOST",
config.get("tidb", "host", fallback="localhost"),
),
"port": os.environ.get(
"TIDB_PORT", config.get("tidb", "port", fallback=4000)
),
"user": os.environ.get(
"TIDB_USER",
config.get("tidb", "user", fallback=None),
),
"password": os.environ.get(
"TIDB_PASSWORD",
config.get("tidb", "password", fallback=None),
),
"database": os.environ.get(
"TIDB_DATABASE",
config.get("tidb", "database", fallback=None),
),
"workspace": os.environ.get(
"TIDB_WORKSPACE",
config.get("tidb", "workspace", fallback="default"),
),
}
def verify_storage_implementation(
self, storage_type: str, storage_name: str
) -> None:
@@ -456,167 +541,6 @@ class LightRAG:
# Initialize document status storage
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
# Check if Oracle storage implementation is used
if (
self.kv_storage == "OracleKVStorage"
or self.vector_storage == "OracleVectorDBStorage"
or self.graph_storage == "OracleGraphStorage"
):
# Get parameters from environment variables or config file
dbconfig = {
"user": os.environ.get(
"ORACLE_USER",
config.get("oracle", "user", fallback=None),
),
"password": os.environ.get(
"ORACLE_PASSWORD",
config.get("oracle", "password", fallback=None),
),
"dsn": os.environ.get(
"ORACLE_DSN",
config.get("oracle", "dsn", fallback=None),
),
"config_dir": os.environ.get(
"ORACLE_CONFIG_DIR",
config.get("oracle", "config_dir", fallback=None),
),
"wallet_location": os.environ.get(
"ORACLE_WALLET_LOCATION",
config.get("oracle", "wallet_location", fallback=None),
),
"wallet_password": os.environ.get(
"ORACLE_WALLET_PASSWORD",
config.get("oracle", "wallet_password", fallback=None),
),
"workspace": os.environ.get(
"ORACLE_WORKSPACE",
config.get("oracle", "workspace", fallback="default"),
),
}
# Initialize OracleDB object
from .kg.oracle_impl import OracleDB
oracle_db = OracleDB(dbconfig)
# Check if DB tables exist, if not, tables will be created
loop = always_get_an_event_loop()
loop.run_until_complete(oracle_db.check_tables())
# Only inject db object for Oracle storage implementations
if self.kv_storage == "OracleKVStorage":
self.key_string_value_json_storage_cls.db = oracle_db
if self.vector_storage == "OracleVectorDBStorage":
self.vector_db_storage_cls.db = oracle_db
if self.graph_storage == "OracleGraphStorage":
self.graph_storage_cls.db = oracle_db
# Check if TiDB storage implementation is used
if (
self.kv_storage == "TiDBKVStorage"
or self.vector_storage == "TiDBVectorDBStorage"
or self.graph_storage == "TiDBGraphStorage"
):
# Get parameters from environment variables or config file
dbconfig = {
"host": os.environ.get(
"TIDB_HOST",
config.get("tidb", "host", fallback="localhost"),
),
"port": os.environ.get(
"TIDB_PORT", config.get("tidb", "port", fallback=4000)
),
"user": os.environ.get(
"TIDB_USER",
config.get("tidb", "user", fallback=None),
),
"password": os.environ.get(
"TIDB_PASSWORD",
config.get("tidb", "password", fallback=None),
),
"database": os.environ.get(
"TIDB_DATABASE",
config.get("tidb", "database", fallback=None),
),
"workspace": os.environ.get(
"TIDB_WORKSPACE",
config.get("tidb", "workspace", fallback="default"),
),
}
# Initialize TiDB object
from .kg.tidb_impl import TiDB
tidb_db = TiDB(dbconfig)
# Check if DB tables exist, if not, tables will be created
loop = always_get_an_event_loop()
loop.run_until_complete(tidb_db.check_tables())
# Only inject db object for TiDB storage implementations
if self.kv_storage == "TiDBKVStorage":
self.key_string_value_json_storage_cls.db = tidb_db
if self.vector_storage == "TiDBVectorDBStorage":
self.vector_db_storage_cls.db = tidb_db
if self.graph_storage == "TiDBGraphStorage":
self.graph_storage_cls.db = tidb_db
# Check if PostgreSQL storage implementation is used
if (
self.kv_storage == "PGKVStorage"
or self.vector_storage == "PGVectorStorage"
or self.graph_storage == "PGGraphStorage"
or self.doc_status_storage == "PGDocStatusStorage"
):
# Read configuration file
config_parser = configparser.ConfigParser()
if os.path.exists("config.ini"):
config_parser.read("config.ini")
# Get parameters from environment variables or config file
dbconfig = {
"host": os.environ.get(
"POSTGRES_HOST",
config.get("postgres", "host", fallback="localhost"),
),
"port": os.environ.get(
"POSTGRES_PORT", config.get("postgres", "port", fallback=5432)
),
"user": os.environ.get(
"POSTGRES_USER", config.get("postgres", "user", fallback=None)
),
"password": os.environ.get(
"POSTGRES_PASSWORD",
config.get("postgres", "password", fallback=None),
),
"database": os.environ.get(
"POSTGRES_DATABASE",
config.get("postgres", "database", fallback=None),
),
"workspace": os.environ.get(
"POSTGRES_WORKSPACE",
config.get("postgres", "workspace", fallback="default"),
),
}
# Initialize PostgreSQLDB object
from .kg.postgres_impl import PostgreSQLDB
postgres_db = PostgreSQLDB(dbconfig)
# Initialize and check tables
loop = always_get_an_event_loop()
loop.run_until_complete(postgres_db.initdb())
# Check if DB tables exist, if not, tables will be created
loop.run_until_complete(postgres_db.check_tables())
# Only inject db object for PostgreSQL storage implementations
if self.kv_storage == "PGKVStorage":
self.key_string_value_json_storage_cls.db = postgres_db
if self.vector_storage == "PGVectorStorage":
self.vector_db_storage_cls.db = postgres_db
if self.graph_storage == "PGGraphStorage":
self.graph_storage_cls.db = postgres_db
if self.doc_status_storage == "PGDocStatusStorage":
self.doc_status_storage_cls.db = postgres_db
self.llm_response_cache = self.key_string_value_json_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
@@ -664,6 +588,13 @@ class LightRAG:
embedding_func=self.embedding_func,
)
# Initialize document status storage
self.doc_status: DocStatusStorage = self.doc_status_storage_cls(
namespace=make_namespace(self.namespace_prefix, NameSpace.DOC_STATUS),
global_config=global_config,
embedding_func=None,
)
# What's for, Is this nessisary ?
if self.llm_response_cache and hasattr(
self.llm_response_cache, "global_config"
@@ -677,16 +608,21 @@ class LightRAG:
embedding_func=self.embedding_func,
)
# self.json_doc_status_storage = self.key_string_value_json_storage_cls(
# namespace=self.namespace_prefix + "json_doc_status_storage",
# embedding_func=None,
# )
self.doc_status: DocStatusStorage = self.doc_status_storage_cls(
namespace=make_namespace(self.namespace_prefix, NameSpace.DOC_STATUS),
global_config=global_config,
embedding_func=None,
)
# Collect all storage instances
storage_instances = [
self.full_docs,
self.text_chunks,
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
self.chunks_vdb,
self.doc_status,
]
# Initialize database connections if needed
loop = always_get_an_event_loop()
loop.run_until_complete(self._initialize_database_if_needed(storage_instances))
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
partial(
@@ -710,8 +646,81 @@ class LightRAG:
storage_class = lazy_external_import(import_path, storage_name)
return storage_class
async def _initialize_database_if_needed(self, storage_instances: list):
"""Intialize database connection and inject it to storage implementation if needed"""
from .kg.postgres_impl import PostgreSQLDB
from .kg.oracle_impl import OracleDB
from .kg.tidb_impl import TiDB
from .kg.postgres_impl import (
PGKVStorage,
PGVectorStorage,
PGGraphStorage,
PGDocStatusStorage,
)
from .kg.oracle_impl import (
OracleKVStorage,
OracleVectorDBStorage,
OracleGraphStorage,
)
from .kg.tidb_impl import (
TiDBKVStorage,
TiDBVectorDBStorage,
TiDBGraphStorage)
# Checking if PostgreSQL is needed
if any(
isinstance(
storage,
(PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage),
)
for storage in storage_instances
):
postgres_db = PostgreSQLDB(self._get_postgres_config())
await postgres_db.initdb()
await postgres_db.check_tables()
for storage in storage_instances:
if isinstance(
storage,
(PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage),
):
storage.db = postgres_db
logger.info(f"Injected postgres_db to {storage.__class__.__name__}")
# Checking if Oracle is needed
if any(
isinstance(
storage, (OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage)
)
for storage in storage_instances
):
oracle_db = OracleDB(self._get_oracle_config())
await oracle_db.check_tables()
for storage in storage_instances:
if isinstance(
storage,
(OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage),
):
storage.db = oracle_db
logger.info(f"Injected oracle_db to {storage.__class__.__name__}")
# Checking if TiDB is needed
if any(
isinstance(storage, (TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage))
for storage in storage_instances
):
tidb_db = TiDB(self._get_tidb_config())
await tidb_db.check_tables()
# 注入db实例
for storage in storage_instances:
if isinstance(
storage, (TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage)
):
storage.db = tidb_db
logger.info(f"Injected tidb_db to {storage.__class__.__name__}")
def set_storage_client(self, db_client):
# Now only tested on Oracle Database
# Inject db to storage implementation (only tested on Oracle Database
# Deprecated, seting correct value to *_storage creating LightRAG insteaded
for storage in [
self.vector_db_storage_cls,
self.graph_storage_cls,