Merge branch 'HKUDS:main' into main

This commit is contained in:
Samuel Chan
2025-01-04 18:34:24 +08:00
committed by GitHub
5 changed files with 370 additions and 45 deletions

View File

@@ -1025,6 +1025,7 @@ Each server has its own specific configuration options:
| --max-embed-tokens | 8192 | Maximum embedding token size | | --max-embed-tokens | 8192 | Maximum embedding token size |
| --input-file | ./book.txt | Initial input file | | --input-file | ./book.txt | Initial input file |
| --log-level | INFO | Logging level | | --log-level | INFO | Logging level |
| --key | none | Access Key to protect the lightrag service |
#### Ollama Server Options #### Ollama Server Options
@@ -1042,6 +1043,7 @@ Each server has its own specific configuration options:
| --max-embed-tokens | 8192 | Maximum embedding token size | | --max-embed-tokens | 8192 | Maximum embedding token size |
| --input-file | ./book.txt | Initial input file | | --input-file | ./book.txt | Initial input file |
| --log-level | INFO | Logging level | | --log-level | INFO | Logging level |
| --key | none | Access Key to protect the lightrag service |
#### OpenAI Server Options #### OpenAI Server Options
@@ -1056,6 +1058,7 @@ Each server has its own specific configuration options:
| --max-embed-tokens | 8192 | Maximum embedding token size | | --max-embed-tokens | 8192 | Maximum embedding token size |
| --input-dir | ./inputs | Input directory for documents | | --input-dir | ./inputs | Input directory for documents |
| --log-level | INFO | Logging level | | --log-level | INFO | Logging level |
| --key | none | Access Key to protect the lightrag service |
#### OpenAI AZURE Server Options #### OpenAI AZURE Server Options
@@ -1071,8 +1074,10 @@ Each server has its own specific configuration options:
| --input-dir | ./inputs | Input directory for documents | | --input-dir | ./inputs | Input directory for documents |
| --enable-cache | True | Enable response cache | | --enable-cache | True | Enable response cache |
| --log-level | INFO | Logging level | | --log-level | INFO | Logging level |
| --key | none | Access Key to protect the lightrag service |
For protecting the server using an authentication key, you can also use an environment variable named `LIGHTRAG_API_KEY`.
### Example Usage ### Example Usage
#### LoLLMs RAG Server #### LoLLMs RAG Server
@@ -1083,6 +1088,10 @@ lollms-lightrag-server --model mistral-nemo --port 8080 --working-dir ./custom_r
# Using specific models (ensure they are installed in your LoLLMs instance) # Using specific models (ensure they are installed in your LoLLMs instance)
lollms-lightrag-server --model mistral-nemo:latest --embedding-model bge-m3 --embedding-dim 1024 lollms-lightrag-server --model mistral-nemo:latest --embedding-model bge-m3 --embedding-dim 1024
# Using specific models and an authentication key
lollms-lightrag-server --model mistral-nemo:latest --embedding-model bge-m3 --embedding-dim 1024 --key ky-mykey
``` ```
#### Ollama RAG Server #### Ollama RAG Server

View File

