Merge pull request #833 from danielaskdd/import-as-needed

Implement dynamic database module imports
This commit is contained in:
Yannick Stephan
2025-02-18 15:13:33 +01:00
committed by GitHub
2 changed files with 119 additions and 97 deletions

View File

@@ -41,25 +41,35 @@ from .ollama_api import (
OllamaAPI,
)
from .ollama_api import ollama_server_infos
from ..kg.postgres_impl import (
PostgreSQLDB,
PGKVStorage,
PGVectorStorage,
PGGraphStorage,
PGDocStatusStorage,
)
from ..kg.oracle_impl import (
OracleDB,
OracleKVStorage,
OracleVectorDBStorage,
OracleGraphStorage,
)
from ..kg.tidb_impl import (
TiDB,
TiDBKVStorage,
TiDBVectorDBStorage,
TiDBGraphStorage,
)
def get_db_type_from_storage_class(class_name: str) -> str | None:
"""Determine database type based on storage class name"""
if class_name.startswith("PG"):
return "postgres"
elif class_name.startswith("Oracle"):
return "oracle"
elif class_name.startswith("TiDB"):
return "tidb"
return None
def import_db_module(db_type: str):
"""Dynamically import database module"""
if db_type == "postgres":
from ..kg.postgres_impl import PostgreSQLDB
return PostgreSQLDB
elif db_type == "oracle":
from ..kg.oracle_impl import OracleDB
return OracleDB
elif db_type == "tidb":
from ..kg.tidb_impl import TiDB
return TiDB
return None
# Load environment variables
try:
@@ -333,28 +343,28 @@ def parse_args() -> argparse.Namespace:
default=get_env_value(
"LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE
),
help=f"KV存储实现 (default: {DefaultRAGStorageConfig.KV_STORAGE})",
help=f"KV storage implementation (default: {DefaultRAGStorageConfig.KV_STORAGE})",
)
parser.add_argument(
"--doc-status-storage",
default=get_env_value(
"LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE
),
help=f"文档状态存储实现 (default: {DefaultRAGStorageConfig.DOC_STATUS_STORAGE})",
help=f"Document status storage implementation (default: {DefaultRAGStorageConfig.DOC_STATUS_STORAGE})",
)
parser.add_argument(
"--graph-storage",
default=get_env_value(
"LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE
),
help=f"图存储实现 (default: {DefaultRAGStorageConfig.GRAPH_STORAGE})",
help=f"Graph storage implementation (default: {DefaultRAGStorageConfig.GRAPH_STORAGE})",
)
parser.add_argument(
"--vector-storage",
default=get_env_value(
"LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE
),
help=f"向量存储实现 (default: {DefaultRAGStorageConfig.VECTOR_STORAGE})",
help=f"Vector storage implementation (default: {DefaultRAGStorageConfig.VECTOR_STORAGE})",
)
# Bindings configuration
@@ -788,6 +798,36 @@ class InsertResponse(BaseModel):
class DocStatusResponse(BaseModel):
@staticmethod
def format_datetime(dt: Any) -> Optional[str]:
"""Format datetime to ISO string
Args:
dt: Datetime object or string
Returns:
Formatted datetime string or None
"""
if dt is None:
return None
if isinstance(dt, str):
return dt
return dt.isoformat()
"""Response model for document status
Attributes:
id: Document identifier
content_summary: Summary of document content
content_length: Length of document content
status: Current processing status
created_at: Creation timestamp (ISO format string)
updated_at: Last update timestamp (ISO format string)
chunks_count: Number of chunks (optional)
error: Error message if any (optional)
metadata: Additional metadata (optional)
"""
id: str
content_summary: str
content_length: int
@@ -890,72 +930,51 @@ def create_app(args):
async def lifespan(app: FastAPI):
"""Lifespan context manager for startup and shutdown events"""
# Initialize database connections
postgres_db = None
oracle_db = None
tidb_db = None
db_instances = {}
# Store background tasks
app.state.background_tasks = set()
try:
# Check if PostgreSQL is needed
if any(
isinstance(
storage_instance,
(PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage),
)
for _, storage_instance in storage_instances
):
postgres_db = PostgreSQLDB(_get_postgres_config())
await postgres_db.initdb()
await postgres_db.check_tables()
# Check which database types are used
db_types = set()
for storage_name, storage_instance in storage_instances:
if isinstance(
storage_instance,
(
PGKVStorage,
PGVectorStorage,
PGGraphStorage,
PGDocStatusStorage,
),
):
storage_instance.db = postgres_db
logger.info(f"Injected postgres_db to {storage_name}")
db_type = get_db_type_from_storage_class(
storage_instance.__class__.__name__
)
if db_type:
db_types.add(db_type)
# Check if Oracle is needed
if any(
isinstance(
storage_instance,
(OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage),
)
for _, storage_instance in storage_instances
):
oracle_db = OracleDB(_get_oracle_config())
await oracle_db.check_tables()
for storage_name, storage_instance in storage_instances:
if isinstance(
storage_instance,
(OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage),
):
storage_instance.db = oracle_db
logger.info(f"Injected oracle_db to {storage_name}")
# Import and initialize databases as needed
for db_type in db_types:
if db_type == "postgres":
DB = import_db_module("postgres")
db = DB(_get_postgres_config())
await db.initdb()
await db.check_tables()
db_instances["postgres"] = db
elif db_type == "oracle":
DB = import_db_module("oracle")
db = DB(_get_oracle_config())
await db.check_tables()
db_instances["oracle"] = db
elif db_type == "tidb":
DB = import_db_module("tidb")
db = DB(_get_tidb_config())
await db.check_tables()
db_instances["tidb"] = db
# Check if TiDB is needed
if any(
isinstance(
storage_instance,
(TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage),
)
for _, storage_instance in storage_instances
):
tidb_db = TiDB(_get_tidb_config())
await tidb_db.check_tables()
# Inject database instances into storage classes
for storage_name, storage_instance in storage_instances:
if isinstance(
storage_instance,
(TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage),
):
storage_instance.db = tidb_db
logger.info(f"Injected tidb_db to {storage_name}")
db_type = get_db_type_from_storage_class(
storage_instance.__class__.__name__
)
if db_type:
if db_type not in db_instances:
error_msg = f"Database type '{db_type}' is required by {storage_name} but not initialized"
logger.error(error_msg)
raise RuntimeError(error_msg)
storage_instance.db = db_instances[db_type]
logger.info(f"Injected {db_type} db to {storage_name}")
# Auto scan documents if enabled
if args.auto_scan_at_startup:
@@ -981,17 +1000,17 @@ def create_app(args):
finally:
# Clean up database connections
if postgres_db and hasattr(postgres_db, "pool"):
await postgres_db.pool.close()
logger.info("Closed PostgreSQL connection pool")
if oracle_db and hasattr(oracle_db, "pool"):
await oracle_db.pool.close()
logger.info("Closed Oracle connection pool")
if tidb_db and hasattr(tidb_db, "pool"):
await tidb_db.pool.close()
logger.info("Closed TiDB connection pool")
for db_type, db in db_instances.items():
if hasattr(db, "pool"):
await db.pool.close()
# Use more accurate database name display
db_names = {
"postgres": "PostgreSQL",
"oracle": "Oracle",
"tidb": "TiDB",
}
db_name = db_names.get(db_type, db_type)
logger.info(f"Closed {db_name} database connection pool")
# Initialize FastAPI
app = FastAPI(
@@ -1311,7 +1330,7 @@ def create_app(args):
case ".pdf":
if not pm.is_installed("pypdf2"):
pm.install("pypdf2")
from PyPDF2 import PdfReader
from PyPDF2 import PdfReader # type: ignore
from io import BytesIO
pdf_file = BytesIO(file)
@@ -1869,8 +1888,12 @@ def create_app(args):
content_summary=doc_status.content_summary,
content_length=doc_status.content_length,
status=doc_status.status,
created_at=doc_status.created_at,
updated_at=doc_status.updated_at,
created_at=DocStatusResponse.format_datetime(
doc_status.created_at
),
updated_at=DocStatusResponse.format_datetime(
doc_status.updated_at
),
chunks_count=doc_status.chunks_count,
error=doc_status.error,
metadata=doc_status.metadata,

View File

@@ -1,6 +1,5 @@
ascii_colors
fastapi
nest_asyncio
numpy
pipmaster
python-dotenv