feat(lightrag): Add document status tracking and checkpoint support

功能(lightrag): 添加文档状态跟踪和断点续传支持

- Add DocStatus enum and DocProcessingStatus class for document processing state management
- 添加 DocStatus 枚举和 DocProcessingStatus 类用于文档处理状态管理

- Implement JsonDocStatusStorage for persistent status storage
- 实现 JsonDocStatusStorage 用于持久化状态存储

- Add document-level deduplication in batch processing
- 在批处理中添加文档级别的去重功能

- Add checkpoint support in ainsert method for resumable document processing
- 在 ainsert 方法中添加断点续传支持,实现可恢复的文档处理

- Add status query methods for monitoring processing progress
- 添加状态查询方法用于监控处理进度

- Update LightRAG initialization to support document status tracking
- 更新 LightRAG 初始化以支持文档状态跟踪
This commit is contained in:
Magic_yuan
2024-12-28 00:11:25 +08:00
parent c022db4355
commit 650b8e38b7
4 changed files with 256 additions and 61 deletions

View File

@@ -1,5 +1,6 @@
from dataclasses import dataclass, field
from typing import TypedDict, Union, Literal, Generic, TypeVar
from typing import TypedDict, Union, Literal, Generic, TypeVar, Optional, Dict, Any
from enum import Enum
import numpy as np
@@ -129,3 +130,42 @@ class BaseGraphStorage(StorageNameSpace):
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
raise NotImplementedError("Node embedding is not used in lightrag.")
class DocStatus(str, Enum):
"""Document processing status enum"""
PENDING = "pending"
PROCESSING = "processing"
PROCESSED = "processed"
FAILED = "failed"
@dataclass
class DocProcessingStatus:
"""Document processing status data structure"""
content_summary: str # First 100 chars of document content
content_length: int # Total length of document
status: DocStatus # Current processing status
created_at: str # ISO format timestamp
updated_at: str # ISO format timestamp
chunks_count: Optional[int] = None # Number of chunks after splitting
error: Optional[str] = None # Error message if failed
metadata: Dict[str, Any] = field(default_factory=dict) # Additional metadata
class DocStatusStorage(BaseKVStorage):
"""Base class for document status storage"""
async def get_status_counts(self) -> Dict[str, int]:
"""Get counts of documents in each status"""
raise NotImplementedError
async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]:
"""Get all failed documents"""
raise NotImplementedError
async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]:
"""Get all pending documents"""
raise NotImplementedError

View File

@@ -1,7 +1,8 @@
import asyncio
import inspect
import json
import os, sys
import os
import sys
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
@@ -22,8 +23,10 @@ from ..base import BaseGraphStorage
if sys.platform.startswith("win"):
import asyncio.windows_events
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
class AGEQueryException(Exception):
"""Exception for the AGE queries."""

View File