@@ -21,6 +21,12 @@ import inspect
import json import json
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from fastapi import Depends, Security
from fastapi.security import APIKeyHeader
from fastapi.middleware.cors import CORSMiddleware
from starlette.status import HTTP_403_FORBIDDEN
load_dotenv() load_dotenv()
AZURE_OPENAI_API_VERSION = os.getenv("AZURE_OPENAI_API_VERSION") AZURE_OPENAI_API_VERSION = os.getenv("AZURE_OPENAI_API_VERSION")
@@ -93,6 +99,13 @@ def parse_args():
help="Logging level (default: INFO)", help="Logging level (default: INFO)",
) )
parser.add_argument(
"--key",
type=str,
help="API key for authentication. This protects lightrag server against unauthorized access",
default=None,
)
return parser.parse_args() return parser.parse_args()
@@ -155,6 +168,31 @@ class InsertResponse(BaseModel):
document_count: int document_count: int
def get_api_key_dependency(api_key: Optional[str]):
if not api_key:
# If no API key is configured, return a dummy dependency that always succeeds
async def no_auth():
return None
return no_auth
# If API key is configured, use proper authentication
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)):
if not api_key_header_value:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="API Key required"
)
if api_key_header_value != api_key:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Invalid API Key"
)
return api_key_header_value
return api_key_auth
async def get_embedding_dim(embedding_model: str) -> int: async def get_embedding_dim(embedding_model: str) -> int:
"""Get embedding dimensions for the specified model""" """Get embedding dimensions for the specified model"""
test_text = ["This is a test sentence."] test_text = ["This is a test sentence."]
@@ -168,12 +206,32 @@ def create_app(args):
format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level) format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level)
) )
# Initialize FastAPI app # Check if API key is provided either through env var or args
api_key = os.getenv("LIGHTRAG_API_KEY") or args.key
# Initialize FastAPI
app = FastAPI( app = FastAPI(
title="LightRAG API", title="LightRAG API",
description="API for querying text using LightRAG with OpenAI integration", description="API for querying text using LightRAG with separate storage and input directories"
+ "(With authentication)"
if api_key
else "",
version="1.0.0",
openapi_tags=[{"name": "api"}],
) )
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Create the optional API key dependency
optional_api_key = get_api_key_dependency(api_key)
# Create working directory if it doesn't exist # Create working directory if it doesn't exist
Path(args.working_dir).mkdir(parents=True, exist_ok=True) Path(args.working_dir).mkdir(parents=True, exist_ok=True)
@@ -239,7 +297,7 @@ def create_app(args):
except Exception as e: except Exception as e:
logging.error(f"Error during startup indexing: {str(e)}") logging.error(f"Error during startup indexing: {str(e)}")
@app.post("/documents/scan") @app.post("/documents/scan", dependencies=[Depends(optional_api_key)])
async def scan_for_new_documents(): async def scan_for_new_documents():
"""Manually trigger scanning for new documents""" """Manually trigger scanning for new documents"""
try: try:
@@ -264,7 +322,7 @@ def create_app(args):
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.post("/resetcache") @app.post("/resetcache", dependencies=[Depends(optional_api_key)])
async def reset_cache(): async def reset_cache():
"""Manually reset cache""" """Manually reset cache"""
try: try:
@@ -276,7 +334,7 @@ def create_app(args):
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/upload") @app.post("/documents/upload", dependencies=[Depends(optional_api_key)])
async def upload_to_input_dir(file: UploadFile = File(...)): async def upload_to_input_dir(file: UploadFile = File(...)):
"""Upload a file to the input directory""" """Upload a file to the input directory"""
try: try:
@@ -304,7 +362,9 @@ def create_app(args):
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.post("/query", response_model=QueryResponse) @app.post(
"/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)]
)
async def query_text(request: QueryRequest): async def query_text(request: QueryRequest):
try: try:
response = await rag.aquery( response = await rag.aquery(
@@ -319,7 +379,7 @@ def create_app(args):
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.post("/query/stream") @app.post("/query/stream", dependencies=[Depends(optional_api_key)])
async def query_text_stream(request: QueryRequest): async def query_text_stream(request: QueryRequest):
try: try:
response = await rag.aquery( response = await rag.aquery(
@@ -345,7 +405,11 @@ def create_app(args):
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/text", response_model=InsertResponse) @app.post(
"/documents/text",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_text(request: InsertTextRequest): async def insert_text(request: InsertTextRequest):
try: try:
await rag.ainsert(request.text) await rag.ainsert(request.text)
@@ -357,7 +421,11 @@ def create_app(args):
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/file", response_model=InsertResponse) @app.post(
"/documents/file",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_file(file: UploadFile = File(...), description: str = Form(None)): async def insert_file(file: UploadFile = File(...), description: str = Form(None)):
try: try:
content = await file.read() content = await file.read()
@@ -381,7 +449,11 @@ def create_app(args):
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/batch", response_model=InsertResponse) @app.post(
"/documents/batch",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_batch(files: List[UploadFile] = File(...)): async def insert_batch(files: List[UploadFile] = File(...)):
try: try:
inserted_count = 0 inserted_count = 0
@@ -411,7 +483,11 @@ def create_app(args):
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.delete("/documents", response_model=InsertResponse) @app.delete(
"/documents",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def clear_documents(): async def clear_documents():
try: try:
rag.text_chunks = [] rag.text_chunks = []
@@ -425,7 +501,7 @@ def create_app(args):
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.get("/health") @app.get("/health", dependencies=[Depends(optional_api_key)])
async def get_status(): async def get_status():
"""Get current system status""" """Get current system status"""
return { return {

View File

@@ -11,6 +11,13 @@ from pathlib import Path
import shutil import shutil
import aiofiles import aiofiles
from ascii_colors import trace_exception from ascii_colors import trace_exception
import os
from fastapi import Depends, Security
from fastapi.security import APIKeyHeader
from fastapi.middleware.cors import CORSMiddleware
from starlette.status import HTTP_403_FORBIDDEN
def parse_args(): def parse_args():
@@ -86,6 +93,13 @@ def parse_args():
help="Logging level (default: INFO)", help="Logging level (default: INFO)",
) )
parser.add_argument(
"--key",
type=str,
help="API key for authentication. This protects lightrag server against unauthorized access",
default=None,
)
return parser.parse_args() return parser.parse_args()
@@ -148,18 +162,63 @@ class InsertResponse(BaseModel):
document_count: int document_count: int
def get_api_key_dependency(api_key: Optional[str]):
if not api_key:
# If no API key is configured, return a dummy dependency that always succeeds
async def no_auth():
return None
return no_auth
# If API key is configured, use proper authentication
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)):
if not api_key_header_value:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="API Key required"
)
if api_key_header_value != api_key:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Invalid API Key"
)
return api_key_header_value
return api_key_auth
def create_app(args): def create_app(args):
# Setup logging # Setup logging
logging.basicConfig( logging.basicConfig(
format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level) format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level)
) )
# Initialize FastAPI app # Check if API key is provided either through env var or args
api_key = os.getenv("LIGHTRAG_API_KEY") or args.key
# Initialize FastAPI
app = FastAPI( app = FastAPI(
title="LightRAG API", title="LightRAG API",
description="API for querying text using LightRAG with separate storage and input directories", description="API for querying text using LightRAG with separate storage and input directories"
+ "(With authentication)"
if api_key
else "",
version="1.0.0",
openapi_tags=[{"name": "api"}],
) )
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Create the optional API key dependency
optional_api_key = get_api_key_dependency(api_key)
# Create working directory if it doesn't exist # Create working directory if it doesn't exist
Path(args.working_dir).mkdir(parents=True, exist_ok=True) Path(args.working_dir).mkdir(parents=True, exist_ok=True)
@@ -209,7 +268,7 @@ def create_app(args):
except Exception as e: except Exception as e:
logging.error(f"Error during startup indexing: {str(e)}") logging.error(f"Error during startup indexing: {str(e)}")
@app.post("/documents/scan") @app.post("/documents/scan", dependencies=[Depends(optional_api_key)])
async def scan_for_new_documents(): async def scan_for_new_documents():
"""Manually trigger scanning for new documents""" """Manually trigger scanning for new documents"""
try: try:
@@ -234,7 +293,7 @@ def create_app(args):
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/upload") @app.post("/documents/upload", dependencies=[Depends(optional_api_key)])
async def upload_to_input_dir(file: UploadFile = File(...)): async def upload_to_input_dir(file: UploadFile = File(...)):
"""Upload a file to the input directory""" """Upload a file to the input directory"""
try: try:
@@ -262,7 +321,9 @@ def create_app(args):
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.post("/query", response_model=QueryResponse) @app.post(
"/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)]
)
async def query_text(request: QueryRequest): async def query_text(request: QueryRequest):
try: try:
response = await rag.aquery( response = await rag.aquery(
@@ -284,7 +345,7 @@ def create_app(args):
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.post("/query/stream") @app.post("/query/stream", dependencies=[Depends(optional_api_key)])
async def query_text_stream(request: QueryRequest): async def query_text_stream(request: QueryRequest):
try: try:
response = rag.query( response = rag.query(
@@ -304,7 +365,11 @@ def create_app(args):
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/text", response_model=InsertResponse) @app.post(
"/documents/text",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_text(request: InsertTextRequest): async def insert_text(request: InsertTextRequest):
try: try:
rag.insert(request.text) rag.insert(request.text)
@@ -316,7 +381,11 @@ def create_app(args):
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/file", response_model=InsertResponse) @app.post(
"/documents/file",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_file(file: UploadFile = File(...), description: str = Form(None)): async def insert_file(file: UploadFile = File(...), description: str = Form(None)):
try: try:
content = await file.read() content = await file.read()
@@ -340,7 +409,11 @@ def create_app(args):
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/batch", response_model=InsertResponse) @app.post(
"/documents/batch",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_batch(files: List[UploadFile] = File(...)): async def insert_batch(files: List[UploadFile] = File(...)):
try: try:
inserted_count = 0 inserted_count = 0
@@ -370,7 +443,11 @@ def create_app(args):
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.delete("/documents", response_model=InsertResponse) @app.delete(
"/documents",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def clear_documents(): async def clear_documents():
try: try:
rag.text_chunks = [] rag.text_chunks = []
@@ -384,7 +461,7 @@ def create_app(args):
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.get("/health") @app.get("/health", dependencies=[Depends(optional_api_key)])
async def get_status(): async def get_status():
"""Get current system status""" """Get current system status"""
return { return {

View File

@@ -11,6 +11,13 @@ from pathlib import Path
import shutil import shutil
import aiofiles import aiofiles
from ascii_colors import trace_exception from ascii_colors import trace_exception
import os
from fastapi import Depends, Security
from fastapi.security import APIKeyHeader
from fastapi.middleware.cors import CORSMiddleware
from starlette.status import HTTP_403_FORBIDDEN
def parse_args(): def parse_args():
@@ -85,6 +92,12 @@ def parse_args():
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
help="Logging level (default: INFO)", help="Logging level (default: INFO)",
) )
parser.add_argument(
"--key",
type=str,
help="API key for authentication. This protects lightrag server against unauthorized access",
default=None,
)
return parser.parse_args() return parser.parse_args()
@@ -148,18 +161,63 @@ class InsertResponse(BaseModel):
document_count: int document_count: int
def get_api_key_dependency(api_key: Optional[str]):
if not api_key:
# If no API key is configured, return a dummy dependency that always succeeds
async def no_auth():
return None
return no_auth
# If API key is configured, use proper authentication
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)):
if not api_key_header_value:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="API Key required"
)
if api_key_header_value != api_key:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Invalid API Key"
)
return api_key_header_value
return api_key_auth
def create_app(args): def create_app(args):
# Setup logging # Setup logging
logging.basicConfig( logging.basicConfig(
format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level) format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level)
) )
# Initialize FastAPI app # Check if API key is provided either through env var or args
api_key = os.getenv("LIGHTRAG_API_KEY") or args.key
# Initialize FastAPI
app = FastAPI( app = FastAPI(
title="LightRAG API", title="LightRAG API",
description="API for querying text using LightRAG with separate storage and input directories", description="API for querying text using LightRAG with separate storage and input directories"
+ "(With authentication)"
if api_key
else "",
version="1.0.0",
openapi_tags=[{"name": "api"}],
) )
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Create the optional API key dependency
optional_api_key = get_api_key_dependency(api_key)
# Create working directory if it doesn't exist # Create working directory if it doesn't exist
Path(args.working_dir).mkdir(parents=True, exist_ok=True) Path(args.working_dir).mkdir(parents=True, exist_ok=True)
@@ -209,7 +267,7 @@ def create_app(args):
except Exception as e: except Exception as e:
logging.error(f"Error during startup indexing: {str(e)}") logging.error(f"Error during startup indexing: {str(e)}")
@app.post("/documents/scan") @app.post("/documents/scan", dependencies=[Depends(optional_api_key)])
async def scan_for_new_documents(): async def scan_for_new_documents():
"""Manually trigger scanning for new documents""" """Manually trigger scanning for new documents"""
try: try:
@@ -234,7 +292,7 @@ def create_app(args):
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/upload") @app.post("/documents/upload", dependencies=[Depends(optional_api_key)])
async def upload_to_input_dir(file: UploadFile = File(...)): async def upload_to_input_dir(file: UploadFile = File(...)):
"""Upload a file to the input directory""" """Upload a file to the input directory"""
try: try:
@@ -262,7 +320,9 @@ def create_app(args):
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.post("/query", response_model=QueryResponse) @app.post(
"/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)]
)
async def query_text(request: QueryRequest): async def query_text(request: QueryRequest):
try: try:
response = await rag.aquery( response = await rag.aquery(
@@ -284,7 +344,7 @@ def create_app(args):
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.post("/query/stream") @app.post("/query/stream", dependencies=[Depends(optional_api_key)])
async def query_text_stream(request: QueryRequest): async def query_text_stream(request: QueryRequest):
try: try:
response = rag.query( response = rag.query(
@@ -304,7 +364,11 @@ def create_app(args):
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/text", response_model=InsertResponse) @app.post(
"/documents/text",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_text(request: InsertTextRequest): async def insert_text(request: InsertTextRequest):
try: try:
await rag.ainsert(request.text) await rag.ainsert(request.text)
@@ -316,7 +380,11 @@ def create_app(args):
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/file", response_model=InsertResponse) @app.post(
"/documents/file",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_file(file: UploadFile = File(...), description: str = Form(None)): async def insert_file(file: UploadFile = File(...), description: str = Form(None)):
try: try:
content = await file.read() content = await file.read()
@@ -340,7 +408,11 @@ def create_app(args):
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/batch", response_model=InsertResponse) @app.post(
"/documents/batch",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_batch(files: List[UploadFile] = File(...)): async def insert_batch(files: List[UploadFile] = File(...)):
try: try:
inserted_count = 0 inserted_count = 0
@@ -370,7 +442,11 @@ def create_app(args):
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.delete("/documents", response_model=InsertResponse) @app.delete(
"/documents",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def clear_documents(): async def clear_documents():
try: try:
rag.text_chunks = [] rag.text_chunks = []
@@ -384,7 +460,7 @@ def create_app(args):
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.get("/health") @app.get("/health", dependencies=[Depends(optional_api_key)])
async def get_status(): async def get_status():
"""Get current system status""" """Get current system status"""
return { return {

View File

@@ -14,6 +14,14 @@ import aiofiles
from ascii_colors import trace_exception from ascii_colors import trace_exception
import nest_asyncio import nest_asyncio
import os
from fastapi import Depends, Security
from fastapi.security import APIKeyHeader
from fastapi.middleware.cors import CORSMiddleware
from starlette.status import HTTP_403_FORBIDDEN
# Apply nest_asyncio to solve event loop issues # Apply nest_asyncio to solve event loop issues
nest_asyncio.apply() nest_asyncio.apply()
@@ -75,6 +83,13 @@ def parse_args():
help="Logging level (default: INFO)", help="Logging level (default: INFO)",
) )
parser.add_argument(
"--key",
type=str,
help="API key for authentication. This protects lightrag server against unauthorized access",
default=None,
)
return parser.parse_args() return parser.parse_args()
@@ -137,6 +152,31 @@ class InsertResponse(BaseModel):
document_count: int document_count: int
def get_api_key_dependency(api_key: Optional[str]):
if not api_key:
# If no API key is configured, return a dummy dependency that always succeeds
async def no_auth():
return None
return no_auth
# If API key is configured, use proper authentication
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)):
if not api_key_header_value:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="API Key required"
)
if api_key_header_value != api_key:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Invalid API Key"
)
return api_key_header_value
return api_key_auth
async def get_embedding_dim(embedding_model: str) -> int: async def get_embedding_dim(embedding_model: str) -> int:
"""Get embedding dimensions for the specified model""" """Get embedding dimensions for the specified model"""
test_text = ["This is a test sentence."] test_text = ["This is a test sentence."]
@@ -150,10 +190,39 @@ def create_app(args):
format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level) format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level)
) )
# Initialize FastAPI app # Check if API key is provided either through env var or args
api_key = os.getenv("LIGHTRAG_API_KEY") or args.key
# Initialize FastAPI
app = FastAPI( app = FastAPI(
title="LightRAG API", title="LightRAG API",
description="API for querying text using LightRAG with OpenAI integration", description="API for querying text using LightRAG with separate storage and input directories"
+ "(With authentication)"
if api_key
else "",
version="1.0.0",
openapi_tags=[{"name": "api"}],
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Create the optional API key dependency
optional_api_key = get_api_key_dependency(api_key)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
) )
# Create working directory if it doesn't exist # Create working directory if it doesn't exist
@@ -213,7 +282,7 @@ def create_app(args):
except Exception as e: except Exception as e:
logging.error(f"Error during startup indexing: {str(e)}") logging.error(f"Error during startup indexing: {str(e)}")
@app.post("/documents/scan") @app.post("/documents/scan", dependencies=[Depends(optional_api_key)])
async def scan_for_new_documents(): async def scan_for_new_documents():
"""Manually trigger scanning for new documents""" """Manually trigger scanning for new documents"""
try: try:
@@ -238,7 +307,7 @@ def create_app(args):
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/upload") @app.post("/documents/upload", dependencies=[Depends(optional_api_key)])
async def upload_to_input_dir(file: UploadFile = File(...)): async def upload_to_input_dir(file: UploadFile = File(...)):
"""Upload a file to the input directory""" """Upload a file to the input directory"""
try: try:
@@ -266,7 +335,9 @@ def create_app(args):
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.post("/query", response_model=QueryResponse) @app.post(
"/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)]
)
async def query_text(request: QueryRequest): async def query_text(request: QueryRequest):
try: try:
response = await rag.aquery( response = await rag.aquery(
@@ -288,7 +359,7 @@ def create_app(args):
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.post("/query/stream") @app.post("/query/stream", dependencies=[Depends(optional_api_key)])
async def query_text_stream(request: QueryRequest): async def query_text_stream(request: QueryRequest):
try: try:
response = rag.query( response = rag.query(
@@ -308,7 +379,11 @@ def create_app(args):
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/text", response_model=InsertResponse) @app.post(
"/documents/text",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_text(request: InsertTextRequest): async def insert_text(request: InsertTextRequest):
try: try:
rag.insert(request.text) rag.insert(request.text)
@@ -320,7 +395,11 @@ def create_app(args):
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/file", response_model=InsertResponse) @app.post(
"/documents/file",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_file(file: UploadFile = File(...), description: str = Form(None)): async def insert_file(file: UploadFile = File(...), description: str = Form(None)):
try: try:
content = await file.read() content = await file.read()
@@ -344,7 +423,11 @@ def create_app(args):
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/batch", response_model=InsertResponse) @app.post(
"/documents/batch",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_batch(files: List[UploadFile] = File(...)): async def insert_batch(files: List[UploadFile] = File(...)):
try: try:
inserted_count = 0 inserted_count = 0
@@ -374,7 +457,11 @@ def create_app(args):
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.delete("/documents", response_model=InsertResponse) @app.delete(
"/documents",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def clear_documents(): async def clear_documents():
try: try:
rag.text_chunks = [] rag.text_chunks = []
@@ -388,7 +475,7 @@ def create_app(args):
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.get("/health") @app.get("/health", dependencies=[Depends(optional_api_key)])
async def get_status(): async def get_status():
"""Get current system status""" """Get current system status"""
return { return {