Line data Source code
1 : # SPDX-FileCopyrightText: 2025 PairInteraction Developers
2 : # SPDX-License-Identifier: LGPL-3.0-or-later
3 1 : from __future__ import annotations
4 :
5 1 : import logging
6 1 : import math
7 1 : from collections.abc import Callable
8 1 : from typing import TYPE_CHECKING, Any, ClassVar
9 :
10 1 : from PySide6.QtCore import QObject, Qt, QThread, QTimer, Signal
11 1 : from PySide6.QtGui import QColor, QPainter
12 1 : from PySide6.QtWidgets import QApplication, QLabel, QWidget
13 :
14 1 : from pairinteraction import _backend
15 :
16 : if TYPE_CHECKING:
17 : from collections.abc import Callable
18 :
19 1 : logger = logging.getLogger(__name__)
20 :
21 :
22 1 : class SpinnerWidget(QWidget):
23 : """Spinning wheel indicator similar to the OS busy cursor, but larger."""
24 :
25 1 : def __init__(
26 : self,
27 : parent: QWidget,
28 : circle_size: int = 80,
29 : n_dots: int = 12,
30 : dot_radius: int = 5,
31 : label_width: int = 220,
32 : label_height: int = 20,
33 : ) -> None:
34 1 : super().__init__(parent)
35 1 : self._step = 0
36 1 : self._n_dots = n_dots
37 1 : self._dot_radius = dot_radius
38 1 : self._circle_size = circle_size
39 :
40 1 : self._timer = QTimer(self)
41 1 : self._timer.setInterval(80)
42 1 : self._timer.timeout.connect(self._advance)
43 :
44 1 : self.setAttribute(Qt.WidgetAttribute.WA_TranslucentBackground)
45 1 : self.setFixedSize(max(circle_size, label_width), circle_size + 6 + label_height)
46 1 : x, y = (parent.width() - self.width()) // 2, (parent.height() - self.height()) // 2
47 1 : self.setGeometry(x, y, self.width(), self.height())
48 :
49 1 : self._label = QLabel("", self)
50 1 : self._label.setAlignment(Qt.AlignmentFlag.AlignCenter)
51 1 : self._label.setGeometry((self.width() - label_width) // 2, circle_size + 6, label_width, label_height)
52 :
53 1 : self.hide()
54 :
55 1 : def _advance(self) -> None:
56 0 : self._step = (self._step + 1) % self._n_dots
57 0 : self.update()
58 :
59 1 : def start(self) -> None:
60 1 : self._timer.start()
61 1 : self.show()
62 :
63 1 : def stop(self) -> None:
64 1 : self._timer.stop()
65 1 : self.hide()
66 :
67 1 : def set_diagonalization_progress(self, done: int, total: int | None) -> None:
68 0 : if total is not None:
69 0 : self._label.setText(f"Diagonalizing systems {done}/{total}...")
70 : else:
71 0 : self._label.setText(f"Diagonalizing systems {done}...")
72 :
73 1 : def paintEvent(self, event: Any) -> None:
74 1 : painter = QPainter(self)
75 1 : painter.setRenderHint(QPainter.RenderHint.Antialiasing)
76 1 : circle_radius = self._circle_size / 2
77 1 : offset_x = (self.width() - self._circle_size) / 2
78 1 : orbit = circle_radius - self._dot_radius - 4
79 1 : for i in range(self._n_dots):
80 1 : opacity = ((i - self._step) % self._n_dots) / (self._n_dots - 1)
81 1 : painter.setBrush(QColor(128, 128, 128, int(opacity * 220)))
82 1 : painter.setPen(Qt.PenStyle.NoPen)
83 1 : angle = 2 * math.pi * i / self._n_dots
84 1 : x = offset_x + circle_radius + orbit * math.sin(angle)
85 1 : y = circle_radius - orbit * math.cos(angle)
86 1 : painter.drawEllipse(
87 : int(x - self._dot_radius),
88 : int(y - self._dot_radius),
89 : self._dot_radius * 2,
90 : self._dot_radius * 2,
91 : )
92 :
93 :
94 1 : class WorkerSignals(QObject):
95 : """Signals to be used by the Worker class."""
96 :
97 1 : started = Signal()
98 1 : finished = Signal(str)
99 1 : error = Signal(Exception)
100 1 : progress = Signal(str)
101 1 : diag_progress = Signal(int)
102 1 : result = Signal(object)
103 :
104 :
105 1 : class MultiThreadWorker(QThread):
106 : """Simple worker class to run a function in a separate thread.
107 :
108 : Example:
109 : worker = Worker(my_function, arg1, arg2, kwarg1=value1)
110 : worker.signals.result.connect(process_result)
111 : worker.start()
112 :
113 : """
114 :
115 1 : all_threads: ClassVar[set[MultiThreadWorker]] = set()
116 :
117 1 : def __init__(self, fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
118 1 : super().__init__(QApplication.instance())
119 :
120 1 : self.all_threads.add(self)
121 :
122 1 : self.fn = fn
123 1 : self.args = args
124 1 : self.kwargs = kwargs
125 :
126 1 : self.signals = WorkerSignals()
127 :
128 1 : self._last_task_info = ""
129 1 : self._last_diagonalization_status: int = 0
130 1 : self._progress_timer = QTimer(self)
131 1 : self._progress_timer.setInterval(75)
132 1 : self._progress_timer.timeout.connect(self._poll_progress)
133 1 : self.started.connect(self._progress_timer.start)
134 1 : self.finished.connect(self._progress_timer.stop)
135 1 : self.finished.connect(self.finish_up)
136 :
137 1 : @classmethod
138 1 : def current_worker(cls) -> MultiThreadWorker | None:
139 0 : current_thread = QThread.currentThread()
140 0 : if isinstance(current_thread, cls):
141 0 : return current_thread
142 0 : return None
143 :
144 1 : def _poll_progress(self) -> None:
145 1 : if self.isInterruptionRequested():
146 0 : raise _backend.TaskAbortedError
147 :
148 1 : task_info = _backend.get_task_info()
149 1 : if task_info not in ("", self._last_task_info):
150 0 : self._last_task_info = task_info
151 0 : self.signals.progress.emit(task_info)
152 :
153 1 : done = _backend.get_progress_count()
154 1 : if done not in (0, self._last_diagonalization_status):
155 0 : self._last_diagonalization_status = done
156 0 : self.signals.diag_progress.emit(done)
157 :
158 1 : def enable_busy_indicator(
159 : self, widget: QWidget, *, add_progress_label: bool = False, number_of_steps: int | None = None
160 : ) -> None:
161 : """Show a spinning wheel overlay while the worker is running."""
162 1 : self.busy_spinner = SpinnerWidget(widget)
163 :
164 1 : self.signals.started.connect(self.busy_spinner.start)
165 1 : self.signals.finished.connect(self.busy_spinner.stop)
166 1 : if add_progress_label:
167 1 : self.signals.diag_progress.connect(
168 : lambda done: self.busy_spinner.set_diagonalization_progress(done, number_of_steps)
169 : )
170 :
171 1 : def run(self) -> None:
172 : """Initialise the runner function with passed args, kwargs."""
173 1 : logger.debug("Run on thread %s", self)
174 1 : status = None
175 1 : self.signals.started.emit()
176 1 : try:
177 1 : result = self.fn(*self.args, **self.kwargs)
178 1 : status = "Calculation succeeded"
179 1 : except _backend.TaskAbortedError:
180 0 : logger.debug("Calculation thread %s aborted.", self)
181 0 : status = "Calculation aborted"
182 1 : except Exception as err:
183 1 : self.signals.error.emit(err)
184 : else:
185 1 : self.signals.result.emit(result)
186 : finally:
187 1 : if status is None:
188 1 : status = "Calculation failed"
189 1 : _backend.reset_task_status()
190 1 : self.signals.finished.emit(status)
191 :
192 1 : def request_abort(self) -> None:
193 : """Request a cooperative abort for the running task."""
194 0 : logger.debug("Requesting abort for thread %s.", self)
195 0 : self.requestInterruption()
196 0 : _backend.request_task_abort()
197 :
198 1 : def finish_up(self) -> None:
199 : """Perform any final cleanup or actions before the thread exits."""
200 1 : logger.debug("Finishing up thread %s", self)
201 1 : self.all_threads.discard(self)
202 :
203 1 : @classmethod
204 1 : def terminate_all(cls) -> None:
205 : """Terminate all threads started by the application."""
206 : # Shallow copy to avoid error if the set is modified during the loop,
207 : # e.g. if the thread is finished and removes itself from the list
208 1 : all_threads = list(cls.all_threads)
209 1 : for thread in all_threads:
210 1 : if thread.isRunning():
211 0 : thread.request_abort()
212 :
213 1 : for thread in all_threads:
214 1 : if thread.isRunning():
215 0 : logger.debug("Waiting for thread %s to be aborted.", thread)
216 0 : thread.wait()
217 :
218 1 : cls.all_threads.clear()
219 1 : logger.debug("All threads terminated.")
|