Refactor authentication and whitelist handling

- Combined auth and API key dependencies
- Optimized whitelist path matching
- Added optional API key to OllamaAPI
This commit is contained in:
yangdx
2025-03-24 05:23:40 +08:00
parent 8301f0a523
commit 90ef55960d
6 changed files with 145 additions and 70 deletions

View File

@@ -2,7 +2,7 @@
LightRAG FastAPI Server LightRAG FastAPI Server
""" """
from fastapi import FastAPI, Depends, HTTPException, status from fastapi import FastAPI, Depends, HTTPException, status, Request
import asyncio import asyncio
import os import os
import logging import logging
@@ -169,7 +169,12 @@ def create_app(args):
) )
# Create the optional API key dependency # Create the optional API key dependency
optional_api_key = get_api_key_dependency(api_key) # Create a dependency that passes the request to get_api_key_dependency
async def optional_api_key_dependency(request: Request):
# Create the dependency function with the request
api_key_dependency = get_api_key_dependency(api_key)
# Call the dependency function with the request
return await api_key_dependency(request)
# Create working directory if it doesn't exist # Create working directory if it doesn't exist
Path(args.working_dir).mkdir(parents=True, exist_ok=True) Path(args.working_dir).mkdir(parents=True, exist_ok=True)
@@ -343,7 +348,7 @@ def create_app(args):
app.include_router(create_graph_routes(rag, api_key)) app.include_router(create_graph_routes(rag, api_key))
# Add Ollama API routes # Add Ollama API routes
ollama_api = OllamaAPI(rag, top_k=args.top_k) ollama_api = OllamaAPI(rag, top_k=args.top_k, api_key=api_key)
app.include_router(ollama_api.router, prefix="/api") app.include_router(ollama_api.router, prefix="/api")
@app.get("/") @app.get("/")
@@ -351,7 +356,7 @@ def create_app(args):
"""Redirect root path to /webui""" """Redirect root path to /webui"""
return RedirectResponse(url="/webui") return RedirectResponse(url="/webui")
@app.get("/auth-status", dependencies=[Depends(optional_api_key)]) @app.get("/auth-status")
async def get_auth_status(): async def get_auth_status():
"""Get authentication status and guest token if auth is not configured""" """Get authentication status and guest token if auth is not configured"""
username = os.getenv("AUTH_USERNAME") username = os.getenv("AUTH_USERNAME")
@@ -379,7 +384,7 @@ def create_app(args):
"api_version": __api_version__, "api_version": __api_version__,
} }
@app.post("/login", dependencies=[Depends(optional_api_key)]) @app.post("/login")
async def login(form_data: OAuth2PasswordRequestForm = Depends()): async def login(form_data: OAuth2PasswordRequestForm = Depends()):
username = os.getenv("AUTH_USERNAME") username = os.getenv("AUTH_USERNAME")
password = os.getenv("AUTH_PASSWORD") password = os.getenv("AUTH_PASSWORD")
@@ -415,7 +420,7 @@ def create_app(args):
"api_version": __api_version__, "api_version": __api_version__,
} }
@app.get("/health", dependencies=[Depends(optional_api_key)]) @app.get("/health", dependencies=[Depends(optional_api_key_dependency)])
async def get_status(): async def get_status():
"""Get current system status""" """Get current system status"""
# Get update flags status for all namespaces # Get update flags status for all namespaces

View File

