diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index f1c92adf..be07a0f3 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -631,9 +631,47 @@ class SearchMode(str, Enum): class QueryRequest(BaseModel): query: str + + """Specifies the retrieval mode""" mode: SearchMode = SearchMode.hybrid - stream: bool = False - only_need_context: bool = False + + """If True, enables streaming output for real-time responses.""" + stream: Optional[bool] = None + + """If True, only returns the retrieved context without generating a response.""" + only_need_context: Optional[bool] = None + + """If True, only returns the generated prompt without producing a response.""" + only_need_prompt: Optional[bool] = None + + """Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'.""" + response_type: Optional[str] = None + + """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.""" + top_k: Optional[int] = None + + """Maximum number of tokens allowed for each retrieved text chunk.""" + max_token_for_text_unit: Optional[int] = None + + """Maximum number of tokens allocated for relationship descriptions in global retrieval.""" + max_token_for_global_context: Optional[int] = None + + """Maximum number of tokens allocated for entity descriptions in local retrieval.""" + max_token_for_local_context: Optional[int] = None + + """List of high-level keywords to prioritize in retrieval.""" + hl_keywords: Optional[List[str]] = None + + """List of low-level keywords to refine retrieval focus.""" + ll_keywords: Optional[List[str]] = None + + """Stores past conversation history to maintain context. + Format: [{"role": "user/assistant", "content": "message"}]. + """ + conversation_history: Optional[List[dict[str, Any]]] = None + + """Number of complete conversation turns (user-assistant pairs) to consider in the response context.""" + history_turns: Optional[int] = None class QueryResponse(BaseModel): @@ -642,7 +680,6 @@ class QueryResponse(BaseModel): class InsertTextRequest(BaseModel): text: str - description: Optional[str] = None class InsertResponse(BaseModel): @@ -650,6 +687,33 @@ class InsertResponse(BaseModel): message: str +def QueryRequestToQueryParams(request: QueryRequest): + param = QueryParam(mode=request.mode, stream=request.stream) + if request.only_need_context is not None: + param.only_need_context = request.only_need_context + if request.only_need_prompt is not None: + param.only_need_prompt = request.only_need_prompt + if request.response_type is not None: + param.response_type = request.response_type + if request.top_k is not None: + param.top_k = request.top_k + if request.max_token_for_text_unit is not None: + param.max_token_for_text_unit = request.max_token_for_text_unit + if request.max_token_for_global_context is not None: + param.max_token_for_global_context = request.max_token_for_global_context + if request.max_token_for_local_context is not None: + param.max_token_for_local_context = request.max_token_for_local_context + if request.hl_keywords is not None: + param.hl_keywords = request.hl_keywords + if request.ll_keywords is not None: + param.ll_keywords = request.ll_keywords + if request.conversation_history is not None: + param.conversation_history = request.conversation_history + if request.history_turns is not None: + param.history_turns = request.history_turns + return param + + def get_api_key_dependency(api_key: Optional[str]): if not api_key: # If no API key is configured, return a dummy dependency that always succeeds @@ -661,7 +725,9 @@ def get_api_key_dependency(api_key: Optional[str]): # If API key is configured, use proper authentication api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) - async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)): + async def api_key_auth( + api_key_header_value: Optional[str] = Security(api_key_header), + ): if not api_key_header_value: raise HTTPException( status_code=HTTP_403_FORBIDDEN, detail="API Key required" @@ -1119,12 +1185,13 @@ def create_app(args): ("llm_response_cache", rag.llm_response_cache), ] - async def index_file(file_path: Path, description: Optional[str] = None): - """Index a file + async def pipeline_enqueue_file(file_path: Path) -> bool: + """Add a file to the queue for processing Args: file_path: Path to the saved file - description: Optional description of the file + Returns: + bool: True if the file was successfully enqueued, False otherwise """ try: content = "" @@ -1177,25 +1244,24 @@ def create_app(args): logging.error( f"Unsupported file type: {file_path.name} (extension {ext})" ) - return + return False - # Add description if provided - if description: - content = f"{description}\n\n{content}" - - # Insert into RAG system + # Insert into the RAG queue if content: - await rag.ainsert(content) + await rag.apipeline_enqueue_documents(content) logging.info( - f"Successfully processed and indexed file: {file_path.name}" + f"Successfully processed and enqueued file: {file_path.name}" ) + return True else: logging.error( f"No content could be extracted from file: {file_path.name}" ) except Exception as e: - logging.error(f"Error indexing file {file_path.name}: {str(e)}") + logging.error( + f"Error processing or enqueueing file {file_path.name}: {str(e)}" + ) logging.error(traceback.format_exc()) finally: if file_path.name.startswith(temp_prefix): @@ -1204,8 +1270,23 @@ def create_app(args): file_path.unlink() except Exception as e: logging.error(f"Error deleting file {file_path}: {str(e)}") + return False - async def batch_index_files(file_paths: List[Path]): + async def pipeline_index_file(file_path: Path): + """Index a file + + Args: + file_path: Path to the saved file + """ + try: + if await pipeline_enqueue_file(file_path): + await rag.apipeline_process_enqueue_documents() + + except Exception as e: + logging.error(f"Error indexing file {file_path.name}: {str(e)}") + logging.error(traceback.format_exc()) + + async def pipeline_index_files(file_paths: List[Path]): """Index multiple files concurrently Args: @@ -1213,11 +1294,31 @@ def create_app(args): """ if not file_paths: return - if len(file_paths) == 1: - await index_file(file_paths[0]) - else: - tasks = [index_file(path) for path in file_paths] - await asyncio.gather(*tasks) + try: + enqueued = False + + if len(file_paths) == 1: + enqueued = await pipeline_enqueue_file(file_paths[0]) + else: + tasks = [pipeline_enqueue_file(path) for path in file_paths] + enqueued = any(await asyncio.gather(*tasks)) + + if enqueued: + await rag.apipeline_process_enqueue_documents() + except Exception as e: + logging.error(f"Error indexing files: {str(e)}") + logging.error(traceback.format_exc()) + + async def pipeline_index_texts(texts: List[str]): + """Index a list of texts + + Args: + texts: The texts to index + """ + if not texts: + return + await rag.apipeline_enqueue_documents(texts) + await rag.apipeline_process_enqueue_documents() async def save_temp_file(file: UploadFile = File(...)) -> Path: """Save the uploaded file to a temporary location @@ -1254,7 +1355,7 @@ def create_app(args): with progress_lock: scan_progress["current_file"] = os.path.basename(file_path) - await index_file(file_path) + await pipeline_index_file(file_path) with progress_lock: scan_progress["indexed_count"] += 1 @@ -1334,7 +1435,7 @@ def create_app(args): shutil.copyfileobj(file.file, buffer) # Add to background tasks - background_tasks.add_task(index_file, file_path) + background_tasks.add_task(pipeline_index_file, file_path) return InsertResponse( status="success", @@ -1366,7 +1467,7 @@ def create_app(args): InsertResponse: A response object containing the status of the operation, a message, and the number of documents inserted. """ try: - background_tasks.add_task(rag.ainsert, request.text) + background_tasks.add_task(pipeline_index_texts, [request.text]) return InsertResponse( status="success", message="Text successfully received. Processing will continue in background.", @@ -1382,16 +1483,13 @@ def create_app(args): dependencies=[Depends(optional_api_key)], ) async def insert_file( - background_tasks: BackgroundTasks, - file: UploadFile = File(...), - description: str = Form(None), + background_tasks: BackgroundTasks, file: UploadFile = File(...) ): """Insert a file directly into the RAG system Args: background_tasks: FastAPI BackgroundTasks for async processing file: Uploaded file - description: Optional description of the file Returns: InsertResponse: Status of the insertion operation @@ -1410,7 +1508,7 @@ def create_app(args): temp_path = save_temp_file(file) # Add to background tasks - background_tasks.add_task(index_file, temp_path, description) + background_tasks.add_task(pipeline_index_file, temp_path) return InsertResponse( status="success", @@ -1456,7 +1554,7 @@ def create_app(args): failed_files.append(f"{file.filename} (unsupported type)") if temp_files: - background_tasks.add_task(batch_index_files, temp_files) + background_tasks.add_task(pipeline_index_files, temp_files) # Prepare status message if inserted_count == len(files): @@ -1515,12 +1613,7 @@ def create_app(args): Handle a POST request at the /query endpoint to process user queries using RAG capabilities. Parameters: - request (QueryRequest): A Pydantic model containing the following fields: - - query (str): The text of the user's query. - - mode (ModeEnum): Optional. Specifies the mode of retrieval augmentation. - - stream (bool): Optional. Determines if the response should be streamed. - - only_need_context (bool): Optional. If true, returns only the context without further processing. - + request (QueryRequest): The request object containing the query parameters. Returns: QueryResponse: A Pydantic model containing the result of the query processing. If a string is returned (e.g., cache hit), it's directly returned. @@ -1532,13 +1625,7 @@ def create_app(args): """ try: response = await rag.aquery( - request.query, - param=QueryParam( - mode=request.mode, - stream=request.stream, - only_need_context=request.only_need_context, - top_k=global_top_k, - ), + request.query, param=QueryRequestToQueryParams(request) ) # If response is a string (e.g. cache hit), return directly @@ -1546,16 +1633,16 @@ def create_app(args): return QueryResponse(response=response) # If it's an async generator, decide whether to stream based on stream parameter - if request.stream: + if request.stream or hasattr(response, "__aiter__"): result = "" async for chunk in response: result += chunk return QueryResponse(response=result) + elif isinstance(response, dict): + result = json.dumps(response, indent=2) + return QueryResponse(response=result) else: - result = "" - async for chunk in response: - result += chunk - return QueryResponse(response=result) + return QueryResponse(response=str(response)) except Exception as e: trace_exception(e) raise HTTPException(status_code=500, detail=str(e)) @@ -1573,14 +1660,11 @@ def create_app(args): StreamingResponse: A streaming response containing the RAG query results. """ try: + params = QueryRequestToQueryParams(request) + + params.stream = True response = await rag.aquery( # Use aquery instead of query, and add await - request.query, - param=QueryParam( - mode=request.mode, - stream=True, - only_need_context=request.only_need_context, - top_k=global_top_k, - ), + request.query, param=params ) from fastapi.responses import StreamingResponse