fix linting

This commit is contained in:
zrguo
2025-03-04 15:53:20 +08:00
parent 3a2a636862
commit 81568f3bad
11 changed files with 394 additions and 327 deletions

View File

@@ -657,11 +657,13 @@ class AGEStorage(BaseGraphStorage):
"""
params = {
"src_label": AGEStorage._encode_graph_label(entity_name_label_source),
"tgt_label": AGEStorage._encode_graph_label(entity_name_label_target)
"tgt_label": AGEStorage._encode_graph_label(entity_name_label_target),
}
try:
await self._query(query, **params)
logger.debug(f"Deleted edge from '{entity_name_label_source}' to '{entity_name_label_target}'")
logger.debug(
f"Deleted edge from '{entity_name_label_source}' to '{entity_name_label_target}'"
)
except Exception as e:
logger.error(f"Error during edge deletion: {str(e)}")
raise
@@ -755,7 +757,7 @@ class AGEStorage(BaseGraphStorage):
KnowledgeGraphNode(
id=node_id,
labels=[node_label],
properties=node_properties
properties=node_properties,
)
)
seen_nodes.add(node_id)
@@ -785,7 +787,7 @@ class AGEStorage(BaseGraphStorage):
type="DIRECTED",
source=source,
target=target,
properties=edge_properties
properties=edge_properties,
)
)
seen_edges.add(edge_id)
@@ -809,9 +811,6 @@ class AGEStorage(BaseGraphStorage):
# Traverse graph from each start node
for start_node_record in start_nodes:
if "n" in start_node_record:
start_node = start_node_record["n"]
start_id = str(start_node.get("id", ""))
# Use BFS to traverse graph
query = """
MATCH (start:`{label}`)
@@ -830,14 +829,17 @@ class AGEStorage(BaseGraphStorage):
# Process nodes
for node in record["path_nodes"]:
node_id = str(node.get("id", ""))
if node_id not in seen_nodes and len(seen_nodes) < max_graph_nodes:
if (
node_id not in seen_nodes
and len(seen_nodes) < max_graph_nodes
):
node_properties = {k: v for k, v in node.items()}
node_label = node.get("label", "")
result.nodes.append(
KnowledgeGraphNode(
id=node_id,
labels=[node_label],
properties=node_properties
properties=node_properties,
)
)
seen_nodes.add(node_id)
@@ -856,7 +858,7 @@ class AGEStorage(BaseGraphStorage):
type=rel.get("label", "DIRECTED"),
source=source,
target=target,
properties=edge_properties
properties=edge_properties,
)
)
seen_edges.add(edge_id)

View File

@@ -223,7 +223,9 @@ class ChromaVectorDBStorage(BaseVectorStorage):
try:
logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
self._collection.delete(ids=ids)
logger.debug(f"Successfully deleted {len(ids)} vectors from {self.namespace}")
logger.debug(
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
)
except Exception as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
raise

View File

