linting fix

This commit is contained in:
Saifeddine ALOUI
2025-01-22 00:40:39 +01:00
parent f5fd8d5eac
commit 6db8b5bf79

View File

@@ -8,6 +8,7 @@ Version: 2.2
from pathlib import Path from pathlib import Path
from typing import Optional, Tuple, Dict, List, Any from typing import Optional, Tuple, Dict, List, Any
import pipmaster as pm import pipmaster as pm
# Install all required dependencies # Install all required dependencies
REQUIRED_PACKAGES = [ REQUIRED_PACKAGES = [
"PyQt5", "PyQt5",
@@ -18,10 +19,9 @@ REQUIRED_PACKAGES = [
"networkx", "networkx",
"matplotlib", "matplotlib",
"python-louvain", "python-louvain",
"ascii_colors" "ascii_colors",
] ]
from ascii_colors import ASCIIColors, trace_exception
def setup_dependencies(): def setup_dependencies():
""" """
@@ -32,6 +32,7 @@ def setup_dependencies():
print(f"Installing {package}...") print(f"Installing {package}...")
pm.install(package) pm.install(package)
# Install dependencies # Install dependencies
setup_dependencies() setup_dependencies()
@@ -40,18 +41,32 @@ import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import community import community
from PyQt5.QtWidgets import ( from PyQt5.QtWidgets import (
QApplication, QMainWindow, QWidget, QVBoxLayout, QApplication,
QHBoxLayout, QPushButton, QFileDialog, QLabel, QMainWindow,
QMessageBox, QSpinBox, QComboBox, QCheckBox, QWidget,
QTableWidget, QTableWidgetItem, QSplitter, QDockWidget, QVBoxLayout,
QTextEdit QHBoxLayout,
QPushButton,
QFileDialog,
QLabel,
QMessageBox,
QSpinBox,
QComboBox,
QCheckBox,
QTableWidget,
QTableWidgetItem,
QSplitter,
QDockWidget,
QTextEdit,
) )
from PyQt5.QtCore import Qt from PyQt5.QtCore import Qt
import pyqtgraph.opengl as gl import pyqtgraph.opengl as gl
from ascii_colors import trace_exception
class Point: class Point:
"""Simple point class to handle coordinates""" """Simple point class to handle coordinates"""
def __init__(self, x: float, y: float): def __init__(self, x: float, y: float):
self.x = x self.x = x
self.y = y self.y = y
@@ -59,25 +74,33 @@ class Point:
class NodeState: class NodeState:
"""Data class for node visual state""" """Data class for node visual state"""
NORMAL_SCALE = 1.0 NORMAL_SCALE = 1.0
HOVER_SCALE = 1.2 HOVER_SCALE = 1.2
SELECTED_SCALE = 1.3 SELECTED_SCALE = 1.3
NORMAL_OPACITY = 0.8 NORMAL_OPACITY = 0.8
HOVER_OPACITY = 1.0 HOVER_OPACITY = 1.0
SELECTED_OPACITY = 1.0 SELECTED_OPACITY = 1.0
# Increase base node size (was 0.05) # Increase base node size (was 0.05)
BASE_SIZE = 0.2 BASE_SIZE = 0.2
SELECTED_COLOR = (1.0, 1.0, 0.0, 1.0) SELECTED_COLOR = (1.0, 1.0, 0.0, 1.0)
HOVER_COLOR = (1.0, 0.8, 0.0, 1.0) HOVER_COLOR = (1.0, 0.8, 0.0, 1.0)
class Node3D: class Node3D:
"""Class representing a 3D node in the graph""" """Class representing a 3D node in the graph"""
def __init__(self, position: np.ndarray, color: Tuple[float, float, float, float],
label: str, node_type: str, size: float): def __init__(
self,
position: np.ndarray,
color: Tuple[float, float, float, float],
label: str,
node_type: str,
size: float,
):
self.position = position self.position = position
self.base_color = color self.base_color = color
self.color = color self.color = color
@@ -119,26 +142,27 @@ class Node3D:
"""Update node visual appearance""" """Update node visual appearance"""
if self.mesh_item: if self.mesh_item:
self.mesh_item.setData( self.mesh_item.setData(
color=np.array([self.color]), color=np.array([self.color]), size=np.array([self.size * scale * 5])
size=np.array([self.size * scale * 5])
) )
class NodeDetailsWidget(QWidget): class NodeDetailsWidget(QWidget):
"""Widget to display node details""" """Widget to display node details"""
def __init__(self, parent=None): def __init__(self, parent=None):
super().__init__(parent) super().__init__(parent)
self.init_ui() self.init_ui()
def init_ui(self): def init_ui(self):
"""Initialize the UI""" """Initialize the UI"""
layout = QVBoxLayout(self) layout = QVBoxLayout(self)
# Properties text edit # Properties text edit
self.properties = QTextEdit() self.properties = QTextEdit()
self.properties.setReadOnly(True) self.properties.setReadOnly(True)
layout.addWidget(QLabel("Properties:")) layout.addWidget(QLabel("Properties:"))
layout.addWidget(self.properties) layout.addWidget(self.properties)
# Connections table # Connections table
self.connections = QTableWidget() self.connections = QTableWidget()
self.connections.setColumnCount(3) self.connections.setColumnCount(3)
@@ -155,22 +179,23 @@ class NodeDetailsWidget(QWidget):
for key, value in node_data.items(): for key, value in node_data.items():
properties_text += f"{key}: {value}\n" properties_text += f"{key}: {value}\n"
self.properties.setText(properties_text) self.properties.setText(properties_text)
# Update connections # Update connections
self.connections.setRowCount(len(connections)) self.connections.setRowCount(len(connections))
for idx, (neighbor, edge_data) in enumerate(connections.items()): for idx, (neighbor, edge_data) in enumerate(connections.items()):
self.connections.setItem(idx, 0, QTableWidgetItem(str(neighbor))) self.connections.setItem(idx, 0, QTableWidgetItem(str(neighbor)))
self.connections.setItem( self.connections.setItem(
idx, 1, idx, 1, QTableWidgetItem(edge_data.get("relationship", "unknown"))
QTableWidgetItem(edge_data.get('relationship', 'unknown'))
) )
self.connections.setItem(idx, 2, QTableWidgetItem("outgoing")) self.connections.setItem(idx, 2, QTableWidgetItem("outgoing"))
class GraphMLViewer3D(QMainWindow): class GraphMLViewer3D(QMainWindow):
"""Main window class for 3D GraphML visualization""" """Main window class for 3D GraphML visualization"""
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.graph: Optional[nx.Graph] = None self.graph: Optional[nx.Graph] = None
self.nodes: Dict[str, Node3D] = {} self.nodes: Dict[str, Node3D] = {}
self.edges: List[gl.GLLinePlotItem] = [] self.edges: List[gl.GLLinePlotItem] = []
@@ -178,7 +203,7 @@ class GraphMLViewer3D(QMainWindow):
self.selected_node = None self.selected_node = None
self.communities = None self.communities = None
self.community_colors = None self.community_colors = None
self.mouse_pos_last = None self.mouse_pos_last = None
self.mouse_buttons_pressed = set() self.mouse_buttons_pressed = set()
self.distance = 20 # Initial camera distance self.distance = 20 # Initial camera distance
@@ -186,45 +211,44 @@ class GraphMLViewer3D(QMainWindow):
self.elevation = 30 # Initial camera elevation self.elevation = 30 # Initial camera elevation
self.azimuth = 45 # Initial camera azimuth self.azimuth = 45 # Initial camera azimuth
self.init_ui() self.init_ui()
def init_ui(self): def init_ui(self):
"""Initialize the user interface""" """Initialize the user interface"""
self.setWindowTitle("3D GraphML Viewer") self.setWindowTitle("3D GraphML Viewer")
self.setGeometry(100, 100, 1600, 900) self.setGeometry(100, 100, 1600, 900)
# Create main splitter # Create main splitter
self.main_splitter = QSplitter(Qt.Horizontal) self.main_splitter = QSplitter(Qt.Horizontal)
self.setCentralWidget(self.main_splitter) self.setCentralWidget(self.main_splitter)
# Create left panel for 3D view # Create left panel for 3D view
left_widget = QWidget() left_widget = QWidget()
left_layout = QVBoxLayout(left_widget) left_layout = QVBoxLayout(left_widget)
# Create controls # Create controls
self.create_toolbar(left_layout) self.create_toolbar(left_layout)
# Create 3D view # Create 3D view
self.view = gl.GLViewWidget() self.view = gl.GLViewWidget()
self.view.setMouseTracking(True) self.view.setMouseTracking(True)
# Connect mouse events # Connect mouse events
self.view.mousePressEvent = self.on_mouse_press self.view.mousePressEvent = self.on_mouse_press
self.view.mouseMoveEvent = self.on_mouse_move self.view.mouseMoveEvent = self.on_mouse_move
left_layout.addWidget(self.view) left_layout.addWidget(self.view)
self.main_splitter.addWidget(left_widget) self.main_splitter.addWidget(left_widget)
# Create details widget # Create details widget
self.details = NodeDetailsWidget() self.details = NodeDetailsWidget()
details_dock = QDockWidget("Node Details", self) details_dock = QDockWidget("Node Details", self)
details_dock.setWidget(self.details) details_dock.setWidget(self.details)
self.addDockWidget(Qt.RightDockWidgetArea, details_dock) self.addDockWidget(Qt.RightDockWidgetArea, details_dock)
# Add status bar # Add status bar
self.statusBar().showMessage("Ready") self.statusBar().showMessage("Ready")
# Add initial grid # Add initial grid
grid = gl.GLGridItem() grid = gl.GLGridItem()
grid.setSize(x=20, y=20, z=20) grid.setSize(x=20, y=20, z=20)
@@ -233,11 +257,9 @@ class GraphMLViewer3D(QMainWindow):
# Set initial camera position # Set initial camera position
self.view.setCameraPosition( self.view.setCameraPosition(
distance=self.distance, distance=self.distance, elevation=self.elevation, azimuth=self.azimuth
elevation=self.elevation,
azimuth=self.azimuth
) )
# Connect all mouse events # Connect all mouse events
self.view.mousePressEvent = self.on_mouse_press self.view.mousePressEvent = self.on_mouse_press
self.view.mouseReleaseEvent = self.on_mouse_release self.view.mouseReleaseEvent = self.on_mouse_release
@@ -248,14 +270,14 @@ class GraphMLViewer3D(QMainWindow):
"""Calculate node sizes based on number of connections""" """Calculate node sizes based on number of connections"""
if not self.graph: if not self.graph:
return {} return {}
# Get degree (number of connections) for each node # Get degree (number of connections) for each node
degrees = dict(self.graph.degree()) degrees = dict(self.graph.degree())
# Calculate size scaling # Calculate size scaling
max_degree = max(degrees.values()) max_degree = max(degrees.values())
min_degree = min(degrees.values()) min_degree = min(degrees.values())
# Normalize sizes between 0.5 and 2.0 # Normalize sizes between 0.5 and 2.0
sizes = {} sizes = {}
for node, degree in degrees.items(): for node, degree in degrees.items():
@@ -265,31 +287,30 @@ class GraphMLViewer3D(QMainWindow):
# Normalize and scale size # Normalize and scale size
normalized = (degree - min_degree) / (max_degree - min_degree) normalized = (degree - min_degree) / (max_degree - min_degree)
sizes[node] = 0.5 + normalized * 1.5 sizes[node] = 0.5 + normalized * 1.5
return sizes
return sizes
def create_toolbar(self, layout: QVBoxLayout): def create_toolbar(self, layout: QVBoxLayout):
"""Create the toolbar with controls""" """Create the toolbar with controls"""
toolbar = QHBoxLayout() toolbar = QHBoxLayout()
# Load button # Load button
load_btn = QPushButton("Load GraphML") load_btn = QPushButton("Load GraphML")
load_btn.clicked.connect(self.load_graphml) load_btn.clicked.connect(self.load_graphml)
toolbar.addWidget(load_btn) toolbar.addWidget(load_btn)
# Reset view button # Reset view button
reset_btn = QPushButton("Reset View") reset_btn = QPushButton("Reset View")
reset_btn.clicked.connect(lambda: self.view.setCameraPosition(distance=20)) reset_btn.clicked.connect(lambda: self.view.setCameraPosition(distance=20))
toolbar.addWidget(reset_btn) toolbar.addWidget(reset_btn)
# Layout selector # Layout selector
self.layout_combo = QComboBox() self.layout_combo = QComboBox()
self.layout_combo.addItems(["Spring", "Circular", "Shell", "Random"]) self.layout_combo.addItems(["Spring", "Circular", "Shell", "Random"])
self.layout_combo.currentTextChanged.connect(self.refresh_layout) self.layout_combo.currentTextChanged.connect(self.refresh_layout)
toolbar.addWidget(QLabel("Layout:")) toolbar.addWidget(QLabel("Layout:"))
toolbar.addWidget(self.layout_combo) toolbar.addWidget(self.layout_combo)
# Node size control # Node size control
self.node_size = QSpinBox() self.node_size = QSpinBox()
self.node_size.setRange(1, 100) self.node_size.setRange(1, 100)
@@ -297,28 +318,26 @@ class GraphMLViewer3D(QMainWindow):
self.node_size.valueChanged.connect(self.refresh_layout) self.node_size.valueChanged.connect(self.refresh_layout)
toolbar.addWidget(QLabel("Node Size:")) toolbar.addWidget(QLabel("Node Size:"))
toolbar.addWidget(self.node_size) toolbar.addWidget(self.node_size)
# Show labels checkbox # Show labels checkbox
self.show_labels = QCheckBox("Show Labels") self.show_labels = QCheckBox("Show Labels")
self.show_labels.setChecked(True) self.show_labels.setChecked(True)
self.show_labels.stateChanged.connect(self.refresh_layout) self.show_labels.stateChanged.connect(self.refresh_layout)
toolbar.addWidget(self.show_labels) toolbar.addWidget(self.show_labels)
layout.addLayout(toolbar) layout.addLayout(toolbar)
# Reset view button # Reset view button
reset_btn = QPushButton("Reset View") reset_btn = QPushButton("Reset View")
reset_btn.clicked.connect(self.reset_view) # Use the new reset_view method reset_btn.clicked.connect(self.reset_view) # Use the new reset_view method
toolbar.addWidget(reset_btn) toolbar.addWidget(reset_btn)
def load_graphml(self) -> None: def load_graphml(self) -> None:
"""Load and visualize a GraphML file""" """Load and visualize a GraphML file"""
try: try:
file_path, _ = QFileDialog.getOpenFileName( file_path, _ = QFileDialog.getOpenFileName(
self, "Open GraphML file", "", "GraphML files (*.graphml)" self, "Open GraphML file", "", "GraphML files (*.graphml)"
) )
if file_path: if file_path:
self.graph = nx.read_graphml(Path(file_path)) self.graph = nx.read_graphml(Path(file_path))
self.refresh_layout() self.refresh_layout()
@@ -330,19 +349,15 @@ class GraphMLViewer3D(QMainWindow):
def calculate_layout(self) -> Dict[str, np.ndarray]: def calculate_layout(self) -> Dict[str, np.ndarray]:
"""Calculate node positions based on selected layout""" """Calculate node positions based on selected layout"""
layout_type = self.layout_combo.currentText().lower() layout_type = self.layout_combo.currentText().lower()
# Detect communities for coloring # Detect communities for coloring
self.communities = community.best_partition(self.graph) self.communities = community.best_partition(self.graph)
num_communities = len(set(self.communities.values())) num_communities = len(set(self.communities.values()))
self.community_colors = plt.cm.rainbow(np.linspace(0, 1, num_communities)) self.community_colors = plt.cm.rainbow(np.linspace(0, 1, num_communities))
if layout_type == "spring": if layout_type == "spring":
pos = nx.spring_layout( pos = nx.spring_layout(
self.graph, self.graph, dim=3, k=2.0, iterations=100, weight=None
dim=3,
k=2.0,
iterations=100,
weight=None
) )
elif layout_type == "circular": elif layout_type == "circular":
pos_2d = nx.circular_layout(self.graph) pos_2d = nx.circular_layout(self.graph)
@@ -355,7 +370,7 @@ class GraphMLViewer3D(QMainWindow):
pos = {node: np.array([x, y, 0.0]) for node, (x, y) in pos_2d.items()} pos = {node: np.array([x, y, 0.0]) for node, (x, y) in pos_2d.items()}
else: # random else: # random
pos = {node: np.random.rand(3) * 2 - 1 for node in self.graph.nodes()} pos = {node: np.random.rand(3) * 2 - 1 for node in self.graph.nodes()}
# Scale positions # Scale positions
positions = np.array(list(pos.values())) positions = np.array(list(pos.values()))
if len(positions) > 0: if len(positions) > 0:
@@ -365,56 +380,52 @@ class GraphMLViewer3D(QMainWindow):
def get_node_color(self, node_id: str) -> Tuple[float, float, float, float]: def get_node_color(self, node_id: str) -> Tuple[float, float, float, float]:
"""Get RGBA color based on community""" """Get RGBA color based on community"""
if hasattr(self, 'communities') and node_id in self.communities: if hasattr(self, "communities") and node_id in self.communities:
comm_id = self.communities[node_id] comm_id = self.communities[node_id]
color = self.community_colors[comm_id] color = self.community_colors[comm_id]
return tuple(color) return tuple(color)
return (0.5, 0.5, 0.5, 0.8) return (0.5, 0.5, 0.5, 0.8)
def create_node(self, node_id: str, position: np.ndarray, node_type: str) -> Node3D: def create_node(self, node_id: str, position: np.ndarray, node_type: str) -> Node3D:
"""Create a 3D node with interaction capabilities""" """Create a 3D node with interaction capabilities"""
color = self.get_node_color(node_id) color = self.get_node_color(node_id)
# Get size multiplier based on connections # Get size multiplier based on connections
size_multiplier = self.node_sizes.get(node_id, 1.0) size_multiplier = self.node_sizes.get(node_id, 1.0)
size = NodeState.BASE_SIZE * self.node_size.value() / 50.0 * size_multiplier size = NodeState.BASE_SIZE * self.node_size.value() / 50.0 * size_multiplier
node = Node3D(position, color, str(node_id), node_type, size) node = Node3D(position, color, str(node_id), node_type, size)
node.mesh_item = gl.GLScatterPlotItem( node.mesh_item = gl.GLScatterPlotItem(
pos=np.array([position]), pos=np.array([position]),
size=np.array([size * 8]), size=np.array([size * 8]),
color=np.array([color]), color=np.array([color]),
pxMode=False pxMode=False,
) )
# Enable picking and set node ID # Enable picking and set node ID
node.mesh_item.setGLOptions('translucent') node.mesh_item.setGLOptions("translucent")
node.mesh_item.node_id = node_id node.mesh_item.node_id = node_id
if self.show_labels.isChecked(): if self.show_labels.isChecked():
node.label_item = gl.GLTextItem( node.label_item = gl.GLTextItem(
pos=position, pos=position,
text=str(node_id), text=str(node_id),
color=(1, 1, 1, 1), color=(1, 1, 1, 1),
) )
return node return node
def mapToView(self, pos) -> Point: def mapToView(self, pos) -> Point:
"""Convert screen coordinates to world coordinates""" """Convert screen coordinates to world coordinates"""
# Get the viewport size # Get the viewport size
width = self.view.width() width = self.view.width()
height = self.view.height() height = self.view.height()
# Normalize coordinates # Normalize coordinates
x = (pos.x() / width - 0.5) * 20 # Scale factor of 20 matches the grid size x = (pos.x() / width - 0.5) * 20 # Scale factor of 20 matches the grid size
y = -(pos.y() / height - 0.5) * 20 y = -(pos.y() / height - 0.5) * 20
return Point(x, y) return Point(x, y)
def on_mouse_move(self, event): def on_mouse_move(self, event):
@@ -422,59 +433,57 @@ class GraphMLViewer3D(QMainWindow):
if self.mouse_pos_last is None: if self.mouse_pos_last is None:
self.mouse_pos_last = event.pos() self.mouse_pos_last = event.pos()
return return
pos = event.pos() pos = event.pos()
dx = pos.x() - self.mouse_pos_last.x() dx = pos.x() - self.mouse_pos_last.x()
dy = pos.y() - self.mouse_pos_last.y() dy = pos.y() - self.mouse_pos_last.y()
# Handle right button drag for panning # Handle right button drag for panning
if Qt.RightButton in self.mouse_buttons_pressed: if Qt.RightButton in self.mouse_buttons_pressed:
# Scale the pan amount based on view distance # Scale the pan amount based on view distance
scale = self.distance / 1000.0 scale = self.distance / 1000.0
# Calculate pan in view coordinates # Calculate pan in view coordinates
right = np.cross([0, 0, 1], self.view.cameraPosition()) right = np.cross([0, 0, 1], self.view.cameraPosition())
right = right / np.linalg.norm(right) right = right / np.linalg.norm(right)
up = np.cross(self.view.cameraPosition(), right) up = np.cross(self.view.cameraPosition(), right)
up = up / np.linalg.norm(up) up = up / np.linalg.norm(up)
pan = -right * dx * scale + up * dy * scale pan = -right * dx * scale + up * dy * scale
self.center += pan self.center += pan
self.view.pan(dx, dy, 0) self.view.pan(dx, dy, 0)
# Handle middle button drag for rotation # Handle middle button drag for rotation
elif Qt.MiddleButton in self.mouse_buttons_pressed: elif Qt.MiddleButton in self.mouse_buttons_pressed:
self.azimuth += dx * 0.5 # Adjust rotation speed as needed self.azimuth += dx * 0.5 # Adjust rotation speed as needed
self.elevation -= dy * 0.5 self.elevation -= dy * 0.5
# Clamp elevation to prevent gimbal lock # Clamp elevation to prevent gimbal lock
self.elevation = np.clip(self.elevation, -89, 89) self.elevation = np.clip(self.elevation, -89, 89)
self.view.setCameraPosition( self.view.setCameraPosition(
distance=self.distance, distance=self.distance, elevation=self.elevation, azimuth=self.azimuth
elevation=self.elevation,
azimuth=self.azimuth
) )
# Handle hover events when no buttons are pressed # Handle hover events when no buttons are pressed
elif not self.mouse_buttons_pressed: elif not self.mouse_buttons_pressed:
# Get the mouse position in world coordinates # Get the mouse position in world coordinates
mouse_pos = self.mapToView(pos) mouse_pos = self.mapToView(pos)
# Check for hover # Check for hover
min_dist = float('inf') min_dist = float("inf")
hovered_node = None hovered_node = None
for node_id, node in self.nodes.items(): for node_id, node in self.nodes.items():
# Calculate distance to mouse in world coordinates # Calculate distance to mouse in world coordinates
dx = mouse_pos.x - node.position[0] dx = mouse_pos.x - node.position[0]
dy = mouse_pos.y - node.position[1] dy = mouse_pos.y - node.position[1]
dist = np.sqrt(dx*dx + dy*dy) dist = np.sqrt(dx * dx + dy * dy)
if dist < min_dist and dist < 0.5: # Adjust threshold as needed if dist < min_dist and dist < 0.5: # Adjust threshold as needed
min_dist = dist min_dist = dist
hovered_node = node_id hovered_node = node_id
# Update hover states # Update hover states
for node_id, node in self.nodes.items(): for node_id, node in self.nodes.items():
if node_id == hovered_node: if node_id == hovered_node:
@@ -483,91 +492,88 @@ class GraphMLViewer3D(QMainWindow):
else: else:
if not node.is_selected: if not node.is_selected:
node.unhighlight() node.unhighlight()
self.mouse_pos_last = pos self.mouse_pos_last = pos
def on_mouse_press(self, event): def on_mouse_press(self, event):
"""Handle mouse press events""" """Handle mouse press events"""
self.mouse_pos_last = event.pos() self.mouse_pos_last = event.pos()
self.mouse_buttons_pressed.add(event.button()) self.mouse_buttons_pressed.add(event.button())
# Handle left click for node selection # Handle left click for node selection
if event.button() == Qt.LeftButton: if event.button() == Qt.LeftButton:
pos = event.pos() pos = event.pos()
mouse_pos = self.mapToView(pos) mouse_pos = self.mapToView(pos)
# Find closest node # Find closest node
min_dist = float('inf') min_dist = float("inf")
clicked_node = None clicked_node = None
for node_id, node in self.nodes.items(): for node_id, node in self.nodes.items():
dx = mouse_pos.x - node.position[0] dx = mouse_pos.x - node.position[0]
dy = mouse_pos.y - node.position[1] dy = mouse_pos.y - node.position[1]
dist = np.sqrt(dx*dx + dy*dy) dist = np.sqrt(dx * dx + dy * dy)
if dist < min_dist and dist < 0.5: # Adjust threshold as needed if dist < min_dist and dist < 0.5: # Adjust threshold as needed
min_dist = dist min_dist = dist
clicked_node = node_id clicked_node = node_id
# Handle selection # Handle selection
if clicked_node: if clicked_node:
if self.selected_node and self.selected_node in self.nodes: if self.selected_node and self.selected_node in self.nodes:
self.nodes[self.selected_node].deselect() self.nodes[self.selected_node].deselect()
self.nodes[clicked_node].select() self.nodes[clicked_node].select()
self.selected_node = clicked_node self.selected_node = clicked_node
if self.graph: if self.graph:
self.details.update_node_info( self.details.update_node_info(
self.graph.nodes[clicked_node], self.graph.nodes[clicked_node], self.graph[clicked_node]
self.graph[clicked_node]
) )
def on_mouse_release(self, event): def on_mouse_release(self, event):
"""Handle mouse release events""" """Handle mouse release events"""
self.mouse_buttons_pressed.discard(event.button()) self.mouse_buttons_pressed.discard(event.button())
self.mouse_pos_last = None self.mouse_pos_last = None
def on_mouse_wheel(self, event): def on_mouse_wheel(self, event):
"""Handle mouse wheel for zooming""" """Handle mouse wheel for zooming"""
delta = event.angleDelta().y() delta = event.angleDelta().y()
# Adjust zoom speed based on current distance # Adjust zoom speed based on current distance
zoom_speed = self.distance / 100.0 zoom_speed = self.distance / 100.0
# Update distance with limits # Update distance with limits
self.distance -= delta * zoom_speed self.distance -= delta * zoom_speed
self.distance = np.clip(self.distance, 1.0, 100.0) self.distance = np.clip(self.distance, 1.0, 100.0)
self.view.setCameraPosition( self.view.setCameraPosition(
distance=self.distance, distance=self.distance, elevation=self.elevation, azimuth=self.azimuth
elevation=self.elevation,
azimuth=self.azimuth
) )
def reset_view(self): def reset_view(self):
"""Reset camera to default position""" """Reset camera to default position"""
self.distance = 20 self.distance = 20
self.elevation = 30 self.elevation = 30
self.azimuth = 45 self.azimuth = 45
self.center = np.array([0, 0, 0]) self.center = np.array([0, 0, 0])
self.view.setCameraPosition( self.view.setCameraPosition(
distance=self.distance, distance=self.distance, elevation=self.elevation, azimuth=self.azimuth
elevation=self.elevation,
azimuth=self.azimuth
) )
def create_edge(
def create_edge(self, start_pos: np.ndarray, end_pos: np.ndarray, self,
color: Tuple[float, float, float, float] = (0.3, 0.3, 0.3, 0.2) start_pos: np.ndarray,
) -> gl.GLLinePlotItem: end_pos: np.ndarray,
color: Tuple[float, float, float, float] = (0.3, 0.3, 0.3, 0.2),
) -> gl.GLLinePlotItem:
"""Create a 3D edge between nodes""" """Create a 3D edge between nodes"""
return gl.GLLinePlotItem( return gl.GLLinePlotItem(
pos=np.array([start_pos, end_pos]), pos=np.array([start_pos, end_pos]),
color=color, color=color,
width=1, width=1,
antialias=True, antialias=True,
mode='lines' mode="lines",
) )
def handle_node_hover(self, event: Any, node_id: str) -> None: def handle_node_hover(self, event: Any, node_id: str) -> None:
@@ -585,58 +591,57 @@ class GraphMLViewer3D(QMainWindow):
"""Handle node click events""" """Handle node click events"""
if event.button() != Qt.LeftButton or node_id not in self.nodes: if event.button() != Qt.LeftButton or node_id not in self.nodes:
return return
if self.selected_node and self.selected_node in self.nodes: if self.selected_node and self.selected_node in self.nodes:
self.nodes[self.selected_node].deselect() self.nodes[self.selected_node].deselect()
node = self.nodes[node_id] node = self.nodes[node_id]
node.select() node.select()
self.selected_node = node_id self.selected_node = node_id
if self.graph: if self.graph:
self.details.update_node_info( self.details.update_node_info(
self.graph.nodes[node_id], self.graph.nodes[node_id], self.graph[node_id]
self.graph[node_id]
) )
def refresh_layout(self) -> None: def refresh_layout(self) -> None:
"""Refresh the graph visualization""" """Refresh the graph visualization"""
if not self.graph: if not self.graph:
return return
self.positions = self.calculate_layout() self.positions = self.calculate_layout()
self.node_sizes = self.calculate_node_sizes() self.node_sizes = self.calculate_node_sizes()
self.view.clear() self.view.clear()
self.nodes.clear() self.nodes.clear()
self.edges.clear() self.edges.clear()
self.edge_labels.clear() self.edge_labels.clear()
grid = gl.GLGridItem() grid = gl.GLGridItem()
grid.setSize(x=20, y=20, z=20) grid.setSize(x=20, y=20, z=20)
grid.setSpacing(x=1, y=1, z=1) grid.setSpacing(x=1, y=1, z=1)
self.view.addItem(grid) self.view.addItem(grid)
positions = self.calculate_layout() positions = self.calculate_layout()
for node_id in self.graph.nodes(): for node_id in self.graph.nodes():
node_type = self.graph.nodes[node_id].get('type', 'default') node_type = self.graph.nodes[node_id].get("type", "default")
node = self.create_node(node_id, positions[node_id], node_type) node = self.create_node(node_id, positions[node_id], node_type)
self.view.addItem(node.mesh_item) self.view.addItem(node.mesh_item)
if node.label_item: if node.label_item:
self.view.addItem(node.label_item) self.view.addItem(node.label_item)
self.nodes[node_id] = node self.nodes[node_id] = node
for source, target in self.graph.edges(): for source, target in self.graph.edges():
edge = self.create_edge(positions[source], positions[target]) edge = self.create_edge(positions[source], positions[target])
self.view.addItem(edge) self.view.addItem(edge)
self.edges.append(edge) self.edges.append(edge)
if self.show_labels.isChecked(): if self.show_labels.isChecked():
mid_point = (positions[source] + positions[target]) / 2 mid_point = (positions[source] + positions[target]) / 2
relationship = self.graph.edges[source, target].get('relationship', '') relationship = self.graph.edges[source, target].get("relationship", "")
if relationship: if relationship:
label = gl.GLTextItem( label = gl.GLTextItem(
pos=mid_point, pos=mid_point,
@@ -646,14 +651,16 @@ class GraphMLViewer3D(QMainWindow):
self.view.addItem(label) self.view.addItem(label)
self.edge_labels.append(label) self.edge_labels.append(label)
def main(): def main():
"""Application entry point""" """Application entry point"""
import sys import sys
app = QApplication(sys.argv) app = QApplication(sys.argv)
viewer = GraphMLViewer3D() viewer = GraphMLViewer3D()
viewer.show() viewer.show()
sys.exit(app.exec_()) sys.exit(app.exec_())
if __name__ == "__main__": if __name__ == "__main__":
main() main()