Merge pull request #833 from danielaskdd/import-as-needed
Implement dynamic database module imports
This commit is contained in:
@@ -41,25 +41,35 @@ from .ollama_api import (
|
|||||||
OllamaAPI,
|
OllamaAPI,
|
||||||
)
|
)
|
||||||
from .ollama_api import ollama_server_infos
|
from .ollama_api import ollama_server_infos
|
||||||
from ..kg.postgres_impl import (
|
|
||||||
PostgreSQLDB,
|
|
||||||
PGKVStorage,
|
def get_db_type_from_storage_class(class_name: str) -> str | None:
|
||||||
PGVectorStorage,
|
"""Determine database type based on storage class name"""
|
||||||
PGGraphStorage,
|
if class_name.startswith("PG"):
|
||||||
PGDocStatusStorage,
|
return "postgres"
|
||||||
)
|
elif class_name.startswith("Oracle"):
|
||||||
from ..kg.oracle_impl import (
|
return "oracle"
|
||||||
OracleDB,
|
elif class_name.startswith("TiDB"):
|
||||||
OracleKVStorage,
|
return "tidb"
|
||||||
OracleVectorDBStorage,
|
return None
|
||||||
OracleGraphStorage,
|
|
||||||
)
|
|
||||||
from ..kg.tidb_impl import (
|
def import_db_module(db_type: str):
|
||||||
TiDB,
|
"""Dynamically import database module"""
|
||||||
TiDBKVStorage,
|
if db_type == "postgres":
|
||||||
TiDBVectorDBStorage,
|
from ..kg.postgres_impl import PostgreSQLDB
|
||||||
TiDBGraphStorage,
|
|
||||||
)
|
return PostgreSQLDB
|
||||||
|
elif db_type == "oracle":
|
||||||
|
from ..kg.oracle_impl import OracleDB
|
||||||
|
|
||||||
|
return OracleDB
|
||||||
|
elif db_type == "tidb":
|
||||||
|
from ..kg.tidb_impl import TiDB
|
||||||
|
|
||||||
|
return TiDB
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
# Load environment variables
|
# Load environment variables
|
||||||
try:
|
try:
|
||||||
@@ -333,28 +343,28 @@ def parse_args() -> argparse.Namespace:
|
|||||||
default=get_env_value(
|
default=get_env_value(
|
||||||
"LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE
|
"LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE
|
||||||
),
|
),
|
||||||
help=f"KV存储实现 (default: {DefaultRAGStorageConfig.KV_STORAGE})",
|
help=f"KV storage implementation (default: {DefaultRAGStorageConfig.KV_STORAGE})",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--doc-status-storage",
|
"--doc-status-storage",
|
||||||
default=get_env_value(
|
default=get_env_value(
|
||||||
"LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE
|
"LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE
|
||||||
),
|
),
|
||||||
help=f"文档状态存储实现 (default: {DefaultRAGStorageConfig.DOC_STATUS_STORAGE})",
|
help=f"Document status storage implementation (default: {DefaultRAGStorageConfig.DOC_STATUS_STORAGE})",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--graph-storage",
|
"--graph-storage",
|
||||||
default=get_env_value(
|
default=get_env_value(
|
||||||
"LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE
|
"LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE
|
||||||
),
|
),
|
||||||
help=f"图存储实现 (default: {DefaultRAGStorageConfig.GRAPH_STORAGE})",
|
help=f"Graph storage implementation (default: {DefaultRAGStorageConfig.GRAPH_STORAGE})",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--vector-storage",
|
"--vector-storage",
|
||||||
default=get_env_value(
|
default=get_env_value(
|
||||||
"LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE
|
"LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE
|
||||||
),
|
),
|
||||||
help=f"向量存储实现 (default: {DefaultRAGStorageConfig.VECTOR_STORAGE})",
|
help=f"Vector storage implementation (default: {DefaultRAGStorageConfig.VECTOR_STORAGE})",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Bindings configuration
|
# Bindings configuration
|
||||||
@@ -788,6 +798,36 @@ class InsertResponse(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class DocStatusResponse(BaseModel):
|
class DocStatusResponse(BaseModel):
|
||||||
|
@staticmethod
|
||||||
|
def format_datetime(dt: Any) -> Optional[str]:
|
||||||
|
"""Format datetime to ISO string
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dt: Datetime object or string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted datetime string or None
|
||||||
|
"""
|
||||||
|
if dt is None:
|
||||||
|
return None
|
||||||
|
if isinstance(dt, str):
|
||||||
|
return dt
|
||||||
|
return dt.isoformat()
|
||||||
|
|
||||||
|
"""Response model for document status
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
id: Document identifier
|
||||||
|
content_summary: Summary of document content
|
||||||
|
content_length: Length of document content
|
||||||
|
status: Current processing status
|
||||||
|
created_at: Creation timestamp (ISO format string)
|
||||||
|
updated_at: Last update timestamp (ISO format string)
|
||||||
|
chunks_count: Number of chunks (optional)
|
||||||
|
error: Error message if any (optional)
|
||||||
|
metadata: Additional metadata (optional)
|
||||||
|
"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
content_summary: str
|
content_summary: str
|
||||||
content_length: int
|
content_length: int
|
||||||
@@ -890,72 +930,51 @@ def create_app(args):
|
|||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
"""Lifespan context manager for startup and shutdown events"""
|
"""Lifespan context manager for startup and shutdown events"""
|
||||||
# Initialize database connections
|
# Initialize database connections
|
||||||
postgres_db = None
|
db_instances = {}
|
||||||
oracle_db = None
|
|
||||||
tidb_db = None
|
|
||||||
# Store background tasks
|
# Store background tasks
|
||||||
app.state.background_tasks = set()
|
app.state.background_tasks = set()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Check if PostgreSQL is needed
|
# Check which database types are used
|
||||||
if any(
|
db_types = set()
|
||||||
isinstance(
|
|
||||||
storage_instance,
|
|
||||||
(PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage),
|
|
||||||
)
|
|
||||||
for _, storage_instance in storage_instances
|
|
||||||
):
|
|
||||||
postgres_db = PostgreSQLDB(_get_postgres_config())
|
|
||||||
await postgres_db.initdb()
|
|
||||||
await postgres_db.check_tables()
|
|
||||||
for storage_name, storage_instance in storage_instances:
|
for storage_name, storage_instance in storage_instances:
|
||||||
if isinstance(
|
db_type = get_db_type_from_storage_class(
|
||||||
storage_instance,
|
storage_instance.__class__.__name__
|
||||||
(
|
)
|
||||||
PGKVStorage,
|
if db_type:
|
||||||
PGVectorStorage,
|
db_types.add(db_type)
|
||||||
PGGraphStorage,
|
|
||||||
PGDocStatusStorage,
|
|
||||||
),
|
|
||||||
):
|
|
||||||
storage_instance.db = postgres_db
|
|
||||||
logger.info(f"Injected postgres_db to {storage_name}")
|
|
||||||
|
|
||||||
# Check if Oracle is needed
|
# Import and initialize databases as needed
|
||||||
if any(
|
for db_type in db_types:
|
||||||
isinstance(
|
if db_type == "postgres":
|
||||||
storage_instance,
|
DB = import_db_module("postgres")
|
||||||
(OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage),
|
db = DB(_get_postgres_config())
|
||||||
)
|
await db.initdb()
|
||||||
for _, storage_instance in storage_instances
|
await db.check_tables()
|
||||||
):
|
db_instances["postgres"] = db
|
||||||
oracle_db = OracleDB(_get_oracle_config())
|
elif db_type == "oracle":
|
||||||
await oracle_db.check_tables()
|
DB = import_db_module("oracle")
|
||||||
for storage_name, storage_instance in storage_instances:
|
db = DB(_get_oracle_config())
|
||||||
if isinstance(
|
await db.check_tables()
|
||||||
storage_instance,
|
db_instances["oracle"] = db
|
||||||
(OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage),
|
elif db_type == "tidb":
|
||||||
):
|
DB = import_db_module("tidb")
|
||||||
storage_instance.db = oracle_db
|
db = DB(_get_tidb_config())
|
||||||
logger.info(f"Injected oracle_db to {storage_name}")
|
await db.check_tables()
|
||||||
|
db_instances["tidb"] = db
|
||||||
|
|
||||||
# Check if TiDB is needed
|
# Inject database instances into storage classes
|
||||||
if any(
|
|
||||||
isinstance(
|
|
||||||
storage_instance,
|
|
||||||
(TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage),
|
|
||||||
)
|
|
||||||
for _, storage_instance in storage_instances
|
|
||||||
):
|
|
||||||
tidb_db = TiDB(_get_tidb_config())
|
|
||||||
await tidb_db.check_tables()
|
|
||||||
for storage_name, storage_instance in storage_instances:
|
for storage_name, storage_instance in storage_instances:
|
||||||
if isinstance(
|
db_type = get_db_type_from_storage_class(
|
||||||
storage_instance,
|
storage_instance.__class__.__name__
|
||||||
(TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage),
|
)
|
||||||
):
|
if db_type:
|
||||||
storage_instance.db = tidb_db
|
if db_type not in db_instances:
|
||||||
logger.info(f"Injected tidb_db to {storage_name}")
|
error_msg = f"Database type '{db_type}' is required by {storage_name} but not initialized"
|
||||||
|
logger.error(error_msg)
|
||||||
|
raise RuntimeError(error_msg)
|
||||||
|
storage_instance.db = db_instances[db_type]
|
||||||
|
logger.info(f"Injected {db_type} db to {storage_name}")
|
||||||
|
|
||||||
# Auto scan documents if enabled
|
# Auto scan documents if enabled
|
||||||
if args.auto_scan_at_startup:
|
if args.auto_scan_at_startup:
|
||||||
@@ -981,17 +1000,17 @@ def create_app(args):
|
|||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Clean up database connections
|
# Clean up database connections
|
||||||
if postgres_db and hasattr(postgres_db, "pool"):
|
for db_type, db in db_instances.items():
|
||||||
await postgres_db.pool.close()
|
if hasattr(db, "pool"):
|
||||||
logger.info("Closed PostgreSQL connection pool")
|
await db.pool.close()
|
||||||
|
# Use more accurate database name display
|
||||||
if oracle_db and hasattr(oracle_db, "pool"):
|
db_names = {
|
||||||
await oracle_db.pool.close()
|
"postgres": "PostgreSQL",
|
||||||
logger.info("Closed Oracle connection pool")
|
"oracle": "Oracle",
|
||||||
|
"tidb": "TiDB",
|
||||||
if tidb_db and hasattr(tidb_db, "pool"):
|
}
|
||||||
await tidb_db.pool.close()
|
db_name = db_names.get(db_type, db_type)
|
||||||
logger.info("Closed TiDB connection pool")
|
logger.info(f"Closed {db_name} database connection pool")
|
||||||
|
|
||||||
# Initialize FastAPI
|
# Initialize FastAPI
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
@@ -1311,7 +1330,7 @@ def create_app(args):
|
|||||||
case ".pdf":
|
case ".pdf":
|
||||||
if not pm.is_installed("pypdf2"):
|
if not pm.is_installed("pypdf2"):
|
||||||
pm.install("pypdf2")
|
pm.install("pypdf2")
|
||||||
from PyPDF2 import PdfReader
|
from PyPDF2 import PdfReader # type: ignore
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
pdf_file = BytesIO(file)
|
pdf_file = BytesIO(file)
|
||||||
@@ -1869,8 +1888,12 @@ def create_app(args):
|
|||||||
content_summary=doc_status.content_summary,
|
content_summary=doc_status.content_summary,
|
||||||
content_length=doc_status.content_length,
|
content_length=doc_status.content_length,
|
||||||
status=doc_status.status,
|
status=doc_status.status,
|
||||||
created_at=doc_status.created_at,
|
created_at=DocStatusResponse.format_datetime(
|
||||||
updated_at=doc_status.updated_at,
|
doc_status.created_at
|
||||||
|
),
|
||||||
|
updated_at=DocStatusResponse.format_datetime(
|
||||||
|
doc_status.updated_at
|
||||||
|
),
|
||||||
chunks_count=doc_status.chunks_count,
|
chunks_count=doc_status.chunks_count,
|
||||||
error=doc_status.error,
|
error=doc_status.error,
|
||||||
metadata=doc_status.metadata,
|
metadata=doc_status.metadata,
|
||||||
|
@@ -1,6 +1,5 @@
|
|||||||
ascii_colors
|
ascii_colors
|
||||||
fastapi
|
fastapi
|
||||||
nest_asyncio
|
|
||||||
numpy
|
numpy
|
||||||
pipmaster
|
pipmaster
|
||||||
python-dotenv
|
python-dotenv
|
||||||
|
Reference in New Issue
Block a user