Line data Source code
1 : # SPDX-FileCopyrightText: 2024 PairInteraction Developers 2 : # SPDX-License-Identifier: LGPL-3.0-or-later 3 : 4 1 : from __future__ import annotations 5 : 6 1 : from typing import TYPE_CHECKING 7 : 8 1 : import numpy as np 9 1 : import pytest 10 : 11 : if TYPE_CHECKING: 12 : from pairinteraction import BasisPair 13 : 14 : from .utils import PairinteractionModule 15 : 16 : 17 1 : @pytest.fixture 18 1 : def basis(pi_module: PairinteractionModule) -> BasisPair: 19 : """Create a test basis with a few states around Rb 60S.""" 20 1 : ket = pi_module.KetAtom("Rb", n=60, l=0, j=0.5, m=0.5) 21 1 : energy_min = ket.get_energy(unit="GHz") - 100 22 1 : energy_max = ket.get_energy(unit="GHz") + 100 23 1 : basis_atom = pi_module.BasisAtom("Rb", n=(58, 62), l=(0, 2), energy=(energy_min, energy_max), energy_unit="GHz") 24 1 : system_atom = pi_module.SystemAtom(basis_atom).set_electric_field([0.1, 0, 0.2], unit="V/cm") 25 1 : system_atom.diagonalize() 26 1 : return pi_module.BasisPair([system_atom, system_atom]) 27 : 28 : 29 1 : def test_basis_creation(basis: BasisPair) -> None: 30 : """Test basic properties of created basis.""" 31 1 : assert basis.number_of_kets == 80 * 80 32 1 : assert basis.number_of_states == basis.number_of_kets 33 1 : assert len(basis.kets) == basis.number_of_kets 34 1 : assert all(x in str(basis) for x in ["BasisPair", "Rb:58,S_1/2,", "...", "Rb:61,D_5/2,"]) 35 : 36 : 37 1 : def test_coefficients(basis: BasisPair) -> None: 38 : """Test coefficient matrix properties.""" 39 1 : coeffs = basis.get_coefficients() 40 1 : assert coeffs.shape == (basis.number_of_kets, basis.number_of_states) 41 1 : assert pytest.approx(coeffs.diagonal()) == 1.0 # NOSONAR 42 1 : assert pytest.approx(coeffs.sum()) == basis.number_of_kets # NOSONAR 43 : 44 : 45 1 : def test_get_amplitudes_and_overlaps(basis: BasisPair) -> None: 46 : """Test amplitude and overlap calculations.""" 47 : # Test with ket 48 1 : test_ket = basis.kets[0] 49 1 : amplitudes = basis.get_amplitudes(test_ket) 50 1 : assert len(amplitudes) == basis.number_of_states 51 1 : assert pytest.approx(amplitudes[0]) == 1.0 # NOSONAR 52 1 : overlaps = basis.get_overlaps(test_ket) 53 1 : assert len(overlaps) == basis.number_of_states 54 1 : assert pytest.approx(overlaps[0]) == 1.0 # NOSONAR 55 : 56 : # Test with basis 57 1 : matrix_amplitudes = basis.get_amplitudes(basis) 58 1 : assert matrix_amplitudes.shape == (basis.number_of_kets, basis.number_of_states) 59 1 : assert pytest.approx(matrix_amplitudes.diagonal()) == 1.0 # NOSONAR 60 1 : matrix_overlaps = basis.get_overlaps(basis) 61 1 : assert matrix_overlaps.shape == (basis.number_of_states, basis.number_of_states) 62 1 : assert pytest.approx(matrix_overlaps.diagonal()) == 1.0 # NOSONAR 63 : 64 : 65 1 : def test_get_matrix_elements(basis: BasisPair) -> None: 66 : """Test matrix element calculations.""" 67 : # Test with ket 68 1 : test_ket = basis.kets[0] 69 1 : elements_dipole = basis.get_matrix_elements( 70 : test_ket, ("electric_dipole", "electric_dipole"), qs=(0, 0), unit="e^2 * a0^2" 71 : ) 72 1 : assert elements_dipole.shape == (basis.number_of_states,) 73 1 : assert np.count_nonzero(elements_dipole) > 0 74 1 : assert np.count_nonzero(elements_dipole) <= basis.number_of_states 75 : 76 : # Test with basis 77 1 : matrix_elements = basis.get_matrix_elements( 78 : basis, ("electric_dipole", "electric_dipole"), qs=(0, 0), unit="e^2 * a0^2" 79 : ) 80 1 : assert matrix_elements.shape == (basis.number_of_states, basis.number_of_states) 81 1 : assert np.count_nonzero(matrix_elements.toarray()) > 0 82 1 : assert np.count_nonzero(matrix_elements.toarray()) <= basis.number_of_states**2 83 : 84 : 85 1 : def test_error_handling(basis: BasisPair) -> None: 86 : """Test error cases.""" 87 1 : with pytest.raises(TypeError): 88 1 : basis.get_amplitudes("not a ket") # type: ignore [arg-type] 89 : 90 1 : with pytest.raises(TypeError): 91 1 : basis.get_overlaps("not a ket") # type: ignore [arg-type] 92 : 93 1 : with pytest.raises(TypeError): 94 1 : basis.get_matrix_elements("not a ket", ("energy", "energy"), (0, 0)) # type: ignore [arg-type]