Line data Source code
1 : # SPDX-FileCopyrightText: 2025 Pairinteraction Developers
2 : # SPDX-License-Identifier: LGPL-3.0-or-later
3 :
4 1 : import logging
5 1 : import time
6 1 : from pathlib import Path
7 1 : from typing import Any, Optional
8 :
9 1 : import nbformat
10 1 : from nbconvert import PythonExporter
11 1 : from PySide6.QtGui import QHideEvent, QShowEvent
12 1 : from PySide6.QtWidgets import (
13 : QFileDialog,
14 : QHBoxLayout,
15 : QMenu,
16 : QPushButton,
17 : QStyle,
18 : QToolBox,
19 : )
20 :
21 1 : import pairinteraction
22 1 : from pairinteraction_gui.calculate.calculate_base import Parameters, Results
23 1 : from pairinteraction_gui.config import BaseConfig
24 1 : from pairinteraction_gui.config.ket_config import KetConfig
25 1 : from pairinteraction_gui.plotwidget.plotwidget import PlotEnergies, PlotWidget
26 1 : from pairinteraction_gui.qobjects import NamedStackedWidget, WidgetV, show_status_tip
27 1 : from pairinteraction_gui.worker import MultiProcessWorker, MultiThreadWorker
28 :
29 1 : logger = logging.getLogger(__name__)
30 :
31 :
32 1 : class BasePage(WidgetV):
33 : """Base class for all pages in this application."""
34 :
35 1 : margin = (20, 20, 20, 20)
36 1 : spacing = 15
37 :
38 1 : title: str
39 1 : tooltip: str
40 1 : icon_path: Optional[Path] = None
41 :
42 1 : def showEvent(self, event: QShowEvent) -> None:
43 : """Show event."""
44 1 : super().showEvent(event)
45 1 : self.window().setWindowTitle(f"PairInteraction v{pairinteraction.__version__} - " + self.title)
46 :
47 :
48 1 : class SimulationPage(BasePage):
49 : """Base class for all simulation pages in this application."""
50 :
51 1 : ket_config: KetConfig
52 :
53 1 : plotwidget: PlotWidget
54 :
55 1 : def setupWidget(self) -> None:
56 1 : self.toolbox = QToolBox()
57 :
58 1 : def postSetupWidget(self) -> None:
59 1 : for attr in self.__dict__.values():
60 1 : if isinstance(attr, BaseConfig):
61 1 : self.toolbox.addItem(attr, attr.title)
62 :
63 1 : def showEvent(self, event: QShowEvent) -> None:
64 1 : super().showEvent(event)
65 1 : self.window().dockwidget.setWidget(self.toolbox)
66 1 : self.window().dockwidget.setVisible(True)
67 1 : self.toolbox.show()
68 :
69 1 : def hideEvent(self, event: QHideEvent) -> None:
70 1 : super().hideEvent(event)
71 1 : self.window().dockwidget.setVisible(False)
72 :
73 :
74 1 : class CalculationPage(SimulationPage):
75 : """Base class for all pages with a calculation button."""
76 :
77 1 : plotwidget: PlotEnergies
78 1 : _calculation_finished = False
79 1 : _plot_finished = False
80 :
81 1 : def setupWidget(self) -> None:
82 1 : super().setupWidget()
83 :
84 : # Plot Panel
85 1 : self.plotwidget = PlotEnergies(self)
86 1 : self.layout().addWidget(self.plotwidget)
87 :
88 : # Control panel below the plot
89 1 : bottom_layout = QHBoxLayout()
90 1 : bottom_layout.setObjectName("bottomLayout")
91 :
92 : # Calculate/Abort stacked buttons
93 1 : self.calculate_and_abort = NamedStackedWidget[QPushButton](self)
94 :
95 1 : calculate_button = QPushButton("Calculate")
96 1 : calculate_button.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_BrowserReload))
97 1 : calculate_button.clicked.connect(self.calculate_clicked)
98 1 : self.calculate_and_abort.addNamedWidget(calculate_button, "Calculate")
99 :
100 1 : abort_button = QPushButton("Abort")
101 1 : abort_button.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_BrowserStop))
102 1 : abort_button.clicked.connect(self.abort_clicked)
103 1 : self.calculate_and_abort.addNamedWidget(abort_button, "Abort")
104 :
105 1 : self.calculate_and_abort.setFixedHeight(50)
106 1 : bottom_layout.addWidget(self.calculate_and_abort, stretch=2)
107 :
108 : # Create export button with menu
109 1 : export_button = QPushButton("Export")
110 1 : export_button.setObjectName("Export")
111 1 : export_button.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_DialogSaveButton))
112 1 : export_menu = QMenu(self)
113 1 : export_menu.addAction("Export as PNG", self.export_png)
114 1 : export_menu.addAction("Export as Python script", self.export_python)
115 1 : export_menu.addAction("Export as Jupyter notebook", self.export_notebook)
116 1 : export_button.setMenu(export_menu)
117 1 : export_button.setFixedHeight(50)
118 1 : bottom_layout.addWidget(export_button, stretch=1)
119 :
120 1 : self.layout().addLayout(bottom_layout)
121 :
122 1 : def calculate_clicked(self) -> None:
123 1 : self._calculation_finished = False
124 1 : self._plot_finished = False
125 1 : self.before_calculate()
126 :
127 1 : def update_plot(
128 : parameters_and_results: tuple[Parameters[Any], Results],
129 : ) -> None:
130 1 : worker_plot = MultiThreadWorker(self.update_plot, *parameters_and_results)
131 1 : worker_plot.start()
132 1 : worker_plot.signals.finished.connect(lambda _sucess: setattr(self, "_plot_finished", True))
133 :
134 1 : worker = MultiThreadWorker(self.calculate)
135 1 : worker.enable_busy_indicator(self)
136 1 : worker.signals.result.connect(update_plot)
137 1 : worker.signals.finished.connect(self.after_calculate)
138 1 : worker.signals.finished.connect(lambda _sucess: setattr(self, "_calculation_finished", True))
139 1 : worker.start()
140 :
141 1 : def before_calculate(self) -> None:
142 1 : show_status_tip(self, "Calculating... Please wait.", logger=logger)
143 1 : self.calculate_and_abort.setCurrentNamedWidget("Abort")
144 1 : self.plotwidget.clear()
145 :
146 1 : self._start_time = time.perf_counter()
147 :
148 1 : def after_calculate(self, success: bool) -> None:
149 1 : time_needed = time.perf_counter() - self._start_time
150 :
151 1 : if success:
152 1 : show_status_tip(self, f"Calculation finished after {time_needed:.2f} seconds.", logger=logger)
153 : else:
154 0 : show_status_tip(self, f"Calculation failed after {time_needed:.2f} seconds.", logger=logger)
155 :
156 1 : self.calculate_and_abort.setCurrentNamedWidget("Calculate")
157 :
158 1 : def calculate(self) -> tuple[Parameters[Any], Results]:
159 0 : raise NotImplementedError("Subclasses must implement this method")
160 :
161 1 : def update_plot(self, parameters: Parameters[Any], results: Results) -> None:
162 0 : energies = results.energies
163 0 : overlaps = results.ket_overlaps
164 :
165 0 : x_values = parameters.get_x_values()
166 0 : x_label = parameters.get_x_label()
167 :
168 0 : self.plotwidget.plot(x_values, energies, overlaps, x_label)
169 :
170 0 : self.plotwidget.add_cursor(x_values, energies, results.state_labels)
171 :
172 0 : self.plotwidget.canvas.draw()
173 0 : self._plot_finished = False
174 :
175 1 : def export_png(self) -> None:
176 : """Export the current plot as a PNG file."""
177 0 : logger.debug("Exporting results as PNG...")
178 :
179 0 : filename, _ = QFileDialog.getSaveFileName(self, "Save Plot", "", "PNG Files (*.png)")
180 :
181 0 : if filename:
182 0 : filename = filename.removesuffix(".png") + ".png"
183 0 : self.plotwidget.canvas.fig.savefig(
184 : filename, dpi=300, bbox_inches="tight", facecolor="white", edgecolor="none"
185 : )
186 0 : logger.info("Plot saved as %s", filename)
187 :
188 1 : def export_python(self) -> None:
189 : """Export the current calculation as a Python script."""
190 0 : logger.debug("Exporting results as Python script...")
191 :
192 0 : filename, _ = QFileDialog.getSaveFileName(self, "Save Python Script", "", "Python Files (*.py)")
193 :
194 0 : if filename:
195 0 : filename = filename.removesuffix(".py") + ".py"
196 :
197 0 : template_path = (
198 : Path(__file__).parent.parent / "export_templates" / self._get_export_notebook_template_name()
199 : )
200 0 : with Path(template_path).open() as f:
201 0 : notebook = nbformat.read(f, as_version=4)
202 :
203 0 : exporter = PythonExporter(exclude_output_prompt=True, exclude_input_prompt=True)
204 0 : content, _ = exporter.from_notebook_node(notebook)
205 :
206 0 : replacements = self._get_export_replacements()
207 0 : for key, value in replacements.items():
208 0 : content = content.replace(key, str(value))
209 :
210 0 : with Path(filename).open("w") as f:
211 0 : f.write(content)
212 :
213 0 : logger.info("Python script saved as %s", filename)
214 :
215 1 : def export_notebook(self) -> None:
216 : """Export the current calculation as a Jupyter notebook."""
217 0 : logger.debug("Exporting results as Jupyter notebook...")
218 :
219 0 : filename, _ = QFileDialog.getSaveFileName(self, "Save Jupyter Notebook", "", "Jupyter Notebooks (*.ipynb)")
220 :
221 0 : if filename:
222 0 : filename = filename.removesuffix(".ipynb") + ".ipynb"
223 :
224 0 : template_path = (
225 : Path(__file__).parent.parent / "export_templates" / self._get_export_notebook_template_name()
226 : )
227 0 : with Path(template_path).open() as f:
228 0 : notebook = nbformat.read(f, as_version=4)
229 :
230 0 : replacements = self._get_export_replacements()
231 0 : for cell in notebook.cells:
232 0 : if cell.cell_type == "code":
233 0 : source = cell.source
234 0 : for key, value in replacements.items():
235 0 : source = source.replace(key, str(value))
236 0 : cell.source = source
237 :
238 0 : nbformat.write(notebook, filename)
239 :
240 0 : logger.info("Jupyter notebook saved as %s", filename)
241 :
242 1 : def _get_export_notebook_template_name(self) -> str:
243 0 : raise NotImplementedError("Subclasses must implement this method")
244 :
245 1 : def _get_export_replacements(self) -> dict[str, str]:
246 : # Override this method in subclasses to provide specific replacements for the export
247 0 : return {}
248 :
249 1 : def abort_clicked(self) -> None:
250 : """Handle abort button click."""
251 0 : logger.debug("Aborting calculation.")
252 0 : MultiProcessWorker.terminate_all(create_new_pool=True)
253 0 : MultiThreadWorker.terminate_all()
254 0 : self.after_calculate(False)
255 0 : show_status_tip(self, "Calculation aborted.", logger=logger)
|