Merge pull request #798 from YanSte/api_improvment

API insert Texts, Improvement stream and naming
This commit is contained in:
Yannick Stephan
2025-02-17 11:54:23 +01:00
committed by GitHub

View File

@@ -13,14 +13,13 @@ import re
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
import logging import logging
import argparse import argparse
from typing import List, Any, Optional, Dict from typing import List, Any, Literal, Optional, Dict
from pydantic import BaseModel from pydantic import BaseModel, Field, field_validator
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
from lightrag.base import DocProcessingStatus, DocStatus
from lightrag.types import GPTKeywordExtractionFormat from lightrag.types import GPTKeywordExtractionFormat
from lightrag.api import __api_version__ from lightrag.api import __api_version__
from lightrag.utils import EmbeddingFunc from lightrag.utils import EmbeddingFunc
from lightrag.base import DocStatus, DocProcessingStatus
from enum import Enum
from pathlib import Path from pathlib import Path
import shutil import shutil
import aiofiles import aiofiles
@@ -637,71 +636,155 @@ class DocumentManager:
return any(filename.lower().endswith(ext) for ext in self.supported_extensions) return any(filename.lower().endswith(ext) for ext in self.supported_extensions)
# LightRAG query mode
class SearchMode(str, Enum):
naive = "naive"
local = "local"
global_ = "global"
hybrid = "hybrid"
mix = "mix"
class QueryRequest(BaseModel): class QueryRequest(BaseModel):
query: str query: str = Field(
min_length=1,
description="The query text",
)
"""Specifies the retrieval mode""" mode: Literal["local", "global", "hybrid", "naive", "mix"] = Field(
mode: SearchMode = SearchMode.hybrid default="hybrid",
description="Query mode",
)
"""If True, enables streaming output for real-time responses.""" only_need_context: Optional[bool] = Field(
stream: Optional[bool] = None default=None,
description="If True, only returns the retrieved context without generating a response.",
)
"""If True, only returns the retrieved context without generating a response.""" only_need_prompt: Optional[bool] = Field(
only_need_context: Optional[bool] = None default=None,
description="If True, only returns the generated prompt without producing a response.",
)
"""If True, only returns the generated prompt without producing a response.""" response_type: Optional[str] = Field(
only_need_prompt: Optional[bool] = None min_length=1,
default=None,
description="Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'.",
)
"""Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'.""" top_k: Optional[int] = Field(
response_type: Optional[str] = None ge=1,
default=None,
description="Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.",
)
"""Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.""" max_token_for_text_unit: Optional[int] = Field(
top_k: Optional[int] = None gt=1,
default=None,
description="Maximum number of tokens allowed for each retrieved text chunk.",
)
"""Maximum number of tokens allowed for each retrieved text chunk.""" max_token_for_global_context: Optional[int] = Field(
max_token_for_text_unit: Optional[int] = None gt=1,
default=None,
description="Maximum number of tokens allocated for relationship descriptions in global retrieval.",
)
"""Maximum number of tokens allocated for relationship descriptions in global retrieval.""" max_token_for_local_context: Optional[int] = Field(
max_token_for_global_context: Optional[int] = None gt=1,
default=None,
description="Maximum number of tokens allocated for entity descriptions in local retrieval.",
)
"""Maximum number of tokens allocated for entity descriptions in local retrieval.""" hl_keywords: Optional[List[str]] = Field(
max_token_for_local_context: Optional[int] = None default=None,
description="List of high-level keywords to prioritize in retrieval.",
)
"""List of high-level keywords to prioritize in retrieval.""" ll_keywords: Optional[List[str]] = Field(
hl_keywords: Optional[List[str]] = None default=None,
description="List of low-level keywords to refine retrieval focus.",
)
"""List of low-level keywords to refine retrieval focus.""" conversation_history: Optional[List[dict[str, Any]]] = Field(
ll_keywords: Optional[List[str]] = None default=None,
description="Stores past conversation history to maintain context. Format: [{'role': 'user/assistant', 'content': 'message'}].",
)
"""Stores past conversation history to maintain context. history_turns: Optional[int] = Field(
Format: [{"role": "user/assistant", "content": "message"}]. ge=0,
""" default=None,
conversation_history: Optional[List[dict[str, Any]]] = None description="Number of complete conversation turns (user-assistant pairs) to consider in the response context.",
)
"""Number of complete conversation turns (user-assistant pairs) to consider in the response context.""" @field_validator("query", mode="after")
history_turns: Optional[int] = None @classmethod
def query_strip_after(cls, query: str) -> str:
return query.strip()
@field_validator("hl_keywords", mode="after")
@classmethod
def hl_keywords_strip_after(cls, hl_keywords: List[str] | None) -> List[str] | None:
if hl_keywords is None:
return None
return [keyword.strip() for keyword in hl_keywords]
@field_validator("ll_keywords", mode="after")
@classmethod
def ll_keywords_strip_after(cls, ll_keywords: List[str] | None) -> List[str] | None:
if ll_keywords is None:
return None
return [keyword.strip() for keyword in ll_keywords]
@field_validator("conversation_history", mode="after")
@classmethod
def conversation_history_role_check(
cls, conversation_history: List[dict[str, Any]] | None
) -> List[dict[str, Any]] | None:
if conversation_history is None:
return None
for msg in conversation_history:
if "role" not in msg or msg["role"] not in {"user", "assistant"}:
raise ValueError(
"Each message must have a 'role' key with value 'user' or 'assistant'."
)
return conversation_history
def to_query_params(self, is_stream: bool) -> QueryParam:
"""Converts a QueryRequest instance into a QueryParam instance."""
# Use Pydantic's `.model_dump(exclude_none=True)` to remove None values automatically
request_data = self.model_dump(exclude_none=True, exclude={"query"})
# Ensure `mode` and `stream` are set explicitly
param = QueryParam(**request_data)
param.stream = is_stream
return param
class QueryResponse(BaseModel): class QueryResponse(BaseModel):
response: str response: str = Field(
description="The generated response",
)
class InsertTextRequest(BaseModel): class InsertTextRequest(BaseModel):
text: str text: str = Field(
min_length=1,
description="The text to insert",
)
@field_validator("text", mode="after")
@classmethod
def strip_after(cls, text: str) -> str:
return text.strip()
class InsertTextsRequest(BaseModel):
texts: list[str] = Field(
min_length=1,
description="The texts to insert",
)
@field_validator("texts", mode="after")
@classmethod
def strip_after(cls, texts: list[str]) -> list[str]:
return [text.strip() for text in texts]
class InsertResponse(BaseModel): class InsertResponse(BaseModel):
status: str status: str = Field(description="Status of the operation")
message: str message: str = Field(description="Message describing the operation result")
class DocStatusResponse(BaseModel): class DocStatusResponse(BaseModel):
@@ -720,33 +803,6 @@ class DocsStatusesResponse(BaseModel):
statuses: Dict[DocStatus, List[DocStatusResponse]] = {} statuses: Dict[DocStatus, List[DocStatusResponse]] = {}
def QueryRequestToQueryParams(request: QueryRequest):
param = QueryParam(mode=request.mode, stream=request.stream)
if request.only_need_context is not None:
param.only_need_context = request.only_need_context
if request.only_need_prompt is not None:
param.only_need_prompt = request.only_need_prompt
if request.response_type is not None:
param.response_type = request.response_type
if request.top_k is not None:
param.top_k = request.top_k
if request.max_token_for_text_unit is not None:
param.max_token_for_text_unit = request.max_token_for_text_unit
if request.max_token_for_global_context is not None:
param.max_token_for_global_context = request.max_token_for_global_context
if request.max_token_for_local_context is not None:
param.max_token_for_local_context = request.max_token_for_local_context
if request.hl_keywords is not None:
param.hl_keywords = request.hl_keywords
if request.ll_keywords is not None:
param.ll_keywords = request.ll_keywords
if request.conversation_history is not None:
param.conversation_history = request.conversation_history
if request.history_turns is not None:
param.history_turns = request.history_turns
return param
def get_api_key_dependency(api_key: Optional[str]): def get_api_key_dependency(api_key: Optional[str]):
if not api_key: if not api_key:
# If no API key is configured, return a dummy dependency that always succeeds # If no API key is configured, return a dummy dependency that always succeeds
@@ -1525,6 +1581,37 @@ def create_app(args):
logging.error(traceback.format_exc()) logging.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.post(
"/documents/texts",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_texts(
request: InsertTextsRequest, background_tasks: BackgroundTasks
):
"""
Insert texts into the Retrieval-Augmented Generation (RAG) system.
This endpoint allows you to insert text data into the RAG system for later retrieval and use in generating responses.
Args:
request (InsertTextsRequest): The request body containing the text to be inserted.
background_tasks: FastAPI BackgroundTasks for async processing
Returns:
InsertResponse: A response object containing the status of the operation, a message, and the number of documents inserted.
"""
try:
background_tasks.add_task(pipeline_index_texts, request.texts)
return InsertResponse(
status="success",
message="Text successfully received. Processing will continue in background.",
)
except Exception as e:
logging.error(f"Error /documents/text: {str(e)}")
logging.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@app.post( @app.post(
"/documents/file", "/documents/file",
response_model=InsertResponse, response_model=InsertResponse,
@@ -1569,7 +1656,7 @@ def create_app(args):
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.post( @app.post(
"/documents/batch", "/documents/file_batch",
response_model=InsertResponse, response_model=InsertResponse,
dependencies=[Depends(optional_api_key)], dependencies=[Depends(optional_api_key)],
) )
@@ -1673,20 +1760,14 @@ def create_app(args):
""" """
try: try:
response = await rag.aquery( response = await rag.aquery(
request.query, param=QueryRequestToQueryParams(request) request.query, param=request.to_query_params(False)
) )
# If response is a string (e.g. cache hit), return directly # If response is a string (e.g. cache hit), return directly
if isinstance(response, str): if isinstance(response, str):
return QueryResponse(response=response) return QueryResponse(response=response)
# If it's an async generator, decide whether to stream based on stream parameter if isinstance(response, dict):
if request.stream or hasattr(response, "__aiter__"):
result = ""
async for chunk in response:
result += chunk
return QueryResponse(response=result)
elif isinstance(response, dict):
result = json.dumps(response, indent=2) result = json.dumps(response, indent=2)
return QueryResponse(response=result) return QueryResponse(response=result)
else: else:
@@ -1708,11 +1789,8 @@ def create_app(args):
StreamingResponse: A streaming response containing the RAG query results. StreamingResponse: A streaming response containing the RAG query results.
""" """
try: try:
params = QueryRequestToQueryParams(request) response = await rag.aquery(
request.query, param=request.to_query_params(True)
params.stream = True
response = await rag.aquery( # Use aquery instead of query, and add await
request.query, param=params
) )
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
@@ -1738,7 +1816,7 @@ def create_app(args):
"Cache-Control": "no-cache", "Cache-Control": "no-cache",
"Connection": "keep-alive", "Connection": "keep-alive",
"Content-Type": "application/x-ndjson", "Content-Type": "application/x-ndjson",
"X-Accel-Buffering": "no", # 确保在Nginx代理时正确处理流式响应 "X-Accel-Buffering": "no", # Ensure proper handling of streaming response when proxied by Nginx
}, },
) )
except Exception as e: except Exception as e: