From 7a866cbe216a46b97de9a8f931d1338aebe763c9 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 3 Mar 2025 11:48:43 +0100 Subject: [PATCH 01/21] Update run_with_gunicorn.py --- lightrag/api/run_with_gunicorn.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lightrag/api/run_with_gunicorn.py b/lightrag/api/run_with_gunicorn.py index 903c5c17..f27ddd61 100644 --- a/lightrag/api/run_with_gunicorn.py +++ b/lightrag/api/run_with_gunicorn.py @@ -9,7 +9,10 @@ import signal import pipmaster as pm from lightrag.api.utils_api import parse_args, display_splash_screen from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data +from dotenv import load_dotenv +# Load environment variables from .env file +load_dotenv() def check_and_install_dependencies(): """Check and install required dependencies""" From ff3f29d2406c27d1d28e74ec96a8311858009bda Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 3 Mar 2025 12:13:01 +0100 Subject: [PATCH 02/21] Update run_with_gunicorn.py --- lightrag/api/run_with_gunicorn.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lightrag/api/run_with_gunicorn.py b/lightrag/api/run_with_gunicorn.py index f27ddd61..e7143a39 100644 --- a/lightrag/api/run_with_gunicorn.py +++ b/lightrag/api/run_with_gunicorn.py @@ -12,6 +12,7 @@ from lightrag.kg.shared_storage import initialize_share_data, finalize_share_dat from dotenv import load_dotenv # Load environment variables from .env file +print("Current folder: {}".format(os.getcwd())) load_dotenv() def check_and_install_dependencies(): From e87feb76bc86a94c97d7407dda7f329455936e0c Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 3 Mar 2025 12:21:15 +0100 Subject: [PATCH 03/21] Update run_with_gunicorn.py --- lightrag/api/run_with_gunicorn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lightrag/api/run_with_gunicorn.py b/lightrag/api/run_with_gunicorn.py index e7143a39..50dd195d 100644 --- a/lightrag/api/run_with_gunicorn.py +++ b/lightrag/api/run_with_gunicorn.py @@ -12,8 +12,9 @@ from lightrag.kg.shared_storage import initialize_share_data, finalize_share_dat from dotenv import load_dotenv # Load environment variables from .env file -print("Current folder: {}".format(os.getcwd())) +print(f"Current folder: {os.getcwd()}") load_dotenv() +print(f"Check: {os.getenv('LLM_MODEL')}") def check_and_install_dependencies(): """Check and install required dependencies""" From bda931e1d2abc934b7418e0c2ad4e73734d2366d Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 3 Mar 2025 12:21:50 +0100 Subject: [PATCH 04/21] Update run_with_gunicorn.py --- lightrag/api/run_with_gunicorn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/api/run_with_gunicorn.py b/lightrag/api/run_with_gunicorn.py index 50dd195d..71844fe0 100644 --- a/lightrag/api/run_with_gunicorn.py +++ b/lightrag/api/run_with_gunicorn.py @@ -13,7 +13,7 @@ from dotenv import load_dotenv # Load environment variables from .env file print(f"Current folder: {os.getcwd()}") -load_dotenv() +load_dotenv(".env") print(f"Check: {os.getenv('LLM_MODEL')}") def check_and_install_dependencies(): From 52bedc9118892d1ad214cda7ed6164d06f27e574 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 3 Mar 2025 12:22:37 +0100 Subject: [PATCH 05/21] Update run_with_gunicorn.py --- lightrag/api/run_with_gunicorn.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/lightrag/api/run_with_gunicorn.py b/lightrag/api/run_with_gunicorn.py index 71844fe0..4e5353bf 100644 --- a/lightrag/api/run_with_gunicorn.py +++ b/lightrag/api/run_with_gunicorn.py @@ -11,10 +11,8 @@ from lightrag.api.utils_api import parse_args, display_splash_screen from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data from dotenv import load_dotenv -# Load environment variables from .env file -print(f"Current folder: {os.getcwd()}") +# Updated to use the .env that is inside the current folder load_dotenv(".env") -print(f"Check: {os.getenv('LLM_MODEL')}") def check_and_install_dependencies(): """Check and install required dependencies""" From 7b3e39473065935570466537a0c4b7139fb3d176 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 3 Mar 2025 12:23:47 +0100 Subject: [PATCH 06/21] Update run_with_gunicorn.py --- lightrag/api/run_with_gunicorn.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lightrag/api/run_with_gunicorn.py b/lightrag/api/run_with_gunicorn.py index 4e5353bf..231a1727 100644 --- a/lightrag/api/run_with_gunicorn.py +++ b/lightrag/api/run_with_gunicorn.py @@ -12,6 +12,7 @@ from lightrag.kg.shared_storage import initialize_share_data, finalize_share_dat from dotenv import load_dotenv # Updated to use the .env that is inside the current folder +# This update allows the user to put a different.env file for each lightrag folder load_dotenv(".env") def check_and_install_dependencies(): From 5680e9ef11ef8403f65823af59a7d52c29548d15 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 3 Mar 2025 12:24:49 +0100 Subject: [PATCH 07/21] Update lightrag_server.py --- lightrag/api/lightrag_server.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 5f2c437f..637595d3 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -20,7 +20,7 @@ from ascii_colors import ASCIIColors from fastapi.middleware.cors import CORSMiddleware from contextlib import asynccontextmanager from dotenv import load_dotenv -from .utils_api import ( +from lightrag.api.utils_api import ( get_api_key_dependency, parse_args, get_default_host, @@ -30,14 +30,14 @@ from lightrag import LightRAG from lightrag.types import GPTKeywordExtractionFormat from lightrag.api import __api_version__ from lightrag.utils import EmbeddingFunc -from .routers.document_routes import ( +from lightrag.api.routers.document_routes import ( DocumentManager, create_document_routes, run_scanning_process, ) -from .routers.query_routes import create_query_routes -from .routers.graph_routes import create_graph_routes -from .routers.ollama_api import OllamaAPI +from lightrag.api.routers.query_routes import create_query_routes +from lightrag.api.routers.graph_routes import create_graph_routes +from lightrag.api.routers.ollama_api import OllamaAPI from lightrag.utils import logger, set_verbose_debug from lightrag.kg.shared_storage import ( @@ -48,7 +48,9 @@ from lightrag.kg.shared_storage import ( ) # Load environment variables -load_dotenv(override=True) +# Updated to use the .env that is inside the current folder +# This update allows the user to put a different.env file for each lightrag folder +load_dotenv(".env", override=True) # Initialize config parser config = configparser.ConfigParser() From 0679ca4055d36dfd53afcb9ab87ea5d4c056cd31 Mon Sep 17 00:00:00 2001 From: zrguo Date: Tue, 4 Mar 2025 14:20:55 +0800 Subject: [PATCH 08/21] Update neo4j_impl.py --- lightrag/kg/neo4j_impl.py | 92 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 91 insertions(+), 1 deletion(-) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index dccee330..fec39138 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -690,8 +690,98 @@ class Neo4JStorage(BaseGraphStorage): labels.append(record["label"]) return labels + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type( + ( + neo4jExceptions.ServiceUnavailable, + neo4jExceptions.TransientError, + neo4jExceptions.WriteServiceUnavailable, + neo4jExceptions.ClientError, + ) + ), + ) async def delete_node(self, node_id: str) -> None: - raise NotImplementedError + """Delete a node with the specified label + + Args: + node_id: The label of the node to delete + """ + label = await self._ensure_label(node_id) + + async def _do_delete(tx: AsyncManagedTransaction): + query = f""" + MATCH (n:`{label}`) + DETACH DELETE n + """ + await tx.run(query) + logger.debug(f"Deleted node with label '{label}'") + + try: + async with self._driver.session(database=self._DATABASE) as session: + await session.execute_write(_do_delete) + except Exception as e: + logger.error(f"Error during node deletion: {str(e)}") + raise + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type( + ( + neo4jExceptions.ServiceUnavailable, + neo4jExceptions.TransientError, + neo4jExceptions.WriteServiceUnavailable, + neo4jExceptions.ClientError, + ) + ), + ) + async def remove_nodes(self, nodes: list[str]): + """Delete multiple nodes + + Args: + nodes: List of node labels to be deleted + """ + for node in nodes: + await self.delete_node(node) + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type( + ( + neo4jExceptions.ServiceUnavailable, + neo4jExceptions.TransientError, + neo4jExceptions.WriteServiceUnavailable, + neo4jExceptions.ClientError, + ) + ), + ) + async def remove_edges(self, edges: list[tuple[str, str]]): + """Delete multiple edges + + Args: + edges: List of edges to be deleted, each edge is a (source, target) tuple + """ + for source, target in edges: + source_label = await self._ensure_label(source) + target_label = await self._ensure_label(target) + + async def _do_delete_edge(tx: AsyncManagedTransaction): + query = f""" + MATCH (source:`{source_label}`)-[r]->(target:`{target_label}`) + DELETE r + """ + await tx.run(query) + logger.debug(f"Deleted edge from '{source_label}' to '{target_label}'") + + try: + async with self._driver.session(database=self._DATABASE) as session: + await session.execute_write(_do_delete_edge) + except Exception as e: + logger.error(f"Error during edge deletion: {str(e)}") + raise async def embed_nodes( self, algorithm: str From 3a2a6368628fd2d54851ed6b1de8026cdf3cf608 Mon Sep 17 00:00:00 2001 From: zrguo Date: Tue, 4 Mar 2025 15:50:53 +0800 Subject: [PATCH 09/21] Implement the missing methods. --- lightrag/kg/age_impl.py | 248 ++++++++++++++++++++++++++++++- lightrag/kg/chroma_impl.py | 34 ++++- lightrag/kg/gremlin_impl.py | 279 ++++++++++++++++++++++++++++++++++- lightrag/kg/milvus_impl.py | 83 ++++++++++- lightrag/kg/mongo_impl.py | 109 +++++++++++++- lightrag/kg/oracle_impl.py | 255 +++++++++++++++++++++++++++++++- lightrag/kg/postgres_impl.py | 243 +++++++++++++++++++++++++++++- lightrag/kg/qdrant_impl.py | 89 ++++++++++- lightrag/kg/redis_impl.py | 78 +++++++++- lightrag/kg/tidb_impl.py | 180 +++++++++++++++++++++- lightrag/lightrag.py | 48 ++++++ 11 files changed, 1603 insertions(+), 43 deletions(-) diff --git a/lightrag/kg/age_impl.py b/lightrag/kg/age_impl.py index 97b3825d..c6b98221 100644 --- a/lightrag/kg/age_impl.py +++ b/lightrag/kg/age_impl.py @@ -8,7 +8,7 @@ from dataclasses import dataclass from typing import Any, Dict, List, NamedTuple, Optional, Union, final import numpy as np import pipmaster as pm -from lightrag.types import KnowledgeGraph +from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from tenacity import ( retry, @@ -613,20 +613,258 @@ class AGEStorage(BaseGraphStorage): await self._driver.putconn(connection) async def delete_node(self, node_id: str) -> None: - raise NotImplementedError + """Delete a node with the specified label + + Args: + node_id: The label of the node to delete + """ + entity_name_label = node_id.strip('"') + + query = """ + MATCH (n:`{label}`) + DETACH DELETE n + """ + params = {"label": AGEStorage._encode_graph_label(entity_name_label)} + try: + await self._query(query, **params) + logger.debug(f"Deleted node with label '{entity_name_label}'") + except Exception as e: + logger.error(f"Error during node deletion: {str(e)}") + raise + + async def remove_nodes(self, nodes: list[str]): + """Delete multiple nodes + + Args: + nodes: List of node labels to be deleted + """ + for node in nodes: + await self.delete_node(node) + + async def remove_edges(self, edges: list[tuple[str, str]]): + """Delete multiple edges + + Args: + edges: List of edges to be deleted, each edge is a (source, target) tuple + """ + for source, target in edges: + entity_name_label_source = source.strip('"') + entity_name_label_target = target.strip('"') + + query = """ + MATCH (source:`{src_label}`)-[r]->(target:`{tgt_label}`) + DELETE r + """ + params = { + "src_label": AGEStorage._encode_graph_label(entity_name_label_source), + "tgt_label": AGEStorage._encode_graph_label(entity_name_label_target) + } + try: + await self._query(query, **params) + logger.debug(f"Deleted edge from '{entity_name_label_source}' to '{entity_name_label_target}'") + except Exception as e: + logger.error(f"Error during edge deletion: {str(e)}") + raise async def embed_nodes( self, algorithm: str ) -> tuple[np.ndarray[Any, Any], list[str]]: - raise NotImplementedError + """Embed nodes using the specified algorithm + + Args: + algorithm: Name of the embedding algorithm + + Returns: + tuple: (embedding matrix, list of node identifiers) + """ + if algorithm not in self._node_embed_algorithms: + raise ValueError(f"Node embedding algorithm {algorithm} not supported") + return await self._node_embed_algorithms[algorithm]() async def get_all_labels(self) -> list[str]: - raise NotImplementedError + """Get all node labels in the database + + Returns: + ["label1", "label2", ...] # Alphabetically sorted label list + """ + query = """ + MATCH (n) + RETURN DISTINCT labels(n) AS node_labels + """ + results = await self._query(query) + + all_labels = [] + for record in results: + if record and "node_labels" in record: + for label in record["node_labels"]: + if label: + # Decode label + decoded_label = AGEStorage._decode_graph_label(label) + all_labels.append(decoded_label) + + # Remove duplicates and sort + return sorted(list(set(all_labels))) async def get_knowledge_graph( self, node_label: str, max_depth: int = 5 ) -> KnowledgeGraph: - raise NotImplementedError + """ + Retrieve a connected subgraph of nodes where the label includes the specified 'node_label'. + Maximum number of nodes is constrained by the environment variable 'MAX_GRAPH_NODES' (default: 1000). + When reducing the number of nodes, the prioritization criteria are as follows: + 1. Label matching nodes take precedence (nodes containing the specified label string) + 2. Followed by nodes directly connected to the matching nodes + 3. Finally, the degree of the nodes + + Args: + node_label: String to match in node labels (will match any node containing this string in its label) + max_depth: Maximum depth of the graph. Defaults to 5. + + Returns: + KnowledgeGraph: Complete connected subgraph for specified node + """ + max_graph_nodes = int(os.getenv("MAX_GRAPH_NODES", 1000)) + result = KnowledgeGraph() + seen_nodes = set() + seen_edges = set() + + # Handle special case for "*" label + if node_label == "*": + # Query all nodes and sort by degree + query = """ + MATCH (n) + OPTIONAL MATCH (n)-[r]-() + WITH n, count(r) AS degree + ORDER BY degree DESC + LIMIT {max_nodes} + RETURN n, degree + """ + params = {"max_nodes": max_graph_nodes} + nodes_result = await self._query(query, **params) + + # Add nodes to result + node_ids = [] + for record in nodes_result: + if "n" in record: + node = record["n"] + node_id = str(node.get("id", "")) + if node_id not in seen_nodes: + node_properties = {k: v for k, v in node.items()} + node_label = node.get("label", "") + result.nodes.append( + KnowledgeGraphNode( + id=node_id, + labels=[node_label], + properties=node_properties + ) + ) + seen_nodes.add(node_id) + node_ids.append(node_id) + + # Query edges between these nodes + if node_ids: + edges_query = """ + MATCH (a)-[r]->(b) + WHERE a.id IN {node_ids} AND b.id IN {node_ids} + RETURN a, r, b + """ + edges_params = {"node_ids": node_ids} + edges_result = await self._query(edges_query, **edges_params) + + # Add edges to result + for record in edges_result: + if "r" in record and "a" in record and "b" in record: + source = record["a"].get("id", "") + target = record["b"].get("id", "") + edge_id = f"{source}-{target}" + if edge_id not in seen_edges: + edge_properties = {k: v for k, v in record["r"].items()} + result.edges.append( + KnowledgeGraphEdge( + id=edge_id, + type="DIRECTED", + source=source, + target=target, + properties=edge_properties + ) + ) + seen_edges.add(edge_id) + else: + # For specific label, use partial matching + entity_name_label = node_label.strip('"') + encoded_label = AGEStorage._encode_graph_label(entity_name_label) + + # Find matching start nodes + start_query = """ + MATCH (n:`{label}`) + RETURN n + """ + start_params = {"label": encoded_label} + start_nodes = await self._query(start_query, **start_params) + + if not start_nodes: + logger.warning(f"No nodes found with label '{entity_name_label}'!") + return result + + # Traverse graph from each start node + for start_node_record in start_nodes: + if "n" in start_node_record: + start_node = start_node_record["n"] + start_id = str(start_node.get("id", "")) + + # Use BFS to traverse graph + query = """ + MATCH (start:`{label}`) + CALL { + MATCH path = (start)-[*0..{max_depth}]->(n) + RETURN nodes(path) AS path_nodes, relationships(path) AS path_rels + } + RETURN DISTINCT path_nodes, path_rels + """ + params = {"label": encoded_label, "max_depth": max_depth} + results = await self._query(query, **params) + + # Extract nodes and edges from results + for record in results: + if "path_nodes" in record: + # Process nodes + for node in record["path_nodes"]: + node_id = str(node.get("id", "")) + if node_id not in seen_nodes and len(seen_nodes) < max_graph_nodes: + node_properties = {k: v for k, v in node.items()} + node_label = node.get("label", "") + result.nodes.append( + KnowledgeGraphNode( + id=node_id, + labels=[node_label], + properties=node_properties + ) + ) + seen_nodes.add(node_id) + + if "path_rels" in record: + # Process edges + for rel in record["path_rels"]: + source = str(rel.get("start_id", "")) + target = str(rel.get("end_id", "")) + edge_id = f"{source}-{target}" + if edge_id not in seen_edges: + edge_properties = {k: v for k, v in rel.items()} + result.edges.append( + KnowledgeGraphEdge( + id=edge_id, + type=rel.get("label", "DIRECTED"), + source=source, + target=target, + properties=edge_properties + ) + ) + seen_edges.add(edge_id) + + logger.info( + f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" + ) + return result async def index_done_callback(self) -> None: # AGES handles persistence automatically diff --git a/lightrag/kg/chroma_impl.py b/lightrag/kg/chroma_impl.py index 3b726c8b..d36e6d7c 100644 --- a/lightrag/kg/chroma_impl.py +++ b/lightrag/kg/chroma_impl.py @@ -193,7 +193,37 @@ class ChromaVectorDBStorage(BaseVectorStorage): pass async def delete_entity(self, entity_name: str) -> None: - raise NotImplementedError + """Delete an entity by its ID. + + Args: + entity_name: The ID of the entity to delete + """ + try: + logger.info(f"Deleting entity with ID {entity_name} from {self.namespace}") + self._collection.delete(ids=[entity_name]) + except Exception as e: + logger.error(f"Error during entity deletion: {str(e)}") + raise async def delete_entity_relation(self, entity_name: str) -> None: - raise NotImplementedError + """Delete an entity and its relations by ID. + In vector DB context, this is equivalent to delete_entity. + + Args: + entity_name: The ID of the entity to delete + """ + await self.delete_entity(entity_name) + + async def delete(self, ids: list[str]) -> None: + """Delete vectors with specified IDs + + Args: + ids: List of vector IDs to be deleted + """ + try: + logger.info(f"Deleting {len(ids)} vectors from {self.namespace}") + self._collection.delete(ids=ids) + logger.debug(f"Successfully deleted {len(ids)} vectors from {self.namespace}") + except Exception as e: + logger.error(f"Error while deleting vectors from {self.namespace}: {e}") + raise diff --git a/lightrag/kg/gremlin_impl.py b/lightrag/kg/gremlin_impl.py index 3a26401d..4d343bb5 100644 --- a/lightrag/kg/gremlin_impl.py +++ b/lightrag/kg/gremlin_impl.py @@ -16,7 +16,7 @@ from tenacity import ( wait_exponential, ) -from lightrag.types import KnowledgeGraph +from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from lightrag.utils import logger from ..base import BaseGraphStorage @@ -396,17 +396,286 @@ class GremlinStorage(BaseGraphStorage): print("Implemented but never called.") async def delete_node(self, node_id: str) -> None: - raise NotImplementedError + """Delete a node with the specified entity_name + + Args: + node_id: The entity_name of the node to delete + """ + entity_name = GremlinStorage._fix_name(node_id) + + query = f"""g + .V().has('graph', {self.graph_name}) + .has('entity_name', {entity_name}) + .drop() + """ + try: + await self._query(query) + logger.debug( + "{%s}: Deleted node with entity_name '%s'", + inspect.currentframe().f_code.co_name, + entity_name + ) + except Exception as e: + logger.error(f"Error during node deletion: {str(e)}") + raise async def embed_nodes( self, algorithm: str ) -> tuple[np.ndarray[Any, Any], list[str]]: - raise NotImplementedError + """ + Embed nodes using the specified algorithm. + Currently, only node2vec is supported but never called. + + Args: + algorithm: The name of the embedding algorithm to use + + Returns: + A tuple of (embeddings, node_ids) + + Raises: + NotImplementedError: If the specified algorithm is not supported + ValueError: If the algorithm is not supported + """ + if algorithm not in self._node_embed_algorithms: + raise ValueError(f"Node embedding algorithm {algorithm} not supported") + return await self._node_embed_algorithms[algorithm]() async def get_all_labels(self) -> list[str]: - raise NotImplementedError + """ + Get all node entity_names in the graph + Returns: + [entity_name1, entity_name2, ...] # Alphabetically sorted entity_name list + """ + query = f"""g + .V().has('graph', {self.graph_name}) + .values('entity_name') + .dedup() + .order() + """ + try: + result = await self._query(query) + labels = result if result else [] + logger.debug( + "{%s}: Retrieved %d labels", + inspect.currentframe().f_code.co_name, + len(labels) + ) + return labels + except Exception as e: + logger.error(f"Error retrieving labels: {str(e)}") + return [] async def get_knowledge_graph( self, node_label: str, max_depth: int = 5 ) -> KnowledgeGraph: - raise NotImplementedError + """ + Retrieve a connected subgraph of nodes where the entity_name includes the specified `node_label`. + Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000). + + Args: + node_label: Entity name of the starting node + max_depth: Maximum depth of the subgraph + + Returns: + KnowledgeGraph object containing nodes and edges + """ + result = KnowledgeGraph() + seen_nodes = set() + seen_edges = set() + + # Get maximum number of graph nodes from environment variable, default is 1000 + MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) + + entity_name = GremlinStorage._fix_name(node_label) + + # Handle special case for "*" label + if node_label == "*": + # For "*", get all nodes and their edges (limited by MAX_GRAPH_NODES) + query = f"""g + .V().has('graph', {self.graph_name}) + .limit({MAX_GRAPH_NODES}) + .elementMap() + """ + nodes_result = await self._query(query) + + # Add nodes to result + for node_data in nodes_result: + node_id = node_data.get('entity_name', str(node_data.get('id', ''))) + if str(node_id) in seen_nodes: + continue + + # Create node with properties + node_properties = {k: v for k, v in node_data.items() if k not in ['id', 'label']} + + result.nodes.append( + KnowledgeGraphNode( + id=str(node_id), + labels=[str(node_id)], + properties=node_properties + ) + ) + seen_nodes.add(str(node_id)) + + # Get and add edges + if nodes_result: + query = f"""g + .V().has('graph', {self.graph_name}) + .limit({MAX_GRAPH_NODES}) + .outE() + .inV().has('graph', {self.graph_name}) + .limit({MAX_GRAPH_NODES}) + .path() + .by(elementMap()) + .by(elementMap()) + .by(elementMap()) + """ + edges_result = await self._query(query) + + for path in edges_result: + if len(path) >= 3: # source -> edge -> target + source = path[0] + edge_data = path[1] + target = path[2] + + source_id = source.get('entity_name', str(source.get('id', ''))) + target_id = target.get('entity_name', str(target.get('id', ''))) + + edge_id = f"{source_id}-{target_id}" + if edge_id in seen_edges: + continue + + # Create edge with properties + edge_properties = {k: v for k, v in edge_data.items() if k not in ['id', 'label']} + + result.edges.append( + KnowledgeGraphEdge( + id=edge_id, + type="DIRECTED", + source=str(source_id), + target=str(target_id), + properties=edge_properties + ) + ) + seen_edges.add(edge_id) + else: + # Search for specific node and get its neighborhood + query = f"""g + .V().has('graph', {self.graph_name}) + .has('entity_name', {entity_name}) + .repeat(__.both().simplePath().dedup()) + .times({max_depth}) + .emit() + .dedup() + .limit({MAX_GRAPH_NODES}) + .elementMap() + """ + nodes_result = await self._query(query) + + # Add nodes to result + for node_data in nodes_result: + node_id = node_data.get('entity_name', str(node_data.get('id', ''))) + if str(node_id) in seen_nodes: + continue + + # Create node with properties + node_properties = {k: v for k, v in node_data.items() if k not in ['id', 'label']} + + result.nodes.append( + KnowledgeGraphNode( + id=str(node_id), + labels=[str(node_id)], + properties=node_properties + ) + ) + seen_nodes.add(str(node_id)) + + # Get edges between the nodes in the result + if nodes_result: + node_ids = [n.get('entity_name', str(n.get('id', ''))) for n in nodes_result] + node_ids_query = ", ".join([GremlinStorage._to_value_map(nid) for nid in node_ids]) + + query = f"""g + .V().has('graph', {self.graph_name}) + .has('entity_name', within({node_ids_query})) + .outE() + .where(inV().has('graph', {self.graph_name}) + .has('entity_name', within({node_ids_query}))) + .path() + .by(elementMap()) + .by(elementMap()) + .by(elementMap()) + """ + edges_result = await self._query(query) + + for path in edges_result: + if len(path) >= 3: # source -> edge -> target + source = path[0] + edge_data = path[1] + target = path[2] + + source_id = source.get('entity_name', str(source.get('id', ''))) + target_id = target.get('entity_name', str(target.get('id', ''))) + + edge_id = f"{source_id}-{target_id}" + if edge_id in seen_edges: + continue + + # Create edge with properties + edge_properties = {k: v for k, v in edge_data.items() if k not in ['id', 'label']} + + result.edges.append( + KnowledgeGraphEdge( + id=edge_id, + type="DIRECTED", + source=str(source_id), + target=str(target_id), + properties=edge_properties + ) + ) + seen_edges.add(edge_id) + + logger.info( + "Subgraph query successful | Node count: %d | Edge count: %d", + len(result.nodes), + len(result.edges) + ) + return result + + async def remove_nodes(self, nodes: list[str]): + """Delete multiple nodes + + Args: + nodes: List of node entity_names to be deleted + """ + for node in nodes: + await self.delete_node(node) + + async def remove_edges(self, edges: list[tuple[str, str]]): + """Delete multiple edges + + Args: + edges: List of edges to be deleted, each edge is a (source, target) tuple + """ + for source, target in edges: + entity_name_source = GremlinStorage._fix_name(source) + entity_name_target = GremlinStorage._fix_name(target) + + query = f"""g + .V().has('graph', {self.graph_name}) + .has('entity_name', {entity_name_source}) + .outE() + .where(inV().has('graph', {self.graph_name}) + .has('entity_name', {entity_name_target})) + .drop() + """ + try: + await self._query(query) + logger.debug( + "{%s}: Deleted edge from '%s' to '%s'", + inspect.currentframe().f_code.co_name, + entity_name_source, + entity_name_target + ) + except Exception as e: + logger.error(f"Error during edge deletion: {str(e)}") + raise diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index 33a5c12b..2ad4da18 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -3,7 +3,7 @@ import os from typing import Any, final from dataclasses import dataclass import numpy as np -from lightrag.utils import logger +from lightrag.utils import logger, compute_mdhash_id from ..base import BaseVectorStorage import pipmaster as pm @@ -124,7 +124,84 @@ class MilvusVectorDBStorage(BaseVectorStorage): pass async def delete_entity(self, entity_name: str) -> None: - raise NotImplementedError + """Delete an entity from the vector database + + Args: + entity_name: The name of the entity to delete + """ + try: + # Compute entity ID from name + entity_id = compute_mdhash_id(entity_name, prefix="ent-") + logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}") + + # Delete the entity from Milvus collection + result = self._client.delete( + collection_name=self.namespace, + pks=[entity_id] + ) + + if result and result.get("delete_count", 0) > 0: + logger.debug(f"Successfully deleted entity {entity_name}") + else: + logger.debug(f"Entity {entity_name} not found in storage") + + except Exception as e: + logger.error(f"Error deleting entity {entity_name}: {e}") async def delete_entity_relation(self, entity_name: str) -> None: - raise NotImplementedError + """Delete all relations associated with an entity + + Args: + entity_name: The name of the entity whose relations should be deleted + """ + try: + # Search for relations where entity is either source or target + expr = f'src_id == "{entity_name}" or tgt_id == "{entity_name}"' + + # Find all relations involving this entity + results = self._client.query( + collection_name=self.namespace, + filter=expr, + output_fields=["id"] + ) + + if not results or len(results) == 0: + logger.debug(f"No relations found for entity {entity_name}") + return + + # Extract IDs of relations to delete + relation_ids = [item["id"] for item in results] + logger.debug(f"Found {len(relation_ids)} relations for entity {entity_name}") + + # Delete the relations + if relation_ids: + delete_result = self._client.delete( + collection_name=self.namespace, + pks=relation_ids + ) + + logger.debug(f"Deleted {delete_result.get('delete_count', 0)} relations for {entity_name}") + + except Exception as e: + logger.error(f"Error deleting relations for {entity_name}: {e}") + + async def delete(self, ids: list[str]) -> None: + """Delete vectors with specified IDs + + Args: + ids: List of vector IDs to be deleted + """ + try: + # Delete vectors by IDs + result = self._client.delete( + collection_name=self.namespace, + pks=ids + ) + + if result and result.get("delete_count", 0) > 0: + logger.debug(f"Successfully deleted {result.get('delete_count', 0)} vectors from {self.namespace}") + else: + logger.debug(f"No vectors were deleted from {self.namespace}") + + except Exception as e: + logger.error(f"Error while deleting vectors from {self.namespace}: {e}") diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 0048b384..3afd2b44 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -15,7 +15,7 @@ from ..base import ( DocStatusStorage, ) from ..namespace import NameSpace, is_namespace -from ..utils import logger +from ..utils import logger, compute_mdhash_id from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge import pipmaster as pm @@ -333,7 +333,7 @@ class MongoGraphStorage(BaseGraphStorage): Check if there's a direct single-hop edge from source_node_id to target_node_id. We'll do a $graphLookup with maxDepth=0 from the source node—meaning - “Look up zero expansions.” Actually, for a direct edge check, we can do maxDepth=1 + "Look up zero expansions." Actually, for a direct edge check, we can do maxDepth=1 and then see if the target node is in the "reachableNodes" at depth=0. But typically for a direct edge, we might just do a find_one. @@ -795,6 +795,52 @@ class MongoGraphStorage(BaseGraphStorage): # Mongo handles persistence automatically pass + async def remove_nodes(self, nodes: list[str]) -> None: + """Delete multiple nodes + + Args: + nodes: List of node IDs to be deleted + """ + logger.info(f"Deleting {len(nodes)} nodes") + if not nodes: + return + + # 1. Remove all edges referencing these nodes (remove from edges array of other nodes) + await self.collection.update_many( + {}, + {"$pull": {"edges": {"target": {"$in": nodes}}}} + ) + + # 2. Delete the node documents + await self.collection.delete_many({"_id": {"$in": nodes}}) + + logger.debug(f"Successfully deleted nodes: {nodes}") + + async def remove_edges(self, edges: list[tuple[str, str]]) -> None: + """Delete multiple edges + + Args: + edges: List of edges to be deleted, each edge is a (source, target) tuple + """ + logger.info(f"Deleting {len(edges)} edges") + if not edges: + return + + update_tasks = [] + for source, target in edges: + # Remove edge pointing to target from source node's edges array + update_tasks.append( + self.collection.update_one( + {"_id": source}, + {"$pull": {"edges": {"target": target}}} + ) + ) + + if update_tasks: + await asyncio.gather(*update_tasks) + + logger.debug(f"Successfully deleted edges: {edges}") + @final @dataclass @@ -932,11 +978,66 @@ class MongoVectorDBStorage(BaseVectorStorage): # Mongo handles persistence automatically pass + async def delete(self, ids: list[str]) -> None: + """Delete vectors with specified IDs + + Args: + ids: List of vector IDs to be deleted + """ + logger.info(f"Deleting {len(ids)} vectors from {self.namespace}") + if not ids: + return + + try: + result = await self._data.delete_many({"_id": {"$in": ids}}) + logger.debug(f"Successfully deleted {result.deleted_count} vectors from {self.namespace}") + except PyMongoError as e: + logger.error(f"Error while deleting vectors from {self.namespace}: {str(e)}") + async def delete_entity(self, entity_name: str) -> None: - raise NotImplementedError + """Delete an entity by its name + + Args: + entity_name: Name of the entity to delete + """ + try: + entity_id = compute_mdhash_id(entity_name, prefix="ent-") + logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}") + + result = await self._data.delete_one({"_id": entity_id}) + if result.deleted_count > 0: + logger.debug(f"Successfully deleted entity {entity_name}") + else: + logger.debug(f"Entity {entity_name} not found in storage") + except PyMongoError as e: + logger.error(f"Error deleting entity {entity_name}: {str(e)}") async def delete_entity_relation(self, entity_name: str) -> None: - raise NotImplementedError + """Delete all relations associated with an entity + + Args: + entity_name: Name of the entity whose relations should be deleted + """ + try: + # Find relations where entity appears as source or target + relations_cursor = self._data.find( + {"$or": [{"src_id": entity_name}, {"tgt_id": entity_name}]} + ) + relations = await relations_cursor.to_list(length=None) + + if not relations: + logger.debug(f"No relations found for entity {entity_name}") + return + + # Extract IDs of relations to delete + relation_ids = [relation["_id"] for relation in relations] + logger.debug(f"Found {len(relation_ids)} relations for entity {entity_name}") + + # Delete the relations + result = await self._data.delete_many({"_id": {"$in": relation_ids}}) + logger.debug(f"Deleted {result.deleted_count} relations for {entity_name}") + except PyMongoError as e: + logger.error(f"Error deleting relations for {entity_name}: {str(e)}") async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str): diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index af2ededb..d189679e 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -8,7 +8,7 @@ from typing import Any, Union, final import numpy as np import configparser -from lightrag.types import KnowledgeGraph +from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..base import ( BaseGraphStorage, @@ -442,11 +442,55 @@ class OracleVectorDBStorage(BaseVectorStorage): # Oracles handles persistence automatically pass + async def delete(self, ids: list[str]) -> None: + """Delete vectors with specified IDs + + Args: + ids: List of vector IDs to be deleted + """ + if not ids: + return + + try: + SQL = SQL_TEMPLATES["delete_vectors"].format( + ids=",".join([f"'{id}'" for id in ids]) + ) + params = {"workspace": self.db.workspace} + await self.db.execute(SQL, params) + logger.info(f"Successfully deleted {len(ids)} vectors from {self.namespace}") + except Exception as e: + logger.error(f"Error while deleting vectors from {self.namespace}: {e}") + raise + async def delete_entity(self, entity_name: str) -> None: - raise NotImplementedError + """Delete entity by name + + Args: + entity_name: Name of the entity to delete + """ + try: + SQL = SQL_TEMPLATES["delete_entity"] + params = {"workspace": self.db.workspace, "entity_name": entity_name} + await self.db.execute(SQL, params) + logger.info(f"Successfully deleted entity {entity_name}") + except Exception as e: + logger.error(f"Error deleting entity {entity_name}: {e}") + raise async def delete_entity_relation(self, entity_name: str) -> None: - raise NotImplementedError + """Delete all relations connected to an entity + + Args: + entity_name: Name of the entity whose relations should be deleted + """ + try: + SQL = SQL_TEMPLATES["delete_entity_relations"] + params = {"workspace": self.db.workspace, "entity_name": entity_name} + await self.db.execute(SQL, params) + logger.info(f"Successfully deleted relations for entity {entity_name}") + except Exception as e: + logger.error(f"Error deleting relations for entity {entity_name}: {e}") + raise @final @@ -668,15 +712,206 @@ class OracleGraphStorage(BaseGraphStorage): return res async def delete_node(self, node_id: str) -> None: - raise NotImplementedError + """Delete a node from the graph + + Args: + node_id: ID of the node to delete + """ + try: + # First delete all relations connected to this node + delete_relations_sql = SQL_TEMPLATES["delete_entity_relations"] + params_relations = {"workspace": self.db.workspace, "entity_name": node_id} + await self.db.execute(delete_relations_sql, params_relations) + + # Then delete the node itself + delete_node_sql = SQL_TEMPLATES["delete_entity"] + params_node = {"workspace": self.db.workspace, "entity_name": node_id} + await self.db.execute(delete_node_sql, params_node) + + logger.info(f"Successfully deleted node {node_id} and all its relationships") + except Exception as e: + logger.error(f"Error deleting node {node_id}: {e}") + raise async def get_all_labels(self) -> list[str]: - raise NotImplementedError + """Get all unique entity types (labels) in the graph + + Returns: + List of unique entity types/labels + """ + try: + SQL = """ + SELECT DISTINCT entity_type + FROM LIGHTRAG_GRAPH_NODES + WHERE workspace = :workspace + ORDER BY entity_type + """ + params = {"workspace": self.db.workspace} + results = await self.db.query(SQL, params, multirows=True) + + if results: + labels = [row["entity_type"] for row in results] + return labels + else: + return [] + except Exception as e: + logger.error(f"Error retrieving entity types: {e}") + return [] async def get_knowledge_graph( self, node_label: str, max_depth: int = 5 ) -> KnowledgeGraph: - raise NotImplementedError + """Retrieve a connected subgraph starting from nodes matching the given label + + Maximum number of nodes is constrained by MAX_GRAPH_NODES environment variable. + Prioritizes nodes by: + 1. Nodes matching the specified label + 2. Nodes directly connected to matching nodes + 3. Node degree (number of connections) + + Args: + node_label: Label to match for starting nodes (use "*" for all nodes) + max_depth: Maximum depth of traversal from starting nodes + + Returns: + KnowledgeGraph object containing nodes and edges + """ + result = KnowledgeGraph() + + try: + # Define maximum number of nodes to return + max_graph_nodes = int(os.environ.get("MAX_GRAPH_NODES", 1000)) + + if node_label == "*": + # For "*" label, get all nodes up to the limit + nodes_sql = """ + SELECT name, entity_type, description, source_chunk_id + FROM LIGHTRAG_GRAPH_NODES + WHERE workspace = :workspace + ORDER BY id + FETCH FIRST :limit ROWS ONLY + """ + nodes_params = {"workspace": self.db.workspace, "limit": max_graph_nodes} + nodes = await self.db.query(nodes_sql, nodes_params, multirows=True) + else: + # For specific label, find matching nodes and related nodes + nodes_sql = """ + WITH matching_nodes AS ( + SELECT name + FROM LIGHTRAG_GRAPH_NODES + WHERE workspace = :workspace + AND (name LIKE '%' || :node_label || '%' OR entity_type LIKE '%' || :node_label || '%') + ) + SELECT n.name, n.entity_type, n.description, n.source_chunk_id, + CASE + WHEN n.name IN (SELECT name FROM matching_nodes) THEN 2 + WHEN EXISTS ( + SELECT 1 FROM LIGHTRAG_GRAPH_EDGES e + WHERE workspace = :workspace + AND ((e.source_name = n.name AND e.target_name IN (SELECT name FROM matching_nodes)) + OR (e.target_name = n.name AND e.source_name IN (SELECT name FROM matching_nodes))) + ) THEN 1 + ELSE 0 + END AS priority, + (SELECT COUNT(*) FROM LIGHTRAG_GRAPH_EDGES e + WHERE workspace = :workspace + AND (e.source_name = n.name OR e.target_name = n.name)) AS degree + FROM LIGHTRAG_GRAPH_NODES n + WHERE workspace = :workspace + ORDER BY priority DESC, degree DESC + FETCH FIRST :limit ROWS ONLY + """ + nodes_params = { + "workspace": self.db.workspace, + "node_label": node_label, + "limit": max_graph_nodes + } + nodes = await self.db.query(nodes_sql, nodes_params, multirows=True) + + if not nodes: + logger.warning(f"No nodes found matching '{node_label}'") + return result + + # Create mapping of node IDs to be used to filter edges + node_names = [node["name"] for node in nodes] + + # Add nodes to result + seen_nodes = set() + for node in nodes: + node_id = node["name"] + if node_id in seen_nodes: + continue + + # Create node properties dictionary + properties = { + "entity_type": node["entity_type"], + "description": node["description"] or "", + "source_id": node["source_chunk_id"] or "" + } + + # Add node to result + result.nodes.append( + KnowledgeGraphNode( + id=node_id, + labels=[node["entity_type"]], + properties=properties + ) + ) + seen_nodes.add(node_id) + + # Get edges between these nodes + edges_sql = """ + SELECT source_name, target_name, weight, keywords, description, source_chunk_id + FROM LIGHTRAG_GRAPH_EDGES + WHERE workspace = :workspace + AND source_name IN (SELECT COLUMN_VALUE FROM TABLE(CAST(:node_names AS SYS.ODCIVARCHAR2LIST))) + AND target_name IN (SELECT COLUMN_VALUE FROM TABLE(CAST(:node_names AS SYS.ODCIVARCHAR2LIST))) + ORDER BY id + """ + edges_params = { + "workspace": self.db.workspace, + "node_names": node_names + } + edges = await self.db.query(edges_sql, edges_params, multirows=True) + + # Add edges to result + seen_edges = set() + for edge in edges: + source = edge["source_name"] + target = edge["target_name"] + edge_id = f"{source}-{target}" + + if edge_id in seen_edges: + continue + + # Create edge properties dictionary + properties = { + "weight": edge["weight"] or 0.0, + "keywords": edge["keywords"] or "", + "description": edge["description"] or "", + "source_id": edge["source_chunk_id"] or "" + } + + # Add edge to result + result.edges.append( + KnowledgeGraphEdge( + id=edge_id, + type="RELATED", + source=source, + target=target, + properties=properties + ) + ) + seen_edges.add(edge_id) + + logger.info( + f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" + ) + + except Exception as e: + logger.error(f"Error retrieving knowledge graph: {e}") + + return result N_T = { @@ -927,4 +1162,12 @@ SQL_TEMPLATES = { select 'edge' as type, TO_CHAR(id) id FROM GRAPH_TABLE (lightrag_graph MATCH (a)-[e]->(b) WHERE e.workspace=:workspace columns(e.id)) )""", + # SQL for deletion + "delete_vectors": "DELETE FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=:workspace AND id IN ({ids})", + "delete_entity": "DELETE FROM LIGHTRAG_GRAPH_NODES WHERE workspace=:workspace AND name=:entity_name", + "delete_entity_relations": "DELETE FROM LIGHTRAG_GRAPH_EDGES WHERE workspace=:workspace AND (source_name=:entity_name OR target_name=:entity_name)", + "delete_node": """DELETE FROM GRAPH_TABLE (lightrag_graph + MATCH (a) + WHERE a.workspace=:workspace AND a.name=:node_id + ACTION DELETE a)""", } diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 51044be5..7ce2b427 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -7,7 +7,7 @@ from typing import Any, Union, final import numpy as np import configparser -from lightrag.types import KnowledgeGraph +from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge import sys from tenacity import ( @@ -512,11 +512,66 @@ class PGVectorStorage(BaseVectorStorage): # PG handles persistence automatically pass + async def delete(self, ids: list[str]) -> None: + """Delete vectors with specified IDs from the storage. + + Args: + ids: List of vector IDs to be deleted + """ + if not ids: + return + + table_name = namespace_to_table_name(self.namespace) + if not table_name: + logger.error(f"Unknown namespace for vector deletion: {self.namespace}") + return + + ids_list = ",".join([f"'{id}'" for id in ids]) + delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id IN ({ids_list})" + + try: + await self.db.execute(delete_sql, {"workspace": self.db.workspace}) + logger.debug(f"Successfully deleted {len(ids)} vectors from {self.namespace}") + except Exception as e: + logger.error(f"Error while deleting vectors from {self.namespace}: {e}") + async def delete_entity(self, entity_name: str) -> None: - raise NotImplementedError + """Delete an entity by its name from the vector storage. + + Args: + entity_name: The name of the entity to delete + """ + try: + # Construct SQL to delete the entity + delete_sql = """DELETE FROM LIGHTRAG_VDB_ENTITY + WHERE workspace=$1 AND entity_name=$2""" + + await self.db.execute( + delete_sql, + {"workspace": self.db.workspace, "entity_name": entity_name} + ) + logger.debug(f"Successfully deleted entity {entity_name}") + except Exception as e: + logger.error(f"Error deleting entity {entity_name}: {e}") async def delete_entity_relation(self, entity_name: str) -> None: - raise NotImplementedError + """Delete all relations associated with an entity. + + Args: + entity_name: The name of the entity whose relations should be deleted + """ + try: + # Delete relations where the entity is either the source or target + delete_sql = """DELETE FROM LIGHTRAG_VDB_RELATION + WHERE workspace=$1 AND (source_id=$2 OR target_id=$2)""" + + await self.db.execute( + delete_sql, + {"workspace": self.db.workspace, "entity_name": entity_name} + ) + logger.debug(f"Successfully deleted relations for entity {entity_name}") + except Exception as e: + logger.error(f"Error deleting relations for entity {entity_name}: {e}") @final @@ -1086,20 +1141,192 @@ class PGGraphStorage(BaseGraphStorage): print("Implemented but never called.") async def delete_node(self, node_id: str) -> None: - raise NotImplementedError + """ + Delete a node from the graph. + + Args: + node_id (str): The ID of the node to delete. + """ + label = self._encode_graph_label(node_id.strip('"')) + + query = """SELECT * FROM cypher('%s', $$ + MATCH (n:Entity {node_id: "%s"}) + DETACH DELETE n + $$) AS (n agtype)""" % (self.graph_name, label) + + try: + await self._query(query, readonly=False) + except Exception as e: + logger.error("Error during node deletion: {%s}", e) + raise + + async def remove_nodes(self, node_ids: list[str]) -> None: + """ + Remove multiple nodes from the graph. + + Args: + node_ids (list[str]): A list of node IDs to remove. + """ + encoded_node_ids = [self._encode_graph_label(node_id.strip('"')) for node_id in node_ids] + node_id_list = ", ".join([f'"{node_id}"' for node_id in encoded_node_ids]) + + query = """SELECT * FROM cypher('%s', $$ + MATCH (n:Entity) + WHERE n.node_id IN [%s] + DETACH DELETE n + $$) AS (n agtype)""" % (self.graph_name, node_id_list) + + try: + await self._query(query, readonly=False) + except Exception as e: + logger.error("Error during node removal: {%s}", e) + raise + + async def remove_edges(self, edges: list[tuple[str, str]]) -> None: + """ + Remove multiple edges from the graph. + + Args: + edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id). + """ + encoded_edges = [(self._encode_graph_label(src.strip('"')), self._encode_graph_label(tgt.strip('"'))) for src, tgt in edges] + edge_list = ", ".join([f'["{src}", "{tgt}"]' for src, tgt in encoded_edges]) + + query = """SELECT * FROM cypher('%s', $$ + MATCH (a:Entity)-[r]->(b:Entity) + WHERE [a.node_id, b.node_id] IN [%s] + DELETE r + $$) AS (r agtype)""" % (self.graph_name, edge_list) + + try: + await self._query(query, readonly=False) + except Exception as e: + logger.error("Error during edge removal: {%s}", e) + raise + + async def get_all_labels(self) -> list[str]: + """ + Get all labels (node IDs) in the graph. + + Returns: + list[str]: A list of all labels in the graph. + """ + query = """SELECT * FROM cypher('%s', $$ + MATCH (n:Entity) + RETURN DISTINCT n.node_id AS label + $$) AS (label text)""" % self.graph_name + + results = await self._query(query) + labels = [self._decode_graph_label(result["label"]) for result in results] + + return labels async def embed_nodes( self, algorithm: str ) -> tuple[np.ndarray[Any, Any], list[str]]: - raise NotImplementedError + """ + Generate node embeddings using the specified algorithm. - async def get_all_labels(self) -> list[str]: - raise NotImplementedError + Args: + algorithm (str): The name of the embedding algorithm to use. + + Returns: + tuple[np.ndarray[Any, Any], list[str]]: A tuple containing the embeddings and the corresponding node IDs. + """ + if algorithm not in self._node_embed_algorithms: + raise ValueError(f"Unsupported embedding algorithm: {algorithm}") + + embed_func = self._node_embed_algorithms[algorithm] + return await embed_func() async def get_knowledge_graph( self, node_label: str, max_depth: int = 5 ) -> KnowledgeGraph: - raise NotImplementedError + """ + Retrieve a subgraph containing the specified node and its neighbors up to the specified depth. + + Args: + node_label (str): The label of the node to start from. If "*", the entire graph is returned. + max_depth (int): The maximum depth to traverse from the starting node. + + Returns: + KnowledgeGraph: The retrieved subgraph. + """ + MAX_GRAPH_NODES = 1000 + + if node_label == "*": + query = """SELECT * FROM cypher('%s', $$ + MATCH (n:Entity) + OPTIONAL MATCH (n)-[r]->(m:Entity) + RETURN n, r, m + LIMIT %d + $$) AS (n agtype, r agtype, m agtype)""" % (self.graph_name, MAX_GRAPH_NODES) + else: + encoded_node_label = self._encode_graph_label(node_label.strip('"')) + query = """SELECT * FROM cypher('%s', $$ + MATCH (n:Entity {node_id: "%s"}) + OPTIONAL MATCH p = (n)-[*..%d]-(m) + RETURN nodes(p) AS nodes, relationships(p) AS relationships + LIMIT %d + $$) AS (nodes agtype[], relationships agtype[])""" % (self.graph_name, encoded_node_label, max_depth, MAX_GRAPH_NODES) + + results = await self._query(query) + + nodes = set() + edges = [] + + for result in results: + if node_label == "*": + if result["n"]: + node = result["n"] + nodes.add(self._decode_graph_label(node["node_id"])) + if result["m"]: + node = result["m"] + nodes.add(self._decode_graph_label(node["node_id"])) + if result["r"]: + edge = result["r"] + src_id = self._decode_graph_label(edge["start_id"]) + tgt_id = self._decode_graph_label(edge["end_id"]) + edges.append((src_id, tgt_id)) + else: + if result["nodes"]: + for node in result["nodes"]: + nodes.add(self._decode_graph_label(node["node_id"])) + if result["relationships"]: + for edge in result["relationships"]: + src_id = self._decode_graph_label(edge["start_id"]) + tgt_id = self._decode_graph_label(edge["end_id"]) + edges.append((src_id, tgt_id)) + + kg = KnowledgeGraph( + nodes=[KnowledgeGraphNode(id=node_id) for node_id in nodes], + edges=[KnowledgeGraphEdge(source=src, target=tgt) for src, tgt in edges], + ) + + return kg + + async def get_all_labels(self) -> list[str]: + """ + Get all node labels in the graph + Returns: + [label1, label2, ...] # Alphabetically sorted label list + """ + query = """SELECT * FROM cypher('%s', $$ + MATCH (n:Entity) + RETURN DISTINCT n.node_id AS label + ORDER BY label + $$) AS (label agtype)""" % (self.graph_name) + + try: + results = await self._query(query) + labels = [] + for record in results: + if record["label"]: + labels.append(self._decode_graph_label(record["label"])) + return labels + except Exception as e: + logger.error(f"Error getting all labels: {str(e)}") + return [] async def drop(self) -> None: """Drop the storage""" diff --git a/lightrag/kg/qdrant_impl.py b/lightrag/kg/qdrant_impl.py index b08f0b62..e3488caa 100644 --- a/lightrag/kg/qdrant_impl.py +++ b/lightrag/kg/qdrant_impl.py @@ -1,6 +1,6 @@ import asyncio import os -from typing import Any, final +from typing import Any, final, List from dataclasses import dataclass import numpy as np import hashlib @@ -141,8 +141,91 @@ class QdrantVectorDBStorage(BaseVectorStorage): # Qdrant handles persistence automatically pass + async def delete(self, ids: List[str]) -> None: + """Delete vectors with specified IDs + + Args: + ids: List of vector IDs to be deleted + """ + try: + # Convert regular ids to Qdrant compatible ids + qdrant_ids = [compute_mdhash_id_for_qdrant(id) for id in ids] + # Delete points from the collection + self._client.delete( + collection_name=self.namespace, + points_selector=models.PointIdsList( + points=qdrant_ids, + ), + wait=True + ) + logger.debug(f"Successfully deleted {len(ids)} vectors from {self.namespace}") + except Exception as e: + logger.error(f"Error while deleting vectors from {self.namespace}: {e}") + async def delete_entity(self, entity_name: str) -> None: - raise NotImplementedError + """Delete an entity by name + + Args: + entity_name: Name of the entity to delete + """ + try: + # Generate the entity ID + entity_id = compute_mdhash_id_for_qdrant(entity_name, prefix="ent-") + logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}") + + # Delete the entity point from the collection + self._client.delete( + collection_name=self.namespace, + points_selector=models.PointIdsList( + points=[entity_id], + ), + wait=True + ) + logger.debug(f"Successfully deleted entity {entity_name}") + except Exception as e: + logger.error(f"Error deleting entity {entity_name}: {e}") async def delete_entity_relation(self, entity_name: str) -> None: - raise NotImplementedError + """Delete all relations associated with an entity + + Args: + entity_name: Name of the entity whose relations should be deleted + """ + try: + # Find relations where the entity is either source or target + results = self._client.scroll( + collection_name=self.namespace, + scroll_filter=models.Filter( + should=[ + models.FieldCondition( + key="src_id", + match=models.MatchValue(value=entity_name) + ), + models.FieldCondition( + key="tgt_id", + match=models.MatchValue(value=entity_name) + ) + ] + ), + with_payload=True, + limit=1000 # Adjust as needed for your use case + ) + + # Extract points that need to be deleted + relation_points = results[0] + ids_to_delete = [point.id for point in relation_points] + + if ids_to_delete: + # Delete the relations + self._client.delete( + collection_name=self.namespace, + points_selector=models.PointIdsList( + points=ids_to_delete, + ), + wait=True + ) + logger.debug(f"Deleted {len(ids_to_delete)} relations for {entity_name}") + else: + logger.debug(f"No relations found for entity {entity_name}") + except Exception as e: + logger.error(f"Error deleting relations for {entity_name}: {e}") diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index 7e177346..bb42b367 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -9,7 +9,7 @@ if not pm.is_installed("redis"): # aioredis is a depricated library, replaced with redis from redis.asyncio import Redis -from lightrag.utils import logger +from lightrag.utils import logger, compute_mdhash_id from lightrag.base import BaseKVStorage import json @@ -64,3 +64,79 @@ class RedisKVStorage(BaseKVStorage): async def index_done_callback(self) -> None: # Redis handles persistence automatically pass + + async def delete(self, ids: list[str]) -> None: + """Delete entries with specified IDs + + Args: + ids: List of entry IDs to be deleted + """ + if not ids: + return + + pipe = self._redis.pipeline() + for id in ids: + pipe.delete(f"{self.namespace}:{id}") + + results = await pipe.execute() + deleted_count = sum(results) + logger.info(f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}") + + async def delete_entity(self, entity_name: str) -> None: + """Delete an entity by name + + Args: + entity_name: Name of the entity to delete + """ + + try: + entity_id = compute_mdhash_id(entity_name, prefix="ent-") + logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}") + + # Delete the entity + result = await self._redis.delete(f"{self.namespace}:{entity_id}") + + if result: + logger.debug(f"Successfully deleted entity {entity_name}") + else: + logger.debug(f"Entity {entity_name} not found in storage") + except Exception as e: + logger.error(f"Error deleting entity {entity_name}: {e}") + + async def delete_entity_relation(self, entity_name: str) -> None: + """Delete all relations associated with an entity + + Args: + entity_name: Name of the entity whose relations should be deleted + """ + try: + # Get all keys in this namespace + cursor = 0 + relation_keys = [] + pattern = f"{self.namespace}:*" + + while True: + cursor, keys = await self._redis.scan(cursor, match=pattern) + + # For each key, get the value and check if it's related to entity_name + for key in keys: + value = await self._redis.get(key) + if value: + data = json.loads(value) + # Check if this is a relation involving the entity + if data.get("src_id") == entity_name or data.get("tgt_id") == entity_name: + relation_keys.append(key) + + # Exit loop when cursor returns to 0 + if cursor == 0: + break + + # Delete the relation keys + if relation_keys: + deleted = await self._redis.delete(*relation_keys) + logger.debug(f"Deleted {deleted} relations for {entity_name}") + else: + logger.debug(f"No relations found for entity {entity_name}") + + except Exception as e: + logger.error(f"Error deleting relations for {entity_name}: {e}") diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index 51d1c365..f791d401 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -5,7 +5,7 @@ from typing import Any, Union, final import numpy as np -from lightrag.types import KnowledgeGraph +from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage @@ -566,15 +566,148 @@ class TiDBGraphStorage(BaseGraphStorage): pass async def delete_node(self, node_id: str) -> None: - raise NotImplementedError - + """Delete a node and all its related edges + + Args: + node_id: The ID of the node to delete + """ + # First delete all edges related to this node + await self.db.execute(SQL_TEMPLATES["delete_node_edges"], + {"name": node_id, "workspace": self.db.workspace}) + + # Then delete the node itself + await self.db.execute(SQL_TEMPLATES["delete_node"], + {"name": node_id, "workspace": self.db.workspace}) + + logger.debug(f"Node {node_id} and its related edges have been deleted from the graph") + async def get_all_labels(self) -> list[str]: - raise NotImplementedError - + """Get all entity types (labels) in the database + + Returns: + List of labels sorted alphabetically + """ + result = await self.db.query( + SQL_TEMPLATES["get_all_labels"], + {"workspace": self.db.workspace}, + multirows=True + ) + + if not result: + return [] + + # Extract all labels + return [item["label"] for item in result] + async def get_knowledge_graph( self, node_label: str, max_depth: int = 5 ) -> KnowledgeGraph: - raise NotImplementedError + """ + Get a connected subgraph of nodes matching the specified label + Maximum number of nodes is limited by MAX_GRAPH_NODES environment variable (default: 1000) + + Args: + node_label: The node label to match + max_depth: Maximum depth of the subgraph + + Returns: + KnowledgeGraph object containing nodes and edges + """ + result = KnowledgeGraph() + MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) + + # Get matching nodes + if node_label == "*": + # Handle special case, get all nodes + node_results = await self.db.query( + SQL_TEMPLATES["get_all_nodes"], + {"workspace": self.db.workspace, "max_nodes": MAX_GRAPH_NODES}, + multirows=True + ) + else: + # Get nodes matching the label + label_pattern = f"%{node_label}%" + node_results = await self.db.query( + SQL_TEMPLATES["get_matching_nodes"], + {"workspace": self.db.workspace, "label_pattern": label_pattern}, + multirows=True + ) + + if not node_results: + logger.warning(f"No nodes found matching label {node_label}") + return result + + # Limit the number of returned nodes + if len(node_results) > MAX_GRAPH_NODES: + node_results = node_results[:MAX_GRAPH_NODES] + + # Extract node names for edge query + node_names = [node["name"] for node in node_results] + node_names_str = ",".join([f"'{name}'" for name in node_names]) + + # Add nodes to result + for node in node_results: + node_properties = {k: v for k, v in node.items() if k not in ["id", "name", "entity_type"]} + result.nodes.append( + KnowledgeGraphNode( + id=node["name"], + labels=[node["entity_type"]] if node.get("entity_type") else [node["name"]], + properties=node_properties + ) + ) + + # Get related edges + edge_results = await self.db.query( + SQL_TEMPLATES["get_related_edges"].format(node_names=node_names_str), + {"workspace": self.db.workspace}, + multirows=True + ) + + if edge_results: + # Add edges to result + for edge in edge_results: + # Only include edges related to selected nodes + if edge["source_name"] in node_names and edge["target_name"] in node_names: + edge_id = f"{edge['source_name']}-{edge['target_name']}" + edge_properties = {k: v for k, v in edge.items() + if k not in ["id", "source_name", "target_name"]} + + result.edges.append( + KnowledgeGraphEdge( + id=edge_id, + type="RELATED", + source=edge["source_name"], + target=edge["target_name"], + properties=edge_properties + ) + ) + + logger.info( + f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" + ) + return result + + async def remove_nodes(self, nodes: list[str]): + """Delete multiple nodes + + Args: + nodes: List of node IDs to delete + """ + for node_id in nodes: + await self.delete_node(node_id) + + async def remove_edges(self, edges: list[tuple[str, str]]): + """Delete multiple edges + + Args: + edges: List of edges to delete, each edge is a (source, target) tuple + """ + for source, target in edges: + await self.db.execute(SQL_TEMPLATES["remove_multiple_edges"], { + "source": source, + "target": target, + "workspace": self.db.workspace + }) N_T = { @@ -785,4 +918,39 @@ SQL_TEMPLATES = { weight = VALUES(weight), keywords = VALUES(keywords), description = VALUES(description), source_chunk_id = VALUES(source_chunk_id) """, + "delete_node": """ + DELETE FROM LIGHTRAG_GRAPH_NODES + WHERE name = :name AND workspace = :workspace + """, + "delete_node_edges": """ + DELETE FROM LIGHTRAG_GRAPH_EDGES + WHERE (source_name = :name OR target_name = :name) AND workspace = :workspace + """, + "get_all_labels": """ + SELECT DISTINCT entity_type as label + FROM LIGHTRAG_GRAPH_NODES + WHERE workspace = :workspace + ORDER BY entity_type + """, + "get_matching_nodes": """ + SELECT * FROM LIGHTRAG_GRAPH_NODES + WHERE name LIKE :label_pattern AND workspace = :workspace + ORDER BY name + """, + "get_all_nodes": """ + SELECT * FROM LIGHTRAG_GRAPH_NODES + WHERE workspace = :workspace + ORDER BY name + LIMIT :max_nodes + """, + "get_related_edges": """ + SELECT * FROM LIGHTRAG_GRAPH_EDGES + WHERE (source_name IN (:node_names) OR target_name IN (:node_names)) + AND workspace = :workspace + """, + "remove_multiple_edges": """ + DELETE FROM LIGHTRAG_GRAPH_EDGES + WHERE (source_name = :source AND target_name = :target) + AND workspace = :workspace + """ } diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 6f42003d..eeed8a70 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -1399,6 +1399,54 @@ class LightRAG: ] ) + def delete_by_relation(self, source_entity: str, target_entity: str) -> None: + """Synchronously delete a relation between two entities. + + Args: + source_entity: Name of the source entity + target_entity: Name of the target entity + """ + loop = always_get_an_event_loop() + return loop.run_until_complete(self.adelete_by_relation(source_entity, target_entity)) + + async def adelete_by_relation(self, source_entity: str, target_entity: str) -> None: + """Asynchronously delete a relation between two entities. + + Args: + source_entity: Name of the source entity + target_entity: Name of the target entity + """ + try: + # Check if the relation exists + edge_exists = await self.chunk_entity_relation_graph.has_edge(source_entity, target_entity) + if not edge_exists: + logger.warning(f"Relation from '{source_entity}' to '{target_entity}' does not exist") + return + + # Delete relation from vector database + relation_id = compute_mdhash_id(source_entity + target_entity, prefix="rel-") + await self.relationships_vdb.delete([relation_id]) + + # Delete relation from knowledge graph + await self.chunk_entity_relation_graph.remove_edges([(source_entity, target_entity)]) + + logger.info(f"Successfully deleted relation from '{source_entity}' to '{target_entity}'") + await self._delete_relation_done() + except Exception as e: + logger.error(f"Error while deleting relation from '{source_entity}' to '{target_entity}': {e}") + + async def _delete_relation_done(self) -> None: + """Callback after relation deletion is complete""" + await asyncio.gather( + *[ + cast(StorageNameSpace, storage_inst).index_done_callback() + for storage_inst in [ # type: ignore + self.relationships_vdb, + self.chunk_entity_relation_graph, + ] + ] + ) + def _get_content_summary(self, content: str, max_length: int = 100) -> str: """Get summary of document content From 81568f3badbba85294ae0fc2d759a6f7f1715706 Mon Sep 17 00:00:00 2001 From: zrguo Date: Tue, 4 Mar 2025 15:53:20 +0800 Subject: [PATCH 10/21] fix linting --- lightrag/kg/age_impl.py | 54 ++++++++-------- lightrag/kg/chroma_impl.py | 14 ++-- lightrag/kg/gremlin_impl.py | 120 +++++++++++++++++++--------------- lightrag/kg/milvus_impl.py | 59 ++++++++--------- lightrag/kg/mongo_impl.py | 48 ++++++++------ lightrag/kg/oracle_impl.py | 108 ++++++++++++++++--------------- lightrag/kg/postgres_impl.py | 78 +++++++++++----------- lightrag/kg/qdrant_impl.py | 40 ++++++------ lightrag/kg/redis_impl.py | 39 ++++++----- lightrag/kg/tidb_impl.py | 121 ++++++++++++++++++++--------------- lightrag/lightrag.py | 40 ++++++++---- 11 files changed, 394 insertions(+), 327 deletions(-) diff --git a/lightrag/kg/age_impl.py b/lightrag/kg/age_impl.py index c6b98221..22951554 100644 --- a/lightrag/kg/age_impl.py +++ b/lightrag/kg/age_impl.py @@ -619,7 +619,7 @@ class AGEStorage(BaseGraphStorage): node_id: The label of the node to delete """ entity_name_label = node_id.strip('"') - + query = """ MATCH (n:`{label}`) DETACH DELETE n @@ -650,18 +650,20 @@ class AGEStorage(BaseGraphStorage): for source, target in edges: entity_name_label_source = source.strip('"') entity_name_label_target = target.strip('"') - + query = """ MATCH (source:`{src_label}`)-[r]->(target:`{tgt_label}`) DELETE r """ params = { "src_label": AGEStorage._encode_graph_label(entity_name_label_source), - "tgt_label": AGEStorage._encode_graph_label(entity_name_label_target) + "tgt_label": AGEStorage._encode_graph_label(entity_name_label_target), } try: await self._query(query, **params) - logger.debug(f"Deleted edge from '{entity_name_label_source}' to '{entity_name_label_target}'") + logger.debug( + f"Deleted edge from '{entity_name_label_source}' to '{entity_name_label_target}'" + ) except Exception as e: logger.error(f"Error during edge deletion: {str(e)}") raise @@ -683,7 +685,7 @@ class AGEStorage(BaseGraphStorage): async def get_all_labels(self) -> list[str]: """Get all node labels in the database - + Returns: ["label1", "label2", ...] # Alphabetically sorted label list """ @@ -692,7 +694,7 @@ class AGEStorage(BaseGraphStorage): RETURN DISTINCT labels(n) AS node_labels """ results = await self._query(query) - + all_labels = [] for record in results: if record and "node_labels" in record: @@ -701,7 +703,7 @@ class AGEStorage(BaseGraphStorage): # Decode label decoded_label = AGEStorage._decode_graph_label(label) all_labels.append(decoded_label) - + # Remove duplicates and sort return sorted(list(set(all_labels))) @@ -719,7 +721,7 @@ class AGEStorage(BaseGraphStorage): Args: node_label: String to match in node labels (will match any node containing this string in its label) max_depth: Maximum depth of the graph. Defaults to 5. - + Returns: KnowledgeGraph: Complete connected subgraph for specified node """ @@ -727,7 +729,7 @@ class AGEStorage(BaseGraphStorage): result = KnowledgeGraph() seen_nodes = set() seen_edges = set() - + # Handle special case for "*" label if node_label == "*": # Query all nodes and sort by degree @@ -741,7 +743,7 @@ class AGEStorage(BaseGraphStorage): """ params = {"max_nodes": max_graph_nodes} nodes_result = await self._query(query, **params) - + # Add nodes to result node_ids = [] for record in nodes_result: @@ -755,12 +757,12 @@ class AGEStorage(BaseGraphStorage): KnowledgeGraphNode( id=node_id, labels=[node_label], - properties=node_properties + properties=node_properties, ) ) seen_nodes.add(node_id) node_ids.append(node_id) - + # Query edges between these nodes if node_ids: edges_query = """ @@ -770,7 +772,7 @@ class AGEStorage(BaseGraphStorage): """ edges_params = {"node_ids": node_ids} edges_result = await self._query(edges_query, **edges_params) - + # Add edges to result for record in edges_result: if "r" in record and "a" in record and "b" in record: @@ -785,7 +787,7 @@ class AGEStorage(BaseGraphStorage): type="DIRECTED", source=source, target=target, - properties=edge_properties + properties=edge_properties, ) ) seen_edges.add(edge_id) @@ -793,7 +795,7 @@ class AGEStorage(BaseGraphStorage): # For specific label, use partial matching entity_name_label = node_label.strip('"') encoded_label = AGEStorage._encode_graph_label(entity_name_label) - + # Find matching start nodes start_query = """ MATCH (n:`{label}`) @@ -801,17 +803,14 @@ class AGEStorage(BaseGraphStorage): """ start_params = {"label": encoded_label} start_nodes = await self._query(start_query, **start_params) - + if not start_nodes: logger.warning(f"No nodes found with label '{entity_name_label}'!") return result - + # Traverse graph from each start node for start_node_record in start_nodes: if "n" in start_node_record: - start_node = start_node_record["n"] - start_id = str(start_node.get("id", "")) - # Use BFS to traverse graph query = """ MATCH (start:`{label}`) @@ -823,25 +822,28 @@ class AGEStorage(BaseGraphStorage): """ params = {"label": encoded_label, "max_depth": max_depth} results = await self._query(query, **params) - + # Extract nodes and edges from results for record in results: if "path_nodes" in record: # Process nodes for node in record["path_nodes"]: node_id = str(node.get("id", "")) - if node_id not in seen_nodes and len(seen_nodes) < max_graph_nodes: + if ( + node_id not in seen_nodes + and len(seen_nodes) < max_graph_nodes + ): node_properties = {k: v for k, v in node.items()} node_label = node.get("label", "") result.nodes.append( KnowledgeGraphNode( id=node_id, labels=[node_label], - properties=node_properties + properties=node_properties, ) ) seen_nodes.add(node_id) - + if "path_rels" in record: # Process edges for rel in record["path_rels"]: @@ -856,11 +858,11 @@ class AGEStorage(BaseGraphStorage): type=rel.get("label", "DIRECTED"), source=source, target=target, - properties=edge_properties + properties=edge_properties, ) ) seen_edges.add(edge_id) - + logger.info( f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" ) diff --git a/lightrag/kg/chroma_impl.py b/lightrag/kg/chroma_impl.py index d36e6d7c..ea4b31a1 100644 --- a/lightrag/kg/chroma_impl.py +++ b/lightrag/kg/chroma_impl.py @@ -194,7 +194,7 @@ class ChromaVectorDBStorage(BaseVectorStorage): async def delete_entity(self, entity_name: str) -> None: """Delete an entity by its ID. - + Args: entity_name: The ID of the entity to delete """ @@ -206,24 +206,26 @@ class ChromaVectorDBStorage(BaseVectorStorage): raise async def delete_entity_relation(self, entity_name: str) -> None: - """Delete an entity and its relations by ID. + """Delete an entity and its relations by ID. In vector DB context, this is equivalent to delete_entity. - + Args: entity_name: The ID of the entity to delete """ await self.delete_entity(entity_name) - + async def delete(self, ids: list[str]) -> None: """Delete vectors with specified IDs - + Args: ids: List of vector IDs to be deleted """ try: logger.info(f"Deleting {len(ids)} vectors from {self.namespace}") self._collection.delete(ids=ids) - logger.debug(f"Successfully deleted {len(ids)} vectors from {self.namespace}") + logger.debug( + f"Successfully deleted {len(ids)} vectors from {self.namespace}" + ) except Exception as e: logger.error(f"Error while deleting vectors from {self.namespace}: {e}") raise diff --git a/lightrag/kg/gremlin_impl.py b/lightrag/kg/gremlin_impl.py index 4d343bb5..ddb7559f 100644 --- a/lightrag/kg/gremlin_impl.py +++ b/lightrag/kg/gremlin_impl.py @@ -397,12 +397,12 @@ class GremlinStorage(BaseGraphStorage): async def delete_node(self, node_id: str) -> None: """Delete a node with the specified entity_name - + Args: node_id: The entity_name of the node to delete """ entity_name = GremlinStorage._fix_name(node_id) - + query = f"""g .V().has('graph', {self.graph_name}) .has('entity_name', {entity_name}) @@ -413,7 +413,7 @@ class GremlinStorage(BaseGraphStorage): logger.debug( "{%s}: Deleted node with entity_name '%s'", inspect.currentframe().f_code.co_name, - entity_name + entity_name, ) except Exception as e: logger.error(f"Error during node deletion: {str(e)}") @@ -425,13 +425,13 @@ class GremlinStorage(BaseGraphStorage): """ Embed nodes using the specified algorithm. Currently, only node2vec is supported but never called. - + Args: algorithm: The name of the embedding algorithm to use - + Returns: A tuple of (embeddings, node_ids) - + Raises: NotImplementedError: If the specified algorithm is not supported ValueError: If the algorithm is not supported @@ -458,7 +458,7 @@ class GremlinStorage(BaseGraphStorage): logger.debug( "{%s}: Retrieved %d labels", inspect.currentframe().f_code.co_name, - len(labels) + len(labels), ) return labels except Exception as e: @@ -471,7 +471,7 @@ class GremlinStorage(BaseGraphStorage): """ Retrieve a connected subgraph of nodes where the entity_name includes the specified `node_label`. Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000). - + Args: node_label: Entity name of the starting node max_depth: Maximum depth of the subgraph @@ -482,12 +482,12 @@ class GremlinStorage(BaseGraphStorage): result = KnowledgeGraph() seen_nodes = set() seen_edges = set() - + # Get maximum number of graph nodes from environment variable, default is 1000 MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) - + entity_name = GremlinStorage._fix_name(node_label) - + # Handle special case for "*" label if node_label == "*": # For "*", get all nodes and their edges (limited by MAX_GRAPH_NODES) @@ -497,25 +497,27 @@ class GremlinStorage(BaseGraphStorage): .elementMap() """ nodes_result = await self._query(query) - + # Add nodes to result for node_data in nodes_result: - node_id = node_data.get('entity_name', str(node_data.get('id', ''))) + node_id = node_data.get("entity_name", str(node_data.get("id", ""))) if str(node_id) in seen_nodes: continue - + # Create node with properties - node_properties = {k: v for k, v in node_data.items() if k not in ['id', 'label']} - + node_properties = { + k: v for k, v in node_data.items() if k not in ["id", "label"] + } + result.nodes.append( KnowledgeGraphNode( id=str(node_id), - labels=[str(node_id)], - properties=node_properties + labels=[str(node_id)], + properties=node_properties, ) ) seen_nodes.add(str(node_id)) - + # Get and add edges if nodes_result: query = f"""g @@ -530,30 +532,34 @@ class GremlinStorage(BaseGraphStorage): .by(elementMap()) """ edges_result = await self._query(query) - + for path in edges_result: if len(path) >= 3: # source -> edge -> target source = path[0] edge_data = path[1] target = path[2] - - source_id = source.get('entity_name', str(source.get('id', ''))) - target_id = target.get('entity_name', str(target.get('id', ''))) - + + source_id = source.get("entity_name", str(source.get("id", ""))) + target_id = target.get("entity_name", str(target.get("id", ""))) + edge_id = f"{source_id}-{target_id}" if edge_id in seen_edges: continue - + # Create edge with properties - edge_properties = {k: v for k, v in edge_data.items() if k not in ['id', 'label']} - + edge_properties = { + k: v + for k, v in edge_data.items() + if k not in ["id", "label"] + } + result.edges.append( KnowledgeGraphEdge( id=edge_id, type="DIRECTED", source=str(source_id), target=str(target_id), - properties=edge_properties + properties=edge_properties, ) ) seen_edges.add(edge_id) @@ -570,30 +576,36 @@ class GremlinStorage(BaseGraphStorage): .elementMap() """ nodes_result = await self._query(query) - + # Add nodes to result for node_data in nodes_result: - node_id = node_data.get('entity_name', str(node_data.get('id', ''))) + node_id = node_data.get("entity_name", str(node_data.get("id", ""))) if str(node_id) in seen_nodes: continue - + # Create node with properties - node_properties = {k: v for k, v in node_data.items() if k not in ['id', 'label']} - + node_properties = { + k: v for k, v in node_data.items() if k not in ["id", "label"] + } + result.nodes.append( KnowledgeGraphNode( id=str(node_id), - labels=[str(node_id)], - properties=node_properties + labels=[str(node_id)], + properties=node_properties, ) ) seen_nodes.add(str(node_id)) - + # Get edges between the nodes in the result if nodes_result: - node_ids = [n.get('entity_name', str(n.get('id', ''))) for n in nodes_result] - node_ids_query = ", ".join([GremlinStorage._to_value_map(nid) for nid in node_ids]) - + node_ids = [ + n.get("entity_name", str(n.get("id", ""))) for n in nodes_result + ] + node_ids_query = ", ".join( + [GremlinStorage._to_value_map(nid) for nid in node_ids] + ) + query = f"""g .V().has('graph', {self.graph_name}) .has('entity_name', within({node_ids_query})) @@ -606,38 +618,42 @@ class GremlinStorage(BaseGraphStorage): .by(elementMap()) """ edges_result = await self._query(query) - + for path in edges_result: if len(path) >= 3: # source -> edge -> target source = path[0] edge_data = path[1] target = path[2] - - source_id = source.get('entity_name', str(source.get('id', ''))) - target_id = target.get('entity_name', str(target.get('id', ''))) - + + source_id = source.get("entity_name", str(source.get("id", ""))) + target_id = target.get("entity_name", str(target.get("id", ""))) + edge_id = f"{source_id}-{target_id}" if edge_id in seen_edges: continue - + # Create edge with properties - edge_properties = {k: v for k, v in edge_data.items() if k not in ['id', 'label']} - + edge_properties = { + k: v + for k, v in edge_data.items() + if k not in ["id", "label"] + } + result.edges.append( KnowledgeGraphEdge( id=edge_id, type="DIRECTED", source=str(source_id), target=str(target_id), - properties=edge_properties + properties=edge_properties, ) ) seen_edges.add(edge_id) - + logger.info( "Subgraph query successful | Node count: %d | Edge count: %d", len(result.nodes), - len(result.edges) + len(result.edges), ) return result @@ -659,7 +675,7 @@ class GremlinStorage(BaseGraphStorage): for source, target in edges: entity_name_source = GremlinStorage._fix_name(source) entity_name_target = GremlinStorage._fix_name(target) - + query = f"""g .V().has('graph', {self.graph_name}) .has('entity_name', {entity_name_source}) @@ -674,7 +690,7 @@ class GremlinStorage(BaseGraphStorage): "{%s}: Deleted edge from '%s' to '%s'", inspect.currentframe().f_code.co_name, entity_name_source, - entity_name_target + entity_name_target, ) except Exception as e: logger.error(f"Error during edge deletion: {str(e)}") diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index 2ad4da18..7242f03d 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -125,83 +125,84 @@ class MilvusVectorDBStorage(BaseVectorStorage): async def delete_entity(self, entity_name: str) -> None: """Delete an entity from the vector database - + Args: entity_name: The name of the entity to delete """ try: # Compute entity ID from name entity_id = compute_mdhash_id(entity_name, prefix="ent-") - logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}") - + logger.debug( + f"Attempting to delete entity {entity_name} with ID {entity_id}" + ) + # Delete the entity from Milvus collection result = self._client.delete( - collection_name=self.namespace, - pks=[entity_id] + collection_name=self.namespace, pks=[entity_id] ) - + if result and result.get("delete_count", 0) > 0: logger.debug(f"Successfully deleted entity {entity_name}") else: logger.debug(f"Entity {entity_name} not found in storage") - + except Exception as e: logger.error(f"Error deleting entity {entity_name}: {e}") async def delete_entity_relation(self, entity_name: str) -> None: """Delete all relations associated with an entity - + Args: entity_name: The name of the entity whose relations should be deleted """ try: # Search for relations where entity is either source or target expr = f'src_id == "{entity_name}" or tgt_id == "{entity_name}"' - + # Find all relations involving this entity results = self._client.query( - collection_name=self.namespace, - filter=expr, - output_fields=["id"] + collection_name=self.namespace, filter=expr, output_fields=["id"] ) - + if not results or len(results) == 0: logger.debug(f"No relations found for entity {entity_name}") return - + # Extract IDs of relations to delete relation_ids = [item["id"] for item in results] - logger.debug(f"Found {len(relation_ids)} relations for entity {entity_name}") - + logger.debug( + f"Found {len(relation_ids)} relations for entity {entity_name}" + ) + # Delete the relations if relation_ids: delete_result = self._client.delete( - collection_name=self.namespace, - pks=relation_ids + collection_name=self.namespace, pks=relation_ids ) - - logger.debug(f"Deleted {delete_result.get('delete_count', 0)} relations for {entity_name}") - + + logger.debug( + f"Deleted {delete_result.get('delete_count', 0)} relations for {entity_name}" + ) + except Exception as e: logger.error(f"Error deleting relations for {entity_name}: {e}") async def delete(self, ids: list[str]) -> None: """Delete vectors with specified IDs - + Args: ids: List of vector IDs to be deleted """ try: # Delete vectors by IDs - result = self._client.delete( - collection_name=self.namespace, - pks=ids - ) - + result = self._client.delete(collection_name=self.namespace, pks=ids) + if result and result.get("delete_count", 0) > 0: - logger.debug(f"Successfully deleted {result.get('delete_count', 0)} vectors from {self.namespace}") + logger.debug( + f"Successfully deleted {result.get('delete_count', 0)} vectors from {self.namespace}" + ) else: logger.debug(f"No vectors were deleted from {self.namespace}") - + except Exception as e: logger.error(f"Error while deleting vectors from {self.namespace}: {e}") diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 3afd2b44..c2957502 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -804,16 +804,15 @@ class MongoGraphStorage(BaseGraphStorage): logger.info(f"Deleting {len(nodes)} nodes") if not nodes: return - + # 1. Remove all edges referencing these nodes (remove from edges array of other nodes) await self.collection.update_many( - {}, - {"$pull": {"edges": {"target": {"$in": nodes}}}} + {}, {"$pull": {"edges": {"target": {"$in": nodes}}}} ) - + # 2. Delete the node documents await self.collection.delete_many({"_id": {"$in": nodes}}) - + logger.debug(f"Successfully deleted nodes: {nodes}") async def remove_edges(self, edges: list[tuple[str, str]]) -> None: @@ -825,20 +824,19 @@ class MongoGraphStorage(BaseGraphStorage): logger.info(f"Deleting {len(edges)} edges") if not edges: return - + update_tasks = [] for source, target in edges: # Remove edge pointing to target from source node's edges array update_tasks.append( self.collection.update_one( - {"_id": source}, - {"$pull": {"edges": {"target": target}}} + {"_id": source}, {"$pull": {"edges": {"target": target}}} ) ) - + if update_tasks: await asyncio.gather(*update_tasks) - + logger.debug(f"Successfully deleted edges: {edges}") @@ -987,23 +985,29 @@ class MongoVectorDBStorage(BaseVectorStorage): logger.info(f"Deleting {len(ids)} vectors from {self.namespace}") if not ids: return - + try: result = await self._data.delete_many({"_id": {"$in": ids}}) - logger.debug(f"Successfully deleted {result.deleted_count} vectors from {self.namespace}") + logger.debug( + f"Successfully deleted {result.deleted_count} vectors from {self.namespace}" + ) except PyMongoError as e: - logger.error(f"Error while deleting vectors from {self.namespace}: {str(e)}") + logger.error( + f"Error while deleting vectors from {self.namespace}: {str(e)}" + ) async def delete_entity(self, entity_name: str) -> None: """Delete an entity by its name - + Args: entity_name: Name of the entity to delete """ try: entity_id = compute_mdhash_id(entity_name, prefix="ent-") - logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}") - + logger.debug( + f"Attempting to delete entity {entity_name} with ID {entity_id}" + ) + result = await self._data.delete_one({"_id": entity_id}) if result.deleted_count > 0: logger.debug(f"Successfully deleted entity {entity_name}") @@ -1014,7 +1018,7 @@ class MongoVectorDBStorage(BaseVectorStorage): async def delete_entity_relation(self, entity_name: str) -> None: """Delete all relations associated with an entity - + Args: entity_name: Name of the entity whose relations should be deleted """ @@ -1024,15 +1028,17 @@ class MongoVectorDBStorage(BaseVectorStorage): {"$or": [{"src_id": entity_name}, {"tgt_id": entity_name}]} ) relations = await relations_cursor.to_list(length=None) - + if not relations: logger.debug(f"No relations found for entity {entity_name}") return - + # Extract IDs of relations to delete relation_ids = [relation["_id"] for relation in relations] - logger.debug(f"Found {len(relation_ids)} relations for entity {entity_name}") - + logger.debug( + f"Found {len(relation_ids)} relations for entity {entity_name}" + ) + # Delete the relations result = await self._data.delete_many({"_id": {"$in": relation_ids}}) logger.debug(f"Deleted {result.deleted_count} relations for {entity_name}") diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index d189679e..5dee1143 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -444,27 +444,29 @@ class OracleVectorDBStorage(BaseVectorStorage): async def delete(self, ids: list[str]) -> None: """Delete vectors with specified IDs - + Args: ids: List of vector IDs to be deleted """ if not ids: return - + try: SQL = SQL_TEMPLATES["delete_vectors"].format( ids=",".join([f"'{id}'" for id in ids]) ) params = {"workspace": self.db.workspace} await self.db.execute(SQL, params) - logger.info(f"Successfully deleted {len(ids)} vectors from {self.namespace}") + logger.info( + f"Successfully deleted {len(ids)} vectors from {self.namespace}" + ) except Exception as e: logger.error(f"Error while deleting vectors from {self.namespace}: {e}") raise async def delete_entity(self, entity_name: str) -> None: """Delete entity by name - + Args: entity_name: Name of the entity to delete """ @@ -479,7 +481,7 @@ class OracleVectorDBStorage(BaseVectorStorage): async def delete_entity_relation(self, entity_name: str) -> None: """Delete all relations connected to an entity - + Args: entity_name: Name of the entity whose relations should be deleted """ @@ -713,7 +715,7 @@ class OracleGraphStorage(BaseGraphStorage): async def delete_node(self, node_id: str) -> None: """Delete a node from the graph - + Args: node_id: ID of the node to delete """ @@ -722,33 +724,35 @@ class OracleGraphStorage(BaseGraphStorage): delete_relations_sql = SQL_TEMPLATES["delete_entity_relations"] params_relations = {"workspace": self.db.workspace, "entity_name": node_id} await self.db.execute(delete_relations_sql, params_relations) - + # Then delete the node itself delete_node_sql = SQL_TEMPLATES["delete_entity"] params_node = {"workspace": self.db.workspace, "entity_name": node_id} await self.db.execute(delete_node_sql, params_node) - - logger.info(f"Successfully deleted node {node_id} and all its relationships") + + logger.info( + f"Successfully deleted node {node_id} and all its relationships" + ) except Exception as e: logger.error(f"Error deleting node {node_id}: {e}") raise async def get_all_labels(self) -> list[str]: """Get all unique entity types (labels) in the graph - + Returns: List of unique entity types/labels """ try: SQL = """ - SELECT DISTINCT entity_type - FROM LIGHTRAG_GRAPH_NODES - WHERE workspace = :workspace + SELECT DISTINCT entity_type + FROM LIGHTRAG_GRAPH_NODES + WHERE workspace = :workspace ORDER BY entity_type """ params = {"workspace": self.db.workspace} results = await self.db.query(SQL, params, multirows=True) - + if results: labels = [row["entity_type"] for row in results] return labels @@ -762,26 +766,26 @@ class OracleGraphStorage(BaseGraphStorage): self, node_label: str, max_depth: int = 5 ) -> KnowledgeGraph: """Retrieve a connected subgraph starting from nodes matching the given label - + Maximum number of nodes is constrained by MAX_GRAPH_NODES environment variable. Prioritizes nodes by: 1. Nodes matching the specified label 2. Nodes directly connected to matching nodes 3. Node degree (number of connections) - + Args: node_label: Label to match for starting nodes (use "*" for all nodes) max_depth: Maximum depth of traversal from starting nodes - + Returns: KnowledgeGraph object containing nodes and edges """ result = KnowledgeGraph() - + try: # Define maximum number of nodes to return max_graph_nodes = int(os.environ.get("MAX_GRAPH_NODES", 1000)) - + if node_label == "*": # For "*" label, get all nodes up to the limit nodes_sql = """ @@ -791,30 +795,33 @@ class OracleGraphStorage(BaseGraphStorage): ORDER BY id FETCH FIRST :limit ROWS ONLY """ - nodes_params = {"workspace": self.db.workspace, "limit": max_graph_nodes} + nodes_params = { + "workspace": self.db.workspace, + "limit": max_graph_nodes, + } nodes = await self.db.query(nodes_sql, nodes_params, multirows=True) else: # For specific label, find matching nodes and related nodes nodes_sql = """ WITH matching_nodes AS ( - SELECT name + SELECT name FROM LIGHTRAG_GRAPH_NODES - WHERE workspace = :workspace + WHERE workspace = :workspace AND (name LIKE '%' || :node_label || '%' OR entity_type LIKE '%' || :node_label || '%') ) SELECT n.name, n.entity_type, n.description, n.source_chunk_id, CASE WHEN n.name IN (SELECT name FROM matching_nodes) THEN 2 WHEN EXISTS ( - SELECT 1 FROM LIGHTRAG_GRAPH_EDGES e + SELECT 1 FROM LIGHTRAG_GRAPH_EDGES e WHERE workspace = :workspace AND ((e.source_name = n.name AND e.target_name IN (SELECT name FROM matching_nodes)) OR (e.target_name = n.name AND e.source_name IN (SELECT name FROM matching_nodes))) ) THEN 1 ELSE 0 END AS priority, - (SELECT COUNT(*) FROM LIGHTRAG_GRAPH_EDGES e - WHERE workspace = :workspace + (SELECT COUNT(*) FROM LIGHTRAG_GRAPH_EDGES e + WHERE workspace = :workspace AND (e.source_name = n.name OR e.target_name = n.name)) AS degree FROM LIGHTRAG_GRAPH_NODES n WHERE workspace = :workspace @@ -822,43 +829,41 @@ class OracleGraphStorage(BaseGraphStorage): FETCH FIRST :limit ROWS ONLY """ nodes_params = { - "workspace": self.db.workspace, + "workspace": self.db.workspace, "node_label": node_label, - "limit": max_graph_nodes + "limit": max_graph_nodes, } nodes = await self.db.query(nodes_sql, nodes_params, multirows=True) - + if not nodes: logger.warning(f"No nodes found matching '{node_label}'") return result - + # Create mapping of node IDs to be used to filter edges node_names = [node["name"] for node in nodes] - + # Add nodes to result seen_nodes = set() for node in nodes: node_id = node["name"] if node_id in seen_nodes: continue - + # Create node properties dictionary properties = { "entity_type": node["entity_type"], "description": node["description"] or "", - "source_id": node["source_chunk_id"] or "" + "source_id": node["source_chunk_id"] or "", } - + # Add node to result result.nodes.append( KnowledgeGraphNode( - id=node_id, - labels=[node["entity_type"]], - properties=properties + id=node_id, labels=[node["entity_type"]], properties=properties ) ) seen_nodes.add(node_id) - + # Get edges between these nodes edges_sql = """ SELECT source_name, target_name, weight, keywords, description, source_chunk_id @@ -868,30 +873,27 @@ class OracleGraphStorage(BaseGraphStorage): AND target_name IN (SELECT COLUMN_VALUE FROM TABLE(CAST(:node_names AS SYS.ODCIVARCHAR2LIST))) ORDER BY id """ - edges_params = { - "workspace": self.db.workspace, - "node_names": node_names - } + edges_params = {"workspace": self.db.workspace, "node_names": node_names} edges = await self.db.query(edges_sql, edges_params, multirows=True) - + # Add edges to result seen_edges = set() for edge in edges: source = edge["source_name"] target = edge["target_name"] edge_id = f"{source}-{target}" - + if edge_id in seen_edges: continue - + # Create edge properties dictionary properties = { "weight": edge["weight"] or 0.0, "keywords": edge["keywords"] or "", "description": edge["description"] or "", - "source_id": edge["source_chunk_id"] or "" + "source_id": edge["source_chunk_id"] or "", } - + # Add edge to result result.edges.append( KnowledgeGraphEdge( @@ -899,18 +901,18 @@ class OracleGraphStorage(BaseGraphStorage): type="RELATED", source=source, target=target, - properties=properties + properties=properties, ) ) seen_edges.add(edge_id) - + logger.info( f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" ) - + except Exception as e: logger.error(f"Error retrieving knowledge graph: {e}") - + return result @@ -1166,8 +1168,8 @@ SQL_TEMPLATES = { "delete_vectors": "DELETE FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=:workspace AND id IN ({ids})", "delete_entity": "DELETE FROM LIGHTRAG_GRAPH_NODES WHERE workspace=:workspace AND name=:entity_name", "delete_entity_relations": "DELETE FROM LIGHTRAG_GRAPH_EDGES WHERE workspace=:workspace AND (source_name=:entity_name OR target_name=:entity_name)", - "delete_node": """DELETE FROM GRAPH_TABLE (lightrag_graph - MATCH (a) - WHERE a.workspace=:workspace AND a.name=:node_id + "delete_node": """DELETE FROM GRAPH_TABLE (lightrag_graph + MATCH (a) + WHERE a.workspace=:workspace AND a.name=:node_id ACTION DELETE a)""", } diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 7ce2b427..54a59f5d 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -527,11 +527,15 @@ class PGVectorStorage(BaseVectorStorage): return ids_list = ",".join([f"'{id}'" for id in ids]) - delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id IN ({ids_list})" - + delete_sql = ( + f"DELETE FROM {table_name} WHERE workspace=$1 AND id IN ({ids_list})" + ) + try: await self.db.execute(delete_sql, {"workspace": self.db.workspace}) - logger.debug(f"Successfully deleted {len(ids)} vectors from {self.namespace}") + logger.debug( + f"Successfully deleted {len(ids)} vectors from {self.namespace}" + ) except Exception as e: logger.error(f"Error while deleting vectors from {self.namespace}: {e}") @@ -543,12 +547,11 @@ class PGVectorStorage(BaseVectorStorage): """ try: # Construct SQL to delete the entity - delete_sql = """DELETE FROM LIGHTRAG_VDB_ENTITY + delete_sql = """DELETE FROM LIGHTRAG_VDB_ENTITY WHERE workspace=$1 AND entity_name=$2""" - + await self.db.execute( - delete_sql, - {"workspace": self.db.workspace, "entity_name": entity_name} + delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name} ) logger.debug(f"Successfully deleted entity {entity_name}") except Exception as e: @@ -562,12 +565,11 @@ class PGVectorStorage(BaseVectorStorage): """ try: # Delete relations where the entity is either the source or target - delete_sql = """DELETE FROM LIGHTRAG_VDB_RELATION + delete_sql = """DELETE FROM LIGHTRAG_VDB_RELATION WHERE workspace=$1 AND (source_id=$2 OR target_id=$2)""" - + await self.db.execute( - delete_sql, - {"workspace": self.db.workspace, "entity_name": entity_name} + delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name} ) logger.debug(f"Successfully deleted relations for entity {entity_name}") except Exception as e: @@ -1167,7 +1169,9 @@ class PGGraphStorage(BaseGraphStorage): Args: node_ids (list[str]): A list of node IDs to remove. """ - encoded_node_ids = [self._encode_graph_label(node_id.strip('"')) for node_id in node_ids] + encoded_node_ids = [ + self._encode_graph_label(node_id.strip('"')) for node_id in node_ids + ] node_id_list = ", ".join([f'"{node_id}"' for node_id in encoded_node_ids]) query = """SELECT * FROM cypher('%s', $$ @@ -1189,7 +1193,13 @@ class PGGraphStorage(BaseGraphStorage): Args: edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id). """ - encoded_edges = [(self._encode_graph_label(src.strip('"')), self._encode_graph_label(tgt.strip('"'))) for src, tgt in edges] + encoded_edges = [ + ( + self._encode_graph_label(src.strip('"')), + self._encode_graph_label(tgt.strip('"')), + ) + for src, tgt in edges + ] edge_list = ", ".join([f'["{src}", "{tgt}"]' for src, tgt in encoded_edges]) query = """SELECT * FROM cypher('%s', $$ @@ -1211,10 +1221,13 @@ class PGGraphStorage(BaseGraphStorage): Returns: list[str]: A list of all labels in the graph. """ - query = """SELECT * FROM cypher('%s', $$ + query = ( + """SELECT * FROM cypher('%s', $$ MATCH (n:Entity) RETURN DISTINCT n.node_id AS label - $$) AS (label text)""" % self.graph_name + $$) AS (label text)""" + % self.graph_name + ) results = await self._query(query) labels = [self._decode_graph_label(result["label"]) for result in results] @@ -1260,7 +1273,10 @@ class PGGraphStorage(BaseGraphStorage): OPTIONAL MATCH (n)-[r]->(m:Entity) RETURN n, r, m LIMIT %d - $$) AS (n agtype, r agtype, m agtype)""" % (self.graph_name, MAX_GRAPH_NODES) + $$) AS (n agtype, r agtype, m agtype)""" % ( + self.graph_name, + MAX_GRAPH_NODES, + ) else: encoded_node_label = self._encode_graph_label(node_label.strip('"')) query = """SELECT * FROM cypher('%s', $$ @@ -1268,7 +1284,12 @@ class PGGraphStorage(BaseGraphStorage): OPTIONAL MATCH p = (n)-[*..%d]-(m) RETURN nodes(p) AS nodes, relationships(p) AS relationships LIMIT %d - $$) AS (nodes agtype[], relationships agtype[])""" % (self.graph_name, encoded_node_label, max_depth, MAX_GRAPH_NODES) + $$) AS (nodes agtype[], relationships agtype[])""" % ( + self.graph_name, + encoded_node_label, + max_depth, + MAX_GRAPH_NODES, + ) results = await self._query(query) @@ -1305,29 +1326,6 @@ class PGGraphStorage(BaseGraphStorage): return kg - async def get_all_labels(self) -> list[str]: - """ - Get all node labels in the graph - Returns: - [label1, label2, ...] # Alphabetically sorted label list - """ - query = """SELECT * FROM cypher('%s', $$ - MATCH (n:Entity) - RETURN DISTINCT n.node_id AS label - ORDER BY label - $$) AS (label agtype)""" % (self.graph_name) - - try: - results = await self._query(query) - labels = [] - for record in results: - if record["label"]: - labels.append(self._decode_graph_label(record["label"])) - return labels - except Exception as e: - logger.error(f"Error getting all labels: {str(e)}") - return [] - async def drop(self) -> None: """Drop the storage""" drop_sql = SQL_TEMPLATES["drop_vdb_entity"] diff --git a/lightrag/kg/qdrant_impl.py b/lightrag/kg/qdrant_impl.py index e3488caa..c7d346e6 100644 --- a/lightrag/kg/qdrant_impl.py +++ b/lightrag/kg/qdrant_impl.py @@ -143,7 +143,7 @@ class QdrantVectorDBStorage(BaseVectorStorage): async def delete(self, ids: List[str]) -> None: """Delete vectors with specified IDs - + Args: ids: List of vector IDs to be deleted """ @@ -156,30 +156,34 @@ class QdrantVectorDBStorage(BaseVectorStorage): points_selector=models.PointIdsList( points=qdrant_ids, ), - wait=True + wait=True, + ) + logger.debug( + f"Successfully deleted {len(ids)} vectors from {self.namespace}" ) - logger.debug(f"Successfully deleted {len(ids)} vectors from {self.namespace}") except Exception as e: logger.error(f"Error while deleting vectors from {self.namespace}: {e}") async def delete_entity(self, entity_name: str) -> None: """Delete an entity by name - + Args: entity_name: Name of the entity to delete """ try: # Generate the entity ID entity_id = compute_mdhash_id_for_qdrant(entity_name, prefix="ent-") - logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}") - + logger.debug( + f"Attempting to delete entity {entity_name} with ID {entity_id}" + ) + # Delete the entity point from the collection self._client.delete( collection_name=self.namespace, points_selector=models.PointIdsList( points=[entity_id], ), - wait=True + wait=True, ) logger.debug(f"Successfully deleted entity {entity_name}") except Exception as e: @@ -187,7 +191,7 @@ class QdrantVectorDBStorage(BaseVectorStorage): async def delete_entity_relation(self, entity_name: str) -> None: """Delete all relations associated with an entity - + Args: entity_name: Name of the entity whose relations should be deleted """ @@ -198,23 +202,21 @@ class QdrantVectorDBStorage(BaseVectorStorage): scroll_filter=models.Filter( should=[ models.FieldCondition( - key="src_id", - match=models.MatchValue(value=entity_name) + key="src_id", match=models.MatchValue(value=entity_name) ), models.FieldCondition( - key="tgt_id", - match=models.MatchValue(value=entity_name) - ) + key="tgt_id", match=models.MatchValue(value=entity_name) + ), ] ), with_payload=True, - limit=1000 # Adjust as needed for your use case + limit=1000, # Adjust as needed for your use case ) - + # Extract points that need to be deleted relation_points = results[0] ids_to_delete = [point.id for point in relation_points] - + if ids_to_delete: # Delete the relations self._client.delete( @@ -222,9 +224,11 @@ class QdrantVectorDBStorage(BaseVectorStorage): points_selector=models.PointIdsList( points=ids_to_delete, ), - wait=True + wait=True, + ) + logger.debug( + f"Deleted {len(ids_to_delete)} relations for {entity_name}" ) - logger.debug(f"Deleted {len(ids_to_delete)} relations for {entity_name}") else: logger.debug(f"No relations found for entity {entity_name}") except Exception as e: diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index bb42b367..3feb4985 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -67,35 +67,39 @@ class RedisKVStorage(BaseKVStorage): async def delete(self, ids: list[str]) -> None: """Delete entries with specified IDs - + Args: ids: List of entry IDs to be deleted """ if not ids: return - + pipe = self._redis.pipeline() for id in ids: pipe.delete(f"{self.namespace}:{id}") - + results = await pipe.execute() deleted_count = sum(results) - logger.info(f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}") + logger.info( + f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}" + ) async def delete_entity(self, entity_name: str) -> None: """Delete an entity by name - + Args: entity_name: Name of the entity to delete """ - + try: entity_id = compute_mdhash_id(entity_name, prefix="ent-") - logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}") - + logger.debug( + f"Attempting to delete entity {entity_name} with ID {entity_id}" + ) + # Delete the entity result = await self._redis.delete(f"{self.namespace}:{entity_id}") - + if result: logger.debug(f"Successfully deleted entity {entity_name}") else: @@ -105,7 +109,7 @@ class RedisKVStorage(BaseKVStorage): async def delete_entity_relation(self, entity_name: str) -> None: """Delete all relations associated with an entity - + Args: entity_name: Name of the entity whose relations should be deleted """ @@ -114,29 +118,32 @@ class RedisKVStorage(BaseKVStorage): cursor = 0 relation_keys = [] pattern = f"{self.namespace}:*" - + while True: cursor, keys = await self._redis.scan(cursor, match=pattern) - + # For each key, get the value and check if it's related to entity_name for key in keys: value = await self._redis.get(key) if value: data = json.loads(value) # Check if this is a relation involving the entity - if data.get("src_id") == entity_name or data.get("tgt_id") == entity_name: + if ( + data.get("src_id") == entity_name + or data.get("tgt_id") == entity_name + ): relation_keys.append(key) - + # Exit loop when cursor returns to 0 if cursor == 0: break - + # Delete the relation keys if relation_keys: deleted = await self._redis.delete(*relation_keys) logger.debug(f"Deleted {deleted} relations for {entity_name}") else: logger.debug(f"No relations found for entity {entity_name}") - + except Exception as e: logger.error(f"Error deleting relations for {entity_name}: {e}") diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index f791d401..684c30d7 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -567,62 +567,68 @@ class TiDBGraphStorage(BaseGraphStorage): async def delete_node(self, node_id: str) -> None: """Delete a node and all its related edges - + Args: node_id: The ID of the node to delete """ # First delete all edges related to this node - await self.db.execute(SQL_TEMPLATES["delete_node_edges"], - {"name": node_id, "workspace": self.db.workspace}) - + await self.db.execute( + SQL_TEMPLATES["delete_node_edges"], + {"name": node_id, "workspace": self.db.workspace}, + ) + # Then delete the node itself - await self.db.execute(SQL_TEMPLATES["delete_node"], - {"name": node_id, "workspace": self.db.workspace}) - - logger.debug(f"Node {node_id} and its related edges have been deleted from the graph") - + await self.db.execute( + SQL_TEMPLATES["delete_node"], + {"name": node_id, "workspace": self.db.workspace}, + ) + + logger.debug( + f"Node {node_id} and its related edges have been deleted from the graph" + ) + async def get_all_labels(self) -> list[str]: """Get all entity types (labels) in the database - + Returns: List of labels sorted alphabetically """ result = await self.db.query( - SQL_TEMPLATES["get_all_labels"], - {"workspace": self.db.workspace}, - multirows=True + SQL_TEMPLATES["get_all_labels"], + {"workspace": self.db.workspace}, + multirows=True, ) - + if not result: return [] - + # Extract all labels return [item["label"] for item in result] - + async def get_knowledge_graph( self, node_label: str, max_depth: int = 5 ) -> KnowledgeGraph: """ Get a connected subgraph of nodes matching the specified label Maximum number of nodes is limited by MAX_GRAPH_NODES environment variable (default: 1000) - + Args: node_label: The node label to match max_depth: Maximum depth of the subgraph - + Returns: KnowledgeGraph object containing nodes and edges """ result = KnowledgeGraph() MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) - + # Get matching nodes if node_label == "*": # Handle special case, get all nodes node_results = await self.db.query( SQL_TEMPLATES["get_all_nodes"], {"workspace": self.db.workspace, "max_nodes": MAX_GRAPH_NODES}, - multirows=True + multirows=True, ) else: # Get nodes matching the label @@ -630,84 +636,93 @@ class TiDBGraphStorage(BaseGraphStorage): node_results = await self.db.query( SQL_TEMPLATES["get_matching_nodes"], {"workspace": self.db.workspace, "label_pattern": label_pattern}, - multirows=True + multirows=True, ) - + if not node_results: logger.warning(f"No nodes found matching label {node_label}") return result - + # Limit the number of returned nodes if len(node_results) > MAX_GRAPH_NODES: node_results = node_results[:MAX_GRAPH_NODES] - + # Extract node names for edge query node_names = [node["name"] for node in node_results] node_names_str = ",".join([f"'{name}'" for name in node_names]) - + # Add nodes to result for node in node_results: - node_properties = {k: v for k, v in node.items() if k not in ["id", "name", "entity_type"]} + node_properties = { + k: v for k, v in node.items() if k not in ["id", "name", "entity_type"] + } result.nodes.append( KnowledgeGraphNode( id=node["name"], - labels=[node["entity_type"]] if node.get("entity_type") else [node["name"]], - properties=node_properties + labels=[node["entity_type"]] + if node.get("entity_type") + else [node["name"]], + properties=node_properties, ) ) - + # Get related edges edge_results = await self.db.query( SQL_TEMPLATES["get_related_edges"].format(node_names=node_names_str), {"workspace": self.db.workspace}, - multirows=True + multirows=True, ) - + if edge_results: # Add edges to result for edge in edge_results: # Only include edges related to selected nodes - if edge["source_name"] in node_names and edge["target_name"] in node_names: + if ( + edge["source_name"] in node_names + and edge["target_name"] in node_names + ): edge_id = f"{edge['source_name']}-{edge['target_name']}" - edge_properties = {k: v for k, v in edge.items() - if k not in ["id", "source_name", "target_name"]} - + edge_properties = { + k: v + for k, v in edge.items() + if k not in ["id", "source_name", "target_name"] + } + result.edges.append( KnowledgeGraphEdge( id=edge_id, type="RELATED", source=edge["source_name"], target=edge["target_name"], - properties=edge_properties + properties=edge_properties, ) ) - + logger.info( f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" ) return result - + async def remove_nodes(self, nodes: list[str]): """Delete multiple nodes - + Args: nodes: List of node IDs to delete """ for node_id in nodes: await self.delete_node(node_id) - + async def remove_edges(self, edges: list[tuple[str, str]]): """Delete multiple edges - + Args: edges: List of edges to delete, each edge is a (source, target) tuple """ for source, target in edges: - await self.db.execute(SQL_TEMPLATES["remove_multiple_edges"], { - "source": source, - "target": target, - "workspace": self.db.workspace - }) + await self.db.execute( + SQL_TEMPLATES["remove_multiple_edges"], + {"source": source, "target": target, "workspace": self.db.workspace}, + ) N_T = { @@ -919,26 +934,26 @@ SQL_TEMPLATES = { source_chunk_id = VALUES(source_chunk_id) """, "delete_node": """ - DELETE FROM LIGHTRAG_GRAPH_NODES + DELETE FROM LIGHTRAG_GRAPH_NODES WHERE name = :name AND workspace = :workspace """, "delete_node_edges": """ - DELETE FROM LIGHTRAG_GRAPH_EDGES + DELETE FROM LIGHTRAG_GRAPH_EDGES WHERE (source_name = :name OR target_name = :name) AND workspace = :workspace """, "get_all_labels": """ - SELECT DISTINCT entity_type as label - FROM LIGHTRAG_GRAPH_NODES + SELECT DISTINCT entity_type as label + FROM LIGHTRAG_GRAPH_NODES WHERE workspace = :workspace ORDER BY entity_type """, "get_matching_nodes": """ - SELECT * FROM LIGHTRAG_GRAPH_NODES + SELECT * FROM LIGHTRAG_GRAPH_NODES WHERE name LIKE :label_pattern AND workspace = :workspace ORDER BY name """, "get_all_nodes": """ - SELECT * FROM LIGHTRAG_GRAPH_NODES + SELECT * FROM LIGHTRAG_GRAPH_NODES WHERE workspace = :workspace ORDER BY name LIMIT :max_nodes @@ -952,5 +967,5 @@ SQL_TEMPLATES = { DELETE FROM LIGHTRAG_GRAPH_EDGES WHERE (source_name = :source AND target_name = :target) AND workspace = :workspace - """ + """, } diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index eeed8a70..a8034ddd 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -1401,40 +1401,54 @@ class LightRAG: def delete_by_relation(self, source_entity: str, target_entity: str) -> None: """Synchronously delete a relation between two entities. - + Args: source_entity: Name of the source entity target_entity: Name of the target entity """ loop = always_get_an_event_loop() - return loop.run_until_complete(self.adelete_by_relation(source_entity, target_entity)) + return loop.run_until_complete( + self.adelete_by_relation(source_entity, target_entity) + ) async def adelete_by_relation(self, source_entity: str, target_entity: str) -> None: """Asynchronously delete a relation between two entities. - + Args: source_entity: Name of the source entity target_entity: Name of the target entity """ try: # Check if the relation exists - edge_exists = await self.chunk_entity_relation_graph.has_edge(source_entity, target_entity) + edge_exists = await self.chunk_entity_relation_graph.has_edge( + source_entity, target_entity + ) if not edge_exists: - logger.warning(f"Relation from '{source_entity}' to '{target_entity}' does not exist") + logger.warning( + f"Relation from '{source_entity}' to '{target_entity}' does not exist" + ) return - + # Delete relation from vector database - relation_id = compute_mdhash_id(source_entity + target_entity, prefix="rel-") + relation_id = compute_mdhash_id( + source_entity + target_entity, prefix="rel-" + ) await self.relationships_vdb.delete([relation_id]) - + # Delete relation from knowledge graph - await self.chunk_entity_relation_graph.remove_edges([(source_entity, target_entity)]) - - logger.info(f"Successfully deleted relation from '{source_entity}' to '{target_entity}'") + await self.chunk_entity_relation_graph.remove_edges( + [(source_entity, target_entity)] + ) + + logger.info( + f"Successfully deleted relation from '{source_entity}' to '{target_entity}'" + ) await self._delete_relation_done() except Exception as e: - logger.error(f"Error while deleting relation from '{source_entity}' to '{target_entity}': {e}") - + logger.error( + f"Error while deleting relation from '{source_entity}' to '{target_entity}': {e}" + ) + async def _delete_relation_done(self) -> None: """Callback after relation deletion is complete""" await asyncio.gather( From 4ebaf8026b85e5794c132cf4aad92ad582b05605 Mon Sep 17 00:00:00 2001 From: zrguo Date: Tue, 4 Mar 2025 16:11:13 +0800 Subject: [PATCH 11/21] Update oracle_impl.py --- lightrag/kg/oracle_impl.py | 58 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 5dee1143..754c3491 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -737,6 +737,64 @@ class OracleGraphStorage(BaseGraphStorage): logger.error(f"Error deleting node {node_id}: {e}") raise + async def remove_nodes(self, nodes: list[str]) -> None: + """Delete multiple nodes from the graph + + Args: + nodes: List of node IDs to be deleted + """ + if not nodes: + return + + try: + for node in nodes: + # For each node, first delete all its relationships + delete_relations_sql = SQL_TEMPLATES["delete_entity_relations"] + params_relations = {"workspace": self.db.workspace, "entity_name": node} + await self.db.execute(delete_relations_sql, params_relations) + + # Then delete the node itself + delete_node_sql = SQL_TEMPLATES["delete_entity"] + params_node = {"workspace": self.db.workspace, "entity_name": node} + await self.db.execute(delete_node_sql, params_node) + + logger.info(f"Successfully deleted {len(nodes)} nodes and their relationships") + except Exception as e: + logger.error(f"Error during batch node deletion: {e}") + raise + + async def remove_edges(self, edges: list[tuple[str, str]]) -> None: + """Delete multiple edges from the graph + + Args: + edges: List of edges to be deleted, each edge is a (source, target) tuple + """ + if not edges: + return + + try: + for source, target in edges: + # Check if the edge exists before attempting to delete + if await self.has_edge(source, target): + # Delete the edge using a SQL query that matches both source and target + delete_edge_sql = """ + DELETE FROM LIGHTRAG_GRAPH_EDGES + WHERE workspace = :workspace + AND source_name = :source_name + AND target_name = :target_name + """ + params = { + "workspace": self.db.workspace, + "source_name": source, + "target_name": target + } + await self.db.execute(delete_edge_sql, params) + + logger.info(f"Successfully deleted {len(edges)} edges from the graph") + except Exception as e: + logger.error(f"Error during batch edge deletion: {e}") + raise + async def get_all_labels(self) -> list[str]: """Get all unique entity types (labels) in the graph From 1ee6c23a53a3418be88e2ecdb5bc8637271712b4 Mon Sep 17 00:00:00 2001 From: zrguo Date: Tue, 4 Mar 2025 16:12:27 +0800 Subject: [PATCH 12/21] fix linting --- lightrag/kg/oracle_impl.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 754c3491..d105aa54 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -739,57 +739,59 @@ class OracleGraphStorage(BaseGraphStorage): async def remove_nodes(self, nodes: list[str]) -> None: """Delete multiple nodes from the graph - + Args: nodes: List of node IDs to be deleted """ if not nodes: return - + try: for node in nodes: # For each node, first delete all its relationships delete_relations_sql = SQL_TEMPLATES["delete_entity_relations"] params_relations = {"workspace": self.db.workspace, "entity_name": node} await self.db.execute(delete_relations_sql, params_relations) - + # Then delete the node itself delete_node_sql = SQL_TEMPLATES["delete_entity"] params_node = {"workspace": self.db.workspace, "entity_name": node} await self.db.execute(delete_node_sql, params_node) - - logger.info(f"Successfully deleted {len(nodes)} nodes and their relationships") + + logger.info( + f"Successfully deleted {len(nodes)} nodes and their relationships" + ) except Exception as e: logger.error(f"Error during batch node deletion: {e}") raise async def remove_edges(self, edges: list[tuple[str, str]]) -> None: """Delete multiple edges from the graph - + Args: edges: List of edges to be deleted, each edge is a (source, target) tuple """ if not edges: return - + try: for source, target in edges: # Check if the edge exists before attempting to delete if await self.has_edge(source, target): # Delete the edge using a SQL query that matches both source and target delete_edge_sql = """ - DELETE FROM LIGHTRAG_GRAPH_EDGES - WHERE workspace = :workspace - AND source_name = :source_name + DELETE FROM LIGHTRAG_GRAPH_EDGES + WHERE workspace = :workspace + AND source_name = :source_name AND target_name = :target_name """ params = { "workspace": self.db.workspace, "source_name": source, - "target_name": target + "target_name": target, } await self.db.execute(delete_edge_sql, params) - + logger.info(f"Successfully deleted {len(edges)} edges from the graph") except Exception as e: logger.error(f"Error during batch edge deletion: {e}") From 4e59a293fe792f3733a0ff836b6b7fe714ba5526 Mon Sep 17 00:00:00 2001 From: zrguo <49157727+LarFii@users.noreply.github.com> Date: Tue, 4 Mar 2025 16:19:23 +0800 Subject: [PATCH 13/21] Update __init__.py --- lightrag/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/__init__.py b/lightrag/__init__.py index 2d660928..e4cb3e63 100644 --- a/lightrag/__init__.py +++ b/lightrag/__init__.py @@ -1,5 +1,5 @@ from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam -__version__ = "1.2.3" +__version__ = "1.2.4" __author__ = "Zirui Guo" __url__ = "https://github.com/HKUDS/LightRAG" From 3264f6a118f572467f84986868217f258a299dfb Mon Sep 17 00:00:00 2001 From: zrguo Date: Tue, 4 Mar 2025 16:36:58 +0800 Subject: [PATCH 14/21] Update delete_by_doc_id --- lightrag/lightrag.py | 92 +++++++++++++++++++++++++------------------- 1 file changed, 52 insertions(+), 40 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index a8034ddd..e8e468af 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -1555,51 +1555,57 @@ class LightRAG: await self.text_chunks.delete(chunk_ids) # 5. Find and process entities and relationships that have these chunks as source - # Get all nodes in the graph - nodes = self.chunk_entity_relation_graph._graph.nodes(data=True) - edges = self.chunk_entity_relation_graph._graph.edges(data=True) - - # Track which entities and relationships need to be deleted or updated + # Get all nodes and edges from the graph storage using storage-agnostic methods entities_to_delete = set() entities_to_update = {} # entity_name -> new_source_id relationships_to_delete = set() relationships_to_update = {} # (src, tgt) -> new_source_id - # Process entities - for node, data in nodes: - if "source_id" in data: + # Process entities - use storage-agnostic methods + all_labels = await self.chunk_entity_relation_graph.get_all_labels() + for node_label in all_labels: + node_data = await self.chunk_entity_relation_graph.get_node(node_label) + if node_data and "source_id" in node_data: # Split source_id using GRAPH_FIELD_SEP - sources = set(data["source_id"].split(GRAPH_FIELD_SEP)) + sources = set(node_data["source_id"].split(GRAPH_FIELD_SEP)) sources.difference_update(chunk_ids) if not sources: - entities_to_delete.add(node) + entities_to_delete.add(node_label) logger.debug( - f"Entity {node} marked for deletion - no remaining sources" + f"Entity {node_label} marked for deletion - no remaining sources" ) else: new_source_id = GRAPH_FIELD_SEP.join(sources) - entities_to_update[node] = new_source_id + entities_to_update[node_label] = new_source_id logger.debug( - f"Entity {node} will be updated with new source_id: {new_source_id}" + f"Entity {node_label} will be updated with new source_id: {new_source_id}" ) # Process relationships - for src, tgt, data in edges: - if "source_id" in data: - # Split source_id using GRAPH_FIELD_SEP - sources = set(data["source_id"].split(GRAPH_FIELD_SEP)) - sources.difference_update(chunk_ids) - if not sources: - relationships_to_delete.add((src, tgt)) - logger.debug( - f"Relationship {src}-{tgt} marked for deletion - no remaining sources" - ) - else: - new_source_id = GRAPH_FIELD_SEP.join(sources) - relationships_to_update[(src, tgt)] = new_source_id - logger.debug( - f"Relationship {src}-{tgt} will be updated with new source_id: {new_source_id}" + for node_label in all_labels: + node_edges = await self.chunk_entity_relation_graph.get_node_edges( + node_label + ) + if node_edges: + for src, tgt in node_edges: + edge_data = await self.chunk_entity_relation_graph.get_edge( + src, tgt ) + if edge_data and "source_id" in edge_data: + # Split source_id using GRAPH_FIELD_SEP + sources = set(edge_data["source_id"].split(GRAPH_FIELD_SEP)) + sources.difference_update(chunk_ids) + if not sources: + relationships_to_delete.add((src, tgt)) + logger.debug( + f"Relationship {src}-{tgt} marked for deletion - no remaining sources" + ) + else: + new_source_id = GRAPH_FIELD_SEP.join(sources) + relationships_to_update[(src, tgt)] = new_source_id + logger.debug( + f"Relationship {src}-{tgt} will be updated with new source_id: {new_source_id}" + ) # Delete entities if entities_to_delete: @@ -1613,12 +1619,15 @@ class LightRAG: # Update entities for entity, new_source_id in entities_to_update.items(): - node_data = self.chunk_entity_relation_graph._graph.nodes[entity] - node_data["source_id"] = new_source_id - await self.chunk_entity_relation_graph.upsert_node(entity, node_data) - logger.debug( - f"Updated entity {entity} with new source_id: {new_source_id}" - ) + node_data = await self.chunk_entity_relation_graph.get_node(entity) + if node_data: + node_data["source_id"] = new_source_id + await self.chunk_entity_relation_graph.upsert_node( + entity, node_data + ) + logger.debug( + f"Updated entity {entity} with new source_id: {new_source_id}" + ) # Delete relationships if relationships_to_delete: @@ -1636,12 +1645,15 @@ class LightRAG: # Update relationships for (src, tgt), new_source_id in relationships_to_update.items(): - edge_data = self.chunk_entity_relation_graph._graph.edges[src, tgt] - edge_data["source_id"] = new_source_id - await self.chunk_entity_relation_graph.upsert_edge(src, tgt, edge_data) - logger.debug( - f"Updated relationship {src}-{tgt} with new source_id: {new_source_id}" - ) + edge_data = await self.chunk_entity_relation_graph.get_edge(src, tgt) + if edge_data: + edge_data["source_id"] = new_source_id + await self.chunk_entity_relation_graph.upsert_edge( + src, tgt, edge_data + ) + logger.debug( + f"Updated relationship {src}-{tgt} with new source_id: {new_source_id}" + ) # 6. Delete original document and status await self.full_docs.delete([doc_id]) From a688b8822a8a9eb3853781bb5c71029a22aa5396 Mon Sep 17 00:00:00 2001 From: Brocowlee Date: Tue, 4 Mar 2025 10:09:47 +0000 Subject: [PATCH 15/21] [EVO] Add language configuration to environment and argument parsing --- env.example | 2 +- lightrag/api/lightrag_server.py | 3 +++ lightrag/api/utils_api.py | 1 + 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/env.example b/env.example index 112676c6..d0c03a05 100644 --- a/env.example +++ b/env.example @@ -47,7 +47,7 @@ # CHUNK_OVERLAP_SIZE=100 # MAX_TOKENS=32768 # Max tokens send to LLM for summarization # MAX_TOKEN_SUMMARY=500 # Max tokens for entity or relations summary -# SUMMARY_LANGUAGE=English +# LANGUAGE=English # MAX_EMBED_TOKENS=8192 ### LLM Configuration (Use valid host. For local services installed with docker, you can use host.docker.internal) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 5f2c437f..93201a20 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -331,6 +331,9 @@ def create_app(args): }, log_level=args.log_level, namespace_prefix=args.namespace_prefix, + addon_params={ + "language": args.language, + }, auto_manage_storages_states=False, ) else: # azure_openai diff --git a/lightrag/api/utils_api.py b/lightrag/api/utils_api.py index ed1250d4..f865682b 100644 --- a/lightrag/api/utils_api.py +++ b/lightrag/api/utils_api.py @@ -340,6 +340,7 @@ def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace: # Inject chunk configuration args.chunk_size = get_env_value("CHUNK_SIZE", 1200, int) args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int) + args.language = get_env_value("LANGUAGE", "English") ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name From f1ad55244abb482bef83ef2a9b340ef305e025de Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Tue, 4 Mar 2025 14:44:12 +0100 Subject: [PATCH 16/21] linting --- lightrag/api/run_with_gunicorn.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lightrag/api/run_with_gunicorn.py b/lightrag/api/run_with_gunicorn.py index 231a1727..cf9b3b91 100644 --- a/lightrag/api/run_with_gunicorn.py +++ b/lightrag/api/run_with_gunicorn.py @@ -15,6 +15,7 @@ from dotenv import load_dotenv # This update allows the user to put a different.env file for each lightrag folder load_dotenv(".env") + def check_and_install_dependencies(): """Check and install required dependencies""" required_packages = [ From 5e7ef39998c4cc6a7ca0b8bd3fd8300c12a2edee Mon Sep 17 00:00:00 2001 From: zrguo <49157727+LarFii@users.noreply.github.com> Date: Wed, 5 Mar 2025 15:12:01 +0800 Subject: [PATCH 17/21] Update operate.py --- lightrag/operate.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lightrag/operate.py b/lightrag/operate.py index 7db42284..30983145 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -1242,9 +1242,11 @@ async def _find_most_related_text_unit_from_entities( all_text_units_lookup = {} tasks = [] + for index, (this_text_units, this_edges) in enumerate(zip(text_units, edges)): for c_id in this_text_units: if c_id not in all_text_units_lookup: + all_text_units_lookup[c_id] = index tasks.append((c_id, index, this_edges)) results = await asyncio.gather( From 06fc65d9a0168f01ba657548fea55f54c35c2a6b Mon Sep 17 00:00:00 2001 From: zrguo <49157727+LarFii@users.noreply.github.com> Date: Wed, 5 Mar 2025 16:26:28 +0800 Subject: [PATCH 18/21] Revert "[EVO] Add language configuration to environment and argument parsing" This reverts commit a688b8822a8a9eb3853781bb5c71029a22aa5396. --- env.example | 2 +- lightrag/api/lightrag_server.py | 3 --- lightrag/api/utils_api.py | 1 - 3 files changed, 1 insertion(+), 5 deletions(-) diff --git a/env.example b/env.example index 294a5b68..99909ac6 100644 --- a/env.example +++ b/env.example @@ -48,7 +48,7 @@ # CHUNK_OVERLAP_SIZE=100 # MAX_TOKENS=32768 # Max tokens send to LLM for summarization # MAX_TOKEN_SUMMARY=500 # Max tokens for entity or relations summary -# LANGUAGE=English +# SUMMARY_LANGUAGE=English # MAX_EMBED_TOKENS=8192 ### LLM Configuration (Use valid host. For local services installed with docker, you can use host.docker.internal) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 775dc5e3..8ad232f0 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -331,9 +331,6 @@ def create_app(args): "use_llm_check": False, }, namespace_prefix=args.namespace_prefix, - addon_params={ - "language": args.language, - }, auto_manage_storages_states=False, ) else: # azure_openai diff --git a/lightrag/api/utils_api.py b/lightrag/api/utils_api.py index f865682b..ed1250d4 100644 --- a/lightrag/api/utils_api.py +++ b/lightrag/api/utils_api.py @@ -340,7 +340,6 @@ def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace: # Inject chunk configuration args.chunk_size = get_env_value("CHUNK_SIZE", 1200, int) args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int) - args.language = get_env_value("LANGUAGE", "English") ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name From 649164c3e692659aeabc966e53e0af415ca15863 Mon Sep 17 00:00:00 2001 From: zrguo <49157727+LarFii@users.noreply.github.com> Date: Wed, 5 Mar 2025 16:55:09 +0800 Subject: [PATCH 19/21] Update lightrag.py --- lightrag/lightrag.py | 36 +++++++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index e8e468af..f81ade0b 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -1945,6 +1945,9 @@ class LightRAG: new_entity_name, new_node_data ) + # Store relationships that need to be updated + relations_to_update = [] + # Get all edges related to the original entity edges = await self.chunk_entity_relation_graph.get_node_edges( entity_name @@ -1960,10 +1963,12 @@ class LightRAG: await self.chunk_entity_relation_graph.upsert_edge( new_entity_name, target, edge_data ) + relations_to_update.append((new_entity_name, target, edge_data)) else: # target == entity_name await self.chunk_entity_relation_graph.upsert_edge( source, new_entity_name, edge_data ) + relations_to_update.append((source, new_entity_name, edge_data)) # Delete old entity await self.chunk_entity_relation_graph.delete_node(entity_name) @@ -1972,6 +1977,35 @@ class LightRAG: old_entity_id = compute_mdhash_id(entity_name, prefix="ent-") await self.entities_vdb.delete([old_entity_id]) + # Update relationship vector representations + for src, tgt, edge_data in relations_to_update: + description = edge_data.get("description", "") + keywords = edge_data.get("keywords", "") + source_id = edge_data.get("source_id", "") + weight = float(edge_data.get("weight", 1.0)) + + # Create new content for embedding + content = f"{src}\t{tgt}\n{keywords}\n{description}" + + # Calculate relationship ID + relation_id = compute_mdhash_id(src + tgt, prefix="rel-") + + # Prepare data for vector database update + relation_data = { + relation_id: { + "content": content, + "src_id": src, + "tgt_id": tgt, + "source_id": source_id, + "description": description, + "keywords": keywords, + "weight": weight, + } + } + + # Update vector database + await self.relationships_vdb.upsert(relation_data) + # Update working entity name to new name entity_name = new_entity_name else: @@ -2082,7 +2116,7 @@ class LightRAG: weight = float(new_edge_data.get("weight", 1.0)) # Create content for embedding - content = f"{keywords}\t{source_entity}\n{target_entity}\n{description}" + content = f"{source_entity}\t{target_entity}\n{keywords}\n{description}" # Calculate relation ID relation_id = compute_mdhash_id( From 917dc39334d99120f98ede2c7cb13c70bdf78e24 Mon Sep 17 00:00:00 2001 From: zrguo <49157727+LarFii@users.noreply.github.com> Date: Wed, 5 Mar 2025 17:00:01 +0800 Subject: [PATCH 20/21] fix linting --- lightrag/lightrag.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index f81ade0b..5c060658 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -1963,12 +1963,16 @@ class LightRAG: await self.chunk_entity_relation_graph.upsert_edge( new_entity_name, target, edge_data ) - relations_to_update.append((new_entity_name, target, edge_data)) + relations_to_update.append( + (new_entity_name, target, edge_data) + ) else: # target == entity_name await self.chunk_entity_relation_graph.upsert_edge( source, new_entity_name, edge_data ) - relations_to_update.append((source, new_entity_name, edge_data)) + relations_to_update.append( + (source, new_entity_name, edge_data) + ) # Delete old entity await self.chunk_entity_relation_graph.delete_node(entity_name) @@ -1983,13 +1987,13 @@ class LightRAG: keywords = edge_data.get("keywords", "") source_id = edge_data.get("source_id", "") weight = float(edge_data.get("weight", 1.0)) - + # Create new content for embedding content = f"{src}\t{tgt}\n{keywords}\n{description}" - + # Calculate relationship ID relation_id = compute_mdhash_id(src + tgt, prefix="rel-") - + # Prepare data for vector database update relation_data = { relation_id: { @@ -2002,7 +2006,7 @@ class LightRAG: "weight": weight, } } - + # Update vector database await self.relationships_vdb.upsert(relation_data) From ec0450f7121e82f7f5aee2fd6e96903187a33283 Mon Sep 17 00:00:00 2001 From: zrguo Date: Thu, 6 Mar 2025 00:53:23 +0800 Subject: [PATCH 21/21] Add merge entities --- README.md | 70 ++++++++ lightrag/lightrag.py | 389 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 459 insertions(+) diff --git a/README.md b/README.md index 57563a1f..00da54fb 100644 --- a/README.md +++ b/README.md @@ -849,6 +849,76 @@ All operations are available in both synchronous and asynchronous versions. The These operations maintain data consistency across both the graph database and vector database components, ensuring your knowledge graph remains coherent. +## Entity Merging + +
+ Merge Entities and Their Relationships + +LightRAG now supports merging multiple entities into a single entity, automatically handling all relationships: + +```python +# Basic entity merging +rag.merge_entities( + source_entities=["Artificial Intelligence", "AI", "Machine Intelligence"], + target_entity="AI Technology" +) +``` + +With custom merge strategy: + +```python +# Define custom merge strategy for different fields +rag.merge_entities( + source_entities=["John Smith", "Dr. Smith", "J. Smith"], + target_entity="John Smith", + merge_strategy={ + "description": "concatenate", # Combine all descriptions + "entity_type": "keep_first", # Keep the entity type from the first entity + "source_id": "join_unique" # Combine all unique source IDs + } +) +``` + +With custom target entity data: + +```python +# Specify exact values for the merged entity +rag.merge_entities( + source_entities=["New York", "NYC", "Big Apple"], + target_entity="New York City", + target_entity_data={ + "entity_type": "LOCATION", + "description": "New York City is the most populous city in the United States.", + } +) +``` + +Advanced usage combining both approaches: + +```python +# Merge company entities with both strategy and custom data +rag.merge_entities( + source_entities=["Microsoft Corp", "Microsoft Corporation", "MSFT"], + target_entity="Microsoft", + merge_strategy={ + "description": "concatenate", # Combine all descriptions + "source_id": "join_unique" # Combine source IDs + }, + target_entity_data={ + "entity_type": "ORGANIZATION", + } +) +``` + +When merging entities: +* All relationships from source entities are redirected to the target entity +* Duplicate relationships are intelligently merged +* Self-relationships (loops) are prevented +* Source entities are removed after merging +* Relationship weights and attributes are preserved + +
+ ## Cache
diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index e8c2bb9c..fa3a3a2b 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -2420,3 +2420,392 @@ class LightRAG: return loop.run_until_complete( self.acreate_relation(source_entity, target_entity, relation_data) ) + + async def amerge_entities( + self, + source_entities: list[str], + target_entity: str, + merge_strategy: dict[str, str] = None, + target_entity_data: dict[str, Any] = None, + ) -> dict[str, Any]: + """Asynchronously merge multiple entities into one entity. + + Merges multiple source entities into a target entity, handling all relationships, + and updating both the knowledge graph and vector database. + + Args: + source_entities: List of source entity names to merge + target_entity: Name of the target entity after merging + merge_strategy: Merge strategy configuration, e.g. {"description": "concatenate", "entity_type": "keep_first"} + Supported strategies: + - "concatenate": Concatenate all values (for text fields) + - "keep_first": Keep the first non-empty value + - "keep_last": Keep the last non-empty value + - "join_unique": Join all unique values (for fields separated by delimiter) + target_entity_data: Dictionary of specific values to set for the target entity, + overriding any merged values, e.g. {"description": "custom description", "entity_type": "PERSON"} + + Returns: + Dictionary containing the merged entity information + """ + try: + # Default merge strategy + default_strategy = { + "description": "concatenate", + "entity_type": "keep_first", + "source_id": "join_unique", + } + + merge_strategy = ( + default_strategy + if merge_strategy is None + else {**default_strategy, **merge_strategy} + ) + target_entity_data = ( + {} if target_entity_data is None else target_entity_data + ) + + # 1. Check if all source entities exist + source_entities_data = {} + for entity_name in source_entities: + node_data = await self.chunk_entity_relation_graph.get_node(entity_name) + if not node_data: + raise ValueError(f"Source entity '{entity_name}' does not exist") + source_entities_data[entity_name] = node_data + + # 2. Check if target entity exists and get its data if it does + target_exists = await self.chunk_entity_relation_graph.has_node( + target_entity + ) + target_entity_data = {} + if target_exists: + target_entity_data = await self.chunk_entity_relation_graph.get_node( + target_entity + ) + logger.info( + f"Target entity '{target_entity}' already exists, will merge data" + ) + + # 3. Merge entity data + merged_entity_data = self._merge_entity_attributes( + list(source_entities_data.values()) + + ([target_entity_data] if target_exists else []), + merge_strategy, + ) + + # Apply any explicitly provided target entity data (overrides merged data) + for key, value in target_entity_data.items(): + merged_entity_data[key] = value + + # 4. Get all relationships of the source entities + all_relations = [] + for entity_name in source_entities: + # Get all relationships where this entity is the source + outgoing_edges = await self.chunk_entity_relation_graph.get_node_edges( + entity_name + ) + if outgoing_edges: + for src, tgt in outgoing_edges: + # Ensure src is the current entity + if src == entity_name: + edge_data = await self.chunk_entity_relation_graph.get_edge( + src, tgt + ) + all_relations.append(("outgoing", src, tgt, edge_data)) + + # Get all relationships where this entity is the target + incoming_edges = [] + all_labels = await self.chunk_entity_relation_graph.get_all_labels() + for label in all_labels: + if label == entity_name: + continue + node_edges = await self.chunk_entity_relation_graph.get_node_edges( + label + ) + for src, tgt in node_edges or []: + if tgt == entity_name: + incoming_edges.append((src, tgt)) + + for src, tgt in incoming_edges: + edge_data = await self.chunk_entity_relation_graph.get_edge( + src, tgt + ) + all_relations.append(("incoming", src, tgt, edge_data)) + + # 5. Create or update the target entity + if not target_exists: + await self.chunk_entity_relation_graph.upsert_node( + target_entity, merged_entity_data + ) + logger.info(f"Created new target entity '{target_entity}'") + else: + await self.chunk_entity_relation_graph.upsert_node( + target_entity, merged_entity_data + ) + logger.info(f"Updated existing target entity '{target_entity}'") + + # 6. Recreate all relationships, pointing to the target entity + relation_updates = {} # Track relationships that need to be merged + + for rel_type, src, tgt, edge_data in all_relations: + new_src = target_entity if src in source_entities else src + new_tgt = target_entity if tgt in source_entities else tgt + + # Skip relationships between source entities to avoid self-loops + if new_src == new_tgt: + logger.info( + f"Skipping relationship between source entities: {src} -> {tgt} to avoid self-loop" + ) + continue + + # Check if the same relationship already exists + relation_key = f"{new_src}|{new_tgt}" + if relation_key in relation_updates: + # Merge relationship data + existing_data = relation_updates[relation_key]["data"] + merged_relation = self._merge_relation_attributes( + [existing_data, edge_data], + { + "description": "concatenate", + "keywords": "join_unique", + "source_id": "join_unique", + "weight": "max", + }, + ) + relation_updates[relation_key]["data"] = merged_relation + logger.info( + f"Merged duplicate relationship: {new_src} -> {new_tgt}" + ) + else: + relation_updates[relation_key] = { + "src": new_src, + "tgt": new_tgt, + "data": edge_data.copy(), + } + + # Apply relationship updates + for rel_data in relation_updates.values(): + await self.chunk_entity_relation_graph.upsert_edge( + rel_data["src"], rel_data["tgt"], rel_data["data"] + ) + logger.info( + f"Created or updated relationship: {rel_data['src']} -> {rel_data['tgt']}" + ) + + # 7. Update entity vector representation + description = merged_entity_data.get("description", "") + source_id = merged_entity_data.get("source_id", "") + entity_type = merged_entity_data.get("entity_type", "") + content = target_entity + "\n" + description + + entity_id = compute_mdhash_id(target_entity, prefix="ent-") + entity_data_for_vdb = { + entity_id: { + "content": content, + "entity_name": target_entity, + "source_id": source_id, + "description": description, + "entity_type": entity_type, + } + } + + await self.entities_vdb.upsert(entity_data_for_vdb) + + # 8. Update relationship vector representations + for rel_data in relation_updates.values(): + src = rel_data["src"] + tgt = rel_data["tgt"] + edge_data = rel_data["data"] + + description = edge_data.get("description", "") + keywords = edge_data.get("keywords", "") + source_id = edge_data.get("source_id", "") + weight = float(edge_data.get("weight", 1.0)) + + content = f"{keywords}\t{src}\n{tgt}\n{description}" + relation_id = compute_mdhash_id(src + tgt, prefix="rel-") + + relation_data_for_vdb = { + relation_id: { + "content": content, + "src_id": src, + "tgt_id": tgt, + "source_id": source_id, + "description": description, + "keywords": keywords, + "weight": weight, + } + } + + await self.relationships_vdb.upsert(relation_data_for_vdb) + + # 9. Delete source entities + for entity_name in source_entities: + # Delete entity node + await self.chunk_entity_relation_graph.delete_node(entity_name) + # Delete record from vector database + entity_id = compute_mdhash_id(entity_name, prefix="ent-") + await self.entities_vdb.delete([entity_id]) + logger.info(f"Deleted source entity '{entity_name}'") + + # 10. Save changes + await self._merge_entities_done() + + logger.info( + f"Successfully merged {len(source_entities)} entities into '{target_entity}'" + ) + return await self.get_entity_info(target_entity, include_vector_data=True) + + except Exception as e: + logger.error(f"Error merging entities: {e}") + raise + + def merge_entities( + self, + source_entities: list[str], + target_entity: str, + merge_strategy: dict[str, str] = None, + target_entity_data: dict[str, Any] = None, + ) -> dict[str, Any]: + """Synchronously merge multiple entities into one entity. + + Merges multiple source entities into a target entity, handling all relationships, + and updating both the knowledge graph and vector database. + + Args: + source_entities: List of source entity names to merge + target_entity: Name of the target entity after merging + merge_strategy: Merge strategy configuration, e.g. {"description": "concatenate", "entity_type": "keep_first"} + target_entity_data: Dictionary of specific values to set for the target entity, + overriding any merged values, e.g. {"description": "custom description", "entity_type": "PERSON"} + + Returns: + Dictionary containing the merged entity information + """ + loop = always_get_an_event_loop() + return loop.run_until_complete( + self.amerge_entities( + source_entities, target_entity, merge_strategy, target_entity_data + ) + ) + + def _merge_entity_attributes( + self, entity_data_list: list[dict[str, Any]], merge_strategy: dict[str, str] + ) -> dict[str, Any]: + """Merge attributes from multiple entities. + + Args: + entity_data_list: List of dictionaries containing entity data + merge_strategy: Merge strategy for each field + + Returns: + Dictionary containing merged entity data + """ + merged_data = {} + + # Collect all possible keys + all_keys = set() + for data in entity_data_list: + all_keys.update(data.keys()) + + # Merge values for each key + for key in all_keys: + # Get all values for this key + values = [data.get(key) for data in entity_data_list if data.get(key)] + + if not values: + continue + + # Merge values according to strategy + strategy = merge_strategy.get(key, "keep_first") + + if strategy == "concatenate": + merged_data[key] = "\n\n".join(values) + elif strategy == "keep_first": + merged_data[key] = values[0] + elif strategy == "keep_last": + merged_data[key] = values[-1] + elif strategy == "join_unique": + # Handle fields separated by GRAPH_FIELD_SEP + unique_items = set() + for value in values: + items = value.split(GRAPH_FIELD_SEP) + unique_items.update(items) + merged_data[key] = GRAPH_FIELD_SEP.join(unique_items) + else: + # Default strategy + merged_data[key] = values[0] + + return merged_data + + def _merge_relation_attributes( + self, relation_data_list: list[dict[str, Any]], merge_strategy: dict[str, str] + ) -> dict[str, Any]: + """Merge attributes from multiple relationships. + + Args: + relation_data_list: List of dictionaries containing relationship data + merge_strategy: Merge strategy for each field + + Returns: + Dictionary containing merged relationship data + """ + merged_data = {} + + # Collect all possible keys + all_keys = set() + for data in relation_data_list: + all_keys.update(data.keys()) + + # Merge values for each key + for key in all_keys: + # Get all values for this key + values = [ + data.get(key) + for data in relation_data_list + if data.get(key) is not None + ] + + if not values: + continue + + # Merge values according to strategy + strategy = merge_strategy.get(key, "keep_first") + + if strategy == "concatenate": + merged_data[key] = "\n\n".join(str(v) for v in values) + elif strategy == "keep_first": + merged_data[key] = values[0] + elif strategy == "keep_last": + merged_data[key] = values[-1] + elif strategy == "join_unique": + # Handle fields separated by GRAPH_FIELD_SEP + unique_items = set() + for value in values: + items = str(value).split(GRAPH_FIELD_SEP) + unique_items.update(items) + merged_data[key] = GRAPH_FIELD_SEP.join(unique_items) + elif strategy == "max": + # For numeric fields like weight + try: + merged_data[key] = max(float(v) for v in values) + except (ValueError, TypeError): + merged_data[key] = values[0] + else: + # Default strategy + merged_data[key] = values[0] + + return merged_data + + async def _merge_entities_done(self) -> None: + """Callback after entity merging is complete, ensures updates are persisted""" + await asyncio.gather( + *[ + cast(StorageNameSpace, storage_inst).index_done_callback() + for storage_inst in [ # type: ignore + self.entities_vdb, + self.relationships_vdb, + self.chunk_entity_relation_graph, + ] + ] + )