LCOV - code coverage report
Current view: top level - src/pairinteraction_gui/calculate - calculate_base.py (source / functions) Hit Total Coverage
Test: coverage.info Lines: 108 138 78.3 %
Date: 2025-08-29 20:47:05 Functions: 13 40 32.5 %

          Line data    Source code
       1             : # SPDX-FileCopyrightText: 2025 PairInteraction Developers
       2             : # SPDX-License-Identifier: LGPL-3.0-or-later
       3             : 
       4           1 : import logging
       5           1 : from abc import ABC
       6           1 : from collections.abc import Mapping
       7           1 : from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union
       8             : 
       9           1 : import numpy as np
      10           1 : from attr import dataclass
      11             : 
      12           1 : from pairinteraction import (
      13             :     _wrapped,
      14             :     complex as pi_complex,
      15             :     real as pi_real,
      16             : )
      17           1 : from pairinteraction_gui.config.system_config import RangesKeys
      18             : 
      19             : if TYPE_CHECKING:
      20             :     from typing_extensions import Self
      21             : 
      22             :     from pairinteraction.units import NDArray
      23             :     from pairinteraction_gui.config.basis_config import QuantumNumberRestrictions
      24             :     from pairinteraction_gui.config.ket_config import QuantumNumbers
      25             :     from pairinteraction_gui.page import OneAtomPage, TwoAtomsPage
      26             : 
      27           1 : logger = logging.getLogger(__name__)
      28             : 
      29           1 : UnitFromRangeKey: dict[RangesKeys, str] = {
      30             :     "Ex": "V/cm",
      31             :     "Ey": "V/cm",
      32             :     "Ez": "V/cm",
      33             :     "Bx": "Gauss",
      34             :     "By": "Gauss",
      35             :     "Bz": "Gauss",
      36             :     "Distance": r"$\mu$m",
      37             :     "Angle": r"$^\circ$",
      38             : }
      39             : 
      40           1 : VariableNameFromRangeKey: dict[RangesKeys, str] = {
      41             :     "Ex": "efield_x",
      42             :     "Ey": "efield_y",
      43             :     "Ez": "efield_z",
      44             :     "Bx": "bfield_x",
      45             :     "By": "bfield_y",
      46             :     "Bz": "bfield_z",
      47             :     "Distance": "distance",
      48             :     "Angle": "angle",
      49             : }
      50             : 
      51           1 : PageType = TypeVar("PageType", "OneAtomPage", "TwoAtomsPage")
      52             : 
      53             : 
      54           1 : @dataclass
      55           1 : class Parameters(ABC, Generic[PageType]):
      56           1 :     species: tuple[str, ...]
      57           1 :     quantum_numbers: tuple["QuantumNumbers", ...]
      58           1 :     quantum_number_restrictions: tuple["QuantumNumberRestrictions", ...]
      59           1 :     ranges: dict[RangesKeys, list[float]]
      60           1 :     diamagnetism_enabled: bool
      61           1 :     diagonalize_kwargs: dict[str, str]
      62           1 :     diagonalize_relative_energy_range: Union[tuple[float, float], None]
      63           1 :     number_state_labels: int
      64             : 
      65           1 :     def __post_init__(self) -> None:
      66             :         """Post-initialization processing."""
      67             :         # Check if all ranges have the same number of steps
      68           0 :         if not all(len(v) == self.steps for v in self.ranges.values()):
      69           0 :             raise ValueError("All ranges must have the same number of steps")
      70             : 
      71             :         # Check if all tuples have the same length
      72           0 :         if not all(
      73             :             len(tup) == self.n_atoms for tup in [self.species, self.quantum_numbers, self.quantum_number_restrictions]
      74             :         ):
      75           0 :             raise ValueError("All tuples must have the same length as the number of atoms")
      76             : 
      77           1 :     @classmethod
      78           1 :     def from_page(cls, page: PageType) -> "Self":
      79             :         """Create Parameters object from page."""
      80           1 :         n_atoms = page.ket_config.n_atoms
      81             : 
      82           1 :         species = tuple(page.ket_config.get_species(atom) for atom in range(n_atoms))
      83           1 :         quantum_numbers = tuple(page.ket_config.get_quantum_numbers(atom) for atom in range(n_atoms))
      84             : 
      85           1 :         quantum_number_restrictions = tuple(
      86             :             page.basis_config.get_quantum_number_restrictions(atom) for atom in range(n_atoms)
      87             :         )
      88             : 
      89           1 :         ranges = page.system_config.get_ranges_dict()
      90           1 :         diamagnetism_enabled = page.system_config.diamagnetism.isChecked()
      91             : 
      92           1 :         diagonalize_kwargs = {}
      93           1 :         if page.calculation_config.fast_mode.isChecked():
      94           1 :             diagonalize_kwargs["diagonalizer"] = "lapacke_evr"
      95           1 :             diagonalize_kwargs["float_type"] = "float32"
      96             : 
      97           1 :         diagonalize_relative_energy_range = None
      98           1 :         if page.calculation_config.energy_range.isChecked():
      99           0 :             diagonalize_relative_energy_range = page.calculation_config.energy_range.values()
     100             : 
     101           1 :         return cls(
     102             :             species,
     103             :             quantum_numbers,
     104             :             quantum_number_restrictions,
     105             :             ranges,
     106             :             diamagnetism_enabled,
     107             :             diagonalize_kwargs,
     108             :             diagonalize_relative_energy_range,
     109             :             page.calculation_config.number_state_labels.value(default=0),
     110             :         )
     111             : 
     112           1 :     @property
     113           1 :     def is_real(self) -> bool:
     114             :         """Check if the parameters are real."""
     115           1 :         return all(e == 0 for e in self.ranges.get("Ey", [0])) and all(b == 0 for b in self.ranges.get("By", [0]))
     116             : 
     117           1 :     @property
     118           1 :     def steps(self) -> int:
     119             :         """Return the number of steps."""
     120           1 :         return len(next(iter(self.ranges.values())))
     121             : 
     122           1 :     @property
     123           1 :     def n_atoms(self) -> int:
     124             :         """Return the number of atoms."""
     125           1 :         return len(self.species)
     126             : 
     127           1 :     def get_efield(self, step: int) -> list[float]:
     128             :         """Return the electric field for the given step."""
     129           0 :         efield_keys: list[RangesKeys] = ["Ex", "Ey", "Ez"]
     130           0 :         return [self.ranges[key][step] if key in self.ranges else 0 for key in efield_keys]
     131             : 
     132           1 :     def get_bfield(self, step: int) -> list[float]:
     133             :         """Return the magnetic field for the given step."""
     134           0 :         bfield_keys: list[RangesKeys] = ["Bx", "By", "Bz"]
     135           0 :         return [self.ranges[key][step] if key in self.ranges else 0 for key in bfield_keys]
     136             : 
     137           1 :     def get_species(self, atom: Optional[int] = None) -> str:
     138             :         """Return the species for the given ket."""
     139           1 :         return self.species[self._check_atom(atom)]
     140             : 
     141           1 :     def get_quantum_numbers(self, atom: Optional[int] = None) -> "QuantumNumbers":
     142             :         """Return the quantum numbers for the given ket."""
     143           1 :         return self.quantum_numbers[self._check_atom(atom)]
     144             : 
     145           1 :     def get_ket_atom(self, atom: Optional[int] = None) -> _wrapped.KetAtom:
     146             :         """Return the ket atom for the given atom index."""
     147           0 :         return _wrapped.KetAtom(self.get_species(atom), **self.get_quantum_numbers(atom))
     148             : 
     149           1 :     def get_quantum_number_restrictions(self, atom: Optional[int] = None) -> "QuantumNumberRestrictions":
     150             :         """Return the quantum number restrictions."""
     151           1 :         return self.quantum_number_restrictions[self._check_atom(atom)]
     152             : 
     153           1 :     def _check_atom(self, atom: Optional[int] = None) -> int:
     154             :         """Check if the atom is valid."""
     155           1 :         if atom is not None:
     156           1 :             return atom
     157           0 :         if self.n_atoms == 1:
     158           0 :             return 0
     159           0 :         raise ValueError("Atom index is required for multiple atoms")
     160             : 
     161           1 :     def get_diagonalize_energy_range_kwargs(self, energy_of_interest: float) -> dict[str, Any]:
     162             :         """Return the kwargs for the diagonalization energy range."""
     163           0 :         if self.diagonalize_relative_energy_range is None:
     164           0 :             return {}
     165           0 :         kwargs: dict[str, Any] = {"energy_range_unit": "GHz"}
     166           0 :         kwargs["energy_range"] = (
     167             :             energy_of_interest + self.diagonalize_relative_energy_range[0],
     168             :             energy_of_interest + self.diagonalize_relative_energy_range[1],
     169             :         )
     170           0 :         return kwargs
     171             : 
     172           1 :     def get_x_values(self) -> list[float]:
     173             :         """Return the x values for the plot."""
     174           0 :         max_key = self._get_ranges_max_diff_key()
     175           0 :         return self.ranges[max_key]
     176             : 
     177           1 :     def get_x_label(self) -> str:
     178             :         """Return the x values for the plot."""
     179           1 :         max_key = self._get_ranges_max_diff_key()
     180           1 :         x_label = f"{max_key} [{UnitFromRangeKey[max_key]}]"
     181             : 
     182           1 :         non_constant_keys = [key for key, values in self.ranges.items() if key != max_key and values[0] != values[-1]]
     183           1 :         if non_constant_keys:
     184           0 :             x_label += f"  ({', '.join(non_constant_keys)} did also change)"
     185             : 
     186           1 :         return x_label
     187             : 
     188           1 :     def _get_ranges_max_diff_key(self) -> RangesKeys:
     189             :         """Return the key with the maximum difference in the ranges."""
     190           1 :         range_diffs: dict[RangesKeys, float] = {key: abs(r[-1] - r[0]) for key, r in self.ranges.items()}
     191           1 :         return max(range_diffs, key=lambda x: range_diffs.get(x, -1))
     192             : 
     193           1 :     def to_replacement_dict(self) -> dict[str, str]:
     194             :         """Return a dictionary with the parameters for replacement."""
     195           1 :         max_key = self._get_ranges_max_diff_key()
     196           1 :         replacements: dict[str, str] = {
     197             :             "$PI_DTYPE": "real" if self.is_real else "complex",
     198             :             "$X_VARIABLE_NAME": VariableNameFromRangeKey[max_key],
     199             :             "$X_LABEL": as_string(self.get_x_label(), raw_string=True),
     200             :             "$DIAMAGNETISM_ENABLED": str(self.diamagnetism_enabled),
     201             :         }
     202             : 
     203           1 :         for atom in range(self.n_atoms):
     204           1 :             replacements[f"$SPECIES_{atom}"] = as_string(self.get_species(atom))
     205           1 :             replacements[f"$QUANTUM_NUMBERS_{atom}"] = dict_to_repl(self.get_quantum_numbers(atom))
     206           1 :             replacements[f"$QUANTUM_NUMBERS_RESTRICTIONS_{atom}"] = dict_to_repl(
     207             :                 self.get_quantum_number_restrictions(atom)
     208             :             )
     209             : 
     210           1 :         replacements["$STEPS"] = str(self.steps)
     211           1 :         for key, values in self.ranges.items():
     212           1 :             replacements[f"${key.upper()}_MIN"] = str(values[0])
     213           1 :             replacements[f"${key.upper()}_MAX"] = str(values[-1])
     214           1 :             if values[0] == values[-1]:
     215           1 :                 replacements[f"${key.upper()}_VALUE"] = str(values[0])
     216             : 
     217           1 :         replacements["$DIAGONALIZE_KWARGS"] = dict_to_repl(self.diagonalize_kwargs)
     218             : 
     219           1 :         if self.diagonalize_relative_energy_range is not None:
     220           0 :             r_energy = self.diagonalize_relative_energy_range
     221           0 :             replacements["$DIAGONALIZE_ENERGY_RANGE_KWARGS"] = (
     222             :                 f', energy_range=(ket_energy + {r_energy[0]}, ket_energy - {-r_energy[1]}), energy_range_unit="GHz"'
     223             :             )
     224             :         else:
     225           1 :             replacements["$DIAGONALIZE_ENERGY_RANGE_KWARGS"] = ""
     226             : 
     227           1 :         return replacements
     228             : 
     229             : 
     230           1 : @dataclass
     231           1 : class Results(ABC):
     232           1 :     energies: list["NDArray"]
     233           1 :     energy_offset: float
     234           1 :     ket_overlaps: list["NDArray"]
     235           1 :     state_labels: dict[int, list[str]]
     236             : 
     237           1 :     @classmethod
     238           1 :     def from_calculate(
     239             :         cls,
     240             :         parameters: Parameters[Any],
     241             :         system_list: Union[
     242             :             list[pi_real.SystemPair], list[pi_complex.SystemPair], list[pi_real.SystemAtom], list[pi_complex.SystemAtom]
     243             :         ],
     244             :         ket: Union[_wrapped.KetAtom, tuple[_wrapped.KetAtom, ...]],
     245             :         energy_offset: float,
     246             :     ) -> "Self":
     247             :         """Create Results object from ket, basis, and diagonalized systems."""
     248           0 :         energies = [system.get_eigenenergies("GHz") - energy_offset for system in system_list]
     249           0 :         ket_overlaps = [system.get_eigenbasis().get_overlaps(ket) for system in system_list]  # type: ignore [arg-type]
     250             : 
     251           0 :         steps_with_labels = [int(i) for i in np.linspace(0, parameters.steps - 1, parameters.number_state_labels)]
     252           0 :         states_dict = {i: system_list[i].get_eigenbasis().states for i in steps_with_labels}
     253           0 :         state_labels = {i: [s.get_label() for s in states] for i, states in states_dict.items()}
     254             : 
     255           0 :         return cls(energies, energy_offset, ket_overlaps, state_labels)
     256             : 
     257             : 
     258           1 : def as_string(value: str, *, raw_string: bool = False) -> str:
     259           1 :     string = '"' + value + '"'
     260           1 :     if raw_string:
     261           1 :         string = "r" + string
     262           1 :     return string
     263             : 
     264             : 
     265           1 : def dict_to_repl(d: Mapping[str, Any]) -> str:
     266             :     """Convert a dictionary to a string for replacement."""
     267           1 :     if not d:
     268           0 :         return ""
     269           1 :     repl = ""
     270           1 :     for k, v in d.items():
     271           1 :         if isinstance(v, str):
     272           1 :             repl += f", {k}={as_string(v)}"
     273             :         else:
     274           1 :             repl += f", {k}={v}"
     275           1 :     return repl

Generated by: LCOV version 1.16