fix linting
This commit is contained in:
@@ -51,10 +51,12 @@ class GemmaTokenizer(Tokenizer):
|
||||
"google/gemma3": _TokenizerConfig(
|
||||
tokenizer_model_url="https://raw.githubusercontent.com/google/gemma_pytorch/cb7c0152a369e43908e769eb09e1ce6043afe084/tokenizer/gemma3_cleaned_262144_v2.spiece.model",
|
||||
tokenizer_model_hash="1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c",
|
||||
)
|
||||
}
|
||||
),
|
||||
}
|
||||
|
||||
def __init__(self, model_name: str = "gemini-2.0-flash", tokenizer_dir: Optional[str] = None):
|
||||
def __init__(
|
||||
self, model_name: str = "gemini-2.0-flash", tokenizer_dir: Optional[str] = None
|
||||
):
|
||||
# https://github.com/google/gemma_pytorch/tree/main/tokenizer
|
||||
if "1.5" in model_name or "1.0" in model_name:
|
||||
# up to gemini 1.5 gemma2 is a comparable local tokenizer
|
||||
@@ -77,7 +79,9 @@ class GemmaTokenizer(Tokenizer):
|
||||
else:
|
||||
model_data = None
|
||||
if not model_data:
|
||||
model_data = self._load_from_url(file_url=file_url, expected_hash=expected_hash)
|
||||
model_data = self._load_from_url(
|
||||
file_url=file_url, expected_hash=expected_hash
|
||||
)
|
||||
self.save_tokenizer_to_cache(cache_path=file_path, model_data=model_data)
|
||||
|
||||
tokenizer = spm.SentencePieceProcessor()
|
||||
@@ -140,7 +144,7 @@ class GemmaTokenizer(Tokenizer):
|
||||
|
||||
# def encode(self, content: str) -> list[int]:
|
||||
# return self.tokenizer.encode(content)
|
||||
|
||||
|
||||
# def decode(self, tokens: list[int]) -> str:
|
||||
# return self.tokenizer.decode(tokens)
|
||||
|
||||
@@ -187,7 +191,10 @@ async def initialize_rag():
|
||||
rag = LightRAG(
|
||||
working_dir=WORKING_DIR,
|
||||
# tiktoken_model_name="gpt-4o-mini",
|
||||
tokenizer=GemmaTokenizer(tokenizer_dir=(Path(WORKING_DIR) / "vertexai_tokenizer_model"), model_name="gemini-2.0-flash"),
|
||||
tokenizer=GemmaTokenizer(
|
||||
tokenizer_dir=(Path(WORKING_DIR) / "vertexai_tokenizer_model"),
|
||||
model_name="gemini-2.0-flash",
|
||||
),
|
||||
llm_model_func=llm_model_func,
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=384,
|
||||
|
@@ -10,7 +10,7 @@ from fastapi.responses import StreamingResponse
|
||||
import asyncio
|
||||
from ascii_colors import trace_exception
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.utils import TiktokenTokenizer
|
||||
from lightrag.utils import TiktokenTokenizer
|
||||
from lightrag.api.utils_api import ollama_server_infos, get_combined_auth_dependency
|
||||
from fastapi import Depends
|
||||
|
||||
|
@@ -7,7 +7,18 @@ import warnings
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import Any, AsyncIterator, Callable, Iterator, cast, final, Literal, Optional, List, Dict
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Iterator,
|
||||
cast,
|
||||
final,
|
||||
Literal,
|
||||
Optional,
|
||||
List,
|
||||
Dict,
|
||||
)
|
||||
|
||||
from lightrag.kg import (
|
||||
STORAGES,
|
||||
@@ -1147,11 +1158,7 @@ class LightRAG:
|
||||
for chunk_data in custom_kg.get("chunks", []):
|
||||
chunk_content = clean_text(chunk_data["content"])
|
||||
source_id = chunk_data["source_id"]
|
||||
tokens = len(
|
||||
self.tokenizer.encode(
|
||||
chunk_content
|
||||
)
|
||||
)
|
||||
tokens = len(self.tokenizer.encode(chunk_content))
|
||||
chunk_order_index = (
|
||||
0
|
||||
if "chunk_order_index" not in chunk_data.keys()
|
||||
|
@@ -88,9 +88,7 @@ def chunking_by_token_size(
|
||||
for index, start in enumerate(
|
||||
range(0, len(tokens), max_token_size - overlap_token_size)
|
||||
):
|
||||
chunk_content = tokenizer.decode(
|
||||
tokens[start : start + max_token_size]
|
||||
)
|
||||
chunk_content = tokenizer.decode(tokens[start : start + max_token_size])
|
||||
results.append(
|
||||
{
|
||||
"tokens": min(max_token_size, len(tokens) - start),
|
||||
@@ -126,9 +124,7 @@ async def _handle_entity_relation_summary(
|
||||
if len(tokens) < summary_max_tokens: # No need for summary
|
||||
return description
|
||||
prompt_template = PROMPTS["summarize_entity_descriptions"]
|
||||
use_description = tokenizer.decode(
|
||||
tokens[:llm_max_tokens]
|
||||
)
|
||||
use_description = tokenizer.decode(tokens[:llm_max_tokens])
|
||||
context_base = dict(
|
||||
entity_name=entity_or_relation_name,
|
||||
description_list=use_description.split(GRAPH_FIELD_SEP),
|
||||
@@ -1378,10 +1374,15 @@ async def _get_node_data(
|
||||
] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
|
||||
# get entitytext chunk
|
||||
use_text_units = await _find_most_related_text_unit_from_entities(
|
||||
node_datas, query_param, text_chunks_db, knowledge_graph_inst,
|
||||
node_datas,
|
||||
query_param,
|
||||
text_chunks_db,
|
||||
knowledge_graph_inst,
|
||||
)
|
||||
use_relations = await _find_most_related_edges_from_entities(
|
||||
node_datas, query_param, knowledge_graph_inst,
|
||||
node_datas,
|
||||
query_param,
|
||||
knowledge_graph_inst,
|
||||
)
|
||||
|
||||
tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
|
||||
@@ -1703,10 +1704,15 @@ async def _get_edge_data(
|
||||
)
|
||||
use_entities, use_text_units = await asyncio.gather(
|
||||
_find_most_related_entities_from_relationships(
|
||||
edge_datas, query_param, knowledge_graph_inst,
|
||||
edge_datas,
|
||||
query_param,
|
||||
knowledge_graph_inst,
|
||||
),
|
||||
_find_related_text_unit_from_relationships(
|
||||
edge_datas, query_param, text_chunks_db, knowledge_graph_inst,
|
||||
edge_datas,
|
||||
query_param,
|
||||
text_chunks_db,
|
||||
knowledge_graph_inst,
|
||||
),
|
||||
)
|
||||
logger.info(
|
||||
|
@@ -12,7 +12,7 @@ import re
|
||||
from dataclasses import dataclass
|
||||
from functools import wraps
|
||||
from hashlib import md5
|
||||
from typing import Any, Protocol, Callable, TYPE_CHECKING, List, Optional, Union
|
||||
from typing import Any, Protocol, Callable, TYPE_CHECKING, List
|
||||
import xml.etree.ElementTree as ET
|
||||
import numpy as np
|
||||
from lightrag.prompt import PROMPTS
|
||||
@@ -311,6 +311,7 @@ class TokenizerInterface(Protocol):
|
||||
"""
|
||||
Defines the interface for a tokenizer, requiring encode and decode methods.
|
||||
"""
|
||||
|
||||
def encode(self, content: str) -> List[int]:
|
||||
"""Encodes a string into a list of tokens."""
|
||||
...
|
||||
@@ -319,10 +320,12 @@ class TokenizerInterface(Protocol):
|
||||
"""Decodes a list of tokens into a string."""
|
||||
...
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
"""
|
||||
A wrapper around a tokenizer to provide a consistent interface for encoding and decoding.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str, tokenizer: TokenizerInterface):
|
||||
"""
|
||||
Initializes the Tokenizer with a tokenizer model name and a tokenizer instance.
|
||||
@@ -363,6 +366,7 @@ class TiktokenTokenizer(Tokenizer):
|
||||
"""
|
||||
A Tokenizer implementation using the tiktoken library.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str = "gpt-4o-mini"):
|
||||
"""
|
||||
Initializes the TiktokenTokenizer with a specified model name.
|
||||
@@ -385,9 +389,7 @@ class TiktokenTokenizer(Tokenizer):
|
||||
tokenizer = tiktoken.encoding_for_model(model_name)
|
||||
super().__init__(model_name=model_name, tokenizer=tokenizer)
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f"Invalid model_name: {model_name}."
|
||||
)
|
||||
raise ValueError(f"Invalid model_name: {model_name}.")
|
||||
|
||||
|
||||
def pack_user_ass_to_openai_messages(*args: str):
|
||||
@@ -424,7 +426,10 @@ def is_float_regex(value: str) -> bool:
|
||||
|
||||
|
||||
def truncate_list_by_token_size(
|
||||
list_data: list[Any], key: Callable[[Any], str], max_token_size: int, tokenizer: Tokenizer
|
||||
list_data: list[Any],
|
||||
key: Callable[[Any], str],
|
||||
max_token_size: int,
|
||||
tokenizer: Tokenizer,
|
||||
) -> list[int]:
|
||||
"""Truncate a list of data by token size"""
|
||||
if max_token_size <= 0:
|
||||
|
Reference in New Issue
Block a user