add support of providing ids for documents insert
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import configparser
|
||||
import os
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
@@ -37,11 +37,11 @@ from .utils import (
|
||||
always_get_an_event_loop,
|
||||
compute_mdhash_id,
|
||||
convert_response_to_json,
|
||||
encode_string_by_tiktoken,
|
||||
lazy_external_import,
|
||||
limit_async_func_call,
|
||||
logger,
|
||||
set_logger,
|
||||
encode_string_by_tiktoken,
|
||||
)
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
@@ -461,6 +461,7 @@ class LightRAG:
|
||||
input: str | list[str],
|
||||
split_by_character: str | None = None,
|
||||
split_by_character_only: bool = False,
|
||||
ids: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Sync Insert documents with checkpoint support
|
||||
|
||||
@@ -469,10 +470,11 @@ class LightRAG:
|
||||
split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
|
||||
split_by_character_only: if split_by_character_only is True, split the string by character only, when
|
||||
split_by_character is None, this parameter is ignored.
|
||||
ids: list of unique document IDs, if not provided, MD5 hash IDs will be generated
|
||||
"""
|
||||
loop = always_get_an_event_loop()
|
||||
loop.run_until_complete(
|
||||
self.ainsert(input, split_by_character, split_by_character_only)
|
||||
self.ainsert(input, split_by_character, split_by_character_only, ids)
|
||||
)
|
||||
|
||||
async def ainsert(
|
||||
@@ -480,6 +482,7 @@ class LightRAG:
|
||||
input: str | list[str],
|
||||
split_by_character: str | None = None,
|
||||
split_by_character_only: bool = False,
|
||||
ids: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Async Insert documents with checkpoint support
|
||||
|
||||
@@ -488,8 +491,9 @@ class LightRAG:
|
||||
split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
|
||||
split_by_character_only: if split_by_character_only is True, split the string by character only, when
|
||||
split_by_character is None, this parameter is ignored.
|
||||
ids: list of unique document IDs, if not provided, MD5 hash IDs will be generated
|
||||
"""
|
||||
await self.apipeline_enqueue_documents(input)
|
||||
await self.apipeline_enqueue_documents(input, ids)
|
||||
await self.apipeline_process_enqueue_documents(
|
||||
split_by_character, split_by_character_only
|
||||
)
|
||||
@@ -546,24 +550,51 @@ class LightRAG:
|
||||
if update_storage:
|
||||
await self._insert_done()
|
||||
|
||||
async def apipeline_enqueue_documents(self, input: str | list[str]) -> None:
|
||||
async def apipeline_enqueue_documents(
|
||||
self, input: str | list[str], ids: list[str] | None
|
||||
) -> None:
|
||||
"""
|
||||
Pipeline for Processing Documents
|
||||
|
||||
1. Remove duplicate contents from the list
|
||||
2. Generate document IDs and initial status
|
||||
3. Filter out already processed documents
|
||||
4. Enqueue document in status
|
||||
1. Validate ids if provided or generate MD5 hash IDs
|
||||
2. Remove duplicate contents
|
||||
3. Generate document initial status
|
||||
4. Filter out already processed documents
|
||||
5. Enqueue document in status
|
||||
"""
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
# 1. Remove duplicate contents from the list
|
||||
unique_contents = list(set(doc.strip() for doc in input))
|
||||
# 1. Validate ids if provided or generate MD5 hash IDs
|
||||
if ids is not None:
|
||||
# Check if the number of IDs matches the number of documents
|
||||
if len(ids) != len(input):
|
||||
raise ValueError("Number of IDs must match the number of documents")
|
||||
|
||||
# 2. Generate document IDs and initial status
|
||||
# Check if IDs are unique
|
||||
if len(ids) != len(set(ids)):
|
||||
raise ValueError("IDs must be unique")
|
||||
|
||||
# Generate contents dict of IDs provided by user and documents
|
||||
contents = {id_: doc.strip() for id_, doc in zip(ids, input)}
|
||||
else:
|
||||
# Generate contents dict of MD5 hash IDs and documents
|
||||
contents = {
|
||||
compute_mdhash_id(doc.strip(), prefix="doc-"): doc.strip()
|
||||
for doc in input
|
||||
}
|
||||
|
||||
# 2. Remove duplicate contents
|
||||
unique_contents = {
|
||||
id_: content
|
||||
for content, id_ in {
|
||||
content: id_ for id_, content in contents.items()
|
||||
}.items()
|
||||
}
|
||||
|
||||
# 3. Generate document initial status
|
||||
new_docs: dict[str, Any] = {
|
||||
compute_mdhash_id(content, prefix="doc-"): {
|
||||
id_: {
|
||||
"content": content,
|
||||
"content_summary": self._get_content_summary(content),
|
||||
"content_length": len(content),
|
||||
@@ -571,10 +602,10 @@ class LightRAG:
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
}
|
||||
for content in unique_contents
|
||||
for id_, content in unique_contents.items()
|
||||
}
|
||||
|
||||
# 3. Filter out already processed documents
|
||||
# 4. Filter out already processed documents
|
||||
# Get docs ids
|
||||
all_new_doc_ids = set(new_docs.keys())
|
||||
# Exclude IDs of documents that are already in progress
|
||||
@@ -586,7 +617,7 @@ class LightRAG:
|
||||
logger.info("No new unique documents were found.")
|
||||
return
|
||||
|
||||
# 4. Store status document
|
||||
# 5. Store status document
|
||||
await self.doc_status.upsert(new_docs)
|
||||
logger.info(f"Stored {len(new_docs)} new unique documents")
|
||||
|
||||
@@ -643,8 +674,6 @@ class LightRAG:
|
||||
# 4. iterate over batch
|
||||
for doc_id_processing_status in docs_batch:
|
||||
doc_id, status_doc = doc_id_processing_status
|
||||
# Update status in processing
|
||||
doc_status_id = compute_mdhash_id(status_doc.content, prefix="doc-")
|
||||
# Generate chunks from document
|
||||
chunks: dict[str, Any] = {
|
||||
compute_mdhash_id(dp["content"], prefix="chunk-"): {
|
||||
@@ -664,7 +693,7 @@ class LightRAG:
|
||||
tasks = [
|
||||
self.doc_status.upsert(
|
||||
{
|
||||
doc_status_id: {
|
||||
doc_id: {
|
||||
"status": DocStatus.PROCESSING,
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
"content": status_doc.content,
|
||||
@@ -685,7 +714,7 @@ class LightRAG:
|
||||
await asyncio.gather(*tasks)
|
||||
await self.doc_status.upsert(
|
||||
{
|
||||
doc_status_id: {
|
||||
doc_id: {
|
||||
"status": DocStatus.PROCESSED,
|
||||
"chunks_count": len(chunks),
|
||||
"content": status_doc.content,
|
||||
@@ -700,7 +729,7 @@ class LightRAG:
|
||||
logger.error(f"Failed to process document {doc_id}: {str(e)}")
|
||||
await self.doc_status.upsert(
|
||||
{
|
||||
doc_status_id: {
|
||||
doc_id: {
|
||||
"status": DocStatus.FAILED,
|
||||
"error": str(e),
|
||||
"content": status_doc.content,
|
||||
|
Reference in New Issue
Block a user