diff --git a/README.md b/README.md index 950c5c5a..dd570608 100644 --- a/README.md +++ b/README.md @@ -799,6 +799,7 @@ if __name__ == "__main__": | **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese", "entity_types": ["organization", "person", "geo", "event"], "insert_batch_size": 10}`: sets example limit, output language, and batch size for document processing | `example_number: all examples, language: English, insert_batch_size: 10` | | **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` | | **embedding\_cache\_config** | `dict` | Configuration for question-answer caching. Contains three parameters:
- `enabled`: Boolean value to enable/disable cache lookup functionality. When enabled, the system will check cached responses before generating new answers.
- `similarity_threshold`: Float value (0-1), similarity threshold. When a new question's similarity with a cached question exceeds this threshold, the cached answer will be returned directly without calling the LLM.
- `use_llm_check`: Boolean value to enable/disable LLM similarity verification. When enabled, LLM will be used as a secondary check to verify the similarity between questions before returning cached answers. | Default: `{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False}` | +|**log\_dir** | `str` | Directory to store logs. | `./` | ### Error Handling diff --git a/extra/VisualizationTOol/GraphVisualizer.py b/extra/VisualizationTOol/GraphVisualizer.py deleted file mode 100644 index 3a933430..00000000 --- a/extra/VisualizationTOol/GraphVisualizer.py +++ /dev/null @@ -1,666 +0,0 @@ -""" -3D GraphML Viewer -Author: LoLLMs -Description: An interactive 3D GraphML viewer using PyQt5 and pyqtgraph -Version: 2.2 -""" - -from pathlib import Path -from typing import Optional, Tuple, Dict, List, Any -import pipmaster as pm - -# Install all required dependencies -REQUIRED_PACKAGES = [ - "PyQt5", - "pyqtgraph", - "numpy", - "PyOpenGL", - "PyOpenGL_accelerate", - "networkx", - "matplotlib", - "python-louvain", - "ascii_colors", -] - - -def setup_dependencies(): - """ - Ensure all required packages are installed - """ - for package in REQUIRED_PACKAGES: - if not pm.is_installed(package): - print(f"Installing {package}...") - pm.install(package) - - -# Install dependencies -setup_dependencies() - -import networkx as nx -import numpy as np -import matplotlib.pyplot as plt -import community -from PyQt5.QtWidgets import ( - QApplication, - QMainWindow, - QWidget, - QVBoxLayout, - QHBoxLayout, - QPushButton, - QFileDialog, - QLabel, - QMessageBox, - QSpinBox, - QComboBox, - QCheckBox, - QTableWidget, - QTableWidgetItem, - QSplitter, - QDockWidget, - QTextEdit, -) -from PyQt5.QtCore import Qt -import pyqtgraph.opengl as gl -from ascii_colors import trace_exception - - -class Point: - """Simple point class to handle coordinates""" - - def __init__(self, x: float, y: float): - self.x = x - self.y = y - - -class NodeState: - """Data class for node visual state""" - - NORMAL_SCALE = 1.0 - HOVER_SCALE = 1.2 - SELECTED_SCALE = 1.3 - - NORMAL_OPACITY = 0.8 - HOVER_OPACITY = 1.0 - SELECTED_OPACITY = 1.0 - - # Increase base node size (was 0.05) - BASE_SIZE = 0.2 - - SELECTED_COLOR = (1.0, 1.0, 0.0, 1.0) - HOVER_COLOR = (1.0, 0.8, 0.0, 1.0) - - -class Node3D: - """Class representing a 3D node in the graph""" - - def __init__( - self, - position: np.ndarray, - color: Tuple[float, float, float, float], - label: str, - node_type: str, - size: float, - ): - self.position = position - self.base_color = color - self.color = color - self.label = label - self.node_type = node_type - self.size = size - self.mesh_item = None - self.label_item = None - self.is_highlighted = False - self.is_selected = False - - def highlight(self): - """Highlight the node""" - if not self.is_highlighted and not self.is_selected: - self.color = NodeState.HOVER_COLOR - self.update_appearance(NodeState.HOVER_SCALE) - self.is_highlighted = True - - def unhighlight(self): - """Remove highlight from node""" - if self.is_highlighted and not self.is_selected: - self.color = self.base_color - self.update_appearance(NodeState.NORMAL_SCALE) - self.is_highlighted = False - - def select(self): - """Select the node""" - self.is_selected = True - self.color = NodeState.SELECTED_COLOR - self.update_appearance(NodeState.SELECTED_SCALE) - - def deselect(self): - """Deselect the node""" - self.is_selected = False - self.color = self.base_color - self.update_appearance(NodeState.NORMAL_SCALE) - - def update_appearance(self, scale: float = 1.0): - """Update node visual appearance""" - if self.mesh_item: - self.mesh_item.setData( - color=np.array([self.color]), size=np.array([self.size * scale * 5]) - ) - - -class NodeDetailsWidget(QWidget): - """Widget to display node details""" - - def __init__(self, parent=None): - super().__init__(parent) - self.init_ui() - - def init_ui(self): - """Initialize the UI""" - layout = QVBoxLayout(self) - - # Properties text edit - self.properties = QTextEdit() - self.properties.setReadOnly(True) - layout.addWidget(QLabel("Properties:")) - layout.addWidget(self.properties) - - # Connections table - self.connections = QTableWidget() - self.connections.setColumnCount(3) - self.connections.setHorizontalHeaderLabels( - ["Connected Node", "Relationship", "Direction"] - ) - layout.addWidget(QLabel("Connections:")) - layout.addWidget(self.connections) - - def update_node_info(self, node_data: Dict, connections: Dict): - """Update the display with node information""" - # Update properties - properties_text = "Node Properties:\n" - for key, value in node_data.items(): - properties_text += f"{key}: {value}\n" - self.properties.setText(properties_text) - - # Update connections - self.connections.setRowCount(len(connections)) - for idx, (neighbor, edge_data) in enumerate(connections.items()): - self.connections.setItem(idx, 0, QTableWidgetItem(str(neighbor))) - self.connections.setItem( - idx, 1, QTableWidgetItem(edge_data.get("relationship", "unknown")) - ) - self.connections.setItem(idx, 2, QTableWidgetItem("outgoing")) - - -class GraphMLViewer3D(QMainWindow): - """Main window class for 3D GraphML visualization""" - - def __init__(self): - super().__init__() - - self.graph: Optional[nx.Graph] = None - self.nodes: Dict[str, Node3D] = {} - self.edges: List[gl.GLLinePlotItem] = [] - self.edge_labels: List[gl.GLTextItem] = [] - self.selected_node = None - self.communities = None - self.community_colors = None - - self.mouse_pos_last = None - self.mouse_buttons_pressed = set() - self.distance = 20 # Initial camera distance - self.center = np.array([0, 0, 0]) # View center point - self.elevation = 30 # Initial camera elevation - self.azimuth = 45 # Initial camera azimuth - - self.init_ui() - - def init_ui(self): - """Initialize the user interface""" - self.setWindowTitle("3D GraphML Viewer") - self.setGeometry(100, 100, 1600, 900) - - # Create main splitter - self.main_splitter = QSplitter(Qt.Horizontal) - self.setCentralWidget(self.main_splitter) - - # Create left panel for 3D view - left_widget = QWidget() - left_layout = QVBoxLayout(left_widget) - - # Create controls - self.create_toolbar(left_layout) - - # Create 3D view - self.view = gl.GLViewWidget() - self.view.setMouseTracking(True) - - # Connect mouse events - self.view.mousePressEvent = self.on_mouse_press - self.view.mouseMoveEvent = self.on_mouse_move - left_layout.addWidget(self.view) - - self.main_splitter.addWidget(left_widget) - - # Create details widget - self.details = NodeDetailsWidget() - details_dock = QDockWidget("Node Details", self) - details_dock.setWidget(self.details) - self.addDockWidget(Qt.RightDockWidgetArea, details_dock) - - # Add status bar - self.statusBar().showMessage("Ready") - - # Add initial grid - grid = gl.GLGridItem() - grid.setSize(x=20, y=20, z=20) - grid.setSpacing(x=1, y=1, z=1) - self.view.addItem(grid) - - # Set initial camera position - self.view.setCameraPosition( - distance=self.distance, elevation=self.elevation, azimuth=self.azimuth - ) - - # Connect all mouse events - self.view.mousePressEvent = self.on_mouse_press - self.view.mouseReleaseEvent = self.on_mouse_release - self.view.mouseMoveEvent = self.on_mouse_move - self.view.wheelEvent = self.on_mouse_wheel - - def calculate_node_sizes(self) -> Dict[str, float]: - """Calculate node sizes based on number of connections""" - if not self.graph: - return {} - - # Get degree (number of connections) for each node - degrees = dict(self.graph.degree()) - - # Calculate size scaling - max_degree = max(degrees.values()) - min_degree = min(degrees.values()) - - # Normalize sizes between 0.5 and 2.0 - sizes = {} - for node, degree in degrees.items(): - if max_degree == min_degree: - sizes[node] = 1.0 - else: - # Normalize and scale size - normalized = (degree - min_degree) / (max_degree - min_degree) - sizes[node] = 0.5 + normalized * 1.5 - - return sizes - - def create_toolbar(self, layout: QVBoxLayout): - """Create the toolbar with controls""" - toolbar = QHBoxLayout() - - # Load button - load_btn = QPushButton("Load GraphML") - load_btn.clicked.connect(self.load_graphml) - toolbar.addWidget(load_btn) - - # Reset view button - reset_btn = QPushButton("Reset View") - reset_btn.clicked.connect(lambda: self.view.setCameraPosition(distance=20)) - toolbar.addWidget(reset_btn) - - # Layout selector - self.layout_combo = QComboBox() - self.layout_combo.addItems(["Spring", "Circular", "Shell", "Random"]) - self.layout_combo.currentTextChanged.connect(self.refresh_layout) - toolbar.addWidget(QLabel("Layout:")) - toolbar.addWidget(self.layout_combo) - - # Node size control - self.node_size = QSpinBox() - self.node_size.setRange(1, 100) - self.node_size.setValue(20) - self.node_size.valueChanged.connect(self.refresh_layout) - toolbar.addWidget(QLabel("Node Size:")) - toolbar.addWidget(self.node_size) - - # Show labels checkbox - self.show_labels = QCheckBox("Show Labels") - self.show_labels.setChecked(True) - self.show_labels.stateChanged.connect(self.refresh_layout) - toolbar.addWidget(self.show_labels) - - layout.addLayout(toolbar) - # Reset view button - reset_btn = QPushButton("Reset View") - reset_btn.clicked.connect(self.reset_view) # Use the new reset_view method - toolbar.addWidget(reset_btn) - - def load_graphml(self) -> None: - """Load and visualize a GraphML file""" - try: - file_path, _ = QFileDialog.getOpenFileName( - self, "Open GraphML file", "", "GraphML files (*.graphml)" - ) - - if file_path: - self.graph = nx.read_graphml(Path(file_path)) - self.refresh_layout() - self.statusBar().showMessage(f"Loaded: {file_path}") - except Exception as e: - trace_exception(e) - QMessageBox.critical(self, "Error", f"Error loading file: {str(e)}") - - def calculate_layout(self) -> Dict[str, np.ndarray]: - """Calculate node positions based on selected layout""" - layout_type = self.layout_combo.currentText().lower() - - # Detect communities for coloring - self.communities = community.best_partition(self.graph) - num_communities = len(set(self.communities.values())) - self.community_colors = plt.cm.rainbow(np.linspace(0, 1, num_communities)) - - if layout_type == "spring": - pos = nx.spring_layout( - self.graph, dim=3, k=2.0, iterations=100, weight=None - ) - elif layout_type == "circular": - pos_2d = nx.circular_layout(self.graph) - pos = {node: np.array([x, y, 0.0]) for node, (x, y) in pos_2d.items()} - elif layout_type == "shell": - comm_lists = [[] for _ in range(num_communities)] - for node, comm in self.communities.items(): - comm_lists[comm].append(node) - pos_2d = nx.shell_layout(self.graph, comm_lists) - pos = {node: np.array([x, y, 0.0]) for node, (x, y) in pos_2d.items()} - else: # random - pos = {node: np.random.rand(3) * 2 - 1 for node in self.graph.nodes()} - - # Scale positions - positions = np.array(list(pos.values())) - if len(positions) > 0: - scale = 10.0 / max(1.0, np.max(np.abs(positions))) - return {node: coords * scale for node, coords in pos.items()} - return pos - - def get_node_color(self, node_id: str) -> Tuple[float, float, float, float]: - """Get RGBA color based on community""" - if hasattr(self, "communities") and node_id in self.communities: - comm_id = self.communities[node_id] - color = self.community_colors[comm_id] - return tuple(color) - return (0.5, 0.5, 0.5, 0.8) - - def create_node(self, node_id: str, position: np.ndarray, node_type: str) -> Node3D: - """Create a 3D node with interaction capabilities""" - color = self.get_node_color(node_id) - - # Get size multiplier based on connections - size_multiplier = self.node_sizes.get(node_id, 1.0) - size = NodeState.BASE_SIZE * self.node_size.value() / 50.0 * size_multiplier - - node = Node3D(position, color, str(node_id), node_type, size) - - node.mesh_item = gl.GLScatterPlotItem( - pos=np.array([position]), - size=np.array([size * 8]), - color=np.array([color]), - pxMode=False, - ) - - # Enable picking and set node ID - node.mesh_item.setGLOptions("translucent") - node.mesh_item.node_id = node_id - - if self.show_labels.isChecked(): - node.label_item = gl.GLTextItem( - pos=position, - text=str(node_id), - color=(1, 1, 1, 1), - ) - - return node - - def mapToView(self, pos) -> Point: - """Convert screen coordinates to world coordinates""" - # Get the viewport size - width = self.view.width() - height = self.view.height() - - # Normalize coordinates - x = (pos.x() / width - 0.5) * 20 # Scale factor of 20 matches the grid size - y = -(pos.y() / height - 0.5) * 20 - - return Point(x, y) - - def on_mouse_move(self, event): - """Handle mouse movement for pan, rotate and hover""" - if self.mouse_pos_last is None: - self.mouse_pos_last = event.pos() - return - - pos = event.pos() - dx = pos.x() - self.mouse_pos_last.x() - dy = pos.y() - self.mouse_pos_last.y() - - # Handle right button drag for panning - if Qt.RightButton in self.mouse_buttons_pressed: - # Scale the pan amount based on view distance - scale = self.distance / 1000.0 - - # Calculate pan in view coordinates - right = np.cross([0, 0, 1], self.view.cameraPosition()) - right = right / np.linalg.norm(right) - up = np.cross(self.view.cameraPosition(), right) - up = up / np.linalg.norm(up) - - pan = -right * dx * scale + up * dy * scale - self.center += pan - self.view.pan(dx, dy, 0) - - # Handle middle button drag for rotation - elif Qt.MiddleButton in self.mouse_buttons_pressed: - self.azimuth += dx * 0.5 # Adjust rotation speed as needed - self.elevation -= dy * 0.5 - - # Clamp elevation to prevent gimbal lock - self.elevation = np.clip(self.elevation, -89, 89) - - self.view.setCameraPosition( - distance=self.distance, elevation=self.elevation, azimuth=self.azimuth - ) - - # Handle hover events when no buttons are pressed - elif not self.mouse_buttons_pressed: - # Get the mouse position in world coordinates - mouse_pos = self.mapToView(pos) - - # Check for hover - min_dist = float("inf") - hovered_node = None - - for node_id, node in self.nodes.items(): - # Calculate distance to mouse in world coordinates - dx = mouse_pos.x - node.position[0] - dy = mouse_pos.y - node.position[1] - dist = np.sqrt(dx * dx + dy * dy) - - if dist < min_dist and dist < 0.5: # Adjust threshold as needed - min_dist = dist - hovered_node = node_id - - # Update hover states - for node_id, node in self.nodes.items(): - if node_id == hovered_node: - node.highlight() - self.statusBar().showMessage(f"Node: {node_id} ({node.node_type})") - else: - if not node.is_selected: - node.unhighlight() - self.mouse_pos_last = pos - - def on_mouse_press(self, event): - """Handle mouse press events""" - self.mouse_pos_last = event.pos() - self.mouse_buttons_pressed.add(event.button()) - - # Handle left click for node selection - if event.button() == Qt.LeftButton: - pos = event.pos() - mouse_pos = self.mapToView(pos) - - # Find closest node - min_dist = float("inf") - clicked_node = None - - for node_id, node in self.nodes.items(): - dx = mouse_pos.x - node.position[0] - dy = mouse_pos.y - node.position[1] - dist = np.sqrt(dx * dx + dy * dy) - - if dist < min_dist and dist < 0.5: # Adjust threshold as needed - min_dist = dist - clicked_node = node_id - - # Handle selection - if clicked_node: - if self.selected_node and self.selected_node in self.nodes: - self.nodes[self.selected_node].deselect() - - self.nodes[clicked_node].select() - self.selected_node = clicked_node - - if self.graph: - self.details.update_node_info( - self.graph.nodes[clicked_node], self.graph[clicked_node] - ) - - def on_mouse_release(self, event): - """Handle mouse release events""" - self.mouse_buttons_pressed.discard(event.button()) - self.mouse_pos_last = None - - def on_mouse_wheel(self, event): - """Handle mouse wheel for zooming""" - delta = event.angleDelta().y() - - # Adjust zoom speed based on current distance - zoom_speed = self.distance / 100.0 - - # Update distance with limits - self.distance -= delta * zoom_speed - self.distance = np.clip(self.distance, 1.0, 100.0) - - self.view.setCameraPosition( - distance=self.distance, elevation=self.elevation, azimuth=self.azimuth - ) - - def reset_view(self): - """Reset camera to default position""" - self.distance = 20 - self.elevation = 30 - self.azimuth = 45 - self.center = np.array([0, 0, 0]) - - self.view.setCameraPosition( - distance=self.distance, elevation=self.elevation, azimuth=self.azimuth - ) - - def create_edge( - self, - start_pos: np.ndarray, - end_pos: np.ndarray, - color: Tuple[float, float, float, float] = (0.3, 0.3, 0.3, 0.2), - ) -> gl.GLLinePlotItem: - """Create a 3D edge between nodes""" - return gl.GLLinePlotItem( - pos=np.array([start_pos, end_pos]), - color=color, - width=1, - antialias=True, - mode="lines", - ) - - def handle_node_hover(self, event: Any, node_id: str) -> None: - """Handle node hover events""" - if node_id in self.nodes: - node = self.nodes[node_id] - if event.isEnter(): - node.highlight() - self.statusBar().showMessage(f"Node: {node_id} ({node.node_type})") - elif event.isExit(): - node.unhighlight() - self.statusBar().showMessage("") - - def handle_node_click(self, event: Any, node_id: str) -> None: - """Handle node click events""" - if event.button() != Qt.LeftButton or node_id not in self.nodes: - return - - if self.selected_node and self.selected_node in self.nodes: - self.nodes[self.selected_node].deselect() - - node = self.nodes[node_id] - node.select() - self.selected_node = node_id - - if self.graph: - self.details.update_node_info( - self.graph.nodes[node_id], self.graph[node_id] - ) - - def refresh_layout(self) -> None: - """Refresh the graph visualization""" - if not self.graph: - return - - self.positions = self.calculate_layout() - self.node_sizes = self.calculate_node_sizes() - - self.view.clear() - self.nodes.clear() - self.edges.clear() - self.edge_labels.clear() - - grid = gl.GLGridItem() - grid.setSize(x=20, y=20, z=20) - grid.setSpacing(x=1, y=1, z=1) - self.view.addItem(grid) - - positions = self.calculate_layout() - - for node_id in self.graph.nodes(): - node_type = self.graph.nodes[node_id].get("type", "default") - node = self.create_node(node_id, positions[node_id], node_type) - - self.view.addItem(node.mesh_item) - if node.label_item: - self.view.addItem(node.label_item) - - self.nodes[node_id] = node - - for source, target in self.graph.edges(): - edge = self.create_edge(positions[source], positions[target]) - self.view.addItem(edge) - self.edges.append(edge) - - if self.show_labels.isChecked(): - mid_point = (positions[source] + positions[target]) / 2 - relationship = self.graph.edges[source, target].get("relationship", "") - if relationship: - label = gl.GLTextItem( - pos=mid_point, - text=relationship, - color=(0.8, 0.8, 0.8, 0.8), - ) - self.view.addItem(label) - self.edge_labels.append(label) - - -def main(): - """Application entry point""" - import sys - - app = QApplication(sys.argv) - viewer = GraphMLViewer3D() - viewer.show() - sys.exit(app.exec_()) - - -if __name__ == "__main__": - main() diff --git a/extra/VisualizationTool/README-zh.md b/extra/VisualizationTool/README-zh.md new file mode 100644 index 00000000..d76c93a2 --- /dev/null +++ b/extra/VisualizationTool/README-zh.md @@ -0,0 +1,95 @@ +# 3D GraphML Viewer + +一个基于 Dear ImGui 和 ModernGL 的交互式 3D 图可视化工具。 + +## 功能特点 + +- **3D 交互式可视化**: 使用 ModernGL 实现高性能的 3D 图形渲染 +- **多种布局算法**: 支持多种图布局方式 + - Spring 布局 + - Circular 布局 + - Shell 布局 + - Random 布局 +- **社区检测**: 支持图社区结构的自动检测和可视化 +- **交互控制**: + - WASD + QE 键控制相机移动 + - 鼠标右键拖拽控制视角 + - 节点选择和高亮 + - 可调节节点大小和边宽度 + - 可控制标签显示 + - 可在节点的Connections间快速跳转 +- **社区检测**: 支持图社区结构的自动检测和可视化 +- **交互控制**: + - WASD + QE 键控制相机移动 + - 鼠标右键拖拽控制视角 + - 节点选择和高亮 + - 可调节节点大小和边宽度 + - 可控制标签显示 + +## 技术栈 + +- **imgui_bundle**: 用户界面 +- **ModernGL**: OpenGL 图形渲染 +- **NetworkX**: 图数据结构和算法 +- **NumPy**: 数值计算 +- **community**: 社区检测 + +## 使用方法 + +1. **启动程序**: + ```bash + python -m pip install -r requirements.txt + python graph_visualizer.py + ``` + +2. **加载字体**: + - 将中文字体文件 `font.ttf` 放置在 `assets` 目录下 + - 或者修改 `CUSTOM_FONT` 常量来使用其他字体文件 + +3. **加载图文件**: + - 点击界面上的 "Load GraphML" 按钮 + - 选择 GraphML 格式的图文件 + +4. **交互控制**: + - **相机移动**: + - W: 前进 + - S: 后退 + - A: 左移 + - D: 右移 + - Q: 上升 + - E: 下降 + - **视角控制**: + - 按住鼠标右键拖动来旋转视角 + - **节点交互**: + - 鼠标悬停可高亮节点 + - 点击可选中节点 + +5. **可视化设置**: + - 可通过 UI 控制面板调整: + - 布局类型 + - 节点大小 + - 边的宽度 + - 标签显示 + - 标签大小 + - 背景颜色 + +## 自定义设置 + +- **节点缩放**: 通过 `node_scale` 参数调整节点大小 +- **边宽度**: 通过 `edge_width` 参数调整边的宽度 +- **标签显示**: 可通过 `show_labels` 开关标签显示 +- **标签大小**: 使用 `label_size` 调整标签大小 +- **标签颜色**: 通过 `label_color` 设置标签颜色 +- **视距控制**: 使用 `label_culling_distance` 控制标签显示的最大距离 + +## 性能优化 + +- 使用 ModernGL 进行高效的图形渲染 +- 视距裁剪优化标签显示 +- 社区检测算法优化大规模图的可视化效果 + +## 系统要求 + +- Python 3.10+ +- OpenGL 3.3+ 兼容的显卡 +- 支持的操作系统:Windows/Linux/MacOS diff --git a/extra/VisualizationTool/README.md b/extra/VisualizationTool/README.md new file mode 100644 index 00000000..a2581703 --- /dev/null +++ b/extra/VisualizationTool/README.md @@ -0,0 +1,88 @@ +# 3D GraphML Viewer + +An interactive 3D graph visualization tool based on Dear ImGui and ModernGL. + +## Features + +- **3D Interactive Visualization**: High-performance 3D graphics rendering using ModernGL +- **Multiple Layout Algorithms**: Support for various graph layouts + - Spring layout + - Circular layout + - Shell layout + - Random layout +- **Community Detection**: Automatic detection and visualization of graph community structures +- **Interactive Controls**: + - WASD + QE keys for camera movement + - Right mouse drag for view angle control + - Node selection and highlighting + - Adjustable node size and edge width + - Configurable label display + - Quick navigation between node connections + +## Tech Stack + +- **imgui_bundle**: User interface +- **ModernGL**: OpenGL graphics rendering +- **NetworkX**: Graph data structures and algorithms +- **NumPy**: Numerical computations +- **community**: Community detection + +## Usage + +1. **Launch the Program**: + ```bash + python -m pip install -r requirements.txt + python graph_visualizer.py + ``` + +2. **Load Font**: + - Place the font file `font.ttf` in the `assets` directory + - Or modify the `CUSTOM_FONT` constant to use a different font file + +3. **Load Graph File**: + - Click the "Load GraphML" button in the interface + - Select a graph file in GraphML format + +4. **Interactive Controls**: + - **Camera Movement**: + - W: Move forward + - S: Move backward + - A: Move left + - D: Move right + - Q: Move up + - E: Move down + - **View Control**: + - Hold right mouse button and drag to rotate view + - **Node Interaction**: + - Hover mouse to highlight nodes + - Click to select nodes + +5. **Visualization Settings**: + - Adjustable via UI control panel: + - Layout type + - Node size + - Edge width + - Label visibility + - Label size + - Background color + +## Customization Options + +- **Node Scaling**: Adjust node size via `node_scale` parameter +- **Edge Width**: Modify edge width using `edge_width` parameter +- **Label Display**: Toggle label visibility with `show_labels` +- **Label Size**: Adjust label size using `label_size` +- **Label Color**: Set label color through `label_color` +- **View Distance**: Control maximum label display distance with `label_culling_distance` + +## Performance Optimizations + +- Efficient graphics rendering using ModernGL +- View distance culling for label display optimization +- Community detection algorithms for optimized visualization of large-scale graphs + +## System Requirements + +- Python 3.10+ +- Graphics card with OpenGL 3.3+ support +- Supported Operating Systems: Windows/Linux/MacOS diff --git a/extra/VisualizationTool/assets/place_font_here b/extra/VisualizationTool/assets/place_font_here new file mode 100644 index 00000000..e69de29b diff --git a/extra/VisualizationTool/graph_visualizer.py b/extra/VisualizationTool/graph_visualizer.py new file mode 100644 index 00000000..9bbd6235 --- /dev/null +++ b/extra/VisualizationTool/graph_visualizer.py @@ -0,0 +1,1149 @@ +""" +3D GraphML Viewer using Dear ImGui and ModernGL +Author: LoLLMs, ArnoChen +Description: An interactive 3D GraphML viewer using imgui_bundle and ModernGL +Version: 2.0 +""" + +from typing import Optional, Tuple, Dict, List +import numpy as np +import networkx as nx +import moderngl +from imgui_bundle import imgui, immapp, hello_imgui +import community +import glm +import tkinter as tk +from tkinter import filedialog +import traceback +import colorsys +import os + +CUSTOM_FONT = "font.ttf" + + +class Node3D: + """Class representing a 3D node in the graph""" + + def __init__( + self, position: glm.vec3, color: glm.vec3, label: str, size: float, idx: int + ): + self.position = position + self.color = color + self.label = label + self.size = size + self.idx = idx + + +class GraphViewer: + """Main class for 3D graph visualization""" + + def __init__(self): + self.glctx = None # ModernGL context + self.graph: Optional[nx.Graph] = None + self.nodes: List[Node3D] = [] + self.id_node_map: Dict[str, Node3D] = {} + self.communities = None + self.community_colors = None + + # Window dimensions + self.window_width = 1280 + self.window_height = 720 + + # Camera parameters + self.position = glm.vec3(0.0, -10.0, 0.0) # Initial camera position + self.front = glm.vec3(0.0, 1.0, 0.0) # Direction camera is facing + self.up = glm.vec3(0.0, 0.0, 1.0) # Up vector + self.yaw = 90.0 # Horizontal rotation (around Z axis) + self.pitch = 0.0 # Vertical rotation + self.move_speed = 0.05 + self.mouse_sensitivity = 0.15 + + # Graph visualization settings + self.layout_type = "Spring" + self.node_scale = 0.2 + self.edge_width = 0.5 + self.show_labels = True + self.label_size = 2 + self.label_color = (1.0, 1.0, 1.0, 1.0) + self.label_culling_distance = 10.0 + self.available_layouts = ("Spring", "Circular", "Shell", "Random") + self.background_color = (0.05, 0.05, 0.05, 1.0) + + # Mouse interaction + self.last_mouse_pos = None + self.mouse_pressed = False + self.mouse_button = -1 + self.first_mouse = True + + # File dialog state + self.show_load_error = False + self.error_message = "" + + # Selection state + self.selected_node: Optional[Node3D] = None + self.highlighted_node: Optional[Node3D] = None + + # Node id map + self.node_id_fbo = None + self.node_id_texture = None + self.node_id_depth = None + self.node_id_texture_np: np.ndarray = None + + # Static data + self.sphere_data = create_sphere() + + # Initialization flag + self.initialized = False + + def setup(self): + self.setup_render_context() + self.setup_shaders() + self.setup_buffers() + self.initialized = True + + def handle_keyboard_input(self): + """Handle WASD keyboard input for camera movement""" + io = imgui.get_io() + + if io.want_capture_keyboard: + return + + # Calculate camera vectors + right = glm.normalize(glm.cross(self.front, self.up)) + + # Get movement direction from WASD keys + if imgui.is_key_down(imgui.Key.w): # Forward + self.position += self.front * self.move_speed * 0.1 + if imgui.is_key_down(imgui.Key.s): # Backward + self.position -= self.front * self.move_speed * 0.1 + if imgui.is_key_down(imgui.Key.a): # Left + self.position -= right * self.move_speed * 0.1 + if imgui.is_key_down(imgui.Key.d): # Right + self.position += right * self.move_speed * 0.1 + if imgui.is_key_down(imgui.Key.q): # Up + self.position += self.up * self.move_speed * 0.1 + if imgui.is_key_down(imgui.Key.e): # Down + self.position -= self.up * self.move_speed * 0.1 + + def handle_mouse_interaction(self): + """Handle mouse interaction for camera control and node selection""" + if ( + imgui.is_any_item_active() + or imgui.is_any_item_hovered() + or imgui.is_any_item_focused() + ): + return + + io = imgui.get_io() + mouse_pos = (io.mouse_pos.x, io.mouse_pos.y) + if ( + mouse_pos[0] < 0 + or mouse_pos[1] < 0 + or mouse_pos[0] >= self.window_width + or mouse_pos[1] >= self.window_height + ): + return + + # Handle first mouse input + if self.first_mouse: + self.last_mouse_pos = mouse_pos + self.first_mouse = False + return + + # Handle mouse movement for camera rotation + if self.mouse_pressed and self.mouse_button == 1: # Right mouse button + dx = self.last_mouse_pos[0] - mouse_pos[0] + dy = self.last_mouse_pos[1] - mouse_pos[1] # Reversed for intuitive control + + dx *= self.mouse_sensitivity + dy *= self.mouse_sensitivity + + self.yaw += dx + self.pitch += dy + + # Limit pitch to avoid flipping + self.pitch = np.clip(self.pitch, -89.0, 89.0) + + # Update front vector + self.front = glm.normalize( + glm.vec3( + np.cos(np.radians(self.yaw)) * np.cos(np.radians(self.pitch)), + np.sin(np.radians(self.yaw)) * np.cos(np.radians(self.pitch)), + np.sin(np.radians(self.pitch)), + ) + ) + if io.mouse_wheel != 0: + self.move_speed += io.mouse_wheel * 0.05 + self.move_speed = np.max([self.move_speed, 0.01]) + + # Handle mouse press/release + for button in range(3): + if imgui.is_mouse_clicked(button): + self.mouse_pressed = True + self.mouse_button = button + if button == 0 and self.highlighted_node: # Left click for selection + self.selected_node = self.highlighted_node + + if imgui.is_mouse_released(button) and self.mouse_button == button: + self.mouse_pressed = False + self.mouse_button = -1 + + # Handle node hovering + if not self.mouse_pressed: + hovered = self.find_node_at((int(mouse_pos[0]), int(mouse_pos[1]))) + self.highlighted_node = hovered + + # Update last mouse position + self.last_mouse_pos = mouse_pos + + def update_layout(self): + """Update the graph layout""" + pos = nx.spring_layout( + self.graph, + dim=3, + pos={ + node_id: list(node.position) + for node_id, node in self.id_node_map.items() + }, + k=2.0, + iterations=100, + weight=None, + ) + + # Update node positions + for node_id, position in pos.items(): + self.id_node_map[node_id].position = glm.vec3(position) + self.update_buffers() + + def render_node_details(self): + """Render node details window""" + if self.selected_node and imgui.begin("Node Details"): + imgui.text(f"ID: {self.selected_node.label}") + + if self.graph: + node_data = self.graph.nodes[self.selected_node.label] + imgui.text(f"Type: {node_data.get('type', 'default')}") + + degree = self.graph.degree[self.selected_node.label] + imgui.text(f"Degree: {degree}") + + for key, value in node_data.items(): + if key != "type": + imgui.text(f"{key}: {value}") + if value and imgui.is_item_hovered(): + imgui.set_tooltip(str(value)) + + imgui.separator() + + connections = self.graph[self.selected_node.label] + if connections: + imgui.text("Connections:") + keys = next(iter(connections.values())).keys() + if imgui.begin_table( + "Connections", + len(keys) + 1, + imgui.TableFlags_.borders + | imgui.TableFlags_.row_bg + | imgui.TableFlags_.resizable + | imgui.TableFlags_.hideable, + ): + imgui.table_setup_column("Node") + for key in keys: + imgui.table_setup_column(key) + imgui.table_headers_row() + + for neighbor, edge_data in connections.items(): + imgui.table_next_row() + imgui.table_set_column_index(0) + if imgui.selectable(str(neighbor), True)[0]: + # Select neighbor node + self.selected_node = self.id_node_map[neighbor] + self.position = self.selected_node.position - self.front + for idx, key in enumerate(keys): + imgui.table_set_column_index(idx + 1) + value = str(edge_data.get(key, "")) + imgui.text(value) + if value and imgui.is_item_hovered(): + imgui.set_tooltip(value) + imgui.end_table() + + imgui.end() + + def setup_render_context(self): + """Initialize ModernGL context""" + self.glctx = moderngl.create_context() + self.glctx.enable(moderngl.DEPTH_TEST | moderngl.CULL_FACE) + self.glctx.clear_color = self.background_color + + def setup_shaders(self): + """Setup vertex and fragment shaders for node and edge rendering""" + # Node shader program + self.node_prog = self.glctx.program( + vertex_shader=""" + #version 330 + + uniform mat4 mvp; + uniform vec3 camera; + uniform int selected_node; + uniform int highlighted_node; + uniform float scale; + + in vec3 in_position; + in vec3 in_instance_position; + in vec3 in_instance_color; + in float in_instance_size; + + out vec3 frag_color; + out vec3 frag_normal; + out vec3 frag_view_dir; + + void main() { + vec3 pos = in_position * in_instance_size * scale + in_instance_position; + gl_Position = mvp * vec4(pos, 1.0); + + frag_normal = normalize(in_position); + frag_view_dir = normalize(camera - pos); + + if (selected_node == gl_InstanceID) { + frag_color = vec3(1.0, 0.5, 0.0); + } + else if (highlighted_node == gl_InstanceID) { + frag_color = vec3(1.0, 0.8, 0.2); + } + else { + frag_color = in_instance_color; + } + } + """, + fragment_shader=""" + #version 330 + + in vec3 frag_color; + in vec3 frag_normal; + in vec3 frag_view_dir; + + out vec4 outColor; + + void main() { + // Edge detection based on normal-view angle + float edge = 1.0 - abs(dot(frag_normal, frag_view_dir)); + + // Create sharp outline + float outline = smoothstep(0.8, 0.9, edge); + + // Mix the sphere color with outline + vec3 final_color = mix(frag_color, vec3(0.0), outline); + + outColor = vec4(final_color, 1.0); + } + """, + ) + + # Edge shader program with wide lines using geometry shader + self.edge_prog = self.glctx.program( + vertex_shader=""" + #version 330 + + uniform mat4 mvp; + + in vec3 in_position; + in vec3 in_color; + + out vec3 v_color; + out vec4 v_position; + + void main() { + v_position = mvp * vec4(in_position, 1.0); + gl_Position = v_position; + v_color = in_color; + } + """, + geometry_shader=""" + #version 330 + + layout(lines) in; + layout(triangle_strip, max_vertices = 4) out; + + uniform float edge_width; + uniform vec2 viewport_size; + + in vec3 v_color[]; + in vec4 v_position[]; + out vec3 g_color; + out float edge_coord; + + void main() { + // Get the two vertices of the line + vec4 p1 = v_position[0]; + vec4 p2 = v_position[1]; + + // Perspective division + vec4 p1_ndc = p1 / p1.w; + vec4 p2_ndc = p2 / p2.w; + + // Calculate line direction in screen space + vec2 dir = normalize((p2_ndc.xy - p1_ndc.xy) * viewport_size); + vec2 normal = vec2(-dir.y, dir.x); + + // Calculate half width based on screen space + float half_width = edge_width * 0.5; + vec2 offset = normal * (half_width / viewport_size); + + // Emit vertices with proper depth + gl_Position = vec4(p1_ndc.xy + offset, p1_ndc.z, 1.0); + gl_Position *= p1.w; // Restore perspective + g_color = v_color[0]; + edge_coord = 1.0; + EmitVertex(); + + gl_Position = vec4(p1_ndc.xy - offset, p1_ndc.z, 1.0); + gl_Position *= p1.w; + g_color = v_color[0]; + edge_coord = -1.0; + EmitVertex(); + + gl_Position = vec4(p2_ndc.xy + offset, p2_ndc.z, 1.0); + gl_Position *= p2.w; + g_color = v_color[1]; + edge_coord = 1.0; + EmitVertex(); + + gl_Position = vec4(p2_ndc.xy - offset, p2_ndc.z, 1.0); + gl_Position *= p2.w; + g_color = v_color[1]; + edge_coord = -1.0; + EmitVertex(); + + EndPrimitive(); + } + """, + fragment_shader=""" + #version 330 + + in vec3 g_color; + in float edge_coord; + + out vec4 fragColor; + + void main() { + // Edge outline parameters + float outline_width = 0.2; // Width of the outline relative to edge + float edge_softness = 0.1; // Softness of the edge + float edge_dist = abs(edge_coord); + + // Calculate outline + float outline_factor = smoothstep(1.0 - outline_width - edge_softness, + 1.0 - outline_width, + edge_dist); + + // Mix edge color with outline (black) + vec3 final_color = mix(g_color, vec3(0.0), outline_factor); + + // Calculate alpha for anti-aliasing + float alpha = 1.0 - smoothstep(1.0 - edge_softness, 1.0, edge_dist); + + fragColor = vec4(final_color, alpha); + } + """, + ) + + # Id framebuffer shader program + self.node_id_prog = self.glctx.program( + vertex_shader=""" + #version 330 + + uniform mat4 mvp; + uniform float scale; + + in vec3 in_position; + in vec3 in_instance_position; + in float in_instance_size; + + out vec3 frag_color; + + vec3 int_to_rgb(int value) { + float R = float((value >> 16) & 0xFF); + float G = float((value >> 8) & 0xFF); + float B = float(value & 0xFF); + // normalize to [0, 1] + return vec3(R / 255.0, G / 255.0, B / 255.0); + } + + void main() { + vec3 pos = in_position * in_instance_size * scale + in_instance_position; + gl_Position = mvp * vec4(pos, 1.0); + frag_color = int_to_rgb(gl_InstanceID); + } + """, + fragment_shader=""" + #version 330 + in vec3 frag_color; + out vec4 outColor; + void main() { + outColor = vec4(frag_color, 1.0); + } + """, + ) + + def setup_buffers(self): + """Setup vertex buffers for nodes and edges""" + # We'll create these when loading the graph + self.node_vbo = None + self.node_color_vbo = None + self.node_size_vbo = None + self.edge_vbo = None + self.edge_color_vbo = None + self.node_vao = None + self.edge_vao = None + self.node_id_vao = None + self.sphere_pos_vbo = None + self.sphere_index_buffer = None + + def load_file(self, filepath: str): + """Load a GraphML file with error handling""" + try: + self.graph = nx.read_graphml(filepath) + self.calculate_layout() + self.update_buffers() + self.show_load_error = False + self.error_message = "" + except Exception as _: + self.show_load_error = True + self.error_message = traceback.format_exc() + print(self.error_message) + + def calculate_layout(self): + """Calculate 3D layout for the graph""" + if not self.graph: + return + + # Detect communities for coloring + self.communities = community.best_partition(self.graph) + num_communities = len(set(self.communities.values())) + self.community_colors = generate_colors(num_communities) + + # Calculate layout based on selected type + if self.layout_type == "Spring": + pos = nx.spring_layout( + self.graph, dim=3, k=2.0, iterations=100, weight=None + ) + elif self.layout_type == "Circular": + pos_2d = nx.circular_layout(self.graph) + pos = {node: np.array((x, 0.0, y)) for node, (x, y) in pos_2d.items()} + elif self.layout_type == "Shell": + # Group nodes by community for shell layout + comm_lists = [[] for _ in range(num_communities)] + for node, comm in self.communities.items(): + comm_lists[comm].append(node) + pos_2d = nx.shell_layout(self.graph, comm_lists) + pos = {node: np.array((x, 0.0, y)) for node, (x, y) in pos_2d.items()} + else: # Random + pos = {node: np.random.rand(3) * 2 - 1 for node in self.graph.nodes()} + + # Scale positions + positions = np.array(list(pos.values())) + if len(positions) > 0: + scale = 10.0 / max(1.0, np.max(np.abs(positions))) + pos = {node: coords * scale for node, coords in pos.items()} + + # Calculate degree-based sizes + degrees = dict(self.graph.degree()) + max_degree = max(degrees.values()) if degrees else 1 + min_degree = min(degrees.values()) if degrees else 1 + + idx = 0 + # Create nodes with community colors + for node_id in self.graph.nodes(): + position = glm.vec3(pos[node_id]) + color = self.get_node_color(node_id) + + # Normalize sizes between 0.5 and 2.0 + size = 1.0 + if max_degree != min_degree: + # Normalize and scale size + normalized = (degrees[node_id] - min_degree) / (max_degree - min_degree) + size = 0.5 + normalized * 1.5 + + if node_id in self.id_node_map: + node = self.id_node_map[node_id] + node.position = position + node.base_color = color + node.color = color + node.size = size + else: + node = Node3D(position, color, str(node_id), size, idx) + self.id_node_map[node_id] = node + self.nodes.append(node) + idx += 1 + + self.update_buffers() + + def get_node_color(self, node_id: str) -> glm.vec3: + """Get RGBA color based on community""" + if self.communities and node_id in self.communities: + comm_id = self.communities[node_id] + color = self.community_colors[comm_id] + return color + return glm.vec3(0.5, 0.5, 0.5) + + def update_buffers(self): + """Update vertex buffers with current node and edge data using batch rendering""" + if not self.graph: + return + + # Update node buffers + node_positions = [] + node_colors = [] + node_sizes = [] + + for node in self.nodes: + node_positions.append(node.position) + node_colors.append(node.color) # Only use RGB components + node_sizes.append(node.size) + + if node_positions: + node_positions = np.array(node_positions, dtype=np.float32) + node_colors = np.array(node_colors, dtype=np.float32) + node_sizes = np.array(node_sizes, dtype=np.float32) + + self.node_vbo = self.glctx.buffer(node_positions.tobytes()) + self.node_color_vbo = self.glctx.buffer(node_colors.tobytes()) + self.node_size_vbo = self.glctx.buffer(node_sizes.tobytes()) + self.sphere_pos_vbo = self.glctx.buffer(self.sphere_data[0].tobytes()) + self.sphere_index_buffer = self.glctx.buffer(self.sphere_data[1].tobytes()) + + self.node_vao = self.glctx.vertex_array( + self.node_prog, + [ + (self.sphere_pos_vbo, "3f", "in_position"), + (self.node_vbo, "3f /i", "in_instance_position"), + (self.node_color_vbo, "3f /i", "in_instance_color"), + (self.node_size_vbo, "f /i", "in_instance_size"), + ], + index_buffer=self.sphere_index_buffer, + index_element_size=4, + ) + self.node_vao.instances = len(self.nodes) + + self.node_id_vao = self.glctx.vertex_array( + self.node_id_prog, + [ + (self.sphere_pos_vbo, "3f", "in_position"), + (self.node_vbo, "3f /i", "in_instance_position"), + (self.node_size_vbo, "f /i", "in_instance_size"), + ], + index_buffer=self.sphere_index_buffer, + index_element_size=4, + ) + self.node_id_vao.instances = len(self.nodes) + + # Update edge buffers + edge_positions = [] + edge_colors = [] + + for edge in self.graph.edges(): + start_node = self.id_node_map[edge[0]] + end_node = self.id_node_map[edge[1]] + + edge_positions.append(start_node.position) + edge_colors.append(start_node.color) + + edge_positions.append(end_node.position) + edge_colors.append(end_node.color) + + if edge_positions: + edge_positions = np.array(edge_positions, dtype=np.float32) + edge_colors = np.array(edge_colors, dtype=np.float32) + + self.edge_vbo = self.glctx.buffer(edge_positions.tobytes()) + self.edge_color_vbo = self.glctx.buffer(edge_colors.tobytes()) + + self.edge_vao = self.glctx.vertex_array( + self.edge_prog, + [ + (self.edge_vbo, "3f", "in_position"), + (self.edge_color_vbo, "3f", "in_color"), + ], + ) + + def update_view_proj_matrix(self): + """Update view matrix based on camera parameters""" + self.view_matrix = glm.lookAt( + self.position, self.position + self.front, self.up + ) + + io = imgui.get_io() + self.window_width = int(io.display_size.x) + self.window_height = int(io.display_size.y) + + aspect_ratio = self.window_width / self.window_height + self.proj_matrix = glm.perspective( + glm.radians(60.0), # FOV + aspect_ratio, # Aspect ratio + 0.001, # Near plane + 1000.0, # Far plane + ) + + def find_node_at(self, screen_pos: Tuple[int, int]) -> Optional[Node3D]: + """Find the node at a specific screen position""" + if ( + self.node_id_texture_np is None + or self.node_id_texture_np.shape[1] != self.window_width + or self.node_id_texture_np.shape[0] != self.window_height + or screen_pos[0] < 0 + or screen_pos[1] < 0 + or screen_pos[0] >= self.window_width + or screen_pos[1] >= self.window_height + ): + return None + + x = screen_pos[0] + y = self.window_height - screen_pos[1] - 1 + pixel = self.node_id_texture_np[y, x] + + if pixel[3] == 0: + return None + + R = int(round(pixel[0] * 255)) + G = int(round(pixel[1] * 255)) + B = int(round(pixel[2] * 255)) + index = (R << 16) | (G << 8) | B + + if index > len(self.nodes): + return None + return self.nodes[index] + + def is_node_visible_at(self, screen_pos: Tuple[int, int], node_idx: int) -> bool: + """Check if a node exists at a specific screen position""" + node = self.find_node_at(screen_pos) + return node is not None and node.idx == node_idx + + def render_settings(self): + """Render settings window""" + if imgui.begin("Graph Settings"): + # Layout type combo + changed, value = imgui.combo( + "Layout", + self.available_layouts.index(self.layout_type), + self.available_layouts, + ) + if changed: + self.layout_type = self.available_layouts[value] + self.calculate_layout() # Recalculate layout when changed + + # Node size slider + changed, value = imgui.slider_float("Node Scale", self.node_scale, 0.01, 10) + if changed: + self.node_scale = value + + # Edge width slider + changed, value = imgui.slider_float("Edge Width", self.edge_width, 0, 20) + if changed: + self.edge_width = value + + # Show labels checkbox + changed, value = imgui.checkbox("Show Labels", self.show_labels) + + if changed: + self.show_labels = value + + if self.show_labels: + # Label size slider + changed, value = imgui.slider_float( + "Label Size", self.label_size, 0.5, 10.0 + ) + if changed: + self.label_size = value + + # Label color picker + changed, value = imgui.color_edit4( + "Label Color", + self.label_color, + imgui.ColorEditFlags_.picker_hue_wheel, + ) + if changed: + self.label_color = (value[0], value[1], value[2], value[3]) + + # Label culling distance slider + changed, value = imgui.slider_float( + "Label Culling Distance", self.label_culling_distance, 0.1, 100.0 + ) + if changed: + self.label_culling_distance = value + + # Background color picker + changed, value = imgui.color_edit4( + "Background Color", + self.background_color, + imgui.ColorEditFlags_.picker_hue_wheel, + ) + if changed: + self.background_color = (value[0], value[1], value[2], value[3]) + + imgui.end() + + def save_node_id_texture_to_png(self, filename): + # Convert to a PIL Image and save as PNG + from PIL import Image + + scaled_array = self.node_id_texture_np * 255 + img = Image.fromarray( + scaled_array.astype(np.uint8), + "RGBA", + ) + img = img.transpose(method=Image.FLIP_TOP_BOTTOM) + img.save(filename) + + def render_id_map(self, mvp: glm.mat4): + """Render an offscreen id map where each node is drawn with a unique id color.""" + # Lazy initialization of id framebuffer + if self.node_id_texture is not None: + if ( + self.node_id_texture.width != self.window_width + or self.node_id_texture.height != self.window_height + ): + self.node_id_fbo = None + self.node_id_texture = None + self.node_id_texture_np = None + self.node_id_depth = None + + if self.node_id_texture is None: + self.node_id_texture = self.glctx.texture( + (self.window_width, self.window_height), components=4, dtype="f4" + ) + self.node_id_depth = self.glctx.depth_renderbuffer( + size=(self.window_width, self.window_height) + ) + self.node_id_fbo = self.glctx.framebuffer( + color_attachments=[self.node_id_texture], + depth_attachment=self.node_id_depth, + ) + self.node_id_texture_np = np.zeros( + (self.window_height, self.window_width, 4), dtype=np.float32 + ) + + # Bind the offscreen framebuffer + self.node_id_fbo.use() + self.glctx.clear(0, 0, 0, 0) + + # Render nodes + if self.node_id_vao: + self.node_id_prog["mvp"].write(mvp.to_bytes()) + self.node_id_prog["scale"].write(np.float32(self.node_scale).tobytes()) + self.node_id_vao.render(moderngl.TRIANGLES) + + # Revert to default framebuffer + self.glctx.screen.use() + self.node_id_texture.read_into(self.node_id_texture_np.data) + + def render(self): + """Render the graph""" + # Clear screen + self.glctx.clear(*self.background_color) + + if not self.graph: + return + + # Enable blending for transparency + self.glctx.enable(moderngl.BLEND) + self.glctx.blend_func = moderngl.SRC_ALPHA, moderngl.ONE_MINUS_SRC_ALPHA + + # Update view and projection matrices + self.update_view_proj_matrix() + mvp = self.proj_matrix * self.view_matrix + + # Render edges first (under nodes) + if self.edge_vao: + self.edge_prog["mvp"].write(mvp.to_bytes()) + self.edge_prog["edge_width"].value = ( + float(self.edge_width) * 2.0 + ) # Double the width for better visibility + self.edge_prog["viewport_size"].value = ( + float(self.window_width), + float(self.window_height), + ) + self.edge_vao.render(moderngl.LINES) + + # Render nodes + if self.node_vao: + self.node_prog["mvp"].write(mvp.to_bytes()) + self.node_prog["camera"].write(self.position.to_bytes()) + self.node_prog["selected_node"].write( + np.int32(self.selected_node.idx).tobytes() + if self.selected_node + else np.int32(-1).tobytes() + ) + self.node_prog["highlighted_node"].write( + np.int32(self.highlighted_node.idx).tobytes() + if self.highlighted_node + else np.int32(-1).tobytes() + ) + self.node_prog["scale"].write(np.float32(self.node_scale).tobytes()) + self.node_vao.render(moderngl.TRIANGLES) + + self.glctx.disable(moderngl.BLEND) + + # Render id map + self.render_id_map(mvp) + + # Render labels if enabled + if self.show_labels: + # Save current font scale + original_scale = imgui.get_font_size() + + for node in self.nodes: + # Project node position to screen space + pos = mvp * glm.vec4( + node.position[0], node.position[1], node.position[2], 1.0 + ) + + # Check if node is behind camera + if pos.w > 0 and pos.w < self.label_culling_distance: + screen_x = (pos.x / pos.w + 1) * self.window_width / 2 + screen_y = (-pos.y / pos.w + 1) * self.window_height / 2 + + if self.is_node_visible_at( + (int(screen_x), int(screen_y)), node.idx + ): + # Set font scale + imgui.set_window_font_scale(float(self.label_size) * node.size) + + # Calculate label size + label_size = imgui.calc_text_size(node.label) + + # Adjust position to center the label + screen_x -= label_size.x / 2 + screen_y -= label_size.y / 2 + + # Set text color with calculated alpha + imgui.push_style_color(imgui.Col_.text, self.label_color) + + # Draw label using ImGui + imgui.set_cursor_pos((screen_x, screen_y)) + imgui.text(node.label) + + # Restore text color + imgui.pop_style_color() + + # Restore original font scale + imgui.set_window_font_scale(original_scale) + + def reset_view(self): + """Reset camera view to default""" + self.position = glm.vec3(0.0, -10.0, 0.0) + self.front = glm.vec3(0.0, 1.0, 0.0) + self.yaw = 90.0 + self.pitch = 0.0 + + +def generate_colors(n: int) -> List[glm.vec3]: + """Generate n distinct colors using HSV color space""" + colors = [] + for i in range(n): + # Use golden ratio to generate well-distributed hues + hue = (i * 0.618033988749895) % 1.0 + # Fixed saturation and value for vibrant colors + saturation = 0.8 + value = 0.95 + # Convert HSV to RGB + rgb = colorsys.hsv_to_rgb(hue, saturation, value) + # Add alpha channel + colors.append(glm.vec3(rgb)) + return colors + + +def show_file_dialog() -> Optional[str]: + """Show a file dialog for selecting GraphML files""" + root = tk.Tk() + root.withdraw() # Hide the main window + file_path = filedialog.askopenfilename( + title="Select GraphML File", + filetypes=[("GraphML files", "*.graphml"), ("All files", "*.*")], + ) + root.destroy() + return file_path if file_path else None + + +def create_sphere(sectors: int = 32, rings: int = 16) -> Tuple: + """ + Creates a sphere. + """ + R = 1.0 / (rings - 1) + S = 1.0 / (sectors - 1) + + # Use those names as normals and uvs are part of the API + vertices_l = [0.0] * (rings * sectors * 3) + # normals_l = [0.0] * (rings * sectors * 3) + uvs_l = [0.0] * (rings * sectors * 2) + + v, n, t = 0, 0, 0 + for r in range(rings): + for s in range(sectors): + y = np.sin(-np.pi / 2 + np.pi * r * R) + x = np.cos(2 * np.pi * s * S) * np.sin(np.pi * r * R) + z = np.sin(2 * np.pi * s * S) * np.sin(np.pi * r * R) + + uvs_l[t] = s * S + uvs_l[t + 1] = r * R + + vertices_l[v] = x + vertices_l[v + 1] = y + vertices_l[v + 2] = z + + t += 2 + v += 3 + n += 3 + + indices = [0] * rings * sectors * 6 + i = 0 + for r in range(rings - 1): + for s in range(sectors - 1): + indices[i] = r * sectors + s + indices[i + 1] = (r + 1) * sectors + (s + 1) + indices[i + 2] = r * sectors + (s + 1) + + indices[i + 3] = r * sectors + s + indices[i + 4] = (r + 1) * sectors + s + indices[i + 5] = (r + 1) * sectors + (s + 1) + i += 6 + + vbo_vertices = np.array(vertices_l, dtype=np.float32) + vbo_elements = np.array(indices, dtype=np.uint32) + + return (vbo_vertices, vbo_elements) + + +def main(): + """Main application entry point""" + viewer = GraphViewer() + + def gui(): + if not viewer.initialized: + viewer.setup() + + # Handle keyboard and mouse input + viewer.handle_keyboard_input() + viewer.handle_mouse_interaction() + + style = imgui.get_style() + window_bg_color = style.color_(imgui.Col_.window_bg.value) + + window_bg_color.w = 0.8 + style.set_color_(imgui.Col_.window_bg.value, window_bg_color) + + # Main control window + imgui.begin("Graph Controls") + + if imgui.button("Load GraphML"): + filepath = show_file_dialog() + if filepath: + viewer.load_file(filepath) + + # Show error message if loading failed + if viewer.show_load_error: + imgui.push_style_color(imgui.Col_.text, (1.0, 0.0, 0.0, 1.0)) + imgui.text(f"Error loading file: {viewer.error_message}") + imgui.pop_style_color() + + imgui.separator() + + # Camera controls help + imgui.text("Camera Controls:") + imgui.bullet_text("Hold Right Mouse - Look around") + imgui.bullet_text("W/S - Move forward/backward") + imgui.bullet_text("A/D - Move left/right") + imgui.bullet_text("Q/E - Move up/down") + imgui.bullet_text("Left Mouse - Select node") + imgui.bullet_text("Wheel - Change the movement speed") + + imgui.separator() + + # Camera settings + _, viewer.move_speed = imgui.slider_float( + "Movement Speed", viewer.move_speed, 0.01, 2.0 + ) + _, viewer.mouse_sensitivity = imgui.slider_float( + "Mouse Sensitivity", viewer.mouse_sensitivity, 0.01, 0.5 + ) + + imgui.separator() + + imgui.begin_horizontal("buttons") + + if imgui.button("Reset Camera"): + viewer.reset_view() + + if imgui.button("Update Layout") and viewer.graph: + viewer.update_layout() + + # if imgui.button("Save Node ID Texture"): + # viewer.save_node_id_texture_to_png("node_id_texture.png") + + imgui.end_horizontal() + + imgui.end() + + # Render node details window if a node is selected + viewer.render_node_details() + + # Render graph settings window + viewer.render_settings() + + window_bg_color.w = 0.0 + style.set_color_(imgui.Col_.window_bg.value, window_bg_color) + + # Render the graph + viewer.render() + + runner_params = hello_imgui.RunnerParams() + runner_params.app_window_params.window_geometry.size = ( + viewer.window_width, + viewer.window_height, + ) + runner_params.app_window_params.window_title = "3D GraphML Viewer" + runner_params.callbacks.show_gui = gui + addons = immapp.AddOnsParams() + addons.with_markdown = True + + def load_font(): + io = imgui.get_io() + io.fonts.add_font_default() + + # Load font for Chinese character support + # You will need to provide it yourself, or use another font. + font_filename = CUSTOM_FONT + + if not os.path.exists("assets/" + font_filename): + return + + # Get the full Chinese character range for ImGui + # This includes all Chinese characters supported by ImGui + cn_glyph_ranges_imgui = imgui.get_io().fonts.get_glyph_ranges_chinese_full() + + # Set up font loading parameters with Chinese character support + font_loading_params = hello_imgui.FontLoadingParams() + font_loading_params.glyph_ranges = hello_imgui.translate_common_glyph_ranges( + cn_glyph_ranges_imgui + ) + custom_font = hello_imgui.load_font(font_filename, 16.0, font_loading_params) + + # # Merge with default font + # font_config = imgui.ImFontConfig() + # font_config.merge_mode = True + # custom_font = io.fonts.add_font_from_file_ttf( + # filename= "assets/" + font_filename, + # size_pixels=16.0, + # font_cfg=font_config, + # glyph_ranges_as_int_list=cn_glyph_ranges_imgui, + # ) + + io.fonts.tex_desired_width = 4096 # Larger texture for better CJK font quality + io.font_default = custom_font + + runner_params.callbacks.load_additional_fonts = load_font + + immapp.run(runner_params, addons) + + +if __name__ == "__main__": + main() diff --git a/extra/VisualizationTool/requirements.txt b/extra/VisualizationTool/requirements.txt new file mode 100644 index 00000000..59f45627 --- /dev/null +++ b/extra/VisualizationTool/requirements.txt @@ -0,0 +1,8 @@ +imgui_bundle +moderngl +networkx +numpy +pyglm +python-louvain +scipy +tk diff --git a/lightrag/api/README.md b/lightrag/api/README.md index 288ff79c..66b3a10c 100644 --- a/lightrag/api/README.md +++ b/lightrag/api/README.md @@ -82,14 +82,19 @@ We provide an Ollama-compatible interfaces for LightRAG, aiming to emulate Light A query prefix in the query string can determines which LightRAG query mode is used to generate the respond for the query. The supported prefixes include: +``` /local /global /hybrid /naive /mix +/bypass +``` For example, chat message "/mix 唐僧有几个徒弟" will trigger a mix mode query for LighRAG. A chat message without query prefix will trigger a hybrid mode query by default。 +"/bypass" is not a LightRAG query mode, it will tell API Server to pass the query directly to the underlying LLM with chat history. So user can use LLM to answer question base on the LightRAG query results. (If you are using Open WebUI as front end, you can just switch the model to a normal LLM instead of using /bypass prefix) + #### Connect Open WebUI to LightRAG After starting the lightrag-server, you can add an Ollama-type connection in the Open WebUI admin pannel. And then a model named lightrag:latest will appear in Open WebUI's model management interface. Users can then send queries to LightRAG through the chat interface. diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index fa192f9c..2680a2ec 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -599,6 +599,7 @@ class SearchMode(str, Enum): global_ = "global" hybrid = "hybrid" mix = "mix" + bypass = "bypass" class OllamaMessage(BaseModel): @@ -1476,7 +1477,7 @@ def create_app(args): @app.get("/api/tags") async def get_tags(): - """Get available models""" + """Retrun available models acting as an Ollama server""" return OllamaTagResponse( models=[ { @@ -1507,6 +1508,7 @@ def create_app(args): "/naive ": SearchMode.naive, "/hybrid ": SearchMode.hybrid, "/mix ": SearchMode.mix, + "/bypass ": SearchMode.bypass, } for prefix, mode in mode_map.items(): @@ -1519,7 +1521,7 @@ def create_app(args): @app.post("/api/generate") async def generate(raw_request: Request, request: OllamaGenerateRequest): - """Handle generate completion requests + """Handle generate completion requests acting as an Ollama model For compatiblity purpuse, the request is not processed by LightRAG, and will be handled by underlying LLM model. """ @@ -1661,7 +1663,7 @@ def create_app(args): @app.post("/api/chat") async def chat(raw_request: Request, request: OllamaChatRequest): - """Process chat completion requests. + """Process chat completion requests acting as an Ollama model Routes user queries through LightRAG by selecting query mode based on prefix indicators. Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to LLM. """ @@ -1700,9 +1702,20 @@ def create_app(args): if request.stream: from fastapi.responses import StreamingResponse - response = await rag.aquery( # Need await to get async generator - cleaned_query, param=query_param - ) + # Determine if the request is prefix with "/bypass" + if mode == SearchMode.bypass: + if request.system: + rag.llm_model_kwargs["system_prompt"] = request.system + response = await rag.llm_model_func( + cleaned_query, + stream=True, + history_messages=conversation_history, + **rag.llm_model_kwargs, + ) + else: + response = await rag.aquery( # Need await to get async generator + cleaned_query, param=query_param + ) async def stream_generator(): try: @@ -1804,16 +1817,19 @@ def create_app(args): else: first_chunk_time = time.time_ns() - # Determine if the request is from Open WebUI's session title and session keyword generation task + # Determine if the request is prefix with "/bypass" or from Open WebUI's session title and session keyword generation task match_result = re.search( r"\n\nUSER:", cleaned_query, re.MULTILINE ) - if match_result: + if match_result or mode == SearchMode.bypass: if request.system: rag.llm_model_kwargs["system_prompt"] = request.system response_text = await rag.llm_model_func( - cleaned_query, stream=False, **rag.llm_model_kwargs + cleaned_query, + stream=False, + history_messages=conversation_history, + **rag.llm_model_kwargs, ) else: response_text = await rag.aquery(cleaned_query, param=query_param) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index f83c9e38..3014f737 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -126,8 +126,10 @@ class LightRAG: vector_storage: str = field(default="NanoVectorDBStorage") graph_storage: str = field(default="NetworkXStorage") + # logging current_log_level = logger.level log_level: str = field(default=current_log_level) + log_dir: str = field(default=os.getcwd()) # text chunking chunk_token_size: int = 1200 @@ -182,10 +184,11 @@ class LightRAG: chunking_func_kwargs: dict = field(default_factory=dict) def __post_init__(self): - log_file = os.path.join("lightrag.log") + os.makedirs(self.log_dir, exist_ok=True) + log_file = os.path.join(self.log_dir, "lightrag.log") set_logger(log_file) - logger.setLevel(self.log_level) + logger.setLevel(self.log_level) logger.info(f"Logger initialized for working directory: {self.working_dir}") if not os.path.exists(self.working_dir): logger.info(f"Creating working directory {self.working_dir}") @@ -649,7 +652,7 @@ class LightRAG: } chunk_cnt += len(chunks) await self.text_chunks.upsert(chunks) - await self.text_chunks.change_status(doc_id, DocStatus.PROCESSED) + await self.text_chunks.change_status(doc_id, DocStatus.PROCESSING) try: # Store chunks in vector database diff --git a/test_lightrag_ollama_chat.py b/test_lightrag_ollama_chat.py index d7c03641..8cb633ba 100644 --- a/test_lightrag_ollama_chat.py +++ b/test_lightrag_ollama_chat.py @@ -123,18 +123,20 @@ class TestStats: def make_request( - url: str, data: Dict[str, Any], stream: bool = False + url: str, data: Dict[str, Any], stream: bool = False, check_status: bool = True ) -> requests.Response: """Send an HTTP request with retry mechanism Args: url: Request URL data: Request data stream: Whether to use streaming response + check_status: Whether to check HTTP status code (default: True) Returns: requests.Response: Response object Raises: requests.exceptions.RequestException: Request failed after all retries + requests.exceptions.HTTPError: HTTP status code is not 200 (when check_status is True) """ server_config = CONFIG["server"] max_retries = server_config["max_retries"] @@ -144,6 +146,8 @@ def make_request( for attempt in range(max_retries): try: response = requests.post(url, json=data, stream=stream, timeout=timeout) + if check_status and response.status_code != 200: + response.raise_for_status() return response except requests.exceptions.RequestException as e: if attempt == max_retries - 1: # Last retry @@ -433,7 +437,7 @@ def test_stream_error_handling() -> None: if OutputControl.is_verbose(): print("\n--- Testing empty message list (streaming) ---") data = create_error_test_data("empty_messages") - response = make_request(url, data, stream=True) + response = make_request(url, data, stream=True, check_status=False) print(f"Status code: {response.status_code}") if response.status_code != 200: print_json_response(response.json(), "Error message") @@ -443,7 +447,7 @@ def test_stream_error_handling() -> None: if OutputControl.is_verbose(): print("\n--- Testing invalid role field (streaming) ---") data = create_error_test_data("invalid_role") - response = make_request(url, data, stream=True) + response = make_request(url, data, stream=True, check_status=False) print(f"Status code: {response.status_code}") if response.status_code != 200: print_json_response(response.json(), "Error message") @@ -453,7 +457,7 @@ def test_stream_error_handling() -> None: if OutputControl.is_verbose(): print("\n--- Testing missing content field (streaming) ---") data = create_error_test_data("missing_content") - response = make_request(url, data, stream=True) + response = make_request(url, data, stream=True, check_status=False) print(f"Status code: {response.status_code}") if response.status_code != 200: print_json_response(response.json(), "Error message") @@ -484,7 +488,7 @@ def test_error_handling() -> None: print("\n--- Testing empty message list ---") data = create_error_test_data("empty_messages") data["stream"] = False # Change to non-streaming mode - response = make_request(url, data) + response = make_request(url, data, check_status=False) print(f"Status code: {response.status_code}") print_json_response(response.json(), "Error message") @@ -493,7 +497,7 @@ def test_error_handling() -> None: print("\n--- Testing invalid role field ---") data = create_error_test_data("invalid_role") data["stream"] = False # Change to non-streaming mode - response = make_request(url, data) + response = make_request(url, data, check_status=False) print(f"Status code: {response.status_code}") print_json_response(response.json(), "Error message") @@ -502,7 +506,7 @@ def test_error_handling() -> None: print("\n--- Testing missing content field ---") data = create_error_test_data("missing_content") data["stream"] = False # Change to non-streaming mode - response = make_request(url, data) + response = make_request(url, data, check_status=False) print(f"Status code: {response.status_code}") print_json_response(response.json(), "Error message") @@ -609,7 +613,7 @@ def test_generate_error_handling() -> None: if OutputControl.is_verbose(): print("\n=== Testing empty prompt ===") data = create_generate_request_data("", stream=False) - response = make_request(url, data) + response = make_request(url, data, check_status=False) print(f"Status code: {response.status_code}") print_json_response(response.json(), "Error message") @@ -621,7 +625,7 @@ def test_generate_error_handling() -> None: options={"invalid_option": "value"}, stream=False, ) - response = make_request(url, data) + response = make_request(url, data, check_status=False) print(f"Status code: {response.status_code}") print_json_response(response.json(), "Error message") @@ -642,6 +646,8 @@ def test_generate_concurrent() -> None: data = create_generate_request_data(prompt, stream=False) try: async with session.post(url, json=data) as response: + if response.status != 200: + response.raise_for_status() return await response.json() except Exception as e: return {"error": str(e)} @@ -742,7 +748,7 @@ Configuration file (config.json): nargs="+", choices=list(get_test_cases().keys()) + ["all"], default=["all"], - help="Test cases to run, options: %(choices)s. Use 'all' to run all tests", + help="Test cases to run, options: %(choices)s. Use 'all' to run all tests (except error tests)", ) return parser.parse_args() @@ -766,21 +772,18 @@ if __name__ == "__main__": try: if "all" in args.tests: - # Run all tests + # Run all tests except error handling tests if OutputControl.is_verbose(): print("\n【Chat API Tests】") run_test(test_non_stream_chat, "Non-streaming Chat Test") run_test(test_stream_chat, "Streaming Chat Test") run_test(test_query_modes, "Chat Query Mode Test") - run_test(test_error_handling, "Chat Error Handling Test") - run_test(test_stream_error_handling, "Chat Streaming Error Test") if OutputControl.is_verbose(): print("\n【Generate API Tests】") run_test(test_non_stream_generate, "Non-streaming Generate Test") run_test(test_stream_generate, "Streaming Generate Test") run_test(test_generate_with_system, "Generate with System Prompt Test") - run_test(test_generate_error_handling, "Generate Error Handling Test") run_test(test_generate_concurrent, "Generate Concurrent Test") else: # Run specified tests