Line data Source code
1 : # SPDX-FileCopyrightText: 2025 Pairinteraction Developers
2 : # SPDX-License-Identifier: LGPL-3.0-or-later
3 :
4 1 : import logging
5 1 : import time
6 1 : from pathlib import Path
7 1 : from typing import Any, Optional
8 :
9 1 : import nbformat
10 1 : from nbconvert import PythonExporter
11 1 : from PySide6.QtGui import QHideEvent, QShowEvent
12 1 : from PySide6.QtWidgets import (
13 : QFileDialog,
14 : QHBoxLayout,
15 : QMenu,
16 : QPushButton,
17 : QStyle,
18 : QToolBox,
19 : )
20 :
21 1 : import pairinteraction
22 1 : from pairinteraction_gui.calculate.calculate_base import Parameters, Results
23 1 : from pairinteraction_gui.config import BaseConfig
24 1 : from pairinteraction_gui.config.ket_config import KetConfig
25 1 : from pairinteraction_gui.plotwidget.plotwidget import PlotEnergies, PlotWidget
26 1 : from pairinteraction_gui.qobjects import NamedStackedWidget, WidgetV, show_status_tip
27 1 : from pairinteraction_gui.worker import MultiProcessWorker, MultiThreadWorker
28 :
29 1 : logger = logging.getLogger(__name__)
30 :
31 :
32 1 : class BasePage(WidgetV):
33 : """Base class for all pages in this application."""
34 :
35 1 : margin = (20, 20, 20, 20)
36 1 : spacing = 15
37 :
38 1 : title: str
39 1 : tooltip: str
40 1 : icon_path: Optional[Path] = None
41 :
42 1 : def showEvent(self, event: QShowEvent) -> None:
43 : """Show event."""
44 1 : super().showEvent(event)
45 1 : self.window().setWindowTitle(f"PairInteraction v{pairinteraction.__version__} - " + self.title)
46 :
47 :
48 1 : class SimulationPage(BasePage):
49 : """Base class for all simulation pages in this application."""
50 :
51 1 : ket_config: KetConfig
52 :
53 1 : plotwidget: PlotWidget
54 :
55 1 : def setupWidget(self) -> None:
56 1 : self.toolbox = QToolBox()
57 :
58 1 : def postSetupWidget(self) -> None:
59 1 : for attr in self.__dict__.values():
60 1 : if isinstance(attr, BaseConfig):
61 1 : self.toolbox.addItem(attr, attr.title)
62 :
63 1 : def showEvent(self, event: QShowEvent) -> None:
64 1 : super().showEvent(event)
65 1 : self.window().dockwidget.setWidget(self.toolbox)
66 1 : self.window().dockwidget.setVisible(True)
67 1 : self.toolbox.show()
68 :
69 1 : def hideEvent(self, event: QHideEvent) -> None:
70 1 : super().hideEvent(event)
71 1 : self.window().dockwidget.setVisible(False)
72 :
73 :
74 1 : class CalculationPage(SimulationPage):
75 : """Base class for all pages with a calculation button."""
76 :
77 1 : plotwidget: PlotEnergies
78 1 : _calculation_finished = False
79 1 : _plot_finished = False
80 :
81 1 : def setupWidget(self) -> None:
82 1 : super().setupWidget()
83 :
84 : # Plot Panel
85 1 : self.plotwidget = PlotEnergies(self)
86 1 : self.layout().addWidget(self.plotwidget)
87 :
88 : # Control panel below the plot
89 1 : bottom_layout = QHBoxLayout()
90 :
91 : # Calculate/Abort stacked buttons
92 1 : self.calculate_and_abort = NamedStackedWidget[QPushButton](self)
93 :
94 1 : calculate_button = QPushButton("Calculate")
95 1 : calculate_button.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_BrowserReload))
96 1 : calculate_button.clicked.connect(self.calculate_clicked)
97 1 : self.calculate_and_abort.addNamedWidget(calculate_button, "Calculate")
98 :
99 1 : abort_button = QPushButton("Abort")
100 1 : abort_button.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_BrowserStop))
101 1 : abort_button.clicked.connect(self.abort_clicked)
102 1 : self.calculate_and_abort.addNamedWidget(abort_button, "Abort")
103 :
104 1 : self.calculate_and_abort.setFixedHeight(50)
105 1 : bottom_layout.addWidget(self.calculate_and_abort, stretch=2)
106 :
107 : # Create export button with menu
108 1 : export_button = QPushButton("Export")
109 1 : export_button.setObjectName("Export")
110 1 : export_button.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_DialogSaveButton))
111 1 : export_menu = QMenu(self)
112 1 : export_menu.addAction("Export as PNG", self.export_png)
113 1 : export_menu.addAction("Export as Python script", self.export_python)
114 1 : export_menu.addAction("Export as Jupyter notebook", self.export_notebook)
115 1 : export_button.setMenu(export_menu)
116 1 : export_button.setFixedHeight(50)
117 1 : bottom_layout.addWidget(export_button, stretch=1)
118 :
119 1 : self.layout().addLayout(bottom_layout)
120 :
121 1 : def calculate_clicked(self) -> None:
122 1 : self._calculation_finished = False
123 1 : self._plot_finished = False
124 1 : self.before_calculate()
125 :
126 1 : def update_plot(
127 : parameters_and_results: tuple[Parameters[Any], Results],
128 : ) -> None:
129 1 : worker_plot = MultiThreadWorker(self.update_plot, *parameters_and_results)
130 1 : worker_plot.start()
131 1 : worker_plot.signals.finished.connect(lambda _sucess: setattr(self, "_plot_finished", True))
132 :
133 1 : worker = MultiThreadWorker(self.calculate)
134 1 : worker.enable_busy_indicator(self)
135 1 : worker.signals.result.connect(update_plot)
136 1 : worker.signals.finished.connect(self.after_calculate)
137 1 : worker.signals.finished.connect(lambda _sucess: setattr(self, "_calculation_finished", True))
138 1 : worker.start()
139 :
140 1 : def before_calculate(self) -> None:
141 1 : show_status_tip(self, "Calculating... Please wait.", logger=logger)
142 1 : self.calculate_and_abort.setCurrentNamedWidget("Abort")
143 1 : self.plotwidget.clear()
144 :
145 1 : self._start_time = time.perf_counter()
146 :
147 1 : def after_calculate(self, success: bool) -> None:
148 1 : time_needed = time.perf_counter() - self._start_time
149 :
150 1 : if success:
151 1 : show_status_tip(self, f"Calculation finished after {time_needed:.2f} seconds.", logger=logger)
152 : else:
153 0 : show_status_tip(self, f"Calculation failed after {time_needed:.2f} seconds.", logger=logger)
154 :
155 1 : self.calculate_and_abort.setCurrentNamedWidget("Calculate")
156 :
157 1 : def calculate(self) -> tuple[Parameters[Any], Results]:
158 0 : raise NotImplementedError("Subclasses must implement this method")
159 :
160 1 : def update_plot(self, parameters: Parameters[Any], results: Results) -> None:
161 0 : energies = results.energies
162 0 : overlaps = results.ket_overlaps
163 :
164 0 : x_values = parameters.get_x_values()
165 0 : x_label = parameters.get_x_label()
166 :
167 0 : self.plotwidget.plot(x_values, energies, overlaps, x_label)
168 :
169 0 : self.plotwidget.add_cursor(x_values, energies, results.state_labels)
170 :
171 0 : self.plotwidget.canvas.draw()
172 0 : self._plot_finished = False
173 :
174 1 : def export_png(self) -> None:
175 : """Export the current plot as a PNG file."""
176 0 : logger.debug("Exporting results as PNG...")
177 :
178 0 : filename, _ = QFileDialog.getSaveFileName(self, "Save Plot", "", "PNG Files (*.png)")
179 :
180 0 : if filename:
181 0 : filename = filename.removesuffix(".png") + ".png"
182 0 : self.plotwidget.canvas.fig.savefig(
183 : filename, dpi=300, bbox_inches="tight", facecolor="white", edgecolor="none"
184 : )
185 0 : logger.info("Plot saved as %s", filename)
186 :
187 1 : def export_python(self) -> None:
188 : """Export the current calculation as a Python script."""
189 0 : logger.debug("Exporting results as Python script...")
190 :
191 0 : filename, _ = QFileDialog.getSaveFileName(self, "Save Python Script", "", "Python Files (*.py)")
192 :
193 0 : if filename:
194 0 : filename = filename.removesuffix(".py") + ".py"
195 :
196 0 : template_path = (
197 : Path(__file__).parent.parent / "export_templates" / self._get_export_notebook_template_name()
198 : )
199 0 : with Path(template_path).open() as f:
200 0 : notebook = nbformat.read(f, as_version=4)
201 :
202 0 : exporter = PythonExporter(exclude_output_prompt=True, exclude_input_prompt=True)
203 0 : content, _ = exporter.from_notebook_node(notebook)
204 :
205 0 : replacements = self._get_export_replacements()
206 0 : for key, value in replacements.items():
207 0 : content = content.replace(key, str(value))
208 :
209 0 : with Path(filename).open("w") as f:
210 0 : f.write(content)
211 :
212 0 : logger.info("Python script saved as %s", filename)
213 :
214 1 : def export_notebook(self) -> None:
215 : """Export the current calculation as a Jupyter notebook."""
216 0 : logger.debug("Exporting results as Jupyter notebook...")
217 :
218 0 : filename, _ = QFileDialog.getSaveFileName(self, "Save Jupyter Notebook", "", "Jupyter Notebooks (*.ipynb)")
219 :
220 0 : if filename:
221 0 : filename = filename.removesuffix(".ipynb") + ".ipynb"
222 :
223 0 : template_path = (
224 : Path(__file__).parent.parent / "export_templates" / self._get_export_notebook_template_name()
225 : )
226 0 : with Path(template_path).open() as f:
227 0 : notebook = nbformat.read(f, as_version=4)
228 :
229 0 : replacements = self._get_export_replacements()
230 0 : for cell in notebook.cells:
231 0 : if cell.cell_type == "code":
232 0 : source = cell.source
233 0 : for key, value in replacements.items():
234 0 : source = source.replace(key, str(value))
235 0 : cell.source = source
236 :
237 0 : nbformat.write(notebook, filename)
238 :
239 0 : logger.info("Jupyter notebook saved as %s", filename)
240 :
241 1 : def _get_export_notebook_template_name(self) -> str:
242 0 : raise NotImplementedError("Subclasses must implement this method")
243 :
244 1 : def _get_export_replacements(self) -> dict[str, str]:
245 : # Override this method in subclasses to provide specific replacements for the export
246 0 : return {}
247 :
248 1 : def abort_clicked(self) -> None:
249 : """Handle abort button click."""
250 0 : logger.debug("Aborting calculation.")
251 0 : MultiProcessWorker.terminate_all(create_new_pool=True)
252 0 : MultiThreadWorker.terminate_all()
253 0 : self.after_calculate(False)
254 0 : show_status_tip(self, "Calculation aborted.", logger=logger)
|