Fix linting
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
LightRAG FastAPI Server
|
||||
"""
|
||||
|
||||
from fastapi import FastAPI, Depends, HTTPException, status, Request
|
||||
from fastapi import FastAPI, Depends, HTTPException, status
|
||||
import asyncio
|
||||
import os
|
||||
import logging
|
||||
@@ -18,7 +18,6 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from contextlib import asynccontextmanager
|
||||
from dotenv import load_dotenv
|
||||
from lightrag.api.utils_api import (
|
||||
get_api_key_dependency,
|
||||
get_combined_auth_dependency,
|
||||
parse_args,
|
||||
get_default_host,
|
||||
@@ -149,14 +148,14 @@ def create_app(args):
|
||||
"openapi_tags": [{"name": "api"}],
|
||||
"lifespan": lifespan,
|
||||
}
|
||||
|
||||
|
||||
# Configure Swagger UI parameters
|
||||
# Enable persistAuthorization and tryItOutEnabled for better user experience
|
||||
app_kwargs["swagger_ui_parameters"] = {
|
||||
"persistAuthorization": True,
|
||||
"tryItOutEnabled": True,
|
||||
}
|
||||
|
||||
|
||||
app = FastAPI(**app_kwargs)
|
||||
|
||||
def get_cors_origins():
|
||||
|
@@ -808,14 +808,14 @@ def create_document_routes(
|
||||
|
||||
# Get update flags status for all namespaces
|
||||
update_status = await get_all_update_flags_status()
|
||||
|
||||
|
||||
# Convert MutableBoolean objects to regular boolean values
|
||||
processed_update_status = {}
|
||||
for namespace, flags in update_status.items():
|
||||
processed_flags = []
|
||||
for flag in flags:
|
||||
# Handle both multiprocess and single process cases
|
||||
if hasattr(flag, 'value'):
|
||||
if hasattr(flag, "value"):
|
||||
processed_flags.append(bool(flag.value))
|
||||
else:
|
||||
processed_flags.append(bool(flag))
|
||||
|
@@ -135,9 +135,7 @@ class OllamaAPI:
|
||||
# Create combined auth dependency for Ollama API routes
|
||||
combined_auth = get_combined_auth_dependency(self.api_key)
|
||||
|
||||
@self.router.get(
|
||||
"/version", dependencies=[Depends(combined_auth)]
|
||||
)
|
||||
@self.router.get("/version", dependencies=[Depends(combined_auth)])
|
||||
async def get_version():
|
||||
"""Get Ollama version information"""
|
||||
return OllamaVersionResponse(version="0.5.4")
|
||||
@@ -165,9 +163,7 @@ class OllamaAPI:
|
||||
]
|
||||
)
|
||||
|
||||
@self.router.post(
|
||||
"/generate", dependencies=[Depends(combined_auth)]
|
||||
)
|
||||
@self.router.post("/generate", dependencies=[Depends(combined_auth)])
|
||||
async def generate(raw_request: Request, request: OllamaGenerateRequest):
|
||||
"""Handle generate completion requests acting as an Ollama model
|
||||
For compatibility purpose, the request is not processed by LightRAG,
|
||||
|
@@ -9,7 +9,7 @@ import sys
|
||||
import logging
|
||||
from ascii_colors import ASCIIColors
|
||||
from lightrag.api import __api_version__
|
||||
from fastapi import HTTPException, Security, Depends, Request, status
|
||||
from fastapi import HTTPException, Security, Request, status
|
||||
from dotenv import load_dotenv
|
||||
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
|
||||
from starlette.status import HTTP_403_FORBIDDEN
|
||||
@@ -72,27 +72,25 @@ def get_combined_auth_dependency(api_key: Optional[str] = None):
|
||||
|
||||
# Only calculate api_key_configured as it depends on the function parameter
|
||||
api_key_configured = bool(api_key)
|
||||
|
||||
|
||||
# Create security dependencies with proper descriptions for Swagger UI
|
||||
oauth2_scheme = OAuth2PasswordBearer(
|
||||
tokenUrl="login",
|
||||
auto_error=False,
|
||||
description="OAuth2 Password Authentication"
|
||||
tokenUrl="login", auto_error=False, description="OAuth2 Password Authentication"
|
||||
)
|
||||
|
||||
|
||||
# If API key is configured, create an API key header security
|
||||
api_key_header = None
|
||||
if api_key_configured:
|
||||
api_key_header = APIKeyHeader(
|
||||
name="X-API-Key",
|
||||
auto_error=False,
|
||||
description="API Key Authentication"
|
||||
name="X-API-Key", auto_error=False, description="API Key Authentication"
|
||||
)
|
||||
|
||||
async def combined_dependency(
|
||||
request: Request,
|
||||
token: str = Security(oauth2_scheme),
|
||||
api_key_header_value: Optional[str] = None if api_key_header is None else Security(api_key_header),
|
||||
api_key_header_value: Optional[str] = None
|
||||
if api_key_header is None
|
||||
else Security(api_key_header),
|
||||
):
|
||||
# 1. Check if path is in whitelist
|
||||
path = request.url.path
|
||||
@@ -106,11 +104,15 @@ def get_combined_auth_dependency(api_key: Optional[str] = None):
|
||||
is_special_endpoint = path == "/health" or path.startswith("/api/")
|
||||
if is_special_endpoint and not api_key_configured:
|
||||
return # Special endpoint and no API key configured, allow access
|
||||
|
||||
|
||||
# 3. Validate API key
|
||||
if api_key_configured and api_key_header_value and api_key_header_value == api_key:
|
||||
if (
|
||||
api_key_configured
|
||||
and api_key_header_value
|
||||
and api_key_header_value == api_key
|
||||
):
|
||||
return # API key validation successful
|
||||
|
||||
|
||||
# 4. Validate token
|
||||
if token:
|
||||
try:
|
||||
@@ -121,34 +123,33 @@ def get_combined_auth_dependency(api_key: Optional[str] = None):
|
||||
# Accept non-guest token if auth is configured
|
||||
if auth_configured and token_info.get("role") != "guest":
|
||||
return
|
||||
|
||||
|
||||
# Token validation failed, immediately return 401 error
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token. Please login again."
|
||||
detail="Invalid token. Please login again.",
|
||||
)
|
||||
except HTTPException as e:
|
||||
# If already a 401 error, re-raise it
|
||||
if e.status_code == status.HTTP_401_UNAUTHORIZED:
|
||||
raise
|
||||
# For other exceptions, continue processing
|
||||
|
||||
|
||||
# If token exists but validation failed (didn't return above), return 401
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token. Please login again."
|
||||
detail="Invalid token. Please login again.",
|
||||
)
|
||||
|
||||
|
||||
# 5. No token and API key validation failed, return 403 error
|
||||
if api_key_configured:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN,
|
||||
detail="API Key required or login authentication required."
|
||||
detail="API Key required or login authentication required.",
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN,
|
||||
detail="Login authentication required."
|
||||
status_code=HTTP_403_FORBIDDEN, detail="Login authentication required."
|
||||
)
|
||||
|
||||
return combined_dependency
|
||||
@@ -183,7 +184,9 @@ def get_api_key_dependency(api_key: Optional[str]):
|
||||
|
||||
async def api_key_auth(
|
||||
request: Request,
|
||||
api_key_header_value: Optional[str] = Security(api_key_header, description="API Key for authentication"),
|
||||
api_key_header_value: Optional[str] = Security(
|
||||
api_key_header, description="API Key for authentication"
|
||||
),
|
||||
):
|
||||
# Check if request path is in whitelist
|
||||
path = request.url.path
|
||||
|
Reference in New Issue
Block a user