Line data Source code
1 : # SPDX-FileCopyrightText: 2025 PairInteraction Developers
2 : # SPDX-License-Identifier: LGPL-3.0-or-later
3 1 : from __future__ import annotations
4 :
5 1 : import logging
6 1 : import time
7 1 : from pathlib import Path
8 1 : from typing import TYPE_CHECKING, Any
9 :
10 1 : import nbformat
11 1 : from nbconvert import PythonExporter
12 1 : from PySide6.QtCore import Qt
13 1 : from PySide6.QtGui import QIcon, QPixmap
14 1 : from PySide6.QtWidgets import (
15 : QFileDialog,
16 : QHBoxLayout,
17 : QMenu,
18 : QPushButton,
19 : QStyle,
20 : QToolBox,
21 : )
22 :
23 1 : import pairinteraction
24 1 : from pairinteraction_gui.config import BaseConfig
25 1 : from pairinteraction_gui.plotwidget.plotwidget import PlotEnergies
26 1 : from pairinteraction_gui.qobjects import NamedStackedWidget, WidgetV, show_status_tip
27 1 : from pairinteraction_gui.worker import MultiThreadWorker
28 :
29 : if TYPE_CHECKING:
30 : from collections.abc import Callable
31 :
32 : from PySide6.QtGui import QHideEvent, QShowEvent
33 :
34 : from pairinteraction_gui.calculate.calculate_base import Parameters, Results
35 : from pairinteraction_gui.config.calculation_config import CalculationConfig
36 : from pairinteraction_gui.config.ket_config import KetConfig
37 : from pairinteraction_gui.plotwidget.plotwidget import PlotWidget
38 :
39 1 : logger = logging.getLogger(__name__)
40 :
41 :
42 1 : class BasePage(WidgetV):
43 : """Base class for all pages in this application."""
44 :
45 1 : margin = (20, 20, 20, 20)
46 1 : spacing = 15
47 :
48 1 : title: str
49 1 : tooltip: str
50 1 : icon_path: Path | None = None
51 :
52 1 : def showEvent(self, event: QShowEvent) -> None:
53 : """Show event."""
54 1 : super().showEvent(event)
55 1 : self.window().setWindowTitle(
56 : f"PairInteraction v{pairinteraction.__version__} - " + self.title.replace("\n", " ")
57 : )
58 :
59 :
60 1 : class SimulationPage(BasePage):
61 : """Base class for all simulation pages in this application."""
62 :
63 1 : ket_config: KetConfig
64 :
65 1 : plotwidget: PlotWidget
66 :
67 1 : def setupWidget(self) -> None:
68 1 : self.toolbox = QToolBox()
69 :
70 : # Create a dummy icon to allow adjusting the height of the toolbox tabs,
71 : # see https://stackoverflow.com/questions/48503645/customizing-qtoolbox-tab-height
72 1 : px = QPixmap(1, 1)
73 1 : px.fill(Qt.GlobalColor.transparent)
74 1 : self._toolbox_dummy_icon = QIcon(px)
75 :
76 1 : def postSetupWidget(self) -> None:
77 1 : for attr in self.__dict__.values():
78 1 : if isinstance(attr, BaseConfig):
79 1 : self.toolbox.addItem(attr, self._toolbox_dummy_icon, attr.title)
80 :
81 1 : for i, species_combo in enumerate(self.ket_config.species_combo_list):
82 1 : self.ket_config.signal_species_changed.emit(i, species_combo.currentText())
83 :
84 1 : def showEvent(self, event: QShowEvent) -> None:
85 1 : super().showEvent(event)
86 1 : self.window().dockwidget.setWidget(self.toolbox)
87 1 : self.window().dockwidget.setVisible(True)
88 1 : self.toolbox.show()
89 :
90 1 : def hideEvent(self, event: QHideEvent) -> None:
91 1 : super().hideEvent(event)
92 1 : self.window().dockwidget.setVisible(False)
93 :
94 :
95 1 : class CalculationPage(SimulationPage):
96 : """Base class for all pages with a calculation button."""
97 :
98 1 : plotwidget: PlotEnergies
99 1 : _calculation_finished = False
100 1 : _plot_finished = False
101 :
102 1 : def setupWidget(self) -> None:
103 1 : super().setupWidget()
104 :
105 : # Plot Panel
106 1 : self.plotwidget = self._create_plot_widget()
107 1 : self.layout().addWidget(self.plotwidget)
108 :
109 : # Control panel below the plot
110 1 : bottom_layout = QHBoxLayout()
111 1 : bottom_layout.setObjectName("bottomLayout")
112 :
113 : # Calculate/Abort stacked buttons
114 1 : self.calculate_and_abort = NamedStackedWidget[QPushButton](self)
115 :
116 1 : calculate_button = QPushButton("Calculate")
117 1 : calculate_button.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_BrowserReload))
118 1 : calculate_button.clicked.connect(self.calculate_clicked)
119 1 : self.calculate_and_abort.addNamedWidget(calculate_button, "Calculate")
120 :
121 1 : abort_button = QPushButton("Abort")
122 1 : abort_button.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_BrowserStop))
123 1 : abort_button.clicked.connect(self.abort_clicked)
124 1 : self.calculate_and_abort.addNamedWidget(abort_button, "Abort")
125 :
126 1 : self.calculate_and_abort.setFixedHeight(50)
127 1 : bottom_layout.addWidget(self.calculate_and_abort, stretch=2)
128 :
129 : # Create export button with menu
130 1 : export_button = QPushButton("Export")
131 1 : export_button.setObjectName("Export")
132 1 : export_button.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_DialogSaveButton))
133 1 : export_menu = QMenu(self)
134 1 : for label, handler in self._get_export_actions():
135 1 : export_menu.addAction(label, handler)
136 1 : export_button.setMenu(export_menu)
137 1 : export_button.setFixedHeight(50)
138 1 : bottom_layout.addWidget(export_button, stretch=1)
139 :
140 1 : self.layout().addLayout(bottom_layout)
141 :
142 1 : def calculate_clicked(self) -> None:
143 1 : self._calculation_finished = False
144 1 : self._plot_finished = False
145 1 : self.before_calculate()
146 :
147 1 : def update_plot(
148 : parameters_and_results: tuple[Parameters[Any], Results],
149 : ) -> None:
150 1 : worker_plot = MultiThreadWorker(self.update_plot, *parameters_and_results)
151 1 : worker_plot.signals.progress.connect(lambda message: show_status_tip(self, message))
152 1 : worker_plot.signals.finished.connect(lambda _: setattr(self, "_plot_finished", True))
153 1 : worker_plot.start()
154 :
155 1 : worker = MultiThreadWorker(self.calculate)
156 1 : if hasattr(self, "calculation_config"):
157 1 : calculation_config: CalculationConfig = self.calculation_config
158 1 : number_of_steps = calculation_config.steps.value()
159 1 : worker.enable_busy_indicator(self.plotwidget, add_progress_label=True, number_of_steps=number_of_steps)
160 : else:
161 0 : worker.enable_busy_indicator(self.plotwidget)
162 1 : worker.signals.progress.connect(lambda message: show_status_tip(self, message))
163 1 : worker.signals.result.connect(update_plot)
164 1 : worker.signals.finished.connect(self.after_calculate)
165 1 : worker.signals.finished.connect(lambda _: setattr(self, "_calculation_finished", True))
166 1 : worker.start()
167 :
168 1 : def before_calculate(self) -> None:
169 1 : show_status_tip(self, "Calculating... Please wait.", logger=logger)
170 1 : self.calculate_and_abort.setCurrentNamedWidget("Abort")
171 1 : self.plotwidget.clear()
172 :
173 1 : self._start_time = time.perf_counter()
174 :
175 1 : def after_calculate(self, status: str) -> None:
176 1 : time_needed = time.perf_counter() - self._start_time
177 1 : show_status_tip(self, f"{status} after {time_needed:.2f} seconds.", logger=logger)
178 1 : self.calculate_and_abort.setCurrentNamedWidget("Calculate")
179 :
180 1 : def calculate(self) -> tuple[Parameters[Any], Results]:
181 0 : raise NotImplementedError("Subclasses must implement this method")
182 :
183 1 : def update_plot(self, parameters: Parameters[Any], results: Results) -> None:
184 1 : self.plotwidget.canvas.draw() # draw once before, to avoid displaying artifacts during plotting
185 1 : self.plotwidget.plot(parameters, results)
186 1 : self.plotwidget.setup_annotations(parameters, results)
187 1 : self._plot_function(parameters, results)
188 1 : self.plotwidget.canvas.draw()
189 1 : self.plotwidget.navigation_toolbar.reset_home_view()
190 1 : show_status_tip(self, "Finished updating plot. Tip: Click on the plot to see state information.", logger=logger)
191 :
192 1 : def _plot_function(self, parameters: Parameters[Any], results: Results) -> None:
193 : # This method can be overridden by subclasses to provide a custom plotting function
194 : # that is called after the default plotting and before drawing the canvas.
195 0 : pass
196 :
197 1 : def export_png(self) -> None:
198 : """Export the current plot as a PNG file."""
199 0 : logger.debug("Exporting results as PNG...")
200 :
201 0 : filename, _ = QFileDialog.getSaveFileName(self, "Save Plot", "", "PNG Files (*.png)")
202 :
203 0 : if filename:
204 0 : filename = filename.removesuffix(".png") + ".png"
205 0 : self.plotwidget.canvas.fig.savefig(
206 : filename, dpi=300, bbox_inches="tight", facecolor="white", edgecolor="none"
207 : )
208 0 : logger.info("Plot saved as %s", filename)
209 :
210 1 : def _create_python_code(self) -> str:
211 1 : template_path = Path(__file__).parent.parent / "export_templates" / self._get_export_notebook_template_name()
212 1 : with Path(template_path).open() as f:
213 1 : notebook = nbformat.read(f, as_version=4)
214 :
215 1 : exporter = PythonExporter(exclude_output_prompt=True, exclude_input_prompt=True)
216 1 : content, _ = exporter.from_notebook_node(notebook)
217 :
218 1 : replacements = self._get_export_replacements()
219 1 : for key, value in replacements.items():
220 1 : content = content.replace(key, str(value))
221 :
222 1 : return content
223 :
224 1 : def export_python(self) -> None:
225 : """Export the current calculation as a Python script."""
226 0 : logger.debug("Exporting results as Python script...")
227 0 : filename, _ = QFileDialog.getSaveFileName(self, "Save Python Script", "", "Python Files (*.py)")
228 0 : if filename:
229 0 : filename = filename.removesuffix(".py") + ".py"
230 0 : content = self._create_python_code()
231 0 : with Path(filename).open("w") as f:
232 0 : f.write(content)
233 0 : logger.info("Python script saved as %s", filename)
234 :
235 1 : def export_notebook(self) -> None:
236 : """Export the current calculation as a Jupyter notebook."""
237 0 : logger.debug("Exporting results as Jupyter notebook...")
238 :
239 0 : filename, _ = QFileDialog.getSaveFileName(self, "Save Jupyter Notebook", "", "Jupyter Notebooks (*.ipynb)")
240 :
241 0 : if filename:
242 0 : filename = filename.removesuffix(".ipynb") + ".ipynb"
243 :
244 0 : template_path = (
245 : Path(__file__).parent.parent / "export_templates" / self._get_export_notebook_template_name()
246 : )
247 0 : with Path(template_path).open() as f:
248 0 : notebook = nbformat.read(f, as_version=4)
249 :
250 0 : replacements = self._get_export_replacements()
251 0 : for cell in notebook.cells:
252 0 : if cell.cell_type == "code":
253 0 : source = cell.source
254 0 : for key, value in replacements.items():
255 0 : source = source.replace(key, str(value))
256 0 : cell.source = source
257 :
258 0 : nbformat.write(notebook, filename)
259 :
260 0 : logger.info("Jupyter notebook saved as %s", filename)
261 :
262 1 : def _get_export_notebook_template_name(self) -> str:
263 0 : raise NotImplementedError("Subclasses must implement this method")
264 :
265 1 : def _get_export_replacements(self) -> dict[str, str]:
266 : # Override this method in subclasses to provide specific replacements for the export
267 0 : return {}
268 :
269 1 : def abort_clicked(self) -> None:
270 : """Handle abort button click."""
271 0 : logger.debug("Aborting calculation.")
272 0 : MultiThreadWorker.terminate_all()
273 0 : self.after_calculate("Calculation aborted.")
274 :
275 1 : def _create_plot_widget(self) -> PlotEnergies:
276 1 : return PlotEnergies(self)
277 :
278 1 : def _get_export_actions(self) -> list[tuple[str, Callable[[], None]]]:
279 1 : return [
280 : ("Export as PNG", self.export_png),
281 : ("Export as Python script", self.export_python),
282 : ("Export as Jupyter notebook", self.export_notebook),
283 : ]
|