Line data Source code
1 : // SPDX-FileCopyrightText: 2024 PairInteraction Developers
2 : // SPDX-License-Identifier: LGPL-3.0-or-later
3 :
4 : #include "./Diagonalizer.py.hpp"
5 :
6 : #include "pairinteraction/diagonalize/DiagonalizerEigen.hpp"
7 : #include "pairinteraction/diagonalize/DiagonalizerFeast.hpp"
8 : #include "pairinteraction/diagonalize/DiagonalizerLapackeEvd.hpp"
9 : #include "pairinteraction/diagonalize/DiagonalizerLapackeEvr.hpp"
10 : #include "pairinteraction/diagonalize/diagonalize.hpp"
11 : #include "pairinteraction/enums/FloatType.hpp"
12 : #include "pairinteraction/system/SystemAtom.hpp"
13 : #include "pairinteraction/system/SystemPair.hpp"
14 :
15 : #include <nanobind/eigen/sparse.h>
16 : #include <nanobind/nanobind.h>
17 : #include <nanobind/stl/complex.h>
18 : #include <nanobind/stl/optional.h>
19 :
20 : namespace nb = nanobind;
21 : using namespace nb::literals;
22 : using namespace pairinteraction;
23 :
24 : template <typename T>
25 4 : static void declare_diagonalizer_eigen(nb::module_ &m, std::string const &type_name) {
26 4 : std::string pyclass_name = "DiagonalizerEigen" + type_name;
27 4 : nb::class_<DiagonalizerEigen<T>, DiagonalizerInterface<T>> pyclass(m, pyclass_name.c_str());
28 4 : pyclass.def(nb::init<FloatType>(), "float_type"_a = FloatType::FLOAT64)
29 : .def("eigh",
30 4 : nb::overload_cast<const Eigen::SparseMatrix<T, Eigen::RowMajor> &, double>(
31 : &DiagonalizerEigen<T>::eigh, nb::const_),
32 4 : nb::call_guard<nb::gil_scoped_release>());
33 4 : }
34 :
35 : template <typename T>
36 4 : static void declare_diagonalizer_feast(nb::module_ &m, std::string const &type_name) {
37 4 : std::string pyclass_name = "DiagonalizerFeast" + type_name;
38 : using real_t = typename DiagonalizerFeast<T>::real_t;
39 4 : nb::class_<DiagonalizerFeast<T>, DiagonalizerInterface<T>> pyclass(m, pyclass_name.c_str());
40 8 : pyclass.def(nb::init<int, FloatType>(), "m0"_a, "float_type"_a = FloatType::FLOAT64)
41 4 : .def("eigh",
42 4 : nb::overload_cast<const Eigen::SparseMatrix<T, Eigen::RowMajor> &, double>(
43 : &DiagonalizerFeast<T>::eigh, nb::const_),
44 4 : nb::call_guard<nb::gil_scoped_release>())
45 : .def("eigh",
46 : nb::overload_cast<const Eigen::SparseMatrix<T, Eigen::RowMajor> &,
47 4 : std::optional<real_t>, std::optional<real_t>, double>(
48 : &DiagonalizerFeast<T>::eigh, nb::const_),
49 4 : nb::call_guard<nb::gil_scoped_release>());
50 4 : }
51 :
52 : template <typename T>
53 4 : static void declare_diagonalizer_lapacke_evd(nb::module_ &m, std::string const &type_name) {
54 4 : std::string pyclass_name = "DiagonalizerLapackeEvd" + type_name;
55 4 : nb::class_<DiagonalizerLapackeEvd<T>, DiagonalizerInterface<T>> pyclass(m,
56 : pyclass_name.c_str());
57 4 : pyclass.def(nb::init<FloatType>(), "float_type"_a = FloatType::FLOAT64)
58 : .def("eigh",
59 4 : nb::overload_cast<const Eigen::SparseMatrix<T, Eigen::RowMajor> &, double>(
60 : &DiagonalizerLapackeEvd<T>::eigh, nb::const_),
61 4 : nb::call_guard<nb::gil_scoped_release>());
62 4 : }
63 :
64 : template <typename T>
65 4 : static void declare_diagonalizer_lapacke_evr(nb::module_ &m, std::string const &type_name) {
66 4 : std::string pyclass_name = "DiagonalizerLapackeEvr" + type_name;
67 4 : nb::class_<DiagonalizerLapackeEvr<T>, DiagonalizerInterface<T>> pyclass(m,
68 : pyclass_name.c_str());
69 4 : pyclass.def(nb::init<FloatType>(), "float_type"_a = FloatType::FLOAT64)
70 : .def("eigh",
71 4 : nb::overload_cast<const Eigen::SparseMatrix<T, Eigen::RowMajor> &, double>(
72 : &DiagonalizerLapackeEvr<T>::eigh, nb::const_),
73 4 : nb::call_guard<nb::gil_scoped_release>());
74 4 : }
75 :
76 : template <typename T>
77 8 : static void declare_diagonalize(nb::module_ &m, std::string const &type_name) {
78 8 : std::string pyclass_name = "diagonalize" + type_name;
79 : using real_t = typename T::real_t;
80 : using scalar_t = typename T::scalar_t;
81 40 : m.def(
82 : pyclass_name.c_str(),
83 89 : [](nb::list pylist, // NOLINT
84 : const DiagonalizerInterface<scalar_t> &diagonalizer,
85 : std::optional<real_t> min_eigenvalue, std::optional<real_t> max_eigenvalue,
86 : double rtol) {
87 89 : std::vector<std::reference_wrapper<T>> systems;
88 89 : systems.reserve(pylist.size());
89 696 : for (nb::handle_t<T> &&h : pylist) {
90 607 : systems.push_back(nb::cast<T &>(h));
91 : }
92 : {
93 89 : nb::gil_scoped_release release;
94 89 : diagonalize(systems, diagonalizer, min_eigenvalue, max_eigenvalue, rtol);
95 89 : }
96 89 : },
97 24 : "systems"_a, "diagonalizer"_a, "min_eigenvalue"_a = nb::none(),
98 24 : "max_eigenvalue"_a = nb::none(), "rtol"_a = 1e-6);
99 8 : }
100 :
101 2 : void bind_diagonalizer(nb::module_ &m) {
102 2 : declare_diagonalizer_eigen<double>(m, "Real");
103 2 : declare_diagonalizer_eigen<std::complex<double>>(m, "Complex");
104 2 : declare_diagonalizer_feast<double>(m, "Real");
105 2 : declare_diagonalizer_feast<std::complex<double>>(m, "Complex");
106 2 : declare_diagonalizer_lapacke_evd<double>(m, "Real");
107 2 : declare_diagonalizer_lapacke_evd<std::complex<double>>(m, "Complex");
108 2 : declare_diagonalizer_lapacke_evr<double>(m, "Real");
109 2 : declare_diagonalizer_lapacke_evr<std::complex<double>>(m, "Complex");
110 :
111 2 : declare_diagonalize<SystemAtom<double>>(m, "SystemAtomReal");
112 2 : declare_diagonalize<SystemAtom<std::complex<double>>>(m, "SystemAtomComplex");
113 :
114 2 : declare_diagonalize<SystemPair<double>>(m, "SystemPairReal");
115 2 : declare_diagonalize<SystemPair<std::complex<double>>>(m, "SystemPairComplex");
116 2 : }
|