Merge pull request #798 from YanSte/api_improvment
API insert Texts, Improvement stream and naming
This commit is contained in:
@@ -13,14 +13,13 @@ import re
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
import logging
|
||||
import argparse
|
||||
from typing import List, Any, Optional, Dict
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Any, Literal, Optional, Dict
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.base import DocProcessingStatus, DocStatus
|
||||
from lightrag.types import GPTKeywordExtractionFormat
|
||||
from lightrag.api import __api_version__
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
from lightrag.base import DocStatus, DocProcessingStatus
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import aiofiles
|
||||
@@ -637,71 +636,155 @@ class DocumentManager:
|
||||
return any(filename.lower().endswith(ext) for ext in self.supported_extensions)
|
||||
|
||||
|
||||
# LightRAG query mode
|
||||
class SearchMode(str, Enum):
|
||||
naive = "naive"
|
||||
local = "local"
|
||||
global_ = "global"
|
||||
hybrid = "hybrid"
|
||||
mix = "mix"
|
||||
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
query: str
|
||||
query: str = Field(
|
||||
min_length=1,
|
||||
description="The query text",
|
||||
)
|
||||
|
||||
"""Specifies the retrieval mode"""
|
||||
mode: SearchMode = SearchMode.hybrid
|
||||
mode: Literal["local", "global", "hybrid", "naive", "mix"] = Field(
|
||||
default="hybrid",
|
||||
description="Query mode",
|
||||
)
|
||||
|
||||
"""If True, enables streaming output for real-time responses."""
|
||||
stream: Optional[bool] = None
|
||||
only_need_context: Optional[bool] = Field(
|
||||
default=None,
|
||||
description="If True, only returns the retrieved context without generating a response.",
|
||||
)
|
||||
|
||||
"""If True, only returns the retrieved context without generating a response."""
|
||||
only_need_context: Optional[bool] = None
|
||||
only_need_prompt: Optional[bool] = Field(
|
||||
default=None,
|
||||
description="If True, only returns the generated prompt without producing a response.",
|
||||
)
|
||||
|
||||
"""If True, only returns the generated prompt without producing a response."""
|
||||
only_need_prompt: Optional[bool] = None
|
||||
response_type: Optional[str] = Field(
|
||||
min_length=1,
|
||||
default=None,
|
||||
description="Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'.",
|
||||
)
|
||||
|
||||
"""Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'."""
|
||||
response_type: Optional[str] = None
|
||||
top_k: Optional[int] = Field(
|
||||
ge=1,
|
||||
default=None,
|
||||
description="Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.",
|
||||
)
|
||||
|
||||
"""Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
|
||||
top_k: Optional[int] = None
|
||||
max_token_for_text_unit: Optional[int] = Field(
|
||||
gt=1,
|
||||
default=None,
|
||||
description="Maximum number of tokens allowed for each retrieved text chunk.",
|
||||
)
|
||||
|
||||
"""Maximum number of tokens allowed for each retrieved text chunk."""
|
||||
max_token_for_text_unit: Optional[int] = None
|
||||
max_token_for_global_context: Optional[int] = Field(
|
||||
gt=1,
|
||||
default=None,
|
||||
description="Maximum number of tokens allocated for relationship descriptions in global retrieval.",
|
||||
)
|
||||
|
||||
"""Maximum number of tokens allocated for relationship descriptions in global retrieval."""
|
||||
max_token_for_global_context: Optional[int] = None
|
||||
max_token_for_local_context: Optional[int] = Field(
|
||||
gt=1,
|
||||
default=None,
|
||||
description="Maximum number of tokens allocated for entity descriptions in local retrieval.",
|
||||
)
|
||||
|
||||
"""Maximum number of tokens allocated for entity descriptions in local retrieval."""
|
||||
max_token_for_local_context: Optional[int] = None
|
||||
hl_keywords: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="List of high-level keywords to prioritize in retrieval.",
|
||||
)
|
||||
|
||||
"""List of high-level keywords to prioritize in retrieval."""
|
||||
hl_keywords: Optional[List[str]] = None
|
||||
ll_keywords: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="List of low-level keywords to refine retrieval focus.",
|
||||
)
|
||||
|
||||
"""List of low-level keywords to refine retrieval focus."""
|
||||
ll_keywords: Optional[List[str]] = None
|
||||
conversation_history: Optional[List[dict[str, Any]]] = Field(
|
||||
default=None,
|
||||
description="Stores past conversation history to maintain context. Format: [{'role': 'user/assistant', 'content': 'message'}].",
|
||||
)
|
||||
|
||||
"""Stores past conversation history to maintain context.
|
||||
Format: [{"role": "user/assistant", "content": "message"}].
|
||||
"""
|
||||
conversation_history: Optional[List[dict[str, Any]]] = None
|
||||
history_turns: Optional[int] = Field(
|
||||
ge=0,
|
||||
default=None,
|
||||
description="Number of complete conversation turns (user-assistant pairs) to consider in the response context.",
|
||||
)
|
||||
|
||||
"""Number of complete conversation turns (user-assistant pairs) to consider in the response context."""
|
||||
history_turns: Optional[int] = None
|
||||
@field_validator("query", mode="after")
|
||||
@classmethod
|
||||
def query_strip_after(cls, query: str) -> str:
|
||||
return query.strip()
|
||||
|
||||
@field_validator("hl_keywords", mode="after")
|
||||
@classmethod
|
||||
def hl_keywords_strip_after(cls, hl_keywords: List[str] | None) -> List[str] | None:
|
||||
if hl_keywords is None:
|
||||
return None
|
||||
return [keyword.strip() for keyword in hl_keywords]
|
||||
|
||||
@field_validator("ll_keywords", mode="after")
|
||||
@classmethod
|
||||
def ll_keywords_strip_after(cls, ll_keywords: List[str] | None) -> List[str] | None:
|
||||
if ll_keywords is None:
|
||||
return None
|
||||
return [keyword.strip() for keyword in ll_keywords]
|
||||
|
||||
@field_validator("conversation_history", mode="after")
|
||||
@classmethod
|
||||
def conversation_history_role_check(
|
||||
cls, conversation_history: List[dict[str, Any]] | None
|
||||
) -> List[dict[str, Any]] | None:
|
||||
if conversation_history is None:
|
||||
return None
|
||||
for msg in conversation_history:
|
||||
if "role" not in msg or msg["role"] not in {"user", "assistant"}:
|
||||
raise ValueError(
|
||||
"Each message must have a 'role' key with value 'user' or 'assistant'."
|
||||
)
|
||||
return conversation_history
|
||||
|
||||
def to_query_params(self, is_stream: bool) -> QueryParam:
|
||||
"""Converts a QueryRequest instance into a QueryParam instance."""
|
||||
# Use Pydantic's `.model_dump(exclude_none=True)` to remove None values automatically
|
||||
request_data = self.model_dump(exclude_none=True, exclude={"query"})
|
||||
|
||||
# Ensure `mode` and `stream` are set explicitly
|
||||
param = QueryParam(**request_data)
|
||||
param.stream = is_stream
|
||||
return param
|
||||
|
||||
|
||||
class QueryResponse(BaseModel):
|
||||
response: str
|
||||
response: str = Field(
|
||||
description="The generated response",
|
||||
)
|
||||
|
||||
|
||||
class InsertTextRequest(BaseModel):
|
||||
text: str
|
||||
text: str = Field(
|
||||
min_length=1,
|
||||
description="The text to insert",
|
||||
)
|
||||
|
||||
@field_validator("text", mode="after")
|
||||
@classmethod
|
||||
def strip_after(cls, text: str) -> str:
|
||||
return text.strip()
|
||||
|
||||
|
||||
class InsertTextsRequest(BaseModel):
|
||||
texts: list[str] = Field(
|
||||
min_length=1,
|
||||
description="The texts to insert",
|
||||
)
|
||||
|
||||
@field_validator("texts", mode="after")
|
||||
@classmethod
|
||||
def strip_after(cls, texts: list[str]) -> list[str]:
|
||||
return [text.strip() for text in texts]
|
||||
|
||||
|
||||
class InsertResponse(BaseModel):
|
||||
status: str
|
||||
message: str
|
||||
status: str = Field(description="Status of the operation")
|
||||
message: str = Field(description="Message describing the operation result")
|
||||
|
||||
|
||||
class DocStatusResponse(BaseModel):
|
||||
@@ -720,33 +803,6 @@ class DocsStatusesResponse(BaseModel):
|
||||
statuses: Dict[DocStatus, List[DocStatusResponse]] = {}
|
||||
|
||||
|
||||
def QueryRequestToQueryParams(request: QueryRequest):
|
||||
param = QueryParam(mode=request.mode, stream=request.stream)
|
||||
if request.only_need_context is not None:
|
||||
param.only_need_context = request.only_need_context
|
||||
if request.only_need_prompt is not None:
|
||||
param.only_need_prompt = request.only_need_prompt
|
||||
if request.response_type is not None:
|
||||
param.response_type = request.response_type
|
||||
if request.top_k is not None:
|
||||
param.top_k = request.top_k
|
||||
if request.max_token_for_text_unit is not None:
|
||||
param.max_token_for_text_unit = request.max_token_for_text_unit
|
||||
if request.max_token_for_global_context is not None:
|
||||
param.max_token_for_global_context = request.max_token_for_global_context
|
||||
if request.max_token_for_local_context is not None:
|
||||
param.max_token_for_local_context = request.max_token_for_local_context
|
||||
if request.hl_keywords is not None:
|
||||
param.hl_keywords = request.hl_keywords
|
||||
if request.ll_keywords is not None:
|
||||
param.ll_keywords = request.ll_keywords
|
||||
if request.conversation_history is not None:
|
||||
param.conversation_history = request.conversation_history
|
||||
if request.history_turns is not None:
|
||||
param.history_turns = request.history_turns
|
||||
return param
|
||||
|
||||
|
||||
def get_api_key_dependency(api_key: Optional[str]):
|
||||
if not api_key:
|
||||
# If no API key is configured, return a dummy dependency that always succeeds
|
||||
@@ -1525,6 +1581,37 @@ def create_app(args):
|
||||
logging.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post(
|
||||
"/documents/texts",
|
||||
response_model=InsertResponse,
|
||||
dependencies=[Depends(optional_api_key)],
|
||||
)
|
||||
async def insert_texts(
|
||||
request: InsertTextsRequest, background_tasks: BackgroundTasks
|
||||
):
|
||||
"""
|
||||
Insert texts into the Retrieval-Augmented Generation (RAG) system.
|
||||
|
||||
This endpoint allows you to insert text data into the RAG system for later retrieval and use in generating responses.
|
||||
|
||||
Args:
|
||||
request (InsertTextsRequest): The request body containing the text to be inserted.
|
||||
background_tasks: FastAPI BackgroundTasks for async processing
|
||||
|
||||
Returns:
|
||||
InsertResponse: A response object containing the status of the operation, a message, and the number of documents inserted.
|
||||
"""
|
||||
try:
|
||||
background_tasks.add_task(pipeline_index_texts, request.texts)
|
||||
return InsertResponse(
|
||||
status="success",
|
||||
message="Text successfully received. Processing will continue in background.",
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error /documents/text: {str(e)}")
|
||||
logging.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post(
|
||||
"/documents/file",
|
||||
response_model=InsertResponse,
|
||||
@@ -1569,7 +1656,7 @@ def create_app(args):
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post(
|
||||
"/documents/batch",
|
||||
"/documents/file_batch",
|
||||
response_model=InsertResponse,
|
||||
dependencies=[Depends(optional_api_key)],
|
||||
)
|
||||
@@ -1673,20 +1760,14 @@ def create_app(args):
|
||||
"""
|
||||
try:
|
||||
response = await rag.aquery(
|
||||
request.query, param=QueryRequestToQueryParams(request)
|
||||
request.query, param=request.to_query_params(False)
|
||||
)
|
||||
|
||||
# If response is a string (e.g. cache hit), return directly
|
||||
if isinstance(response, str):
|
||||
return QueryResponse(response=response)
|
||||
|
||||
# If it's an async generator, decide whether to stream based on stream parameter
|
||||
if request.stream or hasattr(response, "__aiter__"):
|
||||
result = ""
|
||||
async for chunk in response:
|
||||
result += chunk
|
||||
return QueryResponse(response=result)
|
||||
elif isinstance(response, dict):
|
||||
if isinstance(response, dict):
|
||||
result = json.dumps(response, indent=2)
|
||||
return QueryResponse(response=result)
|
||||
else:
|
||||
@@ -1708,11 +1789,8 @@ def create_app(args):
|
||||
StreamingResponse: A streaming response containing the RAG query results.
|
||||
"""
|
||||
try:
|
||||
params = QueryRequestToQueryParams(request)
|
||||
|
||||
params.stream = True
|
||||
response = await rag.aquery( # Use aquery instead of query, and add await
|
||||
request.query, param=params
|
||||
response = await rag.aquery(
|
||||
request.query, param=request.to_query_params(True)
|
||||
)
|
||||
|
||||
from fastapi.responses import StreamingResponse
|
||||
@@ -1738,7 +1816,7 @@ def create_app(args):
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Content-Type": "application/x-ndjson",
|
||||
"X-Accel-Buffering": "no", # 确保在Nginx代理时正确处理流式响应
|
||||
"X-Accel-Buffering": "no", # Ensure proper handling of streaming response when proxied by Nginx
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
|
Reference in New Issue
Block a user