refactor database connection management and improve storage lifecycle handling
update
This commit is contained in:
@@ -15,11 +15,6 @@ import logging
|
||||
import argparse
|
||||
from typing import List, Any, Literal, Optional, Dict
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.base import DocProcessingStatus, DocStatus
|
||||
from lightrag.types import GPTKeywordExtractionFormat
|
||||
from lightrag.api import __api_version__
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import aiofiles
|
||||
@@ -36,39 +31,13 @@ import configparser
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.base import DocProcessingStatus, DocStatus
|
||||
from lightrag.types import GPTKeywordExtractionFormat
|
||||
from lightrag.api import __api_version__
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
from lightrag.utils import logger
|
||||
from .ollama_api import (
|
||||
OllamaAPI,
|
||||
)
|
||||
from .ollama_api import ollama_server_infos
|
||||
|
||||
|
||||
def get_db_type_from_storage_class(class_name: str) -> str | None:
|
||||
"""Determine database type based on storage class name"""
|
||||
if class_name.startswith("PG"):
|
||||
return "postgres"
|
||||
elif class_name.startswith("Oracle"):
|
||||
return "oracle"
|
||||
elif class_name.startswith("TiDB"):
|
||||
return "tidb"
|
||||
return None
|
||||
|
||||
|
||||
def import_db_module(db_type: str):
|
||||
"""Dynamically import database module"""
|
||||
if db_type == "postgres":
|
||||
from ..kg.postgres_impl import PostgreSQLDB
|
||||
|
||||
return PostgreSQLDB
|
||||
elif db_type == "oracle":
|
||||
from ..kg.oracle_impl import OracleDB
|
||||
|
||||
return OracleDB
|
||||
elif db_type == "tidb":
|
||||
from ..kg.tidb_impl import TiDB
|
||||
|
||||
return TiDB
|
||||
return None
|
||||
from .ollama_api import OllamaAPI, ollama_server_infos
|
||||
|
||||
|
||||
# Load environment variables
|
||||
@@ -929,52 +898,12 @@ def create_app(args):
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Lifespan context manager for startup and shutdown events"""
|
||||
# Initialize database connections
|
||||
db_instances = {}
|
||||
# Store background tasks
|
||||
app.state.background_tasks = set()
|
||||
|
||||
try:
|
||||
# Check which database types are used
|
||||
db_types = set()
|
||||
for storage_name, storage_instance in storage_instances:
|
||||
db_type = get_db_type_from_storage_class(
|
||||
storage_instance.__class__.__name__
|
||||
)
|
||||
if db_type:
|
||||
db_types.add(db_type)
|
||||
|
||||
# Import and initialize databases as needed
|
||||
for db_type in db_types:
|
||||
if db_type == "postgres":
|
||||
DB = import_db_module("postgres")
|
||||
db = DB(_get_postgres_config())
|
||||
await db.initdb()
|
||||
await db.check_tables()
|
||||
db_instances["postgres"] = db
|
||||
elif db_type == "oracle":
|
||||
DB = import_db_module("oracle")
|
||||
db = DB(_get_oracle_config())
|
||||
await db.check_tables()
|
||||
db_instances["oracle"] = db
|
||||
elif db_type == "tidb":
|
||||
DB = import_db_module("tidb")
|
||||
db = DB(_get_tidb_config())
|
||||
await db.check_tables()
|
||||
db_instances["tidb"] = db
|
||||
|
||||
# Inject database instances into storage classes
|
||||
for storage_name, storage_instance in storage_instances:
|
||||
db_type = get_db_type_from_storage_class(
|
||||
storage_instance.__class__.__name__
|
||||
)
|
||||
if db_type:
|
||||
if db_type not in db_instances:
|
||||
error_msg = f"Database type '{db_type}' is required by {storage_name} but not initialized"
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
storage_instance.db = db_instances[db_type]
|
||||
logger.info(f"Injected {db_type} db to {storage_name}")
|
||||
# Initialize database connections
|
||||
await rag.initialize_storages()
|
||||
|
||||
# Auto scan documents if enabled
|
||||
if args.auto_scan_at_startup:
|
||||
@@ -1000,17 +929,7 @@ def create_app(args):
|
||||
|
||||
finally:
|
||||
# Clean up database connections
|
||||
for db_type, db in db_instances.items():
|
||||
if hasattr(db, "pool"):
|
||||
await db.pool.close()
|
||||
# Use more accurate database name display
|
||||
db_names = {
|
||||
"postgres": "PostgreSQL",
|
||||
"oracle": "Oracle",
|
||||
"tidb": "TiDB",
|
||||
}
|
||||
db_name = db_names.get(db_type, db_type)
|
||||
logger.info(f"Closed {db_name} database connection pool")
|
||||
await rag.finalize_storages()
|
||||
|
||||
# Initialize FastAPI
|
||||
app = FastAPI(
|
||||
@@ -1042,92 +961,6 @@ def create_app(args):
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Database configuration functions
|
||||
def _get_postgres_config():
|
||||
return {
|
||||
"host": os.environ.get(
|
||||
"POSTGRES_HOST",
|
||||
config.get("postgres", "host", fallback="localhost"),
|
||||
),
|
||||
"port": os.environ.get(
|
||||
"POSTGRES_PORT", config.get("postgres", "port", fallback=5432)
|
||||
),
|
||||
"user": os.environ.get(
|
||||
"POSTGRES_USER", config.get("postgres", "user", fallback=None)
|
||||
),
|
||||
"password": os.environ.get(
|
||||
"POSTGRES_PASSWORD",
|
||||
config.get("postgres", "password", fallback=None),
|
||||
),
|
||||
"database": os.environ.get(
|
||||
"POSTGRES_DATABASE",
|
||||
config.get("postgres", "database", fallback=None),
|
||||
),
|
||||
"workspace": os.environ.get(
|
||||
"POSTGRES_WORKSPACE",
|
||||
config.get("postgres", "workspace", fallback="default"),
|
||||
),
|
||||
}
|
||||
|
||||
def _get_oracle_config():
|
||||
return {
|
||||
"user": os.environ.get(
|
||||
"ORACLE_USER",
|
||||
config.get("oracle", "user", fallback=None),
|
||||
),
|
||||
"password": os.environ.get(
|
||||
"ORACLE_PASSWORD",
|
||||
config.get("oracle", "password", fallback=None),
|
||||
),
|
||||
"dsn": os.environ.get(
|
||||
"ORACLE_DSN",
|
||||
config.get("oracle", "dsn", fallback=None),
|
||||
),
|
||||
"config_dir": os.environ.get(
|
||||
"ORACLE_CONFIG_DIR",
|
||||
config.get("oracle", "config_dir", fallback=None),
|
||||
),
|
||||
"wallet_location": os.environ.get(
|
||||
"ORACLE_WALLET_LOCATION",
|
||||
config.get("oracle", "wallet_location", fallback=None),
|
||||
),
|
||||
"wallet_password": os.environ.get(
|
||||
"ORACLE_WALLET_PASSWORD",
|
||||
config.get("oracle", "wallet_password", fallback=None),
|
||||
),
|
||||
"workspace": os.environ.get(
|
||||
"ORACLE_WORKSPACE",
|
||||
config.get("oracle", "workspace", fallback="default"),
|
||||
),
|
||||
}
|
||||
|
||||
def _get_tidb_config():
|
||||
return {
|
||||
"host": os.environ.get(
|
||||
"TIDB_HOST",
|
||||
config.get("tidb", "host", fallback="localhost"),
|
||||
),
|
||||
"port": os.environ.get(
|
||||
"TIDB_PORT", config.get("tidb", "port", fallback=4000)
|
||||
),
|
||||
"user": os.environ.get(
|
||||
"TIDB_USER",
|
||||
config.get("tidb", "user", fallback=None),
|
||||
),
|
||||
"password": os.environ.get(
|
||||
"TIDB_PASSWORD",
|
||||
config.get("tidb", "password", fallback=None),
|
||||
),
|
||||
"database": os.environ.get(
|
||||
"TIDB_DATABASE",
|
||||
config.get("tidb", "database", fallback=None),
|
||||
),
|
||||
"workspace": os.environ.get(
|
||||
"TIDB_WORKSPACE",
|
||||
config.get("tidb", "workspace", fallback="default"),
|
||||
),
|
||||
}
|
||||
|
||||
# Create the optional API key dependency
|
||||
optional_api_key = get_api_key_dependency(api_key)
|
||||
|
||||
@@ -1262,6 +1095,7 @@ def create_app(args):
|
||||
},
|
||||
log_level=args.log_level,
|
||||
namespace_prefix=args.namespace_prefix,
|
||||
is_managed_by_server=True,
|
||||
)
|
||||
else:
|
||||
rag = LightRAG(
|
||||
@@ -1293,20 +1127,9 @@ def create_app(args):
|
||||
},
|
||||
log_level=args.log_level,
|
||||
namespace_prefix=args.namespace_prefix,
|
||||
is_managed_by_server=True,
|
||||
)
|
||||
|
||||
# Collect all storage instances
|
||||
storage_instances = [
|
||||
("full_docs", rag.full_docs),
|
||||
("text_chunks", rag.text_chunks),
|
||||
("chunk_entity_relation_graph", rag.chunk_entity_relation_graph),
|
||||
("entities_vdb", rag.entities_vdb),
|
||||
("relationships_vdb", rag.relationships_vdb),
|
||||
("chunks_vdb", rag.chunks_vdb),
|
||||
("doc_status", rag.doc_status),
|
||||
("llm_response_cache", rag.llm_response_cache),
|
||||
]
|
||||
|
||||
async def pipeline_enqueue_file(file_path: Path) -> bool:
|
||||
"""Add a file to the queue for processing
|
||||
|
||||
|
Reference in New Issue
Block a user