Line data Source code
1 : # SPDX-FileCopyrightText: 2024 Pairinteraction Developers 2 : # SPDX-License-Identifier: LGPL-3.0-or-later 3 : 4 1 : import numpy as np 5 1 : import pytest 6 : 7 1 : import pairinteraction.real as pi 8 : 9 : 10 1 : @pytest.fixture 11 1 : def basis() -> pi.BasisAtom: 12 : """Create a test basis with a few states around Rb 60S.""" 13 1 : ket = pi.KetAtom("Rb", n=60, l=0, j=0.5, m=0.5) 14 1 : energy_min = ket.get_energy(unit="GHz") - 100 15 1 : energy_max = ket.get_energy(unit="GHz") + 100 16 1 : return pi.BasisAtom("Rb", n=(58, 62), l=(0, 2), energy=(energy_min, energy_max), energy_unit="GHz") 17 : 18 : 19 1 : def test_basis_creation(basis: pi.BasisAtom) -> None: 20 : """Test basic properties of created basis.""" 21 1 : assert basis.species == "Rb" 22 1 : assert basis.number_of_kets == 80 23 1 : assert basis.number_of_states == basis.number_of_kets 24 1 : assert len(basis.kets) == basis.number_of_kets 25 1 : assert basis.number_of_kets < pi.BasisAtom("Rb", n=(58, 62), l=(0, 2)).number_of_kets 26 1 : assert all(x in str(basis) for x in ["BasisAtom", "n=(58, 62)", "l=(0, 2)"]) 27 : 28 : 29 1 : def test_coefficients(basis: pi.BasisAtom) -> None: 30 : """Test coefficient matrix properties.""" 31 1 : coeffs = basis.get_coefficients() 32 1 : assert coeffs.shape == (basis.number_of_kets, basis.number_of_states) 33 1 : assert pytest.approx(coeffs.diagonal()) == 1.0 # NOSONAR 34 1 : assert pytest.approx(coeffs.sum()) == basis.number_of_kets # NOSONAR 35 : 36 : 37 1 : def test_get_amplitudes_and_overlaps(basis: pi.BasisAtom) -> None: 38 : """Test amplitude and overlap calculations.""" 39 : # Test with ket 40 1 : test_ket = basis.kets[0] 41 1 : amplitudes = basis.get_amplitudes(test_ket) 42 1 : assert len(amplitudes) == basis.number_of_states 43 1 : assert pytest.approx(amplitudes[0]) == 1.0 # NOSONAR 44 1 : overlaps = basis.get_overlaps(test_ket) 45 1 : assert len(overlaps) == basis.number_of_states 46 1 : assert pytest.approx(overlaps[0]) == 1.0 # NOSONAR 47 : 48 : # Test with state 49 1 : test_state = basis.states[0] 50 1 : amplitudes = basis.get_amplitudes(test_state) 51 1 : assert len(amplitudes) == basis.number_of_states 52 1 : assert pytest.approx(amplitudes[0]) == 1.0 # NOSONAR 53 1 : overlaps = basis.get_overlaps(test_state) 54 1 : assert len(overlaps) == basis.number_of_states 55 1 : assert pytest.approx(overlaps[0]) == 1.0 # NOSONAR 56 : 57 : # Test with basis 58 1 : matrix_amplitudes = basis.get_amplitudes(basis) 59 1 : assert matrix_amplitudes.shape == (basis.number_of_kets, basis.number_of_states) 60 1 : assert pytest.approx(matrix_amplitudes.diagonal()) == 1.0 # NOSONAR 61 1 : matrix_overlaps = basis.get_overlaps(basis) 62 1 : assert matrix_overlaps.shape == (basis.number_of_states, basis.number_of_states) 63 1 : assert pytest.approx(matrix_overlaps.diagonal()) == 1.0 # NOSONAR 64 : 65 : 66 1 : def test_get_matrix_elements(basis: pi.BasisAtom) -> None: 67 : """Test matrix element calculations.""" 68 : # Test with ket 69 1 : test_ket = basis.kets[0] 70 1 : elements_dipole = basis.get_matrix_elements(test_ket, "electric_dipole", q=0, unit="e * a0") 71 1 : assert elements_dipole.shape == (basis.number_of_states,) 72 1 : assert np.count_nonzero(elements_dipole) > 0 73 1 : assert np.count_nonzero(elements_dipole) < basis.number_of_states 74 : 75 : # Test with state 76 1 : test_state = basis.states[0] 77 1 : elements_dipole = basis.get_matrix_elements(test_state, "electric_dipole", q=0, unit="e * a0") 78 1 : assert elements_dipole.shape == (basis.number_of_states,) 79 1 : assert np.count_nonzero(elements_dipole) > 0 80 1 : assert np.count_nonzero(elements_dipole) < basis.number_of_states 81 : 82 : # Test with basis 83 1 : matrix_elements = basis.get_matrix_elements(basis, "electric_dipole", q=0, unit="e * a0") 84 1 : assert matrix_elements.shape == (basis.number_of_states, basis.number_of_states) 85 1 : assert np.count_nonzero(matrix_elements.toarray()) > 0 86 1 : assert np.count_nonzero(matrix_elements.toarray()) < basis.number_of_states**2 87 : 88 : 89 1 : def test_error_handling(basis: pi.BasisAtom) -> None: 90 : """Test error cases.""" 91 1 : with pytest.raises(TypeError): 92 1 : basis.get_amplitudes("not a ket") # type: ignore [call-overload] 93 : 94 1 : with pytest.raises(TypeError): 95 1 : basis.get_overlaps("not a ket") # type: ignore [call-overload] 96 : 97 1 : with pytest.raises(TypeError): 98 1 : basis.get_matrix_elements("not a ket", "energy", 0) # type: ignore [call-overload]