Refactor /api/generate:use llm_model_func directly

This commit is contained in:
yangdx
2025-01-24 18:39:43 +08:00
parent b94cae9990
commit 2c8885792c

View File

@@ -1272,23 +1272,17 @@ def create_app(args):
# 计算输入token数量
prompt_tokens = estimate_tokens(cleaned_query)
# 调用RAG进行查询
query_param = QueryParam(
mode=mode,
stream=request.stream,
only_need_context=False
)
# 如果有 system prompt更新 rag 的 llm_model_kwargs
# 直接使用 llm_model_func 进行查询
if request.system:
rag.llm_model_kwargs["system_prompt"] = request.system
if request.stream:
from fastapi.responses import StreamingResponse
response = await rag.aquery(
response = await rag.llm_model_func(
cleaned_query,
param=query_param
stream=True,
**rag.llm_model_kwargs
)
async def stream_generator():
@@ -1383,7 +1377,11 @@ def create_app(args):
)
else:
first_chunk_time = time.time_ns()
response_text = await rag.aquery(cleaned_query, param=query_param)
response_text = await rag.llm_model_func(
cleaned_query,
stream=False,
**rag.llm_model_kwargs
)
last_chunk_time = time.time_ns()
if not response_text: