LCOV - code coverage report
Current view: top level - src/pairinteraction_gui/calculate - calculate_two_atoms.py (source / functions) Hit Total Coverage
Test: coverage.info Lines: 34 81 42.0 %
Date: 2025-09-29 10:28:29 Functions: 2 8 25.0 %

          Line data    Source code
       1             : # SPDX-FileCopyrightText: 2025 PairInteraction Developers
       2             : # SPDX-License-Identifier: LGPL-3.0-or-later
       3           1 : from __future__ import annotations
       4             : 
       5           1 : import logging
       6           1 : from typing import TYPE_CHECKING
       7             : 
       8           1 : import numpy as np
       9           1 : from attr import dataclass
      10             : 
      11           1 : import pairinteraction as pi_complex
      12           1 : import pairinteraction.real as pi_real
      13           1 : from pairinteraction_gui.calculate.calculate_base import Parameters, Results
      14           1 : from pairinteraction_gui.worker import run_in_other_process
      15             : 
      16             : if TYPE_CHECKING:
      17             :     from typing_extensions import Self
      18             : 
      19             :     from pairinteraction_gui.page import TwoAtomsPage
      20             : 
      21           1 : logger = logging.getLogger(__name__)
      22             : 
      23             : 
      24           1 : @dataclass
      25           1 : class ParametersTwoAtoms(Parameters["TwoAtomsPage"]):
      26             :     """Parameters for the two atoms calculation."""
      27             : 
      28           1 :     pair_delta_energy: float = np.inf
      29           1 :     pair_m_range: tuple[float, float] | None = None
      30           1 :     order: int = 3
      31             : 
      32           1 :     @classmethod
      33           1 :     def from_page(cls, page: TwoAtomsPage) -> Self:
      34           1 :         obj = super().from_page(page)
      35           1 :         obj.pair_delta_energy = page.basis_config.pair_delta_energy.value(np.inf)
      36           1 :         obj.pair_m_range = (
      37             :             page.basis_config.pair_m_range.values() if page.basis_config.pair_m_range.isChecked() else None
      38             :         )
      39           1 :         obj.order = page.system_config.order.value()
      40           1 :         return obj
      41             : 
      42           1 :     def to_replacement_dict(self) -> dict[str, str]:
      43           1 :         replacements = super().to_replacement_dict()
      44           1 :         replacements["$MULTIPOLE_ORDER"] = str(self.order)
      45           1 :         replacements["$PAIR_DELTA_ENERGY"] = (
      46             :             "np.inf" if np.isinf(self.pair_delta_energy) else str(self.pair_delta_energy)
      47             :         )
      48           1 :         replacements["$PAIR_M_RANGE"] = str(self.pair_m_range)
      49           1 :         return replacements
      50             : 
      51             : 
      52           1 : @dataclass
      53           1 : class ResultsTwoAtoms(Results):
      54           1 :     basis_0_label: str | None = None
      55             : 
      56             : 
      57           1 : @run_in_other_process
      58           1 : def calculate_two_atoms(parameters: ParametersTwoAtoms) -> ResultsTwoAtoms:
      59             :     """Calculate the energy plot for two atoms.
      60             : 
      61             :     This means, given a Parameters object, do the PairInteraction calculations and return an ResultsTwoAtoms object.
      62             :     """
      63           0 :     return _calculate_two_atoms(parameters)
      64             : 
      65             : 
      66           1 : def _calculate_two_atoms(parameters: ParametersTwoAtoms) -> ResultsTwoAtoms:
      67             :     """Make the unwrapped function available for testing."""
      68           0 :     pi = pi_real if parameters.is_real else pi_complex
      69           0 :     n_atoms = 2
      70             : 
      71           0 :     kets = tuple(pi.KetAtom(parameters.get_species(i), **parameters.get_quantum_numbers(i)) for i in range(n_atoms))
      72           0 :     bases = tuple(
      73             :         pi.BasisAtom(parameters.get_species(i), **parameters.get_quantum_number_restrictions(i)) for i in range(n_atoms)
      74             :     )
      75             : 
      76           0 :     fields = {k: v for k, v in parameters.ranges.items() if k in ["Ex", "Ey", "Ez", "Bx", "By", "Bz"]}
      77             : 
      78             :     basis_pair_list: list[pi_real.BasisPair] | list[pi_complex.BasisPair]
      79           0 :     if all(v[0] == v[-1] for v in fields.values()):
      80             :         # If all fields are constant, we can only have to diagonalize one SystemAtom per atom
      81             :         # and can construct one BasisPair, which we can use for all steps
      82           0 :         systems = tuple(
      83             :             pi.SystemAtom(bases[i])
      84             :             .set_electric_field(parameters.get_efield(0), unit="V/cm")
      85             :             .set_magnetic_field(parameters.get_bfield(0), unit="G")
      86             :             .set_diamagnetism_enabled(parameters.diamagnetism_enabled)
      87             :             for i in range(n_atoms)
      88             :         )
      89           0 :         logger.debug("Diagonalizing SystemAtoms...")
      90           0 :         pi.diagonalize(systems, **parameters.diagonalize_kwargs)
      91           0 :         logger.debug("Done diagonalizing SystemAtoms.")
      92           0 :         ket_pair_energy_0 = sum(systems[i].get_corresponding_energy(kets[i], "GHz") for i in range(n_atoms))
      93           0 :         delta_energy = parameters.pair_delta_energy
      94           0 :         basis_pair = pi.BasisPair(
      95             :             systems,
      96             :             energy=(ket_pair_energy_0 - delta_energy, ket_pair_energy_0 + delta_energy),
      97             :             energy_unit="GHz",
      98             :             m=parameters.pair_m_range,
      99             :         )
     100             :         # not very elegant, but works (note that importantly this does not copy the basis_pair objects)
     101           0 :         basis_pair_list = parameters.steps * [basis_pair]
     102             :     else:
     103             :         # Otherwise, we have to diagonalize one SystemAtom per atom and per step
     104             :         # and construct one BasisPair per step
     105           0 :         systems_list = []
     106           0 :         for step in range(parameters.steps):
     107           0 :             systems = tuple(
     108             :                 pi.SystemAtom(bases[i])
     109             :                 .set_electric_field(parameters.get_efield(step), unit="V/cm")
     110             :                 .set_magnetic_field(parameters.get_bfield(step), unit="G")
     111             :                 .set_diamagnetism_enabled(parameters.diamagnetism_enabled)
     112             :                 for i in range(n_atoms)
     113             :             )
     114           0 :             systems_list.append(systems)
     115           0 :         systems_flattened = [system for systems in systems_list for system in systems]
     116           0 :         logger.debug("Diagonalizing SystemAtoms...")
     117           0 :         pi.diagonalize(systems_flattened, **parameters.diagonalize_kwargs)
     118           0 :         logger.debug("Done diagonalizing SystemAtoms.")
     119           0 :         delta_energy = parameters.pair_delta_energy
     120           0 :         basis_pair_list = []
     121           0 :         for step in range(parameters.steps):
     122           0 :             ket_pair_energy = sum(
     123             :                 systems_list[step][i].get_corresponding_energy(kets[i], "GHz") for i in range(n_atoms)
     124             :             )
     125           0 :             basis_pair = pi.BasisPair(
     126             :                 systems_list[step],
     127             :                 energy=(ket_pair_energy - delta_energy, ket_pair_energy + delta_energy),
     128             :                 energy_unit="GHz",
     129             :                 m=parameters.pair_m_range,
     130             :             )
     131           0 :             basis_pair_list.append(basis_pair)
     132           0 :         ket_pair_energy_0 = sum(systems_list[-1][i].get_corresponding_energy(kets[i], "GHz") for i in range(n_atoms))
     133             : 
     134           0 :     system_pair_list: list[pi_real.SystemPair] | list[pi_complex.SystemPair] = []
     135           0 :     for step in range(parameters.steps):
     136           0 :         system = pi.SystemPair(basis_pair_list[step])
     137           0 :         system.set_interaction_order(parameters.order)
     138           0 :         if "Distance" in parameters.ranges:
     139           0 :             distance = parameters.ranges["Distance"][step]
     140           0 :             angle: float = 0
     141           0 :             if "Angle" in parameters.ranges:
     142           0 :                 angle = parameters.ranges["Angle"][step]
     143           0 :             system.set_distance(distance, angle, unit="micrometer")
     144           0 :         system_pair_list.append(system)
     145             : 
     146           0 :     logger.debug("Diagonalizing SystemPairs...")
     147           0 :     pi.diagonalize(
     148             :         system_pair_list,
     149             :         **parameters.diagonalize_kwargs,
     150             :         **parameters.get_diagonalize_energy_range_kwargs(ket_pair_energy_0),
     151             :     )
     152           0 :     logger.debug("Done diagonalizing SystemPairs.")
     153             : 
     154           0 :     results = ResultsTwoAtoms.from_calculate(parameters, system_pair_list, kets, ket_pair_energy_0)
     155           0 :     results.basis_0_label = (
     156             :         str(basis_pair_list[-1]) + f"\n  ⇒ Basis consists of {basis_pair_list[-1].number_of_kets} kets"
     157             :     )
     158             : 
     159           0 :     return results

Generated by: LCOV version 1.16