From 2c8885792c2bb8f0d1e645d19343d6016b34d4ae Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 24 Jan 2025 18:39:43 +0800 Subject: [PATCH] =?UTF-8?q?Refactor=20/api/generate=EF=BC=9Ause=20llm=5Fmo?= =?UTF-8?q?del=5Ffunc=20directly?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lightrag/api/lightrag_server.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index d417d732..36617947 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -1272,23 +1272,17 @@ def create_app(args): # 计算输入token数量 prompt_tokens = estimate_tokens(cleaned_query) - # 调用RAG进行查询 - query_param = QueryParam( - mode=mode, - stream=request.stream, - only_need_context=False - ) - - # 如果有 system prompt,更新 rag 的 llm_model_kwargs + # 直接使用 llm_model_func 进行查询 if request.system: rag.llm_model_kwargs["system_prompt"] = request.system if request.stream: from fastapi.responses import StreamingResponse - response = await rag.aquery( - cleaned_query, - param=query_param + response = await rag.llm_model_func( + cleaned_query, + stream=True, + **rag.llm_model_kwargs ) async def stream_generator(): @@ -1383,7 +1377,11 @@ def create_app(args): ) else: first_chunk_time = time.time_ns() - response_text = await rag.aquery(cleaned_query, param=query_param) + response_text = await rag.llm_model_func( + cleaned_query, + stream=False, + **rag.llm_model_kwargs + ) last_chunk_time = time.time_ns() if not response_text: