LCOV - code coverage report
Current view: top level - src/pairinteraction_gui/plotwidget - plotwidget.py (source / functions) Hit Total Coverage
Test: coverage.info Lines: 42 78 53.8 %
Date: 2025-06-06 09:09:03 Functions: 4 14 28.6 %

          Line data    Source code
       1             : # SPDX-FileCopyrightText: 2025 Pairinteraction Developers
       2             : # SPDX-License-Identifier: LGPL-3.0-or-later
       3             : 
       4           1 : import logging
       5           1 : from collections.abc import Sequence
       6           1 : from typing import TYPE_CHECKING, Any
       7             : 
       8           1 : import matplotlib as mpl
       9           1 : import matplotlib.pyplot as plt
      10           1 : import mplcursors
      11           1 : import numpy as np
      12           1 : from matplotlib.colors import Normalize
      13           1 : from PySide6.QtWidgets import QHBoxLayout
      14             : 
      15           1 : from pairinteraction.visualization.colormaps import alphamagma
      16           1 : from pairinteraction_gui.plotwidget.canvas import MatplotlibCanvas
      17           1 : from pairinteraction_gui.plotwidget.navigation_toolbar import CustomNavigationToolbar
      18           1 : from pairinteraction_gui.qobjects import WidgetV
      19           1 : from pairinteraction_gui.theme import plot_widget_theme
      20             : 
      21             : if TYPE_CHECKING:
      22             :     from numpy.typing import NDArray
      23             : 
      24             :     from pairinteraction_gui.page import SimulationPage
      25             : 
      26           1 : logger = logging.getLogger(__name__)
      27             : 
      28             : 
      29           1 : class PlotWidget(WidgetV):
      30             :     """Widget for displaying plots with controls."""
      31             : 
      32           1 :     margin = (0, 0, 0, 0)
      33           1 :     spacing = 15
      34             : 
      35           1 :     def __init__(self, parent: "SimulationPage") -> None:
      36             :         """Initialize the base section."""
      37           1 :         mpl.use("Qt5Agg")
      38             : 
      39           1 :         self.page = parent
      40           1 :         super().__init__(parent)
      41             : 
      42           1 :         self.setStyleSheet(plot_widget_theme)
      43             : 
      44           1 :     def setupWidget(self) -> None:
      45           1 :         self.canvas = MatplotlibCanvas(self)
      46           1 :         self.navigation_toolbar = CustomNavigationToolbar(self.canvas, self)
      47             : 
      48           1 :         top_layout = QHBoxLayout()
      49           1 :         top_layout.addStretch(1)
      50           1 :         top_layout.addWidget(self.navigation_toolbar)
      51           1 :         self.layout().addLayout(top_layout)
      52             : 
      53           1 :         self.layout().addWidget(self.canvas, stretch=1)
      54             : 
      55           1 :     def clear(self) -> None:
      56           1 :         self.canvas.ax.clear()
      57           1 :         self.canvas.draw()
      58             : 
      59             : 
      60           1 : class PlotEnergies(PlotWidget):
      61             :     """Plotwidget for plotting energy levels."""
      62             : 
      63           1 :     def setupWidget(self) -> None:
      64           1 :         super().setupWidget()
      65             : 
      66           1 :         mappable = plt.cm.ScalarMappable(cmap=alphamagma, norm=Normalize(vmin=0, vmax=1))
      67           1 :         self.canvas.fig.colorbar(mappable, ax=self.canvas.ax, label="Overlap with state of interest")
      68           1 :         self.canvas.fig.tight_layout()
      69             : 
      70           1 :     def plot(
      71             :         self,
      72             :         x_list: Sequence[float],
      73             :         energies_list: Sequence["NDArray[Any]"],
      74             :         overlaps_list: Sequence["NDArray[Any]"],
      75             :         xlabel: str,
      76             :     ) -> None:
      77           0 :         ax = self.canvas.ax
      78           0 :         ax.clear()
      79             : 
      80           0 :         try:
      81           0 :             ax.plot(x_list, np.array(energies_list), c="0.75", lw=0.25, zorder=-10)
      82           0 :         except ValueError as err:
      83           0 :             if "inhomogeneous shape" in str(err):
      84           0 :                 for x_value, es in zip(x_list, energies_list):
      85           0 :                     ax.plot([x_value] * len(es), es, c="0.75", ls="None", marker=".", zorder=-10)
      86             :             else:
      87           0 :                 raise err
      88             : 
      89             :         # Flatten the arrays for scatter plot and repeat x value for each energy
      90             :         # (dont use numpy.flatten, etc. to also handle inhomogeneous shapes)
      91           0 :         x_repeated = np.hstack([val * np.ones_like(es) for val, es in zip(x_list, energies_list)])
      92           0 :         energies_flattend = np.hstack(energies_list)
      93           0 :         overlaps_flattend = np.hstack(overlaps_list)
      94             : 
      95           0 :         min_overlap = 1e-4
      96           0 :         inds: NDArray[Any] = np.argwhere(overlaps_flattend > min_overlap).flatten()
      97           0 :         inds = inds[np.argsort(overlaps_flattend[inds])]
      98             : 
      99           0 :         if len(inds) > 0:
     100           0 :             ax.scatter(
     101             :                 x_repeated[inds],
     102             :                 energies_flattend[inds],
     103             :                 c=overlaps_flattend[inds],
     104             :                 s=15,
     105             :                 vmin=0,
     106             :                 vmax=1,
     107             :                 cmap=alphamagma,
     108             :             )
     109             : 
     110           0 :         ax.set_xlabel(xlabel)
     111           0 :         ax.set_ylabel("Energy [GHz]")
     112             : 
     113           0 :         self.canvas.fig.tight_layout()
     114             : 
     115           1 :     def add_cursor(
     116             :         self, x_value: list[float], energies: list["NDArray[Any]"], state_labels: dict[int, list[str]]
     117             :     ) -> None:
     118             :         # Remove any existing cursors to avoid duplicates
     119           0 :         if hasattr(self, "mpl_cursor"):
     120           0 :             if hasattr(self.mpl_cursor, "remove"):  # type: ignore
     121           0 :                 self.mpl_cursor.remove()  # type: ignore
     122           0 :             del self.mpl_cursor  # type: ignore
     123             : 
     124           0 :         ax = self.canvas.ax
     125             : 
     126           0 :         artists = []
     127           0 :         for idx, labels in state_labels.items():
     128           0 :             x = x_value[idx]
     129           0 :             for energy, label in zip(energies[idx], labels):
     130           0 :                 artist = ax.plot(x, energy, "d", c="0.93", alpha=0.5, ms=7, label=label, zorder=-20)
     131           0 :                 artists.extend(artist)
     132             : 
     133           0 :         self.mpl_cursor = mplcursors.cursor(
     134             :             artists,
     135             :             hover=False,
     136             :             annotation_kwargs={
     137             :                 "bbox": {"boxstyle": "round,pad=0.5", "fc": "white", "alpha": 0.9, "ec": "gray"},
     138             :                 "arrowprops": {"arrowstyle": "->", "connectionstyle": "arc3", "color": "gray"},
     139             :             },
     140             :         )
     141             : 
     142           0 :         @self.mpl_cursor.connect("add")
     143           0 :         def on_add(sel: mplcursors.Selection) -> None:
     144           0 :             label = sel.artist.get_label()
     145           0 :             sel.annotation.set_text(label.replace(" + ", "\n + "))

Generated by: LCOV version 1.16