Line data Source code
1 : # SPDX-FileCopyrightText: 2024 PairInteraction Developers
2 : # SPDX-License-Identifier: LGPL-3.0-or-later
3 :
4 1 : from __future__ import annotations
5 :
6 1 : from typing import TYPE_CHECKING
7 :
8 1 : import numpy as np
9 1 : import pytest
10 1 : from pairinteraction import BasisAtom, BasisPair
11 1 : from pairinteraction.ket.ket_pair import is_ket_pair_like
12 1 : from pairinteraction.state.state_pair import is_state_pair_like
13 :
14 : if TYPE_CHECKING:
15 : from pairinteraction import KetAtom, SystemAtom
16 : from pairinteraction.basis.basis_pair import BasisPairLike
17 : from pairinteraction.ket.ket_pair import KetPairLike
18 : from pairinteraction.state.state_pair import StatePairLike
19 :
20 : from .utils import PairinteractionModule
21 :
22 1 : from .utils import no_log_propagation
23 :
24 :
25 1 : @pytest.fixture
26 1 : def basis(pi_module: PairinteractionModule, system_atom: SystemAtom, system_atom2: SystemAtom) -> BasisPair:
27 1 : system_atoms = [system_atom, system_atom2]
28 1 : ket_atoms = tuple(s.basis.get_ket(30) for s in system_atoms)
29 1 : with no_log_propagation("pairinteraction.basis.basis_pair"): # suppress number_of_kets warning
30 1 : return pi_module.BasisPair.from_kets(ket_atoms, system_atoms, number_of_kets=80)
31 :
32 :
33 1 : @pytest.fixture
34 1 : def basis2(pi_module: PairinteractionModule, system_atom: SystemAtom) -> BasisPair:
35 1 : system_atoms = [system_atom, system_atom]
36 1 : ket_atoms = tuple(s.basis.get_ket(30) for s in system_atoms)
37 1 : with no_log_propagation("pairinteraction.basis.basis_pair"): # suppress number_of_kets warning
38 1 : return pi_module.BasisPair.from_kets(ket_atoms, system_atoms, number_of_kets=80)
39 :
40 :
41 1 : @pytest.fixture
42 1 : def basis_atom(pi_module: PairinteractionModule) -> BasisAtom:
43 1 : return pi_module.BasisAtom("Rb", n=(58, 62), l=(0, 2))
44 :
45 :
46 1 : @pytest.fixture
47 1 : def basis_atom2(pi_module: PairinteractionModule) -> BasisAtom:
48 1 : return pi_module.BasisAtom("Rb", n=(58, 62), l=(2, 3))
49 :
50 :
51 1 : @pytest.fixture
52 1 : def system_atom(pi_module: PairinteractionModule, basis_atom: BasisAtom) -> SystemAtom:
53 1 : return pi_module.SystemAtom(basis_atom)
54 :
55 :
56 1 : @pytest.fixture
57 1 : def system_atom2(pi_module: PairinteractionModule, basis_atom2: BasisAtom) -> SystemAtom:
58 1 : return pi_module.SystemAtom(basis_atom2)
59 :
60 :
61 1 : def test_basis_creation(pi_module: PairinteractionModule, system_atom: SystemAtom, system_atom2: SystemAtom) -> None:
62 : """Test basic properties of created basis."""
63 1 : basis = pi_module.BasisPair([system_atom, system_atom2])
64 1 : assert basis.number_of_kets == system_atom.basis.number_of_kets * system_atom2.basis.number_of_kets
65 1 : assert basis.number_of_states == basis.number_of_kets
66 1 : assert len(basis.kets) == basis.number_of_kets
67 1 : assert all(x in str(basis) for x in ["BasisPair", "Rb:58,S_1/2,", "...", "Rb:62,D_5/2,"])
68 :
69 :
70 1 : def test_coefficients(basis: BasisPair) -> None:
71 : """Test coefficient matrix properties."""
72 1 : coeffs = basis.get_coefficients()
73 1 : assert coeffs.shape == (basis.number_of_kets, basis.number_of_states)
74 1 : assert pytest.approx(coeffs.diagonal()) == 1.0 # NOSONAR
75 1 : assert pytest.approx(coeffs.sum()) == basis.number_of_kets # NOSONAR
76 :
77 :
78 1 : def _get_expected_shape(other: KetPairLike | StatePairLike | BasisPairLike, target_basis: BasisPair) -> tuple[int, ...]:
79 1 : if is_ket_pair_like(other) or is_state_pair_like(other):
80 1 : return (target_basis.number_of_states,)
81 1 : if isinstance(other, BasisPair):
82 1 : return (other.number_of_states, target_basis.number_of_states)
83 1 : if isinstance(other, list) and all(isinstance(b, BasisAtom) for b in other):
84 1 : return (other[0].number_of_states * other[1].number_of_states, target_basis.number_of_states)
85 0 : raise ValueError("Invalid basis_like type")
86 :
87 :
88 1 : @pytest.mark.parametrize(
89 : "other_key",
90 : [
91 : "ket_from_basis",
92 : "ket_from_basis2",
93 : "ket_atom_tuple_from_basis",
94 : "ket_atom_tuple_from_basis2",
95 : "state_from_basis",
96 : "state_from_basis2",
97 : "state_atom_tuple_from_basis",
98 : "state_atom_tuple_from_basis2",
99 : "basis",
100 : "basis2",
101 : "basis_atom_tuple_from_basis",
102 : "basis_atom_tuple_from_basis2",
103 : ],
104 : )
105 1 : def test_get_methods(basis: BasisPair, basis2: BasisPair, other_key: str) -> None:
106 : """Test amplitude, overlap and matrix element calculations with another ket, state and basis."""
107 1 : other_dict: dict[str, KetPairLike | StatePairLike | BasisPairLike] = {
108 : "ket_from_basis": basis.get_ket(0),
109 : "ket_from_basis2": basis2.get_ket(0),
110 : "ket_atom_tuple_from_basis": [s.basis.get_ket(0) for s in basis.system_atoms],
111 : "ket_atom_tuple_from_basis2": [s.basis.get_ket(0) for s in basis2.system_atoms],
112 : "state_from_basis": basis.get_state(0),
113 : "state_from_basis2": basis2.get_state(0),
114 : "state_atom_tuple_from_basis": [s.basis.get_state(0) for s in basis.system_atoms],
115 : "state_atom_tuple_from_basis2": [s.basis.get_state(0) for s in basis2.system_atoms],
116 : "basis": basis,
117 : "basis2": basis2,
118 : "basis_atom_tuple_from_basis": [s.basis for s in basis.system_atoms],
119 : "basis_atom_tuple_from_basis2": [s.basis for s in basis2.system_atoms],
120 : }
121 1 : other = other_dict[other_key]
122 :
123 1 : amplitudes = basis.get_amplitudes(other)
124 1 : assert amplitudes.shape == _get_expected_shape(other, basis)
125 :
126 1 : overlaps = basis.get_overlaps(other)
127 1 : assert overlaps.shape == _get_expected_shape(other, basis)
128 :
129 1 : elements_dipole = basis.get_matrix_elements(
130 : other, ("electric_dipole", "electric_dipole"), qs=(0, 0), unit="e^2 * a0^2"
131 : )
132 1 : assert elements_dipole.shape == _get_expected_shape(other, basis)
133 :
134 :
135 1 : def test_get_overlaps_pair_explicitly(pi_module: PairinteractionModule, system_atom: SystemAtom) -> None:
136 : # Test with a different basis (non-identical kets) - cross-basis overlaps
137 1 : ket = system_atom.basis.get_ket(20)
138 1 : energy_min = 2 * ket.get_energy("GHz") - 1
139 1 : energy_max = 2 * ket.get_energy("GHz") + 1
140 1 : basis_pair1 = pi_module.BasisPair([system_atom, system_atom])
141 1 : basis_pair2 = pi_module.BasisPair([system_atom, system_atom], energy=(energy_min, energy_max), energy_unit="GHz")
142 1 : assert basis_pair1.number_of_states > basis_pair2.number_of_states
143 1 : assert basis_pair1.number_of_states > 1
144 :
145 1 : matrix_overlaps = basis_pair1.get_overlaps(basis_pair2)
146 1 : assert matrix_overlaps.shape == (basis_pair2.number_of_states, basis_pair1.number_of_states)
147 1 : assert (matrix_overlaps.data >= 0).all()
148 1 : assert (matrix_overlaps.data <= 1 + 1e-10).all()
149 :
150 1 : col_sums = np.array(matrix_overlaps.sum(axis=0)).flatten()
151 1 : assert np.all(col_sums <= 1.0 + 1e-10)
152 1 : row_sums = np.array(matrix_overlaps.sum(axis=1)).flatten()
153 1 : assert np.all(row_sums <= 1.0 + 1e-10)
154 :
155 1 : idx0 = basis_pair1.get_corresponding_state_index((ket, ket))
156 1 : idx2 = basis_pair2.get_corresponding_state_index((ket, ket))
157 1 : assert pytest.approx(matrix_overlaps[idx2, idx0]) == 1.0 # NOSONAR
158 :
159 :
160 1 : def test_error_handling(basis: BasisPair) -> None:
161 : """Test error cases."""
162 1 : with pytest.raises(TypeError):
163 1 : basis.get_amplitudes("not a ket") # type: ignore [arg-type]
164 :
165 1 : with pytest.raises(TypeError):
166 1 : basis.get_overlaps("not a ket") # type: ignore [arg-type]
167 :
168 1 : with pytest.raises(TypeError):
169 1 : basis.get_matrix_elements("not a ket", ("energy", "energy"), (0, 0)) # type: ignore [arg-type]
170 :
171 :
172 1 : def test_from_kets(pi_module: PairinteractionModule, system_atom: SystemAtom) -> None:
173 : """Test BasisPair.from_kets."""
174 1 : ket = pi_module.KetAtom("Rb", n=60, l=0, j=0.5, m=0.5)
175 1 : ket1 = pi_module.KetAtom("Rb", n=59, l=1, j=0.5, m=0.5)
176 1 : ket2 = pi_module.KetAtom("Rb", n=60, l=1, j=0.5, m=0.5)
177 1 : ket_atoms_dict: dict[str, tuple[KetAtom, KetAtom] | list[tuple[KetAtom, KetAtom]]] = {
178 : "single_ket": (ket, ket),
179 : "multiple_kets": [(ket, ket), (ket1, ket2), (ket2, ket1)],
180 : }
181 :
182 1 : for ket_atoms in ket_atoms_dict.values():
183 : # delta_energy restriction
184 1 : pair_basis = pi_module.BasisPair.from_kets(
185 : ket_atoms, [system_atom, system_atom], delta_energy=3, delta_energy_unit="GHz", delta_m=1
186 : )
187 1 : assert pair_basis.number_of_kets > 0
188 :
189 : # number_of_kets
190 1 : ket = pi_module.KetAtom("Rb", n=60, l=0, j=0.5, m=0.5)
191 1 : for target in [50, 100, 1000]:
192 1 : with no_log_propagation("pairinteraction.basis.basis_pair"): # suppress number_of_kets warning
193 1 : pair_basis = pi_module.BasisPair.from_kets(
194 : ket_atoms, [system_atom, system_atom], number_of_kets=target, delta_m=1
195 : )
196 1 : assert pair_basis.number_of_kets >= target
197 1 : assert pair_basis.number_of_kets < target + 20 # allow some extra due to degeneracies
198 :
199 : # test error cases
200 1 : with pytest.raises(ValueError, match="empty"):
201 1 : pi_module.BasisPair.from_kets([], [system_atom, system_atom])
202 :
203 1 : ket = pi_module.KetAtom("Rb", n=60, l=0, j=0.5, m=0.5)
204 1 : with pytest.raises(ValueError, match="number_of_kets"):
205 1 : pi_module.BasisPair.from_kets(
206 : (ket, ket),
207 : [system_atom, system_atom],
208 : delta_energy=3,
209 : delta_energy_unit="GHz",
210 : number_of_kets=10,
211 : )
|