diff --git a/examples/lightrag_gemini_demo_no_tiktoken.py b/examples/lightrag_gemini_demo_no_tiktoken.py index 7ebaf5f2..92c74201 100644 --- a/examples/lightrag_gemini_demo_no_tiktoken.py +++ b/examples/lightrag_gemini_demo_no_tiktoken.py @@ -51,10 +51,12 @@ class GemmaTokenizer(Tokenizer): "google/gemma3": _TokenizerConfig( tokenizer_model_url="https://raw.githubusercontent.com/google/gemma_pytorch/cb7c0152a369e43908e769eb09e1ce6043afe084/tokenizer/gemma3_cleaned_262144_v2.spiece.model", tokenizer_model_hash="1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c", - ) - } + ), + } - def __init__(self, model_name: str = "gemini-2.0-flash", tokenizer_dir: Optional[str] = None): + def __init__( + self, model_name: str = "gemini-2.0-flash", tokenizer_dir: Optional[str] = None + ): # https://github.com/google/gemma_pytorch/tree/main/tokenizer if "1.5" in model_name or "1.0" in model_name: # up to gemini 1.5 gemma2 is a comparable local tokenizer @@ -77,7 +79,9 @@ class GemmaTokenizer(Tokenizer): else: model_data = None if not model_data: - model_data = self._load_from_url(file_url=file_url, expected_hash=expected_hash) + model_data = self._load_from_url( + file_url=file_url, expected_hash=expected_hash + ) self.save_tokenizer_to_cache(cache_path=file_path, model_data=model_data) tokenizer = spm.SentencePieceProcessor() @@ -140,7 +144,7 @@ class GemmaTokenizer(Tokenizer): # def encode(self, content: str) -> list[int]: # return self.tokenizer.encode(content) - + # def decode(self, tokens: list[int]) -> str: # return self.tokenizer.decode(tokens) @@ -187,7 +191,10 @@ async def initialize_rag(): rag = LightRAG( working_dir=WORKING_DIR, # tiktoken_model_name="gpt-4o-mini", - tokenizer=GemmaTokenizer(tokenizer_dir=(Path(WORKING_DIR) / "vertexai_tokenizer_model"), model_name="gemini-2.0-flash"), + tokenizer=GemmaTokenizer( + tokenizer_dir=(Path(WORKING_DIR) / "vertexai_tokenizer_model"), + model_name="gemini-2.0-flash", + ), llm_model_func=llm_model_func, embedding_func=EmbeddingFunc( embedding_dim=384, diff --git a/lightrag/api/routers/ollama_api.py b/lightrag/api/routers/ollama_api.py index e33e3992..3aabfe35 100644 --- a/lightrag/api/routers/ollama_api.py +++ b/lightrag/api/routers/ollama_api.py @@ -10,7 +10,7 @@ from fastapi.responses import StreamingResponse import asyncio from ascii_colors import trace_exception from lightrag import LightRAG, QueryParam -from lightrag.utils import TiktokenTokenizer +from lightrag.utils import TiktokenTokenizer from lightrag.api.utils_api import ollama_server_infos, get_combined_auth_dependency from fastapi import Depends diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 22f2aa0d..9ae3a7ef 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -7,7 +7,18 @@ import warnings from dataclasses import asdict, dataclass, field from datetime import datetime from functools import partial -from typing import Any, AsyncIterator, Callable, Iterator, cast, final, Literal, Optional, List, Dict +from typing import ( + Any, + AsyncIterator, + Callable, + Iterator, + cast, + final, + Literal, + Optional, + List, + Dict, +) from lightrag.kg import ( STORAGES, @@ -1147,11 +1158,7 @@ class LightRAG: for chunk_data in custom_kg.get("chunks", []): chunk_content = clean_text(chunk_data["content"]) source_id = chunk_data["source_id"] - tokens = len( - self.tokenizer.encode( - chunk_content - ) - ) + tokens = len(self.tokenizer.encode(chunk_content)) chunk_order_index = ( 0 if "chunk_order_index" not in chunk_data.keys() diff --git a/lightrag/operate.py b/lightrag/operate.py index 9cc4cf0c..8eb7cf24 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -88,9 +88,7 @@ def chunking_by_token_size( for index, start in enumerate( range(0, len(tokens), max_token_size - overlap_token_size) ): - chunk_content = tokenizer.decode( - tokens[start : start + max_token_size] - ) + chunk_content = tokenizer.decode(tokens[start : start + max_token_size]) results.append( { "tokens": min(max_token_size, len(tokens) - start), @@ -126,9 +124,7 @@ async def _handle_entity_relation_summary( if len(tokens) < summary_max_tokens: # No need for summary return description prompt_template = PROMPTS["summarize_entity_descriptions"] - use_description = tokenizer.decode( - tokens[:llm_max_tokens] - ) + use_description = tokenizer.decode(tokens[:llm_max_tokens]) context_base = dict( entity_name=entity_or_relation_name, description_list=use_description.split(GRAPH_FIELD_SEP), @@ -1378,10 +1374,15 @@ async def _get_node_data( ] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram. # get entitytext chunk use_text_units = await _find_most_related_text_unit_from_entities( - node_datas, query_param, text_chunks_db, knowledge_graph_inst, + node_datas, + query_param, + text_chunks_db, + knowledge_graph_inst, ) use_relations = await _find_most_related_edges_from_entities( - node_datas, query_param, knowledge_graph_inst, + node_datas, + query_param, + knowledge_graph_inst, ) tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer") @@ -1703,10 +1704,15 @@ async def _get_edge_data( ) use_entities, use_text_units = await asyncio.gather( _find_most_related_entities_from_relationships( - edge_datas, query_param, knowledge_graph_inst, + edge_datas, + query_param, + knowledge_graph_inst, ), _find_related_text_unit_from_relationships( - edge_datas, query_param, text_chunks_db, knowledge_graph_inst, + edge_datas, + query_param, + text_chunks_db, + knowledge_graph_inst, ), ) logger.info( diff --git a/lightrag/utils.py b/lightrag/utils.py index 0d490612..c6991629 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -12,7 +12,7 @@ import re from dataclasses import dataclass from functools import wraps from hashlib import md5 -from typing import Any, Protocol, Callable, TYPE_CHECKING, List, Optional, Union +from typing import Any, Protocol, Callable, TYPE_CHECKING, List import xml.etree.ElementTree as ET import numpy as np from lightrag.prompt import PROMPTS @@ -311,6 +311,7 @@ class TokenizerInterface(Protocol): """ Defines the interface for a tokenizer, requiring encode and decode methods. """ + def encode(self, content: str) -> List[int]: """Encodes a string into a list of tokens.""" ... @@ -319,10 +320,12 @@ class TokenizerInterface(Protocol): """Decodes a list of tokens into a string.""" ... + class Tokenizer: """ A wrapper around a tokenizer to provide a consistent interface for encoding and decoding. """ + def __init__(self, model_name: str, tokenizer: TokenizerInterface): """ Initializes the Tokenizer with a tokenizer model name and a tokenizer instance. @@ -363,6 +366,7 @@ class TiktokenTokenizer(Tokenizer): """ A Tokenizer implementation using the tiktoken library. """ + def __init__(self, model_name: str = "gpt-4o-mini"): """ Initializes the TiktokenTokenizer with a specified model name. @@ -385,9 +389,7 @@ class TiktokenTokenizer(Tokenizer): tokenizer = tiktoken.encoding_for_model(model_name) super().__init__(model_name=model_name, tokenizer=tokenizer) except KeyError: - raise ValueError( - f"Invalid model_name: {model_name}." - ) + raise ValueError(f"Invalid model_name: {model_name}.") def pack_user_ass_to_openai_messages(*args: str): @@ -424,7 +426,10 @@ def is_float_regex(value: str) -> bool: def truncate_list_by_token_size( - list_data: list[Any], key: Callable[[Any], str], max_token_size: int, tokenizer: Tokenizer + list_data: list[Any], + key: Callable[[Any], str], + max_token_size: int, + tokenizer: Tokenizer, ) -> list[int]: """Truncate a list of data by token size""" if max_token_size <= 0: