fix linting

This commit is contained in:
zrguo
2025-03-04 15:53:20 +08:00
parent 3a2a636862
commit 81568f3bad
11 changed files with 394 additions and 327 deletions

View File

@@ -657,11 +657,13 @@ class AGEStorage(BaseGraphStorage):
""" """
params = { params = {
"src_label": AGEStorage._encode_graph_label(entity_name_label_source), "src_label": AGEStorage._encode_graph_label(entity_name_label_source),
"tgt_label": AGEStorage._encode_graph_label(entity_name_label_target) "tgt_label": AGEStorage._encode_graph_label(entity_name_label_target),
} }
try: try:
await self._query(query, **params) await self._query(query, **params)
logger.debug(f"Deleted edge from '{entity_name_label_source}' to '{entity_name_label_target}'") logger.debug(
f"Deleted edge from '{entity_name_label_source}' to '{entity_name_label_target}'"
)
except Exception as e: except Exception as e:
logger.error(f"Error during edge deletion: {str(e)}") logger.error(f"Error during edge deletion: {str(e)}")
raise raise
@@ -755,7 +757,7 @@ class AGEStorage(BaseGraphStorage):
KnowledgeGraphNode( KnowledgeGraphNode(
id=node_id, id=node_id,
labels=[node_label], labels=[node_label],
properties=node_properties properties=node_properties,
) )
) )
seen_nodes.add(node_id) seen_nodes.add(node_id)
@@ -785,7 +787,7 @@ class AGEStorage(BaseGraphStorage):
type="DIRECTED", type="DIRECTED",
source=source, source=source,
target=target, target=target,
properties=edge_properties properties=edge_properties,
) )
) )
seen_edges.add(edge_id) seen_edges.add(edge_id)
@@ -809,9 +811,6 @@ class AGEStorage(BaseGraphStorage):
# Traverse graph from each start node # Traverse graph from each start node
for start_node_record in start_nodes: for start_node_record in start_nodes:
if "n" in start_node_record: if "n" in start_node_record:
start_node = start_node_record["n"]
start_id = str(start_node.get("id", ""))
# Use BFS to traverse graph # Use BFS to traverse graph
query = """ query = """
MATCH (start:`{label}`) MATCH (start:`{label}`)
@@ -830,14 +829,17 @@ class AGEStorage(BaseGraphStorage):
# Process nodes # Process nodes
for node in record["path_nodes"]: for node in record["path_nodes"]:
node_id = str(node.get("id", "")) node_id = str(node.get("id", ""))
if node_id not in seen_nodes and len(seen_nodes) < max_graph_nodes: if (
node_id not in seen_nodes
and len(seen_nodes) < max_graph_nodes
):
node_properties = {k: v for k, v in node.items()} node_properties = {k: v for k, v in node.items()}
node_label = node.get("label", "") node_label = node.get("label", "")
result.nodes.append( result.nodes.append(
KnowledgeGraphNode( KnowledgeGraphNode(
id=node_id, id=node_id,
labels=[node_label], labels=[node_label],
properties=node_properties properties=node_properties,
) )
) )
seen_nodes.add(node_id) seen_nodes.add(node_id)
@@ -856,7 +858,7 @@ class AGEStorage(BaseGraphStorage):
type=rel.get("label", "DIRECTED"), type=rel.get("label", "DIRECTED"),
source=source, source=source,
target=target, target=target,
properties=edge_properties properties=edge_properties,
) )
) )
seen_edges.add(edge_id) seen_edges.add(edge_id)

View File

@@ -223,7 +223,9 @@ class ChromaVectorDBStorage(BaseVectorStorage):
try: try:
logger.info(f"Deleting {len(ids)} vectors from {self.namespace}") logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
self._collection.delete(ids=ids) self._collection.delete(ids=ids)
logger.debug(f"Successfully deleted {len(ids)} vectors from {self.namespace}") logger.debug(
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
)
except Exception as e: except Exception as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {e}") logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
raise raise

View File

