Line data Source code
1 : # SPDX-FileCopyrightText: 2025 Pairinteraction Developers
2 : # SPDX-License-Identifier: LGPL-3.0-or-later
3 :
4 0 : import logging
5 0 : from collections.abc import Sequence
6 0 : from typing import TYPE_CHECKING, Any
7 :
8 0 : import matplotlib as mpl
9 0 : import matplotlib.pyplot as plt
10 0 : import mplcursors
11 0 : import numpy as np
12 0 : from matplotlib.colors import Normalize
13 0 : from PySide6.QtWidgets import QPushButton
14 :
15 0 : from pairinteraction.visualization.colormaps import alphamagma
16 0 : from pairinteraction_gui.plotwidget.canvas import MatplotlibCanvas
17 0 : from pairinteraction_gui.qobjects import WidgetH, WidgetV
18 0 : from pairinteraction_gui.qobjects.item import Item, RangeItem
19 :
20 : if TYPE_CHECKING:
21 : from numpy.typing import NDArray
22 :
23 : from pairinteraction_gui.page import SimulationPage
24 :
25 0 : logger = logging.getLogger(__name__)
26 :
27 :
28 0 : class PlotWidget(WidgetV):
29 : """Widget for displaying plots with controls."""
30 :
31 0 : margin = (0, 0, 0, 0)
32 0 : spacing = 15
33 :
34 0 : def __init__(self, parent: "SimulationPage") -> None:
35 : """Initialize the base section."""
36 0 : mpl.use("Qt5Agg")
37 :
38 0 : self.page = parent
39 0 : super().__init__(parent)
40 :
41 0 : def setupWidget(self) -> None:
42 0 : self.plot_toolbar = WidgetV(self)
43 0 : self.canvas = MatplotlibCanvas(self)
44 :
45 0 : def postSetupWidget(self) -> None:
46 0 : top_row = WidgetH(self)
47 :
48 : # Add plot toolbar on left
49 0 : top_row.layout().addWidget(self.plot_toolbar)
50 :
51 : # Add reset zoom button on right
52 0 : reset_zoom_button = QPushButton("Reset Zoom", self)
53 0 : reset_zoom_button.setToolTip(
54 : "Reset the plot view to its original state. You can zoom in/out using the mousewheel."
55 : )
56 0 : reset_zoom_button.clicked.connect(self.canvas.reset_view)
57 0 : top_row.layout().addWidget(reset_zoom_button)
58 :
59 0 : self.layout().addWidget(top_row)
60 0 : self.layout().addWidget(self.canvas, stretch=1)
61 :
62 0 : def clear(self) -> None:
63 0 : self.canvas.ax.clear()
64 0 : self.canvas.draw()
65 :
66 :
67 0 : class PlotEnergies(PlotWidget):
68 : """Plotwidget for plotting energy levels."""
69 :
70 0 : def setupWidget(self) -> None:
71 0 : super().setupWidget()
72 :
73 0 : self.energy_range = RangeItem(
74 : self,
75 : "Calculate the energies from",
76 : (-999, 999),
77 : (-0.5, 0.5),
78 : unit="GHz",
79 : checked=False,
80 : tooltip_label="energy",
81 : )
82 0 : self.plot_toolbar.layout().addWidget(self.energy_range)
83 :
84 0 : self.fast_mode = Item(self, "Use fast calculation mode", {}, "", checked=True)
85 0 : self.plot_toolbar.layout().addWidget(self.fast_mode)
86 :
87 0 : mappable = plt.cm.ScalarMappable(cmap=alphamagma, norm=Normalize(vmin=0, vmax=1))
88 0 : self.canvas.fig.colorbar(mappable, ax=self.canvas.ax, label="Overlap with state of interest")
89 0 : self.canvas.fig.tight_layout()
90 :
91 0 : def plot(
92 : self,
93 : x_list: Sequence[float],
94 : energies_list: Sequence["NDArray[Any]"],
95 : overlaps_list: Sequence["NDArray[Any]"],
96 : xlabel: str,
97 : ) -> None:
98 0 : ax = self.canvas.ax
99 0 : ax.clear()
100 :
101 0 : try:
102 0 : ax.plot(x_list, np.array(energies_list), c="0.9", lw=0.25, zorder=-10)
103 0 : except ValueError as err:
104 0 : if "inhomogeneous shape" in str(err):
105 0 : for x_value, es in zip(x_list, energies_list):
106 0 : ax.plot([x_value] * len(es), es, c="0.9", ls="None", marker=".", zorder=-10)
107 : else:
108 0 : raise err
109 :
110 : # Flatten the arrays for scatter plot and repeat x value for each energy
111 : # (dont use numpy.flatten, etc. to also handle inhomogeneous shapes)
112 0 : x_repeated = np.hstack([val * np.ones_like(es) for val, es in zip(x_list, energies_list)])
113 0 : energies_flattend = np.hstack(energies_list)
114 0 : overlaps_flattend = np.hstack(overlaps_list)
115 :
116 0 : min_overlap = 1e-4
117 0 : inds: NDArray[Any] = np.argwhere(overlaps_flattend > min_overlap).flatten()
118 0 : inds = inds[np.argsort(overlaps_flattend[inds])]
119 :
120 0 : if len(inds) > 0:
121 0 : ax.scatter(
122 : x_repeated[inds],
123 : energies_flattend[inds],
124 : c=overlaps_flattend[inds],
125 : s=15,
126 : vmin=0,
127 : vmax=1,
128 : cmap=alphamagma,
129 : )
130 :
131 0 : ylim = ax.get_ylim()
132 0 : if abs(ylim[1] - ylim[0]) < 1e-2:
133 0 : ax.set_ylim(ylim[0] - 1e-2, ylim[1] + 1e-2)
134 :
135 0 : ax.set_xlabel(xlabel)
136 0 : ax.set_ylabel("Energy [GHz]")
137 0 : self.canvas.fig.tight_layout()
138 :
139 0 : def add_cursor(self, x_value: float, energies: "NDArray[Any]", state_labels_0: list[str]) -> None:
140 : # Remove any existing cursors to avoid duplicates
141 0 : if hasattr(self, "mpl_cursor"):
142 0 : if hasattr(self.mpl_cursor, "remove"): # type: ignore
143 0 : self.mpl_cursor.remove() # type: ignore
144 0 : del self.mpl_cursor # type: ignore
145 :
146 0 : ax = self.canvas.ax
147 :
148 0 : artists = []
149 0 : for e, ket_label in zip(energies, state_labels_0):
150 0 : artist = ax.plot(x_value, e, "o", c="0.9", ms=5, zorder=-20, fillstyle="none", label=ket_label)
151 0 : artists.extend(artist)
152 :
153 0 : self.mpl_cursor = mplcursors.cursor(
154 : artists,
155 : hover=False,
156 : annotation_kwargs={
157 : "bbox": {"boxstyle": "round,pad=0.5", "fc": "white", "alpha": 0.9, "ec": "gray"},
158 : "arrowprops": {"arrowstyle": "->", "connectionstyle": "arc3", "color": "gray"},
159 : },
160 : )
161 :
162 0 : @self.mpl_cursor.connect("add")
163 0 : def on_add(sel: mplcursors.Selection) -> None:
164 0 : label = sel.artist.get_label()
165 0 : sel.annotation.set_text(label)
|