feat(api): Add user authentication functionality
- Implement JWT-based user authentication logic - Add login endpoint and token validation middleware - Update API routes with authentication dependencies - Add authentication-related environment variables - Optimize requirements.txt with necessary dependencies
This commit is contained in:
@@ -148,3 +148,10 @@ QDRANT_URL=http://localhost:16333
|
|||||||
|
|
||||||
### Redis
|
### Redis
|
||||||
REDIS_URI=redis://localhost:6379
|
REDIS_URI=redis://localhost:6379
|
||||||
|
|
||||||
|
# For jwt auth
|
||||||
|
AUTH_USERNAME=admin # login name
|
||||||
|
AUTH_PASSWORD=admin123 # password
|
||||||
|
TOKEN_SECRET=your-key # JWT key
|
||||||
|
TOKEN_EXPIRE_HOURS=4 # expire duration
|
||||||
|
WHITELIST_PATHS=/login,/health # white list
|
||||||
|
@@ -296,7 +296,7 @@ You can not change storage implementation selection after you add documents to L
|
|||||||
### LightRag API Server Comand Line Options
|
### LightRag API Server Comand Line Options
|
||||||
|
|
||||||
| Parameter | Default | Description |
|
| Parameter | Default | Description |
|
||||||
|-----------|---------|-------------|
|
|-------------------------|----------------|-----------------------------------------------------------------------------------------------------------------------------|
|
||||||
| --host | 0.0.0.0 | Server host |
|
| --host | 0.0.0.0 | Server host |
|
||||||
| --port | 9621 | Server port |
|
| --port | 9621 | Server port |
|
||||||
| --working-dir | ./rag_storage | Working directory for RAG storage |
|
| --working-dir | ./rag_storage | Working directory for RAG storage |
|
||||||
@@ -314,7 +314,13 @@ You can not change storage implementation selection after you add documents to L
|
|||||||
| --cosine-threshold | 0.4 | The cossine threshold for nodes and relations retrieval, works with top-k to control the retrieval of nodes and relations. |
|
| --cosine-threshold | 0.4 | The cossine threshold for nodes and relations retrieval, works with top-k to control the retrieval of nodes and relations. |
|
||||||
| --llm-binding | ollama | LLM binding type (lollms, ollama, openai, openai-ollama, azure_openai) |
|
| --llm-binding | ollama | LLM binding type (lollms, ollama, openai, openai-ollama, azure_openai) |
|
||||||
| --embedding-binding | ollama | Embedding binding type (lollms, ollama, openai, azure_openai) |
|
| --embedding-binding | ollama | Embedding binding type (lollms, ollama, openai, azure_openai) |
|
||||||
| auto-scan-at-startup | - | Scan input directory for new files and start indexing |
|
| --auto-scan-at-startup | - | Scan input directory for new files and start indexing |
|
||||||
|
| --auth-username | - | Enable jwt if not empty |
|
||||||
|
| --auth-password | - | Enable jwt if not empty |
|
||||||
|
| --token-secret | - | JWT key |
|
||||||
|
| --token-expire-hours | 4 | expire duration |
|
||||||
|
| --whitelist-paths | /login,/health | white list |
|
||||||
|
|
||||||
|
|
||||||
### Example Usage
|
### Example Usage
|
||||||
|
|
||||||
@@ -387,6 +393,19 @@ Note: If you don't need the API functionality, you can install the base package
|
|||||||
pip install lightrag-hku
|
pip install lightrag-hku
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Authentication Endpoints
|
||||||
|
|
||||||
|
### JWT Authentication Mechanism
|
||||||
|
LightRAG API Server implements JWT-based authentication using HS256 algorithm. To enable secure access control, the following environment variables are required:
|
||||||
|
```bash
|
||||||
|
# For jwt auth
|
||||||
|
AUTH_USERNAME=admin # login name --auth-username
|
||||||
|
AUTH_PASSWORD=admin123 # password --auth-password
|
||||||
|
TOKEN_SECRET=your-key # JWT key --token-secret
|
||||||
|
TOKEN_EXPIRE_HOURS=4 # expire duration --token-expire-hours
|
||||||
|
WHITELIST_PATHS=/login,/health # white list --whitelist-paths
|
||||||
|
```
|
||||||
|
|
||||||
## API Endpoints
|
## API Endpoints
|
||||||
|
|
||||||
All servers (LoLLMs, Ollama, OpenAI and Azure OpenAI) provide the same REST API endpoints for RAG functionality. When API Server is running, visit:
|
All servers (LoLLMs, Ollama, OpenAI and Azure OpenAI) provide the same REST API endpoints for RAG functionality. When API Server is running, visit:
|
||||||
|
43
lightrag/api/auth.py
Normal file
43
lightrag/api/auth.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
import os
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
import jwt
|
||||||
|
from fastapi import HTTPException, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class TokenPayload(BaseModel):
|
||||||
|
sub: str
|
||||||
|
exp: datetime
|
||||||
|
|
||||||
|
|
||||||
|
class AuthHandler:
|
||||||
|
def __init__(self):
|
||||||
|
self.secret = os.getenv("TOKEN_SECRET", "4f85ds4f56dsf46")
|
||||||
|
self.algorithm = "HS256"
|
||||||
|
self.expire_hours = int(os.getenv("TOKEN_EXPIRE_HOURS", 4))
|
||||||
|
|
||||||
|
def create_token(self, username: str) -> str:
|
||||||
|
expire = datetime.utcnow() + timedelta(hours=self.expire_hours)
|
||||||
|
payload = TokenPayload(sub=username, exp=expire)
|
||||||
|
return jwt.encode(payload.dict(), self.secret, algorithm=self.algorithm)
|
||||||
|
|
||||||
|
def validate_token(self, token: str) -> str:
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(token, self.secret, algorithms=[self.algorithm])
|
||||||
|
expire_timestamp = payload["exp"]
|
||||||
|
expire_time = datetime.utcfromtimestamp(expire_timestamp)
|
||||||
|
|
||||||
|
if datetime.utcnow() > expire_time:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Token expired"
|
||||||
|
)
|
||||||
|
return payload["sub"]
|
||||||
|
except jwt.PyJWTError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid token"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
auth_handler = AuthHandler()
|
@@ -5,6 +5,9 @@ LightRAG FastAPI Server
|
|||||||
from fastapi import (
|
from fastapi import (
|
||||||
FastAPI,
|
FastAPI,
|
||||||
Depends,
|
Depends,
|
||||||
|
HTTPException,
|
||||||
|
Request,
|
||||||
|
status
|
||||||
)
|
)
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -25,6 +28,7 @@ from .utils_api import (
|
|||||||
parse_args,
|
parse_args,
|
||||||
get_default_host,
|
get_default_host,
|
||||||
display_splash_screen,
|
display_splash_screen,
|
||||||
|
get_auth_dependency,
|
||||||
)
|
)
|
||||||
from lightrag import LightRAG
|
from lightrag import LightRAG
|
||||||
from lightrag.types import GPTKeywordExtractionFormat
|
from lightrag.types import GPTKeywordExtractionFormat
|
||||||
@@ -46,6 +50,8 @@ from lightrag.kg.shared_storage import (
|
|||||||
initialize_pipeline_status,
|
initialize_pipeline_status,
|
||||||
get_all_update_flags_status,
|
get_all_update_flags_status,
|
||||||
)
|
)
|
||||||
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
|
from .auth import auth_handler
|
||||||
|
|
||||||
# Load environment variables
|
# Load environment variables
|
||||||
load_dotenv(override=True)
|
load_dotenv(override=True)
|
||||||
@@ -373,7 +379,29 @@ def create_app(args):
|
|||||||
ollama_api = OllamaAPI(rag, top_k=args.top_k)
|
ollama_api = OllamaAPI(rag, top_k=args.top_k)
|
||||||
app.include_router(ollama_api.router, prefix="/api")
|
app.include_router(ollama_api.router, prefix="/api")
|
||||||
|
|
||||||
@app.get("/health", dependencies=[Depends(optional_api_key)])
|
@app.post("/login")
|
||||||
|
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
|
||||||
|
username = os.getenv("AUTH_USERNAME")
|
||||||
|
password = os.getenv("AUTH_PASSWORD")
|
||||||
|
|
||||||
|
if not (username and password):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||||
|
detail="Authentication not configured"
|
||||||
|
)
|
||||||
|
|
||||||
|
if form_data.username != username or form_data.password != password:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Incorrect credentials"
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"access_token": auth_handler.create_token(username),
|
||||||
|
"token_type": "bearer"
|
||||||
|
}
|
||||||
|
|
||||||
|
@app.get("/health", dependencies=[Depends(optional_api_key), Depends(get_auth_dependency())])
|
||||||
async def get_status():
|
async def get_status():
|
||||||
"""Get current system status"""
|
"""Get current system status"""
|
||||||
# Get update flags status for all namespaces
|
# Get update flags status for all namespaces
|
||||||
@@ -414,6 +442,12 @@ def create_app(args):
|
|||||||
async def webui_root():
|
async def webui_root():
|
||||||
return FileResponse(static_dir / "index.html")
|
return FileResponse(static_dir / "index.html")
|
||||||
|
|
||||||
|
@app.middleware("http")
|
||||||
|
async def debug_middleware(request: Request, call_next):
|
||||||
|
print(f"Request path: {request.url.path}")
|
||||||
|
response = await call_next(request)
|
||||||
|
return response
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
@@ -8,3 +8,15 @@ python-multipart
|
|||||||
tenacity
|
tenacity
|
||||||
tiktoken
|
tiktoken
|
||||||
uvicorn
|
uvicorn
|
||||||
|
tqdm
|
||||||
|
jiter
|
||||||
|
httpcore
|
||||||
|
distro
|
||||||
|
httpx
|
||||||
|
openai
|
||||||
|
asyncpg
|
||||||
|
neo4j
|
||||||
|
pytz
|
||||||
|
python-jose[cryptography]
|
||||||
|
passlib[bcrypt]
|
||||||
|
PyJWT
|
@@ -16,10 +16,9 @@ from pydantic import BaseModel, Field, field_validator
|
|||||||
|
|
||||||
from lightrag import LightRAG
|
from lightrag import LightRAG
|
||||||
from lightrag.base import DocProcessingStatus, DocStatus
|
from lightrag.base import DocProcessingStatus, DocStatus
|
||||||
from ..utils_api import get_api_key_dependency
|
from ..utils_api import get_api_key_dependency, get_auth_dependency
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/documents", tags=["documents"], dependencies=[Depends(get_auth_dependency())])
|
||||||
router = APIRouter(prefix="/documents", tags=["documents"])
|
|
||||||
|
|
||||||
# Temporary file prefix
|
# Temporary file prefix
|
||||||
temp_prefix = "__tmp__"
|
temp_prefix = "__tmp__"
|
||||||
|
@@ -6,9 +6,9 @@ from typing import Optional
|
|||||||
|
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends
|
||||||
|
|
||||||
from ..utils_api import get_api_key_dependency
|
from ..utils_api import get_api_key_dependency, get_auth_dependency
|
||||||
|
|
||||||
router = APIRouter(tags=["graph"])
|
router = APIRouter(tags=["graph"], dependencies=[Depends(get_auth_dependency())])
|
||||||
|
|
||||||
|
|
||||||
def create_graph_routes(rag, api_key: Optional[str] = None):
|
def create_graph_routes(rag, api_key: Optional[str] = None):
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
from fastapi import APIRouter, HTTPException, Request
|
from fastapi import APIRouter, HTTPException, Request, Depends
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
import logging
|
import logging
|
||||||
@@ -11,7 +11,7 @@ import asyncio
|
|||||||
from ascii_colors import trace_exception
|
from ascii_colors import trace_exception
|
||||||
from lightrag import LightRAG, QueryParam
|
from lightrag import LightRAG, QueryParam
|
||||||
from lightrag.utils import encode_string_by_tiktoken
|
from lightrag.utils import encode_string_by_tiktoken
|
||||||
from ..utils_api import ollama_server_infos
|
from ..utils_api import ollama_server_infos, get_auth_dependency
|
||||||
|
|
||||||
|
|
||||||
# query mode according to query prefix (bypass is not LightRAG quer mode)
|
# query mode according to query prefix (bypass is not LightRAG quer mode)
|
||||||
@@ -126,7 +126,7 @@ class OllamaAPI:
|
|||||||
self.rag = rag
|
self.rag = rag
|
||||||
self.ollama_server_infos = ollama_server_infos
|
self.ollama_server_infos = ollama_server_infos
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
self.router = APIRouter(tags=["ollama"])
|
self.router = APIRouter(tags=["ollama"], dependencies=[Depends(get_auth_dependency())])
|
||||||
self.setup_routes()
|
self.setup_routes()
|
||||||
|
|
||||||
def setup_routes(self):
|
def setup_routes(self):
|
||||||
|
@@ -8,12 +8,12 @@ from typing import Any, Dict, List, Literal, Optional
|
|||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from lightrag.base import QueryParam
|
from lightrag.base import QueryParam
|
||||||
from ..utils_api import get_api_key_dependency
|
from ..utils_api import get_api_key_dependency, get_auth_dependency
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from ascii_colors import trace_exception
|
from ascii_colors import trace_exception
|
||||||
|
|
||||||
router = APIRouter(tags=["query"])
|
router = APIRouter(tags=["query"], dependencies=[Depends(get_auth_dependency())])
|
||||||
|
|
||||||
|
|
||||||
class QueryRequest(BaseModel):
|
class QueryRequest(BaseModel):
|
||||||
|
@@ -9,10 +9,16 @@ import sys
|
|||||||
import logging
|
import logging
|
||||||
from ascii_colors import ASCIIColors
|
from ascii_colors import ASCIIColors
|
||||||
from lightrag.api import __api_version__
|
from lightrag.api import __api_version__
|
||||||
from fastapi import HTTPException, Security
|
from fastapi import (
|
||||||
|
HTTPException,
|
||||||
|
Security,
|
||||||
|
Depends,
|
||||||
|
Request
|
||||||
|
)
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from fastapi.security import APIKeyHeader
|
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
|
||||||
from starlette.status import HTTP_403_FORBIDDEN
|
from starlette.status import HTTP_403_FORBIDDEN
|
||||||
|
from .auth import auth_handler
|
||||||
|
|
||||||
# Load environment variables
|
# Load environment variables
|
||||||
load_dotenv(override=True)
|
load_dotenv(override=True)
|
||||||
@@ -31,6 +37,24 @@ class OllamaServerInfos:
|
|||||||
ollama_server_infos = OllamaServerInfos()
|
ollama_server_infos = OllamaServerInfos()
|
||||||
|
|
||||||
|
|
||||||
|
def get_auth_dependency():
|
||||||
|
whitelist = os.getenv("WHITELIST_PATHS", "").split(",")
|
||||||
|
|
||||||
|
async def dependency(
|
||||||
|
request: Request,
|
||||||
|
token: str = Depends(OAuth2PasswordBearer(tokenUrl="login", auto_error=False))
|
||||||
|
):
|
||||||
|
if request.url.path in whitelist:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not (os.getenv("AUTH_USERNAME") and os.getenv("AUTH_PASSWORD")):
|
||||||
|
return
|
||||||
|
|
||||||
|
auth_handler.validate_token(token)
|
||||||
|
|
||||||
|
return dependency
|
||||||
|
|
||||||
|
|
||||||
def get_api_key_dependency(api_key: Optional[str]):
|
def get_api_key_dependency(api_key: Optional[str]):
|
||||||
"""
|
"""
|
||||||
Create an API key dependency for route protection.
|
Create an API key dependency for route protection.
|
||||||
@@ -288,6 +312,38 @@ def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace:
|
|||||||
help="Embedding binding type (default: from env or ollama)",
|
help="Embedding binding type (default: from env or ollama)",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Authentication configuration
|
||||||
|
parser.add_argument(
|
||||||
|
"--auth-username",
|
||||||
|
type=str,
|
||||||
|
default=get_env_value("AUTH_USERNAME", ""),
|
||||||
|
help="Login username (default: from env or empty)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--auth-password",
|
||||||
|
type=str,
|
||||||
|
default=get_env_value("AUTH_PASSWORD", ""),
|
||||||
|
help="Login password (default: from env or empty)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--token-secret",
|
||||||
|
type=str,
|
||||||
|
default=get_env_value("TOKEN_SECRET", ""),
|
||||||
|
help="JWT signing secret (default: from env or empty)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--token-expire-hours",
|
||||||
|
type=int,
|
||||||
|
default=get_env_value("TOKEN_EXPIRE_HOURS", 4, int),
|
||||||
|
help="Token validity in hours (default: from env or 4)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--whitelist-paths",
|
||||||
|
type=str,
|
||||||
|
default=get_env_value("WHITELIST_PATHS", "/login,/health"),
|
||||||
|
help="Comma-separated auth-exempt paths (default: from env or /login,/health)"
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# If in uvicorn mode and workers > 1, force it to 1 and log warning
|
# If in uvicorn mode and workers > 1, force it to 1 and log warning
|
||||||
|
Reference in New Issue
Block a user