Merge branch 'HKUDS:main' into main
This commit is contained in:
@@ -45,6 +45,7 @@ from .storage import (
|
||||
|
||||
from .prompt import GRAPH_FIELD_SEP
|
||||
|
||||
|
||||
# future KG integrations
|
||||
|
||||
# from .kg.ArangoDB_impl import (
|
||||
@@ -168,7 +169,7 @@ class LightRAG:
|
||||
|
||||
# LLM
|
||||
llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete#
|
||||
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
|
||||
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" # 'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
|
||||
llm_model_max_token_size: int = 32768
|
||||
llm_model_max_async: int = 16
|
||||
llm_model_kwargs: dict = field(default_factory=dict)
|
||||
@@ -187,6 +188,10 @@ class LightRAG:
|
||||
# Add new field for document status storage type
|
||||
doc_status_storage: str = field(default="JsonDocStatusStorage")
|
||||
|
||||
# Custom Chunking Function
|
||||
chunking_func: callable = chunking_by_token_size
|
||||
chunking_func_kwargs: dict = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
log_file = os.path.join("lightrag.log")
|
||||
set_logger(log_file)
|
||||
@@ -315,15 +320,25 @@ class LightRAG:
|
||||
"JsonDocStatusStorage": JsonDocStatusStorage,
|
||||
}
|
||||
|
||||
def insert(self, string_or_strings):
|
||||
def insert(
|
||||
self, string_or_strings, split_by_character=None, split_by_character_only=False
|
||||
):
|
||||
loop = always_get_an_event_loop()
|
||||
return loop.run_until_complete(self.ainsert(string_or_strings))
|
||||
return loop.run_until_complete(
|
||||
self.ainsert(string_or_strings, split_by_character, split_by_character_only)
|
||||
)
|
||||
|
||||
async def ainsert(self, string_or_strings):
|
||||
async def ainsert(
|
||||
self, string_or_strings, split_by_character=None, split_by_character_only=False
|
||||
):
|
||||
"""Insert documents with checkpoint support
|
||||
|
||||
Args:
|
||||
string_or_strings: Single document string or list of document strings
|
||||
split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
|
||||
chunk_size, split the sub chunk by token size.
|
||||
split_by_character_only: if split_by_character_only is True, split the string by character only, when
|
||||
split_by_character is None, this parameter is ignored.
|
||||
"""
|
||||
if isinstance(string_or_strings, str):
|
||||
string_or_strings = [string_or_strings]
|
||||
@@ -360,7 +375,7 @@ class LightRAG:
|
||||
batch_docs = dict(list(new_docs.items())[i : i + batch_size])
|
||||
|
||||
for doc_id, doc in tqdm_async(
|
||||
batch_docs.items(), desc=f"Processing batch {i//batch_size + 1}"
|
||||
batch_docs.items(), desc=f"Processing batch {i // batch_size + 1}"
|
||||
):
|
||||
try:
|
||||
# Update status to processing
|
||||
@@ -379,11 +394,14 @@ class LightRAG:
|
||||
**dp,
|
||||
"full_doc_id": doc_id,
|
||||
}
|
||||
for dp in chunking_by_token_size(
|
||||
for dp in self.chunking_func(
|
||||
doc["content"],
|
||||
split_by_character=split_by_character,
|
||||
split_by_character_only=split_by_character_only,
|
||||
overlap_token_size=self.chunk_overlap_token_size,
|
||||
max_token_size=self.chunk_token_size,
|
||||
tiktoken_model=self.tiktoken_model_name,
|
||||
**self.chunking_func_kwargs,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -455,6 +473,73 @@ class LightRAG:
|
||||
# Ensure all indexes are updated after each document
|
||||
await self._insert_done()
|
||||
|
||||
def insert_custom_chunks(self, full_text: str, text_chunks: list[str]):
|
||||
loop = always_get_an_event_loop()
|
||||
return loop.run_until_complete(
|
||||
self.ainsert_custom_chunks(full_text, text_chunks)
|
||||
)
|
||||
|
||||
async def ainsert_custom_chunks(self, full_text: str, text_chunks: list[str]):
|
||||
update_storage = False
|
||||
try:
|
||||
doc_key = compute_mdhash_id(full_text.strip(), prefix="doc-")
|
||||
new_docs = {doc_key: {"content": full_text.strip()}}
|
||||
|
||||
_add_doc_keys = await self.full_docs.filter_keys([doc_key])
|
||||
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
|
||||
if not len(new_docs):
|
||||
logger.warning("This document is already in the storage.")
|
||||
return
|
||||
|
||||
update_storage = True
|
||||
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
|
||||
|
||||
inserting_chunks = {}
|
||||
for chunk_text in text_chunks:
|
||||
chunk_text_stripped = chunk_text.strip()
|
||||
chunk_key = compute_mdhash_id(chunk_text_stripped, prefix="chunk-")
|
||||
|
||||
inserting_chunks[chunk_key] = {
|
||||
"content": chunk_text_stripped,
|
||||
"full_doc_id": doc_key,
|
||||
}
|
||||
|
||||
_add_chunk_keys = await self.text_chunks.filter_keys(
|
||||
list(inserting_chunks.keys())
|
||||
)
|
||||
inserting_chunks = {
|
||||
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
|
||||
}
|
||||
if not len(inserting_chunks):
|
||||
logger.warning("All chunks are already in the storage.")
|
||||
return
|
||||
|
||||
logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
|
||||
|
||||
await self.chunks_vdb.upsert(inserting_chunks)
|
||||
|
||||
logger.info("[Entity Extraction]...")
|
||||
maybe_new_kg = await extract_entities(
|
||||
inserting_chunks,
|
||||
knowledge_graph_inst=self.chunk_entity_relation_graph,
|
||||
entity_vdb=self.entities_vdb,
|
||||
relationships_vdb=self.relationships_vdb,
|
||||
global_config=asdict(self),
|
||||
)
|
||||
|
||||
if maybe_new_kg is None:
|
||||
logger.warning("No new entities and relationships found")
|
||||
return
|
||||
else:
|
||||
self.chunk_entity_relation_graph = maybe_new_kg
|
||||
|
||||
await self.full_docs.upsert(new_docs)
|
||||
await self.text_chunks.upsert(inserting_chunks)
|
||||
|
||||
finally:
|
||||
if update_storage:
|
||||
await self._insert_done()
|
||||
|
||||
async def _insert_done(self):
|
||||
tasks = []
|
||||
for storage_inst in [
|
||||
|
Reference in New Issue
Block a user