diff --git a/lightrag/api/README.md b/lightrag/api/README.md index 00a8db18..89906006 100644 --- a/lightrag/api/README.md +++ b/lightrag/api/README.md @@ -94,6 +94,7 @@ For example, chat message "/mix 唐僧有几个徒弟" will trigger a mix mode q After starting the lightrag-server, you can add an Ollama-type connection in the Open WebUI admin pannel. And then a model named lightrag:latest will appear in Open WebUI's model management interface. Users can then send queries to LightRAG through the chat interface. +To prevent Open WebUI from using LightRAG when generating conversation titles, go to Admin Panel > Interface > Set Task Model and change both Local Models and External Models to any option except "Current Model". ## Configuration diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index ea9aaa2f..e5f68d72 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -533,6 +533,7 @@ class OllamaChatRequest(BaseModel): messages: List[OllamaMessage] stream: bool = True # Default to streaming mode options: Optional[Dict[str, Any]] = None + system: Optional[str] = None class OllamaChatResponse(BaseModel): @@ -542,6 +543,28 @@ class OllamaChatResponse(BaseModel): done: bool +class OllamaGenerateRequest(BaseModel): + model: str = LIGHTRAG_MODEL + prompt: str + system: Optional[str] = None + stream: bool = False + options: Optional[Dict[str, Any]] = None + + +class OllamaGenerateResponse(BaseModel): + model: str + created_at: str + response: str + done: bool + context: Optional[List[int]] + total_duration: Optional[int] + load_duration: Optional[int] + prompt_eval_count: Optional[int] + prompt_eval_duration: Optional[int] + eval_count: Optional[int] + eval_duration: Optional[int] + + class OllamaVersionResponse(BaseModel): version: str @@ -1417,6 +1440,145 @@ def create_app(args): return query, SearchMode.hybrid + @app.post("/api/generate") + async def generate(raw_request: Request, request: OllamaGenerateRequest): + """Handle generate completion requests""" + try: + query = request.prompt + start_time = time.time_ns() + prompt_tokens = estimate_tokens(query) + + if request.system: + rag.llm_model_kwargs["system_prompt"] = request.system + + if request.stream: + from fastapi.responses import StreamingResponse + + response = await rag.llm_model_func( + query, stream=True, **rag.llm_model_kwargs + ) + + async def stream_generator(): + try: + first_chunk_time = None + last_chunk_time = None + total_response = "" + + # Ensure response is an async generator + if isinstance(response, str): + # If it's a string, send in two parts + first_chunk_time = time.time_ns() + last_chunk_time = first_chunk_time + total_response = response + + data = { + "model": LIGHTRAG_MODEL, + "created_at": LIGHTRAG_CREATED_AT, + "response": response, + "done": False, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + + completion_tokens = estimate_tokens(total_response) + total_time = last_chunk_time - start_time + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time + + data = { + "model": LIGHTRAG_MODEL, + "created_at": LIGHTRAG_CREATED_AT, + "done": True, + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + else: + async for chunk in response: + if chunk: + if first_chunk_time is None: + first_chunk_time = time.time_ns() + + last_chunk_time = time.time_ns() + + total_response += chunk + data = { + "model": LIGHTRAG_MODEL, + "created_at": LIGHTRAG_CREATED_AT, + "response": chunk, + "done": False, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + + completion_tokens = estimate_tokens(total_response) + total_time = last_chunk_time - start_time + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time + + data = { + "model": LIGHTRAG_MODEL, + "created_at": LIGHTRAG_CREATED_AT, + "done": True, + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + return + + except Exception as e: + logging.error(f"Error in stream_generator: {str(e)}") + raise + + return StreamingResponse( + stream_generator(), + media_type="application/x-ndjson", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Content-Type": "application/x-ndjson", + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "POST, OPTIONS", + "Access-Control-Allow-Headers": "Content-Type", + }, + ) + else: + first_chunk_time = time.time_ns() + response_text = await rag.llm_model_func( + query, stream=False, **rag.llm_model_kwargs + ) + last_chunk_time = time.time_ns() + + if not response_text: + response_text = "No response generated" + + completion_tokens = estimate_tokens(str(response_text)) + total_time = last_chunk_time - start_time + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time + + return { + "model": LIGHTRAG_MODEL, + "created_at": LIGHTRAG_CREATED_AT, + "response": str(response_text), + "done": True, + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time, + } + except Exception as e: + trace_exception(e) + raise HTTPException(status_code=500, detail=str(e)) + @app.post("/api/chat") async def chat(raw_request: Request, request: OllamaChatRequest): """Handle chat completion requests""" @@ -1429,16 +1591,12 @@ def create_app(args): # Get the last message as query query = messages[-1].content - # 解析查询模式 + # Check for query prefix cleaned_query, mode = parse_query_mode(query) - # 开始计时 start_time = time.time_ns() - - # 计算输入token数量 prompt_tokens = estimate_tokens(cleaned_query) - # 调用RAG进行查询 query_param = QueryParam( mode=mode, stream=request.stream, only_need_context=False ) @@ -1549,7 +1707,21 @@ def create_app(args): ) else: first_chunk_time = time.time_ns() - response_text = await rag.aquery(cleaned_query, param=query_param) + + # Determine if the request is from Open WebUI's session title and session keyword generation task + match_result = re.search( + r"\n\nUSER:", cleaned_query, re.MULTILINE + ) + if match_result: + if request.system: + rag.llm_model_kwargs["system_prompt"] = request.system + + response_text = await rag.llm_model_func( + cleaned_query, stream=False, **rag.llm_model_kwargs + ) + else: + response_text = await rag.aquery(cleaned_query, param=query_param) + last_chunk_time = time.time_ns() if not response_text: diff --git a/test_lightrag_ollama_chat.py b/test_lightrag_ollama_chat.py index 96aee692..d1e61d39 100644 --- a/test_lightrag_ollama_chat.py +++ b/test_lightrag_ollama_chat.py @@ -108,7 +108,10 @@ DEFAULT_CONFIG = { "max_retries": 3, "retry_delay": 1, }, - "test_cases": {"basic": {"query": "唐僧有几个徒弟"}}, + "test_cases": { + "basic": {"query": "唐僧有几个徒弟"}, + "generate": {"query": "电视剧西游记导演是谁"}, + }, } @@ -174,22 +177,27 @@ def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2) CONFIG = load_config() -def get_base_url() -> str: - """Return the base URL""" +def get_base_url(endpoint: str = "chat") -> str: + """Return the base URL for specified endpoint + Args: + endpoint: API endpoint name (chat or generate) + Returns: + Complete URL for the endpoint + """ server = CONFIG["server"] - return f"http://{server['host']}:{server['port']}/api/chat" + return f"http://{server['host']}:{server['port']}/api/{endpoint}" -def create_request_data( +def create_chat_request_data( content: str, stream: bool = False, model: str = None ) -> Dict[str, Any]: - """Create basic request data + """Create chat request data Args: content: User message content stream: Whether to use streaming response model: Model name Returns: - Dictionary containing complete request data + Dictionary containing complete chat request data """ return { "model": model or CONFIG["server"]["model"], @@ -198,6 +206,35 @@ def create_request_data( } +def create_generate_request_data( + prompt: str, + system: str = None, + stream: bool = False, + model: str = None, + options: Dict[str, Any] = None, +) -> Dict[str, Any]: + """Create generate request data + Args: + prompt: Generation prompt + system: System prompt + stream: Whether to use streaming response + model: Model name + options: Additional options + Returns: + Dictionary containing complete generate request data + """ + data = { + "model": model or CONFIG["server"]["model"], + "prompt": prompt, + "stream": stream, + } + if system: + data["system"] = system + if options: + data["options"] = options + return data + + # Global test statistics STATS = TestStats() @@ -219,10 +256,12 @@ def run_test(func: Callable, name: str) -> None: raise -def test_non_stream_chat(): +def test_non_stream_chat() -> None: """Test non-streaming call to /api/chat endpoint""" url = get_base_url() - data = create_request_data(CONFIG["test_cases"]["basic"]["query"], stream=False) + data = create_chat_request_data( + CONFIG["test_cases"]["basic"]["query"], stream=False + ) # Send request response = make_request(url, data) @@ -239,7 +278,7 @@ def test_non_stream_chat(): ) -def test_stream_chat(): +def test_stream_chat() -> None: """Test streaming call to /api/chat endpoint Use JSON Lines format to process streaming responses, each line is a complete JSON object. @@ -258,7 +297,7 @@ def test_stream_chat(): The last message will contain performance statistics, with done set to true. """ url = get_base_url() - data = create_request_data(CONFIG["test_cases"]["basic"]["query"], stream=True) + data = create_chat_request_data(CONFIG["test_cases"]["basic"]["query"], stream=True) # Send request and get streaming response response = make_request(url, data, stream=True) @@ -295,7 +334,7 @@ def test_stream_chat(): print() -def test_query_modes(): +def test_query_modes() -> None: """Test different query mode prefixes Supported query modes: @@ -313,7 +352,7 @@ def test_query_modes(): for mode in modes: if OutputControl.is_verbose(): print(f"\n=== Testing /{mode} mode ===") - data = create_request_data( + data = create_chat_request_data( f"/{mode} {CONFIG['test_cases']['basic']['query']}", stream=False ) @@ -354,7 +393,7 @@ def create_error_test_data(error_type: str) -> Dict[str, Any]: return error_data.get(error_type, error_data["empty_messages"]) -def test_stream_error_handling(): +def test_stream_error_handling() -> None: """Test error handling for streaming responses Test scenarios: @@ -400,7 +439,7 @@ def test_stream_error_handling(): response.close() -def test_error_handling(): +def test_error_handling() -> None: """Test error handling for non-streaming responses Test scenarios: @@ -447,6 +486,165 @@ def test_error_handling(): print_json_response(response.json(), "Error message") +def test_non_stream_generate() -> None: + """Test non-streaming call to /api/generate endpoint""" + url = get_base_url("generate") + data = create_generate_request_data( + CONFIG["test_cases"]["generate"]["query"], stream=False + ) + + # Send request + response = make_request(url, data) + + # Print response + if OutputControl.is_verbose(): + print("\n=== Non-streaming generate response ===") + response_json = response.json() + + # Print response content + print_json_response( + { + "model": response_json["model"], + "response": response_json["response"], + "done": response_json["done"], + }, + "Response content", + ) + + +def test_stream_generate() -> None: + """Test streaming call to /api/generate endpoint""" + url = get_base_url("generate") + data = create_generate_request_data( + CONFIG["test_cases"]["generate"]["query"], stream=True + ) + + # Send request and get streaming response + response = make_request(url, data, stream=True) + + if OutputControl.is_verbose(): + print("\n=== Streaming generate response ===") + output_buffer = [] + try: + for line in response.iter_lines(): + if line: # Skip empty lines + try: + # Decode and parse JSON + data = json.loads(line.decode("utf-8")) + if data.get("done", True): # If it's the completion marker + if ( + "total_duration" in data + ): # Final performance statistics message + break + else: # Normal content message + content = data.get("response", "") + if content: # Only collect non-empty content + output_buffer.append(content) + print( + content, end="", flush=True + ) # Print content in real-time + except json.JSONDecodeError: + print("Error decoding JSON from response line") + finally: + response.close() # Ensure the response connection is closed + + # Print a newline + print() + + +def test_generate_with_system() -> None: + """Test generate with system prompt""" + url = get_base_url("generate") + data = create_generate_request_data( + CONFIG["test_cases"]["generate"]["query"], + system="你是一个知识渊博的助手", + stream=False, + ) + + # Send request + response = make_request(url, data) + + # Print response + if OutputControl.is_verbose(): + print("\n=== Generate with system prompt response ===") + response_json = response.json() + + # Print response content + print_json_response( + { + "model": response_json["model"], + "response": response_json["response"], + "done": response_json["done"], + }, + "Response content", + ) + + +def test_generate_error_handling() -> None: + """Test error handling for generate endpoint""" + url = get_base_url("generate") + + # Test empty prompt + if OutputControl.is_verbose(): + print("\n=== Testing empty prompt ===") + data = create_generate_request_data("", stream=False) + response = make_request(url, data) + print(f"Status code: {response.status_code}") + print_json_response(response.json(), "Error message") + + # Test invalid options + if OutputControl.is_verbose(): + print("\n=== Testing invalid options ===") + data = create_generate_request_data( + CONFIG["test_cases"]["basic"]["query"], + options={"invalid_option": "value"}, + stream=False, + ) + response = make_request(url, data) + print(f"Status code: {response.status_code}") + print_json_response(response.json(), "Error message") + + +def test_generate_concurrent() -> None: + """Test concurrent generate requests""" + import asyncio + import aiohttp + from contextlib import asynccontextmanager + + @asynccontextmanager + async def get_session(): + async with aiohttp.ClientSession() as session: + yield session + + async def make_request(session, prompt: str): + url = get_base_url("generate") + data = create_generate_request_data(prompt, stream=False) + try: + async with session.post(url, json=data) as response: + return await response.json() + except Exception as e: + return {"error": str(e)} + + async def run_concurrent_requests(): + prompts = ["第一个问题", "第二个问题", "第三个问题", "第四个问题", "第五个问题"] + + async with get_session() as session: + tasks = [make_request(session, prompt) for prompt in prompts] + results = await asyncio.gather(*tasks) + return results + + if OutputControl.is_verbose(): + print("\n=== Testing concurrent generate requests ===") + + # Run concurrent requests + results = asyncio.run(run_concurrent_requests()) + + # Print results + for i, result in enumerate(results, 1): + print(f"\nRequest {i} result:") + print_json_response(result) + + def get_test_cases() -> Dict[str, Callable]: """Get all available test cases Returns: @@ -458,6 +656,11 @@ def get_test_cases() -> Dict[str, Callable]: "modes": test_query_modes, "errors": test_error_handling, "stream_errors": test_stream_error_handling, + "non_stream_generate": test_non_stream_generate, + "stream_generate": test_stream_generate, + "generate_with_system": test_generate_with_system, + "generate_errors": test_generate_error_handling, + "generate_concurrent": test_generate_concurrent, } @@ -544,18 +747,20 @@ if __name__ == "__main__": if "all" in args.tests: # Run all tests if OutputControl.is_verbose(): - print("\n【Basic Functionality Tests】") - run_test(test_non_stream_chat, "Non-streaming Call Test") - run_test(test_stream_chat, "Streaming Call Test") + print("\n【Chat API Tests】") + run_test(test_non_stream_chat, "Non-streaming Chat Test") + run_test(test_stream_chat, "Streaming Chat Test") + run_test(test_query_modes, "Chat Query Mode Test") + run_test(test_error_handling, "Chat Error Handling Test") + run_test(test_stream_error_handling, "Chat Streaming Error Test") if OutputControl.is_verbose(): - print("\n【Query Mode Tests】") - run_test(test_query_modes, "Query Mode Test") - - if OutputControl.is_verbose(): - print("\n【Error Handling Tests】") - run_test(test_error_handling, "Error Handling Test") - run_test(test_stream_error_handling, "Streaming Error Handling Test") + print("\n【Generate API Tests】") + run_test(test_non_stream_generate, "Non-streaming Generate Test") + run_test(test_stream_generate, "Streaming Generate Test") + run_test(test_generate_with_system, "Generate with System Prompt Test") + run_test(test_generate_error_handling, "Generate Error Handling Test") + run_test(test_generate_concurrent, "Generate Concurrent Test") else: # Run specified tests for test_name in args.tests: