Line data Source code
1 : // SPDX-FileCopyrightText: 2024 Pairinteraction Developers 2 : // SPDX-License-Identifier: LGPL-3.0-or-later 3 : 4 : #include "pairinteraction/diagonalize/DiagonalizerEigen.hpp" 5 : 6 : #include "pairinteraction/enums/FloatType.hpp" 7 : #include "pairinteraction/utils/eigen_assertion.hpp" 8 : #include "pairinteraction/utils/eigen_compat.hpp" 9 : #include "pairinteraction/utils/traits.hpp" 10 : 11 : #include <Eigen/Dense> 12 : #include <Eigen/Eigenvalues> 13 : #include <cmath> 14 : 15 : namespace pairinteraction { 16 : 17 : template <typename Scalar> 18 34 : DiagonalizerEigen<Scalar>::DiagonalizerEigen(FloatType float_type) 19 34 : : DiagonalizerInterface<Scalar>(float_type) {} 20 : 21 : template <typename Scalar> 22 : template <typename ScalarLim> 23 : EigenSystemH<Scalar> 24 314 : DiagonalizerEigen<Scalar>::dispatch_eigh(const Eigen::SparseMatrix<Scalar, Eigen::RowMajor> &matrix, 25 : double rtol) const { 26 : using real_t = typename traits::NumTraits<Scalar>::real_t; 27 314 : int dim = matrix.rows(); 28 : 29 : // Subtract the mean of the diagonal elements from the diagonal 30 314 : real_t shift{}; 31 314 : Eigen::MatrixX<ScalarLim> shifted_matrix = 32 : this->template subtract_mean<ScalarLim>(matrix, shift, rtol); 33 : 34 : // Diagonalize the shifted matrix 35 314 : Eigen::SelfAdjointEigenSolver<Eigen::MatrixX<ScalarLim>> eigensolver; 36 314 : eigensolver.compute(shifted_matrix); 37 : 38 628 : return {eigensolver.eigenvectors() 39 314 : .sparseView(1, 0.5 * rtol / std::sqrt(dim)) 40 312 : .template cast<Scalar>(), 41 314 : this->add_mean(eigensolver.eigenvalues(), shift)}; 42 628 : } 43 : 44 : template <typename Scalar> 45 : EigenSystemH<Scalar> 46 314 : DiagonalizerEigen<Scalar>::eigh(const Eigen::SparseMatrix<Scalar, Eigen::RowMajor> &matrix, 47 : double rtol) const { 48 314 : switch (this->float_type) { 49 2 : case FloatType::FLOAT32: 50 2 : return dispatch_eigh<traits::restricted_t<Scalar, FloatType::FLOAT32>>(matrix, rtol); 51 312 : case FloatType::FLOAT64: 52 312 : return dispatch_eigh<traits::restricted_t<Scalar, FloatType::FLOAT64>>(matrix, rtol); 53 0 : default: 54 0 : throw std::invalid_argument("Unsupported floating point precision."); 55 : } 56 : } 57 : 58 : // Explicit instantiations 59 : template class DiagonalizerEigen<double>; 60 : template class DiagonalizerEigen<std::complex<double>>; 61 : } // namespace pairinteraction