Merge branch 'main' into main
This commit is contained in:
@@ -387,6 +387,19 @@ Note: If you don't need the API functionality, you can install the base package
|
||||
pip install lightrag-hku
|
||||
```
|
||||
|
||||
## Authentication Endpoints
|
||||
|
||||
### JWT Authentication Mechanism
|
||||
LightRAG API Server implements JWT-based authentication using HS256 algorithm. To enable secure access control, the following environment variables are required:
|
||||
```bash
|
||||
# For jwt auth
|
||||
AUTH_USERNAME=admin # login name
|
||||
AUTH_PASSWORD=admin123 # password
|
||||
TOKEN_SECRET=your-key # JWT key
|
||||
TOKEN_EXPIRE_HOURS=4 # expire duration
|
||||
WHITELIST_PATHS=/api1,/api2 # white list. /login,/health,/docs,/redoc,/openapi.json are whitelisted by default.
|
||||
```
|
||||
|
||||
## API Endpoints
|
||||
|
||||
All servers (LoLLMs, Ollama, OpenAI and Azure OpenAI) provide the same REST API endpoints for RAG functionality. When API Server is running, visit:
|
||||
|
41
lightrag/api/auth.py
Normal file
41
lightrag/api/auth.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
import jwt
|
||||
from fastapi import HTTPException, status
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
sub: str
|
||||
exp: datetime
|
||||
|
||||
|
||||
class AuthHandler:
|
||||
def __init__(self):
|
||||
self.secret = os.getenv("TOKEN_SECRET", "4f85ds4f56dsf46")
|
||||
self.algorithm = "HS256"
|
||||
self.expire_hours = int(os.getenv("TOKEN_EXPIRE_HOURS", 4))
|
||||
|
||||
def create_token(self, username: str) -> str:
|
||||
expire = datetime.utcnow() + timedelta(hours=self.expire_hours)
|
||||
payload = TokenPayload(sub=username, exp=expire)
|
||||
return jwt.encode(payload.dict(), self.secret, algorithm=self.algorithm)
|
||||
|
||||
def validate_token(self, token: str) -> str:
|
||||
try:
|
||||
payload = jwt.decode(token, self.secret, algorithms=[self.algorithm])
|
||||
expire_timestamp = payload["exp"]
|
||||
expire_time = datetime.utcfromtimestamp(expire_timestamp)
|
||||
|
||||
if datetime.utcnow() > expire_time:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token expired"
|
||||
)
|
||||
return payload["sub"]
|
||||
except jwt.PyJWTError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
|
||||
)
|
||||
|
||||
|
||||
auth_handler = AuthHandler()
|
@@ -2,10 +2,7 @@
|
||||
LightRAG FastAPI Server
|
||||
"""
|
||||
|
||||
from fastapi import (
|
||||
FastAPI,
|
||||
Depends,
|
||||
)
|
||||
from fastapi import FastAPI, Depends, HTTPException, status
|
||||
import asyncio
|
||||
import os
|
||||
import logging
|
||||
@@ -45,6 +42,8 @@ from lightrag.kg.shared_storage import (
|
||||
initialize_pipeline_status,
|
||||
get_all_update_flags_status,
|
||||
)
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from .auth import auth_handler
|
||||
|
||||
# Load environment variables
|
||||
# Updated to use the .env that is inside the current folder
|
||||
@@ -372,6 +371,27 @@ def create_app(args):
|
||||
ollama_api = OllamaAPI(rag, top_k=args.top_k)
|
||||
app.include_router(ollama_api.router, prefix="/api")
|
||||
|
||||
@app.post("/login")
|
||||
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
|
||||
username = os.getenv("AUTH_USERNAME")
|
||||
password = os.getenv("AUTH_PASSWORD")
|
||||
|
||||
if not (username and password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail="Authentication not configured",
|
||||
)
|
||||
|
||||
if form_data.username != username or form_data.password != password:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect credentials"
|
||||
)
|
||||
|
||||
return {
|
||||
"access_token": auth_handler.create_token(username),
|
||||
"token_type": "bearer",
|
||||
}
|
||||
|
||||
@app.get("/health", dependencies=[Depends(optional_api_key)])
|
||||
async def get_status():
|
||||
"""Get current system status"""
|
||||
|
@@ -1,10 +1,20 @@
|
||||
aiofiles
|
||||
ascii_colors
|
||||
asyncpg
|
||||
distro
|
||||
fastapi
|
||||
httpcore
|
||||
httpx
|
||||
jiter
|
||||
numpy
|
||||
openai
|
||||
passlib[bcrypt]
|
||||
pipmaster
|
||||
PyJWT
|
||||
python-dotenv
|
||||
python-jose[cryptography]
|
||||
python-multipart
|
||||
pytz
|
||||
tenacity
|
||||
tiktoken
|
||||
uvicorn
|
||||
|
@@ -18,8 +18,11 @@ from lightrag import LightRAG
|
||||
from lightrag.base import DocProcessingStatus, DocStatus
|
||||
from lightrag.api.utils_api import get_api_key_dependency, global_args
|
||||
|
||||
|
||||
router = APIRouter(prefix="/documents", tags=["documents"])
|
||||
router = APIRouter(
|
||||
prefix="/documents",
|
||||
tags=["documents"],
|
||||
dependencies=[Depends(get_auth_dependency())],
|
||||
)
|
||||
|
||||
# Temporary file prefix
|
||||
temp_prefix = "__tmp__"
|
||||
|
@@ -3,12 +3,11 @@ This module contains all graph-related routes for the LightRAG API.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from ..utils_api import get_api_key_dependency
|
||||
from ..utils_api import get_api_key_dependency, get_auth_dependency
|
||||
|
||||
router = APIRouter(tags=["graph"])
|
||||
router = APIRouter(tags=["graph"], dependencies=[Depends(get_auth_dependency())])
|
||||
|
||||
|
||||
def create_graph_routes(rag, api_key: Optional[str] = None):
|
||||
@@ -25,23 +24,33 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
|
||||
return await rag.get_graph_labels()
|
||||
|
||||
@router.get("/graphs", dependencies=[Depends(optional_api_key)])
|
||||
async def get_knowledge_graph(label: str, max_depth: int = 3):
|
||||
async def get_knowledge_graph(
|
||||
label: str, max_depth: int = 3, min_degree: int = 0, inclusive: bool = False
|
||||
):
|
||||
"""
|
||||
Retrieve a connected subgraph of nodes where the label includes the specified label.
|
||||
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
|
||||
When reducing the number of nodes, the prioritization criteria are as follows:
|
||||
1. Label matching nodes take precedence
|
||||
2. Followed by nodes directly connected to the matching nodes
|
||||
3. Finally, the degree of the nodes
|
||||
1. min_degree does not affect nodes directly connected to the matching nodes
|
||||
2. Label matching nodes take precedence
|
||||
3. Followed by nodes directly connected to the matching nodes
|
||||
4. Finally, the degree of the nodes
|
||||
Maximum number of nodes is limited to env MAX_GRAPH_NODES(default: 1000)
|
||||
|
||||
Args:
|
||||
label (str): Label to get knowledge graph for
|
||||
max_depth (int, optional): Maximum depth of graph. Defaults to 3.
|
||||
inclusive_search (bool, optional): If True, search for nodes that include the label. Defaults to False.
|
||||
min_degree (int, optional): Minimum degree of nodes. Defaults to 0.
|
||||
|
||||
Returns:
|
||||
Dict[str, List[str]]: Knowledge graph for label
|
||||
"""
|
||||
return await rag.get_knowledge_graph(node_label=label, max_depth=max_depth)
|
||||
return await rag.get_knowledge_graph(
|
||||
node_label=label,
|
||||
max_depth=max_depth,
|
||||
inclusive=inclusive,
|
||||
min_degree=min_degree,
|
||||
)
|
||||
|
||||
return router
|
||||
|
@@ -8,12 +8,12 @@ from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from lightrag.base import QueryParam
|
||||
from ..utils_api import get_api_key_dependency
|
||||
from ..utils_api import get_api_key_dependency, get_auth_dependency
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from ascii_colors import trace_exception
|
||||
|
||||
router = APIRouter(tags=["query"])
|
||||
router = APIRouter(tags=["query"], dependencies=[Depends(get_auth_dependency())])
|
||||
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
|
@@ -9,10 +9,11 @@ import sys
|
||||
import logging
|
||||
from ascii_colors import ASCIIColors
|
||||
from lightrag.api import __api_version__
|
||||
from fastapi import HTTPException, Security
|
||||
from fastapi import HTTPException, Security, Depends, Request
|
||||
from dotenv import load_dotenv
|
||||
from fastapi.security import APIKeyHeader
|
||||
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
|
||||
from starlette.status import HTTP_403_FORBIDDEN
|
||||
from .auth import auth_handler
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv(override=True)
|
||||
@@ -33,6 +34,24 @@ class OllamaServerInfos:
|
||||
ollama_server_infos = OllamaServerInfos()
|
||||
|
||||
|
||||
def get_auth_dependency():
|
||||
whitelist = os.getenv("WHITELIST_PATHS", "").split(",")
|
||||
|
||||
async def dependency(
|
||||
request: Request,
|
||||
token: str = Depends(OAuth2PasswordBearer(tokenUrl="login", auto_error=False)),
|
||||
):
|
||||
if request.url.path in whitelist:
|
||||
return
|
||||
|
||||
if not (os.getenv("AUTH_USERNAME") and os.getenv("AUTH_PASSWORD")):
|
||||
return
|
||||
|
||||
auth_handler.validate_token(token)
|
||||
|
||||
return dependency
|
||||
|
||||
|
||||
def get_api_key_dependency(api_key: Optional[str]):
|
||||
"""
|
||||
Create an API key dependency for route protection.
|
||||
|
1
lightrag/api/webui/assets/index-CH-3l4_Z.css
Normal file
1
lightrag/api/webui/assets/index-CH-3l4_Z.css
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -5,8 +5,8 @@
|
||||
<link rel="icon" type="image/svg+xml" href="./logo.png" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>Lightrag</title>
|
||||
<script type="module" crossorigin src="./assets/index-DbuMPJAD.js"></script>
|
||||
<link rel="stylesheet" crossorigin href="./assets/index-rP-YlyR1.css">
|
||||
<script type="module" crossorigin src="./assets/index-CJz72b6Q.js"></script>
|
||||
<link rel="stylesheet" crossorigin href="./assets/index-CH-3l4_Z.css">
|
||||
</head>
|
||||
<body>
|
||||
<div id="root"></div>
|
||||
|
Reference in New Issue
Block a user