Merge branch 'drahnreb/add-custom-tokenizer'

This commit is contained in:
yangdx
2025-04-20 12:22:10 +08:00
7 changed files with 413 additions and 71 deletions

View File

@@ -1090,7 +1090,8 @@ rag.clear_cache(modes=["local"])
| **doc_status_storage** | `str` | Storage type for documents process status. Supported types: `JsonDocStatusStorage`,`PGDocStatusStorage`,`MongoDocStatusStorage` | `JsonDocStatusStorage` |
| **chunk_token_size** | `int` | 拆分文档时每个块的最大令牌大小 | `1200` |
| **chunk_overlap_token_size** | `int` | 拆分文档时两个块之间的重叠令牌大小 | `100` |
| **tiktoken_model_name** | `str` | 用于计算令牌数的Tiktoken编码器的模型名称 | `gpt-4o-mini` |
| **tokenizer** | `Tokenizer` | 用于将文本转换为 tokens数字以及使用遵循 TokenizerInterface 协议的 .encode() 和 .decode() 函数将 tokens 转换回文本的函数。 如果您不指定,它将使用默认的 Tiktoken tokenizer。 | `TiktokenTokenizer` |
| **tiktoken_model_name** | `str` | 如果您使用的是默认的 Tiktoken tokenizer那么这是要使用的特定 Tiktoken 模型的名称。如果您提供自己的 tokenizer则忽略此设置。 | `gpt-4o-mini` |
| **entity_extract_max_gleaning** | `int` | 实体提取过程中的循环次数,附加历史消息 | `1` |
| **entity_summary_to_max_tokens** | `int` | 每个实体摘要的最大令牌大小 | `500` |
| **node_embedding_algorithm** | `str` | 节点嵌入算法(当前未使用) | `node2vec` |

View File

@@ -1156,7 +1156,8 @@ Valid modes are:
| **doc_status_storage** | `str` | Storage type for documents process status. Supported types: `JsonDocStatusStorage`,`PGDocStatusStorage`,`MongoDocStatusStorage` | `JsonDocStatusStorage` |
| **chunk_token_size** | `int` | Maximum token size per chunk when splitting documents | `1200` |
| **chunk_overlap_token_size** | `int` | Overlap token size between two chunks when splitting documents | `100` |
| **tiktoken_model_name** | `str` | Model name for the Tiktoken encoder used to calculate token numbers | `gpt-4o-mini` |
| **tokenizer** | `Tokenizer` | The function used to convert text into tokens (numbers) and back using .encode() and .decode() functions following `TokenizerInterface` protocol. If you don't specify one, it will use the default Tiktoken tokenizer. | `TiktokenTokenizer` |
| **tiktoken_model_name** | `str` | If you're using the default Tiktoken tokenizer, this is the name of the specific Tiktoken model to use. This setting is ignored if you provide your own tokenizer. | `gpt-4o-mini` |
| **entity_extract_max_gleaning** | `int` | Number of loops in the entity extraction process, appending history messages | `1` |
| **entity_summary_to_max_tokens** | `int` | Maximum token size for each entity summary | `500` |
| **node_embedding_algorithm** | `str` | Algorithm for node embedding (currently not used) | `node2vec` |

View File

