diff --git a/lightrag/api/lightrag_ollama.py b/lightrag/api/lightrag_ollama.py index bb9b1ac5..af991c19 100644 --- a/lightrag/api/lightrag_ollama.py +++ b/lightrag/api/lightrag_ollama.py @@ -24,22 +24,25 @@ from fastapi.middleware.cors import CORSMiddleware from starlette.status import HTTP_403_FORBIDDEN from dotenv import load_dotenv + load_dotenv() + def estimate_tokens(text: str) -> int: """Estimate the number of tokens in text Chinese characters: approximately 1.5 tokens per character English characters: approximately 0.25 tokens per character """ # Use regex to match Chinese and non-Chinese characters separately - chinese_chars = len(re.findall(r'[\u4e00-\u9fff]', text)) - non_chinese_chars = len(re.findall(r'[^\u4e00-\u9fff]', text)) - + chinese_chars = len(re.findall(r"[\u4e00-\u9fff]", text)) + non_chinese_chars = len(re.findall(r"[^\u4e00-\u9fff]", text)) + # Calculate estimated token count tokens = chinese_chars * 1.5 + non_chinese_chars * 0.25 - + return int(tokens) + # Constants for model information LIGHTRAG_NAME = "lightrag" LIGHTRAG_TAG = "latest" @@ -48,6 +51,7 @@ LIGHTRAG_SIZE = 7365960935 LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z" LIGHTRAG_DIGEST = "sha256:lightrag" + async def llm_model_func( prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs ) -> str: @@ -61,6 +65,7 @@ async def llm_model_func( **kwargs, ) + def get_default_host(binding_type: str) -> str: default_hosts = { "ollama": "http://m4.lan.znipower.com:11434", @@ -245,27 +250,32 @@ class SearchMode(str, Enum): hybrid = "hybrid" mix = "mix" + # Ollama API compatible models class OllamaMessage(BaseModel): role: str content: str images: Optional[List[str]] = None + class OllamaChatRequest(BaseModel): model: str = LIGHTRAG_MODEL messages: List[OllamaMessage] stream: bool = True # Default to streaming mode options: Optional[Dict[str, Any]] = None + class OllamaChatResponse(BaseModel): model: str created_at: str message: OllamaMessage done: bool + class OllamaVersionResponse(BaseModel): version: str + class OllamaModelDetails(BaseModel): parent_model: str format: str @@ -274,6 +284,7 @@ class OllamaModelDetails(BaseModel): parameter_size: str quantization_level: str + class OllamaModel(BaseModel): name: str model: str @@ -282,9 +293,11 @@ class OllamaModel(BaseModel): modified_at: str details: OllamaModelDetails + class OllamaTagResponse(BaseModel): models: List[OllamaModel] + # Original LightRAG models class QueryRequest(BaseModel): query: str @@ -292,9 +305,11 @@ class QueryRequest(BaseModel): stream: bool = False only_need_context: bool = False + class QueryResponse(BaseModel): response: str + class InsertTextRequest(BaseModel): text: str description: Optional[str] = None @@ -395,7 +410,9 @@ def create_app(args): embedding_dim=1024, max_token_size=8192, func=lambda texts: ollama_embedding( - texts, embed_model="bge-m3:latest", host="http://m4.lan.znipower.com:11434" + texts, + embed_model="bge-m3:latest", + host="http://m4.lan.znipower.com:11434", ), ), ) @@ -493,7 +510,7 @@ def create_app(args): # If response is a string (e.g. cache hit), return directly if isinstance(response, str): return QueryResponse(response=response) - + # If it's an async generator, decide whether to stream based on stream parameter if request.stream: result = "" @@ -546,8 +563,8 @@ def create_app(args): "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Methods": "POST, OPTIONS", "Access-Control-Allow-Headers": "Content-Type", - "X-Accel-Buffering": "no" # 禁用 Nginx 缓冲 - } + "X-Accel-Buffering": "no", # 禁用 Nginx 缓冲 + }, ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -652,29 +669,29 @@ def create_app(args): @app.get("/api/version") async def get_version(): """Get Ollama version information""" - return OllamaVersionResponse( - version="0.5.4" - ) + return OllamaVersionResponse(version="0.5.4") @app.get("/api/tags") async def get_tags(): """Get available models""" return OllamaTagResponse( - models=[{ - "name": LIGHTRAG_MODEL, - "model": LIGHTRAG_MODEL, - "size": LIGHTRAG_SIZE, - "digest": LIGHTRAG_DIGEST, - "modified_at": LIGHTRAG_CREATED_AT, - "details": { - "parent_model": "", - "format": "gguf", - "family": LIGHTRAG_NAME, - "families": [LIGHTRAG_NAME], - "parameter_size": "13B", - "quantization_level": "Q4_0" - } - }] + models=[ + { + "name": LIGHTRAG_MODEL, + "model": LIGHTRAG_MODEL, + "size": LIGHTRAG_SIZE, + "digest": LIGHTRAG_DIGEST, + "modified_at": LIGHTRAG_CREATED_AT, + "details": { + "parent_model": "", + "format": "gguf", + "family": LIGHTRAG_NAME, + "families": [LIGHTRAG_NAME], + "parameter_size": "13B", + "quantization_level": "Q4_0", + }, + } + ] ) def parse_query_mode(query: str) -> tuple[str, SearchMode]: @@ -686,15 +703,15 @@ def create_app(args): "/global ": SearchMode.global_, # global_ is used because 'global' is a Python keyword "/naive ": SearchMode.naive, "/hybrid ": SearchMode.hybrid, - "/mix ": SearchMode.mix + "/mix ": SearchMode.mix, } - + for prefix, mode in mode_map.items(): if query.startswith(prefix): # After removing prefix an leading spaces - cleaned_query = query[len(prefix):].lstrip() + cleaned_query = query[len(prefix) :].lstrip() return cleaned_query, mode - + return query, SearchMode.hybrid @app.post("/api/chat") @@ -705,32 +722,29 @@ def create_app(args): messages = request.messages if not messages: raise HTTPException(status_code=400, detail="No messages provided") - + # Get the last message as query query = messages[-1].content - + # 解析查询模式 cleaned_query, mode = parse_query_mode(query) - + # 开始计时 start_time = time.time_ns() - + # 计算输入token数量 prompt_tokens = estimate_tokens(cleaned_query) - + # 调用RAG进行查询 query_param = QueryParam( - mode=mode, - stream=request.stream, - only_need_context=False + mode=mode, stream=request.stream, only_need_context=False ) - + if request.stream: from fastapi.responses import StreamingResponse - + response = await rag.aquery( # Need await to get async generator - cleaned_query, - param=query_param + cleaned_query, param=query_param ) async def stream_generator(): @@ -738,33 +752,37 @@ def create_app(args): first_chunk_time = None last_chunk_time = None total_response = "" - + # Ensure response is an async generator if isinstance(response, str): # If it's a string, send in two parts first_chunk_time = time.time_ns() last_chunk_time = first_chunk_time total_response = response - + # 第一次发送查询内容 data = { "model": LIGHTRAG_MODEL, "created_at": LIGHTRAG_CREATED_AT, "message": { - "role": "assistant", + "role": "assistant", "content": response, - "images": None + "images": None, }, - "done": False + "done": False, } yield f"{json.dumps(data, ensure_ascii=False)}\n" - + # 计算各项指标 completion_tokens = estimate_tokens(total_response) total_time = last_chunk_time - start_time # 总时间 - prompt_eval_time = first_chunk_time - start_time # 首个响应之前的时间 - eval_time = last_chunk_time - first_chunk_time # 生成响应的时间 - + prompt_eval_time = ( + first_chunk_time - start_time + ) # 首个响应之前的时间 + eval_time = ( + last_chunk_time - first_chunk_time + ) # 生成响应的时间 + # 第二次发送统计信息 data = { "model": LIGHTRAG_MODEL, @@ -775,7 +793,7 @@ def create_app(args): "prompt_eval_count": prompt_tokens, # 输入token数 "prompt_eval_duration": prompt_eval_time, # 首个响应之前的时间 "eval_count": completion_tokens, # 输出token数 - "eval_duration": eval_time # 生成响应的时间 + "eval_duration": eval_time, # 生成响应的时间 } yield f"{json.dumps(data, ensure_ascii=False)}\n" else: @@ -785,10 +803,10 @@ def create_app(args): # 记录第一个chunk的时间 if first_chunk_time is None: first_chunk_time = time.time_ns() - + # 更新最后一个chunk的时间 last_chunk_time = time.time_ns() - + # 累积响应内容 total_response += chunk data = { @@ -797,18 +815,22 @@ def create_app(args): "message": { "role": "assistant", "content": chunk, - "images": None + "images": None, }, - "done": False + "done": False, } yield f"{json.dumps(data, ensure_ascii=False)}\n" - + # 计算各项指标 completion_tokens = estimate_tokens(total_response) total_time = last_chunk_time - start_time # 总时间 - prompt_eval_time = first_chunk_time - start_time # 首个响应之前的时间 - eval_time = last_chunk_time - first_chunk_time # 生成响应的时间 - + prompt_eval_time = ( + first_chunk_time - start_time + ) # 首个响应之前的时间 + eval_time = ( + last_chunk_time - first_chunk_time + ) # 生成响应的时间 + # 发送完成标记,包含性能统计信息 data = { "model": LIGHTRAG_MODEL, @@ -819,14 +841,14 @@ def create_app(args): "prompt_eval_count": prompt_tokens, # 输入token数 "prompt_eval_duration": prompt_eval_time, # 首个响应之前的时间 "eval_count": completion_tokens, # 输出token数 - "eval_duration": eval_time # 生成响应的时间 + "eval_duration": eval_time, # 生成响应的时间 } yield f"{json.dumps(data, ensure_ascii=False)}\n" return # 确保生成器在发送完成标记后立即结束 except Exception as e: logging.error(f"Error in stream_generator: {str(e)}") raise - + return StreamingResponse( stream_generator(), media_type="application/x-ndjson", @@ -836,28 +858,25 @@ def create_app(args): "Content-Type": "application/x-ndjson", "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Methods": "POST, OPTIONS", - "Access-Control-Allow-Headers": "Content-Type" - } + "Access-Control-Allow-Headers": "Content-Type", + }, ) else: # 非流式响应 first_chunk_time = time.time_ns() - response_text = await rag.aquery( - cleaned_query, - param=query_param - ) + response_text = await rag.aquery(cleaned_query, param=query_param) last_chunk_time = time.time_ns() - + # 确保响应不为空 if not response_text: response_text = "No response generated" - + # 计算各项指标 completion_tokens = estimate_tokens(str(response_text)) total_time = last_chunk_time - start_time # 总时间 prompt_eval_time = first_chunk_time - start_time # 首个响应之前的时间 eval_time = last_chunk_time - first_chunk_time # 生成响应的时间 - + # 构造响应,包含性能统计信息 return { "model": LIGHTRAG_MODEL, @@ -865,7 +884,7 @@ def create_app(args): "message": { "role": "assistant", "content": str(response_text), # 确保转换为字符串 - "images": None + "images": None, }, "done": True, "total_duration": total_time, # 总时间 @@ -873,7 +892,7 @@ def create_app(args): "prompt_eval_count": prompt_tokens, # 输入token数 "prompt_eval_duration": prompt_eval_time, # 首个响应之前的时间 "eval_count": completion_tokens, # 输出token数 - "eval_duration": eval_time # 生成响应的时间 + "eval_duration": eval_time, # 生成响应的时间 } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) diff --git a/test_lightrag_ollama_chat.py b/test_lightrag_ollama_chat.py index 4f6cab29..96aee692 100644 --- a/test_lightrag_ollama_chat.py +++ b/test_lightrag_ollama_chat.py @@ -18,8 +18,10 @@ from dataclasses import dataclass, asdict from datetime import datetime from pathlib import Path + class OutputControl: """Output control class, manages the verbosity of test output""" + _verbose: bool = False @classmethod @@ -30,9 +32,11 @@ class OutputControl: def is_verbose(cls) -> bool: return cls._verbose + @dataclass class TestResult: """Test result data class""" + name: str success: bool duration: float @@ -43,8 +47,10 @@ class TestResult: if not self.timestamp: self.timestamp = datetime.now().isoformat() + class TestStats: """Test statistics""" + def __init__(self): self.results: List[TestResult] = [] self.start_time = datetime.now() @@ -65,8 +71,8 @@ class TestStats: "total": len(self.results), "passed": sum(1 for r in self.results if r.success), "failed": sum(1 for r in self.results if not r.success), - "total_duration": sum(r.duration for r in self.results) - } + "total_duration": sum(r.duration for r in self.results), + }, } with open(path, "w", encoding="utf-8") as f: @@ -92,6 +98,7 @@ class TestStats: if not result.success: print(f"- {result.name}: {result.error}") + DEFAULT_CONFIG = { "server": { "host": "localhost", @@ -99,16 +106,15 @@ DEFAULT_CONFIG = { "model": "lightrag:latest", "timeout": 30, "max_retries": 3, - "retry_delay": 1 + "retry_delay": 1, }, - "test_cases": { - "basic": { - "query": "唐僧有几个徒弟" - } - } + "test_cases": {"basic": {"query": "唐僧有几个徒弟"}}, } -def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> requests.Response: + +def make_request( + url: str, data: Dict[str, Any], stream: bool = False +) -> requests.Response: """Send an HTTP request with retry mechanism Args: url: Request URL @@ -127,12 +133,7 @@ def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> reques for attempt in range(max_retries): try: - response = requests.post( - url, - json=data, - stream=stream, - timeout=timeout - ) + response = requests.post(url, json=data, stream=stream, timeout=timeout) return response except requests.exceptions.RequestException as e: if attempt == max_retries - 1: # Last retry @@ -140,6 +141,7 @@ def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> reques print(f"\nRequest failed, retrying in {retry_delay} seconds: {str(e)}") time.sleep(retry_delay) + def load_config() -> Dict[str, Any]: """Load configuration file @@ -154,6 +156,7 @@ def load_config() -> Dict[str, Any]: return json.load(f) return DEFAULT_CONFIG + def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2) -> None: """Format and print JSON response data Args: @@ -166,18 +169,19 @@ def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2) print(f"\n=== {title} ===") print(json.dumps(data, ensure_ascii=False, indent=indent)) + # Global configuration CONFIG = load_config() + def get_base_url() -> str: """Return the base URL""" server = CONFIG["server"] return f"http://{server['host']}:{server['port']}/api/chat" + def create_request_data( - content: str, - stream: bool = False, - model: str = None + content: str, stream: bool = False, model: str = None ) -> Dict[str, Any]: """Create basic request data Args: @@ -189,18 +193,15 @@ def create_request_data( """ return { "model": model or CONFIG["server"]["model"], - "messages": [ - { - "role": "user", - "content": content - } - ], - "stream": stream + "messages": [{"role": "user", "content": content}], + "stream": stream, } + # Global test statistics STATS = TestStats() + def run_test(func: Callable, name: str) -> None: """Run a test and record the results Args: @@ -217,13 +218,11 @@ def run_test(func: Callable, name: str) -> None: STATS.add_result(TestResult(name, False, duration, str(e))) raise + def test_non_stream_chat(): """Test non-streaming call to /api/chat endpoint""" url = get_base_url() - data = create_request_data( - CONFIG["test_cases"]["basic"]["query"], - stream=False - ) + data = create_request_data(CONFIG["test_cases"]["basic"]["query"], stream=False) # Send request response = make_request(url, data) @@ -234,10 +233,12 @@ def test_non_stream_chat(): response_json = response.json() # Print response content - print_json_response({ - "model": response_json["model"], - "message": response_json["message"] - }, "Response content") + print_json_response( + {"model": response_json["model"], "message": response_json["message"]}, + "Response content", + ) + + def test_stream_chat(): """Test streaming call to /api/chat endpoint @@ -257,10 +258,7 @@ def test_stream_chat(): The last message will contain performance statistics, with done set to true. """ url = get_base_url() - data = create_request_data( - CONFIG["test_cases"]["basic"]["query"], - stream=True - ) + data = create_request_data(CONFIG["test_cases"]["basic"]["query"], stream=True) # Send request and get streaming response response = make_request(url, data, stream=True) @@ -273,9 +271,11 @@ def test_stream_chat(): if line: # Skip empty lines try: # Decode and parse JSON - data = json.loads(line.decode('utf-8')) + data = json.loads(line.decode("utf-8")) if data.get("done", True): # If it's the completion marker - if "total_duration" in data: # Final performance statistics message + if ( + "total_duration" in data + ): # Final performance statistics message # print_json_response(data, "Performance statistics") break else: # Normal content message @@ -283,7 +283,9 @@ def test_stream_chat(): content = message.get("content", "") if content: # Only collect non-empty content output_buffer.append(content) - print(content, end="", flush=True) # Print content in real-time + print( + content, end="", flush=True + ) # Print content in real-time except json.JSONDecodeError: print("Error decoding JSON from response line") finally: @@ -292,6 +294,7 @@ def test_stream_chat(): # Print a newline print() + def test_query_modes(): """Test different query mode prefixes @@ -311,8 +314,7 @@ def test_query_modes(): if OutputControl.is_verbose(): print(f"\n=== Testing /{mode} mode ===") data = create_request_data( - f"/{mode} {CONFIG['test_cases']['basic']['query']}", - stream=False + f"/{mode} {CONFIG['test_cases']['basic']['query']}", stream=False ) # Send request @@ -320,10 +322,10 @@ def test_query_modes(): response_json = response.json() # Print response content - print_json_response({ - "model": response_json["model"], - "message": response_json["message"] - }) + print_json_response( + {"model": response_json["model"], "message": response_json["message"]} + ) + def create_error_test_data(error_type: str) -> Dict[str, Any]: """Create request data for error testing @@ -337,33 +339,21 @@ def create_error_test_data(error_type: str) -> Dict[str, Any]: Request dictionary containing error data """ error_data = { - "empty_messages": { - "model": "lightrag:latest", - "messages": [], - "stream": True - }, + "empty_messages": {"model": "lightrag:latest", "messages": [], "stream": True}, "invalid_role": { "model": "lightrag:latest", - "messages": [ - { - "invalid_role": "user", - "content": "Test message" - } - ], - "stream": True + "messages": [{"invalid_role": "user", "content": "Test message"}], + "stream": True, }, "missing_content": { "model": "lightrag:latest", - "messages": [ - { - "role": "user" - } - ], - "stream": True - } + "messages": [{"role": "user"}], + "stream": True, + }, } return error_data.get(error_type, error_data["empty_messages"]) + def test_stream_error_handling(): """Test error handling for streaming responses @@ -409,6 +399,7 @@ def test_stream_error_handling(): print_json_response(response.json(), "Error message") response.close() + def test_error_handling(): """Test error handling for non-streaming responses @@ -455,6 +446,7 @@ def test_error_handling(): print(f"Status code: {response.status_code}") print_json_response(response.json(), "Error message") + def get_test_cases() -> Dict[str, Callable]: """Get all available test cases Returns: @@ -465,9 +457,10 @@ def get_test_cases() -> Dict[str, Callable]: "stream": test_stream_chat, "modes": test_query_modes, "errors": test_error_handling, - "stream_errors": test_stream_error_handling + "stream_errors": test_stream_error_handling, } + def create_default_config(): """Create a default configuration file""" config_path = Path("config.json") @@ -476,6 +469,7 @@ def create_default_config(): json.dump(DEFAULT_CONFIG, f, ensure_ascii=False, indent=2) print(f"Default configuration file created: {config_path}") + def parse_args() -> argparse.Namespace: """Parse command line arguments""" parser = argparse.ArgumentParser( @@ -496,38 +490,39 @@ Configuration file (config.json): } } } -""" +""", ) parser.add_argument( - "-q", "--quiet", + "-q", + "--quiet", action="store_true", - help="Silent mode, only display test result summary" + help="Silent mode, only display test result summary", ) parser.add_argument( - "-a", "--ask", + "-a", + "--ask", type=str, - help="Specify query content, which will override the query settings in the configuration file" + help="Specify query content, which will override the query settings in the configuration file", ) parser.add_argument( - "--init-config", - action="store_true", - help="Create default configuration file" + "--init-config", action="store_true", help="Create default configuration file" ) parser.add_argument( "--output", type=str, default="", - help="Test result output file path, default is not to output to a file" + help="Test result output file path, default is not to output to a file", ) parser.add_argument( "--tests", nargs="+", choices=list(get_test_cases().keys()) + ["all"], default=["all"], - help="Test cases to run, options: %(choices)s. Use 'all' to run all tests" + help="Test cases to run, options: %(choices)s. Use 'all' to run all tests", ) return parser.parse_args() + if __name__ == "__main__": args = parse_args()