diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 5dee1143..754c3491 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -737,6 +737,64 @@ class OracleGraphStorage(BaseGraphStorage): logger.error(f"Error deleting node {node_id}: {e}") raise + async def remove_nodes(self, nodes: list[str]) -> None: + """Delete multiple nodes from the graph + + Args: + nodes: List of node IDs to be deleted + """ + if not nodes: + return + + try: + for node in nodes: + # For each node, first delete all its relationships + delete_relations_sql = SQL_TEMPLATES["delete_entity_relations"] + params_relations = {"workspace": self.db.workspace, "entity_name": node} + await self.db.execute(delete_relations_sql, params_relations) + + # Then delete the node itself + delete_node_sql = SQL_TEMPLATES["delete_entity"] + params_node = {"workspace": self.db.workspace, "entity_name": node} + await self.db.execute(delete_node_sql, params_node) + + logger.info(f"Successfully deleted {len(nodes)} nodes and their relationships") + except Exception as e: + logger.error(f"Error during batch node deletion: {e}") + raise + + async def remove_edges(self, edges: list[tuple[str, str]]) -> None: + """Delete multiple edges from the graph + + Args: + edges: List of edges to be deleted, each edge is a (source, target) tuple + """ + if not edges: + return + + try: + for source, target in edges: + # Check if the edge exists before attempting to delete + if await self.has_edge(source, target): + # Delete the edge using a SQL query that matches both source and target + delete_edge_sql = """ + DELETE FROM LIGHTRAG_GRAPH_EDGES + WHERE workspace = :workspace + AND source_name = :source_name + AND target_name = :target_name + """ + params = { + "workspace": self.db.workspace, + "source_name": source, + "target_name": target + } + await self.db.execute(delete_edge_sql, params) + + logger.info(f"Successfully deleted {len(edges)} edges from the graph") + except Exception as e: + logger.error(f"Error during batch edge deletion: {e}") + raise + async def get_all_labels(self) -> list[str]: """Get all unique entity types (labels) in the graph