fix linting

This commit is contained in:
zrguo
2025-03-04 15:53:20 +08:00
parent 3a2a636862
commit 81568f3bad
11 changed files with 394 additions and 327 deletions

View File

@@ -619,7 +619,7 @@ class AGEStorage(BaseGraphStorage):
node_id: The label of the node to delete node_id: The label of the node to delete
""" """
entity_name_label = node_id.strip('"') entity_name_label = node_id.strip('"')
query = """ query = """
MATCH (n:`{label}`) MATCH (n:`{label}`)
DETACH DELETE n DETACH DELETE n
@@ -650,18 +650,20 @@ class AGEStorage(BaseGraphStorage):
for source, target in edges: for source, target in edges:
entity_name_label_source = source.strip('"') entity_name_label_source = source.strip('"')
entity_name_label_target = target.strip('"') entity_name_label_target = target.strip('"')
query = """ query = """
MATCH (source:`{src_label}`)-[r]->(target:`{tgt_label}`) MATCH (source:`{src_label}`)-[r]->(target:`{tgt_label}`)
DELETE r DELETE r
""" """
params = { params = {
"src_label": AGEStorage._encode_graph_label(entity_name_label_source), "src_label": AGEStorage._encode_graph_label(entity_name_label_source),
"tgt_label": AGEStorage._encode_graph_label(entity_name_label_target) "tgt_label": AGEStorage._encode_graph_label(entity_name_label_target),
} }
try: try:
await self._query(query, **params) await self._query(query, **params)
logger.debug(f"Deleted edge from '{entity_name_label_source}' to '{entity_name_label_target}'") logger.debug(
f"Deleted edge from '{entity_name_label_source}' to '{entity_name_label_target}'"
)
except Exception as e: except Exception as e:
logger.error(f"Error during edge deletion: {str(e)}") logger.error(f"Error during edge deletion: {str(e)}")
raise raise
@@ -683,7 +685,7 @@ class AGEStorage(BaseGraphStorage):
async def get_all_labels(self) -> list[str]: async def get_all_labels(self) -> list[str]:
"""Get all node labels in the database """Get all node labels in the database
Returns: Returns:
["label1", "label2", ...] # Alphabetically sorted label list ["label1", "label2", ...] # Alphabetically sorted label list
""" """
@@ -692,7 +694,7 @@ class AGEStorage(BaseGraphStorage):
RETURN DISTINCT labels(n) AS node_labels RETURN DISTINCT labels(n) AS node_labels
""" """
results = await self._query(query) results = await self._query(query)
all_labels = [] all_labels = []
for record in results: for record in results:
if record and "node_labels" in record: if record and "node_labels" in record:
@@ -701,7 +703,7 @@ class AGEStorage(BaseGraphStorage):
# Decode label # Decode label
decoded_label = AGEStorage._decode_graph_label(label) decoded_label = AGEStorage._decode_graph_label(label)
all_labels.append(decoded_label) all_labels.append(decoded_label)
# Remove duplicates and sort # Remove duplicates and sort
return sorted(list(set(all_labels))) return sorted(list(set(all_labels)))
@@ -719,7 +721,7 @@ class AGEStorage(BaseGraphStorage):
Args: Args:
node_label: String to match in node labels (will match any node containing this string in its label) node_label: String to match in node labels (will match any node containing this string in its label)
max_depth: Maximum depth of the graph. Defaults to 5. max_depth: Maximum depth of the graph. Defaults to 5.
Returns: Returns:
KnowledgeGraph: Complete connected subgraph for specified node KnowledgeGraph: Complete connected subgraph for specified node
""" """
@@ -727,7 +729,7 @@ class AGEStorage(BaseGraphStorage):
result = KnowledgeGraph() result = KnowledgeGraph()
seen_nodes = set() seen_nodes = set()
seen_edges = set() seen_edges = set()
# Handle special case for "*" label # Handle special case for "*" label
if node_label == "*": if node_label == "*":
# Query all nodes and sort by degree # Query all nodes and sort by degree
@@ -741,7 +743,7 @@ class AGEStorage(BaseGraphStorage):
""" """
params = {"max_nodes": max_graph_nodes} params = {"max_nodes": max_graph_nodes}
nodes_result = await self._query(query, **params) nodes_result = await self._query(query, **params)
# Add nodes to result # Add nodes to result
node_ids = [] node_ids = []
for record in nodes_result: for record in nodes_result:
@@ -755,12 +757,12 @@ class AGEStorage(BaseGraphStorage):
KnowledgeGraphNode( KnowledgeGraphNode(
id=node_id, id=node_id,
labels=[node_label], labels=[node_label],
properties=node_properties properties=node_properties,
) )
) )
seen_nodes.add(node_id) seen_nodes.add(node_id)
node_ids.append(node_id) node_ids.append(node_id)
# Query edges between these nodes # Query edges between these nodes
if node_ids: if node_ids:
edges_query = """ edges_query = """
@@ -770,7 +772,7 @@ class AGEStorage(BaseGraphStorage):
""" """
edges_params = {"node_ids": node_ids} edges_params = {"node_ids": node_ids}
edges_result = await self._query(edges_query, **edges_params) edges_result = await self._query(edges_query, **edges_params)
# Add edges to result # Add edges to result
for record in edges_result: for record in edges_result:
if "r" in record and "a" in record and "b" in record: if "r" in record and "a" in record and "b" in record:
@@ -785,7 +787,7 @@ class AGEStorage(BaseGraphStorage):
type="DIRECTED", type="DIRECTED",
source=source, source=source,
target=target, target=target,
properties=edge_properties properties=edge_properties,
) )
) )
seen_edges.add(edge_id) seen_edges.add(edge_id)
@@ -793,7 +795,7 @@ class AGEStorage(BaseGraphStorage):
# For specific label, use partial matching # For specific label, use partial matching
entity_name_label = node_label.strip('"') entity_name_label = node_label.strip('"')
encoded_label = AGEStorage._encode_graph_label(entity_name_label) encoded_label = AGEStorage._encode_graph_label(entity_name_label)
# Find matching start nodes # Find matching start nodes
start_query = """ start_query = """
MATCH (n:`{label}`) MATCH (n:`{label}`)
@@ -801,17 +803,14 @@ class AGEStorage(BaseGraphStorage):
""" """
start_params = {"label": encoded_label} start_params = {"label": encoded_label}
start_nodes = await self._query(start_query, **start_params) start_nodes = await self._query(start_query, **start_params)
if not start_nodes: if not start_nodes:
logger.warning(f"No nodes found with label '{entity_name_label}'!") logger.warning(f"No nodes found with label '{entity_name_label}'!")
return result return result
# Traverse graph from each start node # Traverse graph from each start node
for start_node_record in start_nodes: for start_node_record in start_nodes:
if "n" in start_node_record: if "n" in start_node_record:
start_node = start_node_record["n"]
start_id = str(start_node.get("id", ""))
# Use BFS to traverse graph # Use BFS to traverse graph
query = """ query = """
MATCH (start:`{label}`) MATCH (start:`{label}`)
@@ -823,25 +822,28 @@ class AGEStorage(BaseGraphStorage):
""" """
params = {"label": encoded_label, "max_depth": max_depth} params = {"label": encoded_label, "max_depth": max_depth}
results = await self._query(query, **params) results = await self._query(query, **params)
# Extract nodes and edges from results # Extract nodes and edges from results
for record in results: for record in results:
if "path_nodes" in record: if "path_nodes" in record:
# Process nodes # Process nodes
for node in record["path_nodes"]: for node in record["path_nodes"]:
node_id = str(node.get("id", "")) node_id = str(node.get("id", ""))
if node_id not in seen_nodes and len(seen_nodes) < max_graph_nodes: if (
node_id not in seen_nodes
and len(seen_nodes) < max_graph_nodes
):
node_properties = {k: v for k, v in node.items()} node_properties = {k: v for k, v in node.items()}
node_label = node.get("label", "") node_label = node.get("label", "")
result.nodes.append( result.nodes.append(
KnowledgeGraphNode( KnowledgeGraphNode(
id=node_id, id=node_id,
labels=[node_label], labels=[node_label],
properties=node_properties properties=node_properties,
) )
) )
seen_nodes.add(node_id) seen_nodes.add(node_id)
if "path_rels" in record: if "path_rels" in record:
# Process edges # Process edges
for rel in record["path_rels"]: for rel in record["path_rels"]:
@@ -856,11 +858,11 @@ class AGEStorage(BaseGraphStorage):
type=rel.get("label", "DIRECTED"), type=rel.get("label", "DIRECTED"),
source=source, source=source,
target=target, target=target,
properties=edge_properties properties=edge_properties,
) )
) )
seen_edges.add(edge_id) seen_edges.add(edge_id)
logger.info( logger.info(
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
) )

View File

