Merge pull request #791 from ArnoChenFx/refactor-server

Refactor File Indexing for Background Asynchronous Processing
This commit is contained in:
zrguo
2025-02-16 22:07:30 +08:00
committed by GitHub
6 changed files with 448 additions and 443 deletions

View File

@@ -3,7 +3,6 @@ from fastapi import (
HTTPException,
File,
UploadFile,
Form,
BackgroundTasks,
)
import asyncio
@@ -14,7 +13,7 @@ import re
from fastapi.staticfiles import StaticFiles
import logging
import argparse
from typing import List, Any, Optional, Union, Dict
from typing import List, Any, Optional, Dict
from pydantic import BaseModel
from lightrag import LightRAG, QueryParam
from lightrag.types import GPTKeywordExtractionFormat
@@ -34,6 +33,9 @@ from starlette.status import HTTP_403_FORBIDDEN
import pipmaster as pm
from dotenv import load_dotenv
import configparser
import traceback
from datetime import datetime
from lightrag.utils import logger
from .ollama_api import (
OllamaAPI,
@@ -635,9 +637,47 @@ class SearchMode(str, Enum):
class QueryRequest(BaseModel):
query: str
"""Specifies the retrieval mode"""
mode: SearchMode = SearchMode.hybrid
stream: bool = False
only_need_context: bool = False
"""If True, enables streaming output for real-time responses."""
stream: Optional[bool] = None
"""If True, only returns the retrieved context without generating a response."""
only_need_context: Optional[bool] = None
"""If True, only returns the generated prompt without producing a response."""
only_need_prompt: Optional[bool] = None
"""Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'."""
response_type: Optional[str] = None
"""Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
top_k: Optional[int] = None
"""Maximum number of tokens allowed for each retrieved text chunk."""
max_token_for_text_unit: Optional[int] = None
"""Maximum number of tokens allocated for relationship descriptions in global retrieval."""
max_token_for_global_context: Optional[int] = None
"""Maximum number of tokens allocated for entity descriptions in local retrieval."""
max_token_for_local_context: Optional[int] = None
"""List of high-level keywords to prioritize in retrieval."""
hl_keywords: Optional[List[str]] = None
"""List of low-level keywords to refine retrieval focus."""
ll_keywords: Optional[List[str]] = None
"""Stores past conversation history to maintain context.
Format: [{"role": "user/assistant", "content": "message"}].
"""
conversation_history: Optional[List[dict[str, Any]]] = None
"""Number of complete conversation turns (user-assistant pairs) to consider in the response context."""
history_turns: Optional[int] = None
class QueryResponse(BaseModel):
@@ -646,13 +686,38 @@ class QueryResponse(BaseModel):
class InsertTextRequest(BaseModel):
text: str
description: Optional[str] = None
class InsertResponse(BaseModel):
status: str
message: str
document_count: int
def QueryRequestToQueryParams(request: QueryRequest):
param = QueryParam(mode=request.mode, stream=request.stream)
if request.only_need_context is not None:
param.only_need_context = request.only_need_context
if request.only_need_prompt is not None:
param.only_need_prompt = request.only_need_prompt
if request.response_type is not None:
param.response_type = request.response_type
if request.top_k is not None:
param.top_k = request.top_k
if request.max_token_for_text_unit is not None:
param.max_token_for_text_unit = request.max_token_for_text_unit
if request.max_token_for_global_context is not None:
param.max_token_for_global_context = request.max_token_for_global_context
if request.max_token_for_local_context is not None:
param.max_token_for_local_context = request.max_token_for_local_context
if request.hl_keywords is not None:
param.hl_keywords = request.hl_keywords
if request.ll_keywords is not None:
param.ll_keywords = request.ll_keywords
if request.conversation_history is not None:
param.conversation_history = request.conversation_history
if request.history_turns is not None:
param.history_turns = request.history_turns
return param
def get_api_key_dependency(api_key: Optional[str]):
@@ -666,7 +731,9 @@ def get_api_key_dependency(api_key: Optional[str]):
# If API key is configured, use proper authentication
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)):
async def api_key_auth(
api_key_header_value: Optional[str] = Security(api_key_header),
):
if not api_key_header_value:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="API Key required"
@@ -682,6 +749,7 @@ def get_api_key_dependency(api_key: Optional[str]):
# Global configuration
global_top_k = 60 # default value
temp_prefix = "__tmp_" # prefix for temporary files
def create_app(args):
@@ -1132,61 +1200,194 @@ def create_app(args):
("llm_response_cache", rag.llm_response_cache),
]
async def index_file(file_path: Union[str, Path]) -> None:
"""Index all files inside the folder with support for multiple file formats
async def pipeline_enqueue_file(file_path: Path) -> bool:
"""Add a file to the queue for processing
Args:
file_path: Path to the file to be indexed (str or Path object)
Raises:
ValueError: If file format is not supported
FileNotFoundError: If file doesn't exist
file_path: Path to the saved file
Returns:
bool: True if the file was successfully enqueued, False otherwise
"""
if not pm.is_installed("aiofiles"):
pm.install("aiofiles")
try:
content = ""
ext = file_path.suffix.lower()
# Convert to Path object if string
file_path = Path(file_path)
file = None
async with aiofiles.open(file_path, "rb") as f:
file = await f.read()
# Check if file exists
if not file_path.exists():
raise FileNotFoundError(f"File not found: {file_path}")
# Process based on file type
match ext:
case ".txt" | ".md":
content = file.decode("utf-8")
case ".pdf":
if not pm.is_installed("pypdf2"):
pm.install("pypdf2")
from PyPDF2 import PdfReader
from io import BytesIO
content = ""
# Get file extension in lowercase
ext = file_path.suffix.lower()
pdf_file = BytesIO(file)
reader = PdfReader(pdf_file)
for page in reader.pages:
content += page.extract_text() + "\n"
case ".docx":
if not pm.is_installed("docx"):
pm.install("docx")
from docx import Document
from io import BytesIO
match ext:
case ".txt" | ".md":
# Text files handling
async with aiofiles.open(file_path, "r", encoding="utf-8") as f:
content = await f.read()
docx_content = await file.read()
docx_file = BytesIO(docx_content)
doc = Document(docx_file)
content = "\n".join(
[paragraph.text for paragraph in doc.paragraphs]
)
case ".pptx":
if not pm.is_installed("pptx"):
pm.install("pptx")
from pptx import Presentation # type: ignore
from io import BytesIO
case ".pdf" | ".docx" | ".pptx" | ".xlsx":
if not pm.is_installed("docling"):
pm.install("docling")
from docling.document_converter import DocumentConverter
pptx_content = await file.read()
pptx_file = BytesIO(pptx_content)
prs = Presentation(pptx_file)
for slide in prs.slides:
for shape in slide.shapes:
if hasattr(shape, "text"):
content += shape.text + "\n"
case _:
logging.error(
f"Unsupported file type: {file_path.name} (extension {ext})"
)
return False
async def convert_doc():
def sync_convert():
converter = DocumentConverter()
result = converter.convert(file_path)
return result.document.export_to_markdown()
# Insert into the RAG queue
if content:
await rag.apipeline_enqueue_documents(content)
logging.info(
f"Successfully processed and enqueued file: {file_path.name}"
)
return True
else:
logging.error(
f"No content could be extracted from file: {file_path.name}"
)
return await asyncio.to_thread(sync_convert)
except Exception as e:
logging.error(
f"Error processing or enqueueing file {file_path.name}: {str(e)}"
)
logging.error(traceback.format_exc())
finally:
if file_path.name.startswith(temp_prefix):
# Clean up the temporary file after indexing
try:
file_path.unlink()
except Exception as e:
logging.error(f"Error deleting file {file_path}: {str(e)}")
return False
content = await convert_doc()
async def pipeline_index_file(file_path: Path):
"""Index a file
case _:
raise ValueError(f"Unsupported file format: {ext}")
Args:
file_path: Path to the saved file
"""
try:
if await pipeline_enqueue_file(file_path):
await rag.apipeline_process_enqueue_documents()
# Insert content into RAG system
if content:
await rag.ainsert(content)
doc_manager.mark_as_indexed(file_path)
logging.info(f"Successfully indexed file: {file_path}")
else:
logging.warning(f"No content extracted from file: {file_path}")
except Exception as e:
logging.error(f"Error indexing file {file_path.name}: {str(e)}")
logging.error(traceback.format_exc())
async def pipeline_index_files(file_paths: List[Path]):
"""Index multiple files concurrently
Args:
file_paths: Paths to the files to index
"""
if not file_paths:
return
try:
enqueued = False
if len(file_paths) == 1:
enqueued = await pipeline_enqueue_file(file_paths[0])
else:
tasks = [pipeline_enqueue_file(path) for path in file_paths]
enqueued = any(await asyncio.gather(*tasks))
if enqueued:
await rag.apipeline_process_enqueue_documents()
except Exception as e:
logging.error(f"Error indexing files: {str(e)}")
logging.error(traceback.format_exc())
async def pipeline_index_texts(texts: List[str]):
"""Index a list of texts
Args:
texts: The texts to index
"""
if not texts:
return
await rag.apipeline_enqueue_documents(texts)
await rag.apipeline_process_enqueue_documents()
async def save_temp_file(file: UploadFile = File(...)) -> Path:
"""Save the uploaded file to a temporary location
Args:
file: The uploaded file
Returns:
Path: The path to the saved file
"""
# Generate unique filename to avoid conflicts
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
unique_filename = f"{temp_prefix}{timestamp}_{file.filename}"
# Create a temporary file to save the uploaded content
temp_path = doc_manager.input_dir / "temp" / unique_filename
temp_path.parent.mkdir(exist_ok=True)
# Save the file
with open(temp_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
return temp_path
async def run_scanning_process():
"""Background task to scan and index documents"""
global scan_progress
try:
new_files = doc_manager.scan_directory_for_new_files()
scan_progress["total_files"] = len(new_files)
logger.info(f"Found {len(new_files)} new files to index.")
for file_path in new_files:
try:
with progress_lock:
scan_progress["current_file"] = os.path.basename(file_path)
await pipeline_index_file(file_path)
with progress_lock:
scan_progress["indexed_count"] += 1
scan_progress["progress"] = (
scan_progress["indexed_count"]
/ scan_progress["total_files"]
) * 100
except Exception as e:
logging.error(f"Error indexing file {file_path}: {str(e)}")
except Exception as e:
logging.error(f"Error during scanning process: {str(e)}")
finally:
with progress_lock:
scan_progress["is_scanning"] = False
@app.post("/documents/scan", dependencies=[Depends(optional_api_key)])
async def scan_for_new_documents(background_tasks: BackgroundTasks):
@@ -1206,38 +1407,6 @@ def create_app(args):
return {"status": "scanning_started"}
async def run_scanning_process():
"""Background task to scan and index documents"""
global scan_progress
try:
new_files = doc_manager.scan_directory_for_new_files()
scan_progress["total_files"] = len(new_files)
logger.info(f"Found {len(new_files)} new files to index.")
for file_path in new_files:
try:
with progress_lock:
scan_progress["current_file"] = os.path.basename(file_path)
await index_file(file_path)
with progress_lock:
scan_progress["indexed_count"] += 1
scan_progress["progress"] = (
scan_progress["indexed_count"]
/ scan_progress["total_files"]
) * 100
except Exception as e:
logging.error(f"Error indexing file {file_path}: {str(e)}")
except Exception as e:
logging.error(f"Error during scanning process: {str(e)}")
finally:
with progress_lock:
scan_progress["is_scanning"] = False
@app.get("/documents/scan-progress")
async def get_scan_progress():
"""Get the current scanning progress"""
@@ -1245,7 +1414,9 @@ def create_app(args):
return scan_progress
@app.post("/documents/upload", dependencies=[Depends(optional_api_key)])
async def upload_to_input_dir(file: UploadFile = File(...)):
async def upload_to_input_dir(
background_tasks: BackgroundTasks, file: UploadFile = File(...)
):
"""
Endpoint for uploading a file to the input directory and indexing it.
@@ -1254,6 +1425,7 @@ def create_app(args):
indexes it for retrieval, and returns a success status with relevant details.
Parameters:
background_tasks: FastAPI BackgroundTasks for async processing
file (UploadFile): The file to be uploaded. It must have an allowed extension as per
`doc_manager.supported_extensions`.
@@ -1278,15 +1450,175 @@ def create_app(args):
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
# Immediately index the uploaded file
await index_file(file_path)
# Add to background tasks
background_tasks.add_task(pipeline_index_file, file_path)
return {
"status": "success",
"message": f"File uploaded and indexed: {file.filename}",
"total_documents": len(doc_manager.indexed_files),
}
return InsertResponse(
status="success",
message=f"File '{file.filename}' uploaded successfully. Processing will continue in background.",
)
except Exception as e:
logging.error(f"Error /documents/upload: {file.filename}: {str(e)}")
logging.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@app.post(
"/documents/text",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_text(
request: InsertTextRequest, background_tasks: BackgroundTasks
):
"""
Insert text into the Retrieval-Augmented Generation (RAG) system.
This endpoint allows you to insert text data into the RAG system for later retrieval and use in generating responses.
Args:
request (InsertTextRequest): The request body containing the text to be inserted.
background_tasks: FastAPI BackgroundTasks for async processing
Returns:
InsertResponse: A response object containing the status of the operation, a message, and the number of documents inserted.
"""
try:
background_tasks.add_task(pipeline_index_texts, [request.text])
return InsertResponse(
status="success",
message="Text successfully received. Processing will continue in background.",
)
except Exception as e:
logging.error(f"Error /documents/text: {str(e)}")
logging.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@app.post(
"/documents/file",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_file(
background_tasks: BackgroundTasks, file: UploadFile = File(...)
):
"""Insert a file directly into the RAG system
Args:
background_tasks: FastAPI BackgroundTasks for async processing
file: Uploaded file
Returns:
InsertResponse: Status of the insertion operation
Raises:
HTTPException: For unsupported file types or processing errors
"""
try:
if not doc_manager.is_supported_file(file.filename):
raise HTTPException(
status_code=400,
detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}",
)
# Create a temporary file to save the uploaded content
temp_path = save_temp_file(file)
# Add to background tasks
background_tasks.add_task(pipeline_index_file, temp_path)
return InsertResponse(
status="success",
message=f"File '{file.filename}' saved successfully. Processing will continue in background.",
)
except Exception as e:
logging.error(f"Error /documents/file: {str(e)}")
logging.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@app.post(
"/documents/batch",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_batch(
background_tasks: BackgroundTasks, files: List[UploadFile] = File(...)
):
"""Process multiple files in batch mode
Args:
background_tasks: FastAPI BackgroundTasks for async processing
files: List of files to process
Returns:
InsertResponse: Status of the batch insertion operation
Raises:
HTTPException: For processing errors
"""
try:
inserted_count = 0
failed_files = []
temp_files = []
for file in files:
if doc_manager.is_supported_file(file.filename):
# Create a temporary file to save the uploaded content
temp_files.append(save_temp_file(file))
inserted_count += 1
else:
failed_files.append(f"{file.filename} (unsupported type)")
if temp_files:
background_tasks.add_task(pipeline_index_files, temp_files)
# Prepare status message
if inserted_count == len(files):
status = "success"
status_message = f"Successfully inserted all {inserted_count} documents"
elif inserted_count > 0:
status = "partial_success"
status_message = f"Successfully inserted {inserted_count} out of {len(files)} documents"
if failed_files:
status_message += f". Failed files: {', '.join(failed_files)}"
else:
status = "failure"
status_message = "No documents were successfully inserted"
if failed_files:
status_message += f". Failed files: {', '.join(failed_files)}"
return InsertResponse(status=status, message=status_message)
except Exception as e:
logging.error(f"Error /documents/batch: {file.filename}: {str(e)}")
logging.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@app.delete(
"/documents",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def clear_documents():
"""
Clear all documents from the LightRAG system.
This endpoint deletes all text chunks, entities vector database, and relationships vector database,
effectively clearing all documents from the LightRAG system.
Returns:
InsertResponse: A response object containing the status, message, and the new document count (0 in this case).
"""
try:
rag.text_chunks = []
rag.entities_vdb = None
rag.relationships_vdb = None
return InsertResponse(
status="success", message="All documents cleared successfully"
)
except Exception as e:
logging.error(f"Error DELETE /documents: {str(e)}")
logging.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@app.post(
@@ -1297,12 +1629,7 @@ def create_app(args):
Handle a POST request at the /query endpoint to process user queries using RAG capabilities.
Parameters:
request (QueryRequest): A Pydantic model containing the following fields:
- query (str): The text of the user's query.
- mode (ModeEnum): Optional. Specifies the mode of retrieval augmentation.
- stream (bool): Optional. Determines if the response should be streamed.
- only_need_context (bool): Optional. If true, returns only the context without further processing.
request (QueryRequest): The request object containing the query parameters.
Returns:
QueryResponse: A Pydantic model containing the result of the query processing.
If a string is returned (e.g., cache hit), it's directly returned.
@@ -1314,13 +1641,7 @@ def create_app(args):
"""
try:
response = await rag.aquery(
request.query,
param=QueryParam(
mode=request.mode,
stream=request.stream,
only_need_context=request.only_need_context,
top_k=global_top_k,
),
request.query, param=QueryRequestToQueryParams(request)
)
# If response is a string (e.g. cache hit), return directly
@@ -1328,16 +1649,16 @@ def create_app(args):
return QueryResponse(response=response)
# If it's an async generator, decide whether to stream based on stream parameter
if request.stream:
if request.stream or hasattr(response, "__aiter__"):
result = ""
async for chunk in response:
result += chunk
return QueryResponse(response=result)
elif isinstance(response, dict):
result = json.dumps(response, indent=2)
return QueryResponse(response=result)
else:
result = ""
async for chunk in response:
result += chunk
return QueryResponse(response=result)
return QueryResponse(response=str(response))
except Exception as e:
trace_exception(e)
raise HTTPException(status_code=500, detail=str(e))
@@ -1355,14 +1676,11 @@ def create_app(args):
StreamingResponse: A streaming response containing the RAG query results.
"""
try:
params = QueryRequestToQueryParams(request)
params.stream = True
response = await rag.aquery( # Use aquery instead of query, and add await
request.query,
param=QueryParam(
mode=request.mode,
stream=True,
only_need_context=request.only_need_context,
top_k=global_top_k,
),
request.query, param=params
)
from fastapi.responses import StreamingResponse
@@ -1395,255 +1713,6 @@ def create_app(args):
trace_exception(e)
raise HTTPException(status_code=500, detail=str(e))
@app.post(
"/documents/text",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_text(request: InsertTextRequest):
"""
Insert text into the Retrieval-Augmented Generation (RAG) system.
This endpoint allows you to insert text data into the RAG system for later retrieval and use in generating responses.
Args:
request (InsertTextRequest): The request body containing the text to be inserted.
Returns:
InsertResponse: A response object containing the status of the operation, a message, and the number of documents inserted.
"""
try:
await rag.ainsert(request.text)
return InsertResponse(
status="success",
message="Text successfully inserted",
document_count=1,
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post(
"/documents/file",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_file(file: UploadFile = File(...), description: str = Form(None)):
"""Insert a file directly into the RAG system
Args:
file: Uploaded file
description: Optional description of the file
Returns:
InsertResponse: Status of the insertion operation
Raises:
HTTPException: For unsupported file types or processing errors
"""
try:
content = ""
# Get file extension in lowercase
ext = Path(file.filename).suffix.lower()
match ext:
case ".txt" | ".md":
# Text files handling
text_content = await file.read()
content = text_content.decode("utf-8")
case ".pdf" | ".docx" | ".pptx" | ".xlsx":
if not pm.is_installed("docling"):
pm.install("docling")
from docling.document_converter import DocumentConverter
# Create a temporary file to save the uploaded content
temp_path = Path("temp") / file.filename
temp_path.parent.mkdir(exist_ok=True)
# Save the uploaded file
with temp_path.open("wb") as f:
f.write(await file.read())
try:
async def convert_doc():
def sync_convert():
converter = DocumentConverter()
result = converter.convert(str(temp_path))
return result.document.export_to_markdown()
return await asyncio.to_thread(sync_convert)
content = await convert_doc()
finally:
# Clean up the temporary file
temp_path.unlink()
# Insert content into RAG system
if content:
# Add description if provided
if description:
content = f"{description}\n\n{content}"
await rag.ainsert(content)
logging.info(f"Successfully indexed file: {file.filename}")
return InsertResponse(
status="success",
message=f"File '{file.filename}' successfully inserted",
document_count=1,
)
else:
raise HTTPException(
status_code=400,
detail="No content could be extracted from the file",
)
except UnicodeDecodeError:
raise HTTPException(status_code=400, detail="File encoding not supported")
except Exception as e:
logging.error(f"Error processing file {file.filename}: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.post(
"/documents/batch",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_batch(files: List[UploadFile] = File(...)):
"""Process multiple files in batch mode
Args:
files: List of files to process
Returns:
InsertResponse: Status of the batch insertion operation
Raises:
HTTPException: For processing errors
"""
try:
inserted_count = 0
failed_files = []
for file in files:
try:
content = ""
ext = Path(file.filename).suffix.lower()
match ext:
case ".txt" | ".md":
text_content = await file.read()
content = text_content.decode("utf-8")
case ".pdf":
if not pm.is_installed("pypdf2"):
pm.install("pypdf2")
from PyPDF2 import PdfReader
from io import BytesIO
pdf_content = await file.read()
pdf_file = BytesIO(pdf_content)
reader = PdfReader(pdf_file)
for page in reader.pages:
content += page.extract_text() + "\n"
case ".docx":
if not pm.is_installed("docx"):
pm.install("docx")
from docx import Document
from io import BytesIO
docx_content = await file.read()
docx_file = BytesIO(docx_content)
doc = Document(docx_file)
content = "\n".join(
[paragraph.text for paragraph in doc.paragraphs]
)
case ".pptx":
if not pm.is_installed("pptx"):
pm.install("pptx")
from pptx import Presentation # type: ignore
from io import BytesIO
pptx_content = await file.read()
pptx_file = BytesIO(pptx_content)
prs = Presentation(pptx_file)
for slide in prs.slides:
for shape in slide.shapes:
if hasattr(shape, "text"):
content += shape.text + "\n"
case _:
failed_files.append(f"{file.filename} (unsupported type)")
continue
if content:
await rag.ainsert(content)
inserted_count += 1
logging.info(f"Successfully indexed file: {file.filename}")
else:
failed_files.append(f"{file.filename} (no content extracted)")
except UnicodeDecodeError:
failed_files.append(f"{file.filename} (encoding error)")
except Exception as e:
failed_files.append(f"{file.filename} ({str(e)})")
logging.error(f"Error processing file {file.filename}: {str(e)}")
# Prepare status message
if inserted_count == len(files):
status = "success"
status_message = f"Successfully inserted all {inserted_count} documents"
elif inserted_count > 0:
status = "partial_success"
status_message = f"Successfully inserted {inserted_count} out of {len(files)} documents"
if failed_files:
status_message += f". Failed files: {', '.join(failed_files)}"
else:
status = "failure"
status_message = "No documents were successfully inserted"
if failed_files:
status_message += f". Failed files: {', '.join(failed_files)}"
return InsertResponse(
status=status,
message=status_message,
document_count=inserted_count,
)
except Exception as e:
logging.error(f"Batch processing error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.delete(
"/documents",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def clear_documents():
"""
Clear all documents from the LightRAG system.
This endpoint deletes all text chunks, entities vector database, and relationships vector database,
effectively clearing all documents from the LightRAG system.
Returns:
InsertResponse: A response object containing the status, message, and the new document count (0 in this case).
"""
try:
rag.text_chunks = []
rag.entities_vdb = None
rag.relationships_vdb = None
return InsertResponse(
status="success",
message="All documents cleared successfully",
document_count=0,
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# query all graph labels
@app.get("/graph/label/list")
async def get_graph_labels():

View File

@@ -249,20 +249,10 @@ class DocStatusStorage(BaseKVStorage):
"""Get counts of documents in each status"""
raise NotImplementedError
async def get_failed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all failed documents"""
raise NotImplementedError
async def get_pending_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all pending documents"""
raise NotImplementedError
async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all processing documents"""
raise NotImplementedError
async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all procesed documents"""
async def get_docs_by_status(
self, status: DocStatus
) -> dict[str, DocProcessingStatus]:
"""Get all documents with a specific status"""
raise NotImplementedError
async def update_doc_status(self, data: dict[str, Any]) -> None:

View File

@@ -93,36 +93,14 @@ class JsonDocStatusStorage(DocStatusStorage):
counts[doc["status"]] += 1
return counts
async def get_failed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all failed documents"""
async def get_docs_by_status(
self, status: DocStatus
) -> dict[str, DocProcessingStatus]:
"""all documents with a specific status"""
return {
k: DocProcessingStatus(**v)
for k, v in self._data.items()
if v["status"] == DocStatus.FAILED
}
async def get_pending_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all pending documents"""
return {
k: DocProcessingStatus(**v)
for k, v in self._data.items()
if v["status"] == DocStatus.PENDING
}
async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all processed documents"""
return {
k: DocProcessingStatus(**v)
for k, v in self._data.items()
if v["status"] == DocStatus.PROCESSED
}
async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all processing documents"""
return {
k: DocProcessingStatus(**v)
for k, v in self._data.items()
if v["status"] == DocStatus.PROCESSING
if v["status"] == status
}
async def index_done_callback(self):

View File

@@ -175,7 +175,7 @@ class MongoDocStatusStorage(DocStatusStorage):
async def get_docs_by_status(
self, status: DocStatus
) -> dict[str, DocProcessingStatus]:
"""Get all documents by status"""
"""Get all documents with a specific status"""
cursor = self._data.find({"status": status.value})
result = await cursor.to_list()
return {
@@ -191,22 +191,6 @@ class MongoDocStatusStorage(DocStatusStorage):
for doc in result
}
async def get_failed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all failed documents"""
return await self.get_docs_by_status(DocStatus.FAILED)
async def get_pending_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all pending documents"""
return await self.get_docs_by_status(DocStatus.PENDING)
async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all processing documents"""
return await self.get_docs_by_status(DocStatus.PROCESSING)
async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all procesed documents"""
return await self.get_docs_by_status(DocStatus.PROCESSED)
@dataclass
class MongoGraphStorage(BaseGraphStorage):

View File

@@ -468,7 +468,7 @@ class PGDocStatusStorage(DocStatusStorage):
async def get_docs_by_status(
self, status: DocStatus
) -> Dict[str, DocProcessingStatus]:
"""Get all documents by status"""
"""all documents with a specific status"""
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2"
params = {"workspace": self.db.workspace, "status": status}
result = await self.db.query(sql, params, True)
@@ -485,22 +485,6 @@ class PGDocStatusStorage(DocStatusStorage):
for element in result
}
async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]:
"""Get all failed documents"""
return await self.get_docs_by_status(DocStatus.FAILED)
async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]:
"""Get all pending documents"""
return await self.get_docs_by_status(DocStatus.PENDING)
async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all processing documents"""
return await self.get_docs_by_status(DocStatus.PROCESSING)
async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all procesed documents"""
return await self.get_docs_by_status(DocStatus.PROCESSED)
async def index_done_callback(self):
"""Save data after indexing, but for PostgreSQL, we already saved them during the upsert stage, so no action to take here"""
logger.info("Doc status had been saved into postgresql db!")

View File

@@ -89,7 +89,7 @@ STORAGE_IMPLEMENTATIONS = {
"PGDocStatusStorage",
"MongoDocStatusStorage",
],
"required_methods": ["get_pending_docs"],
"required_methods": ["get_docs_by_status"],
},
}
@@ -230,7 +230,7 @@ class LightRAG:
"""LightRAG: Simple and Fast Retrieval-Augmented Generation."""
working_dir: str = field(
default_factory=lambda: f'./lightrag_cache_{datetime.now().strftime("%Y-%m-%d-%H:%M:%S")}'
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
)
"""Directory where cache and temporary files are stored."""
@@ -715,11 +715,11 @@ class LightRAG:
# 1. Get all pending, failed, and abnormally terminated processing documents.
to_process_docs: dict[str, DocProcessingStatus] = {}
processing_docs = await self.doc_status.get_processing_docs()
processing_docs = await self.doc_status.get_docs_by_status(DocStatus.PROCESSING)
to_process_docs.update(processing_docs)
failed_docs = await self.doc_status.get_failed_docs()
failed_docs = await self.doc_status.get_docs_by_status(DocStatus.FAILED)
to_process_docs.update(failed_docs)
pendings_docs = await self.doc_status.get_pending_docs()
pendings_docs = await self.doc_status.get_docs_by_status(DocStatus.PENDING)
to_process_docs.update(pendings_docs)
if not to_process_docs: