feat(api): Add user authentication functionality

- Implement JWT-based user authentication logic
- Add login endpoint and token validation middleware
- Update API routes with authentication dependencies
- Add authentication-related environment variables
- Optimize requirements.txt with necessary dependencies
This commit is contained in:
Milin
2025-03-05 11:09:31 +08:00
parent fb07bc04a0
commit 63aa4f9dfc
10 changed files with 203 additions and 33 deletions

View File

@@ -5,6 +5,9 @@ LightRAG FastAPI Server
from fastapi import (
FastAPI,
Depends,
HTTPException,
Request,
status
)
from fastapi.responses import FileResponse
import asyncio
@@ -25,6 +28,7 @@ from .utils_api import (
parse_args,
get_default_host,
display_splash_screen,
get_auth_dependency,
)
from lightrag import LightRAG
from lightrag.types import GPTKeywordExtractionFormat
@@ -46,6 +50,8 @@ from lightrag.kg.shared_storage import (
initialize_pipeline_status,
get_all_update_flags_status,
)
from fastapi.security import OAuth2PasswordRequestForm
from .auth import auth_handler
# Load environment variables
load_dotenv(override=True)
@@ -373,7 +379,29 @@ def create_app(args):
ollama_api = OllamaAPI(rag, top_k=args.top_k)
app.include_router(ollama_api.router, prefix="/api")
@app.get("/health", dependencies=[Depends(optional_api_key)])
@app.post("/login")
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
username = os.getenv("AUTH_USERNAME")
password = os.getenv("AUTH_PASSWORD")
if not (username and password):
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="Authentication not configured"
)
if form_data.username != username or form_data.password != password:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect credentials"
)
return {
"access_token": auth_handler.create_token(username),
"token_type": "bearer"
}
@app.get("/health", dependencies=[Depends(optional_api_key), Depends(get_auth_dependency())])
async def get_status():
"""Get current system status"""
# Get update flags status for all namespaces
@@ -414,6 +442,12 @@ def create_app(args):
async def webui_root():
return FileResponse(static_dir / "index.html")
@app.middleware("http")
async def debug_middleware(request: Request, call_next):
print(f"Request path: {request.url.path}")
response = await call_next(request)
return response
return app