LCOV - code coverage report
Current view: top level - tests - utils.py (source / functions) Hit Total Coverage
Test: coverage.info Lines: 40 40 100.0 %
Date: 2026-03-03 11:15:30 Functions: 2 2 100.0 %

          Line data    Source code
       1             : # SPDX-FileCopyrightText: 2024 PairInteraction Developers
       2             : # SPDX-License-Identifier: LGPL-3.0-or-later
       3             : 
       4           1 : from __future__ import annotations
       5             : 
       6           1 : import contextlib
       7           1 : import logging
       8           1 : from pathlib import Path
       9           1 : from typing import TYPE_CHECKING, Callable, Protocol
      10             : 
      11           1 : import numpy as np
      12             : 
      13             : if TYPE_CHECKING:
      14             :     from collections.abc import Iterator
      15             : 
      16             :     import pairinteraction as pi
      17             :     import pint
      18             :     from pairinteraction.units import NDArray
      19             : 
      20             : 
      21           1 : REFERENCE_PATHS = {
      22             :     "stark_map": Path(__file__).parent.parent / "data" / "reference_stark_map",
      23             :     "pair_potential": Path(__file__).parent.parent / "data" / "reference_pair_potential",
      24             : }
      25             : 
      26             : 
      27           1 : def compare_eigensystem_to_reference(
      28             :     reference_path: Path,
      29             :     eigenenergies: NDArray,
      30             :     overlaps: NDArray | None = None,
      31             :     eigenvectors: NDArray | None = None,
      32             :     kets: list[str] | None = None,
      33             : ) -> None:
      34           1 :     n_systems, n_kets = eigenenergies.shape
      35           1 :     np.testing.assert_allclose(eigenenergies, np.loadtxt(reference_path / "eigenenergies.txt"))
      36             : 
      37           1 :     if overlaps is not None:
      38             :         # Ensure that the overlaps sum up to one
      39           1 :         np.testing.assert_allclose(np.sum(overlaps, axis=1), np.ones(n_systems))
      40           1 :         np.testing.assert_allclose(overlaps, np.loadtxt(reference_path / "overlaps.txt"), atol=1e-8)
      41             : 
      42           1 :     if kets is not None:
      43           1 :         np.testing.assert_equal(kets, np.loadtxt(reference_path / "kets.txt", dtype=str, delimiter="\t"))
      44             : 
      45           1 :     if eigenvectors is not None:
      46             :         # Because of degeneracies, checking the eigenvectors against reference data is complicated.
      47             :         # Thus, we only check their normalization and orthogonality.
      48           1 :         cumulative_norm = (np.array(eigenvectors) * np.array(eigenvectors).conj()).sum(axis=1)
      49           1 :         np.testing.assert_allclose(cumulative_norm, n_kets * np.ones(n_systems))
      50             : 
      51             : 
      52           1 : @contextlib.contextmanager
      53           1 : def no_log_propagation(logger: logging.Logger | str) -> Iterator[None]:
      54             :     """Context manager to temporarily disable log propagation for a given logger."""
      55           1 :     if isinstance(logger, str):
      56           1 :         logger = logging.getLogger(logger)
      57           1 :     old_value = logger.propagate
      58           1 :     try:
      59           1 :         logger.propagate = False
      60           1 :         yield
      61             :     finally:
      62           1 :         logger.propagate = old_value
      63             : 
      64             : 
      65           1 : class PairinteractionModule(Protocol):
      66           1 :     ureg: pint.UnitRegistry
      67           1 :     Database: type[pi.Database]
      68           1 :     KetAtom: type[pi.KetAtom]
      69           1 :     BasisAtom: type[pi.BasisAtom]
      70           1 :     SystemAtom: type[pi.SystemAtom]
      71           1 :     KetPair: type[pi.KetPair]
      72           1 :     BasisPair: type[pi.BasisPair]
      73           1 :     SystemPair: type[pi.SystemPair]
      74           1 :     EffectiveSystemPair: type[pi.EffectiveSystemPair]
      75           1 :     C3: type[pi.C3]
      76           1 :     C6: type[pi.C6]
      77           1 :     diagonalize: Callable[..., None]

Generated by: LCOV version 1.16