@@ -194,7 +194,7 @@ class ChromaVectorDBStorage(BaseVectorStorage):
async def delete_entity(self, entity_name: str) -> None: async def delete_entity(self, entity_name: str) -> None:
"""Delete an entity by its ID. """Delete an entity by its ID.
Args: Args:
entity_name: The ID of the entity to delete entity_name: The ID of the entity to delete
""" """
@@ -206,24 +206,26 @@ class ChromaVectorDBStorage(BaseVectorStorage):
raise raise
async def delete_entity_relation(self, entity_name: str) -> None: async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete an entity and its relations by ID. """Delete an entity and its relations by ID.
In vector DB context, this is equivalent to delete_entity. In vector DB context, this is equivalent to delete_entity.
Args: Args:
entity_name: The ID of the entity to delete entity_name: The ID of the entity to delete
""" """
await self.delete_entity(entity_name) await self.delete_entity(entity_name)
async def delete(self, ids: list[str]) -> None: async def delete(self, ids: list[str]) -> None:
"""Delete vectors with specified IDs """Delete vectors with specified IDs
Args: Args:
ids: List of vector IDs to be deleted ids: List of vector IDs to be deleted
""" """
try: try:
logger.info(f"Deleting {len(ids)} vectors from {self.namespace}") logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
self._collection.delete(ids=ids) self._collection.delete(ids=ids)
logger.debug(f"Successfully deleted {len(ids)} vectors from {self.namespace}") logger.debug(
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
)
except Exception as e: except Exception as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {e}") logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
raise raise

View File

@@ -397,12 +397,12 @@ class GremlinStorage(BaseGraphStorage):
async def delete_node(self, node_id: str) -> None: async def delete_node(self, node_id: str) -> None:
"""Delete a node with the specified entity_name """Delete a node with the specified entity_name
Args: Args:
node_id: The entity_name of the node to delete node_id: The entity_name of the node to delete
""" """
entity_name = GremlinStorage._fix_name(node_id) entity_name = GremlinStorage._fix_name(node_id)
query = f"""g query = f"""g
.V().has('graph', {self.graph_name}) .V().has('graph', {self.graph_name})
.has('entity_name', {entity_name}) .has('entity_name', {entity_name})
@@ -413,7 +413,7 @@ class GremlinStorage(BaseGraphStorage):
logger.debug( logger.debug(
"{%s}: Deleted node with entity_name '%s'", "{%s}: Deleted node with entity_name '%s'",
inspect.currentframe().f_code.co_name, inspect.currentframe().f_code.co_name,
entity_name entity_name,
) )
except Exception as e: except Exception as e:
logger.error(f"Error during node deletion: {str(e)}") logger.error(f"Error during node deletion: {str(e)}")
@@ -425,13 +425,13 @@ class GremlinStorage(BaseGraphStorage):
""" """
Embed nodes using the specified algorithm. Embed nodes using the specified algorithm.
Currently, only node2vec is supported but never called. Currently, only node2vec is supported but never called.
Args: Args:
algorithm: The name of the embedding algorithm to use algorithm: The name of the embedding algorithm to use
Returns: Returns:
A tuple of (embeddings, node_ids) A tuple of (embeddings, node_ids)
Raises: Raises:
NotImplementedError: If the specified algorithm is not supported NotImplementedError: If the specified algorithm is not supported
ValueError: If the algorithm is not supported ValueError: If the algorithm is not supported
@@ -458,7 +458,7 @@ class GremlinStorage(BaseGraphStorage):
logger.debug( logger.debug(
"{%s}: Retrieved %d labels", "{%s}: Retrieved %d labels",
inspect.currentframe().f_code.co_name, inspect.currentframe().f_code.co_name,
len(labels) len(labels),
) )
return labels return labels
except Exception as e: except Exception as e:
@@ -471,7 +471,7 @@ class GremlinStorage(BaseGraphStorage):
""" """
Retrieve a connected subgraph of nodes where the entity_name includes the specified `node_label`. Retrieve a connected subgraph of nodes where the entity_name includes the specified `node_label`.
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000). Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
Args: Args:
node_label: Entity name of the starting node node_label: Entity name of the starting node
max_depth: Maximum depth of the subgraph max_depth: Maximum depth of the subgraph
@@ -482,12 +482,12 @@ class GremlinStorage(BaseGraphStorage):
result = KnowledgeGraph() result = KnowledgeGraph()
seen_nodes = set() seen_nodes = set()
seen_edges = set() seen_edges = set()
# Get maximum number of graph nodes from environment variable, default is 1000 # Get maximum number of graph nodes from environment variable, default is 1000
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
entity_name = GremlinStorage._fix_name(node_label) entity_name = GremlinStorage._fix_name(node_label)
# Handle special case for "*" label # Handle special case for "*" label
if node_label == "*": if node_label == "*":
# For "*", get all nodes and their edges (limited by MAX_GRAPH_NODES) # For "*", get all nodes and their edges (limited by MAX_GRAPH_NODES)
@@ -497,25 +497,27 @@ class GremlinStorage(BaseGraphStorage):
.elementMap() .elementMap()
""" """
nodes_result = await self._query(query) nodes_result = await self._query(query)
# Add nodes to result # Add nodes to result
for node_data in nodes_result: for node_data in nodes_result:
node_id = node_data.get('entity_name', str(node_data.get('id', ''))) node_id = node_data.get("entity_name", str(node_data.get("id", "")))
if str(node_id) in seen_nodes: if str(node_id) in seen_nodes:
continue continue
# Create node with properties # Create node with properties
node_properties = {k: v for k, v in node_data.items() if k not in ['id', 'label']} node_properties = {
k: v for k, v in node_data.items() if k not in ["id", "label"]
}
result.nodes.append( result.nodes.append(
KnowledgeGraphNode( KnowledgeGraphNode(
id=str(node_id), id=str(node_id),
labels=[str(node_id)], labels=[str(node_id)],
properties=node_properties properties=node_properties,
) )
) )
seen_nodes.add(str(node_id)) seen_nodes.add(str(node_id))
# Get and add edges # Get and add edges
if nodes_result: if nodes_result:
query = f"""g query = f"""g
@@ -530,30 +532,34 @@ class GremlinStorage(BaseGraphStorage):
.by(elementMap()) .by(elementMap())
""" """
edges_result = await self._query(query) edges_result = await self._query(query)
for path in edges_result: for path in edges_result:
if len(path) >= 3: # source -> edge -> target if len(path) >= 3: # source -> edge -> target
source = path[0] source = path[0]
edge_data = path[1] edge_data = path[1]
target = path[2] target = path[2]
source_id = source.get('entity_name', str(source.get('id', ''))) source_id = source.get("entity_name", str(source.get("id", "")))
target_id = target.get('entity_name', str(target.get('id', ''))) target_id = target.get("entity_name", str(target.get("id", "")))
edge_id = f"{source_id}-{target_id}" edge_id = f"{source_id}-{target_id}"
if edge_id in seen_edges: if edge_id in seen_edges:
continue continue
# Create edge with properties # Create edge with properties
edge_properties = {k: v for k, v in edge_data.items() if k not in ['id', 'label']} edge_properties = {
k: v
for k, v in edge_data.items()
if k not in ["id", "label"]
}
result.edges.append( result.edges.append(
KnowledgeGraphEdge( KnowledgeGraphEdge(
id=edge_id, id=edge_id,
type="DIRECTED", type="DIRECTED",
source=str(source_id), source=str(source_id),
target=str(target_id), target=str(target_id),
properties=edge_properties properties=edge_properties,
) )
) )
seen_edges.add(edge_id) seen_edges.add(edge_id)
@@ -570,30 +576,36 @@ class GremlinStorage(BaseGraphStorage):
.elementMap() .elementMap()
""" """
nodes_result = await self._query(query) nodes_result = await self._query(query)
# Add nodes to result # Add nodes to result
for node_data in nodes_result: for node_data in nodes_result:
node_id = node_data.get('entity_name', str(node_data.get('id', ''))) node_id = node_data.get("entity_name", str(node_data.get("id", "")))
if str(node_id) in seen_nodes: if str(node_id) in seen_nodes:
continue continue
# Create node with properties # Create node with properties
node_properties = {k: v for k, v in node_data.items() if k not in ['id', 'label']} node_properties = {
k: v for k, v in node_data.items() if k not in ["id", "label"]
}
result.nodes.append( result.nodes.append(
KnowledgeGraphNode( KnowledgeGraphNode(
id=str(node_id), id=str(node_id),
labels=[str(node_id)], labels=[str(node_id)],
properties=node_properties properties=node_properties,
) )
) )
seen_nodes.add(str(node_id)) seen_nodes.add(str(node_id))
# Get edges between the nodes in the result # Get edges between the nodes in the result
if nodes_result: if nodes_result:
node_ids = [n.get('entity_name', str(n.get('id', ''))) for n in nodes_result] node_ids = [
node_ids_query = ", ".join([GremlinStorage._to_value_map(nid) for nid in node_ids]) n.get("entity_name", str(n.get("id", ""))) for n in nodes_result
]
node_ids_query = ", ".join(
[GremlinStorage._to_value_map(nid) for nid in node_ids]
)
query = f"""g query = f"""g
.V().has('graph', {self.graph_name}) .V().has('graph', {self.graph_name})
.has('entity_name', within({node_ids_query})) .has('entity_name', within({node_ids_query}))
@@ -606,38 +618,42 @@ class GremlinStorage(BaseGraphStorage):
.by(elementMap()) .by(elementMap())
""" """
edges_result = await self._query(query) edges_result = await self._query(query)
for path in edges_result: for path in edges_result:
if len(path) >= 3: # source -> edge -> target if len(path) >= 3: # source -> edge -> target
source = path[0] source = path[0]
edge_data = path[1] edge_data = path[1]
target = path[2] target = path[2]
source_id = source.get('entity_name', str(source.get('id', ''))) source_id = source.get("entity_name", str(source.get("id", "")))
target_id = target.get('entity_name', str(target.get('id', ''))) target_id = target.get("entity_name", str(target.get("id", "")))
edge_id = f"{source_id}-{target_id}" edge_id = f"{source_id}-{target_id}"
if edge_id in seen_edges: if edge_id in seen_edges:
continue continue
# Create edge with properties # Create edge with properties
edge_properties = {k: v for k, v in edge_data.items() if k not in ['id', 'label']} edge_properties = {
k: v
for k, v in edge_data.items()
if k not in ["id", "label"]
}
result.edges.append( result.edges.append(
KnowledgeGraphEdge( KnowledgeGraphEdge(
id=edge_id, id=edge_id,
type="DIRECTED", type="DIRECTED",
source=str(source_id), source=str(source_id),
target=str(target_id), target=str(target_id),
properties=edge_properties properties=edge_properties,
) )
) )
seen_edges.add(edge_id) seen_edges.add(edge_id)
logger.info( logger.info(
"Subgraph query successful | Node count: %d | Edge count: %d", "Subgraph query successful | Node count: %d | Edge count: %d",
len(result.nodes), len(result.nodes),
len(result.edges) len(result.edges),
) )
return result return result
@@ -659,7 +675,7 @@ class GremlinStorage(BaseGraphStorage):
for source, target in edges: for source, target in edges:
entity_name_source = GremlinStorage._fix_name(source) entity_name_source = GremlinStorage._fix_name(source)
entity_name_target = GremlinStorage._fix_name(target) entity_name_target = GremlinStorage._fix_name(target)
query = f"""g query = f"""g
.V().has('graph', {self.graph_name}) .V().has('graph', {self.graph_name})
.has('entity_name', {entity_name_source}) .has('entity_name', {entity_name_source})
@@ -674,7 +690,7 @@ class GremlinStorage(BaseGraphStorage):
"{%s}: Deleted edge from '%s' to '%s'", "{%s}: Deleted edge from '%s' to '%s'",
inspect.currentframe().f_code.co_name, inspect.currentframe().f_code.co_name,
entity_name_source, entity_name_source,
entity_name_target entity_name_target,
) )
except Exception as e: except Exception as e:
logger.error(f"Error during edge deletion: {str(e)}") logger.error(f"Error during edge deletion: {str(e)}")

View File

@@ -125,83 +125,84 @@ class MilvusVectorDBStorage(BaseVectorStorage):
async def delete_entity(self, entity_name: str) -> None: async def delete_entity(self, entity_name: str) -> None:
"""Delete an entity from the vector database """Delete an entity from the vector database
Args: Args:
entity_name: The name of the entity to delete entity_name: The name of the entity to delete
""" """
try: try:
# Compute entity ID from name # Compute entity ID from name
entity_id = compute_mdhash_id(entity_name, prefix="ent-") entity_id = compute_mdhash_id(entity_name, prefix="ent-")
logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}") logger.debug(
f"Attempting to delete entity {entity_name} with ID {entity_id}"
)
# Delete the entity from Milvus collection # Delete the entity from Milvus collection
result = self._client.delete( result = self._client.delete(
collection_name=self.namespace, collection_name=self.namespace, pks=[entity_id]
pks=[entity_id]
) )
if result and result.get("delete_count", 0) > 0: if result and result.get("delete_count", 0) > 0:
logger.debug(f"Successfully deleted entity {entity_name}") logger.debug(f"Successfully deleted entity {entity_name}")
else: else:
logger.debug(f"Entity {entity_name} not found in storage") logger.debug(f"Entity {entity_name} not found in storage")
except Exception as e: except Exception as e:
logger.error(f"Error deleting entity {entity_name}: {e}") logger.error(f"Error deleting entity {entity_name}: {e}")
async def delete_entity_relation(self, entity_name: str) -> None: async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete all relations associated with an entity """Delete all relations associated with an entity
Args: Args:
entity_name: The name of the entity whose relations should be deleted entity_name: The name of the entity whose relations should be deleted
""" """
try: try:
# Search for relations where entity is either source or target # Search for relations where entity is either source or target
expr = f'src_id == "{entity_name}" or tgt_id == "{entity_name}"' expr = f'src_id == "{entity_name}" or tgt_id == "{entity_name}"'
# Find all relations involving this entity # Find all relations involving this entity
results = self._client.query( results = self._client.query(
collection_name=self.namespace, collection_name=self.namespace, filter=expr, output_fields=["id"]
filter=expr,
output_fields=["id"]
) )
if not results or len(results) == 0: if not results or len(results) == 0:
logger.debug(f"No relations found for entity {entity_name}") logger.debug(f"No relations found for entity {entity_name}")
return return
# Extract IDs of relations to delete # Extract IDs of relations to delete
relation_ids = [item["id"] for item in results] relation_ids = [item["id"] for item in results]
logger.debug(f"Found {len(relation_ids)} relations for entity {entity_name}") logger.debug(
f"Found {len(relation_ids)} relations for entity {entity_name}"
)
# Delete the relations # Delete the relations
if relation_ids: if relation_ids:
delete_result = self._client.delete( delete_result = self._client.delete(
collection_name=self.namespace, collection_name=self.namespace, pks=relation_ids
pks=relation_ids
) )
logger.debug(f"Deleted {delete_result.get('delete_count', 0)} relations for {entity_name}") logger.debug(
f"Deleted {delete_result.get('delete_count', 0)} relations for {entity_name}"
)
except Exception as e: except Exception as e:
logger.error(f"Error deleting relations for {entity_name}: {e}") logger.error(f"Error deleting relations for {entity_name}: {e}")
async def delete(self, ids: list[str]) -> None: async def delete(self, ids: list[str]) -> None:
"""Delete vectors with specified IDs """Delete vectors with specified IDs
Args: Args:
ids: List of vector IDs to be deleted ids: List of vector IDs to be deleted
""" """
try: try:
# Delete vectors by IDs # Delete vectors by IDs
result = self._client.delete( result = self._client.delete(collection_name=self.namespace, pks=ids)
collection_name=self.namespace,
pks=ids
)
if result and result.get("delete_count", 0) > 0: if result and result.get("delete_count", 0) > 0:
logger.debug(f"Successfully deleted {result.get('delete_count', 0)} vectors from {self.namespace}") logger.debug(
f"Successfully deleted {result.get('delete_count', 0)} vectors from {self.namespace}"
)
else: else:
logger.debug(f"No vectors were deleted from {self.namespace}") logger.debug(f"No vectors were deleted from {self.namespace}")
except Exception as e: except Exception as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {e}") logger.error(f"Error while deleting vectors from {self.namespace}: {e}")

