LCOV - code coverage report
Current view: top level - tests - test_database.py (source / functions) Hit Total Coverage
Test: coverage.info Lines: 103 103 100.0 %
Date: 2025-09-29 10:28:29 Functions: 6 12 50.0 %

          Line data    Source code
       1             : # SPDX-FileCopyrightText: 2025 PairInteraction Developers
       2             : # SPDX-License-Identifier: LGPL-3.0-or-later
       3             : 
       4             : """Test receiving matrix elements from the databases."""
       5             : 
       6           1 : from __future__ import annotations
       7             : 
       8           1 : from pathlib import Path
       9           1 : from typing import TYPE_CHECKING
      10             : 
      11           1 : import duckdb
      12           1 : import numpy as np
      13           1 : import pairinteraction as pi
      14           1 : import pytest
      15           1 : from packaging.version import Version
      16           1 : from sympy.physics.wigner import wigner_3j
      17             : 
      18           1 : from tests.constants import GAUSS_IN_ATOMIC_UNITS, HARTREE_IN_GHZ, SPECIES_TO_NUCLEAR_SPIN, SUPPORTED_SPECIES
      19             : 
      20             : if TYPE_CHECKING:
      21             :     from collections.abc import Generator
      22             : 
      23             : 
      24           1 : def fetch_id(n: int, l: float, f: float, s: float, connection: duckdb.DuckDBPyConnection, table: str | Path) -> int:
      25           1 :     result = connection.execute(
      26             :         f"SELECT id FROM '{table}' WHERE n = {n} AND f = {f} ORDER BY (exp_s - {s})^2+(exp_l - {l})^2 LIMIT 1"  # noqa: S608
      27             :     ).fetchone()
      28           1 :     return result[0] if result else -1
      29             : 
      30             : 
      31           1 : def fetch_wigner_element(
      32             :     f_initial: float,
      33             :     f_final: float,
      34             :     m_initial: float,
      35             :     m_final: float,
      36             :     kappa: int,
      37             :     q: int,
      38             :     connection: duckdb.DuckDBPyConnection,
      39             :     table: str | Path,
      40             : ) -> float:
      41           1 :     result = connection.execute(
      42             :         f"SELECT val FROM '{table}' WHERE f_initial = {f_initial} AND f_final = {f_final} AND "  # noqa: S608
      43             :         f"m_initial = {m_initial} AND m_final = {m_final} AND kappa = {kappa} AND q = {q}"
      44             :     ).fetchone()
      45           1 :     return result[0] if result else 0
      46             : 
      47             : 
      48           1 : def fetch_reduced_matrix_element(
      49             :     id_initial: int, id_final: int, connection: duckdb.DuckDBPyConnection, table: str | Path
      50             : ) -> float:
      51           1 :     result = connection.execute(
      52             :         f"SELECT val FROM '{table}' WHERE id_initial = {id_initial} AND id_final = {id_final}"  # noqa: S608
      53             :     ).fetchone()
      54           1 :     return result[0] if result else 0
      55             : 
      56             : 
      57           1 : @pytest.fixture(scope="module")
      58           1 : def connection() -> Generator[duckdb.DuckDBPyConnection]:
      59           1 :     with duckdb.connect(":memory:") as connection:
      60           1 :         yield connection
      61             : 
      62             : 
      63           1 : @pytest.mark.parametrize("swap_states", [False, True])
      64           1 : def test_database(connection: duckdb.DuckDBPyConnection, swap_states: bool) -> None:  # noqa: PLR0915
      65             :     """Test receiving matrix elements from the databases."""
      66           1 :     database = pi.Database.get_global_database()
      67           1 :     bfield_in_gauss = 1500
      68             : 
      69             :     # Define initial and final quantum states
      70           1 :     n_initial, n_final = 54, 54
      71           1 :     l_initial, l_final = 1, 1
      72           1 :     f_initial, f_final = 1, 0
      73           1 :     m_initial, m_final = 0, 0
      74           1 :     s_initial, s_final = 0.6, 1.0
      75             : 
      76             :     # Swap states if required by the test parameter
      77           1 :     if swap_states:
      78           1 :         n_initial, n_final = n_final, n_initial
      79           1 :         l_initial, l_final = l_final, l_initial
      80           1 :         f_initial, f_final = f_final, f_initial
      81           1 :         m_initial, m_final = m_final, m_initial
      82           1 :         s_initial, s_final = s_final, s_initial
      83             : 
      84             :     # Get the Zeeman interaction operator from the database using pairinteraction
      85           1 :     ket_initial = pi.KetAtom("Yb174_mqdt", n=n_initial, l=l_initial, f=f_initial, m=m_initial, s=s_initial)
      86           1 :     ket_final = pi.KetAtom("Yb174_mqdt", n=n_final, l=l_final, f=f_final, m=m_final, s=s_final)
      87           1 :     basis = pi.BasisAtom("Yb174_mqdt", additional_kets=[ket_initial, ket_final])
      88           1 :     operator = (
      89             :         pi.SystemAtom(basis)
      90             :         .set_magnetic_field([0, 0, bfield_in_gauss], unit="G")
      91             :         .set_diamagnetism_enabled(True)
      92             :         .get_hamiltonian(unit="GHz")
      93             :     ).toarray()
      94           1 :     operator -= np.diag(np.sort([ket_initial.get_energy(unit="GHz"), ket_final.get_energy(unit="GHz")]))
      95           1 :     expected_operator = np.array([[3.58588117, 1.66420213], [1.66420213, 4.16645123]])
      96           1 :     assert np.allclose(operator, expected_operator, rtol=1e-3)
      97             : 
      98             :     # Get the latest parquet files from the database directory
      99           1 :     parquet_files: dict[str, Path] = {}
     100           1 :     parquet_versions: dict[str, Version] = {}
     101           1 :     for path in list(Path(database.database_dir).rglob("*.parquet")):
     102           1 :         species, version_str = path.parent.name.rsplit("_v", 1)
     103           1 :         table = path.stem
     104           1 :         name = f"{species}_{table}"
     105           1 :         version = Version(version_str)
     106           1 :         if name not in parquet_files or version > parquet_versions[name]:
     107           1 :             parquet_files[name] = path
     108           1 :             parquet_versions[name] = version
     109           1 :     assert "misc_wigner" in parquet_files
     110           1 :     assert "Yb174_mqdt_states" in parquet_files
     111           1 :     assert "Yb174_mqdt_matrix_elements_mu" in parquet_files
     112           1 :     assert "Yb174_mqdt_matrix_elements_q" in parquet_files
     113           1 :     assert "Yb174_mqdt_matrix_elements_q0" in parquet_files
     114             : 
     115             :     # Obtain the ids of the initial and final states
     116           1 :     id_initial = fetch_id(n_initial, l_initial, f_initial, s_initial, connection, parquet_files["Yb174_mqdt_states"])
     117           1 :     assert id_initial == 362 if swap_states else 363
     118             : 
     119           1 :     id_final = fetch_id(n_final, l_final, f_final, s_final, connection, parquet_files["Yb174_mqdt_states"])
     120           1 :     assert id_final == 363 if swap_states else 362
     121             : 
     122             :     # Obtain a matrix element of the magnetic dipole operator (for the chosen kets, it is non-zero iff initial != final)
     123           1 :     kappa, q = 1, 0
     124           1 :     wigner_element = fetch_wigner_element(
     125             :         f_initial, f_final, m_initial, m_final, kappa, q, connection, parquet_files["misc_wigner"]
     126             :     )
     127           1 :     assert np.isclose(
     128             :         wigner_element,
     129             :         float((-1) ** (f_final - m_final) * wigner_3j(f_final, kappa, f_initial, -m_final, q, m_initial)),
     130             :     )
     131             : 
     132           1 :     me_mu = fetch_reduced_matrix_element(
     133             :         id_initial, id_final, connection, parquet_files["Yb174_mqdt_matrix_elements_mu"]
     134             :     )
     135           1 :     matrix_element = -wigner_element * me_mu * bfield_in_gauss * GAUSS_IN_ATOMIC_UNITS * HARTREE_IN_GHZ
     136           1 :     assert np.isclose(matrix_element, operator[0, 1], rtol=1e-3)
     137             : 
     138             :     # Obtain a matrix element of the diamagnetic operator (for the chosen kets, it is non-zero iff initial == final)
     139           1 :     n_final = n_initial
     140           1 :     l_final = l_initial
     141           1 :     f_final = f_initial
     142           1 :     m_final = m_initial
     143           1 :     s_final = s_initial
     144           1 :     id_final = id_initial
     145             : 
     146           1 :     kappa, q = 0, 0
     147           1 :     wigner_element = fetch_wigner_element(
     148             :         f_initial, f_final, m_initial, m_final, kappa, q, connection, parquet_files["misc_wigner"]
     149             :     )
     150           1 :     assert np.isclose(
     151             :         wigner_element,
     152             :         float((-1) ** (f_final - m_final) * wigner_3j(f_final, kappa, f_initial, -m_final, q, m_initial)),
     153             :     )
     154             : 
     155           1 :     me_q0 = fetch_reduced_matrix_element(
     156             :         id_initial, id_final, connection, parquet_files["Yb174_mqdt_matrix_elements_q0"]
     157             :     )
     158           1 :     matrix_element = 1 / 12 * wigner_element * me_q0 * (bfield_in_gauss * GAUSS_IN_ATOMIC_UNITS) ** 2 * HARTREE_IN_GHZ
     159             : 
     160           1 :     kappa, q = 2, 0
     161           1 :     wigner_element = fetch_wigner_element(
     162             :         f_initial, f_final, m_initial, m_final, kappa, q, connection, parquet_files["misc_wigner"]
     163             :     )
     164           1 :     assert np.isclose(
     165             :         wigner_element,
     166             :         float((-1) ** (f_final - m_final) * wigner_3j(f_final, kappa, f_initial, -m_final, q, m_initial)),
     167             :     )
     168             : 
     169           1 :     me_q = fetch_reduced_matrix_element(id_initial, id_final, connection, parquet_files["Yb174_mqdt_matrix_elements_q"])
     170           1 :     matrix_element -= 1 / 12 * wigner_element * me_q * (bfield_in_gauss * GAUSS_IN_ATOMIC_UNITS) ** 2 * HARTREE_IN_GHZ
     171           1 :     assert np.isclose(matrix_element, operator[0, 0] if swap_states else operator[1, 1], rtol=1e-3)
     172             : 
     173             : 
     174           1 : @pytest.mark.parametrize("species", SUPPORTED_SPECIES)
     175           1 : def test_obtaining_kets(species: str) -> None:
     176             :     """Test obtaining kets from the database."""
     177           1 :     is_mqdt = species.endswith("_mqdt")
     178           1 :     is_single_valence_electron = species in ["Rb"]
     179           1 :     is_triplet = species in ["Sr88_triplet"]
     180             : 
     181           1 :     quantum_number_i = SPECIES_TO_NUCLEAR_SPIN[species] if is_mqdt else 0
     182           1 :     quantum_number_s = 0.5 if is_single_valence_electron else (1 if is_triplet else 0)
     183           1 :     quantum_number_f = quantum_number_i + quantum_number_s
     184           1 :     quantum_number_m = quantum_number_i + quantum_number_s
     185             : 
     186             :     # Obtain a ket from the database
     187           1 :     ket = pi.KetAtom(species, n=60, l=0, f=quantum_number_f, m=quantum_number_m, s=quantum_number_s)
     188             : 
     189             :     # Check the result
     190           1 :     assert ket.species == species
     191           1 :     assert ket.n == 60 if not is_mqdt else abs(ket.n - 60) < 1
     192           1 :     assert ket.l == 0 if not is_mqdt else abs(ket.l - 0) < 1
     193           1 :     assert ket.f == quantum_number_f
     194           1 :     assert ket.m == quantum_number_m
     195           1 :     assert ket.s == quantum_number_s if not is_mqdt else abs(ket.s - quantum_number_s) < 1
     196             : 
     197             :     # TODO check repr(ket) (once the mqdt databases are updated)

Generated by: LCOV version 1.16