Refactor authentication and whitelist handling

- Combined auth and API key dependencies
- Optimized whitelist path matching
- Added optional API key to OllamaAPI
This commit is contained in:
yangdx
2025-03-24 05:23:40 +08:00
parent 8301f0a523
commit 90ef55960d
6 changed files with 145 additions and 70 deletions

View File

@@ -2,7 +2,7 @@
LightRAG FastAPI Server
"""
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi import FastAPI, Depends, HTTPException, status, Request
import asyncio
import os
import logging
@@ -169,7 +169,12 @@ def create_app(args):
)
# Create the optional API key dependency
optional_api_key = get_api_key_dependency(api_key)
# Create a dependency that passes the request to get_api_key_dependency
async def optional_api_key_dependency(request: Request):
# Create the dependency function with the request
api_key_dependency = get_api_key_dependency(api_key)
# Call the dependency function with the request
return await api_key_dependency(request)
# Create working directory if it doesn't exist
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
@@ -343,7 +348,7 @@ def create_app(args):
app.include_router(create_graph_routes(rag, api_key))
# Add Ollama API routes
ollama_api = OllamaAPI(rag, top_k=args.top_k)
ollama_api = OllamaAPI(rag, top_k=args.top_k, api_key=api_key)
app.include_router(ollama_api.router, prefix="/api")
@app.get("/")
@@ -351,7 +356,7 @@ def create_app(args):
"""Redirect root path to /webui"""
return RedirectResponse(url="/webui")
@app.get("/auth-status", dependencies=[Depends(optional_api_key)])
@app.get("/auth-status")
async def get_auth_status():
"""Get authentication status and guest token if auth is not configured"""
username = os.getenv("AUTH_USERNAME")
@@ -379,7 +384,7 @@ def create_app(args):
"api_version": __api_version__,
}
@app.post("/login", dependencies=[Depends(optional_api_key)])
@app.post("/login")
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
username = os.getenv("AUTH_USERNAME")
password = os.getenv("AUTH_PASSWORD")
@@ -415,7 +420,7 @@ def create_app(args):
"api_version": __api_version__,
}
@app.get("/health", dependencies=[Depends(optional_api_key)])
@app.get("/health", dependencies=[Depends(optional_api_key_dependency)])
async def get_status():
"""Get current system status"""
# Get update flags status for all namespaces