Line data Source code
1 : # SPDX-FileCopyrightText: 2025 PairInteraction Developers
2 : # SPDX-License-Identifier: LGPL-3.0-or-later
3 1 : from __future__ import annotations
4 :
5 1 : import contextlib
6 1 : import logging
7 1 : from typing import TYPE_CHECKING, Any
8 :
9 1 : import matplotlib as mpl
10 1 : import matplotlib.pyplot as plt
11 1 : import numpy as np
12 1 : from matplotlib.colors import Normalize
13 1 : from PySide6.QtGui import QPalette
14 1 : from PySide6.QtWidgets import QHBoxLayout
15 1 : from scipy.optimize import curve_fit
16 :
17 1 : from pairinteraction.state.state_atom import StateAtom
18 1 : from pairinteraction.visualization.colormaps import alphamagma
19 1 : from pairinteraction_gui.plotwidget.canvas import MatplotlibCanvas
20 1 : from pairinteraction_gui.plotwidget.navigation_toolbar import CustomNavigationToolbar
21 1 : from pairinteraction_gui.qobjects import WidgetV
22 1 : from pairinteraction_gui.qobjects.events import show_status_tip
23 1 : from pairinteraction_gui.theme import theme_manager
24 :
25 : if TYPE_CHECKING:
26 : from collections.abc import Callable, Sequence
27 : from typing import Concatenate
28 :
29 : from numpy.typing import NDArray
30 :
31 : from pairinteraction.state import StateBase
32 : from pairinteraction_gui.calculate.calculate_base import Parameters, Results
33 : from pairinteraction_gui.calculate.calculate_lifetimes import KetData, ParametersLifetimes, ResultsLifetimes
34 : from pairinteraction_gui.page import SimulationPage
35 :
36 1 : logger = logging.getLogger(__name__)
37 :
38 :
39 1 : class PlotWidget(WidgetV):
40 : """Widget for displaying plots with controls."""
41 :
42 1 : margin = (0, 0, 0, 0)
43 1 : spacing = 15
44 1 : _annotations: dict[Any, mpl.text.Annotation]
45 :
46 1 : def __init__(self, parent: SimulationPage) -> None:
47 : """Initialize the base section."""
48 1 : mpl.use("Qt5Agg")
49 1 : self.page = parent
50 1 : super().__init__(parent)
51 :
52 1 : self._annotations = {}
53 1 : self._click_cid: int | None = None
54 :
55 1 : def setupWidget(self) -> None:
56 1 : self.canvas = MatplotlibCanvas(self)
57 1 : self.navigation_toolbar = CustomNavigationToolbar(self.canvas, self)
58 1 : self.navigation_toolbar.setObjectName("PlotNavigationToolBar")
59 :
60 1 : top_layout = QHBoxLayout()
61 1 : top_layout.addStretch(1)
62 1 : top_layout.addWidget(self.navigation_toolbar)
63 1 : self.layout().addLayout(top_layout)
64 :
65 1 : self.layout().addWidget(self.canvas, stretch=1)
66 :
67 1 : def clear(self) -> None:
68 1 : self.canvas.ax.clear()
69 1 : self.canvas.draw_idle()
70 1 : self.clear_annotations()
71 1 : self.disconnect_click()
72 :
73 1 : def clear_annotations(self) -> None:
74 1 : for ann in self._annotations.values():
75 0 : with contextlib.suppress(NotImplementedError):
76 0 : ann.remove() # artist may already be gone if ax.clear() was called
77 1 : self._annotations.clear()
78 1 : self.canvas.draw_idle()
79 :
80 1 : def disconnect_click(self) -> None:
81 1 : if self._click_cid is not None:
82 0 : self.canvas.mpl_disconnect(self._click_cid)
83 0 : self._click_cid = None
84 :
85 :
86 1 : class PlotEnergies(PlotWidget):
87 : """Plotwidget for plotting energy levels."""
88 :
89 1 : parameters: Parameters[Any] | None = None
90 1 : results: Results | None = None
91 1 : _annotations: dict[int, mpl.text.Annotation]
92 :
93 1 : def __init__(self, parent: SimulationPage) -> None:
94 1 : super().__init__(parent)
95 :
96 1 : self.fit_idx = 0
97 1 : self.fit_type = ""
98 1 : self.fit_data_highlight: mpl.collections.PathCollection | None = None
99 1 : self.fit_curve: Sequence[mpl.lines.Line2D] | None = None
100 :
101 1 : def setupWidget(self) -> None:
102 1 : super().setupWidget()
103 :
104 1 : window_color = theme_manager.get_palette().color(QPalette.ColorRole.Window).name()
105 :
106 1 : self.canvas.fig.set_facecolor(window_color)
107 1 : self.canvas.fig.set_layout_engine(
108 : "constrained",
109 : w_pad=0.2,
110 : h_pad=0.2,
111 : wspace=0.0,
112 : hspace=0.0,
113 : )
114 1 : mappable = plt.cm.ScalarMappable(cmap=alphamagma, norm=Normalize(vmin=0, vmax=1))
115 1 : cbar = self.canvas.fig.colorbar(mappable, ax=self.canvas.ax, label="Overlap with state of interest", aspect=60)
116 1 : cbar.ax.set_zorder(0)
117 1 : self.canvas.ax.set_zorder(1)
118 :
119 1 : def plot(self, parameters: Parameters[Any], results: Results) -> None:
120 1 : self.clear()
121 :
122 1 : show_status_tip(self, "Plotting energy curves...")
123 1 : ax = self.canvas.ax
124 1 : ax.set_xmargin(0)
125 :
126 : # store data to allow fitting later on
127 1 : self.parameters = parameters
128 1 : self.results = results
129 :
130 1 : x_values = parameters.get_x_values()
131 1 : energies = results.energies
132 :
133 1 : if len({len(es) for es in energies}) <= 1: # check if homogeneous shape
134 1 : ax.plot(x_values, np.array(energies), c="0.75", lw=0.25, zorder=-10)
135 : else:
136 0 : for x_value, es in zip(x_values, energies, strict=True):
137 0 : ax.plot([x_value] * len(es), es, c="0.75", ls="None", marker=".", zorder=-10)
138 :
139 1 : show_status_tip(self, "Plotting overlaps...")
140 :
141 : # Flatten the arrays for scatter plot and repeat x value for each energy
142 : # (dont use numpy.flatten, etc. to also handle inhomogeneous shapes)
143 1 : x_repeated = np.hstack([val * np.ones_like(es) for val, es in zip(x_values, energies, strict=True)])
144 1 : energies_flattened = np.hstack(energies)
145 1 : overlaps_flattened = np.hstack(results.ket_overlaps)
146 :
147 1 : min_overlap = 1e-4
148 1 : inds: NDArray[Any] = np.argwhere(overlaps_flattened > min_overlap).flatten()
149 1 : inds = inds[np.argsort(overlaps_flattened[inds])]
150 :
151 1 : if len(inds) > 0:
152 1 : ax.scatter(
153 : x_repeated[inds],
154 : energies_flattened[inds],
155 : c=overlaps_flattened[inds],
156 : s=15,
157 : vmin=0,
158 : vmax=1,
159 : cmap=alphamagma,
160 : )
161 :
162 1 : ax.set_xlabel(parameters.get_x_label())
163 1 : ax.set_ylabel("Energy (GHz)")
164 :
165 1 : def setup_annotations(self, parameters: Parameters[Any], results: Results) -> None:
166 : """Connect click-based state annotation to the energy plot."""
167 1 : energies = results.energies
168 1 : overlaps = results.ket_overlaps
169 1 : x_values = parameters.get_x_values()
170 :
171 1 : self._point_index_map: list[tuple[int, int]] = []
172 1 : all_x: list[float] = []
173 1 : all_y: list[float] = []
174 1 : all_overlaps: list[float] = []
175 1 : for idx in range(len(energies)):
176 1 : x = x_values[idx]
177 1 : for idstate, (energy, overlap) in enumerate(zip(energies[idx], overlaps[idx], strict=True)):
178 1 : all_x.append(x)
179 1 : all_y.append(float(energy))
180 1 : all_overlaps.append(float(overlap))
181 1 : self._point_index_map.append((idx, idstate))
182 1 : pts_data = np.column_stack([all_x, all_y]) if all_x else np.empty((0, 2))
183 1 : pts_overlaps = np.array(all_overlaps)
184 :
185 1 : def on_click(event: mpl.backend_bases.MouseEvent) -> None:
186 0 : if (
187 : event.inaxes is not self.canvas.ax
188 : or event.button not in [1, 3]
189 : or len(pts_data) == 0
190 : or self.navigation_toolbar.mode
191 : ):
192 0 : return
193 0 : if event.button == 3: # right click clears annotations
194 0 : self.clear_annotations()
195 0 : return
196 :
197 0 : pts_pos = self.canvas.ax.transData.transform(pts_data)
198 0 : click_pos = np.array([event.x, event.y])
199 0 : dists = np.hypot(pts_pos[:, 0] - click_pos[0], pts_pos[:, 1] - click_pos[1])
200 0 : candidates = np.flatnonzero(dists <= 10) # threshold in pixels
201 0 : if len(candidates) == 0:
202 0 : self.clear_annotations()
203 0 : return
204 0 : selected = int(candidates[np.argmax(pts_overlaps[candidates])])
205 0 : if selected in self._annotations:
206 0 : self._annotations[selected].remove()
207 0 : del self._annotations[selected]
208 0 : self.canvas.draw_idle()
209 0 : return
210 0 : idstep, idstate = self._point_index_map[selected]
211 0 : state: StateBase[Any] = results.systems[idstep].get_eigenbasis().get_state(idstate)
212 0 : label = state.get_label().replace(" + ", "\n + ").replace("+ -", " - ")
213 0 : xlim = self.canvas.ax.get_xlim()
214 0 : ylim = self.canvas.ax.get_ylim()
215 0 : x_frac = (pts_data[selected, 0] - xlim[0]) / (xlim[1] - xlim[0])
216 0 : y_frac = (pts_data[selected, 1] - ylim[0]) / (ylim[1] - ylim[0])
217 0 : x_offset = -100 if isinstance(state, StateAtom) else -250
218 0 : x_offset = x_offset if x_frac > 0.5 else 0
219 0 : y_offset = 15 + 10 * label.count("\n")
220 0 : y_offset = -y_offset if y_frac > 0.5 else y_offset
221 0 : ann = self.canvas.ax.annotate(
222 : label,
223 : xy=(pts_data[selected, 0], pts_data[selected, 1]),
224 : xytext=(x_offset, y_offset),
225 : textcoords="offset points",
226 : va="center",
227 : bbox={"boxstyle": "round,pad=0.5", "fc": "white", "alpha": 0.9, "ec": "gray"},
228 : arrowprops={"arrowstyle": "->", "connectionstyle": "arc3", "color": "gray"},
229 : clip_on=False,
230 : )
231 0 : ann.set_in_layout(False)
232 0 : self._annotations[selected] = ann
233 0 : self.canvas.draw_idle()
234 :
235 1 : self._click_cid = self.canvas.mpl_connect("button_press_event", on_click) # type: ignore [arg-type]
236 1 : self.navigation_toolbar._home_callbacks = [self.clear_annotations]
237 :
238 1 : def fit(self, fit_type: str = "c6") -> None: # noqa: PLR0912, PLR0915, C901
239 : """Fits a potential curve and displays the fit values.
240 :
241 : Args:
242 : fit_type: Type of fit to perform. Options are:
243 : c6: E = E0 + C6 * r^6
244 : c3: E = E0 + C3 * r^3
245 : c3+c6: E = E0 + C3 * r^3 + C6 * r^6
246 :
247 : Iterative calls will iterate through the potential curves
248 :
249 : """
250 1 : if self.parameters is None or self.results is None:
251 0 : logger.warning("No data to fit.")
252 0 : return
253 :
254 1 : energies = self.results.energies
255 1 : x_values = np.array(self.parameters.get_x_values())
256 1 : overlaps_list = self.results.ket_overlaps
257 :
258 : fit_func: Callable[Concatenate[NDArray[Any], ...], NDArray[Any]]
259 1 : if fit_type == "c6":
260 1 : fit_func = fit_c6
261 1 : fitlabel = "E0 = {0:.3f} GHz\nC6 = {1:.3f} GHz*µm^6"
262 1 : elif fit_type == "c3":
263 1 : fit_func = fit_c3
264 1 : fitlabel = "E0 = {0:.3f} GHz\nC3 = {1:.3f} GHz*µm^3"
265 1 : elif fit_type == "c3+c6":
266 1 : fit_func = fit_c3_c6
267 1 : fitlabel = "E0 = {0:.3f} GHz\nC3 = {1:.3f} GHz*µm^3\nC6 = {2:.3f} GHz*µm^6"
268 : else:
269 0 : raise ValueError(f"Unknown fit type: {fit_type}")
270 :
271 : # increase the selected potential curve by one if we use the same fit type
272 1 : if self.fit_type == fit_type:
273 1 : self.fit_idx = (self.fit_idx + 1) % len(energies)
274 : else:
275 1 : self.fit_idx = 1
276 1 : self.fit_type = fit_type
277 :
278 : # We want to follow the potential curves. The ordering of energies is just by value, so we
279 : # need to follow the curve somehow. We go right to left, start at the nth largest value, keep our
280 : # index as long as the difference in overlap is less than a factor 2 or less than 5% total difference.
281 : # Otherwise, we search until we find an overlap that is less than a factor 2 different.
282 : # This is of course a simple heuristic, a more sophisticated approach would do some global optimization
283 : # of the curves. This approach is simple, fast and robust, but curves may e.g. merge.
284 : # This does not at all take into account the line shapes of the curves. There is also no trade-off
285 : # between overlap being close and not doing jumps.
286 1 : idxs = [np.argpartition(overlaps_list[0], -self.fit_idx)[-self.fit_idx]]
287 1 : last_overlap = overlaps_list[0][idxs[-1]]
288 1 : for overlaps in overlaps_list[1:]:
289 1 : idx = idxs[-1]
290 1 : overlap = overlaps[idx]
291 1 : if 0.5 * last_overlap < overlap < 2 * last_overlap or abs(overlap - last_overlap) < 0.05:
292 : # we keep the current index
293 1 : idxs.append(idx)
294 1 : last_overlap = overlap
295 : else:
296 : # we search until we find an overlap that is less than a factor 2 different
297 1 : possible_options = np.argwhere(
298 : np.logical_and(overlaps > 0.5 * last_overlap, overlaps < 2 * last_overlap)
299 : ).flatten()
300 1 : if len(possible_options) == 0:
301 : # there is no state in that range - our best bet is to keep the current index
302 1 : idxs.append(idx)
303 1 : last_overlap = overlap
304 : else:
305 : # we select the closest possible option
306 1 : best_option = np.argmin(np.abs(possible_options - idx))
307 1 : idxs.append(possible_options[best_option])
308 1 : last_overlap = overlaps[idxs[-1]]
309 :
310 : # this could be a call to np.take_along_axis if the sizes match, but the handling of inhomogeneous shapes
311 : # in the plot() function makes me worry they won't, so I go for a slower python for loop...
312 1 : energies_fit = np.array([energy[idx] for energy, idx in zip(energies, idxs, strict=True)])
313 :
314 : # stop highlighting the previous fit
315 1 : if self.fit_data_highlight is not None:
316 1 : self.fit_data_highlight.remove()
317 1 : if self.fit_curve is not None:
318 1 : for curve in self.fit_curve:
319 1 : curve.remove()
320 :
321 1 : self.fit_data_highlight = self.canvas.ax.scatter(x_values, energies_fit, c="green", s=5)
322 :
323 1 : try:
324 1 : fit_params = curve_fit(fit_func, x_values, energies_fit)[0]
325 0 : except (RuntimeError, TypeError):
326 0 : logger.warning("Curve fit failed.")
327 : else:
328 1 : self.fit_curve = self.canvas.ax.plot(
329 : x_values,
330 : fit_func(x_values, *fit_params),
331 : c="green",
332 : linestyle="dashed",
333 : lw=2,
334 : label=fitlabel.format(*fit_params),
335 : )
336 1 : self.canvas.ax.legend()
337 :
338 1 : self.canvas.draw_idle()
339 :
340 1 : def clear(self) -> None:
341 1 : super().clear()
342 1 : self.reset_fit()
343 :
344 1 : def reset_fit(self) -> None:
345 : """Clear fit output and reset fit index."""
346 : # restart at first potential curve
347 1 : self.fit_idx = 0
348 : # and also remove any previous highlighting/fit display
349 1 : self.fit_data_highlight = None
350 1 : self.fit_curve = None
351 :
352 :
353 1 : class PlotLifetimes(PlotWidget):
354 : """Plotwidget for plotting lifetime/transition rate bar charts."""
355 :
356 1 : _annotations: dict[tuple[str, int], mpl.text.Annotation]
357 :
358 1 : def setupWidget(self) -> None:
359 1 : super().setupWidget()
360 :
361 1 : window_color = theme_manager.get_palette().color(QPalette.ColorRole.Window).name()
362 1 : self.canvas.fig.set_facecolor(window_color)
363 1 : self.canvas.fig.set_layout_engine(
364 : "constrained",
365 : w_pad=0.2,
366 : h_pad=0.2,
367 : wspace=0.0,
368 : hspace=0.0,
369 : )
370 :
371 1 : def plot(self, parameters: ParametersLifetimes, results: ResultsLifetimes) -> None:
372 1 : self.clear()
373 1 : ax = self.canvas.ax
374 :
375 1 : show_status_tip(self, "Preparing transition rates...")
376 1 : labels = ["Spontaneous Decay", "Black Body Radiation"]
377 1 : n_list = np.arange(0, np.max([s.n for s in results.kets_bbr + results.kets_sp] + [0]) + 1)
378 1 : sorted_rates: dict[str, dict[int, list[tuple[KetData, float]]]] = {}
379 1 : for key, kets, rates in [
380 : (labels[0], results.kets_sp, results.transition_rates_sp),
381 : (labels[1], results.kets_bbr, results.transition_rates_bbr),
382 : ]:
383 1 : sorted_rates[key] = {n: [] for n in n_list}
384 1 : for i, s in enumerate(kets):
385 1 : sorted_rates[key][s.n].append((s, rates[i]))
386 1 : self.sorted_rates = sorted_rates
387 1 : rates_summed = {key: [sum(r for _, r in sorted_rates[key][n]) for n in n_list] for key in sorted_rates}
388 :
389 1 : show_status_tip(self, "Plotting transition rates...")
390 1 : self.artists: list[mpl.container.BarContainer] = []
391 1 : for label, color in zip(labels, ["blue", "red"], strict=True):
392 1 : bar = ax.bar(n_list, rates_summed[label], label=label, color=color, alpha=0.8)
393 1 : self.artists.append(bar)
394 1 : ax.legend()
395 :
396 1 : ax.set_xlabel("Principal Quantum Number $n$")
397 1 : ax.set_ylabel(r"Transition Rates (1 / ms)")
398 :
399 1 : def setup_annotations(self, parameters: ParametersLifetimes, results: ResultsLifetimes) -> None: # noqa: C901
400 : """Add click-based annotations to the plot."""
401 0 : show_status_tip(self, "Adding transition rate annotations...")
402 :
403 0 : self._bar_data: list[tuple[str, int, mpl.patches.Rectangle]] = []
404 0 : for container in reversed(self.artists):
405 0 : label = container.get_label()
406 0 : if label is None:
407 0 : continue
408 0 : for rect in container.patches:
409 0 : n = round(rect.get_x() + rect.get_width() / 2)
410 0 : self._bar_data.append((label, n, rect))
411 :
412 0 : def on_click(event: mpl.backend_bases.MouseEvent) -> None:
413 0 : if event.inaxes is not self.canvas.ax or event.button not in [1, 3] or self.navigation_toolbar.mode:
414 0 : return
415 0 : if event.button == 3: # right click clears annotations
416 0 : self.clear_annotations()
417 0 : return
418 :
419 0 : if event.xdata is None or event.ydata is None:
420 0 : return
421 0 : click_coords = np.array([event.xdata, event.ydata])
422 :
423 0 : for _label, _n, rect in self._bar_data:
424 0 : x, y = rect.get_x(), rect.get_y()
425 0 : if x <= click_coords[0] <= x + rect.get_width() and y <= click_coords[1] <= y + rect.get_height():
426 0 : label, n = _label, _n
427 0 : break
428 : else: # no break -> no bar found
429 0 : self.clear_annotations()
430 0 : return
431 :
432 : # if we click the same bar again, remove the annotation
433 0 : if (label, n) in self._annotations:
434 0 : self._annotations[(label, n)].remove()
435 0 : del self._annotations[(label, n)]
436 0 : self.canvas.draw_idle()
437 0 : return
438 :
439 0 : state_text = "\n".join(f" - {s}: {r:.5f}/ms" for (s, r) in self.sorted_rates[label][n])
440 0 : text = f"{label} to n={n}:\n{state_text}"
441 0 : xlim = self.canvas.ax.get_xlim()
442 0 : ylim = self.canvas.ax.get_ylim()
443 0 : bar_cx = rect.get_x() + rect.get_width() / 2
444 0 : bar_top = rect.get_y() + rect.get_height()
445 0 : x_frac = (bar_cx - xlim[0]) / (xlim[1] - xlim[0])
446 0 : y_frac = (bar_top - ylim[0]) / (ylim[1] - ylim[0])
447 0 : x_offset = -200 if x_frac > 0.5 else 25
448 0 : y_offset = -50 if y_frac > 0.5 else 50
449 0 : ann = self.canvas.ax.annotate(
450 : text,
451 : xy=(bar_cx, bar_top),
452 : xytext=(x_offset, y_offset),
453 : textcoords="offset points",
454 : bbox={"boxstyle": "round,pad=0.5", "fc": "white", "alpha": 0.9, "ec": "gray"},
455 : arrowprops={"arrowstyle": "->", "connectionstyle": "arc3", "color": "gray"},
456 : clip_on=False,
457 : )
458 0 : ann.set_in_layout(False)
459 0 : self._annotations[(label, n)] = ann
460 0 : self.canvas.draw_idle()
461 :
462 0 : self._click_cid = self.canvas.mpl_connect("button_press_event", on_click) # type: ignore [arg-type]
463 0 : self.navigation_toolbar._home_callbacks = [self.clear_annotations]
464 :
465 :
466 1 : def fit_c3(x: NDArray[Any], /, e0: float, c3: float) -> NDArray[Any]:
467 1 : return e0 + c3 / x**3
468 :
469 :
470 1 : def fit_c6(x: NDArray[Any], /, e0: float, c6: float) -> NDArray[Any]:
471 1 : return e0 + c6 / x**6
472 :
473 :
474 1 : def fit_c3_c6(x: NDArray[Any], /, e0: float, c3: float, c6: float) -> NDArray[Any]:
475 1 : return e0 + c3 / x**3 + c6 / x**6
|