Line data Source code
1 : # SPDX-FileCopyrightText: 2025 PairInteraction Developers
2 : # SPDX-License-Identifier: LGPL-3.0-or-later
3 1 : from __future__ import annotations
4 :
5 1 : import logging
6 1 : import os
7 1 : from functools import wraps
8 1 : from multiprocessing.pool import Pool
9 1 : from pathlib import Path
10 1 : from threading import Thread
11 1 : from typing import TYPE_CHECKING, Any, Callable, ClassVar, TypeVar
12 :
13 1 : from PySide6.QtCore import QObject, QSize, Qt, QThread, Signal
14 1 : from PySide6.QtGui import QMovie
15 1 : from PySide6.QtWidgets import QApplication, QLabel
16 :
17 : if TYPE_CHECKING:
18 : from PySide6.QtWidgets import QWidget
19 : from typing_extensions import ParamSpec
20 :
21 : P = ParamSpec("P")
22 : R = TypeVar("R")
23 :
24 1 : logger = logging.getLogger(__name__)
25 :
26 :
27 1 : class WorkerSignals(QObject):
28 : """Signals to be used by the Worker class."""
29 :
30 1 : started = Signal()
31 1 : finished = Signal(bool)
32 1 : error = Signal(Exception)
33 1 : result = Signal(object)
34 :
35 :
36 1 : class MultiThreadWorker(QThread):
37 : """Simple worker class to run a function in a separate thread.
38 :
39 : Example:
40 : worker = Worker(my_function, arg1, arg2, kwarg1=value1)
41 : worker.signals.result.connect(process_result)
42 : worker.start()
43 :
44 : """
45 :
46 1 : all_threads: ClassVar[set[MultiThreadWorker]] = set()
47 :
48 1 : def __init__(self, fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
49 1 : super().__init__(QApplication.instance())
50 :
51 1 : self.all_threads.add(self)
52 :
53 1 : self.fn = fn
54 1 : self.args = args
55 1 : self.kwargs = kwargs
56 :
57 1 : self.signals = WorkerSignals()
58 1 : self.finished.connect(self.finish_up)
59 :
60 1 : def enable_busy_indicator(self, widget: QWidget) -> None:
61 : """Run a loading gif while the worker is running."""
62 1 : self.busy_label = QLabel(widget)
63 1 : gif_path = Path(__file__).parent / "images" / "loading.gif"
64 1 : self.busy_movie = QMovie(str(gif_path))
65 1 : self.busy_movie.setScaledSize(QSize(100, 100)) # Make the gif larger
66 1 : self.busy_label.setMovie(self.busy_movie)
67 1 : self.busy_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
68 1 : self.busy_label.setGeometry((widget.width() - 100) // 2, (widget.height() - 100) // 2, 100, 100)
69 :
70 1 : self.signals.started.connect(self._start_gif)
71 1 : self.signals.finished.connect(self._stop_gif)
72 :
73 1 : def _start_gif(self) -> None:
74 1 : self.busy_label.show()
75 1 : self.busy_movie.start()
76 :
77 1 : def _stop_gif(self, _success: bool = True) -> None:
78 1 : if hasattr(self, "busy_movie"):
79 1 : self.busy_movie.stop()
80 1 : if hasattr(self, "busy_label"):
81 1 : self.busy_label.hide()
82 :
83 1 : def run(self) -> None:
84 : """Initialise the runner function with passed args, kwargs."""
85 0 : logger.debug("Run on thread %s", self)
86 0 : success = False
87 0 : self.signals.started.emit()
88 0 : try:
89 0 : result = self.fn(*self.args, **self.kwargs)
90 0 : success = True
91 0 : except Exception as err:
92 0 : self.signals.error.emit(err)
93 : else:
94 0 : self.signals.result.emit(result)
95 : finally:
96 0 : self.signals.finished.emit(success)
97 :
98 1 : def finish_up(self) -> None:
99 : """Perform any final cleanup or actions before the thread exits."""
100 1 : logger.debug("Finishing up thread %s", self)
101 1 : self.all_threads.discard(self)
102 :
103 1 : @classmethod
104 1 : def terminate_all(cls) -> None:
105 : """Terminate all threads started by the application."""
106 : # Shallow copy to avoid error if the set is modified during the loop,
107 : # e.g. if the thread is finished and removes itself from the list
108 1 : all_threads = list(cls.all_threads)
109 1 : for thread in all_threads:
110 0 : if thread.isRunning():
111 0 : logger.debug("Terminating thread %s.", thread)
112 0 : thread.terminate()
113 0 : thread.signals.finished.emit(False)
114 0 : thread.wait()
115 :
116 1 : cls.all_threads.clear()
117 1 : logger.debug("All threads terminated.")
118 :
119 :
120 1 : class MultiProcessWorker:
121 1 : _mp_functions_dict: ClassVar[dict[str, Callable[..., Any]]] = {}
122 1 : _pool: ClassVar[Pool | None] = None
123 1 : _async_worker: ClassVar[Thread | None] = None
124 :
125 1 : def __init__(self, fn_name: str, *args: Any, **kwargs: Any) -> None:
126 1 : if fn_name not in self._mp_functions_dict:
127 0 : raise ValueError(f"Function {fn_name} is not registered.")
128 :
129 1 : self.fn_name = fn_name
130 1 : self.args = args
131 1 : self.kwargs = kwargs
132 :
133 1 : @classmethod
134 1 : def create_pool(cls, n_processes: int = 1) -> None:
135 : """Create a pool of processes."""
136 1 : if cls._pool is not None or cls._async_worker is not None:
137 0 : raise RuntimeError(
138 : "create_pool already called. Use terminate_all(create_new_pool=True) to restart the pool."
139 : )
140 :
141 1 : cls._async_worker = Thread(target=cls._create_pool, args=(n_processes,))
142 1 : cls._async_worker.start()
143 :
144 1 : @classmethod
145 1 : def _create_pool(cls, n_processes: int) -> None:
146 : """Create a pool of processes."""
147 1 : cls._pool = Pool(n_processes)
148 1 : cls._pool.apply(cls._dummy_function) # Call the pool once, to make the next call faster
149 1 : cls._async_worker = None
150 1 : logger.debug("Pool created successfully.")
151 :
152 1 : @staticmethod
153 1 : def _dummy_function() -> None:
154 : """Do nothing.
155 :
156 : Dummy function to run after creating the pool asynchronously.
157 : """
158 0 : return
159 :
160 1 : @classmethod
161 1 : def register(cls, func: Callable[..., Any], name: str | None = None) -> None:
162 1 : name = name if name is not None else func.__name__
163 1 : if name in cls._mp_functions_dict:
164 0 : raise ValueError(f"Function {name} is already registered.")
165 1 : cls._mp_functions_dict[name] = func
166 :
167 1 : def start(self) -> Any:
168 1 : async_worker = self._async_worker
169 1 : if async_worker is not None:
170 1 : logger.debug("Waiting for creating_pool to finish.")
171 1 : async_worker.join()
172 1 : logger.debug("creating_pool finished.")
173 :
174 1 : if self._pool is None:
175 0 : raise RuntimeError("Pool is not created. Call create_pool() first.")
176 :
177 1 : logger.debug("Starting pool.apply")
178 1 : result = self._pool.apply(self.run)
179 1 : logger.debug("Finished pool.apply")
180 1 : return result
181 :
182 1 : def run(self) -> Any:
183 0 : logger.debug("Run on process %s", os.getpid())
184 0 : func = self._mp_functions_dict[self.fn_name]
185 0 : return func(*self.args, **self.kwargs)
186 :
187 1 : @classmethod
188 1 : def terminate_all(cls, create_new_pool: bool) -> None:
189 : """Terminate all processes."""
190 1 : if cls._async_worker is not None:
191 0 : cls._async_worker.join()
192 :
193 1 : if cls._pool is None:
194 0 : return
195 :
196 1 : cls._pool.terminate()
197 1 : cls._pool = None
198 1 : logger.debug("Process pool terminated.")
199 :
200 1 : if create_new_pool:
201 0 : cls.create_pool()
202 :
203 :
204 1 : def run_in_other_process(func: Callable[P, R]) -> Callable[P, R]:
205 1 : MultiProcessWorker.register(func)
206 :
207 1 : @wraps(func)
208 1 : def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
209 1 : return MultiProcessWorker(func.__name__, *args, **kwargs).start() # type: ignore [no-any-return]
210 :
211 1 : return wrapper
|