Fix linting

This commit is contained in:
yangdx
2025-02-26 18:11:16 +08:00
parent 7d12715f09
commit 7436c06f6c
11 changed files with 205 additions and 144 deletions

View File

@@ -78,29 +78,33 @@ class NetworkXStorage(BaseGraphStorage):
with self._storage_lock:
if is_multiprocess:
if self._graph.value is None:
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
preloaded_graph = NetworkXStorage.load_nx_graph(
self._graphml_xml_file
)
self._graph.value = preloaded_graph or nx.Graph()
if preloaded_graph:
logger.info(
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
)
else:
logger.info("Created new empty graph")
else:
if self._graph is None:
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
preloaded_graph = NetworkXStorage.load_nx_graph(
self._graphml_xml_file
)
self._graph = preloaded_graph or nx.Graph()
if preloaded_graph:
logger.info(
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
)
else:
logger.info("Created new empty graph")
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed,
"node2vec": self._node2vec_embed,
}
def _get_graph(self):
"""Get the appropriate graph instance based on multiprocess mode"""
if is_multiprocess:
@@ -248,11 +252,13 @@ class NetworkXStorage(BaseGraphStorage):
with self._storage_lock:
graph = self._get_graph()
# Handle special case for "*" label
if node_label == "*":
# For "*", return the entire graph including all nodes and edges
subgraph = graph.copy() # Create a copy to avoid modifying the original graph
subgraph = (
graph.copy()
) # Create a copy to avoid modifying the original graph
else:
# Find nodes with matching node id (partial match)
nodes_to_explore = []
@@ -272,9 +278,9 @@ class NetworkXStorage(BaseGraphStorage):
if len(subgraph.nodes()) > max_graph_nodes:
origin_nodes = len(subgraph.nodes())
node_degrees = dict(subgraph.degree())
top_nodes = sorted(node_degrees.items(), key=lambda x: x[1], reverse=True)[
:max_graph_nodes
]
top_nodes = sorted(
node_degrees.items(), key=lambda x: x[1], reverse=True
)[:max_graph_nodes]
top_node_ids = [node[0] for node in top_nodes]
# Create new subgraph with only top nodes
subgraph = subgraph.subgraph(top_node_ids)