View File

@@ -804,16 +804,15 @@ class MongoGraphStorage(BaseGraphStorage):
logger.info(f"Deleting {len(nodes)} nodes") logger.info(f"Deleting {len(nodes)} nodes")
if not nodes: if not nodes:
return return
# 1. Remove all edges referencing these nodes (remove from edges array of other nodes) # 1. Remove all edges referencing these nodes (remove from edges array of other nodes)
await self.collection.update_many( await self.collection.update_many(
{}, {}, {"$pull": {"edges": {"target": {"$in": nodes}}}}
{"$pull": {"edges": {"target": {"$in": nodes}}}}
) )
# 2. Delete the node documents # 2. Delete the node documents
await self.collection.delete_many({"_id": {"$in": nodes}}) await self.collection.delete_many({"_id": {"$in": nodes}})
logger.debug(f"Successfully deleted nodes: {nodes}") logger.debug(f"Successfully deleted nodes: {nodes}")
async def remove_edges(self, edges: list[tuple[str, str]]) -> None: async def remove_edges(self, edges: list[tuple[str, str]]) -> None:
@@ -825,20 +824,19 @@ class MongoGraphStorage(BaseGraphStorage):
logger.info(f"Deleting {len(edges)} edges") logger.info(f"Deleting {len(edges)} edges")
if not edges: if not edges:
return return
update_tasks = [] update_tasks = []
for source, target in edges: for source, target in edges:
# Remove edge pointing to target from source node's edges array # Remove edge pointing to target from source node's edges array
update_tasks.append( update_tasks.append(
self.collection.update_one( self.collection.update_one(
{"_id": source}, {"_id": source}, {"$pull": {"edges": {"target": target}}}
{"$pull": {"edges": {"target": target}}}
) )
) )
if update_tasks: if update_tasks:
await asyncio.gather(*update_tasks) await asyncio.gather(*update_tasks)
logger.debug(f"Successfully deleted edges: {edges}") logger.debug(f"Successfully deleted edges: {edges}")
@@ -987,23 +985,29 @@ class MongoVectorDBStorage(BaseVectorStorage):
logger.info(f"Deleting {len(ids)} vectors from {self.namespace}") logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
if not ids: if not ids:
return return
try: try:
result = await self._data.delete_many({"_id": {"$in": ids}}) result = await self._data.delete_many({"_id": {"$in": ids}})
logger.debug(f"Successfully deleted {result.deleted_count} vectors from {self.namespace}") logger.debug(
f"Successfully deleted {result.deleted_count} vectors from {self.namespace}"
)
except PyMongoError as e: except PyMongoError as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {str(e)}") logger.error(
f"Error while deleting vectors from {self.namespace}: {str(e)}"
)
async def delete_entity(self, entity_name: str) -> None: async def delete_entity(self, entity_name: str) -> None:
"""Delete an entity by its name """Delete an entity by its name
Args: Args:
entity_name: Name of the entity to delete entity_name: Name of the entity to delete
""" """
try: try:
entity_id = compute_mdhash_id(entity_name, prefix="ent-") entity_id = compute_mdhash_id(entity_name, prefix="ent-")
logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}") logger.debug(
f"Attempting to delete entity {entity_name} with ID {entity_id}"
)
result = await self._data.delete_one({"_id": entity_id}) result = await self._data.delete_one({"_id": entity_id})
if result.deleted_count > 0: if result.deleted_count > 0:
logger.debug(f"Successfully deleted entity {entity_name}") logger.debug(f"Successfully deleted entity {entity_name}")
@@ -1014,7 +1018,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
async def delete_entity_relation(self, entity_name: str) -> None: async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete all relations associated with an entity """Delete all relations associated with an entity
Args: Args:
entity_name: Name of the entity whose relations should be deleted entity_name: Name of the entity whose relations should be deleted
""" """
@@ -1024,15 +1028,17 @@ class MongoVectorDBStorage(BaseVectorStorage):
{"$or": [{"src_id": entity_name}, {"tgt_id": entity_name}]} {"$or": [{"src_id": entity_name}, {"tgt_id": entity_name}]}
) )
relations = await relations_cursor.to_list(length=None) relations = await relations_cursor.to_list(length=None)
if not relations: if not relations:
logger.debug(f"No relations found for entity {entity_name}") logger.debug(f"No relations found for entity {entity_name}")
return return
# Extract IDs of relations to delete # Extract IDs of relations to delete
relation_ids = [relation["_id"] for relation in relations] relation_ids = [relation["_id"] for relation in relations]
logger.debug(f"Found {len(relation_ids)} relations for entity {entity_name}") logger.debug(
f"Found {len(relation_ids)} relations for entity {entity_name}"
)
# Delete the relations # Delete the relations
result = await self._data.delete_many({"_id": {"$in": relation_ids}}) result = await self._data.delete_many({"_id": {"$in": relation_ids}})
logger.debug(f"Deleted {result.deleted_count} relations for {entity_name}") logger.debug(f"Deleted {result.deleted_count} relations for {entity_name}")

View File

@@ -444,27 +444,29 @@ class OracleVectorDBStorage(BaseVectorStorage):
async def delete(self, ids: list[str]) -> None: async def delete(self, ids: list[str]) -> None:
"""Delete vectors with specified IDs """Delete vectors with specified IDs
Args: Args:
ids: List of vector IDs to be deleted ids: List of vector IDs to be deleted
""" """
if not ids: if not ids:
return return
try: try:
SQL = SQL_TEMPLATES["delete_vectors"].format( SQL = SQL_TEMPLATES["delete_vectors"].format(
ids=",".join([f"'{id}'" for id in ids]) ids=",".join([f"'{id}'" for id in ids])
) )
params = {"workspace": self.db.workspace} params = {"workspace": self.db.workspace}
await self.db.execute(SQL, params) await self.db.execute(SQL, params)
logger.info(f"Successfully deleted {len(ids)} vectors from {self.namespace}") logger.info(
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
)
except Exception as e: except Exception as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {e}") logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
raise raise
async def delete_entity(self, entity_name: str) -> None: async def delete_entity(self, entity_name: str) -> None:
"""Delete entity by name """Delete entity by name
Args: Args:
entity_name: Name of the entity to delete entity_name: Name of the entity to delete
""" """
@@ -479,7 +481,7 @@ class OracleVectorDBStorage(BaseVectorStorage):
async def delete_entity_relation(self, entity_name: str) -> None: async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete all relations connected to an entity """Delete all relations connected to an entity
Args: Args:
entity_name: Name of the entity whose relations should be deleted entity_name: Name of the entity whose relations should be deleted
""" """
@@ -713,7 +715,7 @@ class OracleGraphStorage(BaseGraphStorage):
async def delete_node(self, node_id: str) -> None: async def delete_node(self, node_id: str) -> None:
"""Delete a node from the graph """Delete a node from the graph
Args: Args:
node_id: ID of the node to delete node_id: ID of the node to delete
""" """
@@ -722,33 +724,35 @@ class OracleGraphStorage(BaseGraphStorage):
delete_relations_sql = SQL_TEMPLATES["delete_entity_relations"] delete_relations_sql = SQL_TEMPLATES["delete_entity_relations"]
params_relations = {"workspace": self.db.workspace, "entity_name": node_id} params_relations = {"workspace": self.db.workspace, "entity_name": node_id}
await self.db.execute(delete_relations_sql, params_relations) await self.db.execute(delete_relations_sql, params_relations)
# Then delete the node itself # Then delete the node itself
delete_node_sql = SQL_TEMPLATES["delete_entity"] delete_node_sql = SQL_TEMPLATES["delete_entity"]
params_node = {"workspace": self.db.workspace, "entity_name": node_id} params_node = {"workspace": self.db.workspace, "entity_name": node_id}
await self.db.execute(delete_node_sql, params_node) await self.db.execute(delete_node_sql, params_node)
logger.info(f"Successfully deleted node {node_id} and all its relationships") logger.info(
f"Successfully deleted node {node_id} and all its relationships"
)
except Exception as e: except Exception as e:
logger.error(f"Error deleting node {node_id}: {e}") logger.error(f"Error deleting node {node_id}: {e}")
raise raise
async def get_all_labels(self) -> list[str]: async def get_all_labels(self) -> list[str]:
"""Get all unique entity types (labels) in the graph """Get all unique entity types (labels) in the graph
Returns: Returns:
List of unique entity types/labels List of unique entity types/labels
""" """
try: try:
SQL = """ SQL = """
SELECT DISTINCT entity_type SELECT DISTINCT entity_type
FROM LIGHTRAG_GRAPH_NODES FROM LIGHTRAG_GRAPH_NODES
WHERE workspace = :workspace WHERE workspace = :workspace
ORDER BY entity_type ORDER BY entity_type
""" """
params = {"workspace": self.db.workspace} params = {"workspace": self.db.workspace}
results = await self.db.query(SQL, params, multirows=True) results = await self.db.query(SQL, params, multirows=True)
if results: if results:
labels = [row["entity_type"] for row in results] labels = [row["entity_type"] for row in results]
return labels return labels
@@ -762,26 +766,26 @@ class OracleGraphStorage(BaseGraphStorage):
self, node_label: str, max_depth: int = 5 self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph: ) -> KnowledgeGraph:
"""Retrieve a connected subgraph starting from nodes matching the given label """Retrieve a connected subgraph starting from nodes matching the given label
Maximum number of nodes is constrained by MAX_GRAPH_NODES environment variable. Maximum number of nodes is constrained by MAX_GRAPH_NODES environment variable.
Prioritizes nodes by: Prioritizes nodes by:
1. Nodes matching the specified label 1. Nodes matching the specified label
2. Nodes directly connected to matching nodes 2. Nodes directly connected to matching nodes
3. Node degree (number of connections) 3. Node degree (number of connections)
Args: Args:
node_label: Label to match for starting nodes (use "*" for all nodes) node_label: Label to match for starting nodes (use "*" for all nodes)
max_depth: Maximum depth of traversal from starting nodes max_depth: Maximum depth of traversal from starting nodes
Returns: Returns:
KnowledgeGraph object containing nodes and edges KnowledgeGraph object containing nodes and edges
""" """
result = KnowledgeGraph() result = KnowledgeGraph()
try: try:
# Define maximum number of nodes to return # Define maximum number of nodes to return
max_graph_nodes = int(os.environ.get("MAX_GRAPH_NODES", 1000)) max_graph_nodes = int(os.environ.get("MAX_GRAPH_NODES", 1000))
if node_label == "*": if node_label == "*":
# For "*" label, get all nodes up to the limit # For "*" label, get all nodes up to the limit
nodes_sql = """ nodes_sql = """
@@ -791,30 +795,33 @@ class OracleGraphStorage(BaseGraphStorage):
ORDER BY id ORDER BY id
FETCH FIRST :limit ROWS ONLY FETCH FIRST :limit ROWS ONLY
""" """
nodes_params = {"workspace": self.db.workspace, "limit": max_graph_nodes} nodes_params = {
"workspace": self.db.workspace,
"limit": max_graph_nodes,
}
nodes = await self.db.query(nodes_sql, nodes_params, multirows=True) nodes = await self.db.query(nodes_sql, nodes_params, multirows=True)
else: else:
# For specific label, find matching nodes and related nodes # For specific label, find matching nodes and related nodes
nodes_sql = """ nodes_sql = """
WITH matching_nodes AS ( WITH matching_nodes AS (
SELECT name SELECT name
FROM LIGHTRAG_GRAPH_NODES FROM LIGHTRAG_GRAPH_NODES
WHERE workspace = :workspace WHERE workspace = :workspace
AND (name LIKE '%' || :node_label || '%' OR entity_type LIKE '%' || :node_label || '%') AND (name LIKE '%' || :node_label || '%' OR entity_type LIKE '%' || :node_label || '%')
) )
SELECT n.name, n.entity_type, n.description, n.source_chunk_id, SELECT n.name, n.entity_type, n.description, n.source_chunk_id,
CASE CASE
WHEN n.name IN (SELECT name FROM matching_nodes) THEN 2 WHEN n.name IN (SELECT name FROM matching_nodes) THEN 2
WHEN EXISTS ( WHEN EXISTS (
SELECT 1 FROM LIGHTRAG_GRAPH_EDGES e SELECT 1 FROM LIGHTRAG_GRAPH_EDGES e
WHERE workspace = :workspace WHERE workspace = :workspace
AND ((e.source_name = n.name AND e.target_name IN (SELECT name FROM matching_nodes)) AND ((e.source_name = n.name AND e.target_name IN (SELECT name FROM matching_nodes))
OR (e.target_name = n.name AND e.source_name IN (SELECT name FROM matching_nodes))) OR (e.target_name = n.name AND e.source_name IN (SELECT name FROM matching_nodes)))
) THEN 1 ) THEN 1
ELSE 0 ELSE 0
END AS priority, END AS priority,
(SELECT COUNT(*) FROM LIGHTRAG_GRAPH_EDGES e (SELECT COUNT(*) FROM LIGHTRAG_GRAPH_EDGES e
WHERE workspace = :workspace WHERE workspace = :workspace
AND (e.source_name = n.name OR e.target_name = n.name)) AS degree AND (e.source_name = n.name OR e.target_name = n.name)) AS degree
FROM LIGHTRAG_GRAPH_NODES n FROM LIGHTRAG_GRAPH_NODES n
WHERE workspace = :workspace WHERE workspace = :workspace
@@ -822,43 +829,41 @@ class OracleGraphStorage(BaseGraphStorage):
FETCH FIRST :limit ROWS ONLY FETCH FIRST :limit ROWS ONLY
""" """
nodes_params = { nodes_params = {
"workspace": self.db.workspace, "workspace": self.db.workspace,
"node_label": node_label, "node_label": node_label,
"limit": max_graph_nodes "limit": max_graph_nodes,
} }
nodes = await self.db.query(nodes_sql, nodes_params, multirows=True) nodes = await self.db.query(nodes_sql, nodes_params, multirows=True)
if not nodes: if not nodes:
logger.warning(f"No nodes found matching '{node_label}'") logger.warning(f"No nodes found matching '{node_label}'")
return result return result
# Create mapping of node IDs to be used to filter edges # Create mapping of node IDs to be used to filter edges
node_names = [node["name"] for node in nodes] node_names = [node["name"] for node in nodes]
# Add nodes to result # Add nodes to result
seen_nodes = set() seen_nodes = set()
for node in nodes: for node in nodes:
node_id = node["name"] node_id = node["name"]
if node_id in seen_nodes: if node_id in seen_nodes:
continue continue
# Create node properties dictionary # Create node properties dictionary
properties = { properties = {
"entity_type": node["entity_type"], "entity_type": node["entity_type"],
"description": node["description"] or "", "description": node["description"] or "",
"source_id": node["source_chunk_id"] or "" "source_id": node["source_chunk_id"] or "",
} }
# Add node to result # Add node to result
result.nodes.append( result.nodes.append(
KnowledgeGraphNode( KnowledgeGraphNode(
id=node_id, id=node_id, labels=[node["entity_type"]], properties=properties
labels=[node["entity_type"]],
properties=properties
) )
) )
seen_nodes.add(node_id) seen_nodes.add(node_id)
# Get edges between these nodes # Get edges between these nodes
edges_sql = """ edges_sql = """
SELECT source_name, target_name, weight, keywords, description, source_chunk_id SELECT source_name, target_name, weight, keywords, description, source_chunk_id
@@ -868,30 +873,27 @@ class OracleGraphStorage(BaseGraphStorage):
AND target_name IN (SELECT COLUMN_VALUE FROM TABLE(CAST(:node_names AS SYS.ODCIVARCHAR2LIST))) AND target_name IN (SELECT COLUMN_VALUE FROM TABLE(CAST(:node_names AS SYS.ODCIVARCHAR2LIST)))
ORDER BY id ORDER BY id
""" """
edges_params = { edges_params = {"workspace": self.db.workspace, "node_names": node_names}
"workspace": self.db.workspace,
"node_names": node_names
}
edges = await self.db.query(edges_sql, edges_params, multirows=True) edges = await self.db.query(edges_sql, edges_params, multirows=True)
# Add edges to result # Add edges to result
seen_edges = set() seen_edges = set()
for edge in edges: for edge in edges:
source = edge["source_name"] source = edge["source_name"]
target = edge["target_name"] target = edge["target_name"]
edge_id = f"{source}-{target}" edge_id = f"{source}-{target}"
if edge_id in seen_edges: if edge_id in seen_edges:
continue continue
# Create edge properties dictionary # Create edge properties dictionary
properties = { properties = {
"weight": edge["weight"] or 0.0, "weight": edge["weight"] or 0.0,
"keywords": edge["keywords"] or "", "keywords": edge["keywords"] or "",
"description": edge["description"] or "", "description": edge["description"] or "",
"source_id": edge["source_chunk_id"] or "" "source_id": edge["source_chunk_id"] or "",
} }
# Add edge to result # Add edge to result
result.edges.append( result.edges.append(
KnowledgeGraphEdge( KnowledgeGraphEdge(
@@ -899,18 +901,18 @@ class OracleGraphStorage(BaseGraphStorage):
type="RELATED", type="RELATED",
source=source, source=source,
target=target, target=target,
properties=properties properties=properties,
) )
) )
seen_edges.add(edge_id) seen_edges.add(edge_id)
logger.info( logger.info(
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
) )
except Exception as e: except Exception as e:
logger.error(f"Error retrieving knowledge graph: {e}") logger.error(f"Error retrieving knowledge graph: {e}")
return result return result
@@ -1166,8 +1168,8 @@ SQL_TEMPLATES = {
"delete_vectors": "DELETE FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=:workspace AND id IN ({ids})", "delete_vectors": "DELETE FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=:workspace AND id IN ({ids})",
"delete_entity": "DELETE FROM LIGHTRAG_GRAPH_NODES WHERE workspace=:workspace AND name=:entity_name", "delete_entity": "DELETE FROM LIGHTRAG_GRAPH_NODES WHERE workspace=:workspace AND name=:entity_name",
"delete_entity_relations": "DELETE FROM LIGHTRAG_GRAPH_EDGES WHERE workspace=:workspace AND (source_name=:entity_name OR target_name=:entity_name)", "delete_entity_relations": "DELETE FROM LIGHTRAG_GRAPH_EDGES WHERE workspace=:workspace AND (source_name=:entity_name OR target_name=:entity_name)",
"delete_node": """DELETE FROM GRAPH_TABLE (lightrag_graph "delete_node": """DELETE FROM GRAPH_TABLE (lightrag_graph
MATCH (a) MATCH (a)
WHERE a.workspace=:workspace AND a.name=:node_id WHERE a.workspace=:workspace AND a.name=:node_id
ACTION DELETE a)""", ACTION DELETE a)""",
} }

View File

@@ -527,11 +527,15 @@ class PGVectorStorage(BaseVectorStorage):
return return
ids_list = ",".join([f"'{id}'" for id in ids]) ids_list = ",".join([f"'{id}'" for id in ids])
delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id IN ({ids_list})" delete_sql = (
f"DELETE FROM {table_name} WHERE workspace=$1 AND id IN ({ids_list})"
)
try: try:
await self.db.execute(delete_sql, {"workspace": self.db.workspace}) await self.db.execute(delete_sql, {"workspace": self.db.workspace})
logger.debug(f"Successfully deleted {len(ids)} vectors from {self.namespace}") logger.debug(
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
)
except Exception as e: except Exception as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {e}") logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
@@ -543,12 +547,11 @@ class PGVectorStorage(BaseVectorStorage):
""" """
try: try:
# Construct SQL to delete the entity # Construct SQL to delete the entity
delete_sql = """DELETE FROM LIGHTRAG_VDB_ENTITY delete_sql = """DELETE FROM LIGHTRAG_VDB_ENTITY
WHERE workspace=$1 AND entity_name=$2""" WHERE workspace=$1 AND entity_name=$2"""
await self.db.execute( await self.db.execute(
delete_sql, delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name}
{"workspace": self.db.workspace, "entity_name": entity_name}
) )
logger.debug(f"Successfully deleted entity {entity_name}") logger.debug(f"Successfully deleted entity {entity_name}")
except Exception as e: except Exception as e:
@@ -562,12 +565,11 @@ class PGVectorStorage(BaseVectorStorage):
""" """
try: try:
# Delete relations where the entity is either the source or target # Delete relations where the entity is either the source or target
delete_sql = """DELETE FROM LIGHTRAG_VDB_RELATION delete_sql = """DELETE FROM LIGHTRAG_VDB_RELATION
WHERE workspace=$1 AND (source_id=$2 OR target_id=$2)""" WHERE workspace=$1 AND (source_id=$2 OR target_id=$2)"""
await self.db.execute( await self.db.execute(
delete_sql, delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name}
{"workspace": self.db.workspace, "entity_name": entity_name}
) )
logger.debug(f"Successfully deleted relations for entity {entity_name}") logger.debug(f"Successfully deleted relations for entity {entity_name}")
except Exception as e: except Exception as e:
@@ -1167,7 +1169,9 @@ class PGGraphStorage(BaseGraphStorage):
Args: Args:
node_ids (list[str]): A list of node IDs to remove. node_ids (list[str]): A list of node IDs to remove.
""" """
encoded_node_ids = [self._encode_graph_label(node_id.strip('"')) for node_id in node_ids] encoded_node_ids = [
self._encode_graph_label(node_id.strip('"')) for node_id in node_ids
]
node_id_list = ", ".join([f'"{node_id}"' for node_id in encoded_node_ids]) node_id_list = ", ".join([f'"{node_id}"' for node_id in encoded_node_ids])
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
@@ -1189,7 +1193,13 @@ class PGGraphStorage(BaseGraphStorage):
Args: Args:
edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id). edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id).
""" """
encoded_edges = [(self._encode_graph_label(src.strip('"')), self._encode_graph_label(tgt.strip('"'))) for src, tgt in edges] encoded_edges = [
(
self._encode_graph_label(src.strip('"')),
self._encode_graph_label(tgt.strip('"')),
)
for src, tgt in edges
]
edge_list = ", ".join([f'["{src}", "{tgt}"]' for src, tgt in encoded_edges]) edge_list = ", ".join([f'["{src}", "{tgt}"]' for src, tgt in encoded_edges])
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
@@ -1211,10 +1221,13 @@ class PGGraphStorage(BaseGraphStorage):
Returns: Returns:
list[str]: A list of all labels in the graph. list[str]: A list of all labels in the graph.
""" """
query = """SELECT * FROM cypher('%s', $$ query = (
"""SELECT * FROM cypher('%s', $$
MATCH (n:Entity) MATCH (n:Entity)
RETURN DISTINCT n.node_id AS label RETURN DISTINCT n.node_id AS label
$$) AS (label text)""" % self.graph_name $$) AS (label text)"""
% self.graph_name
)
results = await self._query(query) results = await self._query(query)
labels = [self._decode_graph_label(result["label"]) for result in results] labels = [self._decode_graph_label(result["label"]) for result in results]
@@ -1260,7 +1273,10 @@ class PGGraphStorage(BaseGraphStorage):
OPTIONAL MATCH (n)-[r]->(m:Entity) OPTIONAL MATCH (n)-[r]->(m:Entity)
RETURN n, r, m RETURN n, r, m
LIMIT %d LIMIT %d
$$) AS (n agtype, r agtype, m agtype)""" % (self.graph_name, MAX_GRAPH_NODES) $$) AS (n agtype, r agtype, m agtype)""" % (
self.graph_name,
MAX_GRAPH_NODES,
)
else: else:
encoded_node_label = self._encode_graph_label(node_label.strip('"')) encoded_node_label = self._encode_graph_label(node_label.strip('"'))
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
@@ -1268,7 +1284,12 @@ class PGGraphStorage(BaseGraphStorage):
OPTIONAL MATCH p = (n)-[*..%d]-(m) OPTIONAL MATCH p = (n)-[*..%d]-(m)
RETURN nodes(p) AS nodes, relationships(p) AS relationships RETURN nodes(p) AS nodes, relationships(p) AS relationships
LIMIT %d LIMIT %d
$$) AS (nodes agtype[], relationships agtype[])""" % (self.graph_name, encoded_node_label, max_depth, MAX_GRAPH_NODES) $$) AS (nodes agtype[], relationships agtype[])""" % (
self.graph_name,
encoded_node_label,
max_depth,
MAX_GRAPH_NODES,
)
results = await self._query(query) results = await self._query(query)
@@ -1305,29 +1326,6 @@ class PGGraphStorage(BaseGraphStorage):
return kg return kg
async def get_all_labels(self) -> list[str]:
"""
Get all node labels in the graph
Returns:
[label1, label2, ...] # Alphabetically sorted label list
"""
query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity)
RETURN DISTINCT n.node_id AS label
ORDER BY label
$$) AS (label agtype)""" % (self.graph_name)
try:
results = await self._query(query)
labels = []
for record in results:
if record["label"]:
labels.append(self._decode_graph_label(record["label"]))
return labels
except Exception as e:
logger.error(f"Error getting all labels: {str(e)}")
return []
async def drop(self) -> None: async def drop(self) -> None:
"""Drop the storage""" """Drop the storage"""
drop_sql = SQL_TEMPLATES["drop_vdb_entity"] drop_sql = SQL_TEMPLATES["drop_vdb_entity"]

View File

@@ -143,7 +143,7 @@ class QdrantVectorDBStorage(BaseVectorStorage):
async def delete(self, ids: List[str]) -> None: async def delete(self, ids: List[str]) -> None:
"""Delete vectors with specified IDs """Delete vectors with specified IDs
Args: Args:
ids: List of vector IDs to be deleted ids: List of vector IDs to be deleted
""" """
@@ -156,30 +156,34 @@ class QdrantVectorDBStorage(BaseVectorStorage):
points_selector=models.PointIdsList( points_selector=models.PointIdsList(
points=qdrant_ids, points=qdrant_ids,
), ),
wait=True wait=True,
)
logger.debug(
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
) )
logger.debug(f"Successfully deleted {len(ids)} vectors from {self.namespace}")
except Exception as e: except Exception as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {e}") logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
async def delete_entity(self, entity_name: str) -> None: async def delete_entity(self, entity_name: str) -> None:
"""Delete an entity by name """Delete an entity by name
Args: Args:
entity_name: Name of the entity to delete entity_name: Name of the entity to delete
""" """
try: try:
# Generate the entity ID # Generate the entity ID
entity_id = compute_mdhash_id_for_qdrant(entity_name, prefix="ent-") entity_id = compute_mdhash_id_for_qdrant(entity_name, prefix="ent-")
logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}") logger.debug(
f"Attempting to delete entity {entity_name} with ID {entity_id}"
)
# Delete the entity point from the collection # Delete the entity point from the collection
self._client.delete( self._client.delete(
collection_name=self.namespace, collection_name=self.namespace,
points_selector=models.PointIdsList( points_selector=models.PointIdsList(
points=[entity_id], points=[entity_id],
), ),
wait=True wait=True,
) )
logger.debug(f"Successfully deleted entity {entity_name}") logger.debug(f"Successfully deleted entity {entity_name}")
except Exception as e: except Exception as e:
@@ -187,7 +191,7 @@ class QdrantVectorDBStorage(BaseVectorStorage):
async def delete_entity_relation(self, entity_name: str) -> None: async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete all relations associated with an entity """Delete all relations associated with an entity
Args: Args:
entity_name: Name of the entity whose relations should be deleted entity_name: Name of the entity whose relations should be deleted
""" """
@@ -198,23 +202,21 @@ class QdrantVectorDBStorage(BaseVectorStorage):
scroll_filter=models.Filter( scroll_filter=models.Filter(
should=[ should=[
models.FieldCondition( models.FieldCondition(
key="src_id", key="src_id", match=models.MatchValue(value=entity_name)
match=models.MatchValue(value=entity_name)
), ),
models.FieldCondition( models.FieldCondition(
key="tgt_id", key="tgt_id", match=models.MatchValue(value=entity_name)
match=models.MatchValue(value=entity_name) ),
)
] ]
), ),
with_payload=True, with_payload=True,
limit=1000 # Adjust as needed for your use case limit=1000, # Adjust as needed for your use case
) )
# Extract points that need to be deleted # Extract points that need to be deleted
relation_points = results[0] relation_points = results[0]
ids_to_delete = [point.id for point in relation_points] ids_to_delete = [point.id for point in relation_points]
if ids_to_delete: if ids_to_delete:
# Delete the relations # Delete the relations
self._client.delete( self._client.delete(
@@ -222,9 +224,11 @@ class QdrantVectorDBStorage(BaseVectorStorage):
points_selector=models.PointIdsList( points_selector=models.PointIdsList(
points=ids_to_delete, points=ids_to_delete,
), ),
wait=True wait=True,
)
logger.debug(
f"Deleted {len(ids_to_delete)} relations for {entity_name}"
) )
logger.debug(f"Deleted {len(ids_to_delete)} relations for {entity_name}")
else: else:
logger.debug(f"No relations found for entity {entity_name}") logger.debug(f"No relations found for entity {entity_name}")
except Exception as e: except Exception as e:

