fix linting

This commit is contained in:
zrguo
2025-03-04 15:53:20 +08:00
parent 3a2a636862
commit 81568f3bad
11 changed files with 394 additions and 327 deletions

View File

@@ -527,11 +527,15 @@ class PGVectorStorage(BaseVectorStorage):
return
ids_list = ",".join([f"'{id}'" for id in ids])
delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id IN ({ids_list})"
delete_sql = (
f"DELETE FROM {table_name} WHERE workspace=$1 AND id IN ({ids_list})"
)
try:
await self.db.execute(delete_sql, {"workspace": self.db.workspace})
logger.debug(f"Successfully deleted {len(ids)} vectors from {self.namespace}")
logger.debug(
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
)
except Exception as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
@@ -543,12 +547,11 @@ class PGVectorStorage(BaseVectorStorage):
"""
try:
# Construct SQL to delete the entity
delete_sql = """DELETE FROM LIGHTRAG_VDB_ENTITY
delete_sql = """DELETE FROM LIGHTRAG_VDB_ENTITY
WHERE workspace=$1 AND entity_name=$2"""
await self.db.execute(
delete_sql,
{"workspace": self.db.workspace, "entity_name": entity_name}
delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name}
)
logger.debug(f"Successfully deleted entity {entity_name}")
except Exception as e:
@@ -562,12 +565,11 @@ class PGVectorStorage(BaseVectorStorage):
"""
try:
# Delete relations where the entity is either the source or target
delete_sql = """DELETE FROM LIGHTRAG_VDB_RELATION
delete_sql = """DELETE FROM LIGHTRAG_VDB_RELATION
WHERE workspace=$1 AND (source_id=$2 OR target_id=$2)"""
await self.db.execute(
delete_sql,
{"workspace": self.db.workspace, "entity_name": entity_name}
delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name}
)
logger.debug(f"Successfully deleted relations for entity {entity_name}")
except Exception as e:
@@ -1167,7 +1169,9 @@ class PGGraphStorage(BaseGraphStorage):
Args:
node_ids (list[str]): A list of node IDs to remove.
"""
encoded_node_ids = [self._encode_graph_label(node_id.strip('"')) for node_id in node_ids]
encoded_node_ids = [
self._encode_graph_label(node_id.strip('"')) for node_id in node_ids
]
node_id_list = ", ".join([f'"{node_id}"' for node_id in encoded_node_ids])
query = """SELECT * FROM cypher('%s', $$
@@ -1189,7 +1193,13 @@ class PGGraphStorage(BaseGraphStorage):
Args:
edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id).
"""
encoded_edges = [(self._encode_graph_label(src.strip('"')), self._encode_graph_label(tgt.strip('"'))) for src, tgt in edges]
encoded_edges = [
(
self._encode_graph_label(src.strip('"')),
self._encode_graph_label(tgt.strip('"')),
)
for src, tgt in edges
]
edge_list = ", ".join([f'["{src}", "{tgt}"]' for src, tgt in encoded_edges])
query = """SELECT * FROM cypher('%s', $$
@@ -1211,10 +1221,13 @@ class PGGraphStorage(BaseGraphStorage):
Returns:
list[str]: A list of all labels in the graph.
"""
query = """SELECT * FROM cypher('%s', $$
query = (
"""SELECT * FROM cypher('%s', $$
MATCH (n:Entity)
RETURN DISTINCT n.node_id AS label
$$) AS (label text)""" % self.graph_name
$$) AS (label text)"""
% self.graph_name
)
results = await self._query(query)
labels = [self._decode_graph_label(result["label"]) for result in results]
@@ -1260,7 +1273,10 @@ class PGGraphStorage(BaseGraphStorage):
OPTIONAL MATCH (n)-[r]->(m:Entity)
RETURN n, r, m
LIMIT %d
$$) AS (n agtype, r agtype, m agtype)""" % (self.graph_name, MAX_GRAPH_NODES)
$$) AS (n agtype, r agtype, m agtype)""" % (
self.graph_name,
MAX_GRAPH_NODES,
)
else:
encoded_node_label = self._encode_graph_label(node_label.strip('"'))
query = """SELECT * FROM cypher('%s', $$
@@ -1268,7 +1284,12 @@ class PGGraphStorage(BaseGraphStorage):
OPTIONAL MATCH p = (n)-[*..%d]-(m)
RETURN nodes(p) AS nodes, relationships(p) AS relationships
LIMIT %d
$$) AS (nodes agtype[], relationships agtype[])""" % (self.graph_name, encoded_node_label, max_depth, MAX_GRAPH_NODES)
$$) AS (nodes agtype[], relationships agtype[])""" % (
self.graph_name,
encoded_node_label,
max_depth,
MAX_GRAPH_NODES,
)
results = await self._query(query)
@@ -1305,29 +1326,6 @@ class PGGraphStorage(BaseGraphStorage):
return kg
async def get_all_labels(self) -> list[str]:
"""
Get all node labels in the graph
Returns:
[label1, label2, ...] # Alphabetically sorted label list
"""
query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity)
RETURN DISTINCT n.node_id AS label
ORDER BY label
$$) AS (label agtype)""" % (self.graph_name)
try:
results = await self._query(query)
labels = []
for record in results:
if record["label"]:
labels.append(self._decode_graph_label(record["label"]))
return labels
except Exception as e:
logger.error(f"Error getting all labels: {str(e)}")
return []
async def drop(self) -> None:
"""Drop the storage"""
drop_sql = SQL_TEMPLATES["drop_vdb_entity"]