diff --git a/lightrag/base.py b/lightrag/base.py index f5a6e0c0..c3ba3e09 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -1,5 +1,6 @@ from dataclasses import dataclass, field -from typing import TypedDict, Union, Literal, Generic, TypeVar +from typing import TypedDict, Union, Literal, Generic, TypeVar, Optional, Dict, Any +from enum import Enum import numpy as np @@ -129,3 +130,42 @@ class BaseGraphStorage(StorageNameSpace): async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: raise NotImplementedError("Node embedding is not used in lightrag.") + + +class DocStatus(str, Enum): + """Document processing status enum""" + + PENDING = "pending" + PROCESSING = "processing" + PROCESSED = "processed" + FAILED = "failed" + + +@dataclass +class DocProcessingStatus: + """Document processing status data structure""" + + content_summary: str # First 100 chars of document content + content_length: int # Total length of document + status: DocStatus # Current processing status + created_at: str # ISO format timestamp + updated_at: str # ISO format timestamp + chunks_count: Optional[int] = None # Number of chunks after splitting + error: Optional[str] = None # Error message if failed + metadata: Dict[str, Any] = field(default_factory=dict) # Additional metadata + + +class DocStatusStorage(BaseKVStorage): + """Base class for document status storage""" + + async def get_status_counts(self) -> Dict[str, int]: + """Get counts of documents in each status""" + raise NotImplementedError + + async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]: + """Get all failed documents""" + raise NotImplementedError + + async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]: + """Get all pending documents""" + raise NotImplementedError diff --git a/lightrag/kg/age_impl.py b/lightrag/kg/age_impl.py index 2a97bc37..275f5775 100644 --- a/lightrag/kg/age_impl.py +++ b/lightrag/kg/age_impl.py @@ -1,7 +1,8 @@ import asyncio import inspect import json -import os, sys +import os +import sys from contextlib import asynccontextmanager from dataclasses import dataclass from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union @@ -22,8 +23,10 @@ from ..base import BaseGraphStorage if sys.platform.startswith("win"): import asyncio.windows_events + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + class AGEQueryException(Exception): """Exception for the AGE queries.""" diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 51ac204d..992c43a4 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -4,7 +4,7 @@ from tqdm.asyncio import tqdm as tqdm_async from dataclasses import asdict, dataclass, field from datetime import datetime from functools import partial -from typing import Type, cast +from typing import Type, cast, Dict from .llm import ( gpt_4o_mini_complete, @@ -32,12 +32,14 @@ from .base import ( BaseVectorStorage, StorageNameSpace, QueryParam, + DocStatus, ) from .storage import ( JsonKVStorage, NanoVectorDBStorage, NetworkXStorage, + JsonDocStatusStorage, ) # future KG integrations @@ -172,6 +174,9 @@ class LightRAG: addon_params: dict = field(default_factory=dict) convert_response_to_json_func: callable = convert_response_to_json + # Add new field for document status storage type + doc_status_storage: str = field(default="JsonDocStatusStorage") + def __post_init__(self): log_file = os.path.join("lightrag.log") set_logger(log_file) @@ -263,7 +268,15 @@ class LightRAG: ) ) - def _get_storage_class(self) -> Type[BaseGraphStorage]: + # Initialize document status storage + self.doc_status_storage_cls = self._get_storage_class()[self.doc_status_storage] + self.doc_status = self.doc_status_storage_cls( + namespace="doc_status", + global_config=asdict(self), + embedding_func=None, + ) + + def _get_storage_class(self) -> dict: return { # kv storage "JsonKVStorage": JsonKVStorage, @@ -284,6 +297,7 @@ class LightRAG: "TiDBGraphStorage": TiDBGraphStorage, "GremlinStorage": GremlinStorage, # "ArangoDBStorage": ArangoDBStorage + "JsonDocStatusStorage": JsonDocStatusStorage, } def insert(self, string_or_strings): @@ -291,71 +305,139 @@ class LightRAG: return loop.run_until_complete(self.ainsert(string_or_strings)) async def ainsert(self, string_or_strings): - update_storage = False - try: - if isinstance(string_or_strings, str): - string_or_strings = [string_or_strings] + """Insert documents with checkpoint support - new_docs = { - compute_mdhash_id(c.strip(), prefix="doc-"): {"content": c.strip()} - for c in string_or_strings + Args: + string_or_strings: Single document string or list of document strings + """ + if isinstance(string_or_strings, str): + string_or_strings = [string_or_strings] + + # 1. Remove duplicate contents from the list + unique_contents = list(set(doc.strip() for doc in string_or_strings)) + + # 2. Generate document IDs and initial status + new_docs = { + compute_mdhash_id(content, prefix="doc-"): { + "content": content, + "content_summary": self._get_content_summary(content), + "content_length": len(content), + "status": DocStatus.PENDING, + "created_at": datetime.now().isoformat(), + "updated_at": datetime.now().isoformat(), } - _add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys())) - new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys} - if not len(new_docs): - logger.warning("All docs are already in the storage") - return - update_storage = True - logger.info(f"[New Docs] inserting {len(new_docs)} docs") + for content in unique_contents + } - inserting_chunks = {} - for doc_key, doc in tqdm_async( - new_docs.items(), desc="Chunking documents", unit="doc" + # 3. Filter out already processed documents + _add_doc_keys = await self.doc_status.filter_keys(list(new_docs.keys())) + new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys} + + if not new_docs: + logger.info("All documents have been processed or are duplicates") + return + + logger.info(f"Processing {len(new_docs)} new unique documents") + + # Process documents in batches + batch_size = self.addon_params.get("insert_batch_size", 10) + for i in range(0, len(new_docs), batch_size): + batch_docs = dict(list(new_docs.items())[i : i + batch_size]) + + for doc_id, doc in tqdm_async( + batch_docs.items(), desc=f"Processing batch {i//batch_size + 1}" ): - chunks = { - compute_mdhash_id(dp["content"], prefix="chunk-"): { - **dp, - "full_doc_id": doc_key, + try: + # Update status to processing + doc_status = { + "content_summary": doc["content_summary"], + "content_length": doc["content_length"], + "status": DocStatus.PROCESSING, + "created_at": doc["created_at"], + "updated_at": datetime.now().isoformat(), } - for dp in chunking_by_token_size( - doc["content"], - overlap_token_size=self.chunk_overlap_token_size, - max_token_size=self.chunk_token_size, - tiktoken_model=self.tiktoken_model_name, + await self.doc_status.upsert({doc_id: doc_status}) + + # Generate chunks from document + chunks = { + compute_mdhash_id(dp["content"], prefix="chunk-"): { + **dp, + "full_doc_id": doc_id, + } + for dp in chunking_by_token_size( + doc["content"], + overlap_token_size=self.chunk_overlap_token_size, + max_token_size=self.chunk_token_size, + tiktoken_model=self.tiktoken_model_name, + ) + } + + # Update status with chunks information + doc_status.update( + { + "chunks_count": len(chunks), + "updated_at": datetime.now().isoformat(), + } ) - } - inserting_chunks.update(chunks) - _add_chunk_keys = await self.text_chunks.filter_keys( - list(inserting_chunks.keys()) - ) - inserting_chunks = { - k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys - } - if not len(inserting_chunks): - logger.warning("All chunks are already in the storage") - return - logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks") + await self.doc_status.upsert({doc_id: doc_status}) - await self.chunks_vdb.upsert(inserting_chunks) + try: + # Store chunks in vector database + await self.chunks_vdb.upsert(chunks) - logger.info("[Entity Extraction]...") - maybe_new_kg = await extract_entities( - inserting_chunks, - knowledge_graph_inst=self.chunk_entity_relation_graph, - entity_vdb=self.entities_vdb, - relationships_vdb=self.relationships_vdb, - global_config=asdict(self), - ) - if maybe_new_kg is None: - logger.warning("No new entities and relationships found") - return - self.chunk_entity_relation_graph = maybe_new_kg + # Extract and store entities and relationships + maybe_new_kg = await extract_entities( + chunks, + knowledge_graph_inst=self.chunk_entity_relation_graph, + entity_vdb=self.entities_vdb, + relationships_vdb=self.relationships_vdb, + global_config=asdict(self), + ) - await self.full_docs.upsert(new_docs) - await self.text_chunks.upsert(inserting_chunks) - finally: - if update_storage: - await self._insert_done() + if maybe_new_kg is None: + raise Exception( + "Failed to extract entities and relationships" + ) + + self.chunk_entity_relation_graph = maybe_new_kg + + # Store original document and chunks + await self.full_docs.upsert( + {doc_id: {"content": doc["content"]}} + ) + await self.text_chunks.upsert(chunks) + + # Update status to processed + doc_status.update( + { + "status": DocStatus.PROCESSED, + "updated_at": datetime.now().isoformat(), + } + ) + await self.doc_status.upsert({doc_id: doc_status}) + + except Exception as e: + # Mark as failed if any step fails + doc_status.update( + { + "status": DocStatus.FAILED, + "error": str(e), + "updated_at": datetime.now().isoformat(), + } + ) + await self.doc_status.upsert({doc_id: doc_status}) + raise e + + except Exception as e: + import traceback + + error_msg = f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}" + logger.error(error_msg) + continue + + finally: + # Ensure all indexes are updated after each document + await self._insert_done() async def _insert_done(self): tasks = [] @@ -591,3 +673,26 @@ class LightRAG: continue tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) await asyncio.gather(*tasks) + + def _get_content_summary(self, content: str, max_length: int = 100) -> str: + """Get summary of document content + + Args: + content: Original document content + max_length: Maximum length of summary + + Returns: + Truncated content with ellipsis if needed + """ + content = content.strip() + if len(content) <= max_length: + return content + return content[:max_length] + "..." + + async def get_processing_status(self) -> Dict[str, int]: + """Get current document processing status counts + + Returns: + Dict with counts for each status + """ + return await self.doc_status.get_status_counts() diff --git a/lightrag/storage.py b/lightrag/storage.py index 0c880bb7..0f65d09c 100644 --- a/lightrag/storage.py +++ b/lightrag/storage.py @@ -3,7 +3,7 @@ import html import os from tqdm.asyncio import tqdm as tqdm_async from dataclasses import dataclass -from typing import Any, Union, cast +from typing import Any, Union, cast, Dict import networkx as nx import numpy as np from nano_vectordb import NanoVectorDB @@ -19,6 +19,9 @@ from .base import ( BaseGraphStorage, BaseKVStorage, BaseVectorStorage, + DocStatus, + DocProcessingStatus, + DocStatusStorage, ) @@ -315,3 +318,47 @@ class NetworkXStorage(BaseGraphStorage): nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes] return embeddings, nodes_ids + + +@dataclass +class JsonDocStatusStorage(DocStatusStorage): + """JSON implementation of document status storage""" + + def __post_init__(self): + working_dir = self.global_config["working_dir"] + self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") + self._data = load_json(self._file_name) or {} + logger.info(f"Loaded document status storage with {len(self._data)} records") + + async def filter_keys(self, data: list[str]) -> set[str]: + """Return keys that don't exist in storage""" + return set([k for k in data if k not in self._data]) + + async def get_status_counts(self) -> Dict[str, int]: + """Get counts of documents in each status""" + counts = {status: 0 for status in DocStatus} + for doc in self._data.values(): + counts[doc["status"]] += 1 + return counts + + async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]: + """Get all failed documents""" + return {k: v for k, v in self._data.items() if v["status"] == DocStatus.FAILED} + + async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]: + """Get all pending documents""" + return {k: v for k, v in self._data.items() if v["status"] == DocStatus.PENDING} + + async def index_done_callback(self): + """Save data to file after indexing""" + write_json(self._data, self._file_name) + + async def upsert(self, data: dict[str, dict]): + """Update or insert document status + + Args: + data: Dictionary of document IDs and their status data + """ + self._data.update(data) + await self.index_done_callback() + return data