Line data Source code
1 : // SPDX-FileCopyrightText: 2024 PairInteraction Developers 2 : // SPDX-License-Identifier: LGPL-3.0-or-later 3 : 4 : #include "./TransformationBuilderInterface.py.hpp" 5 : 6 : #include "pairinteraction/interfaces/TransformationBuilderInterface.hpp" 7 : 8 : #include <nanobind/eigen/dense.h> 9 : #include <nanobind/eigen/sparse.h> 10 : #include <nanobind/nanobind.h> 11 : #include <nanobind/stl/array.h> 12 : #include <nanobind/stl/complex.h> 13 : #include <nanobind/stl/vector.h> 14 : 15 : namespace nb = nanobind; 16 : using namespace nb::literals; 17 : using namespace pairinteraction; 18 : 19 : template <typename T> 20 4 : static void declare_transformation(nb::module_ &m, std::string const &type_name) { 21 4 : std::string pyclass_name = "Transformation" + type_name; 22 8 : nb::class_<Transformation<T>> pyclass(m, pyclass_name.c_str()); 23 4 : pyclass.def(nb::init<>()) 24 4 : .def_rw("matrix", &Transformation<T>::matrix) 25 : .def_rw("transformation_type", &Transformation<T>::transformation_type); 26 4 : } 27 : 28 2 : static void declare_sorting(nb::module_ &m) { 29 2 : nb::class_<Sorting> pyclass(m, "Sorting"); 30 2 : pyclass.def(nb::init<>()).def_rw("transformation_type", &Sorting::transformation_type); 31 2 : } 32 : 33 2 : static void declare_indices_of_blocks(nb::module_ &m) { 34 2 : nb::class_<IndicesOfBlock> pyclass(m, "IndicesOfBlock"); 35 4 : pyclass.def(nb::init<size_t, size_t>(), "start"_a, "end"_a) 36 2 : .def_rw("start", &IndicesOfBlock::start) 37 : .def_rw("end", &IndicesOfBlock::end); 38 2 : } 39 : 40 2 : static void declare_indices_of_blocks_creator(nb::module_ &m) { 41 2 : nb::class_<IndicesOfBlocksCreator> pyclass(m, "IndicesOfBlocksCreator"); 42 2 : } 43 : 44 : template <typename T> 45 4 : static void declare_transformation_builder_interface(nb::module_ &m, std::string const &type_name) { 46 4 : std::string pyclass_name = "TransformationBuilderInterface" + type_name; 47 : using real_t = typename TransformationBuilderInterface<T>::real_t; 48 4 : nb::class_<TransformationBuilderInterface<T>> pyclass(m, pyclass_name.c_str()); 49 : pyclass.def("get_rotator", 50 4 : nb::overload_cast<const std::array<real_t, 3> &, const std::array<real_t, 3> &>( 51 : &TransformationBuilderInterface<T>::get_rotator, nb::const_)); 52 4 : } 53 : 54 2 : void bind_transformation_builder_interface(nb::module_ &m) { 55 2 : declare_transformation<double>(m, "Real"); 56 2 : declare_transformation<std::complex<double>>(m, "Complex"); 57 2 : declare_sorting(m); 58 2 : declare_indices_of_blocks(m); 59 2 : declare_indices_of_blocks_creator(m); 60 2 : declare_transformation_builder_interface<double>(m, "Real"); 61 2 : declare_transformation_builder_interface<std::complex<double>>(m, "Complex"); 62 2 : }