From b94cae9990f5b7626643393c58f221bce2abc215 Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 24 Jan 2025 18:18:14 +0800 Subject: [PATCH 01/10] draft implementation of /api/generate endpoint --- lightrag/api/lightrag_server.py | 173 ++++++++++++++++++++++++++++++++ 1 file changed, 173 insertions(+) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 16f3ae49..d417d732 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -475,6 +475,25 @@ class OllamaChatResponse(BaseModel): message: OllamaMessage done: bool +class OllamaGenerateRequest(BaseModel): + model: str = LIGHTRAG_MODEL + prompt: str + system: Optional[str] = None + stream: bool = False + options: Optional[Dict[str, Any]] = None + +class OllamaGenerateResponse(BaseModel): + model: str + created_at: str + response: str + done: bool + context: Optional[List[int]] + total_duration: Optional[int] + load_duration: Optional[int] + prompt_eval_count: Optional[int] + prompt_eval_duration: Optional[int] + eval_count: Optional[int] + eval_duration: Optional[int] class OllamaVersionResponse(BaseModel): version: str @@ -1237,6 +1256,160 @@ def create_app(args): return query, SearchMode.hybrid + @app.post("/api/generate") + async def generate(raw_request: Request, request: OllamaGenerateRequest): + """Handle generate completion requests""" + try: + # 获取查询内容 + query = request.prompt + + # 解析查询模式 + cleaned_query, mode = parse_query_mode(query) + + # 开始计时 + start_time = time.time_ns() + + # 计算输入token数量 + prompt_tokens = estimate_tokens(cleaned_query) + + # 调用RAG进行查询 + query_param = QueryParam( + mode=mode, + stream=request.stream, + only_need_context=False + ) + + # 如果有 system prompt,更新 rag 的 llm_model_kwargs + if request.system: + rag.llm_model_kwargs["system_prompt"] = request.system + + if request.stream: + from fastapi.responses import StreamingResponse + + response = await rag.aquery( + cleaned_query, + param=query_param + ) + + async def stream_generator(): + try: + first_chunk_time = None + last_chunk_time = None + total_response = "" + + # 处理响应 + if isinstance(response, str): + # 如果是字符串,分两部分发送 + first_chunk_time = time.time_ns() + last_chunk_time = first_chunk_time + total_response = response + + data = { + "model": LIGHTRAG_MODEL, + "created_at": LIGHTRAG_CREATED_AT, + "response": response, + "done": False + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + + completion_tokens = estimate_tokens(total_response) + total_time = last_chunk_time - start_time + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time + + data = { + "model": LIGHTRAG_MODEL, + "created_at": LIGHTRAG_CREATED_AT, + "done": True, + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + else: + async for chunk in response: + if chunk: + if first_chunk_time is None: + first_chunk_time = time.time_ns() + + last_chunk_time = time.time_ns() + + total_response += chunk + data = { + "model": LIGHTRAG_MODEL, + "created_at": LIGHTRAG_CREATED_AT, + "response": chunk, + "done": False + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + + completion_tokens = estimate_tokens(total_response) + total_time = last_chunk_time - start_time + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time + + data = { + "model": LIGHTRAG_MODEL, + "created_at": LIGHTRAG_CREATED_AT, + "done": True, + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + return + + except Exception as e: + logging.error(f"Error in stream_generator: {str(e)}") + raise + + return StreamingResponse( + stream_generator(), + media_type="application/x-ndjson", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Content-Type": "application/x-ndjson", + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "POST, OPTIONS", + "Access-Control-Allow-Headers": "Content-Type", + }, + ) + else: + first_chunk_time = time.time_ns() + response_text = await rag.aquery(cleaned_query, param=query_param) + last_chunk_time = time.time_ns() + + if not response_text: + response_text = "No response generated" + + completion_tokens = estimate_tokens(str(response_text)) + total_time = last_chunk_time - start_time + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time + + return { + "model": LIGHTRAG_MODEL, + "created_at": LIGHTRAG_CREATED_AT, + "response": str(response_text), + "done": True, + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time + } + except Exception as e: + trace_exception(e) + raise HTTPException(status_code=500, detail=str(e)) + @app.post("/api/chat") async def chat(raw_request: Request, request: OllamaChatRequest): """Handle chat completion requests""" From 2c8885792c2bb8f0d1e645d19343d6016b34d4ae Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 24 Jan 2025 18:39:43 +0800 Subject: [PATCH 02/10] =?UTF-8?q?Refactor=20/api/generate=EF=BC=9Ause=20ll?= =?UTF-8?q?m=5Fmodel=5Ffunc=20directly?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lightrag/api/lightrag_server.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index d417d732..36617947 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -1272,23 +1272,17 @@ def create_app(args): # 计算输入token数量 prompt_tokens = estimate_tokens(cleaned_query) - # 调用RAG进行查询 - query_param = QueryParam( - mode=mode, - stream=request.stream, - only_need_context=False - ) - - # 如果有 system prompt,更新 rag 的 llm_model_kwargs + # 直接使用 llm_model_func 进行查询 if request.system: rag.llm_model_kwargs["system_prompt"] = request.system if request.stream: from fastapi.responses import StreamingResponse - response = await rag.aquery( - cleaned_query, - param=query_param + response = await rag.llm_model_func( + cleaned_query, + stream=True, + **rag.llm_model_kwargs ) async def stream_generator(): @@ -1383,7 +1377,11 @@ def create_app(args): ) else: first_chunk_time = time.time_ns() - response_text = await rag.aquery(cleaned_query, param=query_param) + response_text = await rag.llm_model_func( + cleaned_query, + stream=False, + **rag.llm_model_kwargs + ) last_chunk_time = time.time_ns() if not response_text: From c26d799bb6c3c9da7e25621e617edeb237d1730d Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 24 Jan 2025 19:09:31 +0800 Subject: [PATCH 03/10] Add generate API tests and enhance chat API tests - Add non-streaming generate API test - Add streaming generate API test - Add generate error handling tests - Add generate performance stats test - Add generate concurrent request test --- test_lightrag_ollama_chat.py | 319 ++++++++++++++++++++++++++++++++--- 1 file changed, 294 insertions(+), 25 deletions(-) diff --git a/test_lightrag_ollama_chat.py b/test_lightrag_ollama_chat.py index 96aee692..c190e7ac 100644 --- a/test_lightrag_ollama_chat.py +++ b/test_lightrag_ollama_chat.py @@ -108,7 +108,10 @@ DEFAULT_CONFIG = { "max_retries": 3, "retry_delay": 1, }, - "test_cases": {"basic": {"query": "唐僧有几个徒弟"}}, + "test_cases": { + "basic": {"query": "唐僧有几个徒弟"}, + "generate": {"query": "电视剧西游记导演是谁"} + }, } @@ -174,22 +177,27 @@ def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2) CONFIG = load_config() -def get_base_url() -> str: - """Return the base URL""" +def get_base_url(endpoint: str = "chat") -> str: + """Return the base URL for specified endpoint + Args: + endpoint: API endpoint name (chat or generate) + Returns: + Complete URL for the endpoint + """ server = CONFIG["server"] - return f"http://{server['host']}:{server['port']}/api/chat" + return f"http://{server['host']}:{server['port']}/api/{endpoint}" -def create_request_data( +def create_chat_request_data( content: str, stream: bool = False, model: str = None ) -> Dict[str, Any]: - """Create basic request data + """Create chat request data Args: content: User message content stream: Whether to use streaming response model: Model name Returns: - Dictionary containing complete request data + Dictionary containing complete chat request data """ return { "model": model or CONFIG["server"]["model"], @@ -197,6 +205,34 @@ def create_request_data( "stream": stream, } +def create_generate_request_data( + prompt: str, + system: str = None, + stream: bool = False, + model: str = None, + options: Dict[str, Any] = None +) -> Dict[str, Any]: + """Create generate request data + Args: + prompt: Generation prompt + system: System prompt + stream: Whether to use streaming response + model: Model name + options: Additional options + Returns: + Dictionary containing complete generate request data + """ + data = { + "model": model or CONFIG["server"]["model"], + "prompt": prompt, + "stream": stream + } + if system: + data["system"] = system + if options: + data["options"] = options + return data + # Global test statistics STATS = TestStats() @@ -219,10 +255,10 @@ def run_test(func: Callable, name: str) -> None: raise -def test_non_stream_chat(): +def test_non_stream_chat() -> None: """Test non-streaming call to /api/chat endpoint""" url = get_base_url() - data = create_request_data(CONFIG["test_cases"]["basic"]["query"], stream=False) + data = create_chat_request_data(CONFIG["test_cases"]["basic"]["query"], stream=False) # Send request response = make_request(url, data) @@ -239,7 +275,7 @@ def test_non_stream_chat(): ) -def test_stream_chat(): +def test_stream_chat() -> None: """Test streaming call to /api/chat endpoint Use JSON Lines format to process streaming responses, each line is a complete JSON object. @@ -258,7 +294,7 @@ def test_stream_chat(): The last message will contain performance statistics, with done set to true. """ url = get_base_url() - data = create_request_data(CONFIG["test_cases"]["basic"]["query"], stream=True) + data = create_chat_request_data(CONFIG["test_cases"]["basic"]["query"], stream=True) # Send request and get streaming response response = make_request(url, data, stream=True) @@ -295,7 +331,7 @@ def test_stream_chat(): print() -def test_query_modes(): +def test_query_modes() -> None: """Test different query mode prefixes Supported query modes: @@ -313,7 +349,7 @@ def test_query_modes(): for mode in modes: if OutputControl.is_verbose(): print(f"\n=== Testing /{mode} mode ===") - data = create_request_data( + data = create_chat_request_data( f"/{mode} {CONFIG['test_cases']['basic']['query']}", stream=False ) @@ -354,7 +390,7 @@ def create_error_test_data(error_type: str) -> Dict[str, Any]: return error_data.get(error_type, error_data["empty_messages"]) -def test_stream_error_handling(): +def test_stream_error_handling() -> None: """Test error handling for streaming responses Test scenarios: @@ -400,7 +436,7 @@ def test_stream_error_handling(): response.close() -def test_error_handling(): +def test_error_handling() -> None: """Test error handling for non-streaming responses Test scenarios: @@ -447,6 +483,228 @@ def test_error_handling(): print_json_response(response.json(), "Error message") +def test_non_stream_generate() -> None: + """Test non-streaming call to /api/generate endpoint""" + url = get_base_url("generate") + data = create_generate_request_data( + CONFIG["test_cases"]["generate"]["query"], + stream=False + ) + + # Send request + response = make_request(url, data) + + # Print response + if OutputControl.is_verbose(): + print("\n=== Non-streaming generate response ===") + response_json = response.json() + + # Print response content + print_json_response( + { + "model": response_json["model"], + "response": response_json["response"], + "done": response_json["done"] + }, + "Response content" + ) + +def test_stream_generate() -> None: + """Test streaming call to /api/generate endpoint""" + url = get_base_url("generate") + data = create_generate_request_data( + CONFIG["test_cases"]["generate"]["query"], + stream=True + ) + + # Send request and get streaming response + response = make_request(url, data, stream=True) + + if OutputControl.is_verbose(): + print("\n=== Streaming generate response ===") + output_buffer = [] + try: + for line in response.iter_lines(): + if line: # Skip empty lines + try: + # Decode and parse JSON + data = json.loads(line.decode("utf-8")) + if data.get("done", True): # If it's the completion marker + if "total_duration" in data: # Final performance statistics message + break + else: # Normal content message + content = data.get("response", "") + if content: # Only collect non-empty content + output_buffer.append(content) + print(content, end="", flush=True) # Print content in real-time + except json.JSONDecodeError: + print("Error decoding JSON from response line") + finally: + response.close() # Ensure the response connection is closed + + # Print a newline + print() + +def test_generate_with_system() -> None: + """Test generate with system prompt""" + url = get_base_url("generate") + data = create_generate_request_data( + CONFIG["test_cases"]["generate"]["query"], + system="你是一个知识渊博的助手", + stream=False + ) + + # Send request + response = make_request(url, data) + + # Print response + if OutputControl.is_verbose(): + print("\n=== Generate with system prompt response ===") + response_json = response.json() + + # Print response content + print_json_response( + { + "model": response_json["model"], + "response": response_json["response"], + "done": response_json["done"] + }, + "Response content" + ) + +def test_generate_error_handling() -> None: + """Test error handling for generate endpoint""" + url = get_base_url("generate") + + # Test empty prompt + if OutputControl.is_verbose(): + print("\n=== Testing empty prompt ===") + data = create_generate_request_data("", stream=False) + response = make_request(url, data) + print(f"Status code: {response.status_code}") + print_json_response(response.json(), "Error message") + + # Test invalid options + if OutputControl.is_verbose(): + print("\n=== Testing invalid options ===") + data = create_generate_request_data( + CONFIG["test_cases"]["basic"]["query"], + options={"invalid_option": "value"}, + stream=False + ) + response = make_request(url, data) + print(f"Status code: {response.status_code}") + print_json_response(response.json(), "Error message") + + # Test very long input + if OutputControl.is_verbose(): + print("\n=== Testing very long input ===") + long_text = "测试" * 10000 # Create a very long input + data = create_generate_request_data(long_text, stream=False) + response = make_request(url, data) + print(f"Status code: {response.status_code}") + print_json_response(response.json(), "Error message") + +def test_generate_performance_stats() -> None: + """Test performance statistics in generate response""" + url = get_base_url("generate") + + # Test with different length inputs to verify token counting + inputs = [ + "你好", # Short Chinese + "Hello world", # Short English + "这是一个较长的中文输入,用来测试token数量的估算是否准确。", # Medium Chinese + "This is a longer English input that will be used to test the accuracy of token count estimation." # Medium English + ] + + for test_input in inputs: + if OutputControl.is_verbose(): + print(f"\n=== Testing performance stats with input: {test_input} ===") + data = create_generate_request_data(test_input, stream=False) + response = make_request(url, data) + response_json = response.json() + + # Verify performance statistics exist and are reasonable + stats = { + "total_duration": response_json.get("total_duration"), + "prompt_eval_count": response_json.get("prompt_eval_count"), + "prompt_eval_duration": response_json.get("prompt_eval_duration"), + "eval_count": response_json.get("eval_count"), + "eval_duration": response_json.get("eval_duration") + } + print_json_response(stats, "Performance statistics") + +def test_generate_concurrent() -> None: + """Test concurrent generate requests""" + import asyncio + import aiohttp + from contextlib import asynccontextmanager + + @asynccontextmanager + async def get_session(): + async with aiohttp.ClientSession() as session: + yield session + + async def make_request(session, prompt: str): + url = get_base_url("generate") + data = create_generate_request_data(prompt, stream=False) + try: + async with session.post(url, json=data) as response: + return await response.json() + except Exception as e: + return {"error": str(e)} + + async def run_concurrent_requests(): + prompts = [ + "第一个问题", + "第二个问题", + "第三个问题", + "第四个问题", + "第五个问题" + ] + + async with get_session() as session: + tasks = [make_request(session, prompt) for prompt in prompts] + results = await asyncio.gather(*tasks) + return results + + if OutputControl.is_verbose(): + print("\n=== Testing concurrent generate requests ===") + + # Run concurrent requests + results = asyncio.run(run_concurrent_requests()) + + # Print results + for i, result in enumerate(results, 1): + print(f"\nRequest {i} result:") + print_json_response(result) + +def test_generate_query_modes() -> None: + """Test different query mode prefixes for generate endpoint""" + url = get_base_url("generate") + modes = ["local", "global", "naive", "hybrid", "mix"] + + for mode in modes: + if OutputControl.is_verbose(): + print(f"\n=== Testing /{mode} mode for generate ===") + data = create_generate_request_data( + f"/{mode} {CONFIG['test_cases']['generate']['query']}", + stream=False + ) + + # Send request + response = make_request(url, data) + response_json = response.json() + + # Print response content + print_json_response( + { + "model": response_json["model"], + "response": response_json["response"], + "done": response_json["done"] + } + ) + def get_test_cases() -> Dict[str, Callable]: """Get all available test cases Returns: @@ -458,6 +716,13 @@ def get_test_cases() -> Dict[str, Callable]: "modes": test_query_modes, "errors": test_error_handling, "stream_errors": test_stream_error_handling, + "non_stream_generate": test_non_stream_generate, + "stream_generate": test_stream_generate, + "generate_with_system": test_generate_with_system, + "generate_modes": test_generate_query_modes, + "generate_errors": test_generate_error_handling, + "generate_stats": test_generate_performance_stats, + "generate_concurrent": test_generate_concurrent } @@ -544,18 +809,22 @@ if __name__ == "__main__": if "all" in args.tests: # Run all tests if OutputControl.is_verbose(): - print("\n【Basic Functionality Tests】") - run_test(test_non_stream_chat, "Non-streaming Call Test") - run_test(test_stream_chat, "Streaming Call Test") + print("\n【Chat API Tests】") + run_test(test_non_stream_chat, "Non-streaming Chat Test") + run_test(test_stream_chat, "Streaming Chat Test") + run_test(test_query_modes, "Chat Query Mode Test") + run_test(test_error_handling, "Chat Error Handling Test") + run_test(test_stream_error_handling, "Chat Streaming Error Test") if OutputControl.is_verbose(): - print("\n【Query Mode Tests】") - run_test(test_query_modes, "Query Mode Test") - - if OutputControl.is_verbose(): - print("\n【Error Handling Tests】") - run_test(test_error_handling, "Error Handling Test") - run_test(test_stream_error_handling, "Streaming Error Handling Test") + print("\n【Generate API Tests】") + run_test(test_non_stream_generate, "Non-streaming Generate Test") + run_test(test_stream_generate, "Streaming Generate Test") + run_test(test_generate_with_system, "Generate with System Prompt Test") + run_test(test_generate_query_modes, "Generate Query Mode Test") + run_test(test_generate_error_handling, "Generate Error Handling Test") + run_test(test_generate_performance_stats, "Generate Performance Stats Test") + run_test(test_generate_concurrent, "Generate Concurrent Test") else: # Run specified tests for test_name in args.tests: From 385661b10e1427d42a9a618fa6940533ae1e9f4c Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 24 Jan 2025 19:25:45 +0800 Subject: [PATCH 04/10] Removed query mode parsing and related tests - Removed query mode parsing logic - Removed test_generate_query_modes - Simplified generate endpoint - Updated test cases list - Cleaned up unused code --- lightrag/api/lightrag_server.py | 12 ++++-------- test_lightrag_ollama_chat.py | 28 ---------------------------- 2 files changed, 4 insertions(+), 36 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 36617947..855424a6 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -1260,17 +1260,13 @@ def create_app(args): async def generate(raw_request: Request, request: OllamaGenerateRequest): """Handle generate completion requests""" try: - # 获取查询内容 query = request.prompt - - # 解析查询模式 - cleaned_query, mode = parse_query_mode(query) - + # 开始计时 start_time = time.time_ns() # 计算输入token数量 - prompt_tokens = estimate_tokens(cleaned_query) + prompt_tokens = estimate_tokens(query) # 直接使用 llm_model_func 进行查询 if request.system: @@ -1280,7 +1276,7 @@ def create_app(args): from fastapi.responses import StreamingResponse response = await rag.llm_model_func( - cleaned_query, + query, stream=True, **rag.llm_model_kwargs ) @@ -1378,7 +1374,7 @@ def create_app(args): else: first_chunk_time = time.time_ns() response_text = await rag.llm_model_func( - cleaned_query, + query, stream=False, **rag.llm_model_kwargs ) diff --git a/test_lightrag_ollama_chat.py b/test_lightrag_ollama_chat.py index c190e7ac..3d3034b9 100644 --- a/test_lightrag_ollama_chat.py +++ b/test_lightrag_ollama_chat.py @@ -679,32 +679,6 @@ def test_generate_concurrent() -> None: print(f"\nRequest {i} result:") print_json_response(result) -def test_generate_query_modes() -> None: - """Test different query mode prefixes for generate endpoint""" - url = get_base_url("generate") - modes = ["local", "global", "naive", "hybrid", "mix"] - - for mode in modes: - if OutputControl.is_verbose(): - print(f"\n=== Testing /{mode} mode for generate ===") - data = create_generate_request_data( - f"/{mode} {CONFIG['test_cases']['generate']['query']}", - stream=False - ) - - # Send request - response = make_request(url, data) - response_json = response.json() - - # Print response content - print_json_response( - { - "model": response_json["model"], - "response": response_json["response"], - "done": response_json["done"] - } - ) - def get_test_cases() -> Dict[str, Callable]: """Get all available test cases Returns: @@ -719,7 +693,6 @@ def get_test_cases() -> Dict[str, Callable]: "non_stream_generate": test_non_stream_generate, "stream_generate": test_stream_generate, "generate_with_system": test_generate_with_system, - "generate_modes": test_generate_query_modes, "generate_errors": test_generate_error_handling, "generate_stats": test_generate_performance_stats, "generate_concurrent": test_generate_concurrent @@ -821,7 +794,6 @@ if __name__ == "__main__": run_test(test_non_stream_generate, "Non-streaming Generate Test") run_test(test_stream_generate, "Streaming Generate Test") run_test(test_generate_with_system, "Generate with System Prompt Test") - run_test(test_generate_query_modes, "Generate Query Mode Test") run_test(test_generate_error_handling, "Generate Error Handling Test") run_test(test_generate_performance_stats, "Generate Performance Stats Test") run_test(test_generate_concurrent, "Generate Concurrent Test") From 5967dd8da762933d839b4921e5f5c87f64c26013 Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 24 Jan 2025 19:30:14 +0800 Subject: [PATCH 05/10] Remove long input test case from error handling - Deleted very long input test - Simplified error handling test - Improved code readability - Reduced unnecessary test cases - Streamlined test function --- test_lightrag_ollama_chat.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/test_lightrag_ollama_chat.py b/test_lightrag_ollama_chat.py index 3d3034b9..7be23131 100644 --- a/test_lightrag_ollama_chat.py +++ b/test_lightrag_ollama_chat.py @@ -595,15 +595,7 @@ def test_generate_error_handling() -> None: response = make_request(url, data) print(f"Status code: {response.status_code}") print_json_response(response.json(), "Error message") - - # Test very long input - if OutputControl.is_verbose(): - print("\n=== Testing very long input ===") - long_text = "测试" * 10000 # Create a very long input - data = create_generate_request_data(long_text, stream=False) - response = make_request(url, data) - print(f"Status code: {response.status_code}") - print_json_response(response.json(), "Error message") + def test_generate_performance_stats() -> None: """Test performance statistics in generate response""" From 032af348b0699651ac8028ae1b7676de54d99010 Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 24 Jan 2025 19:38:35 +0800 Subject: [PATCH 06/10] Remove performance stats test from generate endpoint - Deleted performance stats test function - Updated test cases dictionary - Removed test from main execution block - Simplified test suite - Focused on core functionality --- test_lightrag_ollama_chat.py | 31 ------------------------------- 1 file changed, 31 deletions(-) diff --git a/test_lightrag_ollama_chat.py b/test_lightrag_ollama_chat.py index 7be23131..21014735 100644 --- a/test_lightrag_ollama_chat.py +++ b/test_lightrag_ollama_chat.py @@ -597,35 +597,6 @@ def test_generate_error_handling() -> None: print_json_response(response.json(), "Error message") -def test_generate_performance_stats() -> None: - """Test performance statistics in generate response""" - url = get_base_url("generate") - - # Test with different length inputs to verify token counting - inputs = [ - "你好", # Short Chinese - "Hello world", # Short English - "这是一个较长的中文输入,用来测试token数量的估算是否准确。", # Medium Chinese - "This is a longer English input that will be used to test the accuracy of token count estimation." # Medium English - ] - - for test_input in inputs: - if OutputControl.is_verbose(): - print(f"\n=== Testing performance stats with input: {test_input} ===") - data = create_generate_request_data(test_input, stream=False) - response = make_request(url, data) - response_json = response.json() - - # Verify performance statistics exist and are reasonable - stats = { - "total_duration": response_json.get("total_duration"), - "prompt_eval_count": response_json.get("prompt_eval_count"), - "prompt_eval_duration": response_json.get("prompt_eval_duration"), - "eval_count": response_json.get("eval_count"), - "eval_duration": response_json.get("eval_duration") - } - print_json_response(stats, "Performance statistics") - def test_generate_concurrent() -> None: """Test concurrent generate requests""" import asyncio @@ -686,7 +657,6 @@ def get_test_cases() -> Dict[str, Callable]: "stream_generate": test_stream_generate, "generate_with_system": test_generate_with_system, "generate_errors": test_generate_error_handling, - "generate_stats": test_generate_performance_stats, "generate_concurrent": test_generate_concurrent } @@ -787,7 +757,6 @@ if __name__ == "__main__": run_test(test_stream_generate, "Streaming Generate Test") run_test(test_generate_with_system, "Generate with System Prompt Test") run_test(test_generate_error_handling, "Generate Error Handling Test") - run_test(test_generate_performance_stats, "Generate Performance Stats Test") run_test(test_generate_concurrent, "Generate Concurrent Test") else: # Run specified tests From 24a460f84b5e6ee47dcc63bc5f0d8f7db8f57a1b Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 24 Jan 2025 22:24:56 +0800 Subject: [PATCH 07/10] Update README with Open WebUI task model config --- lightrag/api/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/lightrag/api/README.md b/lightrag/api/README.md index 00a8db18..89906006 100644 --- a/lightrag/api/README.md +++ b/lightrag/api/README.md @@ -94,6 +94,7 @@ For example, chat message "/mix 唐僧有几个徒弟" will trigger a mix mode q After starting the lightrag-server, you can add an Ollama-type connection in the Open WebUI admin pannel. And then a model named lightrag:latest will appear in Open WebUI's model management interface. Users can then send queries to LightRAG through the chat interface. +To prevent Open WebUI from using LightRAG when generating conversation titles, go to Admin Panel > Interface > Set Task Model and change both Local Models and External Models to any option except "Current Model". ## Configuration From 930b6cf00611c65353869ed77cfc1873977817dd Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 24 Jan 2025 23:33:12 +0800 Subject: [PATCH 08/10] Directly pass the session title or keyword generation request from Open WebUI to the underlying LLM --- lightrag/api/lightrag_server.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 855424a6..93667e84 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -467,6 +467,7 @@ class OllamaChatRequest(BaseModel): messages: List[OllamaMessage] stream: bool = True # Default to streaming mode options: Optional[Dict[str, Any]] = None + system: Optional[str] = None class OllamaChatResponse(BaseModel): @@ -1536,7 +1537,25 @@ def create_app(args): ) else: first_chunk_time = time.time_ns() - response_text = await rag.aquery(cleaned_query, param=query_param) + + # 判断是否包含特定字符串,使用正则表达式进行匹配 + logging.info(f"Cleaned query content: {cleaned_query}") + match_result = re.search(r'\\n\\nUSER:', cleaned_query) + logging.info(f"Regex match result: {bool(match_result)}") + + if match_result: + + if request.system: + rag.llm_model_kwargs["system_prompt"] = request.system + + response_text = await rag.llm_model_func( + cleaned_query, + stream=False, + **rag.llm_model_kwargs + ) + else: + response_text = await rag.aquery(cleaned_query, param=query_param) + last_chunk_time = time.time_ns() if not response_text: From 11873625a37dc58937ba03214080683693fa1cef Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 24 Jan 2025 23:39:33 +0800 Subject: [PATCH 09/10] Fix regex pattern for chat history detection - Updated regex to match newline characters - Added re.MULTILINE flag for multiline matching - Improved logging for regex match results - Enhanced query content cleaning process - Ensured consistent chat history detection --- lightrag/api/lightrag_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 93667e84..ab9fe732 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -1540,7 +1540,7 @@ def create_app(args): # 判断是否包含特定字符串,使用正则表达式进行匹配 logging.info(f"Cleaned query content: {cleaned_query}") - match_result = re.search(r'\\n\\nUSER:', cleaned_query) + match_result = re.search(r'\n\nUSER:', cleaned_query, re.MULTILINE) logging.info(f"Regex match result: {bool(match_result)}") if match_result: From f30a69e2012db9db7d5d15c921e55e7fd4989e0b Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 24 Jan 2025 23:50:47 +0800 Subject: [PATCH 10/10] Fix linting, remove redundant commentsr and clean up code for better readability --- lightrag/api/lightrag_server.py | 90 ++++++++++++++------------------- test_lightrag_ollama_chat.py | 73 +++++++++++++------------- 2 files changed, 76 insertions(+), 87 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index ab9fe732..24da3009 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -476,6 +476,7 @@ class OllamaChatResponse(BaseModel): message: OllamaMessage done: bool + class OllamaGenerateRequest(BaseModel): model: str = LIGHTRAG_MODEL prompt: str @@ -483,6 +484,7 @@ class OllamaGenerateRequest(BaseModel): stream: bool = False options: Optional[Dict[str, Any]] = None + class OllamaGenerateResponse(BaseModel): model: str created_at: str @@ -490,12 +492,13 @@ class OllamaGenerateResponse(BaseModel): done: bool context: Optional[List[int]] total_duration: Optional[int] - load_duration: Optional[int] + load_duration: Optional[int] prompt_eval_count: Optional[int] prompt_eval_duration: Optional[int] eval_count: Optional[int] eval_duration: Optional[int] + class OllamaVersionResponse(BaseModel): version: str @@ -1262,52 +1265,45 @@ def create_app(args): """Handle generate completion requests""" try: query = request.prompt - - # 开始计时 start_time = time.time_ns() - - # 计算输入token数量 prompt_tokens = estimate_tokens(query) - - # 直接使用 llm_model_func 进行查询 + if request.system: rag.llm_model_kwargs["system_prompt"] = request.system - + if request.stream: from fastapi.responses import StreamingResponse - + response = await rag.llm_model_func( - query, - stream=True, - **rag.llm_model_kwargs + query, stream=True, **rag.llm_model_kwargs ) - + async def stream_generator(): try: first_chunk_time = None last_chunk_time = None total_response = "" - - # 处理响应 + + # Ensure response is an async generator if isinstance(response, str): - # 如果是字符串,分两部分发送 + # If it's a string, send in two parts first_chunk_time = time.time_ns() last_chunk_time = first_chunk_time total_response = response - + data = { "model": LIGHTRAG_MODEL, "created_at": LIGHTRAG_CREATED_AT, "response": response, - "done": False + "done": False, } yield f"{json.dumps(data, ensure_ascii=False)}\n" - + completion_tokens = estimate_tokens(total_response) total_time = last_chunk_time - start_time prompt_eval_time = first_chunk_time - start_time eval_time = last_chunk_time - first_chunk_time - + data = { "model": LIGHTRAG_MODEL, "created_at": LIGHTRAG_CREATED_AT, @@ -1317,7 +1313,7 @@ def create_app(args): "prompt_eval_count": prompt_tokens, "prompt_eval_duration": prompt_eval_time, "eval_count": completion_tokens, - "eval_duration": eval_time + "eval_duration": eval_time, } yield f"{json.dumps(data, ensure_ascii=False)}\n" else: @@ -1325,23 +1321,23 @@ def create_app(args): if chunk: if first_chunk_time is None: first_chunk_time = time.time_ns() - + last_chunk_time = time.time_ns() - + total_response += chunk data = { "model": LIGHTRAG_MODEL, "created_at": LIGHTRAG_CREATED_AT, "response": chunk, - "done": False + "done": False, } yield f"{json.dumps(data, ensure_ascii=False)}\n" - + completion_tokens = estimate_tokens(total_response) total_time = last_chunk_time - start_time prompt_eval_time = first_chunk_time - start_time eval_time = last_chunk_time - first_chunk_time - + data = { "model": LIGHTRAG_MODEL, "created_at": LIGHTRAG_CREATED_AT, @@ -1351,15 +1347,15 @@ def create_app(args): "prompt_eval_count": prompt_tokens, "prompt_eval_duration": prompt_eval_time, "eval_count": completion_tokens, - "eval_duration": eval_time + "eval_duration": eval_time, } yield f"{json.dumps(data, ensure_ascii=False)}\n" return - + except Exception as e: logging.error(f"Error in stream_generator: {str(e)}") raise - + return StreamingResponse( stream_generator(), media_type="application/x-ndjson", @@ -1375,20 +1371,18 @@ def create_app(args): else: first_chunk_time = time.time_ns() response_text = await rag.llm_model_func( - query, - stream=False, - **rag.llm_model_kwargs + query, stream=False, **rag.llm_model_kwargs ) last_chunk_time = time.time_ns() - + if not response_text: response_text = "No response generated" - + completion_tokens = estimate_tokens(str(response_text)) total_time = last_chunk_time - start_time prompt_eval_time = first_chunk_time - start_time eval_time = last_chunk_time - first_chunk_time - + return { "model": LIGHTRAG_MODEL, "created_at": LIGHTRAG_CREATED_AT, @@ -1399,7 +1393,7 @@ def create_app(args): "prompt_eval_count": prompt_tokens, "prompt_eval_duration": prompt_eval_time, "eval_count": completion_tokens, - "eval_duration": eval_time + "eval_duration": eval_time, } except Exception as e: trace_exception(e) @@ -1417,16 +1411,12 @@ def create_app(args): # Get the last message as query query = messages[-1].content - # 解析查询模式 + # Check for query prefix cleaned_query, mode = parse_query_mode(query) - # 开始计时 start_time = time.time_ns() - - # 计算输入token数量 prompt_tokens = estimate_tokens(cleaned_query) - # 调用RAG进行查询 query_param = QueryParam( mode=mode, stream=request.stream, only_need_context=False ) @@ -1537,25 +1527,21 @@ def create_app(args): ) else: first_chunk_time = time.time_ns() - - # 判断是否包含特定字符串,使用正则表达式进行匹配 - logging.info(f"Cleaned query content: {cleaned_query}") - match_result = re.search(r'\n\nUSER:', cleaned_query, re.MULTILINE) - logging.info(f"Regex match result: {bool(match_result)}") - - if match_result: + # Determine if the request is from Open WebUI's session title and session keyword generation task + match_result = re.search( + r"\n\nUSER:", cleaned_query, re.MULTILINE + ) + if match_result: if request.system: rag.llm_model_kwargs["system_prompt"] = request.system response_text = await rag.llm_model_func( - cleaned_query, - stream=False, - **rag.llm_model_kwargs + cleaned_query, stream=False, **rag.llm_model_kwargs ) else: response_text = await rag.aquery(cleaned_query, param=query_param) - + last_chunk_time = time.time_ns() if not response_text: diff --git a/test_lightrag_ollama_chat.py b/test_lightrag_ollama_chat.py index 21014735..d1e61d39 100644 --- a/test_lightrag_ollama_chat.py +++ b/test_lightrag_ollama_chat.py @@ -110,7 +110,7 @@ DEFAULT_CONFIG = { }, "test_cases": { "basic": {"query": "唐僧有几个徒弟"}, - "generate": {"query": "电视剧西游记导演是谁"} + "generate": {"query": "电视剧西游记导演是谁"}, }, } @@ -205,12 +205,13 @@ def create_chat_request_data( "stream": stream, } + def create_generate_request_data( - prompt: str, + prompt: str, system: str = None, - stream: bool = False, + stream: bool = False, model: str = None, - options: Dict[str, Any] = None + options: Dict[str, Any] = None, ) -> Dict[str, Any]: """Create generate request data Args: @@ -225,7 +226,7 @@ def create_generate_request_data( data = { "model": model or CONFIG["server"]["model"], "prompt": prompt, - "stream": stream + "stream": stream, } if system: data["system"] = system @@ -258,7 +259,9 @@ def run_test(func: Callable, name: str) -> None: def test_non_stream_chat() -> None: """Test non-streaming call to /api/chat endpoint""" url = get_base_url() - data = create_chat_request_data(CONFIG["test_cases"]["basic"]["query"], stream=False) + data = create_chat_request_data( + CONFIG["test_cases"]["basic"]["query"], stream=False + ) # Send request response = make_request(url, data) @@ -487,8 +490,7 @@ def test_non_stream_generate() -> None: """Test non-streaming call to /api/generate endpoint""" url = get_base_url("generate") data = create_generate_request_data( - CONFIG["test_cases"]["generate"]["query"], - stream=False + CONFIG["test_cases"]["generate"]["query"], stream=False ) # Send request @@ -504,17 +506,17 @@ def test_non_stream_generate() -> None: { "model": response_json["model"], "response": response_json["response"], - "done": response_json["done"] + "done": response_json["done"], }, - "Response content" + "Response content", ) + def test_stream_generate() -> None: """Test streaming call to /api/generate endpoint""" url = get_base_url("generate") data = create_generate_request_data( - CONFIG["test_cases"]["generate"]["query"], - stream=True + CONFIG["test_cases"]["generate"]["query"], stream=True ) # Send request and get streaming response @@ -530,13 +532,17 @@ def test_stream_generate() -> None: # Decode and parse JSON data = json.loads(line.decode("utf-8")) if data.get("done", True): # If it's the completion marker - if "total_duration" in data: # Final performance statistics message + if ( + "total_duration" in data + ): # Final performance statistics message break else: # Normal content message content = data.get("response", "") if content: # Only collect non-empty content output_buffer.append(content) - print(content, end="", flush=True) # Print content in real-time + print( + content, end="", flush=True + ) # Print content in real-time except json.JSONDecodeError: print("Error decoding JSON from response line") finally: @@ -545,13 +551,14 @@ def test_stream_generate() -> None: # Print a newline print() + def test_generate_with_system() -> None: """Test generate with system prompt""" url = get_base_url("generate") data = create_generate_request_data( CONFIG["test_cases"]["generate"]["query"], system="你是一个知识渊博的助手", - stream=False + stream=False, ) # Send request @@ -567,15 +574,16 @@ def test_generate_with_system() -> None: { "model": response_json["model"], "response": response_json["response"], - "done": response_json["done"] + "done": response_json["done"], }, - "Response content" + "Response content", ) + def test_generate_error_handling() -> None: """Test error handling for generate endpoint""" url = get_base_url("generate") - + # Test empty prompt if OutputControl.is_verbose(): print("\n=== Testing empty prompt ===") @@ -583,14 +591,14 @@ def test_generate_error_handling() -> None: response = make_request(url, data) print(f"Status code: {response.status_code}") print_json_response(response.json(), "Error message") - + # Test invalid options if OutputControl.is_verbose(): print("\n=== Testing invalid options ===") data = create_generate_request_data( CONFIG["test_cases"]["basic"]["query"], options={"invalid_option": "value"}, - stream=False + stream=False, ) response = make_request(url, data) print(f"Status code: {response.status_code}") @@ -602,12 +610,12 @@ def test_generate_concurrent() -> None: import asyncio import aiohttp from contextlib import asynccontextmanager - + @asynccontextmanager async def get_session(): async with aiohttp.ClientSession() as session: yield session - + async def make_request(session, prompt: str): url = get_base_url("generate") data = create_generate_request_data(prompt, stream=False) @@ -616,32 +624,27 @@ def test_generate_concurrent() -> None: return await response.json() except Exception as e: return {"error": str(e)} - + async def run_concurrent_requests(): - prompts = [ - "第一个问题", - "第二个问题", - "第三个问题", - "第四个问题", - "第五个问题" - ] - + prompts = ["第一个问题", "第二个问题", "第三个问题", "第四个问题", "第五个问题"] + async with get_session() as session: tasks = [make_request(session, prompt) for prompt in prompts] results = await asyncio.gather(*tasks) return results - + if OutputControl.is_verbose(): print("\n=== Testing concurrent generate requests ===") - + # Run concurrent requests results = asyncio.run(run_concurrent_requests()) - + # Print results for i, result in enumerate(results, 1): print(f"\nRequest {i} result:") print_json_response(result) + def get_test_cases() -> Dict[str, Callable]: """Get all available test cases Returns: @@ -657,7 +660,7 @@ def get_test_cases() -> Dict[str, Callable]: "stream_generate": test_stream_generate, "generate_with_system": test_generate_with_system, "generate_errors": test_generate_error_handling, - "generate_concurrent": test_generate_concurrent + "generate_concurrent": test_generate_concurrent, }