LCOV - code coverage report
Current view: top level - src/basis - Basis.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 364 459 79.3 %
Date: 2025-09-03 06:42:31 Functions: 88 204 43.1 %

          Line data    Source code
       1             : // SPDX-FileCopyrightText: 2024 PairInteraction Developers
       2             : // SPDX-License-Identifier: LGPL-3.0-or-later
       3             : 
       4             : #include "pairinteraction/basis/Basis.hpp"
       5             : 
       6             : #include "pairinteraction/basis/BasisAtom.hpp"
       7             : #include "pairinteraction/basis/BasisPair.hpp"
       8             : #include "pairinteraction/enums/Parity.hpp"
       9             : #include "pairinteraction/enums/TransformationType.hpp"
      10             : #include "pairinteraction/ket/KetAtom.hpp"
      11             : #include "pairinteraction/ket/KetPair.hpp"
      12             : #include "pairinteraction/utils/eigen_assertion.hpp"
      13             : #include "pairinteraction/utils/eigen_compat.hpp"
      14             : #include "pairinteraction/utils/wigner.hpp"
      15             : 
      16             : #include <cassert>
      17             : #include <numeric>
      18             : #include <set>
      19             : 
      20             : namespace pairinteraction {
      21             : 
      22             : template <typename Scalar>
      23             : class BasisAtom;
      24             : 
      25             : template <typename Derived>
      26         550 : void Basis<Derived>::perform_sorter_checks(const std::vector<TransformationType> &labels) const {
      27             :     // Check if the labels are valid sorting labels
      28        1106 :     for (const auto &label : labels) {
      29         556 :         if (!utils::is_sorting(label)) {
      30           0 :             throw std::invalid_argument("One of the labels is not a valid sorting label.");
      31             :         }
      32             :     }
      33         550 : }
      34             : 
      35             : template <typename Derived>
      36         141 : void Basis<Derived>::perform_blocks_checks(
      37             :     const std::set<TransformationType> &unique_labels) const {
      38             :     // Check if the states are sorted by the requested labels
      39         141 :     std::set<TransformationType> unique_labels_present;
      40         265 :     for (const auto &label : get_transformation().transformation_type) {
      41         164 :         if (!utils::is_sorting(label) || unique_labels_present.size() >= unique_labels.size()) {
      42          40 :             break;
      43             :         }
      44         124 :         unique_labels_present.insert(label);
      45             :     }
      46         141 :     if (unique_labels != unique_labels_present) {
      47           0 :         throw std::invalid_argument("The states are not sorted by the requested labels.");
      48             :     }
      49             : 
      50             :     // Throw a meaningful error if getting the blocks by energy is requested as this might be a
      51             :     // common mistake
      52         141 :     if (unique_labels.count(TransformationType::SORT_BY_ENERGY) > 0) {
      53           0 :         throw std::invalid_argument("States do not store the energy and thus no energy blocks can "
      54             :                                     "be obtained. Use an energy operator instead.");
      55             :     }
      56         141 : }
      57             : 
      58             : template <typename Derived>
      59         210 : Basis<Derived>::Basis(ketvec_t &&kets)
      60         420 :     : kets(std::move(kets)), coefficients{{static_cast<Eigen::Index>(this->kets.size()),
      61         210 :                                            static_cast<Eigen::Index>(this->kets.size())},
      62         420 :                                           {TransformationType::SORT_BY_KET}} {
      63         210 :     if (this->kets.empty()) {
      64           0 :         throw std::invalid_argument("The basis must contain at least one element.");
      65             :     }
      66         210 :     state_index_to_quantum_number_f.reserve(this->kets.size());
      67         210 :     state_index_to_quantum_number_m.reserve(this->kets.size());
      68         210 :     state_index_to_parity.reserve(this->kets.size());
      69         210 :     ket_to_ket_index.reserve(this->kets.size());
      70         210 :     size_t index = 0;
      71       81146 :     for (const auto &ket : this->kets) {
      72       80936 :         state_index_to_quantum_number_f.push_back(ket->get_quantum_number_f());
      73       80936 :         state_index_to_quantum_number_m.push_back(ket->get_quantum_number_m());
      74       80936 :         state_index_to_parity.push_back(ket->get_parity());
      75       80936 :         ket_to_ket_index[ket] = index++;
      76       80936 :         if (ket->get_quantum_number_f() == std::numeric_limits<real_t>::max()) {
      77       69647 :             _has_quantum_number_f = false;
      78             :         }
      79       80936 :         if (ket->get_quantum_number_m() == std::numeric_limits<real_t>::max()) {
      80           0 :             _has_quantum_number_m = false;
      81             :         }
      82       80936 :         if (ket->get_parity() == Parity::UNKNOWN) {
      83       69647 :             _has_parity = false;
      84             :         }
      85             :     }
      86         210 :     state_index_to_ket_index.resize(this->kets.size());
      87         210 :     std::iota(state_index_to_ket_index.begin(), state_index_to_ket_index.end(), 0);
      88         210 :     ket_index_to_state_index.resize(this->kets.size());
      89         210 :     std::iota(ket_index_to_state_index.begin(), ket_index_to_state_index.end(), 0);
      90         210 :     coefficients.matrix.setIdentity();
      91         210 : }
      92             : 
      93             : template <typename Derived>
      94          85 : bool Basis<Derived>::has_quantum_number_f() const {
      95          85 :     return _has_quantum_number_f;
      96             : }
      97             : 
      98             : template <typename Derived>
      99      149019 : bool Basis<Derived>::has_quantum_number_m() const {
     100      149019 :     return _has_quantum_number_m;
     101             : }
     102             : 
     103             : template <typename Derived>
     104          85 : bool Basis<Derived>::has_parity() const {
     105          85 :     return _has_parity;
     106             : }
     107             : 
     108             : template <typename Derived>
     109         859 : const Derived &Basis<Derived>::derived() const {
     110         859 :     return static_cast<const Derived &>(*this);
     111             : }
     112             : 
     113             : template <typename Derived>
     114         302 : const typename Basis<Derived>::ketvec_t &Basis<Derived>::get_kets() const {
     115         302 :     return kets;
     116             : }
     117             : 
     118             : template <typename Derived>
     119             : const Eigen::SparseMatrix<typename Basis<Derived>::scalar_t, Eigen::RowMajor> &
     120        2569 : Basis<Derived>::get_coefficients() const {
     121        2569 :     return coefficients.matrix;
     122             : }
     123             : 
     124             : template <typename Derived>
     125             : Eigen::SparseMatrix<typename Basis<Derived>::scalar_t, Eigen::RowMajor> &
     126           0 : Basis<Derived>::get_coefficients() {
     127           0 :     return coefficients.matrix;
     128             : }
     129             : 
     130             : template <typename Derived>
     131           5 : void Basis<Derived>::set_coefficients(
     132             :     const Eigen::SparseMatrix<scalar_t, Eigen::RowMajor> &values) {
     133           5 :     if (values.rows() != coefficients.matrix.rows()) {
     134           0 :         throw std::invalid_argument("Incompatible number of rows.");
     135             :     }
     136           5 :     if (values.cols() != coefficients.matrix.cols()) {
     137           0 :         throw std::invalid_argument("Incompatible number of columns.");
     138             :     }
     139             : 
     140           5 :     coefficients.matrix = values;
     141             : 
     142           5 :     std::fill(ket_index_to_state_index.begin(), ket_index_to_state_index.end(),
     143           5 :               std::numeric_limits<int>::max());
     144           5 :     std::fill(state_index_to_ket_index.begin(), state_index_to_ket_index.end(),
     145           5 :               std::numeric_limits<int>::max());
     146           5 :     std::fill(state_index_to_quantum_number_f.begin(), state_index_to_quantum_number_f.end(),
     147           5 :               std::numeric_limits<real_t>::max());
     148           5 :     std::fill(state_index_to_quantum_number_m.begin(), state_index_to_quantum_number_m.end(),
     149           5 :               std::numeric_limits<real_t>::max());
     150           5 :     std::fill(state_index_to_parity.begin(), state_index_to_parity.end(), Parity::UNKNOWN);
     151           5 :     _has_quantum_number_f = false;
     152           5 :     _has_quantum_number_m = false;
     153           5 :     _has_parity = false;
     154           5 : }
     155             : 
     156             : template <typename Derived>
     157         308 : int Basis<Derived>::get_ket_index_from_ket(std::shared_ptr<const ket_t> ket) const {
     158         308 :     if (ket_to_ket_index.count(ket) == 0) {
     159           0 :         return -1;
     160             :     }
     161         308 :     return ket_to_ket_index.at(ket);
     162             : }
     163             : 
     164             : template <typename Derived>
     165             : Eigen::VectorX<typename Basis<Derived>::scalar_t>
     166          69 : Basis<Derived>::get_amplitudes(std::shared_ptr<const ket_t> ket) const {
     167          69 :     int ket_index = get_ket_index_from_ket(ket);
     168          69 :     if (ket_index < 0) {
     169           0 :         throw std::invalid_argument("The ket does not belong to the basis.");
     170             :     }
     171             :     // The following line is a more efficient alternative to
     172             :     // "get_amplitudes(get_canonical_state_from_ket(ket)).transpose()"
     173         138 :     return coefficients.matrix.row(ket_index);
     174             : }
     175             : 
     176             : template <typename Derived>
     177             : Eigen::SparseMatrix<typename Basis<Derived>::scalar_t, Eigen::RowMajor>
     178           6 : Basis<Derived>::get_amplitudes(std::shared_ptr<const Derived> other) const {
     179          12 :     return other->coefficients.matrix.adjoint() * coefficients.matrix;
     180             : }
     181             : 
     182             : template <typename Derived>
     183             : Eigen::VectorX<typename Basis<Derived>::real_t>
     184          67 : Basis<Derived>::get_overlaps(std::shared_ptr<const ket_t> ket) const {
     185          67 :     return get_amplitudes(ket).cwiseAbs2();
     186             : }
     187             : 
     188             : template <typename Derived>
     189             : Eigen::SparseMatrix<typename Basis<Derived>::real_t, Eigen::RowMajor>
     190           3 : Basis<Derived>::get_overlaps(std::shared_ptr<const Derived> other) const {
     191           3 :     return get_amplitudes(other).cwiseAbs2();
     192             : }
     193             : 
     194             : template <typename Derived>
     195           0 : typename Basis<Derived>::real_t Basis<Derived>::get_quantum_number_f(size_t state_index) const {
     196           0 :     real_t quantum_number_f = state_index_to_quantum_number_f.at(state_index);
     197           0 :     if (quantum_number_f == std::numeric_limits<real_t>::max()) {
     198           0 :         throw std::invalid_argument("The state does not have a well-defined quantum number f.");
     199             :     }
     200           0 :     return quantum_number_f;
     201             : }
     202             : 
     203             : template <typename Derived>
     204      148976 : typename Basis<Derived>::real_t Basis<Derived>::get_quantum_number_m(size_t state_index) const {
     205      148976 :     real_t quantum_number_m = state_index_to_quantum_number_m.at(state_index);
     206      148976 :     if (quantum_number_m == std::numeric_limits<real_t>::max()) {
     207           0 :         throw std::invalid_argument("The state does not have a well-defined quantum number m.");
     208             :     }
     209      148976 :     return quantum_number_m;
     210             : }
     211             : 
     212             : template <typename Derived>
     213          43 : Parity Basis<Derived>::get_parity(size_t state_index) const {
     214          43 :     Parity parity = state_index_to_parity.at(state_index);
     215          43 :     if (parity == Parity::UNKNOWN) {
     216           0 :         throw std::invalid_argument("The state does not have a well-defined parity.");
     217             :     }
     218          43 :     return parity;
     219             : }
     220             : 
     221             : template <typename Derived>
     222             : std::shared_ptr<const typename Basis<Derived>::ket_t>
     223          64 : Basis<Derived>::get_corresponding_ket(size_t state_index) const {
     224          64 :     size_t ket_index = state_index_to_ket_index.at(state_index);
     225          64 :     if (ket_index == std::numeric_limits<int>::max()) {
     226           0 :         throw std::invalid_argument("The state does not belong to a ket in a well-defined way.");
     227             :     }
     228          64 :     return kets[ket_index];
     229             : }
     230             : 
     231             : template <typename Derived>
     232             : std::shared_ptr<const typename Basis<Derived>::ket_t>
     233           0 : Basis<Derived>::get_corresponding_ket(std::shared_ptr<const Derived> /*state*/) const {
     234           0 :     throw std::runtime_error("Not implemented yet.");
     235             : }
     236             : 
     237             : template <typename Derived>
     238         190 : std::shared_ptr<const Derived> Basis<Derived>::get_state(size_t state_index) const {
     239             :     // Create a copy of the current object
     240         190 :     auto restricted = std::make_shared<Derived>(derived());
     241             : 
     242             :     // Restrict the copy to the state with the largest overlap
     243         190 :     restricted->coefficients.matrix = restricted->coefficients.matrix.col(state_index);
     244             : 
     245         190 :     std::fill(restricted->ket_index_to_state_index.begin(),
     246         190 :               restricted->ket_index_to_state_index.end(), std::numeric_limits<int>::max());
     247             : 
     248         190 :     size_t ket_index = state_index_to_ket_index[state_index];
     249         190 :     restricted->state_index_to_ket_index = {ket_index};
     250         190 :     if (ket_index != std::numeric_limits<int>::max()) {
     251         190 :         restricted->ket_index_to_state_index[ket_index] = 0;
     252             :     }
     253             : 
     254         190 :     restricted->state_index_to_quantum_number_f = {state_index_to_quantum_number_f[state_index]};
     255         190 :     restricted->state_index_to_quantum_number_m = {state_index_to_quantum_number_m[state_index]};
     256         190 :     restricted->state_index_to_parity = {state_index_to_parity[state_index]};
     257             : 
     258         380 :     restricted->_has_quantum_number_f =
     259         190 :         restricted->state_index_to_quantum_number_f[0] != std::numeric_limits<real_t>::max();
     260         380 :     restricted->_has_quantum_number_m =
     261         190 :         restricted->state_index_to_quantum_number_m[0] != std::numeric_limits<real_t>::max();
     262         190 :     restricted->_has_parity = restricted->state_index_to_parity[0] != Parity::UNKNOWN;
     263             : 
     264         380 :     return restricted;
     265         190 : }
     266             : 
     267             : template <typename Derived>
     268             : std::shared_ptr<const typename Basis<Derived>::ket_t>
     269           0 : Basis<Derived>::get_ket(size_t ket_index) const {
     270           0 :     return kets[ket_index];
     271             : }
     272             : 
     273             : template <typename Derived>
     274          28 : std::shared_ptr<const Derived> Basis<Derived>::get_corresponding_state(size_t ket_index) const {
     275          28 :     size_t state_index = ket_index_to_state_index.at(ket_index);
     276          28 :     if (state_index == std::numeric_limits<int>::max()) {
     277           0 :         throw std::runtime_error("The ket does not belong to a state in a well-defined way.");
     278             :     }
     279          28 :     return get_state(state_index);
     280             : }
     281             : 
     282             : template <typename Derived>
     283             : std::shared_ptr<const Derived>
     284          28 : Basis<Derived>::get_corresponding_state(std::shared_ptr<const ket_t> ket) const {
     285          28 :     int ket_index = get_ket_index_from_ket(ket);
     286          28 :     if (ket_index < 0) {
     287           0 :         throw std::invalid_argument("The ket does not belong to the basis.");
     288             :     }
     289          28 :     return get_corresponding_state(ket_index);
     290             : }
     291             : 
     292             : template <typename Derived>
     293          92 : size_t Basis<Derived>::get_corresponding_state_index(size_t ket_index) const {
     294          92 :     int state_index = ket_index_to_state_index.at(ket_index);
     295          92 :     if (state_index == std::numeric_limits<int>::max()) {
     296           0 :         throw std::runtime_error("The ket does not belong to a state in a well-defined way.");
     297             :     }
     298          92 :     return state_index;
     299             : }
     300             : 
     301             : template <typename Derived>
     302          92 : size_t Basis<Derived>::get_corresponding_state_index(std::shared_ptr<const ket_t> ket) const {
     303          92 :     int ket_index = get_ket_index_from_ket(ket);
     304          92 :     if (ket_index < 0) {
     305           0 :         throw std::invalid_argument("The ket does not belong to the basis.");
     306             :     }
     307          92 :     return get_corresponding_state_index(ket_index);
     308             : }
     309             : 
     310             : template <typename Derived>
     311           0 : size_t Basis<Derived>::get_corresponding_ket_index(size_t state_index) const {
     312           0 :     size_t ket_index = state_index_to_ket_index.at(state_index);
     313           0 :     if (ket_index == std::numeric_limits<int>::max()) {
     314           0 :         throw std::runtime_error("The state does not belong to a ket in a well-defined way.");
     315             :     }
     316           0 :     return ket_index;
     317             : }
     318             : 
     319             : template <typename Derived>
     320           0 : size_t Basis<Derived>::get_corresponding_ket_index(std::shared_ptr<const Derived> /*state*/) const {
     321           0 :     throw std::runtime_error("Not implemented yet.");
     322             : }
     323             : 
     324             : template <typename Derived>
     325             : std::shared_ptr<const Derived>
     326         119 : Basis<Derived>::get_canonical_state_from_ket(size_t ket_index) const {
     327             :     // Create a copy of the current object
     328         119 :     auto created = std::make_shared<Derived>(derived());
     329             : 
     330             :     // Fill the copy with the state corresponding to the ket index
     331         119 :     created->coefficients.matrix =
     332         238 :         Eigen::SparseMatrix<scalar_t, Eigen::RowMajor>(coefficients.matrix.rows(), 1);
     333         119 :     created->coefficients.matrix.coeffRef(ket_index, 0) = 1;
     334         119 :     created->coefficients.matrix.makeCompressed();
     335             : 
     336         119 :     std::fill(created->ket_index_to_state_index.begin(), created->ket_index_to_state_index.end(),
     337         119 :               std::numeric_limits<int>::max());
     338         119 :     created->ket_index_to_state_index[ket_index] = 0;
     339             : 
     340         119 :     created->state_index_to_quantum_number_f = {kets[ket_index]->get_quantum_number_f()};
     341         119 :     created->state_index_to_quantum_number_m = {kets[ket_index]->get_quantum_number_m()};
     342         119 :     created->state_index_to_parity = {kets[ket_index]->get_parity()};
     343         119 :     created->state_index_to_ket_index = {ket_index};
     344             : 
     345         238 :     created->_has_quantum_number_f =
     346         119 :         created->state_index_to_quantum_number_f[0] != std::numeric_limits<real_t>::max();
     347         238 :     created->_has_quantum_number_m =
     348         119 :         created->state_index_to_quantum_number_m[0] != std::numeric_limits<real_t>::max();
     349         119 :     created->_has_parity = created->state_index_to_parity[0] != Parity::UNKNOWN;
     350             : 
     351         238 :     return created;
     352         119 : }
     353             : 
     354             : template <typename Derived>
     355             : std::shared_ptr<const Derived>
     356         119 : Basis<Derived>::get_canonical_state_from_ket(std::shared_ptr<const ket_t> ket) const {
     357         119 :     int ket_index = get_ket_index_from_ket(ket);
     358         119 :     if (ket_index < 0) {
     359           0 :         throw std::invalid_argument("The ket does not belong to the basis.");
     360             :     }
     361         119 :     return get_canonical_state_from_ket(ket_index);
     362             : }
     363             : 
     364             : template <typename Derived>
     365           4 : typename Basis<Derived>::Iterator Basis<Derived>::begin() const {
     366           4 :     return kets.begin();
     367             : }
     368             : 
     369             : template <typename Derived>
     370           4 : typename Basis<Derived>::Iterator Basis<Derived>::end() const {
     371           4 :     return kets.end();
     372             : }
     373             : 
     374             : template <typename Derived>
     375           8 : Basis<Derived>::Iterator::Iterator(typename ketvec_t::const_iterator it) : it{std::move(it)} {}
     376             : 
     377             : template <typename Derived>
     378          80 : bool Basis<Derived>::Iterator::operator!=(const Iterator &other) const {
     379          80 :     return other.it != it;
     380             : }
     381             : 
     382             : template <typename Derived>
     383          76 : std::shared_ptr<const typename Basis<Derived>::ket_t> Basis<Derived>::Iterator::operator*() const {
     384          76 :     return *it;
     385             : }
     386             : 
     387             : template <typename Derived>
     388          76 : typename Basis<Derived>::Iterator &Basis<Derived>::Iterator::operator++() {
     389          76 :     ++it;
     390          76 :     return *this;
     391             : }
     392             : 
     393             : template <typename Derived>
     394        4563 : size_t Basis<Derived>::get_number_of_states() const {
     395        4563 :     return coefficients.matrix.cols();
     396             : }
     397             : 
     398             : template <typename Derived>
     399        1931 : size_t Basis<Derived>::get_number_of_kets() const {
     400        1931 :     return coefficients.matrix.rows();
     401             : }
     402             : 
     403             : template <typename Derived>
     404             : const Transformation<typename Basis<Derived>::scalar_t> &
     405         142 : Basis<Derived>::get_transformation() const {
     406         142 :     return coefficients;
     407             : }
     408             : 
     409             : template <typename Derived>
     410             : Transformation<typename Basis<Derived>::scalar_t>
     411           0 : Basis<Derived>::get_rotator(real_t alpha, real_t beta, real_t gamma) const {
     412           0 :     Transformation<scalar_t> transformation{{static_cast<Eigen::Index>(coefficients.matrix.rows()),
     413             :                                              static_cast<Eigen::Index>(coefficients.matrix.rows())},
     414             :                                             {TransformationType::ROTATE}};
     415             : 
     416           0 :     std::vector<Eigen::Triplet<scalar_t>> entries;
     417             : 
     418           0 :     for (size_t idx_initial = 0; idx_initial < kets.size(); ++idx_initial) {
     419           0 :         real_t f = kets[idx_initial]->get_quantum_number_f();
     420           0 :         real_t m_initial = kets[idx_initial]->get_quantum_number_m();
     421             : 
     422           0 :         assert(2 * f == std::floor(2 * f) && f >= 0);
     423           0 :         assert(2 * m_initial == std::floor(2 * m_initial) && m_initial >= -f && m_initial <= f);
     424             : 
     425           0 :         for (real_t m_final = -f; m_final <= f; // NOSONAR m_final is precisely representable
     426             :              ++m_final) {
     427           0 :             auto val = wigner::wigner_uppercase_d_matrix<scalar_t>(f, m_initial, m_final, alpha,
     428             :                                                                    beta, gamma);
     429           0 :             size_t idx_final = get_ket_index_from_ket(
     430           0 :                 kets[idx_initial]->get_ket_for_different_quantum_number_m(m_final));
     431           0 :             entries.emplace_back(idx_final, idx_initial, val);
     432             :         }
     433             :     }
     434             : 
     435           0 :     transformation.matrix.setFromTriplets(entries.begin(), entries.end());
     436           0 :     transformation.matrix.makeCompressed();
     437             : 
     438           0 :     return transformation;
     439           0 : }
     440             : 
     441             : template <typename Derived>
     442           1 : Sorting Basis<Derived>::get_sorter(const std::vector<TransformationType> &labels) const {
     443           1 :     perform_sorter_checks(labels);
     444             : 
     445             :     // Throw a meaningful error if sorting by energy is requested as this might be a common mistake
     446           1 :     if (std::find(labels.begin(), labels.end(), TransformationType::SORT_BY_ENERGY) !=
     447           2 :         labels.end()) {
     448           0 :         throw std::invalid_argument("States do not store the energy and thus can not be sorted by "
     449             :                                     "the energy. Use an energy operator instead.");
     450             :     }
     451             : 
     452             :     // Initialize transformation
     453           1 :     Sorting transformation;
     454           1 :     transformation.matrix.resize(coefficients.matrix.cols());
     455           1 :     transformation.matrix.setIdentity();
     456             : 
     457             :     // Get the sorter
     458           1 :     get_sorter_without_checks(labels, transformation);
     459             : 
     460             :     // Check if all labels have been used for sorting
     461           1 :     if (labels != transformation.transformation_type) {
     462           0 :         throw std::invalid_argument("The states could not be sorted by all the requested labels.");
     463             :     }
     464             : 
     465           1 :     return transformation;
     466           0 : }
     467             : 
     468             : template <typename Derived>
     469             : std::vector<IndicesOfBlock>
     470           1 : Basis<Derived>::get_indices_of_blocks(const std::vector<TransformationType> &labels) const {
     471           1 :     perform_sorter_checks(labels);
     472             : 
     473           1 :     std::set<TransformationType> unique_labels(labels.begin(), labels.end());
     474           1 :     perform_blocks_checks(unique_labels);
     475             : 
     476             :     // Get the blocks
     477           1 :     IndicesOfBlocksCreator blocks_creator({0, static_cast<size_t>(coefficients.matrix.cols())});
     478           1 :     get_indices_of_blocks_without_checks(unique_labels, blocks_creator);
     479             : 
     480           2 :     return blocks_creator.create();
     481           1 : }
     482             : 
     483             : template <typename Derived>
     484         101 : void Basis<Derived>::get_sorter_without_checks(const std::vector<TransformationType> &labels,
     485             :                                                Sorting &transformation) const {
     486         101 :     constexpr real_t numerical_precision = 100 * std::numeric_limits<real_t>::epsilon();
     487             : 
     488         101 :     int *perm_begin = transformation.matrix.indices().data();
     489         101 :     int *perm_end = perm_begin + coefficients.matrix.cols();
     490         101 :     const int *perm_back = perm_end - 1;
     491             : 
     492             :     // Sort the vector based on the requested labels
     493      254106 :     std::stable_sort(perm_begin, perm_end, [&](int a, int b) {
     494       93408 :         for (const auto &label : labels) {
     495       70162 :             switch (label) {
     496       17609 :             case TransformationType::SORT_BY_PARITY:
     497       17609 :                 if (state_index_to_parity[a] != state_index_to_parity[b]) {
     498       30145 :                     return state_index_to_parity[a] < state_index_to_parity[b];
     499             :                 }
     500       11581 :                 break;
     501       52553 :             case TransformationType::SORT_BY_QUANTUM_NUMBER_M:
     502      105106 :                 if (std::abs(state_index_to_quantum_number_m[a] -
     503      105106 :                              state_index_to_quantum_number_m[b]) > numerical_precision) {
     504       24117 :                     return state_index_to_quantum_number_m[a] < state_index_to_quantum_number_m[b];
     505             :                 }
     506       28436 :                 break;
     507           0 :             case TransformationType::SORT_BY_QUANTUM_NUMBER_F:
     508           0 :                 if (std::abs(state_index_to_quantum_number_f[a] -
     509           0 :                              state_index_to_quantum_number_f[b]) > numerical_precision) {
     510           0 :                     return state_index_to_quantum_number_f[a] < state_index_to_quantum_number_f[b];
     511             :                 }
     512           0 :                 break;
     513           0 :             case TransformationType::SORT_BY_KET:
     514           0 :                 if (state_index_to_ket_index[a] != state_index_to_ket_index[b]) {
     515           0 :                     return state_index_to_ket_index[a] < state_index_to_ket_index[b];
     516             :                 }
     517           0 :                 break;
     518           0 :             default:
     519           0 :                 std::abort(); // Can't happen because of previous checks
     520             :             }
     521             :         }
     522       23246 :         return false; // Elements are equal
     523             :     });
     524             : 
     525             :     // Check for invalid values and add transformation types
     526         225 :     for (const auto &label : labels) {
     527         124 :         switch (label) {
     528          25 :         case TransformationType::SORT_BY_PARITY:
     529          25 :             if (state_index_to_parity[*perm_back] == Parity::UNKNOWN) {
     530           0 :                 throw std::invalid_argument(
     531             :                     "States cannot be labeled and thus not sorted by the parity.");
     532             :             }
     533          25 :             transformation.transformation_type.push_back(TransformationType::SORT_BY_PARITY);
     534          25 :             break;
     535          99 :         case TransformationType::SORT_BY_QUANTUM_NUMBER_M:
     536          99 :             if (state_index_to_quantum_number_m[*perm_back] == std::numeric_limits<real_t>::max()) {
     537           0 :                 throw std::invalid_argument(
     538             :                     "States cannot be labeled and thus not sorted by the quantum number m.");
     539             :             }
     540          99 :             transformation.transformation_type.push_back(
     541          99 :                 TransformationType::SORT_BY_QUANTUM_NUMBER_M);
     542          99 :             break;
     543           0 :         case TransformationType::SORT_BY_QUANTUM_NUMBER_F:
     544           0 :             if (state_index_to_quantum_number_f[*perm_back] == std::numeric_limits<real_t>::max()) {
     545           0 :                 throw std::invalid_argument(
     546             :                     "States cannot be labeled and thus not sorted by the quantum number f.");
     547             :             }
     548           0 :             transformation.transformation_type.push_back(
     549           0 :                 TransformationType::SORT_BY_QUANTUM_NUMBER_F);
     550           0 :             break;
     551           0 :         case TransformationType::SORT_BY_KET:
     552           0 :             if (state_index_to_ket_index[*perm_back] == std::numeric_limits<int>::max()) {
     553           0 :                 throw std::invalid_argument(
     554             :                     "States cannot be labeled and thus not sorted by kets.");
     555             :             }
     556           0 :             transformation.transformation_type.push_back(TransformationType::SORT_BY_KET);
     557           0 :             break;
     558           0 :         default:
     559           0 :             std::abort(); // Can't happen because of previous checks
     560             :         }
     561             :     }
     562         101 : }
     563             : 
     564             : template <typename Derived>
     565         101 : void Basis<Derived>::get_indices_of_blocks_without_checks(
     566             :     const std::set<TransformationType> &unique_labels,
     567             :     IndicesOfBlocksCreator &blocks_creator) const {
     568         101 :     constexpr real_t numerical_precision = 100 * std::numeric_limits<real_t>::epsilon();
     569             : 
     570         101 :     auto last_quantum_number_f = state_index_to_quantum_number_f[0];
     571         101 :     auto last_quantum_number_m = state_index_to_quantum_number_m[0];
     572         101 :     auto last_parity = state_index_to_parity[0];
     573         101 :     auto last_ket = state_index_to_ket_index[0];
     574             : 
     575        8850 :     for (int i = 0; i < coefficients.matrix.cols(); ++i) {
     576       21145 :         for (auto label : unique_labels) {
     577       12764 :             if (label == TransformationType::SORT_BY_QUANTUM_NUMBER_F &&
     578           0 :                 std::abs(state_index_to_quantum_number_f[i] - last_quantum_number_f) >
     579             :                     numerical_precision) {
     580           0 :                 blocks_creator.add(i);
     581           0 :                 break;
     582             :             }
     583       21333 :             if (label == TransformationType::SORT_BY_QUANTUM_NUMBER_M &&
     584        8569 :                 std::abs(state_index_to_quantum_number_m[i] - last_quantum_number_m) >
     585             :                     numerical_precision) {
     586         278 :                 blocks_creator.add(i);
     587         278 :                 break;
     588             :             }
     589       16681 :             if (label == TransformationType::SORT_BY_PARITY &&
     590        4195 :                 state_index_to_parity[i] != last_parity) {
     591          90 :                 blocks_creator.add(i);
     592          90 :                 break;
     593             :             }
     594       12396 :             if (label == TransformationType::SORT_BY_KET &&
     595       12396 :                 state_index_to_ket_index[i] != last_ket &&
     596           0 :                 last_ket != std::numeric_limits<int>::max()) {
     597           0 :                 blocks_creator.add(i);
     598           0 :                 break;
     599             :             }
     600             :         }
     601        8749 :         last_quantum_number_f = state_index_to_quantum_number_f[i];
     602        8749 :         last_quantum_number_m = state_index_to_quantum_number_m[i];
     603        8749 :         last_parity = state_index_to_parity[i];
     604        8749 :         last_ket = state_index_to_ket_index[i];
     605             :     }
     606         101 : }
     607             : 
     608             : template <typename Derived>
     609         409 : std::shared_ptr<const Derived> Basis<Derived>::transformed(const Sorting &transformation) const {
     610             :     // Create a copy of the current object
     611         409 :     auto transformed = std::make_shared<Derived>(derived());
     612             : 
     613         409 :     if (coefficients.matrix.cols() == 0) {
     614           0 :         return transformed;
     615             :     }
     616             : 
     617             :     // Apply the transformation
     618         409 :     transformed->coefficients.matrix = coefficients.matrix * transformation.matrix;
     619         409 :     transformed->coefficients.transformation_type = transformation.transformation_type;
     620             : 
     621         409 :     transformed->state_index_to_quantum_number_f.resize(transformation.matrix.size());
     622         409 :     transformed->state_index_to_quantum_number_m.resize(transformation.matrix.size());
     623         409 :     transformed->state_index_to_parity.resize(transformation.matrix.size());
     624         409 :     transformed->state_index_to_ket_index.resize(transformation.matrix.size());
     625             : 
     626       67475 :     for (int i = 0; i < transformation.matrix.size(); ++i) {
     627       67066 :         transformed->state_index_to_quantum_number_f[i] =
     628       67066 :             state_index_to_quantum_number_f[transformation.matrix.indices()[i]];
     629       67066 :         transformed->state_index_to_quantum_number_m[i] =
     630       67066 :             state_index_to_quantum_number_m[transformation.matrix.indices()[i]];
     631       67066 :         transformed->state_index_to_parity[i] =
     632       67066 :             state_index_to_parity[transformation.matrix.indices()[i]];
     633             : 
     634       67066 :         size_t ket_index = state_index_to_ket_index[transformation.matrix.indices()[i]];
     635       67066 :         transformed->state_index_to_ket_index[i] = ket_index;
     636       67066 :         if (ket_index != std::numeric_limits<int>::max()) {
     637       67066 :             transformed->ket_index_to_state_index[ket_index] = i;
     638             :         }
     639             :     }
     640             : 
     641         409 :     return transformed;
     642         409 : }
     643             : 
     644             : template <typename Derived>
     645             : std::shared_ptr<const Derived>
     646         141 : Basis<Derived>::transformed(const Transformation<scalar_t> &transformation) const {
     647             :     // TODO why is "numerical_precision = 100 * std::sqrt(coefficients.matrix.rows()) *
     648             :     // std::numeric_limits<real_t>::epsilon()" too small for figuring out whether m is conserved?
     649         141 :     real_t numerical_precision = 0.001;
     650             : 
     651             :     // If the transformation is a rotation, it should be a rotation and nothing else
     652         141 :     bool is_rotation = false;
     653         282 :     for (auto t : transformation.transformation_type) {
     654         141 :         if (t == TransformationType::ROTATE) {
     655           0 :             is_rotation = true;
     656           0 :             break;
     657             :         }
     658             :     }
     659         141 :     if (is_rotation && transformation.transformation_type.size() != 1) {
     660           0 :         throw std::invalid_argument("A rotation can not be combined with other transformations.");
     661             :     }
     662             : 
     663             :     // To apply a rotation, the object must only be sorted but other transformations are not allowed
     664         141 :     if (is_rotation) {
     665           0 :         for (auto t : coefficients.transformation_type) {
     666           0 :             if (!utils::is_sorting(t)) {
     667           0 :                 throw std::runtime_error(
     668             :                     "If the object was transformed by a different transformation "
     669             :                     "than sorting, it can not be rotated.");
     670             :             }
     671             :         }
     672             :     }
     673             : 
     674             :     // Create a copy of the current object
     675         141 :     auto transformed = std::make_shared<Derived>(derived());
     676             : 
     677         141 :     if (coefficients.matrix.cols() == 0) {
     678           0 :         return transformed;
     679             :     }
     680             : 
     681             :     // Apply the transformation
     682             :     // If a quantum number turns out to be conserved by the transformation, it will be
     683             :     // rounded to the nearest half integer to avoid loss of numerical_precision.
     684         141 :     transformed->coefficients.matrix = coefficients.matrix * transformation.matrix;
     685         141 :     transformed->coefficients.transformation_type = transformation.transformation_type;
     686             : 
     687         141 :     Eigen::SparseMatrix<real_t> probs = transformation.matrix.cwiseAbs2().transpose();
     688             : 
     689             :     {
     690         282 :         auto map = Eigen::Map<const Eigen::VectorX<real_t>>(state_index_to_quantum_number_f.data(),
     691         141 :                                                             state_index_to_quantum_number_f.size());
     692         141 :         Eigen::VectorX<real_t> val = probs * map;
     693         141 :         Eigen::VectorX<real_t> sq = probs * map.cwiseAbs2();
     694         141 :         Eigen::VectorX<real_t> diff = (val.cwiseAbs2() - sq).cwiseAbs();
     695         141 :         transformed->state_index_to_quantum_number_f.resize(probs.rows());
     696             : 
     697       10355 :         for (size_t i = 0; i < transformed->state_index_to_quantum_number_f.size(); ++i) {
     698       10214 :             if (diff[i] < numerical_precision) {
     699        3005 :                 transformed->state_index_to_quantum_number_f[i] = std::round(val[i] * 2) / 2;
     700             :             } else {
     701        7209 :                 transformed->state_index_to_quantum_number_f[i] =
     702        7209 :                     std::numeric_limits<real_t>::max();
     703        7209 :                 transformed->_has_quantum_number_f = false;
     704             :             }
     705             :         }
     706         141 :     }
     707             : 
     708             :     {
     709         282 :         auto map = Eigen::Map<const Eigen::VectorX<real_t>>(state_index_to_quantum_number_m.data(),
     710         141 :                                                             state_index_to_quantum_number_m.size());
     711         141 :         Eigen::VectorX<real_t> val = probs * map;
     712         141 :         Eigen::VectorX<real_t> sq = probs * map.cwiseAbs2();
     713         141 :         Eigen::VectorX<real_t> diff = (val.cwiseAbs2() - sq).cwiseAbs();
     714         141 :         transformed->state_index_to_quantum_number_m.resize(probs.rows());
     715             : 
     716       10355 :         for (size_t i = 0; i < transformed->state_index_to_quantum_number_m.size(); ++i) {
     717       10214 :             if (diff[i] < numerical_precision) {
     718        8274 :                 transformed->state_index_to_quantum_number_m[i] = std::round(val[i] * 2) / 2;
     719             :             } else {
     720        1940 :                 transformed->state_index_to_quantum_number_m[i] =
     721        1940 :                     std::numeric_limits<real_t>::max();
     722        1940 :                 transformed->_has_quantum_number_m = false;
     723             :             }
     724             :         }
     725         141 :     }
     726             : 
     727             :     {
     728             :         using utype = std::underlying_type<Parity>::type;
     729         141 :         Eigen::VectorX<real_t> map(state_index_to_parity.size());
     730       10914 :         for (size_t i = 0; i < state_index_to_parity.size(); ++i) {
     731       10773 :             map[i] = static_cast<utype>(state_index_to_parity[i]);
     732             :         }
     733         141 :         Eigen::VectorX<real_t> val = probs * map;
     734         141 :         Eigen::VectorX<real_t> sq = probs * map.cwiseAbs2();
     735         141 :         Eigen::VectorX<real_t> diff = (val.cwiseAbs2() - sq).cwiseAbs();
     736         141 :         transformed->state_index_to_parity.resize(probs.rows());
     737             : 
     738       10355 :         for (size_t i = 0; i < transformed->state_index_to_parity.size(); ++i) {
     739       10214 :             if (diff[i] < numerical_precision) {
     740        6706 :                 transformed->state_index_to_parity[i] = static_cast<Parity>(std::lround(val[i]));
     741             :             } else {
     742        3508 :                 transformed->state_index_to_parity[i] = Parity::UNKNOWN;
     743        3508 :                 transformed->_has_parity = false;
     744             :             }
     745             :         }
     746         141 :     }
     747             : 
     748             :     {
     749             :         // In the following, we obtain a bijective map between state index and ket index.
     750             : 
     751             :         // Find the maximum value in each row and column
     752         141 :         std::vector<real_t> max_in_row(transformed->coefficients.matrix.rows(), 0);
     753         141 :         std::vector<real_t> max_in_col(transformed->coefficients.matrix.cols(), 0);
     754       10914 :         for (int row = 0; row < transformed->coefficients.matrix.outerSize(); ++row) {
     755       10773 :             for (typename Eigen::SparseMatrix<scalar_t, Eigen::RowMajor>::InnerIterator it(
     756       10773 :                      transformed->coefficients.matrix, row);
     757      248008 :                  it; ++it) {
     758      237235 :                 real_t val = std::pow(std::abs(it.value()), 2);
     759      237235 :                 max_in_row[row] = std::max(max_in_row[row], val);
     760      237235 :                 max_in_col[it.col()] = std::max(max_in_col[it.col()], val);
     761             :             }
     762             :         }
     763             : 
     764             :         // Use the maximum values to define a cost for a sub-optimal mapping
     765         141 :         std::vector<real_t> costs;
     766         141 :         std::vector<std::pair<int, int>> mappings;
     767         141 :         costs.reserve(transformed->coefficients.matrix.nonZeros());
     768         141 :         mappings.reserve(transformed->coefficients.matrix.nonZeros());
     769       10914 :         for (int row = 0; row < transformed->coefficients.matrix.outerSize(); ++row) {
     770       10773 :             for (typename Eigen::SparseMatrix<scalar_t, Eigen::RowMajor>::InnerIterator it(
     771       10773 :                      transformed->coefficients.matrix, row);
     772      248008 :                  it; ++it) {
     773      237235 :                 real_t val = std::pow(std::abs(it.value()), 2);
     774      237235 :                 real_t cost = max_in_row[row] + max_in_col[it.col()] - 2 * val;
     775      237235 :                 costs.push_back(cost);
     776      237235 :                 mappings.push_back({row, it.col()});
     777             :             }
     778             :         }
     779             : 
     780             :         // Obtain from the costs in which order the mappings should be considered
     781         141 :         std::vector<size_t> order(costs.size());
     782         141 :         std::iota(order.begin(), order.end(), 0);
     783         141 :         std::sort(order.begin(), order.end(),
     784     3472964 :                   [&](size_t a, size_t b) { return costs[a] < costs[b]; });
     785             : 
     786             :         // Fill ket_index_to_state_index with invalid values as there can be more kets than states
     787         141 :         std::fill(transformed->ket_index_to_state_index.begin(),
     788         141 :                   transformed->ket_index_to_state_index.end(), std::numeric_limits<int>::max());
     789             : 
     790             :         // Generate the bijective map
     791         141 :         std::vector<bool> row_used(transformed->coefficients.matrix.rows(), false);
     792         141 :         std::vector<bool> col_used(transformed->coefficients.matrix.cols(), false);
     793         141 :         int num_used = 0;
     794       21344 :         for (size_t idx : order) {
     795       21342 :             int row = mappings[idx].first;  // corresponds to the ket index
     796       21342 :             int col = mappings[idx].second; // corresponds to the state index
     797       21342 :             if (!row_used[row] && !col_used[col]) {
     798       10214 :                 row_used[row] = true;
     799       10214 :                 col_used[col] = true;
     800       10214 :                 num_used++;
     801       10214 :                 transformed->state_index_to_ket_index[col] = row;
     802       10214 :                 transformed->ket_index_to_state_index[row] = col;
     803             :             }
     804       21342 :             if (num_used == transformed->coefficients.matrix.cols()) {
     805         139 :                 break;
     806             :             }
     807             :         }
     808         141 :         assert(num_used == transformed->coefficients.matrix.cols());
     809         141 :     }
     810             : 
     811         141 :     return transformed;
     812         141 : }
     813             : 
     814             : template <typename Derived>
     815       81552 : size_t Basis<Derived>::hash::operator()(const std::shared_ptr<const ket_t> &k) const {
     816       81552 :     return typename ket_t::hash()(*k);
     817             : }
     818             : 
     819             : template <typename Derived>
     820         616 : bool Basis<Derived>::equal_to::operator()(const std::shared_ptr<const ket_t> &lhs,
     821             :                                           const std::shared_ptr<const ket_t> &rhs) const {
     822         616 :     return *lhs == *rhs;
     823             : }
     824             : 
     825             : // Explicit instantiations
     826             : template class Basis<BasisAtom<double>>;
     827             : template class Basis<BasisAtom<std::complex<double>>>;
     828             : template class Basis<BasisPair<double>>;
     829             : template class Basis<BasisPair<std::complex<double>>>;
     830             : } // namespace pairinteraction

Generated by: LCOV version 1.16