为Ollama API添加性能统计功能

- 新增token估算函数
- 记录流式响应时间
- 计算输入输出token数
- 统计响应生成时间
- 返回详细的性能指标
This commit is contained in:
yangdx
2025-01-16 19:42:34 +08:00
parent 9c69438c3e
commit 95ff048a9e

View File

@@ -3,6 +3,8 @@ from pydantic import BaseModel
import logging import logging
import argparse import argparse
import json import json
import time
import re
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
from lightrag.llm import openai_complete_if_cache, ollama_embedding from lightrag.llm import openai_complete_if_cache, ollama_embedding
@@ -24,6 +26,20 @@ from starlette.status import HTTP_403_FORBIDDEN
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
def estimate_tokens(text: str) -> int:
"""估算文本的token数量
中文每字约1.5个token
英文每字约0.25个token
"""
# 使用正则表达式分别匹配中文字符和非中文字符
chinese_chars = len(re.findall(r'[\u4e00-\u9fff]', text))
non_chinese_chars = len(re.findall(r'[^\u4e00-\u9fff]', text))
# 计算估算的token数量
tokens = chinese_chars * 1.5 + non_chinese_chars * 0.25
return int(tokens)
# Constants for model information # Constants for model information
LIGHTRAG_NAME = "lightrag" LIGHTRAG_NAME = "lightrag"
LIGHTRAG_TAG = "latest" LIGHTRAG_TAG = "latest"
@@ -690,6 +706,12 @@ def create_app(args):
# 解析查询模式 # 解析查询模式
cleaned_query, mode = parse_query_mode(query) cleaned_query, mode = parse_query_mode(query)
# 开始计时
start_time = time.time_ns()
# 计算输入token数量
prompt_tokens = estimate_tokens(cleaned_query)
# 调用RAG进行查询 # 调用RAG进行查询
query_param = QueryParam( query_param = QueryParam(
mode=mode, # 使用解析出的模式,如果没有前缀则为默认的 hybrid mode=mode, # 使用解析出的模式,如果没有前缀则为默认的 hybrid
@@ -707,9 +729,17 @@ def create_app(args):
async def stream_generator(): async def stream_generator():
try: try:
first_chunk_time = None
last_chunk_time = None
total_response = ""
# 确保 response 是异步生成器 # 确保 response 是异步生成器
if isinstance(response, str): if isinstance(response, str):
# 如果是字符串,分两次发送 # 如果是字符串,分两次发送
first_chunk_time = time.time_ns()
last_chunk_time = first_chunk_time
total_response = response
# 第一次发送查询内容 # 第一次发送查询内容
data = { data = {
"model": LIGHTRAG_MODEL, "model": LIGHTRAG_MODEL,
@@ -723,23 +753,38 @@ def create_app(args):
} }
yield f"{json.dumps(data, ensure_ascii=False)}\n" yield f"{json.dumps(data, ensure_ascii=False)}\n"
# 计算各项指标
completion_tokens = estimate_tokens(total_response)
total_time = last_chunk_time - start_time # 总时间
prompt_eval_time = first_chunk_time - start_time # 首个响应之前的时间
eval_time = last_chunk_time - first_chunk_time # 生成响应的时间
# 第二次发送统计信息 # 第二次发送统计信息
data = { data = {
"model": LIGHTRAG_MODEL, "model": LIGHTRAG_MODEL,
"created_at": LIGHTRAG_CREATED_AT, "created_at": LIGHTRAG_CREATED_AT,
"done": True, "done": True,
"total_duration": 1, "total_duration": total_time, # 总时间
"load_duration": 1, "load_duration": 0, # 加载时间为0
"prompt_eval_count": 999, "prompt_eval_count": prompt_tokens, # 输入token数
"prompt_eval_duration": 1, "prompt_eval_duration": prompt_eval_time, # 首个响应之前的时间
"eval_count": 999, "eval_count": completion_tokens, # 输出token数
"eval_duration": 1 "eval_duration": eval_time # 生成响应的时间
} }
yield f"{json.dumps(data, ensure_ascii=False)}\n" yield f"{json.dumps(data, ensure_ascii=False)}\n"
else: else:
# 流式响应 # 流式响应
async for chunk in response: async for chunk in response:
if chunk: # 只发送非空内容 if chunk: # 只发送非空内容
# 记录第一个chunk的时间
if first_chunk_time is None:
first_chunk_time = time.time_ns()
# 更新最后一个chunk的时间
last_chunk_time = time.time_ns()
# 累积响应内容
total_response += chunk
data = { data = {
"model": LIGHTRAG_MODEL, "model": LIGHTRAG_MODEL,
"created_at": LIGHTRAG_CREATED_AT, "created_at": LIGHTRAG_CREATED_AT,
@@ -752,17 +797,23 @@ def create_app(args):
} }
yield f"{json.dumps(data, ensure_ascii=False)}\n" yield f"{json.dumps(data, ensure_ascii=False)}\n"
# 计算各项指标
completion_tokens = estimate_tokens(total_response)
total_time = last_chunk_time - start_time # 总时间
prompt_eval_time = first_chunk_time - start_time # 首个响应之前的时间
eval_time = last_chunk_time - first_chunk_time # 生成响应的时间
# 发送完成标记,包含性能统计信息 # 发送完成标记,包含性能统计信息
data = { data = {
"model": LIGHTRAG_MODEL, "model": LIGHTRAG_MODEL,
"created_at": LIGHTRAG_CREATED_AT, "created_at": LIGHTRAG_CREATED_AT,
"done": True, "done": True,
"total_duration": 1, # 由于我们没有实际统计这些指标,暂时使用默认值 "total_duration": total_time, # 总时间
"load_duration": 1, "load_duration": 0, # 加载时间为0
"prompt_eval_count": 999, "prompt_eval_count": prompt_tokens, # 输入token数
"prompt_eval_duration": 1, "prompt_eval_duration": prompt_eval_time, # 首个响应之前的时间
"eval_count": 999, "eval_count": completion_tokens, # 输出token数
"eval_duration": 1 "eval_duration": eval_time # 生成响应的时间
} }
yield f"{json.dumps(data, ensure_ascii=False)}\n" yield f"{json.dumps(data, ensure_ascii=False)}\n"
return # 确保生成器在发送完成标记后立即结束 return # 确保生成器在发送完成标记后立即结束
@@ -784,15 +835,23 @@ def create_app(args):
) )
else: else:
# 非流式响应 # 非流式响应
first_chunk_time = time.time_ns()
response_text = await rag.aquery( response_text = await rag.aquery(
cleaned_query, cleaned_query,
param=query_param param=query_param
) )
last_chunk_time = time.time_ns()
# 确保响应不为空 # 确保响应不为空
if not response_text: if not response_text:
response_text = "No response generated" response_text = "No response generated"
# 计算各项指标
completion_tokens = estimate_tokens(str(response_text))
total_time = last_chunk_time - start_time # 总时间
prompt_eval_time = first_chunk_time - start_time # 首个响应之前的时间
eval_time = last_chunk_time - first_chunk_time # 生成响应的时间
# 构造响应,包含性能统计信息 # 构造响应,包含性能统计信息
return { return {
"model": LIGHTRAG_MODEL, "model": LIGHTRAG_MODEL,
@@ -803,12 +862,12 @@ def create_app(args):
"images": None "images": None
}, },
"done": True, "done": True,
"total_duration": 1, # 由于我们没有实际统计这些指标,暂时使用默认值 "total_duration": total_time, # 总时间
"load_duration": 1, "load_duration": 0, # 加载时间为0
"prompt_eval_count": 999, "prompt_eval_count": prompt_tokens, # 输入token数
"prompt_eval_duration": 1, "prompt_eval_duration": prompt_eval_time, # 首个响应之前的时间
"eval_count": 999, "eval_count": completion_tokens, # 输出token数
"eval_duration": 1 "eval_duration": eval_time # 生成响应的时间
} }
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))