@@ -0,0 +1,230 @@
# pip install -q -U google-genai to use gemini as a client
import os
from typing import Optional
import dataclasses
from pathlib import Path
import hashlib
import numpy as np
from google import genai
from google.genai import types
from dotenv import load_dotenv
from lightrag.utils import EmbeddingFunc, Tokenizer
from lightrag import LightRAG, QueryParam
from sentence_transformers import SentenceTransformer
from lightrag.kg.shared_storage import initialize_pipeline_status
import sentencepiece as spm
import requests
import asyncio
import nest_asyncio
# Apply nest_asyncio to solve event loop issues
nest_asyncio.apply()
load_dotenv()
gemini_api_key = os.getenv("GEMINI_API_KEY")
WORKING_DIR = "./dickens"
if os.path.exists(WORKING_DIR):
import shutil
shutil.rmtree(WORKING_DIR)
os.mkdir(WORKING_DIR)
class GemmaTokenizer(Tokenizer):
# adapted from google-cloud-aiplatform[tokenization]
@dataclasses.dataclass(frozen=True)
class _TokenizerConfig:
tokenizer_model_url: str
tokenizer_model_hash: str
_TOKENIZERS = {
"google/gemma2": _TokenizerConfig(
tokenizer_model_url="https://raw.githubusercontent.com/google/gemma_pytorch/33b652c465537c6158f9a472ea5700e5e770ad3f/tokenizer/tokenizer.model",
tokenizer_model_hash="61a7b147390c64585d6c3543dd6fc636906c9af3865a5548f27f31aee1d4c8e2",
),
"google/gemma3": _TokenizerConfig(
tokenizer_model_url="https://raw.githubusercontent.com/google/gemma_pytorch/cb7c0152a369e43908e769eb09e1ce6043afe084/tokenizer/gemma3_cleaned_262144_v2.spiece.model",
tokenizer_model_hash="1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c",
),
}
def __init__(
self, model_name: str = "gemini-2.0-flash", tokenizer_dir: Optional[str] = None
):
# https://github.com/google/gemma_pytorch/tree/main/tokenizer
if "1.5" in model_name or "1.0" in model_name:
# up to gemini 1.5 gemma2 is a comparable local tokenizer
# https://github.com/googleapis/python-aiplatform/blob/main/vertexai/tokenization/_tokenizer_loading.py
tokenizer_name = "google/gemma2"
else:
# for gemini > 2.0 gemma3 was used
tokenizer_name = "google/gemma3"
file_url = self._TOKENIZERS[tokenizer_name].tokenizer_model_url
tokenizer_model_name = file_url.rsplit("/", 1)[1]
expected_hash = self._TOKENIZERS[tokenizer_name].tokenizer_model_hash
tokenizer_dir = Path(tokenizer_dir)
if tokenizer_dir.is_dir():
file_path = tokenizer_dir / tokenizer_model_name
model_data = self._maybe_load_from_cache(
file_path=file_path, expected_hash=expected_hash
)
else:
model_data = None
if not model_data:
model_data = self._load_from_url(
file_url=file_url, expected_hash=expected_hash
)
self.save_tokenizer_to_cache(cache_path=file_path, model_data=model_data)
tokenizer = spm.SentencePieceProcessor()
tokenizer.LoadFromSerializedProto(model_data)
super().__init__(model_name=model_name, tokenizer=tokenizer)
def _is_valid_model(self, model_data: bytes, expected_hash: str) -> bool:
"""Returns true if the content is valid by checking the hash."""
return hashlib.sha256(model_data).hexdigest() == expected_hash
def _maybe_load_from_cache(self, file_path: Path, expected_hash: str) -> bytes:
"""Loads the model data from the cache path."""
if not file_path.is_file():
return
with open(file_path, "rb") as f:
content = f.read()
if self._is_valid_model(model_data=content, expected_hash=expected_hash):
return content
# Cached file corrupted.
self._maybe_remove_file(file_path)
def _load_from_url(self, file_url: str, expected_hash: str) -> bytes:
"""Loads model bytes from the given file url."""
resp = requests.get(file_url)
resp.raise_for_status()
content = resp.content
if not self._is_valid_model(model_data=content, expected_hash=expected_hash):
actual_hash = hashlib.sha256(content).hexdigest()
raise ValueError(
f"Downloaded model file is corrupted."
f" Expected hash {expected_hash}. Got file hash {actual_hash}."
)
return content
@staticmethod
def save_tokenizer_to_cache(cache_path: Path, model_data: bytes) -> None:
"""Saves the model data to the cache path."""
try:
if not cache_path.is_file():
cache_dir = cache_path.parent
cache_dir.mkdir(parents=True, exist_ok=True)
with open(cache_path, "wb") as f:
f.write(model_data)
except OSError:
# Don't raise if we cannot write file.
pass
@staticmethod
def _maybe_remove_file(file_path: Path) -> None:
"""Removes the file if exists."""
if not file_path.is_file():
return
try:
file_path.unlink()
except OSError:
# Don't raise if we cannot remove file.
pass
# def encode(self, content: str) -> list[int]:
# return self.tokenizer.encode(content)
# def decode(self, tokens: list[int]) -> str:
# return self.tokenizer.decode(tokens)
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
# 1. Initialize the GenAI Client with your Gemini API Key
client = genai.Client(api_key=gemini_api_key)
# 2. Combine prompts: system prompt, history, and user prompt
if history_messages is None:
history_messages = []
combined_prompt = ""
if system_prompt:
combined_prompt += f"{system_prompt}\n"
for msg in history_messages:
# Each msg is expected to be a dict: {"role": "...", "content": "..."}
combined_prompt += f"{msg['role']}: {msg['content']}\n"
# Finally, add the new user prompt
combined_prompt += f"user: {prompt}"
# 3. Call the Gemini model
response = client.models.generate_content(
model="gemini-1.5-flash",
contents=[combined_prompt],
config=types.GenerateContentConfig(max_output_tokens=500, temperature=0.1),
)
# 4. Return the response text
return response.text
async def embedding_func(texts: list[str]) -> np.ndarray:
model = SentenceTransformer("all-MiniLM-L6-v2")
embeddings = model.encode(texts, convert_to_numpy=True)
return embeddings
async def initialize_rag():
rag = LightRAG(
working_dir=WORKING_DIR,
# tiktoken_model_name="gpt-4o-mini",
tokenizer=GemmaTokenizer(
tokenizer_dir=(Path(WORKING_DIR) / "vertexai_tokenizer_model"),
model_name="gemini-2.0-flash",
),
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=384,
max_token_size=8192,
func=embedding_func,
),
)
await rag.initialize_storages()
await initialize_pipeline_status()
return rag
def main():
# Initialize RAG instance
rag = asyncio.run(initialize_rag())
file_path = "story.txt"
with open(file_path, "r") as file:
text = file.read()
rag.insert(text)
response = rag.query(
query="What is the main theme of the story?",
param=QueryParam(mode="hybrid", top_k=5, response_type="single line"),
)
print(response)
if __name__ == "__main__":
main()