@@ -413,7 +413,7 @@ class GremlinStorage(BaseGraphStorage):
logger.debug( logger.debug(
"{%s}: Deleted node with entity_name '%s'", "{%s}: Deleted node with entity_name '%s'",
inspect.currentframe().f_code.co_name, inspect.currentframe().f_code.co_name,
entity_name entity_name,
) )
except Exception as e: except Exception as e:
logger.error(f"Error during node deletion: {str(e)}") logger.error(f"Error during node deletion: {str(e)}")
@@ -458,7 +458,7 @@ class GremlinStorage(BaseGraphStorage):
logger.debug( logger.debug(
"{%s}: Retrieved %d labels", "{%s}: Retrieved %d labels",
inspect.currentframe().f_code.co_name, inspect.currentframe().f_code.co_name,
len(labels) len(labels),
) )
return labels return labels
except Exception as e: except Exception as e:
@@ -500,18 +500,20 @@ class GremlinStorage(BaseGraphStorage):
# Add nodes to result # Add nodes to result
for node_data in nodes_result: for node_data in nodes_result:
node_id = node_data.get('entity_name', str(node_data.get('id', ''))) node_id = node_data.get("entity_name", str(node_data.get("id", "")))
if str(node_id) in seen_nodes: if str(node_id) in seen_nodes:
continue continue
# Create node with properties # Create node with properties
node_properties = {k: v for k, v in node_data.items() if k not in ['id', 'label']} node_properties = {
k: v for k, v in node_data.items() if k not in ["id", "label"]
}
result.nodes.append( result.nodes.append(
KnowledgeGraphNode( KnowledgeGraphNode(
id=str(node_id), id=str(node_id),
labels=[str(node_id)], labels=[str(node_id)],
properties=node_properties properties=node_properties,
) )
) )
seen_nodes.add(str(node_id)) seen_nodes.add(str(node_id))
@@ -537,15 +539,19 @@ class GremlinStorage(BaseGraphStorage):
edge_data = path[1] edge_data = path[1]
target = path[2] target = path[2]
source_id = source.get('entity_name', str(source.get('id', ''))) source_id = source.get("entity_name", str(source.get("id", "")))
target_id = target.get('entity_name', str(target.get('id', ''))) target_id = target.get("entity_name", str(target.get("id", "")))
edge_id = f"{source_id}-{target_id}" edge_id = f"{source_id}-{target_id}"
if edge_id in seen_edges: if edge_id in seen_edges:
continue continue
# Create edge with properties # Create edge with properties
edge_properties = {k: v for k, v in edge_data.items() if k not in ['id', 'label']} edge_properties = {
k: v
for k, v in edge_data.items()
if k not in ["id", "label"]
}
result.edges.append( result.edges.append(
KnowledgeGraphEdge( KnowledgeGraphEdge(
@@ -553,7 +559,7 @@ class GremlinStorage(BaseGraphStorage):
type="DIRECTED", type="DIRECTED",
source=str(source_id), source=str(source_id),
target=str(target_id), target=str(target_id),
properties=edge_properties properties=edge_properties,
) )
) )
seen_edges.add(edge_id) seen_edges.add(edge_id)
@@ -573,26 +579,32 @@ class GremlinStorage(BaseGraphStorage):
# Add nodes to result # Add nodes to result
for node_data in nodes_result: for node_data in nodes_result:
node_id = node_data.get('entity_name', str(node_data.get('id', ''))) node_id = node_data.get("entity_name", str(node_data.get("id", "")))
if str(node_id) in seen_nodes: if str(node_id) in seen_nodes:
continue continue
# Create node with properties # Create node with properties
node_properties = {k: v for k, v in node_data.items() if k not in ['id', 'label']} node_properties = {
k: v for k, v in node_data.items() if k not in ["id", "label"]
}
result.nodes.append( result.nodes.append(
KnowledgeGraphNode( KnowledgeGraphNode(
id=str(node_id), id=str(node_id),
labels=[str(node_id)], labels=[str(node_id)],
properties=node_properties properties=node_properties,
) )
) )
seen_nodes.add(str(node_id)) seen_nodes.add(str(node_id))
# Get edges between the nodes in the result # Get edges between the nodes in the result
if nodes_result: if nodes_result:
node_ids = [n.get('entity_name', str(n.get('id', ''))) for n in nodes_result] node_ids = [
node_ids_query = ", ".join([GremlinStorage._to_value_map(nid) for nid in node_ids]) n.get("entity_name", str(n.get("id", ""))) for n in nodes_result
]
node_ids_query = ", ".join(
[GremlinStorage._to_value_map(nid) for nid in node_ids]
)
query = f"""g query = f"""g
.V().has('graph', {self.graph_name}) .V().has('graph', {self.graph_name})
@@ -613,15 +625,19 @@ class GremlinStorage(BaseGraphStorage):
edge_data = path[1] edge_data = path[1]
target = path[2] target = path[2]
source_id = source.get('entity_name', str(source.get('id', ''))) source_id = source.get("entity_name", str(source.get("id", "")))
target_id = target.get('entity_name', str(target.get('id', ''))) target_id = target.get("entity_name", str(target.get("id", "")))
edge_id = f"{source_id}-{target_id}" edge_id = f"{source_id}-{target_id}"
if edge_id in seen_edges: if edge_id in seen_edges:
continue continue
# Create edge with properties # Create edge with properties
edge_properties = {k: v for k, v in edge_data.items() if k not in ['id', 'label']} edge_properties = {
k: v
for k, v in edge_data.items()
if k not in ["id", "label"]
}
result.edges.append( result.edges.append(
KnowledgeGraphEdge( KnowledgeGraphEdge(
@@ -629,7 +645,7 @@ class GremlinStorage(BaseGraphStorage):
type="DIRECTED", type="DIRECTED",
source=str(source_id), source=str(source_id),
target=str(target_id), target=str(target_id),
properties=edge_properties properties=edge_properties,
) )
) )
seen_edges.add(edge_id) seen_edges.add(edge_id)
@@ -637,7 +653,7 @@ class GremlinStorage(BaseGraphStorage):
logger.info( logger.info(
"Subgraph query successful | Node count: %d | Edge count: %d", "Subgraph query successful | Node count: %d | Edge count: %d",
len(result.nodes), len(result.nodes),
len(result.edges) len(result.edges),
) )
return result return result
@@ -674,7 +690,7 @@ class GremlinStorage(BaseGraphStorage):
"{%s}: Deleted edge from '%s' to '%s'", "{%s}: Deleted edge from '%s' to '%s'",
inspect.currentframe().f_code.co_name, inspect.currentframe().f_code.co_name,
entity_name_source, entity_name_source,
entity_name_target entity_name_target,
) )
except Exception as e: except Exception as e:
logger.error(f"Error during edge deletion: {str(e)}") logger.error(f"Error during edge deletion: {str(e)}")

