Line data Source code
1 : # SPDX-FileCopyrightText: 2024 PairInteraction Developers
2 : # SPDX-License-Identifier: LGPL-3.0-or-later
3 :
4 1 : from __future__ import annotations
5 :
6 1 : from typing import TYPE_CHECKING
7 :
8 1 : import numpy as np
9 1 : import pytest
10 :
11 : if TYPE_CHECKING:
12 : from pairinteraction import BasisAtom, BasisPair, KetAtom, SystemAtom
13 :
14 : from .utils import PairinteractionModule
15 :
16 1 : from .utils import no_log_propagation
17 :
18 :
19 1 : @pytest.fixture
20 1 : def basis_atom(pi_module: PairinteractionModule) -> BasisAtom:
21 : """Create a BasisAtom around Rb 60S."""
22 1 : ket = pi_module.KetAtom("Rb", n=60, l=0, j=0.5, m=0.5)
23 1 : energy_min = ket.get_energy(unit="GHz") - 100
24 1 : energy_max = ket.get_energy(unit="GHz") + 100
25 1 : return pi_module.BasisAtom("Rb", n=(58, 62), l=(0, 2), energy=(energy_min, energy_max), energy_unit="GHz")
26 :
27 :
28 1 : @pytest.fixture
29 1 : def system_atom(pi_module: PairinteractionModule, basis_atom: BasisAtom) -> SystemAtom:
30 : """Create a diagonalized SystemAtom around Rb 60S."""
31 1 : system_atom = pi_module.SystemAtom(basis_atom)
32 1 : system_atom.diagonalize()
33 1 : return system_atom
34 :
35 :
36 1 : @pytest.fixture
37 1 : def basis(pi_module: PairinteractionModule, system_atom: SystemAtom) -> BasisPair:
38 : """Create a test basis with a few states around Rb 60S."""
39 1 : return pi_module.BasisPair([system_atom, system_atom])
40 :
41 :
42 1 : def test_basis_creation(basis: BasisPair) -> None:
43 : """Test basic properties of created basis."""
44 1 : assert basis.number_of_kets == 80 * 80
45 1 : assert basis.number_of_states == basis.number_of_kets
46 1 : assert len(basis.kets) == basis.number_of_kets
47 1 : assert all(x in str(basis) for x in ["BasisPair", "Rb:58,S_1/2,", "...", "Rb:61,D_5/2,"])
48 :
49 :
50 1 : def test_coefficients(basis: BasisPair) -> None:
51 : """Test coefficient matrix properties."""
52 1 : coeffs = basis.get_coefficients()
53 1 : assert coeffs.shape == (basis.number_of_kets, basis.number_of_states)
54 1 : assert pytest.approx(coeffs.diagonal()) == 1.0 # NOSONAR
55 1 : assert pytest.approx(coeffs.sum()) == basis.number_of_kets # NOSONAR
56 :
57 :
58 1 : def test_get_amplitudes_and_overlaps(basis: BasisPair) -> None:
59 : """Test amplitude and overlap calculations."""
60 : # Test with ket
61 1 : test_ket = basis.kets[0]
62 1 : amplitudes = basis.get_amplitudes(test_ket)
63 1 : assert len(amplitudes) == basis.number_of_states
64 1 : assert pytest.approx(amplitudes[0]) == 1.0 # NOSONAR
65 1 : overlaps = basis.get_overlaps(test_ket)
66 1 : assert len(overlaps) == basis.number_of_states
67 1 : assert pytest.approx(overlaps[0]) == 1.0 # NOSONAR
68 :
69 : # Test with basis
70 1 : matrix_amplitudes = basis.get_amplitudes(basis)
71 1 : assert matrix_amplitudes.shape == (basis.number_of_kets, basis.number_of_states)
72 1 : assert pytest.approx(matrix_amplitudes.diagonal()) == 1.0 # NOSONAR
73 1 : matrix_overlaps = basis.get_overlaps(basis)
74 1 : assert matrix_overlaps.shape == (basis.number_of_states, basis.number_of_states)
75 1 : assert pytest.approx(matrix_overlaps.diagonal()) == 1.0 # NOSONAR
76 :
77 :
78 1 : def test_get_matrix_elements(basis: BasisPair) -> None:
79 : """Test matrix element calculations."""
80 : # Test with ket
81 1 : test_ket = basis.kets[0]
82 1 : elements_dipole = basis.get_matrix_elements(
83 : test_ket, ("electric_dipole", "electric_dipole"), qs=(0, 0), unit="e^2 * a0^2"
84 : )
85 1 : assert elements_dipole.shape == (basis.number_of_states,)
86 1 : assert np.count_nonzero(elements_dipole) > 0
87 1 : assert np.count_nonzero(elements_dipole) <= basis.number_of_states
88 :
89 : # Test with basis
90 1 : matrix_elements = basis.get_matrix_elements(
91 : basis, ("electric_dipole", "electric_dipole"), qs=(0, 0), unit="e^2 * a0^2"
92 : )
93 1 : assert matrix_elements.shape == (basis.number_of_states, basis.number_of_states)
94 1 : assert np.count_nonzero(matrix_elements.toarray()) > 0
95 1 : assert np.count_nonzero(matrix_elements.toarray()) <= basis.number_of_states**2
96 :
97 :
98 1 : def test_error_handling(basis: BasisPair) -> None:
99 : """Test error cases."""
100 1 : with pytest.raises(TypeError):
101 1 : basis.get_amplitudes("not a ket") # type: ignore [arg-type]
102 :
103 1 : with pytest.raises(TypeError):
104 1 : basis.get_overlaps("not a ket") # type: ignore [arg-type]
105 :
106 1 : with pytest.raises(TypeError):
107 1 : basis.get_matrix_elements("not a ket", ("energy", "energy"), (0, 0)) # type: ignore [arg-type]
108 :
109 :
110 1 : def test_from_ket_atoms(pi_module: PairinteractionModule, system_atom: SystemAtom) -> None:
111 : """Test BasisPair.from_ket_atoms."""
112 1 : ket = pi_module.KetAtom("Rb", n=60, l=0, j=0.5, m=0.5)
113 1 : ket1 = pi_module.KetAtom("Rb", n=59, l=1, j=0.5, m=0.5)
114 1 : ket2 = pi_module.KetAtom("Rb", n=60, l=1, j=0.5, m=0.5)
115 1 : ket_atoms_dict: dict[str, tuple[KetAtom, KetAtom] | list[tuple[KetAtom, KetAtom]]] = {
116 : "single_ket": (ket, ket),
117 : "multiple_kets": [(ket, ket), (ket1, ket2), (ket2, ket1)],
118 : }
119 :
120 1 : for ket_atoms in ket_atoms_dict.values():
121 : # delta_energy restriction
122 1 : pair_basis = pi_module.BasisPair.from_ket_atoms(
123 : ket_atoms, [system_atom, system_atom], delta_energy=3, delta_energy_unit="GHz", delta_m=1
124 : )
125 1 : assert pair_basis.number_of_kets > 0
126 :
127 : # number_of_kets
128 1 : ket = pi_module.KetAtom("Rb", n=60, l=0, j=0.5, m=0.5)
129 1 : for target in [50, 100, 1000]:
130 1 : with no_log_propagation("pairinteraction.basis.basis_pair"): # surpress number_of_kets warning
131 1 : pair_basis = pi_module.BasisPair.from_ket_atoms(
132 : ket_atoms, [system_atom, system_atom], number_of_kets=target, delta_m=1
133 : )
134 1 : assert pair_basis.number_of_kets >= target
135 1 : assert pair_basis.number_of_kets < target + 20 # allow some extra due to degeneracies
136 :
137 : # test error cases
138 1 : with pytest.raises(ValueError, match="empty"):
139 1 : pi_module.BasisPair.from_ket_atoms([], [system_atom, system_atom])
140 :
141 1 : ket = pi_module.KetAtom("Rb", n=60, l=0, j=0.5, m=0.5)
142 1 : with pytest.raises(ValueError, match="number_of_kets"):
143 1 : pi_module.BasisPair.from_ket_atoms(
144 : (ket, ket),
145 : [system_atom, system_atom],
146 : delta_energy=3,
147 : delta_energy_unit="GHz",
148 : number_of_kets=10,
149 : )
|