22template <
typename Scalar>
25template <
typename Derived>
28 for (
const auto &label : labels) {
30 throw std::invalid_argument(
"One of the labels is not a valid sorting label.");
35template <
typename Derived>
37 const std::set<TransformationType> &unique_labels)
const {
39 std::set<TransformationType> unique_labels_present;
40 for (
const auto &label : get_transformation().transformation_type) {
41 if (!
utils::is_sorting(label) || unique_labels_present.size() >= unique_labels.size()) {
44 unique_labels_present.insert(label);
46 if (unique_labels != unique_labels_present) {
47 throw std::invalid_argument(
"The states are not sorted by the requested labels.");
53 throw std::invalid_argument(
"States do not store the energy and thus no energy blocks can "
54 "be obtained. Use an energy operator instead.");
58template <
typename Derived>
60 : kets(std::move(kets)), coefficients{{
static_cast<Eigen::Index
>(this->kets.size()),
61 static_cast<Eigen::Index
>(this->kets.size())},
63 if (this->kets.empty()) {
64 throw std::invalid_argument(
"The basis must contain at least one element.");
66 state_index_to_quantum_number_f.reserve(this->kets.size());
67 state_index_to_quantum_number_m.reserve(this->kets.size());
68 state_index_to_parity.reserve(this->kets.size());
69 ket_to_ket_index.reserve(this->kets.size());
71 for (
const auto &ket : this->kets) {
72 state_index_to_quantum_number_f.push_back(ket->get_quantum_number_f());
73 state_index_to_quantum_number_m.push_back(ket->get_quantum_number_m());
74 state_index_to_parity.push_back(ket->get_parity());
75 ket_to_ket_index[ket] = index++;
76 if (ket->get_quantum_number_f() == std::numeric_limits<real_t>::max()) {
77 _has_quantum_number_f =
false;
79 if (ket->get_quantum_number_m() == std::numeric_limits<real_t>::max()) {
80 _has_quantum_number_m =
false;
82 if (ket->get_parity() == Parity::UNKNOWN) {
86 state_index_to_ket_index.resize(this->kets.size());
87 std::iota(state_index_to_ket_index.begin(), state_index_to_ket_index.end(), 0);
88 ket_index_to_state_index.resize(this->kets.size());
89 std::iota(ket_index_to_state_index.begin(), ket_index_to_state_index.end(), 0);
90 coefficients.matrix.setIdentity();
93template <
typename Derived>
95 return _has_quantum_number_f;
98template <
typename Derived>
100 return _has_quantum_number_m;
103template <
typename Derived>
108template <
typename Derived>
110 return static_cast<const Derived &
>(*this);
113template <
typename Derived>
118template <
typename Derived>
119const Eigen::SparseMatrix<typename Basis<Derived>::scalar_t, Eigen::RowMajor> &
121 return coefficients.matrix;
124template <
typename Derived>
125Eigen::SparseMatrix<typename Basis<Derived>::scalar_t, Eigen::RowMajor> &
127 return coefficients.matrix;
130template <
typename Derived>
132 if (ket_to_ket_index.count(ket) == 0) {
135 return ket_to_ket_index.at(ket);
138template <
typename Derived>
141 int ket_index = get_ket_index_from_ket(ket);
143 throw std::invalid_argument(
"The ket does not belong to the basis.");
147 return coefficients.matrix.row(ket_index);
150template <
typename Derived>
151Eigen::SparseMatrix<typename Basis<Derived>::scalar_t, Eigen::RowMajor>
153 return other->coefficients.
matrix.adjoint() * coefficients.matrix;
156template <
typename Derived>
159 return get_amplitudes(ket).cwiseAbs2();
162template <
typename Derived>
163Eigen::SparseMatrix<typename Basis<Derived>::real_t, Eigen::RowMajor>
165 return get_amplitudes(other).cwiseAbs2();
168template <
typename Derived>
170 real_t quantum_number_f = state_index_to_quantum_number_f.at(state_index);
171 if (quantum_number_f == std::numeric_limits<real_t>::max()) {
172 throw std::invalid_argument(
"The state does not have a well-defined quantum number f.");
174 return quantum_number_f;
177template <
typename Derived>
179 real_t quantum_number_m = state_index_to_quantum_number_m.at(state_index);
180 if (quantum_number_m == std::numeric_limits<real_t>::max()) {
181 throw std::invalid_argument(
"The state does not have a well-defined quantum number m.");
183 return quantum_number_m;
186template <
typename Derived>
188 Parity parity = state_index_to_parity.at(state_index);
189 if (parity == Parity::UNKNOWN) {
190 throw std::invalid_argument(
"The state does not have a well-defined parity.");
195template <
typename Derived>
196std::shared_ptr<const typename Basis<Derived>::ket_t>
198 size_t ket_index = state_index_to_ket_index.at(state_index);
199 if (ket_index == std::numeric_limits<int>::max()) {
200 throw std::invalid_argument(
"The state does not belong to a ket in a well-defined way.");
202 return kets[ket_index];
205template <
typename Derived>
206std::shared_ptr<const typename Basis<Derived>::ket_t>
208 throw std::runtime_error(
"Not implemented yet.");
211template <
typename Derived>
214 auto restricted = std::make_shared<Derived>(derived());
217 restricted->coefficients.matrix = restricted->coefficients.matrix.col(state_index);
219 std::fill(restricted->ket_index_to_state_index.begin(),
220 restricted->ket_index_to_state_index.end(), std::numeric_limits<int>::max());
221 restricted->ket_index_to_state_index[state_index_to_ket_index[state_index]] = 0;
223 restricted->state_index_to_quantum_number_f = {state_index_to_quantum_number_f[state_index]};
224 restricted->state_index_to_quantum_number_m = {state_index_to_quantum_number_m[state_index]};
225 restricted->state_index_to_parity = {state_index_to_parity[state_index]};
226 restricted->state_index_to_ket_index = {state_index_to_ket_index[state_index]};
228 restricted->_has_quantum_number_f =
229 restricted->state_index_to_quantum_number_f[0] != std::numeric_limits<real_t>::max();
230 restricted->_has_quantum_number_m =
231 restricted->state_index_to_quantum_number_m[0] != std::numeric_limits<real_t>::max();
232 restricted->_has_parity = restricted->state_index_to_parity[0] != Parity::UNKNOWN;
237template <
typename Derived>
238std::shared_ptr<const typename Basis<Derived>::ket_t>
240 return kets[ket_index];
243template <
typename Derived>
245 size_t state_index = ket_index_to_state_index.at(ket_index);
246 if (state_index == std::numeric_limits<int>::max()) {
247 throw std::runtime_error(
"The ket does not belong to a state in a well-defined way.");
249 return get_state(state_index);
252template <
typename Derived>
253std::shared_ptr<const Derived>
255 int ket_index = get_ket_index_from_ket(ket);
257 throw std::invalid_argument(
"The ket does not belong to the basis.");
259 return get_corresponding_state(ket_index);
262template <
typename Derived>
264 int state_index = ket_index_to_state_index.at(ket_index);
265 if (state_index == std::numeric_limits<int>::max()) {
266 throw std::runtime_error(
"The ket does not belong to a state in a well-defined way.");
271template <
typename Derived>
273 int ket_index = get_ket_index_from_ket(ket);
275 throw std::invalid_argument(
"The ket does not belong to the basis.");
277 return get_corresponding_state_index(ket_index);
280template <
typename Derived>
282 int ket_index = state_index_to_ket_index.at(state_index);
283 if (ket_index == std::numeric_limits<int>::max()) {
284 throw std::runtime_error(
"The state does not belong to a ket in a well-defined way.");
289template <
typename Derived>
291 throw std::runtime_error(
"Not implemented yet.");
294template <
typename Derived>
295std::shared_ptr<const Derived>
298 auto created = std::make_shared<Derived>(derived());
301 created->coefficients.matrix =
302 Eigen::SparseMatrix<scalar_t, Eigen::RowMajor>(coefficients.matrix.rows(), 1);
303 created->coefficients.matrix.coeffRef(ket_index, 0) = 1;
304 created->coefficients.matrix.makeCompressed();
306 std::fill(created->ket_index_to_state_index.begin(), created->ket_index_to_state_index.end(),
307 std::numeric_limits<int>::max());
308 created->ket_index_to_state_index[ket_index] = 0;
310 created->state_index_to_quantum_number_f = {kets[ket_index]->get_quantum_number_f()};
311 created->state_index_to_quantum_number_m = {kets[ket_index]->get_quantum_number_m()};
312 created->state_index_to_parity = {kets[ket_index]->get_parity()};
313 created->state_index_to_ket_index = {ket_index};
315 created->_has_quantum_number_f =
316 created->state_index_to_quantum_number_f[0] != std::numeric_limits<real_t>::max();
317 created->_has_quantum_number_m =
318 created->state_index_to_quantum_number_m[0] != std::numeric_limits<real_t>::max();
319 created->_has_parity = created->state_index_to_parity[0] != Parity::UNKNOWN;
324template <
typename Derived>
325std::shared_ptr<const Derived>
327 int ket_index = get_ket_index_from_ket(ket);
329 throw std::invalid_argument(
"The ket does not belong to the basis.");
331 return get_canonical_state_from_ket(ket_index);
334template <
typename Derived>
339template <
typename Derived>
344template <
typename Derived>
347template <
typename Derived>
349 return other.it != it;
352template <
typename Derived>
357template <
typename Derived>
363template <
typename Derived>
365 return coefficients.
matrix.cols();
368template <
typename Derived>
370 return coefficients.
matrix.rows();
373template <
typename Derived>
379template <
typename Derived>
383 static_cast<Eigen::Index
>(coefficients.
matrix.rows())},
386 std::vector<Eigen::Triplet<scalar_t>> entries;
388 for (
size_t idx_initial = 0; idx_initial <
kets.size(); ++idx_initial) {
389 real_t f =
kets[idx_initial]->get_quantum_number_f();
390 real_t m_initial =
kets[idx_initial]->get_quantum_number_m();
392 assert(2 * f == std::floor(2 * f) && f >= 0);
393 assert(2 * m_initial == std::floor(2 * m_initial) && m_initial >= -f && m_initial <= f);
395 for (
real_t m_final = -f; m_final <= f;
397 auto val = wigner::wigner_uppercase_d_matrix<scalar_t>(f, m_initial, m_final, alpha,
400 kets[idx_initial]->get_ket_for_different_quantum_number_m(m_final));
401 entries.emplace_back(idx_final, idx_initial, val);
405 transformation.matrix.setFromTriplets(entries.begin(), entries.end());
406 transformation.matrix.makeCompressed();
408 return transformation;
411template <
typename Derived>
418 throw std::invalid_argument(
"States do not store the energy and thus can not be sorted by "
419 "the energy. Use an energy operator instead.");
424 transformation.
matrix.resize(coefficients.
matrix.cols());
425 transformation.
matrix.setIdentity();
432 throw std::invalid_argument(
"The states could not be sorted by all the requested labels.");
435 return transformation;
438template <
typename Derived>
439std::vector<IndicesOfBlock>
443 std::set<TransformationType> unique_labels(labels.begin(), labels.end());
450 return blocks_creator.create();
453template <
typename Derived>
455 Sorting &transformation)
const {
456 constexpr real_t numerical_precision = 100 * std::numeric_limits<real_t>::epsilon();
458 int *perm_begin = transformation.
matrix.indices().data();
459 int *perm_end = perm_begin + coefficients.
matrix.cols();
460 const int *perm_back = perm_end - 1;
463 std::stable_sort(perm_begin, perm_end, [&](
int a,
int b) {
464 for (
const auto &label : labels) {
467 if (state_index_to_parity[a] != state_index_to_parity[b]) {
468 return state_index_to_parity[a] < state_index_to_parity[b];
472 if (std::abs(state_index_to_quantum_number_m[a] -
473 state_index_to_quantum_number_m[b]) > numerical_precision) {
474 return state_index_to_quantum_number_m[a] < state_index_to_quantum_number_m[b];
478 if (std::abs(state_index_to_quantum_number_f[a] -
479 state_index_to_quantum_number_f[b]) > numerical_precision) {
480 return state_index_to_quantum_number_f[a] < state_index_to_quantum_number_f[b];
484 if (state_index_to_ket_index[a] != state_index_to_ket_index[b]) {
485 return state_index_to_ket_index[a] < state_index_to_ket_index[b];
496 for (
const auto &label : labels) {
500 throw std::invalid_argument(
501 "States cannot be labeled and thus not sorted by the parity.");
506 if (state_index_to_quantum_number_m[*perm_back] == std::numeric_limits<real_t>::max()) {
507 throw std::invalid_argument(
508 "States cannot be labeled and thus not sorted by the quantum number m.");
514 if (state_index_to_quantum_number_f[*perm_back] == std::numeric_limits<real_t>::max()) {
515 throw std::invalid_argument(
516 "States cannot be labeled and thus not sorted by the quantum number f.");
522 if (state_index_to_ket_index[*perm_back] == std::numeric_limits<int>::max()) {
523 throw std::invalid_argument(
524 "States cannot be labeled and thus not sorted by kets.");
534template <
typename Derived>
536 const std::set<TransformationType> &unique_labels,
538 constexpr real_t numerical_precision = 100 * std::numeric_limits<real_t>::epsilon();
540 auto last_quantum_number_f = state_index_to_quantum_number_f[0];
541 auto last_quantum_number_m = state_index_to_quantum_number_m[0];
542 auto last_parity = state_index_to_parity[0];
543 auto last_ket = state_index_to_ket_index[0];
545 for (
int i = 0; i < coefficients.
matrix.cols(); ++i) {
546 for (
auto label : unique_labels) {
548 if (std::abs(state_index_to_quantum_number_f[i] - last_quantum_number_f) >
549 numerical_precision) {
550 blocks_creator.
add(i);
554 if (std::abs(state_index_to_quantum_number_m[i] - last_quantum_number_m) >
555 numerical_precision) {
556 blocks_creator.
add(i);
560 if (state_index_to_parity[i] != last_parity) {
561 blocks_creator.
add(i);
565 if (state_index_to_ket_index[i] != last_ket) {
566 blocks_creator.
add(i);
571 last_quantum_number_f = state_index_to_quantum_number_f[i];
572 last_quantum_number_m = state_index_to_quantum_number_m[i];
573 last_parity = state_index_to_parity[i];
574 last_ket = state_index_to_ket_index[i];
578template <
typename Derived>
581 auto transformed = std::make_shared<Derived>(derived());
583 if (coefficients.
matrix.cols() == 0) {
591 transformed->state_index_to_quantum_number_f.resize(transformation.
matrix.size());
592 transformed->state_index_to_quantum_number_m.resize(transformation.
matrix.size());
596 for (
int i = 0; i < transformation.
matrix.size(); ++i) {
598 state_index_to_quantum_number_f[transformation.
matrix.indices()[i]];
600 state_index_to_quantum_number_m[transformation.
matrix.indices()[i]];
602 state_index_to_parity[transformation.
matrix.indices()[i]];
604 state_index_to_ket_index[transformation.
matrix.indices()[i]];
606 [state_index_to_ket_index[transformation.
matrix.indices()[i]]] = i;
612template <
typename Derived>
613std::shared_ptr<const Derived>
617 real_t numerical_precision = 0.001;
620 bool is_rotation =
false;
628 throw std::invalid_argument(
"A rotation can not be combined with other transformations.");
635 throw std::runtime_error(
636 "If the object was transformed by a different transformation "
637 "than sorting, it can not be rotated.");
643 auto transformed = std::make_shared<Derived>(derived());
645 if (coefficients.
matrix.cols() == 0) {
655 Eigen::SparseMatrix<real_t> probs = transformation.
matrix.cwiseAbs2().transpose();
658 auto map = Eigen::Map<const Eigen::VectorX<real_t>>(state_index_to_quantum_number_f.data(),
659 state_index_to_quantum_number_f.size());
663 transformed->state_index_to_quantum_number_f.resize(probs.rows());
665 for (
size_t i = 0; i <
transformed->state_index_to_quantum_number_f.size(); ++i) {
666 if (diff[i] < numerical_precision) {
667 transformed->state_index_to_quantum_number_f[i] = std::round(val[i] * 2) / 2;
670 std::numeric_limits<real_t>::max();
677 auto map = Eigen::Map<const Eigen::VectorX<real_t>>(state_index_to_quantum_number_m.data(),
678 state_index_to_quantum_number_m.size());
682 transformed->state_index_to_quantum_number_m.resize(probs.rows());
684 for (
size_t i = 0; i <
transformed->state_index_to_quantum_number_m.size(); ++i) {
685 if (diff[i] < numerical_precision) {
686 transformed->state_index_to_quantum_number_m[i] = std::round(val[i] * 2) / 2;
689 std::numeric_limits<real_t>::max();
696 using utype = std::underlying_type<Parity>::type;
698 for (
size_t i = 0; i < state_index_to_parity.size(); ++i) {
699 map[i] =
static_cast<utype
>(state_index_to_parity[i]);
704 transformed->state_index_to_parity.resize(probs.rows());
706 for (
size_t i = 0; i <
transformed->state_index_to_parity.size(); ++i) {
707 if (diff[i] < numerical_precision) {
708 transformed->state_index_to_parity[i] =
static_cast<Parity>(std::lround(val[i]));
720 std::vector<real_t> max_in_row(
transformed->coefficients.matrix.rows(), 0);
721 std::vector<real_t> max_in_col(
transformed->coefficients.matrix.cols(), 0);
722 for (
int row = 0; row <
transformed->coefficients.matrix.outerSize(); ++row) {
723 for (
typename Eigen::SparseMatrix<scalar_t, Eigen::RowMajor>::InnerIterator it(
726 real_t val = std::pow(std::abs(it.value()), 2);
727 max_in_row[row] = std::max(max_in_row[row], val);
728 max_in_col[it.col()] = std::max(max_in_col[it.col()], val);
733 std::vector<real_t> costs;
734 std::vector<std::pair<int, int>> mappings;
735 costs.reserve(
transformed->coefficients.matrix.nonZeros());
736 mappings.reserve(
transformed->coefficients.matrix.nonZeros());
737 for (
int row = 0; row <
transformed->coefficients.matrix.outerSize(); ++row) {
738 for (
typename Eigen::SparseMatrix<scalar_t, Eigen::RowMajor>::InnerIterator it(
741 real_t val = std::pow(std::abs(it.value()), 2);
742 real_t cost = max_in_row[row] + max_in_col[it.col()] - 2 * val;
743 costs.push_back(cost);
744 mappings.push_back({row, it.col()});
749 std::vector<size_t> order(costs.size());
750 std::iota(order.begin(), order.end(), 0);
751 std::sort(order.begin(), order.end(),
752 [&](
size_t a,
size_t b) { return costs[a] < costs[b]; });
755 std::fill(
transformed->ket_index_to_state_index.begin(),
756 transformed->ket_index_to_state_index.end(), std::numeric_limits<int>::max());
759 std::vector<bool> row_used(
transformed->coefficients.matrix.rows(),
false);
760 std::vector<bool> col_used(
transformed->coefficients.matrix.cols(),
false);
762 for (
size_t idx : order) {
763 int row = mappings[idx].first;
764 int col = mappings[idx].second;
765 if (!row_used[row] && !col_used[col]) {
766 row_used[row] =
true;
767 col_used[col] =
true;
772 if (num_used ==
transformed->coefficients.matrix.cols()) {
776 assert(num_used ==
transformed->coefficients.matrix.cols());
782template <
typename Derived>
784 return typename ket_t::hash()(*k);
787template <
typename Derived>
788bool Basis<Derived>::equal_to::operator()(
const std::shared_ptr<const ket_t> &lhs,
789 const std::shared_ptr<const ket_t> &rhs)
const {
794template class Basis<BasisAtom<double>>;
795template class Basis<BasisAtom<std::complex<double>>>;
796template class Basis<BasisPair<double>>;
797template class Basis<BasisPair<std::complex<double>>>;
std::shared_ptr< const ket_t > operator*() const
bool operator!=(const Iterator &other) const
const Transformation< scalar_t > & get_transformation() const override
void get_indices_of_blocks_without_checks(const std::set< TransformationType > &unique_labels, IndicesOfBlocksCreator &blocks) const
void get_sorter_without_checks(const std::vector< TransformationType > &labels, Sorting &transformation) const
int get_ket_index_from_ket(std::shared_ptr< const ket_t > ket) const
typename traits::CrtpTraits< Derived >::ketvec_t ketvec_t
size_t get_number_of_states() const
Transformation< scalar_t > get_rotator(real_t alpha, real_t beta, real_t gamma) const override
void perform_sorter_checks(const std::vector< TransformationType > &labels) const
void perform_blocks_checks(const std::set< TransformationType > &unique_labels) const
std::vector< IndicesOfBlock > get_indices_of_blocks(const std::vector< TransformationType > &labels) const override
typename traits::CrtpTraits< Derived >::real_t real_t
std::shared_ptr< const Derived > transformed(const Transformation< scalar_t > &transformation) const
Sorting get_sorter(const std::vector< TransformationType > &labels) const override
size_t get_number_of_kets() const
void add(size_t boundary)
Matrix< Type, Dynamic, 1 > VectorX
bool is_sorting(TransformationType label)
@ SORT_BY_QUANTUM_NUMBER_M
@ SORT_BY_QUANTUM_NUMBER_F
Eigen::PermutationMatrix< Eigen::Dynamic, Eigen::Dynamic > matrix
std::vector< TransformationType > transformation_type