Line data Source code
1 : // SPDX-FileCopyrightText: 2024 Pairinteraction Developers
2 : // SPDX-License-Identifier: LGPL-3.0-or-later
3 :
4 : #include "./Diagonalizer.py.hpp"
5 :
6 : #include "pairinteraction/diagonalize/DiagonalizerEigen.hpp"
7 : #include "pairinteraction/diagonalize/DiagonalizerFeast.hpp"
8 : #include "pairinteraction/diagonalize/DiagonalizerLapackeEvd.hpp"
9 : #include "pairinteraction/diagonalize/DiagonalizerLapackeEvr.hpp"
10 : #include "pairinteraction/diagonalize/diagonalize.hpp"
11 : #include "pairinteraction/enums/FloatType.hpp"
12 : #include "pairinteraction/system/SystemAtom.hpp"
13 : #include "pairinteraction/system/SystemPair.hpp"
14 :
15 : #include <nanobind/eigen/sparse.h>
16 : #include <nanobind/nanobind.h>
17 : #include <nanobind/stl/complex.h>
18 : #include <nanobind/stl/optional.h>
19 :
20 : namespace nb = nanobind;
21 : using namespace nb::literals;
22 : using namespace pairinteraction;
23 :
24 : template <typename T>
25 4 : static void declare_diagonalizer_eigen(nb::module_ &m, std::string const &type_name) {
26 4 : std::string pyclass_name = "DiagonalizerEigen" + type_name;
27 4 : nb::class_<DiagonalizerEigen<T>, DiagonalizerInterface<T>> pyclass(m, pyclass_name.c_str());
28 4 : pyclass.def(nb::init<FloatType>(), "float_type"_a = FloatType::FLOAT64)
29 : .def("eigh",
30 8 : nb::overload_cast<const Eigen::SparseMatrix<T, Eigen::RowMajor> &, double>(
31 : &DiagonalizerEigen<T>::eigh, nb::const_));
32 4 : }
33 :
34 : template <typename T>
35 4 : static void declare_diagonalizer_feast(nb::module_ &m, std::string const &type_name) {
36 4 : std::string pyclass_name = "DiagonalizerFeast" + type_name;
37 : using real_t = typename DiagonalizerFeast<T>::real_t;
38 4 : nb::class_<DiagonalizerFeast<T>, DiagonalizerInterface<T>> pyclass(m, pyclass_name.c_str());
39 8 : pyclass.def(nb::init<int, FloatType>(), "m0"_a, "float_type"_a = FloatType::FLOAT64)
40 4 : .def("eigh",
41 4 : nb::overload_cast<const Eigen::SparseMatrix<T, Eigen::RowMajor> &, double>(
42 : &DiagonalizerFeast<T>::eigh, nb::const_))
43 : .def("eigh",
44 : nb::overload_cast<const Eigen::SparseMatrix<T, Eigen::RowMajor> &,
45 8 : std::optional<real_t>, std::optional<real_t>, double>(
46 : &DiagonalizerFeast<T>::eigh, nb::const_));
47 4 : }
48 :
49 : template <typename T>
50 4 : static void declare_diagonalizer_lapacke_evd(nb::module_ &m, std::string const &type_name) {
51 4 : std::string pyclass_name = "DiagonalizerLapackeEvd" + type_name;
52 4 : nb::class_<DiagonalizerLapackeEvd<T>, DiagonalizerInterface<T>> pyclass(m,
53 : pyclass_name.c_str());
54 4 : pyclass.def(nb::init<FloatType>(), "float_type"_a = FloatType::FLOAT64)
55 : .def("eigh",
56 8 : nb::overload_cast<const Eigen::SparseMatrix<T, Eigen::RowMajor> &, double>(
57 : &DiagonalizerLapackeEvd<T>::eigh, nb::const_));
58 4 : }
59 :
60 : template <typename T>
61 4 : static void declare_diagonalizer_lapacke_evr(nb::module_ &m, std::string const &type_name) {
62 4 : std::string pyclass_name = "DiagonalizerLapackeEvr" + type_name;
63 4 : nb::class_<DiagonalizerLapackeEvr<T>, DiagonalizerInterface<T>> pyclass(m,
64 : pyclass_name.c_str());
65 4 : pyclass.def(nb::init<FloatType>(), "float_type"_a = FloatType::FLOAT64)
66 : .def("eigh",
67 8 : nb::overload_cast<const Eigen::SparseMatrix<T, Eigen::RowMajor> &, double>(
68 : &DiagonalizerLapackeEvr<T>::eigh, nb::const_));
69 4 : }
70 :
71 : template <typename T>
72 8 : static void declare_diagonalize(nb::module_ &m, std::string const &type_name) {
73 8 : std::string pyclass_name = "diagonalize" + type_name;
74 : using real_t = typename T::real_t;
75 : using scalar_t = typename T::scalar_t;
76 40 : m.def(
77 : pyclass_name.c_str(),
78 10 : [](nb::list pylist, // NOLINT
79 : const DiagonalizerInterface<scalar_t> &diagonalizer,
80 : std::optional<real_t> min_eigenvalue, std::optional<real_t> max_eigenvalue,
81 : double rtol) {
82 10 : std::vector<T> systems;
83 10 : systems.reserve(pylist.size());
84 70 : for (auto h : pylist) {
85 60 : systems.push_back(nb::cast<T>(h));
86 : }
87 10 : diagonalize(systems, diagonalizer, min_eigenvalue, max_eigenvalue, rtol);
88 70 : for (size_t i = 0; i < systems.size(); ++i) {
89 60 : pylist[i] = nb::cast(systems[i]);
90 : }
91 10 : },
92 24 : "systems"_a, "diagonalizer"_a, "min_eigenvalue"_a = nb::none(),
93 24 : "max_eigenvalue"_a = nb::none(), "rtol"_a = 1e-6);
94 8 : }
95 :
96 2 : void bind_diagonalizer(nb::module_ &m) {
97 2 : declare_diagonalizer_eigen<double>(m, "Real");
98 2 : declare_diagonalizer_eigen<std::complex<double>>(m, "Complex");
99 2 : declare_diagonalizer_feast<double>(m, "Real");
100 2 : declare_diagonalizer_feast<std::complex<double>>(m, "Complex");
101 2 : declare_diagonalizer_lapacke_evd<double>(m, "Real");
102 2 : declare_diagonalizer_lapacke_evd<std::complex<double>>(m, "Complex");
103 2 : declare_diagonalizer_lapacke_evr<double>(m, "Real");
104 2 : declare_diagonalizer_lapacke_evr<std::complex<double>>(m, "Complex");
105 :
106 2 : declare_diagonalize<SystemAtom<double>>(m, "SystemAtomReal");
107 2 : declare_diagonalize<SystemAtom<std::complex<double>>>(m, "SystemAtomComplex");
108 :
109 2 : declare_diagonalize<SystemPair<double>>(m, "SystemPairReal");
110 2 : declare_diagonalize<SystemPair<std::complex<double>>>(m, "SystemPairComplex");
111 2 : }
|