diff --git a/env.example b/env.example index 99909ac6..cd92abc8 100644 --- a/env.example +++ b/env.example @@ -148,3 +148,10 @@ QDRANT_URL=http://localhost:16333 ### Redis REDIS_URI=redis://localhost:6379 + +# For jwt auth +AUTH_USERNAME=admin # login name +AUTH_PASSWORD=admin123 # password +TOKEN_SECRET=your-key # JWT key +TOKEN_EXPIRE_HOURS=4 # expire duration +WHITELIST_PATHS=/login,/health # white list diff --git a/lightrag/api/README.md b/lightrag/api/README.md index 8f61f2f6..f9a2a197 100644 --- a/lightrag/api/README.md +++ b/lightrag/api/README.md @@ -387,6 +387,19 @@ Note: If you don't need the API functionality, you can install the base package pip install lightrag-hku ``` +## Authentication Endpoints + +### JWT Authentication Mechanism +LightRAG API Server implements JWT-based authentication using HS256 algorithm. To enable secure access control, the following environment variables are required: +```bash +# For jwt auth +AUTH_USERNAME=admin # login name +AUTH_PASSWORD=admin123 # password +TOKEN_SECRET=your-key # JWT key +TOKEN_EXPIRE_HOURS=4 # expire duration +WHITELIST_PATHS=/api1,/api2 # white list. /login,/health,/docs,/redoc,/openapi.json are whitelisted by default. +``` + ## API Endpoints All servers (LoLLMs, Ollama, OpenAI and Azure OpenAI) provide the same REST API endpoints for RAG functionality. When API Server is running, visit: diff --git a/lightrag/api/auth.py b/lightrag/api/auth.py new file mode 100644 index 00000000..4d905de8 --- /dev/null +++ b/lightrag/api/auth.py @@ -0,0 +1,41 @@ +import os +from datetime import datetime, timedelta +import jwt +from fastapi import HTTPException, status +from pydantic import BaseModel + + +class TokenPayload(BaseModel): + sub: str + exp: datetime + + +class AuthHandler: + def __init__(self): + self.secret = os.getenv("TOKEN_SECRET", "4f85ds4f56dsf46") + self.algorithm = "HS256" + self.expire_hours = int(os.getenv("TOKEN_EXPIRE_HOURS", 4)) + + def create_token(self, username: str) -> str: + expire = datetime.utcnow() + timedelta(hours=self.expire_hours) + payload = TokenPayload(sub=username, exp=expire) + return jwt.encode(payload.dict(), self.secret, algorithm=self.algorithm) + + def validate_token(self, token: str) -> str: + try: + payload = jwt.decode(token, self.secret, algorithms=[self.algorithm]) + expire_timestamp = payload["exp"] + expire_time = datetime.utcfromtimestamp(expire_timestamp) + + if datetime.utcnow() > expire_time: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail="Token expired" + ) + return payload["sub"] + except jwt.PyJWTError: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token" + ) + + +auth_handler = AuthHandler() diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 8ad232f0..2891b542 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -2,10 +2,7 @@ LightRAG FastAPI Server """ -from fastapi import ( - FastAPI, - Depends, -) +from fastapi import FastAPI, Depends, HTTPException, status import asyncio import os import logging @@ -45,6 +42,8 @@ from lightrag.kg.shared_storage import ( initialize_pipeline_status, get_all_update_flags_status, ) +from fastapi.security import OAuth2PasswordRequestForm +from .auth import auth_handler # Load environment variables # Updated to use the .env that is inside the current folder @@ -372,6 +371,27 @@ def create_app(args): ollama_api = OllamaAPI(rag, top_k=args.top_k) app.include_router(ollama_api.router, prefix="/api") + @app.post("/login") + async def login(form_data: OAuth2PasswordRequestForm = Depends()): + username = os.getenv("AUTH_USERNAME") + password = os.getenv("AUTH_PASSWORD") + + if not (username and password): + raise HTTPException( + status_code=status.HTTP_501_NOT_IMPLEMENTED, + detail="Authentication not configured", + ) + + if form_data.username != username or form_data.password != password: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect credentials" + ) + + return { + "access_token": auth_handler.create_token(username), + "token_type": "bearer", + } + @app.get("/health", dependencies=[Depends(optional_api_key)]) async def get_status(): """Get current system status""" diff --git a/lightrag/api/requirements.txt b/lightrag/api/requirements.txt index 38ad569b..0e8e246b 100644 --- a/lightrag/api/requirements.txt +++ b/lightrag/api/requirements.txt @@ -1,10 +1,20 @@ aiofiles ascii_colors +asyncpg +distro fastapi +httpcore +httpx +jiter numpy +openai +passlib[bcrypt] pipmaster +PyJWT python-dotenv +python-jose[cryptography] python-multipart +pytz tenacity tiktoken uvicorn diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index d9dfe913..3e51fa4d 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -16,10 +16,13 @@ from pydantic import BaseModel, Field, field_validator from lightrag import LightRAG from lightrag.base import DocProcessingStatus, DocStatus -from ..utils_api import get_api_key_dependency +from ..utils_api import get_api_key_dependency, get_auth_dependency - -router = APIRouter(prefix="/documents", tags=["documents"]) +router = APIRouter( + prefix="/documents", + tags=["documents"], + dependencies=[Depends(get_auth_dependency())], +) # Temporary file prefix temp_prefix = "__tmp__" diff --git a/lightrag/api/routers/graph_routes.py b/lightrag/api/routers/graph_routes.py index e6f894a2..c7a5411b 100644 --- a/lightrag/api/routers/graph_routes.py +++ b/lightrag/api/routers/graph_routes.py @@ -6,9 +6,9 @@ from typing import Optional from fastapi import APIRouter, Depends -from ..utils_api import get_api_key_dependency +from ..utils_api import get_api_key_dependency, get_auth_dependency -router = APIRouter(tags=["graph"]) +router = APIRouter(tags=["graph"], dependencies=[Depends(get_auth_dependency())]) def create_graph_routes(rag, api_key: Optional[str] = None): diff --git a/lightrag/api/routers/query_routes.py b/lightrag/api/routers/query_routes.py index 10bfe7a8..7a5bd8c3 100644 --- a/lightrag/api/routers/query_routes.py +++ b/lightrag/api/routers/query_routes.py @@ -8,12 +8,12 @@ from typing import Any, Dict, List, Literal, Optional from fastapi import APIRouter, Depends, HTTPException from lightrag.base import QueryParam -from ..utils_api import get_api_key_dependency +from ..utils_api import get_api_key_dependency, get_auth_dependency from pydantic import BaseModel, Field, field_validator from ascii_colors import trace_exception -router = APIRouter(tags=["query"]) +router = APIRouter(tags=["query"], dependencies=[Depends(get_auth_dependency())]) class QueryRequest(BaseModel): diff --git a/lightrag/api/utils_api.py b/lightrag/api/utils_api.py index ed1250d4..dc467449 100644 --- a/lightrag/api/utils_api.py +++ b/lightrag/api/utils_api.py @@ -9,10 +9,11 @@ import sys import logging from ascii_colors import ASCIIColors from lightrag.api import __api_version__ -from fastapi import HTTPException, Security +from fastapi import HTTPException, Security, Depends, Request from dotenv import load_dotenv -from fastapi.security import APIKeyHeader +from fastapi.security import APIKeyHeader, OAuth2PasswordBearer from starlette.status import HTTP_403_FORBIDDEN +from .auth import auth_handler # Load environment variables load_dotenv(override=True) @@ -31,6 +32,24 @@ class OllamaServerInfos: ollama_server_infos = OllamaServerInfos() +def get_auth_dependency(): + whitelist = os.getenv("WHITELIST_PATHS", "").split(",") + + async def dependency( + request: Request, + token: str = Depends(OAuth2PasswordBearer(tokenUrl="login", auto_error=False)), + ): + if request.url.path in whitelist: + return + + if not (os.getenv("AUTH_USERNAME") and os.getenv("AUTH_PASSWORD")): + return + + auth_handler.validate_token(token) + + return dependency + + def get_api_key_dependency(api_key: Optional[str]): """ Create an API key dependency for route protection.