View File

@@ -67,35 +67,39 @@ class RedisKVStorage(BaseKVStorage):
async def delete(self, ids: list[str]) -> None: async def delete(self, ids: list[str]) -> None:
"""Delete entries with specified IDs """Delete entries with specified IDs
Args: Args:
ids: List of entry IDs to be deleted ids: List of entry IDs to be deleted
""" """
if not ids: if not ids:
return return
pipe = self._redis.pipeline() pipe = self._redis.pipeline()
for id in ids: for id in ids:
pipe.delete(f"{self.namespace}:{id}") pipe.delete(f"{self.namespace}:{id}")
results = await pipe.execute() results = await pipe.execute()
deleted_count = sum(results) deleted_count = sum(results)
logger.info(f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}") logger.info(
f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}"
)
async def delete_entity(self, entity_name: str) -> None: async def delete_entity(self, entity_name: str) -> None:
"""Delete an entity by name """Delete an entity by name
Args: Args:
entity_name: Name of the entity to delete entity_name: Name of the entity to delete
""" """
try: try:
entity_id = compute_mdhash_id(entity_name, prefix="ent-") entity_id = compute_mdhash_id(entity_name, prefix="ent-")
logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}") logger.debug(
f"Attempting to delete entity {entity_name} with ID {entity_id}"
)
# Delete the entity # Delete the entity
result = await self._redis.delete(f"{self.namespace}:{entity_id}") result = await self._redis.delete(f"{self.namespace}:{entity_id}")
if result: if result:
logger.debug(f"Successfully deleted entity {entity_name}") logger.debug(f"Successfully deleted entity {entity_name}")
else: else:
@@ -105,7 +109,7 @@ class RedisKVStorage(BaseKVStorage):
async def delete_entity_relation(self, entity_name: str) -> None: async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete all relations associated with an entity """Delete all relations associated with an entity
Args: Args:
entity_name: Name of the entity whose relations should be deleted entity_name: Name of the entity whose relations should be deleted
""" """
@@ -114,29 +118,32 @@ class RedisKVStorage(BaseKVStorage):
cursor = 0 cursor = 0
relation_keys = [] relation_keys = []
pattern = f"{self.namespace}:*" pattern = f"{self.namespace}:*"
while True: while True:
cursor, keys = await self._redis.scan(cursor, match=pattern) cursor, keys = await self._redis.scan(cursor, match=pattern)
# For each key, get the value and check if it's related to entity_name # For each key, get the value and check if it's related to entity_name
for key in keys: for key in keys:
value = await self._redis.get(key) value = await self._redis.get(key)
if value: if value:
data = json.loads(value) data = json.loads(value)
# Check if this is a relation involving the entity # Check if this is a relation involving the entity
if data.get("src_id") == entity_name or data.get("tgt_id") == entity_name: if (
data.get("src_id") == entity_name
or data.get("tgt_id") == entity_name
):
relation_keys.append(key) relation_keys.append(key)
# Exit loop when cursor returns to 0 # Exit loop when cursor returns to 0
if cursor == 0: if cursor == 0:
break break
# Delete the relation keys # Delete the relation keys
if relation_keys: if relation_keys:
deleted = await self._redis.delete(*relation_keys) deleted = await self._redis.delete(*relation_keys)
logger.debug(f"Deleted {deleted} relations for {entity_name}") logger.debug(f"Deleted {deleted} relations for {entity_name}")
else: else:
logger.debug(f"No relations found for entity {entity_name}") logger.debug(f"No relations found for entity {entity_name}")
except Exception as e: except Exception as e:
logger.error(f"Error deleting relations for {entity_name}: {e}") logger.error(f"Error deleting relations for {entity_name}: {e}")

View File

@@ -567,62 +567,68 @@ class TiDBGraphStorage(BaseGraphStorage):
async def delete_node(self, node_id: str) -> None: async def delete_node(self, node_id: str) -> None:
"""Delete a node and all its related edges """Delete a node and all its related edges
Args: Args:
node_id: The ID of the node to delete node_id: The ID of the node to delete
""" """
# First delete all edges related to this node # First delete all edges related to this node
await self.db.execute(SQL_TEMPLATES["delete_node_edges"], await self.db.execute(
{"name": node_id, "workspace": self.db.workspace}) SQL_TEMPLATES["delete_node_edges"],
{"name": node_id, "workspace": self.db.workspace},
)
# Then delete the node itself # Then delete the node itself
await self.db.execute(SQL_TEMPLATES["delete_node"], await self.db.execute(
{"name": node_id, "workspace": self.db.workspace}) SQL_TEMPLATES["delete_node"],
{"name": node_id, "workspace": self.db.workspace},
logger.debug(f"Node {node_id} and its related edges have been deleted from the graph") )
logger.debug(
f"Node {node_id} and its related edges have been deleted from the graph"
)
async def get_all_labels(self) -> list[str]: async def get_all_labels(self) -> list[str]:
"""Get all entity types (labels) in the database """Get all entity types (labels) in the database
Returns: Returns:
List of labels sorted alphabetically List of labels sorted alphabetically
""" """
result = await self.db.query( result = await self.db.query(
SQL_TEMPLATES["get_all_labels"], SQL_TEMPLATES["get_all_labels"],
{"workspace": self.db.workspace}, {"workspace": self.db.workspace},
multirows=True multirows=True,
) )
if not result: if not result:
return [] return []
# Extract all labels # Extract all labels
return [item["label"] for item in result] return [item["label"] for item in result]
async def get_knowledge_graph( async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5 self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph: ) -> KnowledgeGraph:
""" """
Get a connected subgraph of nodes matching the specified label Get a connected subgraph of nodes matching the specified label
Maximum number of nodes is limited by MAX_GRAPH_NODES environment variable (default: 1000) Maximum number of nodes is limited by MAX_GRAPH_NODES environment variable (default: 1000)
Args: Args:
node_label: The node label to match node_label: The node label to match
max_depth: Maximum depth of the subgraph max_depth: Maximum depth of the subgraph
Returns: Returns:
KnowledgeGraph object containing nodes and edges KnowledgeGraph object containing nodes and edges
""" """
result = KnowledgeGraph() result = KnowledgeGraph()
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
# Get matching nodes # Get matching nodes
if node_label == "*": if node_label == "*":
# Handle special case, get all nodes # Handle special case, get all nodes
node_results = await self.db.query( node_results = await self.db.query(
SQL_TEMPLATES["get_all_nodes"], SQL_TEMPLATES["get_all_nodes"],
{"workspace": self.db.workspace, "max_nodes": MAX_GRAPH_NODES}, {"workspace": self.db.workspace, "max_nodes": MAX_GRAPH_NODES},
multirows=True multirows=True,
) )
else: else:
# Get nodes matching the label # Get nodes matching the label
@@ -630,84 +636,93 @@ class TiDBGraphStorage(BaseGraphStorage):
node_results = await self.db.query( node_results = await self.db.query(
SQL_TEMPLATES["get_matching_nodes"], SQL_TEMPLATES["get_matching_nodes"],
{"workspace": self.db.workspace, "label_pattern": label_pattern}, {"workspace": self.db.workspace, "label_pattern": label_pattern},
multirows=True multirows=True,
) )
if not node_results: if not node_results:
logger.warning(f"No nodes found matching label {node_label}") logger.warning(f"No nodes found matching label {node_label}")
return result return result
# Limit the number of returned nodes # Limit the number of returned nodes
if len(node_results) > MAX_GRAPH_NODES: if len(node_results) > MAX_GRAPH_NODES:
node_results = node_results[:MAX_GRAPH_NODES] node_results = node_results[:MAX_GRAPH_NODES]
# Extract node names for edge query # Extract node names for edge query
node_names = [node["name"] for node in node_results] node_names = [node["name"] for node in node_results]
node_names_str = ",".join([f"'{name}'" for name in node_names]) node_names_str = ",".join([f"'{name}'" for name in node_names])
# Add nodes to result # Add nodes to result
for node in node_results: for node in node_results:
node_properties = {k: v for k, v in node.items() if k not in ["id", "name", "entity_type"]} node_properties = {
k: v for k, v in node.items() if k not in ["id", "name", "entity_type"]
}
result.nodes.append( result.nodes.append(
KnowledgeGraphNode( KnowledgeGraphNode(
id=node["name"], id=node["name"],
labels=[node["entity_type"]] if node.get("entity_type") else [node["name"]], labels=[node["entity_type"]]
properties=node_properties if node.get("entity_type")
else [node["name"]],
properties=node_properties,
) )
) )
# Get related edges # Get related edges
edge_results = await self.db.query( edge_results = await self.db.query(
SQL_TEMPLATES["get_related_edges"].format(node_names=node_names_str), SQL_TEMPLATES["get_related_edges"].format(node_names=node_names_str),
{"workspace": self.db.workspace}, {"workspace": self.db.workspace},
multirows=True multirows=True,
) )
if edge_results: if edge_results:
# Add edges to result # Add edges to result
for edge in edge_results: for edge in edge_results:
# Only include edges related to selected nodes # Only include edges related to selected nodes
if edge["source_name"] in node_names and edge["target_name"] in node_names: if (
edge["source_name"] in node_names
and edge["target_name"] in node_names
):
edge_id = f"{edge['source_name']}-{edge['target_name']}" edge_id = f"{edge['source_name']}-{edge['target_name']}"
edge_properties = {k: v for k, v in edge.items() edge_properties = {
if k not in ["id", "source_name", "target_name"]} k: v
for k, v in edge.items()
if k not in ["id", "source_name", "target_name"]
}
result.edges.append( result.edges.append(
KnowledgeGraphEdge( KnowledgeGraphEdge(
id=edge_id, id=edge_id,
type="RELATED", type="RELATED",
source=edge["source_name"], source=edge["source_name"],
target=edge["target_name"], target=edge["target_name"],
properties=edge_properties properties=edge_properties,
) )
) )
logger.info( logger.info(
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
) )
return result return result
async def remove_nodes(self, nodes: list[str]): async def remove_nodes(self, nodes: list[str]):
"""Delete multiple nodes """Delete multiple nodes
Args: Args:
nodes: List of node IDs to delete nodes: List of node IDs to delete
""" """
for node_id in nodes: for node_id in nodes:
await self.delete_node(node_id) await self.delete_node(node_id)
async def remove_edges(self, edges: list[tuple[str, str]]): async def remove_edges(self, edges: list[tuple[str, str]]):
"""Delete multiple edges """Delete multiple edges
Args: Args:
edges: List of edges to delete, each edge is a (source, target) tuple edges: List of edges to delete, each edge is a (source, target) tuple
""" """
for source, target in edges: for source, target in edges:
await self.db.execute(SQL_TEMPLATES["remove_multiple_edges"], { await self.db.execute(
"source": source, SQL_TEMPLATES["remove_multiple_edges"],
"target": target, {"source": source, "target": target, "workspace": self.db.workspace},
"workspace": self.db.workspace )
})
N_T = { N_T = {
@@ -919,26 +934,26 @@ SQL_TEMPLATES = {
source_chunk_id = VALUES(source_chunk_id) source_chunk_id = VALUES(source_chunk_id)
""", """,
"delete_node": """ "delete_node": """
DELETE FROM LIGHTRAG_GRAPH_NODES DELETE FROM LIGHTRAG_GRAPH_NODES
WHERE name = :name AND workspace = :workspace WHERE name = :name AND workspace = :workspace
""", """,
"delete_node_edges": """ "delete_node_edges": """
DELETE FROM LIGHTRAG_GRAPH_EDGES DELETE FROM LIGHTRAG_GRAPH_EDGES
WHERE (source_name = :name OR target_name = :name) AND workspace = :workspace WHERE (source_name = :name OR target_name = :name) AND workspace = :workspace
""", """,
"get_all_labels": """ "get_all_labels": """
SELECT DISTINCT entity_type as label SELECT DISTINCT entity_type as label
FROM LIGHTRAG_GRAPH_NODES FROM LIGHTRAG_GRAPH_NODES
WHERE workspace = :workspace WHERE workspace = :workspace
ORDER BY entity_type ORDER BY entity_type
""", """,
"get_matching_nodes": """ "get_matching_nodes": """
SELECT * FROM LIGHTRAG_GRAPH_NODES SELECT * FROM LIGHTRAG_GRAPH_NODES
WHERE name LIKE :label_pattern AND workspace = :workspace WHERE name LIKE :label_pattern AND workspace = :workspace
ORDER BY name ORDER BY name
""", """,
"get_all_nodes": """ "get_all_nodes": """
SELECT * FROM LIGHTRAG_GRAPH_NODES SELECT * FROM LIGHTRAG_GRAPH_NODES
WHERE workspace = :workspace WHERE workspace = :workspace
ORDER BY name ORDER BY name
LIMIT :max_nodes LIMIT :max_nodes
@@ -952,5 +967,5 @@ SQL_TEMPLATES = {
DELETE FROM LIGHTRAG_GRAPH_EDGES DELETE FROM LIGHTRAG_GRAPH_EDGES
WHERE (source_name = :source AND target_name = :target) WHERE (source_name = :source AND target_name = :target)
AND workspace = :workspace AND workspace = :workspace
""" """,
} }

View File

@@ -1401,40 +1401,54 @@ class LightRAG:
def delete_by_relation(self, source_entity: str, target_entity: str) -> None: def delete_by_relation(self, source_entity: str, target_entity: str) -> None:
"""Synchronously delete a relation between two entities. """Synchronously delete a relation between two entities.
Args: Args:
source_entity: Name of the source entity source_entity: Name of the source entity
target_entity: Name of the target entity target_entity: Name of the target entity
""" """
loop = always_get_an_event_loop() loop = always_get_an_event_loop()
return loop.run_until_complete(self.adelete_by_relation(source_entity, target_entity)) return loop.run_until_complete(
self.adelete_by_relation(source_entity, target_entity)
)
async def adelete_by_relation(self, source_entity: str, target_entity: str) -> None: async def adelete_by_relation(self, source_entity: str, target_entity: str) -> None:
"""Asynchronously delete a relation between two entities. """Asynchronously delete a relation between two entities.
Args: Args:
source_entity: Name of the source entity source_entity: Name of the source entity
target_entity: Name of the target entity target_entity: Name of the target entity
""" """
try: try:
# Check if the relation exists # Check if the relation exists
edge_exists = await self.chunk_entity_relation_graph.has_edge(source_entity, target_entity) edge_exists = await self.chunk_entity_relation_graph.has_edge(
source_entity, target_entity
)
if not edge_exists: if not edge_exists:
logger.warning(f"Relation from '{source_entity}' to '{target_entity}' does not exist") logger.warning(
f"Relation from '{source_entity}' to '{target_entity}' does not exist"
)
return return
# Delete relation from vector database # Delete relation from vector database
relation_id = compute_mdhash_id(source_entity + target_entity, prefix="rel-") relation_id = compute_mdhash_id(
source_entity + target_entity, prefix="rel-"
)
await self.relationships_vdb.delete([relation_id]) await self.relationships_vdb.delete([relation_id])
# Delete relation from knowledge graph # Delete relation from knowledge graph
await self.chunk_entity_relation_graph.remove_edges([(source_entity, target_entity)]) await self.chunk_entity_relation_graph.remove_edges(
[(source_entity, target_entity)]
logger.info(f"Successfully deleted relation from '{source_entity}' to '{target_entity}'") )
logger.info(
f"Successfully deleted relation from '{source_entity}' to '{target_entity}'"
)
await self._delete_relation_done() await self._delete_relation_done()
except Exception as e: except Exception as e:
logger.error(f"Error while deleting relation from '{source_entity}' to '{target_entity}': {e}") logger.error(
f"Error while deleting relation from '{source_entity}' to '{target_entity}': {e}"
)
async def _delete_relation_done(self) -> None: async def _delete_relation_done(self) -> None:
"""Callback after relation deletion is complete""" """Callback after relation deletion is complete"""
await asyncio.gather( await asyncio.gather(