LCOV - code coverage report
Current view: top level - src/pairinteraction_gui/calculate - calculate_base.py (source / functions) Hit Total Coverage
Test: coverage.info Lines: 86 139 61.9 %
Date: 2025-06-06 09:09:03 Functions: 11 38 28.9 %

          Line data    Source code
       1             : # SPDX-FileCopyrightText: 2025 Pairinteraction Developers
       2             : # SPDX-License-Identifier: LGPL-3.0-or-later
       3             : 
       4           1 : import logging
       5           1 : from abc import ABC
       6           1 : from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union
       7             : 
       8           1 : import numpy as np
       9           1 : from attr import dataclass
      10             : 
      11           1 : from pairinteraction import (
      12             :     _wrapped,
      13             :     complex as pi_complex,
      14             :     real as pi_real,
      15             : )
      16           1 : from pairinteraction_gui.config.system_config import RangesKeys
      17             : 
      18             : if TYPE_CHECKING:
      19             :     from typing_extensions import Self
      20             : 
      21             :     from pairinteraction.units import NDArray
      22             :     from pairinteraction_gui.page import OneAtomPage, TwoAtomsPage
      23             : 
      24           1 : logger = logging.getLogger(__name__)
      25             : 
      26             : # FIXME: having all kwargs dictionaries being Any is a hacky solution, it would be nice to use TypedDict in the future
      27             : 
      28           1 : UnitFromRangeKey: dict[RangesKeys, str] = {
      29             :     "Ex": "V/cm",
      30             :     "Ey": "V/cm",
      31             :     "Ez": "V/cm",
      32             :     "Bx": "Gauss",
      33             :     "By": "Gauss",
      34             :     "Bz": "Gauss",
      35             :     "Distance": r"$\mu$m",
      36             :     "Angle": r"$^\circ$",
      37             : }
      38             : 
      39           1 : VariableNameFromRangeKey: dict[RangesKeys, str] = {
      40             :     "Ex": "efield_x",
      41             :     "Ey": "efield_y",
      42             :     "Ez": "efield_z",
      43             :     "Bx": "bfield_x",
      44             :     "By": "bfield_y",
      45             :     "Bz": "bfield_z",
      46             :     "Distance": "distance",
      47             :     "Angle": "angle",
      48             : }
      49             : 
      50           1 : PageType = TypeVar("PageType", "OneAtomPage", "TwoAtomsPage")
      51             : 
      52             : 
      53           1 : @dataclass
      54           1 : class Parameters(ABC, Generic[PageType]):
      55           1 :     species: tuple[str, ...]
      56           1 :     quantum_numbers: tuple[dict[str, float], ...]
      57           1 :     quantum_number_deltas: tuple[dict[str, float], ...]
      58           1 :     ranges: dict[RangesKeys, list[float]]
      59           1 :     diagonalize_kwargs: dict[str, str]
      60           1 :     diagonalize_relative_energy_range: Union[tuple[float, float], None]
      61           1 :     number_state_labels: int
      62             : 
      63           1 :     def __post_init__(self) -> None:
      64             :         """Post-initialization processing."""
      65             :         # Check if all ranges have the same number of steps
      66           0 :         if not all(len(v) == self.steps for v in self.ranges.values()):
      67           0 :             raise ValueError("All ranges must have the same number of steps")
      68             : 
      69             :         # Check if all tuples have the same length
      70           0 :         if not all(
      71             :             len(tup) == self.n_atoms for tup in [self.species, self.quantum_numbers, self.quantum_number_deltas]
      72             :         ):
      73           0 :             raise ValueError("All tuples must have the same length as the number of atoms")
      74             : 
      75           1 :     @classmethod
      76           1 :     def from_page(cls, page: PageType) -> "Self":
      77             :         """Create Parameters object from page."""
      78           1 :         n_atoms = page.ket_config.n_atoms
      79             : 
      80           1 :         species = tuple(page.ket_config.get_species(atom) for atom in range(n_atoms))
      81           1 :         quantum_numbers = tuple(page.ket_config.get_quantum_numbers(atom) for atom in range(n_atoms))
      82             : 
      83           1 :         quantum_number_deltas = tuple(page.basis_config.get_quantum_number_deltas(atom) for atom in range(n_atoms))
      84             : 
      85           1 :         ranges = page.system_config.get_ranges_dict()
      86             : 
      87           1 :         diagonalize_kwargs = {}
      88           1 :         if page.calculation_config.fast_mode.isChecked():
      89           1 :             diagonalize_kwargs["diagonalizer"] = "lapacke_evr"
      90           1 :             diagonalize_kwargs["float_type"] = "float32"
      91             : 
      92           1 :         diagonalize_relative_energy_range = None
      93           1 :         if page.calculation_config.energy_range.isChecked():
      94           0 :             diagonalize_relative_energy_range = page.calculation_config.energy_range.values()
      95             : 
      96           1 :         return cls(
      97             :             species,
      98             :             quantum_numbers,
      99             :             quantum_number_deltas,
     100             :             ranges,
     101             :             diagonalize_kwargs,
     102             :             diagonalize_relative_energy_range,
     103             :             page.calculation_config.number_state_labels.value(default=0),
     104             :         )
     105             : 
     106           1 :     @property
     107           1 :     def is_real(self) -> bool:
     108             :         """Check if the parameters are real."""
     109           1 :         return all(e == 0 for e in self.ranges.get("Ey", [0])) and all(b == 0 for b in self.ranges.get("By", [0]))
     110             : 
     111           1 :     @property
     112           1 :     def steps(self) -> int:
     113             :         """Return the number of steps."""
     114           1 :         return len(next(iter(self.ranges.values())))
     115             : 
     116           1 :     @property
     117           1 :     def n_atoms(self) -> int:
     118             :         """Return the number of atoms."""
     119           0 :         return len(self.species)
     120             : 
     121           1 :     def get_efield(self, step: int) -> list[float]:
     122             :         """Return the electric field for the given step."""
     123           1 :         efield_keys: list[RangesKeys] = ["Ex", "Ey", "Ez"]
     124           1 :         return [self.ranges[key][step] if key in self.ranges else 0 for key in efield_keys]
     125             : 
     126           1 :     def get_bfield(self, step: int) -> list[float]:
     127             :         """Return the magnetic field for the given step."""
     128           1 :         bfield_keys: list[RangesKeys] = ["Bx", "By", "Bz"]
     129           1 :         return [self.ranges[key][step] if key in self.ranges else 0 for key in bfield_keys]
     130             : 
     131           1 :     def get_species(self, atom: Optional[int] = None) -> str:
     132             :         """Return the species for the given ket."""
     133           1 :         return self.species[self._check_atom(atom)]
     134             : 
     135           1 :     def get_quantum_numbers(self, atom: Optional[int] = None) -> dict[str, Any]:
     136             :         """Return the quantum numbers for the given ket."""
     137           1 :         return self.quantum_numbers[self._check_atom(atom)]
     138             : 
     139           1 :     def get_quantum_number_restrictions(self, atom: Optional[int] = None) -> dict[str, Any]:
     140             :         """Return the quantum number restrictions for the given ket."""
     141           1 :         atom = self._check_atom(atom)
     142           1 :         qn_restrictions: dict[str, tuple[float, float]] = {}
     143           1 :         for key, delta in self.quantum_number_deltas[atom].items():
     144           1 :             if key in self.quantum_numbers[atom]:
     145           1 :                 qn_restrictions[key] = (
     146             :                     self.quantum_numbers[atom][key] - delta,
     147             :                     self.quantum_numbers[atom][key] + delta,
     148             :                 )
     149             :             else:
     150           0 :                 raise ValueError(f"Quantum number delta {key} not found in quantum numbers.")
     151           1 :         return qn_restrictions
     152             : 
     153           1 :     def _check_atom(self, atom: Optional[int] = None) -> int:
     154             :         """Check if the atom is valid."""
     155           1 :         if atom is not None:
     156           1 :             return atom
     157           0 :         if self.n_atoms == 1:
     158           0 :             return 0
     159           0 :         raise ValueError("Atom index is required for multiple atoms")
     160             : 
     161           1 :     def get_diagonalize_energy_range(self, energy_of_interest: float) -> dict[str, Any]:
     162             :         """Return the kwargs for the diagonalization energy range."""
     163           1 :         if self.diagonalize_relative_energy_range is None:
     164           1 :             return {}
     165           0 :         kwargs: dict[str, Any] = {"energy_unit": "GHz"}
     166           0 :         kwargs["energy_range"] = (
     167             :             energy_of_interest + self.diagonalize_relative_energy_range[0],
     168             :             energy_of_interest + self.diagonalize_relative_energy_range[1],
     169             :         )
     170           0 :         return kwargs
     171             : 
     172           1 :     def get_x_values(self) -> list[float]:
     173             :         """Return the x values for the plot."""
     174           0 :         max_key = self._get_ranges_max_diff_key()
     175           0 :         return self.ranges[max_key]
     176             : 
     177           1 :     def get_x_label(self) -> str:
     178             :         """Return the x values for the plot."""
     179           0 :         max_key = self._get_ranges_max_diff_key()
     180           0 :         x_label = f"{max_key} [{UnitFromRangeKey[max_key]}]"
     181             : 
     182           0 :         non_constant_keys = [key for key, values in self.ranges.items() if key != max_key and values[0] != values[-1]]
     183           0 :         if non_constant_keys:
     184           0 :             x_label += f"  ({', '.join(non_constant_keys)} did also change)"
     185             : 
     186           0 :         return x_label
     187             : 
     188           1 :     def _get_ranges_max_diff_key(self) -> RangesKeys:
     189             :         """Return the key with the maximum difference in the ranges."""
     190           0 :         range_diffs: dict[RangesKeys, float] = {key: abs(r[-1] - r[0]) for key, r in self.ranges.items()}
     191           0 :         return max(range_diffs, key=lambda x: range_diffs.get(x, -1))
     192             : 
     193           1 :     def to_replacement_dict(self) -> dict[str, str]:
     194             :         """Return a dictionary with the parameters for replacement."""
     195           0 :         max_key = self._get_ranges_max_diff_key()
     196           0 :         replacements: dict[str, str] = {
     197             :             "$PI_DTYPE": "real" if self.is_real else "complex",
     198             :             "$X_VARIABLE_NAME": VariableNameFromRangeKey[max_key],
     199             :             "$X_LABEL": as_string(self.get_x_label(), raw_string=True),
     200             :         }
     201             : 
     202           0 :         for atom in range(self.n_atoms):
     203           0 :             replacements[f"$SPECIES_{atom}"] = as_string(self.get_species(atom))
     204           0 :             replacements[f"$QUANTUM_NUMBERS_{atom}"] = dict_to_repl(self.get_quantum_numbers(atom))
     205           0 :             replacements[f"$QUANTUM_NUMBERS_RESTRICTIONS_{atom}"] = dict_to_repl(
     206             :                 self.get_quantum_number_restrictions(atom)
     207             :             )
     208             : 
     209           0 :         replacements["$STEPS"] = str(self.steps)
     210           0 :         for key, values in self.ranges.items():
     211           0 :             replacements[f"${key.upper()}_MIN"] = str(values[0])
     212           0 :             replacements[f"${key.upper()}_MAX"] = str(values[-1])
     213           0 :             if values[0] == values[-1]:
     214           0 :                 replacements[f"${key.upper()}_VALUE"] = str(values[0])
     215             : 
     216           0 :         replacements["$DIAGONALIZE_KWARGS"] = dict_to_repl(self.diagonalize_kwargs)
     217             : 
     218           0 :         if self.diagonalize_relative_energy_range is not None:
     219           0 :             r_energy = self.diagonalize_relative_energy_range
     220           0 :             replacements["$DIAGONALIZE_ENERGY_RANGE_KWARGS"] = (
     221             :                 f', energy_range=(ket_energy + {r_energy[0]}, ket_energy - {-r_energy[1]}), energy_unit="GHz"'
     222             :             )
     223             :         else:
     224           0 :             replacements["$DIAGONALIZE_ENERGY_RANGE_KWARGS"] = ""
     225             : 
     226           0 :         return replacements
     227             : 
     228             : 
     229           1 : @dataclass
     230           1 : class Results(ABC):
     231           1 :     energies: list["NDArray"]
     232           1 :     energy_offset: float
     233           1 :     ket_overlaps: list["NDArray"]
     234           1 :     state_labels: dict[int, list[str]]
     235             : 
     236           1 :     @classmethod
     237           1 :     def from_calculate(
     238             :         cls,
     239             :         parameters: Parameters[Any],
     240             :         system_list: Union[
     241             :             list[pi_real.SystemPair], list[pi_complex.SystemPair], list[pi_real.SystemAtom], list[pi_complex.SystemAtom]
     242             :         ],
     243             :         ket: Union[_wrapped.KetAtom, tuple[_wrapped.KetAtom, ...]],
     244             :         energy_offset: float,
     245             :     ) -> "Self":
     246             :         """Create Results object from ket, basis, and diagonalized systems."""
     247           1 :         energies = [system.get_eigenenergies("GHz") - energy_offset for system in system_list]
     248           1 :         ket_overlaps = [system.get_eigenbasis().get_overlaps(ket) for system in system_list]  # type: ignore [arg-type]
     249             : 
     250           1 :         steps_with_labels = [int(i) for i in np.linspace(0, parameters.steps - 1, parameters.number_state_labels)]
     251           1 :         states_dict = {i: system_list[i].get_eigenbasis().states for i in steps_with_labels}
     252           1 :         state_labels = {i: [s.get_label() for s in states] for i, states in states_dict.items()}
     253             : 
     254           1 :         return cls(energies, energy_offset, ket_overlaps, state_labels)
     255             : 
     256             : 
     257           1 : def as_string(value: str, *, raw_string: bool = False) -> str:
     258           0 :     string = '"' + value + '"'
     259           0 :     if raw_string:
     260           0 :         string = "r" + string
     261           0 :     return string
     262             : 
     263             : 
     264           1 : def dict_to_repl(d: dict[str, Any]) -> str:
     265             :     """Convert a dictionary to a string for replacement."""
     266           0 :     if not d:
     267           0 :         return ""
     268           0 :     repl = ""
     269           0 :     for k, v in d.items():
     270           0 :         if isinstance(v, str):
     271           0 :             repl += f", {k}={as_string(v)}"
     272             :         else:
     273           0 :             repl += f", {k}={v}"
     274           0 :     return repl

Generated by: LCOV version 1.16