LCOV - code coverage report
Current view: top level - src/basis - Basis.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 378 468 80.8 %
Date: 2026-04-17 09:20:02 Functions: 138 204 67.6 %

          Line data    Source code
       1             : // SPDX-FileCopyrightText: 2024 PairInteraction Developers
       2             : // SPDX-License-Identifier: LGPL-3.0-or-later
       3             : 
       4             : #include "pairinteraction/basis/Basis.hpp"
       5             : 
       6             : #include "pairinteraction/basis/BasisAtom.hpp"
       7             : #include "pairinteraction/basis/BasisPair.hpp"
       8             : #include "pairinteraction/enums/Parity.hpp"
       9             : #include "pairinteraction/enums/TransformationType.hpp"
      10             : #include "pairinteraction/ket/KetAtom.hpp"
      11             : #include "pairinteraction/ket/KetPair.hpp"
      12             : #include "pairinteraction/utils/TaskControl.hpp"
      13             : #include "pairinteraction/utils/eigen_assertion.hpp"
      14             : #include "pairinteraction/utils/eigen_compat.hpp"
      15             : #include "pairinteraction/utils/wigner.hpp"
      16             : 
      17             : #include <cassert>
      18             : #include <numeric>
      19             : #include <set>
      20             : #include <spdlog/spdlog.h>
      21             : 
      22             : namespace pairinteraction {
      23             : 
      24             : template <typename Scalar>
      25             : class BasisAtom;
      26             : 
      27             : template <typename Derived>
      28        2238 : void Basis<Derived>::perform_sorter_checks(const std::vector<TransformationType> &labels) const {
      29             :     // Check if the labels are valid sorting labels
      30        4501 :     for (const auto &label : labels) {
      31        2263 :         if (!utils::is_sorting(label)) {
      32           0 :             throw std::invalid_argument("One of the labels is not a valid sorting label.");
      33             :         }
      34             :     }
      35        2238 : }
      36             : 
      37             : template <typename Derived>
      38         622 : void Basis<Derived>::perform_blocks_checks(
      39             :     const std::set<TransformationType> &unique_labels) const {
      40             :     // Check if the states are sorted by the requested labels
      41         622 :     std::set<TransformationType> unique_labels_present;
      42        1230 :     for (const auto &label : get_transformation().transformation_type) {
      43         661 :         if (!utils::is_sorting(label) || unique_labels_present.size() >= unique_labels.size()) {
      44          53 :             break;
      45             :         }
      46         608 :         unique_labels_present.insert(label);
      47             :     }
      48         622 :     if (unique_labels != unique_labels_present) {
      49           0 :         throw std::invalid_argument("The states are not sorted by the requested labels.");
      50             :     }
      51             : 
      52             :     // Throw a meaningful error if getting the blocks by energy is requested as this might be a
      53             :     // common mistake
      54         622 :     if (unique_labels.contains(TransformationType::SORT_BY_ENERGY)) {
      55           0 :         throw std::invalid_argument("States do not store the energy and thus no energy blocks can "
      56             :                                     "be obtained. Use an energy operator instead.");
      57             :     }
      58         622 : }
      59             : 
      60             : template <typename Derived>
      61         458 : Basis<Derived>::Basis(ketvec_t &&kets)
      62         916 :     : kets(std::move(kets)), coefficients{{static_cast<Eigen::Index>(this->kets.size()),
      63         458 :                                            static_cast<Eigen::Index>(this->kets.size())},
      64         916 :                                           {TransformationType::SORT_BY_KET}} {
      65         458 :     if (this->kets.empty()) {
      66           0 :         throw std::invalid_argument("The basis must contain at least one element.");
      67             :     }
      68         458 :     state_index_to_quantum_number_f.reserve(this->kets.size());
      69         458 :     state_index_to_quantum_number_m.reserve(this->kets.size());
      70         458 :     state_index_to_parity.reserve(this->kets.size());
      71         458 :     ket_to_ket_index.reserve(this->kets.size());
      72         458 :     size_t index = 0;
      73      251874 :     for (const auto &ket : this->kets) {
      74      251416 :         state_index_to_quantum_number_f.push_back(ket->get_quantum_number_f());
      75      251416 :         state_index_to_quantum_number_m.push_back(ket->get_quantum_number_m());
      76      251416 :         state_index_to_parity.push_back(ket->get_parity());
      77      251416 :         ket_to_ket_index[ket] = index++;
      78      251416 :         if (ket->get_quantum_number_f() == std::numeric_limits<real_t>::max()) {
      79      224297 :             _has_quantum_number_f = false;
      80             :         }
      81      251416 :         if (ket->get_quantum_number_m() == std::numeric_limits<real_t>::max()) {
      82           0 :             _has_quantum_number_m = false;
      83             :         }
      84      251416 :         if (ket->get_parity() == Parity::UNKNOWN) {
      85      224297 :             _has_parity = false;
      86             :         }
      87             :     }
      88         458 :     state_index_to_ket_index.resize(this->kets.size());
      89         458 :     std::iota(state_index_to_ket_index.begin(), state_index_to_ket_index.end(), 0);
      90         458 :     ket_index_to_state_index.resize(this->kets.size());
      91         458 :     std::iota(ket_index_to_state_index.begin(), ket_index_to_state_index.end(), 0);
      92         458 :     coefficients.matrix.setIdentity();
      93         458 : }
      94             : 
      95             : template <typename Derived>
      96         843 : bool Basis<Derived>::has_quantum_number_f() const {
      97         843 :     return _has_quantum_number_f;
      98             : }
      99             : 
     100             : template <typename Derived>
     101      563099 : bool Basis<Derived>::has_quantum_number_m() const {
     102      563099 :     return _has_quantum_number_m;
     103             : }
     104             : 
     105             : template <typename Derived>
     106         843 : bool Basis<Derived>::has_parity() const {
     107         843 :     return _has_parity;
     108             : }
     109             : 
     110             : template <typename Derived>
     111        2892 : const Derived &Basis<Derived>::derived() const {
     112        2892 :     return static_cast<const Derived &>(*this);
     113             : }
     114             : 
     115             : template <typename Derived>
     116         120 : const typename Basis<Derived>::ketvec_t &Basis<Derived>::get_kets() const {
     117         120 :     return kets;
     118             : }
     119             : 
     120             : template <typename Derived>
     121             : const Eigen::SparseMatrix<typename Basis<Derived>::scalar_t, Eigen::RowMajor> &
     122        9783 : Basis<Derived>::get_coefficients() const {
     123        9783 :     return coefficients.matrix;
     124             : }
     125             : 
     126             : template <typename Derived>
     127             : Eigen::SparseMatrix<typename Basis<Derived>::scalar_t, Eigen::RowMajor> &
     128           0 : Basis<Derived>::get_coefficients() {
     129           0 :     return coefficients.matrix;
     130             : }
     131             : 
     132             : template <typename Derived>
     133           5 : void Basis<Derived>::set_coefficients(
     134             :     const Eigen::SparseMatrix<scalar_t, Eigen::RowMajor> &values) {
     135           5 :     if (values.rows() != coefficients.matrix.rows()) {
     136           0 :         throw std::invalid_argument("Incompatible number of rows.");
     137             :     }
     138           5 :     if (values.cols() != coefficients.matrix.cols()) {
     139           0 :         throw std::invalid_argument("Incompatible number of columns.");
     140             :     }
     141             : 
     142           5 :     coefficients.matrix = values;
     143             : 
     144           5 :     std::fill(ket_index_to_state_index.begin(), ket_index_to_state_index.end(),
     145           5 :               std::numeric_limits<int>::max());
     146           5 :     std::fill(state_index_to_ket_index.begin(), state_index_to_ket_index.end(),
     147           5 :               std::numeric_limits<int>::max());
     148           5 :     std::fill(state_index_to_quantum_number_f.begin(), state_index_to_quantum_number_f.end(),
     149           5 :               std::numeric_limits<real_t>::max());
     150           5 :     std::fill(state_index_to_quantum_number_m.begin(), state_index_to_quantum_number_m.end(),
     151           5 :               std::numeric_limits<real_t>::max());
     152           5 :     std::fill(state_index_to_parity.begin(), state_index_to_parity.end(), Parity::UNKNOWN);
     153           5 :     _has_quantum_number_f = false;
     154           5 :     _has_quantum_number_m = false;
     155           5 :     _has_parity = false;
     156           5 : }
     157             : 
     158             : template <typename Derived>
     159         723 : int Basis<Derived>::get_ket_index_from_ket(std::shared_ptr<const ket_t> ket) const {
     160         723 :     if (!ket_to_ket_index.contains(ket)) {
     161           0 :         return -1;
     162             :     }
     163         723 :     return ket_to_ket_index.at(ket);
     164             : }
     165             : 
     166             : template <typename Derived>
     167             : Eigen::VectorX<typename Basis<Derived>::scalar_t>
     168         220 : Basis<Derived>::get_amplitudes(std::shared_ptr<const ket_t> ket) const {
     169         220 :     int ket_index = get_ket_index_from_ket(ket);
     170         220 :     if (ket_index < 0) {
     171           0 :         throw std::invalid_argument("The ket does not belong to the basis.");
     172             :     }
     173             :     // The following line is a more efficient alternative to
     174             :     // "get_amplitudes(get_canonical_state_from_ket(ket)).transpose()"
     175         440 :     return coefficients.matrix.row(ket_index);
     176             : }
     177             : 
     178             : template <typename Derived>
     179             : Eigen::SparseMatrix<typename Basis<Derived>::scalar_t, Eigen::RowMajor>
     180          16 : Basis<Derived>::get_amplitudes(std::shared_ptr<const Derived> other) const {
     181          24 :     return other->coefficients.matrix.adjoint() * coefficients.matrix;
     182             : }
     183             : 
     184             : template <typename Derived>
     185             : Eigen::VectorX<typename Basis<Derived>::real_t>
     186         214 : Basis<Derived>::get_overlaps(std::shared_ptr<const ket_t> ket) const {
     187         214 :     return get_amplitudes(ket).cwiseAbs2();
     188             : }
     189             : 
     190             : template <typename Derived>
     191             : Eigen::SparseMatrix<typename Basis<Derived>::real_t, Eigen::RowMajor>
     192           8 : Basis<Derived>::get_overlaps(std::shared_ptr<const Derived> other) const {
     193           8 :     return get_amplitudes(other).cwiseAbs2();
     194             : }
     195             : 
     196             : template <typename Derived>
     197           0 : typename Basis<Derived>::real_t Basis<Derived>::get_quantum_number_f(size_t state_index) const {
     198           0 :     real_t quantum_number_f = state_index_to_quantum_number_f.at(state_index);
     199           0 :     if (quantum_number_f == std::numeric_limits<real_t>::max()) {
     200           0 :         throw std::invalid_argument("The state does not have a well-defined quantum number f.");
     201             :     }
     202           0 :     return quantum_number_f;
     203             : }
     204             : 
     205             : template <typename Derived>
     206      562298 : typename Basis<Derived>::real_t Basis<Derived>::get_quantum_number_m(size_t state_index) const {
     207      562298 :     real_t quantum_number_m = state_index_to_quantum_number_m.at(state_index);
     208      562298 :     if (quantum_number_m == std::numeric_limits<real_t>::max()) {
     209           0 :         throw std::invalid_argument("The state does not have a well-defined quantum number m.");
     210             :     }
     211      562298 :     return quantum_number_m;
     212             : }
     213             : 
     214             : template <typename Derived>
     215          43 : Parity Basis<Derived>::get_parity(size_t state_index) const {
     216          43 :     Parity parity = state_index_to_parity.at(state_index);
     217          43 :     if (parity == Parity::UNKNOWN) {
     218           0 :         throw std::invalid_argument("The state does not have a well-defined parity.");
     219             :     }
     220          43 :     return parity;
     221             : }
     222             : 
     223             : template <typename Derived>
     224             : std::shared_ptr<const typename Basis<Derived>::ket_t>
     225         184 : Basis<Derived>::get_corresponding_ket(size_t state_index) const {
     226         184 :     size_t ket_index = state_index_to_ket_index.at(state_index);
     227         184 :     if (ket_index == std::numeric_limits<int>::max()) {
     228           0 :         throw std::invalid_argument("The state does not belong to a ket in a well-defined way.");
     229             :     }
     230         184 :     return kets[ket_index];
     231             : }
     232             : 
     233             : template <typename Derived>
     234             : std::shared_ptr<const typename Basis<Derived>::ket_t>
     235           0 : Basis<Derived>::get_corresponding_ket(std::shared_ptr<const Derived> /*state*/) const {
     236           0 :     throw std::runtime_error("Not implemented yet.");
     237             : }
     238             : 
     239             : template <typename Derived>
     240         399 : std::shared_ptr<const Derived> Basis<Derived>::get_state(size_t state_index) const {
     241             :     // Create a copy of the current object
     242         399 :     auto restricted = std::make_shared<Derived>(derived());
     243             : 
     244             :     // Restrict the copy to the state with the largest overlap
     245         399 :     restricted->coefficients.matrix = restricted->coefficients.matrix.col(state_index);
     246             : 
     247         399 :     std::fill(restricted->ket_index_to_state_index.begin(),
     248         399 :               restricted->ket_index_to_state_index.end(), std::numeric_limits<int>::max());
     249             : 
     250         399 :     size_t ket_index = state_index_to_ket_index[state_index];
     251         399 :     restricted->state_index_to_ket_index = {ket_index};
     252         399 :     if (ket_index != std::numeric_limits<int>::max()) {
     253         399 :         restricted->ket_index_to_state_index[ket_index] = 0;
     254             :     }
     255             : 
     256         399 :     restricted->state_index_to_quantum_number_f = {state_index_to_quantum_number_f[state_index]};
     257         399 :     restricted->state_index_to_quantum_number_m = {state_index_to_quantum_number_m[state_index]};
     258         399 :     restricted->state_index_to_parity = {state_index_to_parity[state_index]};
     259             : 
     260         798 :     restricted->_has_quantum_number_f =
     261         399 :         restricted->state_index_to_quantum_number_f[0] != std::numeric_limits<real_t>::max();
     262         798 :     restricted->_has_quantum_number_m =
     263         399 :         restricted->state_index_to_quantum_number_m[0] != std::numeric_limits<real_t>::max();
     264         399 :     restricted->_has_parity = restricted->state_index_to_parity[0] != Parity::UNKNOWN;
     265             : 
     266         798 :     return restricted;
     267         399 : }
     268             : 
     269             : template <typename Derived>
     270             : std::shared_ptr<const typename Basis<Derived>::ket_t>
     271      140512 : Basis<Derived>::get_ket(size_t ket_index) const {
     272      140512 :     return kets[ket_index];
     273             : }
     274             : 
     275             : template <typename Derived>
     276          68 : std::shared_ptr<const Derived> Basis<Derived>::get_corresponding_state(size_t ket_index) const {
     277          68 :     size_t state_index = ket_index_to_state_index.at(ket_index);
     278          68 :     if (state_index == std::numeric_limits<int>::max()) {
     279           0 :         throw std::runtime_error("The ket does not belong to a state in a well-defined way.");
     280             :     }
     281          68 :     return get_state(state_index);
     282             : }
     283             : 
     284             : template <typename Derived>
     285             : std::shared_ptr<const Derived>
     286          68 : Basis<Derived>::get_corresponding_state(std::shared_ptr<const ket_t> ket) const {
     287          68 :     int ket_index = get_ket_index_from_ket(ket);
     288          68 :     if (ket_index < 0) {
     289           0 :         throw std::invalid_argument("The ket does not belong to the basis.");
     290             :     }
     291          68 :     return get_corresponding_state(ket_index);
     292             : }
     293             : 
     294             : template <typename Derived>
     295         182 : size_t Basis<Derived>::get_corresponding_state_index(size_t ket_index) const {
     296         182 :     int state_index = ket_index_to_state_index.at(ket_index);
     297         182 :     if (state_index == std::numeric_limits<int>::max()) {
     298           0 :         throw std::runtime_error("The ket does not belong to a state in a well-defined way.");
     299             :     }
     300         182 :     return state_index;
     301             : }
     302             : 
     303             : template <typename Derived>
     304         182 : size_t Basis<Derived>::get_corresponding_state_index(std::shared_ptr<const ket_t> ket) const {
     305         182 :     int ket_index = get_ket_index_from_ket(ket);
     306         182 :     if (ket_index < 0) {
     307           0 :         throw std::invalid_argument("The ket does not belong to the basis.");
     308             :     }
     309         182 :     return get_corresponding_state_index(ket_index);
     310             : }
     311             : 
     312             : template <typename Derived>
     313         184 : size_t Basis<Derived>::get_corresponding_ket_index(size_t state_index) const {
     314         184 :     size_t ket_index = state_index_to_ket_index.at(state_index);
     315         184 :     if (ket_index == std::numeric_limits<int>::max()) {
     316           0 :         throw std::runtime_error("The state does not belong to a ket in a well-defined way.");
     317             :     }
     318         184 :     return ket_index;
     319             : }
     320             : 
     321             : template <typename Derived>
     322           0 : size_t Basis<Derived>::get_corresponding_ket_index(std::shared_ptr<const Derived> /*state*/) const {
     323           0 :     throw std::runtime_error("Not implemented yet.");
     324             : }
     325             : 
     326             : template <typename Derived>
     327             : std::shared_ptr<const Derived>
     328         253 : Basis<Derived>::get_canonical_state_from_ket(size_t ket_index) const {
     329             :     // Create a copy of the current object
     330         253 :     auto created = std::make_shared<Derived>(derived());
     331             : 
     332             :     // Fill the copy with the state corresponding to the ket index
     333         253 :     created->coefficients.matrix =
     334         506 :         Eigen::SparseMatrix<scalar_t, Eigen::RowMajor>(coefficients.matrix.rows(), 1);
     335         253 :     created->coefficients.matrix.coeffRef(ket_index, 0) = 1;
     336         253 :     created->coefficients.matrix.makeCompressed();
     337             : 
     338         253 :     std::fill(created->ket_index_to_state_index.begin(), created->ket_index_to_state_index.end(),
     339         253 :               std::numeric_limits<int>::max());
     340         253 :     created->ket_index_to_state_index[ket_index] = 0;
     341             : 
     342         253 :     created->state_index_to_quantum_number_f = {kets[ket_index]->get_quantum_number_f()};
     343         253 :     created->state_index_to_quantum_number_m = {kets[ket_index]->get_quantum_number_m()};
     344         253 :     created->state_index_to_parity = {kets[ket_index]->get_parity()};
     345         253 :     created->state_index_to_ket_index = {ket_index};
     346             : 
     347         506 :     created->_has_quantum_number_f =
     348         253 :         created->state_index_to_quantum_number_f[0] != std::numeric_limits<real_t>::max();
     349         506 :     created->_has_quantum_number_m =
     350         253 :         created->state_index_to_quantum_number_m[0] != std::numeric_limits<real_t>::max();
     351         253 :     created->_has_parity = created->state_index_to_parity[0] != Parity::UNKNOWN;
     352             : 
     353         506 :     return created;
     354         253 : }
     355             : 
     356             : template <typename Derived>
     357             : std::shared_ptr<const Derived>
     358         253 : Basis<Derived>::get_canonical_state_from_ket(std::shared_ptr<const ket_t> ket) const {
     359         253 :     int ket_index = get_ket_index_from_ket(ket);
     360         253 :     if (ket_index < 0) {
     361           0 :         throw std::invalid_argument("The ket does not belong to the basis.");
     362             :     }
     363         253 :     return get_canonical_state_from_ket(ket_index);
     364             : }
     365             : 
     366             : template <typename Derived>
     367           4 : typename Basis<Derived>::Iterator Basis<Derived>::begin() const {
     368           4 :     return kets.begin();
     369             : }
     370             : 
     371             : template <typename Derived>
     372           4 : typename Basis<Derived>::Iterator Basis<Derived>::end() const {
     373           4 :     return kets.end();
     374             : }
     375             : 
     376             : template <typename Derived>
     377           8 : Basis<Derived>::Iterator::Iterator(typename ketvec_t::const_iterator it) : it{std::move(it)} {}
     378             : 
     379             : template <typename Derived>
     380          80 : bool Basis<Derived>::Iterator::operator!=(const Iterator &other) const {
     381          80 :     return other.it != it;
     382             : }
     383             : 
     384             : template <typename Derived>
     385          76 : std::shared_ptr<const typename Basis<Derived>::ket_t> Basis<Derived>::Iterator::operator*() const {
     386          76 :     return *it;
     387             : }
     388             : 
     389             : template <typename Derived>
     390          76 : typename Basis<Derived>::Iterator &Basis<Derived>::Iterator::operator++() {
     391          76 :     ++it;
     392          76 :     return *this;
     393             : }
     394             : 
     395             : template <typename Derived>
     396        2770 : size_t Basis<Derived>::get_number_of_states() const {
     397        2770 :     return coefficients.matrix.cols();
     398             : }
     399             : 
     400             : template <typename Derived>
     401        6034 : size_t Basis<Derived>::get_number_of_kets() const {
     402        6034 :     return coefficients.matrix.rows();
     403             : }
     404             : 
     405             : template <typename Derived>
     406             : const Transformation<typename Basis<Derived>::scalar_t> &
     407         623 : Basis<Derived>::get_transformation() const {
     408         623 :     return coefficients;
     409             : }
     410             : 
     411             : template <typename Derived>
     412             : Transformation<typename Basis<Derived>::scalar_t>
     413           0 : Basis<Derived>::get_rotator(real_t alpha, real_t beta, real_t gamma) const {
     414           0 :     Transformation<scalar_t> transformation{{static_cast<Eigen::Index>(coefficients.matrix.rows()),
     415             :                                              static_cast<Eigen::Index>(coefficients.matrix.rows())},
     416             :                                             {TransformationType::ROTATE}};
     417             : 
     418           0 :     std::vector<Eigen::Triplet<scalar_t>> entries;
     419             : 
     420           0 :     set_task_status("Constructing basis rotation...");
     421           0 :     for (size_t idx_initial = 0; idx_initial < kets.size(); ++idx_initial) {
     422           0 :         real_t f = kets[idx_initial]->get_quantum_number_f();
     423           0 :         real_t m_initial = kets[idx_initial]->get_quantum_number_m();
     424             : 
     425           0 :         assert(2 * f == std::floor(2 * f) && f >= 0);
     426           0 :         assert(2 * m_initial == std::floor(2 * m_initial) && m_initial >= -f && m_initial <= f);
     427             : 
     428           0 :         for (real_t m_final = -f; m_final <= f; // NOSONAR m_final is precisely representable
     429             :              ++m_final) {
     430           0 :             auto val = wigner::wigner_uppercase_d_matrix<scalar_t>(f, m_initial, m_final, alpha,
     431             :                                                                    beta, gamma);
     432           0 :             size_t idx_final = get_ket_index_from_ket(
     433           0 :                 kets[idx_initial]->get_ket_for_different_quantum_number_m(m_final));
     434           0 :             entries.emplace_back(idx_final, idx_initial, val);
     435             :         }
     436             :     }
     437             : 
     438           0 :     transformation.matrix.setFromTriplets(entries.begin(), entries.end());
     439           0 :     transformation.matrix.makeCompressed();
     440             : 
     441           0 :     return transformation;
     442           0 : }
     443             : 
     444             : template <typename Derived>
     445           1 : Sorting Basis<Derived>::get_sorter(const std::vector<TransformationType> &labels) const {
     446           1 :     perform_sorter_checks(labels);
     447             : 
     448             :     // Throw a meaningful error if sorting by energy is requested as this might be a common mistake
     449           1 :     if (std::find(labels.begin(), labels.end(), TransformationType::SORT_BY_ENERGY) !=
     450           2 :         labels.end()) {
     451           0 :         throw std::invalid_argument("States do not store the energy and thus can not be sorted by "
     452             :                                     "the energy. Use an energy operator instead.");
     453             :     }
     454             : 
     455             :     // Initialize transformation
     456           1 :     Sorting transformation;
     457           1 :     transformation.matrix.resize(coefficients.matrix.cols());
     458           1 :     transformation.matrix.setIdentity();
     459             : 
     460             :     // Get the sorter
     461           1 :     get_sorter_without_checks(labels, transformation);
     462             : 
     463             :     // Check if all labels have been used for sorting
     464           1 :     if (labels != transformation.transformation_type) {
     465           0 :         throw std::invalid_argument("The states could not be sorted by all the requested labels.");
     466             :     }
     467             : 
     468           1 :     return transformation;
     469           0 : }
     470             : 
     471             : template <typename Derived>
     472             : std::vector<IndicesOfBlock>
     473           1 : Basis<Derived>::get_indices_of_blocks(const std::vector<TransformationType> &labels) const {
     474           1 :     perform_sorter_checks(labels);
     475             : 
     476           1 :     std::set<TransformationType> unique_labels(labels.begin(), labels.end());
     477           1 :     perform_blocks_checks(unique_labels);
     478             : 
     479             :     // Get the blocks
     480           1 :     IndicesOfBlocksCreator blocks_creator({0, static_cast<size_t>(coefficients.matrix.cols())});
     481           1 :     get_indices_of_blocks_without_checks(unique_labels, blocks_creator);
     482             : 
     483           2 :     return blocks_creator.create();
     484           1 : }
     485             : 
     486             : template <typename Derived>
     487         569 : void Basis<Derived>::get_sorter_without_checks(const std::vector<TransformationType> &labels,
     488             :                                                Sorting &transformation) const {
     489         569 :     constexpr real_t numerical_precision = 100 * std::numeric_limits<real_t>::epsilon();
     490             : 
     491         569 :     int *perm_begin = transformation.matrix.indices().data();
     492         569 :     int *perm_end = perm_begin + coefficients.matrix.cols();
     493         569 :     const int *perm_back = perm_end - 1;
     494             : 
     495             :     // Sort the vector based on the requested labels
     496         569 :     set_task_status("Sorting basis states...");
     497      635598 :     std::stable_sort(perm_begin, perm_end, [&](int a, int b) {
     498      279497 :         for (const auto &label : labels) {
     499      183374 :             switch (label) {
     500       30703 :             case TransformationType::SORT_BY_PARITY:
     501       30703 :                 if (state_index_to_parity[a] != state_index_to_parity[b]) {
     502       57386 :                     return state_index_to_parity[a] < state_index_to_parity[b];
     503             :                 }
     504       20305 :                 break;
     505      152671 :             case TransformationType::SORT_BY_QUANTUM_NUMBER_M:
     506      305342 :                 if (std::abs(state_index_to_quantum_number_m[a] -
     507      305342 :                              state_index_to_quantum_number_m[b]) > numerical_precision) {
     508       46988 :                     return state_index_to_quantum_number_m[a] < state_index_to_quantum_number_m[b];
     509             :                 }
     510      105683 :                 break;
     511           0 :             case TransformationType::SORT_BY_QUANTUM_NUMBER_F:
     512           0 :                 if (std::abs(state_index_to_quantum_number_f[a] -
     513           0 :                              state_index_to_quantum_number_f[b]) > numerical_precision) {
     514           0 :                     return state_index_to_quantum_number_f[a] < state_index_to_quantum_number_f[b];
     515             :                 }
     516           0 :                 break;
     517           0 :             case TransformationType::SORT_BY_KET:
     518           0 :                 if (state_index_to_ket_index[a] != state_index_to_ket_index[b]) {
     519           0 :                     return state_index_to_ket_index[a] < state_index_to_ket_index[b];
     520             :                 }
     521           0 :                 break;
     522           0 :             default:
     523           0 :                 std::abort(); // Can't happen because of previous checks
     524             :             }
     525             :         }
     526       96123 :         return false; // Elements are equal
     527             :     });
     528             : 
     529             :     // Check for invalid values and add transformation types
     530        1177 :     for (const auto &label : labels) {
     531         608 :         switch (label) {
     532          41 :         case TransformationType::SORT_BY_PARITY:
     533          41 :             if (state_index_to_parity[*perm_back] == Parity::UNKNOWN) {
     534           0 :                 throw std::invalid_argument(
     535             :                     "States cannot be labeled and thus not sorted by the parity.");
     536             :             }
     537          41 :             transformation.transformation_type.push_back(TransformationType::SORT_BY_PARITY);
     538          41 :             break;
     539         567 :         case TransformationType::SORT_BY_QUANTUM_NUMBER_M:
     540         567 :             if (state_index_to_quantum_number_m[*perm_back] == std::numeric_limits<real_t>::max()) {
     541           0 :                 throw std::invalid_argument(
     542             :                     "States cannot be labeled and thus not sorted by the quantum number m.");
     543             :             }
     544         567 :             transformation.transformation_type.push_back(
     545         567 :                 TransformationType::SORT_BY_QUANTUM_NUMBER_M);
     546         567 :             break;
     547           0 :         case TransformationType::SORT_BY_QUANTUM_NUMBER_F:
     548           0 :             if (state_index_to_quantum_number_f[*perm_back] == std::numeric_limits<real_t>::max()) {
     549           0 :                 throw std::invalid_argument(
     550             :                     "States cannot be labeled and thus not sorted by the quantum number f.");
     551             :             }
     552           0 :             transformation.transformation_type.push_back(
     553           0 :                 TransformationType::SORT_BY_QUANTUM_NUMBER_F);
     554           0 :             break;
     555           0 :         case TransformationType::SORT_BY_KET:
     556           0 :             if (state_index_to_ket_index[*perm_back] == std::numeric_limits<int>::max()) {
     557           0 :                 throw std::invalid_argument(
     558             :                     "States cannot be labeled and thus not sorted by kets.");
     559             :             }
     560           0 :             transformation.transformation_type.push_back(TransformationType::SORT_BY_KET);
     561           0 :             break;
     562           0 :         default:
     563           0 :             std::abort(); // Can't happen because of previous checks
     564             :         }
     565             :     }
     566         569 : }
     567             : 
     568             : template <typename Derived>
     569         569 : void Basis<Derived>::get_indices_of_blocks_without_checks(
     570             :     const std::set<TransformationType> &unique_labels,
     571             :     IndicesOfBlocksCreator &blocks_creator) const {
     572         569 :     constexpr real_t numerical_precision = 100 * std::numeric_limits<real_t>::epsilon();
     573             : 
     574         569 :     auto last_quantum_number_f = state_index_to_quantum_number_f[0];
     575         569 :     auto last_quantum_number_m = state_index_to_quantum_number_m[0];
     576         569 :     auto last_parity = state_index_to_parity[0];
     577         569 :     auto last_ket = state_index_to_ket_index[0];
     578             : 
     579         569 :     set_task_status("Identifying basis blocks...");
     580       32971 :     for (int i = 0; i < coefficients.matrix.cols(); ++i) {
     581       71114 :         for (auto label : unique_labels) {
     582       39411 :             if (label == TransformationType::SORT_BY_QUANTUM_NUMBER_F &&
     583           0 :                 std::abs(state_index_to_quantum_number_f[i] - last_quantum_number_f) >
     584             :                     numerical_precision) {
     585           0 :                 blocks_creator.add(i);
     586           0 :                 break;
     587             :             }
     588       71633 :             if (label == TransformationType::SORT_BY_QUANTUM_NUMBER_M &&
     589       32222 :                 std::abs(state_index_to_quantum_number_m[i] - last_quantum_number_m) >
     590             :                     numerical_precision) {
     591         555 :                 blocks_creator.add(i);
     592         555 :                 break;
     593             :             }
     594       46045 :             if (label == TransformationType::SORT_BY_PARITY &&
     595        7189 :                 state_index_to_parity[i] != last_parity) {
     596         144 :                 blocks_creator.add(i);
     597         144 :                 break;
     598             :             }
     599       38712 :             if (label == TransformationType::SORT_BY_KET &&
     600       38712 :                 state_index_to_ket_index[i] != last_ket &&
     601           0 :                 last_ket != std::numeric_limits<int>::max()) {
     602           0 :                 blocks_creator.add(i);
     603           0 :                 break;
     604             :             }
     605             :         }
     606       32402 :         last_quantum_number_f = state_index_to_quantum_number_f[i];
     607       32402 :         last_quantum_number_m = state_index_to_quantum_number_m[i];
     608       32402 :         last_parity = state_index_to_parity[i];
     609       32402 :         last_ket = state_index_to_ket_index[i];
     610             :     }
     611         569 : }
     612             : 
     613             : template <typename Derived>
     614        1616 : std::shared_ptr<const Derived> Basis<Derived>::transformed(const Sorting &transformation) const {
     615             :     // Create a copy of the current object
     616        1616 :     auto transformed = std::make_shared<Derived>(derived());
     617             : 
     618        1616 :     if (coefficients.matrix.cols() == 0) {
     619           0 :         return transformed;
     620             :     }
     621             : 
     622             :     // Apply the transformation
     623        1616 :     set_task_status("Applying basis sorting...");
     624        1616 :     transformed->coefficients.matrix = coefficients.matrix * transformation.matrix;
     625        1616 :     transformed->coefficients.transformation_type = transformation.transformation_type;
     626             : 
     627        1616 :     transformed->state_index_to_quantum_number_f.resize(transformation.matrix.size());
     628        1616 :     transformed->state_index_to_quantum_number_m.resize(transformation.matrix.size());
     629        1616 :     transformed->state_index_to_parity.resize(transformation.matrix.size());
     630        1616 :     transformed->state_index_to_ket_index.resize(transformation.matrix.size());
     631             : 
     632        1616 :     set_task_status("Relabeling sorted basis states...");
     633      153784 :     for (int i = 0; i < transformation.matrix.size(); ++i) {
     634      152168 :         transformed->state_index_to_quantum_number_f[i] =
     635      152168 :             state_index_to_quantum_number_f[transformation.matrix.indices()[i]];
     636      152168 :         transformed->state_index_to_quantum_number_m[i] =
     637      152168 :             state_index_to_quantum_number_m[transformation.matrix.indices()[i]];
     638      152168 :         transformed->state_index_to_parity[i] =
     639      152168 :             state_index_to_parity[transformation.matrix.indices()[i]];
     640             : 
     641      152168 :         size_t ket_index = state_index_to_ket_index[transformation.matrix.indices()[i]];
     642      152168 :         transformed->state_index_to_ket_index[i] = ket_index;
     643      152168 :         if (ket_index != std::numeric_limits<int>::max()) {
     644      152168 :             transformed->ket_index_to_state_index[ket_index] = i;
     645             :         }
     646             :     }
     647             : 
     648        1616 :     return transformed;
     649        1616 : }
     650             : 
     651             : template <typename Derived>
     652             : std::shared_ptr<const Derived>
     653         624 : Basis<Derived>::transformed(const Transformation<scalar_t> &transformation) const {
     654             :     // TODO why is "numerical_precision = 100 * std::sqrt(coefficients.matrix.rows()) *
     655             :     // std::numeric_limits<real_t>::epsilon()" too small for figuring out whether m is conserved?
     656         624 :     real_t numerical_precision = 0.001;
     657             : 
     658             :     // If the transformation is a rotation, it should be a rotation and nothing else
     659         624 :     bool is_rotation = false;
     660        1248 :     for (auto t : transformation.transformation_type) {
     661         624 :         if (t == TransformationType::ROTATE) {
     662           0 :             is_rotation = true;
     663           0 :             break;
     664             :         }
     665             :     }
     666         624 :     if (is_rotation && transformation.transformation_type.size() != 1) {
     667           0 :         throw std::invalid_argument("A rotation can not be combined with other transformations.");
     668             :     }
     669             : 
     670             :     // To apply a rotation, the object must only be sorted but other transformations are not allowed
     671         624 :     if (is_rotation) {
     672           0 :         for (auto t : coefficients.transformation_type) {
     673           0 :             if (!utils::is_sorting(t)) {
     674           0 :                 throw std::runtime_error(
     675             :                     "If the object was transformed by a different transformation "
     676             :                     "than sorting, it can not be rotated.");
     677             :             }
     678             :         }
     679             :     }
     680             : 
     681             :     // Create a copy of the current object
     682         624 :     auto transformed = std::make_shared<Derived>(derived());
     683             : 
     684         624 :     if (coefficients.matrix.cols() == 0) {
     685           0 :         return transformed;
     686             :     }
     687             : 
     688             :     // Apply the transformation
     689             :     // If a quantum number turns out to be conserved by the transformation, it will be
     690             :     // rounded to the nearest half integer to avoid loss of numerical_precision.
     691         624 :     set_task_status("Applying basis transformation...");
     692         624 :     transformed->coefficients.matrix = coefficients.matrix * transformation.matrix;
     693         624 :     transformed->coefficients.transformation_type = transformation.transformation_type;
     694             : 
     695         624 :     Eigen::SparseMatrix<real_t> probs = transformation.matrix.cwiseAbs2().transpose();
     696             : 
     697         624 :     set_task_status("Updating transformed quantum numbers...");
     698             :     {
     699        1248 :         auto map = Eigen::Map<const Eigen::VectorX<real_t>>(state_index_to_quantum_number_f.data(),
     700         624 :                                                             state_index_to_quantum_number_f.size());
     701         624 :         Eigen::VectorX<real_t> val = probs * map;
     702         624 :         Eigen::VectorX<real_t> sq = probs * map.cwiseAbs2();
     703         624 :         Eigen::VectorX<real_t> diff = (val.cwiseAbs2() - sq).cwiseAbs();
     704         624 :         transformed->state_index_to_quantum_number_f.resize(probs.rows());
     705             : 
     706       31002 :         for (size_t i = 0; i < transformed->state_index_to_quantum_number_f.size(); ++i) {
     707       30378 :             if (diff[i] < numerical_precision) {
     708        5113 :                 transformed->state_index_to_quantum_number_f[i] = std::round(val[i] * 2) / 2;
     709             :             } else {
     710       25265 :                 transformed->state_index_to_quantum_number_f[i] =
     711       25265 :                     std::numeric_limits<real_t>::max();
     712       25265 :                 transformed->_has_quantum_number_f = false;
     713             :             }
     714             :         }
     715         624 :     }
     716             : 
     717             :     {
     718        1248 :         auto map = Eigen::Map<const Eigen::VectorX<real_t>>(state_index_to_quantum_number_m.data(),
     719         624 :                                                             state_index_to_quantum_number_m.size());
     720         624 :         Eigen::VectorX<real_t> val = probs * map;
     721         624 :         Eigen::VectorX<real_t> sq = probs * map.cwiseAbs2();
     722         624 :         Eigen::VectorX<real_t> diff = (val.cwiseAbs2() - sq).cwiseAbs();
     723         624 :         transformed->state_index_to_quantum_number_m.resize(probs.rows());
     724             : 
     725       31002 :         for (size_t i = 0; i < transformed->state_index_to_quantum_number_m.size(); ++i) {
     726       30378 :             if (diff[i] < numerical_precision) {
     727       27526 :                 transformed->state_index_to_quantum_number_m[i] = std::round(val[i] * 2) / 2;
     728             :             } else {
     729        2852 :                 transformed->state_index_to_quantum_number_m[i] =
     730        2852 :                     std::numeric_limits<real_t>::max();
     731        2852 :                 transformed->_has_quantum_number_m = false;
     732             :             }
     733             :         }
     734         624 :     }
     735             : 
     736             :     {
     737             :         using utype = std::underlying_type<Parity>::type;
     738         624 :         Eigen::VectorX<real_t> map(state_index_to_parity.size());
     739       36100 :         for (size_t i = 0; i < state_index_to_parity.size(); ++i) {
     740       35476 :             map[i] = static_cast<utype>(state_index_to_parity[i]);
     741             :         }
     742         624 :         Eigen::VectorX<real_t> val = probs * map;
     743         624 :         Eigen::VectorX<real_t> sq = probs * map.cwiseAbs2();
     744         624 :         Eigen::VectorX<real_t> diff = (val.cwiseAbs2() - sq).cwiseAbs();
     745         624 :         transformed->state_index_to_parity.resize(probs.rows());
     746             : 
     747       31002 :         for (size_t i = 0; i < transformed->state_index_to_parity.size(); ++i) {
     748       30378 :             if (diff[i] < numerical_precision) {
     749       24262 :                 transformed->state_index_to_parity[i] = static_cast<Parity>(std::lround(val[i]));
     750             :             } else {
     751        6116 :                 transformed->state_index_to_parity[i] = Parity::UNKNOWN;
     752        6116 :                 transformed->_has_parity = false;
     753             :             }
     754             :         }
     755         624 :     }
     756             : 
     757         624 :     set_task_status("Obtaining transformed state mapping...");
     758             :     {
     759             :         // In the following, we obtain a bijective map between state index and ket index.
     760             : 
     761             :         // Find the maximum value in each row and column
     762         624 :         std::vector<real_t> max_in_row(transformed->coefficients.matrix.rows(), 0);
     763         624 :         std::vector<real_t> max_in_col(transformed->coefficients.matrix.cols(), 0);
     764       36100 :         for (int row = 0; row < transformed->coefficients.matrix.outerSize(); ++row) {
     765       35476 :             for (typename Eigen::SparseMatrix<scalar_t, Eigen::RowMajor>::InnerIterator it(
     766       35476 :                      transformed->coefficients.matrix, row);
     767      657371 :                  it; ++it) {
     768      621895 :                 real_t val = std::pow(std::abs(it.value()), 2);
     769      621895 :                 max_in_row[row] = std::max(max_in_row[row], val);
     770      621895 :                 max_in_col[it.col()] = std::max(max_in_col[it.col()], val);
     771             :             }
     772             :         }
     773             : 
     774             :         // Use the maximum values to define a cost for a sub-optimal mapping
     775         624 :         std::vector<real_t> costs;
     776         624 :         std::vector<std::pair<int, int>> mappings;
     777         624 :         costs.reserve(transformed->coefficients.matrix.nonZeros());
     778         624 :         mappings.reserve(transformed->coefficients.matrix.nonZeros());
     779       36100 :         for (int row = 0; row < transformed->coefficients.matrix.outerSize(); ++row) {
     780       35476 :             for (typename Eigen::SparseMatrix<scalar_t, Eigen::RowMajor>::InnerIterator it(
     781       35476 :                      transformed->coefficients.matrix, row);
     782      657371 :                  it; ++it) {
     783      621895 :                 real_t val = std::pow(std::abs(it.value()), 2);
     784      621895 :                 real_t cost = max_in_row[row] + max_in_col[it.col()] - 2 * val;
     785      621895 :                 costs.push_back(cost);
     786      621895 :                 mappings.push_back({row, it.col()});
     787             :             }
     788             :         }
     789             : 
     790             :         // Obtain from the costs in which order the mappings should be considered
     791         624 :         std::vector<size_t> order(costs.size());
     792         624 :         std::iota(order.begin(), order.end(), 0);
     793         624 :         std::sort(order.begin(), order.end(),
     794     7961500 :                   [&](size_t a, size_t b) { return costs[a] < costs[b]; });
     795             : 
     796             :         // Fill ket_index_to_state_index with invalid values as there can be more kets than states
     797         624 :         std::fill(transformed->ket_index_to_state_index.begin(),
     798         624 :                   transformed->ket_index_to_state_index.end(), std::numeric_limits<int>::max());
     799             : 
     800             :         // Generate the bijective map
     801         624 :         std::vector<bool> row_used(transformed->coefficients.matrix.rows(), false);
     802         624 :         std::vector<bool> col_used(transformed->coefficients.matrix.cols(), false);
     803         624 :         int num_used = 0;
     804      116841 :         for (size_t idx : order) {
     805      116837 :             int row = mappings[idx].first;  // corresponds to the ket index
     806      116837 :             int col = mappings[idx].second; // corresponds to the state index
     807      116837 :             if (!row_used[row] && !col_used[col]) {
     808       30376 :                 row_used[row] = true;
     809       30376 :                 col_used[col] = true;
     810       30376 :                 num_used++;
     811       30376 :                 transformed->state_index_to_ket_index[col] = row;
     812       30376 :                 transformed->ket_index_to_state_index[row] = col;
     813             :             }
     814      116837 :             if (num_used == transformed->coefficients.matrix.cols()) {
     815         620 :                 break;
     816             :             }
     817             :         }
     818         624 :         if (num_used != transformed->coefficients.matrix.cols()) {
     819           4 :             SPDLOG_WARN("A bijective map between states and kets could not be found.");
     820             :         }
     821         624 :     }
     822             : 
     823         624 :     return transformed;
     824         624 : }
     825             : 
     826             : template <typename Derived>
     827      252862 : size_t Basis<Derived>::hash::operator()(const std::shared_ptr<const ket_t> &k) const {
     828      252862 :     return typename ket_t::hash()(*k);
     829             : }
     830             : 
     831             : template <typename Derived>
     832        1446 : bool Basis<Derived>::equal_to::operator()(const std::shared_ptr<const ket_t> &lhs,
     833             :                                           const std::shared_ptr<const ket_t> &rhs) const {
     834        1446 :     return *lhs == *rhs;
     835             : }
     836             : 
     837             : // Explicit instantiations
     838             : template class Basis<BasisAtom<double>>;
     839             : template class Basis<BasisAtom<std::complex<double>>>;
     840             : template class Basis<BasisPair<double>>;
     841             : template class Basis<BasisPair<std::complex<double>>>;
     842             : } // namespace pairinteraction

Generated by: LCOV version 1.16