@@ -413,7 +413,7 @@ class GremlinStorage(BaseGraphStorage):
logger.debug(
"{%s}: Deleted node with entity_name '%s'",
inspect.currentframe().f_code.co_name,
entity_name
entity_name,
)
except Exception as e:
logger.error(f"Error during node deletion: {str(e)}")
@@ -458,7 +458,7 @@ class GremlinStorage(BaseGraphStorage):
logger.debug(
"{%s}: Retrieved %d labels",
inspect.currentframe().f_code.co_name,
len(labels)
len(labels),
)
return labels
except Exception as e:
@@ -500,18 +500,20 @@ class GremlinStorage(BaseGraphStorage):
# Add nodes to result
for node_data in nodes_result:
node_id = node_data.get('entity_name', str(node_data.get('id', '')))
node_id = node_data.get("entity_name", str(node_data.get("id", "")))
if str(node_id) in seen_nodes:
continue
# Create node with properties
node_properties = {k: v for k, v in node_data.items() if k not in ['id', 'label']}
node_properties = {
k: v for k, v in node_data.items() if k not in ["id", "label"]
}
result.nodes.append(
KnowledgeGraphNode(
id=str(node_id),
labels=[str(node_id)],
properties=node_properties
properties=node_properties,
)
)
seen_nodes.add(str(node_id))
@@ -537,15 +539,19 @@ class GremlinStorage(BaseGraphStorage):
edge_data = path[1]
target = path[2]
source_id = source.get('entity_name', str(source.get('id', '')))
target_id = target.get('entity_name', str(target.get('id', '')))
source_id = source.get("entity_name", str(source.get("id", "")))
target_id = target.get("entity_name", str(target.get("id", "")))
edge_id = f"{source_id}-{target_id}"
if edge_id in seen_edges:
continue
# Create edge with properties
edge_properties = {k: v for k, v in edge_data.items() if k not in ['id', 'label']}
edge_properties = {
k: v
for k, v in edge_data.items()
if k not in ["id", "label"]
}
result.edges.append(
KnowledgeGraphEdge(
@@ -553,7 +559,7 @@ class GremlinStorage(BaseGraphStorage):
type="DIRECTED",
source=str(source_id),
target=str(target_id),
properties=edge_properties
properties=edge_properties,
)
)
seen_edges.add(edge_id)
@@ -573,26 +579,32 @@ class GremlinStorage(BaseGraphStorage):
# Add nodes to result
for node_data in nodes_result:
node_id = node_data.get('entity_name', str(node_data.get('id', '')))
node_id = node_data.get("entity_name", str(node_data.get("id", "")))
if str(node_id) in seen_nodes:
continue
# Create node with properties
node_properties = {k: v for k, v in node_data.items() if k not in ['id', 'label']}
node_properties = {
k: v for k, v in node_data.items() if k not in ["id", "label"]
}
result.nodes.append(
KnowledgeGraphNode(
id=str(node_id),
labels=[str(node_id)],
properties=node_properties
properties=node_properties,
)
)
seen_nodes.add(str(node_id))
# Get edges between the nodes in the result
if nodes_result:
node_ids = [n.get('entity_name', str(n.get('id', ''))) for n in nodes_result]
node_ids_query = ", ".join([GremlinStorage._to_value_map(nid) for nid in node_ids])
node_ids = [
n.get("entity_name", str(n.get("id", ""))) for n in nodes_result
]
node_ids_query = ", ".join(
[GremlinStorage._to_value_map(nid) for nid in node_ids]
)
query = f"""g
.V().has('graph', {self.graph_name})
@@ -613,15 +625,19 @@ class GremlinStorage(BaseGraphStorage):
edge_data = path[1]
target = path[2]
source_id = source.get('entity_name', str(source.get('id', '')))
target_id = target.get('entity_name', str(target.get('id', '')))
source_id = source.get("entity_name", str(source.get("id", "")))
target_id = target.get("entity_name", str(target.get("id", "")))
edge_id = f"{source_id}-{target_id}"
if edge_id in seen_edges:
continue
# Create edge with properties
edge_properties = {k: v for k, v in edge_data.items() if k not in ['id', 'label']}
edge_properties = {
k: v
for k, v in edge_data.items()
if k not in ["id", "label"]
}
result.edges.append(
KnowledgeGraphEdge(
@@ -629,7 +645,7 @@ class GremlinStorage(BaseGraphStorage):
type="DIRECTED",
source=str(source_id),
target=str(target_id),
properties=edge_properties
properties=edge_properties,
)
)
seen_edges.add(edge_id)
@@ -637,7 +653,7 @@ class GremlinStorage(BaseGraphStorage):
logger.info(
"Subgraph query successful | Node count: %d | Edge count: %d",
len(result.nodes),
len(result.edges)
len(result.edges),
)
return result
@@ -674,7 +690,7 @@ class GremlinStorage(BaseGraphStorage):
"{%s}: Deleted edge from '%s' to '%s'",
inspect.currentframe().f_code.co_name,
entity_name_source,
entity_name_target
entity_name_target,
)
except Exception as e:
logger.error(f"Error during edge deletion: {str(e)}")

View File

@@ -132,12 +132,13 @@ class MilvusVectorDBStorage(BaseVectorStorage):
try:
# Compute entity ID from name
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}")
logger.debug(
f"Attempting to delete entity {entity_name} with ID {entity_id}"
)
# Delete the entity from Milvus collection
result = self._client.delete(
collection_name=self.namespace,
pks=[entity_id]
collection_name=self.namespace, pks=[entity_id]
)
if result and result.get("delete_count", 0) > 0:
@@ -160,9 +161,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
# Find all relations involving this entity
results = self._client.query(
collection_name=self.namespace,
filter=expr,
output_fields=["id"]
collection_name=self.namespace, filter=expr, output_fields=["id"]
)
if not results or len(results) == 0:
@@ -171,16 +170,19 @@ class MilvusVectorDBStorage(BaseVectorStorage):
# Extract IDs of relations to delete
relation_ids = [item["id"] for item in results]
logger.debug(f"Found {len(relation_ids)} relations for entity {entity_name}")
logger.debug(
f"Found {len(relation_ids)} relations for entity {entity_name}"
)
# Delete the relations
if relation_ids:
delete_result = self._client.delete(
collection_name=self.namespace,
pks=relation_ids
collection_name=self.namespace, pks=relation_ids
)
logger.debug(f"Deleted {delete_result.get('delete_count', 0)} relations for {entity_name}")
logger.debug(
f"Deleted {delete_result.get('delete_count', 0)} relations for {entity_name}"
)
except Exception as e:
logger.error(f"Error deleting relations for {entity_name}: {e}")
@@ -193,13 +195,12 @@ class MilvusVectorDBStorage(BaseVectorStorage):
"""
try:
# Delete vectors by IDs
result = self._client.delete(
collection_name=self.namespace,
pks=ids
)
result = self._client.delete(collection_name=self.namespace, pks=ids)
if result and result.get("delete_count", 0) > 0:
logger.debug(f"Successfully deleted {result.get('delete_count', 0)} vectors from {self.namespace}")
logger.debug(
f"Successfully deleted {result.get('delete_count', 0)} vectors from {self.namespace}"
)
else:
logger.debug(f"No vectors were deleted from {self.namespace}")

