diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index b5249540..c33059ad 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -16,6 +16,9 @@ from .shared_storage import ( get_namespace_data, get_storage_lock, get_data_init_lock, + get_update_flag, + set_all_update_flags, + clear_all_update_flags, try_initialize_namespace, ) @@ -29,10 +32,13 @@ class JsonDocStatusStorage(DocStatusStorage): working_dir = self.global_config["working_dir"] self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") self._data = None + self._storage_lock = None + self.storage_updated = None async def initialize(self): """Initialize storage data""" self._storage_lock = get_storage_lock() + self.storage_updated = await get_update_flag(self.namespace) async with get_data_init_lock(): # check need_init must before get_namespace_data need_init = await try_initialize_namespace(self.namespace) @@ -89,11 +95,13 @@ class JsonDocStatusStorage(DocStatusStorage): async def index_done_callback(self) -> None: async with self._storage_lock: - data_dict = ( - dict(self._data) if hasattr(self._data, "_getvalue") else self._data - ) - logger.info(f"Process {os.getpid()} doc status writting {len(data_dict)} records to {self.namespace}") - write_json(data_dict, self._file_name) + if self.storage_updated: + data_dict = ( + dict(self._data) if hasattr(self._data, "_getvalue") else self._data + ) + logger.info(f"Process {os.getpid()} doc status writting {len(data_dict)} records to {self.namespace}") + write_json(data_dict, self._file_name) + await clear_all_update_flags(self.namespace) async def upsert(self, data: dict[str, dict[str, Any]]) -> None: if not data: @@ -101,6 +109,7 @@ class JsonDocStatusStorage(DocStatusStorage): logger.info(f"Inserting {len(data)} to {self.namespace}") async with self._storage_lock: self._data.update(data) + await set_all_update_flags(self.namespace) await self.index_done_callback() @@ -112,9 +121,12 @@ class JsonDocStatusStorage(DocStatusStorage): async with self._storage_lock: for doc_id in doc_ids: self._data.pop(doc_id, None) + await set_all_update_flags(self.namespace) await self.index_done_callback() async def drop(self) -> None: """Drop the storage""" async with self._storage_lock: self._data.clear() + await set_all_update_flags(self.namespace) + await self.index_done_callback() diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index 81439151..c69b53ec 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -14,6 +14,9 @@ from .shared_storage import ( get_namespace_data, get_storage_lock, get_data_init_lock, + get_update_flag, + set_all_update_flags, + clear_all_update_flags, try_initialize_namespace, ) @@ -25,10 +28,13 @@ class JsonKVStorage(BaseKVStorage): working_dir = self.global_config["working_dir"] self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") self._data = None + self._storage_lock = None + self.storage_updated = None async def initialize(self): """Initialize storage data""" self._storage_lock = get_storage_lock() + self.storage_updated = await get_update_flag(self.namespace) async with get_data_init_lock(): # check need_init must before get_namespace_data need_init = await try_initialize_namespace(self.namespace) @@ -51,21 +57,24 @@ class JsonKVStorage(BaseKVStorage): async def index_done_callback(self) -> None: async with self._storage_lock: - data_dict = ( - dict(self._data) if hasattr(self._data, "_getvalue") else self._data - ) - - # Calculate data count based on namespace - if self.namespace.endswith("cache"): - # # For cache namespaces, sum the cache entries across all cache types - data_count = sum(len(first_level_dict) for first_level_dict in data_dict.values() - if isinstance(first_level_dict, dict)) - else: - # For non-cache namespaces, use the original count method - data_count = len(data_dict) - - logger.info(f"Process {os.getpid()} KV writting {data_count} records to {self.namespace}") - write_json(data_dict, self._file_name) + if self.storage_updated: + data_dict = ( + dict(self._data) if hasattr(self._data, "_getvalue") else self._data + ) + + # Calculate data count based on namespace + if self.namespace.endswith("cache"): + # # For cache namespaces, sum the cache entries across all cache types + data_count = sum(len(first_level_dict) for first_level_dict in data_dict.values() + if isinstance(first_level_dict, dict)) + else: + # For non-cache namespaces, use the original count method + data_count = len(data_dict) + + logger.info(f"Process {os.getpid()} KV writting {data_count} records to {self.namespace}") + write_json(data_dict, self._file_name) + await clear_all_update_flags(self.namespace) + async def get_all(self) -> dict[str, Any]: """Get all data from storage @@ -101,9 +110,11 @@ class JsonKVStorage(BaseKVStorage): logger.info(f"Inserting {len(data)} to {self.namespace}") async with self._storage_lock: self._data.update(data) + await set_all_update_flags(self.namespace) async def delete(self, ids: list[str]) -> None: async with self._storage_lock: for doc_id in ids: self._data.pop(doc_id, None) + await set_all_update_flags(self.namespace) await self.index_done_callback() diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index e3c25d34..9ce04d23 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -344,6 +344,21 @@ async def set_all_update_flags(namespace: str): else: _update_flags[namespace][i] = True +async def clear_all_update_flags(namespace: str): + """Clear all update flag of namespace indicating all workers need to reload data from files""" + global _update_flags + if _update_flags is None: + raise ValueError("Try to create namespace before Shared-Data is initialized") + + async with get_internal_lock(): + if namespace not in _update_flags: + raise ValueError(f"Namespace {namespace} not found in update flags") + # Update flags for both modes + for i in range(len(_update_flags[namespace])): + if is_multiprocess: + _update_flags[namespace][i].value = False + else: + _update_flags[namespace][i] = False async def get_all_update_flags_status() -> Dict[str, list]: """