Line data Source code
1 : // SPDX-FileCopyrightText: 2024 PairInteraction Developers 2 : // SPDX-License-Identifier: LGPL-3.0-or-later 3 : 4 : #include "./DiagonalizerInterface.py.hpp" 5 : 6 : #include "pairinteraction/interfaces/DiagonalizerInterface.hpp" 7 : 8 : #include <nanobind/eigen/sparse.h> 9 : #include <nanobind/nanobind.h> 10 : #include <nanobind/stl/complex.h> 11 : #include <nanobind/stl/optional.h> 12 : 13 : namespace nb = nanobind; 14 : using namespace pairinteraction; 15 : 16 : template <typename T> 17 4 : static void declare_diagonalizer_interface(nb::module_ &m, std::string const &type_name) { 18 4 : std::string pylass_name = "DiagonalizerInterface" + type_name; 19 : using real_t = typename DiagonalizerInterface<T>::real_t; 20 4 : nb::class_<DiagonalizerInterface<T>> pyclass(m, pylass_name.c_str()); 21 : pyclass 22 4 : .def("eigh", 23 4 : nb::overload_cast<const Eigen::SparseMatrix<T, Eigen::RowMajor> &, double>( 24 : &DiagonalizerInterface<T>::eigh, nb::const_), 25 4 : nb::call_guard<nb::gil_scoped_release>()) 26 : .def("eigh", 27 : nb::overload_cast<const Eigen::SparseMatrix<T, Eigen::RowMajor> &, 28 4 : std::optional<real_t>, std::optional<real_t>, double>( 29 : &DiagonalizerInterface<T>::eigh, nb::const_), 30 4 : nb::call_guard<nb::gil_scoped_release>()); 31 4 : } 32 : 33 : template <typename T> 34 4 : static void declare_eigen_system_h(nb::module_ &m, std::string const &type_name) { 35 4 : std::string pylass_name = "EigenSystemH" + type_name; 36 8 : nb::class_<EigenSystemH<T>> pyclass(m, pylass_name.c_str()); 37 4 : pyclass.def_rw("eigenvectors", &EigenSystemH<T>::eigenvectors) 38 : .def_rw("eigenvalues", &EigenSystemH<T>::eigenvalues); 39 4 : } 40 : 41 2 : void bind_diagonalizer_interface(nb::module_ &m) { 42 2 : declare_diagonalizer_interface<double>(m, "Real"); 43 2 : declare_diagonalizer_interface<std::complex<double>>(m, "Complex"); 44 2 : declare_eigen_system_h<double>(m, "Real"); 45 2 : declare_eigen_system_h<std::complex<double>>(m, "Complex"); 46 2 : }