LCOV - code coverage report
Current view: top level - src/pairinteraction_gui/calculate - calculate_two_atoms.py (source / functions) Hit Total Coverage
Test: coverage.info Lines: 58 79 73.4 %
Date: 2025-06-06 09:09:03 Functions: 2 8 25.0 %

          Line data    Source code
       1             : # SPDX-FileCopyrightText: 2025 Pairinteraction Developers
       2             : # SPDX-License-Identifier: LGPL-3.0-or-later
       3             : 
       4           1 : import logging
       5           1 : from typing import TYPE_CHECKING, Optional, Union
       6             : 
       7           1 : import numpy as np
       8           1 : from attr import dataclass
       9             : 
      10           1 : from pairinteraction import (
      11             :     complex as pi_complex,
      12             :     real as pi_real,
      13             : )
      14           1 : from pairinteraction_gui.calculate.calculate_base import Parameters, Results
      15           1 : from pairinteraction_gui.worker import run_in_other_process
      16             : 
      17             : if TYPE_CHECKING:
      18             :     from typing_extensions import Self
      19             : 
      20             :     from pairinteraction_gui.page import TwoAtomsPage
      21             : 
      22           1 : logger = logging.getLogger(__name__)
      23             : 
      24             : 
      25           1 : @dataclass
      26           1 : class ParametersTwoAtoms(Parameters["TwoAtomsPage"]):
      27             :     """Parameters for the two atoms calculation."""
      28             : 
      29           1 :     pair_delta_energy: float = np.inf
      30           1 :     pair_m_range: Optional[tuple[float, float]] = None
      31           1 :     order: int = 3
      32             : 
      33           1 :     @classmethod
      34           1 :     def from_page(cls, page: "TwoAtomsPage") -> "Self":
      35           1 :         obj = super().from_page(page)
      36           1 :         obj.pair_delta_energy = page.basis_config.pair_delta_energy.value(np.inf)
      37           1 :         obj.pair_m_range = (
      38             :             page.basis_config.pair_m_range.values() if page.basis_config.pair_m_range.isChecked() else None
      39             :         )
      40           1 :         obj.order = page.system_config.order.value()
      41           1 :         return obj
      42             : 
      43           1 :     def to_replacement_dict(self) -> dict[str, str]:
      44           0 :         replacements = super().to_replacement_dict()
      45           0 :         replacements["$MULTIPOLE_ORDER"] = str(self.order)
      46           0 :         replacements["$PAIR_DELTA_ENERGY"] = (
      47             :             "np.inf" if np.isinf(self.pair_delta_energy) else str(self.pair_delta_energy)
      48             :         )
      49           0 :         replacements["$PAIR_M_RANGE"] = str(self.pair_m_range)
      50           0 :         return replacements
      51             : 
      52             : 
      53           1 : @dataclass
      54           1 : class ResultsTwoAtoms(Results):
      55           1 :     basis_0_label: Optional[str] = None
      56             : 
      57             : 
      58           1 : @run_in_other_process
      59           1 : def calculate_two_atoms(parameters: ParametersTwoAtoms) -> ResultsTwoAtoms:
      60             :     """Calculate the energy plot for two atoms.
      61             : 
      62             :     This means, given a Parameters object, do the pairinteraction calculations and return an ResultsTwoAtoms object.
      63             :     """
      64           0 :     return _calculate_two_atoms(parameters)
      65             : 
      66             : 
      67           1 : def _calculate_two_atoms(parameters: ParametersTwoAtoms) -> ResultsTwoAtoms:
      68             :     """Make the unwrapped function available for testing."""
      69           1 :     pi = pi_real if parameters.is_real else pi_complex
      70           1 :     n_atoms = 2
      71             : 
      72           1 :     kets = tuple(pi.KetAtom(parameters.get_species(i), **parameters.get_quantum_numbers(i)) for i in range(n_atoms))
      73           1 :     bases = tuple(
      74             :         pi.BasisAtom(parameters.get_species(i), **parameters.get_quantum_number_restrictions(i)) for i in range(n_atoms)
      75             :     )
      76             : 
      77           1 :     fields = {k: v for k, v in parameters.ranges.items() if k in ["Ex", "Ey", "Ez", "Bx", "By", "Bz"]}
      78             : 
      79             :     basis_pair_list: Union[list[pi_real.BasisPair], list[pi_complex.BasisPair]]
      80           1 :     if all(v[0] == v[-1] for v in fields.values()):
      81             :         # If all fields are constant, we can only have to diagonalize one SystemAtom per atom
      82             :         # and can construct one BasisPair, which we can use for all steps
      83           1 :         systems = tuple(
      84             :             pi.SystemAtom(bases[i])
      85             :             .set_electric_field(parameters.get_efield(0), unit="V/cm")
      86             :             .set_magnetic_field(parameters.get_bfield(0), unit="G")
      87             :             for i in range(n_atoms)
      88             :         )
      89           1 :         logger.debug("Diagonalizing SystemAtoms...")
      90           1 :         pi.diagonalize(systems, **parameters.diagonalize_kwargs)
      91           1 :         logger.debug("Done diagonalizing SystemAtoms.")
      92           1 :         ket_pair_energy_0 = sum(systems[i].get_corresponding_energy(kets[i], "GHz") for i in range(n_atoms))
      93           1 :         delta_energy = parameters.pair_delta_energy
      94           1 :         basis_pair = pi.BasisPair(
      95             :             systems,
      96             :             energy=(ket_pair_energy_0 - delta_energy, ket_pair_energy_0 + delta_energy),
      97             :             energy_unit="GHz",
      98             :             m=parameters.pair_m_range,
      99             :         )
     100             :         # not very elegant, but works (note that importantly this does not copy the basis_pair objects)
     101           1 :         basis_pair_list = parameters.steps * [basis_pair]
     102             :     else:
     103             :         # Otherwise, we have to diagonalize one SystemAtom per atom and per step
     104             :         # and construct one BasisPair per step
     105           0 :         systems_list = []
     106           0 :         for step in range(parameters.steps):
     107           0 :             systems = tuple(
     108             :                 pi.SystemAtom(bases[i])
     109             :                 .set_electric_field(parameters.get_efield(step), unit="V/cm")
     110             :                 .set_magnetic_field(parameters.get_bfield(step), unit="G")
     111             :                 for i in range(n_atoms)
     112             :             )
     113           0 :             systems_list.append(systems)
     114           0 :         systems_flattened = [system for systems in systems_list for system in systems]
     115           0 :         logger.debug("Diagonalizing SystemAtoms...")
     116           0 :         pi.diagonalize(systems_flattened, **parameters.diagonalize_kwargs)
     117           0 :         logger.debug("Done diagonalizing SystemAtoms.")
     118           0 :         delta_energy = parameters.pair_delta_energy
     119           0 :         basis_pair_list = []
     120           0 :         for step in range(parameters.steps):
     121           0 :             ket_pair_energy = sum(
     122             :                 systems_list[step][i].get_corresponding_energy(kets[i], "GHz") for i in range(n_atoms)
     123             :             )
     124           0 :             basis_pair = pi.BasisPair(
     125             :                 systems_list[step],
     126             :                 energy=(ket_pair_energy - delta_energy, ket_pair_energy + delta_energy),
     127             :                 energy_unit="GHz",
     128             :                 m=parameters.pair_m_range,
     129             :             )
     130           0 :             basis_pair_list.append(basis_pair)
     131           0 :         ket_pair_energy_0 = sum(systems_list[-1][i].get_corresponding_energy(kets[i], "GHz") for i in range(n_atoms))
     132             : 
     133           1 :     system_pair_list: Union[list[pi_real.SystemPair], list[pi_complex.SystemPair]] = []
     134           1 :     for step in range(parameters.steps):
     135           1 :         system = pi.SystemPair(basis_pair_list[step])
     136           1 :         system.set_interaction_order(parameters.order)
     137           1 :         if "Distance" in parameters.ranges:
     138           1 :             distance = parameters.ranges["Distance"][step]
     139           1 :             angle: float = 0
     140           1 :             if "Angle" in parameters.ranges:
     141           1 :                 angle = parameters.ranges["Angle"][step]
     142           1 :             system.set_distance(distance, angle, unit="micrometer")
     143           1 :         system_pair_list.append(system)
     144             : 
     145           1 :     logger.debug("Diagonalizing SystemPairs...")
     146           1 :     pi.diagonalize(
     147             :         system_pair_list,
     148             :         **parameters.diagonalize_kwargs,
     149             :         **parameters.get_diagonalize_energy_range(ket_pair_energy_0),
     150             :     )
     151           1 :     logger.debug("Done diagonalizing SystemPairs.")
     152             : 
     153           1 :     results = ResultsTwoAtoms.from_calculate(parameters, system_pair_list, kets, ket_pair_energy_0)
     154           1 :     results.basis_0_label = (
     155             :         str(basis_pair_list[-1]) + f"\n  ⇒ Basis consists of {basis_pair_list[-1].number_of_kets} kets"
     156             :     )
     157             : 
     158           1 :     return results

Generated by: LCOV version 1.16