From f30a69e2012db9db7d5d15c921e55e7fd4989e0b Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 24 Jan 2025 23:50:47 +0800 Subject: [PATCH] Fix linting, remove redundant commentsr and clean up code for better readability --- lightrag/api/lightrag_server.py | 90 ++++++++++++++------------------- test_lightrag_ollama_chat.py | 73 +++++++++++++------------- 2 files changed, 76 insertions(+), 87 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index ab9fe732..24da3009 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -476,6 +476,7 @@ class OllamaChatResponse(BaseModel): message: OllamaMessage done: bool + class OllamaGenerateRequest(BaseModel): model: str = LIGHTRAG_MODEL prompt: str @@ -483,6 +484,7 @@ class OllamaGenerateRequest(BaseModel): stream: bool = False options: Optional[Dict[str, Any]] = None + class OllamaGenerateResponse(BaseModel): model: str created_at: str @@ -490,12 +492,13 @@ class OllamaGenerateResponse(BaseModel): done: bool context: Optional[List[int]] total_duration: Optional[int] - load_duration: Optional[int] + load_duration: Optional[int] prompt_eval_count: Optional[int] prompt_eval_duration: Optional[int] eval_count: Optional[int] eval_duration: Optional[int] + class OllamaVersionResponse(BaseModel): version: str @@ -1262,52 +1265,45 @@ def create_app(args): """Handle generate completion requests""" try: query = request.prompt - - # 开始计时 start_time = time.time_ns() - - # 计算输入token数量 prompt_tokens = estimate_tokens(query) - - # 直接使用 llm_model_func 进行查询 + if request.system: rag.llm_model_kwargs["system_prompt"] = request.system - + if request.stream: from fastapi.responses import StreamingResponse - + response = await rag.llm_model_func( - query, - stream=True, - **rag.llm_model_kwargs + query, stream=True, **rag.llm_model_kwargs ) - + async def stream_generator(): try: first_chunk_time = None last_chunk_time = None total_response = "" - - # 处理响应 + + # Ensure response is an async generator if isinstance(response, str): - # 如果是字符串,分两部分发送 + # If it's a string, send in two parts first_chunk_time = time.time_ns() last_chunk_time = first_chunk_time total_response = response - + data = { "model": LIGHTRAG_MODEL, "created_at": LIGHTRAG_CREATED_AT, "response": response, - "done": False + "done": False, } yield f"{json.dumps(data, ensure_ascii=False)}\n" - + completion_tokens = estimate_tokens(total_response) total_time = last_chunk_time - start_time prompt_eval_time = first_chunk_time - start_time eval_time = last_chunk_time - first_chunk_time - + data = { "model": LIGHTRAG_MODEL, "created_at": LIGHTRAG_CREATED_AT, @@ -1317,7 +1313,7 @@ def create_app(args): "prompt_eval_count": prompt_tokens, "prompt_eval_duration": prompt_eval_time, "eval_count": completion_tokens, - "eval_duration": eval_time + "eval_duration": eval_time, } yield f"{json.dumps(data, ensure_ascii=False)}\n" else: @@ -1325,23 +1321,23 @@ def create_app(args): if chunk: if first_chunk_time is None: first_chunk_time = time.time_ns() - + last_chunk_time = time.time_ns() - + total_response += chunk data = { "model": LIGHTRAG_MODEL, "created_at": LIGHTRAG_CREATED_AT, "response": chunk, - "done": False + "done": False, } yield f"{json.dumps(data, ensure_ascii=False)}\n" - + completion_tokens = estimate_tokens(total_response) total_time = last_chunk_time - start_time prompt_eval_time = first_chunk_time - start_time eval_time = last_chunk_time - first_chunk_time - + data = { "model": LIGHTRAG_MODEL, "created_at": LIGHTRAG_CREATED_AT, @@ -1351,15 +1347,15 @@ def create_app(args): "prompt_eval_count": prompt_tokens, "prompt_eval_duration": prompt_eval_time, "eval_count": completion_tokens, - "eval_duration": eval_time + "eval_duration": eval_time, } yield f"{json.dumps(data, ensure_ascii=False)}\n" return - + except Exception as e: logging.error(f"Error in stream_generator: {str(e)}") raise - + return StreamingResponse( stream_generator(), media_type="application/x-ndjson", @@ -1375,20 +1371,18 @@ def create_app(args): else: first_chunk_time = time.time_ns() response_text = await rag.llm_model_func( - query, - stream=False, - **rag.llm_model_kwargs + query, stream=False, **rag.llm_model_kwargs ) last_chunk_time = time.time_ns() - + if not response_text: response_text = "No response generated" - + completion_tokens = estimate_tokens(str(response_text)) total_time = last_chunk_time - start_time prompt_eval_time = first_chunk_time - start_time eval_time = last_chunk_time - first_chunk_time - + return { "model": LIGHTRAG_MODEL, "created_at": LIGHTRAG_CREATED_AT, @@ -1399,7 +1393,7 @@ def create_app(args): "prompt_eval_count": prompt_tokens, "prompt_eval_duration": prompt_eval_time, "eval_count": completion_tokens, - "eval_duration": eval_time + "eval_duration": eval_time, } except Exception as e: trace_exception(e) @@ -1417,16 +1411,12 @@ def create_app(args): # Get the last message as query query = messages[-1].content - # 解析查询模式 + # Check for query prefix cleaned_query, mode = parse_query_mode(query) - # 开始计时 start_time = time.time_ns() - - # 计算输入token数量 prompt_tokens = estimate_tokens(cleaned_query) - # 调用RAG进行查询 query_param = QueryParam( mode=mode, stream=request.stream, only_need_context=False ) @@ -1537,25 +1527,21 @@ def create_app(args): ) else: first_chunk_time = time.time_ns() - - # 判断是否包含特定字符串,使用正则表达式进行匹配 - logging.info(f"Cleaned query content: {cleaned_query}") - match_result = re.search(r'\n\nUSER:', cleaned_query, re.MULTILINE) - logging.info(f"Regex match result: {bool(match_result)}") - - if match_result: + # Determine if the request is from Open WebUI's session title and session keyword generation task + match_result = re.search( + r"\n\nUSER:", cleaned_query, re.MULTILINE + ) + if match_result: if request.system: rag.llm_model_kwargs["system_prompt"] = request.system response_text = await rag.llm_model_func( - cleaned_query, - stream=False, - **rag.llm_model_kwargs + cleaned_query, stream=False, **rag.llm_model_kwargs ) else: response_text = await rag.aquery(cleaned_query, param=query_param) - + last_chunk_time = time.time_ns() if not response_text: diff --git a/test_lightrag_ollama_chat.py b/test_lightrag_ollama_chat.py index 21014735..d1e61d39 100644 --- a/test_lightrag_ollama_chat.py +++ b/test_lightrag_ollama_chat.py @@ -110,7 +110,7 @@ DEFAULT_CONFIG = { }, "test_cases": { "basic": {"query": "唐僧有几个徒弟"}, - "generate": {"query": "电视剧西游记导演是谁"} + "generate": {"query": "电视剧西游记导演是谁"}, }, } @@ -205,12 +205,13 @@ def create_chat_request_data( "stream": stream, } + def create_generate_request_data( - prompt: str, + prompt: str, system: str = None, - stream: bool = False, + stream: bool = False, model: str = None, - options: Dict[str, Any] = None + options: Dict[str, Any] = None, ) -> Dict[str, Any]: """Create generate request data Args: @@ -225,7 +226,7 @@ def create_generate_request_data( data = { "model": model or CONFIG["server"]["model"], "prompt": prompt, - "stream": stream + "stream": stream, } if system: data["system"] = system @@ -258,7 +259,9 @@ def run_test(func: Callable, name: str) -> None: def test_non_stream_chat() -> None: """Test non-streaming call to /api/chat endpoint""" url = get_base_url() - data = create_chat_request_data(CONFIG["test_cases"]["basic"]["query"], stream=False) + data = create_chat_request_data( + CONFIG["test_cases"]["basic"]["query"], stream=False + ) # Send request response = make_request(url, data) @@ -487,8 +490,7 @@ def test_non_stream_generate() -> None: """Test non-streaming call to /api/generate endpoint""" url = get_base_url("generate") data = create_generate_request_data( - CONFIG["test_cases"]["generate"]["query"], - stream=False + CONFIG["test_cases"]["generate"]["query"], stream=False ) # Send request @@ -504,17 +506,17 @@ def test_non_stream_generate() -> None: { "model": response_json["model"], "response": response_json["response"], - "done": response_json["done"] + "done": response_json["done"], }, - "Response content" + "Response content", ) + def test_stream_generate() -> None: """Test streaming call to /api/generate endpoint""" url = get_base_url("generate") data = create_generate_request_data( - CONFIG["test_cases"]["generate"]["query"], - stream=True + CONFIG["test_cases"]["generate"]["query"], stream=True ) # Send request and get streaming response @@ -530,13 +532,17 @@ def test_stream_generate() -> None: # Decode and parse JSON data = json.loads(line.decode("utf-8")) if data.get("done", True): # If it's the completion marker - if "total_duration" in data: # Final performance statistics message + if ( + "total_duration" in data + ): # Final performance statistics message break else: # Normal content message content = data.get("response", "") if content: # Only collect non-empty content output_buffer.append(content) - print(content, end="", flush=True) # Print content in real-time + print( + content, end="", flush=True + ) # Print content in real-time except json.JSONDecodeError: print("Error decoding JSON from response line") finally: @@ -545,13 +551,14 @@ def test_stream_generate() -> None: # Print a newline print() + def test_generate_with_system() -> None: """Test generate with system prompt""" url = get_base_url("generate") data = create_generate_request_data( CONFIG["test_cases"]["generate"]["query"], system="你是一个知识渊博的助手", - stream=False + stream=False, ) # Send request @@ -567,15 +574,16 @@ def test_generate_with_system() -> None: { "model": response_json["model"], "response": response_json["response"], - "done": response_json["done"] + "done": response_json["done"], }, - "Response content" + "Response content", ) + def test_generate_error_handling() -> None: """Test error handling for generate endpoint""" url = get_base_url("generate") - + # Test empty prompt if OutputControl.is_verbose(): print("\n=== Testing empty prompt ===") @@ -583,14 +591,14 @@ def test_generate_error_handling() -> None: response = make_request(url, data) print(f"Status code: {response.status_code}") print_json_response(response.json(), "Error message") - + # Test invalid options if OutputControl.is_verbose(): print("\n=== Testing invalid options ===") data = create_generate_request_data( CONFIG["test_cases"]["basic"]["query"], options={"invalid_option": "value"}, - stream=False + stream=False, ) response = make_request(url, data) print(f"Status code: {response.status_code}") @@ -602,12 +610,12 @@ def test_generate_concurrent() -> None: import asyncio import aiohttp from contextlib import asynccontextmanager - + @asynccontextmanager async def get_session(): async with aiohttp.ClientSession() as session: yield session - + async def make_request(session, prompt: str): url = get_base_url("generate") data = create_generate_request_data(prompt, stream=False) @@ -616,32 +624,27 @@ def test_generate_concurrent() -> None: return await response.json() except Exception as e: return {"error": str(e)} - + async def run_concurrent_requests(): - prompts = [ - "第一个问题", - "第二个问题", - "第三个问题", - "第四个问题", - "第五个问题" - ] - + prompts = ["第一个问题", "第二个问题", "第三个问题", "第四个问题", "第五个问题"] + async with get_session() as session: tasks = [make_request(session, prompt) for prompt in prompts] results = await asyncio.gather(*tasks) return results - + if OutputControl.is_verbose(): print("\n=== Testing concurrent generate requests ===") - + # Run concurrent requests results = asyncio.run(run_concurrent_requests()) - + # Print results for i, result in enumerate(results, 1): print(f"\nRequest {i} result:") print_json_response(result) + def get_test_cases() -> Dict[str, Callable]: """Get all available test cases Returns: @@ -657,7 +660,7 @@ def get_test_cases() -> Dict[str, Callable]: "stream_generate": test_stream_generate, "generate_with_system": test_generate_with_system, "generate_errors": test_generate_error_handling, - "generate_concurrent": test_generate_concurrent + "generate_concurrent": test_generate_concurrent, }