Line data Source code
1 : // SPDX-FileCopyrightText: 2024 Pairinteraction Developers 2 : // SPDX-License-Identifier: LGPL-3.0-or-later 3 : 4 : #include "./System.py.hpp" 5 : 6 : #include "pairinteraction/basis/BasisAtom.hpp" 7 : #include "pairinteraction/basis/BasisPair.hpp" 8 : #include "pairinteraction/interfaces/DiagonalizerInterface.hpp" 9 : #include "pairinteraction/system/System.hpp" 10 : #include "pairinteraction/system/SystemAtom.hpp" 11 : #include "pairinteraction/system/SystemPair.hpp" 12 : 13 : #include <nanobind/eigen/dense.h> 14 : #include <nanobind/eigen/sparse.h> 15 : #include <nanobind/nanobind.h> 16 : #include <nanobind/stl/array.h> 17 : #include <nanobind/stl/complex.h> 18 : #include <nanobind/stl/optional.h> 19 : #include <nanobind/stl/shared_ptr.h> 20 : #include <nanobind/stl/vector.h> 21 : 22 : namespace nb = nanobind; 23 : using namespace nb::literals; 24 : using namespace pairinteraction; 25 : 26 : template <typename T> 27 8 : static void declare_system(nb::module_ &m, std::string const &type_name) { 28 8 : std::string pyclass_name = "System" + type_name; 29 : using scalar_t = typename System<T>::scalar_t; 30 8 : nb::class_<System<T>, TransformationBuilderInterface<scalar_t>> pyclass(m, 31 : pyclass_name.c_str()); 32 8 : pyclass.def("get_basis", &System<T>::get_basis) 33 8 : .def("get_eigenbasis", &System<T>::get_eigenbasis) 34 8 : .def("get_eigenenergies", &System<T>::get_eigenenergies) 35 8 : .def("get_matrix", &System<T>::get_matrix) 36 8 : .def("get_transformation", &System<T>::get_transformation) 37 8 : .def("get_rotator", &System<T>::get_rotator) 38 8 : .def("get_sorter", &System<T>::get_sorter) 39 8 : .def("get_indices_of_blocks", &System<T>::get_indices_of_blocks) 40 8 : .def("transform", 41 8 : nb::overload_cast<const Transformation<scalar_t> &>(&System<T>::transform)) 42 8 : .def("transform", nb::overload_cast<const Sorting &>(&System<T>::transform)) 43 40 : .def("diagonalize", &System<T>::diagonalize, "diagonalizer"_a, 44 40 : "min_eigenenergy"_a = nb::none(), "max_eigenenergy"_a = nb::none(), "rtol"_a = 1e-6) 45 16 : .def("is_diagonal", &System<T>::is_diagonal); 46 8 : } 47 : 48 : template <typename T> 49 4 : static void declare_system_atom(nb::module_ &m, std::string const &type_name) { 50 4 : std::string pyclass_name = "SystemAtom" + type_name; 51 : using basis_t = typename SystemAtom<T>::basis_t; 52 8 : nb::class_<SystemAtom<T>, System<SystemAtom<T>>> pyclass(m, pyclass_name.c_str()); 53 8 : pyclass.def(nb::init<std::shared_ptr<const basis_t>>()) 54 4 : .def("set_electric_field", &SystemAtom<T>::set_electric_field) 55 4 : .def("set_magnetic_field", &SystemAtom<T>::set_magnetic_field) 56 4 : .def("set_diamagnetism_enabled", &SystemAtom<T>::set_diamagnetism_enabled) 57 4 : .def("set_ion_distance_vector", &SystemAtom<T>::set_ion_distance_vector) 58 4 : .def("set_ion_charge", &SystemAtom<T>::set_ion_charge) 59 4 : .def("set_ion_interaction_order", &SystemAtom<T>::set_ion_interaction_order); 60 4 : } 61 : 62 : template <typename T> 63 4 : static void declare_system_pair(nb::module_ &m, std::string const &type_name) { 64 4 : std::string pyclass_name = "SystemPair" + type_name; 65 : using basis_t = typename SystemPair<T>::basis_t; 66 8 : nb::class_<SystemPair<T>, System<SystemPair<T>>> pyclass(m, pyclass_name.c_str()); 67 8 : pyclass.def(nb::init<std::shared_ptr<const basis_t>>()) 68 4 : .def("set_interaction_order", &SystemPair<T>::set_interaction_order) 69 4 : .def("set_distance_vector", &SystemPair<T>::set_distance_vector); 70 4 : } 71 : 72 2 : void bind_system(nb::module_ &m) { 73 2 : declare_system<SystemAtom<double>>(m, "SystemAtomReal"); 74 2 : declare_system<SystemAtom<std::complex<double>>>(m, "SystemAtomComplex"); 75 2 : declare_system_atom<double>(m, "Real"); 76 2 : declare_system_atom<std::complex<double>>(m, "Complex"); 77 : 78 2 : declare_system<SystemPair<double>>(m, "SystemPairReal"); 79 2 : declare_system<SystemPair<std::complex<double>>>(m, "SystemPairComplex"); 80 2 : declare_system_pair<double>(m, "Real"); 81 2 : declare_system_pair<std::complex<double>>(m, "Complex"); 82 2 : }