Line data Source code
1 : // SPDX-FileCopyrightText: 2024 Pairinteraction Developers 2 : // SPDX-License-Identifier: LGPL-3.0-or-later 3 : 4 : #include "./TransformationBuilderInterface.py.hpp" 5 : 6 : #include "pairinteraction/interfaces/TransformationBuilderInterface.hpp" 7 : 8 : #include <nanobind/eigen/dense.h> 9 : #include <nanobind/eigen/sparse.h> 10 : #include <nanobind/nanobind.h> 11 : #include <nanobind/stl/array.h> 12 : #include <nanobind/stl/complex.h> 13 : #include <nanobind/stl/vector.h> 14 : 15 : namespace nb = nanobind; 16 : using namespace nb::literals; 17 : using namespace pairinteraction; 18 : 19 : template <typename T> 20 4 : static void declare_transformation(nb::module_ &m, std::string const &type_name) { 21 4 : std::string pyclass_name = "Transformation" + type_name; 22 8 : nb::class_<Transformation<T>> pyclass(m, pyclass_name.c_str()); 23 4 : pyclass.def(nb::init<>()) 24 4 : .def_rw("matrix", &Transformation<T>::matrix) 25 : .def_rw("transformation_type", &Transformation<T>::transformation_type); 26 4 : } 27 : 28 2 : static void declare_sorting(nb::module_ &m) { 29 2 : nb::class_<Sorting> pyclass(m, "Sorting"); 30 2 : pyclass.def(nb::init<>()) 31 2 : .def_rw("matrix", &Sorting::matrix) 32 : .def_rw("transformation_type", &Sorting::transformation_type); 33 2 : } 34 : 35 2 : static void declare_indices_of_blocks(nb::module_ &m) { 36 2 : nb::class_<IndicesOfBlock> pyclass(m, "IndicesOfBlock"); 37 4 : pyclass.def(nb::init<size_t, size_t>(), "start"_a, "end"_a) 38 2 : .def_rw("start", &IndicesOfBlock::start) 39 : .def_rw("end", &IndicesOfBlock::end); 40 2 : } 41 : 42 2 : static void declare_indices_of_blocks_creator(nb::module_ &m) { 43 2 : nb::class_<IndicesOfBlocksCreator> pyclass(m, "IndicesOfBlocksCreator"); 44 2 : } 45 : 46 : template <typename T> 47 4 : static void declare_transformation_builder_interface(nb::module_ &m, std::string const &type_name) { 48 4 : std::string pyclass_name = "TransformationBuilderInterface" + type_name; 49 : using real_t = typename TransformationBuilderInterface<T>::real_t; 50 4 : nb::class_<TransformationBuilderInterface<T>> pyclass(m, pyclass_name.c_str()); 51 : pyclass.def("get_rotator", 52 4 : nb::overload_cast<const std::array<real_t, 3> &, const std::array<real_t, 3> &>( 53 : &TransformationBuilderInterface<T>::get_rotator, nb::const_)); 54 4 : } 55 : 56 2 : void bind_transformation_builder_interface(nb::module_ &m) { 57 2 : declare_transformation<double>(m, "Real"); 58 2 : declare_transformation<std::complex<double>>(m, "Complex"); 59 2 : declare_sorting(m); 60 2 : declare_indices_of_blocks(m); 61 2 : declare_indices_of_blocks_creator(m); 62 2 : declare_transformation_builder_interface<double>(m, "Real"); 63 2 : declare_transformation_builder_interface<std::complex<double>>(m, "Complex"); 64 2 : }