Line data Source code
1 : # SPDX-FileCopyrightText: 2024 PairInteraction Developers 2 : # SPDX-License-Identifier: LGPL-3.0-or-later 3 : 4 1 : from __future__ import annotations 5 : 6 1 : import contextlib 7 1 : import logging 8 1 : from pathlib import Path 9 1 : from typing import TYPE_CHECKING, Callable, Protocol 10 : 11 1 : import numpy as np 12 : 13 : if TYPE_CHECKING: 14 : from collections.abc import Iterator 15 : 16 : import pairinteraction as pi 17 : import pint 18 : from pairinteraction.units import NDArray 19 : 20 : 21 1 : REFERENCE_PATHS = { 22 : "stark_map": Path(__file__).parent.parent / "data" / "reference_stark_map", 23 : "pair_potential": Path(__file__).parent.parent / "data" / "reference_pair_potential", 24 : } 25 : 26 : 27 1 : def compare_eigensystem_to_reference( 28 : reference_path: Path, 29 : eigenenergies: NDArray, 30 : overlaps: NDArray | None = None, 31 : eigenvectors: NDArray | None = None, 32 : kets: list[str] | None = None, 33 : ) -> None: 34 1 : n_systems, n_kets = eigenenergies.shape 35 1 : np.testing.assert_allclose(eigenenergies, np.loadtxt(reference_path / "eigenenergies.txt")) 36 : 37 1 : if overlaps is not None: 38 : # Ensure that the overlaps sum up to one 39 1 : np.testing.assert_allclose(np.sum(overlaps, axis=1), np.ones(n_systems)) 40 1 : np.testing.assert_allclose(overlaps, np.loadtxt(reference_path / "overlaps.txt"), atol=1e-8) 41 : 42 1 : if kets is not None: 43 1 : np.testing.assert_equal(kets, np.loadtxt(reference_path / "kets.txt", dtype=str, delimiter="\t")) 44 : 45 1 : if eigenvectors is not None: 46 : # Because of degeneracies, checking the eigenvectors against reference data is complicated. 47 : # Thus, we only check their normalization and orthogonality. 48 1 : cumulative_norm = (np.array(eigenvectors) * np.array(eigenvectors).conj()).sum(axis=1) 49 1 : np.testing.assert_allclose(cumulative_norm, n_kets * np.ones(n_systems)) 50 : 51 : 52 1 : @contextlib.contextmanager 53 1 : def no_log_propagation(logger: logging.Logger | str) -> Iterator[None]: 54 : """Context manager to temporarily disable log propagation for a given logger.""" 55 1 : if isinstance(logger, str): 56 1 : logger = logging.getLogger(logger) 57 1 : old_value = logger.propagate 58 1 : try: 59 1 : logger.propagate = False 60 1 : yield 61 : finally: 62 1 : logger.propagate = old_value 63 : 64 : 65 1 : class PairinteractionModule(Protocol): 66 1 : ureg: pint.UnitRegistry 67 1 : Database: type[pi.Database] 68 1 : KetAtom: type[pi.KetAtom] 69 1 : BasisAtom: type[pi.BasisAtom] 70 1 : SystemAtom: type[pi.SystemAtom] 71 1 : KetPair: type[pi.KetPair] 72 1 : BasisPair: type[pi.BasisPair] 73 1 : SystemPair: type[pi.SystemPair] 74 1 : EffectiveSystemPair: type[pi.EffectiveSystemPair] 75 1 : C3: type[pi.C3] 76 1 : C6: type[pi.C6] 77 1 : diagonalize: Callable[..., None]