Refactor /api/generate:use llm_model_func directly
This commit is contained in:
@@ -1272,23 +1272,17 @@ def create_app(args):
|
|||||||
# 计算输入token数量
|
# 计算输入token数量
|
||||||
prompt_tokens = estimate_tokens(cleaned_query)
|
prompt_tokens = estimate_tokens(cleaned_query)
|
||||||
|
|
||||||
# 调用RAG进行查询
|
# 直接使用 llm_model_func 进行查询
|
||||||
query_param = QueryParam(
|
|
||||||
mode=mode,
|
|
||||||
stream=request.stream,
|
|
||||||
only_need_context=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# 如果有 system prompt,更新 rag 的 llm_model_kwargs
|
|
||||||
if request.system:
|
if request.system:
|
||||||
rag.llm_model_kwargs["system_prompt"] = request.system
|
rag.llm_model_kwargs["system_prompt"] = request.system
|
||||||
|
|
||||||
if request.stream:
|
if request.stream:
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
response = await rag.aquery(
|
response = await rag.llm_model_func(
|
||||||
cleaned_query,
|
cleaned_query,
|
||||||
param=query_param
|
stream=True,
|
||||||
|
**rag.llm_model_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
async def stream_generator():
|
async def stream_generator():
|
||||||
@@ -1383,7 +1377,11 @@ def create_app(args):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
first_chunk_time = time.time_ns()
|
first_chunk_time = time.time_ns()
|
||||||
response_text = await rag.aquery(cleaned_query, param=query_param)
|
response_text = await rag.llm_model_func(
|
||||||
|
cleaned_query,
|
||||||
|
stream=False,
|
||||||
|
**rag.llm_model_kwargs
|
||||||
|
)
|
||||||
last_chunk_time = time.time_ns()
|
last_chunk_time = time.time_ns()
|
||||||
|
|
||||||
if not response_text:
|
if not response_text:
|
||||||
|
Reference in New Issue
Block a user