LCOV - code coverage report
Current view: top level - src/pairinteraction_gui/plotwidget - plotwidget.py (source / functions) Hit Total Coverage
Test: coverage.info Lines: 139 156 89.1 %
Date: 2026-03-03 11:15:30 Functions: 12 13 92.3 %

          Line data    Source code
       1             : # SPDX-FileCopyrightText: 2025 PairInteraction Developers
       2             : # SPDX-License-Identifier: LGPL-3.0-or-later
       3           1 : from __future__ import annotations
       4             : 
       5           1 : import logging
       6           1 : from typing import TYPE_CHECKING, Any, Callable
       7             : 
       8           1 : import matplotlib as mpl
       9           1 : import matplotlib.pyplot as plt
      10           1 : import mplcursors
      11           1 : import numpy as np
      12           1 : from matplotlib.colors import Normalize
      13           1 : from PySide6.QtWidgets import QHBoxLayout
      14           1 : from scipy.optimize import curve_fit
      15             : 
      16           1 : from pairinteraction.visualization.colormaps import alphamagma
      17           1 : from pairinteraction_gui.plotwidget.canvas import MatplotlibCanvas
      18           1 : from pairinteraction_gui.plotwidget.navigation_toolbar import CustomNavigationToolbar
      19           1 : from pairinteraction_gui.qobjects import WidgetV
      20           1 : from pairinteraction_gui.theme import plot_widget_theme
      21             : 
      22             : if TYPE_CHECKING:
      23             :     from collections.abc import Sequence
      24             : 
      25             :     from numpy.typing import NDArray
      26             :     from typing_extensions import Concatenate
      27             : 
      28             :     from pairinteraction_gui.calculate.calculate_base import Parameters, Results
      29             :     from pairinteraction_gui.page import SimulationPage
      30             : 
      31           1 : logger = logging.getLogger(__name__)
      32             : 
      33             : 
      34           1 : class PlotWidget(WidgetV):
      35             :     """Widget for displaying plots with controls."""
      36             : 
      37           1 :     margin = (0, 0, 0, 0)
      38           1 :     spacing = 15
      39             : 
      40           1 :     def __init__(self, parent: SimulationPage) -> None:
      41             :         """Initialize the base section."""
      42           1 :         mpl.use("Qt5Agg")
      43             : 
      44           1 :         self.page = parent
      45           1 :         super().__init__(parent)
      46             : 
      47           1 :         self.setStyleSheet(plot_widget_theme)
      48             : 
      49           1 :     def setupWidget(self) -> None:
      50           1 :         self.canvas = MatplotlibCanvas(self)
      51           1 :         self.navigation_toolbar = CustomNavigationToolbar(self.canvas, self)
      52             : 
      53           1 :         top_layout = QHBoxLayout()
      54           1 :         top_layout.addStretch(1)
      55           1 :         top_layout.addWidget(self.navigation_toolbar)
      56           1 :         self.layout().addLayout(top_layout)
      57             : 
      58           1 :         self.layout().addWidget(self.canvas, stretch=1)
      59             : 
      60           1 :     def clear(self) -> None:
      61           1 :         self.canvas.ax.clear()
      62           1 :         self.canvas.draw()
      63             : 
      64             : 
      65           1 : class PlotEnergies(PlotWidget):
      66             :     """Plotwidget for plotting energy levels."""
      67             : 
      68           1 :     parameters: Parameters[Any] | None = None
      69           1 :     results: Results | None = None
      70             : 
      71           1 :     fit_idx: int = 0
      72           1 :     fit_type: str = ""
      73           1 :     fit_data_highlight: mpl.collections.PathCollection
      74           1 :     fit_curve: Sequence[mpl.lines.Line2D]
      75             : 
      76           1 :     def setupWidget(self) -> None:
      77           1 :         super().setupWidget()
      78             : 
      79           1 :         mappable = plt.cm.ScalarMappable(cmap=alphamagma, norm=Normalize(vmin=0, vmax=1))
      80           1 :         self.canvas.fig.colorbar(mappable, ax=self.canvas.ax, label="Overlap with state of interest")
      81           1 :         self.canvas.fig.tight_layout()
      82             : 
      83           1 :     def plot(self, parameters: Parameters[Any], results: Results) -> None:
      84           1 :         self.reset_fit()
      85           1 :         ax = self.canvas.ax
      86           1 :         ax.clear()
      87             : 
      88             :         # store data to allow fitting later on
      89           1 :         self.parameters = parameters
      90           1 :         self.results = results
      91             : 
      92           1 :         x_values = parameters.get_x_values()
      93           1 :         energies = results.energies
      94             : 
      95           1 :         try:
      96           1 :             ax.plot(x_values, np.array(energies), c="0.75", lw=0.25, zorder=-10)
      97           0 :         except ValueError as err:
      98           0 :             if "inhomogeneous shape" in str(err):
      99           0 :                 for x_value, es in zip(x_values, energies):
     100           0 :                     ax.plot([x_value] * len(es), es, c="0.75", ls="None", marker=".", zorder=-10)
     101             :             else:
     102           0 :                 raise err
     103             : 
     104             :         # Flatten the arrays for scatter plot and repeat x value for each energy
     105             :         # (dont use numpy.flatten, etc. to also handle inhomogeneous shapes)
     106           1 :         x_repeated = np.hstack([val * np.ones_like(es) for val, es in zip(x_values, energies)])
     107           1 :         energies_flattened = np.hstack(energies)
     108           1 :         overlaps_flattened = np.hstack(results.ket_overlaps)
     109             : 
     110           1 :         min_overlap = 1e-4
     111           1 :         inds: NDArray[Any] = np.argwhere(overlaps_flattened > min_overlap).flatten()
     112           1 :         inds = inds[np.argsort(overlaps_flattened[inds])]
     113             : 
     114           1 :         if len(inds) > 0:
     115           1 :             ax.scatter(
     116             :                 x_repeated[inds],
     117             :                 energies_flattened[inds],
     118             :                 c=overlaps_flattened[inds],
     119             :                 s=15,
     120             :                 vmin=0,
     121             :                 vmax=1,
     122             :                 cmap=alphamagma,
     123             :             )
     124             : 
     125           1 :         ax.set_xlabel(parameters.get_x_label())
     126           1 :         ax.set_ylabel("Energy [GHz]")
     127             : 
     128           1 :         self.canvas.fig.tight_layout()
     129             : 
     130           1 :     def add_cursor(self, parameters: Parameters[Any], results: Results) -> None:
     131             :         # Remove any existing cursors to avoid duplicates
     132           1 :         if hasattr(self, "mpl_cursor"):
     133           0 :             if hasattr(self.mpl_cursor, "remove"):  # type: ignore
     134           0 :                 self.mpl_cursor.remove()  # type: ignore
     135           0 :             del self.mpl_cursor  # type: ignore
     136             : 
     137           1 :         energies = results.energies
     138           1 :         x_values = parameters.get_x_values()
     139             : 
     140           1 :         artists = []
     141           1 :         for idx, labels in results.state_labels.items():
     142           1 :             x = x_values[idx]
     143           1 :             for energy, label in zip(energies[idx], labels):
     144           1 :                 artist = self.canvas.ax.plot(x, energy, "d", c="0.93", alpha=0.5, ms=7, label=label, zorder=-20)
     145           1 :                 artists.extend(artist)
     146             : 
     147           1 :         self.mpl_cursor = mplcursors.cursor(
     148             :             artists,
     149             :             hover=False,
     150             :             annotation_kwargs={
     151             :                 "bbox": {"boxstyle": "round,pad=0.5", "fc": "white", "alpha": 0.9, "ec": "gray"},
     152             :                 "arrowprops": {"arrowstyle": "->", "connectionstyle": "arc3", "color": "gray"},
     153             :             },
     154             :         )
     155             : 
     156           1 :         @self.mpl_cursor.connect("add")
     157           1 :         def on_add(sel: mplcursors.Selection) -> None:
     158           0 :             label = sel.artist.get_label()
     159           0 :             sel.annotation.set_text(label.replace(" + ", "\n + "))
     160             : 
     161           1 :     def fit(self, fit_type: str = "c6") -> None:  # noqa: PLR0912, PLR0915, C901
     162             :         """Fits a potential curve and displays the fit values.
     163             : 
     164             :         Args:
     165             :             fit_type: Type of fit to perform. Options are:
     166             :               c6: E = E0 + C6 * r^6
     167             :               c3: E = E0 + C3 * r^3
     168             :               c3+c6: E = E0 + C3 * r^3 + C6 * r^6
     169             : 
     170             :         Iterative calls will iterate through the potential curves
     171             : 
     172             :         """
     173           1 :         if self.parameters is None or self.results is None:
     174           0 :             logger.warning("No data to fit.")
     175           0 :             return
     176             : 
     177           1 :         energies = self.results.energies
     178           1 :         x_values = np.array(self.parameters.get_x_values())
     179           1 :         overlaps_list = self.results.ket_overlaps
     180             : 
     181             :         fit_func: Callable[Concatenate[NDArray[Any], ...], NDArray[Any]]
     182           1 :         if fit_type == "c6":
     183           1 :             fit_func = fit_c6
     184           1 :             fitlabel = "E0 = {0:.3f} GHz\nC6 = {1:.3f} GHz*µm^6"
     185           1 :         elif fit_type == "c3":
     186           1 :             fit_func = fit_c3
     187           1 :             fitlabel = "E0 = {0:.3f} GHz\nC3 = {1:.3f} GHz*µm^3"
     188           1 :         elif fit_type == "c3+c6":
     189           1 :             fit_func = fit_c3_c6
     190           1 :             fitlabel = "E0 = {0:.3f} GHz\nC3 = {1:.3f} GHz*µm^3\nC6 = {2:.3f} GHz*µm^6"
     191             :         else:
     192           0 :             raise ValueError(f"Unknown fit type: {fit_type}")
     193             : 
     194             :         # increase the selected potential curve by one if we use the same fit type
     195           1 :         if self.fit_type == fit_type:
     196           1 :             self.fit_idx = (self.fit_idx + 1) % len(energies)
     197             :         else:
     198           1 :             self.fit_idx = 1
     199           1 :         self.fit_type = fit_type
     200             : 
     201             :         # We want to follow the potential curves. The ordering of energies is just by value, so we
     202             :         # need to follow the curve somehow. We go right to left, start at the nth largest value, keep our
     203             :         # index as long as the difference in overlap is less than a factor 2 or less than 5% total difference.
     204             :         # Otherwise, we search until we find an overlap that is less than a factor 2 different.
     205             :         # This is of course a simple heuristic, a more sophisticated approach would do some global optimization
     206             :         # of the curves. This approach is simple, fast and robust, but curves may e.g. merge.
     207             :         # This does not at all take into account the line shapes of the curves. There is also no trade-off
     208             :         # between overlap being close and not doing jumps.
     209           1 :         idxs = [np.argpartition(overlaps_list[0], -self.fit_idx)[-self.fit_idx]]
     210           1 :         last_overlap = overlaps_list[0][idxs[-1]]
     211           1 :         for overlaps in overlaps_list[1:]:
     212           1 :             idx = idxs[-1]
     213           1 :             overlap = overlaps[idx]
     214           1 :             if 0.5 * last_overlap < overlap < 2 * last_overlap or abs(overlap - last_overlap) < 0.05:
     215             :                 # we keep the current index
     216           1 :                 idxs.append(idx)
     217           1 :                 last_overlap = overlap
     218             :             else:
     219             :                 # we search until we find an overlap that is less than a factor 2 different
     220           1 :                 possible_options = np.argwhere(
     221             :                     np.logical_and(overlaps > 0.5 * last_overlap, overlaps < 2 * last_overlap)
     222             :                 ).flatten()
     223           1 :                 if len(possible_options) == 0:
     224             :                     # there is no state in that range - our best bet is to keep the current index
     225           1 :                     idxs.append(idx)
     226           1 :                     last_overlap = overlap
     227             :                 else:
     228             :                     # we select the closest possible option
     229           1 :                     best_option = np.argmin(np.abs(possible_options - idx))
     230           1 :                     idxs.append(possible_options[best_option])
     231           1 :                     last_overlap = overlaps[idxs[-1]]
     232             : 
     233             :         # this could be a call to np.take_along_axis if the sizes match, but the handling of inhomogeneous shapes
     234             :         # in the plot() function makes me worry they won't, so I go for a slower python for loop...
     235           1 :         energies_fit = np.array([energy[idx] for energy, idx in zip(energies, idxs)])
     236             : 
     237             :         # stop highlighting the previous fit
     238           1 :         if hasattr(self, "fit_data_highlight"):
     239           1 :             self.fit_data_highlight.remove()
     240           1 :         if hasattr(self, "fit_curve"):
     241           1 :             for curve in self.fit_curve:
     242           1 :                 curve.remove()
     243             : 
     244           1 :         self.fit_data_highlight = self.canvas.ax.scatter(x_values, energies_fit, c="green", s=5)
     245             : 
     246           1 :         try:
     247           1 :             fit_params = curve_fit(fit_func, x_values, energies_fit)[0]
     248           0 :         except (RuntimeError, TypeError):
     249           0 :             logger.warning("Curve fit failed.")
     250             :         else:
     251           1 :             self.fit_curve = self.canvas.ax.plot(
     252             :                 x_values,
     253             :                 fit_func(x_values, *fit_params),
     254             :                 c="green",
     255             :                 linestyle="dashed",
     256             :                 lw=2,
     257             :                 label=fitlabel.format(*fit_params),
     258             :             )
     259           1 :             self.canvas.ax.legend()
     260             : 
     261           1 :         self.canvas.draw()
     262             : 
     263           1 :     def clear(self) -> None:
     264           1 :         super().clear()
     265           1 :         self.reset_fit()
     266             : 
     267           1 :     def reset_fit(self) -> None:
     268             :         """Clear fit output and reset fit index."""
     269             :         # restart at first potential curve
     270           1 :         self.fit_idx = 0
     271             :         # and also remove any previous highlighting/fit display
     272           1 :         if hasattr(self, "fit_data_highlight"):
     273           0 :             del self.fit_data_highlight
     274           1 :         if hasattr(self, "fit_curve"):
     275           0 :             del self.fit_curve
     276             : 
     277             : 
     278           1 : def fit_c3(x: NDArray[Any], /, e0: float, c3: float) -> NDArray[Any]:
     279           1 :     return e0 + c3 / x**3
     280             : 
     281             : 
     282           1 : def fit_c6(x: NDArray[Any], /, e0: float, c6: float) -> NDArray[Any]:
     283           1 :     return e0 + c6 / x**6
     284             : 
     285             : 
     286           1 : def fit_c3_c6(x: NDArray[Any], /, e0: float, c3: float, c6: float) -> NDArray[Any]:
     287           1 :     return e0 + c3 / x**3 + c6 / x**6

Generated by: LCOV version 1.16