enhance query and indexing with pipeline
This commit is contained in:
@@ -631,9 +631,47 @@ class SearchMode(str, Enum):
|
|||||||
|
|
||||||
class QueryRequest(BaseModel):
|
class QueryRequest(BaseModel):
|
||||||
query: str
|
query: str
|
||||||
|
|
||||||
|
"""Specifies the retrieval mode"""
|
||||||
mode: SearchMode = SearchMode.hybrid
|
mode: SearchMode = SearchMode.hybrid
|
||||||
stream: bool = False
|
|
||||||
only_need_context: bool = False
|
"""If True, enables streaming output for real-time responses."""
|
||||||
|
stream: Optional[bool] = None
|
||||||
|
|
||||||
|
"""If True, only returns the retrieved context without generating a response."""
|
||||||
|
only_need_context: Optional[bool] = None
|
||||||
|
|
||||||
|
"""If True, only returns the generated prompt without producing a response."""
|
||||||
|
only_need_prompt: Optional[bool] = None
|
||||||
|
|
||||||
|
"""Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'."""
|
||||||
|
response_type: Optional[str] = None
|
||||||
|
|
||||||
|
"""Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
|
||||||
|
top_k: Optional[int] = None
|
||||||
|
|
||||||
|
"""Maximum number of tokens allowed for each retrieved text chunk."""
|
||||||
|
max_token_for_text_unit: Optional[int] = None
|
||||||
|
|
||||||
|
"""Maximum number of tokens allocated for relationship descriptions in global retrieval."""
|
||||||
|
max_token_for_global_context: Optional[int] = None
|
||||||
|
|
||||||
|
"""Maximum number of tokens allocated for entity descriptions in local retrieval."""
|
||||||
|
max_token_for_local_context: Optional[int] = None
|
||||||
|
|
||||||
|
"""List of high-level keywords to prioritize in retrieval."""
|
||||||
|
hl_keywords: Optional[List[str]] = None
|
||||||
|
|
||||||
|
"""List of low-level keywords to refine retrieval focus."""
|
||||||
|
ll_keywords: Optional[List[str]] = None
|
||||||
|
|
||||||
|
"""Stores past conversation history to maintain context.
|
||||||
|
Format: [{"role": "user/assistant", "content": "message"}].
|
||||||
|
"""
|
||||||
|
conversation_history: Optional[List[dict[str, Any]]] = None
|
||||||
|
|
||||||
|
"""Number of complete conversation turns (user-assistant pairs) to consider in the response context."""
|
||||||
|
history_turns: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
class QueryResponse(BaseModel):
|
class QueryResponse(BaseModel):
|
||||||
@@ -642,7 +680,6 @@ class QueryResponse(BaseModel):
|
|||||||
|
|
||||||
class InsertTextRequest(BaseModel):
|
class InsertTextRequest(BaseModel):
|
||||||
text: str
|
text: str
|
||||||
description: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class InsertResponse(BaseModel):
|
class InsertResponse(BaseModel):
|
||||||
@@ -650,6 +687,33 @@ class InsertResponse(BaseModel):
|
|||||||
message: str
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
def QueryRequestToQueryParams(request: QueryRequest):
|
||||||
|
param = QueryParam(mode=request.mode, stream=request.stream)
|
||||||
|
if request.only_need_context is not None:
|
||||||
|
param.only_need_context = request.only_need_context
|
||||||
|
if request.only_need_prompt is not None:
|
||||||
|
param.only_need_prompt = request.only_need_prompt
|
||||||
|
if request.response_type is not None:
|
||||||
|
param.response_type = request.response_type
|
||||||
|
if request.top_k is not None:
|
||||||
|
param.top_k = request.top_k
|
||||||
|
if request.max_token_for_text_unit is not None:
|
||||||
|
param.max_token_for_text_unit = request.max_token_for_text_unit
|
||||||
|
if request.max_token_for_global_context is not None:
|
||||||
|
param.max_token_for_global_context = request.max_token_for_global_context
|
||||||
|
if request.max_token_for_local_context is not None:
|
||||||
|
param.max_token_for_local_context = request.max_token_for_local_context
|
||||||
|
if request.hl_keywords is not None:
|
||||||
|
param.hl_keywords = request.hl_keywords
|
||||||
|
if request.ll_keywords is not None:
|
||||||
|
param.ll_keywords = request.ll_keywords
|
||||||
|
if request.conversation_history is not None:
|
||||||
|
param.conversation_history = request.conversation_history
|
||||||
|
if request.history_turns is not None:
|
||||||
|
param.history_turns = request.history_turns
|
||||||
|
return param
|
||||||
|
|
||||||
|
|
||||||
def get_api_key_dependency(api_key: Optional[str]):
|
def get_api_key_dependency(api_key: Optional[str]):
|
||||||
if not api_key:
|
if not api_key:
|
||||||
# If no API key is configured, return a dummy dependency that always succeeds
|
# If no API key is configured, return a dummy dependency that always succeeds
|
||||||
@@ -661,7 +725,9 @@ def get_api_key_dependency(api_key: Optional[str]):
|
|||||||
# If API key is configured, use proper authentication
|
# If API key is configured, use proper authentication
|
||||||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||||
|
|
||||||
async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)):
|
async def api_key_auth(
|
||||||
|
api_key_header_value: Optional[str] = Security(api_key_header),
|
||||||
|
):
|
||||||
if not api_key_header_value:
|
if not api_key_header_value:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=HTTP_403_FORBIDDEN, detail="API Key required"
|
status_code=HTTP_403_FORBIDDEN, detail="API Key required"
|
||||||
@@ -1119,12 +1185,13 @@ def create_app(args):
|
|||||||
("llm_response_cache", rag.llm_response_cache),
|
("llm_response_cache", rag.llm_response_cache),
|
||||||
]
|
]
|
||||||
|
|
||||||
async def index_file(file_path: Path, description: Optional[str] = None):
|
async def pipeline_enqueue_file(file_path: Path) -> bool:
|
||||||
"""Index a file
|
"""Add a file to the queue for processing
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_path: Path to the saved file
|
file_path: Path to the saved file
|
||||||
description: Optional description of the file
|
Returns:
|
||||||
|
bool: True if the file was successfully enqueued, False otherwise
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
content = ""
|
content = ""
|
||||||
@@ -1177,25 +1244,24 @@ def create_app(args):
|
|||||||
logging.error(
|
logging.error(
|
||||||
f"Unsupported file type: {file_path.name} (extension {ext})"
|
f"Unsupported file type: {file_path.name} (extension {ext})"
|
||||||
)
|
)
|
||||||
return
|
return False
|
||||||
|
|
||||||
# Add description if provided
|
# Insert into the RAG queue
|
||||||
if description:
|
|
||||||
content = f"{description}\n\n{content}"
|
|
||||||
|
|
||||||
# Insert into RAG system
|
|
||||||
if content:
|
if content:
|
||||||
await rag.ainsert(content)
|
await rag.apipeline_enqueue_documents(content)
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Successfully processed and indexed file: {file_path.name}"
|
f"Successfully processed and enqueued file: {file_path.name}"
|
||||||
)
|
)
|
||||||
|
return True
|
||||||
else:
|
else:
|
||||||
logging.error(
|
logging.error(
|
||||||
f"No content could be extracted from file: {file_path.name}"
|
f"No content could be extracted from file: {file_path.name}"
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error indexing file {file_path.name}: {str(e)}")
|
logging.error(
|
||||||
|
f"Error processing or enqueueing file {file_path.name}: {str(e)}"
|
||||||
|
)
|
||||||
logging.error(traceback.format_exc())
|
logging.error(traceback.format_exc())
|
||||||
finally:
|
finally:
|
||||||
if file_path.name.startswith(temp_prefix):
|
if file_path.name.startswith(temp_prefix):
|
||||||
@@ -1204,8 +1270,23 @@ def create_app(args):
|
|||||||
file_path.unlink()
|
file_path.unlink()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error deleting file {file_path}: {str(e)}")
|
logging.error(f"Error deleting file {file_path}: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
async def batch_index_files(file_paths: List[Path]):
|
async def pipeline_index_file(file_path: Path):
|
||||||
|
"""Index a file
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the saved file
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if await pipeline_enqueue_file(file_path):
|
||||||
|
await rag.apipeline_process_enqueue_documents()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error indexing file {file_path.name}: {str(e)}")
|
||||||
|
logging.error(traceback.format_exc())
|
||||||
|
|
||||||
|
async def pipeline_index_files(file_paths: List[Path]):
|
||||||
"""Index multiple files concurrently
|
"""Index multiple files concurrently
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -1213,11 +1294,31 @@ def create_app(args):
|
|||||||
"""
|
"""
|
||||||
if not file_paths:
|
if not file_paths:
|
||||||
return
|
return
|
||||||
if len(file_paths) == 1:
|
try:
|
||||||
await index_file(file_paths[0])
|
enqueued = False
|
||||||
else:
|
|
||||||
tasks = [index_file(path) for path in file_paths]
|
if len(file_paths) == 1:
|
||||||
await asyncio.gather(*tasks)
|
enqueued = await pipeline_enqueue_file(file_paths[0])
|
||||||
|
else:
|
||||||
|
tasks = [pipeline_enqueue_file(path) for path in file_paths]
|
||||||
|
enqueued = any(await asyncio.gather(*tasks))
|
||||||
|
|
||||||
|
if enqueued:
|
||||||
|
await rag.apipeline_process_enqueue_documents()
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error indexing files: {str(e)}")
|
||||||
|
logging.error(traceback.format_exc())
|
||||||
|
|
||||||
|
async def pipeline_index_texts(texts: List[str]):
|
||||||
|
"""Index a list of texts
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: The texts to index
|
||||||
|
"""
|
||||||
|
if not texts:
|
||||||
|
return
|
||||||
|
await rag.apipeline_enqueue_documents(texts)
|
||||||
|
await rag.apipeline_process_enqueue_documents()
|
||||||
|
|
||||||
async def save_temp_file(file: UploadFile = File(...)) -> Path:
|
async def save_temp_file(file: UploadFile = File(...)) -> Path:
|
||||||
"""Save the uploaded file to a temporary location
|
"""Save the uploaded file to a temporary location
|
||||||
@@ -1254,7 +1355,7 @@ def create_app(args):
|
|||||||
with progress_lock:
|
with progress_lock:
|
||||||
scan_progress["current_file"] = os.path.basename(file_path)
|
scan_progress["current_file"] = os.path.basename(file_path)
|
||||||
|
|
||||||
await index_file(file_path)
|
await pipeline_index_file(file_path)
|
||||||
|
|
||||||
with progress_lock:
|
with progress_lock:
|
||||||
scan_progress["indexed_count"] += 1
|
scan_progress["indexed_count"] += 1
|
||||||
@@ -1334,7 +1435,7 @@ def create_app(args):
|
|||||||
shutil.copyfileobj(file.file, buffer)
|
shutil.copyfileobj(file.file, buffer)
|
||||||
|
|
||||||
# Add to background tasks
|
# Add to background tasks
|
||||||
background_tasks.add_task(index_file, file_path)
|
background_tasks.add_task(pipeline_index_file, file_path)
|
||||||
|
|
||||||
return InsertResponse(
|
return InsertResponse(
|
||||||
status="success",
|
status="success",
|
||||||
@@ -1366,7 +1467,7 @@ def create_app(args):
|
|||||||
InsertResponse: A response object containing the status of the operation, a message, and the number of documents inserted.
|
InsertResponse: A response object containing the status of the operation, a message, and the number of documents inserted.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
background_tasks.add_task(rag.ainsert, request.text)
|
background_tasks.add_task(pipeline_index_texts, [request.text])
|
||||||
return InsertResponse(
|
return InsertResponse(
|
||||||
status="success",
|
status="success",
|
||||||
message="Text successfully received. Processing will continue in background.",
|
message="Text successfully received. Processing will continue in background.",
|
||||||
@@ -1382,16 +1483,13 @@ def create_app(args):
|
|||||||
dependencies=[Depends(optional_api_key)],
|
dependencies=[Depends(optional_api_key)],
|
||||||
)
|
)
|
||||||
async def insert_file(
|
async def insert_file(
|
||||||
background_tasks: BackgroundTasks,
|
background_tasks: BackgroundTasks, file: UploadFile = File(...)
|
||||||
file: UploadFile = File(...),
|
|
||||||
description: str = Form(None),
|
|
||||||
):
|
):
|
||||||
"""Insert a file directly into the RAG system
|
"""Insert a file directly into the RAG system
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
background_tasks: FastAPI BackgroundTasks for async processing
|
background_tasks: FastAPI BackgroundTasks for async processing
|
||||||
file: Uploaded file
|
file: Uploaded file
|
||||||
description: Optional description of the file
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
InsertResponse: Status of the insertion operation
|
InsertResponse: Status of the insertion operation
|
||||||
@@ -1410,7 +1508,7 @@ def create_app(args):
|
|||||||
temp_path = save_temp_file(file)
|
temp_path = save_temp_file(file)
|
||||||
|
|
||||||
# Add to background tasks
|
# Add to background tasks
|
||||||
background_tasks.add_task(index_file, temp_path, description)
|
background_tasks.add_task(pipeline_index_file, temp_path)
|
||||||
|
|
||||||
return InsertResponse(
|
return InsertResponse(
|
||||||
status="success",
|
status="success",
|
||||||
@@ -1456,7 +1554,7 @@ def create_app(args):
|
|||||||
failed_files.append(f"{file.filename} (unsupported type)")
|
failed_files.append(f"{file.filename} (unsupported type)")
|
||||||
|
|
||||||
if temp_files:
|
if temp_files:
|
||||||
background_tasks.add_task(batch_index_files, temp_files)
|
background_tasks.add_task(pipeline_index_files, temp_files)
|
||||||
|
|
||||||
# Prepare status message
|
# Prepare status message
|
||||||
if inserted_count == len(files):
|
if inserted_count == len(files):
|
||||||
@@ -1515,12 +1613,7 @@ def create_app(args):
|
|||||||
Handle a POST request at the /query endpoint to process user queries using RAG capabilities.
|
Handle a POST request at the /query endpoint to process user queries using RAG capabilities.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
request (QueryRequest): A Pydantic model containing the following fields:
|
request (QueryRequest): The request object containing the query parameters.
|
||||||
- query (str): The text of the user's query.
|
|
||||||
- mode (ModeEnum): Optional. Specifies the mode of retrieval augmentation.
|
|
||||||
- stream (bool): Optional. Determines if the response should be streamed.
|
|
||||||
- only_need_context (bool): Optional. If true, returns only the context without further processing.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
QueryResponse: A Pydantic model containing the result of the query processing.
|
QueryResponse: A Pydantic model containing the result of the query processing.
|
||||||
If a string is returned (e.g., cache hit), it's directly returned.
|
If a string is returned (e.g., cache hit), it's directly returned.
|
||||||
@@ -1532,13 +1625,7 @@ def create_app(args):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
response = await rag.aquery(
|
response = await rag.aquery(
|
||||||
request.query,
|
request.query, param=QueryRequestToQueryParams(request)
|
||||||
param=QueryParam(
|
|
||||||
mode=request.mode,
|
|
||||||
stream=request.stream,
|
|
||||||
only_need_context=request.only_need_context,
|
|
||||||
top_k=global_top_k,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# If response is a string (e.g. cache hit), return directly
|
# If response is a string (e.g. cache hit), return directly
|
||||||
@@ -1546,16 +1633,16 @@ def create_app(args):
|
|||||||
return QueryResponse(response=response)
|
return QueryResponse(response=response)
|
||||||
|
|
||||||
# If it's an async generator, decide whether to stream based on stream parameter
|
# If it's an async generator, decide whether to stream based on stream parameter
|
||||||
if request.stream:
|
if request.stream or hasattr(response, "__aiter__"):
|
||||||
result = ""
|
result = ""
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
result += chunk
|
result += chunk
|
||||||
return QueryResponse(response=result)
|
return QueryResponse(response=result)
|
||||||
|
elif isinstance(response, dict):
|
||||||
|
result = json.dumps(response, indent=2)
|
||||||
|
return QueryResponse(response=result)
|
||||||
else:
|
else:
|
||||||
result = ""
|
return QueryResponse(response=str(response))
|
||||||
async for chunk in response:
|
|
||||||
result += chunk
|
|
||||||
return QueryResponse(response=result)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
trace_exception(e)
|
trace_exception(e)
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
@@ -1573,14 +1660,11 @@ def create_app(args):
|
|||||||
StreamingResponse: A streaming response containing the RAG query results.
|
StreamingResponse: A streaming response containing the RAG query results.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
params = QueryRequestToQueryParams(request)
|
||||||
|
|
||||||
|
params.stream = True
|
||||||
response = await rag.aquery( # Use aquery instead of query, and add await
|
response = await rag.aquery( # Use aquery instead of query, and add await
|
||||||
request.query,
|
request.query, param=params
|
||||||
param=QueryParam(
|
|
||||||
mode=request.mode,
|
|
||||||
stream=True,
|
|
||||||
only_need_context=request.only_need_context,
|
|
||||||
top_k=global_top_k,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
Reference in New Issue
Block a user