refactor: Implement dynamic database module imports

- Consolidate database instance management
- Improve database management and error handling
- Enhance error handling and logging
This commit is contained in:
yangdx
2025-02-18 16:14:11 +08:00
parent 8ab369c2af
commit 75ee4592b8

View File

@@ -41,25 +41,28 @@ from .ollama_api import (
OllamaAPI, OllamaAPI,
) )
from .ollama_api import ollama_server_infos from .ollama_api import ollama_server_infos
from ..kg.postgres_impl import ( def get_db_type_from_storage_class(class_name: str) -> str | None:
PostgreSQLDB, """Determine database type based on storage class name"""
PGKVStorage, if class_name.startswith("PG"):
PGVectorStorage, return "postgres"
PGGraphStorage, elif class_name.startswith("Oracle"):
PGDocStatusStorage, return "oracle"
) elif class_name.startswith("TiDB"):
from ..kg.oracle_impl import ( return "tidb"
OracleDB, return None
OracleKVStorage,
OracleVectorDBStorage, def import_db_module(db_type: str):
OracleGraphStorage, """Dynamically import database module"""
) if db_type == "postgres":
from ..kg.tidb_impl import ( from ..kg.postgres_impl import PostgreSQLDB
TiDB, return PostgreSQLDB
TiDBKVStorage, elif db_type == "oracle":
TiDBVectorDBStorage, from ..kg.oracle_impl import OracleDB
TiDBGraphStorage, return OracleDB
) elif db_type == "tidb":
from ..kg.tidb_impl import TiDB
return TiDB
return None
# Load environment variables # Load environment variables
try: try:
@@ -333,28 +336,28 @@ def parse_args() -> argparse.Namespace:
default=get_env_value( default=get_env_value(
"LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE "LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE
), ),
help=f"KV存储实现 (default: {DefaultRAGStorageConfig.KV_STORAGE})", help=f"KV storage implementation (default: {DefaultRAGStorageConfig.KV_STORAGE})",
) )
parser.add_argument( parser.add_argument(
"--doc-status-storage", "--doc-status-storage",
default=get_env_value( default=get_env_value(
"LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE "LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE
), ),
help=f"文档状态存储实现 (default: {DefaultRAGStorageConfig.DOC_STATUS_STORAGE})", help=f"Document status storage implementation (default: {DefaultRAGStorageConfig.DOC_STATUS_STORAGE})",
) )
parser.add_argument( parser.add_argument(
"--graph-storage", "--graph-storage",
default=get_env_value( default=get_env_value(
"LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE "LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE
), ),
help=f"图存储实现 (default: {DefaultRAGStorageConfig.GRAPH_STORAGE})", help=f"Graph storage implementation (default: {DefaultRAGStorageConfig.GRAPH_STORAGE})",
) )
parser.add_argument( parser.add_argument(
"--vector-storage", "--vector-storage",
default=get_env_value( default=get_env_value(
"LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE "LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE
), ),
help=f"向量存储实现 (default: {DefaultRAGStorageConfig.VECTOR_STORAGE})", help=f"Vector storage implementation (default: {DefaultRAGStorageConfig.VECTOR_STORAGE})",
) )
# Bindings configuration # Bindings configuration
@@ -890,72 +893,47 @@ def create_app(args):
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
"""Lifespan context manager for startup and shutdown events""" """Lifespan context manager for startup and shutdown events"""
# Initialize database connections # Initialize database connections
postgres_db = None db_instances = {}
oracle_db = None
tidb_db = None
# Store background tasks # Store background tasks
app.state.background_tasks = set() app.state.background_tasks = set()
try: try:
# Check if PostgreSQL is needed # Check which database types are used
if any( db_types = set()
isinstance( for storage_name, storage_instance in storage_instances:
storage_instance, db_type = get_db_type_from_storage_class(storage_instance.__class__.__name__)
(PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage), if db_type:
) db_types.add(db_type)
for _, storage_instance in storage_instances
):
postgres_db = PostgreSQLDB(_get_postgres_config())
await postgres_db.initdb()
await postgres_db.check_tables()
for storage_name, storage_instance in storage_instances:
if isinstance(
storage_instance,
(
PGKVStorage,
PGVectorStorage,
PGGraphStorage,
PGDocStatusStorage,
),
):
storage_instance.db = postgres_db
logger.info(f"Injected postgres_db to {storage_name}")
# Check if Oracle is needed # Import and initialize databases as needed
if any( for db_type in db_types:
isinstance( if db_type == "postgres":
storage_instance, DB = import_db_module("postgres")
(OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage), db = DB(_get_postgres_config())
) await db.initdb()
for _, storage_instance in storage_instances await db.check_tables()
): db_instances["postgres"] = db
oracle_db = OracleDB(_get_oracle_config()) elif db_type == "oracle":
await oracle_db.check_tables() DB = import_db_module("oracle")
for storage_name, storage_instance in storage_instances: db = DB(_get_oracle_config())
if isinstance( await db.check_tables()
storage_instance, db_instances["oracle"] = db
(OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage), elif db_type == "tidb":
): DB = import_db_module("tidb")
storage_instance.db = oracle_db db = DB(_get_tidb_config())
logger.info(f"Injected oracle_db to {storage_name}") await db.check_tables()
db_instances["tidb"] = db
# Check if TiDB is needed # Inject database instances into storage classes
if any( for storage_name, storage_instance in storage_instances:
isinstance( db_type = get_db_type_from_storage_class(storage_instance.__class__.__name__)
storage_instance, if db_type:
(TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage), if db_type not in db_instances:
) error_msg = f"Database type '{db_type}' is required by {storage_name} but not initialized"
for _, storage_instance in storage_instances logger.error(error_msg)
): raise RuntimeError(error_msg)
tidb_db = TiDB(_get_tidb_config()) storage_instance.db = db_instances[db_type]
await tidb_db.check_tables() logger.info(f"Injected {db_type} db to {storage_name}")
for storage_name, storage_instance in storage_instances:
if isinstance(
storage_instance,
(TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage),
):
storage_instance.db = tidb_db
logger.info(f"Injected tidb_db to {storage_name}")
# Auto scan documents if enabled # Auto scan documents if enabled
if args.auto_scan_at_startup: if args.auto_scan_at_startup:
@@ -980,18 +958,18 @@ def create_app(args):
yield yield
finally: finally:
# Cleanup database connections # Clean up database connections
if postgres_db and hasattr(postgres_db, "pool"): for db_type, db in db_instances.items():
await postgres_db.pool.close() if hasattr(db, "pool"):
logger.info("Closed PostgreSQL connection pool") await db.pool.close()
# Use more accurate database name display
if oracle_db and hasattr(oracle_db, "pool"): db_names = {
await oracle_db.pool.close() "postgres": "PostgreSQL",
logger.info("Closed Oracle connection pool") "oracle": "Oracle",
"tidb": "TiDB"
if tidb_db and hasattr(tidb_db, "pool"): }
await tidb_db.pool.close() db_name = db_names.get(db_type, db_type)
logger.info("Closed TiDB connection pool") logger.info(f"Closed {db_name} database connection pool")
# Initialize FastAPI # Initialize FastAPI
app = FastAPI( app = FastAPI(
@@ -1311,7 +1289,7 @@ def create_app(args):
case ".pdf": case ".pdf":
if not pm.is_installed("pypdf2"): if not pm.is_installed("pypdf2"):
pm.install("pypdf2") pm.install("pypdf2")
from PyPDF2 import PdfReader from PyPDF2 import PdfReader # type: ignore
from io import BytesIO from io import BytesIO
pdf_file = BytesIO(file) pdf_file = BytesIO(file)