diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 97b3f5a5..0839c1f8 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -33,14 +33,39 @@ from contextlib import asynccontextmanager from starlette.status import HTTP_403_FORBIDDEN import pipmaster as pm from dotenv import load_dotenv +import configparser +from lightrag.utils import logger from .ollama_api import ( OllamaAPI, ) from .ollama_api import ollama_server_infos +from ..kg.postgres_impl import ( + PostgreSQLDB, + PGKVStorage, + PGVectorStorage, + PGGraphStorage, + PGDocStatusStorage, +) +from ..kg.oracle_impl import ( + OracleDB, + OracleKVStorage, + OracleVectorDBStorage, + OracleGraphStorage, +) +from ..kg.tidb_impl import ( + TiDB, + TiDBKVStorage, + TiDBVectorDBStorage, + TiDBGraphStorage, +) # Load environment variables load_dotenv(override=True) +# Initialize config parser +config = configparser.ConfigParser() +config.read("config.ini") + class RAGStorageConfig: """存储配置类,支持通过环境变量和命令行参数修改默认值""" @@ -714,25 +739,99 @@ def create_app(args): @asynccontextmanager async def lifespan(app: FastAPI): """Lifespan context manager for startup and shutdown events""" - # Startup logic - if args.auto_scan_at_startup: - try: - new_files = doc_manager.scan_directory_for_new_files() - for file_path in new_files: - try: - await index_file(file_path) - except Exception as e: - trace_exception(e) - logging.error(f"Error indexing file {file_path}: {str(e)}") + # Initialize database connections + postgres_db = None + oracle_db = None + tidb_db = None - ASCIIColors.info( - f"Indexed {len(new_files)} documents from {args.input_dir}" + try: + # Check if PostgreSQL is needed + if any( + isinstance( + storage_instance, + (PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage), ) - except Exception as e: - logging.error(f"Error during startup indexing: {str(e)}") - yield - # Cleanup logic (if needed) - pass + for _, storage_instance in storage_instances + ): + postgres_db = PostgreSQLDB(_get_postgres_config()) + await postgres_db.initdb() + await postgres_db.check_tables() + for storage_name, storage_instance in storage_instances: + if isinstance( + storage_instance, + (PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage), + ): + storage_instance.db = postgres_db + logger.info(f"Injected postgres_db to {storage_name}") + + # Check if Oracle is needed + if any( + isinstance( + storage_instance, + (OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage), + ) + for _, storage_instance in storage_instances + ): + oracle_db = OracleDB(_get_oracle_config()) + await oracle_db.check_tables() + for storage_name, storage_instance in storage_instances: + if isinstance( + storage_instance, + (OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage), + ): + storage_instance.db = oracle_db + logger.info(f"Injected oracle_db to {storage_name}") + + # Check if TiDB is needed + if any( + isinstance( + storage_instance, + (TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage), + ) + for _, storage_instance in storage_instances + ): + tidb_db = TiDB(_get_tidb_config()) + await tidb_db.check_tables() + for storage_name, storage_instance in storage_instances: + if isinstance( + storage_instance, + (TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage), + ): + storage_instance.db = tidb_db + logger.info(f"Injected tidb_db to {storage_name}") + + # Auto scan documents if enabled + if args.auto_scan_at_startup: + try: + new_files = doc_manager.scan_directory_for_new_files() + for file_path in new_files: + try: + await index_file(file_path) + except Exception as e: + trace_exception(e) + logging.error(f"Error indexing file {file_path}: {str(e)}") + + ASCIIColors.info( + f"Indexed {len(new_files)} documents from {args.input_dir}" + ) + except Exception as e: + logging.error(f"Error during startup indexing: {str(e)}") + + yield + + finally: + # Cleanup database connections + if postgres_db and hasattr(postgres_db, "pool"): + await postgres_db.pool.close() + logger.info("Closed PostgreSQL connection pool") + + if oracle_db and hasattr(oracle_db, "pool"): + await oracle_db.pool.close() + logger.info("Closed Oracle connection pool") + + if tidb_db and hasattr(tidb_db, "pool"): + await tidb_db.pool.close() + logger.info("Closed TiDB connection pool") # Initialize FastAPI app = FastAPI( @@ -755,6 +854,92 @@ def create_app(args): allow_headers=["*"], ) + # Database configuration functions + def _get_postgres_config(): + return { + "host": os.environ.get( + "POSTGRES_HOST", + config.get("postgres", "host", fallback="localhost"), + ), + "port": os.environ.get( + "POSTGRES_PORT", config.get("postgres", "port", fallback=5432) + ), + "user": os.environ.get( + "POSTGRES_USER", config.get("postgres", "user", fallback=None) + ), + "password": os.environ.get( + "POSTGRES_PASSWORD", + config.get("postgres", "password", fallback=None), + ), + "database": os.environ.get( + "POSTGRES_DATABASE", + config.get("postgres", "database", fallback=None), + ), + "workspace": os.environ.get( + "POSTGRES_WORKSPACE", + config.get("postgres", "workspace", fallback="default"), + ), + } + + def _get_oracle_config(): + return { + "user": os.environ.get( + "ORACLE_USER", + config.get("oracle", "user", fallback=None), + ), + "password": os.environ.get( + "ORACLE_PASSWORD", + config.get("oracle", "password", fallback=None), + ), + "dsn": os.environ.get( + "ORACLE_DSN", + config.get("oracle", "dsn", fallback=None), + ), + "config_dir": os.environ.get( + "ORACLE_CONFIG_DIR", + config.get("oracle", "config_dir", fallback=None), + ), + "wallet_location": os.environ.get( + "ORACLE_WALLET_LOCATION", + config.get("oracle", "wallet_location", fallback=None), + ), + "wallet_password": os.environ.get( + "ORACLE_WALLET_PASSWORD", + config.get("oracle", "wallet_password", fallback=None), + ), + "workspace": os.environ.get( + "ORACLE_WORKSPACE", + config.get("oracle", "workspace", fallback="default"), + ), + } + + def _get_tidb_config(): + return { + "host": os.environ.get( + "TIDB_HOST", + config.get("tidb", "host", fallback="localhost"), + ), + "port": os.environ.get( + "TIDB_PORT", config.get("tidb", "port", fallback=4000) + ), + "user": os.environ.get( + "TIDB_USER", + config.get("tidb", "user", fallback=None), + ), + "password": os.environ.get( + "TIDB_PASSWORD", + config.get("tidb", "password", fallback=None), + ), + "database": os.environ.get( + "TIDB_DATABASE", + config.get("tidb", "database", fallback=None), + ), + "workspace": os.environ.get( + "TIDB_WORKSPACE", + config.get("tidb", "workspace", fallback="default"), + ), + } + # Create the optional API key dependency optional_api_key = get_api_key_dependency(api_key) @@ -921,6 +1106,17 @@ def create_app(args): namespace_prefix=args.namespace_prefix, ) + # Collect all storage instances + storage_instances = [ + ("full_docs", rag.full_docs), + ("text_chunks", rag.text_chunks), + ("chunk_entity_relation_graph", rag.chunk_entity_relation_graph), + ("entities_vdb", rag.entities_vdb), + ("relationships_vdb", rag.relationships_vdb), + ("chunks_vdb", rag.chunks_vdb), + ("doc_status", rag.doc_status), + ] + async def index_file(file_path: Union[str, Path]) -> None: """Index all files inside the folder with support for multiple file formats diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index dcd829eb..e6217572 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -355,91 +355,6 @@ class LightRAG: list[dict[str, Any]], ] = chunking_by_token_size - def _get_postgres_config(self): - return { - "host": os.environ.get( - "POSTGRES_HOST", - config.get("postgres", "host", fallback="localhost"), - ), - "port": os.environ.get( - "POSTGRES_PORT", config.get("postgres", "port", fallback=5432) - ), - "user": os.environ.get( - "POSTGRES_USER", config.get("postgres", "user", fallback=None) - ), - "password": os.environ.get( - "POSTGRES_PASSWORD", - config.get("postgres", "password", fallback=None), - ), - "database": os.environ.get( - "POSTGRES_DATABASE", - config.get("postgres", "database", fallback=None), - ), - "workspace": os.environ.get( - "POSTGRES_WORKSPACE", - config.get("postgres", "workspace", fallback="default"), - ), - } - - def _get_oracle_config(self): - return { - "user": os.environ.get( - "ORACLE_USER", - config.get("oracle", "user", fallback=None), - ), - "password": os.environ.get( - "ORACLE_PASSWORD", - config.get("oracle", "password", fallback=None), - ), - "dsn": os.environ.get( - "ORACLE_DSN", - config.get("oracle", "dsn", fallback=None), - ), - "config_dir": os.environ.get( - "ORACLE_CONFIG_DIR", - config.get("oracle", "config_dir", fallback=None), - ), - "wallet_location": os.environ.get( - "ORACLE_WALLET_LOCATION", - config.get("oracle", "wallet_location", fallback=None), - ), - "wallet_password": os.environ.get( - "ORACLE_WALLET_PASSWORD", - config.get("oracle", "wallet_password", fallback=None), - ), - "workspace": os.environ.get( - "ORACLE_WORKSPACE", - config.get("oracle", "workspace", fallback="default"), - ), - } - - def _get_tidb_config(self): - return { - "host": os.environ.get( - "TIDB_HOST", - config.get("tidb", "host", fallback="localhost"), - ), - "port": os.environ.get( - "TIDB_PORT", config.get("tidb", "port", fallback=4000) - ), - "user": os.environ.get( - "TIDB_USER", - config.get("tidb", "user", fallback=None), - ), - "password": os.environ.get( - "TIDB_PASSWORD", - config.get("tidb", "password", fallback=None), - ), - "database": os.environ.get( - "TIDB_DATABASE", - config.get("tidb", "database", fallback=None), - ), - "workspace": os.environ.get( - "TIDB_WORKSPACE", - config.get("tidb", "workspace", fallback="default"), - ), - } - def verify_storage_implementation( self, storage_type: str, storage_name: str ) -> None: @@ -609,20 +524,6 @@ class LightRAG: ) - # Collect all storage instances with their names - storage_instances = [ - ("full_docs", self.full_docs), - ("text_chunks", self.text_chunks), - ("chunk_entity_relation_graph", self.chunk_entity_relation_graph), - ("entities_vdb", self.entities_vdb), - ("relationships_vdb", self.relationships_vdb), - ("chunks_vdb", self.chunks_vdb), - ("doc_status", self.doc_status), - ] - - # Initialize database connections if needed - loop = always_get_an_event_loop() - loop.run_until_complete(self._initialize_database_if_needed(storage_instances)) self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( partial( @@ -646,78 +547,6 @@ class LightRAG: storage_class = lazy_external_import(import_path, storage_name) return storage_class - async def _initialize_database_if_needed(self, storage_instances: list[tuple[str, Any]]): - """Intialize database connection and inject it to storage implementation if needed""" - from .kg.postgres_impl import PostgreSQLDB - from .kg.oracle_impl import OracleDB - from .kg.tidb_impl import TiDB - from .kg.postgres_impl import ( - PGKVStorage, - PGVectorStorage, - PGGraphStorage, - PGDocStatusStorage, - ) - from .kg.oracle_impl import ( - OracleKVStorage, - OracleVectorDBStorage, - OracleGraphStorage, - ) - from .kg.tidb_impl import ( - TiDBKVStorage, - TiDBVectorDBStorage, - TiDBGraphStorage) - - # Checking if PostgreSQL is needed - if any( - isinstance( - storage_instance, - (PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage), - ) - for _, storage_instance in storage_instances - ): - postgres_db = PostgreSQLDB(self._get_postgres_config()) - await postgres_db.initdb() - await postgres_db.check_tables() - for storage_name, storage_instance in storage_instances: - if isinstance( - storage_instance, - (PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage), - ): - storage_instance.db = postgres_db - logger.info(f"Injected postgres_db to {storage_name}") - - # Checking if Oracle is needed - if any( - isinstance( - storage_instance, (OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage) - ) - for _, storage_instance in storage_instances - ): - oracle_db = OracleDB(self._get_oracle_config()) - await oracle_db.check_tables() - for storage_name, storage_instance in storage_instances: - if isinstance( - storage_instance, - (OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage), - ): - storage_instance.db = oracle_db - logger.info(f"Injected oracle_db to {storage_name}") - - # Checking if TiDB is needed - if any( - isinstance(storage_instance, (TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage)) - for _, storage_instance in storage_instances - ): - tidb_db = TiDB(self._get_tidb_config()) - await tidb_db.check_tables() - # 注入db实例 - for storage_name, storage_instance in storage_instances: - if isinstance( - storage_instance, (TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage) - ): - storage_instance.db = tidb_db - logger.info(f"Injected tidb_db to {storage_name}") - def set_storage_client(self, db_client): # Inject db to storage implementation (only tested on Oracle Database # Deprecated, seting correct value to *_storage creating LightRAG insteaded