refactor: improve database initialization by centralizing db instance injection

- Move db configs to separate methods
- Remove db field defaults in storage classes
- Add _initialize_database_if_needed method
- Inject db instances during initialization
- Clean up storage implementation code
This commit is contained in:
yangdx
2025-02-12 22:25:34 +08:00
parent cf61bed62c
commit 7b79427097
4 changed files with 205 additions and 200 deletions

View File

@@ -172,8 +172,8 @@ class OracleDB:
@dataclass @dataclass
class OracleKVStorage(BaseKVStorage): class OracleKVStorage(BaseKVStorage):
# should pass db object to self.db # db instance must be injected before use
db: OracleDB = None # db: OracleDB
meta_fields = None meta_fields = None
def __post_init__(self): def __post_init__(self):
@@ -318,8 +318,8 @@ class OracleKVStorage(BaseKVStorage):
@dataclass @dataclass
class OracleVectorDBStorage(BaseVectorStorage): class OracleVectorDBStorage(BaseVectorStorage):
# should pass db object to self.db # db instance must be injected before use
db: OracleDB = None # db: OracleDB
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
def __post_init__(self): def __post_init__(self):
@@ -361,8 +361,8 @@ class OracleVectorDBStorage(BaseVectorStorage):
@dataclass @dataclass
class OracleGraphStorage(BaseGraphStorage): class OracleGraphStorage(BaseGraphStorage):
# should pass db object to self.db # db instance must be injected before use
db: OracleDB = None # db: OracleDB
def __post_init__(self): def __post_init__(self):
"""从graphml文件加载图""" """从graphml文件加载图"""

View File

@@ -177,7 +177,8 @@ class PostgreSQLDB:
@dataclass @dataclass
class PGKVStorage(BaseKVStorage): class PGKVStorage(BaseKVStorage):
db: PostgreSQLDB = None # db instance must be injected before use
# db: PostgreSQLDB
def __post_init__(self): def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
@@ -296,8 +297,9 @@ class PGKVStorage(BaseKVStorage):
@dataclass @dataclass
class PGVectorStorage(BaseVectorStorage): class PGVectorStorage(BaseVectorStorage):
# db instance must be injected before use
# db: PostgreSQLDB
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
db: PostgreSQLDB = None
def __post_init__(self): def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
@@ -418,10 +420,8 @@ class PGVectorStorage(BaseVectorStorage):
class PGDocStatusStorage(DocStatusStorage): class PGDocStatusStorage(DocStatusStorage):
"""PostgreSQL implementation of document status storage""" """PostgreSQL implementation of document status storage"""
db: PostgreSQLDB = None # db instance must be injected before use
db: PostgreSQLDB
def __post_init__(self):
pass
async def filter_keys(self, data: set[str]) -> set[str]: async def filter_keys(self, data: set[str]) -> set[str]:
"""Return keys that don't exist in storage""" """Return keys that don't exist in storage"""
@@ -577,19 +577,15 @@ class PGGraphQueryException(Exception):
@dataclass @dataclass
class PGGraphStorage(BaseGraphStorage): class PGGraphStorage(BaseGraphStorage):
db: PostgreSQLDB = None # db instance must be injected before use
# db: PostgreSQLDB
@staticmethod @staticmethod
def load_nx_graph(file_name): def load_nx_graph(file_name):
print("no preloading of graph with AGE in production") print("no preloading of graph with AGE in production")
def __init__(self, namespace, global_config, embedding_func): def __post_init__(self):
super().__init__( self.graph_name = os.environ.get("AGE_GRAPH_NAME", "lightrag")
namespace=namespace,
global_config=global_config,
embedding_func=embedding_func,
)
self.graph_name = os.environ["AGE_GRAPH_NAME"]
self._node_embed_algorithms = { self._node_embed_algorithms = {
"node2vec": self._node2vec_embed, "node2vec": self._node2vec_embed,
} }

View File

@@ -101,8 +101,8 @@ class TiDB:
@dataclass @dataclass
class TiDBKVStorage(BaseKVStorage): class TiDBKVStorage(BaseKVStorage):
# should pass db object to self.db # db instance must be injected before use
db: TiDB = None # db: TiDB
def __post_init__(self): def __post_init__(self):
self._data = {} self._data = {}
@@ -210,8 +210,8 @@ class TiDBKVStorage(BaseKVStorage):
@dataclass @dataclass
class TiDBVectorDBStorage(BaseVectorStorage): class TiDBVectorDBStorage(BaseVectorStorage):
# should pass db object to self.db # db instance must be injected before use
db: TiDB = None # db: TiDB
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
def __post_init__(self): def __post_init__(self):
@@ -333,8 +333,8 @@ class TiDBVectorDBStorage(BaseVectorStorage):
@dataclass @dataclass
class TiDBGraphStorage(BaseGraphStorage): class TiDBGraphStorage(BaseGraphStorage):
# should pass db object to self.db # db instance must be injected before use
db: TiDB = None # db: TiDB
def __post_init__(self): def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]

