Line data Source code
1 : // SPDX-FileCopyrightText: 2024 PairInteraction Developers
2 : // SPDX-License-Identifier: LGPL-3.0-or-later
3 :
4 : #include "pairinteraction/basis/Basis.hpp"
5 :
6 : #include "pairinteraction/basis/BasisAtom.hpp"
7 : #include "pairinteraction/basis/BasisPair.hpp"
8 : #include "pairinteraction/enums/Parity.hpp"
9 : #include "pairinteraction/enums/TransformationType.hpp"
10 : #include "pairinteraction/ket/KetAtom.hpp"
11 : #include "pairinteraction/ket/KetPair.hpp"
12 : #include "pairinteraction/utils/eigen_assertion.hpp"
13 : #include "pairinteraction/utils/eigen_compat.hpp"
14 : #include "pairinteraction/utils/wigner.hpp"
15 :
16 : #include <cassert>
17 : #include <numeric>
18 : #include <set>
19 : #include <spdlog/spdlog.h>
20 :
21 : namespace pairinteraction {
22 :
23 : template <typename Scalar>
24 : class BasisAtom;
25 :
26 : template <typename Derived>
27 2050 : void Basis<Derived>::perform_sorter_checks(const std::vector<TransformationType> &labels) const {
28 : // Check if the labels are valid sorting labels
29 4115 : for (const auto &label : labels) {
30 2065 : if (!utils::is_sorting(label)) {
31 0 : throw std::invalid_argument("One of the labels is not a valid sorting label.");
32 : }
33 : }
34 2050 : }
35 :
36 : template <typename Derived>
37 590 : void Basis<Derived>::perform_blocks_checks(
38 : const std::set<TransformationType> &unique_labels) const {
39 : // Check if the states are sorted by the requested labels
40 590 : std::set<TransformationType> unique_labels_present;
41 1156 : for (const auto &label : get_transformation().transformation_type) {
42 629 : if (!utils::is_sorting(label) || unique_labels_present.size() >= unique_labels.size()) {
43 63 : break;
44 : }
45 566 : unique_labels_present.insert(label);
46 : }
47 590 : if (unique_labels != unique_labels_present) {
48 0 : throw std::invalid_argument("The states are not sorted by the requested labels.");
49 : }
50 :
51 : // Throw a meaningful error if getting the blocks by energy is requested as this might be a
52 : // common mistake
53 590 : if (unique_labels.contains(TransformationType::SORT_BY_ENERGY)) {
54 0 : throw std::invalid_argument("States do not store the energy and thus no energy blocks can "
55 : "be obtained. Use an energy operator instead.");
56 : }
57 590 : }
58 :
59 : template <typename Derived>
60 397 : Basis<Derived>::Basis(ketvec_t &&kets)
61 794 : : kets(std::move(kets)), coefficients{{static_cast<Eigen::Index>(this->kets.size()),
62 397 : static_cast<Eigen::Index>(this->kets.size())},
63 794 : {TransformationType::SORT_BY_KET}} {
64 397 : if (this->kets.empty()) {
65 0 : throw std::invalid_argument("The basis must contain at least one element.");
66 : }
67 397 : state_index_to_quantum_number_f.reserve(this->kets.size());
68 397 : state_index_to_quantum_number_m.reserve(this->kets.size());
69 397 : state_index_to_parity.reserve(this->kets.size());
70 397 : ket_to_ket_index.reserve(this->kets.size());
71 397 : size_t index = 0;
72 200594 : for (const auto &ket : this->kets) {
73 200197 : state_index_to_quantum_number_f.push_back(ket->get_quantum_number_f());
74 200197 : state_index_to_quantum_number_m.push_back(ket->get_quantum_number_m());
75 200197 : state_index_to_parity.push_back(ket->get_parity());
76 200197 : ket_to_ket_index[ket] = index++;
77 200197 : if (ket->get_quantum_number_f() == std::numeric_limits<real_t>::max()) {
78 177795 : _has_quantum_number_f = false;
79 : }
80 200197 : if (ket->get_quantum_number_m() == std::numeric_limits<real_t>::max()) {
81 64000 : _has_quantum_number_m = false;
82 : }
83 200197 : if (ket->get_parity() == Parity::UNKNOWN) {
84 177795 : _has_parity = false;
85 : }
86 : }
87 397 : state_index_to_ket_index.resize(this->kets.size());
88 397 : std::iota(state_index_to_ket_index.begin(), state_index_to_ket_index.end(), 0);
89 397 : ket_index_to_state_index.resize(this->kets.size());
90 397 : std::iota(ket_index_to_state_index.begin(), ket_index_to_state_index.end(), 0);
91 397 : coefficients.matrix.setIdentity();
92 397 : }
93 :
94 : template <typename Derived>
95 549 : bool Basis<Derived>::has_quantum_number_f() const {
96 549 : return _has_quantum_number_f;
97 : }
98 :
99 : template <typename Derived>
100 326861 : bool Basis<Derived>::has_quantum_number_m() const {
101 326861 : return _has_quantum_number_m;
102 : }
103 :
104 : template <typename Derived>
105 549 : bool Basis<Derived>::has_parity() const {
106 549 : return _has_parity;
107 : }
108 :
109 : template <typename Derived>
110 2649 : const Derived &Basis<Derived>::derived() const {
111 2649 : return static_cast<const Derived &>(*this);
112 : }
113 :
114 : template <typename Derived>
115 863 : const typename Basis<Derived>::ketvec_t &Basis<Derived>::get_kets() const {
116 863 : return kets;
117 : }
118 :
119 : template <typename Derived>
120 : const Eigen::SparseMatrix<typename Basis<Derived>::scalar_t, Eigen::RowMajor> &
121 10131 : Basis<Derived>::get_coefficients() const {
122 10131 : return coefficients.matrix;
123 : }
124 :
125 : template <typename Derived>
126 : Eigen::SparseMatrix<typename Basis<Derived>::scalar_t, Eigen::RowMajor> &
127 0 : Basis<Derived>::get_coefficients() {
128 0 : return coefficients.matrix;
129 : }
130 :
131 : template <typename Derived>
132 5 : void Basis<Derived>::set_coefficients(
133 : const Eigen::SparseMatrix<scalar_t, Eigen::RowMajor> &values) {
134 5 : if (values.rows() != coefficients.matrix.rows()) {
135 0 : throw std::invalid_argument("Incompatible number of rows.");
136 : }
137 5 : if (values.cols() != coefficients.matrix.cols()) {
138 0 : throw std::invalid_argument("Incompatible number of columns.");
139 : }
140 :
141 5 : coefficients.matrix = values;
142 :
143 5 : std::fill(ket_index_to_state_index.begin(), ket_index_to_state_index.end(),
144 5 : std::numeric_limits<int>::max());
145 5 : std::fill(state_index_to_ket_index.begin(), state_index_to_ket_index.end(),
146 5 : std::numeric_limits<int>::max());
147 5 : std::fill(state_index_to_quantum_number_f.begin(), state_index_to_quantum_number_f.end(),
148 5 : std::numeric_limits<real_t>::max());
149 5 : std::fill(state_index_to_quantum_number_m.begin(), state_index_to_quantum_number_m.end(),
150 5 : std::numeric_limits<real_t>::max());
151 5 : std::fill(state_index_to_parity.begin(), state_index_to_parity.end(), Parity::UNKNOWN);
152 5 : _has_quantum_number_f = false;
153 5 : _has_quantum_number_m = false;
154 5 : _has_parity = false;
155 5 : }
156 :
157 : template <typename Derived>
158 576 : int Basis<Derived>::get_ket_index_from_ket(std::shared_ptr<const ket_t> ket) const {
159 576 : if (!ket_to_ket_index.contains(ket)) {
160 0 : return -1;
161 : }
162 576 : return ket_to_ket_index.at(ket);
163 : }
164 :
165 : template <typename Derived>
166 : Eigen::VectorX<typename Basis<Derived>::scalar_t>
167 117 : Basis<Derived>::get_amplitudes(std::shared_ptr<const ket_t> ket) const {
168 117 : int ket_index = get_ket_index_from_ket(ket);
169 117 : if (ket_index < 0) {
170 0 : throw std::invalid_argument("The ket does not belong to the basis.");
171 : }
172 : // The following line is a more efficient alternative to
173 : // "get_amplitudes(get_canonical_state_from_ket(ket)).transpose()"
174 234 : return coefficients.matrix.row(ket_index);
175 : }
176 :
177 : template <typename Derived>
178 : Eigen::SparseMatrix<typename Basis<Derived>::scalar_t, Eigen::RowMajor>
179 16 : Basis<Derived>::get_amplitudes(std::shared_ptr<const Derived> other) const {
180 24 : return other->coefficients.matrix.adjoint() * coefficients.matrix;
181 : }
182 :
183 : template <typename Derived>
184 : Eigen::VectorX<typename Basis<Derived>::real_t>
185 111 : Basis<Derived>::get_overlaps(std::shared_ptr<const ket_t> ket) const {
186 111 : return get_amplitudes(ket).cwiseAbs2();
187 : }
188 :
189 : template <typename Derived>
190 : Eigen::SparseMatrix<typename Basis<Derived>::real_t, Eigen::RowMajor>
191 8 : Basis<Derived>::get_overlaps(std::shared_ptr<const Derived> other) const {
192 8 : return get_amplitudes(other).cwiseAbs2();
193 : }
194 :
195 : template <typename Derived>
196 0 : typename Basis<Derived>::real_t Basis<Derived>::get_quantum_number_f(size_t state_index) const {
197 0 : real_t quantum_number_f = state_index_to_quantum_number_f.at(state_index);
198 0 : if (quantum_number_f == std::numeric_limits<real_t>::max()) {
199 0 : throw std::invalid_argument("The state does not have a well-defined quantum number f.");
200 : }
201 0 : return quantum_number_f;
202 : }
203 :
204 : template <typename Derived>
205 262354 : typename Basis<Derived>::real_t Basis<Derived>::get_quantum_number_m(size_t state_index) const {
206 262354 : real_t quantum_number_m = state_index_to_quantum_number_m.at(state_index);
207 262354 : if (quantum_number_m == std::numeric_limits<real_t>::max()) {
208 0 : throw std::invalid_argument("The state does not have a well-defined quantum number m.");
209 : }
210 262354 : return quantum_number_m;
211 : }
212 :
213 : template <typename Derived>
214 43 : Parity Basis<Derived>::get_parity(size_t state_index) const {
215 43 : Parity parity = state_index_to_parity.at(state_index);
216 43 : if (parity == Parity::UNKNOWN) {
217 0 : throw std::invalid_argument("The state does not have a well-defined parity.");
218 : }
219 43 : return parity;
220 : }
221 :
222 : template <typename Derived>
223 : std::shared_ptr<const typename Basis<Derived>::ket_t>
224 248 : Basis<Derived>::get_corresponding_ket(size_t state_index) const {
225 248 : size_t ket_index = state_index_to_ket_index.at(state_index);
226 248 : if (ket_index == std::numeric_limits<int>::max()) {
227 0 : throw std::invalid_argument("The state does not belong to a ket in a well-defined way.");
228 : }
229 248 : return kets[ket_index];
230 : }
231 :
232 : template <typename Derived>
233 : std::shared_ptr<const typename Basis<Derived>::ket_t>
234 0 : Basis<Derived>::get_corresponding_ket(std::shared_ptr<const Derived> /*state*/) const {
235 0 : throw std::runtime_error("Not implemented yet.");
236 : }
237 :
238 : template <typename Derived>
239 390 : std::shared_ptr<const Derived> Basis<Derived>::get_state(size_t state_index) const {
240 : // Create a copy of the current object
241 390 : auto restricted = std::make_shared<Derived>(derived());
242 :
243 : // Restrict the copy to the state with the largest overlap
244 390 : restricted->coefficients.matrix = restricted->coefficients.matrix.col(state_index);
245 :
246 390 : std::fill(restricted->ket_index_to_state_index.begin(),
247 390 : restricted->ket_index_to_state_index.end(), std::numeric_limits<int>::max());
248 :
249 390 : size_t ket_index = state_index_to_ket_index[state_index];
250 390 : restricted->state_index_to_ket_index = {ket_index};
251 390 : if (ket_index != std::numeric_limits<int>::max()) {
252 390 : restricted->ket_index_to_state_index[ket_index] = 0;
253 : }
254 :
255 390 : restricted->state_index_to_quantum_number_f = {state_index_to_quantum_number_f[state_index]};
256 390 : restricted->state_index_to_quantum_number_m = {state_index_to_quantum_number_m[state_index]};
257 390 : restricted->state_index_to_parity = {state_index_to_parity[state_index]};
258 :
259 780 : restricted->_has_quantum_number_f =
260 390 : restricted->state_index_to_quantum_number_f[0] != std::numeric_limits<real_t>::max();
261 780 : restricted->_has_quantum_number_m =
262 390 : restricted->state_index_to_quantum_number_m[0] != std::numeric_limits<real_t>::max();
263 390 : restricted->_has_parity = restricted->state_index_to_parity[0] != Parity::UNKNOWN;
264 :
265 780 : return restricted;
266 390 : }
267 :
268 : template <typename Derived>
269 : std::shared_ptr<const typename Basis<Derived>::ket_t>
270 0 : Basis<Derived>::get_ket(size_t ket_index) const {
271 0 : return kets[ket_index];
272 : }
273 :
274 : template <typename Derived>
275 68 : std::shared_ptr<const Derived> Basis<Derived>::get_corresponding_state(size_t ket_index) const {
276 68 : size_t state_index = ket_index_to_state_index.at(ket_index);
277 68 : if (state_index == std::numeric_limits<int>::max()) {
278 0 : throw std::runtime_error("The ket does not belong to a state in a well-defined way.");
279 : }
280 68 : return get_state(state_index);
281 : }
282 :
283 : template <typename Derived>
284 : std::shared_ptr<const Derived>
285 68 : Basis<Derived>::get_corresponding_state(std::shared_ptr<const ket_t> ket) const {
286 68 : int ket_index = get_ket_index_from_ket(ket);
287 68 : if (ket_index < 0) {
288 0 : throw std::invalid_argument("The ket does not belong to the basis.");
289 : }
290 68 : return get_corresponding_state(ket_index);
291 : }
292 :
293 : template <typename Derived>
294 182 : size_t Basis<Derived>::get_corresponding_state_index(size_t ket_index) const {
295 182 : int state_index = ket_index_to_state_index.at(ket_index);
296 182 : if (state_index == std::numeric_limits<int>::max()) {
297 0 : throw std::runtime_error("The ket does not belong to a state in a well-defined way.");
298 : }
299 182 : return state_index;
300 : }
301 :
302 : template <typename Derived>
303 182 : size_t Basis<Derived>::get_corresponding_state_index(std::shared_ptr<const ket_t> ket) const {
304 182 : int ket_index = get_ket_index_from_ket(ket);
305 182 : if (ket_index < 0) {
306 0 : throw std::invalid_argument("The ket does not belong to the basis.");
307 : }
308 182 : return get_corresponding_state_index(ket_index);
309 : }
310 :
311 : template <typename Derived>
312 248 : size_t Basis<Derived>::get_corresponding_ket_index(size_t state_index) const {
313 248 : size_t ket_index = state_index_to_ket_index.at(state_index);
314 248 : if (ket_index == std::numeric_limits<int>::max()) {
315 0 : throw std::runtime_error("The state does not belong to a ket in a well-defined way.");
316 : }
317 248 : return ket_index;
318 : }
319 :
320 : template <typename Derived>
321 0 : size_t Basis<Derived>::get_corresponding_ket_index(std::shared_ptr<const Derived> /*state*/) const {
322 0 : throw std::runtime_error("Not implemented yet.");
323 : }
324 :
325 : template <typename Derived>
326 : std::shared_ptr<const Derived>
327 209 : Basis<Derived>::get_canonical_state_from_ket(size_t ket_index) const {
328 : // Create a copy of the current object
329 209 : auto created = std::make_shared<Derived>(derived());
330 :
331 : // Fill the copy with the state corresponding to the ket index
332 209 : created->coefficients.matrix =
333 418 : Eigen::SparseMatrix<scalar_t, Eigen::RowMajor>(coefficients.matrix.rows(), 1);
334 209 : created->coefficients.matrix.coeffRef(ket_index, 0) = 1;
335 209 : created->coefficients.matrix.makeCompressed();
336 :
337 209 : std::fill(created->ket_index_to_state_index.begin(), created->ket_index_to_state_index.end(),
338 209 : std::numeric_limits<int>::max());
339 209 : created->ket_index_to_state_index[ket_index] = 0;
340 :
341 209 : created->state_index_to_quantum_number_f = {kets[ket_index]->get_quantum_number_f()};
342 209 : created->state_index_to_quantum_number_m = {kets[ket_index]->get_quantum_number_m()};
343 209 : created->state_index_to_parity = {kets[ket_index]->get_parity()};
344 209 : created->state_index_to_ket_index = {ket_index};
345 :
346 418 : created->_has_quantum_number_f =
347 209 : created->state_index_to_quantum_number_f[0] != std::numeric_limits<real_t>::max();
348 418 : created->_has_quantum_number_m =
349 209 : created->state_index_to_quantum_number_m[0] != std::numeric_limits<real_t>::max();
350 209 : created->_has_parity = created->state_index_to_parity[0] != Parity::UNKNOWN;
351 :
352 418 : return created;
353 209 : }
354 :
355 : template <typename Derived>
356 : std::shared_ptr<const Derived>
357 209 : Basis<Derived>::get_canonical_state_from_ket(std::shared_ptr<const ket_t> ket) const {
358 209 : int ket_index = get_ket_index_from_ket(ket);
359 209 : if (ket_index < 0) {
360 0 : throw std::invalid_argument("The ket does not belong to the basis.");
361 : }
362 209 : return get_canonical_state_from_ket(ket_index);
363 : }
364 :
365 : template <typename Derived>
366 4 : typename Basis<Derived>::Iterator Basis<Derived>::begin() const {
367 4 : return kets.begin();
368 : }
369 :
370 : template <typename Derived>
371 4 : typename Basis<Derived>::Iterator Basis<Derived>::end() const {
372 4 : return kets.end();
373 : }
374 :
375 : template <typename Derived>
376 8 : Basis<Derived>::Iterator::Iterator(typename ketvec_t::const_iterator it) : it{std::move(it)} {}
377 :
378 : template <typename Derived>
379 80 : bool Basis<Derived>::Iterator::operator!=(const Iterator &other) const {
380 80 : return other.it != it;
381 : }
382 :
383 : template <typename Derived>
384 76 : std::shared_ptr<const typename Basis<Derived>::ket_t> Basis<Derived>::Iterator::operator*() const {
385 76 : return *it;
386 : }
387 :
388 : template <typename Derived>
389 76 : typename Basis<Derived>::Iterator &Basis<Derived>::Iterator::operator++() {
390 76 : ++it;
391 76 : return *this;
392 : }
393 :
394 : template <typename Derived>
395 19113 : size_t Basis<Derived>::get_number_of_states() const {
396 19113 : return coefficients.matrix.cols();
397 : }
398 :
399 : template <typename Derived>
400 7463 : size_t Basis<Derived>::get_number_of_kets() const {
401 7463 : return coefficients.matrix.rows();
402 : }
403 :
404 : template <typename Derived>
405 : const Transformation<typename Basis<Derived>::scalar_t> &
406 591 : Basis<Derived>::get_transformation() const {
407 591 : return coefficients;
408 : }
409 :
410 : template <typename Derived>
411 : Transformation<typename Basis<Derived>::scalar_t>
412 0 : Basis<Derived>::get_rotator(real_t alpha, real_t beta, real_t gamma) const {
413 0 : Transformation<scalar_t> transformation{{static_cast<Eigen::Index>(coefficients.matrix.rows()),
414 : static_cast<Eigen::Index>(coefficients.matrix.rows())},
415 : {TransformationType::ROTATE}};
416 :
417 0 : std::vector<Eigen::Triplet<scalar_t>> entries;
418 :
419 0 : for (size_t idx_initial = 0; idx_initial < kets.size(); ++idx_initial) {
420 0 : real_t f = kets[idx_initial]->get_quantum_number_f();
421 0 : real_t m_initial = kets[idx_initial]->get_quantum_number_m();
422 :
423 0 : assert(2 * f == std::floor(2 * f) && f >= 0);
424 0 : assert(2 * m_initial == std::floor(2 * m_initial) && m_initial >= -f && m_initial <= f);
425 :
426 0 : for (real_t m_final = -f; m_final <= f; // NOSONAR m_final is precisely representable
427 : ++m_final) {
428 0 : auto val = wigner::wigner_uppercase_d_matrix<scalar_t>(f, m_initial, m_final, alpha,
429 : beta, gamma);
430 0 : size_t idx_final = get_ket_index_from_ket(
431 0 : kets[idx_initial]->get_ket_for_different_quantum_number_m(m_final));
432 0 : entries.emplace_back(idx_final, idx_initial, val);
433 : }
434 : }
435 :
436 0 : transformation.matrix.setFromTriplets(entries.begin(), entries.end());
437 0 : transformation.matrix.makeCompressed();
438 :
439 0 : return transformation;
440 0 : }
441 :
442 : template <typename Derived>
443 1 : Sorting Basis<Derived>::get_sorter(const std::vector<TransformationType> &labels) const {
444 1 : perform_sorter_checks(labels);
445 :
446 : // Throw a meaningful error if sorting by energy is requested as this might be a common mistake
447 1 : if (std::find(labels.begin(), labels.end(), TransformationType::SORT_BY_ENERGY) !=
448 2 : labels.end()) {
449 0 : throw std::invalid_argument("States do not store the energy and thus can not be sorted by "
450 : "the energy. Use an energy operator instead.");
451 : }
452 :
453 : // Initialize transformation
454 1 : Sorting transformation;
455 1 : transformation.matrix.resize(coefficients.matrix.cols());
456 1 : transformation.matrix.setIdentity();
457 :
458 : // Get the sorter
459 1 : get_sorter_without_checks(labels, transformation);
460 :
461 : // Check if all labels have been used for sorting
462 1 : if (labels != transformation.transformation_type) {
463 0 : throw std::invalid_argument("The states could not be sorted by all the requested labels.");
464 : }
465 :
466 1 : return transformation;
467 0 : }
468 :
469 : template <typename Derived>
470 : std::vector<IndicesOfBlock>
471 1 : Basis<Derived>::get_indices_of_blocks(const std::vector<TransformationType> &labels) const {
472 1 : perform_sorter_checks(labels);
473 :
474 1 : std::set<TransformationType> unique_labels(labels.begin(), labels.end());
475 1 : perform_blocks_checks(unique_labels);
476 :
477 : // Get the blocks
478 1 : IndicesOfBlocksCreator blocks_creator({0, static_cast<size_t>(coefficients.matrix.cols())});
479 1 : get_indices_of_blocks_without_checks(unique_labels, blocks_creator);
480 :
481 2 : return blocks_creator.create();
482 1 : }
483 :
484 : template <typename Derived>
485 527 : void Basis<Derived>::get_sorter_without_checks(const std::vector<TransformationType> &labels,
486 : Sorting &transformation) const {
487 527 : constexpr real_t numerical_precision = 100 * std::numeric_limits<real_t>::epsilon();
488 :
489 527 : int *perm_begin = transformation.matrix.indices().data();
490 527 : int *perm_end = perm_begin + coefficients.matrix.cols();
491 527 : const int *perm_back = perm_end - 1;
492 :
493 : // Sort the vector based on the requested labels
494 589108 : std::stable_sort(perm_begin, perm_end, [&](int a, int b) {
495 263171 : for (const auto &label : labels) {
496 172074 : switch (label) {
497 30703 : case TransformationType::SORT_BY_PARITY:
498 30703 : if (state_index_to_parity[a] != state_index_to_parity[b]) {
499 51112 : return state_index_to_parity[a] < state_index_to_parity[b];
500 : }
501 20305 : break;
502 141371 : case TransformationType::SORT_BY_QUANTUM_NUMBER_M:
503 282742 : if (std::abs(state_index_to_quantum_number_m[a] -
504 282742 : state_index_to_quantum_number_m[b]) > numerical_precision) {
505 40714 : return state_index_to_quantum_number_m[a] < state_index_to_quantum_number_m[b];
506 : }
507 100657 : break;
508 0 : case TransformationType::SORT_BY_QUANTUM_NUMBER_F:
509 0 : if (std::abs(state_index_to_quantum_number_f[a] -
510 0 : state_index_to_quantum_number_f[b]) > numerical_precision) {
511 0 : return state_index_to_quantum_number_f[a] < state_index_to_quantum_number_f[b];
512 : }
513 0 : break;
514 0 : case TransformationType::SORT_BY_KET:
515 0 : if (state_index_to_ket_index[a] != state_index_to_ket_index[b]) {
516 0 : return state_index_to_ket_index[a] < state_index_to_ket_index[b];
517 : }
518 0 : break;
519 0 : default:
520 0 : std::abort(); // Can't happen because of previous checks
521 : }
522 : }
523 91097 : return false; // Elements are equal
524 : });
525 :
526 : // Check for invalid values and add transformation types
527 1093 : for (const auto &label : labels) {
528 566 : switch (label) {
529 41 : case TransformationType::SORT_BY_PARITY:
530 41 : if (state_index_to_parity[*perm_back] == Parity::UNKNOWN) {
531 0 : throw std::invalid_argument(
532 : "States cannot be labeled and thus not sorted by the parity.");
533 : }
534 41 : transformation.transformation_type.push_back(TransformationType::SORT_BY_PARITY);
535 41 : break;
536 525 : case TransformationType::SORT_BY_QUANTUM_NUMBER_M:
537 525 : if (state_index_to_quantum_number_m[*perm_back] == std::numeric_limits<real_t>::max()) {
538 0 : throw std::invalid_argument(
539 : "States cannot be labeled and thus not sorted by the quantum number m.");
540 : }
541 525 : transformation.transformation_type.push_back(
542 525 : TransformationType::SORT_BY_QUANTUM_NUMBER_M);
543 525 : break;
544 0 : case TransformationType::SORT_BY_QUANTUM_NUMBER_F:
545 0 : if (state_index_to_quantum_number_f[*perm_back] == std::numeric_limits<real_t>::max()) {
546 0 : throw std::invalid_argument(
547 : "States cannot be labeled and thus not sorted by the quantum number f.");
548 : }
549 0 : transformation.transformation_type.push_back(
550 0 : TransformationType::SORT_BY_QUANTUM_NUMBER_F);
551 0 : break;
552 0 : case TransformationType::SORT_BY_KET:
553 0 : if (state_index_to_ket_index[*perm_back] == std::numeric_limits<int>::max()) {
554 0 : throw std::invalid_argument(
555 : "States cannot be labeled and thus not sorted by kets.");
556 : }
557 0 : transformation.transformation_type.push_back(TransformationType::SORT_BY_KET);
558 0 : break;
559 0 : default:
560 0 : std::abort(); // Can't happen because of previous checks
561 : }
562 : }
563 527 : }
564 :
565 : template <typename Derived>
566 527 : void Basis<Derived>::get_indices_of_blocks_without_checks(
567 : const std::set<TransformationType> &unique_labels,
568 : IndicesOfBlocksCreator &blocks_creator) const {
569 527 : constexpr real_t numerical_precision = 100 * std::numeric_limits<real_t>::epsilon();
570 :
571 527 : auto last_quantum_number_f = state_index_to_quantum_number_f[0];
572 527 : auto last_quantum_number_m = state_index_to_quantum_number_m[0];
573 527 : auto last_parity = state_index_to_parity[0];
574 527 : auto last_ket = state_index_to_ket_index[0];
575 :
576 30825 : for (int i = 0; i < coefficients.matrix.cols(); ++i) {
577 67008 : for (auto label : unique_labels) {
578 37307 : if (label == TransformationType::SORT_BY_QUANTUM_NUMBER_F &&
579 0 : std::abs(state_index_to_quantum_number_f[i] - last_quantum_number_f) >
580 : numerical_precision) {
581 0 : blocks_creator.add(i);
582 0 : break;
583 : }
584 67425 : if (label == TransformationType::SORT_BY_QUANTUM_NUMBER_M &&
585 30118 : std::abs(state_index_to_quantum_number_m[i] - last_quantum_number_m) >
586 : numerical_precision) {
587 453 : blocks_creator.add(i);
588 453 : break;
589 : }
590 44043 : if (label == TransformationType::SORT_BY_PARITY &&
591 7189 : state_index_to_parity[i] != last_parity) {
592 144 : blocks_creator.add(i);
593 144 : break;
594 : }
595 36710 : if (label == TransformationType::SORT_BY_KET &&
596 36710 : state_index_to_ket_index[i] != last_ket &&
597 0 : last_ket != std::numeric_limits<int>::max()) {
598 0 : blocks_creator.add(i);
599 0 : break;
600 : }
601 : }
602 30298 : last_quantum_number_f = state_index_to_quantum_number_f[i];
603 30298 : last_quantum_number_m = state_index_to_quantum_number_m[i];
604 30298 : last_parity = state_index_to_parity[i];
605 30298 : last_ket = state_index_to_ket_index[i];
606 : }
607 527 : }
608 :
609 : template <typename Derived>
610 1460 : std::shared_ptr<const Derived> Basis<Derived>::transformed(const Sorting &transformation) const {
611 : // Create a copy of the current object
612 1460 : auto transformed = std::make_shared<Derived>(derived());
613 :
614 1460 : if (coefficients.matrix.cols() == 0) {
615 0 : return transformed;
616 : }
617 :
618 : // Apply the transformation
619 1460 : transformed->coefficients.matrix = coefficients.matrix * transformation.matrix;
620 1460 : transformed->coefficients.transformation_type = transformation.transformation_type;
621 :
622 1460 : transformed->state_index_to_quantum_number_f.resize(transformation.matrix.size());
623 1460 : transformed->state_index_to_quantum_number_m.resize(transformation.matrix.size());
624 1460 : transformed->state_index_to_parity.resize(transformation.matrix.size());
625 1460 : transformed->state_index_to_ket_index.resize(transformation.matrix.size());
626 :
627 143601 : for (int i = 0; i < transformation.matrix.size(); ++i) {
628 142141 : transformed->state_index_to_quantum_number_f[i] =
629 142141 : state_index_to_quantum_number_f[transformation.matrix.indices()[i]];
630 142141 : transformed->state_index_to_quantum_number_m[i] =
631 142141 : state_index_to_quantum_number_m[transformation.matrix.indices()[i]];
632 142141 : transformed->state_index_to_parity[i] =
633 142141 : state_index_to_parity[transformation.matrix.indices()[i]];
634 :
635 142141 : size_t ket_index = state_index_to_ket_index[transformation.matrix.indices()[i]];
636 142141 : transformed->state_index_to_ket_index[i] = ket_index;
637 142141 : if (ket_index != std::numeric_limits<int>::max()) {
638 142141 : transformed->ket_index_to_state_index[ket_index] = i;
639 : }
640 : }
641 :
642 1460 : return transformed;
643 1460 : }
644 :
645 : template <typename Derived>
646 : std::shared_ptr<const Derived>
647 590 : Basis<Derived>::transformed(const Transformation<scalar_t> &transformation) const {
648 : // TODO why is "numerical_precision = 100 * std::sqrt(coefficients.matrix.rows()) *
649 : // std::numeric_limits<real_t>::epsilon()" too small for figuring out whether m is conserved?
650 590 : real_t numerical_precision = 0.001;
651 :
652 : // If the transformation is a rotation, it should be a rotation and nothing else
653 590 : bool is_rotation = false;
654 1180 : for (auto t : transformation.transformation_type) {
655 590 : if (t == TransformationType::ROTATE) {
656 0 : is_rotation = true;
657 0 : break;
658 : }
659 : }
660 590 : if (is_rotation && transformation.transformation_type.size() != 1) {
661 0 : throw std::invalid_argument("A rotation can not be combined with other transformations.");
662 : }
663 :
664 : // To apply a rotation, the object must only be sorted but other transformations are not allowed
665 590 : if (is_rotation) {
666 0 : for (auto t : coefficients.transformation_type) {
667 0 : if (!utils::is_sorting(t)) {
668 0 : throw std::runtime_error(
669 : "If the object was transformed by a different transformation "
670 : "than sorting, it can not be rotated.");
671 : }
672 : }
673 : }
674 :
675 : // Create a copy of the current object
676 590 : auto transformed = std::make_shared<Derived>(derived());
677 :
678 590 : if (coefficients.matrix.cols() == 0) {
679 0 : return transformed;
680 : }
681 :
682 : // Apply the transformation
683 : // If a quantum number turns out to be conserved by the transformation, it will be
684 : // rounded to the nearest half integer to avoid loss of numerical_precision.
685 590 : transformed->coefficients.matrix = coefficients.matrix * transformation.matrix;
686 590 : transformed->coefficients.transformation_type = transformation.transformation_type;
687 :
688 590 : Eigen::SparseMatrix<real_t> probs = transformation.matrix.cwiseAbs2().transpose();
689 :
690 : {
691 1180 : auto map = Eigen::Map<const Eigen::VectorX<real_t>>(state_index_to_quantum_number_f.data(),
692 590 : state_index_to_quantum_number_f.size());
693 590 : Eigen::VectorX<real_t> val = probs * map;
694 590 : Eigen::VectorX<real_t> sq = probs * map.cwiseAbs2();
695 590 : Eigen::VectorX<real_t> diff = (val.cwiseAbs2() - sq).cwiseAbs();
696 590 : transformed->state_index_to_quantum_number_f.resize(probs.rows());
697 :
698 29514 : for (size_t i = 0; i < transformed->state_index_to_quantum_number_f.size(); ++i) {
699 28924 : if (diff[i] < numerical_precision) {
700 4887 : transformed->state_index_to_quantum_number_f[i] = std::round(val[i] * 2) / 2;
701 : } else {
702 24037 : transformed->state_index_to_quantum_number_f[i] =
703 24037 : std::numeric_limits<real_t>::max();
704 24037 : transformed->_has_quantum_number_f = false;
705 : }
706 : }
707 590 : }
708 :
709 : {
710 1180 : auto map = Eigen::Map<const Eigen::VectorX<real_t>>(state_index_to_quantum_number_m.data(),
711 590 : state_index_to_quantum_number_m.size());
712 590 : Eigen::VectorX<real_t> val = probs * map;
713 590 : Eigen::VectorX<real_t> sq = probs * map.cwiseAbs2();
714 590 : Eigen::VectorX<real_t> diff = (val.cwiseAbs2() - sq).cwiseAbs();
715 590 : transformed->state_index_to_quantum_number_m.resize(probs.rows());
716 :
717 29514 : for (size_t i = 0; i < transformed->state_index_to_quantum_number_m.size(); ++i) {
718 28924 : if (diff[i] < numerical_precision) {
719 25354 : transformed->state_index_to_quantum_number_m[i] = std::round(val[i] * 2) / 2;
720 : } else {
721 3570 : transformed->state_index_to_quantum_number_m[i] =
722 3570 : std::numeric_limits<real_t>::max();
723 3570 : transformed->_has_quantum_number_m = false;
724 : }
725 : }
726 590 : }
727 :
728 : {
729 : using utype = std::underlying_type<Parity>::type;
730 590 : Eigen::VectorX<real_t> map(state_index_to_parity.size());
731 34612 : for (size_t i = 0; i < state_index_to_parity.size(); ++i) {
732 34022 : map[i] = static_cast<utype>(state_index_to_parity[i]);
733 : }
734 590 : Eigen::VectorX<real_t> val = probs * map;
735 590 : Eigen::VectorX<real_t> sq = probs * map.cwiseAbs2();
736 590 : Eigen::VectorX<real_t> diff = (val.cwiseAbs2() - sq).cwiseAbs();
737 590 : transformed->state_index_to_parity.resize(probs.rows());
738 :
739 29514 : for (size_t i = 0; i < transformed->state_index_to_parity.size(); ++i) {
740 28924 : if (diff[i] < numerical_precision) {
741 23498 : transformed->state_index_to_parity[i] = static_cast<Parity>(std::lround(val[i]));
742 : } else {
743 5426 : transformed->state_index_to_parity[i] = Parity::UNKNOWN;
744 5426 : transformed->_has_parity = false;
745 : }
746 : }
747 590 : }
748 :
749 : {
750 : // In the following, we obtain a bijective map between state index and ket index.
751 :
752 : // Find the maximum value in each row and column
753 590 : std::vector<real_t> max_in_row(transformed->coefficients.matrix.rows(), 0);
754 590 : std::vector<real_t> max_in_col(transformed->coefficients.matrix.cols(), 0);
755 34612 : for (int row = 0; row < transformed->coefficients.matrix.outerSize(); ++row) {
756 34022 : for (typename Eigen::SparseMatrix<scalar_t, Eigen::RowMajor>::InnerIterator it(
757 34022 : transformed->coefficients.matrix, row);
758 667854 : it; ++it) {
759 633832 : real_t val = std::pow(std::abs(it.value()), 2);
760 633832 : max_in_row[row] = std::max(max_in_row[row], val);
761 633832 : max_in_col[it.col()] = std::max(max_in_col[it.col()], val);
762 : }
763 : }
764 :
765 : // Use the maximum values to define a cost for a sub-optimal mapping
766 590 : std::vector<real_t> costs;
767 590 : std::vector<std::pair<int, int>> mappings;
768 590 : costs.reserve(transformed->coefficients.matrix.nonZeros());
769 590 : mappings.reserve(transformed->coefficients.matrix.nonZeros());
770 34612 : for (int row = 0; row < transformed->coefficients.matrix.outerSize(); ++row) {
771 34022 : for (typename Eigen::SparseMatrix<scalar_t, Eigen::RowMajor>::InnerIterator it(
772 34022 : transformed->coefficients.matrix, row);
773 667854 : it; ++it) {
774 633832 : real_t val = std::pow(std::abs(it.value()), 2);
775 633832 : real_t cost = max_in_row[row] + max_in_col[it.col()] - 2 * val;
776 633832 : costs.push_back(cost);
777 633832 : mappings.push_back({row, it.col()});
778 : }
779 : }
780 :
781 : // Obtain from the costs in which order the mappings should be considered
782 590 : std::vector<size_t> order(costs.size());
783 590 : std::iota(order.begin(), order.end(), 0);
784 590 : std::sort(order.begin(), order.end(),
785 8232674 : [&](size_t a, size_t b) { return costs[a] < costs[b]; });
786 :
787 : // Fill ket_index_to_state_index with invalid values as there can be more kets than states
788 590 : std::fill(transformed->ket_index_to_state_index.begin(),
789 590 : transformed->ket_index_to_state_index.end(), std::numeric_limits<int>::max());
790 :
791 : // Generate the bijective map
792 590 : std::vector<bool> row_used(transformed->coefficients.matrix.rows(), false);
793 590 : std::vector<bool> col_used(transformed->coefficients.matrix.cols(), false);
794 590 : int num_used = 0;
795 110853 : for (size_t idx : order) {
796 110849 : int row = mappings[idx].first; // corresponds to the ket index
797 110849 : int col = mappings[idx].second; // corresponds to the state index
798 110849 : if (!row_used[row] && !col_used[col]) {
799 28922 : row_used[row] = true;
800 28922 : col_used[col] = true;
801 28922 : num_used++;
802 28922 : transformed->state_index_to_ket_index[col] = row;
803 28922 : transformed->ket_index_to_state_index[row] = col;
804 : }
805 110849 : if (num_used == transformed->coefficients.matrix.cols()) {
806 586 : break;
807 : }
808 : }
809 590 : if (num_used != transformed->coefficients.matrix.cols()) {
810 4 : SPDLOG_WARN("A bijective map between states and kets could not be found.");
811 : }
812 590 : }
813 :
814 590 : return transformed;
815 590 : }
816 :
817 : template <typename Derived>
818 201349 : size_t Basis<Derived>::hash::operator()(const std::shared_ptr<const ket_t> &k) const {
819 201349 : return typename ket_t::hash()(*k);
820 : }
821 :
822 : template <typename Derived>
823 1152 : bool Basis<Derived>::equal_to::operator()(const std::shared_ptr<const ket_t> &lhs,
824 : const std::shared_ptr<const ket_t> &rhs) const {
825 1152 : return *lhs == *rhs;
826 : }
827 :
828 : // Explicit instantiations
829 : template class Basis<BasisAtom<double>>;
830 : template class Basis<BasisAtom<std::complex<double>>>;
831 : template class Basis<BasisPair<double>>;
832 : template class Basis<BasisPair<std::complex<double>>>;
833 : } // namespace pairinteraction
|