LCOV - code coverage report
Current view: top level - src/basis - Basis.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 343 437 78.5 %
Date: 2025-05-02 21:49:55 Functions: 87 200 43.5 %

          Line data    Source code
       1             : // SPDX-FileCopyrightText: 2024 Pairinteraction Developers
       2             : // SPDX-License-Identifier: LGPL-3.0-or-later
       3             : 
       4             : #include "pairinteraction/basis/Basis.hpp"
       5             : 
       6             : #include "pairinteraction/basis/BasisAtom.hpp"
       7             : #include "pairinteraction/basis/BasisPair.hpp"
       8             : #include "pairinteraction/enums/Parity.hpp"
       9             : #include "pairinteraction/enums/TransformationType.hpp"
      10             : #include "pairinteraction/ket/KetAtom.hpp"
      11             : #include "pairinteraction/ket/KetPair.hpp"
      12             : #include "pairinteraction/utils/eigen_assertion.hpp"
      13             : #include "pairinteraction/utils/eigen_compat.hpp"
      14             : #include "pairinteraction/utils/wigner.hpp"
      15             : 
      16             : #include <cassert>
      17             : #include <numeric>
      18             : #include <set>
      19             : 
      20             : namespace pairinteraction {
      21             : 
      22             : template <typename Scalar>
      23             : class BasisAtom;
      24             : 
      25             : template <typename Derived>
      26         296 : void Basis<Derived>::perform_sorter_checks(const std::vector<TransformationType> &labels) const {
      27             :     // Check if the labels are valid sorting labels
      28         574 :     for (const auto &label : labels) {
      29         278 :         if (!utils::is_sorting(label)) {
      30           0 :             throw std::invalid_argument("One of the labels is not a valid sorting label.");
      31             :         }
      32             :     }
      33         296 : }
      34             : 
      35             : template <typename Derived>
      36         112 : void Basis<Derived>::perform_blocks_checks(
      37             :     const std::set<TransformationType> &unique_labels) const {
      38             :     // Check if the states are sorted by the requested labels
      39         112 :     std::set<TransformationType> unique_labels_present;
      40         195 :     for (const auto &label : get_transformation().transformation_type) {
      41         123 :         if (!utils::is_sorting(label) || unique_labels_present.size() >= unique_labels.size()) {
      42          40 :             break;
      43             :         }
      44          83 :         unique_labels_present.insert(label);
      45             :     }
      46         112 :     if (unique_labels != unique_labels_present) {
      47           0 :         throw std::invalid_argument("The states are not sorted by the requested labels.");
      48             :     }
      49             : 
      50             :     // Throw a meaningful error if getting the blocks by energy is requested as this might be a
      51             :     // common mistake
      52         112 :     if (unique_labels.count(TransformationType::SORT_BY_ENERGY) > 0) {
      53           0 :         throw std::invalid_argument("States do not store the energy and thus no energy blocks can "
      54             :                                     "be obtained. Use an energy operator instead.");
      55             :     }
      56         112 : }
      57             : 
      58             : template <typename Derived>
      59         110 : Basis<Derived>::Basis(ketvec_t &&kets)
      60         220 :     : kets(std::move(kets)), coefficients{{static_cast<Eigen::Index>(this->kets.size()),
      61         110 :                                            static_cast<Eigen::Index>(this->kets.size())},
      62         220 :                                           {TransformationType::SORT_BY_KET}} {
      63         110 :     if (this->kets.empty()) {
      64           0 :         throw std::invalid_argument("The basis must contain at least one element.");
      65             :     }
      66         110 :     state_index_to_quantum_number_f.reserve(this->kets.size());
      67         110 :     state_index_to_quantum_number_m.reserve(this->kets.size());
      68         110 :     state_index_to_parity.reserve(this->kets.size());
      69         110 :     ket_to_ket_index.reserve(this->kets.size());
      70         110 :     size_t index = 0;
      71       23756 :     for (const auto &ket : this->kets) {
      72       23646 :         state_index_to_quantum_number_f.push_back(ket->get_quantum_number_f());
      73       23646 :         state_index_to_quantum_number_m.push_back(ket->get_quantum_number_m());
      74       23646 :         state_index_to_parity.push_back(ket->get_parity());
      75       23646 :         ket_to_ket_index[ket] = index++;
      76       23646 :         if (ket->get_quantum_number_f() == std::numeric_limits<real_t>::max()) {
      77       14370 :             _has_quantum_number_f = false;
      78             :         }
      79       23646 :         if (ket->get_quantum_number_m() == std::numeric_limits<real_t>::max()) {
      80           0 :             _has_quantum_number_m = false;
      81             :         }
      82       23646 :         if (ket->get_parity() == Parity::UNKNOWN) {
      83       14370 :             _has_parity = false;
      84             :         }
      85             :     }
      86         110 :     state_index_to_ket_index.resize(this->kets.size());
      87         110 :     std::iota(state_index_to_ket_index.begin(), state_index_to_ket_index.end(), 0);
      88         110 :     ket_index_to_state_index.resize(this->kets.size());
      89         110 :     std::iota(ket_index_to_state_index.begin(), ket_index_to_state_index.end(), 0);
      90         110 :     coefficients.matrix.setIdentity();
      91         110 : }
      92             : 
      93             : template <typename Derived>
      94         114 : bool Basis<Derived>::has_quantum_number_f() const {
      95         114 :     return _has_quantum_number_f;
      96             : }
      97             : 
      98             : template <typename Derived>
      99       37974 : bool Basis<Derived>::has_quantum_number_m() const {
     100       37974 :     return _has_quantum_number_m;
     101             : }
     102             : 
     103             : template <typename Derived>
     104         114 : bool Basis<Derived>::has_parity() const {
     105         114 :     return _has_parity;
     106             : }
     107             : 
     108             : template <typename Derived>
     109         583 : const Derived &Basis<Derived>::derived() const {
     110         583 :     return static_cast<const Derived &>(*this);
     111             : }
     112             : 
     113             : template <typename Derived>
     114         356 : const typename Basis<Derived>::ketvec_t &Basis<Derived>::get_kets() const {
     115         356 :     return kets;
     116             : }
     117             : 
     118             : template <typename Derived>
     119             : const Eigen::SparseMatrix<typename Basis<Derived>::scalar_t, Eigen::RowMajor> &
     120        2168 : Basis<Derived>::get_coefficients() const {
     121        2168 :     return coefficients.matrix;
     122             : }
     123             : 
     124             : template <typename Derived>
     125             : Eigen::SparseMatrix<typename Basis<Derived>::scalar_t, Eigen::RowMajor> &
     126           0 : Basis<Derived>::get_coefficients() {
     127           0 :     return coefficients.matrix;
     128             : }
     129             : 
     130             : template <typename Derived>
     131         239 : int Basis<Derived>::get_ket_index_from_ket(std::shared_ptr<const ket_t> ket) const {
     132         239 :     if (ket_to_ket_index.count(ket) == 0) {
     133           0 :         return -1;
     134             :     }
     135         239 :     return ket_to_ket_index.at(ket);
     136             : }
     137             : 
     138             : template <typename Derived>
     139             : Eigen::VectorX<typename Basis<Derived>::scalar_t>
     140          22 : Basis<Derived>::get_amplitudes(std::shared_ptr<const ket_t> ket) const {
     141          22 :     int ket_index = get_ket_index_from_ket(ket);
     142          22 :     if (ket_index < 0) {
     143           0 :         throw std::invalid_argument("The ket does not belong to the basis.");
     144             :     }
     145             :     // The following line is a more efficient alternative to
     146             :     // "get_amplitudes(get_canonical_state_from_ket(ket)).transpose()"
     147          44 :     return coefficients.matrix.row(ket_index);
     148             : }
     149             : 
     150             : template <typename Derived>
     151             : Eigen::SparseMatrix<typename Basis<Derived>::scalar_t, Eigen::RowMajor>
     152           6 : Basis<Derived>::get_amplitudes(std::shared_ptr<const Derived> other) const {
     153          12 :     return other->coefficients.matrix.adjoint() * coefficients.matrix;
     154             : }
     155             : 
     156             : template <typename Derived>
     157             : Eigen::VectorX<typename Basis<Derived>::real_t>
     158          20 : Basis<Derived>::get_overlaps(std::shared_ptr<const ket_t> ket) const {
     159          20 :     return get_amplitudes(ket).cwiseAbs2();
     160             : }
     161             : 
     162             : template <typename Derived>
     163             : Eigen::SparseMatrix<typename Basis<Derived>::real_t, Eigen::RowMajor>
     164           3 : Basis<Derived>::get_overlaps(std::shared_ptr<const Derived> other) const {
     165           3 :     return get_amplitudes(other).cwiseAbs2();
     166             : }
     167             : 
     168             : template <typename Derived>
     169           0 : typename Basis<Derived>::real_t Basis<Derived>::get_quantum_number_f(size_t state_index) const {
     170           0 :     real_t quantum_number_f = state_index_to_quantum_number_f.at(state_index);
     171           0 :     if (quantum_number_f == std::numeric_limits<real_t>::max()) {
     172           0 :         throw std::invalid_argument("The state does not have a well-defined quantum number f.");
     173             :     }
     174           0 :     return quantum_number_f;
     175             : }
     176             : 
     177             : template <typename Derived>
     178       37902 : typename Basis<Derived>::real_t Basis<Derived>::get_quantum_number_m(size_t state_index) const {
     179       37902 :     real_t quantum_number_m = state_index_to_quantum_number_m.at(state_index);
     180       37902 :     if (quantum_number_m == std::numeric_limits<real_t>::max()) {
     181           0 :         throw std::invalid_argument("The state does not have a well-defined quantum number m.");
     182             :     }
     183       37902 :     return quantum_number_m;
     184             : }
     185             : 
     186             : template <typename Derived>
     187          43 : Parity Basis<Derived>::get_parity(size_t state_index) const {
     188          43 :     Parity parity = state_index_to_parity.at(state_index);
     189          43 :     if (parity == Parity::UNKNOWN) {
     190           0 :         throw std::invalid_argument("The state does not have a well-defined parity.");
     191             :     }
     192          43 :     return parity;
     193             : }
     194             : 
     195             : template <typename Derived>
     196             : std::shared_ptr<const typename Basis<Derived>::ket_t>
     197          46 : Basis<Derived>::get_corresponding_ket(size_t state_index) const {
     198          46 :     size_t ket_index = state_index_to_ket_index.at(state_index);
     199          46 :     if (ket_index == std::numeric_limits<int>::max()) {
     200           0 :         throw std::invalid_argument("The state does not belong to a ket in a well-defined way.");
     201             :     }
     202          46 :     return kets[ket_index];
     203             : }
     204             : 
     205             : template <typename Derived>
     206             : std::shared_ptr<const typename Basis<Derived>::ket_t>
     207           0 : Basis<Derived>::get_corresponding_ket(std::shared_ptr<const Derived> /*state*/) const {
     208           0 :     throw std::runtime_error("Not implemented yet.");
     209             : }
     210             : 
     211             : template <typename Derived>
     212         189 : std::shared_ptr<const Derived> Basis<Derived>::get_state(size_t state_index) const {
     213             :     // Create a copy of the current object
     214         189 :     auto restricted = std::make_shared<Derived>(derived());
     215             : 
     216             :     // Restrict the copy to the state with the largest overlap
     217         189 :     restricted->coefficients.matrix = restricted->coefficients.matrix.col(state_index);
     218             : 
     219         189 :     std::fill(restricted->ket_index_to_state_index.begin(),
     220         189 :               restricted->ket_index_to_state_index.end(), std::numeric_limits<int>::max());
     221         189 :     restricted->ket_index_to_state_index[state_index_to_ket_index[state_index]] = 0;
     222             : 
     223         189 :     restricted->state_index_to_quantum_number_f = {state_index_to_quantum_number_f[state_index]};
     224         189 :     restricted->state_index_to_quantum_number_m = {state_index_to_quantum_number_m[state_index]};
     225         189 :     restricted->state_index_to_parity = {state_index_to_parity[state_index]};
     226         189 :     restricted->state_index_to_ket_index = {state_index_to_ket_index[state_index]};
     227             : 
     228         378 :     restricted->_has_quantum_number_f =
     229         189 :         restricted->state_index_to_quantum_number_f[0] != std::numeric_limits<real_t>::max();
     230         378 :     restricted->_has_quantum_number_m =
     231         189 :         restricted->state_index_to_quantum_number_m[0] != std::numeric_limits<real_t>::max();
     232         189 :     restricted->_has_parity = restricted->state_index_to_parity[0] != Parity::UNKNOWN;
     233             : 
     234         378 :     return restricted;
     235         189 : }
     236             : 
     237             : template <typename Derived>
     238             : std::shared_ptr<const typename Basis<Derived>::ket_t>
     239           0 : Basis<Derived>::get_ket(size_t ket_index) const {
     240           0 :     return kets[ket_index];
     241             : }
     242             : 
     243             : template <typename Derived>
     244          27 : std::shared_ptr<const Derived> Basis<Derived>::get_corresponding_state(size_t ket_index) const {
     245          27 :     size_t state_index = ket_index_to_state_index.at(ket_index);
     246          27 :     if (state_index == std::numeric_limits<int>::max()) {
     247           0 :         throw std::runtime_error("The ket does not belong to a state in a well-defined way.");
     248             :     }
     249          27 :     return get_state(state_index);
     250             : }
     251             : 
     252             : template <typename Derived>
     253             : std::shared_ptr<const Derived>
     254          27 : Basis<Derived>::get_corresponding_state(std::shared_ptr<const ket_t> ket) const {
     255          27 :     int ket_index = get_ket_index_from_ket(ket);
     256          27 :     if (ket_index < 0) {
     257           0 :         throw std::invalid_argument("The ket does not belong to the basis.");
     258             :     }
     259          27 :     return get_corresponding_state(ket_index);
     260             : }
     261             : 
     262             : template <typename Derived>
     263          92 : size_t Basis<Derived>::get_corresponding_state_index(size_t ket_index) const {
     264          92 :     int state_index = ket_index_to_state_index.at(ket_index);
     265          92 :     if (state_index == std::numeric_limits<int>::max()) {
     266           0 :         throw std::runtime_error("The ket does not belong to a state in a well-defined way.");
     267             :     }
     268          92 :     return state_index;
     269             : }
     270             : 
     271             : template <typename Derived>
     272          92 : size_t Basis<Derived>::get_corresponding_state_index(std::shared_ptr<const ket_t> ket) const {
     273          92 :     int ket_index = get_ket_index_from_ket(ket);
     274          92 :     if (ket_index < 0) {
     275           0 :         throw std::invalid_argument("The ket does not belong to the basis.");
     276             :     }
     277          92 :     return get_corresponding_state_index(ket_index);
     278             : }
     279             : 
     280             : template <typename Derived>
     281           0 : size_t Basis<Derived>::get_corresponding_ket_index(size_t state_index) const {
     282           0 :     int ket_index = state_index_to_ket_index.at(state_index);
     283           0 :     if (ket_index == std::numeric_limits<int>::max()) {
     284           0 :         throw std::runtime_error("The state does not belong to a ket in a well-defined way.");
     285             :     }
     286           0 :     return ket_index;
     287             : }
     288             : 
     289             : template <typename Derived>
     290           0 : size_t Basis<Derived>::get_corresponding_ket_index(std::shared_ptr<const Derived> /*state*/) const {
     291           0 :     throw std::runtime_error("Not implemented yet.");
     292             : }
     293             : 
     294             : template <typename Derived>
     295             : std::shared_ptr<const Derived>
     296          98 : Basis<Derived>::get_canonical_state_from_ket(size_t ket_index) const {
     297             :     // Create a copy of the current object
     298          98 :     auto created = std::make_shared<Derived>(derived());
     299             : 
     300             :     // Fill the copy with the state corresponding to the ket index
     301          98 :     created->coefficients.matrix =
     302         196 :         Eigen::SparseMatrix<scalar_t, Eigen::RowMajor>(coefficients.matrix.rows(), 1);
     303          98 :     created->coefficients.matrix.coeffRef(ket_index, 0) = 1;
     304          98 :     created->coefficients.matrix.makeCompressed();
     305             : 
     306          98 :     std::fill(created->ket_index_to_state_index.begin(), created->ket_index_to_state_index.end(),
     307          98 :               std::numeric_limits<int>::max());
     308          98 :     created->ket_index_to_state_index[ket_index] = 0;
     309             : 
     310          98 :     created->state_index_to_quantum_number_f = {kets[ket_index]->get_quantum_number_f()};
     311          98 :     created->state_index_to_quantum_number_m = {kets[ket_index]->get_quantum_number_m()};
     312          98 :     created->state_index_to_parity = {kets[ket_index]->get_parity()};
     313          98 :     created->state_index_to_ket_index = {ket_index};
     314             : 
     315         196 :     created->_has_quantum_number_f =
     316          98 :         created->state_index_to_quantum_number_f[0] != std::numeric_limits<real_t>::max();
     317         196 :     created->_has_quantum_number_m =
     318          98 :         created->state_index_to_quantum_number_m[0] != std::numeric_limits<real_t>::max();
     319          98 :     created->_has_parity = created->state_index_to_parity[0] != Parity::UNKNOWN;
     320             : 
     321         196 :     return created;
     322          98 : }
     323             : 
     324             : template <typename Derived>
     325             : std::shared_ptr<const Derived>
     326          98 : Basis<Derived>::get_canonical_state_from_ket(std::shared_ptr<const ket_t> ket) const {
     327          98 :     int ket_index = get_ket_index_from_ket(ket);
     328          98 :     if (ket_index < 0) {
     329           0 :         throw std::invalid_argument("The ket does not belong to the basis.");
     330             :     }
     331          98 :     return get_canonical_state_from_ket(ket_index);
     332             : }
     333             : 
     334             : template <typename Derived>
     335           4 : typename Basis<Derived>::Iterator Basis<Derived>::begin() const {
     336           4 :     return kets.begin();
     337             : }
     338             : 
     339             : template <typename Derived>
     340           4 : typename Basis<Derived>::Iterator Basis<Derived>::end() const {
     341           4 :     return kets.end();
     342             : }
     343             : 
     344             : template <typename Derived>
     345           8 : Basis<Derived>::Iterator::Iterator(typename ketvec_t::const_iterator it) : it{std::move(it)} {}
     346             : 
     347             : template <typename Derived>
     348          80 : bool Basis<Derived>::Iterator::operator!=(const Iterator &other) const {
     349          80 :     return other.it != it;
     350             : }
     351             : 
     352             : template <typename Derived>
     353          76 : std::shared_ptr<const typename Basis<Derived>::ket_t> Basis<Derived>::Iterator::operator*() const {
     354          76 :     return *it;
     355             : }
     356             : 
     357             : template <typename Derived>
     358          76 : typename Basis<Derived>::Iterator &Basis<Derived>::Iterator::operator++() {
     359          76 :     ++it;
     360          76 :     return *this;
     361             : }
     362             : 
     363             : template <typename Derived>
     364        3749 : size_t Basis<Derived>::get_number_of_states() const {
     365        3749 :     return coefficients.matrix.cols();
     366             : }
     367             : 
     368             : template <typename Derived>
     369        1695 : size_t Basis<Derived>::get_number_of_kets() const {
     370        1695 :     return coefficients.matrix.rows();
     371             : }
     372             : 
     373             : template <typename Derived>
     374             : const Transformation<typename Basis<Derived>::scalar_t> &
     375         113 : Basis<Derived>::get_transformation() const {
     376         113 :     return coefficients;
     377             : }
     378             : 
     379             : template <typename Derived>
     380             : Transformation<typename Basis<Derived>::scalar_t>
     381           0 : Basis<Derived>::get_rotator(real_t alpha, real_t beta, real_t gamma) const {
     382           0 :     Transformation<scalar_t> transformation{{static_cast<Eigen::Index>(coefficients.matrix.rows()),
     383             :                                              static_cast<Eigen::Index>(coefficients.matrix.rows())},
     384             :                                             {TransformationType::ROTATE}};
     385             : 
     386           0 :     std::vector<Eigen::Triplet<scalar_t>> entries;
     387             : 
     388           0 :     for (size_t idx_initial = 0; idx_initial < kets.size(); ++idx_initial) {
     389           0 :         real_t f = kets[idx_initial]->get_quantum_number_f();
     390           0 :         real_t m_initial = kets[idx_initial]->get_quantum_number_m();
     391             : 
     392           0 :         assert(2 * f == std::floor(2 * f) && f >= 0);
     393           0 :         assert(2 * m_initial == std::floor(2 * m_initial) && m_initial >= -f && m_initial <= f);
     394             : 
     395           0 :         for (real_t m_final = -f; m_final <= f; // NOSONAR m_final is precisely representable
     396             :              ++m_final) {
     397           0 :             auto val = wigner::wigner_uppercase_d_matrix<scalar_t>(f, m_initial, m_final, alpha,
     398             :                                                                    beta, gamma);
     399           0 :             size_t idx_final = get_ket_index_from_ket(
     400           0 :                 kets[idx_initial]->get_ket_for_different_quantum_number_m(m_final));
     401           0 :             entries.emplace_back(idx_final, idx_initial, val);
     402             :         }
     403             :     }
     404             : 
     405           0 :     transformation.matrix.setFromTriplets(entries.begin(), entries.end());
     406           0 :     transformation.matrix.makeCompressed();
     407             : 
     408           0 :     return transformation;
     409           0 : }
     410             : 
     411             : template <typename Derived>
     412           1 : Sorting Basis<Derived>::get_sorter(const std::vector<TransformationType> &labels) const {
     413           1 :     perform_sorter_checks(labels);
     414             : 
     415             :     // Throw a meaningful error if sorting by energy is requested as this might be a common mistake
     416           1 :     if (std::find(labels.begin(), labels.end(), TransformationType::SORT_BY_ENERGY) !=
     417           2 :         labels.end()) {
     418           0 :         throw std::invalid_argument("States do not store the energy and thus can not be sorted by "
     419             :                                     "the energy. Use an energy operator instead.");
     420             :     }
     421             : 
     422             :     // Initialize transformation
     423           1 :     Sorting transformation;
     424           1 :     transformation.matrix.resize(coefficients.matrix.cols());
     425           1 :     transformation.matrix.setIdentity();
     426             : 
     427             :     // Get the sorter
     428           1 :     get_sorter_without_checks(labels, transformation);
     429             : 
     430             :     // Check if all labels have been used for sorting
     431           1 :     if (labels != transformation.transformation_type) {
     432           0 :         throw std::invalid_argument("The states could not be sorted by all the requested labels.");
     433             :     }
     434             : 
     435           1 :     return transformation;
     436           0 : }
     437             : 
     438             : template <typename Derived>
     439             : std::vector<IndicesOfBlock>
     440           1 : Basis<Derived>::get_indices_of_blocks(const std::vector<TransformationType> &labels) const {
     441           1 :     perform_sorter_checks(labels);
     442             : 
     443           1 :     std::set<TransformationType> unique_labels(labels.begin(), labels.end());
     444           1 :     perform_blocks_checks(unique_labels);
     445             : 
     446             :     // Get the blocks
     447           1 :     IndicesOfBlocksCreator blocks_creator({0, static_cast<size_t>(coefficients.matrix.cols())});
     448           1 :     get_indices_of_blocks_without_checks(unique_labels, blocks_creator);
     449             : 
     450           2 :     return blocks_creator.create();
     451           1 : }
     452             : 
     453             : template <typename Derived>
     454          72 : void Basis<Derived>::get_sorter_without_checks(const std::vector<TransformationType> &labels,
     455             :                                                Sorting &transformation) const {
     456          72 :     constexpr real_t numerical_precision = 100 * std::numeric_limits<real_t>::epsilon();
     457             : 
     458          72 :     int *perm_begin = transformation.matrix.indices().data();
     459          72 :     int *perm_end = perm_begin + coefficients.matrix.cols();
     460          72 :     const int *perm_back = perm_end - 1;
     461             : 
     462             :     // Sort the vector based on the requested labels
     463      113960 :     std::stable_sort(perm_begin, perm_end, [&](int a, int b) {
     464       42444 :         for (const auto &label : labels) {
     465       30209 :             switch (label) {
     466        5067 :             case TransformationType::SORT_BY_PARITY:
     467        5067 :                 if (state_index_to_parity[a] != state_index_to_parity[b]) {
     468       13745 :                     return state_index_to_parity[a] < state_index_to_parity[b];
     469             :                 }
     470        3209 :                 break;
     471       25142 :             case TransformationType::SORT_BY_QUANTUM_NUMBER_M:
     472       50284 :                 if (std::abs(state_index_to_quantum_number_m[a] -
     473       50284 :                              state_index_to_quantum_number_m[b]) > numerical_precision) {
     474       11887 :                     return state_index_to_quantum_number_m[a] < state_index_to_quantum_number_m[b];
     475             :                 }
     476       13255 :                 break;
     477           0 :             case TransformationType::SORT_BY_QUANTUM_NUMBER_F:
     478           0 :                 if (std::abs(state_index_to_quantum_number_f[a] -
     479           0 :                              state_index_to_quantum_number_f[b]) > numerical_precision) {
     480           0 :                     return state_index_to_quantum_number_f[a] < state_index_to_quantum_number_f[b];
     481             :                 }
     482           0 :                 break;
     483           0 :             case TransformationType::SORT_BY_KET:
     484           0 :                 if (state_index_to_ket_index[a] != state_index_to_ket_index[b]) {
     485           0 :                     return state_index_to_ket_index[a] < state_index_to_ket_index[b];
     486             :                 }
     487           0 :                 break;
     488           0 :             default:
     489           0 :                 std::abort(); // Can't happen because of previous checks
     490             :             }
     491             :         }
     492       12235 :         return false; // Elements are equal
     493             :     });
     494             : 
     495             :     // Check for invalid values and add transformation types
     496         155 :     for (const auto &label : labels) {
     497          83 :         switch (label) {
     498          13 :         case TransformationType::SORT_BY_PARITY:
     499          13 :             if (state_index_to_parity[*perm_back] == Parity::UNKNOWN) {
     500           0 :                 throw std::invalid_argument(
     501             :                     "States cannot be labeled and thus not sorted by the parity.");
     502             :             }
     503          13 :             transformation.transformation_type.push_back(TransformationType::SORT_BY_PARITY);
     504          13 :             break;
     505          70 :         case TransformationType::SORT_BY_QUANTUM_NUMBER_M:
     506          70 :             if (state_index_to_quantum_number_m[*perm_back] == std::numeric_limits<real_t>::max()) {
     507           0 :                 throw std::invalid_argument(
     508             :                     "States cannot be labeled and thus not sorted by the quantum number m.");
     509             :             }
     510          70 :             transformation.transformation_type.push_back(
     511          70 :                 TransformationType::SORT_BY_QUANTUM_NUMBER_M);
     512          70 :             break;
     513           0 :         case TransformationType::SORT_BY_QUANTUM_NUMBER_F:
     514           0 :             if (state_index_to_quantum_number_f[*perm_back] == std::numeric_limits<real_t>::max()) {
     515           0 :                 throw std::invalid_argument(
     516             :                     "States cannot be labeled and thus not sorted by the quantum number f.");
     517             :             }
     518           0 :             transformation.transformation_type.push_back(
     519           0 :                 TransformationType::SORT_BY_QUANTUM_NUMBER_F);
     520           0 :             break;
     521           0 :         case TransformationType::SORT_BY_KET:
     522           0 :             if (state_index_to_ket_index[*perm_back] == std::numeric_limits<int>::max()) {
     523           0 :                 throw std::invalid_argument(
     524             :                     "States cannot be labeled and thus not sorted by kets.");
     525             :             }
     526           0 :             transformation.transformation_type.push_back(TransformationType::SORT_BY_KET);
     527           0 :             break;
     528           0 :         default:
     529           0 :             std::abort(); // Can't happen because of previous checks
     530             :         }
     531             :     }
     532          72 : }
     533             : 
     534             : template <typename Derived>
     535          72 : void Basis<Derived>::get_indices_of_blocks_without_checks(
     536             :     const std::set<TransformationType> &unique_labels,
     537             :     IndicesOfBlocksCreator &blocks_creator) const {
     538          72 :     constexpr real_t numerical_precision = 100 * std::numeric_limits<real_t>::epsilon();
     539             : 
     540          72 :     auto last_quantum_number_f = state_index_to_quantum_number_f[0];
     541          72 :     auto last_quantum_number_m = state_index_to_quantum_number_m[0];
     542          72 :     auto last_parity = state_index_to_parity[0];
     543          72 :     auto last_ket = state_index_to_ket_index[0];
     544             : 
     545        4916 :     for (int i = 0; i < coefficients.matrix.cols(); ++i) {
     546       10673 :         for (auto label : unique_labels) {
     547        6061 :             if (label == TransformationType::SORT_BY_QUANTUM_NUMBER_F) {
     548           0 :                 if (std::abs(state_index_to_quantum_number_f[i] - last_quantum_number_f) >
     549             :                     numerical_precision) {
     550           0 :                     blocks_creator.add(i);
     551           0 :                     break;
     552             :                 }
     553        6061 :             } else if (label == TransformationType::SORT_BY_QUANTUM_NUMBER_M) {
     554        4664 :                 if (std::abs(state_index_to_quantum_number_m[i] - last_quantum_number_m) >
     555             :                     numerical_precision) {
     556         186 :                     blocks_creator.add(i);
     557         186 :                     break;
     558             :                 }
     559        1397 :             } else if (label == TransformationType::SORT_BY_PARITY) {
     560        1397 :                 if (state_index_to_parity[i] != last_parity) {
     561          46 :                     blocks_creator.add(i);
     562          46 :                     break;
     563             :                 }
     564           0 :             } else if (label == TransformationType::SORT_BY_KET) {
     565           0 :                 if (state_index_to_ket_index[i] != last_ket) {
     566           0 :                     blocks_creator.add(i);
     567           0 :                     break;
     568             :                 }
     569             :             }
     570             :         }
     571        4844 :         last_quantum_number_f = state_index_to_quantum_number_f[i];
     572        4844 :         last_quantum_number_m = state_index_to_quantum_number_m[i];
     573        4844 :         last_parity = state_index_to_parity[i];
     574        4844 :         last_ket = state_index_to_ket_index[i];
     575             :     }
     576          72 : }
     577             : 
     578             : template <typename Derived>
     579         184 : std::shared_ptr<const Derived> Basis<Derived>::transformed(const Sorting &transformation) const {
     580             :     // Create a copy of the current object
     581         184 :     auto transformed = std::make_shared<Derived>(derived());
     582             : 
     583         184 :     if (coefficients.matrix.cols() == 0) {
     584           0 :         return transformed;
     585             :     }
     586             : 
     587             :     // Apply the transformation
     588         184 :     transformed->coefficients.matrix = coefficients.matrix * transformation.matrix;
     589         184 :     transformed->coefficients.transformation_type = transformation.transformation_type;
     590             : 
     591         184 :     transformed->state_index_to_quantum_number_f.resize(transformation.matrix.size());
     592         184 :     transformed->state_index_to_quantum_number_m.resize(transformation.matrix.size());
     593         184 :     transformed->state_index_to_parity.resize(transformation.matrix.size());
     594         184 :     transformed->state_index_to_ket_index.resize(transformation.matrix.size());
     595             : 
     596       14110 :     for (int i = 0; i < transformation.matrix.size(); ++i) {
     597       13926 :         transformed->state_index_to_quantum_number_f[i] =
     598       13926 :             state_index_to_quantum_number_f[transformation.matrix.indices()[i]];
     599       13926 :         transformed->state_index_to_quantum_number_m[i] =
     600       13926 :             state_index_to_quantum_number_m[transformation.matrix.indices()[i]];
     601       13926 :         transformed->state_index_to_parity[i] =
     602       13926 :             state_index_to_parity[transformation.matrix.indices()[i]];
     603       13926 :         transformed->state_index_to_ket_index[i] =
     604       13926 :             state_index_to_ket_index[transformation.matrix.indices()[i]];
     605       13926 :         transformed->ket_index_to_state_index
     606       13926 :             [state_index_to_ket_index[transformation.matrix.indices()[i]]] = i;
     607             :     }
     608             : 
     609         184 :     return transformed;
     610         184 : }
     611             : 
     612             : template <typename Derived>
     613             : std::shared_ptr<const Derived>
     614         112 : Basis<Derived>::transformed(const Transformation<scalar_t> &transformation) const {
     615             :     // TODO why is "numerical_precision = 100 * std::sqrt(coefficients.matrix.rows()) *
     616             :     // std::numeric_limits<real_t>::epsilon()" too small for figuring out whether m is conserved?
     617         112 :     real_t numerical_precision = 0.001;
     618             : 
     619             :     // If the transformation is a rotation, it should be a rotation and nothing else
     620         112 :     bool is_rotation = false;
     621         224 :     for (auto t : transformation.transformation_type) {
     622         112 :         if (t == TransformationType::ROTATE) {
     623           0 :             is_rotation = true;
     624           0 :             break;
     625             :         }
     626             :     }
     627         112 :     if (is_rotation && transformation.transformation_type.size() != 1) {
     628           0 :         throw std::invalid_argument("A rotation can not be combined with other transformations.");
     629             :     }
     630             : 
     631             :     // To apply a rotation, the object must only be sorted but other transformations are not allowed
     632         112 :     if (is_rotation) {
     633           0 :         for (auto t : coefficients.transformation_type) {
     634           0 :             if (!utils::is_sorting(t)) {
     635           0 :                 throw std::runtime_error(
     636             :                     "If the object was transformed by a different transformation "
     637             :                     "than sorting, it can not be rotated.");
     638             :             }
     639             :         }
     640             :     }
     641             : 
     642             :     // Create a copy of the current object
     643         112 :     auto transformed = std::make_shared<Derived>(derived());
     644             : 
     645         112 :     if (coefficients.matrix.cols() == 0) {
     646           0 :         return transformed;
     647             :     }
     648             : 
     649             :     // Apply the transformation
     650             :     // If a quantum number turns out to be conserved by the transformation, it will be
     651             :     // rounded to the nearest half integer to avoid loss of numerical_precision.
     652         112 :     transformed->coefficients.matrix = coefficients.matrix * transformation.matrix;
     653         112 :     transformed->coefficients.transformation_type = transformation.transformation_type;
     654             : 
     655         112 :     Eigen::SparseMatrix<real_t> probs = transformation.matrix.cwiseAbs2().transpose();
     656             : 
     657             :     {
     658         224 :         auto map = Eigen::Map<const Eigen::VectorX<real_t>>(state_index_to_quantum_number_f.data(),
     659         112 :                                                             state_index_to_quantum_number_f.size());
     660         112 :         Eigen::VectorX<real_t> val = probs * map;
     661         112 :         Eigen::VectorX<real_t> sq = probs * map.cwiseAbs2();
     662         112 :         Eigen::VectorX<real_t> diff = (val.cwiseAbs2() - sq).cwiseAbs();
     663         112 :         transformed->state_index_to_quantum_number_f.resize(probs.rows());
     664             : 
     665        6459 :         for (size_t i = 0; i < transformed->state_index_to_quantum_number_f.size(); ++i) {
     666        6347 :             if (diff[i] < numerical_precision) {
     667        1672 :                 transformed->state_index_to_quantum_number_f[i] = std::round(val[i] * 2) / 2;
     668             :             } else {
     669        4675 :                 transformed->state_index_to_quantum_number_f[i] =
     670        4675 :                     std::numeric_limits<real_t>::max();
     671        4675 :                 transformed->_has_quantum_number_f = false;
     672             :             }
     673             :         }
     674         112 :     }
     675             : 
     676             :     {
     677         224 :         auto map = Eigen::Map<const Eigen::VectorX<real_t>>(state_index_to_quantum_number_m.data(),
     678         112 :                                                             state_index_to_quantum_number_m.size());
     679         112 :         Eigen::VectorX<real_t> val = probs * map;
     680         112 :         Eigen::VectorX<real_t> sq = probs * map.cwiseAbs2();
     681         112 :         Eigen::VectorX<real_t> diff = (val.cwiseAbs2() - sq).cwiseAbs();
     682         112 :         transformed->state_index_to_quantum_number_m.resize(probs.rows());
     683             : 
     684        6459 :         for (size_t i = 0; i < transformed->state_index_to_quantum_number_m.size(); ++i) {
     685        6347 :             if (diff[i] < numerical_precision) {
     686        4413 :                 transformed->state_index_to_quantum_number_m[i] = std::round(val[i] * 2) / 2;
     687             :             } else {
     688        1934 :                 transformed->state_index_to_quantum_number_m[i] =
     689        1934 :                     std::numeric_limits<real_t>::max();
     690        1934 :                 transformed->_has_quantum_number_m = false;
     691             :             }
     692             :         }
     693         112 :     }
     694             : 
     695             :     {
     696             :         using utype = std::underlying_type<Parity>::type;
     697         112 :         Eigen::VectorX<real_t> map(state_index_to_parity.size());
     698        6980 :         for (size_t i = 0; i < state_index_to_parity.size(); ++i) {
     699        6868 :             map[i] = static_cast<utype>(state_index_to_parity[i]);
     700             :         }
     701         112 :         Eigen::VectorX<real_t> val = probs * map;
     702         112 :         Eigen::VectorX<real_t> sq = probs * map.cwiseAbs2();
     703         112 :         Eigen::VectorX<real_t> diff = (val.cwiseAbs2() - sq).cwiseAbs();
     704         112 :         transformed->state_index_to_parity.resize(probs.rows());
     705             : 
     706        6459 :         for (size_t i = 0; i < transformed->state_index_to_parity.size(); ++i) {
     707        6347 :             if (diff[i] < numerical_precision) {
     708        3556 :                 transformed->state_index_to_parity[i] = static_cast<Parity>(std::lround(val[i]));
     709             :             } else {
     710        2791 :                 transformed->state_index_to_parity[i] = Parity::UNKNOWN;
     711        2791 :                 transformed->_has_parity = false;
     712             :             }
     713             :         }
     714         112 :     }
     715             : 
     716             :     {
     717             :         // In the following, we obtain a bijective map between state index and ket index.
     718             : 
     719             :         // Find the maximum value in each row and column
     720         112 :         std::vector<real_t> max_in_row(transformed->coefficients.matrix.rows(), 0);
     721         112 :         std::vector<real_t> max_in_col(transformed->coefficients.matrix.cols(), 0);
     722        6980 :         for (int row = 0; row < transformed->coefficients.matrix.outerSize(); ++row) {
     723        6868 :             for (typename Eigen::SparseMatrix<scalar_t, Eigen::RowMajor>::InnerIterator it(
     724        6868 :                      transformed->coefficients.matrix, row);
     725      201060 :                  it; ++it) {
     726      194192 :                 real_t val = std::pow(std::abs(it.value()), 2);
     727      194192 :                 max_in_row[row] = std::max(max_in_row[row], val);
     728      194192 :                 max_in_col[it.col()] = std::max(max_in_col[it.col()], val);
     729             :             }
     730             :         }
     731             : 
     732             :         // Use the maximum values to define a cost for a sub-optimal mapping
     733         112 :         std::vector<real_t> costs;
     734         112 :         std::vector<std::pair<int, int>> mappings;
     735         112 :         costs.reserve(transformed->coefficients.matrix.nonZeros());
     736         112 :         mappings.reserve(transformed->coefficients.matrix.nonZeros());
     737        6980 :         for (int row = 0; row < transformed->coefficients.matrix.outerSize(); ++row) {
     738        6868 :             for (typename Eigen::SparseMatrix<scalar_t, Eigen::RowMajor>::InnerIterator it(
     739        6868 :                      transformed->coefficients.matrix, row);
     740      201060 :                  it; ++it) {
     741      194192 :                 real_t val = std::pow(std::abs(it.value()), 2);
     742      194192 :                 real_t cost = max_in_row[row] + max_in_col[it.col()] - 2 * val;
     743      194192 :                 costs.push_back(cost);
     744      194192 :                 mappings.push_back({row, it.col()});
     745             :             }
     746             :         }
     747             : 
     748             :         // Obtain from the costs in which order the mappings should be considered
     749         112 :         std::vector<size_t> order(costs.size());
     750         112 :         std::iota(order.begin(), order.end(), 0);
     751         112 :         std::sort(order.begin(), order.end(),
     752     2905733 :                   [&](size_t a, size_t b) { return costs[a] < costs[b]; });
     753             : 
     754             :         // Fill ket_index_to_state_index with invalid values as there can be more kets than states
     755         112 :         std::fill(transformed->ket_index_to_state_index.begin(),
     756         112 :                   transformed->ket_index_to_state_index.end(), std::numeric_limits<int>::max());
     757             : 
     758             :         // Generate the bijective map
     759         112 :         std::vector<bool> row_used(transformed->coefficients.matrix.rows(), false);
     760         112 :         std::vector<bool> col_used(transformed->coefficients.matrix.cols(), false);
     761         112 :         int num_used = 0;
     762       14677 :         for (size_t idx : order) {
     763       14675 :             int row = mappings[idx].first;  // corresponds to the ket index
     764       14675 :             int col = mappings[idx].second; // corresponds to the state index
     765       14675 :             if (!row_used[row] && !col_used[col]) {
     766        6347 :                 row_used[row] = true;
     767        6347 :                 col_used[col] = true;
     768        6347 :                 num_used++;
     769        6347 :                 transformed->state_index_to_ket_index[col] = row;
     770        6347 :                 transformed->ket_index_to_state_index[row] = col;
     771             :             }
     772       14675 :             if (num_used == transformed->coefficients.matrix.cols()) {
     773         110 :                 break;
     774             :             }
     775             :         }
     776         112 :         assert(num_used == transformed->coefficients.matrix.cols());
     777         112 :     }
     778             : 
     779         112 :     return transformed;
     780         112 : }
     781             : 
     782             : template <typename Derived>
     783       24124 : size_t Basis<Derived>::hash::operator()(const std::shared_ptr<const ket_t> &k) const {
     784       24124 :     return typename ket_t::hash()(*k);
     785             : }
     786             : 
     787             : template <typename Derived>
     788         478 : bool Basis<Derived>::equal_to::operator()(const std::shared_ptr<const ket_t> &lhs,
     789             :                                           const std::shared_ptr<const ket_t> &rhs) const {
     790         478 :     return *lhs == *rhs;
     791             : }
     792             : 
     793             : // Explicit instantiations
     794             : template class Basis<BasisAtom<double>>;
     795             : template class Basis<BasisAtom<std::complex<double>>>;
     796             : template class Basis<BasisPair<double>>;
     797             : template class Basis<BasisPair<std::complex<double>>>;
     798             : } // namespace pairinteraction

Generated by: LCOV version 1.16