Line data Source code
1 : // SPDX-FileCopyrightText: 2024 PairInteraction Developers
2 : // SPDX-License-Identifier: LGPL-3.0-or-later
3 :
4 : #include "pairinteraction/basis/Basis.hpp"
5 :
6 : #include "pairinteraction/basis/BasisAtom.hpp"
7 : #include "pairinteraction/basis/BasisPair.hpp"
8 : #include "pairinteraction/enums/Parity.hpp"
9 : #include "pairinteraction/enums/TransformationType.hpp"
10 : #include "pairinteraction/ket/KetAtom.hpp"
11 : #include "pairinteraction/ket/KetPair.hpp"
12 : #include "pairinteraction/utils/eigen_assertion.hpp"
13 : #include "pairinteraction/utils/eigen_compat.hpp"
14 : #include "pairinteraction/utils/wigner.hpp"
15 :
16 : #include <cassert>
17 : #include <numeric>
18 : #include <set>
19 :
20 : namespace pairinteraction {
21 :
22 : template <typename Scalar>
23 : class BasisAtom;
24 :
25 : template <typename Derived>
26 550 : void Basis<Derived>::perform_sorter_checks(const std::vector<TransformationType> &labels) const {
27 : // Check if the labels are valid sorting labels
28 1106 : for (const auto &label : labels) {
29 556 : if (!utils::is_sorting(label)) {
30 0 : throw std::invalid_argument("One of the labels is not a valid sorting label.");
31 : }
32 : }
33 550 : }
34 :
35 : template <typename Derived>
36 141 : void Basis<Derived>::perform_blocks_checks(
37 : const std::set<TransformationType> &unique_labels) const {
38 : // Check if the states are sorted by the requested labels
39 141 : std::set<TransformationType> unique_labels_present;
40 265 : for (const auto &label : get_transformation().transformation_type) {
41 164 : if (!utils::is_sorting(label) || unique_labels_present.size() >= unique_labels.size()) {
42 40 : break;
43 : }
44 124 : unique_labels_present.insert(label);
45 : }
46 141 : if (unique_labels != unique_labels_present) {
47 0 : throw std::invalid_argument("The states are not sorted by the requested labels.");
48 : }
49 :
50 : // Throw a meaningful error if getting the blocks by energy is requested as this might be a
51 : // common mistake
52 141 : if (unique_labels.count(TransformationType::SORT_BY_ENERGY) > 0) {
53 0 : throw std::invalid_argument("States do not store the energy and thus no energy blocks can "
54 : "be obtained. Use an energy operator instead.");
55 : }
56 141 : }
57 :
58 : template <typename Derived>
59 210 : Basis<Derived>::Basis(ketvec_t &&kets)
60 420 : : kets(std::move(kets)), coefficients{{static_cast<Eigen::Index>(this->kets.size()),
61 210 : static_cast<Eigen::Index>(this->kets.size())},
62 420 : {TransformationType::SORT_BY_KET}} {
63 210 : if (this->kets.empty()) {
64 0 : throw std::invalid_argument("The basis must contain at least one element.");
65 : }
66 210 : state_index_to_quantum_number_f.reserve(this->kets.size());
67 210 : state_index_to_quantum_number_m.reserve(this->kets.size());
68 210 : state_index_to_parity.reserve(this->kets.size());
69 210 : ket_to_ket_index.reserve(this->kets.size());
70 210 : size_t index = 0;
71 81146 : for (const auto &ket : this->kets) {
72 80936 : state_index_to_quantum_number_f.push_back(ket->get_quantum_number_f());
73 80936 : state_index_to_quantum_number_m.push_back(ket->get_quantum_number_m());
74 80936 : state_index_to_parity.push_back(ket->get_parity());
75 80936 : ket_to_ket_index[ket] = index++;
76 80936 : if (ket->get_quantum_number_f() == std::numeric_limits<real_t>::max()) {
77 69647 : _has_quantum_number_f = false;
78 : }
79 80936 : if (ket->get_quantum_number_m() == std::numeric_limits<real_t>::max()) {
80 0 : _has_quantum_number_m = false;
81 : }
82 80936 : if (ket->get_parity() == Parity::UNKNOWN) {
83 69647 : _has_parity = false;
84 : }
85 : }
86 210 : state_index_to_ket_index.resize(this->kets.size());
87 210 : std::iota(state_index_to_ket_index.begin(), state_index_to_ket_index.end(), 0);
88 210 : ket_index_to_state_index.resize(this->kets.size());
89 210 : std::iota(ket_index_to_state_index.begin(), ket_index_to_state_index.end(), 0);
90 210 : coefficients.matrix.setIdentity();
91 210 : }
92 :
93 : template <typename Derived>
94 85 : bool Basis<Derived>::has_quantum_number_f() const {
95 85 : return _has_quantum_number_f;
96 : }
97 :
98 : template <typename Derived>
99 149019 : bool Basis<Derived>::has_quantum_number_m() const {
100 149019 : return _has_quantum_number_m;
101 : }
102 :
103 : template <typename Derived>
104 85 : bool Basis<Derived>::has_parity() const {
105 85 : return _has_parity;
106 : }
107 :
108 : template <typename Derived>
109 859 : const Derived &Basis<Derived>::derived() const {
110 859 : return static_cast<const Derived &>(*this);
111 : }
112 :
113 : template <typename Derived>
114 302 : const typename Basis<Derived>::ketvec_t &Basis<Derived>::get_kets() const {
115 302 : return kets;
116 : }
117 :
118 : template <typename Derived>
119 : const Eigen::SparseMatrix<typename Basis<Derived>::scalar_t, Eigen::RowMajor> &
120 2569 : Basis<Derived>::get_coefficients() const {
121 2569 : return coefficients.matrix;
122 : }
123 :
124 : template <typename Derived>
125 : Eigen::SparseMatrix<typename Basis<Derived>::scalar_t, Eigen::RowMajor> &
126 0 : Basis<Derived>::get_coefficients() {
127 0 : return coefficients.matrix;
128 : }
129 :
130 : template <typename Derived>
131 5 : void Basis<Derived>::set_coefficients(
132 : const Eigen::SparseMatrix<scalar_t, Eigen::RowMajor> &values) {
133 5 : if (values.rows() != coefficients.matrix.rows()) {
134 0 : throw std::invalid_argument("Incompatible number of rows.");
135 : }
136 5 : if (values.cols() != coefficients.matrix.cols()) {
137 0 : throw std::invalid_argument("Incompatible number of columns.");
138 : }
139 :
140 5 : coefficients.matrix = values;
141 :
142 5 : std::fill(ket_index_to_state_index.begin(), ket_index_to_state_index.end(),
143 5 : std::numeric_limits<int>::max());
144 5 : std::fill(state_index_to_ket_index.begin(), state_index_to_ket_index.end(),
145 5 : std::numeric_limits<int>::max());
146 5 : std::fill(state_index_to_quantum_number_f.begin(), state_index_to_quantum_number_f.end(),
147 5 : std::numeric_limits<real_t>::max());
148 5 : std::fill(state_index_to_quantum_number_m.begin(), state_index_to_quantum_number_m.end(),
149 5 : std::numeric_limits<real_t>::max());
150 5 : std::fill(state_index_to_parity.begin(), state_index_to_parity.end(), Parity::UNKNOWN);
151 5 : _has_quantum_number_f = false;
152 5 : _has_quantum_number_m = false;
153 5 : _has_parity = false;
154 5 : }
155 :
156 : template <typename Derived>
157 308 : int Basis<Derived>::get_ket_index_from_ket(std::shared_ptr<const ket_t> ket) const {
158 308 : if (ket_to_ket_index.count(ket) == 0) {
159 0 : return -1;
160 : }
161 308 : return ket_to_ket_index.at(ket);
162 : }
163 :
164 : template <typename Derived>
165 : Eigen::VectorX<typename Basis<Derived>::scalar_t>
166 69 : Basis<Derived>::get_amplitudes(std::shared_ptr<const ket_t> ket) const {
167 69 : int ket_index = get_ket_index_from_ket(ket);
168 69 : if (ket_index < 0) {
169 0 : throw std::invalid_argument("The ket does not belong to the basis.");
170 : }
171 : // The following line is a more efficient alternative to
172 : // "get_amplitudes(get_canonical_state_from_ket(ket)).transpose()"
173 138 : return coefficients.matrix.row(ket_index);
174 : }
175 :
176 : template <typename Derived>
177 : Eigen::SparseMatrix<typename Basis<Derived>::scalar_t, Eigen::RowMajor>
178 6 : Basis<Derived>::get_amplitudes(std::shared_ptr<const Derived> other) const {
179 12 : return other->coefficients.matrix.adjoint() * coefficients.matrix;
180 : }
181 :
182 : template <typename Derived>
183 : Eigen::VectorX<typename Basis<Derived>::real_t>
184 67 : Basis<Derived>::get_overlaps(std::shared_ptr<const ket_t> ket) const {
185 67 : return get_amplitudes(ket).cwiseAbs2();
186 : }
187 :
188 : template <typename Derived>
189 : Eigen::SparseMatrix<typename Basis<Derived>::real_t, Eigen::RowMajor>
190 3 : Basis<Derived>::get_overlaps(std::shared_ptr<const Derived> other) const {
191 3 : return get_amplitudes(other).cwiseAbs2();
192 : }
193 :
194 : template <typename Derived>
195 0 : typename Basis<Derived>::real_t Basis<Derived>::get_quantum_number_f(size_t state_index) const {
196 0 : real_t quantum_number_f = state_index_to_quantum_number_f.at(state_index);
197 0 : if (quantum_number_f == std::numeric_limits<real_t>::max()) {
198 0 : throw std::invalid_argument("The state does not have a well-defined quantum number f.");
199 : }
200 0 : return quantum_number_f;
201 : }
202 :
203 : template <typename Derived>
204 148976 : typename Basis<Derived>::real_t Basis<Derived>::get_quantum_number_m(size_t state_index) const {
205 148976 : real_t quantum_number_m = state_index_to_quantum_number_m.at(state_index);
206 148976 : if (quantum_number_m == std::numeric_limits<real_t>::max()) {
207 0 : throw std::invalid_argument("The state does not have a well-defined quantum number m.");
208 : }
209 148976 : return quantum_number_m;
210 : }
211 :
212 : template <typename Derived>
213 43 : Parity Basis<Derived>::get_parity(size_t state_index) const {
214 43 : Parity parity = state_index_to_parity.at(state_index);
215 43 : if (parity == Parity::UNKNOWN) {
216 0 : throw std::invalid_argument("The state does not have a well-defined parity.");
217 : }
218 43 : return parity;
219 : }
220 :
221 : template <typename Derived>
222 : std::shared_ptr<const typename Basis<Derived>::ket_t>
223 64 : Basis<Derived>::get_corresponding_ket(size_t state_index) const {
224 64 : size_t ket_index = state_index_to_ket_index.at(state_index);
225 64 : if (ket_index == std::numeric_limits<int>::max()) {
226 0 : throw std::invalid_argument("The state does not belong to a ket in a well-defined way.");
227 : }
228 64 : return kets[ket_index];
229 : }
230 :
231 : template <typename Derived>
232 : std::shared_ptr<const typename Basis<Derived>::ket_t>
233 0 : Basis<Derived>::get_corresponding_ket(std::shared_ptr<const Derived> /*state*/) const {
234 0 : throw std::runtime_error("Not implemented yet.");
235 : }
236 :
237 : template <typename Derived>
238 190 : std::shared_ptr<const Derived> Basis<Derived>::get_state(size_t state_index) const {
239 : // Create a copy of the current object
240 190 : auto restricted = std::make_shared<Derived>(derived());
241 :
242 : // Restrict the copy to the state with the largest overlap
243 190 : restricted->coefficients.matrix = restricted->coefficients.matrix.col(state_index);
244 :
245 190 : std::fill(restricted->ket_index_to_state_index.begin(),
246 190 : restricted->ket_index_to_state_index.end(), std::numeric_limits<int>::max());
247 :
248 190 : size_t ket_index = state_index_to_ket_index[state_index];
249 190 : restricted->state_index_to_ket_index = {ket_index};
250 190 : if (ket_index != std::numeric_limits<int>::max()) {
251 190 : restricted->ket_index_to_state_index[ket_index] = 0;
252 : }
253 :
254 190 : restricted->state_index_to_quantum_number_f = {state_index_to_quantum_number_f[state_index]};
255 190 : restricted->state_index_to_quantum_number_m = {state_index_to_quantum_number_m[state_index]};
256 190 : restricted->state_index_to_parity = {state_index_to_parity[state_index]};
257 :
258 380 : restricted->_has_quantum_number_f =
259 190 : restricted->state_index_to_quantum_number_f[0] != std::numeric_limits<real_t>::max();
260 380 : restricted->_has_quantum_number_m =
261 190 : restricted->state_index_to_quantum_number_m[0] != std::numeric_limits<real_t>::max();
262 190 : restricted->_has_parity = restricted->state_index_to_parity[0] != Parity::UNKNOWN;
263 :
264 380 : return restricted;
265 190 : }
266 :
267 : template <typename Derived>
268 : std::shared_ptr<const typename Basis<Derived>::ket_t>
269 0 : Basis<Derived>::get_ket(size_t ket_index) const {
270 0 : return kets[ket_index];
271 : }
272 :
273 : template <typename Derived>
274 28 : std::shared_ptr<const Derived> Basis<Derived>::get_corresponding_state(size_t ket_index) const {
275 28 : size_t state_index = ket_index_to_state_index.at(ket_index);
276 28 : if (state_index == std::numeric_limits<int>::max()) {
277 0 : throw std::runtime_error("The ket does not belong to a state in a well-defined way.");
278 : }
279 28 : return get_state(state_index);
280 : }
281 :
282 : template <typename Derived>
283 : std::shared_ptr<const Derived>
284 28 : Basis<Derived>::get_corresponding_state(std::shared_ptr<const ket_t> ket) const {
285 28 : int ket_index = get_ket_index_from_ket(ket);
286 28 : if (ket_index < 0) {
287 0 : throw std::invalid_argument("The ket does not belong to the basis.");
288 : }
289 28 : return get_corresponding_state(ket_index);
290 : }
291 :
292 : template <typename Derived>
293 92 : size_t Basis<Derived>::get_corresponding_state_index(size_t ket_index) const {
294 92 : int state_index = ket_index_to_state_index.at(ket_index);
295 92 : if (state_index == std::numeric_limits<int>::max()) {
296 0 : throw std::runtime_error("The ket does not belong to a state in a well-defined way.");
297 : }
298 92 : return state_index;
299 : }
300 :
301 : template <typename Derived>
302 92 : size_t Basis<Derived>::get_corresponding_state_index(std::shared_ptr<const ket_t> ket) const {
303 92 : int ket_index = get_ket_index_from_ket(ket);
304 92 : if (ket_index < 0) {
305 0 : throw std::invalid_argument("The ket does not belong to the basis.");
306 : }
307 92 : return get_corresponding_state_index(ket_index);
308 : }
309 :
310 : template <typename Derived>
311 0 : size_t Basis<Derived>::get_corresponding_ket_index(size_t state_index) const {
312 0 : size_t ket_index = state_index_to_ket_index.at(state_index);
313 0 : if (ket_index == std::numeric_limits<int>::max()) {
314 0 : throw std::runtime_error("The state does not belong to a ket in a well-defined way.");
315 : }
316 0 : return ket_index;
317 : }
318 :
319 : template <typename Derived>
320 0 : size_t Basis<Derived>::get_corresponding_ket_index(std::shared_ptr<const Derived> /*state*/) const {
321 0 : throw std::runtime_error("Not implemented yet.");
322 : }
323 :
324 : template <typename Derived>
325 : std::shared_ptr<const Derived>
326 119 : Basis<Derived>::get_canonical_state_from_ket(size_t ket_index) const {
327 : // Create a copy of the current object
328 119 : auto created = std::make_shared<Derived>(derived());
329 :
330 : // Fill the copy with the state corresponding to the ket index
331 119 : created->coefficients.matrix =
332 238 : Eigen::SparseMatrix<scalar_t, Eigen::RowMajor>(coefficients.matrix.rows(), 1);
333 119 : created->coefficients.matrix.coeffRef(ket_index, 0) = 1;
334 119 : created->coefficients.matrix.makeCompressed();
335 :
336 119 : std::fill(created->ket_index_to_state_index.begin(), created->ket_index_to_state_index.end(),
337 119 : std::numeric_limits<int>::max());
338 119 : created->ket_index_to_state_index[ket_index] = 0;
339 :
340 119 : created->state_index_to_quantum_number_f = {kets[ket_index]->get_quantum_number_f()};
341 119 : created->state_index_to_quantum_number_m = {kets[ket_index]->get_quantum_number_m()};
342 119 : created->state_index_to_parity = {kets[ket_index]->get_parity()};
343 119 : created->state_index_to_ket_index = {ket_index};
344 :
345 238 : created->_has_quantum_number_f =
346 119 : created->state_index_to_quantum_number_f[0] != std::numeric_limits<real_t>::max();
347 238 : created->_has_quantum_number_m =
348 119 : created->state_index_to_quantum_number_m[0] != std::numeric_limits<real_t>::max();
349 119 : created->_has_parity = created->state_index_to_parity[0] != Parity::UNKNOWN;
350 :
351 238 : return created;
352 119 : }
353 :
354 : template <typename Derived>
355 : std::shared_ptr<const Derived>
356 119 : Basis<Derived>::get_canonical_state_from_ket(std::shared_ptr<const ket_t> ket) const {
357 119 : int ket_index = get_ket_index_from_ket(ket);
358 119 : if (ket_index < 0) {
359 0 : throw std::invalid_argument("The ket does not belong to the basis.");
360 : }
361 119 : return get_canonical_state_from_ket(ket_index);
362 : }
363 :
364 : template <typename Derived>
365 4 : typename Basis<Derived>::Iterator Basis<Derived>::begin() const {
366 4 : return kets.begin();
367 : }
368 :
369 : template <typename Derived>
370 4 : typename Basis<Derived>::Iterator Basis<Derived>::end() const {
371 4 : return kets.end();
372 : }
373 :
374 : template <typename Derived>
375 8 : Basis<Derived>::Iterator::Iterator(typename ketvec_t::const_iterator it) : it{std::move(it)} {}
376 :
377 : template <typename Derived>
378 80 : bool Basis<Derived>::Iterator::operator!=(const Iterator &other) const {
379 80 : return other.it != it;
380 : }
381 :
382 : template <typename Derived>
383 76 : std::shared_ptr<const typename Basis<Derived>::ket_t> Basis<Derived>::Iterator::operator*() const {
384 76 : return *it;
385 : }
386 :
387 : template <typename Derived>
388 76 : typename Basis<Derived>::Iterator &Basis<Derived>::Iterator::operator++() {
389 76 : ++it;
390 76 : return *this;
391 : }
392 :
393 : template <typename Derived>
394 4563 : size_t Basis<Derived>::get_number_of_states() const {
395 4563 : return coefficients.matrix.cols();
396 : }
397 :
398 : template <typename Derived>
399 1931 : size_t Basis<Derived>::get_number_of_kets() const {
400 1931 : return coefficients.matrix.rows();
401 : }
402 :
403 : template <typename Derived>
404 : const Transformation<typename Basis<Derived>::scalar_t> &
405 142 : Basis<Derived>::get_transformation() const {
406 142 : return coefficients;
407 : }
408 :
409 : template <typename Derived>
410 : Transformation<typename Basis<Derived>::scalar_t>
411 0 : Basis<Derived>::get_rotator(real_t alpha, real_t beta, real_t gamma) const {
412 0 : Transformation<scalar_t> transformation{{static_cast<Eigen::Index>(coefficients.matrix.rows()),
413 : static_cast<Eigen::Index>(coefficients.matrix.rows())},
414 : {TransformationType::ROTATE}};
415 :
416 0 : std::vector<Eigen::Triplet<scalar_t>> entries;
417 :
418 0 : for (size_t idx_initial = 0; idx_initial < kets.size(); ++idx_initial) {
419 0 : real_t f = kets[idx_initial]->get_quantum_number_f();
420 0 : real_t m_initial = kets[idx_initial]->get_quantum_number_m();
421 :
422 0 : assert(2 * f == std::floor(2 * f) && f >= 0);
423 0 : assert(2 * m_initial == std::floor(2 * m_initial) && m_initial >= -f && m_initial <= f);
424 :
425 0 : for (real_t m_final = -f; m_final <= f; // NOSONAR m_final is precisely representable
426 : ++m_final) {
427 0 : auto val = wigner::wigner_uppercase_d_matrix<scalar_t>(f, m_initial, m_final, alpha,
428 : beta, gamma);
429 0 : size_t idx_final = get_ket_index_from_ket(
430 0 : kets[idx_initial]->get_ket_for_different_quantum_number_m(m_final));
431 0 : entries.emplace_back(idx_final, idx_initial, val);
432 : }
433 : }
434 :
435 0 : transformation.matrix.setFromTriplets(entries.begin(), entries.end());
436 0 : transformation.matrix.makeCompressed();
437 :
438 0 : return transformation;
439 0 : }
440 :
441 : template <typename Derived>
442 1 : Sorting Basis<Derived>::get_sorter(const std::vector<TransformationType> &labels) const {
443 1 : perform_sorter_checks(labels);
444 :
445 : // Throw a meaningful error if sorting by energy is requested as this might be a common mistake
446 1 : if (std::find(labels.begin(), labels.end(), TransformationType::SORT_BY_ENERGY) !=
447 2 : labels.end()) {
448 0 : throw std::invalid_argument("States do not store the energy and thus can not be sorted by "
449 : "the energy. Use an energy operator instead.");
450 : }
451 :
452 : // Initialize transformation
453 1 : Sorting transformation;
454 1 : transformation.matrix.resize(coefficients.matrix.cols());
455 1 : transformation.matrix.setIdentity();
456 :
457 : // Get the sorter
458 1 : get_sorter_without_checks(labels, transformation);
459 :
460 : // Check if all labels have been used for sorting
461 1 : if (labels != transformation.transformation_type) {
462 0 : throw std::invalid_argument("The states could not be sorted by all the requested labels.");
463 : }
464 :
465 1 : return transformation;
466 0 : }
467 :
468 : template <typename Derived>
469 : std::vector<IndicesOfBlock>
470 1 : Basis<Derived>::get_indices_of_blocks(const std::vector<TransformationType> &labels) const {
471 1 : perform_sorter_checks(labels);
472 :
473 1 : std::set<TransformationType> unique_labels(labels.begin(), labels.end());
474 1 : perform_blocks_checks(unique_labels);
475 :
476 : // Get the blocks
477 1 : IndicesOfBlocksCreator blocks_creator({0, static_cast<size_t>(coefficients.matrix.cols())});
478 1 : get_indices_of_blocks_without_checks(unique_labels, blocks_creator);
479 :
480 2 : return blocks_creator.create();
481 1 : }
482 :
483 : template <typename Derived>
484 101 : void Basis<Derived>::get_sorter_without_checks(const std::vector<TransformationType> &labels,
485 : Sorting &transformation) const {
486 101 : constexpr real_t numerical_precision = 100 * std::numeric_limits<real_t>::epsilon();
487 :
488 101 : int *perm_begin = transformation.matrix.indices().data();
489 101 : int *perm_end = perm_begin + coefficients.matrix.cols();
490 101 : const int *perm_back = perm_end - 1;
491 :
492 : // Sort the vector based on the requested labels
493 254106 : std::stable_sort(perm_begin, perm_end, [&](int a, int b) {
494 93408 : for (const auto &label : labels) {
495 70162 : switch (label) {
496 17609 : case TransformationType::SORT_BY_PARITY:
497 17609 : if (state_index_to_parity[a] != state_index_to_parity[b]) {
498 30145 : return state_index_to_parity[a] < state_index_to_parity[b];
499 : }
500 11581 : break;
501 52553 : case TransformationType::SORT_BY_QUANTUM_NUMBER_M:
502 105106 : if (std::abs(state_index_to_quantum_number_m[a] -
503 105106 : state_index_to_quantum_number_m[b]) > numerical_precision) {
504 24117 : return state_index_to_quantum_number_m[a] < state_index_to_quantum_number_m[b];
505 : }
506 28436 : break;
507 0 : case TransformationType::SORT_BY_QUANTUM_NUMBER_F:
508 0 : if (std::abs(state_index_to_quantum_number_f[a] -
509 0 : state_index_to_quantum_number_f[b]) > numerical_precision) {
510 0 : return state_index_to_quantum_number_f[a] < state_index_to_quantum_number_f[b];
511 : }
512 0 : break;
513 0 : case TransformationType::SORT_BY_KET:
514 0 : if (state_index_to_ket_index[a] != state_index_to_ket_index[b]) {
515 0 : return state_index_to_ket_index[a] < state_index_to_ket_index[b];
516 : }
517 0 : break;
518 0 : default:
519 0 : std::abort(); // Can't happen because of previous checks
520 : }
521 : }
522 23246 : return false; // Elements are equal
523 : });
524 :
525 : // Check for invalid values and add transformation types
526 225 : for (const auto &label : labels) {
527 124 : switch (label) {
528 25 : case TransformationType::SORT_BY_PARITY:
529 25 : if (state_index_to_parity[*perm_back] == Parity::UNKNOWN) {
530 0 : throw std::invalid_argument(
531 : "States cannot be labeled and thus not sorted by the parity.");
532 : }
533 25 : transformation.transformation_type.push_back(TransformationType::SORT_BY_PARITY);
534 25 : break;
535 99 : case TransformationType::SORT_BY_QUANTUM_NUMBER_M:
536 99 : if (state_index_to_quantum_number_m[*perm_back] == std::numeric_limits<real_t>::max()) {
537 0 : throw std::invalid_argument(
538 : "States cannot be labeled and thus not sorted by the quantum number m.");
539 : }
540 99 : transformation.transformation_type.push_back(
541 99 : TransformationType::SORT_BY_QUANTUM_NUMBER_M);
542 99 : break;
543 0 : case TransformationType::SORT_BY_QUANTUM_NUMBER_F:
544 0 : if (state_index_to_quantum_number_f[*perm_back] == std::numeric_limits<real_t>::max()) {
545 0 : throw std::invalid_argument(
546 : "States cannot be labeled and thus not sorted by the quantum number f.");
547 : }
548 0 : transformation.transformation_type.push_back(
549 0 : TransformationType::SORT_BY_QUANTUM_NUMBER_F);
550 0 : break;
551 0 : case TransformationType::SORT_BY_KET:
552 0 : if (state_index_to_ket_index[*perm_back] == std::numeric_limits<int>::max()) {
553 0 : throw std::invalid_argument(
554 : "States cannot be labeled and thus not sorted by kets.");
555 : }
556 0 : transformation.transformation_type.push_back(TransformationType::SORT_BY_KET);
557 0 : break;
558 0 : default:
559 0 : std::abort(); // Can't happen because of previous checks
560 : }
561 : }
562 101 : }
563 :
564 : template <typename Derived>
565 101 : void Basis<Derived>::get_indices_of_blocks_without_checks(
566 : const std::set<TransformationType> &unique_labels,
567 : IndicesOfBlocksCreator &blocks_creator) const {
568 101 : constexpr real_t numerical_precision = 100 * std::numeric_limits<real_t>::epsilon();
569 :
570 101 : auto last_quantum_number_f = state_index_to_quantum_number_f[0];
571 101 : auto last_quantum_number_m = state_index_to_quantum_number_m[0];
572 101 : auto last_parity = state_index_to_parity[0];
573 101 : auto last_ket = state_index_to_ket_index[0];
574 :
575 8850 : for (int i = 0; i < coefficients.matrix.cols(); ++i) {
576 21145 : for (auto label : unique_labels) {
577 12764 : if (label == TransformationType::SORT_BY_QUANTUM_NUMBER_F &&
578 0 : std::abs(state_index_to_quantum_number_f[i] - last_quantum_number_f) >
579 : numerical_precision) {
580 0 : blocks_creator.add(i);
581 0 : break;
582 : }
583 21333 : if (label == TransformationType::SORT_BY_QUANTUM_NUMBER_M &&
584 8569 : std::abs(state_index_to_quantum_number_m[i] - last_quantum_number_m) >
585 : numerical_precision) {
586 278 : blocks_creator.add(i);
587 278 : break;
588 : }
589 16681 : if (label == TransformationType::SORT_BY_PARITY &&
590 4195 : state_index_to_parity[i] != last_parity) {
591 90 : blocks_creator.add(i);
592 90 : break;
593 : }
594 12396 : if (label == TransformationType::SORT_BY_KET &&
595 12396 : state_index_to_ket_index[i] != last_ket &&
596 0 : last_ket != std::numeric_limits<int>::max()) {
597 0 : blocks_creator.add(i);
598 0 : break;
599 : }
600 : }
601 8749 : last_quantum_number_f = state_index_to_quantum_number_f[i];
602 8749 : last_quantum_number_m = state_index_to_quantum_number_m[i];
603 8749 : last_parity = state_index_to_parity[i];
604 8749 : last_ket = state_index_to_ket_index[i];
605 : }
606 101 : }
607 :
608 : template <typename Derived>
609 409 : std::shared_ptr<const Derived> Basis<Derived>::transformed(const Sorting &transformation) const {
610 : // Create a copy of the current object
611 409 : auto transformed = std::make_shared<Derived>(derived());
612 :
613 409 : if (coefficients.matrix.cols() == 0) {
614 0 : return transformed;
615 : }
616 :
617 : // Apply the transformation
618 409 : transformed->coefficients.matrix = coefficients.matrix * transformation.matrix;
619 409 : transformed->coefficients.transformation_type = transformation.transformation_type;
620 :
621 409 : transformed->state_index_to_quantum_number_f.resize(transformation.matrix.size());
622 409 : transformed->state_index_to_quantum_number_m.resize(transformation.matrix.size());
623 409 : transformed->state_index_to_parity.resize(transformation.matrix.size());
624 409 : transformed->state_index_to_ket_index.resize(transformation.matrix.size());
625 :
626 67475 : for (int i = 0; i < transformation.matrix.size(); ++i) {
627 67066 : transformed->state_index_to_quantum_number_f[i] =
628 67066 : state_index_to_quantum_number_f[transformation.matrix.indices()[i]];
629 67066 : transformed->state_index_to_quantum_number_m[i] =
630 67066 : state_index_to_quantum_number_m[transformation.matrix.indices()[i]];
631 67066 : transformed->state_index_to_parity[i] =
632 67066 : state_index_to_parity[transformation.matrix.indices()[i]];
633 :
634 67066 : size_t ket_index = state_index_to_ket_index[transformation.matrix.indices()[i]];
635 67066 : transformed->state_index_to_ket_index[i] = ket_index;
636 67066 : if (ket_index != std::numeric_limits<int>::max()) {
637 67066 : transformed->ket_index_to_state_index[ket_index] = i;
638 : }
639 : }
640 :
641 409 : return transformed;
642 409 : }
643 :
644 : template <typename Derived>
645 : std::shared_ptr<const Derived>
646 141 : Basis<Derived>::transformed(const Transformation<scalar_t> &transformation) const {
647 : // TODO why is "numerical_precision = 100 * std::sqrt(coefficients.matrix.rows()) *
648 : // std::numeric_limits<real_t>::epsilon()" too small for figuring out whether m is conserved?
649 141 : real_t numerical_precision = 0.001;
650 :
651 : // If the transformation is a rotation, it should be a rotation and nothing else
652 141 : bool is_rotation = false;
653 282 : for (auto t : transformation.transformation_type) {
654 141 : if (t == TransformationType::ROTATE) {
655 0 : is_rotation = true;
656 0 : break;
657 : }
658 : }
659 141 : if (is_rotation && transformation.transformation_type.size() != 1) {
660 0 : throw std::invalid_argument("A rotation can not be combined with other transformations.");
661 : }
662 :
663 : // To apply a rotation, the object must only be sorted but other transformations are not allowed
664 141 : if (is_rotation) {
665 0 : for (auto t : coefficients.transformation_type) {
666 0 : if (!utils::is_sorting(t)) {
667 0 : throw std::runtime_error(
668 : "If the object was transformed by a different transformation "
669 : "than sorting, it can not be rotated.");
670 : }
671 : }
672 : }
673 :
674 : // Create a copy of the current object
675 141 : auto transformed = std::make_shared<Derived>(derived());
676 :
677 141 : if (coefficients.matrix.cols() == 0) {
678 0 : return transformed;
679 : }
680 :
681 : // Apply the transformation
682 : // If a quantum number turns out to be conserved by the transformation, it will be
683 : // rounded to the nearest half integer to avoid loss of numerical_precision.
684 141 : transformed->coefficients.matrix = coefficients.matrix * transformation.matrix;
685 141 : transformed->coefficients.transformation_type = transformation.transformation_type;
686 :
687 141 : Eigen::SparseMatrix<real_t> probs = transformation.matrix.cwiseAbs2().transpose();
688 :
689 : {
690 282 : auto map = Eigen::Map<const Eigen::VectorX<real_t>>(state_index_to_quantum_number_f.data(),
691 141 : state_index_to_quantum_number_f.size());
692 141 : Eigen::VectorX<real_t> val = probs * map;
693 141 : Eigen::VectorX<real_t> sq = probs * map.cwiseAbs2();
694 141 : Eigen::VectorX<real_t> diff = (val.cwiseAbs2() - sq).cwiseAbs();
695 141 : transformed->state_index_to_quantum_number_f.resize(probs.rows());
696 :
697 10355 : for (size_t i = 0; i < transformed->state_index_to_quantum_number_f.size(); ++i) {
698 10214 : if (diff[i] < numerical_precision) {
699 3005 : transformed->state_index_to_quantum_number_f[i] = std::round(val[i] * 2) / 2;
700 : } else {
701 7209 : transformed->state_index_to_quantum_number_f[i] =
702 7209 : std::numeric_limits<real_t>::max();
703 7209 : transformed->_has_quantum_number_f = false;
704 : }
705 : }
706 141 : }
707 :
708 : {
709 282 : auto map = Eigen::Map<const Eigen::VectorX<real_t>>(state_index_to_quantum_number_m.data(),
710 141 : state_index_to_quantum_number_m.size());
711 141 : Eigen::VectorX<real_t> val = probs * map;
712 141 : Eigen::VectorX<real_t> sq = probs * map.cwiseAbs2();
713 141 : Eigen::VectorX<real_t> diff = (val.cwiseAbs2() - sq).cwiseAbs();
714 141 : transformed->state_index_to_quantum_number_m.resize(probs.rows());
715 :
716 10355 : for (size_t i = 0; i < transformed->state_index_to_quantum_number_m.size(); ++i) {
717 10214 : if (diff[i] < numerical_precision) {
718 8274 : transformed->state_index_to_quantum_number_m[i] = std::round(val[i] * 2) / 2;
719 : } else {
720 1940 : transformed->state_index_to_quantum_number_m[i] =
721 1940 : std::numeric_limits<real_t>::max();
722 1940 : transformed->_has_quantum_number_m = false;
723 : }
724 : }
725 141 : }
726 :
727 : {
728 : using utype = std::underlying_type<Parity>::type;
729 141 : Eigen::VectorX<real_t> map(state_index_to_parity.size());
730 10914 : for (size_t i = 0; i < state_index_to_parity.size(); ++i) {
731 10773 : map[i] = static_cast<utype>(state_index_to_parity[i]);
732 : }
733 141 : Eigen::VectorX<real_t> val = probs * map;
734 141 : Eigen::VectorX<real_t> sq = probs * map.cwiseAbs2();
735 141 : Eigen::VectorX<real_t> diff = (val.cwiseAbs2() - sq).cwiseAbs();
736 141 : transformed->state_index_to_parity.resize(probs.rows());
737 :
738 10355 : for (size_t i = 0; i < transformed->state_index_to_parity.size(); ++i) {
739 10214 : if (diff[i] < numerical_precision) {
740 6706 : transformed->state_index_to_parity[i] = static_cast<Parity>(std::lround(val[i]));
741 : } else {
742 3508 : transformed->state_index_to_parity[i] = Parity::UNKNOWN;
743 3508 : transformed->_has_parity = false;
744 : }
745 : }
746 141 : }
747 :
748 : {
749 : // In the following, we obtain a bijective map between state index and ket index.
750 :
751 : // Find the maximum value in each row and column
752 141 : std::vector<real_t> max_in_row(transformed->coefficients.matrix.rows(), 0);
753 141 : std::vector<real_t> max_in_col(transformed->coefficients.matrix.cols(), 0);
754 10914 : for (int row = 0; row < transformed->coefficients.matrix.outerSize(); ++row) {
755 10773 : for (typename Eigen::SparseMatrix<scalar_t, Eigen::RowMajor>::InnerIterator it(
756 10773 : transformed->coefficients.matrix, row);
757 248008 : it; ++it) {
758 237235 : real_t val = std::pow(std::abs(it.value()), 2);
759 237235 : max_in_row[row] = std::max(max_in_row[row], val);
760 237235 : max_in_col[it.col()] = std::max(max_in_col[it.col()], val);
761 : }
762 : }
763 :
764 : // Use the maximum values to define a cost for a sub-optimal mapping
765 141 : std::vector<real_t> costs;
766 141 : std::vector<std::pair<int, int>> mappings;
767 141 : costs.reserve(transformed->coefficients.matrix.nonZeros());
768 141 : mappings.reserve(transformed->coefficients.matrix.nonZeros());
769 10914 : for (int row = 0; row < transformed->coefficients.matrix.outerSize(); ++row) {
770 10773 : for (typename Eigen::SparseMatrix<scalar_t, Eigen::RowMajor>::InnerIterator it(
771 10773 : transformed->coefficients.matrix, row);
772 248008 : it; ++it) {
773 237235 : real_t val = std::pow(std::abs(it.value()), 2);
774 237235 : real_t cost = max_in_row[row] + max_in_col[it.col()] - 2 * val;
775 237235 : costs.push_back(cost);
776 237235 : mappings.push_back({row, it.col()});
777 : }
778 : }
779 :
780 : // Obtain from the costs in which order the mappings should be considered
781 141 : std::vector<size_t> order(costs.size());
782 141 : std::iota(order.begin(), order.end(), 0);
783 141 : std::sort(order.begin(), order.end(),
784 3472964 : [&](size_t a, size_t b) { return costs[a] < costs[b]; });
785 :
786 : // Fill ket_index_to_state_index with invalid values as there can be more kets than states
787 141 : std::fill(transformed->ket_index_to_state_index.begin(),
788 141 : transformed->ket_index_to_state_index.end(), std::numeric_limits<int>::max());
789 :
790 : // Generate the bijective map
791 141 : std::vector<bool> row_used(transformed->coefficients.matrix.rows(), false);
792 141 : std::vector<bool> col_used(transformed->coefficients.matrix.cols(), false);
793 141 : int num_used = 0;
794 21344 : for (size_t idx : order) {
795 21342 : int row = mappings[idx].first; // corresponds to the ket index
796 21342 : int col = mappings[idx].second; // corresponds to the state index
797 21342 : if (!row_used[row] && !col_used[col]) {
798 10214 : row_used[row] = true;
799 10214 : col_used[col] = true;
800 10214 : num_used++;
801 10214 : transformed->state_index_to_ket_index[col] = row;
802 10214 : transformed->ket_index_to_state_index[row] = col;
803 : }
804 21342 : if (num_used == transformed->coefficients.matrix.cols()) {
805 139 : break;
806 : }
807 : }
808 141 : assert(num_used == transformed->coefficients.matrix.cols());
809 141 : }
810 :
811 141 : return transformed;
812 141 : }
813 :
814 : template <typename Derived>
815 81552 : size_t Basis<Derived>::hash::operator()(const std::shared_ptr<const ket_t> &k) const {
816 81552 : return typename ket_t::hash()(*k);
817 : }
818 :
819 : template <typename Derived>
820 616 : bool Basis<Derived>::equal_to::operator()(const std::shared_ptr<const ket_t> &lhs,
821 : const std::shared_ptr<const ket_t> &rhs) const {
822 616 : return *lhs == *rhs;
823 : }
824 :
825 : // Explicit instantiations
826 : template class Basis<BasisAtom<double>>;
827 : template class Basis<BasisAtom<std::complex<double>>>;
828 : template class Basis<BasisPair<double>>;
829 : template class Basis<BasisPair<std::complex<double>>>;
830 : } // namespace pairinteraction
|