Line data Source code
1 : // SPDX-FileCopyrightText: 2024 PairInteraction Developers
2 : // SPDX-License-Identifier: LGPL-3.0-or-later
3 :
4 : #include "pairinteraction/basis/Basis.hpp"
5 :
6 : #include "pairinteraction/basis/BasisAtom.hpp"
7 : #include "pairinteraction/basis/BasisPair.hpp"
8 : #include "pairinteraction/enums/Parity.hpp"
9 : #include "pairinteraction/enums/TransformationType.hpp"
10 : #include "pairinteraction/ket/KetAtom.hpp"
11 : #include "pairinteraction/ket/KetPair.hpp"
12 : #include "pairinteraction/utils/TaskControl.hpp"
13 : #include "pairinteraction/utils/eigen_assertion.hpp"
14 : #include "pairinteraction/utils/eigen_compat.hpp"
15 : #include "pairinteraction/utils/wigner.hpp"
16 :
17 : #include <cassert>
18 : #include <numeric>
19 : #include <set>
20 : #include <spdlog/spdlog.h>
21 :
22 : namespace pairinteraction {
23 :
24 : template <typename Scalar>
25 : class BasisAtom;
26 :
27 : template <typename Derived>
28 2761 : void Basis<Derived>::perform_sorter_checks(const std::vector<TransformationType> &labels) const {
29 : // Check if the labels are valid sorting labels
30 5549 : for (const auto &label : labels) {
31 2788 : if (!utils::is_sorting(label)) {
32 0 : throw std::invalid_argument("One of the labels is not a valid sorting label.");
33 : }
34 : }
35 2761 : }
36 :
37 : template <typename Derived>
38 630 : void Basis<Derived>::perform_blocks_checks(
39 : const std::set<TransformationType> &unique_labels) const {
40 : // Check if the states are sorted by the requested labels
41 630 : std::set<TransformationType> unique_labels_present;
42 1248 : for (const auto &label : get_transformation().transformation_type) {
43 669 : if (!utils::is_sorting(label) || unique_labels_present.size() >= unique_labels.size()) {
44 51 : break;
45 : }
46 618 : unique_labels_present.insert(label);
47 : }
48 630 : if (unique_labels != unique_labels_present) {
49 0 : throw std::invalid_argument("The states are not sorted by the requested labels.");
50 : }
51 :
52 : // Throw a meaningful error if getting the blocks by energy is requested as this might be a
53 : // common mistake
54 630 : if (unique_labels.contains(TransformationType::SORT_BY_ENERGY)) {
55 0 : throw std::invalid_argument("States do not store the energy and thus no energy blocks can "
56 : "be obtained. Use an energy operator instead.");
57 : }
58 630 : }
59 :
60 : template <typename Derived>
61 1508 : Basis<Derived>::Basis(ketvec_t &&kets)
62 3016 : : kets(std::move(kets)), coefficients{{static_cast<Eigen::Index>(this->kets.size()),
63 1508 : static_cast<Eigen::Index>(this->kets.size())},
64 3016 : {TransformationType::SORT_BY_KET}} {
65 1508 : if (this->kets.empty()) {
66 0 : throw std::invalid_argument("The basis must contain at least one element.");
67 : }
68 1508 : state_index_to_quantum_number_f.reserve(this->kets.size());
69 1508 : state_index_to_quantum_number_m.reserve(this->kets.size());
70 1508 : state_index_to_parity.reserve(this->kets.size());
71 1508 : ket_to_ket_index.reserve(this->kets.size());
72 1508 : size_t index = 0;
73 2874399 : for (const auto &ket : this->kets) {
74 2872891 : real_t f = std::numeric_limits<real_t>::max();
75 2872891 : real_t m = std::numeric_limits<real_t>::max();
76 2872891 : Parity p = Parity::UNKNOWN;
77 : // TODO: this is a workaround, and should be fixed, once we restructure the quantum number
78 : // handling of the Basis class
79 : if constexpr (requires { ket->get_quantum_number(std::string{}); }) {
80 39800 : f = ket->get_quantum_number("f");
81 39800 : m = ket->get_quantum_number("m");
82 39800 : p = static_cast<Parity>(static_cast<int>(ket->get_quantum_number("parity")));
83 : } else {
84 2833091 : m = ket->get_quantum_number_m();
85 : }
86 2872891 : state_index_to_quantum_number_f.push_back(f);
87 2872891 : state_index_to_quantum_number_m.push_back(m);
88 2872891 : state_index_to_parity.push_back(p);
89 2872891 : ket_to_ket_index[ket] = index++;
90 2872891 : if (f == std::numeric_limits<real_t>::max()) {
91 2833091 : _has_quantum_number_f = false;
92 : }
93 2872891 : if (m == std::numeric_limits<real_t>::max()) {
94 0 : _has_quantum_number_m = false;
95 : }
96 2872891 : if (p == Parity::UNKNOWN) {
97 2833091 : _has_parity = false;
98 : }
99 : }
100 1508 : state_index_to_ket_index.resize(this->kets.size());
101 1508 : std::iota(state_index_to_ket_index.begin(), state_index_to_ket_index.end(), 0);
102 1508 : ket_index_to_state_index.resize(this->kets.size());
103 1508 : std::iota(ket_index_to_state_index.begin(), ket_index_to_state_index.end(), 0);
104 1508 : coefficients.matrix.setIdentity();
105 1508 : }
106 :
107 : template <typename Derived>
108 1229 : bool Basis<Derived>::has_quantum_number_f() const {
109 1229 : return _has_quantum_number_f;
110 : }
111 :
112 : template <typename Derived>
113 5802053 : bool Basis<Derived>::has_quantum_number_m() const {
114 5802053 : return _has_quantum_number_m;
115 : }
116 :
117 : template <typename Derived>
118 1229 : bool Basis<Derived>::has_parity() const {
119 1229 : return _has_parity;
120 : }
121 :
122 : template <typename Derived>
123 3666 : const Derived &Basis<Derived>::derived() const {
124 3666 : return static_cast<const Derived &>(*this);
125 : }
126 :
127 : template <typename Derived>
128 182 : const typename Basis<Derived>::ketvec_t &Basis<Derived>::get_kets() const {
129 182 : return kets;
130 : }
131 :
132 : template <typename Derived>
133 : const Eigen::SparseMatrix<typename Basis<Derived>::scalar_t, Eigen::RowMajor> &
134 12378 : Basis<Derived>::get_coefficients() const {
135 12378 : return coefficients.matrix;
136 : }
137 :
138 : template <typename Derived>
139 : Eigen::SparseMatrix<typename Basis<Derived>::scalar_t, Eigen::RowMajor> &
140 0 : Basis<Derived>::get_coefficients() {
141 0 : return coefficients.matrix;
142 : }
143 :
144 : template <typename Derived>
145 5 : void Basis<Derived>::set_coefficients(
146 : const Eigen::SparseMatrix<scalar_t, Eigen::RowMajor> &values) {
147 5 : if (values.rows() != coefficients.matrix.rows()) {
148 0 : throw std::invalid_argument("Incompatible number of rows.");
149 : }
150 5 : if (values.cols() != coefficients.matrix.cols()) {
151 0 : throw std::invalid_argument("Incompatible number of columns.");
152 : }
153 :
154 5 : coefficients.matrix = values;
155 :
156 5 : std::fill(ket_index_to_state_index.begin(), ket_index_to_state_index.end(),
157 5 : std::numeric_limits<int>::max());
158 5 : std::fill(state_index_to_ket_index.begin(), state_index_to_ket_index.end(),
159 5 : std::numeric_limits<int>::max());
160 5 : std::fill(state_index_to_quantum_number_f.begin(), state_index_to_quantum_number_f.end(),
161 5 : std::numeric_limits<real_t>::max());
162 5 : std::fill(state_index_to_quantum_number_m.begin(), state_index_to_quantum_number_m.end(),
163 5 : std::numeric_limits<real_t>::max());
164 5 : std::fill(state_index_to_parity.begin(), state_index_to_parity.end(), Parity::UNKNOWN);
165 5 : _has_quantum_number_f = false;
166 5 : _has_quantum_number_m = false;
167 5 : _has_parity = false;
168 5 : }
169 :
170 : template <typename Derived>
171 252 : int Basis<Derived>::get_ket_index_from_ket(std::shared_ptr<const ket_t> ket) const {
172 252 : if (!ket_to_ket_index.contains(ket)) {
173 0 : return -1;
174 : }
175 252 : return ket_to_ket_index.at(ket);
176 : }
177 :
178 : template <typename Derived>
179 0 : typename Basis<Derived>::real_t Basis<Derived>::get_quantum_number_f(size_t state_index) const {
180 0 : real_t quantum_number_f = state_index_to_quantum_number_f.at(state_index);
181 0 : if (quantum_number_f == std::numeric_limits<real_t>::max()) {
182 0 : throw std::invalid_argument("The state does not have a well-defined quantum number f.");
183 : }
184 0 : return quantum_number_f;
185 : }
186 :
187 : template <typename Derived>
188 5800866 : typename Basis<Derived>::real_t Basis<Derived>::get_quantum_number_m(size_t state_index) const {
189 5800866 : real_t quantum_number_m = state_index_to_quantum_number_m.at(state_index);
190 5800866 : if (quantum_number_m == std::numeric_limits<real_t>::max()) {
191 0 : throw std::invalid_argument("The state does not have a well-defined quantum number m.");
192 : }
193 5800866 : return quantum_number_m;
194 : }
195 :
196 : template <typename Derived>
197 32795 : Parity Basis<Derived>::get_parity(size_t state_index) const {
198 32795 : Parity parity = state_index_to_parity.at(state_index);
199 32795 : if (parity == Parity::UNKNOWN) {
200 0 : throw std::invalid_argument("The state does not have a well-defined parity.");
201 : }
202 32795 : return parity;
203 : }
204 :
205 : template <typename Derived>
206 : std::shared_ptr<const typename Basis<Derived>::ket_t>
207 1100 : Basis<Derived>::get_corresponding_ket(size_t state_index) const {
208 1100 : size_t ket_index = state_index_to_ket_index.at(state_index);
209 1100 : if (ket_index == std::numeric_limits<int>::max()) {
210 0 : throw std::invalid_argument("The state does not belong to a ket in a well-defined way.");
211 : }
212 1100 : return kets[ket_index];
213 : }
214 :
215 : template <typename Derived>
216 : std::shared_ptr<const typename Basis<Derived>::ket_t>
217 0 : Basis<Derived>::get_corresponding_ket(std::shared_ptr<const Derived> /*state*/) const {
218 0 : throw std::runtime_error("Not implemented yet.");
219 : }
220 :
221 : template <typename Derived>
222 824 : std::shared_ptr<const Derived> Basis<Derived>::get_state(size_t state_index) const {
223 : // Create a copy of the current object
224 824 : auto restricted = std::make_shared<Derived>(derived());
225 :
226 : // Restrict the copy to the state with the largest overlap
227 824 : restricted->coefficients.matrix = restricted->coefficients.matrix.col(state_index);
228 :
229 824 : std::fill(restricted->ket_index_to_state_index.begin(),
230 824 : restricted->ket_index_to_state_index.end(), std::numeric_limits<int>::max());
231 :
232 824 : size_t ket_index = state_index_to_ket_index[state_index];
233 824 : restricted->state_index_to_ket_index = {ket_index};
234 824 : if (ket_index != std::numeric_limits<int>::max()) {
235 824 : restricted->ket_index_to_state_index[ket_index] = 0;
236 : }
237 :
238 824 : restricted->state_index_to_quantum_number_f = {state_index_to_quantum_number_f[state_index]};
239 824 : restricted->state_index_to_quantum_number_m = {state_index_to_quantum_number_m[state_index]};
240 824 : restricted->state_index_to_parity = {state_index_to_parity[state_index]};
241 :
242 1648 : restricted->_has_quantum_number_f =
243 824 : restricted->state_index_to_quantum_number_f[0] != std::numeric_limits<real_t>::max();
244 1648 : restricted->_has_quantum_number_m =
245 824 : restricted->state_index_to_quantum_number_m[0] != std::numeric_limits<real_t>::max();
246 824 : restricted->_has_parity = restricted->state_index_to_parity[0] != Parity::UNKNOWN;
247 :
248 1648 : return restricted;
249 824 : }
250 :
251 : template <typename Derived>
252 : std::shared_ptr<const typename Basis<Derived>::ket_t>
253 187616 : Basis<Derived>::get_ket(size_t ket_index) const {
254 187616 : return kets[ket_index];
255 : }
256 :
257 : template <typename Derived>
258 2 : std::shared_ptr<const Derived> Basis<Derived>::get_corresponding_state(size_t ket_index) const {
259 2 : size_t state_index = ket_index_to_state_index.at(ket_index);
260 2 : if (state_index == std::numeric_limits<int>::max()) {
261 0 : throw std::runtime_error("The ket does not belong to a state in a well-defined way.");
262 : }
263 2 : return get_state(state_index);
264 : }
265 :
266 : template <typename Derived>
267 : std::shared_ptr<const Derived>
268 2 : Basis<Derived>::get_corresponding_state(std::shared_ptr<const ket_t> ket) const {
269 2 : int ket_index = get_ket_index_from_ket(ket);
270 2 : if (ket_index < 0) {
271 0 : throw std::invalid_argument("The ket does not belong to the basis.");
272 : }
273 2 : return get_corresponding_state(ket_index);
274 : }
275 :
276 : template <typename Derived>
277 250 : size_t Basis<Derived>::get_corresponding_state_index(size_t ket_index) const {
278 250 : int state_index = ket_index_to_state_index.at(ket_index);
279 250 : if (state_index == std::numeric_limits<int>::max()) {
280 0 : throw std::runtime_error("The ket does not belong to a state in a well-defined way.");
281 : }
282 250 : return state_index;
283 : }
284 :
285 : template <typename Derived>
286 250 : size_t Basis<Derived>::get_corresponding_state_index(std::shared_ptr<const ket_t> ket) const {
287 250 : int ket_index = get_ket_index_from_ket(ket);
288 250 : if (ket_index < 0) {
289 0 : throw std::invalid_argument("The ket does not belong to the basis.");
290 : }
291 250 : return get_corresponding_state_index(ket_index);
292 : }
293 :
294 : template <typename Derived>
295 236 : size_t Basis<Derived>::get_corresponding_ket_index(size_t state_index) const {
296 236 : size_t ket_index = state_index_to_ket_index.at(state_index);
297 236 : if (ket_index == std::numeric_limits<int>::max()) {
298 0 : throw std::runtime_error("The state does not belong to a ket in a well-defined way.");
299 : }
300 236 : return ket_index;
301 : }
302 :
303 : template <typename Derived>
304 0 : size_t Basis<Derived>::get_corresponding_ket_index(std::shared_ptr<const Derived> /*state*/) const {
305 0 : throw std::runtime_error("Not implemented yet.");
306 : }
307 :
308 : template <typename Derived>
309 5 : typename Basis<Derived>::Iterator Basis<Derived>::begin() const {
310 5 : return kets.begin();
311 : }
312 :
313 : template <typename Derived>
314 5 : typename Basis<Derived>::Iterator Basis<Derived>::end() const {
315 5 : return kets.end();
316 : }
317 :
318 : template <typename Derived>
319 10 : Basis<Derived>::Iterator::Iterator(typename ketvec_t::const_iterator it) : it{std::move(it)} {}
320 :
321 : template <typename Derived>
322 225 : bool Basis<Derived>::Iterator::operator!=(const Iterator &other) const {
323 225 : return other.it != it;
324 : }
325 :
326 : template <typename Derived>
327 220 : std::shared_ptr<const typename Basis<Derived>::ket_t> Basis<Derived>::Iterator::operator*() const {
328 220 : return *it;
329 : }
330 :
331 : template <typename Derived>
332 220 : typename Basis<Derived>::Iterator &Basis<Derived>::Iterator::operator++() {
333 220 : ++it;
334 220 : return *this;
335 : }
336 :
337 : template <typename Derived>
338 5800 : size_t Basis<Derived>::get_number_of_states() const {
339 5800 : return coefficients.matrix.cols();
340 : }
341 :
342 : template <typename Derived>
343 47023 : size_t Basis<Derived>::get_number_of_kets() const {
344 47023 : return coefficients.matrix.rows();
345 : }
346 :
347 : template <typename Derived>
348 : const Transformation<typename Basis<Derived>::scalar_t> &
349 631 : Basis<Derived>::get_transformation() const {
350 631 : return coefficients;
351 : }
352 :
353 : template <typename Derived>
354 : Transformation<typename Basis<Derived>::scalar_t>
355 0 : Basis<Derived>::get_rotator(real_t alpha, real_t beta, real_t gamma) const {
356 0 : Transformation<scalar_t> transformation{{static_cast<Eigen::Index>(coefficients.matrix.rows()),
357 : static_cast<Eigen::Index>(coefficients.matrix.rows())},
358 : {TransformationType::ROTATE}};
359 :
360 0 : std::vector<Eigen::Triplet<scalar_t>> entries;
361 :
362 0 : set_task_status("Constructing basis rotation...");
363 0 : for (size_t idx_initial = 0; idx_initial < kets.size(); ++idx_initial) {
364 0 : real_t f = std::numeric_limits<real_t>::max();
365 0 : real_t m_initial = std::numeric_limits<real_t>::max();
366 : // TODO: this is a workaround, and should be fixed, once we restructure the quantum number
367 : // handling of the Basis class
368 : if constexpr (requires { kets[idx_initial]->get_quantum_number(std::string{}); }) {
369 0 : f = kets[idx_initial]->get_quantum_number("f");
370 0 : m_initial = kets[idx_initial]->get_quantum_number("m");
371 : } else {
372 0 : m_initial = kets[idx_initial]->get_quantum_number_m();
373 : }
374 :
375 0 : assert(2 * f == std::floor(2 * f) && f >= 0);
376 0 : assert(2 * m_initial == std::floor(2 * m_initial) && m_initial >= -f && m_initial <= f);
377 :
378 0 : for (real_t m_final = -f; m_final <= f; // NOSONAR m_final is precisely representable
379 : ++m_final) {
380 0 : auto val = wigner::wigner_uppercase_d_matrix<scalar_t>(f, m_initial, m_final, alpha,
381 : beta, gamma);
382 0 : size_t idx_final = get_ket_index_from_ket(
383 0 : kets[idx_initial]->get_ket_for_different_quantum_number_m(m_final));
384 0 : entries.emplace_back(idx_final, idx_initial, val);
385 : }
386 : }
387 :
388 0 : transformation.matrix.setFromTriplets(entries.begin(), entries.end());
389 0 : transformation.matrix.makeCompressed();
390 :
391 0 : return transformation;
392 0 : }
393 :
394 : template <typename Derived>
395 1 : Sorting Basis<Derived>::get_sorter(const std::vector<TransformationType> &labels) const {
396 1 : perform_sorter_checks(labels);
397 :
398 : // Throw a meaningful error if sorting by energy is requested as this might be a common mistake
399 1 : if (std::find(labels.begin(), labels.end(), TransformationType::SORT_BY_ENERGY) !=
400 2 : labels.end()) {
401 0 : throw std::invalid_argument("States do not store the energy and thus can not be sorted by "
402 : "the energy. Use an energy operator instead.");
403 : }
404 :
405 : // Initialize transformation
406 1 : Sorting transformation;
407 1 : transformation.matrix.resize(coefficients.matrix.cols());
408 1 : transformation.matrix.setIdentity();
409 :
410 : // Get the sorter
411 1 : get_sorter_without_checks(labels, transformation);
412 :
413 : // Check if all labels have been used for sorting
414 1 : if (labels != transformation.transformation_type) {
415 0 : throw std::invalid_argument("The states could not be sorted by all the requested labels.");
416 : }
417 :
418 1 : return transformation;
419 0 : }
420 :
421 : template <typename Derived>
422 : std::vector<IndicesOfBlock>
423 1 : Basis<Derived>::get_indices_of_blocks(const std::vector<TransformationType> &labels) const {
424 1 : perform_sorter_checks(labels);
425 :
426 1 : std::set<TransformationType> unique_labels(labels.begin(), labels.end());
427 1 : perform_blocks_checks(unique_labels);
428 :
429 : // Get the blocks
430 1 : IndicesOfBlocksCreator blocks_creator({0, static_cast<size_t>(coefficients.matrix.cols())});
431 1 : get_indices_of_blocks_without_checks(unique_labels, blocks_creator);
432 :
433 2 : return blocks_creator.create();
434 1 : }
435 :
436 : template <typename Derived>
437 579 : void Basis<Derived>::get_sorter_without_checks(const std::vector<TransformationType> &labels,
438 : Sorting &transformation) const {
439 579 : constexpr real_t numerical_precision = 100 * std::numeric_limits<real_t>::epsilon();
440 :
441 579 : int *perm_begin = transformation.matrix.indices().data();
442 579 : int *perm_end = perm_begin + coefficients.matrix.cols();
443 579 : const int *perm_back = perm_end - 1;
444 :
445 : // Sort the vector based on the requested labels
446 579 : set_task_status("Sorting basis states...");
447 647265 : std::stable_sort(perm_begin, perm_end, [&](int a, int b) {
448 284277 : for (const auto &label : labels) {
449 186405 : switch (label) {
450 30703 : case TransformationType::SORT_BY_PARITY:
451 30703 : if (state_index_to_parity[a] != state_index_to_parity[b]) {
452 58668 : return state_index_to_parity[a] < state_index_to_parity[b];
453 : }
454 20305 : break;
455 155702 : case TransformationType::SORT_BY_QUANTUM_NUMBER_M:
456 311404 : if (std::abs(state_index_to_quantum_number_m[a] -
457 311404 : state_index_to_quantum_number_m[b]) > numerical_precision) {
458 48270 : return state_index_to_quantum_number_m[a] < state_index_to_quantum_number_m[b];
459 : }
460 107432 : break;
461 0 : case TransformationType::SORT_BY_QUANTUM_NUMBER_F:
462 0 : if (std::abs(state_index_to_quantum_number_f[a] -
463 0 : state_index_to_quantum_number_f[b]) > numerical_precision) {
464 0 : return state_index_to_quantum_number_f[a] < state_index_to_quantum_number_f[b];
465 : }
466 0 : break;
467 0 : case TransformationType::SORT_BY_KET:
468 0 : if (state_index_to_ket_index[a] != state_index_to_ket_index[b]) {
469 0 : return state_index_to_ket_index[a] < state_index_to_ket_index[b];
470 : }
471 0 : break;
472 0 : default:
473 0 : std::abort(); // Can't happen because of previous checks
474 : }
475 : }
476 97872 : return false; // Elements are equal
477 : });
478 :
479 : // Check for invalid values and add transformation types
480 1197 : for (const auto &label : labels) {
481 618 : switch (label) {
482 41 : case TransformationType::SORT_BY_PARITY:
483 41 : if (state_index_to_parity[*perm_back] == Parity::UNKNOWN) {
484 0 : throw std::invalid_argument(
485 : "States cannot be labeled and thus not sorted by the parity.");
486 : }
487 41 : transformation.transformation_type.push_back(TransformationType::SORT_BY_PARITY);
488 41 : break;
489 577 : case TransformationType::SORT_BY_QUANTUM_NUMBER_M:
490 577 : if (state_index_to_quantum_number_m[*perm_back] == std::numeric_limits<real_t>::max()) {
491 0 : throw std::invalid_argument(
492 : "States cannot be labeled and thus not sorted by the quantum number m.");
493 : }
494 577 : transformation.transformation_type.push_back(
495 577 : TransformationType::SORT_BY_QUANTUM_NUMBER_M);
496 577 : break;
497 0 : case TransformationType::SORT_BY_QUANTUM_NUMBER_F:
498 0 : if (state_index_to_quantum_number_f[*perm_back] == std::numeric_limits<real_t>::max()) {
499 0 : throw std::invalid_argument(
500 : "States cannot be labeled and thus not sorted by the quantum number f.");
501 : }
502 0 : transformation.transformation_type.push_back(
503 0 : TransformationType::SORT_BY_QUANTUM_NUMBER_F);
504 0 : break;
505 0 : case TransformationType::SORT_BY_KET:
506 0 : if (state_index_to_ket_index[*perm_back] == std::numeric_limits<int>::max()) {
507 0 : throw std::invalid_argument(
508 : "States cannot be labeled and thus not sorted by kets.");
509 : }
510 0 : transformation.transformation_type.push_back(TransformationType::SORT_BY_KET);
511 0 : break;
512 0 : default:
513 0 : std::abort(); // Can't happen because of previous checks
514 : }
515 : }
516 579 : }
517 :
518 : template <typename Derived>
519 579 : void Basis<Derived>::get_indices_of_blocks_without_checks(
520 : const std::set<TransformationType> &unique_labels,
521 : IndicesOfBlocksCreator &blocks_creator) const {
522 579 : constexpr real_t numerical_precision = 100 * std::numeric_limits<real_t>::epsilon();
523 :
524 579 : auto last_quantum_number_f = state_index_to_quantum_number_f[0];
525 579 : auto last_quantum_number_m = state_index_to_quantum_number_m[0];
526 579 : auto last_parity = state_index_to_parity[0];
527 579 : auto last_ket = state_index_to_ket_index[0];
528 :
529 579 : set_task_status("Identifying basis blocks...");
530 33557 : for (int i = 0; i < coefficients.matrix.cols(); ++i) {
531 72251 : for (auto label : unique_labels) {
532 39987 : if (label == TransformationType::SORT_BY_QUANTUM_NUMBER_F &&
533 0 : std::abs(state_index_to_quantum_number_f[i] - last_quantum_number_f) >
534 : numerical_precision) {
535 0 : blocks_creator.add(i);
536 0 : break;
537 : }
538 72785 : if (label == TransformationType::SORT_BY_QUANTUM_NUMBER_M &&
539 32798 : std::abs(state_index_to_quantum_number_m[i] - last_quantum_number_m) >
540 : numerical_precision) {
541 570 : blocks_creator.add(i);
542 570 : break;
543 : }
544 46606 : if (label == TransformationType::SORT_BY_PARITY &&
545 7189 : state_index_to_parity[i] != last_parity) {
546 144 : blocks_creator.add(i);
547 144 : break;
548 : }
549 39273 : if (label == TransformationType::SORT_BY_KET &&
550 39273 : state_index_to_ket_index[i] != last_ket &&
551 0 : last_ket != std::numeric_limits<int>::max()) {
552 0 : blocks_creator.add(i);
553 0 : break;
554 : }
555 : }
556 32978 : last_quantum_number_f = state_index_to_quantum_number_f[i];
557 32978 : last_quantum_number_m = state_index_to_quantum_number_m[i];
558 32978 : last_parity = state_index_to_parity[i];
559 32978 : last_ket = state_index_to_ket_index[i];
560 : }
561 579 : }
562 :
563 : template <typename Derived>
564 64 : std::shared_ptr<const Derived> Basis<Derived>::canonicalized() const {
565 64 : auto result = std::make_shared<Derived>(derived());
566 :
567 64 : size_t n = kets.size();
568 :
569 64 : result->coefficients.matrix.resize(n, n);
570 64 : result->coefficients.matrix.setIdentity();
571 64 : result->coefficients.transformation_type = {TransformationType::SORT_BY_KET};
572 :
573 64 : result->state_index_to_ket_index.resize(n);
574 64 : std::iota(result->state_index_to_ket_index.begin(), result->state_index_to_ket_index.end(), 0);
575 :
576 64 : result->ket_index_to_state_index.resize(n);
577 64 : std::iota(result->ket_index_to_state_index.begin(), result->ket_index_to_state_index.end(), 0);
578 :
579 64 : result->state_index_to_quantum_number_f.resize(n);
580 64 : result->state_index_to_quantum_number_m.resize(n);
581 64 : result->state_index_to_parity.resize(n);
582 64 : result->_has_quantum_number_f = true;
583 64 : result->_has_quantum_number_m = true;
584 64 : result->_has_parity = true;
585 :
586 13536 : for (size_t i = 0; i < n; ++i) {
587 13472 : real_t f = std::numeric_limits<real_t>::max();
588 13472 : real_t m = std::numeric_limits<real_t>::max();
589 13472 : Parity p = Parity::UNKNOWN;
590 : // TODO: this is a workaround, and should be fixed, once we restructure the quantum number
591 : // handling of the Basis class
592 : if constexpr (requires { kets[i]->get_quantum_number(std::string{}); }) {
593 13472 : f = kets[i]->get_quantum_number("f");
594 13472 : m = kets[i]->get_quantum_number("m");
595 13472 : p = static_cast<Parity>(static_cast<int>(kets[i]->get_quantum_number("parity")));
596 : } else {
597 0 : m = kets[i]->get_quantum_number_m();
598 : }
599 13472 : result->state_index_to_quantum_number_f[i] = f;
600 13472 : result->state_index_to_quantum_number_m[i] = m;
601 13472 : result->state_index_to_parity[i] = p;
602 13472 : if (f == std::numeric_limits<real_t>::max()) {
603 0 : result->_has_quantum_number_f = false;
604 : }
605 13472 : if (m == std::numeric_limits<real_t>::max()) {
606 0 : result->_has_quantum_number_m = false;
607 : }
608 13472 : if (p == Parity::UNKNOWN) {
609 0 : result->_has_parity = false;
610 : }
611 : }
612 :
613 128 : return result;
614 64 : }
615 :
616 : template <typename Derived>
617 2131 : std::shared_ptr<const Derived> Basis<Derived>::transformed(const Sorting &transformation) const {
618 : // Create a copy of the current object
619 2131 : auto transformed = std::make_shared<Derived>(derived());
620 :
621 2131 : if (coefficients.matrix.cols() == 0) {
622 0 : return transformed;
623 : }
624 :
625 : // Apply the transformation
626 2131 : set_task_status("Applying basis sorting...");
627 2131 : transformed->coefficients.matrix = coefficients.matrix * transformation.matrix;
628 2131 : transformed->coefficients.transformation_type = transformation.transformation_type;
629 :
630 2131 : transformed->state_index_to_quantum_number_f.resize(transformation.matrix.size());
631 2131 : transformed->state_index_to_quantum_number_m.resize(transformation.matrix.size());
632 2131 : transformed->state_index_to_parity.resize(transformation.matrix.size());
633 2131 : transformed->state_index_to_ket_index.resize(transformation.matrix.size());
634 :
635 2131 : set_task_status("Relabeling sorted basis states...");
636 180549 : for (int i = 0; i < transformation.matrix.size(); ++i) {
637 178418 : transformed->state_index_to_quantum_number_f[i] =
638 178418 : state_index_to_quantum_number_f[transformation.matrix.indices()[i]];
639 178418 : transformed->state_index_to_quantum_number_m[i] =
640 178418 : state_index_to_quantum_number_m[transformation.matrix.indices()[i]];
641 178418 : transformed->state_index_to_parity[i] =
642 178418 : state_index_to_parity[transformation.matrix.indices()[i]];
643 :
644 178418 : size_t ket_index = state_index_to_ket_index[transformation.matrix.indices()[i]];
645 178418 : transformed->state_index_to_ket_index[i] = ket_index;
646 178418 : if (ket_index != std::numeric_limits<int>::max()) {
647 178418 : transformed->ket_index_to_state_index[ket_index] = i;
648 : }
649 : }
650 :
651 2131 : return transformed;
652 2131 : }
653 :
654 : template <typename Derived>
655 : std::shared_ptr<const Derived>
656 647 : Basis<Derived>::transformed(const Transformation<scalar_t> &transformation) const {
657 : // TODO why is "numerical_precision = 100 * std::sqrt(coefficients.matrix.rows()) *
658 : // std::numeric_limits<real_t>::epsilon()" too small for figuring out whether m is conserved?
659 647 : real_t numerical_precision = 0.001;
660 :
661 : // If the transformation is a rotation, it should be a rotation and nothing else
662 647 : bool is_rotation = false;
663 1294 : for (auto t : transformation.transformation_type) {
664 647 : if (t == TransformationType::ROTATE) {
665 0 : is_rotation = true;
666 0 : break;
667 : }
668 : }
669 647 : if (is_rotation && transformation.transformation_type.size() != 1) {
670 0 : throw std::invalid_argument("A rotation can not be combined with other transformations.");
671 : }
672 :
673 : // To apply a rotation, the object must only be sorted but other transformations are not allowed
674 647 : if (is_rotation) {
675 0 : for (auto t : coefficients.transformation_type) {
676 0 : if (!utils::is_sorting(t)) {
677 0 : throw std::runtime_error(
678 : "If the object was transformed by a different transformation "
679 : "than sorting, it can not be rotated.");
680 : }
681 : }
682 : }
683 :
684 : // Create a copy of the current object
685 647 : auto transformed = std::make_shared<Derived>(derived());
686 :
687 647 : if (coefficients.matrix.cols() == 0) {
688 0 : return transformed;
689 : }
690 :
691 : // Apply the transformation
692 : // If a quantum number turns out to be conserved by the transformation, it will be
693 : // rounded to the nearest half integer to avoid loss of numerical_precision.
694 647 : set_task_status("Applying basis transformation...");
695 647 : transformed->coefficients.matrix = coefficients.matrix * transformation.matrix;
696 647 : transformed->coefficients.transformation_type = transformation.transformation_type;
697 :
698 647 : Eigen::SparseMatrix<real_t> probs = transformation.matrix.cwiseAbs2().transpose();
699 :
700 647 : set_task_status("Updating transformed quantum numbers...");
701 : {
702 1294 : auto map = Eigen::Map<const Eigen::VectorX<real_t>>(state_index_to_quantum_number_f.data(),
703 647 : state_index_to_quantum_number_f.size());
704 647 : Eigen::VectorX<real_t> val = probs * map;
705 647 : Eigen::VectorX<real_t> sq = probs * map.cwiseAbs2();
706 647 : Eigen::VectorX<real_t> diff = (val.cwiseAbs2() - sq).cwiseAbs();
707 647 : transformed->state_index_to_quantum_number_f.resize(probs.rows());
708 :
709 64351 : for (size_t i = 0; i < transformed->state_index_to_quantum_number_f.size(); ++i) {
710 63704 : if (diff[i] < numerical_precision) {
711 5101 : transformed->state_index_to_quantum_number_f[i] = std::round(val[i] * 2) / 2;
712 : } else {
713 58603 : transformed->state_index_to_quantum_number_f[i] =
714 58603 : std::numeric_limits<real_t>::max();
715 58603 : transformed->_has_quantum_number_f = false;
716 : }
717 : }
718 647 : }
719 :
720 : {
721 1294 : auto map = Eigen::Map<const Eigen::VectorX<real_t>>(state_index_to_quantum_number_m.data(),
722 647 : state_index_to_quantum_number_m.size());
723 647 : Eigen::VectorX<real_t> val = probs * map;
724 647 : Eigen::VectorX<real_t> sq = probs * map.cwiseAbs2();
725 647 : Eigen::VectorX<real_t> diff = (val.cwiseAbs2() - sq).cwiseAbs();
726 647 : transformed->state_index_to_quantum_number_m.resize(probs.rows());
727 :
728 64351 : for (size_t i = 0; i < transformed->state_index_to_quantum_number_m.size(); ++i) {
729 63704 : if (diff[i] < numerical_precision) {
730 61004 : transformed->state_index_to_quantum_number_m[i] = std::round(val[i] * 2) / 2;
731 : } else {
732 2700 : transformed->state_index_to_quantum_number_m[i] =
733 2700 : std::numeric_limits<real_t>::max();
734 2700 : transformed->_has_quantum_number_m = false;
735 : }
736 : }
737 647 : }
738 :
739 : {
740 : using utype = std::underlying_type<Parity>::type;
741 647 : Eigen::VectorX<real_t> map(state_index_to_parity.size());
742 102291 : for (size_t i = 0; i < state_index_to_parity.size(); ++i) {
743 101644 : map[i] = static_cast<utype>(state_index_to_parity[i]);
744 : }
745 647 : Eigen::VectorX<real_t> val = probs * map;
746 647 : Eigen::VectorX<real_t> sq = probs * map.cwiseAbs2();
747 647 : Eigen::VectorX<real_t> diff = (val.cwiseAbs2() - sq).cwiseAbs();
748 647 : transformed->state_index_to_parity.resize(probs.rows());
749 :
750 64351 : for (size_t i = 0; i < transformed->state_index_to_parity.size(); ++i) {
751 63704 : if (diff[i] < numerical_precision) {
752 57596 : transformed->state_index_to_parity[i] = static_cast<Parity>(std::lround(val[i]));
753 : } else {
754 6108 : transformed->state_index_to_parity[i] = Parity::UNKNOWN;
755 6108 : transformed->_has_parity = false;
756 : }
757 : }
758 647 : }
759 :
760 647 : set_task_status("Obtaining transformed state mapping...");
761 : {
762 : // In the following, we obtain a bijective map between state index and ket index.
763 :
764 : // Find the maximum value in each row and column
765 647 : std::vector<real_t> max_in_row(transformed->coefficients.matrix.rows(), 0);
766 647 : std::vector<real_t> max_in_col(transformed->coefficients.matrix.cols(), 0);
767 102699 : for (int row = 0; row < transformed->coefficients.matrix.outerSize(); ++row) {
768 102052 : for (typename Eigen::SparseMatrix<scalar_t, Eigen::RowMajor>::InnerIterator it(
769 102052 : transformed->coefficients.matrix, row);
770 797908 : it; ++it) {
771 695856 : real_t val = std::pow(std::abs(it.value()), 2);
772 695856 : max_in_row[row] = std::max(max_in_row[row], val);
773 695856 : max_in_col[it.col()] = std::max(max_in_col[it.col()], val);
774 : }
775 : }
776 :
777 : // Use the maximum values to define a cost for a sub-optimal mapping
778 647 : std::vector<real_t> costs;
779 647 : std::vector<std::pair<int, int>> mappings;
780 647 : costs.reserve(transformed->coefficients.matrix.nonZeros());
781 647 : mappings.reserve(transformed->coefficients.matrix.nonZeros());
782 102699 : for (int row = 0; row < transformed->coefficients.matrix.outerSize(); ++row) {
783 102052 : for (typename Eigen::SparseMatrix<scalar_t, Eigen::RowMajor>::InnerIterator it(
784 102052 : transformed->coefficients.matrix, row);
785 797908 : it; ++it) {
786 695856 : real_t val = std::pow(std::abs(it.value()), 2);
787 695856 : real_t cost = max_in_row[row] + max_in_col[it.col()] - 2 * val;
788 695856 : costs.push_back(cost);
789 695856 : mappings.push_back({row, it.col()});
790 : }
791 : }
792 :
793 : // Obtain from the costs in which order the mappings should be considered
794 647 : std::vector<size_t> order(costs.size());
795 647 : std::iota(order.begin(), order.end(), 0);
796 647 : std::sort(order.begin(), order.end(),
797 8745103 : [&](size_t a, size_t b) { return costs[a] < costs[b]; });
798 :
799 : // Fill ket_index_to_state_index with invalid values as there can be more kets than states
800 647 : std::fill(transformed->ket_index_to_state_index.begin(),
801 647 : transformed->ket_index_to_state_index.end(), std::numeric_limits<int>::max());
802 :
803 : // Generate the bijective map
804 647 : std::vector<bool> row_used(transformed->coefficients.matrix.rows(), false);
805 647 : std::vector<bool> col_used(transformed->coefficients.matrix.cols(), false);
806 647 : int num_used = 0;
807 185390 : for (size_t idx : order) {
808 185385 : int row = mappings[idx].first; // corresponds to the ket index
809 185385 : int col = mappings[idx].second; // corresponds to the state index
810 185385 : if (!row_used[row] && !col_used[col]) {
811 63701 : row_used[row] = true;
812 63701 : col_used[col] = true;
813 63701 : num_used++;
814 63701 : transformed->state_index_to_ket_index[col] = row;
815 63701 : transformed->ket_index_to_state_index[row] = col;
816 : }
817 185385 : if (num_used == transformed->coefficients.matrix.cols()) {
818 642 : break;
819 : }
820 : }
821 647 : if (num_used != transformed->coefficients.matrix.cols()) {
822 6 : SPDLOG_WARN("A bijective map between states and kets could not be found.");
823 : }
824 647 : }
825 :
826 647 : return transformed;
827 647 : }
828 :
829 : template <typename Derived>
830 2873395 : size_t Basis<Derived>::hash::operator()(const std::shared_ptr<const ket_t> &k) const {
831 2873395 : return typename ket_t::hash()(*k);
832 : }
833 :
834 : template <typename Derived>
835 504 : bool Basis<Derived>::equal_to::operator()(const std::shared_ptr<const ket_t> &lhs,
836 : const std::shared_ptr<const ket_t> &rhs) const {
837 504 : return *lhs == *rhs;
838 : }
839 :
840 : // Explicit instantiations
841 : template class Basis<BasisAtom<double>>;
842 : template class Basis<BasisAtom<std::complex<double>>>;
843 : template class Basis<BasisPair<double>>;
844 : template class Basis<BasisPair<std::complex<double>>>;
845 : } // namespace pairinteraction
|