add: to optionally replace default tiktoken Tokenizer with a custom one

This commit is contained in:
drahnreb
2025-04-17 10:56:23 +02:00
parent 4fd40fd798
commit 20ba1eb9c2
6 changed files with 138 additions and 53 deletions

View File

@@ -12,10 +12,9 @@ import re
from dataclasses import dataclass
from functools import wraps
from hashlib import md5
from typing import Any, Callable, TYPE_CHECKING
from typing import Any, Protocol, Callable, TYPE_CHECKING, List, Optional, Union
import xml.etree.ElementTree as ET
import numpy as np
import tiktoken
from lightrag.prompt import PROMPTS
from dotenv import load_dotenv
@@ -193,9 +192,6 @@ class UnlimitedSemaphore:
pass
ENCODER = None
@dataclass
class EmbeddingFunc:
embedding_dim: int
@@ -311,20 +307,87 @@ def write_json(json_obj, file_name):
json.dump(json_obj, f, indent=2, ensure_ascii=False)
def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"):
global ENCODER
if ENCODER is None:
ENCODER = tiktoken.encoding_for_model(model_name)
tokens = ENCODER.encode(content)
return tokens
class TokenizerInterface(Protocol):
"""
Defines the interface for a tokenizer, requiring encode and decode methods.
"""
def encode(self, content: str) -> List[int]:
"""Encodes a string into a list of tokens."""
...
def decode(self, tokens: List[int]) -> str:
"""Decodes a list of tokens into a string."""
...
class Tokenizer:
"""
A wrapper around a tokenizer to provide a consistent interface for encoding and decoding.
"""
def __init__(self, model_name: str, tokenizer: TokenizerInterface):
"""
Initializes the Tokenizer with a tokenizer model name and a tokenizer instance.
Args:
model_name: The associated model name for the tokenizer.
tokenizer: An instance of a class implementing the TokenizerInterface.
"""
self.model_name: str = model_name
self.tokenizer: TokenizerInterface = tokenizer
def encode(self, content: str) -> List[int]:
"""
Encodes a string into a list of tokens using the underlying tokenizer.
Args:
content: The string to encode.
Returns:
A list of integer tokens.
"""
return self.tokenizer.encode(content)
def decode(self, tokens: List[int]) -> str:
"""
Decodes a list of tokens into a string using the underlying tokenizer.
Args:
tokens: A list of integer tokens to decode.
Returns:
The decoded string.
"""
return self.tokenizer.decode(tokens)
def decode_tokens_by_tiktoken(tokens: list[int], model_name: str = "gpt-4o"):
global ENCODER
if ENCODER is None:
ENCODER = tiktoken.encoding_for_model(model_name)
content = ENCODER.decode(tokens)
return content
class TiktokenTokenizer(Tokenizer):
"""
A Tokenizer implementation using the tiktoken library.
"""
def __init__(self, model_name: str = "gpt-4o-mini"):
"""
Initializes the TiktokenTokenizer with a specified model name.
Args:
model_name: The model name for the tiktoken tokenizer to use. Defaults to "gpt-4o-mini".
Raises:
ImportError: If tiktoken is not installed.
ValueError: If the model_name is invalid.
"""
try:
import tiktoken
except ImportError:
raise ImportError(
"tiktoken is not installed. Please install it with `pip install tiktoken` or define custom `tokenizer_func`."
)
try:
tokenizer = tiktoken.encoding_for_model(model_name)
super().__init__(model_name=model_name, tokenizer=tokenizer)
except KeyError:
raise ValueError(
f"Invalid model_name: {model_name}."
)
def pack_user_ass_to_openai_messages(*args: str):
@@ -368,7 +431,7 @@ def truncate_list_by_token_size(
return []
tokens = 0
for i, data in enumerate(list_data):
tokens += len(encode_string_by_tiktoken(key(data)))
tokens += len(tokenizer.encode(key(data)))
if tokens > max_token_size:
return list_data[:i]
return list_data