LCOV - code coverage report
Current view: top level - src/pairinteraction_gui/page - base_page.py (source / functions) Hit Total Coverage
Test: coverage.info Lines: 118 165 71.5 %
Date: 2025-08-29 20:47:05 Functions: 11 38 28.9 %

          Line data    Source code
       1             : # SPDX-FileCopyrightText: 2025 PairInteraction Developers
       2             : # SPDX-License-Identifier: LGPL-3.0-or-later
       3             : 
       4           1 : import logging
       5           1 : import time
       6           1 : from pathlib import Path
       7           1 : from typing import Any, Optional
       8             : 
       9           1 : import nbformat
      10           1 : from nbconvert import PythonExporter
      11           1 : from PySide6.QtGui import QHideEvent, QShowEvent
      12           1 : from PySide6.QtWidgets import (
      13             :     QFileDialog,
      14             :     QHBoxLayout,
      15             :     QMenu,
      16             :     QPushButton,
      17             :     QStyle,
      18             :     QToolBox,
      19             : )
      20             : 
      21           1 : import pairinteraction
      22           1 : from pairinteraction_gui.calculate.calculate_base import Parameters, Results
      23           1 : from pairinteraction_gui.config import BaseConfig
      24           1 : from pairinteraction_gui.config.ket_config import KetConfig
      25           1 : from pairinteraction_gui.plotwidget.plotwidget import PlotEnergies, PlotWidget
      26           1 : from pairinteraction_gui.qobjects import NamedStackedWidget, WidgetV, show_status_tip
      27           1 : from pairinteraction_gui.worker import MultiProcessWorker, MultiThreadWorker
      28             : 
      29           1 : logger = logging.getLogger(__name__)
      30             : 
      31             : 
      32           1 : class BasePage(WidgetV):
      33             :     """Base class for all pages in this application."""
      34             : 
      35           1 :     margin = (20, 20, 20, 20)
      36           1 :     spacing = 15
      37             : 
      38           1 :     title: str
      39           1 :     tooltip: str
      40           1 :     icon_path: Optional[Path] = None
      41             : 
      42           1 :     def showEvent(self, event: QShowEvent) -> None:
      43             :         """Show event."""
      44           1 :         super().showEvent(event)
      45           1 :         self.window().setWindowTitle(f"PairInteraction v{pairinteraction.__version__} - " + self.title)
      46             : 
      47             : 
      48           1 : class SimulationPage(BasePage):
      49             :     """Base class for all simulation pages in this application."""
      50             : 
      51           1 :     ket_config: KetConfig
      52             : 
      53           1 :     plotwidget: PlotWidget
      54             : 
      55           1 :     def setupWidget(self) -> None:
      56           1 :         self.toolbox = QToolBox()
      57             : 
      58           1 :     def postSetupWidget(self) -> None:
      59           1 :         for attr in self.__dict__.values():
      60           1 :             if isinstance(attr, BaseConfig):
      61           1 :                 self.toolbox.addItem(attr, attr.title)
      62             : 
      63           1 :         for i, species_combo in enumerate(self.ket_config.species_combo_list):
      64           1 :             self.ket_config.signal_species_changed.emit(i, species_combo.currentText())
      65             : 
      66           1 :     def showEvent(self, event: QShowEvent) -> None:
      67           1 :         super().showEvent(event)
      68           1 :         self.window().dockwidget.setWidget(self.toolbox)
      69           1 :         self.window().dockwidget.setVisible(True)
      70           1 :         self.toolbox.show()
      71             : 
      72           1 :     def hideEvent(self, event: QHideEvent) -> None:
      73           1 :         super().hideEvent(event)
      74           1 :         self.window().dockwidget.setVisible(False)
      75             : 
      76             : 
      77           1 : class CalculationPage(SimulationPage):
      78             :     """Base class for all pages with a calculation button."""
      79             : 
      80           1 :     plotwidget: PlotEnergies
      81           1 :     _calculation_finished = False
      82           1 :     _plot_finished = False
      83             : 
      84           1 :     def setupWidget(self) -> None:
      85           1 :         super().setupWidget()
      86             : 
      87             :         # Plot Panel
      88           1 :         self.plotwidget = PlotEnergies(self)
      89           1 :         self.layout().addWidget(self.plotwidget)
      90             : 
      91             :         # Control panel below the plot
      92           1 :         bottom_layout = QHBoxLayout()
      93           1 :         bottom_layout.setObjectName("bottomLayout")
      94             : 
      95             :         # Calculate/Abort stacked buttons
      96           1 :         self.calculate_and_abort = NamedStackedWidget[QPushButton](self)
      97             : 
      98           1 :         calculate_button = QPushButton("Calculate")
      99           1 :         calculate_button.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_BrowserReload))
     100           1 :         calculate_button.clicked.connect(self.calculate_clicked)
     101           1 :         self.calculate_and_abort.addNamedWidget(calculate_button, "Calculate")
     102             : 
     103           1 :         abort_button = QPushButton("Abort")
     104           1 :         abort_button.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_BrowserStop))
     105           1 :         abort_button.clicked.connect(self.abort_clicked)
     106           1 :         self.calculate_and_abort.addNamedWidget(abort_button, "Abort")
     107             : 
     108           1 :         self.calculate_and_abort.setFixedHeight(50)
     109           1 :         bottom_layout.addWidget(self.calculate_and_abort, stretch=2)
     110             : 
     111             :         # Create export button with menu
     112           1 :         export_button = QPushButton("Export")
     113           1 :         export_button.setObjectName("Export")
     114           1 :         export_button.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_DialogSaveButton))
     115           1 :         export_menu = QMenu(self)
     116           1 :         export_menu.addAction("Export as PNG", self.export_png)
     117           1 :         export_menu.addAction("Export as Python script", self.export_python)
     118           1 :         export_menu.addAction("Export as Jupyter notebook", self.export_notebook)
     119           1 :         export_button.setMenu(export_menu)
     120           1 :         export_button.setFixedHeight(50)
     121           1 :         bottom_layout.addWidget(export_button, stretch=1)
     122             : 
     123           1 :         self.layout().addLayout(bottom_layout)
     124             : 
     125           1 :     def calculate_clicked(self) -> None:
     126           1 :         self._calculation_finished = False
     127           1 :         self._plot_finished = False
     128           1 :         self.before_calculate()
     129             : 
     130           1 :         def update_plot(
     131             :             parameters_and_results: tuple[Parameters[Any], Results],
     132             :         ) -> None:
     133           1 :             worker_plot = MultiThreadWorker(self.update_plot, *parameters_and_results)
     134           1 :             worker_plot.start()
     135           1 :             worker_plot.signals.finished.connect(lambda _sucess: setattr(self, "_plot_finished", True))
     136             : 
     137           1 :         worker = MultiThreadWorker(self.calculate)
     138           1 :         worker.enable_busy_indicator(self)
     139           1 :         worker.signals.result.connect(update_plot)
     140           1 :         worker.signals.finished.connect(self.after_calculate)
     141           1 :         worker.signals.finished.connect(lambda _sucess: setattr(self, "_calculation_finished", True))
     142           1 :         worker.start()
     143             : 
     144           1 :     def before_calculate(self) -> None:
     145           1 :         show_status_tip(self, "Calculating... Please wait.", logger=logger)
     146           1 :         self.calculate_and_abort.setCurrentNamedWidget("Abort")
     147           1 :         self.plotwidget.clear()
     148             : 
     149           1 :         self._start_time = time.perf_counter()
     150             : 
     151           1 :     def after_calculate(self, success: bool) -> None:
     152           1 :         time_needed = time.perf_counter() - self._start_time
     153             : 
     154           1 :         if success:
     155           1 :             show_status_tip(self, f"Calculation finished after {time_needed:.2f} seconds.", logger=logger)
     156             :         else:
     157           0 :             show_status_tip(self, f"Calculation failed after {time_needed:.2f} seconds.", logger=logger)
     158             : 
     159           1 :         self.calculate_and_abort.setCurrentNamedWidget("Calculate")
     160             : 
     161           1 :     def calculate(self) -> tuple[Parameters[Any], Results]:
     162           0 :         raise NotImplementedError("Subclasses must implement this method")
     163             : 
     164           1 :     def update_plot(self, parameters: Parameters[Any], results: Results) -> None:
     165           0 :         energies = results.energies
     166           0 :         overlaps = results.ket_overlaps
     167             : 
     168           0 :         x_values = parameters.get_x_values()
     169           0 :         x_label = parameters.get_x_label()
     170             : 
     171           0 :         self.plotwidget.plot(x_values, energies, overlaps, x_label)
     172             : 
     173           0 :         self.plotwidget.add_cursor(x_values, energies, results.state_labels)
     174             : 
     175           0 :         self.plotwidget.canvas.draw()
     176           0 :         self._plot_finished = False
     177             : 
     178           1 :     def export_png(self) -> None:
     179             :         """Export the current plot as a PNG file."""
     180           0 :         logger.debug("Exporting results as PNG...")
     181             : 
     182           0 :         filename, _ = QFileDialog.getSaveFileName(self, "Save Plot", "", "PNG Files (*.png)")
     183             : 
     184           0 :         if filename:
     185           0 :             filename = filename.removesuffix(".png") + ".png"
     186           0 :             self.plotwidget.canvas.fig.savefig(
     187             :                 filename, dpi=300, bbox_inches="tight", facecolor="white", edgecolor="none"
     188             :             )
     189           0 :             logger.info("Plot saved as %s", filename)
     190             : 
     191           1 :     def _create_python_code(self) -> str:
     192           1 :         template_path = Path(__file__).parent.parent / "export_templates" / self._get_export_notebook_template_name()
     193           1 :         with Path(template_path).open() as f:
     194           1 :             notebook = nbformat.read(f, as_version=4)
     195             : 
     196           1 :         exporter = PythonExporter(exclude_output_prompt=True, exclude_input_prompt=True)
     197           1 :         content, _ = exporter.from_notebook_node(notebook)
     198             : 
     199           1 :         replacements = self._get_export_replacements()
     200           1 :         for key, value in replacements.items():
     201           1 :             content = content.replace(key, str(value))
     202             : 
     203           1 :         return content
     204             : 
     205           1 :     def export_python(self) -> None:
     206             :         """Export the current calculation as a Python script."""
     207           0 :         logger.debug("Exporting results as Python script...")
     208           0 :         filename, _ = QFileDialog.getSaveFileName(self, "Save Python Script", "", "Python Files (*.py)")
     209           0 :         if filename:
     210           0 :             filename = filename.removesuffix(".py") + ".py"
     211           0 :             content = self._create_python_code()
     212           0 :             with Path(filename).open("w") as f:
     213           0 :                 f.write(content)
     214           0 :             logger.info("Python script saved as %s", filename)
     215             : 
     216           1 :     def export_notebook(self) -> None:
     217             :         """Export the current calculation as a Jupyter notebook."""
     218           0 :         logger.debug("Exporting results as Jupyter notebook...")
     219             : 
     220           0 :         filename, _ = QFileDialog.getSaveFileName(self, "Save Jupyter Notebook", "", "Jupyter Notebooks (*.ipynb)")
     221             : 
     222           0 :         if filename:
     223           0 :             filename = filename.removesuffix(".ipynb") + ".ipynb"
     224             : 
     225           0 :             template_path = (
     226             :                 Path(__file__).parent.parent / "export_templates" / self._get_export_notebook_template_name()
     227             :             )
     228           0 :             with Path(template_path).open() as f:
     229           0 :                 notebook = nbformat.read(f, as_version=4)
     230             : 
     231           0 :             replacements = self._get_export_replacements()
     232           0 :             for cell in notebook.cells:
     233           0 :                 if cell.cell_type == "code":
     234           0 :                     source = cell.source
     235           0 :                     for key, value in replacements.items():
     236           0 :                         source = source.replace(key, str(value))
     237           0 :                     cell.source = source
     238             : 
     239           0 :             nbformat.write(notebook, filename)
     240             : 
     241           0 :             logger.info("Jupyter notebook saved as %s", filename)
     242             : 
     243           1 :     def _get_export_notebook_template_name(self) -> str:
     244           0 :         raise NotImplementedError("Subclasses must implement this method")
     245             : 
     246           1 :     def _get_export_replacements(self) -> dict[str, str]:
     247             :         # Override this method in subclasses to provide specific replacements for the export
     248           0 :         return {}
     249             : 
     250           1 :     def abort_clicked(self) -> None:
     251             :         """Handle abort button click."""
     252           0 :         logger.debug("Aborting calculation.")
     253           0 :         MultiProcessWorker.terminate_all(create_new_pool=True)
     254           0 :         MultiThreadWorker.terminate_all()
     255           0 :         self.after_calculate(False)
     256           0 :         show_status_tip(self, "Calculation aborted.", logger=logger)

Generated by: LCOV version 1.16