fix linting

This commit is contained in:
drahnreb
2025-04-18 16:14:31 +02:00
parent e71f466910
commit 9c6b5aefcb
5 changed files with 53 additions and 28 deletions

View File

@@ -51,10 +51,12 @@ class GemmaTokenizer(Tokenizer):
"google/gemma3": _TokenizerConfig(
tokenizer_model_url="https://raw.githubusercontent.com/google/gemma_pytorch/cb7c0152a369e43908e769eb09e1ce6043afe084/tokenizer/gemma3_cleaned_262144_v2.spiece.model",
tokenizer_model_hash="1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c",
)
}
),
}
def __init__(self, model_name: str = "gemini-2.0-flash", tokenizer_dir: Optional[str] = None):
def __init__(
self, model_name: str = "gemini-2.0-flash", tokenizer_dir: Optional[str] = None
):
# https://github.com/google/gemma_pytorch/tree/main/tokenizer
if "1.5" in model_name or "1.0" in model_name:
# up to gemini 1.5 gemma2 is a comparable local tokenizer
@@ -77,7 +79,9 @@ class GemmaTokenizer(Tokenizer):
else:
model_data = None
if not model_data:
model_data = self._load_from_url(file_url=file_url, expected_hash=expected_hash)
model_data = self._load_from_url(
file_url=file_url, expected_hash=expected_hash
)
self.save_tokenizer_to_cache(cache_path=file_path, model_data=model_data)
tokenizer = spm.SentencePieceProcessor()
@@ -140,7 +144,7 @@ class GemmaTokenizer(Tokenizer):
# def encode(self, content: str) -> list[int]:
# return self.tokenizer.encode(content)
# def decode(self, tokens: list[int]) -> str:
# return self.tokenizer.decode(tokens)
@@ -187,7 +191,10 @@ async def initialize_rag():
rag = LightRAG(
working_dir=WORKING_DIR,
# tiktoken_model_name="gpt-4o-mini",
tokenizer=GemmaTokenizer(tokenizer_dir=(Path(WORKING_DIR) / "vertexai_tokenizer_model"), model_name="gemini-2.0-flash"),
tokenizer=GemmaTokenizer(
tokenizer_dir=(Path(WORKING_DIR) / "vertexai_tokenizer_model"),
model_name="gemini-2.0-flash",
),
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=384,