LCOV - code coverage report
Current view: top level - src/pairinteraction_gui/calculate - calculate_base.py (source / functions) Hit Total Coverage
Test: coverage.info Lines: 107 137 78.1 %
Date: 2025-09-29 10:28:29 Functions: 13 40 32.5 %

          Line data    Source code
       1             : # SPDX-FileCopyrightText: 2025 PairInteraction Developers
       2             : # SPDX-License-Identifier: LGPL-3.0-or-later
       3           1 : from __future__ import annotations
       4             : 
       5           1 : import logging
       6           1 : from abc import ABC
       7           1 : from typing import TYPE_CHECKING, Any, Generic, TypeVar
       8             : 
       9           1 : import numpy as np
      10           1 : from attr import dataclass
      11             : 
      12           1 : import pairinteraction as pi
      13             : 
      14             : if TYPE_CHECKING:
      15             :     from collections.abc import Mapping
      16             : 
      17             :     from typing_extensions import Self
      18             : 
      19             :     from pairinteraction.system import SystemAtom, SystemAtomReal, SystemPair, SystemPairReal
      20             :     from pairinteraction.units import NDArray
      21             :     from pairinteraction_gui.config.basis_config import QuantumNumberRestrictions
      22             :     from pairinteraction_gui.config.ket_config import QuantumNumbers
      23             :     from pairinteraction_gui.config.system_config import RangesKeys
      24             :     from pairinteraction_gui.page import OneAtomPage, TwoAtomsPage
      25             : 
      26           1 : logger = logging.getLogger(__name__)
      27             : 
      28           1 : UnitFromRangeKey: dict[RangesKeys, str] = {
      29             :     "Ex": "V/cm",
      30             :     "Ey": "V/cm",
      31             :     "Ez": "V/cm",
      32             :     "Bx": "Gauss",
      33             :     "By": "Gauss",
      34             :     "Bz": "Gauss",
      35             :     "Distance": r"$\mu$m",
      36             :     "Angle": r"$^\circ$",
      37             : }
      38             : 
      39           1 : VariableNameFromRangeKey: dict[RangesKeys, str] = {
      40             :     "Ex": "efield_x",
      41             :     "Ey": "efield_y",
      42             :     "Ez": "efield_z",
      43             :     "Bx": "bfield_x",
      44             :     "By": "bfield_y",
      45             :     "Bz": "bfield_z",
      46             :     "Distance": "distance",
      47             :     "Angle": "angle",
      48             : }
      49             : 
      50           1 : PageType = TypeVar("PageType", "OneAtomPage", "TwoAtomsPage")
      51             : 
      52             : 
      53           1 : @dataclass
      54           1 : class Parameters(ABC, Generic[PageType]):
      55           1 :     species: tuple[str, ...]
      56           1 :     quantum_numbers: tuple[QuantumNumbers, ...]
      57           1 :     quantum_number_restrictions: tuple[QuantumNumberRestrictions, ...]
      58           1 :     ranges: dict[RangesKeys, list[float]]
      59           1 :     diamagnetism_enabled: bool
      60           1 :     diagonalize_kwargs: dict[str, str]
      61           1 :     diagonalize_relative_energy_range: tuple[float, float] | None
      62           1 :     number_state_labels: int
      63             : 
      64           1 :     def __post_init__(self) -> None:
      65             :         """Post-initialization processing."""
      66             :         # Check if all ranges have the same number of steps
      67           0 :         if not all(len(v) == self.steps for v in self.ranges.values()):
      68           0 :             raise ValueError("All ranges must have the same number of steps")
      69             : 
      70             :         # Check if all tuples have the same length
      71           0 :         if not all(
      72             :             len(tup) == self.n_atoms for tup in [self.species, self.quantum_numbers, self.quantum_number_restrictions]
      73             :         ):
      74           0 :             raise ValueError("All tuples must have the same length as the number of atoms")
      75             : 
      76           1 :     @classmethod
      77           1 :     def from_page(cls, page: PageType) -> Self:
      78             :         """Create Parameters object from page."""
      79           1 :         n_atoms = page.ket_config.n_atoms
      80             : 
      81           1 :         species = tuple(page.ket_config.get_species(atom) for atom in range(n_atoms))
      82           1 :         quantum_numbers = tuple(page.ket_config.get_quantum_numbers(atom) for atom in range(n_atoms))
      83             : 
      84           1 :         quantum_number_restrictions = tuple(
      85             :             page.basis_config.get_quantum_number_restrictions(atom) for atom in range(n_atoms)
      86             :         )
      87             : 
      88           1 :         ranges = page.system_config.get_ranges_dict()
      89           1 :         diamagnetism_enabled = page.system_config.diamagnetism.isChecked()
      90             : 
      91           1 :         diagonalize_kwargs = {}
      92           1 :         if page.calculation_config.fast_mode.isChecked():
      93           1 :             diagonalize_kwargs["diagonalizer"] = "lapacke_evr"
      94           1 :             diagonalize_kwargs["float_type"] = "float32"
      95             : 
      96           1 :         diagonalize_relative_energy_range = None
      97           1 :         if page.calculation_config.energy_range.isChecked():
      98           0 :             diagonalize_relative_energy_range = page.calculation_config.energy_range.values()
      99             : 
     100           1 :         return cls(
     101             :             species,
     102             :             quantum_numbers,
     103             :             quantum_number_restrictions,
     104             :             ranges,
     105             :             diamagnetism_enabled,
     106             :             diagonalize_kwargs,
     107             :             diagonalize_relative_energy_range,
     108             :             page.calculation_config.number_state_labels.value(default=0),
     109             :         )
     110             : 
     111           1 :     @property
     112           1 :     def is_real(self) -> bool:
     113             :         """Check if the parameters are real."""
     114           1 :         return all(e == 0 for e in self.ranges.get("Ey", [0])) and all(b == 0 for b in self.ranges.get("By", [0]))
     115             : 
     116           1 :     @property
     117           1 :     def steps(self) -> int:
     118             :         """Return the number of steps."""
     119           1 :         return len(next(iter(self.ranges.values())))
     120             : 
     121           1 :     @property
     122           1 :     def n_atoms(self) -> int:
     123             :         """Return the number of atoms."""
     124           1 :         return len(self.species)
     125             : 
     126           1 :     def get_efield(self, step: int) -> list[float]:
     127             :         """Return the electric field for the given step."""
     128           0 :         efield_keys: list[RangesKeys] = ["Ex", "Ey", "Ez"]
     129           0 :         return [self.ranges[key][step] if key in self.ranges else 0 for key in efield_keys]
     130             : 
     131           1 :     def get_bfield(self, step: int) -> list[float]:
     132             :         """Return the magnetic field for the given step."""
     133           0 :         bfield_keys: list[RangesKeys] = ["Bx", "By", "Bz"]
     134           0 :         return [self.ranges[key][step] if key in self.ranges else 0 for key in bfield_keys]
     135             : 
     136           1 :     def get_species(self, atom: int | None = None) -> str:
     137             :         """Return the species for the given ket."""
     138           1 :         return self.species[self._check_atom(atom)]
     139             : 
     140           1 :     def get_quantum_numbers(self, atom: int | None = None) -> QuantumNumbers:
     141             :         """Return the quantum numbers for the given ket."""
     142           1 :         return self.quantum_numbers[self._check_atom(atom)]
     143             : 
     144           1 :     def get_ket_atom(self, atom: int | None = None) -> pi.KetAtom:
     145             :         """Return the ket atom for the given atom index."""
     146           0 :         return pi.KetAtom(self.get_species(atom), **self.get_quantum_numbers(atom))
     147             : 
     148           1 :     def get_quantum_number_restrictions(self, atom: int | None = None) -> QuantumNumberRestrictions:
     149             :         """Return the quantum number restrictions."""
     150           1 :         return self.quantum_number_restrictions[self._check_atom(atom)]
     151             : 
     152           1 :     def _check_atom(self, atom: int | None = None) -> int:
     153             :         """Check if the atom is valid."""
     154           1 :         if atom is not None:
     155           1 :             return atom
     156           0 :         if self.n_atoms == 1:
     157           0 :             return 0
     158           0 :         raise ValueError("Atom index is required for multiple atoms")
     159             : 
     160           1 :     def get_diagonalize_energy_range_kwargs(self, energy_of_interest: float) -> dict[str, Any]:
     161             :         """Return the kwargs for the diagonalization energy range."""
     162           0 :         if self.diagonalize_relative_energy_range is None:
     163           0 :             return {}
     164           0 :         kwargs: dict[str, Any] = {"energy_range_unit": "GHz"}
     165           0 :         kwargs["energy_range"] = (
     166             :             energy_of_interest + self.diagonalize_relative_energy_range[0],
     167             :             energy_of_interest + self.diagonalize_relative_energy_range[1],
     168             :         )
     169           0 :         return kwargs
     170             : 
     171           1 :     def get_x_values(self) -> list[float]:
     172             :         """Return the x values for the plot."""
     173           0 :         max_key = self._get_ranges_max_diff_key()
     174           0 :         return self.ranges[max_key]
     175             : 
     176           1 :     def get_x_label(self) -> str:
     177             :         """Return the x values for the plot."""
     178           1 :         max_key = self._get_ranges_max_diff_key()
     179           1 :         x_label = f"{max_key} [{UnitFromRangeKey[max_key]}]"
     180             : 
     181           1 :         non_constant_keys = [key for key, values in self.ranges.items() if key != max_key and values[0] != values[-1]]
     182           1 :         if non_constant_keys:
     183           0 :             x_label += f"  ({', '.join(non_constant_keys)} did also change)"
     184             : 
     185           1 :         return x_label
     186             : 
     187           1 :     def _get_ranges_max_diff_key(self) -> RangesKeys:
     188             :         """Return the key with the maximum difference in the ranges."""
     189           1 :         range_diffs: dict[RangesKeys, float] = {key: abs(r[-1] - r[0]) for key, r in self.ranges.items()}
     190           1 :         return max(range_diffs, key=lambda x: range_diffs.get(x, -1))
     191             : 
     192           1 :     def to_replacement_dict(self) -> dict[str, str]:
     193             :         """Return a dictionary with the parameters for replacement."""
     194           1 :         max_key = self._get_ranges_max_diff_key()
     195           1 :         replacements: dict[str, str] = {
     196             :             "$PI_DTYPE": "real" if self.is_real else "complex",
     197             :             "$X_VARIABLE_NAME": VariableNameFromRangeKey[max_key],
     198             :             "$X_LABEL": as_string(self.get_x_label(), raw_string=True),
     199             :             "$DIAMAGNETISM_ENABLED": str(self.diamagnetism_enabled),
     200             :         }
     201             : 
     202           1 :         for atom in range(self.n_atoms):
     203           1 :             replacements[f"$SPECIES_{atom}"] = as_string(self.get_species(atom))
     204           1 :             replacements[f"$QUANTUM_NUMBERS_{atom}"] = dict_to_repl(self.get_quantum_numbers(atom))
     205           1 :             replacements[f"$QUANTUM_NUMBERS_RESTRICTIONS_{atom}"] = dict_to_repl(
     206             :                 self.get_quantum_number_restrictions(atom)
     207             :             )
     208             : 
     209           1 :         replacements["$STEPS"] = str(self.steps)
     210           1 :         for key, values in self.ranges.items():
     211           1 :             replacements[f"${key.upper()}_MIN"] = str(values[0])
     212           1 :             replacements[f"${key.upper()}_MAX"] = str(values[-1])
     213           1 :             if values[0] == values[-1]:
     214           1 :                 replacements[f"${key.upper()}_VALUE"] = str(values[0])
     215             : 
     216           1 :         replacements["$DIAGONALIZE_KWARGS"] = dict_to_repl(self.diagonalize_kwargs)
     217             : 
     218           1 :         if self.diagonalize_relative_energy_range is not None:
     219           0 :             r_energy = self.diagonalize_relative_energy_range
     220           0 :             replacements["$DIAGONALIZE_ENERGY_RANGE_KWARGS"] = (
     221             :                 f', energy_range=(ket_energy + {r_energy[0]}, ket_energy - {-r_energy[1]}), energy_range_unit="GHz"'
     222             :             )
     223             :         else:
     224           1 :             replacements["$DIAGONALIZE_ENERGY_RANGE_KWARGS"] = ""
     225             : 
     226           1 :         return replacements
     227             : 
     228             : 
     229           1 : @dataclass
     230           1 : class Results(ABC):
     231           1 :     energies: list[NDArray]
     232           1 :     energy_offset: float
     233           1 :     ket_overlaps: list[NDArray]
     234           1 :     state_labels: dict[int, list[str]]
     235             : 
     236           1 :     @classmethod
     237           1 :     def from_calculate(
     238             :         cls,
     239             :         parameters: Parameters[Any],
     240             :         system_list: list[SystemPairReal] | list[SystemPair] | list[SystemAtomReal] | list[SystemAtom],
     241             :         ket: pi.KetAtom | tuple[pi.KetAtom, ...],
     242             :         energy_offset: float,
     243             :     ) -> Self:
     244             :         """Create Results object from ket, basis, and diagonalized systems."""
     245           0 :         energies = [system.get_eigenenergies("GHz") - energy_offset for system in system_list]
     246           0 :         ket_overlaps = [system.get_eigenbasis().get_overlaps(ket) for system in system_list]  # type: ignore [arg-type]
     247             : 
     248           0 :         steps_with_labels = [int(i) for i in np.linspace(0, parameters.steps - 1, parameters.number_state_labels)]
     249           0 :         states_dict = {i: system_list[i].get_eigenbasis().states for i in steps_with_labels}
     250           0 :         state_labels = {i: [s.get_label() for s in states] for i, states in states_dict.items()}
     251             : 
     252           0 :         return cls(energies, energy_offset, ket_overlaps, state_labels)
     253             : 
     254             : 
     255           1 : def as_string(value: str, *, raw_string: bool = False) -> str:
     256           1 :     string = '"' + value + '"'
     257           1 :     if raw_string:
     258           1 :         string = "r" + string
     259           1 :     return string
     260             : 
     261             : 
     262           1 : def dict_to_repl(d: Mapping[str, Any]) -> str:
     263             :     """Convert a dictionary to a string for replacement."""
     264           1 :     if not d:
     265           0 :         return ""
     266           1 :     repl = ""
     267           1 :     for k, v in d.items():
     268           1 :         if isinstance(v, str):
     269           1 :             repl += f", {k}={as_string(v)}"
     270             :         else:
     271           1 :             repl += f", {k}={v}"
     272           1 :     return repl

Generated by: LCOV version 1.16