Removed query mode parsing and related tests
- Removed query mode parsing logic - Removed test_generate_query_modes - Simplified generate endpoint - Updated test cases list - Cleaned up unused code
This commit is contained in:
@@ -1260,17 +1260,13 @@ def create_app(args):
|
||||
async def generate(raw_request: Request, request: OllamaGenerateRequest):
|
||||
"""Handle generate completion requests"""
|
||||
try:
|
||||
# 获取查询内容
|
||||
query = request.prompt
|
||||
|
||||
# 解析查询模式
|
||||
cleaned_query, mode = parse_query_mode(query)
|
||||
|
||||
|
||||
# 开始计时
|
||||
start_time = time.time_ns()
|
||||
|
||||
# 计算输入token数量
|
||||
prompt_tokens = estimate_tokens(cleaned_query)
|
||||
prompt_tokens = estimate_tokens(query)
|
||||
|
||||
# 直接使用 llm_model_func 进行查询
|
||||
if request.system:
|
||||
@@ -1280,7 +1276,7 @@ def create_app(args):
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
response = await rag.llm_model_func(
|
||||
cleaned_query,
|
||||
query,
|
||||
stream=True,
|
||||
**rag.llm_model_kwargs
|
||||
)
|
||||
@@ -1378,7 +1374,7 @@ def create_app(args):
|
||||
else:
|
||||
first_chunk_time = time.time_ns()
|
||||
response_text = await rag.llm_model_func(
|
||||
cleaned_query,
|
||||
query,
|
||||
stream=False,
|
||||
**rag.llm_model_kwargs
|
||||
)
|
||||
|
@@ -679,32 +679,6 @@ def test_generate_concurrent() -> None:
|
||||
print(f"\nRequest {i} result:")
|
||||
print_json_response(result)
|
||||
|
||||
def test_generate_query_modes() -> None:
|
||||
"""Test different query mode prefixes for generate endpoint"""
|
||||
url = get_base_url("generate")
|
||||
modes = ["local", "global", "naive", "hybrid", "mix"]
|
||||
|
||||
for mode in modes:
|
||||
if OutputControl.is_verbose():
|
||||
print(f"\n=== Testing /{mode} mode for generate ===")
|
||||
data = create_generate_request_data(
|
||||
f"/{mode} {CONFIG['test_cases']['generate']['query']}",
|
||||
stream=False
|
||||
)
|
||||
|
||||
# Send request
|
||||
response = make_request(url, data)
|
||||
response_json = response.json()
|
||||
|
||||
# Print response content
|
||||
print_json_response(
|
||||
{
|
||||
"model": response_json["model"],
|
||||
"response": response_json["response"],
|
||||
"done": response_json["done"]
|
||||
}
|
||||
)
|
||||
|
||||
def get_test_cases() -> Dict[str, Callable]:
|
||||
"""Get all available test cases
|
||||
Returns:
|
||||
@@ -719,7 +693,6 @@ def get_test_cases() -> Dict[str, Callable]:
|
||||
"non_stream_generate": test_non_stream_generate,
|
||||
"stream_generate": test_stream_generate,
|
||||
"generate_with_system": test_generate_with_system,
|
||||
"generate_modes": test_generate_query_modes,
|
||||
"generate_errors": test_generate_error_handling,
|
||||
"generate_stats": test_generate_performance_stats,
|
||||
"generate_concurrent": test_generate_concurrent
|
||||
@@ -821,7 +794,6 @@ if __name__ == "__main__":
|
||||
run_test(test_non_stream_generate, "Non-streaming Generate Test")
|
||||
run_test(test_stream_generate, "Streaming Generate Test")
|
||||
run_test(test_generate_with_system, "Generate with System Prompt Test")
|
||||
run_test(test_generate_query_modes, "Generate Query Mode Test")
|
||||
run_test(test_generate_error_handling, "Generate Error Handling Test")
|
||||
run_test(test_generate_performance_stats, "Generate Performance Stats Test")
|
||||
run_test(test_generate_concurrent, "Generate Concurrent Test")
|
||||
|
Reference in New Issue
Block a user