enhance query and indexing with pipeline
This commit is contained in:
@@ -631,9 +631,47 @@ class SearchMode(str, Enum):
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
query: str
|
||||
|
||||
"""Specifies the retrieval mode"""
|
||||
mode: SearchMode = SearchMode.hybrid
|
||||
stream: bool = False
|
||||
only_need_context: bool = False
|
||||
|
||||
"""If True, enables streaming output for real-time responses."""
|
||||
stream: Optional[bool] = None
|
||||
|
||||
"""If True, only returns the retrieved context without generating a response."""
|
||||
only_need_context: Optional[bool] = None
|
||||
|
||||
"""If True, only returns the generated prompt without producing a response."""
|
||||
only_need_prompt: Optional[bool] = None
|
||||
|
||||
"""Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'."""
|
||||
response_type: Optional[str] = None
|
||||
|
||||
"""Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
|
||||
top_k: Optional[int] = None
|
||||
|
||||
"""Maximum number of tokens allowed for each retrieved text chunk."""
|
||||
max_token_for_text_unit: Optional[int] = None
|
||||
|
||||
"""Maximum number of tokens allocated for relationship descriptions in global retrieval."""
|
||||
max_token_for_global_context: Optional[int] = None
|
||||
|
||||
"""Maximum number of tokens allocated for entity descriptions in local retrieval."""
|
||||
max_token_for_local_context: Optional[int] = None
|
||||
|
||||
"""List of high-level keywords to prioritize in retrieval."""
|
||||
hl_keywords: Optional[List[str]] = None
|
||||
|
||||
"""List of low-level keywords to refine retrieval focus."""
|
||||
ll_keywords: Optional[List[str]] = None
|
||||
|
||||
"""Stores past conversation history to maintain context.
|
||||
Format: [{"role": "user/assistant", "content": "message"}].
|
||||
"""
|
||||
conversation_history: Optional[List[dict[str, Any]]] = None
|
||||
|
||||
"""Number of complete conversation turns (user-assistant pairs) to consider in the response context."""
|
||||
history_turns: Optional[int] = None
|
||||
|
||||
|
||||
class QueryResponse(BaseModel):
|
||||
@@ -642,7 +680,6 @@ class QueryResponse(BaseModel):
|
||||
|
||||
class InsertTextRequest(BaseModel):
|
||||
text: str
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class InsertResponse(BaseModel):
|
||||
@@ -650,6 +687,33 @@ class InsertResponse(BaseModel):
|
||||
message: str
|
||||
|
||||
|
||||
def QueryRequestToQueryParams(request: QueryRequest):
|
||||
param = QueryParam(mode=request.mode, stream=request.stream)
|
||||
if request.only_need_context is not None:
|
||||
param.only_need_context = request.only_need_context
|
||||
if request.only_need_prompt is not None:
|
||||
param.only_need_prompt = request.only_need_prompt
|
||||
if request.response_type is not None:
|
||||
param.response_type = request.response_type
|
||||
if request.top_k is not None:
|
||||
param.top_k = request.top_k
|
||||
if request.max_token_for_text_unit is not None:
|
||||
param.max_token_for_text_unit = request.max_token_for_text_unit
|
||||
if request.max_token_for_global_context is not None:
|
||||
param.max_token_for_global_context = request.max_token_for_global_context
|
||||
if request.max_token_for_local_context is not None:
|
||||
param.max_token_for_local_context = request.max_token_for_local_context
|
||||
if request.hl_keywords is not None:
|
||||
param.hl_keywords = request.hl_keywords
|
||||
if request.ll_keywords is not None:
|
||||
param.ll_keywords = request.ll_keywords
|
||||
if request.conversation_history is not None:
|
||||
param.conversation_history = request.conversation_history
|
||||
if request.history_turns is not None:
|
||||
param.history_turns = request.history_turns
|
||||
return param
|
||||
|
||||
|
||||
def get_api_key_dependency(api_key: Optional[str]):
|
||||
if not api_key:
|
||||
# If no API key is configured, return a dummy dependency that always succeeds
|
||||
@@ -661,7 +725,9 @@ def get_api_key_dependency(api_key: Optional[str]):
|
||||
# If API key is configured, use proper authentication
|
||||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
|
||||
async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)):
|
||||
async def api_key_auth(
|
||||
api_key_header_value: Optional[str] = Security(api_key_header),
|
||||
):
|
||||
if not api_key_header_value:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN, detail="API Key required"
|
||||
@@ -1119,12 +1185,13 @@ def create_app(args):
|
||||
("llm_response_cache", rag.llm_response_cache),
|
||||
]
|
||||
|
||||
async def index_file(file_path: Path, description: Optional[str] = None):
|
||||
"""Index a file
|
||||
async def pipeline_enqueue_file(file_path: Path) -> bool:
|
||||
"""Add a file to the queue for processing
|
||||
|
||||
Args:
|
||||
file_path: Path to the saved file
|
||||
description: Optional description of the file
|
||||
Returns:
|
||||
bool: True if the file was successfully enqueued, False otherwise
|
||||
"""
|
||||
try:
|
||||
content = ""
|
||||
@@ -1177,25 +1244,24 @@ def create_app(args):
|
||||
logging.error(
|
||||
f"Unsupported file type: {file_path.name} (extension {ext})"
|
||||
)
|
||||
return
|
||||
return False
|
||||
|
||||
# Add description if provided
|
||||
if description:
|
||||
content = f"{description}\n\n{content}"
|
||||
|
||||
# Insert into RAG system
|
||||
# Insert into the RAG queue
|
||||
if content:
|
||||
await rag.ainsert(content)
|
||||
await rag.apipeline_enqueue_documents(content)
|
||||
logging.info(
|
||||
f"Successfully processed and indexed file: {file_path.name}"
|
||||
f"Successfully processed and enqueued file: {file_path.name}"
|
||||
)
|
||||
return True
|
||||
else:
|
||||
logging.error(
|
||||
f"No content could be extracted from file: {file_path.name}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error indexing file {file_path.name}: {str(e)}")
|
||||
logging.error(
|
||||
f"Error processing or enqueueing file {file_path.name}: {str(e)}"
|
||||
)
|
||||
logging.error(traceback.format_exc())
|
||||
finally:
|
||||
if file_path.name.startswith(temp_prefix):
|
||||
@@ -1204,8 +1270,23 @@ def create_app(args):
|
||||
file_path.unlink()
|
||||
except Exception as e:
|
||||
logging.error(f"Error deleting file {file_path}: {str(e)}")
|
||||
return False
|
||||
|
||||
async def batch_index_files(file_paths: List[Path]):
|
||||
async def pipeline_index_file(file_path: Path):
|
||||
"""Index a file
|
||||
|
||||
Args:
|
||||
file_path: Path to the saved file
|
||||
"""
|
||||
try:
|
||||
if await pipeline_enqueue_file(file_path):
|
||||
await rag.apipeline_process_enqueue_documents()
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error indexing file {file_path.name}: {str(e)}")
|
||||
logging.error(traceback.format_exc())
|
||||
|
||||
async def pipeline_index_files(file_paths: List[Path]):
|
||||
"""Index multiple files concurrently
|
||||
|
||||
Args:
|
||||
@@ -1213,11 +1294,31 @@ def create_app(args):
|
||||
"""
|
||||
if not file_paths:
|
||||
return
|
||||
try:
|
||||
enqueued = False
|
||||
|
||||
if len(file_paths) == 1:
|
||||
await index_file(file_paths[0])
|
||||
enqueued = await pipeline_enqueue_file(file_paths[0])
|
||||
else:
|
||||
tasks = [index_file(path) for path in file_paths]
|
||||
await asyncio.gather(*tasks)
|
||||
tasks = [pipeline_enqueue_file(path) for path in file_paths]
|
||||
enqueued = any(await asyncio.gather(*tasks))
|
||||
|
||||
if enqueued:
|
||||
await rag.apipeline_process_enqueue_documents()
|
||||
except Exception as e:
|
||||
logging.error(f"Error indexing files: {str(e)}")
|
||||
logging.error(traceback.format_exc())
|
||||
|
||||
async def pipeline_index_texts(texts: List[str]):
|
||||
"""Index a list of texts
|
||||
|
||||
Args:
|
||||
texts: The texts to index
|
||||
"""
|
||||
if not texts:
|
||||
return
|
||||
await rag.apipeline_enqueue_documents(texts)
|
||||
await rag.apipeline_process_enqueue_documents()
|
||||
|
||||
async def save_temp_file(file: UploadFile = File(...)) -> Path:
|
||||
"""Save the uploaded file to a temporary location
|
||||
@@ -1254,7 +1355,7 @@ def create_app(args):
|
||||
with progress_lock:
|
||||
scan_progress["current_file"] = os.path.basename(file_path)
|
||||
|
||||
await index_file(file_path)
|
||||
await pipeline_index_file(file_path)
|
||||
|
||||
with progress_lock:
|
||||
scan_progress["indexed_count"] += 1
|
||||
@@ -1334,7 +1435,7 @@ def create_app(args):
|
||||
shutil.copyfileobj(file.file, buffer)
|
||||
|
||||
# Add to background tasks
|
||||
background_tasks.add_task(index_file, file_path)
|
||||
background_tasks.add_task(pipeline_index_file, file_path)
|
||||
|
||||
return InsertResponse(
|
||||
status="success",
|
||||
@@ -1366,7 +1467,7 @@ def create_app(args):
|
||||
InsertResponse: A response object containing the status of the operation, a message, and the number of documents inserted.
|
||||
"""
|
||||
try:
|
||||
background_tasks.add_task(rag.ainsert, request.text)
|
||||
background_tasks.add_task(pipeline_index_texts, [request.text])
|
||||
return InsertResponse(
|
||||
status="success",
|
||||
message="Text successfully received. Processing will continue in background.",
|
||||
@@ -1382,16 +1483,13 @@ def create_app(args):
|
||||
dependencies=[Depends(optional_api_key)],
|
||||
)
|
||||
async def insert_file(
|
||||
background_tasks: BackgroundTasks,
|
||||
file: UploadFile = File(...),
|
||||
description: str = Form(None),
|
||||
background_tasks: BackgroundTasks, file: UploadFile = File(...)
|
||||
):
|
||||
"""Insert a file directly into the RAG system
|
||||
|
||||
Args:
|
||||
background_tasks: FastAPI BackgroundTasks for async processing
|
||||
file: Uploaded file
|
||||
description: Optional description of the file
|
||||
|
||||
Returns:
|
||||
InsertResponse: Status of the insertion operation
|
||||
@@ -1410,7 +1508,7 @@ def create_app(args):
|
||||
temp_path = save_temp_file(file)
|
||||
|
||||
# Add to background tasks
|
||||
background_tasks.add_task(index_file, temp_path, description)
|
||||
background_tasks.add_task(pipeline_index_file, temp_path)
|
||||
|
||||
return InsertResponse(
|
||||
status="success",
|
||||
@@ -1456,7 +1554,7 @@ def create_app(args):
|
||||
failed_files.append(f"{file.filename} (unsupported type)")
|
||||
|
||||
if temp_files:
|
||||
background_tasks.add_task(batch_index_files, temp_files)
|
||||
background_tasks.add_task(pipeline_index_files, temp_files)
|
||||
|
||||
# Prepare status message
|
||||
if inserted_count == len(files):
|
||||
@@ -1515,12 +1613,7 @@ def create_app(args):
|
||||
Handle a POST request at the /query endpoint to process user queries using RAG capabilities.
|
||||
|
||||
Parameters:
|
||||
request (QueryRequest): A Pydantic model containing the following fields:
|
||||
- query (str): The text of the user's query.
|
||||
- mode (ModeEnum): Optional. Specifies the mode of retrieval augmentation.
|
||||
- stream (bool): Optional. Determines if the response should be streamed.
|
||||
- only_need_context (bool): Optional. If true, returns only the context without further processing.
|
||||
|
||||
request (QueryRequest): The request object containing the query parameters.
|
||||
Returns:
|
||||
QueryResponse: A Pydantic model containing the result of the query processing.
|
||||
If a string is returned (e.g., cache hit), it's directly returned.
|
||||
@@ -1532,13 +1625,7 @@ def create_app(args):
|
||||
"""
|
||||
try:
|
||||
response = await rag.aquery(
|
||||
request.query,
|
||||
param=QueryParam(
|
||||
mode=request.mode,
|
||||
stream=request.stream,
|
||||
only_need_context=request.only_need_context,
|
||||
top_k=global_top_k,
|
||||
),
|
||||
request.query, param=QueryRequestToQueryParams(request)
|
||||
)
|
||||
|
||||
# If response is a string (e.g. cache hit), return directly
|
||||
@@ -1546,16 +1633,16 @@ def create_app(args):
|
||||
return QueryResponse(response=response)
|
||||
|
||||
# If it's an async generator, decide whether to stream based on stream parameter
|
||||
if request.stream:
|
||||
if request.stream or hasattr(response, "__aiter__"):
|
||||
result = ""
|
||||
async for chunk in response:
|
||||
result += chunk
|
||||
return QueryResponse(response=result)
|
||||
elif isinstance(response, dict):
|
||||
result = json.dumps(response, indent=2)
|
||||
return QueryResponse(response=result)
|
||||
else:
|
||||
result = ""
|
||||
async for chunk in response:
|
||||
result += chunk
|
||||
return QueryResponse(response=result)
|
||||
return QueryResponse(response=str(response))
|
||||
except Exception as e:
|
||||
trace_exception(e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -1573,14 +1660,11 @@ def create_app(args):
|
||||
StreamingResponse: A streaming response containing the RAG query results.
|
||||
"""
|
||||
try:
|
||||
params = QueryRequestToQueryParams(request)
|
||||
|
||||
params.stream = True
|
||||
response = await rag.aquery( # Use aquery instead of query, and add await
|
||||
request.query,
|
||||
param=QueryParam(
|
||||
mode=request.mode,
|
||||
stream=True,
|
||||
only_need_context=request.only_need_context,
|
||||
top_k=global_top_k,
|
||||
),
|
||||
request.query, param=params
|
||||
)
|
||||
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
Reference in New Issue
Block a user