split lightrag_servery.py to smaller files

This commit is contained in:
yangdx
2025-02-20 03:26:39 +08:00
parent 0b795aa183
commit c0c87edc45
7 changed files with 1008 additions and 933 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,10 @@
"""
This module contains all the routers for the LightRAG API.
"""
from .document_routes import router as document_router
from .query_routes import router as query_router
from .graph_routes import router as graph_router
from .ollama_api import OllamaAPI
__all__ = ["document_router", "query_router", "graph_router", "OllamaAPI"]

View File

@@ -0,0 +1,667 @@
"""
This module contains all document-related routes for the LightRAG API.
"""
import asyncio
import logging
import os
import aiofiles
import shutil
import traceback
import pipmaster as pm
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Any
from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile
from fastapi.security import APIKeyHeader
from pydantic import BaseModel, Field, field_validator
from starlette.status import HTTP_403_FORBIDDEN
from lightrag.base import DocProcessingStatus, DocStatus
from ..utils_api import get_api_key_dependency
router = APIRouter(prefix="/documents", tags=["documents"])
# Global progress tracker
scan_progress: Dict = {
"is_scanning": False,
"current_file": "",
"indexed_count": 0,
"total_files": 0,
"progress": 0,
}
# Lock for thread-safe operations
progress_lock = asyncio.Lock()
# Temporary file prefix
temp_prefix = "__tmp__"
class InsertTextRequest(BaseModel):
text: str = Field(
min_length=1,
description="The text to insert",
)
@field_validator("text", mode="after")
@classmethod
def strip_after(cls, text: str) -> str:
return text.strip()
class InsertTextsRequest(BaseModel):
texts: list[str] = Field(
min_length=1,
description="The texts to insert",
)
@field_validator("texts", mode="after")
@classmethod
def strip_after(cls, texts: list[str]) -> list[str]:
return [text.strip() for text in texts]
class InsertResponse(BaseModel):
status: str = Field(description="Status of the operation")
message: str = Field(description="Message describing the operation result")
class DocStatusResponse(BaseModel):
@staticmethod
def format_datetime(dt: Any) -> Optional[str]:
if dt is None:
return None
if isinstance(dt, str):
return dt
return dt.isoformat()
id: str
content_summary: str
content_length: int
status: DocStatus
created_at: str
updated_at: str
chunks_count: Optional[int] = None
error: Optional[str] = None
metadata: Optional[dict[str, Any]] = None
class DocsStatusesResponse(BaseModel):
statuses: Dict[DocStatus, List[DocStatusResponse]] = {}
class DocumentManager:
def __init__(
self,
input_dir: str,
supported_extensions: tuple = (
".txt",
".md",
".pdf",
".docx",
".pptx",
".xlsx",
),
):
self.input_dir = Path(input_dir)
self.supported_extensions = supported_extensions
self.indexed_files = set()
# Create input directory if it doesn't exist
self.input_dir.mkdir(parents=True, exist_ok=True)
def scan_directory_for_new_files(self) -> List[Path]:
new_files = []
for ext in self.supported_extensions:
logging.info(f"Scanning for {ext} files in {self.input_dir}")
for file_path in self.input_dir.rglob(f"*{ext}"):
if file_path not in self.indexed_files:
new_files.append(file_path)
return new_files
def scan_directory(self) -> List[Path]:
new_files = []
for ext in self.supported_extensions:
for file_path in self.input_dir.rglob(f"*{ext}"):
new_files.append(file_path)
return new_files
def mark_as_indexed(self, file_path: Path):
self.indexed_files.add(file_path)
def is_supported_file(self, filename: str) -> bool:
return any(filename.lower().endswith(ext) for ext in self.supported_extensions)
async def pipeline_enqueue_file(rag, file_path: Path) -> bool:
try:
content = ""
ext = file_path.suffix.lower()
file = None
async with aiofiles.open(file_path, "rb") as f:
file = await f.read()
# Process based on file type
match ext:
case ".txt" | ".md":
content = file.decode("utf-8")
case ".pdf":
if not pm.is_installed("pypdf2"):
pm.install("pypdf2")
from PyPDF2 import PdfReader # type: ignore
from io import BytesIO
pdf_file = BytesIO(file)
reader = PdfReader(pdf_file)
for page in reader.pages:
content += page.extract_text() + "\n"
case ".docx":
if not pm.is_installed("docx"):
pm.install("docx")
from docx import Document
from io import BytesIO
docx_file = BytesIO(file)
doc = Document(docx_file)
content = "\n".join([paragraph.text for paragraph in doc.paragraphs])
case ".pptx":
if not pm.is_installed("pptx"):
pm.install("pptx")
from pptx import Presentation
from io import BytesIO
pptx_file = BytesIO(file)
prs = Presentation(pptx_file)
for slide in prs.slides:
for shape in slide.shapes:
if hasattr(shape, "text"):
content += shape.text + "\n"
case ".xlsx":
if not pm.is_installed("openpyxl"):
pm.install("openpyxl")
from openpyxl import load_workbook
from io import BytesIO
xlsx_file = BytesIO(file)
wb = load_workbook(xlsx_file)
for sheet in wb:
content += f"Sheet: {sheet.title}\n"
for row in sheet.iter_rows(values_only=True):
content += "\t".join(str(cell) if cell is not None else "" for cell in row) + "\n"
content += "\n"
case _:
logging.error(f"Unsupported file type: {file_path.name} (extension {ext})")
return False
# Insert into the RAG queue
if content:
await rag.apipeline_enqueue_documents(content)
logging.info(f"Successfully fetched and enqueued file: {file_path.name}")
return True
else:
logging.error(f"No content could be extracted from file: {file_path.name}")
except Exception as e:
logging.error(f"Error processing or enqueueing file {file_path.name}: {str(e)}")
logging.error(traceback.format_exc())
finally:
if file_path.name.startswith(temp_prefix):
try:
file_path.unlink()
except Exception as e:
logging.error(f"Error deleting file {file_path}: {str(e)}")
return False
async def pipeline_index_file(rag, file_path: Path):
"""Index a file
Args:
rag: LightRAG instance
file_path: Path to the saved file
"""
try:
content = ""
ext = file_path.suffix.lower()
file = None
async with aiofiles.open(file_path, "rb") as f:
file = await f.read()
# Process based on file type
match ext:
case ".txt" | ".md":
content = file.decode("utf-8")
case ".pdf":
if not pm.is_installed("pypdf2"):
pm.install("pypdf2")
from PyPDF2 import PdfReader # type: ignore
from io import BytesIO
pdf_file = BytesIO(file)
reader = PdfReader(pdf_file)
for page in reader.pages:
content += page.extract_text() + "\n"
case ".docx":
if not pm.is_installed("docx"):
pm.install("docx")
from docx import Document
from io import BytesIO
docx_file = BytesIO(file)
doc = Document(docx_file)
content = "\n".join([paragraph.text for paragraph in doc.paragraphs])
case ".pptx":
if not pm.is_installed("pptx"):
pm.install("pptx")
from pptx import Presentation
from io import BytesIO
pptx_file = BytesIO(file)
prs = Presentation(pptx_file)
for slide in prs.slides:
for shape in slide.shapes:
if hasattr(shape, "text"):
content += shape.text + "\n"
case ".xlsx":
if not pm.is_installed("openpyxl"):
pm.install("openpyxl")
from openpyxl import load_workbook
from io import BytesIO
xlsx_file = BytesIO(file)
wb = load_workbook(xlsx_file)
for sheet in wb:
content += f"Sheet: {sheet.title}\n"
for row in sheet.iter_rows(values_only=True):
content += "\t".join(str(cell) if cell is not None else "" for cell in row) + "\n"
content += "\n"
case _:
logging.error(f"Unsupported file type: {file_path.name} (extension {ext})")
return
# Insert into the RAG queue
if content:
await rag.apipeline_enqueue_documents(content)
await rag.apipeline_process_enqueue_documents()
logging.info(f"Successfully indexed file: {file_path.name}")
else:
logging.error(f"No content could be extracted from file: {file_path.name}")
except Exception as e:
logging.error(f"Error indexing file {file_path.name}: {str(e)}")
logging.error(traceback.format_exc())
async def pipeline_index_files(rag, file_paths: List[Path]):
if not file_paths:
return
try:
enqueued = False
if len(file_paths) == 1:
enqueued = await pipeline_enqueue_file(rag, file_paths[0])
else:
tasks = [pipeline_enqueue_file(rag, path) for path in file_paths]
enqueued = any(await asyncio.gather(*tasks))
if enqueued:
await rag.apipeline_process_enqueue_documents()
except Exception as e:
logging.error(f"Error indexing files: {str(e)}")
logging.error(traceback.format_exc())
async def pipeline_index_texts(rag, texts: List[str]):
if not texts:
return
await rag.apipeline_enqueue_documents(texts)
await rag.apipeline_process_enqueue_documents()
async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
unique_filename = f"{temp_prefix}{timestamp}_{file.filename}"
temp_path = input_dir / "temp" / unique_filename
temp_path.parent.mkdir(exist_ok=True)
with open(temp_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
return temp_path
async def run_scanning_process(rag, doc_manager: DocumentManager):
"""Background task to scan and index documents"""
try:
new_files = doc_manager.scan_directory_for_new_files()
scan_progress["total_files"] = len(new_files)
logging.info(f"Found {len(new_files)} new files to index.")
for file_path in new_files:
try:
async with progress_lock:
scan_progress["current_file"] = os.path.basename(file_path)
await pipeline_index_file(rag, file_path)
async with progress_lock:
scan_progress["indexed_count"] += 1
scan_progress["progress"] = (
scan_progress["indexed_count"] / scan_progress["total_files"]
) * 100
except Exception as e:
logging.error(f"Error indexing file {file_path}: {str(e)}")
except Exception as e:
logging.error(f"Error during scanning process: {str(e)}")
finally:
async with progress_lock:
scan_progress["is_scanning"] = False
def create_document_routes(rag, doc_manager: DocumentManager, api_key: Optional[str] = None):
optional_api_key = get_api_key_dependency(api_key)
@router.post("/scan", dependencies=[Depends(optional_api_key)])
async def scan_for_new_documents(background_tasks: BackgroundTasks):
"""
Trigger the scanning process for new documents.
This endpoint initiates a background task that scans the input directory for new documents
and processes them. If a scanning process is already running, it returns a status indicating
that fact.
Args:
background_tasks (BackgroundTasks): FastAPI background tasks handler
Returns:
dict: A dictionary containing the scanning status
"""
async with progress_lock:
if scan_progress["is_scanning"]:
return {"status": "already_scanning"}
scan_progress["is_scanning"] = True
scan_progress["indexed_count"] = 0
scan_progress["progress"] = 0
background_tasks.add_task(run_scanning_process, rag, doc_manager)
return {"status": "scanning_started"}
@router.get("/scan-progress")
async def get_scan_progress():
"""
Get the current progress of the document scanning process.
Returns:
dict: A dictionary containing the current scanning progress information including:
- is_scanning: Whether a scan is currently in progress
- current_file: The file currently being processed
- indexed_count: Number of files indexed so far
- total_files: Total number of files to process
- progress: Percentage of completion
"""
async with progress_lock:
return scan_progress
@router.post("/upload", dependencies=[Depends(optional_api_key)])
async def upload_to_input_dir(
background_tasks: BackgroundTasks, file: UploadFile = File(...)
):
"""
Upload a file to the input directory and index it.
This API endpoint accepts a file through an HTTP POST request, checks if the
uploaded file is of a supported type, saves it in the specified input directory,
indexes it for retrieval, and returns a success status with relevant details.
Args:
background_tasks: FastAPI BackgroundTasks for async processing
file (UploadFile): The file to be uploaded. It must have an allowed extension.
Returns:
InsertResponse: A response object containing the upload status and a message.
Raises:
HTTPException: If the file type is not supported (400) or other errors occur (500).
"""
try:
if not doc_manager.is_supported_file(file.filename):
raise HTTPException(
status_code=400,
detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}",
)
file_path = doc_manager.input_dir / file.filename
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
background_tasks.add_task(pipeline_index_file, rag, file_path)
return InsertResponse(
status="success",
message=f"File '{file.filename}' uploaded successfully. Processing will continue in background.",
)
except Exception as e:
logging.error(f"Error /documents/upload: {file.filename}: {str(e)}")
logging.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@router.post("/text", response_model=InsertResponse, dependencies=[Depends(optional_api_key)])
async def insert_text(request: InsertTextRequest, background_tasks: BackgroundTasks):
"""
Insert text into the RAG system.
This endpoint allows you to insert text data into the RAG system for later retrieval
and use in generating responses.
Args:
request (InsertTextRequest): The request body containing the text to be inserted.
background_tasks: FastAPI BackgroundTasks for async processing
Returns:
InsertResponse: A response object containing the status of the operation.
Raises:
HTTPException: If an error occurs during text processing (500).
"""
try:
background_tasks.add_task(pipeline_index_texts, rag, [request.text])
return InsertResponse(
status="success",
message="Text successfully received. Processing will continue in background.",
)
except Exception as e:
logging.error(f"Error /documents/text: {str(e)}")
logging.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@router.post("/texts", response_model=InsertResponse, dependencies=[Depends(optional_api_key)])
async def insert_texts(request: InsertTextsRequest, background_tasks: BackgroundTasks):
"""
Insert multiple texts into the RAG system.
This endpoint allows you to insert multiple text entries into the RAG system
in a single request.
Args:
request (InsertTextsRequest): The request body containing the list of texts.
background_tasks: FastAPI BackgroundTasks for async processing
Returns:
InsertResponse: A response object containing the status of the operation.
Raises:
HTTPException: If an error occurs during text processing (500).
"""
try:
background_tasks.add_task(pipeline_index_texts, rag, request.texts)
return InsertResponse(
status="success",
message="Text successfully received. Processing will continue in background.",
)
except Exception as e:
logging.error(f"Error /documents/text: {str(e)}")
logging.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@router.post("/file", response_model=InsertResponse, dependencies=[Depends(optional_api_key)])
async def insert_file(background_tasks: BackgroundTasks, file: UploadFile = File(...)):
"""
Insert a file directly into the RAG system.
This endpoint accepts a file upload and processes it for inclusion in the RAG system.
The file is saved temporarily and processed in the background.
Args:
background_tasks: FastAPI BackgroundTasks for async processing
file (UploadFile): The file to be processed
Returns:
InsertResponse: A response object containing the status of the operation.
Raises:
HTTPException: If the file type is not supported (400) or other errors occur (500).
"""
try:
if not doc_manager.is_supported_file(file.filename):
raise HTTPException(
status_code=400,
detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}",
)
temp_path = await save_temp_file(doc_manager.input_dir, file)
background_tasks.add_task(pipeline_index_file, rag, temp_path)
return InsertResponse(
status="success",
message=f"File '{file.filename}' saved successfully. Processing will continue in background.",
)
except Exception as e:
logging.error(f"Error /documents/file: {str(e)}")
logging.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@router.post("/file_batch", response_model=InsertResponse, dependencies=[Depends(optional_api_key)])
async def insert_batch(background_tasks: BackgroundTasks, files: List[UploadFile] = File(...)):
"""
Process multiple files in batch mode.
This endpoint allows uploading and processing multiple files simultaneously.
It handles partial successes and provides detailed feedback about failed files.
Args:
background_tasks: FastAPI BackgroundTasks for async processing
files (List[UploadFile]): List of files to process
Returns:
InsertResponse: A response object containing:
- status: "success", "partial_success", or "failure"
- message: Detailed information about the operation results
Raises:
HTTPException: If an error occurs during processing (500).
"""
try:
inserted_count = 0
failed_files = []
temp_files = []
for file in files:
if doc_manager.is_supported_file(file.filename):
temp_files.append(await save_temp_file(doc_manager.input_dir, file))
inserted_count += 1
else:
failed_files.append(f"{file.filename} (unsupported type)")
if temp_files:
background_tasks.add_task(pipeline_index_files, rag, temp_files)
if inserted_count == len(files):
status = "success"
status_message = f"Successfully inserted all {inserted_count} documents"
elif inserted_count > 0:
status = "partial_success"
status_message = f"Successfully inserted {inserted_count} out of {len(files)} documents"
if failed_files:
status_message += f". Failed files: {', '.join(failed_files)}"
else:
status = "failure"
status_message = "No documents were successfully inserted"
if failed_files:
status_message += f". Failed files: {', '.join(failed_files)}"
return InsertResponse(status=status, message=status_message)
except Exception as e:
logging.error(f"Error /documents/batch: {str(e)}")
logging.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@router.delete("", response_model=InsertResponse, dependencies=[Depends(optional_api_key)])
async def clear_documents():
"""
Clear all documents from the RAG system.
This endpoint deletes all text chunks, entities vector database, and relationships
vector database, effectively clearing all documents from the RAG system.
Returns:
InsertResponse: A response object containing the status and message.
Raises:
HTTPException: If an error occurs during the clearing process (500).
"""
try:
rag.text_chunks = []
rag.entities_vdb = None
rag.relationships_vdb = None
return InsertResponse(status="success", message="All documents cleared successfully")
except Exception as e:
logging.error(f"Error DELETE /documents: {str(e)}")
logging.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@router.get("", dependencies=[Depends(optional_api_key)])
async def documents() -> DocsStatusesResponse:
"""
Get the status of all documents in the system.
This endpoint retrieves the current status of all documents, grouped by their
processing status (PENDING, PROCESSING, PROCESSED, FAILED).
Returns:
DocsStatusesResponse: A response object containing a dictionary where keys are
DocStatus values and values are lists of DocStatusResponse
objects representing documents in each status category.
Raises:
HTTPException: If an error occurs while retrieving document statuses (500).
"""
try:
statuses = (
DocStatus.PENDING,
DocStatus.PROCESSING,
DocStatus.PROCESSED,
DocStatus.FAILED,
)
tasks = [rag.get_docs_by_status(status) for status in statuses]
results: List[Dict[str, DocProcessingStatus]] = await asyncio.gather(*tasks)
response = DocsStatusesResponse()
for idx, result in enumerate(results):
status = statuses[idx]
for doc_id, doc_status in result.items():
if status not in response.statuses:
response.statuses[status] = []
response.statuses[status].append(
DocStatusResponse(
id=doc_id,
content_summary=doc_status.content_summary,
content_length=doc_status.content_length,
status=doc_status.status,
created_at=DocStatusResponse.format_datetime(doc_status.created_at),
updated_at=DocStatusResponse.format_datetime(doc_status.updated_at),
chunks_count=doc_status.chunks_count,
error=doc_status.error,
metadata=doc_status.metadata,
)
)
return response
except Exception as e:
logging.error(f"Error GET /documents: {str(e)}")
logging.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
return router

