Merge pull request #798 from YanSte/api_improvment

API insert Texts, Improvement stream and naming
This commit is contained in:
Yannick Stephan
2025-02-17 11:54:23 +01:00
committed by GitHub

View File

@@ -13,14 +13,13 @@ import re
from fastapi.staticfiles import StaticFiles
import logging
import argparse
from typing import List, Any, Optional, Dict
from pydantic import BaseModel
from typing import List, Any, Literal, Optional, Dict
from pydantic import BaseModel, Field, field_validator
from lightrag import LightRAG, QueryParam
from lightrag.base import DocProcessingStatus, DocStatus
from lightrag.types import GPTKeywordExtractionFormat
from lightrag.api import __api_version__
from lightrag.utils import EmbeddingFunc
from lightrag.base import DocStatus, DocProcessingStatus
from enum import Enum
from pathlib import Path
import shutil
import aiofiles
@@ -637,71 +636,155 @@ class DocumentManager:
return any(filename.lower().endswith(ext) for ext in self.supported_extensions)
# LightRAG query mode
class SearchMode(str, Enum):
naive = "naive"
local = "local"
global_ = "global"
hybrid = "hybrid"
mix = "mix"
class QueryRequest(BaseModel):
query: str
query: str = Field(
min_length=1,
description="The query text",
)
"""Specifies the retrieval mode"""
mode: SearchMode = SearchMode.hybrid
mode: Literal["local", "global", "hybrid", "naive", "mix"] = Field(
default="hybrid",
description="Query mode",
)
"""If True, enables streaming output for real-time responses."""
stream: Optional[bool] = None
only_need_context: Optional[bool] = Field(
default=None,
description="If True, only returns the retrieved context without generating a response.",
)
"""If True, only returns the retrieved context without generating a response."""
only_need_context: Optional[bool] = None
only_need_prompt: Optional[bool] = Field(
default=None,
description="If True, only returns the generated prompt without producing a response.",
)
"""If True, only returns the generated prompt without producing a response."""
only_need_prompt: Optional[bool] = None
response_type: Optional[str] = Field(
min_length=1,
default=None,
description="Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'.",
)
"""Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'."""
response_type: Optional[str] = None
top_k: Optional[int] = Field(
ge=1,
default=None,
description="Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.",
)
"""Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
top_k: Optional[int] = None
max_token_for_text_unit: Optional[int] = Field(
gt=1,
default=None,
description="Maximum number of tokens allowed for each retrieved text chunk.",
)
"""Maximum number of tokens allowed for each retrieved text chunk."""
max_token_for_text_unit: Optional[int] = None
max_token_for_global_context: Optional[int] = Field(
gt=1,
default=None,
description="Maximum number of tokens allocated for relationship descriptions in global retrieval.",
)
"""Maximum number of tokens allocated for relationship descriptions in global retrieval."""
max_token_for_global_context: Optional[int] = None
max_token_for_local_context: Optional[int] = Field(
gt=1,
default=None,
description="Maximum number of tokens allocated for entity descriptions in local retrieval.",
)
"""Maximum number of tokens allocated for entity descriptions in local retrieval."""
max_token_for_local_context: Optional[int] = None
hl_keywords: Optional[List[str]] = Field(
default=None,
description="List of high-level keywords to prioritize in retrieval.",
)
"""List of high-level keywords to prioritize in retrieval."""
hl_keywords: Optional[List[str]] = None
ll_keywords: Optional[List[str]] = Field(
default=None,
description="List of low-level keywords to refine retrieval focus.",
)
"""List of low-level keywords to refine retrieval focus."""
ll_keywords: Optional[List[str]] = None
conversation_history: Optional[List[dict[str, Any]]] = Field(
default=None,
description="Stores past conversation history to maintain context. Format: [{'role': 'user/assistant', 'content': 'message'}].",
)
"""Stores past conversation history to maintain context.
Format: [{"role": "user/assistant", "content": "message"}].
"""
conversation_history: Optional[List[dict[str, Any]]] = None
history_turns: Optional[int] = Field(
ge=0,
default=None,
description="Number of complete conversation turns (user-assistant pairs) to consider in the response context.",
)
"""Number of complete conversation turns (user-assistant pairs) to consider in the response context."""
history_turns: Optional[int] = None
@field_validator("query", mode="after")
@classmethod
def query_strip_after(cls, query: str) -> str:
return query.strip()
@field_validator("hl_keywords", mode="after")
@classmethod
def hl_keywords_strip_after(cls, hl_keywords: List[str] | None) -> List[str] | None:
if hl_keywords is None:
return None
return [keyword.strip() for keyword in hl_keywords]
@field_validator("ll_keywords", mode="after")
@classmethod
def ll_keywords_strip_after(cls, ll_keywords: List[str] | None) -> List[str] | None:
if ll_keywords is None:
return None
return [keyword.strip() for keyword in ll_keywords]
@field_validator("conversation_history", mode="after")
@classmethod
def conversation_history_role_check(
cls, conversation_history: List[dict[str, Any]] | None
) -> List[dict[str, Any]] | None:
if conversation_history is None:
return None
for msg in conversation_history:
if "role" not in msg or msg["role"] not in {"user", "assistant"}:
raise ValueError(
"Each message must have a 'role' key with value 'user' or 'assistant'."
)
return conversation_history
def to_query_params(self, is_stream: bool) -> QueryParam:
"""Converts a QueryRequest instance into a QueryParam instance."""
# Use Pydantic's `.model_dump(exclude_none=True)` to remove None values automatically
request_data = self.model_dump(exclude_none=True, exclude={"query"})
# Ensure `mode` and `stream` are set explicitly
param = QueryParam(**request_data)
param.stream = is_stream
return param
class QueryResponse(BaseModel):
response: str
response: str = Field(
description="The generated response",
)
class InsertTextRequest(BaseModel):
text: str
text: str = Field(
min_length=1,
description="The text to insert",
)
@field_validator("text", mode="after")
@classmethod
def strip_after(cls, text: str) -> str:
return text.strip()
class InsertTextsRequest(BaseModel):
texts: list[str] = Field(
min_length=1,
description="The texts to insert",
)
@field_validator("texts", mode="after")
@classmethod
def strip_after(cls, texts: list[str]) -> list[str]:
return [text.strip() for text in texts]
class InsertResponse(BaseModel):
status: str
message: str
status: str = Field(description="Status of the operation")
message: str = Field(description="Message describing the operation result")
class DocStatusResponse(BaseModel):
@@ -720,33 +803,6 @@ class DocsStatusesResponse(BaseModel):
statuses: Dict[DocStatus, List[DocStatusResponse]] = {}
def QueryRequestToQueryParams(request: QueryRequest):
param = QueryParam(mode=request.mode, stream=request.stream)
if request.only_need_context is not None:
param.only_need_context = request.only_need_context
if request.only_need_prompt is not None:
param.only_need_prompt = request.only_need_prompt
if request.response_type is not None:
param.response_type = request.response_type
if request.top_k is not None:
param.top_k = request.top_k
if request.max_token_for_text_unit is not None:
param.max_token_for_text_unit = request.max_token_for_text_unit
if request.max_token_for_global_context is not None:
param.max_token_for_global_context = request.max_token_for_global_context
if request.max_token_for_local_context is not None:
param.max_token_for_local_context = request.max_token_for_local_context
if request.hl_keywords is not None:
param.hl_keywords = request.hl_keywords
if request.ll_keywords is not None:
param.ll_keywords = request.ll_keywords
if request.conversation_history is not None:
param.conversation_history = request.conversation_history
if request.history_turns is not None:
param.history_turns = request.history_turns
return param
def get_api_key_dependency(api_key: Optional[str]):
if not api_key:
# If no API key is configured, return a dummy dependency that always succeeds
@@ -1525,6 +1581,37 @@ def create_app(args):
logging.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@app.post(
"/documents/texts",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_texts(
request: InsertTextsRequest, background_tasks: BackgroundTasks
):
"""
Insert texts into the Retrieval-Augmented Generation (RAG) system.
This endpoint allows you to insert text data into the RAG system for later retrieval and use in generating responses.
Args:
request (InsertTextsRequest): The request body containing the text to be inserted.
background_tasks: FastAPI BackgroundTasks for async processing
Returns:
InsertResponse: A response object containing the status of the operation, a message, and the number of documents inserted.
"""
try:
background_tasks.add_task(pipeline_index_texts, request.texts)
return InsertResponse(
status="success",
message="Text successfully received. Processing will continue in background.",
)
except Exception as e:
logging.error(f"Error /documents/text: {str(e)}")
logging.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@app.post(
"/documents/file",
response_model=InsertResponse,
@@ -1569,7 +1656,7 @@ def create_app(args):
raise HTTPException(status_code=500, detail=str(e))
@app.post(
"/documents/batch",
"/documents/file_batch",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
@@ -1673,20 +1760,14 @@ def create_app(args):
"""
try:
response = await rag.aquery(
request.query, param=QueryRequestToQueryParams(request)
request.query, param=request.to_query_params(False)
)
# If response is a string (e.g. cache hit), return directly
if isinstance(response, str):
return QueryResponse(response=response)
# If it's an async generator, decide whether to stream based on stream parameter
if request.stream or hasattr(response, "__aiter__"):
result = ""
async for chunk in response:
result += chunk
return QueryResponse(response=result)
elif isinstance(response, dict):
if isinstance(response, dict):
result = json.dumps(response, indent=2)
return QueryResponse(response=result)
else:
@@ -1708,11 +1789,8 @@ def create_app(args):
StreamingResponse: A streaming response containing the RAG query results.
"""
try:
params = QueryRequestToQueryParams(request)
params.stream = True
response = await rag.aquery( # Use aquery instead of query, and add await
request.query, param=params
response = await rag.aquery(
request.query, param=request.to_query_params(True)
)
from fastapi.responses import StreamingResponse
@@ -1738,7 +1816,7 @@ def create_app(args):
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Content-Type": "application/x-ndjson",
"X-Accel-Buffering": "no", # 确保在Nginx代理时正确处理流式响应
"X-Accel-Buffering": "no", # Ensure proper handling of streaming response when proxied by Nginx
},
)
except Exception as e: