From d4f6dcfd54963183e10f43a7a311425cbbb4f5bd Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 1 Mar 2025 12:41:30 +0800 Subject: [PATCH] Improve multi-process data synchronization and persistence in storage implementations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Remove _get_client() or _get_graph() from index_done_callback • Add return value for index_done_callback --- lightrag/kg/nano_vector_db_impl.py | 12 +++++++----- lightrag/kg/networkx_impl.py | 12 +++++++----- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index e0047a21..c17189c6 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -66,7 +66,7 @@ class NanoVectorDBStorage(BaseVectorStorage): # Check if data needs to be reloaded if (is_multiprocess and self.storage_updated.value) or \ (not is_multiprocess and self.storage_updated): - logger.info(f"Reloading storage for {self.namespace} due to update by another process") + logger.info(f"Process {os.getpid()} reloading {self.namespace} due to update by another process") # Reload data self._client = NanoVectorDB( self.embedding_func.embedding_dim, @@ -199,7 +199,8 @@ class NanoVectorDBStorage(BaseVectorStorage): except Exception as e: logger.error(f"Error deleting relations for {entity_name}: {e}") - async def index_done_callback(self) -> None: + async def index_done_callback(self) -> bool: + """Save data to disk""" # Check if storage was updated by another process if is_multiprocess and self.storage_updated.value: # Storage was updated by another process, reload data instead of saving @@ -213,14 +214,13 @@ class NanoVectorDBStorage(BaseVectorStorage): return False # Return error # Acquire lock and perform persistence - client = await self._get_client() async with self._storage_lock: try: # Save data to disk - client.save() + self._get_client.save() # Notify other processes that data has been updated await set_all_update_flags(self.namespace) - # Reset own update flag to avoid self-notification + # Reset own update flag to avoid self-reloading if is_multiprocess: self.storage_updated.value = False else: @@ -229,3 +229,5 @@ class NanoVectorDBStorage(BaseVectorStorage): except Exception as e: logger.error(f"Error saving data for {self.namespace}: {e}") return False # Return error + + return True # Return success diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index 37db8469..2e61e6b3 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -110,7 +110,7 @@ class NetworkXStorage(BaseGraphStorage): # Check if data needs to be reloaded if (is_multiprocess and self.storage_updated.value) or \ (not is_multiprocess and self.storage_updated): - logger.info(f"Reloading graph for {self.namespace} due to update by another process") + logger.info(f"Process {os.getpid()} reloading graph {self.namespace} due to update by another process") # Reload data self._graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph() # Reset update flag @@ -329,7 +329,8 @@ class NetworkXStorage(BaseGraphStorage): ) return result - async def index_done_callback(self) -> None: + async def index_done_callback(self) -> bool: + """Save data to disk""" # Check if storage was updated by another process if is_multiprocess and self.storage_updated.value: # Storage was updated by another process, reload data instead of saving @@ -340,14 +341,13 @@ class NetworkXStorage(BaseGraphStorage): return False # Return error # Acquire lock and perform persistence - graph = await self._get_graph() async with self._storage_lock: try: # Save data to disk - NetworkXStorage.write_nx_graph(graph, self._graphml_xml_file) + NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file) # Notify other processes that data has been updated await set_all_update_flags(self.namespace) - # Reset own update flag to avoid self-notification + # Reset own update flag to avoid self-reloading if is_multiprocess: self.storage_updated.value = False else: @@ -356,3 +356,5 @@ class NetworkXStorage(BaseGraphStorage): except Exception as e: logger.error(f"Error saving graph for {self.namespace}: {e}") return False # Return error + + return True