diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 97f1156f..a392e67a 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -3,7 +3,6 @@ from fastapi import ( HTTPException, File, UploadFile, - Form, BackgroundTasks, ) import asyncio @@ -14,7 +13,7 @@ import re from fastapi.staticfiles import StaticFiles import logging import argparse -from typing import List, Any, Optional, Union, Dict +from typing import List, Any, Optional, Dict from pydantic import BaseModel from lightrag import LightRAG, QueryParam from lightrag.types import GPTKeywordExtractionFormat @@ -34,6 +33,9 @@ from starlette.status import HTTP_403_FORBIDDEN import pipmaster as pm from dotenv import load_dotenv import configparser +import traceback +from datetime import datetime + from lightrag.utils import logger from .ollama_api import ( OllamaAPI, @@ -635,9 +637,47 @@ class SearchMode(str, Enum): class QueryRequest(BaseModel): query: str + + """Specifies the retrieval mode""" mode: SearchMode = SearchMode.hybrid - stream: bool = False - only_need_context: bool = False + + """If True, enables streaming output for real-time responses.""" + stream: Optional[bool] = None + + """If True, only returns the retrieved context without generating a response.""" + only_need_context: Optional[bool] = None + + """If True, only returns the generated prompt without producing a response.""" + only_need_prompt: Optional[bool] = None + + """Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'.""" + response_type: Optional[str] = None + + """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.""" + top_k: Optional[int] = None + + """Maximum number of tokens allowed for each retrieved text chunk.""" + max_token_for_text_unit: Optional[int] = None + + """Maximum number of tokens allocated for relationship descriptions in global retrieval.""" + max_token_for_global_context: Optional[int] = None + + """Maximum number of tokens allocated for entity descriptions in local retrieval.""" + max_token_for_local_context: Optional[int] = None + + """List of high-level keywords to prioritize in retrieval.""" + hl_keywords: Optional[List[str]] = None + + """List of low-level keywords to refine retrieval focus.""" + ll_keywords: Optional[List[str]] = None + + """Stores past conversation history to maintain context. + Format: [{"role": "user/assistant", "content": "message"}]. + """ + conversation_history: Optional[List[dict[str, Any]]] = None + + """Number of complete conversation turns (user-assistant pairs) to consider in the response context.""" + history_turns: Optional[int] = None class QueryResponse(BaseModel): @@ -646,13 +686,38 @@ class QueryResponse(BaseModel): class InsertTextRequest(BaseModel): text: str - description: Optional[str] = None class InsertResponse(BaseModel): status: str message: str - document_count: int + + +def QueryRequestToQueryParams(request: QueryRequest): + param = QueryParam(mode=request.mode, stream=request.stream) + if request.only_need_context is not None: + param.only_need_context = request.only_need_context + if request.only_need_prompt is not None: + param.only_need_prompt = request.only_need_prompt + if request.response_type is not None: + param.response_type = request.response_type + if request.top_k is not None: + param.top_k = request.top_k + if request.max_token_for_text_unit is not None: + param.max_token_for_text_unit = request.max_token_for_text_unit + if request.max_token_for_global_context is not None: + param.max_token_for_global_context = request.max_token_for_global_context + if request.max_token_for_local_context is not None: + param.max_token_for_local_context = request.max_token_for_local_context + if request.hl_keywords is not None: + param.hl_keywords = request.hl_keywords + if request.ll_keywords is not None: + param.ll_keywords = request.ll_keywords + if request.conversation_history is not None: + param.conversation_history = request.conversation_history + if request.history_turns is not None: + param.history_turns = request.history_turns + return param def get_api_key_dependency(api_key: Optional[str]): @@ -666,7 +731,9 @@ def get_api_key_dependency(api_key: Optional[str]): # If API key is configured, use proper authentication api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) - async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)): + async def api_key_auth( + api_key_header_value: Optional[str] = Security(api_key_header), + ): if not api_key_header_value: raise HTTPException( status_code=HTTP_403_FORBIDDEN, detail="API Key required" @@ -682,6 +749,7 @@ def get_api_key_dependency(api_key: Optional[str]): # Global configuration global_top_k = 60 # default value +temp_prefix = "__tmp_" # prefix for temporary files def create_app(args): @@ -1132,61 +1200,194 @@ def create_app(args): ("llm_response_cache", rag.llm_response_cache), ] - async def index_file(file_path: Union[str, Path]) -> None: - """Index all files inside the folder with support for multiple file formats + async def pipeline_enqueue_file(file_path: Path) -> bool: + """Add a file to the queue for processing Args: - file_path: Path to the file to be indexed (str or Path object) - - Raises: - ValueError: If file format is not supported - FileNotFoundError: If file doesn't exist + file_path: Path to the saved file + Returns: + bool: True if the file was successfully enqueued, False otherwise """ - if not pm.is_installed("aiofiles"): - pm.install("aiofiles") + try: + content = "" + ext = file_path.suffix.lower() - # Convert to Path object if string - file_path = Path(file_path) + file = None + async with aiofiles.open(file_path, "rb") as f: + file = await f.read() - # Check if file exists - if not file_path.exists(): - raise FileNotFoundError(f"File not found: {file_path}") + # Process based on file type + match ext: + case ".txt" | ".md": + content = file.decode("utf-8") + case ".pdf": + if not pm.is_installed("pypdf2"): + pm.install("pypdf2") + from PyPDF2 import PdfReader + from io import BytesIO - content = "" - # Get file extension in lowercase - ext = file_path.suffix.lower() + pdf_file = BytesIO(file) + reader = PdfReader(pdf_file) + for page in reader.pages: + content += page.extract_text() + "\n" + case ".docx": + if not pm.is_installed("docx"): + pm.install("docx") + from docx import Document + from io import BytesIO - match ext: - case ".txt" | ".md": - # Text files handling - async with aiofiles.open(file_path, "r", encoding="utf-8") as f: - content = await f.read() + docx_content = await file.read() + docx_file = BytesIO(docx_content) + doc = Document(docx_file) + content = "\n".join( + [paragraph.text for paragraph in doc.paragraphs] + ) + case ".pptx": + if not pm.is_installed("pptx"): + pm.install("pptx") + from pptx import Presentation # type: ignore + from io import BytesIO - case ".pdf" | ".docx" | ".pptx" | ".xlsx": - if not pm.is_installed("docling"): - pm.install("docling") - from docling.document_converter import DocumentConverter + pptx_content = await file.read() + pptx_file = BytesIO(pptx_content) + prs = Presentation(pptx_file) + for slide in prs.slides: + for shape in slide.shapes: + if hasattr(shape, "text"): + content += shape.text + "\n" + case _: + logging.error( + f"Unsupported file type: {file_path.name} (extension {ext})" + ) + return False - async def convert_doc(): - def sync_convert(): - converter = DocumentConverter() - result = converter.convert(file_path) - return result.document.export_to_markdown() + # Insert into the RAG queue + if content: + await rag.apipeline_enqueue_documents(content) + logging.info( + f"Successfully processed and enqueued file: {file_path.name}" + ) + return True + else: + logging.error( + f"No content could be extracted from file: {file_path.name}" + ) - return await asyncio.to_thread(sync_convert) + except Exception as e: + logging.error( + f"Error processing or enqueueing file {file_path.name}: {str(e)}" + ) + logging.error(traceback.format_exc()) + finally: + if file_path.name.startswith(temp_prefix): + # Clean up the temporary file after indexing + try: + file_path.unlink() + except Exception as e: + logging.error(f"Error deleting file {file_path}: {str(e)}") + return False - content = await convert_doc() + async def pipeline_index_file(file_path: Path): + """Index a file - case _: - raise ValueError(f"Unsupported file format: {ext}") + Args: + file_path: Path to the saved file + """ + try: + if await pipeline_enqueue_file(file_path): + await rag.apipeline_process_enqueue_documents() - # Insert content into RAG system - if content: - await rag.ainsert(content) - doc_manager.mark_as_indexed(file_path) - logging.info(f"Successfully indexed file: {file_path}") - else: - logging.warning(f"No content extracted from file: {file_path}") + except Exception as e: + logging.error(f"Error indexing file {file_path.name}: {str(e)}") + logging.error(traceback.format_exc()) + + async def pipeline_index_files(file_paths: List[Path]): + """Index multiple files concurrently + + Args: + file_paths: Paths to the files to index + """ + if not file_paths: + return + try: + enqueued = False + + if len(file_paths) == 1: + enqueued = await pipeline_enqueue_file(file_paths[0]) + else: + tasks = [pipeline_enqueue_file(path) for path in file_paths] + enqueued = any(await asyncio.gather(*tasks)) + + if enqueued: + await rag.apipeline_process_enqueue_documents() + except Exception as e: + logging.error(f"Error indexing files: {str(e)}") + logging.error(traceback.format_exc()) + + async def pipeline_index_texts(texts: List[str]): + """Index a list of texts + + Args: + texts: The texts to index + """ + if not texts: + return + await rag.apipeline_enqueue_documents(texts) + await rag.apipeline_process_enqueue_documents() + + async def save_temp_file(file: UploadFile = File(...)) -> Path: + """Save the uploaded file to a temporary location + + Args: + file: The uploaded file + + Returns: + Path: The path to the saved file + """ + # Generate unique filename to avoid conflicts + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + unique_filename = f"{temp_prefix}{timestamp}_{file.filename}" + + # Create a temporary file to save the uploaded content + temp_path = doc_manager.input_dir / "temp" / unique_filename + temp_path.parent.mkdir(exist_ok=True) + + # Save the file + with open(temp_path, "wb") as buffer: + shutil.copyfileobj(file.file, buffer) + return temp_path + + async def run_scanning_process(): + """Background task to scan and index documents""" + global scan_progress + + try: + new_files = doc_manager.scan_directory_for_new_files() + scan_progress["total_files"] = len(new_files) + + logger.info(f"Found {len(new_files)} new files to index.") + for file_path in new_files: + try: + with progress_lock: + scan_progress["current_file"] = os.path.basename(file_path) + + await pipeline_index_file(file_path) + + with progress_lock: + scan_progress["indexed_count"] += 1 + scan_progress["progress"] = ( + scan_progress["indexed_count"] + / scan_progress["total_files"] + ) * 100 + + except Exception as e: + logging.error(f"Error indexing file {file_path}: {str(e)}") + + except Exception as e: + logging.error(f"Error during scanning process: {str(e)}") + finally: + with progress_lock: + scan_progress["is_scanning"] = False @app.post("/documents/scan", dependencies=[Depends(optional_api_key)]) async def scan_for_new_documents(background_tasks: BackgroundTasks): @@ -1206,38 +1407,6 @@ def create_app(args): return {"status": "scanning_started"} - async def run_scanning_process(): - """Background task to scan and index documents""" - global scan_progress - - try: - new_files = doc_manager.scan_directory_for_new_files() - scan_progress["total_files"] = len(new_files) - - logger.info(f"Found {len(new_files)} new files to index.") - for file_path in new_files: - try: - with progress_lock: - scan_progress["current_file"] = os.path.basename(file_path) - - await index_file(file_path) - - with progress_lock: - scan_progress["indexed_count"] += 1 - scan_progress["progress"] = ( - scan_progress["indexed_count"] - / scan_progress["total_files"] - ) * 100 - - except Exception as e: - logging.error(f"Error indexing file {file_path}: {str(e)}") - - except Exception as e: - logging.error(f"Error during scanning process: {str(e)}") - finally: - with progress_lock: - scan_progress["is_scanning"] = False - @app.get("/documents/scan-progress") async def get_scan_progress(): """Get the current scanning progress""" @@ -1245,7 +1414,9 @@ def create_app(args): return scan_progress @app.post("/documents/upload", dependencies=[Depends(optional_api_key)]) - async def upload_to_input_dir(file: UploadFile = File(...)): + async def upload_to_input_dir( + background_tasks: BackgroundTasks, file: UploadFile = File(...) + ): """ Endpoint for uploading a file to the input directory and indexing it. @@ -1254,6 +1425,7 @@ def create_app(args): indexes it for retrieval, and returns a success status with relevant details. Parameters: + background_tasks: FastAPI BackgroundTasks for async processing file (UploadFile): The file to be uploaded. It must have an allowed extension as per `doc_manager.supported_extensions`. @@ -1278,15 +1450,175 @@ def create_app(args): with open(file_path, "wb") as buffer: shutil.copyfileobj(file.file, buffer) - # Immediately index the uploaded file - await index_file(file_path) + # Add to background tasks + background_tasks.add_task(pipeline_index_file, file_path) - return { - "status": "success", - "message": f"File uploaded and indexed: {file.filename}", - "total_documents": len(doc_manager.indexed_files), - } + return InsertResponse( + status="success", + message=f"File '{file.filename}' uploaded successfully. Processing will continue in background.", + ) except Exception as e: + logging.error(f"Error /documents/upload: {file.filename}: {str(e)}") + logging.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=str(e)) + + @app.post( + "/documents/text", + response_model=InsertResponse, + dependencies=[Depends(optional_api_key)], + ) + async def insert_text( + request: InsertTextRequest, background_tasks: BackgroundTasks + ): + """ + Insert text into the Retrieval-Augmented Generation (RAG) system. + + This endpoint allows you to insert text data into the RAG system for later retrieval and use in generating responses. + + Args: + request (InsertTextRequest): The request body containing the text to be inserted. + background_tasks: FastAPI BackgroundTasks for async processing + + Returns: + InsertResponse: A response object containing the status of the operation, a message, and the number of documents inserted. + """ + try: + background_tasks.add_task(pipeline_index_texts, [request.text]) + return InsertResponse( + status="success", + message="Text successfully received. Processing will continue in background.", + ) + except Exception as e: + logging.error(f"Error /documents/text: {str(e)}") + logging.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=str(e)) + + @app.post( + "/documents/file", + response_model=InsertResponse, + dependencies=[Depends(optional_api_key)], + ) + async def insert_file( + background_tasks: BackgroundTasks, file: UploadFile = File(...) + ): + """Insert a file directly into the RAG system + + Args: + background_tasks: FastAPI BackgroundTasks for async processing + file: Uploaded file + + Returns: + InsertResponse: Status of the insertion operation + + Raises: + HTTPException: For unsupported file types or processing errors + """ + try: + if not doc_manager.is_supported_file(file.filename): + raise HTTPException( + status_code=400, + detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}", + ) + + # Create a temporary file to save the uploaded content + temp_path = save_temp_file(file) + + # Add to background tasks + background_tasks.add_task(pipeline_index_file, temp_path) + + return InsertResponse( + status="success", + message=f"File '{file.filename}' saved successfully. Processing will continue in background.", + ) + + except Exception as e: + logging.error(f"Error /documents/file: {str(e)}") + logging.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=str(e)) + + @app.post( + "/documents/batch", + response_model=InsertResponse, + dependencies=[Depends(optional_api_key)], + ) + async def insert_batch( + background_tasks: BackgroundTasks, files: List[UploadFile] = File(...) + ): + """Process multiple files in batch mode + + Args: + background_tasks: FastAPI BackgroundTasks for async processing + files: List of files to process + + Returns: + InsertResponse: Status of the batch insertion operation + + Raises: + HTTPException: For processing errors + """ + try: + inserted_count = 0 + failed_files = [] + temp_files = [] + + for file in files: + if doc_manager.is_supported_file(file.filename): + # Create a temporary file to save the uploaded content + temp_files.append(save_temp_file(file)) + inserted_count += 1 + else: + failed_files.append(f"{file.filename} (unsupported type)") + + if temp_files: + background_tasks.add_task(pipeline_index_files, temp_files) + + # Prepare status message + if inserted_count == len(files): + status = "success" + status_message = f"Successfully inserted all {inserted_count} documents" + elif inserted_count > 0: + status = "partial_success" + status_message = f"Successfully inserted {inserted_count} out of {len(files)} documents" + if failed_files: + status_message += f". Failed files: {', '.join(failed_files)}" + else: + status = "failure" + status_message = "No documents were successfully inserted" + if failed_files: + status_message += f". Failed files: {', '.join(failed_files)}" + + return InsertResponse(status=status, message=status_message) + + except Exception as e: + logging.error(f"Error /documents/batch: {file.filename}: {str(e)}") + logging.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=str(e)) + + @app.delete( + "/documents", + response_model=InsertResponse, + dependencies=[Depends(optional_api_key)], + ) + async def clear_documents(): + """ + Clear all documents from the LightRAG system. + + This endpoint deletes all text chunks, entities vector database, and relationships vector database, + effectively clearing all documents from the LightRAG system. + + Returns: + InsertResponse: A response object containing the status, message, and the new document count (0 in this case). + """ + try: + rag.text_chunks = [] + rag.entities_vdb = None + rag.relationships_vdb = None + return InsertResponse( + status="success", message="All documents cleared successfully" + ) + except Exception as e: + logging.error(f"Error DELETE /documents: {str(e)}") + logging.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) @app.post( @@ -1297,12 +1629,7 @@ def create_app(args): Handle a POST request at the /query endpoint to process user queries using RAG capabilities. Parameters: - request (QueryRequest): A Pydantic model containing the following fields: - - query (str): The text of the user's query. - - mode (ModeEnum): Optional. Specifies the mode of retrieval augmentation. - - stream (bool): Optional. Determines if the response should be streamed. - - only_need_context (bool): Optional. If true, returns only the context without further processing. - + request (QueryRequest): The request object containing the query parameters. Returns: QueryResponse: A Pydantic model containing the result of the query processing. If a string is returned (e.g., cache hit), it's directly returned. @@ -1314,13 +1641,7 @@ def create_app(args): """ try: response = await rag.aquery( - request.query, - param=QueryParam( - mode=request.mode, - stream=request.stream, - only_need_context=request.only_need_context, - top_k=global_top_k, - ), + request.query, param=QueryRequestToQueryParams(request) ) # If response is a string (e.g. cache hit), return directly @@ -1328,16 +1649,16 @@ def create_app(args): return QueryResponse(response=response) # If it's an async generator, decide whether to stream based on stream parameter - if request.stream: + if request.stream or hasattr(response, "__aiter__"): result = "" async for chunk in response: result += chunk return QueryResponse(response=result) + elif isinstance(response, dict): + result = json.dumps(response, indent=2) + return QueryResponse(response=result) else: - result = "" - async for chunk in response: - result += chunk - return QueryResponse(response=result) + return QueryResponse(response=str(response)) except Exception as e: trace_exception(e) raise HTTPException(status_code=500, detail=str(e)) @@ -1355,14 +1676,11 @@ def create_app(args): StreamingResponse: A streaming response containing the RAG query results. """ try: + params = QueryRequestToQueryParams(request) + + params.stream = True response = await rag.aquery( # Use aquery instead of query, and add await - request.query, - param=QueryParam( - mode=request.mode, - stream=True, - only_need_context=request.only_need_context, - top_k=global_top_k, - ), + request.query, param=params ) from fastapi.responses import StreamingResponse @@ -1395,255 +1713,6 @@ def create_app(args): trace_exception(e) raise HTTPException(status_code=500, detail=str(e)) - @app.post( - "/documents/text", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def insert_text(request: InsertTextRequest): - """ - Insert text into the Retrieval-Augmented Generation (RAG) system. - - This endpoint allows you to insert text data into the RAG system for later retrieval and use in generating responses. - - Args: - request (InsertTextRequest): The request body containing the text to be inserted. - - Returns: - InsertResponse: A response object containing the status of the operation, a message, and the number of documents inserted. - """ - try: - await rag.ainsert(request.text) - return InsertResponse( - status="success", - message="Text successfully inserted", - document_count=1, - ) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/documents/file", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def insert_file(file: UploadFile = File(...), description: str = Form(None)): - """Insert a file directly into the RAG system - - Args: - file: Uploaded file - description: Optional description of the file - - Returns: - InsertResponse: Status of the insertion operation - - Raises: - HTTPException: For unsupported file types or processing errors - """ - try: - content = "" - # Get file extension in lowercase - ext = Path(file.filename).suffix.lower() - - match ext: - case ".txt" | ".md": - # Text files handling - text_content = await file.read() - content = text_content.decode("utf-8") - - case ".pdf" | ".docx" | ".pptx" | ".xlsx": - if not pm.is_installed("docling"): - pm.install("docling") - from docling.document_converter import DocumentConverter - - # Create a temporary file to save the uploaded content - temp_path = Path("temp") / file.filename - temp_path.parent.mkdir(exist_ok=True) - - # Save the uploaded file - with temp_path.open("wb") as f: - f.write(await file.read()) - - try: - - async def convert_doc(): - def sync_convert(): - converter = DocumentConverter() - result = converter.convert(str(temp_path)) - return result.document.export_to_markdown() - - return await asyncio.to_thread(sync_convert) - - content = await convert_doc() - finally: - # Clean up the temporary file - temp_path.unlink() - - # Insert content into RAG system - if content: - # Add description if provided - if description: - content = f"{description}\n\n{content}" - - await rag.ainsert(content) - logging.info(f"Successfully indexed file: {file.filename}") - - return InsertResponse( - status="success", - message=f"File '{file.filename}' successfully inserted", - document_count=1, - ) - else: - raise HTTPException( - status_code=400, - detail="No content could be extracted from the file", - ) - - except UnicodeDecodeError: - raise HTTPException(status_code=400, detail="File encoding not supported") - except Exception as e: - logging.error(f"Error processing file {file.filename}: {str(e)}") - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/documents/batch", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def insert_batch(files: List[UploadFile] = File(...)): - """Process multiple files in batch mode - - Args: - files: List of files to process - - Returns: - InsertResponse: Status of the batch insertion operation - - Raises: - HTTPException: For processing errors - """ - try: - inserted_count = 0 - failed_files = [] - - for file in files: - try: - content = "" - ext = Path(file.filename).suffix.lower() - - match ext: - case ".txt" | ".md": - text_content = await file.read() - content = text_content.decode("utf-8") - - case ".pdf": - if not pm.is_installed("pypdf2"): - pm.install("pypdf2") - from PyPDF2 import PdfReader - from io import BytesIO - - pdf_content = await file.read() - pdf_file = BytesIO(pdf_content) - reader = PdfReader(pdf_file) - for page in reader.pages: - content += page.extract_text() + "\n" - - case ".docx": - if not pm.is_installed("docx"): - pm.install("docx") - from docx import Document - from io import BytesIO - - docx_content = await file.read() - docx_file = BytesIO(docx_content) - doc = Document(docx_file) - content = "\n".join( - [paragraph.text for paragraph in doc.paragraphs] - ) - - case ".pptx": - if not pm.is_installed("pptx"): - pm.install("pptx") - from pptx import Presentation # type: ignore - from io import BytesIO - - pptx_content = await file.read() - pptx_file = BytesIO(pptx_content) - prs = Presentation(pptx_file) - for slide in prs.slides: - for shape in slide.shapes: - if hasattr(shape, "text"): - content += shape.text + "\n" - - case _: - failed_files.append(f"{file.filename} (unsupported type)") - continue - - if content: - await rag.ainsert(content) - inserted_count += 1 - logging.info(f"Successfully indexed file: {file.filename}") - else: - failed_files.append(f"{file.filename} (no content extracted)") - - except UnicodeDecodeError: - failed_files.append(f"{file.filename} (encoding error)") - except Exception as e: - failed_files.append(f"{file.filename} ({str(e)})") - logging.error(f"Error processing file {file.filename}: {str(e)}") - - # Prepare status message - if inserted_count == len(files): - status = "success" - status_message = f"Successfully inserted all {inserted_count} documents" - elif inserted_count > 0: - status = "partial_success" - status_message = f"Successfully inserted {inserted_count} out of {len(files)} documents" - if failed_files: - status_message += f". Failed files: {', '.join(failed_files)}" - else: - status = "failure" - status_message = "No documents were successfully inserted" - if failed_files: - status_message += f". Failed files: {', '.join(failed_files)}" - - return InsertResponse( - status=status, - message=status_message, - document_count=inserted_count, - ) - - except Exception as e: - logging.error(f"Batch processing error: {str(e)}") - raise HTTPException(status_code=500, detail=str(e)) - - @app.delete( - "/documents", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def clear_documents(): - """ - Clear all documents from the LightRAG system. - - This endpoint deletes all text chunks, entities vector database, and relationships vector database, - effectively clearing all documents from the LightRAG system. - - Returns: - InsertResponse: A response object containing the status, message, and the new document count (0 in this case). - """ - try: - rag.text_chunks = [] - rag.entities_vdb = None - rag.relationships_vdb = None - return InsertResponse( - status="success", - message="All documents cleared successfully", - document_count=0, - ) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - # query all graph labels @app.get("/graph/label/list") async def get_graph_labels(): diff --git a/lightrag/base.py b/lightrag/base.py index 3d4fc022..d9a63d26 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -249,20 +249,10 @@ class DocStatusStorage(BaseKVStorage): """Get counts of documents in each status""" raise NotImplementedError - async def get_failed_docs(self) -> dict[str, DocProcessingStatus]: - """Get all failed documents""" - raise NotImplementedError - - async def get_pending_docs(self) -> dict[str, DocProcessingStatus]: - """Get all pending documents""" - raise NotImplementedError - - async def get_processing_docs(self) -> dict[str, DocProcessingStatus]: - """Get all processing documents""" - raise NotImplementedError - - async def get_processed_docs(self) -> dict[str, DocProcessingStatus]: - """Get all procesed documents""" + async def get_docs_by_status( + self, status: DocStatus + ) -> dict[str, DocProcessingStatus]: + """Get all documents with a specific status""" raise NotImplementedError async def update_doc_status(self, data: dict[str, Any]) -> None: diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index fad03acc..ed79a370 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -93,36 +93,14 @@ class JsonDocStatusStorage(DocStatusStorage): counts[doc["status"]] += 1 return counts - async def get_failed_docs(self) -> dict[str, DocProcessingStatus]: - """Get all failed documents""" + async def get_docs_by_status( + self, status: DocStatus + ) -> dict[str, DocProcessingStatus]: + """all documents with a specific status""" return { k: DocProcessingStatus(**v) for k, v in self._data.items() - if v["status"] == DocStatus.FAILED - } - - async def get_pending_docs(self) -> dict[str, DocProcessingStatus]: - """Get all pending documents""" - return { - k: DocProcessingStatus(**v) - for k, v in self._data.items() - if v["status"] == DocStatus.PENDING - } - - async def get_processed_docs(self) -> dict[str, DocProcessingStatus]: - """Get all processed documents""" - return { - k: DocProcessingStatus(**v) - for k, v in self._data.items() - if v["status"] == DocStatus.PROCESSED - } - - async def get_processing_docs(self) -> dict[str, DocProcessingStatus]: - """Get all processing documents""" - return { - k: DocProcessingStatus(**v) - for k, v in self._data.items() - if v["status"] == DocStatus.PROCESSING + if v["status"] == status } async def index_done_callback(self): diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index c216e7be..f6326b76 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -175,7 +175,7 @@ class MongoDocStatusStorage(DocStatusStorage): async def get_docs_by_status( self, status: DocStatus ) -> dict[str, DocProcessingStatus]: - """Get all documents by status""" + """Get all documents with a specific status""" cursor = self._data.find({"status": status.value}) result = await cursor.to_list() return { @@ -191,22 +191,6 @@ class MongoDocStatusStorage(DocStatusStorage): for doc in result } - async def get_failed_docs(self) -> dict[str, DocProcessingStatus]: - """Get all failed documents""" - return await self.get_docs_by_status(DocStatus.FAILED) - - async def get_pending_docs(self) -> dict[str, DocProcessingStatus]: - """Get all pending documents""" - return await self.get_docs_by_status(DocStatus.PENDING) - - async def get_processing_docs(self) -> dict[str, DocProcessingStatus]: - """Get all processing documents""" - return await self.get_docs_by_status(DocStatus.PROCESSING) - - async def get_processed_docs(self) -> dict[str, DocProcessingStatus]: - """Get all procesed documents""" - return await self.get_docs_by_status(DocStatus.PROCESSED) - @dataclass class MongoGraphStorage(BaseGraphStorage): diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index a44aefe7..51b25385 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -468,7 +468,7 @@ class PGDocStatusStorage(DocStatusStorage): async def get_docs_by_status( self, status: DocStatus ) -> Dict[str, DocProcessingStatus]: - """Get all documents by status""" + """all documents with a specific status""" sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2" params = {"workspace": self.db.workspace, "status": status} result = await self.db.query(sql, params, True) @@ -485,22 +485,6 @@ class PGDocStatusStorage(DocStatusStorage): for element in result } - async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]: - """Get all failed documents""" - return await self.get_docs_by_status(DocStatus.FAILED) - - async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]: - """Get all pending documents""" - return await self.get_docs_by_status(DocStatus.PENDING) - - async def get_processing_docs(self) -> dict[str, DocProcessingStatus]: - """Get all processing documents""" - return await self.get_docs_by_status(DocStatus.PROCESSING) - - async def get_processed_docs(self) -> dict[str, DocProcessingStatus]: - """Get all procesed documents""" - return await self.get_docs_by_status(DocStatus.PROCESSED) - async def index_done_callback(self): """Save data after indexing, but for PostgreSQL, we already saved them during the upsert stage, so no action to take here""" logger.info("Doc status had been saved into postgresql db!") diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 23c3df80..9909b4b7 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -89,7 +89,7 @@ STORAGE_IMPLEMENTATIONS = { "PGDocStatusStorage", "MongoDocStatusStorage", ], - "required_methods": ["get_pending_docs"], + "required_methods": ["get_docs_by_status"], }, } @@ -230,7 +230,7 @@ class LightRAG: """LightRAG: Simple and Fast Retrieval-Augmented Generation.""" working_dir: str = field( - default_factory=lambda: f'./lightrag_cache_{datetime.now().strftime("%Y-%m-%d-%H:%M:%S")}' + default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}" ) """Directory where cache and temporary files are stored.""" @@ -715,11 +715,11 @@ class LightRAG: # 1. Get all pending, failed, and abnormally terminated processing documents. to_process_docs: dict[str, DocProcessingStatus] = {} - processing_docs = await self.doc_status.get_processing_docs() + processing_docs = await self.doc_status.get_docs_by_status(DocStatus.PROCESSING) to_process_docs.update(processing_docs) - failed_docs = await self.doc_status.get_failed_docs() + failed_docs = await self.doc_status.get_docs_by_status(DocStatus.FAILED) to_process_docs.update(failed_docs) - pendings_docs = await self.doc_status.get_pending_docs() + pendings_docs = await self.doc_status.get_docs_by_status(DocStatus.PENDING) to_process_docs.update(pendings_docs) if not to_process_docs: