Merge pull request #353 from b10902118/dev
JSON mode support for ollama and openai gpt
This commit is contained in:
@@ -29,7 +29,11 @@ import torch
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import List, Dict, Callable, Any
|
from typing import List, Dict, Callable, Any
|
||||||
from .base import BaseKVStorage
|
from .base import BaseKVStorage
|
||||||
from .utils import compute_args_hash, wrap_embedding_func_with_attrs
|
from .utils import (
|
||||||
|
compute_args_hash,
|
||||||
|
wrap_embedding_func_with_attrs,
|
||||||
|
locate_json_string_body_from_string,
|
||||||
|
)
|
||||||
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
@@ -66,9 +70,14 @@ async def openai_complete_if_cache(
|
|||||||
if if_cache_return is not None:
|
if if_cache_return is not None:
|
||||||
return if_cache_return["return"]
|
return if_cache_return["return"]
|
||||||
|
|
||||||
response = await openai_async_client.chat.completions.create(
|
if "response_format" in kwargs:
|
||||||
model=model, messages=messages, **kwargs
|
response = await openai_async_client.beta.chat.completions.parse(
|
||||||
)
|
model=model, messages=messages, **kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = await openai_async_client.chat.completions.create(
|
||||||
|
model=model, messages=messages, **kwargs
|
||||||
|
)
|
||||||
content = response.choices[0].message.content
|
content = response.choices[0].message.content
|
||||||
if r"\u" in content:
|
if r"\u" in content:
|
||||||
content = content.encode("utf-8").decode("unicode_escape")
|
content = content.encode("utf-8").decode("unicode_escape")
|
||||||
@@ -301,7 +310,7 @@ async def ollama_model_if_cache(
|
|||||||
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
kwargs.pop("max_tokens", None)
|
kwargs.pop("max_tokens", None)
|
||||||
kwargs.pop("response_format", None)
|
# kwargs.pop("response_format", None) # allow json
|
||||||
host = kwargs.pop("host", None)
|
host = kwargs.pop("host", None)
|
||||||
timeout = kwargs.pop("timeout", None)
|
timeout = kwargs.pop("timeout", None)
|
||||||
|
|
||||||
@@ -345,9 +354,9 @@ def initialize_lmdeploy_pipeline(
|
|||||||
backend_config=TurbomindEngineConfig(
|
backend_config=TurbomindEngineConfig(
|
||||||
tp=tp, model_format=model_format, quant_policy=quant_policy
|
tp=tp, model_format=model_format, quant_policy=quant_policy
|
||||||
),
|
),
|
||||||
chat_template_config=ChatTemplateConfig(model_name=chat_template)
|
chat_template_config=(
|
||||||
if chat_template
|
ChatTemplateConfig(model_name=chat_template) if chat_template else None
|
||||||
else None,
|
),
|
||||||
log_level="WARNING",
|
log_level="WARNING",
|
||||||
)
|
)
|
||||||
return lmdeploy_pipe
|
return lmdeploy_pipe
|
||||||
@@ -458,9 +467,16 @@ async def lmdeploy_model_if_cache(
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
class GPTKeywordExtractionFormat(BaseModel):
|
||||||
|
high_level_keywords: List[str]
|
||||||
|
low_level_keywords: List[str]
|
||||||
|
|
||||||
|
|
||||||
async def gpt_4o_complete(
|
async def gpt_4o_complete(
|
||||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
|
if keyword_extraction:
|
||||||
|
kwargs["response_format"] = GPTKeywordExtractionFormat
|
||||||
return await openai_complete_if_cache(
|
return await openai_complete_if_cache(
|
||||||
"gpt-4o",
|
"gpt-4o",
|
||||||
prompt,
|
prompt,
|
||||||
@@ -471,8 +487,10 @@ async def gpt_4o_complete(
|
|||||||
|
|
||||||
|
|
||||||
async def gpt_4o_mini_complete(
|
async def gpt_4o_mini_complete(
|
||||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
|
if keyword_extraction:
|
||||||
|
kwargs["response_format"] = GPTKeywordExtractionFormat
|
||||||
return await openai_complete_if_cache(
|
return await openai_complete_if_cache(
|
||||||
"gpt-4o-mini",
|
"gpt-4o-mini",
|
||||||
prompt,
|
prompt,
|
||||||
@@ -483,45 +501,56 @@ async def gpt_4o_mini_complete(
|
|||||||
|
|
||||||
|
|
||||||
async def azure_openai_complete(
|
async def azure_openai_complete(
|
||||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
return await azure_openai_complete_if_cache(
|
result = await azure_openai_complete_if_cache(
|
||||||
"conversation-4o-mini",
|
"conversation-4o-mini",
|
||||||
prompt,
|
prompt,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
history_messages=history_messages,
|
history_messages=history_messages,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
if keyword_extraction: # TODO: use JSON API
|
||||||
|
return locate_json_string_body_from_string(result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def bedrock_complete(
|
async def bedrock_complete(
|
||||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
return await bedrock_complete_if_cache(
|
result = await bedrock_complete_if_cache(
|
||||||
"anthropic.claude-3-haiku-20240307-v1:0",
|
"anthropic.claude-3-haiku-20240307-v1:0",
|
||||||
prompt,
|
prompt,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
history_messages=history_messages,
|
history_messages=history_messages,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
if keyword_extraction: # TODO: use JSON API
|
||||||
|
return locate_json_string_body_from_string(result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def hf_model_complete(
|
async def hf_model_complete(
|
||||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
||||||
return await hf_model_if_cache(
|
result = await hf_model_if_cache(
|
||||||
model_name,
|
model_name,
|
||||||
prompt,
|
prompt,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
history_messages=history_messages,
|
history_messages=history_messages,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
if keyword_extraction: # TODO: use JSON API
|
||||||
|
return locate_json_string_body_from_string(result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def ollama_model_complete(
|
async def ollama_model_complete(
|
||||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
|
if keyword_extraction:
|
||||||
|
kwargs["format"] = "json"
|
||||||
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
||||||
return await ollama_model_if_cache(
|
return await ollama_model_if_cache(
|
||||||
model_name,
|
model_name,
|
||||||
|
@@ -17,7 +17,6 @@ from .utils import (
|
|||||||
split_string_by_multi_markers,
|
split_string_by_multi_markers,
|
||||||
truncate_list_by_token_size,
|
truncate_list_by_token_size,
|
||||||
process_combine_contexts,
|
process_combine_contexts,
|
||||||
locate_json_string_body_from_string,
|
|
||||||
)
|
)
|
||||||
from .base import (
|
from .base import (
|
||||||
BaseGraphStorage,
|
BaseGraphStorage,
|
||||||
@@ -461,12 +460,12 @@ async def kg_query(
|
|||||||
use_model_func = global_config["llm_model_func"]
|
use_model_func = global_config["llm_model_func"]
|
||||||
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
||||||
kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language)
|
kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language)
|
||||||
result = await use_model_func(kw_prompt)
|
result = await use_model_func(kw_prompt, keyword_extraction=True)
|
||||||
logger.info("kw_prompt result:")
|
logger.info("kw_prompt result:")
|
||||||
print(result)
|
print(result)
|
||||||
try:
|
try:
|
||||||
json_text = locate_json_string_body_from_string(result)
|
# json_text = locate_json_string_body_from_string(result) # handled in use_model_func
|
||||||
keywords_data = json.loads(json_text)
|
keywords_data = json.loads(result)
|
||||||
hl_keywords = keywords_data.get("high_level_keywords", [])
|
hl_keywords = keywords_data.get("high_level_keywords", [])
|
||||||
ll_keywords = keywords_data.get("low_level_keywords", [])
|
ll_keywords = keywords_data.get("low_level_keywords", [])
|
||||||
|
|
||||||
|
@@ -54,7 +54,7 @@ def locate_json_string_body_from_string(content: str) -> Union[str, None]:
|
|||||||
maybe_json_str = maybe_json_str.replace("\\n", "")
|
maybe_json_str = maybe_json_str.replace("\\n", "")
|
||||||
maybe_json_str = maybe_json_str.replace("\n", "")
|
maybe_json_str = maybe_json_str.replace("\n", "")
|
||||||
maybe_json_str = maybe_json_str.replace("'", '"')
|
maybe_json_str = maybe_json_str.replace("'", '"')
|
||||||
json.loads(maybe_json_str)
|
# json.loads(maybe_json_str) # don't check here, cannot validate schema after all
|
||||||
return maybe_json_str
|
return maybe_json_str
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
Reference in New Issue
Block a user