为Ollama API返回结果添加图像字段和性能统计信息

- 在OllamaMessage中添加images字段
- 响应消息中增加images字段
- 完成标记中添加性能统计信息
- 更新测试用例以处理性能统计
- 移除测试用例中的/naive前缀
This commit is contained in:
yangdx
2025-01-15 20:46:45 +08:00
parent 23f838ec94
commit f81b1cdf0a
2 changed files with 26 additions and 16 deletions

View File

@@ -231,6 +231,7 @@ class SearchMode(str, Enum):
class OllamaMessage(BaseModel): class OllamaMessage(BaseModel):
role: str role: str
content: str content: str
images: Optional[List[str]] = None
class OllamaChatRequest(BaseModel): class OllamaChatRequest(BaseModel):
model: str = LIGHTRAG_MODEL model: str = LIGHTRAG_MODEL
@@ -712,7 +713,8 @@ def create_app(args):
"created_at": LIGHTRAG_CREATED_AT, "created_at": LIGHTRAG_CREATED_AT,
"message": { "message": {
"role": "assistant", "role": "assistant",
"content": response "content": response,
"images": None
}, },
"done": True "done": True
} }
@@ -726,21 +728,24 @@ def create_app(args):
"created_at": LIGHTRAG_CREATED_AT, "created_at": LIGHTRAG_CREATED_AT,
"message": { "message": {
"role": "assistant", "role": "assistant",
"content": chunk "content": chunk,
"images": None
}, },
"done": False "done": False
} }
yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n" yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
# 发送完成标记 # 发送完成标记,包含性能统计信息
data = { data = {
"model": LIGHTRAG_MODEL, "model": LIGHTRAG_MODEL,
"created_at": LIGHTRAG_CREATED_AT, "created_at": LIGHTRAG_CREATED_AT,
"message": { "done": True,
"role": "assistant", "total_duration": 0, # 由于我们没有实际统计这些指标,暂时使用默认值
"content": "" "load_duration": 0,
}, "prompt_eval_count": 0,
"done": True "prompt_eval_duration": 0,
"eval_count": 0,
"eval_duration": 0
} }
yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n" yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
return # 确保生成器在发送完成标记后立即结束 return # 确保生成器在发送完成标记后立即结束
@@ -777,7 +782,8 @@ def create_app(args):
created_at=LIGHTRAG_CREATED_AT, created_at=LIGHTRAG_CREATED_AT,
message=OllamaMessage( message=OllamaMessage(
role="assistant", role="assistant",
content=str(response_text) # 确保转换为字符串 content=str(response_text), # 确保转换为字符串
images=None
), ),
done=True done=True
) )

View File

@@ -35,7 +35,7 @@ def test_stream_chat():
"messages": [ "messages": [
{ {
"role": "user", "role": "user",
"content": "/naive 孙悟空有什么法力,性格特征是什么" "content": "孙悟空有什么法力,性格特征是什么"
} }
], ],
"stream": True "stream": True
@@ -51,12 +51,16 @@ def test_stream_chat():
for event in client.events(): for event in client.events():
try: try:
data = json.loads(event.data) data = json.loads(event.data)
message = data.get("message", {}) if data.get("done", False): # 如果是完成标记
content = message.get("content", "") if "total_duration" in data: # 最终的性能统计消息
if content: # 只收集非空内容 print("\n=== 性能统计 ===")
output_buffer.append(content) print(json.dumps(data, ensure_ascii=False, indent=2))
if data.get("done", False): # 如果收到完成标记,退出循环 break
break else: # 正常的内容消息
message = data.get("message", {})
content = message.get("content", "")
if content: # 只收集非空内容
output_buffer.append(content)
except json.JSONDecodeError: except json.JSONDecodeError:
print("Error decoding JSON from SSE event") print("Error decoding JSON from SSE event")
finally: finally: