Merge pull request #791 from ArnoChenFx/refactor-server

Refactor File Indexing for Background Asynchronous Processing
This commit is contained in:
zrguo
2025-02-16 22:07:30 +08:00
committed by GitHub
6 changed files with 448 additions and 443 deletions

View File

@@ -3,7 +3,6 @@ from fastapi import (
HTTPException, HTTPException,
File, File,
UploadFile, UploadFile,
Form,
BackgroundTasks, BackgroundTasks,
) )
import asyncio import asyncio
@@ -14,7 +13,7 @@ import re
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
import logging import logging
import argparse import argparse
from typing import List, Any, Optional, Union, Dict from typing import List, Any, Optional, Dict
from pydantic import BaseModel from pydantic import BaseModel
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
from lightrag.types import GPTKeywordExtractionFormat from lightrag.types import GPTKeywordExtractionFormat
@@ -34,6 +33,9 @@ from starlette.status import HTTP_403_FORBIDDEN
import pipmaster as pm import pipmaster as pm
from dotenv import load_dotenv from dotenv import load_dotenv
import configparser import configparser
import traceback
from datetime import datetime
from lightrag.utils import logger from lightrag.utils import logger
from .ollama_api import ( from .ollama_api import (
OllamaAPI, OllamaAPI,
@@ -635,9 +637,47 @@ class SearchMode(str, Enum):
class QueryRequest(BaseModel): class QueryRequest(BaseModel):
query: str query: str
"""Specifies the retrieval mode"""
mode: SearchMode = SearchMode.hybrid mode: SearchMode = SearchMode.hybrid
stream: bool = False
only_need_context: bool = False """If True, enables streaming output for real-time responses."""
stream: Optional[bool] = None
"""If True, only returns the retrieved context without generating a response."""
only_need_context: Optional[bool] = None
"""If True, only returns the generated prompt without producing a response."""
only_need_prompt: Optional[bool] = None
"""Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'."""
response_type: Optional[str] = None
"""Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
top_k: Optional[int] = None
"""Maximum number of tokens allowed for each retrieved text chunk."""
max_token_for_text_unit: Optional[int] = None
"""Maximum number of tokens allocated for relationship descriptions in global retrieval."""
max_token_for_global_context: Optional[int] = None
"""Maximum number of tokens allocated for entity descriptions in local retrieval."""
max_token_for_local_context: Optional[int] = None
"""List of high-level keywords to prioritize in retrieval."""
hl_keywords: Optional[List[str]] = None
"""List of low-level keywords to refine retrieval focus."""
ll_keywords: Optional[List[str]] = None
"""Stores past conversation history to maintain context.
Format: [{"role": "user/assistant", "content": "message"}].
"""
conversation_history: Optional[List[dict[str, Any]]] = None
"""Number of complete conversation turns (user-assistant pairs) to consider in the response context."""
history_turns: Optional[int] = None
class QueryResponse(BaseModel): class QueryResponse(BaseModel):
@@ -646,13 +686,38 @@ class QueryResponse(BaseModel):
class InsertTextRequest(BaseModel): class InsertTextRequest(BaseModel):
text: str text: str
description: Optional[str] = None
class InsertResponse(BaseModel): class InsertResponse(BaseModel):
status: str status: str
message: str message: str
document_count: int
def QueryRequestToQueryParams(request: QueryRequest):
param = QueryParam(mode=request.mode, stream=request.stream)
if request.only_need_context is not None:
param.only_need_context = request.only_need_context
if request.only_need_prompt is not None:
param.only_need_prompt = request.only_need_prompt
if request.response_type is not None:
param.response_type = request.response_type
if request.top_k is not None:
param.top_k = request.top_k
if request.max_token_for_text_unit is not None:
param.max_token_for_text_unit = request.max_token_for_text_unit
if request.max_token_for_global_context is not None:
param.max_token_for_global_context = request.max_token_for_global_context
if request.max_token_for_local_context is not None:
param.max_token_for_local_context = request.max_token_for_local_context
if request.hl_keywords is not None:
param.hl_keywords = request.hl_keywords
if request.ll_keywords is not None:
param.ll_keywords = request.ll_keywords
if request.conversation_history is not None:
param.conversation_history = request.conversation_history
if request.history_turns is not None:
param.history_turns = request.history_turns
return param
def get_api_key_dependency(api_key: Optional[str]): def get_api_key_dependency(api_key: Optional[str]):
@@ -666,7 +731,9 @@ def get_api_key_dependency(api_key: Optional[str]):
# If API key is configured, use proper authentication # If API key is configured, use proper authentication
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)): async def api_key_auth(
api_key_header_value: Optional[str] = Security(api_key_header),
):
if not api_key_header_value: if not api_key_header_value:
raise HTTPException( raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="API Key required" status_code=HTTP_403_FORBIDDEN, detail="API Key required"
@@ -682,6 +749,7 @@ def get_api_key_dependency(api_key: Optional[str]):
# Global configuration # Global configuration
global_top_k = 60 # default value global_top_k = 60 # default value
temp_prefix = "__tmp_" # prefix for temporary files
def create_app(args): def create_app(args):
@@ -1132,61 +1200,194 @@ def create_app(args):
("llm_response_cache", rag.llm_response_cache), ("llm_response_cache", rag.llm_response_cache),
] ]
async def index_file(file_path: Union[str, Path]) -> None: async def pipeline_enqueue_file(file_path: Path) -> bool:
"""Index all files inside the folder with support for multiple file formats """Add a file to the queue for processing
Args: Args:
file_path: Path to the file to be indexed (str or Path object) file_path: Path to the saved file
Returns:
Raises: bool: True if the file was successfully enqueued, False otherwise
ValueError: If file format is not supported
FileNotFoundError: If file doesn't exist
""" """
if not pm.is_installed("aiofiles"): try:
pm.install("aiofiles") content = ""
ext = file_path.suffix.lower()
# Convert to Path object if string file = None
file_path = Path(file_path) async with aiofiles.open(file_path, "rb") as f:
file = await f.read()
# Check if file exists # Process based on file type
if not file_path.exists(): match ext:
raise FileNotFoundError(f"File not found: {file_path}") case ".txt" | ".md":
content = file.decode("utf-8")
case ".pdf":
if not pm.is_installed("pypdf2"):
pm.install("pypdf2")
from PyPDF2 import PdfReader
from io import BytesIO
content = "" pdf_file = BytesIO(file)
# Get file extension in lowercase reader = PdfReader(pdf_file)
ext = file_path.suffix.lower() for page in reader.pages:
content += page.extract_text() + "\n"
case ".docx":
if not pm.is_installed("docx"):
pm.install("docx")
from docx import Document
from io import BytesIO
match ext: docx_content = await file.read()
case ".txt" | ".md": docx_file = BytesIO(docx_content)
# Text files handling doc = Document(docx_file)
async with aiofiles.open(file_path, "r", encoding="utf-8") as f: content = "\n".join(
content = await f.read() [paragraph.text for paragraph in doc.paragraphs]
)
case ".pptx":
if not pm.is_installed("pptx"):
pm.install("pptx")
from pptx import Presentation # type: ignore
from io import BytesIO
case ".pdf" | ".docx" | ".pptx" | ".xlsx": pptx_content = await file.read()
if not pm.is_installed("docling"): pptx_file = BytesIO(pptx_content)
pm.install("docling") prs = Presentation(pptx_file)
from docling.document_converter import DocumentConverter for slide in prs.slides:
for shape in slide.shapes:
if hasattr(shape, "text"):
content += shape.text + "\n"
case _:
logging.error(
f"Unsupported file type: {file_path.name} (extension {ext})"
)
return False
async def convert_doc(): # Insert into the RAG queue
def sync_convert(): if content:
converter = DocumentConverter() await rag.apipeline_enqueue_documents(content)
result = converter.convert(file_path) logging.info(
return result.document.export_to_markdown() f"Successfully processed and enqueued file: {file_path.name}"
)
return True
else:
logging.error(
f"No content could be extracted from file: {file_path.name}"
)
return await asyncio.to_thread(sync_convert) except Exception as e:
logging.error(
f"Error processing or enqueueing file {file_path.name}: {str(e)}"
)
logging.error(traceback.format_exc())
finally:
if file_path.name.startswith(temp_prefix):
# Clean up the temporary file after indexing
try:
file_path.unlink()
except Exception as e:
logging.error(f"Error deleting file {file_path}: {str(e)}")
return False
content = await convert_doc() async def pipeline_index_file(file_path: Path):
"""Index a file
case _: Args:
raise ValueError(f"Unsupported file format: {ext}") file_path: Path to the saved file
"""
try:
if await pipeline_enqueue_file(file_path):
await rag.apipeline_process_enqueue_documents()
# Insert content into RAG system except Exception as e:
if content: logging.error(f"Error indexing file {file_path.name}: {str(e)}")
await rag.ainsert(content) logging.error(traceback.format_exc())
doc_manager.mark_as_indexed(file_path)
logging.info(f"Successfully indexed file: {file_path}") async def pipeline_index_files(file_paths: List[Path]):
else: """Index multiple files concurrently
logging.warning(f"No content extracted from file: {file_path}")
Args:
file_paths: Paths to the files to index
"""
if not file_paths:
return
try:
enqueued = False
if len(file_paths) == 1:
enqueued = await pipeline_enqueue_file(file_paths[0])
else:
tasks = [pipeline_enqueue_file(path) for path in file_paths]
enqueued = any(await asyncio.gather(*tasks))
if enqueued:
await rag.apipeline_process_enqueue_documents()
except Exception as e:
logging.error(f"Error indexing files: {str(e)}")
logging.error(traceback.format_exc())
async def pipeline_index_texts(texts: List[str]):
"""Index a list of texts
Args:
texts: The texts to index
"""
if not texts:
return
await rag.apipeline_enqueue_documents(texts)
await rag.apipeline_process_enqueue_documents()
async def save_temp_file(file: UploadFile = File(...)) -> Path:
"""Save the uploaded file to a temporary location
Args:
file: The uploaded file
Returns:
Path: The path to the saved file
"""
# Generate unique filename to avoid conflicts
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
unique_filename = f"{temp_prefix}{timestamp}_{file.filename}"
# Create a temporary file to save the uploaded content
temp_path = doc_manager.input_dir / "temp" / unique_filename
temp_path.parent.mkdir(exist_ok=True)
# Save the file
with open(temp_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
return temp_path
async def run_scanning_process():
"""Background task to scan and index documents"""
global scan_progress
try:
new_files = doc_manager.scan_directory_for_new_files()
scan_progress["total_files"] = len(new_files)
logger.info(f"Found {len(new_files)} new files to index.")
for file_path in new_files:
try:
with progress_lock:
scan_progress["current_file"] = os.path.basename(file_path)
await pipeline_index_file(file_path)
with progress_lock:
scan_progress["indexed_count"] += 1
scan_progress["progress"] = (
scan_progress["indexed_count"]
/ scan_progress["total_files"]
) * 100
except Exception as e:
logging.error(f"Error indexing file {file_path}: {str(e)}")
except Exception as e:
logging.error(f"Error during scanning process: {str(e)}")
finally:
with progress_lock:
scan_progress["is_scanning"] = False
@app.post("/documents/scan", dependencies=[Depends(optional_api_key)]) @app.post("/documents/scan", dependencies=[Depends(optional_api_key)])
async def scan_for_new_documents(background_tasks: BackgroundTasks): async def scan_for_new_documents(background_tasks: BackgroundTasks):
@@ -1206,38 +1407,6 @@ def create_app(args):
return {"status": "scanning_started"} return {"status": "scanning_started"}
async def run_scanning_process():
"""Background task to scan and index documents"""
global scan_progress
try:
new_files = doc_manager.scan_directory_for_new_files()
scan_progress["total_files"] = len(new_files)
logger.info(f"Found {len(new_files)} new files to index.")
for file_path in new_files:
try:
with progress_lock:
scan_progress["current_file"] = os.path.basename(file_path)
await index_file(file_path)
with progress_lock:
scan_progress["indexed_count"] += 1
scan_progress["progress"] = (
scan_progress["indexed_count"]
/ scan_progress["total_files"]
) * 100
except Exception as e:
logging.error(f"Error indexing file {file_path}: {str(e)}")
except Exception as e:
logging.error(f"Error during scanning process: {str(e)}")
finally:
with progress_lock:
scan_progress["is_scanning"] = False
@app.get("/documents/scan-progress") @app.get("/documents/scan-progress")
async def get_scan_progress(): async def get_scan_progress():
"""Get the current scanning progress""" """Get the current scanning progress"""
@@ -1245,7 +1414,9 @@ def create_app(args):
return scan_progress return scan_progress
@app.post("/documents/upload", dependencies=[Depends(optional_api_key)]) @app.post("/documents/upload", dependencies=[Depends(optional_api_key)])
async def upload_to_input_dir(file: UploadFile = File(...)): async def upload_to_input_dir(
background_tasks: BackgroundTasks, file: UploadFile = File(...)
):
""" """
Endpoint for uploading a file to the input directory and indexing it. Endpoint for uploading a file to the input directory and indexing it.
@@ -1254,6 +1425,7 @@ def create_app(args):
indexes it for retrieval, and returns a success status with relevant details. indexes it for retrieval, and returns a success status with relevant details.
Parameters: Parameters:
background_tasks: FastAPI BackgroundTasks for async processing
file (UploadFile): The file to be uploaded. It must have an allowed extension as per file (UploadFile): The file to be uploaded. It must have an allowed extension as per
`doc_manager.supported_extensions`. `doc_manager.supported_extensions`.
@@ -1278,15 +1450,175 @@ def create_app(args):
with open(file_path, "wb") as buffer: with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer) shutil.copyfileobj(file.file, buffer)
# Immediately index the uploaded file # Add to background tasks
await index_file(file_path) background_tasks.add_task(pipeline_index_file, file_path)
return { return InsertResponse(
"status": "success", status="success",
"message": f"File uploaded and indexed: {file.filename}", message=f"File '{file.filename}' uploaded successfully. Processing will continue in background.",
"total_documents": len(doc_manager.indexed_files), )
}
except Exception as e: except Exception as e:
logging.error(f"Error /documents/upload: {file.filename}: {str(e)}")
logging.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@app.post(
"/documents/text",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_text(
request: InsertTextRequest, background_tasks: BackgroundTasks
):
"""
Insert text into the Retrieval-Augmented Generation (RAG) system.
This endpoint allows you to insert text data into the RAG system for later retrieval and use in generating responses.
Args:
request (InsertTextRequest): The request body containing the text to be inserted.
background_tasks: FastAPI BackgroundTasks for async processing
Returns:
InsertResponse: A response object containing the status of the operation, a message, and the number of documents inserted.
"""
try:
background_tasks.add_task(pipeline_index_texts, [request.text])
return InsertResponse(
status="success",
message="Text successfully received. Processing will continue in background.",
)
except Exception as e:
logging.error(f"Error /documents/text: {str(e)}")
logging.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@app.post(
"/documents/file",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_file(
background_tasks: BackgroundTasks, file: UploadFile = File(...)
):
"""Insert a file directly into the RAG system
Args:
background_tasks: FastAPI BackgroundTasks for async processing
file: Uploaded file
Returns:
InsertResponse: Status of the insertion operation
Raises:
HTTPException: For unsupported file types or processing errors
"""
try:
if not doc_manager.is_supported_file(file.filename):
raise HTTPException(
status_code=400,
detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}",
)
# Create a temporary file to save the uploaded content
temp_path = save_temp_file(file)
# Add to background tasks
background_tasks.add_task(pipeline_index_file, temp_path)
return InsertResponse(
status="success",
message=f"File '{file.filename}' saved successfully. Processing will continue in background.",
)
except Exception as e:
logging.error(f"Error /documents/file: {str(e)}")
logging.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@app.post(
"/documents/batch",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_batch(
background_tasks: BackgroundTasks, files: List[UploadFile] = File(...)
):
"""Process multiple files in batch mode
Args:
background_tasks: FastAPI BackgroundTasks for async processing
files: List of files to process
Returns:
InsertResponse: Status of the batch insertion operation
Raises:
HTTPException: For processing errors
"""
try:
inserted_count = 0
failed_files = []
temp_files = []
for file in files:
if doc_manager.is_supported_file(file.filename):
# Create a temporary file to save the uploaded content
temp_files.append(save_temp_file(file))
inserted_count += 1
else:
failed_files.append(f"{file.filename} (unsupported type)")
if temp_files:
background_tasks.add_task(pipeline_index_files, temp_files)
# Prepare status message
if inserted_count == len(files):
status = "success"
status_message = f"Successfully inserted all {inserted_count} documents"
elif inserted_count > 0:
status = "partial_success"
status_message = f"Successfully inserted {inserted_count} out of {len(files)} documents"
if failed_files:
status_message += f". Failed files: {', '.join(failed_files)}"
else:
status = "failure"
status_message = "No documents were successfully inserted"
if failed_files:
status_message += f". Failed files: {', '.join(failed_files)}"
return InsertResponse(status=status, message=status_message)
except Exception as e:
logging.error(f"Error /documents/batch: {file.filename}: {str(e)}")
logging.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@app.delete(
"/documents",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def clear_documents():
"""
Clear all documents from the LightRAG system.
This endpoint deletes all text chunks, entities vector database, and relationships vector database,
effectively clearing all documents from the LightRAG system.
Returns:
InsertResponse: A response object containing the status, message, and the new document count (0 in this case).
"""
try:
rag.text_chunks = []
rag.entities_vdb = None
rag.relationships_vdb = None
return InsertResponse(
status="success", message="All documents cleared successfully"
)
except Exception as e:
logging.error(f"Error DELETE /documents: {str(e)}")
logging.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.post( @app.post(
@@ -1297,12 +1629,7 @@ def create_app(args):
Handle a POST request at the /query endpoint to process user queries using RAG capabilities. Handle a POST request at the /query endpoint to process user queries using RAG capabilities.
Parameters: Parameters:
request (QueryRequest): A Pydantic model containing the following fields: request (QueryRequest): The request object containing the query parameters.
- query (str): The text of the user's query.
- mode (ModeEnum): Optional. Specifies the mode of retrieval augmentation.
- stream (bool): Optional. Determines if the response should be streamed.
- only_need_context (bool): Optional. If true, returns only the context without further processing.
Returns: Returns:
QueryResponse: A Pydantic model containing the result of the query processing. QueryResponse: A Pydantic model containing the result of the query processing.
If a string is returned (e.g., cache hit), it's directly returned. If a string is returned (e.g., cache hit), it's directly returned.
@@ -1314,13 +1641,7 @@ def create_app(args):
""" """
try: try:
response = await rag.aquery( response = await rag.aquery(
request.query, request.query, param=QueryRequestToQueryParams(request)
param=QueryParam(
mode=request.mode,
stream=request.stream,
only_need_context=request.only_need_context,
top_k=global_top_k,
),
) )
# If response is a string (e.g. cache hit), return directly # If response is a string (e.g. cache hit), return directly
@@ -1328,16 +1649,16 @@ def create_app(args):
return QueryResponse(response=response) return QueryResponse(response=response)
# If it's an async generator, decide whether to stream based on stream parameter # If it's an async generator, decide whether to stream based on stream parameter
if request.stream: if request.stream or hasattr(response, "__aiter__"):
result = "" result = ""
async for chunk in response: async for chunk in response:
result += chunk result += chunk
return QueryResponse(response=result) return QueryResponse(response=result)
elif isinstance(response, dict):
result = json.dumps(response, indent=2)
return QueryResponse(response=result)
else: else:
result = "" return QueryResponse(response=str(response))
async for chunk in response:
result += chunk
return QueryResponse(response=result)
except Exception as e: except Exception as e:
trace_exception(e) trace_exception(e)
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@@ -1355,14 +1676,11 @@ def create_app(args):
StreamingResponse: A streaming response containing the RAG query results. StreamingResponse: A streaming response containing the RAG query results.
""" """
try: try:
params = QueryRequestToQueryParams(request)
params.stream = True
response = await rag.aquery( # Use aquery instead of query, and add await response = await rag.aquery( # Use aquery instead of query, and add await
request.query, request.query, param=params
param=QueryParam(
mode=request.mode,
stream=True,
only_need_context=request.only_need_context,
top_k=global_top_k,
),
) )
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
@@ -1395,255 +1713,6 @@ def create_app(args):
trace_exception(e) trace_exception(e)
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.post(
"/documents/text",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_text(request: InsertTextRequest):
"""
Insert text into the Retrieval-Augmented Generation (RAG) system.
This endpoint allows you to insert text data into the RAG system for later retrieval and use in generating responses.
Args:
request (InsertTextRequest): The request body containing the text to be inserted.
Returns:
InsertResponse: A response object containing the status of the operation, a message, and the number of documents inserted.
"""
try:
await rag.ainsert(request.text)
return InsertResponse(
status="success",
message="Text successfully inserted",
document_count=1,
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post(
"/documents/file",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_file(file: UploadFile = File(...), description: str = Form(None)):
"""Insert a file directly into the RAG system
Args:
file: Uploaded file
description: Optional description of the file
Returns:
InsertResponse: Status of the insertion operation
Raises:
HTTPException: For unsupported file types or processing errors
"""
try:
content = ""
# Get file extension in lowercase
ext = Path(file.filename).suffix.lower()
match ext:
case ".txt" | ".md":
# Text files handling
text_content = await file.read()
content = text_content.decode("utf-8")
case ".pdf" | ".docx" | ".pptx" | ".xlsx":
if not pm.is_installed("docling"):
pm.install("docling")
from docling.document_converter import DocumentConverter
# Create a temporary file to save the uploaded content
temp_path = Path("temp") / file.filename
temp_path.parent.mkdir(exist_ok=True)
# Save the uploaded file
with temp_path.open("wb") as f:
f.write(await file.read())
try:
async def convert_doc():
def sync_convert():
converter = DocumentConverter()
result = converter.convert(str(temp_path))
return result.document.export_to_markdown()
return await asyncio.to_thread(sync_convert)
content = await convert_doc()
finally:
# Clean up the temporary file
temp_path.unlink()
# Insert content into RAG system
if content:
# Add description if provided
if description:
content = f"{description}\n\n{content}"
await rag.ainsert(content)
logging.info(f"Successfully indexed file: {file.filename}")
return InsertResponse(
status="success",
message=f"File '{file.filename}' successfully inserted",
document_count=1,
)
else:
raise HTTPException(
status_code=400,
detail="No content could be extracted from the file",
)
except UnicodeDecodeError:
raise HTTPException(status_code=400, detail="File encoding not supported")
except Exception as e:
logging.error(f"Error processing file {file.filename}: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.post(
"/documents/batch",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_batch(files: List[UploadFile] = File(...)):
"""Process multiple files in batch mode
Args:
files: List of files to process
Returns:
InsertResponse: Status of the batch insertion operation
Raises:
HTTPException: For processing errors
"""
try:
inserted_count = 0
failed_files = []
for file in files:
try:
content = ""
ext = Path(file.filename).suffix.lower()
match ext:
case ".txt" | ".md":
text_content = await file.read()
content = text_content.decode("utf-8")
case ".pdf":
if not pm.is_installed("pypdf2"):
pm.install("pypdf2")
from PyPDF2 import PdfReader
from io import BytesIO
pdf_content = await file.read()
pdf_file = BytesIO(pdf_content)
reader = PdfReader(pdf_file)
for page in reader.pages:
content += page.extract_text() + "\n"
case ".docx":
if not pm.is_installed("docx"):
pm.install("docx")
from docx import Document
from io import BytesIO
docx_content = await file.read()
docx_file = BytesIO(docx_content)
doc = Document(docx_file)
content = "\n".join(
[paragraph.text for paragraph in doc.paragraphs]
)
case ".pptx":
if not pm.is_installed("pptx"):
pm.install("pptx")
from pptx import Presentation # type: ignore
from io import BytesIO
pptx_content = await file.read()
pptx_file = BytesIO(pptx_content)
prs = Presentation(pptx_file)
for slide in prs.slides:
for shape in slide.shapes:
if hasattr(shape, "text"):
content += shape.text + "\n"
case _:
failed_files.append(f"{file.filename} (unsupported type)")
continue
if content:
await rag.ainsert(content)
inserted_count += 1
logging.info(f"Successfully indexed file: {file.filename}")
else:
failed_files.append(f"{file.filename} (no content extracted)")
except UnicodeDecodeError:
failed_files.append(f"{file.filename} (encoding error)")
except Exception as e:
failed_files.append(f"{file.filename} ({str(e)})")
logging.error(f"Error processing file {file.filename}: {str(e)}")
# Prepare status message
if inserted_count == len(files):
status = "success"
status_message = f"Successfully inserted all {inserted_count} documents"
elif inserted_count > 0:
status = "partial_success"
status_message = f"Successfully inserted {inserted_count} out of {len(files)} documents"
if failed_files:
status_message += f". Failed files: {', '.join(failed_files)}"
else:
status = "failure"
status_message = "No documents were successfully inserted"
if failed_files:
status_message += f". Failed files: {', '.join(failed_files)}"
return InsertResponse(
status=status,
message=status_message,
document_count=inserted_count,
)
except Exception as e:
logging.error(f"Batch processing error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.delete(
"/documents",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def clear_documents():
"""
Clear all documents from the LightRAG system.
This endpoint deletes all text chunks, entities vector database, and relationships vector database,
effectively clearing all documents from the LightRAG system.
Returns:
InsertResponse: A response object containing the status, message, and the new document count (0 in this case).
"""
try:
rag.text_chunks = []
rag.entities_vdb = None
rag.relationships_vdb = None
return InsertResponse(
status="success",
message="All documents cleared successfully",
document_count=0,
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# query all graph labels # query all graph labels
@app.get("/graph/label/list") @app.get("/graph/label/list")
async def get_graph_labels(): async def get_graph_labels():

View File

@@ -249,20 +249,10 @@ class DocStatusStorage(BaseKVStorage):
"""Get counts of documents in each status""" """Get counts of documents in each status"""
raise NotImplementedError raise NotImplementedError
async def get_failed_docs(self) -> dict[str, DocProcessingStatus]: async def get_docs_by_status(
"""Get all failed documents""" self, status: DocStatus
raise NotImplementedError ) -> dict[str, DocProcessingStatus]:
"""Get all documents with a specific status"""
async def get_pending_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all pending documents"""
raise NotImplementedError
async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all processing documents"""
raise NotImplementedError
async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all procesed documents"""
raise NotImplementedError raise NotImplementedError
async def update_doc_status(self, data: dict[str, Any]) -> None: async def update_doc_status(self, data: dict[str, Any]) -> None:

View File

@@ -93,36 +93,14 @@ class JsonDocStatusStorage(DocStatusStorage):
counts[doc["status"]] += 1 counts[doc["status"]] += 1
return counts return counts
async def get_failed_docs(self) -> dict[str, DocProcessingStatus]: async def get_docs_by_status(
"""Get all failed documents""" self, status: DocStatus
) -> dict[str, DocProcessingStatus]:
"""all documents with a specific status"""
return { return {
k: DocProcessingStatus(**v) k: DocProcessingStatus(**v)
for k, v in self._data.items() for k, v in self._data.items()
if v["status"] == DocStatus.FAILED if v["status"] == status
}
async def get_pending_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all pending documents"""
return {
k: DocProcessingStatus(**v)
for k, v in self._data.items()
if v["status"] == DocStatus.PENDING
}
async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all processed documents"""
return {
k: DocProcessingStatus(**v)
for k, v in self._data.items()
if v["status"] == DocStatus.PROCESSED
}
async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all processing documents"""
return {
k: DocProcessingStatus(**v)
for k, v in self._data.items()
if v["status"] == DocStatus.PROCESSING
} }
async def index_done_callback(self): async def index_done_callback(self):

View File

@@ -175,7 +175,7 @@ class MongoDocStatusStorage(DocStatusStorage):
async def get_docs_by_status( async def get_docs_by_status(
self, status: DocStatus self, status: DocStatus
) -> dict[str, DocProcessingStatus]: ) -> dict[str, DocProcessingStatus]:
"""Get all documents by status""" """Get all documents with a specific status"""
cursor = self._data.find({"status": status.value}) cursor = self._data.find({"status": status.value})
result = await cursor.to_list() result = await cursor.to_list()
return { return {
@@ -191,22 +191,6 @@ class MongoDocStatusStorage(DocStatusStorage):
for doc in result for doc in result
} }
async def get_failed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all failed documents"""
return await self.get_docs_by_status(DocStatus.FAILED)
async def get_pending_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all pending documents"""
return await self.get_docs_by_status(DocStatus.PENDING)
async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all processing documents"""
return await self.get_docs_by_status(DocStatus.PROCESSING)
async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all procesed documents"""
return await self.get_docs_by_status(DocStatus.PROCESSED)
@dataclass @dataclass
class MongoGraphStorage(BaseGraphStorage): class MongoGraphStorage(BaseGraphStorage):

View File

@@ -468,7 +468,7 @@ class PGDocStatusStorage(DocStatusStorage):
async def get_docs_by_status( async def get_docs_by_status(
self, status: DocStatus self, status: DocStatus
) -> Dict[str, DocProcessingStatus]: ) -> Dict[str, DocProcessingStatus]:
"""Get all documents by status""" """all documents with a specific status"""
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2" sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2"
params = {"workspace": self.db.workspace, "status": status} params = {"workspace": self.db.workspace, "status": status}
result = await self.db.query(sql, params, True) result = await self.db.query(sql, params, True)
@@ -485,22 +485,6 @@ class PGDocStatusStorage(DocStatusStorage):
for element in result for element in result
} }
async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]:
"""Get all failed documents"""
return await self.get_docs_by_status(DocStatus.FAILED)
async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]:
"""Get all pending documents"""
return await self.get_docs_by_status(DocStatus.PENDING)
async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all processing documents"""
return await self.get_docs_by_status(DocStatus.PROCESSING)
async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all procesed documents"""
return await self.get_docs_by_status(DocStatus.PROCESSED)
async def index_done_callback(self): async def index_done_callback(self):
"""Save data after indexing, but for PostgreSQL, we already saved them during the upsert stage, so no action to take here""" """Save data after indexing, but for PostgreSQL, we already saved them during the upsert stage, so no action to take here"""
logger.info("Doc status had been saved into postgresql db!") logger.info("Doc status had been saved into postgresql db!")

View File

@@ -89,7 +89,7 @@ STORAGE_IMPLEMENTATIONS = {
"PGDocStatusStorage", "PGDocStatusStorage",
"MongoDocStatusStorage", "MongoDocStatusStorage",
], ],
"required_methods": ["get_pending_docs"], "required_methods": ["get_docs_by_status"],
}, },
} }
@@ -230,7 +230,7 @@ class LightRAG:
"""LightRAG: Simple and Fast Retrieval-Augmented Generation.""" """LightRAG: Simple and Fast Retrieval-Augmented Generation."""
working_dir: str = field( working_dir: str = field(
default_factory=lambda: f'./lightrag_cache_{datetime.now().strftime("%Y-%m-%d-%H:%M:%S")}' default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
) )
"""Directory where cache and temporary files are stored.""" """Directory where cache and temporary files are stored."""
@@ -715,11 +715,11 @@ class LightRAG:
# 1. Get all pending, failed, and abnormally terminated processing documents. # 1. Get all pending, failed, and abnormally terminated processing documents.
to_process_docs: dict[str, DocProcessingStatus] = {} to_process_docs: dict[str, DocProcessingStatus] = {}
processing_docs = await self.doc_status.get_processing_docs() processing_docs = await self.doc_status.get_docs_by_status(DocStatus.PROCESSING)
to_process_docs.update(processing_docs) to_process_docs.update(processing_docs)
failed_docs = await self.doc_status.get_failed_docs() failed_docs = await self.doc_status.get_docs_by_status(DocStatus.FAILED)
to_process_docs.update(failed_docs) to_process_docs.update(failed_docs)
pendings_docs = await self.doc_status.get_pending_docs() pendings_docs = await self.doc_status.get_docs_by_status(DocStatus.PENDING)
to_process_docs.update(pendings_docs) to_process_docs.update(pendings_docs)
if not to_process_docs: if not to_process_docs: