LCOV - code coverage report
Current view: top level - tests - test_basis_pair.py (source / functions) Hit Total Coverage
Test: coverage.info Lines: 105 106 99.1 %
Date: 2026-06-16 12:53:10 Functions: 13 13 100.0 %

          Line data    Source code
       1             : # SPDX-FileCopyrightText: 2024 PairInteraction Developers
       2             : # SPDX-License-Identifier: LGPL-3.0-or-later
       3             : 
       4           1 : from __future__ import annotations
       5             : 
       6           1 : from typing import TYPE_CHECKING
       7             : 
       8           1 : import numpy as np
       9           1 : import pytest
      10           1 : from pairinteraction import BasisAtom, BasisPair
      11           1 : from pairinteraction.ket.ket_pair import is_ket_pair_like
      12           1 : from pairinteraction.state.state_pair import is_state_pair_like
      13             : 
      14             : if TYPE_CHECKING:
      15             :     from pairinteraction import KetAtom, SystemAtom
      16             :     from pairinteraction.basis.basis_pair import BasisPairLike
      17             :     from pairinteraction.ket.ket_pair import KetPairLike
      18             :     from pairinteraction.state.state_pair import StatePairLike
      19             : 
      20             :     from .utils import PairinteractionModule
      21             : 
      22           1 : from .utils import no_log_propagation
      23             : 
      24             : 
      25           1 : @pytest.fixture
      26           1 : def basis(pi_module: PairinteractionModule, system_atom: SystemAtom, system_atom2: SystemAtom) -> BasisPair:
      27           1 :     system_atoms = [system_atom, system_atom2]
      28           1 :     ket_atoms = tuple(s.basis.get_ket(30) for s in system_atoms)
      29           1 :     with no_log_propagation("pairinteraction.basis.basis_pair"):  # suppress number_of_kets warning
      30           1 :         return pi_module.BasisPair.from_kets(ket_atoms, system_atoms, number_of_kets=80)
      31             : 
      32             : 
      33           1 : @pytest.fixture
      34           1 : def basis2(pi_module: PairinteractionModule, system_atom: SystemAtom) -> BasisPair:
      35           1 :     system_atoms = [system_atom, system_atom]
      36           1 :     ket_atoms = tuple(s.basis.get_ket(30) for s in system_atoms)
      37           1 :     with no_log_propagation("pairinteraction.basis.basis_pair"):  # suppress number_of_kets warning
      38           1 :         return pi_module.BasisPair.from_kets(ket_atoms, system_atoms, number_of_kets=80)
      39             : 
      40             : 
      41           1 : @pytest.fixture
      42           1 : def basis_atom(pi_module: PairinteractionModule) -> BasisAtom:
      43           1 :     return pi_module.BasisAtom("Rb", n=(58, 62), l=(0, 2))
      44             : 
      45             : 
      46           1 : @pytest.fixture
      47           1 : def basis_atom2(pi_module: PairinteractionModule) -> BasisAtom:
      48           1 :     return pi_module.BasisAtom("Rb", n=(58, 62), l=(2, 3))
      49             : 
      50             : 
      51           1 : @pytest.fixture
      52           1 : def system_atom(pi_module: PairinteractionModule, basis_atom: BasisAtom) -> SystemAtom:
      53           1 :     return pi_module.SystemAtom(basis_atom)
      54             : 
      55             : 
      56           1 : @pytest.fixture
      57           1 : def system_atom2(pi_module: PairinteractionModule, basis_atom2: BasisAtom) -> SystemAtom:
      58           1 :     return pi_module.SystemAtom(basis_atom2)
      59             : 
      60             : 
      61           1 : def test_basis_creation(pi_module: PairinteractionModule, system_atom: SystemAtom, system_atom2: SystemAtom) -> None:
      62             :     """Test basic properties of created basis."""
      63           1 :     basis = pi_module.BasisPair([system_atom, system_atom2])
      64           1 :     assert basis.number_of_kets == system_atom.basis.number_of_kets * system_atom2.basis.number_of_kets
      65           1 :     assert basis.number_of_states == basis.number_of_kets
      66           1 :     assert len(basis.kets) == basis.number_of_kets
      67           1 :     assert all(x in str(basis) for x in ["BasisPair", "Rb:58,S_1/2,", "...", "Rb:62,D_5/2,"])
      68             : 
      69             : 
      70           1 : def test_coefficients(basis: BasisPair) -> None:
      71             :     """Test coefficient matrix properties."""
      72           1 :     coeffs = basis.get_coefficients()
      73           1 :     assert coeffs.shape == (basis.number_of_kets, basis.number_of_states)
      74           1 :     assert pytest.approx(coeffs.diagonal()) == 1.0  # NOSONAR
      75           1 :     assert pytest.approx(coeffs.sum()) == basis.number_of_kets  # NOSONAR
      76             : 
      77             : 
      78           1 : def _get_expected_shape(other: KetPairLike | StatePairLike | BasisPairLike, target_basis: BasisPair) -> tuple[int, ...]:
      79           1 :     if is_ket_pair_like(other) or is_state_pair_like(other):
      80           1 :         return (target_basis.number_of_states,)
      81           1 :     if isinstance(other, BasisPair):
      82           1 :         return (other.number_of_states, target_basis.number_of_states)
      83           1 :     if isinstance(other, list) and all(isinstance(b, BasisAtom) for b in other):
      84           1 :         return (other[0].number_of_states * other[1].number_of_states, target_basis.number_of_states)
      85           0 :     raise ValueError("Invalid basis_like type")
      86             : 
      87             : 
      88           1 : @pytest.mark.parametrize(
      89             :     "other_key",
      90             :     [
      91             :         "ket_from_basis",
      92             :         "ket_from_basis2",
      93             :         "ket_atom_tuple_from_basis",
      94             :         "ket_atom_tuple_from_basis2",
      95             :         "state_from_basis",
      96             :         "state_from_basis2",
      97             :         "state_atom_tuple_from_basis",
      98             :         "state_atom_tuple_from_basis2",
      99             :         "basis",
     100             :         "basis2",
     101             :         "basis_atom_tuple_from_basis",
     102             :         "basis_atom_tuple_from_basis2",
     103             :     ],
     104             : )
     105           1 : def test_get_methods(basis: BasisPair, basis2: BasisPair, other_key: str) -> None:
     106             :     """Test amplitude, overlap and matrix element calculations with another ket, state and basis."""
     107           1 :     other_dict: dict[str, KetPairLike | StatePairLike | BasisPairLike] = {
     108             :         "ket_from_basis": basis.get_ket(0),
     109             :         "ket_from_basis2": basis2.get_ket(0),
     110             :         "ket_atom_tuple_from_basis": [s.basis.get_ket(0) for s in basis.system_atoms],
     111             :         "ket_atom_tuple_from_basis2": [s.basis.get_ket(0) for s in basis2.system_atoms],
     112             :         "state_from_basis": basis.get_state(0),
     113             :         "state_from_basis2": basis2.get_state(0),
     114             :         "state_atom_tuple_from_basis": [s.basis.get_state(0) for s in basis.system_atoms],
     115             :         "state_atom_tuple_from_basis2": [s.basis.get_state(0) for s in basis2.system_atoms],
     116             :         "basis": basis,
     117             :         "basis2": basis2,
     118             :         "basis_atom_tuple_from_basis": [s.basis for s in basis.system_atoms],
     119             :         "basis_atom_tuple_from_basis2": [s.basis for s in basis2.system_atoms],
     120             :     }
     121           1 :     other = other_dict[other_key]
     122             : 
     123           1 :     amplitudes = basis.get_amplitudes(other)
     124           1 :     assert amplitudes.shape == _get_expected_shape(other, basis)
     125             : 
     126           1 :     overlaps = basis.get_overlaps(other)
     127           1 :     assert overlaps.shape == _get_expected_shape(other, basis)
     128             : 
     129           1 :     elements_dipole = basis.get_matrix_elements(
     130             :         other, ("electric_dipole", "electric_dipole"), qs=(0, 0), unit="e^2 * a0^2"
     131             :     )
     132           1 :     assert elements_dipole.shape == _get_expected_shape(other, basis)
     133             : 
     134             : 
     135           1 : def test_get_overlaps_pair_explicitly(pi_module: PairinteractionModule, system_atom: SystemAtom) -> None:
     136             :     # Test with a different basis (non-identical kets) - cross-basis overlaps
     137           1 :     ket = system_atom.basis.get_ket(20)
     138           1 :     energy_min = 2 * ket.get_energy("GHz") - 1
     139           1 :     energy_max = 2 * ket.get_energy("GHz") + 1
     140           1 :     basis_pair1 = pi_module.BasisPair([system_atom, system_atom])
     141           1 :     basis_pair2 = pi_module.BasisPair([system_atom, system_atom], energy=(energy_min, energy_max), energy_unit="GHz")
     142           1 :     assert basis_pair1.number_of_states > basis_pair2.number_of_states
     143           1 :     assert basis_pair1.number_of_states > 1
     144             : 
     145           1 :     matrix_overlaps = basis_pair1.get_overlaps(basis_pair2)
     146           1 :     assert matrix_overlaps.shape == (basis_pair2.number_of_states, basis_pair1.number_of_states)
     147           1 :     assert (matrix_overlaps.data >= 0).all()
     148           1 :     assert (matrix_overlaps.data <= 1 + 1e-10).all()
     149             : 
     150           1 :     col_sums = np.array(matrix_overlaps.sum(axis=0)).flatten()
     151           1 :     assert np.all(col_sums <= 1.0 + 1e-10)
     152           1 :     row_sums = np.array(matrix_overlaps.sum(axis=1)).flatten()
     153           1 :     assert np.all(row_sums <= 1.0 + 1e-10)
     154             : 
     155           1 :     idx0 = basis_pair1.get_corresponding_state_index((ket, ket))
     156           1 :     idx2 = basis_pair2.get_corresponding_state_index((ket, ket))
     157           1 :     assert pytest.approx(matrix_overlaps[idx2, idx0]) == 1.0  # NOSONAR
     158             : 
     159             : 
     160           1 : def test_error_handling(basis: BasisPair) -> None:
     161             :     """Test error cases."""
     162           1 :     with pytest.raises(TypeError):
     163           1 :         basis.get_amplitudes("not a ket")  # type: ignore [arg-type]
     164             : 
     165           1 :     with pytest.raises(TypeError):
     166           1 :         basis.get_overlaps("not a ket")  # type: ignore [arg-type]
     167             : 
     168           1 :     with pytest.raises(TypeError):
     169           1 :         basis.get_matrix_elements("not a ket", ("energy", "energy"), (0, 0))  # type: ignore [arg-type]
     170             : 
     171             : 
     172           1 : def test_from_kets(pi_module: PairinteractionModule, system_atom: SystemAtom) -> None:
     173             :     """Test BasisPair.from_kets."""
     174           1 :     ket = pi_module.KetAtom("Rb", n=60, l=0, j=0.5, m=0.5)
     175           1 :     ket1 = pi_module.KetAtom("Rb", n=59, l=1, j=0.5, m=0.5)
     176           1 :     ket2 = pi_module.KetAtom("Rb", n=60, l=1, j=0.5, m=0.5)
     177           1 :     ket_atoms_dict: dict[str, tuple[KetAtom, KetAtom] | list[tuple[KetAtom, KetAtom]]] = {
     178             :         "single_ket": (ket, ket),
     179             :         "multiple_kets": [(ket, ket), (ket1, ket2), (ket2, ket1)],
     180             :     }
     181             : 
     182           1 :     for ket_atoms in ket_atoms_dict.values():
     183             :         # delta_energy restriction
     184           1 :         pair_basis = pi_module.BasisPair.from_kets(
     185             :             ket_atoms, [system_atom, system_atom], delta_energy=3, delta_energy_unit="GHz", delta_m=1
     186             :         )
     187           1 :         assert pair_basis.number_of_kets > 0
     188             : 
     189             :         # number_of_kets
     190           1 :         ket = pi_module.KetAtom("Rb", n=60, l=0, j=0.5, m=0.5)
     191           1 :         for target in [50, 100, 1000]:
     192           1 :             with no_log_propagation("pairinteraction.basis.basis_pair"):  # suppress number_of_kets warning
     193           1 :                 pair_basis = pi_module.BasisPair.from_kets(
     194             :                     ket_atoms, [system_atom, system_atom], number_of_kets=target, delta_m=1
     195             :                 )
     196           1 :             assert pair_basis.number_of_kets >= target
     197           1 :             assert pair_basis.number_of_kets < target + 20  # allow some extra due to degeneracies
     198             : 
     199             :     # test error cases
     200           1 :     with pytest.raises(ValueError, match="empty"):
     201           1 :         pi_module.BasisPair.from_kets([], [system_atom, system_atom])
     202             : 
     203           1 :     ket = pi_module.KetAtom("Rb", n=60, l=0, j=0.5, m=0.5)
     204           1 :     with pytest.raises(ValueError, match="number_of_kets"):
     205           1 :         pi_module.BasisPair.from_kets(
     206             :             (ket, ket),
     207             :             [system_atom, system_atom],
     208             :             delta_energy=3,
     209             :             delta_energy_unit="GHz",
     210             :             number_of_kets=10,
     211             :         )

Generated by: LCOV version 1.16