Line data Source code
1 : # SPDX-FileCopyrightText: 2024 PairInteraction Developers 2 : # SPDX-License-Identifier: LGPL-3.0-or-later 3 : 4 1 : import contextlib 5 1 : import logging 6 1 : from collections.abc import Iterator 7 1 : from pathlib import Path 8 1 : from typing import TYPE_CHECKING, Optional, Union 9 : 10 1 : import numpy as np 11 : 12 : if TYPE_CHECKING: 13 : from pairinteraction.units import NDArray 14 : 15 : 16 1 : REFERENCE_PATHS = { 17 : "stark_map": Path(__file__).parent.parent / "data" / "reference_stark_map", 18 : "pair_potential": Path(__file__).parent.parent / "data" / "reference_pair_potential", 19 : } 20 : 21 : 22 1 : def compare_eigensystem_to_reference( 23 : reference_path: Path, 24 : eigenenergies: "NDArray", 25 : overlaps: Optional["NDArray"] = None, 26 : eigenvectors: Optional["NDArray"] = None, 27 : kets: Optional[list[str]] = None, 28 : ) -> None: 29 1 : n_systems, n_kets = eigenenergies.shape 30 1 : np.testing.assert_allclose(eigenenergies, np.loadtxt(reference_path / "eigenenergies.txt")) 31 : 32 1 : if overlaps is not None: 33 : # Ensure that the overlaps sum up to one 34 1 : np.testing.assert_allclose(np.sum(overlaps, axis=1), np.ones(n_systems)) 35 1 : np.testing.assert_allclose(overlaps, np.loadtxt(reference_path / "overlaps.txt"), atol=1e-8) 36 : 37 1 : if kets is not None: 38 1 : np.testing.assert_equal(kets, np.loadtxt(reference_path / "kets.txt", dtype=str, delimiter="\t")) 39 : 40 1 : if eigenvectors is not None: 41 : # Because of degeneracies, checking the eigenvectors against reference data is complicated. 42 : # Thus, we only check their normalization and orthogonality. 43 1 : cumulative_norm = (np.array(eigenvectors) * np.array(eigenvectors).conj()).sum(axis=1) 44 1 : np.testing.assert_allclose(cumulative_norm, n_kets * np.ones(n_systems)) 45 : 46 : 47 1 : @contextlib.contextmanager 48 1 : def no_log_propagation(logger: Union[logging.Logger, str]) -> Iterator[None]: 49 : """Context manager to temporarily disable log propagation for a given logger.""" 50 1 : if isinstance(logger, str): 51 1 : logger = logging.getLogger(logger) 52 1 : old_value = logger.propagate 53 1 : try: 54 1 : logger.propagate = False 55 1 : yield 56 : finally: 57 1 : logger.propagate = old_value