LCOV - code coverage report
Current view: top level - src/pairinteraction_gui/page - base_page.py (source / functions) Hit Total Coverage
Test: coverage.info Lines: 116 158 73.4 %
Date: 2025-09-29 10:28:29 Functions: 11 38 28.9 %

          Line data    Source code
       1             : # SPDX-FileCopyrightText: 2025 PairInteraction Developers
       2             : # SPDX-License-Identifier: LGPL-3.0-or-later
       3           1 : from __future__ import annotations
       4             : 
       5           1 : import logging
       6           1 : import time
       7           1 : from pathlib import Path
       8           1 : from typing import TYPE_CHECKING, Any
       9             : 
      10           1 : import nbformat
      11           1 : from nbconvert import PythonExporter
      12           1 : from PySide6.QtWidgets import (
      13             :     QFileDialog,
      14             :     QHBoxLayout,
      15             :     QMenu,
      16             :     QPushButton,
      17             :     QStyle,
      18             :     QToolBox,
      19             : )
      20             : 
      21           1 : import pairinteraction
      22           1 : from pairinteraction_gui.config import BaseConfig
      23           1 : from pairinteraction_gui.plotwidget.plotwidget import PlotEnergies
      24           1 : from pairinteraction_gui.qobjects import NamedStackedWidget, WidgetV, show_status_tip
      25           1 : from pairinteraction_gui.worker import MultiProcessWorker, MultiThreadWorker
      26             : 
      27             : if TYPE_CHECKING:
      28             :     from PySide6.QtGui import QHideEvent, QShowEvent
      29             : 
      30             :     from pairinteraction_gui.calculate.calculate_base import Parameters, Results
      31             :     from pairinteraction_gui.config.ket_config import KetConfig
      32             :     from pairinteraction_gui.plotwidget.plotwidget import PlotWidget
      33             : 
      34           1 : logger = logging.getLogger(__name__)
      35             : 
      36             : 
      37           1 : class BasePage(WidgetV):
      38             :     """Base class for all pages in this application."""
      39             : 
      40           1 :     margin = (20, 20, 20, 20)
      41           1 :     spacing = 15
      42             : 
      43           1 :     title: str
      44           1 :     tooltip: str
      45           1 :     icon_path: Path | None = None
      46             : 
      47           1 :     def showEvent(self, event: QShowEvent) -> None:
      48             :         """Show event."""
      49           1 :         super().showEvent(event)
      50           1 :         self.window().setWindowTitle(f"PairInteraction v{pairinteraction.__version__} - " + self.title)
      51             : 
      52             : 
      53           1 : class SimulationPage(BasePage):
      54             :     """Base class for all simulation pages in this application."""
      55             : 
      56           1 :     ket_config: KetConfig
      57             : 
      58           1 :     plotwidget: PlotWidget
      59             : 
      60           1 :     def setupWidget(self) -> None:
      61           1 :         self.toolbox = QToolBox()
      62             : 
      63           1 :     def postSetupWidget(self) -> None:
      64           1 :         for attr in self.__dict__.values():
      65           1 :             if isinstance(attr, BaseConfig):
      66           1 :                 self.toolbox.addItem(attr, attr.title)
      67             : 
      68           1 :         for i, species_combo in enumerate(self.ket_config.species_combo_list):
      69           1 :             self.ket_config.signal_species_changed.emit(i, species_combo.currentText())
      70             : 
      71           1 :     def showEvent(self, event: QShowEvent) -> None:
      72           1 :         super().showEvent(event)
      73           1 :         self.window().dockwidget.setWidget(self.toolbox)
      74           1 :         self.window().dockwidget.setVisible(True)
      75           1 :         self.toolbox.show()
      76             : 
      77           1 :     def hideEvent(self, event: QHideEvent) -> None:
      78           1 :         super().hideEvent(event)
      79           1 :         self.window().dockwidget.setVisible(False)
      80             : 
      81             : 
      82           1 : class CalculationPage(SimulationPage):
      83             :     """Base class for all pages with a calculation button."""
      84             : 
      85           1 :     plotwidget: PlotEnergies
      86           1 :     _calculation_finished = False
      87           1 :     _plot_finished = False
      88             : 
      89           1 :     def setupWidget(self) -> None:
      90           1 :         super().setupWidget()
      91             : 
      92             :         # Plot Panel
      93           1 :         self.plotwidget = PlotEnergies(self)
      94           1 :         self.layout().addWidget(self.plotwidget)
      95             : 
      96             :         # Control panel below the plot
      97           1 :         bottom_layout = QHBoxLayout()
      98           1 :         bottom_layout.setObjectName("bottomLayout")
      99             : 
     100             :         # Calculate/Abort stacked buttons
     101           1 :         self.calculate_and_abort = NamedStackedWidget[QPushButton](self)
     102             : 
     103           1 :         calculate_button = QPushButton("Calculate")
     104           1 :         calculate_button.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_BrowserReload))
     105           1 :         calculate_button.clicked.connect(self.calculate_clicked)
     106           1 :         self.calculate_and_abort.addNamedWidget(calculate_button, "Calculate")
     107             : 
     108           1 :         abort_button = QPushButton("Abort")
     109           1 :         abort_button.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_BrowserStop))
     110           1 :         abort_button.clicked.connect(self.abort_clicked)
     111           1 :         self.calculate_and_abort.addNamedWidget(abort_button, "Abort")
     112             : 
     113           1 :         self.calculate_and_abort.setFixedHeight(50)
     114           1 :         bottom_layout.addWidget(self.calculate_and_abort, stretch=2)
     115             : 
     116             :         # Create export button with menu
     117           1 :         export_button = QPushButton("Export")
     118           1 :         export_button.setObjectName("Export")
     119           1 :         export_button.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_DialogSaveButton))
     120           1 :         export_menu = QMenu(self)
     121           1 :         export_menu.addAction("Export as PNG", self.export_png)
     122           1 :         export_menu.addAction("Export as Python script", self.export_python)
     123           1 :         export_menu.addAction("Export as Jupyter notebook", self.export_notebook)
     124           1 :         export_button.setMenu(export_menu)
     125           1 :         export_button.setFixedHeight(50)
     126           1 :         bottom_layout.addWidget(export_button, stretch=1)
     127             : 
     128           1 :         self.layout().addLayout(bottom_layout)
     129             : 
     130           1 :     def calculate_clicked(self) -> None:
     131           1 :         self._calculation_finished = False
     132           1 :         self._plot_finished = False
     133           1 :         self.before_calculate()
     134             : 
     135           1 :         def update_plot(
     136             :             parameters_and_results: tuple[Parameters[Any], Results],
     137             :         ) -> None:
     138           1 :             worker_plot = MultiThreadWorker(self.update_plot, *parameters_and_results)
     139           1 :             worker_plot.start()
     140           1 :             worker_plot.signals.finished.connect(lambda _sucess: setattr(self, "_plot_finished", True))
     141             : 
     142           1 :         worker = MultiThreadWorker(self.calculate)
     143           1 :         worker.enable_busy_indicator(self)
     144           1 :         worker.signals.result.connect(update_plot)
     145           1 :         worker.signals.finished.connect(self.after_calculate)
     146           1 :         worker.signals.finished.connect(lambda _sucess: setattr(self, "_calculation_finished", True))
     147           1 :         worker.start()
     148             : 
     149           1 :     def before_calculate(self) -> None:
     150           1 :         show_status_tip(self, "Calculating... Please wait.", logger=logger)
     151           1 :         self.calculate_and_abort.setCurrentNamedWidget("Abort")
     152           1 :         self.plotwidget.clear()
     153             : 
     154           1 :         self._start_time = time.perf_counter()
     155             : 
     156           1 :     def after_calculate(self, success: bool) -> None:
     157           1 :         time_needed = time.perf_counter() - self._start_time
     158             : 
     159           1 :         if success:
     160           1 :             show_status_tip(self, f"Calculation finished after {time_needed:.2f} seconds.", logger=logger)
     161             :         else:
     162           0 :             show_status_tip(self, f"Calculation failed after {time_needed:.2f} seconds.", logger=logger)
     163             : 
     164           1 :         self.calculate_and_abort.setCurrentNamedWidget("Calculate")
     165             : 
     166           1 :     def calculate(self) -> tuple[Parameters[Any], Results]:
     167           0 :         raise NotImplementedError("Subclasses must implement this method")
     168             : 
     169           1 :     def update_plot(self, parameters: Parameters[Any], results: Results) -> None:
     170           0 :         self.plotwidget.plot(parameters, results)
     171           0 :         self.plotwidget.add_cursor(parameters, results)
     172           0 :         self.plotwidget.canvas.draw()
     173             : 
     174           1 :     def export_png(self) -> None:
     175             :         """Export the current plot as a PNG file."""
     176           0 :         logger.debug("Exporting results as PNG...")
     177             : 
     178           0 :         filename, _ = QFileDialog.getSaveFileName(self, "Save Plot", "", "PNG Files (*.png)")
     179             : 
     180           0 :         if filename:
     181           0 :             filename = filename.removesuffix(".png") + ".png"
     182           0 :             self.plotwidget.canvas.fig.savefig(
     183             :                 filename, dpi=300, bbox_inches="tight", facecolor="white", edgecolor="none"
     184             :             )
     185           0 :             logger.info("Plot saved as %s", filename)
     186             : 
     187           1 :     def _create_python_code(self) -> str:
     188           1 :         template_path = Path(__file__).parent.parent / "export_templates" / self._get_export_notebook_template_name()
     189           1 :         with Path(template_path).open() as f:
     190           1 :             notebook = nbformat.read(f, as_version=4)
     191             : 
     192           1 :         exporter = PythonExporter(exclude_output_prompt=True, exclude_input_prompt=True)
     193           1 :         content, _ = exporter.from_notebook_node(notebook)
     194             : 
     195           1 :         replacements = self._get_export_replacements()
     196           1 :         for key, value in replacements.items():
     197           1 :             content = content.replace(key, str(value))
     198             : 
     199           1 :         return content
     200             : 
     201           1 :     def export_python(self) -> None:
     202             :         """Export the current calculation as a Python script."""
     203           0 :         logger.debug("Exporting results as Python script...")
     204           0 :         filename, _ = QFileDialog.getSaveFileName(self, "Save Python Script", "", "Python Files (*.py)")
     205           0 :         if filename:
     206           0 :             filename = filename.removesuffix(".py") + ".py"
     207           0 :             content = self._create_python_code()
     208           0 :             with Path(filename).open("w") as f:
     209           0 :                 f.write(content)
     210           0 :             logger.info("Python script saved as %s", filename)
     211             : 
     212           1 :     def export_notebook(self) -> None:
     213             :         """Export the current calculation as a Jupyter notebook."""
     214           0 :         logger.debug("Exporting results as Jupyter notebook...")
     215             : 
     216           0 :         filename, _ = QFileDialog.getSaveFileName(self, "Save Jupyter Notebook", "", "Jupyter Notebooks (*.ipynb)")
     217             : 
     218           0 :         if filename:
     219           0 :             filename = filename.removesuffix(".ipynb") + ".ipynb"
     220             : 
     221           0 :             template_path = (
     222             :                 Path(__file__).parent.parent / "export_templates" / self._get_export_notebook_template_name()
     223             :             )
     224           0 :             with Path(template_path).open() as f:
     225           0 :                 notebook = nbformat.read(f, as_version=4)
     226             : 
     227           0 :             replacements = self._get_export_replacements()
     228           0 :             for cell in notebook.cells:
     229           0 :                 if cell.cell_type == "code":
     230           0 :                     source = cell.source
     231           0 :                     for key, value in replacements.items():
     232           0 :                         source = source.replace(key, str(value))
     233           0 :                     cell.source = source
     234             : 
     235           0 :             nbformat.write(notebook, filename)
     236             : 
     237           0 :             logger.info("Jupyter notebook saved as %s", filename)
     238             : 
     239           1 :     def _get_export_notebook_template_name(self) -> str:
     240           0 :         raise NotImplementedError("Subclasses must implement this method")
     241             : 
     242           1 :     def _get_export_replacements(self) -> dict[str, str]:
     243             :         # Override this method in subclasses to provide specific replacements for the export
     244           0 :         return {}
     245             : 
     246           1 :     def abort_clicked(self) -> None:
     247             :         """Handle abort button click."""
     248           0 :         logger.debug("Aborting calculation.")
     249           0 :         MultiProcessWorker.terminate_all(create_new_pool=True)
     250           0 :         MultiThreadWorker.terminate_all()
     251           0 :         self.after_calculate(False)
     252           0 :         show_status_tip(self, "Calculation aborted.", logger=logger)

Generated by: LCOV version 1.16