pre-commit run --all-files

This commit is contained in:
yangdx
2025-01-17 14:20:55 +08:00
parent 48f70ff8b4
commit fa9765ecd9
2 changed files with 163 additions and 149 deletions

View File

@@ -24,22 +24,25 @@ from fastapi.middleware.cors import CORSMiddleware
from starlette.status import HTTP_403_FORBIDDEN from starlette.status import HTTP_403_FORBIDDEN
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
def estimate_tokens(text: str) -> int: def estimate_tokens(text: str) -> int:
"""Estimate the number of tokens in text """Estimate the number of tokens in text
Chinese characters: approximately 1.5 tokens per character Chinese characters: approximately 1.5 tokens per character
English characters: approximately 0.25 tokens per character English characters: approximately 0.25 tokens per character
""" """
# Use regex to match Chinese and non-Chinese characters separately # Use regex to match Chinese and non-Chinese characters separately
chinese_chars = len(re.findall(r'[\u4e00-\u9fff]', text)) chinese_chars = len(re.findall(r"[\u4e00-\u9fff]", text))
non_chinese_chars = len(re.findall(r'[^\u4e00-\u9fff]', text)) non_chinese_chars = len(re.findall(r"[^\u4e00-\u9fff]", text))
# Calculate estimated token count # Calculate estimated token count
tokens = chinese_chars * 1.5 + non_chinese_chars * 0.25 tokens = chinese_chars * 1.5 + non_chinese_chars * 0.25
return int(tokens) return int(tokens)
# Constants for model information # Constants for model information
LIGHTRAG_NAME = "lightrag" LIGHTRAG_NAME = "lightrag"
LIGHTRAG_TAG = "latest" LIGHTRAG_TAG = "latest"
@@ -48,6 +51,7 @@ LIGHTRAG_SIZE = 7365960935
LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z" LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
LIGHTRAG_DIGEST = "sha256:lightrag" LIGHTRAG_DIGEST = "sha256:lightrag"
async def llm_model_func( async def llm_model_func(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str: ) -> str:
@@ -61,6 +65,7 @@ async def llm_model_func(
**kwargs, **kwargs,
) )
def get_default_host(binding_type: str) -> str: def get_default_host(binding_type: str) -> str:
default_hosts = { default_hosts = {
"ollama": "http://m4.lan.znipower.com:11434", "ollama": "http://m4.lan.znipower.com:11434",
@@ -245,27 +250,32 @@ class SearchMode(str, Enum):
hybrid = "hybrid" hybrid = "hybrid"
mix = "mix" mix = "mix"
# Ollama API compatible models # Ollama API compatible models
class OllamaMessage(BaseModel): class OllamaMessage(BaseModel):
role: str role: str
content: str content: str
images: Optional[List[str]] = None images: Optional[List[str]] = None
class OllamaChatRequest(BaseModel): class OllamaChatRequest(BaseModel):
model: str = LIGHTRAG_MODEL model: str = LIGHTRAG_MODEL
messages: List[OllamaMessage] messages: List[OllamaMessage]
stream: bool = True # Default to streaming mode stream: bool = True # Default to streaming mode
options: Optional[Dict[str, Any]] = None options: Optional[Dict[str, Any]] = None
class OllamaChatResponse(BaseModel): class OllamaChatResponse(BaseModel):
model: str model: str
created_at: str created_at: str
message: OllamaMessage message: OllamaMessage
done: bool done: bool
class OllamaVersionResponse(BaseModel): class OllamaVersionResponse(BaseModel):
version: str version: str
class OllamaModelDetails(BaseModel): class OllamaModelDetails(BaseModel):
parent_model: str parent_model: str
format: str format: str
@@ -274,6 +284,7 @@ class OllamaModelDetails(BaseModel):
parameter_size: str parameter_size: str
quantization_level: str quantization_level: str
class OllamaModel(BaseModel): class OllamaModel(BaseModel):
name: str name: str
model: str model: str
@@ -282,9 +293,11 @@ class OllamaModel(BaseModel):
modified_at: str modified_at: str
details: OllamaModelDetails details: OllamaModelDetails
class OllamaTagResponse(BaseModel): class OllamaTagResponse(BaseModel):
models: List[OllamaModel] models: List[OllamaModel]
# Original LightRAG models # Original LightRAG models
class QueryRequest(BaseModel): class QueryRequest(BaseModel):
query: str query: str
@@ -292,9 +305,11 @@ class QueryRequest(BaseModel):
stream: bool = False stream: bool = False
only_need_context: bool = False only_need_context: bool = False
class QueryResponse(BaseModel): class QueryResponse(BaseModel):
response: str response: str
class InsertTextRequest(BaseModel): class InsertTextRequest(BaseModel):
text: str text: str
description: Optional[str] = None description: Optional[str] = None
@@ -395,7 +410,9 @@ def create_app(args):
embedding_dim=1024, embedding_dim=1024,
max_token_size=8192, max_token_size=8192,
func=lambda texts: ollama_embedding( func=lambda texts: ollama_embedding(
texts, embed_model="bge-m3:latest", host="http://m4.lan.znipower.com:11434" texts,
embed_model="bge-m3:latest",
host="http://m4.lan.znipower.com:11434",
), ),
), ),
) )
@@ -493,7 +510,7 @@ def create_app(args):
# If response is a string (e.g. cache hit), return directly # If response is a string (e.g. cache hit), return directly
if isinstance(response, str): if isinstance(response, str):
return QueryResponse(response=response) return QueryResponse(response=response)
# If it's an async generator, decide whether to stream based on stream parameter # If it's an async generator, decide whether to stream based on stream parameter
if request.stream: if request.stream:
result = "" result = ""
@@ -546,8 +563,8 @@ def create_app(args):
"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "POST, OPTIONS", "Access-Control-Allow-Methods": "POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type", "Access-Control-Allow-Headers": "Content-Type",
"X-Accel-Buffering": "no" # 禁用 Nginx 缓冲 "X-Accel-Buffering": "no", # 禁用 Nginx 缓冲
} },
) )
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@@ -652,29 +669,29 @@ def create_app(args):
@app.get("/api/version") @app.get("/api/version")
async def get_version(): async def get_version():
"""Get Ollama version information""" """Get Ollama version information"""
return OllamaVersionResponse( return OllamaVersionResponse(version="0.5.4")
version="0.5.4"
)
@app.get("/api/tags") @app.get("/api/tags")
async def get_tags(): async def get_tags():
"""Get available models""" """Get available models"""
return OllamaTagResponse( return OllamaTagResponse(
models=[{ models=[
"name": LIGHTRAG_MODEL, {
"model": LIGHTRAG_MODEL, "name": LIGHTRAG_MODEL,
"size": LIGHTRAG_SIZE, "model": LIGHTRAG_MODEL,
"digest": LIGHTRAG_DIGEST, "size": LIGHTRAG_SIZE,
"modified_at": LIGHTRAG_CREATED_AT, "digest": LIGHTRAG_DIGEST,
"details": { "modified_at": LIGHTRAG_CREATED_AT,
"parent_model": "", "details": {
"format": "gguf", "parent_model": "",
"family": LIGHTRAG_NAME, "format": "gguf",
"families": [LIGHTRAG_NAME], "family": LIGHTRAG_NAME,
"parameter_size": "13B", "families": [LIGHTRAG_NAME],
"quantization_level": "Q4_0" "parameter_size": "13B",
} "quantization_level": "Q4_0",
}] },
}
]
) )
def parse_query_mode(query: str) -> tuple[str, SearchMode]: def parse_query_mode(query: str) -> tuple[str, SearchMode]:
@@ -686,15 +703,15 @@ def create_app(args):
"/global ": SearchMode.global_, # global_ is used because 'global' is a Python keyword "/global ": SearchMode.global_, # global_ is used because 'global' is a Python keyword
"/naive ": SearchMode.naive, "/naive ": SearchMode.naive,
"/hybrid ": SearchMode.hybrid, "/hybrid ": SearchMode.hybrid,
"/mix ": SearchMode.mix "/mix ": SearchMode.mix,
} }
for prefix, mode in mode_map.items(): for prefix, mode in mode_map.items():
if query.startswith(prefix): if query.startswith(prefix):
# After removing prefix an leading spaces # After removing prefix an leading spaces
cleaned_query = query[len(prefix):].lstrip() cleaned_query = query[len(prefix) :].lstrip()
return cleaned_query, mode return cleaned_query, mode
return query, SearchMode.hybrid return query, SearchMode.hybrid
@app.post("/api/chat") @app.post("/api/chat")
@@ -705,32 +722,29 @@ def create_app(args):
messages = request.messages messages = request.messages
if not messages: if not messages:
raise HTTPException(status_code=400, detail="No messages provided") raise HTTPException(status_code=400, detail="No messages provided")
# Get the last message as query # Get the last message as query
query = messages[-1].content query = messages[-1].content
# 解析查询模式 # 解析查询模式
cleaned_query, mode = parse_query_mode(query) cleaned_query, mode = parse_query_mode(query)
# 开始计时 # 开始计时
start_time = time.time_ns() start_time = time.time_ns()
# 计算输入token数量 # 计算输入token数量
prompt_tokens = estimate_tokens(cleaned_query) prompt_tokens = estimate_tokens(cleaned_query)
# 调用RAG进行查询 # 调用RAG进行查询
query_param = QueryParam( query_param = QueryParam(
mode=mode, mode=mode, stream=request.stream, only_need_context=False
stream=request.stream,
only_need_context=False
) )
if request.stream: if request.stream:
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
response = await rag.aquery( # Need await to get async generator response = await rag.aquery( # Need await to get async generator
cleaned_query, cleaned_query, param=query_param
param=query_param
) )
async def stream_generator(): async def stream_generator():
@@ -738,33 +752,37 @@ def create_app(args):
first_chunk_time = None first_chunk_time = None
last_chunk_time = None last_chunk_time = None
total_response = "" total_response = ""
# Ensure response is an async generator # Ensure response is an async generator
if isinstance(response, str): if isinstance(response, str):
# If it's a string, send in two parts # If it's a string, send in two parts
first_chunk_time = time.time_ns() first_chunk_time = time.time_ns()
last_chunk_time = first_chunk_time last_chunk_time = first_chunk_time
total_response = response total_response = response
# 第一次发送查询内容 # 第一次发送查询内容
data = { data = {
"model": LIGHTRAG_MODEL, "model": LIGHTRAG_MODEL,
"created_at": LIGHTRAG_CREATED_AT, "created_at": LIGHTRAG_CREATED_AT,
"message": { "message": {
"role": "assistant", "role": "assistant",
"content": response, "content": response,
"images": None "images": None,
}, },
"done": False "done": False,
} }
yield f"{json.dumps(data, ensure_ascii=False)}\n" yield f"{json.dumps(data, ensure_ascii=False)}\n"
# 计算各项指标 # 计算各项指标
completion_tokens = estimate_tokens(total_response) completion_tokens = estimate_tokens(total_response)
total_time = last_chunk_time - start_time # 总时间 total_time = last_chunk_time - start_time # 总时间
prompt_eval_time = first_chunk_time - start_time # 首个响应之前的时间 prompt_eval_time = (
eval_time = last_chunk_time - first_chunk_time # 生成响应的时间 first_chunk_time - start_time
) # 首个响应之前的时间
eval_time = (
last_chunk_time - first_chunk_time
) # 生成响应的时间
# 第二次发送统计信息 # 第二次发送统计信息
data = { data = {
"model": LIGHTRAG_MODEL, "model": LIGHTRAG_MODEL,
@@ -775,7 +793,7 @@ def create_app(args):
"prompt_eval_count": prompt_tokens, # 输入token数 "prompt_eval_count": prompt_tokens, # 输入token数
"prompt_eval_duration": prompt_eval_time, # 首个响应之前的时间 "prompt_eval_duration": prompt_eval_time, # 首个响应之前的时间
"eval_count": completion_tokens, # 输出token数 "eval_count": completion_tokens, # 输出token数
"eval_duration": eval_time # 生成响应的时间 "eval_duration": eval_time, # 生成响应的时间
} }
yield f"{json.dumps(data, ensure_ascii=False)}\n" yield f"{json.dumps(data, ensure_ascii=False)}\n"
else: else:
@@ -785,10 +803,10 @@ def create_app(args):
# 记录第一个chunk的时间 # 记录第一个chunk的时间
if first_chunk_time is None: if first_chunk_time is None:
first_chunk_time = time.time_ns() first_chunk_time = time.time_ns()
# 更新最后一个chunk的时间 # 更新最后一个chunk的时间
last_chunk_time = time.time_ns() last_chunk_time = time.time_ns()
# 累积响应内容 # 累积响应内容
total_response += chunk total_response += chunk
data = { data = {
@@ -797,18 +815,22 @@ def create_app(args):
"message": { "message": {
"role": "assistant", "role": "assistant",
"content": chunk, "content": chunk,
"images": None "images": None,
}, },
"done": False "done": False,
} }
yield f"{json.dumps(data, ensure_ascii=False)}\n" yield f"{json.dumps(data, ensure_ascii=False)}\n"
# 计算各项指标 # 计算各项指标
completion_tokens = estimate_tokens(total_response) completion_tokens = estimate_tokens(total_response)
total_time = last_chunk_time - start_time # 总时间 total_time = last_chunk_time - start_time # 总时间
prompt_eval_time = first_chunk_time - start_time # 首个响应之前的时间 prompt_eval_time = (
eval_time = last_chunk_time - first_chunk_time # 生成响应的时间 first_chunk_time - start_time
) # 首个响应之前的时间
eval_time = (
last_chunk_time - first_chunk_time
) # 生成响应的时间
# 发送完成标记,包含性能统计信息 # 发送完成标记,包含性能统计信息
data = { data = {
"model": LIGHTRAG_MODEL, "model": LIGHTRAG_MODEL,
@@ -819,14 +841,14 @@ def create_app(args):
"prompt_eval_count": prompt_tokens, # 输入token数 "prompt_eval_count": prompt_tokens, # 输入token数
"prompt_eval_duration": prompt_eval_time, # 首个响应之前的时间 "prompt_eval_duration": prompt_eval_time, # 首个响应之前的时间
"eval_count": completion_tokens, # 输出token数 "eval_count": completion_tokens, # 输出token数
"eval_duration": eval_time # 生成响应的时间 "eval_duration": eval_time, # 生成响应的时间
} }
yield f"{json.dumps(data, ensure_ascii=False)}\n" yield f"{json.dumps(data, ensure_ascii=False)}\n"
return # 确保生成器在发送完成标记后立即结束 return # 确保生成器在发送完成标记后立即结束
except Exception as e: except Exception as e:
logging.error(f"Error in stream_generator: {str(e)}") logging.error(f"Error in stream_generator: {str(e)}")
raise raise
return StreamingResponse( return StreamingResponse(
stream_generator(), stream_generator(),
media_type="application/x-ndjson", media_type="application/x-ndjson",
@@ -836,28 +858,25 @@ def create_app(args):
"Content-Type": "application/x-ndjson", "Content-Type": "application/x-ndjson",
"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "POST, OPTIONS", "Access-Control-Allow-Methods": "POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type" "Access-Control-Allow-Headers": "Content-Type",
} },
) )
else: else:
# 非流式响应 # 非流式响应
first_chunk_time = time.time_ns() first_chunk_time = time.time_ns()
response_text = await rag.aquery( response_text = await rag.aquery(cleaned_query, param=query_param)
cleaned_query,
param=query_param
)
last_chunk_time = time.time_ns() last_chunk_time = time.time_ns()
# 确保响应不为空 # 确保响应不为空
if not response_text: if not response_text:
response_text = "No response generated" response_text = "No response generated"
# 计算各项指标 # 计算各项指标
completion_tokens = estimate_tokens(str(response_text)) completion_tokens = estimate_tokens(str(response_text))
total_time = last_chunk_time - start_time # 总时间 total_time = last_chunk_time - start_time # 总时间
prompt_eval_time = first_chunk_time - start_time # 首个响应之前的时间 prompt_eval_time = first_chunk_time - start_time # 首个响应之前的时间
eval_time = last_chunk_time - first_chunk_time # 生成响应的时间 eval_time = last_chunk_time - first_chunk_time # 生成响应的时间
# 构造响应,包含性能统计信息 # 构造响应,包含性能统计信息
return { return {
"model": LIGHTRAG_MODEL, "model": LIGHTRAG_MODEL,
@@ -865,7 +884,7 @@ def create_app(args):
"message": { "message": {
"role": "assistant", "role": "assistant",
"content": str(response_text), # 确保转换为字符串 "content": str(response_text), # 确保转换为字符串
"images": None "images": None,
}, },
"done": True, "done": True,
"total_duration": total_time, # 总时间 "total_duration": total_time, # 总时间
@@ -873,7 +892,7 @@ def create_app(args):
"prompt_eval_count": prompt_tokens, # 输入token数 "prompt_eval_count": prompt_tokens, # 输入token数
"prompt_eval_duration": prompt_eval_time, # 首个响应之前的时间 "prompt_eval_duration": prompt_eval_time, # 首个响应之前的时间
"eval_count": completion_tokens, # 输出token数 "eval_count": completion_tokens, # 输出token数
"eval_duration": eval_time # 生成响应的时间 "eval_duration": eval_time, # 生成响应的时间
} }
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))

