Fix linting

This commit is contained in:
yangdx
2025-03-24 14:30:17 +08:00
parent 9e3994419e
commit d05cf286f4
4 changed files with 32 additions and 34 deletions

View File

@@ -2,7 +2,7 @@
LightRAG FastAPI Server
"""
from fastapi import FastAPI, Depends, HTTPException, status, Request
from fastapi import FastAPI, Depends, HTTPException, status
import asyncio
import os
import logging
@@ -18,7 +18,6 @@ from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from dotenv import load_dotenv
from lightrag.api.utils_api import (
get_api_key_dependency,
get_combined_auth_dependency,
parse_args,
get_default_host,
@@ -149,14 +148,14 @@ def create_app(args):
"openapi_tags": [{"name": "api"}],
"lifespan": lifespan,
}
# Configure Swagger UI parameters
# Enable persistAuthorization and tryItOutEnabled for better user experience
app_kwargs["swagger_ui_parameters"] = {
"persistAuthorization": True,
"tryItOutEnabled": True,
}
app = FastAPI(**app_kwargs)
def get_cors_origins():

View File

@@ -808,14 +808,14 @@ def create_document_routes(
# Get update flags status for all namespaces
update_status = await get_all_update_flags_status()
# Convert MutableBoolean objects to regular boolean values
processed_update_status = {}
for namespace, flags in update_status.items():
processed_flags = []
for flag in flags:
# Handle both multiprocess and single process cases
if hasattr(flag, 'value'):
if hasattr(flag, "value"):
processed_flags.append(bool(flag.value))
else:
processed_flags.append(bool(flag))

View File

@@ -135,9 +135,7 @@ class OllamaAPI:
# Create combined auth dependency for Ollama API routes
combined_auth = get_combined_auth_dependency(self.api_key)
@self.router.get(
"/version", dependencies=[Depends(combined_auth)]
)
@self.router.get("/version", dependencies=[Depends(combined_auth)])
async def get_version():
"""Get Ollama version information"""
return OllamaVersionResponse(version="0.5.4")
@@ -165,9 +163,7 @@ class OllamaAPI:
]
)
@self.router.post(
"/generate", dependencies=[Depends(combined_auth)]
)
@self.router.post("/generate", dependencies=[Depends(combined_auth)])
async def generate(raw_request: Request, request: OllamaGenerateRequest):
"""Handle generate completion requests acting as an Ollama model
For compatibility purpose, the request is not processed by LightRAG,

View File

@@ -9,7 +9,7 @@ import sys
import logging
from ascii_colors import ASCIIColors
from lightrag.api import __api_version__
from fastapi import HTTPException, Security, Depends, Request, status
from fastapi import HTTPException, Security, Request, status
from dotenv import load_dotenv
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
from starlette.status import HTTP_403_FORBIDDEN
@@ -72,27 +72,25 @@ def get_combined_auth_dependency(api_key: Optional[str] = None):
# Only calculate api_key_configured as it depends on the function parameter
api_key_configured = bool(api_key)
# Create security dependencies with proper descriptions for Swagger UI
oauth2_scheme = OAuth2PasswordBearer(
tokenUrl="login",
auto_error=False,
description="OAuth2 Password Authentication"
tokenUrl="login", auto_error=False, description="OAuth2 Password Authentication"
)
# If API key is configured, create an API key header security
api_key_header = None
if api_key_configured:
api_key_header = APIKeyHeader(
name="X-API-Key",
auto_error=False,
description="API Key Authentication"
name="X-API-Key", auto_error=False, description="API Key Authentication"
)
async def combined_dependency(
request: Request,
token: str = Security(oauth2_scheme),
api_key_header_value: Optional[str] = None if api_key_header is None else Security(api_key_header),
api_key_header_value: Optional[str] = None
if api_key_header is None
else Security(api_key_header),
):
# 1. Check if path is in whitelist
path = request.url.path
@@ -106,11 +104,15 @@ def get_combined_auth_dependency(api_key: Optional[str] = None):
is_special_endpoint = path == "/health" or path.startswith("/api/")
if is_special_endpoint and not api_key_configured:
return # Special endpoint and no API key configured, allow access
# 3. Validate API key
if api_key_configured and api_key_header_value and api_key_header_value == api_key:
if (
api_key_configured
and api_key_header_value
and api_key_header_value == api_key
):
return # API key validation successful
# 4. Validate token
if token:
try:
@@ -121,34 +123,33 @@ def get_combined_auth_dependency(api_key: Optional[str] = None):
# Accept non-guest token if auth is configured
if auth_configured and token_info.get("role") != "guest":
return
# Token validation failed, immediately return 401 error
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token. Please login again."
detail="Invalid token. Please login again.",
)
except HTTPException as e:
# If already a 401 error, re-raise it
if e.status_code == status.HTTP_401_UNAUTHORIZED:
raise
# For other exceptions, continue processing
# If token exists but validation failed (didn't return above), return 401
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token. Please login again."
detail="Invalid token. Please login again.",
)
# 5. No token and API key validation failed, return 403 error
if api_key_configured:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail="API Key required or login authentication required."
detail="API Key required or login authentication required.",
)
else:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail="Login authentication required."
status_code=HTTP_403_FORBIDDEN, detail="Login authentication required."
)
return combined_dependency
@@ -183,7 +184,9 @@ def get_api_key_dependency(api_key: Optional[str]):
async def api_key_auth(
request: Request,
api_key_header_value: Optional[str] = Security(api_key_header, description="API Key for authentication"),
api_key_header_value: Optional[str] = Security(
api_key_header, description="API Key for authentication"
),
):
# Check if request path is in whitelist
path = request.url.path