Add multi-process support for vector database and graph storage with lock flags

• Implement storage lock mechanism
• Add update flag handling
• Add cross-process reload detection
This commit is contained in:
yangdx
2025-03-01 10:37:05 +08:00
parent d704512139
commit d3de57c1e4
2 changed files with 175 additions and 49 deletions

View File

@@ -17,7 +17,12 @@ if not pm.is_installed("graspologic"):
import networkx as nx
from graspologic import embed
from threading import Lock as ThreadLock
from .shared_storage import (
get_storage_lock,
get_update_flag,
set_all_update_flags,
is_multiprocess,
)
@final
@@ -73,10 +78,12 @@ class NetworkXStorage(BaseGraphStorage):
self._graphml_xml_file = os.path.join(
self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
)
self._storage_lock = ThreadLock()
self._storage_lock = None
self.storage_updated = None
self._graph = None
with self._storage_lock:
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
# Load initial graph
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
if preloaded_graph is not None:
logger.info(
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
@@ -84,54 +91,83 @@ class NetworkXStorage(BaseGraphStorage):
else:
logger.info("Created new empty graph")
self._graph = preloaded_graph or nx.Graph()
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed,
}
def _get_graph(self):
"""Check if the shtorage should be reloaded"""
return self._graph
async def initialize(self):
"""Initialize storage data"""
# Get the update flag for cross-process update notification
self.storage_updated = await get_update_flag(self.namespace)
# Get the storage lock for use in other methods
self._storage_lock = get_storage_lock()
async def _get_graph(self):
"""Check if the storage should be reloaded"""
# Acquire lock to prevent concurrent read and write
async with self._storage_lock:
# Check if data needs to be reloaded
if (is_multiprocess and self.storage_updated.value) or \
(not is_multiprocess and self.storage_updated):
logger.info(f"Reloading graph for {self.namespace} due to update by another process")
# Reload data
self._graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph()
# Reset update flag
if is_multiprocess:
self.storage_updated.value = False
else:
self.storage_updated = False
return self._graph
async def index_done_callback(self) -> None:
with self._storage_lock:
NetworkXStorage.write_nx_graph(self._get_graph(), self._graphml_xml_file)
async def has_node(self, node_id: str) -> bool:
return self._get_graph().has_node(node_id)
graph = await self._get_graph()
return graph.has_node(node_id)
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
return self._get_graph().has_edge(source_node_id, target_node_id)
graph = await self._get_graph()
return graph.has_edge(source_node_id, target_node_id)
async def get_node(self, node_id: str) -> dict[str, str] | None:
return self._get_graph().nodes.get(node_id)
graph = await self._get_graph()
return graph.nodes.get(node_id)
async def node_degree(self, node_id: str) -> int:
return self._get_graph().degree(node_id)
graph = await self._get_graph()
return graph.degree(node_id)
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
return self._get_graph().degree(src_id) + self._get_graph().degree(tgt_id)
graph = await self._get_graph()
return graph.degree(src_id) + graph.degree(tgt_id)
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None:
return self._get_graph().edges.get((source_node_id, target_node_id))
graph = await self._get_graph()
return graph.edges.get((source_node_id, target_node_id))
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
if self._get_graph().has_node(source_node_id):
return list(self._get_graph().edges(source_node_id))
graph = await self._get_graph()
if graph.has_node(source_node_id):
return list(graph.edges(source_node_id))
return None
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
self._get_graph().add_node(node_id, **node_data)
graph = await self._get_graph()
graph.add_node(node_id, **node_data)
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
) -> None:
self._get_graph().add_edge(source_node_id, target_node_id, **edge_data)
graph = await self._get_graph()
graph.add_edge(source_node_id, target_node_id, **edge_data)
async def delete_node(self, node_id: str) -> None:
if self._get_graph().has_node(node_id):
self._get_graph().remove_node(node_id)
graph = await self._get_graph()
if graph.has_node(node_id):
graph.remove_node(node_id)
logger.debug(f"Node {node_id} deleted from the graph.")
else:
logger.warning(f"Node {node_id} not found in the graph for deletion.")
@@ -145,7 +181,7 @@ class NetworkXStorage(BaseGraphStorage):
# TODO: NOT USED
async def _node2vec_embed(self):
graph = self._get_graph()
graph = await self._get_graph()
embeddings, nodes = embed.node2vec_embed(
graph,
**self.global_config["node2vec_params"],
@@ -153,24 +189,24 @@ class NetworkXStorage(BaseGraphStorage):
nodes_ids = [graph.nodes[node_id]["id"] for node_id in nodes]
return embeddings, nodes_ids
def remove_nodes(self, nodes: list[str]):
async def remove_nodes(self, nodes: list[str]):
"""Delete multiple nodes
Args:
nodes: List of node IDs to be deleted
"""
graph = self._get_graph()
graph = await self._get_graph()
for node in nodes:
if graph.has_node(node):
graph.remove_node(node)
def remove_edges(self, edges: list[tuple[str, str]]):
async def remove_edges(self, edges: list[tuple[str, str]]):
"""Delete multiple edges
Args:
edges: List of edges to be deleted, each edge is a (source, target) tuple
"""
graph = self._get_graph()
graph = await self._get_graph()
for source, target in edges:
if graph.has_edge(source, target):
graph.remove_edge(source, target)
@@ -181,8 +217,9 @@ class NetworkXStorage(BaseGraphStorage):
Returns:
[label1, label2, ...] # Alphabetically sorted label list
"""
graph = await self._get_graph()
labels = set()
for node in self._get_graph().nodes():
for node in graph.nodes():
labels.add(str(node)) # Add node id as a label
# Return sorted list
@@ -205,7 +242,7 @@ class NetworkXStorage(BaseGraphStorage):
seen_nodes = set()
seen_edges = set()
graph = self._get_graph()
graph = await self._get_graph()
# Handle special case for "*" label
if node_label == "*":
@@ -291,3 +328,31 @@ class NetworkXStorage(BaseGraphStorage):
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
)
return result
async def index_done_callback(self) -> None:
# Check if storage was updated by another process
if is_multiprocess and self.storage_updated.value:
# Storage was updated by another process, reload data instead of saving
logger.warning(f"Graph for {self.namespace} was updated by another process, reloading...")
self._graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph()
# Reset update flag
self.storage_updated.value = False
return False # Return error
# Acquire lock and perform persistence
graph = await self._get_graph()
async with self._storage_lock:
try:
# Save data to disk
NetworkXStorage.write_nx_graph(graph, self._graphml_xml_file)
# Notify other processes that data has been updated
await set_all_update_flags(self.namespace)
# Reset own update flag to avoid self-notification
if is_multiprocess:
self.storage_updated.value = False
else:
self.storage_updated = False
return True # Return success
except Exception as e:
logger.error(f"Error saving graph for {self.namespace}: {e}")
return False # Return error