View File

@@ -10,7 +10,7 @@ from fastapi.responses import StreamingResponse
import asyncio
from ascii_colors import trace_exception
from lightrag import LightRAG, QueryParam
from lightrag.utils import encode_string_by_tiktoken
from lightrag.utils import TiktokenTokenizer
from lightrag.api.utils_api import ollama_server_infos, get_combined_auth_dependency
from fastapi import Depends
@@ -97,7 +97,7 @@ class OllamaTagResponse(BaseModel):
def estimate_tokens(text: str) -> int:
"""Estimate the number of tokens in text using tiktoken"""
tokens = encode_string_by_tiktoken(text)
tokens = TiktokenTokenizer().encode(text)
return len(tokens)

View File

@@ -7,7 +7,18 @@ import warnings
from dataclasses import asdict, dataclass, field
from datetime import datetime
from functools import partial
from typing import Any, AsyncIterator, Callable, Iterator, cast, final, Literal
from typing import (
Any,
AsyncIterator,
Callable,
Iterator,
cast,
final,
Literal,
Optional,
List,
Dict,
)
from lightrag.kg import (
STORAGES,
@@ -41,11 +52,12 @@ from .operate import (
)
from .prompt import GRAPH_FIELD_SEP, PROMPTS
from .utils import (
Tokenizer,
TiktokenTokenizer,
EmbeddingFunc,
always_get_an_event_loop,
compute_mdhash_id,
convert_response_to_json,
encode_string_by_tiktoken,
lazy_external_import,
limit_async_func_call,
get_content_summary,
@@ -122,33 +134,38 @@ class LightRAG:
)
"""Number of overlapping tokens between consecutive text chunks to preserve context."""
tiktoken_model_name: str = field(default="gpt-4o-mini")
"""Model name used for tokenization when chunking text."""
tokenizer: Optional[Tokenizer] = field(default=None)
"""
A function that returns a Tokenizer instance.
If None, and a `tiktoken_model_name` is provided, a TiktokenTokenizer will be created.
If both are None, the default TiktokenTokenizer is used.
"""
"""Maximum number of tokens used for summarizing extracted entities."""
tiktoken_model_name: str = field(default="gpt-4o-mini")
"""Model name used for tokenization when chunking text with tiktoken. Defaults to `gpt-4o-mini`."""
chunking_func: Callable[
[
Tokenizer,
str,
str | None,
Optional[str],
bool,
int,
int,
str,
],
list[dict[str, Any]],
List[Dict[str, Any]],
] = field(default_factory=lambda: chunking_by_token_size)
"""
Custom chunking function for splitting text into chunks before processing.
The function should take the following parameters:
- `tokenizer`: A Tokenizer instance to use for tokenization.
- `content`: The text to be split into chunks.
- `split_by_character`: The character to split the text on. If None, the text is split into chunks of `chunk_token_size` tokens.
- `split_by_character_only`: If True, the text is split only on the specified character.
- `chunk_token_size`: The maximum number of tokens per chunk.
- `chunk_overlap_token_size`: The number of overlapping tokens between consecutive chunks.
- `tiktoken_model_name`: The name of the tiktoken model to use for tokenization.
The function should return a list of dictionaries, where each dictionary contains the following keys:
- `tokens`: The number of tokens in the chunk.
@@ -310,7 +327,15 @@ class LightRAG:
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
# Init LLM
# Init Tokenizer
# Post-initialization hook to handle backward compatabile tokenizer initialization based on provided parameters
if self.tokenizer is None:
if self.tiktoken_model_name:
self.tokenizer = TiktokenTokenizer(self.tiktoken_model_name)
else:
self.tokenizer = TiktokenTokenizer()
# Init Embedding
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore
self.embedding_func
)
@@ -603,11 +628,7 @@ class LightRAG:
inserting_chunks: dict[str, Any] = {}
for index, chunk_text in enumerate(text_chunks):
chunk_key = compute_mdhash_id(chunk_text, prefix="chunk-")
tokens = len(
encode_string_by_tiktoken(
chunk_text, model_name=self.tiktoken_model_name
)
)
tokens = len(self.tokenizer.encode(chunk_text))
inserting_chunks[chunk_key] = {
"content": chunk_text,
"full_doc_id": doc_key,
@@ -900,12 +921,12 @@ class LightRAG:
"file_path": file_path, # Add file path to each chunk
}
for dp in self.chunking_func(
self.tokenizer,
status_doc.content,
split_by_character,
split_by_character_only,
self.chunk_overlap_token_size,
self.chunk_token_size,
self.tiktoken_model_name,
)
}
@@ -1133,11 +1154,7 @@ class LightRAG:
for chunk_data in custom_kg.get("chunks", []):
chunk_content = clean_text(chunk_data["content"])
source_id = chunk_data["source_id"]
tokens = len(
encode_string_by_tiktoken(
chunk_content, model_name=self.tiktoken_model_name
)
)
tokens = len(self.tokenizer.encode(chunk_content))
chunk_order_index = (
0
if "chunk_order_index" not in chunk_data.keys()

View File

@@ -12,8 +12,7 @@ from .utils import (
logger,
clean_str,
compute_mdhash_id,
decode_tokens_by_tiktoken,
encode_string_by_tiktoken,
Tokenizer,
is_float_regex,
list_of_list_to_csv,
normalize_extracted_info,
@@ -46,32 +45,31 @@ load_dotenv(dotenv_path=".env", override=False)
def chunking_by_token_size(
tokenizer: Tokenizer,
content: str,
split_by_character: str | None = None,
split_by_character_only: bool = False,
overlap_token_size: int = 128,
max_token_size: int = 1024,
tiktoken_model: str = "gpt-4o",
) -> list[dict[str, Any]]:
tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model)
tokens = tokenizer.encode(content)
results: list[dict[str, Any]] = []
if split_by_character:
raw_chunks = content.split(split_by_character)
new_chunks = []
if split_by_character_only:
for chunk in raw_chunks:
_tokens = encode_string_by_tiktoken(chunk, model_name=tiktoken_model)
_tokens = tokenizer.encode(chunk)
new_chunks.append((len(_tokens), chunk))
else:
for chunk in raw_chunks:
_tokens = encode_string_by_tiktoken(chunk, model_name=tiktoken_model)
_tokens = tokenizer.encode(chunk)
if len(_tokens) > max_token_size:
for start in range(
0, len(_tokens), max_token_size - overlap_token_size
):
chunk_content = decode_tokens_by_tiktoken(
_tokens[start : start + max_token_size],
model_name=tiktoken_model,
chunk_content = tokenizer.decode(
_tokens[start : start + max_token_size]
)
new_chunks.append(
(min(max_token_size, len(_tokens) - start), chunk_content)
@@ -90,9 +88,7 @@ def chunking_by_token_size(
for index, start in enumerate(
range(0, len(tokens), max_token_size - overlap_token_size)
):
chunk_content = decode_tokens_by_tiktoken(
tokens[start : start + max_token_size], model_name=tiktoken_model
)
chunk_content = tokenizer.decode(tokens[start : start + max_token_size])
results.append(
{
"tokens": min(max_token_size, len(tokens) - start),
@@ -116,19 +112,19 @@ async def _handle_entity_relation_summary(
If too long, use LLM to summarize.
"""
use_llm_func: callable = global_config["llm_model_func"]
tokenizer: Tokenizer = global_config["tokenizer"]
llm_max_tokens = global_config["llm_model_max_token_size"]
tiktoken_model_name = global_config["tiktoken_model_name"]
summary_max_tokens = global_config["summary_to_max_tokens"]
language = global_config["addon_params"].get(
"language", PROMPTS["DEFAULT_LANGUAGE"]
)
tokens = encode_string_by_tiktoken(description, model_name=tiktoken_model_name)
tokens = tokenizer.encode(description)
if len(tokens) < summary_max_tokens: # No need for summary
return description
prompt_template = PROMPTS["summarize_entity_descriptions"]
use_description = decode_tokens_by_tiktoken(
tokens[:llm_max_tokens], model_name=tiktoken_model_name
)
use_description = tokenizer.decode(tokens[:llm_max_tokens])
context_base = dict(
entity_name=entity_or_relation_name,
description_list=use_description.split(GRAPH_FIELD_SEP),
@@ -865,7 +861,8 @@ async def kg_query(
if query_param.only_need_prompt:
return sys_prompt
len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt))
tokenizer: Tokenizer = global_config["tokenizer"]
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
response = await use_model_func(
@@ -987,7 +984,8 @@ async def extract_keywords_only(
query=text, examples=examples, language=language, history=history_context
)
len_of_prompts = len(encode_string_by_tiktoken(kw_prompt))
tokenizer: Tokenizer = global_config["tokenizer"]
len_of_prompts = len(tokenizer.encode(kw_prompt))
logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
# 5. Call the LLM for keyword extraction
@@ -1054,6 +1052,8 @@ async def mix_kg_vector_query(
2. Retrieving relevant text chunks through vector similarity
3. Combining both results for comprehensive answer generation
"""
# get tokenizer
tokenizer: Tokenizer = global_config["tokenizer"]
# 1. Cache handling
use_model_func = (
query_param.model_func
@@ -1153,6 +1153,7 @@ async def mix_kg_vector_query(
valid_chunks,
key=lambda x: x["content"],
max_token_size=query_param.max_token_for_text_unit,
tokenizer=tokenizer,
)
if not maybe_trun_chunks:
@@ -1210,7 +1211,7 @@ async def mix_kg_vector_query(
if query_param.only_need_prompt:
return sys_prompt
len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt))
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
logger.debug(f"[mix_kg_vector_query]Prompt Tokens: {len_of_prompts}")
# 6. Generate response
@@ -1373,17 +1374,24 @@ async def _get_node_data(
] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
# get entitytext chunk
use_text_units = await _find_most_related_text_unit_from_entities(
node_datas, query_param, text_chunks_db, knowledge_graph_inst
node_datas,
query_param,
text_chunks_db,
knowledge_graph_inst,
)
use_relations = await _find_most_related_edges_from_entities(
node_datas, query_param, knowledge_graph_inst
node_datas,
query_param,
knowledge_graph_inst,
)
tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
len_node_datas = len(node_datas)
node_datas = truncate_list_by_token_size(
node_datas,
key=lambda x: x["description"] if x["description"] is not None else "",
max_token_size=query_param.max_token_for_local_context,
tokenizer=tokenizer,
)
logger.debug(
f"Truncate entities from {len_node_datas} to {len(node_datas)} (max tokens:{query_param.max_token_for_local_context})"
@@ -1558,14 +1566,15 @@ async def _find_most_related_text_unit_from_entities(
logger.warning("No valid text units found")
return []
tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
all_text_units = sorted(
all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
)
all_text_units = truncate_list_by_token_size(
all_text_units,
key=lambda x: x["data"]["content"],
max_token_size=query_param.max_token_for_text_unit,
tokenizer=tokenizer,
)
logger.debug(
@@ -1619,6 +1628,7 @@ async def _find_most_related_edges_from_entities(
}
all_edges_data.append(combined)
tokenizer: Tokenizer = knowledge_graph_inst.global_config.get("tokenizer")
all_edges_data = sorted(
all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
)
@@ -1626,6 +1636,7 @@ async def _find_most_related_edges_from_entities(
all_edges_data,
key=lambda x: x["description"] if x["description"] is not None else "",
max_token_size=query_param.max_token_for_global_context,
tokenizer=tokenizer,
)
logger.debug(
@@ -1681,6 +1692,7 @@ async def _get_edge_data(
}
edge_datas.append(combined)
tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
edge_datas = sorted(
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
)
@@ -1688,13 +1700,19 @@ async def _get_edge_data(
edge_datas,
key=lambda x: x["description"] if x["description"] is not None else "",
max_token_size=query_param.max_token_for_global_context,
tokenizer=tokenizer,
)
use_entities, use_text_units = await asyncio.gather(
_find_most_related_entities_from_relationships(
edge_datas, query_param, knowledge_graph_inst
edge_datas,
query_param,
knowledge_graph_inst,
),
_find_related_text_unit_from_relationships(
edge_datas, query_param, text_chunks_db, knowledge_graph_inst
edge_datas,
query_param,
text_chunks_db,
knowledge_graph_inst,
),
)
logger.info(
@@ -1804,11 +1822,13 @@ async def _find_most_related_entities_from_relationships(
combined = {**node, "entity_name": entity_name, "rank": degree}
node_datas.append(combined)
tokenizer: Tokenizer = knowledge_graph_inst.global_config.get("tokenizer")
len_node_datas = len(node_datas)
node_datas = truncate_list_by_token_size(
node_datas,
key=lambda x: x["description"] if x["description"] is not None else "",
max_token_size=query_param.max_token_for_local_context,
tokenizer=tokenizer,
)
logger.debug(
f"Truncate entities from {len_node_datas} to {len(node_datas)} (max tokens:{query_param.max_token_for_local_context})"
@@ -1863,10 +1883,12 @@ async def _find_related_text_unit_from_relationships(
logger.warning("No valid text chunks after filtering")
return []
tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
truncated_text_units = truncate_list_by_token_size(
valid_text_units,
key=lambda x: x["data"]["content"],
max_token_size=query_param.max_token_for_text_unit,
tokenizer=tokenizer,
)
logger.debug(
@@ -1937,10 +1959,12 @@ async def naive_query(
logger.warning("No valid chunks found after filtering")
return PROMPTS["fail_response"]
tokenizer: Tokenizer = global_config["tokenizer"]
maybe_trun_chunks = truncate_list_by_token_size(
valid_chunks,
key=lambda x: x["content"],
max_token_size=query_param.max_token_for_text_unit,
tokenizer=tokenizer,
)
if not maybe_trun_chunks:
@@ -1978,7 +2002,7 @@ async def naive_query(
if query_param.only_need_prompt:
return sys_prompt
len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt))
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
logger.debug(f"[naive_query]Prompt Tokens: {len_of_prompts}")
response = await use_model_func(
@@ -2125,7 +2149,8 @@ async def kg_query_with_keywords(
if query_param.only_need_prompt:
return sys_prompt
len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt))
tokenizer: Tokenizer = global_config["tokenizer"]
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
logger.debug(f"[kg_query_with_keywords]Prompt Tokens: {len_of_prompts}")
# 6. Generate response

View File

@@ -12,10 +12,9 @@ import re
from dataclasses import dataclass
from functools import wraps
from hashlib import md5
from typing import Any, Callable, TYPE_CHECKING
from typing import Any, Protocol, Callable, TYPE_CHECKING, List
import xml.etree.ElementTree as ET
import numpy as np
import tiktoken
from lightrag.prompt import PROMPTS
from dotenv import load_dotenv
@@ -193,9 +192,6 @@ class UnlimitedSemaphore:
pass
ENCODER = None
@dataclass
class EmbeddingFunc:
embedding_dim: int
@@ -311,20 +307,89 @@ def write_json(json_obj, file_name):
json.dump(json_obj, f, indent=2, ensure_ascii=False)
def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"):
global ENCODER
if ENCODER is None:
ENCODER = tiktoken.encoding_for_model(model_name)
tokens = ENCODER.encode(content)
return tokens
class TokenizerInterface(Protocol):
"""
Defines the interface for a tokenizer, requiring encode and decode methods.
"""
def encode(self, content: str) -> List[int]:
"""Encodes a string into a list of tokens."""
...
def decode(self, tokens: List[int]) -> str:
"""Decodes a list of tokens into a string."""
...
def decode_tokens_by_tiktoken(tokens: list[int], model_name: str = "gpt-4o"):
global ENCODER
if ENCODER is None:
ENCODER = tiktoken.encoding_for_model(model_name)
content = ENCODER.decode(tokens)
return content
class Tokenizer:
"""
A wrapper around a tokenizer to provide a consistent interface for encoding and decoding.
"""
def __init__(self, model_name: str, tokenizer: TokenizerInterface):
"""
Initializes the Tokenizer with a tokenizer model name and a tokenizer instance.
Args:
model_name: The associated model name for the tokenizer.
tokenizer: An instance of a class implementing the TokenizerInterface.
"""
self.model_name: str = model_name
self.tokenizer: TokenizerInterface = tokenizer
def encode(self, content: str) -> List[int]:
"""
Encodes a string into a list of tokens using the underlying tokenizer.
Args:
content: The string to encode.
Returns:
A list of integer tokens.
"""
return self.tokenizer.encode(content)
def decode(self, tokens: List[int]) -> str:
"""
Decodes a list of tokens into a string using the underlying tokenizer.
Args:
tokens: A list of integer tokens to decode.
Returns:
The decoded string.
"""
return self.tokenizer.decode(tokens)
class TiktokenTokenizer(Tokenizer):
"""
A Tokenizer implementation using the tiktoken library.
"""
def __init__(self, model_name: str = "gpt-4o-mini"):
"""
Initializes the TiktokenTokenizer with a specified model name.
Args:
model_name: The model name for the tiktoken tokenizer to use. Defaults to "gpt-4o-mini".
Raises:
ImportError: If tiktoken is not installed.
ValueError: If the model_name is invalid.
"""
try:
import tiktoken
except ImportError:
raise ImportError(
"tiktoken is not installed. Please install it with `pip install tiktoken` or define custom `tokenizer_func`."
)
try:
tokenizer = tiktoken.encoding_for_model(model_name)
super().__init__(model_name=model_name, tokenizer=tokenizer)
except KeyError:
raise ValueError(f"Invalid model_name: {model_name}.")
def pack_user_ass_to_openai_messages(*args: str):
@@ -361,14 +426,17 @@ def is_float_regex(value: str) -> bool:
def truncate_list_by_token_size(
list_data: list[Any], key: Callable[[Any], str], max_token_size: int
list_data: list[Any],
key: Callable[[Any], str],
max_token_size: int,
tokenizer: Tokenizer,
) -> list[int]:
"""Truncate a list of data by token size"""
if max_token_size <= 0:
return []
tokens = 0
for i, data in enumerate(list_data):
tokens += len(encode_string_by_tiktoken(key(data)))
tokens += len(tokenizer.encode(key(data)))
if tokens > max_token_size:
return list_data[:i]
return list_data