refactor: move database connection pool initialization to lifespan of FastAPI

- Add proper database connection lifecycle management
- Add connection pool cleanup in FastAPI lifespan
This commit is contained in:
yangdx
2025-02-13 01:11:09 +08:00
parent 7c7cac1cfd
commit 4c39cf399d
2 changed files with 213 additions and 188 deletions

View File

@@ -33,14 +33,39 @@ from contextlib import asynccontextmanager
from starlette.status import HTTP_403_FORBIDDEN from starlette.status import HTTP_403_FORBIDDEN
import pipmaster as pm import pipmaster as pm
from dotenv import load_dotenv from dotenv import load_dotenv
import configparser
from lightrag.utils import logger
from .ollama_api import ( from .ollama_api import (
OllamaAPI, OllamaAPI,
) )
from .ollama_api import ollama_server_infos from .ollama_api import ollama_server_infos
from ..kg.postgres_impl import (
PostgreSQLDB,
PGKVStorage,
PGVectorStorage,
PGGraphStorage,
PGDocStatusStorage,
)
from ..kg.oracle_impl import (
OracleDB,
OracleKVStorage,
OracleVectorDBStorage,
OracleGraphStorage,
)
from ..kg.tidb_impl import (
TiDB,
TiDBKVStorage,
TiDBVectorDBStorage,
TiDBGraphStorage,
)
# Load environment variables # Load environment variables
load_dotenv(override=True) load_dotenv(override=True)
# Initialize config parser
config = configparser.ConfigParser()
config.read("config.ini")
class RAGStorageConfig: class RAGStorageConfig:
"""存储配置类,支持通过环境变量和命令行参数修改默认值""" """存储配置类,支持通过环境变量和命令行参数修改默认值"""
@@ -714,7 +739,68 @@ def create_app(args):
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
"""Lifespan context manager for startup and shutdown events""" """Lifespan context manager for startup and shutdown events"""
# Startup logic # Initialize database connections
postgres_db = None
oracle_db = None
tidb_db = None
try:
# Check if PostgreSQL is needed
if any(
isinstance(
storage_instance,
(PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage),
)
for _, storage_instance in storage_instances
):
postgres_db = PostgreSQLDB(_get_postgres_config())
await postgres_db.initdb()
await postgres_db.check_tables()
for storage_name, storage_instance in storage_instances:
if isinstance(
storage_instance,
(PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage),
):
storage_instance.db = postgres_db
logger.info(f"Injected postgres_db to {storage_name}")
# Check if Oracle is needed
if any(
isinstance(
storage_instance,
(OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage),
)
for _, storage_instance in storage_instances
):
oracle_db = OracleDB(_get_oracle_config())
await oracle_db.check_tables()
for storage_name, storage_instance in storage_instances:
if isinstance(
storage_instance,
(OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage),
):
storage_instance.db = oracle_db
logger.info(f"Injected oracle_db to {storage_name}")
# Check if TiDB is needed
if any(
isinstance(
storage_instance,
(TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage),
)
for _, storage_instance in storage_instances
):
tidb_db = TiDB(_get_tidb_config())
await tidb_db.check_tables()
for storage_name, storage_instance in storage_instances:
if isinstance(
storage_instance,
(TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage),
):
storage_instance.db = tidb_db
logger.info(f"Injected tidb_db to {storage_name}")
# Auto scan documents if enabled
if args.auto_scan_at_startup: if args.auto_scan_at_startup:
try: try:
new_files = doc_manager.scan_directory_for_new_files() new_files = doc_manager.scan_directory_for_new_files()
@@ -730,9 +816,22 @@ def create_app(args):
) )
except Exception as e: except Exception as e:
logging.error(f"Error during startup indexing: {str(e)}") logging.error(f"Error during startup indexing: {str(e)}")
yield yield
# Cleanup logic (if needed)
pass finally:
# Cleanup database connections
if postgres_db and hasattr(postgres_db, "pool"):
await postgres_db.pool.close()
logger.info("Closed PostgreSQL connection pool")
if oracle_db and hasattr(oracle_db, "pool"):
await oracle_db.pool.close()
logger.info("Closed Oracle connection pool")
if tidb_db and hasattr(tidb_db, "pool"):
await tidb_db.pool.close()
logger.info("Closed TiDB connection pool")
# Initialize FastAPI # Initialize FastAPI
app = FastAPI( app = FastAPI(
@@ -755,6 +854,92 @@ def create_app(args):
allow_headers=["*"], allow_headers=["*"],
) )
# Database configuration functions
def _get_postgres_config():
return {
"host": os.environ.get(
"POSTGRES_HOST",
config.get("postgres", "host", fallback="localhost"),
),
"port": os.environ.get(
"POSTGRES_PORT", config.get("postgres", "port", fallback=5432)
),
"user": os.environ.get(
"POSTGRES_USER", config.get("postgres", "user", fallback=None)
),
"password": os.environ.get(
"POSTGRES_PASSWORD",
config.get("postgres", "password", fallback=None),
),
"database": os.environ.get(
"POSTGRES_DATABASE",
config.get("postgres", "database", fallback=None),
),
"workspace": os.environ.get(
"POSTGRES_WORKSPACE",
config.get("postgres", "workspace", fallback="default"),
),
}
def _get_oracle_config():
return {
"user": os.environ.get(
"ORACLE_USER",
config.get("oracle", "user", fallback=None),
),
"password": os.environ.get(
"ORACLE_PASSWORD",
config.get("oracle", "password", fallback=None),
),
"dsn": os.environ.get(
"ORACLE_DSN",
config.get("oracle", "dsn", fallback=None),
),
"config_dir": os.environ.get(
"ORACLE_CONFIG_DIR",
config.get("oracle", "config_dir", fallback=None),
),
"wallet_location": os.environ.get(
"ORACLE_WALLET_LOCATION",
config.get("oracle", "wallet_location", fallback=None),
),
"wallet_password": os.environ.get(
"ORACLE_WALLET_PASSWORD",
config.get("oracle", "wallet_password", fallback=None),
),
"workspace": os.environ.get(
"ORACLE_WORKSPACE",
config.get("oracle", "workspace", fallback="default"),
),
}
def _get_tidb_config():
return {
"host": os.environ.get(
"TIDB_HOST",
config.get("tidb", "host", fallback="localhost"),
),
"port": os.environ.get(
"TIDB_PORT", config.get("tidb", "port", fallback=4000)
),
"user": os.environ.get(
"TIDB_USER",
config.get("tidb", "user", fallback=None),
),
"password": os.environ.get(
"TIDB_PASSWORD",
config.get("tidb", "password", fallback=None),
),
"database": os.environ.get(
"TIDB_DATABASE",
config.get("tidb", "database", fallback=None),
),
"workspace": os.environ.get(
"TIDB_WORKSPACE",
config.get("tidb", "workspace", fallback="default"),
),
}
# Create the optional API key dependency # Create the optional API key dependency
optional_api_key = get_api_key_dependency(api_key) optional_api_key = get_api_key_dependency(api_key)
@@ -921,6 +1106,17 @@ def create_app(args):
namespace_prefix=args.namespace_prefix, namespace_prefix=args.namespace_prefix,
) )
# Collect all storage instances
storage_instances = [
("full_docs", rag.full_docs),
("text_chunks", rag.text_chunks),
("chunk_entity_relation_graph", rag.chunk_entity_relation_graph),
("entities_vdb", rag.entities_vdb),
("relationships_vdb", rag.relationships_vdb),
("chunks_vdb", rag.chunks_vdb),
("doc_status", rag.doc_status),
]
async def index_file(file_path: Union[str, Path]) -> None: async def index_file(file_path: Union[str, Path]) -> None:
"""Index all files inside the folder with support for multiple file formats """Index all files inside the folder with support for multiple file formats

