Merge remote-tracking branch 'origin/main'

This commit is contained in:
Samuel Chan
2025-01-04 18:35:39 +08:00
5 changed files with 370 additions and 45 deletions

View File

@@ -1025,6 +1025,7 @@ Each server has its own specific configuration options:
| --max-embed-tokens | 8192 | Maximum embedding token size |
| --input-file | ./book.txt | Initial input file |
| --log-level | INFO | Logging level |
| --key | none | Access Key to protect the lightrag service |
#### Ollama Server Options
@@ -1042,6 +1043,7 @@ Each server has its own specific configuration options:
| --max-embed-tokens | 8192 | Maximum embedding token size |
| --input-file | ./book.txt | Initial input file |
| --log-level | INFO | Logging level |
| --key | none | Access Key to protect the lightrag service |
#### OpenAI Server Options
@@ -1056,6 +1058,7 @@ Each server has its own specific configuration options:
| --max-embed-tokens | 8192 | Maximum embedding token size |
| --input-dir | ./inputs | Input directory for documents |
| --log-level | INFO | Logging level |
| --key | none | Access Key to protect the lightrag service |
#### OpenAI AZURE Server Options
@@ -1071,8 +1074,10 @@ Each server has its own specific configuration options:
| --input-dir | ./inputs | Input directory for documents |
| --enable-cache | True | Enable response cache |
| --log-level | INFO | Logging level |
| --key | none | Access Key to protect the lightrag service |
For protecting the server using an authentication key, you can also use an environment variable named `LIGHTRAG_API_KEY`.
### Example Usage
#### LoLLMs RAG Server
@@ -1083,6 +1088,10 @@ lollms-lightrag-server --model mistral-nemo --port 8080 --working-dir ./custom_r
# Using specific models (ensure they are installed in your LoLLMs instance)
lollms-lightrag-server --model mistral-nemo:latest --embedding-model bge-m3 --embedding-dim 1024
# Using specific models and an authentication key
lollms-lightrag-server --model mistral-nemo:latest --embedding-model bge-m3 --embedding-dim 1024 --key ky-mykey
```
#### Ollama RAG Server

View File

@@ -21,6 +21,12 @@ import inspect
import json
from fastapi.responses import StreamingResponse
from fastapi import Depends, Security
from fastapi.security import APIKeyHeader
from fastapi.middleware.cors import CORSMiddleware
from starlette.status import HTTP_403_FORBIDDEN
load_dotenv()
AZURE_OPENAI_API_VERSION = os.getenv("AZURE_OPENAI_API_VERSION")
@@ -93,6 +99,13 @@ def parse_args():
help="Logging level (default: INFO)",
)
parser.add_argument(
"--key",
type=str,
help="API key for authentication. This protects lightrag server against unauthorized access",
default=None,
)
return parser.parse_args()
@@ -155,6 +168,31 @@ class InsertResponse(BaseModel):
document_count: int
def get_api_key_dependency(api_key: Optional[str]):
if not api_key:
# If no API key is configured, return a dummy dependency that always succeeds
async def no_auth():
return None
return no_auth
# If API key is configured, use proper authentication
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)):
if not api_key_header_value:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="API Key required"
)
if api_key_header_value != api_key:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Invalid API Key"
)
return api_key_header_value
return api_key_auth
async def get_embedding_dim(embedding_model: str) -> int:
"""Get embedding dimensions for the specified model"""
test_text = ["This is a test sentence."]
@@ -168,12 +206,32 @@ def create_app(args):
format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level)
)
# Initialize FastAPI app
# Check if API key is provided either through env var or args
api_key = os.getenv("LIGHTRAG_API_KEY") or args.key
# Initialize FastAPI
app = FastAPI(
title="LightRAG API",
description="API for querying text using LightRAG with OpenAI integration",
description="API for querying text using LightRAG with separate storage and input directories"
+ "(With authentication)"
if api_key
else "",
version="1.0.0",
openapi_tags=[{"name": "api"}],
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Create the optional API key dependency
optional_api_key = get_api_key_dependency(api_key)
# Create working directory if it doesn't exist
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
@@ -239,7 +297,7 @@ def create_app(args):
except Exception as e:
logging.error(f"Error during startup indexing: {str(e)}")
@app.post("/documents/scan")
@app.post("/documents/scan", dependencies=[Depends(optional_api_key)])
async def scan_for_new_documents():
"""Manually trigger scanning for new documents"""
try:
@@ -264,7 +322,7 @@ def create_app(args):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/resetcache")
@app.post("/resetcache", dependencies=[Depends(optional_api_key)])
async def reset_cache():
"""Manually reset cache"""
try:
@@ -276,7 +334,7 @@ def create_app(args):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/upload")
@app.post("/documents/upload", dependencies=[Depends(optional_api_key)])
async def upload_to_input_dir(file: UploadFile = File(...)):
"""Upload a file to the input directory"""
try:
@@ -304,7 +362,9 @@ def create_app(args):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/query", response_model=QueryResponse)
@app.post(
"/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)]
)
async def query_text(request: QueryRequest):
try:
response = await rag.aquery(
@@ -319,7 +379,7 @@ def create_app(args):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/query/stream")
@app.post("/query/stream", dependencies=[Depends(optional_api_key)])
async def query_text_stream(request: QueryRequest):
try:
response = await rag.aquery(
@@ -345,7 +405,11 @@ def create_app(args):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/text", response_model=InsertResponse)
@app.post(
"/documents/text",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_text(request: InsertTextRequest):
try:
await rag.ainsert(request.text)
@@ -357,7 +421,11 @@ def create_app(args):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/file", response_model=InsertResponse)
@app.post(
"/documents/file",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_file(file: UploadFile = File(...), description: str = Form(None)):
try:
content = await file.read()
@@ -381,7 +449,11 @@ def create_app(args):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/batch", response_model=InsertResponse)
@app.post(
"/documents/batch",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_batch(files: List[UploadFile] = File(...)):
try:
inserted_count = 0
@@ -411,7 +483,11 @@ def create_app(args):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.delete("/documents", response_model=InsertResponse)
@app.delete(
"/documents",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def clear_documents():
try:
rag.text_chunks = []
@@ -425,7 +501,7 @@ def create_app(args):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
@app.get("/health", dependencies=[Depends(optional_api_key)])
async def get_status():
"""Get current system status"""
return {

View File

@@ -11,6 +11,13 @@ from pathlib import Path
import shutil
import aiofiles
from ascii_colors import trace_exception
import os
from fastapi import Depends, Security
from fastapi.security import APIKeyHeader
from fastapi.middleware.cors import CORSMiddleware
from starlette.status import HTTP_403_FORBIDDEN
def parse_args():
@@ -86,6 +93,13 @@ def parse_args():
help="Logging level (default: INFO)",
)
parser.add_argument(
"--key",
type=str,
help="API key for authentication. This protects lightrag server against unauthorized access",
default=None,
)
return parser.parse_args()
@@ -148,18 +162,63 @@ class InsertResponse(BaseModel):
document_count: int
def get_api_key_dependency(api_key: Optional[str]):
if not api_key:
# If no API key is configured, return a dummy dependency that always succeeds
async def no_auth():
return None
return no_auth
# If API key is configured, use proper authentication
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)):
if not api_key_header_value:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="API Key required"
)
if api_key_header_value != api_key:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Invalid API Key"
)
return api_key_header_value
return api_key_auth
def create_app(args):
# Setup logging
logging.basicConfig(
format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level)
)
# Initialize FastAPI app
# Check if API key is provided either through env var or args
api_key = os.getenv("LIGHTRAG_API_KEY") or args.key
# Initialize FastAPI
app = FastAPI(
title="LightRAG API",
description="API for querying text using LightRAG with separate storage and input directories",
description="API for querying text using LightRAG with separate storage and input directories"
+ "(With authentication)"
if api_key
else "",
version="1.0.0",
openapi_tags=[{"name": "api"}],
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Create the optional API key dependency
optional_api_key = get_api_key_dependency(api_key)
# Create working directory if it doesn't exist
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
@@ -209,7 +268,7 @@ def create_app(args):
except Exception as e:
logging.error(f"Error during startup indexing: {str(e)}")
@app.post("/documents/scan")
@app.post("/documents/scan", dependencies=[Depends(optional_api_key)])
async def scan_for_new_documents():
"""Manually trigger scanning for new documents"""
try:
@@ -234,7 +293,7 @@ def create_app(args):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/upload")
@app.post("/documents/upload", dependencies=[Depends(optional_api_key)])
async def upload_to_input_dir(file: UploadFile = File(...)):
"""Upload a file to the input directory"""
try:
@@ -262,7 +321,9 @@ def create_app(args):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/query", response_model=QueryResponse)
@app.post(
"/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)]
)
async def query_text(request: QueryRequest):
try:
response = await rag.aquery(
@@ -284,7 +345,7 @@ def create_app(args):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/query/stream")
@app.post("/query/stream", dependencies=[Depends(optional_api_key)])
async def query_text_stream(request: QueryRequest):
try:
response = rag.query(
@@ -304,7 +365,11 @@ def create_app(args):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/text", response_model=InsertResponse)
@app.post(
"/documents/text",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_text(request: InsertTextRequest):
try:
rag.insert(request.text)
@@ -316,7 +381,11 @@ def create_app(args):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/file", response_model=InsertResponse)
@app.post(
"/documents/file",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_file(file: UploadFile = File(...), description: str = Form(None)):
try:
content = await file.read()
@@ -340,7 +409,11 @@ def create_app(args):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/batch", response_model=InsertResponse)
@app.post(
"/documents/batch",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_batch(files: List[UploadFile] = File(...)):
try:
inserted_count = 0
@@ -370,7 +443,11 @@ def create_app(args):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.delete("/documents", response_model=InsertResponse)
@app.delete(
"/documents",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def clear_documents():
try:
rag.text_chunks = []
@@ -384,7 +461,7 @@ def create_app(args):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
@app.get("/health", dependencies=[Depends(optional_api_key)])
async def get_status():
"""Get current system status"""
return {

View File

@@ -11,6 +11,13 @@ from pathlib import Path
import shutil
import aiofiles
from ascii_colors import trace_exception
import os
from fastapi import Depends, Security
from fastapi.security import APIKeyHeader
from fastapi.middleware.cors import CORSMiddleware
from starlette.status import HTTP_403_FORBIDDEN
def parse_args():
@@ -85,6 +92,12 @@ def parse_args():
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
help="Logging level (default: INFO)",
)
parser.add_argument(
"--key",
type=str,
help="API key for authentication. This protects lightrag server against unauthorized access",
default=None,
)
return parser.parse_args()
@@ -148,18 +161,63 @@ class InsertResponse(BaseModel):
document_count: int
def get_api_key_dependency(api_key: Optional[str]):
if not api_key:
# If no API key is configured, return a dummy dependency that always succeeds
async def no_auth():
return None
return no_auth
# If API key is configured, use proper authentication
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)):
if not api_key_header_value:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="API Key required"
)
if api_key_header_value != api_key:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Invalid API Key"
)
return api_key_header_value
return api_key_auth
def create_app(args):
# Setup logging
logging.basicConfig(
format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level)
)
# Initialize FastAPI app
# Check if API key is provided either through env var or args
api_key = os.getenv("LIGHTRAG_API_KEY") or args.key
# Initialize FastAPI
app = FastAPI(
title="LightRAG API",
description="API for querying text using LightRAG with separate storage and input directories",
description="API for querying text using LightRAG with separate storage and input directories"
+ "(With authentication)"
if api_key
else "",
version="1.0.0",
openapi_tags=[{"name": "api"}],
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Create the optional API key dependency
optional_api_key = get_api_key_dependency(api_key)
# Create working directory if it doesn't exist
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
@@ -209,7 +267,7 @@ def create_app(args):
except Exception as e:
logging.error(f"Error during startup indexing: {str(e)}")
@app.post("/documents/scan")
@app.post("/documents/scan", dependencies=[Depends(optional_api_key)])
async def scan_for_new_documents():
"""Manually trigger scanning for new documents"""
try:
@@ -234,7 +292,7 @@ def create_app(args):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/upload")
@app.post("/documents/upload", dependencies=[Depends(optional_api_key)])
async def upload_to_input_dir(file: UploadFile = File(...)):
"""Upload a file to the input directory"""
try:
@@ -262,7 +320,9 @@ def create_app(args):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/query", response_model=QueryResponse)
@app.post(
"/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)]
)
async def query_text(request: QueryRequest):
try:
response = await rag.aquery(
@@ -284,7 +344,7 @@ def create_app(args):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/query/stream")
@app.post("/query/stream", dependencies=[Depends(optional_api_key)])
async def query_text_stream(request: QueryRequest):
try:
response = rag.query(
@@ -304,7 +364,11 @@ def create_app(args):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/text", response_model=InsertResponse)
@app.post(
"/documents/text",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_text(request: InsertTextRequest):
try:
await rag.ainsert(request.text)
@@ -316,7 +380,11 @@ def create_app(args):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/file", response_model=InsertResponse)
@app.post(
"/documents/file",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_file(file: UploadFile = File(...), description: str = Form(None)):
try:
content = await file.read()
@@ -340,7 +408,11 @@ def create_app(args):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/batch", response_model=InsertResponse)
@app.post(
"/documents/batch",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_batch(files: List[UploadFile] = File(...)):
try:
inserted_count = 0
@@ -370,7 +442,11 @@ def create_app(args):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.delete("/documents", response_model=InsertResponse)
@app.delete(
"/documents",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def clear_documents():
try:
rag.text_chunks = []
@@ -384,7 +460,7 @@ def create_app(args):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
@app.get("/health", dependencies=[Depends(optional_api_key)])
async def get_status():
"""Get current system status"""
return {

View File

@@ -14,6 +14,14 @@ import aiofiles
from ascii_colors import trace_exception
import nest_asyncio
import os
from fastapi import Depends, Security
from fastapi.security import APIKeyHeader
from fastapi.middleware.cors import CORSMiddleware
from starlette.status import HTTP_403_FORBIDDEN
# Apply nest_asyncio to solve event loop issues
nest_asyncio.apply()
@@ -75,6 +83,13 @@ def parse_args():
help="Logging level (default: INFO)",
)
parser.add_argument(
"--key",
type=str,
help="API key for authentication. This protects lightrag server against unauthorized access",
default=None,
)
return parser.parse_args()
@@ -137,6 +152,31 @@ class InsertResponse(BaseModel):
document_count: int
def get_api_key_dependency(api_key: Optional[str]):
if not api_key:
# If no API key is configured, return a dummy dependency that always succeeds
async def no_auth():
return None
return no_auth
# If API key is configured, use proper authentication
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)):
if not api_key_header_value:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="API Key required"
)
if api_key_header_value != api_key:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Invalid API Key"
)
return api_key_header_value
return api_key_auth
async def get_embedding_dim(embedding_model: str) -> int:
"""Get embedding dimensions for the specified model"""
test_text = ["This is a test sentence."]
@@ -150,10 +190,39 @@ def create_app(args):
format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level)
)
# Initialize FastAPI app
# Check if API key is provided either through env var or args
api_key = os.getenv("LIGHTRAG_API_KEY") or args.key
# Initialize FastAPI
app = FastAPI(
title="LightRAG API",
description="API for querying text using LightRAG with OpenAI integration",
description="API for querying text using LightRAG with separate storage and input directories"
+ "(With authentication)"
if api_key
else "",
version="1.0.0",
openapi_tags=[{"name": "api"}],
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Create the optional API key dependency
optional_api_key = get_api_key_dependency(api_key)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Create working directory if it doesn't exist
@@ -213,7 +282,7 @@ def create_app(args):
except Exception as e:
logging.error(f"Error during startup indexing: {str(e)}")
@app.post("/documents/scan")
@app.post("/documents/scan", dependencies=[Depends(optional_api_key)])
async def scan_for_new_documents():
"""Manually trigger scanning for new documents"""
try:
@@ -238,7 +307,7 @@ def create_app(args):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/upload")
@app.post("/documents/upload", dependencies=[Depends(optional_api_key)])
async def upload_to_input_dir(file: UploadFile = File(...)):
"""Upload a file to the input directory"""
try:
@@ -266,7 +335,9 @@ def create_app(args):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/query", response_model=QueryResponse)
@app.post(
"/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)]
)
async def query_text(request: QueryRequest):
try:
response = await rag.aquery(
@@ -288,7 +359,7 @@ def create_app(args):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/query/stream")
@app.post("/query/stream", dependencies=[Depends(optional_api_key)])
async def query_text_stream(request: QueryRequest):
try:
response = rag.query(
@@ -308,7 +379,11 @@ def create_app(args):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/text", response_model=InsertResponse)
@app.post(
"/documents/text",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_text(request: InsertTextRequest):
try:
rag.insert(request.text)
@@ -320,7 +395,11 @@ def create_app(args):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/file", response_model=InsertResponse)
@app.post(
"/documents/file",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_file(file: UploadFile = File(...), description: str = Form(None)):
try:
content = await file.read()
@@ -344,7 +423,11 @@ def create_app(args):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/batch", response_model=InsertResponse)
@app.post(
"/documents/batch",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_batch(files: List[UploadFile] = File(...)):
try:
inserted_count = 0
@@ -374,7 +457,11 @@ def create_app(args):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.delete("/documents", response_model=InsertResponse)
@app.delete(
"/documents",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def clear_documents():
try:
rag.text_chunks = []
@@ -388,7 +475,7 @@ def create_app(args):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
@app.get("/health", dependencies=[Depends(optional_api_key)])
async def get_status():
"""Get current system status"""
return {