Line data Source code
1 : # SPDX-FileCopyrightText: 2025 PairInteraction Developers
2 : # SPDX-License-Identifier: LGPL-3.0-or-later
3 :
4 1 : import logging
5 1 : from collections.abc import Sequence
6 1 : from typing import TYPE_CHECKING, Any, Callable
7 :
8 1 : import matplotlib as mpl
9 1 : import matplotlib.pyplot as plt
10 1 : import mplcursors
11 1 : import numpy as np
12 1 : from matplotlib.colors import Normalize
13 1 : from PySide6.QtWidgets import QHBoxLayout
14 1 : from scipy.optimize import curve_fit
15 :
16 1 : from pairinteraction.visualization.colormaps import alphamagma
17 1 : from pairinteraction_gui.plotwidget.canvas import MatplotlibCanvas
18 1 : from pairinteraction_gui.plotwidget.navigation_toolbar import CustomNavigationToolbar
19 1 : from pairinteraction_gui.qobjects import WidgetV
20 1 : from pairinteraction_gui.theme import plot_widget_theme
21 :
22 : if TYPE_CHECKING:
23 : from numpy.typing import NDArray
24 :
25 : from pairinteraction_gui.page import SimulationPage
26 :
27 1 : logger = logging.getLogger(__name__)
28 :
29 :
30 1 : class PlotWidget(WidgetV):
31 : """Widget for displaying plots with controls."""
32 :
33 1 : margin = (0, 0, 0, 0)
34 1 : spacing = 15
35 :
36 1 : def __init__(self, parent: "SimulationPage") -> None:
37 : """Initialize the base section."""
38 1 : mpl.use("Qt5Agg")
39 :
40 1 : self.page = parent
41 1 : super().__init__(parent)
42 :
43 1 : self.setStyleSheet(plot_widget_theme)
44 :
45 1 : def setupWidget(self) -> None:
46 1 : self.canvas = MatplotlibCanvas(self)
47 1 : self.navigation_toolbar = CustomNavigationToolbar(self.canvas, self)
48 :
49 1 : top_layout = QHBoxLayout()
50 1 : top_layout.addStretch(1)
51 1 : top_layout.addWidget(self.navigation_toolbar)
52 1 : self.layout().addLayout(top_layout)
53 :
54 1 : self.layout().addWidget(self.canvas, stretch=1)
55 :
56 1 : def clear(self) -> None:
57 1 : self.canvas.ax.clear()
58 1 : self.canvas.draw()
59 :
60 :
61 1 : class PlotEnergies(PlotWidget):
62 : """Plotwidget for plotting energy levels."""
63 :
64 1 : x_list: "NDArray[Any]"
65 1 : energies_list: Sequence["NDArray[Any]"]
66 1 : overlaps_list: Sequence["NDArray[Any]"]
67 1 : fit_idx: int = 0
68 1 : fit_type: str = ""
69 1 : fit_data_highlight: mpl.collections.PathCollection
70 1 : fit_curve: Sequence[mpl.lines.Line2D]
71 :
72 1 : def setupWidget(self) -> None:
73 1 : super().setupWidget()
74 :
75 1 : mappable = plt.cm.ScalarMappable(cmap=alphamagma, norm=Normalize(vmin=0, vmax=1))
76 1 : self.canvas.fig.colorbar(mappable, ax=self.canvas.ax, label="Overlap with state of interest")
77 1 : self.canvas.fig.tight_layout()
78 :
79 1 : def plot(
80 : self,
81 : x_list: Sequence[float],
82 : energies_list: Sequence["NDArray[Any]"],
83 : overlaps_list: Sequence["NDArray[Any]"],
84 : xlabel: str,
85 : ) -> None:
86 : # store data to allow fitting later on
87 0 : self.x_list = np.array(x_list)
88 0 : self.energies_list = energies_list
89 0 : self.overlaps_list = overlaps_list
90 0 : self.reset_fit()
91 :
92 0 : ax = self.canvas.ax
93 0 : ax.clear()
94 :
95 0 : try:
96 0 : ax.plot(x_list, np.array(energies_list), c="0.75", lw=0.25, zorder=-10)
97 0 : except ValueError as err:
98 0 : if "inhomogeneous shape" in str(err):
99 0 : for x_value, es in zip(x_list, energies_list):
100 0 : ax.plot([x_value] * len(es), es, c="0.75", ls="None", marker=".", zorder=-10)
101 : else:
102 0 : raise err
103 :
104 : # Flatten the arrays for scatter plot and repeat x value for each energy
105 : # (dont use numpy.flatten, etc. to also handle inhomogeneous shapes)
106 0 : x_repeated = np.hstack([val * np.ones_like(es) for val, es in zip(x_list, energies_list)])
107 0 : energies_flattened = np.hstack(energies_list)
108 0 : overlaps_flattened = np.hstack(overlaps_list)
109 :
110 0 : min_overlap = 1e-4
111 0 : inds: NDArray[Any] = np.argwhere(overlaps_flattened > min_overlap).flatten()
112 0 : inds = inds[np.argsort(overlaps_flattened[inds])]
113 :
114 0 : if len(inds) > 0:
115 0 : ax.scatter(
116 : x_repeated[inds],
117 : energies_flattened[inds],
118 : c=overlaps_flattened[inds],
119 : s=15,
120 : vmin=0,
121 : vmax=1,
122 : cmap=alphamagma,
123 : )
124 :
125 0 : ax.set_xlabel(xlabel)
126 0 : ax.set_ylabel("Energy [GHz]")
127 :
128 0 : self.canvas.fig.tight_layout()
129 :
130 1 : def add_cursor(
131 : self, x_value: list[float], energies: list["NDArray[Any]"], state_labels: dict[int, list[str]]
132 : ) -> None:
133 : # Remove any existing cursors to avoid duplicates
134 0 : if hasattr(self, "mpl_cursor"):
135 0 : if hasattr(self.mpl_cursor, "remove"): # type: ignore
136 0 : self.mpl_cursor.remove() # type: ignore
137 0 : del self.mpl_cursor # type: ignore
138 :
139 0 : ax = self.canvas.ax
140 :
141 0 : artists = []
142 0 : for idx, labels in state_labels.items():
143 0 : x = x_value[idx]
144 0 : for energy, label in zip(energies[idx], labels):
145 0 : artist = ax.plot(x, energy, "d", c="0.93", alpha=0.5, ms=7, label=label, zorder=-20)
146 0 : artists.extend(artist)
147 :
148 0 : self.mpl_cursor = mplcursors.cursor(
149 : artists,
150 : hover=False,
151 : annotation_kwargs={
152 : "bbox": {"boxstyle": "round,pad=0.5", "fc": "white", "alpha": 0.9, "ec": "gray"},
153 : "arrowprops": {"arrowstyle": "->", "connectionstyle": "arc3", "color": "gray"},
154 : },
155 : )
156 :
157 0 : @self.mpl_cursor.connect("add")
158 0 : def on_add(sel: mplcursors.Selection) -> None:
159 0 : label = sel.artist.get_label()
160 0 : sel.annotation.set_text(label.replace(" + ", "\n + "))
161 :
162 1 : def fit(self, fit_type: str = "c6") -> None: # noqa: PLR0912, C901
163 : """Fits a potential curve and displays the fit values.
164 :
165 : Args:
166 : fit_type: Type of fit to perform. Options are:
167 : c6: E = E0 + C6 * r^6
168 : c3: E = E0 + C3 * r^3
169 : c3+c6: E = E0 + C3 * r^3 + C6 * r^6
170 :
171 : Iterative calls will iterate through the potential curves
172 :
173 : """
174 : fit_func: Callable[..., NDArray[Any]]
175 0 : if fit_type == "c6":
176 0 : fit_func = lambda x, e0, c6: e0 + c6 / x**6 # noqa: E731
177 0 : fitlabel = "E0 = {0:.3f} GHz\nC6 = {1:.3f} GHz*µm^6"
178 0 : elif fit_type == "c3":
179 0 : fit_func = lambda x, e0, c3: e0 + c3 / x**3 # noqa: E731
180 0 : fitlabel = "E0 = {0:.3f} GHz\nC3 = {1:.3f} GHz*µm^3"
181 0 : elif fit_type == "c3+c6":
182 0 : fit_func = lambda x, e0, c3, c6: e0 + c3 / x**3 + c6 / x**6 # noqa: E731
183 0 : fitlabel = "E0 = {0:.3f} GHz\nC3 = {1:.3f} GHz*µm^3\nC6 = {2:.3f} GHz*µm^6"
184 : else:
185 0 : raise ValueError(f"Unknown fit type: {fit_type}")
186 :
187 : # first see if we actually have data to fit
188 0 : if not (hasattr(self, "x_list") and hasattr(self, "energies_list") and hasattr(self, "overlaps_list")):
189 0 : return
190 :
191 : # increase the selected potential curve by one if we use the same fit type
192 0 : if self.fit_type == fit_type:
193 0 : self.fit_idx = (self.fit_idx + 1) % len(self.energies_list)
194 : else:
195 0 : self.fit_idx = 1
196 0 : self.fit_type = fit_type
197 :
198 : # We want to follow the potential curves. The ordering of energies is just by value, so we
199 : # need to follow the curve somehow. We go right to left, start at the nth largest value, keep our
200 : # index as long as the difference in overlap is less than a factor 2 or less than 5% total difference.
201 : # Otherwise, we search until we find an overlap that is less than a factor 2 different.
202 : # This is of course a simple heuristic, a more sophisticated approach would do some global optimization
203 : # of the curves. This approach is simple, fast and robust, but curves may e.g. merge.
204 : # This does not at all take into account the line shapes of the curves. There is also no trade-off
205 : # between overlap being close and not doing jumps.
206 0 : idxs = [np.argpartition(self.overlaps_list[0], -self.fit_idx)[-self.fit_idx]]
207 0 : last_overlap = self.overlaps_list[0][idxs[-1]]
208 0 : for overlaps in self.overlaps_list[1:]:
209 0 : idx = idxs[-1]
210 0 : overlap = overlaps[idx]
211 0 : if 0.5 * last_overlap < overlap < 2 * last_overlap or abs(overlap - last_overlap) < 0.05:
212 : # we keep the current index
213 0 : idxs.append(idx)
214 0 : last_overlap = overlap
215 : else:
216 : # we search until we find an overlap that is less than a factor 2 different
217 0 : possible_options = np.argwhere(
218 : np.logical_and(overlaps > 0.5 * last_overlap, overlaps < 2 * last_overlap)
219 : ).flatten()
220 0 : if len(possible_options) == 0:
221 : # there is no state in that range - our best bet is to keep the current index
222 0 : idxs.append(idx)
223 0 : last_overlap = overlap
224 : else:
225 : # we select the closest possible option
226 0 : best_option = np.argmin(np.abs(possible_options - idx))
227 0 : idxs.append(possible_options[best_option])
228 0 : last_overlap = overlaps[idxs[-1]]
229 :
230 : # this could be a call to np.take_along_axis if the sizes match, but the handling of inhomogeneous shapes
231 : # in the plot() function makes me worry they won't, so I go for a slower python for loop...
232 0 : energies = np.array([energy[idx] for energy, idx in zip(self.energies_list, idxs)])
233 :
234 : # stop highlighting the previous fit
235 0 : if hasattr(self, "fit_data_highlight"):
236 0 : self.fit_data_highlight.remove()
237 0 : if hasattr(self, "fit_curve"):
238 0 : for curve in self.fit_curve:
239 0 : curve.remove()
240 :
241 0 : self.fit_data_highlight = self.canvas.ax.scatter(self.x_list, energies, c="green", s=5)
242 :
243 0 : try:
244 0 : fit_params = curve_fit(fit_func, self.x_list, energies)[0]
245 0 : except (RuntimeError, TypeError):
246 0 : logger.warning("Curve fit failed.")
247 : else:
248 0 : self.fit_curve = self.canvas.ax.plot(
249 : self.x_list,
250 : fit_func(self.x_list, *fit_params),
251 : c="green",
252 : linestyle="dashed",
253 : lw=2,
254 : label=fitlabel.format(*fit_params),
255 : )
256 0 : self.canvas.ax.legend()
257 :
258 0 : self.canvas.draw()
259 :
260 1 : def clear(self) -> None:
261 1 : super().clear()
262 1 : self.reset_fit()
263 :
264 1 : def reset_fit(self) -> None:
265 : """Clear fit output and reset fit index."""
266 : # restart at first potential curve
267 1 : self.fit_idx = 0
268 : # and also remove any previous highlighting/fit display
269 1 : if hasattr(self, "fit_data_highlight"):
270 0 : del self.fit_data_highlight
271 1 : if hasattr(self, "fit_curve"):
272 0 : del self.fit_curve
|