From 2c56141bfd5ab8a1f8d52b77f08dbab23a067ee2 Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 14 Feb 2025 12:34:26 +0800 Subject: [PATCH 01/12] Standardize variable names with other vector database implementations (without functional modifications) --- lightrag/kg/faiss_impl.py | 4 ++-- lightrag/kg/nano_vector_db_impl.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index 0dca9e4c..b2090d78 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -27,8 +27,8 @@ class FaissVectorDBStorage(BaseVectorStorage): def __post_init__(self): # Grab config values if available - config = self.global_config.get("vector_db_storage_cls_kwargs", {}) - cosine_threshold = config.get("cosine_better_than_threshold") + kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) + cosine_threshold = kwargs.get("cosine_better_than_threshold") if cosine_threshold is None: raise ValueError( "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs" diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index 2db8f72a..60eed3dc 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -79,8 +79,8 @@ class NanoVectorDBStorage(BaseVectorStorage): # Initialize lock only for file operations self._save_lock = asyncio.Lock() # Use global config value if specified, otherwise use default - config = self.global_config.get("vector_db_storage_cls_kwargs", {}) - cosine_threshold = config.get("cosine_better_than_threshold") + kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) + cosine_threshold = kwargs.get("cosine_better_than_threshold") if cosine_threshold is None: raise ValueError( "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs" From 258c7596e6a49eb1533c5e41280bbab89a818902 Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 14 Feb 2025 12:50:43 +0800 Subject: [PATCH 02/12] fix: Improve file path handling and logging for document scanning MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Convert relative paths to absolute paths • Add logging for file scanning progress • Log total number of new files found • Enhance file scanning feedback • Improve path resolution safety --- lightrag/api/lightrag_server.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 1aeff264..ce182bc1 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -564,6 +564,10 @@ def parse_args() -> argparse.Namespace: args = parser.parse_args() + # conver relative path to absolute path + args.working_dir = os.path.abspath(args.working_dir) + args.input_dir = os.path.abspath(args.input_dir) + ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name return args @@ -595,6 +599,7 @@ class DocumentManager: """Scan input directory for new files""" new_files = [] for ext in self.supported_extensions: + logger.info(f"Scanning for {ext} files in {self.input_dir}") for file_path in self.input_dir.rglob(f"*{ext}"): if file_path not in self.indexed_files: new_files.append(file_path) @@ -1198,6 +1203,7 @@ def create_app(args): new_files = doc_manager.scan_directory_for_new_files() scan_progress["total_files"] = len(new_files) + logger.info(f"Found {len(new_files)} new files to index.") for file_path in new_files: try: with progress_lock: From f6058b79b643e8d52386f435b6d9bf4830d06038 Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 14 Feb 2025 13:26:19 +0800 Subject: [PATCH 03/12] Update .env.example with absolute path placeholders --- .env.example | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.env.example b/.env.example index 4b64ecb4..022bd63d 100644 --- a/.env.example +++ b/.env.example @@ -12,8 +12,8 @@ # LIGHTRAG_API_KEY=your-secure-api-key-here ### Directory Configuration -# WORKING_DIR=./rag_storage -# INPUT_DIR=./inputs +# WORKING_DIR= +# INPUT_DIR= ### Logging level LOG_LEVEL=INFO From ad88ba03bf8e531e010f053ecce694d0f343f13a Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 15 Feb 2025 11:07:38 +0800 Subject: [PATCH 04/12] docs: reorganize Ollama emulation API documentation for better readability --- lightrag/api/README.md | 110 ++++++++++++++++++++++------------------- 1 file changed, 60 insertions(+), 50 deletions(-) diff --git a/lightrag/api/README.md b/lightrag/api/README.md index 8e5a61d5..7e4fda7e 100644 --- a/lightrag/api/README.md +++ b/lightrag/api/README.md @@ -74,30 +74,38 @@ LLM_MODEL=model_name_of_azure_ai LLM_BINDING_API_KEY=api_key_of_azure_ai ``` -### About Ollama API +### 3. Install Lightrag as a Linux Service -We provide an Ollama-compatible interfaces for LightRAG, aiming to emulate LightRAG as an Ollama chat model. This allows AI chat frontends supporting Ollama, such as Open WebUI, to access LightRAG easily. +Create a your service file `lightrag.sevice` from the sample file : `lightrag.sevice.example`. Modified the WorkingDirectoryand EexecStart in the service file: -#### Choose Query mode in chat - -A query prefix in the query string can determines which LightRAG query mode is used to generate the respond for the query. The supported prefixes include: - -``` -/local -/global -/hybrid -/naive -/mix -/bypass +```text +Description=LightRAG Ollama Service +WorkingDirectory= +ExecStart=/lightrag/api/lightrag-api ``` -For example, chat message "/mix 唐僧有几个徒弟" will trigger a mix mode query for LighRAG. A chat message without query prefix will trigger a hybrid mode query by default。 +Modify your service startup script: `lightrag-api`. Change you python virtual environment activation command as needed: -"/bypass" is not a LightRAG query mode, it will tell API Server to pass the query directly to the underlying LLM with chat history. So user can use LLM to answer question base on the LightRAG query results. (If you are using Open WebUI as front end, you can just switch the model to a normal LLM instead of using /bypass prefix) +```shell +#!/bin/bash + +# your python virtual environment activation +source /home/netman/lightrag-xyj/venv/bin/activate +# start lightrag api server +lightrag-server +``` + +Install LightRAG service. If your system is Ubuntu, the following commands will work: + +```shell +sudo cp lightrag.service /etc/systemd/system/ +sudo systemctl daemon-reload +sudo systemctl start lightrag.service +sudo systemctl status lightrag.service +sudo systemctl enable lightrag.service +``` -#### Connect Open WebUI to LightRAG -After starting the lightrag-server, you can add an Ollama-type connection in the Open WebUI admin pannel. And then a model named lightrag:latest will appear in Open WebUI's model management interface. Users can then send queries to LightRAG through the chat interface. ## Configuration @@ -378,7 +386,7 @@ curl -X DELETE "http://localhost:9621/documents" #### GET /api/version -Get Ollama version information +Get Ollama version information. ```bash curl http://localhost:9621/api/version @@ -386,7 +394,7 @@ curl http://localhost:9621/api/version #### GET /api/tags -Get Ollama available models +Get Ollama available models. ```bash curl http://localhost:9621/api/tags @@ -394,7 +402,7 @@ curl http://localhost:9621/api/tags #### POST /api/chat -Handle chat completion requests +Handle chat completion requests. Routes user queries through LightRAG by selecting query mode based on query prefix. Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to underlying LLM. ```shell curl -N -X POST http://localhost:9621/api/chat -H "Content-Type: application/json" -d \ @@ -403,6 +411,10 @@ curl -N -X POST http://localhost:9621/api/chat -H "Content-Type: application/jso > For more information about Ollama API pls. visit : [Ollama API documentation](https://github.com/ollama/ollama/blob/main/docs/api.md) +#### POST /api/generate + +Handle generate completion requests. For compatibility purpose, the request is not processed by LightRAG, and will be handled by underlying LLM model. + ### Utility Endpoints #### GET /health @@ -412,7 +424,35 @@ Check server health and configuration. curl "http://localhost:9621/health" ``` +## Ollama Emulation + +We provide an Ollama-compatible interfaces for LightRAG, aiming to emulate LightRAG as an Ollama chat model. This allows AI chat frontends supporting Ollama, such as Open WebUI, to access LightRAG easily. + +### Connect Open WebUI to LightRAG + +After starting the lightrag-server, you can add an Ollama-type connection in the Open WebUI admin pannel. And then a model named lightrag:latest will appear in Open WebUI's model management interface. Users can then send queries to LightRAG through the chat interface. You'd better install LightRAG as service for this use case. + +Open WebUI's use LLM to do the session title and session keyword generation task. So the Ollama chat chat completion API detects and forwards OpenWebUI session-related requests directly to underlying LLM. + +### Choose Query mode in chat + +A query prefix in the query string can determines which LightRAG query mode is used to generate the respond for the query. The supported prefixes include: + +``` +/local +/global +/hybrid +/naive +/mix +/bypass +``` + +For example, chat message "/mix 唐僧有几个徒弟" will trigger a mix mode query for LighRAG. A chat message without query prefix will trigger a hybrid mode query by default。 + +"/bypass" is not a LightRAG query mode, it will tell API Server to pass the query directly to the underlying LLM with chat history. So user can use LLM to answer question base on the chat history. If you are using Open WebUI as front end, you can just switch the model to a normal LLM instead of using /bypass prefix. + ## Development + Contribute to the project: [Guide](contributor-readme.MD) ### Running in Development Mode @@ -471,33 +511,3 @@ This intelligent caching mechanism: - This optimization significantly reduces startup time for subsequent runs - The working directory (`--working-dir`) stores the vectorized documents database -## Install Lightrag as a Linux Service - -Create a your service file `lightrag.sevice` from the sample file : `lightrag.sevice.example`. Modified the WorkingDirectoryand EexecStart in the service file: - -```text -Description=LightRAG Ollama Service -WorkingDirectory= -ExecStart=/lightrag/api/lightrag-api -``` - -Modify your service startup script: `lightrag-api`. Change you python virtual environment activation command as needed: - -```shell -#!/bin/bash - -# your python virtual environment activation -source /home/netman/lightrag-xyj/venv/bin/activate -# start lightrag api server -lightrag-server -``` - -Install LightRAG service. If your system is Ubuntu, the following commands will work: - -```shell -sudo cp lightrag.service /etc/systemd/system/ -sudo systemctl daemon-reload -sudo systemctl start lightrag.service -sudo systemctl status lightrag.service -sudo systemctl enable lightrag.service -``` From 0db0419c6dc38651589db6121245acad3df74eeb Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 15 Feb 2025 11:08:54 +0800 Subject: [PATCH 05/12] Fix linting --- lightrag/api/README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/lightrag/api/README.md b/lightrag/api/README.md index 7e4fda7e..06510618 100644 --- a/lightrag/api/README.md +++ b/lightrag/api/README.md @@ -510,4 +510,3 @@ This intelligent caching mechanism: - Only new documents in the input directory will be processed - This optimization significantly reduces startup time for subsequent runs - The working directory (`--working-dir`) stores the vectorized documents database - From 2985d88f976ab63b6ce31d1c9929506e37c288ae Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 15 Feb 2025 11:39:10 +0800 Subject: [PATCH 06/12] refactor: improve CORS and streaming response headers - Add configurable CORS origins - Remove duplicate CORS headers - Add X-Accel-Buffering header - Update env example file - Clean up header configurations --- .env.example | 13 +++++++------ lightrag/api/lightrag_server.py | 16 +++++++++++----- lightrag/api/ollama_api.py | 8 ++------ 3 files changed, 20 insertions(+), 17 deletions(-) diff --git a/.env.example b/.env.example index 022bd63d..2701335a 100644 --- a/.env.example +++ b/.env.example @@ -1,12 +1,13 @@ ### Server Configuration -#HOST=0.0.0.0 -#PORT=9621 -#NAMESPACE_PREFIX=lightrag # separating data from difference Lightrag instances +# HOST=0.0.0.0 +# PORT=9621 +# NAMESPACE_PREFIX=lightrag # separating data from difference Lightrag instances +# CORS_ORIGINS=http://localhost:3000,http://localhost:8080 ### Optional SSL Configuration -#SSL=true -#SSL_CERTFILE=/path/to/cert.pem -#SSL_KEYFILE=/path/to/key.pem +# SSL=true +# SSL_CERTFILE=/path/to/cert.pem +# SSL_KEYFILE=/path/to/key.pem ### Security (empty for no api-key is needed) # LIGHTRAG_API_KEY=your-secure-api-key-here diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index ce182bc1..19552faf 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -847,10 +847,19 @@ def create_app(args): lifespan=lifespan, ) + def get_cors_origins(): + """Get allowed origins from environment variable + Returns a list of allowed origins, defaults to ["*"] if not set + """ + origins_str = os.getenv("CORS_ORIGINS", "*") + if origins_str == "*": + return ["*"] + return [origin.strip() for origin in origins_str.split(",")] + # Add CORS middleware app.add_middleware( CORSMiddleware, - allow_origins=["*"], + allow_origins=get_cors_origins(), allow_credentials=True, allow_methods=["*"], allow_headers=["*"], @@ -1377,10 +1386,7 @@ def create_app(args): "Cache-Control": "no-cache", "Connection": "keep-alive", "Content-Type": "application/x-ndjson", - "Access-Control-Allow-Origin": "*", - "Access-Control-Allow-Methods": "POST, OPTIONS", - "Access-Control-Allow-Headers": "Content-Type", - "X-Accel-Buffering": "no", # Disable Nginx buffering + "X-Accel-Buffering": "no", # 确保在Nginx代理时正确处理流式响应 }, ) except Exception as e: diff --git a/lightrag/api/ollama_api.py b/lightrag/api/ollama_api.py index 01a883ca..94703dee 100644 --- a/lightrag/api/ollama_api.py +++ b/lightrag/api/ollama_api.py @@ -316,9 +316,7 @@ class OllamaAPI: "Cache-Control": "no-cache", "Connection": "keep-alive", "Content-Type": "application/x-ndjson", - "Access-Control-Allow-Origin": "*", - "Access-Control-Allow-Methods": "POST, OPTIONS", - "Access-Control-Allow-Headers": "Content-Type", + "X-Accel-Buffering": "no", # 确保在Nginx代理时正确处理流式响应 }, ) else: @@ -534,9 +532,7 @@ class OllamaAPI: "Cache-Control": "no-cache", "Connection": "keep-alive", "Content-Type": "application/x-ndjson", - "Access-Control-Allow-Origin": "*", - "Access-Control-Allow-Methods": "POST, OPTIONS", - "Access-Control-Allow-Headers": "Content-Type", + "X-Accel-Buffering": "no", # 确保在Nginx代理时正确处理流式响应 }, ) else: From 8fdbcb0d3f749741daa57dfbd346000f1b4e652f Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 15 Feb 2025 11:46:47 +0800 Subject: [PATCH 07/12] fix: reorganize server info display and add CORS origins info MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Add CORS origins display • Move API key status higher in display • Fix tree symbols for better readability • Regroup related server info • Remove redundant line breaks --- lightrag/api/lightrag_server.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 19552faf..97f1156f 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -159,8 +159,12 @@ def display_splash_screen(args: argparse.Namespace) -> None: ASCIIColors.yellow(f"{args.host}") ASCIIColors.white(" ├─ Port: ", end="") ASCIIColors.yellow(f"{args.port}") - ASCIIColors.white(" └─ SSL Enabled: ", end="") + ASCIIColors.white(" ├─ CORS Origins: ", end="") + ASCIIColors.yellow(f"{os.getenv('CORS_ORIGINS', '*')}") + ASCIIColors.white(" ├─ SSL Enabled: ", end="") ASCIIColors.yellow(f"{args.ssl}") + ASCIIColors.white(" └─ API Key: ", end="") + ASCIIColors.yellow("Set" if args.key else "Not Set") if args.ssl: ASCIIColors.white(" ├─ SSL Cert: ", end="") ASCIIColors.yellow(f"{args.ssl_certfile}") @@ -229,10 +233,8 @@ def display_splash_screen(args: argparse.Namespace) -> None: ASCIIColors.yellow(f"{ollama_server_infos.LIGHTRAG_MODEL}") ASCIIColors.white(" ├─ Log Level: ", end="") ASCIIColors.yellow(f"{args.log_level}") - ASCIIColors.white(" ├─ Timeout: ", end="") + ASCIIColors.white(" └─ Timeout: ", end="") ASCIIColors.yellow(f"{args.timeout if args.timeout else 'None (infinite)'}") - ASCIIColors.white(" └─ API Key: ", end="") - ASCIIColors.yellow("Set" if args.key else "Not Set") # Server Status ASCIIColors.green("\n✨ Server starting up...\n") From 147d73bd56ed03ecc0770e8ec0c7a47305d3b0cd Mon Sep 17 00:00:00 2001 From: ArnoChen Date: Sat, 15 Feb 2025 22:25:48 +0800 Subject: [PATCH 08/12] refactor file indexing for background async processing --- lightrag/api/lightrag_server.py | 609 ++++++++++++++++---------------- 1 file changed, 295 insertions(+), 314 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 1aeff264..c51933b3 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -14,7 +14,7 @@ import re from fastapi.staticfiles import StaticFiles import logging import argparse -from typing import List, Any, Optional, Union, Dict +from typing import List, Any, Optional, Dict from pydantic import BaseModel from lightrag import LightRAG, QueryParam from lightrag.types import GPTKeywordExtractionFormat @@ -34,6 +34,9 @@ from starlette.status import HTTP_403_FORBIDDEN import pipmaster as pm from dotenv import load_dotenv import configparser +import traceback +from datetime import datetime + from lightrag.utils import logger from .ollama_api import ( OllamaAPI, @@ -645,7 +648,6 @@ class InsertTextRequest(BaseModel): class InsertResponse(BaseModel): status: str message: str - document_count: int def get_api_key_dependency(api_key: Optional[str]): @@ -675,6 +677,7 @@ def get_api_key_dependency(api_key: Optional[str]): # Global configuration global_top_k = 60 # default value +temp_prefix = "__tmp_" # prefix for temporary files def create_app(args): @@ -1116,79 +1119,122 @@ def create_app(args): ("llm_response_cache", rag.llm_response_cache), ] - async def index_file(file_path: Union[str, Path]) -> None: - """Index all files inside the folder with support for multiple file formats + async def index_file(file_path: Path, description: Optional[str] = None): + """Index a file Args: - file_path: Path to the file to be indexed (str or Path object) - - Raises: - ValueError: If file format is not supported - FileNotFoundError: If file doesn't exist + file_path: Path to the saved file + description: Optional description of the file """ - if not pm.is_installed("aiofiles"): - pm.install("aiofiles") + try: + content = "" + ext = file_path.suffix.lower() - # Convert to Path object if string - file_path = Path(file_path) + file = None + async with aiofiles.open(file_path, "rb") as f: + file = await f.read() - # Check if file exists - if not file_path.exists(): - raise FileNotFoundError(f"File not found: {file_path}") + # Process based on file type + match ext: + case ".txt" | ".md": + content = file.decode("utf-8") + case ".pdf": + if not pm.is_installed("pypdf2"): + pm.install("pypdf2") + from PyPDF2 import PdfReader + from io import BytesIO - content = "" - # Get file extension in lowercase - ext = file_path.suffix.lower() + pdf_file = BytesIO(file) + reader = PdfReader(pdf_file) + for page in reader.pages: + content += page.extract_text() + "\n" + case ".docx": + if not pm.is_installed("docx"): + pm.install("docx") + from docx import Document + from io import BytesIO - match ext: - case ".txt" | ".md": - # Text files handling - async with aiofiles.open(file_path, "r", encoding="utf-8") as f: - content = await f.read() + docx_content = await file.read() + docx_file = BytesIO(docx_content) + doc = Document(docx_file) + content = "\n".join( + [paragraph.text for paragraph in doc.paragraphs] + ) + case ".pptx": + if not pm.is_installed("pptx"): + pm.install("pptx") + from pptx import Presentation # type: ignore + from io import BytesIO - case ".pdf" | ".docx" | ".pptx" | ".xlsx": - if not pm.is_installed("docling"): - pm.install("docling") - from docling.document_converter import DocumentConverter + pptx_content = await file.read() + pptx_file = BytesIO(pptx_content) + prs = Presentation(pptx_file) + for slide in prs.slides: + for shape in slide.shapes: + if hasattr(shape, "text"): + content += shape.text + "\n" + case _: + logging.error( + f"Unsupported file type: {file_path.name} (extension {ext})" + ) + return - async def convert_doc(): - def sync_convert(): - converter = DocumentConverter() - result = converter.convert(file_path) - return result.document.export_to_markdown() + # Add description if provided + if description: + content = f"{description}\n\n{content}" - return await asyncio.to_thread(sync_convert) + # Insert into RAG system + if content: + await rag.ainsert(content) + logging.info( + f"Successfully processed and indexed file: {file_path.name}" + ) + else: + logging.error( + f"No content could be extracted from file: {file_path.name}" + ) - content = await convert_doc() + except Exception as e: + logging.error(f"Error indexing file {file_path.name}: {str(e)}") + logging.error(traceback.format_exc()) + finally: + if file_path.name.startswith(temp_prefix): + # Clean up the temporary file after indexing + try: + file_path.unlink() + except Exception as e: + logging.error(f"Error deleting file {file_path}: {str(e)}") - case _: - raise ValueError(f"Unsupported file format: {ext}") + async def batch_index_files(file_paths: List[Path]): + """Index multiple files - # Insert content into RAG system - if content: - await rag.ainsert(content) - doc_manager.mark_as_indexed(file_path) - logging.info(f"Successfully indexed file: {file_path}") - else: - logging.warning(f"No content extracted from file: {file_path}") + Args: + file_paths: Paths to the files to index + """ + for file_path in file_paths: + await index_file(file_path) - @app.post("/documents/scan", dependencies=[Depends(optional_api_key)]) - async def scan_for_new_documents(background_tasks: BackgroundTasks): - """Trigger the scanning process""" - global scan_progress + async def save_temp_file(file: UploadFile = File(...)) -> Path: + """Save the uploaded file to a temporary location - with progress_lock: - if scan_progress["is_scanning"]: - return {"status": "already_scanning"} + Args: + file: The uploaded file - scan_progress["is_scanning"] = True - scan_progress["indexed_count"] = 0 - scan_progress["progress"] = 0 + Returns: + Path: The path to the saved file + """ + # Generate unique filename to avoid conflicts + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + unique_filename = f"{temp_prefix}{timestamp}_{file.filename}" - # Start the scanning process in the background - background_tasks.add_task(run_scanning_process) + # Create a temporary file to save the uploaded content + temp_path = doc_manager.input_dir / "temp" / unique_filename + temp_path.parent.mkdir(exist_ok=True) - return {"status": "scanning_started"} + # Save the file + with open(temp_path, "wb") as buffer: + shutil.copyfileobj(file.file, buffer) + return temp_path async def run_scanning_process(): """Background task to scan and index documents""" @@ -1221,6 +1267,24 @@ def create_app(args): with progress_lock: scan_progress["is_scanning"] = False + @app.post("/documents/scan", dependencies=[Depends(optional_api_key)]) + async def scan_for_new_documents(background_tasks: BackgroundTasks): + """Trigger the scanning process""" + global scan_progress + + with progress_lock: + if scan_progress["is_scanning"]: + return {"status": "already_scanning"} + + scan_progress["is_scanning"] = True + scan_progress["indexed_count"] = 0 + scan_progress["progress"] = 0 + + # Start the scanning process in the background + background_tasks.add_task(run_scanning_process) + + return {"status": "scanning_started"} + @app.get("/documents/scan-progress") async def get_scan_progress(): """Get the current scanning progress""" @@ -1228,7 +1292,9 @@ def create_app(args): return scan_progress @app.post("/documents/upload", dependencies=[Depends(optional_api_key)]) - async def upload_to_input_dir(file: UploadFile = File(...)): + async def upload_to_input_dir( + background_tasks: BackgroundTasks, file: UploadFile = File(...) + ): """ Endpoint for uploading a file to the input directory and indexing it. @@ -1237,6 +1303,7 @@ def create_app(args): indexes it for retrieval, and returns a success status with relevant details. Parameters: + background_tasks: FastAPI BackgroundTasks for async processing file (UploadFile): The file to be uploaded. It must have an allowed extension as per `doc_manager.supported_extensions`. @@ -1261,15 +1328,178 @@ def create_app(args): with open(file_path, "wb") as buffer: shutil.copyfileobj(file.file, buffer) - # Immediately index the uploaded file - await index_file(file_path) + # Add to background tasks + background_tasks.add_task(index_file, file_path) - return { - "status": "success", - "message": f"File uploaded and indexed: {file.filename}", - "total_documents": len(doc_manager.indexed_files), - } + return InsertResponse( + status="success", + message=f"File '{file.filename}' uploaded successfully. Processing will continue in background.", + ) except Exception as e: + logging.error(f"Error /documents/upload: {file.filename}: {str(e)}") + logging.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=str(e)) + + @app.post( + "/documents/text", + response_model=InsertResponse, + dependencies=[Depends(optional_api_key)], + ) + async def insert_text( + request: InsertTextRequest, background_tasks: BackgroundTasks + ): + """ + Insert text into the Retrieval-Augmented Generation (RAG) system. + + This endpoint allows you to insert text data into the RAG system for later retrieval and use in generating responses. + + Args: + request (InsertTextRequest): The request body containing the text to be inserted. + background_tasks: FastAPI BackgroundTasks for async processing + + Returns: + InsertResponse: A response object containing the status of the operation, a message, and the number of documents inserted. + """ + try: + background_tasks.add_task(rag.ainsert, request.text) + return InsertResponse( + status="success", + message="Text successfully received. Processing will continue in background.", + ) + except Exception as e: + logging.error(f"Error /documents/text: {str(e)}") + logging.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=str(e)) + + @app.post( + "/documents/file", + response_model=InsertResponse, + dependencies=[Depends(optional_api_key)], + ) + async def insert_file( + background_tasks: BackgroundTasks, + file: UploadFile = File(...), + description: str = Form(None), + ): + """Insert a file directly into the RAG system + + Args: + background_tasks: FastAPI BackgroundTasks for async processing + file: Uploaded file + description: Optional description of the file + + Returns: + InsertResponse: Status of the insertion operation + + Raises: + HTTPException: For unsupported file types or processing errors + """ + try: + if not doc_manager.is_supported_file(file.filename): + raise HTTPException( + status_code=400, + detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}", + ) + + # Create a temporary file to save the uploaded content + temp_path = save_temp_file(file) + + # Add to background tasks + background_tasks.add_task(index_file, temp_path, description) + + return InsertResponse( + status="success", + message=f"File '{file.filename}' saved successfully. Processing will continue in background.", + ) + + except Exception as e: + logging.error(f"Error /documents/file: {str(e)}") + logging.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=str(e)) + + @app.post( + "/documents/batch", + response_model=InsertResponse, + dependencies=[Depends(optional_api_key)], + ) + async def insert_batch( + background_tasks: BackgroundTasks, files: List[UploadFile] = File(...) + ): + """Process multiple files in batch mode + + Args: + background_tasks: FastAPI BackgroundTasks for async processing + files: List of files to process + + Returns: + InsertResponse: Status of the batch insertion operation + + Raises: + HTTPException: For processing errors + """ + try: + inserted_count = 0 + failed_files = [] + temp_files = [] + + for file in files: + if doc_manager.is_supported_file(file.filename): + # Create a temporary file to save the uploaded content + temp_files.append(save_temp_file(file)) + inserted_count += 1 + else: + failed_files.append(f"{file.filename} (unsupported type)") + + if temp_files: + background_tasks.add_task(batch_index_files, temp_files) + + # Prepare status message + if inserted_count == len(files): + status = "success" + status_message = f"Successfully inserted all {inserted_count} documents" + elif inserted_count > 0: + status = "partial_success" + status_message = f"Successfully inserted {inserted_count} out of {len(files)} documents" + if failed_files: + status_message += f". Failed files: {', '.join(failed_files)}" + else: + status = "failure" + status_message = "No documents were successfully inserted" + if failed_files: + status_message += f". Failed files: {', '.join(failed_files)}" + + return InsertResponse(status=status, message=status_message) + + except Exception as e: + logging.error(f"Error /documents/batch: {file.filename}: {str(e)}") + logging.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=str(e)) + + @app.delete( + "/documents", + response_model=InsertResponse, + dependencies=[Depends(optional_api_key)], + ) + async def clear_documents(): + """ + Clear all documents from the LightRAG system. + + This endpoint deletes all text chunks, entities vector database, and relationships vector database, + effectively clearing all documents from the LightRAG system. + + Returns: + InsertResponse: A response object containing the status, message, and the new document count (0 in this case). + """ + try: + rag.text_chunks = [] + rag.entities_vdb = None + rag.relationships_vdb = None + return InsertResponse( + status="success", message="All documents cleared successfully" + ) + except Exception as e: + logging.error(f"Error DELETE /documents: {str(e)}") + logging.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) @app.post( @@ -1381,255 +1611,6 @@ def create_app(args): trace_exception(e) raise HTTPException(status_code=500, detail=str(e)) - @app.post( - "/documents/text", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def insert_text(request: InsertTextRequest): - """ - Insert text into the Retrieval-Augmented Generation (RAG) system. - - This endpoint allows you to insert text data into the RAG system for later retrieval and use in generating responses. - - Args: - request (InsertTextRequest): The request body containing the text to be inserted. - - Returns: - InsertResponse: A response object containing the status of the operation, a message, and the number of documents inserted. - """ - try: - await rag.ainsert(request.text) - return InsertResponse( - status="success", - message="Text successfully inserted", - document_count=1, - ) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/documents/file", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def insert_file(file: UploadFile = File(...), description: str = Form(None)): - """Insert a file directly into the RAG system - - Args: - file: Uploaded file - description: Optional description of the file - - Returns: - InsertResponse: Status of the insertion operation - - Raises: - HTTPException: For unsupported file types or processing errors - """ - try: - content = "" - # Get file extension in lowercase - ext = Path(file.filename).suffix.lower() - - match ext: - case ".txt" | ".md": - # Text files handling - text_content = await file.read() - content = text_content.decode("utf-8") - - case ".pdf" | ".docx" | ".pptx" | ".xlsx": - if not pm.is_installed("docling"): - pm.install("docling") - from docling.document_converter import DocumentConverter - - # Create a temporary file to save the uploaded content - temp_path = Path("temp") / file.filename - temp_path.parent.mkdir(exist_ok=True) - - # Save the uploaded file - with temp_path.open("wb") as f: - f.write(await file.read()) - - try: - - async def convert_doc(): - def sync_convert(): - converter = DocumentConverter() - result = converter.convert(str(temp_path)) - return result.document.export_to_markdown() - - return await asyncio.to_thread(sync_convert) - - content = await convert_doc() - finally: - # Clean up the temporary file - temp_path.unlink() - - # Insert content into RAG system - if content: - # Add description if provided - if description: - content = f"{description}\n\n{content}" - - await rag.ainsert(content) - logging.info(f"Successfully indexed file: {file.filename}") - - return InsertResponse( - status="success", - message=f"File '{file.filename}' successfully inserted", - document_count=1, - ) - else: - raise HTTPException( - status_code=400, - detail="No content could be extracted from the file", - ) - - except UnicodeDecodeError: - raise HTTPException(status_code=400, detail="File encoding not supported") - except Exception as e: - logging.error(f"Error processing file {file.filename}: {str(e)}") - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/documents/batch", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def insert_batch(files: List[UploadFile] = File(...)): - """Process multiple files in batch mode - - Args: - files: List of files to process - - Returns: - InsertResponse: Status of the batch insertion operation - - Raises: - HTTPException: For processing errors - """ - try: - inserted_count = 0 - failed_files = [] - - for file in files: - try: - content = "" - ext = Path(file.filename).suffix.lower() - - match ext: - case ".txt" | ".md": - text_content = await file.read() - content = text_content.decode("utf-8") - - case ".pdf": - if not pm.is_installed("pypdf2"): - pm.install("pypdf2") - from PyPDF2 import PdfReader - from io import BytesIO - - pdf_content = await file.read() - pdf_file = BytesIO(pdf_content) - reader = PdfReader(pdf_file) - for page in reader.pages: - content += page.extract_text() + "\n" - - case ".docx": - if not pm.is_installed("docx"): - pm.install("docx") - from docx import Document - from io import BytesIO - - docx_content = await file.read() - docx_file = BytesIO(docx_content) - doc = Document(docx_file) - content = "\n".join( - [paragraph.text for paragraph in doc.paragraphs] - ) - - case ".pptx": - if not pm.is_installed("pptx"): - pm.install("pptx") - from pptx import Presentation # type: ignore - from io import BytesIO - - pptx_content = await file.read() - pptx_file = BytesIO(pptx_content) - prs = Presentation(pptx_file) - for slide in prs.slides: - for shape in slide.shapes: - if hasattr(shape, "text"): - content += shape.text + "\n" - - case _: - failed_files.append(f"{file.filename} (unsupported type)") - continue - - if content: - await rag.ainsert(content) - inserted_count += 1 - logging.info(f"Successfully indexed file: {file.filename}") - else: - failed_files.append(f"{file.filename} (no content extracted)") - - except UnicodeDecodeError: - failed_files.append(f"{file.filename} (encoding error)") - except Exception as e: - failed_files.append(f"{file.filename} ({str(e)})") - logging.error(f"Error processing file {file.filename}: {str(e)}") - - # Prepare status message - if inserted_count == len(files): - status = "success" - status_message = f"Successfully inserted all {inserted_count} documents" - elif inserted_count > 0: - status = "partial_success" - status_message = f"Successfully inserted {inserted_count} out of {len(files)} documents" - if failed_files: - status_message += f". Failed files: {', '.join(failed_files)}" - else: - status = "failure" - status_message = "No documents were successfully inserted" - if failed_files: - status_message += f". Failed files: {', '.join(failed_files)}" - - return InsertResponse( - status=status, - message=status_message, - document_count=inserted_count, - ) - - except Exception as e: - logging.error(f"Batch processing error: {str(e)}") - raise HTTPException(status_code=500, detail=str(e)) - - @app.delete( - "/documents", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def clear_documents(): - """ - Clear all documents from the LightRAG system. - - This endpoint deletes all text chunks, entities vector database, and relationships vector database, - effectively clearing all documents from the LightRAG system. - - Returns: - InsertResponse: A response object containing the status, message, and the new document count (0 in this case). - """ - try: - rag.text_chunks = [] - rag.entities_vdb = None - rag.relationships_vdb = None - return InsertResponse( - status="success", - message="All documents cleared successfully", - document_count=0, - ) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - # query all graph labels @app.get("/graph/label/list") async def get_graph_labels(): From 33a4f00b1d170dccd12fcd7823287db15a389f4d Mon Sep 17 00:00:00 2001 From: ArnoChen Date: Sun, 16 Feb 2025 01:10:43 +0800 Subject: [PATCH 09/12] index multiple files concurrently --- lightrag/api/lightrag_server.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index c51933b3..f1c92adf 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -1206,13 +1206,18 @@ def create_app(args): logging.error(f"Error deleting file {file_path}: {str(e)}") async def batch_index_files(file_paths: List[Path]): - """Index multiple files + """Index multiple files concurrently Args: file_paths: Paths to the files to index """ - for file_path in file_paths: - await index_file(file_path) + if not file_paths: + return + if len(file_paths) == 1: + await index_file(file_paths[0]) + else: + tasks = [index_file(path) for path in file_paths] + await asyncio.gather(*tasks) async def save_temp_file(file: UploadFile = File(...)) -> Path: """Save the uploaded file to a temporary location From bbe24ab7ce4bcbc9600a2ef2a1cf89e39f041cd1 Mon Sep 17 00:00:00 2001 From: ArnoChen Date: Sun, 16 Feb 2025 21:11:05 +0800 Subject: [PATCH 10/12] enhance query and indexing with pipeline --- lightrag/api/lightrag_server.py | 196 +++++++++++++++++++++++--------- 1 file changed, 140 insertions(+), 56 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index f1c92adf..be07a0f3 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -631,9 +631,47 @@ class SearchMode(str, Enum): class QueryRequest(BaseModel): query: str + + """Specifies the retrieval mode""" mode: SearchMode = SearchMode.hybrid - stream: bool = False - only_need_context: bool = False + + """If True, enables streaming output for real-time responses.""" + stream: Optional[bool] = None + + """If True, only returns the retrieved context without generating a response.""" + only_need_context: Optional[bool] = None + + """If True, only returns the generated prompt without producing a response.""" + only_need_prompt: Optional[bool] = None + + """Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'.""" + response_type: Optional[str] = None + + """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.""" + top_k: Optional[int] = None + + """Maximum number of tokens allowed for each retrieved text chunk.""" + max_token_for_text_unit: Optional[int] = None + + """Maximum number of tokens allocated for relationship descriptions in global retrieval.""" + max_token_for_global_context: Optional[int] = None + + """Maximum number of tokens allocated for entity descriptions in local retrieval.""" + max_token_for_local_context: Optional[int] = None + + """List of high-level keywords to prioritize in retrieval.""" + hl_keywords: Optional[List[str]] = None + + """List of low-level keywords to refine retrieval focus.""" + ll_keywords: Optional[List[str]] = None + + """Stores past conversation history to maintain context. + Format: [{"role": "user/assistant", "content": "message"}]. + """ + conversation_history: Optional[List[dict[str, Any]]] = None + + """Number of complete conversation turns (user-assistant pairs) to consider in the response context.""" + history_turns: Optional[int] = None class QueryResponse(BaseModel): @@ -642,7 +680,6 @@ class QueryResponse(BaseModel): class InsertTextRequest(BaseModel): text: str - description: Optional[str] = None class InsertResponse(BaseModel): @@ -650,6 +687,33 @@ class InsertResponse(BaseModel): message: str +def QueryRequestToQueryParams(request: QueryRequest): + param = QueryParam(mode=request.mode, stream=request.stream) + if request.only_need_context is not None: + param.only_need_context = request.only_need_context + if request.only_need_prompt is not None: + param.only_need_prompt = request.only_need_prompt + if request.response_type is not None: + param.response_type = request.response_type + if request.top_k is not None: + param.top_k = request.top_k + if request.max_token_for_text_unit is not None: + param.max_token_for_text_unit = request.max_token_for_text_unit + if request.max_token_for_global_context is not None: + param.max_token_for_global_context = request.max_token_for_global_context + if request.max_token_for_local_context is not None: + param.max_token_for_local_context = request.max_token_for_local_context + if request.hl_keywords is not None: + param.hl_keywords = request.hl_keywords + if request.ll_keywords is not None: + param.ll_keywords = request.ll_keywords + if request.conversation_history is not None: + param.conversation_history = request.conversation_history + if request.history_turns is not None: + param.history_turns = request.history_turns + return param + + def get_api_key_dependency(api_key: Optional[str]): if not api_key: # If no API key is configured, return a dummy dependency that always succeeds @@ -661,7 +725,9 @@ def get_api_key_dependency(api_key: Optional[str]): # If API key is configured, use proper authentication api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) - async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)): + async def api_key_auth( + api_key_header_value: Optional[str] = Security(api_key_header), + ): if not api_key_header_value: raise HTTPException( status_code=HTTP_403_FORBIDDEN, detail="API Key required" @@ -1119,12 +1185,13 @@ def create_app(args): ("llm_response_cache", rag.llm_response_cache), ] - async def index_file(file_path: Path, description: Optional[str] = None): - """Index a file + async def pipeline_enqueue_file(file_path: Path) -> bool: + """Add a file to the queue for processing Args: file_path: Path to the saved file - description: Optional description of the file + Returns: + bool: True if the file was successfully enqueued, False otherwise """ try: content = "" @@ -1177,25 +1244,24 @@ def create_app(args): logging.error( f"Unsupported file type: {file_path.name} (extension {ext})" ) - return + return False - # Add description if provided - if description: - content = f"{description}\n\n{content}" - - # Insert into RAG system + # Insert into the RAG queue if content: - await rag.ainsert(content) + await rag.apipeline_enqueue_documents(content) logging.info( - f"Successfully processed and indexed file: {file_path.name}" + f"Successfully processed and enqueued file: {file_path.name}" ) + return True else: logging.error( f"No content could be extracted from file: {file_path.name}" ) except Exception as e: - logging.error(f"Error indexing file {file_path.name}: {str(e)}") + logging.error( + f"Error processing or enqueueing file {file_path.name}: {str(e)}" + ) logging.error(traceback.format_exc()) finally: if file_path.name.startswith(temp_prefix): @@ -1204,8 +1270,23 @@ def create_app(args): file_path.unlink() except Exception as e: logging.error(f"Error deleting file {file_path}: {str(e)}") + return False - async def batch_index_files(file_paths: List[Path]): + async def pipeline_index_file(file_path: Path): + """Index a file + + Args: + file_path: Path to the saved file + """ + try: + if await pipeline_enqueue_file(file_path): + await rag.apipeline_process_enqueue_documents() + + except Exception as e: + logging.error(f"Error indexing file {file_path.name}: {str(e)}") + logging.error(traceback.format_exc()) + + async def pipeline_index_files(file_paths: List[Path]): """Index multiple files concurrently Args: @@ -1213,11 +1294,31 @@ def create_app(args): """ if not file_paths: return - if len(file_paths) == 1: - await index_file(file_paths[0]) - else: - tasks = [index_file(path) for path in file_paths] - await asyncio.gather(*tasks) + try: + enqueued = False + + if len(file_paths) == 1: + enqueued = await pipeline_enqueue_file(file_paths[0]) + else: + tasks = [pipeline_enqueue_file(path) for path in file_paths] + enqueued = any(await asyncio.gather(*tasks)) + + if enqueued: + await rag.apipeline_process_enqueue_documents() + except Exception as e: + logging.error(f"Error indexing files: {str(e)}") + logging.error(traceback.format_exc()) + + async def pipeline_index_texts(texts: List[str]): + """Index a list of texts + + Args: + texts: The texts to index + """ + if not texts: + return + await rag.apipeline_enqueue_documents(texts) + await rag.apipeline_process_enqueue_documents() async def save_temp_file(file: UploadFile = File(...)) -> Path: """Save the uploaded file to a temporary location @@ -1254,7 +1355,7 @@ def create_app(args): with progress_lock: scan_progress["current_file"] = os.path.basename(file_path) - await index_file(file_path) + await pipeline_index_file(file_path) with progress_lock: scan_progress["indexed_count"] += 1 @@ -1334,7 +1435,7 @@ def create_app(args): shutil.copyfileobj(file.file, buffer) # Add to background tasks - background_tasks.add_task(index_file, file_path) + background_tasks.add_task(pipeline_index_file, file_path) return InsertResponse( status="success", @@ -1366,7 +1467,7 @@ def create_app(args): InsertResponse: A response object containing the status of the operation, a message, and the number of documents inserted. """ try: - background_tasks.add_task(rag.ainsert, request.text) + background_tasks.add_task(pipeline_index_texts, [request.text]) return InsertResponse( status="success", message="Text successfully received. Processing will continue in background.", @@ -1382,16 +1483,13 @@ def create_app(args): dependencies=[Depends(optional_api_key)], ) async def insert_file( - background_tasks: BackgroundTasks, - file: UploadFile = File(...), - description: str = Form(None), + background_tasks: BackgroundTasks, file: UploadFile = File(...) ): """Insert a file directly into the RAG system Args: background_tasks: FastAPI BackgroundTasks for async processing file: Uploaded file - description: Optional description of the file Returns: InsertResponse: Status of the insertion operation @@ -1410,7 +1508,7 @@ def create_app(args): temp_path = save_temp_file(file) # Add to background tasks - background_tasks.add_task(index_file, temp_path, description) + background_tasks.add_task(pipeline_index_file, temp_path) return InsertResponse( status="success", @@ -1456,7 +1554,7 @@ def create_app(args): failed_files.append(f"{file.filename} (unsupported type)") if temp_files: - background_tasks.add_task(batch_index_files, temp_files) + background_tasks.add_task(pipeline_index_files, temp_files) # Prepare status message if inserted_count == len(files): @@ -1515,12 +1613,7 @@ def create_app(args): Handle a POST request at the /query endpoint to process user queries using RAG capabilities. Parameters: - request (QueryRequest): A Pydantic model containing the following fields: - - query (str): The text of the user's query. - - mode (ModeEnum): Optional. Specifies the mode of retrieval augmentation. - - stream (bool): Optional. Determines if the response should be streamed. - - only_need_context (bool): Optional. If true, returns only the context without further processing. - + request (QueryRequest): The request object containing the query parameters. Returns: QueryResponse: A Pydantic model containing the result of the query processing. If a string is returned (e.g., cache hit), it's directly returned. @@ -1532,13 +1625,7 @@ def create_app(args): """ try: response = await rag.aquery( - request.query, - param=QueryParam( - mode=request.mode, - stream=request.stream, - only_need_context=request.only_need_context, - top_k=global_top_k, - ), + request.query, param=QueryRequestToQueryParams(request) ) # If response is a string (e.g. cache hit), return directly @@ -1546,16 +1633,16 @@ def create_app(args): return QueryResponse(response=response) # If it's an async generator, decide whether to stream based on stream parameter - if request.stream: + if request.stream or hasattr(response, "__aiter__"): result = "" async for chunk in response: result += chunk return QueryResponse(response=result) + elif isinstance(response, dict): + result = json.dumps(response, indent=2) + return QueryResponse(response=result) else: - result = "" - async for chunk in response: - result += chunk - return QueryResponse(response=result) + return QueryResponse(response=str(response)) except Exception as e: trace_exception(e) raise HTTPException(status_code=500, detail=str(e)) @@ -1573,14 +1660,11 @@ def create_app(args): StreamingResponse: A streaming response containing the RAG query results. """ try: + params = QueryRequestToQueryParams(request) + + params.stream = True response = await rag.aquery( # Use aquery instead of query, and add await - request.query, - param=QueryParam( - mode=request.mode, - stream=True, - only_need_context=request.only_need_context, - top_k=global_top_k, - ), + request.query, param=params ) from fastapi.responses import StreamingResponse From b580e473249b9ab5fabaa7eda39b9f104c390519 Mon Sep 17 00:00:00 2001 From: ArnoChen Date: Sun, 16 Feb 2025 21:20:43 +0800 Subject: [PATCH 11/12] format --- lightrag/api/lightrag_server.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 98990d26..a392e67a 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -3,7 +3,6 @@ from fastapi import ( HTTPException, File, UploadFile, - Form, BackgroundTasks, ) import asyncio From 893b6455068a70b0716d1db2ee97aa264be2b31c Mon Sep 17 00:00:00 2001 From: ArnoChen Date: Sun, 16 Feb 2025 21:28:58 +0800 Subject: [PATCH 12/12] unify doc status retrieval with get_docs_by_status --- lightrag/base.py | 18 ++++------------ lightrag/kg/json_doc_status_impl.py | 32 +++++------------------------ lightrag/kg/mongo_impl.py | 18 +--------------- lightrag/kg/postgres_impl.py | 18 +--------------- lightrag/lightrag.py | 10 ++++----- 5 files changed, 16 insertions(+), 80 deletions(-) diff --git a/lightrag/base.py b/lightrag/base.py index 3d4fc022..d9a63d26 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -249,20 +249,10 @@ class DocStatusStorage(BaseKVStorage): """Get counts of documents in each status""" raise NotImplementedError - async def get_failed_docs(self) -> dict[str, DocProcessingStatus]: - """Get all failed documents""" - raise NotImplementedError - - async def get_pending_docs(self) -> dict[str, DocProcessingStatus]: - """Get all pending documents""" - raise NotImplementedError - - async def get_processing_docs(self) -> dict[str, DocProcessingStatus]: - """Get all processing documents""" - raise NotImplementedError - - async def get_processed_docs(self) -> dict[str, DocProcessingStatus]: - """Get all procesed documents""" + async def get_docs_by_status( + self, status: DocStatus + ) -> dict[str, DocProcessingStatus]: + """Get all documents with a specific status""" raise NotImplementedError async def update_doc_status(self, data: dict[str, Any]) -> None: diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index fad03acc..ed79a370 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -93,36 +93,14 @@ class JsonDocStatusStorage(DocStatusStorage): counts[doc["status"]] += 1 return counts - async def get_failed_docs(self) -> dict[str, DocProcessingStatus]: - """Get all failed documents""" + async def get_docs_by_status( + self, status: DocStatus + ) -> dict[str, DocProcessingStatus]: + """all documents with a specific status""" return { k: DocProcessingStatus(**v) for k, v in self._data.items() - if v["status"] == DocStatus.FAILED - } - - async def get_pending_docs(self) -> dict[str, DocProcessingStatus]: - """Get all pending documents""" - return { - k: DocProcessingStatus(**v) - for k, v in self._data.items() - if v["status"] == DocStatus.PENDING - } - - async def get_processed_docs(self) -> dict[str, DocProcessingStatus]: - """Get all processed documents""" - return { - k: DocProcessingStatus(**v) - for k, v in self._data.items() - if v["status"] == DocStatus.PROCESSED - } - - async def get_processing_docs(self) -> dict[str, DocProcessingStatus]: - """Get all processing documents""" - return { - k: DocProcessingStatus(**v) - for k, v in self._data.items() - if v["status"] == DocStatus.PROCESSING + if v["status"] == status } async def index_done_callback(self): diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index c216e7be..f6326b76 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -175,7 +175,7 @@ class MongoDocStatusStorage(DocStatusStorage): async def get_docs_by_status( self, status: DocStatus ) -> dict[str, DocProcessingStatus]: - """Get all documents by status""" + """Get all documents with a specific status""" cursor = self._data.find({"status": status.value}) result = await cursor.to_list() return { @@ -191,22 +191,6 @@ class MongoDocStatusStorage(DocStatusStorage): for doc in result } - async def get_failed_docs(self) -> dict[str, DocProcessingStatus]: - """Get all failed documents""" - return await self.get_docs_by_status(DocStatus.FAILED) - - async def get_pending_docs(self) -> dict[str, DocProcessingStatus]: - """Get all pending documents""" - return await self.get_docs_by_status(DocStatus.PENDING) - - async def get_processing_docs(self) -> dict[str, DocProcessingStatus]: - """Get all processing documents""" - return await self.get_docs_by_status(DocStatus.PROCESSING) - - async def get_processed_docs(self) -> dict[str, DocProcessingStatus]: - """Get all procesed documents""" - return await self.get_docs_by_status(DocStatus.PROCESSED) - @dataclass class MongoGraphStorage(BaseGraphStorage): diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index a44aefe7..51b25385 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -468,7 +468,7 @@ class PGDocStatusStorage(DocStatusStorage): async def get_docs_by_status( self, status: DocStatus ) -> Dict[str, DocProcessingStatus]: - """Get all documents by status""" + """all documents with a specific status""" sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2" params = {"workspace": self.db.workspace, "status": status} result = await self.db.query(sql, params, True) @@ -485,22 +485,6 @@ class PGDocStatusStorage(DocStatusStorage): for element in result } - async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]: - """Get all failed documents""" - return await self.get_docs_by_status(DocStatus.FAILED) - - async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]: - """Get all pending documents""" - return await self.get_docs_by_status(DocStatus.PENDING) - - async def get_processing_docs(self) -> dict[str, DocProcessingStatus]: - """Get all processing documents""" - return await self.get_docs_by_status(DocStatus.PROCESSING) - - async def get_processed_docs(self) -> dict[str, DocProcessingStatus]: - """Get all procesed documents""" - return await self.get_docs_by_status(DocStatus.PROCESSED) - async def index_done_callback(self): """Save data after indexing, but for PostgreSQL, we already saved them during the upsert stage, so no action to take here""" logger.info("Doc status had been saved into postgresql db!") diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 23c3df80..9909b4b7 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -89,7 +89,7 @@ STORAGE_IMPLEMENTATIONS = { "PGDocStatusStorage", "MongoDocStatusStorage", ], - "required_methods": ["get_pending_docs"], + "required_methods": ["get_docs_by_status"], }, } @@ -230,7 +230,7 @@ class LightRAG: """LightRAG: Simple and Fast Retrieval-Augmented Generation.""" working_dir: str = field( - default_factory=lambda: f'./lightrag_cache_{datetime.now().strftime("%Y-%m-%d-%H:%M:%S")}' + default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}" ) """Directory where cache and temporary files are stored.""" @@ -715,11 +715,11 @@ class LightRAG: # 1. Get all pending, failed, and abnormally terminated processing documents. to_process_docs: dict[str, DocProcessingStatus] = {} - processing_docs = await self.doc_status.get_processing_docs() + processing_docs = await self.doc_status.get_docs_by_status(DocStatus.PROCESSING) to_process_docs.update(processing_docs) - failed_docs = await self.doc_status.get_failed_docs() + failed_docs = await self.doc_status.get_docs_by_status(DocStatus.FAILED) to_process_docs.update(failed_docs) - pendings_docs = await self.doc_status.get_pending_docs() + pendings_docs = await self.doc_status.get_docs_by_status(DocStatus.PENDING) to_process_docs.update(pendings_docs) if not to_process_docs: