Line data Source code
1 : // SPDX-FileCopyrightText: 2024 PairInteraction Developers
2 : // SPDX-License-Identifier: LGPL-3.0-or-later
3 :
4 : #include "./System.py.hpp"
5 :
6 : #include "pairinteraction/basis/BasisAtom.hpp"
7 : #include "pairinteraction/basis/BasisPair.hpp"
8 : #include "pairinteraction/interfaces/DiagonalizerInterface.hpp"
9 : #include "pairinteraction/system/GreenTensorInterpolator.hpp"
10 : #include "pairinteraction/system/System.hpp"
11 : #include "pairinteraction/system/SystemAtom.hpp"
12 : #include "pairinteraction/system/SystemPair.hpp"
13 :
14 : #include <nanobind/eigen/dense.h>
15 : #include <nanobind/eigen/sparse.h>
16 : #include <nanobind/nanobind.h>
17 : #include <nanobind/stl/array.h>
18 : #include <nanobind/stl/complex.h>
19 : #include <nanobind/stl/optional.h>
20 : #include <nanobind/stl/shared_ptr.h>
21 : #include <nanobind/stl/variant.h>
22 : #include <nanobind/stl/vector.h>
23 :
24 : namespace nb = nanobind;
25 : using namespace nb::literals;
26 : using namespace pairinteraction;
27 :
28 : template <typename T>
29 8 : static void declare_system(nb::module_ &m, const std::string &type_name) {
30 : using S = System<T>;
31 : using scalar_t = typename System<T>::scalar_t;
32 : using real_t = typename System<T>::real_t;
33 :
34 8 : std::string pyclass_name = "System" + type_name;
35 :
36 8 : nb::class_<System<T>, TransformationBuilderInterface<scalar_t>> pyclass(m,
37 : pyclass_name.c_str());
38 8 : pyclass.def("get_basis", &S::get_basis, nb::call_guard<nb::gil_scoped_release>())
39 8 : .def("get_eigenbasis", &S::get_eigenbasis, nb::call_guard<nb::gil_scoped_release>())
40 8 : .def("get_eigenenergies", &S::get_eigenenergies, nb::call_guard<nb::gil_scoped_release>())
41 8 : .def("get_matrix", &S::get_matrix, nb::call_guard<nb::gil_scoped_release>())
42 8 : .def("get_transformation", &S::get_transformation, nb::call_guard<nb::gil_scoped_release>())
43 8 : .def("get_rotator", &S::get_rotator, nb::call_guard<nb::gil_scoped_release>())
44 8 : .def("get_sorter", &S::get_sorter, nb::call_guard<nb::gil_scoped_release>())
45 16 : .def("get_indices_of_blocks", &S::get_indices_of_blocks,
46 0 : nb::call_guard<nb::gil_scoped_release>())
47 8 : .def(
48 : "transform",
49 0 : [](S &self, const Transformation<scalar_t> &transformation) -> T & {
50 0 : return static_cast<T &>(self.transform(transformation));
51 : },
52 0 : nb::call_guard<nb::gil_scoped_release>())
53 8 : .def(
54 : "transform",
55 597 : [](S &self, const Sorting &sorting) -> T & {
56 597 : return static_cast<T &>(self.transform(sorting));
57 : },
58 0 : nb::call_guard<nb::gil_scoped_release>())
59 32 : .def(
60 : "diagonalize",
61 8 : [](S &self, const DiagonalizerInterface<scalar_t> &diagonalizer,
62 : std::optional<real_t> min_eigenenergy, std::optional<real_t> max_eigenenergy,
63 : double rtol) -> T & {
64 : return static_cast<T &>(
65 0 : self.diagonalize(diagonalizer, min_eigenenergy, max_eigenenergy, rtol));
66 : },
67 40 : "diagonalizer"_a, "min_eigenenergy"_a = nb::none(), "max_eigenenergy"_a = nb::none(),
68 8 : "rtol"_a = 1e-6, nb::call_guard<nb::gil_scoped_release>())
69 16 : .def("is_diagonal", &S::is_diagonal, nb::call_guard<nb::gil_scoped_release>());
70 8 : }
71 :
72 : template <typename T>
73 4 : static void declare_system_atom(nb::module_ &m, const std::string &type_name) {
74 : using S = SystemAtom<T>;
75 : using basis_t = typename SystemAtom<T>::basis_t;
76 :
77 4 : std::string pyclass_name = "SystemAtom" + type_name;
78 :
79 8 : nb::class_<S, System<S>> pyclass(m, pyclass_name.c_str());
80 8 : pyclass.def(nb::init<std::shared_ptr<const basis_t>>())
81 4 : .def("set_electric_field", &S::set_electric_field)
82 4 : .def("set_magnetic_field", &S::set_magnetic_field)
83 4 : .def("set_diamagnetism_enabled", &S::set_diamagnetism_enabled)
84 4 : .def("set_ion_distance_vector", &S::set_ion_distance_vector)
85 4 : .def("set_ion_charge", &S::set_ion_charge)
86 4 : .def("set_ion_interaction_order", &S::set_ion_interaction_order);
87 4 : }
88 :
89 : template <typename T>
90 4 : static void declare_system_pair(nb::module_ &m, const std::string &type_name) {
91 : using S = SystemPair<T>;
92 : using basis_t = typename SystemPair<T>::basis_t;
93 :
94 4 : std::string pyclass_name = "SystemPair" + type_name;
95 :
96 8 : nb::class_<S, System<S>> pyclass(m, pyclass_name.c_str());
97 8 : pyclass.def(nb::init<std::shared_ptr<const basis_t>>())
98 4 : .def("set_interaction_order", &S::set_interaction_order)
99 4 : .def("set_distance_vector", &S::set_distance_vector)
100 4 : .def("set_green_tensor_interpolator", &S::set_green_tensor_interpolator);
101 4 : }
102 :
103 : template <typename T>
104 4 : static void declare_green_tensor_interpolator(nb::module_ &m, const std::string &type_name) {
105 : using CE = typename GreenTensorInterpolator<T>::ConstantEntry;
106 : using OE = typename GreenTensorInterpolator<T>::OmegaDependentEntry;
107 : using GT = GreenTensorInterpolator<T>;
108 :
109 4 : std::string ce_name = "ConstantEntry" + type_name;
110 4 : std::string oe_name = "OmegaDependentEntry" + type_name;
111 4 : std::string gt_name = "GreenTensorInterpolator" + type_name;
112 :
113 4 : nb::class_<CE>(m, ce_name.c_str())
114 4 : .def("row", &CE::row)
115 4 : .def("col", &CE::col)
116 8 : .def("val", &CE::val);
117 :
118 4 : nb::class_<OE>(m, oe_name.c_str())
119 4 : .def("row", &OE::row)
120 4 : .def("col", &OE::col)
121 8 : .def("val", &OE::val);
122 :
123 8 : nb::class_<GT>(m, gt_name.c_str())
124 8 : .def(nb::init<>())
125 4 : .def("create_entries_from_cartesian",
126 4 : nb::overload_cast<int, int, const Eigen::MatrixX<T> &>(
127 : >::create_entries_from_cartesian))
128 4 : .def("create_entries_from_cartesian",
129 : nb::overload_cast<int, int, const std::vector<Eigen::MatrixX<T>> &,
130 4 : const std::vector<double> &>(>::create_entries_from_cartesian))
131 : .def("get_spherical_entries",
132 8 : nb::overload_cast<int, int>(>::get_spherical_entries, nb::const_));
133 4 : }
134 :
135 2 : void bind_system(nb::module_ &m) {
136 2 : declare_system<SystemAtom<double>>(m, "SystemAtomReal");
137 2 : declare_system<SystemAtom<std::complex<double>>>(m, "SystemAtomComplex");
138 2 : declare_system_atom<double>(m, "Real");
139 2 : declare_system_atom<std::complex<double>>(m, "Complex");
140 :
141 2 : declare_system<SystemPair<double>>(m, "SystemPairReal");
142 2 : declare_system<SystemPair<std::complex<double>>>(m, "SystemPairComplex");
143 2 : declare_system_pair<double>(m, "Real");
144 2 : declare_system_pair<std::complex<double>>(m, "Complex");
145 :
146 2 : declare_green_tensor_interpolator<double>(m, "Real");
147 2 : declare_green_tensor_interpolator<std::complex<double>>(m, "Complex");
148 2 : }
|