Merge branch 'main' into main

This commit is contained in:
zrguo
2025-03-09 00:23:06 +08:00
committed by GitHub
36 changed files with 884 additions and 191 deletions

View File

@@ -387,6 +387,19 @@ Note: If you don't need the API functionality, you can install the base package
pip install lightrag-hku
```
## Authentication Endpoints
### JWT Authentication Mechanism
LightRAG API Server implements JWT-based authentication using HS256 algorithm. To enable secure access control, the following environment variables are required:
```bash
# For jwt auth
AUTH_USERNAME=admin # login name
AUTH_PASSWORD=admin123 # password
TOKEN_SECRET=your-key # JWT key
TOKEN_EXPIRE_HOURS=4 # expire duration
WHITELIST_PATHS=/api1,/api2 # white list. /login,/health,/docs,/redoc,/openapi.json are whitelisted by default.
```
## API Endpoints
All servers (LoLLMs, Ollama, OpenAI and Azure OpenAI) provide the same REST API endpoints for RAG functionality. When API Server is running, visit:

41
lightrag/api/auth.py Normal file
View File

@@ -0,0 +1,41 @@
import os
from datetime import datetime, timedelta
import jwt
from fastapi import HTTPException, status
from pydantic import BaseModel
class TokenPayload(BaseModel):
sub: str
exp: datetime
class AuthHandler:
def __init__(self):
self.secret = os.getenv("TOKEN_SECRET", "4f85ds4f56dsf46")
self.algorithm = "HS256"
self.expire_hours = int(os.getenv("TOKEN_EXPIRE_HOURS", 4))
def create_token(self, username: str) -> str:
expire = datetime.utcnow() + timedelta(hours=self.expire_hours)
payload = TokenPayload(sub=username, exp=expire)
return jwt.encode(payload.dict(), self.secret, algorithm=self.algorithm)
def validate_token(self, token: str) -> str:
try:
payload = jwt.decode(token, self.secret, algorithms=[self.algorithm])
expire_timestamp = payload["exp"]
expire_time = datetime.utcfromtimestamp(expire_timestamp)
if datetime.utcnow() > expire_time:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token expired"
)
return payload["sub"]
except jwt.PyJWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
)
auth_handler = AuthHandler()

View File

@@ -2,10 +2,7 @@
LightRAG FastAPI Server
"""
from fastapi import (
FastAPI,
Depends,
)
from fastapi import FastAPI, Depends, HTTPException, status
import asyncio
import os
import logging
@@ -45,6 +42,8 @@ from lightrag.kg.shared_storage import (
initialize_pipeline_status,
get_all_update_flags_status,
)
from fastapi.security import OAuth2PasswordRequestForm
from .auth import auth_handler
# Load environment variables
# Updated to use the .env that is inside the current folder
@@ -372,6 +371,27 @@ def create_app(args):
ollama_api = OllamaAPI(rag, top_k=args.top_k)
app.include_router(ollama_api.router, prefix="/api")
@app.post("/login")
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
username = os.getenv("AUTH_USERNAME")
password = os.getenv("AUTH_PASSWORD")
if not (username and password):
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="Authentication not configured",
)
if form_data.username != username or form_data.password != password:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect credentials"
)
return {
"access_token": auth_handler.create_token(username),
"token_type": "bearer",
}
@app.get("/health", dependencies=[Depends(optional_api_key)])
async def get_status():
"""Get current system status"""

View File

@@ -1,10 +1,20 @@
aiofiles
ascii_colors
asyncpg
distro
fastapi
httpcore
httpx
jiter
numpy
openai
passlib[bcrypt]
pipmaster
PyJWT
python-dotenv
python-jose[cryptography]
python-multipart
pytz
tenacity
tiktoken
uvicorn

View File

@@ -18,8 +18,11 @@ from lightrag import LightRAG
from lightrag.base import DocProcessingStatus, DocStatus
from lightrag.api.utils_api import get_api_key_dependency, global_args
router = APIRouter(prefix="/documents", tags=["documents"])
router = APIRouter(
prefix="/documents",
tags=["documents"],
dependencies=[Depends(get_auth_dependency())],
)
# Temporary file prefix
temp_prefix = "__tmp__"

View File

@@ -3,12 +3,11 @@ This module contains all graph-related routes for the LightRAG API.
"""
from typing import Optional
from fastapi import APIRouter, Depends
from ..utils_api import get_api_key_dependency
from ..utils_api import get_api_key_dependency, get_auth_dependency
router = APIRouter(tags=["graph"])
router = APIRouter(tags=["graph"], dependencies=[Depends(get_auth_dependency())])
def create_graph_routes(rag, api_key: Optional[str] = None):
@@ -25,23 +24,33 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
return await rag.get_graph_labels()
@router.get("/graphs", dependencies=[Depends(optional_api_key)])
async def get_knowledge_graph(label: str, max_depth: int = 3):
async def get_knowledge_graph(
label: str, max_depth: int = 3, min_degree: int = 0, inclusive: bool = False
):
"""
Retrieve a connected subgraph of nodes where the label includes the specified label.
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
When reducing the number of nodes, the prioritization criteria are as follows:
1. Label matching nodes take precedence
2. Followed by nodes directly connected to the matching nodes
3. Finally, the degree of the nodes
1. min_degree does not affect nodes directly connected to the matching nodes
2. Label matching nodes take precedence
3. Followed by nodes directly connected to the matching nodes
4. Finally, the degree of the nodes
Maximum number of nodes is limited to env MAX_GRAPH_NODES(default: 1000)
Args:
label (str): Label to get knowledge graph for
max_depth (int, optional): Maximum depth of graph. Defaults to 3.
inclusive_search (bool, optional): If True, search for nodes that include the label. Defaults to False.
min_degree (int, optional): Minimum degree of nodes. Defaults to 0.
Returns:
Dict[str, List[str]]: Knowledge graph for label
"""
return await rag.get_knowledge_graph(node_label=label, max_depth=max_depth)
return await rag.get_knowledge_graph(
node_label=label,
max_depth=max_depth,
inclusive=inclusive,
min_degree=min_degree,
)
return router

View File

@@ -8,12 +8,12 @@ from typing import Any, Dict, List, Literal, Optional
from fastapi import APIRouter, Depends, HTTPException
from lightrag.base import QueryParam
from ..utils_api import get_api_key_dependency
from ..utils_api import get_api_key_dependency, get_auth_dependency
from pydantic import BaseModel, Field, field_validator
from ascii_colors import trace_exception
router = APIRouter(tags=["query"])
router = APIRouter(tags=["query"], dependencies=[Depends(get_auth_dependency())])
class QueryRequest(BaseModel):

View File

@@ -9,10 +9,11 @@ import sys
import logging
from ascii_colors import ASCIIColors
from lightrag.api import __api_version__
from fastapi import HTTPException, Security
from fastapi import HTTPException, Security, Depends, Request
from dotenv import load_dotenv
from fastapi.security import APIKeyHeader
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
from starlette.status import HTTP_403_FORBIDDEN
from .auth import auth_handler
# Load environment variables
load_dotenv(override=True)
@@ -33,6 +34,24 @@ class OllamaServerInfos:
ollama_server_infos = OllamaServerInfos()
def get_auth_dependency():
whitelist = os.getenv("WHITELIST_PATHS", "").split(",")
async def dependency(
request: Request,
token: str = Depends(OAuth2PasswordBearer(tokenUrl="login", auto_error=False)),
):
if request.url.path in whitelist:
return
if not (os.getenv("AUTH_USERNAME") and os.getenv("AUTH_PASSWORD")):
return
auth_handler.validate_token(token)
return dependency
def get_api_key_dependency(api_key: Optional[str]):
"""
Create an API key dependency for route protection.

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -5,8 +5,8 @@
<link rel="icon" type="image/svg+xml" href="./logo.png" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Lightrag</title>
<script type="module" crossorigin src="./assets/index-DbuMPJAD.js"></script>
<link rel="stylesheet" crossorigin href="./assets/index-rP-YlyR1.css">
<script type="module" crossorigin src="./assets/index-CJz72b6Q.js"></script>
<link rel="stylesheet" crossorigin href="./assets/index-CH-3l4_Z.css">
</head>
<body>
<div id="root"></div>