LCOV - code coverage report
Current view: top level - bindings/system - System.py.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 72 78 92.3 %
Date: 2026-04-17 09:20:02 Functions: 15 23 65.2 %

          Line data    Source code
       1             : // SPDX-FileCopyrightText: 2024 PairInteraction Developers
       2             : // SPDX-License-Identifier: LGPL-3.0-or-later
       3             : 
       4             : #include "./System.py.hpp"
       5             : 
       6             : #include "pairinteraction/basis/BasisAtom.hpp"
       7             : #include "pairinteraction/basis/BasisPair.hpp"
       8             : #include "pairinteraction/interfaces/DiagonalizerInterface.hpp"
       9             : #include "pairinteraction/system/GreenTensorInterpolator.hpp"
      10             : #include "pairinteraction/system/System.hpp"
      11             : #include "pairinteraction/system/SystemAtom.hpp"
      12             : #include "pairinteraction/system/SystemPair.hpp"
      13             : 
      14             : #include <nanobind/eigen/dense.h>
      15             : #include <nanobind/eigen/sparse.h>
      16             : #include <nanobind/nanobind.h>
      17             : #include <nanobind/stl/array.h>
      18             : #include <nanobind/stl/complex.h>
      19             : #include <nanobind/stl/optional.h>
      20             : #include <nanobind/stl/shared_ptr.h>
      21             : #include <nanobind/stl/variant.h>
      22             : #include <nanobind/stl/vector.h>
      23             : 
      24             : namespace nb = nanobind;
      25             : using namespace nb::literals;
      26             : using namespace pairinteraction;
      27             : 
      28             : template <typename T>
      29           8 : static void declare_system(nb::module_ &m, const std::string &type_name) {
      30             :     using S = System<T>;
      31             :     using scalar_t = typename System<T>::scalar_t;
      32             :     using real_t = typename System<T>::real_t;
      33             : 
      34           8 :     std::string pyclass_name = "System" + type_name;
      35             : 
      36           8 :     nb::class_<System<T>, TransformationBuilderInterface<scalar_t>> pyclass(m,
      37             :                                                                             pyclass_name.c_str());
      38           8 :     pyclass.def("get_basis", &S::get_basis, nb::call_guard<nb::gil_scoped_release>())
      39           8 :         .def("get_eigenbasis", &S::get_eigenbasis, nb::call_guard<nb::gil_scoped_release>())
      40           8 :         .def("get_eigenenergies", &S::get_eigenenergies, nb::call_guard<nb::gil_scoped_release>())
      41           8 :         .def("get_matrix", &S::get_matrix, nb::call_guard<nb::gil_scoped_release>())
      42           8 :         .def("get_transformation", &S::get_transformation, nb::call_guard<nb::gil_scoped_release>())
      43           8 :         .def("get_rotator", &S::get_rotator, nb::call_guard<nb::gil_scoped_release>())
      44           8 :         .def("get_sorter", &S::get_sorter, nb::call_guard<nb::gil_scoped_release>())
      45          16 :         .def("get_indices_of_blocks", &S::get_indices_of_blocks,
      46           0 :              nb::call_guard<nb::gil_scoped_release>())
      47           8 :         .def(
      48             :             "transform",
      49           0 :             [](S &self, const Transformation<scalar_t> &transformation) -> T & {
      50           0 :                 return static_cast<T &>(self.transform(transformation));
      51             :             },
      52           0 :             nb::call_guard<nb::gil_scoped_release>())
      53           8 :         .def(
      54             :             "transform",
      55         597 :             [](S &self, const Sorting &sorting) -> T & {
      56         597 :                 return static_cast<T &>(self.transform(sorting));
      57             :             },
      58           0 :             nb::call_guard<nb::gil_scoped_release>())
      59          32 :         .def(
      60             :             "diagonalize",
      61           8 :             [](S &self, const DiagonalizerInterface<scalar_t> &diagonalizer,
      62             :                std::optional<real_t> min_eigenenergy, std::optional<real_t> max_eigenenergy,
      63             :                double rtol) -> T & {
      64             :                 return static_cast<T &>(
      65           0 :                     self.diagonalize(diagonalizer, min_eigenenergy, max_eigenenergy, rtol));
      66             :             },
      67          40 :             "diagonalizer"_a, "min_eigenenergy"_a = nb::none(), "max_eigenenergy"_a = nb::none(),
      68           8 :             "rtol"_a = 1e-6, nb::call_guard<nb::gil_scoped_release>())
      69          16 :         .def("is_diagonal", &S::is_diagonal, nb::call_guard<nb::gil_scoped_release>());
      70           8 : }
      71             : 
      72             : template <typename T>
      73           4 : static void declare_system_atom(nb::module_ &m, const std::string &type_name) {
      74             :     using S = SystemAtom<T>;
      75             :     using basis_t = typename SystemAtom<T>::basis_t;
      76             : 
      77           4 :     std::string pyclass_name = "SystemAtom" + type_name;
      78             : 
      79           8 :     nb::class_<S, System<S>> pyclass(m, pyclass_name.c_str());
      80           8 :     pyclass.def(nb::init<std::shared_ptr<const basis_t>>())
      81           4 :         .def("set_electric_field", &S::set_electric_field)
      82           4 :         .def("set_magnetic_field", &S::set_magnetic_field)
      83           4 :         .def("set_diamagnetism_enabled", &S::set_diamagnetism_enabled)
      84           4 :         .def("set_ion_distance_vector", &S::set_ion_distance_vector)
      85           4 :         .def("set_ion_charge", &S::set_ion_charge)
      86           4 :         .def("set_ion_interaction_order", &S::set_ion_interaction_order);
      87           4 : }
      88             : 
      89             : template <typename T>
      90           4 : static void declare_system_pair(nb::module_ &m, const std::string &type_name) {
      91             :     using S = SystemPair<T>;
      92             :     using basis_t = typename SystemPair<T>::basis_t;
      93             : 
      94           4 :     std::string pyclass_name = "SystemPair" + type_name;
      95             : 
      96           8 :     nb::class_<S, System<S>> pyclass(m, pyclass_name.c_str());
      97           8 :     pyclass.def(nb::init<std::shared_ptr<const basis_t>>())
      98           4 :         .def("set_interaction_order", &S::set_interaction_order)
      99           4 :         .def("set_distance_vector", &S::set_distance_vector)
     100           4 :         .def("set_green_tensor_interpolator", &S::set_green_tensor_interpolator);
     101           4 : }
     102             : 
     103             : template <typename T>
     104           4 : static void declare_green_tensor_interpolator(nb::module_ &m, const std::string &type_name) {
     105             :     using CE = typename GreenTensorInterpolator<T>::ConstantEntry;
     106             :     using OE = typename GreenTensorInterpolator<T>::OmegaDependentEntry;
     107             :     using GT = GreenTensorInterpolator<T>;
     108             : 
     109           4 :     std::string ce_name = "ConstantEntry" + type_name;
     110           4 :     std::string oe_name = "OmegaDependentEntry" + type_name;
     111           4 :     std::string gt_name = "GreenTensorInterpolator" + type_name;
     112             : 
     113           4 :     nb::class_<CE>(m, ce_name.c_str())
     114           4 :         .def("row", &CE::row)
     115           4 :         .def("col", &CE::col)
     116           8 :         .def("val", &CE::val);
     117             : 
     118           4 :     nb::class_<OE>(m, oe_name.c_str())
     119           4 :         .def("row", &OE::row)
     120           4 :         .def("col", &OE::col)
     121           8 :         .def("val", &OE::val);
     122             : 
     123           8 :     nb::class_<GT>(m, gt_name.c_str())
     124           8 :         .def(nb::init<>())
     125           4 :         .def("create_entries_from_cartesian",
     126           4 :              nb::overload_cast<int, int, const Eigen::MatrixX<T> &>(
     127             :                  &GT::create_entries_from_cartesian))
     128           4 :         .def("create_entries_from_cartesian",
     129             :              nb::overload_cast<int, int, const std::vector<Eigen::MatrixX<T>> &,
     130           4 :                                const std::vector<double> &>(&GT::create_entries_from_cartesian))
     131             :         .def("get_spherical_entries",
     132           8 :              nb::overload_cast<int, int>(&GT::get_spherical_entries, nb::const_));
     133           4 : }
     134             : 
     135           2 : void bind_system(nb::module_ &m) {
     136           2 :     declare_system<SystemAtom<double>>(m, "SystemAtomReal");
     137           2 :     declare_system<SystemAtom<std::complex<double>>>(m, "SystemAtomComplex");
     138           2 :     declare_system_atom<double>(m, "Real");
     139           2 :     declare_system_atom<std::complex<double>>(m, "Complex");
     140             : 
     141           2 :     declare_system<SystemPair<double>>(m, "SystemPairReal");
     142           2 :     declare_system<SystemPair<std::complex<double>>>(m, "SystemPairComplex");
     143           2 :     declare_system_pair<double>(m, "Real");
     144           2 :     declare_system_pair<std::complex<double>>(m, "Complex");
     145             : 
     146           2 :     declare_green_tensor_interpolator<double>(m, "Real");
     147           2 :     declare_green_tensor_interpolator<std::complex<double>>(m, "Complex");
     148           2 : }

Generated by: LCOV version 1.16