diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 7a50a512..fba81086 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -41,25 +41,35 @@ from .ollama_api import ( OllamaAPI, ) from .ollama_api import ollama_server_infos -from ..kg.postgres_impl import ( - PostgreSQLDB, - PGKVStorage, - PGVectorStorage, - PGGraphStorage, - PGDocStatusStorage, -) -from ..kg.oracle_impl import ( - OracleDB, - OracleKVStorage, - OracleVectorDBStorage, - OracleGraphStorage, -) -from ..kg.tidb_impl import ( - TiDB, - TiDBKVStorage, - TiDBVectorDBStorage, - TiDBGraphStorage, -) + + +def get_db_type_from_storage_class(class_name: str) -> str | None: + """Determine database type based on storage class name""" + if class_name.startswith("PG"): + return "postgres" + elif class_name.startswith("Oracle"): + return "oracle" + elif class_name.startswith("TiDB"): + return "tidb" + return None + + +def import_db_module(db_type: str): + """Dynamically import database module""" + if db_type == "postgres": + from ..kg.postgres_impl import PostgreSQLDB + + return PostgreSQLDB + elif db_type == "oracle": + from ..kg.oracle_impl import OracleDB + + return OracleDB + elif db_type == "tidb": + from ..kg.tidb_impl import TiDB + + return TiDB + return None + # Load environment variables try: @@ -333,28 +343,28 @@ def parse_args() -> argparse.Namespace: default=get_env_value( "LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE ), - help=f"KV存储实现 (default: {DefaultRAGStorageConfig.KV_STORAGE})", + help=f"KV storage implementation (default: {DefaultRAGStorageConfig.KV_STORAGE})", ) parser.add_argument( "--doc-status-storage", default=get_env_value( "LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE ), - help=f"文档状态存储实现 (default: {DefaultRAGStorageConfig.DOC_STATUS_STORAGE})", + help=f"Document status storage implementation (default: {DefaultRAGStorageConfig.DOC_STATUS_STORAGE})", ) parser.add_argument( "--graph-storage", default=get_env_value( "LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE ), - help=f"图存储实现 (default: {DefaultRAGStorageConfig.GRAPH_STORAGE})", + help=f"Graph storage implementation (default: {DefaultRAGStorageConfig.GRAPH_STORAGE})", ) parser.add_argument( "--vector-storage", default=get_env_value( "LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE ), - help=f"向量存储实现 (default: {DefaultRAGStorageConfig.VECTOR_STORAGE})", + help=f"Vector storage implementation (default: {DefaultRAGStorageConfig.VECTOR_STORAGE})", ) # Bindings configuration @@ -788,6 +798,36 @@ class InsertResponse(BaseModel): class DocStatusResponse(BaseModel): + @staticmethod + def format_datetime(dt: Any) -> Optional[str]: + """Format datetime to ISO string + + Args: + dt: Datetime object or string + + Returns: + Formatted datetime string or None + """ + if dt is None: + return None + if isinstance(dt, str): + return dt + return dt.isoformat() + + """Response model for document status + + Attributes: + id: Document identifier + content_summary: Summary of document content + content_length: Length of document content + status: Current processing status + created_at: Creation timestamp (ISO format string) + updated_at: Last update timestamp (ISO format string) + chunks_count: Number of chunks (optional) + error: Error message if any (optional) + metadata: Additional metadata (optional) + """ + id: str content_summary: str content_length: int @@ -890,72 +930,51 @@ def create_app(args): async def lifespan(app: FastAPI): """Lifespan context manager for startup and shutdown events""" # Initialize database connections - postgres_db = None - oracle_db = None - tidb_db = None + db_instances = {} # Store background tasks app.state.background_tasks = set() try: - # Check if PostgreSQL is needed - if any( - isinstance( - storage_instance, - (PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage), + # Check which database types are used + db_types = set() + for storage_name, storage_instance in storage_instances: + db_type = get_db_type_from_storage_class( + storage_instance.__class__.__name__ ) - for _, storage_instance in storage_instances - ): - postgres_db = PostgreSQLDB(_get_postgres_config()) - await postgres_db.initdb() - await postgres_db.check_tables() - for storage_name, storage_instance in storage_instances: - if isinstance( - storage_instance, - ( - PGKVStorage, - PGVectorStorage, - PGGraphStorage, - PGDocStatusStorage, - ), - ): - storage_instance.db = postgres_db - logger.info(f"Injected postgres_db to {storage_name}") + if db_type: + db_types.add(db_type) - # Check if Oracle is needed - if any( - isinstance( - storage_instance, - (OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage), - ) - for _, storage_instance in storage_instances - ): - oracle_db = OracleDB(_get_oracle_config()) - await oracle_db.check_tables() - for storage_name, storage_instance in storage_instances: - if isinstance( - storage_instance, - (OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage), - ): - storage_instance.db = oracle_db - logger.info(f"Injected oracle_db to {storage_name}") + # Import and initialize databases as needed + for db_type in db_types: + if db_type == "postgres": + DB = import_db_module("postgres") + db = DB(_get_postgres_config()) + await db.initdb() + await db.check_tables() + db_instances["postgres"] = db + elif db_type == "oracle": + DB = import_db_module("oracle") + db = DB(_get_oracle_config()) + await db.check_tables() + db_instances["oracle"] = db + elif db_type == "tidb": + DB = import_db_module("tidb") + db = DB(_get_tidb_config()) + await db.check_tables() + db_instances["tidb"] = db - # Check if TiDB is needed - if any( - isinstance( - storage_instance, - (TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage), + # Inject database instances into storage classes + for storage_name, storage_instance in storage_instances: + db_type = get_db_type_from_storage_class( + storage_instance.__class__.__name__ ) - for _, storage_instance in storage_instances - ): - tidb_db = TiDB(_get_tidb_config()) - await tidb_db.check_tables() - for storage_name, storage_instance in storage_instances: - if isinstance( - storage_instance, - (TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage), - ): - storage_instance.db = tidb_db - logger.info(f"Injected tidb_db to {storage_name}") + if db_type: + if db_type not in db_instances: + error_msg = f"Database type '{db_type}' is required by {storage_name} but not initialized" + logger.error(error_msg) + raise RuntimeError(error_msg) + storage_instance.db = db_instances[db_type] + logger.info(f"Injected {db_type} db to {storage_name}") # Auto scan documents if enabled if args.auto_scan_at_startup: @@ -980,18 +999,18 @@ def create_app(args): yield finally: - # Cleanup database connections - if postgres_db and hasattr(postgres_db, "pool"): - await postgres_db.pool.close() - logger.info("Closed PostgreSQL connection pool") - - if oracle_db and hasattr(oracle_db, "pool"): - await oracle_db.pool.close() - logger.info("Closed Oracle connection pool") - - if tidb_db and hasattr(tidb_db, "pool"): - await tidb_db.pool.close() - logger.info("Closed TiDB connection pool") + # Clean up database connections + for db_type, db in db_instances.items(): + if hasattr(db, "pool"): + await db.pool.close() + # Use more accurate database name display + db_names = { + "postgres": "PostgreSQL", + "oracle": "Oracle", + "tidb": "TiDB", + } + db_name = db_names.get(db_type, db_type) + logger.info(f"Closed {db_name} database connection pool") # Initialize FastAPI app = FastAPI( @@ -1311,7 +1330,7 @@ def create_app(args): case ".pdf": if not pm.is_installed("pypdf2"): pm.install("pypdf2") - from PyPDF2 import PdfReader + from PyPDF2 import PdfReader # type: ignore from io import BytesIO pdf_file = BytesIO(file) @@ -1869,8 +1888,12 @@ def create_app(args): content_summary=doc_status.content_summary, content_length=doc_status.content_length, status=doc_status.status, - created_at=doc_status.created_at, - updated_at=doc_status.updated_at, + created_at=DocStatusResponse.format_datetime( + doc_status.created_at + ), + updated_at=DocStatusResponse.format_datetime( + doc_status.updated_at + ), chunks_count=doc_status.chunks_count, error=doc_status.error, metadata=doc_status.metadata, diff --git a/lightrag/api/requirements.txt b/lightrag/api/requirements.txt index 7b2593c0..068a84b9 100644 --- a/lightrag/api/requirements.txt +++ b/lightrag/api/requirements.txt @@ -1,6 +1,5 @@ ascii_colors fastapi -nest_asyncio numpy pipmaster python-dotenv