add: to optionally replace default tiktoken Tokenizer with a custom one
This commit is contained in:
@@ -12,10 +12,9 @@ import re
|
||||
from dataclasses import dataclass
|
||||
from functools import wraps
|
||||
from hashlib import md5
|
||||
from typing import Any, Callable, TYPE_CHECKING
|
||||
from typing import Any, Protocol, Callable, TYPE_CHECKING, List, Optional, Union
|
||||
import xml.etree.ElementTree as ET
|
||||
import numpy as np
|
||||
import tiktoken
|
||||
from lightrag.prompt import PROMPTS
|
||||
from dotenv import load_dotenv
|
||||
|
||||
@@ -193,9 +192,6 @@ class UnlimitedSemaphore:
|
||||
pass
|
||||
|
||||
|
||||
ENCODER = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingFunc:
|
||||
embedding_dim: int
|
||||
@@ -311,20 +307,87 @@ def write_json(json_obj, file_name):
|
||||
json.dump(json_obj, f, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"):
|
||||
global ENCODER
|
||||
if ENCODER is None:
|
||||
ENCODER = tiktoken.encoding_for_model(model_name)
|
||||
tokens = ENCODER.encode(content)
|
||||
return tokens
|
||||
class TokenizerInterface(Protocol):
|
||||
"""
|
||||
Defines the interface for a tokenizer, requiring encode and decode methods.
|
||||
"""
|
||||
def encode(self, content: str) -> List[int]:
|
||||
"""Encodes a string into a list of tokens."""
|
||||
...
|
||||
|
||||
def decode(self, tokens: List[int]) -> str:
|
||||
"""Decodes a list of tokens into a string."""
|
||||
...
|
||||
|
||||
class Tokenizer:
|
||||
"""
|
||||
A wrapper around a tokenizer to provide a consistent interface for encoding and decoding.
|
||||
"""
|
||||
def __init__(self, model_name: str, tokenizer: TokenizerInterface):
|
||||
"""
|
||||
Initializes the Tokenizer with a tokenizer model name and a tokenizer instance.
|
||||
|
||||
Args:
|
||||
model_name: The associated model name for the tokenizer.
|
||||
tokenizer: An instance of a class implementing the TokenizerInterface.
|
||||
"""
|
||||
self.model_name: str = model_name
|
||||
self.tokenizer: TokenizerInterface = tokenizer
|
||||
|
||||
def encode(self, content: str) -> List[int]:
|
||||
"""
|
||||
Encodes a string into a list of tokens using the underlying tokenizer.
|
||||
|
||||
Args:
|
||||
content: The string to encode.
|
||||
|
||||
Returns:
|
||||
A list of integer tokens.
|
||||
"""
|
||||
return self.tokenizer.encode(content)
|
||||
|
||||
def decode(self, tokens: List[int]) -> str:
|
||||
"""
|
||||
Decodes a list of tokens into a string using the underlying tokenizer.
|
||||
|
||||
Args:
|
||||
tokens: A list of integer tokens to decode.
|
||||
|
||||
Returns:
|
||||
The decoded string.
|
||||
"""
|
||||
return self.tokenizer.decode(tokens)
|
||||
|
||||
|
||||
def decode_tokens_by_tiktoken(tokens: list[int], model_name: str = "gpt-4o"):
|
||||
global ENCODER
|
||||
if ENCODER is None:
|
||||
ENCODER = tiktoken.encoding_for_model(model_name)
|
||||
content = ENCODER.decode(tokens)
|
||||
return content
|
||||
class TiktokenTokenizer(Tokenizer):
|
||||
"""
|
||||
A Tokenizer implementation using the tiktoken library.
|
||||
"""
|
||||
def __init__(self, model_name: str = "gpt-4o-mini"):
|
||||
"""
|
||||
Initializes the TiktokenTokenizer with a specified model name.
|
||||
|
||||
Args:
|
||||
model_name: The model name for the tiktoken tokenizer to use. Defaults to "gpt-4o-mini".
|
||||
|
||||
Raises:
|
||||
ImportError: If tiktoken is not installed.
|
||||
ValueError: If the model_name is invalid.
|
||||
"""
|
||||
try:
|
||||
import tiktoken
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"tiktoken is not installed. Please install it with `pip install tiktoken` or define custom `tokenizer_func`."
|
||||
)
|
||||
|
||||
try:
|
||||
tokenizer = tiktoken.encoding_for_model(model_name)
|
||||
super().__init__(model_name=model_name, tokenizer=tokenizer)
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f"Invalid model_name: {model_name}."
|
||||
)
|
||||
|
||||
|
||||
def pack_user_ass_to_openai_messages(*args: str):
|
||||
@@ -368,7 +431,7 @@ def truncate_list_by_token_size(
|
||||
return []
|
||||
tokens = 0
|
||||
for i, data in enumerate(list_data):
|
||||
tokens += len(encode_string_by_tiktoken(key(data)))
|
||||
tokens += len(tokenizer.encode(key(data)))
|
||||
if tokens > max_token_size:
|
||||
return list_data[:i]
|
||||
return list_data
|
||||
|
Reference in New Issue
Block a user