Line data Source code
1 : # SPDX-FileCopyrightText: 2025 PairInteraction Developers
2 : # SPDX-License-Identifier: LGPL-3.0-or-later
3 1 : from __future__ import annotations
4 :
5 1 : import logging
6 1 : from typing import TYPE_CHECKING, Any, Callable
7 :
8 1 : import matplotlib as mpl
9 1 : import matplotlib.pyplot as plt
10 1 : import mplcursors
11 1 : import numpy as np
12 1 : from matplotlib.colors import Normalize
13 1 : from PySide6.QtWidgets import QHBoxLayout
14 1 : from scipy.optimize import curve_fit
15 :
16 1 : from pairinteraction.visualization.colormaps import alphamagma
17 1 : from pairinteraction_gui.plotwidget.canvas import MatplotlibCanvas
18 1 : from pairinteraction_gui.plotwidget.navigation_toolbar import CustomNavigationToolbar
19 1 : from pairinteraction_gui.qobjects import WidgetV
20 1 : from pairinteraction_gui.theme import plot_widget_theme
21 :
22 : if TYPE_CHECKING:
23 : from collections.abc import Sequence
24 :
25 : from numpy.typing import NDArray
26 : from typing_extensions import Concatenate
27 :
28 : from pairinteraction_gui.calculate.calculate_base import Parameters, Results
29 : from pairinteraction_gui.page import SimulationPage
30 :
31 1 : logger = logging.getLogger(__name__)
32 :
33 :
34 1 : class PlotWidget(WidgetV):
35 : """Widget for displaying plots with controls."""
36 :
37 1 : margin = (0, 0, 0, 0)
38 1 : spacing = 15
39 :
40 1 : def __init__(self, parent: SimulationPage) -> None:
41 : """Initialize the base section."""
42 1 : mpl.use("Qt5Agg")
43 :
44 1 : self.page = parent
45 1 : super().__init__(parent)
46 :
47 1 : self.setStyleSheet(plot_widget_theme)
48 :
49 1 : def setupWidget(self) -> None:
50 1 : self.canvas = MatplotlibCanvas(self)
51 1 : self.navigation_toolbar = CustomNavigationToolbar(self.canvas, self)
52 :
53 1 : top_layout = QHBoxLayout()
54 1 : top_layout.addStretch(1)
55 1 : top_layout.addWidget(self.navigation_toolbar)
56 1 : self.layout().addLayout(top_layout)
57 :
58 1 : self.layout().addWidget(self.canvas, stretch=1)
59 :
60 1 : def clear(self) -> None:
61 1 : self.canvas.ax.clear()
62 1 : self.canvas.draw()
63 :
64 :
65 1 : class PlotEnergies(PlotWidget):
66 : """Plotwidget for plotting energy levels."""
67 :
68 1 : parameters: Parameters[Any] | None = None
69 1 : results: Results | None = None
70 :
71 1 : fit_idx: int = 0
72 1 : fit_type: str = ""
73 1 : fit_data_highlight: mpl.collections.PathCollection
74 1 : fit_curve: Sequence[mpl.lines.Line2D]
75 :
76 1 : def setupWidget(self) -> None:
77 1 : super().setupWidget()
78 :
79 1 : mappable = plt.cm.ScalarMappable(cmap=alphamagma, norm=Normalize(vmin=0, vmax=1))
80 1 : self.canvas.fig.colorbar(mappable, ax=self.canvas.ax, label="Overlap with state of interest")
81 1 : self.canvas.fig.tight_layout()
82 :
83 1 : def plot(self, parameters: Parameters[Any], results: Results) -> None:
84 1 : self.reset_fit()
85 1 : ax = self.canvas.ax
86 1 : ax.clear()
87 :
88 : # store data to allow fitting later on
89 1 : self.parameters = parameters
90 1 : self.results = results
91 :
92 1 : x_values = parameters.get_x_values()
93 1 : energies = results.energies
94 :
95 1 : try:
96 1 : ax.plot(x_values, np.array(energies), c="0.75", lw=0.25, zorder=-10)
97 0 : except ValueError as err:
98 0 : if "inhomogeneous shape" in str(err):
99 0 : for x_value, es in zip(x_values, energies):
100 0 : ax.plot([x_value] * len(es), es, c="0.75", ls="None", marker=".", zorder=-10)
101 : else:
102 0 : raise err
103 :
104 : # Flatten the arrays for scatter plot and repeat x value for each energy
105 : # (dont use numpy.flatten, etc. to also handle inhomogeneous shapes)
106 1 : x_repeated = np.hstack([val * np.ones_like(es) for val, es in zip(x_values, energies)])
107 1 : energies_flattened = np.hstack(energies)
108 1 : overlaps_flattened = np.hstack(results.ket_overlaps)
109 :
110 1 : min_overlap = 1e-4
111 1 : inds: NDArray[Any] = np.argwhere(overlaps_flattened > min_overlap).flatten()
112 1 : inds = inds[np.argsort(overlaps_flattened[inds])]
113 :
114 1 : if len(inds) > 0:
115 1 : ax.scatter(
116 : x_repeated[inds],
117 : energies_flattened[inds],
118 : c=overlaps_flattened[inds],
119 : s=15,
120 : vmin=0,
121 : vmax=1,
122 : cmap=alphamagma,
123 : )
124 :
125 1 : ax.set_xlabel(parameters.get_x_label())
126 1 : ax.set_ylabel("Energy [GHz]")
127 :
128 1 : self.canvas.fig.tight_layout()
129 :
130 1 : def add_cursor(self, parameters: Parameters[Any], results: Results) -> None:
131 : # Remove any existing cursors to avoid duplicates
132 1 : if hasattr(self, "mpl_cursor"):
133 0 : if hasattr(self.mpl_cursor, "remove"): # type: ignore
134 0 : self.mpl_cursor.remove() # type: ignore
135 0 : del self.mpl_cursor # type: ignore
136 :
137 1 : energies = results.energies
138 1 : x_values = parameters.get_x_values()
139 :
140 1 : artists = []
141 1 : for idx, labels in results.state_labels.items():
142 1 : x = x_values[idx]
143 1 : for energy, label in zip(energies[idx], labels):
144 1 : artist = self.canvas.ax.plot(x, energy, "d", c="0.93", alpha=0.5, ms=7, label=label, zorder=-20)
145 1 : artists.extend(artist)
146 :
147 1 : self.mpl_cursor = mplcursors.cursor(
148 : artists,
149 : hover=False,
150 : annotation_kwargs={
151 : "bbox": {"boxstyle": "round,pad=0.5", "fc": "white", "alpha": 0.9, "ec": "gray"},
152 : "arrowprops": {"arrowstyle": "->", "connectionstyle": "arc3", "color": "gray"},
153 : },
154 : )
155 :
156 1 : @self.mpl_cursor.connect("add")
157 1 : def on_add(sel: mplcursors.Selection) -> None:
158 0 : label = sel.artist.get_label()
159 0 : sel.annotation.set_text(label.replace(" + ", "\n + "))
160 :
161 1 : def fit(self, fit_type: str = "c6") -> None: # noqa: PLR0912, PLR0915, C901
162 : """Fits a potential curve and displays the fit values.
163 :
164 : Args:
165 : fit_type: Type of fit to perform. Options are:
166 : c6: E = E0 + C6 * r^6
167 : c3: E = E0 + C3 * r^3
168 : c3+c6: E = E0 + C3 * r^3 + C6 * r^6
169 :
170 : Iterative calls will iterate through the potential curves
171 :
172 : """
173 1 : if self.parameters is None or self.results is None:
174 0 : logger.warning("No data to fit.")
175 0 : return
176 :
177 1 : energies = self.results.energies
178 1 : x_values = np.array(self.parameters.get_x_values())
179 1 : overlaps_list = self.results.ket_overlaps
180 :
181 : fit_func: Callable[Concatenate[NDArray[Any], ...], NDArray[Any]]
182 1 : if fit_type == "c6":
183 1 : fit_func = fit_c6
184 1 : fitlabel = "E0 = {0:.3f} GHz\nC6 = {1:.3f} GHz*µm^6"
185 1 : elif fit_type == "c3":
186 1 : fit_func = fit_c3
187 1 : fitlabel = "E0 = {0:.3f} GHz\nC3 = {1:.3f} GHz*µm^3"
188 1 : elif fit_type == "c3+c6":
189 1 : fit_func = fit_c3_c6
190 1 : fitlabel = "E0 = {0:.3f} GHz\nC3 = {1:.3f} GHz*µm^3\nC6 = {2:.3f} GHz*µm^6"
191 : else:
192 0 : raise ValueError(f"Unknown fit type: {fit_type}")
193 :
194 : # increase the selected potential curve by one if we use the same fit type
195 1 : if self.fit_type == fit_type:
196 1 : self.fit_idx = (self.fit_idx + 1) % len(energies)
197 : else:
198 1 : self.fit_idx = 1
199 1 : self.fit_type = fit_type
200 :
201 : # We want to follow the potential curves. The ordering of energies is just by value, so we
202 : # need to follow the curve somehow. We go right to left, start at the nth largest value, keep our
203 : # index as long as the difference in overlap is less than a factor 2 or less than 5% total difference.
204 : # Otherwise, we search until we find an overlap that is less than a factor 2 different.
205 : # This is of course a simple heuristic, a more sophisticated approach would do some global optimization
206 : # of the curves. This approach is simple, fast and robust, but curves may e.g. merge.
207 : # This does not at all take into account the line shapes of the curves. There is also no trade-off
208 : # between overlap being close and not doing jumps.
209 1 : idxs = [np.argpartition(overlaps_list[0], -self.fit_idx)[-self.fit_idx]]
210 1 : last_overlap = overlaps_list[0][idxs[-1]]
211 1 : for overlaps in overlaps_list[1:]:
212 1 : idx = idxs[-1]
213 1 : overlap = overlaps[idx]
214 1 : if 0.5 * last_overlap < overlap < 2 * last_overlap or abs(overlap - last_overlap) < 0.05:
215 : # we keep the current index
216 1 : idxs.append(idx)
217 1 : last_overlap = overlap
218 : else:
219 : # we search until we find an overlap that is less than a factor 2 different
220 1 : possible_options = np.argwhere(
221 : np.logical_and(overlaps > 0.5 * last_overlap, overlaps < 2 * last_overlap)
222 : ).flatten()
223 1 : if len(possible_options) == 0:
224 : # there is no state in that range - our best bet is to keep the current index
225 1 : idxs.append(idx)
226 1 : last_overlap = overlap
227 : else:
228 : # we select the closest possible option
229 1 : best_option = np.argmin(np.abs(possible_options - idx))
230 1 : idxs.append(possible_options[best_option])
231 1 : last_overlap = overlaps[idxs[-1]]
232 :
233 : # this could be a call to np.take_along_axis if the sizes match, but the handling of inhomogeneous shapes
234 : # in the plot() function makes me worry they won't, so I go for a slower python for loop...
235 1 : energies_fit = np.array([energy[idx] for energy, idx in zip(energies, idxs)])
236 :
237 : # stop highlighting the previous fit
238 1 : if hasattr(self, "fit_data_highlight"):
239 1 : self.fit_data_highlight.remove()
240 1 : if hasattr(self, "fit_curve"):
241 1 : for curve in self.fit_curve:
242 1 : curve.remove()
243 :
244 1 : self.fit_data_highlight = self.canvas.ax.scatter(x_values, energies_fit, c="green", s=5)
245 :
246 1 : try:
247 1 : fit_params = curve_fit(fit_func, x_values, energies_fit)[0]
248 0 : except (RuntimeError, TypeError):
249 0 : logger.warning("Curve fit failed.")
250 : else:
251 1 : self.fit_curve = self.canvas.ax.plot(
252 : x_values,
253 : fit_func(x_values, *fit_params),
254 : c="green",
255 : linestyle="dashed",
256 : lw=2,
257 : label=fitlabel.format(*fit_params),
258 : )
259 1 : self.canvas.ax.legend()
260 :
261 1 : self.canvas.draw()
262 :
263 1 : def clear(self) -> None:
264 1 : super().clear()
265 1 : self.reset_fit()
266 :
267 1 : def reset_fit(self) -> None:
268 : """Clear fit output and reset fit index."""
269 : # restart at first potential curve
270 1 : self.fit_idx = 0
271 : # and also remove any previous highlighting/fit display
272 1 : if hasattr(self, "fit_data_highlight"):
273 0 : del self.fit_data_highlight
274 1 : if hasattr(self, "fit_curve"):
275 0 : del self.fit_curve
276 :
277 :
278 1 : def fit_c3(x: NDArray[Any], /, e0: float, c3: float) -> NDArray[Any]:
279 1 : return e0 + c3 / x**3
280 :
281 :
282 1 : def fit_c6(x: NDArray[Any], /, e0: float, c6: float) -> NDArray[Any]:
283 1 : return e0 + c6 / x**6
284 :
285 :
286 1 : def fit_c3_c6(x: NDArray[Any], /, e0: float, c3: float, c6: float) -> NDArray[Any]:
287 1 : return e0 + c3 / x**3 + c6 / x**6
|