LCOV - code coverage report
Current view: top level - src/pairinteraction_gui/page - base_page.py (source / functions) Hit Total Coverage
Test: coverage.info Lines: 106 160 66.2 %
Date: 2025-07-05 13:02:33 Functions: 10 36 27.8 %

          Line data    Source code
       1             : # SPDX-FileCopyrightText: 2025 Pairinteraction Developers
       2             : # SPDX-License-Identifier: LGPL-3.0-or-later
       3             : 
       4           1 : import logging
       5           1 : import time
       6           1 : from pathlib import Path
       7           1 : from typing import Any, Optional
       8             : 
       9           1 : import nbformat
      10           1 : from nbconvert import PythonExporter
      11           1 : from PySide6.QtGui import QHideEvent, QShowEvent
      12           1 : from PySide6.QtWidgets import (
      13             :     QFileDialog,
      14             :     QHBoxLayout,
      15             :     QMenu,
      16             :     QPushButton,
      17             :     QStyle,
      18             :     QToolBox,
      19             : )
      20             : 
      21           1 : import pairinteraction
      22           1 : from pairinteraction_gui.calculate.calculate_base import Parameters, Results
      23           1 : from pairinteraction_gui.config import BaseConfig
      24           1 : from pairinteraction_gui.config.ket_config import KetConfig
      25           1 : from pairinteraction_gui.plotwidget.plotwidget import PlotEnergies, PlotWidget
      26           1 : from pairinteraction_gui.qobjects import NamedStackedWidget, WidgetV, show_status_tip
      27           1 : from pairinteraction_gui.worker import MultiProcessWorker, MultiThreadWorker
      28             : 
      29           1 : logger = logging.getLogger(__name__)
      30             : 
      31             : 
      32           1 : class BasePage(WidgetV):
      33             :     """Base class for all pages in this application."""
      34             : 
      35           1 :     margin = (20, 20, 20, 20)
      36           1 :     spacing = 15
      37             : 
      38           1 :     title: str
      39           1 :     tooltip: str
      40           1 :     icon_path: Optional[Path] = None
      41             : 
      42           1 :     def showEvent(self, event: QShowEvent) -> None:
      43             :         """Show event."""
      44           1 :         super().showEvent(event)
      45           1 :         self.window().setWindowTitle(f"PairInteraction v{pairinteraction.__version__} - " + self.title)
      46             : 
      47             : 
      48           1 : class SimulationPage(BasePage):
      49             :     """Base class for all simulation pages in this application."""
      50             : 
      51           1 :     ket_config: KetConfig
      52             : 
      53           1 :     plotwidget: PlotWidget
      54             : 
      55           1 :     def setupWidget(self) -> None:
      56           1 :         self.toolbox = QToolBox()
      57             : 
      58           1 :     def postSetupWidget(self) -> None:
      59           1 :         for attr in self.__dict__.values():
      60           1 :             if isinstance(attr, BaseConfig):
      61           1 :                 self.toolbox.addItem(attr, attr.title)
      62             : 
      63           1 :     def showEvent(self, event: QShowEvent) -> None:
      64           1 :         super().showEvent(event)
      65           1 :         self.window().dockwidget.setWidget(self.toolbox)
      66           1 :         self.window().dockwidget.setVisible(True)
      67           1 :         self.toolbox.show()
      68             : 
      69           1 :     def hideEvent(self, event: QHideEvent) -> None:
      70           1 :         super().hideEvent(event)
      71           1 :         self.window().dockwidget.setVisible(False)
      72             : 
      73             : 
      74           1 : class CalculationPage(SimulationPage):
      75             :     """Base class for all pages with a calculation button."""
      76             : 
      77           1 :     plotwidget: PlotEnergies
      78           1 :     _calculation_finished = False
      79           1 :     _plot_finished = False
      80             : 
      81           1 :     def setupWidget(self) -> None:
      82           1 :         super().setupWidget()
      83             : 
      84             :         # Plot Panel
      85           1 :         self.plotwidget = PlotEnergies(self)
      86           1 :         self.layout().addWidget(self.plotwidget)
      87             : 
      88             :         # Control panel below the plot
      89           1 :         bottom_layout = QHBoxLayout()
      90           1 :         bottom_layout.setObjectName("bottomLayout")
      91             : 
      92             :         # Calculate/Abort stacked buttons
      93           1 :         self.calculate_and_abort = NamedStackedWidget[QPushButton](self)
      94             : 
      95           1 :         calculate_button = QPushButton("Calculate")
      96           1 :         calculate_button.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_BrowserReload))
      97           1 :         calculate_button.clicked.connect(self.calculate_clicked)
      98           1 :         self.calculate_and_abort.addNamedWidget(calculate_button, "Calculate")
      99             : 
     100           1 :         abort_button = QPushButton("Abort")
     101           1 :         abort_button.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_BrowserStop))
     102           1 :         abort_button.clicked.connect(self.abort_clicked)
     103           1 :         self.calculate_and_abort.addNamedWidget(abort_button, "Abort")
     104             : 
     105           1 :         self.calculate_and_abort.setFixedHeight(50)
     106           1 :         bottom_layout.addWidget(self.calculate_and_abort, stretch=2)
     107             : 
     108             :         # Create export button with menu
     109           1 :         export_button = QPushButton("Export")
     110           1 :         export_button.setObjectName("Export")
     111           1 :         export_button.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_DialogSaveButton))
     112           1 :         export_menu = QMenu(self)
     113           1 :         export_menu.addAction("Export as PNG", self.export_png)
     114           1 :         export_menu.addAction("Export as Python script", self.export_python)
     115           1 :         export_menu.addAction("Export as Jupyter notebook", self.export_notebook)
     116           1 :         export_button.setMenu(export_menu)
     117           1 :         export_button.setFixedHeight(50)
     118           1 :         bottom_layout.addWidget(export_button, stretch=1)
     119             : 
     120           1 :         self.layout().addLayout(bottom_layout)
     121             : 
     122           1 :     def calculate_clicked(self) -> None:
     123           1 :         self._calculation_finished = False
     124           1 :         self._plot_finished = False
     125           1 :         self.before_calculate()
     126             : 
     127           1 :         def update_plot(
     128             :             parameters_and_results: tuple[Parameters[Any], Results],
     129             :         ) -> None:
     130           1 :             worker_plot = MultiThreadWorker(self.update_plot, *parameters_and_results)
     131           1 :             worker_plot.start()
     132           1 :             worker_plot.signals.finished.connect(lambda _sucess: setattr(self, "_plot_finished", True))
     133             : 
     134           1 :         worker = MultiThreadWorker(self.calculate)
     135           1 :         worker.enable_busy_indicator(self)
     136           1 :         worker.signals.result.connect(update_plot)
     137           1 :         worker.signals.finished.connect(self.after_calculate)
     138           1 :         worker.signals.finished.connect(lambda _sucess: setattr(self, "_calculation_finished", True))
     139           1 :         worker.start()
     140             : 
     141           1 :     def before_calculate(self) -> None:
     142           1 :         show_status_tip(self, "Calculating... Please wait.", logger=logger)
     143           1 :         self.calculate_and_abort.setCurrentNamedWidget("Abort")
     144           1 :         self.plotwidget.clear()
     145             : 
     146           1 :         self._start_time = time.perf_counter()
     147             : 
     148           1 :     def after_calculate(self, success: bool) -> None:
     149           1 :         time_needed = time.perf_counter() - self._start_time
     150             : 
     151           1 :         if success:
     152           1 :             show_status_tip(self, f"Calculation finished after {time_needed:.2f} seconds.", logger=logger)
     153             :         else:
     154           0 :             show_status_tip(self, f"Calculation failed after {time_needed:.2f} seconds.", logger=logger)
     155             : 
     156           1 :         self.calculate_and_abort.setCurrentNamedWidget("Calculate")
     157             : 
     158           1 :     def calculate(self) -> tuple[Parameters[Any], Results]:
     159           0 :         raise NotImplementedError("Subclasses must implement this method")
     160             : 
     161           1 :     def update_plot(self, parameters: Parameters[Any], results: Results) -> None:
     162           0 :         energies = results.energies
     163           0 :         overlaps = results.ket_overlaps
     164             : 
     165           0 :         x_values = parameters.get_x_values()
     166           0 :         x_label = parameters.get_x_label()
     167             : 
     168           0 :         self.plotwidget.plot(x_values, energies, overlaps, x_label)
     169             : 
     170           0 :         self.plotwidget.add_cursor(x_values, energies, results.state_labels)
     171             : 
     172           0 :         self.plotwidget.canvas.draw()
     173           0 :         self._plot_finished = False
     174             : 
     175           1 :     def export_png(self) -> None:
     176             :         """Export the current plot as a PNG file."""
     177           0 :         logger.debug("Exporting results as PNG...")
     178             : 
     179           0 :         filename, _ = QFileDialog.getSaveFileName(self, "Save Plot", "", "PNG Files (*.png)")
     180             : 
     181           0 :         if filename:
     182           0 :             filename = filename.removesuffix(".png") + ".png"
     183           0 :             self.plotwidget.canvas.fig.savefig(
     184             :                 filename, dpi=300, bbox_inches="tight", facecolor="white", edgecolor="none"
     185             :             )
     186           0 :             logger.info("Plot saved as %s", filename)
     187             : 
     188           1 :     def export_python(self) -> None:
     189             :         """Export the current calculation as a Python script."""
     190           0 :         logger.debug("Exporting results as Python script...")
     191             : 
     192           0 :         filename, _ = QFileDialog.getSaveFileName(self, "Save Python Script", "", "Python Files (*.py)")
     193             : 
     194           0 :         if filename:
     195           0 :             filename = filename.removesuffix(".py") + ".py"
     196             : 
     197           0 :             template_path = (
     198             :                 Path(__file__).parent.parent / "export_templates" / self._get_export_notebook_template_name()
     199             :             )
     200           0 :             with Path(template_path).open() as f:
     201           0 :                 notebook = nbformat.read(f, as_version=4)
     202             : 
     203           0 :             exporter = PythonExporter(exclude_output_prompt=True, exclude_input_prompt=True)
     204           0 :             content, _ = exporter.from_notebook_node(notebook)
     205             : 
     206           0 :             replacements = self._get_export_replacements()
     207           0 :             for key, value in replacements.items():
     208           0 :                 content = content.replace(key, str(value))
     209             : 
     210           0 :             with Path(filename).open("w") as f:
     211           0 :                 f.write(content)
     212             : 
     213           0 :             logger.info("Python script saved as %s", filename)
     214             : 
     215           1 :     def export_notebook(self) -> None:
     216             :         """Export the current calculation as a Jupyter notebook."""
     217           0 :         logger.debug("Exporting results as Jupyter notebook...")
     218             : 
     219           0 :         filename, _ = QFileDialog.getSaveFileName(self, "Save Jupyter Notebook", "", "Jupyter Notebooks (*.ipynb)")
     220             : 
     221           0 :         if filename:
     222           0 :             filename = filename.removesuffix(".ipynb") + ".ipynb"
     223             : 
     224           0 :             template_path = (
     225             :                 Path(__file__).parent.parent / "export_templates" / self._get_export_notebook_template_name()
     226             :             )
     227           0 :             with Path(template_path).open() as f:
     228           0 :                 notebook = nbformat.read(f, as_version=4)
     229             : 
     230           0 :             replacements = self._get_export_replacements()
     231           0 :             for cell in notebook.cells:
     232           0 :                 if cell.cell_type == "code":
     233           0 :                     source = cell.source
     234           0 :                     for key, value in replacements.items():
     235           0 :                         source = source.replace(key, str(value))
     236           0 :                     cell.source = source
     237             : 
     238           0 :             nbformat.write(notebook, filename)
     239             : 
     240           0 :             logger.info("Jupyter notebook saved as %s", filename)
     241             : 
     242           1 :     def _get_export_notebook_template_name(self) -> str:
     243           0 :         raise NotImplementedError("Subclasses must implement this method")
     244             : 
     245           1 :     def _get_export_replacements(self) -> dict[str, str]:
     246             :         # Override this method in subclasses to provide specific replacements for the export
     247           0 :         return {}
     248             : 
     249           1 :     def abort_clicked(self) -> None:
     250             :         """Handle abort button click."""
     251           0 :         logger.debug("Aborting calculation.")
     252           0 :         MultiProcessWorker.terminate_all(create_new_pool=True)
     253           0 :         MultiThreadWorker.terminate_all()
     254           0 :         self.after_calculate(False)
     255           0 :         show_status_tip(self, "Calculation aborted.", logger=logger)

Generated by: LCOV version 1.16