diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 8b35e34b..e2160b12 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -13,13 +13,12 @@ import re from fastapi.staticfiles import StaticFiles import logging import argparse -from typing import List, Any, Optional, Dict +from typing import List, Any, Literal, Optional, Dict from pydantic import BaseModel from lightrag import LightRAG, QueryParam from lightrag.types import GPTKeywordExtractionFormat from lightrag.api import __api_version__ from lightrag.utils import EmbeddingFunc -from enum import Enum from pathlib import Path import shutil import aiofiles @@ -626,23 +625,11 @@ class DocumentManager: return any(filename.lower().endswith(ext) for ext in self.supported_extensions) -# LightRAG query mode -class SearchMode(str, Enum): - naive = "naive" - local = "local" - global_ = "global" - hybrid = "hybrid" - mix = "mix" - - class QueryRequest(BaseModel): query: str """Specifies the retrieval mode""" - mode: SearchMode = SearchMode.hybrid - - """If True, enables streaming output for real-time responses.""" - stream: Optional[bool] = None + mode: Literal["local", "global", "hybrid", "naive", "mix"] = "hybrid" """If True, only returns the retrieved context without generating a response.""" only_need_context: Optional[bool] = None @@ -688,13 +675,18 @@ class InsertTextRequest(BaseModel): text: str +class InsertTextsRequest(BaseModel): + texts: list[str] + + class InsertResponse(BaseModel): status: str message: str -def QueryRequestToQueryParams(request: QueryRequest): - param = QueryParam(mode=request.mode, stream=request.stream) +def QueryRequestToQueryParams(request: QueryRequest, is_stream: bool): + param = QueryParam(mode=request.mode, stream=is_stream) + if request.only_need_context is not None: param.only_need_context = request.only_need_context if request.only_need_prompt is not None: @@ -1523,7 +1515,7 @@ def create_app(args): logging.error(f"Error /documents/text: {str(e)}") logging.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) - + @app.post( "/documents/file", response_model=InsertResponse, @@ -1653,9 +1645,7 @@ def create_app(args): raise HTTPException(status_code=500, detail=str(e)) @app.post( - "/query", - response_model=QueryResponse, - dependencies=[Depends(optional_api_key)] + "/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)] ) async def query_text(request: QueryRequest): """ @@ -1674,7 +1664,7 @@ def create_app(args): """ try: response = await rag.aquery( - request.query, param=QueryRequestToQueryParams(request) + request.query, param=QueryRequestToQueryParams(request, False) ) # If response is a string (e.g. cache hit), return directly @@ -1703,9 +1693,8 @@ def create_app(args): StreamingResponse: A streaming response containing the RAG query results. """ try: - params = QueryRequestToQueryParams(request) + params = QueryRequestToQueryParams(request, True) - params.stream = True response = await rag.aquery( # Use aquery instead of query, and add await request.query, param=params )