Line data Source code
1 : // SPDX-FileCopyrightText: 2024 PairInteraction Developers
2 : // SPDX-License-Identifier: LGPL-3.0-or-later
3 :
4 : #include "pairinteraction/basis/Basis.hpp"
5 :
6 : #include "pairinteraction/basis/BasisAtom.hpp"
7 : #include "pairinteraction/basis/BasisPair.hpp"
8 : #include "pairinteraction/enums/Parity.hpp"
9 : #include "pairinteraction/enums/TransformationType.hpp"
10 : #include "pairinteraction/ket/KetAtom.hpp"
11 : #include "pairinteraction/ket/KetPair.hpp"
12 : #include "pairinteraction/utils/TaskControl.hpp"
13 : #include "pairinteraction/utils/eigen_assertion.hpp"
14 : #include "pairinteraction/utils/eigen_compat.hpp"
15 : #include "pairinteraction/utils/wigner.hpp"
16 :
17 : #include <cassert>
18 : #include <numeric>
19 : #include <set>
20 : #include <spdlog/spdlog.h>
21 :
22 : namespace pairinteraction {
23 :
24 : template <typename Scalar>
25 : class BasisAtom;
26 :
27 : template <typename Derived>
28 2238 : void Basis<Derived>::perform_sorter_checks(const std::vector<TransformationType> &labels) const {
29 : // Check if the labels are valid sorting labels
30 4501 : for (const auto &label : labels) {
31 2263 : if (!utils::is_sorting(label)) {
32 0 : throw std::invalid_argument("One of the labels is not a valid sorting label.");
33 : }
34 : }
35 2238 : }
36 :
37 : template <typename Derived>
38 622 : void Basis<Derived>::perform_blocks_checks(
39 : const std::set<TransformationType> &unique_labels) const {
40 : // Check if the states are sorted by the requested labels
41 622 : std::set<TransformationType> unique_labels_present;
42 1230 : for (const auto &label : get_transformation().transformation_type) {
43 661 : if (!utils::is_sorting(label) || unique_labels_present.size() >= unique_labels.size()) {
44 53 : break;
45 : }
46 608 : unique_labels_present.insert(label);
47 : }
48 622 : if (unique_labels != unique_labels_present) {
49 0 : throw std::invalid_argument("The states are not sorted by the requested labels.");
50 : }
51 :
52 : // Throw a meaningful error if getting the blocks by energy is requested as this might be a
53 : // common mistake
54 622 : if (unique_labels.contains(TransformationType::SORT_BY_ENERGY)) {
55 0 : throw std::invalid_argument("States do not store the energy and thus no energy blocks can "
56 : "be obtained. Use an energy operator instead.");
57 : }
58 622 : }
59 :
60 : template <typename Derived>
61 458 : Basis<Derived>::Basis(ketvec_t &&kets)
62 916 : : kets(std::move(kets)), coefficients{{static_cast<Eigen::Index>(this->kets.size()),
63 458 : static_cast<Eigen::Index>(this->kets.size())},
64 916 : {TransformationType::SORT_BY_KET}} {
65 458 : if (this->kets.empty()) {
66 0 : throw std::invalid_argument("The basis must contain at least one element.");
67 : }
68 458 : state_index_to_quantum_number_f.reserve(this->kets.size());
69 458 : state_index_to_quantum_number_m.reserve(this->kets.size());
70 458 : state_index_to_parity.reserve(this->kets.size());
71 458 : ket_to_ket_index.reserve(this->kets.size());
72 458 : size_t index = 0;
73 251874 : for (const auto &ket : this->kets) {
74 251416 : state_index_to_quantum_number_f.push_back(ket->get_quantum_number_f());
75 251416 : state_index_to_quantum_number_m.push_back(ket->get_quantum_number_m());
76 251416 : state_index_to_parity.push_back(ket->get_parity());
77 251416 : ket_to_ket_index[ket] = index++;
78 251416 : if (ket->get_quantum_number_f() == std::numeric_limits<real_t>::max()) {
79 224297 : _has_quantum_number_f = false;
80 : }
81 251416 : if (ket->get_quantum_number_m() == std::numeric_limits<real_t>::max()) {
82 0 : _has_quantum_number_m = false;
83 : }
84 251416 : if (ket->get_parity() == Parity::UNKNOWN) {
85 224297 : _has_parity = false;
86 : }
87 : }
88 458 : state_index_to_ket_index.resize(this->kets.size());
89 458 : std::iota(state_index_to_ket_index.begin(), state_index_to_ket_index.end(), 0);
90 458 : ket_index_to_state_index.resize(this->kets.size());
91 458 : std::iota(ket_index_to_state_index.begin(), ket_index_to_state_index.end(), 0);
92 458 : coefficients.matrix.setIdentity();
93 458 : }
94 :
95 : template <typename Derived>
96 843 : bool Basis<Derived>::has_quantum_number_f() const {
97 843 : return _has_quantum_number_f;
98 : }
99 :
100 : template <typename Derived>
101 563099 : bool Basis<Derived>::has_quantum_number_m() const {
102 563099 : return _has_quantum_number_m;
103 : }
104 :
105 : template <typename Derived>
106 843 : bool Basis<Derived>::has_parity() const {
107 843 : return _has_parity;
108 : }
109 :
110 : template <typename Derived>
111 2892 : const Derived &Basis<Derived>::derived() const {
112 2892 : return static_cast<const Derived &>(*this);
113 : }
114 :
115 : template <typename Derived>
116 120 : const typename Basis<Derived>::ketvec_t &Basis<Derived>::get_kets() const {
117 120 : return kets;
118 : }
119 :
120 : template <typename Derived>
121 : const Eigen::SparseMatrix<typename Basis<Derived>::scalar_t, Eigen::RowMajor> &
122 9783 : Basis<Derived>::get_coefficients() const {
123 9783 : return coefficients.matrix;
124 : }
125 :
126 : template <typename Derived>
127 : Eigen::SparseMatrix<typename Basis<Derived>::scalar_t, Eigen::RowMajor> &
128 0 : Basis<Derived>::get_coefficients() {
129 0 : return coefficients.matrix;
130 : }
131 :
132 : template <typename Derived>
133 5 : void Basis<Derived>::set_coefficients(
134 : const Eigen::SparseMatrix<scalar_t, Eigen::RowMajor> &values) {
135 5 : if (values.rows() != coefficients.matrix.rows()) {
136 0 : throw std::invalid_argument("Incompatible number of rows.");
137 : }
138 5 : if (values.cols() != coefficients.matrix.cols()) {
139 0 : throw std::invalid_argument("Incompatible number of columns.");
140 : }
141 :
142 5 : coefficients.matrix = values;
143 :
144 5 : std::fill(ket_index_to_state_index.begin(), ket_index_to_state_index.end(),
145 5 : std::numeric_limits<int>::max());
146 5 : std::fill(state_index_to_ket_index.begin(), state_index_to_ket_index.end(),
147 5 : std::numeric_limits<int>::max());
148 5 : std::fill(state_index_to_quantum_number_f.begin(), state_index_to_quantum_number_f.end(),
149 5 : std::numeric_limits<real_t>::max());
150 5 : std::fill(state_index_to_quantum_number_m.begin(), state_index_to_quantum_number_m.end(),
151 5 : std::numeric_limits<real_t>::max());
152 5 : std::fill(state_index_to_parity.begin(), state_index_to_parity.end(), Parity::UNKNOWN);
153 5 : _has_quantum_number_f = false;
154 5 : _has_quantum_number_m = false;
155 5 : _has_parity = false;
156 5 : }
157 :
158 : template <typename Derived>
159 723 : int Basis<Derived>::get_ket_index_from_ket(std::shared_ptr<const ket_t> ket) const {
160 723 : if (!ket_to_ket_index.contains(ket)) {
161 0 : return -1;
162 : }
163 723 : return ket_to_ket_index.at(ket);
164 : }
165 :
166 : template <typename Derived>
167 : Eigen::VectorX<typename Basis<Derived>::scalar_t>
168 220 : Basis<Derived>::get_amplitudes(std::shared_ptr<const ket_t> ket) const {
169 220 : int ket_index = get_ket_index_from_ket(ket);
170 220 : if (ket_index < 0) {
171 0 : throw std::invalid_argument("The ket does not belong to the basis.");
172 : }
173 : // The following line is a more efficient alternative to
174 : // "get_amplitudes(get_canonical_state_from_ket(ket)).transpose()"
175 440 : return coefficients.matrix.row(ket_index);
176 : }
177 :
178 : template <typename Derived>
179 : Eigen::SparseMatrix<typename Basis<Derived>::scalar_t, Eigen::RowMajor>
180 16 : Basis<Derived>::get_amplitudes(std::shared_ptr<const Derived> other) const {
181 24 : return other->coefficients.matrix.adjoint() * coefficients.matrix;
182 : }
183 :
184 : template <typename Derived>
185 : Eigen::VectorX<typename Basis<Derived>::real_t>
186 214 : Basis<Derived>::get_overlaps(std::shared_ptr<const ket_t> ket) const {
187 214 : return get_amplitudes(ket).cwiseAbs2();
188 : }
189 :
190 : template <typename Derived>
191 : Eigen::SparseMatrix<typename Basis<Derived>::real_t, Eigen::RowMajor>
192 8 : Basis<Derived>::get_overlaps(std::shared_ptr<const Derived> other) const {
193 8 : return get_amplitudes(other).cwiseAbs2();
194 : }
195 :
196 : template <typename Derived>
197 0 : typename Basis<Derived>::real_t Basis<Derived>::get_quantum_number_f(size_t state_index) const {
198 0 : real_t quantum_number_f = state_index_to_quantum_number_f.at(state_index);
199 0 : if (quantum_number_f == std::numeric_limits<real_t>::max()) {
200 0 : throw std::invalid_argument("The state does not have a well-defined quantum number f.");
201 : }
202 0 : return quantum_number_f;
203 : }
204 :
205 : template <typename Derived>
206 562298 : typename Basis<Derived>::real_t Basis<Derived>::get_quantum_number_m(size_t state_index) const {
207 562298 : real_t quantum_number_m = state_index_to_quantum_number_m.at(state_index);
208 562298 : if (quantum_number_m == std::numeric_limits<real_t>::max()) {
209 0 : throw std::invalid_argument("The state does not have a well-defined quantum number m.");
210 : }
211 562298 : return quantum_number_m;
212 : }
213 :
214 : template <typename Derived>
215 43 : Parity Basis<Derived>::get_parity(size_t state_index) const {
216 43 : Parity parity = state_index_to_parity.at(state_index);
217 43 : if (parity == Parity::UNKNOWN) {
218 0 : throw std::invalid_argument("The state does not have a well-defined parity.");
219 : }
220 43 : return parity;
221 : }
222 :
223 : template <typename Derived>
224 : std::shared_ptr<const typename Basis<Derived>::ket_t>
225 184 : Basis<Derived>::get_corresponding_ket(size_t state_index) const {
226 184 : size_t ket_index = state_index_to_ket_index.at(state_index);
227 184 : if (ket_index == std::numeric_limits<int>::max()) {
228 0 : throw std::invalid_argument("The state does not belong to a ket in a well-defined way.");
229 : }
230 184 : return kets[ket_index];
231 : }
232 :
233 : template <typename Derived>
234 : std::shared_ptr<const typename Basis<Derived>::ket_t>
235 0 : Basis<Derived>::get_corresponding_ket(std::shared_ptr<const Derived> /*state*/) const {
236 0 : throw std::runtime_error("Not implemented yet.");
237 : }
238 :
239 : template <typename Derived>
240 399 : std::shared_ptr<const Derived> Basis<Derived>::get_state(size_t state_index) const {
241 : // Create a copy of the current object
242 399 : auto restricted = std::make_shared<Derived>(derived());
243 :
244 : // Restrict the copy to the state with the largest overlap
245 399 : restricted->coefficients.matrix = restricted->coefficients.matrix.col(state_index);
246 :
247 399 : std::fill(restricted->ket_index_to_state_index.begin(),
248 399 : restricted->ket_index_to_state_index.end(), std::numeric_limits<int>::max());
249 :
250 399 : size_t ket_index = state_index_to_ket_index[state_index];
251 399 : restricted->state_index_to_ket_index = {ket_index};
252 399 : if (ket_index != std::numeric_limits<int>::max()) {
253 399 : restricted->ket_index_to_state_index[ket_index] = 0;
254 : }
255 :
256 399 : restricted->state_index_to_quantum_number_f = {state_index_to_quantum_number_f[state_index]};
257 399 : restricted->state_index_to_quantum_number_m = {state_index_to_quantum_number_m[state_index]};
258 399 : restricted->state_index_to_parity = {state_index_to_parity[state_index]};
259 :
260 798 : restricted->_has_quantum_number_f =
261 399 : restricted->state_index_to_quantum_number_f[0] != std::numeric_limits<real_t>::max();
262 798 : restricted->_has_quantum_number_m =
263 399 : restricted->state_index_to_quantum_number_m[0] != std::numeric_limits<real_t>::max();
264 399 : restricted->_has_parity = restricted->state_index_to_parity[0] != Parity::UNKNOWN;
265 :
266 798 : return restricted;
267 399 : }
268 :
269 : template <typename Derived>
270 : std::shared_ptr<const typename Basis<Derived>::ket_t>
271 140512 : Basis<Derived>::get_ket(size_t ket_index) const {
272 140512 : return kets[ket_index];
273 : }
274 :
275 : template <typename Derived>
276 68 : std::shared_ptr<const Derived> Basis<Derived>::get_corresponding_state(size_t ket_index) const {
277 68 : size_t state_index = ket_index_to_state_index.at(ket_index);
278 68 : if (state_index == std::numeric_limits<int>::max()) {
279 0 : throw std::runtime_error("The ket does not belong to a state in a well-defined way.");
280 : }
281 68 : return get_state(state_index);
282 : }
283 :
284 : template <typename Derived>
285 : std::shared_ptr<const Derived>
286 68 : Basis<Derived>::get_corresponding_state(std::shared_ptr<const ket_t> ket) const {
287 68 : int ket_index = get_ket_index_from_ket(ket);
288 68 : if (ket_index < 0) {
289 0 : throw std::invalid_argument("The ket does not belong to the basis.");
290 : }
291 68 : return get_corresponding_state(ket_index);
292 : }
293 :
294 : template <typename Derived>
295 182 : size_t Basis<Derived>::get_corresponding_state_index(size_t ket_index) const {
296 182 : int state_index = ket_index_to_state_index.at(ket_index);
297 182 : if (state_index == std::numeric_limits<int>::max()) {
298 0 : throw std::runtime_error("The ket does not belong to a state in a well-defined way.");
299 : }
300 182 : return state_index;
301 : }
302 :
303 : template <typename Derived>
304 182 : size_t Basis<Derived>::get_corresponding_state_index(std::shared_ptr<const ket_t> ket) const {
305 182 : int ket_index = get_ket_index_from_ket(ket);
306 182 : if (ket_index < 0) {
307 0 : throw std::invalid_argument("The ket does not belong to the basis.");
308 : }
309 182 : return get_corresponding_state_index(ket_index);
310 : }
311 :
312 : template <typename Derived>
313 184 : size_t Basis<Derived>::get_corresponding_ket_index(size_t state_index) const {
314 184 : size_t ket_index = state_index_to_ket_index.at(state_index);
315 184 : if (ket_index == std::numeric_limits<int>::max()) {
316 0 : throw std::runtime_error("The state does not belong to a ket in a well-defined way.");
317 : }
318 184 : return ket_index;
319 : }
320 :
321 : template <typename Derived>
322 0 : size_t Basis<Derived>::get_corresponding_ket_index(std::shared_ptr<const Derived> /*state*/) const {
323 0 : throw std::runtime_error("Not implemented yet.");
324 : }
325 :
326 : template <typename Derived>
327 : std::shared_ptr<const Derived>
328 253 : Basis<Derived>::get_canonical_state_from_ket(size_t ket_index) const {
329 : // Create a copy of the current object
330 253 : auto created = std::make_shared<Derived>(derived());
331 :
332 : // Fill the copy with the state corresponding to the ket index
333 253 : created->coefficients.matrix =
334 506 : Eigen::SparseMatrix<scalar_t, Eigen::RowMajor>(coefficients.matrix.rows(), 1);
335 253 : created->coefficients.matrix.coeffRef(ket_index, 0) = 1;
336 253 : created->coefficients.matrix.makeCompressed();
337 :
338 253 : std::fill(created->ket_index_to_state_index.begin(), created->ket_index_to_state_index.end(),
339 253 : std::numeric_limits<int>::max());
340 253 : created->ket_index_to_state_index[ket_index] = 0;
341 :
342 253 : created->state_index_to_quantum_number_f = {kets[ket_index]->get_quantum_number_f()};
343 253 : created->state_index_to_quantum_number_m = {kets[ket_index]->get_quantum_number_m()};
344 253 : created->state_index_to_parity = {kets[ket_index]->get_parity()};
345 253 : created->state_index_to_ket_index = {ket_index};
346 :
347 506 : created->_has_quantum_number_f =
348 253 : created->state_index_to_quantum_number_f[0] != std::numeric_limits<real_t>::max();
349 506 : created->_has_quantum_number_m =
350 253 : created->state_index_to_quantum_number_m[0] != std::numeric_limits<real_t>::max();
351 253 : created->_has_parity = created->state_index_to_parity[0] != Parity::UNKNOWN;
352 :
353 506 : return created;
354 253 : }
355 :
356 : template <typename Derived>
357 : std::shared_ptr<const Derived>
358 253 : Basis<Derived>::get_canonical_state_from_ket(std::shared_ptr<const ket_t> ket) const {
359 253 : int ket_index = get_ket_index_from_ket(ket);
360 253 : if (ket_index < 0) {
361 0 : throw std::invalid_argument("The ket does not belong to the basis.");
362 : }
363 253 : return get_canonical_state_from_ket(ket_index);
364 : }
365 :
366 : template <typename Derived>
367 4 : typename Basis<Derived>::Iterator Basis<Derived>::begin() const {
368 4 : return kets.begin();
369 : }
370 :
371 : template <typename Derived>
372 4 : typename Basis<Derived>::Iterator Basis<Derived>::end() const {
373 4 : return kets.end();
374 : }
375 :
376 : template <typename Derived>
377 8 : Basis<Derived>::Iterator::Iterator(typename ketvec_t::const_iterator it) : it{std::move(it)} {}
378 :
379 : template <typename Derived>
380 80 : bool Basis<Derived>::Iterator::operator!=(const Iterator &other) const {
381 80 : return other.it != it;
382 : }
383 :
384 : template <typename Derived>
385 76 : std::shared_ptr<const typename Basis<Derived>::ket_t> Basis<Derived>::Iterator::operator*() const {
386 76 : return *it;
387 : }
388 :
389 : template <typename Derived>
390 76 : typename Basis<Derived>::Iterator &Basis<Derived>::Iterator::operator++() {
391 76 : ++it;
392 76 : return *this;
393 : }
394 :
395 : template <typename Derived>
396 2770 : size_t Basis<Derived>::get_number_of_states() const {
397 2770 : return coefficients.matrix.cols();
398 : }
399 :
400 : template <typename Derived>
401 6034 : size_t Basis<Derived>::get_number_of_kets() const {
402 6034 : return coefficients.matrix.rows();
403 : }
404 :
405 : template <typename Derived>
406 : const Transformation<typename Basis<Derived>::scalar_t> &
407 623 : Basis<Derived>::get_transformation() const {
408 623 : return coefficients;
409 : }
410 :
411 : template <typename Derived>
412 : Transformation<typename Basis<Derived>::scalar_t>
413 0 : Basis<Derived>::get_rotator(real_t alpha, real_t beta, real_t gamma) const {
414 0 : Transformation<scalar_t> transformation{{static_cast<Eigen::Index>(coefficients.matrix.rows()),
415 : static_cast<Eigen::Index>(coefficients.matrix.rows())},
416 : {TransformationType::ROTATE}};
417 :
418 0 : std::vector<Eigen::Triplet<scalar_t>> entries;
419 :
420 0 : set_task_status("Constructing basis rotation...");
421 0 : for (size_t idx_initial = 0; idx_initial < kets.size(); ++idx_initial) {
422 0 : real_t f = kets[idx_initial]->get_quantum_number_f();
423 0 : real_t m_initial = kets[idx_initial]->get_quantum_number_m();
424 :
425 0 : assert(2 * f == std::floor(2 * f) && f >= 0);
426 0 : assert(2 * m_initial == std::floor(2 * m_initial) && m_initial >= -f && m_initial <= f);
427 :
428 0 : for (real_t m_final = -f; m_final <= f; // NOSONAR m_final is precisely representable
429 : ++m_final) {
430 0 : auto val = wigner::wigner_uppercase_d_matrix<scalar_t>(f, m_initial, m_final, alpha,
431 : beta, gamma);
432 0 : size_t idx_final = get_ket_index_from_ket(
433 0 : kets[idx_initial]->get_ket_for_different_quantum_number_m(m_final));
434 0 : entries.emplace_back(idx_final, idx_initial, val);
435 : }
436 : }
437 :
438 0 : transformation.matrix.setFromTriplets(entries.begin(), entries.end());
439 0 : transformation.matrix.makeCompressed();
440 :
441 0 : return transformation;
442 0 : }
443 :
444 : template <typename Derived>
445 1 : Sorting Basis<Derived>::get_sorter(const std::vector<TransformationType> &labels) const {
446 1 : perform_sorter_checks(labels);
447 :
448 : // Throw a meaningful error if sorting by energy is requested as this might be a common mistake
449 1 : if (std::find(labels.begin(), labels.end(), TransformationType::SORT_BY_ENERGY) !=
450 2 : labels.end()) {
451 0 : throw std::invalid_argument("States do not store the energy and thus can not be sorted by "
452 : "the energy. Use an energy operator instead.");
453 : }
454 :
455 : // Initialize transformation
456 1 : Sorting transformation;
457 1 : transformation.matrix.resize(coefficients.matrix.cols());
458 1 : transformation.matrix.setIdentity();
459 :
460 : // Get the sorter
461 1 : get_sorter_without_checks(labels, transformation);
462 :
463 : // Check if all labels have been used for sorting
464 1 : if (labels != transformation.transformation_type) {
465 0 : throw std::invalid_argument("The states could not be sorted by all the requested labels.");
466 : }
467 :
468 1 : return transformation;
469 0 : }
470 :
471 : template <typename Derived>
472 : std::vector<IndicesOfBlock>
473 1 : Basis<Derived>::get_indices_of_blocks(const std::vector<TransformationType> &labels) const {
474 1 : perform_sorter_checks(labels);
475 :
476 1 : std::set<TransformationType> unique_labels(labels.begin(), labels.end());
477 1 : perform_blocks_checks(unique_labels);
478 :
479 : // Get the blocks
480 1 : IndicesOfBlocksCreator blocks_creator({0, static_cast<size_t>(coefficients.matrix.cols())});
481 1 : get_indices_of_blocks_without_checks(unique_labels, blocks_creator);
482 :
483 2 : return blocks_creator.create();
484 1 : }
485 :
486 : template <typename Derived>
487 569 : void Basis<Derived>::get_sorter_without_checks(const std::vector<TransformationType> &labels,
488 : Sorting &transformation) const {
489 569 : constexpr real_t numerical_precision = 100 * std::numeric_limits<real_t>::epsilon();
490 :
491 569 : int *perm_begin = transformation.matrix.indices().data();
492 569 : int *perm_end = perm_begin + coefficients.matrix.cols();
493 569 : const int *perm_back = perm_end - 1;
494 :
495 : // Sort the vector based on the requested labels
496 569 : set_task_status("Sorting basis states...");
497 635598 : std::stable_sort(perm_begin, perm_end, [&](int a, int b) {
498 279497 : for (const auto &label : labels) {
499 183374 : switch (label) {
500 30703 : case TransformationType::SORT_BY_PARITY:
501 30703 : if (state_index_to_parity[a] != state_index_to_parity[b]) {
502 57386 : return state_index_to_parity[a] < state_index_to_parity[b];
503 : }
504 20305 : break;
505 152671 : case TransformationType::SORT_BY_QUANTUM_NUMBER_M:
506 305342 : if (std::abs(state_index_to_quantum_number_m[a] -
507 305342 : state_index_to_quantum_number_m[b]) > numerical_precision) {
508 46988 : return state_index_to_quantum_number_m[a] < state_index_to_quantum_number_m[b];
509 : }
510 105683 : break;
511 0 : case TransformationType::SORT_BY_QUANTUM_NUMBER_F:
512 0 : if (std::abs(state_index_to_quantum_number_f[a] -
513 0 : state_index_to_quantum_number_f[b]) > numerical_precision) {
514 0 : return state_index_to_quantum_number_f[a] < state_index_to_quantum_number_f[b];
515 : }
516 0 : break;
517 0 : case TransformationType::SORT_BY_KET:
518 0 : if (state_index_to_ket_index[a] != state_index_to_ket_index[b]) {
519 0 : return state_index_to_ket_index[a] < state_index_to_ket_index[b];
520 : }
521 0 : break;
522 0 : default:
523 0 : std::abort(); // Can't happen because of previous checks
524 : }
525 : }
526 96123 : return false; // Elements are equal
527 : });
528 :
529 : // Check for invalid values and add transformation types
530 1177 : for (const auto &label : labels) {
531 608 : switch (label) {
532 41 : case TransformationType::SORT_BY_PARITY:
533 41 : if (state_index_to_parity[*perm_back] == Parity::UNKNOWN) {
534 0 : throw std::invalid_argument(
535 : "States cannot be labeled and thus not sorted by the parity.");
536 : }
537 41 : transformation.transformation_type.push_back(TransformationType::SORT_BY_PARITY);
538 41 : break;
539 567 : case TransformationType::SORT_BY_QUANTUM_NUMBER_M:
540 567 : if (state_index_to_quantum_number_m[*perm_back] == std::numeric_limits<real_t>::max()) {
541 0 : throw std::invalid_argument(
542 : "States cannot be labeled and thus not sorted by the quantum number m.");
543 : }
544 567 : transformation.transformation_type.push_back(
545 567 : TransformationType::SORT_BY_QUANTUM_NUMBER_M);
546 567 : break;
547 0 : case TransformationType::SORT_BY_QUANTUM_NUMBER_F:
548 0 : if (state_index_to_quantum_number_f[*perm_back] == std::numeric_limits<real_t>::max()) {
549 0 : throw std::invalid_argument(
550 : "States cannot be labeled and thus not sorted by the quantum number f.");
551 : }
552 0 : transformation.transformation_type.push_back(
553 0 : TransformationType::SORT_BY_QUANTUM_NUMBER_F);
554 0 : break;
555 0 : case TransformationType::SORT_BY_KET:
556 0 : if (state_index_to_ket_index[*perm_back] == std::numeric_limits<int>::max()) {
557 0 : throw std::invalid_argument(
558 : "States cannot be labeled and thus not sorted by kets.");
559 : }
560 0 : transformation.transformation_type.push_back(TransformationType::SORT_BY_KET);
561 0 : break;
562 0 : default:
563 0 : std::abort(); // Can't happen because of previous checks
564 : }
565 : }
566 569 : }
567 :
568 : template <typename Derived>
569 569 : void Basis<Derived>::get_indices_of_blocks_without_checks(
570 : const std::set<TransformationType> &unique_labels,
571 : IndicesOfBlocksCreator &blocks_creator) const {
572 569 : constexpr real_t numerical_precision = 100 * std::numeric_limits<real_t>::epsilon();
573 :
574 569 : auto last_quantum_number_f = state_index_to_quantum_number_f[0];
575 569 : auto last_quantum_number_m = state_index_to_quantum_number_m[0];
576 569 : auto last_parity = state_index_to_parity[0];
577 569 : auto last_ket = state_index_to_ket_index[0];
578 :
579 569 : set_task_status("Identifying basis blocks...");
580 32971 : for (int i = 0; i < coefficients.matrix.cols(); ++i) {
581 71114 : for (auto label : unique_labels) {
582 39411 : if (label == TransformationType::SORT_BY_QUANTUM_NUMBER_F &&
583 0 : std::abs(state_index_to_quantum_number_f[i] - last_quantum_number_f) >
584 : numerical_precision) {
585 0 : blocks_creator.add(i);
586 0 : break;
587 : }
588 71633 : if (label == TransformationType::SORT_BY_QUANTUM_NUMBER_M &&
589 32222 : std::abs(state_index_to_quantum_number_m[i] - last_quantum_number_m) >
590 : numerical_precision) {
591 555 : blocks_creator.add(i);
592 555 : break;
593 : }
594 46045 : if (label == TransformationType::SORT_BY_PARITY &&
595 7189 : state_index_to_parity[i] != last_parity) {
596 144 : blocks_creator.add(i);
597 144 : break;
598 : }
599 38712 : if (label == TransformationType::SORT_BY_KET &&
600 38712 : state_index_to_ket_index[i] != last_ket &&
601 0 : last_ket != std::numeric_limits<int>::max()) {
602 0 : blocks_creator.add(i);
603 0 : break;
604 : }
605 : }
606 32402 : last_quantum_number_f = state_index_to_quantum_number_f[i];
607 32402 : last_quantum_number_m = state_index_to_quantum_number_m[i];
608 32402 : last_parity = state_index_to_parity[i];
609 32402 : last_ket = state_index_to_ket_index[i];
610 : }
611 569 : }
612 :
613 : template <typename Derived>
614 1616 : std::shared_ptr<const Derived> Basis<Derived>::transformed(const Sorting &transformation) const {
615 : // Create a copy of the current object
616 1616 : auto transformed = std::make_shared<Derived>(derived());
617 :
618 1616 : if (coefficients.matrix.cols() == 0) {
619 0 : return transformed;
620 : }
621 :
622 : // Apply the transformation
623 1616 : set_task_status("Applying basis sorting...");
624 1616 : transformed->coefficients.matrix = coefficients.matrix * transformation.matrix;
625 1616 : transformed->coefficients.transformation_type = transformation.transformation_type;
626 :
627 1616 : transformed->state_index_to_quantum_number_f.resize(transformation.matrix.size());
628 1616 : transformed->state_index_to_quantum_number_m.resize(transformation.matrix.size());
629 1616 : transformed->state_index_to_parity.resize(transformation.matrix.size());
630 1616 : transformed->state_index_to_ket_index.resize(transformation.matrix.size());
631 :
632 1616 : set_task_status("Relabeling sorted basis states...");
633 153784 : for (int i = 0; i < transformation.matrix.size(); ++i) {
634 152168 : transformed->state_index_to_quantum_number_f[i] =
635 152168 : state_index_to_quantum_number_f[transformation.matrix.indices()[i]];
636 152168 : transformed->state_index_to_quantum_number_m[i] =
637 152168 : state_index_to_quantum_number_m[transformation.matrix.indices()[i]];
638 152168 : transformed->state_index_to_parity[i] =
639 152168 : state_index_to_parity[transformation.matrix.indices()[i]];
640 :
641 152168 : size_t ket_index = state_index_to_ket_index[transformation.matrix.indices()[i]];
642 152168 : transformed->state_index_to_ket_index[i] = ket_index;
643 152168 : if (ket_index != std::numeric_limits<int>::max()) {
644 152168 : transformed->ket_index_to_state_index[ket_index] = i;
645 : }
646 : }
647 :
648 1616 : return transformed;
649 1616 : }
650 :
651 : template <typename Derived>
652 : std::shared_ptr<const Derived>
653 624 : Basis<Derived>::transformed(const Transformation<scalar_t> &transformation) const {
654 : // TODO why is "numerical_precision = 100 * std::sqrt(coefficients.matrix.rows()) *
655 : // std::numeric_limits<real_t>::epsilon()" too small for figuring out whether m is conserved?
656 624 : real_t numerical_precision = 0.001;
657 :
658 : // If the transformation is a rotation, it should be a rotation and nothing else
659 624 : bool is_rotation = false;
660 1248 : for (auto t : transformation.transformation_type) {
661 624 : if (t == TransformationType::ROTATE) {
662 0 : is_rotation = true;
663 0 : break;
664 : }
665 : }
666 624 : if (is_rotation && transformation.transformation_type.size() != 1) {
667 0 : throw std::invalid_argument("A rotation can not be combined with other transformations.");
668 : }
669 :
670 : // To apply a rotation, the object must only be sorted but other transformations are not allowed
671 624 : if (is_rotation) {
672 0 : for (auto t : coefficients.transformation_type) {
673 0 : if (!utils::is_sorting(t)) {
674 0 : throw std::runtime_error(
675 : "If the object was transformed by a different transformation "
676 : "than sorting, it can not be rotated.");
677 : }
678 : }
679 : }
680 :
681 : // Create a copy of the current object
682 624 : auto transformed = std::make_shared<Derived>(derived());
683 :
684 624 : if (coefficients.matrix.cols() == 0) {
685 0 : return transformed;
686 : }
687 :
688 : // Apply the transformation
689 : // If a quantum number turns out to be conserved by the transformation, it will be
690 : // rounded to the nearest half integer to avoid loss of numerical_precision.
691 624 : set_task_status("Applying basis transformation...");
692 624 : transformed->coefficients.matrix = coefficients.matrix * transformation.matrix;
693 624 : transformed->coefficients.transformation_type = transformation.transformation_type;
694 :
695 624 : Eigen::SparseMatrix<real_t> probs = transformation.matrix.cwiseAbs2().transpose();
696 :
697 624 : set_task_status("Updating transformed quantum numbers...");
698 : {
699 1248 : auto map = Eigen::Map<const Eigen::VectorX<real_t>>(state_index_to_quantum_number_f.data(),
700 624 : state_index_to_quantum_number_f.size());
701 624 : Eigen::VectorX<real_t> val = probs * map;
702 624 : Eigen::VectorX<real_t> sq = probs * map.cwiseAbs2();
703 624 : Eigen::VectorX<real_t> diff = (val.cwiseAbs2() - sq).cwiseAbs();
704 624 : transformed->state_index_to_quantum_number_f.resize(probs.rows());
705 :
706 31002 : for (size_t i = 0; i < transformed->state_index_to_quantum_number_f.size(); ++i) {
707 30378 : if (diff[i] < numerical_precision) {
708 5113 : transformed->state_index_to_quantum_number_f[i] = std::round(val[i] * 2) / 2;
709 : } else {
710 25265 : transformed->state_index_to_quantum_number_f[i] =
711 25265 : std::numeric_limits<real_t>::max();
712 25265 : transformed->_has_quantum_number_f = false;
713 : }
714 : }
715 624 : }
716 :
717 : {
718 1248 : auto map = Eigen::Map<const Eigen::VectorX<real_t>>(state_index_to_quantum_number_m.data(),
719 624 : state_index_to_quantum_number_m.size());
720 624 : Eigen::VectorX<real_t> val = probs * map;
721 624 : Eigen::VectorX<real_t> sq = probs * map.cwiseAbs2();
722 624 : Eigen::VectorX<real_t> diff = (val.cwiseAbs2() - sq).cwiseAbs();
723 624 : transformed->state_index_to_quantum_number_m.resize(probs.rows());
724 :
725 31002 : for (size_t i = 0; i < transformed->state_index_to_quantum_number_m.size(); ++i) {
726 30378 : if (diff[i] < numerical_precision) {
727 27526 : transformed->state_index_to_quantum_number_m[i] = std::round(val[i] * 2) / 2;
728 : } else {
729 2852 : transformed->state_index_to_quantum_number_m[i] =
730 2852 : std::numeric_limits<real_t>::max();
731 2852 : transformed->_has_quantum_number_m = false;
732 : }
733 : }
734 624 : }
735 :
736 : {
737 : using utype = std::underlying_type<Parity>::type;
738 624 : Eigen::VectorX<real_t> map(state_index_to_parity.size());
739 36100 : for (size_t i = 0; i < state_index_to_parity.size(); ++i) {
740 35476 : map[i] = static_cast<utype>(state_index_to_parity[i]);
741 : }
742 624 : Eigen::VectorX<real_t> val = probs * map;
743 624 : Eigen::VectorX<real_t> sq = probs * map.cwiseAbs2();
744 624 : Eigen::VectorX<real_t> diff = (val.cwiseAbs2() - sq).cwiseAbs();
745 624 : transformed->state_index_to_parity.resize(probs.rows());
746 :
747 31002 : for (size_t i = 0; i < transformed->state_index_to_parity.size(); ++i) {
748 30378 : if (diff[i] < numerical_precision) {
749 24262 : transformed->state_index_to_parity[i] = static_cast<Parity>(std::lround(val[i]));
750 : } else {
751 6116 : transformed->state_index_to_parity[i] = Parity::UNKNOWN;
752 6116 : transformed->_has_parity = false;
753 : }
754 : }
755 624 : }
756 :
757 624 : set_task_status("Obtaining transformed state mapping...");
758 : {
759 : // In the following, we obtain a bijective map between state index and ket index.
760 :
761 : // Find the maximum value in each row and column
762 624 : std::vector<real_t> max_in_row(transformed->coefficients.matrix.rows(), 0);
763 624 : std::vector<real_t> max_in_col(transformed->coefficients.matrix.cols(), 0);
764 36100 : for (int row = 0; row < transformed->coefficients.matrix.outerSize(); ++row) {
765 35476 : for (typename Eigen::SparseMatrix<scalar_t, Eigen::RowMajor>::InnerIterator it(
766 35476 : transformed->coefficients.matrix, row);
767 657371 : it; ++it) {
768 621895 : real_t val = std::pow(std::abs(it.value()), 2);
769 621895 : max_in_row[row] = std::max(max_in_row[row], val);
770 621895 : max_in_col[it.col()] = std::max(max_in_col[it.col()], val);
771 : }
772 : }
773 :
774 : // Use the maximum values to define a cost for a sub-optimal mapping
775 624 : std::vector<real_t> costs;
776 624 : std::vector<std::pair<int, int>> mappings;
777 624 : costs.reserve(transformed->coefficients.matrix.nonZeros());
778 624 : mappings.reserve(transformed->coefficients.matrix.nonZeros());
779 36100 : for (int row = 0; row < transformed->coefficients.matrix.outerSize(); ++row) {
780 35476 : for (typename Eigen::SparseMatrix<scalar_t, Eigen::RowMajor>::InnerIterator it(
781 35476 : transformed->coefficients.matrix, row);
782 657371 : it; ++it) {
783 621895 : real_t val = std::pow(std::abs(it.value()), 2);
784 621895 : real_t cost = max_in_row[row] + max_in_col[it.col()] - 2 * val;
785 621895 : costs.push_back(cost);
786 621895 : mappings.push_back({row, it.col()});
787 : }
788 : }
789 :
790 : // Obtain from the costs in which order the mappings should be considered
791 624 : std::vector<size_t> order(costs.size());
792 624 : std::iota(order.begin(), order.end(), 0);
793 624 : std::sort(order.begin(), order.end(),
794 7961500 : [&](size_t a, size_t b) { return costs[a] < costs[b]; });
795 :
796 : // Fill ket_index_to_state_index with invalid values as there can be more kets than states
797 624 : std::fill(transformed->ket_index_to_state_index.begin(),
798 624 : transformed->ket_index_to_state_index.end(), std::numeric_limits<int>::max());
799 :
800 : // Generate the bijective map
801 624 : std::vector<bool> row_used(transformed->coefficients.matrix.rows(), false);
802 624 : std::vector<bool> col_used(transformed->coefficients.matrix.cols(), false);
803 624 : int num_used = 0;
804 116841 : for (size_t idx : order) {
805 116837 : int row = mappings[idx].first; // corresponds to the ket index
806 116837 : int col = mappings[idx].second; // corresponds to the state index
807 116837 : if (!row_used[row] && !col_used[col]) {
808 30376 : row_used[row] = true;
809 30376 : col_used[col] = true;
810 30376 : num_used++;
811 30376 : transformed->state_index_to_ket_index[col] = row;
812 30376 : transformed->ket_index_to_state_index[row] = col;
813 : }
814 116837 : if (num_used == transformed->coefficients.matrix.cols()) {
815 620 : break;
816 : }
817 : }
818 624 : if (num_used != transformed->coefficients.matrix.cols()) {
819 4 : SPDLOG_WARN("A bijective map between states and kets could not be found.");
820 : }
821 624 : }
822 :
823 624 : return transformed;
824 624 : }
825 :
826 : template <typename Derived>
827 252862 : size_t Basis<Derived>::hash::operator()(const std::shared_ptr<const ket_t> &k) const {
828 252862 : return typename ket_t::hash()(*k);
829 : }
830 :
831 : template <typename Derived>
832 1446 : bool Basis<Derived>::equal_to::operator()(const std::shared_ptr<const ket_t> &lhs,
833 : const std::shared_ptr<const ket_t> &rhs) const {
834 1446 : return *lhs == *rhs;
835 : }
836 :
837 : // Explicit instantiations
838 : template class Basis<BasisAtom<double>>;
839 : template class Basis<BasisAtom<std::complex<double>>>;
840 : template class Basis<BasisPair<double>>;
841 : template class Basis<BasisPair<std::complex<double>>>;
842 : } // namespace pairinteraction
|