Line data Source code
1 : # SPDX-FileCopyrightText: 2025 Pairinteraction Developers
2 : # SPDX-License-Identifier: LGPL-3.0-or-later
3 :
4 1 : import logging
5 1 : import os
6 1 : from functools import wraps
7 1 : from multiprocessing.pool import Pool
8 1 : from pathlib import Path
9 1 : from threading import Thread
10 1 : from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, TypeVar
11 :
12 1 : from PySide6.QtCore import QObject, QSize, Qt, QThread, Signal
13 1 : from PySide6.QtGui import QMovie
14 1 : from PySide6.QtWidgets import QApplication, QLabel, QWidget
15 :
16 : if TYPE_CHECKING:
17 : from typing_extensions import ParamSpec
18 :
19 : P = ParamSpec("P")
20 : R = TypeVar("R")
21 :
22 1 : logger = logging.getLogger(__name__)
23 :
24 :
25 1 : class WorkerSignals(QObject):
26 : """Signals to be used by the Worker class."""
27 :
28 1 : started = Signal()
29 1 : finished = Signal(bool)
30 1 : error = Signal(Exception)
31 1 : result = Signal(object)
32 :
33 :
34 1 : class MultiThreadWorker(QThread):
35 : """Simple worker class to run a function in a separate thread.
36 :
37 : Example:
38 : worker = Worker(my_function, arg1, arg2, kwarg1=value1)
39 : worker.signals.result.connect(process_result)
40 : worker.start()
41 :
42 : """
43 :
44 1 : all_threads: ClassVar[set["MultiThreadWorker"]] = set()
45 :
46 1 : def __init__(self, fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
47 1 : super().__init__(QApplication.instance())
48 :
49 1 : self.all_threads.add(self)
50 :
51 1 : self.fn = fn
52 1 : self.args = args
53 1 : self.kwargs = kwargs
54 :
55 1 : self.signals = WorkerSignals()
56 1 : self.finished.connect(self.finish_up)
57 :
58 1 : def enable_busy_indicator(self, widget: "QWidget") -> None:
59 : """Run a loading gif while the worker is running."""
60 1 : self.busy_label = QLabel(widget)
61 1 : gif_path = Path(__file__).parent / "images" / "loading.gif"
62 1 : self.busy_movie = QMovie(str(gif_path))
63 1 : self.busy_movie.setScaledSize(QSize(100, 100)) # Make the gif larger
64 1 : self.busy_label.setMovie(self.busy_movie)
65 1 : self.busy_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
66 1 : self.busy_label.setGeometry((widget.width() - 100) // 2, (widget.height() - 100) // 2, 100, 100)
67 :
68 1 : self.signals.started.connect(self._start_gif)
69 1 : self.signals.finished.connect(self._stop_gif)
70 :
71 1 : def _start_gif(self) -> None:
72 1 : self.busy_label.show()
73 1 : self.busy_movie.start()
74 :
75 1 : def _stop_gif(self, _success: bool = True) -> None:
76 1 : if hasattr(self, "busy_movie"):
77 1 : self.busy_movie.stop()
78 1 : if hasattr(self, "busy_label"):
79 1 : self.busy_label.hide()
80 :
81 1 : def run(self) -> None:
82 : """Initialise the runner function with passed args, kwargs."""
83 0 : logger.debug("Run on thread %s", self)
84 0 : success = False
85 0 : self.signals.started.emit()
86 0 : try:
87 0 : result = self.fn(*self.args, **self.kwargs)
88 0 : success = True
89 0 : except Exception as err:
90 0 : self.signals.error.emit(err)
91 : else:
92 0 : self.signals.result.emit(result)
93 : finally:
94 0 : self.signals.finished.emit(success)
95 :
96 1 : def finish_up(self) -> None:
97 : """Perform any final cleanup or actions before the thread exits."""
98 1 : logger.debug("Finishing up thread %s", self)
99 1 : self.all_threads.discard(self)
100 :
101 1 : @classmethod
102 1 : def terminate_all(cls) -> None:
103 : """Terminate all threads started by the application."""
104 : # Shallow copy to avoid error if the set is modified during the loop,
105 : # e.g. if the thread is finished and removes itself from the list
106 1 : all_threads = list(cls.all_threads)
107 1 : for thread in all_threads:
108 0 : if thread.isRunning():
109 0 : logger.debug("Terminating thread %s.", thread)
110 0 : thread.terminate()
111 0 : thread.signals.finished.emit(False)
112 0 : thread.wait()
113 :
114 1 : cls.all_threads.clear()
115 1 : logger.debug("All threads terminated.")
116 :
117 :
118 1 : class MultiProcessWorker:
119 1 : _mp_functions_dict: ClassVar[dict[str, Callable[..., Any]]] = {}
120 1 : _pool: ClassVar[Optional[Pool]] = None
121 1 : _async_worker: ClassVar[Optional[Thread]] = None
122 :
123 1 : def __init__(self, fn_name: str, *args: Any, **kwargs: Any) -> None:
124 1 : if fn_name not in self._mp_functions_dict:
125 0 : raise ValueError(f"Function {fn_name} is not registered.")
126 :
127 1 : self.fn_name = fn_name
128 1 : self.args = args
129 1 : self.kwargs = kwargs
130 :
131 1 : @classmethod
132 1 : def create_pool(cls, n_processes: int = 1) -> None:
133 : """Create a pool of processes."""
134 1 : if cls._pool is not None or cls._async_worker is not None:
135 0 : raise RuntimeError(
136 : "create_pool already called. Use terminate_all(create_new_pool=True) to restart the pool."
137 : )
138 :
139 1 : cls._async_worker = Thread(target=cls._create_pool, args=(n_processes,))
140 1 : cls._async_worker.start()
141 :
142 1 : @classmethod
143 1 : def _create_pool(cls, n_processes: int) -> None:
144 : """Create a pool of processes."""
145 1 : cls._pool = Pool(n_processes)
146 1 : cls._pool.apply(cls._dummy_function) # Call the pool once, to make the next call faster
147 1 : cls._async_worker = None
148 1 : logger.debug("Pool created successfully.")
149 :
150 1 : @staticmethod
151 1 : def _dummy_function() -> None:
152 : """Do nothing.
153 :
154 : Dummy function to run after creating the pool asynchronously.
155 : """
156 0 : return
157 :
158 1 : @classmethod
159 1 : def register(cls, func: Callable[..., Any], name: Optional[str] = None) -> None:
160 1 : name = name if name is not None else func.__name__
161 1 : if name in cls._mp_functions_dict:
162 0 : raise ValueError(f"Function {name} is already registered.")
163 1 : cls._mp_functions_dict[name] = func
164 :
165 1 : def start(self) -> Any:
166 1 : async_worker = self._async_worker
167 1 : if async_worker is not None:
168 1 : logger.debug("Waiting for creating_pool to finish.")
169 1 : async_worker.join()
170 1 : logger.debug("creating_pool finished.")
171 :
172 1 : if self._pool is None:
173 0 : raise RuntimeError("Pool is not created. Call create_pool() first.")
174 :
175 1 : logger.debug("Starting pool.apply")
176 1 : result = self._pool.apply(self.run)
177 1 : logger.debug("Finished pool.apply")
178 1 : return result
179 :
180 1 : def run(self) -> Any:
181 0 : logger.debug("Run on process %s", os.getpid())
182 0 : func = self._mp_functions_dict[self.fn_name]
183 0 : return func(*self.args, **self.kwargs)
184 :
185 1 : @classmethod
186 1 : def terminate_all(cls, create_new_pool: bool) -> None:
187 : """Terminate all processes."""
188 1 : if cls._async_worker is not None:
189 1 : cls._async_worker.join()
190 :
191 1 : if cls._pool is None:
192 0 : return
193 :
194 1 : cls._pool.terminate()
195 1 : cls._pool = None
196 1 : logger.debug("Process pool terminated.")
197 :
198 1 : if create_new_pool:
199 0 : cls.create_pool()
200 :
201 :
202 1 : def run_in_other_process(func: Callable["P", "R"]) -> Callable["P", "R"]:
203 1 : MultiProcessWorker.register(func)
204 :
205 1 : @wraps(func)
206 1 : def wrapper_func(*args: "P.args", **kwargs: "P.kwargs") -> "R":
207 1 : return MultiProcessWorker(func.__name__, *args, **kwargs).start() # type: ignore [no-any-return]
208 :
209 1 : return wrapper_func
|