diff --git a/lightrag/api/lightrag_ollama.py b/lightrag/api/lightrag_ollama.py index cc549f4b..004c2739 100644 --- a/lightrag/api/lightrag_ollama.py +++ b/lightrag/api/lightrag_ollama.py @@ -706,50 +706,44 @@ def create_app(args): try: # 确保 response 是异步生成器 if isinstance(response, str): - data = { - 'model': LIGHTRAG_MODEL, - 'created_at': LIGHTRAG_CREATED_AT, - 'message': { - 'role': 'assistant', - 'content': response - }, - 'done': True - } - yield f"data: {json.dumps(data)}\n\n" - else: - async for chunk in response: - data = { - "model": LIGHTRAG_MODEL, - "created_at": LIGHTRAG_CREATED_AT, - "message": { - "role": "assistant", - "content": chunk - }, - "done": False - } - yield f"data: {json.dumps(data)}\n\n" + # 如果是字符串,作为单个完整响应发送 data = { "model": LIGHTRAG_MODEL, "created_at": LIGHTRAG_CREATED_AT, "message": { "role": "assistant", - "content": chunk + "content": response }, - "done": False + "done": True } - yield f"data: {json.dumps(data)}\n\n" + yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n" + else: + # 流式响应 + async for chunk in response: + if chunk: # 只发送非空内容 + data = { + "model": LIGHTRAG_MODEL, + "created_at": LIGHTRAG_CREATED_AT, + "message": { + "role": "assistant", + "content": chunk + }, + "done": False + } + yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n" - # 发送完成标记 - data = { - "model": LIGHTRAG_MODEL, - "created_at": LIGHTRAG_CREATED_AT, - "message": { - "role": "assistant", - "content": "" - }, - "done": True - } - yield f"data: {json.dumps(data)}\n\n" + # 发送完成标记 + data = { + "model": LIGHTRAG_MODEL, + "created_at": LIGHTRAG_CREATED_AT, + "message": { + "role": "assistant", + "content": "" + }, + "done": True + } + yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n" + return # 确保生成器在发送完成标记后立即结束 except Exception as e: logging.error(f"Error in stream_generator: {str(e)}") raise diff --git a/test_lightrag_ollama_chat.py b/test_lightrag_ollama_chat.py new file mode 100644 index 00000000..067b8877 --- /dev/null +++ b/test_lightrag_ollama_chat.py @@ -0,0 +1,73 @@ +import requests +import json +import sseclient + +def test_non_stream_chat(): + """测试非流式调用 /api/chat 接口""" + url = "http://localhost:9621/api/chat" + + # 构造请求数据 + data = { + "model": "lightrag:latest", + "messages": [ + { + "role": "user", + "content": "孙悟空" + } + ], + "stream": False + } + + # 发送请求 + response = requests.post(url, json=data) + + # 打印响应 + print("\n=== 非流式调用响应 ===") + print(json.dumps(response.json(), ensure_ascii=False, indent=2)) + +def test_stream_chat(): + """测试流式调用 /api/chat 接口""" + url = "http://localhost:9621/api/chat" + + # 构造请求数据 + data = { + "model": "lightrag:latest", + "messages": [ + { + "role": "user", + "content": "/naive 孙悟空有什么法力,性格特征是什么" + } + ], + "stream": True + } + + # 发送请求并获取 SSE 流 + response = requests.post(url, json=data, stream=True) + client = sseclient.SSEClient(response) + + print("\n=== 流式调用响应 ===") + output_buffer = [] + try: + for event in client.events(): + try: + data = json.loads(event.data) + message = data.get("message", {}) + content = message.get("content", "") + if content: # 只收集非空内容 + output_buffer.append(content) + if data.get("done", False): # 如果收到完成标记,退出循环 + break + except json.JSONDecodeError: + print("Error decoding JSON from SSE event") + finally: + response.close() # 确保关闭响应连接 + + # 一次性打印所有收集到的内容 + print("".join(output_buffer)) + +if __name__ == "__main__": + # 先测试非流式调用 + test_non_stream_chat() + + # 再测试流式调用 + test_stream_chat()