LCOV - code coverage report
Current view: top level - bindings/basis - Basis.py.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 114 114 100.0 %
Date: 2025-04-29 15:56:08 Functions: 13 13 100.0 %

          Line data    Source code
       1             : // SPDX-FileCopyrightText: 2024 Pairinteraction Developers
       2             : // SPDX-License-Identifier: LGPL-3.0-or-later
       3             : 
       4             : #include "./Basis.py.hpp"
       5             : 
       6             : #include "pairinteraction/basis/Basis.hpp"
       7             : #include "pairinteraction/basis/BasisAtom.hpp"
       8             : #include "pairinteraction/basis/BasisAtomCreator.hpp"
       9             : #include "pairinteraction/basis/BasisPair.hpp"
      10             : #include "pairinteraction/basis/BasisPairCreator.hpp"
      11             : #include "pairinteraction/database/Database.hpp"
      12             : #include "pairinteraction/interfaces/TransformationBuilderInterface.hpp"
      13             : #include "pairinteraction/ket/KetAtom.hpp"
      14             : #include "pairinteraction/ket/KetPair.hpp"
      15             : #include "pairinteraction/system/SystemAtom.hpp"
      16             : 
      17             : #include <nanobind/eigen/sparse.h>
      18             : #include <nanobind/nanobind.h>
      19             : #include <nanobind/stl/complex.h>
      20             : #include <nanobind/stl/set.h>
      21             : #include <nanobind/stl/shared_ptr.h>
      22             : #include <nanobind/stl/string.h>
      23             : #include <nanobind/stl/vector.h>
      24             : 
      25             : namespace nb = nanobind;
      26             : using namespace pairinteraction;
      27             : 
      28             : template <typename T>
      29           4 : static void declare_basis(nb::module_ &m, std::string const &type_name) {
      30           4 :     std::string pyclass_name = "Basis" + type_name;
      31             :     using scalar_t = typename Basis<T>::scalar_t;
      32           4 :     nb::class_<Basis<T>, TransformationBuilderInterface<scalar_t>> pyclass(m, pyclass_name.c_str());
      33           4 :     pyclass.def("get_kets", &Basis<T>::get_kets)
      34           4 :         .def("get_ket", &Basis<T>::get_ket)
      35           4 :         .def("get_state", &Basis<T>::get_state)
      36           4 :         .def("get_number_of_states", &Basis<T>::get_number_of_states)
      37           4 :         .def("get_number_of_kets", &Basis<T>::get_number_of_kets)
      38           4 :         .def("get_quantum_number_f", &Basis<T>::get_quantum_number_f)
      39           4 :         .def("get_quantum_number_m", &Basis<T>::get_quantum_number_m)
      40           4 :         .def("get_parity", &Basis<T>::get_parity)
      41           4 :         .def("get_coefficients", nb::overload_cast<>(&Basis<T>::get_coefficients, nb::const_))
      42           4 :         .def("get_transformation", &Basis<T>::get_transformation)
      43           4 :         .def("get_rotator", &Basis<T>::get_rotator)
      44           4 :         .def("get_sorter", &Basis<T>::get_sorter)
      45           4 :         .def("get_indices_of_blocks", &Basis<T>::get_indices_of_blocks)
      46           4 :         .def("get_sorter_without_checks", &Basis<T>::get_sorter_without_checks)
      47           4 :         .def("get_indices_of_blocks_without_checks",
      48           4 :              &Basis<T>::get_indices_of_blocks_without_checks)
      49           4 :         .def(
      50             :             "transformed",
      51           4 :             nb::overload_cast<const Transformation<scalar_t> &>(&Basis<T>::transformed, nb::const_))
      52           4 :         .def("transformed", nb::overload_cast<const Sorting &>(&Basis<T>::transformed, nb::const_))
      53           4 :         .def("get_amplitudes",
      54           4 :              nb::overload_cast<std::shared_ptr<const typename Basis<T>::ket_t>>(
      55             :                  &Basis<T>::get_amplitudes, nb::const_))
      56           4 :         .def("get_amplitudes",
      57           4 :              nb::overload_cast<std::shared_ptr<const T>>(&Basis<T>::get_amplitudes, nb::const_))
      58           4 :         .def("get_overlaps",
      59           4 :              nb::overload_cast<std::shared_ptr<const typename Basis<T>::ket_t>>(
      60             :                  &Basis<T>::get_overlaps, nb::const_))
      61           4 :         .def("get_overlaps",
      62           4 :              nb::overload_cast<std::shared_ptr<const T>>(&Basis<T>::get_overlaps, nb::const_))
      63           4 :         .def("get_matrix_elements",
      64           4 :              nb::overload_cast<std::shared_ptr<const typename Basis<T>::ket_t>, OperatorType, int>(
      65             :                  &Basis<T>::get_matrix_elements, nb::const_))
      66           4 :         .def("get_matrix_elements",
      67           4 :              nb::overload_cast<std::shared_ptr<const T>, OperatorType, int>(
      68             :                  &Basis<T>::get_matrix_elements, nb::const_))
      69           4 :         .def("get_corresponding_state",
      70           4 :              nb::overload_cast<size_t>(&Basis<T>::get_corresponding_state, nb::const_))
      71           4 :         .def("get_corresponding_state",
      72           4 :              nb::overload_cast<std::shared_ptr<const typename Basis<T>::ket_t>>(
      73             :                  &Basis<T>::get_corresponding_state, nb::const_))
      74           4 :         .def("get_corresponding_state_index",
      75           4 :              nb::overload_cast<size_t>(&Basis<T>::get_corresponding_state_index, nb::const_))
      76           4 :         .def("get_corresponding_state_index",
      77           4 :              nb::overload_cast<std::shared_ptr<const typename Basis<T>::ket_t>>(
      78             :                  &Basis<T>::get_corresponding_state_index, nb::const_))
      79           4 :         .def("get_corresponding_ket",
      80           4 :              nb::overload_cast<size_t>(&Basis<T>::get_corresponding_ket, nb::const_))
      81           4 :         .def("get_corresponding_ket",
      82           4 :              nb::overload_cast<std::shared_ptr<const T>>(&Basis<T>::get_corresponding_ket,
      83             :                                                          nb::const_))
      84           4 :         .def("get_corresponding_ket_index",
      85           4 :              nb::overload_cast<size_t>(&Basis<T>::get_corresponding_ket_index, nb::const_))
      86             :         .def("get_corresponding_ket_index",
      87           4 :              nb::overload_cast<std::shared_ptr<const T>>(&Basis<T>::get_corresponding_ket_index,
      88             :                                                          nb::const_));
      89           4 : }
      90             : 
      91             : template <typename T>
      92           2 : static void declare_basis_atom(nb::module_ &m, std::string const &type_name) {
      93           2 :     std::string pyclass_name = "BasisAtom" + type_name;
      94           2 :     nb::class_<BasisAtom<T>, Basis<BasisAtom<T>>> pyclass(m, pyclass_name.c_str());
      95           2 : }
      96             : 
      97             : template <typename T>
      98           2 : static void declare_basis_atom_creator(nb::module_ &m, std::string const &type_name) {
      99           2 :     std::string pyclass_name = "BasisAtomCreator" + type_name;
     100           4 :     nb::class_<BasisAtomCreator<T>> pyclass(m, pyclass_name.c_str());
     101           4 :     pyclass.def(nb::init<>())
     102           2 :         .def("set_species", &BasisAtomCreator<T>::set_species)
     103           2 :         .def("restrict_energy", &BasisAtomCreator<T>::restrict_energy)
     104           2 :         .def("restrict_quantum_number_f", &BasisAtomCreator<T>::restrict_quantum_number_f)
     105           2 :         .def("restrict_quantum_number_m", &BasisAtomCreator<T>::restrict_quantum_number_m)
     106           2 :         .def("restrict_parity", &BasisAtomCreator<T>::restrict_parity)
     107           2 :         .def("restrict_quantum_number_n", &BasisAtomCreator<T>::restrict_quantum_number_n)
     108           2 :         .def("restrict_quantum_number_nu", &BasisAtomCreator<T>::restrict_quantum_number_nu)
     109           2 :         .def("restrict_quantum_number_l", &BasisAtomCreator<T>::restrict_quantum_number_l)
     110           2 :         .def("restrict_quantum_number_s", &BasisAtomCreator<T>::restrict_quantum_number_s)
     111           2 :         .def("restrict_quantum_number_j", &BasisAtomCreator<T>::restrict_quantum_number_j)
     112           2 :         .def("append_ket", &BasisAtomCreator<T>::append_ket)
     113           2 :         .def("create", &BasisAtomCreator<T>::create);
     114           2 : }
     115             : 
     116             : template <typename T>
     117           2 : static void declare_basis_pair(nb::module_ &m, std::string const &type_name) {
     118           2 :     std::string pyclass_name = "BasisPair" + type_name;
     119           2 :     nb::class_<BasisPair<T>, Basis<BasisPair<T>>> pyclass(m, pyclass_name.c_str());
     120             :     pyclass
     121           2 :         .def("get_amplitudes",
     122           2 :              nb::overload_cast<std::shared_ptr<const KetAtom>, std::shared_ptr<const KetAtom>>(
     123             :                  &BasisPair<T>::get_amplitudes, nb::const_))
     124           2 :         .def("get_amplitudes",
     125             :              nb::overload_cast<std::shared_ptr<const BasisAtom<T>>,
     126           2 :                                std::shared_ptr<const BasisAtom<T>>>(&BasisPair<T>::get_amplitudes,
     127             :                                                                     nb::const_))
     128           2 :         .def("get_overlaps",
     129           2 :              nb::overload_cast<std::shared_ptr<const KetAtom>, std::shared_ptr<const KetAtom>>(
     130             :                  &BasisPair<T>::get_overlaps, nb::const_))
     131           2 :         .def("get_overlaps",
     132             :              nb::overload_cast<std::shared_ptr<const BasisAtom<T>>,
     133           2 :                                std::shared_ptr<const BasisAtom<T>>>(&BasisPair<T>::get_overlaps,
     134             :                                                                     nb::const_))
     135           2 :         .def("get_matrix_elements",
     136             :              nb::overload_cast<std::shared_ptr<const BasisPair<T>>, OperatorType, OperatorType, int,
     137           2 :                                int>(&BasisPair<T>::get_matrix_elements, nb::const_))
     138           2 :         .def("get_matrix_elements",
     139             :              nb::overload_cast<std::shared_ptr<const BasisAtom<T>>,
     140             :                                std::shared_ptr<const BasisAtom<T>>, OperatorType, OperatorType, int,
     141           2 :                                int>(&BasisPair<T>::get_matrix_elements, nb::const_))
     142           2 :         .def("get_matrix_elements",
     143             :              nb::overload_cast<std::shared_ptr<const typename BasisPair<T>::ket_t>, OperatorType,
     144           2 :                                OperatorType, int, int>(&BasisPair<T>::get_matrix_elements,
     145             :                                                        nb::const_))
     146             :         .def("get_matrix_elements",
     147             :              nb::overload_cast<std::shared_ptr<const KetAtom>, std::shared_ptr<const KetAtom>,
     148           2 :                                OperatorType, OperatorType, int, int>(
     149             :                  &BasisPair<T>::get_matrix_elements, nb::const_));
     150           2 : }
     151             : 
     152             : template <typename T>
     153           2 : static void declare_basis_pair_creator(nb::module_ &m, std::string const &type_name) {
     154           2 :     std::string pyclass_name = "BasisPairCreator" + type_name;
     155           4 :     nb::class_<BasisPairCreator<T>> pyclass(m, pyclass_name.c_str());
     156           4 :     pyclass.def(nb::init<>())
     157           2 :         .def("add", &BasisPairCreator<T>::add)
     158           2 :         .def("restrict_energy", &BasisPairCreator<T>::restrict_energy)
     159           2 :         .def("restrict_quantum_number_m", &BasisPairCreator<T>::restrict_quantum_number_m)
     160           2 :         .def("restrict_product_of_parities", &BasisPairCreator<T>::restrict_product_of_parities)
     161           2 :         .def("create", &BasisPairCreator<T>::create);
     162           2 : }
     163             : 
     164           1 : void bind_basis(nb::module_ &m) {
     165           1 :     declare_basis<BasisAtom<double>>(m, "BasisAtomReal");
     166           1 :     declare_basis<BasisAtom<std::complex<double>>>(m, "BasisAtomComplex");
     167           1 :     declare_basis_atom<double>(m, "Real");
     168           1 :     declare_basis_atom<std::complex<double>>(m, "Complex");
     169           1 :     declare_basis_atom_creator<double>(m, "Real");
     170           1 :     declare_basis_atom_creator<std::complex<double>>(m, "Complex");
     171             : 
     172           1 :     declare_basis<BasisPair<double>>(m, "BasisPairReal");
     173           1 :     declare_basis<BasisPair<std::complex<double>>>(m, "BasisPairComplex");
     174           1 :     declare_basis_pair<double>(m, "Real");
     175           1 :     declare_basis_pair<std::complex<double>>(m, "Complex");
     176           1 :     declare_basis_pair_creator<double>(m, "Real");
     177           1 :     declare_basis_pair_creator<std::complex<double>>(m, "Complex");
     178           1 : }

Generated by: LCOV version 1.16