Fix linting

This commit is contained in:
yangdx
2025-03-24 14:30:17 +08:00
parent 9e3994419e
commit d05cf286f4
4 changed files with 32 additions and 34 deletions

View File

@@ -2,7 +2,7 @@
LightRAG FastAPI Server LightRAG FastAPI Server
""" """
from fastapi import FastAPI, Depends, HTTPException, status, Request from fastapi import FastAPI, Depends, HTTPException, status
import asyncio import asyncio
import os import os
import logging import logging
@@ -18,7 +18,6 @@ from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from dotenv import load_dotenv from dotenv import load_dotenv
from lightrag.api.utils_api import ( from lightrag.api.utils_api import (
get_api_key_dependency,
get_combined_auth_dependency, get_combined_auth_dependency,
parse_args, parse_args,
get_default_host, get_default_host,
@@ -149,14 +148,14 @@ def create_app(args):
"openapi_tags": [{"name": "api"}], "openapi_tags": [{"name": "api"}],
"lifespan": lifespan, "lifespan": lifespan,
} }
# Configure Swagger UI parameters # Configure Swagger UI parameters
# Enable persistAuthorization and tryItOutEnabled for better user experience # Enable persistAuthorization and tryItOutEnabled for better user experience
app_kwargs["swagger_ui_parameters"] = { app_kwargs["swagger_ui_parameters"] = {
"persistAuthorization": True, "persistAuthorization": True,
"tryItOutEnabled": True, "tryItOutEnabled": True,
} }
app = FastAPI(**app_kwargs) app = FastAPI(**app_kwargs)
def get_cors_origins(): def get_cors_origins():

View File

@@ -808,14 +808,14 @@ def create_document_routes(
# Get update flags status for all namespaces # Get update flags status for all namespaces
update_status = await get_all_update_flags_status() update_status = await get_all_update_flags_status()
# Convert MutableBoolean objects to regular boolean values # Convert MutableBoolean objects to regular boolean values
processed_update_status = {} processed_update_status = {}
for namespace, flags in update_status.items(): for namespace, flags in update_status.items():
processed_flags = [] processed_flags = []
for flag in flags: for flag in flags:
# Handle both multiprocess and single process cases # Handle both multiprocess and single process cases
if hasattr(flag, 'value'): if hasattr(flag, "value"):
processed_flags.append(bool(flag.value)) processed_flags.append(bool(flag.value))
else: else:
processed_flags.append(bool(flag)) processed_flags.append(bool(flag))

View File

@@ -135,9 +135,7 @@ class OllamaAPI:
# Create combined auth dependency for Ollama API routes # Create combined auth dependency for Ollama API routes
combined_auth = get_combined_auth_dependency(self.api_key) combined_auth = get_combined_auth_dependency(self.api_key)
@self.router.get( @self.router.get("/version", dependencies=[Depends(combined_auth)])
"/version", dependencies=[Depends(combined_auth)]
)
async def get_version(): async def get_version():
"""Get Ollama version information""" """Get Ollama version information"""
return OllamaVersionResponse(version="0.5.4") return OllamaVersionResponse(version="0.5.4")
@@ -165,9 +163,7 @@ class OllamaAPI:
] ]
) )
@self.router.post( @self.router.post("/generate", dependencies=[Depends(combined_auth)])
"/generate", dependencies=[Depends(combined_auth)]
)
async def generate(raw_request: Request, request: OllamaGenerateRequest): async def generate(raw_request: Request, request: OllamaGenerateRequest):
"""Handle generate completion requests acting as an Ollama model """Handle generate completion requests acting as an Ollama model
For compatibility purpose, the request is not processed by LightRAG, For compatibility purpose, the request is not processed by LightRAG,

View File

@@ -9,7 +9,7 @@ import sys
import logging import logging
from ascii_colors import ASCIIColors from ascii_colors import ASCIIColors
from lightrag.api import __api_version__ from lightrag.api import __api_version__
from fastapi import HTTPException, Security, Depends, Request, status from fastapi import HTTPException, Security, Request, status
from dotenv import load_dotenv from dotenv import load_dotenv
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
from starlette.status import HTTP_403_FORBIDDEN from starlette.status import HTTP_403_FORBIDDEN
@@ -72,27 +72,25 @@ def get_combined_auth_dependency(api_key: Optional[str] = None):
# Only calculate api_key_configured as it depends on the function parameter # Only calculate api_key_configured as it depends on the function parameter
api_key_configured = bool(api_key) api_key_configured = bool(api_key)
# Create security dependencies with proper descriptions for Swagger UI # Create security dependencies with proper descriptions for Swagger UI
oauth2_scheme = OAuth2PasswordBearer( oauth2_scheme = OAuth2PasswordBearer(
tokenUrl="login", tokenUrl="login", auto_error=False, description="OAuth2 Password Authentication"
auto_error=False,
description="OAuth2 Password Authentication"
) )
# If API key is configured, create an API key header security # If API key is configured, create an API key header security
api_key_header = None api_key_header = None
if api_key_configured: if api_key_configured:
api_key_header = APIKeyHeader( api_key_header = APIKeyHeader(
name="X-API-Key", name="X-API-Key", auto_error=False, description="API Key Authentication"
auto_error=False,
description="API Key Authentication"
) )
async def combined_dependency( async def combined_dependency(
request: Request, request: Request,
token: str = Security(oauth2_scheme), token: str = Security(oauth2_scheme),
api_key_header_value: Optional[str] = None if api_key_header is None else Security(api_key_header), api_key_header_value: Optional[str] = None
if api_key_header is None
else Security(api_key_header),
): ):
# 1. Check if path is in whitelist # 1. Check if path is in whitelist
path = request.url.path path = request.url.path
@@ -106,11 +104,15 @@ def get_combined_auth_dependency(api_key: Optional[str] = None):
is_special_endpoint = path == "/health" or path.startswith("/api/") is_special_endpoint = path == "/health" or path.startswith("/api/")
if is_special_endpoint and not api_key_configured: if is_special_endpoint and not api_key_configured:
return # Special endpoint and no API key configured, allow access return # Special endpoint and no API key configured, allow access
# 3. Validate API key # 3. Validate API key
if api_key_configured and api_key_header_value and api_key_header_value == api_key: if (
api_key_configured
and api_key_header_value
and api_key_header_value == api_key
):
return # API key validation successful return # API key validation successful
# 4. Validate token # 4. Validate token
if token: if token:
try: try:
@@ -121,34 +123,33 @@ def get_combined_auth_dependency(api_key: Optional[str] = None):
# Accept non-guest token if auth is configured # Accept non-guest token if auth is configured
if auth_configured and token_info.get("role") != "guest": if auth_configured and token_info.get("role") != "guest":
return return
# Token validation failed, immediately return 401 error # Token validation failed, immediately return 401 error
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token. Please login again." detail="Invalid token. Please login again.",
) )
except HTTPException as e: except HTTPException as e:
# If already a 401 error, re-raise it # If already a 401 error, re-raise it
if e.status_code == status.HTTP_401_UNAUTHORIZED: if e.status_code == status.HTTP_401_UNAUTHORIZED:
raise raise
# For other exceptions, continue processing # For other exceptions, continue processing
# If token exists but validation failed (didn't return above), return 401 # If token exists but validation failed (didn't return above), return 401
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token. Please login again." detail="Invalid token. Please login again.",
) )
# 5. No token and API key validation failed, return 403 error # 5. No token and API key validation failed, return 403 error
if api_key_configured: if api_key_configured:
raise HTTPException( raise HTTPException(
status_code=HTTP_403_FORBIDDEN, status_code=HTTP_403_FORBIDDEN,
detail="API Key required or login authentication required." detail="API Key required or login authentication required.",
) )
else: else:
raise HTTPException( raise HTTPException(
status_code=HTTP_403_FORBIDDEN, status_code=HTTP_403_FORBIDDEN, detail="Login authentication required."
detail="Login authentication required."
) )
return combined_dependency return combined_dependency
@@ -183,7 +184,9 @@ def get_api_key_dependency(api_key: Optional[str]):
async def api_key_auth( async def api_key_auth(
request: Request, request: Request,
api_key_header_value: Optional[str] = Security(api_key_header, description="API Key for authentication"), api_key_header_value: Optional[str] = Security(
api_key_header, description="API Key for authentication"
),
): ):
# Check if request path is in whitelist # Check if request path is in whitelist
path = request.url.path path = request.url.path