Refactor authentication logic and Swagger UI config
- Consolidate authentication dependencies - Improve Swagger UI security parameters
This commit is contained in:
@@ -19,6 +19,7 @@ from contextlib import asynccontextmanager
|
|||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from lightrag.api.utils_api import (
|
from lightrag.api.utils_api import (
|
||||||
get_api_key_dependency,
|
get_api_key_dependency,
|
||||||
|
get_combined_auth_dependency,
|
||||||
parse_args,
|
parse_args,
|
||||||
get_default_host,
|
get_default_host,
|
||||||
display_splash_screen,
|
display_splash_screen,
|
||||||
@@ -135,19 +136,28 @@ def create_app(args):
|
|||||||
await rag.finalize_storages()
|
await rag.finalize_storages()
|
||||||
|
|
||||||
# Initialize FastAPI
|
# Initialize FastAPI
|
||||||
app = FastAPI(
|
app_kwargs = {
|
||||||
title="LightRAG API",
|
"title": "LightRAG Server API",
|
||||||
description="API for querying text using LightRAG with separate storage and input directories"
|
"description": "Providing API for LightRAG core, Web UI and Ollama Model Emulation"
|
||||||
+ "(With authentication)"
|
+ "(With authentication)"
|
||||||
if api_key
|
if api_key
|
||||||
else "",
|
else "",
|
||||||
version=__api_version__,
|
"version": __api_version__,
|
||||||
openapi_url="/openapi.json", # Explicitly set OpenAPI schema URL
|
"openapi_url": "/openapi.json", # Explicitly set OpenAPI schema URL
|
||||||
docs_url="/docs", # Explicitly set docs URL
|
"docs_url": "/docs", # Explicitly set docs URL
|
||||||
redoc_url="/redoc", # Explicitly set redoc URL
|
"redoc_url": "/redoc", # Explicitly set redoc URL
|
||||||
openapi_tags=[{"name": "api"}],
|
"openapi_tags": [{"name": "api"}],
|
||||||
lifespan=lifespan,
|
"lifespan": lifespan,
|
||||||
)
|
}
|
||||||
|
|
||||||
|
# Configure Swagger UI parameters
|
||||||
|
# Enable persistAuthorization and tryItOutEnabled for better user experience
|
||||||
|
app_kwargs["swagger_ui_parameters"] = {
|
||||||
|
"persistAuthorization": True,
|
||||||
|
"tryItOutEnabled": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
app = FastAPI(**app_kwargs)
|
||||||
|
|
||||||
def get_cors_origins():
|
def get_cors_origins():
|
||||||
"""Get allowed origins from environment variable
|
"""Get allowed origins from environment variable
|
||||||
@@ -167,13 +177,8 @@ def create_app(args):
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create the optional API key dependency
|
# Create combined auth dependency for all endpoints
|
||||||
# Create a dependency that passes the request to get_api_key_dependency
|
combined_auth = get_combined_auth_dependency(api_key)
|
||||||
async def optional_api_key_dependency(request: Request):
|
|
||||||
# Create the dependency function with the request
|
|
||||||
api_key_dependency = get_api_key_dependency(api_key)
|
|
||||||
# Call the dependency function with the request
|
|
||||||
return await api_key_dependency(request)
|
|
||||||
|
|
||||||
# Create working directory if it doesn't exist
|
# Create working directory if it doesn't exist
|
||||||
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
|
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
|
||||||
@@ -419,7 +424,7 @@ def create_app(args):
|
|||||||
"api_version": __api_version__,
|
"api_version": __api_version__,
|
||||||
}
|
}
|
||||||
|
|
||||||
@app.get("/health", dependencies=[Depends(optional_api_key_dependency)])
|
@app.get("/health", dependencies=[Depends(combined_auth)])
|
||||||
async def get_status():
|
async def get_status():
|
||||||
"""Get current system status"""
|
"""Get current system status"""
|
||||||
username = os.getenv("AUTH_USERNAME")
|
username = os.getenv("AUTH_USERNAME")
|
||||||
|
@@ -505,6 +505,7 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
|
|||||||
def create_document_routes(
|
def create_document_routes(
|
||||||
rag: LightRAG, doc_manager: DocumentManager, api_key: Optional[str] = None
|
rag: LightRAG, doc_manager: DocumentManager, api_key: Optional[str] = None
|
||||||
):
|
):
|
||||||
|
# Create combined auth dependency for document routes
|
||||||
combined_auth = get_combined_auth_dependency(api_key)
|
combined_auth = get_combined_auth_dependency(api_key)
|
||||||
|
|
||||||
@router.post("/scan", dependencies=[Depends(combined_auth)])
|
@router.post("/scan", dependencies=[Depends(combined_auth)])
|
||||||
|
@@ -11,7 +11,7 @@ import asyncio
|
|||||||
from ascii_colors import trace_exception
|
from ascii_colors import trace_exception
|
||||||
from lightrag import LightRAG, QueryParam
|
from lightrag import LightRAG, QueryParam
|
||||||
from lightrag.utils import encode_string_by_tiktoken
|
from lightrag.utils import encode_string_by_tiktoken
|
||||||
from lightrag.api.utils_api import ollama_server_infos, get_api_key_dependency
|
from lightrag.api.utils_api import ollama_server_infos, get_combined_auth_dependency
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
|
|
||||||
|
|
||||||
@@ -132,21 +132,17 @@ class OllamaAPI:
|
|||||||
self.setup_routes()
|
self.setup_routes()
|
||||||
|
|
||||||
def setup_routes(self):
|
def setup_routes(self):
|
||||||
# Create a dependency that passes the request to get_api_key_dependency
|
# Create combined auth dependency for Ollama API routes
|
||||||
async def optional_api_key_dependency(request: Request):
|
combined_auth = get_combined_auth_dependency(self.api_key)
|
||||||
# Create the dependency function with the request
|
|
||||||
api_key_dependency = get_api_key_dependency(self.api_key)
|
|
||||||
# Call the dependency function with the request
|
|
||||||
return await api_key_dependency(request)
|
|
||||||
|
|
||||||
@self.router.get(
|
@self.router.get(
|
||||||
"/version", dependencies=[Depends(optional_api_key_dependency)]
|
"/version", dependencies=[Depends(combined_auth)]
|
||||||
)
|
)
|
||||||
async def get_version():
|
async def get_version():
|
||||||
"""Get Ollama version information"""
|
"""Get Ollama version information"""
|
||||||
return OllamaVersionResponse(version="0.5.4")
|
return OllamaVersionResponse(version="0.5.4")
|
||||||
|
|
||||||
@self.router.get("/tags", dependencies=[Depends(optional_api_key_dependency)])
|
@self.router.get("/tags", dependencies=[Depends(combined_auth)])
|
||||||
async def get_tags():
|
async def get_tags():
|
||||||
"""Return available models acting as an Ollama server"""
|
"""Return available models acting as an Ollama server"""
|
||||||
return OllamaTagResponse(
|
return OllamaTagResponse(
|
||||||
@@ -170,7 +166,7 @@ class OllamaAPI:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@self.router.post(
|
@self.router.post(
|
||||||
"/generate", dependencies=[Depends(optional_api_key_dependency)]
|
"/generate", dependencies=[Depends(combined_auth)]
|
||||||
)
|
)
|
||||||
async def generate(raw_request: Request, request: OllamaGenerateRequest):
|
async def generate(raw_request: Request, request: OllamaGenerateRequest):
|
||||||
"""Handle generate completion requests acting as an Ollama model
|
"""Handle generate completion requests acting as an Ollama model
|
||||||
@@ -337,7 +333,7 @@ class OllamaAPI:
|
|||||||
trace_exception(e)
|
trace_exception(e)
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@self.router.post("/chat", dependencies=[Depends(optional_api_key_dependency)])
|
@self.router.post("/chat", dependencies=[Depends(combined_auth)])
|
||||||
async def chat(raw_request: Request, request: OllamaChatRequest):
|
async def chat(raw_request: Request, request: OllamaChatRequest):
|
||||||
"""Process chat completion requests acting as an Ollama model
|
"""Process chat completion requests acting as an Ollama model
|
||||||
Routes user queries through LightRAG by selecting query mode based on prefix indicators.
|
Routes user queries through LightRAG by selecting query mode based on prefix indicators.
|
||||||
|
@@ -58,29 +58,43 @@ ollama_server_infos = OllamaServerInfos()
|
|||||||
|
|
||||||
def get_combined_auth_dependency(api_key: Optional[str] = None):
|
def get_combined_auth_dependency(api_key: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
Create a combined authentication dependency that implements OR logic (pass through any authentication method)
|
Create a combined authentication dependency that implements authentication logic
|
||||||
|
based on API key, OAuth2 token, and whitelist paths.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
api_key (Optional[str]): API key for validation
|
api_key (Optional[str]): API key for validation
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Callable: A dependency function that implements OR authentication logic
|
Callable: A dependency function that implements the authentication logic
|
||||||
"""
|
"""
|
||||||
# Use global whitelist_patterns and auth_configured variables
|
# Use global whitelist_patterns and auth_configured variables
|
||||||
# whitelist_patterns and auth_configured are already initialized at module level
|
# whitelist_patterns and auth_configured are already initialized at module level
|
||||||
|
|
||||||
# Only calculate api_key_configured as it depends on the function parameter
|
# Only calculate api_key_configured as it depends on the function parameter
|
||||||
api_key_configured = bool(api_key)
|
api_key_configured = bool(api_key)
|
||||||
|
|
||||||
|
# Create security dependencies with proper descriptions for Swagger UI
|
||||||
|
oauth2_scheme = OAuth2PasswordBearer(
|
||||||
|
tokenUrl="login",
|
||||||
|
auto_error=False,
|
||||||
|
description="OAuth2 Password Authentication"
|
||||||
|
)
|
||||||
|
|
||||||
|
# If API key is configured, create an API key header security
|
||||||
|
api_key_header = None
|
||||||
|
if api_key_configured:
|
||||||
|
api_key_header = APIKeyHeader(
|
||||||
|
name="X-API-Key",
|
||||||
|
auto_error=False,
|
||||||
|
description="API Key Authentication"
|
||||||
|
)
|
||||||
|
|
||||||
async def combined_dependency(
|
async def combined_dependency(
|
||||||
request: Request,
|
request: Request,
|
||||||
token: str = Depends(OAuth2PasswordBearer(tokenUrl="login", auto_error=False)),
|
token: str = Security(oauth2_scheme),
|
||||||
|
api_key_header_value: Optional[str] = None if api_key_header is None else Security(api_key_header),
|
||||||
):
|
):
|
||||||
# If both authentication methods are not configured, allow access
|
# 1. Check if path is in whitelist
|
||||||
if not auth_configured and not api_key_configured:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check if request path is in whitelist
|
|
||||||
path = request.url.path
|
path = request.url.path
|
||||||
for pattern, is_prefix in whitelist_patterns:
|
for pattern, is_prefix in whitelist_patterns:
|
||||||
if (is_prefix and path.startswith(pattern)) or (
|
if (is_prefix and path.startswith(pattern)) or (
|
||||||
@@ -88,35 +102,54 @@ def get_combined_auth_dependency(api_key: Optional[str] = None):
|
|||||||
):
|
):
|
||||||
return # Whitelist path, allow access
|
return # Whitelist path, allow access
|
||||||
|
|
||||||
# Access with token
|
# 2. Check for special endpoints (/health and Ollama API)
|
||||||
|
is_special_endpoint = path == "/health" or path.startswith("/api/")
|
||||||
|
if is_special_endpoint and not api_key_configured:
|
||||||
|
return # Special endpoint and no API key configured, allow access
|
||||||
|
|
||||||
|
# 3. Validate API key
|
||||||
|
if api_key_configured and api_key_header_value and api_key_header_value == api_key:
|
||||||
|
return # API key validation successful
|
||||||
|
|
||||||
|
# 4. Validate token
|
||||||
if token:
|
if token:
|
||||||
token_info = auth_handler.validate_token(token)
|
try:
|
||||||
if auth_configured:
|
token_info = auth_handler.validate_token(token)
|
||||||
if token_info.get("role") != "guest" or not api_key_configured:
|
# Accept guest token if no auth is configured
|
||||||
return # Password authentication successful
|
if not auth_configured and token_info.get("role") == "guest":
|
||||||
else:
|
return
|
||||||
if token_info.get("role") == "guest":
|
# Accept non-guest token if auth is configured
|
||||||
return # Guest authentication successful
|
if auth_configured and token_info.get("role") != "guest":
|
||||||
|
return
|
||||||
|
|
||||||
|
# Token validation failed, immediately return 401 error
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid token. Please login again."
|
||||||
|
)
|
||||||
|
except HTTPException as e:
|
||||||
|
# If already a 401 error, re-raise it
|
||||||
|
if e.status_code == status.HTTP_401_UNAUTHORIZED:
|
||||||
|
raise
|
||||||
|
# For other exceptions, continue processing
|
||||||
|
|
||||||
|
# If token exists but validation failed (didn't return above), return 401
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token required"
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid token. Please login again."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Try API key authentication (if configured)
|
# 5. No token and API key validation failed, return 403 error
|
||||||
if api_key_configured:
|
if api_key_configured:
|
||||||
api_key_header = request.headers.get("X-API-Key")
|
raise HTTPException(
|
||||||
if api_key_header and api_key_header == api_key:
|
status_code=HTTP_403_FORBIDDEN,
|
||||||
return # API key authentication successful
|
detail="API Key required or login authentication required."
|
||||||
else:
|
)
|
||||||
if auth_configured:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=HTTP_403_FORBIDDEN,
|
status_code=HTTP_403_FORBIDDEN,
|
||||||
detail="API Key required or use password authentication.",
|
detail="Login authentication required."
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=HTTP_403_FORBIDDEN,
|
|
||||||
detail="API Key required or use guest authentication.",
|
|
||||||
)
|
|
||||||
|
|
||||||
return combined_dependency
|
return combined_dependency
|
||||||
|
|
||||||
@@ -145,12 +178,12 @@ def get_api_key_dependency(api_key: Optional[str]):
|
|||||||
|
|
||||||
return no_auth
|
return no_auth
|
||||||
|
|
||||||
# If API key is configured, use proper authentication
|
# If API key is configured, use proper authentication with Security for Swagger UI
|
||||||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||||
|
|
||||||
async def api_key_auth(
|
async def api_key_auth(
|
||||||
request: Request,
|
request: Request,
|
||||||
api_key_header_value: Optional[str] = Security(api_key_header),
|
api_key_header_value: Optional[str] = Security(api_key_header, description="API Key for authentication"),
|
||||||
):
|
):
|
||||||
# Check if request path is in whitelist
|
# Check if request path is in whitelist
|
||||||
path = request.url.path
|
path = request.url.path
|
||||||
|
Reference in New Issue
Block a user