Line data Source code
1 : # SPDX-FileCopyrightText: 2025 PairInteraction Developers
2 : # SPDX-License-Identifier: LGPL-3.0-or-later
3 :
4 1 : import logging
5 1 : import time
6 1 : from pathlib import Path
7 1 : from typing import Any, Optional
8 :
9 1 : import nbformat
10 1 : from nbconvert import PythonExporter
11 1 : from PySide6.QtGui import QHideEvent, QShowEvent
12 1 : from PySide6.QtWidgets import (
13 : QFileDialog,
14 : QHBoxLayout,
15 : QMenu,
16 : QPushButton,
17 : QStyle,
18 : QToolBox,
19 : )
20 :
21 1 : import pairinteraction
22 1 : from pairinteraction_gui.calculate.calculate_base import Parameters, Results
23 1 : from pairinteraction_gui.config import BaseConfig
24 1 : from pairinteraction_gui.config.ket_config import KetConfig
25 1 : from pairinteraction_gui.plotwidget.plotwidget import PlotEnergies, PlotWidget
26 1 : from pairinteraction_gui.qobjects import NamedStackedWidget, WidgetV, show_status_tip
27 1 : from pairinteraction_gui.worker import MultiProcessWorker, MultiThreadWorker
28 :
29 1 : logger = logging.getLogger(__name__)
30 :
31 :
32 1 : class BasePage(WidgetV):
33 : """Base class for all pages in this application."""
34 :
35 1 : margin = (20, 20, 20, 20)
36 1 : spacing = 15
37 :
38 1 : title: str
39 1 : tooltip: str
40 1 : icon_path: Optional[Path] = None
41 :
42 1 : def showEvent(self, event: QShowEvent) -> None:
43 : """Show event."""
44 1 : super().showEvent(event)
45 1 : self.window().setWindowTitle(f"PairInteraction v{pairinteraction.__version__} - " + self.title)
46 :
47 :
48 1 : class SimulationPage(BasePage):
49 : """Base class for all simulation pages in this application."""
50 :
51 1 : ket_config: KetConfig
52 :
53 1 : plotwidget: PlotWidget
54 :
55 1 : def setupWidget(self) -> None:
56 1 : self.toolbox = QToolBox()
57 :
58 1 : def postSetupWidget(self) -> None:
59 1 : for attr in self.__dict__.values():
60 1 : if isinstance(attr, BaseConfig):
61 1 : self.toolbox.addItem(attr, attr.title)
62 :
63 1 : for i, species_combo in enumerate(self.ket_config.species_combo_list):
64 1 : self.ket_config.signal_species_changed.emit(i, species_combo.currentText())
65 :
66 1 : def showEvent(self, event: QShowEvent) -> None:
67 1 : super().showEvent(event)
68 1 : self.window().dockwidget.setWidget(self.toolbox)
69 1 : self.window().dockwidget.setVisible(True)
70 1 : self.toolbox.show()
71 :
72 1 : def hideEvent(self, event: QHideEvent) -> None:
73 1 : super().hideEvent(event)
74 1 : self.window().dockwidget.setVisible(False)
75 :
76 :
77 1 : class CalculationPage(SimulationPage):
78 : """Base class for all pages with a calculation button."""
79 :
80 1 : plotwidget: PlotEnergies
81 1 : _calculation_finished = False
82 1 : _plot_finished = False
83 :
84 1 : def setupWidget(self) -> None:
85 1 : super().setupWidget()
86 :
87 : # Plot Panel
88 1 : self.plotwidget = PlotEnergies(self)
89 1 : self.layout().addWidget(self.plotwidget)
90 :
91 : # Control panel below the plot
92 1 : bottom_layout = QHBoxLayout()
93 1 : bottom_layout.setObjectName("bottomLayout")
94 :
95 : # Calculate/Abort stacked buttons
96 1 : self.calculate_and_abort = NamedStackedWidget[QPushButton](self)
97 :
98 1 : calculate_button = QPushButton("Calculate")
99 1 : calculate_button.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_BrowserReload))
100 1 : calculate_button.clicked.connect(self.calculate_clicked)
101 1 : self.calculate_and_abort.addNamedWidget(calculate_button, "Calculate")
102 :
103 1 : abort_button = QPushButton("Abort")
104 1 : abort_button.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_BrowserStop))
105 1 : abort_button.clicked.connect(self.abort_clicked)
106 1 : self.calculate_and_abort.addNamedWidget(abort_button, "Abort")
107 :
108 1 : self.calculate_and_abort.setFixedHeight(50)
109 1 : bottom_layout.addWidget(self.calculate_and_abort, stretch=2)
110 :
111 : # Create export button with menu
112 1 : export_button = QPushButton("Export")
113 1 : export_button.setObjectName("Export")
114 1 : export_button.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_DialogSaveButton))
115 1 : export_menu = QMenu(self)
116 1 : export_menu.addAction("Export as PNG", self.export_png)
117 1 : export_menu.addAction("Export as Python script", self.export_python)
118 1 : export_menu.addAction("Export as Jupyter notebook", self.export_notebook)
119 1 : export_button.setMenu(export_menu)
120 1 : export_button.setFixedHeight(50)
121 1 : bottom_layout.addWidget(export_button, stretch=1)
122 :
123 1 : self.layout().addLayout(bottom_layout)
124 :
125 1 : def calculate_clicked(self) -> None:
126 1 : self._calculation_finished = False
127 1 : self._plot_finished = False
128 1 : self.before_calculate()
129 :
130 1 : def update_plot(
131 : parameters_and_results: tuple[Parameters[Any], Results],
132 : ) -> None:
133 1 : worker_plot = MultiThreadWorker(self.update_plot, *parameters_and_results)
134 1 : worker_plot.start()
135 1 : worker_plot.signals.finished.connect(lambda _sucess: setattr(self, "_plot_finished", True))
136 :
137 1 : worker = MultiThreadWorker(self.calculate)
138 1 : worker.enable_busy_indicator(self)
139 1 : worker.signals.result.connect(update_plot)
140 1 : worker.signals.finished.connect(self.after_calculate)
141 1 : worker.signals.finished.connect(lambda _sucess: setattr(self, "_calculation_finished", True))
142 1 : worker.start()
143 :
144 1 : def before_calculate(self) -> None:
145 1 : show_status_tip(self, "Calculating... Please wait.", logger=logger)
146 1 : self.calculate_and_abort.setCurrentNamedWidget("Abort")
147 1 : self.plotwidget.clear()
148 :
149 1 : self._start_time = time.perf_counter()
150 :
151 1 : def after_calculate(self, success: bool) -> None:
152 1 : time_needed = time.perf_counter() - self._start_time
153 :
154 1 : if success:
155 1 : show_status_tip(self, f"Calculation finished after {time_needed:.2f} seconds.", logger=logger)
156 : else:
157 0 : show_status_tip(self, f"Calculation failed after {time_needed:.2f} seconds.", logger=logger)
158 :
159 1 : self.calculate_and_abort.setCurrentNamedWidget("Calculate")
160 :
161 1 : def calculate(self) -> tuple[Parameters[Any], Results]:
162 0 : raise NotImplementedError("Subclasses must implement this method")
163 :
164 1 : def update_plot(self, parameters: Parameters[Any], results: Results) -> None:
165 0 : energies = results.energies
166 0 : overlaps = results.ket_overlaps
167 :
168 0 : x_values = parameters.get_x_values()
169 0 : x_label = parameters.get_x_label()
170 :
171 0 : self.plotwidget.plot(x_values, energies, overlaps, x_label)
172 :
173 0 : self.plotwidget.add_cursor(x_values, energies, results.state_labels)
174 :
175 0 : self.plotwidget.canvas.draw()
176 0 : self._plot_finished = False
177 :
178 1 : def export_png(self) -> None:
179 : """Export the current plot as a PNG file."""
180 0 : logger.debug("Exporting results as PNG...")
181 :
182 0 : filename, _ = QFileDialog.getSaveFileName(self, "Save Plot", "", "PNG Files (*.png)")
183 :
184 0 : if filename:
185 0 : filename = filename.removesuffix(".png") + ".png"
186 0 : self.plotwidget.canvas.fig.savefig(
187 : filename, dpi=300, bbox_inches="tight", facecolor="white", edgecolor="none"
188 : )
189 0 : logger.info("Plot saved as %s", filename)
190 :
191 1 : def _create_python_code(self) -> str:
192 1 : template_path = Path(__file__).parent.parent / "export_templates" / self._get_export_notebook_template_name()
193 1 : with Path(template_path).open() as f:
194 1 : notebook = nbformat.read(f, as_version=4)
195 :
196 1 : exporter = PythonExporter(exclude_output_prompt=True, exclude_input_prompt=True)
197 1 : content, _ = exporter.from_notebook_node(notebook)
198 :
199 1 : replacements = self._get_export_replacements()
200 1 : for key, value in replacements.items():
201 1 : content = content.replace(key, str(value))
202 :
203 1 : return content
204 :
205 1 : def export_python(self) -> None:
206 : """Export the current calculation as a Python script."""
207 0 : logger.debug("Exporting results as Python script...")
208 0 : filename, _ = QFileDialog.getSaveFileName(self, "Save Python Script", "", "Python Files (*.py)")
209 0 : if filename:
210 0 : filename = filename.removesuffix(".py") + ".py"
211 0 : content = self._create_python_code()
212 0 : with Path(filename).open("w") as f:
213 0 : f.write(content)
214 0 : logger.info("Python script saved as %s", filename)
215 :
216 1 : def export_notebook(self) -> None:
217 : """Export the current calculation as a Jupyter notebook."""
218 0 : logger.debug("Exporting results as Jupyter notebook...")
219 :
220 0 : filename, _ = QFileDialog.getSaveFileName(self, "Save Jupyter Notebook", "", "Jupyter Notebooks (*.ipynb)")
221 :
222 0 : if filename:
223 0 : filename = filename.removesuffix(".ipynb") + ".ipynb"
224 :
225 0 : template_path = (
226 : Path(__file__).parent.parent / "export_templates" / self._get_export_notebook_template_name()
227 : )
228 0 : with Path(template_path).open() as f:
229 0 : notebook = nbformat.read(f, as_version=4)
230 :
231 0 : replacements = self._get_export_replacements()
232 0 : for cell in notebook.cells:
233 0 : if cell.cell_type == "code":
234 0 : source = cell.source
235 0 : for key, value in replacements.items():
236 0 : source = source.replace(key, str(value))
237 0 : cell.source = source
238 :
239 0 : nbformat.write(notebook, filename)
240 :
241 0 : logger.info("Jupyter notebook saved as %s", filename)
242 :
243 1 : def _get_export_notebook_template_name(self) -> str:
244 0 : raise NotImplementedError("Subclasses must implement this method")
245 :
246 1 : def _get_export_replacements(self) -> dict[str, str]:
247 : # Override this method in subclasses to provide specific replacements for the export
248 0 : return {}
249 :
250 1 : def abort_clicked(self) -> None:
251 : """Handle abort button click."""
252 0 : logger.debug("Aborting calculation.")
253 0 : MultiProcessWorker.terminate_all(create_new_pool=True)
254 0 : MultiThreadWorker.terminate_all()
255 0 : self.after_calculate(False)
256 0 : show_status_tip(self, "Calculation aborted.", logger=logger)
|