support JSON output for ollama and openai

This commit is contained in:
b10902118
2024-11-29 21:41:37 +08:00
parent f3fe5d3b64
commit b0dd600429
3 changed files with 42 additions and 18 deletions

View File

@@ -29,7 +29,11 @@ import torch
from pydantic import BaseModel, Field
from typing import List, Dict, Callable, Any
from .base import BaseKVStorage
from .utils import compute_args_hash, wrap_embedding_func_with_attrs
from .utils import (
compute_args_hash,
wrap_embedding_func_with_attrs,
locate_json_string_body_from_string,
)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -301,7 +305,7 @@ async def ollama_model_if_cache(
model, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
kwargs.pop("max_tokens", None)
kwargs.pop("response_format", None)
# kwargs.pop("response_format", None) # allow json
host = kwargs.pop("host", None)
timeout = kwargs.pop("timeout", None)
@@ -345,9 +349,9 @@ def initialize_lmdeploy_pipeline(
backend_config=TurbomindEngineConfig(
tp=tp, model_format=model_format, quant_policy=quant_policy
),
chat_template_config=ChatTemplateConfig(model_name=chat_template)
if chat_template
else None,
chat_template_config=(
ChatTemplateConfig(model_name=chat_template) if chat_template else None
),
log_level="WARNING",
)
return lmdeploy_pipe
@@ -458,9 +462,16 @@ async def lmdeploy_model_if_cache(
return response
class GPTKeywordExtractionFormat(BaseModel):
high_level_keywords: List[str]
low_level_keywords: List[str]
async def gpt_4o_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
if keyword_extraction:
kwargs["response_format"] = GPTKeywordExtractionFormat
return await openai_complete_if_cache(
"gpt-4o",
prompt,
@@ -471,8 +482,10 @@ async def gpt_4o_complete(
async def gpt_4o_mini_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
if keyword_extraction:
kwargs["response_format"] = GPTKeywordExtractionFormat
return await openai_complete_if_cache(
"gpt-4o-mini",
prompt,
@@ -483,45 +496,56 @@ async def gpt_4o_mini_complete(
async def azure_openai_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
return await azure_openai_complete_if_cache(
result = await azure_openai_complete_if_cache(
"conversation-4o-mini",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
if keyword_extraction: # TODO: use JSON API
return locate_json_string_body_from_string(result)
return result
async def bedrock_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
return await bedrock_complete_if_cache(
result = await bedrock_complete_if_cache(
"anthropic.claude-3-haiku-20240307-v1:0",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
if keyword_extraction: # TODO: use JSON API
return locate_json_string_body_from_string(result)
return result
async def hf_model_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
return await hf_model_if_cache(
result = await hf_model_if_cache(
model_name,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
if keyword_extraction: # TODO: use JSON API
return locate_json_string_body_from_string(result)
return result
async def ollama_model_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
if keyword_extraction:
kwargs["response_format"] = "json"
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
return await ollama_model_if_cache(
model_name,

View File

@@ -461,12 +461,12 @@ async def kg_query(
use_model_func = global_config["llm_model_func"]
kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language)
result = await use_model_func(kw_prompt)
result = await use_model_func(kw_prompt, keyword_extraction=True)
logger.info("kw_prompt result:")
print(result)
try:
json_text = locate_json_string_body_from_string(result)
keywords_data = json.loads(json_text)
# json_text = locate_json_string_body_from_string(result) # handled in use_model_func
keywords_data = json.loads(result)
hl_keywords = keywords_data.get("high_level_keywords", [])
ll_keywords = keywords_data.get("low_level_keywords", [])

View File

@@ -54,7 +54,7 @@ def locate_json_string_body_from_string(content: str) -> Union[str, None]:
maybe_json_str = maybe_json_str.replace("\\n", "")
maybe_json_str = maybe_json_str.replace("\n", "")
maybe_json_str = maybe_json_str.replace("'", '"')
json.loads(maybe_json_str)
# json.loads(maybe_json_str) # don't check here, cannot validate schema after all
return maybe_json_str
except Exception:
pass