LCOV - code coverage report
Current view: top level - src/pairinteraction_gui/calculate - calculate_base.py (source / functions) Hit Total Coverage
Test: coverage.info Lines: 0 137 0.0 %
Date: 2025-05-02 21:49:59 Functions: 0 38 0.0 %

          Line data    Source code
       1             : # SPDX-FileCopyrightText: 2025 Pairinteraction Developers
       2             : # SPDX-License-Identifier: LGPL-3.0-or-later
       3             : 
       4           0 : import logging
       5           0 : from abc import ABC
       6           0 : from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union
       7             : 
       8           0 : from attr import dataclass
       9             : 
      10           0 : from pairinteraction import (
      11             :     _wrapped,
      12             :     complex as pi_complex,
      13             :     real as pi_real,
      14             : )
      15           0 : from pairinteraction_gui.config.system_config import RangesKeys
      16             : 
      17             : if TYPE_CHECKING:
      18             :     from typing_extensions import Self
      19             : 
      20             :     from pairinteraction.units import NDArray
      21             :     from pairinteraction_gui.page import OneAtomPage, TwoAtomsPage
      22             : 
      23           0 : logger = logging.getLogger(__name__)
      24             : 
      25             : # FIXME: having all kwargs dictionaries being Any is a hacky solution, it would be nice to use TypedDict in the future
      26             : 
      27           0 : UnitFromRangeKey: dict[RangesKeys, str] = {
      28             :     "Ex": "V/cm",
      29             :     "Ey": "V/cm",
      30             :     "Ez": "V/cm",
      31             :     "Bx": "Gauss",
      32             :     "By": "Gauss",
      33             :     "Bz": "Gauss",
      34             :     "Distance": r"$\mu$m",
      35             :     "Angle": r"$^\circ$",
      36             : }
      37             : 
      38           0 : VariableNameFromRangeKey: dict[RangesKeys, str] = {
      39             :     "Ex": "efield_x",
      40             :     "Ey": "efield_y",
      41             :     "Ez": "efield_z",
      42             :     "Bx": "bfield_x",
      43             :     "By": "bfield_y",
      44             :     "Bz": "bfield_z",
      45             :     "Distance": "distance",
      46             :     "Angle": "angle",
      47             : }
      48             : 
      49           0 : PageType = TypeVar("PageType", "OneAtomPage", "TwoAtomsPage")
      50             : 
      51             : 
      52           0 : @dataclass
      53           0 : class Parameters(ABC, Generic[PageType]):
      54           0 :     species: tuple[str, ...]
      55           0 :     quantum_numbers: tuple[dict[str, float], ...]
      56           0 :     quantum_number_deltas: tuple[dict[str, float], ...]
      57           0 :     ranges: dict[RangesKeys, list[float]]
      58           0 :     diagonalize_kwargs: dict[str, str]
      59           0 :     diagonalize_relative_energy_range: Union[tuple[float, float], None]
      60             : 
      61           0 :     def __post_init__(self) -> None:
      62             :         """Post-initialization processing."""
      63             :         # Check if all ranges have the same number of steps
      64           0 :         if not all(len(v) == self.steps for v in self.ranges.values()):
      65           0 :             raise ValueError("All ranges must have the same number of steps")
      66             : 
      67             :         # Check if all tuples have the same length
      68           0 :         if not all(
      69             :             len(tup) == self.n_atoms for tup in [self.species, self.quantum_numbers, self.quantum_number_deltas]
      70             :         ):
      71           0 :             raise ValueError("All tuples must have the same length as the number of atoms")
      72             : 
      73           0 :     @classmethod
      74           0 :     def from_page(cls, page: PageType) -> "Self":
      75             :         """Create Parameters object from page."""
      76           0 :         n_atoms = page.ket_config.n_atoms
      77             : 
      78           0 :         species = tuple(page.ket_config.get_species(atom) for atom in range(n_atoms))
      79           0 :         quantum_numbers = tuple(page.ket_config.get_quantum_numbers(atom) for atom in range(n_atoms))
      80             : 
      81           0 :         quantum_number_deltas = tuple(page.basis_config.get_quantum_number_deltas(atom) for atom in range(n_atoms))
      82             : 
      83           0 :         ranges = page.system_config.get_ranges_dict()
      84             : 
      85           0 :         diagonalize_kwargs = {}
      86           0 :         if page.plotwidget.fast_mode.isChecked():
      87           0 :             diagonalize_kwargs["diagonalizer"] = "lapacke_evr"
      88           0 :             diagonalize_kwargs["float_type"] = "float32"
      89             : 
      90           0 :         diagonalize_relative_energy_range = None
      91           0 :         if page.plotwidget.energy_range.isChecked():
      92           0 :             diagonalize_relative_energy_range = page.plotwidget.energy_range.values()
      93             : 
      94           0 :         return cls(
      95             :             species,
      96             :             quantum_numbers,
      97             :             quantum_number_deltas,
      98             :             ranges,
      99             :             diagonalize_kwargs,
     100             :             diagonalize_relative_energy_range,
     101             :         )
     102             : 
     103           0 :     @property
     104           0 :     def is_real(self) -> bool:
     105             :         """Check if the parameters are real."""
     106           0 :         return all(e == 0 for e in self.ranges.get("Ey", [0])) and all(b == 0 for b in self.ranges.get("By", [0]))
     107             : 
     108           0 :     @property
     109           0 :     def steps(self) -> int:
     110             :         """Return the number of steps."""
     111           0 :         return len(next(iter(self.ranges.values())))
     112             : 
     113           0 :     @property
     114           0 :     def n_atoms(self) -> int:
     115             :         """Return the number of atoms."""
     116           0 :         return len(self.species)
     117             : 
     118           0 :     def get_efield(self, step: int) -> list[float]:
     119             :         """Return the electric field for the given step."""
     120           0 :         efield_keys: list[RangesKeys] = ["Ex", "Ey", "Ez"]
     121           0 :         return [self.ranges[key][step] if key in self.ranges else 0 for key in efield_keys]
     122             : 
     123           0 :     def get_bfield(self, step: int) -> list[float]:
     124             :         """Return the magnetic field for the given step."""
     125           0 :         bfield_keys: list[RangesKeys] = ["Bx", "By", "Bz"]
     126           0 :         return [self.ranges[key][step] if key in self.ranges else 0 for key in bfield_keys]
     127             : 
     128           0 :     def get_species(self, atom: Optional[int] = None) -> str:
     129             :         """Return the species for the given ket."""
     130           0 :         return self.species[self._check_atom(atom)]
     131             : 
     132           0 :     def get_quantum_numbers(self, atom: Optional[int] = None) -> dict[str, Any]:
     133             :         """Return the quantum numbers for the given ket."""
     134           0 :         return self.quantum_numbers[self._check_atom(atom)]
     135             : 
     136           0 :     def get_quantum_number_restrictions(self, atom: Optional[int] = None) -> dict[str, Any]:
     137             :         """Return the quantum number restrictions for the given ket."""
     138           0 :         atom = self._check_atom(atom)
     139           0 :         qn_restrictions: dict[str, tuple[float, float]] = {}
     140           0 :         for key, delta in self.quantum_number_deltas[atom].items():
     141           0 :             if key in self.quantum_numbers[atom]:
     142           0 :                 qn_restrictions[key] = (
     143             :                     self.quantum_numbers[atom][key] - delta,
     144             :                     self.quantum_numbers[atom][key] + delta,
     145             :                 )
     146             :             else:
     147           0 :                 raise ValueError(f"Quantum number delta {key} not found in quantum numbers.")
     148           0 :         return qn_restrictions
     149             : 
     150           0 :     def _check_atom(self, atom: Optional[int] = None) -> int:
     151             :         """Check if the atom is valid."""
     152           0 :         if atom is not None:
     153           0 :             return atom
     154           0 :         if self.n_atoms == 1:
     155           0 :             return 0
     156           0 :         raise ValueError("Atom index is required for multiple atoms")
     157             : 
     158           0 :     def get_diagonalize_energy_range(self, energy_of_interest: float) -> dict[str, Any]:
     159             :         """Return the kwargs for the diagonalization energy range."""
     160           0 :         if self.diagonalize_relative_energy_range is None:
     161           0 :             return {}
     162           0 :         kwargs: dict[str, Any] = {"energy_unit": "GHz"}
     163           0 :         kwargs["energy_range"] = (
     164             :             energy_of_interest + self.diagonalize_relative_energy_range[0],
     165             :             energy_of_interest + self.diagonalize_relative_energy_range[1],
     166             :         )
     167           0 :         return kwargs
     168             : 
     169           0 :     def get_x_values(self) -> list[float]:
     170             :         """Return the x values for the plot."""
     171           0 :         max_key = self._get_ranges_max_diff_key()
     172           0 :         return self.ranges[max_key]
     173             : 
     174           0 :     def get_x_label(self) -> str:
     175             :         """Return the x values for the plot."""
     176           0 :         max_key = self._get_ranges_max_diff_key()
     177           0 :         x_label = f"{max_key} [{UnitFromRangeKey[max_key]}]"
     178             : 
     179           0 :         non_constant_keys = [key for key, values in self.ranges.items() if key != max_key and values[0] != values[-1]]
     180           0 :         if non_constant_keys:
     181           0 :             x_label += f"  ({', '.join(non_constant_keys)} did also change)"
     182             : 
     183           0 :         return x_label
     184             : 
     185           0 :     def _get_ranges_max_diff_key(self) -> RangesKeys:
     186             :         """Return the key with the maximum difference in the ranges."""
     187           0 :         range_diffs: dict[RangesKeys, float] = {key: abs(r[-1] - r[0]) for key, r in self.ranges.items()}
     188           0 :         return max(range_diffs, key=lambda x: range_diffs.get(x, -1))
     189             : 
     190           0 :     def to_replacement_dict(self) -> dict[str, str]:
     191             :         """Return a dictionary with the parameters for replacement."""
     192           0 :         max_key = self._get_ranges_max_diff_key()
     193           0 :         replacements: dict[str, str] = {
     194             :             "$PI_DTYPE": "real" if self.is_real else "complex",
     195             :             "$X_VARIABLE_NAME": VariableNameFromRangeKey[max_key],
     196             :             "$X_LABEL": as_string(self.get_x_label(), raw_string=True),
     197             :         }
     198             : 
     199           0 :         for atom in range(self.n_atoms):
     200           0 :             replacements[f"$SPECIES_{atom}"] = as_string(self.get_species(atom))
     201           0 :             replacements[f"$QUANTUM_NUMBERS_{atom}"] = dict_to_repl(self.get_quantum_numbers(atom))
     202           0 :             replacements[f"$QUANTUM_NUMBERS_RESTRICTIONS_{atom}"] = dict_to_repl(
     203             :                 self.get_quantum_number_restrictions(atom)
     204             :             )
     205             : 
     206           0 :         replacements["$STEPS"] = str(self.steps)
     207           0 :         for key, values in self.ranges.items():
     208           0 :             replacements[f"${key.upper()}_MIN"] = str(values[0])
     209           0 :             replacements[f"${key.upper()}_MAX"] = str(values[-1])
     210           0 :             if values[0] == values[-1]:
     211           0 :                 replacements[f"${key.upper()}_VALUE"] = str(values[0])
     212             : 
     213           0 :         replacements["$DIAGONALIZE_KWARGS"] = dict_to_repl(self.diagonalize_kwargs)
     214             : 
     215           0 :         if self.diagonalize_relative_energy_range is not None:
     216           0 :             r_energy = self.diagonalize_relative_energy_range
     217           0 :             replacements["$DIAGONALIZE_ENERGY_RANGE_KWARGS"] = (
     218             :                 f', energy_range=(ket_energy + {r_energy[0]}, ket_energy - {-r_energy[1]}), energy_unit="GHz"'
     219             :             )
     220             :         else:
     221           0 :             replacements["$DIAGONALIZE_ENERGY_RANGE_KWARGS"] = ""
     222             : 
     223           0 :         return replacements
     224             : 
     225             : 
     226           0 : @dataclass
     227           0 : class Results(ABC):
     228           0 :     energies: list["NDArray"]
     229           0 :     energy_offset: float
     230           0 :     ket_overlaps: list["NDArray"]
     231           0 :     state_labels_0: list[str]
     232             : 
     233           0 :     @classmethod
     234           0 :     def from_calculate(
     235             :         cls,
     236             :         system_list: Union[
     237             :             list[pi_real.SystemPair], list[pi_complex.SystemPair], list[pi_real.SystemAtom], list[pi_complex.SystemAtom]
     238             :         ],
     239             :         ket: Union[_wrapped.KetAtom, tuple[_wrapped.KetAtom, ...]],
     240             :         energy_offset: float,
     241             :     ) -> "Self":
     242             :         """Create Results object from ket, basis, and diagonalized systems."""
     243           0 :         energies = [system.get_eigenenergies("GHz") - energy_offset for system in system_list]
     244           0 :         ket_overlaps = [system.get_eigenbasis().get_overlaps(ket) for system in system_list]  # type: ignore [arg-type]
     245           0 :         basis_0 = system_list[-1].get_eigenbasis()
     246           0 :         state_0 = [basis_0.kets[i] for i in range(basis_0.number_of_states)]
     247           0 :         state_labels_0 = [s.get_label("ket") for s in state_0]
     248             : 
     249           0 :         return cls(energies, energy_offset, ket_overlaps, state_labels_0)
     250             : 
     251             : 
     252           0 : def as_string(value: str, *, raw_string: bool = False) -> str:
     253           0 :     string = '"' + value + '"'
     254           0 :     if raw_string:
     255           0 :         string = "r" + string
     256           0 :     return string
     257             : 
     258             : 
     259           0 : def dict_to_repl(d: dict[str, Any]) -> str:
     260             :     """Convert a dictionary to a string for replacement."""
     261           0 :     if not d:
     262           0 :         return ""
     263           0 :     repl = ""
     264           0 :     for k, v in d.items():
     265           0 :         if isinstance(v, str):
     266           0 :             repl += f", {k}={as_string(v)}"
     267             :         else:
     268           0 :             repl += f", {k}={v}"
     269           0 :     return repl

Generated by: LCOV version 1.16