LCOV - code coverage report
Current view: top level - src/basis - Basis.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 382 477 80.1 %
Date: 2026-06-19 12:50:25 Functions: 117 184 63.6 %

          Line data    Source code
       1             : // SPDX-FileCopyrightText: 2024 PairInteraction Developers
       2             : // SPDX-License-Identifier: LGPL-3.0-or-later
       3             : 
       4             : #include "pairinteraction/basis/Basis.hpp"
       5             : 
       6             : #include "pairinteraction/basis/BasisAtom.hpp"
       7             : #include "pairinteraction/basis/BasisPair.hpp"
       8             : #include "pairinteraction/enums/Parity.hpp"
       9             : #include "pairinteraction/enums/TransformationType.hpp"
      10             : #include "pairinteraction/ket/KetAtom.hpp"
      11             : #include "pairinteraction/ket/KetPair.hpp"
      12             : #include "pairinteraction/utils/TaskControl.hpp"
      13             : #include "pairinteraction/utils/eigen_assertion.hpp"
      14             : #include "pairinteraction/utils/eigen_compat.hpp"
      15             : #include "pairinteraction/utils/wigner.hpp"
      16             : 
      17             : #include <cassert>
      18             : #include <numeric>
      19             : #include <set>
      20             : #include <spdlog/spdlog.h>
      21             : 
      22             : namespace pairinteraction {
      23             : 
      24             : template <typename Scalar>
      25             : class BasisAtom;
      26             : 
      27             : template <typename Derived>
      28        2761 : void Basis<Derived>::perform_sorter_checks(const std::vector<TransformationType> &labels) const {
      29             :     // Check if the labels are valid sorting labels
      30        5549 :     for (const auto &label : labels) {
      31        2788 :         if (!utils::is_sorting(label)) {
      32           0 :             throw std::invalid_argument("One of the labels is not a valid sorting label.");
      33             :         }
      34             :     }
      35        2761 : }
      36             : 
      37             : template <typename Derived>
      38         630 : void Basis<Derived>::perform_blocks_checks(
      39             :     const std::set<TransformationType> &unique_labels) const {
      40             :     // Check if the states are sorted by the requested labels
      41         630 :     std::set<TransformationType> unique_labels_present;
      42        1248 :     for (const auto &label : get_transformation().transformation_type) {
      43         669 :         if (!utils::is_sorting(label) || unique_labels_present.size() >= unique_labels.size()) {
      44          51 :             break;
      45             :         }
      46         618 :         unique_labels_present.insert(label);
      47             :     }
      48         630 :     if (unique_labels != unique_labels_present) {
      49           0 :         throw std::invalid_argument("The states are not sorted by the requested labels.");
      50             :     }
      51             : 
      52             :     // Throw a meaningful error if getting the blocks by energy is requested as this might be a
      53             :     // common mistake
      54         630 :     if (unique_labels.contains(TransformationType::SORT_BY_ENERGY)) {
      55           0 :         throw std::invalid_argument("States do not store the energy and thus no energy blocks can "
      56             :                                     "be obtained. Use an energy operator instead.");
      57             :     }
      58         630 : }
      59             : 
      60             : template <typename Derived>
      61        1508 : Basis<Derived>::Basis(ketvec_t &&kets)
      62        3016 :     : kets(std::move(kets)), coefficients{{static_cast<Eigen::Index>(this->kets.size()),
      63        1508 :                                            static_cast<Eigen::Index>(this->kets.size())},
      64        3016 :                                           {TransformationType::SORT_BY_KET}} {
      65        1508 :     if (this->kets.empty()) {
      66           0 :         throw std::invalid_argument("The basis must contain at least one element.");
      67             :     }
      68        1508 :     state_index_to_quantum_number_f.reserve(this->kets.size());
      69        1508 :     state_index_to_quantum_number_m.reserve(this->kets.size());
      70        1508 :     state_index_to_parity.reserve(this->kets.size());
      71        1508 :     ket_to_ket_index.reserve(this->kets.size());
      72        1508 :     size_t index = 0;
      73     2874399 :     for (const auto &ket : this->kets) {
      74     2872891 :         real_t f = std::numeric_limits<real_t>::max();
      75     2872891 :         real_t m = std::numeric_limits<real_t>::max();
      76     2872891 :         Parity p = Parity::UNKNOWN;
      77             :         // TODO: this is a workaround, and should be fixed, once we restructure the quantum number
      78             :         // handling of the Basis class
      79             :         if constexpr (requires { ket->get_quantum_number(std::string{}); }) {
      80       39800 :             f = ket->get_quantum_number("f");
      81       39800 :             m = ket->get_quantum_number("m");
      82       39800 :             p = static_cast<Parity>(static_cast<int>(ket->get_quantum_number("parity")));
      83             :         } else {
      84     2833091 :             m = ket->get_quantum_number_m();
      85             :         }
      86     2872891 :         state_index_to_quantum_number_f.push_back(f);
      87     2872891 :         state_index_to_quantum_number_m.push_back(m);
      88     2872891 :         state_index_to_parity.push_back(p);
      89     2872891 :         ket_to_ket_index[ket] = index++;
      90     2872891 :         if (f == std::numeric_limits<real_t>::max()) {
      91     2833091 :             _has_quantum_number_f = false;
      92             :         }
      93     2872891 :         if (m == std::numeric_limits<real_t>::max()) {
      94           0 :             _has_quantum_number_m = false;
      95             :         }
      96     2872891 :         if (p == Parity::UNKNOWN) {
      97     2833091 :             _has_parity = false;
      98             :         }
      99             :     }
     100        1508 :     state_index_to_ket_index.resize(this->kets.size());
     101        1508 :     std::iota(state_index_to_ket_index.begin(), state_index_to_ket_index.end(), 0);
     102        1508 :     ket_index_to_state_index.resize(this->kets.size());
     103        1508 :     std::iota(ket_index_to_state_index.begin(), ket_index_to_state_index.end(), 0);
     104        1508 :     coefficients.matrix.setIdentity();
     105        1508 : }
     106             : 
     107             : template <typename Derived>
     108        1229 : bool Basis<Derived>::has_quantum_number_f() const {
     109        1229 :     return _has_quantum_number_f;
     110             : }
     111             : 
     112             : template <typename Derived>
     113     5802053 : bool Basis<Derived>::has_quantum_number_m() const {
     114     5802053 :     return _has_quantum_number_m;
     115             : }
     116             : 
     117             : template <typename Derived>
     118        1229 : bool Basis<Derived>::has_parity() const {
     119        1229 :     return _has_parity;
     120             : }
     121             : 
     122             : template <typename Derived>
     123        3666 : const Derived &Basis<Derived>::derived() const {
     124        3666 :     return static_cast<const Derived &>(*this);
     125             : }
     126             : 
     127             : template <typename Derived>
     128         182 : const typename Basis<Derived>::ketvec_t &Basis<Derived>::get_kets() const {
     129         182 :     return kets;
     130             : }
     131             : 
     132             : template <typename Derived>
     133             : const Eigen::SparseMatrix<typename Basis<Derived>::scalar_t, Eigen::RowMajor> &
     134       12378 : Basis<Derived>::get_coefficients() const {
     135       12378 :     return coefficients.matrix;
     136             : }
     137             : 
     138             : template <typename Derived>
     139             : Eigen::SparseMatrix<typename Basis<Derived>::scalar_t, Eigen::RowMajor> &
     140           0 : Basis<Derived>::get_coefficients() {
     141           0 :     return coefficients.matrix;
     142             : }
     143             : 
     144             : template <typename Derived>
     145           5 : void Basis<Derived>::set_coefficients(
     146             :     const Eigen::SparseMatrix<scalar_t, Eigen::RowMajor> &values) {
     147           5 :     if (values.rows() != coefficients.matrix.rows()) {
     148           0 :         throw std::invalid_argument("Incompatible number of rows.");
     149             :     }
     150           5 :     if (values.cols() != coefficients.matrix.cols()) {
     151           0 :         throw std::invalid_argument("Incompatible number of columns.");
     152             :     }
     153             : 
     154           5 :     coefficients.matrix = values;
     155             : 
     156           5 :     std::fill(ket_index_to_state_index.begin(), ket_index_to_state_index.end(),
     157           5 :               std::numeric_limits<int>::max());
     158           5 :     std::fill(state_index_to_ket_index.begin(), state_index_to_ket_index.end(),
     159           5 :               std::numeric_limits<int>::max());
     160           5 :     std::fill(state_index_to_quantum_number_f.begin(), state_index_to_quantum_number_f.end(),
     161           5 :               std::numeric_limits<real_t>::max());
     162           5 :     std::fill(state_index_to_quantum_number_m.begin(), state_index_to_quantum_number_m.end(),
     163           5 :               std::numeric_limits<real_t>::max());
     164           5 :     std::fill(state_index_to_parity.begin(), state_index_to_parity.end(), Parity::UNKNOWN);
     165           5 :     _has_quantum_number_f = false;
     166           5 :     _has_quantum_number_m = false;
     167           5 :     _has_parity = false;
     168           5 : }
     169             : 
     170             : template <typename Derived>
     171         252 : int Basis<Derived>::get_ket_index_from_ket(std::shared_ptr<const ket_t> ket) const {
     172         252 :     if (!ket_to_ket_index.contains(ket)) {
     173           0 :         return -1;
     174             :     }
     175         252 :     return ket_to_ket_index.at(ket);
     176             : }
     177             : 
     178             : template <typename Derived>
     179           0 : typename Basis<Derived>::real_t Basis<Derived>::get_quantum_number_f(size_t state_index) const {
     180           0 :     real_t quantum_number_f = state_index_to_quantum_number_f.at(state_index);
     181           0 :     if (quantum_number_f == std::numeric_limits<real_t>::max()) {
     182           0 :         throw std::invalid_argument("The state does not have a well-defined quantum number f.");
     183             :     }
     184           0 :     return quantum_number_f;
     185             : }
     186             : 
     187             : template <typename Derived>
     188     5800866 : typename Basis<Derived>::real_t Basis<Derived>::get_quantum_number_m(size_t state_index) const {
     189     5800866 :     real_t quantum_number_m = state_index_to_quantum_number_m.at(state_index);
     190     5800866 :     if (quantum_number_m == std::numeric_limits<real_t>::max()) {
     191           0 :         throw std::invalid_argument("The state does not have a well-defined quantum number m.");
     192             :     }
     193     5800866 :     return quantum_number_m;
     194             : }
     195             : 
     196             : template <typename Derived>
     197       32795 : Parity Basis<Derived>::get_parity(size_t state_index) const {
     198       32795 :     Parity parity = state_index_to_parity.at(state_index);
     199       32795 :     if (parity == Parity::UNKNOWN) {
     200           0 :         throw std::invalid_argument("The state does not have a well-defined parity.");
     201             :     }
     202       32795 :     return parity;
     203             : }
     204             : 
     205             : template <typename Derived>
     206             : std::shared_ptr<const typename Basis<Derived>::ket_t>
     207        1100 : Basis<Derived>::get_corresponding_ket(size_t state_index) const {
     208        1100 :     size_t ket_index = state_index_to_ket_index.at(state_index);
     209        1100 :     if (ket_index == std::numeric_limits<int>::max()) {
     210           0 :         throw std::invalid_argument("The state does not belong to a ket in a well-defined way.");
     211             :     }
     212        1100 :     return kets[ket_index];
     213             : }
     214             : 
     215             : template <typename Derived>
     216             : std::shared_ptr<const typename Basis<Derived>::ket_t>
     217           0 : Basis<Derived>::get_corresponding_ket(std::shared_ptr<const Derived> /*state*/) const {
     218           0 :     throw std::runtime_error("Not implemented yet.");
     219             : }
     220             : 
     221             : template <typename Derived>
     222         824 : std::shared_ptr<const Derived> Basis<Derived>::get_state(size_t state_index) const {
     223             :     // Create a copy of the current object
     224         824 :     auto restricted = std::make_shared<Derived>(derived());
     225             : 
     226             :     // Restrict the copy to the state with the largest overlap
     227         824 :     restricted->coefficients.matrix = restricted->coefficients.matrix.col(state_index);
     228             : 
     229         824 :     std::fill(restricted->ket_index_to_state_index.begin(),
     230         824 :               restricted->ket_index_to_state_index.end(), std::numeric_limits<int>::max());
     231             : 
     232         824 :     size_t ket_index = state_index_to_ket_index[state_index];
     233         824 :     restricted->state_index_to_ket_index = {ket_index};
     234         824 :     if (ket_index != std::numeric_limits<int>::max()) {
     235         824 :         restricted->ket_index_to_state_index[ket_index] = 0;
     236             :     }
     237             : 
     238         824 :     restricted->state_index_to_quantum_number_f = {state_index_to_quantum_number_f[state_index]};
     239         824 :     restricted->state_index_to_quantum_number_m = {state_index_to_quantum_number_m[state_index]};
     240         824 :     restricted->state_index_to_parity = {state_index_to_parity[state_index]};
     241             : 
     242        1648 :     restricted->_has_quantum_number_f =
     243         824 :         restricted->state_index_to_quantum_number_f[0] != std::numeric_limits<real_t>::max();
     244        1648 :     restricted->_has_quantum_number_m =
     245         824 :         restricted->state_index_to_quantum_number_m[0] != std::numeric_limits<real_t>::max();
     246         824 :     restricted->_has_parity = restricted->state_index_to_parity[0] != Parity::UNKNOWN;
     247             : 
     248        1648 :     return restricted;
     249         824 : }
     250             : 
     251             : template <typename Derived>
     252             : std::shared_ptr<const typename Basis<Derived>::ket_t>
     253      187616 : Basis<Derived>::get_ket(size_t ket_index) const {
     254      187616 :     return kets[ket_index];
     255             : }
     256             : 
     257             : template <typename Derived>
     258           2 : std::shared_ptr<const Derived> Basis<Derived>::get_corresponding_state(size_t ket_index) const {
     259           2 :     size_t state_index = ket_index_to_state_index.at(ket_index);
     260           2 :     if (state_index == std::numeric_limits<int>::max()) {
     261           0 :         throw std::runtime_error("The ket does not belong to a state in a well-defined way.");
     262             :     }
     263           2 :     return get_state(state_index);
     264             : }
     265             : 
     266             : template <typename Derived>
     267             : std::shared_ptr<const Derived>
     268           2 : Basis<Derived>::get_corresponding_state(std::shared_ptr<const ket_t> ket) const {
     269           2 :     int ket_index = get_ket_index_from_ket(ket);
     270           2 :     if (ket_index < 0) {
     271           0 :         throw std::invalid_argument("The ket does not belong to the basis.");
     272             :     }
     273           2 :     return get_corresponding_state(ket_index);
     274             : }
     275             : 
     276             : template <typename Derived>
     277         250 : size_t Basis<Derived>::get_corresponding_state_index(size_t ket_index) const {
     278         250 :     int state_index = ket_index_to_state_index.at(ket_index);
     279         250 :     if (state_index == std::numeric_limits<int>::max()) {
     280           0 :         throw std::runtime_error("The ket does not belong to a state in a well-defined way.");
     281             :     }
     282         250 :     return state_index;
     283             : }
     284             : 
     285             : template <typename Derived>
     286         250 : size_t Basis<Derived>::get_corresponding_state_index(std::shared_ptr<const ket_t> ket) const {
     287         250 :     int ket_index = get_ket_index_from_ket(ket);
     288         250 :     if (ket_index < 0) {
     289           0 :         throw std::invalid_argument("The ket does not belong to the basis.");
     290             :     }
     291         250 :     return get_corresponding_state_index(ket_index);
     292             : }
     293             : 
     294             : template <typename Derived>
     295         236 : size_t Basis<Derived>::get_corresponding_ket_index(size_t state_index) const {
     296         236 :     size_t ket_index = state_index_to_ket_index.at(state_index);
     297         236 :     if (ket_index == std::numeric_limits<int>::max()) {
     298           0 :         throw std::runtime_error("The state does not belong to a ket in a well-defined way.");
     299             :     }
     300         236 :     return ket_index;
     301             : }
     302             : 
     303             : template <typename Derived>
     304           0 : size_t Basis<Derived>::get_corresponding_ket_index(std::shared_ptr<const Derived> /*state*/) const {
     305           0 :     throw std::runtime_error("Not implemented yet.");
     306             : }
     307             : 
     308             : template <typename Derived>
     309           5 : typename Basis<Derived>::Iterator Basis<Derived>::begin() const {
     310           5 :     return kets.begin();
     311             : }
     312             : 
     313             : template <typename Derived>
     314           5 : typename Basis<Derived>::Iterator Basis<Derived>::end() const {
     315           5 :     return kets.end();
     316             : }
     317             : 
     318             : template <typename Derived>
     319          10 : Basis<Derived>::Iterator::Iterator(typename ketvec_t::const_iterator it) : it{std::move(it)} {}
     320             : 
     321             : template <typename Derived>
     322         225 : bool Basis<Derived>::Iterator::operator!=(const Iterator &other) const {
     323         225 :     return other.it != it;
     324             : }
     325             : 
     326             : template <typename Derived>
     327         220 : std::shared_ptr<const typename Basis<Derived>::ket_t> Basis<Derived>::Iterator::operator*() const {
     328         220 :     return *it;
     329             : }
     330             : 
     331             : template <typename Derived>
     332         220 : typename Basis<Derived>::Iterator &Basis<Derived>::Iterator::operator++() {
     333         220 :     ++it;
     334         220 :     return *this;
     335             : }
     336             : 
     337             : template <typename Derived>
     338        5800 : size_t Basis<Derived>::get_number_of_states() const {
     339        5800 :     return coefficients.matrix.cols();
     340             : }
     341             : 
     342             : template <typename Derived>
     343       47023 : size_t Basis<Derived>::get_number_of_kets() const {
     344       47023 :     return coefficients.matrix.rows();
     345             : }
     346             : 
     347             : template <typename Derived>
     348             : const Transformation<typename Basis<Derived>::scalar_t> &
     349         631 : Basis<Derived>::get_transformation() const {
     350         631 :     return coefficients;
     351             : }
     352             : 
     353             : template <typename Derived>
     354             : Transformation<typename Basis<Derived>::scalar_t>
     355           0 : Basis<Derived>::get_rotator(real_t alpha, real_t beta, real_t gamma) const {
     356           0 :     Transformation<scalar_t> transformation{{static_cast<Eigen::Index>(coefficients.matrix.rows()),
     357             :                                              static_cast<Eigen::Index>(coefficients.matrix.rows())},
     358             :                                             {TransformationType::ROTATE}};
     359             : 
     360           0 :     std::vector<Eigen::Triplet<scalar_t>> entries;
     361             : 
     362           0 :     set_task_status("Constructing basis rotation...");
     363           0 :     for (size_t idx_initial = 0; idx_initial < kets.size(); ++idx_initial) {
     364           0 :         real_t f = std::numeric_limits<real_t>::max();
     365           0 :         real_t m_initial = std::numeric_limits<real_t>::max();
     366             :         // TODO: this is a workaround, and should be fixed, once we restructure the quantum number
     367             :         // handling of the Basis class
     368             :         if constexpr (requires { kets[idx_initial]->get_quantum_number(std::string{}); }) {
     369           0 :             f = kets[idx_initial]->get_quantum_number("f");
     370           0 :             m_initial = kets[idx_initial]->get_quantum_number("m");
     371             :         } else {
     372           0 :             m_initial = kets[idx_initial]->get_quantum_number_m();
     373             :         }
     374             : 
     375           0 :         assert(2 * f == std::floor(2 * f) && f >= 0);
     376           0 :         assert(2 * m_initial == std::floor(2 * m_initial) && m_initial >= -f && m_initial <= f);
     377             : 
     378           0 :         for (real_t m_final = -f; m_final <= f; // NOSONAR m_final is precisely representable
     379             :              ++m_final) {
     380           0 :             auto val = wigner::wigner_uppercase_d_matrix<scalar_t>(f, m_initial, m_final, alpha,
     381             :                                                                    beta, gamma);
     382           0 :             size_t idx_final = get_ket_index_from_ket(
     383           0 :                 kets[idx_initial]->get_ket_for_different_quantum_number_m(m_final));
     384           0 :             entries.emplace_back(idx_final, idx_initial, val);
     385             :         }
     386             :     }
     387             : 
     388           0 :     transformation.matrix.setFromTriplets(entries.begin(), entries.end());
     389           0 :     transformation.matrix.makeCompressed();
     390             : 
     391           0 :     return transformation;
     392           0 : }
     393             : 
     394             : template <typename Derived>
     395           1 : Sorting Basis<Derived>::get_sorter(const std::vector<TransformationType> &labels) const {
     396           1 :     perform_sorter_checks(labels);
     397             : 
     398             :     // Throw a meaningful error if sorting by energy is requested as this might be a common mistake
     399           1 :     if (std::find(labels.begin(), labels.end(), TransformationType::SORT_BY_ENERGY) !=
     400           2 :         labels.end()) {
     401           0 :         throw std::invalid_argument("States do not store the energy and thus can not be sorted by "
     402             :                                     "the energy. Use an energy operator instead.");
     403             :     }
     404             : 
     405             :     // Initialize transformation
     406           1 :     Sorting transformation;
     407           1 :     transformation.matrix.resize(coefficients.matrix.cols());
     408           1 :     transformation.matrix.setIdentity();
     409             : 
     410             :     // Get the sorter
     411           1 :     get_sorter_without_checks(labels, transformation);
     412             : 
     413             :     // Check if all labels have been used for sorting
     414           1 :     if (labels != transformation.transformation_type) {
     415           0 :         throw std::invalid_argument("The states could not be sorted by all the requested labels.");
     416             :     }
     417             : 
     418           1 :     return transformation;
     419           0 : }
     420             : 
     421             : template <typename Derived>
     422             : std::vector<IndicesOfBlock>
     423           1 : Basis<Derived>::get_indices_of_blocks(const std::vector<TransformationType> &labels) const {
     424           1 :     perform_sorter_checks(labels);
     425             : 
     426           1 :     std::set<TransformationType> unique_labels(labels.begin(), labels.end());
     427           1 :     perform_blocks_checks(unique_labels);
     428             : 
     429             :     // Get the blocks
     430           1 :     IndicesOfBlocksCreator blocks_creator({0, static_cast<size_t>(coefficients.matrix.cols())});
     431           1 :     get_indices_of_blocks_without_checks(unique_labels, blocks_creator);
     432             : 
     433           2 :     return blocks_creator.create();
     434           1 : }
     435             : 
     436             : template <typename Derived>
     437         579 : void Basis<Derived>::get_sorter_without_checks(const std::vector<TransformationType> &labels,
     438             :                                                Sorting &transformation) const {
     439         579 :     constexpr real_t numerical_precision = 100 * std::numeric_limits<real_t>::epsilon();
     440             : 
     441         579 :     int *perm_begin = transformation.matrix.indices().data();
     442         579 :     int *perm_end = perm_begin + coefficients.matrix.cols();
     443         579 :     const int *perm_back = perm_end - 1;
     444             : 
     445             :     // Sort the vector based on the requested labels
     446         579 :     set_task_status("Sorting basis states...");
     447      647265 :     std::stable_sort(perm_begin, perm_end, [&](int a, int b) {
     448      284277 :         for (const auto &label : labels) {
     449      186405 :             switch (label) {
     450       30703 :             case TransformationType::SORT_BY_PARITY:
     451       30703 :                 if (state_index_to_parity[a] != state_index_to_parity[b]) {
     452       58668 :                     return state_index_to_parity[a] < state_index_to_parity[b];
     453             :                 }
     454       20305 :                 break;
     455      155702 :             case TransformationType::SORT_BY_QUANTUM_NUMBER_M:
     456      311404 :                 if (std::abs(state_index_to_quantum_number_m[a] -
     457      311404 :                              state_index_to_quantum_number_m[b]) > numerical_precision) {
     458       48270 :                     return state_index_to_quantum_number_m[a] < state_index_to_quantum_number_m[b];
     459             :                 }
     460      107432 :                 break;
     461           0 :             case TransformationType::SORT_BY_QUANTUM_NUMBER_F:
     462           0 :                 if (std::abs(state_index_to_quantum_number_f[a] -
     463           0 :                              state_index_to_quantum_number_f[b]) > numerical_precision) {
     464           0 :                     return state_index_to_quantum_number_f[a] < state_index_to_quantum_number_f[b];
     465             :                 }
     466           0 :                 break;
     467           0 :             case TransformationType::SORT_BY_KET:
     468           0 :                 if (state_index_to_ket_index[a] != state_index_to_ket_index[b]) {
     469           0 :                     return state_index_to_ket_index[a] < state_index_to_ket_index[b];
     470             :                 }
     471           0 :                 break;
     472           0 :             default:
     473           0 :                 std::abort(); // Can't happen because of previous checks
     474             :             }
     475             :         }
     476       97872 :         return false; // Elements are equal
     477             :     });
     478             : 
     479             :     // Check for invalid values and add transformation types
     480        1197 :     for (const auto &label : labels) {
     481         618 :         switch (label) {
     482          41 :         case TransformationType::SORT_BY_PARITY:
     483          41 :             if (state_index_to_parity[*perm_back] == Parity::UNKNOWN) {
     484           0 :                 throw std::invalid_argument(
     485             :                     "States cannot be labeled and thus not sorted by the parity.");
     486             :             }
     487          41 :             transformation.transformation_type.push_back(TransformationType::SORT_BY_PARITY);
     488          41 :             break;
     489         577 :         case TransformationType::SORT_BY_QUANTUM_NUMBER_M:
     490         577 :             if (state_index_to_quantum_number_m[*perm_back] == std::numeric_limits<real_t>::max()) {
     491           0 :                 throw std::invalid_argument(
     492             :                     "States cannot be labeled and thus not sorted by the quantum number m.");
     493             :             }
     494         577 :             transformation.transformation_type.push_back(
     495         577 :                 TransformationType::SORT_BY_QUANTUM_NUMBER_M);
     496         577 :             break;
     497           0 :         case TransformationType::SORT_BY_QUANTUM_NUMBER_F:
     498           0 :             if (state_index_to_quantum_number_f[*perm_back] == std::numeric_limits<real_t>::max()) {
     499           0 :                 throw std::invalid_argument(
     500             :                     "States cannot be labeled and thus not sorted by the quantum number f.");
     501             :             }
     502           0 :             transformation.transformation_type.push_back(
     503           0 :                 TransformationType::SORT_BY_QUANTUM_NUMBER_F);
     504           0 :             break;
     505           0 :         case TransformationType::SORT_BY_KET:
     506           0 :             if (state_index_to_ket_index[*perm_back] == std::numeric_limits<int>::max()) {
     507           0 :                 throw std::invalid_argument(
     508             :                     "States cannot be labeled and thus not sorted by kets.");
     509             :             }
     510           0 :             transformation.transformation_type.push_back(TransformationType::SORT_BY_KET);
     511           0 :             break;
     512           0 :         default:
     513           0 :             std::abort(); // Can't happen because of previous checks
     514             :         }
     515             :     }
     516         579 : }
     517             : 
     518             : template <typename Derived>
     519         579 : void Basis<Derived>::get_indices_of_blocks_without_checks(
     520             :     const std::set<TransformationType> &unique_labels,
     521             :     IndicesOfBlocksCreator &blocks_creator) const {
     522         579 :     constexpr real_t numerical_precision = 100 * std::numeric_limits<real_t>::epsilon();
     523             : 
     524         579 :     auto last_quantum_number_f = state_index_to_quantum_number_f[0];
     525         579 :     auto last_quantum_number_m = state_index_to_quantum_number_m[0];
     526         579 :     auto last_parity = state_index_to_parity[0];
     527         579 :     auto last_ket = state_index_to_ket_index[0];
     528             : 
     529         579 :     set_task_status("Identifying basis blocks...");
     530       33557 :     for (int i = 0; i < coefficients.matrix.cols(); ++i) {
     531       72251 :         for (auto label : unique_labels) {
     532       39987 :             if (label == TransformationType::SORT_BY_QUANTUM_NUMBER_F &&
     533           0 :                 std::abs(state_index_to_quantum_number_f[i] - last_quantum_number_f) >
     534             :                     numerical_precision) {
     535           0 :                 blocks_creator.add(i);
     536           0 :                 break;
     537             :             }
     538       72785 :             if (label == TransformationType::SORT_BY_QUANTUM_NUMBER_M &&
     539       32798 :                 std::abs(state_index_to_quantum_number_m[i] - last_quantum_number_m) >
     540             :                     numerical_precision) {
     541         570 :                 blocks_creator.add(i);
     542         570 :                 break;
     543             :             }
     544       46606 :             if (label == TransformationType::SORT_BY_PARITY &&
     545        7189 :                 state_index_to_parity[i] != last_parity) {
     546         144 :                 blocks_creator.add(i);
     547         144 :                 break;
     548             :             }
     549       39273 :             if (label == TransformationType::SORT_BY_KET &&
     550       39273 :                 state_index_to_ket_index[i] != last_ket &&
     551           0 :                 last_ket != std::numeric_limits<int>::max()) {
     552           0 :                 blocks_creator.add(i);
     553           0 :                 break;
     554             :             }
     555             :         }
     556       32978 :         last_quantum_number_f = state_index_to_quantum_number_f[i];
     557       32978 :         last_quantum_number_m = state_index_to_quantum_number_m[i];
     558       32978 :         last_parity = state_index_to_parity[i];
     559       32978 :         last_ket = state_index_to_ket_index[i];
     560             :     }
     561         579 : }
     562             : 
     563             : template <typename Derived>
     564          64 : std::shared_ptr<const Derived> Basis<Derived>::canonicalized() const {
     565          64 :     auto result = std::make_shared<Derived>(derived());
     566             : 
     567          64 :     size_t n = kets.size();
     568             : 
     569          64 :     result->coefficients.matrix.resize(n, n);
     570          64 :     result->coefficients.matrix.setIdentity();
     571          64 :     result->coefficients.transformation_type = {TransformationType::SORT_BY_KET};
     572             : 
     573          64 :     result->state_index_to_ket_index.resize(n);
     574          64 :     std::iota(result->state_index_to_ket_index.begin(), result->state_index_to_ket_index.end(), 0);
     575             : 
     576          64 :     result->ket_index_to_state_index.resize(n);
     577          64 :     std::iota(result->ket_index_to_state_index.begin(), result->ket_index_to_state_index.end(), 0);
     578             : 
     579          64 :     result->state_index_to_quantum_number_f.resize(n);
     580          64 :     result->state_index_to_quantum_number_m.resize(n);
     581          64 :     result->state_index_to_parity.resize(n);
     582          64 :     result->_has_quantum_number_f = true;
     583          64 :     result->_has_quantum_number_m = true;
     584          64 :     result->_has_parity = true;
     585             : 
     586       13536 :     for (size_t i = 0; i < n; ++i) {
     587       13472 :         real_t f = std::numeric_limits<real_t>::max();
     588       13472 :         real_t m = std::numeric_limits<real_t>::max();
     589       13472 :         Parity p = Parity::UNKNOWN;
     590             :         // TODO: this is a workaround, and should be fixed, once we restructure the quantum number
     591             :         // handling of the Basis class
     592             :         if constexpr (requires { kets[i]->get_quantum_number(std::string{}); }) {
     593       13472 :             f = kets[i]->get_quantum_number("f");
     594       13472 :             m = kets[i]->get_quantum_number("m");
     595       13472 :             p = static_cast<Parity>(static_cast<int>(kets[i]->get_quantum_number("parity")));
     596             :         } else {
     597           0 :             m = kets[i]->get_quantum_number_m();
     598             :         }
     599       13472 :         result->state_index_to_quantum_number_f[i] = f;
     600       13472 :         result->state_index_to_quantum_number_m[i] = m;
     601       13472 :         result->state_index_to_parity[i] = p;
     602       13472 :         if (f == std::numeric_limits<real_t>::max()) {
     603           0 :             result->_has_quantum_number_f = false;
     604             :         }
     605       13472 :         if (m == std::numeric_limits<real_t>::max()) {
     606           0 :             result->_has_quantum_number_m = false;
     607             :         }
     608       13472 :         if (p == Parity::UNKNOWN) {
     609           0 :             result->_has_parity = false;
     610             :         }
     611             :     }
     612             : 
     613         128 :     return result;
     614          64 : }
     615             : 
     616             : template <typename Derived>
     617        2131 : std::shared_ptr<const Derived> Basis<Derived>::transformed(const Sorting &transformation) const {
     618             :     // Create a copy of the current object
     619        2131 :     auto transformed = std::make_shared<Derived>(derived());
     620             : 
     621        2131 :     if (coefficients.matrix.cols() == 0) {
     622           0 :         return transformed;
     623             :     }
     624             : 
     625             :     // Apply the transformation
     626        2131 :     set_task_status("Applying basis sorting...");
     627        2131 :     transformed->coefficients.matrix = coefficients.matrix * transformation.matrix;
     628        2131 :     transformed->coefficients.transformation_type = transformation.transformation_type;
     629             : 
     630        2131 :     transformed->state_index_to_quantum_number_f.resize(transformation.matrix.size());
     631        2131 :     transformed->state_index_to_quantum_number_m.resize(transformation.matrix.size());
     632        2131 :     transformed->state_index_to_parity.resize(transformation.matrix.size());
     633        2131 :     transformed->state_index_to_ket_index.resize(transformation.matrix.size());
     634             : 
     635        2131 :     set_task_status("Relabeling sorted basis states...");
     636      180549 :     for (int i = 0; i < transformation.matrix.size(); ++i) {
     637      178418 :         transformed->state_index_to_quantum_number_f[i] =
     638      178418 :             state_index_to_quantum_number_f[transformation.matrix.indices()[i]];
     639      178418 :         transformed->state_index_to_quantum_number_m[i] =
     640      178418 :             state_index_to_quantum_number_m[transformation.matrix.indices()[i]];
     641      178418 :         transformed->state_index_to_parity[i] =
     642      178418 :             state_index_to_parity[transformation.matrix.indices()[i]];
     643             : 
     644      178418 :         size_t ket_index = state_index_to_ket_index[transformation.matrix.indices()[i]];
     645      178418 :         transformed->state_index_to_ket_index[i] = ket_index;
     646      178418 :         if (ket_index != std::numeric_limits<int>::max()) {
     647      178418 :             transformed->ket_index_to_state_index[ket_index] = i;
     648             :         }
     649             :     }
     650             : 
     651        2131 :     return transformed;
     652        2131 : }
     653             : 
     654             : template <typename Derived>
     655             : std::shared_ptr<const Derived>
     656         647 : Basis<Derived>::transformed(const Transformation<scalar_t> &transformation) const {
     657             :     // TODO why is "numerical_precision = 100 * std::sqrt(coefficients.matrix.rows()) *
     658             :     // std::numeric_limits<real_t>::epsilon()" too small for figuring out whether m is conserved?
     659         647 :     real_t numerical_precision = 0.001;
     660             : 
     661             :     // If the transformation is a rotation, it should be a rotation and nothing else
     662         647 :     bool is_rotation = false;
     663        1294 :     for (auto t : transformation.transformation_type) {
     664         647 :         if (t == TransformationType::ROTATE) {
     665           0 :             is_rotation = true;
     666           0 :             break;
     667             :         }
     668             :     }
     669         647 :     if (is_rotation && transformation.transformation_type.size() != 1) {
     670           0 :         throw std::invalid_argument("A rotation can not be combined with other transformations.");
     671             :     }
     672             : 
     673             :     // To apply a rotation, the object must only be sorted but other transformations are not allowed
     674         647 :     if (is_rotation) {
     675           0 :         for (auto t : coefficients.transformation_type) {
     676           0 :             if (!utils::is_sorting(t)) {
     677           0 :                 throw std::runtime_error(
     678             :                     "If the object was transformed by a different transformation "
     679             :                     "than sorting, it can not be rotated.");
     680             :             }
     681             :         }
     682             :     }
     683             : 
     684             :     // Create a copy of the current object
     685         647 :     auto transformed = std::make_shared<Derived>(derived());
     686             : 
     687         647 :     if (coefficients.matrix.cols() == 0) {
     688           0 :         return transformed;
     689             :     }
     690             : 
     691             :     // Apply the transformation
     692             :     // If a quantum number turns out to be conserved by the transformation, it will be
     693             :     // rounded to the nearest half integer to avoid loss of numerical_precision.
     694         647 :     set_task_status("Applying basis transformation...");
     695         647 :     transformed->coefficients.matrix = coefficients.matrix * transformation.matrix;
     696         647 :     transformed->coefficients.transformation_type = transformation.transformation_type;
     697             : 
     698         647 :     Eigen::SparseMatrix<real_t> probs = transformation.matrix.cwiseAbs2().transpose();
     699             : 
     700         647 :     set_task_status("Updating transformed quantum numbers...");
     701             :     {
     702        1294 :         auto map = Eigen::Map<const Eigen::VectorX<real_t>>(state_index_to_quantum_number_f.data(),
     703         647 :                                                             state_index_to_quantum_number_f.size());
     704         647 :         Eigen::VectorX<real_t> val = probs * map;
     705         647 :         Eigen::VectorX<real_t> sq = probs * map.cwiseAbs2();
     706         647 :         Eigen::VectorX<real_t> diff = (val.cwiseAbs2() - sq).cwiseAbs();
     707         647 :         transformed->state_index_to_quantum_number_f.resize(probs.rows());
     708             : 
     709       64351 :         for (size_t i = 0; i < transformed->state_index_to_quantum_number_f.size(); ++i) {
     710       63704 :             if (diff[i] < numerical_precision) {
     711        5101 :                 transformed->state_index_to_quantum_number_f[i] = std::round(val[i] * 2) / 2;
     712             :             } else {
     713       58603 :                 transformed->state_index_to_quantum_number_f[i] =
     714       58603 :                     std::numeric_limits<real_t>::max();
     715       58603 :                 transformed->_has_quantum_number_f = false;
     716             :             }
     717             :         }
     718         647 :     }
     719             : 
     720             :     {
     721        1294 :         auto map = Eigen::Map<const Eigen::VectorX<real_t>>(state_index_to_quantum_number_m.data(),
     722         647 :                                                             state_index_to_quantum_number_m.size());
     723         647 :         Eigen::VectorX<real_t> val = probs * map;
     724         647 :         Eigen::VectorX<real_t> sq = probs * map.cwiseAbs2();
     725         647 :         Eigen::VectorX<real_t> diff = (val.cwiseAbs2() - sq).cwiseAbs();
     726         647 :         transformed->state_index_to_quantum_number_m.resize(probs.rows());
     727             : 
     728       64351 :         for (size_t i = 0; i < transformed->state_index_to_quantum_number_m.size(); ++i) {
     729       63704 :             if (diff[i] < numerical_precision) {
     730       61004 :                 transformed->state_index_to_quantum_number_m[i] = std::round(val[i] * 2) / 2;
     731             :             } else {
     732        2700 :                 transformed->state_index_to_quantum_number_m[i] =
     733        2700 :                     std::numeric_limits<real_t>::max();
     734        2700 :                 transformed->_has_quantum_number_m = false;
     735             :             }
     736             :         }
     737         647 :     }
     738             : 
     739             :     {
     740             :         using utype = std::underlying_type<Parity>::type;
     741         647 :         Eigen::VectorX<real_t> map(state_index_to_parity.size());
     742      102291 :         for (size_t i = 0; i < state_index_to_parity.size(); ++i) {
     743      101644 :             map[i] = static_cast<utype>(state_index_to_parity[i]);
     744             :         }
     745         647 :         Eigen::VectorX<real_t> val = probs * map;
     746         647 :         Eigen::VectorX<real_t> sq = probs * map.cwiseAbs2();
     747         647 :         Eigen::VectorX<real_t> diff = (val.cwiseAbs2() - sq).cwiseAbs();
     748         647 :         transformed->state_index_to_parity.resize(probs.rows());
     749             : 
     750       64351 :         for (size_t i = 0; i < transformed->state_index_to_parity.size(); ++i) {
     751       63704 :             if (diff[i] < numerical_precision) {
     752       57596 :                 transformed->state_index_to_parity[i] = static_cast<Parity>(std::lround(val[i]));
     753             :             } else {
     754        6108 :                 transformed->state_index_to_parity[i] = Parity::UNKNOWN;
     755        6108 :                 transformed->_has_parity = false;
     756             :             }
     757             :         }
     758         647 :     }
     759             : 
     760         647 :     set_task_status("Obtaining transformed state mapping...");
     761             :     {
     762             :         // In the following, we obtain a bijective map between state index and ket index.
     763             : 
     764             :         // Find the maximum value in each row and column
     765         647 :         std::vector<real_t> max_in_row(transformed->coefficients.matrix.rows(), 0);
     766         647 :         std::vector<real_t> max_in_col(transformed->coefficients.matrix.cols(), 0);
     767      102699 :         for (int row = 0; row < transformed->coefficients.matrix.outerSize(); ++row) {
     768      102052 :             for (typename Eigen::SparseMatrix<scalar_t, Eigen::RowMajor>::InnerIterator it(
     769      102052 :                      transformed->coefficients.matrix, row);
     770      797908 :                  it; ++it) {
     771      695856 :                 real_t val = std::pow(std::abs(it.value()), 2);
     772      695856 :                 max_in_row[row] = std::max(max_in_row[row], val);
     773      695856 :                 max_in_col[it.col()] = std::max(max_in_col[it.col()], val);
     774             :             }
     775             :         }
     776             : 
     777             :         // Use the maximum values to define a cost for a sub-optimal mapping
     778         647 :         std::vector<real_t> costs;
     779         647 :         std::vector<std::pair<int, int>> mappings;
     780         647 :         costs.reserve(transformed->coefficients.matrix.nonZeros());
     781         647 :         mappings.reserve(transformed->coefficients.matrix.nonZeros());
     782      102699 :         for (int row = 0; row < transformed->coefficients.matrix.outerSize(); ++row) {
     783      102052 :             for (typename Eigen::SparseMatrix<scalar_t, Eigen::RowMajor>::InnerIterator it(
     784      102052 :                      transformed->coefficients.matrix, row);
     785      797908 :                  it; ++it) {
     786      695856 :                 real_t val = std::pow(std::abs(it.value()), 2);
     787      695856 :                 real_t cost = max_in_row[row] + max_in_col[it.col()] - 2 * val;
     788      695856 :                 costs.push_back(cost);
     789      695856 :                 mappings.push_back({row, it.col()});
     790             :             }
     791             :         }
     792             : 
     793             :         // Obtain from the costs in which order the mappings should be considered
     794         647 :         std::vector<size_t> order(costs.size());
     795         647 :         std::iota(order.begin(), order.end(), 0);
     796         647 :         std::sort(order.begin(), order.end(),
     797     8745103 :                   [&](size_t a, size_t b) { return costs[a] < costs[b]; });
     798             : 
     799             :         // Fill ket_index_to_state_index with invalid values as there can be more kets than states
     800         647 :         std::fill(transformed->ket_index_to_state_index.begin(),
     801         647 :                   transformed->ket_index_to_state_index.end(), std::numeric_limits<int>::max());
     802             : 
     803             :         // Generate the bijective map
     804         647 :         std::vector<bool> row_used(transformed->coefficients.matrix.rows(), false);
     805         647 :         std::vector<bool> col_used(transformed->coefficients.matrix.cols(), false);
     806         647 :         int num_used = 0;
     807      185390 :         for (size_t idx : order) {
     808      185385 :             int row = mappings[idx].first;  // corresponds to the ket index
     809      185385 :             int col = mappings[idx].second; // corresponds to the state index
     810      185385 :             if (!row_used[row] && !col_used[col]) {
     811       63701 :                 row_used[row] = true;
     812       63701 :                 col_used[col] = true;
     813       63701 :                 num_used++;
     814       63701 :                 transformed->state_index_to_ket_index[col] = row;
     815       63701 :                 transformed->ket_index_to_state_index[row] = col;
     816             :             }
     817      185385 :             if (num_used == transformed->coefficients.matrix.cols()) {
     818         642 :                 break;
     819             :             }
     820             :         }
     821         647 :         if (num_used != transformed->coefficients.matrix.cols()) {
     822           6 :             SPDLOG_WARN("A bijective map between states and kets could not be found.");
     823             :         }
     824         647 :     }
     825             : 
     826         647 :     return transformed;
     827         647 : }
     828             : 
     829             : template <typename Derived>
     830     2873395 : size_t Basis<Derived>::hash::operator()(const std::shared_ptr<const ket_t> &k) const {
     831     2873395 :     return typename ket_t::hash()(*k);
     832             : }
     833             : 
     834             : template <typename Derived>
     835         504 : bool Basis<Derived>::equal_to::operator()(const std::shared_ptr<const ket_t> &lhs,
     836             :                                           const std::shared_ptr<const ket_t> &rhs) const {
     837         504 :     return *lhs == *rhs;
     838             : }
     839             : 
     840             : // Explicit instantiations
     841             : template class Basis<BasisAtom<double>>;
     842             : template class Basis<BasisAtom<std::complex<double>>>;
     843             : template class Basis<BasisPair<double>>;
     844             : template class Basis<BasisPair<std::complex<double>>>;
     845             : } // namespace pairinteraction

Generated by: LCOV version 1.16