Merge branch 'main'
This commit is contained in:
@@ -18,7 +18,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from contextlib import asynccontextmanager
|
||||
from dotenv import load_dotenv
|
||||
from lightrag.api.utils_api import (
|
||||
get_api_key_dependency,
|
||||
get_combined_auth_dependency,
|
||||
parse_args,
|
||||
get_default_host,
|
||||
display_splash_screen,
|
||||
@@ -41,7 +41,6 @@ from lightrag.kg.shared_storage import (
|
||||
get_namespace_data,
|
||||
get_pipeline_status_lock,
|
||||
initialize_pipeline_status,
|
||||
get_all_update_flags_status,
|
||||
)
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from .auth import auth_handler
|
||||
@@ -136,19 +135,28 @@ def create_app(args):
|
||||
await rag.finalize_storages()
|
||||
|
||||
# Initialize FastAPI
|
||||
app = FastAPI(
|
||||
title="LightRAG API",
|
||||
description="API for querying text using LightRAG with separate storage and input directories"
|
||||
app_kwargs = {
|
||||
"title": "LightRAG Server API",
|
||||
"description": "Providing API for LightRAG core, Web UI and Ollama Model Emulation"
|
||||
+ "(With authentication)"
|
||||
if api_key
|
||||
else "",
|
||||
version=__api_version__,
|
||||
openapi_url="/openapi.json", # Explicitly set OpenAPI schema URL
|
||||
docs_url="/docs", # Explicitly set docs URL
|
||||
redoc_url="/redoc", # Explicitly set redoc URL
|
||||
openapi_tags=[{"name": "api"}],
|
||||
lifespan=lifespan,
|
||||
)
|
||||
"version": __api_version__,
|
||||
"openapi_url": "/openapi.json", # Explicitly set OpenAPI schema URL
|
||||
"docs_url": "/docs", # Explicitly set docs URL
|
||||
"redoc_url": "/redoc", # Explicitly set redoc URL
|
||||
"openapi_tags": [{"name": "api"}],
|
||||
"lifespan": lifespan,
|
||||
}
|
||||
|
||||
# Configure Swagger UI parameters
|
||||
# Enable persistAuthorization and tryItOutEnabled for better user experience
|
||||
app_kwargs["swagger_ui_parameters"] = {
|
||||
"persistAuthorization": True,
|
||||
"tryItOutEnabled": True,
|
||||
}
|
||||
|
||||
app = FastAPI(**app_kwargs)
|
||||
|
||||
def get_cors_origins():
|
||||
"""Get allowed origins from environment variable
|
||||
@@ -168,8 +176,8 @@ def create_app(args):
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Create the optional API key dependency
|
||||
optional_api_key = get_api_key_dependency(api_key)
|
||||
# Create combined auth dependency for all endpoints
|
||||
combined_auth = get_combined_auth_dependency(api_key)
|
||||
|
||||
# Create working directory if it doesn't exist
|
||||
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
|
||||
@@ -200,6 +208,7 @@ def create_app(args):
|
||||
kwargs["response_format"] = GPTKeywordExtractionFormat
|
||||
if history_messages is None:
|
||||
history_messages = []
|
||||
kwargs["temperature"] = args.temperature
|
||||
return await openai_complete_if_cache(
|
||||
args.llm_model,
|
||||
prompt,
|
||||
@@ -222,6 +231,7 @@ def create_app(args):
|
||||
kwargs["response_format"] = GPTKeywordExtractionFormat
|
||||
if history_messages is None:
|
||||
history_messages = []
|
||||
kwargs["temperature"] = args.temperature
|
||||
return await azure_openai_complete_if_cache(
|
||||
args.llm_model,
|
||||
prompt,
|
||||
@@ -302,6 +312,7 @@ def create_app(args):
|
||||
},
|
||||
namespace_prefix=args.namespace_prefix,
|
||||
auto_manage_storages_states=False,
|
||||
max_parallel_insert=args.max_parallel_insert,
|
||||
)
|
||||
else: # azure_openai
|
||||
rag = LightRAG(
|
||||
@@ -331,6 +342,7 @@ def create_app(args):
|
||||
},
|
||||
namespace_prefix=args.namespace_prefix,
|
||||
auto_manage_storages_states=False,
|
||||
max_parallel_insert=args.max_parallel_insert,
|
||||
)
|
||||
|
||||
# Add routes
|
||||
@@ -339,7 +351,7 @@ def create_app(args):
|
||||
app.include_router(create_graph_routes(rag, api_key))
|
||||
|
||||
# Add Ollama API routes
|
||||
ollama_api = OllamaAPI(rag, top_k=args.top_k)
|
||||
ollama_api = OllamaAPI(rag, top_k=args.top_k, api_key=api_key)
|
||||
app.include_router(ollama_api.router, prefix="/api")
|
||||
|
||||
@app.get("/")
|
||||
@@ -347,7 +359,7 @@ def create_app(args):
|
||||
"""Redirect root path to /webui"""
|
||||
return RedirectResponse(url="/webui")
|
||||
|
||||
@app.get("/auth-status", dependencies=[Depends(optional_api_key)])
|
||||
@app.get("/auth-status")
|
||||
async def get_auth_status():
|
||||
"""Get authentication status and guest token if auth is not configured"""
|
||||
|
||||
@@ -373,7 +385,7 @@ def create_app(args):
|
||||
"api_version": __api_version__,
|
||||
}
|
||||
|
||||
@app.post("/login", dependencies=[Depends(optional_api_key)])
|
||||
@app.post("/login")
|
||||
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
|
||||
if not auth_handler.accounts:
|
||||
# Authentication not configured, return guest token
|
||||
@@ -406,12 +418,9 @@ def create_app(args):
|
||||
"api_version": __api_version__,
|
||||
}
|
||||
|
||||
@app.get("/health", dependencies=[Depends(optional_api_key)])
|
||||
@app.get("/health", dependencies=[Depends(combined_auth)])
|
||||
async def get_status():
|
||||
"""Get current system status"""
|
||||
# Get update flags status for all namespaces
|
||||
update_status = await get_all_update_flags_status()
|
||||
|
||||
username = os.getenv("AUTH_USERNAME")
|
||||
password = os.getenv("AUTH_PASSWORD")
|
||||
if not (username and password):
|
||||
@@ -439,7 +448,6 @@ def create_app(args):
|
||||
"vector_storage": args.vector_storage,
|
||||
"enable_llm_cache_for_extract": args.enable_llm_cache_for_extract,
|
||||
},
|
||||
"update_status": update_status,
|
||||
"core_version": core_version,
|
||||
"api_version": __api_version__,
|
||||
"auth_mode": auth_mode,
|
||||
|
Reference in New Issue
Block a user