From f441a454537b05182baa0c8584d48940fa455d9d Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 15 Jan 2025 22:15:46 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E5=96=84=E6=B5=8B=E8=AF=95=E7=94=A8?= =?UTF-8?q?=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 1 + test_lightrag_ollama_chat.py | 614 ++++++++++++++++++++++++++++++++--- 2 files changed, 561 insertions(+), 54 deletions(-) diff --git a/.gitignore b/.gitignore index 5749adb5..8ac420b1 100644 --- a/.gitignore +++ b/.gitignore @@ -21,3 +21,4 @@ rag_storage venv/ examples/input/ examples/output/ +test_results.json diff --git a/test_lightrag_ollama_chat.py b/test_lightrag_ollama_chat.py index f8947585..ae32a3b8 100644 --- a/test_lightrag_ollama_chat.py +++ b/test_lightrag_ollama_chat.py @@ -1,96 +1,602 @@ +""" +LightRAG Ollama 兼容接口测试脚本 + +这个脚本测试 LightRAG 的 Ollama 兼容接口,包括: +1. 基本功能测试(流式和非流式响应) +2. 查询模式测试(local、global、naive、hybrid) +3. 错误处理测试(包括流式和非流式场景) + +所有响应都使用 JSON Lines 格式,符合 Ollama API 规范。 +""" + import requests import json -import sseclient +import argparse +import os +import time +from typing import Dict, Any, Optional, List, Callable, Tuple +from dataclasses import dataclass, asdict +from datetime import datetime +from pathlib import Path -def test_non_stream_chat(): - """测试非流式调用 /api/chat 接口""" - url = "http://localhost:9621/api/chat" +class OutputControl: + """输出控制类,管理测试输出的详细程度""" + _verbose: bool = False + + @classmethod + def set_verbose(cls, verbose: bool) -> None: + """设置输出详细程度 + + Args: + verbose: True 为详细模式,False 为静默模式 + """ + cls._verbose = verbose + + @classmethod + def is_verbose(cls) -> bool: + """获取当前输出模式 + + Returns: + 当前是否为详细模式 + """ + return cls._verbose + +@dataclass +class TestResult: + """测试结果数据类""" + name: str + success: bool + duration: float + error: Optional[str] = None + timestamp: str = "" - # 构造请求数据 - data = { + def __post_init__(self): + """初始化后设置时间戳""" + if not self.timestamp: + self.timestamp = datetime.now().isoformat() + +class TestStats: + """测试统计信息""" + def __init__(self): + self.results: List[TestResult] = [] + self.start_time = datetime.now() + + def add_result(self, result: TestResult): + """添加测试结果""" + self.results.append(result) + + def export_results(self, path: str = "test_results.json"): + """导出测试结果到 JSON 文件 + + Args: + path: 输出文件路径 + """ + results_data = { + "start_time": self.start_time.isoformat(), + "end_time": datetime.now().isoformat(), + "results": [asdict(r) for r in self.results], + "summary": { + "total": len(self.results), + "passed": sum(1 for r in self.results if r.success), + "failed": sum(1 for r in self.results if not r.success), + "total_duration": sum(r.duration for r in self.results) + } + } + + with open(path, "w", encoding="utf-8") as f: + json.dump(results_data, f, ensure_ascii=False, indent=2) + print(f"\n测试结果已保存到: {path}") + + def print_summary(self): + """打印测试统计摘要""" + total = len(self.results) + passed = sum(1 for r in self.results if r.success) + failed = total - passed + duration = sum(r.duration for r in self.results) + + print("\n=== 测试结果摘要 ===") + print(f"开始时间: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}") + print(f"总用时: {duration:.2f}秒") + print(f"总计: {total} 个测试") + print(f"通过: {passed} 个") + print(f"失败: {failed} 个") + + if failed > 0: + print("\n失败的测试:") + for result in self.results: + if not result.success: + print(f"- {result.name}: {result.error}") + +# 默认配置 +DEFAULT_CONFIG = { + "server": { + "host": "localhost", + "port": 9621, "model": "lightrag:latest", + "timeout": 30, # 请求超时时间(秒) + "max_retries": 3, # 最大重试次数 + "retry_delay": 1 # 重试间隔(秒) + }, + "test_cases": { + "basic": { + "query": "孙悟空", + "stream_query": "孙悟空有什么法力,性格特征是什么" + } + } +} + +def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> requests.Response: + """发送 HTTP 请求,支持重试机制 + + Args: + url: 请求 URL + data: 请求数据 + stream: 是否使用流式响应 + + Returns: + requests.Response 对象 + + Raises: + requests.exceptions.RequestException: 请求失败且重试次数用完 + """ + server_config = CONFIG["server"] + max_retries = server_config["max_retries"] + retry_delay = server_config["retry_delay"] + timeout = server_config["timeout"] + + for attempt in range(max_retries): + try: + response = requests.post( + url, + json=data, + stream=stream, + timeout=timeout + ) + return response + except requests.exceptions.RequestException as e: + if attempt == max_retries - 1: # 最后一次重试 + raise + print(f"\n请求失败,{retry_delay}秒后重试: {str(e)}") + time.sleep(retry_delay) + +def load_config() -> Dict[str, Any]: + """加载配置文件 + + 首先尝试从当前目录的 config.json 加载, + 如果不存在则使用默认配置 + + Returns: + 配置字典 + """ + config_path = Path("config.json") + if config_path.exists(): + with open(config_path, "r", encoding="utf-8") as f: + return json.load(f) + return DEFAULT_CONFIG + +def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2) -> None: + """格式化打印 JSON 响应数据 + + Args: + data: 要打印的数据字典 + title: 打印的标题 + indent: JSON 缩进空格数 + """ + if OutputControl.is_verbose(): + if title: + print(f"\n=== {title} ===") + print(json.dumps(data, ensure_ascii=False, indent=indent)) + +# 全局配置 +CONFIG = load_config() + +def get_base_url() -> str: + """返回基础 URL""" + server = CONFIG["server"] + return f"http://{server['host']}:{server['port']}/api/chat" + +def create_request_data( + content: str, + stream: bool = False, + model: str = None +) -> Dict[str, Any]: + """创建基本的请求数据 + + Args: + content: 用户消息内容 + stream: 是否使用流式响应 + model: 模型名称 + + Returns: + 包含完整请求数据的字典 + """ + return { + "model": model or CONFIG["server"]["model"], "messages": [ { "role": "user", - "content": "孙悟空" + "content": content } ], - "stream": False + "stream": stream } + +# 全局测试统计 +STATS = TestStats() + +def run_test(func: Callable, name: str) -> None: + """运行测试并记录结果 + + Args: + func: 测试函数 + name: 测试名称 + """ + start_time = time.time() + try: + func() + duration = time.time() - start_time + STATS.add_result(TestResult(name, True, duration)) + except Exception as e: + duration = time.time() - start_time + STATS.add_result(TestResult(name, False, duration, str(e))) + raise + +def test_non_stream_chat(): + """测试非流式调用 /api/chat 接口""" + url = get_base_url() + data = create_request_data( + CONFIG["test_cases"]["basic"]["query"], + stream=False + ) # 发送请求 - response = requests.post(url, json=data) + response = make_request(url, data) # 打印响应 - print("\n=== 非流式调用响应 ===") + if OutputControl.is_verbose(): + print("\n=== 非流式调用响应 ===") response_json = response.json() - # 打印消息内容 - print("=== 响应内容 ===") - print(json.dumps({ + # 打印响应内容 + print_json_response({ "model": response_json["model"], "message": response_json["message"] - }, ensure_ascii=False, indent=2)) + }, "响应内容") # 打印性能统计 - print("\n=== 性能统计 ===") - stats = { + print_json_response({ "total_duration": response_json["total_duration"], "load_duration": response_json["load_duration"], "prompt_eval_count": response_json["prompt_eval_count"], "prompt_eval_duration": response_json["prompt_eval_duration"], "eval_count": response_json["eval_count"], "eval_duration": response_json["eval_duration"] - } - print(json.dumps(stats, ensure_ascii=False, indent=2)) + }, "性能统计") def test_stream_chat(): - """测试流式调用 /api/chat 接口""" - url = "http://localhost:9621/api/chat" + """测试流式调用 /api/chat 接口 - # 构造请求数据 - data = { + 使用 JSON Lines 格式处理流式响应,每行是一个完整的 JSON 对象。 + 响应格式: + { "model": "lightrag:latest", - "messages": [ - { - "role": "user", - "content": "孙悟空有什么法力,性格特征是什么" - } - ], - "stream": True + "created_at": "2024-01-15T00:00:00Z", + "message": { + "role": "assistant", + "content": "部分响应内容", + "images": null + }, + "done": false } - # 发送请求并获取 SSE 流 - response = requests.post(url, json=data, stream=True) - client = sseclient.SSEClient(response) + 最后一条消息会包含性能统计信息,done 为 true。 + """ + url = get_base_url() + data = create_request_data( + CONFIG["test_cases"]["basic"]["stream_query"], + stream=True + ) - print("\n=== 流式调用响应 ===") + # 发送请求并获取流式响应 + response = make_request(url, data, stream=True) + + if OutputControl.is_verbose(): + print("\n=== 流式调用响应 ===") output_buffer = [] try: - for event in client.events(): - try: - data = json.loads(event.data) - if data.get("done", True): # 如果是完成标记 - if "total_duration" in data: # 最终的性能统计消息 - print("\n=== 性能统计 ===") - print(json.dumps(data, ensure_ascii=False, indent=2)) - break - else: # 正常的内容消息 - message = data.get("message", {}) - content = message.get("content", "") - if content: # 只收集非空内容 - output_buffer.append(content) - except json.JSONDecodeError: - print("Error decoding JSON from SSE event") + for line in response.iter_lines(): + if line: # 跳过空行 + try: + # 解码并解析 JSON + data = json.loads(line.decode('utf-8')) + if data.get("done", True): # 如果是完成标记 + if "total_duration" in data: # 最终的性能统计消息 + print_json_response(data, "性能统计") + break + else: # 正常的内容消息 + message = data.get("message", {}) + content = message.get("content", "") + if content: # 只收集非空内容 + output_buffer.append(content) + print(content, end="", flush=True) # 实时打印内容 + except json.JSONDecodeError: + print("Error decoding JSON from response line") finally: response.close() # 确保关闭响应连接 - # 一次性打印所有收集到的内容 - print("".join(output_buffer)) + # 打印一个换行 + print() + +def test_query_modes(): + """测试不同的查询模式前缀 + + 支持的查询模式: + - /local: 本地检索模式,只在相关度高的文档中搜索 + - /global: 全局检索模式,在所有文档中搜索 + - /naive: 朴素模式,不使用任何优化策略 + - /hybrid: 混合模式(默认),结合多种策略 + + 每个模式都会返回相同格式的响应,但检索策略不同。 + """ + url = get_base_url() + modes = ["local", "global", "naive", "hybrid"] # 支持的查询模式 + + for mode in modes: + if OutputControl.is_verbose(): + print(f"\n=== 测试 /{mode} 模式 ===") + data = create_request_data( + f"/{mode} 孙悟空的特点", + stream=False + ) + + # 发送请求 + response = make_request(url, data) + response_json = response.json() + + # 打印响应内容 + print_json_response({ + "model": response_json["model"], + "message": response_json["message"] + }) + +def create_error_test_data(error_type: str) -> Dict[str, Any]: + """创建用于错误测试的请求数据 + + Args: + error_type: 错误类型,支持: + - empty_messages: 空消息列表 + - invalid_role: 无效的角色字段 + - missing_content: 缺少内容字段 + + Returns: + 包含错误数据的请求字典 + """ + error_data = { + "empty_messages": { + "model": "lightrag:latest", + "messages": [], + "stream": True + }, + "invalid_role": { + "model": "lightrag:latest", + "messages": [ + { + "invalid_role": "user", + "content": "测试消息" + } + ], + "stream": True + }, + "missing_content": { + "model": "lightrag:latest", + "messages": [ + { + "role": "user" + } + ], + "stream": True + } + } + return error_data.get(error_type, error_data["empty_messages"]) + +def test_stream_error_handling(): + """测试流式响应的错误处理 + + 测试场景: + 1. 空消息列表 + 2. 消息格式错误(缺少必需字段) + + 错误响应会立即返回,不会建立流式连接。 + 状态码应该是 4xx,并返回详细的错误信息。 + """ + url = get_base_url() + + if OutputControl.is_verbose(): + print("\n=== 测试流式响应错误处理 ===") + + # 测试空消息列表 + if OutputControl.is_verbose(): + print("\n--- 测试空消息列表(流式)---") + data = create_error_test_data("empty_messages") + response = make_request(url, data, stream=True) + print(f"状态码: {response.status_code}") + if response.status_code != 200: + print_json_response(response.json(), "错误信息") + response.close() + + # 测试无效角色字段 + if OutputControl.is_verbose(): + print("\n--- 测试无效角色字段(流式)---") + data = create_error_test_data("invalid_role") + response = make_request(url, data, stream=True) + print(f"状态码: {response.status_code}") + if response.status_code != 200: + print_json_response(response.json(), "错误信息") + response.close() + + # 测试缺少内容字段 + if OutputControl.is_verbose(): + print("\n--- 测试缺少内容字段(流式)---") + data = create_error_test_data("missing_content") + response = make_request(url, data, stream=True) + print(f"状态码: {response.status_code}") + if response.status_code != 200: + print_json_response(response.json(), "错误信息") + response.close() + +def test_error_handling(): + """测试非流式响应的错误处理 + + 测试场景: + 1. 空消息列表 + 2. 消息格式错误(缺少必需字段) + + 错误响应格式: + { + "detail": "错误描述" + } + + 所有错误都应该返回合适的 HTTP 状态码和清晰的错误信息。 + """ + url = get_base_url() + + if OutputControl.is_verbose(): + print("\n=== 测试错误处理 ===") + + # 测试空消息列表 + if OutputControl.is_verbose(): + print("\n--- 测试空消息列表 ---") + data = create_error_test_data("empty_messages") + data["stream"] = False # 修改为非流式模式 + response = make_request(url, data) + print(f"状态码: {response.status_code}") + print_json_response(response.json(), "错误信息") + + # 测试无效角色字段 + if OutputControl.is_verbose(): + print("\n--- 测试无效角色字段 ---") + data = create_error_test_data("invalid_role") + data["stream"] = False # 修改为非流式模式 + response = make_request(url, data) + print(f"状态码: {response.status_code}") + print_json_response(response.json(), "错误信息") + + # 测试缺少内容字段 + if OutputControl.is_verbose(): + print("\n--- 测试缺少内容字段 ---") + data = create_error_test_data("missing_content") + data["stream"] = False # 修改为非流式模式 + response = make_request(url, data) + print(f"状态码: {response.status_code}") + print_json_response(response.json(), "错误信息") + +def get_test_cases() -> Dict[str, Callable]: + """获取所有可用的测试用例 + + Returns: + 测试名称到测试函数的映射字典 + """ + return { + "non_stream": test_non_stream_chat, + "stream": test_stream_chat, + "modes": test_query_modes, + "errors": test_error_handling, + "stream_errors": test_stream_error_handling + } + +def create_default_config(): + """创建默认配置文件""" + config_path = Path("config.json") + if not config_path.exists(): + with open(config_path, "w", encoding="utf-8") as f: + json.dump(DEFAULT_CONFIG, f, ensure_ascii=False, indent=2) + print(f"已创建默认配置文件: {config_path}") + +def parse_args() -> argparse.Namespace: + """解析命令行参数""" + parser = argparse.ArgumentParser( + description="LightRAG Ollama 兼容接口测试", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +配置文件 (config.json): + { + "server": { + "host": "localhost", # 服务器地址 + "port": 9621, # 服务器端口 + "model": "lightrag:latest" # 默认模型名称 + }, + "test_cases": { + "basic": { + "query": "测试查询", # 基本查询文本 + "stream_query": "流式查询" # 流式查询文本 + } + } + } +""" + ) + parser.add_argument( + "--tests", + nargs="+", + choices=list(get_test_cases().keys()) + ["all"], + default=["all"], + help="要运行的测试用例,可选: %(choices)s。使用 all 运行所有测试" + ) + parser.add_argument( + "--init-config", + action="store_true", + help="创建默认配置文件" + ) + parser.add_argument( + "--output", + type=str, + default="test_results.json", + help="测试结果输出文件路径" + ) + parser.add_argument( + "-q", "--quiet", + action="store_true", + help="静默模式,只显示测试结果摘要" + ) + return parser.parse_args() if __name__ == "__main__": - # 先测试非流式调用 - test_non_stream_chat() + args = parse_args() - # 再测试流式调用 - test_stream_chat() + # 设置输出模式 + OutputControl.set_verbose(not args.quiet) + + # 如果指定了创建配置文件 + if args.init_config: + create_default_config() + exit(0) + + test_cases = get_test_cases() + + try: + if "all" in args.tests: + # 运行所有测试 + if OutputControl.is_verbose(): + print("\n【基本功能测试】") + run_test(test_non_stream_chat, "非流式调用测试") + run_test(test_stream_chat, "流式调用测试") + + if OutputControl.is_verbose(): + print("\n【查询模式测试】") + run_test(test_query_modes, "查询模式测试") + + if OutputControl.is_verbose(): + print("\n【错误处理测试】") + run_test(test_error_handling, "错误处理测试") + run_test(test_stream_error_handling, "流式错误处理测试") + else: + # 运行指定的测试 + for test_name in args.tests: + if OutputControl.is_verbose(): + print(f"\n【运行测试: {test_name}】") + run_test(test_cases[test_name], test_name) + except Exception as e: + print(f"\n发生错误: {str(e)}") + finally: + # 打印并导出测试统计 + STATS.print_summary() + STATS.export_results(args.output)