Refactor authentication logic and Swagger UI config

- Consolidate authentication dependencies
- Improve Swagger UI security parameters
This commit is contained in:
yangdx
2025-03-24 14:29:36 +08:00
parent 79bf26dfeb
commit 9e3994419e
4 changed files with 99 additions and 64 deletions

View File

@@ -19,6 +19,7 @@ from contextlib import asynccontextmanager
from dotenv import load_dotenv
from lightrag.api.utils_api import (
get_api_key_dependency,
get_combined_auth_dependency,
parse_args,
get_default_host,
display_splash_screen,
@@ -135,19 +136,28 @@ def create_app(args):
await rag.finalize_storages()
# Initialize FastAPI
app = FastAPI(
title="LightRAG API",
description="API for querying text using LightRAG with separate storage and input directories"
app_kwargs = {
"title": "LightRAG Server API",
"description": "Providing API for LightRAG core, Web UI and Ollama Model Emulation"
+ "(With authentication)"
if api_key
else "",
version=__api_version__,
openapi_url="/openapi.json", # Explicitly set OpenAPI schema URL
docs_url="/docs", # Explicitly set docs URL
redoc_url="/redoc", # Explicitly set redoc URL
openapi_tags=[{"name": "api"}],
lifespan=lifespan,
)
"version": __api_version__,
"openapi_url": "/openapi.json", # Explicitly set OpenAPI schema URL
"docs_url": "/docs", # Explicitly set docs URL
"redoc_url": "/redoc", # Explicitly set redoc URL
"openapi_tags": [{"name": "api"}],
"lifespan": lifespan,
}
# Configure Swagger UI parameters
# Enable persistAuthorization and tryItOutEnabled for better user experience
app_kwargs["swagger_ui_parameters"] = {
"persistAuthorization": True,
"tryItOutEnabled": True,
}
app = FastAPI(**app_kwargs)
def get_cors_origins():
"""Get allowed origins from environment variable
@@ -167,13 +177,8 @@ def create_app(args):
allow_headers=["*"],
)
# Create the optional API key dependency
# Create a dependency that passes the request to get_api_key_dependency
async def optional_api_key_dependency(request: Request):
# Create the dependency function with the request
api_key_dependency = get_api_key_dependency(api_key)
# Call the dependency function with the request
return await api_key_dependency(request)
# Create combined auth dependency for all endpoints
combined_auth = get_combined_auth_dependency(api_key)
# Create working directory if it doesn't exist
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
@@ -419,7 +424,7 @@ def create_app(args):
"api_version": __api_version__,
}
@app.get("/health", dependencies=[Depends(optional_api_key_dependency)])
@app.get("/health", dependencies=[Depends(combined_auth)])
async def get_status():
"""Get current system status"""
username = os.getenv("AUTH_USERNAME")