Fix share storage update status handling problem of in memeory storage

This commit is contained in:
yangdx
2025-03-25 10:48:15 +08:00
parent 91f32dc561
commit 15e060f854
4 changed files with 39 additions and 74 deletions

View File

@@ -19,7 +19,6 @@ from .shared_storage import (
get_storage_lock,
get_update_flag,
set_all_update_flags,
is_multiprocess,
)
@@ -73,9 +72,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
# Acquire lock to prevent concurrent read and write
async with self._storage_lock:
# Check if storage was updated by another process
if (is_multiprocess and self.storage_updated.value) or (
not is_multiprocess and self.storage_updated
):
if self.storage_updated.value:
logger.info(
f"Process {os.getpid()} FAISS reloading {self.namespace} due to update by another process"
)
@@ -83,10 +80,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
self._index = faiss.IndexFlatIP(self._dim)
self._id_to_meta = {}
self._load_faiss_index()
if is_multiprocess:
self.storage_updated.value = False
else:
self.storage_updated = False
return self._index
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
@@ -345,7 +339,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
async def index_done_callback(self) -> None:
async with self._storage_lock:
# Check if storage was updated by another process
if is_multiprocess and self.storage_updated.value:
if self.storage_updated.value:
# Storage was updated by another process, reload data instead of saving
logger.warning(
f"Storage for FAISS {self.namespace} was updated by another process, reloading..."
@@ -365,10 +359,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
# Notify other processes that data has been updated
await set_all_update_flags(self.namespace)
# Reset own update flag to avoid self-reloading
if is_multiprocess:
self.storage_updated.value = False
else:
self.storage_updated = False
except Exception as e:
logger.error(f"Error saving FAISS index for {self.namespace}: {e}")
return False # Return error

View File

@@ -20,7 +20,6 @@ from .shared_storage import (
get_storage_lock,
get_update_flag,
set_all_update_flags,
is_multiprocess,
)
@@ -57,16 +56,14 @@ class NanoVectorDBStorage(BaseVectorStorage):
# Get the update flag for cross-process update notification
self.storage_updated = await get_update_flag(self.namespace)
# Get the storage lock for use in other methods
self._storage_lock = get_storage_lock()
self._storage_lock = get_storage_lock(enable_logging=False)
async def _get_client(self):
"""Check if the storage should be reloaded"""
# Acquire lock to prevent concurrent read and write
async with self._storage_lock:
# Check if data needs to be reloaded
if (is_multiprocess and self.storage_updated.value) or (
not is_multiprocess and self.storage_updated
):
if self.storage_updated.value:
logger.info(
f"Process {os.getpid()} reloading {self.namespace} due to update by another process"
)
@@ -76,10 +73,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
storage_file=self._client_file_name,
)
# Reset update flag
if is_multiprocess:
self.storage_updated.value = False
else:
self.storage_updated = False
return self._client
@@ -208,7 +202,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
"""Save data to disk"""
async with self._storage_lock:
# Check if storage was updated by another process
if is_multiprocess and self.storage_updated.value:
if self.storage_updated.value:
# Storage was updated by another process, reload data instead of saving
logger.warning(
f"Storage for {self.namespace} was updated by another process, reloading..."
@@ -229,10 +223,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
# Notify other processes that data has been updated
await set_all_update_flags(self.namespace)
# Reset own update flag to avoid self-reloading
if is_multiprocess:
self.storage_updated.value = False
else:
self.storage_updated = False
return True # Return success
except Exception as e:
logger.error(f"Error saving data for {self.namespace}: {e}")

View File

@@ -21,7 +21,6 @@ from .shared_storage import (
get_storage_lock,
get_update_flag,
set_all_update_flags,
is_multiprocess,
)
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
@@ -110,9 +109,7 @@ class NetworkXStorage(BaseGraphStorage):
# Acquire lock to prevent concurrent read and write
async with self._storage_lock:
# Check if data needs to be reloaded
if (is_multiprocess and self.storage_updated.value) or (
not is_multiprocess and self.storage_updated
):
if self.storage_updated.value:
logger.info(
f"Process {os.getpid()} reloading graph {self.namespace} due to update by another process"
)
@@ -121,10 +118,7 @@ class NetworkXStorage(BaseGraphStorage):
NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph()
)
# Reset update flag
if is_multiprocess:
self.storage_updated.value = False
else:
self.storage_updated = False
return self._graph
@@ -403,7 +397,7 @@ class NetworkXStorage(BaseGraphStorage):
"""Save data to disk"""
async with self._storage_lock:
# Check if storage was updated by another process
if is_multiprocess and self.storage_updated.value:
if self.storage_updated.value:
# Storage was updated by another process, reload data instead of saving
logger.warning(
f"Graph for {self.namespace} was updated by another process, reloading..."
@@ -423,10 +417,7 @@ class NetworkXStorage(BaseGraphStorage):
# Notify other processes that data has been updated
await set_all_update_flags(self.namespace)
# Reset own update flag to avoid self-reloading
if is_multiprocess:
self.storage_updated.value = False
else:
self.storage_updated = False
return True # Return success
except Exception as e:
logger.error(f"Error saving graph for {self.namespace}: {e}")

View File

@@ -24,7 +24,7 @@ def direct_log(message, level="INFO", enable_output: bool = True):
T = TypeVar("T")
LockType = Union[ProcessLock, asyncio.Lock]
is_multiprocess = None
_is_multiprocess = None
_workers = None
_manager = None
_initialized = None
@@ -218,10 +218,10 @@ class UnifiedLock(Generic[T]):
def get_internal_lock(enable_logging: bool = False) -> UnifiedLock:
"""return unified storage lock for data consistency"""
async_lock = _async_locks.get("internal_lock") if is_multiprocess else None
async_lock = _async_locks.get("internal_lock") if _is_multiprocess else None
return UnifiedLock(
lock=_internal_lock,
is_async=not is_multiprocess,
is_async=not _is_multiprocess,
name="internal_lock",
enable_logging=enable_logging,
async_lock=async_lock,
@@ -230,10 +230,10 @@ def get_internal_lock(enable_logging: bool = False) -> UnifiedLock:
def get_storage_lock(enable_logging: bool = False) -> UnifiedLock:
"""return unified storage lock for data consistency"""
async_lock = _async_locks.get("storage_lock") if is_multiprocess else None
async_lock = _async_locks.get("storage_lock") if _is_multiprocess else None
return UnifiedLock(
lock=_storage_lock,
is_async=not is_multiprocess,
is_async=not _is_multiprocess,
name="storage_lock",
enable_logging=enable_logging,
async_lock=async_lock,
@@ -242,10 +242,10 @@ def get_storage_lock(enable_logging: bool = False) -> UnifiedLock:
def get_pipeline_status_lock(enable_logging: bool = False) -> UnifiedLock:
"""return unified storage lock for data consistency"""
async_lock = _async_locks.get("pipeline_status_lock") if is_multiprocess else None
async_lock = _async_locks.get("pipeline_status_lock") if _is_multiprocess else None
return UnifiedLock(
lock=_pipeline_status_lock,
is_async=not is_multiprocess,
is_async=not _is_multiprocess,
name="pipeline_status_lock",
enable_logging=enable_logging,
async_lock=async_lock,
@@ -254,10 +254,10 @@ def get_pipeline_status_lock(enable_logging: bool = False) -> UnifiedLock:
def get_graph_db_lock(enable_logging: bool = False) -> UnifiedLock:
"""return unified graph database lock for ensuring atomic operations"""
async_lock = _async_locks.get("graph_db_lock") if is_multiprocess else None
async_lock = _async_locks.get("graph_db_lock") if _is_multiprocess else None
return UnifiedLock(
lock=_graph_db_lock,
is_async=not is_multiprocess,
is_async=not _is_multiprocess,
name="graph_db_lock",
enable_logging=enable_logging,
async_lock=async_lock,
@@ -266,10 +266,10 @@ def get_graph_db_lock(enable_logging: bool = False) -> UnifiedLock:
def get_data_init_lock(enable_logging: bool = False) -> UnifiedLock:
"""return unified data initialization lock for ensuring atomic data initialization"""
async_lock = _async_locks.get("data_init_lock") if is_multiprocess else None
async_lock = _async_locks.get("data_init_lock") if _is_multiprocess else None
return UnifiedLock(
lock=_data_init_lock,
is_async=not is_multiprocess,
is_async=not _is_multiprocess,
name="data_init_lock",
enable_logging=enable_logging,
async_lock=async_lock,
@@ -297,7 +297,7 @@ def initialize_share_data(workers: int = 1):
global \
_manager, \
_workers, \
is_multiprocess, \
_is_multiprocess, \
_storage_lock, \
_internal_lock, \
_pipeline_status_lock, \
@@ -312,14 +312,14 @@ def initialize_share_data(workers: int = 1):
# Check if already initialized
if _initialized:
direct_log(
f"Process {os.getpid()} Shared-Data already initialized (multiprocess={is_multiprocess})"
f"Process {os.getpid()} Shared-Data already initialized (multiprocess={_is_multiprocess})"
)
return
_workers = workers
if workers > 1:
is_multiprocess = True
_is_multiprocess = True
_manager = Manager()
_internal_lock = _manager.Lock()
_storage_lock = _manager.Lock()
@@ -343,7 +343,7 @@ def initialize_share_data(workers: int = 1):
f"Process {os.getpid()} Shared-Data created for Multiple Process (workers={workers})"
)
else:
is_multiprocess = False
_is_multiprocess = False
_internal_lock = asyncio.Lock()
_storage_lock = asyncio.Lock()
_pipeline_status_lock = asyncio.Lock()
@@ -372,7 +372,7 @@ async def initialize_pipeline_status():
return
# Create a shared list object for history_messages
history_messages = _manager.list() if is_multiprocess else []
history_messages = _manager.list() if _is_multiprocess else []
pipeline_namespace.update(
{
"autoscanned": False, # Auto-scan started
@@ -401,7 +401,7 @@ async def get_update_flag(namespace: str):
async with get_internal_lock():
if namespace not in _update_flags:
if is_multiprocess and _manager is not None:
if _is_multiprocess and _manager is not None:
_update_flags[namespace] = _manager.list()
else:
_update_flags[namespace] = []
@@ -409,7 +409,7 @@ async def get_update_flag(namespace: str):
f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]"
)
if is_multiprocess and _manager is not None:
if _is_multiprocess and _manager is not None:
new_update_flag = _manager.Value("b", False)
else:
# Create a simple mutable object to store boolean value for compatibility with mutiprocess
@@ -434,10 +434,6 @@ async def set_all_update_flags(namespace: str):
raise ValueError(f"Namespace {namespace} not found in update flags")
# Update flags for both modes
for i in range(len(_update_flags[namespace])):
if is_multiprocess:
_update_flags[namespace][i].value = True
else:
# Use .value attribute instead of direct assignment
_update_flags[namespace][i].value = True
@@ -452,10 +448,6 @@ async def clear_all_update_flags(namespace: str):
raise ValueError(f"Namespace {namespace} not found in update flags")
# Update flags for both modes
for i in range(len(_update_flags[namespace])):
if is_multiprocess:
_update_flags[namespace][i].value = False
else:
# Use .value attribute instead of direct assignment
_update_flags[namespace][i].value = False
@@ -474,7 +466,7 @@ async def get_all_update_flags_status() -> Dict[str, list]:
for namespace, flags in _update_flags.items():
worker_statuses = []
for flag in flags:
if is_multiprocess:
if _is_multiprocess:
worker_statuses.append(flag.value)
else:
worker_statuses.append(flag)
@@ -518,7 +510,7 @@ async def get_namespace_data(namespace: str) -> Dict[str, Any]:
async with get_internal_lock():
if namespace not in _shared_dicts:
if is_multiprocess and _manager is not None:
if _is_multiprocess and _manager is not None:
_shared_dicts[namespace] = _manager.dict()
else:
_shared_dicts[namespace] = {}
@@ -538,7 +530,7 @@ def finalize_share_data():
"""
global \
_manager, \
is_multiprocess, \
_is_multiprocess, \
_storage_lock, \
_internal_lock, \
_pipeline_status_lock, \
@@ -558,11 +550,11 @@ def finalize_share_data():
return
direct_log(
f"Process {os.getpid()} finalizing storage data (multiprocess={is_multiprocess})"
f"Process {os.getpid()} finalizing storage data (multiprocess={_is_multiprocess})"
)
# In multi-process mode, shut down the Manager
if is_multiprocess and _manager is not None:
if _is_multiprocess and _manager is not None:
try:
# Clear shared resources before shutting down Manager
if _shared_dicts is not None:
@@ -604,7 +596,7 @@ def finalize_share_data():
# Reset global variables
_manager = None
_initialized = None
is_multiprocess = None
_is_multiprocess = None
_shared_dicts = None
_init_flags = None
_storage_lock = None