Refactor authentication and whitelist handling
- Combined auth and API key dependencies - Optimized whitelist path matching - Added optional API key to OllamaAPI
This commit is contained in:
@@ -2,7 +2,7 @@
|
|||||||
LightRAG FastAPI Server
|
LightRAG FastAPI Server
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from fastapi import FastAPI, Depends, HTTPException, status
|
from fastapi import FastAPI, Depends, HTTPException, status, Request
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
@@ -169,7 +169,12 @@ def create_app(args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Create the optional API key dependency
|
# Create the optional API key dependency
|
||||||
optional_api_key = get_api_key_dependency(api_key)
|
# Create a dependency that passes the request to get_api_key_dependency
|
||||||
|
async def optional_api_key_dependency(request: Request):
|
||||||
|
# Create the dependency function with the request
|
||||||
|
api_key_dependency = get_api_key_dependency(api_key)
|
||||||
|
# Call the dependency function with the request
|
||||||
|
return await api_key_dependency(request)
|
||||||
|
|
||||||
# Create working directory if it doesn't exist
|
# Create working directory if it doesn't exist
|
||||||
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
|
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
|
||||||
@@ -343,7 +348,7 @@ def create_app(args):
|
|||||||
app.include_router(create_graph_routes(rag, api_key))
|
app.include_router(create_graph_routes(rag, api_key))
|
||||||
|
|
||||||
# Add Ollama API routes
|
# Add Ollama API routes
|
||||||
ollama_api = OllamaAPI(rag, top_k=args.top_k)
|
ollama_api = OllamaAPI(rag, top_k=args.top_k, api_key=api_key)
|
||||||
app.include_router(ollama_api.router, prefix="/api")
|
app.include_router(ollama_api.router, prefix="/api")
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
@@ -351,7 +356,7 @@ def create_app(args):
|
|||||||
"""Redirect root path to /webui"""
|
"""Redirect root path to /webui"""
|
||||||
return RedirectResponse(url="/webui")
|
return RedirectResponse(url="/webui")
|
||||||
|
|
||||||
@app.get("/auth-status", dependencies=[Depends(optional_api_key)])
|
@app.get("/auth-status")
|
||||||
async def get_auth_status():
|
async def get_auth_status():
|
||||||
"""Get authentication status and guest token if auth is not configured"""
|
"""Get authentication status and guest token if auth is not configured"""
|
||||||
username = os.getenv("AUTH_USERNAME")
|
username = os.getenv("AUTH_USERNAME")
|
||||||
@@ -379,7 +384,7 @@ def create_app(args):
|
|||||||
"api_version": __api_version__,
|
"api_version": __api_version__,
|
||||||
}
|
}
|
||||||
|
|
||||||
@app.post("/login", dependencies=[Depends(optional_api_key)])
|
@app.post("/login")
|
||||||
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
|
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
|
||||||
username = os.getenv("AUTH_USERNAME")
|
username = os.getenv("AUTH_USERNAME")
|
||||||
password = os.getenv("AUTH_PASSWORD")
|
password = os.getenv("AUTH_PASSWORD")
|
||||||
@@ -415,7 +420,7 @@ def create_app(args):
|
|||||||
"api_version": __api_version__,
|
"api_version": __api_version__,
|
||||||
}
|
}
|
||||||
|
|
||||||
@app.get("/health", dependencies=[Depends(optional_api_key)])
|
@app.get("/health", dependencies=[Depends(optional_api_key_dependency)])
|
||||||
async def get_status():
|
async def get_status():
|
||||||
"""Get current system status"""
|
"""Get current system status"""
|
||||||
# Get update flags status for all namespaces
|
# Get update flags status for all namespaces
|
||||||
|
@@ -17,15 +17,13 @@ from pydantic import BaseModel, Field, field_validator
|
|||||||
from lightrag import LightRAG
|
from lightrag import LightRAG
|
||||||
from lightrag.base import DocProcessingStatus, DocStatus
|
from lightrag.base import DocProcessingStatus, DocStatus
|
||||||
from lightrag.api.utils_api import (
|
from lightrag.api.utils_api import (
|
||||||
get_api_key_dependency,
|
get_combined_auth_dependency,
|
||||||
global_args,
|
global_args,
|
||||||
get_auth_dependency,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
prefix="/documents",
|
prefix="/documents",
|
||||||
tags=["documents"],
|
tags=["documents"],
|
||||||
dependencies=[Depends(get_auth_dependency())],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Temporary file prefix
|
# Temporary file prefix
|
||||||
@@ -505,9 +503,9 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
|
|||||||
def create_document_routes(
|
def create_document_routes(
|
||||||
rag: LightRAG, doc_manager: DocumentManager, api_key: Optional[str] = None
|
rag: LightRAG, doc_manager: DocumentManager, api_key: Optional[str] = None
|
||||||
):
|
):
|
||||||
optional_api_key = get_api_key_dependency(api_key)
|
combined_auth = get_combined_auth_dependency(api_key)
|
||||||
|
|
||||||
@router.post("/scan", dependencies=[Depends(optional_api_key)])
|
@router.post("/scan", dependencies=[Depends(combined_auth)])
|
||||||
async def scan_for_new_documents(background_tasks: BackgroundTasks):
|
async def scan_for_new_documents(background_tasks: BackgroundTasks):
|
||||||
"""
|
"""
|
||||||
Trigger the scanning process for new documents.
|
Trigger the scanning process for new documents.
|
||||||
@@ -523,7 +521,7 @@ def create_document_routes(
|
|||||||
background_tasks.add_task(run_scanning_process, rag, doc_manager)
|
background_tasks.add_task(run_scanning_process, rag, doc_manager)
|
||||||
return {"status": "scanning_started"}
|
return {"status": "scanning_started"}
|
||||||
|
|
||||||
@router.post("/upload", dependencies=[Depends(optional_api_key)])
|
@router.post("/upload", dependencies=[Depends(combined_auth)])
|
||||||
async def upload_to_input_dir(
|
async def upload_to_input_dir(
|
||||||
background_tasks: BackgroundTasks, file: UploadFile = File(...)
|
background_tasks: BackgroundTasks, file: UploadFile = File(...)
|
||||||
):
|
):
|
||||||
@@ -568,7 +566,7 @@ def create_document_routes(
|
|||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/text", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]
|
"/text", response_model=InsertResponse, dependencies=[Depends(combined_auth)]
|
||||||
)
|
)
|
||||||
async def insert_text(
|
async def insert_text(
|
||||||
request: InsertTextRequest, background_tasks: BackgroundTasks
|
request: InsertTextRequest, background_tasks: BackgroundTasks
|
||||||
@@ -603,7 +601,7 @@ def create_document_routes(
|
|||||||
@router.post(
|
@router.post(
|
||||||
"/texts",
|
"/texts",
|
||||||
response_model=InsertResponse,
|
response_model=InsertResponse,
|
||||||
dependencies=[Depends(optional_api_key)],
|
dependencies=[Depends(combined_auth)],
|
||||||
)
|
)
|
||||||
async def insert_texts(
|
async def insert_texts(
|
||||||
request: InsertTextsRequest, background_tasks: BackgroundTasks
|
request: InsertTextsRequest, background_tasks: BackgroundTasks
|
||||||
@@ -636,7 +634,7 @@ def create_document_routes(
|
|||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/file", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]
|
"/file", response_model=InsertResponse, dependencies=[Depends(combined_auth)]
|
||||||
)
|
)
|
||||||
async def insert_file(
|
async def insert_file(
|
||||||
background_tasks: BackgroundTasks, file: UploadFile = File(...)
|
background_tasks: BackgroundTasks, file: UploadFile = File(...)
|
||||||
@@ -681,7 +679,7 @@ def create_document_routes(
|
|||||||
@router.post(
|
@router.post(
|
||||||
"/file_batch",
|
"/file_batch",
|
||||||
response_model=InsertResponse,
|
response_model=InsertResponse,
|
||||||
dependencies=[Depends(optional_api_key)],
|
dependencies=[Depends(combined_auth)],
|
||||||
)
|
)
|
||||||
async def insert_batch(
|
async def insert_batch(
|
||||||
background_tasks: BackgroundTasks, files: List[UploadFile] = File(...)
|
background_tasks: BackgroundTasks, files: List[UploadFile] = File(...)
|
||||||
@@ -742,7 +740,7 @@ def create_document_routes(
|
|||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@router.delete(
|
@router.delete(
|
||||||
"", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]
|
"", response_model=InsertResponse, dependencies=[Depends(combined_auth)]
|
||||||
)
|
)
|
||||||
async def clear_documents():
|
async def clear_documents():
|
||||||
"""
|
"""
|
||||||
@@ -771,7 +769,7 @@ def create_document_routes(
|
|||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/pipeline_status",
|
"/pipeline_status",
|
||||||
dependencies=[Depends(optional_api_key)],
|
dependencies=[Depends(combined_auth)],
|
||||||
response_model=PipelineStatusResponse,
|
response_model=PipelineStatusResponse,
|
||||||
)
|
)
|
||||||
async def get_pipeline_status() -> PipelineStatusResponse:
|
async def get_pipeline_status() -> PipelineStatusResponse:
|
||||||
@@ -819,7 +817,7 @@ def create_document_routes(
|
|||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@router.get("", dependencies=[Depends(optional_api_key)])
|
@router.get("", dependencies=[Depends(combined_auth)])
|
||||||
async def documents() -> DocsStatusesResponse:
|
async def documents() -> DocsStatusesResponse:
|
||||||
"""
|
"""
|
||||||
Get the status of all documents in the system.
|
Get the status of all documents in the system.
|
||||||
|
@@ -5,15 +5,15 @@ This module contains all graph-related routes for the LightRAG API.
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends
|
||||||
|
|
||||||
from ..utils_api import get_api_key_dependency, get_auth_dependency
|
from ..utils_api import get_combined_auth_dependency
|
||||||
|
|
||||||
router = APIRouter(tags=["graph"], dependencies=[Depends(get_auth_dependency())])
|
router = APIRouter(tags=["graph"])
|
||||||
|
|
||||||
|
|
||||||
def create_graph_routes(rag, api_key: Optional[str] = None):
|
def create_graph_routes(rag, api_key: Optional[str] = None):
|
||||||
optional_api_key = get_api_key_dependency(api_key)
|
combined_auth = get_combined_auth_dependency(api_key)
|
||||||
|
|
||||||
@router.get("/graph/label/list", dependencies=[Depends(optional_api_key)])
|
@router.get("/graph/label/list", dependencies=[Depends(combined_auth)])
|
||||||
async def get_graph_labels():
|
async def get_graph_labels():
|
||||||
"""
|
"""
|
||||||
Get all graph labels
|
Get all graph labels
|
||||||
@@ -23,7 +23,7 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
|
|||||||
"""
|
"""
|
||||||
return await rag.get_graph_labels()
|
return await rag.get_graph_labels()
|
||||||
|
|
||||||
@router.get("/graphs", dependencies=[Depends(optional_api_key)])
|
@router.get("/graphs", dependencies=[Depends(combined_auth)])
|
||||||
async def get_knowledge_graph(
|
async def get_knowledge_graph(
|
||||||
label: str, max_depth: int = 3, min_degree: int = 0, inclusive: bool = False
|
label: str, max_depth: int = 3, min_degree: int = 0, inclusive: bool = False
|
||||||
):
|
):
|
||||||
|
@@ -11,7 +11,8 @@ import asyncio
|
|||||||
from ascii_colors import trace_exception
|
from ascii_colors import trace_exception
|
||||||
from lightrag import LightRAG, QueryParam
|
from lightrag import LightRAG, QueryParam
|
||||||
from lightrag.utils import encode_string_by_tiktoken
|
from lightrag.utils import encode_string_by_tiktoken
|
||||||
from lightrag.api.utils_api import ollama_server_infos
|
from lightrag.api.utils_api import ollama_server_infos, get_api_key_dependency
|
||||||
|
from fastapi import Depends
|
||||||
|
|
||||||
|
|
||||||
# query mode according to query prefix (bypass is not LightRAG quer mode)
|
# query mode according to query prefix (bypass is not LightRAG quer mode)
|
||||||
@@ -122,20 +123,30 @@ def parse_query_mode(query: str) -> tuple[str, SearchMode]:
|
|||||||
|
|
||||||
|
|
||||||
class OllamaAPI:
|
class OllamaAPI:
|
||||||
def __init__(self, rag: LightRAG, top_k: int = 60):
|
def __init__(self, rag: LightRAG, top_k: int = 60, api_key: Optional[str] = None):
|
||||||
self.rag = rag
|
self.rag = rag
|
||||||
self.ollama_server_infos = ollama_server_infos
|
self.ollama_server_infos = ollama_server_infos
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
|
self.api_key = api_key
|
||||||
self.router = APIRouter(tags=["ollama"])
|
self.router = APIRouter(tags=["ollama"])
|
||||||
self.setup_routes()
|
self.setup_routes()
|
||||||
|
|
||||||
def setup_routes(self):
|
def setup_routes(self):
|
||||||
@self.router.get("/version")
|
# Create a dependency that passes the request to get_api_key_dependency
|
||||||
|
async def optional_api_key_dependency(request: Request):
|
||||||
|
# Create the dependency function with the request
|
||||||
|
api_key_dependency = get_api_key_dependency(self.api_key)
|
||||||
|
# Call the dependency function with the request
|
||||||
|
return await api_key_dependency(request)
|
||||||
|
|
||||||
|
@self.router.get(
|
||||||
|
"/version", dependencies=[Depends(optional_api_key_dependency)]
|
||||||
|
)
|
||||||
async def get_version():
|
async def get_version():
|
||||||
"""Get Ollama version information"""
|
"""Get Ollama version information"""
|
||||||
return OllamaVersionResponse(version="0.5.4")
|
return OllamaVersionResponse(version="0.5.4")
|
||||||
|
|
||||||
@self.router.get("/tags")
|
@self.router.get("/tags", dependencies=[Depends(optional_api_key_dependency)])
|
||||||
async def get_tags():
|
async def get_tags():
|
||||||
"""Return available models acting as an Ollama server"""
|
"""Return available models acting as an Ollama server"""
|
||||||
return OllamaTagResponse(
|
return OllamaTagResponse(
|
||||||
@@ -158,7 +169,9 @@ class OllamaAPI:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@self.router.post("/generate")
|
@self.router.post(
|
||||||
|
"/generate", dependencies=[Depends(optional_api_key_dependency)]
|
||||||
|
)
|
||||||
async def generate(raw_request: Request, request: OllamaGenerateRequest):
|
async def generate(raw_request: Request, request: OllamaGenerateRequest):
|
||||||
"""Handle generate completion requests acting as an Ollama model
|
"""Handle generate completion requests acting as an Ollama model
|
||||||
For compatibility purpose, the request is not processed by LightRAG,
|
For compatibility purpose, the request is not processed by LightRAG,
|
||||||
@@ -324,7 +337,7 @@ class OllamaAPI:
|
|||||||
trace_exception(e)
|
trace_exception(e)
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@self.router.post("/chat")
|
@self.router.post("/chat", dependencies=[Depends(optional_api_key_dependency)])
|
||||||
async def chat(raw_request: Request, request: OllamaChatRequest):
|
async def chat(raw_request: Request, request: OllamaChatRequest):
|
||||||
"""Process chat completion requests acting as an Ollama model
|
"""Process chat completion requests acting as an Ollama model
|
||||||
Routes user queries through LightRAG by selecting query mode based on prefix indicators.
|
Routes user queries through LightRAG by selecting query mode based on prefix indicators.
|
||||||
|
@@ -8,12 +8,12 @@ from typing import Any, Dict, List, Literal, Optional
|
|||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from lightrag.base import QueryParam
|
from lightrag.base import QueryParam
|
||||||
from ..utils_api import get_api_key_dependency, get_auth_dependency
|
from ..utils_api import get_combined_auth_dependency
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from ascii_colors import trace_exception
|
from ascii_colors import trace_exception
|
||||||
|
|
||||||
router = APIRouter(tags=["query"], dependencies=[Depends(get_auth_dependency())])
|
router = APIRouter(tags=["query"])
|
||||||
|
|
||||||
|
|
||||||
class QueryRequest(BaseModel):
|
class QueryRequest(BaseModel):
|
||||||
@@ -139,10 +139,10 @@ class QueryResponse(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
||||||
optional_api_key = get_api_key_dependency(api_key)
|
combined_auth = get_combined_auth_dependency(api_key)
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)]
|
"/query", response_model=QueryResponse, dependencies=[Depends(combined_auth)]
|
||||||
)
|
)
|
||||||
async def query_text(request: QueryRequest):
|
async def query_text(request: QueryRequest):
|
||||||
"""
|
"""
|
||||||
@@ -176,7 +176,7 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
|||||||
trace_exception(e)
|
trace_exception(e)
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@router.post("/query/stream", dependencies=[Depends(optional_api_key)])
|
@router.post("/query/stream", dependencies=[Depends(combined_auth)])
|
||||||
async def query_text_stream(request: QueryRequest):
|
async def query_text_stream(request: QueryRequest):
|
||||||
"""
|
"""
|
||||||
This endpoint performs a retrieval-augmented generation (RAG) query and streams the response.
|
This endpoint performs a retrieval-augmented generation (RAG) query and streams the response.
|
||||||
|
@@ -4,7 +4,7 @@ Utility functions for the LightRAG API.
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import argparse
|
import argparse
|
||||||
from typing import Optional
|
from typing import Optional, List, Tuple
|
||||||
import sys
|
import sys
|
||||||
import logging
|
import logging
|
||||||
from ascii_colors import ASCIIColors
|
from ascii_colors import ASCIIColors
|
||||||
@@ -21,6 +21,29 @@ load_dotenv()
|
|||||||
|
|
||||||
global_args = {"main_args": None}
|
global_args = {"main_args": None}
|
||||||
|
|
||||||
|
# Get whitelist paths from environment variable, only once during initialization
|
||||||
|
default_whitelist = "/health,/api/*"
|
||||||
|
whitelist_paths = os.getenv("WHITELIST_PATHS", default_whitelist).split(",")
|
||||||
|
|
||||||
|
# Pre-compile path matching patterns
|
||||||
|
whitelist_patterns: List[Tuple[str, bool]] = []
|
||||||
|
for path in whitelist_paths:
|
||||||
|
path = path.strip()
|
||||||
|
if path:
|
||||||
|
# If path ends with /*, match all paths with that prefix
|
||||||
|
if path.endswith("/*"):
|
||||||
|
prefix = path[:-2]
|
||||||
|
whitelist_patterns.append((prefix, True)) # (prefix, is_prefix_match)
|
||||||
|
else:
|
||||||
|
whitelist_patterns.append(
|
||||||
|
(path, False)
|
||||||
|
) # (exact_path, is_prefix_match)
|
||||||
|
|
||||||
|
# Global authentication configuration
|
||||||
|
auth_username = os.getenv("AUTH_USERNAME")
|
||||||
|
auth_password = os.getenv("AUTH_PASSWORD")
|
||||||
|
auth_configured = bool(auth_username and auth_password)
|
||||||
|
|
||||||
|
|
||||||
class OllamaServerInfos:
|
class OllamaServerInfos:
|
||||||
# Constants for emulated Ollama model information
|
# Constants for emulated Ollama model information
|
||||||
@@ -35,49 +58,69 @@ class OllamaServerInfos:
|
|||||||
ollama_server_infos = OllamaServerInfos()
|
ollama_server_infos = OllamaServerInfos()
|
||||||
|
|
||||||
|
|
||||||
def get_auth_dependency():
|
def get_combined_auth_dependency(api_key: Optional[str] = None):
|
||||||
# Set default whitelist paths
|
"""
|
||||||
whitelist = os.getenv("WHITELIST_PATHS", "/login,/health").split(",")
|
Create a combined authentication dependency that implements OR logic (pass through any authentication method)
|
||||||
|
|
||||||
async def dependency(
|
Args:
|
||||||
|
api_key (Optional[str]): API key for validation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable: A dependency function that implements OR authentication logic
|
||||||
|
"""
|
||||||
|
# Use global whitelist_patterns and auth_configured variables
|
||||||
|
# whitelist_patterns and auth_configured are already initialized at module level
|
||||||
|
|
||||||
|
# Only calculate api_key_configured as it depends on the function parameter
|
||||||
|
api_key_configured = bool(api_key)
|
||||||
|
|
||||||
|
async def combined_dependency(
|
||||||
request: Request,
|
request: Request,
|
||||||
token: str = Depends(OAuth2PasswordBearer(tokenUrl="login", auto_error=False)),
|
token: str = Depends(OAuth2PasswordBearer(tokenUrl="login", auto_error=False)),
|
||||||
):
|
):
|
||||||
# Check if authentication is configured
|
# If both authentication methods are not configured, allow access
|
||||||
auth_configured = bool(
|
if not auth_configured and not api_key_configured:
|
||||||
os.getenv("AUTH_USERNAME") and os.getenv("AUTH_PASSWORD")
|
|
||||||
)
|
|
||||||
|
|
||||||
# If authentication is not configured, skip all validation
|
|
||||||
if not auth_configured:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# For configured auth, allow whitelist paths without token
|
# Check if request path is in whitelist
|
||||||
if request.url.path in whitelist:
|
path = request.url.path
|
||||||
return
|
for pattern, is_prefix in whitelist_patterns:
|
||||||
|
if (is_prefix and path.startswith(pattern)) or (
|
||||||
|
not is_prefix and path == pattern
|
||||||
|
):
|
||||||
|
return # Whitelist path, allow access
|
||||||
|
|
||||||
# Require token for all other paths when auth is configured
|
# Access with token
|
||||||
if not token:
|
if token:
|
||||||
|
token_info = auth_handler.validate_token(token)
|
||||||
|
if auth_configured:
|
||||||
|
if token_info.get("role") != "guest" or not api_key_configured:
|
||||||
|
return # Password authentication successful
|
||||||
|
else:
|
||||||
|
if token_info.get("role") == "guest":
|
||||||
|
return # Guest authentication successful
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token required"
|
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token required"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
# Try API key authentication (if configured)
|
||||||
token_info = auth_handler.validate_token(token)
|
if api_key_configured:
|
||||||
# Reject guest tokens when authentication is configured
|
api_key_header = request.headers.get("X-API-Key")
|
||||||
if token_info.get("role") == "guest":
|
if api_key_header and api_key_header == api_key:
|
||||||
|
return # API key authentication successful
|
||||||
|
else:
|
||||||
|
if auth_configured:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=HTTP_403_FORBIDDEN,
|
||||||
detail="Authentication required. Guest access not allowed when authentication is configured.",
|
detail="API Key required or use password authentication.",
|
||||||
)
|
)
|
||||||
except Exception:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
|
status_code=HTTP_403_FORBIDDEN,
|
||||||
|
detail="API Key required or use guest authentication.",
|
||||||
)
|
)
|
||||||
|
|
||||||
return
|
return combined_dependency
|
||||||
|
|
||||||
return dependency
|
|
||||||
|
|
||||||
|
|
||||||
def get_api_key_dependency(api_key: Optional[str]):
|
def get_api_key_dependency(api_key: Optional[str]):
|
||||||
@@ -91,9 +134,15 @@ def get_api_key_dependency(api_key: Optional[str]):
|
|||||||
Returns:
|
Returns:
|
||||||
Callable: A dependency function that validates the API key.
|
Callable: A dependency function that validates the API key.
|
||||||
"""
|
"""
|
||||||
if not api_key:
|
# Use global whitelist_patterns and auth_configured variables
|
||||||
|
# whitelist_patterns and auth_configured are already initialized at module level
|
||||||
|
|
||||||
|
# Only calculate api_key_configured as it depends on the function parameter
|
||||||
|
api_key_configured = bool(api_key)
|
||||||
|
|
||||||
|
if not api_key_configured:
|
||||||
# If no API key is configured, return a dummy dependency that always succeeds
|
# If no API key is configured, return a dummy dependency that always succeeds
|
||||||
async def no_auth():
|
async def no_auth(request: Request = None, **kwargs):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return no_auth
|
return no_auth
|
||||||
@@ -102,8 +151,18 @@ def get_api_key_dependency(api_key: Optional[str]):
|
|||||||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||||
|
|
||||||
async def api_key_auth(
|
async def api_key_auth(
|
||||||
|
request: Request,
|
||||||
api_key_header_value: Optional[str] = Security(api_key_header),
|
api_key_header_value: Optional[str] = Security(api_key_header),
|
||||||
):
|
):
|
||||||
|
# Check if request path is in whitelist
|
||||||
|
path = request.url.path
|
||||||
|
for pattern, is_prefix in whitelist_patterns:
|
||||||
|
if (is_prefix and path.startswith(pattern)) or (
|
||||||
|
not is_prefix and path == pattern
|
||||||
|
):
|
||||||
|
return # Whitelist path, allow access
|
||||||
|
|
||||||
|
# Non-whitelist path, validate API key
|
||||||
if not api_key_header_value:
|
if not api_key_header_value:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=HTTP_403_FORBIDDEN, detail="API Key required"
|
status_code=HTTP_403_FORBIDDEN, detail="API Key required"
|
||||||
|
Reference in New Issue
Block a user