View File

@@ -0,0 +1,26 @@
"""
This module contains all graph-related routes for the LightRAG API.
"""
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException
from ..utils_api import get_api_key_dependency
router = APIRouter(tags=["graph"])
def create_graph_routes(rag, api_key: Optional[str] = None):
optional_api_key = get_api_key_dependency(api_key)
@router.get("/graph/label/list", dependencies=[Depends(optional_api_key)])
async def get_graph_labels():
"""Get all graph labels"""
return await rag.get_graph_labels()
@router.get("/graphs", dependencies=[Depends(optional_api_key)])
async def get_knowledge_graph(label: str):
"""Get knowledge graph for a specific label"""
return await rag.get_knowledge_graph(nodel_label=label, max_depth=100)
return router

View File

@@ -0,0 +1,225 @@
"""
This module contains all query-related routes for the LightRAG API.
"""
import json
import logging
import traceback
from typing import Any, Dict, List, Literal, Optional
from fastapi import APIRouter, Depends, HTTPException
from lightrag.base import QueryParam
from ..utils_api import get_api_key_dependency
from pydantic import BaseModel, Field, field_validator
from ascii_colors import trace_exception
router = APIRouter(tags=["query"])
class QueryRequest(BaseModel):
query: str = Field(
min_length=1,
description="The query text",
)
mode: Literal["local", "global", "hybrid", "naive", "mix"] = Field(
default="hybrid",
description="Query mode",
)
only_need_context: Optional[bool] = Field(
default=None,
description="If True, only returns the retrieved context without generating a response.",
)
only_need_prompt: Optional[bool] = Field(
default=None,
description="If True, only returns the generated prompt without producing a response.",
)
response_type: Optional[str] = Field(
min_length=1,
default=None,
description="Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'.",
)
top_k: Optional[int] = Field(
ge=1,
default=None,
description="Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.",
)
max_token_for_text_unit: Optional[int] = Field(
gt=1,
default=None,
description="Maximum number of tokens allowed for each retrieved text chunk.",
)
max_token_for_global_context: Optional[int] = Field(
gt=1,
default=None,
description="Maximum number of tokens allocated for relationship descriptions in global retrieval.",
)
max_token_for_local_context: Optional[int] = Field(
gt=1,
default=None,
description="Maximum number of tokens allocated for entity descriptions in local retrieval.",
)
hl_keywords: Optional[List[str]] = Field(
default=None,
description="List of high-level keywords to prioritize in retrieval.",
)
ll_keywords: Optional[List[str]] = Field(
default=None,
description="List of low-level keywords to refine retrieval focus.",
)
conversation_history: Optional[List[Dict[str, Any]]] = Field(
default=None,
description="Stores past conversation history to maintain context. Format: [{'role': 'user/assistant', 'content': 'message'}].",
)
history_turns: Optional[int] = Field(
ge=0,
default=None,
description="Number of complete conversation turns (user-assistant pairs) to consider in the response context.",
)
@field_validator("query", mode="after")
@classmethod
def query_strip_after(cls, query: str) -> str:
return query.strip()
@field_validator("hl_keywords", mode="after")
@classmethod
def hl_keywords_strip_after(cls, hl_keywords: List[str] | None) -> List[str] | None:
if hl_keywords is None:
return None
return [keyword.strip() for keyword in hl_keywords]
@field_validator("ll_keywords", mode="after")
@classmethod
def ll_keywords_strip_after(cls, ll_keywords: List[str] | None) -> List[str] | None:
if ll_keywords is None:
return None
return [keyword.strip() for keyword in ll_keywords]
@field_validator("conversation_history", mode="after")
@classmethod
def conversation_history_role_check(
cls, conversation_history: List[Dict[str, Any]] | None
) -> List[Dict[str, Any]] | None:
if conversation_history is None:
return None
for msg in conversation_history:
if "role" not in msg or msg["role"] not in {"user", "assistant"}:
raise ValueError(
"Each message must have a 'role' key with value 'user' or 'assistant'."
)
return conversation_history
def to_query_params(self, is_stream: bool) -> "QueryParam":
"""Converts a QueryRequest instance into a QueryParam instance."""
# Use Pydantic's `.model_dump(exclude_none=True)` to remove None values automatically
request_data = self.model_dump(exclude_none=True, exclude={"query"})
# Ensure `mode` and `stream` are set explicitly
param = QueryParam(**request_data)
param.stream = is_stream
return param
class QueryResponse(BaseModel):
response: str = Field(
description="The generated response",
)
def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
optional_api_key = get_api_key_dependency(api_key)
@router.post("/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)])
async def query_text(request: QueryRequest):
"""
Handle a POST request at the /query endpoint to process user queries using RAG capabilities.
Parameters:
request (QueryRequest): The request object containing the query parameters.
Returns:
QueryResponse: A Pydantic model containing the result of the query processing.
If a string is returned (e.g., cache hit), it's directly returned.
Otherwise, an async generator may be used to build the response.
Raises:
HTTPException: Raised when an error occurs during the request handling process,
with status code 500 and detail containing the exception message.
"""
try:
param = request.to_query_params(False)
if param.top_k is None:
param.top_k = top_k
response = await rag.aquery(request.query, param=param)
# If response is a string (e.g. cache hit), return directly
if isinstance(response, str):
return QueryResponse(response=response)
if isinstance(response, dict):
result = json.dumps(response, indent=2)
return QueryResponse(response=result)
else:
return QueryResponse(response=str(response))
except Exception as e:
trace_exception(e)
raise HTTPException(status_code=500, detail=str(e))
@router.post("/query/stream", dependencies=[Depends(optional_api_key)])
async def query_text_stream(request: QueryRequest):
"""
This endpoint performs a retrieval-augmented generation (RAG) query and streams the response.
Args:
request (QueryRequest): The request object containing the query parameters.
optional_api_key (Optional[str], optional): An optional API key for authentication. Defaults to None.
Returns:
StreamingResponse: A streaming response containing the RAG query results.
"""
try:
param = request.to_query_params(True)
if param.top_k is None:
param.top_k = top_k
response = await rag.aquery(request.query, param=param)
from fastapi.responses import StreamingResponse
async def stream_generator():
if isinstance(response, str):
# If it's a string, send it all at once
yield f"{json.dumps({'response': response})}\n"
else:
# If it's an async generator, send chunks one by one
try:
async for chunk in response:
if chunk: # Only send non-empty content
yield f"{json.dumps({'response': chunk})}\n"
except Exception as e:
logging.error(f"Streaming error: {str(e)}")
yield f"{json.dumps({'error': str(e)})}\n"
return StreamingResponse(
stream_generator(),
media_type="application/x-ndjson",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Content-Type": "application/x-ndjson",
"X-Accel-Buffering": "no", # Ensure proper handling of streaming response when proxied by Nginx
},
)
except Exception as e:
trace_exception(e)
raise HTTPException(status_code=500, detail=str(e))
return router

44
lightrag/api/utils_api.py Normal file
View File

@@ -0,0 +1,44 @@
"""
Utility functions for the LightRAG API.
"""
from typing import Optional
from fastapi import HTTPException, Security
from fastapi.security import APIKeyHeader
from starlette.status import HTTP_403_FORBIDDEN
def get_api_key_dependency(api_key: Optional[str]):
"""
Create an API key dependency for route protection.
Args:
api_key (Optional[str]): The API key to validate against.
If None, no authentication is required.
Returns:
Callable: A dependency function that validates the API key.
"""
if not api_key:
# If no API key is configured, return a dummy dependency that always succeeds
async def no_auth():
return None
return no_auth
# If API key is configured, use proper authentication
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
async def api_key_auth(
api_key_header_value: Optional[str] = Security(api_key_header),
):
if not api_key_header_value:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="API Key required"
)
if api_key_header_value != api_key:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Invalid API Key"
)
return api_key_header_value
return api_key_auth