LCOV - code coverage report
Current view: top level - src/pairinteraction_gui/page - base_page.py (source / functions) Hit Total Coverage
Test: coverage.info Lines: 136 174 78.2 %
Date: 2026-04-17 09:29:39 Functions: 14 22 63.6 %

          Line data    Source code
       1             : # SPDX-FileCopyrightText: 2025 PairInteraction Developers
       2             : # SPDX-License-Identifier: LGPL-3.0-or-later
       3           1 : from __future__ import annotations
       4             : 
       5           1 : import logging
       6           1 : import time
       7           1 : from pathlib import Path
       8           1 : from typing import TYPE_CHECKING, Any
       9             : 
      10           1 : import nbformat
      11           1 : from nbconvert import PythonExporter
      12           1 : from PySide6.QtCore import Qt
      13           1 : from PySide6.QtGui import QIcon, QPixmap
      14           1 : from PySide6.QtWidgets import (
      15             :     QFileDialog,
      16             :     QHBoxLayout,
      17             :     QMenu,
      18             :     QPushButton,
      19             :     QStyle,
      20             :     QToolBox,
      21             : )
      22             : 
      23           1 : import pairinteraction
      24           1 : from pairinteraction_gui.config import BaseConfig
      25           1 : from pairinteraction_gui.plotwidget.plotwidget import PlotEnergies
      26           1 : from pairinteraction_gui.qobjects import NamedStackedWidget, WidgetV, show_status_tip
      27           1 : from pairinteraction_gui.worker import MultiThreadWorker
      28             : 
      29             : if TYPE_CHECKING:
      30             :     from collections.abc import Callable
      31             : 
      32             :     from PySide6.QtGui import QHideEvent, QShowEvent
      33             : 
      34             :     from pairinteraction_gui.calculate.calculate_base import Parameters, Results
      35             :     from pairinteraction_gui.config.calculation_config import CalculationConfig
      36             :     from pairinteraction_gui.config.ket_config import KetConfig
      37             :     from pairinteraction_gui.plotwidget.plotwidget import PlotWidget
      38             : 
      39           1 : logger = logging.getLogger(__name__)
      40             : 
      41             : 
      42           1 : class BasePage(WidgetV):
      43             :     """Base class for all pages in this application."""
      44             : 
      45           1 :     margin = (20, 20, 20, 20)
      46           1 :     spacing = 15
      47             : 
      48           1 :     title: str
      49           1 :     tooltip: str
      50           1 :     icon_path: Path | None = None
      51             : 
      52           1 :     def showEvent(self, event: QShowEvent) -> None:
      53             :         """Show event."""
      54           1 :         super().showEvent(event)
      55           1 :         self.window().setWindowTitle(
      56             :             f"PairInteraction v{pairinteraction.__version__} - " + self.title.replace("\n", " ")
      57             :         )
      58             : 
      59             : 
      60           1 : class SimulationPage(BasePage):
      61             :     """Base class for all simulation pages in this application."""
      62             : 
      63           1 :     ket_config: KetConfig
      64             : 
      65           1 :     plotwidget: PlotWidget
      66             : 
      67           1 :     def setupWidget(self) -> None:
      68           1 :         self.toolbox = QToolBox()
      69             : 
      70             :         # Create a dummy icon to allow adjusting the height of the toolbox tabs,
      71             :         # see https://stackoverflow.com/questions/48503645/customizing-qtoolbox-tab-height
      72           1 :         px = QPixmap(1, 1)
      73           1 :         px.fill(Qt.GlobalColor.transparent)
      74           1 :         self._toolbox_dummy_icon = QIcon(px)
      75             : 
      76           1 :     def postSetupWidget(self) -> None:
      77           1 :         for attr in self.__dict__.values():
      78           1 :             if isinstance(attr, BaseConfig):
      79           1 :                 self.toolbox.addItem(attr, self._toolbox_dummy_icon, attr.title)
      80             : 
      81           1 :         for i, species_combo in enumerate(self.ket_config.species_combo_list):
      82           1 :             self.ket_config.signal_species_changed.emit(i, species_combo.currentText())
      83             : 
      84           1 :     def showEvent(self, event: QShowEvent) -> None:
      85           1 :         super().showEvent(event)
      86           1 :         self.window().dockwidget.setWidget(self.toolbox)
      87           1 :         self.window().dockwidget.setVisible(True)
      88           1 :         self.toolbox.show()
      89             : 
      90           1 :     def hideEvent(self, event: QHideEvent) -> None:
      91           1 :         super().hideEvent(event)
      92           1 :         self.window().dockwidget.setVisible(False)
      93             : 
      94             : 
      95           1 : class CalculationPage(SimulationPage):
      96             :     """Base class for all pages with a calculation button."""
      97             : 
      98           1 :     plotwidget: PlotEnergies
      99           1 :     _calculation_finished = False
     100           1 :     _plot_finished = False
     101             : 
     102           1 :     def setupWidget(self) -> None:
     103           1 :         super().setupWidget()
     104             : 
     105             :         # Plot Panel
     106           1 :         self.plotwidget = self._create_plot_widget()
     107           1 :         self.layout().addWidget(self.plotwidget)
     108             : 
     109             :         # Control panel below the plot
     110           1 :         bottom_layout = QHBoxLayout()
     111           1 :         bottom_layout.setObjectName("bottomLayout")
     112             : 
     113             :         # Calculate/Abort stacked buttons
     114           1 :         self.calculate_and_abort = NamedStackedWidget[QPushButton](self)
     115             : 
     116           1 :         calculate_button = QPushButton("Calculate")
     117           1 :         calculate_button.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_BrowserReload))
     118           1 :         calculate_button.clicked.connect(self.calculate_clicked)
     119           1 :         self.calculate_and_abort.addNamedWidget(calculate_button, "Calculate")
     120             : 
     121           1 :         abort_button = QPushButton("Abort")
     122           1 :         abort_button.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_BrowserStop))
     123           1 :         abort_button.clicked.connect(self.abort_clicked)
     124           1 :         self.calculate_and_abort.addNamedWidget(abort_button, "Abort")
     125             : 
     126           1 :         self.calculate_and_abort.setFixedHeight(50)
     127           1 :         bottom_layout.addWidget(self.calculate_and_abort, stretch=2)
     128             : 
     129             :         # Create export button with menu
     130           1 :         export_button = QPushButton("Export")
     131           1 :         export_button.setObjectName("Export")
     132           1 :         export_button.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_DialogSaveButton))
     133           1 :         export_menu = QMenu(self)
     134           1 :         for label, handler in self._get_export_actions():
     135           1 :             export_menu.addAction(label, handler)
     136           1 :         export_button.setMenu(export_menu)
     137           1 :         export_button.setFixedHeight(50)
     138           1 :         bottom_layout.addWidget(export_button, stretch=1)
     139             : 
     140           1 :         self.layout().addLayout(bottom_layout)
     141             : 
     142           1 :     def calculate_clicked(self) -> None:
     143           1 :         self._calculation_finished = False
     144           1 :         self._plot_finished = False
     145           1 :         self.before_calculate()
     146             : 
     147           1 :         def update_plot(
     148             :             parameters_and_results: tuple[Parameters[Any], Results],
     149             :         ) -> None:
     150           1 :             worker_plot = MultiThreadWorker(self.update_plot, *parameters_and_results)
     151           1 :             worker_plot.signals.progress.connect(lambda message: show_status_tip(self, message))
     152           1 :             worker_plot.signals.finished.connect(lambda _: setattr(self, "_plot_finished", True))
     153           1 :             worker_plot.start()
     154             : 
     155           1 :         worker = MultiThreadWorker(self.calculate)
     156           1 :         if hasattr(self, "calculation_config"):
     157           1 :             calculation_config: CalculationConfig = self.calculation_config
     158           1 :             number_of_steps = calculation_config.steps.value()
     159           1 :             worker.enable_busy_indicator(self.plotwidget, add_progress_label=True, number_of_steps=number_of_steps)
     160             :         else:
     161           0 :             worker.enable_busy_indicator(self.plotwidget)
     162           1 :         worker.signals.progress.connect(lambda message: show_status_tip(self, message))
     163           1 :         worker.signals.result.connect(update_plot)
     164           1 :         worker.signals.finished.connect(self.after_calculate)
     165           1 :         worker.signals.finished.connect(lambda _: setattr(self, "_calculation_finished", True))
     166           1 :         worker.start()
     167             : 
     168           1 :     def before_calculate(self) -> None:
     169           1 :         show_status_tip(self, "Calculating... Please wait.", logger=logger)
     170           1 :         self.calculate_and_abort.setCurrentNamedWidget("Abort")
     171           1 :         self.plotwidget.clear()
     172             : 
     173           1 :         self._start_time = time.perf_counter()
     174             : 
     175           1 :     def after_calculate(self, status: str) -> None:
     176           1 :         time_needed = time.perf_counter() - self._start_time
     177           1 :         show_status_tip(self, f"{status} after {time_needed:.2f} seconds.", logger=logger)
     178           1 :         self.calculate_and_abort.setCurrentNamedWidget("Calculate")
     179             : 
     180           1 :     def calculate(self) -> tuple[Parameters[Any], Results]:
     181           0 :         raise NotImplementedError("Subclasses must implement this method")
     182             : 
     183           1 :     def update_plot(self, parameters: Parameters[Any], results: Results) -> None:
     184           1 :         self.plotwidget.canvas.draw()  # draw once before, to avoid displaying artifacts during plotting
     185           1 :         self.plotwidget.plot(parameters, results)
     186           1 :         self.plotwidget.setup_annotations(parameters, results)
     187           1 :         self._plot_function(parameters, results)
     188           1 :         self.plotwidget.canvas.draw()
     189           1 :         self.plotwidget.navigation_toolbar.reset_home_view()
     190           1 :         show_status_tip(self, "Finished updating plot. Tip: Click on the plot to see state information.", logger=logger)
     191             : 
     192           1 :     def _plot_function(self, parameters: Parameters[Any], results: Results) -> None:
     193             :         # This method can be overridden by subclasses to provide a custom plotting function
     194             :         # that is called after the default plotting and before drawing the canvas.
     195           0 :         pass
     196             : 
     197           1 :     def export_png(self) -> None:
     198             :         """Export the current plot as a PNG file."""
     199           0 :         logger.debug("Exporting results as PNG...")
     200             : 
     201           0 :         filename, _ = QFileDialog.getSaveFileName(self, "Save Plot", "", "PNG Files (*.png)")
     202             : 
     203           0 :         if filename:
     204           0 :             filename = filename.removesuffix(".png") + ".png"
     205           0 :             self.plotwidget.canvas.fig.savefig(
     206             :                 filename, dpi=300, bbox_inches="tight", facecolor="white", edgecolor="none"
     207             :             )
     208           0 :             logger.info("Plot saved as %s", filename)
     209             : 
     210           1 :     def _create_python_code(self) -> str:
     211           1 :         template_path = Path(__file__).parent.parent / "export_templates" / self._get_export_notebook_template_name()
     212           1 :         with Path(template_path).open() as f:
     213           1 :             notebook = nbformat.read(f, as_version=4)
     214             : 
     215           1 :         exporter = PythonExporter(exclude_output_prompt=True, exclude_input_prompt=True)
     216           1 :         content, _ = exporter.from_notebook_node(notebook)
     217             : 
     218           1 :         replacements = self._get_export_replacements()
     219           1 :         for key, value in replacements.items():
     220           1 :             content = content.replace(key, str(value))
     221             : 
     222           1 :         return content
     223             : 
     224           1 :     def export_python(self) -> None:
     225             :         """Export the current calculation as a Python script."""
     226           0 :         logger.debug("Exporting results as Python script...")
     227           0 :         filename, _ = QFileDialog.getSaveFileName(self, "Save Python Script", "", "Python Files (*.py)")
     228           0 :         if filename:
     229           0 :             filename = filename.removesuffix(".py") + ".py"
     230           0 :             content = self._create_python_code()
     231           0 :             with Path(filename).open("w") as f:
     232           0 :                 f.write(content)
     233           0 :             logger.info("Python script saved as %s", filename)
     234             : 
     235           1 :     def export_notebook(self) -> None:
     236             :         """Export the current calculation as a Jupyter notebook."""
     237           0 :         logger.debug("Exporting results as Jupyter notebook...")
     238             : 
     239           0 :         filename, _ = QFileDialog.getSaveFileName(self, "Save Jupyter Notebook", "", "Jupyter Notebooks (*.ipynb)")
     240             : 
     241           0 :         if filename:
     242           0 :             filename = filename.removesuffix(".ipynb") + ".ipynb"
     243             : 
     244           0 :             template_path = (
     245             :                 Path(__file__).parent.parent / "export_templates" / self._get_export_notebook_template_name()
     246             :             )
     247           0 :             with Path(template_path).open() as f:
     248           0 :                 notebook = nbformat.read(f, as_version=4)
     249             : 
     250           0 :             replacements = self._get_export_replacements()
     251           0 :             for cell in notebook.cells:
     252           0 :                 if cell.cell_type == "code":
     253           0 :                     source = cell.source
     254           0 :                     for key, value in replacements.items():
     255           0 :                         source = source.replace(key, str(value))
     256           0 :                     cell.source = source
     257             : 
     258           0 :             nbformat.write(notebook, filename)
     259             : 
     260           0 :             logger.info("Jupyter notebook saved as %s", filename)
     261             : 
     262           1 :     def _get_export_notebook_template_name(self) -> str:
     263           0 :         raise NotImplementedError("Subclasses must implement this method")
     264             : 
     265           1 :     def _get_export_replacements(self) -> dict[str, str]:
     266             :         # Override this method in subclasses to provide specific replacements for the export
     267           0 :         return {}
     268             : 
     269           1 :     def abort_clicked(self) -> None:
     270             :         """Handle abort button click."""
     271           0 :         logger.debug("Aborting calculation.")
     272           0 :         MultiThreadWorker.terminate_all()
     273           0 :         self.after_calculate("Calculation aborted.")
     274             : 
     275           1 :     def _create_plot_widget(self) -> PlotEnergies:
     276           1 :         return PlotEnergies(self)
     277             : 
     278           1 :     def _get_export_actions(self) -> list[tuple[str, Callable[[], None]]]:
     279           1 :         return [
     280             :             ("Export as PNG", self.export_png),
     281             :             ("Export as Python script", self.export_python),
     282             :             ("Export as Jupyter notebook", self.export_notebook),
     283             :         ]

Generated by: LCOV version 1.16