LCOV - code coverage report
Current view: top level - src/basis - BasisPairCreator.test.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 343 344 99.7 %
Date: 2026-06-19 12:50:25 Functions: 7 7 100.0 %

          Line data    Source code
       1             : // SPDX-FileCopyrightText: 2024 PairInteraction Developers
       2             : // SPDX-License-Identifier: LGPL-3.0-or-later
       3             : 
       4             : #include "pairinteraction/basis/BasisPairCreator.hpp"
       5             : 
       6             : #include "pairinteraction/basis/BasisAtom.hpp"
       7             : #include "pairinteraction/basis/BasisAtomCreator.hpp"
       8             : #include "pairinteraction/basis/BasisPair.hpp"
       9             : #include "pairinteraction/database/Database.hpp"
      10             : #include "pairinteraction/diagonalize/DiagonalizerEigen.hpp"
      11             : #include "pairinteraction/enums/OperatorType.hpp"
      12             : #include "pairinteraction/enums/Parity.hpp"
      13             : #include "pairinteraction/enums/TransformationType.hpp"
      14             : #include "pairinteraction/ket/KetAtom.hpp"
      15             : #include "pairinteraction/ket/KetAtomCreator.hpp"
      16             : #include "pairinteraction/ket/KetPair.hpp"
      17             : #include "pairinteraction/system/SystemAtom.hpp"
      18             : #include "pairinteraction/system/SystemPair.hpp"
      19             : #include "pairinteraction/utils/hash.hpp"
      20             : #include "pairinteraction/utils/streamed.hpp"
      21             : 
      22             : #include <algorithm>
      23             : #include <array>
      24             : #include <cmath>
      25             : #include <doctest/doctest.h>
      26             : #include <unordered_map>
      27             : #include <utility>
      28             : #include <vector>
      29             : 
      30             : namespace pairinteraction {
      31             : 
      32             : constexpr double HARTREE_IN_GHZ = 6579683.920501762;
      33             : constexpr double VOLT_PER_CM_IN_ATOMIC_UNITS = 1 / 5.14220675112e9;
      34             : constexpr double UM_IN_ATOMIC_UNITS = 1 / 5.29177210544e-5;
      35             : 
      36             : namespace {
      37             : template <typename Scalar>
      38             : Eigen::SparseMatrix<Scalar, Eigen::RowMajor>
      39           2 : build_manual_symmetrizer(const std::shared_ptr<const BasisPair<Scalar>> &basis,
      40             :                          Parity parity_under_inversion, Parity parity_under_permutation) {
      41             :     using real_t = typename BasisPair<Scalar>::real_t;
      42             : 
      43           2 :     const auto basis1 = basis->get_basis1();
      44           2 :     const auto basis2 = basis->get_basis2();
      45           2 :     const auto inv_sqrt_two = static_cast<real_t>(1 / std::sqrt(2.0));
      46             : 
      47           2 :     std::vector<Eigen::Triplet<Scalar>> triplets;
      48           2 :     triplets.reserve(2 * basis->get_number_of_states());
      49             : 
      50           2 :     Eigen::Index state_index = 0;
      51             :     std::unordered_map<std::array<size_t, 2>, Eigen::Index, utils::hash<std::array<size_t, 2>>>
      52           2 :         ket_ids_to_state_index;
      53          26 :     for (size_t idx1 = 0; idx1 < basis1->get_number_of_states(); ++idx1) {
      54         312 :         for (size_t idx2 = 0; idx2 < basis2->get_number_of_states(); ++idx2) {
      55         288 :             int ket_index = basis->get_ket_index_from_tuple(idx1, idx2);
      56         288 :             if (ket_index < 0) {
      57          24 :                 continue;
      58             :             }
      59             : 
      60         288 :             size_t id1 = basis1->get_corresponding_ket(idx1)->get_id_in_database();
      61         288 :             size_t id2 = basis2->get_corresponding_ket(idx2)->get_id_in_database();
      62             : 
      63         288 :             if (id1 == id2) {
      64          24 :                 if (parity_under_inversion == Parity::EVEN ||
      65             :                     parity_under_permutation == Parity::EVEN) {
      66           0 :                     continue;
      67             :                 }
      68          24 :                 triplets.emplace_back(ket_index, state_index++, Scalar{1});
      69          24 :                 continue;
      70             :             }
      71             : 
      72         264 :             std::array<size_t, 2> ordered_ids{std::max(id1, id2), std::min(id1, id2)};
      73         264 :             auto [it, inserted] = ket_ids_to_state_index.try_emplace(ordered_ids, state_index);
      74         264 :             if (inserted) {
      75         132 :                 ++state_index;
      76             :             }
      77         264 :             Eigen::Index column_index = it->second;
      78             : 
      79         264 :             if (id1 > id2) {
      80         132 :                 triplets.emplace_back(ket_index, column_index, static_cast<Scalar>(inv_sqrt_two));
      81             :             } else {
      82         132 :                 int swapped_sign = 0;
      83         132 :                 if (parity_under_inversion != Parity::UNKNOWN &&
      84             :                     parity_under_permutation == Parity::UNKNOWN) {
      85         198 :                     swapped_sign = -static_cast<int>(parity_under_inversion) *
      86          66 :                         static_cast<int>(basis1->get_parity(idx1)) *
      87          66 :                         static_cast<int>(basis2->get_parity(idx2));
      88             :                 } else {
      89          66 :                     swapped_sign = -static_cast<int>(parity_under_permutation);
      90             :                 }
      91         132 :                 triplets.emplace_back(ket_index, column_index,
      92         132 :                                       static_cast<Scalar>(swapped_sign * inv_sqrt_two));
      93             :             }
      94             :         }
      95             :     }
      96             : 
      97           4 :     Eigen::SparseMatrix<Scalar, Eigen::RowMajor> transformation(
      98           2 :         static_cast<Eigen::Index>(basis->get_number_of_states()), state_index);
      99           2 :     transformation.setFromTriplets(triplets.begin(), triplets.end());
     100           4 :     return transformation;
     101           2 : }
     102             : 
     103             : template <typename Scalar>
     104             : std::shared_ptr<const BasisPair<Scalar>>
     105           4 : build_pair_basis(std::shared_ptr<const BasisAtom<Scalar>> basis1,
     106             :                  std::shared_ptr<const BasisAtom<Scalar>> basis2) {
     107           4 :     auto system1 = SystemAtom<Scalar>(std::move(basis1));
     108           4 :     auto system2 = SystemAtom<Scalar>(std::move(basis2));
     109          12 :     return BasisPairCreator<Scalar>().add(system1).add(system2).create();
     110           4 : }
     111             : 
     112             : template <typename Scalar>
     113           2 : void check_same_pair_eigenenergies(const std::shared_ptr<const BasisPair<Scalar>> &basis1,
     114             :                                    const std::shared_ptr<const BasisPair<Scalar>> &basis2,
     115             :                                    const DiagonalizerEigen<Scalar> &diagonalizer) {
     116           2 :     auto system_pair_1 = SystemPair<Scalar>(basis1).set_distance_vector(
     117             :         std::array<typename BasisPair<Scalar>::real_t, 3>{0, 0, 1 * UM_IN_ATOMIC_UNITS});
     118           2 :     auto system_pair_2 = SystemPair<Scalar>(basis2).set_distance_vector(
     119             :         std::array<typename BasisPair<Scalar>::real_t, 3>{0, 0, 1 * UM_IN_ATOMIC_UNITS});
     120             : 
     121           2 :     system_pair_1.diagonalize(diagonalizer);
     122           2 :     system_pair_2.diagonalize(diagonalizer);
     123             : 
     124           2 :     auto eigenenergies_1 = system_pair_1.get_eigenenergies();
     125           2 :     auto eigenenergies_2 = system_pair_2.get_eigenenergies();
     126             : 
     127           2 :     DOCTEST_REQUIRE(eigenenergies_1.size() == eigenenergies_2.size());
     128           2 :     DOCTEST_CHECK(eigenenergies_1.isApprox(eigenenergies_2, 1e-11));
     129           2 : }
     130             : } // namespace
     131             : 
     132           3 : DOCTEST_TEST_CASE("create a BasisPair") {
     133             :     // Create single-atom system
     134           3 :     Database &database = Database::get_global_instance();
     135           3 :     auto basis = BasisAtomCreator<double>()
     136           6 :                      .set_species("Rb")
     137           6 :                      .restrict_quantum_number("n", 58, 62)
     138           6 :                      .restrict_quantum_number("l", 0, 2)
     139           3 :                      .create(database);
     140           3 :     SystemAtom<double> system(basis);
     141           3 :     system.set_electric_field({0, 0, 1 * VOLT_PER_CM_IN_ATOMIC_UNITS});
     142             : 
     143           3 :     DiagonalizerEigen<double> diagonalizer;
     144           3 :     system.diagonalize(diagonalizer);
     145             : 
     146             :     // Get energy window for a two-atom basis
     147           3 :     auto ket = KetAtomCreator()
     148           6 :                    .set_species("Rb")
     149           6 :                    .set_quantum_number("n", 60)
     150           6 :                    .set_quantum_number("l", 0)
     151           6 :                    .set_quantum_number("m", 0.5)
     152           3 :                    .create(database);
     153           3 :     double min_energy = 2 * ket->get_energy() - 3 / HARTREE_IN_GHZ;
     154           3 :     double max_energy = 2 * ket->get_energy() + 3 / HARTREE_IN_GHZ;
     155             : 
     156             :     // Create two-atom bases
     157           3 :     auto basis_pair_a = pairinteraction::BasisPairCreator<double>()
     158           3 :                             .add(system)
     159           3 :                             .add(system)
     160           3 :                             .restrict_energy(min_energy, max_energy)
     161           3 :                             .restrict_quantum_number_m(1, 1)
     162           3 :                             .create();
     163           3 :     auto basis_pair_b = pairinteraction::BasisPairCreator<double>()
     164           3 :                             .add(system)
     165           3 :                             .add(system)
     166           3 :                             .restrict_energy(min_energy, max_energy)
     167           3 :                             .restrict_quantum_number_m(1, 1)
     168           3 :                             .create();
     169             : 
     170           3 :     DOCTEST_SUBCASE("check equality of kets") {
     171             :         // Obtain kets from the two-atom bases and check for equality
     172           1 :         auto ket1a = basis_pair_a->get_kets()[0];
     173           1 :         auto ket1b = basis_pair_b->get_kets()[0];
     174           1 :         auto ket2a = basis_pair_a->get_kets()[1];
     175           1 :         auto ket2b = basis_pair_b->get_kets()[1];
     176           1 :         DOCTEST_CHECK(*ket1a == *ket1a);
     177           1 :         DOCTEST_CHECK(*ket2a == *ket2a);
     178           1 :         DOCTEST_CHECK(*ket1a != *ket2b);
     179           1 :         DOCTEST_CHECK(*ket2a != *ket1b);
     180             : 
     181             :         // Currently, kets from different BasisPair are never equal
     182           1 :         DOCTEST_CHECK(*ket1a != *ket1b);
     183           1 :         DOCTEST_CHECK(*ket2a != *ket2b);
     184           4 :     }
     185             : 
     186           3 :     DOCTEST_SUBCASE("check overlap") {
     187           1 :         auto basis_ket = BasisAtomCreator<double>().add_ket(ket).create(database);
     188           1 :         auto basis_pair_ket = build_pair_basis<double>(basis_ket, basis_ket);
     189             :         Eigen::RowVectorXd amplitudes =
     190             :             basis_pair_a
     191           2 :                 ->get_matrix_elements(basis_pair_ket, OperatorType::IDENTITY,
     192             :                                       OperatorType::IDENTITY, 0, 0)
     193           1 :                 .row(0);
     194           1 :         auto overlaps = amplitudes.cwiseAbs2().eval();
     195             : 
     196             :         // The total overlap is less than 1 because of the restricted energy window
     197           1 :         DOCTEST_CHECK(overlaps.sum() == doctest::Approx(0.9107819201));
     198           4 :     }
     199             : 
     200           3 :     DOCTEST_SUBCASE("get the atomic states constituting a ket of the basis_pair") {
     201           1 :         auto atomic_states = basis_pair_a->get_kets()[0]->get_atomic_states();
     202           1 :         DOCTEST_CHECK(atomic_states.size() == 2);
     203           1 :         DOCTEST_CHECK(atomic_states[0]->get_number_of_states() == 1);
     204           1 :         DOCTEST_CHECK(atomic_states[0]->get_number_of_kets() == basis->get_number_of_kets());
     205           4 :     }
     206           3 : }
     207             : 
     208           2 : DOCTEST_TEST_CASE("get matrix elements in the pair basis") {
     209           2 :     DiagonalizerEigen<double> diagonalizer;
     210             : 
     211             :     // Create single-atom system
     212           2 :     Database &database = Database::get_global_instance();
     213           2 :     auto basis = BasisAtomCreator<double>()
     214           4 :                      .set_species("Rb")
     215           4 :                      .restrict_quantum_number("n", 58, 62)
     216           4 :                      .restrict_quantum_number("l", 0, 2)
     217           2 :                      .create(database);
     218           2 :     SystemAtom<double> system(basis);
     219           2 :     system.set_electric_field({0, 0, 10 * VOLT_PER_CM_IN_ATOMIC_UNITS});
     220           2 :     system.diagonalize(diagonalizer);
     221             : 
     222             :     // Get energy window for a two-atom basis
     223           2 :     auto ket = KetAtomCreator()
     224           4 :                    .set_species("Rb")
     225           4 :                    .set_quantum_number("n", 60)
     226           4 :                    .set_quantum_number("l", 0)
     227           4 :                    .set_quantum_number("m", 0.5)
     228           2 :                    .create(database);
     229           2 :     double min_energy = 2 * ket->get_energy() - 3 / HARTREE_IN_GHZ;
     230           2 :     double max_energy = 2 * ket->get_energy() + 3 / HARTREE_IN_GHZ;
     231             : 
     232             :     // Create two-atom system
     233           2 :     auto basis_pair_unperturbed = pairinteraction::BasisPairCreator<double>()
     234           2 :                                       .add(system)
     235           2 :                                       .add(system)
     236           2 :                                       .restrict_energy(min_energy, max_energy)
     237           2 :                                       .restrict_quantum_number_m(1, 1)
     238           2 :                                       .create();
     239           4 :     auto system_pair = SystemPair<double>(basis_pair_unperturbed)
     240           2 :                            .set_distance_vector({0, 0, 1 * UM_IN_ATOMIC_UNITS});
     241           2 :     system_pair.diagonalize(diagonalizer);
     242             : 
     243           2 :     auto basis_pair = system_pair.get_eigenbasis();
     244             : 
     245           2 :     DOCTEST_SUBCASE("check dimensions") {
     246             :         // <basis_pair_unperturbed|d0d0|basis_pair_unperturbed>
     247             :         auto matrix_elements_all = basis_pair_unperturbed->get_matrix_elements(
     248             :             basis_pair_unperturbed, OperatorType::ELECTRIC_DIPOLE, OperatorType::ELECTRIC_DIPOLE, 0,
     249           1 :             0);
     250           1 :         DOCTEST_CHECK(matrix_elements_all.rows() == basis_pair_unperturbed->get_number_of_states());
     251           1 :         DOCTEST_CHECK(matrix_elements_all.cols() == basis_pair_unperturbed->get_number_of_states());
     252             : 
     253             :         // <ket_pair|d0d0|basis_pair_unperturbed>
     254           1 :         auto atomic_states = basis_pair_unperturbed->get_kets()[0]->get_atomic_states();
     255           1 :         auto basis_pair_ket_pair = build_pair_basis<double>(atomic_states[0], atomic_states[1]);
     256             :         Eigen::RowVectorXd matrix_elements_ket_pair =
     257             :             basis_pair_unperturbed
     258           2 :                 ->get_matrix_elements(basis_pair_ket_pair, OperatorType::ELECTRIC_DIPOLE,
     259             :                                       OperatorType::ELECTRIC_DIPOLE, 0, 0)
     260           1 :                 .row(0);
     261           1 :         DOCTEST_CHECK(matrix_elements_ket_pair.size() ==
     262             :                       basis_pair_unperturbed->get_number_of_states());
     263             : 
     264             :         {
     265           1 :             Eigen::RowVectorXd ref = matrix_elements_all.row(0);
     266           1 :             DOCTEST_CHECK(ref.isApprox(matrix_elements_ket_pair, 1e-11));
     267           1 :         }
     268             : 
     269             :         // <basis x basis|d0d0|basis_pair>
     270           1 :         auto basis_pair_product = build_pair_basis<double>(basis, basis);
     271             :         auto matrix_elements_product = basis_pair->get_matrix_elements(
     272           1 :             basis_pair_product, OperatorType::ELECTRIC_DIPOLE, OperatorType::ELECTRIC_DIPOLE, 0, 0);
     273           1 :         DOCTEST_CHECK(matrix_elements_product.rows() ==
     274             :                       basis->get_number_of_states() * basis->get_number_of_states());
     275           1 :         DOCTEST_CHECK(matrix_elements_product.cols() == basis_pair->get_number_of_states());
     276             : 
     277             :         // <ket,ket|d0d0|basis_pair>
     278           1 :         auto basis_ket = BasisAtomCreator<double>().add_ket(ket).create(database);
     279           1 :         auto basis_pair_ket = build_pair_basis<double>(basis_ket, basis_ket);
     280             :         auto matrix_elements_ket = basis_pair->get_matrix_elements(
     281           1 :             basis_pair_ket, OperatorType::ELECTRIC_DIPOLE, OperatorType::ELECTRIC_DIPOLE, 0, 0);
     282           1 :         DOCTEST_CHECK(matrix_elements_ket.rows() == 1);
     283           1 :         DOCTEST_CHECK(matrix_elements_ket.cols() == basis_pair->get_number_of_states());
     284           3 :     }
     285             : 
     286           2 :     DOCTEST_SUBCASE("check matrix elements") {
     287             :         // energy
     288             :         auto hamiltonian = basis_pair->get_matrix_elements(basis_pair, OperatorType::ENERGY,
     289           1 :                                                            OperatorType::IDENTITY, 0, 0);
     290           2 :         hamiltonian += basis_pair->get_matrix_elements(basis_pair, OperatorType::IDENTITY,
     291           1 :                                                        OperatorType::ENERGY, 0, 0);
     292             : 
     293             :         // interaction with electric field
     294             :         {
     295           2 :             Eigen::SparseMatrix<double, Eigen::RowMajor> tmp = -basis_pair->get_matrix_elements(
     296           1 :                 basis_pair, OperatorType::ELECTRIC_DIPOLE, OperatorType::IDENTITY, 0, 0);
     297           2 :             tmp += -basis_pair->get_matrix_elements(basis_pair, OperatorType::IDENTITY,
     298           1 :                                                     OperatorType::ELECTRIC_DIPOLE, 0, 0);
     299           1 :             hamiltonian += 10 * VOLT_PER_CM_IN_ATOMIC_UNITS * tmp;
     300           1 :         }
     301             : 
     302             :         // dipole-dipole interaction
     303             :         {
     304           1 :             Eigen::SparseMatrix<double, Eigen::RowMajor> tmp = -2 *
     305           2 :                 basis_pair->get_matrix_elements(basis_pair, OperatorType::ELECTRIC_DIPOLE,
     306           1 :                                                 OperatorType::ELECTRIC_DIPOLE, 0, 0);
     307           2 :             tmp += -basis_pair->get_matrix_elements(basis_pair, OperatorType::ELECTRIC_DIPOLE,
     308           1 :                                                     OperatorType::ELECTRIC_DIPOLE, 1, -1);
     309           2 :             tmp += -basis_pair->get_matrix_elements(basis_pair, OperatorType::ELECTRIC_DIPOLE,
     310           1 :                                                     OperatorType::ELECTRIC_DIPOLE, -1, 1);
     311           1 :             hamiltonian += std::pow(UM_IN_ATOMIC_UNITS, -3) * tmp;
     312           1 :         }
     313             : 
     314             :         // compare to reference
     315           1 :         const auto &ref = system_pair.get_matrix();
     316           1 :         DOCTEST_CHECK(ref.isApprox(hamiltonian, 1e-11));
     317           3 :     }
     318           2 : }
     319             : 
     320           1 : DOCTEST_TEST_CASE("get amplitudes (via matrix elements) between different pair basis") {
     321           1 :     DiagonalizerEigen<double> diagonalizer;
     322             : 
     323           1 :     Database &database = Database::get_global_instance();
     324           1 :     auto atomic_basis = BasisAtomCreator<double>()
     325           2 :                             .set_species("Rb")
     326           2 :                             .restrict_quantum_number("n", 58, 62)
     327           2 :                             .restrict_quantum_number("l", 0, 2)
     328           2 :                             .restrict_quantum_number("m", 0.5, 0.5)
     329           1 :                             .create(database);
     330             : 
     331           1 :     auto ket = KetAtomCreator()
     332           2 :                    .set_species("Rb")
     333           2 :                    .set_quantum_number("n", 60)
     334           2 :                    .set_quantum_number("l", 0)
     335           2 :                    .set_quantum_number("m", 0.5)
     336           1 :                    .create(database);
     337             : 
     338           1 :     auto perturbed_system1 = SystemAtom<double>(atomic_basis);
     339           1 :     perturbed_system1.set_electric_field({0, 0, 1 * VOLT_PER_CM_IN_ATOMIC_UNITS});
     340           1 :     perturbed_system1.diagonalize(diagonalizer);
     341             : 
     342           1 :     auto perturbed_system2 = SystemAtom<double>(atomic_basis);
     343           1 :     perturbed_system2.set_electric_field({0, 0, 2 * VOLT_PER_CM_IN_ATOMIC_UNITS});
     344           1 :     perturbed_system2.diagonalize(diagonalizer);
     345             : 
     346           1 :     double min_energy = 2 * ket->get_energy() - 20 / HARTREE_IN_GHZ;
     347           1 :     double max_energy = 2 * ket->get_energy() + 20 / HARTREE_IN_GHZ;
     348           1 :     auto perturbed_basis = BasisPairCreator<double>()
     349           1 :                                .add(perturbed_system1)
     350           1 :                                .add(perturbed_system2)
     351           1 :                                .restrict_energy(min_energy, max_energy)
     352           1 :                                .create();
     353           1 :     auto state_in_perturbed_basis = perturbed_basis->get_state(42);
     354             : 
     355           1 :     auto unperturbed_system1 = SystemAtom<double>(atomic_basis);
     356           1 :     auto unperturbed_system2 = SystemAtom<double>(atomic_basis);
     357           1 :     min_energy = 2 * ket->get_energy() - 10 / HARTREE_IN_GHZ;
     358           1 :     max_energy = 2 * ket->get_energy() + 10 / HARTREE_IN_GHZ;
     359           1 :     auto small_unperturbed_basis = BasisPairCreator<double>()
     360           1 :                                        .add(unperturbed_system1)
     361           1 :                                        .add(unperturbed_system2)
     362           1 :                                        .restrict_energy(min_energy, max_energy)
     363           1 :                                        .create();
     364           1 :     DOCTEST_CHECK(perturbed_basis->get_number_of_states() !=
     365             :                   small_unperturbed_basis->get_number_of_states());
     366             : 
     367             :     auto amplitudes = state_in_perturbed_basis->get_matrix_elements(
     368           1 :         small_unperturbed_basis, OperatorType::IDENTITY, OperatorType::IDENTITY, 0, 0);
     369           1 :     DOCTEST_CHECK(amplitudes.rows() == small_unperturbed_basis->get_number_of_states());
     370           1 :     DOCTEST_CHECK(amplitudes.cols() == state_in_perturbed_basis->get_number_of_states());
     371           1 : }
     372             : 
     373           5 : DOCTEST_TEST_CASE("create a symmetrized BasisPair") {
     374           5 :     auto &database = Database::get_global_instance();
     375           5 :     auto diagonalizer = DiagonalizerEigen<double>();
     376             : 
     377           5 :     auto basis = BasisAtomCreator<double>()
     378          10 :                      .set_species("Rb")
     379          10 :                      .restrict_quantum_number("n", 60, 61)
     380          10 :                      .restrict_quantum_number("l", 0, 1)
     381          10 :                      .restrict_quantum_number("m", -0.5, 0.5)
     382           5 :                      .create(database);
     383             : 
     384           5 :     SystemAtom<double> system(basis);
     385           5 :     system.diagonalize(diagonalizer);
     386             : 
     387           5 :     auto canonical_basis = BasisPairCreator<double>().add(system).add(system).create();
     388             : 
     389           5 :     DOCTEST_SUBCASE("restrict permutation parity") {
     390           1 :         auto symmetrized_basis = BasisPairCreator<double>()
     391           1 :                                      .add(system)
     392           1 :                                      .add(system)
     393           1 :                                      .restrict_parity_under_permutation(Parity::ODD)
     394           1 :                                      .create();
     395             : 
     396           3 :         auto expected_basis = canonical_basis->transformed(Transformation<double>(
     397           3 :             build_manual_symmetrizer(canonical_basis, Parity::UNKNOWN, Parity::ODD)));
     398             : 
     399           1 :         check_same_pair_eigenenergies(symmetrized_basis, expected_basis, diagonalizer);
     400           1 :         DOCTEST_CHECK(symmetrized_basis->get_number_of_states() <
     401             :                       canonical_basis->get_number_of_states());
     402           6 :     }
     403             : 
     404           5 :     DOCTEST_SUBCASE("restrict inversion parity") {
     405           1 :         auto symmetrized_basis = BasisPairCreator<double>()
     406           1 :                                      .add(system)
     407           1 :                                      .add(system)
     408           1 :                                      .restrict_parity_under_inversion(Parity::ODD)
     409           1 :                                      .create();
     410             : 
     411           3 :         auto expected_basis = canonical_basis->transformed(Transformation<double>(
     412           3 :             build_manual_symmetrizer(canonical_basis, Parity::ODD, Parity::UNKNOWN)));
     413             : 
     414           1 :         check_same_pair_eigenenergies(symmetrized_basis, expected_basis, diagonalizer);
     415           1 :         DOCTEST_CHECK(symmetrized_basis->get_number_of_states() <
     416             :                       canonical_basis->get_number_of_states());
     417           6 :     }
     418             : 
     419           5 :     DOCTEST_SUBCASE("restrict permutation parity to EVEN includes identical-state kets") {
     420           1 :         auto symmetrized_basis_even = BasisPairCreator<double>()
     421           1 :                                           .add(system)
     422           1 :                                           .add(system)
     423           1 :                                           .restrict_parity_under_permutation(Parity::EVEN)
     424           1 :                                           .create();
     425             : 
     426           1 :         auto symmetrized_basis_odd = BasisPairCreator<double>()
     427           1 :                                          .add(system)
     428           1 :                                          .add(system)
     429           1 :                                          .restrict_parity_under_permutation(Parity::ODD)
     430           1 :                                          .create();
     431             : 
     432             :         // Count kets in the canonical basis where both atoms are in the same state (id1 == id2).
     433             :         // Such kets are always permutation-symmetric and must appear in EVEN but not ODD.
     434           1 :         size_t num_diagonal_kets = 0;
     435         145 :         for (const auto &ket : *canonical_basis) {
     436         144 :             auto atomic_states = ket->get_atomic_states();
     437         144 :             DOCTEST_REQUIRE(atomic_states.size() == 2);
     438         144 :             size_t id1 = atomic_states[0]->get_corresponding_ket(0)->get_id_in_database();
     439         144 :             size_t id2 = atomic_states[1]->get_corresponding_ket(0)->get_id_in_database();
     440         144 :             if (id1 == id2) {
     441          12 :                 ++num_diagonal_kets;
     442             :             }
     443         144 :         }
     444           1 :         DOCTEST_REQUIRE(num_diagonal_kets > 0);
     445             : 
     446             :         // EVEN and ODD together partition all canonical kets: off-diagonal pairs contribute one
     447             :         // state to each sector, diagonal kets only contribute to EVEN.
     448           1 :         DOCTEST_CHECK(symmetrized_basis_even->get_number_of_states() +
     449             :                           symmetrized_basis_odd->get_number_of_states() ==
     450             :                       canonical_basis->get_number_of_states());
     451           1 :         DOCTEST_CHECK(symmetrized_basis_odd->get_number_of_states() -
     452             :                           symmetrized_basis_even->get_number_of_states() ==
     453             :                       num_diagonal_kets);
     454             : 
     455             :         // The eigenenergies of the EVEN and ODD bases together must reproduce the
     456             :         // eigenenergies of the canonical (non-symmetrized) basis.
     457             :         auto system_pair_canonical =
     458           1 :             SystemPair<double>(canonical_basis).set_distance_vector({0, 0, 1 * UM_IN_ATOMIC_UNITS});
     459           2 :         auto system_pair_even = SystemPair<double>(symmetrized_basis_even)
     460           1 :                                     .set_distance_vector({0, 0, 1 * UM_IN_ATOMIC_UNITS});
     461           2 :         auto system_pair_odd = SystemPair<double>(symmetrized_basis_odd)
     462           1 :                                    .set_distance_vector({0, 0, 1 * UM_IN_ATOMIC_UNITS});
     463             : 
     464           1 :         system_pair_canonical.diagonalize(diagonalizer);
     465           1 :         system_pair_even.diagonalize(diagonalizer);
     466           1 :         system_pair_odd.diagonalize(diagonalizer);
     467             : 
     468           1 :         auto canonical_eigenenergies = system_pair_canonical.get_eigenenergies();
     469           1 :         auto even_eigenenergies = system_pair_even.get_eigenenergies();
     470           1 :         auto odd_eigenenergies = system_pair_odd.get_eigenenergies();
     471             : 
     472           1 :         DOCTEST_REQUIRE(canonical_eigenenergies.size() ==
     473             :                         even_eigenenergies.size() + odd_eigenenergies.size());
     474             : 
     475           1 :         Eigen::VectorXd combined_eigenenergies(canonical_eigenenergies.size());
     476           1 :         combined_eigenenergies << even_eigenenergies, odd_eigenenergies;
     477           1 :         std::sort(combined_eigenenergies.data(),
     478           1 :                   combined_eigenenergies.data() + combined_eigenenergies.size());
     479           1 :         std::sort(canonical_eigenenergies.data(),
     480           1 :                   canonical_eigenenergies.data() + canonical_eigenenergies.size());
     481             : 
     482           1 :         DOCTEST_CHECK(combined_eigenenergies.isApprox(canonical_eigenenergies, 1e-11));
     483           6 :     }
     484             : 
     485           5 :     DOCTEST_SUBCASE("combine inversion and permutation parity") {
     486           1 :         auto symmetrized_basis = BasisPairCreator<double>()
     487           1 :                                      .add(system)
     488           1 :                                      .add(system)
     489           1 :                                      .restrict_parity_under_inversion(Parity::ODD)
     490           1 :                                      .restrict_parity_under_permutation(Parity::ODD)
     491           1 :                                      .create();
     492             : 
     493           1 :         DOCTEST_CHECK(symmetrized_basis->get_number_of_states() <
     494             :                       canonical_basis->get_number_of_states());
     495             : 
     496             :         Eigen::SparseMatrix<double, Eigen::ColMajor> coefficients =
     497           1 :             symmetrized_basis->get_coefficients();
     498           1 :         const double inv_sqrt_two = 1 / std::sqrt(2.0);
     499             : 
     500          47 :         for (int state_index = 0; state_index < coefficients.outerSize(); ++state_index) {
     501          46 :             std::vector<std::pair<int, double>> entries;
     502          46 :             for (Eigen::SparseMatrix<double, Eigen::ColMajor>::InnerIterator it(coefficients,
     503          46 :                                                                                 state_index);
     504         126 :                  it; ++it) {
     505          80 :                 auto atomic_states = symmetrized_basis->get_kets()[it.row()]->get_atomic_states();
     506          80 :                 DOCTEST_REQUIRE(atomic_states.size() == 2);
     507          80 :                 DOCTEST_CHECK(static_cast<int>(atomic_states[0]->get_parity(0)) *
     508             :                                   static_cast<int>(atomic_states[1]->get_parity(0)) ==
     509             :                               static_cast<int>(Parity::EVEN));
     510          80 :                 entries.emplace_back(it.row(), it.value());
     511          80 :             }
     512             : 
     513          46 :             DOCTEST_CHECK(entries.size() >= 1);
     514          46 :             DOCTEST_CHECK(entries.size() <= 2);
     515          46 :             if (entries.size() == 1) {
     516          12 :                 DOCTEST_CHECK(entries[0].second == doctest::Approx(1));
     517             :             } else {
     518          34 :                 DOCTEST_CHECK(std::abs(entries[0].second) == doctest::Approx(inv_sqrt_two));
     519          34 :                 DOCTEST_CHECK(std::abs(entries[1].second) == doctest::Approx(inv_sqrt_two));
     520          34 :                 DOCTEST_CHECK(entries[0].second == doctest::Approx(entries[1].second));
     521             :             }
     522          46 :         }
     523           6 :     }
     524             : 
     525           5 :     DOCTEST_SUBCASE("parity restrictions require the same SystemAtom twice") {
     526             :         // A second, independently constructed system represents a different atom. Even though it
     527             :         // is built from the same basis, it is a distinct object, so symmetrization is rejected.
     528           1 :         SystemAtom<double> system_other(basis);
     529           1 :         system_other.diagonalize(diagonalizer);
     530             : 
     531           2 :         DOCTEST_CHECK_THROWS_AS(BasisPairCreator<double>()
     532             :                                     .add(system)
     533             :                                     .add(system_other)
     534             :                                     .restrict_parity_under_permutation(Parity::ODD)
     535             :                                     .create(),
     536             :                                 std::invalid_argument);
     537             : 
     538           2 :         DOCTEST_CHECK_THROWS_AS(BasisPairCreator<double>()
     539             :                                     .add(system)
     540             :                                     .add(system_other)
     541             :                                     .restrict_parity_under_inversion(Parity::ODD)
     542             :                                     .create(),
     543             :                                 std::invalid_argument);
     544             : 
     545             :         // Without a parity restriction, two different systems remain allowed.
     546           1 :         DOCTEST_CHECK_NOTHROW(BasisPairCreator<double>().add(system).add(system_other).create());
     547           6 :     }
     548           5 : }
     549             : 
     550             : } // namespace pairinteraction

Generated by: LCOV version 1.16