fix linting

This commit is contained in:
drahnreb
2025-04-18 16:14:31 +02:00
parent e71f466910
commit 9c6b5aefcb
5 changed files with 53 additions and 28 deletions

View File

@@ -51,10 +51,12 @@ class GemmaTokenizer(Tokenizer):
"google/gemma3": _TokenizerConfig( "google/gemma3": _TokenizerConfig(
tokenizer_model_url="https://raw.githubusercontent.com/google/gemma_pytorch/cb7c0152a369e43908e769eb09e1ce6043afe084/tokenizer/gemma3_cleaned_262144_v2.spiece.model", tokenizer_model_url="https://raw.githubusercontent.com/google/gemma_pytorch/cb7c0152a369e43908e769eb09e1ce6043afe084/tokenizer/gemma3_cleaned_262144_v2.spiece.model",
tokenizer_model_hash="1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c", tokenizer_model_hash="1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c",
) ),
} }
def __init__(self, model_name: str = "gemini-2.0-flash", tokenizer_dir: Optional[str] = None): def __init__(
self, model_name: str = "gemini-2.0-flash", tokenizer_dir: Optional[str] = None
):
# https://github.com/google/gemma_pytorch/tree/main/tokenizer # https://github.com/google/gemma_pytorch/tree/main/tokenizer
if "1.5" in model_name or "1.0" in model_name: if "1.5" in model_name or "1.0" in model_name:
# up to gemini 1.5 gemma2 is a comparable local tokenizer # up to gemini 1.5 gemma2 is a comparable local tokenizer
@@ -77,7 +79,9 @@ class GemmaTokenizer(Tokenizer):
else: else:
model_data = None model_data = None
if not model_data: if not model_data:
model_data = self._load_from_url(file_url=file_url, expected_hash=expected_hash) model_data = self._load_from_url(
file_url=file_url, expected_hash=expected_hash
)
self.save_tokenizer_to_cache(cache_path=file_path, model_data=model_data) self.save_tokenizer_to_cache(cache_path=file_path, model_data=model_data)
tokenizer = spm.SentencePieceProcessor() tokenizer = spm.SentencePieceProcessor()
@@ -187,7 +191,10 @@ async def initialize_rag():
rag = LightRAG( rag = LightRAG(
working_dir=WORKING_DIR, working_dir=WORKING_DIR,
# tiktoken_model_name="gpt-4o-mini", # tiktoken_model_name="gpt-4o-mini",
tokenizer=GemmaTokenizer(tokenizer_dir=(Path(WORKING_DIR) / "vertexai_tokenizer_model"), model_name="gemini-2.0-flash"), tokenizer=GemmaTokenizer(
tokenizer_dir=(Path(WORKING_DIR) / "vertexai_tokenizer_model"),
model_name="gemini-2.0-flash",
),
llm_model_func=llm_model_func, llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc( embedding_func=EmbeddingFunc(
embedding_dim=384, embedding_dim=384,

View File

@@ -7,7 +7,18 @@ import warnings
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from typing import Any, AsyncIterator, Callable, Iterator, cast, final, Literal, Optional, List, Dict from typing import (
Any,
AsyncIterator,
Callable,
Iterator,
cast,
final,
Literal,
Optional,
List,
Dict,
)
from lightrag.kg import ( from lightrag.kg import (
STORAGES, STORAGES,
@@ -1147,11 +1158,7 @@ class LightRAG:
for chunk_data in custom_kg.get("chunks", []): for chunk_data in custom_kg.get("chunks", []):
chunk_content = clean_text(chunk_data["content"]) chunk_content = clean_text(chunk_data["content"])
source_id = chunk_data["source_id"] source_id = chunk_data["source_id"]
tokens = len( tokens = len(self.tokenizer.encode(chunk_content))
self.tokenizer.encode(
chunk_content
)
)
chunk_order_index = ( chunk_order_index = (
0 0
if "chunk_order_index" not in chunk_data.keys() if "chunk_order_index" not in chunk_data.keys()

View File

@@ -88,9 +88,7 @@ def chunking_by_token_size(
for index, start in enumerate( for index, start in enumerate(
range(0, len(tokens), max_token_size - overlap_token_size) range(0, len(tokens), max_token_size - overlap_token_size)
): ):
chunk_content = tokenizer.decode( chunk_content = tokenizer.decode(tokens[start : start + max_token_size])
tokens[start : start + max_token_size]
)
results.append( results.append(
{ {
"tokens": min(max_token_size, len(tokens) - start), "tokens": min(max_token_size, len(tokens) - start),
@@ -126,9 +124,7 @@ async def _handle_entity_relation_summary(
if len(tokens) < summary_max_tokens: # No need for summary if len(tokens) < summary_max_tokens: # No need for summary
return description return description
prompt_template = PROMPTS["summarize_entity_descriptions"] prompt_template = PROMPTS["summarize_entity_descriptions"]
use_description = tokenizer.decode( use_description = tokenizer.decode(tokens[:llm_max_tokens])
tokens[:llm_max_tokens]
)
context_base = dict( context_base = dict(
entity_name=entity_or_relation_name, entity_name=entity_or_relation_name,
description_list=use_description.split(GRAPH_FIELD_SEP), description_list=use_description.split(GRAPH_FIELD_SEP),
@@ -1378,10 +1374,15 @@ async def _get_node_data(
] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram. ] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
# get entitytext chunk # get entitytext chunk
use_text_units = await _find_most_related_text_unit_from_entities( use_text_units = await _find_most_related_text_unit_from_entities(
node_datas, query_param, text_chunks_db, knowledge_graph_inst, node_datas,
query_param,
text_chunks_db,
knowledge_graph_inst,
) )
use_relations = await _find_most_related_edges_from_entities( use_relations = await _find_most_related_edges_from_entities(
node_datas, query_param, knowledge_graph_inst, node_datas,
query_param,
knowledge_graph_inst,
) )
tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer") tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
@@ -1703,10 +1704,15 @@ async def _get_edge_data(
) )
use_entities, use_text_units = await asyncio.gather( use_entities, use_text_units = await asyncio.gather(
_find_most_related_entities_from_relationships( _find_most_related_entities_from_relationships(
edge_datas, query_param, knowledge_graph_inst, edge_datas,
query_param,
knowledge_graph_inst,
), ),
_find_related_text_unit_from_relationships( _find_related_text_unit_from_relationships(
edge_datas, query_param, text_chunks_db, knowledge_graph_inst, edge_datas,
query_param,
text_chunks_db,
knowledge_graph_inst,
), ),
) )
logger.info( logger.info(

View File

@@ -12,7 +12,7 @@ import re
from dataclasses import dataclass from dataclasses import dataclass
from functools import wraps from functools import wraps
from hashlib import md5 from hashlib import md5
from typing import Any, Protocol, Callable, TYPE_CHECKING, List, Optional, Union from typing import Any, Protocol, Callable, TYPE_CHECKING, List
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
import numpy as np import numpy as np
from lightrag.prompt import PROMPTS from lightrag.prompt import PROMPTS
@@ -311,6 +311,7 @@ class TokenizerInterface(Protocol):
""" """
Defines the interface for a tokenizer, requiring encode and decode methods. Defines the interface for a tokenizer, requiring encode and decode methods.
""" """
def encode(self, content: str) -> List[int]: def encode(self, content: str) -> List[int]:
"""Encodes a string into a list of tokens.""" """Encodes a string into a list of tokens."""
... ...
@@ -319,10 +320,12 @@ class TokenizerInterface(Protocol):
"""Decodes a list of tokens into a string.""" """Decodes a list of tokens into a string."""
... ...
class Tokenizer: class Tokenizer:
""" """
A wrapper around a tokenizer to provide a consistent interface for encoding and decoding. A wrapper around a tokenizer to provide a consistent interface for encoding and decoding.
""" """
def __init__(self, model_name: str, tokenizer: TokenizerInterface): def __init__(self, model_name: str, tokenizer: TokenizerInterface):
""" """
Initializes the Tokenizer with a tokenizer model name and a tokenizer instance. Initializes the Tokenizer with a tokenizer model name and a tokenizer instance.
@@ -363,6 +366,7 @@ class TiktokenTokenizer(Tokenizer):
""" """
A Tokenizer implementation using the tiktoken library. A Tokenizer implementation using the tiktoken library.
""" """
def __init__(self, model_name: str = "gpt-4o-mini"): def __init__(self, model_name: str = "gpt-4o-mini"):
""" """
Initializes the TiktokenTokenizer with a specified model name. Initializes the TiktokenTokenizer with a specified model name.
@@ -385,9 +389,7 @@ class TiktokenTokenizer(Tokenizer):
tokenizer = tiktoken.encoding_for_model(model_name) tokenizer = tiktoken.encoding_for_model(model_name)
super().__init__(model_name=model_name, tokenizer=tokenizer) super().__init__(model_name=model_name, tokenizer=tokenizer)
except KeyError: except KeyError:
raise ValueError( raise ValueError(f"Invalid model_name: {model_name}.")
f"Invalid model_name: {model_name}."
)
def pack_user_ass_to_openai_messages(*args: str): def pack_user_ass_to_openai_messages(*args: str):
@@ -424,7 +426,10 @@ def is_float_regex(value: str) -> bool:
def truncate_list_by_token_size( def truncate_list_by_token_size(
list_data: list[Any], key: Callable[[Any], str], max_token_size: int, tokenizer: Tokenizer list_data: list[Any],
key: Callable[[Any], str],
max_token_size: int,
tokenizer: Tokenizer,
) -> list[int]: ) -> list[int]:
"""Truncate a list of data by token size""" """Truncate a list of data by token size"""
if max_token_size <= 0: if max_token_size <= 0: