Fix linting
This commit is contained in:
@@ -2,7 +2,7 @@
|
|||||||
LightRAG FastAPI Server
|
LightRAG FastAPI Server
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from fastapi import FastAPI, Depends, HTTPException, status, Request
|
from fastapi import FastAPI, Depends, HTTPException, status
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
@@ -18,7 +18,6 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from lightrag.api.utils_api import (
|
from lightrag.api.utils_api import (
|
||||||
get_api_key_dependency,
|
|
||||||
get_combined_auth_dependency,
|
get_combined_auth_dependency,
|
||||||
parse_args,
|
parse_args,
|
||||||
get_default_host,
|
get_default_host,
|
||||||
@@ -149,14 +148,14 @@ def create_app(args):
|
|||||||
"openapi_tags": [{"name": "api"}],
|
"openapi_tags": [{"name": "api"}],
|
||||||
"lifespan": lifespan,
|
"lifespan": lifespan,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Configure Swagger UI parameters
|
# Configure Swagger UI parameters
|
||||||
# Enable persistAuthorization and tryItOutEnabled for better user experience
|
# Enable persistAuthorization and tryItOutEnabled for better user experience
|
||||||
app_kwargs["swagger_ui_parameters"] = {
|
app_kwargs["swagger_ui_parameters"] = {
|
||||||
"persistAuthorization": True,
|
"persistAuthorization": True,
|
||||||
"tryItOutEnabled": True,
|
"tryItOutEnabled": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
app = FastAPI(**app_kwargs)
|
app = FastAPI(**app_kwargs)
|
||||||
|
|
||||||
def get_cors_origins():
|
def get_cors_origins():
|
||||||
|
@@ -808,14 +808,14 @@ def create_document_routes(
|
|||||||
|
|
||||||
# Get update flags status for all namespaces
|
# Get update flags status for all namespaces
|
||||||
update_status = await get_all_update_flags_status()
|
update_status = await get_all_update_flags_status()
|
||||||
|
|
||||||
# Convert MutableBoolean objects to regular boolean values
|
# Convert MutableBoolean objects to regular boolean values
|
||||||
processed_update_status = {}
|
processed_update_status = {}
|
||||||
for namespace, flags in update_status.items():
|
for namespace, flags in update_status.items():
|
||||||
processed_flags = []
|
processed_flags = []
|
||||||
for flag in flags:
|
for flag in flags:
|
||||||
# Handle both multiprocess and single process cases
|
# Handle both multiprocess and single process cases
|
||||||
if hasattr(flag, 'value'):
|
if hasattr(flag, "value"):
|
||||||
processed_flags.append(bool(flag.value))
|
processed_flags.append(bool(flag.value))
|
||||||
else:
|
else:
|
||||||
processed_flags.append(bool(flag))
|
processed_flags.append(bool(flag))
|
||||||
|
@@ -135,9 +135,7 @@ class OllamaAPI:
|
|||||||
# Create combined auth dependency for Ollama API routes
|
# Create combined auth dependency for Ollama API routes
|
||||||
combined_auth = get_combined_auth_dependency(self.api_key)
|
combined_auth = get_combined_auth_dependency(self.api_key)
|
||||||
|
|
||||||
@self.router.get(
|
@self.router.get("/version", dependencies=[Depends(combined_auth)])
|
||||||
"/version", dependencies=[Depends(combined_auth)]
|
|
||||||
)
|
|
||||||
async def get_version():
|
async def get_version():
|
||||||
"""Get Ollama version information"""
|
"""Get Ollama version information"""
|
||||||
return OllamaVersionResponse(version="0.5.4")
|
return OllamaVersionResponse(version="0.5.4")
|
||||||
@@ -165,9 +163,7 @@ class OllamaAPI:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@self.router.post(
|
@self.router.post("/generate", dependencies=[Depends(combined_auth)])
|
||||||
"/generate", dependencies=[Depends(combined_auth)]
|
|
||||||
)
|
|
||||||
async def generate(raw_request: Request, request: OllamaGenerateRequest):
|
async def generate(raw_request: Request, request: OllamaGenerateRequest):
|
||||||
"""Handle generate completion requests acting as an Ollama model
|
"""Handle generate completion requests acting as an Ollama model
|
||||||
For compatibility purpose, the request is not processed by LightRAG,
|
For compatibility purpose, the request is not processed by LightRAG,
|
||||||
|
@@ -9,7 +9,7 @@ import sys
|
|||||||
import logging
|
import logging
|
||||||
from ascii_colors import ASCIIColors
|
from ascii_colors import ASCIIColors
|
||||||
from lightrag.api import __api_version__
|
from lightrag.api import __api_version__
|
||||||
from fastapi import HTTPException, Security, Depends, Request, status
|
from fastapi import HTTPException, Security, Request, status
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
|
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
|
||||||
from starlette.status import HTTP_403_FORBIDDEN
|
from starlette.status import HTTP_403_FORBIDDEN
|
||||||
@@ -72,27 +72,25 @@ def get_combined_auth_dependency(api_key: Optional[str] = None):
|
|||||||
|
|
||||||
# Only calculate api_key_configured as it depends on the function parameter
|
# Only calculate api_key_configured as it depends on the function parameter
|
||||||
api_key_configured = bool(api_key)
|
api_key_configured = bool(api_key)
|
||||||
|
|
||||||
# Create security dependencies with proper descriptions for Swagger UI
|
# Create security dependencies with proper descriptions for Swagger UI
|
||||||
oauth2_scheme = OAuth2PasswordBearer(
|
oauth2_scheme = OAuth2PasswordBearer(
|
||||||
tokenUrl="login",
|
tokenUrl="login", auto_error=False, description="OAuth2 Password Authentication"
|
||||||
auto_error=False,
|
|
||||||
description="OAuth2 Password Authentication"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# If API key is configured, create an API key header security
|
# If API key is configured, create an API key header security
|
||||||
api_key_header = None
|
api_key_header = None
|
||||||
if api_key_configured:
|
if api_key_configured:
|
||||||
api_key_header = APIKeyHeader(
|
api_key_header = APIKeyHeader(
|
||||||
name="X-API-Key",
|
name="X-API-Key", auto_error=False, description="API Key Authentication"
|
||||||
auto_error=False,
|
|
||||||
description="API Key Authentication"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def combined_dependency(
|
async def combined_dependency(
|
||||||
request: Request,
|
request: Request,
|
||||||
token: str = Security(oauth2_scheme),
|
token: str = Security(oauth2_scheme),
|
||||||
api_key_header_value: Optional[str] = None if api_key_header is None else Security(api_key_header),
|
api_key_header_value: Optional[str] = None
|
||||||
|
if api_key_header is None
|
||||||
|
else Security(api_key_header),
|
||||||
):
|
):
|
||||||
# 1. Check if path is in whitelist
|
# 1. Check if path is in whitelist
|
||||||
path = request.url.path
|
path = request.url.path
|
||||||
@@ -106,11 +104,15 @@ def get_combined_auth_dependency(api_key: Optional[str] = None):
|
|||||||
is_special_endpoint = path == "/health" or path.startswith("/api/")
|
is_special_endpoint = path == "/health" or path.startswith("/api/")
|
||||||
if is_special_endpoint and not api_key_configured:
|
if is_special_endpoint and not api_key_configured:
|
||||||
return # Special endpoint and no API key configured, allow access
|
return # Special endpoint and no API key configured, allow access
|
||||||
|
|
||||||
# 3. Validate API key
|
# 3. Validate API key
|
||||||
if api_key_configured and api_key_header_value and api_key_header_value == api_key:
|
if (
|
||||||
|
api_key_configured
|
||||||
|
and api_key_header_value
|
||||||
|
and api_key_header_value == api_key
|
||||||
|
):
|
||||||
return # API key validation successful
|
return # API key validation successful
|
||||||
|
|
||||||
# 4. Validate token
|
# 4. Validate token
|
||||||
if token:
|
if token:
|
||||||
try:
|
try:
|
||||||
@@ -121,34 +123,33 @@ def get_combined_auth_dependency(api_key: Optional[str] = None):
|
|||||||
# Accept non-guest token if auth is configured
|
# Accept non-guest token if auth is configured
|
||||||
if auth_configured and token_info.get("role") != "guest":
|
if auth_configured and token_info.get("role") != "guest":
|
||||||
return
|
return
|
||||||
|
|
||||||
# Token validation failed, immediately return 401 error
|
# Token validation failed, immediately return 401 error
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail="Invalid token. Please login again."
|
detail="Invalid token. Please login again.",
|
||||||
)
|
)
|
||||||
except HTTPException as e:
|
except HTTPException as e:
|
||||||
# If already a 401 error, re-raise it
|
# If already a 401 error, re-raise it
|
||||||
if e.status_code == status.HTTP_401_UNAUTHORIZED:
|
if e.status_code == status.HTTP_401_UNAUTHORIZED:
|
||||||
raise
|
raise
|
||||||
# For other exceptions, continue processing
|
# For other exceptions, continue processing
|
||||||
|
|
||||||
# If token exists but validation failed (didn't return above), return 401
|
# If token exists but validation failed (didn't return above), return 401
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail="Invalid token. Please login again."
|
detail="Invalid token. Please login again.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# 5. No token and API key validation failed, return 403 error
|
# 5. No token and API key validation failed, return 403 error
|
||||||
if api_key_configured:
|
if api_key_configured:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=HTTP_403_FORBIDDEN,
|
status_code=HTTP_403_FORBIDDEN,
|
||||||
detail="API Key required or login authentication required."
|
detail="API Key required or login authentication required.",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=HTTP_403_FORBIDDEN,
|
status_code=HTTP_403_FORBIDDEN, detail="Login authentication required."
|
||||||
detail="Login authentication required."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return combined_dependency
|
return combined_dependency
|
||||||
@@ -183,7 +184,9 @@ def get_api_key_dependency(api_key: Optional[str]):
|
|||||||
|
|
||||||
async def api_key_auth(
|
async def api_key_auth(
|
||||||
request: Request,
|
request: Request,
|
||||||
api_key_header_value: Optional[str] = Security(api_key_header, description="API Key for authentication"),
|
api_key_header_value: Optional[str] = Security(
|
||||||
|
api_key_header, description="API Key for authentication"
|
||||||
|
),
|
||||||
):
|
):
|
||||||
# Check if request path is in whitelist
|
# Check if request path is in whitelist
|
||||||
path = request.url.path
|
path = request.url.path
|
||||||
|
Reference in New Issue
Block a user