Refactor /api/generate:use llm_model_func directly
This commit is contained in:
@@ -1272,23 +1272,17 @@ def create_app(args):
|
||||
# 计算输入token数量
|
||||
prompt_tokens = estimate_tokens(cleaned_query)
|
||||
|
||||
# 调用RAG进行查询
|
||||
query_param = QueryParam(
|
||||
mode=mode,
|
||||
stream=request.stream,
|
||||
only_need_context=False
|
||||
)
|
||||
|
||||
# 如果有 system prompt,更新 rag 的 llm_model_kwargs
|
||||
# 直接使用 llm_model_func 进行查询
|
||||
if request.system:
|
||||
rag.llm_model_kwargs["system_prompt"] = request.system
|
||||
|
||||
if request.stream:
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
response = await rag.aquery(
|
||||
response = await rag.llm_model_func(
|
||||
cleaned_query,
|
||||
param=query_param
|
||||
stream=True,
|
||||
**rag.llm_model_kwargs
|
||||
)
|
||||
|
||||
async def stream_generator():
|
||||
@@ -1383,7 +1377,11 @@ def create_app(args):
|
||||
)
|
||||
else:
|
||||
first_chunk_time = time.time_ns()
|
||||
response_text = await rag.aquery(cleaned_query, param=query_param)
|
||||
response_text = await rag.llm_model_func(
|
||||
cleaned_query,
|
||||
stream=False,
|
||||
**rag.llm_model_kwargs
|
||||
)
|
||||
last_chunk_time = time.time_ns()
|
||||
|
||||
if not response_text:
|
||||
|
Reference in New Issue
Block a user