LCOV - code coverage report
Current view: top level - tests - test_database.py (source / functions) Hit Total Coverage
Test: coverage.info Lines: 100 100 100.0 %
Date: 2025-12-08 07:47:12 Functions: 6 12 50.0 %

          Line data    Source code
       1             : # SPDX-FileCopyrightText: 2025 PairInteraction Developers
       2             : # SPDX-License-Identifier: LGPL-3.0-or-later
       3             : 
       4           1 : from __future__ import annotations
       5             : 
       6           1 : from typing import TYPE_CHECKING
       7             : 
       8           1 : import duckdb
       9           1 : import numpy as np
      10           1 : import pytest
      11           1 : from packaging.version import Version
      12           1 : from sympy.physics.wigner import wigner_3j
      13             : 
      14           1 : from tests.constants import GAUSS_IN_ATOMIC_UNITS, HARTREE_IN_GHZ, SPECIES_TO_NUCLEAR_SPIN, SUPPORTED_SPECIES
      15             : 
      16             : if TYPE_CHECKING:
      17             :     from collections.abc import Generator
      18             :     from pathlib import Path
      19             : 
      20             :     from .utils import PairinteractionModule
      21             : 
      22             : 
      23           1 : def fetch_id(n: int, l: float, f: float, s: float, connection: duckdb.DuckDBPyConnection, table: str | Path) -> int:
      24           1 :     result = connection.execute(
      25             :         f"SELECT id FROM '{table}' WHERE n = {n} AND f = {f} ORDER BY (exp_s - {s})^2+(exp_l - {l})^2 LIMIT 1"  # noqa: S608
      26             :     ).fetchone()
      27           1 :     return result[0] if result else -1
      28             : 
      29             : 
      30           1 : def fetch_wigner_element(
      31             :     f_initial: float,
      32             :     f_final: float,
      33             :     m_initial: float,
      34             :     m_final: float,
      35             :     kappa: int,
      36             :     q: int,
      37             :     connection: duckdb.DuckDBPyConnection,
      38             :     table: str | Path,
      39             : ) -> float:
      40           1 :     result = connection.execute(
      41             :         f"SELECT val FROM '{table}' WHERE f_initial = {f_initial} AND f_final = {f_final} AND "  # noqa: S608
      42             :         f"m_initial = {m_initial} AND m_final = {m_final} AND kappa = {kappa} AND q = {q}"
      43             :     ).fetchone()
      44           1 :     return result[0] if result else 0
      45             : 
      46             : 
      47           1 : def fetch_reduced_matrix_element(
      48             :     id_initial: int, id_final: int, connection: duckdb.DuckDBPyConnection, table: str | Path
      49             : ) -> float:
      50           1 :     result = connection.execute(
      51             :         f"SELECT val FROM '{table}' WHERE id_initial = {id_initial} AND id_final = {id_final}"  # noqa: S608
      52             :     ).fetchone()
      53           1 :     return result[0] if result else 0
      54             : 
      55             : 
      56           1 : @pytest.fixture(scope="module")
      57           1 : def connection() -> Generator[duckdb.DuckDBPyConnection]:
      58           1 :     with duckdb.connect(":memory:") as connection:
      59           1 :         yield connection
      60             : 
      61             : 
      62           1 : @pytest.mark.parametrize("swap_states", [False, True])
      63           1 : def test_database(pi_module: PairinteractionModule, connection: duckdb.DuckDBPyConnection, swap_states: bool) -> None:  # noqa: PLR0915
      64             :     """Test receiving matrix elements from the databases."""
      65           1 :     database = pi_module.Database.get_global_database()
      66           1 :     bfield_in_gauss = 1500
      67             : 
      68             :     # Define initial and final quantum states
      69           1 :     n_initial, n_final = 54, 54
      70           1 :     l_initial, l_final = 1, 1
      71           1 :     f_initial, f_final = 1, 0
      72           1 :     m_initial, m_final = 0, 0
      73           1 :     s_initial, s_final = 0.6, 1.0
      74             : 
      75             :     # Swap states if required by the test parameter
      76           1 :     if swap_states:
      77           1 :         n_initial, n_final = n_final, n_initial
      78           1 :         l_initial, l_final = l_final, l_initial
      79           1 :         f_initial, f_final = f_final, f_initial
      80           1 :         m_initial, m_final = m_final, m_initial
      81           1 :         s_initial, s_final = s_final, s_initial
      82             : 
      83             :     # Get the Zeeman interaction operator from the database using pairinteraction
      84           1 :     ket_initial = pi_module.KetAtom("Yb174_mqdt", n=n_initial, l=l_initial, f=f_initial, m=m_initial, s=s_initial)
      85           1 :     ket_final = pi_module.KetAtom("Yb174_mqdt", n=n_final, l=l_final, f=f_final, m=m_final, s=s_final)
      86           1 :     basis = pi_module.BasisAtom("Yb174_mqdt", additional_kets=[ket_initial, ket_final])
      87           1 :     operator = (
      88             :         pi_module.SystemAtom(basis)
      89             :         .set_magnetic_field([0, 0, bfield_in_gauss], unit="G")
      90             :         .set_diamagnetism_enabled(True)
      91             :         .get_hamiltonian(unit="GHz")
      92             :     ).toarray()
      93           1 :     operator -= np.diag(np.sort([ket_initial.get_energy(unit="GHz"), ket_final.get_energy(unit="GHz")]))
      94           1 :     expected_operator = np.array([[3.58588117, 1.66420213], [1.66420213, 4.16645123]])
      95           1 :     assert np.allclose(operator, expected_operator, rtol=1e-3)
      96             : 
      97             :     # Get the latest parquet files from the database directory
      98           1 :     parquet_files: dict[str, Path] = {}
      99           1 :     parquet_versions: dict[str, Version] = {}
     100           1 :     for path in (database.database_dir / "tables").glob("*/*.parquet"):
     101           1 :         species, version_str = path.parent.name.rsplit("_v", 1)
     102           1 :         table = path.stem
     103           1 :         name = f"{species}_{table}"
     104           1 :         version = Version(version_str)
     105           1 :         if name not in parquet_files or version > parquet_versions[name]:
     106           1 :             parquet_files[name] = path
     107           1 :             parquet_versions[name] = version
     108           1 :     assert "misc_wigner" in parquet_files
     109           1 :     assert "Yb174_mqdt_states" in parquet_files
     110           1 :     assert "Yb174_mqdt_matrix_elements_mu" in parquet_files
     111           1 :     assert "Yb174_mqdt_matrix_elements_q" in parquet_files
     112           1 :     assert "Yb174_mqdt_matrix_elements_q0" in parquet_files
     113             : 
     114             :     # Obtain the ids of the initial and final states
     115           1 :     id_initial = fetch_id(n_initial, l_initial, f_initial, s_initial, connection, parquet_files["Yb174_mqdt_states"])
     116           1 :     id_final = fetch_id(n_final, l_final, f_final, s_final, connection, parquet_files["Yb174_mqdt_states"])
     117           1 :     assert id_initial == id_final + (-1 if swap_states else +1), f"Got {id_initial=} {id_final=}"
     118             : 
     119             :     # Obtain a matrix element of the magnetic dipole operator (for the chosen kets, it is non-zero iff initial != final)
     120           1 :     kappa, q = 1, 0
     121           1 :     wigner_element = fetch_wigner_element(
     122             :         f_initial, f_final, m_initial, m_final, kappa, q, connection, parquet_files["misc_wigner"]
     123             :     )
     124           1 :     assert np.isclose(
     125             :         wigner_element,
     126             :         float((-1) ** (f_final - m_final) * wigner_3j(f_final, kappa, f_initial, -m_final, q, m_initial)),
     127             :     )
     128             : 
     129           1 :     me_mu = fetch_reduced_matrix_element(
     130             :         id_initial, id_final, connection, parquet_files["Yb174_mqdt_matrix_elements_mu"]
     131             :     )
     132           1 :     matrix_element = -wigner_element * me_mu * bfield_in_gauss * GAUSS_IN_ATOMIC_UNITS * HARTREE_IN_GHZ
     133           1 :     assert np.isclose(matrix_element, operator[0, 1], rtol=1e-3)
     134             : 
     135             :     # Obtain a matrix element of the diamagnetic operator (for the chosen kets, it is non-zero iff initial == final)
     136           1 :     n_final = n_initial
     137           1 :     l_final = l_initial
     138           1 :     f_final = f_initial
     139           1 :     m_final = m_initial
     140           1 :     s_final = s_initial
     141           1 :     id_final = id_initial
     142             : 
     143           1 :     kappa, q = 0, 0
     144           1 :     wigner_element = fetch_wigner_element(
     145             :         f_initial, f_final, m_initial, m_final, kappa, q, connection, parquet_files["misc_wigner"]
     146             :     )
     147           1 :     assert np.isclose(
     148             :         wigner_element,
     149             :         float((-1) ** (f_final - m_final) * wigner_3j(f_final, kappa, f_initial, -m_final, q, m_initial)),
     150             :     )
     151             : 
     152           1 :     me_q0 = fetch_reduced_matrix_element(
     153             :         id_initial, id_final, connection, parquet_files["Yb174_mqdt_matrix_elements_q0"]
     154             :     )
     155           1 :     matrix_element = 1 / 12 * wigner_element * me_q0 * (bfield_in_gauss * GAUSS_IN_ATOMIC_UNITS) ** 2 * HARTREE_IN_GHZ
     156             : 
     157           1 :     kappa, q = 2, 0
     158           1 :     wigner_element = fetch_wigner_element(
     159             :         f_initial, f_final, m_initial, m_final, kappa, q, connection, parquet_files["misc_wigner"]
     160             :     )
     161           1 :     assert np.isclose(
     162             :         wigner_element,
     163             :         float((-1) ** (f_final - m_final) * wigner_3j(f_final, kappa, f_initial, -m_final, q, m_initial)),
     164             :     )
     165             : 
     166           1 :     me_q = fetch_reduced_matrix_element(id_initial, id_final, connection, parquet_files["Yb174_mqdt_matrix_elements_q"])
     167           1 :     matrix_element -= 1 / 12 * wigner_element * me_q * (bfield_in_gauss * GAUSS_IN_ATOMIC_UNITS) ** 2 * HARTREE_IN_GHZ
     168           1 :     assert np.isclose(matrix_element, operator[0, 0] if swap_states else operator[1, 1], rtol=1e-3)
     169             : 
     170             : 
     171           1 : @pytest.mark.parametrize("species", SUPPORTED_SPECIES)
     172           1 : def test_obtaining_kets(pi_module: PairinteractionModule, species: str) -> None:
     173             :     """Test obtaining kets from the database."""
     174           1 :     is_mqdt = species.endswith("_mqdt")
     175           1 :     is_single_valence_electron = species in ["Rb"]
     176           1 :     is_triplet = species in ["Sr88_triplet"]
     177             : 
     178           1 :     quantum_number_i = SPECIES_TO_NUCLEAR_SPIN[species] if is_mqdt else 0
     179           1 :     quantum_number_s = 0.5 if is_single_valence_electron else (1 if is_triplet else 0)
     180           1 :     quantum_number_f = quantum_number_i + quantum_number_s
     181           1 :     quantum_number_m = quantum_number_i + quantum_number_s
     182             : 
     183             :     # Obtain a ket from the database
     184           1 :     ket = pi_module.KetAtom(species, n=60, l=0, f=quantum_number_f, m=quantum_number_m, s=quantum_number_s)
     185             : 
     186             :     # Check the result
     187           1 :     assert ket.species == species
     188           1 :     assert ket.n == 60 if not is_mqdt else abs(ket.n - 60) < 1
     189           1 :     assert ket.l == 0 if not is_mqdt else abs(ket.l - 0) < 1
     190           1 :     assert ket.f == quantum_number_f
     191           1 :     assert ket.m == quantum_number_m
     192           1 :     assert ket.s == quantum_number_s if not is_mqdt else abs(ket.s - quantum_number_s) < 1
     193             : 
     194             :     # TODO check repr(ket) (once the mqdt databases are updated)

Generated by: LCOV version 1.16