From c0c87edc451068a023ec9d891aff12ef42596ba2 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 20 Feb 2025 03:26:39 +0800 Subject: [PATCH 01/14] split lightrag_servery.py to smaller files --- lightrag/api/lightrag_server.py | 969 +---------------------- lightrag/api/routers/__init__.py | 10 + lightrag/api/routers/document_routes.py | 667 ++++++++++++++++ lightrag/api/routers/graph_routes.py | 26 + lightrag/api/{ => routers}/ollama_api.py | 0 lightrag/api/routers/query_routes.py | 225 ++++++ lightrag/api/utils_api.py | 44 + 7 files changed, 1008 insertions(+), 933 deletions(-) create mode 100644 lightrag/api/routers/__init__.py create mode 100644 lightrag/api/routers/document_routes.py create mode 100644 lightrag/api/routers/graph_routes.py rename lightrag/api/{ => routers}/ollama_api.py (100%) create mode 100644 lightrag/api/routers/query_routes.py create mode 100644 lightrag/api/utils_api.py diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 0cf1d01e..f7f70c62 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -1,44 +1,43 @@ +""" +LightRAG FastAPI Server +""" + from fastapi import ( FastAPI, HTTPException, - File, - UploadFile, - BackgroundTasks, + Depends, ) import asyncio import threading import os -import json -import re from fastapi.staticfiles import StaticFiles import logging import argparse -from typing import List, Any, Literal, Optional, Dict -from pydantic import BaseModel, Field, field_validator +from typing import Optional, Dict from pathlib import Path -import shutil -import aiofiles -from ascii_colors import trace_exception, ASCIIColors +import configparser +from ascii_colors import ASCIIColors import sys -from fastapi import Depends, Security -from fastapi.security import APIKeyHeader from fastapi.middleware.cors import CORSMiddleware from contextlib import asynccontextmanager -from starlette.status import HTTP_403_FORBIDDEN -import pipmaster as pm from dotenv import load_dotenv -import configparser -import traceback -from datetime import datetime -from lightrag import LightRAG, QueryParam -from lightrag.base import DocProcessingStatus, DocStatus +from .utils_api import get_api_key_dependency + +from lightrag import LightRAG from lightrag.types import GPTKeywordExtractionFormat from lightrag.api import __api_version__ from lightrag.utils import EmbeddingFunc from lightrag.utils import logger -from .ollama_api import OllamaAPI, ollama_server_infos +from .routers.document_routes import ( + DocumentManager, + create_document_routes, + run_scanning_process, +) +from .routers.query_routes import create_query_routes +from .routers.graph_routes import create_graph_routes +from .routers.ollama_api import OllamaAPI, ollama_server_infos # Load environment variables try: @@ -50,6 +49,9 @@ except Exception as e: config = configparser.ConfigParser() config.read("config.ini") +# Global configuration +global_top_k = 60 # default value + class DefaultRAGStorageConfig: KV_STORAGE = "JsonKVStorage" @@ -70,22 +72,6 @@ scan_progress: Dict = { # Lock for thread-safe operations progress_lock = threading.Lock() - -def estimate_tokens(text: str) -> int: - """Estimate the number of tokens in text - Chinese characters: approximately 1.5 tokens per character - English characters: approximately 0.25 tokens per character - """ - # Use regex to match Chinese and non-Chinese characters separately - chinese_chars = len(re.findall(r"[\u4e00-\u9fff]", text)) - non_chinese_chars = len(re.findall(r"[^\u4e00-\u9fff]", text)) - - # Calculate estimated token count - tokens = chinese_chars * 1.5 + non_chinese_chars * 0.25 - - return int(tokens) - - def get_default_host(binding_type: str) -> str: default_hosts = { "ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"), @@ -98,17 +84,17 @@ def get_default_host(binding_type: str) -> str: ) # fallback to ollama if unknown -def get_env_value(env_key: str, default: Any, value_type: type = str) -> Any: +def get_env_value(env_key: str, default: any, value_type: type = str) -> any: """ Get value from environment variable with type conversion Args: env_key (str): Environment variable key - default (Any): Default value if env variable is not set + default (any): Default value if env variable is not set value_type (type): Type to convert the value to Returns: - Any: Converted value from environment or default + any: Converted value from environment or default """ value = os.getenv(env_key) if value is None: @@ -557,7 +543,7 @@ def parse_args() -> argparse.Namespace: args = parser.parse_args() - # conver relative path to absolute path + # convert relative path to absolute path args.working_dir = os.path.abspath(args.working_dir) args.input_dir = os.path.abspath(args.input_dir) @@ -566,293 +552,16 @@ def parse_args() -> argparse.Namespace: return args -class DocumentManager: - """Handles document operations and tracking""" - - def __init__( - self, - input_dir: str, - supported_extensions: tuple = ( - ".txt", - ".md", - ".pdf", - ".docx", - ".pptx", - ".xlsx", - ), - ): - self.input_dir = Path(input_dir) - self.supported_extensions = supported_extensions - self.indexed_files = set() - - # Create input directory if it doesn't exist - self.input_dir.mkdir(parents=True, exist_ok=True) - - def scan_directory_for_new_files(self) -> List[Path]: - """Scan input directory for new files""" - new_files = [] - for ext in self.supported_extensions: - logger.info(f"Scanning for {ext} files in {self.input_dir}") - for file_path in self.input_dir.rglob(f"*{ext}"): - if file_path not in self.indexed_files: - new_files.append(file_path) - return new_files - - def scan_directory(self) -> List[Path]: - """Scan input directory for new files""" - new_files = [] - for ext in self.supported_extensions: - for file_path in self.input_dir.rglob(f"*{ext}"): - new_files.append(file_path) - return new_files - - def mark_as_indexed(self, file_path: Path): - """Mark a file as indexed""" - self.indexed_files.add(file_path) - - def is_supported_file(self, filename: str) -> bool: - """Check if file type is supported""" - return any(filename.lower().endswith(ext) for ext in self.supported_extensions) - - -class QueryRequest(BaseModel): - query: str = Field( - min_length=1, - description="The query text", - ) - - mode: Literal["local", "global", "hybrid", "naive", "mix"] = Field( - default="hybrid", - description="Query mode", - ) - - only_need_context: Optional[bool] = Field( - default=None, - description="If True, only returns the retrieved context without generating a response.", - ) - - only_need_prompt: Optional[bool] = Field( - default=None, - description="If True, only returns the generated prompt without producing a response.", - ) - - response_type: Optional[str] = Field( - min_length=1, - default=None, - description="Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'.", - ) - - top_k: Optional[int] = Field( - ge=1, - default=None, - description="Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.", - ) - - max_token_for_text_unit: Optional[int] = Field( - gt=1, - default=None, - description="Maximum number of tokens allowed for each retrieved text chunk.", - ) - - max_token_for_global_context: Optional[int] = Field( - gt=1, - default=None, - description="Maximum number of tokens allocated for relationship descriptions in global retrieval.", - ) - - max_token_for_local_context: Optional[int] = Field( - gt=1, - default=None, - description="Maximum number of tokens allocated for entity descriptions in local retrieval.", - ) - - hl_keywords: Optional[List[str]] = Field( - default=None, - description="List of high-level keywords to prioritize in retrieval.", - ) - - ll_keywords: Optional[List[str]] = Field( - default=None, - description="List of low-level keywords to refine retrieval focus.", - ) - - conversation_history: Optional[List[dict[str, Any]]] = Field( - default=None, - description="Stores past conversation history to maintain context. Format: [{'role': 'user/assistant', 'content': 'message'}].", - ) - - history_turns: Optional[int] = Field( - ge=0, - default=None, - description="Number of complete conversation turns (user-assistant pairs) to consider in the response context.", - ) - - @field_validator("query", mode="after") - @classmethod - def query_strip_after(cls, query: str) -> str: - return query.strip() - - @field_validator("hl_keywords", mode="after") - @classmethod - def hl_keywords_strip_after(cls, hl_keywords: List[str] | None) -> List[str] | None: - if hl_keywords is None: - return None - return [keyword.strip() for keyword in hl_keywords] - - @field_validator("ll_keywords", mode="after") - @classmethod - def ll_keywords_strip_after(cls, ll_keywords: List[str] | None) -> List[str] | None: - if ll_keywords is None: - return None - return [keyword.strip() for keyword in ll_keywords] - - @field_validator("conversation_history", mode="after") - @classmethod - def conversation_history_role_check( - cls, conversation_history: List[dict[str, Any]] | None - ) -> List[dict[str, Any]] | None: - if conversation_history is None: - return None - for msg in conversation_history: - if "role" not in msg or msg["role"] not in {"user", "assistant"}: - raise ValueError( - "Each message must have a 'role' key with value 'user' or 'assistant'." - ) - return conversation_history - - def to_query_params(self, is_stream: bool) -> QueryParam: - """Converts a QueryRequest instance into a QueryParam instance.""" - # Use Pydantic's `.model_dump(exclude_none=True)` to remove None values automatically - request_data = self.model_dump(exclude_none=True, exclude={"query"}) - - # Ensure `mode` and `stream` are set explicitly - param = QueryParam(**request_data) - param.stream = is_stream - return param - - -class QueryResponse(BaseModel): - response: str = Field( - description="The generated response", - ) - - -class InsertTextRequest(BaseModel): - text: str = Field( - min_length=1, - description="The text to insert", - ) - - @field_validator("text", mode="after") - @classmethod - def strip_after(cls, text: str) -> str: - return text.strip() - - -class InsertTextsRequest(BaseModel): - texts: list[str] = Field( - min_length=1, - description="The texts to insert", - ) - - @field_validator("texts", mode="after") - @classmethod - def strip_after(cls, texts: list[str]) -> list[str]: - return [text.strip() for text in texts] - - -class InsertResponse(BaseModel): - status: str = Field(description="Status of the operation") - message: str = Field(description="Message describing the operation result") - - -class DocStatusResponse(BaseModel): - @staticmethod - def format_datetime(dt: Any) -> Optional[str]: - """Format datetime to ISO string - - Args: - dt: Datetime object or string - - Returns: - Formatted datetime string or None - """ - if dt is None: - return None - if isinstance(dt, str): - return dt - return dt.isoformat() - - """Response model for document status - - Attributes: - id: Document identifier - content_summary: Summary of document content - content_length: Length of document content - status: Current processing status - created_at: Creation timestamp (ISO format string) - updated_at: Last update timestamp (ISO format string) - chunks_count: Number of chunks (optional) - error: Error message if any (optional) - metadata: Additional metadata (optional) - """ - - id: str - content_summary: str - content_length: int - status: DocStatus - created_at: str - updated_at: str - chunks_count: Optional[int] = None - error: Optional[str] = None - metadata: Optional[dict[str, Any]] = None - - -class DocsStatusesResponse(BaseModel): - statuses: Dict[DocStatus, List[DocStatusResponse]] = {} - - -def get_api_key_dependency(api_key: Optional[str]): - if not api_key: - # If no API key is configured, return a dummy dependency that always succeeds - async def no_auth(): - return None - - return no_auth - - # If API key is configured, use proper authentication - api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) - - async def api_key_auth( - api_key_header_value: Optional[str] = Security(api_key_header), - ): - if not api_key_header_value: - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, detail="API Key required" - ) - if api_key_header_value != api_key: - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, detail="Invalid API Key" - ) - return api_key_header_value - - return api_key_auth - - -# Global configuration -global_top_k = 60 # default value -temp_prefix = "__tmp_" # prefix for temporary files - - def create_app(args): + # Set global top_k + global global_top_k + global_top_k = args.top_k # save top_k from args + # Initialize verbose debug setting from lightrag.utils import set_verbose_debug set_verbose_debug(args.verbose) - global global_top_k - global_top_k = args.top_k # save top_k from args - # Verify that bindings are correctly setup if args.llm_binding not in [ "lollms", @@ -914,7 +623,7 @@ def create_app(args): scan_progress["indexed_count"] = 0 scan_progress["progress"] = 0 # Create background task - task = asyncio.create_task(run_scanning_process()) + task = asyncio.create_task(run_scanning_process(rag, doc_manager)) app.state.background_tasks.add(task) task.add_done_callback(app.state.background_tasks.discard) ASCIIColors.info( @@ -922,7 +631,7 @@ def create_app(args): ) else: ASCIIColors.info( - "Skip document scanning(anohter scanning is active)" + "Skip document scanning(another scanning is active)" ) yield @@ -1130,621 +839,15 @@ def create_app(args): auto_manage_storages_states=False, ) - async def pipeline_enqueue_file(file_path: Path) -> bool: - """Add a file to the queue for processing - - Args: - file_path: Path to the saved file - Returns: - bool: True if the file was successfully enqueued, False otherwise - """ - try: - content = "" - ext = file_path.suffix.lower() - - file = None - async with aiofiles.open(file_path, "rb") as f: - file = await f.read() - - # Process based on file type - match ext: - case ".txt" | ".md": - content = file.decode("utf-8") - case ".pdf": - if not pm.is_installed("pypdf2"): - pm.install("pypdf2") - from PyPDF2 import PdfReader # type: ignore - from io import BytesIO - - pdf_file = BytesIO(file) - reader = PdfReader(pdf_file) - for page in reader.pages: - content += page.extract_text() + "\n" - case ".docx": - if not pm.is_installed("docx"): - pm.install("docx") - from docx import Document - from io import BytesIO - - docx_file = BytesIO(file) - doc = Document(docx_file) - content = "\n".join( - [paragraph.text for paragraph in doc.paragraphs] - ) - case ".pptx": - if not pm.is_installed("pptx"): - pm.install("pptx") - from pptx import Presentation # type: ignore - from io import BytesIO - - pptx_file = BytesIO(file) - prs = Presentation(pptx_file) - for slide in prs.slides: - for shape in slide.shapes: - if hasattr(shape, "text"): - content += shape.text + "\n" - case ".xlsx": - if not pm.is_installed("openpyxl"): - pm.install("openpyxl") - from openpyxl import load_workbook # type: ignore - from io import BytesIO - - xlsx_file = BytesIO(file) - wb = load_workbook(xlsx_file) - for sheet in wb: - content += f"Sheet: {sheet.title}\n" - for row in sheet.iter_rows(values_only=True): - content += ( - "\t".join( - str(cell) if cell is not None else "" - for cell in row - ) - + "\n" - ) - content += "\n" - case _: - logging.error( - f"Unsupported file type: {file_path.name} (extension {ext})" - ) - return False - - # Insert into the RAG queue - if content: - await rag.apipeline_enqueue_documents(content) - logging.info( - f"Successfully fetched and enqueued file: {file_path.name}" - ) - return True - else: - logging.error( - f"No content could be extracted from file: {file_path.name}" - ) - - except Exception as e: - logging.error( - f"Error processing or enqueueing file {file_path.name}: {str(e)}" - ) - logging.error(traceback.format_exc()) - finally: - if file_path.name.startswith(temp_prefix): - # Clean up the temporary file after indexing - try: - file_path.unlink() - except Exception as e: - logging.error(f"Error deleting file {file_path}: {str(e)}") - return False - - async def pipeline_index_file(file_path: Path): - """Index a file - - Args: - file_path: Path to the saved file - """ - try: - if await pipeline_enqueue_file(file_path): - await rag.apipeline_process_enqueue_documents() - - except Exception as e: - logging.error(f"Error indexing file {file_path.name}: {str(e)}") - logging.error(traceback.format_exc()) - - async def pipeline_index_files(file_paths: List[Path]): - """Index multiple files concurrently - - Args: - file_paths: Paths to the files to index - """ - if not file_paths: - return - try: - enqueued = False - - if len(file_paths) == 1: - enqueued = await pipeline_enqueue_file(file_paths[0]) - else: - tasks = [pipeline_enqueue_file(path) for path in file_paths] - enqueued = any(await asyncio.gather(*tasks)) - - if enqueued: - await rag.apipeline_process_enqueue_documents() - except Exception as e: - logging.error(f"Error indexing files: {str(e)}") - logging.error(traceback.format_exc()) - - async def pipeline_index_texts(texts: List[str]): - """Index a list of texts - - Args: - texts: The texts to index - """ - if not texts: - return - await rag.apipeline_enqueue_documents(texts) - await rag.apipeline_process_enqueue_documents() - - async def save_temp_file(file: UploadFile = File(...)) -> Path: - """Save the uploaded file to a temporary location - - Args: - file: The uploaded file - - Returns: - Path: The path to the saved file - """ - # Generate unique filename to avoid conflicts - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - unique_filename = f"{temp_prefix}{timestamp}_{file.filename}" - - # Create a temporary file to save the uploaded content - temp_path = doc_manager.input_dir / "temp" / unique_filename - temp_path.parent.mkdir(exist_ok=True) - - # Save the file - with open(temp_path, "wb") as buffer: - shutil.copyfileobj(file.file, buffer) - return temp_path - - async def run_scanning_process(): - """Background task to scan and index documents""" - global scan_progress - - try: - new_files = doc_manager.scan_directory_for_new_files() - scan_progress["total_files"] = len(new_files) - - logger.info(f"Found {len(new_files)} new files to index.") - for file_path in new_files: - try: - with progress_lock: - scan_progress["current_file"] = os.path.basename(file_path) - - await pipeline_index_file(file_path) - - with progress_lock: - scan_progress["indexed_count"] += 1 - scan_progress["progress"] = ( - scan_progress["indexed_count"] - / scan_progress["total_files"] - ) * 100 - - except Exception as e: - logging.error(f"Error indexing file {file_path}: {str(e)}") - - except Exception as e: - logging.error(f"Error during scanning process: {str(e)}") - finally: - with progress_lock: - scan_progress["is_scanning"] = False - - @app.post("/documents/scan", dependencies=[Depends(optional_api_key)]) - async def scan_for_new_documents(background_tasks: BackgroundTasks): - """Trigger the scanning process""" - global scan_progress - - with progress_lock: - if scan_progress["is_scanning"]: - return {"status": "already_scanning"} - - scan_progress["is_scanning"] = True - scan_progress["indexed_count"] = 0 - scan_progress["progress"] = 0 - - # Start the scanning process in the background - background_tasks.add_task(run_scanning_process) - - return {"status": "scanning_started"} - - @app.get("/documents/scan-progress") - async def get_scan_progress(): - """Get the current scanning progress""" - with progress_lock: - return scan_progress - - @app.post("/documents/upload", dependencies=[Depends(optional_api_key)]) - async def upload_to_input_dir( - background_tasks: BackgroundTasks, file: UploadFile = File(...) - ): - """ - Endpoint for uploading a file to the input directory and indexing it. - - This API endpoint accepts a file through an HTTP POST request, checks if the - uploaded file is of a supported type, saves it in the specified input directory, - indexes it for retrieval, and returns a success status with relevant details. - - Parameters: - background_tasks: FastAPI BackgroundTasks for async processing - file (UploadFile): The file to be uploaded. It must have an allowed extension as per - `doc_manager.supported_extensions`. - - Returns: - dict: A dictionary containing the upload status ("success"), - a message detailing the operation result, and - the total number of indexed documents. - - Raises: - HTTPException: If the file type is not supported, it raises a 400 Bad Request error. - If any other exception occurs during the file handling or indexing, - it raises a 500 Internal Server Error with details about the exception. - """ - try: - if not doc_manager.is_supported_file(file.filename): - raise HTTPException( - status_code=400, - detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}", - ) - - file_path = doc_manager.input_dir / file.filename - with open(file_path, "wb") as buffer: - shutil.copyfileobj(file.file, buffer) - - # Add to background tasks - background_tasks.add_task(pipeline_index_file, file_path) - - return InsertResponse( - status="success", - message=f"File '{file.filename}' uploaded successfully. Processing will continue in background.", - ) - except Exception as e: - logging.error(f"Error /documents/upload: {file.filename}: {str(e)}") - logging.error(traceback.format_exc()) - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/documents/text", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def insert_text( - request: InsertTextRequest, background_tasks: BackgroundTasks - ): - """ - Insert text into the Retrieval-Augmented Generation (RAG) system. - - This endpoint allows you to insert text data into the RAG system for later retrieval and use in generating responses. - - Args: - request (InsertTextRequest): The request body containing the text to be inserted. - background_tasks: FastAPI BackgroundTasks for async processing - - Returns: - InsertResponse: A response object containing the status of the operation, a message, and the number of documents inserted. - """ - try: - background_tasks.add_task(pipeline_index_texts, [request.text]) - return InsertResponse( - status="success", - message="Text successfully received. Processing will continue in background.", - ) - except Exception as e: - logging.error(f"Error /documents/text: {str(e)}") - logging.error(traceback.format_exc()) - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/documents/texts", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def insert_texts( - request: InsertTextsRequest, background_tasks: BackgroundTasks - ): - """ - Insert texts into the Retrieval-Augmented Generation (RAG) system. - - This endpoint allows you to insert text data into the RAG system for later retrieval and use in generating responses. - - Args: - request (InsertTextsRequest): The request body containing the text to be inserted. - background_tasks: FastAPI BackgroundTasks for async processing - - Returns: - InsertResponse: A response object containing the status of the operation, a message, and the number of documents inserted. - """ - try: - background_tasks.add_task(pipeline_index_texts, request.texts) - return InsertResponse( - status="success", - message="Text successfully received. Processing will continue in background.", - ) - except Exception as e: - logging.error(f"Error /documents/text: {str(e)}") - logging.error(traceback.format_exc()) - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/documents/file", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def insert_file( - background_tasks: BackgroundTasks, file: UploadFile = File(...) - ): - """Insert a file directly into the RAG system - - Args: - background_tasks: FastAPI BackgroundTasks for async processing - file: Uploaded file - - Returns: - InsertResponse: Status of the insertion operation - - Raises: - HTTPException: For unsupported file types or processing errors - """ - try: - if not doc_manager.is_supported_file(file.filename): - raise HTTPException( - status_code=400, - detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}", - ) - - # Create a temporary file to save the uploaded content - temp_path = save_temp_file(file) - - # Add to background tasks - background_tasks.add_task(pipeline_index_file, temp_path) - - return InsertResponse( - status="success", - message=f"File '{file.filename}' saved successfully. Processing will continue in background.", - ) - - except Exception as e: - logging.error(f"Error /documents/file: {str(e)}") - logging.error(traceback.format_exc()) - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/documents/file_batch", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def insert_batch( - background_tasks: BackgroundTasks, files: List[UploadFile] = File(...) - ): - """Process multiple files in batch mode - - Args: - background_tasks: FastAPI BackgroundTasks for async processing - files: List of files to process - - Returns: - InsertResponse: Status of the batch insertion operation - - Raises: - HTTPException: For processing errors - """ - try: - inserted_count = 0 - failed_files = [] - temp_files = [] - - for file in files: - if doc_manager.is_supported_file(file.filename): - # Create a temporary file to save the uploaded content - temp_files.append(save_temp_file(file)) - inserted_count += 1 - else: - failed_files.append(f"{file.filename} (unsupported type)") - - if temp_files: - background_tasks.add_task(pipeline_index_files, temp_files) - - # Prepare status message - if inserted_count == len(files): - status = "success" - status_message = f"Successfully inserted all {inserted_count} documents" - elif inserted_count > 0: - status = "partial_success" - status_message = f"Successfully inserted {inserted_count} out of {len(files)} documents" - if failed_files: - status_message += f". Failed files: {', '.join(failed_files)}" - else: - status = "failure" - status_message = "No documents were successfully inserted" - if failed_files: - status_message += f". Failed files: {', '.join(failed_files)}" - - return InsertResponse(status=status, message=status_message) - - except Exception as e: - logging.error(f"Error /documents/batch: {file.filename}: {str(e)}") - logging.error(traceback.format_exc()) - raise HTTPException(status_code=500, detail=str(e)) - - @app.delete( - "/documents", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def clear_documents(): - """ - Clear all documents from the LightRAG system. - - This endpoint deletes all text chunks, entities vector database, and relationships vector database, - effectively clearing all documents from the LightRAG system. - - Returns: - InsertResponse: A response object containing the status, message, and the new document count (0 in this case). - """ - try: - rag.text_chunks = [] - rag.entities_vdb = None - rag.relationships_vdb = None - return InsertResponse( - status="success", message="All documents cleared successfully" - ) - except Exception as e: - logging.error(f"Error DELETE /documents: {str(e)}") - logging.error(traceback.format_exc()) - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)] - ) - async def query_text(request: QueryRequest): - """ - Handle a POST request at the /query endpoint to process user queries using RAG capabilities. - - Parameters: - request (QueryRequest): The request object containing the query parameters. - Returns: - QueryResponse: A Pydantic model containing the result of the query processing. - If a string is returned (e.g., cache hit), it's directly returned. - Otherwise, an async generator may be used to build the response. - - Raises: - HTTPException: Raised when an error occurs during the request handling process, - with status code 500 and detail containing the exception message. - """ - try: - response = await rag.aquery( - request.query, param=request.to_query_params(False) - ) - - # If response is a string (e.g. cache hit), return directly - if isinstance(response, str): - return QueryResponse(response=response) - - if isinstance(response, dict): - result = json.dumps(response, indent=2) - return QueryResponse(response=result) - else: - return QueryResponse(response=str(response)) - except Exception as e: - trace_exception(e) - raise HTTPException(status_code=500, detail=str(e)) - - @app.post("/query/stream", dependencies=[Depends(optional_api_key)]) - async def query_text_stream(request: QueryRequest): - """ - This endpoint performs a retrieval-augmented generation (RAG) query and streams the response. - - Args: - request (QueryRequest): The request object containing the query parameters. - optional_api_key (Optional[str], optional): An optional API key for authentication. Defaults to None. - - Returns: - StreamingResponse: A streaming response containing the RAG query results. - """ - try: - response = await rag.aquery( - request.query, param=request.to_query_params(True) - ) - - from fastapi.responses import StreamingResponse - - async def stream_generator(): - if isinstance(response, str): - # If it's a string, send it all at once - yield f"{json.dumps({'response': response})}\n" - else: - # If it's an async generator, send chunks one by one - try: - async for chunk in response: - if chunk: # Only send non-empty content - yield f"{json.dumps({'response': chunk})}\n" - except Exception as e: - logging.error(f"Streaming error: {str(e)}") - yield f"{json.dumps({'error': str(e)})}\n" - - return StreamingResponse( - stream_generator(), - media_type="application/x-ndjson", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "Content-Type": "application/x-ndjson", - "X-Accel-Buffering": "no", # Ensure proper handling of streaming response when proxied by Nginx - }, - ) - except Exception as e: - trace_exception(e) - raise HTTPException(status_code=500, detail=str(e)) - - # query all graph labels - @app.get("/graph/label/list") - async def get_graph_labels(): - return await rag.get_graph_labels() - - # query all graph - @app.get("/graphs") - async def get_knowledge_graph(label: str): - return await rag.get_knowledge_graph(nodel_label=label, max_depth=100) + # Add routes + app.include_router(create_document_routes(rag, doc_manager, api_key)) + app.include_router(create_query_routes(rag, api_key, args.top_k)) + app.include_router(create_graph_routes(rag, api_key)) # Add Ollama API routes ollama_api = OllamaAPI(rag, top_k=args.top_k) app.include_router(ollama_api.router, prefix="/api") - @app.get("/documents", dependencies=[Depends(optional_api_key)]) - async def documents() -> DocsStatusesResponse: - """ - Get documents statuses - Returns: - DocsStatusesResponse: A response object containing a dictionary where keys are DocStatus - and values are lists of DocStatusResponse objects representing documents in each status category. - """ - try: - statuses = ( - DocStatus.PENDING, - DocStatus.PROCESSING, - DocStatus.PROCESSED, - DocStatus.FAILED, - ) - - tasks = [rag.get_docs_by_status(status) for status in statuses] - results: List[Dict[str, DocProcessingStatus]] = await asyncio.gather(*tasks) - - response = DocsStatusesResponse() - - for idx, result in enumerate(results): - status = statuses[idx] - for doc_id, doc_status in result.items(): - if status not in response.statuses: - response.statuses[status] = [] - response.statuses[status].append( - DocStatusResponse( - id=doc_id, - content_summary=doc_status.content_summary, - content_length=doc_status.content_length, - status=doc_status.status, - created_at=DocStatusResponse.format_datetime( - doc_status.created_at - ), - updated_at=DocStatusResponse.format_datetime( - doc_status.updated_at - ), - chunks_count=doc_status.chunks_count, - error=doc_status.error, - metadata=doc_status.metadata, - ) - ) - return response - except Exception as e: - logging.error(f"Error GET /documents: {str(e)}") - logging.error(traceback.format_exc()) - raise HTTPException(status_code=500, detail=str(e)) - @app.get("/health", dependencies=[Depends(optional_api_key)]) async def get_status(): """Get current system status""" diff --git a/lightrag/api/routers/__init__.py b/lightrag/api/routers/__init__.py new file mode 100644 index 00000000..b71f204e --- /dev/null +++ b/lightrag/api/routers/__init__.py @@ -0,0 +1,10 @@ +""" +This module contains all the routers for the LightRAG API. +""" + +from .document_routes import router as document_router +from .query_routes import router as query_router +from .graph_routes import router as graph_router +from .ollama_api import OllamaAPI + +__all__ = ["document_router", "query_router", "graph_router", "OllamaAPI"] diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py new file mode 100644 index 00000000..48401658 --- /dev/null +++ b/lightrag/api/routers/document_routes.py @@ -0,0 +1,667 @@ +""" +This module contains all document-related routes for the LightRAG API. +""" + +import asyncio +import logging +import os +import aiofiles +import shutil +import traceback +import pipmaster as pm +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Optional, Any + +from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile +from fastapi.security import APIKeyHeader +from pydantic import BaseModel, Field, field_validator +from starlette.status import HTTP_403_FORBIDDEN + +from lightrag.base import DocProcessingStatus, DocStatus +from ..utils_api import get_api_key_dependency + + +router = APIRouter(prefix="/documents", tags=["documents"]) + +# Global progress tracker +scan_progress: Dict = { + "is_scanning": False, + "current_file": "", + "indexed_count": 0, + "total_files": 0, + "progress": 0, +} + +# Lock for thread-safe operations +progress_lock = asyncio.Lock() + +# Temporary file prefix +temp_prefix = "__tmp__" + +class InsertTextRequest(BaseModel): + text: str = Field( + min_length=1, + description="The text to insert", + ) + + @field_validator("text", mode="after") + @classmethod + def strip_after(cls, text: str) -> str: + return text.strip() + +class InsertTextsRequest(BaseModel): + texts: list[str] = Field( + min_length=1, + description="The texts to insert", + ) + + @field_validator("texts", mode="after") + @classmethod + def strip_after(cls, texts: list[str]) -> list[str]: + return [text.strip() for text in texts] + +class InsertResponse(BaseModel): + status: str = Field(description="Status of the operation") + message: str = Field(description="Message describing the operation result") + +class DocStatusResponse(BaseModel): + @staticmethod + def format_datetime(dt: Any) -> Optional[str]: + if dt is None: + return None + if isinstance(dt, str): + return dt + return dt.isoformat() + + id: str + content_summary: str + content_length: int + status: DocStatus + created_at: str + updated_at: str + chunks_count: Optional[int] = None + error: Optional[str] = None + metadata: Optional[dict[str, Any]] = None + +class DocsStatusesResponse(BaseModel): + statuses: Dict[DocStatus, List[DocStatusResponse]] = {} + +class DocumentManager: + def __init__( + self, + input_dir: str, + supported_extensions: tuple = ( + ".txt", + ".md", + ".pdf", + ".docx", + ".pptx", + ".xlsx", + ), + ): + self.input_dir = Path(input_dir) + self.supported_extensions = supported_extensions + self.indexed_files = set() + + # Create input directory if it doesn't exist + self.input_dir.mkdir(parents=True, exist_ok=True) + + def scan_directory_for_new_files(self) -> List[Path]: + new_files = [] + for ext in self.supported_extensions: + logging.info(f"Scanning for {ext} files in {self.input_dir}") + for file_path in self.input_dir.rglob(f"*{ext}"): + if file_path not in self.indexed_files: + new_files.append(file_path) + return new_files + + def scan_directory(self) -> List[Path]: + new_files = [] + for ext in self.supported_extensions: + for file_path in self.input_dir.rglob(f"*{ext}"): + new_files.append(file_path) + return new_files + + def mark_as_indexed(self, file_path: Path): + self.indexed_files.add(file_path) + + def is_supported_file(self, filename: str) -> bool: + return any(filename.lower().endswith(ext) for ext in self.supported_extensions) + +async def pipeline_enqueue_file(rag, file_path: Path) -> bool: + try: + content = "" + ext = file_path.suffix.lower() + + file = None + async with aiofiles.open(file_path, "rb") as f: + file = await f.read() + + # Process based on file type + match ext: + case ".txt" | ".md": + content = file.decode("utf-8") + case ".pdf": + if not pm.is_installed("pypdf2"): + pm.install("pypdf2") + from PyPDF2 import PdfReader # type: ignore + from io import BytesIO + + pdf_file = BytesIO(file) + reader = PdfReader(pdf_file) + for page in reader.pages: + content += page.extract_text() + "\n" + case ".docx": + if not pm.is_installed("docx"): + pm.install("docx") + from docx import Document + from io import BytesIO + + docx_file = BytesIO(file) + doc = Document(docx_file) + content = "\n".join([paragraph.text for paragraph in doc.paragraphs]) + case ".pptx": + if not pm.is_installed("pptx"): + pm.install("pptx") + from pptx import Presentation + from io import BytesIO + + pptx_file = BytesIO(file) + prs = Presentation(pptx_file) + for slide in prs.slides: + for shape in slide.shapes: + if hasattr(shape, "text"): + content += shape.text + "\n" + case ".xlsx": + if not pm.is_installed("openpyxl"): + pm.install("openpyxl") + from openpyxl import load_workbook + from io import BytesIO + + xlsx_file = BytesIO(file) + wb = load_workbook(xlsx_file) + for sheet in wb: + content += f"Sheet: {sheet.title}\n" + for row in sheet.iter_rows(values_only=True): + content += "\t".join(str(cell) if cell is not None else "" for cell in row) + "\n" + content += "\n" + case _: + logging.error(f"Unsupported file type: {file_path.name} (extension {ext})") + return False + + # Insert into the RAG queue + if content: + await rag.apipeline_enqueue_documents(content) + logging.info(f"Successfully fetched and enqueued file: {file_path.name}") + return True + else: + logging.error(f"No content could be extracted from file: {file_path.name}") + + except Exception as e: + logging.error(f"Error processing or enqueueing file {file_path.name}: {str(e)}") + logging.error(traceback.format_exc()) + finally: + if file_path.name.startswith(temp_prefix): + try: + file_path.unlink() + except Exception as e: + logging.error(f"Error deleting file {file_path}: {str(e)}") + return False + +async def pipeline_index_file(rag, file_path: Path): + """Index a file + + Args: + rag: LightRAG instance + file_path: Path to the saved file + """ + try: + content = "" + ext = file_path.suffix.lower() + + file = None + async with aiofiles.open(file_path, "rb") as f: + file = await f.read() + + # Process based on file type + match ext: + case ".txt" | ".md": + content = file.decode("utf-8") + case ".pdf": + if not pm.is_installed("pypdf2"): + pm.install("pypdf2") + from PyPDF2 import PdfReader # type: ignore + from io import BytesIO + + pdf_file = BytesIO(file) + reader = PdfReader(pdf_file) + for page in reader.pages: + content += page.extract_text() + "\n" + case ".docx": + if not pm.is_installed("docx"): + pm.install("docx") + from docx import Document + from io import BytesIO + + docx_file = BytesIO(file) + doc = Document(docx_file) + content = "\n".join([paragraph.text for paragraph in doc.paragraphs]) + case ".pptx": + if not pm.is_installed("pptx"): + pm.install("pptx") + from pptx import Presentation + from io import BytesIO + + pptx_file = BytesIO(file) + prs = Presentation(pptx_file) + for slide in prs.slides: + for shape in slide.shapes: + if hasattr(shape, "text"): + content += shape.text + "\n" + case ".xlsx": + if not pm.is_installed("openpyxl"): + pm.install("openpyxl") + from openpyxl import load_workbook + from io import BytesIO + + xlsx_file = BytesIO(file) + wb = load_workbook(xlsx_file) + for sheet in wb: + content += f"Sheet: {sheet.title}\n" + for row in sheet.iter_rows(values_only=True): + content += "\t".join(str(cell) if cell is not None else "" for cell in row) + "\n" + content += "\n" + case _: + logging.error(f"Unsupported file type: {file_path.name} (extension {ext})") + return + + # Insert into the RAG queue + if content: + await rag.apipeline_enqueue_documents(content) + await rag.apipeline_process_enqueue_documents() + logging.info(f"Successfully indexed file: {file_path.name}") + else: + logging.error(f"No content could be extracted from file: {file_path.name}") + + except Exception as e: + logging.error(f"Error indexing file {file_path.name}: {str(e)}") + logging.error(traceback.format_exc()) + +async def pipeline_index_files(rag, file_paths: List[Path]): + if not file_paths: + return + try: + enqueued = False + if len(file_paths) == 1: + enqueued = await pipeline_enqueue_file(rag, file_paths[0]) + else: + tasks = [pipeline_enqueue_file(rag, path) for path in file_paths] + enqueued = any(await asyncio.gather(*tasks)) + + if enqueued: + await rag.apipeline_process_enqueue_documents() + except Exception as e: + logging.error(f"Error indexing files: {str(e)}") + logging.error(traceback.format_exc()) + +async def pipeline_index_texts(rag, texts: List[str]): + if not texts: + return + await rag.apipeline_enqueue_documents(texts) + await rag.apipeline_process_enqueue_documents() + +async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + unique_filename = f"{temp_prefix}{timestamp}_{file.filename}" + temp_path = input_dir / "temp" / unique_filename + temp_path.parent.mkdir(exist_ok=True) + with open(temp_path, "wb") as buffer: + shutil.copyfileobj(file.file, buffer) + return temp_path + +async def run_scanning_process(rag, doc_manager: DocumentManager): + """Background task to scan and index documents""" + try: + new_files = doc_manager.scan_directory_for_new_files() + scan_progress["total_files"] = len(new_files) + + logging.info(f"Found {len(new_files)} new files to index.") + for file_path in new_files: + try: + async with progress_lock: + scan_progress["current_file"] = os.path.basename(file_path) + + await pipeline_index_file(rag, file_path) + + async with progress_lock: + scan_progress["indexed_count"] += 1 + scan_progress["progress"] = ( + scan_progress["indexed_count"] / scan_progress["total_files"] + ) * 100 + + except Exception as e: + logging.error(f"Error indexing file {file_path}: {str(e)}") + + except Exception as e: + logging.error(f"Error during scanning process: {str(e)}") + finally: + async with progress_lock: + scan_progress["is_scanning"] = False + +def create_document_routes(rag, doc_manager: DocumentManager, api_key: Optional[str] = None): + optional_api_key = get_api_key_dependency(api_key) + + @router.post("/scan", dependencies=[Depends(optional_api_key)]) + async def scan_for_new_documents(background_tasks: BackgroundTasks): + """ + Trigger the scanning process for new documents. + + This endpoint initiates a background task that scans the input directory for new documents + and processes them. If a scanning process is already running, it returns a status indicating + that fact. + + Args: + background_tasks (BackgroundTasks): FastAPI background tasks handler + + Returns: + dict: A dictionary containing the scanning status + """ + async with progress_lock: + if scan_progress["is_scanning"]: + return {"status": "already_scanning"} + + scan_progress["is_scanning"] = True + scan_progress["indexed_count"] = 0 + scan_progress["progress"] = 0 + + background_tasks.add_task(run_scanning_process, rag, doc_manager) + return {"status": "scanning_started"} + + @router.get("/scan-progress") + async def get_scan_progress(): + """ + Get the current progress of the document scanning process. + + Returns: + dict: A dictionary containing the current scanning progress information including: + - is_scanning: Whether a scan is currently in progress + - current_file: The file currently being processed + - indexed_count: Number of files indexed so far + - total_files: Total number of files to process + - progress: Percentage of completion + """ + async with progress_lock: + return scan_progress + + @router.post("/upload", dependencies=[Depends(optional_api_key)]) + async def upload_to_input_dir( + background_tasks: BackgroundTasks, file: UploadFile = File(...) + ): + """ + Upload a file to the input directory and index it. + + This API endpoint accepts a file through an HTTP POST request, checks if the + uploaded file is of a supported type, saves it in the specified input directory, + indexes it for retrieval, and returns a success status with relevant details. + + Args: + background_tasks: FastAPI BackgroundTasks for async processing + file (UploadFile): The file to be uploaded. It must have an allowed extension. + + Returns: + InsertResponse: A response object containing the upload status and a message. + + Raises: + HTTPException: If the file type is not supported (400) or other errors occur (500). + """ + try: + if not doc_manager.is_supported_file(file.filename): + raise HTTPException( + status_code=400, + detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}", + ) + + file_path = doc_manager.input_dir / file.filename + with open(file_path, "wb") as buffer: + shutil.copyfileobj(file.file, buffer) + + background_tasks.add_task(pipeline_index_file, rag, file_path) + + return InsertResponse( + status="success", + message=f"File '{file.filename}' uploaded successfully. Processing will continue in background.", + ) + except Exception as e: + logging.error(f"Error /documents/upload: {file.filename}: {str(e)}") + logging.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=str(e)) + + @router.post("/text", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) + async def insert_text(request: InsertTextRequest, background_tasks: BackgroundTasks): + """ + Insert text into the RAG system. + + This endpoint allows you to insert text data into the RAG system for later retrieval + and use in generating responses. + + Args: + request (InsertTextRequest): The request body containing the text to be inserted. + background_tasks: FastAPI BackgroundTasks for async processing + + Returns: + InsertResponse: A response object containing the status of the operation. + + Raises: + HTTPException: If an error occurs during text processing (500). + """ + try: + background_tasks.add_task(pipeline_index_texts, rag, [request.text]) + return InsertResponse( + status="success", + message="Text successfully received. Processing will continue in background.", + ) + except Exception as e: + logging.error(f"Error /documents/text: {str(e)}") + logging.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=str(e)) + + @router.post("/texts", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) + async def insert_texts(request: InsertTextsRequest, background_tasks: BackgroundTasks): + """ + Insert multiple texts into the RAG system. + + This endpoint allows you to insert multiple text entries into the RAG system + in a single request. + + Args: + request (InsertTextsRequest): The request body containing the list of texts. + background_tasks: FastAPI BackgroundTasks for async processing + + Returns: + InsertResponse: A response object containing the status of the operation. + + Raises: + HTTPException: If an error occurs during text processing (500). + """ + try: + background_tasks.add_task(pipeline_index_texts, rag, request.texts) + return InsertResponse( + status="success", + message="Text successfully received. Processing will continue in background.", + ) + except Exception as e: + logging.error(f"Error /documents/text: {str(e)}") + logging.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=str(e)) + + @router.post("/file", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) + async def insert_file(background_tasks: BackgroundTasks, file: UploadFile = File(...)): + """ + Insert a file directly into the RAG system. + + This endpoint accepts a file upload and processes it for inclusion in the RAG system. + The file is saved temporarily and processed in the background. + + Args: + background_tasks: FastAPI BackgroundTasks for async processing + file (UploadFile): The file to be processed + + Returns: + InsertResponse: A response object containing the status of the operation. + + Raises: + HTTPException: If the file type is not supported (400) or other errors occur (500). + """ + try: + if not doc_manager.is_supported_file(file.filename): + raise HTTPException( + status_code=400, + detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}", + ) + + temp_path = await save_temp_file(doc_manager.input_dir, file) + background_tasks.add_task(pipeline_index_file, rag, temp_path) + + return InsertResponse( + status="success", + message=f"File '{file.filename}' saved successfully. Processing will continue in background.", + ) + except Exception as e: + logging.error(f"Error /documents/file: {str(e)}") + logging.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=str(e)) + + @router.post("/file_batch", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) + async def insert_batch(background_tasks: BackgroundTasks, files: List[UploadFile] = File(...)): + """ + Process multiple files in batch mode. + + This endpoint allows uploading and processing multiple files simultaneously. + It handles partial successes and provides detailed feedback about failed files. + + Args: + background_tasks: FastAPI BackgroundTasks for async processing + files (List[UploadFile]): List of files to process + + Returns: + InsertResponse: A response object containing: + - status: "success", "partial_success", or "failure" + - message: Detailed information about the operation results + + Raises: + HTTPException: If an error occurs during processing (500). + """ + try: + inserted_count = 0 + failed_files = [] + temp_files = [] + + for file in files: + if doc_manager.is_supported_file(file.filename): + temp_files.append(await save_temp_file(doc_manager.input_dir, file)) + inserted_count += 1 + else: + failed_files.append(f"{file.filename} (unsupported type)") + + if temp_files: + background_tasks.add_task(pipeline_index_files, rag, temp_files) + + if inserted_count == len(files): + status = "success" + status_message = f"Successfully inserted all {inserted_count} documents" + elif inserted_count > 0: + status = "partial_success" + status_message = f"Successfully inserted {inserted_count} out of {len(files)} documents" + if failed_files: + status_message += f". Failed files: {', '.join(failed_files)}" + else: + status = "failure" + status_message = "No documents were successfully inserted" + if failed_files: + status_message += f". Failed files: {', '.join(failed_files)}" + + return InsertResponse(status=status, message=status_message) + except Exception as e: + logging.error(f"Error /documents/batch: {str(e)}") + logging.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=str(e)) + + @router.delete("", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) + async def clear_documents(): + """ + Clear all documents from the RAG system. + + This endpoint deletes all text chunks, entities vector database, and relationships + vector database, effectively clearing all documents from the RAG system. + + Returns: + InsertResponse: A response object containing the status and message. + + Raises: + HTTPException: If an error occurs during the clearing process (500). + """ + try: + rag.text_chunks = [] + rag.entities_vdb = None + rag.relationships_vdb = None + return InsertResponse(status="success", message="All documents cleared successfully") + except Exception as e: + logging.error(f"Error DELETE /documents: {str(e)}") + logging.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=str(e)) + + @router.get("", dependencies=[Depends(optional_api_key)]) + async def documents() -> DocsStatusesResponse: + """ + Get the status of all documents in the system. + + This endpoint retrieves the current status of all documents, grouped by their + processing status (PENDING, PROCESSING, PROCESSED, FAILED). + + Returns: + DocsStatusesResponse: A response object containing a dictionary where keys are + DocStatus values and values are lists of DocStatusResponse + objects representing documents in each status category. + + Raises: + HTTPException: If an error occurs while retrieving document statuses (500). + """ + try: + statuses = ( + DocStatus.PENDING, + DocStatus.PROCESSING, + DocStatus.PROCESSED, + DocStatus.FAILED, + ) + + tasks = [rag.get_docs_by_status(status) for status in statuses] + results: List[Dict[str, DocProcessingStatus]] = await asyncio.gather(*tasks) + + response = DocsStatusesResponse() + + for idx, result in enumerate(results): + status = statuses[idx] + for doc_id, doc_status in result.items(): + if status not in response.statuses: + response.statuses[status] = [] + response.statuses[status].append( + DocStatusResponse( + id=doc_id, + content_summary=doc_status.content_summary, + content_length=doc_status.content_length, + status=doc_status.status, + created_at=DocStatusResponse.format_datetime(doc_status.created_at), + updated_at=DocStatusResponse.format_datetime(doc_status.updated_at), + chunks_count=doc_status.chunks_count, + error=doc_status.error, + metadata=doc_status.metadata, + ) + ) + return response + except Exception as e: + logging.error(f"Error GET /documents: {str(e)}") + logging.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=str(e)) + + return router diff --git a/lightrag/api/routers/graph_routes.py b/lightrag/api/routers/graph_routes.py new file mode 100644 index 00000000..1d08d9ae --- /dev/null +++ b/lightrag/api/routers/graph_routes.py @@ -0,0 +1,26 @@ +""" +This module contains all graph-related routes for the LightRAG API. +""" + +from typing import Optional + +from fastapi import APIRouter, Depends, HTTPException + +from ..utils_api import get_api_key_dependency + +router = APIRouter(tags=["graph"]) + +def create_graph_routes(rag, api_key: Optional[str] = None): + optional_api_key = get_api_key_dependency(api_key) + + @router.get("/graph/label/list", dependencies=[Depends(optional_api_key)]) + async def get_graph_labels(): + """Get all graph labels""" + return await rag.get_graph_labels() + + @router.get("/graphs", dependencies=[Depends(optional_api_key)]) + async def get_knowledge_graph(label: str): + """Get knowledge graph for a specific label""" + return await rag.get_knowledge_graph(nodel_label=label, max_depth=100) + + return router diff --git a/lightrag/api/ollama_api.py b/lightrag/api/routers/ollama_api.py similarity index 100% rename from lightrag/api/ollama_api.py rename to lightrag/api/routers/ollama_api.py diff --git a/lightrag/api/routers/query_routes.py b/lightrag/api/routers/query_routes.py new file mode 100644 index 00000000..389d127d --- /dev/null +++ b/lightrag/api/routers/query_routes.py @@ -0,0 +1,225 @@ +""" +This module contains all query-related routes for the LightRAG API. +""" + +import json +import logging +import traceback +from typing import Any, Dict, List, Literal, Optional + +from fastapi import APIRouter, Depends, HTTPException +from lightrag.base import QueryParam +from ..utils_api import get_api_key_dependency +from pydantic import BaseModel, Field, field_validator + +from ascii_colors import trace_exception + +router = APIRouter(tags=["query"]) + +class QueryRequest(BaseModel): + query: str = Field( + min_length=1, + description="The query text", + ) + + mode: Literal["local", "global", "hybrid", "naive", "mix"] = Field( + default="hybrid", + description="Query mode", + ) + + only_need_context: Optional[bool] = Field( + default=None, + description="If True, only returns the retrieved context without generating a response.", + ) + + only_need_prompt: Optional[bool] = Field( + default=None, + description="If True, only returns the generated prompt without producing a response.", + ) + + response_type: Optional[str] = Field( + min_length=1, + default=None, + description="Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'.", + ) + + top_k: Optional[int] = Field( + ge=1, + default=None, + description="Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.", + ) + + max_token_for_text_unit: Optional[int] = Field( + gt=1, + default=None, + description="Maximum number of tokens allowed for each retrieved text chunk.", + ) + + max_token_for_global_context: Optional[int] = Field( + gt=1, + default=None, + description="Maximum number of tokens allocated for relationship descriptions in global retrieval.", + ) + + max_token_for_local_context: Optional[int] = Field( + gt=1, + default=None, + description="Maximum number of tokens allocated for entity descriptions in local retrieval.", + ) + + hl_keywords: Optional[List[str]] = Field( + default=None, + description="List of high-level keywords to prioritize in retrieval.", + ) + + ll_keywords: Optional[List[str]] = Field( + default=None, + description="List of low-level keywords to refine retrieval focus.", + ) + + conversation_history: Optional[List[Dict[str, Any]]] = Field( + default=None, + description="Stores past conversation history to maintain context. Format: [{'role': 'user/assistant', 'content': 'message'}].", + ) + + history_turns: Optional[int] = Field( + ge=0, + default=None, + description="Number of complete conversation turns (user-assistant pairs) to consider in the response context.", + ) + + @field_validator("query", mode="after") + @classmethod + def query_strip_after(cls, query: str) -> str: + return query.strip() + + @field_validator("hl_keywords", mode="after") + @classmethod + def hl_keywords_strip_after(cls, hl_keywords: List[str] | None) -> List[str] | None: + if hl_keywords is None: + return None + return [keyword.strip() for keyword in hl_keywords] + + @field_validator("ll_keywords", mode="after") + @classmethod + def ll_keywords_strip_after(cls, ll_keywords: List[str] | None) -> List[str] | None: + if ll_keywords is None: + return None + return [keyword.strip() for keyword in ll_keywords] + + @field_validator("conversation_history", mode="after") + @classmethod + def conversation_history_role_check( + cls, conversation_history: List[Dict[str, Any]] | None + ) -> List[Dict[str, Any]] | None: + if conversation_history is None: + return None + for msg in conversation_history: + if "role" not in msg or msg["role"] not in {"user", "assistant"}: + raise ValueError( + "Each message must have a 'role' key with value 'user' or 'assistant'." + ) + return conversation_history + + def to_query_params(self, is_stream: bool) -> "QueryParam": + """Converts a QueryRequest instance into a QueryParam instance.""" + # Use Pydantic's `.model_dump(exclude_none=True)` to remove None values automatically + request_data = self.model_dump(exclude_none=True, exclude={"query"}) + + # Ensure `mode` and `stream` are set explicitly + param = QueryParam(**request_data) + param.stream = is_stream + return param + +class QueryResponse(BaseModel): + response: str = Field( + description="The generated response", + ) + +def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60): + optional_api_key = get_api_key_dependency(api_key) + + @router.post("/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)]) + async def query_text(request: QueryRequest): + """ + Handle a POST request at the /query endpoint to process user queries using RAG capabilities. + + Parameters: + request (QueryRequest): The request object containing the query parameters. + Returns: + QueryResponse: A Pydantic model containing the result of the query processing. + If a string is returned (e.g., cache hit), it's directly returned. + Otherwise, an async generator may be used to build the response. + + Raises: + HTTPException: Raised when an error occurs during the request handling process, + with status code 500 and detail containing the exception message. + """ + try: + param = request.to_query_params(False) + if param.top_k is None: + param.top_k = top_k + response = await rag.aquery(request.query, param=param) + + # If response is a string (e.g. cache hit), return directly + if isinstance(response, str): + return QueryResponse(response=response) + + if isinstance(response, dict): + result = json.dumps(response, indent=2) + return QueryResponse(response=result) + else: + return QueryResponse(response=str(response)) + except Exception as e: + trace_exception(e) + raise HTTPException(status_code=500, detail=str(e)) + + @router.post("/query/stream", dependencies=[Depends(optional_api_key)]) + async def query_text_stream(request: QueryRequest): + """ + This endpoint performs a retrieval-augmented generation (RAG) query and streams the response. + + Args: + request (QueryRequest): The request object containing the query parameters. + optional_api_key (Optional[str], optional): An optional API key for authentication. Defaults to None. + + Returns: + StreamingResponse: A streaming response containing the RAG query results. + """ + try: + param = request.to_query_params(True) + if param.top_k is None: + param.top_k = top_k + response = await rag.aquery(request.query, param=param) + + from fastapi.responses import StreamingResponse + + async def stream_generator(): + if isinstance(response, str): + # If it's a string, send it all at once + yield f"{json.dumps({'response': response})}\n" + else: + # If it's an async generator, send chunks one by one + try: + async for chunk in response: + if chunk: # Only send non-empty content + yield f"{json.dumps({'response': chunk})}\n" + except Exception as e: + logging.error(f"Streaming error: {str(e)}") + yield f"{json.dumps({'error': str(e)})}\n" + + return StreamingResponse( + stream_generator(), + media_type="application/x-ndjson", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Content-Type": "application/x-ndjson", + "X-Accel-Buffering": "no", # Ensure proper handling of streaming response when proxied by Nginx + }, + ) + except Exception as e: + trace_exception(e) + raise HTTPException(status_code=500, detail=str(e)) + + return router diff --git a/lightrag/api/utils_api.py b/lightrag/api/utils_api.py new file mode 100644 index 00000000..c5f1034a --- /dev/null +++ b/lightrag/api/utils_api.py @@ -0,0 +1,44 @@ +""" +Utility functions for the LightRAG API. +""" + +from typing import Optional +from fastapi import HTTPException, Security +from fastapi.security import APIKeyHeader +from starlette.status import HTTP_403_FORBIDDEN + +def get_api_key_dependency(api_key: Optional[str]): + """ + Create an API key dependency for route protection. + + Args: + api_key (Optional[str]): The API key to validate against. + If None, no authentication is required. + + Returns: + Callable: A dependency function that validates the API key. + """ + if not api_key: + # If no API key is configured, return a dummy dependency that always succeeds + async def no_auth(): + return None + + return no_auth + + # If API key is configured, use proper authentication + api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) + + async def api_key_auth( + api_key_header_value: Optional[str] = Security(api_key_header), + ): + if not api_key_header_value: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, detail="API Key required" + ) + if api_key_header_value != api_key: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, detail="Invalid API Key" + ) + return api_key_header_value + + return api_key_auth From a64ba7b4da5d0934af44dd9b670ee380eb657881 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 20 Feb 2025 03:45:23 +0800 Subject: [PATCH 02/14] Update dependencies and clean up requirements.txt - Added 'future' package - Removed 'networkx' package - Cleaned up commented sections - Maintained core dependencies - Simplified requirements structure --- requirements.txt | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index 03d93aa3..1c7c16cd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,7 @@ +future aiohttp configparser -# database packages -networkx - # Basic modules numpy pipmaster From 36de7e3197ed881a0e410fcacb8b88a0ec1dd608 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 20 Feb 2025 03:46:36 +0800 Subject: [PATCH 03/14] Add test_* pattern to .gitignore for unit test files --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index e8130e18..9cb71979 100644 --- a/.gitignore +++ b/.gitignore @@ -60,3 +60,6 @@ dickens/ book.txt lightrag-dev/ gui/ + +# unit-test files +test_* \ No newline at end of file From 3c080a9ebfaf1927394981f555c65c72a4698bc3 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 20 Feb 2025 04:04:54 +0800 Subject: [PATCH 04/14] Enhance webui mounting with root endpoint and directory check. - Added FileResponse for webui root endpoint - Enabled directory check in StaticFiles mount - Improved webui static file handling - Ensured webui directory existence - Simplified webui access with root endpoint --- lightrag/api/lightrag_server.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index f7f70c62..4083e790 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -7,6 +7,7 @@ from fastapi import ( HTTPException, Depends, ) +from fastapi.responses import FileResponse import asyncio import threading import os @@ -875,7 +876,11 @@ def create_app(args): # Webui mount webui/index.html static_dir = Path(__file__).parent / "webui" static_dir.mkdir(exist_ok=True) - app.mount("/webui", StaticFiles(directory=static_dir, html=True), name="webui") + app.mount("/webui", StaticFiles(directory=static_dir, html=True, check_dir=True), name="webui") + + @app.get("/webui/") + async def webui_root(): + return FileResponse(static_dir / "index.html") return app From f776db07799abc4897e2c7122dcc67c6aaf67e0f Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 20 Feb 2025 04:09:02 +0800 Subject: [PATCH 05/14] Improved document status retrieval with content fallback. - Added content fallback to content_summary - Handled missing fields gracefully - Made data copy to avoid modification - Added error logging for missing fields - Improved code readability and robustness --- lightrag/kg/json_doc_status_impl.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index 1a05abc2..76b7158b 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -48,11 +48,20 @@ class JsonDocStatusStorage(DocStatusStorage): self, status: DocStatus ) -> dict[str, DocProcessingStatus]: """Get all documents with a specific status""" - return { - k: DocProcessingStatus(**v) - for k, v in self._data.items() - if v["status"] == status.value - } + result = {} + for k, v in self._data.items(): + if v["status"] == status.value: + try: + # Make a copy of the data to avoid modifying the original + data = v.copy() + # If content is missing, use content_summary as content + if "content" not in data and "content_summary" in data: + data["content"] = data["content_summary"] + result[k] = DocProcessingStatus(**data) + except KeyError as e: + logger.error(f"Missing required field for document {k}: {e}") + continue + return result async def index_done_callback(self) -> None: write_json(self._data, self._file_name) From a8abcf14acef2d3758af0d64f10c56027b2638b0 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 20 Feb 2025 04:12:21 +0800 Subject: [PATCH 06/14] Fit linting --- .gitignore | 2 +- lightrag/api/lightrag_server.py | 16 ++-- lightrag/api/routers/document_routes.py | 97 +++++++++++++++++++------ lightrag/api/routers/graph_routes.py | 3 +- lightrag/api/routers/query_routes.py | 8 +- lightrag/api/utils_api.py | 1 + requirements.txt | 2 +- 7 files changed, 98 insertions(+), 31 deletions(-) diff --git a/.gitignore b/.gitignore index 9cb71979..3eb55bd3 100644 --- a/.gitignore +++ b/.gitignore @@ -62,4 +62,4 @@ lightrag-dev/ gui/ # unit-test files -test_* \ No newline at end of file +test_* diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 4083e790..eff927c2 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -4,7 +4,6 @@ LightRAG FastAPI Server from fastapi import ( FastAPI, - HTTPException, Depends, ) from fastapi.responses import FileResponse @@ -14,7 +13,7 @@ import os from fastapi.staticfiles import StaticFiles import logging import argparse -from typing import Optional, Dict +from typing import Dict from pathlib import Path import configparser from ascii_colors import ASCIIColors @@ -73,6 +72,7 @@ scan_progress: Dict = { # Lock for thread-safe operations progress_lock = threading.Lock() + def get_default_host(binding_type: str) -> str: default_hosts = { "ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"), @@ -624,7 +624,9 @@ def create_app(args): scan_progress["indexed_count"] = 0 scan_progress["progress"] = 0 # Create background task - task = asyncio.create_task(run_scanning_process(rag, doc_manager)) + task = asyncio.create_task( + run_scanning_process(rag, doc_manager) + ) app.state.background_tasks.add(task) task.add_done_callback(app.state.background_tasks.discard) ASCIIColors.info( @@ -876,8 +878,12 @@ def create_app(args): # Webui mount webui/index.html static_dir = Path(__file__).parent / "webui" static_dir.mkdir(exist_ok=True) - app.mount("/webui", StaticFiles(directory=static_dir, html=True, check_dir=True), name="webui") - + app.mount( + "/webui", + StaticFiles(directory=static_dir, html=True, check_dir=True), + name="webui", + ) + @app.get("/webui/") async def webui_root(): return FileResponse(static_dir / "index.html") diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index 48401658..ba4f8b0a 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -14,9 +14,7 @@ from pathlib import Path from typing import Dict, List, Optional, Any from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile -from fastapi.security import APIKeyHeader from pydantic import BaseModel, Field, field_validator -from starlette.status import HTTP_403_FORBIDDEN from lightrag.base import DocProcessingStatus, DocStatus from ..utils_api import get_api_key_dependency @@ -39,6 +37,7 @@ progress_lock = asyncio.Lock() # Temporary file prefix temp_prefix = "__tmp__" + class InsertTextRequest(BaseModel): text: str = Field( min_length=1, @@ -50,6 +49,7 @@ class InsertTextRequest(BaseModel): def strip_after(cls, text: str) -> str: return text.strip() + class InsertTextsRequest(BaseModel): texts: list[str] = Field( min_length=1, @@ -61,10 +61,12 @@ class InsertTextsRequest(BaseModel): def strip_after(cls, texts: list[str]) -> list[str]: return [text.strip() for text in texts] + class InsertResponse(BaseModel): status: str = Field(description="Status of the operation") message: str = Field(description="Message describing the operation result") + class DocStatusResponse(BaseModel): @staticmethod def format_datetime(dt: Any) -> Optional[str]: @@ -84,9 +86,11 @@ class DocStatusResponse(BaseModel): error: Optional[str] = None metadata: Optional[dict[str, Any]] = None + class DocsStatusesResponse(BaseModel): statuses: Dict[DocStatus, List[DocStatusResponse]] = {} + class DocumentManager: def __init__( self, @@ -129,6 +133,7 @@ class DocumentManager: def is_supported_file(self, filename: str) -> bool: return any(filename.lower().endswith(ext) for ext in self.supported_extensions) + async def pipeline_enqueue_file(rag, file_path: Path) -> bool: try: content = "" @@ -145,7 +150,7 @@ async def pipeline_enqueue_file(rag, file_path: Path) -> bool: case ".pdf": if not pm.is_installed("pypdf2"): pm.install("pypdf2") - from PyPDF2 import PdfReader # type: ignore + from PyPDF2 import PdfReader # type: ignore from io import BytesIO pdf_file = BytesIO(file) @@ -184,10 +189,17 @@ async def pipeline_enqueue_file(rag, file_path: Path) -> bool: for sheet in wb: content += f"Sheet: {sheet.title}\n" for row in sheet.iter_rows(values_only=True): - content += "\t".join(str(cell) if cell is not None else "" for cell in row) + "\n" + content += ( + "\t".join( + str(cell) if cell is not None else "" for cell in row + ) + + "\n" + ) content += "\n" case _: - logging.error(f"Unsupported file type: {file_path.name} (extension {ext})") + logging.error( + f"Unsupported file type: {file_path.name} (extension {ext})" + ) return False # Insert into the RAG queue @@ -209,6 +221,7 @@ async def pipeline_enqueue_file(rag, file_path: Path) -> bool: logging.error(f"Error deleting file {file_path}: {str(e)}") return False + async def pipeline_index_file(rag, file_path: Path): """Index a file @@ -231,7 +244,7 @@ async def pipeline_index_file(rag, file_path: Path): case ".pdf": if not pm.is_installed("pypdf2"): pm.install("pypdf2") - from PyPDF2 import PdfReader # type: ignore + from PyPDF2 import PdfReader # type: ignore from io import BytesIO pdf_file = BytesIO(file) @@ -270,10 +283,17 @@ async def pipeline_index_file(rag, file_path: Path): for sheet in wb: content += f"Sheet: {sheet.title}\n" for row in sheet.iter_rows(values_only=True): - content += "\t".join(str(cell) if cell is not None else "" for cell in row) + "\n" + content += ( + "\t".join( + str(cell) if cell is not None else "" for cell in row + ) + + "\n" + ) content += "\n" case _: - logging.error(f"Unsupported file type: {file_path.name} (extension {ext})") + logging.error( + f"Unsupported file type: {file_path.name} (extension {ext})" + ) return # Insert into the RAG queue @@ -288,6 +308,7 @@ async def pipeline_index_file(rag, file_path: Path): logging.error(f"Error indexing file {file_path.name}: {str(e)}") logging.error(traceback.format_exc()) + async def pipeline_index_files(rag, file_paths: List[Path]): if not file_paths: return @@ -305,12 +326,14 @@ async def pipeline_index_files(rag, file_paths: List[Path]): logging.error(f"Error indexing files: {str(e)}") logging.error(traceback.format_exc()) + async def pipeline_index_texts(rag, texts: List[str]): if not texts: return await rag.apipeline_enqueue_documents(texts) await rag.apipeline_process_enqueue_documents() + async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") unique_filename = f"{temp_prefix}{timestamp}_{file.filename}" @@ -320,6 +343,7 @@ async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path: shutil.copyfileobj(file.file, buffer) return temp_path + async def run_scanning_process(rag, doc_manager: DocumentManager): """Background task to scan and index documents""" try: @@ -349,7 +373,10 @@ async def run_scanning_process(rag, doc_manager: DocumentManager): async with progress_lock: scan_progress["is_scanning"] = False -def create_document_routes(rag, doc_manager: DocumentManager, api_key: Optional[str] = None): + +def create_document_routes( + rag, doc_manager: DocumentManager, api_key: Optional[str] = None +): optional_api_key = get_api_key_dependency(api_key) @router.post("/scan", dependencies=[Depends(optional_api_key)]) @@ -437,8 +464,12 @@ def create_document_routes(rag, doc_manager: DocumentManager, api_key: Optional[ logging.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) - @router.post("/text", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) - async def insert_text(request: InsertTextRequest, background_tasks: BackgroundTasks): + @router.post( + "/text", response_model=InsertResponse, dependencies=[Depends(optional_api_key)] + ) + async def insert_text( + request: InsertTextRequest, background_tasks: BackgroundTasks + ): """ Insert text into the RAG system. @@ -466,8 +497,14 @@ def create_document_routes(rag, doc_manager: DocumentManager, api_key: Optional[ logging.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) - @router.post("/texts", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) - async def insert_texts(request: InsertTextsRequest, background_tasks: BackgroundTasks): + @router.post( + "/texts", + response_model=InsertResponse, + dependencies=[Depends(optional_api_key)], + ) + async def insert_texts( + request: InsertTextsRequest, background_tasks: BackgroundTasks + ): """ Insert multiple texts into the RAG system. @@ -495,8 +532,12 @@ def create_document_routes(rag, doc_manager: DocumentManager, api_key: Optional[ logging.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) - @router.post("/file", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) - async def insert_file(background_tasks: BackgroundTasks, file: UploadFile = File(...)): + @router.post( + "/file", response_model=InsertResponse, dependencies=[Depends(optional_api_key)] + ) + async def insert_file( + background_tasks: BackgroundTasks, file: UploadFile = File(...) + ): """ Insert a file directly into the RAG system. @@ -532,8 +573,14 @@ def create_document_routes(rag, doc_manager: DocumentManager, api_key: Optional[ logging.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) - @router.post("/file_batch", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) - async def insert_batch(background_tasks: BackgroundTasks, files: List[UploadFile] = File(...)): + @router.post( + "/file_batch", + response_model=InsertResponse, + dependencies=[Depends(optional_api_key)], + ) + async def insert_batch( + background_tasks: BackgroundTasks, files: List[UploadFile] = File(...) + ): """ Process multiple files in batch mode. @@ -587,7 +634,9 @@ def create_document_routes(rag, doc_manager: DocumentManager, api_key: Optional[ logging.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) - @router.delete("", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) + @router.delete( + "", response_model=InsertResponse, dependencies=[Depends(optional_api_key)] + ) async def clear_documents(): """ Clear all documents from the RAG system. @@ -605,7 +654,9 @@ def create_document_routes(rag, doc_manager: DocumentManager, api_key: Optional[ rag.text_chunks = [] rag.entities_vdb = None rag.relationships_vdb = None - return InsertResponse(status="success", message="All documents cleared successfully") + return InsertResponse( + status="success", message="All documents cleared successfully" + ) except Exception as e: logging.error(f"Error DELETE /documents: {str(e)}") logging.error(traceback.format_exc()) @@ -651,8 +702,12 @@ def create_document_routes(rag, doc_manager: DocumentManager, api_key: Optional[ content_summary=doc_status.content_summary, content_length=doc_status.content_length, status=doc_status.status, - created_at=DocStatusResponse.format_datetime(doc_status.created_at), - updated_at=DocStatusResponse.format_datetime(doc_status.updated_at), + created_at=DocStatusResponse.format_datetime( + doc_status.created_at + ), + updated_at=DocStatusResponse.format_datetime( + doc_status.updated_at + ), chunks_count=doc_status.chunks_count, error=doc_status.error, metadata=doc_status.metadata, diff --git a/lightrag/api/routers/graph_routes.py b/lightrag/api/routers/graph_routes.py index 1d08d9ae..bfdb838c 100644 --- a/lightrag/api/routers/graph_routes.py +++ b/lightrag/api/routers/graph_routes.py @@ -4,12 +4,13 @@ This module contains all graph-related routes for the LightRAG API. from typing import Optional -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends from ..utils_api import get_api_key_dependency router = APIRouter(tags=["graph"]) + def create_graph_routes(rag, api_key: Optional[str] = None): optional_api_key = get_api_key_dependency(api_key) diff --git a/lightrag/api/routers/query_routes.py b/lightrag/api/routers/query_routes.py index 389d127d..b86c170e 100644 --- a/lightrag/api/routers/query_routes.py +++ b/lightrag/api/routers/query_routes.py @@ -4,7 +4,6 @@ This module contains all query-related routes for the LightRAG API. import json import logging -import traceback from typing import Any, Dict, List, Literal, Optional from fastapi import APIRouter, Depends, HTTPException @@ -16,6 +15,7 @@ from ascii_colors import trace_exception router = APIRouter(tags=["query"]) + class QueryRequest(BaseModel): query: str = Field( min_length=1, @@ -131,15 +131,19 @@ class QueryRequest(BaseModel): param.stream = is_stream return param + class QueryResponse(BaseModel): response: str = Field( description="The generated response", ) + def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60): optional_api_key = get_api_key_dependency(api_key) - @router.post("/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)]) + @router.post( + "/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)] + ) async def query_text(request: QueryRequest): """ Handle a POST request at the /query endpoint to process user queries using RAG capabilities. diff --git a/lightrag/api/utils_api.py b/lightrag/api/utils_api.py index c5f1034a..aef3c128 100644 --- a/lightrag/api/utils_api.py +++ b/lightrag/api/utils_api.py @@ -7,6 +7,7 @@ from fastapi import HTTPException, Security from fastapi.security import APIKeyHeader from starlette.status import HTTP_403_FORBIDDEN + def get_api_key_dependency(api_key: Optional[str]): """ Create an API key dependency for route protection. diff --git a/requirements.txt b/requirements.txt index 1c7c16cd..a1a1157e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ -future aiohttp configparser +future # Basic modules numpy From 5d884f6d3e88aa1e80ed480e257e156bb7b20191 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 20 Feb 2025 12:12:20 +0800 Subject: [PATCH 07/14] refactor: centralize configuration and utility functions - Move parse_args and display_splash_screen functions from lightrag_server.py to utils_api.py - Move OllamaServerInfos class and instance from ollama_api.py to utils_api.py --- lightrag/api/lightrag_server.py | 499 +--------------------------- lightrag/api/routers/ollama_api.py | 20 +- lightrag/api/utils_api.py | 509 +++++++++++++++++++++++++++++ 3 files changed, 517 insertions(+), 511 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index eff927c2..f6cee412 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -12,17 +12,20 @@ import threading import os from fastapi.staticfiles import StaticFiles import logging -import argparse from typing import Dict from pathlib import Path import configparser from ascii_colors import ASCIIColors -import sys from fastapi.middleware.cors import CORSMiddleware from contextlib import asynccontextmanager from dotenv import load_dotenv -from .utils_api import get_api_key_dependency +from .utils_api import ( + get_api_key_dependency, + parse_args, + get_default_host, + display_splash_screen, +) from lightrag import LightRAG from lightrag.types import GPTKeywordExtractionFormat @@ -37,7 +40,7 @@ from .routers.document_routes import ( ) from .routers.query_routes import create_query_routes from .routers.graph_routes import create_graph_routes -from .routers.ollama_api import OllamaAPI, ollama_server_infos +from .routers.ollama_api import OllamaAPI # Load environment variables try: @@ -52,14 +55,6 @@ config.read("config.ini") # Global configuration global_top_k = 60 # default value - -class DefaultRAGStorageConfig: - KV_STORAGE = "JsonKVStorage" - VECTOR_STORAGE = "NanoVectorDBStorage" - GRAPH_STORAGE = "NetworkXStorage" - DOC_STATUS_STORAGE = "JsonDocStatusStorage" - - # Global progress tracker scan_progress: Dict = { "is_scanning": False, @@ -73,486 +68,6 @@ scan_progress: Dict = { progress_lock = threading.Lock() -def get_default_host(binding_type: str) -> str: - default_hosts = { - "ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"), - "lollms": os.getenv("LLM_BINDING_HOST", "http://localhost:9600"), - "azure_openai": os.getenv("AZURE_OPENAI_ENDPOINT", "https://api.openai.com/v1"), - "openai": os.getenv("LLM_BINDING_HOST", "https://api.openai.com/v1"), - } - return default_hosts.get( - binding_type, os.getenv("LLM_BINDING_HOST", "http://localhost:11434") - ) # fallback to ollama if unknown - - -def get_env_value(env_key: str, default: any, value_type: type = str) -> any: - """ - Get value from environment variable with type conversion - - Args: - env_key (str): Environment variable key - default (any): Default value if env variable is not set - value_type (type): Type to convert the value to - - Returns: - any: Converted value from environment or default - """ - value = os.getenv(env_key) - if value is None: - return default - - if value_type is bool: - return value.lower() in ("true", "1", "yes", "t", "on") - try: - return value_type(value) - except ValueError: - return default - - -def display_splash_screen(args: argparse.Namespace) -> None: - """ - Display a colorful splash screen showing LightRAG server configuration - - Args: - args: Parsed command line arguments - """ - # Banner - ASCIIColors.cyan(f""" - ╔══════════════════════════════════════════════════════════════╗ - ║ 🚀 LightRAG Server v{__api_version__} ║ - ║ Fast, Lightweight RAG Server Implementation ║ - ╚══════════════════════════════════════════════════════════════╝ - """) - - # Server Configuration - ASCIIColors.magenta("\n📡 Server Configuration:") - ASCIIColors.white(" ├─ Host: ", end="") - ASCIIColors.yellow(f"{args.host}") - ASCIIColors.white(" ├─ Port: ", end="") - ASCIIColors.yellow(f"{args.port}") - ASCIIColors.white(" ├─ CORS Origins: ", end="") - ASCIIColors.yellow(f"{os.getenv('CORS_ORIGINS', '*')}") - ASCIIColors.white(" ├─ SSL Enabled: ", end="") - ASCIIColors.yellow(f"{args.ssl}") - ASCIIColors.white(" └─ API Key: ", end="") - ASCIIColors.yellow("Set" if args.key else "Not Set") - if args.ssl: - ASCIIColors.white(" ├─ SSL Cert: ", end="") - ASCIIColors.yellow(f"{args.ssl_certfile}") - ASCIIColors.white(" └─ SSL Key: ", end="") - ASCIIColors.yellow(f"{args.ssl_keyfile}") - - # Directory Configuration - ASCIIColors.magenta("\n📂 Directory Configuration:") - ASCIIColors.white(" ├─ Working Directory: ", end="") - ASCIIColors.yellow(f"{args.working_dir}") - ASCIIColors.white(" └─ Input Directory: ", end="") - ASCIIColors.yellow(f"{args.input_dir}") - - # LLM Configuration - ASCIIColors.magenta("\n🤖 LLM Configuration:") - ASCIIColors.white(" ├─ Binding: ", end="") - ASCIIColors.yellow(f"{args.llm_binding}") - ASCIIColors.white(" ├─ Host: ", end="") - ASCIIColors.yellow(f"{args.llm_binding_host}") - ASCIIColors.white(" └─ Model: ", end="") - ASCIIColors.yellow(f"{args.llm_model}") - - # Embedding Configuration - ASCIIColors.magenta("\n📊 Embedding Configuration:") - ASCIIColors.white(" ├─ Binding: ", end="") - ASCIIColors.yellow(f"{args.embedding_binding}") - ASCIIColors.white(" ├─ Host: ", end="") - ASCIIColors.yellow(f"{args.embedding_binding_host}") - ASCIIColors.white(" ├─ Model: ", end="") - ASCIIColors.yellow(f"{args.embedding_model}") - ASCIIColors.white(" └─ Dimensions: ", end="") - ASCIIColors.yellow(f"{args.embedding_dim}") - - # RAG Configuration - ASCIIColors.magenta("\n⚙️ RAG Configuration:") - ASCIIColors.white(" ├─ Max Async Operations: ", end="") - ASCIIColors.yellow(f"{args.max_async}") - ASCIIColors.white(" ├─ Max Tokens: ", end="") - ASCIIColors.yellow(f"{args.max_tokens}") - ASCIIColors.white(" ├─ Max Embed Tokens: ", end="") - ASCIIColors.yellow(f"{args.max_embed_tokens}") - ASCIIColors.white(" ├─ Chunk Size: ", end="") - ASCIIColors.yellow(f"{args.chunk_size}") - ASCIIColors.white(" ├─ Chunk Overlap Size: ", end="") - ASCIIColors.yellow(f"{args.chunk_overlap_size}") - ASCIIColors.white(" ├─ History Turns: ", end="") - ASCIIColors.yellow(f"{args.history_turns}") - ASCIIColors.white(" ├─ Cosine Threshold: ", end="") - ASCIIColors.yellow(f"{args.cosine_threshold}") - ASCIIColors.white(" └─ Top-K: ", end="") - ASCIIColors.yellow(f"{args.top_k}") - - # System Configuration - ASCIIColors.magenta("\n💾 Storage Configuration:") - ASCIIColors.white(" ├─ KV Storage: ", end="") - ASCIIColors.yellow(f"{args.kv_storage}") - ASCIIColors.white(" ├─ Vector Storage: ", end="") - ASCIIColors.yellow(f"{args.vector_storage}") - ASCIIColors.white(" ├─ Graph Storage: ", end="") - ASCIIColors.yellow(f"{args.graph_storage}") - ASCIIColors.white(" └─ Document Status Storage: ", end="") - ASCIIColors.yellow(f"{args.doc_status_storage}") - - ASCIIColors.magenta("\n🛠️ System Configuration:") - ASCIIColors.white(" ├─ Ollama Emulating Model: ", end="") - ASCIIColors.yellow(f"{ollama_server_infos.LIGHTRAG_MODEL}") - ASCIIColors.white(" ├─ Log Level: ", end="") - ASCIIColors.yellow(f"{args.log_level}") - ASCIIColors.white(" ├─ Verbose Debug: ", end="") - ASCIIColors.yellow(f"{args.verbose}") - ASCIIColors.white(" └─ Timeout: ", end="") - ASCIIColors.yellow(f"{args.timeout if args.timeout else 'None (infinite)'}") - - # Server Status - ASCIIColors.green("\n✨ Server starting up...\n") - - # Server Access Information - protocol = "https" if args.ssl else "http" - if args.host == "0.0.0.0": - ASCIIColors.magenta("\n🌐 Server Access Information:") - ASCIIColors.white(" ├─ Local Access: ", end="") - ASCIIColors.yellow(f"{protocol}://localhost:{args.port}") - ASCIIColors.white(" ├─ Remote Access: ", end="") - ASCIIColors.yellow(f"{protocol}://:{args.port}") - ASCIIColors.white(" ├─ API Documentation (local): ", end="") - ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/docs") - ASCIIColors.white(" ├─ Alternative Documentation (local): ", end="") - ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/redoc") - ASCIIColors.white(" └─ WebUI (local): ", end="") - ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/webui") - - ASCIIColors.yellow("\n📝 Note:") - ASCIIColors.white(""" Since the server is running on 0.0.0.0: - - Use 'localhost' or '127.0.0.1' for local access - - Use your machine's IP address for remote access - - To find your IP address: - • Windows: Run 'ipconfig' in terminal - • Linux/Mac: Run 'ifconfig' or 'ip addr' in terminal - """) - else: - base_url = f"{protocol}://{args.host}:{args.port}" - ASCIIColors.magenta("\n🌐 Server Access Information:") - ASCIIColors.white(" ├─ Base URL: ", end="") - ASCIIColors.yellow(f"{base_url}") - ASCIIColors.white(" ├─ API Documentation: ", end="") - ASCIIColors.yellow(f"{base_url}/docs") - ASCIIColors.white(" └─ Alternative Documentation: ", end="") - ASCIIColors.yellow(f"{base_url}/redoc") - - # Usage Examples - ASCIIColors.magenta("\n📚 Quick Start Guide:") - ASCIIColors.cyan(""" - 1. Access the Swagger UI: - Open your browser and navigate to the API documentation URL above - - 2. API Authentication:""") - if args.key: - ASCIIColors.cyan(""" Add the following header to your requests: - X-API-Key: - """) - else: - ASCIIColors.cyan(" No authentication required\n") - - ASCIIColors.cyan(""" 3. Basic Operations: - - POST /upload_document: Upload new documents to RAG - - POST /query: Query your document collection - - GET /collections: List available collections - - 4. Monitor the server: - - Check server logs for detailed operation information - - Use healthcheck endpoint: GET /health - """) - - # Security Notice - if args.key: - ASCIIColors.yellow("\n⚠️ Security Notice:") - ASCIIColors.white(""" API Key authentication is enabled. - Make sure to include the X-API-Key header in all your requests. - """) - - ASCIIColors.green("Server is ready to accept connections! 🚀\n") - - # Ensure splash output flush to system log - sys.stdout.flush() - - -def parse_args() -> argparse.Namespace: - """ - Parse command line arguments with environment variable fallback - - Returns: - argparse.Namespace: Parsed arguments - """ - - parser = argparse.ArgumentParser( - description="LightRAG FastAPI Server with separate working and input directories" - ) - - parser.add_argument( - "--kv-storage", - default=get_env_value( - "LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE - ), - help=f"KV storage implementation (default: {DefaultRAGStorageConfig.KV_STORAGE})", - ) - parser.add_argument( - "--doc-status-storage", - default=get_env_value( - "LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE - ), - help=f"Document status storage implementation (default: {DefaultRAGStorageConfig.DOC_STATUS_STORAGE})", - ) - parser.add_argument( - "--graph-storage", - default=get_env_value( - "LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE - ), - help=f"Graph storage implementation (default: {DefaultRAGStorageConfig.GRAPH_STORAGE})", - ) - parser.add_argument( - "--vector-storage", - default=get_env_value( - "LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE - ), - help=f"Vector storage implementation (default: {DefaultRAGStorageConfig.VECTOR_STORAGE})", - ) - - # Bindings configuration - parser.add_argument( - "--llm-binding", - default=get_env_value("LLM_BINDING", "ollama"), - help="LLM binding to be used. Supported: lollms, ollama, openai (default: from env or ollama)", - ) - parser.add_argument( - "--embedding-binding", - default=get_env_value("EMBEDDING_BINDING", "ollama"), - help="Embedding binding to be used. Supported: lollms, ollama, openai (default: from env or ollama)", - ) - - # Server configuration - parser.add_argument( - "--host", - default=get_env_value("HOST", "0.0.0.0"), - help="Server host (default: from env or 0.0.0.0)", - ) - parser.add_argument( - "--port", - type=int, - default=get_env_value("PORT", 9621, int), - help="Server port (default: from env or 9621)", - ) - - # Directory configuration - parser.add_argument( - "--working-dir", - default=get_env_value("WORKING_DIR", "./rag_storage"), - help="Working directory for RAG storage (default: from env or ./rag_storage)", - ) - parser.add_argument( - "--input-dir", - default=get_env_value("INPUT_DIR", "./inputs"), - help="Directory containing input documents (default: from env or ./inputs)", - ) - - # LLM Model configuration - parser.add_argument( - "--llm-binding-host", - default=get_env_value("LLM_BINDING_HOST", None), - help="LLM server host URL. If not provided, defaults based on llm-binding:\n" - + "- ollama: http://localhost:11434\n" - + "- lollms: http://localhost:9600\n" - + "- openai: https://api.openai.com/v1", - ) - - default_llm_api_key = get_env_value("LLM_BINDING_API_KEY", None) - - parser.add_argument( - "--llm-binding-api-key", - default=default_llm_api_key, - help="llm server API key (default: from env or empty string)", - ) - - parser.add_argument( - "--llm-model", - default=get_env_value("LLM_MODEL", "mistral-nemo:latest"), - help="LLM model name (default: from env or mistral-nemo:latest)", - ) - - # Embedding model configuration - parser.add_argument( - "--embedding-binding-host", - default=get_env_value("EMBEDDING_BINDING_HOST", None), - help="Embedding server host URL. If not provided, defaults based on embedding-binding:\n" - + "- ollama: http://localhost:11434\n" - + "- lollms: http://localhost:9600\n" - + "- openai: https://api.openai.com/v1", - ) - - default_embedding_api_key = get_env_value("EMBEDDING_BINDING_API_KEY", "") - parser.add_argument( - "--embedding-binding-api-key", - default=default_embedding_api_key, - help="embedding server API key (default: from env or empty string)", - ) - - parser.add_argument( - "--embedding-model", - default=get_env_value("EMBEDDING_MODEL", "bge-m3:latest"), - help="Embedding model name (default: from env or bge-m3:latest)", - ) - - parser.add_argument( - "--chunk_size", - default=get_env_value("CHUNK_SIZE", 1200), - help="chunk chunk size default 1200", - ) - - parser.add_argument( - "--chunk_overlap_size", - default=get_env_value("CHUNK_OVERLAP_SIZE", 100), - help="chunk overlap size default 100", - ) - - def timeout_type(value): - if value is None or value == "None": - return None - return int(value) - - parser.add_argument( - "--timeout", - default=get_env_value("TIMEOUT", None, timeout_type), - type=timeout_type, - help="Timeout in seconds (useful when using slow AI). Use None for infinite timeout", - ) - - # RAG configuration - parser.add_argument( - "--max-async", - type=int, - default=get_env_value("MAX_ASYNC", 4, int), - help="Maximum async operations (default: from env or 4)", - ) - parser.add_argument( - "--max-tokens", - type=int, - default=get_env_value("MAX_TOKENS", 32768, int), - help="Maximum token size (default: from env or 32768)", - ) - parser.add_argument( - "--embedding-dim", - type=int, - default=get_env_value("EMBEDDING_DIM", 1024, int), - help="Embedding dimensions (default: from env or 1024)", - ) - parser.add_argument( - "--max-embed-tokens", - type=int, - default=get_env_value("MAX_EMBED_TOKENS", 8192, int), - help="Maximum embedding token size (default: from env or 8192)", - ) - - # Logging configuration - parser.add_argument( - "--log-level", - default=get_env_value("LOG_LEVEL", "INFO"), - choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], - help="Logging level (default: from env or INFO)", - ) - - parser.add_argument( - "--key", - type=str, - default=get_env_value("LIGHTRAG_API_KEY", None), - help="API key for authentication. This protects lightrag server against unauthorized access", - ) - - # Optional https parameters - parser.add_argument( - "--ssl", - action="store_true", - default=get_env_value("SSL", False, bool), - help="Enable HTTPS (default: from env or False)", - ) - parser.add_argument( - "--ssl-certfile", - default=get_env_value("SSL_CERTFILE", None), - help="Path to SSL certificate file (required if --ssl is enabled)", - ) - parser.add_argument( - "--ssl-keyfile", - default=get_env_value("SSL_KEYFILE", None), - help="Path to SSL private key file (required if --ssl is enabled)", - ) - parser.add_argument( - "--auto-scan-at-startup", - action="store_true", - default=False, - help="Enable automatic scanning when the program starts", - ) - - parser.add_argument( - "--history-turns", - type=int, - default=get_env_value("HISTORY_TURNS", 3, int), - help="Number of conversation history turns to include (default: from env or 3)", - ) - - # Search parameters - parser.add_argument( - "--top-k", - type=int, - default=get_env_value("TOP_K", 60, int), - help="Number of most similar results to return (default: from env or 60)", - ) - parser.add_argument( - "--cosine-threshold", - type=float, - default=get_env_value("COSINE_THRESHOLD", 0.2, float), - help="Cosine similarity threshold (default: from env or 0.4)", - ) - - # Ollama model name - parser.add_argument( - "--simulated-model-name", - type=str, - default=get_env_value( - "SIMULATED_MODEL_NAME", ollama_server_infos.LIGHTRAG_MODEL - ), - help="Number of conversation history turns to include (default: from env or 3)", - ) - - # Namespace - parser.add_argument( - "--namespace-prefix", - type=str, - default=get_env_value("NAMESPACE_PREFIX", ""), - help="Prefix of the namespace", - ) - - parser.add_argument( - "--verbose", - type=bool, - default=get_env_value("VERBOSE", False, bool), - help="Verbose debug output(default: from env or false)", - ) - - args = parser.parse_args() - - # convert relative path to absolute path - args.working_dir = os.path.abspath(args.working_dir) - args.input_dir = os.path.abspath(args.input_dir) - - ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name - - return args - - def create_app(args): # Set global top_k global global_top_k diff --git a/lightrag/api/routers/ollama_api.py b/lightrag/api/routers/ollama_api.py index 7d9fe3b9..9be8067e 100644 --- a/lightrag/api/routers/ollama_api.py +++ b/lightrag/api/routers/ollama_api.py @@ -5,31 +5,13 @@ import logging import time import json import re -import os from enum import Enum from fastapi.responses import StreamingResponse import asyncio from ascii_colors import trace_exception from lightrag import LightRAG, QueryParam from lightrag.utils import encode_string_by_tiktoken -from dotenv import load_dotenv - - -# Load environment variables -load_dotenv(override=True) - - -class OllamaServerInfos: - # Constants for emulated Ollama model information - LIGHTRAG_NAME = "lightrag" - LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest") - LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}" - LIGHTRAG_SIZE = 7365960935 # it's a dummy value - LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z" - LIGHTRAG_DIGEST = "sha256:lightrag" - - -ollama_server_infos = OllamaServerInfos() +from ..utils_api import ollama_server_infos # query mode according to query prefix (bypass is not LightRAG quer mode) diff --git a/lightrag/api/utils_api.py b/lightrag/api/utils_api.py index aef3c128..a24e731e 100644 --- a/lightrag/api/utils_api.py +++ b/lightrag/api/utils_api.py @@ -2,11 +2,33 @@ Utility functions for the LightRAG API. """ +import os +import argparse from typing import Optional +import sys +from ascii_colors import ASCIIColors +from lightrag.api import __api_version__ from fastapi import HTTPException, Security +from dotenv import load_dotenv from fastapi.security import APIKeyHeader from starlette.status import HTTP_403_FORBIDDEN +# Load environment variables +load_dotenv(override=True) + + +class OllamaServerInfos: + # Constants for emulated Ollama model information + LIGHTRAG_NAME = "lightrag" + LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest") + LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}" + LIGHTRAG_SIZE = 7365960935 # it's a dummy value + LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z" + LIGHTRAG_DIGEST = "sha256:lightrag" + + +ollama_server_infos = OllamaServerInfos() + def get_api_key_dependency(api_key: Optional[str]): """ @@ -43,3 +65,490 @@ def get_api_key_dependency(api_key: Optional[str]): return api_key_header_value return api_key_auth + + +class DefaultRAGStorageConfig: + KV_STORAGE = "JsonKVStorage" + VECTOR_STORAGE = "NanoVectorDBStorage" + GRAPH_STORAGE = "NetworkXStorage" + DOC_STATUS_STORAGE = "JsonDocStatusStorage" + + +def get_default_host(binding_type: str) -> str: + default_hosts = { + "ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"), + "lollms": os.getenv("LLM_BINDING_HOST", "http://localhost:9600"), + "azure_openai": os.getenv("AZURE_OPENAI_ENDPOINT", "https://api.openai.com/v1"), + "openai": os.getenv("LLM_BINDING_HOST", "https://api.openai.com/v1"), + } + return default_hosts.get( + binding_type, os.getenv("LLM_BINDING_HOST", "http://localhost:11434") + ) # fallback to ollama if unknown + + +def get_env_value(env_key: str, default: any, value_type: type = str) -> any: + """ + Get value from environment variable with type conversion + + Args: + env_key (str): Environment variable key + default (any): Default value if env variable is not set + value_type (type): Type to convert the value to + + Returns: + any: Converted value from environment or default + """ + value = os.getenv(env_key) + if value is None: + return default + + if value_type is bool: + return value.lower() in ("true", "1", "yes", "t", "on") + try: + return value_type(value) + except ValueError: + return default + + +def parse_args() -> argparse.Namespace: + """ + Parse command line arguments with environment variable fallback + + Returns: + argparse.Namespace: Parsed arguments + """ + + parser = argparse.ArgumentParser( + description="LightRAG FastAPI Server with separate working and input directories" + ) + + parser.add_argument( + "--kv-storage", + default=get_env_value( + "LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE + ), + help=f"KV storage implementation (default: {DefaultRAGStorageConfig.KV_STORAGE})", + ) + parser.add_argument( + "--doc-status-storage", + default=get_env_value( + "LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE + ), + help=f"Document status storage implementation (default: {DefaultRAGStorageConfig.DOC_STATUS_STORAGE})", + ) + parser.add_argument( + "--graph-storage", + default=get_env_value( + "LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE + ), + help=f"Graph storage implementation (default: {DefaultRAGStorageConfig.GRAPH_STORAGE})", + ) + parser.add_argument( + "--vector-storage", + default=get_env_value( + "LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE + ), + help=f"Vector storage implementation (default: {DefaultRAGStorageConfig.VECTOR_STORAGE})", + ) + + # Bindings configuration + parser.add_argument( + "--llm-binding", + default=get_env_value("LLM_BINDING", "ollama"), + help="LLM binding to be used. Supported: lollms, ollama, openai (default: from env or ollama)", + ) + parser.add_argument( + "--embedding-binding", + default=get_env_value("EMBEDDING_BINDING", "ollama"), + help="Embedding binding to be used. Supported: lollms, ollama, openai (default: from env or ollama)", + ) + + # Server configuration + parser.add_argument( + "--host", + default=get_env_value("HOST", "0.0.0.0"), + help="Server host (default: from env or 0.0.0.0)", + ) + parser.add_argument( + "--port", + type=int, + default=get_env_value("PORT", 9621, int), + help="Server port (default: from env or 9621)", + ) + + # Directory configuration + parser.add_argument( + "--working-dir", + default=get_env_value("WORKING_DIR", "./rag_storage"), + help="Working directory for RAG storage (default: from env or ./rag_storage)", + ) + parser.add_argument( + "--input-dir", + default=get_env_value("INPUT_DIR", "./inputs"), + help="Directory containing input documents (default: from env or ./inputs)", + ) + + # LLM Model configuration + parser.add_argument( + "--llm-binding-host", + default=get_env_value("LLM_BINDING_HOST", None), + help="LLM server host URL. If not provided, defaults based on llm-binding:\n" + + "- ollama: http://localhost:11434\n" + + "- lollms: http://localhost:9600\n" + + "- openai: https://api.openai.com/v1", + ) + + default_llm_api_key = get_env_value("LLM_BINDING_API_KEY", None) + + parser.add_argument( + "--llm-binding-api-key", + default=default_llm_api_key, + help="llm server API key (default: from env or empty string)", + ) + + parser.add_argument( + "--llm-model", + default=get_env_value("LLM_MODEL", "mistral-nemo:latest"), + help="LLM model name (default: from env or mistral-nemo:latest)", + ) + + # Embedding model configuration + parser.add_argument( + "--embedding-binding-host", + default=get_env_value("EMBEDDING_BINDING_HOST", None), + help="Embedding server host URL. If not provided, defaults based on embedding-binding:\n" + + "- ollama: http://localhost:11434\n" + + "- lollms: http://localhost:9600\n" + + "- openai: https://api.openai.com/v1", + ) + + default_embedding_api_key = get_env_value("EMBEDDING_BINDING_API_KEY", "") + parser.add_argument( + "--embedding-binding-api-key", + default=default_embedding_api_key, + help="embedding server API key (default: from env or empty string)", + ) + + parser.add_argument( + "--embedding-model", + default=get_env_value("EMBEDDING_MODEL", "bge-m3:latest"), + help="Embedding model name (default: from env or bge-m3:latest)", + ) + + parser.add_argument( + "--chunk_size", + default=get_env_value("CHUNK_SIZE", 1200), + help="chunk chunk size default 1200", + ) + + parser.add_argument( + "--chunk_overlap_size", + default=get_env_value("CHUNK_OVERLAP_SIZE", 100), + help="chunk overlap size default 100", + ) + + def timeout_type(value): + if value is None or value == "None": + return None + return int(value) + + parser.add_argument( + "--timeout", + default=get_env_value("TIMEOUT", None, timeout_type), + type=timeout_type, + help="Timeout in seconds (useful when using slow AI). Use None for infinite timeout", + ) + + # RAG configuration + parser.add_argument( + "--max-async", + type=int, + default=get_env_value("MAX_ASYNC", 4, int), + help="Maximum async operations (default: from env or 4)", + ) + parser.add_argument( + "--max-tokens", + type=int, + default=get_env_value("MAX_TOKENS", 32768, int), + help="Maximum token size (default: from env or 32768)", + ) + parser.add_argument( + "--embedding-dim", + type=int, + default=get_env_value("EMBEDDING_DIM", 1024, int), + help="Embedding dimensions (default: from env or 1024)", + ) + parser.add_argument( + "--max-embed-tokens", + type=int, + default=get_env_value("MAX_EMBED_TOKENS", 8192, int), + help="Maximum embedding token size (default: from env or 8192)", + ) + + # Logging configuration + parser.add_argument( + "--log-level", + default=get_env_value("LOG_LEVEL", "INFO"), + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + help="Logging level (default: from env or INFO)", + ) + + parser.add_argument( + "--key", + type=str, + default=get_env_value("LIGHTRAG_API_KEY", None), + help="API key for authentication. This protects lightrag server against unauthorized access", + ) + + # Optional https parameters + parser.add_argument( + "--ssl", + action="store_true", + default=get_env_value("SSL", False, bool), + help="Enable HTTPS (default: from env or False)", + ) + parser.add_argument( + "--ssl-certfile", + default=get_env_value("SSL_CERTFILE", None), + help="Path to SSL certificate file (required if --ssl is enabled)", + ) + parser.add_argument( + "--ssl-keyfile", + default=get_env_value("SSL_KEYFILE", None), + help="Path to SSL private key file (required if --ssl is enabled)", + ) + parser.add_argument( + "--auto-scan-at-startup", + action="store_true", + default=False, + help="Enable automatic scanning when the program starts", + ) + + parser.add_argument( + "--history-turns", + type=int, + default=get_env_value("HISTORY_TURNS", 3, int), + help="Number of conversation history turns to include (default: from env or 3)", + ) + + # Search parameters + parser.add_argument( + "--top-k", + type=int, + default=get_env_value("TOP_K", 60, int), + help="Number of most similar results to return (default: from env or 60)", + ) + parser.add_argument( + "--cosine-threshold", + type=float, + default=get_env_value("COSINE_THRESHOLD", 0.2, float), + help="Cosine similarity threshold (default: from env or 0.4)", + ) + + # Ollama model name + parser.add_argument( + "--simulated-model-name", + type=str, + default=get_env_value( + "SIMULATED_MODEL_NAME", ollama_server_infos.LIGHTRAG_MODEL + ), + help="Number of conversation history turns to include (default: from env or 3)", + ) + + # Namespace + parser.add_argument( + "--namespace-prefix", + type=str, + default=get_env_value("NAMESPACE_PREFIX", ""), + help="Prefix of the namespace", + ) + + parser.add_argument( + "--verbose", + type=bool, + default=get_env_value("VERBOSE", False, bool), + help="Verbose debug output(default: from env or false)", + ) + + args = parser.parse_args() + + # convert relative path to absolute path + args.working_dir = os.path.abspath(args.working_dir) + args.input_dir = os.path.abspath(args.input_dir) + + ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name + + return args + + +def display_splash_screen(args: argparse.Namespace) -> None: + """ + Display a colorful splash screen showing LightRAG server configuration + + Args: + args: Parsed command line arguments + """ + # Banner + ASCIIColors.cyan(f""" + ╔══════════════════════════════════════════════════════════════╗ + ║ 🚀 LightRAG Server v{__api_version__} ║ + ║ Fast, Lightweight RAG Server Implementation ║ + ╚══════════════════════════════════════════════════════════════╝ + """) + + # Server Configuration + ASCIIColors.magenta("\n📡 Server Configuration:") + ASCIIColors.white(" ├─ Host: ", end="") + ASCIIColors.yellow(f"{args.host}") + ASCIIColors.white(" ├─ Port: ", end="") + ASCIIColors.yellow(f"{args.port}") + ASCIIColors.white(" ├─ CORS Origins: ", end="") + ASCIIColors.yellow(f"{os.getenv('CORS_ORIGINS', '*')}") + ASCIIColors.white(" ├─ SSL Enabled: ", end="") + ASCIIColors.yellow(f"{args.ssl}") + ASCIIColors.white(" └─ API Key: ", end="") + ASCIIColors.yellow("Set" if args.key else "Not Set") + if args.ssl: + ASCIIColors.white(" ├─ SSL Cert: ", end="") + ASCIIColors.yellow(f"{args.ssl_certfile}") + ASCIIColors.white(" └─ SSL Key: ", end="") + ASCIIColors.yellow(f"{args.ssl_keyfile}") + + # Directory Configuration + ASCIIColors.magenta("\n📂 Directory Configuration:") + ASCIIColors.white(" ├─ Working Directory: ", end="") + ASCIIColors.yellow(f"{args.working_dir}") + ASCIIColors.white(" └─ Input Directory: ", end="") + ASCIIColors.yellow(f"{args.input_dir}") + + # LLM Configuration + ASCIIColors.magenta("\n🤖 LLM Configuration:") + ASCIIColors.white(" ├─ Binding: ", end="") + ASCIIColors.yellow(f"{args.llm_binding}") + ASCIIColors.white(" ├─ Host: ", end="") + ASCIIColors.yellow(f"{args.llm_binding_host}") + ASCIIColors.white(" └─ Model: ", end="") + ASCIIColors.yellow(f"{args.llm_model}") + + # Embedding Configuration + ASCIIColors.magenta("\n📊 Embedding Configuration:") + ASCIIColors.white(" ├─ Binding: ", end="") + ASCIIColors.yellow(f"{args.embedding_binding}") + ASCIIColors.white(" ├─ Host: ", end="") + ASCIIColors.yellow(f"{args.embedding_binding_host}") + ASCIIColors.white(" ├─ Model: ", end="") + ASCIIColors.yellow(f"{args.embedding_model}") + ASCIIColors.white(" └─ Dimensions: ", end="") + ASCIIColors.yellow(f"{args.embedding_dim}") + + # RAG Configuration + ASCIIColors.magenta("\n⚙️ RAG Configuration:") + ASCIIColors.white(" ├─ Max Async Operations: ", end="") + ASCIIColors.yellow(f"{args.max_async}") + ASCIIColors.white(" ├─ Max Tokens: ", end="") + ASCIIColors.yellow(f"{args.max_tokens}") + ASCIIColors.white(" ├─ Max Embed Tokens: ", end="") + ASCIIColors.yellow(f"{args.max_embed_tokens}") + ASCIIColors.white(" ├─ Chunk Size: ", end="") + ASCIIColors.yellow(f"{args.chunk_size}") + ASCIIColors.white(" ├─ Chunk Overlap Size: ", end="") + ASCIIColors.yellow(f"{args.chunk_overlap_size}") + ASCIIColors.white(" ├─ History Turns: ", end="") + ASCIIColors.yellow(f"{args.history_turns}") + ASCIIColors.white(" ├─ Cosine Threshold: ", end="") + ASCIIColors.yellow(f"{args.cosine_threshold}") + ASCIIColors.white(" └─ Top-K: ", end="") + ASCIIColors.yellow(f"{args.top_k}") + + # System Configuration + ASCIIColors.magenta("\n💾 Storage Configuration:") + ASCIIColors.white(" ├─ KV Storage: ", end="") + ASCIIColors.yellow(f"{args.kv_storage}") + ASCIIColors.white(" ├─ Vector Storage: ", end="") + ASCIIColors.yellow(f"{args.vector_storage}") + ASCIIColors.white(" ├─ Graph Storage: ", end="") + ASCIIColors.yellow(f"{args.graph_storage}") + ASCIIColors.white(" └─ Document Status Storage: ", end="") + ASCIIColors.yellow(f"{args.doc_status_storage}") + + ASCIIColors.magenta("\n🛠️ System Configuration:") + ASCIIColors.white(" ├─ Ollama Emulating Model: ", end="") + ASCIIColors.yellow(f"{ollama_server_infos.LIGHTRAG_MODEL}") + ASCIIColors.white(" ├─ Log Level: ", end="") + ASCIIColors.yellow(f"{args.log_level}") + ASCIIColors.white(" ├─ Verbose Debug: ", end="") + ASCIIColors.yellow(f"{args.verbose}") + ASCIIColors.white(" └─ Timeout: ", end="") + ASCIIColors.yellow(f"{args.timeout if args.timeout else 'None (infinite)'}") + + # Server Status + ASCIIColors.green("\n✨ Server starting up...\n") + + # Server Access Information + protocol = "https" if args.ssl else "http" + if args.host == "0.0.0.0": + ASCIIColors.magenta("\n🌐 Server Access Information:") + ASCIIColors.white(" ├─ Local Access: ", end="") + ASCIIColors.yellow(f"{protocol}://localhost:{args.port}") + ASCIIColors.white(" ├─ Remote Access: ", end="") + ASCIIColors.yellow(f"{protocol}://:{args.port}") + ASCIIColors.white(" ├─ API Documentation (local): ", end="") + ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/docs") + ASCIIColors.white(" ├─ Alternative Documentation (local): ", end="") + ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/redoc") + ASCIIColors.white(" └─ WebUI (local): ", end="") + ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/webui") + + ASCIIColors.yellow("\n📝 Note:") + ASCIIColors.white(""" Since the server is running on 0.0.0.0: + - Use 'localhost' or '127.0.0.1' for local access + - Use your machine's IP address for remote access + - To find your IP address: + • Windows: Run 'ipconfig' in terminal + • Linux/Mac: Run 'ifconfig' or 'ip addr' in terminal + """) + else: + base_url = f"{protocol}://{args.host}:{args.port}" + ASCIIColors.magenta("\n🌐 Server Access Information:") + ASCIIColors.white(" ├─ Base URL: ", end="") + ASCIIColors.yellow(f"{base_url}") + ASCIIColors.white(" ├─ API Documentation: ", end="") + ASCIIColors.yellow(f"{base_url}/docs") + ASCIIColors.white(" └─ Alternative Documentation: ", end="") + ASCIIColors.yellow(f"{base_url}/redoc") + + # Usage Examples + ASCIIColors.magenta("\n📚 Quick Start Guide:") + ASCIIColors.cyan(""" + 1. Access the Swagger UI: + Open your browser and navigate to the API documentation URL above + + 2. API Authentication:""") + if args.key: + ASCIIColors.cyan(""" Add the following header to your requests: + X-API-Key: + """) + else: + ASCIIColors.cyan(" No authentication required\n") + + ASCIIColors.cyan(""" 3. Basic Operations: + - POST /upload_document: Upload new documents to RAG + - POST /query: Query your document collection + - GET /collections: List available collections + + 4. Monitor the server: + - Check server logs for detailed operation information + - Use healthcheck endpoint: GET /health + """) + + # Security Notice + if args.key: + ASCIIColors.yellow("\n⚠️ Security Notice:") + ASCIIColors.white(""" API Key authentication is enabled. + Make sure to include the X-API-Key header in all your requests. + """) + + ASCIIColors.green("Server is ready to accept connections! 🚀\n") + + # Ensure splash output flush to system log + sys.stdout.flush() From 57cdab2b2bd8e5c39202d2618ff67e54556dfa1a Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 20 Feb 2025 12:33:43 +0800 Subject: [PATCH 08/14] Add tags to OllamaAPI router --- lightrag/api/routers/ollama_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/api/routers/ollama_api.py b/lightrag/api/routers/ollama_api.py index 9be8067e..ea9662ad 100644 --- a/lightrag/api/routers/ollama_api.py +++ b/lightrag/api/routers/ollama_api.py @@ -126,7 +126,7 @@ class OllamaAPI: self.rag = rag self.ollama_server_infos = ollama_server_infos self.top_k = top_k - self.router = APIRouter() + self.router = APIRouter(tags=["Ollama"]) self.setup_routes() def setup_routes(self): From 62e1fe5df2d362cad1a8b1d2a820c40ec4720357 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 20 Feb 2025 14:23:33 +0800 Subject: [PATCH 09/14] Change Ollama API router tag to lowercase --- lightrag/api/routers/ollama_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/api/routers/ollama_api.py b/lightrag/api/routers/ollama_api.py index ea9662ad..9688d073 100644 --- a/lightrag/api/routers/ollama_api.py +++ b/lightrag/api/routers/ollama_api.py @@ -126,7 +126,7 @@ class OllamaAPI: self.rag = rag self.ollama_server_infos = ollama_server_infos self.top_k = top_k - self.router = APIRouter(tags=["Ollama"]) + self.router = APIRouter(tags=["ollama"]) self.setup_routes() def setup_routes(self): From 82a4cb3e7915816f50e6a653edb62385b5275316 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 20 Feb 2025 14:30:41 +0800 Subject: [PATCH 10/14] Fix refactoring error on document handling - Fix refactoring error on pipeline_index_file - Delete unsed func: scan_directory - Add type hints of rag for better maintainability - Refine comments for better understanding --- lightrag/api/routers/document_routes.py | 173 +++++++++++------------- 1 file changed, 82 insertions(+), 91 deletions(-) diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index ba4f8b0a..383c762c 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -16,6 +16,7 @@ from typing import Dict, List, Optional, Any from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile from pydantic import BaseModel, Field, field_validator +from lightrag import LightRAG from lightrag.base import DocProcessingStatus, DocStatus from ..utils_api import get_api_key_dependency @@ -76,6 +77,20 @@ class DocStatusResponse(BaseModel): return dt return dt.isoformat() + """Response model for document status + + Attributes: + id: Document identifier + content_summary: Summary of document content + content_length: Length of document content + status: Current processing status + created_at: Creation timestamp (ISO format string) + updated_at: Last update timestamp (ISO format string) + chunks_count: Number of chunks (optional) + error: Error message if any (optional) + metadata: Additional metadata (optional) + """ + id: str content_summary: str content_length: int @@ -112,6 +127,7 @@ class DocumentManager: self.input_dir.mkdir(parents=True, exist_ok=True) def scan_directory_for_new_files(self) -> List[Path]: + """Scan input directory for new files""" new_files = [] for ext in self.supported_extensions: logging.info(f"Scanning for {ext} files in {self.input_dir}") @@ -120,12 +136,12 @@ class DocumentManager: new_files.append(file_path) return new_files - def scan_directory(self) -> List[Path]: - new_files = [] - for ext in self.supported_extensions: - for file_path in self.input_dir.rglob(f"*{ext}"): - new_files.append(file_path) - return new_files + # def scan_directory(self) -> List[Path]: + # new_files = [] + # for ext in self.supported_extensions: + # for file_path in self.input_dir.rglob(f"*{ext}"): + # new_files.append(file_path) + # return new_files def mark_as_indexed(self, file_path: Path): self.indexed_files.add(file_path) @@ -134,7 +150,16 @@ class DocumentManager: return any(filename.lower().endswith(ext) for ext in self.supported_extensions) -async def pipeline_enqueue_file(rag, file_path: Path) -> bool: +async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool: + """Add a file to the queue for processing + + Args: + rag: LightRAG instance + file_path: Path to the saved file + Returns: + bool: True if the file was successfully enqueued, False otherwise + """ + try: content = "" ext = file_path.suffix.lower() @@ -165,7 +190,9 @@ async def pipeline_enqueue_file(rag, file_path: Path) -> bool: docx_file = BytesIO(file) doc = Document(docx_file) - content = "\n".join([paragraph.text for paragraph in doc.paragraphs]) + content = "\n".join( + [paragraph.text for paragraph in doc.paragraphs] + ) case ".pptx": if not pm.is_installed("pptx"): pm.install("pptx") @@ -205,13 +232,19 @@ async def pipeline_enqueue_file(rag, file_path: Path) -> bool: # Insert into the RAG queue if content: await rag.apipeline_enqueue_documents(content) - logging.info(f"Successfully fetched and enqueued file: {file_path.name}") + logging.info( + f"Successfully fetched and enqueued file: {file_path.name}" + ) return True else: - logging.error(f"No content could be extracted from file: {file_path.name}") + logging.error( + f"No content could be extracted from file: {file_path.name}" + ) except Exception as e: - logging.error(f"Error processing or enqueueing file {file_path.name}: {str(e)}") + logging.error( + f"Error processing or enqueueing file {file_path.name}: {str(e)}" + ) logging.error(traceback.format_exc()) finally: if file_path.name.startswith(temp_prefix): @@ -222,7 +255,7 @@ async def pipeline_enqueue_file(rag, file_path: Path) -> bool: return False -async def pipeline_index_file(rag, file_path: Path): +async def pipeline_index_file(rag: LightRAG, file_path: Path): """Index a file Args: @@ -230,90 +263,26 @@ async def pipeline_index_file(rag, file_path: Path): file_path: Path to the saved file """ try: - content = "" - ext = file_path.suffix.lower() - - file = None - async with aiofiles.open(file_path, "rb") as f: - file = await f.read() - - # Process based on file type - match ext: - case ".txt" | ".md": - content = file.decode("utf-8") - case ".pdf": - if not pm.is_installed("pypdf2"): - pm.install("pypdf2") - from PyPDF2 import PdfReader # type: ignore - from io import BytesIO - - pdf_file = BytesIO(file) - reader = PdfReader(pdf_file) - for page in reader.pages: - content += page.extract_text() + "\n" - case ".docx": - if not pm.is_installed("docx"): - pm.install("docx") - from docx import Document - from io import BytesIO - - docx_file = BytesIO(file) - doc = Document(docx_file) - content = "\n".join([paragraph.text for paragraph in doc.paragraphs]) - case ".pptx": - if not pm.is_installed("pptx"): - pm.install("pptx") - from pptx import Presentation - from io import BytesIO - - pptx_file = BytesIO(file) - prs = Presentation(pptx_file) - for slide in prs.slides: - for shape in slide.shapes: - if hasattr(shape, "text"): - content += shape.text + "\n" - case ".xlsx": - if not pm.is_installed("openpyxl"): - pm.install("openpyxl") - from openpyxl import load_workbook - from io import BytesIO - - xlsx_file = BytesIO(file) - wb = load_workbook(xlsx_file) - for sheet in wb: - content += f"Sheet: {sheet.title}\n" - for row in sheet.iter_rows(values_only=True): - content += ( - "\t".join( - str(cell) if cell is not None else "" for cell in row - ) - + "\n" - ) - content += "\n" - case _: - logging.error( - f"Unsupported file type: {file_path.name} (extension {ext})" - ) - return - - # Insert into the RAG queue - if content: - await rag.apipeline_enqueue_documents(content) + if await pipeline_enqueue_file(file_path): await rag.apipeline_process_enqueue_documents() - logging.info(f"Successfully indexed file: {file_path.name}") - else: - logging.error(f"No content could be extracted from file: {file_path.name}") except Exception as e: logging.error(f"Error indexing file {file_path.name}: {str(e)}") logging.error(traceback.format_exc()) -async def pipeline_index_files(rag, file_paths: List[Path]): +async def pipeline_index_files(rag: LightRAG, file_paths: List[Path]): + """Index multiple files concurrently + + Args: + rag: LightRAG instance + file_paths: Paths to the files to index + """ if not file_paths: return try: enqueued = False + if len(file_paths) == 1: enqueued = await pipeline_enqueue_file(rag, file_paths[0]) else: @@ -327,7 +296,13 @@ async def pipeline_index_files(rag, file_paths: List[Path]): logging.error(traceback.format_exc()) -async def pipeline_index_texts(rag, texts: List[str]): +async def pipeline_index_texts(rag: LightRAG, texts: List[str]): + """Index a list of texts + + Args: + rag: LightRAG instance + texts: The texts to index + """ if not texts: return await rag.apipeline_enqueue_documents(texts) @@ -335,16 +310,29 @@ async def pipeline_index_texts(rag, texts: List[str]): async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path: + """Save the uploaded file to a temporary location + + Args: + file: The uploaded file + + Returns: + Path: The path to the saved file + """ + # Generate unique filename to avoid conflicts timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") unique_filename = f"{temp_prefix}{timestamp}_{file.filename}" + + # Create a temporary file to save the uploaded content temp_path = input_dir / "temp" / unique_filename temp_path.parent.mkdir(exist_ok=True) + + # Save the file with open(temp_path, "wb") as buffer: shutil.copyfileobj(file.file, buffer) return temp_path -async def run_scanning_process(rag, doc_manager: DocumentManager): +async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager): """Background task to scan and index documents""" try: new_files = doc_manager.scan_directory_for_new_files() @@ -375,7 +363,7 @@ async def run_scanning_process(rag, doc_manager: DocumentManager): def create_document_routes( - rag, doc_manager: DocumentManager, api_key: Optional[str] = None + rag: LightRAG, doc_manager: DocumentManager, api_key: Optional[str] = None ): optional_api_key = get_api_key_dependency(api_key) @@ -388,9 +376,6 @@ def create_document_routes( and processes them. If a scanning process is already running, it returns a status indicating that fact. - Args: - background_tasks (BackgroundTasks): FastAPI background tasks handler - Returns: dict: A dictionary containing the scanning status """ @@ -402,6 +387,7 @@ def create_document_routes( scan_progress["indexed_count"] = 0 scan_progress["progress"] = 0 + # Start the scanning process in the background background_tasks.add_task(run_scanning_process, rag, doc_manager) return {"status": "scanning_started"} @@ -453,6 +439,7 @@ def create_document_routes( with open(file_path, "wb") as buffer: shutil.copyfileobj(file.file, buffer) + # Add to background tasks background_tasks.add_task(pipeline_index_file, rag, file_path) return InsertResponse( @@ -562,6 +549,8 @@ def create_document_routes( ) temp_path = await save_temp_file(doc_manager.input_dir, file) + + # Add to background tasks background_tasks.add_task(pipeline_index_file, rag, temp_path) return InsertResponse( @@ -606,6 +595,7 @@ def create_document_routes( for file in files: if doc_manager.is_supported_file(file.filename): + # Create a temporary file to save the uploaded content temp_files.append(await save_temp_file(doc_manager.input_dir, file)) inserted_count += 1 else: @@ -614,6 +604,7 @@ def create_document_routes( if temp_files: background_tasks.add_task(pipeline_index_files, rag, temp_files) + # Prepare status message if inserted_count == len(files): status = "success" status_message = f"Successfully inserted all {inserted_count} documents" From f52b9929bb39b069c9a9cffc9659863c26a89022 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 20 Feb 2025 14:48:54 +0800 Subject: [PATCH 11/14] fix: add missing rag parameter in pipeline_enqueue_file call MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Add rag param to function call • Fix argument mismatch error • Ensure proper pipeline execution --- lightrag/api/routers/document_routes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index 383c762c..242f37b9 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -263,7 +263,7 @@ async def pipeline_index_file(rag: LightRAG, file_path: Path): file_path: Path to the saved file """ try: - if await pipeline_enqueue_file(file_path): + if await pipeline_enqueue_file(rag, file_path): await rag.apipeline_process_enqueue_documents() except Exception as e: From 17496783841bd9b5daf78adc2537a8af091bf3e8 Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 21 Feb 2025 12:16:04 +0800 Subject: [PATCH 12/14] Fix linting --- lightrag/api/routers/document_routes.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index 242f37b9..c17ccd88 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -190,9 +190,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool: docx_file = BytesIO(file) doc = Document(docx_file) - content = "\n".join( - [paragraph.text for paragraph in doc.paragraphs] - ) + content = "\n".join([paragraph.text for paragraph in doc.paragraphs]) case ".pptx": if not pm.is_installed("pptx"): pm.install("pptx") @@ -232,19 +230,13 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool: # Insert into the RAG queue if content: await rag.apipeline_enqueue_documents(content) - logging.info( - f"Successfully fetched and enqueued file: {file_path.name}" - ) + logging.info(f"Successfully fetched and enqueued file: {file_path.name}") return True else: - logging.error( - f"No content could be extracted from file: {file_path.name}" - ) + logging.error(f"No content could be extracted from file: {file_path.name}") except Exception as e: - logging.error( - f"Error processing or enqueueing file {file_path.name}: {str(e)}" - ) + logging.error(f"Error processing or enqueueing file {file_path.name}: {str(e)}") logging.error(traceback.format_exc()) finally: if file_path.name.startswith(temp_prefix): From cff229a806143e8756b6ce2c658a521c12cdfb71 Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 21 Feb 2025 14:46:27 +0800 Subject: [PATCH 13/14] fix: respect user-specified log level in set_logger Previously, the set_logger function would always set the log level to DEBUG, overriding any user-specified log level. --- lightrag/lightrag.py | 3 +-- lightrag/utils.py | 12 +++++++++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index db61788a..e5c67df4 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -263,9 +263,8 @@ class LightRAG: _storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED) def __post_init__(self): - logger.setLevel(self.log_level) os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True) - set_logger(self.log_file_path) + set_logger(self.log_file_path, self.log_level) logger.info(f"Logger initialized for working directory: {self.working_dir}") if not os.path.exists(self.working_dir): diff --git a/lightrag/utils.py b/lightrag/utils.py index d402d14c..ae7e8dce 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -57,11 +57,17 @@ logger = logging.getLogger("lightrag") logging.getLogger("httpx").setLevel(logging.WARNING) -def set_logger(log_file: str): - logger.setLevel(logging.DEBUG) +def set_logger(log_file: str, level: int = logging.DEBUG): + """Set up file logging with the specified level. + + Args: + log_file: Path to the log file + level: Logging level (e.g. logging.DEBUG, logging.INFO) + """ + logger.setLevel(level) file_handler = logging.FileHandler(log_file, encoding="utf-8") - file_handler.setLevel(logging.DEBUG) + file_handler.setLevel(level) formatter = logging.Formatter( "%(asctime)s - %(name)s - %(levelname)s - %(message)s" From b144e0c3b0298243c6673e524df414057a6fb9d5 Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 21 Feb 2025 21:07:37 +0800 Subject: [PATCH 14/14] Sync modifications from main branch --- lightrag/api/routers/document_routes.py | 67 ++++++++++++++++++++++++- 1 file changed, 66 insertions(+), 1 deletion(-) diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index c17ccd88..25ca24e4 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -117,6 +117,37 @@ class DocumentManager: ".docx", ".pptx", ".xlsx", + ".rtf", # Rich Text Format + ".odt", # OpenDocument Text + ".tex", # LaTeX + ".epub", # Electronic Publication + ".html", # HyperText Markup Language + ".htm", # HyperText Markup Language + ".csv", # Comma-Separated Values + ".json", # JavaScript Object Notation + ".xml", # eXtensible Markup Language + ".yaml", # YAML Ain't Markup Language + ".yml", # YAML + ".log", # Log files + ".conf", # Configuration files + ".ini", # Initialization files + ".properties", # Java properties files + ".sql", # SQL scripts + ".bat", # Batch files + ".sh", # Shell scripts + ".c", # C source code + ".cpp", # C++ source code + ".py", # Python source code + ".java", # Java source code + ".js", # JavaScript source code + ".ts", # TypeScript source code + ".swift", # Swift source code + ".go", # Go source code + ".rb", # Ruby source code + ".php", # PHP source code + ".css", # Cascading Style Sheets + ".scss", # Sassy CSS + ".less", # LESS CSS ), ): self.input_dir = Path(input_dir) @@ -170,7 +201,41 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool: # Process based on file type match ext: - case ".txt" | ".md": + case ( + ".txt" + | ".md" + | ".html" + | ".htm" + | ".tex" + | ".json" + | ".xml" + | ".yaml" + | ".yml" + | ".rtf" + | ".odt" + | ".epub" + | ".csv" + | ".log" + | ".conf" + | ".ini" + | ".properties" + | ".sql" + | ".bat" + | ".sh" + | ".c" + | ".cpp" + | ".py" + | ".java" + | ".js" + | ".ts" + | ".swift" + | ".go" + | ".rb" + | ".php" + | ".css" + | ".scss" + | ".less" + ): content = file.decode("utf-8") case ".pdf": if not pm.is_installed("pypdf2"):