Line data Source code
1 : // SPDX-FileCopyrightText: 2024 Pairinteraction Developers
2 : // SPDX-License-Identifier: LGPL-3.0-or-later
3 :
4 : #include "./Basis.py.hpp"
5 :
6 : #include "pairinteraction/basis/Basis.hpp"
7 : #include "pairinteraction/basis/BasisAtom.hpp"
8 : #include "pairinteraction/basis/BasisAtomCreator.hpp"
9 : #include "pairinteraction/basis/BasisPair.hpp"
10 : #include "pairinteraction/basis/BasisPairCreator.hpp"
11 : #include "pairinteraction/database/Database.hpp"
12 : #include "pairinteraction/interfaces/TransformationBuilderInterface.hpp"
13 : #include "pairinteraction/ket/KetAtom.hpp"
14 : #include "pairinteraction/ket/KetPair.hpp"
15 : #include "pairinteraction/system/SystemAtom.hpp"
16 :
17 : #include <nanobind/eigen/sparse.h>
18 : #include <nanobind/nanobind.h>
19 : #include <nanobind/stl/complex.h>
20 : #include <nanobind/stl/set.h>
21 : #include <nanobind/stl/shared_ptr.h>
22 : #include <nanobind/stl/string.h>
23 : #include <nanobind/stl/vector.h>
24 :
25 : namespace nb = nanobind;
26 : using namespace pairinteraction;
27 :
28 : template <typename T>
29 4 : static void declare_basis(nb::module_ &m, std::string const &type_name) {
30 4 : std::string pyclass_name = "Basis" + type_name;
31 : using scalar_t = typename Basis<T>::scalar_t;
32 4 : nb::class_<Basis<T>, TransformationBuilderInterface<scalar_t>> pyclass(m, pyclass_name.c_str());
33 4 : pyclass.def("get_kets", &Basis<T>::get_kets)
34 4 : .def("get_ket", &Basis<T>::get_ket)
35 4 : .def("get_state", &Basis<T>::get_state)
36 4 : .def("get_number_of_states", &Basis<T>::get_number_of_states)
37 4 : .def("get_number_of_kets", &Basis<T>::get_number_of_kets)
38 4 : .def("get_quantum_number_f", &Basis<T>::get_quantum_number_f)
39 4 : .def("get_quantum_number_m", &Basis<T>::get_quantum_number_m)
40 4 : .def("get_parity", &Basis<T>::get_parity)
41 4 : .def("get_coefficients", nb::overload_cast<>(&Basis<T>::get_coefficients, nb::const_))
42 4 : .def("get_transformation", &Basis<T>::get_transformation)
43 4 : .def("get_rotator", &Basis<T>::get_rotator)
44 4 : .def("get_sorter", &Basis<T>::get_sorter)
45 4 : .def("get_indices_of_blocks", &Basis<T>::get_indices_of_blocks)
46 4 : .def("get_sorter_without_checks", &Basis<T>::get_sorter_without_checks)
47 4 : .def("get_indices_of_blocks_without_checks",
48 4 : &Basis<T>::get_indices_of_blocks_without_checks)
49 4 : .def(
50 : "transformed",
51 4 : nb::overload_cast<const Transformation<scalar_t> &>(&Basis<T>::transformed, nb::const_))
52 4 : .def("transformed", nb::overload_cast<const Sorting &>(&Basis<T>::transformed, nb::const_))
53 4 : .def("get_amplitudes",
54 4 : nb::overload_cast<std::shared_ptr<const typename Basis<T>::ket_t>>(
55 : &Basis<T>::get_amplitudes, nb::const_))
56 4 : .def("get_amplitudes",
57 4 : nb::overload_cast<std::shared_ptr<const T>>(&Basis<T>::get_amplitudes, nb::const_))
58 4 : .def("get_overlaps",
59 4 : nb::overload_cast<std::shared_ptr<const typename Basis<T>::ket_t>>(
60 : &Basis<T>::get_overlaps, nb::const_))
61 4 : .def("get_overlaps",
62 4 : nb::overload_cast<std::shared_ptr<const T>>(&Basis<T>::get_overlaps, nb::const_))
63 4 : .def("get_matrix_elements",
64 4 : nb::overload_cast<std::shared_ptr<const typename Basis<T>::ket_t>, OperatorType, int>(
65 : &Basis<T>::get_matrix_elements, nb::const_))
66 4 : .def("get_matrix_elements",
67 4 : nb::overload_cast<std::shared_ptr<const T>, OperatorType, int>(
68 : &Basis<T>::get_matrix_elements, nb::const_))
69 4 : .def("get_corresponding_state",
70 4 : nb::overload_cast<size_t>(&Basis<T>::get_corresponding_state, nb::const_))
71 4 : .def("get_corresponding_state",
72 4 : nb::overload_cast<std::shared_ptr<const typename Basis<T>::ket_t>>(
73 : &Basis<T>::get_corresponding_state, nb::const_))
74 4 : .def("get_corresponding_state_index",
75 4 : nb::overload_cast<size_t>(&Basis<T>::get_corresponding_state_index, nb::const_))
76 4 : .def("get_corresponding_state_index",
77 4 : nb::overload_cast<std::shared_ptr<const typename Basis<T>::ket_t>>(
78 : &Basis<T>::get_corresponding_state_index, nb::const_))
79 4 : .def("get_corresponding_ket",
80 4 : nb::overload_cast<size_t>(&Basis<T>::get_corresponding_ket, nb::const_))
81 4 : .def("get_corresponding_ket",
82 4 : nb::overload_cast<std::shared_ptr<const T>>(&Basis<T>::get_corresponding_ket,
83 : nb::const_))
84 4 : .def("get_corresponding_ket_index",
85 4 : nb::overload_cast<size_t>(&Basis<T>::get_corresponding_ket_index, nb::const_))
86 : .def("get_corresponding_ket_index",
87 4 : nb::overload_cast<std::shared_ptr<const T>>(&Basis<T>::get_corresponding_ket_index,
88 : nb::const_));
89 4 : }
90 :
91 : template <typename T>
92 2 : static void declare_basis_atom(nb::module_ &m, std::string const &type_name) {
93 2 : std::string pyclass_name = "BasisAtom" + type_name;
94 2 : nb::class_<BasisAtom<T>, Basis<BasisAtom<T>>> pyclass(m, pyclass_name.c_str());
95 2 : }
96 :
97 : template <typename T>
98 2 : static void declare_basis_atom_creator(nb::module_ &m, std::string const &type_name) {
99 2 : std::string pyclass_name = "BasisAtomCreator" + type_name;
100 4 : nb::class_<BasisAtomCreator<T>> pyclass(m, pyclass_name.c_str());
101 4 : pyclass.def(nb::init<>())
102 2 : .def("set_species", &BasisAtomCreator<T>::set_species)
103 2 : .def("restrict_energy", &BasisAtomCreator<T>::restrict_energy)
104 2 : .def("restrict_quantum_number_f", &BasisAtomCreator<T>::restrict_quantum_number_f)
105 2 : .def("restrict_quantum_number_m", &BasisAtomCreator<T>::restrict_quantum_number_m)
106 2 : .def("restrict_parity", &BasisAtomCreator<T>::restrict_parity)
107 2 : .def("restrict_quantum_number_n", &BasisAtomCreator<T>::restrict_quantum_number_n)
108 2 : .def("restrict_quantum_number_nu", &BasisAtomCreator<T>::restrict_quantum_number_nu)
109 2 : .def("restrict_quantum_number_l", &BasisAtomCreator<T>::restrict_quantum_number_l)
110 2 : .def("restrict_quantum_number_s", &BasisAtomCreator<T>::restrict_quantum_number_s)
111 2 : .def("restrict_quantum_number_j", &BasisAtomCreator<T>::restrict_quantum_number_j)
112 2 : .def("append_ket", &BasisAtomCreator<T>::append_ket)
113 2 : .def("create", &BasisAtomCreator<T>::create);
114 2 : }
115 :
116 : template <typename T>
117 2 : static void declare_basis_pair(nb::module_ &m, std::string const &type_name) {
118 2 : std::string pyclass_name = "BasisPair" + type_name;
119 2 : nb::class_<BasisPair<T>, Basis<BasisPair<T>>> pyclass(m, pyclass_name.c_str());
120 : pyclass
121 2 : .def("get_amplitudes",
122 2 : nb::overload_cast<std::shared_ptr<const KetAtom>, std::shared_ptr<const KetAtom>>(
123 : &BasisPair<T>::get_amplitudes, nb::const_))
124 2 : .def("get_amplitudes",
125 : nb::overload_cast<std::shared_ptr<const BasisAtom<T>>,
126 2 : std::shared_ptr<const BasisAtom<T>>>(&BasisPair<T>::get_amplitudes,
127 : nb::const_))
128 2 : .def("get_overlaps",
129 2 : nb::overload_cast<std::shared_ptr<const KetAtom>, std::shared_ptr<const KetAtom>>(
130 : &BasisPair<T>::get_overlaps, nb::const_))
131 2 : .def("get_overlaps",
132 : nb::overload_cast<std::shared_ptr<const BasisAtom<T>>,
133 2 : std::shared_ptr<const BasisAtom<T>>>(&BasisPair<T>::get_overlaps,
134 : nb::const_))
135 2 : .def("get_matrix_elements",
136 : nb::overload_cast<std::shared_ptr<const BasisPair<T>>, OperatorType, OperatorType, int,
137 2 : int>(&BasisPair<T>::get_matrix_elements, nb::const_))
138 2 : .def("get_matrix_elements",
139 : nb::overload_cast<std::shared_ptr<const BasisAtom<T>>,
140 : std::shared_ptr<const BasisAtom<T>>, OperatorType, OperatorType, int,
141 2 : int>(&BasisPair<T>::get_matrix_elements, nb::const_))
142 2 : .def("get_matrix_elements",
143 : nb::overload_cast<std::shared_ptr<const typename BasisPair<T>::ket_t>, OperatorType,
144 2 : OperatorType, int, int>(&BasisPair<T>::get_matrix_elements,
145 : nb::const_))
146 : .def("get_matrix_elements",
147 : nb::overload_cast<std::shared_ptr<const KetAtom>, std::shared_ptr<const KetAtom>,
148 2 : OperatorType, OperatorType, int, int>(
149 : &BasisPair<T>::get_matrix_elements, nb::const_));
150 2 : }
151 :
152 : template <typename T>
153 2 : static void declare_basis_pair_creator(nb::module_ &m, std::string const &type_name) {
154 2 : std::string pyclass_name = "BasisPairCreator" + type_name;
155 4 : nb::class_<BasisPairCreator<T>> pyclass(m, pyclass_name.c_str());
156 4 : pyclass.def(nb::init<>())
157 2 : .def("add", &BasisPairCreator<T>::add)
158 2 : .def("restrict_energy", &BasisPairCreator<T>::restrict_energy)
159 2 : .def("restrict_quantum_number_m", &BasisPairCreator<T>::restrict_quantum_number_m)
160 2 : .def("restrict_product_of_parities", &BasisPairCreator<T>::restrict_product_of_parities)
161 2 : .def("create", &BasisPairCreator<T>::create);
162 2 : }
163 :
164 1 : void bind_basis(nb::module_ &m) {
165 1 : declare_basis<BasisAtom<double>>(m, "BasisAtomReal");
166 1 : declare_basis<BasisAtom<std::complex<double>>>(m, "BasisAtomComplex");
167 1 : declare_basis_atom<double>(m, "Real");
168 1 : declare_basis_atom<std::complex<double>>(m, "Complex");
169 1 : declare_basis_atom_creator<double>(m, "Real");
170 1 : declare_basis_atom_creator<std::complex<double>>(m, "Complex");
171 :
172 1 : declare_basis<BasisPair<double>>(m, "BasisPairReal");
173 1 : declare_basis<BasisPair<std::complex<double>>>(m, "BasisPairComplex");
174 1 : declare_basis_pair<double>(m, "Real");
175 1 : declare_basis_pair<std::complex<double>>(m, "Complex");
176 1 : declare_basis_pair_creator<double>(m, "Real");
177 1 : declare_basis_pair_creator<std::complex<double>>(m, "Complex");
178 1 : }
|