cleaned code
This commit is contained in:
@@ -733,6 +733,14 @@ class QueryRequest(BaseModel):
|
|||||||
)
|
)
|
||||||
return conversation_history
|
return conversation_history
|
||||||
|
|
||||||
|
def to_query_params(self, is_stream: bool) -> QueryParam:
|
||||||
|
"""Converts a QueryRequest instance into a QueryParam instance."""
|
||||||
|
# Use Pydantic's `.model_dump(exclude_none=True)` to remove None values automatically
|
||||||
|
request_data = self.model_dump(exclude_none=True, exclude={"query"})
|
||||||
|
|
||||||
|
# Ensure `mode` and `stream` are set explicitly
|
||||||
|
return QueryParam(**request_data, stream=is_stream)
|
||||||
|
|
||||||
|
|
||||||
class QueryResponse(BaseModel):
|
class QueryResponse(BaseModel):
|
||||||
response: str = Field(
|
response: str = Field(
|
||||||
@@ -769,15 +777,6 @@ class InsertResponse(BaseModel):
|
|||||||
message: str = Field(description="Message describing the operation result")
|
message: str = Field(description="Message describing the operation result")
|
||||||
|
|
||||||
|
|
||||||
def QueryRequestToQueryParams(request: QueryRequest, is_stream: bool) -> QueryParam:
|
|
||||||
"""Converts a QueryRequest instance into a QueryParam instance."""
|
|
||||||
# Use Pydantic's `.model_dump(exclude_none=True)` to remove None values automatically
|
|
||||||
request_data = request.model_dump(exclude_none=True, exclude={"query"})
|
|
||||||
|
|
||||||
# Ensure `mode` and `stream` are set explicitly
|
|
||||||
return QueryParam(**request_data, stream=is_stream)
|
|
||||||
|
|
||||||
|
|
||||||
def get_api_key_dependency(api_key: Optional[str]):
|
def get_api_key_dependency(api_key: Optional[str]):
|
||||||
if not api_key:
|
if not api_key:
|
||||||
# If no API key is configured, return a dummy dependency that always succeeds
|
# If no API key is configured, return a dummy dependency that always succeeds
|
||||||
@@ -1730,7 +1729,7 @@ def create_app(args):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
response = await rag.aquery(
|
response = await rag.aquery(
|
||||||
request.query, param=QueryRequestToQueryParams(request, False)
|
request.query, param=request.to_query_params(False)
|
||||||
)
|
)
|
||||||
|
|
||||||
# If response is a string (e.g. cache hit), return directly
|
# If response is a string (e.g. cache hit), return directly
|
||||||
@@ -1759,10 +1758,8 @@ def create_app(args):
|
|||||||
StreamingResponse: A streaming response containing the RAG query results.
|
StreamingResponse: A streaming response containing the RAG query results.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
params = QueryRequestToQueryParams(request, True)
|
response = await rag.aquery(
|
||||||
|
request.query, param=request.to_query_params(True)
|
||||||
response = await rag.aquery( # Use aquery instead of query, and add await
|
|
||||||
request.query, param=params
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
Reference in New Issue
Block a user