Merge branch 'drahnreb/add-custom-tokenizer'

This commit is contained in:
yangdx
2025-04-20 12:22:10 +08:00
7 changed files with 413 additions and 71 deletions

View File

@@ -10,7 +10,7 @@ from fastapi.responses import StreamingResponse
import asyncio
from ascii_colors import trace_exception
from lightrag import LightRAG, QueryParam
from lightrag.utils import encode_string_by_tiktoken
from lightrag.utils import TiktokenTokenizer
from lightrag.api.utils_api import ollama_server_infos, get_combined_auth_dependency
from fastapi import Depends
@@ -97,7 +97,7 @@ class OllamaTagResponse(BaseModel):
def estimate_tokens(text: str) -> int:
"""Estimate the number of tokens in text using tiktoken"""
tokens = encode_string_by_tiktoken(text)
tokens = TiktokenTokenizer().encode(text)
return len(tokens)

View File

@@ -7,7 +7,18 @@ import warnings
from dataclasses import asdict, dataclass, field
from datetime import datetime
from functools import partial
from typing import Any, AsyncIterator, Callable, Iterator, cast, final, Literal
from typing import (
Any,
AsyncIterator,
Callable,
Iterator,
cast,
final,
Literal,
Optional,
List,
Dict,
)
from lightrag.kg import (
STORAGES,
@@ -41,11 +52,12 @@ from .operate import (
)
from .prompt import GRAPH_FIELD_SEP, PROMPTS
from .utils import (
Tokenizer,
TiktokenTokenizer,
EmbeddingFunc,
always_get_an_event_loop,
compute_mdhash_id,
convert_response_to_json,
encode_string_by_tiktoken,
lazy_external_import,
limit_async_func_call,
get_content_summary,
@@ -122,33 +134,38 @@ class LightRAG:
)
"""Number of overlapping tokens between consecutive text chunks to preserve context."""
tiktoken_model_name: str = field(default="gpt-4o-mini")
"""Model name used for tokenization when chunking text."""
tokenizer: Optional[Tokenizer] = field(default=None)
"""
A function that returns a Tokenizer instance.
If None, and a `tiktoken_model_name` is provided, a TiktokenTokenizer will be created.
If both are None, the default TiktokenTokenizer is used.
"""
"""Maximum number of tokens used for summarizing extracted entities."""
tiktoken_model_name: str = field(default="gpt-4o-mini")
"""Model name used for tokenization when chunking text with tiktoken. Defaults to `gpt-4o-mini`."""
chunking_func: Callable[
[
Tokenizer,
str,
str | None,
Optional[str],
bool,
int,
int,
str,
],
list[dict[str, Any]],
List[Dict[str, Any]],
] = field(default_factory=lambda: chunking_by_token_size)
"""
Custom chunking function for splitting text into chunks before processing.
The function should take the following parameters:
- `tokenizer`: A Tokenizer instance to use for tokenization.
- `content`: The text to be split into chunks.
- `split_by_character`: The character to split the text on. If None, the text is split into chunks of `chunk_token_size` tokens.
- `split_by_character_only`: If True, the text is split only on the specified character.
- `chunk_token_size`: The maximum number of tokens per chunk.
- `chunk_overlap_token_size`: The number of overlapping tokens between consecutive chunks.
- `tiktoken_model_name`: The name of the tiktoken model to use for tokenization.
The function should return a list of dictionaries, where each dictionary contains the following keys:
- `tokens`: The number of tokens in the chunk.
@@ -310,7 +327,15 @@ class LightRAG:
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
# Init LLM
# Init Tokenizer
# Post-initialization hook to handle backward compatabile tokenizer initialization based on provided parameters
if self.tokenizer is None:
if self.tiktoken_model_name:
self.tokenizer = TiktokenTokenizer(self.tiktoken_model_name)
else:
self.tokenizer = TiktokenTokenizer()
# Init Embedding
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore
self.embedding_func
)
@@ -603,11 +628,7 @@ class LightRAG:
inserting_chunks: dict[str, Any] = {}
for index, chunk_text in enumerate(text_chunks):
chunk_key = compute_mdhash_id(chunk_text, prefix="chunk-")
tokens = len(
encode_string_by_tiktoken(
chunk_text, model_name=self.tiktoken_model_name
)
)
tokens = len(self.tokenizer.encode(chunk_text))
inserting_chunks[chunk_key] = {
"content": chunk_text,
"full_doc_id": doc_key,
@@ -900,12 +921,12 @@ class LightRAG:
"file_path": file_path, # Add file path to each chunk
}
for dp in self.chunking_func(
self.tokenizer,
status_doc.content,
split_by_character,
split_by_character_only,
self.chunk_overlap_token_size,
self.chunk_token_size,
self.tiktoken_model_name,
)
}
@@ -1133,11 +1154,7 @@ class LightRAG:
for chunk_data in custom_kg.get("chunks", []):
chunk_content = clean_text(chunk_data["content"])
source_id = chunk_data["source_id"]
tokens = len(
encode_string_by_tiktoken(
chunk_content, model_name=self.tiktoken_model_name
)
)
tokens = len(self.tokenizer.encode(chunk_content))
chunk_order_index = (
0
if "chunk_order_index" not in chunk_data.keys()

View File

@@ -12,8 +12,7 @@ from .utils import (
logger,
clean_str,
compute_mdhash_id,
decode_tokens_by_tiktoken,
encode_string_by_tiktoken,
Tokenizer,
is_float_regex,
list_of_list_to_csv,
normalize_extracted_info,
@@ -46,32 +45,31 @@ load_dotenv(dotenv_path=".env", override=False)
def chunking_by_token_size(
tokenizer: Tokenizer,
content: str,
split_by_character: str | None = None,
split_by_character_only: bool = False,
overlap_token_size: int = 128,
max_token_size: int = 1024,
tiktoken_model: str = "gpt-4o",
) -> list[dict[str, Any]]:
tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model)
tokens = tokenizer.encode(content)
results: list[dict[str, Any]] = []
if split_by_character:
raw_chunks = content.split(split_by_character)
new_chunks = []
if split_by_character_only:
for chunk in raw_chunks:
_tokens = encode_string_by_tiktoken(chunk, model_name=tiktoken_model)
_tokens = tokenizer.encode(chunk)
new_chunks.append((len(_tokens), chunk))
else:
for chunk in raw_chunks:
_tokens = encode_string_by_tiktoken(chunk, model_name=tiktoken_model)
_tokens = tokenizer.encode(chunk)
if len(_tokens) > max_token_size:
for start in range(
0, len(_tokens), max_token_size - overlap_token_size
):
chunk_content = decode_tokens_by_tiktoken(
_tokens[start : start + max_token_size],
model_name=tiktoken_model,
chunk_content = tokenizer.decode(
_tokens[start : start + max_token_size]
)
new_chunks.append(
(min(max_token_size, len(_tokens) - start), chunk_content)
@@ -90,9 +88,7 @@ def chunking_by_token_size(
for index, start in enumerate(
range(0, len(tokens), max_token_size - overlap_token_size)
):
chunk_content = decode_tokens_by_tiktoken(
tokens[start : start + max_token_size], model_name=tiktoken_model
)
chunk_content = tokenizer.decode(tokens[start : start + max_token_size])
results.append(
{
"tokens": min(max_token_size, len(tokens) - start),
@@ -116,19 +112,19 @@ async def _handle_entity_relation_summary(
If too long, use LLM to summarize.
"""
use_llm_func: callable = global_config["llm_model_func"]
tokenizer: Tokenizer = global_config["tokenizer"]
llm_max_tokens = global_config["llm_model_max_token_size"]
tiktoken_model_name = global_config["tiktoken_model_name"]
summary_max_tokens = global_config["summary_to_max_tokens"]
language = global_config["addon_params"].get(
"language", PROMPTS["DEFAULT_LANGUAGE"]
)
tokens = encode_string_by_tiktoken(description, model_name=tiktoken_model_name)
tokens = tokenizer.encode(description)
if len(tokens) < summary_max_tokens: # No need for summary
return description
prompt_template = PROMPTS["summarize_entity_descriptions"]
use_description = decode_tokens_by_tiktoken(
tokens[:llm_max_tokens], model_name=tiktoken_model_name
)
use_description = tokenizer.decode(tokens[:llm_max_tokens])
context_base = dict(
entity_name=entity_or_relation_name,
description_list=use_description.split(GRAPH_FIELD_SEP),
@@ -865,7 +861,8 @@ async def kg_query(
if query_param.only_need_prompt:
return sys_prompt
len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt))
tokenizer: Tokenizer = global_config["tokenizer"]
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
response = await use_model_func(
@@ -987,7 +984,8 @@ async def extract_keywords_only(
query=text, examples=examples, language=language, history=history_context
)
len_of_prompts = len(encode_string_by_tiktoken(kw_prompt))
tokenizer: Tokenizer = global_config["tokenizer"]
len_of_prompts = len(tokenizer.encode(kw_prompt))
logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
# 5. Call the LLM for keyword extraction
@@ -1054,6 +1052,8 @@ async def mix_kg_vector_query(
2. Retrieving relevant text chunks through vector similarity
3. Combining both results for comprehensive answer generation
"""
# get tokenizer
tokenizer: Tokenizer = global_config["tokenizer"]
# 1. Cache handling
use_model_func = (
query_param.model_func
@@ -1153,6 +1153,7 @@ async def mix_kg_vector_query(
valid_chunks,
key=lambda x: x["content"],
max_token_size=query_param.max_token_for_text_unit,
tokenizer=tokenizer,
)
if not maybe_trun_chunks:
@@ -1210,7 +1211,7 @@ async def mix_kg_vector_query(
if query_param.only_need_prompt:
return sys_prompt
len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt))
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
logger.debug(f"[mix_kg_vector_query]Prompt Tokens: {len_of_prompts}")
# 6. Generate response
@@ -1373,17 +1374,24 @@ async def _get_node_data(
] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
# get entitytext chunk
use_text_units = await _find_most_related_text_unit_from_entities(
node_datas, query_param, text_chunks_db, knowledge_graph_inst
node_datas,
query_param,
text_chunks_db,
knowledge_graph_inst,
)
use_relations = await _find_most_related_edges_from_entities(
node_datas, query_param, knowledge_graph_inst
node_datas,
query_param,
knowledge_graph_inst,
)
tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
len_node_datas = len(node_datas)
node_datas = truncate_list_by_token_size(
node_datas,
key=lambda x: x["description"] if x["description"] is not None else "",
max_token_size=query_param.max_token_for_local_context,
tokenizer=tokenizer,
)
logger.debug(
f"Truncate entities from {len_node_datas} to {len(node_datas)} (max tokens:{query_param.max_token_for_local_context})"
@@ -1558,14 +1566,15 @@ async def _find_most_related_text_unit_from_entities(
logger.warning("No valid text units found")
return []
tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
all_text_units = sorted(
all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
)
all_text_units = truncate_list_by_token_size(
all_text_units,
key=lambda x: x["data"]["content"],
max_token_size=query_param.max_token_for_text_unit,
tokenizer=tokenizer,
)
logger.debug(
@@ -1619,6 +1628,7 @@ async def _find_most_related_edges_from_entities(
}
all_edges_data.append(combined)
tokenizer: Tokenizer = knowledge_graph_inst.global_config.get("tokenizer")
all_edges_data = sorted(
all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
)
@@ -1626,6 +1636,7 @@ async def _find_most_related_edges_from_entities(
all_edges_data,
key=lambda x: x["description"] if x["description"] is not None else "",
max_token_size=query_param.max_token_for_global_context,
tokenizer=tokenizer,
)
logger.debug(
@@ -1681,6 +1692,7 @@ async def _get_edge_data(
}
edge_datas.append(combined)
tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
edge_datas = sorted(
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
)
@@ -1688,13 +1700,19 @@ async def _get_edge_data(
edge_datas,
key=lambda x: x["description"] if x["description"] is not None else "",
max_token_size=query_param.max_token_for_global_context,
tokenizer=tokenizer,
)
use_entities, use_text_units = await asyncio.gather(
_find_most_related_entities_from_relationships(
edge_datas, query_param, knowledge_graph_inst
edge_datas,
query_param,
knowledge_graph_inst,
),
_find_related_text_unit_from_relationships(
edge_datas, query_param, text_chunks_db, knowledge_graph_inst
edge_datas,
query_param,
text_chunks_db,
knowledge_graph_inst,
),
)
logger.info(
@@ -1804,11 +1822,13 @@ async def _find_most_related_entities_from_relationships(
combined = {**node, "entity_name": entity_name, "rank": degree}
node_datas.append(combined)
tokenizer: Tokenizer = knowledge_graph_inst.global_config.get("tokenizer")
len_node_datas = len(node_datas)
node_datas = truncate_list_by_token_size(
node_datas,
key=lambda x: x["description"] if x["description"] is not None else "",
max_token_size=query_param.max_token_for_local_context,
tokenizer=tokenizer,
)
logger.debug(
f"Truncate entities from {len_node_datas} to {len(node_datas)} (max tokens:{query_param.max_token_for_local_context})"
@@ -1863,10 +1883,12 @@ async def _find_related_text_unit_from_relationships(
logger.warning("No valid text chunks after filtering")
return []
tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
truncated_text_units = truncate_list_by_token_size(
valid_text_units,
key=lambda x: x["data"]["content"],
max_token_size=query_param.max_token_for_text_unit,
tokenizer=tokenizer,
)
logger.debug(
@@ -1937,10 +1959,12 @@ async def naive_query(
logger.warning("No valid chunks found after filtering")
return PROMPTS["fail_response"]
tokenizer: Tokenizer = global_config["tokenizer"]
maybe_trun_chunks = truncate_list_by_token_size(
valid_chunks,
key=lambda x: x["content"],
max_token_size=query_param.max_token_for_text_unit,
tokenizer=tokenizer,
)
if not maybe_trun_chunks:
@@ -1978,7 +2002,7 @@ async def naive_query(
if query_param.only_need_prompt:
return sys_prompt
len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt))
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
logger.debug(f"[naive_query]Prompt Tokens: {len_of_prompts}")
response = await use_model_func(
@@ -2125,7 +2149,8 @@ async def kg_query_with_keywords(
if query_param.only_need_prompt:
return sys_prompt
len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt))
tokenizer: Tokenizer = global_config["tokenizer"]
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
logger.debug(f"[kg_query_with_keywords]Prompt Tokens: {len_of_prompts}")
# 6. Generate response

View File

@@ -12,10 +12,9 @@ import re
from dataclasses import dataclass
from functools import wraps
from hashlib import md5
from typing import Any, Callable, TYPE_CHECKING
from typing import Any, Protocol, Callable, TYPE_CHECKING, List
import xml.etree.ElementTree as ET
import numpy as np
import tiktoken
from lightrag.prompt import PROMPTS
from dotenv import load_dotenv
@@ -193,9 +192,6 @@ class UnlimitedSemaphore:
pass
ENCODER = None
@dataclass
class EmbeddingFunc:
embedding_dim: int
@@ -311,20 +307,89 @@ def write_json(json_obj, file_name):
json.dump(json_obj, f, indent=2, ensure_ascii=False)
def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"):
global ENCODER
if ENCODER is None:
ENCODER = tiktoken.encoding_for_model(model_name)
tokens = ENCODER.encode(content)
return tokens
class TokenizerInterface(Protocol):
"""
Defines the interface for a tokenizer, requiring encode and decode methods.
"""
def encode(self, content: str) -> List[int]:
"""Encodes a string into a list of tokens."""
...
def decode(self, tokens: List[int]) -> str:
"""Decodes a list of tokens into a string."""
...
def decode_tokens_by_tiktoken(tokens: list[int], model_name: str = "gpt-4o"):
global ENCODER
if ENCODER is None:
ENCODER = tiktoken.encoding_for_model(model_name)
content = ENCODER.decode(tokens)
return content
class Tokenizer:
"""
A wrapper around a tokenizer to provide a consistent interface for encoding and decoding.
"""
def __init__(self, model_name: str, tokenizer: TokenizerInterface):
"""
Initializes the Tokenizer with a tokenizer model name and a tokenizer instance.
Args:
model_name: The associated model name for the tokenizer.
tokenizer: An instance of a class implementing the TokenizerInterface.
"""
self.model_name: str = model_name
self.tokenizer: TokenizerInterface = tokenizer
def encode(self, content: str) -> List[int]:
"""
Encodes a string into a list of tokens using the underlying tokenizer.
Args:
content: The string to encode.
Returns:
A list of integer tokens.
"""
return self.tokenizer.encode(content)
def decode(self, tokens: List[int]) -> str:
"""
Decodes a list of tokens into a string using the underlying tokenizer.
Args:
tokens: A list of integer tokens to decode.
Returns:
The decoded string.
"""
return self.tokenizer.decode(tokens)
class TiktokenTokenizer(Tokenizer):
"""
A Tokenizer implementation using the tiktoken library.
"""
def __init__(self, model_name: str = "gpt-4o-mini"):
"""
Initializes the TiktokenTokenizer with a specified model name.
Args:
model_name: The model name for the tiktoken tokenizer to use. Defaults to "gpt-4o-mini".
Raises:
ImportError: If tiktoken is not installed.
ValueError: If the model_name is invalid.
"""
try:
import tiktoken
except ImportError:
raise ImportError(
"tiktoken is not installed. Please install it with `pip install tiktoken` or define custom `tokenizer_func`."
)
try:
tokenizer = tiktoken.encoding_for_model(model_name)
super().__init__(model_name=model_name, tokenizer=tokenizer)
except KeyError:
raise ValueError(f"Invalid model_name: {model_name}.")
def pack_user_ass_to_openai_messages(*args: str):
@@ -361,14 +426,17 @@ def is_float_regex(value: str) -> bool:
def truncate_list_by_token_size(
list_data: list[Any], key: Callable[[Any], str], max_token_size: int
list_data: list[Any],
key: Callable[[Any], str],
max_token_size: int,
tokenizer: Tokenizer,
) -> list[int]:
"""Truncate a list of data by token size"""
if max_token_size <= 0:
return []
tokens = 0
for i, data in enumerate(list_data):
tokens += len(encode_string_by_tiktoken(key(data)))
tokens += len(tokenizer.encode(key(data)))
if tokens > max_token_size:
return list_data[:i]
return list_data