add support of providing ids for documents insert

This commit is contained in:
PiochU19
2025-02-20 00:26:35 +01:00
parent 1c3a4944d3
commit d462ace978

View File

@@ -1,8 +1,8 @@
from __future__ import annotations
import asyncio
import os
import configparser
import os
from dataclasses import asdict, dataclass, field
from datetime import datetime
from functools import partial
@@ -37,11 +37,11 @@ from .utils import (
always_get_an_event_loop,
compute_mdhash_id,
convert_response_to_json,
encode_string_by_tiktoken,
lazy_external_import,
limit_async_func_call,
logger,
set_logger,
encode_string_by_tiktoken,
)
config = configparser.ConfigParser()
@@ -461,6 +461,7 @@ class LightRAG:
input: str | list[str],
split_by_character: str | None = None,
split_by_character_only: bool = False,
ids: list[str] | None = None,
) -> None:
"""Sync Insert documents with checkpoint support
@@ -469,10 +470,11 @@ class LightRAG:
split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
split_by_character_only: if split_by_character_only is True, split the string by character only, when
split_by_character is None, this parameter is ignored.
ids: list of unique document IDs, if not provided, MD5 hash IDs will be generated
"""
loop = always_get_an_event_loop()
loop.run_until_complete(
self.ainsert(input, split_by_character, split_by_character_only)
self.ainsert(input, split_by_character, split_by_character_only, ids)
)
async def ainsert(
@@ -480,6 +482,7 @@ class LightRAG:
input: str | list[str],
split_by_character: str | None = None,
split_by_character_only: bool = False,
ids: list[str] | None = None,
) -> None:
"""Async Insert documents with checkpoint support
@@ -488,8 +491,9 @@ class LightRAG:
split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
split_by_character_only: if split_by_character_only is True, split the string by character only, when
split_by_character is None, this parameter is ignored.
ids: list of unique document IDs, if not provided, MD5 hash IDs will be generated
"""
await self.apipeline_enqueue_documents(input)
await self.apipeline_enqueue_documents(input, ids)
await self.apipeline_process_enqueue_documents(
split_by_character, split_by_character_only
)
@@ -546,24 +550,51 @@ class LightRAG:
if update_storage:
await self._insert_done()
async def apipeline_enqueue_documents(self, input: str | list[str]) -> None:
async def apipeline_enqueue_documents(
self, input: str | list[str], ids: list[str] | None
) -> None:
"""
Pipeline for Processing Documents
1. Remove duplicate contents from the list
2. Generate document IDs and initial status
3. Filter out already processed documents
4. Enqueue document in status
1. Validate ids if provided or generate MD5 hash IDs
2. Remove duplicate contents
3. Generate document initial status
4. Filter out already processed documents
5. Enqueue document in status
"""
if isinstance(input, str):
input = [input]
# 1. Remove duplicate contents from the list
unique_contents = list(set(doc.strip() for doc in input))
# 1. Validate ids if provided or generate MD5 hash IDs
if ids is not None:
# Check if the number of IDs matches the number of documents
if len(ids) != len(input):
raise ValueError("Number of IDs must match the number of documents")
# 2. Generate document IDs and initial status
# Check if IDs are unique
if len(ids) != len(set(ids)):
raise ValueError("IDs must be unique")
# Generate contents dict of IDs provided by user and documents
contents = {id_: doc.strip() for id_, doc in zip(ids, input)}
else:
# Generate contents dict of MD5 hash IDs and documents
contents = {
compute_mdhash_id(doc.strip(), prefix="doc-"): doc.strip()
for doc in input
}
# 2. Remove duplicate contents
unique_contents = {
id_: content
for content, id_ in {
content: id_ for id_, content in contents.items()
}.items()
}
# 3. Generate document initial status
new_docs: dict[str, Any] = {
compute_mdhash_id(content, prefix="doc-"): {
id_: {
"content": content,
"content_summary": self._get_content_summary(content),
"content_length": len(content),
@@ -571,10 +602,10 @@ class LightRAG:
"created_at": datetime.now().isoformat(),
"updated_at": datetime.now().isoformat(),
}
for content in unique_contents
for id_, content in unique_contents.items()
}
# 3. Filter out already processed documents
# 4. Filter out already processed documents
# Get docs ids
all_new_doc_ids = set(new_docs.keys())
# Exclude IDs of documents that are already in progress
@@ -586,7 +617,7 @@ class LightRAG:
logger.info("No new unique documents were found.")
return
# 4. Store status document
# 5. Store status document
await self.doc_status.upsert(new_docs)
logger.info(f"Stored {len(new_docs)} new unique documents")
@@ -643,8 +674,6 @@ class LightRAG:
# 4. iterate over batch
for doc_id_processing_status in docs_batch:
doc_id, status_doc = doc_id_processing_status
# Update status in processing
doc_status_id = compute_mdhash_id(status_doc.content, prefix="doc-")
# Generate chunks from document
chunks: dict[str, Any] = {
compute_mdhash_id(dp["content"], prefix="chunk-"): {
@@ -664,7 +693,7 @@ class LightRAG:
tasks = [
self.doc_status.upsert(
{
doc_status_id: {
doc_id: {
"status": DocStatus.PROCESSING,
"updated_at": datetime.now().isoformat(),
"content": status_doc.content,
@@ -685,7 +714,7 @@ class LightRAG:
await asyncio.gather(*tasks)
await self.doc_status.upsert(
{
doc_status_id: {
doc_id: {
"status": DocStatus.PROCESSED,
"chunks_count": len(chunks),
"content": status_doc.content,
@@ -700,7 +729,7 @@ class LightRAG:
logger.error(f"Failed to process document {doc_id}: {str(e)}")
await self.doc_status.upsert(
{
doc_status_id: {
doc_id: {
"status": DocStatus.FAILED,
"error": str(e),
"content": status_doc.content,