Merge pull request #1000 from lcjqyml/feat_login-jwt

feat(api): Add user authentication functionality
This commit is contained in:
zrguo
2025-03-07 12:24:46 +08:00
committed by GitHub
9 changed files with 126 additions and 13 deletions

View File

@@ -148,3 +148,10 @@ QDRANT_URL=http://localhost:16333
### Redis
REDIS_URI=redis://localhost:6379
# For jwt auth
AUTH_USERNAME=admin # login name
AUTH_PASSWORD=admin123 # password
TOKEN_SECRET=your-key # JWT key
TOKEN_EXPIRE_HOURS=4 # expire duration
WHITELIST_PATHS=/login,/health # white list

View File

@@ -387,6 +387,19 @@ Note: If you don't need the API functionality, you can install the base package
pip install lightrag-hku
```
## Authentication Endpoints
### JWT Authentication Mechanism
LightRAG API Server implements JWT-based authentication using HS256 algorithm. To enable secure access control, the following environment variables are required:
```bash
# For jwt auth
AUTH_USERNAME=admin # login name
AUTH_PASSWORD=admin123 # password
TOKEN_SECRET=your-key # JWT key
TOKEN_EXPIRE_HOURS=4 # expire duration
WHITELIST_PATHS=/api1,/api2 # white list. /login,/health,/docs,/redoc,/openapi.json are whitelisted by default.
```
## API Endpoints
All servers (LoLLMs, Ollama, OpenAI and Azure OpenAI) provide the same REST API endpoints for RAG functionality. When API Server is running, visit:

41
lightrag/api/auth.py Normal file
View File

@@ -0,0 +1,41 @@
import os
from datetime import datetime, timedelta
import jwt
from fastapi import HTTPException, status
from pydantic import BaseModel
class TokenPayload(BaseModel):
sub: str
exp: datetime
class AuthHandler:
def __init__(self):
self.secret = os.getenv("TOKEN_SECRET", "4f85ds4f56dsf46")
self.algorithm = "HS256"
self.expire_hours = int(os.getenv("TOKEN_EXPIRE_HOURS", 4))
def create_token(self, username: str) -> str:
expire = datetime.utcnow() + timedelta(hours=self.expire_hours)
payload = TokenPayload(sub=username, exp=expire)
return jwt.encode(payload.dict(), self.secret, algorithm=self.algorithm)
def validate_token(self, token: str) -> str:
try:
payload = jwt.decode(token, self.secret, algorithms=[self.algorithm])
expire_timestamp = payload["exp"]
expire_time = datetime.utcfromtimestamp(expire_timestamp)
if datetime.utcnow() > expire_time:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token expired"
)
return payload["sub"]
except jwt.PyJWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
)
auth_handler = AuthHandler()

View File

@@ -2,10 +2,7 @@
LightRAG FastAPI Server
"""
from fastapi import (
FastAPI,
Depends,
)
from fastapi import FastAPI, Depends, HTTPException, status
import asyncio
import os
import logging
@@ -45,6 +42,8 @@ from lightrag.kg.shared_storage import (
initialize_pipeline_status,
get_all_update_flags_status,
)
from fastapi.security import OAuth2PasswordRequestForm
from .auth import auth_handler
# Load environment variables
# Updated to use the .env that is inside the current folder
@@ -372,6 +371,27 @@ def create_app(args):
ollama_api = OllamaAPI(rag, top_k=args.top_k)
app.include_router(ollama_api.router, prefix="/api")
@app.post("/login")
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
username = os.getenv("AUTH_USERNAME")
password = os.getenv("AUTH_PASSWORD")
if not (username and password):
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="Authentication not configured",
)
if form_data.username != username or form_data.password != password:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect credentials"
)
return {
"access_token": auth_handler.create_token(username),
"token_type": "bearer",
}
@app.get("/health", dependencies=[Depends(optional_api_key)])
async def get_status():
"""Get current system status"""

View File

@@ -1,10 +1,20 @@
aiofiles
ascii_colors
asyncpg
distro
fastapi
httpcore
httpx
jiter
numpy
openai
passlib[bcrypt]
pipmaster
PyJWT
python-dotenv
python-jose[cryptography]
python-multipart
pytz
tenacity
tiktoken
uvicorn

View File

@@ -16,10 +16,13 @@ from pydantic import BaseModel, Field, field_validator
from lightrag import LightRAG
from lightrag.base import DocProcessingStatus, DocStatus
from ..utils_api import get_api_key_dependency
from ..utils_api import get_api_key_dependency, get_auth_dependency
router = APIRouter(prefix="/documents", tags=["documents"])
router = APIRouter(
prefix="/documents",
tags=["documents"],
dependencies=[Depends(get_auth_dependency())],
)
# Temporary file prefix
temp_prefix = "__tmp__"

View File

@@ -6,9 +6,9 @@ from typing import Optional
from fastapi import APIRouter, Depends
from ..utils_api import get_api_key_dependency
from ..utils_api import get_api_key_dependency, get_auth_dependency
router = APIRouter(tags=["graph"])
router = APIRouter(tags=["graph"], dependencies=[Depends(get_auth_dependency())])
def create_graph_routes(rag, api_key: Optional[str] = None):

View File

@@ -8,12 +8,12 @@ from typing import Any, Dict, List, Literal, Optional
from fastapi import APIRouter, Depends, HTTPException
from lightrag.base import QueryParam
from ..utils_api import get_api_key_dependency
from ..utils_api import get_api_key_dependency, get_auth_dependency
from pydantic import BaseModel, Field, field_validator
from ascii_colors import trace_exception
router = APIRouter(tags=["query"])
router = APIRouter(tags=["query"], dependencies=[Depends(get_auth_dependency())])
class QueryRequest(BaseModel):

View File

@@ -9,10 +9,11 @@ import sys
import logging
from ascii_colors import ASCIIColors
from lightrag.api import __api_version__
from fastapi import HTTPException, Security
from fastapi import HTTPException, Security, Depends, Request
from dotenv import load_dotenv
from fastapi.security import APIKeyHeader
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
from starlette.status import HTTP_403_FORBIDDEN
from .auth import auth_handler
# Load environment variables
load_dotenv(override=True)
@@ -31,6 +32,24 @@ class OllamaServerInfos:
ollama_server_infos = OllamaServerInfos()
def get_auth_dependency():
whitelist = os.getenv("WHITELIST_PATHS", "").split(",")
async def dependency(
request: Request,
token: str = Depends(OAuth2PasswordBearer(tokenUrl="login", auto_error=False)),
):
if request.url.path in whitelist:
return
if not (os.getenv("AUTH_USERNAME") and os.getenv("AUTH_PASSWORD")):
return
auth_handler.validate_token(token)
return dependency
def get_api_key_dependency(api_key: Optional[str]):
"""
Create an API key dependency for route protection.