LCOV - code coverage report
Current view: top level - tests - test_basis_pair.py (source / functions) Hit Total Coverage
Test: coverage.info Lines: 79 79 100.0 %
Date: 2026-04-17 09:29:39 Functions: 9 9 100.0 %

          Line data    Source code
       1             : # SPDX-FileCopyrightText: 2024 PairInteraction Developers
       2             : # SPDX-License-Identifier: LGPL-3.0-or-later
       3             : 
       4           1 : from __future__ import annotations
       5             : 
       6           1 : from typing import TYPE_CHECKING
       7             : 
       8           1 : import numpy as np
       9           1 : import pytest
      10             : 
      11             : if TYPE_CHECKING:
      12             :     from pairinteraction import BasisAtom, BasisPair, KetAtom, SystemAtom
      13             : 
      14             :     from .utils import PairinteractionModule
      15             : 
      16           1 : from .utils import no_log_propagation
      17             : 
      18             : 
      19           1 : @pytest.fixture
      20           1 : def basis_atom(pi_module: PairinteractionModule) -> BasisAtom:
      21             :     """Create a BasisAtom around Rb 60S."""
      22           1 :     ket = pi_module.KetAtom("Rb", n=60, l=0, j=0.5, m=0.5)
      23           1 :     energy_min = ket.get_energy(unit="GHz") - 100
      24           1 :     energy_max = ket.get_energy(unit="GHz") + 100
      25           1 :     return pi_module.BasisAtom("Rb", n=(58, 62), l=(0, 2), energy=(energy_min, energy_max), energy_unit="GHz")
      26             : 
      27             : 
      28           1 : @pytest.fixture
      29           1 : def system_atom(pi_module: PairinteractionModule, basis_atom: BasisAtom) -> SystemAtom:
      30             :     """Create a diagonalized SystemAtom around Rb 60S."""
      31           1 :     system_atom = pi_module.SystemAtom(basis_atom)
      32           1 :     system_atom.diagonalize()
      33           1 :     return system_atom
      34             : 
      35             : 
      36           1 : @pytest.fixture
      37           1 : def basis(pi_module: PairinteractionModule, system_atom: SystemAtom) -> BasisPair:
      38             :     """Create a test basis with a few states around Rb 60S."""
      39           1 :     return pi_module.BasisPair([system_atom, system_atom])
      40             : 
      41             : 
      42           1 : def test_basis_creation(basis: BasisPair) -> None:
      43             :     """Test basic properties of created basis."""
      44           1 :     assert basis.number_of_kets == 80 * 80
      45           1 :     assert basis.number_of_states == basis.number_of_kets
      46           1 :     assert len(basis.kets) == basis.number_of_kets
      47           1 :     assert all(x in str(basis) for x in ["BasisPair", "Rb:58,S_1/2,", "...", "Rb:61,D_5/2,"])
      48             : 
      49             : 
      50           1 : def test_coefficients(basis: BasisPair) -> None:
      51             :     """Test coefficient matrix properties."""
      52           1 :     coeffs = basis.get_coefficients()
      53           1 :     assert coeffs.shape == (basis.number_of_kets, basis.number_of_states)
      54           1 :     assert pytest.approx(coeffs.diagonal()) == 1.0  # NOSONAR
      55           1 :     assert pytest.approx(coeffs.sum()) == basis.number_of_kets  # NOSONAR
      56             : 
      57             : 
      58           1 : def test_get_amplitudes_and_overlaps(basis: BasisPair) -> None:
      59             :     """Test amplitude and overlap calculations."""
      60             :     # Test with ket
      61           1 :     test_ket = basis.kets[0]
      62           1 :     amplitudes = basis.get_amplitudes(test_ket)
      63           1 :     assert len(amplitudes) == basis.number_of_states
      64           1 :     assert pytest.approx(amplitudes[0]) == 1.0  # NOSONAR
      65           1 :     overlaps = basis.get_overlaps(test_ket)
      66           1 :     assert len(overlaps) == basis.number_of_states
      67           1 :     assert pytest.approx(overlaps[0]) == 1.0  # NOSONAR
      68             : 
      69             :     # Test with basis
      70           1 :     matrix_amplitudes = basis.get_amplitudes(basis)
      71           1 :     assert matrix_amplitudes.shape == (basis.number_of_kets, basis.number_of_states)
      72           1 :     assert pytest.approx(matrix_amplitudes.diagonal()) == 1.0  # NOSONAR
      73           1 :     matrix_overlaps = basis.get_overlaps(basis)
      74           1 :     assert matrix_overlaps.shape == (basis.number_of_states, basis.number_of_states)
      75           1 :     assert pytest.approx(matrix_overlaps.diagonal()) == 1.0  # NOSONAR
      76             : 
      77             : 
      78           1 : def test_get_matrix_elements(basis: BasisPair) -> None:
      79             :     """Test matrix element calculations."""
      80             :     # Test with ket
      81           1 :     test_ket = basis.kets[0]
      82           1 :     elements_dipole = basis.get_matrix_elements(
      83             :         test_ket, ("electric_dipole", "electric_dipole"), qs=(0, 0), unit="e^2 * a0^2"
      84             :     )
      85           1 :     assert elements_dipole.shape == (basis.number_of_states,)
      86           1 :     assert np.count_nonzero(elements_dipole) > 0
      87           1 :     assert np.count_nonzero(elements_dipole) <= basis.number_of_states
      88             : 
      89             :     # Test with basis
      90           1 :     matrix_elements = basis.get_matrix_elements(
      91             :         basis, ("electric_dipole", "electric_dipole"), qs=(0, 0), unit="e^2 * a0^2"
      92             :     )
      93           1 :     assert matrix_elements.shape == (basis.number_of_states, basis.number_of_states)
      94           1 :     assert np.count_nonzero(matrix_elements.toarray()) > 0
      95           1 :     assert np.count_nonzero(matrix_elements.toarray()) <= basis.number_of_states**2
      96             : 
      97             : 
      98           1 : def test_error_handling(basis: BasisPair) -> None:
      99             :     """Test error cases."""
     100           1 :     with pytest.raises(TypeError):
     101           1 :         basis.get_amplitudes("not a ket")  # type: ignore [arg-type]
     102             : 
     103           1 :     with pytest.raises(TypeError):
     104           1 :         basis.get_overlaps("not a ket")  # type: ignore [arg-type]
     105             : 
     106           1 :     with pytest.raises(TypeError):
     107           1 :         basis.get_matrix_elements("not a ket", ("energy", "energy"), (0, 0))  # type: ignore [arg-type]
     108             : 
     109             : 
     110           1 : def test_from_ket_atoms(pi_module: PairinteractionModule, system_atom: SystemAtom) -> None:
     111             :     """Test BasisPair.from_ket_atoms."""
     112           1 :     ket = pi_module.KetAtom("Rb", n=60, l=0, j=0.5, m=0.5)
     113           1 :     ket1 = pi_module.KetAtom("Rb", n=59, l=1, j=0.5, m=0.5)
     114           1 :     ket2 = pi_module.KetAtom("Rb", n=60, l=1, j=0.5, m=0.5)
     115           1 :     ket_atoms_dict: dict[str, tuple[KetAtom, KetAtom] | list[tuple[KetAtom, KetAtom]]] = {
     116             :         "single_ket": (ket, ket),
     117             :         "multiple_kets": [(ket, ket), (ket1, ket2), (ket2, ket1)],
     118             :     }
     119             : 
     120           1 :     for ket_atoms in ket_atoms_dict.values():
     121             :         # delta_energy restriction
     122           1 :         pair_basis = pi_module.BasisPair.from_ket_atoms(
     123             :             ket_atoms, [system_atom, system_atom], delta_energy=3, delta_energy_unit="GHz", delta_m=1
     124             :         )
     125           1 :         assert pair_basis.number_of_kets > 0
     126             : 
     127             :         # number_of_kets
     128           1 :         ket = pi_module.KetAtom("Rb", n=60, l=0, j=0.5, m=0.5)
     129           1 :         for target in [50, 100, 1000]:
     130           1 :             with no_log_propagation("pairinteraction.basis.basis_pair"):  # surpress number_of_kets warning
     131           1 :                 pair_basis = pi_module.BasisPair.from_ket_atoms(
     132             :                     ket_atoms, [system_atom, system_atom], number_of_kets=target, delta_m=1
     133             :                 )
     134           1 :             assert pair_basis.number_of_kets >= target
     135           1 :             assert pair_basis.number_of_kets < target + 20  # allow some extra due to degeneracies
     136             : 
     137             :     # test error cases
     138           1 :     with pytest.raises(ValueError, match="empty"):
     139           1 :         pi_module.BasisPair.from_ket_atoms([], [system_atom, system_atom])
     140             : 
     141           1 :     ket = pi_module.KetAtom("Rb", n=60, l=0, j=0.5, m=0.5)
     142           1 :     with pytest.raises(ValueError, match="number_of_kets"):
     143           1 :         pi_module.BasisPair.from_ket_atoms(
     144             :             (ket, ket),
     145             :             [system_atom, system_atom],
     146             :             delta_energy=3,
     147             :             delta_energy_unit="GHz",
     148             :             number_of_kets=10,
     149             :         )

Generated by: LCOV version 1.16