Merge pull request #353 from b10902118/dev
JSON mode support for ollama and openai gpt
This commit is contained in:
@@ -29,7 +29,11 @@ import torch
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Callable, Any
|
||||
from .base import BaseKVStorage
|
||||
from .utils import compute_args_hash, wrap_embedding_func_with_attrs
|
||||
from .utils import (
|
||||
compute_args_hash,
|
||||
wrap_embedding_func_with_attrs,
|
||||
locate_json_string_body_from_string,
|
||||
)
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
@@ -66,6 +70,11 @@ async def openai_complete_if_cache(
|
||||
if if_cache_return is not None:
|
||||
return if_cache_return["return"]
|
||||
|
||||
if "response_format" in kwargs:
|
||||
response = await openai_async_client.beta.chat.completions.parse(
|
||||
model=model, messages=messages, **kwargs
|
||||
)
|
||||
else:
|
||||
response = await openai_async_client.chat.completions.create(
|
||||
model=model, messages=messages, **kwargs
|
||||
)
|
||||
@@ -301,7 +310,7 @@ async def ollama_model_if_cache(
|
||||
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
kwargs.pop("max_tokens", None)
|
||||
kwargs.pop("response_format", None)
|
||||
# kwargs.pop("response_format", None) # allow json
|
||||
host = kwargs.pop("host", None)
|
||||
timeout = kwargs.pop("timeout", None)
|
||||
|
||||
@@ -345,9 +354,9 @@ def initialize_lmdeploy_pipeline(
|
||||
backend_config=TurbomindEngineConfig(
|
||||
tp=tp, model_format=model_format, quant_policy=quant_policy
|
||||
),
|
||||
chat_template_config=ChatTemplateConfig(model_name=chat_template)
|
||||
if chat_template
|
||||
else None,
|
||||
chat_template_config=(
|
||||
ChatTemplateConfig(model_name=chat_template) if chat_template else None
|
||||
),
|
||||
log_level="WARNING",
|
||||
)
|
||||
return lmdeploy_pipe
|
||||
@@ -458,9 +467,16 @@ async def lmdeploy_model_if_cache(
|
||||
return response
|
||||
|
||||
|
||||
class GPTKeywordExtractionFormat(BaseModel):
|
||||
high_level_keywords: List[str]
|
||||
low_level_keywords: List[str]
|
||||
|
||||
|
||||
async def gpt_4o_complete(
|
||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
) -> str:
|
||||
if keyword_extraction:
|
||||
kwargs["response_format"] = GPTKeywordExtractionFormat
|
||||
return await openai_complete_if_cache(
|
||||
"gpt-4o",
|
||||
prompt,
|
||||
@@ -471,8 +487,10 @@ async def gpt_4o_complete(
|
||||
|
||||
|
||||
async def gpt_4o_mini_complete(
|
||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
) -> str:
|
||||
if keyword_extraction:
|
||||
kwargs["response_format"] = GPTKeywordExtractionFormat
|
||||
return await openai_complete_if_cache(
|
||||
"gpt-4o-mini",
|
||||
prompt,
|
||||
@@ -483,45 +501,56 @@ async def gpt_4o_mini_complete(
|
||||
|
||||
|
||||
async def azure_openai_complete(
|
||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
) -> str:
|
||||
return await azure_openai_complete_if_cache(
|
||||
result = await azure_openai_complete_if_cache(
|
||||
"conversation-4o-mini",
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
**kwargs,
|
||||
)
|
||||
if keyword_extraction: # TODO: use JSON API
|
||||
return locate_json_string_body_from_string(result)
|
||||
return result
|
||||
|
||||
|
||||
async def bedrock_complete(
|
||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
) -> str:
|
||||
return await bedrock_complete_if_cache(
|
||||
result = await bedrock_complete_if_cache(
|
||||
"anthropic.claude-3-haiku-20240307-v1:0",
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
**kwargs,
|
||||
)
|
||||
if keyword_extraction: # TODO: use JSON API
|
||||
return locate_json_string_body_from_string(result)
|
||||
return result
|
||||
|
||||
|
||||
async def hf_model_complete(
|
||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
) -> str:
|
||||
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
||||
return await hf_model_if_cache(
|
||||
result = await hf_model_if_cache(
|
||||
model_name,
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
**kwargs,
|
||||
)
|
||||
if keyword_extraction: # TODO: use JSON API
|
||||
return locate_json_string_body_from_string(result)
|
||||
return result
|
||||
|
||||
|
||||
async def ollama_model_complete(
|
||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
) -> str:
|
||||
if keyword_extraction:
|
||||
kwargs["format"] = "json"
|
||||
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
||||
return await ollama_model_if_cache(
|
||||
model_name,
|
||||
|
@@ -17,7 +17,6 @@ from .utils import (
|
||||
split_string_by_multi_markers,
|
||||
truncate_list_by_token_size,
|
||||
process_combine_contexts,
|
||||
locate_json_string_body_from_string,
|
||||
)
|
||||
from .base import (
|
||||
BaseGraphStorage,
|
||||
@@ -461,12 +460,12 @@ async def kg_query(
|
||||
use_model_func = global_config["llm_model_func"]
|
||||
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
||||
kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language)
|
||||
result = await use_model_func(kw_prompt)
|
||||
result = await use_model_func(kw_prompt, keyword_extraction=True)
|
||||
logger.info("kw_prompt result:")
|
||||
print(result)
|
||||
try:
|
||||
json_text = locate_json_string_body_from_string(result)
|
||||
keywords_data = json.loads(json_text)
|
||||
# json_text = locate_json_string_body_from_string(result) # handled in use_model_func
|
||||
keywords_data = json.loads(result)
|
||||
hl_keywords = keywords_data.get("high_level_keywords", [])
|
||||
ll_keywords = keywords_data.get("low_level_keywords", [])
|
||||
|
||||
|
@@ -54,7 +54,7 @@ def locate_json_string_body_from_string(content: str) -> Union[str, None]:
|
||||
maybe_json_str = maybe_json_str.replace("\\n", "")
|
||||
maybe_json_str = maybe_json_str.replace("\n", "")
|
||||
maybe_json_str = maybe_json_str.replace("'", '"')
|
||||
json.loads(maybe_json_str)
|
||||
# json.loads(maybe_json_str) # don't check here, cannot validate schema after all
|
||||
return maybe_json_str
|
||||
except Exception:
|
||||
pass
|
||||
|
Reference in New Issue
Block a user