添加字符分割功能,在“insert”函数中如果增加参数split_by_character,则会按照split_by_character进行字符分割,此时如果每个分割后的chunk的tokens大于max_token_size,则会继续按token_size分割(todo:考虑字符分割后过短的chunk处理)
This commit is contained in:
@@ -45,6 +45,7 @@ from .storage import (
|
||||
|
||||
from .prompt import GRAPH_FIELD_SEP
|
||||
|
||||
|
||||
# future KG integrations
|
||||
|
||||
# from .kg.ArangoDB_impl import (
|
||||
@@ -167,7 +168,7 @@ class LightRAG:
|
||||
|
||||
# LLM
|
||||
llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete#
|
||||
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
|
||||
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" # 'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
|
||||
llm_model_max_token_size: int = 32768
|
||||
llm_model_max_async: int = 16
|
||||
llm_model_kwargs: dict = field(default_factory=dict)
|
||||
@@ -313,15 +314,16 @@ class LightRAG:
|
||||
"JsonDocStatusStorage": JsonDocStatusStorage,
|
||||
}
|
||||
|
||||
def insert(self, string_or_strings):
|
||||
def insert(self, string_or_strings, split_by_character=None):
|
||||
loop = always_get_an_event_loop()
|
||||
return loop.run_until_complete(self.ainsert(string_or_strings))
|
||||
return loop.run_until_complete(self.ainsert(string_or_strings, split_by_character))
|
||||
|
||||
async def ainsert(self, string_or_strings):
|
||||
async def ainsert(self, string_or_strings, split_by_character):
|
||||
"""Insert documents with checkpoint support
|
||||
|
||||
Args:
|
||||
string_or_strings: Single document string or list of document strings
|
||||
split_by_character: if split_by_character is not None, split the string by character
|
||||
"""
|
||||
if isinstance(string_or_strings, str):
|
||||
string_or_strings = [string_or_strings]
|
||||
@@ -355,10 +357,10 @@ class LightRAG:
|
||||
# Process documents in batches
|
||||
batch_size = self.addon_params.get("insert_batch_size", 10)
|
||||
for i in range(0, len(new_docs), batch_size):
|
||||
batch_docs = dict(list(new_docs.items())[i : i + batch_size])
|
||||
batch_docs = dict(list(new_docs.items())[i: i + batch_size])
|
||||
|
||||
for doc_id, doc in tqdm_async(
|
||||
batch_docs.items(), desc=f"Processing batch {i//batch_size + 1}"
|
||||
batch_docs.items(), desc=f"Processing batch {i // batch_size + 1}"
|
||||
):
|
||||
try:
|
||||
# Update status to processing
|
||||
@@ -379,6 +381,7 @@ class LightRAG:
|
||||
}
|
||||
for dp in chunking_by_token_size(
|
||||
doc["content"],
|
||||
split_by_character=split_by_character,
|
||||
overlap_token_size=self.chunk_overlap_token_size,
|
||||
max_token_size=self.chunk_token_size,
|
||||
tiktoken_model=self.tiktoken_model_name,
|
||||
|
@@ -34,15 +34,37 @@ import time
|
||||
|
||||
|
||||
def chunking_by_token_size(
|
||||
content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o"
|
||||
content: str, split_by_character=None, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o"
|
||||
):
|
||||
tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model)
|
||||
results = []
|
||||
if split_by_character:
|
||||
raw_chunks = content.split(split_by_character)
|
||||
new_chunks = []
|
||||
for chunk in raw_chunks:
|
||||
_tokens = encode_string_by_tiktoken(chunk, model_name=tiktoken_model)
|
||||
if len(_tokens) > max_token_size:
|
||||
for start in range(0, len(_tokens), max_token_size - overlap_token_size):
|
||||
chunk_content = decode_tokens_by_tiktoken(
|
||||
_tokens[start: start + max_token_size], model_name=tiktoken_model
|
||||
)
|
||||
new_chunks.append((min(max_token_size, len(_tokens) - start), chunk_content))
|
||||
else:
|
||||
new_chunks.append((len(_tokens), chunk))
|
||||
for index, (_len, chunk) in enumerate(new_chunks):
|
||||
results.append(
|
||||
{
|
||||
"tokens": _len,
|
||||
"content": chunk.strip(),
|
||||
"chunk_order_index": index,
|
||||
}
|
||||
)
|
||||
else:
|
||||
for index, start in enumerate(
|
||||
range(0, len(tokens), max_token_size - overlap_token_size)
|
||||
):
|
||||
chunk_content = decode_tokens_by_tiktoken(
|
||||
tokens[start : start + max_token_size], model_name=tiktoken_model
|
||||
tokens[start: start + max_token_size], model_name=tiktoken_model
|
||||
)
|
||||
results.append(
|
||||
{
|
||||
@@ -1213,7 +1235,7 @@ async def naive_query(
|
||||
|
||||
if len(response) > len(sys_prompt):
|
||||
response = (
|
||||
response[len(sys_prompt) :]
|
||||
response[len(sys_prompt):]
|
||||
.replace(sys_prompt, "")
|
||||
.replace("user", "")
|
||||
.replace("model", "")
|
||||
|
Reference in New Issue
Block a user