Fix linting, remove redundant commentsr and clean up code for better readability

This commit is contained in:
yangdx
2025-01-24 23:50:47 +08:00
parent 11873625a3
commit f30a69e201
2 changed files with 76 additions and 87 deletions

View File

@@ -476,6 +476,7 @@ class OllamaChatResponse(BaseModel):
message: OllamaMessage
done: bool
class OllamaGenerateRequest(BaseModel):
model: str = LIGHTRAG_MODEL
prompt: str
@@ -483,6 +484,7 @@ class OllamaGenerateRequest(BaseModel):
stream: bool = False
options: Optional[Dict[str, Any]] = None
class OllamaGenerateResponse(BaseModel):
model: str
created_at: str
@@ -490,12 +492,13 @@ class OllamaGenerateResponse(BaseModel):
done: bool
context: Optional[List[int]]
total_duration: Optional[int]
load_duration: Optional[int]
load_duration: Optional[int]
prompt_eval_count: Optional[int]
prompt_eval_duration: Optional[int]
eval_count: Optional[int]
eval_duration: Optional[int]
class OllamaVersionResponse(BaseModel):
version: str
@@ -1262,52 +1265,45 @@ def create_app(args):
"""Handle generate completion requests"""
try:
query = request.prompt
# 开始计时
start_time = time.time_ns()
# 计算输入token数量
prompt_tokens = estimate_tokens(query)
# 直接使用 llm_model_func 进行查询
if request.system:
rag.llm_model_kwargs["system_prompt"] = request.system
if request.stream:
from fastapi.responses import StreamingResponse
response = await rag.llm_model_func(
query,
stream=True,
**rag.llm_model_kwargs
query, stream=True, **rag.llm_model_kwargs
)
async def stream_generator():
try:
first_chunk_time = None
last_chunk_time = None
total_response = ""
# 处理响应
# Ensure response is an async generator
if isinstance(response, str):
# 如果是字符串,分两部分发送
# If it's a string, send in two parts
first_chunk_time = time.time_ns()
last_chunk_time = first_chunk_time
total_response = response
data = {
"model": LIGHTRAG_MODEL,
"created_at": LIGHTRAG_CREATED_AT,
"response": response,
"done": False
"done": False,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
completion_tokens = estimate_tokens(total_response)
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
data = {
"model": LIGHTRAG_MODEL,
"created_at": LIGHTRAG_CREATED_AT,
@@ -1317,7 +1313,7 @@ def create_app(args):
"prompt_eval_count": prompt_tokens,
"prompt_eval_duration": prompt_eval_time,
"eval_count": completion_tokens,
"eval_duration": eval_time
"eval_duration": eval_time,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
else:
@@ -1325,23 +1321,23 @@ def create_app(args):
if chunk:
if first_chunk_time is None:
first_chunk_time = time.time_ns()
last_chunk_time = time.time_ns()
total_response += chunk
data = {
"model": LIGHTRAG_MODEL,
"created_at": LIGHTRAG_CREATED_AT,
"response": chunk,
"done": False
"done": False,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
completion_tokens = estimate_tokens(total_response)
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
data = {
"model": LIGHTRAG_MODEL,
"created_at": LIGHTRAG_CREATED_AT,
@@ -1351,15 +1347,15 @@ def create_app(args):
"prompt_eval_count": prompt_tokens,
"prompt_eval_duration": prompt_eval_time,
"eval_count": completion_tokens,
"eval_duration": eval_time
"eval_duration": eval_time,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
return
except Exception as e:
logging.error(f"Error in stream_generator: {str(e)}")
raise
return StreamingResponse(
stream_generator(),
media_type="application/x-ndjson",
@@ -1375,20 +1371,18 @@ def create_app(args):
else:
first_chunk_time = time.time_ns()
response_text = await rag.llm_model_func(
query,
stream=False,
**rag.llm_model_kwargs
query, stream=False, **rag.llm_model_kwargs
)
last_chunk_time = time.time_ns()
if not response_text:
response_text = "No response generated"
completion_tokens = estimate_tokens(str(response_text))
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
return {
"model": LIGHTRAG_MODEL,
"created_at": LIGHTRAG_CREATED_AT,
@@ -1399,7 +1393,7 @@ def create_app(args):
"prompt_eval_count": prompt_tokens,
"prompt_eval_duration": prompt_eval_time,
"eval_count": completion_tokens,
"eval_duration": eval_time
"eval_duration": eval_time,
}
except Exception as e:
trace_exception(e)
@@ -1417,16 +1411,12 @@ def create_app(args):
# Get the last message as query
query = messages[-1].content
# 解析查询模式
# Check for query prefix
cleaned_query, mode = parse_query_mode(query)
# 开始计时
start_time = time.time_ns()
# 计算输入token数量
prompt_tokens = estimate_tokens(cleaned_query)
# 调用RAG进行查询
query_param = QueryParam(
mode=mode, stream=request.stream, only_need_context=False
)
@@ -1537,25 +1527,21 @@ def create_app(args):
)
else:
first_chunk_time = time.time_ns()
# 判断是否包含特定字符串,使用正则表达式进行匹配
logging.info(f"Cleaned query content: {cleaned_query}")
match_result = re.search(r'\n<chat_history>\nUSER:', cleaned_query, re.MULTILINE)
logging.info(f"Regex match result: {bool(match_result)}")
if match_result:
# Determine if the request is from Open WebUI's session title and session keyword generation task
match_result = re.search(
r"\n<chat_history>\nUSER:", cleaned_query, re.MULTILINE
)
if match_result:
if request.system:
rag.llm_model_kwargs["system_prompt"] = request.system
response_text = await rag.llm_model_func(
cleaned_query,
stream=False,
**rag.llm_model_kwargs
cleaned_query, stream=False, **rag.llm_model_kwargs
)
else:
response_text = await rag.aquery(cleaned_query, param=query_param)
last_chunk_time = time.time_ns()
if not response_text:

View File

@@ -110,7 +110,7 @@ DEFAULT_CONFIG = {
},
"test_cases": {
"basic": {"query": "唐僧有几个徒弟"},
"generate": {"query": "电视剧西游记导演是谁"}
"generate": {"query": "电视剧西游记导演是谁"},
},
}
@@ -205,12 +205,13 @@ def create_chat_request_data(
"stream": stream,
}
def create_generate_request_data(
prompt: str,
prompt: str,
system: str = None,
stream: bool = False,
stream: bool = False,
model: str = None,
options: Dict[str, Any] = None
options: Dict[str, Any] = None,
) -> Dict[str, Any]:
"""Create generate request data
Args:
@@ -225,7 +226,7 @@ def create_generate_request_data(
data = {
"model": model or CONFIG["server"]["model"],
"prompt": prompt,
"stream": stream
"stream": stream,
}
if system:
data["system"] = system
@@ -258,7 +259,9 @@ def run_test(func: Callable, name: str) -> None:
def test_non_stream_chat() -> None:
"""Test non-streaming call to /api/chat endpoint"""
url = get_base_url()
data = create_chat_request_data(CONFIG["test_cases"]["basic"]["query"], stream=False)
data = create_chat_request_data(
CONFIG["test_cases"]["basic"]["query"], stream=False
)
# Send request
response = make_request(url, data)
@@ -487,8 +490,7 @@ def test_non_stream_generate() -> None:
"""Test non-streaming call to /api/generate endpoint"""
url = get_base_url("generate")
data = create_generate_request_data(
CONFIG["test_cases"]["generate"]["query"],
stream=False
CONFIG["test_cases"]["generate"]["query"], stream=False
)
# Send request
@@ -504,17 +506,17 @@ def test_non_stream_generate() -> None:
{
"model": response_json["model"],
"response": response_json["response"],
"done": response_json["done"]
"done": response_json["done"],
},
"Response content"
"Response content",
)
def test_stream_generate() -> None:
"""Test streaming call to /api/generate endpoint"""
url = get_base_url("generate")
data = create_generate_request_data(
CONFIG["test_cases"]["generate"]["query"],
stream=True
CONFIG["test_cases"]["generate"]["query"], stream=True
)
# Send request and get streaming response
@@ -530,13 +532,17 @@ def test_stream_generate() -> None:
# Decode and parse JSON
data = json.loads(line.decode("utf-8"))
if data.get("done", True): # If it's the completion marker
if "total_duration" in data: # Final performance statistics message
if (
"total_duration" in data
): # Final performance statistics message
break
else: # Normal content message
content = data.get("response", "")
if content: # Only collect non-empty content
output_buffer.append(content)
print(content, end="", flush=True) # Print content in real-time
print(
content, end="", flush=True
) # Print content in real-time
except json.JSONDecodeError:
print("Error decoding JSON from response line")
finally:
@@ -545,13 +551,14 @@ def test_stream_generate() -> None:
# Print a newline
print()
def test_generate_with_system() -> None:
"""Test generate with system prompt"""
url = get_base_url("generate")
data = create_generate_request_data(
CONFIG["test_cases"]["generate"]["query"],
system="你是一个知识渊博的助手",
stream=False
stream=False,
)
# Send request
@@ -567,15 +574,16 @@ def test_generate_with_system() -> None:
{
"model": response_json["model"],
"response": response_json["response"],
"done": response_json["done"]
"done": response_json["done"],
},
"Response content"
"Response content",
)
def test_generate_error_handling() -> None:
"""Test error handling for generate endpoint"""
url = get_base_url("generate")
# Test empty prompt
if OutputControl.is_verbose():
print("\n=== Testing empty prompt ===")
@@ -583,14 +591,14 @@ def test_generate_error_handling() -> None:
response = make_request(url, data)
print(f"Status code: {response.status_code}")
print_json_response(response.json(), "Error message")
# Test invalid options
if OutputControl.is_verbose():
print("\n=== Testing invalid options ===")
data = create_generate_request_data(
CONFIG["test_cases"]["basic"]["query"],
options={"invalid_option": "value"},
stream=False
stream=False,
)
response = make_request(url, data)
print(f"Status code: {response.status_code}")
@@ -602,12 +610,12 @@ def test_generate_concurrent() -> None:
import asyncio
import aiohttp
from contextlib import asynccontextmanager
@asynccontextmanager
async def get_session():
async with aiohttp.ClientSession() as session:
yield session
async def make_request(session, prompt: str):
url = get_base_url("generate")
data = create_generate_request_data(prompt, stream=False)
@@ -616,32 +624,27 @@ def test_generate_concurrent() -> None:
return await response.json()
except Exception as e:
return {"error": str(e)}
async def run_concurrent_requests():
prompts = [
"第一个问题",
"第二个问题",
"第三个问题",
"第四个问题",
"第五个问题"
]
prompts = ["第一个问题", "第二个问题", "第三个问题", "第四个问题", "第五个问题"]
async with get_session() as session:
tasks = [make_request(session, prompt) for prompt in prompts]
results = await asyncio.gather(*tasks)
return results
if OutputControl.is_verbose():
print("\n=== Testing concurrent generate requests ===")
# Run concurrent requests
results = asyncio.run(run_concurrent_requests())
# Print results
for i, result in enumerate(results, 1):
print(f"\nRequest {i} result:")
print_json_response(result)
def get_test_cases() -> Dict[str, Callable]:
"""Get all available test cases
Returns:
@@ -657,7 +660,7 @@ def get_test_cases() -> Dict[str, Callable]:
"stream_generate": test_stream_generate,
"generate_with_system": test_generate_with_system,
"generate_errors": test_generate_error_handling,
"generate_concurrent": test_generate_concurrent
"generate_concurrent": test_generate_concurrent,
}