LCOV - code coverage report
Current view: top level - src/pairinteraction/perturbative - effective_system_pair.py (source / functions) Hit Total Coverage
Test: coverage.info Lines: 299 350 85.4 %
Date: 2026-04-30 10:43:26 Functions: 34 42 81.0 %

          Line data    Source code
       1             : # SPDX-FileCopyrightText: 2025 PairInteraction Developers
       2             : # SPDX-License-Identifier: LGPL-3.0-or-later
       3           1 : from __future__ import annotations
       4             : 
       5           1 : import contextlib
       6           1 : import copy
       7           1 : import logging
       8           1 : from functools import cached_property, lru_cache
       9           1 : from typing import TYPE_CHECKING, Literal, overload
      10             : 
      11           1 : import numpy as np
      12           1 : from scipy import sparse
      13             : 
      14           1 : from pairinteraction.basis import BasisAtom, BasisAtomReal, BasisPair, BasisPairReal
      15           1 : from pairinteraction.diagonalization import diagonalize
      16           1 : from pairinteraction.perturbative.perturbation_theory import calculate_perturbative_hamiltonian
      17           1 : from pairinteraction.system import SystemAtom, SystemAtomReal, SystemPair, SystemPairReal
      18           1 : from pairinteraction.units import QuantityArray, QuantityScalar
      19             : 
      20             : if TYPE_CHECKING:
      21             :     from collections.abc import Sequence
      22             : 
      23             :     from scipy.sparse import csr_matrix
      24             :     from typing_extensions import Self
      25             : 
      26             :     from pairinteraction.ket import KetAtomTuple
      27             :     from pairinteraction.units import ArrayLike, NDArray, PintArray, PintFloat
      28             : 
      29             : 
      30           1 : logger = logging.getLogger(__name__)
      31             : 
      32           1 : BasisSystemLiteral = Literal["basis_atoms", "system_atoms", "basis_pair", "system_pair"]
      33             : 
      34             : 
      35           1 : class EffectiveSystemPair:
      36             :     """Class for creating an effective SystemPair object and calculating the effective Hamiltonian.
      37             : 
      38             :     Given a subspace spanned by tuples of `KetAtom` objects (ket_tuples),
      39             :     this class automatically generates appropriate `BasisAtom`, `SystemAtom` objects as well as a `BasisPair` and
      40             :     `SystemPair` object to calculate the effective Hamiltonian in the subspace via perturbation theory.
      41             : 
      42             :     This class also allows to set magnetic and electric fields similar to the `SystemAtom` class,
      43             :     as well as the angle and distance between the two atoms like in the `SystemPair` class.
      44             : 
      45             :     Examples:
      46             :         >>> import pairinteraction as pi
      47             :         >>> ket_atoms = {
      48             :         ...     "+": pi.KetAtom("Rb", n=59, l=0, j=0.5, m=0.5),
      49             :         ...     "0": pi.KetAtom("Rb", n=58, l=1, j=1.5, m=1.5),
      50             :         ...     "-": pi.KetAtom("Rb", n=58, l=0, j=0.5, m=0.5),
      51             :         ... }
      52             :         >>> ket_tuples = [
      53             :         ...     (ket_atoms["+"], ket_atoms["-"]),
      54             :         ...     (ket_atoms["0"], ket_atoms["0"]),
      55             :         ...     (ket_atoms["-"], ket_atoms["+"]),
      56             :         ... ]
      57             :         >>> eff_system = pi.EffectiveSystemPair(ket_tuples)
      58             :         >>> eff_system = eff_system.set_distance(10, angle_degree=45, unit="micrometer")
      59             :         >>> eff_h = eff_system.get_effective_hamiltonian(unit="MHz")
      60             :         >>> eff_h -= np.eye(3) * eff_system.get_pair_energies("MHz")[1]
      61             :         >>> print(np.round(eff_h, 0), "MHz")
      62             :         [[292.   3.   0.]
      63             :          [  3.   0.   3.]
      64             :          [  0.   3. 292.]] MHz
      65             : 
      66             :     """
      67             : 
      68           1 :     _basis_atom_class = BasisAtom
      69           1 :     _basis_pair_class = BasisPair
      70           1 :     _system_atom_class = SystemAtom
      71           1 :     _system_pair_class = SystemPair
      72             : 
      73           1 :     def __init__(self, ket_tuples: Sequence[KetAtomTuple]) -> None:
      74           1 :         if not all(len(ket_tuple) == 2 for ket_tuple in ket_tuples):
      75           0 :             raise ValueError("All ket tuples must contain exactly two kets")
      76           1 :         for i in range(2):
      77           1 :             if not all(ket_tuple[i].species == ket_tuples[0][i].species for ket_tuple in ket_tuples):
      78           0 :                 raise ValueError(f"All kets for atom={i} must have the same species")
      79             : 
      80             :         # Perturbation attributes
      81           1 :         self._ket_tuples = [tuple(kets) for kets in ket_tuples]
      82           1 :         self._perturbation_order = 2
      83             : 
      84             :         # BasisAtom and SystemAtom attributes
      85           1 :         self._delta_n: int | None = None
      86           1 :         self._delta_l: int | None = None
      87           1 :         self._delta_m: int | None = None
      88           1 :         self._electric_field: PintArray | None = None
      89           1 :         self._magnetic_field: PintArray | None = None
      90           1 :         self._diamagnetism_enabled: bool | None = None
      91             : 
      92             :         # BasisPair and SystemPair attributes
      93           1 :         self._minimum_number_of_ket_pairs: int | None = None
      94           1 :         self._maximum_number_of_ket_pairs: int | None = None
      95           1 :         self._interaction_order: int | None = None
      96           1 :         self._distance_vector: PintArray | None = None
      97             : 
      98             :         # misc
      99           1 :         self._eff_h_dict_au: dict[int, NDArray] | None = None
     100           1 :         self._eff_vecs: csr_matrix | None = None
     101             : 
     102             :         # misc user set stuff
     103           1 :         self._user_set_parts: set[BasisSystemLiteral] = set()
     104             : 
     105           1 :     def copy(self: Self) -> Self:
     106             :         """Create a copy of the EffectiveSystemPair object (before it has been created)."""
     107           0 :         if self._is_created("basis_atoms"):
     108           0 :             raise RuntimeError(
     109             :                 "Cannot copy the EffectiveSystemPair object after it has been created. "
     110             :                 "Please create a new object instead."
     111             :             )
     112           0 :         return copy.copy(self)
     113             : 
     114           1 :     def _is_created(self: Self, what: BasisSystemLiteral = "basis_atoms") -> bool:
     115             :         """Check if some part of the effective Hamiltonian has already been created."""
     116           1 :         return hasattr(self, "_" + what)
     117             : 
     118           1 :     def _ensure_not_created(self: Self, what: BasisSystemLiteral = "basis_atoms") -> None:
     119             :         """Ensure that some part of the effective Hamiltonian has not been created yet."""
     120           1 :         if self._is_created(what):
     121           0 :             raise RuntimeError(
     122             :                 f"Cannot change parameters for {what} after it has already been created. "
     123             :                 f"Please set all parameters before {what} before accessing it (or creating the effective Hamiltonian)."
     124             :             )
     125             : 
     126           1 :     def _delete_created(self: Self, what: BasisSystemLiteral = "basis_atoms") -> None:
     127             :         """Delete the created part of the effective Hamiltonian.
     128             : 
     129             :         Args:
     130             :             what: The part of the effective Hamiltonian to delete.
     131             :                 Default is "basis_atoms", which means delete all parts that have been created.
     132             : 
     133             :         """
     134           1 :         self._eff_h_dict_au = None
     135           1 :         self._eff_vecs = None
     136           1 :         self._eff_basis = None
     137           1 :         with contextlib.suppress(AttributeError):
     138           1 :             del self.model_inds
     139             : 
     140           1 :         parts_order: list[BasisSystemLiteral] = ["system_pair", "basis_pair", "system_atoms", "basis_atoms"]
     141           1 :         for part in parts_order:
     142           1 :             if part in self._user_set_parts:
     143           0 :                 raise RuntimeError(
     144             :                     f"Cannot delete {part} because it has been set by the user. "
     145             :                     "Please create a new EffectiveSystemPair object instead."
     146             :                 )
     147           1 :             with contextlib.suppress(AttributeError):
     148           1 :                 delattr(self, "_" + part)
     149           1 :             if part == what:
     150           1 :                 break
     151             : 
     152             :     # # # Perturbation methods and attributes # # #
     153           1 :     @property
     154           1 :     def ket_tuples(self) -> list[KetAtomTuple]:
     155             :         """The tuples of kets, which form the model space for the effective Hamiltonian."""
     156           1 :         return self._ket_tuples  # type: ignore [return-value]
     157             : 
     158           1 :     @property
     159           1 :     def perturbation_order(self) -> int:
     160             :         """The perturbation order for the effective Hamiltonian."""
     161           1 :         return self._perturbation_order
     162             : 
     163           1 :     def set_perturbation_order(self: Self, order: int) -> Self:
     164             :         """Set the perturbation order for the effective Hamiltonian."""
     165           1 :         self._delete_created()
     166           1 :         self._perturbation_order = order
     167           1 :         return self
     168             : 
     169             :     # # # BasisAtom methods and attributes # # #
     170           1 :     @property
     171           1 :     def basis_atoms(self) -> tuple[BasisAtom, BasisAtom]:
     172             :         """The basis objects for the single-atom systems."""
     173           1 :         if not self._is_created("basis_atoms"):
     174           1 :             self._create_basis_atoms()
     175           1 :         return self._basis_atoms  # type: ignore [return-value]
     176             : 
     177           1 :     @basis_atoms.setter
     178           1 :     def basis_atoms(self, basis_atoms: tuple[BasisAtom, BasisAtom]) -> None:
     179           1 :         self._ensure_not_created()
     180           1 :         if self._delta_n is not None or self._delta_l is not None or self._delta_m is not None:
     181           0 :             logger.warning("Setting basis_atoms will overwrite parameters defined for basis_atoms.")
     182           1 :         self._user_set_parts.add("basis_atoms")
     183           1 :         self._basis_atoms = tuple(basis_atoms)
     184             : 
     185           1 :     def set_delta_n(self: Self, delta_n: int) -> Self:
     186             :         """Set the delta_n value for single-atom basis."""
     187           0 :         self._delete_created()
     188           0 :         self._delta_n = delta_n
     189           0 :         return self
     190             : 
     191           1 :     def set_delta_l(self: Self, delta_l: int) -> Self:
     192             :         """Set the delta_l value for single-atom basis."""
     193           0 :         self._delete_created()
     194           0 :         self._delta_l = delta_l
     195           0 :         return self
     196             : 
     197           1 :     def set_delta_m(self: Self, delta_m: int) -> Self:
     198             :         """Set the delta_m value for single-atom basis."""
     199           0 :         self._delete_created()
     200           0 :         self._delta_m = delta_m
     201           0 :         return self
     202             : 
     203           1 :     def _create_basis_atoms(self) -> None:
     204           1 :         delta_n = self._delta_n if self._delta_n is not None else 7
     205           1 :         delta_l = self._delta_l
     206           1 :         if delta_l is None:
     207           1 :             delta_l = self.perturbation_order * (self.interaction_order - 2)
     208           1 :         delta_m = self._delta_m
     209           1 :         if delta_m is None and self._delta_l is None and self._are_fields_along_z:
     210           1 :             delta_m = self.perturbation_order * (self.interaction_order - 2)
     211             : 
     212           1 :         basis_atoms: list[BasisAtom] = []
     213           1 :         use_real = isinstance(self, EffectiveSystemPairReal)
     214           1 :         for i in range(2):
     215           1 :             kets = [ket_tuple[i] for ket_tuple in self.ket_tuples]
     216           1 :             nlfm = np.transpose([[ket.n, ket.l, ket.f, ket.m] for ket in kets])
     217           1 :             n_range = (int(np.min(nlfm[0])) - delta_n, int(np.max(nlfm[0])) + delta_n)
     218           1 :             l_range = (np.min(nlfm[1]) - delta_l, np.max(nlfm[1]) + delta_l)
     219           1 :             if any(ket.is_calculated_with_mqdt for ket in kets) and self._delta_l is None:
     220             :                 # for mqdt we increase the default delta_l by 1 to take into account the variance ...
     221           0 :                 l_range = (np.min(nlfm[1]) - delta_l - 1, np.max(nlfm[1]) + delta_l + 1)
     222           1 :             m_range = (np.min(nlfm[3]) - delta_m, np.max(nlfm[3]) + delta_m) if delta_m is not None else None
     223           1 :             basis = get_basis_atom_with_cache(kets[0].species, n_range, l_range, m_range, use_real=use_real)
     224           1 :             basis_atoms.append(basis)
     225             : 
     226           1 :         self._basis_atoms = tuple(basis_atoms)
     227             : 
     228             :     # # # SystemAtom methods and attributes # # #
     229           1 :     @property
     230           1 :     def system_atoms(self) -> tuple[SystemAtom, SystemAtom]:
     231             :         """The system objects for the single-atom systems."""
     232           1 :         if not self._is_created("system_atoms"):
     233           1 :             self._create_system_atoms()
     234           1 :         return self._system_atoms
     235             : 
     236           1 :     @system_atoms.setter
     237           1 :     def system_atoms(self, system_atoms: tuple[SystemAtom, SystemAtom]) -> None:
     238           1 :         self._ensure_not_created()
     239           1 :         if (
     240             :             self._electric_field is not None
     241             :             or self._magnetic_field is not None
     242             :             or self._diamagnetism_enabled is not None
     243             :         ):
     244           0 :             logger.warning("Setting system_atoms will overwrite parameters defined for system_atoms.")
     245           1 :         self._user_set_parts.add("system_atoms")
     246           1 :         self._system_atoms: tuple[SystemAtom, SystemAtom] = tuple(system_atoms)  # type: ignore [assignment]
     247           1 :         self.basis_atoms = tuple(system.basis for system in system_atoms)  # type: ignore [assignment]
     248             : 
     249           1 :     @property
     250           1 :     def electric_field(self) -> PintArray:
     251             :         """The electric field for the single-atom systems."""
     252           1 :         if self._electric_field is None:
     253           1 :             self.set_electric_field([0, 0, 0], "V/cm")
     254           1 :             assert self._electric_field is not None
     255           1 :         return self._electric_field
     256             : 
     257           1 :     def set_electric_field(
     258             :         self: Self,
     259             :         electric_field: PintArray | ArrayLike,
     260             :         unit: str | None = None,
     261             :     ) -> Self:
     262             :         """Set the electric field for the single-atom systems.
     263             : 
     264             :         Args:
     265             :             electric_field: The electric field to set for the systems.
     266             :             unit: The unit of the electric field, e.g. "V/cm".
     267             :                 Default None expects a `pint.Quantity`.
     268             : 
     269             :         """
     270           1 :         self._delete_created()
     271           1 :         self._electric_field = QuantityArray.convert_user_to_pint(electric_field, unit, "electric_field")
     272           1 :         return self
     273             : 
     274           1 :     @property
     275           1 :     def magnetic_field(self) -> PintArray:
     276             :         """The magnetic field for the single-atom systems."""
     277           1 :         if self._magnetic_field is None:
     278           1 :             self.set_magnetic_field([0, 0, 0], "gauss")
     279           1 :             assert self._magnetic_field is not None
     280           1 :         return self._magnetic_field
     281             : 
     282           1 :     def set_magnetic_field(
     283             :         self: Self,
     284             :         magnetic_field: PintArray | ArrayLike,
     285             :         unit: str | None = None,
     286             :     ) -> Self:
     287             :         """Set the magnetic field for the single-atom systems.
     288             : 
     289             :         Args:
     290             :             magnetic_field: The magnetic field to set for the systems.
     291             :             unit: The unit of the magnetic field, e.g. "gauss".
     292             :                 Default None expects a `pint.Quantity`.
     293             : 
     294             :         """
     295           1 :         self._delete_created()
     296           1 :         self._magnetic_field = QuantityArray.convert_user_to_pint(magnetic_field, unit, "magnetic_field")
     297           1 :         return self
     298             : 
     299           1 :     @property
     300           1 :     def _are_fields_along_z(self) -> bool:
     301           1 :         return all(x == 0 for x in [*self.magnetic_field[:2], *self.electric_field[:2]])  # type: ignore [index]
     302             : 
     303           1 :     @property
     304           1 :     def diamagnetism_enabled(self) -> bool:
     305             :         """Whether diamagnetism is enabled for the single-atom systems."""
     306           1 :         if self._diamagnetism_enabled is None:
     307           1 :             self.set_diamagnetism_enabled(False)
     308           1 :             assert self._diamagnetism_enabled is not None
     309           1 :         return self._diamagnetism_enabled
     310             : 
     311           1 :     def set_diamagnetism_enabled(self: Self, enable: bool = True) -> Self:
     312             :         """Enable or disable diamagnetism for the system.
     313             : 
     314             :         Args:
     315             :             enable: Whether to enable or disable diamagnetism.
     316             : 
     317             :         """
     318           1 :         self._delete_created("system_atoms")
     319           1 :         self._diamagnetism_enabled = enable
     320           1 :         return self
     321             : 
     322           1 :     def _create_system_atoms(self) -> None:
     323           1 :         system_atoms: list[SystemAtom] = []
     324           1 :         for basis_atom in self.basis_atoms:
     325           1 :             system = self._system_atom_class(basis_atom)
     326           1 :             system.set_diamagnetism_enabled(self.diamagnetism_enabled)
     327           1 :             system.set_electric_field(self.electric_field)
     328           1 :             system.set_magnetic_field(self.magnetic_field)
     329           1 :             system_atoms.append(system)
     330           1 :         diagonalize(system_atoms)
     331             : 
     332           1 :         self._system_atoms = tuple(system_atoms)  # type: ignore [assignment]
     333             : 
     334             :     @overload
     335             :     def get_pair_energies(self, unit: None = None) -> list[PintFloat]: ...
     336             : 
     337             :     @overload
     338             :     def get_pair_energies(self, unit: str) -> list[float]: ...
     339             : 
     340           1 :     def get_pair_energies(self, unit: str | None = None) -> list[float] | list[PintFloat]:
     341             :         """Get the pair energies of the ket tuples for infinite distance (i.e. no interaction).
     342             : 
     343             :         Args:
     344             :             unit: The unit to which to convert the energies to.
     345             :                 Default None will return a list of `pint.Quantity`.
     346             : 
     347             :         Returns:
     348             :             The energies as list of float if a unit was given, otherwise as list of `pint.Quantity`.
     349             : 
     350             :         """
     351           1 :         return [  # type: ignore [return-value]
     352             :             sum(
     353             :                 system.get_corresponding_energy(ket, unit=unit)
     354             :                 for system, ket in zip(self.system_atoms, ket_tuple, strict=True)
     355             :             )
     356             :             for ket_tuple in self.ket_tuples
     357             :         ]
     358             : 
     359             :     # # # BasisPair methods and attributes # # #
     360           1 :     @property
     361           1 :     def basis_pair(self) -> BasisPair:
     362             :         """The basis pair object for the pair system."""
     363           1 :         if not self._is_created("basis_pair"):
     364           1 :             self._create_basis_pair()
     365           1 :         return self._basis_pair
     366             : 
     367           1 :     @basis_pair.setter
     368           1 :     def basis_pair(self, basis_pair: BasisPair) -> None:
     369           1 :         self._ensure_not_created()
     370           1 :         if self._minimum_number_of_ket_pairs is not None or self._maximum_number_of_ket_pairs is not None:
     371           0 :             logger.warning("Setting basis_pair will overwrite parameters defined for basis_pair.")
     372           1 :         self._user_set_parts.add("basis_pair")
     373           1 :         self._basis_pair = basis_pair
     374           1 :         self.system_atoms = basis_pair.system_atoms
     375             : 
     376           1 :     def set_minimum_number_of_ket_pairs(self: Self, number_of_kets: int) -> Self:
     377             :         """Set the minimum number of ket pairs in the basis pair.
     378             : 
     379             :         Args:
     380             :             number_of_kets: The minimum number of ket pairs to set in the basis pair, by default we use 2000.
     381             : 
     382             :         """
     383           1 :         self._delete_created("basis_pair")
     384           1 :         if self._maximum_number_of_ket_pairs is not None and number_of_kets > self._maximum_number_of_ket_pairs:
     385           0 :             raise ValueError("The minimum number of ket pairs cannot be larger than the maximum number of ket pairs.")
     386           1 :         self._minimum_number_of_ket_pairs = number_of_kets
     387           1 :         return self
     388             : 
     389           1 :     def set_maximum_number_of_ket_pairs(self: Self, number_of_kets: int) -> Self:
     390             :         """Set the maximum number of ket pairs in the basis pair.
     391             : 
     392             :         Args:
     393             :             number_of_kets: The maximum number of ket pairs to set in the basis pair.
     394             : 
     395             :         """
     396           0 :         self._delete_created("basis_pair")
     397           0 :         if self._minimum_number_of_ket_pairs is not None and number_of_kets < self._minimum_number_of_ket_pairs:
     398           0 :             raise ValueError("The maximum number of ket pairs cannot be smaller than the minimum number of ket pairs.")
     399           0 :         self._maximum_number_of_ket_pairs = number_of_kets
     400           0 :         return self
     401             : 
     402           1 :     def _create_basis_pair(self) -> None:
     403           1 :         max_number_of_kets: float | None = self._maximum_number_of_ket_pairs
     404           1 :         if max_number_of_kets is None:
     405           1 :             max_number_of_kets = np.inf
     406           1 :         min_number_of_kets: float | None = self._minimum_number_of_ket_pairs
     407           1 :         if min_number_of_kets is None:
     408           1 :             min_number_of_kets = min(2_000, max_number_of_kets)
     409             : 
     410           1 :         pair_energies_au = self.get_pair_energies(unit="hartree")
     411           1 :         min_energy_au = min(pair_energies_au)
     412           1 :         max_energy_au = max(pair_energies_au)
     413             : 
     414           1 :         mhz_au = QuantityScalar.convert_user_to_au(1, "MHz", "energy")
     415           1 :         delta_energy_au = 100 * mhz_au
     416           1 :         min_delta, max_delta = 0.1 * mhz_au, 1.0
     417             : 
     418             :         # make a bisect search to get a sensible basis size between:
     419             :         # min_number_of_kets and max_number_of_kets
     420           1 :         while True:
     421           1 :             basis_pair = self._basis_pair_class(
     422             :                 self.system_atoms,
     423             :                 energy=(min_energy_au - delta_energy_au, max_energy_au + delta_energy_au),
     424             :                 energy_unit="hartree",
     425             :             )
     426             : 
     427           1 :             if max_delta - min_delta < mhz_au:
     428           0 :                 break  # stop condition if delta_energy_au does not change anymore
     429             : 
     430           1 :             if basis_pair.number_of_kets < min_number_of_kets:
     431           1 :                 min_delta = delta_energy_au
     432           1 :                 delta_energy_au = min(2 * delta_energy_au, (delta_energy_au + max_delta) / 2)
     433           1 :             elif basis_pair.number_of_kets > max_number_of_kets:
     434           0 :                 max_delta = delta_energy_au
     435           0 :                 min_delta = max(0.5 * delta_energy_au, (delta_energy_au + min_delta) / 2)
     436             :             else:
     437           1 :                 break
     438             : 
     439           1 :         self._basis_pair = basis_pair
     440           1 :         logger.debug("The pair basis for the perturbative calculations consists of %d kets.", basis_pair.number_of_kets)
     441             : 
     442             :     # # # SystemPair methods and attributes # # #
     443           1 :     @property
     444           1 :     def system_pair(self) -> SystemPair:
     445             :         """The system pair object for the pair system."""
     446           1 :         if not self._is_created("system_pair"):
     447           1 :             self._create_system_pair()
     448           1 :         return self._system_pair
     449             : 
     450           1 :     @system_pair.setter
     451           1 :     def system_pair(self, system_pair: SystemPair) -> None:
     452           1 :         self._ensure_not_created()
     453           1 :         if self._interaction_order is not None or self._distance_vector is not None:
     454           0 :             logger.warning("Setting system_pair will overwrite parameters defined for system_pair.")
     455           1 :         self._user_set_parts.add("system_pair")
     456           1 :         self._system_pair = system_pair
     457           1 :         self.basis_pair = system_pair.basis
     458             : 
     459           1 :     @property
     460           1 :     def interaction_order(self) -> int:
     461             :         """The interaction order for the pair system."""
     462           1 :         if self._interaction_order is None:
     463           1 :             self.set_interaction_order(3)
     464           1 :         return self._interaction_order  # type: ignore [return-value]
     465             : 
     466           1 :     def set_interaction_order(self: Self, order: int) -> Self:
     467             :         """Set the interaction order of the pair system.
     468             : 
     469             :         Args:
     470             :             order: The interaction order to set for the pair system.
     471             :                 The order must be 3, 4, or 5.
     472             : 
     473             :         """
     474           1 :         self._delete_created()
     475           1 :         self._interaction_order = order
     476           1 :         return self
     477             : 
     478           1 :     @property
     479           1 :     def distance_vector(self) -> PintArray:
     480             :         """The distance vector between the atoms in the pair system."""
     481           1 :         if self._distance_vector is None:
     482           0 :             self.set_distance_vector([0, 0, np.inf], "micrometer")
     483           1 :         return self._distance_vector  # type: ignore [return-value]
     484             : 
     485           1 :     def set_distance(
     486             :         self: Self,
     487             :         distance: float | PintFloat,
     488             :         angle_degree: float = 0,
     489             :         unit: str | None = None,
     490             :     ) -> Self:
     491             :         """Set the distance between the atoms using the specified distance and angle.
     492             : 
     493             :         Args:
     494             :             distance: The distance to set between the atoms in the given unit.
     495             :             angle_degree: The angle between the distance vector and the z-axis in degrees.
     496             :                 90 degrees corresponds to the x-axis.
     497             :                 Defaults to 0, which corresponds to the z-axis.
     498             :             unit: The unit of the distance, e.g. "micrometer".
     499             :                 Default None expects a `pint.Quantity`.
     500             : 
     501             :         """
     502           1 :         distance_vector = [np.sin(np.deg2rad(angle_degree)) * distance, 0, np.cos(np.deg2rad(angle_degree)) * distance]
     503           1 :         return self.set_distance_vector(distance_vector, unit)
     504             : 
     505           1 :     def set_distance_vector(
     506             :         self: Self,
     507             :         distance: ArrayLike | PintArray,
     508             :         unit: str | None = None,
     509             :     ) -> Self:
     510             :         """Set the distance vector between the atoms.
     511             : 
     512             :         Args:
     513             :             distance: The distance vector to set between the atoms in the given unit.
     514             :             unit: The unit of the distance, e.g. "micrometer".
     515             :                 Default None expects a `pint.Quantity`.
     516             : 
     517             :         """
     518           1 :         self._delete_created("system_pair")
     519           1 :         self._distance_vector = QuantityArray.convert_user_to_pint(distance, unit, "distance")
     520           1 :         return self
     521             : 
     522           1 :     def set_angle(
     523             :         self: Self,
     524             :         angle: float = 0,
     525             :         unit: Literal["degree", "radian"] = "degree",
     526             :     ) -> Self:
     527             :         """Set the angle between the atoms in degrees.
     528             : 
     529             :         Args:
     530             :             angle: The angle between the distance vector and the z-axis (by default in degrees).
     531             :                 90 degrees corresponds to the x-axis.
     532             :                 Defaults to 0, which corresponds to the z-axis.
     533             :             unit: The unit of the angle, either "degree" or "radian", by default "degree".
     534             : 
     535             :         """
     536           0 :         assert unit in ("radian", "degree"), f"Unit {unit} is not supported for angle."
     537           0 :         if unit == "radian":
     538           0 :             angle = np.rad2deg(angle)
     539           0 :         distance_mum: float = np.linalg.norm(self.distance_vector.to("micrometer").magnitude)  # type: ignore [assignment]
     540           0 :         return self.set_distance(distance_mum, angle, "micrometer")
     541             : 
     542           1 :     def _create_system_pair(self) -> None:
     543           1 :         system_pair = self._system_pair_class(self.basis_pair)
     544           1 :         system_pair.set_distance_vector(self.distance_vector)
     545           1 :         system_pair.set_interaction_order(self.interaction_order)
     546           1 :         self._system_pair = system_pair
     547             : 
     548             :     # # # Effective Hamiltonian methods and attributes # # #
     549             :     @overload
     550             :     def get_effective_hamiltonian(self, return_order: int | None = None, unit: None = None) -> PintArray: ...
     551             : 
     552             :     @overload
     553             :     def get_effective_hamiltonian(self, return_order: int | None = None, *, unit: str) -> NDArray: ...
     554             : 
     555           1 :     def get_effective_hamiltonian(
     556             :         self, return_order: int | None = None, unit: str | None = None
     557             :     ) -> NDArray | PintArray:
     558             :         """Get the effective Hamiltonian of the pair system.
     559             : 
     560             :         Args:
     561             :             return_order: The order of the perturbation to return.
     562             :                 Default None, returns the sum up to the perturbation order set in the class.
     563             :             unit: The unit in which to return the effective Hamiltonian.
     564             :                 If None, returns a pint array.
     565             : 
     566             :         Returns:
     567             :             The effective Hamiltonian of the pair system in the given unit.
     568             :             If unit is None, returns a pint array, otherwise returns a numpy array.
     569             : 
     570             :         """
     571           1 :         if self._eff_h_dict_au is None:
     572           1 :             self._create_effective_hamiltonian()
     573           1 :             assert self._eff_h_dict_au is not None
     574           1 :         if return_order is None:
     575           1 :             h_eff_au: NDArray = sum(self._eff_h_dict_au.values())  # type: ignore [assignment]
     576           1 :         elif return_order in self._eff_h_dict_au:
     577           1 :             h_eff_au = self._eff_h_dict_au[return_order]
     578             :         else:
     579           0 :             raise ValueError(
     580             :                 f"The perturbation order {return_order} is not available in the effective Hamiltonian "
     581             :                 f"with the specified perturbation_order {self.perturbation_order}."
     582             :             )
     583           1 :         return QuantityArray.convert_au_to_user(np.real_if_close(h_eff_au), "energy", unit)
     584             : 
     585           1 :     def get_effective_basisvectors(self) -> csr_matrix:
     586             :         """Get the eigenvectors of the perturbative Hamiltonian."""
     587           0 :         if len(self.model_inds) > 1 and self.perturbation_order > 2:
     588           0 :             logger.warning("For more than one state and perturbation_order > 2 the effective basis might be wrong.")
     589           0 :         if self._eff_vecs is None:
     590           0 :             self._create_effective_hamiltonian()
     591           0 :             assert self._eff_vecs is not None
     592           0 :         return self._eff_vecs
     593             : 
     594           1 :     def get_effective_basis(self) -> BasisPair:
     595             :         """Get the effective basis of the pair system."""
     596           0 :         raise NotImplementedError("The get effective basis method is not implemented yet.")
     597             : 
     598           1 :     def _create_effective_hamiltonian(self) -> None:
     599             :         """Calculate the perturbative Hamiltonian up to the given perturbation order."""
     600           1 :         hamiltonian_au = self.system_pair.get_hamiltonian(unit="hartree")
     601           1 :         eff_h_dict_au, eff_vecs = calculate_perturbative_hamiltonian(
     602             :             hamiltonian_au, self.model_inds, self.perturbation_order
     603             :         )
     604           1 :         self._eff_h_dict_au = eff_h_dict_au
     605           1 :         self._eff_vecs = eff_vecs
     606             : 
     607           1 :         self.check_for_resonances()
     608             : 
     609             :     # # # Other stuff # # #
     610           1 :     @cached_property
     611           1 :     def model_inds(self) -> list[int]:
     612             :         """The indices of the corresponding KetPairs of the given ket_tuples in the basis of the pair system."""
     613           1 :         model_inds = []
     614           1 :         for kets in self.ket_tuples:
     615           1 :             overlap = self.basis_pair.get_overlaps(kets)
     616           1 :             inds = np.argsort(overlap)[::-1]
     617           1 :             model_inds.append(int(inds[0]))
     618             : 
     619           1 :             if overlap[inds[0]] > 0.6:
     620           1 :                 continue
     621           0 :             if overlap[inds[0]] == 0:
     622           0 :                 raise ValueError(f"The pairstate {kets} is not part of the basis of the pair system.")
     623           0 :             logger.critical(
     624             :                 "The ket_pair %s only has an overlap of %.3f with its corresponding pair_state."
     625             :                 " The most perturbing states are:",
     626             :                 kets,
     627             :                 overlap[inds[0]],
     628             :             )
     629           0 :             for i in inds[1:5]:
     630           0 :                 logger.error("  - %s with overlap %.3e", self.system_pair.basis.kets[i], overlap[i])
     631             : 
     632           1 :         return model_inds
     633             : 
     634           1 :     def check_for_resonances(self, required_overlap: float = 0.95) -> None:
     635             :         r"""Check if states of the model space have strong resonances with states outside the model space."""
     636             :         # Get the effective eigenvectors without potential warning
     637           1 :         if self._eff_vecs is None:
     638           0 :             self._create_effective_hamiltonian()
     639           0 :             assert self._eff_vecs is not None
     640           1 :         eff_vecs = self._eff_vecs
     641             : 
     642           1 :         overlaps = (eff_vecs.multiply(eff_vecs.conj())).real  # elementwise multiplication
     643             : 
     644           1 :         model_inds = self.model_inds
     645           1 :         for i, m_ind in enumerate(model_inds):
     646           1 :             overlaps_i = overlaps[i, :]
     647             : 
     648           1 :             inf_data_inds = np.isinf(overlaps_i.data)
     649           1 :             if inf_data_inds.any():
     650           1 :                 indices = overlaps_i.indices[np.argwhere(inf_data_inds).flatten()]
     651           1 :                 logger.critical(
     652             :                     "Detected 'inf' entries in the effective eigenvectors.\n"
     653             :                     " This might happen, if you forgot to include a degenerate state in the model space.\n"
     654             :                     " Consider adding the following states to the model space:"
     655             :                 )
     656           1 :                 for index in indices:
     657           1 :                     logger.critical("  - %s has infinite admixture", self.system_pair.basis.kets[index])
     658           1 :                 continue
     659             : 
     660           1 :             overlaps_i /= np.sum(overlaps_i.data)  # normalize the overlaps to 1
     661           1 :             if overlaps_i[0, m_ind] >= required_overlap:
     662           1 :                 continue
     663           1 :             logger.error(
     664             :                 "The ket %s has only %.3f overlap with its corresponding effective eigenvector.\n"
     665             :                 " Thus, the calculation might lead to unexpected or wrong results.\n"
     666             :                 " Consider adding the most perturbing states to the model space.\n"
     667             :                 " The most perturbing states are:",
     668             :                 self.system_pair.basis.kets[m_ind],
     669             :                 overlaps_i[0, m_ind],
     670             :             )
     671           1 :             print_above_admixture = (1 - overlaps_i[0, m_ind]) * 0.05
     672           1 :             indices = list(sparse.find(overlaps_i >= print_above_admixture)[1])
     673           1 :             indices = sorted(indices, key=lambda index, ov=overlaps_i: ov[0, index], reverse=True)  # type: ignore [misc]
     674           1 :             for index in indices:
     675           1 :                 if index != m_ind:
     676           1 :                     admixture = overlaps_i[0, index]
     677           1 :                     logger.error("  - %s with overlap %.3e", self.system_pair.basis.kets[index], admixture)
     678             : 
     679             : 
     680           1 : class EffectiveSystemPairReal(EffectiveSystemPair):
     681           1 :     _basis_atom_class = BasisAtomReal
     682           1 :     _basis_pair_class = BasisPairReal
     683           1 :     _system_atom_class = SystemAtomReal
     684           1 :     _system_pair_class = SystemPairReal
     685             : 
     686             : 
     687           1 : @lru_cache(maxsize=20)
     688           1 : def get_basis_atom_with_cache(
     689             :     species: str, n: tuple[int, int], l: tuple[int, int], m: tuple[int, int], *, use_real: bool
     690             : ) -> BasisAtom:
     691             :     """Get a BasisAtom object potentially by using a cache to avoid recomputing it."""
     692           1 :     if use_real:
     693           1 :         return BasisAtomReal(species, n=n, l=l, m=m)
     694           1 :     return BasisAtom(species, n=n, l=l, m=m)

Generated by: LCOV version 1.16