diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 59bffa47..756148fe 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -733,6 +733,14 @@ class QueryRequest(BaseModel): ) return conversation_history + def to_query_params(self, is_stream: bool) -> QueryParam: + """Converts a QueryRequest instance into a QueryParam instance.""" + # Use Pydantic's `.model_dump(exclude_none=True)` to remove None values automatically + request_data = self.model_dump(exclude_none=True, exclude={"query"}) + + # Ensure `mode` and `stream` are set explicitly + return QueryParam(**request_data, stream=is_stream) + class QueryResponse(BaseModel): response: str = Field( @@ -769,15 +777,6 @@ class InsertResponse(BaseModel): message: str = Field(description="Message describing the operation result") -def QueryRequestToQueryParams(request: QueryRequest, is_stream: bool) -> QueryParam: - """Converts a QueryRequest instance into a QueryParam instance.""" - # Use Pydantic's `.model_dump(exclude_none=True)` to remove None values automatically - request_data = request.model_dump(exclude_none=True, exclude={"query"}) - - # Ensure `mode` and `stream` are set explicitly - return QueryParam(**request_data, stream=is_stream) - - def get_api_key_dependency(api_key: Optional[str]): if not api_key: # If no API key is configured, return a dummy dependency that always succeeds @@ -1730,7 +1729,7 @@ def create_app(args): """ try: response = await rag.aquery( - request.query, param=QueryRequestToQueryParams(request, False) + request.query, param=request.to_query_params(False) ) # If response is a string (e.g. cache hit), return directly @@ -1759,10 +1758,8 @@ def create_app(args): StreamingResponse: A streaming response containing the RAG query results. """ try: - params = QueryRequestToQueryParams(request, True) - - response = await rag.aquery( # Use aquery instead of query, and add await - request.query, param=params + response = await rag.aquery( + request.query, param=request.to_query_params(True) ) from fastapi.responses import StreamingResponse