feat(lightrag): Add document status tracking and checkpoint support

功能(lightrag): 添加文档状态跟踪和断点续传支持

- Add DocStatus enum and DocProcessingStatus class for document processing state management
- 添加 DocStatus 枚举和 DocProcessingStatus 类用于文档处理状态管理

- Implement JsonDocStatusStorage for persistent status storage
- 实现 JsonDocStatusStorage 用于持久化状态存储

- Add document-level deduplication in batch processing
- 在批处理中添加文档级别的去重功能

- Add checkpoint support in ainsert method for resumable document processing
- 在 ainsert 方法中添加断点续传支持,实现可恢复的文档处理

- Add status query methods for monitoring processing progress
- 添加状态查询方法用于监控处理进度

- Update LightRAG initialization to support document status tracking
- 更新 LightRAG 初始化以支持文档状态跟踪
This commit is contained in:
Magic_yuan
2024-12-28 00:11:25 +08:00
parent c022db4355
commit 650b8e38b7
4 changed files with 256 additions and 61 deletions

View File

@@ -4,7 +4,7 @@ from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import asdict, dataclass, field
from datetime import datetime
from functools import partial
from typing import Type, cast
from typing import Type, cast, Dict
from .llm import (
gpt_4o_mini_complete,
@@ -32,12 +32,14 @@ from .base import (
BaseVectorStorage,
StorageNameSpace,
QueryParam,
DocStatus,
)
from .storage import (
JsonKVStorage,
NanoVectorDBStorage,
NetworkXStorage,
JsonDocStatusStorage,
)
# future KG integrations
@@ -172,6 +174,9 @@ class LightRAG:
addon_params: dict = field(default_factory=dict)
convert_response_to_json_func: callable = convert_response_to_json
# Add new field for document status storage type
doc_status_storage: str = field(default="JsonDocStatusStorage")
def __post_init__(self):
log_file = os.path.join("lightrag.log")
set_logger(log_file)
@@ -263,7 +268,15 @@ class LightRAG:
)
)
def _get_storage_class(self) -> Type[BaseGraphStorage]:
# Initialize document status storage
self.doc_status_storage_cls = self._get_storage_class()[self.doc_status_storage]
self.doc_status = self.doc_status_storage_cls(
namespace="doc_status",
global_config=asdict(self),
embedding_func=None,
)
def _get_storage_class(self) -> dict:
return {
# kv storage
"JsonKVStorage": JsonKVStorage,
@@ -284,6 +297,7 @@ class LightRAG:
"TiDBGraphStorage": TiDBGraphStorage,
"GremlinStorage": GremlinStorage,
# "ArangoDBStorage": ArangoDBStorage
"JsonDocStatusStorage": JsonDocStatusStorage,
}
def insert(self, string_or_strings):
@@ -291,71 +305,139 @@ class LightRAG:
return loop.run_until_complete(self.ainsert(string_or_strings))
async def ainsert(self, string_or_strings):
update_storage = False
try:
if isinstance(string_or_strings, str):
string_or_strings = [string_or_strings]
"""Insert documents with checkpoint support
new_docs = {
compute_mdhash_id(c.strip(), prefix="doc-"): {"content": c.strip()}
for c in string_or_strings
Args:
string_or_strings: Single document string or list of document strings
"""
if isinstance(string_or_strings, str):
string_or_strings = [string_or_strings]
# 1. Remove duplicate contents from the list
unique_contents = list(set(doc.strip() for doc in string_or_strings))
# 2. Generate document IDs and initial status
new_docs = {
compute_mdhash_id(content, prefix="doc-"): {
"content": content,
"content_summary": self._get_content_summary(content),
"content_length": len(content),
"status": DocStatus.PENDING,
"created_at": datetime.now().isoformat(),
"updated_at": datetime.now().isoformat(),
}
_add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
if not len(new_docs):
logger.warning("All docs are already in the storage")
return
update_storage = True
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
for content in unique_contents
}
inserting_chunks = {}
for doc_key, doc in tqdm_async(
new_docs.items(), desc="Chunking documents", unit="doc"
# 3. Filter out already processed documents
_add_doc_keys = await self.doc_status.filter_keys(list(new_docs.keys()))
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
if not new_docs:
logger.info("All documents have been processed or are duplicates")
return
logger.info(f"Processing {len(new_docs)} new unique documents")
# Process documents in batches
batch_size = self.addon_params.get("insert_batch_size", 10)
for i in range(0, len(new_docs), batch_size):
batch_docs = dict(list(new_docs.items())[i : i + batch_size])
for doc_id, doc in tqdm_async(
batch_docs.items(), desc=f"Processing batch {i//batch_size + 1}"
):
chunks = {
compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp,
"full_doc_id": doc_key,
try:
# Update status to processing
doc_status = {
"content_summary": doc["content_summary"],
"content_length": doc["content_length"],
"status": DocStatus.PROCESSING,
"created_at": doc["created_at"],
"updated_at": datetime.now().isoformat(),
}
for dp in chunking_by_token_size(
doc["content"],
overlap_token_size=self.chunk_overlap_token_size,
max_token_size=self.chunk_token_size,
tiktoken_model=self.tiktoken_model_name,
await self.doc_status.upsert({doc_id: doc_status})
# Generate chunks from document
chunks = {
compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp,
"full_doc_id": doc_id,
}
for dp in chunking_by_token_size(
doc["content"],
overlap_token_size=self.chunk_overlap_token_size,
max_token_size=self.chunk_token_size,
tiktoken_model=self.tiktoken_model_name,
)
}
# Update status with chunks information
doc_status.update(
{
"chunks_count": len(chunks),
"updated_at": datetime.now().isoformat(),
}
)
}
inserting_chunks.update(chunks)
_add_chunk_keys = await self.text_chunks.filter_keys(
list(inserting_chunks.keys())
)
inserting_chunks = {
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
}
if not len(inserting_chunks):
logger.warning("All chunks are already in the storage")
return
logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
await self.doc_status.upsert({doc_id: doc_status})
await self.chunks_vdb.upsert(inserting_chunks)
try:
# Store chunks in vector database
await self.chunks_vdb.upsert(chunks)
logger.info("[Entity Extraction]...")
maybe_new_kg = await extract_entities(
inserting_chunks,
knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb,
global_config=asdict(self),
)
if maybe_new_kg is None:
logger.warning("No new entities and relationships found")
return
self.chunk_entity_relation_graph = maybe_new_kg
# Extract and store entities and relationships
maybe_new_kg = await extract_entities(
chunks,
knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb,
global_config=asdict(self),
)
await self.full_docs.upsert(new_docs)
await self.text_chunks.upsert(inserting_chunks)
finally:
if update_storage:
await self._insert_done()
if maybe_new_kg is None:
raise Exception(
"Failed to extract entities and relationships"
)
self.chunk_entity_relation_graph = maybe_new_kg
# Store original document and chunks
await self.full_docs.upsert(
{doc_id: {"content": doc["content"]}}
)
await self.text_chunks.upsert(chunks)
# Update status to processed
doc_status.update(
{
"status": DocStatus.PROCESSED,
"updated_at": datetime.now().isoformat(),
}
)
await self.doc_status.upsert({doc_id: doc_status})
except Exception as e:
# Mark as failed if any step fails
doc_status.update(
{
"status": DocStatus.FAILED,
"error": str(e),
"updated_at": datetime.now().isoformat(),
}
)
await self.doc_status.upsert({doc_id: doc_status})
raise e
except Exception as e:
import traceback
error_msg = f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
continue
finally:
# Ensure all indexes are updated after each document
await self._insert_done()
async def _insert_done(self):
tasks = []
@@ -591,3 +673,26 @@ class LightRAG:
continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
await asyncio.gather(*tasks)
def _get_content_summary(self, content: str, max_length: int = 100) -> str:
"""Get summary of document content
Args:
content: Original document content
max_length: Maximum length of summary
Returns:
Truncated content with ellipsis if needed
"""
content = content.strip()
if len(content) <= max_length:
return content
return content[:max_length] + "..."
async def get_processing_status(self) -> Dict[str, int]:
"""Get current document processing status counts
Returns:
Dict with counts for each status
"""
return await self.doc_status.get_status_counts()