enhance query and indexing with pipeline

This commit is contained in:
ArnoChen
2025-02-16 21:11:05 +08:00
parent 33a4f00b1d
commit bbe24ab7ce

View File

@@ -631,9 +631,47 @@ class SearchMode(str, Enum):
class QueryRequest(BaseModel):
query: str
"""Specifies the retrieval mode"""
mode: SearchMode = SearchMode.hybrid
stream: bool = False
only_need_context: bool = False
"""If True, enables streaming output for real-time responses."""
stream: Optional[bool] = None
"""If True, only returns the retrieved context without generating a response."""
only_need_context: Optional[bool] = None
"""If True, only returns the generated prompt without producing a response."""
only_need_prompt: Optional[bool] = None
"""Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'."""
response_type: Optional[str] = None
"""Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
top_k: Optional[int] = None
"""Maximum number of tokens allowed for each retrieved text chunk."""
max_token_for_text_unit: Optional[int] = None
"""Maximum number of tokens allocated for relationship descriptions in global retrieval."""
max_token_for_global_context: Optional[int] = None
"""Maximum number of tokens allocated for entity descriptions in local retrieval."""
max_token_for_local_context: Optional[int] = None
"""List of high-level keywords to prioritize in retrieval."""
hl_keywords: Optional[List[str]] = None
"""List of low-level keywords to refine retrieval focus."""
ll_keywords: Optional[List[str]] = None
"""Stores past conversation history to maintain context.
Format: [{"role": "user/assistant", "content": "message"}].
"""
conversation_history: Optional[List[dict[str, Any]]] = None
"""Number of complete conversation turns (user-assistant pairs) to consider in the response context."""
history_turns: Optional[int] = None
class QueryResponse(BaseModel):
@@ -642,7 +680,6 @@ class QueryResponse(BaseModel):
class InsertTextRequest(BaseModel):
text: str
description: Optional[str] = None
class InsertResponse(BaseModel):
@@ -650,6 +687,33 @@ class InsertResponse(BaseModel):
message: str
def QueryRequestToQueryParams(request: QueryRequest):
param = QueryParam(mode=request.mode, stream=request.stream)
if request.only_need_context is not None:
param.only_need_context = request.only_need_context
if request.only_need_prompt is not None:
param.only_need_prompt = request.only_need_prompt
if request.response_type is not None:
param.response_type = request.response_type
if request.top_k is not None:
param.top_k = request.top_k
if request.max_token_for_text_unit is not None:
param.max_token_for_text_unit = request.max_token_for_text_unit
if request.max_token_for_global_context is not None:
param.max_token_for_global_context = request.max_token_for_global_context
if request.max_token_for_local_context is not None:
param.max_token_for_local_context = request.max_token_for_local_context
if request.hl_keywords is not None:
param.hl_keywords = request.hl_keywords
if request.ll_keywords is not None:
param.ll_keywords = request.ll_keywords
if request.conversation_history is not None:
param.conversation_history = request.conversation_history
if request.history_turns is not None:
param.history_turns = request.history_turns
return param
def get_api_key_dependency(api_key: Optional[str]):
if not api_key:
# If no API key is configured, return a dummy dependency that always succeeds
@@ -661,7 +725,9 @@ def get_api_key_dependency(api_key: Optional[str]):
# If API key is configured, use proper authentication
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)):
async def api_key_auth(
api_key_header_value: Optional[str] = Security(api_key_header),
):
if not api_key_header_value:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="API Key required"
@@ -1119,12 +1185,13 @@ def create_app(args):
("llm_response_cache", rag.llm_response_cache),
]
async def index_file(file_path: Path, description: Optional[str] = None):
"""Index a file
async def pipeline_enqueue_file(file_path: Path) -> bool:
"""Add a file to the queue for processing
Args:
file_path: Path to the saved file
description: Optional description of the file
Returns:
bool: True if the file was successfully enqueued, False otherwise
"""
try:
content = ""
@@ -1177,25 +1244,24 @@ def create_app(args):
logging.error(
f"Unsupported file type: {file_path.name} (extension {ext})"
)
return
return False
# Add description if provided
if description:
content = f"{description}\n\n{content}"
# Insert into RAG system
# Insert into the RAG queue
if content:
await rag.ainsert(content)
await rag.apipeline_enqueue_documents(content)
logging.info(
f"Successfully processed and indexed file: {file_path.name}"
f"Successfully processed and enqueued file: {file_path.name}"
)
return True
else:
logging.error(
f"No content could be extracted from file: {file_path.name}"
)
except Exception as e:
logging.error(f"Error indexing file {file_path.name}: {str(e)}")
logging.error(
f"Error processing or enqueueing file {file_path.name}: {str(e)}"
)
logging.error(traceback.format_exc())
finally:
if file_path.name.startswith(temp_prefix):
@@ -1204,8 +1270,23 @@ def create_app(args):
file_path.unlink()
except Exception as e:
logging.error(f"Error deleting file {file_path}: {str(e)}")
return False
async def batch_index_files(file_paths: List[Path]):
async def pipeline_index_file(file_path: Path):
"""Index a file
Args:
file_path: Path to the saved file
"""
try:
if await pipeline_enqueue_file(file_path):
await rag.apipeline_process_enqueue_documents()
except Exception as e:
logging.error(f"Error indexing file {file_path.name}: {str(e)}")
logging.error(traceback.format_exc())
async def pipeline_index_files(file_paths: List[Path]):
"""Index multiple files concurrently
Args:
@@ -1213,11 +1294,31 @@ def create_app(args):
"""
if not file_paths:
return
if len(file_paths) == 1:
await index_file(file_paths[0])
else:
tasks = [index_file(path) for path in file_paths]
await asyncio.gather(*tasks)
try:
enqueued = False
if len(file_paths) == 1:
enqueued = await pipeline_enqueue_file(file_paths[0])
else:
tasks = [pipeline_enqueue_file(path) for path in file_paths]
enqueued = any(await asyncio.gather(*tasks))
if enqueued:
await rag.apipeline_process_enqueue_documents()
except Exception as e:
logging.error(f"Error indexing files: {str(e)}")
logging.error(traceback.format_exc())
async def pipeline_index_texts(texts: List[str]):
"""Index a list of texts
Args:
texts: The texts to index
"""
if not texts:
return
await rag.apipeline_enqueue_documents(texts)
await rag.apipeline_process_enqueue_documents()
async def save_temp_file(file: UploadFile = File(...)) -> Path:
"""Save the uploaded file to a temporary location
@@ -1254,7 +1355,7 @@ def create_app(args):
with progress_lock:
scan_progress["current_file"] = os.path.basename(file_path)
await index_file(file_path)
await pipeline_index_file(file_path)
with progress_lock:
scan_progress["indexed_count"] += 1
@@ -1334,7 +1435,7 @@ def create_app(args):
shutil.copyfileobj(file.file, buffer)
# Add to background tasks
background_tasks.add_task(index_file, file_path)
background_tasks.add_task(pipeline_index_file, file_path)
return InsertResponse(
status="success",
@@ -1366,7 +1467,7 @@ def create_app(args):
InsertResponse: A response object containing the status of the operation, a message, and the number of documents inserted.
"""
try:
background_tasks.add_task(rag.ainsert, request.text)
background_tasks.add_task(pipeline_index_texts, [request.text])
return InsertResponse(
status="success",
message="Text successfully received. Processing will continue in background.",
@@ -1382,16 +1483,13 @@ def create_app(args):
dependencies=[Depends(optional_api_key)],
)
async def insert_file(
background_tasks: BackgroundTasks,
file: UploadFile = File(...),
description: str = Form(None),
background_tasks: BackgroundTasks, file: UploadFile = File(...)
):
"""Insert a file directly into the RAG system
Args:
background_tasks: FastAPI BackgroundTasks for async processing
file: Uploaded file
description: Optional description of the file
Returns:
InsertResponse: Status of the insertion operation
@@ -1410,7 +1508,7 @@ def create_app(args):
temp_path = save_temp_file(file)
# Add to background tasks
background_tasks.add_task(index_file, temp_path, description)
background_tasks.add_task(pipeline_index_file, temp_path)
return InsertResponse(
status="success",
@@ -1456,7 +1554,7 @@ def create_app(args):
failed_files.append(f"{file.filename} (unsupported type)")
if temp_files:
background_tasks.add_task(batch_index_files, temp_files)
background_tasks.add_task(pipeline_index_files, temp_files)
# Prepare status message
if inserted_count == len(files):
@@ -1515,12 +1613,7 @@ def create_app(args):
Handle a POST request at the /query endpoint to process user queries using RAG capabilities.
Parameters:
request (QueryRequest): A Pydantic model containing the following fields:
- query (str): The text of the user's query.
- mode (ModeEnum): Optional. Specifies the mode of retrieval augmentation.
- stream (bool): Optional. Determines if the response should be streamed.
- only_need_context (bool): Optional. If true, returns only the context without further processing.
request (QueryRequest): The request object containing the query parameters.
Returns:
QueryResponse: A Pydantic model containing the result of the query processing.
If a string is returned (e.g., cache hit), it's directly returned.
@@ -1532,13 +1625,7 @@ def create_app(args):
"""
try:
response = await rag.aquery(
request.query,
param=QueryParam(
mode=request.mode,
stream=request.stream,
only_need_context=request.only_need_context,
top_k=global_top_k,
),
request.query, param=QueryRequestToQueryParams(request)
)
# If response is a string (e.g. cache hit), return directly
@@ -1546,16 +1633,16 @@ def create_app(args):
return QueryResponse(response=response)
# If it's an async generator, decide whether to stream based on stream parameter
if request.stream:
if request.stream or hasattr(response, "__aiter__"):
result = ""
async for chunk in response:
result += chunk
return QueryResponse(response=result)
elif isinstance(response, dict):
result = json.dumps(response, indent=2)
return QueryResponse(response=result)
else:
result = ""
async for chunk in response:
result += chunk
return QueryResponse(response=result)
return QueryResponse(response=str(response))
except Exception as e:
trace_exception(e)
raise HTTPException(status_code=500, detail=str(e))
@@ -1573,14 +1660,11 @@ def create_app(args):
StreamingResponse: A streaming response containing the RAG query results.
"""
try:
params = QueryRequestToQueryParams(request)
params.stream = True
response = await rag.aquery( # Use aquery instead of query, and add await
request.query,
param=QueryParam(
mode=request.mode,
stream=True,
only_need_context=request.only_need_context,
top_k=global_top_k,
),
request.query, param=params
)
from fastapi.responses import StreamingResponse