cleaned code

This commit is contained in:
Yannick Stephan
2025-02-16 20:03:52 +01:00
parent 1fdcd93e84
commit 0b78787b25

View File

@@ -733,6 +733,14 @@ class QueryRequest(BaseModel):
) )
return conversation_history return conversation_history
def to_query_params(self, is_stream: bool) -> QueryParam:
"""Converts a QueryRequest instance into a QueryParam instance."""
# Use Pydantic's `.model_dump(exclude_none=True)` to remove None values automatically
request_data = self.model_dump(exclude_none=True, exclude={"query"})
# Ensure `mode` and `stream` are set explicitly
return QueryParam(**request_data, stream=is_stream)
class QueryResponse(BaseModel): class QueryResponse(BaseModel):
response: str = Field( response: str = Field(
@@ -769,15 +777,6 @@ class InsertResponse(BaseModel):
message: str = Field(description="Message describing the operation result") message: str = Field(description="Message describing the operation result")
def QueryRequestToQueryParams(request: QueryRequest, is_stream: bool) -> QueryParam:
"""Converts a QueryRequest instance into a QueryParam instance."""
# Use Pydantic's `.model_dump(exclude_none=True)` to remove None values automatically
request_data = request.model_dump(exclude_none=True, exclude={"query"})
# Ensure `mode` and `stream` are set explicitly
return QueryParam(**request_data, stream=is_stream)
def get_api_key_dependency(api_key: Optional[str]): def get_api_key_dependency(api_key: Optional[str]):
if not api_key: if not api_key:
# If no API key is configured, return a dummy dependency that always succeeds # If no API key is configured, return a dummy dependency that always succeeds
@@ -1730,7 +1729,7 @@ def create_app(args):
""" """
try: try:
response = await rag.aquery( response = await rag.aquery(
request.query, param=QueryRequestToQueryParams(request, False) request.query, param=request.to_query_params(False)
) )
# If response is a string (e.g. cache hit), return directly # If response is a string (e.g. cache hit), return directly
@@ -1759,10 +1758,8 @@ def create_app(args):
StreamingResponse: A streaming response containing the RAG query results. StreamingResponse: A streaming response containing the RAG query results.
""" """
try: try:
params = QueryRequestToQueryParams(request, True) response = await rag.aquery(
request.query, param=request.to_query_params(True)
response = await rag.aquery( # Use aquery instead of query, and add await
request.query, param=params
) )
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse