添加字符分割功能,在“insert”函数中如果增加参数split_by_character,则会按照split_by_character进行字符分割,此时如果每个分割后的chunk的tokens大于max_token_size,则会继续按token_size分割(todo:考虑字符分割后过短的chunk处理)

This commit is contained in:
童石渊
2025-01-07 00:28:15 +08:00
parent 39a366a3dc
commit 536d6f2283
2 changed files with 171 additions and 146 deletions

View File

@@ -45,6 +45,7 @@ from .storage import (
from .prompt import GRAPH_FIELD_SEP
# future KG integrations
# from .kg.ArangoDB_impl import (
@@ -167,7 +168,7 @@ class LightRAG:
# LLM
llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete#
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" # 'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
llm_model_max_token_size: int = 32768
llm_model_max_async: int = 16
llm_model_kwargs: dict = field(default_factory=dict)
@@ -267,7 +268,7 @@ class LightRAG:
self.llm_model_func,
hashing_kv=self.llm_response_cache
if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config")
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
global_config=asdict(self),
@@ -313,15 +314,16 @@ class LightRAG:
"JsonDocStatusStorage": JsonDocStatusStorage,
}
def insert(self, string_or_strings):
def insert(self, string_or_strings, split_by_character=None):
loop = always_get_an_event_loop()
return loop.run_until_complete(self.ainsert(string_or_strings))
return loop.run_until_complete(self.ainsert(string_or_strings, split_by_character))
async def ainsert(self, string_or_strings):
async def ainsert(self, string_or_strings, split_by_character):
"""Insert documents with checkpoint support
Args:
string_or_strings: Single document string or list of document strings
split_by_character: if split_by_character is not None, split the string by character
"""
if isinstance(string_or_strings, str):
string_or_strings = [string_or_strings]
@@ -355,10 +357,10 @@ class LightRAG:
# Process documents in batches
batch_size = self.addon_params.get("insert_batch_size", 10)
for i in range(0, len(new_docs), batch_size):
batch_docs = dict(list(new_docs.items())[i : i + batch_size])
batch_docs = dict(list(new_docs.items())[i: i + batch_size])
for doc_id, doc in tqdm_async(
batch_docs.items(), desc=f"Processing batch {i//batch_size + 1}"
batch_docs.items(), desc=f"Processing batch {i // batch_size + 1}"
):
try:
# Update status to processing
@@ -379,6 +381,7 @@ class LightRAG:
}
for dp in chunking_by_token_size(
doc["content"],
split_by_character=split_by_character,
overlap_token_size=self.chunk_overlap_token_size,
max_token_size=self.chunk_token_size,
tiktoken_model=self.tiktoken_model_name,
@@ -545,7 +548,7 @@ class LightRAG:
# Check if nodes exist in the knowledge graph
for need_insert_id in [src_id, tgt_id]:
if not (
await self.chunk_entity_relation_graph.has_node(need_insert_id)
await self.chunk_entity_relation_graph.has_node(need_insert_id)
):
await self.chunk_entity_relation_graph.upsert_node(
need_insert_id,
@@ -594,9 +597,9 @@ class LightRAG:
"src_id": dp["src_id"],
"tgt_id": dp["tgt_id"],
"content": dp["keywords"]
+ dp["src_id"]
+ dp["tgt_id"]
+ dp["description"],
+ dp["src_id"]
+ dp["tgt_id"]
+ dp["description"],
}
for dp in all_relationships_data
}
@@ -621,7 +624,7 @@ class LightRAG:
asdict(self),
hashing_kv=self.llm_response_cache
if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config")
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
global_config=asdict(self),
@@ -637,7 +640,7 @@ class LightRAG:
asdict(self),
hashing_kv=self.llm_response_cache
if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config")
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
global_config=asdict(self),
@@ -656,7 +659,7 @@ class LightRAG:
asdict(self),
hashing_kv=self.llm_response_cache
if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config")
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
global_config=asdict(self),
@@ -897,7 +900,7 @@ class LightRAG:
dp
for dp in self.entities_vdb.client_storage["data"]
if chunk_id
in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP)
in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP)
]
if entities_with_chunk:
logger.error(
@@ -909,7 +912,7 @@ class LightRAG:
dp
for dp in self.relationships_vdb.client_storage["data"]
if chunk_id
in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP)
in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP)
]
if relations_with_chunk:
logger.error(
@@ -926,7 +929,7 @@ class LightRAG:
return asyncio.run(self.adelete_by_doc_id(doc_id))
async def get_entity_info(
self, entity_name: str, include_vector_data: bool = False
self, entity_name: str, include_vector_data: bool = False
):
"""Get detailed information of an entity
@@ -977,7 +980,7 @@ class LightRAG:
tracemalloc.stop()
async def get_relation_info(
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
):
"""Get detailed information of a relationship
@@ -1019,7 +1022,7 @@ class LightRAG:
return result
def get_relation_info_sync(
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
):
"""Synchronous version of getting relationship information