Line data Source code
1 : // SPDX-FileCopyrightText: 2024 Pairinteraction Developers
2 : // SPDX-License-Identifier: LGPL-3.0-or-later
3 :
4 : #include "./System.py.hpp"
5 :
6 : #include "pairinteraction/basis/BasisAtom.hpp"
7 : #include "pairinteraction/basis/BasisPair.hpp"
8 : #include "pairinteraction/interfaces/DiagonalizerInterface.hpp"
9 : #include "pairinteraction/system/GreenTensor.hpp"
10 : #include "pairinteraction/system/System.hpp"
11 : #include "pairinteraction/system/SystemAtom.hpp"
12 : #include "pairinteraction/system/SystemPair.hpp"
13 :
14 : #include <nanobind/eigen/dense.h>
15 : #include <nanobind/eigen/sparse.h>
16 : #include <nanobind/nanobind.h>
17 : #include <nanobind/stl/array.h>
18 : #include <nanobind/stl/complex.h>
19 : #include <nanobind/stl/optional.h>
20 : #include <nanobind/stl/shared_ptr.h>
21 : #include <nanobind/stl/variant.h>
22 : #include <nanobind/stl/vector.h>
23 :
24 : namespace nb = nanobind;
25 : using namespace nb::literals;
26 : using namespace pairinteraction;
27 :
28 : template <typename T>
29 8 : static void declare_system(nb::module_ &m, std::string const &type_name) {
30 : using S = System<T>;
31 : using scalar_t = typename System<T>::scalar_t;
32 :
33 8 : std::string pyclass_name = "System" + type_name;
34 :
35 8 : nb::class_<System<T>, TransformationBuilderInterface<scalar_t>> pyclass(m,
36 : pyclass_name.c_str());
37 8 : pyclass.def("get_basis", &S::get_basis)
38 8 : .def("get_eigenbasis", &S::get_eigenbasis)
39 8 : .def("get_eigenenergies", &S::get_eigenenergies)
40 8 : .def("get_matrix", &S::get_matrix)
41 8 : .def("get_transformation", &S::get_transformation)
42 8 : .def("get_rotator", &S::get_rotator)
43 8 : .def("get_sorter", &S::get_sorter)
44 8 : .def("get_indices_of_blocks", &S::get_indices_of_blocks)
45 8 : .def("transform", nb::overload_cast<const Transformation<scalar_t> &>(&S::transform))
46 8 : .def("transform", nb::overload_cast<const Sorting &>(&S::transform))
47 40 : .def("diagonalize", &S::diagonalize, "diagonalizer"_a, "min_eigenenergy"_a = nb::none(),
48 24 : "max_eigenenergy"_a = nb::none(), "rtol"_a = 1e-6)
49 16 : .def("is_diagonal", &S::is_diagonal);
50 8 : }
51 :
52 : template <typename T>
53 4 : static void declare_system_atom(nb::module_ &m, std::string const &type_name) {
54 : using S = SystemAtom<T>;
55 : using basis_t = typename SystemAtom<T>::basis_t;
56 :
57 4 : std::string pyclass_name = "SystemAtom" + type_name;
58 :
59 8 : nb::class_<S, System<S>> pyclass(m, pyclass_name.c_str());
60 8 : pyclass.def(nb::init<std::shared_ptr<const basis_t>>())
61 4 : .def("set_electric_field", &S::set_electric_field)
62 4 : .def("set_magnetic_field", &S::set_magnetic_field)
63 4 : .def("set_diamagnetism_enabled", &S::set_diamagnetism_enabled)
64 4 : .def("set_ion_distance_vector", &S::set_ion_distance_vector)
65 4 : .def("set_ion_charge", &S::set_ion_charge)
66 4 : .def("set_ion_interaction_order", &S::set_ion_interaction_order);
67 4 : }
68 :
69 : template <typename T>
70 4 : static void declare_system_pair(nb::module_ &m, std::string const &type_name) {
71 : using S = SystemPair<T>;
72 : using basis_t = typename SystemPair<T>::basis_t;
73 :
74 4 : std::string pyclass_name = "SystemPair" + type_name;
75 :
76 8 : nb::class_<S, System<S>> pyclass(m, pyclass_name.c_str());
77 8 : pyclass.def(nb::init<std::shared_ptr<const basis_t>>())
78 4 : .def("set_interaction_order", &S::set_interaction_order)
79 4 : .def("set_distance_vector", &S::set_distance_vector)
80 4 : .def("set_green_tensor", &S::set_green_tensor);
81 4 : }
82 :
83 : template <typename T>
84 4 : static void declare_green_tensor(nb::module_ &m, std::string const &type_name) {
85 : using CE = typename GreenTensor<T>::ConstantEntry;
86 : using OE = typename GreenTensor<T>::OmegaDependentEntry;
87 : using GT = GreenTensor<T>;
88 :
89 4 : std::string ce_name = "ConstantEntry" + type_name;
90 4 : std::string oe_name = "OmegaDependentEntry" + type_name;
91 4 : std::string gt_name = "GreenTensor" + type_name;
92 :
93 4 : nb::class_<CE>(m, ce_name.c_str())
94 4 : .def("row", &CE::row)
95 4 : .def("col", &CE::col)
96 8 : .def("val", &CE::val);
97 :
98 4 : nb::class_<OE>(m, oe_name.c_str())
99 4 : .def("row", &OE::row)
100 4 : .def("col", &OE::col)
101 8 : .def("val", &OE::val);
102 :
103 8 : nb::class_<GT>(m, gt_name.c_str())
104 8 : .def(nb::init<>())
105 4 : .def("create_entries_from_cartesian",
106 4 : nb::overload_cast<int, int, const Eigen::MatrixX<T> &>(
107 : >::create_entries_from_cartesian))
108 4 : .def("create_entries_from_cartesian",
109 : nb::overload_cast<int, int, const std::vector<Eigen::MatrixX<T>> &,
110 4 : const std::vector<double> &>(>::create_entries_from_cartesian))
111 : .def("get_spherical_entries",
112 8 : nb::overload_cast<int, int>(>::get_spherical_entries, nb::const_));
113 4 : }
114 :
115 2 : void bind_system(nb::module_ &m) {
116 2 : declare_system<SystemAtom<double>>(m, "SystemAtomReal");
117 2 : declare_system<SystemAtom<std::complex<double>>>(m, "SystemAtomComplex");
118 2 : declare_system_atom<double>(m, "Real");
119 2 : declare_system_atom<std::complex<double>>(m, "Complex");
120 :
121 2 : declare_system<SystemPair<double>>(m, "SystemPairReal");
122 2 : declare_system<SystemPair<std::complex<double>>>(m, "SystemPairComplex");
123 2 : declare_system_pair<double>(m, "Real");
124 2 : declare_system_pair<std::complex<double>>(m, "Complex");
125 :
126 2 : declare_green_tensor<double>(m, "Real");
127 2 : declare_green_tensor<std::complex<double>>(m, "Complex");
128 2 : }
|