Line data Source code
1 : // SPDX-FileCopyrightText: 2024 PairInteraction Developers
2 : // SPDX-License-Identifier: LGPL-3.0-or-later
3 :
4 : #include "./Basis.py.hpp"
5 :
6 : #include "pairinteraction/basis/Basis.hpp"
7 : #include "pairinteraction/basis/BasisAtom.hpp"
8 : #include "pairinteraction/basis/BasisAtomCreator.hpp"
9 : #include "pairinteraction/basis/BasisPair.hpp"
10 : #include "pairinteraction/basis/BasisPairCreator.hpp"
11 : #include "pairinteraction/database/Database.hpp"
12 : #include "pairinteraction/interfaces/TransformationBuilderInterface.hpp"
13 : #include "pairinteraction/ket/KetAtom.hpp"
14 : #include "pairinteraction/ket/KetPair.hpp"
15 : #include "pairinteraction/system/SystemAtom.hpp"
16 :
17 : #include <nanobind/eigen/sparse.h>
18 : #include <nanobind/nanobind.h>
19 : #include <nanobind/stl/complex.h>
20 : #include <nanobind/stl/set.h>
21 : #include <nanobind/stl/shared_ptr.h>
22 : #include <nanobind/stl/string.h>
23 : #include <nanobind/stl/vector.h>
24 :
25 : namespace nb = nanobind;
26 : using namespace pairinteraction;
27 :
28 : template <typename T>
29 8 : static void declare_basis(nb::module_ &m, std::string const &type_name) {
30 8 : std::string pyclass_name = "Basis" + type_name;
31 : using scalar_t = typename Basis<T>::scalar_t;
32 8 : nb::class_<Basis<T>, TransformationBuilderInterface<scalar_t>> pyclass(m, pyclass_name.c_str());
33 8 : pyclass.def("get_kets", &Basis<T>::get_kets)
34 8 : .def("get_ket", &Basis<T>::get_ket)
35 8 : .def("get_state", &Basis<T>::get_state)
36 8 : .def("get_number_of_states", &Basis<T>::get_number_of_states)
37 8 : .def("get_number_of_kets", &Basis<T>::get_number_of_kets)
38 8 : .def("get_quantum_number_f", &Basis<T>::get_quantum_number_f)
39 8 : .def("get_quantum_number_m", &Basis<T>::get_quantum_number_m)
40 8 : .def("get_parity", &Basis<T>::get_parity)
41 8 : .def("get_coefficients", nb::overload_cast<>(&Basis<T>::get_coefficients, nb::const_))
42 8 : .def("set_coefficients", &Basis<T>::set_coefficients)
43 8 : .def("get_transformation", &Basis<T>::get_transformation)
44 8 : .def("get_rotator", &Basis<T>::get_rotator)
45 8 : .def("get_sorter", &Basis<T>::get_sorter)
46 8 : .def("get_indices_of_blocks", &Basis<T>::get_indices_of_blocks)
47 8 : .def("get_sorter_without_checks", &Basis<T>::get_sorter_without_checks)
48 8 : .def("get_indices_of_blocks_without_checks",
49 8 : &Basis<T>::get_indices_of_blocks_without_checks)
50 8 : .def(
51 : "transformed",
52 8 : nb::overload_cast<const Transformation<scalar_t> &>(&Basis<T>::transformed, nb::const_))
53 8 : .def("transformed", nb::overload_cast<const Sorting &>(&Basis<T>::transformed, nb::const_))
54 8 : .def("get_corresponding_state",
55 8 : nb::overload_cast<size_t>(&Basis<T>::get_corresponding_state, nb::const_))
56 8 : .def("get_corresponding_state",
57 8 : nb::overload_cast<std::shared_ptr<const typename Basis<T>::ket_t>>(
58 : &Basis<T>::get_corresponding_state, nb::const_))
59 8 : .def("get_corresponding_state_index",
60 8 : nb::overload_cast<size_t>(&Basis<T>::get_corresponding_state_index, nb::const_))
61 8 : .def("get_corresponding_state_index",
62 8 : nb::overload_cast<std::shared_ptr<const typename Basis<T>::ket_t>>(
63 : &Basis<T>::get_corresponding_state_index, nb::const_))
64 8 : .def("get_corresponding_ket",
65 8 : nb::overload_cast<size_t>(&Basis<T>::get_corresponding_ket, nb::const_))
66 8 : .def("get_corresponding_ket",
67 8 : nb::overload_cast<std::shared_ptr<const T>>(&Basis<T>::get_corresponding_ket,
68 : nb::const_))
69 8 : .def("get_corresponding_ket_index",
70 8 : nb::overload_cast<size_t>(&Basis<T>::get_corresponding_ket_index, nb::const_))
71 8 : .def("get_corresponding_ket_index",
72 8 : nb::overload_cast<std::shared_ptr<const T>>(&Basis<T>::get_corresponding_ket_index,
73 : nb::const_))
74 16 : .def("canonicalized", &Basis<T>::canonicalized)
75 11 : .def("copy", [](const Basis<T> &self) {
76 3 : return std::make_shared<T>(static_cast<const T &>(self));
77 : });
78 8 : }
79 :
80 : template <typename T>
81 4 : static void declare_basis_atom(nb::module_ &m, std::string const &type_name) {
82 4 : std::string pyclass_name = "BasisAtom" + type_name;
83 4 : nb::class_<BasisAtom<T>, Basis<BasisAtom<T>>> pyclass(m, pyclass_name.c_str());
84 4 : pyclass.def("get_matrix_elements", &BasisAtom<T>::get_matrix_elements);
85 4 : }
86 :
87 : template <typename T>
88 4 : static void declare_basis_atom_creator(nb::module_ &m, std::string const &type_name) {
89 4 : std::string pyclass_name = "BasisAtomCreator" + type_name;
90 8 : nb::class_<BasisAtomCreator<T>> pyclass(m, pyclass_name.c_str());
91 12 : pyclass.def(nb::init<>())
92 4 : .def("set_species", &BasisAtomCreator<T>::set_species)
93 4 : .def("restrict_energy", &BasisAtomCreator<T>::restrict_energy)
94 4 : .def("restrict_quantum_number", &BasisAtomCreator<T>::restrict_quantum_number)
95 4 : .def("set_quantum_number_standard_deviation_factor",
96 4 : &BasisAtomCreator<T>::set_quantum_number_standard_deviation_factor)
97 4 : .def("add_ket", &BasisAtomCreator<T>::add_ket)
98 8 : .def("create", &BasisAtomCreator<T>::create, nb::call_guard<nb::gil_scoped_release>());
99 4 : }
100 :
101 : template <typename T>
102 4 : static void declare_basis_pair(nb::module_ &m, std::string const &type_name) {
103 4 : std::string pyclass_name = "BasisPair" + type_name;
104 4 : nb::class_<BasisPair<T>, Basis<BasisPair<T>>> pyclass(m, pyclass_name.c_str());
105 4 : pyclass.def("get_matrix_elements", &BasisPair<T>::get_matrix_elements)
106 4 : .def("get_basis1", &BasisPair<T>::get_basis1)
107 4 : .def("get_basis2", &BasisPair<T>::get_basis2);
108 4 : }
109 :
110 : template <typename T>
111 4 : static void declare_basis_pair_creator(nb::module_ &m, std::string const &type_name) {
112 4 : std::string pyclass_name = "BasisPairCreator" + type_name;
113 8 : nb::class_<BasisPairCreator<T>> pyclass(m, pyclass_name.c_str());
114 : pyclass
115 12 : .def(nb::init<>())
116 : // keep_alive because add() stores only a reference to the system, which must outlive the
117 : // creator (the system is dereferenced in create())
118 4 : .def("add", &BasisPairCreator<T>::add, nb::keep_alive<1, 2>())
119 4 : .def("restrict_energy", &BasisPairCreator<T>::restrict_energy)
120 4 : .def("restrict_quantum_number_m", &BasisPairCreator<T>::restrict_quantum_number_m)
121 4 : .def("restrict_parity_under_inversion",
122 4 : &BasisPairCreator<T>::restrict_parity_under_inversion)
123 4 : .def("restrict_parity_under_permutation",
124 4 : &BasisPairCreator<T>::restrict_parity_under_permutation)
125 8 : .def("create", &BasisPairCreator<T>::create, nb::call_guard<nb::gil_scoped_release>());
126 4 : }
127 :
128 2 : void bind_basis(nb::module_ &m) {
129 2 : declare_basis<BasisAtom<double>>(m, "BasisAtomReal");
130 2 : declare_basis<BasisAtom<std::complex<double>>>(m, "BasisAtomComplex");
131 2 : declare_basis_atom<double>(m, "Real");
132 2 : declare_basis_atom<std::complex<double>>(m, "Complex");
133 2 : declare_basis_atom_creator<double>(m, "Real");
134 2 : declare_basis_atom_creator<std::complex<double>>(m, "Complex");
135 :
136 2 : declare_basis<BasisPair<double>>(m, "BasisPairReal");
137 2 : declare_basis<BasisPair<std::complex<double>>>(m, "BasisPairComplex");
138 2 : declare_basis_pair<double>(m, "Real");
139 2 : declare_basis_pair<std::complex<double>>(m, "Complex");
140 2 : declare_basis_pair_creator<double>(m, "Real");
141 2 : declare_basis_pair_creator<std::complex<double>>(m, "Complex");
142 2 : }
|