diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index dccee330..fec39138 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -690,8 +690,98 @@ class Neo4JStorage(BaseGraphStorage): labels.append(record["label"]) return labels + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type( + ( + neo4jExceptions.ServiceUnavailable, + neo4jExceptions.TransientError, + neo4jExceptions.WriteServiceUnavailable, + neo4jExceptions.ClientError, + ) + ), + ) async def delete_node(self, node_id: str) -> None: - raise NotImplementedError + """Delete a node with the specified label + + Args: + node_id: The label of the node to delete + """ + label = await self._ensure_label(node_id) + + async def _do_delete(tx: AsyncManagedTransaction): + query = f""" + MATCH (n:`{label}`) + DETACH DELETE n + """ + await tx.run(query) + logger.debug(f"Deleted node with label '{label}'") + + try: + async with self._driver.session(database=self._DATABASE) as session: + await session.execute_write(_do_delete) + except Exception as e: + logger.error(f"Error during node deletion: {str(e)}") + raise + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type( + ( + neo4jExceptions.ServiceUnavailable, + neo4jExceptions.TransientError, + neo4jExceptions.WriteServiceUnavailable, + neo4jExceptions.ClientError, + ) + ), + ) + async def remove_nodes(self, nodes: list[str]): + """Delete multiple nodes + + Args: + nodes: List of node labels to be deleted + """ + for node in nodes: + await self.delete_node(node) + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type( + ( + neo4jExceptions.ServiceUnavailable, + neo4jExceptions.TransientError, + neo4jExceptions.WriteServiceUnavailable, + neo4jExceptions.ClientError, + ) + ), + ) + async def remove_edges(self, edges: list[tuple[str, str]]): + """Delete multiple edges + + Args: + edges: List of edges to be deleted, each edge is a (source, target) tuple + """ + for source, target in edges: + source_label = await self._ensure_label(source) + target_label = await self._ensure_label(target) + + async def _do_delete_edge(tx: AsyncManagedTransaction): + query = f""" + MATCH (source:`{source_label}`)-[r]->(target:`{target_label}`) + DELETE r + """ + await tx.run(query) + logger.debug(f"Deleted edge from '{source_label}' to '{target_label}'") + + try: + async with self._driver.session(database=self._DATABASE) as session: + await session.execute_write(_do_delete_edge) + except Exception as e: + logger.error(f"Error during edge deletion: {str(e)}") + raise async def embed_nodes( self, algorithm: str