fix linting
This commit is contained in:
@@ -51,10 +51,12 @@ class GemmaTokenizer(Tokenizer):
|
|||||||
"google/gemma3": _TokenizerConfig(
|
"google/gemma3": _TokenizerConfig(
|
||||||
tokenizer_model_url="https://raw.githubusercontent.com/google/gemma_pytorch/cb7c0152a369e43908e769eb09e1ce6043afe084/tokenizer/gemma3_cleaned_262144_v2.spiece.model",
|
tokenizer_model_url="https://raw.githubusercontent.com/google/gemma_pytorch/cb7c0152a369e43908e769eb09e1ce6043afe084/tokenizer/gemma3_cleaned_262144_v2.spiece.model",
|
||||||
tokenizer_model_hash="1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c",
|
tokenizer_model_hash="1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c",
|
||||||
)
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, model_name: str = "gemini-2.0-flash", tokenizer_dir: Optional[str] = None):
|
def __init__(
|
||||||
|
self, model_name: str = "gemini-2.0-flash", tokenizer_dir: Optional[str] = None
|
||||||
|
):
|
||||||
# https://github.com/google/gemma_pytorch/tree/main/tokenizer
|
# https://github.com/google/gemma_pytorch/tree/main/tokenizer
|
||||||
if "1.5" in model_name or "1.0" in model_name:
|
if "1.5" in model_name or "1.0" in model_name:
|
||||||
# up to gemini 1.5 gemma2 is a comparable local tokenizer
|
# up to gemini 1.5 gemma2 is a comparable local tokenizer
|
||||||
@@ -77,7 +79,9 @@ class GemmaTokenizer(Tokenizer):
|
|||||||
else:
|
else:
|
||||||
model_data = None
|
model_data = None
|
||||||
if not model_data:
|
if not model_data:
|
||||||
model_data = self._load_from_url(file_url=file_url, expected_hash=expected_hash)
|
model_data = self._load_from_url(
|
||||||
|
file_url=file_url, expected_hash=expected_hash
|
||||||
|
)
|
||||||
self.save_tokenizer_to_cache(cache_path=file_path, model_data=model_data)
|
self.save_tokenizer_to_cache(cache_path=file_path, model_data=model_data)
|
||||||
|
|
||||||
tokenizer = spm.SentencePieceProcessor()
|
tokenizer = spm.SentencePieceProcessor()
|
||||||
@@ -140,7 +144,7 @@ class GemmaTokenizer(Tokenizer):
|
|||||||
|
|
||||||
# def encode(self, content: str) -> list[int]:
|
# def encode(self, content: str) -> list[int]:
|
||||||
# return self.tokenizer.encode(content)
|
# return self.tokenizer.encode(content)
|
||||||
|
|
||||||
# def decode(self, tokens: list[int]) -> str:
|
# def decode(self, tokens: list[int]) -> str:
|
||||||
# return self.tokenizer.decode(tokens)
|
# return self.tokenizer.decode(tokens)
|
||||||
|
|
||||||
@@ -187,7 +191,10 @@ async def initialize_rag():
|
|||||||
rag = LightRAG(
|
rag = LightRAG(
|
||||||
working_dir=WORKING_DIR,
|
working_dir=WORKING_DIR,
|
||||||
# tiktoken_model_name="gpt-4o-mini",
|
# tiktoken_model_name="gpt-4o-mini",
|
||||||
tokenizer=GemmaTokenizer(tokenizer_dir=(Path(WORKING_DIR) / "vertexai_tokenizer_model"), model_name="gemini-2.0-flash"),
|
tokenizer=GemmaTokenizer(
|
||||||
|
tokenizer_dir=(Path(WORKING_DIR) / "vertexai_tokenizer_model"),
|
||||||
|
model_name="gemini-2.0-flash",
|
||||||
|
),
|
||||||
llm_model_func=llm_model_func,
|
llm_model_func=llm_model_func,
|
||||||
embedding_func=EmbeddingFunc(
|
embedding_func=EmbeddingFunc(
|
||||||
embedding_dim=384,
|
embedding_dim=384,
|
||||||
|
@@ -10,7 +10,7 @@ from fastapi.responses import StreamingResponse
|
|||||||
import asyncio
|
import asyncio
|
||||||
from ascii_colors import trace_exception
|
from ascii_colors import trace_exception
|
||||||
from lightrag import LightRAG, QueryParam
|
from lightrag import LightRAG, QueryParam
|
||||||
from lightrag.utils import TiktokenTokenizer
|
from lightrag.utils import TiktokenTokenizer
|
||||||
from lightrag.api.utils_api import ollama_server_infos, get_combined_auth_dependency
|
from lightrag.api.utils_api import ollama_server_infos, get_combined_auth_dependency
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
|
|
||||||
|
@@ -7,7 +7,18 @@ import warnings
|
|||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, AsyncIterator, Callable, Iterator, cast, final, Literal, Optional, List, Dict
|
from typing import (
|
||||||
|
Any,
|
||||||
|
AsyncIterator,
|
||||||
|
Callable,
|
||||||
|
Iterator,
|
||||||
|
cast,
|
||||||
|
final,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
List,
|
||||||
|
Dict,
|
||||||
|
)
|
||||||
|
|
||||||
from lightrag.kg import (
|
from lightrag.kg import (
|
||||||
STORAGES,
|
STORAGES,
|
||||||
@@ -1147,11 +1158,7 @@ class LightRAG:
|
|||||||
for chunk_data in custom_kg.get("chunks", []):
|
for chunk_data in custom_kg.get("chunks", []):
|
||||||
chunk_content = clean_text(chunk_data["content"])
|
chunk_content = clean_text(chunk_data["content"])
|
||||||
source_id = chunk_data["source_id"]
|
source_id = chunk_data["source_id"]
|
||||||
tokens = len(
|
tokens = len(self.tokenizer.encode(chunk_content))
|
||||||
self.tokenizer.encode(
|
|
||||||
chunk_content
|
|
||||||
)
|
|
||||||
)
|
|
||||||
chunk_order_index = (
|
chunk_order_index = (
|
||||||
0
|
0
|
||||||
if "chunk_order_index" not in chunk_data.keys()
|
if "chunk_order_index" not in chunk_data.keys()
|
||||||
|
@@ -88,9 +88,7 @@ def chunking_by_token_size(
|
|||||||
for index, start in enumerate(
|
for index, start in enumerate(
|
||||||
range(0, len(tokens), max_token_size - overlap_token_size)
|
range(0, len(tokens), max_token_size - overlap_token_size)
|
||||||
):
|
):
|
||||||
chunk_content = tokenizer.decode(
|
chunk_content = tokenizer.decode(tokens[start : start + max_token_size])
|
||||||
tokens[start : start + max_token_size]
|
|
||||||
)
|
|
||||||
results.append(
|
results.append(
|
||||||
{
|
{
|
||||||
"tokens": min(max_token_size, len(tokens) - start),
|
"tokens": min(max_token_size, len(tokens) - start),
|
||||||
@@ -126,9 +124,7 @@ async def _handle_entity_relation_summary(
|
|||||||
if len(tokens) < summary_max_tokens: # No need for summary
|
if len(tokens) < summary_max_tokens: # No need for summary
|
||||||
return description
|
return description
|
||||||
prompt_template = PROMPTS["summarize_entity_descriptions"]
|
prompt_template = PROMPTS["summarize_entity_descriptions"]
|
||||||
use_description = tokenizer.decode(
|
use_description = tokenizer.decode(tokens[:llm_max_tokens])
|
||||||
tokens[:llm_max_tokens]
|
|
||||||
)
|
|
||||||
context_base = dict(
|
context_base = dict(
|
||||||
entity_name=entity_or_relation_name,
|
entity_name=entity_or_relation_name,
|
||||||
description_list=use_description.split(GRAPH_FIELD_SEP),
|
description_list=use_description.split(GRAPH_FIELD_SEP),
|
||||||
@@ -1378,10 +1374,15 @@ async def _get_node_data(
|
|||||||
] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
|
] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
|
||||||
# get entitytext chunk
|
# get entitytext chunk
|
||||||
use_text_units = await _find_most_related_text_unit_from_entities(
|
use_text_units = await _find_most_related_text_unit_from_entities(
|
||||||
node_datas, query_param, text_chunks_db, knowledge_graph_inst,
|
node_datas,
|
||||||
|
query_param,
|
||||||
|
text_chunks_db,
|
||||||
|
knowledge_graph_inst,
|
||||||
)
|
)
|
||||||
use_relations = await _find_most_related_edges_from_entities(
|
use_relations = await _find_most_related_edges_from_entities(
|
||||||
node_datas, query_param, knowledge_graph_inst,
|
node_datas,
|
||||||
|
query_param,
|
||||||
|
knowledge_graph_inst,
|
||||||
)
|
)
|
||||||
|
|
||||||
tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
|
tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
|
||||||
@@ -1703,10 +1704,15 @@ async def _get_edge_data(
|
|||||||
)
|
)
|
||||||
use_entities, use_text_units = await asyncio.gather(
|
use_entities, use_text_units = await asyncio.gather(
|
||||||
_find_most_related_entities_from_relationships(
|
_find_most_related_entities_from_relationships(
|
||||||
edge_datas, query_param, knowledge_graph_inst,
|
edge_datas,
|
||||||
|
query_param,
|
||||||
|
knowledge_graph_inst,
|
||||||
),
|
),
|
||||||
_find_related_text_unit_from_relationships(
|
_find_related_text_unit_from_relationships(
|
||||||
edge_datas, query_param, text_chunks_db, knowledge_graph_inst,
|
edge_datas,
|
||||||
|
query_param,
|
||||||
|
text_chunks_db,
|
||||||
|
knowledge_graph_inst,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
|
@@ -12,7 +12,7 @@ import re
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from hashlib import md5
|
from hashlib import md5
|
||||||
from typing import Any, Protocol, Callable, TYPE_CHECKING, List, Optional, Union
|
from typing import Any, Protocol, Callable, TYPE_CHECKING, List
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from lightrag.prompt import PROMPTS
|
from lightrag.prompt import PROMPTS
|
||||||
@@ -311,6 +311,7 @@ class TokenizerInterface(Protocol):
|
|||||||
"""
|
"""
|
||||||
Defines the interface for a tokenizer, requiring encode and decode methods.
|
Defines the interface for a tokenizer, requiring encode and decode methods.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def encode(self, content: str) -> List[int]:
|
def encode(self, content: str) -> List[int]:
|
||||||
"""Encodes a string into a list of tokens."""
|
"""Encodes a string into a list of tokens."""
|
||||||
...
|
...
|
||||||
@@ -319,10 +320,12 @@ class TokenizerInterface(Protocol):
|
|||||||
"""Decodes a list of tokens into a string."""
|
"""Decodes a list of tokens into a string."""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
class Tokenizer:
|
class Tokenizer:
|
||||||
"""
|
"""
|
||||||
A wrapper around a tokenizer to provide a consistent interface for encoding and decoding.
|
A wrapper around a tokenizer to provide a consistent interface for encoding and decoding.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model_name: str, tokenizer: TokenizerInterface):
|
def __init__(self, model_name: str, tokenizer: TokenizerInterface):
|
||||||
"""
|
"""
|
||||||
Initializes the Tokenizer with a tokenizer model name and a tokenizer instance.
|
Initializes the Tokenizer with a tokenizer model name and a tokenizer instance.
|
||||||
@@ -363,6 +366,7 @@ class TiktokenTokenizer(Tokenizer):
|
|||||||
"""
|
"""
|
||||||
A Tokenizer implementation using the tiktoken library.
|
A Tokenizer implementation using the tiktoken library.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model_name: str = "gpt-4o-mini"):
|
def __init__(self, model_name: str = "gpt-4o-mini"):
|
||||||
"""
|
"""
|
||||||
Initializes the TiktokenTokenizer with a specified model name.
|
Initializes the TiktokenTokenizer with a specified model name.
|
||||||
@@ -385,9 +389,7 @@ class TiktokenTokenizer(Tokenizer):
|
|||||||
tokenizer = tiktoken.encoding_for_model(model_name)
|
tokenizer = tiktoken.encoding_for_model(model_name)
|
||||||
super().__init__(model_name=model_name, tokenizer=tokenizer)
|
super().__init__(model_name=model_name, tokenizer=tokenizer)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise ValueError(
|
raise ValueError(f"Invalid model_name: {model_name}.")
|
||||||
f"Invalid model_name: {model_name}."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def pack_user_ass_to_openai_messages(*args: str):
|
def pack_user_ass_to_openai_messages(*args: str):
|
||||||
@@ -424,7 +426,10 @@ def is_float_regex(value: str) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def truncate_list_by_token_size(
|
def truncate_list_by_token_size(
|
||||||
list_data: list[Any], key: Callable[[Any], str], max_token_size: int, tokenizer: Tokenizer
|
list_data: list[Any],
|
||||||
|
key: Callable[[Any], str],
|
||||||
|
max_token_size: int,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
) -> list[int]:
|
) -> list[int]:
|
||||||
"""Truncate a list of data by token size"""
|
"""Truncate a list of data by token size"""
|
||||||
if max_token_size <= 0:
|
if max_token_size <= 0:
|
||||||
|
Reference in New Issue
Block a user