LCOV - code coverage report
Current view: top level - src/basis - Basis.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 370 460 80.4 %
Date: 2026-01-22 14:02:01 Functions: 130 204 63.7 %

          Line data    Source code
       1             : // SPDX-FileCopyrightText: 2024 PairInteraction Developers
       2             : // SPDX-License-Identifier: LGPL-3.0-or-later
       3             : 
       4             : #include "pairinteraction/basis/Basis.hpp"
       5             : 
       6             : #include "pairinteraction/basis/BasisAtom.hpp"
       7             : #include "pairinteraction/basis/BasisPair.hpp"
       8             : #include "pairinteraction/enums/Parity.hpp"
       9             : #include "pairinteraction/enums/TransformationType.hpp"
      10             : #include "pairinteraction/ket/KetAtom.hpp"
      11             : #include "pairinteraction/ket/KetPair.hpp"
      12             : #include "pairinteraction/utils/eigen_assertion.hpp"
      13             : #include "pairinteraction/utils/eigen_compat.hpp"
      14             : #include "pairinteraction/utils/wigner.hpp"
      15             : 
      16             : #include <cassert>
      17             : #include <numeric>
      18             : #include <set>
      19             : #include <spdlog/spdlog.h>
      20             : 
      21             : namespace pairinteraction {
      22             : 
      23             : template <typename Scalar>
      24             : class BasisAtom;
      25             : 
      26             : template <typename Derived>
      27        1958 : void Basis<Derived>::perform_sorter_checks(const std::vector<TransformationType> &labels) const {
      28             :     // Check if the labels are valid sorting labels
      29        3943 :     for (const auto &label : labels) {
      30        1985 :         if (!utils::is_sorting(label)) {
      31           0 :             throw std::invalid_argument("One of the labels is not a valid sorting label.");
      32             :         }
      33             :     }
      34        1958 : }
      35             : 
      36             : template <typename Derived>
      37         566 : void Basis<Derived>::perform_blocks_checks(
      38             :     const std::set<TransformationType> &unique_labels) const {
      39             :     // Check if the states are sorted by the requested labels
      40         566 :     std::set<TransformationType> unique_labels_present;
      41        1120 :     for (const auto &label : get_transformation().transformation_type) {
      42         605 :         if (!utils::is_sorting(label) || unique_labels_present.size() >= unique_labels.size()) {
      43          51 :             break;
      44             :         }
      45         554 :         unique_labels_present.insert(label);
      46             :     }
      47         566 :     if (unique_labels != unique_labels_present) {
      48           0 :         throw std::invalid_argument("The states are not sorted by the requested labels.");
      49             :     }
      50             : 
      51             :     // Throw a meaningful error if getting the blocks by energy is requested as this might be a
      52             :     // common mistake
      53         566 :     if (unique_labels.contains(TransformationType::SORT_BY_ENERGY)) {
      54           0 :         throw std::invalid_argument("States do not store the energy and thus no energy blocks can "
      55             :                                     "be obtained. Use an energy operator instead.");
      56             :     }
      57         566 : }
      58             : 
      59             : template <typename Derived>
      60         357 : Basis<Derived>::Basis(ketvec_t &&kets)
      61         714 :     : kets(std::move(kets)), coefficients{{static_cast<Eigen::Index>(this->kets.size()),
      62         357 :                                            static_cast<Eigen::Index>(this->kets.size())},
      63         714 :                                           {TransformationType::SORT_BY_KET}} {
      64         357 :     if (this->kets.empty()) {
      65           0 :         throw std::invalid_argument("The basis must contain at least one element.");
      66             :     }
      67         357 :     state_index_to_quantum_number_f.reserve(this->kets.size());
      68         357 :     state_index_to_quantum_number_m.reserve(this->kets.size());
      69         357 :     state_index_to_parity.reserve(this->kets.size());
      70         357 :     ket_to_ket_index.reserve(this->kets.size());
      71         357 :     size_t index = 0;
      72      196003 :     for (const auto &ket : this->kets) {
      73      195646 :         state_index_to_quantum_number_f.push_back(ket->get_quantum_number_f());
      74      195646 :         state_index_to_quantum_number_m.push_back(ket->get_quantum_number_m());
      75      195646 :         state_index_to_parity.push_back(ket->get_parity());
      76      195646 :         ket_to_ket_index[ket] = index++;
      77      195646 :         if (ket->get_quantum_number_f() == std::numeric_limits<real_t>::max()) {
      78      175963 :             _has_quantum_number_f = false;
      79             :         }
      80      195646 :         if (ket->get_quantum_number_m() == std::numeric_limits<real_t>::max()) {
      81       64000 :             _has_quantum_number_m = false;
      82             :         }
      83      195646 :         if (ket->get_parity() == Parity::UNKNOWN) {
      84      175963 :             _has_parity = false;
      85             :         }
      86             :     }
      87         357 :     state_index_to_ket_index.resize(this->kets.size());
      88         357 :     std::iota(state_index_to_ket_index.begin(), state_index_to_ket_index.end(), 0);
      89         357 :     ket_index_to_state_index.resize(this->kets.size());
      90         357 :     std::iota(ket_index_to_state_index.begin(), ket_index_to_state_index.end(), 0);
      91         357 :     coefficients.matrix.setIdentity();
      92         357 : }
      93             : 
      94             : template <typename Derived>
      95         517 : bool Basis<Derived>::has_quantum_number_f() const {
      96         517 :     return _has_quantum_number_f;
      97             : }
      98             : 
      99             : template <typename Derived>
     100      306509 : bool Basis<Derived>::has_quantum_number_m() const {
     101      306509 :     return _has_quantum_number_m;
     102             : }
     103             : 
     104             : template <typename Derived>
     105         517 : bool Basis<Derived>::has_parity() const {
     106         517 :     return _has_parity;
     107             : }
     108             : 
     109             : template <typename Derived>
     110        2541 : const Derived &Basis<Derived>::derived() const {
     111        2541 :     return static_cast<const Derived &>(*this);
     112             : }
     113             : 
     114             : template <typename Derived>
     115         815 : const typename Basis<Derived>::ketvec_t &Basis<Derived>::get_kets() const {
     116         815 :     return kets;
     117             : }
     118             : 
     119             : template <typename Derived>
     120             : const Eigen::SparseMatrix<typename Basis<Derived>::scalar_t, Eigen::RowMajor> &
     121        9491 : Basis<Derived>::get_coefficients() const {
     122        9491 :     return coefficients.matrix;
     123             : }
     124             : 
     125             : template <typename Derived>
     126             : Eigen::SparseMatrix<typename Basis<Derived>::scalar_t, Eigen::RowMajor> &
     127           0 : Basis<Derived>::get_coefficients() {
     128           0 :     return coefficients.matrix;
     129             : }
     130             : 
     131             : template <typename Derived>
     132           5 : void Basis<Derived>::set_coefficients(
     133             :     const Eigen::SparseMatrix<scalar_t, Eigen::RowMajor> &values) {
     134           5 :     if (values.rows() != coefficients.matrix.rows()) {
     135           0 :         throw std::invalid_argument("Incompatible number of rows.");
     136             :     }
     137           5 :     if (values.cols() != coefficients.matrix.cols()) {
     138           0 :         throw std::invalid_argument("Incompatible number of columns.");
     139             :     }
     140             : 
     141           5 :     coefficients.matrix = values;
     142             : 
     143           5 :     std::fill(ket_index_to_state_index.begin(), ket_index_to_state_index.end(),
     144           5 :               std::numeric_limits<int>::max());
     145           5 :     std::fill(state_index_to_ket_index.begin(), state_index_to_ket_index.end(),
     146           5 :               std::numeric_limits<int>::max());
     147           5 :     std::fill(state_index_to_quantum_number_f.begin(), state_index_to_quantum_number_f.end(),
     148           5 :               std::numeric_limits<real_t>::max());
     149           5 :     std::fill(state_index_to_quantum_number_m.begin(), state_index_to_quantum_number_m.end(),
     150           5 :               std::numeric_limits<real_t>::max());
     151           5 :     std::fill(state_index_to_parity.begin(), state_index_to_parity.end(), Parity::UNKNOWN);
     152           5 :     _has_quantum_number_f = false;
     153           5 :     _has_quantum_number_m = false;
     154           5 :     _has_parity = false;
     155           5 : }
     156             : 
     157             : template <typename Derived>
     158         552 : int Basis<Derived>::get_ket_index_from_ket(std::shared_ptr<const ket_t> ket) const {
     159         552 :     if (!ket_to_ket_index.contains(ket)) {
     160           0 :         return -1;
     161             :     }
     162         552 :     return ket_to_ket_index.at(ket);
     163             : }
     164             : 
     165             : template <typename Derived>
     166             : Eigen::VectorX<typename Basis<Derived>::scalar_t>
     167         109 : Basis<Derived>::get_amplitudes(std::shared_ptr<const ket_t> ket) const {
     168         109 :     int ket_index = get_ket_index_from_ket(ket);
     169         109 :     if (ket_index < 0) {
     170           0 :         throw std::invalid_argument("The ket does not belong to the basis.");
     171             :     }
     172             :     // The following line is a more efficient alternative to
     173             :     // "get_amplitudes(get_canonical_state_from_ket(ket)).transpose()"
     174         218 :     return coefficients.matrix.row(ket_index);
     175             : }
     176             : 
     177             : template <typename Derived>
     178             : Eigen::SparseMatrix<typename Basis<Derived>::scalar_t, Eigen::RowMajor>
     179          16 : Basis<Derived>::get_amplitudes(std::shared_ptr<const Derived> other) const {
     180          24 :     return other->coefficients.matrix.adjoint() * coefficients.matrix;
     181             : }
     182             : 
     183             : template <typename Derived>
     184             : Eigen::VectorX<typename Basis<Derived>::real_t>
     185         103 : Basis<Derived>::get_overlaps(std::shared_ptr<const ket_t> ket) const {
     186         103 :     return get_amplitudes(ket).cwiseAbs2();
     187             : }
     188             : 
     189             : template <typename Derived>
     190             : Eigen::SparseMatrix<typename Basis<Derived>::real_t, Eigen::RowMajor>
     191           8 : Basis<Derived>::get_overlaps(std::shared_ptr<const Derived> other) const {
     192           8 :     return get_amplitudes(other).cwiseAbs2();
     193             : }
     194             : 
     195             : template <typename Derived>
     196           0 : typename Basis<Derived>::real_t Basis<Derived>::get_quantum_number_f(size_t state_index) const {
     197           0 :     real_t quantum_number_f = state_index_to_quantum_number_f.at(state_index);
     198           0 :     if (quantum_number_f == std::numeric_limits<real_t>::max()) {
     199           0 :         throw std::invalid_argument("The state does not have a well-defined quantum number f.");
     200             :     }
     201           0 :     return quantum_number_f;
     202             : }
     203             : 
     204             : template <typename Derived>
     205      242034 : typename Basis<Derived>::real_t Basis<Derived>::get_quantum_number_m(size_t state_index) const {
     206      242034 :     real_t quantum_number_m = state_index_to_quantum_number_m.at(state_index);
     207      242034 :     if (quantum_number_m == std::numeric_limits<real_t>::max()) {
     208           0 :         throw std::invalid_argument("The state does not have a well-defined quantum number m.");
     209             :     }
     210      242034 :     return quantum_number_m;
     211             : }
     212             : 
     213             : template <typename Derived>
     214          43 : Parity Basis<Derived>::get_parity(size_t state_index) const {
     215          43 :     Parity parity = state_index_to_parity.at(state_index);
     216          43 :     if (parity == Parity::UNKNOWN) {
     217           0 :         throw std::invalid_argument("The state does not have a well-defined parity.");
     218             :     }
     219          43 :     return parity;
     220             : }
     221             : 
     222             : template <typename Derived>
     223             : std::shared_ptr<const typename Basis<Derived>::ket_t>
     224         152 : Basis<Derived>::get_corresponding_ket(size_t state_index) const {
     225         152 :     size_t ket_index = state_index_to_ket_index.at(state_index);
     226         152 :     if (ket_index == std::numeric_limits<int>::max()) {
     227           0 :         throw std::invalid_argument("The state does not belong to a ket in a well-defined way.");
     228             :     }
     229         152 :     return kets[ket_index];
     230             : }
     231             : 
     232             : template <typename Derived>
     233             : std::shared_ptr<const typename Basis<Derived>::ket_t>
     234           0 : Basis<Derived>::get_corresponding_ket(std::shared_ptr<const Derived> /*state*/) const {
     235           0 :     throw std::runtime_error("Not implemented yet.");
     236             : }
     237             : 
     238             : template <typename Derived>
     239         374 : std::shared_ptr<const Derived> Basis<Derived>::get_state(size_t state_index) const {
     240             :     // Create a copy of the current object
     241         374 :     auto restricted = std::make_shared<Derived>(derived());
     242             : 
     243             :     // Restrict the copy to the state with the largest overlap
     244         374 :     restricted->coefficients.matrix = restricted->coefficients.matrix.col(state_index);
     245             : 
     246         374 :     std::fill(restricted->ket_index_to_state_index.begin(),
     247         374 :               restricted->ket_index_to_state_index.end(), std::numeric_limits<int>::max());
     248             : 
     249         374 :     size_t ket_index = state_index_to_ket_index[state_index];
     250         374 :     restricted->state_index_to_ket_index = {ket_index};
     251         374 :     if (ket_index != std::numeric_limits<int>::max()) {
     252         374 :         restricted->ket_index_to_state_index[ket_index] = 0;
     253             :     }
     254             : 
     255         374 :     restricted->state_index_to_quantum_number_f = {state_index_to_quantum_number_f[state_index]};
     256         374 :     restricted->state_index_to_quantum_number_m = {state_index_to_quantum_number_m[state_index]};
     257         374 :     restricted->state_index_to_parity = {state_index_to_parity[state_index]};
     258             : 
     259         748 :     restricted->_has_quantum_number_f =
     260         374 :         restricted->state_index_to_quantum_number_f[0] != std::numeric_limits<real_t>::max();
     261         748 :     restricted->_has_quantum_number_m =
     262         374 :         restricted->state_index_to_quantum_number_m[0] != std::numeric_limits<real_t>::max();
     263         374 :     restricted->_has_parity = restricted->state_index_to_parity[0] != Parity::UNKNOWN;
     264             : 
     265         748 :     return restricted;
     266         374 : }
     267             : 
     268             : template <typename Derived>
     269             : std::shared_ptr<const typename Basis<Derived>::ket_t>
     270           0 : Basis<Derived>::get_ket(size_t ket_index) const {
     271           0 :     return kets[ket_index];
     272             : }
     273             : 
     274             : template <typename Derived>
     275          52 : std::shared_ptr<const Derived> Basis<Derived>::get_corresponding_state(size_t ket_index) const {
     276          52 :     size_t state_index = ket_index_to_state_index.at(ket_index);
     277          52 :     if (state_index == std::numeric_limits<int>::max()) {
     278           0 :         throw std::runtime_error("The ket does not belong to a state in a well-defined way.");
     279             :     }
     280          52 :     return get_state(state_index);
     281             : }
     282             : 
     283             : template <typename Derived>
     284             : std::shared_ptr<const Derived>
     285          52 : Basis<Derived>::get_corresponding_state(std::shared_ptr<const ket_t> ket) const {
     286          52 :     int ket_index = get_ket_index_from_ket(ket);
     287          52 :     if (ket_index < 0) {
     288           0 :         throw std::invalid_argument("The ket does not belong to the basis.");
     289             :     }
     290          52 :     return get_corresponding_state(ket_index);
     291             : }
     292             : 
     293             : template <typename Derived>
     294         182 : size_t Basis<Derived>::get_corresponding_state_index(size_t ket_index) const {
     295         182 :     int state_index = ket_index_to_state_index.at(ket_index);
     296         182 :     if (state_index == std::numeric_limits<int>::max()) {
     297           0 :         throw std::runtime_error("The ket does not belong to a state in a well-defined way.");
     298             :     }
     299         182 :     return state_index;
     300             : }
     301             : 
     302             : template <typename Derived>
     303         182 : size_t Basis<Derived>::get_corresponding_state_index(std::shared_ptr<const ket_t> ket) const {
     304         182 :     int ket_index = get_ket_index_from_ket(ket);
     305         182 :     if (ket_index < 0) {
     306           0 :         throw std::invalid_argument("The ket does not belong to the basis.");
     307             :     }
     308         182 :     return get_corresponding_state_index(ket_index);
     309             : }
     310             : 
     311             : template <typename Derived>
     312         152 : size_t Basis<Derived>::get_corresponding_ket_index(size_t state_index) const {
     313         152 :     size_t ket_index = state_index_to_ket_index.at(state_index);
     314         152 :     if (ket_index == std::numeric_limits<int>::max()) {
     315           0 :         throw std::runtime_error("The state does not belong to a ket in a well-defined way.");
     316             :     }
     317         152 :     return ket_index;
     318             : }
     319             : 
     320             : template <typename Derived>
     321           0 : size_t Basis<Derived>::get_corresponding_ket_index(std::shared_ptr<const Derived> /*state*/) const {
     322           0 :     throw std::runtime_error("Not implemented yet.");
     323             : }
     324             : 
     325             : template <typename Derived>
     326             : std::shared_ptr<const Derived>
     327         209 : Basis<Derived>::get_canonical_state_from_ket(size_t ket_index) const {
     328             :     // Create a copy of the current object
     329         209 :     auto created = std::make_shared<Derived>(derived());
     330             : 
     331             :     // Fill the copy with the state corresponding to the ket index
     332         209 :     created->coefficients.matrix =
     333         418 :         Eigen::SparseMatrix<scalar_t, Eigen::RowMajor>(coefficients.matrix.rows(), 1);
     334         209 :     created->coefficients.matrix.coeffRef(ket_index, 0) = 1;
     335         209 :     created->coefficients.matrix.makeCompressed();
     336             : 
     337         209 :     std::fill(created->ket_index_to_state_index.begin(), created->ket_index_to_state_index.end(),
     338         209 :               std::numeric_limits<int>::max());
     339         209 :     created->ket_index_to_state_index[ket_index] = 0;
     340             : 
     341         209 :     created->state_index_to_quantum_number_f = {kets[ket_index]->get_quantum_number_f()};
     342         209 :     created->state_index_to_quantum_number_m = {kets[ket_index]->get_quantum_number_m()};
     343         209 :     created->state_index_to_parity = {kets[ket_index]->get_parity()};
     344         209 :     created->state_index_to_ket_index = {ket_index};
     345             : 
     346         418 :     created->_has_quantum_number_f =
     347         209 :         created->state_index_to_quantum_number_f[0] != std::numeric_limits<real_t>::max();
     348         418 :     created->_has_quantum_number_m =
     349         209 :         created->state_index_to_quantum_number_m[0] != std::numeric_limits<real_t>::max();
     350         209 :     created->_has_parity = created->state_index_to_parity[0] != Parity::UNKNOWN;
     351             : 
     352         418 :     return created;
     353         209 : }
     354             : 
     355             : template <typename Derived>
     356             : std::shared_ptr<const Derived>
     357         209 : Basis<Derived>::get_canonical_state_from_ket(std::shared_ptr<const ket_t> ket) const {
     358         209 :     int ket_index = get_ket_index_from_ket(ket);
     359         209 :     if (ket_index < 0) {
     360           0 :         throw std::invalid_argument("The ket does not belong to the basis.");
     361             :     }
     362         209 :     return get_canonical_state_from_ket(ket_index);
     363             : }
     364             : 
     365             : template <typename Derived>
     366           4 : typename Basis<Derived>::Iterator Basis<Derived>::begin() const {
     367           4 :     return kets.begin();
     368             : }
     369             : 
     370             : template <typename Derived>
     371           4 : typename Basis<Derived>::Iterator Basis<Derived>::end() const {
     372           4 :     return kets.end();
     373             : }
     374             : 
     375             : template <typename Derived>
     376           8 : Basis<Derived>::Iterator::Iterator(typename ketvec_t::const_iterator it) : it{std::move(it)} {}
     377             : 
     378             : template <typename Derived>
     379          80 : bool Basis<Derived>::Iterator::operator!=(const Iterator &other) const {
     380          80 :     return other.it != it;
     381             : }
     382             : 
     383             : template <typename Derived>
     384          76 : std::shared_ptr<const typename Basis<Derived>::ket_t> Basis<Derived>::Iterator::operator*() const {
     385          76 :     return *it;
     386             : }
     387             : 
     388             : template <typename Derived>
     389          76 : typename Basis<Derived>::Iterator &Basis<Derived>::Iterator::operator++() {
     390          76 :     ++it;
     391          76 :     return *this;
     392             : }
     393             : 
     394             : template <typename Derived>
     395       18049 : size_t Basis<Derived>::get_number_of_states() const {
     396       18049 :     return coefficients.matrix.cols();
     397             : }
     398             : 
     399             : template <typename Derived>
     400        6855 : size_t Basis<Derived>::get_number_of_kets() const {
     401        6855 :     return coefficients.matrix.rows();
     402             : }
     403             : 
     404             : template <typename Derived>
     405             : const Transformation<typename Basis<Derived>::scalar_t> &
     406         567 : Basis<Derived>::get_transformation() const {
     407         567 :     return coefficients;
     408             : }
     409             : 
     410             : template <typename Derived>
     411             : Transformation<typename Basis<Derived>::scalar_t>
     412           0 : Basis<Derived>::get_rotator(real_t alpha, real_t beta, real_t gamma) const {
     413           0 :     Transformation<scalar_t> transformation{{static_cast<Eigen::Index>(coefficients.matrix.rows()),
     414             :                                              static_cast<Eigen::Index>(coefficients.matrix.rows())},
     415             :                                             {TransformationType::ROTATE}};
     416             : 
     417           0 :     std::vector<Eigen::Triplet<scalar_t>> entries;
     418             : 
     419           0 :     for (size_t idx_initial = 0; idx_initial < kets.size(); ++idx_initial) {
     420           0 :         real_t f = kets[idx_initial]->get_quantum_number_f();
     421           0 :         real_t m_initial = kets[idx_initial]->get_quantum_number_m();
     422             : 
     423           0 :         assert(2 * f == std::floor(2 * f) && f >= 0);
     424           0 :         assert(2 * m_initial == std::floor(2 * m_initial) && m_initial >= -f && m_initial <= f);
     425             : 
     426           0 :         for (real_t m_final = -f; m_final <= f; // NOSONAR m_final is precisely representable
     427             :              ++m_final) {
     428           0 :             auto val = wigner::wigner_uppercase_d_matrix<scalar_t>(f, m_initial, m_final, alpha,
     429             :                                                                    beta, gamma);
     430           0 :             size_t idx_final = get_ket_index_from_ket(
     431           0 :                 kets[idx_initial]->get_ket_for_different_quantum_number_m(m_final));
     432           0 :             entries.emplace_back(idx_final, idx_initial, val);
     433             :         }
     434             :     }
     435             : 
     436           0 :     transformation.matrix.setFromTriplets(entries.begin(), entries.end());
     437           0 :     transformation.matrix.makeCompressed();
     438             : 
     439           0 :     return transformation;
     440           0 : }
     441             : 
     442             : template <typename Derived>
     443           1 : Sorting Basis<Derived>::get_sorter(const std::vector<TransformationType> &labels) const {
     444           1 :     perform_sorter_checks(labels);
     445             : 
     446             :     // Throw a meaningful error if sorting by energy is requested as this might be a common mistake
     447           1 :     if (std::find(labels.begin(), labels.end(), TransformationType::SORT_BY_ENERGY) !=
     448           2 :         labels.end()) {
     449           0 :         throw std::invalid_argument("States do not store the energy and thus can not be sorted by "
     450             :                                     "the energy. Use an energy operator instead.");
     451             :     }
     452             : 
     453             :     // Initialize transformation
     454           1 :     Sorting transformation;
     455           1 :     transformation.matrix.resize(coefficients.matrix.cols());
     456           1 :     transformation.matrix.setIdentity();
     457             : 
     458             :     // Get the sorter
     459           1 :     get_sorter_without_checks(labels, transformation);
     460             : 
     461             :     // Check if all labels have been used for sorting
     462           1 :     if (labels != transformation.transformation_type) {
     463           0 :         throw std::invalid_argument("The states could not be sorted by all the requested labels.");
     464             :     }
     465             : 
     466           1 :     return transformation;
     467           0 : }
     468             : 
     469             : template <typename Derived>
     470             : std::vector<IndicesOfBlock>
     471           1 : Basis<Derived>::get_indices_of_blocks(const std::vector<TransformationType> &labels) const {
     472           1 :     perform_sorter_checks(labels);
     473             : 
     474           1 :     std::set<TransformationType> unique_labels(labels.begin(), labels.end());
     475           1 :     perform_blocks_checks(unique_labels);
     476             : 
     477             :     // Get the blocks
     478           1 :     IndicesOfBlocksCreator blocks_creator({0, static_cast<size_t>(coefficients.matrix.cols())});
     479           1 :     get_indices_of_blocks_without_checks(unique_labels, blocks_creator);
     480             : 
     481           2 :     return blocks_creator.create();
     482           1 : }
     483             : 
     484             : template <typename Derived>
     485         515 : void Basis<Derived>::get_sorter_without_checks(const std::vector<TransformationType> &labels,
     486             :                                                Sorting &transformation) const {
     487         515 :     constexpr real_t numerical_precision = 100 * std::numeric_limits<real_t>::epsilon();
     488             : 
     489         515 :     int *perm_begin = transformation.matrix.indices().data();
     490         515 :     int *perm_end = perm_begin + coefficients.matrix.cols();
     491         515 :     const int *perm_back = perm_end - 1;
     492             : 
     493             :     // Sort the vector based on the requested labels
     494      564260 :     std::stable_sort(perm_begin, perm_end, [&](int a, int b) {
     495      254827 :         for (const auto &label : labels) {
     496      166142 :             switch (label) {
     497       30703 :             case TransformationType::SORT_BY_PARITY:
     498       30703 :                 if (state_index_to_parity[a] != state_index_to_parity[b]) {
     499       47592 :                     return state_index_to_parity[a] < state_index_to_parity[b];
     500             :                 }
     501       20305 :                 break;
     502      135439 :             case TransformationType::SORT_BY_QUANTUM_NUMBER_M:
     503      270878 :                 if (std::abs(state_index_to_quantum_number_m[a] -
     504      270878 :                              state_index_to_quantum_number_m[b]) > numerical_precision) {
     505       37194 :                     return state_index_to_quantum_number_m[a] < state_index_to_quantum_number_m[b];
     506             :                 }
     507       98245 :                 break;
     508           0 :             case TransformationType::SORT_BY_QUANTUM_NUMBER_F:
     509           0 :                 if (std::abs(state_index_to_quantum_number_f[a] -
     510           0 :                              state_index_to_quantum_number_f[b]) > numerical_precision) {
     511           0 :                     return state_index_to_quantum_number_f[a] < state_index_to_quantum_number_f[b];
     512             :                 }
     513           0 :                 break;
     514           0 :             case TransformationType::SORT_BY_KET:
     515           0 :                 if (state_index_to_ket_index[a] != state_index_to_ket_index[b]) {
     516           0 :                     return state_index_to_ket_index[a] < state_index_to_ket_index[b];
     517             :                 }
     518           0 :                 break;
     519           0 :             default:
     520           0 :                 std::abort(); // Can't happen because of previous checks
     521             :             }
     522             :         }
     523       88685 :         return false; // Elements are equal
     524             :     });
     525             : 
     526             :     // Check for invalid values and add transformation types
     527        1069 :     for (const auto &label : labels) {
     528         554 :         switch (label) {
     529          41 :         case TransformationType::SORT_BY_PARITY:
     530          41 :             if (state_index_to_parity[*perm_back] == Parity::UNKNOWN) {
     531           0 :                 throw std::invalid_argument(
     532             :                     "States cannot be labeled and thus not sorted by the parity.");
     533             :             }
     534          41 :             transformation.transformation_type.push_back(TransformationType::SORT_BY_PARITY);
     535          41 :             break;
     536         513 :         case TransformationType::SORT_BY_QUANTUM_NUMBER_M:
     537         513 :             if (state_index_to_quantum_number_m[*perm_back] == std::numeric_limits<real_t>::max()) {
     538           0 :                 throw std::invalid_argument(
     539             :                     "States cannot be labeled and thus not sorted by the quantum number m.");
     540             :             }
     541         513 :             transformation.transformation_type.push_back(
     542         513 :                 TransformationType::SORT_BY_QUANTUM_NUMBER_M);
     543         513 :             break;
     544           0 :         case TransformationType::SORT_BY_QUANTUM_NUMBER_F:
     545           0 :             if (state_index_to_quantum_number_f[*perm_back] == std::numeric_limits<real_t>::max()) {
     546           0 :                 throw std::invalid_argument(
     547             :                     "States cannot be labeled and thus not sorted by the quantum number f.");
     548             :             }
     549           0 :             transformation.transformation_type.push_back(
     550           0 :                 TransformationType::SORT_BY_QUANTUM_NUMBER_F);
     551           0 :             break;
     552           0 :         case TransformationType::SORT_BY_KET:
     553           0 :             if (state_index_to_ket_index[*perm_back] == std::numeric_limits<int>::max()) {
     554           0 :                 throw std::invalid_argument(
     555             :                     "States cannot be labeled and thus not sorted by kets.");
     556             :             }
     557           0 :             transformation.transformation_type.push_back(TransformationType::SORT_BY_KET);
     558           0 :             break;
     559           0 :         default:
     560           0 :             std::abort(); // Can't happen because of previous checks
     561             :         }
     562             :     }
     563         515 : }
     564             : 
     565             : template <typename Derived>
     566         515 : void Basis<Derived>::get_indices_of_blocks_without_checks(
     567             :     const std::set<TransformationType> &unique_labels,
     568             :     IndicesOfBlocksCreator &blocks_creator) const {
     569         515 :     constexpr real_t numerical_precision = 100 * std::numeric_limits<real_t>::epsilon();
     570             : 
     571         515 :     auto last_quantum_number_f = state_index_to_quantum_number_f[0];
     572         515 :     auto last_quantum_number_m = state_index_to_quantum_number_m[0];
     573         515 :     auto last_parity = state_index_to_parity[0];
     574         515 :     auto last_ket = state_index_to_ket_index[0];
     575             : 
     576       29789 :     for (int i = 0; i < coefficients.matrix.cols(); ++i) {
     577       65024 :         for (auto label : unique_labels) {
     578       36283 :             if (label == TransformationType::SORT_BY_QUANTUM_NUMBER_F &&
     579           0 :                 std::abs(state_index_to_quantum_number_f[i] - last_quantum_number_f) >
     580             :                     numerical_precision) {
     581           0 :                 blocks_creator.add(i);
     582           0 :                 break;
     583             :             }
     584       65377 :             if (label == TransformationType::SORT_BY_QUANTUM_NUMBER_M &&
     585       29094 :                 std::abs(state_index_to_quantum_number_m[i] - last_quantum_number_m) >
     586             :                     numerical_precision) {
     587         389 :                 blocks_creator.add(i);
     588         389 :                 break;
     589             :             }
     590       43083 :             if (label == TransformationType::SORT_BY_PARITY &&
     591        7189 :                 state_index_to_parity[i] != last_parity) {
     592         144 :                 blocks_creator.add(i);
     593         144 :                 break;
     594             :             }
     595       35750 :             if (label == TransformationType::SORT_BY_KET &&
     596       35750 :                 state_index_to_ket_index[i] != last_ket &&
     597           0 :                 last_ket != std::numeric_limits<int>::max()) {
     598           0 :                 blocks_creator.add(i);
     599           0 :                 break;
     600             :             }
     601             :         }
     602       29274 :         last_quantum_number_f = state_index_to_quantum_number_f[i];
     603       29274 :         last_quantum_number_m = state_index_to_quantum_number_m[i];
     604       29274 :         last_parity = state_index_to_parity[i];
     605       29274 :         last_ket = state_index_to_ket_index[i];
     606             :     }
     607         515 : }
     608             : 
     609             : template <typename Derived>
     610        1392 : std::shared_ptr<const Derived> Basis<Derived>::transformed(const Sorting &transformation) const {
     611             :     // Create a copy of the current object
     612        1392 :     auto transformed = std::make_shared<Derived>(derived());
     613             : 
     614        1392 :     if (coefficients.matrix.cols() == 0) {
     615           0 :         return transformed;
     616             :     }
     617             : 
     618             :     // Apply the transformation
     619        1392 :     transformed->coefficients.matrix = coefficients.matrix * transformation.matrix;
     620        1392 :     transformed->coefficients.transformation_type = transformation.transformation_type;
     621             : 
     622        1392 :     transformed->state_index_to_quantum_number_f.resize(transformation.matrix.size());
     623        1392 :     transformed->state_index_to_quantum_number_m.resize(transformation.matrix.size());
     624        1392 :     transformed->state_index_to_parity.resize(transformation.matrix.size());
     625        1392 :     transformed->state_index_to_ket_index.resize(transformation.matrix.size());
     626             : 
     627      138205 :     for (int i = 0; i < transformation.matrix.size(); ++i) {
     628      136813 :         transformed->state_index_to_quantum_number_f[i] =
     629      136813 :             state_index_to_quantum_number_f[transformation.matrix.indices()[i]];
     630      136813 :         transformed->state_index_to_quantum_number_m[i] =
     631      136813 :             state_index_to_quantum_number_m[transformation.matrix.indices()[i]];
     632      136813 :         transformed->state_index_to_parity[i] =
     633      136813 :             state_index_to_parity[transformation.matrix.indices()[i]];
     634             : 
     635      136813 :         size_t ket_index = state_index_to_ket_index[transformation.matrix.indices()[i]];
     636      136813 :         transformed->state_index_to_ket_index[i] = ket_index;
     637      136813 :         if (ket_index != std::numeric_limits<int>::max()) {
     638      136813 :             transformed->ket_index_to_state_index[ket_index] = i;
     639             :         }
     640             :     }
     641             : 
     642        1392 :     return transformed;
     643        1392 : }
     644             : 
     645             : template <typename Derived>
     646             : std::shared_ptr<const Derived>
     647         566 : Basis<Derived>::transformed(const Transformation<scalar_t> &transformation) const {
     648             :     // TODO why is "numerical_precision = 100 * std::sqrt(coefficients.matrix.rows()) *
     649             :     // std::numeric_limits<real_t>::epsilon()" too small for figuring out whether m is conserved?
     650         566 :     real_t numerical_precision = 0.001;
     651             : 
     652             :     // If the transformation is a rotation, it should be a rotation and nothing else
     653         566 :     bool is_rotation = false;
     654        1132 :     for (auto t : transformation.transformation_type) {
     655         566 :         if (t == TransformationType::ROTATE) {
     656           0 :             is_rotation = true;
     657           0 :             break;
     658             :         }
     659             :     }
     660         566 :     if (is_rotation && transformation.transformation_type.size() != 1) {
     661           0 :         throw std::invalid_argument("A rotation can not be combined with other transformations.");
     662             :     }
     663             : 
     664             :     // To apply a rotation, the object must only be sorted but other transformations are not allowed
     665         566 :     if (is_rotation) {
     666           0 :         for (auto t : coefficients.transformation_type) {
     667           0 :             if (!utils::is_sorting(t)) {
     668           0 :                 throw std::runtime_error(
     669             :                     "If the object was transformed by a different transformation "
     670             :                     "than sorting, it can not be rotated.");
     671             :             }
     672             :         }
     673             :     }
     674             : 
     675             :     // Create a copy of the current object
     676         566 :     auto transformed = std::make_shared<Derived>(derived());
     677             : 
     678         566 :     if (coefficients.matrix.cols() == 0) {
     679           0 :         return transformed;
     680             :     }
     681             : 
     682             :     // Apply the transformation
     683             :     // If a quantum number turns out to be conserved by the transformation, it will be
     684             :     // rounded to the nearest half integer to avoid loss of numerical_precision.
     685         566 :     transformed->coefficients.matrix = coefficients.matrix * transformation.matrix;
     686         566 :     transformed->coefficients.transformation_type = transformation.transformation_type;
     687             : 
     688         566 :     Eigen::SparseMatrix<real_t> probs = transformation.matrix.cwiseAbs2().transpose();
     689             : 
     690             :     {
     691        1132 :         auto map = Eigen::Map<const Eigen::VectorX<real_t>>(state_index_to_quantum_number_f.data(),
     692         566 :                                                             state_index_to_quantum_number_f.size());
     693         566 :         Eigen::VectorX<real_t> val = probs * map;
     694         566 :         Eigen::VectorX<real_t> sq = probs * map.cwiseAbs2();
     695         566 :         Eigen::VectorX<real_t> diff = (val.cwiseAbs2() - sq).cwiseAbs();
     696         566 :         transformed->state_index_to_quantum_number_f.resize(probs.rows());
     697             : 
     698       27554 :         for (size_t i = 0; i < transformed->state_index_to_quantum_number_f.size(); ++i) {
     699       26988 :             if (diff[i] < numerical_precision) {
     700        4727 :                 transformed->state_index_to_quantum_number_f[i] = std::round(val[i] * 2) / 2;
     701             :             } else {
     702       22261 :                 transformed->state_index_to_quantum_number_f[i] =
     703       22261 :                     std::numeric_limits<real_t>::max();
     704       22261 :                 transformed->_has_quantum_number_f = false;
     705             :             }
     706             :         }
     707         566 :     }
     708             : 
     709             :     {
     710        1132 :         auto map = Eigen::Map<const Eigen::VectorX<real_t>>(state_index_to_quantum_number_m.data(),
     711         566 :                                                             state_index_to_quantum_number_m.size());
     712         566 :         Eigen::VectorX<real_t> val = probs * map;
     713         566 :         Eigen::VectorX<real_t> sq = probs * map.cwiseAbs2();
     714         566 :         Eigen::VectorX<real_t> diff = (val.cwiseAbs2() - sq).cwiseAbs();
     715         566 :         transformed->state_index_to_quantum_number_m.resize(probs.rows());
     716             : 
     717       27554 :         for (size_t i = 0; i < transformed->state_index_to_quantum_number_m.size(); ++i) {
     718       26988 :             if (diff[i] < numerical_precision) {
     719       24330 :                 transformed->state_index_to_quantum_number_m[i] = std::round(val[i] * 2) / 2;
     720             :             } else {
     721        2658 :                 transformed->state_index_to_quantum_number_m[i] =
     722        2658 :                     std::numeric_limits<real_t>::max();
     723        2658 :                 transformed->_has_quantum_number_m = false;
     724             :             }
     725             :         }
     726         566 :     }
     727             : 
     728             :     {
     729             :         using utype = std::underlying_type<Parity>::type;
     730         566 :         Eigen::VectorX<real_t> map(state_index_to_parity.size());
     731       32652 :         for (size_t i = 0; i < state_index_to_parity.size(); ++i) {
     732       32086 :             map[i] = static_cast<utype>(state_index_to_parity[i]);
     733             :         }
     734         566 :         Eigen::VectorX<real_t> val = probs * map;
     735         566 :         Eigen::VectorX<real_t> sq = probs * map.cwiseAbs2();
     736         566 :         Eigen::VectorX<real_t> diff = (val.cwiseAbs2() - sq).cwiseAbs();
     737         566 :         transformed->state_index_to_parity.resize(probs.rows());
     738             : 
     739       27554 :         for (size_t i = 0; i < transformed->state_index_to_parity.size(); ++i) {
     740       26988 :             if (diff[i] < numerical_precision) {
     741       22058 :                 transformed->state_index_to_parity[i] = static_cast<Parity>(std::lround(val[i]));
     742             :             } else {
     743        4930 :                 transformed->state_index_to_parity[i] = Parity::UNKNOWN;
     744        4930 :                 transformed->_has_parity = false;
     745             :             }
     746             :         }
     747         566 :     }
     748             : 
     749             :     {
     750             :         // In the following, we obtain a bijective map between state index and ket index.
     751             : 
     752             :         // Find the maximum value in each row and column
     753         566 :         std::vector<real_t> max_in_row(transformed->coefficients.matrix.rows(), 0);
     754         566 :         std::vector<real_t> max_in_col(transformed->coefficients.matrix.cols(), 0);
     755       32652 :         for (int row = 0; row < transformed->coefficients.matrix.outerSize(); ++row) {
     756       32086 :             for (typename Eigen::SparseMatrix<scalar_t, Eigen::RowMajor>::InnerIterator it(
     757       32086 :                      transformed->coefficients.matrix, row);
     758      601942 :                  it; ++it) {
     759      569856 :                 real_t val = std::pow(std::abs(it.value()), 2);
     760      569856 :                 max_in_row[row] = std::max(max_in_row[row], val);
     761      569856 :                 max_in_col[it.col()] = std::max(max_in_col[it.col()], val);
     762             :             }
     763             :         }
     764             : 
     765             :         // Use the maximum values to define a cost for a sub-optimal mapping
     766         566 :         std::vector<real_t> costs;
     767         566 :         std::vector<std::pair<int, int>> mappings;
     768         566 :         costs.reserve(transformed->coefficients.matrix.nonZeros());
     769         566 :         mappings.reserve(transformed->coefficients.matrix.nonZeros());
     770       32652 :         for (int row = 0; row < transformed->coefficients.matrix.outerSize(); ++row) {
     771       32086 :             for (typename Eigen::SparseMatrix<scalar_t, Eigen::RowMajor>::InnerIterator it(
     772       32086 :                      transformed->coefficients.matrix, row);
     773      601942 :                  it; ++it) {
     774      569856 :                 real_t val = std::pow(std::abs(it.value()), 2);
     775      569856 :                 real_t cost = max_in_row[row] + max_in_col[it.col()] - 2 * val;
     776      569856 :                 costs.push_back(cost);
     777      569856 :                 mappings.push_back({row, it.col()});
     778             :             }
     779             :         }
     780             : 
     781             :         // Obtain from the costs in which order the mappings should be considered
     782         566 :         std::vector<size_t> order(costs.size());
     783         566 :         std::iota(order.begin(), order.end(), 0);
     784         566 :         std::sort(order.begin(), order.end(),
     785     7351168 :                   [&](size_t a, size_t b) { return costs[a] < costs[b]; });
     786             : 
     787             :         // Fill ket_index_to_state_index with invalid values as there can be more kets than states
     788         566 :         std::fill(transformed->ket_index_to_state_index.begin(),
     789         566 :                   transformed->ket_index_to_state_index.end(), std::numeric_limits<int>::max());
     790             : 
     791             :         // Generate the bijective map
     792         566 :         std::vector<bool> row_used(transformed->coefficients.matrix.rows(), false);
     793         566 :         std::vector<bool> col_used(transformed->coefficients.matrix.cols(), false);
     794         566 :         int num_used = 0;
     795       85793 :         for (size_t idx : order) {
     796       85789 :             int row = mappings[idx].first;  // corresponds to the ket index
     797       85789 :             int col = mappings[idx].second; // corresponds to the state index
     798       85789 :             if (!row_used[row] && !col_used[col]) {
     799       26986 :                 row_used[row] = true;
     800       26986 :                 col_used[col] = true;
     801       26986 :                 num_used++;
     802       26986 :                 transformed->state_index_to_ket_index[col] = row;
     803       26986 :                 transformed->ket_index_to_state_index[row] = col;
     804             :             }
     805       85789 :             if (num_used == transformed->coefficients.matrix.cols()) {
     806         562 :                 break;
     807             :             }
     808             :         }
     809         566 :         if (num_used != transformed->coefficients.matrix.cols()) {
     810           4 :             SPDLOG_WARN("A bijective map between states and kets could not be found.");
     811             :         }
     812         566 :     }
     813             : 
     814         566 :     return transformed;
     815         566 : }
     816             : 
     817             : template <typename Derived>
     818      196750 : size_t Basis<Derived>::hash::operator()(const std::shared_ptr<const ket_t> &k) const {
     819      196750 :     return typename ket_t::hash()(*k);
     820             : }
     821             : 
     822             : template <typename Derived>
     823        1104 : bool Basis<Derived>::equal_to::operator()(const std::shared_ptr<const ket_t> &lhs,
     824             :                                           const std::shared_ptr<const ket_t> &rhs) const {
     825        1104 :     return *lhs == *rhs;
     826             : }
     827             : 
     828             : // Explicit instantiations
     829             : template class Basis<BasisAtom<double>>;
     830             : template class Basis<BasisAtom<std::complex<double>>>;
     831             : template class Basis<BasisPair<double>>;
     832             : template class Basis<BasisPair<std::complex<double>>>;
     833             : } // namespace pairinteraction

Generated by: LCOV version 1.16