Line data Source code
1 : // SPDX-FileCopyrightText: 2024 PairInteraction Developers
2 : // SPDX-License-Identifier: LGPL-3.0-or-later
3 :
4 : #include "pairinteraction/basis/Basis.hpp"
5 :
6 : #include "pairinteraction/basis/BasisAtom.hpp"
7 : #include "pairinteraction/basis/BasisPair.hpp"
8 : #include "pairinteraction/enums/Parity.hpp"
9 : #include "pairinteraction/enums/TransformationType.hpp"
10 : #include "pairinteraction/ket/KetAtom.hpp"
11 : #include "pairinteraction/ket/KetPair.hpp"
12 : #include "pairinteraction/utils/eigen_assertion.hpp"
13 : #include "pairinteraction/utils/eigen_compat.hpp"
14 : #include "pairinteraction/utils/wigner.hpp"
15 :
16 : #include <cassert>
17 : #include <numeric>
18 : #include <set>
19 : #include <spdlog/spdlog.h>
20 :
21 : namespace pairinteraction {
22 :
23 : template <typename Scalar>
24 : class BasisAtom;
25 :
26 : template <typename Derived>
27 1958 : void Basis<Derived>::perform_sorter_checks(const std::vector<TransformationType> &labels) const {
28 : // Check if the labels are valid sorting labels
29 3943 : for (const auto &label : labels) {
30 1985 : if (!utils::is_sorting(label)) {
31 0 : throw std::invalid_argument("One of the labels is not a valid sorting label.");
32 : }
33 : }
34 1958 : }
35 :
36 : template <typename Derived>
37 566 : void Basis<Derived>::perform_blocks_checks(
38 : const std::set<TransformationType> &unique_labels) const {
39 : // Check if the states are sorted by the requested labels
40 566 : std::set<TransformationType> unique_labels_present;
41 1120 : for (const auto &label : get_transformation().transformation_type) {
42 605 : if (!utils::is_sorting(label) || unique_labels_present.size() >= unique_labels.size()) {
43 51 : break;
44 : }
45 554 : unique_labels_present.insert(label);
46 : }
47 566 : if (unique_labels != unique_labels_present) {
48 0 : throw std::invalid_argument("The states are not sorted by the requested labels.");
49 : }
50 :
51 : // Throw a meaningful error if getting the blocks by energy is requested as this might be a
52 : // common mistake
53 566 : if (unique_labels.contains(TransformationType::SORT_BY_ENERGY)) {
54 0 : throw std::invalid_argument("States do not store the energy and thus no energy blocks can "
55 : "be obtained. Use an energy operator instead.");
56 : }
57 566 : }
58 :
59 : template <typename Derived>
60 357 : Basis<Derived>::Basis(ketvec_t &&kets)
61 714 : : kets(std::move(kets)), coefficients{{static_cast<Eigen::Index>(this->kets.size()),
62 357 : static_cast<Eigen::Index>(this->kets.size())},
63 714 : {TransformationType::SORT_BY_KET}} {
64 357 : if (this->kets.empty()) {
65 0 : throw std::invalid_argument("The basis must contain at least one element.");
66 : }
67 357 : state_index_to_quantum_number_f.reserve(this->kets.size());
68 357 : state_index_to_quantum_number_m.reserve(this->kets.size());
69 357 : state_index_to_parity.reserve(this->kets.size());
70 357 : ket_to_ket_index.reserve(this->kets.size());
71 357 : size_t index = 0;
72 196003 : for (const auto &ket : this->kets) {
73 195646 : state_index_to_quantum_number_f.push_back(ket->get_quantum_number_f());
74 195646 : state_index_to_quantum_number_m.push_back(ket->get_quantum_number_m());
75 195646 : state_index_to_parity.push_back(ket->get_parity());
76 195646 : ket_to_ket_index[ket] = index++;
77 195646 : if (ket->get_quantum_number_f() == std::numeric_limits<real_t>::max()) {
78 175963 : _has_quantum_number_f = false;
79 : }
80 195646 : if (ket->get_quantum_number_m() == std::numeric_limits<real_t>::max()) {
81 64000 : _has_quantum_number_m = false;
82 : }
83 195646 : if (ket->get_parity() == Parity::UNKNOWN) {
84 175963 : _has_parity = false;
85 : }
86 : }
87 357 : state_index_to_ket_index.resize(this->kets.size());
88 357 : std::iota(state_index_to_ket_index.begin(), state_index_to_ket_index.end(), 0);
89 357 : ket_index_to_state_index.resize(this->kets.size());
90 357 : std::iota(ket_index_to_state_index.begin(), ket_index_to_state_index.end(), 0);
91 357 : coefficients.matrix.setIdentity();
92 357 : }
93 :
94 : template <typename Derived>
95 517 : bool Basis<Derived>::has_quantum_number_f() const {
96 517 : return _has_quantum_number_f;
97 : }
98 :
99 : template <typename Derived>
100 306509 : bool Basis<Derived>::has_quantum_number_m() const {
101 306509 : return _has_quantum_number_m;
102 : }
103 :
104 : template <typename Derived>
105 517 : bool Basis<Derived>::has_parity() const {
106 517 : return _has_parity;
107 : }
108 :
109 : template <typename Derived>
110 2541 : const Derived &Basis<Derived>::derived() const {
111 2541 : return static_cast<const Derived &>(*this);
112 : }
113 :
114 : template <typename Derived>
115 815 : const typename Basis<Derived>::ketvec_t &Basis<Derived>::get_kets() const {
116 815 : return kets;
117 : }
118 :
119 : template <typename Derived>
120 : const Eigen::SparseMatrix<typename Basis<Derived>::scalar_t, Eigen::RowMajor> &
121 9491 : Basis<Derived>::get_coefficients() const {
122 9491 : return coefficients.matrix;
123 : }
124 :
125 : template <typename Derived>
126 : Eigen::SparseMatrix<typename Basis<Derived>::scalar_t, Eigen::RowMajor> &
127 0 : Basis<Derived>::get_coefficients() {
128 0 : return coefficients.matrix;
129 : }
130 :
131 : template <typename Derived>
132 5 : void Basis<Derived>::set_coefficients(
133 : const Eigen::SparseMatrix<scalar_t, Eigen::RowMajor> &values) {
134 5 : if (values.rows() != coefficients.matrix.rows()) {
135 0 : throw std::invalid_argument("Incompatible number of rows.");
136 : }
137 5 : if (values.cols() != coefficients.matrix.cols()) {
138 0 : throw std::invalid_argument("Incompatible number of columns.");
139 : }
140 :
141 5 : coefficients.matrix = values;
142 :
143 5 : std::fill(ket_index_to_state_index.begin(), ket_index_to_state_index.end(),
144 5 : std::numeric_limits<int>::max());
145 5 : std::fill(state_index_to_ket_index.begin(), state_index_to_ket_index.end(),
146 5 : std::numeric_limits<int>::max());
147 5 : std::fill(state_index_to_quantum_number_f.begin(), state_index_to_quantum_number_f.end(),
148 5 : std::numeric_limits<real_t>::max());
149 5 : std::fill(state_index_to_quantum_number_m.begin(), state_index_to_quantum_number_m.end(),
150 5 : std::numeric_limits<real_t>::max());
151 5 : std::fill(state_index_to_parity.begin(), state_index_to_parity.end(), Parity::UNKNOWN);
152 5 : _has_quantum_number_f = false;
153 5 : _has_quantum_number_m = false;
154 5 : _has_parity = false;
155 5 : }
156 :
157 : template <typename Derived>
158 552 : int Basis<Derived>::get_ket_index_from_ket(std::shared_ptr<const ket_t> ket) const {
159 552 : if (!ket_to_ket_index.contains(ket)) {
160 0 : return -1;
161 : }
162 552 : return ket_to_ket_index.at(ket);
163 : }
164 :
165 : template <typename Derived>
166 : Eigen::VectorX<typename Basis<Derived>::scalar_t>
167 109 : Basis<Derived>::get_amplitudes(std::shared_ptr<const ket_t> ket) const {
168 109 : int ket_index = get_ket_index_from_ket(ket);
169 109 : if (ket_index < 0) {
170 0 : throw std::invalid_argument("The ket does not belong to the basis.");
171 : }
172 : // The following line is a more efficient alternative to
173 : // "get_amplitudes(get_canonical_state_from_ket(ket)).transpose()"
174 218 : return coefficients.matrix.row(ket_index);
175 : }
176 :
177 : template <typename Derived>
178 : Eigen::SparseMatrix<typename Basis<Derived>::scalar_t, Eigen::RowMajor>
179 16 : Basis<Derived>::get_amplitudes(std::shared_ptr<const Derived> other) const {
180 24 : return other->coefficients.matrix.adjoint() * coefficients.matrix;
181 : }
182 :
183 : template <typename Derived>
184 : Eigen::VectorX<typename Basis<Derived>::real_t>
185 103 : Basis<Derived>::get_overlaps(std::shared_ptr<const ket_t> ket) const {
186 103 : return get_amplitudes(ket).cwiseAbs2();
187 : }
188 :
189 : template <typename Derived>
190 : Eigen::SparseMatrix<typename Basis<Derived>::real_t, Eigen::RowMajor>
191 8 : Basis<Derived>::get_overlaps(std::shared_ptr<const Derived> other) const {
192 8 : return get_amplitudes(other).cwiseAbs2();
193 : }
194 :
195 : template <typename Derived>
196 0 : typename Basis<Derived>::real_t Basis<Derived>::get_quantum_number_f(size_t state_index) const {
197 0 : real_t quantum_number_f = state_index_to_quantum_number_f.at(state_index);
198 0 : if (quantum_number_f == std::numeric_limits<real_t>::max()) {
199 0 : throw std::invalid_argument("The state does not have a well-defined quantum number f.");
200 : }
201 0 : return quantum_number_f;
202 : }
203 :
204 : template <typename Derived>
205 242034 : typename Basis<Derived>::real_t Basis<Derived>::get_quantum_number_m(size_t state_index) const {
206 242034 : real_t quantum_number_m = state_index_to_quantum_number_m.at(state_index);
207 242034 : if (quantum_number_m == std::numeric_limits<real_t>::max()) {
208 0 : throw std::invalid_argument("The state does not have a well-defined quantum number m.");
209 : }
210 242034 : return quantum_number_m;
211 : }
212 :
213 : template <typename Derived>
214 43 : Parity Basis<Derived>::get_parity(size_t state_index) const {
215 43 : Parity parity = state_index_to_parity.at(state_index);
216 43 : if (parity == Parity::UNKNOWN) {
217 0 : throw std::invalid_argument("The state does not have a well-defined parity.");
218 : }
219 43 : return parity;
220 : }
221 :
222 : template <typename Derived>
223 : std::shared_ptr<const typename Basis<Derived>::ket_t>
224 152 : Basis<Derived>::get_corresponding_ket(size_t state_index) const {
225 152 : size_t ket_index = state_index_to_ket_index.at(state_index);
226 152 : if (ket_index == std::numeric_limits<int>::max()) {
227 0 : throw std::invalid_argument("The state does not belong to a ket in a well-defined way.");
228 : }
229 152 : return kets[ket_index];
230 : }
231 :
232 : template <typename Derived>
233 : std::shared_ptr<const typename Basis<Derived>::ket_t>
234 0 : Basis<Derived>::get_corresponding_ket(std::shared_ptr<const Derived> /*state*/) const {
235 0 : throw std::runtime_error("Not implemented yet.");
236 : }
237 :
238 : template <typename Derived>
239 374 : std::shared_ptr<const Derived> Basis<Derived>::get_state(size_t state_index) const {
240 : // Create a copy of the current object
241 374 : auto restricted = std::make_shared<Derived>(derived());
242 :
243 : // Restrict the copy to the state with the largest overlap
244 374 : restricted->coefficients.matrix = restricted->coefficients.matrix.col(state_index);
245 :
246 374 : std::fill(restricted->ket_index_to_state_index.begin(),
247 374 : restricted->ket_index_to_state_index.end(), std::numeric_limits<int>::max());
248 :
249 374 : size_t ket_index = state_index_to_ket_index[state_index];
250 374 : restricted->state_index_to_ket_index = {ket_index};
251 374 : if (ket_index != std::numeric_limits<int>::max()) {
252 374 : restricted->ket_index_to_state_index[ket_index] = 0;
253 : }
254 :
255 374 : restricted->state_index_to_quantum_number_f = {state_index_to_quantum_number_f[state_index]};
256 374 : restricted->state_index_to_quantum_number_m = {state_index_to_quantum_number_m[state_index]};
257 374 : restricted->state_index_to_parity = {state_index_to_parity[state_index]};
258 :
259 748 : restricted->_has_quantum_number_f =
260 374 : restricted->state_index_to_quantum_number_f[0] != std::numeric_limits<real_t>::max();
261 748 : restricted->_has_quantum_number_m =
262 374 : restricted->state_index_to_quantum_number_m[0] != std::numeric_limits<real_t>::max();
263 374 : restricted->_has_parity = restricted->state_index_to_parity[0] != Parity::UNKNOWN;
264 :
265 748 : return restricted;
266 374 : }
267 :
268 : template <typename Derived>
269 : std::shared_ptr<const typename Basis<Derived>::ket_t>
270 0 : Basis<Derived>::get_ket(size_t ket_index) const {
271 0 : return kets[ket_index];
272 : }
273 :
274 : template <typename Derived>
275 52 : std::shared_ptr<const Derived> Basis<Derived>::get_corresponding_state(size_t ket_index) const {
276 52 : size_t state_index = ket_index_to_state_index.at(ket_index);
277 52 : if (state_index == std::numeric_limits<int>::max()) {
278 0 : throw std::runtime_error("The ket does not belong to a state in a well-defined way.");
279 : }
280 52 : return get_state(state_index);
281 : }
282 :
283 : template <typename Derived>
284 : std::shared_ptr<const Derived>
285 52 : Basis<Derived>::get_corresponding_state(std::shared_ptr<const ket_t> ket) const {
286 52 : int ket_index = get_ket_index_from_ket(ket);
287 52 : if (ket_index < 0) {
288 0 : throw std::invalid_argument("The ket does not belong to the basis.");
289 : }
290 52 : return get_corresponding_state(ket_index);
291 : }
292 :
293 : template <typename Derived>
294 182 : size_t Basis<Derived>::get_corresponding_state_index(size_t ket_index) const {
295 182 : int state_index = ket_index_to_state_index.at(ket_index);
296 182 : if (state_index == std::numeric_limits<int>::max()) {
297 0 : throw std::runtime_error("The ket does not belong to a state in a well-defined way.");
298 : }
299 182 : return state_index;
300 : }
301 :
302 : template <typename Derived>
303 182 : size_t Basis<Derived>::get_corresponding_state_index(std::shared_ptr<const ket_t> ket) const {
304 182 : int ket_index = get_ket_index_from_ket(ket);
305 182 : if (ket_index < 0) {
306 0 : throw std::invalid_argument("The ket does not belong to the basis.");
307 : }
308 182 : return get_corresponding_state_index(ket_index);
309 : }
310 :
311 : template <typename Derived>
312 152 : size_t Basis<Derived>::get_corresponding_ket_index(size_t state_index) const {
313 152 : size_t ket_index = state_index_to_ket_index.at(state_index);
314 152 : if (ket_index == std::numeric_limits<int>::max()) {
315 0 : throw std::runtime_error("The state does not belong to a ket in a well-defined way.");
316 : }
317 152 : return ket_index;
318 : }
319 :
320 : template <typename Derived>
321 0 : size_t Basis<Derived>::get_corresponding_ket_index(std::shared_ptr<const Derived> /*state*/) const {
322 0 : throw std::runtime_error("Not implemented yet.");
323 : }
324 :
325 : template <typename Derived>
326 : std::shared_ptr<const Derived>
327 209 : Basis<Derived>::get_canonical_state_from_ket(size_t ket_index) const {
328 : // Create a copy of the current object
329 209 : auto created = std::make_shared<Derived>(derived());
330 :
331 : // Fill the copy with the state corresponding to the ket index
332 209 : created->coefficients.matrix =
333 418 : Eigen::SparseMatrix<scalar_t, Eigen::RowMajor>(coefficients.matrix.rows(), 1);
334 209 : created->coefficients.matrix.coeffRef(ket_index, 0) = 1;
335 209 : created->coefficients.matrix.makeCompressed();
336 :
337 209 : std::fill(created->ket_index_to_state_index.begin(), created->ket_index_to_state_index.end(),
338 209 : std::numeric_limits<int>::max());
339 209 : created->ket_index_to_state_index[ket_index] = 0;
340 :
341 209 : created->state_index_to_quantum_number_f = {kets[ket_index]->get_quantum_number_f()};
342 209 : created->state_index_to_quantum_number_m = {kets[ket_index]->get_quantum_number_m()};
343 209 : created->state_index_to_parity = {kets[ket_index]->get_parity()};
344 209 : created->state_index_to_ket_index = {ket_index};
345 :
346 418 : created->_has_quantum_number_f =
347 209 : created->state_index_to_quantum_number_f[0] != std::numeric_limits<real_t>::max();
348 418 : created->_has_quantum_number_m =
349 209 : created->state_index_to_quantum_number_m[0] != std::numeric_limits<real_t>::max();
350 209 : created->_has_parity = created->state_index_to_parity[0] != Parity::UNKNOWN;
351 :
352 418 : return created;
353 209 : }
354 :
355 : template <typename Derived>
356 : std::shared_ptr<const Derived>
357 209 : Basis<Derived>::get_canonical_state_from_ket(std::shared_ptr<const ket_t> ket) const {
358 209 : int ket_index = get_ket_index_from_ket(ket);
359 209 : if (ket_index < 0) {
360 0 : throw std::invalid_argument("The ket does not belong to the basis.");
361 : }
362 209 : return get_canonical_state_from_ket(ket_index);
363 : }
364 :
365 : template <typename Derived>
366 4 : typename Basis<Derived>::Iterator Basis<Derived>::begin() const {
367 4 : return kets.begin();
368 : }
369 :
370 : template <typename Derived>
371 4 : typename Basis<Derived>::Iterator Basis<Derived>::end() const {
372 4 : return kets.end();
373 : }
374 :
375 : template <typename Derived>
376 8 : Basis<Derived>::Iterator::Iterator(typename ketvec_t::const_iterator it) : it{std::move(it)} {}
377 :
378 : template <typename Derived>
379 80 : bool Basis<Derived>::Iterator::operator!=(const Iterator &other) const {
380 80 : return other.it != it;
381 : }
382 :
383 : template <typename Derived>
384 76 : std::shared_ptr<const typename Basis<Derived>::ket_t> Basis<Derived>::Iterator::operator*() const {
385 76 : return *it;
386 : }
387 :
388 : template <typename Derived>
389 76 : typename Basis<Derived>::Iterator &Basis<Derived>::Iterator::operator++() {
390 76 : ++it;
391 76 : return *this;
392 : }
393 :
394 : template <typename Derived>
395 18049 : size_t Basis<Derived>::get_number_of_states() const {
396 18049 : return coefficients.matrix.cols();
397 : }
398 :
399 : template <typename Derived>
400 6855 : size_t Basis<Derived>::get_number_of_kets() const {
401 6855 : return coefficients.matrix.rows();
402 : }
403 :
404 : template <typename Derived>
405 : const Transformation<typename Basis<Derived>::scalar_t> &
406 567 : Basis<Derived>::get_transformation() const {
407 567 : return coefficients;
408 : }
409 :
410 : template <typename Derived>
411 : Transformation<typename Basis<Derived>::scalar_t>
412 0 : Basis<Derived>::get_rotator(real_t alpha, real_t beta, real_t gamma) const {
413 0 : Transformation<scalar_t> transformation{{static_cast<Eigen::Index>(coefficients.matrix.rows()),
414 : static_cast<Eigen::Index>(coefficients.matrix.rows())},
415 : {TransformationType::ROTATE}};
416 :
417 0 : std::vector<Eigen::Triplet<scalar_t>> entries;
418 :
419 0 : for (size_t idx_initial = 0; idx_initial < kets.size(); ++idx_initial) {
420 0 : real_t f = kets[idx_initial]->get_quantum_number_f();
421 0 : real_t m_initial = kets[idx_initial]->get_quantum_number_m();
422 :
423 0 : assert(2 * f == std::floor(2 * f) && f >= 0);
424 0 : assert(2 * m_initial == std::floor(2 * m_initial) && m_initial >= -f && m_initial <= f);
425 :
426 0 : for (real_t m_final = -f; m_final <= f; // NOSONAR m_final is precisely representable
427 : ++m_final) {
428 0 : auto val = wigner::wigner_uppercase_d_matrix<scalar_t>(f, m_initial, m_final, alpha,
429 : beta, gamma);
430 0 : size_t idx_final = get_ket_index_from_ket(
431 0 : kets[idx_initial]->get_ket_for_different_quantum_number_m(m_final));
432 0 : entries.emplace_back(idx_final, idx_initial, val);
433 : }
434 : }
435 :
436 0 : transformation.matrix.setFromTriplets(entries.begin(), entries.end());
437 0 : transformation.matrix.makeCompressed();
438 :
439 0 : return transformation;
440 0 : }
441 :
442 : template <typename Derived>
443 1 : Sorting Basis<Derived>::get_sorter(const std::vector<TransformationType> &labels) const {
444 1 : perform_sorter_checks(labels);
445 :
446 : // Throw a meaningful error if sorting by energy is requested as this might be a common mistake
447 1 : if (std::find(labels.begin(), labels.end(), TransformationType::SORT_BY_ENERGY) !=
448 2 : labels.end()) {
449 0 : throw std::invalid_argument("States do not store the energy and thus can not be sorted by "
450 : "the energy. Use an energy operator instead.");
451 : }
452 :
453 : // Initialize transformation
454 1 : Sorting transformation;
455 1 : transformation.matrix.resize(coefficients.matrix.cols());
456 1 : transformation.matrix.setIdentity();
457 :
458 : // Get the sorter
459 1 : get_sorter_without_checks(labels, transformation);
460 :
461 : // Check if all labels have been used for sorting
462 1 : if (labels != transformation.transformation_type) {
463 0 : throw std::invalid_argument("The states could not be sorted by all the requested labels.");
464 : }
465 :
466 1 : return transformation;
467 0 : }
468 :
469 : template <typename Derived>
470 : std::vector<IndicesOfBlock>
471 1 : Basis<Derived>::get_indices_of_blocks(const std::vector<TransformationType> &labels) const {
472 1 : perform_sorter_checks(labels);
473 :
474 1 : std::set<TransformationType> unique_labels(labels.begin(), labels.end());
475 1 : perform_blocks_checks(unique_labels);
476 :
477 : // Get the blocks
478 1 : IndicesOfBlocksCreator blocks_creator({0, static_cast<size_t>(coefficients.matrix.cols())});
479 1 : get_indices_of_blocks_without_checks(unique_labels, blocks_creator);
480 :
481 2 : return blocks_creator.create();
482 1 : }
483 :
484 : template <typename Derived>
485 515 : void Basis<Derived>::get_sorter_without_checks(const std::vector<TransformationType> &labels,
486 : Sorting &transformation) const {
487 515 : constexpr real_t numerical_precision = 100 * std::numeric_limits<real_t>::epsilon();
488 :
489 515 : int *perm_begin = transformation.matrix.indices().data();
490 515 : int *perm_end = perm_begin + coefficients.matrix.cols();
491 515 : const int *perm_back = perm_end - 1;
492 :
493 : // Sort the vector based on the requested labels
494 564260 : std::stable_sort(perm_begin, perm_end, [&](int a, int b) {
495 254827 : for (const auto &label : labels) {
496 166142 : switch (label) {
497 30703 : case TransformationType::SORT_BY_PARITY:
498 30703 : if (state_index_to_parity[a] != state_index_to_parity[b]) {
499 47592 : return state_index_to_parity[a] < state_index_to_parity[b];
500 : }
501 20305 : break;
502 135439 : case TransformationType::SORT_BY_QUANTUM_NUMBER_M:
503 270878 : if (std::abs(state_index_to_quantum_number_m[a] -
504 270878 : state_index_to_quantum_number_m[b]) > numerical_precision) {
505 37194 : return state_index_to_quantum_number_m[a] < state_index_to_quantum_number_m[b];
506 : }
507 98245 : break;
508 0 : case TransformationType::SORT_BY_QUANTUM_NUMBER_F:
509 0 : if (std::abs(state_index_to_quantum_number_f[a] -
510 0 : state_index_to_quantum_number_f[b]) > numerical_precision) {
511 0 : return state_index_to_quantum_number_f[a] < state_index_to_quantum_number_f[b];
512 : }
513 0 : break;
514 0 : case TransformationType::SORT_BY_KET:
515 0 : if (state_index_to_ket_index[a] != state_index_to_ket_index[b]) {
516 0 : return state_index_to_ket_index[a] < state_index_to_ket_index[b];
517 : }
518 0 : break;
519 0 : default:
520 0 : std::abort(); // Can't happen because of previous checks
521 : }
522 : }
523 88685 : return false; // Elements are equal
524 : });
525 :
526 : // Check for invalid values and add transformation types
527 1069 : for (const auto &label : labels) {
528 554 : switch (label) {
529 41 : case TransformationType::SORT_BY_PARITY:
530 41 : if (state_index_to_parity[*perm_back] == Parity::UNKNOWN) {
531 0 : throw std::invalid_argument(
532 : "States cannot be labeled and thus not sorted by the parity.");
533 : }
534 41 : transformation.transformation_type.push_back(TransformationType::SORT_BY_PARITY);
535 41 : break;
536 513 : case TransformationType::SORT_BY_QUANTUM_NUMBER_M:
537 513 : if (state_index_to_quantum_number_m[*perm_back] == std::numeric_limits<real_t>::max()) {
538 0 : throw std::invalid_argument(
539 : "States cannot be labeled and thus not sorted by the quantum number m.");
540 : }
541 513 : transformation.transformation_type.push_back(
542 513 : TransformationType::SORT_BY_QUANTUM_NUMBER_M);
543 513 : break;
544 0 : case TransformationType::SORT_BY_QUANTUM_NUMBER_F:
545 0 : if (state_index_to_quantum_number_f[*perm_back] == std::numeric_limits<real_t>::max()) {
546 0 : throw std::invalid_argument(
547 : "States cannot be labeled and thus not sorted by the quantum number f.");
548 : }
549 0 : transformation.transformation_type.push_back(
550 0 : TransformationType::SORT_BY_QUANTUM_NUMBER_F);
551 0 : break;
552 0 : case TransformationType::SORT_BY_KET:
553 0 : if (state_index_to_ket_index[*perm_back] == std::numeric_limits<int>::max()) {
554 0 : throw std::invalid_argument(
555 : "States cannot be labeled and thus not sorted by kets.");
556 : }
557 0 : transformation.transformation_type.push_back(TransformationType::SORT_BY_KET);
558 0 : break;
559 0 : default:
560 0 : std::abort(); // Can't happen because of previous checks
561 : }
562 : }
563 515 : }
564 :
565 : template <typename Derived>
566 515 : void Basis<Derived>::get_indices_of_blocks_without_checks(
567 : const std::set<TransformationType> &unique_labels,
568 : IndicesOfBlocksCreator &blocks_creator) const {
569 515 : constexpr real_t numerical_precision = 100 * std::numeric_limits<real_t>::epsilon();
570 :
571 515 : auto last_quantum_number_f = state_index_to_quantum_number_f[0];
572 515 : auto last_quantum_number_m = state_index_to_quantum_number_m[0];
573 515 : auto last_parity = state_index_to_parity[0];
574 515 : auto last_ket = state_index_to_ket_index[0];
575 :
576 29789 : for (int i = 0; i < coefficients.matrix.cols(); ++i) {
577 65024 : for (auto label : unique_labels) {
578 36283 : if (label == TransformationType::SORT_BY_QUANTUM_NUMBER_F &&
579 0 : std::abs(state_index_to_quantum_number_f[i] - last_quantum_number_f) >
580 : numerical_precision) {
581 0 : blocks_creator.add(i);
582 0 : break;
583 : }
584 65377 : if (label == TransformationType::SORT_BY_QUANTUM_NUMBER_M &&
585 29094 : std::abs(state_index_to_quantum_number_m[i] - last_quantum_number_m) >
586 : numerical_precision) {
587 389 : blocks_creator.add(i);
588 389 : break;
589 : }
590 43083 : if (label == TransformationType::SORT_BY_PARITY &&
591 7189 : state_index_to_parity[i] != last_parity) {
592 144 : blocks_creator.add(i);
593 144 : break;
594 : }
595 35750 : if (label == TransformationType::SORT_BY_KET &&
596 35750 : state_index_to_ket_index[i] != last_ket &&
597 0 : last_ket != std::numeric_limits<int>::max()) {
598 0 : blocks_creator.add(i);
599 0 : break;
600 : }
601 : }
602 29274 : last_quantum_number_f = state_index_to_quantum_number_f[i];
603 29274 : last_quantum_number_m = state_index_to_quantum_number_m[i];
604 29274 : last_parity = state_index_to_parity[i];
605 29274 : last_ket = state_index_to_ket_index[i];
606 : }
607 515 : }
608 :
609 : template <typename Derived>
610 1392 : std::shared_ptr<const Derived> Basis<Derived>::transformed(const Sorting &transformation) const {
611 : // Create a copy of the current object
612 1392 : auto transformed = std::make_shared<Derived>(derived());
613 :
614 1392 : if (coefficients.matrix.cols() == 0) {
615 0 : return transformed;
616 : }
617 :
618 : // Apply the transformation
619 1392 : transformed->coefficients.matrix = coefficients.matrix * transformation.matrix;
620 1392 : transformed->coefficients.transformation_type = transformation.transformation_type;
621 :
622 1392 : transformed->state_index_to_quantum_number_f.resize(transformation.matrix.size());
623 1392 : transformed->state_index_to_quantum_number_m.resize(transformation.matrix.size());
624 1392 : transformed->state_index_to_parity.resize(transformation.matrix.size());
625 1392 : transformed->state_index_to_ket_index.resize(transformation.matrix.size());
626 :
627 138205 : for (int i = 0; i < transformation.matrix.size(); ++i) {
628 136813 : transformed->state_index_to_quantum_number_f[i] =
629 136813 : state_index_to_quantum_number_f[transformation.matrix.indices()[i]];
630 136813 : transformed->state_index_to_quantum_number_m[i] =
631 136813 : state_index_to_quantum_number_m[transformation.matrix.indices()[i]];
632 136813 : transformed->state_index_to_parity[i] =
633 136813 : state_index_to_parity[transformation.matrix.indices()[i]];
634 :
635 136813 : size_t ket_index = state_index_to_ket_index[transformation.matrix.indices()[i]];
636 136813 : transformed->state_index_to_ket_index[i] = ket_index;
637 136813 : if (ket_index != std::numeric_limits<int>::max()) {
638 136813 : transformed->ket_index_to_state_index[ket_index] = i;
639 : }
640 : }
641 :
642 1392 : return transformed;
643 1392 : }
644 :
645 : template <typename Derived>
646 : std::shared_ptr<const Derived>
647 566 : Basis<Derived>::transformed(const Transformation<scalar_t> &transformation) const {
648 : // TODO why is "numerical_precision = 100 * std::sqrt(coefficients.matrix.rows()) *
649 : // std::numeric_limits<real_t>::epsilon()" too small for figuring out whether m is conserved?
650 566 : real_t numerical_precision = 0.001;
651 :
652 : // If the transformation is a rotation, it should be a rotation and nothing else
653 566 : bool is_rotation = false;
654 1132 : for (auto t : transformation.transformation_type) {
655 566 : if (t == TransformationType::ROTATE) {
656 0 : is_rotation = true;
657 0 : break;
658 : }
659 : }
660 566 : if (is_rotation && transformation.transformation_type.size() != 1) {
661 0 : throw std::invalid_argument("A rotation can not be combined with other transformations.");
662 : }
663 :
664 : // To apply a rotation, the object must only be sorted but other transformations are not allowed
665 566 : if (is_rotation) {
666 0 : for (auto t : coefficients.transformation_type) {
667 0 : if (!utils::is_sorting(t)) {
668 0 : throw std::runtime_error(
669 : "If the object was transformed by a different transformation "
670 : "than sorting, it can not be rotated.");
671 : }
672 : }
673 : }
674 :
675 : // Create a copy of the current object
676 566 : auto transformed = std::make_shared<Derived>(derived());
677 :
678 566 : if (coefficients.matrix.cols() == 0) {
679 0 : return transformed;
680 : }
681 :
682 : // Apply the transformation
683 : // If a quantum number turns out to be conserved by the transformation, it will be
684 : // rounded to the nearest half integer to avoid loss of numerical_precision.
685 566 : transformed->coefficients.matrix = coefficients.matrix * transformation.matrix;
686 566 : transformed->coefficients.transformation_type = transformation.transformation_type;
687 :
688 566 : Eigen::SparseMatrix<real_t> probs = transformation.matrix.cwiseAbs2().transpose();
689 :
690 : {
691 1132 : auto map = Eigen::Map<const Eigen::VectorX<real_t>>(state_index_to_quantum_number_f.data(),
692 566 : state_index_to_quantum_number_f.size());
693 566 : Eigen::VectorX<real_t> val = probs * map;
694 566 : Eigen::VectorX<real_t> sq = probs * map.cwiseAbs2();
695 566 : Eigen::VectorX<real_t> diff = (val.cwiseAbs2() - sq).cwiseAbs();
696 566 : transformed->state_index_to_quantum_number_f.resize(probs.rows());
697 :
698 27554 : for (size_t i = 0; i < transformed->state_index_to_quantum_number_f.size(); ++i) {
699 26988 : if (diff[i] < numerical_precision) {
700 4727 : transformed->state_index_to_quantum_number_f[i] = std::round(val[i] * 2) / 2;
701 : } else {
702 22261 : transformed->state_index_to_quantum_number_f[i] =
703 22261 : std::numeric_limits<real_t>::max();
704 22261 : transformed->_has_quantum_number_f = false;
705 : }
706 : }
707 566 : }
708 :
709 : {
710 1132 : auto map = Eigen::Map<const Eigen::VectorX<real_t>>(state_index_to_quantum_number_m.data(),
711 566 : state_index_to_quantum_number_m.size());
712 566 : Eigen::VectorX<real_t> val = probs * map;
713 566 : Eigen::VectorX<real_t> sq = probs * map.cwiseAbs2();
714 566 : Eigen::VectorX<real_t> diff = (val.cwiseAbs2() - sq).cwiseAbs();
715 566 : transformed->state_index_to_quantum_number_m.resize(probs.rows());
716 :
717 27554 : for (size_t i = 0; i < transformed->state_index_to_quantum_number_m.size(); ++i) {
718 26988 : if (diff[i] < numerical_precision) {
719 24330 : transformed->state_index_to_quantum_number_m[i] = std::round(val[i] * 2) / 2;
720 : } else {
721 2658 : transformed->state_index_to_quantum_number_m[i] =
722 2658 : std::numeric_limits<real_t>::max();
723 2658 : transformed->_has_quantum_number_m = false;
724 : }
725 : }
726 566 : }
727 :
728 : {
729 : using utype = std::underlying_type<Parity>::type;
730 566 : Eigen::VectorX<real_t> map(state_index_to_parity.size());
731 32652 : for (size_t i = 0; i < state_index_to_parity.size(); ++i) {
732 32086 : map[i] = static_cast<utype>(state_index_to_parity[i]);
733 : }
734 566 : Eigen::VectorX<real_t> val = probs * map;
735 566 : Eigen::VectorX<real_t> sq = probs * map.cwiseAbs2();
736 566 : Eigen::VectorX<real_t> diff = (val.cwiseAbs2() - sq).cwiseAbs();
737 566 : transformed->state_index_to_parity.resize(probs.rows());
738 :
739 27554 : for (size_t i = 0; i < transformed->state_index_to_parity.size(); ++i) {
740 26988 : if (diff[i] < numerical_precision) {
741 22058 : transformed->state_index_to_parity[i] = static_cast<Parity>(std::lround(val[i]));
742 : } else {
743 4930 : transformed->state_index_to_parity[i] = Parity::UNKNOWN;
744 4930 : transformed->_has_parity = false;
745 : }
746 : }
747 566 : }
748 :
749 : {
750 : // In the following, we obtain a bijective map between state index and ket index.
751 :
752 : // Find the maximum value in each row and column
753 566 : std::vector<real_t> max_in_row(transformed->coefficients.matrix.rows(), 0);
754 566 : std::vector<real_t> max_in_col(transformed->coefficients.matrix.cols(), 0);
755 32652 : for (int row = 0; row < transformed->coefficients.matrix.outerSize(); ++row) {
756 32086 : for (typename Eigen::SparseMatrix<scalar_t, Eigen::RowMajor>::InnerIterator it(
757 32086 : transformed->coefficients.matrix, row);
758 601942 : it; ++it) {
759 569856 : real_t val = std::pow(std::abs(it.value()), 2);
760 569856 : max_in_row[row] = std::max(max_in_row[row], val);
761 569856 : max_in_col[it.col()] = std::max(max_in_col[it.col()], val);
762 : }
763 : }
764 :
765 : // Use the maximum values to define a cost for a sub-optimal mapping
766 566 : std::vector<real_t> costs;
767 566 : std::vector<std::pair<int, int>> mappings;
768 566 : costs.reserve(transformed->coefficients.matrix.nonZeros());
769 566 : mappings.reserve(transformed->coefficients.matrix.nonZeros());
770 32652 : for (int row = 0; row < transformed->coefficients.matrix.outerSize(); ++row) {
771 32086 : for (typename Eigen::SparseMatrix<scalar_t, Eigen::RowMajor>::InnerIterator it(
772 32086 : transformed->coefficients.matrix, row);
773 601942 : it; ++it) {
774 569856 : real_t val = std::pow(std::abs(it.value()), 2);
775 569856 : real_t cost = max_in_row[row] + max_in_col[it.col()] - 2 * val;
776 569856 : costs.push_back(cost);
777 569856 : mappings.push_back({row, it.col()});
778 : }
779 : }
780 :
781 : // Obtain from the costs in which order the mappings should be considered
782 566 : std::vector<size_t> order(costs.size());
783 566 : std::iota(order.begin(), order.end(), 0);
784 566 : std::sort(order.begin(), order.end(),
785 7351168 : [&](size_t a, size_t b) { return costs[a] < costs[b]; });
786 :
787 : // Fill ket_index_to_state_index with invalid values as there can be more kets than states
788 566 : std::fill(transformed->ket_index_to_state_index.begin(),
789 566 : transformed->ket_index_to_state_index.end(), std::numeric_limits<int>::max());
790 :
791 : // Generate the bijective map
792 566 : std::vector<bool> row_used(transformed->coefficients.matrix.rows(), false);
793 566 : std::vector<bool> col_used(transformed->coefficients.matrix.cols(), false);
794 566 : int num_used = 0;
795 85793 : for (size_t idx : order) {
796 85789 : int row = mappings[idx].first; // corresponds to the ket index
797 85789 : int col = mappings[idx].second; // corresponds to the state index
798 85789 : if (!row_used[row] && !col_used[col]) {
799 26986 : row_used[row] = true;
800 26986 : col_used[col] = true;
801 26986 : num_used++;
802 26986 : transformed->state_index_to_ket_index[col] = row;
803 26986 : transformed->ket_index_to_state_index[row] = col;
804 : }
805 85789 : if (num_used == transformed->coefficients.matrix.cols()) {
806 562 : break;
807 : }
808 : }
809 566 : if (num_used != transformed->coefficients.matrix.cols()) {
810 4 : SPDLOG_WARN("A bijective map between states and kets could not be found.");
811 : }
812 566 : }
813 :
814 566 : return transformed;
815 566 : }
816 :
817 : template <typename Derived>
818 196750 : size_t Basis<Derived>::hash::operator()(const std::shared_ptr<const ket_t> &k) const {
819 196750 : return typename ket_t::hash()(*k);
820 : }
821 :
822 : template <typename Derived>
823 1104 : bool Basis<Derived>::equal_to::operator()(const std::shared_ptr<const ket_t> &lhs,
824 : const std::shared_ptr<const ket_t> &rhs) const {
825 1104 : return *lhs == *rhs;
826 : }
827 :
828 : // Explicit instantiations
829 : template class Basis<BasisAtom<double>>;
830 : template class Basis<BasisAtom<std::complex<double>>>;
831 : template class Basis<BasisPair<double>>;
832 : template class Basis<BasisPair<std::complex<double>>>;
833 : } // namespace pairinteraction
|