Merge branch 'HKUDS:main' into main
This commit is contained in:
@@ -799,6 +799,7 @@ if __name__ == "__main__":
|
||||
| **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese", "entity_types": ["organization", "person", "geo", "event"], "insert_batch_size": 10}`: sets example limit, output language, and batch size for document processing | `example_number: all examples, language: English, insert_batch_size: 10` |
|
||||
| **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` |
|
||||
| **embedding\_cache\_config** | `dict` | Configuration for question-answer caching. Contains three parameters:<br>- `enabled`: Boolean value to enable/disable cache lookup functionality. When enabled, the system will check cached responses before generating new answers.<br>- `similarity_threshold`: Float value (0-1), similarity threshold. When a new question's similarity with a cached question exceeds this threshold, the cached answer will be returned directly without calling the LLM.<br>- `use_llm_check`: Boolean value to enable/disable LLM similarity verification. When enabled, LLM will be used as a secondary check to verify the similarity between questions before returning cached answers. | Default: `{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False}` |
|
||||
|**log\_dir** | `str` | Directory to store logs. | `./` |
|
||||
|
||||
### Error Handling
|
||||
|
||||
|
@@ -1,666 +0,0 @@
|
||||
"""
|
||||
3D GraphML Viewer
|
||||
Author: LoLLMs
|
||||
Description: An interactive 3D GraphML viewer using PyQt5 and pyqtgraph
|
||||
Version: 2.2
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Dict, List, Any
|
||||
import pipmaster as pm
|
||||
|
||||
# Install all required dependencies
|
||||
REQUIRED_PACKAGES = [
|
||||
"PyQt5",
|
||||
"pyqtgraph",
|
||||
"numpy",
|
||||
"PyOpenGL",
|
||||
"PyOpenGL_accelerate",
|
||||
"networkx",
|
||||
"matplotlib",
|
||||
"python-louvain",
|
||||
"ascii_colors",
|
||||
]
|
||||
|
||||
|
||||
def setup_dependencies():
|
||||
"""
|
||||
Ensure all required packages are installed
|
||||
"""
|
||||
for package in REQUIRED_PACKAGES:
|
||||
if not pm.is_installed(package):
|
||||
print(f"Installing {package}...")
|
||||
pm.install(package)
|
||||
|
||||
|
||||
# Install dependencies
|
||||
setup_dependencies()
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import community
|
||||
from PyQt5.QtWidgets import (
|
||||
QApplication,
|
||||
QMainWindow,
|
||||
QWidget,
|
||||
QVBoxLayout,
|
||||
QHBoxLayout,
|
||||
QPushButton,
|
||||
QFileDialog,
|
||||
QLabel,
|
||||
QMessageBox,
|
||||
QSpinBox,
|
||||
QComboBox,
|
||||
QCheckBox,
|
||||
QTableWidget,
|
||||
QTableWidgetItem,
|
||||
QSplitter,
|
||||
QDockWidget,
|
||||
QTextEdit,
|
||||
)
|
||||
from PyQt5.QtCore import Qt
|
||||
import pyqtgraph.opengl as gl
|
||||
from ascii_colors import trace_exception
|
||||
|
||||
|
||||
class Point:
|
||||
"""Simple point class to handle coordinates"""
|
||||
|
||||
def __init__(self, x: float, y: float):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
|
||||
class NodeState:
|
||||
"""Data class for node visual state"""
|
||||
|
||||
NORMAL_SCALE = 1.0
|
||||
HOVER_SCALE = 1.2
|
||||
SELECTED_SCALE = 1.3
|
||||
|
||||
NORMAL_OPACITY = 0.8
|
||||
HOVER_OPACITY = 1.0
|
||||
SELECTED_OPACITY = 1.0
|
||||
|
||||
# Increase base node size (was 0.05)
|
||||
BASE_SIZE = 0.2
|
||||
|
||||
SELECTED_COLOR = (1.0, 1.0, 0.0, 1.0)
|
||||
HOVER_COLOR = (1.0, 0.8, 0.0, 1.0)
|
||||
|
||||
|
||||
class Node3D:
|
||||
"""Class representing a 3D node in the graph"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
position: np.ndarray,
|
||||
color: Tuple[float, float, float, float],
|
||||
label: str,
|
||||
node_type: str,
|
||||
size: float,
|
||||
):
|
||||
self.position = position
|
||||
self.base_color = color
|
||||
self.color = color
|
||||
self.label = label
|
||||
self.node_type = node_type
|
||||
self.size = size
|
||||
self.mesh_item = None
|
||||
self.label_item = None
|
||||
self.is_highlighted = False
|
||||
self.is_selected = False
|
||||
|
||||
def highlight(self):
|
||||
"""Highlight the node"""
|
||||
if not self.is_highlighted and not self.is_selected:
|
||||
self.color = NodeState.HOVER_COLOR
|
||||
self.update_appearance(NodeState.HOVER_SCALE)
|
||||
self.is_highlighted = True
|
||||
|
||||
def unhighlight(self):
|
||||
"""Remove highlight from node"""
|
||||
if self.is_highlighted and not self.is_selected:
|
||||
self.color = self.base_color
|
||||
self.update_appearance(NodeState.NORMAL_SCALE)
|
||||
self.is_highlighted = False
|
||||
|
||||
def select(self):
|
||||
"""Select the node"""
|
||||
self.is_selected = True
|
||||
self.color = NodeState.SELECTED_COLOR
|
||||
self.update_appearance(NodeState.SELECTED_SCALE)
|
||||
|
||||
def deselect(self):
|
||||
"""Deselect the node"""
|
||||
self.is_selected = False
|
||||
self.color = self.base_color
|
||||
self.update_appearance(NodeState.NORMAL_SCALE)
|
||||
|
||||
def update_appearance(self, scale: float = 1.0):
|
||||
"""Update node visual appearance"""
|
||||
if self.mesh_item:
|
||||
self.mesh_item.setData(
|
||||
color=np.array([self.color]), size=np.array([self.size * scale * 5])
|
||||
)
|
||||
|
||||
|
||||
class NodeDetailsWidget(QWidget):
|
||||
"""Widget to display node details"""
|
||||
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
"""Initialize the UI"""
|
||||
layout = QVBoxLayout(self)
|
||||
|
||||
# Properties text edit
|
||||
self.properties = QTextEdit()
|
||||
self.properties.setReadOnly(True)
|
||||
layout.addWidget(QLabel("Properties:"))
|
||||
layout.addWidget(self.properties)
|
||||
|
||||
# Connections table
|
||||
self.connections = QTableWidget()
|
||||
self.connections.setColumnCount(3)
|
||||
self.connections.setHorizontalHeaderLabels(
|
||||
["Connected Node", "Relationship", "Direction"]
|
||||
)
|
||||
layout.addWidget(QLabel("Connections:"))
|
||||
layout.addWidget(self.connections)
|
||||
|
||||
def update_node_info(self, node_data: Dict, connections: Dict):
|
||||
"""Update the display with node information"""
|
||||
# Update properties
|
||||
properties_text = "Node Properties:\n"
|
||||
for key, value in node_data.items():
|
||||
properties_text += f"{key}: {value}\n"
|
||||
self.properties.setText(properties_text)
|
||||
|
||||
# Update connections
|
||||
self.connections.setRowCount(len(connections))
|
||||
for idx, (neighbor, edge_data) in enumerate(connections.items()):
|
||||
self.connections.setItem(idx, 0, QTableWidgetItem(str(neighbor)))
|
||||
self.connections.setItem(
|
||||
idx, 1, QTableWidgetItem(edge_data.get("relationship", "unknown"))
|
||||
)
|
||||
self.connections.setItem(idx, 2, QTableWidgetItem("outgoing"))
|
||||
|
||||
|
||||
class GraphMLViewer3D(QMainWindow):
|
||||
"""Main window class for 3D GraphML visualization"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.graph: Optional[nx.Graph] = None
|
||||
self.nodes: Dict[str, Node3D] = {}
|
||||
self.edges: List[gl.GLLinePlotItem] = []
|
||||
self.edge_labels: List[gl.GLTextItem] = []
|
||||
self.selected_node = None
|
||||
self.communities = None
|
||||
self.community_colors = None
|
||||
|
||||
self.mouse_pos_last = None
|
||||
self.mouse_buttons_pressed = set()
|
||||
self.distance = 20 # Initial camera distance
|
||||
self.center = np.array([0, 0, 0]) # View center point
|
||||
self.elevation = 30 # Initial camera elevation
|
||||
self.azimuth = 45 # Initial camera azimuth
|
||||
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
"""Initialize the user interface"""
|
||||
self.setWindowTitle("3D GraphML Viewer")
|
||||
self.setGeometry(100, 100, 1600, 900)
|
||||
|
||||
# Create main splitter
|
||||
self.main_splitter = QSplitter(Qt.Horizontal)
|
||||
self.setCentralWidget(self.main_splitter)
|
||||
|
||||
# Create left panel for 3D view
|
||||
left_widget = QWidget()
|
||||
left_layout = QVBoxLayout(left_widget)
|
||||
|
||||
# Create controls
|
||||
self.create_toolbar(left_layout)
|
||||
|
||||
# Create 3D view
|
||||
self.view = gl.GLViewWidget()
|
||||
self.view.setMouseTracking(True)
|
||||
|
||||
# Connect mouse events
|
||||
self.view.mousePressEvent = self.on_mouse_press
|
||||
self.view.mouseMoveEvent = self.on_mouse_move
|
||||
left_layout.addWidget(self.view)
|
||||
|
||||
self.main_splitter.addWidget(left_widget)
|
||||
|
||||
# Create details widget
|
||||
self.details = NodeDetailsWidget()
|
||||
details_dock = QDockWidget("Node Details", self)
|
||||
details_dock.setWidget(self.details)
|
||||
self.addDockWidget(Qt.RightDockWidgetArea, details_dock)
|
||||
|
||||
# Add status bar
|
||||
self.statusBar().showMessage("Ready")
|
||||
|
||||
# Add initial grid
|
||||
grid = gl.GLGridItem()
|
||||
grid.setSize(x=20, y=20, z=20)
|
||||
grid.setSpacing(x=1, y=1, z=1)
|
||||
self.view.addItem(grid)
|
||||
|
||||
# Set initial camera position
|
||||
self.view.setCameraPosition(
|
||||
distance=self.distance, elevation=self.elevation, azimuth=self.azimuth
|
||||
)
|
||||
|
||||
# Connect all mouse events
|
||||
self.view.mousePressEvent = self.on_mouse_press
|
||||
self.view.mouseReleaseEvent = self.on_mouse_release
|
||||
self.view.mouseMoveEvent = self.on_mouse_move
|
||||
self.view.wheelEvent = self.on_mouse_wheel
|
||||
|
||||
def calculate_node_sizes(self) -> Dict[str, float]:
|
||||
"""Calculate node sizes based on number of connections"""
|
||||
if not self.graph:
|
||||
return {}
|
||||
|
||||
# Get degree (number of connections) for each node
|
||||
degrees = dict(self.graph.degree())
|
||||
|
||||
# Calculate size scaling
|
||||
max_degree = max(degrees.values())
|
||||
min_degree = min(degrees.values())
|
||||
|
||||
# Normalize sizes between 0.5 and 2.0
|
||||
sizes = {}
|
||||
for node, degree in degrees.items():
|
||||
if max_degree == min_degree:
|
||||
sizes[node] = 1.0
|
||||
else:
|
||||
# Normalize and scale size
|
||||
normalized = (degree - min_degree) / (max_degree - min_degree)
|
||||
sizes[node] = 0.5 + normalized * 1.5
|
||||
|
||||
return sizes
|
||||
|
||||
def create_toolbar(self, layout: QVBoxLayout):
|
||||
"""Create the toolbar with controls"""
|
||||
toolbar = QHBoxLayout()
|
||||
|
||||
# Load button
|
||||
load_btn = QPushButton("Load GraphML")
|
||||
load_btn.clicked.connect(self.load_graphml)
|
||||
toolbar.addWidget(load_btn)
|
||||
|
||||
# Reset view button
|
||||
reset_btn = QPushButton("Reset View")
|
||||
reset_btn.clicked.connect(lambda: self.view.setCameraPosition(distance=20))
|
||||
toolbar.addWidget(reset_btn)
|
||||
|
||||
# Layout selector
|
||||
self.layout_combo = QComboBox()
|
||||
self.layout_combo.addItems(["Spring", "Circular", "Shell", "Random"])
|
||||
self.layout_combo.currentTextChanged.connect(self.refresh_layout)
|
||||
toolbar.addWidget(QLabel("Layout:"))
|
||||
toolbar.addWidget(self.layout_combo)
|
||||
|
||||
# Node size control
|
||||
self.node_size = QSpinBox()
|
||||
self.node_size.setRange(1, 100)
|
||||
self.node_size.setValue(20)
|
||||
self.node_size.valueChanged.connect(self.refresh_layout)
|
||||
toolbar.addWidget(QLabel("Node Size:"))
|
||||
toolbar.addWidget(self.node_size)
|
||||
|
||||
# Show labels checkbox
|
||||
self.show_labels = QCheckBox("Show Labels")
|
||||
self.show_labels.setChecked(True)
|
||||
self.show_labels.stateChanged.connect(self.refresh_layout)
|
||||
toolbar.addWidget(self.show_labels)
|
||||
|
||||
layout.addLayout(toolbar)
|
||||
# Reset view button
|
||||
reset_btn = QPushButton("Reset View")
|
||||
reset_btn.clicked.connect(self.reset_view) # Use the new reset_view method
|
||||
toolbar.addWidget(reset_btn)
|
||||
|
||||
def load_graphml(self) -> None:
|
||||
"""Load and visualize a GraphML file"""
|
||||
try:
|
||||
file_path, _ = QFileDialog.getOpenFileName(
|
||||
self, "Open GraphML file", "", "GraphML files (*.graphml)"
|
||||
)
|
||||
|
||||
if file_path:
|
||||
self.graph = nx.read_graphml(Path(file_path))
|
||||
self.refresh_layout()
|
||||
self.statusBar().showMessage(f"Loaded: {file_path}")
|
||||
except Exception as e:
|
||||
trace_exception(e)
|
||||
QMessageBox.critical(self, "Error", f"Error loading file: {str(e)}")
|
||||
|
||||
def calculate_layout(self) -> Dict[str, np.ndarray]:
|
||||
"""Calculate node positions based on selected layout"""
|
||||
layout_type = self.layout_combo.currentText().lower()
|
||||
|
||||
# Detect communities for coloring
|
||||
self.communities = community.best_partition(self.graph)
|
||||
num_communities = len(set(self.communities.values()))
|
||||
self.community_colors = plt.cm.rainbow(np.linspace(0, 1, num_communities))
|
||||
|
||||
if layout_type == "spring":
|
||||
pos = nx.spring_layout(
|
||||
self.graph, dim=3, k=2.0, iterations=100, weight=None
|
||||
)
|
||||
elif layout_type == "circular":
|
||||
pos_2d = nx.circular_layout(self.graph)
|
||||
pos = {node: np.array([x, y, 0.0]) for node, (x, y) in pos_2d.items()}
|
||||
elif layout_type == "shell":
|
||||
comm_lists = [[] for _ in range(num_communities)]
|
||||
for node, comm in self.communities.items():
|
||||
comm_lists[comm].append(node)
|
||||
pos_2d = nx.shell_layout(self.graph, comm_lists)
|
||||
pos = {node: np.array([x, y, 0.0]) for node, (x, y) in pos_2d.items()}
|
||||
else: # random
|
||||
pos = {node: np.random.rand(3) * 2 - 1 for node in self.graph.nodes()}
|
||||
|
||||
# Scale positions
|
||||
positions = np.array(list(pos.values()))
|
||||
if len(positions) > 0:
|
||||
scale = 10.0 / max(1.0, np.max(np.abs(positions)))
|
||||
return {node: coords * scale for node, coords in pos.items()}
|
||||
return pos
|
||||
|
||||
def get_node_color(self, node_id: str) -> Tuple[float, float, float, float]:
|
||||
"""Get RGBA color based on community"""
|
||||
if hasattr(self, "communities") and node_id in self.communities:
|
||||
comm_id = self.communities[node_id]
|
||||
color = self.community_colors[comm_id]
|
||||
return tuple(color)
|
||||
return (0.5, 0.5, 0.5, 0.8)
|
||||
|
||||
def create_node(self, node_id: str, position: np.ndarray, node_type: str) -> Node3D:
|
||||
"""Create a 3D node with interaction capabilities"""
|
||||
color = self.get_node_color(node_id)
|
||||
|
||||
# Get size multiplier based on connections
|
||||
size_multiplier = self.node_sizes.get(node_id, 1.0)
|
||||
size = NodeState.BASE_SIZE * self.node_size.value() / 50.0 * size_multiplier
|
||||
|
||||
node = Node3D(position, color, str(node_id), node_type, size)
|
||||
|
||||
node.mesh_item = gl.GLScatterPlotItem(
|
||||
pos=np.array([position]),
|
||||
size=np.array([size * 8]),
|
||||
color=np.array([color]),
|
||||
pxMode=False,
|
||||
)
|
||||
|
||||
# Enable picking and set node ID
|
||||
node.mesh_item.setGLOptions("translucent")
|
||||
node.mesh_item.node_id = node_id
|
||||
|
||||
if self.show_labels.isChecked():
|
||||
node.label_item = gl.GLTextItem(
|
||||
pos=position,
|
||||
text=str(node_id),
|
||||
color=(1, 1, 1, 1),
|
||||
)
|
||||
|
||||
return node
|
||||
|
||||
def mapToView(self, pos) -> Point:
|
||||
"""Convert screen coordinates to world coordinates"""
|
||||
# Get the viewport size
|
||||
width = self.view.width()
|
||||
height = self.view.height()
|
||||
|
||||
# Normalize coordinates
|
||||
x = (pos.x() / width - 0.5) * 20 # Scale factor of 20 matches the grid size
|
||||
y = -(pos.y() / height - 0.5) * 20
|
||||
|
||||
return Point(x, y)
|
||||
|
||||
def on_mouse_move(self, event):
|
||||
"""Handle mouse movement for pan, rotate and hover"""
|
||||
if self.mouse_pos_last is None:
|
||||
self.mouse_pos_last = event.pos()
|
||||
return
|
||||
|
||||
pos = event.pos()
|
||||
dx = pos.x() - self.mouse_pos_last.x()
|
||||
dy = pos.y() - self.mouse_pos_last.y()
|
||||
|
||||
# Handle right button drag for panning
|
||||
if Qt.RightButton in self.mouse_buttons_pressed:
|
||||
# Scale the pan amount based on view distance
|
||||
scale = self.distance / 1000.0
|
||||
|
||||
# Calculate pan in view coordinates
|
||||
right = np.cross([0, 0, 1], self.view.cameraPosition())
|
||||
right = right / np.linalg.norm(right)
|
||||
up = np.cross(self.view.cameraPosition(), right)
|
||||
up = up / np.linalg.norm(up)
|
||||
|
||||
pan = -right * dx * scale + up * dy * scale
|
||||
self.center += pan
|
||||
self.view.pan(dx, dy, 0)
|
||||
|
||||
# Handle middle button drag for rotation
|
||||
elif Qt.MiddleButton in self.mouse_buttons_pressed:
|
||||
self.azimuth += dx * 0.5 # Adjust rotation speed as needed
|
||||
self.elevation -= dy * 0.5
|
||||
|
||||
# Clamp elevation to prevent gimbal lock
|
||||
self.elevation = np.clip(self.elevation, -89, 89)
|
||||
|
||||
self.view.setCameraPosition(
|
||||
distance=self.distance, elevation=self.elevation, azimuth=self.azimuth
|
||||
)
|
||||
|
||||
# Handle hover events when no buttons are pressed
|
||||
elif not self.mouse_buttons_pressed:
|
||||
# Get the mouse position in world coordinates
|
||||
mouse_pos = self.mapToView(pos)
|
||||
|
||||
# Check for hover
|
||||
min_dist = float("inf")
|
||||
hovered_node = None
|
||||
|
||||
for node_id, node in self.nodes.items():
|
||||
# Calculate distance to mouse in world coordinates
|
||||
dx = mouse_pos.x - node.position[0]
|
||||
dy = mouse_pos.y - node.position[1]
|
||||
dist = np.sqrt(dx * dx + dy * dy)
|
||||
|
||||
if dist < min_dist and dist < 0.5: # Adjust threshold as needed
|
||||
min_dist = dist
|
||||
hovered_node = node_id
|
||||
|
||||
# Update hover states
|
||||
for node_id, node in self.nodes.items():
|
||||
if node_id == hovered_node:
|
||||
node.highlight()
|
||||
self.statusBar().showMessage(f"Node: {node_id} ({node.node_type})")
|
||||
else:
|
||||
if not node.is_selected:
|
||||
node.unhighlight()
|
||||
self.mouse_pos_last = pos
|
||||
|
||||
def on_mouse_press(self, event):
|
||||
"""Handle mouse press events"""
|
||||
self.mouse_pos_last = event.pos()
|
||||
self.mouse_buttons_pressed.add(event.button())
|
||||
|
||||
# Handle left click for node selection
|
||||
if event.button() == Qt.LeftButton:
|
||||
pos = event.pos()
|
||||
mouse_pos = self.mapToView(pos)
|
||||
|
||||
# Find closest node
|
||||
min_dist = float("inf")
|
||||
clicked_node = None
|
||||
|
||||
for node_id, node in self.nodes.items():
|
||||
dx = mouse_pos.x - node.position[0]
|
||||
dy = mouse_pos.y - node.position[1]
|
||||
dist = np.sqrt(dx * dx + dy * dy)
|
||||
|
||||
if dist < min_dist and dist < 0.5: # Adjust threshold as needed
|
||||
min_dist = dist
|
||||
clicked_node = node_id
|
||||
|
||||
# Handle selection
|
||||
if clicked_node:
|
||||
if self.selected_node and self.selected_node in self.nodes:
|
||||
self.nodes[self.selected_node].deselect()
|
||||
|
||||
self.nodes[clicked_node].select()
|
||||
self.selected_node = clicked_node
|
||||
|
||||
if self.graph:
|
||||
self.details.update_node_info(
|
||||
self.graph.nodes[clicked_node], self.graph[clicked_node]
|
||||
)
|
||||
|
||||
def on_mouse_release(self, event):
|
||||
"""Handle mouse release events"""
|
||||
self.mouse_buttons_pressed.discard(event.button())
|
||||
self.mouse_pos_last = None
|
||||
|
||||
def on_mouse_wheel(self, event):
|
||||
"""Handle mouse wheel for zooming"""
|
||||
delta = event.angleDelta().y()
|
||||
|
||||
# Adjust zoom speed based on current distance
|
||||
zoom_speed = self.distance / 100.0
|
||||
|
||||
# Update distance with limits
|
||||
self.distance -= delta * zoom_speed
|
||||
self.distance = np.clip(self.distance, 1.0, 100.0)
|
||||
|
||||
self.view.setCameraPosition(
|
||||
distance=self.distance, elevation=self.elevation, azimuth=self.azimuth
|
||||
)
|
||||
|
||||
def reset_view(self):
|
||||
"""Reset camera to default position"""
|
||||
self.distance = 20
|
||||
self.elevation = 30
|
||||
self.azimuth = 45
|
||||
self.center = np.array([0, 0, 0])
|
||||
|
||||
self.view.setCameraPosition(
|
||||
distance=self.distance, elevation=self.elevation, azimuth=self.azimuth
|
||||
)
|
||||
|
||||
def create_edge(
|
||||
self,
|
||||
start_pos: np.ndarray,
|
||||
end_pos: np.ndarray,
|
||||
color: Tuple[float, float, float, float] = (0.3, 0.3, 0.3, 0.2),
|
||||
) -> gl.GLLinePlotItem:
|
||||
"""Create a 3D edge between nodes"""
|
||||
return gl.GLLinePlotItem(
|
||||
pos=np.array([start_pos, end_pos]),
|
||||
color=color,
|
||||
width=1,
|
||||
antialias=True,
|
||||
mode="lines",
|
||||
)
|
||||
|
||||
def handle_node_hover(self, event: Any, node_id: str) -> None:
|
||||
"""Handle node hover events"""
|
||||
if node_id in self.nodes:
|
||||
node = self.nodes[node_id]
|
||||
if event.isEnter():
|
||||
node.highlight()
|
||||
self.statusBar().showMessage(f"Node: {node_id} ({node.node_type})")
|
||||
elif event.isExit():
|
||||
node.unhighlight()
|
||||
self.statusBar().showMessage("")
|
||||
|
||||
def handle_node_click(self, event: Any, node_id: str) -> None:
|
||||
"""Handle node click events"""
|
||||
if event.button() != Qt.LeftButton or node_id not in self.nodes:
|
||||
return
|
||||
|
||||
if self.selected_node and self.selected_node in self.nodes:
|
||||
self.nodes[self.selected_node].deselect()
|
||||
|
||||
node = self.nodes[node_id]
|
||||
node.select()
|
||||
self.selected_node = node_id
|
||||
|
||||
if self.graph:
|
||||
self.details.update_node_info(
|
||||
self.graph.nodes[node_id], self.graph[node_id]
|
||||
)
|
||||
|
||||
def refresh_layout(self) -> None:
|
||||
"""Refresh the graph visualization"""
|
||||
if not self.graph:
|
||||
return
|
||||
|
||||
self.positions = self.calculate_layout()
|
||||
self.node_sizes = self.calculate_node_sizes()
|
||||
|
||||
self.view.clear()
|
||||
self.nodes.clear()
|
||||
self.edges.clear()
|
||||
self.edge_labels.clear()
|
||||
|
||||
grid = gl.GLGridItem()
|
||||
grid.setSize(x=20, y=20, z=20)
|
||||
grid.setSpacing(x=1, y=1, z=1)
|
||||
self.view.addItem(grid)
|
||||
|
||||
positions = self.calculate_layout()
|
||||
|
||||
for node_id in self.graph.nodes():
|
||||
node_type = self.graph.nodes[node_id].get("type", "default")
|
||||
node = self.create_node(node_id, positions[node_id], node_type)
|
||||
|
||||
self.view.addItem(node.mesh_item)
|
||||
if node.label_item:
|
||||
self.view.addItem(node.label_item)
|
||||
|
||||
self.nodes[node_id] = node
|
||||
|
||||
for source, target in self.graph.edges():
|
||||
edge = self.create_edge(positions[source], positions[target])
|
||||
self.view.addItem(edge)
|
||||
self.edges.append(edge)
|
||||
|
||||
if self.show_labels.isChecked():
|
||||
mid_point = (positions[source] + positions[target]) / 2
|
||||
relationship = self.graph.edges[source, target].get("relationship", "")
|
||||
if relationship:
|
||||
label = gl.GLTextItem(
|
||||
pos=mid_point,
|
||||
text=relationship,
|
||||
color=(0.8, 0.8, 0.8, 0.8),
|
||||
)
|
||||
self.view.addItem(label)
|
||||
self.edge_labels.append(label)
|
||||
|
||||
|
||||
def main():
|
||||
"""Application entry point"""
|
||||
import sys
|
||||
|
||||
app = QApplication(sys.argv)
|
||||
viewer = GraphMLViewer3D()
|
||||
viewer.show()
|
||||
sys.exit(app.exec_())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
95
extra/VisualizationTool/README-zh.md
Normal file
95
extra/VisualizationTool/README-zh.md
Normal file
@@ -0,0 +1,95 @@
|
||||
# 3D GraphML Viewer
|
||||
|
||||
一个基于 Dear ImGui 和 ModernGL 的交互式 3D 图可视化工具。
|
||||
|
||||
## 功能特点
|
||||
|
||||
- **3D 交互式可视化**: 使用 ModernGL 实现高性能的 3D 图形渲染
|
||||
- **多种布局算法**: 支持多种图布局方式
|
||||
- Spring 布局
|
||||
- Circular 布局
|
||||
- Shell 布局
|
||||
- Random 布局
|
||||
- **社区检测**: 支持图社区结构的自动检测和可视化
|
||||
- **交互控制**:
|
||||
- WASD + QE 键控制相机移动
|
||||
- 鼠标右键拖拽控制视角
|
||||
- 节点选择和高亮
|
||||
- 可调节节点大小和边宽度
|
||||
- 可控制标签显示
|
||||
- 可在节点的Connections间快速跳转
|
||||
- **社区检测**: 支持图社区结构的自动检测和可视化
|
||||
- **交互控制**:
|
||||
- WASD + QE 键控制相机移动
|
||||
- 鼠标右键拖拽控制视角
|
||||
- 节点选择和高亮
|
||||
- 可调节节点大小和边宽度
|
||||
- 可控制标签显示
|
||||
|
||||
## 技术栈
|
||||
|
||||
- **imgui_bundle**: 用户界面
|
||||
- **ModernGL**: OpenGL 图形渲染
|
||||
- **NetworkX**: 图数据结构和算法
|
||||
- **NumPy**: 数值计算
|
||||
- **community**: 社区检测
|
||||
|
||||
## 使用方法
|
||||
|
||||
1. **启动程序**:
|
||||
```bash
|
||||
python -m pip install -r requirements.txt
|
||||
python graph_visualizer.py
|
||||
```
|
||||
|
||||
2. **加载字体**:
|
||||
- 将中文字体文件 `font.ttf` 放置在 `assets` 目录下
|
||||
- 或者修改 `CUSTOM_FONT` 常量来使用其他字体文件
|
||||
|
||||
3. **加载图文件**:
|
||||
- 点击界面上的 "Load GraphML" 按钮
|
||||
- 选择 GraphML 格式的图文件
|
||||
|
||||
4. **交互控制**:
|
||||
- **相机移动**:
|
||||
- W: 前进
|
||||
- S: 后退
|
||||
- A: 左移
|
||||
- D: 右移
|
||||
- Q: 上升
|
||||
- E: 下降
|
||||
- **视角控制**:
|
||||
- 按住鼠标右键拖动来旋转视角
|
||||
- **节点交互**:
|
||||
- 鼠标悬停可高亮节点
|
||||
- 点击可选中节点
|
||||
|
||||
5. **可视化设置**:
|
||||
- 可通过 UI 控制面板调整:
|
||||
- 布局类型
|
||||
- 节点大小
|
||||
- 边的宽度
|
||||
- 标签显示
|
||||
- 标签大小
|
||||
- 背景颜色
|
||||
|
||||
## 自定义设置
|
||||
|
||||
- **节点缩放**: 通过 `node_scale` 参数调整节点大小
|
||||
- **边宽度**: 通过 `edge_width` 参数调整边的宽度
|
||||
- **标签显示**: 可通过 `show_labels` 开关标签显示
|
||||
- **标签大小**: 使用 `label_size` 调整标签大小
|
||||
- **标签颜色**: 通过 `label_color` 设置标签颜色
|
||||
- **视距控制**: 使用 `label_culling_distance` 控制标签显示的最大距离
|
||||
|
||||
## 性能优化
|
||||
|
||||
- 使用 ModernGL 进行高效的图形渲染
|
||||
- 视距裁剪优化标签显示
|
||||
- 社区检测算法优化大规模图的可视化效果
|
||||
|
||||
## 系统要求
|
||||
|
||||
- Python 3.10+
|
||||
- OpenGL 3.3+ 兼容的显卡
|
||||
- 支持的操作系统:Windows/Linux/MacOS
|
88
extra/VisualizationTool/README.md
Normal file
88
extra/VisualizationTool/README.md
Normal file
@@ -0,0 +1,88 @@
|
||||
# 3D GraphML Viewer
|
||||
|
||||
An interactive 3D graph visualization tool based on Dear ImGui and ModernGL.
|
||||
|
||||
## Features
|
||||
|
||||
- **3D Interactive Visualization**: High-performance 3D graphics rendering using ModernGL
|
||||
- **Multiple Layout Algorithms**: Support for various graph layouts
|
||||
- Spring layout
|
||||
- Circular layout
|
||||
- Shell layout
|
||||
- Random layout
|
||||
- **Community Detection**: Automatic detection and visualization of graph community structures
|
||||
- **Interactive Controls**:
|
||||
- WASD + QE keys for camera movement
|
||||
- Right mouse drag for view angle control
|
||||
- Node selection and highlighting
|
||||
- Adjustable node size and edge width
|
||||
- Configurable label display
|
||||
- Quick navigation between node connections
|
||||
|
||||
## Tech Stack
|
||||
|
||||
- **imgui_bundle**: User interface
|
||||
- **ModernGL**: OpenGL graphics rendering
|
||||
- **NetworkX**: Graph data structures and algorithms
|
||||
- **NumPy**: Numerical computations
|
||||
- **community**: Community detection
|
||||
|
||||
## Usage
|
||||
|
||||
1. **Launch the Program**:
|
||||
```bash
|
||||
python -m pip install -r requirements.txt
|
||||
python graph_visualizer.py
|
||||
```
|
||||
|
||||
2. **Load Font**:
|
||||
- Place the font file `font.ttf` in the `assets` directory
|
||||
- Or modify the `CUSTOM_FONT` constant to use a different font file
|
||||
|
||||
3. **Load Graph File**:
|
||||
- Click the "Load GraphML" button in the interface
|
||||
- Select a graph file in GraphML format
|
||||
|
||||
4. **Interactive Controls**:
|
||||
- **Camera Movement**:
|
||||
- W: Move forward
|
||||
- S: Move backward
|
||||
- A: Move left
|
||||
- D: Move right
|
||||
- Q: Move up
|
||||
- E: Move down
|
||||
- **View Control**:
|
||||
- Hold right mouse button and drag to rotate view
|
||||
- **Node Interaction**:
|
||||
- Hover mouse to highlight nodes
|
||||
- Click to select nodes
|
||||
|
||||
5. **Visualization Settings**:
|
||||
- Adjustable via UI control panel:
|
||||
- Layout type
|
||||
- Node size
|
||||
- Edge width
|
||||
- Label visibility
|
||||
- Label size
|
||||
- Background color
|
||||
|
||||
## Customization Options
|
||||
|
||||
- **Node Scaling**: Adjust node size via `node_scale` parameter
|
||||
- **Edge Width**: Modify edge width using `edge_width` parameter
|
||||
- **Label Display**: Toggle label visibility with `show_labels`
|
||||
- **Label Size**: Adjust label size using `label_size`
|
||||
- **Label Color**: Set label color through `label_color`
|
||||
- **View Distance**: Control maximum label display distance with `label_culling_distance`
|
||||
|
||||
## Performance Optimizations
|
||||
|
||||
- Efficient graphics rendering using ModernGL
|
||||
- View distance culling for label display optimization
|
||||
- Community detection algorithms for optimized visualization of large-scale graphs
|
||||
|
||||
## System Requirements
|
||||
|
||||
- Python 3.10+
|
||||
- Graphics card with OpenGL 3.3+ support
|
||||
- Supported Operating Systems: Windows/Linux/MacOS
|
0
extra/VisualizationTool/assets/place_font_here
Normal file
0
extra/VisualizationTool/assets/place_font_here
Normal file
1149
extra/VisualizationTool/graph_visualizer.py
Normal file
1149
extra/VisualizationTool/graph_visualizer.py
Normal file
File diff suppressed because it is too large
Load Diff
8
extra/VisualizationTool/requirements.txt
Normal file
8
extra/VisualizationTool/requirements.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
imgui_bundle
|
||||
moderngl
|
||||
networkx
|
||||
numpy
|
||||
pyglm
|
||||
python-louvain
|
||||
scipy
|
||||
tk
|
@@ -82,14 +82,19 @@ We provide an Ollama-compatible interfaces for LightRAG, aiming to emulate Light
|
||||
|
||||
A query prefix in the query string can determines which LightRAG query mode is used to generate the respond for the query. The supported prefixes include:
|
||||
|
||||
```
|
||||
/local
|
||||
/global
|
||||
/hybrid
|
||||
/naive
|
||||
/mix
|
||||
/bypass
|
||||
```
|
||||
|
||||
For example, chat message "/mix 唐僧有几个徒弟" will trigger a mix mode query for LighRAG. A chat message without query prefix will trigger a hybrid mode query by default。
|
||||
|
||||
"/bypass" is not a LightRAG query mode, it will tell API Server to pass the query directly to the underlying LLM with chat history. So user can use LLM to answer question base on the LightRAG query results. (If you are using Open WebUI as front end, you can just switch the model to a normal LLM instead of using /bypass prefix)
|
||||
|
||||
#### Connect Open WebUI to LightRAG
|
||||
|
||||
After starting the lightrag-server, you can add an Ollama-type connection in the Open WebUI admin pannel. And then a model named lightrag:latest will appear in Open WebUI's model management interface. Users can then send queries to LightRAG through the chat interface.
|
||||
|
@@ -599,6 +599,7 @@ class SearchMode(str, Enum):
|
||||
global_ = "global"
|
||||
hybrid = "hybrid"
|
||||
mix = "mix"
|
||||
bypass = "bypass"
|
||||
|
||||
|
||||
class OllamaMessage(BaseModel):
|
||||
@@ -1476,7 +1477,7 @@ def create_app(args):
|
||||
|
||||
@app.get("/api/tags")
|
||||
async def get_tags():
|
||||
"""Get available models"""
|
||||
"""Retrun available models acting as an Ollama server"""
|
||||
return OllamaTagResponse(
|
||||
models=[
|
||||
{
|
||||
@@ -1507,6 +1508,7 @@ def create_app(args):
|
||||
"/naive ": SearchMode.naive,
|
||||
"/hybrid ": SearchMode.hybrid,
|
||||
"/mix ": SearchMode.mix,
|
||||
"/bypass ": SearchMode.bypass,
|
||||
}
|
||||
|
||||
for prefix, mode in mode_map.items():
|
||||
@@ -1519,7 +1521,7 @@ def create_app(args):
|
||||
|
||||
@app.post("/api/generate")
|
||||
async def generate(raw_request: Request, request: OllamaGenerateRequest):
|
||||
"""Handle generate completion requests
|
||||
"""Handle generate completion requests acting as an Ollama model
|
||||
For compatiblity purpuse, the request is not processed by LightRAG,
|
||||
and will be handled by underlying LLM model.
|
||||
"""
|
||||
@@ -1661,7 +1663,7 @@ def create_app(args):
|
||||
|
||||
@app.post("/api/chat")
|
||||
async def chat(raw_request: Request, request: OllamaChatRequest):
|
||||
"""Process chat completion requests.
|
||||
"""Process chat completion requests acting as an Ollama model
|
||||
Routes user queries through LightRAG by selecting query mode based on prefix indicators.
|
||||
Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to LLM.
|
||||
"""
|
||||
@@ -1700,9 +1702,20 @@ def create_app(args):
|
||||
if request.stream:
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
response = await rag.aquery( # Need await to get async generator
|
||||
cleaned_query, param=query_param
|
||||
)
|
||||
# Determine if the request is prefix with "/bypass"
|
||||
if mode == SearchMode.bypass:
|
||||
if request.system:
|
||||
rag.llm_model_kwargs["system_prompt"] = request.system
|
||||
response = await rag.llm_model_func(
|
||||
cleaned_query,
|
||||
stream=True,
|
||||
history_messages=conversation_history,
|
||||
**rag.llm_model_kwargs,
|
||||
)
|
||||
else:
|
||||
response = await rag.aquery( # Need await to get async generator
|
||||
cleaned_query, param=query_param
|
||||
)
|
||||
|
||||
async def stream_generator():
|
||||
try:
|
||||
@@ -1804,16 +1817,19 @@ def create_app(args):
|
||||
else:
|
||||
first_chunk_time = time.time_ns()
|
||||
|
||||
# Determine if the request is from Open WebUI's session title and session keyword generation task
|
||||
# Determine if the request is prefix with "/bypass" or from Open WebUI's session title and session keyword generation task
|
||||
match_result = re.search(
|
||||
r"\n<chat_history>\nUSER:", cleaned_query, re.MULTILINE
|
||||
)
|
||||
if match_result:
|
||||
if match_result or mode == SearchMode.bypass:
|
||||
if request.system:
|
||||
rag.llm_model_kwargs["system_prompt"] = request.system
|
||||
|
||||
response_text = await rag.llm_model_func(
|
||||
cleaned_query, stream=False, **rag.llm_model_kwargs
|
||||
cleaned_query,
|
||||
stream=False,
|
||||
history_messages=conversation_history,
|
||||
**rag.llm_model_kwargs,
|
||||
)
|
||||
else:
|
||||
response_text = await rag.aquery(cleaned_query, param=query_param)
|
||||
|
@@ -126,8 +126,10 @@ class LightRAG:
|
||||
vector_storage: str = field(default="NanoVectorDBStorage")
|
||||
graph_storage: str = field(default="NetworkXStorage")
|
||||
|
||||
# logging
|
||||
current_log_level = logger.level
|
||||
log_level: str = field(default=current_log_level)
|
||||
log_dir: str = field(default=os.getcwd())
|
||||
|
||||
# text chunking
|
||||
chunk_token_size: int = 1200
|
||||
@@ -182,10 +184,11 @@ class LightRAG:
|
||||
chunking_func_kwargs: dict = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
log_file = os.path.join("lightrag.log")
|
||||
os.makedirs(self.log_dir, exist_ok=True)
|
||||
log_file = os.path.join(self.log_dir, "lightrag.log")
|
||||
set_logger(log_file)
|
||||
logger.setLevel(self.log_level)
|
||||
|
||||
logger.setLevel(self.log_level)
|
||||
logger.info(f"Logger initialized for working directory: {self.working_dir}")
|
||||
if not os.path.exists(self.working_dir):
|
||||
logger.info(f"Creating working directory {self.working_dir}")
|
||||
@@ -649,7 +652,7 @@ class LightRAG:
|
||||
}
|
||||
chunk_cnt += len(chunks)
|
||||
await self.text_chunks.upsert(chunks)
|
||||
await self.text_chunks.change_status(doc_id, DocStatus.PROCESSED)
|
||||
await self.text_chunks.change_status(doc_id, DocStatus.PROCESSING)
|
||||
|
||||
try:
|
||||
# Store chunks in vector database
|
||||
|
@@ -123,18 +123,20 @@ class TestStats:
|
||||
|
||||
|
||||
def make_request(
|
||||
url: str, data: Dict[str, Any], stream: bool = False
|
||||
url: str, data: Dict[str, Any], stream: bool = False, check_status: bool = True
|
||||
) -> requests.Response:
|
||||
"""Send an HTTP request with retry mechanism
|
||||
Args:
|
||||
url: Request URL
|
||||
data: Request data
|
||||
stream: Whether to use streaming response
|
||||
check_status: Whether to check HTTP status code (default: True)
|
||||
Returns:
|
||||
requests.Response: Response object
|
||||
|
||||
Raises:
|
||||
requests.exceptions.RequestException: Request failed after all retries
|
||||
requests.exceptions.HTTPError: HTTP status code is not 200 (when check_status is True)
|
||||
"""
|
||||
server_config = CONFIG["server"]
|
||||
max_retries = server_config["max_retries"]
|
||||
@@ -144,6 +146,8 @@ def make_request(
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
response = requests.post(url, json=data, stream=stream, timeout=timeout)
|
||||
if check_status and response.status_code != 200:
|
||||
response.raise_for_status()
|
||||
return response
|
||||
except requests.exceptions.RequestException as e:
|
||||
if attempt == max_retries - 1: # Last retry
|
||||
@@ -433,7 +437,7 @@ def test_stream_error_handling() -> None:
|
||||
if OutputControl.is_verbose():
|
||||
print("\n--- Testing empty message list (streaming) ---")
|
||||
data = create_error_test_data("empty_messages")
|
||||
response = make_request(url, data, stream=True)
|
||||
response = make_request(url, data, stream=True, check_status=False)
|
||||
print(f"Status code: {response.status_code}")
|
||||
if response.status_code != 200:
|
||||
print_json_response(response.json(), "Error message")
|
||||
@@ -443,7 +447,7 @@ def test_stream_error_handling() -> None:
|
||||
if OutputControl.is_verbose():
|
||||
print("\n--- Testing invalid role field (streaming) ---")
|
||||
data = create_error_test_data("invalid_role")
|
||||
response = make_request(url, data, stream=True)
|
||||
response = make_request(url, data, stream=True, check_status=False)
|
||||
print(f"Status code: {response.status_code}")
|
||||
if response.status_code != 200:
|
||||
print_json_response(response.json(), "Error message")
|
||||
@@ -453,7 +457,7 @@ def test_stream_error_handling() -> None:
|
||||
if OutputControl.is_verbose():
|
||||
print("\n--- Testing missing content field (streaming) ---")
|
||||
data = create_error_test_data("missing_content")
|
||||
response = make_request(url, data, stream=True)
|
||||
response = make_request(url, data, stream=True, check_status=False)
|
||||
print(f"Status code: {response.status_code}")
|
||||
if response.status_code != 200:
|
||||
print_json_response(response.json(), "Error message")
|
||||
@@ -484,7 +488,7 @@ def test_error_handling() -> None:
|
||||
print("\n--- Testing empty message list ---")
|
||||
data = create_error_test_data("empty_messages")
|
||||
data["stream"] = False # Change to non-streaming mode
|
||||
response = make_request(url, data)
|
||||
response = make_request(url, data, check_status=False)
|
||||
print(f"Status code: {response.status_code}")
|
||||
print_json_response(response.json(), "Error message")
|
||||
|
||||
@@ -493,7 +497,7 @@ def test_error_handling() -> None:
|
||||
print("\n--- Testing invalid role field ---")
|
||||
data = create_error_test_data("invalid_role")
|
||||
data["stream"] = False # Change to non-streaming mode
|
||||
response = make_request(url, data)
|
||||
response = make_request(url, data, check_status=False)
|
||||
print(f"Status code: {response.status_code}")
|
||||
print_json_response(response.json(), "Error message")
|
||||
|
||||
@@ -502,7 +506,7 @@ def test_error_handling() -> None:
|
||||
print("\n--- Testing missing content field ---")
|
||||
data = create_error_test_data("missing_content")
|
||||
data["stream"] = False # Change to non-streaming mode
|
||||
response = make_request(url, data)
|
||||
response = make_request(url, data, check_status=False)
|
||||
print(f"Status code: {response.status_code}")
|
||||
print_json_response(response.json(), "Error message")
|
||||
|
||||
@@ -609,7 +613,7 @@ def test_generate_error_handling() -> None:
|
||||
if OutputControl.is_verbose():
|
||||
print("\n=== Testing empty prompt ===")
|
||||
data = create_generate_request_data("", stream=False)
|
||||
response = make_request(url, data)
|
||||
response = make_request(url, data, check_status=False)
|
||||
print(f"Status code: {response.status_code}")
|
||||
print_json_response(response.json(), "Error message")
|
||||
|
||||
@@ -621,7 +625,7 @@ def test_generate_error_handling() -> None:
|
||||
options={"invalid_option": "value"},
|
||||
stream=False,
|
||||
)
|
||||
response = make_request(url, data)
|
||||
response = make_request(url, data, check_status=False)
|
||||
print(f"Status code: {response.status_code}")
|
||||
print_json_response(response.json(), "Error message")
|
||||
|
||||
@@ -642,6 +646,8 @@ def test_generate_concurrent() -> None:
|
||||
data = create_generate_request_data(prompt, stream=False)
|
||||
try:
|
||||
async with session.post(url, json=data) as response:
|
||||
if response.status != 200:
|
||||
response.raise_for_status()
|
||||
return await response.json()
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
@@ -742,7 +748,7 @@ Configuration file (config.json):
|
||||
nargs="+",
|
||||
choices=list(get_test_cases().keys()) + ["all"],
|
||||
default=["all"],
|
||||
help="Test cases to run, options: %(choices)s. Use 'all' to run all tests",
|
||||
help="Test cases to run, options: %(choices)s. Use 'all' to run all tests (except error tests)",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
@@ -766,21 +772,18 @@ if __name__ == "__main__":
|
||||
|
||||
try:
|
||||
if "all" in args.tests:
|
||||
# Run all tests
|
||||
# Run all tests except error handling tests
|
||||
if OutputControl.is_verbose():
|
||||
print("\n【Chat API Tests】")
|
||||
run_test(test_non_stream_chat, "Non-streaming Chat Test")
|
||||
run_test(test_stream_chat, "Streaming Chat Test")
|
||||
run_test(test_query_modes, "Chat Query Mode Test")
|
||||
run_test(test_error_handling, "Chat Error Handling Test")
|
||||
run_test(test_stream_error_handling, "Chat Streaming Error Test")
|
||||
|
||||
if OutputControl.is_verbose():
|
||||
print("\n【Generate API Tests】")
|
||||
run_test(test_non_stream_generate, "Non-streaming Generate Test")
|
||||
run_test(test_stream_generate, "Streaming Generate Test")
|
||||
run_test(test_generate_with_system, "Generate with System Prompt Test")
|
||||
run_test(test_generate_error_handling, "Generate Error Handling Test")
|
||||
run_test(test_generate_concurrent, "Generate Concurrent Test")
|
||||
else:
|
||||
# Run specified tests
|
||||
|
Reference in New Issue
Block a user