add: to optionally replace default tiktoken Tokenizer with a custom one
This commit is contained in:
@@ -1090,7 +1090,8 @@ rag.clear_cache(modes=["local"])
|
|||||||
| **doc_status_storage** | `str` | Storage type for documents process status. Supported types: `JsonDocStatusStorage`,`PGDocStatusStorage`,`MongoDocStatusStorage` | `JsonDocStatusStorage` |
|
| **doc_status_storage** | `str` | Storage type for documents process status. Supported types: `JsonDocStatusStorage`,`PGDocStatusStorage`,`MongoDocStatusStorage` | `JsonDocStatusStorage` |
|
||||||
| **chunk_token_size** | `int` | 拆分文档时每个块的最大令牌大小 | `1200` |
|
| **chunk_token_size** | `int` | 拆分文档时每个块的最大令牌大小 | `1200` |
|
||||||
| **chunk_overlap_token_size** | `int` | 拆分文档时两个块之间的重叠令牌大小 | `100` |
|
| **chunk_overlap_token_size** | `int` | 拆分文档时两个块之间的重叠令牌大小 | `100` |
|
||||||
| **tiktoken_model_name** | `str` | 用于计算令牌数的Tiktoken编码器的模型名称 | `gpt-4o-mini` |
|
| **tokenizer** | `Tokenizer` | 用于将文本转换为 tokens(数字)以及使用遵循 TokenizerInterface 协议的 .encode() 和 .decode() 函数将 tokens 转换回文本的函数。 如果您不指定,它将使用默认的 Tiktoken tokenizer。 | `TiktokenTokenizer` |
|
||||||
|
| **tiktoken_model_name** | `str` | 如果您使用的是默认的 Tiktoken tokenizer,那么这是要使用的特定 Tiktoken 模型的名称。如果您提供自己的 tokenizer,则忽略此设置。 | `gpt-4o-mini` |
|
||||||
| **entity_extract_max_gleaning** | `int` | 实体提取过程中的循环次数,附加历史消息 | `1` |
|
| **entity_extract_max_gleaning** | `int` | 实体提取过程中的循环次数,附加历史消息 | `1` |
|
||||||
| **entity_summary_to_max_tokens** | `int` | 每个实体摘要的最大令牌大小 | `500` |
|
| **entity_summary_to_max_tokens** | `int` | 每个实体摘要的最大令牌大小 | `500` |
|
||||||
| **node_embedding_algorithm** | `str` | 节点嵌入算法(当前未使用) | `node2vec` |
|
| **node_embedding_algorithm** | `str` | 节点嵌入算法(当前未使用) | `node2vec` |
|
||||||
|
@@ -1156,7 +1156,8 @@ Valid modes are:
|
|||||||
| **doc_status_storage** | `str` | Storage type for documents process status. Supported types: `JsonDocStatusStorage`,`PGDocStatusStorage`,`MongoDocStatusStorage` | `JsonDocStatusStorage` |
|
| **doc_status_storage** | `str` | Storage type for documents process status. Supported types: `JsonDocStatusStorage`,`PGDocStatusStorage`,`MongoDocStatusStorage` | `JsonDocStatusStorage` |
|
||||||
| **chunk_token_size** | `int` | Maximum token size per chunk when splitting documents | `1200` |
|
| **chunk_token_size** | `int` | Maximum token size per chunk when splitting documents | `1200` |
|
||||||
| **chunk_overlap_token_size** | `int` | Overlap token size between two chunks when splitting documents | `100` |
|
| **chunk_overlap_token_size** | `int` | Overlap token size between two chunks when splitting documents | `100` |
|
||||||
| **tiktoken_model_name** | `str` | Model name for the Tiktoken encoder used to calculate token numbers | `gpt-4o-mini` |
|
| **tokenizer** | `Tokenizer` | The function used to convert text into tokens (numbers) and back using .encode() and .decode() functions following `TokenizerInterface` protocol. If you don't specify one, it will use the default Tiktoken tokenizer. | `TiktokenTokenizer` |
|
||||||
|
| **tiktoken_model_name** | `str` | If you're using the default Tiktoken tokenizer, this is the name of the specific Tiktoken model to use. This setting is ignored if you provide your own tokenizer. | `gpt-4o-mini` |
|
||||||
| **entity_extract_max_gleaning** | `int` | Number of loops in the entity extraction process, appending history messages | `1` |
|
| **entity_extract_max_gleaning** | `int` | Number of loops in the entity extraction process, appending history messages | `1` |
|
||||||
| **entity_summary_to_max_tokens** | `int` | Maximum token size for each entity summary | `500` |
|
| **entity_summary_to_max_tokens** | `int` | Maximum token size for each entity summary | `500` |
|
||||||
| **node_embedding_algorithm** | `str` | Algorithm for node embedding (currently not used) | `node2vec` |
|
| **node_embedding_algorithm** | `str` | Algorithm for node embedding (currently not used) | `node2vec` |
|
||||||
|
@@ -10,7 +10,7 @@ from fastapi.responses import StreamingResponse
|
|||||||
import asyncio
|
import asyncio
|
||||||
from ascii_colors import trace_exception
|
from ascii_colors import trace_exception
|
||||||
from lightrag import LightRAG, QueryParam
|
from lightrag import LightRAG, QueryParam
|
||||||
from lightrag.utils import encode_string_by_tiktoken
|
from lightrag.utils import TiktokenTokenizer
|
||||||
from lightrag.api.utils_api import ollama_server_infos, get_combined_auth_dependency
|
from lightrag.api.utils_api import ollama_server_infos, get_combined_auth_dependency
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
|
|
||||||
@@ -97,7 +97,7 @@ class OllamaTagResponse(BaseModel):
|
|||||||
|
|
||||||
def estimate_tokens(text: str) -> int:
|
def estimate_tokens(text: str) -> int:
|
||||||
"""Estimate the number of tokens in text using tiktoken"""
|
"""Estimate the number of tokens in text using tiktoken"""
|
||||||
tokens = encode_string_by_tiktoken(text)
|
tokens = TiktokenTokenizer().encode(text)
|
||||||
return len(tokens)
|
return len(tokens)
|
||||||
|
|
||||||
|
|
||||||
|
@@ -41,11 +41,12 @@ from .operate import (
|
|||||||
)
|
)
|
||||||
from .prompt import GRAPH_FIELD_SEP, PROMPTS
|
from .prompt import GRAPH_FIELD_SEP, PROMPTS
|
||||||
from .utils import (
|
from .utils import (
|
||||||
|
Tokenizer,
|
||||||
|
TiktokenTokenizer,
|
||||||
EmbeddingFunc,
|
EmbeddingFunc,
|
||||||
always_get_an_event_loop,
|
always_get_an_event_loop,
|
||||||
compute_mdhash_id,
|
compute_mdhash_id,
|
||||||
convert_response_to_json,
|
convert_response_to_json,
|
||||||
encode_string_by_tiktoken,
|
|
||||||
lazy_external_import,
|
lazy_external_import,
|
||||||
limit_async_func_call,
|
limit_async_func_call,
|
||||||
get_content_summary,
|
get_content_summary,
|
||||||
@@ -122,33 +123,38 @@ class LightRAG:
|
|||||||
)
|
)
|
||||||
"""Number of overlapping tokens between consecutive text chunks to preserve context."""
|
"""Number of overlapping tokens between consecutive text chunks to preserve context."""
|
||||||
|
|
||||||
tiktoken_model_name: str = field(default="gpt-4o-mini")
|
tokenizer: Optional[Tokenizer] = field(default=None)
|
||||||
"""Model name used for tokenization when chunking text."""
|
"""
|
||||||
|
A function that returns a Tokenizer instance.
|
||||||
|
If None, and a `tiktoken_model_name` is provided, a TiktokenTokenizer will be created.
|
||||||
|
If both are None, the default TiktokenTokenizer is used.
|
||||||
|
"""
|
||||||
|
|
||||||
"""Maximum number of tokens used for summarizing extracted entities."""
|
tiktoken_model_name: str = field(default="gpt-4o-mini")
|
||||||
|
"""Model name used for tokenization when chunking text with tiktoken. Defaults to `gpt-4o-mini`."""
|
||||||
|
|
||||||
chunking_func: Callable[
|
chunking_func: Callable[
|
||||||
[
|
[
|
||||||
|
Tokenizer,
|
||||||
str,
|
str,
|
||||||
str | None,
|
Optional[str],
|
||||||
bool,
|
bool,
|
||||||
int,
|
int,
|
||||||
int,
|
int,
|
||||||
str,
|
|
||||||
],
|
],
|
||||||
list[dict[str, Any]],
|
List[Dict[str, Any]],
|
||||||
] = field(default_factory=lambda: chunking_by_token_size)
|
] = field(default_factory=lambda: chunking_by_token_size)
|
||||||
"""
|
"""
|
||||||
Custom chunking function for splitting text into chunks before processing.
|
Custom chunking function for splitting text into chunks before processing.
|
||||||
|
|
||||||
The function should take the following parameters:
|
The function should take the following parameters:
|
||||||
|
|
||||||
|
- `tokenizer`: A Tokenizer instance to use for tokenization.
|
||||||
- `content`: The text to be split into chunks.
|
- `content`: The text to be split into chunks.
|
||||||
- `split_by_character`: The character to split the text on. If None, the text is split into chunks of `chunk_token_size` tokens.
|
- `split_by_character`: The character to split the text on. If None, the text is split into chunks of `chunk_token_size` tokens.
|
||||||
- `split_by_character_only`: If True, the text is split only on the specified character.
|
- `split_by_character_only`: If True, the text is split only on the specified character.
|
||||||
- `chunk_token_size`: The maximum number of tokens per chunk.
|
- `chunk_token_size`: The maximum number of tokens per chunk.
|
||||||
- `chunk_overlap_token_size`: The number of overlapping tokens between consecutive chunks.
|
- `chunk_overlap_token_size`: The number of overlapping tokens between consecutive chunks.
|
||||||
- `tiktoken_model_name`: The name of the tiktoken model to use for tokenization.
|
|
||||||
|
|
||||||
The function should return a list of dictionaries, where each dictionary contains the following keys:
|
The function should return a list of dictionaries, where each dictionary contains the following keys:
|
||||||
- `tokens`: The number of tokens in the chunk.
|
- `tokens`: The number of tokens in the chunk.
|
||||||
@@ -310,7 +316,15 @@ class LightRAG:
|
|||||||
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
|
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
|
||||||
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
||||||
|
|
||||||
# Init LLM
|
# Init Tokenizer
|
||||||
|
# Post-initialization hook to handle backward compatabile tokenizer initialization based on provided parameters
|
||||||
|
if self.tokenizer is None:
|
||||||
|
if self.tiktoken_model_name:
|
||||||
|
self.tokenizer = TiktokenTokenizer(self.tiktoken_model_name)
|
||||||
|
else:
|
||||||
|
self.tokenizer = TiktokenTokenizer()
|
||||||
|
|
||||||
|
# Init Embedding
|
||||||
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore
|
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore
|
||||||
self.embedding_func
|
self.embedding_func
|
||||||
)
|
)
|
||||||
@@ -900,12 +914,12 @@ class LightRAG:
|
|||||||
"file_path": file_path, # Add file path to each chunk
|
"file_path": file_path, # Add file path to each chunk
|
||||||
}
|
}
|
||||||
for dp in self.chunking_func(
|
for dp in self.chunking_func(
|
||||||
|
self.tokenizer,
|
||||||
status_doc.content,
|
status_doc.content,
|
||||||
split_by_character,
|
split_by_character,
|
||||||
split_by_character_only,
|
split_by_character_only,
|
||||||
self.chunk_overlap_token_size,
|
self.chunk_overlap_token_size,
|
||||||
self.chunk_token_size,
|
self.chunk_token_size,
|
||||||
self.tiktoken_model_name,
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1134,8 +1148,8 @@ class LightRAG:
|
|||||||
chunk_content = clean_text(chunk_data["content"])
|
chunk_content = clean_text(chunk_data["content"])
|
||||||
source_id = chunk_data["source_id"]
|
source_id = chunk_data["source_id"]
|
||||||
tokens = len(
|
tokens = len(
|
||||||
encode_string_by_tiktoken(
|
self.tokenizer.encode(
|
||||||
chunk_content, model_name=self.tiktoken_model_name
|
chunk_content
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
chunk_order_index = (
|
chunk_order_index = (
|
||||||
|
@@ -12,8 +12,7 @@ from .utils import (
|
|||||||
logger,
|
logger,
|
||||||
clean_str,
|
clean_str,
|
||||||
compute_mdhash_id,
|
compute_mdhash_id,
|
||||||
decode_tokens_by_tiktoken,
|
Tokenizer,
|
||||||
encode_string_by_tiktoken,
|
|
||||||
is_float_regex,
|
is_float_regex,
|
||||||
list_of_list_to_csv,
|
list_of_list_to_csv,
|
||||||
normalize_extracted_info,
|
normalize_extracted_info,
|
||||||
@@ -46,32 +45,31 @@ load_dotenv(dotenv_path=".env", override=False)
|
|||||||
|
|
||||||
|
|
||||||
def chunking_by_token_size(
|
def chunking_by_token_size(
|
||||||
|
tokenizer: Tokenizer,
|
||||||
content: str,
|
content: str,
|
||||||
split_by_character: str | None = None,
|
split_by_character: str | None = None,
|
||||||
split_by_character_only: bool = False,
|
split_by_character_only: bool = False,
|
||||||
overlap_token_size: int = 128,
|
overlap_token_size: int = 128,
|
||||||
max_token_size: int = 1024,
|
max_token_size: int = 1024,
|
||||||
tiktoken_model: str = "gpt-4o",
|
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model)
|
tokens = tokenizer.encode(content)
|
||||||
results: list[dict[str, Any]] = []
|
results: list[dict[str, Any]] = []
|
||||||
if split_by_character:
|
if split_by_character:
|
||||||
raw_chunks = content.split(split_by_character)
|
raw_chunks = content.split(split_by_character)
|
||||||
new_chunks = []
|
new_chunks = []
|
||||||
if split_by_character_only:
|
if split_by_character_only:
|
||||||
for chunk in raw_chunks:
|
for chunk in raw_chunks:
|
||||||
_tokens = encode_string_by_tiktoken(chunk, model_name=tiktoken_model)
|
_tokens = tokenizer.encode(chunk)
|
||||||
new_chunks.append((len(_tokens), chunk))
|
new_chunks.append((len(_tokens), chunk))
|
||||||
else:
|
else:
|
||||||
for chunk in raw_chunks:
|
for chunk in raw_chunks:
|
||||||
_tokens = encode_string_by_tiktoken(chunk, model_name=tiktoken_model)
|
_tokens = tokenizer.encode(chunk)
|
||||||
if len(_tokens) > max_token_size:
|
if len(_tokens) > max_token_size:
|
||||||
for start in range(
|
for start in range(
|
||||||
0, len(_tokens), max_token_size - overlap_token_size
|
0, len(_tokens), max_token_size - overlap_token_size
|
||||||
):
|
):
|
||||||
chunk_content = decode_tokens_by_tiktoken(
|
chunk_content = tokenizer.decode(
|
||||||
_tokens[start : start + max_token_size],
|
_tokens[start : start + max_token_size]
|
||||||
model_name=tiktoken_model,
|
|
||||||
)
|
)
|
||||||
new_chunks.append(
|
new_chunks.append(
|
||||||
(min(max_token_size, len(_tokens) - start), chunk_content)
|
(min(max_token_size, len(_tokens) - start), chunk_content)
|
||||||
@@ -90,8 +88,8 @@ def chunking_by_token_size(
|
|||||||
for index, start in enumerate(
|
for index, start in enumerate(
|
||||||
range(0, len(tokens), max_token_size - overlap_token_size)
|
range(0, len(tokens), max_token_size - overlap_token_size)
|
||||||
):
|
):
|
||||||
chunk_content = decode_tokens_by_tiktoken(
|
chunk_content = tokenizer.decode(
|
||||||
tokens[start : start + max_token_size], model_name=tiktoken_model
|
tokens[start : start + max_token_size]
|
||||||
)
|
)
|
||||||
results.append(
|
results.append(
|
||||||
{
|
{
|
||||||
@@ -116,6 +114,7 @@ async def _handle_entity_relation_summary(
|
|||||||
If too long, use LLM to summarize.
|
If too long, use LLM to summarize.
|
||||||
"""
|
"""
|
||||||
use_llm_func: callable = global_config["llm_model_func"]
|
use_llm_func: callable = global_config["llm_model_func"]
|
||||||
|
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||||
llm_max_tokens = global_config["llm_model_max_token_size"]
|
llm_max_tokens = global_config["llm_model_max_token_size"]
|
||||||
tiktoken_model_name = global_config["tiktoken_model_name"]
|
tiktoken_model_name = global_config["tiktoken_model_name"]
|
||||||
summary_max_tokens = global_config["summary_to_max_tokens"]
|
summary_max_tokens = global_config["summary_to_max_tokens"]
|
||||||
@@ -124,10 +123,12 @@ async def _handle_entity_relation_summary(
|
|||||||
"language", PROMPTS["DEFAULT_LANGUAGE"]
|
"language", PROMPTS["DEFAULT_LANGUAGE"]
|
||||||
)
|
)
|
||||||
|
|
||||||
tokens = encode_string_by_tiktoken(description, model_name=tiktoken_model_name)
|
tokens = tokenizer.encode(description)
|
||||||
|
if len(tokens) < summary_max_tokens: # No need for summary
|
||||||
|
return description
|
||||||
prompt_template = PROMPTS["summarize_entity_descriptions"]
|
prompt_template = PROMPTS["summarize_entity_descriptions"]
|
||||||
use_description = decode_tokens_by_tiktoken(
|
use_description = tokenizer.decode(
|
||||||
tokens[:llm_max_tokens], model_name=tiktoken_model_name
|
tokens[:llm_max_tokens]
|
||||||
)
|
)
|
||||||
context_base = dict(
|
context_base = dict(
|
||||||
entity_name=entity_or_relation_name,
|
entity_name=entity_or_relation_name,
|
||||||
@@ -865,7 +866,8 @@ async def kg_query(
|
|||||||
if query_param.only_need_prompt:
|
if query_param.only_need_prompt:
|
||||||
return sys_prompt
|
return sys_prompt
|
||||||
|
|
||||||
len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt))
|
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||||
|
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
|
||||||
logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
|
logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
|
||||||
|
|
||||||
response = await use_model_func(
|
response = await use_model_func(
|
||||||
@@ -987,7 +989,8 @@ async def extract_keywords_only(
|
|||||||
query=text, examples=examples, language=language, history=history_context
|
query=text, examples=examples, language=language, history=history_context
|
||||||
)
|
)
|
||||||
|
|
||||||
len_of_prompts = len(encode_string_by_tiktoken(kw_prompt))
|
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||||
|
len_of_prompts = len(tokenizer.encode(kw_prompt))
|
||||||
logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
|
logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
|
||||||
|
|
||||||
# 5. Call the LLM for keyword extraction
|
# 5. Call the LLM for keyword extraction
|
||||||
@@ -1210,7 +1213,8 @@ async def mix_kg_vector_query(
|
|||||||
if query_param.only_need_prompt:
|
if query_param.only_need_prompt:
|
||||||
return sys_prompt
|
return sys_prompt
|
||||||
|
|
||||||
len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt))
|
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||||
|
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
|
||||||
logger.debug(f"[mix_kg_vector_query]Prompt Tokens: {len_of_prompts}")
|
logger.debug(f"[mix_kg_vector_query]Prompt Tokens: {len_of_prompts}")
|
||||||
|
|
||||||
# 6. Generate response
|
# 6. Generate response
|
||||||
@@ -1978,7 +1982,8 @@ async def naive_query(
|
|||||||
if query_param.only_need_prompt:
|
if query_param.only_need_prompt:
|
||||||
return sys_prompt
|
return sys_prompt
|
||||||
|
|
||||||
len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt))
|
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||||
|
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
|
||||||
logger.debug(f"[naive_query]Prompt Tokens: {len_of_prompts}")
|
logger.debug(f"[naive_query]Prompt Tokens: {len_of_prompts}")
|
||||||
|
|
||||||
response = await use_model_func(
|
response = await use_model_func(
|
||||||
@@ -2125,7 +2130,8 @@ async def kg_query_with_keywords(
|
|||||||
if query_param.only_need_prompt:
|
if query_param.only_need_prompt:
|
||||||
return sys_prompt
|
return sys_prompt
|
||||||
|
|
||||||
len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt))
|
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||||
|
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
|
||||||
logger.debug(f"[kg_query_with_keywords]Prompt Tokens: {len_of_prompts}")
|
logger.debug(f"[kg_query_with_keywords]Prompt Tokens: {len_of_prompts}")
|
||||||
|
|
||||||
# 6. Generate response
|
# 6. Generate response
|
||||||
|
@@ -12,10 +12,9 @@ import re
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from hashlib import md5
|
from hashlib import md5
|
||||||
from typing import Any, Callable, TYPE_CHECKING
|
from typing import Any, Protocol, Callable, TYPE_CHECKING, List, Optional, Union
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tiktoken
|
|
||||||
from lightrag.prompt import PROMPTS
|
from lightrag.prompt import PROMPTS
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
@@ -193,9 +192,6 @@ class UnlimitedSemaphore:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
ENCODER = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EmbeddingFunc:
|
class EmbeddingFunc:
|
||||||
embedding_dim: int
|
embedding_dim: int
|
||||||
@@ -311,20 +307,87 @@ def write_json(json_obj, file_name):
|
|||||||
json.dump(json_obj, f, indent=2, ensure_ascii=False)
|
json.dump(json_obj, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"):
|
class TokenizerInterface(Protocol):
|
||||||
global ENCODER
|
"""
|
||||||
if ENCODER is None:
|
Defines the interface for a tokenizer, requiring encode and decode methods.
|
||||||
ENCODER = tiktoken.encoding_for_model(model_name)
|
"""
|
||||||
tokens = ENCODER.encode(content)
|
def encode(self, content: str) -> List[int]:
|
||||||
return tokens
|
"""Encodes a string into a list of tokens."""
|
||||||
|
...
|
||||||
|
|
||||||
|
def decode(self, tokens: List[int]) -> str:
|
||||||
|
"""Decodes a list of tokens into a string."""
|
||||||
|
...
|
||||||
|
|
||||||
|
class Tokenizer:
|
||||||
|
"""
|
||||||
|
A wrapper around a tokenizer to provide a consistent interface for encoding and decoding.
|
||||||
|
"""
|
||||||
|
def __init__(self, model_name: str, tokenizer: TokenizerInterface):
|
||||||
|
"""
|
||||||
|
Initializes the Tokenizer with a tokenizer model name and a tokenizer instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: The associated model name for the tokenizer.
|
||||||
|
tokenizer: An instance of a class implementing the TokenizerInterface.
|
||||||
|
"""
|
||||||
|
self.model_name: str = model_name
|
||||||
|
self.tokenizer: TokenizerInterface = tokenizer
|
||||||
|
|
||||||
|
def encode(self, content: str) -> List[int]:
|
||||||
|
"""
|
||||||
|
Encodes a string into a list of tokens using the underlying tokenizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: The string to encode.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of integer tokens.
|
||||||
|
"""
|
||||||
|
return self.tokenizer.encode(content)
|
||||||
|
|
||||||
|
def decode(self, tokens: List[int]) -> str:
|
||||||
|
"""
|
||||||
|
Decodes a list of tokens into a string using the underlying tokenizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens: A list of integer tokens to decode.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The decoded string.
|
||||||
|
"""
|
||||||
|
return self.tokenizer.decode(tokens)
|
||||||
|
|
||||||
|
|
||||||
def decode_tokens_by_tiktoken(tokens: list[int], model_name: str = "gpt-4o"):
|
class TiktokenTokenizer(Tokenizer):
|
||||||
global ENCODER
|
"""
|
||||||
if ENCODER is None:
|
A Tokenizer implementation using the tiktoken library.
|
||||||
ENCODER = tiktoken.encoding_for_model(model_name)
|
"""
|
||||||
content = ENCODER.decode(tokens)
|
def __init__(self, model_name: str = "gpt-4o-mini"):
|
||||||
return content
|
"""
|
||||||
|
Initializes the TiktokenTokenizer with a specified model name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: The model name for the tiktoken tokenizer to use. Defaults to "gpt-4o-mini".
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImportError: If tiktoken is not installed.
|
||||||
|
ValueError: If the model_name is invalid.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import tiktoken
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"tiktoken is not installed. Please install it with `pip install tiktoken` or define custom `tokenizer_func`."
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
tokenizer = tiktoken.encoding_for_model(model_name)
|
||||||
|
super().__init__(model_name=model_name, tokenizer=tokenizer)
|
||||||
|
except KeyError:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid model_name: {model_name}."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def pack_user_ass_to_openai_messages(*args: str):
|
def pack_user_ass_to_openai_messages(*args: str):
|
||||||
@@ -368,7 +431,7 @@ def truncate_list_by_token_size(
|
|||||||
return []
|
return []
|
||||||
tokens = 0
|
tokens = 0
|
||||||
for i, data in enumerate(list_data):
|
for i, data in enumerate(list_data):
|
||||||
tokens += len(encode_string_by_tiktoken(key(data)))
|
tokens += len(tokenizer.encode(key(data)))
|
||||||
if tokens > max_token_size:
|
if tokens > max_token_size:
|
||||||
return list_data[:i]
|
return list_data[:i]
|
||||||
return list_data
|
return list_data
|
||||||
|
Reference in New Issue
Block a user