support JSON output for ollama and openai

This commit is contained in:
b10902118
2024-11-29 21:41:37 +08:00
parent 4223b4f603
commit 753c1e6714
3 changed files with 42 additions and 18 deletions

View File

@@ -29,7 +29,11 @@ import torch
from pydantic import BaseModel, Field
from typing import List, Dict, Callable, Any
from .base import BaseKVStorage
from .utils import compute_args_hash, wrap_embedding_func_with_attrs
from .utils import (
compute_args_hash,
wrap_embedding_func_with_attrs,
locate_json_string_body_from_string,
)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -301,7 +305,7 @@ async def ollama_model_if_cache(
model, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
kwargs.pop("max_tokens", None)
kwargs.pop("response_format", None)
# kwargs.pop("response_format", None) # allow json
host = kwargs.pop("host", None)
timeout = kwargs.pop("timeout", None)
@@ -345,9 +349,9 @@ def initialize_lmdeploy_pipeline(
backend_config=TurbomindEngineConfig(
tp=tp, model_format=model_format, quant_policy=quant_policy
),
chat_template_config=ChatTemplateConfig(model_name=chat_template)
if chat_template
else None,
chat_template_config=(
ChatTemplateConfig(model_name=chat_template) if chat_template else None
),
log_level="WARNING",
)
return lmdeploy_pipe
@@ -458,9 +462,16 @@ async def lmdeploy_model_if_cache(
return response
class GPTKeywordExtractionFormat(BaseModel):
high_level_keywords: List[str]
low_level_keywords: List[str]
async def gpt_4o_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
if keyword_extraction:
kwargs["response_format"] = GPTKeywordExtractionFormat
return await openai_complete_if_cache(
"gpt-4o",
prompt,
@@ -471,8 +482,10 @@ async def gpt_4o_complete(
async def gpt_4o_mini_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
if keyword_extraction:
kwargs["response_format"] = GPTKeywordExtractionFormat
return await openai_complete_if_cache(
"gpt-4o-mini",
prompt,
@@ -483,45 +496,56 @@ async def gpt_4o_mini_complete(
async def azure_openai_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
return await azure_openai_complete_if_cache(
result = await azure_openai_complete_if_cache(
"conversation-4o-mini",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
if keyword_extraction: # TODO: use JSON API
return locate_json_string_body_from_string(result)
return result
async def bedrock_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
return await bedrock_complete_if_cache(
result = await bedrock_complete_if_cache(
"anthropic.claude-3-haiku-20240307-v1:0",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
if keyword_extraction: # TODO: use JSON API
return locate_json_string_body_from_string(result)
return result
async def hf_model_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
return await hf_model_if_cache(
result = await hf_model_if_cache(
model_name,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
if keyword_extraction: # TODO: use JSON API
return locate_json_string_body_from_string(result)
return result
async def ollama_model_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
if keyword_extraction:
kwargs["response_format"] = "json"
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
return await ollama_model_if_cache(
model_name,