LCOV - code coverage report
Current view: top level - src/pairinteraction_gui/plotwidget - plotwidget.py (source / functions) Hit Total Coverage
Test: coverage.info Lines: 57 150 38.0 %
Date: 2025-09-29 10:28:29 Functions: 6 20 30.0 %

          Line data    Source code
       1             : # SPDX-FileCopyrightText: 2025 PairInteraction Developers
       2             : # SPDX-License-Identifier: LGPL-3.0-or-later
       3           1 : from __future__ import annotations
       4             : 
       5           1 : import logging
       6           1 : from typing import TYPE_CHECKING, Any, Callable
       7             : 
       8           1 : import matplotlib as mpl
       9           1 : import matplotlib.pyplot as plt
      10           1 : import mplcursors
      11           1 : import numpy as np
      12           1 : from matplotlib.colors import Normalize
      13           1 : from PySide6.QtWidgets import QHBoxLayout
      14           1 : from scipy.optimize import curve_fit
      15             : 
      16           1 : from pairinteraction.visualization.colormaps import alphamagma
      17           1 : from pairinteraction_gui.plotwidget.canvas import MatplotlibCanvas
      18           1 : from pairinteraction_gui.plotwidget.navigation_toolbar import CustomNavigationToolbar
      19           1 : from pairinteraction_gui.qobjects import WidgetV
      20           1 : from pairinteraction_gui.theme import plot_widget_theme
      21             : 
      22             : if TYPE_CHECKING:
      23             :     from collections.abc import Sequence
      24             : 
      25             :     from numpy.typing import NDArray
      26             : 
      27             :     from pairinteraction_gui.calculate.calculate_base import Parameters, Results
      28             :     from pairinteraction_gui.page import SimulationPage
      29             : 
      30           1 : logger = logging.getLogger(__name__)
      31             : 
      32             : 
      33           1 : class PlotWidget(WidgetV):
      34             :     """Widget for displaying plots with controls."""
      35             : 
      36           1 :     margin = (0, 0, 0, 0)
      37           1 :     spacing = 15
      38             : 
      39           1 :     def __init__(self, parent: SimulationPage) -> None:
      40             :         """Initialize the base section."""
      41           1 :         mpl.use("Qt5Agg")
      42             : 
      43           1 :         self.page = parent
      44           1 :         super().__init__(parent)
      45             : 
      46           1 :         self.setStyleSheet(plot_widget_theme)
      47             : 
      48           1 :     def setupWidget(self) -> None:
      49           1 :         self.canvas = MatplotlibCanvas(self)
      50           1 :         self.navigation_toolbar = CustomNavigationToolbar(self.canvas, self)
      51             : 
      52           1 :         top_layout = QHBoxLayout()
      53           1 :         top_layout.addStretch(1)
      54           1 :         top_layout.addWidget(self.navigation_toolbar)
      55           1 :         self.layout().addLayout(top_layout)
      56             : 
      57           1 :         self.layout().addWidget(self.canvas, stretch=1)
      58             : 
      59           1 :     def clear(self) -> None:
      60           1 :         self.canvas.ax.clear()
      61           1 :         self.canvas.draw()
      62             : 
      63             : 
      64           1 : class PlotEnergies(PlotWidget):
      65             :     """Plotwidget for plotting energy levels."""
      66             : 
      67           1 :     parameters: Parameters[Any] | None = None
      68           1 :     results: Results | None = None
      69             : 
      70           1 :     fit_idx: int = 0
      71           1 :     fit_type: str = ""
      72           1 :     fit_data_highlight: mpl.collections.PathCollection
      73           1 :     fit_curve: Sequence[mpl.lines.Line2D]
      74             : 
      75           1 :     def setupWidget(self) -> None:
      76           1 :         super().setupWidget()
      77             : 
      78           1 :         mappable = plt.cm.ScalarMappable(cmap=alphamagma, norm=Normalize(vmin=0, vmax=1))
      79           1 :         self.canvas.fig.colorbar(mappable, ax=self.canvas.ax, label="Overlap with state of interest")
      80           1 :         self.canvas.fig.tight_layout()
      81             : 
      82           1 :     def plot(self, parameters: Parameters[Any], results: Results) -> None:
      83           0 :         self.reset_fit()
      84           0 :         ax = self.canvas.ax
      85           0 :         ax.clear()
      86             : 
      87             :         # store data to allow fitting later on
      88           0 :         self.parameters = parameters
      89           0 :         self.results = results
      90             : 
      91           0 :         x_values = parameters.get_x_values()
      92           0 :         energies = results.energies
      93             : 
      94           0 :         try:
      95           0 :             ax.plot(x_values, np.array(energies), c="0.75", lw=0.25, zorder=-10)
      96           0 :         except ValueError as err:
      97           0 :             if "inhomogeneous shape" in str(err):
      98           0 :                 for x_value, es in zip(x_values, energies):
      99           0 :                     ax.plot([x_value] * len(es), es, c="0.75", ls="None", marker=".", zorder=-10)
     100             :             else:
     101           0 :                 raise err
     102             : 
     103             :         # Flatten the arrays for scatter plot and repeat x value for each energy
     104             :         # (dont use numpy.flatten, etc. to also handle inhomogeneous shapes)
     105           0 :         x_repeated = np.hstack([val * np.ones_like(es) for val, es in zip(x_values, energies)])
     106           0 :         energies_flattened = np.hstack(energies)
     107           0 :         overlaps_flattened = np.hstack(results.ket_overlaps)
     108             : 
     109           0 :         min_overlap = 1e-4
     110           0 :         inds: NDArray[Any] = np.argwhere(overlaps_flattened > min_overlap).flatten()
     111           0 :         inds = inds[np.argsort(overlaps_flattened[inds])]
     112             : 
     113           0 :         if len(inds) > 0:
     114           0 :             ax.scatter(
     115             :                 x_repeated[inds],
     116             :                 energies_flattened[inds],
     117             :                 c=overlaps_flattened[inds],
     118             :                 s=15,
     119             :                 vmin=0,
     120             :                 vmax=1,
     121             :                 cmap=alphamagma,
     122             :             )
     123             : 
     124           0 :         ax.set_xlabel(parameters.get_x_label())
     125           0 :         ax.set_ylabel("Energy [GHz]")
     126             : 
     127           0 :         self.canvas.fig.tight_layout()
     128             : 
     129           1 :     def add_cursor(self, parameters: Parameters[Any], results: Results) -> None:
     130             :         # Remove any existing cursors to avoid duplicates
     131           0 :         if hasattr(self, "mpl_cursor"):
     132           0 :             if hasattr(self.mpl_cursor, "remove"):  # type: ignore
     133           0 :                 self.mpl_cursor.remove()  # type: ignore
     134           0 :             del self.mpl_cursor  # type: ignore
     135             : 
     136           0 :         energies = results.energies
     137           0 :         x_values = parameters.get_x_values()
     138             : 
     139           0 :         artists = []
     140           0 :         for idx, labels in results.state_labels.items():
     141           0 :             x = x_values[idx]
     142           0 :             for energy, label in zip(energies[idx], labels):
     143           0 :                 artist = self.canvas.ax.plot(x, energy, "d", c="0.93", alpha=0.5, ms=7, label=label, zorder=-20)
     144           0 :                 artists.extend(artist)
     145             : 
     146           0 :         self.mpl_cursor = mplcursors.cursor(
     147             :             artists,
     148             :             hover=False,
     149             :             annotation_kwargs={
     150             :                 "bbox": {"boxstyle": "round,pad=0.5", "fc": "white", "alpha": 0.9, "ec": "gray"},
     151             :                 "arrowprops": {"arrowstyle": "->", "connectionstyle": "arc3", "color": "gray"},
     152             :             },
     153             :         )
     154             : 
     155           0 :         @self.mpl_cursor.connect("add")
     156           0 :         def on_add(sel: mplcursors.Selection) -> None:
     157           0 :             label = sel.artist.get_label()
     158           0 :             sel.annotation.set_text(label.replace(" + ", "\n + "))
     159             : 
     160           1 :     def fit(self, fit_type: str = "c6") -> None:  # noqa: PLR0912, PLR0915, C901
     161             :         """Fits a potential curve and displays the fit values.
     162             : 
     163             :         Args:
     164             :             fit_type: Type of fit to perform. Options are:
     165             :               c6: E = E0 + C6 * r^6
     166             :               c3: E = E0 + C3 * r^3
     167             :               c3+c6: E = E0 + C3 * r^3 + C6 * r^6
     168             : 
     169             :         Iterative calls will iterate through the potential curves
     170             : 
     171             :         """
     172           0 :         if self.parameters is None or self.results is None:
     173           0 :             logger.warning("No data to fit.")
     174           0 :             return
     175             : 
     176           0 :         energies = self.results.energies
     177           0 :         x_values = self.parameters.get_x_values()
     178           0 :         overlaps_list = self.results.ket_overlaps
     179             : 
     180             :         fit_func: Callable[..., NDArray[Any]]
     181           0 :         if fit_type == "c6":
     182           0 :             fit_func = lambda x, e0, c6: e0 + c6 / x**6  # noqa: E731
     183           0 :             fitlabel = "E0 = {0:.3f} GHz\nC6 = {1:.3f} GHz*µm^6"
     184           0 :         elif fit_type == "c3":
     185           0 :             fit_func = lambda x, e0, c3: e0 + c3 / x**3  # noqa: E731
     186           0 :             fitlabel = "E0 = {0:.3f} GHz\nC3 = {1:.3f} GHz*µm^3"
     187           0 :         elif fit_type == "c3+c6":
     188           0 :             fit_func = lambda x, e0, c3, c6: e0 + c3 / x**3 + c6 / x**6  # noqa: E731
     189           0 :             fitlabel = "E0 = {0:.3f} GHz\nC3 = {1:.3f} GHz*µm^3\nC6 = {2:.3f} GHz*µm^6"
     190             :         else:
     191           0 :             raise ValueError(f"Unknown fit type: {fit_type}")
     192             : 
     193             :         # increase the selected potential curve by one if we use the same fit type
     194           0 :         if self.fit_type == fit_type:
     195           0 :             self.fit_idx = (self.fit_idx + 1) % len(energies)
     196             :         else:
     197           0 :             self.fit_idx = 1
     198           0 :         self.fit_type = fit_type
     199             : 
     200             :         # We want to follow the potential curves. The ordering of energies is just by value, so we
     201             :         # need to follow the curve somehow. We go right to left, start at the nth largest value, keep our
     202             :         # index as long as the difference in overlap is less than a factor 2 or less than 5% total difference.
     203             :         # Otherwise, we search until we find an overlap that is less than a factor 2 different.
     204             :         # This is of course a simple heuristic, a more sophisticated approach would do some global optimization
     205             :         # of the curves. This approach is simple, fast and robust, but curves may e.g. merge.
     206             :         # This does not at all take into account the line shapes of the curves. There is also no trade-off
     207             :         # between overlap being close and not doing jumps.
     208           0 :         idxs = [np.argpartition(overlaps_list[0], -self.fit_idx)[-self.fit_idx]]
     209           0 :         last_overlap = overlaps_list[0][idxs[-1]]
     210           0 :         for overlaps in overlaps_list[1:]:
     211           0 :             idx = idxs[-1]
     212           0 :             overlap = overlaps[idx]
     213           0 :             if 0.5 * last_overlap < overlap < 2 * last_overlap or abs(overlap - last_overlap) < 0.05:
     214             :                 # we keep the current index
     215           0 :                 idxs.append(idx)
     216           0 :                 last_overlap = overlap
     217             :             else:
     218             :                 # we search until we find an overlap that is less than a factor 2 different
     219           0 :                 possible_options = np.argwhere(
     220             :                     np.logical_and(overlaps > 0.5 * last_overlap, overlaps < 2 * last_overlap)
     221             :                 ).flatten()
     222           0 :                 if len(possible_options) == 0:
     223             :                     # there is no state in that range - our best bet is to keep the current index
     224           0 :                     idxs.append(idx)
     225           0 :                     last_overlap = overlap
     226             :                 else:
     227             :                     # we select the closest possible option
     228           0 :                     best_option = np.argmin(np.abs(possible_options - idx))
     229           0 :                     idxs.append(possible_options[best_option])
     230           0 :                     last_overlap = overlaps[idxs[-1]]
     231             : 
     232             :         # this could be a call to np.take_along_axis if the sizes match, but the handling of inhomogeneous shapes
     233             :         # in the plot() function makes me worry they won't, so I go for a slower python for loop...
     234           0 :         energies_fit = np.array([energy[idx] for energy, idx in zip(energies, idxs)])
     235             : 
     236             :         # stop highlighting the previous fit
     237           0 :         if hasattr(self, "fit_data_highlight"):
     238           0 :             self.fit_data_highlight.remove()
     239           0 :         if hasattr(self, "fit_curve"):
     240           0 :             for curve in self.fit_curve:
     241           0 :                 curve.remove()
     242             : 
     243           0 :         self.fit_data_highlight = self.canvas.ax.scatter(x_values, energies_fit, c="green", s=5)
     244             : 
     245           0 :         try:
     246           0 :             fit_params = curve_fit(fit_func, x_values, energies_fit)[0]
     247           0 :         except (RuntimeError, TypeError):
     248           0 :             logger.warning("Curve fit failed.")
     249             :         else:
     250           0 :             self.fit_curve = self.canvas.ax.plot(
     251             :                 x_values,
     252             :                 fit_func(x_values, *fit_params),
     253             :                 c="green",
     254             :                 linestyle="dashed",
     255             :                 lw=2,
     256             :                 label=fitlabel.format(*fit_params),
     257             :             )
     258           0 :             self.canvas.ax.legend()
     259             : 
     260           0 :         self.canvas.draw()
     261             : 
     262           1 :     def clear(self) -> None:
     263           1 :         super().clear()
     264           1 :         self.reset_fit()
     265             : 
     266           1 :     def reset_fit(self) -> None:
     267             :         """Clear fit output and reset fit index."""
     268             :         # restart at first potential curve
     269           1 :         self.fit_idx = 0
     270             :         # and also remove any previous highlighting/fit display
     271           1 :         if hasattr(self, "fit_data_highlight"):
     272           0 :             del self.fit_data_highlight
     273           1 :         if hasattr(self, "fit_curve"):
     274           0 :             del self.fit_curve

Generated by: LCOV version 1.16