Fix linting, remove redundant commentsr and clean up code for better readability
This commit is contained in:
@@ -476,6 +476,7 @@ class OllamaChatResponse(BaseModel):
|
|||||||
message: OllamaMessage
|
message: OllamaMessage
|
||||||
done: bool
|
done: bool
|
||||||
|
|
||||||
|
|
||||||
class OllamaGenerateRequest(BaseModel):
|
class OllamaGenerateRequest(BaseModel):
|
||||||
model: str = LIGHTRAG_MODEL
|
model: str = LIGHTRAG_MODEL
|
||||||
prompt: str
|
prompt: str
|
||||||
@@ -483,6 +484,7 @@ class OllamaGenerateRequest(BaseModel):
|
|||||||
stream: bool = False
|
stream: bool = False
|
||||||
options: Optional[Dict[str, Any]] = None
|
options: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
class OllamaGenerateResponse(BaseModel):
|
class OllamaGenerateResponse(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
created_at: str
|
created_at: str
|
||||||
@@ -496,6 +498,7 @@ class OllamaGenerateResponse(BaseModel):
|
|||||||
eval_count: Optional[int]
|
eval_count: Optional[int]
|
||||||
eval_duration: Optional[int]
|
eval_duration: Optional[int]
|
||||||
|
|
||||||
|
|
||||||
class OllamaVersionResponse(BaseModel):
|
class OllamaVersionResponse(BaseModel):
|
||||||
version: str
|
version: str
|
||||||
|
|
||||||
@@ -1262,14 +1265,9 @@ def create_app(args):
|
|||||||
"""Handle generate completion requests"""
|
"""Handle generate completion requests"""
|
||||||
try:
|
try:
|
||||||
query = request.prompt
|
query = request.prompt
|
||||||
|
|
||||||
# 开始计时
|
|
||||||
start_time = time.time_ns()
|
start_time = time.time_ns()
|
||||||
|
|
||||||
# 计算输入token数量
|
|
||||||
prompt_tokens = estimate_tokens(query)
|
prompt_tokens = estimate_tokens(query)
|
||||||
|
|
||||||
# 直接使用 llm_model_func 进行查询
|
|
||||||
if request.system:
|
if request.system:
|
||||||
rag.llm_model_kwargs["system_prompt"] = request.system
|
rag.llm_model_kwargs["system_prompt"] = request.system
|
||||||
|
|
||||||
@@ -1277,9 +1275,7 @@ def create_app(args):
|
|||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
response = await rag.llm_model_func(
|
response = await rag.llm_model_func(
|
||||||
query,
|
query, stream=True, **rag.llm_model_kwargs
|
||||||
stream=True,
|
|
||||||
**rag.llm_model_kwargs
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def stream_generator():
|
async def stream_generator():
|
||||||
@@ -1288,9 +1284,9 @@ def create_app(args):
|
|||||||
last_chunk_time = None
|
last_chunk_time = None
|
||||||
total_response = ""
|
total_response = ""
|
||||||
|
|
||||||
# 处理响应
|
# Ensure response is an async generator
|
||||||
if isinstance(response, str):
|
if isinstance(response, str):
|
||||||
# 如果是字符串,分两部分发送
|
# If it's a string, send in two parts
|
||||||
first_chunk_time = time.time_ns()
|
first_chunk_time = time.time_ns()
|
||||||
last_chunk_time = first_chunk_time
|
last_chunk_time = first_chunk_time
|
||||||
total_response = response
|
total_response = response
|
||||||
@@ -1299,7 +1295,7 @@ def create_app(args):
|
|||||||
"model": LIGHTRAG_MODEL,
|
"model": LIGHTRAG_MODEL,
|
||||||
"created_at": LIGHTRAG_CREATED_AT,
|
"created_at": LIGHTRAG_CREATED_AT,
|
||||||
"response": response,
|
"response": response,
|
||||||
"done": False
|
"done": False,
|
||||||
}
|
}
|
||||||
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||||
|
|
||||||
@@ -1317,7 +1313,7 @@ def create_app(args):
|
|||||||
"prompt_eval_count": prompt_tokens,
|
"prompt_eval_count": prompt_tokens,
|
||||||
"prompt_eval_duration": prompt_eval_time,
|
"prompt_eval_duration": prompt_eval_time,
|
||||||
"eval_count": completion_tokens,
|
"eval_count": completion_tokens,
|
||||||
"eval_duration": eval_time
|
"eval_duration": eval_time,
|
||||||
}
|
}
|
||||||
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||||
else:
|
else:
|
||||||
@@ -1333,7 +1329,7 @@ def create_app(args):
|
|||||||
"model": LIGHTRAG_MODEL,
|
"model": LIGHTRAG_MODEL,
|
||||||
"created_at": LIGHTRAG_CREATED_AT,
|
"created_at": LIGHTRAG_CREATED_AT,
|
||||||
"response": chunk,
|
"response": chunk,
|
||||||
"done": False
|
"done": False,
|
||||||
}
|
}
|
||||||
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||||
|
|
||||||
@@ -1351,7 +1347,7 @@ def create_app(args):
|
|||||||
"prompt_eval_count": prompt_tokens,
|
"prompt_eval_count": prompt_tokens,
|
||||||
"prompt_eval_duration": prompt_eval_time,
|
"prompt_eval_duration": prompt_eval_time,
|
||||||
"eval_count": completion_tokens,
|
"eval_count": completion_tokens,
|
||||||
"eval_duration": eval_time
|
"eval_duration": eval_time,
|
||||||
}
|
}
|
||||||
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||||
return
|
return
|
||||||
@@ -1375,9 +1371,7 @@ def create_app(args):
|
|||||||
else:
|
else:
|
||||||
first_chunk_time = time.time_ns()
|
first_chunk_time = time.time_ns()
|
||||||
response_text = await rag.llm_model_func(
|
response_text = await rag.llm_model_func(
|
||||||
query,
|
query, stream=False, **rag.llm_model_kwargs
|
||||||
stream=False,
|
|
||||||
**rag.llm_model_kwargs
|
|
||||||
)
|
)
|
||||||
last_chunk_time = time.time_ns()
|
last_chunk_time = time.time_ns()
|
||||||
|
|
||||||
@@ -1399,7 +1393,7 @@ def create_app(args):
|
|||||||
"prompt_eval_count": prompt_tokens,
|
"prompt_eval_count": prompt_tokens,
|
||||||
"prompt_eval_duration": prompt_eval_time,
|
"prompt_eval_duration": prompt_eval_time,
|
||||||
"eval_count": completion_tokens,
|
"eval_count": completion_tokens,
|
||||||
"eval_duration": eval_time
|
"eval_duration": eval_time,
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
trace_exception(e)
|
trace_exception(e)
|
||||||
@@ -1417,16 +1411,12 @@ def create_app(args):
|
|||||||
# Get the last message as query
|
# Get the last message as query
|
||||||
query = messages[-1].content
|
query = messages[-1].content
|
||||||
|
|
||||||
# 解析查询模式
|
# Check for query prefix
|
||||||
cleaned_query, mode = parse_query_mode(query)
|
cleaned_query, mode = parse_query_mode(query)
|
||||||
|
|
||||||
# 开始计时
|
|
||||||
start_time = time.time_ns()
|
start_time = time.time_ns()
|
||||||
|
|
||||||
# 计算输入token数量
|
|
||||||
prompt_tokens = estimate_tokens(cleaned_query)
|
prompt_tokens = estimate_tokens(cleaned_query)
|
||||||
|
|
||||||
# 调用RAG进行查询
|
|
||||||
query_param = QueryParam(
|
query_param = QueryParam(
|
||||||
mode=mode, stream=request.stream, only_need_context=False
|
mode=mode, stream=request.stream, only_need_context=False
|
||||||
)
|
)
|
||||||
@@ -1538,20 +1528,16 @@ def create_app(args):
|
|||||||
else:
|
else:
|
||||||
first_chunk_time = time.time_ns()
|
first_chunk_time = time.time_ns()
|
||||||
|
|
||||||
# 判断是否包含特定字符串,使用正则表达式进行匹配
|
# Determine if the request is from Open WebUI's session title and session keyword generation task
|
||||||
logging.info(f"Cleaned query content: {cleaned_query}")
|
match_result = re.search(
|
||||||
match_result = re.search(r'\n<chat_history>\nUSER:', cleaned_query, re.MULTILINE)
|
r"\n<chat_history>\nUSER:", cleaned_query, re.MULTILINE
|
||||||
logging.info(f"Regex match result: {bool(match_result)}")
|
)
|
||||||
|
|
||||||
if match_result:
|
if match_result:
|
||||||
|
|
||||||
if request.system:
|
if request.system:
|
||||||
rag.llm_model_kwargs["system_prompt"] = request.system
|
rag.llm_model_kwargs["system_prompt"] = request.system
|
||||||
|
|
||||||
response_text = await rag.llm_model_func(
|
response_text = await rag.llm_model_func(
|
||||||
cleaned_query,
|
cleaned_query, stream=False, **rag.llm_model_kwargs
|
||||||
stream=False,
|
|
||||||
**rag.llm_model_kwargs
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response_text = await rag.aquery(cleaned_query, param=query_param)
|
response_text = await rag.aquery(cleaned_query, param=query_param)
|
||||||
|
@@ -110,7 +110,7 @@ DEFAULT_CONFIG = {
|
|||||||
},
|
},
|
||||||
"test_cases": {
|
"test_cases": {
|
||||||
"basic": {"query": "唐僧有几个徒弟"},
|
"basic": {"query": "唐僧有几个徒弟"},
|
||||||
"generate": {"query": "电视剧西游记导演是谁"}
|
"generate": {"query": "电视剧西游记导演是谁"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -205,12 +205,13 @@ def create_chat_request_data(
|
|||||||
"stream": stream,
|
"stream": stream,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def create_generate_request_data(
|
def create_generate_request_data(
|
||||||
prompt: str,
|
prompt: str,
|
||||||
system: str = None,
|
system: str = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
model: str = None,
|
model: str = None,
|
||||||
options: Dict[str, Any] = None
|
options: Dict[str, Any] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Create generate request data
|
"""Create generate request data
|
||||||
Args:
|
Args:
|
||||||
@@ -225,7 +226,7 @@ def create_generate_request_data(
|
|||||||
data = {
|
data = {
|
||||||
"model": model or CONFIG["server"]["model"],
|
"model": model or CONFIG["server"]["model"],
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"stream": stream
|
"stream": stream,
|
||||||
}
|
}
|
||||||
if system:
|
if system:
|
||||||
data["system"] = system
|
data["system"] = system
|
||||||
@@ -258,7 +259,9 @@ def run_test(func: Callable, name: str) -> None:
|
|||||||
def test_non_stream_chat() -> None:
|
def test_non_stream_chat() -> None:
|
||||||
"""Test non-streaming call to /api/chat endpoint"""
|
"""Test non-streaming call to /api/chat endpoint"""
|
||||||
url = get_base_url()
|
url = get_base_url()
|
||||||
data = create_chat_request_data(CONFIG["test_cases"]["basic"]["query"], stream=False)
|
data = create_chat_request_data(
|
||||||
|
CONFIG["test_cases"]["basic"]["query"], stream=False
|
||||||
|
)
|
||||||
|
|
||||||
# Send request
|
# Send request
|
||||||
response = make_request(url, data)
|
response = make_request(url, data)
|
||||||
@@ -487,8 +490,7 @@ def test_non_stream_generate() -> None:
|
|||||||
"""Test non-streaming call to /api/generate endpoint"""
|
"""Test non-streaming call to /api/generate endpoint"""
|
||||||
url = get_base_url("generate")
|
url = get_base_url("generate")
|
||||||
data = create_generate_request_data(
|
data = create_generate_request_data(
|
||||||
CONFIG["test_cases"]["generate"]["query"],
|
CONFIG["test_cases"]["generate"]["query"], stream=False
|
||||||
stream=False
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send request
|
# Send request
|
||||||
@@ -504,17 +506,17 @@ def test_non_stream_generate() -> None:
|
|||||||
{
|
{
|
||||||
"model": response_json["model"],
|
"model": response_json["model"],
|
||||||
"response": response_json["response"],
|
"response": response_json["response"],
|
||||||
"done": response_json["done"]
|
"done": response_json["done"],
|
||||||
},
|
},
|
||||||
"Response content"
|
"Response content",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_stream_generate() -> None:
|
def test_stream_generate() -> None:
|
||||||
"""Test streaming call to /api/generate endpoint"""
|
"""Test streaming call to /api/generate endpoint"""
|
||||||
url = get_base_url("generate")
|
url = get_base_url("generate")
|
||||||
data = create_generate_request_data(
|
data = create_generate_request_data(
|
||||||
CONFIG["test_cases"]["generate"]["query"],
|
CONFIG["test_cases"]["generate"]["query"], stream=True
|
||||||
stream=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send request and get streaming response
|
# Send request and get streaming response
|
||||||
@@ -530,13 +532,17 @@ def test_stream_generate() -> None:
|
|||||||
# Decode and parse JSON
|
# Decode and parse JSON
|
||||||
data = json.loads(line.decode("utf-8"))
|
data = json.loads(line.decode("utf-8"))
|
||||||
if data.get("done", True): # If it's the completion marker
|
if data.get("done", True): # If it's the completion marker
|
||||||
if "total_duration" in data: # Final performance statistics message
|
if (
|
||||||
|
"total_duration" in data
|
||||||
|
): # Final performance statistics message
|
||||||
break
|
break
|
||||||
else: # Normal content message
|
else: # Normal content message
|
||||||
content = data.get("response", "")
|
content = data.get("response", "")
|
||||||
if content: # Only collect non-empty content
|
if content: # Only collect non-empty content
|
||||||
output_buffer.append(content)
|
output_buffer.append(content)
|
||||||
print(content, end="", flush=True) # Print content in real-time
|
print(
|
||||||
|
content, end="", flush=True
|
||||||
|
) # Print content in real-time
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
print("Error decoding JSON from response line")
|
print("Error decoding JSON from response line")
|
||||||
finally:
|
finally:
|
||||||
@@ -545,13 +551,14 @@ def test_stream_generate() -> None:
|
|||||||
# Print a newline
|
# Print a newline
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
||||||
def test_generate_with_system() -> None:
|
def test_generate_with_system() -> None:
|
||||||
"""Test generate with system prompt"""
|
"""Test generate with system prompt"""
|
||||||
url = get_base_url("generate")
|
url = get_base_url("generate")
|
||||||
data = create_generate_request_data(
|
data = create_generate_request_data(
|
||||||
CONFIG["test_cases"]["generate"]["query"],
|
CONFIG["test_cases"]["generate"]["query"],
|
||||||
system="你是一个知识渊博的助手",
|
system="你是一个知识渊博的助手",
|
||||||
stream=False
|
stream=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send request
|
# Send request
|
||||||
@@ -567,11 +574,12 @@ def test_generate_with_system() -> None:
|
|||||||
{
|
{
|
||||||
"model": response_json["model"],
|
"model": response_json["model"],
|
||||||
"response": response_json["response"],
|
"response": response_json["response"],
|
||||||
"done": response_json["done"]
|
"done": response_json["done"],
|
||||||
},
|
},
|
||||||
"Response content"
|
"Response content",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_generate_error_handling() -> None:
|
def test_generate_error_handling() -> None:
|
||||||
"""Test error handling for generate endpoint"""
|
"""Test error handling for generate endpoint"""
|
||||||
url = get_base_url("generate")
|
url = get_base_url("generate")
|
||||||
@@ -590,7 +598,7 @@ def test_generate_error_handling() -> None:
|
|||||||
data = create_generate_request_data(
|
data = create_generate_request_data(
|
||||||
CONFIG["test_cases"]["basic"]["query"],
|
CONFIG["test_cases"]["basic"]["query"],
|
||||||
options={"invalid_option": "value"},
|
options={"invalid_option": "value"},
|
||||||
stream=False
|
stream=False,
|
||||||
)
|
)
|
||||||
response = make_request(url, data)
|
response = make_request(url, data)
|
||||||
print(f"Status code: {response.status_code}")
|
print(f"Status code: {response.status_code}")
|
||||||
@@ -618,13 +626,7 @@ def test_generate_concurrent() -> None:
|
|||||||
return {"error": str(e)}
|
return {"error": str(e)}
|
||||||
|
|
||||||
async def run_concurrent_requests():
|
async def run_concurrent_requests():
|
||||||
prompts = [
|
prompts = ["第一个问题", "第二个问题", "第三个问题", "第四个问题", "第五个问题"]
|
||||||
"第一个问题",
|
|
||||||
"第二个问题",
|
|
||||||
"第三个问题",
|
|
||||||
"第四个问题",
|
|
||||||
"第五个问题"
|
|
||||||
]
|
|
||||||
|
|
||||||
async with get_session() as session:
|
async with get_session() as session:
|
||||||
tasks = [make_request(session, prompt) for prompt in prompts]
|
tasks = [make_request(session, prompt) for prompt in prompts]
|
||||||
@@ -642,6 +644,7 @@ def test_generate_concurrent() -> None:
|
|||||||
print(f"\nRequest {i} result:")
|
print(f"\nRequest {i} result:")
|
||||||
print_json_response(result)
|
print_json_response(result)
|
||||||
|
|
||||||
|
|
||||||
def get_test_cases() -> Dict[str, Callable]:
|
def get_test_cases() -> Dict[str, Callable]:
|
||||||
"""Get all available test cases
|
"""Get all available test cases
|
||||||
Returns:
|
Returns:
|
||||||
@@ -657,7 +660,7 @@ def get_test_cases() -> Dict[str, Callable]:
|
|||||||
"stream_generate": test_stream_generate,
|
"stream_generate": test_stream_generate,
|
||||||
"generate_with_system": test_generate_with_system,
|
"generate_with_system": test_generate_with_system,
|
||||||
"generate_errors": test_generate_error_handling,
|
"generate_errors": test_generate_error_handling,
|
||||||
"generate_concurrent": test_generate_concurrent
|
"generate_concurrent": test_generate_concurrent,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user