添加字符分割功能,在“insert”函数中如果增加参数split_by_character,则会按照split_by_character进行字符分割,此时如果每个分割后的chunk的tokens大于max_token_size,则会继续按token_size分割(todo:考虑字符分割后过短的chunk处理)

This commit is contained in:
童石渊
2025-01-07 00:28:15 +08:00
parent 39a366a3dc
commit 536d6f2283
2 changed files with 171 additions and 146 deletions

View File

@@ -45,6 +45,7 @@ from .storage import (
from .prompt import GRAPH_FIELD_SEP from .prompt import GRAPH_FIELD_SEP
# future KG integrations # future KG integrations
# from .kg.ArangoDB_impl import ( # from .kg.ArangoDB_impl import (
@@ -167,7 +168,7 @@ class LightRAG:
# LLM # LLM
llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete# llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete#
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it' llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" # 'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
llm_model_max_token_size: int = 32768 llm_model_max_token_size: int = 32768
llm_model_max_async: int = 16 llm_model_max_async: int = 16
llm_model_kwargs: dict = field(default_factory=dict) llm_model_kwargs: dict = field(default_factory=dict)
@@ -313,15 +314,16 @@ class LightRAG:
"JsonDocStatusStorage": JsonDocStatusStorage, "JsonDocStatusStorage": JsonDocStatusStorage,
} }
def insert(self, string_or_strings): def insert(self, string_or_strings, split_by_character=None):
loop = always_get_an_event_loop() loop = always_get_an_event_loop()
return loop.run_until_complete(self.ainsert(string_or_strings)) return loop.run_until_complete(self.ainsert(string_or_strings, split_by_character))
async def ainsert(self, string_or_strings): async def ainsert(self, string_or_strings, split_by_character):
"""Insert documents with checkpoint support """Insert documents with checkpoint support
Args: Args:
string_or_strings: Single document string or list of document strings string_or_strings: Single document string or list of document strings
split_by_character: if split_by_character is not None, split the string by character
""" """
if isinstance(string_or_strings, str): if isinstance(string_or_strings, str):
string_or_strings = [string_or_strings] string_or_strings = [string_or_strings]
@@ -355,10 +357,10 @@ class LightRAG:
# Process documents in batches # Process documents in batches
batch_size = self.addon_params.get("insert_batch_size", 10) batch_size = self.addon_params.get("insert_batch_size", 10)
for i in range(0, len(new_docs), batch_size): for i in range(0, len(new_docs), batch_size):
batch_docs = dict(list(new_docs.items())[i : i + batch_size]) batch_docs = dict(list(new_docs.items())[i: i + batch_size])
for doc_id, doc in tqdm_async( for doc_id, doc in tqdm_async(
batch_docs.items(), desc=f"Processing batch {i//batch_size + 1}" batch_docs.items(), desc=f"Processing batch {i // batch_size + 1}"
): ):
try: try:
# Update status to processing # Update status to processing
@@ -379,6 +381,7 @@ class LightRAG:
} }
for dp in chunking_by_token_size( for dp in chunking_by_token_size(
doc["content"], doc["content"],
split_by_character=split_by_character,
overlap_token_size=self.chunk_overlap_token_size, overlap_token_size=self.chunk_overlap_token_size,
max_token_size=self.chunk_token_size, max_token_size=self.chunk_token_size,
tiktoken_model=self.tiktoken_model_name, tiktoken_model=self.tiktoken_model_name,

View File

@@ -34,15 +34,37 @@ import time
def chunking_by_token_size( def chunking_by_token_size(
content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o" content: str, split_by_character=None, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o"
): ):
tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model) tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model)
results = [] results = []
if split_by_character:
raw_chunks = content.split(split_by_character)
new_chunks = []
for chunk in raw_chunks:
_tokens = encode_string_by_tiktoken(chunk, model_name=tiktoken_model)
if len(_tokens) > max_token_size:
for start in range(0, len(_tokens), max_token_size - overlap_token_size):
chunk_content = decode_tokens_by_tiktoken(
_tokens[start: start + max_token_size], model_name=tiktoken_model
)
new_chunks.append((min(max_token_size, len(_tokens) - start), chunk_content))
else:
new_chunks.append((len(_tokens), chunk))
for index, (_len, chunk) in enumerate(new_chunks):
results.append(
{
"tokens": _len,
"content": chunk.strip(),
"chunk_order_index": index,
}
)
else:
for index, start in enumerate( for index, start in enumerate(
range(0, len(tokens), max_token_size - overlap_token_size) range(0, len(tokens), max_token_size - overlap_token_size)
): ):
chunk_content = decode_tokens_by_tiktoken( chunk_content = decode_tokens_by_tiktoken(
tokens[start : start + max_token_size], model_name=tiktoken_model tokens[start: start + max_token_size], model_name=tiktoken_model
) )
results.append( results.append(
{ {
@@ -1213,7 +1235,7 @@ async def naive_query(
if len(response) > len(sys_prompt): if len(response) > len(sys_prompt):
response = ( response = (
response[len(sys_prompt) :] response[len(sys_prompt):]
.replace(sys_prompt, "") .replace(sys_prompt, "")
.replace("user", "") .replace("user", "")
.replace("model", "") .replace("model", "")