Line data Source code
1 : # SPDX-FileCopyrightText: 2025 Pairinteraction Developers
2 : # SPDX-License-Identifier: LGPL-3.0-or-later
3 :
4 1 : import logging
5 1 : from collections.abc import Sequence
6 1 : from typing import TYPE_CHECKING, Any
7 :
8 1 : import matplotlib as mpl
9 1 : import matplotlib.pyplot as plt
10 1 : import mplcursors
11 1 : import numpy as np
12 1 : from matplotlib.colors import Normalize
13 1 : from PySide6.QtWidgets import QHBoxLayout
14 :
15 1 : from pairinteraction.visualization.colormaps import alphamagma
16 1 : from pairinteraction_gui.plotwidget.canvas import MatplotlibCanvas
17 1 : from pairinteraction_gui.plotwidget.navigation_toolbar import CustomNavigationToolbar
18 1 : from pairinteraction_gui.qobjects import WidgetV
19 1 : from pairinteraction_gui.theme import plot_widget_theme
20 :
21 : if TYPE_CHECKING:
22 : from numpy.typing import NDArray
23 :
24 : from pairinteraction_gui.page import SimulationPage
25 :
26 1 : logger = logging.getLogger(__name__)
27 :
28 :
29 1 : class PlotWidget(WidgetV):
30 : """Widget for displaying plots with controls."""
31 :
32 1 : margin = (0, 0, 0, 0)
33 1 : spacing = 15
34 :
35 1 : def __init__(self, parent: "SimulationPage") -> None:
36 : """Initialize the base section."""
37 1 : mpl.use("Qt5Agg")
38 :
39 1 : self.page = parent
40 1 : super().__init__(parent)
41 :
42 1 : self.setStyleSheet(plot_widget_theme)
43 :
44 1 : def setupWidget(self) -> None:
45 1 : self.canvas = MatplotlibCanvas(self)
46 1 : self.navigation_toolbar = CustomNavigationToolbar(self.canvas, self)
47 :
48 1 : top_layout = QHBoxLayout()
49 1 : top_layout.addStretch(1)
50 1 : top_layout.addWidget(self.navigation_toolbar)
51 1 : self.layout().addLayout(top_layout)
52 :
53 1 : self.layout().addWidget(self.canvas, stretch=1)
54 :
55 1 : def clear(self) -> None:
56 1 : self.canvas.ax.clear()
57 1 : self.canvas.draw()
58 :
59 :
60 1 : class PlotEnergies(PlotWidget):
61 : """Plotwidget for plotting energy levels."""
62 :
63 1 : def setupWidget(self) -> None:
64 1 : super().setupWidget()
65 :
66 1 : mappable = plt.cm.ScalarMappable(cmap=alphamagma, norm=Normalize(vmin=0, vmax=1))
67 1 : self.canvas.fig.colorbar(mappable, ax=self.canvas.ax, label="Overlap with state of interest")
68 1 : self.canvas.fig.tight_layout()
69 :
70 1 : def plot(
71 : self,
72 : x_list: Sequence[float],
73 : energies_list: Sequence["NDArray[Any]"],
74 : overlaps_list: Sequence["NDArray[Any]"],
75 : xlabel: str,
76 : ) -> None:
77 0 : ax = self.canvas.ax
78 0 : ax.clear()
79 :
80 0 : try:
81 0 : ax.plot(x_list, np.array(energies_list), c="0.75", lw=0.25, zorder=-10)
82 0 : except ValueError as err:
83 0 : if "inhomogeneous shape" in str(err):
84 0 : for x_value, es in zip(x_list, energies_list):
85 0 : ax.plot([x_value] * len(es), es, c="0.75", ls="None", marker=".", zorder=-10)
86 : else:
87 0 : raise err
88 :
89 : # Flatten the arrays for scatter plot and repeat x value for each energy
90 : # (dont use numpy.flatten, etc. to also handle inhomogeneous shapes)
91 0 : x_repeated = np.hstack([val * np.ones_like(es) for val, es in zip(x_list, energies_list)])
92 0 : energies_flattend = np.hstack(energies_list)
93 0 : overlaps_flattend = np.hstack(overlaps_list)
94 :
95 0 : min_overlap = 1e-4
96 0 : inds: NDArray[Any] = np.argwhere(overlaps_flattend > min_overlap).flatten()
97 0 : inds = inds[np.argsort(overlaps_flattend[inds])]
98 :
99 0 : if len(inds) > 0:
100 0 : ax.scatter(
101 : x_repeated[inds],
102 : energies_flattend[inds],
103 : c=overlaps_flattend[inds],
104 : s=15,
105 : vmin=0,
106 : vmax=1,
107 : cmap=alphamagma,
108 : )
109 :
110 0 : ax.set_xlabel(xlabel)
111 0 : ax.set_ylabel("Energy [GHz]")
112 :
113 0 : self.canvas.fig.tight_layout()
114 :
115 1 : def add_cursor(
116 : self, x_value: list[float], energies: list["NDArray[Any]"], state_labels: dict[int, list[str]]
117 : ) -> None:
118 : # Remove any existing cursors to avoid duplicates
119 0 : if hasattr(self, "mpl_cursor"):
120 0 : if hasattr(self.mpl_cursor, "remove"): # type: ignore
121 0 : self.mpl_cursor.remove() # type: ignore
122 0 : del self.mpl_cursor # type: ignore
123 :
124 0 : ax = self.canvas.ax
125 :
126 0 : artists = []
127 0 : for idx, labels in state_labels.items():
128 0 : x = x_value[idx]
129 0 : for energy, label in zip(energies[idx], labels):
130 0 : artist = ax.plot(x, energy, "d", c="0.93", alpha=0.5, ms=7, label=label, zorder=-20)
131 0 : artists.extend(artist)
132 :
133 0 : self.mpl_cursor = mplcursors.cursor(
134 : artists,
135 : hover=False,
136 : annotation_kwargs={
137 : "bbox": {"boxstyle": "round,pad=0.5", "fc": "white", "alpha": 0.9, "ec": "gray"},
138 : "arrowprops": {"arrowstyle": "->", "connectionstyle": "arc3", "color": "gray"},
139 : },
140 : )
141 :
142 0 : @self.mpl_cursor.connect("add")
143 0 : def on_add(sel: mplcursors.Selection) -> None:
144 0 : label = sel.artist.get_label()
145 0 : sel.annotation.set_text(label.replace(" + ", "\n + "))
|