Line data Source code
1 : // SPDX-FileCopyrightText: 2024 Pairinteraction Developers
2 : // SPDX-License-Identifier: LGPL-3.0-or-later
3 :
4 : #include "pairinteraction/basis/Basis.hpp"
5 :
6 : #include "pairinteraction/basis/BasisAtom.hpp"
7 : #include "pairinteraction/basis/BasisPair.hpp"
8 : #include "pairinteraction/enums/Parity.hpp"
9 : #include "pairinteraction/enums/TransformationType.hpp"
10 : #include "pairinteraction/ket/KetAtom.hpp"
11 : #include "pairinteraction/ket/KetPair.hpp"
12 : #include "pairinteraction/utils/eigen_assertion.hpp"
13 : #include "pairinteraction/utils/eigen_compat.hpp"
14 : #include "pairinteraction/utils/wigner.hpp"
15 :
16 : #include <cassert>
17 : #include <numeric>
18 : #include <set>
19 :
20 : namespace pairinteraction {
21 :
22 : template <typename Scalar>
23 : class BasisAtom;
24 :
25 : template <typename Derived>
26 296 : void Basis<Derived>::perform_sorter_checks(const std::vector<TransformationType> &labels) const {
27 : // Check if the labels are valid sorting labels
28 574 : for (const auto &label : labels) {
29 278 : if (!utils::is_sorting(label)) {
30 0 : throw std::invalid_argument("One of the labels is not a valid sorting label.");
31 : }
32 : }
33 296 : }
34 :
35 : template <typename Derived>
36 112 : void Basis<Derived>::perform_blocks_checks(
37 : const std::set<TransformationType> &unique_labels) const {
38 : // Check if the states are sorted by the requested labels
39 112 : std::set<TransformationType> unique_labels_present;
40 195 : for (const auto &label : get_transformation().transformation_type) {
41 123 : if (!utils::is_sorting(label) || unique_labels_present.size() >= unique_labels.size()) {
42 40 : break;
43 : }
44 83 : unique_labels_present.insert(label);
45 : }
46 112 : if (unique_labels != unique_labels_present) {
47 0 : throw std::invalid_argument("The states are not sorted by the requested labels.");
48 : }
49 :
50 : // Throw a meaningful error if getting the blocks by energy is requested as this might be a
51 : // common mistake
52 112 : if (unique_labels.count(TransformationType::SORT_BY_ENERGY) > 0) {
53 0 : throw std::invalid_argument("States do not store the energy and thus no energy blocks can "
54 : "be obtained. Use an energy operator instead.");
55 : }
56 112 : }
57 :
58 : template <typename Derived>
59 110 : Basis<Derived>::Basis(ketvec_t &&kets)
60 220 : : kets(std::move(kets)), coefficients{{static_cast<Eigen::Index>(this->kets.size()),
61 110 : static_cast<Eigen::Index>(this->kets.size())},
62 220 : {TransformationType::SORT_BY_KET}} {
63 110 : if (this->kets.empty()) {
64 0 : throw std::invalid_argument("The basis must contain at least one element.");
65 : }
66 110 : state_index_to_quantum_number_f.reserve(this->kets.size());
67 110 : state_index_to_quantum_number_m.reserve(this->kets.size());
68 110 : state_index_to_parity.reserve(this->kets.size());
69 110 : ket_to_ket_index.reserve(this->kets.size());
70 110 : size_t index = 0;
71 23756 : for (const auto &ket : this->kets) {
72 23646 : state_index_to_quantum_number_f.push_back(ket->get_quantum_number_f());
73 23646 : state_index_to_quantum_number_m.push_back(ket->get_quantum_number_m());
74 23646 : state_index_to_parity.push_back(ket->get_parity());
75 23646 : ket_to_ket_index[ket] = index++;
76 23646 : if (ket->get_quantum_number_f() == std::numeric_limits<real_t>::max()) {
77 14370 : _has_quantum_number_f = false;
78 : }
79 23646 : if (ket->get_quantum_number_m() == std::numeric_limits<real_t>::max()) {
80 0 : _has_quantum_number_m = false;
81 : }
82 23646 : if (ket->get_parity() == Parity::UNKNOWN) {
83 14370 : _has_parity = false;
84 : }
85 : }
86 110 : state_index_to_ket_index.resize(this->kets.size());
87 110 : std::iota(state_index_to_ket_index.begin(), state_index_to_ket_index.end(), 0);
88 110 : ket_index_to_state_index.resize(this->kets.size());
89 110 : std::iota(ket_index_to_state_index.begin(), ket_index_to_state_index.end(), 0);
90 110 : coefficients.matrix.setIdentity();
91 110 : }
92 :
93 : template <typename Derived>
94 114 : bool Basis<Derived>::has_quantum_number_f() const {
95 114 : return _has_quantum_number_f;
96 : }
97 :
98 : template <typename Derived>
99 37974 : bool Basis<Derived>::has_quantum_number_m() const {
100 37974 : return _has_quantum_number_m;
101 : }
102 :
103 : template <typename Derived>
104 114 : bool Basis<Derived>::has_parity() const {
105 114 : return _has_parity;
106 : }
107 :
108 : template <typename Derived>
109 583 : const Derived &Basis<Derived>::derived() const {
110 583 : return static_cast<const Derived &>(*this);
111 : }
112 :
113 : template <typename Derived>
114 356 : const typename Basis<Derived>::ketvec_t &Basis<Derived>::get_kets() const {
115 356 : return kets;
116 : }
117 :
118 : template <typename Derived>
119 : const Eigen::SparseMatrix<typename Basis<Derived>::scalar_t, Eigen::RowMajor> &
120 2168 : Basis<Derived>::get_coefficients() const {
121 2168 : return coefficients.matrix;
122 : }
123 :
124 : template <typename Derived>
125 : Eigen::SparseMatrix<typename Basis<Derived>::scalar_t, Eigen::RowMajor> &
126 0 : Basis<Derived>::get_coefficients() {
127 0 : return coefficients.matrix;
128 : }
129 :
130 : template <typename Derived>
131 239 : int Basis<Derived>::get_ket_index_from_ket(std::shared_ptr<const ket_t> ket) const {
132 239 : if (ket_to_ket_index.count(ket) == 0) {
133 0 : return -1;
134 : }
135 239 : return ket_to_ket_index.at(ket);
136 : }
137 :
138 : template <typename Derived>
139 : Eigen::VectorX<typename Basis<Derived>::scalar_t>
140 22 : Basis<Derived>::get_amplitudes(std::shared_ptr<const ket_t> ket) const {
141 22 : int ket_index = get_ket_index_from_ket(ket);
142 22 : if (ket_index < 0) {
143 0 : throw std::invalid_argument("The ket does not belong to the basis.");
144 : }
145 : // The following line is a more efficient alternative to
146 : // "get_amplitudes(get_canonical_state_from_ket(ket)).transpose()"
147 44 : return coefficients.matrix.row(ket_index);
148 : }
149 :
150 : template <typename Derived>
151 : Eigen::SparseMatrix<typename Basis<Derived>::scalar_t, Eigen::RowMajor>
152 6 : Basis<Derived>::get_amplitudes(std::shared_ptr<const Derived> other) const {
153 12 : return other->coefficients.matrix.adjoint() * coefficients.matrix;
154 : }
155 :
156 : template <typename Derived>
157 : Eigen::VectorX<typename Basis<Derived>::real_t>
158 20 : Basis<Derived>::get_overlaps(std::shared_ptr<const ket_t> ket) const {
159 20 : return get_amplitudes(ket).cwiseAbs2();
160 : }
161 :
162 : template <typename Derived>
163 : Eigen::SparseMatrix<typename Basis<Derived>::real_t, Eigen::RowMajor>
164 3 : Basis<Derived>::get_overlaps(std::shared_ptr<const Derived> other) const {
165 3 : return get_amplitudes(other).cwiseAbs2();
166 : }
167 :
168 : template <typename Derived>
169 0 : typename Basis<Derived>::real_t Basis<Derived>::get_quantum_number_f(size_t state_index) const {
170 0 : real_t quantum_number_f = state_index_to_quantum_number_f.at(state_index);
171 0 : if (quantum_number_f == std::numeric_limits<real_t>::max()) {
172 0 : throw std::invalid_argument("The state does not have a well-defined quantum number f.");
173 : }
174 0 : return quantum_number_f;
175 : }
176 :
177 : template <typename Derived>
178 37902 : typename Basis<Derived>::real_t Basis<Derived>::get_quantum_number_m(size_t state_index) const {
179 37902 : real_t quantum_number_m = state_index_to_quantum_number_m.at(state_index);
180 37902 : if (quantum_number_m == std::numeric_limits<real_t>::max()) {
181 0 : throw std::invalid_argument("The state does not have a well-defined quantum number m.");
182 : }
183 37902 : return quantum_number_m;
184 : }
185 :
186 : template <typename Derived>
187 43 : Parity Basis<Derived>::get_parity(size_t state_index) const {
188 43 : Parity parity = state_index_to_parity.at(state_index);
189 43 : if (parity == Parity::UNKNOWN) {
190 0 : throw std::invalid_argument("The state does not have a well-defined parity.");
191 : }
192 43 : return parity;
193 : }
194 :
195 : template <typename Derived>
196 : std::shared_ptr<const typename Basis<Derived>::ket_t>
197 46 : Basis<Derived>::get_corresponding_ket(size_t state_index) const {
198 46 : size_t ket_index = state_index_to_ket_index.at(state_index);
199 46 : if (ket_index == std::numeric_limits<int>::max()) {
200 0 : throw std::invalid_argument("The state does not belong to a ket in a well-defined way.");
201 : }
202 46 : return kets[ket_index];
203 : }
204 :
205 : template <typename Derived>
206 : std::shared_ptr<const typename Basis<Derived>::ket_t>
207 0 : Basis<Derived>::get_corresponding_ket(std::shared_ptr<const Derived> /*state*/) const {
208 0 : throw std::runtime_error("Not implemented yet.");
209 : }
210 :
211 : template <typename Derived>
212 189 : std::shared_ptr<const Derived> Basis<Derived>::get_state(size_t state_index) const {
213 : // Create a copy of the current object
214 189 : auto restricted = std::make_shared<Derived>(derived());
215 :
216 : // Restrict the copy to the state with the largest overlap
217 189 : restricted->coefficients.matrix = restricted->coefficients.matrix.col(state_index);
218 :
219 189 : std::fill(restricted->ket_index_to_state_index.begin(),
220 189 : restricted->ket_index_to_state_index.end(), std::numeric_limits<int>::max());
221 189 : restricted->ket_index_to_state_index[state_index_to_ket_index[state_index]] = 0;
222 :
223 189 : restricted->state_index_to_quantum_number_f = {state_index_to_quantum_number_f[state_index]};
224 189 : restricted->state_index_to_quantum_number_m = {state_index_to_quantum_number_m[state_index]};
225 189 : restricted->state_index_to_parity = {state_index_to_parity[state_index]};
226 189 : restricted->state_index_to_ket_index = {state_index_to_ket_index[state_index]};
227 :
228 378 : restricted->_has_quantum_number_f =
229 189 : restricted->state_index_to_quantum_number_f[0] != std::numeric_limits<real_t>::max();
230 378 : restricted->_has_quantum_number_m =
231 189 : restricted->state_index_to_quantum_number_m[0] != std::numeric_limits<real_t>::max();
232 189 : restricted->_has_parity = restricted->state_index_to_parity[0] != Parity::UNKNOWN;
233 :
234 378 : return restricted;
235 189 : }
236 :
237 : template <typename Derived>
238 : std::shared_ptr<const typename Basis<Derived>::ket_t>
239 0 : Basis<Derived>::get_ket(size_t ket_index) const {
240 0 : return kets[ket_index];
241 : }
242 :
243 : template <typename Derived>
244 27 : std::shared_ptr<const Derived> Basis<Derived>::get_corresponding_state(size_t ket_index) const {
245 27 : size_t state_index = ket_index_to_state_index.at(ket_index);
246 27 : if (state_index == std::numeric_limits<int>::max()) {
247 0 : throw std::runtime_error("The ket does not belong to a state in a well-defined way.");
248 : }
249 27 : return get_state(state_index);
250 : }
251 :
252 : template <typename Derived>
253 : std::shared_ptr<const Derived>
254 27 : Basis<Derived>::get_corresponding_state(std::shared_ptr<const ket_t> ket) const {
255 27 : int ket_index = get_ket_index_from_ket(ket);
256 27 : if (ket_index < 0) {
257 0 : throw std::invalid_argument("The ket does not belong to the basis.");
258 : }
259 27 : return get_corresponding_state(ket_index);
260 : }
261 :
262 : template <typename Derived>
263 92 : size_t Basis<Derived>::get_corresponding_state_index(size_t ket_index) const {
264 92 : int state_index = ket_index_to_state_index.at(ket_index);
265 92 : if (state_index == std::numeric_limits<int>::max()) {
266 0 : throw std::runtime_error("The ket does not belong to a state in a well-defined way.");
267 : }
268 92 : return state_index;
269 : }
270 :
271 : template <typename Derived>
272 92 : size_t Basis<Derived>::get_corresponding_state_index(std::shared_ptr<const ket_t> ket) const {
273 92 : int ket_index = get_ket_index_from_ket(ket);
274 92 : if (ket_index < 0) {
275 0 : throw std::invalid_argument("The ket does not belong to the basis.");
276 : }
277 92 : return get_corresponding_state_index(ket_index);
278 : }
279 :
280 : template <typename Derived>
281 0 : size_t Basis<Derived>::get_corresponding_ket_index(size_t state_index) const {
282 0 : int ket_index = state_index_to_ket_index.at(state_index);
283 0 : if (ket_index == std::numeric_limits<int>::max()) {
284 0 : throw std::runtime_error("The state does not belong to a ket in a well-defined way.");
285 : }
286 0 : return ket_index;
287 : }
288 :
289 : template <typename Derived>
290 0 : size_t Basis<Derived>::get_corresponding_ket_index(std::shared_ptr<const Derived> /*state*/) const {
291 0 : throw std::runtime_error("Not implemented yet.");
292 : }
293 :
294 : template <typename Derived>
295 : std::shared_ptr<const Derived>
296 98 : Basis<Derived>::get_canonical_state_from_ket(size_t ket_index) const {
297 : // Create a copy of the current object
298 98 : auto created = std::make_shared<Derived>(derived());
299 :
300 : // Fill the copy with the state corresponding to the ket index
301 98 : created->coefficients.matrix =
302 196 : Eigen::SparseMatrix<scalar_t, Eigen::RowMajor>(coefficients.matrix.rows(), 1);
303 98 : created->coefficients.matrix.coeffRef(ket_index, 0) = 1;
304 98 : created->coefficients.matrix.makeCompressed();
305 :
306 98 : std::fill(created->ket_index_to_state_index.begin(), created->ket_index_to_state_index.end(),
307 98 : std::numeric_limits<int>::max());
308 98 : created->ket_index_to_state_index[ket_index] = 0;
309 :
310 98 : created->state_index_to_quantum_number_f = {kets[ket_index]->get_quantum_number_f()};
311 98 : created->state_index_to_quantum_number_m = {kets[ket_index]->get_quantum_number_m()};
312 98 : created->state_index_to_parity = {kets[ket_index]->get_parity()};
313 98 : created->state_index_to_ket_index = {ket_index};
314 :
315 196 : created->_has_quantum_number_f =
316 98 : created->state_index_to_quantum_number_f[0] != std::numeric_limits<real_t>::max();
317 196 : created->_has_quantum_number_m =
318 98 : created->state_index_to_quantum_number_m[0] != std::numeric_limits<real_t>::max();
319 98 : created->_has_parity = created->state_index_to_parity[0] != Parity::UNKNOWN;
320 :
321 196 : return created;
322 98 : }
323 :
324 : template <typename Derived>
325 : std::shared_ptr<const Derived>
326 98 : Basis<Derived>::get_canonical_state_from_ket(std::shared_ptr<const ket_t> ket) const {
327 98 : int ket_index = get_ket_index_from_ket(ket);
328 98 : if (ket_index < 0) {
329 0 : throw std::invalid_argument("The ket does not belong to the basis.");
330 : }
331 98 : return get_canonical_state_from_ket(ket_index);
332 : }
333 :
334 : template <typename Derived>
335 4 : typename Basis<Derived>::Iterator Basis<Derived>::begin() const {
336 4 : return kets.begin();
337 : }
338 :
339 : template <typename Derived>
340 4 : typename Basis<Derived>::Iterator Basis<Derived>::end() const {
341 4 : return kets.end();
342 : }
343 :
344 : template <typename Derived>
345 8 : Basis<Derived>::Iterator::Iterator(typename ketvec_t::const_iterator it) : it{std::move(it)} {}
346 :
347 : template <typename Derived>
348 80 : bool Basis<Derived>::Iterator::operator!=(const Iterator &other) const {
349 80 : return other.it != it;
350 : }
351 :
352 : template <typename Derived>
353 76 : std::shared_ptr<const typename Basis<Derived>::ket_t> Basis<Derived>::Iterator::operator*() const {
354 76 : return *it;
355 : }
356 :
357 : template <typename Derived>
358 76 : typename Basis<Derived>::Iterator &Basis<Derived>::Iterator::operator++() {
359 76 : ++it;
360 76 : return *this;
361 : }
362 :
363 : template <typename Derived>
364 3749 : size_t Basis<Derived>::get_number_of_states() const {
365 3749 : return coefficients.matrix.cols();
366 : }
367 :
368 : template <typename Derived>
369 1695 : size_t Basis<Derived>::get_number_of_kets() const {
370 1695 : return coefficients.matrix.rows();
371 : }
372 :
373 : template <typename Derived>
374 : const Transformation<typename Basis<Derived>::scalar_t> &
375 113 : Basis<Derived>::get_transformation() const {
376 113 : return coefficients;
377 : }
378 :
379 : template <typename Derived>
380 : Transformation<typename Basis<Derived>::scalar_t>
381 0 : Basis<Derived>::get_rotator(real_t alpha, real_t beta, real_t gamma) const {
382 0 : Transformation<scalar_t> transformation{{static_cast<Eigen::Index>(coefficients.matrix.rows()),
383 : static_cast<Eigen::Index>(coefficients.matrix.rows())},
384 : {TransformationType::ROTATE}};
385 :
386 0 : std::vector<Eigen::Triplet<scalar_t>> entries;
387 :
388 0 : for (size_t idx_initial = 0; idx_initial < kets.size(); ++idx_initial) {
389 0 : real_t f = kets[idx_initial]->get_quantum_number_f();
390 0 : real_t m_initial = kets[idx_initial]->get_quantum_number_m();
391 :
392 0 : assert(2 * f == std::floor(2 * f) && f >= 0);
393 0 : assert(2 * m_initial == std::floor(2 * m_initial) && m_initial >= -f && m_initial <= f);
394 :
395 0 : for (real_t m_final = -f; m_final <= f; // NOSONAR m_final is precisely representable
396 : ++m_final) {
397 0 : auto val = wigner::wigner_uppercase_d_matrix<scalar_t>(f, m_initial, m_final, alpha,
398 : beta, gamma);
399 0 : size_t idx_final = get_ket_index_from_ket(
400 0 : kets[idx_initial]->get_ket_for_different_quantum_number_m(m_final));
401 0 : entries.emplace_back(idx_final, idx_initial, val);
402 : }
403 : }
404 :
405 0 : transformation.matrix.setFromTriplets(entries.begin(), entries.end());
406 0 : transformation.matrix.makeCompressed();
407 :
408 0 : return transformation;
409 0 : }
410 :
411 : template <typename Derived>
412 1 : Sorting Basis<Derived>::get_sorter(const std::vector<TransformationType> &labels) const {
413 1 : perform_sorter_checks(labels);
414 :
415 : // Throw a meaningful error if sorting by energy is requested as this might be a common mistake
416 1 : if (std::find(labels.begin(), labels.end(), TransformationType::SORT_BY_ENERGY) !=
417 2 : labels.end()) {
418 0 : throw std::invalid_argument("States do not store the energy and thus can not be sorted by "
419 : "the energy. Use an energy operator instead.");
420 : }
421 :
422 : // Initialize transformation
423 1 : Sorting transformation;
424 1 : transformation.matrix.resize(coefficients.matrix.cols());
425 1 : transformation.matrix.setIdentity();
426 :
427 : // Get the sorter
428 1 : get_sorter_without_checks(labels, transformation);
429 :
430 : // Check if all labels have been used for sorting
431 1 : if (labels != transformation.transformation_type) {
432 0 : throw std::invalid_argument("The states could not be sorted by all the requested labels.");
433 : }
434 :
435 1 : return transformation;
436 0 : }
437 :
438 : template <typename Derived>
439 : std::vector<IndicesOfBlock>
440 1 : Basis<Derived>::get_indices_of_blocks(const std::vector<TransformationType> &labels) const {
441 1 : perform_sorter_checks(labels);
442 :
443 1 : std::set<TransformationType> unique_labels(labels.begin(), labels.end());
444 1 : perform_blocks_checks(unique_labels);
445 :
446 : // Get the blocks
447 1 : IndicesOfBlocksCreator blocks_creator({0, static_cast<size_t>(coefficients.matrix.cols())});
448 1 : get_indices_of_blocks_without_checks(unique_labels, blocks_creator);
449 :
450 2 : return blocks_creator.create();
451 1 : }
452 :
453 : template <typename Derived>
454 72 : void Basis<Derived>::get_sorter_without_checks(const std::vector<TransformationType> &labels,
455 : Sorting &transformation) const {
456 72 : constexpr real_t numerical_precision = 100 * std::numeric_limits<real_t>::epsilon();
457 :
458 72 : int *perm_begin = transformation.matrix.indices().data();
459 72 : int *perm_end = perm_begin + coefficients.matrix.cols();
460 72 : const int *perm_back = perm_end - 1;
461 :
462 : // Sort the vector based on the requested labels
463 113960 : std::stable_sort(perm_begin, perm_end, [&](int a, int b) {
464 42444 : for (const auto &label : labels) {
465 30209 : switch (label) {
466 5067 : case TransformationType::SORT_BY_PARITY:
467 5067 : if (state_index_to_parity[a] != state_index_to_parity[b]) {
468 13745 : return state_index_to_parity[a] < state_index_to_parity[b];
469 : }
470 3209 : break;
471 25142 : case TransformationType::SORT_BY_QUANTUM_NUMBER_M:
472 50284 : if (std::abs(state_index_to_quantum_number_m[a] -
473 50284 : state_index_to_quantum_number_m[b]) > numerical_precision) {
474 11887 : return state_index_to_quantum_number_m[a] < state_index_to_quantum_number_m[b];
475 : }
476 13255 : break;
477 0 : case TransformationType::SORT_BY_QUANTUM_NUMBER_F:
478 0 : if (std::abs(state_index_to_quantum_number_f[a] -
479 0 : state_index_to_quantum_number_f[b]) > numerical_precision) {
480 0 : return state_index_to_quantum_number_f[a] < state_index_to_quantum_number_f[b];
481 : }
482 0 : break;
483 0 : case TransformationType::SORT_BY_KET:
484 0 : if (state_index_to_ket_index[a] != state_index_to_ket_index[b]) {
485 0 : return state_index_to_ket_index[a] < state_index_to_ket_index[b];
486 : }
487 0 : break;
488 0 : default:
489 0 : std::abort(); // Can't happen because of previous checks
490 : }
491 : }
492 12235 : return false; // Elements are equal
493 : });
494 :
495 : // Check for invalid values and add transformation types
496 155 : for (const auto &label : labels) {
497 83 : switch (label) {
498 13 : case TransformationType::SORT_BY_PARITY:
499 13 : if (state_index_to_parity[*perm_back] == Parity::UNKNOWN) {
500 0 : throw std::invalid_argument(
501 : "States cannot be labeled and thus not sorted by the parity.");
502 : }
503 13 : transformation.transformation_type.push_back(TransformationType::SORT_BY_PARITY);
504 13 : break;
505 70 : case TransformationType::SORT_BY_QUANTUM_NUMBER_M:
506 70 : if (state_index_to_quantum_number_m[*perm_back] == std::numeric_limits<real_t>::max()) {
507 0 : throw std::invalid_argument(
508 : "States cannot be labeled and thus not sorted by the quantum number m.");
509 : }
510 70 : transformation.transformation_type.push_back(
511 70 : TransformationType::SORT_BY_QUANTUM_NUMBER_M);
512 70 : break;
513 0 : case TransformationType::SORT_BY_QUANTUM_NUMBER_F:
514 0 : if (state_index_to_quantum_number_f[*perm_back] == std::numeric_limits<real_t>::max()) {
515 0 : throw std::invalid_argument(
516 : "States cannot be labeled and thus not sorted by the quantum number f.");
517 : }
518 0 : transformation.transformation_type.push_back(
519 0 : TransformationType::SORT_BY_QUANTUM_NUMBER_F);
520 0 : break;
521 0 : case TransformationType::SORT_BY_KET:
522 0 : if (state_index_to_ket_index[*perm_back] == std::numeric_limits<int>::max()) {
523 0 : throw std::invalid_argument(
524 : "States cannot be labeled and thus not sorted by kets.");
525 : }
526 0 : transformation.transformation_type.push_back(TransformationType::SORT_BY_KET);
527 0 : break;
528 0 : default:
529 0 : std::abort(); // Can't happen because of previous checks
530 : }
531 : }
532 72 : }
533 :
534 : template <typename Derived>
535 72 : void Basis<Derived>::get_indices_of_blocks_without_checks(
536 : const std::set<TransformationType> &unique_labels,
537 : IndicesOfBlocksCreator &blocks_creator) const {
538 72 : constexpr real_t numerical_precision = 100 * std::numeric_limits<real_t>::epsilon();
539 :
540 72 : auto last_quantum_number_f = state_index_to_quantum_number_f[0];
541 72 : auto last_quantum_number_m = state_index_to_quantum_number_m[0];
542 72 : auto last_parity = state_index_to_parity[0];
543 72 : auto last_ket = state_index_to_ket_index[0];
544 :
545 4916 : for (int i = 0; i < coefficients.matrix.cols(); ++i) {
546 10673 : for (auto label : unique_labels) {
547 6061 : if (label == TransformationType::SORT_BY_QUANTUM_NUMBER_F) {
548 0 : if (std::abs(state_index_to_quantum_number_f[i] - last_quantum_number_f) >
549 : numerical_precision) {
550 0 : blocks_creator.add(i);
551 0 : break;
552 : }
553 6061 : } else if (label == TransformationType::SORT_BY_QUANTUM_NUMBER_M) {
554 4664 : if (std::abs(state_index_to_quantum_number_m[i] - last_quantum_number_m) >
555 : numerical_precision) {
556 186 : blocks_creator.add(i);
557 186 : break;
558 : }
559 1397 : } else if (label == TransformationType::SORT_BY_PARITY) {
560 1397 : if (state_index_to_parity[i] != last_parity) {
561 46 : blocks_creator.add(i);
562 46 : break;
563 : }
564 0 : } else if (label == TransformationType::SORT_BY_KET) {
565 0 : if (state_index_to_ket_index[i] != last_ket) {
566 0 : blocks_creator.add(i);
567 0 : break;
568 : }
569 : }
570 : }
571 4844 : last_quantum_number_f = state_index_to_quantum_number_f[i];
572 4844 : last_quantum_number_m = state_index_to_quantum_number_m[i];
573 4844 : last_parity = state_index_to_parity[i];
574 4844 : last_ket = state_index_to_ket_index[i];
575 : }
576 72 : }
577 :
578 : template <typename Derived>
579 184 : std::shared_ptr<const Derived> Basis<Derived>::transformed(const Sorting &transformation) const {
580 : // Create a copy of the current object
581 184 : auto transformed = std::make_shared<Derived>(derived());
582 :
583 184 : if (coefficients.matrix.cols() == 0) {
584 0 : return transformed;
585 : }
586 :
587 : // Apply the transformation
588 184 : transformed->coefficients.matrix = coefficients.matrix * transformation.matrix;
589 184 : transformed->coefficients.transformation_type = transformation.transformation_type;
590 :
591 184 : transformed->state_index_to_quantum_number_f.resize(transformation.matrix.size());
592 184 : transformed->state_index_to_quantum_number_m.resize(transformation.matrix.size());
593 184 : transformed->state_index_to_parity.resize(transformation.matrix.size());
594 184 : transformed->state_index_to_ket_index.resize(transformation.matrix.size());
595 :
596 14110 : for (int i = 0; i < transformation.matrix.size(); ++i) {
597 13926 : transformed->state_index_to_quantum_number_f[i] =
598 13926 : state_index_to_quantum_number_f[transformation.matrix.indices()[i]];
599 13926 : transformed->state_index_to_quantum_number_m[i] =
600 13926 : state_index_to_quantum_number_m[transformation.matrix.indices()[i]];
601 13926 : transformed->state_index_to_parity[i] =
602 13926 : state_index_to_parity[transformation.matrix.indices()[i]];
603 13926 : transformed->state_index_to_ket_index[i] =
604 13926 : state_index_to_ket_index[transformation.matrix.indices()[i]];
605 13926 : transformed->ket_index_to_state_index
606 13926 : [state_index_to_ket_index[transformation.matrix.indices()[i]]] = i;
607 : }
608 :
609 184 : return transformed;
610 184 : }
611 :
612 : template <typename Derived>
613 : std::shared_ptr<const Derived>
614 112 : Basis<Derived>::transformed(const Transformation<scalar_t> &transformation) const {
615 : // TODO why is "numerical_precision = 100 * std::sqrt(coefficients.matrix.rows()) *
616 : // std::numeric_limits<real_t>::epsilon()" too small for figuring out whether m is conserved?
617 112 : real_t numerical_precision = 0.001;
618 :
619 : // If the transformation is a rotation, it should be a rotation and nothing else
620 112 : bool is_rotation = false;
621 224 : for (auto t : transformation.transformation_type) {
622 112 : if (t == TransformationType::ROTATE) {
623 0 : is_rotation = true;
624 0 : break;
625 : }
626 : }
627 112 : if (is_rotation && transformation.transformation_type.size() != 1) {
628 0 : throw std::invalid_argument("A rotation can not be combined with other transformations.");
629 : }
630 :
631 : // To apply a rotation, the object must only be sorted but other transformations are not allowed
632 112 : if (is_rotation) {
633 0 : for (auto t : coefficients.transformation_type) {
634 0 : if (!utils::is_sorting(t)) {
635 0 : throw std::runtime_error(
636 : "If the object was transformed by a different transformation "
637 : "than sorting, it can not be rotated.");
638 : }
639 : }
640 : }
641 :
642 : // Create a copy of the current object
643 112 : auto transformed = std::make_shared<Derived>(derived());
644 :
645 112 : if (coefficients.matrix.cols() == 0) {
646 0 : return transformed;
647 : }
648 :
649 : // Apply the transformation
650 : // If a quantum number turns out to be conserved by the transformation, it will be
651 : // rounded to the nearest half integer to avoid loss of numerical_precision.
652 112 : transformed->coefficients.matrix = coefficients.matrix * transformation.matrix;
653 112 : transformed->coefficients.transformation_type = transformation.transformation_type;
654 :
655 112 : Eigen::SparseMatrix<real_t> probs = transformation.matrix.cwiseAbs2().transpose();
656 :
657 : {
658 224 : auto map = Eigen::Map<const Eigen::VectorX<real_t>>(state_index_to_quantum_number_f.data(),
659 112 : state_index_to_quantum_number_f.size());
660 112 : Eigen::VectorX<real_t> val = probs * map;
661 112 : Eigen::VectorX<real_t> sq = probs * map.cwiseAbs2();
662 112 : Eigen::VectorX<real_t> diff = (val.cwiseAbs2() - sq).cwiseAbs();
663 112 : transformed->state_index_to_quantum_number_f.resize(probs.rows());
664 :
665 6459 : for (size_t i = 0; i < transformed->state_index_to_quantum_number_f.size(); ++i) {
666 6347 : if (diff[i] < numerical_precision) {
667 1672 : transformed->state_index_to_quantum_number_f[i] = std::round(val[i] * 2) / 2;
668 : } else {
669 4675 : transformed->state_index_to_quantum_number_f[i] =
670 4675 : std::numeric_limits<real_t>::max();
671 4675 : transformed->_has_quantum_number_f = false;
672 : }
673 : }
674 112 : }
675 :
676 : {
677 224 : auto map = Eigen::Map<const Eigen::VectorX<real_t>>(state_index_to_quantum_number_m.data(),
678 112 : state_index_to_quantum_number_m.size());
679 112 : Eigen::VectorX<real_t> val = probs * map;
680 112 : Eigen::VectorX<real_t> sq = probs * map.cwiseAbs2();
681 112 : Eigen::VectorX<real_t> diff = (val.cwiseAbs2() - sq).cwiseAbs();
682 112 : transformed->state_index_to_quantum_number_m.resize(probs.rows());
683 :
684 6459 : for (size_t i = 0; i < transformed->state_index_to_quantum_number_m.size(); ++i) {
685 6347 : if (diff[i] < numerical_precision) {
686 4413 : transformed->state_index_to_quantum_number_m[i] = std::round(val[i] * 2) / 2;
687 : } else {
688 1934 : transformed->state_index_to_quantum_number_m[i] =
689 1934 : std::numeric_limits<real_t>::max();
690 1934 : transformed->_has_quantum_number_m = false;
691 : }
692 : }
693 112 : }
694 :
695 : {
696 : using utype = std::underlying_type<Parity>::type;
697 112 : Eigen::VectorX<real_t> map(state_index_to_parity.size());
698 6980 : for (size_t i = 0; i < state_index_to_parity.size(); ++i) {
699 6868 : map[i] = static_cast<utype>(state_index_to_parity[i]);
700 : }
701 112 : Eigen::VectorX<real_t> val = probs * map;
702 112 : Eigen::VectorX<real_t> sq = probs * map.cwiseAbs2();
703 112 : Eigen::VectorX<real_t> diff = (val.cwiseAbs2() - sq).cwiseAbs();
704 112 : transformed->state_index_to_parity.resize(probs.rows());
705 :
706 6459 : for (size_t i = 0; i < transformed->state_index_to_parity.size(); ++i) {
707 6347 : if (diff[i] < numerical_precision) {
708 3556 : transformed->state_index_to_parity[i] = static_cast<Parity>(std::lround(val[i]));
709 : } else {
710 2791 : transformed->state_index_to_parity[i] = Parity::UNKNOWN;
711 2791 : transformed->_has_parity = false;
712 : }
713 : }
714 112 : }
715 :
716 : {
717 : // In the following, we obtain a bijective map between state index and ket index.
718 :
719 : // Find the maximum value in each row and column
720 112 : std::vector<real_t> max_in_row(transformed->coefficients.matrix.rows(), 0);
721 112 : std::vector<real_t> max_in_col(transformed->coefficients.matrix.cols(), 0);
722 6980 : for (int row = 0; row < transformed->coefficients.matrix.outerSize(); ++row) {
723 6868 : for (typename Eigen::SparseMatrix<scalar_t, Eigen::RowMajor>::InnerIterator it(
724 6868 : transformed->coefficients.matrix, row);
725 201060 : it; ++it) {
726 194192 : real_t val = std::pow(std::abs(it.value()), 2);
727 194192 : max_in_row[row] = std::max(max_in_row[row], val);
728 194192 : max_in_col[it.col()] = std::max(max_in_col[it.col()], val);
729 : }
730 : }
731 :
732 : // Use the maximum values to define a cost for a sub-optimal mapping
733 112 : std::vector<real_t> costs;
734 112 : std::vector<std::pair<int, int>> mappings;
735 112 : costs.reserve(transformed->coefficients.matrix.nonZeros());
736 112 : mappings.reserve(transformed->coefficients.matrix.nonZeros());
737 6980 : for (int row = 0; row < transformed->coefficients.matrix.outerSize(); ++row) {
738 6868 : for (typename Eigen::SparseMatrix<scalar_t, Eigen::RowMajor>::InnerIterator it(
739 6868 : transformed->coefficients.matrix, row);
740 201060 : it; ++it) {
741 194192 : real_t val = std::pow(std::abs(it.value()), 2);
742 194192 : real_t cost = max_in_row[row] + max_in_col[it.col()] - 2 * val;
743 194192 : costs.push_back(cost);
744 194192 : mappings.push_back({row, it.col()});
745 : }
746 : }
747 :
748 : // Obtain from the costs in which order the mappings should be considered
749 112 : std::vector<size_t> order(costs.size());
750 112 : std::iota(order.begin(), order.end(), 0);
751 112 : std::sort(order.begin(), order.end(),
752 2905733 : [&](size_t a, size_t b) { return costs[a] < costs[b]; });
753 :
754 : // Fill ket_index_to_state_index with invalid values as there can be more kets than states
755 112 : std::fill(transformed->ket_index_to_state_index.begin(),
756 112 : transformed->ket_index_to_state_index.end(), std::numeric_limits<int>::max());
757 :
758 : // Generate the bijective map
759 112 : std::vector<bool> row_used(transformed->coefficients.matrix.rows(), false);
760 112 : std::vector<bool> col_used(transformed->coefficients.matrix.cols(), false);
761 112 : int num_used = 0;
762 14677 : for (size_t idx : order) {
763 14675 : int row = mappings[idx].first; // corresponds to the ket index
764 14675 : int col = mappings[idx].second; // corresponds to the state index
765 14675 : if (!row_used[row] && !col_used[col]) {
766 6347 : row_used[row] = true;
767 6347 : col_used[col] = true;
768 6347 : num_used++;
769 6347 : transformed->state_index_to_ket_index[col] = row;
770 6347 : transformed->ket_index_to_state_index[row] = col;
771 : }
772 14675 : if (num_used == transformed->coefficients.matrix.cols()) {
773 110 : break;
774 : }
775 : }
776 112 : assert(num_used == transformed->coefficients.matrix.cols());
777 112 : }
778 :
779 112 : return transformed;
780 112 : }
781 :
782 : template <typename Derived>
783 24124 : size_t Basis<Derived>::hash::operator()(const std::shared_ptr<const ket_t> &k) const {
784 24124 : return typename ket_t::hash()(*k);
785 : }
786 :
787 : template <typename Derived>
788 478 : bool Basis<Derived>::equal_to::operator()(const std::shared_ptr<const ket_t> &lhs,
789 : const std::shared_ptr<const ket_t> &rhs) const {
790 478 : return *lhs == *rhs;
791 : }
792 :
793 : // Explicit instantiations
794 : template class Basis<BasisAtom<double>>;
795 : template class Basis<BasisAtom<std::complex<double>>>;
796 : template class Basis<BasisPair<double>>;
797 : template class Basis<BasisPair<std::complex<double>>>;
798 : } // namespace pairinteraction
|