Line data Source code
1 : # SPDX-FileCopyrightText: 2025 PairInteraction Developers
2 : # SPDX-License-Identifier: LGPL-3.0-or-later
3 1 : from __future__ import annotations
4 :
5 1 : import logging
6 1 : from typing import TYPE_CHECKING, Any, Callable
7 :
8 1 : import matplotlib as mpl
9 1 : import matplotlib.pyplot as plt
10 1 : import mplcursors
11 1 : import numpy as np
12 1 : from matplotlib.colors import Normalize
13 1 : from PySide6.QtWidgets import QHBoxLayout
14 1 : from scipy.optimize import curve_fit
15 :
16 1 : from pairinteraction.visualization.colormaps import alphamagma
17 1 : from pairinteraction_gui.plotwidget.canvas import MatplotlibCanvas
18 1 : from pairinteraction_gui.plotwidget.navigation_toolbar import CustomNavigationToolbar
19 1 : from pairinteraction_gui.qobjects import WidgetV
20 1 : from pairinteraction_gui.theme import plot_widget_theme
21 :
22 : if TYPE_CHECKING:
23 : from collections.abc import Sequence
24 :
25 : from numpy.typing import NDArray
26 :
27 : from pairinteraction_gui.calculate.calculate_base import Parameters, Results
28 : from pairinteraction_gui.page import SimulationPage
29 :
30 1 : logger = logging.getLogger(__name__)
31 :
32 :
33 1 : class PlotWidget(WidgetV):
34 : """Widget for displaying plots with controls."""
35 :
36 1 : margin = (0, 0, 0, 0)
37 1 : spacing = 15
38 :
39 1 : def __init__(self, parent: SimulationPage) -> None:
40 : """Initialize the base section."""
41 1 : mpl.use("Qt5Agg")
42 :
43 1 : self.page = parent
44 1 : super().__init__(parent)
45 :
46 1 : self.setStyleSheet(plot_widget_theme)
47 :
48 1 : def setupWidget(self) -> None:
49 1 : self.canvas = MatplotlibCanvas(self)
50 1 : self.navigation_toolbar = CustomNavigationToolbar(self.canvas, self)
51 :
52 1 : top_layout = QHBoxLayout()
53 1 : top_layout.addStretch(1)
54 1 : top_layout.addWidget(self.navigation_toolbar)
55 1 : self.layout().addLayout(top_layout)
56 :
57 1 : self.layout().addWidget(self.canvas, stretch=1)
58 :
59 1 : def clear(self) -> None:
60 1 : self.canvas.ax.clear()
61 1 : self.canvas.draw()
62 :
63 :
64 1 : class PlotEnergies(PlotWidget):
65 : """Plotwidget for plotting energy levels."""
66 :
67 1 : parameters: Parameters[Any] | None = None
68 1 : results: Results | None = None
69 :
70 1 : fit_idx: int = 0
71 1 : fit_type: str = ""
72 1 : fit_data_highlight: mpl.collections.PathCollection
73 1 : fit_curve: Sequence[mpl.lines.Line2D]
74 :
75 1 : def setupWidget(self) -> None:
76 1 : super().setupWidget()
77 :
78 1 : mappable = plt.cm.ScalarMappable(cmap=alphamagma, norm=Normalize(vmin=0, vmax=1))
79 1 : self.canvas.fig.colorbar(mappable, ax=self.canvas.ax, label="Overlap with state of interest")
80 1 : self.canvas.fig.tight_layout()
81 :
82 1 : def plot(self, parameters: Parameters[Any], results: Results) -> None:
83 0 : self.reset_fit()
84 0 : ax = self.canvas.ax
85 0 : ax.clear()
86 :
87 : # store data to allow fitting later on
88 0 : self.parameters = parameters
89 0 : self.results = results
90 :
91 0 : x_values = parameters.get_x_values()
92 0 : energies = results.energies
93 :
94 0 : try:
95 0 : ax.plot(x_values, np.array(energies), c="0.75", lw=0.25, zorder=-10)
96 0 : except ValueError as err:
97 0 : if "inhomogeneous shape" in str(err):
98 0 : for x_value, es in zip(x_values, energies):
99 0 : ax.plot([x_value] * len(es), es, c="0.75", ls="None", marker=".", zorder=-10)
100 : else:
101 0 : raise err
102 :
103 : # Flatten the arrays for scatter plot and repeat x value for each energy
104 : # (dont use numpy.flatten, etc. to also handle inhomogeneous shapes)
105 0 : x_repeated = np.hstack([val * np.ones_like(es) for val, es in zip(x_values, energies)])
106 0 : energies_flattened = np.hstack(energies)
107 0 : overlaps_flattened = np.hstack(results.ket_overlaps)
108 :
109 0 : min_overlap = 1e-4
110 0 : inds: NDArray[Any] = np.argwhere(overlaps_flattened > min_overlap).flatten()
111 0 : inds = inds[np.argsort(overlaps_flattened[inds])]
112 :
113 0 : if len(inds) > 0:
114 0 : ax.scatter(
115 : x_repeated[inds],
116 : energies_flattened[inds],
117 : c=overlaps_flattened[inds],
118 : s=15,
119 : vmin=0,
120 : vmax=1,
121 : cmap=alphamagma,
122 : )
123 :
124 0 : ax.set_xlabel(parameters.get_x_label())
125 0 : ax.set_ylabel("Energy [GHz]")
126 :
127 0 : self.canvas.fig.tight_layout()
128 :
129 1 : def add_cursor(self, parameters: Parameters[Any], results: Results) -> None:
130 : # Remove any existing cursors to avoid duplicates
131 0 : if hasattr(self, "mpl_cursor"):
132 0 : if hasattr(self.mpl_cursor, "remove"): # type: ignore
133 0 : self.mpl_cursor.remove() # type: ignore
134 0 : del self.mpl_cursor # type: ignore
135 :
136 0 : energies = results.energies
137 0 : x_values = parameters.get_x_values()
138 :
139 0 : artists = []
140 0 : for idx, labels in results.state_labels.items():
141 0 : x = x_values[idx]
142 0 : for energy, label in zip(energies[idx], labels):
143 0 : artist = self.canvas.ax.plot(x, energy, "d", c="0.93", alpha=0.5, ms=7, label=label, zorder=-20)
144 0 : artists.extend(artist)
145 :
146 0 : self.mpl_cursor = mplcursors.cursor(
147 : artists,
148 : hover=False,
149 : annotation_kwargs={
150 : "bbox": {"boxstyle": "round,pad=0.5", "fc": "white", "alpha": 0.9, "ec": "gray"},
151 : "arrowprops": {"arrowstyle": "->", "connectionstyle": "arc3", "color": "gray"},
152 : },
153 : )
154 :
155 0 : @self.mpl_cursor.connect("add")
156 0 : def on_add(sel: mplcursors.Selection) -> None:
157 0 : label = sel.artist.get_label()
158 0 : sel.annotation.set_text(label.replace(" + ", "\n + "))
159 :
160 1 : def fit(self, fit_type: str = "c6") -> None: # noqa: PLR0912, PLR0915, C901
161 : """Fits a potential curve and displays the fit values.
162 :
163 : Args:
164 : fit_type: Type of fit to perform. Options are:
165 : c6: E = E0 + C6 * r^6
166 : c3: E = E0 + C3 * r^3
167 : c3+c6: E = E0 + C3 * r^3 + C6 * r^6
168 :
169 : Iterative calls will iterate through the potential curves
170 :
171 : """
172 0 : if self.parameters is None or self.results is None:
173 0 : logger.warning("No data to fit.")
174 0 : return
175 :
176 0 : energies = self.results.energies
177 0 : x_values = self.parameters.get_x_values()
178 0 : overlaps_list = self.results.ket_overlaps
179 :
180 : fit_func: Callable[..., NDArray[Any]]
181 0 : if fit_type == "c6":
182 0 : fit_func = lambda x, e0, c6: e0 + c6 / x**6 # noqa: E731
183 0 : fitlabel = "E0 = {0:.3f} GHz\nC6 = {1:.3f} GHz*µm^6"
184 0 : elif fit_type == "c3":
185 0 : fit_func = lambda x, e0, c3: e0 + c3 / x**3 # noqa: E731
186 0 : fitlabel = "E0 = {0:.3f} GHz\nC3 = {1:.3f} GHz*µm^3"
187 0 : elif fit_type == "c3+c6":
188 0 : fit_func = lambda x, e0, c3, c6: e0 + c3 / x**3 + c6 / x**6 # noqa: E731
189 0 : fitlabel = "E0 = {0:.3f} GHz\nC3 = {1:.3f} GHz*µm^3\nC6 = {2:.3f} GHz*µm^6"
190 : else:
191 0 : raise ValueError(f"Unknown fit type: {fit_type}")
192 :
193 : # increase the selected potential curve by one if we use the same fit type
194 0 : if self.fit_type == fit_type:
195 0 : self.fit_idx = (self.fit_idx + 1) % len(energies)
196 : else:
197 0 : self.fit_idx = 1
198 0 : self.fit_type = fit_type
199 :
200 : # We want to follow the potential curves. The ordering of energies is just by value, so we
201 : # need to follow the curve somehow. We go right to left, start at the nth largest value, keep our
202 : # index as long as the difference in overlap is less than a factor 2 or less than 5% total difference.
203 : # Otherwise, we search until we find an overlap that is less than a factor 2 different.
204 : # This is of course a simple heuristic, a more sophisticated approach would do some global optimization
205 : # of the curves. This approach is simple, fast and robust, but curves may e.g. merge.
206 : # This does not at all take into account the line shapes of the curves. There is also no trade-off
207 : # between overlap being close and not doing jumps.
208 0 : idxs = [np.argpartition(overlaps_list[0], -self.fit_idx)[-self.fit_idx]]
209 0 : last_overlap = overlaps_list[0][idxs[-1]]
210 0 : for overlaps in overlaps_list[1:]:
211 0 : idx = idxs[-1]
212 0 : overlap = overlaps[idx]
213 0 : if 0.5 * last_overlap < overlap < 2 * last_overlap or abs(overlap - last_overlap) < 0.05:
214 : # we keep the current index
215 0 : idxs.append(idx)
216 0 : last_overlap = overlap
217 : else:
218 : # we search until we find an overlap that is less than a factor 2 different
219 0 : possible_options = np.argwhere(
220 : np.logical_and(overlaps > 0.5 * last_overlap, overlaps < 2 * last_overlap)
221 : ).flatten()
222 0 : if len(possible_options) == 0:
223 : # there is no state in that range - our best bet is to keep the current index
224 0 : idxs.append(idx)
225 0 : last_overlap = overlap
226 : else:
227 : # we select the closest possible option
228 0 : best_option = np.argmin(np.abs(possible_options - idx))
229 0 : idxs.append(possible_options[best_option])
230 0 : last_overlap = overlaps[idxs[-1]]
231 :
232 : # this could be a call to np.take_along_axis if the sizes match, but the handling of inhomogeneous shapes
233 : # in the plot() function makes me worry they won't, so I go for a slower python for loop...
234 0 : energies_fit = np.array([energy[idx] for energy, idx in zip(energies, idxs)])
235 :
236 : # stop highlighting the previous fit
237 0 : if hasattr(self, "fit_data_highlight"):
238 0 : self.fit_data_highlight.remove()
239 0 : if hasattr(self, "fit_curve"):
240 0 : for curve in self.fit_curve:
241 0 : curve.remove()
242 :
243 0 : self.fit_data_highlight = self.canvas.ax.scatter(x_values, energies_fit, c="green", s=5)
244 :
245 0 : try:
246 0 : fit_params = curve_fit(fit_func, x_values, energies_fit)[0]
247 0 : except (RuntimeError, TypeError):
248 0 : logger.warning("Curve fit failed.")
249 : else:
250 0 : self.fit_curve = self.canvas.ax.plot(
251 : x_values,
252 : fit_func(x_values, *fit_params),
253 : c="green",
254 : linestyle="dashed",
255 : lw=2,
256 : label=fitlabel.format(*fit_params),
257 : )
258 0 : self.canvas.ax.legend()
259 :
260 0 : self.canvas.draw()
261 :
262 1 : def clear(self) -> None:
263 1 : super().clear()
264 1 : self.reset_fit()
265 :
266 1 : def reset_fit(self) -> None:
267 : """Clear fit output and reset fit index."""
268 : # restart at first potential curve
269 1 : self.fit_idx = 0
270 : # and also remove any previous highlighting/fit display
271 1 : if hasattr(self, "fit_data_highlight"):
272 0 : del self.fit_data_highlight
273 1 : if hasattr(self, "fit_curve"):
274 0 : del self.fit_curve
|