Add multi-process support for vector database and graph storage with lock flags

• Implement storage lock mechanism
• Add update flag handling
• Add cross-process reload detection
This commit is contained in:
yangdx
2025-03-01 10:37:05 +08:00
parent d704512139
commit d3de57c1e4
2 changed files with 175 additions and 49 deletions

View File

@@ -16,7 +16,12 @@ if not pm.is_installed("nano-vectordb"):
pm.install("nano-vectordb")
from nano_vectordb import NanoVectorDB
from .shared_storage import get_storage_lock
from .shared_storage import (
get_storage_lock,
get_update_flag,
set_all_update_flags,
is_multiprocess,
)
@final
@@ -24,8 +29,9 @@ from .shared_storage import get_storage_lock
class NanoVectorDBStorage(BaseVectorStorage):
def __post_init__(self):
# Initialize basic attributes
self._storage_lock = get_storage_lock()
self._client = None
self._storage_lock = None
self.storage_updated = None
# Use global config value if specified, otherwise use default
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
@@ -41,17 +47,38 @@ class NanoVectorDBStorage(BaseVectorStorage):
)
self._max_batch_size = self.global_config["embedding_batch_num"]
self._client = NanoVectorDB(
self.embedding_func.embedding_dim,
storage_file=self._client_file_name,
)
async def initialize(self):
"""Initialize storage data"""
async with self._storage_lock:
self._client = NanoVectorDB(
self.embedding_func.embedding_dim,
storage_file=self._client_file_name,
)
# Get the update flag for cross-process update notification
self.storage_updated = await get_update_flag(self.namespace)
# Get the storage lock for use in other methods
self._storage_lock = get_storage_lock()
def _get_client(self):
"""Check if the shtorage should be reloaded"""
return self._client
async def _get_client(self):
"""Check if the storage should be reloaded"""
# Acquire lock to prevent concurrent read and write
async with self._storage_lock:
# Check if data needs to be reloaded
if (is_multiprocess and self.storage_updated.value) or \
(not is_multiprocess and self.storage_updated):
logger.info(f"Reloading storage for {self.namespace} due to update by another process")
# Reload data
self._client = NanoVectorDB(
self.embedding_func.embedding_dim,
storage_file=self._client_file_name,
)
# Reset update flag
if is_multiprocess:
self.storage_updated.value = False
else:
self.storage_updated = False
return self._client
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
@@ -81,7 +108,8 @@ class NanoVectorDBStorage(BaseVectorStorage):
if len(embeddings) == len(list_data):
for i, d in enumerate(list_data):
d["__vector__"] = embeddings[i]
results = self._get_client().upsert(datas=list_data)
client = await self._get_client()
results = client.upsert(datas=list_data)
return results
else:
# sometimes the embedding is not returned correctly. just log it.
@@ -94,7 +122,8 @@ class NanoVectorDBStorage(BaseVectorStorage):
embedding = await self.embedding_func([query])
embedding = embedding[0]
results = self._get_client().query(
client = await self._get_client()
results = client.query(
query=embedding,
top_k=top_k,
better_than_threshold=self.cosine_better_than_threshold,
@@ -111,8 +140,9 @@ class NanoVectorDBStorage(BaseVectorStorage):
return results
@property
def client_storage(self):
return getattr(self._get_client(), "_NanoVectorDB__storage")
async def client_storage(self):
client = await self._get_client()
return getattr(client, "_NanoVectorDB__storage")
async def delete(self, ids: list[str]):
"""Delete vectors with specified IDs
@@ -121,7 +151,8 @@ class NanoVectorDBStorage(BaseVectorStorage):
ids: List of vector IDs to be deleted
"""
try:
self._get_client().delete(ids)
client = await self._get_client()
client.delete(ids)
logger.debug(
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
)
@@ -136,8 +167,9 @@ class NanoVectorDBStorage(BaseVectorStorage):
)
# Check if the entity exists
if self._get_client().get([entity_id]):
self._get_client().delete([entity_id])
client = await self._get_client()
if client.get([entity_id]):
client.delete([entity_id])
logger.debug(f"Successfully deleted entity {entity_name}")
else:
logger.debug(f"Entity {entity_name} not found in storage")
@@ -146,7 +178,8 @@ class NanoVectorDBStorage(BaseVectorStorage):
async def delete_entity_relation(self, entity_name: str) -> None:
try:
storage = getattr(self._get_client(), "_NanoVectorDB__storage")
client = await self._get_client()
storage = getattr(client, "_NanoVectorDB__storage")
relations = [
dp
for dp in storage["data"]
@@ -156,7 +189,8 @@ class NanoVectorDBStorage(BaseVectorStorage):
ids_to_delete = [relation["__id__"] for relation in relations]
if ids_to_delete:
self._get_client().delete(ids_to_delete)
client = await self._get_client()
client.delete(ids_to_delete)
logger.debug(
f"Deleted {len(ids_to_delete)} relations for {entity_name}"
)
@@ -166,5 +200,32 @@ class NanoVectorDBStorage(BaseVectorStorage):
logger.error(f"Error deleting relations for {entity_name}: {e}")
async def index_done_callback(self) -> None:
# Check if storage was updated by another process
if is_multiprocess and self.storage_updated.value:
# Storage was updated by another process, reload data instead of saving
logger.warning(f"Storage for {self.namespace} was updated by another process, reloading...")
self._client = NanoVectorDB(
self.embedding_func.embedding_dim,
storage_file=self._client_file_name,
)
# Reset update flag
self.storage_updated.value = False
return False # Return error
# Acquire lock and perform persistence
client = await self._get_client()
async with self._storage_lock:
self._get_client().save()
try:
# Save data to disk
client.save()
# Notify other processes that data has been updated
await set_all_update_flags(self.namespace)
# Reset own update flag to avoid self-notification
if is_multiprocess:
self.storage_updated.value = False
else:
self.storage_updated = False
return True # Return success
except Exception as e:
logger.error(f"Error saving data for {self.namespace}: {e}")
return False # Return error

View File

@@ -17,7 +17,12 @@ if not pm.is_installed("graspologic"):
import networkx as nx
from graspologic import embed
from threading import Lock as ThreadLock
from .shared_storage import (
get_storage_lock,
get_update_flag,
set_all_update_flags,
is_multiprocess,
)
@final
@@ -73,10 +78,12 @@ class NetworkXStorage(BaseGraphStorage):
self._graphml_xml_file = os.path.join(
self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
)
self._storage_lock = ThreadLock()
self._storage_lock = None
self.storage_updated = None
self._graph = None
with self._storage_lock:
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
# Load initial graph
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
if preloaded_graph is not None:
logger.info(
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
@@ -84,54 +91,83 @@ class NetworkXStorage(BaseGraphStorage):
else:
logger.info("Created new empty graph")
self._graph = preloaded_graph or nx.Graph()
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed,
}
def _get_graph(self):
"""Check if the shtorage should be reloaded"""
return self._graph
async def initialize(self):
"""Initialize storage data"""
# Get the update flag for cross-process update notification
self.storage_updated = await get_update_flag(self.namespace)
# Get the storage lock for use in other methods
self._storage_lock = get_storage_lock()
async def _get_graph(self):
"""Check if the storage should be reloaded"""
# Acquire lock to prevent concurrent read and write
async with self._storage_lock:
# Check if data needs to be reloaded
if (is_multiprocess and self.storage_updated.value) or \
(not is_multiprocess and self.storage_updated):
logger.info(f"Reloading graph for {self.namespace} due to update by another process")
# Reload data
self._graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph()
# Reset update flag
if is_multiprocess:
self.storage_updated.value = False
else:
self.storage_updated = False
return self._graph
async def index_done_callback(self) -> None:
with self._storage_lock:
NetworkXStorage.write_nx_graph(self._get_graph(), self._graphml_xml_file)
async def has_node(self, node_id: str) -> bool:
return self._get_graph().has_node(node_id)
graph = await self._get_graph()
return graph.has_node(node_id)
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
return self._get_graph().has_edge(source_node_id, target_node_id)
graph = await self._get_graph()
return graph.has_edge(source_node_id, target_node_id)
async def get_node(self, node_id: str) -> dict[str, str] | None:
return self._get_graph().nodes.get(node_id)
graph = await self._get_graph()
return graph.nodes.get(node_id)
async def node_degree(self, node_id: str) -> int:
return self._get_graph().degree(node_id)
graph = await self._get_graph()
return graph.degree(node_id)
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
return self._get_graph().degree(src_id) + self._get_graph().degree(tgt_id)
graph = await self._get_graph()
return graph.degree(src_id) + graph.degree(tgt_id)
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None:
return self._get_graph().edges.get((source_node_id, target_node_id))
graph = await self._get_graph()
return graph.edges.get((source_node_id, target_node_id))
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
if self._get_graph().has_node(source_node_id):
return list(self._get_graph().edges(source_node_id))
graph = await self._get_graph()
if graph.has_node(source_node_id):
return list(graph.edges(source_node_id))
return None
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
self._get_graph().add_node(node_id, **node_data)
graph = await self._get_graph()
graph.add_node(node_id, **node_data)
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
) -> None:
self._get_graph().add_edge(source_node_id, target_node_id, **edge_data)
graph = await self._get_graph()
graph.add_edge(source_node_id, target_node_id, **edge_data)
async def delete_node(self, node_id: str) -> None:
if self._get_graph().has_node(node_id):
self._get_graph().remove_node(node_id)
graph = await self._get_graph()
if graph.has_node(node_id):
graph.remove_node(node_id)
logger.debug(f"Node {node_id} deleted from the graph.")
else:
logger.warning(f"Node {node_id} not found in the graph for deletion.")
@@ -145,7 +181,7 @@ class NetworkXStorage(BaseGraphStorage):
# TODO: NOT USED
async def _node2vec_embed(self):
graph = self._get_graph()
graph = await self._get_graph()
embeddings, nodes = embed.node2vec_embed(
graph,
**self.global_config["node2vec_params"],
@@ -153,24 +189,24 @@ class NetworkXStorage(BaseGraphStorage):
nodes_ids = [graph.nodes[node_id]["id"] for node_id in nodes]
return embeddings, nodes_ids
def remove_nodes(self, nodes: list[str]):
async def remove_nodes(self, nodes: list[str]):
"""Delete multiple nodes
Args:
nodes: List of node IDs to be deleted
"""
graph = self._get_graph()
graph = await self._get_graph()
for node in nodes:
if graph.has_node(node):
graph.remove_node(node)
def remove_edges(self, edges: list[tuple[str, str]]):
async def remove_edges(self, edges: list[tuple[str, str]]):
"""Delete multiple edges
Args:
edges: List of edges to be deleted, each edge is a (source, target) tuple
"""
graph = self._get_graph()
graph = await self._get_graph()
for source, target in edges:
if graph.has_edge(source, target):
graph.remove_edge(source, target)
@@ -181,8 +217,9 @@ class NetworkXStorage(BaseGraphStorage):
Returns:
[label1, label2, ...] # Alphabetically sorted label list
"""
graph = await self._get_graph()
labels = set()
for node in self._get_graph().nodes():
for node in graph.nodes():
labels.add(str(node)) # Add node id as a label
# Return sorted list
@@ -205,7 +242,7 @@ class NetworkXStorage(BaseGraphStorage):
seen_nodes = set()
seen_edges = set()
graph = self._get_graph()
graph = await self._get_graph()
# Handle special case for "*" label
if node_label == "*":
@@ -291,3 +328,31 @@ class NetworkXStorage(BaseGraphStorage):
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
)
return result
async def index_done_callback(self) -> None:
# Check if storage was updated by another process
if is_multiprocess and self.storage_updated.value:
# Storage was updated by another process, reload data instead of saving
logger.warning(f"Graph for {self.namespace} was updated by another process, reloading...")
self._graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph()
# Reset update flag
self.storage_updated.value = False
return False # Return error
# Acquire lock and perform persistence
graph = await self._get_graph()
async with self._storage_lock:
try:
# Save data to disk
NetworkXStorage.write_nx_graph(graph, self._graphml_xml_file)
# Notify other processes that data has been updated
await set_all_update_flags(self.namespace)
# Reset own update flag to avoid self-notification
if is_multiprocess:
self.storage_updated.value = False
else:
self.storage_updated = False
return True # Return success
except Exception as e:
logger.error(f"Error saving graph for {self.namespace}: {e}")
return False # Return error