LCOV - code coverage report
Current view: top level - src/operator - Operator.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 86 140 61.4 %
Date: 2025-05-02 21:49:55 Functions: 37 92 40.2 %

          Line data    Source code
       1             : // SPDX-FileCopyrightText: 2024 Pairinteraction Developers
       2             : // SPDX-License-Identifier: LGPL-3.0-or-later
       3             : 
       4             : #include "pairinteraction/operator/Operator.hpp"
       5             : 
       6             : #include "pairinteraction/basis/BasisAtom.hpp"
       7             : #include "pairinteraction/basis/BasisPair.hpp"
       8             : #include "pairinteraction/enums/OperatorType.hpp"
       9             : #include "pairinteraction/enums/TransformationType.hpp"
      10             : #include "pairinteraction/ket/KetAtom.hpp"
      11             : #include "pairinteraction/ket/KetPair.hpp"
      12             : #include "pairinteraction/operator/OperatorAtom.hpp"
      13             : #include "pairinteraction/operator/OperatorPair.hpp"
      14             : #include "pairinteraction/utils/eigen_assertion.hpp"
      15             : 
      16             : #include <Eigen/SparseCore>
      17             : #include <memory>
      18             : 
      19             : namespace pairinteraction {
      20             : template <typename Derived>
      21        1093 : Operator<Derived>::Operator(std::shared_ptr<const basis_t> basis) : basis(std::move(basis)) {
      22        1093 :     this->matrix = Eigen::SparseMatrix<scalar_t, Eigen::RowMajor>(
      23        1093 :         this->basis->get_number_of_states(), this->basis->get_number_of_states());
      24        1093 : }
      25             : 
      26             : template <typename Derived>
      27         281 : void Operator<Derived>::initialize_as_energy_operator() {
      28         281 :     Eigen::SparseMatrix<scalar_t, Eigen::RowMajor> tmp(this->basis->get_number_of_kets(),
      29         281 :                                                        this->basis->get_number_of_kets());
      30         281 :     tmp.reserve(Eigen::VectorXi::Constant(this->basis->get_number_of_kets(), 1));
      31         281 :     size_t idx = 0;
      32       47174 :     for (const auto &ket : this->basis->get_kets()) {
      33       46893 :         tmp.insert(idx, idx) = ket->get_energy();
      34       46893 :         ++idx;
      35             :     }
      36         281 :     tmp.makeCompressed();
      37             : 
      38         281 :     this->matrix =
      39         281 :         this->basis->get_coefficients().adjoint() * tmp * this->basis->get_coefficients();
      40         281 : }
      41             : 
      42             : template <typename Derived>
      43         627 : void Operator<Derived>::initialize_from_matrix(
      44             :     Eigen::SparseMatrix<scalar_t, Eigen::RowMajor> &&matrix) {
      45        1254 :     if (static_cast<size_t>(matrix.rows()) != this->basis->get_number_of_states() ||
      46         627 :         static_cast<size_t>(matrix.cols()) != this->basis->get_number_of_states()) {
      47           0 :         throw std::invalid_argument("The matrix has the wrong dimensions.");
      48             :     }
      49         627 :     this->matrix = std::move(matrix);
      50         627 : }
      51             : 
      52             : template <typename Derived>
      53         366 : const Derived &Operator<Derived>::derived() const {
      54         366 :     return static_cast<const Derived &>(*this);
      55             : }
      56             : 
      57             : template <typename Derived>
      58         183 : Derived &Operator<Derived>::derived_mutable() {
      59         183 :     return static_cast<Derived &>(*this);
      60             : }
      61             : 
      62             : template <typename Derived>
      63           0 : std::shared_ptr<const typename Operator<Derived>::basis_t> Operator<Derived>::get_basis() const {
      64           0 :     return basis;
      65             : }
      66             : 
      67             : template <typename Derived>
      68        1022 : std::shared_ptr<const typename Operator<Derived>::basis_t> &Operator<Derived>::get_basis() {
      69        1022 :     return basis;
      70             : }
      71             : 
      72             : template <typename Derived>
      73             : const Eigen::SparseMatrix<typename Operator<Derived>::scalar_t, Eigen::RowMajor> &
      74           0 : Operator<Derived>::get_matrix() const {
      75           0 :     return matrix;
      76             : }
      77             : 
      78             : template <typename Derived>
      79             : Eigen::SparseMatrix<typename Operator<Derived>::scalar_t, Eigen::RowMajor> &
      80        1448 : Operator<Derived>::get_matrix() {
      81        1448 :     return matrix;
      82             : }
      83             : 
      84             : template <typename Derived>
      85             : const Transformation<typename Operator<Derived>::scalar_t> &
      86           0 : Operator<Derived>::get_transformation() const {
      87           0 :     return basis->get_transformation();
      88             : }
      89             : 
      90             : template <typename Derived>
      91             : Transformation<typename Operator<Derived>::scalar_t>
      92           0 : Operator<Derived>::get_rotator(real_t alpha, real_t beta, real_t gamma) const {
      93           0 :     return basis->get_rotator(alpha, beta, gamma);
      94             : }
      95             : 
      96             : template <typename Derived>
      97         183 : Sorting Operator<Derived>::get_sorter(const std::vector<TransformationType> &labels) const {
      98         183 :     basis->perform_sorter_checks(labels);
      99             : 
     100             :     // Split labels into three parts (one before SORT_BY_ENERGY, one with SORT_BY_ENERGY, and one
     101             :     // after)
     102         183 :     auto it = std::find(labels.begin(), labels.end(), TransformationType::SORT_BY_ENERGY);
     103         183 :     std::vector<TransformationType> before_energy(labels.begin(), it);
     104         183 :     bool contains_energy = (it != labels.end());
     105         183 :     std::vector<TransformationType> after_energy(contains_energy ? it + 1 : labels.end(),
     106             :                                                  labels.end());
     107             : 
     108             :     // Initialize transformation
     109         183 :     Sorting transformation;
     110         183 :     transformation.matrix.resize(matrix.rows());
     111         183 :     transformation.matrix.setIdentity();
     112             : 
     113             :     // Apply sorting for labels before SORT_BY_ENERGY
     114         183 :     if (!before_energy.empty()) {
     115          71 :         basis->get_sorter_without_checks(before_energy, transformation);
     116             :     }
     117             : 
     118             :     // Apply SORT_BY_ENERGY if present
     119         183 :     if (contains_energy) {
     120         112 :         std::vector<real_t> energies_of_states;
     121         112 :         energies_of_states.reserve(matrix.rows());
     122        9194 :         for (int i = 0; i < matrix.rows(); ++i) {
     123        9082 :             energies_of_states.push_back(std::real(matrix.coeff(i, i)));
     124             :         }
     125             : 
     126         112 :         std::stable_sort(
     127         112 :             transformation.matrix.indices().data(),
     128         112 :             transformation.matrix.indices().data() + transformation.matrix.indices().size(),
     129       42687 :             [&](int i, int j) { return energies_of_states[i] < energies_of_states[j]; });
     130             : 
     131         112 :         transformation.transformation_type.push_back(TransformationType::SORT_BY_ENERGY);
     132         112 :     }
     133             : 
     134             :     // Apply sorting for labels after SORT_BY_ENERGY
     135         183 :     if (!after_energy.empty()) {
     136           0 :         basis->get_sorter_without_checks(after_energy, transformation);
     137             :     }
     138             : 
     139             :     // Check if all labels have been used for sorting
     140         183 :     if (labels != transformation.transformation_type) {
     141           0 :         throw std::invalid_argument("The states could not be sorted by all the requested labels.");
     142             :     }
     143             : 
     144         366 :     return transformation;
     145         183 : }
     146             : 
     147             : template <typename Derived>
     148             : std::vector<IndicesOfBlock>
     149         111 : Operator<Derived>::get_indices_of_blocks(const std::vector<TransformationType> &labels) const {
     150         111 :     basis->perform_sorter_checks(labels);
     151             : 
     152         111 :     std::set<TransformationType> unique_labels(labels.begin(), labels.end());
     153         111 :     basis->perform_blocks_checks(unique_labels);
     154             : 
     155             :     // Split labels into two parts (one with SORT_BY_ENERGY and one without)
     156         111 :     auto it = unique_labels.find(TransformationType::SORT_BY_ENERGY);
     157         111 :     bool contains_energy = (it != unique_labels.end());
     158         111 :     if (contains_energy) {
     159           0 :         unique_labels.erase(it);
     160             :     }
     161             : 
     162             :     // Initialize blocks
     163         111 :     IndicesOfBlocksCreator blocks_creator({0, static_cast<size_t>(matrix.rows())});
     164             : 
     165             :     // Handle all labels except SORT_BY_ENERGY
     166         111 :     if (!unique_labels.empty()) {
     167          71 :         basis->get_indices_of_blocks_without_checks(unique_labels, blocks_creator);
     168             :     }
     169             : 
     170             :     // Handle SORT_BY_ENERGY if present
     171         111 :     if (contains_energy) {
     172           0 :         scalar_t last_energy = std::real(matrix.coeff(0, 0));
     173           0 :         for (int i = 0; i < matrix.rows(); ++i) {
     174           0 :             if (std::real(matrix.coeff(i, i)) != last_energy) {
     175           0 :                 blocks_creator.add(i);
     176           0 :                 last_energy = std::real(matrix.coeff(i, i));
     177             :             }
     178             :         }
     179             :     }
     180             : 
     181         222 :     return blocks_creator.create();
     182         111 : }
     183             : 
     184             : template <typename Derived>
     185           0 : Derived Operator<Derived>::transformed(
     186             :     const Transformation<typename Operator<Derived>::scalar_t> &transformation) const {
     187           0 :     auto transformed = derived();
     188           0 :     if (matrix.cols() == 0) {
     189           0 :         return transformed;
     190             :     }
     191           0 :     transformed.matrix = transformation.matrix.adjoint() * matrix * transformation.matrix;
     192           0 :     transformed.basis = basis->transformed(transformation);
     193           0 :     return transformed;
     194           0 : }
     195             : 
     196             : template <typename Derived>
     197         183 : Derived Operator<Derived>::transformed(const Sorting &transformation) const {
     198         183 :     auto transformed = derived();
     199         183 :     if (matrix.cols() == 0) {
     200           0 :         return transformed;
     201             :     }
     202         183 :     transformed.matrix = matrix.twistedBy(transformation.matrix.inverse());
     203         183 :     transformed.basis = basis->transformed(transformation);
     204         183 :     return transformed;
     205           0 : }
     206             : 
     207             : // Overloaded operators
     208             : template <typename Derived>
     209         183 : Derived operator*(const typename Operator<Derived>::scalar_t &lhs, const Operator<Derived> &rhs) {
     210         183 :     Derived result = rhs.derived();
     211         183 :     result.matrix *= lhs;
     212         183 :     return result;
     213           0 : }
     214             : 
     215             : template <typename Derived>
     216           0 : Derived operator*(const Operator<Derived> &lhs, const typename Operator<Derived>::scalar_t &rhs) {
     217           0 :     Derived result = lhs.derived();
     218           0 :     result.matrix *= rhs;
     219           0 :     return result;
     220           0 : }
     221             : 
     222             : template <typename Derived>
     223           0 : Derived operator/(const Operator<Derived> &lhs, const typename Operator<Derived>::scalar_t &rhs) {
     224           0 :     Derived result = lhs.derived();
     225           0 :     result.matrix /= rhs;
     226           0 :     return result;
     227           0 : }
     228             : 
     229             : template <typename Derived>
     230           5 : Derived &operator+=(Operator<Derived> &lhs, const Operator<Derived> &rhs) {
     231           5 :     if (lhs.basis != rhs.basis) {
     232           0 :         throw std::invalid_argument("The basis of the operators is not the same.");
     233             :     }
     234           5 :     lhs.matrix += rhs.matrix;
     235           5 :     return lhs.derived_mutable();
     236             : }
     237             : 
     238             : template <typename Derived>
     239         178 : Derived &operator-=(Operator<Derived> &lhs, const Operator<Derived> &rhs) {
     240         178 :     if (lhs.basis != rhs.basis) {
     241           0 :         throw std::invalid_argument("The basis of the operators is not the same.");
     242             :     }
     243         178 :     lhs.matrix -= rhs.matrix;
     244         178 :     return lhs.derived_mutable();
     245             : }
     246             : 
     247             : template <typename Derived>
     248           0 : Derived operator+(const Operator<Derived> &lhs, const Operator<Derived> &rhs) {
     249           0 :     if (lhs.basis != rhs.basis) {
     250           0 :         throw std::invalid_argument("The basis of the operators is not the same.");
     251             :     }
     252           0 :     Derived result = lhs.derived();
     253           0 :     result.matrix += rhs.matrix;
     254           0 :     return result;
     255           0 : }
     256             : 
     257             : template <typename Derived>
     258           0 : Derived operator-(const Operator<Derived> &lhs, const Operator<Derived> &rhs) {
     259           0 :     if (lhs.basis != rhs.basis) {
     260           0 :         throw std::invalid_argument("The basis of the operators is not the same.");
     261             :     }
     262           0 :     Derived result = lhs.derived();
     263           0 :     result.matrix -= rhs.matrix;
     264           0 :     return result;
     265           0 : }
     266             : 
     267             : // Explicit instantiations
     268             : // NOLINTBEGIN(bugprone-macro-parentheses, cppcoreguidelines-macro-usage)
     269             : #define INSTANTIATE_OPERATOR_HELPER(SCALAR, TYPE)                                                  \
     270             :     template class Operator<TYPE<SCALAR>>;                                                         \
     271             :     template TYPE<SCALAR> operator*(const SCALAR &lhs, const Operator<TYPE<SCALAR>> &rhs);         \
     272             :     template TYPE<SCALAR> operator*(const Operator<TYPE<SCALAR>> &lhs, const SCALAR &rhs);         \
     273             :     template TYPE<SCALAR> operator/(const Operator<TYPE<SCALAR>> &lhs, const SCALAR &rhs);         \
     274             :     template TYPE<SCALAR> &operator+=(Operator<TYPE<SCALAR>> &lhs,                                 \
     275             :                                       const Operator<TYPE<SCALAR>> &rhs);                          \
     276             :     template TYPE<SCALAR> &operator-=(Operator<TYPE<SCALAR>> &lhs,                                 \
     277             :                                       const Operator<TYPE<SCALAR>> &rhs);                          \
     278             :     template TYPE<SCALAR> operator+(const Operator<TYPE<SCALAR>> &lhs,                             \
     279             :                                     const Operator<TYPE<SCALAR>> &rhs);                            \
     280             :     template TYPE<SCALAR> operator-(const Operator<TYPE<SCALAR>> &lhs,                             \
     281             :                                     const Operator<TYPE<SCALAR>> &rhs);
     282             : #define INSTANTIATE_OPERATOR(SCALAR)                                                               \
     283             :     INSTANTIATE_OPERATOR_HELPER(SCALAR, OperatorAtom)                                              \
     284             :     INSTANTIATE_OPERATOR_HELPER(SCALAR, OperatorPair)
     285             : // NOLINTEND(bugprone-macro-parentheses, cppcoreguidelines-macro-usage)
     286             : 
     287             : INSTANTIATE_OPERATOR(double)
     288             : INSTANTIATE_OPERATOR(std::complex<double>)
     289             : 
     290             : #undef INSTANTIATE_OPERATOR_HELPER
     291             : #undef INSTANTIATE_OPERATOR
     292             : 
     293             : } // namespace pairinteraction

Generated by: LCOV version 1.16