LCOV - code coverage report
Current view: top level - src/pairinteraction_gui - worker.py (source / functions) Hit Total Coverage
Test: coverage.info Lines: 108 134 80.6 %
Date: 2025-09-29 10:28:29 Functions: 14 34 41.2 %

          Line data    Source code
       1             : # SPDX-FileCopyrightText: 2025 PairInteraction Developers
       2             : # SPDX-License-Identifier: LGPL-3.0-or-later
       3           1 : from __future__ import annotations
       4             : 
       5           1 : import logging
       6           1 : import os
       7           1 : from functools import wraps
       8           1 : from multiprocessing.pool import Pool
       9           1 : from pathlib import Path
      10           1 : from threading import Thread
      11           1 : from typing import TYPE_CHECKING, Any, Callable, ClassVar, TypeVar
      12             : 
      13           1 : from PySide6.QtCore import QObject, QSize, Qt, QThread, Signal
      14           1 : from PySide6.QtGui import QMovie
      15           1 : from PySide6.QtWidgets import QApplication, QLabel
      16             : 
      17             : if TYPE_CHECKING:
      18             :     from PySide6.QtWidgets import QWidget
      19             :     from typing_extensions import ParamSpec
      20             : 
      21             :     P = ParamSpec("P")
      22             :     R = TypeVar("R")
      23             : 
      24           1 : logger = logging.getLogger(__name__)
      25             : 
      26             : 
      27           1 : class WorkerSignals(QObject):
      28             :     """Signals to be used by the Worker class."""
      29             : 
      30           1 :     started = Signal()
      31           1 :     finished = Signal(bool)
      32           1 :     error = Signal(Exception)
      33           1 :     result = Signal(object)
      34             : 
      35             : 
      36           1 : class MultiThreadWorker(QThread):
      37             :     """Simple worker class to run a function in a separate thread.
      38             : 
      39             :     Example:
      40             :     worker = Worker(my_function, arg1, arg2, kwarg1=value1)
      41             :     worker.signals.result.connect(process_result)
      42             :     worker.start()
      43             : 
      44             :     """
      45             : 
      46           1 :     all_threads: ClassVar[set[MultiThreadWorker]] = set()
      47             : 
      48           1 :     def __init__(self, fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
      49           1 :         super().__init__(QApplication.instance())
      50             : 
      51           1 :         self.all_threads.add(self)
      52             : 
      53           1 :         self.fn = fn
      54           1 :         self.args = args
      55           1 :         self.kwargs = kwargs
      56             : 
      57           1 :         self.signals = WorkerSignals()
      58           1 :         self.finished.connect(self.finish_up)
      59             : 
      60           1 :     def enable_busy_indicator(self, widget: QWidget) -> None:
      61             :         """Run a loading gif while the worker is running."""
      62           1 :         self.busy_label = QLabel(widget)
      63           1 :         gif_path = Path(__file__).parent / "images" / "loading.gif"
      64           1 :         self.busy_movie = QMovie(str(gif_path))
      65           1 :         self.busy_movie.setScaledSize(QSize(100, 100))  # Make the gif larger
      66           1 :         self.busy_label.setMovie(self.busy_movie)
      67           1 :         self.busy_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
      68           1 :         self.busy_label.setGeometry((widget.width() - 100) // 2, (widget.height() - 100) // 2, 100, 100)
      69             : 
      70           1 :         self.signals.started.connect(self._start_gif)
      71           1 :         self.signals.finished.connect(self._stop_gif)
      72             : 
      73           1 :     def _start_gif(self) -> None:
      74           1 :         self.busy_label.show()
      75           1 :         self.busy_movie.start()
      76             : 
      77           1 :     def _stop_gif(self, _success: bool = True) -> None:
      78           1 :         if hasattr(self, "busy_movie"):
      79           1 :             self.busy_movie.stop()
      80           1 :         if hasattr(self, "busy_label"):
      81           1 :             self.busy_label.hide()
      82             : 
      83           1 :     def run(self) -> None:
      84             :         """Initialise the runner function with passed args, kwargs."""
      85           0 :         logger.debug("Run on thread %s", self)
      86           0 :         success = False
      87           0 :         self.signals.started.emit()
      88           0 :         try:
      89           0 :             result = self.fn(*self.args, **self.kwargs)
      90           0 :             success = True
      91           0 :         except Exception as err:
      92           0 :             self.signals.error.emit(err)
      93             :         else:
      94           0 :             self.signals.result.emit(result)
      95             :         finally:
      96           0 :             self.signals.finished.emit(success)
      97             : 
      98           1 :     def finish_up(self) -> None:
      99             :         """Perform any final cleanup or actions before the thread exits."""
     100           1 :         logger.debug("Finishing up thread %s", self)
     101           1 :         self.all_threads.discard(self)
     102             : 
     103           1 :     @classmethod
     104           1 :     def terminate_all(cls) -> None:
     105             :         """Terminate all threads started by the application."""
     106             :         # Shallow copy to avoid error if the set is modified during the loop,
     107             :         # e.g. if the thread is finished and removes itself from the list
     108           1 :         all_threads = list(cls.all_threads)
     109           1 :         for thread in all_threads:
     110           0 :             if thread.isRunning():
     111           0 :                 logger.debug("Terminating thread %s.", thread)
     112           0 :                 thread.terminate()
     113           0 :                 thread.signals.finished.emit(False)
     114           0 :                 thread.wait()
     115             : 
     116           1 :         cls.all_threads.clear()
     117           1 :         logger.debug("All threads terminated.")
     118             : 
     119             : 
     120           1 : class MultiProcessWorker:
     121           1 :     _mp_functions_dict: ClassVar[dict[str, Callable[..., Any]]] = {}
     122           1 :     _pool: ClassVar[Pool | None] = None
     123           1 :     _async_worker: ClassVar[Thread | None] = None
     124             : 
     125           1 :     def __init__(self, fn_name: str, *args: Any, **kwargs: Any) -> None:
     126           1 :         if fn_name not in self._mp_functions_dict:
     127           0 :             raise ValueError(f"Function {fn_name} is not registered.")
     128             : 
     129           1 :         self.fn_name = fn_name
     130           1 :         self.args = args
     131           1 :         self.kwargs = kwargs
     132             : 
     133           1 :     @classmethod
     134           1 :     def create_pool(cls, n_processes: int = 1) -> None:
     135             :         """Create a pool of processes."""
     136           1 :         if cls._pool is not None or cls._async_worker is not None:
     137           0 :             raise RuntimeError(
     138             :                 "create_pool already called. Use terminate_all(create_new_pool=True) to restart the pool."
     139             :             )
     140             : 
     141           1 :         cls._async_worker = Thread(target=cls._create_pool, args=(n_processes,))
     142           1 :         cls._async_worker.start()
     143             : 
     144           1 :     @classmethod
     145           1 :     def _create_pool(cls, n_processes: int) -> None:
     146             :         """Create a pool of processes."""
     147           1 :         cls._pool = Pool(n_processes)
     148           1 :         cls._pool.apply(cls._dummy_function)  # Call the pool once, to make the next call faster
     149           1 :         cls._async_worker = None
     150           1 :         logger.debug("Pool created successfully.")
     151             : 
     152           1 :     @staticmethod
     153           1 :     def _dummy_function() -> None:
     154             :         """Do nothing.
     155             : 
     156             :         Dummy function to run after creating the pool asynchronously.
     157             :         """
     158           0 :         return
     159             : 
     160           1 :     @classmethod
     161           1 :     def register(cls, func: Callable[..., Any], name: str | None = None) -> None:
     162           1 :         name = name if name is not None else func.__name__
     163           1 :         if name in cls._mp_functions_dict:
     164           0 :             raise ValueError(f"Function {name} is already registered.")
     165           1 :         cls._mp_functions_dict[name] = func
     166             : 
     167           1 :     def start(self) -> Any:
     168           1 :         async_worker = self._async_worker
     169           1 :         if async_worker is not None:
     170           1 :             logger.debug("Waiting for creating_pool to finish.")
     171           1 :             async_worker.join()
     172           1 :             logger.debug("creating_pool finished.")
     173             : 
     174           1 :         if self._pool is None:
     175           0 :             raise RuntimeError("Pool is not created. Call create_pool() first.")
     176             : 
     177           1 :         logger.debug("Starting pool.apply")
     178           1 :         result = self._pool.apply(self.run)
     179           1 :         logger.debug("Finished pool.apply")
     180           1 :         return result
     181             : 
     182           1 :     def run(self) -> Any:
     183           0 :         logger.debug("Run on process %s", os.getpid())
     184           0 :         func = self._mp_functions_dict[self.fn_name]
     185           0 :         return func(*self.args, **self.kwargs)
     186             : 
     187           1 :     @classmethod
     188           1 :     def terminate_all(cls, create_new_pool: bool) -> None:
     189             :         """Terminate all processes."""
     190           1 :         if cls._async_worker is not None:
     191           0 :             cls._async_worker.join()
     192             : 
     193           1 :         if cls._pool is None:
     194           0 :             return
     195             : 
     196           1 :         cls._pool.terminate()
     197           1 :         cls._pool = None
     198           1 :         logger.debug("Process pool terminated.")
     199             : 
     200           1 :         if create_new_pool:
     201           0 :             cls.create_pool()
     202             : 
     203             : 
     204           1 : def run_in_other_process(func: Callable[P, R]) -> Callable[P, R]:
     205           1 :     MultiProcessWorker.register(func)
     206             : 
     207           1 :     @wraps(func)
     208           1 :     def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
     209           1 :         return MultiProcessWorker(func.__name__, *args, **kwargs).start()  # type: ignore [no-any-return]
     210             : 
     211           1 :     return wrapper

Generated by: LCOV version 1.16