Line data Source code
1 : // SPDX-FileCopyrightText: 2024 Pairinteraction Developers 2 : // SPDX-License-Identifier: LGPL-3.0-or-later 3 : 4 : #include "./Operator.py.hpp" 5 : 6 : #include "pairinteraction/basis/BasisAtom.hpp" 7 : #include "pairinteraction/basis/BasisPair.hpp" 8 : #include "pairinteraction/operator/Operator.hpp" 9 : #include "pairinteraction/operator/OperatorAtom.hpp" 10 : #include "pairinteraction/operator/OperatorPair.hpp" 11 : 12 : #include <nanobind/eigen/sparse.h> 13 : #include <nanobind/nanobind.h> 14 : #include <nanobind/operators.h> 15 : #include <nanobind/stl/complex.h> 16 : #include <nanobind/stl/shared_ptr.h> 17 : #include <nanobind/stl/vector.h> 18 : 19 : namespace nb = nanobind; 20 : using namespace pairinteraction; 21 : 22 : template <typename T> 23 4 : static void declare_operator(nb::module_ &m, std::string const &type_name) { 24 4 : std::string pyclass_name = "Operator" + type_name; 25 : using basis_t = typename Operator<T>::basis_t; 26 : using scalar_t = typename Operator<T>::scalar_t; 27 8 : nb::class_<Operator<T>, TransformationBuilderInterface<scalar_t>> pyclass(m, 28 : pyclass_name.c_str()); 29 0 : pyclass.def(nb::init<std::shared_ptr<const basis_t>>()) 30 4 : .def("get_basis", nb::overload_cast<>(&Operator<T>::get_basis, nb::const_)) 31 4 : .def("get_matrix", nb::overload_cast<>(&Operator<T>::get_matrix, nb::const_)) 32 4 : .def("get_transformation", &Operator<T>::get_transformation) 33 4 : .def("get_rotator", &Operator<T>::get_rotator) 34 4 : .def("get_sorter", &Operator<T>::get_sorter) 35 4 : .def("get_indices_of_blocks", &Operator<T>::get_indices_of_blocks) 36 4 : .def("transformed", 37 4 : nb::overload_cast<const Transformation<scalar_t> &>(&Operator<T>::transformed, 38 : nb::const_)) 39 4 : .def("transformed", 40 4 : nb::overload_cast<const Sorting &>(&Operator<T>::transformed, nb::const_)) 41 4 : .def(scalar_t() * nb::self) 42 4 : .def(nb::self * scalar_t()) 43 4 : .def(nb::self / scalar_t()) 44 4 : .def(nb::self + nb::self) 45 4 : .def(nb::self - nb::self); // NOLINT 46 4 : } 47 : 48 : template <typename T> 49 2 : static void declare_operator_atom(nb::module_ &m, std::string const &type_name) { 50 2 : std::string pyclass_name = "OperatorAtom" + type_name; 51 : using basis_t = typename OperatorAtom<T>::basis_t; 52 4 : nb::class_<OperatorAtom<T>, Operator<OperatorAtom<T>>> pyclass(m, pyclass_name.c_str()); 53 2 : pyclass.def(nb::init<std::shared_ptr<const basis_t>>()) 54 2 : .def(nb::init<std::shared_ptr<const basis_t>, OperatorType, int>()); 55 2 : } 56 : 57 : template <typename T> 58 2 : static void declare_operator_pair(nb::module_ &m, std::string const &type_name) { 59 2 : std::string pyclass_name = "OperatorPair" + type_name; 60 : using basis_t = typename OperatorPair<T>::basis_t; 61 4 : nb::class_<OperatorPair<T>, Operator<OperatorPair<T>>> pyclass(m, pyclass_name.c_str()); 62 2 : pyclass.def(nb::init<std::shared_ptr<const basis_t>>()) 63 2 : .def(nb::init<std::shared_ptr<const basis_t>, OperatorType>()); 64 2 : } 65 : 66 1 : void bind_operator(nb::module_ &m) { 67 1 : declare_operator<OperatorAtom<double>>(m, "OperatorAtomReal"); 68 1 : declare_operator<OperatorAtom<std::complex<double>>>(m, "OperatorAtomComplex"); 69 1 : declare_operator_atom<double>(m, "Real"); 70 1 : declare_operator_atom<std::complex<double>>(m, "Complex"); 71 : 72 1 : declare_operator<OperatorPair<double>>(m, "OperatorPairReal"); 73 1 : declare_operator<OperatorPair<std::complex<double>>>(m, "OperatorPairComplex"); 74 1 : declare_operator_pair<double>(m, "Real"); 75 1 : declare_operator_pair<std::complex<double>>(m, "Complex"); 76 1 : }