Line data Source code
1 : // SPDX-FileCopyrightText: 2024 Pairinteraction Developers 2 : // SPDX-License-Identifier: LGPL-3.0-or-later 3 : 4 : #include "pairinteraction/operator/Operator.hpp" 5 : 6 : #include "pairinteraction/basis/BasisAtom.hpp" 7 : #include "pairinteraction/basis/BasisPair.hpp" 8 : #include "pairinteraction/enums/OperatorType.hpp" 9 : #include "pairinteraction/enums/TransformationType.hpp" 10 : #include "pairinteraction/ket/KetAtom.hpp" 11 : #include "pairinteraction/ket/KetPair.hpp" 12 : #include "pairinteraction/operator/OperatorAtom.hpp" 13 : #include "pairinteraction/operator/OperatorPair.hpp" 14 : #include "pairinteraction/utils/eigen_assertion.hpp" 15 : 16 : #include <Eigen/SparseCore> 17 : #include <memory> 18 : 19 : namespace pairinteraction { 20 : template <typename Derived> 21 308 : Operator<Derived>::Operator(std::shared_ptr<const basis_t> basis) : basis(std::move(basis)) { 22 306 : this->matrix = Eigen::SparseMatrix<scalar_t, Eigen::RowMajor>( 23 308 : this->basis->get_number_of_states(), this->basis->get_number_of_states()); 24 307 : } 25 : 26 : template <typename Derived> 27 68 : void Operator<Derived>::initialize_as_energy_operator() { 28 68 : Eigen::SparseMatrix<scalar_t, Eigen::RowMajor> tmp(this->basis->get_number_of_kets(), 29 68 : this->basis->get_number_of_kets()); 30 68 : tmp.reserve(Eigen::VectorXi::Constant(this->basis->get_number_of_kets(), 1)); 31 68 : size_t idx = 0; 32 3117 : for (const auto &ket : this->basis->get_kets()) { 33 3058 : tmp.insert(idx, idx) = ket->get_energy(); 34 3045 : ++idx; 35 : } 36 61 : tmp.makeCompressed(); 37 : 38 68 : this->matrix = 39 68 : this->basis->get_coefficients().adjoint() * tmp * this->basis->get_coefficients(); 40 68 : } 41 : 42 : template <typename Derived> 43 169 : void Operator<Derived>::initialize_from_matrix( 44 : Eigen::SparseMatrix<scalar_t, Eigen::RowMajor> &&matrix) { 45 338 : if (static_cast<size_t>(matrix.rows()) != this->basis->get_number_of_states() || 46 169 : static_cast<size_t>(matrix.cols()) != this->basis->get_number_of_states()) { 47 0 : throw std::invalid_argument("The matrix has the wrong dimensions."); 48 : } 49 169 : this->matrix = std::move(matrix); 50 169 : } 51 : 52 : template <typename Derived> 53 187 : const Derived &Operator<Derived>::derived() const { 54 187 : return static_cast<const Derived &>(*this); 55 : } 56 : 57 : template <typename Derived> 58 111 : Derived &Operator<Derived>::derived_mutable() { 59 111 : return static_cast<Derived &>(*this); 60 : } 61 : 62 : template <typename Derived> 63 0 : std::shared_ptr<const typename Operator<Derived>::basis_t> Operator<Derived>::get_basis() const { 64 0 : return basis; 65 : } 66 : 67 : template <typename Derived> 68 381 : std::shared_ptr<const typename Operator<Derived>::basis_t> &Operator<Derived>::get_basis() { 69 381 : return basis; 70 : } 71 : 72 : template <typename Derived> 73 : const Eigen::SparseMatrix<typename Operator<Derived>::scalar_t, Eigen::RowMajor> & 74 0 : Operator<Derived>::get_matrix() const { 75 0 : return matrix; 76 : } 77 : 78 : template <typename Derived> 79 : Eigen::SparseMatrix<typename Operator<Derived>::scalar_t, Eigen::RowMajor> & 80 503 : Operator<Derived>::get_matrix() { 81 503 : return matrix; 82 : } 83 : 84 : template <typename Derived> 85 : const Transformation<typename Operator<Derived>::scalar_t> & 86 0 : Operator<Derived>::get_transformation() const { 87 0 : return basis->get_transformation(); 88 : } 89 : 90 : template <typename Derived> 91 : Transformation<typename Operator<Derived>::scalar_t> 92 0 : Operator<Derived>::get_rotator(real_t alpha, real_t beta, real_t gamma) const { 93 0 : return basis->get_rotator(alpha, beta, gamma); 94 : } 95 : 96 : template <typename Derived> 97 74 : Sorting Operator<Derived>::get_sorter(const std::vector<TransformationType> &labels) const { 98 74 : basis->perform_sorter_checks(labels); 99 : 100 : // Split labels into three parts (one before SORT_BY_ENERGY, one with SORT_BY_ENERGY, and one 101 : // after) 102 74 : auto it = std::find(labels.begin(), labels.end(), TransformationType::SORT_BY_ENERGY); 103 74 : std::vector<TransformationType> before_energy(labels.begin(), it); 104 74 : bool contains_energy = (it != labels.end()); 105 74 : std::vector<TransformationType> after_energy(contains_energy ? it + 1 : labels.end(), 106 : labels.end()); 107 : 108 : // Initialize transformation 109 74 : Sorting transformation; 110 74 : transformation.matrix.resize(matrix.rows()); 111 74 : transformation.matrix.setIdentity(); 112 : 113 : // Apply sorting for labels before SORT_BY_ENERGY 114 74 : if (!before_energy.empty()) { 115 32 : basis->get_sorter_without_checks(before_energy, transformation); 116 : } 117 : 118 : // Apply SORT_BY_ENERGY if present 119 74 : if (contains_energy) { 120 42 : std::vector<real_t> energies_of_states; 121 42 : energies_of_states.reserve(matrix.rows()); 122 3065 : for (int i = 0; i < matrix.rows(); ++i) { 123 3023 : energies_of_states.push_back(std::real(matrix.coeff(i, i))); 124 : } 125 : 126 42 : std::stable_sort( 127 42 : transformation.matrix.indices().data(), 128 42 : transformation.matrix.indices().data() + transformation.matrix.indices().size(), 129 14634 : [&](int i, int j) { return energies_of_states[i] < energies_of_states[j]; }); 130 : 131 42 : transformation.transformation_type.push_back(TransformationType::SORT_BY_ENERGY); 132 42 : } 133 : 134 : // Apply sorting for labels after SORT_BY_ENERGY 135 74 : if (!after_energy.empty()) { 136 0 : basis->get_sorter_without_checks(after_energy, transformation); 137 : } 138 : 139 : // Check if all labels have been used for sorting 140 74 : if (labels != transformation.transformation_type) { 141 0 : throw std::invalid_argument("The states could not be sorted by all the requested labels."); 142 : } 143 : 144 148 : return transformation; 145 74 : } 146 : 147 : template <typename Derived> 148 : std::vector<IndicesOfBlock> 149 58 : Operator<Derived>::get_indices_of_blocks(const std::vector<TransformationType> &labels) const { 150 58 : basis->perform_sorter_checks(labels); 151 : 152 58 : std::set<TransformationType> unique_labels(labels.begin(), labels.end()); 153 58 : basis->perform_blocks_checks(unique_labels); 154 : 155 : // Split labels into two parts (one with SORT_BY_ENERGY and one without) 156 58 : auto it = unique_labels.find(TransformationType::SORT_BY_ENERGY); 157 58 : bool contains_energy = (it != unique_labels.end()); 158 58 : if (contains_energy) { 159 0 : unique_labels.erase(it); 160 : } 161 : 162 : // Initialize blocks 163 58 : IndicesOfBlocksCreator blocks_creator({0, static_cast<size_t>(matrix.rows())}); 164 : 165 : // Handle all labels except SORT_BY_ENERGY 166 58 : if (!unique_labels.empty()) { 167 32 : basis->get_indices_of_blocks_without_checks(unique_labels, blocks_creator); 168 : } 169 : 170 : // Handle SORT_BY_ENERGY if present 171 58 : if (contains_energy) { 172 0 : scalar_t last_energy = std::real(matrix.coeff(0, 0)); 173 0 : for (int i = 0; i < matrix.rows(); ++i) { 174 0 : if (std::real(matrix.coeff(i, i)) != last_energy) { 175 0 : blocks_creator.add(i); 176 0 : last_energy = std::real(matrix.coeff(i, i)); 177 : } 178 : } 179 : } 180 : 181 116 : return blocks_creator.create(); 182 58 : } 183 : 184 : template <typename Derived> 185 0 : Derived Operator<Derived>::transformed( 186 : const Transformation<typename Operator<Derived>::scalar_t> &transformation) const { 187 0 : auto transformed = derived(); 188 0 : if (matrix.cols() == 0) { 189 0 : return transformed; 190 : } 191 0 : transformed.matrix = transformation.matrix.adjoint() * matrix * transformation.matrix; 192 0 : transformed.basis = basis->transformed(transformation); 193 0 : return transformed; 194 0 : } 195 : 196 : template <typename Derived> 197 74 : Derived Operator<Derived>::transformed(const Sorting &transformation) const { 198 74 : auto transformed = derived(); 199 74 : if (matrix.cols() == 0) { 200 0 : return transformed; 201 : } 202 74 : transformed.matrix = matrix.twistedBy(transformation.matrix.inverse()); 203 74 : transformed.basis = basis->transformed(transformation); 204 74 : return transformed; 205 0 : } 206 : 207 : // Overloaded operators 208 : template <typename Derived> 209 112 : Derived operator*(const typename Operator<Derived>::scalar_t &lhs, const Operator<Derived> &rhs) { 210 112 : Derived result = rhs.derived(); 211 112 : result.matrix *= lhs; 212 112 : return result; 213 0 : } 214 : 215 : template <typename Derived> 216 0 : Derived operator*(const Operator<Derived> &lhs, const typename Operator<Derived>::scalar_t &rhs) { 217 0 : Derived result = lhs.derived(); 218 0 : result.matrix *= rhs; 219 0 : return result; 220 0 : } 221 : 222 : template <typename Derived> 223 0 : Derived operator/(const Operator<Derived> &lhs, const typename Operator<Derived>::scalar_t &rhs) { 224 0 : Derived result = lhs.derived(); 225 0 : result.matrix /= rhs; 226 0 : return result; 227 0 : } 228 : 229 : template <typename Derived> 230 0 : Derived &operator+=(Operator<Derived> &lhs, const Operator<Derived> &rhs) { 231 0 : if (lhs.basis != rhs.basis) { 232 0 : throw std::invalid_argument("The basis of the operators is not the same."); 233 : } 234 0 : lhs.matrix += rhs.matrix; 235 0 : return lhs.derived_mutable(); 236 : } 237 : 238 : template <typename Derived> 239 111 : Derived &operator-=(Operator<Derived> &lhs, const Operator<Derived> &rhs) { 240 111 : if (lhs.basis != rhs.basis) { 241 0 : throw std::invalid_argument("The basis of the operators is not the same."); 242 : } 243 111 : lhs.matrix -= rhs.matrix; 244 111 : return lhs.derived_mutable(); 245 : } 246 : 247 : template <typename Derived> 248 1 : Derived operator+(const Operator<Derived> &lhs, const Operator<Derived> &rhs) { 249 1 : if (lhs.basis != rhs.basis) { 250 0 : throw std::invalid_argument("The basis of the operators is not the same."); 251 : } 252 1 : Derived result = lhs.derived(); 253 1 : result.matrix += rhs.matrix; 254 1 : return result; 255 0 : } 256 : 257 : template <typename Derived> 258 0 : Derived operator-(const Operator<Derived> &lhs, const Operator<Derived> &rhs) { 259 0 : if (lhs.basis != rhs.basis) { 260 0 : throw std::invalid_argument("The basis of the operators is not the same."); 261 : } 262 0 : Derived result = lhs.derived(); 263 0 : result.matrix -= rhs.matrix; 264 0 : return result; 265 0 : } 266 : 267 : // Explicit instantiations 268 : // NOLINTBEGIN(bugprone-macro-parentheses, cppcoreguidelines-macro-usage) 269 : #define INSTANTIATE_OPERATOR_HELPER(SCALAR, TYPE) \ 270 : template class Operator<TYPE<SCALAR>>; \ 271 : template TYPE<SCALAR> operator*(const SCALAR &lhs, const Operator<TYPE<SCALAR>> &rhs); \ 272 : template TYPE<SCALAR> operator*(const Operator<TYPE<SCALAR>> &lhs, const SCALAR &rhs); \ 273 : template TYPE<SCALAR> operator/(const Operator<TYPE<SCALAR>> &lhs, const SCALAR &rhs); \ 274 : template TYPE<SCALAR> &operator+=(Operator<TYPE<SCALAR>> &lhs, \ 275 : const Operator<TYPE<SCALAR>> &rhs); \ 276 : template TYPE<SCALAR> &operator-=(Operator<TYPE<SCALAR>> &lhs, \ 277 : const Operator<TYPE<SCALAR>> &rhs); \ 278 : template TYPE<SCALAR> operator+(const Operator<TYPE<SCALAR>> &lhs, \ 279 : const Operator<TYPE<SCALAR>> &rhs); \ 280 : template TYPE<SCALAR> operator-(const Operator<TYPE<SCALAR>> &lhs, \ 281 : const Operator<TYPE<SCALAR>> &rhs); 282 : #define INSTANTIATE_OPERATOR(SCALAR) \ 283 : INSTANTIATE_OPERATOR_HELPER(SCALAR, OperatorAtom) \ 284 : INSTANTIATE_OPERATOR_HELPER(SCALAR, OperatorPair) 285 : // NOLINTEND(bugprone-macro-parentheses, cppcoreguidelines-macro-usage) 286 : 287 : INSTANTIATE_OPERATOR(double) 288 : INSTANTIATE_OPERATOR(std::complex<double>) 289 : 290 : #undef INSTANTIATE_OPERATOR_HELPER 291 : #undef INSTANTIATE_OPERATOR 292 : 293 : } // namespace pairinteraction