support JSON output for ollama and openai

This commit is contained in:
b10902118
2024-11-29 21:41:37 +08:00
parent 4223b4f603
commit 753c1e6714
3 changed files with 42 additions and 18 deletions

View File

@@ -29,7 +29,11 @@ import torch
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import List, Dict, Callable, Any from typing import List, Dict, Callable, Any
from .base import BaseKVStorage from .base import BaseKVStorage
from .utils import compute_args_hash, wrap_embedding_func_with_attrs from .utils import (
compute_args_hash,
wrap_embedding_func_with_attrs,
locate_json_string_body_from_string,
)
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -301,7 +305,7 @@ async def ollama_model_if_cache(
model, prompt, system_prompt=None, history_messages=[], **kwargs model, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str: ) -> str:
kwargs.pop("max_tokens", None) kwargs.pop("max_tokens", None)
kwargs.pop("response_format", None) # kwargs.pop("response_format", None) # allow json
host = kwargs.pop("host", None) host = kwargs.pop("host", None)
timeout = kwargs.pop("timeout", None) timeout = kwargs.pop("timeout", None)
@@ -345,9 +349,9 @@ def initialize_lmdeploy_pipeline(
backend_config=TurbomindEngineConfig( backend_config=TurbomindEngineConfig(
tp=tp, model_format=model_format, quant_policy=quant_policy tp=tp, model_format=model_format, quant_policy=quant_policy
), ),
chat_template_config=ChatTemplateConfig(model_name=chat_template) chat_template_config=(
if chat_template ChatTemplateConfig(model_name=chat_template) if chat_template else None
else None, ),
log_level="WARNING", log_level="WARNING",
) )
return lmdeploy_pipe return lmdeploy_pipe
@@ -458,9 +462,16 @@ async def lmdeploy_model_if_cache(
return response return response
class GPTKeywordExtractionFormat(BaseModel):
high_level_keywords: List[str]
low_level_keywords: List[str]
async def gpt_4o_complete( async def gpt_4o_complete(
prompt, system_prompt=None, history_messages=[], **kwargs prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str: ) -> str:
if keyword_extraction:
kwargs["response_format"] = GPTKeywordExtractionFormat
return await openai_complete_if_cache( return await openai_complete_if_cache(
"gpt-4o", "gpt-4o",
prompt, prompt,
@@ -471,8 +482,10 @@ async def gpt_4o_complete(
async def gpt_4o_mini_complete( async def gpt_4o_mini_complete(
prompt, system_prompt=None, history_messages=[], **kwargs prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str: ) -> str:
if keyword_extraction:
kwargs["response_format"] = GPTKeywordExtractionFormat
return await openai_complete_if_cache( return await openai_complete_if_cache(
"gpt-4o-mini", "gpt-4o-mini",
prompt, prompt,
@@ -483,45 +496,56 @@ async def gpt_4o_mini_complete(
async def azure_openai_complete( async def azure_openai_complete(
prompt, system_prompt=None, history_messages=[], **kwargs prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str: ) -> str:
return await azure_openai_complete_if_cache( result = await azure_openai_complete_if_cache(
"conversation-4o-mini", "conversation-4o-mini",
prompt, prompt,
system_prompt=system_prompt, system_prompt=system_prompt,
history_messages=history_messages, history_messages=history_messages,
**kwargs, **kwargs,
) )
if keyword_extraction: # TODO: use JSON API
return locate_json_string_body_from_string(result)
return result
async def bedrock_complete( async def bedrock_complete(
prompt, system_prompt=None, history_messages=[], **kwargs prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str: ) -> str:
return await bedrock_complete_if_cache( result = await bedrock_complete_if_cache(
"anthropic.claude-3-haiku-20240307-v1:0", "anthropic.claude-3-haiku-20240307-v1:0",
prompt, prompt,
system_prompt=system_prompt, system_prompt=system_prompt,
history_messages=history_messages, history_messages=history_messages,
**kwargs, **kwargs,
) )
if keyword_extraction: # TODO: use JSON API
return locate_json_string_body_from_string(result)
return result
async def hf_model_complete( async def hf_model_complete(
prompt, system_prompt=None, history_messages=[], **kwargs prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str: ) -> str:
model_name = kwargs["hashing_kv"].global_config["llm_model_name"] model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
return await hf_model_if_cache( result = await hf_model_if_cache(
model_name, model_name,
prompt, prompt,
system_prompt=system_prompt, system_prompt=system_prompt,
history_messages=history_messages, history_messages=history_messages,
**kwargs, **kwargs,
) )
if keyword_extraction: # TODO: use JSON API
return locate_json_string_body_from_string(result)
return result
async def ollama_model_complete( async def ollama_model_complete(
prompt, system_prompt=None, history_messages=[], **kwargs prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str: ) -> str:
if keyword_extraction:
kwargs["response_format"] = "json"
model_name = kwargs["hashing_kv"].global_config["llm_model_name"] model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
return await ollama_model_if_cache( return await ollama_model_if_cache(
model_name, model_name,

View File

@@ -461,12 +461,12 @@ async def kg_query(
use_model_func = global_config["llm_model_func"] use_model_func = global_config["llm_model_func"]
kw_prompt_temp = PROMPTS["keywords_extraction"] kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language) kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language)
result = await use_model_func(kw_prompt) result = await use_model_func(kw_prompt, keyword_extraction=True)
logger.info("kw_prompt result:") logger.info("kw_prompt result:")
print(result) print(result)
try: try:
json_text = locate_json_string_body_from_string(result) # json_text = locate_json_string_body_from_string(result) # handled in use_model_func
keywords_data = json.loads(json_text) keywords_data = json.loads(result)
hl_keywords = keywords_data.get("high_level_keywords", []) hl_keywords = keywords_data.get("high_level_keywords", [])
ll_keywords = keywords_data.get("low_level_keywords", []) ll_keywords = keywords_data.get("low_level_keywords", [])

View File

@@ -54,7 +54,7 @@ def locate_json_string_body_from_string(content: str) -> Union[str, None]:
maybe_json_str = maybe_json_str.replace("\\n", "") maybe_json_str = maybe_json_str.replace("\\n", "")
maybe_json_str = maybe_json_str.replace("\n", "") maybe_json_str = maybe_json_str.replace("\n", "")
maybe_json_str = maybe_json_str.replace("'", '"') maybe_json_str = maybe_json_str.replace("'", '"')
json.loads(maybe_json_str) # json.loads(maybe_json_str) # don't check here, cannot validate schema after all
return maybe_json_str return maybe_json_str
except Exception: except Exception:
pass pass