pairinteraction
A Rydberg Interaction Calculator
Basis.py.cpp
Go to the documentation of this file.
1// SPDX-FileCopyrightText: 2024 Pairinteraction Developers
2// SPDX-License-Identifier: LGPL-3.0-or-later
3
4#include "./Basis.py.hpp"
5
16
17#include <nanobind/eigen/sparse.h>
18#include <nanobind/nanobind.h>
19#include <nanobind/stl/complex.h>
20#include <nanobind/stl/set.h>
21#include <nanobind/stl/shared_ptr.h>
22#include <nanobind/stl/string.h>
23#include <nanobind/stl/vector.h>
24
25namespace nb = nanobind;
26using namespace pairinteraction;
27
28template <typename T>
29static void declare_basis(nb::module_ &m, std::string const &type_name) {
30 std::string pyclass_name = "Basis" + type_name;
31 using scalar_t = typename Basis<T>::scalar_t;
32 nb::class_<Basis<T>, TransformationBuilderInterface<scalar_t>> pyclass(m, pyclass_name.c_str());
33 pyclass.def("get_kets", &Basis<T>::get_kets)
34 .def("get_ket", &Basis<T>::get_ket)
35 .def("get_state", &Basis<T>::get_state)
36 .def("get_number_of_states", &Basis<T>::get_number_of_states)
37 .def("get_number_of_kets", &Basis<T>::get_number_of_kets)
38 .def("get_quantum_number_f", &Basis<T>::get_quantum_number_f)
39 .def("get_quantum_number_m", &Basis<T>::get_quantum_number_m)
40 .def("get_parity", &Basis<T>::get_parity)
41 .def("get_coefficients", nb::overload_cast<>(&Basis<T>::get_coefficients, nb::const_))
42 .def("set_coefficients", &Basis<T>::set_coefficients)
43 .def("get_transformation", &Basis<T>::get_transformation)
44 .def("get_rotator", &Basis<T>::get_rotator)
45 .def("get_sorter", &Basis<T>::get_sorter)
46 .def("get_indices_of_blocks", &Basis<T>::get_indices_of_blocks)
47 .def("get_sorter_without_checks", &Basis<T>::get_sorter_without_checks)
48 .def("get_indices_of_blocks_without_checks",
50 .def(
51 "transformed",
52 nb::overload_cast<const Transformation<scalar_t> &>(&Basis<T>::transformed, nb::const_))
53 .def("transformed", nb::overload_cast<const Sorting &>(&Basis<T>::transformed, nb::const_))
54 .def("get_amplitudes",
55 nb::overload_cast<std::shared_ptr<const typename Basis<T>::ket_t>>(
56 &Basis<T>::get_amplitudes, nb::const_))
57 .def("get_amplitudes",
58 nb::overload_cast<std::shared_ptr<const T>>(&Basis<T>::get_amplitudes, nb::const_))
59 .def("get_overlaps",
60 nb::overload_cast<std::shared_ptr<const typename Basis<T>::ket_t>>(
61 &Basis<T>::get_overlaps, nb::const_))
62 .def("get_overlaps",
63 nb::overload_cast<std::shared_ptr<const T>>(&Basis<T>::get_overlaps, nb::const_))
64 .def("get_matrix_elements",
65 nb::overload_cast<std::shared_ptr<const typename Basis<T>::ket_t>, OperatorType, int>(
67 .def("get_matrix_elements",
68 nb::overload_cast<std::shared_ptr<const T>, OperatorType, int>(
70 .def("get_corresponding_state",
71 nb::overload_cast<size_t>(&Basis<T>::get_corresponding_state, nb::const_))
72 .def("get_corresponding_state",
73 nb::overload_cast<std::shared_ptr<const typename Basis<T>::ket_t>>(
75 .def("get_corresponding_state_index",
76 nb::overload_cast<size_t>(&Basis<T>::get_corresponding_state_index, nb::const_))
77 .def("get_corresponding_state_index",
78 nb::overload_cast<std::shared_ptr<const typename Basis<T>::ket_t>>(
80 .def("get_corresponding_ket",
81 nb::overload_cast<size_t>(&Basis<T>::get_corresponding_ket, nb::const_))
82 .def("get_corresponding_ket",
83 nb::overload_cast<std::shared_ptr<const T>>(&Basis<T>::get_corresponding_ket,
84 nb::const_))
85 .def("get_corresponding_ket_index",
86 nb::overload_cast<size_t>(&Basis<T>::get_corresponding_ket_index, nb::const_))
87 .def("get_corresponding_ket_index",
88 nb::overload_cast<std::shared_ptr<const T>>(&Basis<T>::get_corresponding_ket_index,
89 nb::const_));
90}
91
92template <typename T>
93static void declare_basis_atom(nb::module_ &m, std::string const &type_name) {
94 std::string pyclass_name = "BasisAtom" + type_name;
95 nb::class_<BasisAtom<T>, Basis<BasisAtom<T>>> pyclass(m, pyclass_name.c_str());
96}
97
98template <typename T>
99static void declare_basis_atom_creator(nb::module_ &m, std::string const &type_name) {
100 std::string pyclass_name = "BasisAtomCreator" + type_name;
101 nb::class_<BasisAtomCreator<T>> pyclass(m, pyclass_name.c_str());
102 pyclass.def(nb::init<>())
103 .def("set_species", &BasisAtomCreator<T>::set_species)
104 .def("restrict_energy", &BasisAtomCreator<T>::restrict_energy)
105 .def("restrict_quantum_number_f", &BasisAtomCreator<T>::restrict_quantum_number_f)
106 .def("restrict_quantum_number_m", &BasisAtomCreator<T>::restrict_quantum_number_m)
107 .def("restrict_parity", &BasisAtomCreator<T>::restrict_parity)
108 .def("restrict_quantum_number_n", &BasisAtomCreator<T>::restrict_quantum_number_n)
109 .def("restrict_quantum_number_nu", &BasisAtomCreator<T>::restrict_quantum_number_nu)
110 .def("restrict_quantum_number_nui", &BasisAtomCreator<T>::restrict_quantum_number_nui)
111 .def("restrict_quantum_number_l", &BasisAtomCreator<T>::restrict_quantum_number_l)
112 .def("restrict_quantum_number_s", &BasisAtomCreator<T>::restrict_quantum_number_s)
113 .def("restrict_quantum_number_j", &BasisAtomCreator<T>::restrict_quantum_number_j)
114 .def("restrict_quantum_number_l_ryd", &BasisAtomCreator<T>::restrict_quantum_number_l_ryd)
115 .def("restrict_quantum_number_j_ryd", &BasisAtomCreator<T>::restrict_quantum_number_j_ryd)
116 .def("append_ket", &BasisAtomCreator<T>::append_ket)
117 .def("create", &BasisAtomCreator<T>::create);
118}
119
120template <typename T>
121static void declare_basis_pair(nb::module_ &m, std::string const &type_name) {
122 std::string pyclass_name = "BasisPair" + type_name;
123 nb::class_<BasisPair<T>, Basis<BasisPair<T>>> pyclass(m, pyclass_name.c_str());
124 pyclass
125 .def("get_amplitudes",
126 nb::overload_cast<std::shared_ptr<const KetAtom>, std::shared_ptr<const KetAtom>>(
127 &BasisPair<T>::get_amplitudes, nb::const_))
128 .def("get_amplitudes",
129 nb::overload_cast<std::shared_ptr<const BasisAtom<T>>,
130 std::shared_ptr<const BasisAtom<T>>>(&BasisPair<T>::get_amplitudes,
131 nb::const_))
132 .def("get_overlaps",
133 nb::overload_cast<std::shared_ptr<const KetAtom>, std::shared_ptr<const KetAtom>>(
134 &BasisPair<T>::get_overlaps, nb::const_))
135 .def("get_overlaps",
136 nb::overload_cast<std::shared_ptr<const BasisAtom<T>>,
137 std::shared_ptr<const BasisAtom<T>>>(&BasisPair<T>::get_overlaps,
138 nb::const_))
139 .def("get_matrix_elements",
140 nb::overload_cast<std::shared_ptr<const BasisPair<T>>, OperatorType, OperatorType, int,
141 int>(&BasisPair<T>::get_matrix_elements, nb::const_))
142 .def("get_matrix_elements",
143 nb::overload_cast<std::shared_ptr<const BasisAtom<T>>,
144 std::shared_ptr<const BasisAtom<T>>, OperatorType, OperatorType, int,
145 int>(&BasisPair<T>::get_matrix_elements, nb::const_))
146 .def("get_matrix_elements",
147 nb::overload_cast<std::shared_ptr<const typename BasisPair<T>::ket_t>, OperatorType,
149 nb::const_))
150 .def("get_matrix_elements",
151 nb::overload_cast<std::shared_ptr<const KetAtom>, std::shared_ptr<const KetAtom>,
152 OperatorType, OperatorType, int, int>(
154}
155
156template <typename T>
157static void declare_basis_pair_creator(nb::module_ &m, std::string const &type_name) {
158 std::string pyclass_name = "BasisPairCreator" + type_name;
159 nb::class_<BasisPairCreator<T>> pyclass(m, pyclass_name.c_str());
160 pyclass.def(nb::init<>())
161 .def("add", &BasisPairCreator<T>::add)
162 .def("restrict_energy", &BasisPairCreator<T>::restrict_energy)
163 .def("restrict_quantum_number_m", &BasisPairCreator<T>::restrict_quantum_number_m)
164 .def("restrict_product_of_parities", &BasisPairCreator<T>::restrict_product_of_parities)
165 .def("create", &BasisPairCreator<T>::create);
166}
167
168void bind_basis(nb::module_ &m) {
169 declare_basis<BasisAtom<double>>(m, "BasisAtomReal");
170 declare_basis<BasisAtom<std::complex<double>>>(m, "BasisAtomComplex");
171 declare_basis_atom<double>(m, "Real");
172 declare_basis_atom<std::complex<double>>(m, "Complex");
173 declare_basis_atom_creator<double>(m, "Real");
174 declare_basis_atom_creator<std::complex<double>>(m, "Complex");
175
176 declare_basis<BasisPair<double>>(m, "BasisPairReal");
177 declare_basis<BasisPair<std::complex<double>>>(m, "BasisPairComplex");
178 declare_basis_pair<double>(m, "Real");
179 declare_basis_pair<std::complex<double>>(m, "Complex");
180 declare_basis_pair_creator<double>(m, "Real");
181 declare_basis_pair_creator<std::complex<double>>(m, "Complex");
182}
void bind_basis(nb::module_ &m)
Definition: Basis.py.cpp:168
Builder class for creating BasisAtom objects.
Class for creating a basis of atomic kets.
Definition: BasisAtom.hpp:40
typename traits::CrtpTraits< Type >::ket_t ket_t
Definition: BasisPair.hpp:54
Base class for a basis.
Definition: Basis.hpp:41
typename traits::CrtpTraits< Derived >::scalar_t scalar_t
Definition: Basis.hpp:43
typename traits::CrtpTraits< Derived >::ket_t ket_t
Definition: Basis.hpp:45