添加字符分割功能,在“insert”函数中如果增加参数split_by_character,则会按照split_by_character进行字符分割,此时如果每个分割后的chunk的tokens大于max_token_size,则会继续按token_size分割(todo:考虑字符分割后过短的chunk处理)
This commit is contained in:
@@ -45,6 +45,7 @@ from .storage import (
|
||||
|
||||
from .prompt import GRAPH_FIELD_SEP
|
||||
|
||||
|
||||
# future KG integrations
|
||||
|
||||
# from .kg.ArangoDB_impl import (
|
||||
@@ -167,7 +168,7 @@ class LightRAG:
|
||||
|
||||
# LLM
|
||||
llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete#
|
||||
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
|
||||
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" # 'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
|
||||
llm_model_max_token_size: int = 32768
|
||||
llm_model_max_async: int = 16
|
||||
llm_model_kwargs: dict = field(default_factory=dict)
|
||||
@@ -267,7 +268,7 @@ class LightRAG:
|
||||
self.llm_model_func,
|
||||
hashing_kv=self.llm_response_cache
|
||||
if self.llm_response_cache
|
||||
and hasattr(self.llm_response_cache, "global_config")
|
||||
and hasattr(self.llm_response_cache, "global_config")
|
||||
else self.key_string_value_json_storage_cls(
|
||||
namespace="llm_response_cache",
|
||||
global_config=asdict(self),
|
||||
@@ -313,15 +314,16 @@ class LightRAG:
|
||||
"JsonDocStatusStorage": JsonDocStatusStorage,
|
||||
}
|
||||
|
||||
def insert(self, string_or_strings):
|
||||
def insert(self, string_or_strings, split_by_character=None):
|
||||
loop = always_get_an_event_loop()
|
||||
return loop.run_until_complete(self.ainsert(string_or_strings))
|
||||
return loop.run_until_complete(self.ainsert(string_or_strings, split_by_character))
|
||||
|
||||
async def ainsert(self, string_or_strings):
|
||||
async def ainsert(self, string_or_strings, split_by_character):
|
||||
"""Insert documents with checkpoint support
|
||||
|
||||
Args:
|
||||
string_or_strings: Single document string or list of document strings
|
||||
split_by_character: if split_by_character is not None, split the string by character
|
||||
"""
|
||||
if isinstance(string_or_strings, str):
|
||||
string_or_strings = [string_or_strings]
|
||||
@@ -355,10 +357,10 @@ class LightRAG:
|
||||
# Process documents in batches
|
||||
batch_size = self.addon_params.get("insert_batch_size", 10)
|
||||
for i in range(0, len(new_docs), batch_size):
|
||||
batch_docs = dict(list(new_docs.items())[i : i + batch_size])
|
||||
batch_docs = dict(list(new_docs.items())[i: i + batch_size])
|
||||
|
||||
for doc_id, doc in tqdm_async(
|
||||
batch_docs.items(), desc=f"Processing batch {i//batch_size + 1}"
|
||||
batch_docs.items(), desc=f"Processing batch {i // batch_size + 1}"
|
||||
):
|
||||
try:
|
||||
# Update status to processing
|
||||
@@ -379,6 +381,7 @@ class LightRAG:
|
||||
}
|
||||
for dp in chunking_by_token_size(
|
||||
doc["content"],
|
||||
split_by_character=split_by_character,
|
||||
overlap_token_size=self.chunk_overlap_token_size,
|
||||
max_token_size=self.chunk_token_size,
|
||||
tiktoken_model=self.tiktoken_model_name,
|
||||
@@ -545,7 +548,7 @@ class LightRAG:
|
||||
# Check if nodes exist in the knowledge graph
|
||||
for need_insert_id in [src_id, tgt_id]:
|
||||
if not (
|
||||
await self.chunk_entity_relation_graph.has_node(need_insert_id)
|
||||
await self.chunk_entity_relation_graph.has_node(need_insert_id)
|
||||
):
|
||||
await self.chunk_entity_relation_graph.upsert_node(
|
||||
need_insert_id,
|
||||
@@ -594,9 +597,9 @@ class LightRAG:
|
||||
"src_id": dp["src_id"],
|
||||
"tgt_id": dp["tgt_id"],
|
||||
"content": dp["keywords"]
|
||||
+ dp["src_id"]
|
||||
+ dp["tgt_id"]
|
||||
+ dp["description"],
|
||||
+ dp["src_id"]
|
||||
+ dp["tgt_id"]
|
||||
+ dp["description"],
|
||||
}
|
||||
for dp in all_relationships_data
|
||||
}
|
||||
@@ -621,7 +624,7 @@ class LightRAG:
|
||||
asdict(self),
|
||||
hashing_kv=self.llm_response_cache
|
||||
if self.llm_response_cache
|
||||
and hasattr(self.llm_response_cache, "global_config")
|
||||
and hasattr(self.llm_response_cache, "global_config")
|
||||
else self.key_string_value_json_storage_cls(
|
||||
namespace="llm_response_cache",
|
||||
global_config=asdict(self),
|
||||
@@ -637,7 +640,7 @@ class LightRAG:
|
||||
asdict(self),
|
||||
hashing_kv=self.llm_response_cache
|
||||
if self.llm_response_cache
|
||||
and hasattr(self.llm_response_cache, "global_config")
|
||||
and hasattr(self.llm_response_cache, "global_config")
|
||||
else self.key_string_value_json_storage_cls(
|
||||
namespace="llm_response_cache",
|
||||
global_config=asdict(self),
|
||||
@@ -656,7 +659,7 @@ class LightRAG:
|
||||
asdict(self),
|
||||
hashing_kv=self.llm_response_cache
|
||||
if self.llm_response_cache
|
||||
and hasattr(self.llm_response_cache, "global_config")
|
||||
and hasattr(self.llm_response_cache, "global_config")
|
||||
else self.key_string_value_json_storage_cls(
|
||||
namespace="llm_response_cache",
|
||||
global_config=asdict(self),
|
||||
@@ -897,7 +900,7 @@ class LightRAG:
|
||||
dp
|
||||
for dp in self.entities_vdb.client_storage["data"]
|
||||
if chunk_id
|
||||
in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP)
|
||||
in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP)
|
||||
]
|
||||
if entities_with_chunk:
|
||||
logger.error(
|
||||
@@ -909,7 +912,7 @@ class LightRAG:
|
||||
dp
|
||||
for dp in self.relationships_vdb.client_storage["data"]
|
||||
if chunk_id
|
||||
in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP)
|
||||
in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP)
|
||||
]
|
||||
if relations_with_chunk:
|
||||
logger.error(
|
||||
@@ -926,7 +929,7 @@ class LightRAG:
|
||||
return asyncio.run(self.adelete_by_doc_id(doc_id))
|
||||
|
||||
async def get_entity_info(
|
||||
self, entity_name: str, include_vector_data: bool = False
|
||||
self, entity_name: str, include_vector_data: bool = False
|
||||
):
|
||||
"""Get detailed information of an entity
|
||||
|
||||
@@ -977,7 +980,7 @@ class LightRAG:
|
||||
tracemalloc.stop()
|
||||
|
||||
async def get_relation_info(
|
||||
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
|
||||
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
|
||||
):
|
||||
"""Get detailed information of a relationship
|
||||
|
||||
@@ -1019,7 +1022,7 @@ class LightRAG:
|
||||
return result
|
||||
|
||||
def get_relation_info_sync(
|
||||
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
|
||||
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
|
||||
):
|
||||
"""Synchronous version of getting relationship information
|
||||
|
||||
|
Reference in New Issue
Block a user