View File

@@ -243,6 +243,9 @@ class LightRAG:
graph_storage: str = field(default="NetworkXStorage") graph_storage: str = field(default="NetworkXStorage")
"""Storage backend for knowledge graphs.""" """Storage backend for knowledge graphs."""
doc_status_storage: str = field(default="JsonDocStatusStorage")
"""Storage type for tracking document processing statuses."""
# Logging # Logging
current_log_level = logger.level current_log_level = logger.level
log_level: int = field(default=current_log_level) log_level: int = field(default=current_log_level)
@@ -339,9 +342,6 @@ class LightRAG:
convert_response_to_json convert_response_to_json
) )
doc_status_storage: str = field(default="JsonDocStatusStorage")
"""Storage type for tracking document processing statuses."""
# Custom Chunking Function # Custom Chunking Function
chunking_func: Callable[ chunking_func: Callable[
[ [
@@ -355,6 +355,91 @@ class LightRAG:
list[dict[str, Any]], list[dict[str, Any]],
] = chunking_by_token_size ] = chunking_by_token_size
def _get_postgres_config(self):
return {
"host": os.environ.get(
"POSTGRES_HOST",
config.get("postgres", "host", fallback="localhost"),
),
"port": os.environ.get(
"POSTGRES_PORT", config.get("postgres", "port", fallback=5432)
),
"user": os.environ.get(
"POSTGRES_USER", config.get("postgres", "user", fallback=None)
),
"password": os.environ.get(
"POSTGRES_PASSWORD",
config.get("postgres", "password", fallback=None),
),
"database": os.environ.get(
"POSTGRES_DATABASE",
config.get("postgres", "database", fallback=None),
),
"workspace": os.environ.get(
"POSTGRES_WORKSPACE",
config.get("postgres", "workspace", fallback="default"),
),
}
def _get_oracle_config(self):
return {
"user": os.environ.get(
"ORACLE_USER",
config.get("oracle", "user", fallback=None),
),
"password": os.environ.get(
"ORACLE_PASSWORD",
config.get("oracle", "password", fallback=None),
),
"dsn": os.environ.get(
"ORACLE_DSN",
config.get("oracle", "dsn", fallback=None),
),
"config_dir": os.environ.get(
"ORACLE_CONFIG_DIR",
config.get("oracle", "config_dir", fallback=None),
),
"wallet_location": os.environ.get(
"ORACLE_WALLET_LOCATION",
config.get("oracle", "wallet_location", fallback=None),
),
"wallet_password": os.environ.get(
"ORACLE_WALLET_PASSWORD",
config.get("oracle", "wallet_password", fallback=None),
),
"workspace": os.environ.get(
"ORACLE_WORKSPACE",
config.get("oracle", "workspace", fallback="default"),
),
}
def _get_tidb_config(self):
return {
"host": os.environ.get(
"TIDB_HOST",
config.get("tidb", "host", fallback="localhost"),
),
"port": os.environ.get(
"TIDB_PORT", config.get("tidb", "port", fallback=4000)
),
"user": os.environ.get(
"TIDB_USER",
config.get("tidb", "user", fallback=None),
),
"password": os.environ.get(
"TIDB_PASSWORD",
config.get("tidb", "password", fallback=None),
),
"database": os.environ.get(
"TIDB_DATABASE",
config.get("tidb", "database", fallback=None),
),
"workspace": os.environ.get(
"TIDB_WORKSPACE",
config.get("tidb", "workspace", fallback="default"),
),
}
def verify_storage_implementation( def verify_storage_implementation(
self, storage_type: str, storage_name: str self, storage_type: str, storage_name: str
) -> None: ) -> None:
@@ -456,167 +541,6 @@ class LightRAG:
# Initialize document status storage # Initialize document status storage
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage) self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
# Check if Oracle storage implementation is used
if (
self.kv_storage == "OracleKVStorage"
or self.vector_storage == "OracleVectorDBStorage"
or self.graph_storage == "OracleGraphStorage"
):
# Get parameters from environment variables or config file
dbconfig = {
"user": os.environ.get(
"ORACLE_USER",
config.get("oracle", "user", fallback=None),
),
"password": os.environ.get(
"ORACLE_PASSWORD",
config.get("oracle", "password", fallback=None),
),
"dsn": os.environ.get(
"ORACLE_DSN",
config.get("oracle", "dsn", fallback=None),
),
"config_dir": os.environ.get(
"ORACLE_CONFIG_DIR",
config.get("oracle", "config_dir", fallback=None),
),
"wallet_location": os.environ.get(
"ORACLE_WALLET_LOCATION",
config.get("oracle", "wallet_location", fallback=None),
),
"wallet_password": os.environ.get(
"ORACLE_WALLET_PASSWORD",
config.get("oracle", "wallet_password", fallback=None),
),
"workspace": os.environ.get(
"ORACLE_WORKSPACE",
config.get("oracle", "workspace", fallback="default"),
),
}
# Initialize OracleDB object
from .kg.oracle_impl import OracleDB
oracle_db = OracleDB(dbconfig)
# Check if DB tables exist, if not, tables will be created
loop = always_get_an_event_loop()
loop.run_until_complete(oracle_db.check_tables())
# Only inject db object for Oracle storage implementations
if self.kv_storage == "OracleKVStorage":
self.key_string_value_json_storage_cls.db = oracle_db
if self.vector_storage == "OracleVectorDBStorage":
self.vector_db_storage_cls.db = oracle_db
if self.graph_storage == "OracleGraphStorage":
self.graph_storage_cls.db = oracle_db
# Check if TiDB storage implementation is used
if (
self.kv_storage == "TiDBKVStorage"
or self.vector_storage == "TiDBVectorDBStorage"
or self.graph_storage == "TiDBGraphStorage"
):
# Get parameters from environment variables or config file
dbconfig = {
"host": os.environ.get(
"TIDB_HOST",
config.get("tidb", "host", fallback="localhost"),
),
"port": os.environ.get(
"TIDB_PORT", config.get("tidb", "port", fallback=4000)
),
"user": os.environ.get(
"TIDB_USER",
config.get("tidb", "user", fallback=None),
),
"password": os.environ.get(
"TIDB_PASSWORD",
config.get("tidb", "password", fallback=None),
),
"database": os.environ.get(
"TIDB_DATABASE",
config.get("tidb", "database", fallback=None),
),
"workspace": os.environ.get(
"TIDB_WORKSPACE",
config.get("tidb", "workspace", fallback="default"),
),
}
# Initialize TiDB object
from .kg.tidb_impl import TiDB
tidb_db = TiDB(dbconfig)
# Check if DB tables exist, if not, tables will be created
loop = always_get_an_event_loop()
loop.run_until_complete(tidb_db.check_tables())
# Only inject db object for TiDB storage implementations
if self.kv_storage == "TiDBKVStorage":
self.key_string_value_json_storage_cls.db = tidb_db
if self.vector_storage == "TiDBVectorDBStorage":
self.vector_db_storage_cls.db = tidb_db
if self.graph_storage == "TiDBGraphStorage":
self.graph_storage_cls.db = tidb_db
# Check if PostgreSQL storage implementation is used
if (
self.kv_storage == "PGKVStorage"
or self.vector_storage == "PGVectorStorage"
or self.graph_storage == "PGGraphStorage"
or self.doc_status_storage == "PGDocStatusStorage"
):
# Read configuration file
config_parser = configparser.ConfigParser()
if os.path.exists("config.ini"):
config_parser.read("config.ini")
# Get parameters from environment variables or config file
dbconfig = {
"host": os.environ.get(
"POSTGRES_HOST",
config.get("postgres", "host", fallback="localhost"),
),
"port": os.environ.get(
"POSTGRES_PORT", config.get("postgres", "port", fallback=5432)
),
"user": os.environ.get(
"POSTGRES_USER", config.get("postgres", "user", fallback=None)
),
"password": os.environ.get(
"POSTGRES_PASSWORD",
config.get("postgres", "password", fallback=None),
),
"database": os.environ.get(
"POSTGRES_DATABASE",
config.get("postgres", "database", fallback=None),
),
"workspace": os.environ.get(
"POSTGRES_WORKSPACE",
config.get("postgres", "workspace", fallback="default"),
),
}
# Initialize PostgreSQLDB object
from .kg.postgres_impl import PostgreSQLDB
postgres_db = PostgreSQLDB(dbconfig)
# Initialize and check tables
loop = always_get_an_event_loop()
loop.run_until_complete(postgres_db.initdb())
# Check if DB tables exist, if not, tables will be created
loop.run_until_complete(postgres_db.check_tables())
# Only inject db object for PostgreSQL storage implementations
if self.kv_storage == "PGKVStorage":
self.key_string_value_json_storage_cls.db = postgres_db
if self.vector_storage == "PGVectorStorage":
self.vector_db_storage_cls.db = postgres_db
if self.graph_storage == "PGGraphStorage":
self.graph_storage_cls.db = postgres_db
if self.doc_status_storage == "PGDocStatusStorage":
self.doc_status_storage_cls.db = postgres_db
self.llm_response_cache = self.key_string_value_json_storage_cls( self.llm_response_cache = self.key_string_value_json_storage_cls(
namespace=make_namespace( namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
@@ -664,6 +588,13 @@ class LightRAG:
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
) )
# Initialize document status storage
self.doc_status: DocStatusStorage = self.doc_status_storage_cls(
namespace=make_namespace(self.namespace_prefix, NameSpace.DOC_STATUS),
global_config=global_config,
embedding_func=None,
)
# What's for, Is this nessisary ? # What's for, Is this nessisary ?
if self.llm_response_cache and hasattr( if self.llm_response_cache and hasattr(
self.llm_response_cache, "global_config" self.llm_response_cache, "global_config"
@@ -677,16 +608,21 @@ class LightRAG:
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
) )
# self.json_doc_status_storage = self.key_string_value_json_storage_cls(
# namespace=self.namespace_prefix + "json_doc_status_storage",
# embedding_func=None,
# )
self.doc_status: DocStatusStorage = self.doc_status_storage_cls( # Collect all storage instances
namespace=make_namespace(self.namespace_prefix, NameSpace.DOC_STATUS), storage_instances = [
global_config=global_config, self.full_docs,
embedding_func=None, self.text_chunks,
) self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
self.chunks_vdb,
self.doc_status,
]
# Initialize database connections if needed
loop = always_get_an_event_loop()
loop.run_until_complete(self._initialize_database_if_needed(storage_instances))
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
partial( partial(
@@ -710,8 +646,81 @@ class LightRAG:
storage_class = lazy_external_import(import_path, storage_name) storage_class = lazy_external_import(import_path, storage_name)
return storage_class return storage_class
async def _initialize_database_if_needed(self, storage_instances: list):
"""Intialize database connection and inject it to storage implementation if needed"""
from .kg.postgres_impl import PostgreSQLDB
from .kg.oracle_impl import OracleDB
from .kg.tidb_impl import TiDB
from .kg.postgres_impl import (
PGKVStorage,
PGVectorStorage,
PGGraphStorage,
PGDocStatusStorage,
)
from .kg.oracle_impl import (
OracleKVStorage,
OracleVectorDBStorage,
OracleGraphStorage,
)
from .kg.tidb_impl import (
TiDBKVStorage,
TiDBVectorDBStorage,
TiDBGraphStorage)
# Checking if PostgreSQL is needed
if any(
isinstance(
storage,
(PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage),
)
for storage in storage_instances
):
postgres_db = PostgreSQLDB(self._get_postgres_config())
await postgres_db.initdb()
await postgres_db.check_tables()
for storage in storage_instances:
if isinstance(
storage,
(PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage),
):
storage.db = postgres_db
logger.info(f"Injected postgres_db to {storage.__class__.__name__}")
# Checking if Oracle is needed
if any(
isinstance(
storage, (OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage)
)
for storage in storage_instances
):
oracle_db = OracleDB(self._get_oracle_config())
await oracle_db.check_tables()
for storage in storage_instances:
if isinstance(
storage,
(OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage),
):
storage.db = oracle_db
logger.info(f"Injected oracle_db to {storage.__class__.__name__}")
# Checking if TiDB is needed
if any(
isinstance(storage, (TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage))
for storage in storage_instances
):
tidb_db = TiDB(self._get_tidb_config())
await tidb_db.check_tables()
# 注入db实例
for storage in storage_instances:
if isinstance(
storage, (TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage)
):
storage.db = tidb_db
logger.info(f"Injected tidb_db to {storage.__class__.__name__}")
def set_storage_client(self, db_client): def set_storage_client(self, db_client):
# Now only tested on Oracle Database # Inject db to storage implementation (only tested on Oracle Database
# Deprecated, seting correct value to *_storage creating LightRAG insteaded
for storage in [ for storage in [
self.vector_db_storage_cls, self.vector_db_storage_cls,
self.graph_storage_cls, self.graph_storage_cls,