pre-commit run --all-files
This commit is contained in:
@@ -24,22 +24,25 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||||||
from starlette.status import HTTP_403_FORBIDDEN
|
from starlette.status import HTTP_403_FORBIDDEN
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
def estimate_tokens(text: str) -> int:
|
def estimate_tokens(text: str) -> int:
|
||||||
"""Estimate the number of tokens in text
|
"""Estimate the number of tokens in text
|
||||||
Chinese characters: approximately 1.5 tokens per character
|
Chinese characters: approximately 1.5 tokens per character
|
||||||
English characters: approximately 0.25 tokens per character
|
English characters: approximately 0.25 tokens per character
|
||||||
"""
|
"""
|
||||||
# Use regex to match Chinese and non-Chinese characters separately
|
# Use regex to match Chinese and non-Chinese characters separately
|
||||||
chinese_chars = len(re.findall(r'[\u4e00-\u9fff]', text))
|
chinese_chars = len(re.findall(r"[\u4e00-\u9fff]", text))
|
||||||
non_chinese_chars = len(re.findall(r'[^\u4e00-\u9fff]', text))
|
non_chinese_chars = len(re.findall(r"[^\u4e00-\u9fff]", text))
|
||||||
|
|
||||||
# Calculate estimated token count
|
# Calculate estimated token count
|
||||||
tokens = chinese_chars * 1.5 + non_chinese_chars * 0.25
|
tokens = chinese_chars * 1.5 + non_chinese_chars * 0.25
|
||||||
|
|
||||||
return int(tokens)
|
return int(tokens)
|
||||||
|
|
||||||
|
|
||||||
# Constants for model information
|
# Constants for model information
|
||||||
LIGHTRAG_NAME = "lightrag"
|
LIGHTRAG_NAME = "lightrag"
|
||||||
LIGHTRAG_TAG = "latest"
|
LIGHTRAG_TAG = "latest"
|
||||||
@@ -48,6 +51,7 @@ LIGHTRAG_SIZE = 7365960935
|
|||||||
LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
|
LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
|
||||||
LIGHTRAG_DIGEST = "sha256:lightrag"
|
LIGHTRAG_DIGEST = "sha256:lightrag"
|
||||||
|
|
||||||
|
|
||||||
async def llm_model_func(
|
async def llm_model_func(
|
||||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
@@ -61,6 +65,7 @@ async def llm_model_func(
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_default_host(binding_type: str) -> str:
|
def get_default_host(binding_type: str) -> str:
|
||||||
default_hosts = {
|
default_hosts = {
|
||||||
"ollama": "http://m4.lan.znipower.com:11434",
|
"ollama": "http://m4.lan.znipower.com:11434",
|
||||||
@@ -245,27 +250,32 @@ class SearchMode(str, Enum):
|
|||||||
hybrid = "hybrid"
|
hybrid = "hybrid"
|
||||||
mix = "mix"
|
mix = "mix"
|
||||||
|
|
||||||
|
|
||||||
# Ollama API compatible models
|
# Ollama API compatible models
|
||||||
class OllamaMessage(BaseModel):
|
class OllamaMessage(BaseModel):
|
||||||
role: str
|
role: str
|
||||||
content: str
|
content: str
|
||||||
images: Optional[List[str]] = None
|
images: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
class OllamaChatRequest(BaseModel):
|
class OllamaChatRequest(BaseModel):
|
||||||
model: str = LIGHTRAG_MODEL
|
model: str = LIGHTRAG_MODEL
|
||||||
messages: List[OllamaMessage]
|
messages: List[OllamaMessage]
|
||||||
stream: bool = True # Default to streaming mode
|
stream: bool = True # Default to streaming mode
|
||||||
options: Optional[Dict[str, Any]] = None
|
options: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
class OllamaChatResponse(BaseModel):
|
class OllamaChatResponse(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
created_at: str
|
created_at: str
|
||||||
message: OllamaMessage
|
message: OllamaMessage
|
||||||
done: bool
|
done: bool
|
||||||
|
|
||||||
|
|
||||||
class OllamaVersionResponse(BaseModel):
|
class OllamaVersionResponse(BaseModel):
|
||||||
version: str
|
version: str
|
||||||
|
|
||||||
|
|
||||||
class OllamaModelDetails(BaseModel):
|
class OllamaModelDetails(BaseModel):
|
||||||
parent_model: str
|
parent_model: str
|
||||||
format: str
|
format: str
|
||||||
@@ -274,6 +284,7 @@ class OllamaModelDetails(BaseModel):
|
|||||||
parameter_size: str
|
parameter_size: str
|
||||||
quantization_level: str
|
quantization_level: str
|
||||||
|
|
||||||
|
|
||||||
class OllamaModel(BaseModel):
|
class OllamaModel(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
model: str
|
model: str
|
||||||
@@ -282,9 +293,11 @@ class OllamaModel(BaseModel):
|
|||||||
modified_at: str
|
modified_at: str
|
||||||
details: OllamaModelDetails
|
details: OllamaModelDetails
|
||||||
|
|
||||||
|
|
||||||
class OllamaTagResponse(BaseModel):
|
class OllamaTagResponse(BaseModel):
|
||||||
models: List[OllamaModel]
|
models: List[OllamaModel]
|
||||||
|
|
||||||
|
|
||||||
# Original LightRAG models
|
# Original LightRAG models
|
||||||
class QueryRequest(BaseModel):
|
class QueryRequest(BaseModel):
|
||||||
query: str
|
query: str
|
||||||
@@ -292,9 +305,11 @@ class QueryRequest(BaseModel):
|
|||||||
stream: bool = False
|
stream: bool = False
|
||||||
only_need_context: bool = False
|
only_need_context: bool = False
|
||||||
|
|
||||||
|
|
||||||
class QueryResponse(BaseModel):
|
class QueryResponse(BaseModel):
|
||||||
response: str
|
response: str
|
||||||
|
|
||||||
|
|
||||||
class InsertTextRequest(BaseModel):
|
class InsertTextRequest(BaseModel):
|
||||||
text: str
|
text: str
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
@@ -395,7 +410,9 @@ def create_app(args):
|
|||||||
embedding_dim=1024,
|
embedding_dim=1024,
|
||||||
max_token_size=8192,
|
max_token_size=8192,
|
||||||
func=lambda texts: ollama_embedding(
|
func=lambda texts: ollama_embedding(
|
||||||
texts, embed_model="bge-m3:latest", host="http://m4.lan.znipower.com:11434"
|
texts,
|
||||||
|
embed_model="bge-m3:latest",
|
||||||
|
host="http://m4.lan.znipower.com:11434",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -493,7 +510,7 @@ def create_app(args):
|
|||||||
# If response is a string (e.g. cache hit), return directly
|
# If response is a string (e.g. cache hit), return directly
|
||||||
if isinstance(response, str):
|
if isinstance(response, str):
|
||||||
return QueryResponse(response=response)
|
return QueryResponse(response=response)
|
||||||
|
|
||||||
# If it's an async generator, decide whether to stream based on stream parameter
|
# If it's an async generator, decide whether to stream based on stream parameter
|
||||||
if request.stream:
|
if request.stream:
|
||||||
result = ""
|
result = ""
|
||||||
@@ -546,8 +563,8 @@ def create_app(args):
|
|||||||
"Access-Control-Allow-Origin": "*",
|
"Access-Control-Allow-Origin": "*",
|
||||||
"Access-Control-Allow-Methods": "POST, OPTIONS",
|
"Access-Control-Allow-Methods": "POST, OPTIONS",
|
||||||
"Access-Control-Allow-Headers": "Content-Type",
|
"Access-Control-Allow-Headers": "Content-Type",
|
||||||
"X-Accel-Buffering": "no" # 禁用 Nginx 缓冲
|
"X-Accel-Buffering": "no", # 禁用 Nginx 缓冲
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
@@ -652,29 +669,29 @@ def create_app(args):
|
|||||||
@app.get("/api/version")
|
@app.get("/api/version")
|
||||||
async def get_version():
|
async def get_version():
|
||||||
"""Get Ollama version information"""
|
"""Get Ollama version information"""
|
||||||
return OllamaVersionResponse(
|
return OllamaVersionResponse(version="0.5.4")
|
||||||
version="0.5.4"
|
|
||||||
)
|
|
||||||
|
|
||||||
@app.get("/api/tags")
|
@app.get("/api/tags")
|
||||||
async def get_tags():
|
async def get_tags():
|
||||||
"""Get available models"""
|
"""Get available models"""
|
||||||
return OllamaTagResponse(
|
return OllamaTagResponse(
|
||||||
models=[{
|
models=[
|
||||||
"name": LIGHTRAG_MODEL,
|
{
|
||||||
"model": LIGHTRAG_MODEL,
|
"name": LIGHTRAG_MODEL,
|
||||||
"size": LIGHTRAG_SIZE,
|
"model": LIGHTRAG_MODEL,
|
||||||
"digest": LIGHTRAG_DIGEST,
|
"size": LIGHTRAG_SIZE,
|
||||||
"modified_at": LIGHTRAG_CREATED_AT,
|
"digest": LIGHTRAG_DIGEST,
|
||||||
"details": {
|
"modified_at": LIGHTRAG_CREATED_AT,
|
||||||
"parent_model": "",
|
"details": {
|
||||||
"format": "gguf",
|
"parent_model": "",
|
||||||
"family": LIGHTRAG_NAME,
|
"format": "gguf",
|
||||||
"families": [LIGHTRAG_NAME],
|
"family": LIGHTRAG_NAME,
|
||||||
"parameter_size": "13B",
|
"families": [LIGHTRAG_NAME],
|
||||||
"quantization_level": "Q4_0"
|
"parameter_size": "13B",
|
||||||
}
|
"quantization_level": "Q4_0",
|
||||||
}]
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
def parse_query_mode(query: str) -> tuple[str, SearchMode]:
|
def parse_query_mode(query: str) -> tuple[str, SearchMode]:
|
||||||
@@ -686,15 +703,15 @@ def create_app(args):
|
|||||||
"/global ": SearchMode.global_, # global_ is used because 'global' is a Python keyword
|
"/global ": SearchMode.global_, # global_ is used because 'global' is a Python keyword
|
||||||
"/naive ": SearchMode.naive,
|
"/naive ": SearchMode.naive,
|
||||||
"/hybrid ": SearchMode.hybrid,
|
"/hybrid ": SearchMode.hybrid,
|
||||||
"/mix ": SearchMode.mix
|
"/mix ": SearchMode.mix,
|
||||||
}
|
}
|
||||||
|
|
||||||
for prefix, mode in mode_map.items():
|
for prefix, mode in mode_map.items():
|
||||||
if query.startswith(prefix):
|
if query.startswith(prefix):
|
||||||
# After removing prefix an leading spaces
|
# After removing prefix an leading spaces
|
||||||
cleaned_query = query[len(prefix):].lstrip()
|
cleaned_query = query[len(prefix) :].lstrip()
|
||||||
return cleaned_query, mode
|
return cleaned_query, mode
|
||||||
|
|
||||||
return query, SearchMode.hybrid
|
return query, SearchMode.hybrid
|
||||||
|
|
||||||
@app.post("/api/chat")
|
@app.post("/api/chat")
|
||||||
@@ -705,32 +722,29 @@ def create_app(args):
|
|||||||
messages = request.messages
|
messages = request.messages
|
||||||
if not messages:
|
if not messages:
|
||||||
raise HTTPException(status_code=400, detail="No messages provided")
|
raise HTTPException(status_code=400, detail="No messages provided")
|
||||||
|
|
||||||
# Get the last message as query
|
# Get the last message as query
|
||||||
query = messages[-1].content
|
query = messages[-1].content
|
||||||
|
|
||||||
# 解析查询模式
|
# 解析查询模式
|
||||||
cleaned_query, mode = parse_query_mode(query)
|
cleaned_query, mode = parse_query_mode(query)
|
||||||
|
|
||||||
# 开始计时
|
# 开始计时
|
||||||
start_time = time.time_ns()
|
start_time = time.time_ns()
|
||||||
|
|
||||||
# 计算输入token数量
|
# 计算输入token数量
|
||||||
prompt_tokens = estimate_tokens(cleaned_query)
|
prompt_tokens = estimate_tokens(cleaned_query)
|
||||||
|
|
||||||
# 调用RAG进行查询
|
# 调用RAG进行查询
|
||||||
query_param = QueryParam(
|
query_param = QueryParam(
|
||||||
mode=mode,
|
mode=mode, stream=request.stream, only_need_context=False
|
||||||
stream=request.stream,
|
|
||||||
only_need_context=False
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if request.stream:
|
if request.stream:
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
response = await rag.aquery( # Need await to get async generator
|
response = await rag.aquery( # Need await to get async generator
|
||||||
cleaned_query,
|
cleaned_query, param=query_param
|
||||||
param=query_param
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def stream_generator():
|
async def stream_generator():
|
||||||
@@ -738,33 +752,37 @@ def create_app(args):
|
|||||||
first_chunk_time = None
|
first_chunk_time = None
|
||||||
last_chunk_time = None
|
last_chunk_time = None
|
||||||
total_response = ""
|
total_response = ""
|
||||||
|
|
||||||
# Ensure response is an async generator
|
# Ensure response is an async generator
|
||||||
if isinstance(response, str):
|
if isinstance(response, str):
|
||||||
# If it's a string, send in two parts
|
# If it's a string, send in two parts
|
||||||
first_chunk_time = time.time_ns()
|
first_chunk_time = time.time_ns()
|
||||||
last_chunk_time = first_chunk_time
|
last_chunk_time = first_chunk_time
|
||||||
total_response = response
|
total_response = response
|
||||||
|
|
||||||
# 第一次发送查询内容
|
# 第一次发送查询内容
|
||||||
data = {
|
data = {
|
||||||
"model": LIGHTRAG_MODEL,
|
"model": LIGHTRAG_MODEL,
|
||||||
"created_at": LIGHTRAG_CREATED_AT,
|
"created_at": LIGHTRAG_CREATED_AT,
|
||||||
"message": {
|
"message": {
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": response,
|
"content": response,
|
||||||
"images": None
|
"images": None,
|
||||||
},
|
},
|
||||||
"done": False
|
"done": False,
|
||||||
}
|
}
|
||||||
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||||
|
|
||||||
# 计算各项指标
|
# 计算各项指标
|
||||||
completion_tokens = estimate_tokens(total_response)
|
completion_tokens = estimate_tokens(total_response)
|
||||||
total_time = last_chunk_time - start_time # 总时间
|
total_time = last_chunk_time - start_time # 总时间
|
||||||
prompt_eval_time = first_chunk_time - start_time # 首个响应之前的时间
|
prompt_eval_time = (
|
||||||
eval_time = last_chunk_time - first_chunk_time # 生成响应的时间
|
first_chunk_time - start_time
|
||||||
|
) # 首个响应之前的时间
|
||||||
|
eval_time = (
|
||||||
|
last_chunk_time - first_chunk_time
|
||||||
|
) # 生成响应的时间
|
||||||
|
|
||||||
# 第二次发送统计信息
|
# 第二次发送统计信息
|
||||||
data = {
|
data = {
|
||||||
"model": LIGHTRAG_MODEL,
|
"model": LIGHTRAG_MODEL,
|
||||||
@@ -775,7 +793,7 @@ def create_app(args):
|
|||||||
"prompt_eval_count": prompt_tokens, # 输入token数
|
"prompt_eval_count": prompt_tokens, # 输入token数
|
||||||
"prompt_eval_duration": prompt_eval_time, # 首个响应之前的时间
|
"prompt_eval_duration": prompt_eval_time, # 首个响应之前的时间
|
||||||
"eval_count": completion_tokens, # 输出token数
|
"eval_count": completion_tokens, # 输出token数
|
||||||
"eval_duration": eval_time # 生成响应的时间
|
"eval_duration": eval_time, # 生成响应的时间
|
||||||
}
|
}
|
||||||
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||||
else:
|
else:
|
||||||
@@ -785,10 +803,10 @@ def create_app(args):
|
|||||||
# 记录第一个chunk的时间
|
# 记录第一个chunk的时间
|
||||||
if first_chunk_time is None:
|
if first_chunk_time is None:
|
||||||
first_chunk_time = time.time_ns()
|
first_chunk_time = time.time_ns()
|
||||||
|
|
||||||
# 更新最后一个chunk的时间
|
# 更新最后一个chunk的时间
|
||||||
last_chunk_time = time.time_ns()
|
last_chunk_time = time.time_ns()
|
||||||
|
|
||||||
# 累积响应内容
|
# 累积响应内容
|
||||||
total_response += chunk
|
total_response += chunk
|
||||||
data = {
|
data = {
|
||||||
@@ -797,18 +815,22 @@ def create_app(args):
|
|||||||
"message": {
|
"message": {
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": chunk,
|
"content": chunk,
|
||||||
"images": None
|
"images": None,
|
||||||
},
|
},
|
||||||
"done": False
|
"done": False,
|
||||||
}
|
}
|
||||||
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||||
|
|
||||||
# 计算各项指标
|
# 计算各项指标
|
||||||
completion_tokens = estimate_tokens(total_response)
|
completion_tokens = estimate_tokens(total_response)
|
||||||
total_time = last_chunk_time - start_time # 总时间
|
total_time = last_chunk_time - start_time # 总时间
|
||||||
prompt_eval_time = first_chunk_time - start_time # 首个响应之前的时间
|
prompt_eval_time = (
|
||||||
eval_time = last_chunk_time - first_chunk_time # 生成响应的时间
|
first_chunk_time - start_time
|
||||||
|
) # 首个响应之前的时间
|
||||||
|
eval_time = (
|
||||||
|
last_chunk_time - first_chunk_time
|
||||||
|
) # 生成响应的时间
|
||||||
|
|
||||||
# 发送完成标记,包含性能统计信息
|
# 发送完成标记,包含性能统计信息
|
||||||
data = {
|
data = {
|
||||||
"model": LIGHTRAG_MODEL,
|
"model": LIGHTRAG_MODEL,
|
||||||
@@ -819,14 +841,14 @@ def create_app(args):
|
|||||||
"prompt_eval_count": prompt_tokens, # 输入token数
|
"prompt_eval_count": prompt_tokens, # 输入token数
|
||||||
"prompt_eval_duration": prompt_eval_time, # 首个响应之前的时间
|
"prompt_eval_duration": prompt_eval_time, # 首个响应之前的时间
|
||||||
"eval_count": completion_tokens, # 输出token数
|
"eval_count": completion_tokens, # 输出token数
|
||||||
"eval_duration": eval_time # 生成响应的时间
|
"eval_duration": eval_time, # 生成响应的时间
|
||||||
}
|
}
|
||||||
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||||
return # 确保生成器在发送完成标记后立即结束
|
return # 确保生成器在发送完成标记后立即结束
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error in stream_generator: {str(e)}")
|
logging.error(f"Error in stream_generator: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
stream_generator(),
|
stream_generator(),
|
||||||
media_type="application/x-ndjson",
|
media_type="application/x-ndjson",
|
||||||
@@ -836,28 +858,25 @@ def create_app(args):
|
|||||||
"Content-Type": "application/x-ndjson",
|
"Content-Type": "application/x-ndjson",
|
||||||
"Access-Control-Allow-Origin": "*",
|
"Access-Control-Allow-Origin": "*",
|
||||||
"Access-Control-Allow-Methods": "POST, OPTIONS",
|
"Access-Control-Allow-Methods": "POST, OPTIONS",
|
||||||
"Access-Control-Allow-Headers": "Content-Type"
|
"Access-Control-Allow-Headers": "Content-Type",
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# 非流式响应
|
# 非流式响应
|
||||||
first_chunk_time = time.time_ns()
|
first_chunk_time = time.time_ns()
|
||||||
response_text = await rag.aquery(
|
response_text = await rag.aquery(cleaned_query, param=query_param)
|
||||||
cleaned_query,
|
|
||||||
param=query_param
|
|
||||||
)
|
|
||||||
last_chunk_time = time.time_ns()
|
last_chunk_time = time.time_ns()
|
||||||
|
|
||||||
# 确保响应不为空
|
# 确保响应不为空
|
||||||
if not response_text:
|
if not response_text:
|
||||||
response_text = "No response generated"
|
response_text = "No response generated"
|
||||||
|
|
||||||
# 计算各项指标
|
# 计算各项指标
|
||||||
completion_tokens = estimate_tokens(str(response_text))
|
completion_tokens = estimate_tokens(str(response_text))
|
||||||
total_time = last_chunk_time - start_time # 总时间
|
total_time = last_chunk_time - start_time # 总时间
|
||||||
prompt_eval_time = first_chunk_time - start_time # 首个响应之前的时间
|
prompt_eval_time = first_chunk_time - start_time # 首个响应之前的时间
|
||||||
eval_time = last_chunk_time - first_chunk_time # 生成响应的时间
|
eval_time = last_chunk_time - first_chunk_time # 生成响应的时间
|
||||||
|
|
||||||
# 构造响应,包含性能统计信息
|
# 构造响应,包含性能统计信息
|
||||||
return {
|
return {
|
||||||
"model": LIGHTRAG_MODEL,
|
"model": LIGHTRAG_MODEL,
|
||||||
@@ -865,7 +884,7 @@ def create_app(args):
|
|||||||
"message": {
|
"message": {
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": str(response_text), # 确保转换为字符串
|
"content": str(response_text), # 确保转换为字符串
|
||||||
"images": None
|
"images": None,
|
||||||
},
|
},
|
||||||
"done": True,
|
"done": True,
|
||||||
"total_duration": total_time, # 总时间
|
"total_duration": total_time, # 总时间
|
||||||
@@ -873,7 +892,7 @@ def create_app(args):
|
|||||||
"prompt_eval_count": prompt_tokens, # 输入token数
|
"prompt_eval_count": prompt_tokens, # 输入token数
|
||||||
"prompt_eval_duration": prompt_eval_time, # 首个响应之前的时间
|
"prompt_eval_duration": prompt_eval_time, # 首个响应之前的时间
|
||||||
"eval_count": completion_tokens, # 输出token数
|
"eval_count": completion_tokens, # 输出token数
|
||||||
"eval_duration": eval_time # 生成响应的时间
|
"eval_duration": eval_time, # 生成响应的时间
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
@@ -18,8 +18,10 @@ from dataclasses import dataclass, asdict
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
class OutputControl:
|
class OutputControl:
|
||||||
"""Output control class, manages the verbosity of test output"""
|
"""Output control class, manages the verbosity of test output"""
|
||||||
|
|
||||||
_verbose: bool = False
|
_verbose: bool = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -30,9 +32,11 @@ class OutputControl:
|
|||||||
def is_verbose(cls) -> bool:
|
def is_verbose(cls) -> bool:
|
||||||
return cls._verbose
|
return cls._verbose
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TestResult:
|
class TestResult:
|
||||||
"""Test result data class"""
|
"""Test result data class"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
success: bool
|
success: bool
|
||||||
duration: float
|
duration: float
|
||||||
@@ -43,8 +47,10 @@ class TestResult:
|
|||||||
if not self.timestamp:
|
if not self.timestamp:
|
||||||
self.timestamp = datetime.now().isoformat()
|
self.timestamp = datetime.now().isoformat()
|
||||||
|
|
||||||
|
|
||||||
class TestStats:
|
class TestStats:
|
||||||
"""Test statistics"""
|
"""Test statistics"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.results: List[TestResult] = []
|
self.results: List[TestResult] = []
|
||||||
self.start_time = datetime.now()
|
self.start_time = datetime.now()
|
||||||
@@ -65,8 +71,8 @@ class TestStats:
|
|||||||
"total": len(self.results),
|
"total": len(self.results),
|
||||||
"passed": sum(1 for r in self.results if r.success),
|
"passed": sum(1 for r in self.results if r.success),
|
||||||
"failed": sum(1 for r in self.results if not r.success),
|
"failed": sum(1 for r in self.results if not r.success),
|
||||||
"total_duration": sum(r.duration for r in self.results)
|
"total_duration": sum(r.duration for r in self.results),
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
with open(path, "w", encoding="utf-8") as f:
|
with open(path, "w", encoding="utf-8") as f:
|
||||||
@@ -92,6 +98,7 @@ class TestStats:
|
|||||||
if not result.success:
|
if not result.success:
|
||||||
print(f"- {result.name}: {result.error}")
|
print(f"- {result.name}: {result.error}")
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_CONFIG = {
|
DEFAULT_CONFIG = {
|
||||||
"server": {
|
"server": {
|
||||||
"host": "localhost",
|
"host": "localhost",
|
||||||
@@ -99,16 +106,15 @@ DEFAULT_CONFIG = {
|
|||||||
"model": "lightrag:latest",
|
"model": "lightrag:latest",
|
||||||
"timeout": 30,
|
"timeout": 30,
|
||||||
"max_retries": 3,
|
"max_retries": 3,
|
||||||
"retry_delay": 1
|
"retry_delay": 1,
|
||||||
},
|
},
|
||||||
"test_cases": {
|
"test_cases": {"basic": {"query": "唐僧有几个徒弟"}},
|
||||||
"basic": {
|
|
||||||
"query": "唐僧有几个徒弟"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> requests.Response:
|
|
||||||
|
def make_request(
|
||||||
|
url: str, data: Dict[str, Any], stream: bool = False
|
||||||
|
) -> requests.Response:
|
||||||
"""Send an HTTP request with retry mechanism
|
"""Send an HTTP request with retry mechanism
|
||||||
Args:
|
Args:
|
||||||
url: Request URL
|
url: Request URL
|
||||||
@@ -127,12 +133,7 @@ def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> reques
|
|||||||
|
|
||||||
for attempt in range(max_retries):
|
for attempt in range(max_retries):
|
||||||
try:
|
try:
|
||||||
response = requests.post(
|
response = requests.post(url, json=data, stream=stream, timeout=timeout)
|
||||||
url,
|
|
||||||
json=data,
|
|
||||||
stream=stream,
|
|
||||||
timeout=timeout
|
|
||||||
)
|
|
||||||
return response
|
return response
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
if attempt == max_retries - 1: # Last retry
|
if attempt == max_retries - 1: # Last retry
|
||||||
@@ -140,6 +141,7 @@ def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> reques
|
|||||||
print(f"\nRequest failed, retrying in {retry_delay} seconds: {str(e)}")
|
print(f"\nRequest failed, retrying in {retry_delay} seconds: {str(e)}")
|
||||||
time.sleep(retry_delay)
|
time.sleep(retry_delay)
|
||||||
|
|
||||||
|
|
||||||
def load_config() -> Dict[str, Any]:
|
def load_config() -> Dict[str, Any]:
|
||||||
"""Load configuration file
|
"""Load configuration file
|
||||||
|
|
||||||
@@ -154,6 +156,7 @@ def load_config() -> Dict[str, Any]:
|
|||||||
return json.load(f)
|
return json.load(f)
|
||||||
return DEFAULT_CONFIG
|
return DEFAULT_CONFIG
|
||||||
|
|
||||||
|
|
||||||
def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2) -> None:
|
def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2) -> None:
|
||||||
"""Format and print JSON response data
|
"""Format and print JSON response data
|
||||||
Args:
|
Args:
|
||||||
@@ -166,18 +169,19 @@ def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2)
|
|||||||
print(f"\n=== {title} ===")
|
print(f"\n=== {title} ===")
|
||||||
print(json.dumps(data, ensure_ascii=False, indent=indent))
|
print(json.dumps(data, ensure_ascii=False, indent=indent))
|
||||||
|
|
||||||
|
|
||||||
# Global configuration
|
# Global configuration
|
||||||
CONFIG = load_config()
|
CONFIG = load_config()
|
||||||
|
|
||||||
|
|
||||||
def get_base_url() -> str:
|
def get_base_url() -> str:
|
||||||
"""Return the base URL"""
|
"""Return the base URL"""
|
||||||
server = CONFIG["server"]
|
server = CONFIG["server"]
|
||||||
return f"http://{server['host']}:{server['port']}/api/chat"
|
return f"http://{server['host']}:{server['port']}/api/chat"
|
||||||
|
|
||||||
|
|
||||||
def create_request_data(
|
def create_request_data(
|
||||||
content: str,
|
content: str, stream: bool = False, model: str = None
|
||||||
stream: bool = False,
|
|
||||||
model: str = None
|
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Create basic request data
|
"""Create basic request data
|
||||||
Args:
|
Args:
|
||||||
@@ -189,18 +193,15 @@ def create_request_data(
|
|||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
"model": model or CONFIG["server"]["model"],
|
"model": model or CONFIG["server"]["model"],
|
||||||
"messages": [
|
"messages": [{"role": "user", "content": content}],
|
||||||
{
|
"stream": stream,
|
||||||
"role": "user",
|
|
||||||
"content": content
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"stream": stream
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# Global test statistics
|
# Global test statistics
|
||||||
STATS = TestStats()
|
STATS = TestStats()
|
||||||
|
|
||||||
|
|
||||||
def run_test(func: Callable, name: str) -> None:
|
def run_test(func: Callable, name: str) -> None:
|
||||||
"""Run a test and record the results
|
"""Run a test and record the results
|
||||||
Args:
|
Args:
|
||||||
@@ -217,13 +218,11 @@ def run_test(func: Callable, name: str) -> None:
|
|||||||
STATS.add_result(TestResult(name, False, duration, str(e)))
|
STATS.add_result(TestResult(name, False, duration, str(e)))
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def test_non_stream_chat():
|
def test_non_stream_chat():
|
||||||
"""Test non-streaming call to /api/chat endpoint"""
|
"""Test non-streaming call to /api/chat endpoint"""
|
||||||
url = get_base_url()
|
url = get_base_url()
|
||||||
data = create_request_data(
|
data = create_request_data(CONFIG["test_cases"]["basic"]["query"], stream=False)
|
||||||
CONFIG["test_cases"]["basic"]["query"],
|
|
||||||
stream=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Send request
|
# Send request
|
||||||
response = make_request(url, data)
|
response = make_request(url, data)
|
||||||
@@ -234,10 +233,12 @@ def test_non_stream_chat():
|
|||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
|
|
||||||
# Print response content
|
# Print response content
|
||||||
print_json_response({
|
print_json_response(
|
||||||
"model": response_json["model"],
|
{"model": response_json["model"], "message": response_json["message"]},
|
||||||
"message": response_json["message"]
|
"Response content",
|
||||||
}, "Response content")
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_stream_chat():
|
def test_stream_chat():
|
||||||
"""Test streaming call to /api/chat endpoint
|
"""Test streaming call to /api/chat endpoint
|
||||||
|
|
||||||
@@ -257,10 +258,7 @@ def test_stream_chat():
|
|||||||
The last message will contain performance statistics, with done set to true.
|
The last message will contain performance statistics, with done set to true.
|
||||||
"""
|
"""
|
||||||
url = get_base_url()
|
url = get_base_url()
|
||||||
data = create_request_data(
|
data = create_request_data(CONFIG["test_cases"]["basic"]["query"], stream=True)
|
||||||
CONFIG["test_cases"]["basic"]["query"],
|
|
||||||
stream=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Send request and get streaming response
|
# Send request and get streaming response
|
||||||
response = make_request(url, data, stream=True)
|
response = make_request(url, data, stream=True)
|
||||||
@@ -273,9 +271,11 @@ def test_stream_chat():
|
|||||||
if line: # Skip empty lines
|
if line: # Skip empty lines
|
||||||
try:
|
try:
|
||||||
# Decode and parse JSON
|
# Decode and parse JSON
|
||||||
data = json.loads(line.decode('utf-8'))
|
data = json.loads(line.decode("utf-8"))
|
||||||
if data.get("done", True): # If it's the completion marker
|
if data.get("done", True): # If it's the completion marker
|
||||||
if "total_duration" in data: # Final performance statistics message
|
if (
|
||||||
|
"total_duration" in data
|
||||||
|
): # Final performance statistics message
|
||||||
# print_json_response(data, "Performance statistics")
|
# print_json_response(data, "Performance statistics")
|
||||||
break
|
break
|
||||||
else: # Normal content message
|
else: # Normal content message
|
||||||
@@ -283,7 +283,9 @@ def test_stream_chat():
|
|||||||
content = message.get("content", "")
|
content = message.get("content", "")
|
||||||
if content: # Only collect non-empty content
|
if content: # Only collect non-empty content
|
||||||
output_buffer.append(content)
|
output_buffer.append(content)
|
||||||
print(content, end="", flush=True) # Print content in real-time
|
print(
|
||||||
|
content, end="", flush=True
|
||||||
|
) # Print content in real-time
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
print("Error decoding JSON from response line")
|
print("Error decoding JSON from response line")
|
||||||
finally:
|
finally:
|
||||||
@@ -292,6 +294,7 @@ def test_stream_chat():
|
|||||||
# Print a newline
|
# Print a newline
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
||||||
def test_query_modes():
|
def test_query_modes():
|
||||||
"""Test different query mode prefixes
|
"""Test different query mode prefixes
|
||||||
|
|
||||||
@@ -311,8 +314,7 @@ def test_query_modes():
|
|||||||
if OutputControl.is_verbose():
|
if OutputControl.is_verbose():
|
||||||
print(f"\n=== Testing /{mode} mode ===")
|
print(f"\n=== Testing /{mode} mode ===")
|
||||||
data = create_request_data(
|
data = create_request_data(
|
||||||
f"/{mode} {CONFIG['test_cases']['basic']['query']}",
|
f"/{mode} {CONFIG['test_cases']['basic']['query']}", stream=False
|
||||||
stream=False
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send request
|
# Send request
|
||||||
@@ -320,10 +322,10 @@ def test_query_modes():
|
|||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
|
|
||||||
# Print response content
|
# Print response content
|
||||||
print_json_response({
|
print_json_response(
|
||||||
"model": response_json["model"],
|
{"model": response_json["model"], "message": response_json["message"]}
|
||||||
"message": response_json["message"]
|
)
|
||||||
})
|
|
||||||
|
|
||||||
def create_error_test_data(error_type: str) -> Dict[str, Any]:
|
def create_error_test_data(error_type: str) -> Dict[str, Any]:
|
||||||
"""Create request data for error testing
|
"""Create request data for error testing
|
||||||
@@ -337,33 +339,21 @@ def create_error_test_data(error_type: str) -> Dict[str, Any]:
|
|||||||
Request dictionary containing error data
|
Request dictionary containing error data
|
||||||
"""
|
"""
|
||||||
error_data = {
|
error_data = {
|
||||||
"empty_messages": {
|
"empty_messages": {"model": "lightrag:latest", "messages": [], "stream": True},
|
||||||
"model": "lightrag:latest",
|
|
||||||
"messages": [],
|
|
||||||
"stream": True
|
|
||||||
},
|
|
||||||
"invalid_role": {
|
"invalid_role": {
|
||||||
"model": "lightrag:latest",
|
"model": "lightrag:latest",
|
||||||
"messages": [
|
"messages": [{"invalid_role": "user", "content": "Test message"}],
|
||||||
{
|
"stream": True,
|
||||||
"invalid_role": "user",
|
|
||||||
"content": "Test message"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"stream": True
|
|
||||||
},
|
},
|
||||||
"missing_content": {
|
"missing_content": {
|
||||||
"model": "lightrag:latest",
|
"model": "lightrag:latest",
|
||||||
"messages": [
|
"messages": [{"role": "user"}],
|
||||||
{
|
"stream": True,
|
||||||
"role": "user"
|
},
|
||||||
}
|
|
||||||
],
|
|
||||||
"stream": True
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return error_data.get(error_type, error_data["empty_messages"])
|
return error_data.get(error_type, error_data["empty_messages"])
|
||||||
|
|
||||||
|
|
||||||
def test_stream_error_handling():
|
def test_stream_error_handling():
|
||||||
"""Test error handling for streaming responses
|
"""Test error handling for streaming responses
|
||||||
|
|
||||||
@@ -409,6 +399,7 @@ def test_stream_error_handling():
|
|||||||
print_json_response(response.json(), "Error message")
|
print_json_response(response.json(), "Error message")
|
||||||
response.close()
|
response.close()
|
||||||
|
|
||||||
|
|
||||||
def test_error_handling():
|
def test_error_handling():
|
||||||
"""Test error handling for non-streaming responses
|
"""Test error handling for non-streaming responses
|
||||||
|
|
||||||
@@ -455,6 +446,7 @@ def test_error_handling():
|
|||||||
print(f"Status code: {response.status_code}")
|
print(f"Status code: {response.status_code}")
|
||||||
print_json_response(response.json(), "Error message")
|
print_json_response(response.json(), "Error message")
|
||||||
|
|
||||||
|
|
||||||
def get_test_cases() -> Dict[str, Callable]:
|
def get_test_cases() -> Dict[str, Callable]:
|
||||||
"""Get all available test cases
|
"""Get all available test cases
|
||||||
Returns:
|
Returns:
|
||||||
@@ -465,9 +457,10 @@ def get_test_cases() -> Dict[str, Callable]:
|
|||||||
"stream": test_stream_chat,
|
"stream": test_stream_chat,
|
||||||
"modes": test_query_modes,
|
"modes": test_query_modes,
|
||||||
"errors": test_error_handling,
|
"errors": test_error_handling,
|
||||||
"stream_errors": test_stream_error_handling
|
"stream_errors": test_stream_error_handling,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def create_default_config():
|
def create_default_config():
|
||||||
"""Create a default configuration file"""
|
"""Create a default configuration file"""
|
||||||
config_path = Path("config.json")
|
config_path = Path("config.json")
|
||||||
@@ -476,6 +469,7 @@ def create_default_config():
|
|||||||
json.dump(DEFAULT_CONFIG, f, ensure_ascii=False, indent=2)
|
json.dump(DEFAULT_CONFIG, f, ensure_ascii=False, indent=2)
|
||||||
print(f"Default configuration file created: {config_path}")
|
print(f"Default configuration file created: {config_path}")
|
||||||
|
|
||||||
|
|
||||||
def parse_args() -> argparse.Namespace:
|
def parse_args() -> argparse.Namespace:
|
||||||
"""Parse command line arguments"""
|
"""Parse command line arguments"""
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@@ -496,38 +490,39 @@ Configuration file (config.json):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-q", "--quiet",
|
"-q",
|
||||||
|
"--quiet",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Silent mode, only display test result summary"
|
help="Silent mode, only display test result summary",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-a", "--ask",
|
"-a",
|
||||||
|
"--ask",
|
||||||
type=str,
|
type=str,
|
||||||
help="Specify query content, which will override the query settings in the configuration file"
|
help="Specify query content, which will override the query settings in the configuration file",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--init-config",
|
"--init-config", action="store_true", help="Create default configuration file"
|
||||||
action="store_true",
|
|
||||||
help="Create default configuration file"
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output",
|
"--output",
|
||||||
type=str,
|
type=str,
|
||||||
default="",
|
default="",
|
||||||
help="Test result output file path, default is not to output to a file"
|
help="Test result output file path, default is not to output to a file",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tests",
|
"--tests",
|
||||||
nargs="+",
|
nargs="+",
|
||||||
choices=list(get_test_cases().keys()) + ["all"],
|
choices=list(get_test_cases().keys()) + ["all"],
|
||||||
default=["all"],
|
default=["all"],
|
||||||
help="Test cases to run, options: %(choices)s. Use 'all' to run all tests"
|
help="Test cases to run, options: %(choices)s. Use 'all' to run all tests",
|
||||||
)
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user