Line data Source code
1 : # SPDX-FileCopyrightText: 2025 PairInteraction Developers
2 : # SPDX-License-Identifier: LGPL-3.0-or-later
3 1 : from __future__ import annotations
4 :
5 1 : import logging
6 1 : import time
7 1 : from pathlib import Path
8 1 : from typing import TYPE_CHECKING, Any
9 :
10 1 : import nbformat
11 1 : from nbconvert import PythonExporter
12 1 : from PySide6.QtWidgets import (
13 : QFileDialog,
14 : QHBoxLayout,
15 : QMenu,
16 : QPushButton,
17 : QStyle,
18 : QToolBox,
19 : )
20 :
21 1 : import pairinteraction
22 1 : from pairinteraction_gui.config import BaseConfig
23 1 : from pairinteraction_gui.plotwidget.plotwidget import PlotEnergies
24 1 : from pairinteraction_gui.qobjects import NamedStackedWidget, WidgetV, show_status_tip
25 1 : from pairinteraction_gui.worker import MultiProcessWorker, MultiThreadWorker
26 :
27 : if TYPE_CHECKING:
28 : from PySide6.QtGui import QHideEvent, QShowEvent
29 :
30 : from pairinteraction_gui.calculate.calculate_base import Parameters, Results
31 : from pairinteraction_gui.config.ket_config import KetConfig
32 : from pairinteraction_gui.plotwidget.plotwidget import PlotWidget
33 :
34 1 : logger = logging.getLogger(__name__)
35 :
36 :
37 1 : class BasePage(WidgetV):
38 : """Base class for all pages in this application."""
39 :
40 1 : margin = (20, 20, 20, 20)
41 1 : spacing = 15
42 :
43 1 : title: str
44 1 : tooltip: str
45 1 : icon_path: Path | None = None
46 :
47 1 : def showEvent(self, event: QShowEvent) -> None:
48 : """Show event."""
49 1 : super().showEvent(event)
50 1 : self.window().setWindowTitle(f"PairInteraction v{pairinteraction.__version__} - " + self.title)
51 :
52 :
53 1 : class SimulationPage(BasePage):
54 : """Base class for all simulation pages in this application."""
55 :
56 1 : ket_config: KetConfig
57 :
58 1 : plotwidget: PlotWidget
59 :
60 1 : def setupWidget(self) -> None:
61 1 : self.toolbox = QToolBox()
62 :
63 1 : def postSetupWidget(self) -> None:
64 1 : for attr in self.__dict__.values():
65 1 : if isinstance(attr, BaseConfig):
66 1 : self.toolbox.addItem(attr, attr.title)
67 :
68 1 : for i, species_combo in enumerate(self.ket_config.species_combo_list):
69 1 : self.ket_config.signal_species_changed.emit(i, species_combo.currentText())
70 :
71 1 : def showEvent(self, event: QShowEvent) -> None:
72 1 : super().showEvent(event)
73 1 : self.window().dockwidget.setWidget(self.toolbox)
74 1 : self.window().dockwidget.setVisible(True)
75 1 : self.toolbox.show()
76 :
77 1 : def hideEvent(self, event: QHideEvent) -> None:
78 1 : super().hideEvent(event)
79 1 : self.window().dockwidget.setVisible(False)
80 :
81 :
82 1 : class CalculationPage(SimulationPage):
83 : """Base class for all pages with a calculation button."""
84 :
85 1 : plotwidget: PlotEnergies
86 1 : _calculation_finished = False
87 1 : _plot_finished = False
88 :
89 1 : def setupWidget(self) -> None:
90 1 : super().setupWidget()
91 :
92 : # Plot Panel
93 1 : self.plotwidget = PlotEnergies(self)
94 1 : self.layout().addWidget(self.plotwidget)
95 :
96 : # Control panel below the plot
97 1 : bottom_layout = QHBoxLayout()
98 1 : bottom_layout.setObjectName("bottomLayout")
99 :
100 : # Calculate/Abort stacked buttons
101 1 : self.calculate_and_abort = NamedStackedWidget[QPushButton](self)
102 :
103 1 : calculate_button = QPushButton("Calculate")
104 1 : calculate_button.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_BrowserReload))
105 1 : calculate_button.clicked.connect(self.calculate_clicked)
106 1 : self.calculate_and_abort.addNamedWidget(calculate_button, "Calculate")
107 :
108 1 : abort_button = QPushButton("Abort")
109 1 : abort_button.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_BrowserStop))
110 1 : abort_button.clicked.connect(self.abort_clicked)
111 1 : self.calculate_and_abort.addNamedWidget(abort_button, "Abort")
112 :
113 1 : self.calculate_and_abort.setFixedHeight(50)
114 1 : bottom_layout.addWidget(self.calculate_and_abort, stretch=2)
115 :
116 : # Create export button with menu
117 1 : export_button = QPushButton("Export")
118 1 : export_button.setObjectName("Export")
119 1 : export_button.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_DialogSaveButton))
120 1 : export_menu = QMenu(self)
121 1 : export_menu.addAction("Export as PNG", self.export_png)
122 1 : export_menu.addAction("Export as Python script", self.export_python)
123 1 : export_menu.addAction("Export as Jupyter notebook", self.export_notebook)
124 1 : export_button.setMenu(export_menu)
125 1 : export_button.setFixedHeight(50)
126 1 : bottom_layout.addWidget(export_button, stretch=1)
127 :
128 1 : self.layout().addLayout(bottom_layout)
129 :
130 1 : def calculate_clicked(self) -> None:
131 1 : self._calculation_finished = False
132 1 : self._plot_finished = False
133 1 : self.before_calculate()
134 :
135 1 : def update_plot(
136 : parameters_and_results: tuple[Parameters[Any], Results],
137 : ) -> None:
138 1 : worker_plot = MultiThreadWorker(self.update_plot, *parameters_and_results)
139 1 : worker_plot.start()
140 1 : worker_plot.signals.finished.connect(lambda _sucess: setattr(self, "_plot_finished", True))
141 :
142 1 : worker = MultiThreadWorker(self.calculate)
143 1 : worker.enable_busy_indicator(self)
144 1 : worker.signals.result.connect(update_plot)
145 1 : worker.signals.finished.connect(self.after_calculate)
146 1 : worker.signals.finished.connect(lambda _sucess: setattr(self, "_calculation_finished", True))
147 1 : worker.start()
148 :
149 1 : def before_calculate(self) -> None:
150 1 : show_status_tip(self, "Calculating... Please wait.", logger=logger)
151 1 : self.calculate_and_abort.setCurrentNamedWidget("Abort")
152 1 : self.plotwidget.clear()
153 :
154 1 : self._start_time = time.perf_counter()
155 :
156 1 : def after_calculate(self, success: bool) -> None:
157 1 : time_needed = time.perf_counter() - self._start_time
158 :
159 1 : if success:
160 1 : show_status_tip(self, f"Calculation finished after {time_needed:.2f} seconds.", logger=logger)
161 : else:
162 0 : show_status_tip(self, f"Calculation failed after {time_needed:.2f} seconds.", logger=logger)
163 :
164 1 : self.calculate_and_abort.setCurrentNamedWidget("Calculate")
165 :
166 1 : def calculate(self) -> tuple[Parameters[Any], Results]:
167 0 : raise NotImplementedError("Subclasses must implement this method")
168 :
169 1 : def update_plot(self, parameters: Parameters[Any], results: Results) -> None:
170 0 : self.plotwidget.plot(parameters, results)
171 0 : self.plotwidget.add_cursor(parameters, results)
172 0 : self.plotwidget.canvas.draw()
173 :
174 1 : def export_png(self) -> None:
175 : """Export the current plot as a PNG file."""
176 0 : logger.debug("Exporting results as PNG...")
177 :
178 0 : filename, _ = QFileDialog.getSaveFileName(self, "Save Plot", "", "PNG Files (*.png)")
179 :
180 0 : if filename:
181 0 : filename = filename.removesuffix(".png") + ".png"
182 0 : self.plotwidget.canvas.fig.savefig(
183 : filename, dpi=300, bbox_inches="tight", facecolor="white", edgecolor="none"
184 : )
185 0 : logger.info("Plot saved as %s", filename)
186 :
187 1 : def _create_python_code(self) -> str:
188 1 : template_path = Path(__file__).parent.parent / "export_templates" / self._get_export_notebook_template_name()
189 1 : with Path(template_path).open() as f:
190 1 : notebook = nbformat.read(f, as_version=4)
191 :
192 1 : exporter = PythonExporter(exclude_output_prompt=True, exclude_input_prompt=True)
193 1 : content, _ = exporter.from_notebook_node(notebook)
194 :
195 1 : replacements = self._get_export_replacements()
196 1 : for key, value in replacements.items():
197 1 : content = content.replace(key, str(value))
198 :
199 1 : return content
200 :
201 1 : def export_python(self) -> None:
202 : """Export the current calculation as a Python script."""
203 0 : logger.debug("Exporting results as Python script...")
204 0 : filename, _ = QFileDialog.getSaveFileName(self, "Save Python Script", "", "Python Files (*.py)")
205 0 : if filename:
206 0 : filename = filename.removesuffix(".py") + ".py"
207 0 : content = self._create_python_code()
208 0 : with Path(filename).open("w") as f:
209 0 : f.write(content)
210 0 : logger.info("Python script saved as %s", filename)
211 :
212 1 : def export_notebook(self) -> None:
213 : """Export the current calculation as a Jupyter notebook."""
214 0 : logger.debug("Exporting results as Jupyter notebook...")
215 :
216 0 : filename, _ = QFileDialog.getSaveFileName(self, "Save Jupyter Notebook", "", "Jupyter Notebooks (*.ipynb)")
217 :
218 0 : if filename:
219 0 : filename = filename.removesuffix(".ipynb") + ".ipynb"
220 :
221 0 : template_path = (
222 : Path(__file__).parent.parent / "export_templates" / self._get_export_notebook_template_name()
223 : )
224 0 : with Path(template_path).open() as f:
225 0 : notebook = nbformat.read(f, as_version=4)
226 :
227 0 : replacements = self._get_export_replacements()
228 0 : for cell in notebook.cells:
229 0 : if cell.cell_type == "code":
230 0 : source = cell.source
231 0 : for key, value in replacements.items():
232 0 : source = source.replace(key, str(value))
233 0 : cell.source = source
234 :
235 0 : nbformat.write(notebook, filename)
236 :
237 0 : logger.info("Jupyter notebook saved as %s", filename)
238 :
239 1 : def _get_export_notebook_template_name(self) -> str:
240 0 : raise NotImplementedError("Subclasses must implement this method")
241 :
242 1 : def _get_export_replacements(self) -> dict[str, str]:
243 : # Override this method in subclasses to provide specific replacements for the export
244 0 : return {}
245 :
246 1 : def abort_clicked(self) -> None:
247 : """Handle abort button click."""
248 0 : logger.debug("Aborting calculation.")
249 0 : MultiProcessWorker.terminate_all(create_new_pool=True)
250 0 : MultiThreadWorker.terminate_all()
251 0 : self.after_calculate(False)
252 0 : show_status_tip(self, "Calculation aborted.", logger=logger)
|