Line data Source code
1 : # SPDX-FileCopyrightText: 2024 PairInteraction Developers
2 : # SPDX-License-Identifier: LGPL-3.0-or-later
3 :
4 1 : from __future__ import annotations
5 :
6 1 : from typing import TYPE_CHECKING
7 :
8 1 : import numpy as np
9 1 : import pytest
10 :
11 : if TYPE_CHECKING:
12 : from pairinteraction import BasisAtom
13 :
14 : from .utils import PairinteractionModule
15 :
16 :
17 1 : @pytest.fixture
18 1 : def basis(pi_module: PairinteractionModule) -> BasisAtom:
19 : """Create a test basis with a few states around Rb 60S."""
20 1 : ket = pi_module.KetAtom("Rb", n=60, l=0, j=0.5, m=0.5)
21 1 : energy_min = ket.get_energy(unit="GHz") - 100
22 1 : energy_max = ket.get_energy(unit="GHz") + 100
23 1 : return pi_module.BasisAtom("Rb", n=(58, 62), l=(0, 2), energy=(energy_min, energy_max), energy_unit="GHz")
24 :
25 :
26 1 : def test_basis_creation(pi_module: PairinteractionModule, basis: BasisAtom) -> None:
27 : """Test basic properties of created basis."""
28 1 : assert basis.species == "Rb"
29 1 : assert basis.number_of_kets == 80
30 1 : assert basis.number_of_states == basis.number_of_kets
31 1 : assert len(basis.kets) == basis.number_of_kets
32 1 : assert basis.number_of_kets < pi_module.BasisAtom("Rb", n=(58, 62), l=(0, 2)).number_of_kets
33 1 : assert all(x in str(basis) for x in ["BasisAtom", "n=(58, 62)", "l=(0, 2)"])
34 :
35 :
36 1 : def test_coefficients(basis: BasisAtom) -> None:
37 : """Test coefficient matrix properties."""
38 1 : coeffs = basis.get_coefficients()
39 1 : assert coeffs.shape == (basis.number_of_kets, basis.number_of_states)
40 1 : assert pytest.approx(coeffs.diagonal()) == 1.0 # NOSONAR
41 1 : assert pytest.approx(coeffs.sum()) == basis.number_of_kets # NOSONAR
42 :
43 :
44 1 : def test_get_amplitudes_and_overlaps(basis: BasisAtom) -> None:
45 : """Test amplitude and overlap calculations."""
46 : # Test with ket
47 1 : test_ket = basis.kets[0]
48 1 : amplitudes = basis.get_amplitudes(test_ket)
49 1 : assert len(amplitudes) == basis.number_of_states
50 1 : assert pytest.approx(amplitudes[0]) == 1.0 # NOSONAR
51 1 : overlaps = basis.get_overlaps(test_ket)
52 1 : assert len(overlaps) == basis.number_of_states
53 1 : assert pytest.approx(overlaps[0]) == 1.0 # NOSONAR
54 :
55 : # Test with state
56 1 : test_state = basis.states[0]
57 1 : amplitudes = basis.get_amplitudes(test_state)
58 1 : assert len(amplitudes) == basis.number_of_states
59 1 : assert pytest.approx(amplitudes[0]) == 1.0 # NOSONAR
60 1 : overlaps = basis.get_overlaps(test_state)
61 1 : assert len(overlaps) == basis.number_of_states
62 1 : assert pytest.approx(overlaps[0]) == 1.0 # NOSONAR
63 :
64 : # Test with basis
65 1 : matrix_amplitudes = basis.get_amplitudes(basis)
66 1 : assert matrix_amplitudes.shape == (basis.number_of_kets, basis.number_of_states)
67 1 : assert pytest.approx(matrix_amplitudes.diagonal()) == 1.0 # NOSONAR
68 1 : matrix_overlaps = basis.get_overlaps(basis)
69 1 : assert matrix_overlaps.shape == (basis.number_of_states, basis.number_of_states)
70 1 : assert pytest.approx(matrix_overlaps.diagonal()) == 1.0 # NOSONAR
71 :
72 :
73 1 : def test_get_matrix_elements(basis: BasisAtom) -> None:
74 : """Test matrix element calculations."""
75 : # Test with ket
76 1 : test_ket = basis.kets[0]
77 1 : elements_dipole = basis.get_matrix_elements(test_ket, "electric_dipole", q=0, unit="e * a0")
78 1 : assert elements_dipole.shape == (basis.number_of_states,)
79 1 : assert np.count_nonzero(elements_dipole) > 0
80 1 : assert np.count_nonzero(elements_dipole) < basis.number_of_states
81 :
82 : # Test with state
83 1 : test_state = basis.states[0]
84 1 : elements_dipole = basis.get_matrix_elements(test_state, "electric_dipole", q=0, unit="e * a0")
85 1 : assert elements_dipole.shape == (basis.number_of_states,)
86 1 : assert np.count_nonzero(elements_dipole) > 0
87 1 : assert np.count_nonzero(elements_dipole) < basis.number_of_states
88 :
89 : # Test with basis
90 1 : matrix_elements = basis.get_matrix_elements(basis, "electric_dipole", q=0, unit="e * a0")
91 1 : assert matrix_elements.shape == (basis.number_of_states, basis.number_of_states)
92 1 : assert np.count_nonzero(matrix_elements.toarray()) > 0
93 1 : assert np.count_nonzero(matrix_elements.toarray()) < basis.number_of_states**2
94 :
95 :
96 1 : def test_error_handling(basis: BasisAtom) -> None:
97 : """Test error cases."""
98 1 : with pytest.raises(TypeError):
99 1 : basis.get_amplitudes("not a ket") # type: ignore [call-overload]
100 :
101 1 : with pytest.raises(TypeError):
102 1 : basis.get_overlaps("not a ket") # type: ignore [call-overload]
103 :
104 1 : with pytest.raises(TypeError):
105 1 : basis.get_matrix_elements("not a ket", "energy", 0) # type: ignore [call-overload]
106 :
107 :
108 1 : def test_from_kets(pi_module: PairinteractionModule) -> None:
109 : """Test BasisAtom.from_kets."""
110 : # single ket
111 1 : ket = pi_module.KetAtom("Rb", n=60, l=0, j=0.5, m=0.5)
112 1 : basis = pi_module.BasisAtom.from_kets(
113 : ket,
114 : delta_n=2,
115 : delta_nu=3,
116 : delta_nui=3,
117 : delta_l=2,
118 : delta_s=1,
119 : delta_j=3,
120 : delta_l_ryd=2,
121 : delta_j_ryd=3,
122 : delta_f=3,
123 : delta_m=2,
124 : delta_energy=100,
125 : delta_energy_unit="GHz",
126 : )
127 1 : assert basis.species == "Rb"
128 1 : assert all(58 <= k.n <= 62 for k in basis.kets)
129 1 : assert any(k.n == 62 for k in basis.kets)
130 1 : assert any(k.n == 58 for k in basis.kets)
131 1 : assert any(k == ket for k in basis.kets)
132 :
133 : # multiple kets
134 1 : ket1 = pi_module.KetAtom("Rb", n=60, l=0, j=0.5, m=0.5)
135 1 : ket2 = pi_module.KetAtom("Rb", n=61, l=0, j=0.5, m=0.5)
136 1 : basis = pi_module.BasisAtom.from_kets([ket1, ket2], delta_n=2)
137 1 : assert all(58 <= k.n <= 63 for k in basis.kets)
138 1 : assert any(k.n == 63 for k in basis.kets)
139 1 : assert any(k.n == 58 for k in basis.kets)
140 1 : assert any(k == ket1 for k in basis.kets)
141 1 : assert any(k == ket2 for k in basis.kets)
142 :
143 : # test that from_kets is consistent with direct constructor
144 1 : ket = pi_module.KetAtom("Rb", n=60, l=0, j=0.5, m=0.5)
145 1 : basis_from = pi_module.BasisAtom.from_kets(ket, delta_n=2)
146 1 : basis_direct = pi_module.BasisAtom("Rb", n=(58, 62))
147 1 : assert basis_from.number_of_kets == basis_direct.number_of_kets
148 :
149 : # test error cases
150 1 : with pytest.raises(ValueError, match="empty"):
151 1 : pi_module.BasisAtom.from_kets([])
152 :
153 1 : ket_rb = pi_module.KetAtom("Rb", n=60, l=0, j=0.5, m=0.5)
154 1 : ket_sr = pi_module.KetAtom("Sr88_singlet", n=60, l=1, j=1, m=0)
155 1 : with pytest.raises(ValueError, match="species"):
156 1 : pi_module.BasisAtom.from_kets([ket_rb, ket_sr])
|