From 90ef55960d1457819f06a93fd8ab6800a614df22 Mon Sep 17 00:00:00 2001 From: yangdx Date: Mon, 24 Mar 2025 05:23:40 +0800 Subject: [PATCH] Refactor authentication and whitelist handling - Combined auth and API key dependencies - Optimized whitelist path matching - Added optional API key to OllamaAPI --- lightrag/api/lightrag_server.py | 17 ++-- lightrag/api/routers/document_routes.py | 24 ++--- lightrag/api/routers/graph_routes.py | 10 +- lightrag/api/routers/ollama_api.py | 25 +++-- lightrag/api/routers/query_routes.py | 10 +- lightrag/api/utils_api.py | 129 +++++++++++++++++------- 6 files changed, 145 insertions(+), 70 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index de9e8714..2ebf0ca4 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -2,7 +2,7 @@ LightRAG FastAPI Server """ -from fastapi import FastAPI, Depends, HTTPException, status +from fastapi import FastAPI, Depends, HTTPException, status, Request import asyncio import os import logging @@ -169,7 +169,12 @@ def create_app(args): ) # Create the optional API key dependency - optional_api_key = get_api_key_dependency(api_key) + # Create a dependency that passes the request to get_api_key_dependency + async def optional_api_key_dependency(request: Request): + # Create the dependency function with the request + api_key_dependency = get_api_key_dependency(api_key) + # Call the dependency function with the request + return await api_key_dependency(request) # Create working directory if it doesn't exist Path(args.working_dir).mkdir(parents=True, exist_ok=True) @@ -343,7 +348,7 @@ def create_app(args): app.include_router(create_graph_routes(rag, api_key)) # Add Ollama API routes - ollama_api = OllamaAPI(rag, top_k=args.top_k) + ollama_api = OllamaAPI(rag, top_k=args.top_k, api_key=api_key) app.include_router(ollama_api.router, prefix="/api") @app.get("/") @@ -351,7 +356,7 @@ def create_app(args): """Redirect root path to /webui""" return RedirectResponse(url="/webui") - @app.get("/auth-status", dependencies=[Depends(optional_api_key)]) + @app.get("/auth-status") async def get_auth_status(): """Get authentication status and guest token if auth is not configured""" username = os.getenv("AUTH_USERNAME") @@ -379,7 +384,7 @@ def create_app(args): "api_version": __api_version__, } - @app.post("/login", dependencies=[Depends(optional_api_key)]) + @app.post("/login") async def login(form_data: OAuth2PasswordRequestForm = Depends()): username = os.getenv("AUTH_USERNAME") password = os.getenv("AUTH_PASSWORD") @@ -415,7 +420,7 @@ def create_app(args): "api_version": __api_version__, } - @app.get("/health", dependencies=[Depends(optional_api_key)]) + @app.get("/health", dependencies=[Depends(optional_api_key_dependency)]) async def get_status(): """Get current system status""" # Get update flags status for all namespaces diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index 48bc1243..9de306be 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -17,15 +17,13 @@ from pydantic import BaseModel, Field, field_validator from lightrag import LightRAG from lightrag.base import DocProcessingStatus, DocStatus from lightrag.api.utils_api import ( - get_api_key_dependency, + get_combined_auth_dependency, global_args, - get_auth_dependency, ) router = APIRouter( prefix="/documents", tags=["documents"], - dependencies=[Depends(get_auth_dependency())], ) # Temporary file prefix @@ -505,9 +503,9 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager): def create_document_routes( rag: LightRAG, doc_manager: DocumentManager, api_key: Optional[str] = None ): - optional_api_key = get_api_key_dependency(api_key) + combined_auth = get_combined_auth_dependency(api_key) - @router.post("/scan", dependencies=[Depends(optional_api_key)]) + @router.post("/scan", dependencies=[Depends(combined_auth)]) async def scan_for_new_documents(background_tasks: BackgroundTasks): """ Trigger the scanning process for new documents. @@ -523,7 +521,7 @@ def create_document_routes( background_tasks.add_task(run_scanning_process, rag, doc_manager) return {"status": "scanning_started"} - @router.post("/upload", dependencies=[Depends(optional_api_key)]) + @router.post("/upload", dependencies=[Depends(combined_auth)]) async def upload_to_input_dir( background_tasks: BackgroundTasks, file: UploadFile = File(...) ): @@ -568,7 +566,7 @@ def create_document_routes( raise HTTPException(status_code=500, detail=str(e)) @router.post( - "/text", response_model=InsertResponse, dependencies=[Depends(optional_api_key)] + "/text", response_model=InsertResponse, dependencies=[Depends(combined_auth)] ) async def insert_text( request: InsertTextRequest, background_tasks: BackgroundTasks @@ -603,7 +601,7 @@ def create_document_routes( @router.post( "/texts", response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], + dependencies=[Depends(combined_auth)], ) async def insert_texts( request: InsertTextsRequest, background_tasks: BackgroundTasks @@ -636,7 +634,7 @@ def create_document_routes( raise HTTPException(status_code=500, detail=str(e)) @router.post( - "/file", response_model=InsertResponse, dependencies=[Depends(optional_api_key)] + "/file", response_model=InsertResponse, dependencies=[Depends(combined_auth)] ) async def insert_file( background_tasks: BackgroundTasks, file: UploadFile = File(...) @@ -681,7 +679,7 @@ def create_document_routes( @router.post( "/file_batch", response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], + dependencies=[Depends(combined_auth)], ) async def insert_batch( background_tasks: BackgroundTasks, files: List[UploadFile] = File(...) @@ -742,7 +740,7 @@ def create_document_routes( raise HTTPException(status_code=500, detail=str(e)) @router.delete( - "", response_model=InsertResponse, dependencies=[Depends(optional_api_key)] + "", response_model=InsertResponse, dependencies=[Depends(combined_auth)] ) async def clear_documents(): """ @@ -771,7 +769,7 @@ def create_document_routes( @router.get( "/pipeline_status", - dependencies=[Depends(optional_api_key)], + dependencies=[Depends(combined_auth)], response_model=PipelineStatusResponse, ) async def get_pipeline_status() -> PipelineStatusResponse: @@ -819,7 +817,7 @@ def create_document_routes( logger.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) - @router.get("", dependencies=[Depends(optional_api_key)]) + @router.get("", dependencies=[Depends(combined_auth)]) async def documents() -> DocsStatusesResponse: """ Get the status of all documents in the system. diff --git a/lightrag/api/routers/graph_routes.py b/lightrag/api/routers/graph_routes.py index 95802185..f9d77ff6 100644 --- a/lightrag/api/routers/graph_routes.py +++ b/lightrag/api/routers/graph_routes.py @@ -5,15 +5,15 @@ This module contains all graph-related routes for the LightRAG API. from typing import Optional from fastapi import APIRouter, Depends -from ..utils_api import get_api_key_dependency, get_auth_dependency +from ..utils_api import get_combined_auth_dependency -router = APIRouter(tags=["graph"], dependencies=[Depends(get_auth_dependency())]) +router = APIRouter(tags=["graph"]) def create_graph_routes(rag, api_key: Optional[str] = None): - optional_api_key = get_api_key_dependency(api_key) + combined_auth = get_combined_auth_dependency(api_key) - @router.get("/graph/label/list", dependencies=[Depends(optional_api_key)]) + @router.get("/graph/label/list", dependencies=[Depends(combined_auth)]) async def get_graph_labels(): """ Get all graph labels @@ -23,7 +23,7 @@ def create_graph_routes(rag, api_key: Optional[str] = None): """ return await rag.get_graph_labels() - @router.get("/graphs", dependencies=[Depends(optional_api_key)]) + @router.get("/graphs", dependencies=[Depends(combined_auth)]) async def get_knowledge_graph( label: str, max_depth: int = 3, min_degree: int = 0, inclusive: bool = False ): diff --git a/lightrag/api/routers/ollama_api.py b/lightrag/api/routers/ollama_api.py index 37d7354e..a574ead8 100644 --- a/lightrag/api/routers/ollama_api.py +++ b/lightrag/api/routers/ollama_api.py @@ -11,7 +11,8 @@ import asyncio from ascii_colors import trace_exception from lightrag import LightRAG, QueryParam from lightrag.utils import encode_string_by_tiktoken -from lightrag.api.utils_api import ollama_server_infos +from lightrag.api.utils_api import ollama_server_infos, get_api_key_dependency +from fastapi import Depends # query mode according to query prefix (bypass is not LightRAG quer mode) @@ -122,20 +123,30 @@ def parse_query_mode(query: str) -> tuple[str, SearchMode]: class OllamaAPI: - def __init__(self, rag: LightRAG, top_k: int = 60): + def __init__(self, rag: LightRAG, top_k: int = 60, api_key: Optional[str] = None): self.rag = rag self.ollama_server_infos = ollama_server_infos self.top_k = top_k + self.api_key = api_key self.router = APIRouter(tags=["ollama"]) self.setup_routes() def setup_routes(self): - @self.router.get("/version") + # Create a dependency that passes the request to get_api_key_dependency + async def optional_api_key_dependency(request: Request): + # Create the dependency function with the request + api_key_dependency = get_api_key_dependency(self.api_key) + # Call the dependency function with the request + return await api_key_dependency(request) + + @self.router.get( + "/version", dependencies=[Depends(optional_api_key_dependency)] + ) async def get_version(): """Get Ollama version information""" return OllamaVersionResponse(version="0.5.4") - @self.router.get("/tags") + @self.router.get("/tags", dependencies=[Depends(optional_api_key_dependency)]) async def get_tags(): """Return available models acting as an Ollama server""" return OllamaTagResponse( @@ -158,7 +169,9 @@ class OllamaAPI: ] ) - @self.router.post("/generate") + @self.router.post( + "/generate", dependencies=[Depends(optional_api_key_dependency)] + ) async def generate(raw_request: Request, request: OllamaGenerateRequest): """Handle generate completion requests acting as an Ollama model For compatibility purpose, the request is not processed by LightRAG, @@ -324,7 +337,7 @@ class OllamaAPI: trace_exception(e) raise HTTPException(status_code=500, detail=str(e)) - @self.router.post("/chat") + @self.router.post("/chat", dependencies=[Depends(optional_api_key_dependency)]) async def chat(raw_request: Request, request: OllamaChatRequest): """Process chat completion requests acting as an Ollama model Routes user queries through LightRAG by selecting query mode based on prefix indicators. diff --git a/lightrag/api/routers/query_routes.py b/lightrag/api/routers/query_routes.py index 7a5bd8c3..c9648356 100644 --- a/lightrag/api/routers/query_routes.py +++ b/lightrag/api/routers/query_routes.py @@ -8,12 +8,12 @@ from typing import Any, Dict, List, Literal, Optional from fastapi import APIRouter, Depends, HTTPException from lightrag.base import QueryParam -from ..utils_api import get_api_key_dependency, get_auth_dependency +from ..utils_api import get_combined_auth_dependency from pydantic import BaseModel, Field, field_validator from ascii_colors import trace_exception -router = APIRouter(tags=["query"], dependencies=[Depends(get_auth_dependency())]) +router = APIRouter(tags=["query"]) class QueryRequest(BaseModel): @@ -139,10 +139,10 @@ class QueryResponse(BaseModel): def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60): - optional_api_key = get_api_key_dependency(api_key) + combined_auth = get_combined_auth_dependency(api_key) @router.post( - "/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)] + "/query", response_model=QueryResponse, dependencies=[Depends(combined_auth)] ) async def query_text(request: QueryRequest): """ @@ -176,7 +176,7 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60): trace_exception(e) raise HTTPException(status_code=500, detail=str(e)) - @router.post("/query/stream", dependencies=[Depends(optional_api_key)]) + @router.post("/query/stream", dependencies=[Depends(combined_auth)]) async def query_text_stream(request: QueryRequest): """ This endpoint performs a retrieval-augmented generation (RAG) query and streams the response. diff --git a/lightrag/api/utils_api.py b/lightrag/api/utils_api.py index a762b28b..7b43feb8 100644 --- a/lightrag/api/utils_api.py +++ b/lightrag/api/utils_api.py @@ -4,7 +4,7 @@ Utility functions for the LightRAG API. import os import argparse -from typing import Optional +from typing import Optional, List, Tuple import sys import logging from ascii_colors import ASCIIColors @@ -21,6 +21,29 @@ load_dotenv() global_args = {"main_args": None} +# Get whitelist paths from environment variable, only once during initialization +default_whitelist = "/health,/api/*" +whitelist_paths = os.getenv("WHITELIST_PATHS", default_whitelist).split(",") + +# Pre-compile path matching patterns +whitelist_patterns: List[Tuple[str, bool]] = [] +for path in whitelist_paths: + path = path.strip() + if path: + # If path ends with /*, match all paths with that prefix + if path.endswith("/*"): + prefix = path[:-2] + whitelist_patterns.append((prefix, True)) # (prefix, is_prefix_match) + else: + whitelist_patterns.append( + (path, False) + ) # (exact_path, is_prefix_match) + +# Global authentication configuration +auth_username = os.getenv("AUTH_USERNAME") +auth_password = os.getenv("AUTH_PASSWORD") +auth_configured = bool(auth_username and auth_password) + class OllamaServerInfos: # Constants for emulated Ollama model information @@ -35,49 +58,69 @@ class OllamaServerInfos: ollama_server_infos = OllamaServerInfos() -def get_auth_dependency(): - # Set default whitelist paths - whitelist = os.getenv("WHITELIST_PATHS", "/login,/health").split(",") +def get_combined_auth_dependency(api_key: Optional[str] = None): + """ + Create a combined authentication dependency that implements OR logic (pass through any authentication method) - async def dependency( + Args: + api_key (Optional[str]): API key for validation + + Returns: + Callable: A dependency function that implements OR authentication logic + """ + # Use global whitelist_patterns and auth_configured variables + # whitelist_patterns and auth_configured are already initialized at module level + + # Only calculate api_key_configured as it depends on the function parameter + api_key_configured = bool(api_key) + + async def combined_dependency( request: Request, token: str = Depends(OAuth2PasswordBearer(tokenUrl="login", auto_error=False)), ): - # Check if authentication is configured - auth_configured = bool( - os.getenv("AUTH_USERNAME") and os.getenv("AUTH_PASSWORD") - ) - - # If authentication is not configured, skip all validation - if not auth_configured: + # If both authentication methods are not configured, allow access + if not auth_configured and not api_key_configured: return - # For configured auth, allow whitelist paths without token - if request.url.path in whitelist: - return + # Check if request path is in whitelist + path = request.url.path + for pattern, is_prefix in whitelist_patterns: + if (is_prefix and path.startswith(pattern)) or ( + not is_prefix and path == pattern + ): + return # Whitelist path, allow access - # Require token for all other paths when auth is configured - if not token: + # Access with token + if token: + token_info = auth_handler.validate_token(token) + if auth_configured: + if token_info.get("role") != "guest" or not api_key_configured: + return # Password authentication successful + else: + if token_info.get("role") == "guest": + return # Guest authentication successful raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Token required" ) + + # Try API key authentication (if configured) + if api_key_configured: + api_key_header = request.headers.get("X-API-Key") + if api_key_header and api_key_header == api_key: + return # API key authentication successful + else: + if auth_configured: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, + detail="API Key required or use password authentication.", + ) + else: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, + detail="API Key required or use guest authentication.", + ) - try: - token_info = auth_handler.validate_token(token) - # Reject guest tokens when authentication is configured - if token_info.get("role") == "guest": - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Authentication required. Guest access not allowed when authentication is configured.", - ) - except Exception: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token" - ) - - return - - return dependency + return combined_dependency def get_api_key_dependency(api_key: Optional[str]): @@ -91,9 +134,15 @@ def get_api_key_dependency(api_key: Optional[str]): Returns: Callable: A dependency function that validates the API key. """ - if not api_key: + # Use global whitelist_patterns and auth_configured variables + # whitelist_patterns and auth_configured are already initialized at module level + + # Only calculate api_key_configured as it depends on the function parameter + api_key_configured = bool(api_key) + + if not api_key_configured: # If no API key is configured, return a dummy dependency that always succeeds - async def no_auth(): + async def no_auth(request: Request = None, **kwargs): return None return no_auth @@ -102,8 +151,18 @@ def get_api_key_dependency(api_key: Optional[str]): api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) async def api_key_auth( + request: Request, api_key_header_value: Optional[str] = Security(api_key_header), ): + # Check if request path is in whitelist + path = request.url.path + for pattern, is_prefix in whitelist_patterns: + if (is_prefix and path.startswith(pattern)) or ( + not is_prefix and path == pattern + ): + return # Whitelist path, allow access + + # Non-whitelist path, validate API key if not api_key_header_value: raise HTTPException( status_code=HTTP_403_FORBIDDEN, detail="API Key required"