View File

@@ -807,8 +807,7 @@ class MongoGraphStorage(BaseGraphStorage):
# 1. Remove all edges referencing these nodes (remove from edges array of other nodes)
await self.collection.update_many(
{},
{"$pull": {"edges": {"target": {"$in": nodes}}}}
{}, {"$pull": {"edges": {"target": {"$in": nodes}}}}
)
# 2. Delete the node documents
@@ -831,8 +830,7 @@ class MongoGraphStorage(BaseGraphStorage):
# Remove edge pointing to target from source node's edges array
update_tasks.append(
self.collection.update_one(
{"_id": source},
{"$pull": {"edges": {"target": target}}}
{"_id": source}, {"$pull": {"edges": {"target": target}}}
)
)
@@ -990,9 +988,13 @@ class MongoVectorDBStorage(BaseVectorStorage):
try:
result = await self._data.delete_many({"_id": {"$in": ids}})
logger.debug(f"Successfully deleted {result.deleted_count} vectors from {self.namespace}")
logger.debug(
f"Successfully deleted {result.deleted_count} vectors from {self.namespace}"
)
except PyMongoError as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {str(e)}")
logger.error(
f"Error while deleting vectors from {self.namespace}: {str(e)}"
)
async def delete_entity(self, entity_name: str) -> None:
"""Delete an entity by its name
@@ -1002,7 +1004,9 @@ class MongoVectorDBStorage(BaseVectorStorage):
"""
try:
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}")
logger.debug(
f"Attempting to delete entity {entity_name} with ID {entity_id}"
)
result = await self._data.delete_one({"_id": entity_id})
if result.deleted_count > 0:
@@ -1031,7 +1035,9 @@ class MongoVectorDBStorage(BaseVectorStorage):
# Extract IDs of relations to delete
relation_ids = [relation["_id"] for relation in relations]
logger.debug(f"Found {len(relation_ids)} relations for entity {entity_name}")
logger.debug(
f"Found {len(relation_ids)} relations for entity {entity_name}"
)
# Delete the relations
result = await self._data.delete_many({"_id": {"$in": relation_ids}})

View File

