Merge branch 'main'

This commit is contained in:
Milin
2025-03-25 15:57:14 +08:00
36 changed files with 2780 additions and 914 deletions

View File

@@ -18,7 +18,7 @@ from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from dotenv import load_dotenv
from lightrag.api.utils_api import (
get_api_key_dependency,
get_combined_auth_dependency,
parse_args,
get_default_host,
display_splash_screen,
@@ -41,7 +41,6 @@ from lightrag.kg.shared_storage import (
get_namespace_data,
get_pipeline_status_lock,
initialize_pipeline_status,
get_all_update_flags_status,
)
from fastapi.security import OAuth2PasswordRequestForm
from .auth import auth_handler
@@ -136,19 +135,28 @@ def create_app(args):
await rag.finalize_storages()
# Initialize FastAPI
app = FastAPI(
title="LightRAG API",
description="API for querying text using LightRAG with separate storage and input directories"
app_kwargs = {
"title": "LightRAG Server API",
"description": "Providing API for LightRAG core, Web UI and Ollama Model Emulation"
+ "(With authentication)"
if api_key
else "",
version=__api_version__,
openapi_url="/openapi.json", # Explicitly set OpenAPI schema URL
docs_url="/docs", # Explicitly set docs URL
redoc_url="/redoc", # Explicitly set redoc URL
openapi_tags=[{"name": "api"}],
lifespan=lifespan,
)
"version": __api_version__,
"openapi_url": "/openapi.json", # Explicitly set OpenAPI schema URL
"docs_url": "/docs", # Explicitly set docs URL
"redoc_url": "/redoc", # Explicitly set redoc URL
"openapi_tags": [{"name": "api"}],
"lifespan": lifespan,
}
# Configure Swagger UI parameters
# Enable persistAuthorization and tryItOutEnabled for better user experience
app_kwargs["swagger_ui_parameters"] = {
"persistAuthorization": True,
"tryItOutEnabled": True,
}
app = FastAPI(**app_kwargs)
def get_cors_origins():
"""Get allowed origins from environment variable
@@ -168,8 +176,8 @@ def create_app(args):
allow_headers=["*"],
)
# Create the optional API key dependency
optional_api_key = get_api_key_dependency(api_key)
# Create combined auth dependency for all endpoints
combined_auth = get_combined_auth_dependency(api_key)
# Create working directory if it doesn't exist
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
@@ -200,6 +208,7 @@ def create_app(args):
kwargs["response_format"] = GPTKeywordExtractionFormat
if history_messages is None:
history_messages = []
kwargs["temperature"] = args.temperature
return await openai_complete_if_cache(
args.llm_model,
prompt,
@@ -222,6 +231,7 @@ def create_app(args):
kwargs["response_format"] = GPTKeywordExtractionFormat
if history_messages is None:
history_messages = []
kwargs["temperature"] = args.temperature
return await azure_openai_complete_if_cache(
args.llm_model,
prompt,
@@ -302,6 +312,7 @@ def create_app(args):
},
namespace_prefix=args.namespace_prefix,
auto_manage_storages_states=False,
max_parallel_insert=args.max_parallel_insert,
)
else: # azure_openai
rag = LightRAG(
@@ -331,6 +342,7 @@ def create_app(args):
},
namespace_prefix=args.namespace_prefix,
auto_manage_storages_states=False,
max_parallel_insert=args.max_parallel_insert,
)
# Add routes
@@ -339,7 +351,7 @@ def create_app(args):
app.include_router(create_graph_routes(rag, api_key))
# Add Ollama API routes
ollama_api = OllamaAPI(rag, top_k=args.top_k)
ollama_api = OllamaAPI(rag, top_k=args.top_k, api_key=api_key)
app.include_router(ollama_api.router, prefix="/api")
@app.get("/")
@@ -347,7 +359,7 @@ def create_app(args):
"""Redirect root path to /webui"""
return RedirectResponse(url="/webui")
@app.get("/auth-status", dependencies=[Depends(optional_api_key)])
@app.get("/auth-status")
async def get_auth_status():
"""Get authentication status and guest token if auth is not configured"""
@@ -373,7 +385,7 @@ def create_app(args):
"api_version": __api_version__,
}
@app.post("/login", dependencies=[Depends(optional_api_key)])
@app.post("/login")
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
if not auth_handler.accounts:
# Authentication not configured, return guest token
@@ -406,12 +418,9 @@ def create_app(args):
"api_version": __api_version__,
}
@app.get("/health", dependencies=[Depends(optional_api_key)])
@app.get("/health", dependencies=[Depends(combined_auth)])
async def get_status():
"""Get current system status"""
# Get update flags status for all namespaces
update_status = await get_all_update_flags_status()
username = os.getenv("AUTH_USERNAME")
password = os.getenv("AUTH_PASSWORD")
if not (username and password):
@@ -439,7 +448,6 @@ def create_app(args):
"vector_storage": args.vector_storage,
"enable_llm_cache_for_extract": args.enable_llm_cache_for_extract,
},
"update_status": update_status,
"core_version": core_version,
"api_version": __api_version__,
"auth_mode": auth_mode,