add: to optionally replace default tiktoken Tokenizer with a custom one

This commit is contained in:
drahnreb
2025-04-17 10:56:23 +02:00
parent 4fd40fd798
commit 20ba1eb9c2
6 changed files with 138 additions and 53 deletions

View File

@@ -1090,7 +1090,8 @@ rag.clear_cache(modes=["local"])
| **doc_status_storage** | `str` | Storage type for documents process status. Supported types: `JsonDocStatusStorage`,`PGDocStatusStorage`,`MongoDocStatusStorage` | `JsonDocStatusStorage` | | **doc_status_storage** | `str` | Storage type for documents process status. Supported types: `JsonDocStatusStorage`,`PGDocStatusStorage`,`MongoDocStatusStorage` | `JsonDocStatusStorage` |
| **chunk_token_size** | `int` | 拆分文档时每个块的最大令牌大小 | `1200` | | **chunk_token_size** | `int` | 拆分文档时每个块的最大令牌大小 | `1200` |
| **chunk_overlap_token_size** | `int` | 拆分文档时两个块之间的重叠令牌大小 | `100` | | **chunk_overlap_token_size** | `int` | 拆分文档时两个块之间的重叠令牌大小 | `100` |
| **tiktoken_model_name** | `str` | 用于计算令牌数的Tiktoken编码器的模型名称 | `gpt-4o-mini` | | **tokenizer** | `Tokenizer` | 用于将文本转换为 tokens数字以及使用遵循 TokenizerInterface 协议的 .encode() 和 .decode() 函数将 tokens 转换回文本的函数。 如果您不指定,它将使用默认的 Tiktoken tokenizer。 | `TiktokenTokenizer` |
| **tiktoken_model_name** | `str` | 如果您使用的是默认的 Tiktoken tokenizer那么这是要使用的特定 Tiktoken 模型的名称。如果您提供自己的 tokenizer则忽略此设置。 | `gpt-4o-mini` |
| **entity_extract_max_gleaning** | `int` | 实体提取过程中的循环次数,附加历史消息 | `1` | | **entity_extract_max_gleaning** | `int` | 实体提取过程中的循环次数,附加历史消息 | `1` |
| **entity_summary_to_max_tokens** | `int` | 每个实体摘要的最大令牌大小 | `500` | | **entity_summary_to_max_tokens** | `int` | 每个实体摘要的最大令牌大小 | `500` |
| **node_embedding_algorithm** | `str` | 节点嵌入算法(当前未使用) | `node2vec` | | **node_embedding_algorithm** | `str` | 节点嵌入算法(当前未使用) | `node2vec` |

View File

@@ -1156,7 +1156,8 @@ Valid modes are:
| **doc_status_storage** | `str` | Storage type for documents process status. Supported types: `JsonDocStatusStorage`,`PGDocStatusStorage`,`MongoDocStatusStorage` | `JsonDocStatusStorage` | | **doc_status_storage** | `str` | Storage type for documents process status. Supported types: `JsonDocStatusStorage`,`PGDocStatusStorage`,`MongoDocStatusStorage` | `JsonDocStatusStorage` |
| **chunk_token_size** | `int` | Maximum token size per chunk when splitting documents | `1200` | | **chunk_token_size** | `int` | Maximum token size per chunk when splitting documents | `1200` |
| **chunk_overlap_token_size** | `int` | Overlap token size between two chunks when splitting documents | `100` | | **chunk_overlap_token_size** | `int` | Overlap token size between two chunks when splitting documents | `100` |
| **tiktoken_model_name** | `str` | Model name for the Tiktoken encoder used to calculate token numbers | `gpt-4o-mini` | | **tokenizer** | `Tokenizer` | The function used to convert text into tokens (numbers) and back using .encode() and .decode() functions following `TokenizerInterface` protocol. If you don't specify one, it will use the default Tiktoken tokenizer. | `TiktokenTokenizer` |
| **tiktoken_model_name** | `str` | If you're using the default Tiktoken tokenizer, this is the name of the specific Tiktoken model to use. This setting is ignored if you provide your own tokenizer. | `gpt-4o-mini` |
| **entity_extract_max_gleaning** | `int` | Number of loops in the entity extraction process, appending history messages | `1` | | **entity_extract_max_gleaning** | `int` | Number of loops in the entity extraction process, appending history messages | `1` |
| **entity_summary_to_max_tokens** | `int` | Maximum token size for each entity summary | `500` | | **entity_summary_to_max_tokens** | `int` | Maximum token size for each entity summary | `500` |
| **node_embedding_algorithm** | `str` | Algorithm for node embedding (currently not used) | `node2vec` | | **node_embedding_algorithm** | `str` | Algorithm for node embedding (currently not used) | `node2vec` |

