Line data Source code
1 : // SPDX-FileCopyrightText: 2024 Pairinteraction Developers 2 : // SPDX-License-Identifier: LGPL-3.0-or-later 3 : 4 : #include "pairinteraction/operator/Operator.hpp" 5 : 6 : #include "pairinteraction/basis/BasisAtom.hpp" 7 : #include "pairinteraction/basis/BasisPair.hpp" 8 : #include "pairinteraction/enums/OperatorType.hpp" 9 : #include "pairinteraction/enums/TransformationType.hpp" 10 : #include "pairinteraction/ket/KetAtom.hpp" 11 : #include "pairinteraction/ket/KetPair.hpp" 12 : #include "pairinteraction/operator/OperatorAtom.hpp" 13 : #include "pairinteraction/operator/OperatorPair.hpp" 14 : #include "pairinteraction/utils/eigen_assertion.hpp" 15 : 16 : #include <Eigen/SparseCore> 17 : #include <memory> 18 : 19 : namespace pairinteraction { 20 : template <typename Derived> 21 1093 : Operator<Derived>::Operator(std::shared_ptr<const basis_t> basis) : basis(std::move(basis)) { 22 1093 : this->matrix = Eigen::SparseMatrix<scalar_t, Eigen::RowMajor>( 23 1093 : this->basis->get_number_of_states(), this->basis->get_number_of_states()); 24 1093 : } 25 : 26 : template <typename Derived> 27 281 : void Operator<Derived>::initialize_as_energy_operator() { 28 281 : Eigen::SparseMatrix<scalar_t, Eigen::RowMajor> tmp(this->basis->get_number_of_kets(), 29 281 : this->basis->get_number_of_kets()); 30 281 : tmp.reserve(Eigen::VectorXi::Constant(this->basis->get_number_of_kets(), 1)); 31 281 : size_t idx = 0; 32 47174 : for (const auto &ket : this->basis->get_kets()) { 33 46893 : tmp.insert(idx, idx) = ket->get_energy(); 34 46893 : ++idx; 35 : } 36 281 : tmp.makeCompressed(); 37 : 38 281 : this->matrix = 39 281 : this->basis->get_coefficients().adjoint() * tmp * this->basis->get_coefficients(); 40 281 : } 41 : 42 : template <typename Derived> 43 627 : void Operator<Derived>::initialize_from_matrix( 44 : Eigen::SparseMatrix<scalar_t, Eigen::RowMajor> &&matrix) { 45 1254 : if (static_cast<size_t>(matrix.rows()) != this->basis->get_number_of_states() || 46 627 : static_cast<size_t>(matrix.cols()) != this->basis->get_number_of_states()) { 47 0 : throw std::invalid_argument("The matrix has the wrong dimensions."); 48 : } 49 627 : this->matrix = std::move(matrix); 50 627 : } 51 : 52 : template <typename Derived> 53 366 : const Derived &Operator<Derived>::derived() const { 54 366 : return static_cast<const Derived &>(*this); 55 : } 56 : 57 : template <typename Derived> 58 183 : Derived &Operator<Derived>::derived_mutable() { 59 183 : return static_cast<Derived &>(*this); 60 : } 61 : 62 : template <typename Derived> 63 0 : std::shared_ptr<const typename Operator<Derived>::basis_t> Operator<Derived>::get_basis() const { 64 0 : return basis; 65 : } 66 : 67 : template <typename Derived> 68 1022 : std::shared_ptr<const typename Operator<Derived>::basis_t> &Operator<Derived>::get_basis() { 69 1022 : return basis; 70 : } 71 : 72 : template <typename Derived> 73 : const Eigen::SparseMatrix<typename Operator<Derived>::scalar_t, Eigen::RowMajor> & 74 0 : Operator<Derived>::get_matrix() const { 75 0 : return matrix; 76 : } 77 : 78 : template <typename Derived> 79 : Eigen::SparseMatrix<typename Operator<Derived>::scalar_t, Eigen::RowMajor> & 80 1448 : Operator<Derived>::get_matrix() { 81 1448 : return matrix; 82 : } 83 : 84 : template <typename Derived> 85 : const Transformation<typename Operator<Derived>::scalar_t> & 86 0 : Operator<Derived>::get_transformation() const { 87 0 : return basis->get_transformation(); 88 : } 89 : 90 : template <typename Derived> 91 : Transformation<typename Operator<Derived>::scalar_t> 92 0 : Operator<Derived>::get_rotator(real_t alpha, real_t beta, real_t gamma) const { 93 0 : return basis->get_rotator(alpha, beta, gamma); 94 : } 95 : 96 : template <typename Derived> 97 183 : Sorting Operator<Derived>::get_sorter(const std::vector<TransformationType> &labels) const { 98 183 : basis->perform_sorter_checks(labels); 99 : 100 : // Split labels into three parts (one before SORT_BY_ENERGY, one with SORT_BY_ENERGY, and one 101 : // after) 102 183 : auto it = std::find(labels.begin(), labels.end(), TransformationType::SORT_BY_ENERGY); 103 183 : std::vector<TransformationType> before_energy(labels.begin(), it); 104 183 : bool contains_energy = (it != labels.end()); 105 183 : std::vector<TransformationType> after_energy(contains_energy ? it + 1 : labels.end(), 106 : labels.end()); 107 : 108 : // Initialize transformation 109 183 : Sorting transformation; 110 183 : transformation.matrix.resize(matrix.rows()); 111 183 : transformation.matrix.setIdentity(); 112 : 113 : // Apply sorting for labels before SORT_BY_ENERGY 114 183 : if (!before_energy.empty()) { 115 71 : basis->get_sorter_without_checks(before_energy, transformation); 116 : } 117 : 118 : // Apply SORT_BY_ENERGY if present 119 183 : if (contains_energy) { 120 112 : std::vector<real_t> energies_of_states; 121 112 : energies_of_states.reserve(matrix.rows()); 122 9194 : for (int i = 0; i < matrix.rows(); ++i) { 123 9082 : energies_of_states.push_back(std::real(matrix.coeff(i, i))); 124 : } 125 : 126 112 : std::stable_sort( 127 112 : transformation.matrix.indices().data(), 128 112 : transformation.matrix.indices().data() + transformation.matrix.indices().size(), 129 42687 : [&](int i, int j) { return energies_of_states[i] < energies_of_states[j]; }); 130 : 131 112 : transformation.transformation_type.push_back(TransformationType::SORT_BY_ENERGY); 132 112 : } 133 : 134 : // Apply sorting for labels after SORT_BY_ENERGY 135 183 : if (!after_energy.empty()) { 136 0 : basis->get_sorter_without_checks(after_energy, transformation); 137 : } 138 : 139 : // Check if all labels have been used for sorting 140 183 : if (labels != transformation.transformation_type) { 141 0 : throw std::invalid_argument("The states could not be sorted by all the requested labels."); 142 : } 143 : 144 366 : return transformation; 145 183 : } 146 : 147 : template <typename Derived> 148 : std::vector<IndicesOfBlock> 149 111 : Operator<Derived>::get_indices_of_blocks(const std::vector<TransformationType> &labels) const { 150 111 : basis->perform_sorter_checks(labels); 151 : 152 111 : std::set<TransformationType> unique_labels(labels.begin(), labels.end()); 153 111 : basis->perform_blocks_checks(unique_labels); 154 : 155 : // Split labels into two parts (one with SORT_BY_ENERGY and one without) 156 111 : auto it = unique_labels.find(TransformationType::SORT_BY_ENERGY); 157 111 : bool contains_energy = (it != unique_labels.end()); 158 111 : if (contains_energy) { 159 0 : unique_labels.erase(it); 160 : } 161 : 162 : // Initialize blocks 163 111 : IndicesOfBlocksCreator blocks_creator({0, static_cast<size_t>(matrix.rows())}); 164 : 165 : // Handle all labels except SORT_BY_ENERGY 166 111 : if (!unique_labels.empty()) { 167 71 : basis->get_indices_of_blocks_without_checks(unique_labels, blocks_creator); 168 : } 169 : 170 : // Handle SORT_BY_ENERGY if present 171 111 : if (contains_energy) { 172 0 : scalar_t last_energy = std::real(matrix.coeff(0, 0)); 173 0 : for (int i = 0; i < matrix.rows(); ++i) { 174 0 : if (std::real(matrix.coeff(i, i)) != last_energy) { 175 0 : blocks_creator.add(i); 176 0 : last_energy = std::real(matrix.coeff(i, i)); 177 : } 178 : } 179 : } 180 : 181 222 : return blocks_creator.create(); 182 111 : } 183 : 184 : template <typename Derived> 185 0 : Derived Operator<Derived>::transformed( 186 : const Transformation<typename Operator<Derived>::scalar_t> &transformation) const { 187 0 : auto transformed = derived(); 188 0 : if (matrix.cols() == 0) { 189 0 : return transformed; 190 : } 191 0 : transformed.matrix = transformation.matrix.adjoint() * matrix * transformation.matrix; 192 0 : transformed.basis = basis->transformed(transformation); 193 0 : return transformed; 194 0 : } 195 : 196 : template <typename Derived> 197 183 : Derived Operator<Derived>::transformed(const Sorting &transformation) const { 198 183 : auto transformed = derived(); 199 183 : if (matrix.cols() == 0) { 200 0 : return transformed; 201 : } 202 183 : transformed.matrix = matrix.twistedBy(transformation.matrix.inverse()); 203 183 : transformed.basis = basis->transformed(transformation); 204 183 : return transformed; 205 0 : } 206 : 207 : // Overloaded operators 208 : template <typename Derived> 209 183 : Derived operator*(const typename Operator<Derived>::scalar_t &lhs, const Operator<Derived> &rhs) { 210 183 : Derived result = rhs.derived(); 211 183 : result.matrix *= lhs; 212 183 : return result; 213 0 : } 214 : 215 : template <typename Derived> 216 0 : Derived operator*(const Operator<Derived> &lhs, const typename Operator<Derived>::scalar_t &rhs) { 217 0 : Derived result = lhs.derived(); 218 0 : result.matrix *= rhs; 219 0 : return result; 220 0 : } 221 : 222 : template <typename Derived> 223 0 : Derived operator/(const Operator<Derived> &lhs, const typename Operator<Derived>::scalar_t &rhs) { 224 0 : Derived result = lhs.derived(); 225 0 : result.matrix /= rhs; 226 0 : return result; 227 0 : } 228 : 229 : template <typename Derived> 230 5 : Derived &operator+=(Operator<Derived> &lhs, const Operator<Derived> &rhs) { 231 5 : if (lhs.basis != rhs.basis) { 232 0 : throw std::invalid_argument("The basis of the operators is not the same."); 233 : } 234 5 : lhs.matrix += rhs.matrix; 235 5 : return lhs.derived_mutable(); 236 : } 237 : 238 : template <typename Derived> 239 178 : Derived &operator-=(Operator<Derived> &lhs, const Operator<Derived> &rhs) { 240 178 : if (lhs.basis != rhs.basis) { 241 0 : throw std::invalid_argument("The basis of the operators is not the same."); 242 : } 243 178 : lhs.matrix -= rhs.matrix; 244 178 : return lhs.derived_mutable(); 245 : } 246 : 247 : template <typename Derived> 248 0 : Derived operator+(const Operator<Derived> &lhs, const Operator<Derived> &rhs) { 249 0 : if (lhs.basis != rhs.basis) { 250 0 : throw std::invalid_argument("The basis of the operators is not the same."); 251 : } 252 0 : Derived result = lhs.derived(); 253 0 : result.matrix += rhs.matrix; 254 0 : return result; 255 0 : } 256 : 257 : template <typename Derived> 258 0 : Derived operator-(const Operator<Derived> &lhs, const Operator<Derived> &rhs) { 259 0 : if (lhs.basis != rhs.basis) { 260 0 : throw std::invalid_argument("The basis of the operators is not the same."); 261 : } 262 0 : Derived result = lhs.derived(); 263 0 : result.matrix -= rhs.matrix; 264 0 : return result; 265 0 : } 266 : 267 : // Explicit instantiations 268 : // NOLINTBEGIN(bugprone-macro-parentheses, cppcoreguidelines-macro-usage) 269 : #define INSTANTIATE_OPERATOR_HELPER(SCALAR, TYPE) \ 270 : template class Operator<TYPE<SCALAR>>; \ 271 : template TYPE<SCALAR> operator*(const SCALAR &lhs, const Operator<TYPE<SCALAR>> &rhs); \ 272 : template TYPE<SCALAR> operator*(const Operator<TYPE<SCALAR>> &lhs, const SCALAR &rhs); \ 273 : template TYPE<SCALAR> operator/(const Operator<TYPE<SCALAR>> &lhs, const SCALAR &rhs); \ 274 : template TYPE<SCALAR> &operator+=(Operator<TYPE<SCALAR>> &lhs, \ 275 : const Operator<TYPE<SCALAR>> &rhs); \ 276 : template TYPE<SCALAR> &operator-=(Operator<TYPE<SCALAR>> &lhs, \ 277 : const Operator<TYPE<SCALAR>> &rhs); \ 278 : template TYPE<SCALAR> operator+(const Operator<TYPE<SCALAR>> &lhs, \ 279 : const Operator<TYPE<SCALAR>> &rhs); \ 280 : template TYPE<SCALAR> operator-(const Operator<TYPE<SCALAR>> &lhs, \ 281 : const Operator<TYPE<SCALAR>> &rhs); 282 : #define INSTANTIATE_OPERATOR(SCALAR) \ 283 : INSTANTIATE_OPERATOR_HELPER(SCALAR, OperatorAtom) \ 284 : INSTANTIATE_OPERATOR_HELPER(SCALAR, OperatorPair) 285 : // NOLINTEND(bugprone-macro-parentheses, cppcoreguidelines-macro-usage) 286 : 287 : INSTANTIATE_OPERATOR(double) 288 : INSTANTIATE_OPERATOR(std::complex<double>) 289 : 290 : #undef INSTANTIATE_OPERATOR_HELPER 291 : #undef INSTANTIATE_OPERATOR 292 : 293 : } // namespace pairinteraction