feat(api): Add user authentication functionality

- Implement JWT-based user authentication logic
- Add login endpoint and token validation middleware
- Update API routes with authentication dependencies
- Add authentication-related environment variables
- Optimize requirements.txt with necessary dependencies
This commit is contained in:
Milin
2025-03-05 11:09:31 +08:00
parent fb07bc04a0
commit 63aa4f9dfc
10 changed files with 203 additions and 33 deletions

View File

@@ -148,3 +148,10 @@ QDRANT_URL=http://localhost:16333
### Redis ### Redis
REDIS_URI=redis://localhost:6379 REDIS_URI=redis://localhost:6379
# For jwt auth
AUTH_USERNAME=admin # login name
AUTH_PASSWORD=admin123 # password
TOKEN_SECRET=your-key # JWT key
TOKEN_EXPIRE_HOURS=4 # expire duration
WHITELIST_PATHS=/login,/health # white list

View File

@@ -295,26 +295,32 @@ You can not change storage implementation selection after you add documents to L
### LightRag API Server Comand Line Options ### LightRag API Server Comand Line Options
| Parameter | Default | Description | | Parameter | Default | Description |
|-----------|---------|-------------| |-------------------------|----------------|-----------------------------------------------------------------------------------------------------------------------------|
| --host | 0.0.0.0 | Server host | | --host | 0.0.0.0 | Server host |
| --port | 9621 | Server port | | --port | 9621 | Server port |
| --working-dir | ./rag_storage | Working directory for RAG storage | | --working-dir | ./rag_storage | Working directory for RAG storage |
| --input-dir | ./inputs | Directory containing input documents | | --input-dir | ./inputs | Directory containing input documents |
| --max-async | 4 | Maximum async operations | | --max-async | 4 | Maximum async operations |
| --max-tokens | 32768 | Maximum token size | | --max-tokens | 32768 | Maximum token size |
| --timeout | 150 | Timeout in seconds. None for infinite timeout(not recommended) | | --timeout | 150 | Timeout in seconds. None for infinite timeout(not recommended) |
| --log-level | INFO | Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) | | --log-level | INFO | Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) |
| --verbose | - | Verbose debug output (True, Flase) | | --verbose | - | Verbose debug output (True, Flase) |
| --key | None | API key for authentication. Protects lightrag server against unauthorized access | | --key | None | API key for authentication. Protects lightrag server against unauthorized access |
| --ssl | False | Enable HTTPS | | --ssl | False | Enable HTTPS |
| --ssl-certfile | None | Path to SSL certificate file (required if --ssl is enabled) | | --ssl-certfile | None | Path to SSL certificate file (required if --ssl is enabled) |
| --ssl-keyfile | None | Path to SSL private key file (required if --ssl is enabled) | | --ssl-keyfile | None | Path to SSL private key file (required if --ssl is enabled) |
| --top-k | 50 | Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. | | --top-k | 50 | Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. |
| --cosine-threshold | 0.4 | The cossine threshold for nodes and relations retrieval, works with top-k to control the retrieval of nodes and relations. | | --cosine-threshold | 0.4 | The cossine threshold for nodes and relations retrieval, works with top-k to control the retrieval of nodes and relations. |
| --llm-binding | ollama | LLM binding type (lollms, ollama, openai, openai-ollama, azure_openai) | | --llm-binding | ollama | LLM binding type (lollms, ollama, openai, openai-ollama, azure_openai) |
| --embedding-binding | ollama | Embedding binding type (lollms, ollama, openai, azure_openai) | | --embedding-binding | ollama | Embedding binding type (lollms, ollama, openai, azure_openai) |
| auto-scan-at-startup | - | Scan input directory for new files and start indexing | | --auto-scan-at-startup | - | Scan input directory for new files and start indexing |
| --auth-username | - | Enable jwt if not empty |
| --auth-password | - | Enable jwt if not empty |
| --token-secret | - | JWT key |
| --token-expire-hours | 4 | expire duration |
| --whitelist-paths | /login,/health | white list |
### Example Usage ### Example Usage
@@ -387,6 +393,19 @@ Note: If you don't need the API functionality, you can install the base package
pip install lightrag-hku pip install lightrag-hku
``` ```
## Authentication Endpoints
### JWT Authentication Mechanism
LightRAG API Server implements JWT-based authentication using HS256 algorithm. To enable secure access control, the following environment variables are required:
```bash
# For jwt auth
AUTH_USERNAME=admin # login name --auth-username
AUTH_PASSWORD=admin123 # password --auth-password
TOKEN_SECRET=your-key # JWT key --token-secret
TOKEN_EXPIRE_HOURS=4 # expire duration --token-expire-hours
WHITELIST_PATHS=/login,/health # white list --whitelist-paths
```
## API Endpoints ## API Endpoints
All servers (LoLLMs, Ollama, OpenAI and Azure OpenAI) provide the same REST API endpoints for RAG functionality. When API Server is running, visit: All servers (LoLLMs, Ollama, OpenAI and Azure OpenAI) provide the same REST API endpoints for RAG functionality. When API Server is running, visit:

