From 63aa4f9dfca1269f0655b6636ad77f855848a0e9 Mon Sep 17 00:00:00 2001 From: Milin Date: Wed, 5 Mar 2025 11:09:31 +0800 Subject: [PATCH] feat(api): Add user authentication functionality - Implement JWT-based user authentication logic - Add login endpoint and token validation middleware - Update API routes with authentication dependencies - Add authentication-related environment variables - Optimize requirements.txt with necessary dependencies --- env.example | 7 +++ lightrag/api/README.md | 59 +++++++++++++++--------- lightrag/api/auth.py | 43 ++++++++++++++++++ lightrag/api/lightrag_server.py | 36 ++++++++++++++- lightrag/api/requirements.txt | 12 +++++ lightrag/api/routers/document_routes.py | 5 +-- lightrag/api/routers/graph_routes.py | 4 +- lightrag/api/routers/ollama_api.py | 6 +-- lightrag/api/routers/query_routes.py | 4 +- lightrag/api/utils_api.py | 60 ++++++++++++++++++++++++- 10 files changed, 203 insertions(+), 33 deletions(-) create mode 100644 lightrag/api/auth.py diff --git a/env.example b/env.example index 99909ac6..cd92abc8 100644 --- a/env.example +++ b/env.example @@ -148,3 +148,10 @@ QDRANT_URL=http://localhost:16333 ### Redis REDIS_URI=redis://localhost:6379 + +# For jwt auth +AUTH_USERNAME=admin # login name +AUTH_PASSWORD=admin123 # password +TOKEN_SECRET=your-key # JWT key +TOKEN_EXPIRE_HOURS=4 # expire duration +WHITELIST_PATHS=/login,/health # white list diff --git a/lightrag/api/README.md b/lightrag/api/README.md index 8f61f2f6..e00fac48 100644 --- a/lightrag/api/README.md +++ b/lightrag/api/README.md @@ -295,26 +295,32 @@ You can not change storage implementation selection after you add documents to L ### LightRag API Server Comand Line Options -| Parameter | Default | Description | -|-----------|---------|-------------| -| --host | 0.0.0.0 | Server host | -| --port | 9621 | Server port | -| --working-dir | ./rag_storage | Working directory for RAG storage | -| --input-dir | ./inputs | Directory containing input documents | -| --max-async | 4 | Maximum async operations | -| --max-tokens | 32768 | Maximum token size | -| --timeout | 150 | Timeout in seconds. None for infinite timeout(not recommended) | -| --log-level | INFO | Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) | -| --verbose | - | Verbose debug output (True, Flase) | -| --key | None | API key for authentication. Protects lightrag server against unauthorized access | -| --ssl | False | Enable HTTPS | -| --ssl-certfile | None | Path to SSL certificate file (required if --ssl is enabled) | -| --ssl-keyfile | None | Path to SSL private key file (required if --ssl is enabled) | -| --top-k | 50 | Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. | -| --cosine-threshold | 0.4 | The cossine threshold for nodes and relations retrieval, works with top-k to control the retrieval of nodes and relations. | -| --llm-binding | ollama | LLM binding type (lollms, ollama, openai, openai-ollama, azure_openai) | -| --embedding-binding | ollama | Embedding binding type (lollms, ollama, openai, azure_openai) | -| auto-scan-at-startup | - | Scan input directory for new files and start indexing | +| Parameter | Default | Description | +|-------------------------|----------------|-----------------------------------------------------------------------------------------------------------------------------| +| --host | 0.0.0.0 | Server host | +| --port | 9621 | Server port | +| --working-dir | ./rag_storage | Working directory for RAG storage | +| --input-dir | ./inputs | Directory containing input documents | +| --max-async | 4 | Maximum async operations | +| --max-tokens | 32768 | Maximum token size | +| --timeout | 150 | Timeout in seconds. None for infinite timeout(not recommended) | +| --log-level | INFO | Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) | +| --verbose | - | Verbose debug output (True, Flase) | +| --key | None | API key for authentication. Protects lightrag server against unauthorized access | +| --ssl | False | Enable HTTPS | +| --ssl-certfile | None | Path to SSL certificate file (required if --ssl is enabled) | +| --ssl-keyfile | None | Path to SSL private key file (required if --ssl is enabled) | +| --top-k | 50 | Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. | +| --cosine-threshold | 0.4 | The cossine threshold for nodes and relations retrieval, works with top-k to control the retrieval of nodes and relations. | +| --llm-binding | ollama | LLM binding type (lollms, ollama, openai, openai-ollama, azure_openai) | +| --embedding-binding | ollama | Embedding binding type (lollms, ollama, openai, azure_openai) | +| --auto-scan-at-startup | - | Scan input directory for new files and start indexing | +| --auth-username | - | Enable jwt if not empty | +| --auth-password | - | Enable jwt if not empty | +| --token-secret | - | JWT key | +| --token-expire-hours | 4 | expire duration | +| --whitelist-paths | /login,/health | white list | + ### Example Usage @@ -387,6 +393,19 @@ Note: If you don't need the API functionality, you can install the base package pip install lightrag-hku ``` +## Authentication Endpoints + +### JWT Authentication Mechanism +LightRAG API Server implements JWT-based authentication using HS256 algorithm. To enable secure access control, the following environment variables are required: +```bash +# For jwt auth +AUTH_USERNAME=admin # login name --auth-username +AUTH_PASSWORD=admin123 # password --auth-password +TOKEN_SECRET=your-key # JWT key --token-secret +TOKEN_EXPIRE_HOURS=4 # expire duration --token-expire-hours +WHITELIST_PATHS=/login,/health # white list --whitelist-paths +``` + ## API Endpoints All servers (LoLLMs, Ollama, OpenAI and Azure OpenAI) provide the same REST API endpoints for RAG functionality. When API Server is running, visit: diff --git a/lightrag/api/auth.py b/lightrag/api/auth.py new file mode 100644 index 00000000..04a7ee46 --- /dev/null +++ b/lightrag/api/auth.py @@ -0,0 +1,43 @@ +import os +from datetime import datetime, timedelta +import jwt +from fastapi import HTTPException, status +from pydantic import BaseModel + + +class TokenPayload(BaseModel): + sub: str + exp: datetime + + +class AuthHandler: + def __init__(self): + self.secret = os.getenv("TOKEN_SECRET", "4f85ds4f56dsf46") + self.algorithm = "HS256" + self.expire_hours = int(os.getenv("TOKEN_EXPIRE_HOURS", 4)) + + def create_token(self, username: str) -> str: + expire = datetime.utcnow() + timedelta(hours=self.expire_hours) + payload = TokenPayload(sub=username, exp=expire) + return jwt.encode(payload.dict(), self.secret, algorithm=self.algorithm) + + def validate_token(self, token: str) -> str: + try: + payload = jwt.decode(token, self.secret, algorithms=[self.algorithm]) + expire_timestamp = payload["exp"] + expire_time = datetime.utcfromtimestamp(expire_timestamp) + + if datetime.utcnow() > expire_time: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Token expired" + ) + return payload["sub"] + except jwt.PyJWTError: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid token" + ) + + +auth_handler = AuthHandler() diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 5f2c437f..33f6bc61 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -5,6 +5,9 @@ LightRAG FastAPI Server from fastapi import ( FastAPI, Depends, + HTTPException, + Request, + status ) from fastapi.responses import FileResponse import asyncio @@ -25,6 +28,7 @@ from .utils_api import ( parse_args, get_default_host, display_splash_screen, + get_auth_dependency, ) from lightrag import LightRAG from lightrag.types import GPTKeywordExtractionFormat @@ -46,6 +50,8 @@ from lightrag.kg.shared_storage import ( initialize_pipeline_status, get_all_update_flags_status, ) +from fastapi.security import OAuth2PasswordRequestForm +from .auth import auth_handler # Load environment variables load_dotenv(override=True) @@ -373,7 +379,29 @@ def create_app(args): ollama_api = OllamaAPI(rag, top_k=args.top_k) app.include_router(ollama_api.router, prefix="/api") - @app.get("/health", dependencies=[Depends(optional_api_key)]) + @app.post("/login") + async def login(form_data: OAuth2PasswordRequestForm = Depends()): + username = os.getenv("AUTH_USERNAME") + password = os.getenv("AUTH_PASSWORD") + + if not (username and password): + raise HTTPException( + status_code=status.HTTP_501_NOT_IMPLEMENTED, + detail="Authentication not configured" + ) + + if form_data.username != username or form_data.password != password: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect credentials" + ) + + return { + "access_token": auth_handler.create_token(username), + "token_type": "bearer" + } + + @app.get("/health", dependencies=[Depends(optional_api_key), Depends(get_auth_dependency())]) async def get_status(): """Get current system status""" # Get update flags status for all namespaces @@ -414,6 +442,12 @@ def create_app(args): async def webui_root(): return FileResponse(static_dir / "index.html") + @app.middleware("http") + async def debug_middleware(request: Request, call_next): + print(f"Request path: {request.url.path}") + response = await call_next(request) + return response + return app diff --git a/lightrag/api/requirements.txt b/lightrag/api/requirements.txt index 38ad569b..d44d4eb9 100644 --- a/lightrag/api/requirements.txt +++ b/lightrag/api/requirements.txt @@ -8,3 +8,15 @@ python-multipart tenacity tiktoken uvicorn +tqdm +jiter +httpcore +distro +httpx +openai +asyncpg +neo4j +pytz +python-jose[cryptography] +passlib[bcrypt] +PyJWT \ No newline at end of file diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index ab5aff96..4cebd41a 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -16,10 +16,9 @@ from pydantic import BaseModel, Field, field_validator from lightrag import LightRAG from lightrag.base import DocProcessingStatus, DocStatus -from ..utils_api import get_api_key_dependency +from ..utils_api import get_api_key_dependency, get_auth_dependency - -router = APIRouter(prefix="/documents", tags=["documents"]) +router = APIRouter(prefix="/documents", tags=["documents"], dependencies=[Depends(get_auth_dependency())]) # Temporary file prefix temp_prefix = "__tmp__" diff --git a/lightrag/api/routers/graph_routes.py b/lightrag/api/routers/graph_routes.py index e6f894a2..c7a5411b 100644 --- a/lightrag/api/routers/graph_routes.py +++ b/lightrag/api/routers/graph_routes.py @@ -6,9 +6,9 @@ from typing import Optional from fastapi import APIRouter, Depends -from ..utils_api import get_api_key_dependency +from ..utils_api import get_api_key_dependency, get_auth_dependency -router = APIRouter(tags=["graph"]) +router = APIRouter(tags=["graph"], dependencies=[Depends(get_auth_dependency())]) def create_graph_routes(rag, api_key: Optional[str] = None): diff --git a/lightrag/api/routers/ollama_api.py b/lightrag/api/routers/ollama_api.py index 9688d073..05222908 100644 --- a/lightrag/api/routers/ollama_api.py +++ b/lightrag/api/routers/ollama_api.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter, HTTPException, Request +from fastapi import APIRouter, HTTPException, Request, Depends from pydantic import BaseModel from typing import List, Dict, Any, Optional import logging @@ -11,7 +11,7 @@ import asyncio from ascii_colors import trace_exception from lightrag import LightRAG, QueryParam from lightrag.utils import encode_string_by_tiktoken -from ..utils_api import ollama_server_infos +from ..utils_api import ollama_server_infos, get_auth_dependency # query mode according to query prefix (bypass is not LightRAG quer mode) @@ -126,7 +126,7 @@ class OllamaAPI: self.rag = rag self.ollama_server_infos = ollama_server_infos self.top_k = top_k - self.router = APIRouter(tags=["ollama"]) + self.router = APIRouter(tags=["ollama"], dependencies=[Depends(get_auth_dependency())]) self.setup_routes() def setup_routes(self): diff --git a/lightrag/api/routers/query_routes.py b/lightrag/api/routers/query_routes.py index 10bfe7a8..7a5bd8c3 100644 --- a/lightrag/api/routers/query_routes.py +++ b/lightrag/api/routers/query_routes.py @@ -8,12 +8,12 @@ from typing import Any, Dict, List, Literal, Optional from fastapi import APIRouter, Depends, HTTPException from lightrag.base import QueryParam -from ..utils_api import get_api_key_dependency +from ..utils_api import get_api_key_dependency, get_auth_dependency from pydantic import BaseModel, Field, field_validator from ascii_colors import trace_exception -router = APIRouter(tags=["query"]) +router = APIRouter(tags=["query"], dependencies=[Depends(get_auth_dependency())]) class QueryRequest(BaseModel): diff --git a/lightrag/api/utils_api.py b/lightrag/api/utils_api.py index ed1250d4..d7622ac0 100644 --- a/lightrag/api/utils_api.py +++ b/lightrag/api/utils_api.py @@ -9,10 +9,16 @@ import sys import logging from ascii_colors import ASCIIColors from lightrag.api import __api_version__ -from fastapi import HTTPException, Security +from fastapi import ( + HTTPException, + Security, + Depends, + Request +) from dotenv import load_dotenv -from fastapi.security import APIKeyHeader +from fastapi.security import APIKeyHeader, OAuth2PasswordBearer from starlette.status import HTTP_403_FORBIDDEN +from .auth import auth_handler # Load environment variables load_dotenv(override=True) @@ -31,6 +37,24 @@ class OllamaServerInfos: ollama_server_infos = OllamaServerInfos() +def get_auth_dependency(): + whitelist = os.getenv("WHITELIST_PATHS", "").split(",") + + async def dependency( + request: Request, + token: str = Depends(OAuth2PasswordBearer(tokenUrl="login", auto_error=False)) + ): + if request.url.path in whitelist: + return + + if not (os.getenv("AUTH_USERNAME") and os.getenv("AUTH_PASSWORD")): + return + + auth_handler.validate_token(token) + + return dependency + + def get_api_key_dependency(api_key: Optional[str]): """ Create an API key dependency for route protection. @@ -288,6 +312,38 @@ def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace: help="Embedding binding type (default: from env or ollama)", ) + # Authentication configuration + parser.add_argument( + "--auth-username", + type=str, + default=get_env_value("AUTH_USERNAME", ""), + help="Login username (default: from env or empty)" + ) + parser.add_argument( + "--auth-password", + type=str, + default=get_env_value("AUTH_PASSWORD", ""), + help="Login password (default: from env or empty)" + ) + parser.add_argument( + "--token-secret", + type=str, + default=get_env_value("TOKEN_SECRET", ""), + help="JWT signing secret (default: from env or empty)" + ) + parser.add_argument( + "--token-expire-hours", + type=int, + default=get_env_value("TOKEN_EXPIRE_HOURS", 4, int), + help="Token validity in hours (default: from env or 4)" + ) + parser.add_argument( + "--whitelist-paths", + type=str, + default=get_env_value("WHITELIST_PATHS", "/login,/health"), + help="Comma-separated auth-exempt paths (default: from env or /login,/health)" + ) + args = parser.parse_args() # If in uvicorn mode and workers > 1, force it to 1 and log warning