View File

@@ -18,8 +18,10 @@ from dataclasses import dataclass, asdict
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
class OutputControl: class OutputControl:
"""Output control class, manages the verbosity of test output""" """Output control class, manages the verbosity of test output"""
_verbose: bool = False _verbose: bool = False
@classmethod @classmethod
@@ -30,9 +32,11 @@ class OutputControl:
def is_verbose(cls) -> bool: def is_verbose(cls) -> bool:
return cls._verbose return cls._verbose
@dataclass @dataclass
class TestResult: class TestResult:
"""Test result data class""" """Test result data class"""
name: str name: str
success: bool success: bool
duration: float duration: float
@@ -43,8 +47,10 @@ class TestResult:
if not self.timestamp: if not self.timestamp:
self.timestamp = datetime.now().isoformat() self.timestamp = datetime.now().isoformat()
class TestStats: class TestStats:
"""Test statistics""" """Test statistics"""
def __init__(self): def __init__(self):
self.results: List[TestResult] = [] self.results: List[TestResult] = []
self.start_time = datetime.now() self.start_time = datetime.now()
@@ -65,8 +71,8 @@ class TestStats:
"total": len(self.results), "total": len(self.results),
"passed": sum(1 for r in self.results if r.success), "passed": sum(1 for r in self.results if r.success),
"failed": sum(1 for r in self.results if not r.success), "failed": sum(1 for r in self.results if not r.success),
"total_duration": sum(r.duration for r in self.results) "total_duration": sum(r.duration for r in self.results),
} },
} }
with open(path, "w", encoding="utf-8") as f: with open(path, "w", encoding="utf-8") as f:
@@ -92,6 +98,7 @@ class TestStats:
if not result.success: if not result.success:
print(f"- {result.name}: {result.error}") print(f"- {result.name}: {result.error}")
DEFAULT_CONFIG = { DEFAULT_CONFIG = {
"server": { "server": {
"host": "localhost", "host": "localhost",
@@ -99,16 +106,15 @@ DEFAULT_CONFIG = {
"model": "lightrag:latest", "model": "lightrag:latest",
"timeout": 30, "timeout": 30,
"max_retries": 3, "max_retries": 3,
"retry_delay": 1 "retry_delay": 1,
}, },
"test_cases": { "test_cases": {"basic": {"query": "唐僧有几个徒弟"}},
"basic": {
"query": "唐僧有几个徒弟"
}
}
} }
def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> requests.Response:
def make_request(
url: str, data: Dict[str, Any], stream: bool = False
) -> requests.Response:
"""Send an HTTP request with retry mechanism """Send an HTTP request with retry mechanism
Args: Args:
url: Request URL url: Request URL
@@ -127,12 +133,7 @@ def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> reques
for attempt in range(max_retries): for attempt in range(max_retries):
try: try:
response = requests.post( response = requests.post(url, json=data, stream=stream, timeout=timeout)
url,
json=data,
stream=stream,
timeout=timeout
)
return response return response
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
if attempt == max_retries - 1: # Last retry if attempt == max_retries - 1: # Last retry
@@ -140,6 +141,7 @@ def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> reques
print(f"\nRequest failed, retrying in {retry_delay} seconds: {str(e)}") print(f"\nRequest failed, retrying in {retry_delay} seconds: {str(e)}")
time.sleep(retry_delay) time.sleep(retry_delay)
def load_config() -> Dict[str, Any]: def load_config() -> Dict[str, Any]:
"""Load configuration file """Load configuration file
@@ -154,6 +156,7 @@ def load_config() -> Dict[str, Any]:
return json.load(f) return json.load(f)
return DEFAULT_CONFIG return DEFAULT_CONFIG
def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2) -> None: def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2) -> None:
"""Format and print JSON response data """Format and print JSON response data
Args: Args:
@@ -166,18 +169,19 @@ def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2)
print(f"\n=== {title} ===") print(f"\n=== {title} ===")
print(json.dumps(data, ensure_ascii=False, indent=indent)) print(json.dumps(data, ensure_ascii=False, indent=indent))
# Global configuration # Global configuration
CONFIG = load_config() CONFIG = load_config()
def get_base_url() -> str: def get_base_url() -> str:
"""Return the base URL""" """Return the base URL"""
server = CONFIG["server"] server = CONFIG["server"]
return f"http://{server['host']}:{server['port']}/api/chat" return f"http://{server['host']}:{server['port']}/api/chat"
def create_request_data( def create_request_data(
content: str, content: str, stream: bool = False, model: str = None
stream: bool = False,
model: str = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Create basic request data """Create basic request data
Args: Args:
@@ -189,18 +193,15 @@ def create_request_data(
""" """
return { return {
"model": model or CONFIG["server"]["model"], "model": model or CONFIG["server"]["model"],
"messages": [ "messages": [{"role": "user", "content": content}],
{ "stream": stream,
"role": "user",
"content": content
}
],
"stream": stream
} }
# Global test statistics # Global test statistics
STATS = TestStats() STATS = TestStats()
def run_test(func: Callable, name: str) -> None: def run_test(func: Callable, name: str) -> None:
"""Run a test and record the results """Run a test and record the results
Args: Args:
@@ -217,13 +218,11 @@ def run_test(func: Callable, name: str) -> None:
STATS.add_result(TestResult(name, False, duration, str(e))) STATS.add_result(TestResult(name, False, duration, str(e)))
raise raise
def test_non_stream_chat(): def test_non_stream_chat():
"""Test non-streaming call to /api/chat endpoint""" """Test non-streaming call to /api/chat endpoint"""
url = get_base_url() url = get_base_url()
data = create_request_data( data = create_request_data(CONFIG["test_cases"]["basic"]["query"], stream=False)
CONFIG["test_cases"]["basic"]["query"],
stream=False
)
# Send request # Send request
response = make_request(url, data) response = make_request(url, data)
@@ -234,10 +233,12 @@ def test_non_stream_chat():
response_json = response.json() response_json = response.json()
# Print response content # Print response content
print_json_response({ print_json_response(
"model": response_json["model"], {"model": response_json["model"], "message": response_json["message"]},
"message": response_json["message"] "Response content",
}, "Response content") )
def test_stream_chat(): def test_stream_chat():
"""Test streaming call to /api/chat endpoint """Test streaming call to /api/chat endpoint
@@ -257,10 +258,7 @@ def test_stream_chat():
The last message will contain performance statistics, with done set to true. The last message will contain performance statistics, with done set to true.
""" """
url = get_base_url() url = get_base_url()
data = create_request_data( data = create_request_data(CONFIG["test_cases"]["basic"]["query"], stream=True)
CONFIG["test_cases"]["basic"]["query"],
stream=True
)
# Send request and get streaming response # Send request and get streaming response
response = make_request(url, data, stream=True) response = make_request(url, data, stream=True)
@@ -273,9 +271,11 @@ def test_stream_chat():
if line: # Skip empty lines if line: # Skip empty lines
try: try:
# Decode and parse JSON # Decode and parse JSON
data = json.loads(line.decode('utf-8')) data = json.loads(line.decode("utf-8"))
if data.get("done", True): # If it's the completion marker if data.get("done", True): # If it's the completion marker
if "total_duration" in data: # Final performance statistics message if (
"total_duration" in data
): # Final performance statistics message
# print_json_response(data, "Performance statistics") # print_json_response(data, "Performance statistics")
break break
else: # Normal content message else: # Normal content message
@@ -283,7 +283,9 @@ def test_stream_chat():
content = message.get("content", "") content = message.get("content", "")
if content: # Only collect non-empty content if content: # Only collect non-empty content
output_buffer.append(content) output_buffer.append(content)
print(content, end="", flush=True) # Print content in real-time print(
content, end="", flush=True
) # Print content in real-time
except json.JSONDecodeError: except json.JSONDecodeError:
print("Error decoding JSON from response line") print("Error decoding JSON from response line")
finally: finally:
@@ -292,6 +294,7 @@ def test_stream_chat():
# Print a newline # Print a newline
print() print()
def test_query_modes(): def test_query_modes():
"""Test different query mode prefixes """Test different query mode prefixes
@@ -311,8 +314,7 @@ def test_query_modes():
if OutputControl.is_verbose(): if OutputControl.is_verbose():
print(f"\n=== Testing /{mode} mode ===") print(f"\n=== Testing /{mode} mode ===")
data = create_request_data( data = create_request_data(
f"/{mode} {CONFIG['test_cases']['basic']['query']}", f"/{mode} {CONFIG['test_cases']['basic']['query']}", stream=False
stream=False
) )
# Send request # Send request
@@ -320,10 +322,10 @@ def test_query_modes():
response_json = response.json() response_json = response.json()
# Print response content # Print response content
print_json_response({ print_json_response(
"model": response_json["model"], {"model": response_json["model"], "message": response_json["message"]}
"message": response_json["message"] )
})
def create_error_test_data(error_type: str) -> Dict[str, Any]: def create_error_test_data(error_type: str) -> Dict[str, Any]:
"""Create request data for error testing """Create request data for error testing
@@ -337,33 +339,21 @@ def create_error_test_data(error_type: str) -> Dict[str, Any]:
Request dictionary containing error data Request dictionary containing error data
""" """
error_data = { error_data = {
"empty_messages": { "empty_messages": {"model": "lightrag:latest", "messages": [], "stream": True},
"model": "lightrag:latest",
"messages": [],
"stream": True
},
"invalid_role": { "invalid_role": {
"model": "lightrag:latest", "model": "lightrag:latest",
"messages": [ "messages": [{"invalid_role": "user", "content": "Test message"}],
{ "stream": True,
"invalid_role": "user",
"content": "Test message"
}
],
"stream": True
}, },
"missing_content": { "missing_content": {
"model": "lightrag:latest", "model": "lightrag:latest",
"messages": [ "messages": [{"role": "user"}],
{ "stream": True,
"role": "user" },
}
],
"stream": True
}
} }
return error_data.get(error_type, error_data["empty_messages"]) return error_data.get(error_type, error_data["empty_messages"])
def test_stream_error_handling(): def test_stream_error_handling():
"""Test error handling for streaming responses """Test error handling for streaming responses
@@ -409,6 +399,7 @@ def test_stream_error_handling():
print_json_response(response.json(), "Error message") print_json_response(response.json(), "Error message")
response.close() response.close()
def test_error_handling(): def test_error_handling():
"""Test error handling for non-streaming responses """Test error handling for non-streaming responses
@@ -455,6 +446,7 @@ def test_error_handling():
print(f"Status code: {response.status_code}") print(f"Status code: {response.status_code}")
print_json_response(response.json(), "Error message") print_json_response(response.json(), "Error message")
def get_test_cases() -> Dict[str, Callable]: def get_test_cases() -> Dict[str, Callable]:
"""Get all available test cases """Get all available test cases
Returns: Returns:
@@ -465,9 +457,10 @@ def get_test_cases() -> Dict[str, Callable]:
"stream": test_stream_chat, "stream": test_stream_chat,
"modes": test_query_modes, "modes": test_query_modes,
"errors": test_error_handling, "errors": test_error_handling,
"stream_errors": test_stream_error_handling "stream_errors": test_stream_error_handling,
} }
def create_default_config(): def create_default_config():
"""Create a default configuration file""" """Create a default configuration file"""
config_path = Path("config.json") config_path = Path("config.json")
@@ -476,6 +469,7 @@ def create_default_config():
json.dump(DEFAULT_CONFIG, f, ensure_ascii=False, indent=2) json.dump(DEFAULT_CONFIG, f, ensure_ascii=False, indent=2)
print(f"Default configuration file created: {config_path}") print(f"Default configuration file created: {config_path}")
def parse_args() -> argparse.Namespace: def parse_args() -> argparse.Namespace:
"""Parse command line arguments""" """Parse command line arguments"""
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@@ -496,38 +490,39 @@ Configuration file (config.json):
} }
} }
} }
""" """,
) )
parser.add_argument( parser.add_argument(
"-q", "--quiet", "-q",
"--quiet",
action="store_true", action="store_true",
help="Silent mode, only display test result summary" help="Silent mode, only display test result summary",
) )
parser.add_argument( parser.add_argument(
"-a", "--ask", "-a",
"--ask",
type=str, type=str,
help="Specify query content, which will override the query settings in the configuration file" help="Specify query content, which will override the query settings in the configuration file",
) )
parser.add_argument( parser.add_argument(
"--init-config", "--init-config", action="store_true", help="Create default configuration file"
action="store_true",
help="Create default configuration file"
) )
parser.add_argument( parser.add_argument(
"--output", "--output",
type=str, type=str,
default="", default="",
help="Test result output file path, default is not to output to a file" help="Test result output file path, default is not to output to a file",
) )
parser.add_argument( parser.add_argument(
"--tests", "--tests",
nargs="+", nargs="+",
choices=list(get_test_cases().keys()) + ["all"], choices=list(get_test_cases().keys()) + ["all"],
default=["all"], default=["all"],
help="Test cases to run, options: %(choices)s. Use 'all' to run all tests" help="Test cases to run, options: %(choices)s. Use 'all' to run all tests",
) )
return parser.parse_args() return parser.parse_args()
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()