View File

@@ -132,12 +132,13 @@ class MilvusVectorDBStorage(BaseVectorStorage):
try: try:
# Compute entity ID from name # Compute entity ID from name
entity_id = compute_mdhash_id(entity_name, prefix="ent-") entity_id = compute_mdhash_id(entity_name, prefix="ent-")
logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}") logger.debug(
f"Attempting to delete entity {entity_name} with ID {entity_id}"
)
# Delete the entity from Milvus collection # Delete the entity from Milvus collection
result = self._client.delete( result = self._client.delete(
collection_name=self.namespace, collection_name=self.namespace, pks=[entity_id]
pks=[entity_id]
) )
if result and result.get("delete_count", 0) > 0: if result and result.get("delete_count", 0) > 0:
@@ -160,9 +161,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
# Find all relations involving this entity # Find all relations involving this entity
results = self._client.query( results = self._client.query(
collection_name=self.namespace, collection_name=self.namespace, filter=expr, output_fields=["id"]
filter=expr,
output_fields=["id"]
) )
if not results or len(results) == 0: if not results or len(results) == 0:
@@ -171,16 +170,19 @@ class MilvusVectorDBStorage(BaseVectorStorage):
# Extract IDs of relations to delete # Extract IDs of relations to delete
relation_ids = [item["id"] for item in results] relation_ids = [item["id"] for item in results]
logger.debug(f"Found {len(relation_ids)} relations for entity {entity_name}") logger.debug(
f"Found {len(relation_ids)} relations for entity {entity_name}"
)
# Delete the relations # Delete the relations
if relation_ids: if relation_ids:
delete_result = self._client.delete( delete_result = self._client.delete(
collection_name=self.namespace, collection_name=self.namespace, pks=relation_ids
pks=relation_ids
) )
logger.debug(f"Deleted {delete_result.get('delete_count', 0)} relations for {entity_name}") logger.debug(
f"Deleted {delete_result.get('delete_count', 0)} relations for {entity_name}"
)
except Exception as e: except Exception as e:
logger.error(f"Error deleting relations for {entity_name}: {e}") logger.error(f"Error deleting relations for {entity_name}: {e}")
@@ -193,13 +195,12 @@ class MilvusVectorDBStorage(BaseVectorStorage):
""" """
try: try:
# Delete vectors by IDs # Delete vectors by IDs
result = self._client.delete( result = self._client.delete(collection_name=self.namespace, pks=ids)
collection_name=self.namespace,
pks=ids
)
if result and result.get("delete_count", 0) > 0: if result and result.get("delete_count", 0) > 0:
logger.debug(f"Successfully deleted {result.get('delete_count', 0)} vectors from {self.namespace}") logger.debug(
f"Successfully deleted {result.get('delete_count', 0)} vectors from {self.namespace}"
)
else: else:
logger.debug(f"No vectors were deleted from {self.namespace}") logger.debug(f"No vectors were deleted from {self.namespace}")

View File