43
lightrag/api/auth.py Normal file
View File

@@ -0,0 +1,43 @@
import os
from datetime import datetime, timedelta
import jwt
from fastapi import HTTPException, status
from pydantic import BaseModel
class TokenPayload(BaseModel):
sub: str
exp: datetime
class AuthHandler:
def __init__(self):
self.secret = os.getenv("TOKEN_SECRET", "4f85ds4f56dsf46")
self.algorithm = "HS256"
self.expire_hours = int(os.getenv("TOKEN_EXPIRE_HOURS", 4))
def create_token(self, username: str) -> str:
expire = datetime.utcnow() + timedelta(hours=self.expire_hours)
payload = TokenPayload(sub=username, exp=expire)
return jwt.encode(payload.dict(), self.secret, algorithm=self.algorithm)
def validate_token(self, token: str) -> str:
try:
payload = jwt.decode(token, self.secret, algorithms=[self.algorithm])
expire_timestamp = payload["exp"]
expire_time = datetime.utcfromtimestamp(expire_timestamp)
if datetime.utcnow() > expire_time:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token expired"
)
return payload["sub"]
except jwt.PyJWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token"
)
auth_handler = AuthHandler()

View File

@@ -5,6 +5,9 @@ LightRAG FastAPI Server
from fastapi import ( from fastapi import (
FastAPI, FastAPI,
Depends, Depends,
HTTPException,
Request,
status
) )
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
import asyncio import asyncio
@@ -25,6 +28,7 @@ from .utils_api import (
parse_args, parse_args,
get_default_host, get_default_host,
display_splash_screen, display_splash_screen,
get_auth_dependency,
) )
from lightrag import LightRAG from lightrag import LightRAG
from lightrag.types import GPTKeywordExtractionFormat from lightrag.types import GPTKeywordExtractionFormat
@@ -46,6 +50,8 @@ from lightrag.kg.shared_storage import (
initialize_pipeline_status, initialize_pipeline_status,
get_all_update_flags_status, get_all_update_flags_status,
) )
from fastapi.security import OAuth2PasswordRequestForm
from .auth import auth_handler
# Load environment variables # Load environment variables
load_dotenv(override=True) load_dotenv(override=True)
@@ -373,7 +379,29 @@ def create_app(args):
ollama_api = OllamaAPI(rag, top_k=args.top_k) ollama_api = OllamaAPI(rag, top_k=args.top_k)
app.include_router(ollama_api.router, prefix="/api") app.include_router(ollama_api.router, prefix="/api")
@app.get("/health", dependencies=[Depends(optional_api_key)]) @app.post("/login")
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
username = os.getenv("AUTH_USERNAME")
password = os.getenv("AUTH_PASSWORD")
if not (username and password):
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="Authentication not configured"
)
if form_data.username != username or form_data.password != password:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect credentials"
)
return {
"access_token": auth_handler.create_token(username),
"token_type": "bearer"
}
@app.get("/health", dependencies=[Depends(optional_api_key), Depends(get_auth_dependency())])
async def get_status(): async def get_status():
"""Get current system status""" """Get current system status"""
# Get update flags status for all namespaces # Get update flags status for all namespaces
@@ -414,6 +442,12 @@ def create_app(args):
async def webui_root(): async def webui_root():
return FileResponse(static_dir / "index.html") return FileResponse(static_dir / "index.html")
@app.middleware("http")
async def debug_middleware(request: Request, call_next):
print(f"Request path: {request.url.path}")
response = await call_next(request)
return response
return app return app

View File

@@ -8,3 +8,15 @@ python-multipart
tenacity tenacity
tiktoken tiktoken
uvicorn uvicorn
tqdm
jiter
httpcore
distro
httpx
openai
asyncpg
neo4j
pytz
python-jose[cryptography]
passlib[bcrypt]
PyJWT

View File

@@ -16,10 +16,9 @@ from pydantic import BaseModel, Field, field_validator
from lightrag import LightRAG from lightrag import LightRAG
from lightrag.base import DocProcessingStatus, DocStatus from lightrag.base import DocProcessingStatus, DocStatus
from ..utils_api import get_api_key_dependency from ..utils_api import get_api_key_dependency, get_auth_dependency
router = APIRouter(prefix="/documents", tags=["documents"], dependencies=[Depends(get_auth_dependency())])
router = APIRouter(prefix="/documents", tags=["documents"])
# Temporary file prefix # Temporary file prefix
temp_prefix = "__tmp__" temp_prefix = "__tmp__"

View File

