LCOV - code coverage report
Current view: top level - tests - test_basis_atom.py (source / functions) Hit Total Coverage
Test: coverage.info Lines: 124 125 99.2 %
Date: 2026-06-16 12:53:10 Functions: 9 9 100.0 %

          Line data    Source code
       1             : # SPDX-FileCopyrightText: 2024 PairInteraction Developers
       2             : # SPDX-License-Identifier: LGPL-3.0-or-later
       3             : 
       4           1 : from __future__ import annotations
       5             : 
       6           1 : from typing import TYPE_CHECKING
       7             : 
       8           1 : import numpy as np
       9           1 : import pytest
      10           1 : from pairinteraction import BasisAtom
      11           1 : from pairinteraction.ket.ket_atom import KetAtom
      12           1 : from pairinteraction.state.state_atom import StateAtom
      13           1 : from scipy.sparse import csr_matrix
      14             : 
      15             : if TYPE_CHECKING:
      16             :     from .utils import PairinteractionModule
      17             : 
      18             : 
      19           1 : @pytest.fixture
      20           1 : def basis(pi_module: PairinteractionModule) -> BasisAtom:
      21           1 :     return pi_module.BasisAtom("Rb", n=(58, 62), l=(0, 2))
      22             : 
      23             : 
      24           1 : @pytest.fixture
      25           1 : def basis2(pi_module: PairinteractionModule) -> BasisAtom:
      26           1 :     return pi_module.BasisAtom("Rb", n=(58, 62), l=(2, 3))
      27             : 
      28             : 
      29           1 : def test_basis_creation(pi_module: PairinteractionModule) -> None:
      30             :     """Test basic properties of created basis."""
      31           1 :     ket = pi_module.KetAtom("Rb", n=60, l=0, j=0.5, m=0.5)
      32           1 :     energy_min = ket.get_energy(unit="GHz") - 100
      33           1 :     energy_max = ket.get_energy(unit="GHz") + 100
      34           1 :     basis = pi_module.BasisAtom("Rb", n=(58, 62), l=(0, 2), energy=(energy_min, energy_max), energy_unit="GHz")
      35           1 :     assert basis.species == "Rb"
      36           1 :     assert basis.number_of_kets == 80
      37           1 :     assert basis.number_of_states == basis.number_of_kets
      38           1 :     assert len(basis.kets) == basis.number_of_kets
      39           1 :     assert basis.number_of_kets < pi_module.BasisAtom("Rb", n=(58, 62), l=(0, 2)).number_of_kets
      40           1 :     assert all(x in str(basis) for x in ["BasisAtom", "n=(58, 62)", "l=(0, 2)"])
      41             : 
      42             : 
      43           1 : def test_restriction_mode(pi_module: PairinteractionModule) -> None:
      44             :     """Test exact, fuzzy, and numeric restrictions for expectation-value quantum numbers."""
      45           1 :     l_range = (1, 1)
      46           1 :     fuzzy_basis = pi_module.BasisAtom("Yb171_mqdt", nu=(58, 62), l=l_range, m=(0.5, 0.5))
      47           1 :     assert fuzzy_basis.number_of_kets > 0
      48           1 :     assert any(ket.l > l_range[1] for ket in fuzzy_basis.kets)
      49           1 :     assert any(ket.l < l_range[0] for ket in fuzzy_basis.kets)
      50             : 
      51           1 :     exact_basis = pi_module.BasisAtom("Yb171_mqdt", nu=(58, 62), l=l_range, m=(0.5, 0.5), mode="exact")
      52           1 :     assert exact_basis.number_of_kets > 0
      53           1 :     assert all(l_range[0] <= ket.l <= l_range[1] for ket in exact_basis.kets)
      54           1 :     assert exact_basis.number_of_kets < fuzzy_basis.number_of_kets
      55             : 
      56           1 :     factor_basis = pi_module.BasisAtom("Yb171_mqdt", nu=(58, 62), l=l_range, m=(0.5, 0.5), mode=100)
      57           1 :     assert factor_basis.number_of_kets > fuzzy_basis.number_of_kets
      58             : 
      59           1 :     with pytest.raises(ValueError, match="mode"):
      60           1 :         pi_module.BasisAtom("Yb171_mqdt", nu=(58, 62), l=l_range, m=(0.5, 0.5), mode="invalid")  # type: ignore[arg-type]
      61             : 
      62           1 :     with pytest.raises(ValueError, match="non-negative"):
      63           1 :         pi_module.BasisAtom("Yb171_mqdt", nu=(58, 62), l=l_range, m=(0.5, 0.5), mode=-1)
      64             : 
      65             : 
      66           1 : def test_coefficients(basis: BasisAtom) -> None:
      67             :     """Test coefficient matrix properties."""
      68           1 :     coeffs = basis.get_coefficients()
      69           1 :     assert coeffs.shape == (basis.number_of_kets, basis.number_of_states)
      70           1 :     assert pytest.approx(coeffs.diagonal()) == 1.0  # NOSONAR
      71           1 :     assert pytest.approx(coeffs.sum()) == basis.number_of_kets  # NOSONAR
      72             : 
      73             : 
      74           1 : def _get_expected_shape(other: KetAtom | StateAtom | BasisAtom, target_basis: BasisAtom) -> tuple[int, ...]:
      75           1 :     if isinstance(other, (KetAtom, StateAtom)):
      76           1 :         return (target_basis.number_of_states,)
      77           1 :     if isinstance(other, BasisAtom):
      78           1 :         return (other.number_of_states, target_basis.number_of_states)
      79           0 :     raise ValueError("Invalid basis_like type")
      80             : 
      81             : 
      82           1 : @pytest.mark.parametrize(
      83             :     "other_key",
      84             :     [
      85             :         "ket_from_basis",
      86             :         "ket_from_basis2",
      87             :         "other_ket_from_basis2",
      88             :         "state_from_basis",
      89             :         "state_from_basis2",
      90             :         "basis",
      91             :         "basis2",
      92             :     ],
      93             : )
      94           1 : def test_get_methods(basis: BasisAtom, basis2: BasisAtom, other_key: str) -> None:
      95             :     """Test amplitude, overlap and matrix element calculations with another ket, state and basis."""
      96           1 :     ind = 5
      97           1 :     other_dict: dict[str, KetAtom | StateAtom | BasisAtom] = {
      98             :         "ket_from_basis": basis.get_ket(ind),
      99             :         "ket_from_basis2": next(ket for ket in basis2.kets if ket in basis.kets),
     100             :         "other_ket_from_basis2": next(ket for ket in basis2.kets if ket not in basis.kets and ket.j_ryd < 3),
     101             :         "state_from_basis": basis.get_state(ind),
     102             :         "state_from_basis2": basis2.get_state(0),
     103             :         "basis": basis,
     104             :         "basis2": basis2,
     105             :     }
     106           1 :     other = other_dict[other_key]
     107             : 
     108           1 :     amplitudes = basis.get_amplitudes(other)
     109           1 :     assert amplitudes.shape == _get_expected_shape(other, basis)
     110             : 
     111           1 :     overlaps = basis.get_overlaps(other)
     112           1 :     assert overlaps.shape == _get_expected_shape(other, basis)
     113             : 
     114           1 :     matrix_elements = basis.get_matrix_elements(other, "electric_dipole", q=0, unit="e * a0")
     115           1 :     assert matrix_elements.shape == _get_expected_shape(other, basis)
     116             : 
     117           1 :     if other_key in ["ket_from_basis", "state_from_basis"]:
     118           1 :         assert pytest.approx(amplitudes[ind]) == 1.0  # NOSONAR
     119           1 :         assert pytest.approx(overlaps[ind]) == 1.0  # NOSONAR
     120             : 
     121           1 :     if other_key == "ket_from_basis2":
     122           1 :         assert isinstance(amplitudes, np.ndarray)
     123           1 :         assert isinstance(overlaps, np.ndarray)
     124           1 :         assert pytest.approx(np.max(amplitudes)) == 1.0  # NOSONAR
     125           1 :         assert pytest.approx(np.max(overlaps)) == 1.0  # NOSONAR
     126             : 
     127           1 :     if other_key == "other_ket_from_basis2":
     128           1 :         assert isinstance(amplitudes, np.ndarray)
     129           1 :         assert isinstance(overlaps, np.ndarray)
     130           1 :         assert np.count_nonzero(amplitudes) == 0
     131           1 :         assert np.count_nonzero(overlaps) == 0
     132             : 
     133           1 :     if other_key == "basis":
     134           1 :         assert isinstance(amplitudes, csr_matrix)
     135           1 :         assert isinstance(overlaps, csr_matrix)
     136           1 :         assert pytest.approx(amplitudes.diagonal()) == 1.0  # NOSONAR
     137           1 :         assert pytest.approx(overlaps.diagonal()) == 1.0  # NOSONAR
     138             : 
     139           1 :     if other_key == "basis2":
     140           1 :         assert isinstance(amplitudes, csr_matrix)
     141           1 :         assert isinstance(overlaps, csr_matrix)
     142           1 :         n_matching_kets = len([ket for ket in basis2.kets if ket in basis.kets])
     143           1 :         assert np.count_nonzero(amplitudes.toarray()) == n_matching_kets
     144           1 :         assert np.count_nonzero(overlaps.toarray()) == n_matching_kets
     145             : 
     146           1 :     if other_key.startswith("basis"):
     147           1 :         assert isinstance(matrix_elements, csr_matrix)
     148           1 :         assert 0 < np.count_nonzero(matrix_elements.toarray()) < basis.number_of_states**2
     149             :     else:
     150           1 :         assert isinstance(matrix_elements, np.ndarray)
     151           1 :         assert 0 < np.count_nonzero(matrix_elements) < basis.number_of_states
     152             : 
     153             : 
     154           1 : def test_error_handling(basis: BasisAtom) -> None:
     155             :     """Test error cases."""
     156           1 :     with pytest.raises(TypeError):
     157           1 :         basis.get_amplitudes("not a ket")  # type: ignore [call-overload]
     158             : 
     159           1 :     with pytest.raises(TypeError):
     160           1 :         basis.get_overlaps("not a ket")  # type: ignore [call-overload]
     161             : 
     162           1 :     with pytest.raises(TypeError):
     163           1 :         basis.get_matrix_elements("not a ket", "energy", 0)  # type: ignore [call-overload]
     164             : 
     165             : 
     166           1 : def test_from_kets(pi_module: PairinteractionModule) -> None:
     167             :     """Test BasisAtom.from_kets."""
     168             :     # single ket
     169           1 :     ket = pi_module.KetAtom("Rb", n=60, l=0, j=0.5, m=0.5)
     170           1 :     basis = pi_module.BasisAtom.from_kets(
     171             :         ket,
     172             :         delta_n=2,
     173             :         delta_nu=3,
     174             :         delta_nui=3,
     175             :         delta_l=2,
     176             :         delta_s=1,
     177             :         delta_j=3,
     178             :         delta_l_ryd=2,
     179             :         delta_j_ryd=3,
     180             :         delta_f=3,
     181             :         delta_m=2,
     182             :         delta_energy=100,
     183             :         delta_energy_unit="GHz",
     184             :     )
     185           1 :     assert basis.species == "Rb"
     186           1 :     assert all(58 <= k.n <= 62 for k in basis.kets)
     187           1 :     assert any(k.n == 62 for k in basis.kets)
     188           1 :     assert any(k.n == 58 for k in basis.kets)
     189           1 :     assert any(k == ket for k in basis.kets)
     190             : 
     191             :     # multiple kets
     192           1 :     ket1 = pi_module.KetAtom("Rb", n=60, l=0, j=0.5, m=0.5)
     193           1 :     ket2 = pi_module.KetAtom("Rb", n=61, l=0, j=0.5, m=0.5)
     194           1 :     basis = pi_module.BasisAtom.from_kets([ket1, ket2], delta_n=2)
     195           1 :     assert all(58 <= k.n <= 63 for k in basis.kets)
     196           1 :     assert any(k.n == 63 for k in basis.kets)
     197           1 :     assert any(k.n == 58 for k in basis.kets)
     198           1 :     assert any(k == ket1 for k in basis.kets)
     199           1 :     assert any(k == ket2 for k in basis.kets)
     200             : 
     201             :     # test that from_kets is consistent with direct constructor
     202           1 :     ket = pi_module.KetAtom("Rb", n=60, l=0, j=0.5, m=0.5)
     203           1 :     basis_from = pi_module.BasisAtom.from_kets(ket, delta_n=2)
     204           1 :     basis_direct = pi_module.BasisAtom("Rb", n=(58, 62))
     205           1 :     assert basis_from.number_of_kets == basis_direct.number_of_kets
     206             : 
     207             :     # test error cases
     208           1 :     with pytest.raises(ValueError, match="empty"):
     209           1 :         pi_module.BasisAtom.from_kets([])
     210             : 
     211           1 :     ket_rb = pi_module.KetAtom("Rb", n=60, l=0, j=0.5, m=0.5)
     212           1 :     ket_sr = pi_module.KetAtom("Sr88_singlet", n=60, l=1, j=1, m=0)
     213           1 :     with pytest.raises(ValueError, match="species"):
     214           1 :         pi_module.BasisAtom.from_kets([ket_rb, ket_sr])

Generated by: LCOV version 1.16