From c0c87edc451068a023ec9d891aff12ef42596ba2 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 20 Feb 2025 03:26:39 +0800 Subject: [PATCH] split lightrag_servery.py to smaller files --- lightrag/api/lightrag_server.py | 969 +---------------------- lightrag/api/routers/__init__.py | 10 + lightrag/api/routers/document_routes.py | 667 ++++++++++++++++ lightrag/api/routers/graph_routes.py | 26 + lightrag/api/{ => routers}/ollama_api.py | 0 lightrag/api/routers/query_routes.py | 225 ++++++ lightrag/api/utils_api.py | 44 + 7 files changed, 1008 insertions(+), 933 deletions(-) create mode 100644 lightrag/api/routers/__init__.py create mode 100644 lightrag/api/routers/document_routes.py create mode 100644 lightrag/api/routers/graph_routes.py rename lightrag/api/{ => routers}/ollama_api.py (100%) create mode 100644 lightrag/api/routers/query_routes.py create mode 100644 lightrag/api/utils_api.py diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 0cf1d01e..f7f70c62 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -1,44 +1,43 @@ +""" +LightRAG FastAPI Server +""" + from fastapi import ( FastAPI, HTTPException, - File, - UploadFile, - BackgroundTasks, + Depends, ) import asyncio import threading import os -import json -import re from fastapi.staticfiles import StaticFiles import logging import argparse -from typing import List, Any, Literal, Optional, Dict -from pydantic import BaseModel, Field, field_validator +from typing import Optional, Dict from pathlib import Path -import shutil -import aiofiles -from ascii_colors import trace_exception, ASCIIColors +import configparser +from ascii_colors import ASCIIColors import sys -from fastapi import Depends, Security -from fastapi.security import APIKeyHeader from fastapi.middleware.cors import CORSMiddleware from contextlib import asynccontextmanager -from starlette.status import HTTP_403_FORBIDDEN -import pipmaster as pm from dotenv import load_dotenv -import configparser -import traceback -from datetime import datetime -from lightrag import LightRAG, QueryParam -from lightrag.base import DocProcessingStatus, DocStatus +from .utils_api import get_api_key_dependency + +from lightrag import LightRAG from lightrag.types import GPTKeywordExtractionFormat from lightrag.api import __api_version__ from lightrag.utils import EmbeddingFunc from lightrag.utils import logger -from .ollama_api import OllamaAPI, ollama_server_infos +from .routers.document_routes import ( + DocumentManager, + create_document_routes, + run_scanning_process, +) +from .routers.query_routes import create_query_routes +from .routers.graph_routes import create_graph_routes +from .routers.ollama_api import OllamaAPI, ollama_server_infos # Load environment variables try: @@ -50,6 +49,9 @@ except Exception as e: config = configparser.ConfigParser() config.read("config.ini") +# Global configuration +global_top_k = 60 # default value + class DefaultRAGStorageConfig: KV_STORAGE = "JsonKVStorage" @@ -70,22 +72,6 @@ scan_progress: Dict = { # Lock for thread-safe operations progress_lock = threading.Lock() - -def estimate_tokens(text: str) -> int: - """Estimate the number of tokens in text - Chinese characters: approximately 1.5 tokens per character - English characters: approximately 0.25 tokens per character - """ - # Use regex to match Chinese and non-Chinese characters separately - chinese_chars = len(re.findall(r"[\u4e00-\u9fff]", text)) - non_chinese_chars = len(re.findall(r"[^\u4e00-\u9fff]", text)) - - # Calculate estimated token count - tokens = chinese_chars * 1.5 + non_chinese_chars * 0.25 - - return int(tokens) - - def get_default_host(binding_type: str) -> str: default_hosts = { "ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"), @@ -98,17 +84,17 @@ def get_default_host(binding_type: str) -> str: ) # fallback to ollama if unknown -def get_env_value(env_key: str, default: Any, value_type: type = str) -> Any: +def get_env_value(env_key: str, default: any, value_type: type = str) -> any: """ Get value from environment variable with type conversion Args: env_key (str): Environment variable key - default (Any): Default value if env variable is not set + default (any): Default value if env variable is not set value_type (type): Type to convert the value to Returns: - Any: Converted value from environment or default + any: Converted value from environment or default """ value = os.getenv(env_key) if value is None: @@ -557,7 +543,7 @@ def parse_args() -> argparse.Namespace: args = parser.parse_args() - # conver relative path to absolute path + # convert relative path to absolute path args.working_dir = os.path.abspath(args.working_dir) args.input_dir = os.path.abspath(args.input_dir) @@ -566,293 +552,16 @@ def parse_args() -> argparse.Namespace: return args -class DocumentManager: - """Handles document operations and tracking""" - - def __init__( - self, - input_dir: str, - supported_extensions: tuple = ( - ".txt", - ".md", - ".pdf", - ".docx", - ".pptx", - ".xlsx", - ), - ): - self.input_dir = Path(input_dir) - self.supported_extensions = supported_extensions - self.indexed_files = set() - - # Create input directory if it doesn't exist - self.input_dir.mkdir(parents=True, exist_ok=True) - - def scan_directory_for_new_files(self) -> List[Path]: - """Scan input directory for new files""" - new_files = [] - for ext in self.supported_extensions: - logger.info(f"Scanning for {ext} files in {self.input_dir}") - for file_path in self.input_dir.rglob(f"*{ext}"): - if file_path not in self.indexed_files: - new_files.append(file_path) - return new_files - - def scan_directory(self) -> List[Path]: - """Scan input directory for new files""" - new_files = [] - for ext in self.supported_extensions: - for file_path in self.input_dir.rglob(f"*{ext}"): - new_files.append(file_path) - return new_files - - def mark_as_indexed(self, file_path: Path): - """Mark a file as indexed""" - self.indexed_files.add(file_path) - - def is_supported_file(self, filename: str) -> bool: - """Check if file type is supported""" - return any(filename.lower().endswith(ext) for ext in self.supported_extensions) - - -class QueryRequest(BaseModel): - query: str = Field( - min_length=1, - description="The query text", - ) - - mode: Literal["local", "global", "hybrid", "naive", "mix"] = Field( - default="hybrid", - description="Query mode", - ) - - only_need_context: Optional[bool] = Field( - default=None, - description="If True, only returns the retrieved context without generating a response.", - ) - - only_need_prompt: Optional[bool] = Field( - default=None, - description="If True, only returns the generated prompt without producing a response.", - ) - - response_type: Optional[str] = Field( - min_length=1, - default=None, - description="Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'.", - ) - - top_k: Optional[int] = Field( - ge=1, - default=None, - description="Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.", - ) - - max_token_for_text_unit: Optional[int] = Field( - gt=1, - default=None, - description="Maximum number of tokens allowed for each retrieved text chunk.", - ) - - max_token_for_global_context: Optional[int] = Field( - gt=1, - default=None, - description="Maximum number of tokens allocated for relationship descriptions in global retrieval.", - ) - - max_token_for_local_context: Optional[int] = Field( - gt=1, - default=None, - description="Maximum number of tokens allocated for entity descriptions in local retrieval.", - ) - - hl_keywords: Optional[List[str]] = Field( - default=None, - description="List of high-level keywords to prioritize in retrieval.", - ) - - ll_keywords: Optional[List[str]] = Field( - default=None, - description="List of low-level keywords to refine retrieval focus.", - ) - - conversation_history: Optional[List[dict[str, Any]]] = Field( - default=None, - description="Stores past conversation history to maintain context. Format: [{'role': 'user/assistant', 'content': 'message'}].", - ) - - history_turns: Optional[int] = Field( - ge=0, - default=None, - description="Number of complete conversation turns (user-assistant pairs) to consider in the response context.", - ) - - @field_validator("query", mode="after") - @classmethod - def query_strip_after(cls, query: str) -> str: - return query.strip() - - @field_validator("hl_keywords", mode="after") - @classmethod - def hl_keywords_strip_after(cls, hl_keywords: List[str] | None) -> List[str] | None: - if hl_keywords is None: - return None - return [keyword.strip() for keyword in hl_keywords] - - @field_validator("ll_keywords", mode="after") - @classmethod - def ll_keywords_strip_after(cls, ll_keywords: List[str] | None) -> List[str] | None: - if ll_keywords is None: - return None - return [keyword.strip() for keyword in ll_keywords] - - @field_validator("conversation_history", mode="after") - @classmethod - def conversation_history_role_check( - cls, conversation_history: List[dict[str, Any]] | None - ) -> List[dict[str, Any]] | None: - if conversation_history is None: - return None - for msg in conversation_history: - if "role" not in msg or msg["role"] not in {"user", "assistant"}: - raise ValueError( - "Each message must have a 'role' key with value 'user' or 'assistant'." - ) - return conversation_history - - def to_query_params(self, is_stream: bool) -> QueryParam: - """Converts a QueryRequest instance into a QueryParam instance.""" - # Use Pydantic's `.model_dump(exclude_none=True)` to remove None values automatically - request_data = self.model_dump(exclude_none=True, exclude={"query"}) - - # Ensure `mode` and `stream` are set explicitly - param = QueryParam(**request_data) - param.stream = is_stream - return param - - -class QueryResponse(BaseModel): - response: str = Field( - description="The generated response", - ) - - -class InsertTextRequest(BaseModel): - text: str = Field( - min_length=1, - description="The text to insert", - ) - - @field_validator("text", mode="after") - @classmethod - def strip_after(cls, text: str) -> str: - return text.strip() - - -class InsertTextsRequest(BaseModel): - texts: list[str] = Field( - min_length=1, - description="The texts to insert", - ) - - @field_validator("texts", mode="after") - @classmethod - def strip_after(cls, texts: list[str]) -> list[str]: - return [text.strip() for text in texts] - - -class InsertResponse(BaseModel): - status: str = Field(description="Status of the operation") - message: str = Field(description="Message describing the operation result") - - -class DocStatusResponse(BaseModel): - @staticmethod - def format_datetime(dt: Any) -> Optional[str]: - """Format datetime to ISO string - - Args: - dt: Datetime object or string - - Returns: - Formatted datetime string or None - """ - if dt is None: - return None - if isinstance(dt, str): - return dt - return dt.isoformat() - - """Response model for document status - - Attributes: - id: Document identifier - content_summary: Summary of document content - content_length: Length of document content - status: Current processing status - created_at: Creation timestamp (ISO format string) - updated_at: Last update timestamp (ISO format string) - chunks_count: Number of chunks (optional) - error: Error message if any (optional) - metadata: Additional metadata (optional) - """ - - id: str - content_summary: str - content_length: int - status: DocStatus - created_at: str - updated_at: str - chunks_count: Optional[int] = None - error: Optional[str] = None - metadata: Optional[dict[str, Any]] = None - - -class DocsStatusesResponse(BaseModel): - statuses: Dict[DocStatus, List[DocStatusResponse]] = {} - - -def get_api_key_dependency(api_key: Optional[str]): - if not api_key: - # If no API key is configured, return a dummy dependency that always succeeds - async def no_auth(): - return None - - return no_auth - - # If API key is configured, use proper authentication - api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) - - async def api_key_auth( - api_key_header_value: Optional[str] = Security(api_key_header), - ): - if not api_key_header_value: - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, detail="API Key required" - ) - if api_key_header_value != api_key: - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, detail="Invalid API Key" - ) - return api_key_header_value - - return api_key_auth - - -# Global configuration -global_top_k = 60 # default value -temp_prefix = "__tmp_" # prefix for temporary files - - def create_app(args): + # Set global top_k + global global_top_k + global_top_k = args.top_k # save top_k from args + # Initialize verbose debug setting from lightrag.utils import set_verbose_debug set_verbose_debug(args.verbose) - global global_top_k - global_top_k = args.top_k # save top_k from args - # Verify that bindings are correctly setup if args.llm_binding not in [ "lollms", @@ -914,7 +623,7 @@ def create_app(args): scan_progress["indexed_count"] = 0 scan_progress["progress"] = 0 # Create background task - task = asyncio.create_task(run_scanning_process()) + task = asyncio.create_task(run_scanning_process(rag, doc_manager)) app.state.background_tasks.add(task) task.add_done_callback(app.state.background_tasks.discard) ASCIIColors.info( @@ -922,7 +631,7 @@ def create_app(args): ) else: ASCIIColors.info( - "Skip document scanning(anohter scanning is active)" + "Skip document scanning(another scanning is active)" ) yield @@ -1130,621 +839,15 @@ def create_app(args): auto_manage_storages_states=False, ) - async def pipeline_enqueue_file(file_path: Path) -> bool: - """Add a file to the queue for processing - - Args: - file_path: Path to the saved file - Returns: - bool: True if the file was successfully enqueued, False otherwise - """ - try: - content = "" - ext = file_path.suffix.lower() - - file = None - async with aiofiles.open(file_path, "rb") as f: - file = await f.read() - - # Process based on file type - match ext: - case ".txt" | ".md": - content = file.decode("utf-8") - case ".pdf": - if not pm.is_installed("pypdf2"): - pm.install("pypdf2") - from PyPDF2 import PdfReader # type: ignore - from io import BytesIO - - pdf_file = BytesIO(file) - reader = PdfReader(pdf_file) - for page in reader.pages: - content += page.extract_text() + "\n" - case ".docx": - if not pm.is_installed("docx"): - pm.install("docx") - from docx import Document - from io import BytesIO - - docx_file = BytesIO(file) - doc = Document(docx_file) - content = "\n".join( - [paragraph.text for paragraph in doc.paragraphs] - ) - case ".pptx": - if not pm.is_installed("pptx"): - pm.install("pptx") - from pptx import Presentation # type: ignore - from io import BytesIO - - pptx_file = BytesIO(file) - prs = Presentation(pptx_file) - for slide in prs.slides: - for shape in slide.shapes: - if hasattr(shape, "text"): - content += shape.text + "\n" - case ".xlsx": - if not pm.is_installed("openpyxl"): - pm.install("openpyxl") - from openpyxl import load_workbook # type: ignore - from io import BytesIO - - xlsx_file = BytesIO(file) - wb = load_workbook(xlsx_file) - for sheet in wb: - content += f"Sheet: {sheet.title}\n" - for row in sheet.iter_rows(values_only=True): - content += ( - "\t".join( - str(cell) if cell is not None else "" - for cell in row - ) - + "\n" - ) - content += "\n" - case _: - logging.error( - f"Unsupported file type: {file_path.name} (extension {ext})" - ) - return False - - # Insert into the RAG queue - if content: - await rag.apipeline_enqueue_documents(content) - logging.info( - f"Successfully fetched and enqueued file: {file_path.name}" - ) - return True - else: - logging.error( - f"No content could be extracted from file: {file_path.name}" - ) - - except Exception as e: - logging.error( - f"Error processing or enqueueing file {file_path.name}: {str(e)}" - ) - logging.error(traceback.format_exc()) - finally: - if file_path.name.startswith(temp_prefix): - # Clean up the temporary file after indexing - try: - file_path.unlink() - except Exception as e: - logging.error(f"Error deleting file {file_path}: {str(e)}") - return False - - async def pipeline_index_file(file_path: Path): - """Index a file - - Args: - file_path: Path to the saved file - """ - try: - if await pipeline_enqueue_file(file_path): - await rag.apipeline_process_enqueue_documents() - - except Exception as e: - logging.error(f"Error indexing file {file_path.name}: {str(e)}") - logging.error(traceback.format_exc()) - - async def pipeline_index_files(file_paths: List[Path]): - """Index multiple files concurrently - - Args: - file_paths: Paths to the files to index - """ - if not file_paths: - return - try: - enqueued = False - - if len(file_paths) == 1: - enqueued = await pipeline_enqueue_file(file_paths[0]) - else: - tasks = [pipeline_enqueue_file(path) for path in file_paths] - enqueued = any(await asyncio.gather(*tasks)) - - if enqueued: - await rag.apipeline_process_enqueue_documents() - except Exception as e: - logging.error(f"Error indexing files: {str(e)}") - logging.error(traceback.format_exc()) - - async def pipeline_index_texts(texts: List[str]): - """Index a list of texts - - Args: - texts: The texts to index - """ - if not texts: - return - await rag.apipeline_enqueue_documents(texts) - await rag.apipeline_process_enqueue_documents() - - async def save_temp_file(file: UploadFile = File(...)) -> Path: - """Save the uploaded file to a temporary location - - Args: - file: The uploaded file - - Returns: - Path: The path to the saved file - """ - # Generate unique filename to avoid conflicts - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - unique_filename = f"{temp_prefix}{timestamp}_{file.filename}" - - # Create a temporary file to save the uploaded content - temp_path = doc_manager.input_dir / "temp" / unique_filename - temp_path.parent.mkdir(exist_ok=True) - - # Save the file - with open(temp_path, "wb") as buffer: - shutil.copyfileobj(file.file, buffer) - return temp_path - - async def run_scanning_process(): - """Background task to scan and index documents""" - global scan_progress - - try: - new_files = doc_manager.scan_directory_for_new_files() - scan_progress["total_files"] = len(new_files) - - logger.info(f"Found {len(new_files)} new files to index.") - for file_path in new_files: - try: - with progress_lock: - scan_progress["current_file"] = os.path.basename(file_path) - - await pipeline_index_file(file_path) - - with progress_lock: - scan_progress["indexed_count"] += 1 - scan_progress["progress"] = ( - scan_progress["indexed_count"] - / scan_progress["total_files"] - ) * 100 - - except Exception as e: - logging.error(f"Error indexing file {file_path}: {str(e)}") - - except Exception as e: - logging.error(f"Error during scanning process: {str(e)}") - finally: - with progress_lock: - scan_progress["is_scanning"] = False - - @app.post("/documents/scan", dependencies=[Depends(optional_api_key)]) - async def scan_for_new_documents(background_tasks: BackgroundTasks): - """Trigger the scanning process""" - global scan_progress - - with progress_lock: - if scan_progress["is_scanning"]: - return {"status": "already_scanning"} - - scan_progress["is_scanning"] = True - scan_progress["indexed_count"] = 0 - scan_progress["progress"] = 0 - - # Start the scanning process in the background - background_tasks.add_task(run_scanning_process) - - return {"status": "scanning_started"} - - @app.get("/documents/scan-progress") - async def get_scan_progress(): - """Get the current scanning progress""" - with progress_lock: - return scan_progress - - @app.post("/documents/upload", dependencies=[Depends(optional_api_key)]) - async def upload_to_input_dir( - background_tasks: BackgroundTasks, file: UploadFile = File(...) - ): - """ - Endpoint for uploading a file to the input directory and indexing it. - - This API endpoint accepts a file through an HTTP POST request, checks if the - uploaded file is of a supported type, saves it in the specified input directory, - indexes it for retrieval, and returns a success status with relevant details. - - Parameters: - background_tasks: FastAPI BackgroundTasks for async processing - file (UploadFile): The file to be uploaded. It must have an allowed extension as per - `doc_manager.supported_extensions`. - - Returns: - dict: A dictionary containing the upload status ("success"), - a message detailing the operation result, and - the total number of indexed documents. - - Raises: - HTTPException: If the file type is not supported, it raises a 400 Bad Request error. - If any other exception occurs during the file handling or indexing, - it raises a 500 Internal Server Error with details about the exception. - """ - try: - if not doc_manager.is_supported_file(file.filename): - raise HTTPException( - status_code=400, - detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}", - ) - - file_path = doc_manager.input_dir / file.filename - with open(file_path, "wb") as buffer: - shutil.copyfileobj(file.file, buffer) - - # Add to background tasks - background_tasks.add_task(pipeline_index_file, file_path) - - return InsertResponse( - status="success", - message=f"File '{file.filename}' uploaded successfully. Processing will continue in background.", - ) - except Exception as e: - logging.error(f"Error /documents/upload: {file.filename}: {str(e)}") - logging.error(traceback.format_exc()) - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/documents/text", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def insert_text( - request: InsertTextRequest, background_tasks: BackgroundTasks - ): - """ - Insert text into the Retrieval-Augmented Generation (RAG) system. - - This endpoint allows you to insert text data into the RAG system for later retrieval and use in generating responses. - - Args: - request (InsertTextRequest): The request body containing the text to be inserted. - background_tasks: FastAPI BackgroundTasks for async processing - - Returns: - InsertResponse: A response object containing the status of the operation, a message, and the number of documents inserted. - """ - try: - background_tasks.add_task(pipeline_index_texts, [request.text]) - return InsertResponse( - status="success", - message="Text successfully received. Processing will continue in background.", - ) - except Exception as e: - logging.error(f"Error /documents/text: {str(e)}") - logging.error(traceback.format_exc()) - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/documents/texts", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def insert_texts( - request: InsertTextsRequest, background_tasks: BackgroundTasks - ): - """ - Insert texts into the Retrieval-Augmented Generation (RAG) system. - - This endpoint allows you to insert text data into the RAG system for later retrieval and use in generating responses. - - Args: - request (InsertTextsRequest): The request body containing the text to be inserted. - background_tasks: FastAPI BackgroundTasks for async processing - - Returns: - InsertResponse: A response object containing the status of the operation, a message, and the number of documents inserted. - """ - try: - background_tasks.add_task(pipeline_index_texts, request.texts) - return InsertResponse( - status="success", - message="Text successfully received. Processing will continue in background.", - ) - except Exception as e: - logging.error(f"Error /documents/text: {str(e)}") - logging.error(traceback.format_exc()) - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/documents/file", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def insert_file( - background_tasks: BackgroundTasks, file: UploadFile = File(...) - ): - """Insert a file directly into the RAG system - - Args: - background_tasks: FastAPI BackgroundTasks for async processing - file: Uploaded file - - Returns: - InsertResponse: Status of the insertion operation - - Raises: - HTTPException: For unsupported file types or processing errors - """ - try: - if not doc_manager.is_supported_file(file.filename): - raise HTTPException( - status_code=400, - detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}", - ) - - # Create a temporary file to save the uploaded content - temp_path = save_temp_file(file) - - # Add to background tasks - background_tasks.add_task(pipeline_index_file, temp_path) - - return InsertResponse( - status="success", - message=f"File '{file.filename}' saved successfully. Processing will continue in background.", - ) - - except Exception as e: - logging.error(f"Error /documents/file: {str(e)}") - logging.error(traceback.format_exc()) - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/documents/file_batch", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def insert_batch( - background_tasks: BackgroundTasks, files: List[UploadFile] = File(...) - ): - """Process multiple files in batch mode - - Args: - background_tasks: FastAPI BackgroundTasks for async processing - files: List of files to process - - Returns: - InsertResponse: Status of the batch insertion operation - - Raises: - HTTPException: For processing errors - """ - try: - inserted_count = 0 - failed_files = [] - temp_files = [] - - for file in files: - if doc_manager.is_supported_file(file.filename): - # Create a temporary file to save the uploaded content - temp_files.append(save_temp_file(file)) - inserted_count += 1 - else: - failed_files.append(f"{file.filename} (unsupported type)") - - if temp_files: - background_tasks.add_task(pipeline_index_files, temp_files) - - # Prepare status message - if inserted_count == len(files): - status = "success" - status_message = f"Successfully inserted all {inserted_count} documents" - elif inserted_count > 0: - status = "partial_success" - status_message = f"Successfully inserted {inserted_count} out of {len(files)} documents" - if failed_files: - status_message += f". Failed files: {', '.join(failed_files)}" - else: - status = "failure" - status_message = "No documents were successfully inserted" - if failed_files: - status_message += f". Failed files: {', '.join(failed_files)}" - - return InsertResponse(status=status, message=status_message) - - except Exception as e: - logging.error(f"Error /documents/batch: {file.filename}: {str(e)}") - logging.error(traceback.format_exc()) - raise HTTPException(status_code=500, detail=str(e)) - - @app.delete( - "/documents", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def clear_documents(): - """ - Clear all documents from the LightRAG system. - - This endpoint deletes all text chunks, entities vector database, and relationships vector database, - effectively clearing all documents from the LightRAG system. - - Returns: - InsertResponse: A response object containing the status, message, and the new document count (0 in this case). - """ - try: - rag.text_chunks = [] - rag.entities_vdb = None - rag.relationships_vdb = None - return InsertResponse( - status="success", message="All documents cleared successfully" - ) - except Exception as e: - logging.error(f"Error DELETE /documents: {str(e)}") - logging.error(traceback.format_exc()) - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)] - ) - async def query_text(request: QueryRequest): - """ - Handle a POST request at the /query endpoint to process user queries using RAG capabilities. - - Parameters: - request (QueryRequest): The request object containing the query parameters. - Returns: - QueryResponse: A Pydantic model containing the result of the query processing. - If a string is returned (e.g., cache hit), it's directly returned. - Otherwise, an async generator may be used to build the response. - - Raises: - HTTPException: Raised when an error occurs during the request handling process, - with status code 500 and detail containing the exception message. - """ - try: - response = await rag.aquery( - request.query, param=request.to_query_params(False) - ) - - # If response is a string (e.g. cache hit), return directly - if isinstance(response, str): - return QueryResponse(response=response) - - if isinstance(response, dict): - result = json.dumps(response, indent=2) - return QueryResponse(response=result) - else: - return QueryResponse(response=str(response)) - except Exception as e: - trace_exception(e) - raise HTTPException(status_code=500, detail=str(e)) - - @app.post("/query/stream", dependencies=[Depends(optional_api_key)]) - async def query_text_stream(request: QueryRequest): - """ - This endpoint performs a retrieval-augmented generation (RAG) query and streams the response. - - Args: - request (QueryRequest): The request object containing the query parameters. - optional_api_key (Optional[str], optional): An optional API key for authentication. Defaults to None. - - Returns: - StreamingResponse: A streaming response containing the RAG query results. - """ - try: - response = await rag.aquery( - request.query, param=request.to_query_params(True) - ) - - from fastapi.responses import StreamingResponse - - async def stream_generator(): - if isinstance(response, str): - # If it's a string, send it all at once - yield f"{json.dumps({'response': response})}\n" - else: - # If it's an async generator, send chunks one by one - try: - async for chunk in response: - if chunk: # Only send non-empty content - yield f"{json.dumps({'response': chunk})}\n" - except Exception as e: - logging.error(f"Streaming error: {str(e)}") - yield f"{json.dumps({'error': str(e)})}\n" - - return StreamingResponse( - stream_generator(), - media_type="application/x-ndjson", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "Content-Type": "application/x-ndjson", - "X-Accel-Buffering": "no", # Ensure proper handling of streaming response when proxied by Nginx - }, - ) - except Exception as e: - trace_exception(e) - raise HTTPException(status_code=500, detail=str(e)) - - # query all graph labels - @app.get("/graph/label/list") - async def get_graph_labels(): - return await rag.get_graph_labels() - - # query all graph - @app.get("/graphs") - async def get_knowledge_graph(label: str): - return await rag.get_knowledge_graph(nodel_label=label, max_depth=100) + # Add routes + app.include_router(create_document_routes(rag, doc_manager, api_key)) + app.include_router(create_query_routes(rag, api_key, args.top_k)) + app.include_router(create_graph_routes(rag, api_key)) # Add Ollama API routes ollama_api = OllamaAPI(rag, top_k=args.top_k) app.include_router(ollama_api.router, prefix="/api") - @app.get("/documents", dependencies=[Depends(optional_api_key)]) - async def documents() -> DocsStatusesResponse: - """ - Get documents statuses - Returns: - DocsStatusesResponse: A response object containing a dictionary where keys are DocStatus - and values are lists of DocStatusResponse objects representing documents in each status category. - """ - try: - statuses = ( - DocStatus.PENDING, - DocStatus.PROCESSING, - DocStatus.PROCESSED, - DocStatus.FAILED, - ) - - tasks = [rag.get_docs_by_status(status) for status in statuses] - results: List[Dict[str, DocProcessingStatus]] = await asyncio.gather(*tasks) - - response = DocsStatusesResponse() - - for idx, result in enumerate(results): - status = statuses[idx] - for doc_id, doc_status in result.items(): - if status not in response.statuses: - response.statuses[status] = [] - response.statuses[status].append( - DocStatusResponse( - id=doc_id, - content_summary=doc_status.content_summary, - content_length=doc_status.content_length, - status=doc_status.status, - created_at=DocStatusResponse.format_datetime( - doc_status.created_at - ), - updated_at=DocStatusResponse.format_datetime( - doc_status.updated_at - ), - chunks_count=doc_status.chunks_count, - error=doc_status.error, - metadata=doc_status.metadata, - ) - ) - return response - except Exception as e: - logging.error(f"Error GET /documents: {str(e)}") - logging.error(traceback.format_exc()) - raise HTTPException(status_code=500, detail=str(e)) - @app.get("/health", dependencies=[Depends(optional_api_key)]) async def get_status(): """Get current system status""" diff --git a/lightrag/api/routers/__init__.py b/lightrag/api/routers/__init__.py new file mode 100644 index 00000000..b71f204e --- /dev/null +++ b/lightrag/api/routers/__init__.py @@ -0,0 +1,10 @@ +""" +This module contains all the routers for the LightRAG API. +""" + +from .document_routes import router as document_router +from .query_routes import router as query_router +from .graph_routes import router as graph_router +from .ollama_api import OllamaAPI + +__all__ = ["document_router", "query_router", "graph_router", "OllamaAPI"] diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py new file mode 100644 index 00000000..48401658 --- /dev/null +++ b/lightrag/api/routers/document_routes.py @@ -0,0 +1,667 @@ +""" +This module contains all document-related routes for the LightRAG API. +""" + +import asyncio +import logging +import os +import aiofiles +import shutil +import traceback +import pipmaster as pm +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Optional, Any + +from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile +from fastapi.security import APIKeyHeader +from pydantic import BaseModel, Field, field_validator +from starlette.status import HTTP_403_FORBIDDEN + +from lightrag.base import DocProcessingStatus, DocStatus +from ..utils_api import get_api_key_dependency + + +router = APIRouter(prefix="/documents", tags=["documents"]) + +# Global progress tracker +scan_progress: Dict = { + "is_scanning": False, + "current_file": "", + "indexed_count": 0, + "total_files": 0, + "progress": 0, +} + +# Lock for thread-safe operations +progress_lock = asyncio.Lock() + +# Temporary file prefix +temp_prefix = "__tmp__" + +class InsertTextRequest(BaseModel): + text: str = Field( + min_length=1, + description="The text to insert", + ) + + @field_validator("text", mode="after") + @classmethod + def strip_after(cls, text: str) -> str: + return text.strip() + +class InsertTextsRequest(BaseModel): + texts: list[str] = Field( + min_length=1, + description="The texts to insert", + ) + + @field_validator("texts", mode="after") + @classmethod + def strip_after(cls, texts: list[str]) -> list[str]: + return [text.strip() for text in texts] + +class InsertResponse(BaseModel): + status: str = Field(description="Status of the operation") + message: str = Field(description="Message describing the operation result") + +class DocStatusResponse(BaseModel): + @staticmethod + def format_datetime(dt: Any) -> Optional[str]: + if dt is None: + return None + if isinstance(dt, str): + return dt + return dt.isoformat() + + id: str + content_summary: str + content_length: int + status: DocStatus + created_at: str + updated_at: str + chunks_count: Optional[int] = None + error: Optional[str] = None + metadata: Optional[dict[str, Any]] = None + +class DocsStatusesResponse(BaseModel): + statuses: Dict[DocStatus, List[DocStatusResponse]] = {} + +class DocumentManager: + def __init__( + self, + input_dir: str, + supported_extensions: tuple = ( + ".txt", + ".md", + ".pdf", + ".docx", + ".pptx", + ".xlsx", + ), + ): + self.input_dir = Path(input_dir) + self.supported_extensions = supported_extensions + self.indexed_files = set() + + # Create input directory if it doesn't exist + self.input_dir.mkdir(parents=True, exist_ok=True) + + def scan_directory_for_new_files(self) -> List[Path]: + new_files = [] + for ext in self.supported_extensions: + logging.info(f"Scanning for {ext} files in {self.input_dir}") + for file_path in self.input_dir.rglob(f"*{ext}"): + if file_path not in self.indexed_files: + new_files.append(file_path) + return new_files + + def scan_directory(self) -> List[Path]: + new_files = [] + for ext in self.supported_extensions: + for file_path in self.input_dir.rglob(f"*{ext}"): + new_files.append(file_path) + return new_files + + def mark_as_indexed(self, file_path: Path): + self.indexed_files.add(file_path) + + def is_supported_file(self, filename: str) -> bool: + return any(filename.lower().endswith(ext) for ext in self.supported_extensions) + +async def pipeline_enqueue_file(rag, file_path: Path) -> bool: + try: + content = "" + ext = file_path.suffix.lower() + + file = None + async with aiofiles.open(file_path, "rb") as f: + file = await f.read() + + # Process based on file type + match ext: + case ".txt" | ".md": + content = file.decode("utf-8") + case ".pdf": + if not pm.is_installed("pypdf2"): + pm.install("pypdf2") + from PyPDF2 import PdfReader # type: ignore + from io import BytesIO + + pdf_file = BytesIO(file) + reader = PdfReader(pdf_file) + for page in reader.pages: + content += page.extract_text() + "\n" + case ".docx": + if not pm.is_installed("docx"): + pm.install("docx") + from docx import Document + from io import BytesIO + + docx_file = BytesIO(file) + doc = Document(docx_file) + content = "\n".join([paragraph.text for paragraph in doc.paragraphs]) + case ".pptx": + if not pm.is_installed("pptx"): + pm.install("pptx") + from pptx import Presentation + from io import BytesIO + + pptx_file = BytesIO(file) + prs = Presentation(pptx_file) + for slide in prs.slides: + for shape in slide.shapes: + if hasattr(shape, "text"): + content += shape.text + "\n" + case ".xlsx": + if not pm.is_installed("openpyxl"): + pm.install("openpyxl") + from openpyxl import load_workbook + from io import BytesIO + + xlsx_file = BytesIO(file) + wb = load_workbook(xlsx_file) + for sheet in wb: + content += f"Sheet: {sheet.title}\n" + for row in sheet.iter_rows(values_only=True): + content += "\t".join(str(cell) if cell is not None else "" for cell in row) + "\n" + content += "\n" + case _: + logging.error(f"Unsupported file type: {file_path.name} (extension {ext})") + return False + + # Insert into the RAG queue + if content: + await rag.apipeline_enqueue_documents(content) + logging.info(f"Successfully fetched and enqueued file: {file_path.name}") + return True + else: + logging.error(f"No content could be extracted from file: {file_path.name}") + + except Exception as e: + logging.error(f"Error processing or enqueueing file {file_path.name}: {str(e)}") + logging.error(traceback.format_exc()) + finally: + if file_path.name.startswith(temp_prefix): + try: + file_path.unlink() + except Exception as e: + logging.error(f"Error deleting file {file_path}: {str(e)}") + return False + +async def pipeline_index_file(rag, file_path: Path): + """Index a file + + Args: + rag: LightRAG instance + file_path: Path to the saved file + """ + try: + content = "" + ext = file_path.suffix.lower() + + file = None + async with aiofiles.open(file_path, "rb") as f: + file = await f.read() + + # Process based on file type + match ext: + case ".txt" | ".md": + content = file.decode("utf-8") + case ".pdf": + if not pm.is_installed("pypdf2"): + pm.install("pypdf2") + from PyPDF2 import PdfReader # type: ignore + from io import BytesIO + + pdf_file = BytesIO(file) + reader = PdfReader(pdf_file) + for page in reader.pages: + content += page.extract_text() + "\n" + case ".docx": + if not pm.is_installed("docx"): + pm.install("docx") + from docx import Document + from io import BytesIO + + docx_file = BytesIO(file) + doc = Document(docx_file) + content = "\n".join([paragraph.text for paragraph in doc.paragraphs]) + case ".pptx": + if not pm.is_installed("pptx"): + pm.install("pptx") + from pptx import Presentation + from io import BytesIO + + pptx_file = BytesIO(file) + prs = Presentation(pptx_file) + for slide in prs.slides: + for shape in slide.shapes: + if hasattr(shape, "text"): + content += shape.text + "\n" + case ".xlsx": + if not pm.is_installed("openpyxl"): + pm.install("openpyxl") + from openpyxl import load_workbook + from io import BytesIO + + xlsx_file = BytesIO(file) + wb = load_workbook(xlsx_file) + for sheet in wb: + content += f"Sheet: {sheet.title}\n" + for row in sheet.iter_rows(values_only=True): + content += "\t".join(str(cell) if cell is not None else "" for cell in row) + "\n" + content += "\n" + case _: + logging.error(f"Unsupported file type: {file_path.name} (extension {ext})") + return + + # Insert into the RAG queue + if content: + await rag.apipeline_enqueue_documents(content) + await rag.apipeline_process_enqueue_documents() + logging.info(f"Successfully indexed file: {file_path.name}") + else: + logging.error(f"No content could be extracted from file: {file_path.name}") + + except Exception as e: + logging.error(f"Error indexing file {file_path.name}: {str(e)}") + logging.error(traceback.format_exc()) + +async def pipeline_index_files(rag, file_paths: List[Path]): + if not file_paths: + return + try: + enqueued = False + if len(file_paths) == 1: + enqueued = await pipeline_enqueue_file(rag, file_paths[0]) + else: + tasks = [pipeline_enqueue_file(rag, path) for path in file_paths] + enqueued = any(await asyncio.gather(*tasks)) + + if enqueued: + await rag.apipeline_process_enqueue_documents() + except Exception as e: + logging.error(f"Error indexing files: {str(e)}") + logging.error(traceback.format_exc()) + +async def pipeline_index_texts(rag, texts: List[str]): + if not texts: + return + await rag.apipeline_enqueue_documents(texts) + await rag.apipeline_process_enqueue_documents() + +async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + unique_filename = f"{temp_prefix}{timestamp}_{file.filename}" + temp_path = input_dir / "temp" / unique_filename + temp_path.parent.mkdir(exist_ok=True) + with open(temp_path, "wb") as buffer: + shutil.copyfileobj(file.file, buffer) + return temp_path + +async def run_scanning_process(rag, doc_manager: DocumentManager): + """Background task to scan and index documents""" + try: + new_files = doc_manager.scan_directory_for_new_files() + scan_progress["total_files"] = len(new_files) + + logging.info(f"Found {len(new_files)} new files to index.") + for file_path in new_files: + try: + async with progress_lock: + scan_progress["current_file"] = os.path.basename(file_path) + + await pipeline_index_file(rag, file_path) + + async with progress_lock: + scan_progress["indexed_count"] += 1 + scan_progress["progress"] = ( + scan_progress["indexed_count"] / scan_progress["total_files"] + ) * 100 + + except Exception as e: + logging.error(f"Error indexing file {file_path}: {str(e)}") + + except Exception as e: + logging.error(f"Error during scanning process: {str(e)}") + finally: + async with progress_lock: + scan_progress["is_scanning"] = False + +def create_document_routes(rag, doc_manager: DocumentManager, api_key: Optional[str] = None): + optional_api_key = get_api_key_dependency(api_key) + + @router.post("/scan", dependencies=[Depends(optional_api_key)]) + async def scan_for_new_documents(background_tasks: BackgroundTasks): + """ + Trigger the scanning process for new documents. + + This endpoint initiates a background task that scans the input directory for new documents + and processes them. If a scanning process is already running, it returns a status indicating + that fact. + + Args: + background_tasks (BackgroundTasks): FastAPI background tasks handler + + Returns: + dict: A dictionary containing the scanning status + """ + async with progress_lock: + if scan_progress["is_scanning"]: + return {"status": "already_scanning"} + + scan_progress["is_scanning"] = True + scan_progress["indexed_count"] = 0 + scan_progress["progress"] = 0 + + background_tasks.add_task(run_scanning_process, rag, doc_manager) + return {"status": "scanning_started"} + + @router.get("/scan-progress") + async def get_scan_progress(): + """ + Get the current progress of the document scanning process. + + Returns: + dict: A dictionary containing the current scanning progress information including: + - is_scanning: Whether a scan is currently in progress + - current_file: The file currently being processed + - indexed_count: Number of files indexed so far + - total_files: Total number of files to process + - progress: Percentage of completion + """ + async with progress_lock: + return scan_progress + + @router.post("/upload", dependencies=[Depends(optional_api_key)]) + async def upload_to_input_dir( + background_tasks: BackgroundTasks, file: UploadFile = File(...) + ): + """ + Upload a file to the input directory and index it. + + This API endpoint accepts a file through an HTTP POST request, checks if the + uploaded file is of a supported type, saves it in the specified input directory, + indexes it for retrieval, and returns a success status with relevant details. + + Args: + background_tasks: FastAPI BackgroundTasks for async processing + file (UploadFile): The file to be uploaded. It must have an allowed extension. + + Returns: + InsertResponse: A response object containing the upload status and a message. + + Raises: + HTTPException: If the file type is not supported (400) or other errors occur (500). + """ + try: + if not doc_manager.is_supported_file(file.filename): + raise HTTPException( + status_code=400, + detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}", + ) + + file_path = doc_manager.input_dir / file.filename + with open(file_path, "wb") as buffer: + shutil.copyfileobj(file.file, buffer) + + background_tasks.add_task(pipeline_index_file, rag, file_path) + + return InsertResponse( + status="success", + message=f"File '{file.filename}' uploaded successfully. Processing will continue in background.", + ) + except Exception as e: + logging.error(f"Error /documents/upload: {file.filename}: {str(e)}") + logging.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=str(e)) + + @router.post("/text", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) + async def insert_text(request: InsertTextRequest, background_tasks: BackgroundTasks): + """ + Insert text into the RAG system. + + This endpoint allows you to insert text data into the RAG system for later retrieval + and use in generating responses. + + Args: + request (InsertTextRequest): The request body containing the text to be inserted. + background_tasks: FastAPI BackgroundTasks for async processing + + Returns: + InsertResponse: A response object containing the status of the operation. + + Raises: + HTTPException: If an error occurs during text processing (500). + """ + try: + background_tasks.add_task(pipeline_index_texts, rag, [request.text]) + return InsertResponse( + status="success", + message="Text successfully received. Processing will continue in background.", + ) + except Exception as e: + logging.error(f"Error /documents/text: {str(e)}") + logging.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=str(e)) + + @router.post("/texts", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) + async def insert_texts(request: InsertTextsRequest, background_tasks: BackgroundTasks): + """ + Insert multiple texts into the RAG system. + + This endpoint allows you to insert multiple text entries into the RAG system + in a single request. + + Args: + request (InsertTextsRequest): The request body containing the list of texts. + background_tasks: FastAPI BackgroundTasks for async processing + + Returns: + InsertResponse: A response object containing the status of the operation. + + Raises: + HTTPException: If an error occurs during text processing (500). + """ + try: + background_tasks.add_task(pipeline_index_texts, rag, request.texts) + return InsertResponse( + status="success", + message="Text successfully received. Processing will continue in background.", + ) + except Exception as e: + logging.error(f"Error /documents/text: {str(e)}") + logging.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=str(e)) + + @router.post("/file", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) + async def insert_file(background_tasks: BackgroundTasks, file: UploadFile = File(...)): + """ + Insert a file directly into the RAG system. + + This endpoint accepts a file upload and processes it for inclusion in the RAG system. + The file is saved temporarily and processed in the background. + + Args: + background_tasks: FastAPI BackgroundTasks for async processing + file (UploadFile): The file to be processed + + Returns: + InsertResponse: A response object containing the status of the operation. + + Raises: + HTTPException: If the file type is not supported (400) or other errors occur (500). + """ + try: + if not doc_manager.is_supported_file(file.filename): + raise HTTPException( + status_code=400, + detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}", + ) + + temp_path = await save_temp_file(doc_manager.input_dir, file) + background_tasks.add_task(pipeline_index_file, rag, temp_path) + + return InsertResponse( + status="success", + message=f"File '{file.filename}' saved successfully. Processing will continue in background.", + ) + except Exception as e: + logging.error(f"Error /documents/file: {str(e)}") + logging.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=str(e)) + + @router.post("/file_batch", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) + async def insert_batch(background_tasks: BackgroundTasks, files: List[UploadFile] = File(...)): + """ + Process multiple files in batch mode. + + This endpoint allows uploading and processing multiple files simultaneously. + It handles partial successes and provides detailed feedback about failed files. + + Args: + background_tasks: FastAPI BackgroundTasks for async processing + files (List[UploadFile]): List of files to process + + Returns: + InsertResponse: A response object containing: + - status: "success", "partial_success", or "failure" + - message: Detailed information about the operation results + + Raises: + HTTPException: If an error occurs during processing (500). + """ + try: + inserted_count = 0 + failed_files = [] + temp_files = [] + + for file in files: + if doc_manager.is_supported_file(file.filename): + temp_files.append(await save_temp_file(doc_manager.input_dir, file)) + inserted_count += 1 + else: + failed_files.append(f"{file.filename} (unsupported type)") + + if temp_files: + background_tasks.add_task(pipeline_index_files, rag, temp_files) + + if inserted_count == len(files): + status = "success" + status_message = f"Successfully inserted all {inserted_count} documents" + elif inserted_count > 0: + status = "partial_success" + status_message = f"Successfully inserted {inserted_count} out of {len(files)} documents" + if failed_files: + status_message += f". Failed files: {', '.join(failed_files)}" + else: + status = "failure" + status_message = "No documents were successfully inserted" + if failed_files: + status_message += f". Failed files: {', '.join(failed_files)}" + + return InsertResponse(status=status, message=status_message) + except Exception as e: + logging.error(f"Error /documents/batch: {str(e)}") + logging.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=str(e)) + + @router.delete("", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) + async def clear_documents(): + """ + Clear all documents from the RAG system. + + This endpoint deletes all text chunks, entities vector database, and relationships + vector database, effectively clearing all documents from the RAG system. + + Returns: + InsertResponse: A response object containing the status and message. + + Raises: + HTTPException: If an error occurs during the clearing process (500). + """ + try: + rag.text_chunks = [] + rag.entities_vdb = None + rag.relationships_vdb = None + return InsertResponse(status="success", message="All documents cleared successfully") + except Exception as e: + logging.error(f"Error DELETE /documents: {str(e)}") + logging.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=str(e)) + + @router.get("", dependencies=[Depends(optional_api_key)]) + async def documents() -> DocsStatusesResponse: + """ + Get the status of all documents in the system. + + This endpoint retrieves the current status of all documents, grouped by their + processing status (PENDING, PROCESSING, PROCESSED, FAILED). + + Returns: + DocsStatusesResponse: A response object containing a dictionary where keys are + DocStatus values and values are lists of DocStatusResponse + objects representing documents in each status category. + + Raises: + HTTPException: If an error occurs while retrieving document statuses (500). + """ + try: + statuses = ( + DocStatus.PENDING, + DocStatus.PROCESSING, + DocStatus.PROCESSED, + DocStatus.FAILED, + ) + + tasks = [rag.get_docs_by_status(status) for status in statuses] + results: List[Dict[str, DocProcessingStatus]] = await asyncio.gather(*tasks) + + response = DocsStatusesResponse() + + for idx, result in enumerate(results): + status = statuses[idx] + for doc_id, doc_status in result.items(): + if status not in response.statuses: + response.statuses[status] = [] + response.statuses[status].append( + DocStatusResponse( + id=doc_id, + content_summary=doc_status.content_summary, + content_length=doc_status.content_length, + status=doc_status.status, + created_at=DocStatusResponse.format_datetime(doc_status.created_at), + updated_at=DocStatusResponse.format_datetime(doc_status.updated_at), + chunks_count=doc_status.chunks_count, + error=doc_status.error, + metadata=doc_status.metadata, + ) + ) + return response + except Exception as e: + logging.error(f"Error GET /documents: {str(e)}") + logging.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=str(e)) + + return router diff --git a/lightrag/api/routers/graph_routes.py b/lightrag/api/routers/graph_routes.py new file mode 100644 index 00000000..1d08d9ae --- /dev/null +++ b/lightrag/api/routers/graph_routes.py @@ -0,0 +1,26 @@ +""" +This module contains all graph-related routes for the LightRAG API. +""" + +from typing import Optional + +from fastapi import APIRouter, Depends, HTTPException + +from ..utils_api import get_api_key_dependency + +router = APIRouter(tags=["graph"]) + +def create_graph_routes(rag, api_key: Optional[str] = None): + optional_api_key = get_api_key_dependency(api_key) + + @router.get("/graph/label/list", dependencies=[Depends(optional_api_key)]) + async def get_graph_labels(): + """Get all graph labels""" + return await rag.get_graph_labels() + + @router.get("/graphs", dependencies=[Depends(optional_api_key)]) + async def get_knowledge_graph(label: str): + """Get knowledge graph for a specific label""" + return await rag.get_knowledge_graph(nodel_label=label, max_depth=100) + + return router diff --git a/lightrag/api/ollama_api.py b/lightrag/api/routers/ollama_api.py similarity index 100% rename from lightrag/api/ollama_api.py rename to lightrag/api/routers/ollama_api.py diff --git a/lightrag/api/routers/query_routes.py b/lightrag/api/routers/query_routes.py new file mode 100644 index 00000000..389d127d --- /dev/null +++ b/lightrag/api/routers/query_routes.py @@ -0,0 +1,225 @@ +""" +This module contains all query-related routes for the LightRAG API. +""" + +import json +import logging +import traceback +from typing import Any, Dict, List, Literal, Optional + +from fastapi import APIRouter, Depends, HTTPException +from lightrag.base import QueryParam +from ..utils_api import get_api_key_dependency +from pydantic import BaseModel, Field, field_validator + +from ascii_colors import trace_exception + +router = APIRouter(tags=["query"]) + +class QueryRequest(BaseModel): + query: str = Field( + min_length=1, + description="The query text", + ) + + mode: Literal["local", "global", "hybrid", "naive", "mix"] = Field( + default="hybrid", + description="Query mode", + ) + + only_need_context: Optional[bool] = Field( + default=None, + description="If True, only returns the retrieved context without generating a response.", + ) + + only_need_prompt: Optional[bool] = Field( + default=None, + description="If True, only returns the generated prompt without producing a response.", + ) + + response_type: Optional[str] = Field( + min_length=1, + default=None, + description="Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'.", + ) + + top_k: Optional[int] = Field( + ge=1, + default=None, + description="Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.", + ) + + max_token_for_text_unit: Optional[int] = Field( + gt=1, + default=None, + description="Maximum number of tokens allowed for each retrieved text chunk.", + ) + + max_token_for_global_context: Optional[int] = Field( + gt=1, + default=None, + description="Maximum number of tokens allocated for relationship descriptions in global retrieval.", + ) + + max_token_for_local_context: Optional[int] = Field( + gt=1, + default=None, + description="Maximum number of tokens allocated for entity descriptions in local retrieval.", + ) + + hl_keywords: Optional[List[str]] = Field( + default=None, + description="List of high-level keywords to prioritize in retrieval.", + ) + + ll_keywords: Optional[List[str]] = Field( + default=None, + description="List of low-level keywords to refine retrieval focus.", + ) + + conversation_history: Optional[List[Dict[str, Any]]] = Field( + default=None, + description="Stores past conversation history to maintain context. Format: [{'role': 'user/assistant', 'content': 'message'}].", + ) + + history_turns: Optional[int] = Field( + ge=0, + default=None, + description="Number of complete conversation turns (user-assistant pairs) to consider in the response context.", + ) + + @field_validator("query", mode="after") + @classmethod + def query_strip_after(cls, query: str) -> str: + return query.strip() + + @field_validator("hl_keywords", mode="after") + @classmethod + def hl_keywords_strip_after(cls, hl_keywords: List[str] | None) -> List[str] | None: + if hl_keywords is None: + return None + return [keyword.strip() for keyword in hl_keywords] + + @field_validator("ll_keywords", mode="after") + @classmethod + def ll_keywords_strip_after(cls, ll_keywords: List[str] | None) -> List[str] | None: + if ll_keywords is None: + return None + return [keyword.strip() for keyword in ll_keywords] + + @field_validator("conversation_history", mode="after") + @classmethod + def conversation_history_role_check( + cls, conversation_history: List[Dict[str, Any]] | None + ) -> List[Dict[str, Any]] | None: + if conversation_history is None: + return None + for msg in conversation_history: + if "role" not in msg or msg["role"] not in {"user", "assistant"}: + raise ValueError( + "Each message must have a 'role' key with value 'user' or 'assistant'." + ) + return conversation_history + + def to_query_params(self, is_stream: bool) -> "QueryParam": + """Converts a QueryRequest instance into a QueryParam instance.""" + # Use Pydantic's `.model_dump(exclude_none=True)` to remove None values automatically + request_data = self.model_dump(exclude_none=True, exclude={"query"}) + + # Ensure `mode` and `stream` are set explicitly + param = QueryParam(**request_data) + param.stream = is_stream + return param + +class QueryResponse(BaseModel): + response: str = Field( + description="The generated response", + ) + +def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60): + optional_api_key = get_api_key_dependency(api_key) + + @router.post("/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)]) + async def query_text(request: QueryRequest): + """ + Handle a POST request at the /query endpoint to process user queries using RAG capabilities. + + Parameters: + request (QueryRequest): The request object containing the query parameters. + Returns: + QueryResponse: A Pydantic model containing the result of the query processing. + If a string is returned (e.g., cache hit), it's directly returned. + Otherwise, an async generator may be used to build the response. + + Raises: + HTTPException: Raised when an error occurs during the request handling process, + with status code 500 and detail containing the exception message. + """ + try: + param = request.to_query_params(False) + if param.top_k is None: + param.top_k = top_k + response = await rag.aquery(request.query, param=param) + + # If response is a string (e.g. cache hit), return directly + if isinstance(response, str): + return QueryResponse(response=response) + + if isinstance(response, dict): + result = json.dumps(response, indent=2) + return QueryResponse(response=result) + else: + return QueryResponse(response=str(response)) + except Exception as e: + trace_exception(e) + raise HTTPException(status_code=500, detail=str(e)) + + @router.post("/query/stream", dependencies=[Depends(optional_api_key)]) + async def query_text_stream(request: QueryRequest): + """ + This endpoint performs a retrieval-augmented generation (RAG) query and streams the response. + + Args: + request (QueryRequest): The request object containing the query parameters. + optional_api_key (Optional[str], optional): An optional API key for authentication. Defaults to None. + + Returns: + StreamingResponse: A streaming response containing the RAG query results. + """ + try: + param = request.to_query_params(True) + if param.top_k is None: + param.top_k = top_k + response = await rag.aquery(request.query, param=param) + + from fastapi.responses import StreamingResponse + + async def stream_generator(): + if isinstance(response, str): + # If it's a string, send it all at once + yield f"{json.dumps({'response': response})}\n" + else: + # If it's an async generator, send chunks one by one + try: + async for chunk in response: + if chunk: # Only send non-empty content + yield f"{json.dumps({'response': chunk})}\n" + except Exception as e: + logging.error(f"Streaming error: {str(e)}") + yield f"{json.dumps({'error': str(e)})}\n" + + return StreamingResponse( + stream_generator(), + media_type="application/x-ndjson", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Content-Type": "application/x-ndjson", + "X-Accel-Buffering": "no", # Ensure proper handling of streaming response when proxied by Nginx + }, + ) + except Exception as e: + trace_exception(e) + raise HTTPException(status_code=500, detail=str(e)) + + return router diff --git a/lightrag/api/utils_api.py b/lightrag/api/utils_api.py new file mode 100644 index 00000000..c5f1034a --- /dev/null +++ b/lightrag/api/utils_api.py @@ -0,0 +1,44 @@ +""" +Utility functions for the LightRAG API. +""" + +from typing import Optional +from fastapi import HTTPException, Security +from fastapi.security import APIKeyHeader +from starlette.status import HTTP_403_FORBIDDEN + +def get_api_key_dependency(api_key: Optional[str]): + """ + Create an API key dependency for route protection. + + Args: + api_key (Optional[str]): The API key to validate against. + If None, no authentication is required. + + Returns: + Callable: A dependency function that validates the API key. + """ + if not api_key: + # If no API key is configured, return a dummy dependency that always succeeds + async def no_auth(): + return None + + return no_auth + + # If API key is configured, use proper authentication + api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) + + async def api_key_auth( + api_key_header_value: Optional[str] = Security(api_key_header), + ): + if not api_key_header_value: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, detail="API Key required" + ) + if api_key_header_value != api_key: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, detail="Invalid API Key" + ) + return api_key_header_value + + return api_key_auth