Refactor authentication logic and Swagger UI config

- Consolidate authentication dependencies
- Improve Swagger UI security parameters
This commit is contained in:
yangdx
2025-03-24 14:29:36 +08:00
parent 79bf26dfeb
commit 9e3994419e
4 changed files with 99 additions and 64 deletions

View File

@@ -19,6 +19,7 @@ from contextlib import asynccontextmanager
from dotenv import load_dotenv from dotenv import load_dotenv
from lightrag.api.utils_api import ( from lightrag.api.utils_api import (
get_api_key_dependency, get_api_key_dependency,
get_combined_auth_dependency,
parse_args, parse_args,
get_default_host, get_default_host,
display_splash_screen, display_splash_screen,
@@ -135,19 +136,28 @@ def create_app(args):
await rag.finalize_storages() await rag.finalize_storages()
# Initialize FastAPI # Initialize FastAPI
app = FastAPI( app_kwargs = {
title="LightRAG API", "title": "LightRAG Server API",
description="API for querying text using LightRAG with separate storage and input directories" "description": "Providing API for LightRAG core, Web UI and Ollama Model Emulation"
+ "(With authentication)" + "(With authentication)"
if api_key if api_key
else "", else "",
version=__api_version__, "version": __api_version__,
openapi_url="/openapi.json", # Explicitly set OpenAPI schema URL "openapi_url": "/openapi.json", # Explicitly set OpenAPI schema URL
docs_url="/docs", # Explicitly set docs URL "docs_url": "/docs", # Explicitly set docs URL
redoc_url="/redoc", # Explicitly set redoc URL "redoc_url": "/redoc", # Explicitly set redoc URL
openapi_tags=[{"name": "api"}], "openapi_tags": [{"name": "api"}],
lifespan=lifespan, "lifespan": lifespan,
) }
# Configure Swagger UI parameters
# Enable persistAuthorization and tryItOutEnabled for better user experience
app_kwargs["swagger_ui_parameters"] = {
"persistAuthorization": True,
"tryItOutEnabled": True,
}
app = FastAPI(**app_kwargs)
def get_cors_origins(): def get_cors_origins():
"""Get allowed origins from environment variable """Get allowed origins from environment variable
@@ -167,13 +177,8 @@ def create_app(args):
allow_headers=["*"], allow_headers=["*"],
) )
# Create the optional API key dependency # Create combined auth dependency for all endpoints
# Create a dependency that passes the request to get_api_key_dependency combined_auth = get_combined_auth_dependency(api_key)
async def optional_api_key_dependency(request: Request):
# Create the dependency function with the request
api_key_dependency = get_api_key_dependency(api_key)
# Call the dependency function with the request
return await api_key_dependency(request)
# Create working directory if it doesn't exist # Create working directory if it doesn't exist
Path(args.working_dir).mkdir(parents=True, exist_ok=True) Path(args.working_dir).mkdir(parents=True, exist_ok=True)
@@ -419,7 +424,7 @@ def create_app(args):
"api_version": __api_version__, "api_version": __api_version__,
} }
@app.get("/health", dependencies=[Depends(optional_api_key_dependency)]) @app.get("/health", dependencies=[Depends(combined_auth)])
async def get_status(): async def get_status():
"""Get current system status""" """Get current system status"""
username = os.getenv("AUTH_USERNAME") username = os.getenv("AUTH_USERNAME")

View File

@@ -505,6 +505,7 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
def create_document_routes( def create_document_routes(
rag: LightRAG, doc_manager: DocumentManager, api_key: Optional[str] = None rag: LightRAG, doc_manager: DocumentManager, api_key: Optional[str] = None
): ):
# Create combined auth dependency for document routes
combined_auth = get_combined_auth_dependency(api_key) combined_auth = get_combined_auth_dependency(api_key)
@router.post("/scan", dependencies=[Depends(combined_auth)]) @router.post("/scan", dependencies=[Depends(combined_auth)])

View File

@@ -11,7 +11,7 @@ import asyncio
from ascii_colors import trace_exception from ascii_colors import trace_exception
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
from lightrag.utils import encode_string_by_tiktoken from lightrag.utils import encode_string_by_tiktoken
from lightrag.api.utils_api import ollama_server_infos, get_api_key_dependency from lightrag.api.utils_api import ollama_server_infos, get_combined_auth_dependency
from fastapi import Depends from fastapi import Depends
@@ -132,21 +132,17 @@ class OllamaAPI:
self.setup_routes() self.setup_routes()
def setup_routes(self): def setup_routes(self):
# Create a dependency that passes the request to get_api_key_dependency # Create combined auth dependency for Ollama API routes
async def optional_api_key_dependency(request: Request): combined_auth = get_combined_auth_dependency(self.api_key)
# Create the dependency function with the request
api_key_dependency = get_api_key_dependency(self.api_key)
# Call the dependency function with the request
return await api_key_dependency(request)
@self.router.get( @self.router.get(
"/version", dependencies=[Depends(optional_api_key_dependency)] "/version", dependencies=[Depends(combined_auth)]
) )
async def get_version(): async def get_version():
"""Get Ollama version information""" """Get Ollama version information"""
return OllamaVersionResponse(version="0.5.4") return OllamaVersionResponse(version="0.5.4")
@self.router.get("/tags", dependencies=[Depends(optional_api_key_dependency)]) @self.router.get("/tags", dependencies=[Depends(combined_auth)])
async def get_tags(): async def get_tags():
"""Return available models acting as an Ollama server""" """Return available models acting as an Ollama server"""
return OllamaTagResponse( return OllamaTagResponse(
@@ -170,7 +166,7 @@ class OllamaAPI:
) )
@self.router.post( @self.router.post(
"/generate", dependencies=[Depends(optional_api_key_dependency)] "/generate", dependencies=[Depends(combined_auth)]
) )
async def generate(raw_request: Request, request: OllamaGenerateRequest): async def generate(raw_request: Request, request: OllamaGenerateRequest):
"""Handle generate completion requests acting as an Ollama model """Handle generate completion requests acting as an Ollama model
@@ -337,7 +333,7 @@ class OllamaAPI:
trace_exception(e) trace_exception(e)
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@self.router.post("/chat", dependencies=[Depends(optional_api_key_dependency)]) @self.router.post("/chat", dependencies=[Depends(combined_auth)])
async def chat(raw_request: Request, request: OllamaChatRequest): async def chat(raw_request: Request, request: OllamaChatRequest):
"""Process chat completion requests acting as an Ollama model """Process chat completion requests acting as an Ollama model
Routes user queries through LightRAG by selecting query mode based on prefix indicators. Routes user queries through LightRAG by selecting query mode based on prefix indicators.

