refactor: move database connection pool initialization to lifespan of FastAPI
- Add proper database connection lifecycle management - Add connection pool cleanup in FastAPI lifespan
This commit is contained in:
@@ -33,14 +33,39 @@ from contextlib import asynccontextmanager
|
|||||||
from starlette.status import HTTP_403_FORBIDDEN
|
from starlette.status import HTTP_403_FORBIDDEN
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
import configparser
|
||||||
|
from lightrag.utils import logger
|
||||||
from .ollama_api import (
|
from .ollama_api import (
|
||||||
OllamaAPI,
|
OllamaAPI,
|
||||||
)
|
)
|
||||||
from .ollama_api import ollama_server_infos
|
from .ollama_api import ollama_server_infos
|
||||||
|
from ..kg.postgres_impl import (
|
||||||
|
PostgreSQLDB,
|
||||||
|
PGKVStorage,
|
||||||
|
PGVectorStorage,
|
||||||
|
PGGraphStorage,
|
||||||
|
PGDocStatusStorage,
|
||||||
|
)
|
||||||
|
from ..kg.oracle_impl import (
|
||||||
|
OracleDB,
|
||||||
|
OracleKVStorage,
|
||||||
|
OracleVectorDBStorage,
|
||||||
|
OracleGraphStorage,
|
||||||
|
)
|
||||||
|
from ..kg.tidb_impl import (
|
||||||
|
TiDB,
|
||||||
|
TiDBKVStorage,
|
||||||
|
TiDBVectorDBStorage,
|
||||||
|
TiDBGraphStorage,
|
||||||
|
)
|
||||||
|
|
||||||
# Load environment variables
|
# Load environment variables
|
||||||
load_dotenv(override=True)
|
load_dotenv(override=True)
|
||||||
|
|
||||||
|
# Initialize config parser
|
||||||
|
config = configparser.ConfigParser()
|
||||||
|
config.read("config.ini")
|
||||||
|
|
||||||
|
|
||||||
class RAGStorageConfig:
|
class RAGStorageConfig:
|
||||||
"""存储配置类,支持通过环境变量和命令行参数修改默认值"""
|
"""存储配置类,支持通过环境变量和命令行参数修改默认值"""
|
||||||
@@ -714,7 +739,68 @@ def create_app(args):
|
|||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
"""Lifespan context manager for startup and shutdown events"""
|
"""Lifespan context manager for startup and shutdown events"""
|
||||||
# Startup logic
|
# Initialize database connections
|
||||||
|
postgres_db = None
|
||||||
|
oracle_db = None
|
||||||
|
tidb_db = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check if PostgreSQL is needed
|
||||||
|
if any(
|
||||||
|
isinstance(
|
||||||
|
storage_instance,
|
||||||
|
(PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage),
|
||||||
|
)
|
||||||
|
for _, storage_instance in storage_instances
|
||||||
|
):
|
||||||
|
postgres_db = PostgreSQLDB(_get_postgres_config())
|
||||||
|
await postgres_db.initdb()
|
||||||
|
await postgres_db.check_tables()
|
||||||
|
for storage_name, storage_instance in storage_instances:
|
||||||
|
if isinstance(
|
||||||
|
storage_instance,
|
||||||
|
(PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage),
|
||||||
|
):
|
||||||
|
storage_instance.db = postgres_db
|
||||||
|
logger.info(f"Injected postgres_db to {storage_name}")
|
||||||
|
|
||||||
|
# Check if Oracle is needed
|
||||||
|
if any(
|
||||||
|
isinstance(
|
||||||
|
storage_instance,
|
||||||
|
(OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage),
|
||||||
|
)
|
||||||
|
for _, storage_instance in storage_instances
|
||||||
|
):
|
||||||
|
oracle_db = OracleDB(_get_oracle_config())
|
||||||
|
await oracle_db.check_tables()
|
||||||
|
for storage_name, storage_instance in storage_instances:
|
||||||
|
if isinstance(
|
||||||
|
storage_instance,
|
||||||
|
(OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage),
|
||||||
|
):
|
||||||
|
storage_instance.db = oracle_db
|
||||||
|
logger.info(f"Injected oracle_db to {storage_name}")
|
||||||
|
|
||||||
|
# Check if TiDB is needed
|
||||||
|
if any(
|
||||||
|
isinstance(
|
||||||
|
storage_instance,
|
||||||
|
(TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage),
|
||||||
|
)
|
||||||
|
for _, storage_instance in storage_instances
|
||||||
|
):
|
||||||
|
tidb_db = TiDB(_get_tidb_config())
|
||||||
|
await tidb_db.check_tables()
|
||||||
|
for storage_name, storage_instance in storage_instances:
|
||||||
|
if isinstance(
|
||||||
|
storage_instance,
|
||||||
|
(TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage),
|
||||||
|
):
|
||||||
|
storage_instance.db = tidb_db
|
||||||
|
logger.info(f"Injected tidb_db to {storage_name}")
|
||||||
|
|
||||||
|
# Auto scan documents if enabled
|
||||||
if args.auto_scan_at_startup:
|
if args.auto_scan_at_startup:
|
||||||
try:
|
try:
|
||||||
new_files = doc_manager.scan_directory_for_new_files()
|
new_files = doc_manager.scan_directory_for_new_files()
|
||||||
@@ -730,9 +816,22 @@ def create_app(args):
|
|||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error during startup indexing: {str(e)}")
|
logging.error(f"Error during startup indexing: {str(e)}")
|
||||||
|
|
||||||
yield
|
yield
|
||||||
# Cleanup logic (if needed)
|
|
||||||
pass
|
finally:
|
||||||
|
# Cleanup database connections
|
||||||
|
if postgres_db and hasattr(postgres_db, "pool"):
|
||||||
|
await postgres_db.pool.close()
|
||||||
|
logger.info("Closed PostgreSQL connection pool")
|
||||||
|
|
||||||
|
if oracle_db and hasattr(oracle_db, "pool"):
|
||||||
|
await oracle_db.pool.close()
|
||||||
|
logger.info("Closed Oracle connection pool")
|
||||||
|
|
||||||
|
if tidb_db and hasattr(tidb_db, "pool"):
|
||||||
|
await tidb_db.pool.close()
|
||||||
|
logger.info("Closed TiDB connection pool")
|
||||||
|
|
||||||
# Initialize FastAPI
|
# Initialize FastAPI
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
@@ -755,6 +854,92 @@ def create_app(args):
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Database configuration functions
|
||||||
|
def _get_postgres_config():
|
||||||
|
return {
|
||||||
|
"host": os.environ.get(
|
||||||
|
"POSTGRES_HOST",
|
||||||
|
config.get("postgres", "host", fallback="localhost"),
|
||||||
|
),
|
||||||
|
"port": os.environ.get(
|
||||||
|
"POSTGRES_PORT", config.get("postgres", "port", fallback=5432)
|
||||||
|
),
|
||||||
|
"user": os.environ.get(
|
||||||
|
"POSTGRES_USER", config.get("postgres", "user", fallback=None)
|
||||||
|
),
|
||||||
|
"password": os.environ.get(
|
||||||
|
"POSTGRES_PASSWORD",
|
||||||
|
config.get("postgres", "password", fallback=None),
|
||||||
|
),
|
||||||
|
"database": os.environ.get(
|
||||||
|
"POSTGRES_DATABASE",
|
||||||
|
config.get("postgres", "database", fallback=None),
|
||||||
|
),
|
||||||
|
"workspace": os.environ.get(
|
||||||
|
"POSTGRES_WORKSPACE",
|
||||||
|
config.get("postgres", "workspace", fallback="default"),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_oracle_config():
|
||||||
|
return {
|
||||||
|
"user": os.environ.get(
|
||||||
|
"ORACLE_USER",
|
||||||
|
config.get("oracle", "user", fallback=None),
|
||||||
|
),
|
||||||
|
"password": os.environ.get(
|
||||||
|
"ORACLE_PASSWORD",
|
||||||
|
config.get("oracle", "password", fallback=None),
|
||||||
|
),
|
||||||
|
"dsn": os.environ.get(
|
||||||
|
"ORACLE_DSN",
|
||||||
|
config.get("oracle", "dsn", fallback=None),
|
||||||
|
),
|
||||||
|
"config_dir": os.environ.get(
|
||||||
|
"ORACLE_CONFIG_DIR",
|
||||||
|
config.get("oracle", "config_dir", fallback=None),
|
||||||
|
),
|
||||||
|
"wallet_location": os.environ.get(
|
||||||
|
"ORACLE_WALLET_LOCATION",
|
||||||
|
config.get("oracle", "wallet_location", fallback=None),
|
||||||
|
),
|
||||||
|
"wallet_password": os.environ.get(
|
||||||
|
"ORACLE_WALLET_PASSWORD",
|
||||||
|
config.get("oracle", "wallet_password", fallback=None),
|
||||||
|
),
|
||||||
|
"workspace": os.environ.get(
|
||||||
|
"ORACLE_WORKSPACE",
|
||||||
|
config.get("oracle", "workspace", fallback="default"),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_tidb_config():
|
||||||
|
return {
|
||||||
|
"host": os.environ.get(
|
||||||
|
"TIDB_HOST",
|
||||||
|
config.get("tidb", "host", fallback="localhost"),
|
||||||
|
),
|
||||||
|
"port": os.environ.get(
|
||||||
|
"TIDB_PORT", config.get("tidb", "port", fallback=4000)
|
||||||
|
),
|
||||||
|
"user": os.environ.get(
|
||||||
|
"TIDB_USER",
|
||||||
|
config.get("tidb", "user", fallback=None),
|
||||||
|
),
|
||||||
|
"password": os.environ.get(
|
||||||
|
"TIDB_PASSWORD",
|
||||||
|
config.get("tidb", "password", fallback=None),
|
||||||
|
),
|
||||||
|
"database": os.environ.get(
|
||||||
|
"TIDB_DATABASE",
|
||||||
|
config.get("tidb", "database", fallback=None),
|
||||||
|
),
|
||||||
|
"workspace": os.environ.get(
|
||||||
|
"TIDB_WORKSPACE",
|
||||||
|
config.get("tidb", "workspace", fallback="default"),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
# Create the optional API key dependency
|
# Create the optional API key dependency
|
||||||
optional_api_key = get_api_key_dependency(api_key)
|
optional_api_key = get_api_key_dependency(api_key)
|
||||||
|
|
||||||
@@ -921,6 +1106,17 @@ def create_app(args):
|
|||||||
namespace_prefix=args.namespace_prefix,
|
namespace_prefix=args.namespace_prefix,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Collect all storage instances
|
||||||
|
storage_instances = [
|
||||||
|
("full_docs", rag.full_docs),
|
||||||
|
("text_chunks", rag.text_chunks),
|
||||||
|
("chunk_entity_relation_graph", rag.chunk_entity_relation_graph),
|
||||||
|
("entities_vdb", rag.entities_vdb),
|
||||||
|
("relationships_vdb", rag.relationships_vdb),
|
||||||
|
("chunks_vdb", rag.chunks_vdb),
|
||||||
|
("doc_status", rag.doc_status),
|
||||||
|
]
|
||||||
|
|
||||||
async def index_file(file_path: Union[str, Path]) -> None:
|
async def index_file(file_path: Union[str, Path]) -> None:
|
||||||
"""Index all files inside the folder with support for multiple file formats
|
"""Index all files inside the folder with support for multiple file formats
|
||||||
|
|
||||||
|
@@ -355,91 +355,6 @@ class LightRAG:
|
|||||||
list[dict[str, Any]],
|
list[dict[str, Any]],
|
||||||
] = chunking_by_token_size
|
] = chunking_by_token_size
|
||||||
|
|
||||||
def _get_postgres_config(self):
|
|
||||||
return {
|
|
||||||
"host": os.environ.get(
|
|
||||||
"POSTGRES_HOST",
|
|
||||||
config.get("postgres", "host", fallback="localhost"),
|
|
||||||
),
|
|
||||||
"port": os.environ.get(
|
|
||||||
"POSTGRES_PORT", config.get("postgres", "port", fallback=5432)
|
|
||||||
),
|
|
||||||
"user": os.environ.get(
|
|
||||||
"POSTGRES_USER", config.get("postgres", "user", fallback=None)
|
|
||||||
),
|
|
||||||
"password": os.environ.get(
|
|
||||||
"POSTGRES_PASSWORD",
|
|
||||||
config.get("postgres", "password", fallback=None),
|
|
||||||
),
|
|
||||||
"database": os.environ.get(
|
|
||||||
"POSTGRES_DATABASE",
|
|
||||||
config.get("postgres", "database", fallback=None),
|
|
||||||
),
|
|
||||||
"workspace": os.environ.get(
|
|
||||||
"POSTGRES_WORKSPACE",
|
|
||||||
config.get("postgres", "workspace", fallback="default"),
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
def _get_oracle_config(self):
|
|
||||||
return {
|
|
||||||
"user": os.environ.get(
|
|
||||||
"ORACLE_USER",
|
|
||||||
config.get("oracle", "user", fallback=None),
|
|
||||||
),
|
|
||||||
"password": os.environ.get(
|
|
||||||
"ORACLE_PASSWORD",
|
|
||||||
config.get("oracle", "password", fallback=None),
|
|
||||||
),
|
|
||||||
"dsn": os.environ.get(
|
|
||||||
"ORACLE_DSN",
|
|
||||||
config.get("oracle", "dsn", fallback=None),
|
|
||||||
),
|
|
||||||
"config_dir": os.environ.get(
|
|
||||||
"ORACLE_CONFIG_DIR",
|
|
||||||
config.get("oracle", "config_dir", fallback=None),
|
|
||||||
),
|
|
||||||
"wallet_location": os.environ.get(
|
|
||||||
"ORACLE_WALLET_LOCATION",
|
|
||||||
config.get("oracle", "wallet_location", fallback=None),
|
|
||||||
),
|
|
||||||
"wallet_password": os.environ.get(
|
|
||||||
"ORACLE_WALLET_PASSWORD",
|
|
||||||
config.get("oracle", "wallet_password", fallback=None),
|
|
||||||
),
|
|
||||||
"workspace": os.environ.get(
|
|
||||||
"ORACLE_WORKSPACE",
|
|
||||||
config.get("oracle", "workspace", fallback="default"),
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
def _get_tidb_config(self):
|
|
||||||
return {
|
|
||||||
"host": os.environ.get(
|
|
||||||
"TIDB_HOST",
|
|
||||||
config.get("tidb", "host", fallback="localhost"),
|
|
||||||
),
|
|
||||||
"port": os.environ.get(
|
|
||||||
"TIDB_PORT", config.get("tidb", "port", fallback=4000)
|
|
||||||
),
|
|
||||||
"user": os.environ.get(
|
|
||||||
"TIDB_USER",
|
|
||||||
config.get("tidb", "user", fallback=None),
|
|
||||||
),
|
|
||||||
"password": os.environ.get(
|
|
||||||
"TIDB_PASSWORD",
|
|
||||||
config.get("tidb", "password", fallback=None),
|
|
||||||
),
|
|
||||||
"database": os.environ.get(
|
|
||||||
"TIDB_DATABASE",
|
|
||||||
config.get("tidb", "database", fallback=None),
|
|
||||||
),
|
|
||||||
"workspace": os.environ.get(
|
|
||||||
"TIDB_WORKSPACE",
|
|
||||||
config.get("tidb", "workspace", fallback="default"),
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
def verify_storage_implementation(
|
def verify_storage_implementation(
|
||||||
self, storage_type: str, storage_name: str
|
self, storage_type: str, storage_name: str
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -609,20 +524,6 @@ class LightRAG:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Collect all storage instances with their names
|
|
||||||
storage_instances = [
|
|
||||||
("full_docs", self.full_docs),
|
|
||||||
("text_chunks", self.text_chunks),
|
|
||||||
("chunk_entity_relation_graph", self.chunk_entity_relation_graph),
|
|
||||||
("entities_vdb", self.entities_vdb),
|
|
||||||
("relationships_vdb", self.relationships_vdb),
|
|
||||||
("chunks_vdb", self.chunks_vdb),
|
|
||||||
("doc_status", self.doc_status),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Initialize database connections if needed
|
|
||||||
loop = always_get_an_event_loop()
|
|
||||||
loop.run_until_complete(self._initialize_database_if_needed(storage_instances))
|
|
||||||
|
|
||||||
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
||||||
partial(
|
partial(
|
||||||
@@ -646,78 +547,6 @@ class LightRAG:
|
|||||||
storage_class = lazy_external_import(import_path, storage_name)
|
storage_class = lazy_external_import(import_path, storage_name)
|
||||||
return storage_class
|
return storage_class
|
||||||
|
|
||||||
async def _initialize_database_if_needed(self, storage_instances: list[tuple[str, Any]]):
|
|
||||||
"""Intialize database connection and inject it to storage implementation if needed"""
|
|
||||||
from .kg.postgres_impl import PostgreSQLDB
|
|
||||||
from .kg.oracle_impl import OracleDB
|
|
||||||
from .kg.tidb_impl import TiDB
|
|
||||||
from .kg.postgres_impl import (
|
|
||||||
PGKVStorage,
|
|
||||||
PGVectorStorage,
|
|
||||||
PGGraphStorage,
|
|
||||||
PGDocStatusStorage,
|
|
||||||
)
|
|
||||||
from .kg.oracle_impl import (
|
|
||||||
OracleKVStorage,
|
|
||||||
OracleVectorDBStorage,
|
|
||||||
OracleGraphStorage,
|
|
||||||
)
|
|
||||||
from .kg.tidb_impl import (
|
|
||||||
TiDBKVStorage,
|
|
||||||
TiDBVectorDBStorage,
|
|
||||||
TiDBGraphStorage)
|
|
||||||
|
|
||||||
# Checking if PostgreSQL is needed
|
|
||||||
if any(
|
|
||||||
isinstance(
|
|
||||||
storage_instance,
|
|
||||||
(PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage),
|
|
||||||
)
|
|
||||||
for _, storage_instance in storage_instances
|
|
||||||
):
|
|
||||||
postgres_db = PostgreSQLDB(self._get_postgres_config())
|
|
||||||
await postgres_db.initdb()
|
|
||||||
await postgres_db.check_tables()
|
|
||||||
for storage_name, storage_instance in storage_instances:
|
|
||||||
if isinstance(
|
|
||||||
storage_instance,
|
|
||||||
(PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage),
|
|
||||||
):
|
|
||||||
storage_instance.db = postgres_db
|
|
||||||
logger.info(f"Injected postgres_db to {storage_name}")
|
|
||||||
|
|
||||||
# Checking if Oracle is needed
|
|
||||||
if any(
|
|
||||||
isinstance(
|
|
||||||
storage_instance, (OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage)
|
|
||||||
)
|
|
||||||
for _, storage_instance in storage_instances
|
|
||||||
):
|
|
||||||
oracle_db = OracleDB(self._get_oracle_config())
|
|
||||||
await oracle_db.check_tables()
|
|
||||||
for storage_name, storage_instance in storage_instances:
|
|
||||||
if isinstance(
|
|
||||||
storage_instance,
|
|
||||||
(OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage),
|
|
||||||
):
|
|
||||||
storage_instance.db = oracle_db
|
|
||||||
logger.info(f"Injected oracle_db to {storage_name}")
|
|
||||||
|
|
||||||
# Checking if TiDB is needed
|
|
||||||
if any(
|
|
||||||
isinstance(storage_instance, (TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage))
|
|
||||||
for _, storage_instance in storage_instances
|
|
||||||
):
|
|
||||||
tidb_db = TiDB(self._get_tidb_config())
|
|
||||||
await tidb_db.check_tables()
|
|
||||||
# 注入db实例
|
|
||||||
for storage_name, storage_instance in storage_instances:
|
|
||||||
if isinstance(
|
|
||||||
storage_instance, (TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage)
|
|
||||||
):
|
|
||||||
storage_instance.db = tidb_db
|
|
||||||
logger.info(f"Injected tidb_db to {storage_name}")
|
|
||||||
|
|
||||||
def set_storage_client(self, db_client):
|
def set_storage_client(self, db_client):
|
||||||
# Inject db to storage implementation (only tested on Oracle Database
|
# Inject db to storage implementation (only tested on Oracle Database
|
||||||
# Deprecated, seting correct value to *_storage creating LightRAG insteaded
|
# Deprecated, seting correct value to *_storage creating LightRAG insteaded
|
||||||
|
Reference in New Issue
Block a user