@@ -807,8 +807,7 @@ class MongoGraphStorage(BaseGraphStorage):
# 1. Remove all edges referencing these nodes (remove from edges array of other nodes) # 1. Remove all edges referencing these nodes (remove from edges array of other nodes)
await self.collection.update_many( await self.collection.update_many(
{}, {}, {"$pull": {"edges": {"target": {"$in": nodes}}}}
{"$pull": {"edges": {"target": {"$in": nodes}}}}
) )
# 2. Delete the node documents # 2. Delete the node documents
@@ -831,8 +830,7 @@ class MongoGraphStorage(BaseGraphStorage):
# Remove edge pointing to target from source node's edges array # Remove edge pointing to target from source node's edges array
update_tasks.append( update_tasks.append(
self.collection.update_one( self.collection.update_one(
{"_id": source}, {"_id": source}, {"$pull": {"edges": {"target": target}}}
{"$pull": {"edges": {"target": target}}}
) )
) )
@@ -990,9 +988,13 @@ class MongoVectorDBStorage(BaseVectorStorage):
try: try:
result = await self._data.delete_many({"_id": {"$in": ids}}) result = await self._data.delete_many({"_id": {"$in": ids}})
logger.debug(f"Successfully deleted {result.deleted_count} vectors from {self.namespace}") logger.debug(
f"Successfully deleted {result.deleted_count} vectors from {self.namespace}"
)
except PyMongoError as e: except PyMongoError as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {str(e)}") logger.error(
f"Error while deleting vectors from {self.namespace}: {str(e)}"
)
async def delete_entity(self, entity_name: str) -> None: async def delete_entity(self, entity_name: str) -> None:
"""Delete an entity by its name """Delete an entity by its name
@@ -1002,7 +1004,9 @@ class MongoVectorDBStorage(BaseVectorStorage):
""" """
try: try:
entity_id = compute_mdhash_id(entity_name, prefix="ent-") entity_id = compute_mdhash_id(entity_name, prefix="ent-")
logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}") logger.debug(
f"Attempting to delete entity {entity_name} with ID {entity_id}"
)
result = await self._data.delete_one({"_id": entity_id}) result = await self._data.delete_one({"_id": entity_id})
if result.deleted_count > 0: if result.deleted_count > 0:
@@ -1031,7 +1035,9 @@ class MongoVectorDBStorage(BaseVectorStorage):
# Extract IDs of relations to delete # Extract IDs of relations to delete
relation_ids = [relation["_id"] for relation in relations] relation_ids = [relation["_id"] for relation in relations]
logger.debug(f"Found {len(relation_ids)} relations for entity {entity_name}") logger.debug(
f"Found {len(relation_ids)} relations for entity {entity_name}"
)
# Delete the relations # Delete the relations
result = await self._data.delete_many({"_id": {"$in": relation_ids}}) result = await self._data.delete_many({"_id": {"$in": relation_ids}})

View File

@@ -457,7 +457,9 @@ class OracleVectorDBStorage(BaseVectorStorage):
) )
params = {"workspace": self.db.workspace} params = {"workspace": self.db.workspace}
await self.db.execute(SQL, params) await self.db.execute(SQL, params)
logger.info(f"Successfully deleted {len(ids)} vectors from {self.namespace}") logger.info(
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
)
except Exception as e: except Exception as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {e}") logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
raise raise
@@ -728,7 +730,9 @@ class OracleGraphStorage(BaseGraphStorage):
params_node = {"workspace": self.db.workspace, "entity_name": node_id} params_node = {"workspace": self.db.workspace, "entity_name": node_id}
await self.db.execute(delete_node_sql, params_node) await self.db.execute(delete_node_sql, params_node)
logger.info(f"Successfully deleted node {node_id} and all its relationships") logger.info(
f"Successfully deleted node {node_id} and all its relationships"
)
except Exception as e: except Exception as e:
logger.error(f"Error deleting node {node_id}: {e}") logger.error(f"Error deleting node {node_id}: {e}")
raise raise
@@ -791,7 +795,10 @@ class OracleGraphStorage(BaseGraphStorage):
ORDER BY id ORDER BY id
FETCH FIRST :limit ROWS ONLY FETCH FIRST :limit ROWS ONLY
""" """
nodes_params = {"workspace": self.db.workspace, "limit": max_graph_nodes} nodes_params = {
"workspace": self.db.workspace,
"limit": max_graph_nodes,
}
nodes = await self.db.query(nodes_sql, nodes_params, multirows=True) nodes = await self.db.query(nodes_sql, nodes_params, multirows=True)
else: else:
# For specific label, find matching nodes and related nodes # For specific label, find matching nodes and related nodes
@@ -824,7 +831,7 @@ class OracleGraphStorage(BaseGraphStorage):
nodes_params = { nodes_params = {
"workspace": self.db.workspace, "workspace": self.db.workspace,
"node_label": node_label, "node_label": node_label,
"limit": max_graph_nodes "limit": max_graph_nodes,
} }
nodes = await self.db.query(nodes_sql, nodes_params, multirows=True) nodes = await self.db.query(nodes_sql, nodes_params, multirows=True)
@@ -846,15 +853,13 @@ class OracleGraphStorage(BaseGraphStorage):
properties = { properties = {
"entity_type": node["entity_type"], "entity_type": node["entity_type"],
"description": node["description"] or "", "description": node["description"] or "",
"source_id": node["source_chunk_id"] or "" "source_id": node["source_chunk_id"] or "",
} }
# Add node to result # Add node to result
result.nodes.append( result.nodes.append(
KnowledgeGraphNode( KnowledgeGraphNode(
id=node_id, id=node_id, labels=[node["entity_type"]], properties=properties
labels=[node["entity_type"]],
properties=properties
) )
) )
seen_nodes.add(node_id) seen_nodes.add(node_id)
@@ -868,10 +873,7 @@ class OracleGraphStorage(BaseGraphStorage):
AND target_name IN (SELECT COLUMN_VALUE FROM TABLE(CAST(:node_names AS SYS.ODCIVARCHAR2LIST))) AND target_name IN (SELECT COLUMN_VALUE FROM TABLE(CAST(:node_names AS SYS.ODCIVARCHAR2LIST)))
ORDER BY id ORDER BY id
""" """
edges_params = { edges_params = {"workspace": self.db.workspace, "node_names": node_names}
"workspace": self.db.workspace,
"node_names": node_names
}
edges = await self.db.query(edges_sql, edges_params, multirows=True) edges = await self.db.query(edges_sql, edges_params, multirows=True)
# Add edges to result # Add edges to result
@@ -889,7 +891,7 @@ class OracleGraphStorage(BaseGraphStorage):
"weight": edge["weight"] or 0.0, "weight": edge["weight"] or 0.0,
"keywords": edge["keywords"] or "", "keywords": edge["keywords"] or "",
"description": edge["description"] or "", "description": edge["description"] or "",
"source_id": edge["source_chunk_id"] or "" "source_id": edge["source_chunk_id"] or "",
} }
# Add edge to result # Add edge to result
@@ -899,7 +901,7 @@ class OracleGraphStorage(BaseGraphStorage):
type="RELATED", type="RELATED",
source=source, source=source,
target=target, target=target,
properties=properties properties=properties,
) )
) )
seen_edges.add(edge_id) seen_edges.add(edge_id)