@@ -4,7 +4,7 @@ from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import asdict, dataclass, field
from datetime import datetime
from functools import partial
from typing import Type, cast
from typing import Type, cast, Dict
from .llm import (
gpt_4o_mini_complete,
@@ -32,12 +32,14 @@ from .base import (
BaseVectorStorage,
StorageNameSpace,
QueryParam,
DocStatus,
)
from .storage import (
JsonKVStorage,
NanoVectorDBStorage,
NetworkXStorage,
JsonDocStatusStorage,
)
# future KG integrations
@@ -172,6 +174,9 @@ class LightRAG:
addon_params: dict = field(default_factory=dict)
convert_response_to_json_func: callable = convert_response_to_json
# Add new field for document status storage type
doc_status_storage: str = field(default="JsonDocStatusStorage")
def __post_init__(self):
log_file = os.path.join("lightrag.log")
set_logger(log_file)
@@ -263,7 +268,15 @@ class LightRAG:
)
)
def _get_storage_class(self) -> Type[BaseGraphStorage]:
# Initialize document status storage
self.doc_status_storage_cls = self._get_storage_class()[self.doc_status_storage]
self.doc_status = self.doc_status_storage_cls(
namespace="doc_status",
global_config=asdict(self),
embedding_func=None,
)
def _get_storage_class(self) -> dict:
return {
# kv storage
"JsonKVStorage": JsonKVStorage,
@@ -284,6 +297,7 @@ class LightRAG:
"TiDBGraphStorage": TiDBGraphStorage,
"GremlinStorage": GremlinStorage,
# "ArangoDBStorage": ArangoDBStorage
"JsonDocStatusStorage": JsonDocStatusStorage,
}
def insert(self, string_or_strings):
@@ -291,71 +305,139 @@ class LightRAG:
return loop.run_until_complete(self.ainsert(string_or_strings))
async def ainsert(self, string_or_strings):
update_storage = False
try:
if isinstance(string_or_strings, str):
string_or_strings = [string_or_strings]
"""Insert documents with checkpoint support
new_docs = {
compute_mdhash_id(c.strip(), prefix="doc-"): {"content": c.strip()}
for c in string_or_strings
Args:
string_or_strings: Single document string or list of document strings
"""
if isinstance(string_or_strings, str):
string_or_strings = [string_or_strings]
# 1. Remove duplicate contents from the list
unique_contents = list(set(doc.strip() for doc in string_or_strings))
# 2. Generate document IDs and initial status
new_docs = {
compute_mdhash_id(content, prefix="doc-"): {
"content": content,
"content_summary": self._get_content_summary(content),
"content_length": len(content),
"status": DocStatus.PENDING,
"created_at": datetime.now().isoformat(),
"updated_at": datetime.now().isoformat(),
}
_add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
if not len(new_docs):
logger.warning("All docs are already in the storage")
return
update_storage = True
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
for content in unique_contents
}
inserting_chunks = {}
for doc_key, doc in tqdm_async(
new_docs.items(), desc="Chunking documents", unit="doc"
# 3. Filter out already processed documents
_add_doc_keys = await self.doc_status.filter_keys(list(new_docs.keys()))
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
if not new_docs:
logger.info("All documents have been processed or are duplicates")
return
logger.info(f"Processing {len(new_docs)} new unique documents")
# Process documents in batches
batch_size = self.addon_params.get("insert_batch_size", 10)
for i in range(0, len(new_docs), batch_size):
batch_docs = dict(list(new_docs.items())[i : i + batch_size])
for doc_id, doc in tqdm_async(
batch_docs.items(), desc=f"Processing batch {i//batch_size + 1}"
):
chunks = {
compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp,
"full_doc_id": doc_key,
try:
# Update status to processing
doc_status = {
"content_summary": doc["content_summary"],
"content_length": doc["content_length"],
"status": DocStatus.PROCESSING,
"created_at": doc["created_at"],
"updated_at": datetime.now().isoformat(),
}
for dp in chunking_by_token_size(
doc["content"],
overlap_token_size=self.chunk_overlap_token_size,
max_token_size=self.chunk_token_size,
tiktoken_model=self.tiktoken_model_name,
await self.doc_status.upsert({doc_id: doc_status})
# Generate chunks from document
chunks = {
compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp,
"full_doc_id": doc_id,
}
for dp in chunking_by_token_size(
doc["content"],
overlap_token_size=self.chunk_overlap_token_size,
max_token_size=self.chunk_token_size,
tiktoken_model=self.tiktoken_model_name,
)
}
# Update status with chunks information
doc_status.update(
{
"chunks_count": len(chunks),
"updated_at": datetime.now().isoformat(),
}
)
}
inserting_chunks.update(chunks)
_add_chunk_keys = await self.text_chunks.filter_keys(
list(inserting_chunks.keys())
)
inserting_chunks = {
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
}
if not len(inserting_chunks):
logger.warning("All chunks are already in the storage")
return
logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
await self.doc_status.upsert({doc_id: doc_status})
await self.chunks_vdb.upsert(inserting_chunks)
try:
# Store chunks in vector database
await self.chunks_vdb.upsert(chunks)
logger.info("[Entity Extraction]...")
maybe_new_kg = await extract_entities(
inserting_chunks,
knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb,
global_config=asdict(self),
)
if maybe_new_kg is None:
logger.warning("No new entities and relationships found")
return
self.chunk_entity_relation_graph = maybe_new_kg
# Extract and store entities and relationships
maybe_new_kg = await extract_entities(
chunks,
knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb,
global_config=asdict(self),
)
await self.full_docs.upsert(new_docs)
await self.text_chunks.upsert(inserting_chunks)
finally:
if update_storage:
await self._insert_done()
if maybe_new_kg is None:
raise Exception(
"Failed to extract entities and relationships"
)
self.chunk_entity_relation_graph = maybe_new_kg
# Store original document and chunks
await self.full_docs.upsert(
{doc_id: {"content": doc["content"]}}
)
await self.text_chunks.upsert(chunks)
# Update status to processed
doc_status.update(
{
"status": DocStatus.PROCESSED,
"updated_at": datetime.now().isoformat(),
}
)
await self.doc_status.upsert({doc_id: doc_status})
except Exception as e:
# Mark as failed if any step fails
doc_status.update(
{
"status": DocStatus.FAILED,
"error": str(e),
"updated_at": datetime.now().isoformat(),
}
)
await self.doc_status.upsert({doc_id: doc_status})
raise e
except Exception as e:
import traceback
error_msg = f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
continue
finally:
# Ensure all indexes are updated after each document
await self._insert_done()
async def _insert_done(self):
tasks = []
@@ -591,3 +673,26 @@ class LightRAG:
continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
await asyncio.gather(*tasks)
def _get_content_summary(self, content: str, max_length: int = 100) -> str:
"""Get summary of document content
Args:
content: Original document content
max_length: Maximum length of summary
Returns:
Truncated content with ellipsis if needed
"""
content = content.strip()
if len(content) <= max_length:
return content
return content[:max_length] + "..."
async def get_processing_status(self) -> Dict[str, int]:
"""Get current document processing status counts
Returns:
Dict with counts for each status
"""
return await self.doc_status.get_status_counts()

View File

@@ -3,7 +3,7 @@ import html
import os
from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import dataclass
from typing import Any, Union, cast
from typing import Any, Union, cast, Dict
import networkx as nx
import numpy as np
from nano_vectordb import NanoVectorDB
@@ -19,6 +19,9 @@ from .base import (
BaseGraphStorage,
BaseKVStorage,
BaseVectorStorage,
DocStatus,
DocProcessingStatus,
DocStatusStorage,
)
@@ -315,3 +318,47 @@ class NetworkXStorage(BaseGraphStorage):
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
return embeddings, nodes_ids
@dataclass
class JsonDocStatusStorage(DocStatusStorage):
"""JSON implementation of document status storage"""
def __post_init__(self):
working_dir = self.global_config["working_dir"]
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
self._data = load_json(self._file_name) or {}
logger.info(f"Loaded document status storage with {len(self._data)} records")
async def filter_keys(self, data: list[str]) -> set[str]:
"""Return keys that don't exist in storage"""
return set([k for k in data if k not in self._data])
async def get_status_counts(self) -> Dict[str, int]:
"""Get counts of documents in each status"""
counts = {status: 0 for status in DocStatus}
for doc in self._data.values():
counts[doc["status"]] += 1
return counts
async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]:
"""Get all failed documents"""
return {k: v for k, v in self._data.items() if v["status"] == DocStatus.FAILED}
async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]:
"""Get all pending documents"""
return {k: v for k, v in self._data.items() if v["status"] == DocStatus.PENDING}
async def index_done_callback(self):
"""Save data to file after indexing"""
write_json(self._data, self._file_name)
async def upsert(self, data: dict[str, dict]):
"""Update or insert document status
Args:
data: Dictionary of document IDs and their status data
"""
self._data.update(data)
await self.index_done_callback()
return data