Line data Source code
1 : # SPDX-FileCopyrightText: 2024 PairInteraction Developers
2 : # SPDX-License-Identifier: LGPL-3.0-or-later
3 :
4 1 : from __future__ import annotations
5 :
6 1 : from typing import TYPE_CHECKING
7 :
8 1 : import numpy as np
9 1 : import pytest
10 1 : from pairinteraction import BasisAtom
11 1 : from pairinteraction.ket.ket_atom import KetAtom
12 1 : from pairinteraction.state.state_atom import StateAtom
13 1 : from scipy.sparse import csr_matrix
14 :
15 : if TYPE_CHECKING:
16 : from .utils import PairinteractionModule
17 :
18 :
19 1 : @pytest.fixture
20 1 : def basis(pi_module: PairinteractionModule) -> BasisAtom:
21 1 : return pi_module.BasisAtom("Rb", n=(58, 62), l=(0, 2))
22 :
23 :
24 1 : @pytest.fixture
25 1 : def basis2(pi_module: PairinteractionModule) -> BasisAtom:
26 1 : return pi_module.BasisAtom("Rb", n=(58, 62), l=(2, 3))
27 :
28 :
29 1 : def test_basis_creation(pi_module: PairinteractionModule) -> None:
30 : """Test basic properties of created basis."""
31 1 : ket = pi_module.KetAtom("Rb", n=60, l=0, j=0.5, m=0.5)
32 1 : energy_min = ket.get_energy(unit="GHz") - 100
33 1 : energy_max = ket.get_energy(unit="GHz") + 100
34 1 : basis = pi_module.BasisAtom("Rb", n=(58, 62), l=(0, 2), energy=(energy_min, energy_max), energy_unit="GHz")
35 1 : assert basis.species == "Rb"
36 1 : assert basis.number_of_kets == 80
37 1 : assert basis.number_of_states == basis.number_of_kets
38 1 : assert len(basis.kets) == basis.number_of_kets
39 1 : assert basis.number_of_kets < pi_module.BasisAtom("Rb", n=(58, 62), l=(0, 2)).number_of_kets
40 1 : assert all(x in str(basis) for x in ["BasisAtom", "n=(58, 62)", "l=(0, 2)"])
41 :
42 :
43 1 : def test_restriction_mode(pi_module: PairinteractionModule) -> None:
44 : """Test exact, fuzzy, and numeric restrictions for expectation-value quantum numbers."""
45 1 : l_range = (1, 1)
46 1 : fuzzy_basis = pi_module.BasisAtom("Yb171_mqdt", nu=(58, 62), l=l_range, m=(0.5, 0.5))
47 1 : assert fuzzy_basis.number_of_kets > 0
48 1 : assert any(ket.l > l_range[1] for ket in fuzzy_basis.kets)
49 1 : assert any(ket.l < l_range[0] for ket in fuzzy_basis.kets)
50 :
51 1 : exact_basis = pi_module.BasisAtom("Yb171_mqdt", nu=(58, 62), l=l_range, m=(0.5, 0.5), mode="exact")
52 1 : assert exact_basis.number_of_kets > 0
53 1 : assert all(l_range[0] <= ket.l <= l_range[1] for ket in exact_basis.kets)
54 1 : assert exact_basis.number_of_kets < fuzzy_basis.number_of_kets
55 :
56 1 : factor_basis = pi_module.BasisAtom("Yb171_mqdt", nu=(58, 62), l=l_range, m=(0.5, 0.5), mode=100)
57 1 : assert factor_basis.number_of_kets > fuzzy_basis.number_of_kets
58 :
59 1 : with pytest.raises(ValueError, match="mode"):
60 1 : pi_module.BasisAtom("Yb171_mqdt", nu=(58, 62), l=l_range, m=(0.5, 0.5), mode="invalid") # type: ignore[arg-type]
61 :
62 1 : with pytest.raises(ValueError, match="non-negative"):
63 1 : pi_module.BasisAtom("Yb171_mqdt", nu=(58, 62), l=l_range, m=(0.5, 0.5), mode=-1)
64 :
65 :
66 1 : def test_coefficients(basis: BasisAtom) -> None:
67 : """Test coefficient matrix properties."""
68 1 : coeffs = basis.get_coefficients()
69 1 : assert coeffs.shape == (basis.number_of_kets, basis.number_of_states)
70 1 : assert pytest.approx(coeffs.diagonal()) == 1.0 # NOSONAR
71 1 : assert pytest.approx(coeffs.sum()) == basis.number_of_kets # NOSONAR
72 :
73 :
74 1 : def _get_expected_shape(other: KetAtom | StateAtom | BasisAtom, target_basis: BasisAtom) -> tuple[int, ...]:
75 1 : if isinstance(other, (KetAtom, StateAtom)):
76 1 : return (target_basis.number_of_states,)
77 1 : if isinstance(other, BasisAtom):
78 1 : return (other.number_of_states, target_basis.number_of_states)
79 0 : raise ValueError("Invalid basis_like type")
80 :
81 :
82 1 : @pytest.mark.parametrize(
83 : "other_key",
84 : [
85 : "ket_from_basis",
86 : "ket_from_basis2",
87 : "other_ket_from_basis2",
88 : "state_from_basis",
89 : "state_from_basis2",
90 : "basis",
91 : "basis2",
92 : ],
93 : )
94 1 : def test_get_methods(basis: BasisAtom, basis2: BasisAtom, other_key: str) -> None:
95 : """Test amplitude, overlap and matrix element calculations with another ket, state and basis."""
96 1 : ind = 5
97 1 : other_dict: dict[str, KetAtom | StateAtom | BasisAtom] = {
98 : "ket_from_basis": basis.get_ket(ind),
99 : "ket_from_basis2": next(ket for ket in basis2.kets if ket in basis.kets),
100 : "other_ket_from_basis2": next(ket for ket in basis2.kets if ket not in basis.kets and ket.j_ryd < 3),
101 : "state_from_basis": basis.get_state(ind),
102 : "state_from_basis2": basis2.get_state(0),
103 : "basis": basis,
104 : "basis2": basis2,
105 : }
106 1 : other = other_dict[other_key]
107 :
108 1 : amplitudes = basis.get_amplitudes(other)
109 1 : assert amplitudes.shape == _get_expected_shape(other, basis)
110 :
111 1 : overlaps = basis.get_overlaps(other)
112 1 : assert overlaps.shape == _get_expected_shape(other, basis)
113 :
114 1 : matrix_elements = basis.get_matrix_elements(other, "electric_dipole", q=0, unit="e * a0")
115 1 : assert matrix_elements.shape == _get_expected_shape(other, basis)
116 :
117 1 : if other_key in ["ket_from_basis", "state_from_basis"]:
118 1 : assert pytest.approx(amplitudes[ind]) == 1.0 # NOSONAR
119 1 : assert pytest.approx(overlaps[ind]) == 1.0 # NOSONAR
120 :
121 1 : if other_key == "ket_from_basis2":
122 1 : assert isinstance(amplitudes, np.ndarray)
123 1 : assert isinstance(overlaps, np.ndarray)
124 1 : assert pytest.approx(np.max(amplitudes)) == 1.0 # NOSONAR
125 1 : assert pytest.approx(np.max(overlaps)) == 1.0 # NOSONAR
126 :
127 1 : if other_key == "other_ket_from_basis2":
128 1 : assert isinstance(amplitudes, np.ndarray)
129 1 : assert isinstance(overlaps, np.ndarray)
130 1 : assert np.count_nonzero(amplitudes) == 0
131 1 : assert np.count_nonzero(overlaps) == 0
132 :
133 1 : if other_key == "basis":
134 1 : assert isinstance(amplitudes, csr_matrix)
135 1 : assert isinstance(overlaps, csr_matrix)
136 1 : assert pytest.approx(amplitudes.diagonal()) == 1.0 # NOSONAR
137 1 : assert pytest.approx(overlaps.diagonal()) == 1.0 # NOSONAR
138 :
139 1 : if other_key == "basis2":
140 1 : assert isinstance(amplitudes, csr_matrix)
141 1 : assert isinstance(overlaps, csr_matrix)
142 1 : n_matching_kets = len([ket for ket in basis2.kets if ket in basis.kets])
143 1 : assert np.count_nonzero(amplitudes.toarray()) == n_matching_kets
144 1 : assert np.count_nonzero(overlaps.toarray()) == n_matching_kets
145 :
146 1 : if other_key.startswith("basis"):
147 1 : assert isinstance(matrix_elements, csr_matrix)
148 1 : assert 0 < np.count_nonzero(matrix_elements.toarray()) < basis.number_of_states**2
149 : else:
150 1 : assert isinstance(matrix_elements, np.ndarray)
151 1 : assert 0 < np.count_nonzero(matrix_elements) < basis.number_of_states
152 :
153 :
154 1 : def test_error_handling(basis: BasisAtom) -> None:
155 : """Test error cases."""
156 1 : with pytest.raises(TypeError):
157 1 : basis.get_amplitudes("not a ket") # type: ignore [call-overload]
158 :
159 1 : with pytest.raises(TypeError):
160 1 : basis.get_overlaps("not a ket") # type: ignore [call-overload]
161 :
162 1 : with pytest.raises(TypeError):
163 1 : basis.get_matrix_elements("not a ket", "energy", 0) # type: ignore [call-overload]
164 :
165 :
166 1 : def test_from_kets(pi_module: PairinteractionModule) -> None:
167 : """Test BasisAtom.from_kets."""
168 : # single ket
169 1 : ket = pi_module.KetAtom("Rb", n=60, l=0, j=0.5, m=0.5)
170 1 : basis = pi_module.BasisAtom.from_kets(
171 : ket,
172 : delta_n=2,
173 : delta_nu=3,
174 : delta_nui=3,
175 : delta_l=2,
176 : delta_s=1,
177 : delta_j=3,
178 : delta_l_ryd=2,
179 : delta_j_ryd=3,
180 : delta_f=3,
181 : delta_m=2,
182 : delta_energy=100,
183 : delta_energy_unit="GHz",
184 : )
185 1 : assert basis.species == "Rb"
186 1 : assert all(58 <= k.n <= 62 for k in basis.kets)
187 1 : assert any(k.n == 62 for k in basis.kets)
188 1 : assert any(k.n == 58 for k in basis.kets)
189 1 : assert any(k == ket for k in basis.kets)
190 :
191 : # multiple kets
192 1 : ket1 = pi_module.KetAtom("Rb", n=60, l=0, j=0.5, m=0.5)
193 1 : ket2 = pi_module.KetAtom("Rb", n=61, l=0, j=0.5, m=0.5)
194 1 : basis = pi_module.BasisAtom.from_kets([ket1, ket2], delta_n=2)
195 1 : assert all(58 <= k.n <= 63 for k in basis.kets)
196 1 : assert any(k.n == 63 for k in basis.kets)
197 1 : assert any(k.n == 58 for k in basis.kets)
198 1 : assert any(k == ket1 for k in basis.kets)
199 1 : assert any(k == ket2 for k in basis.kets)
200 :
201 : # test that from_kets is consistent with direct constructor
202 1 : ket = pi_module.KetAtom("Rb", n=60, l=0, j=0.5, m=0.5)
203 1 : basis_from = pi_module.BasisAtom.from_kets(ket, delta_n=2)
204 1 : basis_direct = pi_module.BasisAtom("Rb", n=(58, 62))
205 1 : assert basis_from.number_of_kets == basis_direct.number_of_kets
206 :
207 : # test error cases
208 1 : with pytest.raises(ValueError, match="empty"):
209 1 : pi_module.BasisAtom.from_kets([])
210 :
211 1 : ket_rb = pi_module.KetAtom("Rb", n=60, l=0, j=0.5, m=0.5)
212 1 : ket_sr = pi_module.KetAtom("Sr88_singlet", n=60, l=1, j=1, m=0)
213 1 : with pytest.raises(ValueError, match="species"):
214 1 : pi_module.BasisAtom.from_kets([ket_rb, ket_sr])
|