LCOV - code coverage report
Current view: top level - src/pairinteraction_gui/plotwidget - plotwidget.py (source / functions) Hit Total Coverage
Test: coverage.info Lines: 58 145 40.0 %
Date: 2025-08-29 20:47:05 Functions: 6 20 30.0 %

          Line data    Source code
       1             : # SPDX-FileCopyrightText: 2025 PairInteraction Developers
       2             : # SPDX-License-Identifier: LGPL-3.0-or-later
       3             : 
       4           1 : import logging
       5           1 : from collections.abc import Sequence
       6           1 : from typing import TYPE_CHECKING, Any, Callable
       7             : 
       8           1 : import matplotlib as mpl
       9           1 : import matplotlib.pyplot as plt
      10           1 : import mplcursors
      11           1 : import numpy as np
      12           1 : from matplotlib.colors import Normalize
      13           1 : from PySide6.QtWidgets import QHBoxLayout
      14           1 : from scipy.optimize import curve_fit
      15             : 
      16           1 : from pairinteraction.visualization.colormaps import alphamagma
      17           1 : from pairinteraction_gui.plotwidget.canvas import MatplotlibCanvas
      18           1 : from pairinteraction_gui.plotwidget.navigation_toolbar import CustomNavigationToolbar
      19           1 : from pairinteraction_gui.qobjects import WidgetV
      20           1 : from pairinteraction_gui.theme import plot_widget_theme
      21             : 
      22             : if TYPE_CHECKING:
      23             :     from numpy.typing import NDArray
      24             : 
      25             :     from pairinteraction_gui.page import SimulationPage
      26             : 
      27           1 : logger = logging.getLogger(__name__)
      28             : 
      29             : 
      30           1 : class PlotWidget(WidgetV):
      31             :     """Widget for displaying plots with controls."""
      32             : 
      33           1 :     margin = (0, 0, 0, 0)
      34           1 :     spacing = 15
      35             : 
      36           1 :     def __init__(self, parent: "SimulationPage") -> None:
      37             :         """Initialize the base section."""
      38           1 :         mpl.use("Qt5Agg")
      39             : 
      40           1 :         self.page = parent
      41           1 :         super().__init__(parent)
      42             : 
      43           1 :         self.setStyleSheet(plot_widget_theme)
      44             : 
      45           1 :     def setupWidget(self) -> None:
      46           1 :         self.canvas = MatplotlibCanvas(self)
      47           1 :         self.navigation_toolbar = CustomNavigationToolbar(self.canvas, self)
      48             : 
      49           1 :         top_layout = QHBoxLayout()
      50           1 :         top_layout.addStretch(1)
      51           1 :         top_layout.addWidget(self.navigation_toolbar)
      52           1 :         self.layout().addLayout(top_layout)
      53             : 
      54           1 :         self.layout().addWidget(self.canvas, stretch=1)
      55             : 
      56           1 :     def clear(self) -> None:
      57           1 :         self.canvas.ax.clear()
      58           1 :         self.canvas.draw()
      59             : 
      60             : 
      61           1 : class PlotEnergies(PlotWidget):
      62             :     """Plotwidget for plotting energy levels."""
      63             : 
      64           1 :     x_list: "NDArray[Any]"
      65           1 :     energies_list: Sequence["NDArray[Any]"]
      66           1 :     overlaps_list: Sequence["NDArray[Any]"]
      67           1 :     fit_idx: int = 0
      68           1 :     fit_type: str = ""
      69           1 :     fit_data_highlight: mpl.collections.PathCollection
      70           1 :     fit_curve: Sequence[mpl.lines.Line2D]
      71             : 
      72           1 :     def setupWidget(self) -> None:
      73           1 :         super().setupWidget()
      74             : 
      75           1 :         mappable = plt.cm.ScalarMappable(cmap=alphamagma, norm=Normalize(vmin=0, vmax=1))
      76           1 :         self.canvas.fig.colorbar(mappable, ax=self.canvas.ax, label="Overlap with state of interest")
      77           1 :         self.canvas.fig.tight_layout()
      78             : 
      79           1 :     def plot(
      80             :         self,
      81             :         x_list: Sequence[float],
      82             :         energies_list: Sequence["NDArray[Any]"],
      83             :         overlaps_list: Sequence["NDArray[Any]"],
      84             :         xlabel: str,
      85             :     ) -> None:
      86             :         # store data to allow fitting later on
      87           0 :         self.x_list = np.array(x_list)
      88           0 :         self.energies_list = energies_list
      89           0 :         self.overlaps_list = overlaps_list
      90           0 :         self.reset_fit()
      91             : 
      92           0 :         ax = self.canvas.ax
      93           0 :         ax.clear()
      94             : 
      95           0 :         try:
      96           0 :             ax.plot(x_list, np.array(energies_list), c="0.75", lw=0.25, zorder=-10)
      97           0 :         except ValueError as err:
      98           0 :             if "inhomogeneous shape" in str(err):
      99           0 :                 for x_value, es in zip(x_list, energies_list):
     100           0 :                     ax.plot([x_value] * len(es), es, c="0.75", ls="None", marker=".", zorder=-10)
     101             :             else:
     102           0 :                 raise err
     103             : 
     104             :         # Flatten the arrays for scatter plot and repeat x value for each energy
     105             :         # (dont use numpy.flatten, etc. to also handle inhomogeneous shapes)
     106           0 :         x_repeated = np.hstack([val * np.ones_like(es) for val, es in zip(x_list, energies_list)])
     107           0 :         energies_flattened = np.hstack(energies_list)
     108           0 :         overlaps_flattened = np.hstack(overlaps_list)
     109             : 
     110           0 :         min_overlap = 1e-4
     111           0 :         inds: NDArray[Any] = np.argwhere(overlaps_flattened > min_overlap).flatten()
     112           0 :         inds = inds[np.argsort(overlaps_flattened[inds])]
     113             : 
     114           0 :         if len(inds) > 0:
     115           0 :             ax.scatter(
     116             :                 x_repeated[inds],
     117             :                 energies_flattened[inds],
     118             :                 c=overlaps_flattened[inds],
     119             :                 s=15,
     120             :                 vmin=0,
     121             :                 vmax=1,
     122             :                 cmap=alphamagma,
     123             :             )
     124             : 
     125           0 :         ax.set_xlabel(xlabel)
     126           0 :         ax.set_ylabel("Energy [GHz]")
     127             : 
     128           0 :         self.canvas.fig.tight_layout()
     129             : 
     130           1 :     def add_cursor(
     131             :         self, x_value: list[float], energies: list["NDArray[Any]"], state_labels: dict[int, list[str]]
     132             :     ) -> None:
     133             :         # Remove any existing cursors to avoid duplicates
     134           0 :         if hasattr(self, "mpl_cursor"):
     135           0 :             if hasattr(self.mpl_cursor, "remove"):  # type: ignore
     136           0 :                 self.mpl_cursor.remove()  # type: ignore
     137           0 :             del self.mpl_cursor  # type: ignore
     138             : 
     139           0 :         ax = self.canvas.ax
     140             : 
     141           0 :         artists = []
     142           0 :         for idx, labels in state_labels.items():
     143           0 :             x = x_value[idx]
     144           0 :             for energy, label in zip(energies[idx], labels):
     145           0 :                 artist = ax.plot(x, energy, "d", c="0.93", alpha=0.5, ms=7, label=label, zorder=-20)
     146           0 :                 artists.extend(artist)
     147             : 
     148           0 :         self.mpl_cursor = mplcursors.cursor(
     149             :             artists,
     150             :             hover=False,
     151             :             annotation_kwargs={
     152             :                 "bbox": {"boxstyle": "round,pad=0.5", "fc": "white", "alpha": 0.9, "ec": "gray"},
     153             :                 "arrowprops": {"arrowstyle": "->", "connectionstyle": "arc3", "color": "gray"},
     154             :             },
     155             :         )
     156             : 
     157           0 :         @self.mpl_cursor.connect("add")
     158           0 :         def on_add(sel: mplcursors.Selection) -> None:
     159           0 :             label = sel.artist.get_label()
     160           0 :             sel.annotation.set_text(label.replace(" + ", "\n + "))
     161             : 
     162           1 :     def fit(self, fit_type: str = "c6") -> None:  # noqa: PLR0912, C901
     163             :         """Fits a potential curve and displays the fit values.
     164             : 
     165             :         Args:
     166             :             fit_type: Type of fit to perform. Options are:
     167             :               c6: E = E0 + C6 * r^6
     168             :               c3: E = E0 + C3 * r^3
     169             :               c3+c6: E = E0 + C3 * r^3 + C6 * r^6
     170             : 
     171             :         Iterative calls will iterate through the potential curves
     172             : 
     173             :         """
     174             :         fit_func: Callable[..., NDArray[Any]]
     175           0 :         if fit_type == "c6":
     176           0 :             fit_func = lambda x, e0, c6: e0 + c6 / x**6  # noqa: E731
     177           0 :             fitlabel = "E0 = {0:.3f} GHz\nC6 = {1:.3f} GHz*µm^6"
     178           0 :         elif fit_type == "c3":
     179           0 :             fit_func = lambda x, e0, c3: e0 + c3 / x**3  # noqa: E731
     180           0 :             fitlabel = "E0 = {0:.3f} GHz\nC3 = {1:.3f} GHz*µm^3"
     181           0 :         elif fit_type == "c3+c6":
     182           0 :             fit_func = lambda x, e0, c3, c6: e0 + c3 / x**3 + c6 / x**6  # noqa: E731
     183           0 :             fitlabel = "E0 = {0:.3f} GHz\nC3 = {1:.3f} GHz*µm^3\nC6 = {2:.3f} GHz*µm^6"
     184             :         else:
     185           0 :             raise ValueError(f"Unknown fit type: {fit_type}")
     186             : 
     187             :         # first see if we actually have data to fit
     188           0 :         if not (hasattr(self, "x_list") and hasattr(self, "energies_list") and hasattr(self, "overlaps_list")):
     189           0 :             return
     190             : 
     191             :         # increase the selected potential curve by one if we use the same fit type
     192           0 :         if self.fit_type == fit_type:
     193           0 :             self.fit_idx = (self.fit_idx + 1) % len(self.energies_list)
     194             :         else:
     195           0 :             self.fit_idx = 1
     196           0 :         self.fit_type = fit_type
     197             : 
     198             :         # We want to follow the potential curves. The ordering of energies is just by value, so we
     199             :         # need to follow the curve somehow. We go right to left, start at the nth largest value, keep our
     200             :         # index as long as the difference in overlap is less than a factor 2 or less than 5% total difference.
     201             :         # Otherwise, we search until we find an overlap that is less than a factor 2 different.
     202             :         # This is of course a simple heuristic, a more sophisticated approach would do some global optimization
     203             :         # of the curves. This approach is simple, fast and robust, but curves may e.g. merge.
     204             :         # This does not at all take into account the line shapes of the curves. There is also no trade-off
     205             :         # between overlap being close and not doing jumps.
     206           0 :         idxs = [np.argpartition(self.overlaps_list[0], -self.fit_idx)[-self.fit_idx]]
     207           0 :         last_overlap = self.overlaps_list[0][idxs[-1]]
     208           0 :         for overlaps in self.overlaps_list[1:]:
     209           0 :             idx = idxs[-1]
     210           0 :             overlap = overlaps[idx]
     211           0 :             if 0.5 * last_overlap < overlap < 2 * last_overlap or abs(overlap - last_overlap) < 0.05:
     212             :                 # we keep the current index
     213           0 :                 idxs.append(idx)
     214           0 :                 last_overlap = overlap
     215             :             else:
     216             :                 # we search until we find an overlap that is less than a factor 2 different
     217           0 :                 possible_options = np.argwhere(
     218             :                     np.logical_and(overlaps > 0.5 * last_overlap, overlaps < 2 * last_overlap)
     219             :                 ).flatten()
     220           0 :                 if len(possible_options) == 0:
     221             :                     # there is no state in that range - our best bet is to keep the current index
     222           0 :                     idxs.append(idx)
     223           0 :                     last_overlap = overlap
     224             :                 else:
     225             :                     # we select the closest possible option
     226           0 :                     best_option = np.argmin(np.abs(possible_options - idx))
     227           0 :                     idxs.append(possible_options[best_option])
     228           0 :                     last_overlap = overlaps[idxs[-1]]
     229             : 
     230             :         # this could be a call to np.take_along_axis if the sizes match, but the handling of inhomogeneous shapes
     231             :         # in the plot() function makes me worry they won't, so I go for a slower python for loop...
     232           0 :         energies = np.array([energy[idx] for energy, idx in zip(self.energies_list, idxs)])
     233             : 
     234             :         # stop highlighting the previous fit
     235           0 :         if hasattr(self, "fit_data_highlight"):
     236           0 :             self.fit_data_highlight.remove()
     237           0 :         if hasattr(self, "fit_curve"):
     238           0 :             for curve in self.fit_curve:
     239           0 :                 curve.remove()
     240             : 
     241           0 :         self.fit_data_highlight = self.canvas.ax.scatter(self.x_list, energies, c="green", s=5)
     242             : 
     243           0 :         try:
     244           0 :             fit_params = curve_fit(fit_func, self.x_list, energies)[0]
     245           0 :         except (RuntimeError, TypeError):
     246           0 :             logger.warning("Curve fit failed.")
     247             :         else:
     248           0 :             self.fit_curve = self.canvas.ax.plot(
     249             :                 self.x_list,
     250             :                 fit_func(self.x_list, *fit_params),
     251             :                 c="green",
     252             :                 linestyle="dashed",
     253             :                 lw=2,
     254             :                 label=fitlabel.format(*fit_params),
     255             :             )
     256           0 :             self.canvas.ax.legend()
     257             : 
     258           0 :         self.canvas.draw()
     259             : 
     260           1 :     def clear(self) -> None:
     261           1 :         super().clear()
     262           1 :         self.reset_fit()
     263             : 
     264           1 :     def reset_fit(self) -> None:
     265             :         """Clear fit output and reset fit index."""
     266             :         # restart at first potential curve
     267           1 :         self.fit_idx = 0
     268             :         # and also remove any previous highlighting/fit display
     269           1 :         if hasattr(self, "fit_data_highlight"):
     270           0 :             del self.fit_data_highlight
     271           1 :         if hasattr(self, "fit_curve"):
     272           0 :             del self.fit_curve

Generated by: LCOV version 1.16