@@ -457,7 +457,9 @@ class OracleVectorDBStorage(BaseVectorStorage):
)
params = {"workspace": self.db.workspace}
await self.db.execute(SQL, params)
logger.info(f"Successfully deleted {len(ids)} vectors from {self.namespace}")
logger.info(
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
)
except Exception as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
raise
@@ -728,7 +730,9 @@ class OracleGraphStorage(BaseGraphStorage):
params_node = {"workspace": self.db.workspace, "entity_name": node_id}
await self.db.execute(delete_node_sql, params_node)
logger.info(f"Successfully deleted node {node_id} and all its relationships")
logger.info(
f"Successfully deleted node {node_id} and all its relationships"
)
except Exception as e:
logger.error(f"Error deleting node {node_id}: {e}")
raise
@@ -791,7 +795,10 @@ class OracleGraphStorage(BaseGraphStorage):
ORDER BY id
FETCH FIRST :limit ROWS ONLY
"""
nodes_params = {"workspace": self.db.workspace, "limit": max_graph_nodes}
nodes_params = {
"workspace": self.db.workspace,
"limit": max_graph_nodes,
}
nodes = await self.db.query(nodes_sql, nodes_params, multirows=True)
else:
# For specific label, find matching nodes and related nodes
@@ -824,7 +831,7 @@ class OracleGraphStorage(BaseGraphStorage):
nodes_params = {
"workspace": self.db.workspace,
"node_label": node_label,
"limit": max_graph_nodes
"limit": max_graph_nodes,
}
nodes = await self.db.query(nodes_sql, nodes_params, multirows=True)
@@ -846,15 +853,13 @@ class OracleGraphStorage(BaseGraphStorage):
properties = {
"entity_type": node["entity_type"],
"description": node["description"] or "",
"source_id": node["source_chunk_id"] or ""
"source_id": node["source_chunk_id"] or "",
}
# Add node to result
result.nodes.append(
KnowledgeGraphNode(
id=node_id,
labels=[node["entity_type"]],
properties=properties
id=node_id, labels=[node["entity_type"]], properties=properties
)
)
seen_nodes.add(node_id)
@@ -868,10 +873,7 @@ class OracleGraphStorage(BaseGraphStorage):
AND target_name IN (SELECT COLUMN_VALUE FROM TABLE(CAST(:node_names AS SYS.ODCIVARCHAR2LIST)))
ORDER BY id
"""
edges_params = {
"workspace": self.db.workspace,
"node_names": node_names
}
edges_params = {"workspace": self.db.workspace, "node_names": node_names}
edges = await self.db.query(edges_sql, edges_params, multirows=True)
# Add edges to result
@@ -889,7 +891,7 @@ class OracleGraphStorage(BaseGraphStorage):
"weight": edge["weight"] or 0.0,
"keywords": edge["keywords"] or "",
"description": edge["description"] or "",
"source_id": edge["source_chunk_id"] or ""
"source_id": edge["source_chunk_id"] or "",
}
# Add edge to result
@@ -899,7 +901,7 @@ class OracleGraphStorage(BaseGraphStorage):
type="RELATED",
source=source,
target=target,
properties=properties
properties=properties,
)
)
seen_edges.add(edge_id)

View File

