LCOV - code coverage report
Current view: top level - bindings/basis - Basis.py.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 94 94 100.0 %
Date: 2026-06-19 12:50:25 Functions: 14 17 82.4 %

          Line data    Source code
       1             : // SPDX-FileCopyrightText: 2024 PairInteraction Developers
       2             : // SPDX-License-Identifier: LGPL-3.0-or-later
       3             : 
       4             : #include "./Basis.py.hpp"
       5             : 
       6             : #include "pairinteraction/basis/Basis.hpp"
       7             : #include "pairinteraction/basis/BasisAtom.hpp"
       8             : #include "pairinteraction/basis/BasisAtomCreator.hpp"
       9             : #include "pairinteraction/basis/BasisPair.hpp"
      10             : #include "pairinteraction/basis/BasisPairCreator.hpp"
      11             : #include "pairinteraction/database/Database.hpp"
      12             : #include "pairinteraction/interfaces/TransformationBuilderInterface.hpp"
      13             : #include "pairinteraction/ket/KetAtom.hpp"
      14             : #include "pairinteraction/ket/KetPair.hpp"
      15             : #include "pairinteraction/system/SystemAtom.hpp"
      16             : 
      17             : #include <nanobind/eigen/sparse.h>
      18             : #include <nanobind/nanobind.h>
      19             : #include <nanobind/stl/complex.h>
      20             : #include <nanobind/stl/set.h>
      21             : #include <nanobind/stl/shared_ptr.h>
      22             : #include <nanobind/stl/string.h>
      23             : #include <nanobind/stl/vector.h>
      24             : 
      25             : namespace nb = nanobind;
      26             : using namespace pairinteraction;
      27             : 
      28             : template <typename T>
      29           8 : static void declare_basis(nb::module_ &m, std::string const &type_name) {
      30           8 :     std::string pyclass_name = "Basis" + type_name;
      31             :     using scalar_t = typename Basis<T>::scalar_t;
      32           8 :     nb::class_<Basis<T>, TransformationBuilderInterface<scalar_t>> pyclass(m, pyclass_name.c_str());
      33           8 :     pyclass.def("get_kets", &Basis<T>::get_kets)
      34           8 :         .def("get_ket", &Basis<T>::get_ket)
      35           8 :         .def("get_state", &Basis<T>::get_state)
      36           8 :         .def("get_number_of_states", &Basis<T>::get_number_of_states)
      37           8 :         .def("get_number_of_kets", &Basis<T>::get_number_of_kets)
      38           8 :         .def("get_quantum_number_f", &Basis<T>::get_quantum_number_f)
      39           8 :         .def("get_quantum_number_m", &Basis<T>::get_quantum_number_m)
      40           8 :         .def("get_parity", &Basis<T>::get_parity)
      41           8 :         .def("get_coefficients", nb::overload_cast<>(&Basis<T>::get_coefficients, nb::const_))
      42           8 :         .def("set_coefficients", &Basis<T>::set_coefficients)
      43           8 :         .def("get_transformation", &Basis<T>::get_transformation)
      44           8 :         .def("get_rotator", &Basis<T>::get_rotator)
      45           8 :         .def("get_sorter", &Basis<T>::get_sorter)
      46           8 :         .def("get_indices_of_blocks", &Basis<T>::get_indices_of_blocks)
      47           8 :         .def("get_sorter_without_checks", &Basis<T>::get_sorter_without_checks)
      48           8 :         .def("get_indices_of_blocks_without_checks",
      49           8 :              &Basis<T>::get_indices_of_blocks_without_checks)
      50           8 :         .def(
      51             :             "transformed",
      52           8 :             nb::overload_cast<const Transformation<scalar_t> &>(&Basis<T>::transformed, nb::const_))
      53           8 :         .def("transformed", nb::overload_cast<const Sorting &>(&Basis<T>::transformed, nb::const_))
      54           8 :         .def("get_corresponding_state",
      55           8 :              nb::overload_cast<size_t>(&Basis<T>::get_corresponding_state, nb::const_))
      56           8 :         .def("get_corresponding_state",
      57           8 :              nb::overload_cast<std::shared_ptr<const typename Basis<T>::ket_t>>(
      58             :                  &Basis<T>::get_corresponding_state, nb::const_))
      59           8 :         .def("get_corresponding_state_index",
      60           8 :              nb::overload_cast<size_t>(&Basis<T>::get_corresponding_state_index, nb::const_))
      61           8 :         .def("get_corresponding_state_index",
      62           8 :              nb::overload_cast<std::shared_ptr<const typename Basis<T>::ket_t>>(
      63             :                  &Basis<T>::get_corresponding_state_index, nb::const_))
      64           8 :         .def("get_corresponding_ket",
      65           8 :              nb::overload_cast<size_t>(&Basis<T>::get_corresponding_ket, nb::const_))
      66           8 :         .def("get_corresponding_ket",
      67           8 :              nb::overload_cast<std::shared_ptr<const T>>(&Basis<T>::get_corresponding_ket,
      68             :                                                          nb::const_))
      69           8 :         .def("get_corresponding_ket_index",
      70           8 :              nb::overload_cast<size_t>(&Basis<T>::get_corresponding_ket_index, nb::const_))
      71           8 :         .def("get_corresponding_ket_index",
      72           8 :              nb::overload_cast<std::shared_ptr<const T>>(&Basis<T>::get_corresponding_ket_index,
      73             :                                                          nb::const_))
      74          16 :         .def("canonicalized", &Basis<T>::canonicalized)
      75          11 :         .def("copy", [](const Basis<T> &self) {
      76           3 :             return std::make_shared<T>(static_cast<const T &>(self));
      77             :         });
      78           8 : }
      79             : 
      80             : template <typename T>
      81           4 : static void declare_basis_atom(nb::module_ &m, std::string const &type_name) {
      82           4 :     std::string pyclass_name = "BasisAtom" + type_name;
      83           4 :     nb::class_<BasisAtom<T>, Basis<BasisAtom<T>>> pyclass(m, pyclass_name.c_str());
      84           4 :     pyclass.def("get_matrix_elements", &BasisAtom<T>::get_matrix_elements);
      85           4 : }
      86             : 
      87             : template <typename T>
      88           4 : static void declare_basis_atom_creator(nb::module_ &m, std::string const &type_name) {
      89           4 :     std::string pyclass_name = "BasisAtomCreator" + type_name;
      90           8 :     nb::class_<BasisAtomCreator<T>> pyclass(m, pyclass_name.c_str());
      91          12 :     pyclass.def(nb::init<>())
      92           4 :         .def("set_species", &BasisAtomCreator<T>::set_species)
      93           4 :         .def("restrict_energy", &BasisAtomCreator<T>::restrict_energy)
      94           4 :         .def("restrict_quantum_number", &BasisAtomCreator<T>::restrict_quantum_number)
      95           4 :         .def("set_quantum_number_standard_deviation_factor",
      96           4 :              &BasisAtomCreator<T>::set_quantum_number_standard_deviation_factor)
      97           4 :         .def("add_ket", &BasisAtomCreator<T>::add_ket)
      98           8 :         .def("create", &BasisAtomCreator<T>::create, nb::call_guard<nb::gil_scoped_release>());
      99           4 : }
     100             : 
     101             : template <typename T>
     102           4 : static void declare_basis_pair(nb::module_ &m, std::string const &type_name) {
     103           4 :     std::string pyclass_name = "BasisPair" + type_name;
     104           4 :     nb::class_<BasisPair<T>, Basis<BasisPair<T>>> pyclass(m, pyclass_name.c_str());
     105           4 :     pyclass.def("get_matrix_elements", &BasisPair<T>::get_matrix_elements)
     106           4 :         .def("get_basis1", &BasisPair<T>::get_basis1)
     107           4 :         .def("get_basis2", &BasisPair<T>::get_basis2);
     108           4 : }
     109             : 
     110             : template <typename T>
     111           4 : static void declare_basis_pair_creator(nb::module_ &m, std::string const &type_name) {
     112           4 :     std::string pyclass_name = "BasisPairCreator" + type_name;
     113           8 :     nb::class_<BasisPairCreator<T>> pyclass(m, pyclass_name.c_str());
     114             :     pyclass
     115          12 :         .def(nb::init<>())
     116             :         // keep_alive because add() stores only a reference to the system, which must outlive the
     117             :         // creator (the system is dereferenced in create())
     118           4 :         .def("add", &BasisPairCreator<T>::add, nb::keep_alive<1, 2>())
     119           4 :         .def("restrict_energy", &BasisPairCreator<T>::restrict_energy)
     120           4 :         .def("restrict_quantum_number_m", &BasisPairCreator<T>::restrict_quantum_number_m)
     121           4 :         .def("restrict_parity_under_inversion",
     122           4 :              &BasisPairCreator<T>::restrict_parity_under_inversion)
     123           4 :         .def("restrict_parity_under_permutation",
     124           4 :              &BasisPairCreator<T>::restrict_parity_under_permutation)
     125           8 :         .def("create", &BasisPairCreator<T>::create, nb::call_guard<nb::gil_scoped_release>());
     126           4 : }
     127             : 
     128           2 : void bind_basis(nb::module_ &m) {
     129           2 :     declare_basis<BasisAtom<double>>(m, "BasisAtomReal");
     130           2 :     declare_basis<BasisAtom<std::complex<double>>>(m, "BasisAtomComplex");
     131           2 :     declare_basis_atom<double>(m, "Real");
     132           2 :     declare_basis_atom<std::complex<double>>(m, "Complex");
     133           2 :     declare_basis_atom_creator<double>(m, "Real");
     134           2 :     declare_basis_atom_creator<std::complex<double>>(m, "Complex");
     135             : 
     136           2 :     declare_basis<BasisPair<double>>(m, "BasisPairReal");
     137           2 :     declare_basis<BasisPair<std::complex<double>>>(m, "BasisPairComplex");
     138           2 :     declare_basis_pair<double>(m, "Real");
     139           2 :     declare_basis_pair<std::complex<double>>(m, "Complex");
     140           2 :     declare_basis_pair_creator<double>(m, "Real");
     141           2 :     declare_basis_pair_creator<std::complex<double>>(m, "Complex");
     142           2 : }

Generated by: LCOV version 1.16