add: to optionally replace default tiktoken Tokenizer with a custom one

This commit is contained in:
drahnreb
2025-04-17 10:56:23 +02:00
parent 4fd40fd798
commit 20ba1eb9c2
6 changed files with 138 additions and 53 deletions

View File

@@ -12,8 +12,7 @@ from .utils import (
logger,
clean_str,
compute_mdhash_id,
decode_tokens_by_tiktoken,
encode_string_by_tiktoken,
Tokenizer,
is_float_regex,
list_of_list_to_csv,
normalize_extracted_info,
@@ -46,32 +45,31 @@ load_dotenv(dotenv_path=".env", override=False)
def chunking_by_token_size(
tokenizer: Tokenizer,
content: str,
split_by_character: str | None = None,
split_by_character_only: bool = False,
overlap_token_size: int = 128,
max_token_size: int = 1024,
tiktoken_model: str = "gpt-4o",
) -> list[dict[str, Any]]:
tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model)
tokens = tokenizer.encode(content)
results: list[dict[str, Any]] = []
if split_by_character:
raw_chunks = content.split(split_by_character)
new_chunks = []
if split_by_character_only:
for chunk in raw_chunks:
_tokens = encode_string_by_tiktoken(chunk, model_name=tiktoken_model)
_tokens = tokenizer.encode(chunk)
new_chunks.append((len(_tokens), chunk))
else:
for chunk in raw_chunks:
_tokens = encode_string_by_tiktoken(chunk, model_name=tiktoken_model)
_tokens = tokenizer.encode(chunk)
if len(_tokens) > max_token_size:
for start in range(
0, len(_tokens), max_token_size - overlap_token_size
):
chunk_content = decode_tokens_by_tiktoken(
_tokens[start : start + max_token_size],
model_name=tiktoken_model,
chunk_content = tokenizer.decode(
_tokens[start : start + max_token_size]
)
new_chunks.append(
(min(max_token_size, len(_tokens) - start), chunk_content)
@@ -90,8 +88,8 @@ def chunking_by_token_size(
for index, start in enumerate(
range(0, len(tokens), max_token_size - overlap_token_size)
):
chunk_content = decode_tokens_by_tiktoken(
tokens[start : start + max_token_size], model_name=tiktoken_model
chunk_content = tokenizer.decode(
tokens[start : start + max_token_size]
)
results.append(
{
@@ -116,6 +114,7 @@ async def _handle_entity_relation_summary(
If too long, use LLM to summarize.
"""
use_llm_func: callable = global_config["llm_model_func"]
tokenizer: Tokenizer = global_config["tokenizer"]
llm_max_tokens = global_config["llm_model_max_token_size"]
tiktoken_model_name = global_config["tiktoken_model_name"]
summary_max_tokens = global_config["summary_to_max_tokens"]
@@ -124,10 +123,12 @@ async def _handle_entity_relation_summary(
"language", PROMPTS["DEFAULT_LANGUAGE"]
)
tokens = encode_string_by_tiktoken(description, model_name=tiktoken_model_name)
tokens = tokenizer.encode(description)
if len(tokens) < summary_max_tokens: # No need for summary
return description
prompt_template = PROMPTS["summarize_entity_descriptions"]
use_description = decode_tokens_by_tiktoken(
tokens[:llm_max_tokens], model_name=tiktoken_model_name
use_description = tokenizer.decode(
tokens[:llm_max_tokens]
)
context_base = dict(
entity_name=entity_or_relation_name,
@@ -865,7 +866,8 @@ async def kg_query(
if query_param.only_need_prompt:
return sys_prompt
len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt))
tokenizer: Tokenizer = global_config["tokenizer"]
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
response = await use_model_func(
@@ -987,7 +989,8 @@ async def extract_keywords_only(
query=text, examples=examples, language=language, history=history_context
)
len_of_prompts = len(encode_string_by_tiktoken(kw_prompt))
tokenizer: Tokenizer = global_config["tokenizer"]
len_of_prompts = len(tokenizer.encode(kw_prompt))
logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
# 5. Call the LLM for keyword extraction
@@ -1210,7 +1213,8 @@ async def mix_kg_vector_query(
if query_param.only_need_prompt:
return sys_prompt
len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt))
tokenizer: Tokenizer = global_config["tokenizer"]
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
logger.debug(f"[mix_kg_vector_query]Prompt Tokens: {len_of_prompts}")
# 6. Generate response
@@ -1978,7 +1982,8 @@ async def naive_query(
if query_param.only_need_prompt:
return sys_prompt
len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt))
tokenizer: Tokenizer = global_config["tokenizer"]
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
logger.debug(f"[naive_query]Prompt Tokens: {len_of_prompts}")
response = await use_model_func(
@@ -2125,7 +2130,8 @@ async def kg_query_with_keywords(
if query_param.only_need_prompt:
return sys_prompt
len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt))
tokenizer: Tokenizer = global_config["tokenizer"]
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
logger.debug(f"[kg_query_with_keywords]Prompt Tokens: {len_of_prompts}")
# 6. Generate response