LCOV - code coverage report
Current view: top level - src/pairinteraction_gui - worker.py (source / functions) Hit Total Coverage
Test: coverage.info Lines: 116 138 84.1 %
Date: 2026-04-17 09:29:39 Functions: 10 14 71.4 %

          Line data    Source code
       1             : # SPDX-FileCopyrightText: 2025 PairInteraction Developers
       2             : # SPDX-License-Identifier: LGPL-3.0-or-later
       3           1 : from __future__ import annotations
       4             : 
       5           1 : import logging
       6           1 : import math
       7           1 : from collections.abc import Callable
       8           1 : from typing import TYPE_CHECKING, Any, ClassVar
       9             : 
      10           1 : from PySide6.QtCore import QObject, Qt, QThread, QTimer, Signal
      11           1 : from PySide6.QtGui import QColor, QPainter
      12           1 : from PySide6.QtWidgets import QApplication, QLabel, QWidget
      13             : 
      14           1 : from pairinteraction import _backend
      15             : 
      16             : if TYPE_CHECKING:
      17             :     from collections.abc import Callable
      18             : 
      19           1 : logger = logging.getLogger(__name__)
      20             : 
      21             : 
      22           1 : class SpinnerWidget(QWidget):
      23             :     """Spinning wheel indicator similar to the OS busy cursor, but larger."""
      24             : 
      25           1 :     def __init__(
      26             :         self,
      27             :         parent: QWidget,
      28             :         circle_size: int = 80,
      29             :         n_dots: int = 12,
      30             :         dot_radius: int = 5,
      31             :         label_width: int = 220,
      32             :         label_height: int = 20,
      33             :     ) -> None:
      34           1 :         super().__init__(parent)
      35           1 :         self._step = 0
      36           1 :         self._n_dots = n_dots
      37           1 :         self._dot_radius = dot_radius
      38           1 :         self._circle_size = circle_size
      39             : 
      40           1 :         self._timer = QTimer(self)
      41           1 :         self._timer.setInterval(80)
      42           1 :         self._timer.timeout.connect(self._advance)
      43             : 
      44           1 :         self.setAttribute(Qt.WidgetAttribute.WA_TranslucentBackground)
      45           1 :         self.setFixedSize(max(circle_size, label_width), circle_size + 6 + label_height)
      46           1 :         x, y = (parent.width() - self.width()) // 2, (parent.height() - self.height()) // 2
      47           1 :         self.setGeometry(x, y, self.width(), self.height())
      48             : 
      49           1 :         self._label = QLabel("", self)
      50           1 :         self._label.setAlignment(Qt.AlignmentFlag.AlignCenter)
      51           1 :         self._label.setGeometry((self.width() - label_width) // 2, circle_size + 6, label_width, label_height)
      52             : 
      53           1 :         self.hide()
      54             : 
      55           1 :     def _advance(self) -> None:
      56           0 :         self._step = (self._step + 1) % self._n_dots
      57           0 :         self.update()
      58             : 
      59           1 :     def start(self) -> None:
      60           1 :         self._timer.start()
      61           1 :         self.show()
      62             : 
      63           1 :     def stop(self) -> None:
      64           1 :         self._timer.stop()
      65           1 :         self.hide()
      66             : 
      67           1 :     def set_diagonalization_progress(self, done: int, total: int | None) -> None:
      68           0 :         if total is not None:
      69           0 :             self._label.setText(f"Diagonalizing systems {done}/{total}...")
      70             :         else:
      71           0 :             self._label.setText(f"Diagonalizing systems {done}...")
      72             : 
      73           1 :     def paintEvent(self, event: Any) -> None:
      74           1 :         painter = QPainter(self)
      75           1 :         painter.setRenderHint(QPainter.RenderHint.Antialiasing)
      76           1 :         circle_radius = self._circle_size / 2
      77           1 :         offset_x = (self.width() - self._circle_size) / 2
      78           1 :         orbit = circle_radius - self._dot_radius - 4
      79           1 :         for i in range(self._n_dots):
      80           1 :             opacity = ((i - self._step) % self._n_dots) / (self._n_dots - 1)
      81           1 :             painter.setBrush(QColor(128, 128, 128, int(opacity * 220)))
      82           1 :             painter.setPen(Qt.PenStyle.NoPen)
      83           1 :             angle = 2 * math.pi * i / self._n_dots
      84           1 :             x = offset_x + circle_radius + orbit * math.sin(angle)
      85           1 :             y = circle_radius - orbit * math.cos(angle)
      86           1 :             painter.drawEllipse(
      87             :                 int(x - self._dot_radius),
      88             :                 int(y - self._dot_radius),
      89             :                 self._dot_radius * 2,
      90             :                 self._dot_radius * 2,
      91             :             )
      92             : 
      93             : 
      94           1 : class WorkerSignals(QObject):
      95             :     """Signals to be used by the Worker class."""
      96             : 
      97           1 :     started = Signal()
      98           1 :     finished = Signal(str)
      99           1 :     error = Signal(Exception)
     100           1 :     progress = Signal(str)
     101           1 :     diag_progress = Signal(int)
     102           1 :     result = Signal(object)
     103             : 
     104             : 
     105           1 : class MultiThreadWorker(QThread):
     106             :     """Simple worker class to run a function in a separate thread.
     107             : 
     108             :     Example:
     109             :     worker = Worker(my_function, arg1, arg2, kwarg1=value1)
     110             :     worker.signals.result.connect(process_result)
     111             :     worker.start()
     112             : 
     113             :     """
     114             : 
     115           1 :     all_threads: ClassVar[set[MultiThreadWorker]] = set()
     116             : 
     117           1 :     def __init__(self, fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
     118           1 :         super().__init__(QApplication.instance())
     119             : 
     120           1 :         self.all_threads.add(self)
     121             : 
     122           1 :         self.fn = fn
     123           1 :         self.args = args
     124           1 :         self.kwargs = kwargs
     125             : 
     126           1 :         self.signals = WorkerSignals()
     127             : 
     128           1 :         self._last_task_info = ""
     129           1 :         self._last_diagonalization_status: int = 0
     130           1 :         self._progress_timer = QTimer(self)
     131           1 :         self._progress_timer.setInterval(75)
     132           1 :         self._progress_timer.timeout.connect(self._poll_progress)
     133           1 :         self.started.connect(self._progress_timer.start)
     134           1 :         self.finished.connect(self._progress_timer.stop)
     135           1 :         self.finished.connect(self.finish_up)
     136             : 
     137           1 :     @classmethod
     138           1 :     def current_worker(cls) -> MultiThreadWorker | None:
     139           0 :         current_thread = QThread.currentThread()
     140           0 :         if isinstance(current_thread, cls):
     141           0 :             return current_thread
     142           0 :         return None
     143             : 
     144           1 :     def _poll_progress(self) -> None:
     145           1 :         if self.isInterruptionRequested():
     146           0 :             raise _backend.TaskAbortedError
     147             : 
     148           1 :         task_info = _backend.get_task_info()
     149           1 :         if task_info not in ("", self._last_task_info):
     150           0 :             self._last_task_info = task_info
     151           0 :             self.signals.progress.emit(task_info)
     152             : 
     153           1 :         done = _backend.get_progress_count()
     154           1 :         if done not in (0, self._last_diagonalization_status):
     155           0 :             self._last_diagonalization_status = done
     156           0 :             self.signals.diag_progress.emit(done)
     157             : 
     158           1 :     def enable_busy_indicator(
     159             :         self, widget: QWidget, *, add_progress_label: bool = False, number_of_steps: int | None = None
     160             :     ) -> None:
     161             :         """Show a spinning wheel overlay while the worker is running."""
     162           1 :         self.busy_spinner = SpinnerWidget(widget)
     163             : 
     164           1 :         self.signals.started.connect(self.busy_spinner.start)
     165           1 :         self.signals.finished.connect(self.busy_spinner.stop)
     166           1 :         if add_progress_label:
     167           1 :             self.signals.diag_progress.connect(
     168             :                 lambda done: self.busy_spinner.set_diagonalization_progress(done, number_of_steps)
     169             :             )
     170             : 
     171           1 :     def run(self) -> None:
     172             :         """Initialise the runner function with passed args, kwargs."""
     173           1 :         logger.debug("Run on thread %s", self)
     174           1 :         status = None
     175           1 :         self.signals.started.emit()
     176           1 :         try:
     177           1 :             result = self.fn(*self.args, **self.kwargs)
     178           1 :             status = "Calculation succeeded"
     179           1 :         except _backend.TaskAbortedError:
     180           0 :             logger.debug("Calculation thread %s aborted.", self)
     181           0 :             status = "Calculation aborted"
     182           1 :         except Exception as err:
     183           1 :             self.signals.error.emit(err)
     184             :         else:
     185           1 :             self.signals.result.emit(result)
     186             :         finally:
     187           1 :             if status is None:
     188           1 :                 status = "Calculation failed"
     189           1 :             _backend.reset_task_status()
     190           1 :             self.signals.finished.emit(status)
     191             : 
     192           1 :     def request_abort(self) -> None:
     193             :         """Request a cooperative abort for the running task."""
     194           0 :         logger.debug("Requesting abort for thread %s.", self)
     195           0 :         self.requestInterruption()
     196           0 :         _backend.request_task_abort()
     197             : 
     198           1 :     def finish_up(self) -> None:
     199             :         """Perform any final cleanup or actions before the thread exits."""
     200           1 :         logger.debug("Finishing up thread %s", self)
     201           1 :         self.all_threads.discard(self)
     202             : 
     203           1 :     @classmethod
     204           1 :     def terminate_all(cls) -> None:
     205             :         """Terminate all threads started by the application."""
     206             :         # Shallow copy to avoid error if the set is modified during the loop,
     207             :         # e.g. if the thread is finished and removes itself from the list
     208           1 :         all_threads = list(cls.all_threads)
     209           1 :         for thread in all_threads:
     210           1 :             if thread.isRunning():
     211           0 :                 thread.request_abort()
     212             : 
     213           1 :         for thread in all_threads:
     214           1 :             if thread.isRunning():
     215           0 :                 logger.debug("Waiting for thread %s to be aborted.", thread)
     216           0 :                 thread.wait()
     217             : 
     218           1 :         cls.all_threads.clear()
     219           1 :         logger.debug("All threads terminated.")

Generated by: LCOV version 1.16