View File

@@ -355,91 +355,6 @@ class LightRAG:
list[dict[str, Any]], list[dict[str, Any]],
] = chunking_by_token_size ] = chunking_by_token_size
def _get_postgres_config(self):
return {
"host": os.environ.get(
"POSTGRES_HOST",
config.get("postgres", "host", fallback="localhost"),
),
"port": os.environ.get(
"POSTGRES_PORT", config.get("postgres", "port", fallback=5432)
),
"user": os.environ.get(
"POSTGRES_USER", config.get("postgres", "user", fallback=None)
),
"password": os.environ.get(
"POSTGRES_PASSWORD",
config.get("postgres", "password", fallback=None),
),
"database": os.environ.get(
"POSTGRES_DATABASE",
config.get("postgres", "database", fallback=None),
),
"workspace": os.environ.get(
"POSTGRES_WORKSPACE",
config.get("postgres", "workspace", fallback="default"),
),
}
def _get_oracle_config(self):
return {
"user": os.environ.get(
"ORACLE_USER",
config.get("oracle", "user", fallback=None),
),
"password": os.environ.get(
"ORACLE_PASSWORD",
config.get("oracle", "password", fallback=None),
),
"dsn": os.environ.get(
"ORACLE_DSN",
config.get("oracle", "dsn", fallback=None),
),
"config_dir": os.environ.get(
"ORACLE_CONFIG_DIR",
config.get("oracle", "config_dir", fallback=None),
),
"wallet_location": os.environ.get(
"ORACLE_WALLET_LOCATION",
config.get("oracle", "wallet_location", fallback=None),
),
"wallet_password": os.environ.get(
"ORACLE_WALLET_PASSWORD",
config.get("oracle", "wallet_password", fallback=None),
),
"workspace": os.environ.get(
"ORACLE_WORKSPACE",
config.get("oracle", "workspace", fallback="default"),
),
}
def _get_tidb_config(self):
return {
"host": os.environ.get(
"TIDB_HOST",
config.get("tidb", "host", fallback="localhost"),
),
"port": os.environ.get(
"TIDB_PORT", config.get("tidb", "port", fallback=4000)
),
"user": os.environ.get(
"TIDB_USER",
config.get("tidb", "user", fallback=None),
),
"password": os.environ.get(
"TIDB_PASSWORD",
config.get("tidb", "password", fallback=None),
),
"database": os.environ.get(
"TIDB_DATABASE",
config.get("tidb", "database", fallback=None),
),
"workspace": os.environ.get(
"TIDB_WORKSPACE",
config.get("tidb", "workspace", fallback="default"),
),
}
def verify_storage_implementation( def verify_storage_implementation(
self, storage_type: str, storage_name: str self, storage_type: str, storage_name: str
) -> None: ) -> None:
@@ -609,20 +524,6 @@ class LightRAG:
) )
# Collect all storage instances with their names
storage_instances = [
("full_docs", self.full_docs),
("text_chunks", self.text_chunks),
("chunk_entity_relation_graph", self.chunk_entity_relation_graph),
("entities_vdb", self.entities_vdb),
("relationships_vdb", self.relationships_vdb),
("chunks_vdb", self.chunks_vdb),
("doc_status", self.doc_status),
]
# Initialize database connections if needed
loop = always_get_an_event_loop()
loop.run_until_complete(self._initialize_database_if_needed(storage_instances))
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
partial( partial(
@@ -646,78 +547,6 @@ class LightRAG:
storage_class = lazy_external_import(import_path, storage_name) storage_class = lazy_external_import(import_path, storage_name)
return storage_class return storage_class
async def _initialize_database_if_needed(self, storage_instances: list[tuple[str, Any]]):
"""Intialize database connection and inject it to storage implementation if needed"""
from .kg.postgres_impl import PostgreSQLDB
from .kg.oracle_impl import OracleDB
from .kg.tidb_impl import TiDB
from .kg.postgres_impl import (
PGKVStorage,
PGVectorStorage,
PGGraphStorage,
PGDocStatusStorage,
)
from .kg.oracle_impl import (
OracleKVStorage,
OracleVectorDBStorage,
OracleGraphStorage,
)
from .kg.tidb_impl import (
TiDBKVStorage,
TiDBVectorDBStorage,
TiDBGraphStorage)
# Checking if PostgreSQL is needed
if any(
isinstance(
storage_instance,
(PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage),
)
for _, storage_instance in storage_instances
):
postgres_db = PostgreSQLDB(self._get_postgres_config())
await postgres_db.initdb()
await postgres_db.check_tables()
for storage_name, storage_instance in storage_instances:
if isinstance(
storage_instance,
(PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage),
):
storage_instance.db = postgres_db
logger.info(f"Injected postgres_db to {storage_name}")
# Checking if Oracle is needed
if any(
isinstance(
storage_instance, (OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage)
)
for _, storage_instance in storage_instances
):
oracle_db = OracleDB(self._get_oracle_config())
await oracle_db.check_tables()
for storage_name, storage_instance in storage_instances:
if isinstance(
storage_instance,
(OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage),
):
storage_instance.db = oracle_db
logger.info(f"Injected oracle_db to {storage_name}")
# Checking if TiDB is needed
if any(
isinstance(storage_instance, (TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage))
for _, storage_instance in storage_instances
):
tidb_db = TiDB(self._get_tidb_config())
await tidb_db.check_tables()
# 注入db实例
for storage_name, storage_instance in storage_instances:
if isinstance(
storage_instance, (TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage)
):
storage_instance.db = tidb_db
logger.info(f"Injected tidb_db to {storage_name}")
def set_storage_client(self, db_client): def set_storage_client(self, db_client):
# Inject db to storage implementation (only tested on Oracle Database # Inject db to storage implementation (only tested on Oracle Database
# Deprecated, seting correct value to *_storage creating LightRAG insteaded # Deprecated, seting correct value to *_storage creating LightRAG insteaded