add: to optionally replace default tiktoken Tokenizer with a custom one
This commit is contained in:
@@ -12,8 +12,7 @@ from .utils import (
|
||||
logger,
|
||||
clean_str,
|
||||
compute_mdhash_id,
|
||||
decode_tokens_by_tiktoken,
|
||||
encode_string_by_tiktoken,
|
||||
Tokenizer,
|
||||
is_float_regex,
|
||||
list_of_list_to_csv,
|
||||
normalize_extracted_info,
|
||||
@@ -46,32 +45,31 @@ load_dotenv(dotenv_path=".env", override=False)
|
||||
|
||||
|
||||
def chunking_by_token_size(
|
||||
tokenizer: Tokenizer,
|
||||
content: str,
|
||||
split_by_character: str | None = None,
|
||||
split_by_character_only: bool = False,
|
||||
overlap_token_size: int = 128,
|
||||
max_token_size: int = 1024,
|
||||
tiktoken_model: str = "gpt-4o",
|
||||
) -> list[dict[str, Any]]:
|
||||
tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model)
|
||||
tokens = tokenizer.encode(content)
|
||||
results: list[dict[str, Any]] = []
|
||||
if split_by_character:
|
||||
raw_chunks = content.split(split_by_character)
|
||||
new_chunks = []
|
||||
if split_by_character_only:
|
||||
for chunk in raw_chunks:
|
||||
_tokens = encode_string_by_tiktoken(chunk, model_name=tiktoken_model)
|
||||
_tokens = tokenizer.encode(chunk)
|
||||
new_chunks.append((len(_tokens), chunk))
|
||||
else:
|
||||
for chunk in raw_chunks:
|
||||
_tokens = encode_string_by_tiktoken(chunk, model_name=tiktoken_model)
|
||||
_tokens = tokenizer.encode(chunk)
|
||||
if len(_tokens) > max_token_size:
|
||||
for start in range(
|
||||
0, len(_tokens), max_token_size - overlap_token_size
|
||||
):
|
||||
chunk_content = decode_tokens_by_tiktoken(
|
||||
_tokens[start : start + max_token_size],
|
||||
model_name=tiktoken_model,
|
||||
chunk_content = tokenizer.decode(
|
||||
_tokens[start : start + max_token_size]
|
||||
)
|
||||
new_chunks.append(
|
||||
(min(max_token_size, len(_tokens) - start), chunk_content)
|
||||
@@ -90,8 +88,8 @@ def chunking_by_token_size(
|
||||
for index, start in enumerate(
|
||||
range(0, len(tokens), max_token_size - overlap_token_size)
|
||||
):
|
||||
chunk_content = decode_tokens_by_tiktoken(
|
||||
tokens[start : start + max_token_size], model_name=tiktoken_model
|
||||
chunk_content = tokenizer.decode(
|
||||
tokens[start : start + max_token_size]
|
||||
)
|
||||
results.append(
|
||||
{
|
||||
@@ -116,6 +114,7 @@ async def _handle_entity_relation_summary(
|
||||
If too long, use LLM to summarize.
|
||||
"""
|
||||
use_llm_func: callable = global_config["llm_model_func"]
|
||||
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||
llm_max_tokens = global_config["llm_model_max_token_size"]
|
||||
tiktoken_model_name = global_config["tiktoken_model_name"]
|
||||
summary_max_tokens = global_config["summary_to_max_tokens"]
|
||||
@@ -124,10 +123,12 @@ async def _handle_entity_relation_summary(
|
||||
"language", PROMPTS["DEFAULT_LANGUAGE"]
|
||||
)
|
||||
|
||||
tokens = encode_string_by_tiktoken(description, model_name=tiktoken_model_name)
|
||||
tokens = tokenizer.encode(description)
|
||||
if len(tokens) < summary_max_tokens: # No need for summary
|
||||
return description
|
||||
prompt_template = PROMPTS["summarize_entity_descriptions"]
|
||||
use_description = decode_tokens_by_tiktoken(
|
||||
tokens[:llm_max_tokens], model_name=tiktoken_model_name
|
||||
use_description = tokenizer.decode(
|
||||
tokens[:llm_max_tokens]
|
||||
)
|
||||
context_base = dict(
|
||||
entity_name=entity_or_relation_name,
|
||||
@@ -865,7 +866,8 @@ async def kg_query(
|
||||
if query_param.only_need_prompt:
|
||||
return sys_prompt
|
||||
|
||||
len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt))
|
||||
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
|
||||
logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
|
||||
|
||||
response = await use_model_func(
|
||||
@@ -987,7 +989,8 @@ async def extract_keywords_only(
|
||||
query=text, examples=examples, language=language, history=history_context
|
||||
)
|
||||
|
||||
len_of_prompts = len(encode_string_by_tiktoken(kw_prompt))
|
||||
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||
len_of_prompts = len(tokenizer.encode(kw_prompt))
|
||||
logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
|
||||
|
||||
# 5. Call the LLM for keyword extraction
|
||||
@@ -1210,7 +1213,8 @@ async def mix_kg_vector_query(
|
||||
if query_param.only_need_prompt:
|
||||
return sys_prompt
|
||||
|
||||
len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt))
|
||||
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
|
||||
logger.debug(f"[mix_kg_vector_query]Prompt Tokens: {len_of_prompts}")
|
||||
|
||||
# 6. Generate response
|
||||
@@ -1978,7 +1982,8 @@ async def naive_query(
|
||||
if query_param.only_need_prompt:
|
||||
return sys_prompt
|
||||
|
||||
len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt))
|
||||
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
|
||||
logger.debug(f"[naive_query]Prompt Tokens: {len_of_prompts}")
|
||||
|
||||
response = await use_model_func(
|
||||
@@ -2125,7 +2130,8 @@ async def kg_query_with_keywords(
|
||||
if query_param.only_need_prompt:
|
||||
return sys_prompt
|
||||
|
||||
len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt))
|
||||
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
|
||||
logger.debug(f"[kg_query_with_keywords]Prompt Tokens: {len_of_prompts}")
|
||||
|
||||
# 6. Generate response
|
||||
|
Reference in New Issue
Block a user