refactor: move database connection pool initialization to lifespan of FastAPI

- Add proper database connection lifecycle management
- Add connection pool cleanup in FastAPI lifespan
This commit is contained in:
yangdx
2025-02-13 01:11:09 +08:00
parent 7c7cac1cfd
commit 4c39cf399d
2 changed files with 213 additions and 188 deletions

View File

@@ -33,14 +33,39 @@ from contextlib import asynccontextmanager
from starlette.status import HTTP_403_FORBIDDEN
import pipmaster as pm
from dotenv import load_dotenv
import configparser
from lightrag.utils import logger
from .ollama_api import (
OllamaAPI,
)
from .ollama_api import ollama_server_infos
from ..kg.postgres_impl import (
PostgreSQLDB,
PGKVStorage,
PGVectorStorage,
PGGraphStorage,
PGDocStatusStorage,
)
from ..kg.oracle_impl import (
OracleDB,
OracleKVStorage,
OracleVectorDBStorage,
OracleGraphStorage,
)
from ..kg.tidb_impl import (
TiDB,
TiDBKVStorage,
TiDBVectorDBStorage,
TiDBGraphStorage,
)
# Load environment variables
load_dotenv(override=True)
# Initialize config parser
config = configparser.ConfigParser()
config.read("config.ini")
class RAGStorageConfig:
"""存储配置类,支持通过环境变量和命令行参数修改默认值"""
@@ -714,25 +739,99 @@ def create_app(args):
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifespan context manager for startup and shutdown events"""
# Startup logic
if args.auto_scan_at_startup:
try:
new_files = doc_manager.scan_directory_for_new_files()
for file_path in new_files:
try:
await index_file(file_path)
except Exception as e:
trace_exception(e)
logging.error(f"Error indexing file {file_path}: {str(e)}")
# Initialize database connections
postgres_db = None
oracle_db = None
tidb_db = None
ASCIIColors.info(
f"Indexed {len(new_files)} documents from {args.input_dir}"
try:
# Check if PostgreSQL is needed
if any(
isinstance(
storage_instance,
(PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage),
)
except Exception as e:
logging.error(f"Error during startup indexing: {str(e)}")
yield
# Cleanup logic (if needed)
pass
for _, storage_instance in storage_instances
):
postgres_db = PostgreSQLDB(_get_postgres_config())
await postgres_db.initdb()
await postgres_db.check_tables()
for storage_name, storage_instance in storage_instances:
if isinstance(
storage_instance,
(PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage),
):
storage_instance.db = postgres_db
logger.info(f"Injected postgres_db to {storage_name}")
# Check if Oracle is needed
if any(
isinstance(
storage_instance,
(OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage),
)
for _, storage_instance in storage_instances
):
oracle_db = OracleDB(_get_oracle_config())
await oracle_db.check_tables()
for storage_name, storage_instance in storage_instances:
if isinstance(
storage_instance,
(OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage),
):
storage_instance.db = oracle_db
logger.info(f"Injected oracle_db to {storage_name}")
# Check if TiDB is needed
if any(
isinstance(
storage_instance,
(TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage),
)
for _, storage_instance in storage_instances
):
tidb_db = TiDB(_get_tidb_config())
await tidb_db.check_tables()
for storage_name, storage_instance in storage_instances:
if isinstance(
storage_instance,
(TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage),
):
storage_instance.db = tidb_db
logger.info(f"Injected tidb_db to {storage_name}")
# Auto scan documents if enabled
if args.auto_scan_at_startup:
try:
new_files = doc_manager.scan_directory_for_new_files()
for file_path in new_files:
try:
await index_file(file_path)
except Exception as e:
trace_exception(e)
logging.error(f"Error indexing file {file_path}: {str(e)}")
ASCIIColors.info(
f"Indexed {len(new_files)} documents from {args.input_dir}"
)
except Exception as e:
logging.error(f"Error during startup indexing: {str(e)}")
yield
finally:
# Cleanup database connections
if postgres_db and hasattr(postgres_db, "pool"):
await postgres_db.pool.close()
logger.info("Closed PostgreSQL connection pool")
if oracle_db and hasattr(oracle_db, "pool"):
await oracle_db.pool.close()
logger.info("Closed Oracle connection pool")
if tidb_db and hasattr(tidb_db, "pool"):
await tidb_db.pool.close()
logger.info("Closed TiDB connection pool")
# Initialize FastAPI
app = FastAPI(
@@ -755,6 +854,92 @@ def create_app(args):
allow_headers=["*"],
)
# Database configuration functions
def _get_postgres_config():
return {
"host": os.environ.get(
"POSTGRES_HOST",
config.get("postgres", "host", fallback="localhost"),
),
"port": os.environ.get(
"POSTGRES_PORT", config.get("postgres", "port", fallback=5432)
),
"user": os.environ.get(
"POSTGRES_USER", config.get("postgres", "user", fallback=None)
),
"password": os.environ.get(
"POSTGRES_PASSWORD",
config.get("postgres", "password", fallback=None),
),
"database": os.environ.get(
"POSTGRES_DATABASE",
config.get("postgres", "database", fallback=None),
),
"workspace": os.environ.get(
"POSTGRES_WORKSPACE",
config.get("postgres", "workspace", fallback="default"),
),
}
def _get_oracle_config():
return {
"user": os.environ.get(
"ORACLE_USER",
config.get("oracle", "user", fallback=None),
),
"password": os.environ.get(
"ORACLE_PASSWORD",
config.get("oracle", "password", fallback=None),
),
"dsn": os.environ.get(
"ORACLE_DSN",
config.get("oracle", "dsn", fallback=None),
),
"config_dir": os.environ.get(
"ORACLE_CONFIG_DIR",
config.get("oracle", "config_dir", fallback=None),
),
"wallet_location": os.environ.get(
"ORACLE_WALLET_LOCATION",
config.get("oracle", "wallet_location", fallback=None),
),
"wallet_password": os.environ.get(
"ORACLE_WALLET_PASSWORD",
config.get("oracle", "wallet_password", fallback=None),
),
"workspace": os.environ.get(
"ORACLE_WORKSPACE",
config.get("oracle", "workspace", fallback="default"),
),
}
def _get_tidb_config():
return {
"host": os.environ.get(
"TIDB_HOST",
config.get("tidb", "host", fallback="localhost"),
),
"port": os.environ.get(
"TIDB_PORT", config.get("tidb", "port", fallback=4000)
),
"user": os.environ.get(
"TIDB_USER",
config.get("tidb", "user", fallback=None),
),
"password": os.environ.get(
"TIDB_PASSWORD",
config.get("tidb", "password", fallback=None),
),
"database": os.environ.get(
"TIDB_DATABASE",
config.get("tidb", "database", fallback=None),
),
"workspace": os.environ.get(
"TIDB_WORKSPACE",
config.get("tidb", "workspace", fallback="default"),
),
}
# Create the optional API key dependency
optional_api_key = get_api_key_dependency(api_key)
@@ -921,6 +1106,17 @@ def create_app(args):
namespace_prefix=args.namespace_prefix,
)
# Collect all storage instances
storage_instances = [
("full_docs", rag.full_docs),
("text_chunks", rag.text_chunks),
("chunk_entity_relation_graph", rag.chunk_entity_relation_graph),
("entities_vdb", rag.entities_vdb),
("relationships_vdb", rag.relationships_vdb),
("chunks_vdb", rag.chunks_vdb),
("doc_status", rag.doc_status),
]
async def index_file(file_path: Union[str, Path]) -> None:
"""Index all files inside the folder with support for multiple file formats