diff --git a/lightrag/api/lightrag_ollama.py b/lightrag/api/lightrag_ollama.py index 82329806..bb9b1ac5 100644 --- a/lightrag/api/lightrag_ollama.py +++ b/lightrag/api/lightrag_ollama.py @@ -27,15 +27,15 @@ from dotenv import load_dotenv load_dotenv() def estimate_tokens(text: str) -> int: - """估算文本的token数量 - 中文每字约1.5个token - 英文每字约0.25个token + """Estimate the number of tokens in text + Chinese characters: approximately 1.5 tokens per character + English characters: approximately 0.25 tokens per character """ - # 使用正则表达式分别匹配中文字符和非中文字符 + # Use regex to match Chinese and non-Chinese characters separately chinese_chars = len(re.findall(r'[\u4e00-\u9fff]', text)) non_chinese_chars = len(re.findall(r'[^\u4e00-\u9fff]', text)) - # 计算估算的token数量 + # Calculate estimated token count tokens = chinese_chars * 1.5 + non_chinese_chars * 0.25 return int(tokens) @@ -241,7 +241,7 @@ class DocumentManager: class SearchMode(str, Enum): naive = "naive" local = "local" - global_ = "global" # 使用 global_ 因为 global 是 Python 保留关键字,但枚举值会转换为字符串 "global" + global_ = "global" # Using global_ because global is a Python reserved keyword, but enum value will be converted to string "global" hybrid = "hybrid" mix = "mix" @@ -254,7 +254,7 @@ class OllamaMessage(BaseModel): class OllamaChatRequest(BaseModel): model: str = LIGHTRAG_MODEL messages: List[OllamaMessage] - stream: bool = True # 默认为流式模式 + stream: bool = True # Default to streaming mode options: Optional[Dict[str, Any]] = None class OllamaChatResponse(BaseModel): @@ -490,11 +490,11 @@ def create_app(args): ), ) - # 如果响应是字符串(比如命中缓存),直接返回 + # If response is a string (e.g. cache hit), return directly if isinstance(response, str): return QueryResponse(response=response) - # 如果是异步生成器,根据stream参数决定是否流式返回 + # If it's an async generator, decide whether to stream based on stream parameter if request.stream: result = "" async for chunk in response: @@ -511,7 +511,7 @@ def create_app(args): @app.post("/query/stream", dependencies=[Depends(optional_api_key)]) async def query_text_stream(request: QueryRequest): try: - response = await rag.aquery( # 使用 aquery 而不是 query,并添加 await + response = await rag.aquery( # Use aquery instead of query, and add await request.query, param=QueryParam( mode=request.mode, @@ -691,7 +691,7 @@ def create_app(args): for prefix, mode in mode_map.items(): if query.startswith(prefix): - # 移除前缀后,清理开头的额外空格 + # After removing prefix an leading spaces cleaned_query = query[len(prefix):].lstrip() return cleaned_query, mode @@ -699,17 +699,14 @@ def create_app(args): @app.post("/api/chat") async def chat(raw_request: Request, request: OllamaChatRequest): - # # 打印原始请求数据 - # body = await raw_request.body() - # logging.info(f"收到 /api/chat 原始请求: {body.decode('utf-8')}") """Handle chat completion requests""" try: - # 获取所有消息内容 + # Get all messages messages = request.messages if not messages: raise HTTPException(status_code=400, detail="No messages provided") - # 获取最后一条消息作为查询 + # Get the last message as query query = messages[-1].content # 解析查询模式 @@ -723,7 +720,7 @@ def create_app(args): # 调用RAG进行查询 query_param = QueryParam( - mode=mode, # 使用解析出的模式,如果没有前缀则为默认的 hybrid + mode=mode, stream=request.stream, only_need_context=False ) @@ -731,7 +728,7 @@ def create_app(args): if request.stream: from fastapi.responses import StreamingResponse - response = await rag.aquery( # 需要 await 来获取异步生成器 + response = await rag.aquery( # Need await to get async generator cleaned_query, param=query_param ) @@ -742,9 +739,9 @@ def create_app(args): last_chunk_time = None total_response = "" - # 确保 response 是异步生成器 + # Ensure response is an async generator if isinstance(response, str): - # 如果是字符串,分两次发送 + # If it's a string, send in two parts first_chunk_time = time.time_ns() last_chunk_time = first_chunk_time total_response = response diff --git a/test_lightrag_ollama_chat.py b/test_lightrag_ollama_chat.py index aab059aa..44b4fa42 100644 --- a/test_lightrag_ollama_chat.py +++ b/test_lightrag_ollama_chat.py @@ -1,12 +1,12 @@ """ -LightRAG Ollama 兼容接口测试脚本 +LightRAG Ollama Compatibility Interface Test Script -这个脚本测试 LightRAG 的 Ollama 兼容接口,包括: -1. 基本功能测试(流式和非流式响应) -2. 查询模式测试(local、global、naive、hybrid) -3. 错误处理测试(包括流式和非流式场景) +This script tests the LightRAG's Ollama compatibility interface, including: +1. Basic functionality tests (streaming and non-streaming responses) +2. Query mode tests (local, global, naive, hybrid) +3. Error handling tests (including streaming and non-streaming scenarios) -所有响应都使用 JSON Lines 格式,符合 Ollama API 规范。 +All responses use the JSON Lines format, complying with the Ollama API specification. """ import requests @@ -24,20 +24,10 @@ class OutputControl: @classmethod def set_verbose(cls, verbose: bool) -> None: - """设置输出详细程度 - - Args: - verbose: True 为详细模式,False 为静默模式 - """ cls._verbose = verbose @classmethod def is_verbose(cls) -> bool: - """获取当前输出模式 - - Returns: - 当前是否为详细模式 - """ return cls._verbose @dataclass @@ -48,9 +38,8 @@ class TestResult: duration: float error: Optional[str] = None timestamp: str = "" - + def __post_init__(self): - """初始化后设置时间戳""" if not self.timestamp: self.timestamp = datetime.now().isoformat() @@ -59,14 +48,13 @@ class TestStats: def __init__(self): self.results: List[TestResult] = [] self.start_time = datetime.now() - + def add_result(self, result: TestResult): - """添加测试结果""" self.results.append(result) - + def export_results(self, path: str = "test_results.json"): """导出测试结果到 JSON 文件 - + Args: path: 输出文件路径 """ @@ -81,25 +69,24 @@ class TestStats: "total_duration": sum(r.duration for r in self.results) } } - + with open(path, "w", encoding="utf-8") as f: json.dump(results_data, f, ensure_ascii=False, indent=2) print(f"\n测试结果已保存到: {path}") - + def print_summary(self): - """打印测试统计摘要""" total = len(self.results) passed = sum(1 for r in self.results if r.success) failed = total - passed duration = sum(r.duration for r in self.results) - + print("\n=== 测试结果摘要 ===") print(f"开始时间: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}") print(f"总用时: {duration:.2f}秒") print(f"总计: {total} 个测试") print(f"通过: {passed} 个") print(f"失败: {failed} 个") - + if failed > 0: print("\n失败的测试:") for result in self.results: @@ -125,15 +112,15 @@ DEFAULT_CONFIG = { def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> requests.Response: """发送 HTTP 请求,支持重试机制 - + Args: url: 请求 URL data: 请求数据 stream: 是否使用流式响应 - + Returns: - requests.Response 对象 - + requests.Response: 对象 + Raises: requests.exceptions.RequestException: 请求失败且重试次数用完 """ @@ -141,7 +128,7 @@ def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> reques max_retries = server_config["max_retries"] retry_delay = server_config["retry_delay"] timeout = server_config["timeout"] - + for attempt in range(max_retries): try: response = requests.post( @@ -159,10 +146,10 @@ def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> reques def load_config() -> Dict[str, Any]: """加载配置文件 - + 首先尝试从当前目录的 config.json 加载, 如果不存在则使用默认配置 - + Returns: 配置字典 """ @@ -174,7 +161,7 @@ def load_config() -> Dict[str, Any]: def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2) -> None: """格式化打印 JSON 响应数据 - + Args: data: 要打印的数据字典 title: 打印的标题 @@ -199,12 +186,12 @@ def create_request_data( model: str = None ) -> Dict[str, Any]: """创建基本的请求数据 - + Args: content: 用户消息内容 stream: 是否使用流式响应 model: 模型名称 - + Returns: 包含完整请求数据的字典 """ @@ -224,7 +211,7 @@ STATS = TestStats() def run_test(func: Callable, name: str) -> None: """运行测试并记录结果 - + Args: func: 测试函数 name: 测试名称 @@ -246,21 +233,21 @@ def test_non_stream_chat(): CONFIG["test_cases"]["basic"]["query"], stream=False ) - + # 发送请求 response = make_request(url, data) - + # 打印响应 if OutputControl.is_verbose(): print("\n=== 非流式调用响应 ===") response_json = response.json() - + # 打印响应内容 print_json_response({ "model": response_json["model"], "message": response_json["message"] }, "响应内容") - + # # 打印性能统计 # print_json_response({ # "total_duration": response_json["total_duration"], @@ -273,7 +260,7 @@ def test_non_stream_chat(): def test_stream_chat(): """测试流式调用 /api/chat 接口 - + 使用 JSON Lines 格式处理流式响应,每行是一个完整的 JSON 对象。 响应格式: { @@ -286,7 +273,7 @@ def test_stream_chat(): }, "done": false } - + 最后一条消息会包含性能统计信息,done 为 true。 """ url = get_base_url() @@ -294,10 +281,10 @@ def test_stream_chat(): CONFIG["test_cases"]["basic"]["query"], stream=True ) - + # 发送请求并获取流式响应 response = make_request(url, data, stream=True) - + if OutputControl.is_verbose(): print("\n=== 流式调用响应 ===") output_buffer = [] @@ -321,24 +308,24 @@ def test_stream_chat(): print("Error decoding JSON from response line") finally: response.close() # 确保关闭响应连接 - + # 打印一个换行 print() def test_query_modes(): """测试不同的查询模式前缀 - + 支持的查询模式: - /local: 本地检索模式,只在相关度高的文档中搜索 - /global: 全局检索模式,在所有文档中搜索 - /naive: 朴素模式,不使用任何优化策略 - /hybrid: 混合模式(默认),结合多种策略 - + 每个模式都会返回相同格式的响应,但检索策略不同。 """ url = get_base_url() modes = ["local", "global", "naive", "hybrid", "mix"] # 支持的查询模式 - + for mode in modes: if OutputControl.is_verbose(): print(f"\n=== 测试 /{mode} 模式 ===") @@ -346,11 +333,11 @@ def test_query_modes(): f"/{mode} {CONFIG['test_cases']['basic']['query']}", stream=False ) - + # 发送请求 response = make_request(url, data) response_json = response.json() - + # 打印响应内容 print_json_response({ "model": response_json["model"], @@ -359,13 +346,13 @@ def test_query_modes(): def create_error_test_data(error_type: str) -> Dict[str, Any]: """创建用于错误测试的请求数据 - + Args: error_type: 错误类型,支持: - empty_messages: 空消息列表 - invalid_role: 无效的角色字段 - missing_content: 缺少内容字段 - + Returns: 包含错误数据的请求字典 """ @@ -399,19 +386,19 @@ def create_error_test_data(error_type: str) -> Dict[str, Any]: def test_stream_error_handling(): """测试流式响应的错误处理 - + 测试场景: 1. 空消息列表 2. 消息格式错误(缺少必需字段) - + 错误响应会立即返回,不会建立流式连接。 状态码应该是 4xx,并返回详细的错误信息。 """ url = get_base_url() - + if OutputControl.is_verbose(): print("\n=== 测试流式响应错误处理 ===") - + # 测试空消息列表 if OutputControl.is_verbose(): print("\n--- 测试空消息列表(流式)---") @@ -421,7 +408,7 @@ def test_stream_error_handling(): if response.status_code != 200: print_json_response(response.json(), "错误信息") response.close() - + # 测试无效角色字段 if OutputControl.is_verbose(): print("\n--- 测试无效角色字段(流式)---") @@ -444,23 +431,23 @@ def test_stream_error_handling(): def test_error_handling(): """测试非流式响应的错误处理 - + 测试场景: 1. 空消息列表 2. 消息格式错误(缺少必需字段) - + 错误响应格式: { "detail": "错误描述" } - + 所有错误都应该返回合适的 HTTP 状态码和清晰的错误信息。 """ url = get_base_url() - + if OutputControl.is_verbose(): print("\n=== 测试错误处理 ===") - + # 测试空消息列表 if OutputControl.is_verbose(): print("\n--- 测试空消息列表 ---") @@ -469,7 +456,7 @@ def test_error_handling(): response = make_request(url, data) print(f"状态码: {response.status_code}") print_json_response(response.json(), "错误信息") - + # 测试无效角色字段 if OutputControl.is_verbose(): print("\n--- 测试无效角色字段 ---") @@ -490,7 +477,7 @@ def test_error_handling(): def get_test_cases() -> Dict[str, Callable]: """获取所有可用的测试用例 - + Returns: 测试名称到测试函数的映射字典 """ @@ -564,21 +551,21 @@ def parse_args() -> argparse.Namespace: if __name__ == "__main__": args = parse_args() - + # 设置输出模式 OutputControl.set_verbose(not args.quiet) - + # 如果指定了查询内容,更新配置 if args.ask: CONFIG["test_cases"]["basic"]["query"] = args.ask - + # 如果指定了创建配置文件 if args.init_config: create_default_config() exit(0) - + test_cases = get_test_cases() - + try: if "all" in args.tests: # 运行所有测试 @@ -586,11 +573,11 @@ if __name__ == "__main__": print("\n【基本功能测试】") run_test(test_non_stream_chat, "非流式调用测试") run_test(test_stream_chat, "流式调用测试") - + if OutputControl.is_verbose(): print("\n【查询模式测试】") run_test(test_query_modes, "查询模式测试") - + if OutputControl.is_verbose(): print("\n【错误处理测试】") run_test(test_error_handling, "错误处理测试")