@@ -6,9 +6,9 @@ from typing import Optional
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from ..utils_api import get_api_key_dependency from ..utils_api import get_api_key_dependency, get_auth_dependency
router = APIRouter(tags=["graph"]) router = APIRouter(tags=["graph"], dependencies=[Depends(get_auth_dependency())])
def create_graph_routes(rag, api_key: Optional[str] = None): def create_graph_routes(rag, api_key: Optional[str] = None):

View File

@@ -1,4 +1,4 @@
from fastapi import APIRouter, HTTPException, Request from fastapi import APIRouter, HTTPException, Request, Depends
from pydantic import BaseModel from pydantic import BaseModel
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional
import logging import logging
@@ -11,7 +11,7 @@ import asyncio
from ascii_colors import trace_exception from ascii_colors import trace_exception
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
from lightrag.utils import encode_string_by_tiktoken from lightrag.utils import encode_string_by_tiktoken
from ..utils_api import ollama_server_infos from ..utils_api import ollama_server_infos, get_auth_dependency
# query mode according to query prefix (bypass is not LightRAG quer mode) # query mode according to query prefix (bypass is not LightRAG quer mode)
@@ -126,7 +126,7 @@ class OllamaAPI:
self.rag = rag self.rag = rag
self.ollama_server_infos = ollama_server_infos self.ollama_server_infos = ollama_server_infos
self.top_k = top_k self.top_k = top_k
self.router = APIRouter(tags=["ollama"]) self.router = APIRouter(tags=["ollama"], dependencies=[Depends(get_auth_dependency())])
self.setup_routes() self.setup_routes()
def setup_routes(self): def setup_routes(self):

View File

@@ -8,12 +8,12 @@ from typing import Any, Dict, List, Literal, Optional
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from lightrag.base import QueryParam from lightrag.base import QueryParam
from ..utils_api import get_api_key_dependency from ..utils_api import get_api_key_dependency, get_auth_dependency
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from ascii_colors import trace_exception from ascii_colors import trace_exception
router = APIRouter(tags=["query"]) router = APIRouter(tags=["query"], dependencies=[Depends(get_auth_dependency())])
class QueryRequest(BaseModel): class QueryRequest(BaseModel):

View File

@@ -9,10 +9,16 @@ import sys
import logging import logging
from ascii_colors import ASCIIColors from ascii_colors import ASCIIColors
from lightrag.api import __api_version__ from lightrag.api import __api_version__
from fastapi import HTTPException, Security from fastapi import (
HTTPException,
Security,
Depends,
Request
)
from dotenv import load_dotenv from dotenv import load_dotenv
from fastapi.security import APIKeyHeader from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
from starlette.status import HTTP_403_FORBIDDEN from starlette.status import HTTP_403_FORBIDDEN
from .auth import auth_handler
# Load environment variables # Load environment variables
load_dotenv(override=True) load_dotenv(override=True)
@@ -31,6 +37,24 @@ class OllamaServerInfos:
ollama_server_infos = OllamaServerInfos() ollama_server_infos = OllamaServerInfos()
def get_auth_dependency():
whitelist = os.getenv("WHITELIST_PATHS", "").split(",")
async def dependency(
request: Request,
token: str = Depends(OAuth2PasswordBearer(tokenUrl="login", auto_error=False))
):
if request.url.path in whitelist:
return
if not (os.getenv("AUTH_USERNAME") and os.getenv("AUTH_PASSWORD")):
return
auth_handler.validate_token(token)
return dependency
def get_api_key_dependency(api_key: Optional[str]): def get_api_key_dependency(api_key: Optional[str]):
""" """
Create an API key dependency for route protection. Create an API key dependency for route protection.
@@ -288,6 +312,38 @@ def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace:
help="Embedding binding type (default: from env or ollama)", help="Embedding binding type (default: from env or ollama)",
) )
# Authentication configuration
parser.add_argument(
"--auth-username",
type=str,
default=get_env_value("AUTH_USERNAME", ""),
help="Login username (default: from env or empty)"
)
parser.add_argument(
"--auth-password",
type=str,
default=get_env_value("AUTH_PASSWORD", ""),
help="Login password (default: from env or empty)"
)
parser.add_argument(
"--token-secret",
type=str,
default=get_env_value("TOKEN_SECRET", ""),
help="JWT signing secret (default: from env or empty)"
)
parser.add_argument(
"--token-expire-hours",
type=int,
default=get_env_value("TOKEN_EXPIRE_HOURS", 4, int),
help="Token validity in hours (default: from env or 4)"
)
parser.add_argument(
"--whitelist-paths",
type=str,
default=get_env_value("WHITELIST_PATHS", "/login,/health"),
help="Comma-separated auth-exempt paths (default: from env or /login,/health)"
)
args = parser.parse_args() args = parser.parse_args()
# If in uvicorn mode and workers > 1, force it to 1 and log warning # If in uvicorn mode and workers > 1, force it to 1 and log warning