View File

@@ -527,11 +527,15 @@ class PGVectorStorage(BaseVectorStorage):
return return
ids_list = ",".join([f"'{id}'" for id in ids]) ids_list = ",".join([f"'{id}'" for id in ids])
delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id IN ({ids_list})" delete_sql = (
f"DELETE FROM {table_name} WHERE workspace=$1 AND id IN ({ids_list})"
)
try: try:
await self.db.execute(delete_sql, {"workspace": self.db.workspace}) await self.db.execute(delete_sql, {"workspace": self.db.workspace})
logger.debug(f"Successfully deleted {len(ids)} vectors from {self.namespace}") logger.debug(
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
)
except Exception as e: except Exception as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {e}") logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
@@ -547,8 +551,7 @@ class PGVectorStorage(BaseVectorStorage):
WHERE workspace=$1 AND entity_name=$2""" WHERE workspace=$1 AND entity_name=$2"""
await self.db.execute( await self.db.execute(
delete_sql, delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name}
{"workspace": self.db.workspace, "entity_name": entity_name}
) )
logger.debug(f"Successfully deleted entity {entity_name}") logger.debug(f"Successfully deleted entity {entity_name}")
except Exception as e: except Exception as e:
@@ -566,8 +569,7 @@ class PGVectorStorage(BaseVectorStorage):
WHERE workspace=$1 AND (source_id=$2 OR target_id=$2)""" WHERE workspace=$1 AND (source_id=$2 OR target_id=$2)"""
await self.db.execute( await self.db.execute(
delete_sql, delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name}
{"workspace": self.db.workspace, "entity_name": entity_name}
) )
logger.debug(f"Successfully deleted relations for entity {entity_name}") logger.debug(f"Successfully deleted relations for entity {entity_name}")
except Exception as e: except Exception as e:
@@ -1167,7 +1169,9 @@ class PGGraphStorage(BaseGraphStorage):
Args: Args:
node_ids (list[str]): A list of node IDs to remove. node_ids (list[str]): A list of node IDs to remove.
""" """
encoded_node_ids = [self._encode_graph_label(node_id.strip('"')) for node_id in node_ids] encoded_node_ids = [
self._encode_graph_label(node_id.strip('"')) for node_id in node_ids
]
node_id_list = ", ".join([f'"{node_id}"' for node_id in encoded_node_ids]) node_id_list = ", ".join([f'"{node_id}"' for node_id in encoded_node_ids])
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
@@ -1189,7 +1193,13 @@ class PGGraphStorage(BaseGraphStorage):
Args: Args:
edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id). edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id).
""" """
encoded_edges = [(self._encode_graph_label(src.strip('"')), self._encode_graph_label(tgt.strip('"'))) for src, tgt in edges] encoded_edges = [
(
self._encode_graph_label(src.strip('"')),
self._encode_graph_label(tgt.strip('"')),
)
for src, tgt in edges
]
edge_list = ", ".join([f'["{src}", "{tgt}"]' for src, tgt in encoded_edges]) edge_list = ", ".join([f'["{src}", "{tgt}"]' for src, tgt in encoded_edges])
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
@@ -1211,10 +1221,13 @@ class PGGraphStorage(BaseGraphStorage):
Returns: Returns:
list[str]: A list of all labels in the graph. list[str]: A list of all labels in the graph.
""" """
query = """SELECT * FROM cypher('%s', $$ query = (
"""SELECT * FROM cypher('%s', $$
MATCH (n:Entity) MATCH (n:Entity)
RETURN DISTINCT n.node_id AS label RETURN DISTINCT n.node_id AS label
$$) AS (label text)""" % self.graph_name $$) AS (label text)"""
% self.graph_name
)
results = await self._query(query) results = await self._query(query)
labels = [self._decode_graph_label(result["label"]) for result in results] labels = [self._decode_graph_label(result["label"]) for result in results]
@@ -1260,7 +1273,10 @@ class PGGraphStorage(BaseGraphStorage):
OPTIONAL MATCH (n)-[r]->(m:Entity) OPTIONAL MATCH (n)-[r]->(m:Entity)
RETURN n, r, m RETURN n, r, m
LIMIT %d LIMIT %d
$$) AS (n agtype, r agtype, m agtype)""" % (self.graph_name, MAX_GRAPH_NODES) $$) AS (n agtype, r agtype, m agtype)""" % (
self.graph_name,
MAX_GRAPH_NODES,
)
else: else:
encoded_node_label = self._encode_graph_label(node_label.strip('"')) encoded_node_label = self._encode_graph_label(node_label.strip('"'))
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
@@ -1268,7 +1284,12 @@ class PGGraphStorage(BaseGraphStorage):
OPTIONAL MATCH p = (n)-[*..%d]-(m) OPTIONAL MATCH p = (n)-[*..%d]-(m)
RETURN nodes(p) AS nodes, relationships(p) AS relationships RETURN nodes(p) AS nodes, relationships(p) AS relationships
LIMIT %d LIMIT %d
$$) AS (nodes agtype[], relationships agtype[])""" % (self.graph_name, encoded_node_label, max_depth, MAX_GRAPH_NODES) $$) AS (nodes agtype[], relationships agtype[])""" % (
self.graph_name,
encoded_node_label,
max_depth,
MAX_GRAPH_NODES,
)
results = await self._query(query) results = await self._query(query)
@@ -1305,29 +1326,6 @@ class PGGraphStorage(BaseGraphStorage):
return kg return kg
async def get_all_labels(self) -> list[str]:
"""
Get all node labels in the graph
Returns:
[label1, label2, ...] # Alphabetically sorted label list
"""
query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity)
RETURN DISTINCT n.node_id AS label
ORDER BY label
$$) AS (label agtype)""" % (self.graph_name)
try:
results = await self._query(query)
labels = []
for record in results:
if record["label"]:
labels.append(self._decode_graph_label(record["label"]))
return labels
except Exception as e:
logger.error(f"Error getting all labels: {str(e)}")
return []
async def drop(self) -> None: async def drop(self) -> None:
"""Drop the storage""" """Drop the storage"""
drop_sql = SQL_TEMPLATES["drop_vdb_entity"] drop_sql = SQL_TEMPLATES["drop_vdb_entity"]

View File

@@ -156,9 +156,11 @@ class QdrantVectorDBStorage(BaseVectorStorage):
points_selector=models.PointIdsList( points_selector=models.PointIdsList(
points=qdrant_ids, points=qdrant_ids,
), ),
wait=True wait=True,
)
logger.debug(
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
) )
logger.debug(f"Successfully deleted {len(ids)} vectors from {self.namespace}")
except Exception as e: except Exception as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {e}") logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
@@ -171,7 +173,9 @@ class QdrantVectorDBStorage(BaseVectorStorage):
try: try:
# Generate the entity ID # Generate the entity ID
entity_id = compute_mdhash_id_for_qdrant(entity_name, prefix="ent-") entity_id = compute_mdhash_id_for_qdrant(entity_name, prefix="ent-")
logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}") logger.debug(
f"Attempting to delete entity {entity_name} with ID {entity_id}"
)
# Delete the entity point from the collection # Delete the entity point from the collection
self._client.delete( self._client.delete(
@@ -179,7 +183,7 @@ class QdrantVectorDBStorage(BaseVectorStorage):
points_selector=models.PointIdsList( points_selector=models.PointIdsList(
points=[entity_id], points=[entity_id],
), ),
wait=True wait=True,
) )
logger.debug(f"Successfully deleted entity {entity_name}") logger.debug(f"Successfully deleted entity {entity_name}")
except Exception as e: except Exception as e:
@@ -198,17 +202,15 @@ class QdrantVectorDBStorage(BaseVectorStorage):
scroll_filter=models.Filter( scroll_filter=models.Filter(
should=[ should=[
models.FieldCondition( models.FieldCondition(
key="src_id", key="src_id", match=models.MatchValue(value=entity_name)
match=models.MatchValue(value=entity_name)
), ),
models.FieldCondition( models.FieldCondition(
key="tgt_id", key="tgt_id", match=models.MatchValue(value=entity_name)
match=models.MatchValue(value=entity_name) ),
)
] ]
), ),
with_payload=True, with_payload=True,
limit=1000 # Adjust as needed for your use case limit=1000, # Adjust as needed for your use case
) )
# Extract points that need to be deleted # Extract points that need to be deleted
@@ -222,9 +224,11 @@ class QdrantVectorDBStorage(BaseVectorStorage):
points_selector=models.PointIdsList( points_selector=models.PointIdsList(
points=ids_to_delete, points=ids_to_delete,
), ),
wait=True wait=True,
)
logger.debug(
f"Deleted {len(ids_to_delete)} relations for {entity_name}"
) )
logger.debug(f"Deleted {len(ids_to_delete)} relations for {entity_name}")
else: else:
logger.debug(f"No relations found for entity {entity_name}") logger.debug(f"No relations found for entity {entity_name}")
except Exception as e: except Exception as e:

View File

@@ -80,7 +80,9 @@ class RedisKVStorage(BaseKVStorage):
results = await pipe.execute() results = await pipe.execute()
deleted_count = sum(results) deleted_count = sum(results)
logger.info(f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}") logger.info(
f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}"
)
async def delete_entity(self, entity_name: str) -> None: async def delete_entity(self, entity_name: str) -> None:
"""Delete an entity by name """Delete an entity by name
@@ -91,7 +93,9 @@ class RedisKVStorage(BaseKVStorage):
try: try:
entity_id = compute_mdhash_id(entity_name, prefix="ent-") entity_id = compute_mdhash_id(entity_name, prefix="ent-")
logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}") logger.debug(
f"Attempting to delete entity {entity_name} with ID {entity_id}"
)
# Delete the entity # Delete the entity
result = await self._redis.delete(f"{self.namespace}:{entity_id}") result = await self._redis.delete(f"{self.namespace}:{entity_id}")
@@ -124,7 +128,10 @@ class RedisKVStorage(BaseKVStorage):
if value: if value:
data = json.loads(value) data = json.loads(value)
# Check if this is a relation involving the entity # Check if this is a relation involving the entity
if data.get("src_id") == entity_name or data.get("tgt_id") == entity_name: if (
data.get("src_id") == entity_name
or data.get("tgt_id") == entity_name
):
relation_keys.append(key) relation_keys.append(key)
# Exit loop when cursor returns to 0 # Exit loop when cursor returns to 0

View File

@@ -572,14 +572,20 @@ class TiDBGraphStorage(BaseGraphStorage):
node_id: The ID of the node to delete node_id: The ID of the node to delete
""" """
# First delete all edges related to this node # First delete all edges related to this node
await self.db.execute(SQL_TEMPLATES["delete_node_edges"], await self.db.execute(
{"name": node_id, "workspace": self.db.workspace}) SQL_TEMPLATES["delete_node_edges"],
{"name": node_id, "workspace": self.db.workspace},
)
# Then delete the node itself # Then delete the node itself
await self.db.execute(SQL_TEMPLATES["delete_node"], await self.db.execute(
{"name": node_id, "workspace": self.db.workspace}) SQL_TEMPLATES["delete_node"],
{"name": node_id, "workspace": self.db.workspace},
)
logger.debug(f"Node {node_id} and its related edges have been deleted from the graph") logger.debug(
f"Node {node_id} and its related edges have been deleted from the graph"
)
async def get_all_labels(self) -> list[str]: async def get_all_labels(self) -> list[str]:
"""Get all entity types (labels) in the database """Get all entity types (labels) in the database
@@ -590,7 +596,7 @@ class TiDBGraphStorage(BaseGraphStorage):
result = await self.db.query( result = await self.db.query(
SQL_TEMPLATES["get_all_labels"], SQL_TEMPLATES["get_all_labels"],
{"workspace": self.db.workspace}, {"workspace": self.db.workspace},
multirows=True multirows=True,
) )
if not result: if not result:
@@ -622,7 +628,7 @@ class TiDBGraphStorage(BaseGraphStorage):
node_results = await self.db.query( node_results = await self.db.query(
SQL_TEMPLATES["get_all_nodes"], SQL_TEMPLATES["get_all_nodes"],
{"workspace": self.db.workspace, "max_nodes": MAX_GRAPH_NODES}, {"workspace": self.db.workspace, "max_nodes": MAX_GRAPH_NODES},
multirows=True multirows=True,
) )
else: else:
# Get nodes matching the label # Get nodes matching the label
@@ -630,7 +636,7 @@ class TiDBGraphStorage(BaseGraphStorage):
node_results = await self.db.query( node_results = await self.db.query(
SQL_TEMPLATES["get_matching_nodes"], SQL_TEMPLATES["get_matching_nodes"],
{"workspace": self.db.workspace, "label_pattern": label_pattern}, {"workspace": self.db.workspace, "label_pattern": label_pattern},
multirows=True multirows=True,
) )
if not node_results: if not node_results:
@@ -647,12 +653,16 @@ class TiDBGraphStorage(BaseGraphStorage):
# Add nodes to result # Add nodes to result
for node in node_results: for node in node_results:
node_properties = {k: v for k, v in node.items() if k not in ["id", "name", "entity_type"]} node_properties = {
k: v for k, v in node.items() if k not in ["id", "name", "entity_type"]
}
result.nodes.append( result.nodes.append(
KnowledgeGraphNode( KnowledgeGraphNode(
id=node["name"], id=node["name"],
labels=[node["entity_type"]] if node.get("entity_type") else [node["name"]], labels=[node["entity_type"]]
properties=node_properties if node.get("entity_type")
else [node["name"]],
properties=node_properties,
) )
) )
@@ -660,17 +670,23 @@ class TiDBGraphStorage(BaseGraphStorage):
edge_results = await self.db.query( edge_results = await self.db.query(
SQL_TEMPLATES["get_related_edges"].format(node_names=node_names_str), SQL_TEMPLATES["get_related_edges"].format(node_names=node_names_str),
{"workspace": self.db.workspace}, {"workspace": self.db.workspace},
multirows=True multirows=True,
) )
if edge_results: if edge_results:
# Add edges to result # Add edges to result
for edge in edge_results: for edge in edge_results:
# Only include edges related to selected nodes # Only include edges related to selected nodes
if edge["source_name"] in node_names and edge["target_name"] in node_names: if (
edge["source_name"] in node_names
and edge["target_name"] in node_names
):
edge_id = f"{edge['source_name']}-{edge['target_name']}" edge_id = f"{edge['source_name']}-{edge['target_name']}"
edge_properties = {k: v for k, v in edge.items() edge_properties = {
if k not in ["id", "source_name", "target_name"]} k: v
for k, v in edge.items()
if k not in ["id", "source_name", "target_name"]
}
result.edges.append( result.edges.append(
KnowledgeGraphEdge( KnowledgeGraphEdge(
@@ -678,7 +694,7 @@ class TiDBGraphStorage(BaseGraphStorage):
type="RELATED", type="RELATED",
source=edge["source_name"], source=edge["source_name"],
target=edge["target_name"], target=edge["target_name"],
properties=edge_properties properties=edge_properties,
) )
) )
@@ -703,11 +719,10 @@ class TiDBGraphStorage(BaseGraphStorage):
edges: List of edges to delete, each edge is a (source, target) tuple edges: List of edges to delete, each edge is a (source, target) tuple
""" """
for source, target in edges: for source, target in edges:
await self.db.execute(SQL_TEMPLATES["remove_multiple_edges"], { await self.db.execute(
"source": source, SQL_TEMPLATES["remove_multiple_edges"],
"target": target, {"source": source, "target": target, "workspace": self.db.workspace},
"workspace": self.db.workspace )
})
N_T = { N_T = {
@@ -952,5 +967,5 @@ SQL_TEMPLATES = {
DELETE FROM LIGHTRAG_GRAPH_EDGES DELETE FROM LIGHTRAG_GRAPH_EDGES
WHERE (source_name = :source AND target_name = :target) WHERE (source_name = :source AND target_name = :target)
AND workspace = :workspace AND workspace = :workspace
""" """,
} }

View File

@@ -1407,7 +1407,9 @@ class LightRAG:
target_entity: Name of the target entity target_entity: Name of the target entity
""" """
loop = always_get_an_event_loop() loop = always_get_an_event_loop()
return loop.run_until_complete(self.adelete_by_relation(source_entity, target_entity)) return loop.run_until_complete(
self.adelete_by_relation(source_entity, target_entity)
)
async def adelete_by_relation(self, source_entity: str, target_entity: str) -> None: async def adelete_by_relation(self, source_entity: str, target_entity: str) -> None:
"""Asynchronously delete a relation between two entities. """Asynchronously delete a relation between two entities.
@@ -1418,22 +1420,34 @@ class LightRAG:
""" """
try: try:
# Check if the relation exists # Check if the relation exists
edge_exists = await self.chunk_entity_relation_graph.has_edge(source_entity, target_entity) edge_exists = await self.chunk_entity_relation_graph.has_edge(
source_entity, target_entity
)
if not edge_exists: if not edge_exists:
logger.warning(f"Relation from '{source_entity}' to '{target_entity}' does not exist") logger.warning(
f"Relation from '{source_entity}' to '{target_entity}' does not exist"
)
return return
# Delete relation from vector database # Delete relation from vector database
relation_id = compute_mdhash_id(source_entity + target_entity, prefix="rel-") relation_id = compute_mdhash_id(
source_entity + target_entity, prefix="rel-"
)
await self.relationships_vdb.delete([relation_id]) await self.relationships_vdb.delete([relation_id])
# Delete relation from knowledge graph # Delete relation from knowledge graph
await self.chunk_entity_relation_graph.remove_edges([(source_entity, target_entity)]) await self.chunk_entity_relation_graph.remove_edges(
[(source_entity, target_entity)]
)
logger.info(f"Successfully deleted relation from '{source_entity}' to '{target_entity}'") logger.info(
f"Successfully deleted relation from '{source_entity}' to '{target_entity}'"
)
await self._delete_relation_done() await self._delete_relation_done()
except Exception as e: except Exception as e:
logger.error(f"Error while deleting relation from '{source_entity}' to '{target_entity}': {e}") logger.error(
f"Error while deleting relation from '{source_entity}' to '{target_entity}': {e}"
)
async def _delete_relation_done(self) -> None: async def _delete_relation_done(self) -> None:
"""Callback after relation deletion is complete""" """Callback after relation deletion is complete"""