Add delete method

This commit is contained in:
LarFii
2024-11-11 17:48:40 +08:00
parent 319de6fece
commit 4c0352ee2b
6 changed files with 100 additions and 4 deletions

View File

@@ -7,7 +7,13 @@ import networkx as nx
import numpy as np
from nano_vectordb import NanoVectorDB
from .utils import load_json, logger, write_json
from .utils import (
logger,
load_json,
write_json,
compute_mdhash_id,
)
from .base import (
BaseGraphStorage,
BaseKVStorage,
@@ -110,6 +116,37 @@ class NanoVectorDBStorage(BaseVectorStorage):
{**dp, "id": dp["__id__"], "distance": dp["__metrics__"]} for dp in results
]
return results
@property
def client_storage(self):
return getattr(self._client, "_NanoVectorDB__storage")
async def delete_entity(self, entity_name: str):
try:
entity_id = [compute_mdhash_id(entity_name, prefix="ent-")]
if self._client.get(entity_id):
self._client.delete(entity_id)
logger.info(f"Entity {entity_name} have been deleted.")
else:
logger.info(f"No entity found with name {entity_name}.")
except Exception as e:
logger.error(f"Error while deleting entity {entity_name}: {e}")
async def delete_relation(self, entity_name: str):
try:
relations = [
dp for dp in self.client_storage["data"] if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
]
ids_to_delete = [relation["__id__"] for relation in relations]
if ids_to_delete:
self._client.delete(ids_to_delete)
logger.info(f"All relations related to entity {entity_name} have been deleted.")
else:
logger.info(f"No relations found for entity {entity_name}.")
except Exception as e:
logger.error(f"Error while deleting relations for entity {entity_name}: {e}")
async def index_done_callback(self):
self._client.save()
@@ -228,6 +265,18 @@ class NetworkXStorage(BaseGraphStorage):
):
self._graph.add_edge(source_node_id, target_node_id, **edge_data)
async def delete_node(self, node_id: str):
"""
Delete a node from the graph based on the specified node_id.
:param node_id: The node_id to delete
"""
if self._graph.has_node(node_id):
self._graph.remove_node(node_id)
logger.info(f"Node {node_id} deleted from the graph.")
else:
logger.warning(f"Node {node_id} not found in the graph for deletion.")
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
if algorithm not in self._node_embed_algorithms:
raise ValueError(f"Node embedding algorithm {algorithm} not supported")