View File

@@ -10,7 +10,7 @@ from fastapi.responses import StreamingResponse
import asyncio import asyncio
from ascii_colors import trace_exception from ascii_colors import trace_exception
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
from lightrag.utils import encode_string_by_tiktoken from lightrag.utils import TiktokenTokenizer
from lightrag.api.utils_api import ollama_server_infos, get_combined_auth_dependency from lightrag.api.utils_api import ollama_server_infos, get_combined_auth_dependency
from fastapi import Depends from fastapi import Depends
@@ -97,7 +97,7 @@ class OllamaTagResponse(BaseModel):
def estimate_tokens(text: str) -> int: def estimate_tokens(text: str) -> int:
"""Estimate the number of tokens in text using tiktoken""" """Estimate the number of tokens in text using tiktoken"""
tokens = encode_string_by_tiktoken(text) tokens = TiktokenTokenizer().encode(text)
return len(tokens) return len(tokens)

View File

@@ -41,11 +41,12 @@ from .operate import (
) )
from .prompt import GRAPH_FIELD_SEP, PROMPTS from .prompt import GRAPH_FIELD_SEP, PROMPTS
from .utils import ( from .utils import (
Tokenizer,
TiktokenTokenizer,
EmbeddingFunc, EmbeddingFunc,
always_get_an_event_loop, always_get_an_event_loop,
compute_mdhash_id, compute_mdhash_id,
convert_response_to_json, convert_response_to_json,
encode_string_by_tiktoken,
lazy_external_import, lazy_external_import,
limit_async_func_call, limit_async_func_call,
get_content_summary, get_content_summary,
@@ -122,33 +123,38 @@ class LightRAG:
) )
"""Number of overlapping tokens between consecutive text chunks to preserve context.""" """Number of overlapping tokens between consecutive text chunks to preserve context."""
tiktoken_model_name: str = field(default="gpt-4o-mini") tokenizer: Optional[Tokenizer] = field(default=None)
"""Model name used for tokenization when chunking text.""" """
A function that returns a Tokenizer instance.
If None, and a `tiktoken_model_name` is provided, a TiktokenTokenizer will be created.
If both are None, the default TiktokenTokenizer is used.
"""
"""Maximum number of tokens used for summarizing extracted entities.""" tiktoken_model_name: str = field(default="gpt-4o-mini")
"""Model name used for tokenization when chunking text with tiktoken. Defaults to `gpt-4o-mini`."""
chunking_func: Callable[ chunking_func: Callable[
[ [
Tokenizer,
str, str,
str | None, Optional[str],
bool, bool,
int, int,
int, int,
str,
], ],
list[dict[str, Any]], List[Dict[str, Any]],
] = field(default_factory=lambda: chunking_by_token_size) ] = field(default_factory=lambda: chunking_by_token_size)
""" """
Custom chunking function for splitting text into chunks before processing. Custom chunking function for splitting text into chunks before processing.
The function should take the following parameters: The function should take the following parameters:
- `tokenizer`: A Tokenizer instance to use for tokenization.
- `content`: The text to be split into chunks. - `content`: The text to be split into chunks.
- `split_by_character`: The character to split the text on. If None, the text is split into chunks of `chunk_token_size` tokens. - `split_by_character`: The character to split the text on. If None, the text is split into chunks of `chunk_token_size` tokens.
- `split_by_character_only`: If True, the text is split only on the specified character. - `split_by_character_only`: If True, the text is split only on the specified character.
- `chunk_token_size`: The maximum number of tokens per chunk. - `chunk_token_size`: The maximum number of tokens per chunk.
- `chunk_overlap_token_size`: The number of overlapping tokens between consecutive chunks. - `chunk_overlap_token_size`: The number of overlapping tokens between consecutive chunks.
- `tiktoken_model_name`: The name of the tiktoken model to use for tokenization.
The function should return a list of dictionaries, where each dictionary contains the following keys: The function should return a list of dictionaries, where each dictionary contains the following keys:
- `tokens`: The number of tokens in the chunk. - `tokens`: The number of tokens in the chunk.
@@ -310,7 +316,15 @@ class LightRAG:
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()]) _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
logger.debug(f"LightRAG init with param:\n {_print_config}\n") logger.debug(f"LightRAG init with param:\n {_print_config}\n")
# Init LLM # Init Tokenizer
# Post-initialization hook to handle backward compatabile tokenizer initialization based on provided parameters
if self.tokenizer is None:
if self.tiktoken_model_name:
self.tokenizer = TiktokenTokenizer(self.tiktoken_model_name)
else:
self.tokenizer = TiktokenTokenizer()
# Init Embedding
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore
self.embedding_func self.embedding_func
) )
@@ -900,12 +914,12 @@ class LightRAG:
"file_path": file_path, # Add file path to each chunk "file_path": file_path, # Add file path to each chunk
} }
for dp in self.chunking_func( for dp in self.chunking_func(
self.tokenizer,
status_doc.content, status_doc.content,
split_by_character, split_by_character,
split_by_character_only, split_by_character_only,
self.chunk_overlap_token_size, self.chunk_overlap_token_size,
self.chunk_token_size, self.chunk_token_size,
self.tiktoken_model_name,
) )
} }
@@ -1134,8 +1148,8 @@ class LightRAG:
chunk_content = clean_text(chunk_data["content"]) chunk_content = clean_text(chunk_data["content"])
source_id = chunk_data["source_id"] source_id = chunk_data["source_id"]
tokens = len( tokens = len(
encode_string_by_tiktoken( self.tokenizer.encode(
chunk_content, model_name=self.tiktoken_model_name chunk_content
) )
) )
chunk_order_index = ( chunk_order_index = (

View File

@@ -12,8 +12,7 @@ from .utils import (
logger, logger,
clean_str, clean_str,
compute_mdhash_id, compute_mdhash_id,
decode_tokens_by_tiktoken, Tokenizer,
encode_string_by_tiktoken,
is_float_regex, is_float_regex,
list_of_list_to_csv, list_of_list_to_csv,
normalize_extracted_info, normalize_extracted_info,
@@ -46,32 +45,31 @@ load_dotenv(dotenv_path=".env", override=False)
def chunking_by_token_size( def chunking_by_token_size(
tokenizer: Tokenizer,
content: str, content: str,
split_by_character: str | None = None, split_by_character: str | None = None,
split_by_character_only: bool = False, split_by_character_only: bool = False,
overlap_token_size: int = 128, overlap_token_size: int = 128,
max_token_size: int = 1024, max_token_size: int = 1024,
tiktoken_model: str = "gpt-4o",
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model) tokens = tokenizer.encode(content)
results: list[dict[str, Any]] = [] results: list[dict[str, Any]] = []
if split_by_character: if split_by_character:
raw_chunks = content.split(split_by_character) raw_chunks = content.split(split_by_character)
new_chunks = [] new_chunks = []
if split_by_character_only: if split_by_character_only:
for chunk in raw_chunks: for chunk in raw_chunks:
_tokens = encode_string_by_tiktoken(chunk, model_name=tiktoken_model) _tokens = tokenizer.encode(chunk)
new_chunks.append((len(_tokens), chunk)) new_chunks.append((len(_tokens), chunk))
else: else:
for chunk in raw_chunks: for chunk in raw_chunks:
_tokens = encode_string_by_tiktoken(chunk, model_name=tiktoken_model) _tokens = tokenizer.encode(chunk)
if len(_tokens) > max_token_size: if len(_tokens) > max_token_size:
for start in range( for start in range(
0, len(_tokens), max_token_size - overlap_token_size 0, len(_tokens), max_token_size - overlap_token_size
): ):
chunk_content = decode_tokens_by_tiktoken( chunk_content = tokenizer.decode(
_tokens[start : start + max_token_size], _tokens[start : start + max_token_size]
model_name=tiktoken_model,
) )
new_chunks.append( new_chunks.append(
(min(max_token_size, len(_tokens) - start), chunk_content) (min(max_token_size, len(_tokens) - start), chunk_content)
@@ -90,8 +88,8 @@ def chunking_by_token_size(
for index, start in enumerate( for index, start in enumerate(
range(0, len(tokens), max_token_size - overlap_token_size) range(0, len(tokens), max_token_size - overlap_token_size)
): ):
chunk_content = decode_tokens_by_tiktoken( chunk_content = tokenizer.decode(
tokens[start : start + max_token_size], model_name=tiktoken_model tokens[start : start + max_token_size]
) )
results.append( results.append(
{ {
@@ -116,6 +114,7 @@ async def _handle_entity_relation_summary(
If too long, use LLM to summarize. If too long, use LLM to summarize.
""" """
use_llm_func: callable = global_config["llm_model_func"] use_llm_func: callable = global_config["llm_model_func"]
tokenizer: Tokenizer = global_config["tokenizer"]
llm_max_tokens = global_config["llm_model_max_token_size"] llm_max_tokens = global_config["llm_model_max_token_size"]
tiktoken_model_name = global_config["tiktoken_model_name"] tiktoken_model_name = global_config["tiktoken_model_name"]
summary_max_tokens = global_config["summary_to_max_tokens"] summary_max_tokens = global_config["summary_to_max_tokens"]
@@ -124,10 +123,12 @@ async def _handle_entity_relation_summary(
"language", PROMPTS["DEFAULT_LANGUAGE"] "language", PROMPTS["DEFAULT_LANGUAGE"]
) )
tokens = encode_string_by_tiktoken(description, model_name=tiktoken_model_name) tokens = tokenizer.encode(description)
if len(tokens) < summary_max_tokens: # No need for summary
return description
prompt_template = PROMPTS["summarize_entity_descriptions"] prompt_template = PROMPTS["summarize_entity_descriptions"]
use_description = decode_tokens_by_tiktoken( use_description = tokenizer.decode(
tokens[:llm_max_tokens], model_name=tiktoken_model_name tokens[:llm_max_tokens]
) )
context_base = dict( context_base = dict(
entity_name=entity_or_relation_name, entity_name=entity_or_relation_name,
@@ -865,7 +866,8 @@ async def kg_query(
if query_param.only_need_prompt: if query_param.only_need_prompt:
return sys_prompt return sys_prompt
len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt)) tokenizer: Tokenizer = global_config["tokenizer"]
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}") logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
response = await use_model_func( response = await use_model_func(
@@ -987,7 +989,8 @@ async def extract_keywords_only(
query=text, examples=examples, language=language, history=history_context query=text, examples=examples, language=language, history=history_context
) )
len_of_prompts = len(encode_string_by_tiktoken(kw_prompt)) tokenizer: Tokenizer = global_config["tokenizer"]
len_of_prompts = len(tokenizer.encode(kw_prompt))
logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}") logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
# 5. Call the LLM for keyword extraction # 5. Call the LLM for keyword extraction
@@ -1210,7 +1213,8 @@ async def mix_kg_vector_query(
if query_param.only_need_prompt: if query_param.only_need_prompt:
return sys_prompt return sys_prompt
len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt)) tokenizer: Tokenizer = global_config["tokenizer"]
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
logger.debug(f"[mix_kg_vector_query]Prompt Tokens: {len_of_prompts}") logger.debug(f"[mix_kg_vector_query]Prompt Tokens: {len_of_prompts}")
# 6. Generate response # 6. Generate response
@@ -1978,7 +1982,8 @@ async def naive_query(
if query_param.only_need_prompt: if query_param.only_need_prompt:
return sys_prompt return sys_prompt
len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt)) tokenizer: Tokenizer = global_config["tokenizer"]
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
logger.debug(f"[naive_query]Prompt Tokens: {len_of_prompts}") logger.debug(f"[naive_query]Prompt Tokens: {len_of_prompts}")
response = await use_model_func( response = await use_model_func(
@@ -2125,7 +2130,8 @@ async def kg_query_with_keywords(
if query_param.only_need_prompt: if query_param.only_need_prompt:
return sys_prompt return sys_prompt
len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt)) tokenizer: Tokenizer = global_config["tokenizer"]
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
logger.debug(f"[kg_query_with_keywords]Prompt Tokens: {len_of_prompts}") logger.debug(f"[kg_query_with_keywords]Prompt Tokens: {len_of_prompts}")
# 6. Generate response # 6. Generate response

View File

@@ -12,10 +12,9 @@ import re
from dataclasses import dataclass from dataclasses import dataclass
from functools import wraps from functools import wraps
from hashlib import md5 from hashlib import md5
from typing import Any, Callable, TYPE_CHECKING from typing import Any, Protocol, Callable, TYPE_CHECKING, List, Optional, Union
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
import numpy as np import numpy as np
import tiktoken
from lightrag.prompt import PROMPTS from lightrag.prompt import PROMPTS
from dotenv import load_dotenv from dotenv import load_dotenv
@@ -193,9 +192,6 @@ class UnlimitedSemaphore:
pass pass
ENCODER = None
@dataclass @dataclass
class EmbeddingFunc: class EmbeddingFunc:
embedding_dim: int embedding_dim: int
@@ -311,20 +307,87 @@ def write_json(json_obj, file_name):
json.dump(json_obj, f, indent=2, ensure_ascii=False) json.dump(json_obj, f, indent=2, ensure_ascii=False)
def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"): class TokenizerInterface(Protocol):
global ENCODER """
if ENCODER is None: Defines the interface for a tokenizer, requiring encode and decode methods.
ENCODER = tiktoken.encoding_for_model(model_name) """
tokens = ENCODER.encode(content) def encode(self, content: str) -> List[int]:
return tokens """Encodes a string into a list of tokens."""
...
def decode(self, tokens: List[int]) -> str:
"""Decodes a list of tokens into a string."""
...
class Tokenizer:
"""
A wrapper around a tokenizer to provide a consistent interface for encoding and decoding.
"""
def __init__(self, model_name: str, tokenizer: TokenizerInterface):
"""
Initializes the Tokenizer with a tokenizer model name and a tokenizer instance.
Args:
model_name: The associated model name for the tokenizer.
tokenizer: An instance of a class implementing the TokenizerInterface.
"""
self.model_name: str = model_name
self.tokenizer: TokenizerInterface = tokenizer
def encode(self, content: str) -> List[int]:
"""
Encodes a string into a list of tokens using the underlying tokenizer.
Args:
content: The string to encode.
Returns:
A list of integer tokens.
"""
return self.tokenizer.encode(content)
def decode(self, tokens: List[int]) -> str:
"""
Decodes a list of tokens into a string using the underlying tokenizer.
Args:
tokens: A list of integer tokens to decode.
Returns:
The decoded string.
"""
return self.tokenizer.decode(tokens)
def decode_tokens_by_tiktoken(tokens: list[int], model_name: str = "gpt-4o"): class TiktokenTokenizer(Tokenizer):
global ENCODER """
if ENCODER is None: A Tokenizer implementation using the tiktoken library.
ENCODER = tiktoken.encoding_for_model(model_name) """
content = ENCODER.decode(tokens) def __init__(self, model_name: str = "gpt-4o-mini"):
return content """
Initializes the TiktokenTokenizer with a specified model name.
Args:
model_name: The model name for the tiktoken tokenizer to use. Defaults to "gpt-4o-mini".
Raises:
ImportError: If tiktoken is not installed.
ValueError: If the model_name is invalid.
"""
try:
import tiktoken
except ImportError:
raise ImportError(
"tiktoken is not installed. Please install it with `pip install tiktoken` or define custom `tokenizer_func`."
)
try:
tokenizer = tiktoken.encoding_for_model(model_name)
super().__init__(model_name=model_name, tokenizer=tokenizer)
except KeyError:
raise ValueError(
f"Invalid model_name: {model_name}."
)
def pack_user_ass_to_openai_messages(*args: str): def pack_user_ass_to_openai_messages(*args: str):
@@ -368,7 +431,7 @@ def truncate_list_by_token_size(
return [] return []
tokens = 0 tokens = 0
for i, data in enumerate(list_data): for i, data in enumerate(list_data):
tokens += len(encode_string_by_tiktoken(key(data))) tokens += len(tokenizer.encode(key(data)))
if tokens > max_token_size: if tokens > max_token_size:
return list_data[:i] return list_data[:i]
return list_data return list_data