refactor: move database connection pool initialization to lifespan of FastAPI
- Add proper database connection lifecycle management - Add connection pool cleanup in FastAPI lifespan
This commit is contained in:
@@ -355,91 +355,6 @@ class LightRAG:
|
||||
list[dict[str, Any]],
|
||||
] = chunking_by_token_size
|
||||
|
||||
def _get_postgres_config(self):
|
||||
return {
|
||||
"host": os.environ.get(
|
||||
"POSTGRES_HOST",
|
||||
config.get("postgres", "host", fallback="localhost"),
|
||||
),
|
||||
"port": os.environ.get(
|
||||
"POSTGRES_PORT", config.get("postgres", "port", fallback=5432)
|
||||
),
|
||||
"user": os.environ.get(
|
||||
"POSTGRES_USER", config.get("postgres", "user", fallback=None)
|
||||
),
|
||||
"password": os.environ.get(
|
||||
"POSTGRES_PASSWORD",
|
||||
config.get("postgres", "password", fallback=None),
|
||||
),
|
||||
"database": os.environ.get(
|
||||
"POSTGRES_DATABASE",
|
||||
config.get("postgres", "database", fallback=None),
|
||||
),
|
||||
"workspace": os.environ.get(
|
||||
"POSTGRES_WORKSPACE",
|
||||
config.get("postgres", "workspace", fallback="default"),
|
||||
),
|
||||
}
|
||||
|
||||
def _get_oracle_config(self):
|
||||
return {
|
||||
"user": os.environ.get(
|
||||
"ORACLE_USER",
|
||||
config.get("oracle", "user", fallback=None),
|
||||
),
|
||||
"password": os.environ.get(
|
||||
"ORACLE_PASSWORD",
|
||||
config.get("oracle", "password", fallback=None),
|
||||
),
|
||||
"dsn": os.environ.get(
|
||||
"ORACLE_DSN",
|
||||
config.get("oracle", "dsn", fallback=None),
|
||||
),
|
||||
"config_dir": os.environ.get(
|
||||
"ORACLE_CONFIG_DIR",
|
||||
config.get("oracle", "config_dir", fallback=None),
|
||||
),
|
||||
"wallet_location": os.environ.get(
|
||||
"ORACLE_WALLET_LOCATION",
|
||||
config.get("oracle", "wallet_location", fallback=None),
|
||||
),
|
||||
"wallet_password": os.environ.get(
|
||||
"ORACLE_WALLET_PASSWORD",
|
||||
config.get("oracle", "wallet_password", fallback=None),
|
||||
),
|
||||
"workspace": os.environ.get(
|
||||
"ORACLE_WORKSPACE",
|
||||
config.get("oracle", "workspace", fallback="default"),
|
||||
),
|
||||
}
|
||||
|
||||
def _get_tidb_config(self):
|
||||
return {
|
||||
"host": os.environ.get(
|
||||
"TIDB_HOST",
|
||||
config.get("tidb", "host", fallback="localhost"),
|
||||
),
|
||||
"port": os.environ.get(
|
||||
"TIDB_PORT", config.get("tidb", "port", fallback=4000)
|
||||
),
|
||||
"user": os.environ.get(
|
||||
"TIDB_USER",
|
||||
config.get("tidb", "user", fallback=None),
|
||||
),
|
||||
"password": os.environ.get(
|
||||
"TIDB_PASSWORD",
|
||||
config.get("tidb", "password", fallback=None),
|
||||
),
|
||||
"database": os.environ.get(
|
||||
"TIDB_DATABASE",
|
||||
config.get("tidb", "database", fallback=None),
|
||||
),
|
||||
"workspace": os.environ.get(
|
||||
"TIDB_WORKSPACE",
|
||||
config.get("tidb", "workspace", fallback="default"),
|
||||
),
|
||||
}
|
||||
|
||||
def verify_storage_implementation(
|
||||
self, storage_type: str, storage_name: str
|
||||
) -> None:
|
||||
@@ -609,20 +524,6 @@ class LightRAG:
|
||||
)
|
||||
|
||||
|
||||
# Collect all storage instances with their names
|
||||
storage_instances = [
|
||||
("full_docs", self.full_docs),
|
||||
("text_chunks", self.text_chunks),
|
||||
("chunk_entity_relation_graph", self.chunk_entity_relation_graph),
|
||||
("entities_vdb", self.entities_vdb),
|
||||
("relationships_vdb", self.relationships_vdb),
|
||||
("chunks_vdb", self.chunks_vdb),
|
||||
("doc_status", self.doc_status),
|
||||
]
|
||||
|
||||
# Initialize database connections if needed
|
||||
loop = always_get_an_event_loop()
|
||||
loop.run_until_complete(self._initialize_database_if_needed(storage_instances))
|
||||
|
||||
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
||||
partial(
|
||||
@@ -646,78 +547,6 @@ class LightRAG:
|
||||
storage_class = lazy_external_import(import_path, storage_name)
|
||||
return storage_class
|
||||
|
||||
async def _initialize_database_if_needed(self, storage_instances: list[tuple[str, Any]]):
|
||||
"""Intialize database connection and inject it to storage implementation if needed"""
|
||||
from .kg.postgres_impl import PostgreSQLDB
|
||||
from .kg.oracle_impl import OracleDB
|
||||
from .kg.tidb_impl import TiDB
|
||||
from .kg.postgres_impl import (
|
||||
PGKVStorage,
|
||||
PGVectorStorage,
|
||||
PGGraphStorage,
|
||||
PGDocStatusStorage,
|
||||
)
|
||||
from .kg.oracle_impl import (
|
||||
OracleKVStorage,
|
||||
OracleVectorDBStorage,
|
||||
OracleGraphStorage,
|
||||
)
|
||||
from .kg.tidb_impl import (
|
||||
TiDBKVStorage,
|
||||
TiDBVectorDBStorage,
|
||||
TiDBGraphStorage)
|
||||
|
||||
# Checking if PostgreSQL is needed
|
||||
if any(
|
||||
isinstance(
|
||||
storage_instance,
|
||||
(PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage),
|
||||
)
|
||||
for _, storage_instance in storage_instances
|
||||
):
|
||||
postgres_db = PostgreSQLDB(self._get_postgres_config())
|
||||
await postgres_db.initdb()
|
||||
await postgres_db.check_tables()
|
||||
for storage_name, storage_instance in storage_instances:
|
||||
if isinstance(
|
||||
storage_instance,
|
||||
(PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage),
|
||||
):
|
||||
storage_instance.db = postgres_db
|
||||
logger.info(f"Injected postgres_db to {storage_name}")
|
||||
|
||||
# Checking if Oracle is needed
|
||||
if any(
|
||||
isinstance(
|
||||
storage_instance, (OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage)
|
||||
)
|
||||
for _, storage_instance in storage_instances
|
||||
):
|
||||
oracle_db = OracleDB(self._get_oracle_config())
|
||||
await oracle_db.check_tables()
|
||||
for storage_name, storage_instance in storage_instances:
|
||||
if isinstance(
|
||||
storage_instance,
|
||||
(OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage),
|
||||
):
|
||||
storage_instance.db = oracle_db
|
||||
logger.info(f"Injected oracle_db to {storage_name}")
|
||||
|
||||
# Checking if TiDB is needed
|
||||
if any(
|
||||
isinstance(storage_instance, (TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage))
|
||||
for _, storage_instance in storage_instances
|
||||
):
|
||||
tidb_db = TiDB(self._get_tidb_config())
|
||||
await tidb_db.check_tables()
|
||||
# 注入db实例
|
||||
for storage_name, storage_instance in storage_instances:
|
||||
if isinstance(
|
||||
storage_instance, (TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage)
|
||||
):
|
||||
storage_instance.db = tidb_db
|
||||
logger.info(f"Injected tidb_db to {storage_name}")
|
||||
|
||||
def set_storage_client(self, db_client):
|
||||
# Inject db to storage implementation (only tested on Oracle Database
|
||||
# Deprecated, seting correct value to *_storage creating LightRAG insteaded
|
||||
|
Reference in New Issue
Block a user