From 3a2a6368628fd2d54851ed6b1de8026cdf3cf608 Mon Sep 17 00:00:00 2001 From: zrguo Date: Tue, 4 Mar 2025 15:50:53 +0800 Subject: [PATCH] Implement the missing methods. --- lightrag/kg/age_impl.py | 248 ++++++++++++++++++++++++++++++- lightrag/kg/chroma_impl.py | 34 ++++- lightrag/kg/gremlin_impl.py | 279 ++++++++++++++++++++++++++++++++++- lightrag/kg/milvus_impl.py | 83 ++++++++++- lightrag/kg/mongo_impl.py | 109 +++++++++++++- lightrag/kg/oracle_impl.py | 255 +++++++++++++++++++++++++++++++- lightrag/kg/postgres_impl.py | 243 +++++++++++++++++++++++++++++- lightrag/kg/qdrant_impl.py | 89 ++++++++++- lightrag/kg/redis_impl.py | 78 +++++++++- lightrag/kg/tidb_impl.py | 180 +++++++++++++++++++++- lightrag/lightrag.py | 48 ++++++ 11 files changed, 1603 insertions(+), 43 deletions(-) diff --git a/lightrag/kg/age_impl.py b/lightrag/kg/age_impl.py index 97b3825d..c6b98221 100644 --- a/lightrag/kg/age_impl.py +++ b/lightrag/kg/age_impl.py @@ -8,7 +8,7 @@ from dataclasses import dataclass from typing import Any, Dict, List, NamedTuple, Optional, Union, final import numpy as np import pipmaster as pm -from lightrag.types import KnowledgeGraph +from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from tenacity import ( retry, @@ -613,20 +613,258 @@ class AGEStorage(BaseGraphStorage): await self._driver.putconn(connection) async def delete_node(self, node_id: str) -> None: - raise NotImplementedError + """Delete a node with the specified label + + Args: + node_id: The label of the node to delete + """ + entity_name_label = node_id.strip('"') + + query = """ + MATCH (n:`{label}`) + DETACH DELETE n + """ + params = {"label": AGEStorage._encode_graph_label(entity_name_label)} + try: + await self._query(query, **params) + logger.debug(f"Deleted node with label '{entity_name_label}'") + except Exception as e: + logger.error(f"Error during node deletion: {str(e)}") + raise + + async def remove_nodes(self, nodes: list[str]): + """Delete multiple nodes + + Args: + nodes: List of node labels to be deleted + """ + for node in nodes: + await self.delete_node(node) + + async def remove_edges(self, edges: list[tuple[str, str]]): + """Delete multiple edges + + Args: + edges: List of edges to be deleted, each edge is a (source, target) tuple + """ + for source, target in edges: + entity_name_label_source = source.strip('"') + entity_name_label_target = target.strip('"') + + query = """ + MATCH (source:`{src_label}`)-[r]->(target:`{tgt_label}`) + DELETE r + """ + params = { + "src_label": AGEStorage._encode_graph_label(entity_name_label_source), + "tgt_label": AGEStorage._encode_graph_label(entity_name_label_target) + } + try: + await self._query(query, **params) + logger.debug(f"Deleted edge from '{entity_name_label_source}' to '{entity_name_label_target}'") + except Exception as e: + logger.error(f"Error during edge deletion: {str(e)}") + raise async def embed_nodes( self, algorithm: str ) -> tuple[np.ndarray[Any, Any], list[str]]: - raise NotImplementedError + """Embed nodes using the specified algorithm + + Args: + algorithm: Name of the embedding algorithm + + Returns: + tuple: (embedding matrix, list of node identifiers) + """ + if algorithm not in self._node_embed_algorithms: + raise ValueError(f"Node embedding algorithm {algorithm} not supported") + return await self._node_embed_algorithms[algorithm]() async def get_all_labels(self) -> list[str]: - raise NotImplementedError + """Get all node labels in the database + + Returns: + ["label1", "label2", ...] # Alphabetically sorted label list + """ + query = """ + MATCH (n) + RETURN DISTINCT labels(n) AS node_labels + """ + results = await self._query(query) + + all_labels = [] + for record in results: + if record and "node_labels" in record: + for label in record["node_labels"]: + if label: + # Decode label + decoded_label = AGEStorage._decode_graph_label(label) + all_labels.append(decoded_label) + + # Remove duplicates and sort + return sorted(list(set(all_labels))) async def get_knowledge_graph( self, node_label: str, max_depth: int = 5 ) -> KnowledgeGraph: - raise NotImplementedError + """ + Retrieve a connected subgraph of nodes where the label includes the specified 'node_label'. + Maximum number of nodes is constrained by the environment variable 'MAX_GRAPH_NODES' (default: 1000). + When reducing the number of nodes, the prioritization criteria are as follows: + 1. Label matching nodes take precedence (nodes containing the specified label string) + 2. Followed by nodes directly connected to the matching nodes + 3. Finally, the degree of the nodes + + Args: + node_label: String to match in node labels (will match any node containing this string in its label) + max_depth: Maximum depth of the graph. Defaults to 5. + + Returns: + KnowledgeGraph: Complete connected subgraph for specified node + """ + max_graph_nodes = int(os.getenv("MAX_GRAPH_NODES", 1000)) + result = KnowledgeGraph() + seen_nodes = set() + seen_edges = set() + + # Handle special case for "*" label + if node_label == "*": + # Query all nodes and sort by degree + query = """ + MATCH (n) + OPTIONAL MATCH (n)-[r]-() + WITH n, count(r) AS degree + ORDER BY degree DESC + LIMIT {max_nodes} + RETURN n, degree + """ + params = {"max_nodes": max_graph_nodes} + nodes_result = await self._query(query, **params) + + # Add nodes to result + node_ids = [] + for record in nodes_result: + if "n" in record: + node = record["n"] + node_id = str(node.get("id", "")) + if node_id not in seen_nodes: + node_properties = {k: v for k, v in node.items()} + node_label = node.get("label", "") + result.nodes.append( + KnowledgeGraphNode( + id=node_id, + labels=[node_label], + properties=node_properties + ) + ) + seen_nodes.add(node_id) + node_ids.append(node_id) + + # Query edges between these nodes + if node_ids: + edges_query = """ + MATCH (a)-[r]->(b) + WHERE a.id IN {node_ids} AND b.id IN {node_ids} + RETURN a, r, b + """ + edges_params = {"node_ids": node_ids} + edges_result = await self._query(edges_query, **edges_params) + + # Add edges to result + for record in edges_result: + if "r" in record and "a" in record and "b" in record: + source = record["a"].get("id", "") + target = record["b"].get("id", "") + edge_id = f"{source}-{target}" + if edge_id not in seen_edges: + edge_properties = {k: v for k, v in record["r"].items()} + result.edges.append( + KnowledgeGraphEdge( + id=edge_id, + type="DIRECTED", + source=source, + target=target, + properties=edge_properties + ) + ) + seen_edges.add(edge_id) + else: + # For specific label, use partial matching + entity_name_label = node_label.strip('"') + encoded_label = AGEStorage._encode_graph_label(entity_name_label) + + # Find matching start nodes + start_query = """ + MATCH (n:`{label}`) + RETURN n + """ + start_params = {"label": encoded_label} + start_nodes = await self._query(start_query, **start_params) + + if not start_nodes: + logger.warning(f"No nodes found with label '{entity_name_label}'!") + return result + + # Traverse graph from each start node + for start_node_record in start_nodes: + if "n" in start_node_record: + start_node = start_node_record["n"] + start_id = str(start_node.get("id", "")) + + # Use BFS to traverse graph + query = """ + MATCH (start:`{label}`) + CALL { + MATCH path = (start)-[*0..{max_depth}]->(n) + RETURN nodes(path) AS path_nodes, relationships(path) AS path_rels + } + RETURN DISTINCT path_nodes, path_rels + """ + params = {"label": encoded_label, "max_depth": max_depth} + results = await self._query(query, **params) + + # Extract nodes and edges from results + for record in results: + if "path_nodes" in record: + # Process nodes + for node in record["path_nodes"]: + node_id = str(node.get("id", "")) + if node_id not in seen_nodes and len(seen_nodes) < max_graph_nodes: + node_properties = {k: v for k, v in node.items()} + node_label = node.get("label", "") + result.nodes.append( + KnowledgeGraphNode( + id=node_id, + labels=[node_label], + properties=node_properties + ) + ) + seen_nodes.add(node_id) + + if "path_rels" in record: + # Process edges + for rel in record["path_rels"]: + source = str(rel.get("start_id", "")) + target = str(rel.get("end_id", "")) + edge_id = f"{source}-{target}" + if edge_id not in seen_edges: + edge_properties = {k: v for k, v in rel.items()} + result.edges.append( + KnowledgeGraphEdge( + id=edge_id, + type=rel.get("label", "DIRECTED"), + source=source, + target=target, + properties=edge_properties + ) + ) + seen_edges.add(edge_id) + + logger.info( + f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" + ) + return result async def index_done_callback(self) -> None: # AGES handles persistence automatically diff --git a/lightrag/kg/chroma_impl.py b/lightrag/kg/chroma_impl.py index 3b726c8b..d36e6d7c 100644 --- a/lightrag/kg/chroma_impl.py +++ b/lightrag/kg/chroma_impl.py @@ -193,7 +193,37 @@ class ChromaVectorDBStorage(BaseVectorStorage): pass async def delete_entity(self, entity_name: str) -> None: - raise NotImplementedError + """Delete an entity by its ID. + + Args: + entity_name: The ID of the entity to delete + """ + try: + logger.info(f"Deleting entity with ID {entity_name} from {self.namespace}") + self._collection.delete(ids=[entity_name]) + except Exception as e: + logger.error(f"Error during entity deletion: {str(e)}") + raise async def delete_entity_relation(self, entity_name: str) -> None: - raise NotImplementedError + """Delete an entity and its relations by ID. + In vector DB context, this is equivalent to delete_entity. + + Args: + entity_name: The ID of the entity to delete + """ + await self.delete_entity(entity_name) + + async def delete(self, ids: list[str]) -> None: + """Delete vectors with specified IDs + + Args: + ids: List of vector IDs to be deleted + """ + try: + logger.info(f"Deleting {len(ids)} vectors from {self.namespace}") + self._collection.delete(ids=ids) + logger.debug(f"Successfully deleted {len(ids)} vectors from {self.namespace}") + except Exception as e: + logger.error(f"Error while deleting vectors from {self.namespace}: {e}") + raise diff --git a/lightrag/kg/gremlin_impl.py b/lightrag/kg/gremlin_impl.py index 3a26401d..4d343bb5 100644 --- a/lightrag/kg/gremlin_impl.py +++ b/lightrag/kg/gremlin_impl.py @@ -16,7 +16,7 @@ from tenacity import ( wait_exponential, ) -from lightrag.types import KnowledgeGraph +from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from lightrag.utils import logger from ..base import BaseGraphStorage @@ -396,17 +396,286 @@ class GremlinStorage(BaseGraphStorage): print("Implemented but never called.") async def delete_node(self, node_id: str) -> None: - raise NotImplementedError + """Delete a node with the specified entity_name + + Args: + node_id: The entity_name of the node to delete + """ + entity_name = GremlinStorage._fix_name(node_id) + + query = f"""g + .V().has('graph', {self.graph_name}) + .has('entity_name', {entity_name}) + .drop() + """ + try: + await self._query(query) + logger.debug( + "{%s}: Deleted node with entity_name '%s'", + inspect.currentframe().f_code.co_name, + entity_name + ) + except Exception as e: + logger.error(f"Error during node deletion: {str(e)}") + raise async def embed_nodes( self, algorithm: str ) -> tuple[np.ndarray[Any, Any], list[str]]: - raise NotImplementedError + """ + Embed nodes using the specified algorithm. + Currently, only node2vec is supported but never called. + + Args: + algorithm: The name of the embedding algorithm to use + + Returns: + A tuple of (embeddings, node_ids) + + Raises: + NotImplementedError: If the specified algorithm is not supported + ValueError: If the algorithm is not supported + """ + if algorithm not in self._node_embed_algorithms: + raise ValueError(f"Node embedding algorithm {algorithm} not supported") + return await self._node_embed_algorithms[algorithm]() async def get_all_labels(self) -> list[str]: - raise NotImplementedError + """ + Get all node entity_names in the graph + Returns: + [entity_name1, entity_name2, ...] # Alphabetically sorted entity_name list + """ + query = f"""g + .V().has('graph', {self.graph_name}) + .values('entity_name') + .dedup() + .order() + """ + try: + result = await self._query(query) + labels = result if result else [] + logger.debug( + "{%s}: Retrieved %d labels", + inspect.currentframe().f_code.co_name, + len(labels) + ) + return labels + except Exception as e: + logger.error(f"Error retrieving labels: {str(e)}") + return [] async def get_knowledge_graph( self, node_label: str, max_depth: int = 5 ) -> KnowledgeGraph: - raise NotImplementedError + """ + Retrieve a connected subgraph of nodes where the entity_name includes the specified `node_label`. + Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000). + + Args: + node_label: Entity name of the starting node + max_depth: Maximum depth of the subgraph + + Returns: + KnowledgeGraph object containing nodes and edges + """ + result = KnowledgeGraph() + seen_nodes = set() + seen_edges = set() + + # Get maximum number of graph nodes from environment variable, default is 1000 + MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) + + entity_name = GremlinStorage._fix_name(node_label) + + # Handle special case for "*" label + if node_label == "*": + # For "*", get all nodes and their edges (limited by MAX_GRAPH_NODES) + query = f"""g + .V().has('graph', {self.graph_name}) + .limit({MAX_GRAPH_NODES}) + .elementMap() + """ + nodes_result = await self._query(query) + + # Add nodes to result + for node_data in nodes_result: + node_id = node_data.get('entity_name', str(node_data.get('id', ''))) + if str(node_id) in seen_nodes: + continue + + # Create node with properties + node_properties = {k: v for k, v in node_data.items() if k not in ['id', 'label']} + + result.nodes.append( + KnowledgeGraphNode( + id=str(node_id), + labels=[str(node_id)], + properties=node_properties + ) + ) + seen_nodes.add(str(node_id)) + + # Get and add edges + if nodes_result: + query = f"""g + .V().has('graph', {self.graph_name}) + .limit({MAX_GRAPH_NODES}) + .outE() + .inV().has('graph', {self.graph_name}) + .limit({MAX_GRAPH_NODES}) + .path() + .by(elementMap()) + .by(elementMap()) + .by(elementMap()) + """ + edges_result = await self._query(query) + + for path in edges_result: + if len(path) >= 3: # source -> edge -> target + source = path[0] + edge_data = path[1] + target = path[2] + + source_id = source.get('entity_name', str(source.get('id', ''))) + target_id = target.get('entity_name', str(target.get('id', ''))) + + edge_id = f"{source_id}-{target_id}" + if edge_id in seen_edges: + continue + + # Create edge with properties + edge_properties = {k: v for k, v in edge_data.items() if k not in ['id', 'label']} + + result.edges.append( + KnowledgeGraphEdge( + id=edge_id, + type="DIRECTED", + source=str(source_id), + target=str(target_id), + properties=edge_properties + ) + ) + seen_edges.add(edge_id) + else: + # Search for specific node and get its neighborhood + query = f"""g + .V().has('graph', {self.graph_name}) + .has('entity_name', {entity_name}) + .repeat(__.both().simplePath().dedup()) + .times({max_depth}) + .emit() + .dedup() + .limit({MAX_GRAPH_NODES}) + .elementMap() + """ + nodes_result = await self._query(query) + + # Add nodes to result + for node_data in nodes_result: + node_id = node_data.get('entity_name', str(node_data.get('id', ''))) + if str(node_id) in seen_nodes: + continue + + # Create node with properties + node_properties = {k: v for k, v in node_data.items() if k not in ['id', 'label']} + + result.nodes.append( + KnowledgeGraphNode( + id=str(node_id), + labels=[str(node_id)], + properties=node_properties + ) + ) + seen_nodes.add(str(node_id)) + + # Get edges between the nodes in the result + if nodes_result: + node_ids = [n.get('entity_name', str(n.get('id', ''))) for n in nodes_result] + node_ids_query = ", ".join([GremlinStorage._to_value_map(nid) for nid in node_ids]) + + query = f"""g + .V().has('graph', {self.graph_name}) + .has('entity_name', within({node_ids_query})) + .outE() + .where(inV().has('graph', {self.graph_name}) + .has('entity_name', within({node_ids_query}))) + .path() + .by(elementMap()) + .by(elementMap()) + .by(elementMap()) + """ + edges_result = await self._query(query) + + for path in edges_result: + if len(path) >= 3: # source -> edge -> target + source = path[0] + edge_data = path[1] + target = path[2] + + source_id = source.get('entity_name', str(source.get('id', ''))) + target_id = target.get('entity_name', str(target.get('id', ''))) + + edge_id = f"{source_id}-{target_id}" + if edge_id in seen_edges: + continue + + # Create edge with properties + edge_properties = {k: v for k, v in edge_data.items() if k not in ['id', 'label']} + + result.edges.append( + KnowledgeGraphEdge( + id=edge_id, + type="DIRECTED", + source=str(source_id), + target=str(target_id), + properties=edge_properties + ) + ) + seen_edges.add(edge_id) + + logger.info( + "Subgraph query successful | Node count: %d | Edge count: %d", + len(result.nodes), + len(result.edges) + ) + return result + + async def remove_nodes(self, nodes: list[str]): + """Delete multiple nodes + + Args: + nodes: List of node entity_names to be deleted + """ + for node in nodes: + await self.delete_node(node) + + async def remove_edges(self, edges: list[tuple[str, str]]): + """Delete multiple edges + + Args: + edges: List of edges to be deleted, each edge is a (source, target) tuple + """ + for source, target in edges: + entity_name_source = GremlinStorage._fix_name(source) + entity_name_target = GremlinStorage._fix_name(target) + + query = f"""g + .V().has('graph', {self.graph_name}) + .has('entity_name', {entity_name_source}) + .outE() + .where(inV().has('graph', {self.graph_name}) + .has('entity_name', {entity_name_target})) + .drop() + """ + try: + await self._query(query) + logger.debug( + "{%s}: Deleted edge from '%s' to '%s'", + inspect.currentframe().f_code.co_name, + entity_name_source, + entity_name_target + ) + except Exception as e: + logger.error(f"Error during edge deletion: {str(e)}") + raise diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index 33a5c12b..2ad4da18 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -3,7 +3,7 @@ import os from typing import Any, final from dataclasses import dataclass import numpy as np -from lightrag.utils import logger +from lightrag.utils import logger, compute_mdhash_id from ..base import BaseVectorStorage import pipmaster as pm @@ -124,7 +124,84 @@ class MilvusVectorDBStorage(BaseVectorStorage): pass async def delete_entity(self, entity_name: str) -> None: - raise NotImplementedError + """Delete an entity from the vector database + + Args: + entity_name: The name of the entity to delete + """ + try: + # Compute entity ID from name + entity_id = compute_mdhash_id(entity_name, prefix="ent-") + logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}") + + # Delete the entity from Milvus collection + result = self._client.delete( + collection_name=self.namespace, + pks=[entity_id] + ) + + if result and result.get("delete_count", 0) > 0: + logger.debug(f"Successfully deleted entity {entity_name}") + else: + logger.debug(f"Entity {entity_name} not found in storage") + + except Exception as e: + logger.error(f"Error deleting entity {entity_name}: {e}") async def delete_entity_relation(self, entity_name: str) -> None: - raise NotImplementedError + """Delete all relations associated with an entity + + Args: + entity_name: The name of the entity whose relations should be deleted + """ + try: + # Search for relations where entity is either source or target + expr = f'src_id == "{entity_name}" or tgt_id == "{entity_name}"' + + # Find all relations involving this entity + results = self._client.query( + collection_name=self.namespace, + filter=expr, + output_fields=["id"] + ) + + if not results or len(results) == 0: + logger.debug(f"No relations found for entity {entity_name}") + return + + # Extract IDs of relations to delete + relation_ids = [item["id"] for item in results] + logger.debug(f"Found {len(relation_ids)} relations for entity {entity_name}") + + # Delete the relations + if relation_ids: + delete_result = self._client.delete( + collection_name=self.namespace, + pks=relation_ids + ) + + logger.debug(f"Deleted {delete_result.get('delete_count', 0)} relations for {entity_name}") + + except Exception as e: + logger.error(f"Error deleting relations for {entity_name}: {e}") + + async def delete(self, ids: list[str]) -> None: + """Delete vectors with specified IDs + + Args: + ids: List of vector IDs to be deleted + """ + try: + # Delete vectors by IDs + result = self._client.delete( + collection_name=self.namespace, + pks=ids + ) + + if result and result.get("delete_count", 0) > 0: + logger.debug(f"Successfully deleted {result.get('delete_count', 0)} vectors from {self.namespace}") + else: + logger.debug(f"No vectors were deleted from {self.namespace}") + + except Exception as e: + logger.error(f"Error while deleting vectors from {self.namespace}: {e}") diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 0048b384..3afd2b44 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -15,7 +15,7 @@ from ..base import ( DocStatusStorage, ) from ..namespace import NameSpace, is_namespace -from ..utils import logger +from ..utils import logger, compute_mdhash_id from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge import pipmaster as pm @@ -333,7 +333,7 @@ class MongoGraphStorage(BaseGraphStorage): Check if there's a direct single-hop edge from source_node_id to target_node_id. We'll do a $graphLookup with maxDepth=0 from the source node—meaning - “Look up zero expansions.” Actually, for a direct edge check, we can do maxDepth=1 + "Look up zero expansions." Actually, for a direct edge check, we can do maxDepth=1 and then see if the target node is in the "reachableNodes" at depth=0. But typically for a direct edge, we might just do a find_one. @@ -795,6 +795,52 @@ class MongoGraphStorage(BaseGraphStorage): # Mongo handles persistence automatically pass + async def remove_nodes(self, nodes: list[str]) -> None: + """Delete multiple nodes + + Args: + nodes: List of node IDs to be deleted + """ + logger.info(f"Deleting {len(nodes)} nodes") + if not nodes: + return + + # 1. Remove all edges referencing these nodes (remove from edges array of other nodes) + await self.collection.update_many( + {}, + {"$pull": {"edges": {"target": {"$in": nodes}}}} + ) + + # 2. Delete the node documents + await self.collection.delete_many({"_id": {"$in": nodes}}) + + logger.debug(f"Successfully deleted nodes: {nodes}") + + async def remove_edges(self, edges: list[tuple[str, str]]) -> None: + """Delete multiple edges + + Args: + edges: List of edges to be deleted, each edge is a (source, target) tuple + """ + logger.info(f"Deleting {len(edges)} edges") + if not edges: + return + + update_tasks = [] + for source, target in edges: + # Remove edge pointing to target from source node's edges array + update_tasks.append( + self.collection.update_one( + {"_id": source}, + {"$pull": {"edges": {"target": target}}} + ) + ) + + if update_tasks: + await asyncio.gather(*update_tasks) + + logger.debug(f"Successfully deleted edges: {edges}") + @final @dataclass @@ -932,11 +978,66 @@ class MongoVectorDBStorage(BaseVectorStorage): # Mongo handles persistence automatically pass + async def delete(self, ids: list[str]) -> None: + """Delete vectors with specified IDs + + Args: + ids: List of vector IDs to be deleted + """ + logger.info(f"Deleting {len(ids)} vectors from {self.namespace}") + if not ids: + return + + try: + result = await self._data.delete_many({"_id": {"$in": ids}}) + logger.debug(f"Successfully deleted {result.deleted_count} vectors from {self.namespace}") + except PyMongoError as e: + logger.error(f"Error while deleting vectors from {self.namespace}: {str(e)}") + async def delete_entity(self, entity_name: str) -> None: - raise NotImplementedError + """Delete an entity by its name + + Args: + entity_name: Name of the entity to delete + """ + try: + entity_id = compute_mdhash_id(entity_name, prefix="ent-") + logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}") + + result = await self._data.delete_one({"_id": entity_id}) + if result.deleted_count > 0: + logger.debug(f"Successfully deleted entity {entity_name}") + else: + logger.debug(f"Entity {entity_name} not found in storage") + except PyMongoError as e: + logger.error(f"Error deleting entity {entity_name}: {str(e)}") async def delete_entity_relation(self, entity_name: str) -> None: - raise NotImplementedError + """Delete all relations associated with an entity + + Args: + entity_name: Name of the entity whose relations should be deleted + """ + try: + # Find relations where entity appears as source or target + relations_cursor = self._data.find( + {"$or": [{"src_id": entity_name}, {"tgt_id": entity_name}]} + ) + relations = await relations_cursor.to_list(length=None) + + if not relations: + logger.debug(f"No relations found for entity {entity_name}") + return + + # Extract IDs of relations to delete + relation_ids = [relation["_id"] for relation in relations] + logger.debug(f"Found {len(relation_ids)} relations for entity {entity_name}") + + # Delete the relations + result = await self._data.delete_many({"_id": {"$in": relation_ids}}) + logger.debug(f"Deleted {result.deleted_count} relations for {entity_name}") + except PyMongoError as e: + logger.error(f"Error deleting relations for {entity_name}: {str(e)}") async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str): diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index af2ededb..d189679e 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -8,7 +8,7 @@ from typing import Any, Union, final import numpy as np import configparser -from lightrag.types import KnowledgeGraph +from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..base import ( BaseGraphStorage, @@ -442,11 +442,55 @@ class OracleVectorDBStorage(BaseVectorStorage): # Oracles handles persistence automatically pass + async def delete(self, ids: list[str]) -> None: + """Delete vectors with specified IDs + + Args: + ids: List of vector IDs to be deleted + """ + if not ids: + return + + try: + SQL = SQL_TEMPLATES["delete_vectors"].format( + ids=",".join([f"'{id}'" for id in ids]) + ) + params = {"workspace": self.db.workspace} + await self.db.execute(SQL, params) + logger.info(f"Successfully deleted {len(ids)} vectors from {self.namespace}") + except Exception as e: + logger.error(f"Error while deleting vectors from {self.namespace}: {e}") + raise + async def delete_entity(self, entity_name: str) -> None: - raise NotImplementedError + """Delete entity by name + + Args: + entity_name: Name of the entity to delete + """ + try: + SQL = SQL_TEMPLATES["delete_entity"] + params = {"workspace": self.db.workspace, "entity_name": entity_name} + await self.db.execute(SQL, params) + logger.info(f"Successfully deleted entity {entity_name}") + except Exception as e: + logger.error(f"Error deleting entity {entity_name}: {e}") + raise async def delete_entity_relation(self, entity_name: str) -> None: - raise NotImplementedError + """Delete all relations connected to an entity + + Args: + entity_name: Name of the entity whose relations should be deleted + """ + try: + SQL = SQL_TEMPLATES["delete_entity_relations"] + params = {"workspace": self.db.workspace, "entity_name": entity_name} + await self.db.execute(SQL, params) + logger.info(f"Successfully deleted relations for entity {entity_name}") + except Exception as e: + logger.error(f"Error deleting relations for entity {entity_name}: {e}") + raise @final @@ -668,15 +712,206 @@ class OracleGraphStorage(BaseGraphStorage): return res async def delete_node(self, node_id: str) -> None: - raise NotImplementedError + """Delete a node from the graph + + Args: + node_id: ID of the node to delete + """ + try: + # First delete all relations connected to this node + delete_relations_sql = SQL_TEMPLATES["delete_entity_relations"] + params_relations = {"workspace": self.db.workspace, "entity_name": node_id} + await self.db.execute(delete_relations_sql, params_relations) + + # Then delete the node itself + delete_node_sql = SQL_TEMPLATES["delete_entity"] + params_node = {"workspace": self.db.workspace, "entity_name": node_id} + await self.db.execute(delete_node_sql, params_node) + + logger.info(f"Successfully deleted node {node_id} and all its relationships") + except Exception as e: + logger.error(f"Error deleting node {node_id}: {e}") + raise async def get_all_labels(self) -> list[str]: - raise NotImplementedError + """Get all unique entity types (labels) in the graph + + Returns: + List of unique entity types/labels + """ + try: + SQL = """ + SELECT DISTINCT entity_type + FROM LIGHTRAG_GRAPH_NODES + WHERE workspace = :workspace + ORDER BY entity_type + """ + params = {"workspace": self.db.workspace} + results = await self.db.query(SQL, params, multirows=True) + + if results: + labels = [row["entity_type"] for row in results] + return labels + else: + return [] + except Exception as e: + logger.error(f"Error retrieving entity types: {e}") + return [] async def get_knowledge_graph( self, node_label: str, max_depth: int = 5 ) -> KnowledgeGraph: - raise NotImplementedError + """Retrieve a connected subgraph starting from nodes matching the given label + + Maximum number of nodes is constrained by MAX_GRAPH_NODES environment variable. + Prioritizes nodes by: + 1. Nodes matching the specified label + 2. Nodes directly connected to matching nodes + 3. Node degree (number of connections) + + Args: + node_label: Label to match for starting nodes (use "*" for all nodes) + max_depth: Maximum depth of traversal from starting nodes + + Returns: + KnowledgeGraph object containing nodes and edges + """ + result = KnowledgeGraph() + + try: + # Define maximum number of nodes to return + max_graph_nodes = int(os.environ.get("MAX_GRAPH_NODES", 1000)) + + if node_label == "*": + # For "*" label, get all nodes up to the limit + nodes_sql = """ + SELECT name, entity_type, description, source_chunk_id + FROM LIGHTRAG_GRAPH_NODES + WHERE workspace = :workspace + ORDER BY id + FETCH FIRST :limit ROWS ONLY + """ + nodes_params = {"workspace": self.db.workspace, "limit": max_graph_nodes} + nodes = await self.db.query(nodes_sql, nodes_params, multirows=True) + else: + # For specific label, find matching nodes and related nodes + nodes_sql = """ + WITH matching_nodes AS ( + SELECT name + FROM LIGHTRAG_GRAPH_NODES + WHERE workspace = :workspace + AND (name LIKE '%' || :node_label || '%' OR entity_type LIKE '%' || :node_label || '%') + ) + SELECT n.name, n.entity_type, n.description, n.source_chunk_id, + CASE + WHEN n.name IN (SELECT name FROM matching_nodes) THEN 2 + WHEN EXISTS ( + SELECT 1 FROM LIGHTRAG_GRAPH_EDGES e + WHERE workspace = :workspace + AND ((e.source_name = n.name AND e.target_name IN (SELECT name FROM matching_nodes)) + OR (e.target_name = n.name AND e.source_name IN (SELECT name FROM matching_nodes))) + ) THEN 1 + ELSE 0 + END AS priority, + (SELECT COUNT(*) FROM LIGHTRAG_GRAPH_EDGES e + WHERE workspace = :workspace + AND (e.source_name = n.name OR e.target_name = n.name)) AS degree + FROM LIGHTRAG_GRAPH_NODES n + WHERE workspace = :workspace + ORDER BY priority DESC, degree DESC + FETCH FIRST :limit ROWS ONLY + """ + nodes_params = { + "workspace": self.db.workspace, + "node_label": node_label, + "limit": max_graph_nodes + } + nodes = await self.db.query(nodes_sql, nodes_params, multirows=True) + + if not nodes: + logger.warning(f"No nodes found matching '{node_label}'") + return result + + # Create mapping of node IDs to be used to filter edges + node_names = [node["name"] for node in nodes] + + # Add nodes to result + seen_nodes = set() + for node in nodes: + node_id = node["name"] + if node_id in seen_nodes: + continue + + # Create node properties dictionary + properties = { + "entity_type": node["entity_type"], + "description": node["description"] or "", + "source_id": node["source_chunk_id"] or "" + } + + # Add node to result + result.nodes.append( + KnowledgeGraphNode( + id=node_id, + labels=[node["entity_type"]], + properties=properties + ) + ) + seen_nodes.add(node_id) + + # Get edges between these nodes + edges_sql = """ + SELECT source_name, target_name, weight, keywords, description, source_chunk_id + FROM LIGHTRAG_GRAPH_EDGES + WHERE workspace = :workspace + AND source_name IN (SELECT COLUMN_VALUE FROM TABLE(CAST(:node_names AS SYS.ODCIVARCHAR2LIST))) + AND target_name IN (SELECT COLUMN_VALUE FROM TABLE(CAST(:node_names AS SYS.ODCIVARCHAR2LIST))) + ORDER BY id + """ + edges_params = { + "workspace": self.db.workspace, + "node_names": node_names + } + edges = await self.db.query(edges_sql, edges_params, multirows=True) + + # Add edges to result + seen_edges = set() + for edge in edges: + source = edge["source_name"] + target = edge["target_name"] + edge_id = f"{source}-{target}" + + if edge_id in seen_edges: + continue + + # Create edge properties dictionary + properties = { + "weight": edge["weight"] or 0.0, + "keywords": edge["keywords"] or "", + "description": edge["description"] or "", + "source_id": edge["source_chunk_id"] or "" + } + + # Add edge to result + result.edges.append( + KnowledgeGraphEdge( + id=edge_id, + type="RELATED", + source=source, + target=target, + properties=properties + ) + ) + seen_edges.add(edge_id) + + logger.info( + f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" + ) + + except Exception as e: + logger.error(f"Error retrieving knowledge graph: {e}") + + return result N_T = { @@ -927,4 +1162,12 @@ SQL_TEMPLATES = { select 'edge' as type, TO_CHAR(id) id FROM GRAPH_TABLE (lightrag_graph MATCH (a)-[e]->(b) WHERE e.workspace=:workspace columns(e.id)) )""", + # SQL for deletion + "delete_vectors": "DELETE FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=:workspace AND id IN ({ids})", + "delete_entity": "DELETE FROM LIGHTRAG_GRAPH_NODES WHERE workspace=:workspace AND name=:entity_name", + "delete_entity_relations": "DELETE FROM LIGHTRAG_GRAPH_EDGES WHERE workspace=:workspace AND (source_name=:entity_name OR target_name=:entity_name)", + "delete_node": """DELETE FROM GRAPH_TABLE (lightrag_graph + MATCH (a) + WHERE a.workspace=:workspace AND a.name=:node_id + ACTION DELETE a)""", } diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 51044be5..7ce2b427 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -7,7 +7,7 @@ from typing import Any, Union, final import numpy as np import configparser -from lightrag.types import KnowledgeGraph +from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge import sys from tenacity import ( @@ -512,11 +512,66 @@ class PGVectorStorage(BaseVectorStorage): # PG handles persistence automatically pass + async def delete(self, ids: list[str]) -> None: + """Delete vectors with specified IDs from the storage. + + Args: + ids: List of vector IDs to be deleted + """ + if not ids: + return + + table_name = namespace_to_table_name(self.namespace) + if not table_name: + logger.error(f"Unknown namespace for vector deletion: {self.namespace}") + return + + ids_list = ",".join([f"'{id}'" for id in ids]) + delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id IN ({ids_list})" + + try: + await self.db.execute(delete_sql, {"workspace": self.db.workspace}) + logger.debug(f"Successfully deleted {len(ids)} vectors from {self.namespace}") + except Exception as e: + logger.error(f"Error while deleting vectors from {self.namespace}: {e}") + async def delete_entity(self, entity_name: str) -> None: - raise NotImplementedError + """Delete an entity by its name from the vector storage. + + Args: + entity_name: The name of the entity to delete + """ + try: + # Construct SQL to delete the entity + delete_sql = """DELETE FROM LIGHTRAG_VDB_ENTITY + WHERE workspace=$1 AND entity_name=$2""" + + await self.db.execute( + delete_sql, + {"workspace": self.db.workspace, "entity_name": entity_name} + ) + logger.debug(f"Successfully deleted entity {entity_name}") + except Exception as e: + logger.error(f"Error deleting entity {entity_name}: {e}") async def delete_entity_relation(self, entity_name: str) -> None: - raise NotImplementedError + """Delete all relations associated with an entity. + + Args: + entity_name: The name of the entity whose relations should be deleted + """ + try: + # Delete relations where the entity is either the source or target + delete_sql = """DELETE FROM LIGHTRAG_VDB_RELATION + WHERE workspace=$1 AND (source_id=$2 OR target_id=$2)""" + + await self.db.execute( + delete_sql, + {"workspace": self.db.workspace, "entity_name": entity_name} + ) + logger.debug(f"Successfully deleted relations for entity {entity_name}") + except Exception as e: + logger.error(f"Error deleting relations for entity {entity_name}: {e}") @final @@ -1086,20 +1141,192 @@ class PGGraphStorage(BaseGraphStorage): print("Implemented but never called.") async def delete_node(self, node_id: str) -> None: - raise NotImplementedError + """ + Delete a node from the graph. + + Args: + node_id (str): The ID of the node to delete. + """ + label = self._encode_graph_label(node_id.strip('"')) + + query = """SELECT * FROM cypher('%s', $$ + MATCH (n:Entity {node_id: "%s"}) + DETACH DELETE n + $$) AS (n agtype)""" % (self.graph_name, label) + + try: + await self._query(query, readonly=False) + except Exception as e: + logger.error("Error during node deletion: {%s}", e) + raise + + async def remove_nodes(self, node_ids: list[str]) -> None: + """ + Remove multiple nodes from the graph. + + Args: + node_ids (list[str]): A list of node IDs to remove. + """ + encoded_node_ids = [self._encode_graph_label(node_id.strip('"')) for node_id in node_ids] + node_id_list = ", ".join([f'"{node_id}"' for node_id in encoded_node_ids]) + + query = """SELECT * FROM cypher('%s', $$ + MATCH (n:Entity) + WHERE n.node_id IN [%s] + DETACH DELETE n + $$) AS (n agtype)""" % (self.graph_name, node_id_list) + + try: + await self._query(query, readonly=False) + except Exception as e: + logger.error("Error during node removal: {%s}", e) + raise + + async def remove_edges(self, edges: list[tuple[str, str]]) -> None: + """ + Remove multiple edges from the graph. + + Args: + edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id). + """ + encoded_edges = [(self._encode_graph_label(src.strip('"')), self._encode_graph_label(tgt.strip('"'))) for src, tgt in edges] + edge_list = ", ".join([f'["{src}", "{tgt}"]' for src, tgt in encoded_edges]) + + query = """SELECT * FROM cypher('%s', $$ + MATCH (a:Entity)-[r]->(b:Entity) + WHERE [a.node_id, b.node_id] IN [%s] + DELETE r + $$) AS (r agtype)""" % (self.graph_name, edge_list) + + try: + await self._query(query, readonly=False) + except Exception as e: + logger.error("Error during edge removal: {%s}", e) + raise + + async def get_all_labels(self) -> list[str]: + """ + Get all labels (node IDs) in the graph. + + Returns: + list[str]: A list of all labels in the graph. + """ + query = """SELECT * FROM cypher('%s', $$ + MATCH (n:Entity) + RETURN DISTINCT n.node_id AS label + $$) AS (label text)""" % self.graph_name + + results = await self._query(query) + labels = [self._decode_graph_label(result["label"]) for result in results] + + return labels async def embed_nodes( self, algorithm: str ) -> tuple[np.ndarray[Any, Any], list[str]]: - raise NotImplementedError + """ + Generate node embeddings using the specified algorithm. - async def get_all_labels(self) -> list[str]: - raise NotImplementedError + Args: + algorithm (str): The name of the embedding algorithm to use. + + Returns: + tuple[np.ndarray[Any, Any], list[str]]: A tuple containing the embeddings and the corresponding node IDs. + """ + if algorithm not in self._node_embed_algorithms: + raise ValueError(f"Unsupported embedding algorithm: {algorithm}") + + embed_func = self._node_embed_algorithms[algorithm] + return await embed_func() async def get_knowledge_graph( self, node_label: str, max_depth: int = 5 ) -> KnowledgeGraph: - raise NotImplementedError + """ + Retrieve a subgraph containing the specified node and its neighbors up to the specified depth. + + Args: + node_label (str): The label of the node to start from. If "*", the entire graph is returned. + max_depth (int): The maximum depth to traverse from the starting node. + + Returns: + KnowledgeGraph: The retrieved subgraph. + """ + MAX_GRAPH_NODES = 1000 + + if node_label == "*": + query = """SELECT * FROM cypher('%s', $$ + MATCH (n:Entity) + OPTIONAL MATCH (n)-[r]->(m:Entity) + RETURN n, r, m + LIMIT %d + $$) AS (n agtype, r agtype, m agtype)""" % (self.graph_name, MAX_GRAPH_NODES) + else: + encoded_node_label = self._encode_graph_label(node_label.strip('"')) + query = """SELECT * FROM cypher('%s', $$ + MATCH (n:Entity {node_id: "%s"}) + OPTIONAL MATCH p = (n)-[*..%d]-(m) + RETURN nodes(p) AS nodes, relationships(p) AS relationships + LIMIT %d + $$) AS (nodes agtype[], relationships agtype[])""" % (self.graph_name, encoded_node_label, max_depth, MAX_GRAPH_NODES) + + results = await self._query(query) + + nodes = set() + edges = [] + + for result in results: + if node_label == "*": + if result["n"]: + node = result["n"] + nodes.add(self._decode_graph_label(node["node_id"])) + if result["m"]: + node = result["m"] + nodes.add(self._decode_graph_label(node["node_id"])) + if result["r"]: + edge = result["r"] + src_id = self._decode_graph_label(edge["start_id"]) + tgt_id = self._decode_graph_label(edge["end_id"]) + edges.append((src_id, tgt_id)) + else: + if result["nodes"]: + for node in result["nodes"]: + nodes.add(self._decode_graph_label(node["node_id"])) + if result["relationships"]: + for edge in result["relationships"]: + src_id = self._decode_graph_label(edge["start_id"]) + tgt_id = self._decode_graph_label(edge["end_id"]) + edges.append((src_id, tgt_id)) + + kg = KnowledgeGraph( + nodes=[KnowledgeGraphNode(id=node_id) for node_id in nodes], + edges=[KnowledgeGraphEdge(source=src, target=tgt) for src, tgt in edges], + ) + + return kg + + async def get_all_labels(self) -> list[str]: + """ + Get all node labels in the graph + Returns: + [label1, label2, ...] # Alphabetically sorted label list + """ + query = """SELECT * FROM cypher('%s', $$ + MATCH (n:Entity) + RETURN DISTINCT n.node_id AS label + ORDER BY label + $$) AS (label agtype)""" % (self.graph_name) + + try: + results = await self._query(query) + labels = [] + for record in results: + if record["label"]: + labels.append(self._decode_graph_label(record["label"])) + return labels + except Exception as e: + logger.error(f"Error getting all labels: {str(e)}") + return [] async def drop(self) -> None: """Drop the storage""" diff --git a/lightrag/kg/qdrant_impl.py b/lightrag/kg/qdrant_impl.py index b08f0b62..e3488caa 100644 --- a/lightrag/kg/qdrant_impl.py +++ b/lightrag/kg/qdrant_impl.py @@ -1,6 +1,6 @@ import asyncio import os -from typing import Any, final +from typing import Any, final, List from dataclasses import dataclass import numpy as np import hashlib @@ -141,8 +141,91 @@ class QdrantVectorDBStorage(BaseVectorStorage): # Qdrant handles persistence automatically pass + async def delete(self, ids: List[str]) -> None: + """Delete vectors with specified IDs + + Args: + ids: List of vector IDs to be deleted + """ + try: + # Convert regular ids to Qdrant compatible ids + qdrant_ids = [compute_mdhash_id_for_qdrant(id) for id in ids] + # Delete points from the collection + self._client.delete( + collection_name=self.namespace, + points_selector=models.PointIdsList( + points=qdrant_ids, + ), + wait=True + ) + logger.debug(f"Successfully deleted {len(ids)} vectors from {self.namespace}") + except Exception as e: + logger.error(f"Error while deleting vectors from {self.namespace}: {e}") + async def delete_entity(self, entity_name: str) -> None: - raise NotImplementedError + """Delete an entity by name + + Args: + entity_name: Name of the entity to delete + """ + try: + # Generate the entity ID + entity_id = compute_mdhash_id_for_qdrant(entity_name, prefix="ent-") + logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}") + + # Delete the entity point from the collection + self._client.delete( + collection_name=self.namespace, + points_selector=models.PointIdsList( + points=[entity_id], + ), + wait=True + ) + logger.debug(f"Successfully deleted entity {entity_name}") + except Exception as e: + logger.error(f"Error deleting entity {entity_name}: {e}") async def delete_entity_relation(self, entity_name: str) -> None: - raise NotImplementedError + """Delete all relations associated with an entity + + Args: + entity_name: Name of the entity whose relations should be deleted + """ + try: + # Find relations where the entity is either source or target + results = self._client.scroll( + collection_name=self.namespace, + scroll_filter=models.Filter( + should=[ + models.FieldCondition( + key="src_id", + match=models.MatchValue(value=entity_name) + ), + models.FieldCondition( + key="tgt_id", + match=models.MatchValue(value=entity_name) + ) + ] + ), + with_payload=True, + limit=1000 # Adjust as needed for your use case + ) + + # Extract points that need to be deleted + relation_points = results[0] + ids_to_delete = [point.id for point in relation_points] + + if ids_to_delete: + # Delete the relations + self._client.delete( + collection_name=self.namespace, + points_selector=models.PointIdsList( + points=ids_to_delete, + ), + wait=True + ) + logger.debug(f"Deleted {len(ids_to_delete)} relations for {entity_name}") + else: + logger.debug(f"No relations found for entity {entity_name}") + except Exception as e: + logger.error(f"Error deleting relations for {entity_name}: {e}") diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index 7e177346..bb42b367 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -9,7 +9,7 @@ if not pm.is_installed("redis"): # aioredis is a depricated library, replaced with redis from redis.asyncio import Redis -from lightrag.utils import logger +from lightrag.utils import logger, compute_mdhash_id from lightrag.base import BaseKVStorage import json @@ -64,3 +64,79 @@ class RedisKVStorage(BaseKVStorage): async def index_done_callback(self) -> None: # Redis handles persistence automatically pass + + async def delete(self, ids: list[str]) -> None: + """Delete entries with specified IDs + + Args: + ids: List of entry IDs to be deleted + """ + if not ids: + return + + pipe = self._redis.pipeline() + for id in ids: + pipe.delete(f"{self.namespace}:{id}") + + results = await pipe.execute() + deleted_count = sum(results) + logger.info(f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}") + + async def delete_entity(self, entity_name: str) -> None: + """Delete an entity by name + + Args: + entity_name: Name of the entity to delete + """ + + try: + entity_id = compute_mdhash_id(entity_name, prefix="ent-") + logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}") + + # Delete the entity + result = await self._redis.delete(f"{self.namespace}:{entity_id}") + + if result: + logger.debug(f"Successfully deleted entity {entity_name}") + else: + logger.debug(f"Entity {entity_name} not found in storage") + except Exception as e: + logger.error(f"Error deleting entity {entity_name}: {e}") + + async def delete_entity_relation(self, entity_name: str) -> None: + """Delete all relations associated with an entity + + Args: + entity_name: Name of the entity whose relations should be deleted + """ + try: + # Get all keys in this namespace + cursor = 0 + relation_keys = [] + pattern = f"{self.namespace}:*" + + while True: + cursor, keys = await self._redis.scan(cursor, match=pattern) + + # For each key, get the value and check if it's related to entity_name + for key in keys: + value = await self._redis.get(key) + if value: + data = json.loads(value) + # Check if this is a relation involving the entity + if data.get("src_id") == entity_name or data.get("tgt_id") == entity_name: + relation_keys.append(key) + + # Exit loop when cursor returns to 0 + if cursor == 0: + break + + # Delete the relation keys + if relation_keys: + deleted = await self._redis.delete(*relation_keys) + logger.debug(f"Deleted {deleted} relations for {entity_name}") + else: + logger.debug(f"No relations found for entity {entity_name}") + + except Exception as e: + logger.error(f"Error deleting relations for {entity_name}: {e}") diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index 51d1c365..f791d401 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -5,7 +5,7 @@ from typing import Any, Union, final import numpy as np -from lightrag.types import KnowledgeGraph +from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage @@ -566,15 +566,148 @@ class TiDBGraphStorage(BaseGraphStorage): pass async def delete_node(self, node_id: str) -> None: - raise NotImplementedError - + """Delete a node and all its related edges + + Args: + node_id: The ID of the node to delete + """ + # First delete all edges related to this node + await self.db.execute(SQL_TEMPLATES["delete_node_edges"], + {"name": node_id, "workspace": self.db.workspace}) + + # Then delete the node itself + await self.db.execute(SQL_TEMPLATES["delete_node"], + {"name": node_id, "workspace": self.db.workspace}) + + logger.debug(f"Node {node_id} and its related edges have been deleted from the graph") + async def get_all_labels(self) -> list[str]: - raise NotImplementedError - + """Get all entity types (labels) in the database + + Returns: + List of labels sorted alphabetically + """ + result = await self.db.query( + SQL_TEMPLATES["get_all_labels"], + {"workspace": self.db.workspace}, + multirows=True + ) + + if not result: + return [] + + # Extract all labels + return [item["label"] for item in result] + async def get_knowledge_graph( self, node_label: str, max_depth: int = 5 ) -> KnowledgeGraph: - raise NotImplementedError + """ + Get a connected subgraph of nodes matching the specified label + Maximum number of nodes is limited by MAX_GRAPH_NODES environment variable (default: 1000) + + Args: + node_label: The node label to match + max_depth: Maximum depth of the subgraph + + Returns: + KnowledgeGraph object containing nodes and edges + """ + result = KnowledgeGraph() + MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) + + # Get matching nodes + if node_label == "*": + # Handle special case, get all nodes + node_results = await self.db.query( + SQL_TEMPLATES["get_all_nodes"], + {"workspace": self.db.workspace, "max_nodes": MAX_GRAPH_NODES}, + multirows=True + ) + else: + # Get nodes matching the label + label_pattern = f"%{node_label}%" + node_results = await self.db.query( + SQL_TEMPLATES["get_matching_nodes"], + {"workspace": self.db.workspace, "label_pattern": label_pattern}, + multirows=True + ) + + if not node_results: + logger.warning(f"No nodes found matching label {node_label}") + return result + + # Limit the number of returned nodes + if len(node_results) > MAX_GRAPH_NODES: + node_results = node_results[:MAX_GRAPH_NODES] + + # Extract node names for edge query + node_names = [node["name"] for node in node_results] + node_names_str = ",".join([f"'{name}'" for name in node_names]) + + # Add nodes to result + for node in node_results: + node_properties = {k: v for k, v in node.items() if k not in ["id", "name", "entity_type"]} + result.nodes.append( + KnowledgeGraphNode( + id=node["name"], + labels=[node["entity_type"]] if node.get("entity_type") else [node["name"]], + properties=node_properties + ) + ) + + # Get related edges + edge_results = await self.db.query( + SQL_TEMPLATES["get_related_edges"].format(node_names=node_names_str), + {"workspace": self.db.workspace}, + multirows=True + ) + + if edge_results: + # Add edges to result + for edge in edge_results: + # Only include edges related to selected nodes + if edge["source_name"] in node_names and edge["target_name"] in node_names: + edge_id = f"{edge['source_name']}-{edge['target_name']}" + edge_properties = {k: v for k, v in edge.items() + if k not in ["id", "source_name", "target_name"]} + + result.edges.append( + KnowledgeGraphEdge( + id=edge_id, + type="RELATED", + source=edge["source_name"], + target=edge["target_name"], + properties=edge_properties + ) + ) + + logger.info( + f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" + ) + return result + + async def remove_nodes(self, nodes: list[str]): + """Delete multiple nodes + + Args: + nodes: List of node IDs to delete + """ + for node_id in nodes: + await self.delete_node(node_id) + + async def remove_edges(self, edges: list[tuple[str, str]]): + """Delete multiple edges + + Args: + edges: List of edges to delete, each edge is a (source, target) tuple + """ + for source, target in edges: + await self.db.execute(SQL_TEMPLATES["remove_multiple_edges"], { + "source": source, + "target": target, + "workspace": self.db.workspace + }) N_T = { @@ -785,4 +918,39 @@ SQL_TEMPLATES = { weight = VALUES(weight), keywords = VALUES(keywords), description = VALUES(description), source_chunk_id = VALUES(source_chunk_id) """, + "delete_node": """ + DELETE FROM LIGHTRAG_GRAPH_NODES + WHERE name = :name AND workspace = :workspace + """, + "delete_node_edges": """ + DELETE FROM LIGHTRAG_GRAPH_EDGES + WHERE (source_name = :name OR target_name = :name) AND workspace = :workspace + """, + "get_all_labels": """ + SELECT DISTINCT entity_type as label + FROM LIGHTRAG_GRAPH_NODES + WHERE workspace = :workspace + ORDER BY entity_type + """, + "get_matching_nodes": """ + SELECT * FROM LIGHTRAG_GRAPH_NODES + WHERE name LIKE :label_pattern AND workspace = :workspace + ORDER BY name + """, + "get_all_nodes": """ + SELECT * FROM LIGHTRAG_GRAPH_NODES + WHERE workspace = :workspace + ORDER BY name + LIMIT :max_nodes + """, + "get_related_edges": """ + SELECT * FROM LIGHTRAG_GRAPH_EDGES + WHERE (source_name IN (:node_names) OR target_name IN (:node_names)) + AND workspace = :workspace + """, + "remove_multiple_edges": """ + DELETE FROM LIGHTRAG_GRAPH_EDGES + WHERE (source_name = :source AND target_name = :target) + AND workspace = :workspace + """ } diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 6f42003d..eeed8a70 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -1399,6 +1399,54 @@ class LightRAG: ] ) + def delete_by_relation(self, source_entity: str, target_entity: str) -> None: + """Synchronously delete a relation between two entities. + + Args: + source_entity: Name of the source entity + target_entity: Name of the target entity + """ + loop = always_get_an_event_loop() + return loop.run_until_complete(self.adelete_by_relation(source_entity, target_entity)) + + async def adelete_by_relation(self, source_entity: str, target_entity: str) -> None: + """Asynchronously delete a relation between two entities. + + Args: + source_entity: Name of the source entity + target_entity: Name of the target entity + """ + try: + # Check if the relation exists + edge_exists = await self.chunk_entity_relation_graph.has_edge(source_entity, target_entity) + if not edge_exists: + logger.warning(f"Relation from '{source_entity}' to '{target_entity}' does not exist") + return + + # Delete relation from vector database + relation_id = compute_mdhash_id(source_entity + target_entity, prefix="rel-") + await self.relationships_vdb.delete([relation_id]) + + # Delete relation from knowledge graph + await self.chunk_entity_relation_graph.remove_edges([(source_entity, target_entity)]) + + logger.info(f"Successfully deleted relation from '{source_entity}' to '{target_entity}'") + await self._delete_relation_done() + except Exception as e: + logger.error(f"Error while deleting relation from '{source_entity}' to '{target_entity}': {e}") + + async def _delete_relation_done(self) -> None: + """Callback after relation deletion is complete""" + await asyncio.gather( + *[ + cast(StorageNameSpace, storage_inst).index_done_callback() + for storage_inst in [ # type: ignore + self.relationships_vdb, + self.chunk_entity_relation_graph, + ] + ] + ) + def _get_content_summary(self, content: str, max_length: int = 100) -> str: """Get summary of document content