Line data Source code
1 : // SPDX-FileCopyrightText: 2024 Pairinteraction Developers
2 : // SPDX-License-Identifier: LGPL-3.0-or-later
3 :
4 : #include "./System.py.hpp"
5 :
6 : #include "pairinteraction/basis/BasisAtom.hpp"
7 : #include "pairinteraction/basis/BasisPair.hpp"
8 : #include "pairinteraction/interfaces/DiagonalizerInterface.hpp"
9 : #include "pairinteraction/system/GreenTensor.hpp"
10 : #include "pairinteraction/system/System.hpp"
11 : #include "pairinteraction/system/SystemAtom.hpp"
12 : #include "pairinteraction/system/SystemPair.hpp"
13 :
14 : #include <nanobind/eigen/dense.h>
15 : #include <nanobind/eigen/sparse.h>
16 : #include <nanobind/nanobind.h>
17 : #include <nanobind/stl/array.h>
18 : #include <nanobind/stl/complex.h>
19 : #include <nanobind/stl/optional.h>
20 : #include <nanobind/stl/shared_ptr.h>
21 : #include <nanobind/stl/variant.h>
22 : #include <nanobind/stl/vector.h>
23 :
24 : namespace nb = nanobind;
25 : using namespace nb::literals;
26 : using namespace pairinteraction;
27 :
28 : template <typename T>
29 4 : static void declare_system(nb::module_ &m, std::string const &type_name) {
30 : using S = System<T>;
31 : using scalar_t = typename System<T>::scalar_t;
32 :
33 4 : std::string pyclass_name = "System" + type_name;
34 :
35 4 : nb::class_<System<T>, TransformationBuilderInterface<scalar_t>> pyclass(m,
36 : pyclass_name.c_str());
37 4 : pyclass.def("get_basis", &S::get_basis)
38 4 : .def("get_eigenbasis", &S::get_eigenbasis)
39 4 : .def("get_eigenenergies", &S::get_eigenenergies)
40 4 : .def("get_matrix", &S::get_matrix)
41 4 : .def("get_transformation", &S::get_transformation)
42 4 : .def("get_rotator", &S::get_rotator)
43 4 : .def("get_sorter", &S::get_sorter)
44 4 : .def("get_indices_of_blocks", &S::get_indices_of_blocks)
45 4 : .def("transform", nb::overload_cast<const Transformation<scalar_t> &>(&S::transform))
46 4 : .def("transform", nb::overload_cast<const Sorting &>(&S::transform))
47 20 : .def("diagonalize", &S::diagonalize, "diagonalizer"_a, "min_eigenenergy"_a = nb::none(),
48 12 : "max_eigenenergy"_a = nb::none(), "rtol"_a = 1e-6)
49 8 : .def("is_diagonal", &S::is_diagonal);
50 4 : }
51 :
52 : template <typename T>
53 2 : static void declare_system_atom(nb::module_ &m, std::string const &type_name) {
54 : using S = SystemAtom<T>;
55 : using basis_t = typename SystemAtom<T>::basis_t;
56 :
57 2 : std::string pyclass_name = "SystemAtom" + type_name;
58 :
59 4 : nb::class_<S, System<S>> pyclass(m, pyclass_name.c_str());
60 4 : pyclass.def(nb::init<std::shared_ptr<const basis_t>>())
61 2 : .def("set_electric_field", &S::set_electric_field)
62 2 : .def("set_magnetic_field", &S::set_magnetic_field)
63 2 : .def("set_diamagnetism_enabled", &S::set_diamagnetism_enabled)
64 2 : .def("set_ion_distance_vector", &S::set_ion_distance_vector)
65 2 : .def("set_ion_charge", &S::set_ion_charge)
66 2 : .def("set_ion_interaction_order", &S::set_ion_interaction_order);
67 2 : }
68 :
69 : template <typename T>
70 2 : static void declare_system_pair(nb::module_ &m, std::string const &type_name) {
71 : using S = SystemPair<T>;
72 : using basis_t = typename SystemPair<T>::basis_t;
73 :
74 2 : std::string pyclass_name = "SystemPair" + type_name;
75 :
76 4 : nb::class_<S, System<S>> pyclass(m, pyclass_name.c_str());
77 4 : pyclass.def(nb::init<std::shared_ptr<const basis_t>>())
78 2 : .def("set_interaction_order", &S::set_interaction_order)
79 2 : .def("set_distance_vector", &S::set_distance_vector)
80 2 : .def("set_green_tensor", &S::set_green_tensor);
81 2 : }
82 :
83 : template <typename T>
84 2 : static void declare_green_tensor(nb::module_ &m, std::string const &type_name) {
85 : using CE = typename GreenTensor<T>::ConstantEntry;
86 : using OE = typename GreenTensor<T>::OmegaDependentEntry;
87 : using GT = GreenTensor<T>;
88 :
89 2 : std::string ce_name = "ConstantEntry" + type_name;
90 2 : std::string oe_name = "OmegaDependentEntry" + type_name;
91 2 : std::string gt_name = "GreenTensor" + type_name;
92 :
93 2 : nb::class_<CE>(m, ce_name.c_str())
94 2 : .def("row", &CE::row)
95 2 : .def("col", &CE::col)
96 4 : .def("val", &CE::val);
97 :
98 2 : nb::class_<OE>(m, oe_name.c_str())
99 2 : .def("row", &OE::row)
100 2 : .def("col", &OE::col)
101 4 : .def("val", &OE::val);
102 :
103 4 : nb::class_<GT>(m, gt_name.c_str())
104 4 : .def(nb::init<>())
105 2 : .def("create_entries_from_cartesian",
106 2 : nb::overload_cast<int, int, const Eigen::MatrixX<T> &>(
107 : >::create_entries_from_cartesian))
108 2 : .def("create_entries_from_cartesian",
109 : nb::overload_cast<int, int, const std::vector<Eigen::MatrixX<T>> &,
110 2 : const std::vector<double> &>(>::create_entries_from_cartesian))
111 : .def("get_spherical_entries",
112 4 : nb::overload_cast<int, int>(>::get_spherical_entries, nb::const_));
113 2 : }
114 :
115 1 : void bind_system(nb::module_ &m) {
116 1 : declare_system<SystemAtom<double>>(m, "SystemAtomReal");
117 1 : declare_system<SystemAtom<std::complex<double>>>(m, "SystemAtomComplex");
118 1 : declare_system_atom<double>(m, "Real");
119 1 : declare_system_atom<std::complex<double>>(m, "Complex");
120 :
121 1 : declare_system<SystemPair<double>>(m, "SystemPairReal");
122 1 : declare_system<SystemPair<std::complex<double>>>(m, "SystemPairComplex");
123 1 : declare_system_pair<double>(m, "Real");
124 1 : declare_system_pair<std::complex<double>>(m, "Complex");
125 :
126 1 : declare_green_tensor<double>(m, "Real");
127 1 : declare_green_tensor<std::complex<double>>(m, "Complex");
128 1 : }
|