Add multi-process support for vector database and graph storage with lock flags
• Implement storage lock mechanism • Add update flag handling • Add cross-process reload detection
This commit is contained in:
@@ -16,7 +16,12 @@ if not pm.is_installed("nano-vectordb"):
|
||||
pm.install("nano-vectordb")
|
||||
|
||||
from nano_vectordb import NanoVectorDB
|
||||
from .shared_storage import get_storage_lock
|
||||
from .shared_storage import (
|
||||
get_storage_lock,
|
||||
get_update_flag,
|
||||
set_all_update_flags,
|
||||
is_multiprocess,
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
@@ -24,8 +29,9 @@ from .shared_storage import get_storage_lock
|
||||
class NanoVectorDBStorage(BaseVectorStorage):
|
||||
def __post_init__(self):
|
||||
# Initialize basic attributes
|
||||
self._storage_lock = get_storage_lock()
|
||||
self._client = None
|
||||
self._storage_lock = None
|
||||
self.storage_updated = None
|
||||
|
||||
# Use global config value if specified, otherwise use default
|
||||
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
@@ -41,16 +47,37 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
)
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize storage data"""
|
||||
async with self._storage_lock:
|
||||
self._client = NanoVectorDB(
|
||||
self.embedding_func.embedding_dim,
|
||||
storage_file=self._client_file_name,
|
||||
)
|
||||
|
||||
def _get_client(self):
|
||||
"""Check if the shtorage should be reloaded"""
|
||||
async def initialize(self):
|
||||
"""Initialize storage data"""
|
||||
# Get the update flag for cross-process update notification
|
||||
self.storage_updated = await get_update_flag(self.namespace)
|
||||
# Get the storage lock for use in other methods
|
||||
self._storage_lock = get_storage_lock()
|
||||
|
||||
async def _get_client(self):
|
||||
"""Check if the storage should be reloaded"""
|
||||
# Acquire lock to prevent concurrent read and write
|
||||
async with self._storage_lock:
|
||||
# Check if data needs to be reloaded
|
||||
if (is_multiprocess and self.storage_updated.value) or \
|
||||
(not is_multiprocess and self.storage_updated):
|
||||
logger.info(f"Reloading storage for {self.namespace} due to update by another process")
|
||||
# Reload data
|
||||
self._client = NanoVectorDB(
|
||||
self.embedding_func.embedding_dim,
|
||||
storage_file=self._client_file_name,
|
||||
)
|
||||
# Reset update flag
|
||||
if is_multiprocess:
|
||||
self.storage_updated.value = False
|
||||
else:
|
||||
self.storage_updated = False
|
||||
|
||||
return self._client
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
@@ -81,7 +108,8 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
if len(embeddings) == len(list_data):
|
||||
for i, d in enumerate(list_data):
|
||||
d["__vector__"] = embeddings[i]
|
||||
results = self._get_client().upsert(datas=list_data)
|
||||
client = await self._get_client()
|
||||
results = client.upsert(datas=list_data)
|
||||
return results
|
||||
else:
|
||||
# sometimes the embedding is not returned correctly. just log it.
|
||||
@@ -94,7 +122,8 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
embedding = await self.embedding_func([query])
|
||||
embedding = embedding[0]
|
||||
|
||||
results = self._get_client().query(
|
||||
client = await self._get_client()
|
||||
results = client.query(
|
||||
query=embedding,
|
||||
top_k=top_k,
|
||||
better_than_threshold=self.cosine_better_than_threshold,
|
||||
@@ -111,8 +140,9 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
return results
|
||||
|
||||
@property
|
||||
def client_storage(self):
|
||||
return getattr(self._get_client(), "_NanoVectorDB__storage")
|
||||
async def client_storage(self):
|
||||
client = await self._get_client()
|
||||
return getattr(client, "_NanoVectorDB__storage")
|
||||
|
||||
async def delete(self, ids: list[str]):
|
||||
"""Delete vectors with specified IDs
|
||||
@@ -121,7 +151,8 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
ids: List of vector IDs to be deleted
|
||||
"""
|
||||
try:
|
||||
self._get_client().delete(ids)
|
||||
client = await self._get_client()
|
||||
client.delete(ids)
|
||||
logger.debug(
|
||||
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
|
||||
)
|
||||
@@ -136,8 +167,9 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
)
|
||||
|
||||
# Check if the entity exists
|
||||
if self._get_client().get([entity_id]):
|
||||
self._get_client().delete([entity_id])
|
||||
client = await self._get_client()
|
||||
if client.get([entity_id]):
|
||||
client.delete([entity_id])
|
||||
logger.debug(f"Successfully deleted entity {entity_name}")
|
||||
else:
|
||||
logger.debug(f"Entity {entity_name} not found in storage")
|
||||
@@ -146,7 +178,8 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
|
||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||
try:
|
||||
storage = getattr(self._get_client(), "_NanoVectorDB__storage")
|
||||
client = await self._get_client()
|
||||
storage = getattr(client, "_NanoVectorDB__storage")
|
||||
relations = [
|
||||
dp
|
||||
for dp in storage["data"]
|
||||
@@ -156,7 +189,8 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
ids_to_delete = [relation["__id__"] for relation in relations]
|
||||
|
||||
if ids_to_delete:
|
||||
self._get_client().delete(ids_to_delete)
|
||||
client = await self._get_client()
|
||||
client.delete(ids_to_delete)
|
||||
logger.debug(
|
||||
f"Deleted {len(ids_to_delete)} relations for {entity_name}"
|
||||
)
|
||||
@@ -166,5 +200,32 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
logger.error(f"Error deleting relations for {entity_name}: {e}")
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
# Check if storage was updated by another process
|
||||
if is_multiprocess and self.storage_updated.value:
|
||||
# Storage was updated by another process, reload data instead of saving
|
||||
logger.warning(f"Storage for {self.namespace} was updated by another process, reloading...")
|
||||
self._client = NanoVectorDB(
|
||||
self.embedding_func.embedding_dim,
|
||||
storage_file=self._client_file_name,
|
||||
)
|
||||
# Reset update flag
|
||||
self.storage_updated.value = False
|
||||
return False # Return error
|
||||
|
||||
# Acquire lock and perform persistence
|
||||
client = await self._get_client()
|
||||
async with self._storage_lock:
|
||||
self._get_client().save()
|
||||
try:
|
||||
# Save data to disk
|
||||
client.save()
|
||||
# Notify other processes that data has been updated
|
||||
await set_all_update_flags(self.namespace)
|
||||
# Reset own update flag to avoid self-notification
|
||||
if is_multiprocess:
|
||||
self.storage_updated.value = False
|
||||
else:
|
||||
self.storage_updated = False
|
||||
return True # Return success
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving data for {self.namespace}: {e}")
|
||||
return False # Return error
|
||||
|
@@ -17,7 +17,12 @@ if not pm.is_installed("graspologic"):
|
||||
|
||||
import networkx as nx
|
||||
from graspologic import embed
|
||||
from threading import Lock as ThreadLock
|
||||
from .shared_storage import (
|
||||
get_storage_lock,
|
||||
get_update_flag,
|
||||
set_all_update_flags,
|
||||
is_multiprocess,
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
@@ -73,9 +78,11 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
self._graphml_xml_file = os.path.join(
|
||||
self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
|
||||
)
|
||||
self._storage_lock = ThreadLock()
|
||||
self._storage_lock = None
|
||||
self.storage_updated = None
|
||||
self._graph = None
|
||||
|
||||
with self._storage_lock:
|
||||
# Load initial graph
|
||||
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
|
||||
if preloaded_graph is not None:
|
||||
logger.info(
|
||||
@@ -84,54 +91,83 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
else:
|
||||
logger.info("Created new empty graph")
|
||||
self._graph = preloaded_graph or nx.Graph()
|
||||
|
||||
self._node_embed_algorithms = {
|
||||
"node2vec": self._node2vec_embed,
|
||||
}
|
||||
|
||||
def _get_graph(self):
|
||||
"""Check if the shtorage should be reloaded"""
|
||||
async def initialize(self):
|
||||
"""Initialize storage data"""
|
||||
# Get the update flag for cross-process update notification
|
||||
self.storage_updated = await get_update_flag(self.namespace)
|
||||
# Get the storage lock for use in other methods
|
||||
self._storage_lock = get_storage_lock()
|
||||
|
||||
async def _get_graph(self):
|
||||
"""Check if the storage should be reloaded"""
|
||||
# Acquire lock to prevent concurrent read and write
|
||||
async with self._storage_lock:
|
||||
# Check if data needs to be reloaded
|
||||
if (is_multiprocess and self.storage_updated.value) or \
|
||||
(not is_multiprocess and self.storage_updated):
|
||||
logger.info(f"Reloading graph for {self.namespace} due to update by another process")
|
||||
# Reload data
|
||||
self._graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph()
|
||||
# Reset update flag
|
||||
if is_multiprocess:
|
||||
self.storage_updated.value = False
|
||||
else:
|
||||
self.storage_updated = False
|
||||
|
||||
return self._graph
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
with self._storage_lock:
|
||||
NetworkXStorage.write_nx_graph(self._get_graph(), self._graphml_xml_file)
|
||||
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
return self._get_graph().has_node(node_id)
|
||||
graph = await self._get_graph()
|
||||
return graph.has_node(node_id)
|
||||
|
||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||
return self._get_graph().has_edge(source_node_id, target_node_id)
|
||||
graph = await self._get_graph()
|
||||
return graph.has_edge(source_node_id, target_node_id)
|
||||
|
||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||
return self._get_graph().nodes.get(node_id)
|
||||
graph = await self._get_graph()
|
||||
return graph.nodes.get(node_id)
|
||||
|
||||
async def node_degree(self, node_id: str) -> int:
|
||||
return self._get_graph().degree(node_id)
|
||||
graph = await self._get_graph()
|
||||
return graph.degree(node_id)
|
||||
|
||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||
return self._get_graph().degree(src_id) + self._get_graph().degree(tgt_id)
|
||||
graph = await self._get_graph()
|
||||
return graph.degree(src_id) + graph.degree(tgt_id)
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> dict[str, str] | None:
|
||||
return self._get_graph().edges.get((source_node_id, target_node_id))
|
||||
graph = await self._get_graph()
|
||||
return graph.edges.get((source_node_id, target_node_id))
|
||||
|
||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||
if self._get_graph().has_node(source_node_id):
|
||||
return list(self._get_graph().edges(source_node_id))
|
||||
graph = await self._get_graph()
|
||||
if graph.has_node(source_node_id):
|
||||
return list(graph.edges(source_node_id))
|
||||
return None
|
||||
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||
self._get_graph().add_node(node_id, **node_data)
|
||||
graph = await self._get_graph()
|
||||
graph.add_node(node_id, **node_data)
|
||||
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||
) -> None:
|
||||
self._get_graph().add_edge(source_node_id, target_node_id, **edge_data)
|
||||
graph = await self._get_graph()
|
||||
graph.add_edge(source_node_id, target_node_id, **edge_data)
|
||||
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
if self._get_graph().has_node(node_id):
|
||||
self._get_graph().remove_node(node_id)
|
||||
graph = await self._get_graph()
|
||||
if graph.has_node(node_id):
|
||||
graph.remove_node(node_id)
|
||||
logger.debug(f"Node {node_id} deleted from the graph.")
|
||||
else:
|
||||
logger.warning(f"Node {node_id} not found in the graph for deletion.")
|
||||
@@ -145,7 +181,7 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
|
||||
# TODO: NOT USED
|
||||
async def _node2vec_embed(self):
|
||||
graph = self._get_graph()
|
||||
graph = await self._get_graph()
|
||||
embeddings, nodes = embed.node2vec_embed(
|
||||
graph,
|
||||
**self.global_config["node2vec_params"],
|
||||
@@ -153,24 +189,24 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
nodes_ids = [graph.nodes[node_id]["id"] for node_id in nodes]
|
||||
return embeddings, nodes_ids
|
||||
|
||||
def remove_nodes(self, nodes: list[str]):
|
||||
async def remove_nodes(self, nodes: list[str]):
|
||||
"""Delete multiple nodes
|
||||
|
||||
Args:
|
||||
nodes: List of node IDs to be deleted
|
||||
"""
|
||||
graph = self._get_graph()
|
||||
graph = await self._get_graph()
|
||||
for node in nodes:
|
||||
if graph.has_node(node):
|
||||
graph.remove_node(node)
|
||||
|
||||
def remove_edges(self, edges: list[tuple[str, str]]):
|
||||
async def remove_edges(self, edges: list[tuple[str, str]]):
|
||||
"""Delete multiple edges
|
||||
|
||||
Args:
|
||||
edges: List of edges to be deleted, each edge is a (source, target) tuple
|
||||
"""
|
||||
graph = self._get_graph()
|
||||
graph = await self._get_graph()
|
||||
for source, target in edges:
|
||||
if graph.has_edge(source, target):
|
||||
graph.remove_edge(source, target)
|
||||
@@ -181,8 +217,9 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
Returns:
|
||||
[label1, label2, ...] # Alphabetically sorted label list
|
||||
"""
|
||||
graph = await self._get_graph()
|
||||
labels = set()
|
||||
for node in self._get_graph().nodes():
|
||||
for node in graph.nodes():
|
||||
labels.add(str(node)) # Add node id as a label
|
||||
|
||||
# Return sorted list
|
||||
@@ -205,7 +242,7 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
seen_nodes = set()
|
||||
seen_edges = set()
|
||||
|
||||
graph = self._get_graph()
|
||||
graph = await self._get_graph()
|
||||
|
||||
# Handle special case for "*" label
|
||||
if node_label == "*":
|
||||
@@ -291,3 +328,31 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
||||
)
|
||||
return result
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
# Check if storage was updated by another process
|
||||
if is_multiprocess and self.storage_updated.value:
|
||||
# Storage was updated by another process, reload data instead of saving
|
||||
logger.warning(f"Graph for {self.namespace} was updated by another process, reloading...")
|
||||
self._graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph()
|
||||
# Reset update flag
|
||||
self.storage_updated.value = False
|
||||
return False # Return error
|
||||
|
||||
# Acquire lock and perform persistence
|
||||
graph = await self._get_graph()
|
||||
async with self._storage_lock:
|
||||
try:
|
||||
# Save data to disk
|
||||
NetworkXStorage.write_nx_graph(graph, self._graphml_xml_file)
|
||||
# Notify other processes that data has been updated
|
||||
await set_all_update_flags(self.namespace)
|
||||
# Reset own update flag to avoid self-notification
|
||||
if is_multiprocess:
|
||||
self.storage_updated.value = False
|
||||
else:
|
||||
self.storage_updated = False
|
||||
return True # Return success
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving graph for {self.namespace}: {e}")
|
||||
return False # Return error
|
||||
|
Reference in New Issue
Block a user