Merge pull request #798 from YanSte/api_improvment
API insert Texts, Improvement stream and naming
This commit is contained in:
@@ -13,14 +13,13 @@ import re
|
|||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
import logging
|
import logging
|
||||||
import argparse
|
import argparse
|
||||||
from typing import List, Any, Optional, Dict
|
from typing import List, Any, Literal, Optional, Dict
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from lightrag import LightRAG, QueryParam
|
from lightrag import LightRAG, QueryParam
|
||||||
|
from lightrag.base import DocProcessingStatus, DocStatus
|
||||||
from lightrag.types import GPTKeywordExtractionFormat
|
from lightrag.types import GPTKeywordExtractionFormat
|
||||||
from lightrag.api import __api_version__
|
from lightrag.api import __api_version__
|
||||||
from lightrag.utils import EmbeddingFunc
|
from lightrag.utils import EmbeddingFunc
|
||||||
from lightrag.base import DocStatus, DocProcessingStatus
|
|
||||||
from enum import Enum
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import shutil
|
import shutil
|
||||||
import aiofiles
|
import aiofiles
|
||||||
@@ -637,71 +636,155 @@ class DocumentManager:
|
|||||||
return any(filename.lower().endswith(ext) for ext in self.supported_extensions)
|
return any(filename.lower().endswith(ext) for ext in self.supported_extensions)
|
||||||
|
|
||||||
|
|
||||||
# LightRAG query mode
|
|
||||||
class SearchMode(str, Enum):
|
|
||||||
naive = "naive"
|
|
||||||
local = "local"
|
|
||||||
global_ = "global"
|
|
||||||
hybrid = "hybrid"
|
|
||||||
mix = "mix"
|
|
||||||
|
|
||||||
|
|
||||||
class QueryRequest(BaseModel):
|
class QueryRequest(BaseModel):
|
||||||
query: str
|
query: str = Field(
|
||||||
|
min_length=1,
|
||||||
|
description="The query text",
|
||||||
|
)
|
||||||
|
|
||||||
"""Specifies the retrieval mode"""
|
mode: Literal["local", "global", "hybrid", "naive", "mix"] = Field(
|
||||||
mode: SearchMode = SearchMode.hybrid
|
default="hybrid",
|
||||||
|
description="Query mode",
|
||||||
|
)
|
||||||
|
|
||||||
"""If True, enables streaming output for real-time responses."""
|
only_need_context: Optional[bool] = Field(
|
||||||
stream: Optional[bool] = None
|
default=None,
|
||||||
|
description="If True, only returns the retrieved context without generating a response.",
|
||||||
|
)
|
||||||
|
|
||||||
"""If True, only returns the retrieved context without generating a response."""
|
only_need_prompt: Optional[bool] = Field(
|
||||||
only_need_context: Optional[bool] = None
|
default=None,
|
||||||
|
description="If True, only returns the generated prompt without producing a response.",
|
||||||
|
)
|
||||||
|
|
||||||
"""If True, only returns the generated prompt without producing a response."""
|
response_type: Optional[str] = Field(
|
||||||
only_need_prompt: Optional[bool] = None
|
min_length=1,
|
||||||
|
default=None,
|
||||||
|
description="Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'.",
|
||||||
|
)
|
||||||
|
|
||||||
"""Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'."""
|
top_k: Optional[int] = Field(
|
||||||
response_type: Optional[str] = None
|
ge=1,
|
||||||
|
default=None,
|
||||||
|
description="Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.",
|
||||||
|
)
|
||||||
|
|
||||||
"""Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
|
max_token_for_text_unit: Optional[int] = Field(
|
||||||
top_k: Optional[int] = None
|
gt=1,
|
||||||
|
default=None,
|
||||||
|
description="Maximum number of tokens allowed for each retrieved text chunk.",
|
||||||
|
)
|
||||||
|
|
||||||
"""Maximum number of tokens allowed for each retrieved text chunk."""
|
max_token_for_global_context: Optional[int] = Field(
|
||||||
max_token_for_text_unit: Optional[int] = None
|
gt=1,
|
||||||
|
default=None,
|
||||||
|
description="Maximum number of tokens allocated for relationship descriptions in global retrieval.",
|
||||||
|
)
|
||||||
|
|
||||||
"""Maximum number of tokens allocated for relationship descriptions in global retrieval."""
|
max_token_for_local_context: Optional[int] = Field(
|
||||||
max_token_for_global_context: Optional[int] = None
|
gt=1,
|
||||||
|
default=None,
|
||||||
|
description="Maximum number of tokens allocated for entity descriptions in local retrieval.",
|
||||||
|
)
|
||||||
|
|
||||||
"""Maximum number of tokens allocated for entity descriptions in local retrieval."""
|
hl_keywords: Optional[List[str]] = Field(
|
||||||
max_token_for_local_context: Optional[int] = None
|
default=None,
|
||||||
|
description="List of high-level keywords to prioritize in retrieval.",
|
||||||
|
)
|
||||||
|
|
||||||
"""List of high-level keywords to prioritize in retrieval."""
|
ll_keywords: Optional[List[str]] = Field(
|
||||||
hl_keywords: Optional[List[str]] = None
|
default=None,
|
||||||
|
description="List of low-level keywords to refine retrieval focus.",
|
||||||
|
)
|
||||||
|
|
||||||
"""List of low-level keywords to refine retrieval focus."""
|
conversation_history: Optional[List[dict[str, Any]]] = Field(
|
||||||
ll_keywords: Optional[List[str]] = None
|
default=None,
|
||||||
|
description="Stores past conversation history to maintain context. Format: [{'role': 'user/assistant', 'content': 'message'}].",
|
||||||
|
)
|
||||||
|
|
||||||
"""Stores past conversation history to maintain context.
|
history_turns: Optional[int] = Field(
|
||||||
Format: [{"role": "user/assistant", "content": "message"}].
|
ge=0,
|
||||||
"""
|
default=None,
|
||||||
conversation_history: Optional[List[dict[str, Any]]] = None
|
description="Number of complete conversation turns (user-assistant pairs) to consider in the response context.",
|
||||||
|
)
|
||||||
|
|
||||||
"""Number of complete conversation turns (user-assistant pairs) to consider in the response context."""
|
@field_validator("query", mode="after")
|
||||||
history_turns: Optional[int] = None
|
@classmethod
|
||||||
|
def query_strip_after(cls, query: str) -> str:
|
||||||
|
return query.strip()
|
||||||
|
|
||||||
|
@field_validator("hl_keywords", mode="after")
|
||||||
|
@classmethod
|
||||||
|
def hl_keywords_strip_after(cls, hl_keywords: List[str] | None) -> List[str] | None:
|
||||||
|
if hl_keywords is None:
|
||||||
|
return None
|
||||||
|
return [keyword.strip() for keyword in hl_keywords]
|
||||||
|
|
||||||
|
@field_validator("ll_keywords", mode="after")
|
||||||
|
@classmethod
|
||||||
|
def ll_keywords_strip_after(cls, ll_keywords: List[str] | None) -> List[str] | None:
|
||||||
|
if ll_keywords is None:
|
||||||
|
return None
|
||||||
|
return [keyword.strip() for keyword in ll_keywords]
|
||||||
|
|
||||||
|
@field_validator("conversation_history", mode="after")
|
||||||
|
@classmethod
|
||||||
|
def conversation_history_role_check(
|
||||||
|
cls, conversation_history: List[dict[str, Any]] | None
|
||||||
|
) -> List[dict[str, Any]] | None:
|
||||||
|
if conversation_history is None:
|
||||||
|
return None
|
||||||
|
for msg in conversation_history:
|
||||||
|
if "role" not in msg or msg["role"] not in {"user", "assistant"}:
|
||||||
|
raise ValueError(
|
||||||
|
"Each message must have a 'role' key with value 'user' or 'assistant'."
|
||||||
|
)
|
||||||
|
return conversation_history
|
||||||
|
|
||||||
|
def to_query_params(self, is_stream: bool) -> QueryParam:
|
||||||
|
"""Converts a QueryRequest instance into a QueryParam instance."""
|
||||||
|
# Use Pydantic's `.model_dump(exclude_none=True)` to remove None values automatically
|
||||||
|
request_data = self.model_dump(exclude_none=True, exclude={"query"})
|
||||||
|
|
||||||
|
# Ensure `mode` and `stream` are set explicitly
|
||||||
|
param = QueryParam(**request_data)
|
||||||
|
param.stream = is_stream
|
||||||
|
return param
|
||||||
|
|
||||||
|
|
||||||
class QueryResponse(BaseModel):
|
class QueryResponse(BaseModel):
|
||||||
response: str
|
response: str = Field(
|
||||||
|
description="The generated response",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class InsertTextRequest(BaseModel):
|
class InsertTextRequest(BaseModel):
|
||||||
text: str
|
text: str = Field(
|
||||||
|
min_length=1,
|
||||||
|
description="The text to insert",
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator("text", mode="after")
|
||||||
|
@classmethod
|
||||||
|
def strip_after(cls, text: str) -> str:
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
|
class InsertTextsRequest(BaseModel):
|
||||||
|
texts: list[str] = Field(
|
||||||
|
min_length=1,
|
||||||
|
description="The texts to insert",
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator("texts", mode="after")
|
||||||
|
@classmethod
|
||||||
|
def strip_after(cls, texts: list[str]) -> list[str]:
|
||||||
|
return [text.strip() for text in texts]
|
||||||
|
|
||||||
|
|
||||||
class InsertResponse(BaseModel):
|
class InsertResponse(BaseModel):
|
||||||
status: str
|
status: str = Field(description="Status of the operation")
|
||||||
message: str
|
message: str = Field(description="Message describing the operation result")
|
||||||
|
|
||||||
|
|
||||||
class DocStatusResponse(BaseModel):
|
class DocStatusResponse(BaseModel):
|
||||||
@@ -720,33 +803,6 @@ class DocsStatusesResponse(BaseModel):
|
|||||||
statuses: Dict[DocStatus, List[DocStatusResponse]] = {}
|
statuses: Dict[DocStatus, List[DocStatusResponse]] = {}
|
||||||
|
|
||||||
|
|
||||||
def QueryRequestToQueryParams(request: QueryRequest):
|
|
||||||
param = QueryParam(mode=request.mode, stream=request.stream)
|
|
||||||
if request.only_need_context is not None:
|
|
||||||
param.only_need_context = request.only_need_context
|
|
||||||
if request.only_need_prompt is not None:
|
|
||||||
param.only_need_prompt = request.only_need_prompt
|
|
||||||
if request.response_type is not None:
|
|
||||||
param.response_type = request.response_type
|
|
||||||
if request.top_k is not None:
|
|
||||||
param.top_k = request.top_k
|
|
||||||
if request.max_token_for_text_unit is not None:
|
|
||||||
param.max_token_for_text_unit = request.max_token_for_text_unit
|
|
||||||
if request.max_token_for_global_context is not None:
|
|
||||||
param.max_token_for_global_context = request.max_token_for_global_context
|
|
||||||
if request.max_token_for_local_context is not None:
|
|
||||||
param.max_token_for_local_context = request.max_token_for_local_context
|
|
||||||
if request.hl_keywords is not None:
|
|
||||||
param.hl_keywords = request.hl_keywords
|
|
||||||
if request.ll_keywords is not None:
|
|
||||||
param.ll_keywords = request.ll_keywords
|
|
||||||
if request.conversation_history is not None:
|
|
||||||
param.conversation_history = request.conversation_history
|
|
||||||
if request.history_turns is not None:
|
|
||||||
param.history_turns = request.history_turns
|
|
||||||
return param
|
|
||||||
|
|
||||||
|
|
||||||
def get_api_key_dependency(api_key: Optional[str]):
|
def get_api_key_dependency(api_key: Optional[str]):
|
||||||
if not api_key:
|
if not api_key:
|
||||||
# If no API key is configured, return a dummy dependency that always succeeds
|
# If no API key is configured, return a dummy dependency that always succeeds
|
||||||
@@ -1525,6 +1581,37 @@ def create_app(args):
|
|||||||
logging.error(traceback.format_exc())
|
logging.error(traceback.format_exc())
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
@app.post(
|
||||||
|
"/documents/texts",
|
||||||
|
response_model=InsertResponse,
|
||||||
|
dependencies=[Depends(optional_api_key)],
|
||||||
|
)
|
||||||
|
async def insert_texts(
|
||||||
|
request: InsertTextsRequest, background_tasks: BackgroundTasks
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Insert texts into the Retrieval-Augmented Generation (RAG) system.
|
||||||
|
|
||||||
|
This endpoint allows you to insert text data into the RAG system for later retrieval and use in generating responses.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (InsertTextsRequest): The request body containing the text to be inserted.
|
||||||
|
background_tasks: FastAPI BackgroundTasks for async processing
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
InsertResponse: A response object containing the status of the operation, a message, and the number of documents inserted.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
background_tasks.add_task(pipeline_index_texts, request.texts)
|
||||||
|
return InsertResponse(
|
||||||
|
status="success",
|
||||||
|
message="Text successfully received. Processing will continue in background.",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error /documents/text: {str(e)}")
|
||||||
|
logging.error(traceback.format_exc())
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@app.post(
|
@app.post(
|
||||||
"/documents/file",
|
"/documents/file",
|
||||||
response_model=InsertResponse,
|
response_model=InsertResponse,
|
||||||
@@ -1569,7 +1656,7 @@ def create_app(args):
|
|||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@app.post(
|
@app.post(
|
||||||
"/documents/batch",
|
"/documents/file_batch",
|
||||||
response_model=InsertResponse,
|
response_model=InsertResponse,
|
||||||
dependencies=[Depends(optional_api_key)],
|
dependencies=[Depends(optional_api_key)],
|
||||||
)
|
)
|
||||||
@@ -1673,20 +1760,14 @@ def create_app(args):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
response = await rag.aquery(
|
response = await rag.aquery(
|
||||||
request.query, param=QueryRequestToQueryParams(request)
|
request.query, param=request.to_query_params(False)
|
||||||
)
|
)
|
||||||
|
|
||||||
# If response is a string (e.g. cache hit), return directly
|
# If response is a string (e.g. cache hit), return directly
|
||||||
if isinstance(response, str):
|
if isinstance(response, str):
|
||||||
return QueryResponse(response=response)
|
return QueryResponse(response=response)
|
||||||
|
|
||||||
# If it's an async generator, decide whether to stream based on stream parameter
|
if isinstance(response, dict):
|
||||||
if request.stream or hasattr(response, "__aiter__"):
|
|
||||||
result = ""
|
|
||||||
async for chunk in response:
|
|
||||||
result += chunk
|
|
||||||
return QueryResponse(response=result)
|
|
||||||
elif isinstance(response, dict):
|
|
||||||
result = json.dumps(response, indent=2)
|
result = json.dumps(response, indent=2)
|
||||||
return QueryResponse(response=result)
|
return QueryResponse(response=result)
|
||||||
else:
|
else:
|
||||||
@@ -1708,11 +1789,8 @@ def create_app(args):
|
|||||||
StreamingResponse: A streaming response containing the RAG query results.
|
StreamingResponse: A streaming response containing the RAG query results.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
params = QueryRequestToQueryParams(request)
|
response = await rag.aquery(
|
||||||
|
request.query, param=request.to_query_params(True)
|
||||||
params.stream = True
|
|
||||||
response = await rag.aquery( # Use aquery instead of query, and add await
|
|
||||||
request.query, param=params
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
@@ -1738,7 +1816,7 @@ def create_app(args):
|
|||||||
"Cache-Control": "no-cache",
|
"Cache-Control": "no-cache",
|
||||||
"Connection": "keep-alive",
|
"Connection": "keep-alive",
|
||||||
"Content-Type": "application/x-ndjson",
|
"Content-Type": "application/x-ndjson",
|
||||||
"X-Accel-Buffering": "no", # 确保在Nginx代理时正确处理流式响应
|
"X-Accel-Buffering": "no", # Ensure proper handling of streaming response when proxied by Nginx
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
Reference in New Issue
Block a user