临时保存
This commit is contained in:
@@ -472,10 +472,25 @@ def create_app(args):
|
||||
)
|
||||
|
||||
if request.stream:
|
||||
result = ""
|
||||
async for chunk in response:
|
||||
result += chunk
|
||||
return QueryResponse(response=result)
|
||||
from fastapi.responses import StreamingResponse
|
||||
import json
|
||||
|
||||
async def stream_generator():
|
||||
async for chunk in response:
|
||||
yield f"data: {json.dumps({'response': chunk})}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
stream_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Content-Type": "text/event-stream",
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Methods": "POST, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type"
|
||||
}
|
||||
)
|
||||
else:
|
||||
return QueryResponse(response=response)
|
||||
except Exception as e:
|
||||
@@ -484,7 +499,7 @@ def create_app(args):
|
||||
@app.post("/query/stream", dependencies=[Depends(optional_api_key)])
|
||||
async def query_text_stream(request: QueryRequest):
|
||||
try:
|
||||
response = rag.query(
|
||||
response = await rag.aquery( # 使用 aquery 而不是 query,并添加 await
|
||||
request.query,
|
||||
param=QueryParam(
|
||||
mode=request.mode,
|
||||
@@ -493,11 +508,24 @@ def create_app(args):
|
||||
),
|
||||
)
|
||||
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
async def stream_generator():
|
||||
async for chunk in response:
|
||||
yield chunk
|
||||
yield f"data: {chunk}\n\n"
|
||||
|
||||
return stream_generator()
|
||||
return StreamingResponse(
|
||||
stream_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Content-Type": "text/event-stream",
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Methods": "POST, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type"
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -659,20 +687,48 @@ def create_app(args):
|
||||
cleaned_query, mode = parse_query_mode(query)
|
||||
|
||||
# 调用RAG进行查询
|
||||
query_param = QueryParam(
|
||||
mode=mode, # 使用解析出的模式,如果没有前缀则为默认的 hybrid
|
||||
stream=request.stream,
|
||||
only_need_context=False
|
||||
)
|
||||
|
||||
if request.stream:
|
||||
response = await rag.aquery(
|
||||
from fastapi.responses import StreamingResponse
|
||||
import json
|
||||
|
||||
response = await rag.aquery( # 需要 await 来获取异步生成器
|
||||
cleaned_query,
|
||||
param=QueryParam(
|
||||
mode=mode,
|
||||
stream=True,
|
||||
only_need_context=False
|
||||
),
|
||||
param=query_param
|
||||
)
|
||||
|
||||
async def stream_generator():
|
||||
try:
|
||||
async for chunk in response:
|
||||
yield {
|
||||
# 确保 response 是异步生成器
|
||||
if isinstance(response, str):
|
||||
data = {
|
||||
'model': LIGHTRAG_MODEL,
|
||||
'created_at': LIGHTRAG_CREATED_AT,
|
||||
'message': {
|
||||
'role': 'assistant',
|
||||
'content': response
|
||||
},
|
||||
'done': True
|
||||
}
|
||||
yield f"data: {json.dumps(data)}\n\n"
|
||||
else:
|
||||
async for chunk in response:
|
||||
data = {
|
||||
"model": LIGHTRAG_MODEL,
|
||||
"created_at": LIGHTRAG_CREATED_AT,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": chunk
|
||||
},
|
||||
"done": False
|
||||
}
|
||||
yield f"data: {json.dumps(data)}\n\n"
|
||||
data = {
|
||||
"model": LIGHTRAG_MODEL,
|
||||
"created_at": LIGHTRAG_CREATED_AT,
|
||||
"message": {
|
||||
@@ -681,7 +737,10 @@ def create_app(args):
|
||||
},
|
||||
"done": False
|
||||
}
|
||||
yield {
|
||||
yield f"data: {json.dumps(data)}\n\n"
|
||||
|
||||
# 发送完成标记
|
||||
data = {
|
||||
"model": LIGHTRAG_MODEL,
|
||||
"created_at": LIGHTRAG_CREATED_AT,
|
||||
"message": {
|
||||
@@ -690,30 +749,41 @@ def create_app(args):
|
||||
},
|
||||
"done": True
|
||||
}
|
||||
yield f"data: {json.dumps(data)}\n\n"
|
||||
except Exception as e:
|
||||
logging.error(f"Error in stream_generator: {str(e)}")
|
||||
raise
|
||||
from fastapi.responses import StreamingResponse
|
||||
import json
|
||||
|
||||
return StreamingResponse(
|
||||
(f"data: {json.dumps(chunk)}\n\n" async for chunk in stream_generator()),
|
||||
media_type="text/event-stream"
|
||||
stream_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Content-Type": "text/event-stream",
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Methods": "POST, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type"
|
||||
}
|
||||
)
|
||||
else:
|
||||
response = await rag.aquery(
|
||||
# 非流式响应
|
||||
response_text = await rag.aquery(
|
||||
cleaned_query,
|
||||
param=QueryParam(
|
||||
mode=mode,
|
||||
stream=False,
|
||||
only_need_context=False
|
||||
),
|
||||
param=query_param
|
||||
)
|
||||
|
||||
# 确保响应不为空
|
||||
if not response_text:
|
||||
response_text = "No response generated"
|
||||
|
||||
# 构造并返回响应
|
||||
return OllamaChatResponse(
|
||||
model=LIGHTRAG_MODEL,
|
||||
created_at=LIGHTRAG_CREATED_AT,
|
||||
message=OllamaMessage(
|
||||
role="assistant",
|
||||
content=response
|
||||
content=str(response_text) # 确保转换为字符串
|
||||
),
|
||||
done=True
|
||||
)
|
||||
|
Reference in New Issue
Block a user