Merge pull request #1015 from HKUDS/main

Update dev
This commit is contained in:
zrguo
2025-03-07 12:49:30 +08:00
committed by GitHub
25 changed files with 382 additions and 148 deletions

View File

@@ -148,3 +148,10 @@ QDRANT_URL=http://localhost:16333
### Redis
REDIS_URI=redis://localhost:6379
# For jwt auth
AUTH_USERNAME=admin # login name
AUTH_PASSWORD=admin123 # password
TOKEN_SECRET=your-key # JWT key
TOKEN_EXPIRE_HOURS=4 # expire duration
WHITELIST_PATHS=/login,/health # white list

View File

@@ -387,6 +387,19 @@ Note: If you don't need the API functionality, you can install the base package
pip install lightrag-hku
```
## Authentication Endpoints
### JWT Authentication Mechanism
LightRAG API Server implements JWT-based authentication using HS256 algorithm. To enable secure access control, the following environment variables are required:
```bash
# For jwt auth
AUTH_USERNAME=admin # login name
AUTH_PASSWORD=admin123 # password
TOKEN_SECRET=your-key # JWT key
TOKEN_EXPIRE_HOURS=4 # expire duration
WHITELIST_PATHS=/api1,/api2 # white list. /login,/health,/docs,/redoc,/openapi.json are whitelisted by default.
```
## API Endpoints
All servers (LoLLMs, Ollama, OpenAI and Azure OpenAI) provide the same REST API endpoints for RAG functionality. When API Server is running, visit:

41
lightrag/api/auth.py Normal file
View File

@@ -0,0 +1,41 @@
import os
from datetime import datetime, timedelta
import jwt
from fastapi import HTTPException, status
from pydantic import BaseModel
class TokenPayload(BaseModel):
sub: str
exp: datetime
class AuthHandler:
def __init__(self):
self.secret = os.getenv("TOKEN_SECRET", "4f85ds4f56dsf46")
self.algorithm = "HS256"
self.expire_hours = int(os.getenv("TOKEN_EXPIRE_HOURS", 4))
def create_token(self, username: str) -> str:
expire = datetime.utcnow() + timedelta(hours=self.expire_hours)
payload = TokenPayload(sub=username, exp=expire)
return jwt.encode(payload.dict(), self.secret, algorithm=self.algorithm)
def validate_token(self, token: str) -> str:
try:
payload = jwt.decode(token, self.secret, algorithms=[self.algorithm])
expire_timestamp = payload["exp"]
expire_time = datetime.utcfromtimestamp(expire_timestamp)
if datetime.utcnow() > expire_time:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token expired"
)
return payload["sub"]
except jwt.PyJWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
)
auth_handler = AuthHandler()

View File

@@ -2,10 +2,7 @@
LightRAG FastAPI Server
"""
from fastapi import (
FastAPI,
Depends,
)
from fastapi import FastAPI, Depends, HTTPException, status
import asyncio
import os
import logging
@@ -45,6 +42,8 @@ from lightrag.kg.shared_storage import (
initialize_pipeline_status,
get_all_update_flags_status,
)
from fastapi.security import OAuth2PasswordRequestForm
from .auth import auth_handler
# Load environment variables
# Updated to use the .env that is inside the current folder
@@ -372,6 +371,27 @@ def create_app(args):
ollama_api = OllamaAPI(rag, top_k=args.top_k)
app.include_router(ollama_api.router, prefix="/api")
@app.post("/login")
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
username = os.getenv("AUTH_USERNAME")
password = os.getenv("AUTH_PASSWORD")
if not (username and password):
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="Authentication not configured",
)
if form_data.username != username or form_data.password != password:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect credentials"
)
return {
"access_token": auth_handler.create_token(username),
"token_type": "bearer",
}
@app.get("/health", dependencies=[Depends(optional_api_key)])
async def get_status():
"""Get current system status"""

View File

@@ -1,10 +1,20 @@
aiofiles
ascii_colors
asyncpg
distro
fastapi
httpcore
httpx
jiter
numpy
openai
passlib[bcrypt]
pipmaster
PyJWT
python-dotenv
python-jose[cryptography]
python-multipart
pytz
tenacity
tiktoken
uvicorn

View File

@@ -16,10 +16,13 @@ from pydantic import BaseModel, Field, field_validator
from lightrag import LightRAG
from lightrag.base import DocProcessingStatus, DocStatus
from ..utils_api import get_api_key_dependency
from ..utils_api import get_api_key_dependency, get_auth_dependency
router = APIRouter(prefix="/documents", tags=["documents"])
router = APIRouter(
prefix="/documents",
tags=["documents"],
dependencies=[Depends(get_auth_dependency())],
)
# Temporary file prefix
temp_prefix = "__tmp__"

View File

@@ -3,12 +3,11 @@ This module contains all graph-related routes for the LightRAG API.
"""
from typing import Optional
from fastapi import APIRouter, Depends
from ..utils_api import get_api_key_dependency
from ..utils_api import get_api_key_dependency, get_auth_dependency
router = APIRouter(tags=["graph"])
router = APIRouter(tags=["graph"], dependencies=[Depends(get_auth_dependency())])
def create_graph_routes(rag, api_key: Optional[str] = None):
@@ -25,23 +24,33 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
return await rag.get_graph_labels()
@router.get("/graphs", dependencies=[Depends(optional_api_key)])
async def get_knowledge_graph(label: str, max_depth: int = 3):
async def get_knowledge_graph(
label: str, max_depth: int = 3, min_degree: int = 0, inclusive: bool = False
):
"""
Retrieve a connected subgraph of nodes where the label includes the specified label.
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
When reducing the number of nodes, the prioritization criteria are as follows:
1. Label matching nodes take precedence
2. Followed by nodes directly connected to the matching nodes
3. Finally, the degree of the nodes
1. min_degree does not affect nodes directly connected to the matching nodes
2. Label matching nodes take precedence
3. Followed by nodes directly connected to the matching nodes
4. Finally, the degree of the nodes
Maximum number of nodes is limited to env MAX_GRAPH_NODES(default: 1000)
Args:
label (str): Label to get knowledge graph for
max_depth (int, optional): Maximum depth of graph. Defaults to 3.
inclusive_search (bool, optional): If True, search for nodes that include the label. Defaults to False.
min_degree (int, optional): Minimum degree of nodes. Defaults to 0.
Returns:
Dict[str, List[str]]: Knowledge graph for label
"""
return await rag.get_knowledge_graph(node_label=label, max_depth=max_depth)
return await rag.get_knowledge_graph(
node_label=label,
max_depth=max_depth,
inclusive=inclusive,
min_degree=min_degree,
)
return router

View File

@@ -8,12 +8,12 @@ from typing import Any, Dict, List, Literal, Optional
from fastapi import APIRouter, Depends, HTTPException
from lightrag.base import QueryParam
from ..utils_api import get_api_key_dependency
from ..utils_api import get_api_key_dependency, get_auth_dependency
from pydantic import BaseModel, Field, field_validator
from ascii_colors import trace_exception
router = APIRouter(tags=["query"])
router = APIRouter(tags=["query"], dependencies=[Depends(get_auth_dependency())])
class QueryRequest(BaseModel):

View File

@@ -9,10 +9,11 @@ import sys
import logging
from ascii_colors import ASCIIColors
from lightrag.api import __api_version__
from fastapi import HTTPException, Security
from fastapi import HTTPException, Security, Depends, Request
from dotenv import load_dotenv
from fastapi.security import APIKeyHeader
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
from starlette.status import HTTP_403_FORBIDDEN
from .auth import auth_handler
# Load environment variables
load_dotenv(override=True)
@@ -31,6 +32,24 @@ class OllamaServerInfos:
ollama_server_infos = OllamaServerInfos()
def get_auth_dependency():
whitelist = os.getenv("WHITELIST_PATHS", "").split(",")
async def dependency(
request: Request,
token: str = Depends(OAuth2PasswordBearer(tokenUrl="login", auto_error=False)),
):
if request.url.path in whitelist:
return
if not (os.getenv("AUTH_USERNAME") and os.getenv("AUTH_PASSWORD")):
return
auth_handler.validate_token(token)
return dependency
def get_api_key_dependency(api_key: Optional[str]):
"""
Create an API key dependency for route protection.

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -5,8 +5,8 @@
<link rel="icon" type="image/svg+xml" href="./logo.png" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Lightrag</title>
<script type="module" crossorigin src="./assets/index-DbuMPJAD.js"></script>
<link rel="stylesheet" crossorigin href="./assets/index-rP-YlyR1.css">
<script type="module" crossorigin src="./assets/index-CJz72b6Q.js"></script>
<link rel="stylesheet" crossorigin href="./assets/index-CH-3l4_Z.css">
</head>
<body>
<div id="root"></div>

View File

@@ -204,7 +204,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
@abstractmethod
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
self, node_label: str, max_depth: int = 3
) -> KnowledgeGraph:
"""Retrieve a subgraph of the knowledge graph starting from a given node."""

View File

@@ -232,19 +232,26 @@ class NetworkXStorage(BaseGraphStorage):
return sorted(list(labels))
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
self,
node_label: str,
max_depth: int = 3,
min_degree: int = 0,
inclusive: bool = False,
) -> KnowledgeGraph:
"""
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
When reducing the number of nodes, the prioritization criteria are as follows:
1. Label matching nodes take precedence
2. Followed by nodes directly connected to the matching nodes
3. Finally, the degree of the nodes
1. min_degree does not affect nodes directly connected to the matching nodes
2. Label matching nodes take precedence
3. Followed by nodes directly connected to the matching nodes
4. Finally, the degree of the nodes
Args:
node_label: Label of the starting node
max_depth: Maximum depth of the subgraph
min_degree: Minimum degree of nodes to include. Defaults to 0
inclusive: Do an inclusive search if true
Returns:
KnowledgeGraph object containing nodes and edges
@@ -255,6 +262,10 @@ class NetworkXStorage(BaseGraphStorage):
graph = await self._get_graph()
# Initialize sets for start nodes and direct connected nodes
start_nodes = set()
direct_connected_nodes = set()
# Handle special case for "*" label
if node_label == "*":
# For "*", return the entire graph including all nodes and edges
@@ -262,10 +273,15 @@ class NetworkXStorage(BaseGraphStorage):
graph.copy()
) # Create a copy to avoid modifying the original graph
else:
# Find nodes with matching node id (partial match)
# Find nodes with matching node id based on search_mode
nodes_to_explore = []
for n, attr in graph.nodes(data=True):
if node_label in str(n): # Use partial matching
node_str = str(n)
if not inclusive:
if node_label == node_str: # Use exact matching
nodes_to_explore.append(n)
else: # inclusive mode
if node_label in node_str: # Use partial matching
nodes_to_explore.append(n)
if not nodes_to_explore:
@@ -277,26 +293,37 @@ class NetworkXStorage(BaseGraphStorage):
for start_node in nodes_to_explore:
node_subgraph = nx.ego_graph(graph, start_node, radius=max_depth)
combined_subgraph = nx.compose(combined_subgraph, node_subgraph)
# Get start nodes and direct connected nodes
if nodes_to_explore:
start_nodes = set(nodes_to_explore)
# Get nodes directly connected to all start nodes
for start_node in start_nodes:
direct_connected_nodes.update(
combined_subgraph.neighbors(start_node)
)
# Remove start nodes from directly connected nodes (avoid duplicates)
direct_connected_nodes -= start_nodes
subgraph = combined_subgraph
# Filter nodes based on min_degree, but keep start nodes and direct connected nodes
if min_degree > 0:
nodes_to_keep = [
node
for node, degree in subgraph.degree()
if node in start_nodes
or node in direct_connected_nodes
or degree >= min_degree
]
subgraph = subgraph.subgraph(nodes_to_keep)
# Check if number of nodes exceeds max_graph_nodes
if len(subgraph.nodes()) > MAX_GRAPH_NODES:
origin_nodes = len(subgraph.nodes())
node_degrees = dict(subgraph.degree())
start_nodes = set()
direct_connected_nodes = set()
if node_label != "*" and nodes_to_explore:
start_nodes = set(nodes_to_explore)
# Get nodes directly connected to all start nodes
for start_node in start_nodes:
direct_connected_nodes.update(subgraph.neighbors(start_node))
# Remove start nodes from directly connected nodes (avoid duplicates)
direct_connected_nodes -= start_nodes
def priority_key(node_item):
node, degree = node_item
# Priority order: start(2) > directly connected(1) > other nodes(0)
@@ -356,7 +383,7 @@ class NetworkXStorage(BaseGraphStorage):
result.edges.append(
KnowledgeGraphEdge(
id=edge_id,
type="RELATED",
type="DIRECTED",
source=str(source),
target=str(target),
properties=edge_data,

View File

@@ -504,11 +504,39 @@ class LightRAG:
return text
async def get_knowledge_graph(
self, node_label: str, max_depth: int
self,
node_label: str,
max_depth: int = 3,
min_degree: int = 0,
inclusive: bool = False,
) -> KnowledgeGraph:
return await self.chunk_entity_relation_graph.get_knowledge_graph(
node_label=node_label, max_depth=max_depth
)
"""Get knowledge graph for a given label
Args:
node_label (str): Label to get knowledge graph for
max_depth (int): Maximum depth of graph
min_degree (int, optional): Minimum degree of nodes to include. Defaults to 0.
inclusive (bool, optional): Whether to use inclusive search mode. Defaults to False.
Returns:
KnowledgeGraph: Knowledge graph containing nodes and edges
"""
# get params supported by get_knowledge_graph of specified storage
import inspect
storage_params = inspect.signature(
self.chunk_entity_relation_graph.get_knowledge_graph
).parameters
kwargs = {"node_label": node_label, "max_depth": max_depth}
if "min_degree" in storage_params and min_degree > 0:
kwargs["min_degree"] = min_degree
if "inclusive" in storage_params:
kwargs["inclusive"] = inclusive
return await self.chunk_entity_relation_graph.get_knowledge_graph(**kwargs)
def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
import_path = STORAGES[storage_name]

View File

@@ -161,8 +161,12 @@ axiosInstance.interceptors.response.use(
)
// API methods
export const queryGraphs = async (label: string, maxDepth: number): Promise<LightragGraphType> => {
const response = await axiosInstance.get(`/graphs?label=${label}&max_depth=${maxDepth}`)
export const queryGraphs = async (
label: string,
maxDepth: number,
minDegree: number
): Promise<LightragGraphType> => {
const response = await axiosInstance.get(`/graphs?label=${encodeURIComponent(label)}&max_depth=${maxDepth}&min_degree=${minDegree}`)
return response.data
}

View File

@@ -40,18 +40,21 @@ const GraphControl = ({ disableHoverEffect }: { disableHoverEffect?: boolean })
const focusedEdge = useGraphStore.use.focusedEdge()
/**
* When component mount
* => load the graph
* When component mount or maxIterations changes
* => load the graph and apply layout
*/
useEffect(() => {
// Create & load the graph
const graph = lightrageGraph()
loadGraph(graph)
if (!(graph as any).__force_applied) {
assignLayout()
Object.assign(graph, { __force_applied: true })
}
}, [assignLayout, loadGraph, lightrageGraph, maxIterations])
/**
* When component mount
* => register events
*/
useEffect(() => {
const { setFocusedNode, setSelectedNode, setFocusedEdge, setSelectedEdge, clearSelection } =
useGraphStore.getState()
@@ -87,7 +90,7 @@ const GraphControl = ({ disableHoverEffect }: { disableHoverEffect?: boolean })
},
clickStage: () => clearSelection()
})
}, [assignLayout, loadGraph, registerEvents, lightrageGraph])
}, [registerEvents])
/**
* When component mount or hovered node change

View File

@@ -90,9 +90,12 @@ const LabeledNumberInput = ({
{label}
</label>
<Input
value={currentValue || ''}
type="number"
value={currentValue === null ? '' : currentValue}
onChange={onValueChange}
className="h-6 w-full min-w-0"
className="h-6 w-full min-w-0 pr-1"
min={min}
max={max}
onBlur={onBlur}
onKeyDown={(e) => {
if (e.key === 'Enter') {
@@ -119,6 +122,7 @@ export default function Settings() {
const enableHideUnselectedEdges = useSettingsStore.use.enableHideUnselectedEdges()
const showEdgeLabel = useSettingsStore.use.showEdgeLabel()
const graphQueryMaxDepth = useSettingsStore.use.graphQueryMaxDepth()
const graphMinDegree = useSettingsStore.use.graphMinDegree()
const graphLayoutMaxIterations = useSettingsStore.use.graphLayoutMaxIterations()
const enableHealthCheck = useSettingsStore.use.enableHealthCheck()
@@ -177,6 +181,11 @@ export default function Settings() {
useSettingsStore.setState({ graphQueryMaxDepth: depth })
}, [])
const setGraphMinDegree = useCallback((degree: number) => {
if (degree < 0) return
useSettingsStore.setState({ graphMinDegree: degree })
}, [])
const setGraphLayoutMaxIterations = useCallback((iterations: number) => {
if (iterations < 1) return
useSettingsStore.setState({ graphLayoutMaxIterations: iterations })
@@ -266,6 +275,12 @@ export default function Settings() {
value={graphQueryMaxDepth}
onEditFinished={setGraphQueryMaxDepth}
/>
<LabeledNumberInput
label="Minimum Degree"
min={0}
value={graphMinDegree}
onEditFinished={setGraphMinDegree}
/>
<LabeledNumberInput
label="Max Layout Iterations"
min={1}

View File

@@ -7,7 +7,7 @@ const Input = React.forwardRef<HTMLInputElement, React.ComponentProps<'input'>>(
<input
type={type}
className={cn(
'border-input file:text-foreground placeholder:text-muted-foreground focus-visible:ring-ring flex h-9 rounded-md border bg-transparent px-3 py-1 text-base shadow-sm transition-colors file:border-0 file:bg-transparent file:text-sm file:font-medium focus-visible:ring-1 focus-visible:outline-none disabled:cursor-not-allowed disabled:opacity-50 md:text-sm',
'border-input file:text-foreground placeholder:text-muted-foreground focus-visible:ring-ring flex h-9 rounded-md border bg-transparent px-3 py-1 text-base shadow-sm transition-colors file:border-0 file:bg-transparent file:text-sm file:font-medium focus-visible:ring-1 focus-visible:outline-none disabled:cursor-not-allowed disabled:opacity-50 md:text-sm [&::-webkit-inner-spin-button]:opacity-100 [&::-webkit-outer-spin-button]:opacity-100',
className
)}
ref={ref}

View File

@@ -50,11 +50,11 @@ export type NodeType = {
}
export type EdgeType = { label: string }
const fetchGraph = async (label: string, maxDepth: number) => {
const fetchGraph = async (label: string, maxDepth: number, minDegree: number) => {
let rawData: any = null
try {
rawData = await queryGraphs(label, maxDepth)
rawData = await queryGraphs(label, maxDepth, minDegree)
} catch (e) {
useBackendState.getState().setErrorMessage(errorMessage(e), 'Query Graphs Error!')
return null
@@ -161,13 +161,14 @@ const createSigmaGraph = (rawGraph: RawGraph | null) => {
return graph
}
const lastQueryLabel = { label: '', maxQueryDepth: 0 }
const lastQueryLabel = { label: '', maxQueryDepth: 0, minDegree: 0 }
const useLightrangeGraph = () => {
const queryLabel = useSettingsStore.use.queryLabel()
const rawGraph = useGraphStore.use.rawGraph()
const sigmaGraph = useGraphStore.use.sigmaGraph()
const maxQueryDepth = useSettingsStore.use.graphQueryMaxDepth()
const minDegree = useSettingsStore.use.graphMinDegree()
const getNode = useCallback(
(nodeId: string) => {
@@ -185,13 +186,16 @@ const useLightrangeGraph = () => {
useEffect(() => {
if (queryLabel) {
if (lastQueryLabel.label !== queryLabel || lastQueryLabel.maxQueryDepth !== maxQueryDepth) {
if (lastQueryLabel.label !== queryLabel ||
lastQueryLabel.maxQueryDepth !== maxQueryDepth ||
lastQueryLabel.minDegree !== minDegree) {
lastQueryLabel.label = queryLabel
lastQueryLabel.maxQueryDepth = maxQueryDepth
lastQueryLabel.minDegree = minDegree
const state = useGraphStore.getState()
state.reset()
fetchGraph(queryLabel, maxQueryDepth).then((data) => {
fetchGraph(queryLabel, maxQueryDepth, minDegree).then((data) => {
// console.debug('Query label: ' + queryLabel)
state.setSigmaGraph(createSigmaGraph(data))
data?.buildDynamicMap()
@@ -203,7 +207,7 @@ const useLightrangeGraph = () => {
state.reset()
state.setSigmaGraph(new DirectedGraph())
}
}, [queryLabel, maxQueryDepth])
}, [queryLabel, maxQueryDepth, minDegree])
const lightrageGraph = useCallback(() => {
if (sigmaGraph) {

View File

@@ -22,6 +22,9 @@ interface SettingsState {
graphQueryMaxDepth: number
setGraphQueryMaxDepth: (depth: number) => void
graphMinDegree: number
setGraphMinDegree: (degree: number) => void
graphLayoutMaxIterations: number
setGraphLayoutMaxIterations: (iterations: number) => void
@@ -66,6 +69,7 @@ const useSettingsStoreBase = create<SettingsState>()(
enableEdgeEvents: false,
graphQueryMaxDepth: 3,
graphMinDegree: 0,
graphLayoutMaxIterations: 10,
queryLabel: defaultQueryLabel,
@@ -107,6 +111,8 @@ const useSettingsStoreBase = create<SettingsState>()(
setGraphQueryMaxDepth: (depth: number) => set({ graphQueryMaxDepth: depth }),
setGraphMinDegree: (degree: number) => set({ graphMinDegree: degree }),
setEnableHealthCheck: (enable: boolean) => set({ enableHealthCheck: enable }),
setApiKey: (apiKey: string | null) => set({ apiKey }),

View File

@@ -1 +1,11 @@
/// <reference types="vite/client" />
interface ImportMetaEnv {
readonly VITE_API_PROXY: string
readonly VITE_API_ENDPOINTS: string
readonly VITE_BACKEND_URL: string
}
interface ImportMeta {
readonly env: ImportMetaEnv
}

View File

@@ -26,5 +26,5 @@
"@/*": ["./src/*"]
}
},
"include": ["src"]
"include": ["src", "vite.config.ts"]
}

View File

@@ -14,6 +14,21 @@ export default defineConfig({
},
base: './',
build: {
outDir: path.resolve(__dirname, '../lightrag/api/webui')
outDir: path.resolve(__dirname, '../lightrag/api/webui'),
emptyOutDir: true
},
server: {
proxy: import.meta.env.VITE_API_PROXY === 'true' && import.meta.env.VITE_API_ENDPOINTS ?
Object.fromEntries(
import.meta.env.VITE_API_ENDPOINTS.split(',').map(endpoint => [
endpoint,
{
target: import.meta.env.VITE_BACKEND_URL || 'http://localhost:9621',
changeOrigin: true,
rewrite: endpoint === '/api' ?
(path) => path.replace(/^\/api/, '') : undefined
}
])
) : {}
}
})