From 75ee4592b8172aa2ec22794efa7c47aa1aafe0b1 Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 18 Feb 2025 16:14:11 +0800 Subject: [PATCH] refactor: Implement dynamic database module imports - Consolidate database instance management - Improve database management and error handling - Enhance error handling and logging --- lightrag/api/lightrag_server.py | 170 ++++++++++++++------------------ 1 file changed, 74 insertions(+), 96 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 7a50a512..661e25d0 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -41,25 +41,28 @@ from .ollama_api import ( OllamaAPI, ) from .ollama_api import ollama_server_infos -from ..kg.postgres_impl import ( - PostgreSQLDB, - PGKVStorage, - PGVectorStorage, - PGGraphStorage, - PGDocStatusStorage, -) -from ..kg.oracle_impl import ( - OracleDB, - OracleKVStorage, - OracleVectorDBStorage, - OracleGraphStorage, -) -from ..kg.tidb_impl import ( - TiDB, - TiDBKVStorage, - TiDBVectorDBStorage, - TiDBGraphStorage, -) +def get_db_type_from_storage_class(class_name: str) -> str | None: + """Determine database type based on storage class name""" + if class_name.startswith("PG"): + return "postgres" + elif class_name.startswith("Oracle"): + return "oracle" + elif class_name.startswith("TiDB"): + return "tidb" + return None + +def import_db_module(db_type: str): + """Dynamically import database module""" + if db_type == "postgres": + from ..kg.postgres_impl import PostgreSQLDB + return PostgreSQLDB + elif db_type == "oracle": + from ..kg.oracle_impl import OracleDB + return OracleDB + elif db_type == "tidb": + from ..kg.tidb_impl import TiDB + return TiDB + return None # Load environment variables try: @@ -333,28 +336,28 @@ def parse_args() -> argparse.Namespace: default=get_env_value( "LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE ), - help=f"KV存储实现 (default: {DefaultRAGStorageConfig.KV_STORAGE})", + help=f"KV storage implementation (default: {DefaultRAGStorageConfig.KV_STORAGE})", ) parser.add_argument( "--doc-status-storage", default=get_env_value( "LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE ), - help=f"文档状态存储实现 (default: {DefaultRAGStorageConfig.DOC_STATUS_STORAGE})", + help=f"Document status storage implementation (default: {DefaultRAGStorageConfig.DOC_STATUS_STORAGE})", ) parser.add_argument( "--graph-storage", default=get_env_value( "LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE ), - help=f"图存储实现 (default: {DefaultRAGStorageConfig.GRAPH_STORAGE})", + help=f"Graph storage implementation (default: {DefaultRAGStorageConfig.GRAPH_STORAGE})", ) parser.add_argument( "--vector-storage", default=get_env_value( "LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE ), - help=f"向量存储实现 (default: {DefaultRAGStorageConfig.VECTOR_STORAGE})", + help=f"Vector storage implementation (default: {DefaultRAGStorageConfig.VECTOR_STORAGE})", ) # Bindings configuration @@ -890,72 +893,47 @@ def create_app(args): async def lifespan(app: FastAPI): """Lifespan context manager for startup and shutdown events""" # Initialize database connections - postgres_db = None - oracle_db = None - tidb_db = None + db_instances = {} # Store background tasks app.state.background_tasks = set() try: - # Check if PostgreSQL is needed - if any( - isinstance( - storage_instance, - (PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage), - ) - for _, storage_instance in storage_instances - ): - postgres_db = PostgreSQLDB(_get_postgres_config()) - await postgres_db.initdb() - await postgres_db.check_tables() - for storage_name, storage_instance in storage_instances: - if isinstance( - storage_instance, - ( - PGKVStorage, - PGVectorStorage, - PGGraphStorage, - PGDocStatusStorage, - ), - ): - storage_instance.db = postgres_db - logger.info(f"Injected postgres_db to {storage_name}") + # Check which database types are used + db_types = set() + for storage_name, storage_instance in storage_instances: + db_type = get_db_type_from_storage_class(storage_instance.__class__.__name__) + if db_type: + db_types.add(db_type) - # Check if Oracle is needed - if any( - isinstance( - storage_instance, - (OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage), - ) - for _, storage_instance in storage_instances - ): - oracle_db = OracleDB(_get_oracle_config()) - await oracle_db.check_tables() - for storage_name, storage_instance in storage_instances: - if isinstance( - storage_instance, - (OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage), - ): - storage_instance.db = oracle_db - logger.info(f"Injected oracle_db to {storage_name}") + # Import and initialize databases as needed + for db_type in db_types: + if db_type == "postgres": + DB = import_db_module("postgres") + db = DB(_get_postgres_config()) + await db.initdb() + await db.check_tables() + db_instances["postgres"] = db + elif db_type == "oracle": + DB = import_db_module("oracle") + db = DB(_get_oracle_config()) + await db.check_tables() + db_instances["oracle"] = db + elif db_type == "tidb": + DB = import_db_module("tidb") + db = DB(_get_tidb_config()) + await db.check_tables() + db_instances["tidb"] = db - # Check if TiDB is needed - if any( - isinstance( - storage_instance, - (TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage), - ) - for _, storage_instance in storage_instances - ): - tidb_db = TiDB(_get_tidb_config()) - await tidb_db.check_tables() - for storage_name, storage_instance in storage_instances: - if isinstance( - storage_instance, - (TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage), - ): - storage_instance.db = tidb_db - logger.info(f"Injected tidb_db to {storage_name}") + # Inject database instances into storage classes + for storage_name, storage_instance in storage_instances: + db_type = get_db_type_from_storage_class(storage_instance.__class__.__name__) + if db_type: + if db_type not in db_instances: + error_msg = f"Database type '{db_type}' is required by {storage_name} but not initialized" + logger.error(error_msg) + raise RuntimeError(error_msg) + storage_instance.db = db_instances[db_type] + logger.info(f"Injected {db_type} db to {storage_name}") # Auto scan documents if enabled if args.auto_scan_at_startup: @@ -980,18 +958,18 @@ def create_app(args): yield finally: - # Cleanup database connections - if postgres_db and hasattr(postgres_db, "pool"): - await postgres_db.pool.close() - logger.info("Closed PostgreSQL connection pool") - - if oracle_db and hasattr(oracle_db, "pool"): - await oracle_db.pool.close() - logger.info("Closed Oracle connection pool") - - if tidb_db and hasattr(tidb_db, "pool"): - await tidb_db.pool.close() - logger.info("Closed TiDB connection pool") + # Clean up database connections + for db_type, db in db_instances.items(): + if hasattr(db, "pool"): + await db.pool.close() + # Use more accurate database name display + db_names = { + "postgres": "PostgreSQL", + "oracle": "Oracle", + "tidb": "TiDB" + } + db_name = db_names.get(db_type, db_type) + logger.info(f"Closed {db_name} database connection pool") # Initialize FastAPI app = FastAPI( @@ -1311,7 +1289,7 @@ def create_app(args): case ".pdf": if not pm.is_installed("pypdf2"): pm.install("pypdf2") - from PyPDF2 import PdfReader + from PyPDF2 import PdfReader # type: ignore from io import BytesIO pdf_file = BytesIO(file)