@@ -17,15 +17,13 @@ from pydantic import BaseModel, Field, field_validator
from lightrag import LightRAG from lightrag import LightRAG
from lightrag.base import DocProcessingStatus, DocStatus from lightrag.base import DocProcessingStatus, DocStatus
from lightrag.api.utils_api import ( from lightrag.api.utils_api import (
get_api_key_dependency, get_combined_auth_dependency,
global_args, global_args,
get_auth_dependency,
) )
router = APIRouter( router = APIRouter(
prefix="/documents", prefix="/documents",
tags=["documents"], tags=["documents"],
dependencies=[Depends(get_auth_dependency())],
) )
# Temporary file prefix # Temporary file prefix
@@ -505,9 +503,9 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
def create_document_routes( def create_document_routes(
rag: LightRAG, doc_manager: DocumentManager, api_key: Optional[str] = None rag: LightRAG, doc_manager: DocumentManager, api_key: Optional[str] = None
): ):
optional_api_key = get_api_key_dependency(api_key) combined_auth = get_combined_auth_dependency(api_key)
@router.post("/scan", dependencies=[Depends(optional_api_key)]) @router.post("/scan", dependencies=[Depends(combined_auth)])
async def scan_for_new_documents(background_tasks: BackgroundTasks): async def scan_for_new_documents(background_tasks: BackgroundTasks):
""" """
Trigger the scanning process for new documents. Trigger the scanning process for new documents.
@@ -523,7 +521,7 @@ def create_document_routes(
background_tasks.add_task(run_scanning_process, rag, doc_manager) background_tasks.add_task(run_scanning_process, rag, doc_manager)
return {"status": "scanning_started"} return {"status": "scanning_started"}
@router.post("/upload", dependencies=[Depends(optional_api_key)]) @router.post("/upload", dependencies=[Depends(combined_auth)])
async def upload_to_input_dir( async def upload_to_input_dir(
background_tasks: BackgroundTasks, file: UploadFile = File(...) background_tasks: BackgroundTasks, file: UploadFile = File(...)
): ):
@@ -568,7 +566,7 @@ def create_document_routes(
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.post( @router.post(
"/text", response_model=InsertResponse, dependencies=[Depends(optional_api_key)] "/text", response_model=InsertResponse, dependencies=[Depends(combined_auth)]
) )
async def insert_text( async def insert_text(
request: InsertTextRequest, background_tasks: BackgroundTasks request: InsertTextRequest, background_tasks: BackgroundTasks
@@ -603,7 +601,7 @@ def create_document_routes(
@router.post( @router.post(
"/texts", "/texts",
response_model=InsertResponse, response_model=InsertResponse,
dependencies=[Depends(optional_api_key)], dependencies=[Depends(combined_auth)],
) )
async def insert_texts( async def insert_texts(
request: InsertTextsRequest, background_tasks: BackgroundTasks request: InsertTextsRequest, background_tasks: BackgroundTasks
@@ -636,7 +634,7 @@ def create_document_routes(
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.post( @router.post(
"/file", response_model=InsertResponse, dependencies=[Depends(optional_api_key)] "/file", response_model=InsertResponse, dependencies=[Depends(combined_auth)]
) )
async def insert_file( async def insert_file(
background_tasks: BackgroundTasks, file: UploadFile = File(...) background_tasks: BackgroundTasks, file: UploadFile = File(...)
@@ -681,7 +679,7 @@ def create_document_routes(
@router.post( @router.post(
"/file_batch", "/file_batch",
response_model=InsertResponse, response_model=InsertResponse,
dependencies=[Depends(optional_api_key)], dependencies=[Depends(combined_auth)],
) )
async def insert_batch( async def insert_batch(
background_tasks: BackgroundTasks, files: List[UploadFile] = File(...) background_tasks: BackgroundTasks, files: List[UploadFile] = File(...)
@@ -742,7 +740,7 @@ def create_document_routes(
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.delete( @router.delete(
"", response_model=InsertResponse, dependencies=[Depends(optional_api_key)] "", response_model=InsertResponse, dependencies=[Depends(combined_auth)]
) )
async def clear_documents(): async def clear_documents():
""" """
@@ -771,7 +769,7 @@ def create_document_routes(
@router.get( @router.get(
"/pipeline_status", "/pipeline_status",
dependencies=[Depends(optional_api_key)], dependencies=[Depends(combined_auth)],
response_model=PipelineStatusResponse, response_model=PipelineStatusResponse,
) )
async def get_pipeline_status() -> PipelineStatusResponse: async def get_pipeline_status() -> PipelineStatusResponse:
@@ -819,7 +817,7 @@ def create_document_routes(
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.get("", dependencies=[Depends(optional_api_key)]) @router.get("", dependencies=[Depends(combined_auth)])
async def documents() -> DocsStatusesResponse: async def documents() -> DocsStatusesResponse:
""" """
Get the status of all documents in the system. Get the status of all documents in the system.

View File

@@ -5,15 +5,15 @@ This module contains all graph-related routes for the LightRAG API.
from typing import Optional from typing import Optional
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from ..utils_api import get_api_key_dependency, get_auth_dependency from ..utils_api import get_combined_auth_dependency
router = APIRouter(tags=["graph"], dependencies=[Depends(get_auth_dependency())]) router = APIRouter(tags=["graph"])
def create_graph_routes(rag, api_key: Optional[str] = None): def create_graph_routes(rag, api_key: Optional[str] = None):
optional_api_key = get_api_key_dependency(api_key) combined_auth = get_combined_auth_dependency(api_key)
@router.get("/graph/label/list", dependencies=[Depends(optional_api_key)]) @router.get("/graph/label/list", dependencies=[Depends(combined_auth)])
async def get_graph_labels(): async def get_graph_labels():
""" """
Get all graph labels Get all graph labels
@@ -23,7 +23,7 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
""" """
return await rag.get_graph_labels() return await rag.get_graph_labels()
@router.get("/graphs", dependencies=[Depends(optional_api_key)]) @router.get("/graphs", dependencies=[Depends(combined_auth)])
async def get_knowledge_graph( async def get_knowledge_graph(
label: str, max_depth: int = 3, min_degree: int = 0, inclusive: bool = False label: str, max_depth: int = 3, min_degree: int = 0, inclusive: bool = False
): ):

View File

@@ -11,7 +11,8 @@ import asyncio
from ascii_colors import trace_exception from ascii_colors import trace_exception
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
from lightrag.utils import encode_string_by_tiktoken from lightrag.utils import encode_string_by_tiktoken
from lightrag.api.utils_api import ollama_server_infos from lightrag.api.utils_api import ollama_server_infos, get_api_key_dependency
from fastapi import Depends
# query mode according to query prefix (bypass is not LightRAG quer mode) # query mode according to query prefix (bypass is not LightRAG quer mode)
@@ -122,20 +123,30 @@ def parse_query_mode(query: str) -> tuple[str, SearchMode]:
class OllamaAPI: class OllamaAPI:
def __init__(self, rag: LightRAG, top_k: int = 60): def __init__(self, rag: LightRAG, top_k: int = 60, api_key: Optional[str] = None):
self.rag = rag self.rag = rag
self.ollama_server_infos = ollama_server_infos self.ollama_server_infos = ollama_server_infos
self.top_k = top_k self.top_k = top_k
self.api_key = api_key
self.router = APIRouter(tags=["ollama"]) self.router = APIRouter(tags=["ollama"])
self.setup_routes() self.setup_routes()
def setup_routes(self): def setup_routes(self):
@self.router.get("/version") # Create a dependency that passes the request to get_api_key_dependency
async def optional_api_key_dependency(request: Request):
# Create the dependency function with the request
api_key_dependency = get_api_key_dependency(self.api_key)
# Call the dependency function with the request
return await api_key_dependency(request)
@self.router.get(
"/version", dependencies=[Depends(optional_api_key_dependency)]
)
async def get_version(): async def get_version():
"""Get Ollama version information""" """Get Ollama version information"""
return OllamaVersionResponse(version="0.5.4") return OllamaVersionResponse(version="0.5.4")
@self.router.get("/tags") @self.router.get("/tags", dependencies=[Depends(optional_api_key_dependency)])
async def get_tags(): async def get_tags():
"""Return available models acting as an Ollama server""" """Return available models acting as an Ollama server"""
return OllamaTagResponse( return OllamaTagResponse(
@@ -158,7 +169,9 @@ class OllamaAPI:
] ]
) )
@self.router.post("/generate") @self.router.post(
"/generate", dependencies=[Depends(optional_api_key_dependency)]
)
async def generate(raw_request: Request, request: OllamaGenerateRequest): async def generate(raw_request: Request, request: OllamaGenerateRequest):
"""Handle generate completion requests acting as an Ollama model """Handle generate completion requests acting as an Ollama model
For compatibility purpose, the request is not processed by LightRAG, For compatibility purpose, the request is not processed by LightRAG,
@@ -324,7 +337,7 @@ class OllamaAPI:
trace_exception(e) trace_exception(e)
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@self.router.post("/chat") @self.router.post("/chat", dependencies=[Depends(optional_api_key_dependency)])
async def chat(raw_request: Request, request: OllamaChatRequest): async def chat(raw_request: Request, request: OllamaChatRequest):
"""Process chat completion requests acting as an Ollama model """Process chat completion requests acting as an Ollama model
Routes user queries through LightRAG by selecting query mode based on prefix indicators. Routes user queries through LightRAG by selecting query mode based on prefix indicators.

View File

@@ -8,12 +8,12 @@ from typing import Any, Dict, List, Literal, Optional
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from lightrag.base import QueryParam from lightrag.base import QueryParam
from ..utils_api import get_api_key_dependency, get_auth_dependency from ..utils_api import get_combined_auth_dependency
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from ascii_colors import trace_exception from ascii_colors import trace_exception
router = APIRouter(tags=["query"], dependencies=[Depends(get_auth_dependency())]) router = APIRouter(tags=["query"])
class QueryRequest(BaseModel): class QueryRequest(BaseModel):
@@ -139,10 +139,10 @@ class QueryResponse(BaseModel):
def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60): def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
optional_api_key = get_api_key_dependency(api_key) combined_auth = get_combined_auth_dependency(api_key)
@router.post( @router.post(
"/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)] "/query", response_model=QueryResponse, dependencies=[Depends(combined_auth)]
) )
async def query_text(request: QueryRequest): async def query_text(request: QueryRequest):
""" """
@@ -176,7 +176,7 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
trace_exception(e) trace_exception(e)
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.post("/query/stream", dependencies=[Depends(optional_api_key)]) @router.post("/query/stream", dependencies=[Depends(combined_auth)])
async def query_text_stream(request: QueryRequest): async def query_text_stream(request: QueryRequest):
""" """
This endpoint performs a retrieval-augmented generation (RAG) query and streams the response. This endpoint performs a retrieval-augmented generation (RAG) query and streams the response.

View File

@@ -4,7 +4,7 @@ Utility functions for the LightRAG API.
import os import os
import argparse import argparse
from typing import Optional from typing import Optional, List, Tuple
import sys import sys
import logging import logging
from ascii_colors import ASCIIColors from ascii_colors import ASCIIColors
@@ -21,6 +21,29 @@ load_dotenv()
global_args = {"main_args": None} global_args = {"main_args": None}
# Get whitelist paths from environment variable, only once during initialization
default_whitelist = "/health,/api/*"
whitelist_paths = os.getenv("WHITELIST_PATHS", default_whitelist).split(",")
# Pre-compile path matching patterns
whitelist_patterns: List[Tuple[str, bool]] = []
for path in whitelist_paths:
path = path.strip()
if path:
# If path ends with /*, match all paths with that prefix
if path.endswith("/*"):
prefix = path[:-2]
whitelist_patterns.append((prefix, True)) # (prefix, is_prefix_match)
else:
whitelist_patterns.append(
(path, False)
) # (exact_path, is_prefix_match)
# Global authentication configuration
auth_username = os.getenv("AUTH_USERNAME")
auth_password = os.getenv("AUTH_PASSWORD")
auth_configured = bool(auth_username and auth_password)
class OllamaServerInfos: class OllamaServerInfos:
# Constants for emulated Ollama model information # Constants for emulated Ollama model information
@@ -35,49 +58,69 @@ class OllamaServerInfos:
ollama_server_infos = OllamaServerInfos() ollama_server_infos = OllamaServerInfos()
def get_auth_dependency(): def get_combined_auth_dependency(api_key: Optional[str] = None):
# Set default whitelist paths """
whitelist = os.getenv("WHITELIST_PATHS", "/login,/health").split(",") Create a combined authentication dependency that implements OR logic (pass through any authentication method)
async def dependency( Args:
api_key (Optional[str]): API key for validation
Returns:
Callable: A dependency function that implements OR authentication logic
"""
# Use global whitelist_patterns and auth_configured variables
# whitelist_patterns and auth_configured are already initialized at module level
# Only calculate api_key_configured as it depends on the function parameter
api_key_configured = bool(api_key)
async def combined_dependency(
request: Request, request: Request,
token: str = Depends(OAuth2PasswordBearer(tokenUrl="login", auto_error=False)), token: str = Depends(OAuth2PasswordBearer(tokenUrl="login", auto_error=False)),
): ):
# Check if authentication is configured # If both authentication methods are not configured, allow access
auth_configured = bool( if not auth_configured and not api_key_configured:
os.getenv("AUTH_USERNAME") and os.getenv("AUTH_PASSWORD")
)
# If authentication is not configured, skip all validation
if not auth_configured:
return return
# For configured auth, allow whitelist paths without token # Check if request path is in whitelist
if request.url.path in whitelist: path = request.url.path
return for pattern, is_prefix in whitelist_patterns:
if (is_prefix and path.startswith(pattern)) or (
not is_prefix and path == pattern
):
return # Whitelist path, allow access
# Require token for all other paths when auth is configured # Access with token
if not token: if token:
token_info = auth_handler.validate_token(token)
if auth_configured:
if token_info.get("role") != "guest" or not api_key_configured:
return # Password authentication successful
else:
if token_info.get("role") == "guest":
return # Guest authentication successful
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token required" status_code=status.HTTP_401_UNAUTHORIZED, detail="Token required"
) )
# Try API key authentication (if configured)
if api_key_configured:
api_key_header = request.headers.get("X-API-Key")
if api_key_header and api_key_header == api_key:
return # API key authentication successful
else:
if auth_configured:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail="API Key required or use password authentication.",
)
else:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail="API Key required or use guest authentication.",
)
try: return combined_dependency
token_info = auth_handler.validate_token(token)
# Reject guest tokens when authentication is configured
if token_info.get("role") == "guest":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication required. Guest access not allowed when authentication is configured.",
)
except Exception:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
)
return
return dependency
def get_api_key_dependency(api_key: Optional[str]): def get_api_key_dependency(api_key: Optional[str]):
@@ -91,9 +134,15 @@ def get_api_key_dependency(api_key: Optional[str]):
Returns: Returns:
Callable: A dependency function that validates the API key. Callable: A dependency function that validates the API key.
""" """
if not api_key: # Use global whitelist_patterns and auth_configured variables
# whitelist_patterns and auth_configured are already initialized at module level
# Only calculate api_key_configured as it depends on the function parameter
api_key_configured = bool(api_key)
if not api_key_configured:
# If no API key is configured, return a dummy dependency that always succeeds # If no API key is configured, return a dummy dependency that always succeeds
async def no_auth(): async def no_auth(request: Request = None, **kwargs):
return None return None
return no_auth return no_auth
@@ -102,8 +151,18 @@ def get_api_key_dependency(api_key: Optional[str]):
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
async def api_key_auth( async def api_key_auth(
request: Request,
api_key_header_value: Optional[str] = Security(api_key_header), api_key_header_value: Optional[str] = Security(api_key_header),
): ):
# Check if request path is in whitelist
path = request.url.path
for pattern, is_prefix in whitelist_patterns:
if (is_prefix and path.startswith(pattern)) or (
not is_prefix and path == pattern
):
return # Whitelist path, allow access
# Non-whitelist path, validate API key
if not api_key_header_value: if not api_key_header_value:
raise HTTPException( raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="API Key required" status_code=HTTP_403_FORBIDDEN, detail="API Key required"