cleaned insert by using pipe

This commit is contained in:
Yannick Stephan
2025-02-09 11:10:46 +01:00
parent af477e8a26
commit fd77099af5

View File

@@ -4,17 +4,12 @@ from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import asdict, dataclass, field
from datetime import datetime
from functools import partial
from typing import Type, cast, Dict
from typing import Any, Type, Union
import traceback
from .operate import (
chunking_by_token_size,
extract_entities,
# local_query,global_query,hybrid_query,
kg_query,
naive_query,
mix_kg_vector_query,
extract_keywords_only,
kg_query_with_keywords,
extract_entities
# local_query,global_query,hybrid_query,,
)
from .utils import (
@@ -30,8 +25,6 @@ from .base import (
BaseGraphStorage,
BaseKVStorage,
BaseVectorStorage,
StorageNameSpace,
QueryParam,
DocStatus,
)
@@ -176,7 +169,7 @@ class LightRAG:
enable_llm_cache_for_entity_extract: bool = True
# extension
addon_params: dict = field(default_factory=dict)
addon_params: dict[str, Any] = field(default_factory=dict)
convert_response_to_json_func: callable = convert_response_to_json
# Add new field for document status storage type
@@ -251,7 +244,7 @@ class LightRAG:
),
embedding_func=self.embedding_func,
)
self.text_chunks: BaseVectorStorage = self.key_string_value_json_storage_cls(
self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS
),
@@ -281,7 +274,7 @@ class LightRAG:
embedding_func=self.embedding_func,
meta_fields={"src_id", "tgt_id"},
)
self.chunks_vdb = self.vector_db_storage_cls(
self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS
),
@@ -310,7 +303,7 @@ class LightRAG:
# Initialize document status storage
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
self.doc_status = self.doc_status_storage_cls(
self.doc_status: BaseKVStorage = self.doc_status_storage_cls(
namespace=make_namespace(self.namespace_prefix, NameSpace.DOC_STATUS),
global_config=global_config,
embedding_func=None,
@@ -359,7 +352,9 @@ class LightRAG:
)
async def ainsert(
self, string_or_strings, split_by_character=None, split_by_character_only=False
self, string_or_strings: Union[str, list[str]],
split_by_character: str | None = None,
split_by_character_only: bool = False
):
"""Insert documents with checkpoint support
@@ -370,154 +365,10 @@ class LightRAG:
split_by_character_only: if split_by_character_only is True, split the string by character only, when
split_by_character is None, this parameter is ignored.
"""
if isinstance(string_or_strings, str):
string_or_strings = [string_or_strings]
await self.apipeline_process_documents(string_or_strings)
await self.apipeline_process_chunks(split_by_character, split_by_character_only)
await self.apipeline_process_extract_graph()
# 1. Remove duplicate contents from the list
unique_contents = list(set(doc.strip() for doc in string_or_strings))
# 2. Generate document IDs and initial status
new_docs = {
compute_mdhash_id(content, prefix="doc-"): {
"content": content,
"content_summary": self._get_content_summary(content),
"content_length": len(content),
"status": DocStatus.PENDING,
"created_at": datetime.now().isoformat(),
"updated_at": datetime.now().isoformat(),
}
for content in unique_contents
}
# 3. Filter out already processed documents
# _add_doc_keys = await self.doc_status.filter_keys(list(new_docs.keys()))
_add_doc_keys = set()
for doc_id in new_docs.keys():
current_doc = await self.doc_status.get_by_id(doc_id)
if current_doc is None:
_add_doc_keys.add(doc_id)
continue # skip to the next doc_id
status = None
if isinstance(current_doc, dict):
status = current_doc["status"]
else:
status = current_doc.status
if status == DocStatus.FAILED:
_add_doc_keys.add(doc_id)
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
if not new_docs:
logger.info("All documents have been processed or are duplicates")
return
logger.info(f"Processing {len(new_docs)} new unique documents")
# Process documents in batches
batch_size = self.addon_params.get("insert_batch_size", 10)
for i in range(0, len(new_docs), batch_size):
batch_docs = dict(list(new_docs.items())[i : i + batch_size])
for doc_id, doc in tqdm_async(
batch_docs.items(), desc=f"Processing batch {i // batch_size + 1}"
):
try:
# Update status to processing
doc_status = {
"content_summary": doc["content_summary"],
"content_length": doc["content_length"],
"status": DocStatus.PROCESSING,
"created_at": doc["created_at"],
"updated_at": datetime.now().isoformat(),
}
await self.doc_status.upsert({doc_id: doc_status})
# Generate chunks from document
chunks = {
compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp,
"full_doc_id": doc_id,
}
for dp in self.chunking_func(
doc["content"],
split_by_character=split_by_character,
split_by_character_only=split_by_character_only,
overlap_token_size=self.chunk_overlap_token_size,
max_token_size=self.chunk_token_size,
tiktoken_model=self.tiktoken_model_name,
**self.chunking_func_kwargs,
)
}
# Update status with chunks information
doc_status.update(
{
"chunks_count": len(chunks),
"updated_at": datetime.now().isoformat(),
}
)
await self.doc_status.upsert({doc_id: doc_status})
try:
# Store chunks in vector database
await self.chunks_vdb.upsert(chunks)
# Extract and store entities and relationships
maybe_new_kg = await extract_entities(
chunks,
knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb,
llm_response_cache=self.llm_response_cache,
global_config=asdict(self),
)
if maybe_new_kg is None:
raise Exception(
"Failed to extract entities and relationships"
)
self.chunk_entity_relation_graph = maybe_new_kg
# Store original document and chunks
await self.full_docs.upsert(
{doc_id: {"content": doc["content"]}}
)
await self.text_chunks.upsert(chunks)
# Update status to processed
doc_status.update(
{
"status": DocStatus.PROCESSED,
"updated_at": datetime.now().isoformat(),
}
)
await self.doc_status.upsert({doc_id: doc_status})
except Exception as e:
# Mark as failed if any step fails
doc_status.update(
{
"status": DocStatus.FAILED,
"error": str(e),
"updated_at": datetime.now().isoformat(),
}
)
await self.doc_status.upsert({doc_id: doc_status})
raise e
except Exception as e:
import traceback
error_msg = f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
continue
else:
# Only update index when processing succeeds
await self._insert_done()
def insert_custom_chunks(self, full_text: str, text_chunks: list[str]):
loop = always_get_an_event_loop()
@@ -597,34 +448,32 @@ class LightRAG:
# 1. Remove duplicate contents from the list
unique_contents = list(set(doc.strip() for doc in string_or_strings))
logger.info(
f"Received {len(string_or_strings)} docs, contains {len(unique_contents)} new unique documents"
)
# 2. Generate document IDs and initial status
new_docs = {
new_docs: dict[str, Any] = {
compute_mdhash_id(content, prefix="doc-"): {
"content": content,
"content_summary": self._get_content_summary(content),
"content_length": len(content),
"status": DocStatus.PENDING,
"created_at": datetime.now().isoformat(),
"updated_at": None,
"updated_at": datetime.now().isoformat(),
}
for content in unique_contents
}
# 3. Filter out already processed documents
_not_stored_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
if len(_not_stored_doc_keys) < len(new_docs):
logger.info(
f"Skipping {len(new_docs) - len(_not_stored_doc_keys)} already existing documents"
)
new_docs = {k: v for k, v in new_docs.items() if k in _not_stored_doc_keys}
_add_doc_keys: set[str] = set()
for doc_id in new_docs.keys():
current_doc = await self.doc_status.get_by_id(doc_id)
if not current_doc or current_doc["status"] == DocStatus.FAILED:
_add_doc_keys.add(doc_id)
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
if not new_docs:
logger.info("All documents have been processed or are duplicates")
return None
return
# 4. Store original document
for doc_id, doc in new_docs.items():
@@ -633,96 +482,121 @@ class LightRAG:
)
logger.info(f"Stored {len(new_docs)} new unique documents")
async def apipeline_process_chunks(self):
async def apipeline_process_chunks(
self,
split_by_character: str | None = None,
split_by_character_only: bool = False
) -> None:
"""Get pendding documents, split into chunks,insert chunks"""
# 1. get all pending and failed documents
_todo_doc_keys = []
to_process_doc_keys: list[str] = []
_failed_doc = await self.full_docs.get_by_status_and_ids(
# Process failes
to_process_docs = await self.full_docs.get_by_status(
status=DocStatus.FAILED
)
_pendding_doc = await self.full_docs.get_by_status_and_ids(
if to_process_docs:
to_process_doc_keys.extend([doc["id"] for doc in to_process_docs])
# Process Pending
to_process_docs = await self.full_docs.get_by_status(
status=DocStatus.PENDING
)
if to_process_docs:
to_process_doc_keys.extend([doc["id"] for doc in to_process_docs])
if _failed_doc:
_todo_doc_keys.extend([doc["id"] for doc in _failed_doc])
if _pendding_doc:
_todo_doc_keys.extend([doc["id"] for doc in _pendding_doc])
if not _todo_doc_keys:
if not to_process_doc_keys:
logger.info("All documents have been processed or are duplicates")
return None
else:
logger.info(f"Filtered out {len(_todo_doc_keys)} not processed documents")
return
new_docs = {
doc["id"]: doc for doc in await self.full_docs.get_by_ids(_todo_doc_keys)
}
full_docs_ids = await self.full_docs.get_by_ids(to_process_doc_keys)
new_docs = {}
if full_docs_ids:
new_docs = {doc["id"]: doc for doc in full_docs_ids or []}
if not new_docs:
logger.info("All documents have been processed or are duplicates")
return
# 2. split docs into chunks, insert chunks, update doc status
chunk_cnt = 0
batch_size = self.addon_params.get("insert_batch_size", 10)
for i in range(0, len(new_docs), batch_size):
batch_docs = dict(list(new_docs.items())[i : i + batch_size])
for doc_id, doc in tqdm_async(
batch_docs.items(),
desc=f"Level 1 - Spliting doc in batch {i // batch_size + 1}",
batch_docs.items(), desc=f"Processing batch {i // batch_size + 1}"
):
doc_status: dict[str, Any] = {
"content_summary": doc["content_summary"],
"content_length": doc["content_length"],
"status": DocStatus.PROCESSING,
"created_at": doc["created_at"],
"updated_at": datetime.now().isoformat(),
}
try:
await self.doc_status.upsert({doc_id: doc_status})
# Generate chunks from document
chunks = {
chunks: dict[str, Any] = {
compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp,
"full_doc_id": doc_id,
"status": DocStatus.PROCESSED,
}
for dp in chunking_by_token_size(
for dp in self.chunking_func(
doc["content"],
split_by_character=split_by_character,
split_by_character_only=split_by_character_only,
overlap_token_size=self.chunk_overlap_token_size,
max_token_size=self.chunk_token_size,
tiktoken_model=self.tiktoken_model_name,
**self.chunking_func_kwargs,
)
}
chunk_cnt += len(chunks)
try:
# Store chunks in vector database
# Update status with chunks information
doc_status.update(
{
"chunks_count": len(chunks),
"updated_at": datetime.now().isoformat(),
}
)
await self.doc_status.upsert({doc_id: doc_status})
await self.chunks_vdb.upsert(chunks)
# Update doc status
await self.text_chunks.upsert(
{**chunks, "status": DocStatus.PENDING}
)
except Exception as e:
# Mark as failed if any step fails
await self.text_chunks.upsert(
{**chunks, "status": DocStatus.FAILED}
)
raise e
except Exception as e:
import traceback
error_msg = f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
except Exception as e:
doc_status.update(
{
"status": DocStatus.FAILED,
"error": str(e),
"updated_at": datetime.now().isoformat(),
}
)
await self.doc_status.upsert({doc_id: doc_status})
logger.error(f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}")
continue
logger.info(f"Stored {chunk_cnt} chunks from {len(new_docs)} documents")
async def apipeline_process_extract_graph(self):
"""Get pendding or failed chunks, extract entities and relationships from each chunk"""
# 1. get all pending and failed chunks
_todo_chunk_keys = []
_failed_chunks = await self.text_chunks.get_by_status_and_ids(
to_process_doc_keys: list[str] = []
# Process failes
to_process_docs = await self.full_docs.get_by_status(
status=DocStatus.FAILED
)
_pendding_chunks = await self.text_chunks.get_by_status_and_ids(
if to_process_docs:
to_process_doc_keys.extend([doc["id"] for doc in to_process_docs])
# Process Pending
to_process_docs = await self.full_docs.get_by_status(
status=DocStatus.PENDING
)
if _failed_chunks:
_todo_chunk_keys.extend([doc["id"] for doc in _failed_chunks])
if _pendding_chunks:
_todo_chunk_keys.extend([doc["id"] for doc in _pendding_chunks])
if not _todo_chunk_keys:
logger.info("All chunks have been processed or are duplicates")
return None
if to_process_docs:
to_process_doc_keys.extend([doc["id"] for doc in to_process_docs])
if not to_process_doc_keys:
logger.info("All documents have been processed or are duplicates")
return
# Process documents in batches
batch_size = self.addon_params.get("insert_batch_size", 10)
@@ -731,9 +605,9 @@ class LightRAG:
batch_size
) # Control the number of tasks that are processed simultaneously
async def process_chunk(chunk_id):
async def process_chunk(chunk_id: str):
async with semaphore:
chunks = {
chunks:dict[str, Any] = {
i["id"]: i for i in await self.text_chunks.get_by_ids([chunk_id])
}
# Extract and store entities and relationships
@@ -761,13 +635,13 @@ class LightRAG:
raise e
with tqdm_async(
total=len(_todo_chunk_keys),
total=len(to_process_doc_keys),
desc="\nLevel 1 - Processing chunks",
unit="chunk",
position=0,
) as progress:
tasks = []
for chunk_id in _todo_chunk_keys:
tasks: list[asyncio.Task[None]] = []
for chunk_id in to_process_doc_keys:
task = asyncio.create_task(process_chunk(chunk_id))
tasks.append(task)