Line data Source code
1 : # SPDX-FileCopyrightText: 2024 PairInteraction Developers 2 : # SPDX-License-Identifier: LGPL-3.0-or-later 3 : 4 1 : from __future__ import annotations 5 : 6 1 : from typing import TYPE_CHECKING 7 : 8 1 : import numpy as np 9 1 : import pytest 10 : 11 : if TYPE_CHECKING: 12 : from pairinteraction import BasisAtom, StateAtom 13 : 14 : from .utils import PairinteractionModule 15 : 16 : 17 1 : @pytest.fixture 18 1 : def basis(pi_module: PairinteractionModule) -> BasisAtom: 19 : """Create a test basis with a few states around Rb 60S.""" 20 1 : ket = pi_module.KetAtom("Rb", n=60, l=0, j=0.5, m=0.5) 21 1 : energy_min = ket.get_energy(unit="GHz") - 100 22 1 : energy_max = ket.get_energy(unit="GHz") + 100 23 1 : return pi_module.BasisAtom("Rb", n=(58, 62), l=(0, 2), energy=(energy_min, energy_max), energy_unit="GHz") 24 : 25 : 26 1 : @pytest.fixture 27 1 : def state(pi_module: PairinteractionModule, basis: BasisAtom) -> StateAtom: 28 : """Create a test state.""" 29 1 : ket = pi_module.KetAtom("Rb", n=60, l=1, j=1.5, m=-0.5) 30 1 : return basis.get_corresponding_state(ket) 31 : 32 : 33 1 : def test_state_creation(state: StateAtom) -> None: 34 : """Test basic properties of created state.""" 35 1 : assert state.species == "Rb" 36 1 : assert state.number_of_kets == 80 37 1 : assert len(state.kets) == state.number_of_kets 38 1 : assert all(x in str(state) for x in ["StateAtom", "60", "S", "3/2", "-1/2"]) 39 1 : assert state.is_canonical 40 : 41 : 42 1 : def test_coefficients(state: StateAtom) -> None: 43 : """Test coefficient matrix properties.""" 44 1 : coeffs = state.get_coefficients() 45 1 : assert coeffs.shape == (state.number_of_kets,) 46 1 : assert np.count_nonzero(coeffs) == 1 47 1 : assert pytest.approx(coeffs.sum()) == 1.0 # NOSONAR 48 : 49 : 50 1 : def test_get_amplitude_and_overlap(state: StateAtom) -> None: 51 : """Test amplitude and overlap calculations.""" 52 : # Test with ket 53 1 : test_ket = state.get_corresponding_ket() 54 1 : amplitude = state.get_amplitude(test_ket) 55 1 : assert np.isscalar(amplitude) 56 1 : assert pytest.approx(amplitude) == 1.0 # NOSONAR 57 1 : overlap = state.get_overlap(test_ket) 58 1 : assert np.isscalar(overlap) 59 1 : assert pytest.approx(overlap) == 1.0 # NOSONAR 60 : 61 : # Test with state 62 1 : amplitude = state.get_amplitude(state) 63 1 : assert np.isscalar(amplitude) 64 1 : assert pytest.approx(amplitude) == 1.0 # NOSONAR 65 1 : overlap = state.get_overlap(state) 66 1 : assert np.isscalar(overlap) 67 1 : assert pytest.approx(overlap) == 1.0 # NOSONAR 68 : 69 : 70 1 : def test_get_matrix_element(pi_module: PairinteractionModule, basis: BasisAtom) -> None: 71 : """Test matrix element calculations.""" 72 1 : ket1 = pi_module.KetAtom("Rb", n=60, l=1, j=1.5, m=-0.5) 73 1 : ket2 = pi_module.KetAtom("Rb", n=60, l=0, j=0.5, m=0.5) 74 1 : state1 = basis.get_corresponding_state(ket1) 75 1 : state2 = basis.get_corresponding_state(ket2) 76 : 77 : # Test with ket 78 1 : element_dipole_ket = state1.get_matrix_element(ket2, "electric_dipole", q=1, unit="e * a0") 79 1 : assert np.isscalar(element_dipole_ket) 80 1 : assert element_dipole_ket != 0 81 : 82 : # Test with state 83 1 : element_dipole_state = state1.get_matrix_element(state2, "electric_dipole", q=1, unit="e * a0") 84 1 : assert np.isscalar(element_dipole_state) 85 1 : assert pytest.approx(element_dipole_ket) == element_dipole_state # NOSONAR 86 1 : assert state1.get_matrix_element(state1, "electric_dipole", q=1, unit="e * a0") == 0 87 1 : assert state1.get_matrix_element(state2, "electric_dipole", q=0, unit="e * a0") == 0 88 : 89 : 90 1 : def test_error_handling(state: StateAtom) -> None: 91 : """Test error cases.""" 92 1 : with pytest.raises(TypeError): 93 1 : state.get_amplitude("not a ket") # type: ignore [arg-type] 94 : 95 1 : with pytest.raises(TypeError): 96 1 : state.get_overlap("not a ket") # type: ignore [arg-type] 97 : 98 1 : with pytest.raises(TypeError): 99 1 : state.get_matrix_element("not a ket", "energy", 0) # type: ignore [call-overload]