Merge pull request #353 from b10902118/dev

JSON mode support for ollama and openai gpt
This commit is contained in:
zrguo
2024-12-02 15:52:05 +08:00
committed by GitHub
3 changed files with 50 additions and 22 deletions

View File

@@ -29,7 +29,11 @@ import torch
from pydantic import BaseModel, Field
from typing import List, Dict, Callable, Any
from .base import BaseKVStorage
from .utils import compute_args_hash, wrap_embedding_func_with_attrs
from .utils import (
compute_args_hash,
wrap_embedding_func_with_attrs,
locate_json_string_body_from_string,
)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -66,6 +70,11 @@ async def openai_complete_if_cache(
if if_cache_return is not None:
return if_cache_return["return"]
if "response_format" in kwargs:
response = await openai_async_client.beta.chat.completions.parse(
model=model, messages=messages, **kwargs
)
else:
response = await openai_async_client.chat.completions.create(
model=model, messages=messages, **kwargs
)
@@ -301,7 +310,7 @@ async def ollama_model_if_cache(
model, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
kwargs.pop("max_tokens", None)
kwargs.pop("response_format", None)
# kwargs.pop("response_format", None) # allow json
host = kwargs.pop("host", None)
timeout = kwargs.pop("timeout", None)
@@ -345,9 +354,9 @@ def initialize_lmdeploy_pipeline(
backend_config=TurbomindEngineConfig(
tp=tp, model_format=model_format, quant_policy=quant_policy
),
chat_template_config=ChatTemplateConfig(model_name=chat_template)
if chat_template
else None,
chat_template_config=(
ChatTemplateConfig(model_name=chat_template) if chat_template else None
),
log_level="WARNING",
)
return lmdeploy_pipe
@@ -458,9 +467,16 @@ async def lmdeploy_model_if_cache(
return response
class GPTKeywordExtractionFormat(BaseModel):
high_level_keywords: List[str]
low_level_keywords: List[str]
async def gpt_4o_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
if keyword_extraction:
kwargs["response_format"] = GPTKeywordExtractionFormat
return await openai_complete_if_cache(
"gpt-4o",
prompt,
@@ -471,8 +487,10 @@ async def gpt_4o_complete(
async def gpt_4o_mini_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
if keyword_extraction:
kwargs["response_format"] = GPTKeywordExtractionFormat
return await openai_complete_if_cache(
"gpt-4o-mini",
prompt,
@@ -483,45 +501,56 @@ async def gpt_4o_mini_complete(
async def azure_openai_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
return await azure_openai_complete_if_cache(
result = await azure_openai_complete_if_cache(
"conversation-4o-mini",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
if keyword_extraction: # TODO: use JSON API
return locate_json_string_body_from_string(result)
return result
async def bedrock_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
return await bedrock_complete_if_cache(
result = await bedrock_complete_if_cache(
"anthropic.claude-3-haiku-20240307-v1:0",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
if keyword_extraction: # TODO: use JSON API
return locate_json_string_body_from_string(result)
return result
async def hf_model_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
return await hf_model_if_cache(
result = await hf_model_if_cache(
model_name,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
if keyword_extraction: # TODO: use JSON API
return locate_json_string_body_from_string(result)
return result
async def ollama_model_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
if keyword_extraction:
kwargs["format"] = "json"
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
return await ollama_model_if_cache(
model_name,

View File

@@ -17,7 +17,6 @@ from .utils import (
split_string_by_multi_markers,
truncate_list_by_token_size,
process_combine_contexts,
locate_json_string_body_from_string,
)
from .base import (
BaseGraphStorage,
@@ -461,12 +460,12 @@ async def kg_query(
use_model_func = global_config["llm_model_func"]
kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language)
result = await use_model_func(kw_prompt)
result = await use_model_func(kw_prompt, keyword_extraction=True)
logger.info("kw_prompt result:")
print(result)
try:
json_text = locate_json_string_body_from_string(result)
keywords_data = json.loads(json_text)
# json_text = locate_json_string_body_from_string(result) # handled in use_model_func
keywords_data = json.loads(result)
hl_keywords = keywords_data.get("high_level_keywords", [])
ll_keywords = keywords_data.get("low_level_keywords", [])

View File

@@ -54,7 +54,7 @@ def locate_json_string_body_from_string(content: str) -> Union[str, None]:
maybe_json_str = maybe_json_str.replace("\\n", "")
maybe_json_str = maybe_json_str.replace("\n", "")
maybe_json_str = maybe_json_str.replace("'", '"')
json.loads(maybe_json_str)
# json.loads(maybe_json_str) # don't check here, cannot validate schema after all
return maybe_json_str
except Exception:
pass