Added servers protection using an API key to restrict access to only authenticated entities.
This commit is contained in:
@@ -1025,6 +1025,7 @@ Each server has its own specific configuration options:
|
||||
| --max-embed-tokens | 8192 | Maximum embedding token size |
|
||||
| --input-file | ./book.txt | Initial input file |
|
||||
| --log-level | INFO | Logging level |
|
||||
| --key | none | Access Key to protect the lightrag service |
|
||||
|
||||
#### Ollama Server Options
|
||||
|
||||
@@ -1042,6 +1043,7 @@ Each server has its own specific configuration options:
|
||||
| --max-embed-tokens | 8192 | Maximum embedding token size |
|
||||
| --input-file | ./book.txt | Initial input file |
|
||||
| --log-level | INFO | Logging level |
|
||||
| --key | none | Access Key to protect the lightrag service |
|
||||
|
||||
#### OpenAI Server Options
|
||||
|
||||
@@ -1056,6 +1058,7 @@ Each server has its own specific configuration options:
|
||||
| --max-embed-tokens | 8192 | Maximum embedding token size |
|
||||
| --input-dir | ./inputs | Input directory for documents |
|
||||
| --log-level | INFO | Logging level |
|
||||
| --key | none | Access Key to protect the lightrag service |
|
||||
|
||||
#### OpenAI AZURE Server Options
|
||||
|
||||
@@ -1071,8 +1074,10 @@ Each server has its own specific configuration options:
|
||||
| --input-dir | ./inputs | Input directory for documents |
|
||||
| --enable-cache | True | Enable response cache |
|
||||
| --log-level | INFO | Logging level |
|
||||
| --key | none | Access Key to protect the lightrag service |
|
||||
|
||||
|
||||
For protecting the server using an authentication key, you can also use an environment variable named `LIGHTRAG_API_KEY`.
|
||||
### Example Usage
|
||||
|
||||
#### LoLLMs RAG Server
|
||||
|
@@ -20,6 +20,19 @@ from dotenv import load_dotenv
|
||||
import inspect
|
||||
import json
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi import FastAPI, HTTPException
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import FastAPI, Depends, HTTPException, Security
|
||||
from fastapi.security import APIKeyHeader
|
||||
import os
|
||||
import argparse
|
||||
from typing import Optional
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from starlette.status import HTTP_403_FORBIDDEN
|
||||
from fastapi import HTTPException
|
||||
|
||||
load_dotenv()
|
||||
|
||||
@@ -93,6 +106,9 @@ def parse_args():
|
||||
help="Logging level (default: INFO)",
|
||||
)
|
||||
|
||||
parser.add_argument('--key', type=str, help='API key for authentication. This protects lightrag server against unauthorized access', default=None)
|
||||
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -154,6 +170,31 @@ class InsertResponse(BaseModel):
|
||||
message: str
|
||||
document_count: int
|
||||
|
||||
def get_api_key_dependency(api_key: Optional[str]):
|
||||
if not api_key:
|
||||
# If no API key is configured, return a dummy dependency that always succeeds
|
||||
async def no_auth():
|
||||
return None
|
||||
return no_auth
|
||||
|
||||
# If API key is configured, use proper authentication
|
||||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
|
||||
async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)):
|
||||
if not api_key_header_value:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN,
|
||||
detail="API Key required"
|
||||
)
|
||||
if api_key_header_value != api_key:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN,
|
||||
detail="Invalid API Key"
|
||||
)
|
||||
return api_key_header_value
|
||||
|
||||
return api_key_auth
|
||||
|
||||
|
||||
async def get_embedding_dim(embedding_model: str) -> int:
|
||||
"""Get embedding dimensions for the specified model"""
|
||||
@@ -168,11 +209,29 @@ def create_app(args):
|
||||
format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level)
|
||||
)
|
||||
|
||||
# Initialize FastAPI app
|
||||
|
||||
# Check if API key is provided either through env var or args
|
||||
api_key = os.getenv("LIGHTRAG_API_KEY") or args.key
|
||||
|
||||
# Initialize FastAPI
|
||||
app = FastAPI(
|
||||
title="LightRAG API",
|
||||
description="API for querying text using LightRAG with OpenAI integration",
|
||||
description="API for querying text using LightRAG with separate storage and input directories"+"(With authentication)" if api_key else "",
|
||||
version="1.0.0",
|
||||
openapi_tags=[{"name": "api"}]
|
||||
)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Create the optional API key dependency
|
||||
optional_api_key = get_api_key_dependency(api_key)
|
||||
|
||||
# Create working directory if it doesn't exist
|
||||
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
|
||||
@@ -239,7 +298,7 @@ def create_app(args):
|
||||
except Exception as e:
|
||||
logging.error(f"Error during startup indexing: {str(e)}")
|
||||
|
||||
@app.post("/documents/scan")
|
||||
@app.post("/documents/scan", dependencies=[Depends(optional_api_key)])
|
||||
async def scan_for_new_documents():
|
||||
"""Manually trigger scanning for new documents"""
|
||||
try:
|
||||
@@ -264,7 +323,7 @@ def create_app(args):
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/resetcache")
|
||||
@app.post("/resetcache", dependencies=[Depends(optional_api_key)])
|
||||
async def reset_cache():
|
||||
"""Manually reset cache"""
|
||||
try:
|
||||
@@ -276,7 +335,7 @@ def create_app(args):
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/documents/upload")
|
||||
@app.post("/documents/upload", dependencies=[Depends(optional_api_key)])
|
||||
async def upload_to_input_dir(file: UploadFile = File(...)):
|
||||
"""Upload a file to the input directory"""
|
||||
try:
|
||||
@@ -304,7 +363,7 @@ def create_app(args):
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/query", response_model=QueryResponse)
|
||||
@app.post("/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)])
|
||||
async def query_text(request: QueryRequest):
|
||||
try:
|
||||
response = await rag.aquery(
|
||||
@@ -319,7 +378,7 @@ def create_app(args):
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/query/stream")
|
||||
@app.post("/query/stream", dependencies=[Depends(optional_api_key)])
|
||||
async def query_text_stream(request: QueryRequest):
|
||||
try:
|
||||
response = await rag.aquery(
|
||||
@@ -345,7 +404,7 @@ def create_app(args):
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/documents/text", response_model=InsertResponse)
|
||||
@app.post("/documents/text", response_model=InsertResponse, dependencies=[Depends(optional_api_key)])
|
||||
async def insert_text(request: InsertTextRequest):
|
||||
try:
|
||||
await rag.ainsert(request.text)
|
||||
@@ -357,7 +416,7 @@ def create_app(args):
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/documents/file", response_model=InsertResponse)
|
||||
@app.post("/documents/file", response_model=InsertResponse, dependencies=[Depends(optional_api_key)])
|
||||
async def insert_file(file: UploadFile = File(...), description: str = Form(None)):
|
||||
try:
|
||||
content = await file.read()
|
||||
@@ -381,7 +440,7 @@ def create_app(args):
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/documents/batch", response_model=InsertResponse)
|
||||
@app.post("/documents/batch", response_model=InsertResponse, dependencies=[Depends(optional_api_key)])
|
||||
async def insert_batch(files: List[UploadFile] = File(...)):
|
||||
try:
|
||||
inserted_count = 0
|
||||
@@ -411,7 +470,7 @@ def create_app(args):
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.delete("/documents", response_model=InsertResponse)
|
||||
@app.delete("/documents", response_model=InsertResponse, dependencies=[Depends(optional_api_key)])
|
||||
async def clear_documents():
|
||||
try:
|
||||
rag.text_chunks = []
|
||||
@@ -425,7 +484,7 @@ def create_app(args):
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/health")
|
||||
@app.get("/health", dependencies=[Depends(optional_api_key)])
|
||||
async def get_status():
|
||||
"""Get current system status"""
|
||||
return {
|
||||
|
@@ -11,7 +11,19 @@ from pathlib import Path
|
||||
import shutil
|
||||
import aiofiles
|
||||
from ascii_colors import trace_exception
|
||||
from fastapi import FastAPI, HTTPException
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import FastAPI, Depends, HTTPException, Security
|
||||
from fastapi.security import APIKeyHeader
|
||||
import os
|
||||
import argparse
|
||||
from typing import Optional
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from starlette.status import HTTP_403_FORBIDDEN
|
||||
from fastapi import HTTPException
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
@@ -86,6 +98,9 @@ def parse_args():
|
||||
help="Logging level (default: INFO)",
|
||||
)
|
||||
|
||||
parser.add_argument('--key', type=str, help='API key for authentication. This protects lightrag server against unauthorized access', default=None)
|
||||
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -147,6 +162,31 @@ class InsertResponse(BaseModel):
|
||||
message: str
|
||||
document_count: int
|
||||
|
||||
def get_api_key_dependency(api_key: Optional[str]):
|
||||
if not api_key:
|
||||
# If no API key is configured, return a dummy dependency that always succeeds
|
||||
async def no_auth():
|
||||
return None
|
||||
return no_auth
|
||||
|
||||
# If API key is configured, use proper authentication
|
||||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
|
||||
async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)):
|
||||
if not api_key_header_value:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN,
|
||||
detail="API Key required"
|
||||
)
|
||||
if api_key_header_value != api_key:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN,
|
||||
detail="Invalid API Key"
|
||||
)
|
||||
return api_key_header_value
|
||||
|
||||
return api_key_auth
|
||||
|
||||
|
||||
def create_app(args):
|
||||
# Setup logging
|
||||
@@ -154,11 +194,28 @@ def create_app(args):
|
||||
format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level)
|
||||
)
|
||||
|
||||
# Initialize FastAPI app
|
||||
# Check if API key is provided either through env var or args
|
||||
api_key = os.getenv("LIGHTRAG_API_KEY") or args.key
|
||||
|
||||
# Initialize FastAPI
|
||||
app = FastAPI(
|
||||
title="LightRAG API",
|
||||
description="API for querying text using LightRAG with separate storage and input directories",
|
||||
description="API for querying text using LightRAG with separate storage and input directories"+"(With authentication)" if api_key else "",
|
||||
version="1.0.0",
|
||||
openapi_tags=[{"name": "api"}]
|
||||
)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Create the optional API key dependency
|
||||
optional_api_key = get_api_key_dependency(api_key)
|
||||
|
||||
# Create working directory if it doesn't exist
|
||||
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
|
||||
@@ -209,7 +266,7 @@ def create_app(args):
|
||||
except Exception as e:
|
||||
logging.error(f"Error during startup indexing: {str(e)}")
|
||||
|
||||
@app.post("/documents/scan")
|
||||
@app.post("/documents/scan", dependencies=[Depends(optional_api_key)])
|
||||
async def scan_for_new_documents():
|
||||
"""Manually trigger scanning for new documents"""
|
||||
try:
|
||||
@@ -234,7 +291,7 @@ def create_app(args):
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/documents/upload")
|
||||
@app.post("/documents/upload", dependencies=[Depends(optional_api_key)])
|
||||
async def upload_to_input_dir(file: UploadFile = File(...)):
|
||||
"""Upload a file to the input directory"""
|
||||
try:
|
||||
@@ -262,7 +319,7 @@ def create_app(args):
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/query", response_model=QueryResponse)
|
||||
@app.post("/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)])
|
||||
async def query_text(request: QueryRequest):
|
||||
try:
|
||||
response = await rag.aquery(
|
||||
@@ -284,7 +341,7 @@ def create_app(args):
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/query/stream")
|
||||
@app.post("/query/stream", dependencies=[Depends(optional_api_key)])
|
||||
async def query_text_stream(request: QueryRequest):
|
||||
try:
|
||||
response = rag.query(
|
||||
@@ -304,7 +361,7 @@ def create_app(args):
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/documents/text", response_model=InsertResponse)
|
||||
@app.post("/documents/text", response_model=InsertResponse, dependencies=[Depends(optional_api_key)])
|
||||
async def insert_text(request: InsertTextRequest):
|
||||
try:
|
||||
rag.insert(request.text)
|
||||
@@ -316,7 +373,7 @@ def create_app(args):
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/documents/file", response_model=InsertResponse)
|
||||
@app.post("/documents/file", response_model=InsertResponse, dependencies=[Depends(optional_api_key)])
|
||||
async def insert_file(file: UploadFile = File(...), description: str = Form(None)):
|
||||
try:
|
||||
content = await file.read()
|
||||
@@ -340,7 +397,7 @@ def create_app(args):
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/documents/batch", response_model=InsertResponse)
|
||||
@app.post("/documents/batch", response_model=InsertResponse, dependencies=[Depends(optional_api_key)])
|
||||
async def insert_batch(files: List[UploadFile] = File(...)):
|
||||
try:
|
||||
inserted_count = 0
|
||||
@@ -370,7 +427,7 @@ def create_app(args):
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.delete("/documents", response_model=InsertResponse)
|
||||
@app.delete("/documents", response_model=InsertResponse, dependencies=[Depends(optional_api_key)])
|
||||
async def clear_documents():
|
||||
try:
|
||||
rag.text_chunks = []
|
||||
@@ -384,7 +441,7 @@ def create_app(args):
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/health")
|
||||
@app.get("/health", dependencies=[Depends(optional_api_key)])
|
||||
async def get_status():
|
||||
"""Get current system status"""
|
||||
return {
|
||||
|
@@ -11,6 +11,19 @@ from pathlib import Path
|
||||
import shutil
|
||||
import aiofiles
|
||||
from ascii_colors import trace_exception
|
||||
from fastapi import FastAPI, HTTPException
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import FastAPI, Depends, HTTPException, Security
|
||||
from fastapi.security import APIKeyHeader
|
||||
import os
|
||||
import argparse
|
||||
from typing import Optional
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from starlette.status import HTTP_403_FORBIDDEN
|
||||
from fastapi import HTTPException
|
||||
|
||||
|
||||
def parse_args():
|
||||
@@ -85,6 +98,7 @@ def parse_args():
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
||||
help="Logging level (default: INFO)",
|
||||
)
|
||||
parser.add_argument('--key', type=str, help='API key for authentication. This protects lightrag server against unauthorized access', default=None)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
@@ -147,6 +161,31 @@ class InsertResponse(BaseModel):
|
||||
message: str
|
||||
document_count: int
|
||||
|
||||
def get_api_key_dependency(api_key: Optional[str]):
|
||||
if not api_key:
|
||||
# If no API key is configured, return a dummy dependency that always succeeds
|
||||
async def no_auth():
|
||||
return None
|
||||
return no_auth
|
||||
|
||||
# If API key is configured, use proper authentication
|
||||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
|
||||
async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)):
|
||||
if not api_key_header_value:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN,
|
||||
detail="API Key required"
|
||||
)
|
||||
if api_key_header_value != api_key:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN,
|
||||
detail="Invalid API Key"
|
||||
)
|
||||
return api_key_header_value
|
||||
|
||||
return api_key_auth
|
||||
|
||||
|
||||
def create_app(args):
|
||||
# Setup logging
|
||||
@@ -154,11 +193,29 @@ def create_app(args):
|
||||
format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level)
|
||||
)
|
||||
|
||||
# Initialize FastAPI app
|
||||
# Check if API key is provided either through env var or args
|
||||
api_key = os.getenv("LIGHTRAG_API_KEY") or args.key
|
||||
|
||||
# Initialize FastAPI
|
||||
app = FastAPI(
|
||||
title="LightRAG API",
|
||||
description="API for querying text using LightRAG with separate storage and input directories",
|
||||
description="API for querying text using LightRAG with separate storage and input directories"+"(With authentication)" if api_key else "",
|
||||
version="1.0.0",
|
||||
openapi_tags=[{"name": "api"}]
|
||||
)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Create the optional API key dependency
|
||||
optional_api_key = get_api_key_dependency(api_key)
|
||||
|
||||
|
||||
# Create working directory if it doesn't exist
|
||||
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
|
||||
@@ -209,7 +266,7 @@ def create_app(args):
|
||||
except Exception as e:
|
||||
logging.error(f"Error during startup indexing: {str(e)}")
|
||||
|
||||
@app.post("/documents/scan")
|
||||
@app.post("/documents/scan", dependencies=[Depends(optional_api_key)])
|
||||
async def scan_for_new_documents():
|
||||
"""Manually trigger scanning for new documents"""
|
||||
try:
|
||||
@@ -234,7 +291,7 @@ def create_app(args):
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/documents/upload")
|
||||
@app.post("/documents/upload", dependencies=[Depends(optional_api_key)])
|
||||
async def upload_to_input_dir(file: UploadFile = File(...)):
|
||||
"""Upload a file to the input directory"""
|
||||
try:
|
||||
@@ -262,7 +319,7 @@ def create_app(args):
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/query", response_model=QueryResponse)
|
||||
@app.post("/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)])
|
||||
async def query_text(request: QueryRequest):
|
||||
try:
|
||||
response = await rag.aquery(
|
||||
@@ -284,7 +341,7 @@ def create_app(args):
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/query/stream")
|
||||
@app.post("/query/stream", dependencies=[Depends(optional_api_key)])
|
||||
async def query_text_stream(request: QueryRequest):
|
||||
try:
|
||||
response = rag.query(
|
||||
@@ -304,7 +361,7 @@ def create_app(args):
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/documents/text", response_model=InsertResponse)
|
||||
@app.post("/documents/text", response_model=InsertResponse, dependencies=[Depends(optional_api_key)])
|
||||
async def insert_text(request: InsertTextRequest):
|
||||
try:
|
||||
await rag.ainsert(request.text)
|
||||
@@ -316,7 +373,7 @@ def create_app(args):
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/documents/file", response_model=InsertResponse)
|
||||
@app.post("/documents/file", response_model=InsertResponse, dependencies=[Depends(optional_api_key)])
|
||||
async def insert_file(file: UploadFile = File(...), description: str = Form(None)):
|
||||
try:
|
||||
content = await file.read()
|
||||
@@ -340,7 +397,7 @@ def create_app(args):
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/documents/batch", response_model=InsertResponse)
|
||||
@app.post("/documents/batch", response_model=InsertResponse, dependencies=[Depends(optional_api_key)])
|
||||
async def insert_batch(files: List[UploadFile] = File(...)):
|
||||
try:
|
||||
inserted_count = 0
|
||||
@@ -370,7 +427,7 @@ def create_app(args):
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.delete("/documents", response_model=InsertResponse)
|
||||
@app.delete("/documents", response_model=InsertResponse, dependencies=[Depends(optional_api_key)])
|
||||
async def clear_documents():
|
||||
try:
|
||||
rag.text_chunks = []
|
||||
@@ -384,7 +441,7 @@ def create_app(args):
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/health")
|
||||
@app.get("/health", dependencies=[Depends(optional_api_key)])
|
||||
async def get_status():
|
||||
"""Get current system status"""
|
||||
return {
|
||||
|
@@ -14,6 +14,20 @@ import aiofiles
|
||||
from ascii_colors import trace_exception
|
||||
import nest_asyncio
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import FastAPI, Depends, HTTPException, Security
|
||||
from fastapi.security import APIKeyHeader
|
||||
import os
|
||||
import argparse
|
||||
from typing import Optional
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from starlette.status import HTTP_403_FORBIDDEN
|
||||
from fastapi import HTTPException
|
||||
|
||||
# Apply nest_asyncio to solve event loop issues
|
||||
nest_asyncio.apply()
|
||||
|
||||
@@ -75,6 +89,9 @@ def parse_args():
|
||||
help="Logging level (default: INFO)",
|
||||
)
|
||||
|
||||
parser.add_argument('--key', type=str, help='API key for authentication. This protects lightrag server against unauthorized access', default=None)
|
||||
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -136,6 +153,31 @@ class InsertResponse(BaseModel):
|
||||
message: str
|
||||
document_count: int
|
||||
|
||||
def get_api_key_dependency(api_key: Optional[str]):
|
||||
if not api_key:
|
||||
# If no API key is configured, return a dummy dependency that always succeeds
|
||||
async def no_auth():
|
||||
return None
|
||||
return no_auth
|
||||
|
||||
# If API key is configured, use proper authentication
|
||||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
|
||||
async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)):
|
||||
if not api_key_header_value:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN,
|
||||
detail="API Key required"
|
||||
)
|
||||
if api_key_header_value != api_key:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN,
|
||||
detail="Invalid API Key"
|
||||
)
|
||||
return api_key_header_value
|
||||
|
||||
return api_key_auth
|
||||
|
||||
|
||||
async def get_embedding_dim(embedding_model: str) -> int:
|
||||
"""Get embedding dimensions for the specified model"""
|
||||
@@ -150,10 +192,37 @@ def create_app(args):
|
||||
format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level)
|
||||
)
|
||||
|
||||
# Initialize FastAPI app
|
||||
|
||||
# Check if API key is provided either through env var or args
|
||||
api_key = os.getenv("LIGHTRAG_API_KEY") or args.key
|
||||
|
||||
# Initialize FastAPI
|
||||
app = FastAPI(
|
||||
title="LightRAG API",
|
||||
description="API for querying text using LightRAG with OpenAI integration",
|
||||
description="API for querying text using LightRAG with separate storage and input directories"+"(With authentication)" if api_key else "",
|
||||
version="1.0.0",
|
||||
openapi_tags=[{"name": "api"}]
|
||||
)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Create the optional API key dependency
|
||||
optional_api_key = get_api_key_dependency(api_key)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Create working directory if it doesn't exist
|
||||
@@ -213,7 +282,7 @@ def create_app(args):
|
||||
except Exception as e:
|
||||
logging.error(f"Error during startup indexing: {str(e)}")
|
||||
|
||||
@app.post("/documents/scan")
|
||||
@app.post("/documents/scan", dependencies=[Depends(optional_api_key)])
|
||||
async def scan_for_new_documents():
|
||||
"""Manually trigger scanning for new documents"""
|
||||
try:
|
||||
@@ -238,7 +307,7 @@ def create_app(args):
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/documents/upload")
|
||||
@app.post("/documents/upload", dependencies=[Depends(optional_api_key)])
|
||||
async def upload_to_input_dir(file: UploadFile = File(...)):
|
||||
"""Upload a file to the input directory"""
|
||||
try:
|
||||
@@ -266,7 +335,7 @@ def create_app(args):
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/query", response_model=QueryResponse)
|
||||
@app.post("/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)])
|
||||
async def query_text(request: QueryRequest):
|
||||
try:
|
||||
response = await rag.aquery(
|
||||
@@ -288,7 +357,7 @@ def create_app(args):
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/query/stream")
|
||||
@app.post("/query/stream", dependencies=[Depends(optional_api_key)])
|
||||
async def query_text_stream(request: QueryRequest):
|
||||
try:
|
||||
response = rag.query(
|
||||
@@ -308,7 +377,7 @@ def create_app(args):
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/documents/text", response_model=InsertResponse)
|
||||
@app.post("/documents/text", response_model=InsertResponse, dependencies=[Depends(optional_api_key)])
|
||||
async def insert_text(request: InsertTextRequest):
|
||||
try:
|
||||
rag.insert(request.text)
|
||||
@@ -320,7 +389,7 @@ def create_app(args):
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/documents/file", response_model=InsertResponse)
|
||||
@app.post("/documents/file", response_model=InsertResponse, dependencies=[Depends(optional_api_key)])
|
||||
async def insert_file(file: UploadFile = File(...), description: str = Form(None)):
|
||||
try:
|
||||
content = await file.read()
|
||||
@@ -344,7 +413,7 @@ def create_app(args):
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/documents/batch", response_model=InsertResponse)
|
||||
@app.post("/documents/batch", response_model=InsertResponse, dependencies=[Depends(optional_api_key)])
|
||||
async def insert_batch(files: List[UploadFile] = File(...)):
|
||||
try:
|
||||
inserted_count = 0
|
||||
@@ -374,7 +443,7 @@ def create_app(args):
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.delete("/documents", response_model=InsertResponse)
|
||||
@app.delete("/documents", response_model=InsertResponse, dependencies=[Depends(optional_api_key)])
|
||||
async def clear_documents():
|
||||
try:
|
||||
rag.text_chunks = []
|
||||
@@ -388,7 +457,7 @@ def create_app(args):
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/health")
|
||||
@app.get("/health", dependencies=[Depends(optional_api_key)])
|
||||
async def get_status():
|
||||
"""Get current system status"""
|
||||
return {
|
||||
|
Reference in New Issue
Block a user