LCOV - code coverage report
Current view: top level - src/pairinteraction_gui - worker.py (source / functions) Hit Total Coverage
Test: coverage.info Lines: 108 133 81.2 %
Date: 2025-06-06 09:09:03 Functions: 14 34 41.2 %

          Line data    Source code
       1             : # SPDX-FileCopyrightText: 2025 Pairinteraction Developers
       2             : # SPDX-License-Identifier: LGPL-3.0-or-later
       3             : 
       4           1 : import logging
       5           1 : import os
       6           1 : from functools import wraps
       7           1 : from multiprocessing.pool import Pool
       8           1 : from pathlib import Path
       9           1 : from threading import Thread
      10           1 : from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, TypeVar
      11             : 
      12           1 : from PySide6.QtCore import QObject, QSize, Qt, QThread, Signal
      13           1 : from PySide6.QtGui import QMovie
      14           1 : from PySide6.QtWidgets import QApplication, QLabel, QWidget
      15             : 
      16             : if TYPE_CHECKING:
      17             :     from typing_extensions import ParamSpec
      18             : 
      19             :     P = ParamSpec("P")
      20             :     R = TypeVar("R")
      21             : 
      22           1 : logger = logging.getLogger(__name__)
      23             : 
      24             : 
      25           1 : class WorkerSignals(QObject):
      26             :     """Signals to be used by the Worker class."""
      27             : 
      28           1 :     started = Signal()
      29           1 :     finished = Signal(bool)
      30           1 :     error = Signal(Exception)
      31           1 :     result = Signal(object)
      32             : 
      33             : 
      34           1 : class MultiThreadWorker(QThread):
      35             :     """Simple worker class to run a function in a separate thread.
      36             : 
      37             :     Example:
      38             :     worker = Worker(my_function, arg1, arg2, kwarg1=value1)
      39             :     worker.signals.result.connect(process_result)
      40             :     worker.start()
      41             : 
      42             :     """
      43             : 
      44           1 :     all_threads: ClassVar[set["MultiThreadWorker"]] = set()
      45             : 
      46           1 :     def __init__(self, fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
      47           1 :         super().__init__(QApplication.instance())
      48             : 
      49           1 :         self.all_threads.add(self)
      50             : 
      51           1 :         self.fn = fn
      52           1 :         self.args = args
      53           1 :         self.kwargs = kwargs
      54             : 
      55           1 :         self.signals = WorkerSignals()
      56           1 :         self.finished.connect(self.finish_up)
      57             : 
      58           1 :     def enable_busy_indicator(self, widget: "QWidget") -> None:
      59             :         """Run a loading gif while the worker is running."""
      60           1 :         self.busy_label = QLabel(widget)
      61           1 :         gif_path = Path(__file__).parent / "images" / "loading.gif"
      62           1 :         self.busy_movie = QMovie(str(gif_path))
      63           1 :         self.busy_movie.setScaledSize(QSize(100, 100))  # Make the gif larger
      64           1 :         self.busy_label.setMovie(self.busy_movie)
      65           1 :         self.busy_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
      66           1 :         self.busy_label.setGeometry((widget.width() - 100) // 2, (widget.height() - 100) // 2, 100, 100)
      67             : 
      68           1 :         self.signals.started.connect(self._start_gif)
      69           1 :         self.signals.finished.connect(self._stop_gif)
      70             : 
      71           1 :     def _start_gif(self) -> None:
      72           1 :         self.busy_label.show()
      73           1 :         self.busy_movie.start()
      74             : 
      75           1 :     def _stop_gif(self, _success: bool = True) -> None:
      76           1 :         if hasattr(self, "busy_movie"):
      77           1 :             self.busy_movie.stop()
      78           1 :         if hasattr(self, "busy_label"):
      79           1 :             self.busy_label.hide()
      80             : 
      81           1 :     def run(self) -> None:
      82             :         """Initialise the runner function with passed args, kwargs."""
      83           0 :         logger.debug("Run on thread %s", self)
      84           0 :         success = False
      85           0 :         self.signals.started.emit()
      86           0 :         try:
      87           0 :             result = self.fn(*self.args, **self.kwargs)
      88           0 :             success = True
      89           0 :         except Exception as err:
      90           0 :             self.signals.error.emit(err)
      91             :         else:
      92           0 :             self.signals.result.emit(result)
      93             :         finally:
      94           0 :             self.signals.finished.emit(success)
      95             : 
      96           1 :     def finish_up(self) -> None:
      97             :         """Perform any final cleanup or actions before the thread exits."""
      98           1 :         logger.debug("Finishing up thread %s", self)
      99           1 :         self.all_threads.discard(self)
     100             : 
     101           1 :     @classmethod
     102           1 :     def terminate_all(cls) -> None:
     103             :         """Terminate all threads started by the application."""
     104             :         # Shallow copy to avoid error if the set is modified during the loop,
     105             :         # e.g. if the thread is finished and removes itself from the list
     106           1 :         all_threads = list(cls.all_threads)
     107           1 :         for thread in all_threads:
     108           0 :             if thread.isRunning():
     109           0 :                 logger.debug("Terminating thread %s.", thread)
     110           0 :                 thread.terminate()
     111           0 :                 thread.signals.finished.emit(False)
     112           0 :                 thread.wait()
     113             : 
     114           1 :         cls.all_threads.clear()
     115           1 :         logger.debug("All threads terminated.")
     116             : 
     117             : 
     118           1 : class MultiProcessWorker:
     119           1 :     _mp_functions_dict: ClassVar[dict[str, Callable[..., Any]]] = {}
     120           1 :     _pool: ClassVar[Optional[Pool]] = None
     121           1 :     _async_worker: ClassVar[Optional[Thread]] = None
     122             : 
     123           1 :     def __init__(self, fn_name: str, *args: Any, **kwargs: Any) -> None:
     124           1 :         if fn_name not in self._mp_functions_dict:
     125           0 :             raise ValueError(f"Function {fn_name} is not registered.")
     126             : 
     127           1 :         self.fn_name = fn_name
     128           1 :         self.args = args
     129           1 :         self.kwargs = kwargs
     130             : 
     131           1 :     @classmethod
     132           1 :     def create_pool(cls, n_processes: int = 1) -> None:
     133             :         """Create a pool of processes."""
     134           1 :         if cls._pool is not None or cls._async_worker is not None:
     135           0 :             raise RuntimeError(
     136             :                 "create_pool already called. Use terminate_all(create_new_pool=True) to restart the pool."
     137             :             )
     138             : 
     139           1 :         cls._async_worker = Thread(target=cls._create_pool, args=(n_processes,))
     140           1 :         cls._async_worker.start()
     141             : 
     142           1 :     @classmethod
     143           1 :     def _create_pool(cls, n_processes: int) -> None:
     144             :         """Create a pool of processes."""
     145           1 :         cls._pool = Pool(n_processes)
     146           1 :         cls._pool.apply(cls._dummy_function)  # Call the pool once, to make the next call faster
     147           1 :         cls._async_worker = None
     148           1 :         logger.debug("Pool created successfully.")
     149             : 
     150           1 :     @staticmethod
     151           1 :     def _dummy_function() -> None:
     152             :         """Do nothing.
     153             : 
     154             :         Dummy function to run after creating the pool asynchronously.
     155             :         """
     156           0 :         return
     157             : 
     158           1 :     @classmethod
     159           1 :     def register(cls, func: Callable[..., Any], name: Optional[str] = None) -> None:
     160           1 :         name = name if name is not None else func.__name__
     161           1 :         if name in cls._mp_functions_dict:
     162           0 :             raise ValueError(f"Function {name} is already registered.")
     163           1 :         cls._mp_functions_dict[name] = func
     164             : 
     165           1 :     def start(self) -> Any:
     166           1 :         async_worker = self._async_worker
     167           1 :         if async_worker is not None:
     168           1 :             logger.debug("Waiting for creating_pool to finish.")
     169           1 :             async_worker.join()
     170           1 :             logger.debug("creating_pool finished.")
     171             : 
     172           1 :         if self._pool is None:
     173           0 :             raise RuntimeError("Pool is not created. Call create_pool() first.")
     174             : 
     175           1 :         logger.debug("Starting pool.apply")
     176           1 :         result = self._pool.apply(self.run)
     177           1 :         logger.debug("Finished pool.apply")
     178           1 :         return result
     179             : 
     180           1 :     def run(self) -> Any:
     181           0 :         logger.debug("Run on process %s", os.getpid())
     182           0 :         func = self._mp_functions_dict[self.fn_name]
     183           0 :         return func(*self.args, **self.kwargs)
     184             : 
     185           1 :     @classmethod
     186           1 :     def terminate_all(cls, create_new_pool: bool) -> None:
     187             :         """Terminate all processes."""
     188           1 :         if cls._async_worker is not None:
     189           1 :             cls._async_worker.join()
     190             : 
     191           1 :         if cls._pool is None:
     192           0 :             return
     193             : 
     194           1 :         cls._pool.terminate()
     195           1 :         cls._pool = None
     196           1 :         logger.debug("Process pool terminated.")
     197             : 
     198           1 :         if create_new_pool:
     199           0 :             cls.create_pool()
     200             : 
     201             : 
     202           1 : def run_in_other_process(func: Callable["P", "R"]) -> Callable["P", "R"]:
     203           1 :     MultiProcessWorker.register(func)
     204             : 
     205           1 :     @wraps(func)
     206           1 :     def wrapper_func(*args: "P.args", **kwargs: "P.kwargs") -> "R":
     207           1 :         return MultiProcessWorker(func.__name__, *args, **kwargs).start()  # type: ignore [no-any-return]
     208             : 
     209           1 :     return wrapper_func

Generated by: LCOV version 1.16