From cee5b2fbb0ad6f6ffa244d3335874f0a1830f6e8 Mon Sep 17 00:00:00 2001 From: zrguo Date: Tue, 31 Dec 2024 17:15:57 +0800 Subject: [PATCH] add delete by doc id --- lightrag/lightrag.py | 273 ++++++++++++++++++++++++++++++++++++++++++- lightrag/storage.py | 107 ++++++++++++++--- 2 files changed, 360 insertions(+), 20 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 9a7ebeb8..03922556 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -43,6 +43,8 @@ from .storage import ( JsonDocStatusStorage, ) +from .prompt import GRAPH_FIELD_SEP + # future KG integrations # from .kg.ArangoDB_impl import ( @@ -672,7 +674,7 @@ class LightRAG: try: await self.entities_vdb.delete_entity(entity_name) - await self.relationships_vdb.delete_relation(entity_name) + await self.relationships_vdb.delete_entity_relation(entity_name) await self.chunk_entity_relation_graph.delete_node(entity_name) logger.info( @@ -716,3 +718,272 @@ class LightRAG: Dict with counts for each status """ return await self.doc_status.get_status_counts() + + async def adelete_by_doc_id(self, doc_id: str): + """Delete a document and all its related data + + Args: + doc_id: Document ID to delete + """ + try: + # 1. Get the document status and related data + doc_status = await self.doc_status.get(doc_id) + if not doc_status: + logger.warning(f"Document {doc_id} not found") + return + + logger.debug(f"Starting deletion for document {doc_id}") + + # 2. Get all related chunks + chunks = await self.text_chunks.filter(lambda x: x.get("full_doc_id") == doc_id) + chunk_ids = list(chunks.keys()) + logger.debug(f"Found {len(chunk_ids)} chunks to delete") + + # 3. Before deleting, check the related entities and relationships for these chunks + for chunk_id in chunk_ids: + # Check entities + entities = [ + dp for dp in self.entities_vdb.client_storage["data"] + if dp.get("source_id") == chunk_id + ] + logger.debug(f"Chunk {chunk_id} has {len(entities)} related entities") + + # Check relationships + relations = [ + dp for dp in self.relationships_vdb.client_storage["data"] + if dp.get("source_id") == chunk_id + ] + logger.debug(f"Chunk {chunk_id} has {len(relations)} related relations") + + # Continue with the original deletion process... + + # 4. Delete chunks from vector database + if chunk_ids: + await self.chunks_vdb.delete(chunk_ids) + await self.text_chunks.delete(chunk_ids) + + # 5. Find and process entities and relationships that have these chunks as source + # Get all nodes in the graph + nodes = self.chunk_entity_relation_graph._graph.nodes(data=True) + edges = self.chunk_entity_relation_graph._graph.edges(data=True) + + # Track which entities and relationships need to be deleted or updated + entities_to_delete = set() + entities_to_update = {} # entity_name -> new_source_id + relationships_to_delete = set() + relationships_to_update = {} # (src, tgt) -> new_source_id + + # Process entities + for node, data in nodes: + if 'source_id' in data: + # Split source_id using GRAPH_FIELD_SEP + sources = set(data['source_id'].split(GRAPH_FIELD_SEP)) + sources.difference_update(chunk_ids) + if not sources: + entities_to_delete.add(node) + logger.debug(f"Entity {node} marked for deletion - no remaining sources") + else: + new_source_id = GRAPH_FIELD_SEP.join(sources) + entities_to_update[node] = new_source_id + logger.debug(f"Entity {node} will be updated with new source_id: {new_source_id}") + + # Process relationships + for src, tgt, data in edges: + if 'source_id' in data: + # Split source_id using GRAPH_FIELD_SEP + sources = set(data['source_id'].split(GRAPH_FIELD_SEP)) + sources.difference_update(chunk_ids) + if not sources: + relationships_to_delete.add((src, tgt)) + logger.debug(f"Relationship {src}-{tgt} marked for deletion - no remaining sources") + else: + new_source_id = GRAPH_FIELD_SEP.join(sources) + relationships_to_update[(src, tgt)] = new_source_id + logger.debug(f"Relationship {src}-{tgt} will be updated with new source_id: {new_source_id}") + + # Delete entities + if entities_to_delete: + for entity in entities_to_delete: + await self.entities_vdb.delete_entity(entity) + logger.debug(f"Deleted entity {entity} from vector DB") + self.chunk_entity_relation_graph.remove_nodes(list(entities_to_delete)) + logger.debug(f"Deleted {len(entities_to_delete)} entities from graph") + + # Update entities + for entity, new_source_id in entities_to_update.items(): + node_data = self.chunk_entity_relation_graph._graph.nodes[entity] + node_data['source_id'] = new_source_id + await self.chunk_entity_relation_graph.upsert_node(entity, node_data) + logger.debug(f"Updated entity {entity} with new source_id: {new_source_id}") + + # Delete relationships + if relationships_to_delete: + for src, tgt in relationships_to_delete: + rel_id_0 = compute_mdhash_id(src + tgt, prefix="rel-") + rel_id_1 = compute_mdhash_id(tgt + src, prefix="rel-") + await self.relationships_vdb.delete([rel_id_0, rel_id_1]) + logger.debug(f"Deleted relationship {src}-{tgt} from vector DB") + self.chunk_entity_relation_graph.remove_edges(list(relationships_to_delete)) + logger.debug(f"Deleted {len(relationships_to_delete)} relationships from graph") + + # Update relationships + for (src, tgt), new_source_id in relationships_to_update.items(): + edge_data = self.chunk_entity_relation_graph._graph.edges[src, tgt] + edge_data['source_id'] = new_source_id + await self.chunk_entity_relation_graph.upsert_edge(src, tgt, edge_data) + logger.debug(f"Updated relationship {src}-{tgt} with new source_id: {new_source_id}") + + # 6. Delete original document and status + await self.full_docs.delete([doc_id]) + await self.doc_status.delete([doc_id]) + + # 7. Ensure all indexes are updated + await self._insert_done() + + logger.info( + f"Successfully deleted document {doc_id} and related data. " + f"Deleted {len(entities_to_delete)} entities and {len(relationships_to_delete)} relationships. " + f"Updated {len(entities_to_update)} entities and {len(relationships_to_update)} relationships." + ) + + # Add verification step + async def verify_deletion(): + # Verify if the document has been deleted + if await self.full_docs.get_by_id(doc_id): + logger.error(f"Document {doc_id} still exists in full_docs") + + # Verify if chunks have been deleted + remaining_chunks = await self.text_chunks.filter( + lambda x: x.get("full_doc_id") == doc_id + ) + if remaining_chunks: + logger.error(f"Found {len(remaining_chunks)} remaining chunks") + + # Verify entities and relationships + for chunk_id in chunk_ids: + # Check entities + entities_with_chunk = [ + dp for dp in self.entities_vdb.client_storage["data"] + if chunk_id in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP) + ] + if entities_with_chunk: + logger.error(f"Found {len(entities_with_chunk)} entities still referencing chunk {chunk_id}") + + # Check relationships + relations_with_chunk = [ + dp for dp in self.relationships_vdb.client_storage["data"] + if chunk_id in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP) + ] + if relations_with_chunk: + logger.error(f"Found {len(relations_with_chunk)} relations still referencing chunk {chunk_id}") + + await verify_deletion() + + except Exception as e: + logger.error(f"Error while deleting document {doc_id}: {e}") + + def delete_by_doc_id(self, doc_id: str): + """Synchronous version of adelete""" + return asyncio.run(self.adelete_by_doc_id(doc_id)) + + async def get_entity_info(self, entity_name: str, include_vector_data: bool = False): + """Get detailed information of an entity + + Args: + entity_name: Entity name (no need for quotes) + include_vector_data: Whether to include data from the vector database + + Returns: + dict: A dictionary containing entity information, including: + - entity_name: Entity name + - source_id: Source document ID + - graph_data: Complete node data from the graph database + - vector_data: (optional) Data from the vector database + """ + entity_name = f'"{entity_name.upper()}"' + + # Get information from the graph + node_data = await self.chunk_entity_relation_graph.get_node(entity_name) + source_id = node_data.get('source_id') if node_data else None + + result = { + "entity_name": entity_name, + "source_id": source_id, + "graph_data": node_data, + } + + # Optional: Get vector database information + if include_vector_data: + entity_id = compute_mdhash_id(entity_name, prefix="ent-") + vector_data = self.entities_vdb._client.get([entity_id]) + result["vector_data"] = vector_data[0] if vector_data else None + + return result + + def get_entity_info_sync(self, entity_name: str, include_vector_data: bool = False): + """Synchronous version of getting entity information + + Args: + entity_name: Entity name (no need for quotes) + include_vector_data: Whether to include data from the vector database + """ + try: + import tracemalloc + tracemalloc.start() + return asyncio.run(self.get_entity_info(entity_name, include_vector_data)) + finally: + tracemalloc.stop() + + async def get_relation_info(self, src_entity: str, tgt_entity: str, include_vector_data: bool = False): + """Get detailed information of a relationship + + Args: + src_entity: Source entity name (no need for quotes) + tgt_entity: Target entity name (no need for quotes) + include_vector_data: Whether to include data from the vector database + + Returns: + dict: A dictionary containing relationship information, including: + - src_entity: Source entity name + - tgt_entity: Target entity name + - source_id: Source document ID + - graph_data: Complete edge data from the graph database + - vector_data: (optional) Data from the vector database + """ + src_entity = f'"{src_entity.upper()}"' + tgt_entity = f'"{tgt_entity.upper()}"' + + # Get information from the graph + edge_data = await self.chunk_entity_relation_graph.get_edge(src_entity, tgt_entity) + source_id = edge_data.get('source_id') if edge_data else None + + result = { + "src_entity": src_entity, + "tgt_entity": tgt_entity, + "source_id": source_id, + "graph_data": edge_data, + } + + # Optional: Get vector database information + if include_vector_data: + rel_id = compute_mdhash_id(src_entity + tgt_entity, prefix="rel-") + vector_data = self.relationships_vdb._client.get([rel_id]) + result["vector_data"] = vector_data[0] if vector_data else None + + return result + + def get_relation_info_sync(self, src_entity: str, tgt_entity: str, include_vector_data: bool = False): + """Synchronous version of getting relationship information + + Args: + src_entity: Source entity name (no need for quotes) + tgt_entity: Target entity name (no need for quotes) + include_vector_data: Whether to include data from the vector database + """ + try: + import tracemalloc + tracemalloc.start() + return asyncio.run(self.get_relation_info(src_entity, tgt_entity, include_vector_data)) + finally: + tracemalloc.stop() + diff --git a/lightrag/storage.py b/lightrag/storage.py index ac8a95d3..a3138568 100644 --- a/lightrag/storage.py +++ b/lightrag/storage.py @@ -32,6 +32,7 @@ class JsonKVStorage(BaseKVStorage): working_dir = self.global_config["working_dir"] self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") self._data = load_json(self._file_name) or {} + self._lock = asyncio.Lock() logger.info(f"Load KV {self.namespace} with {len(self._data)} data") async def all_keys(self) -> list[str]: @@ -66,6 +67,35 @@ class JsonKVStorage(BaseKVStorage): async def drop(self): self._data = {} + async def filter(self, filter_func): + """Filter key-value pairs based on a filter function + + Args: + filter_func: The filter function, which takes a value as an argument and returns a boolean value + + Returns: + Dict: Key-value pairs that meet the condition + """ + result = {} + async with self._lock: + for key, value in self._data.items(): + if filter_func(value): + result[key] = value + return result + + async def delete(self, ids: list[str]): + """Delete data with specified IDs + + Args: + ids: List of IDs to delete + """ + async with self._lock: + for id in ids: + if id in self._data: + del self._data[id] + await self.index_done_callback() + logger.info(f"Successfully deleted {len(ids)} items from {self.namespace}") + @dataclass class NanoVectorDBStorage(BaseVectorStorage): @@ -150,38 +180,47 @@ class NanoVectorDBStorage(BaseVectorStorage): def client_storage(self): return getattr(self._client, "_NanoVectorDB__storage") + async def delete(self, ids: list[str]): + """Delete vectors with specified IDs + + Args: + ids: List of vector IDs to be deleted + """ + try: + self._client.delete(ids) + logger.info(f"Successfully deleted {len(ids)} vectors from {self.namespace}") + except Exception as e: + logger.error(f"Error while deleting vectors from {self.namespace}: {e}") + async def delete_entity(self, entity_name: str): try: - entity_id = [compute_mdhash_id(entity_name, prefix="ent-")] - - if self._client.get(entity_id): - self._client.delete(entity_id) - logger.info(f"Entity {entity_name} have been deleted.") + entity_id = compute_mdhash_id(entity_name, prefix="ent-") + logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}") + # Check if the entity exists + if self._client.get([entity_id]): + await self.delete([entity_id]) + logger.debug(f"Successfully deleted entity {entity_name}") else: - logger.info(f"No entity found with name {entity_name}.") + logger.debug(f"Entity {entity_name} not found in storage") except Exception as e: - logger.error(f"Error while deleting entity {entity_name}: {e}") + logger.error(f"Error deleting entity {entity_name}: {e}") - async def delete_relation(self, entity_name: str): + async def delete_entity_relation(self, entity_name: str): try: relations = [ - dp - for dp in self.client_storage["data"] + dp for dp in self.client_storage["data"] if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name ] + logger.debug(f"Found {len(relations)} relations for entity {entity_name}") ids_to_delete = [relation["__id__"] for relation in relations] - + if ids_to_delete: - self._client.delete(ids_to_delete) - logger.info( - f"All relations related to entity {entity_name} have been deleted." - ) + await self.delete(ids_to_delete) + logger.debug(f"Deleted {len(ids_to_delete)} relations for {entity_name}") else: - logger.info(f"No relations found for entity {entity_name}.") + logger.debug(f"No relations found for entity {entity_name}") except Exception as e: - logger.error( - f"Error while deleting relations for entity {entity_name}: {e}" - ) + logger.error(f"Error deleting relations for {entity_name}: {e}") async def index_done_callback(self): self._client.save() @@ -329,6 +368,26 @@ class NetworkXStorage(BaseGraphStorage): nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes] return embeddings, nodes_ids + def remove_nodes(self, nodes: list[str]): + """Delete multiple nodes + + Args: + nodes: List of node IDs to be deleted + """ + for node in nodes: + if self._graph.has_node(node): + self._graph.remove_node(node) + + def remove_edges(self, edges: list[tuple[str, str]]): + """Delete multiple edges + + Args: + edges: List of edges to be deleted, each edge is a (source, target) tuple + """ + for source, target in edges: + if self._graph.has_edge(source, target): + self._graph.remove_edge(source, target) + @dataclass class JsonDocStatusStorage(DocStatusStorage): @@ -378,3 +437,13 @@ class JsonDocStatusStorage(DocStatusStorage): self._data.update(data) await self.index_done_callback() return data + + async def get(self, doc_id: str) -> Union[DocProcessingStatus, None]: + """Get document status by ID""" + return self._data.get(doc_id) + + async def delete(self, doc_ids: list[str]): + """Delete document status by IDs""" + for doc_id in doc_ids: + self._data.pop(doc_id, None) + await self.index_done_callback() \ No newline at end of file