Line data Source code
1 : # SPDX-FileCopyrightText: 2025 Pairinteraction Developers
2 : # SPDX-License-Identifier: LGPL-3.0-or-later
3 :
4 0 : import logging
5 0 : import time
6 0 : from pathlib import Path
7 0 : from typing import Any, Optional
8 :
9 0 : import nbformat
10 0 : from nbconvert import PythonExporter
11 0 : from PySide6.QtCore import QSize, Qt
12 0 : from PySide6.QtGui import QHideEvent, QMovie, QShowEvent
13 0 : from PySide6.QtWidgets import (
14 : QFileDialog,
15 : QHBoxLayout,
16 : QLabel,
17 : QMenu,
18 : QPushButton,
19 : QStyle,
20 : QToolBox,
21 : )
22 :
23 0 : from pairinteraction_gui.app import Application
24 0 : from pairinteraction_gui.calculate.calculate_base import Parameters, Results
25 0 : from pairinteraction_gui.config import BaseConfig
26 0 : from pairinteraction_gui.plotwidget.plotwidget import PlotEnergies, PlotWidget
27 0 : from pairinteraction_gui.qobjects import WidgetV
28 0 : from pairinteraction_gui.qobjects.events import show_status_tip
29 0 : from pairinteraction_gui.qobjects.named_stacked_widget import NamedStackedWidget
30 0 : from pairinteraction_gui.worker import Worker
31 :
32 0 : logger = logging.getLogger(__name__)
33 :
34 :
35 0 : class BasePage(WidgetV):
36 : """Base class for all pages in this application."""
37 :
38 0 : margin = (20, 20, 20, 20)
39 0 : spacing = 15
40 :
41 0 : title: str
42 0 : tooltip: str
43 0 : icon_path: Optional[Path] = None
44 :
45 0 : def showEvent(self, event: QShowEvent) -> None:
46 : """Show event."""
47 0 : super().showEvent(event)
48 0 : self.window().setWindowTitle("Pairinteraction - " + self.title)
49 :
50 :
51 0 : class SimulationPage(BasePage):
52 : """Base class for all simulation pages in this application."""
53 :
54 0 : plotwidget: PlotWidget
55 :
56 0 : _button_style = """
57 : QPushButton {
58 : padding: 8px 16px;
59 : background-color: #343a40;
60 : color: #ffffff;
61 : border: none;
62 : border-radius: 4px;
63 : font-weight: bold;
64 : font-size: 14px;
65 : }
66 : QPushButton:hover {
67 : background-color: #495057;
68 : }
69 : QPushButton:pressed {
70 : background-color: #000000;
71 : }
72 : """
73 :
74 0 : _button_menu_style = """
75 : QMenu {
76 : background-color: #ffffff;
77 : border: 1px solid #ffffff;
78 : border-radius: 4px;
79 : padding: 4px;
80 : }
81 : QMenu::item {
82 : padding: 6px 24px;
83 : color: #000000;
84 : font-size: 14px;
85 : }
86 : QMenu::item:selected {
87 : background-color: #ffffff;
88 : }
89 : """
90 :
91 0 : def setupWidget(self) -> None:
92 0 : self.toolbox = QToolBox()
93 :
94 0 : def postSetupWidget(self) -> None:
95 0 : for attr in self.__dict__.values():
96 0 : if isinstance(attr, BaseConfig):
97 0 : self.toolbox.addItem(attr, attr.title)
98 :
99 0 : def showEvent(self, event: QShowEvent) -> None:
100 0 : super().showEvent(event)
101 0 : self.window().dockwidget.setWidget(self.toolbox)
102 0 : self.window().dockwidget.setVisible(True)
103 0 : self.toolbox.show()
104 :
105 0 : def hideEvent(self, event: QHideEvent) -> None:
106 0 : super().hideEvent(event)
107 0 : self.window().dockwidget.setVisible(False)
108 :
109 :
110 0 : class CalculationPage(SimulationPage):
111 : """Base class for all pages with a calculation button."""
112 :
113 0 : plotwidget: PlotEnergies
114 :
115 0 : def setupWidget(self) -> None:
116 0 : super().setupWidget()
117 :
118 0 : self.plotwidget = PlotEnergies(self)
119 0 : self.layout().addWidget(self.plotwidget)
120 :
121 : # Setup loading animation
122 0 : self.loading_label = QLabel(self)
123 0 : gif_path = Path(__file__).parent.parent / "images" / "loading.gif"
124 0 : self.loading_movie = QMovie(str(gif_path))
125 0 : self.loading_movie.setScaledSize(QSize(100, 100)) # Make the gif larger
126 0 : self.loading_label.setMovie(self.loading_movie)
127 0 : self.loading_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
128 0 : self.loading_label.hide()
129 :
130 : # Control panel below the plot
131 0 : control_layout = QHBoxLayout()
132 :
133 : # Calculate/Abort stacked buttons
134 0 : self.calculate_and_abort = NamedStackedWidget[QPushButton](self)
135 :
136 0 : calculate_button = QPushButton("Calculate")
137 0 : calculate_button.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_BrowserReload))
138 0 : calculate_button.clicked.connect(self.calculate_clicked)
139 0 : calculate_button.setStyleSheet(self._button_style)
140 0 : self.calculate_and_abort.addNamedWidget(calculate_button, "Calculate")
141 :
142 0 : abort_button = QPushButton("Abort")
143 0 : abort_button.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_BrowserStop))
144 0 : abort_button.clicked.connect(self.abort_clicked)
145 0 : abort_button.setStyleSheet(self._button_style)
146 0 : self.calculate_and_abort.addNamedWidget(abort_button, "Abort")
147 :
148 0 : self.calculate_and_abort.setFixedHeight(50)
149 0 : control_layout.addWidget(self.calculate_and_abort, stretch=2)
150 :
151 : # Create export button with menu
152 0 : export_button = QPushButton("Export")
153 0 : export_button.setObjectName("Export")
154 0 : export_button.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_DialogSaveButton))
155 0 : export_button.setStyleSheet(self._button_style)
156 0 : export_menu = QMenu(self)
157 0 : export_menu.setStyleSheet(self._button_menu_style)
158 0 : export_menu.addAction("Export as PNG", self.export_png)
159 0 : export_menu.addAction("Export as Python script", self.export_python)
160 0 : export_menu.addAction("Export as Jupyter notebook", self.export_notebook)
161 0 : export_button.setMenu(export_menu)
162 0 : export_button.setFixedHeight(50)
163 0 : control_layout.addWidget(export_button, stretch=1)
164 :
165 0 : self.layout().addLayout(control_layout)
166 :
167 0 : def calculate_clicked(self) -> None:
168 0 : self.before_calculate()
169 :
170 0 : def update_plot(
171 : parameters_and_results: tuple[Parameters[Any], Results],
172 : ) -> None:
173 0 : worker_plot = Worker(self.update_plot, *parameters_and_results)
174 0 : worker_plot.start()
175 :
176 0 : worker = Worker(self.calculate)
177 0 : worker.signals.result.connect(update_plot)
178 0 : worker.signals.finished.connect(self.after_calculate)
179 0 : worker.start()
180 :
181 0 : def before_calculate(self) -> None:
182 0 : show_status_tip(self, "Calculating... Please wait.", logger=logger)
183 0 : self.calculate_and_abort.setCurrentNamedWidget("Abort")
184 0 : self.plotwidget.clear()
185 :
186 : # run loading gif
187 0 : self.loading_label.setGeometry((self.width() - 100) // 2, (self.height() - 100) // 2, 100, 100)
188 0 : self.loading_label.show()
189 0 : self.loading_movie.start()
190 :
191 0 : self._start_time = time.perf_counter()
192 :
193 0 : def after_calculate(self, success: bool) -> None:
194 0 : time_needed = time.perf_counter() - self._start_time
195 :
196 : # stop loading gif
197 0 : self.loading_movie.stop()
198 0 : self.loading_label.hide()
199 :
200 0 : if success:
201 0 : show_status_tip(self, f"Calculation finished after {time_needed:.2f} seconds.", logger=logger)
202 : else:
203 0 : show_status_tip(self, f"Calculation failed after {time_needed:.2f} seconds.", logger=logger)
204 :
205 0 : self.calculate_and_abort.setCurrentNamedWidget("Calculate")
206 :
207 0 : def calculate(self) -> tuple[Parameters[Any], Results]:
208 0 : raise NotImplementedError("Subclasses must implement this method")
209 :
210 0 : def update_plot(self, parameters: Parameters[Any], results: Results) -> None:
211 0 : energies = results.energies
212 0 : overlaps = results.ket_overlaps
213 :
214 0 : x_values = parameters.get_x_values()
215 0 : x_label = parameters.get_x_label()
216 :
217 0 : self.plotwidget.plot(x_values, energies, overlaps, x_label)
218 :
219 0 : ind = 0 if parameters.n_atoms == 1 else -1
220 0 : self.plotwidget.add_cursor(x_values[ind], energies[ind], results.state_labels_0)
221 :
222 0 : self.plotwidget.canvas.draw()
223 :
224 0 : def export_png(self) -> None:
225 : """Export the current plot as a PNG file."""
226 0 : logger.debug("Exporting results as PNG...")
227 :
228 0 : filename, _ = QFileDialog.getSaveFileName(self, "Save Plot", "", "PNG Files (*.png)")
229 :
230 0 : if filename:
231 0 : filename = filename.removesuffix(".png") + ".png"
232 0 : self.plotwidget.canvas.fig.savefig(
233 : filename, dpi=300, bbox_inches="tight", facecolor="white", edgecolor="none"
234 : )
235 0 : logger.info("Plot saved as %s", filename)
236 :
237 0 : def export_python(self) -> None:
238 : """Export the current calculation as a Python script."""
239 0 : logger.debug("Exporting results as Python script...")
240 :
241 0 : filename, _ = QFileDialog.getSaveFileName(self, "Save Python Script", "", "Python Files (*.py)")
242 :
243 0 : if filename:
244 0 : filename = filename.removesuffix(".py") + ".py"
245 :
246 0 : template_path = (
247 : Path(__file__).parent.parent / "export_templates" / self._get_export_notebook_template_name()
248 : )
249 0 : with Path(template_path).open() as f:
250 0 : notebook = nbformat.read(f, as_version=4)
251 :
252 0 : exporter = PythonExporter(exclude_output_prompt=True, exclude_input_prompt=True)
253 0 : content, _ = exporter.from_notebook_node(notebook)
254 :
255 0 : replacements = self._get_export_replacements()
256 0 : for key, value in replacements.items():
257 0 : content = content.replace(key, str(value))
258 :
259 0 : with Path(filename).open("w") as f:
260 0 : f.write(content)
261 :
262 0 : logger.info("Python script saved as %s", filename)
263 :
264 0 : def export_notebook(self) -> None:
265 : """Export the current calculation as a Jupyter notebook."""
266 0 : logger.debug("Exporting results as Jupyter notebook...")
267 :
268 0 : filename, _ = QFileDialog.getSaveFileName(self, "Save Jupyter Notebook", "", "Jupyter Notebooks (*.ipynb)")
269 :
270 0 : if filename:
271 0 : filename = filename.removesuffix(".ipynb") + ".ipynb"
272 :
273 0 : template_path = (
274 : Path(__file__).parent.parent / "export_templates" / self._get_export_notebook_template_name()
275 : )
276 0 : with Path(template_path).open() as f:
277 0 : notebook = nbformat.read(f, as_version=4)
278 :
279 0 : replacements = self._get_export_replacements()
280 0 : for cell in notebook.cells:
281 0 : if cell.cell_type == "code":
282 0 : source = cell.source
283 0 : for key, value in replacements.items():
284 0 : source = source.replace(key, str(value))
285 0 : cell.source = source
286 :
287 0 : nbformat.write(notebook, filename)
288 :
289 0 : logger.info("Jupyter notebook saved as %s", filename)
290 :
291 0 : def _get_export_notebook_template_name(self) -> str:
292 0 : raise NotImplementedError("Subclasses must implement this method")
293 :
294 0 : def _get_export_replacements(self) -> dict[str, str]:
295 : # Override this method in subclasses to provide specific replacements for the export
296 0 : return {}
297 :
298 0 : def abort_clicked(self) -> None:
299 : """Handle abort button click."""
300 0 : logger.debug("Aborting calculation.")
301 0 : Application.terminate_all_processes()
302 0 : Application.terminate_all_threads()
303 0 : self.after_calculate(False)
304 0 : show_status_tip(self, "Calculation aborted.", logger=logger)
|