Merge pull request #353 from b10902118/dev

JSON mode support for ollama and openai gpt
This commit is contained in:
zrguo
2024-12-02 15:52:05 +08:00
committed by GitHub
3 changed files with 50 additions and 22 deletions

View File

@@ -29,7 +29,11 @@ import torch
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import List, Dict, Callable, Any from typing import List, Dict, Callable, Any
from .base import BaseKVStorage from .base import BaseKVStorage
from .utils import compute_args_hash, wrap_embedding_func_with_attrs from .utils import (
compute_args_hash,
wrap_embedding_func_with_attrs,
locate_json_string_body_from_string,
)
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -66,6 +70,11 @@ async def openai_complete_if_cache(
if if_cache_return is not None: if if_cache_return is not None:
return if_cache_return["return"] return if_cache_return["return"]
if "response_format" in kwargs:
response = await openai_async_client.beta.chat.completions.parse(
model=model, messages=messages, **kwargs
)
else:
response = await openai_async_client.chat.completions.create( response = await openai_async_client.chat.completions.create(
model=model, messages=messages, **kwargs model=model, messages=messages, **kwargs
) )
@@ -301,7 +310,7 @@ async def ollama_model_if_cache(
model, prompt, system_prompt=None, history_messages=[], **kwargs model, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str: ) -> str:
kwargs.pop("max_tokens", None) kwargs.pop("max_tokens", None)
kwargs.pop("response_format", None) # kwargs.pop("response_format", None) # allow json
host = kwargs.pop("host", None) host = kwargs.pop("host", None)
timeout = kwargs.pop("timeout", None) timeout = kwargs.pop("timeout", None)
@@ -345,9 +354,9 @@ def initialize_lmdeploy_pipeline(
backend_config=TurbomindEngineConfig( backend_config=TurbomindEngineConfig(
tp=tp, model_format=model_format, quant_policy=quant_policy tp=tp, model_format=model_format, quant_policy=quant_policy
), ),
chat_template_config=ChatTemplateConfig(model_name=chat_template) chat_template_config=(
if chat_template ChatTemplateConfig(model_name=chat_template) if chat_template else None
else None, ),
log_level="WARNING", log_level="WARNING",
) )
return lmdeploy_pipe return lmdeploy_pipe
@@ -458,9 +467,16 @@ async def lmdeploy_model_if_cache(
return response return response
class GPTKeywordExtractionFormat(BaseModel):
high_level_keywords: List[str]
low_level_keywords: List[str]
async def gpt_4o_complete( async def gpt_4o_complete(
prompt, system_prompt=None, history_messages=[], **kwargs prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str: ) -> str:
if keyword_extraction:
kwargs["response_format"] = GPTKeywordExtractionFormat
return await openai_complete_if_cache( return await openai_complete_if_cache(
"gpt-4o", "gpt-4o",
prompt, prompt,
@@ -471,8 +487,10 @@ async def gpt_4o_complete(
async def gpt_4o_mini_complete( async def gpt_4o_mini_complete(
prompt, system_prompt=None, history_messages=[], **kwargs prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str: ) -> str:
if keyword_extraction:
kwargs["response_format"] = GPTKeywordExtractionFormat
return await openai_complete_if_cache( return await openai_complete_if_cache(
"gpt-4o-mini", "gpt-4o-mini",
prompt, prompt,
@@ -483,45 +501,56 @@ async def gpt_4o_mini_complete(
async def azure_openai_complete( async def azure_openai_complete(
prompt, system_prompt=None, history_messages=[], **kwargs prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str: ) -> str:
return await azure_openai_complete_if_cache( result = await azure_openai_complete_if_cache(
"conversation-4o-mini", "conversation-4o-mini",
prompt, prompt,
system_prompt=system_prompt, system_prompt=system_prompt,
history_messages=history_messages, history_messages=history_messages,
**kwargs, **kwargs,
) )
if keyword_extraction: # TODO: use JSON API
return locate_json_string_body_from_string(result)
return result
async def bedrock_complete( async def bedrock_complete(
prompt, system_prompt=None, history_messages=[], **kwargs prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str: ) -> str:
return await bedrock_complete_if_cache( result = await bedrock_complete_if_cache(
"anthropic.claude-3-haiku-20240307-v1:0", "anthropic.claude-3-haiku-20240307-v1:0",
prompt, prompt,
system_prompt=system_prompt, system_prompt=system_prompt,
history_messages=history_messages, history_messages=history_messages,
**kwargs, **kwargs,
) )
if keyword_extraction: # TODO: use JSON API
return locate_json_string_body_from_string(result)
return result
async def hf_model_complete( async def hf_model_complete(
prompt, system_prompt=None, history_messages=[], **kwargs prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str: ) -> str:
model_name = kwargs["hashing_kv"].global_config["llm_model_name"] model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
return await hf_model_if_cache( result = await hf_model_if_cache(
model_name, model_name,
prompt, prompt,
system_prompt=system_prompt, system_prompt=system_prompt,
history_messages=history_messages, history_messages=history_messages,
**kwargs, **kwargs,
) )
if keyword_extraction: # TODO: use JSON API
return locate_json_string_body_from_string(result)
return result
async def ollama_model_complete( async def ollama_model_complete(
prompt, system_prompt=None, history_messages=[], **kwargs prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str: ) -> str:
if keyword_extraction:
kwargs["format"] = "json"
model_name = kwargs["hashing_kv"].global_config["llm_model_name"] model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
return await ollama_model_if_cache( return await ollama_model_if_cache(
model_name, model_name,

View File

@@ -17,7 +17,6 @@ from .utils import (
split_string_by_multi_markers, split_string_by_multi_markers,
truncate_list_by_token_size, truncate_list_by_token_size,
process_combine_contexts, process_combine_contexts,
locate_json_string_body_from_string,
) )
from .base import ( from .base import (
BaseGraphStorage, BaseGraphStorage,
@@ -461,12 +460,12 @@ async def kg_query(
use_model_func = global_config["llm_model_func"] use_model_func = global_config["llm_model_func"]
kw_prompt_temp = PROMPTS["keywords_extraction"] kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language) kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language)
result = await use_model_func(kw_prompt) result = await use_model_func(kw_prompt, keyword_extraction=True)
logger.info("kw_prompt result:") logger.info("kw_prompt result:")
print(result) print(result)
try: try:
json_text = locate_json_string_body_from_string(result) # json_text = locate_json_string_body_from_string(result) # handled in use_model_func
keywords_data = json.loads(json_text) keywords_data = json.loads(result)
hl_keywords = keywords_data.get("high_level_keywords", []) hl_keywords = keywords_data.get("high_level_keywords", [])
ll_keywords = keywords_data.get("low_level_keywords", []) ll_keywords = keywords_data.get("low_level_keywords", [])

View File

@@ -54,7 +54,7 @@ def locate_json_string_body_from_string(content: str) -> Union[str, None]:
maybe_json_str = maybe_json_str.replace("\\n", "") maybe_json_str = maybe_json_str.replace("\\n", "")
maybe_json_str = maybe_json_str.replace("\n", "") maybe_json_str = maybe_json_str.replace("\n", "")
maybe_json_str = maybe_json_str.replace("'", '"') maybe_json_str = maybe_json_str.replace("'", '"')
json.loads(maybe_json_str) # json.loads(maybe_json_str) # don't check here, cannot validate schema after all
return maybe_json_str return maybe_json_str
except Exception: except Exception:
pass pass