LCOV - code coverage report
Current view: top level - src/pairinteraction_gui/plotwidget - plotwidget.py (source / functions) Hit Total Coverage
Test: coverage.info Lines: 0 85 0.0 %
Date: 2025-04-29 15:59:54 Functions: 0 16 0.0 %

          Line data    Source code
       1             : # SPDX-FileCopyrightText: 2025 Pairinteraction Developers
       2             : # SPDX-License-Identifier: LGPL-3.0-or-later
       3             : 
       4           0 : import logging
       5           0 : from collections.abc import Sequence
       6           0 : from typing import TYPE_CHECKING, Any
       7             : 
       8           0 : import matplotlib as mpl
       9           0 : import matplotlib.pyplot as plt
      10           0 : import mplcursors
      11           0 : import numpy as np
      12           0 : from matplotlib.colors import Normalize
      13           0 : from PySide6.QtWidgets import QPushButton
      14             : 
      15           0 : from pairinteraction.visualization.colormaps import alphamagma
      16           0 : from pairinteraction_gui.plotwidget.canvas import MatplotlibCanvas
      17           0 : from pairinteraction_gui.qobjects import WidgetH, WidgetV
      18           0 : from pairinteraction_gui.qobjects.item import Item, RangeItem
      19             : 
      20             : if TYPE_CHECKING:
      21             :     from numpy.typing import NDArray
      22             : 
      23             :     from pairinteraction_gui.page import SimulationPage
      24             : 
      25           0 : logger = logging.getLogger(__name__)
      26             : 
      27             : 
      28           0 : class PlotWidget(WidgetV):
      29             :     """Widget for displaying plots with controls."""
      30             : 
      31           0 :     margin = (0, 0, 0, 0)
      32           0 :     spacing = 15
      33             : 
      34           0 :     def __init__(self, parent: "SimulationPage") -> None:
      35             :         """Initialize the base section."""
      36           0 :         mpl.use("Qt5Agg")
      37             : 
      38           0 :         self.page = parent
      39           0 :         super().__init__(parent)
      40             : 
      41           0 :     def setupWidget(self) -> None:
      42           0 :         self.plot_toolbar = WidgetV(self)
      43           0 :         self.canvas = MatplotlibCanvas(self)
      44             : 
      45           0 :     def postSetupWidget(self) -> None:
      46           0 :         top_row = WidgetH(self)
      47             : 
      48             :         # Add plot toolbar on left
      49           0 :         top_row.layout().addWidget(self.plot_toolbar)
      50             : 
      51             :         # Add reset zoom button on right
      52           0 :         reset_zoom_button = QPushButton("Reset Zoom", self)
      53           0 :         reset_zoom_button.setToolTip(
      54             :             "Reset the plot view to its original state. You can zoom in/out using the mousewheel."
      55             :         )
      56           0 :         reset_zoom_button.clicked.connect(self.canvas.reset_view)
      57           0 :         top_row.layout().addWidget(reset_zoom_button)
      58             : 
      59           0 :         self.layout().addWidget(top_row)
      60           0 :         self.layout().addWidget(self.canvas, stretch=1)
      61             : 
      62           0 :     def clear(self) -> None:
      63           0 :         self.canvas.ax.clear()
      64           0 :         self.canvas.draw()
      65             : 
      66             : 
      67           0 : class PlotEnergies(PlotWidget):
      68             :     """Plotwidget for plotting energy levels."""
      69             : 
      70           0 :     def setupWidget(self) -> None:
      71           0 :         super().setupWidget()
      72             : 
      73           0 :         self.energy_range = RangeItem(
      74             :             self,
      75             :             "Calculate the energies from",
      76             :             (-999, 999),
      77             :             (-0.5, 0.5),
      78             :             unit="GHz",
      79             :             checked=False,
      80             :             tooltip_label="energy",
      81             :         )
      82           0 :         self.plot_toolbar.layout().addWidget(self.energy_range)
      83             : 
      84           0 :         self.fast_mode = Item(self, "Use fast calculation mode", {}, "", checked=True)
      85           0 :         self.plot_toolbar.layout().addWidget(self.fast_mode)
      86             : 
      87           0 :         mappable = plt.cm.ScalarMappable(cmap=alphamagma, norm=Normalize(vmin=0, vmax=1))
      88           0 :         self.canvas.fig.colorbar(mappable, ax=self.canvas.ax, label="Overlap with state of interest")
      89           0 :         self.canvas.fig.tight_layout()
      90             : 
      91           0 :     def plot(
      92             :         self,
      93             :         x_list: Sequence[float],
      94             :         energies_list: Sequence["NDArray[Any]"],
      95             :         overlaps_list: Sequence["NDArray[Any]"],
      96             :         xlabel: str,
      97             :     ) -> None:
      98           0 :         ax = self.canvas.ax
      99           0 :         ax.clear()
     100             : 
     101           0 :         try:
     102           0 :             ax.plot(x_list, np.array(energies_list), c="0.9", lw=0.25, zorder=-10)
     103           0 :         except ValueError as err:
     104           0 :             if "inhomogeneous shape" in str(err):
     105           0 :                 for x_value, es in zip(x_list, energies_list):
     106           0 :                     ax.plot([x_value] * len(es), es, c="0.9", ls="None", marker=".", zorder=-10)
     107             :             else:
     108           0 :                 raise err
     109             : 
     110             :         # Flatten the arrays for scatter plot and repeat x value for each energy
     111             :         # (dont use numpy.flatten, etc. to also handle inhomogeneous shapes)
     112           0 :         x_repeated = np.hstack([val * np.ones_like(es) for val, es in zip(x_list, energies_list)])
     113           0 :         energies_flattend = np.hstack(energies_list)
     114           0 :         overlaps_flattend = np.hstack(overlaps_list)
     115             : 
     116           0 :         min_overlap = 1e-4
     117           0 :         inds: NDArray[Any] = np.argwhere(overlaps_flattend > min_overlap).flatten()
     118           0 :         inds = inds[np.argsort(overlaps_flattend[inds])]
     119             : 
     120           0 :         if len(inds) > 0:
     121           0 :             ax.scatter(
     122             :                 x_repeated[inds],
     123             :                 energies_flattend[inds],
     124             :                 c=overlaps_flattend[inds],
     125             :                 s=15,
     126             :                 vmin=0,
     127             :                 vmax=1,
     128             :                 cmap=alphamagma,
     129             :             )
     130             : 
     131           0 :         ylim = ax.get_ylim()
     132           0 :         if abs(ylim[1] - ylim[0]) < 1e-2:
     133           0 :             ax.set_ylim(ylim[0] - 1e-2, ylim[1] + 1e-2)
     134             : 
     135           0 :         ax.set_xlabel(xlabel)
     136           0 :         ax.set_ylabel("Energy [GHz]")
     137           0 :         self.canvas.fig.tight_layout()
     138             : 
     139           0 :     def add_cursor(self, x_value: float, energies: "NDArray[Any]", state_labels_0: list[str]) -> None:
     140             :         # Remove any existing cursors to avoid duplicates
     141           0 :         if hasattr(self, "mpl_cursor"):
     142           0 :             if hasattr(self.mpl_cursor, "remove"):  # type: ignore
     143           0 :                 self.mpl_cursor.remove()  # type: ignore
     144           0 :             del self.mpl_cursor  # type: ignore
     145             : 
     146           0 :         ax = self.canvas.ax
     147             : 
     148           0 :         artists = []
     149           0 :         for e, ket_label in zip(energies, state_labels_0):
     150           0 :             artist = ax.plot(x_value, e, "o", c="0.9", ms=5, zorder=-20, fillstyle="none", label=ket_label)
     151           0 :             artists.extend(artist)
     152             : 
     153           0 :         self.mpl_cursor = mplcursors.cursor(
     154             :             artists,
     155             :             hover=False,
     156             :             annotation_kwargs={
     157             :                 "bbox": {"boxstyle": "round,pad=0.5", "fc": "white", "alpha": 0.9, "ec": "gray"},
     158             :                 "arrowprops": {"arrowstyle": "->", "connectionstyle": "arc3", "color": "gray"},
     159             :             },
     160             :         )
     161             : 
     162           0 :         @self.mpl_cursor.connect("add")
     163           0 :         def on_add(sel: mplcursors.Selection) -> None:
     164           0 :             label = sel.artist.get_label()
     165           0 :             sel.annotation.set_text(label)

Generated by: LCOV version 1.16