fix linting

This commit is contained in:
drahnreb
2025-04-18 16:14:31 +02:00
parent e71f466910
commit 9c6b5aefcb
5 changed files with 53 additions and 28 deletions

View File

@@ -51,10 +51,12 @@ class GemmaTokenizer(Tokenizer):
"google/gemma3": _TokenizerConfig(
tokenizer_model_url="https://raw.githubusercontent.com/google/gemma_pytorch/cb7c0152a369e43908e769eb09e1ce6043afe084/tokenizer/gemma3_cleaned_262144_v2.spiece.model",
tokenizer_model_hash="1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c",
)
),
}
def __init__(self, model_name: str = "gemini-2.0-flash", tokenizer_dir: Optional[str] = None):
def __init__(
self, model_name: str = "gemini-2.0-flash", tokenizer_dir: Optional[str] = None
):
# https://github.com/google/gemma_pytorch/tree/main/tokenizer
if "1.5" in model_name or "1.0" in model_name:
# up to gemini 1.5 gemma2 is a comparable local tokenizer
@@ -77,7 +79,9 @@ class GemmaTokenizer(Tokenizer):
else:
model_data = None
if not model_data:
model_data = self._load_from_url(file_url=file_url, expected_hash=expected_hash)
model_data = self._load_from_url(
file_url=file_url, expected_hash=expected_hash
)
self.save_tokenizer_to_cache(cache_path=file_path, model_data=model_data)
tokenizer = spm.SentencePieceProcessor()
@@ -187,7 +191,10 @@ async def initialize_rag():
rag = LightRAG(
working_dir=WORKING_DIR,
# tiktoken_model_name="gpt-4o-mini",
tokenizer=GemmaTokenizer(tokenizer_dir=(Path(WORKING_DIR) / "vertexai_tokenizer_model"), model_name="gemini-2.0-flash"),
tokenizer=GemmaTokenizer(
tokenizer_dir=(Path(WORKING_DIR) / "vertexai_tokenizer_model"),
model_name="gemini-2.0-flash",
),
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=384,

View File

@@ -7,7 +7,18 @@ import warnings
from dataclasses import asdict, dataclass, field
from datetime import datetime
from functools import partial
from typing import Any, AsyncIterator, Callable, Iterator, cast, final, Literal, Optional, List, Dict
from typing import (
Any,
AsyncIterator,
Callable,
Iterator,
cast,
final,
Literal,
Optional,
List,
Dict,
)
from lightrag.kg import (
STORAGES,
@@ -1147,11 +1158,7 @@ class LightRAG:
for chunk_data in custom_kg.get("chunks", []):
chunk_content = clean_text(chunk_data["content"])
source_id = chunk_data["source_id"]
tokens = len(
self.tokenizer.encode(
chunk_content
)
)
tokens = len(self.tokenizer.encode(chunk_content))
chunk_order_index = (
0
if "chunk_order_index" not in chunk_data.keys()

View File

@@ -88,9 +88,7 @@ def chunking_by_token_size(
for index, start in enumerate(
range(0, len(tokens), max_token_size - overlap_token_size)
):
chunk_content = tokenizer.decode(
tokens[start : start + max_token_size]
)
chunk_content = tokenizer.decode(tokens[start : start + max_token_size])
results.append(
{
"tokens": min(max_token_size, len(tokens) - start),
@@ -126,9 +124,7 @@ async def _handle_entity_relation_summary(
if len(tokens) < summary_max_tokens: # No need for summary
return description
prompt_template = PROMPTS["summarize_entity_descriptions"]
use_description = tokenizer.decode(
tokens[:llm_max_tokens]
)
use_description = tokenizer.decode(tokens[:llm_max_tokens])
context_base = dict(
entity_name=entity_or_relation_name,
description_list=use_description.split(GRAPH_FIELD_SEP),
@@ -1378,10 +1374,15 @@ async def _get_node_data(
] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
# get entitytext chunk
use_text_units = await _find_most_related_text_unit_from_entities(
node_datas, query_param, text_chunks_db, knowledge_graph_inst,
node_datas,
query_param,
text_chunks_db,
knowledge_graph_inst,
)
use_relations = await _find_most_related_edges_from_entities(
node_datas, query_param, knowledge_graph_inst,
node_datas,
query_param,
knowledge_graph_inst,
)
tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
@@ -1703,10 +1704,15 @@ async def _get_edge_data(
)
use_entities, use_text_units = await asyncio.gather(
_find_most_related_entities_from_relationships(
edge_datas, query_param, knowledge_graph_inst,
edge_datas,
query_param,
knowledge_graph_inst,
),
_find_related_text_unit_from_relationships(
edge_datas, query_param, text_chunks_db, knowledge_graph_inst,
edge_datas,
query_param,
text_chunks_db,
knowledge_graph_inst,
),
)
logger.info(

View File

@@ -12,7 +12,7 @@ import re
from dataclasses import dataclass
from functools import wraps
from hashlib import md5
from typing import Any, Protocol, Callable, TYPE_CHECKING, List, Optional, Union
from typing import Any, Protocol, Callable, TYPE_CHECKING, List
import xml.etree.ElementTree as ET
import numpy as np
from lightrag.prompt import PROMPTS
@@ -311,6 +311,7 @@ class TokenizerInterface(Protocol):
"""
Defines the interface for a tokenizer, requiring encode and decode methods.
"""
def encode(self, content: str) -> List[int]:
"""Encodes a string into a list of tokens."""
...
@@ -319,10 +320,12 @@ class TokenizerInterface(Protocol):
"""Decodes a list of tokens into a string."""
...
class Tokenizer:
"""
A wrapper around a tokenizer to provide a consistent interface for encoding and decoding.
"""
def __init__(self, model_name: str, tokenizer: TokenizerInterface):
"""
Initializes the Tokenizer with a tokenizer model name and a tokenizer instance.
@@ -363,6 +366,7 @@ class TiktokenTokenizer(Tokenizer):
"""
A Tokenizer implementation using the tiktoken library.
"""
def __init__(self, model_name: str = "gpt-4o-mini"):
"""
Initializes the TiktokenTokenizer with a specified model name.
@@ -385,9 +389,7 @@ class TiktokenTokenizer(Tokenizer):
tokenizer = tiktoken.encoding_for_model(model_name)
super().__init__(model_name=model_name, tokenizer=tokenizer)
except KeyError:
raise ValueError(
f"Invalid model_name: {model_name}."
)
raise ValueError(f"Invalid model_name: {model_name}.")
def pack_user_ass_to_openai_messages(*args: str):
@@ -424,7 +426,10 @@ def is_float_regex(value: str) -> bool:
def truncate_list_by_token_size(
list_data: list[Any], key: Callable[[Any], str], max_token_size: int, tokenizer: Tokenizer
list_data: list[Any],
key: Callable[[Any], str],
max_token_size: int,
tokenizer: Tokenizer,
) -> list[int]:
"""Truncate a list of data by token size"""
if max_token_size <= 0: