diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 3928d065..7a50a512 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -13,14 +13,13 @@ import re from fastapi.staticfiles import StaticFiles import logging import argparse -from typing import List, Any, Optional, Dict -from pydantic import BaseModel +from typing import List, Any, Literal, Optional, Dict +from pydantic import BaseModel, Field, field_validator from lightrag import LightRAG, QueryParam +from lightrag.base import DocProcessingStatus, DocStatus from lightrag.types import GPTKeywordExtractionFormat from lightrag.api import __api_version__ from lightrag.utils import EmbeddingFunc -from lightrag.base import DocStatus, DocProcessingStatus -from enum import Enum from pathlib import Path import shutil import aiofiles @@ -637,71 +636,155 @@ class DocumentManager: return any(filename.lower().endswith(ext) for ext in self.supported_extensions) -# LightRAG query mode -class SearchMode(str, Enum): - naive = "naive" - local = "local" - global_ = "global" - hybrid = "hybrid" - mix = "mix" - - class QueryRequest(BaseModel): - query: str + query: str = Field( + min_length=1, + description="The query text", + ) - """Specifies the retrieval mode""" - mode: SearchMode = SearchMode.hybrid + mode: Literal["local", "global", "hybrid", "naive", "mix"] = Field( + default="hybrid", + description="Query mode", + ) - """If True, enables streaming output for real-time responses.""" - stream: Optional[bool] = None + only_need_context: Optional[bool] = Field( + default=None, + description="If True, only returns the retrieved context without generating a response.", + ) - """If True, only returns the retrieved context without generating a response.""" - only_need_context: Optional[bool] = None + only_need_prompt: Optional[bool] = Field( + default=None, + description="If True, only returns the generated prompt without producing a response.", + ) - """If True, only returns the generated prompt without producing a response.""" - only_need_prompt: Optional[bool] = None + response_type: Optional[str] = Field( + min_length=1, + default=None, + description="Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'.", + ) - """Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'.""" - response_type: Optional[str] = None + top_k: Optional[int] = Field( + ge=1, + default=None, + description="Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.", + ) - """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.""" - top_k: Optional[int] = None + max_token_for_text_unit: Optional[int] = Field( + gt=1, + default=None, + description="Maximum number of tokens allowed for each retrieved text chunk.", + ) - """Maximum number of tokens allowed for each retrieved text chunk.""" - max_token_for_text_unit: Optional[int] = None + max_token_for_global_context: Optional[int] = Field( + gt=1, + default=None, + description="Maximum number of tokens allocated for relationship descriptions in global retrieval.", + ) - """Maximum number of tokens allocated for relationship descriptions in global retrieval.""" - max_token_for_global_context: Optional[int] = None + max_token_for_local_context: Optional[int] = Field( + gt=1, + default=None, + description="Maximum number of tokens allocated for entity descriptions in local retrieval.", + ) - """Maximum number of tokens allocated for entity descriptions in local retrieval.""" - max_token_for_local_context: Optional[int] = None + hl_keywords: Optional[List[str]] = Field( + default=None, + description="List of high-level keywords to prioritize in retrieval.", + ) - """List of high-level keywords to prioritize in retrieval.""" - hl_keywords: Optional[List[str]] = None + ll_keywords: Optional[List[str]] = Field( + default=None, + description="List of low-level keywords to refine retrieval focus.", + ) - """List of low-level keywords to refine retrieval focus.""" - ll_keywords: Optional[List[str]] = None + conversation_history: Optional[List[dict[str, Any]]] = Field( + default=None, + description="Stores past conversation history to maintain context. Format: [{'role': 'user/assistant', 'content': 'message'}].", + ) - """Stores past conversation history to maintain context. - Format: [{"role": "user/assistant", "content": "message"}]. - """ - conversation_history: Optional[List[dict[str, Any]]] = None + history_turns: Optional[int] = Field( + ge=0, + default=None, + description="Number of complete conversation turns (user-assistant pairs) to consider in the response context.", + ) - """Number of complete conversation turns (user-assistant pairs) to consider in the response context.""" - history_turns: Optional[int] = None + @field_validator("query", mode="after") + @classmethod + def query_strip_after(cls, query: str) -> str: + return query.strip() + + @field_validator("hl_keywords", mode="after") + @classmethod + def hl_keywords_strip_after(cls, hl_keywords: List[str] | None) -> List[str] | None: + if hl_keywords is None: + return None + return [keyword.strip() for keyword in hl_keywords] + + @field_validator("ll_keywords", mode="after") + @classmethod + def ll_keywords_strip_after(cls, ll_keywords: List[str] | None) -> List[str] | None: + if ll_keywords is None: + return None + return [keyword.strip() for keyword in ll_keywords] + + @field_validator("conversation_history", mode="after") + @classmethod + def conversation_history_role_check( + cls, conversation_history: List[dict[str, Any]] | None + ) -> List[dict[str, Any]] | None: + if conversation_history is None: + return None + for msg in conversation_history: + if "role" not in msg or msg["role"] not in {"user", "assistant"}: + raise ValueError( + "Each message must have a 'role' key with value 'user' or 'assistant'." + ) + return conversation_history + + def to_query_params(self, is_stream: bool) -> QueryParam: + """Converts a QueryRequest instance into a QueryParam instance.""" + # Use Pydantic's `.model_dump(exclude_none=True)` to remove None values automatically + request_data = self.model_dump(exclude_none=True, exclude={"query"}) + + # Ensure `mode` and `stream` are set explicitly + param = QueryParam(**request_data) + param.stream = is_stream + return param class QueryResponse(BaseModel): - response: str + response: str = Field( + description="The generated response", + ) class InsertTextRequest(BaseModel): - text: str + text: str = Field( + min_length=1, + description="The text to insert", + ) + + @field_validator("text", mode="after") + @classmethod + def strip_after(cls, text: str) -> str: + return text.strip() + + +class InsertTextsRequest(BaseModel): + texts: list[str] = Field( + min_length=1, + description="The texts to insert", + ) + + @field_validator("texts", mode="after") + @classmethod + def strip_after(cls, texts: list[str]) -> list[str]: + return [text.strip() for text in texts] class InsertResponse(BaseModel): - status: str - message: str + status: str = Field(description="Status of the operation") + message: str = Field(description="Message describing the operation result") class DocStatusResponse(BaseModel): @@ -720,33 +803,6 @@ class DocsStatusesResponse(BaseModel): statuses: Dict[DocStatus, List[DocStatusResponse]] = {} -def QueryRequestToQueryParams(request: QueryRequest): - param = QueryParam(mode=request.mode, stream=request.stream) - if request.only_need_context is not None: - param.only_need_context = request.only_need_context - if request.only_need_prompt is not None: - param.only_need_prompt = request.only_need_prompt - if request.response_type is not None: - param.response_type = request.response_type - if request.top_k is not None: - param.top_k = request.top_k - if request.max_token_for_text_unit is not None: - param.max_token_for_text_unit = request.max_token_for_text_unit - if request.max_token_for_global_context is not None: - param.max_token_for_global_context = request.max_token_for_global_context - if request.max_token_for_local_context is not None: - param.max_token_for_local_context = request.max_token_for_local_context - if request.hl_keywords is not None: - param.hl_keywords = request.hl_keywords - if request.ll_keywords is not None: - param.ll_keywords = request.ll_keywords - if request.conversation_history is not None: - param.conversation_history = request.conversation_history - if request.history_turns is not None: - param.history_turns = request.history_turns - return param - - def get_api_key_dependency(api_key: Optional[str]): if not api_key: # If no API key is configured, return a dummy dependency that always succeeds @@ -1525,6 +1581,37 @@ def create_app(args): logging.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) + @app.post( + "/documents/texts", + response_model=InsertResponse, + dependencies=[Depends(optional_api_key)], + ) + async def insert_texts( + request: InsertTextsRequest, background_tasks: BackgroundTasks + ): + """ + Insert texts into the Retrieval-Augmented Generation (RAG) system. + + This endpoint allows you to insert text data into the RAG system for later retrieval and use in generating responses. + + Args: + request (InsertTextsRequest): The request body containing the text to be inserted. + background_tasks: FastAPI BackgroundTasks for async processing + + Returns: + InsertResponse: A response object containing the status of the operation, a message, and the number of documents inserted. + """ + try: + background_tasks.add_task(pipeline_index_texts, request.texts) + return InsertResponse( + status="success", + message="Text successfully received. Processing will continue in background.", + ) + except Exception as e: + logging.error(f"Error /documents/text: {str(e)}") + logging.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=str(e)) + @app.post( "/documents/file", response_model=InsertResponse, @@ -1569,7 +1656,7 @@ def create_app(args): raise HTTPException(status_code=500, detail=str(e)) @app.post( - "/documents/batch", + "/documents/file_batch", response_model=InsertResponse, dependencies=[Depends(optional_api_key)], ) @@ -1673,20 +1760,14 @@ def create_app(args): """ try: response = await rag.aquery( - request.query, param=QueryRequestToQueryParams(request) + request.query, param=request.to_query_params(False) ) # If response is a string (e.g. cache hit), return directly if isinstance(response, str): return QueryResponse(response=response) - # If it's an async generator, decide whether to stream based on stream parameter - if request.stream or hasattr(response, "__aiter__"): - result = "" - async for chunk in response: - result += chunk - return QueryResponse(response=result) - elif isinstance(response, dict): + if isinstance(response, dict): result = json.dumps(response, indent=2) return QueryResponse(response=result) else: @@ -1708,11 +1789,8 @@ def create_app(args): StreamingResponse: A streaming response containing the RAG query results. """ try: - params = QueryRequestToQueryParams(request) - - params.stream = True - response = await rag.aquery( # Use aquery instead of query, and add await - request.query, param=params + response = await rag.aquery( + request.query, param=request.to_query_params(True) ) from fastapi.responses import StreamingResponse @@ -1738,7 +1816,7 @@ def create_app(args): "Cache-Control": "no-cache", "Connection": "keep-alive", "Content-Type": "application/x-ndjson", - "X-Accel-Buffering": "no", # 确保在Nginx代理时正确处理流式响应 + "X-Accel-Buffering": "no", # Ensure proper handling of streaming response when proxied by Nginx }, ) except Exception as e: