Line data Source code
1 : # SPDX-FileCopyrightText: 2024 PairInteraction Developers 2 : # SPDX-License-Identifier: LGPL-3.0-or-later 3 1 : from __future__ import annotations 4 : 5 1 : import contextlib 6 1 : import logging 7 1 : from pathlib import Path 8 1 : from typing import TYPE_CHECKING 9 : 10 1 : import numpy as np 11 : 12 : if TYPE_CHECKING: 13 : from collections.abc import Iterator 14 : 15 : from pairinteraction.units import NDArray 16 : 17 : 18 1 : REFERENCE_PATHS = { 19 : "stark_map": Path(__file__).parent.parent / "data" / "reference_stark_map", 20 : "pair_potential": Path(__file__).parent.parent / "data" / "reference_pair_potential", 21 : } 22 : 23 : 24 1 : def compare_eigensystem_to_reference( 25 : reference_path: Path, 26 : eigenenergies: NDArray, 27 : overlaps: NDArray | None = None, 28 : eigenvectors: NDArray | None = None, 29 : kets: list[str] | None = None, 30 : ) -> None: 31 1 : n_systems, n_kets = eigenenergies.shape 32 1 : np.testing.assert_allclose(eigenenergies, np.loadtxt(reference_path / "eigenenergies.txt")) 33 : 34 1 : if overlaps is not None: 35 : # Ensure that the overlaps sum up to one 36 1 : np.testing.assert_allclose(np.sum(overlaps, axis=1), np.ones(n_systems)) 37 1 : np.testing.assert_allclose(overlaps, np.loadtxt(reference_path / "overlaps.txt"), atol=1e-8) 38 : 39 1 : if kets is not None: 40 1 : np.testing.assert_equal(kets, np.loadtxt(reference_path / "kets.txt", dtype=str, delimiter="\t")) 41 : 42 1 : if eigenvectors is not None: 43 : # Because of degeneracies, checking the eigenvectors against reference data is complicated. 44 : # Thus, we only check their normalization and orthogonality. 45 1 : cumulative_norm = (np.array(eigenvectors) * np.array(eigenvectors).conj()).sum(axis=1) 46 1 : np.testing.assert_allclose(cumulative_norm, n_kets * np.ones(n_systems)) 47 : 48 : 49 1 : @contextlib.contextmanager 50 1 : def no_log_propagation(logger: logging.Logger | str) -> Iterator[None]: 51 : """Context manager to temporarily disable log propagation for a given logger.""" 52 1 : if isinstance(logger, str): 53 1 : logger = logging.getLogger(logger) 54 1 : old_value = logger.propagate 55 1 : try: 56 1 : logger.propagate = False 57 1 : yield 58 : finally: 59 1 : logger.propagate = old_value