diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index 86381379..e0047a21 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -16,7 +16,12 @@ if not pm.is_installed("nano-vectordb"): pm.install("nano-vectordb") from nano_vectordb import NanoVectorDB -from .shared_storage import get_storage_lock +from .shared_storage import ( + get_storage_lock, + get_update_flag, + set_all_update_flags, + is_multiprocess, +) @final @@ -24,8 +29,9 @@ from .shared_storage import get_storage_lock class NanoVectorDBStorage(BaseVectorStorage): def __post_init__(self): # Initialize basic attributes - self._storage_lock = get_storage_lock() self._client = None + self._storage_lock = None + self.storage_updated = None # Use global config value if specified, otherwise use default kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) @@ -41,17 +47,38 @@ class NanoVectorDBStorage(BaseVectorStorage): ) self._max_batch_size = self.global_config["embedding_batch_num"] + self._client = NanoVectorDB( + self.embedding_func.embedding_dim, + storage_file=self._client_file_name, + ) + async def initialize(self): """Initialize storage data""" - async with self._storage_lock: - self._client = NanoVectorDB( - self.embedding_func.embedding_dim, - storage_file=self._client_file_name, - ) + # Get the update flag for cross-process update notification + self.storage_updated = await get_update_flag(self.namespace) + # Get the storage lock for use in other methods + self._storage_lock = get_storage_lock() - def _get_client(self): - """Check if the shtorage should be reloaded""" - return self._client + async def _get_client(self): + """Check if the storage should be reloaded""" + # Acquire lock to prevent concurrent read and write + async with self._storage_lock: + # Check if data needs to be reloaded + if (is_multiprocess and self.storage_updated.value) or \ + (not is_multiprocess and self.storage_updated): + logger.info(f"Reloading storage for {self.namespace} due to update by another process") + # Reload data + self._client = NanoVectorDB( + self.embedding_func.embedding_dim, + storage_file=self._client_file_name, + ) + # Reset update flag + if is_multiprocess: + self.storage_updated.value = False + else: + self.storage_updated = False + + return self._client async def upsert(self, data: dict[str, dict[str, Any]]) -> None: logger.info(f"Inserting {len(data)} to {self.namespace}") @@ -81,7 +108,8 @@ class NanoVectorDBStorage(BaseVectorStorage): if len(embeddings) == len(list_data): for i, d in enumerate(list_data): d["__vector__"] = embeddings[i] - results = self._get_client().upsert(datas=list_data) + client = await self._get_client() + results = client.upsert(datas=list_data) return results else: # sometimes the embedding is not returned correctly. just log it. @@ -94,7 +122,8 @@ class NanoVectorDBStorage(BaseVectorStorage): embedding = await self.embedding_func([query]) embedding = embedding[0] - results = self._get_client().query( + client = await self._get_client() + results = client.query( query=embedding, top_k=top_k, better_than_threshold=self.cosine_better_than_threshold, @@ -111,8 +140,9 @@ class NanoVectorDBStorage(BaseVectorStorage): return results @property - def client_storage(self): - return getattr(self._get_client(), "_NanoVectorDB__storage") + async def client_storage(self): + client = await self._get_client() + return getattr(client, "_NanoVectorDB__storage") async def delete(self, ids: list[str]): """Delete vectors with specified IDs @@ -121,7 +151,8 @@ class NanoVectorDBStorage(BaseVectorStorage): ids: List of vector IDs to be deleted """ try: - self._get_client().delete(ids) + client = await self._get_client() + client.delete(ids) logger.debug( f"Successfully deleted {len(ids)} vectors from {self.namespace}" ) @@ -136,8 +167,9 @@ class NanoVectorDBStorage(BaseVectorStorage): ) # Check if the entity exists - if self._get_client().get([entity_id]): - self._get_client().delete([entity_id]) + client = await self._get_client() + if client.get([entity_id]): + client.delete([entity_id]) logger.debug(f"Successfully deleted entity {entity_name}") else: logger.debug(f"Entity {entity_name} not found in storage") @@ -146,7 +178,8 @@ class NanoVectorDBStorage(BaseVectorStorage): async def delete_entity_relation(self, entity_name: str) -> None: try: - storage = getattr(self._get_client(), "_NanoVectorDB__storage") + client = await self._get_client() + storage = getattr(client, "_NanoVectorDB__storage") relations = [ dp for dp in storage["data"] @@ -156,7 +189,8 @@ class NanoVectorDBStorage(BaseVectorStorage): ids_to_delete = [relation["__id__"] for relation in relations] if ids_to_delete: - self._get_client().delete(ids_to_delete) + client = await self._get_client() + client.delete(ids_to_delete) logger.debug( f"Deleted {len(ids_to_delete)} relations for {entity_name}" ) @@ -166,5 +200,32 @@ class NanoVectorDBStorage(BaseVectorStorage): logger.error(f"Error deleting relations for {entity_name}: {e}") async def index_done_callback(self) -> None: + # Check if storage was updated by another process + if is_multiprocess and self.storage_updated.value: + # Storage was updated by another process, reload data instead of saving + logger.warning(f"Storage for {self.namespace} was updated by another process, reloading...") + self._client = NanoVectorDB( + self.embedding_func.embedding_dim, + storage_file=self._client_file_name, + ) + # Reset update flag + self.storage_updated.value = False + return False # Return error + + # Acquire lock and perform persistence + client = await self._get_client() async with self._storage_lock: - self._get_client().save() + try: + # Save data to disk + client.save() + # Notify other processes that data has been updated + await set_all_update_flags(self.namespace) + # Reset own update flag to avoid self-notification + if is_multiprocess: + self.storage_updated.value = False + else: + self.storage_updated = False + return True # Return success + except Exception as e: + logger.error(f"Error saving data for {self.namespace}: {e}") + return False # Return error diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index ccf85855..37db8469 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -17,7 +17,12 @@ if not pm.is_installed("graspologic"): import networkx as nx from graspologic import embed -from threading import Lock as ThreadLock +from .shared_storage import ( + get_storage_lock, + get_update_flag, + set_all_update_flags, + is_multiprocess, +) @final @@ -73,10 +78,12 @@ class NetworkXStorage(BaseGraphStorage): self._graphml_xml_file = os.path.join( self.global_config["working_dir"], f"graph_{self.namespace}.graphml" ) - self._storage_lock = ThreadLock() + self._storage_lock = None + self.storage_updated = None + self._graph = None - with self._storage_lock: - preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) + # Load initial graph + preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) if preloaded_graph is not None: logger.info( f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" @@ -84,54 +91,83 @@ class NetworkXStorage(BaseGraphStorage): else: logger.info("Created new empty graph") self._graph = preloaded_graph or nx.Graph() + self._node_embed_algorithms = { "node2vec": self._node2vec_embed, } - def _get_graph(self): - """Check if the shtorage should be reloaded""" - return self._graph + async def initialize(self): + """Initialize storage data""" + # Get the update flag for cross-process update notification + self.storage_updated = await get_update_flag(self.namespace) + # Get the storage lock for use in other methods + self._storage_lock = get_storage_lock() + + async def _get_graph(self): + """Check if the storage should be reloaded""" + # Acquire lock to prevent concurrent read and write + async with self._storage_lock: + # Check if data needs to be reloaded + if (is_multiprocess and self.storage_updated.value) or \ + (not is_multiprocess and self.storage_updated): + logger.info(f"Reloading graph for {self.namespace} due to update by another process") + # Reload data + self._graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph() + # Reset update flag + if is_multiprocess: + self.storage_updated.value = False + else: + self.storage_updated = False + + return self._graph - async def index_done_callback(self) -> None: - with self._storage_lock: - NetworkXStorage.write_nx_graph(self._get_graph(), self._graphml_xml_file) async def has_node(self, node_id: str) -> bool: - return self._get_graph().has_node(node_id) + graph = await self._get_graph() + return graph.has_node(node_id) async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: - return self._get_graph().has_edge(source_node_id, target_node_id) + graph = await self._get_graph() + return graph.has_edge(source_node_id, target_node_id) async def get_node(self, node_id: str) -> dict[str, str] | None: - return self._get_graph().nodes.get(node_id) + graph = await self._get_graph() + return graph.nodes.get(node_id) async def node_degree(self, node_id: str) -> int: - return self._get_graph().degree(node_id) + graph = await self._get_graph() + return graph.degree(node_id) async def edge_degree(self, src_id: str, tgt_id: str) -> int: - return self._get_graph().degree(src_id) + self._get_graph().degree(tgt_id) + graph = await self._get_graph() + return graph.degree(src_id) + graph.degree(tgt_id) async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: - return self._get_graph().edges.get((source_node_id, target_node_id)) + graph = await self._get_graph() + return graph.edges.get((source_node_id, target_node_id)) async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: - if self._get_graph().has_node(source_node_id): - return list(self._get_graph().edges(source_node_id)) + graph = await self._get_graph() + if graph.has_node(source_node_id): + return list(graph.edges(source_node_id)) return None async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: - self._get_graph().add_node(node_id, **node_data) + graph = await self._get_graph() + graph.add_node(node_id, **node_data) async def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ) -> None: - self._get_graph().add_edge(source_node_id, target_node_id, **edge_data) + graph = await self._get_graph() + graph.add_edge(source_node_id, target_node_id, **edge_data) async def delete_node(self, node_id: str) -> None: - if self._get_graph().has_node(node_id): - self._get_graph().remove_node(node_id) + graph = await self._get_graph() + if graph.has_node(node_id): + graph.remove_node(node_id) logger.debug(f"Node {node_id} deleted from the graph.") else: logger.warning(f"Node {node_id} not found in the graph for deletion.") @@ -145,7 +181,7 @@ class NetworkXStorage(BaseGraphStorage): # TODO: NOT USED async def _node2vec_embed(self): - graph = self._get_graph() + graph = await self._get_graph() embeddings, nodes = embed.node2vec_embed( graph, **self.global_config["node2vec_params"], @@ -153,24 +189,24 @@ class NetworkXStorage(BaseGraphStorage): nodes_ids = [graph.nodes[node_id]["id"] for node_id in nodes] return embeddings, nodes_ids - def remove_nodes(self, nodes: list[str]): + async def remove_nodes(self, nodes: list[str]): """Delete multiple nodes Args: nodes: List of node IDs to be deleted """ - graph = self._get_graph() + graph = await self._get_graph() for node in nodes: if graph.has_node(node): graph.remove_node(node) - def remove_edges(self, edges: list[tuple[str, str]]): + async def remove_edges(self, edges: list[tuple[str, str]]): """Delete multiple edges Args: edges: List of edges to be deleted, each edge is a (source, target) tuple """ - graph = self._get_graph() + graph = await self._get_graph() for source, target in edges: if graph.has_edge(source, target): graph.remove_edge(source, target) @@ -181,8 +217,9 @@ class NetworkXStorage(BaseGraphStorage): Returns: [label1, label2, ...] # Alphabetically sorted label list """ + graph = await self._get_graph() labels = set() - for node in self._get_graph().nodes(): + for node in graph.nodes(): labels.add(str(node)) # Add node id as a label # Return sorted list @@ -205,7 +242,7 @@ class NetworkXStorage(BaseGraphStorage): seen_nodes = set() seen_edges = set() - graph = self._get_graph() + graph = await self._get_graph() # Handle special case for "*" label if node_label == "*": @@ -291,3 +328,31 @@ class NetworkXStorage(BaseGraphStorage): f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" ) return result + + async def index_done_callback(self) -> None: + # Check if storage was updated by another process + if is_multiprocess and self.storage_updated.value: + # Storage was updated by another process, reload data instead of saving + logger.warning(f"Graph for {self.namespace} was updated by another process, reloading...") + self._graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph() + # Reset update flag + self.storage_updated.value = False + return False # Return error + + # Acquire lock and perform persistence + graph = await self._get_graph() + async with self._storage_lock: + try: + # Save data to disk + NetworkXStorage.write_nx_graph(graph, self._graphml_xml_file) + # Notify other processes that data has been updated + await set_all_update_flags(self.namespace) + # Reset own update flag to avoid self-notification + if is_multiprocess: + self.storage_updated.value = False + else: + self.storage_updated = False + return True # Return success + except Exception as e: + logger.error(f"Error saving graph for {self.namespace}: {e}") + return False # Return error