Line data Source code
1 : # SPDX-FileCopyrightText: 2024 PairInteraction Developers 2 : # SPDX-License-Identifier: LGPL-3.0-or-later 3 : 4 1 : from __future__ import annotations 5 : 6 1 : import contextlib 7 1 : import logging 8 1 : from pathlib import Path 9 1 : from typing import TYPE_CHECKING, Callable, Protocol 10 : 11 1 : import numpy as np 12 : 13 : if TYPE_CHECKING: 14 : from collections.abc import Iterator 15 : 16 : import pairinteraction as pi 17 : from pairinteraction.units import NDArray 18 : 19 : 20 1 : REFERENCE_PATHS = { 21 : "stark_map": Path(__file__).parent.parent / "data" / "reference_stark_map", 22 : "pair_potential": Path(__file__).parent.parent / "data" / "reference_pair_potential", 23 : } 24 : 25 : 26 1 : def compare_eigensystem_to_reference( 27 : reference_path: Path, 28 : eigenenergies: NDArray, 29 : overlaps: NDArray | None = None, 30 : eigenvectors: NDArray | None = None, 31 : kets: list[str] | None = None, 32 : ) -> None: 33 1 : n_systems, n_kets = eigenenergies.shape 34 1 : np.testing.assert_allclose(eigenenergies, np.loadtxt(reference_path / "eigenenergies.txt")) 35 : 36 1 : if overlaps is not None: 37 : # Ensure that the overlaps sum up to one 38 1 : np.testing.assert_allclose(np.sum(overlaps, axis=1), np.ones(n_systems)) 39 1 : np.testing.assert_allclose(overlaps, np.loadtxt(reference_path / "overlaps.txt"), atol=1e-8) 40 : 41 1 : if kets is not None: 42 1 : np.testing.assert_equal(kets, np.loadtxt(reference_path / "kets.txt", dtype=str, delimiter="\t")) 43 : 44 1 : if eigenvectors is not None: 45 : # Because of degeneracies, checking the eigenvectors against reference data is complicated. 46 : # Thus, we only check their normalization and orthogonality. 47 1 : cumulative_norm = (np.array(eigenvectors) * np.array(eigenvectors).conj()).sum(axis=1) 48 1 : np.testing.assert_allclose(cumulative_norm, n_kets * np.ones(n_systems)) 49 : 50 : 51 1 : @contextlib.contextmanager 52 1 : def no_log_propagation(logger: logging.Logger | str) -> Iterator[None]: 53 : """Context manager to temporarily disable log propagation for a given logger.""" 54 1 : if isinstance(logger, str): 55 1 : logger = logging.getLogger(logger) 56 1 : old_value = logger.propagate 57 1 : try: 58 1 : logger.propagate = False 59 1 : yield 60 : finally: 61 1 : logger.propagate = old_value 62 : 63 : 64 1 : class PairinteractionModule(Protocol): 65 1 : Database: type[pi.Database] 66 1 : KetAtom: type[pi.KetAtom] 67 1 : BasisAtom: type[pi.BasisAtom] 68 1 : SystemAtom: type[pi.SystemAtom] 69 1 : KetPair: type[pi.KetPair] 70 1 : BasisPair: type[pi.BasisPair] 71 1 : SystemPair: type[pi.SystemPair] 72 1 : GreenTensor: type[pi.GreenTensor] 73 1 : EffectiveSystemPair: type[pi.EffectiveSystemPair] 74 1 : C3: type[pi.C3] 75 1 : C6: type[pi.C6] 76 1 : diagonalize: Callable[..., None]