View File

@@ -58,29 +58,43 @@ ollama_server_infos = OllamaServerInfos()
def get_combined_auth_dependency(api_key: Optional[str] = None): def get_combined_auth_dependency(api_key: Optional[str] = None):
""" """
Create a combined authentication dependency that implements OR logic (pass through any authentication method) Create a combined authentication dependency that implements authentication logic
based on API key, OAuth2 token, and whitelist paths.
Args: Args:
api_key (Optional[str]): API key for validation api_key (Optional[str]): API key for validation
Returns: Returns:
Callable: A dependency function that implements OR authentication logic Callable: A dependency function that implements the authentication logic
""" """
# Use global whitelist_patterns and auth_configured variables # Use global whitelist_patterns and auth_configured variables
# whitelist_patterns and auth_configured are already initialized at module level # whitelist_patterns and auth_configured are already initialized at module level
# Only calculate api_key_configured as it depends on the function parameter # Only calculate api_key_configured as it depends on the function parameter
api_key_configured = bool(api_key) api_key_configured = bool(api_key)
# Create security dependencies with proper descriptions for Swagger UI
oauth2_scheme = OAuth2PasswordBearer(
tokenUrl="login",
auto_error=False,
description="OAuth2 Password Authentication"
)
# If API key is configured, create an API key header security
api_key_header = None
if api_key_configured:
api_key_header = APIKeyHeader(
name="X-API-Key",
auto_error=False,
description="API Key Authentication"
)
async def combined_dependency( async def combined_dependency(
request: Request, request: Request,
token: str = Depends(OAuth2PasswordBearer(tokenUrl="login", auto_error=False)), token: str = Security(oauth2_scheme),
api_key_header_value: Optional[str] = None if api_key_header is None else Security(api_key_header),
): ):
# If both authentication methods are not configured, allow access # 1. Check if path is in whitelist
if not auth_configured and not api_key_configured:
return
# Check if request path is in whitelist
path = request.url.path path = request.url.path
for pattern, is_prefix in whitelist_patterns: for pattern, is_prefix in whitelist_patterns:
if (is_prefix and path.startswith(pattern)) or ( if (is_prefix and path.startswith(pattern)) or (
@@ -88,35 +102,54 @@ def get_combined_auth_dependency(api_key: Optional[str] = None):
): ):
return # Whitelist path, allow access return # Whitelist path, allow access
# Access with token # 2. Check for special endpoints (/health and Ollama API)
is_special_endpoint = path == "/health" or path.startswith("/api/")
if is_special_endpoint and not api_key_configured:
return # Special endpoint and no API key configured, allow access
# 3. Validate API key
if api_key_configured and api_key_header_value and api_key_header_value == api_key:
return # API key validation successful
# 4. Validate token
if token: if token:
token_info = auth_handler.validate_token(token) try:
if auth_configured: token_info = auth_handler.validate_token(token)
if token_info.get("role") != "guest" or not api_key_configured: # Accept guest token if no auth is configured
return # Password authentication successful if not auth_configured and token_info.get("role") == "guest":
else: return
if token_info.get("role") == "guest": # Accept non-guest token if auth is configured
return # Guest authentication successful if auth_configured and token_info.get("role") != "guest":
return
# Token validation failed, immediately return 401 error
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token. Please login again."
)
except HTTPException as e:
# If already a 401 error, re-raise it
if e.status_code == status.HTTP_401_UNAUTHORIZED:
raise
# For other exceptions, continue processing
# If token exists but validation failed (didn't return above), return 401
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token required" status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token. Please login again."
) )
# Try API key authentication (if configured) # 5. No token and API key validation failed, return 403 error
if api_key_configured: if api_key_configured:
api_key_header = request.headers.get("X-API-Key") raise HTTPException(
if api_key_header and api_key_header == api_key: status_code=HTTP_403_FORBIDDEN,
return # API key authentication successful detail="API Key required or login authentication required."
else: )
if auth_configured: else:
raise HTTPException( raise HTTPException(
status_code=HTTP_403_FORBIDDEN, status_code=HTTP_403_FORBIDDEN,
detail="API Key required or use password authentication.", detail="Login authentication required."
) )
else:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail="API Key required or use guest authentication.",
)
return combined_dependency return combined_dependency
@@ -145,12 +178,12 @@ def get_api_key_dependency(api_key: Optional[str]):
return no_auth return no_auth
# If API key is configured, use proper authentication # If API key is configured, use proper authentication with Security for Swagger UI
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
async def api_key_auth( async def api_key_auth(
request: Request, request: Request,
api_key_header_value: Optional[str] = Security(api_key_header), api_key_header_value: Optional[str] = Security(api_key_header, description="API Key for authentication"),
): ):
# Check if request path is in whitelist # Check if request path is in whitelist
path = request.url.path path = request.url.path