diff --git a/lightrag/api/lightrag_ollama.py b/lightrag/api/lightrag_ollama.py index 4e83acb0..cc549f4b 100644 --- a/lightrag/api/lightrag_ollama.py +++ b/lightrag/api/lightrag_ollama.py @@ -472,10 +472,25 @@ def create_app(args): ) if request.stream: - result = "" - async for chunk in response: - result += chunk - return QueryResponse(response=result) + from fastapi.responses import StreamingResponse + import json + + async def stream_generator(): + async for chunk in response: + yield f"data: {json.dumps({'response': chunk})}\n\n" + + return StreamingResponse( + stream_generator(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Content-Type": "text/event-stream", + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "POST, OPTIONS", + "Access-Control-Allow-Headers": "Content-Type" + } + ) else: return QueryResponse(response=response) except Exception as e: @@ -484,7 +499,7 @@ def create_app(args): @app.post("/query/stream", dependencies=[Depends(optional_api_key)]) async def query_text_stream(request: QueryRequest): try: - response = rag.query( + response = await rag.aquery( # 使用 aquery 而不是 query,并添加 await request.query, param=QueryParam( mode=request.mode, @@ -493,11 +508,24 @@ def create_app(args): ), ) + from fastapi.responses import StreamingResponse + async def stream_generator(): async for chunk in response: - yield chunk + yield f"data: {chunk}\n\n" - return stream_generator() + return StreamingResponse( + stream_generator(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Content-Type": "text/event-stream", + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "POST, OPTIONS", + "Access-Control-Allow-Headers": "Content-Type" + } + ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -659,20 +687,48 @@ def create_app(args): cleaned_query, mode = parse_query_mode(query) # 调用RAG进行查询 + query_param = QueryParam( + mode=mode, # 使用解析出的模式,如果没有前缀则为默认的 hybrid + stream=request.stream, + only_need_context=False + ) + if request.stream: - response = await rag.aquery( + from fastapi.responses import StreamingResponse + import json + + response = await rag.aquery( # 需要 await 来获取异步生成器 cleaned_query, - param=QueryParam( - mode=mode, - stream=True, - only_need_context=False - ), + param=query_param ) async def stream_generator(): try: - async for chunk in response: - yield { + # 确保 response 是异步生成器 + if isinstance(response, str): + data = { + 'model': LIGHTRAG_MODEL, + 'created_at': LIGHTRAG_CREATED_AT, + 'message': { + 'role': 'assistant', + 'content': response + }, + 'done': True + } + yield f"data: {json.dumps(data)}\n\n" + else: + async for chunk in response: + data = { + "model": LIGHTRAG_MODEL, + "created_at": LIGHTRAG_CREATED_AT, + "message": { + "role": "assistant", + "content": chunk + }, + "done": False + } + yield f"data: {json.dumps(data)}\n\n" + data = { "model": LIGHTRAG_MODEL, "created_at": LIGHTRAG_CREATED_AT, "message": { @@ -681,7 +737,10 @@ def create_app(args): }, "done": False } - yield { + yield f"data: {json.dumps(data)}\n\n" + + # 发送完成标记 + data = { "model": LIGHTRAG_MODEL, "created_at": LIGHTRAG_CREATED_AT, "message": { @@ -690,30 +749,41 @@ def create_app(args): }, "done": True } + yield f"data: {json.dumps(data)}\n\n" except Exception as e: logging.error(f"Error in stream_generator: {str(e)}") raise - from fastapi.responses import StreamingResponse - import json + return StreamingResponse( - (f"data: {json.dumps(chunk)}\n\n" async for chunk in stream_generator()), - media_type="text/event-stream" + stream_generator(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Content-Type": "text/event-stream", + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "POST, OPTIONS", + "Access-Control-Allow-Headers": "Content-Type" + } ) else: - response = await rag.aquery( + # 非流式响应 + response_text = await rag.aquery( cleaned_query, - param=QueryParam( - mode=mode, - stream=False, - only_need_context=False - ), + param=query_param ) + + # 确保响应不为空 + if not response_text: + response_text = "No response generated" + + # 构造并返回响应 return OllamaChatResponse( model=LIGHTRAG_MODEL, created_at=LIGHTRAG_CREATED_AT, message=OllamaMessage( role="assistant", - content=response + content=str(response_text) # 确保转换为字符串 ), done=True )