From 9e3994419e2b5624f577389ea44c52fc18fe4b00 Mon Sep 17 00:00:00 2001 From: yangdx Date: Mon, 24 Mar 2025 14:29:36 +0800 Subject: [PATCH] Refactor authentication logic and Swagger UI config - Consolidate authentication dependencies - Improve Swagger UI security parameters --- lightrag/api/lightrag_server.py | 41 +++++----- lightrag/api/routers/document_routes.py | 1 + lightrag/api/routers/ollama_api.py | 18 ++--- lightrag/api/utils_api.py | 103 ++++++++++++++++-------- 4 files changed, 99 insertions(+), 64 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 503fb3a8..ccc4a93a 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -19,6 +19,7 @@ from contextlib import asynccontextmanager from dotenv import load_dotenv from lightrag.api.utils_api import ( get_api_key_dependency, + get_combined_auth_dependency, parse_args, get_default_host, display_splash_screen, @@ -135,19 +136,28 @@ def create_app(args): await rag.finalize_storages() # Initialize FastAPI - app = FastAPI( - title="LightRAG API", - description="API for querying text using LightRAG with separate storage and input directories" + app_kwargs = { + "title": "LightRAG Server API", + "description": "Providing API for LightRAG core, Web UI and Ollama Model Emulation" + "(With authentication)" if api_key else "", - version=__api_version__, - openapi_url="/openapi.json", # Explicitly set OpenAPI schema URL - docs_url="/docs", # Explicitly set docs URL - redoc_url="/redoc", # Explicitly set redoc URL - openapi_tags=[{"name": "api"}], - lifespan=lifespan, - ) + "version": __api_version__, + "openapi_url": "/openapi.json", # Explicitly set OpenAPI schema URL + "docs_url": "/docs", # Explicitly set docs URL + "redoc_url": "/redoc", # Explicitly set redoc URL + "openapi_tags": [{"name": "api"}], + "lifespan": lifespan, + } + + # Configure Swagger UI parameters + # Enable persistAuthorization and tryItOutEnabled for better user experience + app_kwargs["swagger_ui_parameters"] = { + "persistAuthorization": True, + "tryItOutEnabled": True, + } + + app = FastAPI(**app_kwargs) def get_cors_origins(): """Get allowed origins from environment variable @@ -167,13 +177,8 @@ def create_app(args): allow_headers=["*"], ) - # Create the optional API key dependency - # Create a dependency that passes the request to get_api_key_dependency - async def optional_api_key_dependency(request: Request): - # Create the dependency function with the request - api_key_dependency = get_api_key_dependency(api_key) - # Call the dependency function with the request - return await api_key_dependency(request) + # Create combined auth dependency for all endpoints + combined_auth = get_combined_auth_dependency(api_key) # Create working directory if it doesn't exist Path(args.working_dir).mkdir(parents=True, exist_ok=True) @@ -419,7 +424,7 @@ def create_app(args): "api_version": __api_version__, } - @app.get("/health", dependencies=[Depends(optional_api_key_dependency)]) + @app.get("/health", dependencies=[Depends(combined_auth)]) async def get_status(): """Get current system status""" username = os.getenv("AUTH_USERNAME") diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index fc5f2c13..42a8c102 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -505,6 +505,7 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager): def create_document_routes( rag: LightRAG, doc_manager: DocumentManager, api_key: Optional[str] = None ): + # Create combined auth dependency for document routes combined_auth = get_combined_auth_dependency(api_key) @router.post("/scan", dependencies=[Depends(combined_auth)]) diff --git a/lightrag/api/routers/ollama_api.py b/lightrag/api/routers/ollama_api.py index a574ead8..15d1859f 100644 --- a/lightrag/api/routers/ollama_api.py +++ b/lightrag/api/routers/ollama_api.py @@ -11,7 +11,7 @@ import asyncio from ascii_colors import trace_exception from lightrag import LightRAG, QueryParam from lightrag.utils import encode_string_by_tiktoken -from lightrag.api.utils_api import ollama_server_infos, get_api_key_dependency +from lightrag.api.utils_api import ollama_server_infos, get_combined_auth_dependency from fastapi import Depends @@ -132,21 +132,17 @@ class OllamaAPI: self.setup_routes() def setup_routes(self): - # Create a dependency that passes the request to get_api_key_dependency - async def optional_api_key_dependency(request: Request): - # Create the dependency function with the request - api_key_dependency = get_api_key_dependency(self.api_key) - # Call the dependency function with the request - return await api_key_dependency(request) + # Create combined auth dependency for Ollama API routes + combined_auth = get_combined_auth_dependency(self.api_key) @self.router.get( - "/version", dependencies=[Depends(optional_api_key_dependency)] + "/version", dependencies=[Depends(combined_auth)] ) async def get_version(): """Get Ollama version information""" return OllamaVersionResponse(version="0.5.4") - @self.router.get("/tags", dependencies=[Depends(optional_api_key_dependency)]) + @self.router.get("/tags", dependencies=[Depends(combined_auth)]) async def get_tags(): """Return available models acting as an Ollama server""" return OllamaTagResponse( @@ -170,7 +166,7 @@ class OllamaAPI: ) @self.router.post( - "/generate", dependencies=[Depends(optional_api_key_dependency)] + "/generate", dependencies=[Depends(combined_auth)] ) async def generate(raw_request: Request, request: OllamaGenerateRequest): """Handle generate completion requests acting as an Ollama model @@ -337,7 +333,7 @@ class OllamaAPI: trace_exception(e) raise HTTPException(status_code=500, detail=str(e)) - @self.router.post("/chat", dependencies=[Depends(optional_api_key_dependency)]) + @self.router.post("/chat", dependencies=[Depends(combined_auth)]) async def chat(raw_request: Request, request: OllamaChatRequest): """Process chat completion requests acting as an Ollama model Routes user queries through LightRAG by selecting query mode based on prefix indicators. diff --git a/lightrag/api/utils_api.py b/lightrag/api/utils_api.py index 9c2f2bb6..a72aa8e3 100644 --- a/lightrag/api/utils_api.py +++ b/lightrag/api/utils_api.py @@ -58,29 +58,43 @@ ollama_server_infos = OllamaServerInfos() def get_combined_auth_dependency(api_key: Optional[str] = None): """ - Create a combined authentication dependency that implements OR logic (pass through any authentication method) + Create a combined authentication dependency that implements authentication logic + based on API key, OAuth2 token, and whitelist paths. Args: api_key (Optional[str]): API key for validation Returns: - Callable: A dependency function that implements OR authentication logic + Callable: A dependency function that implements the authentication logic """ # Use global whitelist_patterns and auth_configured variables # whitelist_patterns and auth_configured are already initialized at module level # Only calculate api_key_configured as it depends on the function parameter api_key_configured = bool(api_key) + + # Create security dependencies with proper descriptions for Swagger UI + oauth2_scheme = OAuth2PasswordBearer( + tokenUrl="login", + auto_error=False, + description="OAuth2 Password Authentication" + ) + + # If API key is configured, create an API key header security + api_key_header = None + if api_key_configured: + api_key_header = APIKeyHeader( + name="X-API-Key", + auto_error=False, + description="API Key Authentication" + ) async def combined_dependency( request: Request, - token: str = Depends(OAuth2PasswordBearer(tokenUrl="login", auto_error=False)), + token: str = Security(oauth2_scheme), + api_key_header_value: Optional[str] = None if api_key_header is None else Security(api_key_header), ): - # If both authentication methods are not configured, allow access - if not auth_configured and not api_key_configured: - return - - # Check if request path is in whitelist + # 1. Check if path is in whitelist path = request.url.path for pattern, is_prefix in whitelist_patterns: if (is_prefix and path.startswith(pattern)) or ( @@ -88,35 +102,54 @@ def get_combined_auth_dependency(api_key: Optional[str] = None): ): return # Whitelist path, allow access - # Access with token + # 2. Check for special endpoints (/health and Ollama API) + is_special_endpoint = path == "/health" or path.startswith("/api/") + if is_special_endpoint and not api_key_configured: + return # Special endpoint and no API key configured, allow access + + # 3. Validate API key + if api_key_configured and api_key_header_value and api_key_header_value == api_key: + return # API key validation successful + + # 4. Validate token if token: - token_info = auth_handler.validate_token(token) - if auth_configured: - if token_info.get("role") != "guest" or not api_key_configured: - return # Password authentication successful - else: - if token_info.get("role") == "guest": - return # Guest authentication successful + try: + token_info = auth_handler.validate_token(token) + # Accept guest token if no auth is configured + if not auth_configured and token_info.get("role") == "guest": + return + # Accept non-guest token if auth is configured + if auth_configured and token_info.get("role") != "guest": + return + + # Token validation failed, immediately return 401 error + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid token. Please login again." + ) + except HTTPException as e: + # If already a 401 error, re-raise it + if e.status_code == status.HTTP_401_UNAUTHORIZED: + raise + # For other exceptions, continue processing + + # If token exists but validation failed (didn't return above), return 401 raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail="Token required" + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid token. Please login again." ) - - # Try API key authentication (if configured) + + # 5. No token and API key validation failed, return 403 error if api_key_configured: - api_key_header = request.headers.get("X-API-Key") - if api_key_header and api_key_header == api_key: - return # API key authentication successful - else: - if auth_configured: - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, - detail="API Key required or use password authentication.", - ) - else: - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, - detail="API Key required or use guest authentication.", - ) + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, + detail="API Key required or login authentication required." + ) + else: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, + detail="Login authentication required." + ) return combined_dependency @@ -145,12 +178,12 @@ def get_api_key_dependency(api_key: Optional[str]): return no_auth - # If API key is configured, use proper authentication + # If API key is configured, use proper authentication with Security for Swagger UI api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) async def api_key_auth( request: Request, - api_key_header_value: Optional[str] = Security(api_key_header), + api_key_header_value: Optional[str] = Security(api_key_header, description="API Key for authentication"), ): # Check if request path is in whitelist path = request.url.path