Merge pull request #644 from danielaskdd/Add-Ollama-generate-API-support
Add ollama generate api support
This commit is contained in:
@@ -94,6 +94,7 @@ For example, chat message "/mix 唐僧有几个徒弟" will trigger a mix mode q
|
|||||||
|
|
||||||
After starting the lightrag-server, you can add an Ollama-type connection in the Open WebUI admin pannel. And then a model named lightrag:latest will appear in Open WebUI's model management interface. Users can then send queries to LightRAG through the chat interface.
|
After starting the lightrag-server, you can add an Ollama-type connection in the Open WebUI admin pannel. And then a model named lightrag:latest will appear in Open WebUI's model management interface. Users can then send queries to LightRAG through the chat interface.
|
||||||
|
|
||||||
|
To prevent Open WebUI from using LightRAG when generating conversation titles, go to Admin Panel > Interface > Set Task Model and change both Local Models and External Models to any option except "Current Model".
|
||||||
|
|
||||||
## Configuration
|
## Configuration
|
||||||
|
|
||||||
|
@@ -533,6 +533,7 @@ class OllamaChatRequest(BaseModel):
|
|||||||
messages: List[OllamaMessage]
|
messages: List[OllamaMessage]
|
||||||
stream: bool = True # Default to streaming mode
|
stream: bool = True # Default to streaming mode
|
||||||
options: Optional[Dict[str, Any]] = None
|
options: Optional[Dict[str, Any]] = None
|
||||||
|
system: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class OllamaChatResponse(BaseModel):
|
class OllamaChatResponse(BaseModel):
|
||||||
@@ -542,6 +543,28 @@ class OllamaChatResponse(BaseModel):
|
|||||||
done: bool
|
done: bool
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaGenerateRequest(BaseModel):
|
||||||
|
model: str = LIGHTRAG_MODEL
|
||||||
|
prompt: str
|
||||||
|
system: Optional[str] = None
|
||||||
|
stream: bool = False
|
||||||
|
options: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaGenerateResponse(BaseModel):
|
||||||
|
model: str
|
||||||
|
created_at: str
|
||||||
|
response: str
|
||||||
|
done: bool
|
||||||
|
context: Optional[List[int]]
|
||||||
|
total_duration: Optional[int]
|
||||||
|
load_duration: Optional[int]
|
||||||
|
prompt_eval_count: Optional[int]
|
||||||
|
prompt_eval_duration: Optional[int]
|
||||||
|
eval_count: Optional[int]
|
||||||
|
eval_duration: Optional[int]
|
||||||
|
|
||||||
|
|
||||||
class OllamaVersionResponse(BaseModel):
|
class OllamaVersionResponse(BaseModel):
|
||||||
version: str
|
version: str
|
||||||
|
|
||||||
@@ -1417,6 +1440,145 @@ def create_app(args):
|
|||||||
|
|
||||||
return query, SearchMode.hybrid
|
return query, SearchMode.hybrid
|
||||||
|
|
||||||
|
@app.post("/api/generate")
|
||||||
|
async def generate(raw_request: Request, request: OllamaGenerateRequest):
|
||||||
|
"""Handle generate completion requests"""
|
||||||
|
try:
|
||||||
|
query = request.prompt
|
||||||
|
start_time = time.time_ns()
|
||||||
|
prompt_tokens = estimate_tokens(query)
|
||||||
|
|
||||||
|
if request.system:
|
||||||
|
rag.llm_model_kwargs["system_prompt"] = request.system
|
||||||
|
|
||||||
|
if request.stream:
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
|
response = await rag.llm_model_func(
|
||||||
|
query, stream=True, **rag.llm_model_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
async def stream_generator():
|
||||||
|
try:
|
||||||
|
first_chunk_time = None
|
||||||
|
last_chunk_time = None
|
||||||
|
total_response = ""
|
||||||
|
|
||||||
|
# Ensure response is an async generator
|
||||||
|
if isinstance(response, str):
|
||||||
|
# If it's a string, send in two parts
|
||||||
|
first_chunk_time = time.time_ns()
|
||||||
|
last_chunk_time = first_chunk_time
|
||||||
|
total_response = response
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"model": LIGHTRAG_MODEL,
|
||||||
|
"created_at": LIGHTRAG_CREATED_AT,
|
||||||
|
"response": response,
|
||||||
|
"done": False,
|
||||||
|
}
|
||||||
|
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||||
|
|
||||||
|
completion_tokens = estimate_tokens(total_response)
|
||||||
|
total_time = last_chunk_time - start_time
|
||||||
|
prompt_eval_time = first_chunk_time - start_time
|
||||||
|
eval_time = last_chunk_time - first_chunk_time
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"model": LIGHTRAG_MODEL,
|
||||||
|
"created_at": LIGHTRAG_CREATED_AT,
|
||||||
|
"done": True,
|
||||||
|
"total_duration": total_time,
|
||||||
|
"load_duration": 0,
|
||||||
|
"prompt_eval_count": prompt_tokens,
|
||||||
|
"prompt_eval_duration": prompt_eval_time,
|
||||||
|
"eval_count": completion_tokens,
|
||||||
|
"eval_duration": eval_time,
|
||||||
|
}
|
||||||
|
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||||
|
else:
|
||||||
|
async for chunk in response:
|
||||||
|
if chunk:
|
||||||
|
if first_chunk_time is None:
|
||||||
|
first_chunk_time = time.time_ns()
|
||||||
|
|
||||||
|
last_chunk_time = time.time_ns()
|
||||||
|
|
||||||
|
total_response += chunk
|
||||||
|
data = {
|
||||||
|
"model": LIGHTRAG_MODEL,
|
||||||
|
"created_at": LIGHTRAG_CREATED_AT,
|
||||||
|
"response": chunk,
|
||||||
|
"done": False,
|
||||||
|
}
|
||||||
|
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||||
|
|
||||||
|
completion_tokens = estimate_tokens(total_response)
|
||||||
|
total_time = last_chunk_time - start_time
|
||||||
|
prompt_eval_time = first_chunk_time - start_time
|
||||||
|
eval_time = last_chunk_time - first_chunk_time
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"model": LIGHTRAG_MODEL,
|
||||||
|
"created_at": LIGHTRAG_CREATED_AT,
|
||||||
|
"done": True,
|
||||||
|
"total_duration": total_time,
|
||||||
|
"load_duration": 0,
|
||||||
|
"prompt_eval_count": prompt_tokens,
|
||||||
|
"prompt_eval_duration": prompt_eval_time,
|
||||||
|
"eval_count": completion_tokens,
|
||||||
|
"eval_duration": eval_time,
|
||||||
|
}
|
||||||
|
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error in stream_generator: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
stream_generator(),
|
||||||
|
media_type="application/x-ndjson",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"Content-Type": "application/x-ndjson",
|
||||||
|
"Access-Control-Allow-Origin": "*",
|
||||||
|
"Access-Control-Allow-Methods": "POST, OPTIONS",
|
||||||
|
"Access-Control-Allow-Headers": "Content-Type",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
first_chunk_time = time.time_ns()
|
||||||
|
response_text = await rag.llm_model_func(
|
||||||
|
query, stream=False, **rag.llm_model_kwargs
|
||||||
|
)
|
||||||
|
last_chunk_time = time.time_ns()
|
||||||
|
|
||||||
|
if not response_text:
|
||||||
|
response_text = "No response generated"
|
||||||
|
|
||||||
|
completion_tokens = estimate_tokens(str(response_text))
|
||||||
|
total_time = last_chunk_time - start_time
|
||||||
|
prompt_eval_time = first_chunk_time - start_time
|
||||||
|
eval_time = last_chunk_time - first_chunk_time
|
||||||
|
|
||||||
|
return {
|
||||||
|
"model": LIGHTRAG_MODEL,
|
||||||
|
"created_at": LIGHTRAG_CREATED_AT,
|
||||||
|
"response": str(response_text),
|
||||||
|
"done": True,
|
||||||
|
"total_duration": total_time,
|
||||||
|
"load_duration": 0,
|
||||||
|
"prompt_eval_count": prompt_tokens,
|
||||||
|
"prompt_eval_duration": prompt_eval_time,
|
||||||
|
"eval_count": completion_tokens,
|
||||||
|
"eval_duration": eval_time,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
trace_exception(e)
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@app.post("/api/chat")
|
@app.post("/api/chat")
|
||||||
async def chat(raw_request: Request, request: OllamaChatRequest):
|
async def chat(raw_request: Request, request: OllamaChatRequest):
|
||||||
"""Handle chat completion requests"""
|
"""Handle chat completion requests"""
|
||||||
@@ -1429,16 +1591,12 @@ def create_app(args):
|
|||||||
# Get the last message as query
|
# Get the last message as query
|
||||||
query = messages[-1].content
|
query = messages[-1].content
|
||||||
|
|
||||||
# 解析查询模式
|
# Check for query prefix
|
||||||
cleaned_query, mode = parse_query_mode(query)
|
cleaned_query, mode = parse_query_mode(query)
|
||||||
|
|
||||||
# 开始计时
|
|
||||||
start_time = time.time_ns()
|
start_time = time.time_ns()
|
||||||
|
|
||||||
# 计算输入token数量
|
|
||||||
prompt_tokens = estimate_tokens(cleaned_query)
|
prompt_tokens = estimate_tokens(cleaned_query)
|
||||||
|
|
||||||
# 调用RAG进行查询
|
|
||||||
query_param = QueryParam(
|
query_param = QueryParam(
|
||||||
mode=mode, stream=request.stream, only_need_context=False
|
mode=mode, stream=request.stream, only_need_context=False
|
||||||
)
|
)
|
||||||
@@ -1549,7 +1707,21 @@ def create_app(args):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
first_chunk_time = time.time_ns()
|
first_chunk_time = time.time_ns()
|
||||||
response_text = await rag.aquery(cleaned_query, param=query_param)
|
|
||||||
|
# Determine if the request is from Open WebUI's session title and session keyword generation task
|
||||||
|
match_result = re.search(
|
||||||
|
r"\n<chat_history>\nUSER:", cleaned_query, re.MULTILINE
|
||||||
|
)
|
||||||
|
if match_result:
|
||||||
|
if request.system:
|
||||||
|
rag.llm_model_kwargs["system_prompt"] = request.system
|
||||||
|
|
||||||
|
response_text = await rag.llm_model_func(
|
||||||
|
cleaned_query, stream=False, **rag.llm_model_kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response_text = await rag.aquery(cleaned_query, param=query_param)
|
||||||
|
|
||||||
last_chunk_time = time.time_ns()
|
last_chunk_time = time.time_ns()
|
||||||
|
|
||||||
if not response_text:
|
if not response_text:
|
||||||
|
@@ -108,7 +108,10 @@ DEFAULT_CONFIG = {
|
|||||||
"max_retries": 3,
|
"max_retries": 3,
|
||||||
"retry_delay": 1,
|
"retry_delay": 1,
|
||||||
},
|
},
|
||||||
"test_cases": {"basic": {"query": "唐僧有几个徒弟"}},
|
"test_cases": {
|
||||||
|
"basic": {"query": "唐僧有几个徒弟"},
|
||||||
|
"generate": {"query": "电视剧西游记导演是谁"},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -174,22 +177,27 @@ def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2)
|
|||||||
CONFIG = load_config()
|
CONFIG = load_config()
|
||||||
|
|
||||||
|
|
||||||
def get_base_url() -> str:
|
def get_base_url(endpoint: str = "chat") -> str:
|
||||||
"""Return the base URL"""
|
"""Return the base URL for specified endpoint
|
||||||
|
Args:
|
||||||
|
endpoint: API endpoint name (chat or generate)
|
||||||
|
Returns:
|
||||||
|
Complete URL for the endpoint
|
||||||
|
"""
|
||||||
server = CONFIG["server"]
|
server = CONFIG["server"]
|
||||||
return f"http://{server['host']}:{server['port']}/api/chat"
|
return f"http://{server['host']}:{server['port']}/api/{endpoint}"
|
||||||
|
|
||||||
|
|
||||||
def create_request_data(
|
def create_chat_request_data(
|
||||||
content: str, stream: bool = False, model: str = None
|
content: str, stream: bool = False, model: str = None
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Create basic request data
|
"""Create chat request data
|
||||||
Args:
|
Args:
|
||||||
content: User message content
|
content: User message content
|
||||||
stream: Whether to use streaming response
|
stream: Whether to use streaming response
|
||||||
model: Model name
|
model: Model name
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary containing complete request data
|
Dictionary containing complete chat request data
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
"model": model or CONFIG["server"]["model"],
|
"model": model or CONFIG["server"]["model"],
|
||||||
@@ -198,6 +206,35 @@ def create_request_data(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_generate_request_data(
|
||||||
|
prompt: str,
|
||||||
|
system: str = None,
|
||||||
|
stream: bool = False,
|
||||||
|
model: str = None,
|
||||||
|
options: Dict[str, Any] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Create generate request data
|
||||||
|
Args:
|
||||||
|
prompt: Generation prompt
|
||||||
|
system: System prompt
|
||||||
|
stream: Whether to use streaming response
|
||||||
|
model: Model name
|
||||||
|
options: Additional options
|
||||||
|
Returns:
|
||||||
|
Dictionary containing complete generate request data
|
||||||
|
"""
|
||||||
|
data = {
|
||||||
|
"model": model or CONFIG["server"]["model"],
|
||||||
|
"prompt": prompt,
|
||||||
|
"stream": stream,
|
||||||
|
}
|
||||||
|
if system:
|
||||||
|
data["system"] = system
|
||||||
|
if options:
|
||||||
|
data["options"] = options
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
# Global test statistics
|
# Global test statistics
|
||||||
STATS = TestStats()
|
STATS = TestStats()
|
||||||
|
|
||||||
@@ -219,10 +256,12 @@ def run_test(func: Callable, name: str) -> None:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def test_non_stream_chat():
|
def test_non_stream_chat() -> None:
|
||||||
"""Test non-streaming call to /api/chat endpoint"""
|
"""Test non-streaming call to /api/chat endpoint"""
|
||||||
url = get_base_url()
|
url = get_base_url()
|
||||||
data = create_request_data(CONFIG["test_cases"]["basic"]["query"], stream=False)
|
data = create_chat_request_data(
|
||||||
|
CONFIG["test_cases"]["basic"]["query"], stream=False
|
||||||
|
)
|
||||||
|
|
||||||
# Send request
|
# Send request
|
||||||
response = make_request(url, data)
|
response = make_request(url, data)
|
||||||
@@ -239,7 +278,7 @@ def test_non_stream_chat():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_stream_chat():
|
def test_stream_chat() -> None:
|
||||||
"""Test streaming call to /api/chat endpoint
|
"""Test streaming call to /api/chat endpoint
|
||||||
|
|
||||||
Use JSON Lines format to process streaming responses, each line is a complete JSON object.
|
Use JSON Lines format to process streaming responses, each line is a complete JSON object.
|
||||||
@@ -258,7 +297,7 @@ def test_stream_chat():
|
|||||||
The last message will contain performance statistics, with done set to true.
|
The last message will contain performance statistics, with done set to true.
|
||||||
"""
|
"""
|
||||||
url = get_base_url()
|
url = get_base_url()
|
||||||
data = create_request_data(CONFIG["test_cases"]["basic"]["query"], stream=True)
|
data = create_chat_request_data(CONFIG["test_cases"]["basic"]["query"], stream=True)
|
||||||
|
|
||||||
# Send request and get streaming response
|
# Send request and get streaming response
|
||||||
response = make_request(url, data, stream=True)
|
response = make_request(url, data, stream=True)
|
||||||
@@ -295,7 +334,7 @@ def test_stream_chat():
|
|||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
||||||
def test_query_modes():
|
def test_query_modes() -> None:
|
||||||
"""Test different query mode prefixes
|
"""Test different query mode prefixes
|
||||||
|
|
||||||
Supported query modes:
|
Supported query modes:
|
||||||
@@ -313,7 +352,7 @@ def test_query_modes():
|
|||||||
for mode in modes:
|
for mode in modes:
|
||||||
if OutputControl.is_verbose():
|
if OutputControl.is_verbose():
|
||||||
print(f"\n=== Testing /{mode} mode ===")
|
print(f"\n=== Testing /{mode} mode ===")
|
||||||
data = create_request_data(
|
data = create_chat_request_data(
|
||||||
f"/{mode} {CONFIG['test_cases']['basic']['query']}", stream=False
|
f"/{mode} {CONFIG['test_cases']['basic']['query']}", stream=False
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -354,7 +393,7 @@ def create_error_test_data(error_type: str) -> Dict[str, Any]:
|
|||||||
return error_data.get(error_type, error_data["empty_messages"])
|
return error_data.get(error_type, error_data["empty_messages"])
|
||||||
|
|
||||||
|
|
||||||
def test_stream_error_handling():
|
def test_stream_error_handling() -> None:
|
||||||
"""Test error handling for streaming responses
|
"""Test error handling for streaming responses
|
||||||
|
|
||||||
Test scenarios:
|
Test scenarios:
|
||||||
@@ -400,7 +439,7 @@ def test_stream_error_handling():
|
|||||||
response.close()
|
response.close()
|
||||||
|
|
||||||
|
|
||||||
def test_error_handling():
|
def test_error_handling() -> None:
|
||||||
"""Test error handling for non-streaming responses
|
"""Test error handling for non-streaming responses
|
||||||
|
|
||||||
Test scenarios:
|
Test scenarios:
|
||||||
@@ -447,6 +486,165 @@ def test_error_handling():
|
|||||||
print_json_response(response.json(), "Error message")
|
print_json_response(response.json(), "Error message")
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_stream_generate() -> None:
|
||||||
|
"""Test non-streaming call to /api/generate endpoint"""
|
||||||
|
url = get_base_url("generate")
|
||||||
|
data = create_generate_request_data(
|
||||||
|
CONFIG["test_cases"]["generate"]["query"], stream=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send request
|
||||||
|
response = make_request(url, data)
|
||||||
|
|
||||||
|
# Print response
|
||||||
|
if OutputControl.is_verbose():
|
||||||
|
print("\n=== Non-streaming generate response ===")
|
||||||
|
response_json = response.json()
|
||||||
|
|
||||||
|
# Print response content
|
||||||
|
print_json_response(
|
||||||
|
{
|
||||||
|
"model": response_json["model"],
|
||||||
|
"response": response_json["response"],
|
||||||
|
"done": response_json["done"],
|
||||||
|
},
|
||||||
|
"Response content",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_generate() -> None:
|
||||||
|
"""Test streaming call to /api/generate endpoint"""
|
||||||
|
url = get_base_url("generate")
|
||||||
|
data = create_generate_request_data(
|
||||||
|
CONFIG["test_cases"]["generate"]["query"], stream=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send request and get streaming response
|
||||||
|
response = make_request(url, data, stream=True)
|
||||||
|
|
||||||
|
if OutputControl.is_verbose():
|
||||||
|
print("\n=== Streaming generate response ===")
|
||||||
|
output_buffer = []
|
||||||
|
try:
|
||||||
|
for line in response.iter_lines():
|
||||||
|
if line: # Skip empty lines
|
||||||
|
try:
|
||||||
|
# Decode and parse JSON
|
||||||
|
data = json.loads(line.decode("utf-8"))
|
||||||
|
if data.get("done", True): # If it's the completion marker
|
||||||
|
if (
|
||||||
|
"total_duration" in data
|
||||||
|
): # Final performance statistics message
|
||||||
|
break
|
||||||
|
else: # Normal content message
|
||||||
|
content = data.get("response", "")
|
||||||
|
if content: # Only collect non-empty content
|
||||||
|
output_buffer.append(content)
|
||||||
|
print(
|
||||||
|
content, end="", flush=True
|
||||||
|
) # Print content in real-time
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
print("Error decoding JSON from response line")
|
||||||
|
finally:
|
||||||
|
response.close() # Ensure the response connection is closed
|
||||||
|
|
||||||
|
# Print a newline
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_with_system() -> None:
|
||||||
|
"""Test generate with system prompt"""
|
||||||
|
url = get_base_url("generate")
|
||||||
|
data = create_generate_request_data(
|
||||||
|
CONFIG["test_cases"]["generate"]["query"],
|
||||||
|
system="你是一个知识渊博的助手",
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send request
|
||||||
|
response = make_request(url, data)
|
||||||
|
|
||||||
|
# Print response
|
||||||
|
if OutputControl.is_verbose():
|
||||||
|
print("\n=== Generate with system prompt response ===")
|
||||||
|
response_json = response.json()
|
||||||
|
|
||||||
|
# Print response content
|
||||||
|
print_json_response(
|
||||||
|
{
|
||||||
|
"model": response_json["model"],
|
||||||
|
"response": response_json["response"],
|
||||||
|
"done": response_json["done"],
|
||||||
|
},
|
||||||
|
"Response content",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_error_handling() -> None:
|
||||||
|
"""Test error handling for generate endpoint"""
|
||||||
|
url = get_base_url("generate")
|
||||||
|
|
||||||
|
# Test empty prompt
|
||||||
|
if OutputControl.is_verbose():
|
||||||
|
print("\n=== Testing empty prompt ===")
|
||||||
|
data = create_generate_request_data("", stream=False)
|
||||||
|
response = make_request(url, data)
|
||||||
|
print(f"Status code: {response.status_code}")
|
||||||
|
print_json_response(response.json(), "Error message")
|
||||||
|
|
||||||
|
# Test invalid options
|
||||||
|
if OutputControl.is_verbose():
|
||||||
|
print("\n=== Testing invalid options ===")
|
||||||
|
data = create_generate_request_data(
|
||||||
|
CONFIG["test_cases"]["basic"]["query"],
|
||||||
|
options={"invalid_option": "value"},
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
response = make_request(url, data)
|
||||||
|
print(f"Status code: {response.status_code}")
|
||||||
|
print_json_response(response.json(), "Error message")
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_concurrent() -> None:
|
||||||
|
"""Test concurrent generate requests"""
|
||||||
|
import asyncio
|
||||||
|
import aiohttp
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def get_session():
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
async def make_request(session, prompt: str):
|
||||||
|
url = get_base_url("generate")
|
||||||
|
data = create_generate_request_data(prompt, stream=False)
|
||||||
|
try:
|
||||||
|
async with session.post(url, json=data) as response:
|
||||||
|
return await response.json()
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": str(e)}
|
||||||
|
|
||||||
|
async def run_concurrent_requests():
|
||||||
|
prompts = ["第一个问题", "第二个问题", "第三个问题", "第四个问题", "第五个问题"]
|
||||||
|
|
||||||
|
async with get_session() as session:
|
||||||
|
tasks = [make_request(session, prompt) for prompt in prompts]
|
||||||
|
results = await asyncio.gather(*tasks)
|
||||||
|
return results
|
||||||
|
|
||||||
|
if OutputControl.is_verbose():
|
||||||
|
print("\n=== Testing concurrent generate requests ===")
|
||||||
|
|
||||||
|
# Run concurrent requests
|
||||||
|
results = asyncio.run(run_concurrent_requests())
|
||||||
|
|
||||||
|
# Print results
|
||||||
|
for i, result in enumerate(results, 1):
|
||||||
|
print(f"\nRequest {i} result:")
|
||||||
|
print_json_response(result)
|
||||||
|
|
||||||
|
|
||||||
def get_test_cases() -> Dict[str, Callable]:
|
def get_test_cases() -> Dict[str, Callable]:
|
||||||
"""Get all available test cases
|
"""Get all available test cases
|
||||||
Returns:
|
Returns:
|
||||||
@@ -458,6 +656,11 @@ def get_test_cases() -> Dict[str, Callable]:
|
|||||||
"modes": test_query_modes,
|
"modes": test_query_modes,
|
||||||
"errors": test_error_handling,
|
"errors": test_error_handling,
|
||||||
"stream_errors": test_stream_error_handling,
|
"stream_errors": test_stream_error_handling,
|
||||||
|
"non_stream_generate": test_non_stream_generate,
|
||||||
|
"stream_generate": test_stream_generate,
|
||||||
|
"generate_with_system": test_generate_with_system,
|
||||||
|
"generate_errors": test_generate_error_handling,
|
||||||
|
"generate_concurrent": test_generate_concurrent,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -544,18 +747,20 @@ if __name__ == "__main__":
|
|||||||
if "all" in args.tests:
|
if "all" in args.tests:
|
||||||
# Run all tests
|
# Run all tests
|
||||||
if OutputControl.is_verbose():
|
if OutputControl.is_verbose():
|
||||||
print("\n【Basic Functionality Tests】")
|
print("\n【Chat API Tests】")
|
||||||
run_test(test_non_stream_chat, "Non-streaming Call Test")
|
run_test(test_non_stream_chat, "Non-streaming Chat Test")
|
||||||
run_test(test_stream_chat, "Streaming Call Test")
|
run_test(test_stream_chat, "Streaming Chat Test")
|
||||||
|
run_test(test_query_modes, "Chat Query Mode Test")
|
||||||
|
run_test(test_error_handling, "Chat Error Handling Test")
|
||||||
|
run_test(test_stream_error_handling, "Chat Streaming Error Test")
|
||||||
|
|
||||||
if OutputControl.is_verbose():
|
if OutputControl.is_verbose():
|
||||||
print("\n【Query Mode Tests】")
|
print("\n【Generate API Tests】")
|
||||||
run_test(test_query_modes, "Query Mode Test")
|
run_test(test_non_stream_generate, "Non-streaming Generate Test")
|
||||||
|
run_test(test_stream_generate, "Streaming Generate Test")
|
||||||
if OutputControl.is_verbose():
|
run_test(test_generate_with_system, "Generate with System Prompt Test")
|
||||||
print("\n【Error Handling Tests】")
|
run_test(test_generate_error_handling, "Generate Error Handling Test")
|
||||||
run_test(test_error_handling, "Error Handling Test")
|
run_test(test_generate_concurrent, "Generate Concurrent Test")
|
||||||
run_test(test_stream_error_handling, "Streaming Error Handling Test")
|
|
||||||
else:
|
else:
|
||||||
# Run specified tests
|
# Run specified tests
|
||||||
for test_name in args.tests:
|
for test_name in args.tests:
|
||||||
|
Reference in New Issue
Block a user