@@ -527,11 +527,15 @@ class PGVectorStorage(BaseVectorStorage):
return
ids_list = ",".join([f"'{id}'" for id in ids])
delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id IN ({ids_list})"
delete_sql = (
f"DELETE FROM {table_name} WHERE workspace=$1 AND id IN ({ids_list})"
)
try:
await self.db.execute(delete_sql, {"workspace": self.db.workspace})
logger.debug(f"Successfully deleted {len(ids)} vectors from {self.namespace}")
logger.debug(
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
)
except Exception as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
@@ -547,8 +551,7 @@ class PGVectorStorage(BaseVectorStorage):
WHERE workspace=$1 AND entity_name=$2"""
await self.db.execute(
delete_sql,
{"workspace": self.db.workspace, "entity_name": entity_name}
delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name}
)
logger.debug(f"Successfully deleted entity {entity_name}")
except Exception as e:
@@ -566,8 +569,7 @@ class PGVectorStorage(BaseVectorStorage):
WHERE workspace=$1 AND (source_id=$2 OR target_id=$2)"""
await self.db.execute(
delete_sql,
{"workspace": self.db.workspace, "entity_name": entity_name}
delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name}
)
logger.debug(f"Successfully deleted relations for entity {entity_name}")
except Exception as e:
@@ -1167,7 +1169,9 @@ class PGGraphStorage(BaseGraphStorage):
Args:
node_ids (list[str]): A list of node IDs to remove.
"""
encoded_node_ids = [self._encode_graph_label(node_id.strip('"')) for node_id in node_ids]
encoded_node_ids = [
self._encode_graph_label(node_id.strip('"')) for node_id in node_ids
]
node_id_list = ", ".join([f'"{node_id}"' for node_id in encoded_node_ids])
query = """SELECT * FROM cypher('%s', $$
@@ -1189,7 +1193,13 @@ class PGGraphStorage(BaseGraphStorage):
Args:
edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id).
"""
encoded_edges = [(self._encode_graph_label(src.strip('"')), self._encode_graph_label(tgt.strip('"'))) for src, tgt in edges]
encoded_edges = [
(
self._encode_graph_label(src.strip('"')),
self._encode_graph_label(tgt.strip('"')),
)
for src, tgt in edges
]
edge_list = ", ".join([f'["{src}", "{tgt}"]' for src, tgt in encoded_edges])
query = """SELECT * FROM cypher('%s', $$
@@ -1211,10 +1221,13 @@ class PGGraphStorage(BaseGraphStorage):
Returns:
list[str]: A list of all labels in the graph.
"""
query = """SELECT * FROM cypher('%s', $$
query = (
"""SELECT * FROM cypher('%s', $$
MATCH (n:Entity)
RETURN DISTINCT n.node_id AS label
$$) AS (label text)""" % self.graph_name
$$) AS (label text)"""
% self.graph_name
)
results = await self._query(query)
labels = [self._decode_graph_label(result["label"]) for result in results]
@@ -1260,7 +1273,10 @@ class PGGraphStorage(BaseGraphStorage):
OPTIONAL MATCH (n)-[r]->(m:Entity)
RETURN n, r, m
LIMIT %d
$$) AS (n agtype, r agtype, m agtype)""" % (self.graph_name, MAX_GRAPH_NODES)
$$) AS (n agtype, r agtype, m agtype)""" % (
self.graph_name,
MAX_GRAPH_NODES,
)
else:
encoded_node_label = self._encode_graph_label(node_label.strip('"'))
query = """SELECT * FROM cypher('%s', $$
@@ -1268,7 +1284,12 @@ class PGGraphStorage(BaseGraphStorage):
OPTIONAL MATCH p = (n)-[*..%d]-(m)
RETURN nodes(p) AS nodes, relationships(p) AS relationships
LIMIT %d
$$) AS (nodes agtype[], relationships agtype[])""" % (self.graph_name, encoded_node_label, max_depth, MAX_GRAPH_NODES)
$$) AS (nodes agtype[], relationships agtype[])""" % (
self.graph_name,
encoded_node_label,
max_depth,
MAX_GRAPH_NODES,
)
results = await self._query(query)
@@ -1305,29 +1326,6 @@ class PGGraphStorage(BaseGraphStorage):
return kg
async def get_all_labels(self) -> list[str]:
"""
Get all node labels in the graph
Returns:
[label1, label2, ...] # Alphabetically sorted label list
"""
query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity)
RETURN DISTINCT n.node_id AS label
ORDER BY label
$$) AS (label agtype)""" % (self.graph_name)
try:
results = await self._query(query)
labels = []
for record in results:
if record["label"]:
labels.append(self._decode_graph_label(record["label"]))
return labels
except Exception as e:
logger.error(f"Error getting all labels: {str(e)}")
return []
async def drop(self) -> None:
"""Drop the storage"""
drop_sql = SQL_TEMPLATES["drop_vdb_entity"]

View File

@@ -156,9 +156,11 @@ class QdrantVectorDBStorage(BaseVectorStorage):
points_selector=models.PointIdsList(
points=qdrant_ids,
),
wait=True
wait=True,
)
logger.debug(
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
)
logger.debug(f"Successfully deleted {len(ids)} vectors from {self.namespace}")
except Exception as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
@@ -171,7 +173,9 @@ class QdrantVectorDBStorage(BaseVectorStorage):
try:
# Generate the entity ID
entity_id = compute_mdhash_id_for_qdrant(entity_name, prefix="ent-")
logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}")
logger.debug(
f"Attempting to delete entity {entity_name} with ID {entity_id}"
)
# Delete the entity point from the collection
self._client.delete(
@@ -179,7 +183,7 @@ class QdrantVectorDBStorage(BaseVectorStorage):
points_selector=models.PointIdsList(
points=[entity_id],
),
wait=True
wait=True,
)
logger.debug(f"Successfully deleted entity {entity_name}")
except Exception as e:
@@ -198,17 +202,15 @@ class QdrantVectorDBStorage(BaseVectorStorage):
scroll_filter=models.Filter(
should=[
models.FieldCondition(
key="src_id",
match=models.MatchValue(value=entity_name)
key="src_id", match=models.MatchValue(value=entity_name)
),
models.FieldCondition(
key="tgt_id",
match=models.MatchValue(value=entity_name)
)
key="tgt_id", match=models.MatchValue(value=entity_name)
),
]
),
with_payload=True,
limit=1000 # Adjust as needed for your use case
limit=1000, # Adjust as needed for your use case
)
# Extract points that need to be deleted
@@ -222,9 +224,11 @@ class QdrantVectorDBStorage(BaseVectorStorage):
points_selector=models.PointIdsList(
points=ids_to_delete,
),
wait=True
wait=True,
)
logger.debug(
f"Deleted {len(ids_to_delete)} relations for {entity_name}"
)
logger.debug(f"Deleted {len(ids_to_delete)} relations for {entity_name}")
else:
logger.debug(f"No relations found for entity {entity_name}")
except Exception as e:

View File

@@ -80,7 +80,9 @@ class RedisKVStorage(BaseKVStorage):
results = await pipe.execute()
deleted_count = sum(results)
logger.info(f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}")
logger.info(
f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}"
)
async def delete_entity(self, entity_name: str) -> None:
"""Delete an entity by name
@@ -91,7 +93,9 @@ class RedisKVStorage(BaseKVStorage):
try:
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}")
logger.debug(
f"Attempting to delete entity {entity_name} with ID {entity_id}"
)
# Delete the entity
result = await self._redis.delete(f"{self.namespace}:{entity_id}")
@@ -124,7 +128,10 @@ class RedisKVStorage(BaseKVStorage):
if value:
data = json.loads(value)
# Check if this is a relation involving the entity
if data.get("src_id") == entity_name or data.get("tgt_id") == entity_name:
if (
data.get("src_id") == entity_name
or data.get("tgt_id") == entity_name
):
relation_keys.append(key)
# Exit loop when cursor returns to 0

