LCOV - code coverage report
Current view: top level - src/pairinteraction_gui/page - base_page.py (source / functions) Hit Total Coverage
Test: coverage.info Lines: 105 159 66.0 %
Date: 2025-06-06 09:09:03 Functions: 10 36 27.8 %

          Line data    Source code
       1             : # SPDX-FileCopyrightText: 2025 Pairinteraction Developers
       2             : # SPDX-License-Identifier: LGPL-3.0-or-later
       3             : 
       4           1 : import logging
       5           1 : import time
       6           1 : from pathlib import Path
       7           1 : from typing import Any, Optional
       8             : 
       9           1 : import nbformat
      10           1 : from nbconvert import PythonExporter
      11           1 : from PySide6.QtGui import QHideEvent, QShowEvent
      12           1 : from PySide6.QtWidgets import (
      13             :     QFileDialog,
      14             :     QHBoxLayout,
      15             :     QMenu,
      16             :     QPushButton,
      17             :     QStyle,
      18             :     QToolBox,
      19             : )
      20             : 
      21           1 : import pairinteraction
      22           1 : from pairinteraction_gui.calculate.calculate_base import Parameters, Results
      23           1 : from pairinteraction_gui.config import BaseConfig
      24           1 : from pairinteraction_gui.config.ket_config import KetConfig
      25           1 : from pairinteraction_gui.plotwidget.plotwidget import PlotEnergies, PlotWidget
      26           1 : from pairinteraction_gui.qobjects import NamedStackedWidget, WidgetV, show_status_tip
      27           1 : from pairinteraction_gui.worker import MultiProcessWorker, MultiThreadWorker
      28             : 
      29           1 : logger = logging.getLogger(__name__)
      30             : 
      31             : 
      32           1 : class BasePage(WidgetV):
      33             :     """Base class for all pages in this application."""
      34             : 
      35           1 :     margin = (20, 20, 20, 20)
      36           1 :     spacing = 15
      37             : 
      38           1 :     title: str
      39           1 :     tooltip: str
      40           1 :     icon_path: Optional[Path] = None
      41             : 
      42           1 :     def showEvent(self, event: QShowEvent) -> None:
      43             :         """Show event."""
      44           1 :         super().showEvent(event)
      45           1 :         self.window().setWindowTitle(f"PairInteraction v{pairinteraction.__version__} - " + self.title)
      46             : 
      47             : 
      48           1 : class SimulationPage(BasePage):
      49             :     """Base class for all simulation pages in this application."""
      50             : 
      51           1 :     ket_config: KetConfig
      52             : 
      53           1 :     plotwidget: PlotWidget
      54             : 
      55           1 :     def setupWidget(self) -> None:
      56           1 :         self.toolbox = QToolBox()
      57             : 
      58           1 :     def postSetupWidget(self) -> None:
      59           1 :         for attr in self.__dict__.values():
      60           1 :             if isinstance(attr, BaseConfig):
      61           1 :                 self.toolbox.addItem(attr, attr.title)
      62             : 
      63           1 :     def showEvent(self, event: QShowEvent) -> None:
      64           1 :         super().showEvent(event)
      65           1 :         self.window().dockwidget.setWidget(self.toolbox)
      66           1 :         self.window().dockwidget.setVisible(True)
      67           1 :         self.toolbox.show()
      68             : 
      69           1 :     def hideEvent(self, event: QHideEvent) -> None:
      70           1 :         super().hideEvent(event)
      71           1 :         self.window().dockwidget.setVisible(False)
      72             : 
      73             : 
      74           1 : class CalculationPage(SimulationPage):
      75             :     """Base class for all pages with a calculation button."""
      76             : 
      77           1 :     plotwidget: PlotEnergies
      78           1 :     _calculation_finished = False
      79           1 :     _plot_finished = False
      80             : 
      81           1 :     def setupWidget(self) -> None:
      82           1 :         super().setupWidget()
      83             : 
      84             :         # Plot Panel
      85           1 :         self.plotwidget = PlotEnergies(self)
      86           1 :         self.layout().addWidget(self.plotwidget)
      87             : 
      88             :         # Control panel below the plot
      89           1 :         bottom_layout = QHBoxLayout()
      90             : 
      91             :         # Calculate/Abort stacked buttons
      92           1 :         self.calculate_and_abort = NamedStackedWidget[QPushButton](self)
      93             : 
      94           1 :         calculate_button = QPushButton("Calculate")
      95           1 :         calculate_button.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_BrowserReload))
      96           1 :         calculate_button.clicked.connect(self.calculate_clicked)
      97           1 :         self.calculate_and_abort.addNamedWidget(calculate_button, "Calculate")
      98             : 
      99           1 :         abort_button = QPushButton("Abort")
     100           1 :         abort_button.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_BrowserStop))
     101           1 :         abort_button.clicked.connect(self.abort_clicked)
     102           1 :         self.calculate_and_abort.addNamedWidget(abort_button, "Abort")
     103             : 
     104           1 :         self.calculate_and_abort.setFixedHeight(50)
     105           1 :         bottom_layout.addWidget(self.calculate_and_abort, stretch=2)
     106             : 
     107             :         # Create export button with menu
     108           1 :         export_button = QPushButton("Export")
     109           1 :         export_button.setObjectName("Export")
     110           1 :         export_button.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_DialogSaveButton))
     111           1 :         export_menu = QMenu(self)
     112           1 :         export_menu.addAction("Export as PNG", self.export_png)
     113           1 :         export_menu.addAction("Export as Python script", self.export_python)
     114           1 :         export_menu.addAction("Export as Jupyter notebook", self.export_notebook)
     115           1 :         export_button.setMenu(export_menu)
     116           1 :         export_button.setFixedHeight(50)
     117           1 :         bottom_layout.addWidget(export_button, stretch=1)
     118             : 
     119           1 :         self.layout().addLayout(bottom_layout)
     120             : 
     121           1 :     def calculate_clicked(self) -> None:
     122           1 :         self._calculation_finished = False
     123           1 :         self._plot_finished = False
     124           1 :         self.before_calculate()
     125             : 
     126           1 :         def update_plot(
     127             :             parameters_and_results: tuple[Parameters[Any], Results],
     128             :         ) -> None:
     129           1 :             worker_plot = MultiThreadWorker(self.update_plot, *parameters_and_results)
     130           1 :             worker_plot.start()
     131           1 :             worker_plot.signals.finished.connect(lambda _sucess: setattr(self, "_plot_finished", True))
     132             : 
     133           1 :         worker = MultiThreadWorker(self.calculate)
     134           1 :         worker.enable_busy_indicator(self)
     135           1 :         worker.signals.result.connect(update_plot)
     136           1 :         worker.signals.finished.connect(self.after_calculate)
     137           1 :         worker.signals.finished.connect(lambda _sucess: setattr(self, "_calculation_finished", True))
     138           1 :         worker.start()
     139             : 
     140           1 :     def before_calculate(self) -> None:
     141           1 :         show_status_tip(self, "Calculating... Please wait.", logger=logger)
     142           1 :         self.calculate_and_abort.setCurrentNamedWidget("Abort")
     143           1 :         self.plotwidget.clear()
     144             : 
     145           1 :         self._start_time = time.perf_counter()
     146             : 
     147           1 :     def after_calculate(self, success: bool) -> None:
     148           1 :         time_needed = time.perf_counter() - self._start_time
     149             : 
     150           1 :         if success:
     151           1 :             show_status_tip(self, f"Calculation finished after {time_needed:.2f} seconds.", logger=logger)
     152             :         else:
     153           0 :             show_status_tip(self, f"Calculation failed after {time_needed:.2f} seconds.", logger=logger)
     154             : 
     155           1 :         self.calculate_and_abort.setCurrentNamedWidget("Calculate")
     156             : 
     157           1 :     def calculate(self) -> tuple[Parameters[Any], Results]:
     158           0 :         raise NotImplementedError("Subclasses must implement this method")
     159             : 
     160           1 :     def update_plot(self, parameters: Parameters[Any], results: Results) -> None:
     161           0 :         energies = results.energies
     162           0 :         overlaps = results.ket_overlaps
     163             : 
     164           0 :         x_values = parameters.get_x_values()
     165           0 :         x_label = parameters.get_x_label()
     166             : 
     167           0 :         self.plotwidget.plot(x_values, energies, overlaps, x_label)
     168             : 
     169           0 :         self.plotwidget.add_cursor(x_values, energies, results.state_labels)
     170             : 
     171           0 :         self.plotwidget.canvas.draw()
     172           0 :         self._plot_finished = False
     173             : 
     174           1 :     def export_png(self) -> None:
     175             :         """Export the current plot as a PNG file."""
     176           0 :         logger.debug("Exporting results as PNG...")
     177             : 
     178           0 :         filename, _ = QFileDialog.getSaveFileName(self, "Save Plot", "", "PNG Files (*.png)")
     179             : 
     180           0 :         if filename:
     181           0 :             filename = filename.removesuffix(".png") + ".png"
     182           0 :             self.plotwidget.canvas.fig.savefig(
     183             :                 filename, dpi=300, bbox_inches="tight", facecolor="white", edgecolor="none"
     184             :             )
     185           0 :             logger.info("Plot saved as %s", filename)
     186             : 
     187           1 :     def export_python(self) -> None:
     188             :         """Export the current calculation as a Python script."""
     189           0 :         logger.debug("Exporting results as Python script...")
     190             : 
     191           0 :         filename, _ = QFileDialog.getSaveFileName(self, "Save Python Script", "", "Python Files (*.py)")
     192             : 
     193           0 :         if filename:
     194           0 :             filename = filename.removesuffix(".py") + ".py"
     195             : 
     196           0 :             template_path = (
     197             :                 Path(__file__).parent.parent / "export_templates" / self._get_export_notebook_template_name()
     198             :             )
     199           0 :             with Path(template_path).open() as f:
     200           0 :                 notebook = nbformat.read(f, as_version=4)
     201             : 
     202           0 :             exporter = PythonExporter(exclude_output_prompt=True, exclude_input_prompt=True)
     203           0 :             content, _ = exporter.from_notebook_node(notebook)
     204             : 
     205           0 :             replacements = self._get_export_replacements()
     206           0 :             for key, value in replacements.items():
     207           0 :                 content = content.replace(key, str(value))
     208             : 
     209           0 :             with Path(filename).open("w") as f:
     210           0 :                 f.write(content)
     211             : 
     212           0 :             logger.info("Python script saved as %s", filename)
     213             : 
     214           1 :     def export_notebook(self) -> None:
     215             :         """Export the current calculation as a Jupyter notebook."""
     216           0 :         logger.debug("Exporting results as Jupyter notebook...")
     217             : 
     218           0 :         filename, _ = QFileDialog.getSaveFileName(self, "Save Jupyter Notebook", "", "Jupyter Notebooks (*.ipynb)")
     219             : 
     220           0 :         if filename:
     221           0 :             filename = filename.removesuffix(".ipynb") + ".ipynb"
     222             : 
     223           0 :             template_path = (
     224             :                 Path(__file__).parent.parent / "export_templates" / self._get_export_notebook_template_name()
     225             :             )
     226           0 :             with Path(template_path).open() as f:
     227           0 :                 notebook = nbformat.read(f, as_version=4)
     228             : 
     229           0 :             replacements = self._get_export_replacements()
     230           0 :             for cell in notebook.cells:
     231           0 :                 if cell.cell_type == "code":
     232           0 :                     source = cell.source
     233           0 :                     for key, value in replacements.items():
     234           0 :                         source = source.replace(key, str(value))
     235           0 :                     cell.source = source
     236             : 
     237           0 :             nbformat.write(notebook, filename)
     238             : 
     239           0 :             logger.info("Jupyter notebook saved as %s", filename)
     240             : 
     241           1 :     def _get_export_notebook_template_name(self) -> str:
     242           0 :         raise NotImplementedError("Subclasses must implement this method")
     243             : 
     244           1 :     def _get_export_replacements(self) -> dict[str, str]:
     245             :         # Override this method in subclasses to provide specific replacements for the export
     246           0 :         return {}
     247             : 
     248           1 :     def abort_clicked(self) -> None:
     249             :         """Handle abort button click."""
     250           0 :         logger.debug("Aborting calculation.")
     251           0 :         MultiProcessWorker.terminate_all(create_new_pool=True)
     252           0 :         MultiThreadWorker.terminate_all()
     253           0 :         self.after_calculate(False)
     254           0 :         show_status_tip(self, "Calculation aborted.", logger=logger)

Generated by: LCOV version 1.16