Merge pull request #521 from magicyuan876/main

功能(lightrag): 添加文档状态跟踪和断点续传支持
This commit is contained in:
zrguo
2024-12-28 00:40:34 +08:00
committed by GitHub
5 changed files with 273 additions and 63 deletions

View File

@@ -278,10 +278,25 @@ class QueryParam:
### Batch Insert ### Batch Insert
```python ```python
# Batch Insert: Insert multiple texts at once # Basic Batch Insert: Insert multiple texts at once
rag.insert(["TEXT1", "TEXT2",...]) rag.insert(["TEXT1", "TEXT2",...])
# Batch Insert with custom batch size configuration
rag = LightRAG(
working_dir=WORKING_DIR,
addon_params={
"insert_batch_size": 20 # Process 20 documents per batch
}
)
rag.insert(["TEXT1", "TEXT2", "TEXT3", ...]) # Documents will be processed in batches of 20
``` ```
The `insert_batch_size` parameter in `addon_params` controls how many documents are processed in each batch during insertion. This is useful for:
- Managing memory usage with large document collections
- Optimizing processing speed
- Providing better progress tracking
- Default value is 10 if not specified
### Incremental Insert ### Incremental Insert
```python ```python
@@ -594,7 +609,7 @@ if __name__ == "__main__":
| **llm\_model\_kwargs** | `dict` | Additional parameters for LLM generation | | | **llm\_model\_kwargs** | `dict` | Additional parameters for LLM generation | |
| **vector\_db\_storage\_cls\_kwargs** | `dict` | Additional parameters for vector database (currently not used) | | | **vector\_db\_storage\_cls\_kwargs** | `dict` | Additional parameters for vector database (currently not used) | |
| **enable\_llm\_cache** | `bool` | If `TRUE`, stores LLM results in cache; repeated prompts return cached responses | `TRUE` | | **enable\_llm\_cache** | `bool` | If `TRUE`, stores LLM results in cache; repeated prompts return cached responses | `TRUE` |
| **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese", "entity_types": ["organization", "person", "geo", "event"]}`: sets example limit and output language | `example_number: all examples, language: English` | | **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese", "entity_types": ["organization", "person", "geo", "event"], "insert_batch_size": 10}`: sets example limit, output language, and batch size for document processing | `example_number: all examples, language: English, insert_batch_size: 10` |
| **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` | | **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` |
| **embedding\_cache\_config** | `dict` | Configuration for question-answer caching. Contains three parameters:<br>- `enabled`: Boolean value to enable/disable cache lookup functionality. When enabled, the system will check cached responses before generating new answers.<br>- `similarity_threshold`: Float value (0-1), similarity threshold. When a new question's similarity with a cached question exceeds this threshold, the cached answer will be returned directly without calling the LLM.<br>- `use_llm_check`: Boolean value to enable/disable LLM similarity verification. When enabled, LLM will be used as a secondary check to verify the similarity between questions before returning cached answers. | Default: `{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False}` | | **embedding\_cache\_config** | `dict` | Configuration for question-answer caching. Contains three parameters:<br>- `enabled`: Boolean value to enable/disable cache lookup functionality. When enabled, the system will check cached responses before generating new answers.<br>- `similarity_threshold`: Float value (0-1), similarity threshold. When a new question's similarity with a cached question exceeds this threshold, the cached answer will be returned directly without calling the LLM.<br>- `use_llm_check`: Boolean value to enable/disable LLM similarity verification. When enabled, LLM will be used as a secondary check to verify the similarity between questions before returning cached answers. | Default: `{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False}` |

View File

@@ -1,5 +1,6 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TypedDict, Union, Literal, Generic, TypeVar from typing import TypedDict, Union, Literal, Generic, TypeVar, Optional, Dict, Any
from enum import Enum
import numpy as np import numpy as np
@@ -129,3 +130,42 @@ class BaseGraphStorage(StorageNameSpace):
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
raise NotImplementedError("Node embedding is not used in lightrag.") raise NotImplementedError("Node embedding is not used in lightrag.")
class DocStatus(str, Enum):
"""Document processing status enum"""
PENDING = "pending"
PROCESSING = "processing"
PROCESSED = "processed"
FAILED = "failed"
@dataclass
class DocProcessingStatus:
"""Document processing status data structure"""
content_summary: str # First 100 chars of document content
content_length: int # Total length of document
status: DocStatus # Current processing status
created_at: str # ISO format timestamp
updated_at: str # ISO format timestamp
chunks_count: Optional[int] = None # Number of chunks after splitting
error: Optional[str] = None # Error message if failed
metadata: Dict[str, Any] = field(default_factory=dict) # Additional metadata
class DocStatusStorage(BaseKVStorage):
"""Base class for document status storage"""
async def get_status_counts(self) -> Dict[str, int]:
"""Get counts of documents in each status"""
raise NotImplementedError
async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]:
"""Get all failed documents"""
raise NotImplementedError
async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]:
"""Get all pending documents"""
raise NotImplementedError

View File

@@ -1,7 +1,8 @@
import asyncio import asyncio
import inspect import inspect
import json import json
import os, sys import os
import sys
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
@@ -22,8 +23,10 @@ from ..base import BaseGraphStorage
if sys.platform.startswith("win"): if sys.platform.startswith("win"):
import asyncio.windows_events import asyncio.windows_events
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
class AGEQueryException(Exception): class AGEQueryException(Exception):
"""Exception for the AGE queries.""" """Exception for the AGE queries."""

View File

@@ -4,7 +4,7 @@ from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from typing import Type, cast from typing import Type, cast, Dict
from .llm import ( from .llm import (
gpt_4o_mini_complete, gpt_4o_mini_complete,
@@ -32,12 +32,14 @@ from .base import (
BaseVectorStorage, BaseVectorStorage,
StorageNameSpace, StorageNameSpace,
QueryParam, QueryParam,
DocStatus,
) )
from .storage import ( from .storage import (
JsonKVStorage, JsonKVStorage,
NanoVectorDBStorage, NanoVectorDBStorage,
NetworkXStorage, NetworkXStorage,
JsonDocStatusStorage,
) )
# future KG integrations # future KG integrations
@@ -172,6 +174,9 @@ class LightRAG:
addon_params: dict = field(default_factory=dict) addon_params: dict = field(default_factory=dict)
convert_response_to_json_func: callable = convert_response_to_json convert_response_to_json_func: callable = convert_response_to_json
# Add new field for document status storage type
doc_status_storage: str = field(default="JsonDocStatusStorage")
def __post_init__(self): def __post_init__(self):
log_file = os.path.join("lightrag.log") log_file = os.path.join("lightrag.log")
set_logger(log_file) set_logger(log_file)
@@ -263,7 +268,15 @@ class LightRAG:
) )
) )
def _get_storage_class(self) -> Type[BaseGraphStorage]: # Initialize document status storage
self.doc_status_storage_cls = self._get_storage_class()[self.doc_status_storage]
self.doc_status = self.doc_status_storage_cls(
namespace="doc_status",
global_config=asdict(self),
embedding_func=None,
)
def _get_storage_class(self) -> dict:
return { return {
# kv storage # kv storage
"JsonKVStorage": JsonKVStorage, "JsonKVStorage": JsonKVStorage,
@@ -284,6 +297,7 @@ class LightRAG:
"TiDBGraphStorage": TiDBGraphStorage, "TiDBGraphStorage": TiDBGraphStorage,
"GremlinStorage": GremlinStorage, "GremlinStorage": GremlinStorage,
# "ArangoDBStorage": ArangoDBStorage # "ArangoDBStorage": ArangoDBStorage
"JsonDocStatusStorage": JsonDocStatusStorage,
} }
def insert(self, string_or_strings): def insert(self, string_or_strings):
@@ -291,71 +305,139 @@ class LightRAG:
return loop.run_until_complete(self.ainsert(string_or_strings)) return loop.run_until_complete(self.ainsert(string_or_strings))
async def ainsert(self, string_or_strings): async def ainsert(self, string_or_strings):
update_storage = False """Insert documents with checkpoint support
try:
if isinstance(string_or_strings, str):
string_or_strings = [string_or_strings]
new_docs = { Args:
compute_mdhash_id(c.strip(), prefix="doc-"): {"content": c.strip()} string_or_strings: Single document string or list of document strings
for c in string_or_strings """
if isinstance(string_or_strings, str):
string_or_strings = [string_or_strings]
# 1. Remove duplicate contents from the list
unique_contents = list(set(doc.strip() for doc in string_or_strings))
# 2. Generate document IDs and initial status
new_docs = {
compute_mdhash_id(content, prefix="doc-"): {
"content": content,
"content_summary": self._get_content_summary(content),
"content_length": len(content),
"status": DocStatus.PENDING,
"created_at": datetime.now().isoformat(),
"updated_at": datetime.now().isoformat(),
} }
_add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys())) for content in unique_contents
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys} }
if not len(new_docs):
logger.warning("All docs are already in the storage")
return
update_storage = True
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
inserting_chunks = {} # 3. Filter out already processed documents
for doc_key, doc in tqdm_async( _add_doc_keys = await self.doc_status.filter_keys(list(new_docs.keys()))
new_docs.items(), desc="Chunking documents", unit="doc" new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
if not new_docs:
logger.info("All documents have been processed or are duplicates")
return
logger.info(f"Processing {len(new_docs)} new unique documents")
# Process documents in batches
batch_size = self.addon_params.get("insert_batch_size", 10)
for i in range(0, len(new_docs), batch_size):
batch_docs = dict(list(new_docs.items())[i : i + batch_size])
for doc_id, doc in tqdm_async(
batch_docs.items(), desc=f"Processing batch {i//batch_size + 1}"
): ):
chunks = { try:
compute_mdhash_id(dp["content"], prefix="chunk-"): { # Update status to processing
**dp, doc_status = {
"full_doc_id": doc_key, "content_summary": doc["content_summary"],
"content_length": doc["content_length"],
"status": DocStatus.PROCESSING,
"created_at": doc["created_at"],
"updated_at": datetime.now().isoformat(),
} }
for dp in chunking_by_token_size( await self.doc_status.upsert({doc_id: doc_status})
doc["content"],
overlap_token_size=self.chunk_overlap_token_size, # Generate chunks from document
max_token_size=self.chunk_token_size, chunks = {
tiktoken_model=self.tiktoken_model_name, compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp,
"full_doc_id": doc_id,
}
for dp in chunking_by_token_size(
doc["content"],
overlap_token_size=self.chunk_overlap_token_size,
max_token_size=self.chunk_token_size,
tiktoken_model=self.tiktoken_model_name,
)
}
# Update status with chunks information
doc_status.update(
{
"chunks_count": len(chunks),
"updated_at": datetime.now().isoformat(),
}
) )
} await self.doc_status.upsert({doc_id: doc_status})
inserting_chunks.update(chunks)
_add_chunk_keys = await self.text_chunks.filter_keys(
list(inserting_chunks.keys())
)
inserting_chunks = {
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
}
if not len(inserting_chunks):
logger.warning("All chunks are already in the storage")
return
logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
await self.chunks_vdb.upsert(inserting_chunks) try:
# Store chunks in vector database
await self.chunks_vdb.upsert(chunks)
logger.info("[Entity Extraction]...") # Extract and store entities and relationships
maybe_new_kg = await extract_entities( maybe_new_kg = await extract_entities(
inserting_chunks, chunks,
knowledge_graph_inst=self.chunk_entity_relation_graph, knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb, entity_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb, relationships_vdb=self.relationships_vdb,
global_config=asdict(self), global_config=asdict(self),
) )
if maybe_new_kg is None:
logger.warning("No new entities and relationships found")
return
self.chunk_entity_relation_graph = maybe_new_kg
await self.full_docs.upsert(new_docs) if maybe_new_kg is None:
await self.text_chunks.upsert(inserting_chunks) raise Exception(
finally: "Failed to extract entities and relationships"
if update_storage: )
await self._insert_done()
self.chunk_entity_relation_graph = maybe_new_kg
# Store original document and chunks
await self.full_docs.upsert(
{doc_id: {"content": doc["content"]}}
)
await self.text_chunks.upsert(chunks)
# Update status to processed
doc_status.update(
{
"status": DocStatus.PROCESSED,
"updated_at": datetime.now().isoformat(),
}
)
await self.doc_status.upsert({doc_id: doc_status})
except Exception as e:
# Mark as failed if any step fails
doc_status.update(
{
"status": DocStatus.FAILED,
"error": str(e),
"updated_at": datetime.now().isoformat(),
}
)
await self.doc_status.upsert({doc_id: doc_status})
raise e
except Exception as e:
import traceback
error_msg = f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
continue
finally:
# Ensure all indexes are updated after each document
await self._insert_done()
async def _insert_done(self): async def _insert_done(self):
tasks = [] tasks = []
@@ -591,3 +673,26 @@ class LightRAG:
continue continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
def _get_content_summary(self, content: str, max_length: int = 100) -> str:
"""Get summary of document content
Args:
content: Original document content
max_length: Maximum length of summary
Returns:
Truncated content with ellipsis if needed
"""
content = content.strip()
if len(content) <= max_length:
return content
return content[:max_length] + "..."
async def get_processing_status(self) -> Dict[str, int]:
"""Get current document processing status counts
Returns:
Dict with counts for each status
"""
return await self.doc_status.get_status_counts()

View File

@@ -3,7 +3,7 @@ import html
import os import os
from tqdm.asyncio import tqdm as tqdm_async from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Union, cast from typing import Any, Union, cast, Dict
import networkx as nx import networkx as nx
import numpy as np import numpy as np
from nano_vectordb import NanoVectorDB from nano_vectordb import NanoVectorDB
@@ -19,6 +19,9 @@ from .base import (
BaseGraphStorage, BaseGraphStorage,
BaseKVStorage, BaseKVStorage,
BaseVectorStorage, BaseVectorStorage,
DocStatus,
DocProcessingStatus,
DocStatusStorage,
) )
@@ -315,3 +318,47 @@ class NetworkXStorage(BaseGraphStorage):
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes] nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
return embeddings, nodes_ids return embeddings, nodes_ids
@dataclass
class JsonDocStatusStorage(DocStatusStorage):
"""JSON implementation of document status storage"""
def __post_init__(self):
working_dir = self.global_config["working_dir"]
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
self._data = load_json(self._file_name) or {}
logger.info(f"Loaded document status storage with {len(self._data)} records")
async def filter_keys(self, data: list[str]) -> set[str]:
"""Return keys that don't exist in storage"""
return set([k for k in data if k not in self._data])
async def get_status_counts(self) -> Dict[str, int]:
"""Get counts of documents in each status"""
counts = {status: 0 for status in DocStatus}
for doc in self._data.values():
counts[doc["status"]] += 1
return counts
async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]:
"""Get all failed documents"""
return {k: v for k, v in self._data.items() if v["status"] == DocStatus.FAILED}
async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]:
"""Get all pending documents"""
return {k: v for k, v in self._data.items() if v["status"] == DocStatus.PENDING}
async def index_done_callback(self):
"""Save data to file after indexing"""
write_json(self._data, self._file_name)
async def upsert(self, data: dict[str, dict]):
"""Update or insert document status
Args:
data: Dictionary of document IDs and their status data
"""
self._data.update(data)
await self.index_done_callback()
return data