LCOV - code coverage report
Current view: top level - src/pairinteraction_gui/calculate - calculate_base.py (source / functions) Hit Total Coverage
Test: coverage.info Lines: 127 137 92.7 %
Date: 2026-04-17 09:29:39 Functions: 18 20 90.0 %

          Line data    Source code
       1             : # SPDX-FileCopyrightText: 2025 PairInteraction Developers
       2             : # SPDX-License-Identifier: LGPL-3.0-or-later
       3           1 : from __future__ import annotations
       4             : 
       5           1 : import logging
       6           1 : from abc import ABC
       7           1 : from typing import TYPE_CHECKING, Any, Generic, TypeVar
       8             : 
       9           1 : from attr import dataclass
      10             : 
      11           1 : import pairinteraction as pi
      12             : 
      13             : if TYPE_CHECKING:
      14             :     from collections.abc import Mapping, Sequence
      15             : 
      16             :     from typing_extensions import Self
      17             : 
      18             :     from pairinteraction.system import SystemBase
      19             :     from pairinteraction.units import NDArray
      20             :     from pairinteraction_gui.config.basis_config import QuantumNumberRestrictions
      21             :     from pairinteraction_gui.config.ket_config import QuantumNumbers
      22             :     from pairinteraction_gui.config.system_config import RangesKeys
      23             :     from pairinteraction_gui.page import OneAtomPage, TwoAtomsPage
      24             : 
      25           1 : logger = logging.getLogger(__name__)
      26             : 
      27           1 : UnitFromRangeKey: dict[RangesKeys, str] = {
      28             :     "Ex": "V/cm",
      29             :     "Ey": "V/cm",
      30             :     "Ez": "V/cm",
      31             :     "Bx": "Gauss",
      32             :     "By": "Gauss",
      33             :     "Bz": "Gauss",
      34             :     "Distance": r"$\mu$m",
      35             :     "Angle": r"$^\circ$",
      36             : }
      37             : 
      38           1 : VariableNameFromRangeKey: dict[RangesKeys, str] = {
      39             :     "Ex": "efield_x",
      40             :     "Ey": "efield_y",
      41             :     "Ez": "efield_z",
      42             :     "Bx": "bfield_x",
      43             :     "By": "bfield_y",
      44             :     "Bz": "bfield_z",
      45             :     "Distance": "distance",
      46             :     "Angle": "angle",
      47             : }
      48             : 
      49           1 : PageType = TypeVar("PageType", "OneAtomPage", "TwoAtomsPage")
      50             : 
      51             : 
      52           1 : @dataclass
      53           1 : class Parameters(ABC, Generic[PageType]):
      54           1 :     species: tuple[str, ...]
      55           1 :     quantum_numbers: tuple[QuantumNumbers, ...]
      56           1 :     quantum_number_restrictions: tuple[QuantumNumberRestrictions, ...]
      57           1 :     ranges: dict[RangesKeys, list[float]]
      58           1 :     diamagnetism_enabled: bool
      59           1 :     diagonalize_kwargs: dict[str, str]
      60           1 :     diagonalize_relative_energy_range: tuple[float, float] | None
      61             : 
      62           1 :     def __post_init__(self) -> None:
      63             :         """Post-initialization processing."""
      64             :         # Check if all ranges have the same number of steps
      65           0 :         if not all(len(v) == self.steps for v in self.ranges.values()):
      66           0 :             raise ValueError("All ranges must have the same number of steps")
      67             : 
      68             :         # Check if all tuples have the same length
      69           0 :         if not all(
      70             :             len(tup) == self.n_atoms for tup in [self.species, self.quantum_numbers, self.quantum_number_restrictions]
      71             :         ):
      72           0 :             raise ValueError("All tuples must have the same length as the number of atoms")
      73             : 
      74           1 :     @classmethod
      75           1 :     def from_page(cls, page: PageType) -> Self:
      76             :         """Create Parameters object from page."""
      77           1 :         n_atoms = page.ket_config.n_atoms
      78             : 
      79           1 :         species = tuple(page.ket_config.get_species(atom) for atom in range(n_atoms))
      80           1 :         quantum_numbers = tuple(page.ket_config.get_quantum_numbers(atom) for atom in range(n_atoms))
      81             : 
      82           1 :         quantum_number_restrictions = tuple(
      83             :             page.basis_config.get_quantum_number_restrictions(atom) for atom in range(n_atoms)
      84             :         )
      85             : 
      86           1 :         ranges = page.system_config.get_ranges_dict()
      87           1 :         diamagnetism_enabled = page.system_config.diamagnetism.isChecked()
      88             : 
      89           1 :         diagonalize_kwargs = {}
      90           1 :         if page.calculation_config.fast_mode.isChecked():
      91           1 :             diagonalize_kwargs["diagonalizer"] = "lapacke_evr"
      92           1 :             diagonalize_kwargs["float_type"] = "float32"
      93             : 
      94           1 :         diagonalize_relative_energy_range = None
      95           1 :         if page.calculation_config.energy_range.isChecked():
      96           1 :             diagonalize_relative_energy_range = page.calculation_config.energy_range.values()
      97             : 
      98           1 :         return cls(
      99             :             species,
     100             :             quantum_numbers,
     101             :             quantum_number_restrictions,
     102             :             ranges,
     103             :             diamagnetism_enabled,
     104             :             diagonalize_kwargs,
     105             :             diagonalize_relative_energy_range,
     106             :         )
     107             : 
     108           1 :     @property
     109           1 :     def is_real(self) -> bool:
     110             :         """Check if the parameters are real."""
     111           1 :         return all(e == 0 for e in self.ranges.get("Ey", [0])) and all(b == 0 for b in self.ranges.get("By", [0]))
     112             : 
     113           1 :     @property
     114           1 :     def steps(self) -> int:
     115             :         """Return the number of steps."""
     116           1 :         return len(next(iter(self.ranges.values())))
     117             : 
     118           1 :     @property
     119           1 :     def n_atoms(self) -> int:
     120             :         """Return the number of atoms."""
     121           1 :         return len(self.species)
     122             : 
     123           1 :     def get_efield(self, step: int) -> list[float]:
     124             :         """Return the electric field for the given step."""
     125           1 :         efield_keys: list[RangesKeys] = ["Ex", "Ey", "Ez"]
     126           1 :         return [self.ranges[key][step] if key in self.ranges else 0 for key in efield_keys]
     127             : 
     128           1 :     def get_bfield(self, step: int) -> list[float]:
     129             :         """Return the magnetic field for the given step."""
     130           1 :         bfield_keys: list[RangesKeys] = ["Bx", "By", "Bz"]
     131           1 :         return [self.ranges[key][step] if key in self.ranges else 0 for key in bfield_keys]
     132             : 
     133           1 :     def get_species(self, atom: int | None = None) -> str:
     134             :         """Return the species for the given ket."""
     135           1 :         return self.species[self._check_atom(atom)]
     136             : 
     137           1 :     def get_quantum_numbers(self, atom: int | None = None) -> QuantumNumbers:
     138             :         """Return the quantum numbers for the given ket."""
     139           1 :         return self.quantum_numbers[self._check_atom(atom)]
     140             : 
     141           1 :     def get_ket_atom(self, atom: int | None = None) -> pi.KetAtom:
     142             :         """Return the ket atom for the given atom index."""
     143           0 :         return pi.KetAtom(self.get_species(atom), **self.get_quantum_numbers(atom))
     144             : 
     145           1 :     def get_quantum_number_restrictions(self, atom: int | None = None) -> QuantumNumberRestrictions:
     146             :         """Return the quantum number restrictions."""
     147           1 :         return self.quantum_number_restrictions[self._check_atom(atom)]
     148             : 
     149           1 :     def _check_atom(self, atom: int | None = None) -> int:
     150             :         """Check if the atom is valid."""
     151           1 :         if atom is not None:
     152           1 :             return atom
     153           1 :         if self.n_atoms == 1:
     154           1 :             return 0
     155           0 :         raise ValueError("Atom index is required for multiple atoms")
     156             : 
     157           1 :     def get_diagonalize_energy_range_kwargs(self, energy_of_interest: float) -> dict[str, Any]:
     158             :         """Return the kwargs for the diagonalization energy range."""
     159           1 :         if self.diagonalize_relative_energy_range is None:
     160           1 :             return {}
     161           1 :         kwargs: dict[str, Any] = {"energy_range_unit": "GHz"}
     162           1 :         kwargs["energy_range"] = (
     163             :             energy_of_interest + self.diagonalize_relative_energy_range[0],
     164             :             energy_of_interest + self.diagonalize_relative_energy_range[1],
     165             :         )
     166           1 :         return kwargs
     167             : 
     168           1 :     def get_x_values(self) -> list[float]:
     169             :         """Return the x values for the plot."""
     170           1 :         max_key = self._get_ranges_max_diff_key()
     171           1 :         return self.ranges[max_key]
     172             : 
     173           1 :     def get_x_label(self) -> str:
     174             :         """Return the x values for the plot."""
     175           1 :         max_key = self._get_ranges_max_diff_key()
     176           1 :         x_label = f"{max_key} ({UnitFromRangeKey[max_key]})"
     177             : 
     178           1 :         non_constant_keys = [key for key, values in self.ranges.items() if key != max_key and values[0] != values[-1]]
     179           1 :         if non_constant_keys:
     180           0 :             x_label += f"  ({', '.join(non_constant_keys)} did also change)"
     181             : 
     182           1 :         return x_label
     183             : 
     184           1 :     def _get_ranges_max_diff_key(self) -> RangesKeys:
     185             :         """Return the key with the maximum difference in the ranges."""
     186           1 :         range_diffs: dict[RangesKeys, float] = {key: abs(r[-1] - r[0]) for key, r in self.ranges.items()}
     187           1 :         return max(range_diffs, key=lambda x: range_diffs.get(x, -1))
     188             : 
     189           1 :     def to_replacement_dict(self) -> dict[str, str]:
     190             :         """Return a dictionary with the parameters for replacement."""
     191           1 :         max_key = self._get_ranges_max_diff_key()
     192           1 :         replacements: dict[str, str] = {
     193             :             "$PI_DTYPE": "real" if self.is_real else "complex",
     194             :             "$X_VARIABLE_NAME": VariableNameFromRangeKey[max_key],
     195             :             "$X_LABEL": as_string(self.get_x_label(), raw_string=True),
     196             :             "$DIAMAGNETISM_ENABLED": str(self.diamagnetism_enabled),
     197             :         }
     198             : 
     199           1 :         for atom in range(self.n_atoms):
     200           1 :             replacements[f"$SPECIES_{atom}"] = as_string(self.get_species(atom))
     201           1 :             replacements[f"$QUANTUM_NUMBERS_{atom}"] = dict_to_repl(self.get_quantum_numbers(atom))
     202           1 :             replacements[f"$QUANTUM_NUMBERS_RESTRICTIONS_{atom}"] = dict_to_repl(
     203             :                 self.get_quantum_number_restrictions(atom)
     204             :             )
     205             : 
     206           1 :         replacements["$STEPS"] = str(self.steps)
     207           1 :         for key, values in self.ranges.items():
     208           1 :             replacements[f"${key.upper()}_MIN"] = str(values[0])
     209           1 :             replacements[f"${key.upper()}_MAX"] = str(values[-1])
     210           1 :             if values[0] == values[-1]:
     211           1 :                 replacements[f"${key.upper()}_VALUE"] = str(values[0])
     212             : 
     213           1 :         replacements["$DIAGONALIZE_KWARGS"] = dict_to_repl(self.diagonalize_kwargs)
     214             : 
     215           1 :         if self.diagonalize_relative_energy_range is None:
     216           0 :             replacements["$DIAGONALIZE_ENERGY_RANGE_KWARGS"] = ""
     217           1 :         elif self.n_atoms == 1:
     218           1 :             r_energy = self.diagonalize_relative_energy_range
     219           1 :             replacements["$DIAGONALIZE_ENERGY_RANGE_KWARGS"] = (
     220             :                 f', energy_range=(ket_energy + {r_energy[0]}, ket_energy - {-r_energy[1]}), energy_range_unit="GHz"'
     221             :             )
     222           1 :         elif self.n_atoms == 2:
     223           1 :             r_energy = self.diagonalize_relative_energy_range
     224           1 :             replacements["$DIAGONALIZE_ENERGY_RANGE_KWARGS"] = (
     225             :                 f', energy_range=(pair_energy + {r_energy[0]}, pair_energy - {-r_energy[1]}), energy_range_unit="GHz"'
     226             :             )
     227             :         else:
     228           0 :             raise RuntimeError("Energy range kwargs not implemented for more than two atoms")
     229             : 
     230           1 :         return replacements
     231             : 
     232             : 
     233           1 : @dataclass
     234           1 : class Results(ABC):
     235           1 :     energies: list[NDArray]
     236           1 :     energy_offset: float
     237           1 :     ket_overlaps: list[NDArray]
     238           1 :     systems: list[SystemBase[Any]]
     239             : 
     240           1 :     @classmethod
     241           1 :     def from_calculate(
     242             :         cls,
     243             :         parameters: Parameters[Any],
     244             :         system_list: Sequence[SystemBase[Any]],
     245             :         ket: pi.KetAtom | tuple[pi.KetAtom, ...],
     246             :         energy_offset: float,
     247             :     ) -> Self:
     248             :         """Create Results object from ket, basis, and diagonalized systems."""
     249           1 :         energies = [system.get_eigenenergies("GHz") - energy_offset for system in system_list]
     250             : 
     251           1 :         ket_overlaps = [system.get_eigenbasis().get_overlaps(ket) for system in system_list]
     252             : 
     253           1 :         return cls(energies, energy_offset, ket_overlaps, list(system_list))
     254             : 
     255             : 
     256           1 : def as_string(value: str, *, raw_string: bool = False) -> str:
     257           1 :     string = '"' + value + '"'
     258           1 :     if raw_string:
     259           1 :         string = "r" + string
     260           1 :     return string
     261             : 
     262             : 
     263           1 : def dict_to_repl(d: Mapping[str, Any]) -> str:
     264             :     """Convert a dictionary to a string for replacement."""
     265           1 :     if not d:
     266           0 :         return ""
     267           1 :     repl = ""
     268           1 :     for k, v in d.items():
     269           1 :         if isinstance(v, str):
     270           1 :             repl += f", {k}={as_string(v)}"
     271             :         else:
     272           1 :             repl += f", {k}={v}"
     273           1 :     return repl

Generated by: LCOV version 1.16