临时保存

This commit is contained in:
yangdx
2025-01-15 19:32:03 +08:00
parent 828af49d6b
commit f15f97a51d

View File

@@ -472,10 +472,25 @@ def create_app(args):
) )
if request.stream: if request.stream:
result = "" from fastapi.responses import StreamingResponse
async for chunk in response: import json
result += chunk
return QueryResponse(response=result) async def stream_generator():
async for chunk in response:
yield f"data: {json.dumps({'response': chunk})}\n\n"
return StreamingResponse(
stream_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Content-Type": "text/event-stream",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type"
}
)
else: else:
return QueryResponse(response=response) return QueryResponse(response=response)
except Exception as e: except Exception as e:
@@ -484,7 +499,7 @@ def create_app(args):
@app.post("/query/stream", dependencies=[Depends(optional_api_key)]) @app.post("/query/stream", dependencies=[Depends(optional_api_key)])
async def query_text_stream(request: QueryRequest): async def query_text_stream(request: QueryRequest):
try: try:
response = rag.query( response = await rag.aquery( # 使用 aquery 而不是 query,并添加 await
request.query, request.query,
param=QueryParam( param=QueryParam(
mode=request.mode, mode=request.mode,
@@ -493,11 +508,24 @@ def create_app(args):
), ),
) )
from fastapi.responses import StreamingResponse
async def stream_generator(): async def stream_generator():
async for chunk in response: async for chunk in response:
yield chunk yield f"data: {chunk}\n\n"
return stream_generator() return StreamingResponse(
stream_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Content-Type": "text/event-stream",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type"
}
)
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@@ -659,20 +687,48 @@ def create_app(args):
cleaned_query, mode = parse_query_mode(query) cleaned_query, mode = parse_query_mode(query)
# 调用RAG进行查询 # 调用RAG进行查询
query_param = QueryParam(
mode=mode, # 使用解析出的模式,如果没有前缀则为默认的 hybrid
stream=request.stream,
only_need_context=False
)
if request.stream: if request.stream:
response = await rag.aquery( from fastapi.responses import StreamingResponse
import json
response = await rag.aquery( # 需要 await 来获取异步生成器
cleaned_query, cleaned_query,
param=QueryParam( param=query_param
mode=mode,
stream=True,
only_need_context=False
),
) )
async def stream_generator(): async def stream_generator():
try: try:
async for chunk in response: # 确保 response 是异步生成器
yield { if isinstance(response, str):
data = {
'model': LIGHTRAG_MODEL,
'created_at': LIGHTRAG_CREATED_AT,
'message': {
'role': 'assistant',
'content': response
},
'done': True
}
yield f"data: {json.dumps(data)}\n\n"
else:
async for chunk in response:
data = {
"model": LIGHTRAG_MODEL,
"created_at": LIGHTRAG_CREATED_AT,
"message": {
"role": "assistant",
"content": chunk
},
"done": False
}
yield f"data: {json.dumps(data)}\n\n"
data = {
"model": LIGHTRAG_MODEL, "model": LIGHTRAG_MODEL,
"created_at": LIGHTRAG_CREATED_AT, "created_at": LIGHTRAG_CREATED_AT,
"message": { "message": {
@@ -681,7 +737,10 @@ def create_app(args):
}, },
"done": False "done": False
} }
yield { yield f"data: {json.dumps(data)}\n\n"
# 发送完成标记
data = {
"model": LIGHTRAG_MODEL, "model": LIGHTRAG_MODEL,
"created_at": LIGHTRAG_CREATED_AT, "created_at": LIGHTRAG_CREATED_AT,
"message": { "message": {
@@ -690,30 +749,41 @@ def create_app(args):
}, },
"done": True "done": True
} }
yield f"data: {json.dumps(data)}\n\n"
except Exception as e: except Exception as e:
logging.error(f"Error in stream_generator: {str(e)}") logging.error(f"Error in stream_generator: {str(e)}")
raise raise
from fastapi.responses import StreamingResponse
import json
return StreamingResponse( return StreamingResponse(
(f"data: {json.dumps(chunk)}\n\n" async for chunk in stream_generator()), stream_generator(),
media_type="text/event-stream" media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Content-Type": "text/event-stream",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type"
}
) )
else: else:
response = await rag.aquery( # 非流式响应
response_text = await rag.aquery(
cleaned_query, cleaned_query,
param=QueryParam( param=query_param
mode=mode,
stream=False,
only_need_context=False
),
) )
# 确保响应不为空
if not response_text:
response_text = "No response generated"
# 构造并返回响应
return OllamaChatResponse( return OllamaChatResponse(
model=LIGHTRAG_MODEL, model=LIGHTRAG_MODEL,
created_at=LIGHTRAG_CREATED_AT, created_at=LIGHTRAG_CREATED_AT,
message=OllamaMessage( message=OllamaMessage(
role="assistant", role="assistant",
content=response content=str(response_text) # 确保转换为字符串
), ),
done=True done=True
) )