From d462ace97836dd684ce03a881497820c3111855b Mon Sep 17 00:00:00 2001 From: PiochU19 <792954018@wp.pl> Date: Thu, 20 Feb 2025 00:26:35 +0100 Subject: [PATCH] add support of providing ids for documents insert --- lightrag/lightrag.py | 71 +++++++++++++++++++++++++++++++------------- 1 file changed, 50 insertions(+), 21 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 1a8dcf5c..33c29a20 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -1,8 +1,8 @@ from __future__ import annotations import asyncio -import os import configparser +import os from dataclasses import asdict, dataclass, field from datetime import datetime from functools import partial @@ -37,11 +37,11 @@ from .utils import ( always_get_an_event_loop, compute_mdhash_id, convert_response_to_json, + encode_string_by_tiktoken, lazy_external_import, limit_async_func_call, logger, set_logger, - encode_string_by_tiktoken, ) config = configparser.ConfigParser() @@ -461,6 +461,7 @@ class LightRAG: input: str | list[str], split_by_character: str | None = None, split_by_character_only: bool = False, + ids: list[str] | None = None, ) -> None: """Sync Insert documents with checkpoint support @@ -469,10 +470,11 @@ class LightRAG: split_by_character: if split_by_character is not None, split the string by character, if chunk longer than split_by_character_only: if split_by_character_only is True, split the string by character only, when split_by_character is None, this parameter is ignored. + ids: list of unique document IDs, if not provided, MD5 hash IDs will be generated """ loop = always_get_an_event_loop() loop.run_until_complete( - self.ainsert(input, split_by_character, split_by_character_only) + self.ainsert(input, split_by_character, split_by_character_only, ids) ) async def ainsert( @@ -480,6 +482,7 @@ class LightRAG: input: str | list[str], split_by_character: str | None = None, split_by_character_only: bool = False, + ids: list[str] | None = None, ) -> None: """Async Insert documents with checkpoint support @@ -488,8 +491,9 @@ class LightRAG: split_by_character: if split_by_character is not None, split the string by character, if chunk longer than split_by_character_only: if split_by_character_only is True, split the string by character only, when split_by_character is None, this parameter is ignored. + ids: list of unique document IDs, if not provided, MD5 hash IDs will be generated """ - await self.apipeline_enqueue_documents(input) + await self.apipeline_enqueue_documents(input, ids) await self.apipeline_process_enqueue_documents( split_by_character, split_by_character_only ) @@ -546,24 +550,51 @@ class LightRAG: if update_storage: await self._insert_done() - async def apipeline_enqueue_documents(self, input: str | list[str]) -> None: + async def apipeline_enqueue_documents( + self, input: str | list[str], ids: list[str] | None + ) -> None: """ Pipeline for Processing Documents - 1. Remove duplicate contents from the list - 2. Generate document IDs and initial status - 3. Filter out already processed documents - 4. Enqueue document in status + 1. Validate ids if provided or generate MD5 hash IDs + 2. Remove duplicate contents + 3. Generate document initial status + 4. Filter out already processed documents + 5. Enqueue document in status """ if isinstance(input, str): input = [input] - # 1. Remove duplicate contents from the list - unique_contents = list(set(doc.strip() for doc in input)) + # 1. Validate ids if provided or generate MD5 hash IDs + if ids is not None: + # Check if the number of IDs matches the number of documents + if len(ids) != len(input): + raise ValueError("Number of IDs must match the number of documents") - # 2. Generate document IDs and initial status + # Check if IDs are unique + if len(ids) != len(set(ids)): + raise ValueError("IDs must be unique") + + # Generate contents dict of IDs provided by user and documents + contents = {id_: doc.strip() for id_, doc in zip(ids, input)} + else: + # Generate contents dict of MD5 hash IDs and documents + contents = { + compute_mdhash_id(doc.strip(), prefix="doc-"): doc.strip() + for doc in input + } + + # 2. Remove duplicate contents + unique_contents = { + id_: content + for content, id_ in { + content: id_ for id_, content in contents.items() + }.items() + } + + # 3. Generate document initial status new_docs: dict[str, Any] = { - compute_mdhash_id(content, prefix="doc-"): { + id_: { "content": content, "content_summary": self._get_content_summary(content), "content_length": len(content), @@ -571,10 +602,10 @@ class LightRAG: "created_at": datetime.now().isoformat(), "updated_at": datetime.now().isoformat(), } - for content in unique_contents + for id_, content in unique_contents.items() } - # 3. Filter out already processed documents + # 4. Filter out already processed documents # Get docs ids all_new_doc_ids = set(new_docs.keys()) # Exclude IDs of documents that are already in progress @@ -586,7 +617,7 @@ class LightRAG: logger.info("No new unique documents were found.") return - # 4. Store status document + # 5. Store status document await self.doc_status.upsert(new_docs) logger.info(f"Stored {len(new_docs)} new unique documents") @@ -643,8 +674,6 @@ class LightRAG: # 4. iterate over batch for doc_id_processing_status in docs_batch: doc_id, status_doc = doc_id_processing_status - # Update status in processing - doc_status_id = compute_mdhash_id(status_doc.content, prefix="doc-") # Generate chunks from document chunks: dict[str, Any] = { compute_mdhash_id(dp["content"], prefix="chunk-"): { @@ -664,7 +693,7 @@ class LightRAG: tasks = [ self.doc_status.upsert( { - doc_status_id: { + doc_id: { "status": DocStatus.PROCESSING, "updated_at": datetime.now().isoformat(), "content": status_doc.content, @@ -685,7 +714,7 @@ class LightRAG: await asyncio.gather(*tasks) await self.doc_status.upsert( { - doc_status_id: { + doc_id: { "status": DocStatus.PROCESSED, "chunks_count": len(chunks), "content": status_doc.content, @@ -700,7 +729,7 @@ class LightRAG: logger.error(f"Failed to process document {doc_id}: {str(e)}") await self.doc_status.upsert( { - doc_status_id: { + doc_id: { "status": DocStatus.FAILED, "error": str(e), "content": status_doc.content,