cleaned insert by using pipe

This commit is contained in:
Yannick Stephan
2025-02-09 11:10:46 +01:00
parent af477e8a26
commit fd77099af5

View File

@@ -4,17 +4,12 @@ from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from typing import Type, cast, Dict from typing import Any, Type, Union
import traceback
from .operate import ( from .operate import (
chunking_by_token_size, chunking_by_token_size,
extract_entities, extract_entities
# local_query,global_query,hybrid_query, # local_query,global_query,hybrid_query,,
kg_query,
naive_query,
mix_kg_vector_query,
extract_keywords_only,
kg_query_with_keywords,
) )
from .utils import ( from .utils import (
@@ -30,8 +25,6 @@ from .base import (
BaseGraphStorage, BaseGraphStorage,
BaseKVStorage, BaseKVStorage,
BaseVectorStorage, BaseVectorStorage,
StorageNameSpace,
QueryParam,
DocStatus, DocStatus,
) )
@@ -176,7 +169,7 @@ class LightRAG:
enable_llm_cache_for_entity_extract: bool = True enable_llm_cache_for_entity_extract: bool = True
# extension # extension
addon_params: dict = field(default_factory=dict) addon_params: dict[str, Any] = field(default_factory=dict)
convert_response_to_json_func: callable = convert_response_to_json convert_response_to_json_func: callable = convert_response_to_json
# Add new field for document status storage type # Add new field for document status storage type
@@ -251,7 +244,7 @@ class LightRAG:
), ),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
) )
self.text_chunks: BaseVectorStorage = self.key_string_value_json_storage_cls( self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls(
namespace=make_namespace( namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS
), ),
@@ -281,7 +274,7 @@ class LightRAG:
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
meta_fields={"src_id", "tgt_id"}, meta_fields={"src_id", "tgt_id"},
) )
self.chunks_vdb = self.vector_db_storage_cls( self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls(
namespace=make_namespace( namespace=make_namespace(
self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS
), ),
@@ -310,7 +303,7 @@ class LightRAG:
# Initialize document status storage # Initialize document status storage
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage) self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
self.doc_status = self.doc_status_storage_cls( self.doc_status: BaseKVStorage = self.doc_status_storage_cls(
namespace=make_namespace(self.namespace_prefix, NameSpace.DOC_STATUS), namespace=make_namespace(self.namespace_prefix, NameSpace.DOC_STATUS),
global_config=global_config, global_config=global_config,
embedding_func=None, embedding_func=None,
@@ -359,7 +352,9 @@ class LightRAG:
) )
async def ainsert( async def ainsert(
self, string_or_strings, split_by_character=None, split_by_character_only=False self, string_or_strings: Union[str, list[str]],
split_by_character: str | None = None,
split_by_character_only: bool = False
): ):
"""Insert documents with checkpoint support """Insert documents with checkpoint support
@@ -370,154 +365,10 @@ class LightRAG:
split_by_character_only: if split_by_character_only is True, split the string by character only, when split_by_character_only: if split_by_character_only is True, split the string by character only, when
split_by_character is None, this parameter is ignored. split_by_character is None, this parameter is ignored.
""" """
if isinstance(string_or_strings, str): await self.apipeline_process_documents(string_or_strings)
string_or_strings = [string_or_strings] await self.apipeline_process_chunks(split_by_character, split_by_character_only)
await self.apipeline_process_extract_graph()
# 1. Remove duplicate contents from the list
unique_contents = list(set(doc.strip() for doc in string_or_strings))
# 2. Generate document IDs and initial status
new_docs = {
compute_mdhash_id(content, prefix="doc-"): {
"content": content,
"content_summary": self._get_content_summary(content),
"content_length": len(content),
"status": DocStatus.PENDING,
"created_at": datetime.now().isoformat(),
"updated_at": datetime.now().isoformat(),
}
for content in unique_contents
}
# 3. Filter out already processed documents
# _add_doc_keys = await self.doc_status.filter_keys(list(new_docs.keys()))
_add_doc_keys = set()
for doc_id in new_docs.keys():
current_doc = await self.doc_status.get_by_id(doc_id)
if current_doc is None:
_add_doc_keys.add(doc_id)
continue # skip to the next doc_id
status = None
if isinstance(current_doc, dict):
status = current_doc["status"]
else:
status = current_doc.status
if status == DocStatus.FAILED:
_add_doc_keys.add(doc_id)
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
if not new_docs:
logger.info("All documents have been processed or are duplicates")
return
logger.info(f"Processing {len(new_docs)} new unique documents")
# Process documents in batches
batch_size = self.addon_params.get("insert_batch_size", 10)
for i in range(0, len(new_docs), batch_size):
batch_docs = dict(list(new_docs.items())[i : i + batch_size])
for doc_id, doc in tqdm_async(
batch_docs.items(), desc=f"Processing batch {i // batch_size + 1}"
):
try:
# Update status to processing
doc_status = {
"content_summary": doc["content_summary"],
"content_length": doc["content_length"],
"status": DocStatus.PROCESSING,
"created_at": doc["created_at"],
"updated_at": datetime.now().isoformat(),
}
await self.doc_status.upsert({doc_id: doc_status})
# Generate chunks from document
chunks = {
compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp,
"full_doc_id": doc_id,
}
for dp in self.chunking_func(
doc["content"],
split_by_character=split_by_character,
split_by_character_only=split_by_character_only,
overlap_token_size=self.chunk_overlap_token_size,
max_token_size=self.chunk_token_size,
tiktoken_model=self.tiktoken_model_name,
**self.chunking_func_kwargs,
)
}
# Update status with chunks information
doc_status.update(
{
"chunks_count": len(chunks),
"updated_at": datetime.now().isoformat(),
}
)
await self.doc_status.upsert({doc_id: doc_status})
try:
# Store chunks in vector database
await self.chunks_vdb.upsert(chunks)
# Extract and store entities and relationships
maybe_new_kg = await extract_entities(
chunks,
knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb,
llm_response_cache=self.llm_response_cache,
global_config=asdict(self),
)
if maybe_new_kg is None:
raise Exception(
"Failed to extract entities and relationships"
)
self.chunk_entity_relation_graph = maybe_new_kg
# Store original document and chunks
await self.full_docs.upsert(
{doc_id: {"content": doc["content"]}}
)
await self.text_chunks.upsert(chunks)
# Update status to processed
doc_status.update(
{
"status": DocStatus.PROCESSED,
"updated_at": datetime.now().isoformat(),
}
)
await self.doc_status.upsert({doc_id: doc_status})
except Exception as e:
# Mark as failed if any step fails
doc_status.update(
{
"status": DocStatus.FAILED,
"error": str(e),
"updated_at": datetime.now().isoformat(),
}
)
await self.doc_status.upsert({doc_id: doc_status})
raise e
except Exception as e:
import traceback
error_msg = f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
continue
else:
# Only update index when processing succeeds
await self._insert_done()
def insert_custom_chunks(self, full_text: str, text_chunks: list[str]): def insert_custom_chunks(self, full_text: str, text_chunks: list[str]):
loop = always_get_an_event_loop() loop = always_get_an_event_loop()
@@ -597,34 +448,32 @@ class LightRAG:
# 1. Remove duplicate contents from the list # 1. Remove duplicate contents from the list
unique_contents = list(set(doc.strip() for doc in string_or_strings)) unique_contents = list(set(doc.strip() for doc in string_or_strings))
logger.info(
f"Received {len(string_or_strings)} docs, contains {len(unique_contents)} new unique documents"
)
# 2. Generate document IDs and initial status # 2. Generate document IDs and initial status
new_docs = { new_docs: dict[str, Any] = {
compute_mdhash_id(content, prefix="doc-"): { compute_mdhash_id(content, prefix="doc-"): {
"content": content, "content": content,
"content_summary": self._get_content_summary(content), "content_summary": self._get_content_summary(content),
"content_length": len(content), "content_length": len(content),
"status": DocStatus.PENDING, "status": DocStatus.PENDING,
"created_at": datetime.now().isoformat(), "created_at": datetime.now().isoformat(),
"updated_at": None, "updated_at": datetime.now().isoformat(),
} }
for content in unique_contents for content in unique_contents
} }
# 3. Filter out already processed documents # 3. Filter out already processed documents
_not_stored_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys())) _add_doc_keys: set[str] = set()
if len(_not_stored_doc_keys) < len(new_docs): for doc_id in new_docs.keys():
logger.info( current_doc = await self.doc_status.get_by_id(doc_id)
f"Skipping {len(new_docs) - len(_not_stored_doc_keys)} already existing documents"
) if not current_doc or current_doc["status"] == DocStatus.FAILED:
new_docs = {k: v for k, v in new_docs.items() if k in _not_stored_doc_keys} _add_doc_keys.add(doc_id)
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
if not new_docs: if not new_docs:
logger.info("All documents have been processed or are duplicates") logger.info("All documents have been processed or are duplicates")
return None return
# 4. Store original document # 4. Store original document
for doc_id, doc in new_docs.items(): for doc_id, doc in new_docs.items():
@@ -633,96 +482,121 @@ class LightRAG:
) )
logger.info(f"Stored {len(new_docs)} new unique documents") logger.info(f"Stored {len(new_docs)} new unique documents")
async def apipeline_process_chunks(self): async def apipeline_process_chunks(
self,
split_by_character: str | None = None,
split_by_character_only: bool = False
) -> None:
"""Get pendding documents, split into chunks,insert chunks""" """Get pendding documents, split into chunks,insert chunks"""
# 1. get all pending and failed documents # 1. get all pending and failed documents
_todo_doc_keys = [] to_process_doc_keys: list[str] = []
_failed_doc = await self.full_docs.get_by_status_and_ids( # Process failes
to_process_docs = await self.full_docs.get_by_status(
status=DocStatus.FAILED status=DocStatus.FAILED
) )
_pendding_doc = await self.full_docs.get_by_status_and_ids( if to_process_docs:
to_process_doc_keys.extend([doc["id"] for doc in to_process_docs])
# Process Pending
to_process_docs = await self.full_docs.get_by_status(
status=DocStatus.PENDING status=DocStatus.PENDING
) )
if to_process_docs:
to_process_doc_keys.extend([doc["id"] for doc in to_process_docs])
if _failed_doc: if not to_process_doc_keys:
_todo_doc_keys.extend([doc["id"] for doc in _failed_doc])
if _pendding_doc:
_todo_doc_keys.extend([doc["id"] for doc in _pendding_doc])
if not _todo_doc_keys:
logger.info("All documents have been processed or are duplicates") logger.info("All documents have been processed or are duplicates")
return None return
else:
logger.info(f"Filtered out {len(_todo_doc_keys)} not processed documents")
new_docs = { full_docs_ids = await self.full_docs.get_by_ids(to_process_doc_keys)
doc["id"]: doc for doc in await self.full_docs.get_by_ids(_todo_doc_keys) new_docs = {}
} if full_docs_ids:
new_docs = {doc["id"]: doc for doc in full_docs_ids or []}
if not new_docs:
logger.info("All documents have been processed or are duplicates")
return
# 2. split docs into chunks, insert chunks, update doc status # 2. split docs into chunks, insert chunks, update doc status
chunk_cnt = 0
batch_size = self.addon_params.get("insert_batch_size", 10) batch_size = self.addon_params.get("insert_batch_size", 10)
for i in range(0, len(new_docs), batch_size): for i in range(0, len(new_docs), batch_size):
batch_docs = dict(list(new_docs.items())[i : i + batch_size]) batch_docs = dict(list(new_docs.items())[i : i + batch_size])
for doc_id, doc in tqdm_async( for doc_id, doc in tqdm_async(
batch_docs.items(), batch_docs.items(), desc=f"Processing batch {i // batch_size + 1}"
desc=f"Level 1 - Spliting doc in batch {i // batch_size + 1}",
): ):
doc_status: dict[str, Any] = {
"content_summary": doc["content_summary"],
"content_length": doc["content_length"],
"status": DocStatus.PROCESSING,
"created_at": doc["created_at"],
"updated_at": datetime.now().isoformat(),
}
try: try:
await self.doc_status.upsert({doc_id: doc_status})
# Generate chunks from document # Generate chunks from document
chunks = { chunks: dict[str, Any] = {
compute_mdhash_id(dp["content"], prefix="chunk-"): { compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp, **dp,
"full_doc_id": doc_id, "full_doc_id": doc_id,
"status": DocStatus.PROCESSED,
} }
for dp in chunking_by_token_size( for dp in self.chunking_func(
doc["content"], doc["content"],
split_by_character=split_by_character,
split_by_character_only=split_by_character_only,
overlap_token_size=self.chunk_overlap_token_size, overlap_token_size=self.chunk_overlap_token_size,
max_token_size=self.chunk_token_size, max_token_size=self.chunk_token_size,
tiktoken_model=self.tiktoken_model_name, tiktoken_model=self.tiktoken_model_name,
**self.chunking_func_kwargs,
) )
} }
chunk_cnt += len(chunks)
try: # Update status with chunks information
# Store chunks in vector database doc_status.update(
{
"chunks_count": len(chunks),
"updated_at": datetime.now().isoformat(),
}
)
await self.doc_status.upsert({doc_id: doc_status})
await self.chunks_vdb.upsert(chunks) await self.chunks_vdb.upsert(chunks)
# Update doc status
await self.text_chunks.upsert(
{**chunks, "status": DocStatus.PENDING}
)
except Exception as e:
# Mark as failed if any step fails
await self.text_chunks.upsert(
{**chunks, "status": DocStatus.FAILED}
)
raise e
except Exception as e:
import traceback
error_msg = f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}" except Exception as e:
logger.error(error_msg) doc_status.update(
{
"status": DocStatus.FAILED,
"error": str(e),
"updated_at": datetime.now().isoformat(),
}
)
await self.doc_status.upsert({doc_id: doc_status})
logger.error(f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}")
continue continue
logger.info(f"Stored {chunk_cnt} chunks from {len(new_docs)} documents")
async def apipeline_process_extract_graph(self): async def apipeline_process_extract_graph(self):
"""Get pendding or failed chunks, extract entities and relationships from each chunk""" """Get pendding or failed chunks, extract entities and relationships from each chunk"""
# 1. get all pending and failed chunks # 1. get all pending and failed chunks
_todo_chunk_keys = [] to_process_doc_keys: list[str] = []
_failed_chunks = await self.text_chunks.get_by_status_and_ids(
# Process failes
to_process_docs = await self.full_docs.get_by_status(
status=DocStatus.FAILED status=DocStatus.FAILED
) )
_pendding_chunks = await self.text_chunks.get_by_status_and_ids( if to_process_docs:
to_process_doc_keys.extend([doc["id"] for doc in to_process_docs])
# Process Pending
to_process_docs = await self.full_docs.get_by_status(
status=DocStatus.PENDING status=DocStatus.PENDING
) )
if _failed_chunks: if to_process_docs:
_todo_chunk_keys.extend([doc["id"] for doc in _failed_chunks]) to_process_doc_keys.extend([doc["id"] for doc in to_process_docs])
if _pendding_chunks:
_todo_chunk_keys.extend([doc["id"] for doc in _pendding_chunks]) if not to_process_doc_keys:
if not _todo_chunk_keys: logger.info("All documents have been processed or are duplicates")
logger.info("All chunks have been processed or are duplicates") return
return None
# Process documents in batches # Process documents in batches
batch_size = self.addon_params.get("insert_batch_size", 10) batch_size = self.addon_params.get("insert_batch_size", 10)
@@ -731,9 +605,9 @@ class LightRAG:
batch_size batch_size
) # Control the number of tasks that are processed simultaneously ) # Control the number of tasks that are processed simultaneously
async def process_chunk(chunk_id): async def process_chunk(chunk_id: str):
async with semaphore: async with semaphore:
chunks = { chunks:dict[str, Any] = {
i["id"]: i for i in await self.text_chunks.get_by_ids([chunk_id]) i["id"]: i for i in await self.text_chunks.get_by_ids([chunk_id])
} }
# Extract and store entities and relationships # Extract and store entities and relationships
@@ -761,13 +635,13 @@ class LightRAG:
raise e raise e
with tqdm_async( with tqdm_async(
total=len(_todo_chunk_keys), total=len(to_process_doc_keys),
desc="\nLevel 1 - Processing chunks", desc="\nLevel 1 - Processing chunks",
unit="chunk", unit="chunk",
position=0, position=0,
) as progress: ) as progress:
tasks = [] tasks: list[asyncio.Task[None]] = []
for chunk_id in _todo_chunk_keys: for chunk_id in to_process_doc_keys:
task = asyncio.create_task(process_chunk(chunk_id)) task = asyncio.create_task(process_chunk(chunk_id))
tasks.append(task) tasks.append(task)