Line data Source code
1 : # SPDX-FileCopyrightText: 2025 Pairinteraction Developers 2 : # SPDX-License-Identifier: LGPL-3.0-or-later 3 : 4 0 : from typing import Optional 5 : 6 0 : import matplotlib.pyplot as plt 7 0 : import numpy as np 8 0 : from matplotlib.backends.backend_qtagg import FigureCanvasQTAgg 9 0 : from PySide6.QtCore import Qt, QTimer 10 0 : from PySide6.QtGui import QWheelEvent 11 0 : from PySide6.QtWidgets import QWidget 12 : 13 : 14 0 : class MatplotlibCanvas(FigureCanvasQTAgg): 15 : """Canvas for matplotlib figures.""" 16 : 17 0 : def __init__(self, parent: Optional[QWidget] = None) -> None: 18 : """Initialize the canvas with a figure.""" 19 0 : self.fig, self.ax = plt.subplots() 20 0 : super().__init__(self.fig) 21 : 22 0 : self.setup_zoom() 23 : 24 0 : def reset_view(self) -> None: 25 : """Reset the view to show all data.""" 26 : # for just scatter plots this does not work 27 : # self.ax.relim() 28 : # self.ax.autoscale(enable=True, axis='both') 29 : # so we do it manually for now 30 0 : self.ax.set_xlim(self._xlim_orig) 31 0 : self.ax.set_ylim(self._ylim_orig) 32 0 : self.draw(False) 33 : 34 0 : def setup_zoom(self) -> None: 35 : """Set up mouse wheel zoom functionality.""" 36 : # Wheel event accumulation variables 37 0 : self.wheel_accumulation: float = 0 38 0 : self.last_wheel_pos: list[list[float]] = [] 39 0 : self.wheel_timer = QTimer(self) 40 0 : self.wheel_timer.setSingleShot(True) 41 0 : self.wheel_timer.timeout.connect(self.apply_accumulated_zoom) 42 0 : self.setFocusPolicy(Qt.FocusPolicy.StrongFocus) 43 0 : self.setMouseTracking(True) 44 : 45 0 : def wheelEvent(self, event: QWheelEvent) -> None: 46 : """Handle mouse wheel events for zooming.""" 47 0 : self.wheel_accumulation += event.angleDelta().y() / 120 48 0 : self.last_wheel_pos.append([event.position().x(), event.position().y()]) 49 0 : self.wheel_timer.start(100) # Apply zoom after 100ms of inactivity 50 : 51 0 : def apply_accumulated_zoom(self) -> None: 52 : """Apply the accumulated zoom from wheel events.""" 53 0 : if not self.wheel_accumulation: 54 0 : return 55 : 56 0 : x_min, x_max = self.ax.get_xlim() 57 0 : y_min, y_max = self.ax.get_ylim() 58 : 59 0 : scale_factor = 1 - 0.1 * self.wheel_accumulation 60 : 61 : # Get the mouse position in data coordinates 62 0 : wheel_pos_mean = np.mean(self.last_wheel_pos, axis=0) 63 0 : x_data, y_data = self.ax.transData.inverted().transform(wheel_pos_mean) 64 0 : y_data = -(y_data - (y_max + y_min) / 2) + (y_max + y_min) / 2 # y_data is mirrored bottom / top 65 : 66 0 : self.wheel_accumulation = 0 67 0 : self.last_wheel_pos = [] 68 : 69 0 : if x_data > x_max or y_data > y_max or (x_data < x_min and y_data < y_min): 70 0 : return 71 : 72 0 : if x_min <= x_data <= x_max: 73 0 : x_min_new = x_data - (x_data - x_min) * scale_factor 74 0 : x_max_new = x_data + (x_max - x_data) * scale_factor 75 0 : self.ax.set_xlim(x_min_new, x_max_new) 76 : 77 0 : if y_min <= y_data <= y_max: 78 0 : y_min_new = y_data - (y_data - y_min) * scale_factor 79 0 : y_max_new = y_data + (y_max - y_data) * scale_factor 80 0 : self.ax.set_ylim(y_min_new, y_max_new) 81 : 82 : # Redraw the canvas 83 0 : self.draw(False) 84 : 85 0 : def draw(self, new_data: bool = True) -> None: 86 : """Draw the canvas.""" 87 0 : super().draw() 88 0 : if new_data: 89 0 : self._xlim_orig = self.ax.get_xlim() 90 0 : self._ylim_orig = self.ax.get_ylim()