View File

@@ -572,14 +572,20 @@ class TiDBGraphStorage(BaseGraphStorage):
node_id: The ID of the node to delete
"""
# First delete all edges related to this node
await self.db.execute(SQL_TEMPLATES["delete_node_edges"],
{"name": node_id, "workspace": self.db.workspace})
await self.db.execute(
SQL_TEMPLATES["delete_node_edges"],
{"name": node_id, "workspace": self.db.workspace},
)
# Then delete the node itself
await self.db.execute(SQL_TEMPLATES["delete_node"],
{"name": node_id, "workspace": self.db.workspace})
await self.db.execute(
SQL_TEMPLATES["delete_node"],
{"name": node_id, "workspace": self.db.workspace},
)
logger.debug(f"Node {node_id} and its related edges have been deleted from the graph")
logger.debug(
f"Node {node_id} and its related edges have been deleted from the graph"
)
async def get_all_labels(self) -> list[str]:
"""Get all entity types (labels) in the database
@@ -590,7 +596,7 @@ class TiDBGraphStorage(BaseGraphStorage):
result = await self.db.query(
SQL_TEMPLATES["get_all_labels"],
{"workspace": self.db.workspace},
multirows=True
multirows=True,
)
if not result:
@@ -622,7 +628,7 @@ class TiDBGraphStorage(BaseGraphStorage):
node_results = await self.db.query(
SQL_TEMPLATES["get_all_nodes"],
{"workspace": self.db.workspace, "max_nodes": MAX_GRAPH_NODES},
multirows=True
multirows=True,
)
else:
# Get nodes matching the label
@@ -630,7 +636,7 @@ class TiDBGraphStorage(BaseGraphStorage):
node_results = await self.db.query(
SQL_TEMPLATES["get_matching_nodes"],
{"workspace": self.db.workspace, "label_pattern": label_pattern},
multirows=True
multirows=True,
)
if not node_results:
@@ -647,12 +653,16 @@ class TiDBGraphStorage(BaseGraphStorage):
# Add nodes to result
for node in node_results:
node_properties = {k: v for k, v in node.items() if k not in ["id", "name", "entity_type"]}
node_properties = {
k: v for k, v in node.items() if k not in ["id", "name", "entity_type"]
}
result.nodes.append(
KnowledgeGraphNode(
id=node["name"],
labels=[node["entity_type"]] if node.get("entity_type") else [node["name"]],
properties=node_properties
labels=[node["entity_type"]]
if node.get("entity_type")
else [node["name"]],
properties=node_properties,
)
)
@@ -660,17 +670,23 @@ class TiDBGraphStorage(BaseGraphStorage):
edge_results = await self.db.query(
SQL_TEMPLATES["get_related_edges"].format(node_names=node_names_str),
{"workspace": self.db.workspace},
multirows=True
multirows=True,
)
if edge_results:
# Add edges to result
for edge in edge_results:
# Only include edges related to selected nodes
if edge["source_name"] in node_names and edge["target_name"] in node_names:
if (
edge["source_name"] in node_names
and edge["target_name"] in node_names
):
edge_id = f"{edge['source_name']}-{edge['target_name']}"
edge_properties = {k: v for k, v in edge.items()
if k not in ["id", "source_name", "target_name"]}
edge_properties = {
k: v
for k, v in edge.items()
if k not in ["id", "source_name", "target_name"]
}
result.edges.append(
KnowledgeGraphEdge(
@@ -678,7 +694,7 @@ class TiDBGraphStorage(BaseGraphStorage):
type="RELATED",
source=edge["source_name"],
target=edge["target_name"],
properties=edge_properties
properties=edge_properties,
)
)
@@ -703,11 +719,10 @@ class TiDBGraphStorage(BaseGraphStorage):
edges: List of edges to delete, each edge is a (source, target) tuple
"""
for source, target in edges:
await self.db.execute(SQL_TEMPLATES["remove_multiple_edges"], {
"source": source,
"target": target,
"workspace": self.db.workspace
})
await self.db.execute(
SQL_TEMPLATES["remove_multiple_edges"],
{"source": source, "target": target, "workspace": self.db.workspace},
)
N_T = {
@@ -952,5 +967,5 @@ SQL_TEMPLATES = {
DELETE FROM LIGHTRAG_GRAPH_EDGES
WHERE (source_name = :source AND target_name = :target)
AND workspace = :workspace
"""
""",
}

View File

@@ -1407,7 +1407,9 @@ class LightRAG:
target_entity: Name of the target entity
"""
loop = always_get_an_event_loop()
return loop.run_until_complete(self.adelete_by_relation(source_entity, target_entity))
return loop.run_until_complete(
self.adelete_by_relation(source_entity, target_entity)
)
async def adelete_by_relation(self, source_entity: str, target_entity: str) -> None:
"""Asynchronously delete a relation between two entities.
@@ -1418,22 +1420,34 @@ class LightRAG:
"""
try:
# Check if the relation exists
edge_exists = await self.chunk_entity_relation_graph.has_edge(source_entity, target_entity)
edge_exists = await self.chunk_entity_relation_graph.has_edge(
source_entity, target_entity
)
if not edge_exists:
logger.warning(f"Relation from '{source_entity}' to '{target_entity}' does not exist")
logger.warning(
f"Relation from '{source_entity}' to '{target_entity}' does not exist"
)
return
# Delete relation from vector database
relation_id = compute_mdhash_id(source_entity + target_entity, prefix="rel-")
relation_id = compute_mdhash_id(
source_entity + target_entity, prefix="rel-"
)
await self.relationships_vdb.delete([relation_id])
# Delete relation from knowledge graph
await self.chunk_entity_relation_graph.remove_edges([(source_entity, target_entity)])
await self.chunk_entity_relation_graph.remove_edges(
[(source_entity, target_entity)]
)
logger.info(f"Successfully deleted relation from '{source_entity}' to '{target_entity}'")
logger.info(
f"Successfully deleted relation from '{source_entity}' to '{target_entity}'"
)
await self._delete_relation_done()
except Exception as e:
logger.error(f"Error while deleting relation from '{source_entity}' to '{target_entity}': {e}")
logger.error(
f"Error while deleting relation from '{source_entity}' to '{target_entity}': {e}"
)
async def _delete_relation_done(self) -> None:
"""Callback after relation deletion is complete"""