Line data Source code
1 : # SPDX-FileCopyrightText: 2025 PairInteraction Developers
2 : # SPDX-License-Identifier: LGPL-3.0-or-later
3 1 : from __future__ import annotations
4 :
5 1 : import logging
6 1 : from abc import ABC
7 1 : from typing import TYPE_CHECKING, Any, Generic, TypeVar
8 :
9 1 : from attr import dataclass
10 :
11 1 : import pairinteraction as pi
12 :
13 : if TYPE_CHECKING:
14 : from collections.abc import Mapping, Sequence
15 :
16 : from typing_extensions import Self
17 :
18 : from pairinteraction.system import SystemBase
19 : from pairinteraction.units import NDArray
20 : from pairinteraction_gui.config.basis_config import QuantumNumberRestrictions
21 : from pairinteraction_gui.config.ket_config import QuantumNumbers
22 : from pairinteraction_gui.config.system_config import RangesKeys
23 : from pairinteraction_gui.page import OneAtomPage, TwoAtomsPage
24 :
25 1 : logger = logging.getLogger(__name__)
26 :
27 1 : UnitFromRangeKey: dict[RangesKeys, str] = {
28 : "Ex": "V/cm",
29 : "Ey": "V/cm",
30 : "Ez": "V/cm",
31 : "Bx": "Gauss",
32 : "By": "Gauss",
33 : "Bz": "Gauss",
34 : "Distance": r"$\mu$m",
35 : "Angle": r"$^\circ$",
36 : }
37 :
38 1 : VariableNameFromRangeKey: dict[RangesKeys, str] = {
39 : "Ex": "efield_x",
40 : "Ey": "efield_y",
41 : "Ez": "efield_z",
42 : "Bx": "bfield_x",
43 : "By": "bfield_y",
44 : "Bz": "bfield_z",
45 : "Distance": "distance",
46 : "Angle": "angle",
47 : }
48 :
49 1 : PageType = TypeVar("PageType", "OneAtomPage", "TwoAtomsPage")
50 :
51 :
52 1 : @dataclass
53 1 : class Parameters(ABC, Generic[PageType]):
54 1 : species: tuple[str, ...]
55 1 : quantum_numbers: tuple[QuantumNumbers, ...]
56 1 : quantum_number_restrictions: tuple[QuantumNumberRestrictions, ...]
57 1 : ranges: dict[RangesKeys, list[float]]
58 1 : diamagnetism_enabled: bool
59 1 : diagonalize_kwargs: dict[str, str]
60 1 : diagonalize_relative_energy_range: tuple[float, float] | None
61 :
62 1 : def __post_init__(self) -> None:
63 : """Post-initialization processing."""
64 : # Check if all ranges have the same number of steps
65 0 : if not all(len(v) == self.steps for v in self.ranges.values()):
66 0 : raise ValueError("All ranges must have the same number of steps")
67 :
68 : # Check if all tuples have the same length
69 0 : if not all(
70 : len(tup) == self.n_atoms for tup in [self.species, self.quantum_numbers, self.quantum_number_restrictions]
71 : ):
72 0 : raise ValueError("All tuples must have the same length as the number of atoms")
73 :
74 1 : @classmethod
75 1 : def from_page(cls, page: PageType) -> Self:
76 : """Create Parameters object from page."""
77 1 : n_atoms = page.ket_config.n_atoms
78 :
79 1 : species = tuple(page.ket_config.get_species(atom) for atom in range(n_atoms))
80 1 : quantum_numbers = tuple(page.ket_config.get_quantum_numbers(atom) for atom in range(n_atoms))
81 :
82 1 : quantum_number_restrictions = tuple(
83 : page.basis_config.get_quantum_number_restrictions(atom) for atom in range(n_atoms)
84 : )
85 :
86 1 : ranges = page.system_config.get_ranges_dict()
87 1 : diamagnetism_enabled = page.system_config.diamagnetism.isChecked()
88 :
89 1 : diagonalize_kwargs = {}
90 1 : if page.calculation_config.fast_mode.isChecked():
91 1 : diagonalize_kwargs["diagonalizer"] = "lapacke_evr"
92 1 : diagonalize_kwargs["float_type"] = "float32"
93 :
94 1 : diagonalize_relative_energy_range = None
95 1 : if page.calculation_config.energy_range.isChecked():
96 1 : diagonalize_relative_energy_range = page.calculation_config.energy_range.values()
97 :
98 1 : return cls(
99 : species,
100 : quantum_numbers,
101 : quantum_number_restrictions,
102 : ranges,
103 : diamagnetism_enabled,
104 : diagonalize_kwargs,
105 : diagonalize_relative_energy_range,
106 : )
107 :
108 1 : @property
109 1 : def is_real(self) -> bool:
110 : """Check if the parameters are real."""
111 1 : return all(e == 0 for e in self.ranges.get("Ey", [0])) and all(b == 0 for b in self.ranges.get("By", [0]))
112 :
113 1 : @property
114 1 : def steps(self) -> int:
115 : """Return the number of steps."""
116 1 : return len(next(iter(self.ranges.values())))
117 :
118 1 : @property
119 1 : def n_atoms(self) -> int:
120 : """Return the number of atoms."""
121 1 : return len(self.species)
122 :
123 1 : def get_efield(self, step: int) -> list[float]:
124 : """Return the electric field for the given step."""
125 1 : efield_keys: list[RangesKeys] = ["Ex", "Ey", "Ez"]
126 1 : return [self.ranges[key][step] if key in self.ranges else 0 for key in efield_keys]
127 :
128 1 : def get_bfield(self, step: int) -> list[float]:
129 : """Return the magnetic field for the given step."""
130 1 : bfield_keys: list[RangesKeys] = ["Bx", "By", "Bz"]
131 1 : return [self.ranges[key][step] if key in self.ranges else 0 for key in bfield_keys]
132 :
133 1 : def get_species(self, atom: int | None = None) -> str:
134 : """Return the species for the given ket."""
135 1 : return self.species[self._check_atom(atom)]
136 :
137 1 : def get_quantum_numbers(self, atom: int | None = None) -> QuantumNumbers:
138 : """Return the quantum numbers for the given ket."""
139 1 : return self.quantum_numbers[self._check_atom(atom)]
140 :
141 1 : def get_ket_atom(self, atom: int | None = None) -> pi.KetAtom:
142 : """Return the ket atom for the given atom index."""
143 0 : return pi.KetAtom(self.get_species(atom), **self.get_quantum_numbers(atom))
144 :
145 1 : def get_quantum_number_restrictions(self, atom: int | None = None) -> QuantumNumberRestrictions:
146 : """Return the quantum number restrictions."""
147 1 : return self.quantum_number_restrictions[self._check_atom(atom)]
148 :
149 1 : def _check_atom(self, atom: int | None = None) -> int:
150 : """Check if the atom is valid."""
151 1 : if atom is not None:
152 1 : return atom
153 1 : if self.n_atoms == 1:
154 1 : return 0
155 0 : raise ValueError("Atom index is required for multiple atoms")
156 :
157 1 : def get_diagonalize_energy_range_kwargs(self, energy_of_interest: float) -> dict[str, Any]:
158 : """Return the kwargs for the diagonalization energy range."""
159 1 : if self.diagonalize_relative_energy_range is None:
160 1 : return {}
161 1 : kwargs: dict[str, Any] = {"energy_range_unit": "GHz"}
162 1 : kwargs["energy_range"] = (
163 : energy_of_interest + self.diagonalize_relative_energy_range[0],
164 : energy_of_interest + self.diagonalize_relative_energy_range[1],
165 : )
166 1 : return kwargs
167 :
168 1 : def get_x_values(self) -> list[float]:
169 : """Return the x values for the plot."""
170 1 : max_key = self._get_ranges_max_diff_key()
171 1 : return self.ranges[max_key]
172 :
173 1 : def get_x_label(self) -> str:
174 : """Return the x values for the plot."""
175 1 : max_key = self._get_ranges_max_diff_key()
176 1 : x_label = f"{max_key} ({UnitFromRangeKey[max_key]})"
177 :
178 1 : non_constant_keys = [key for key, values in self.ranges.items() if key != max_key and values[0] != values[-1]]
179 1 : if non_constant_keys:
180 0 : x_label += f" ({', '.join(non_constant_keys)} did also change)"
181 :
182 1 : return x_label
183 :
184 1 : def _get_ranges_max_diff_key(self) -> RangesKeys:
185 : """Return the key with the maximum difference in the ranges."""
186 1 : range_diffs: dict[RangesKeys, float] = {key: abs(r[-1] - r[0]) for key, r in self.ranges.items()}
187 1 : return max(range_diffs, key=lambda x: range_diffs.get(x, -1))
188 :
189 1 : def to_replacement_dict(self) -> dict[str, str]:
190 : """Return a dictionary with the parameters for replacement."""
191 1 : max_key = self._get_ranges_max_diff_key()
192 1 : replacements: dict[str, str] = {
193 : "$PI_DTYPE": "real" if self.is_real else "complex",
194 : "$X_VARIABLE_NAME": VariableNameFromRangeKey[max_key],
195 : "$X_LABEL": as_string(self.get_x_label(), raw_string=True),
196 : "$DIAMAGNETISM_ENABLED": str(self.diamagnetism_enabled),
197 : }
198 :
199 1 : for atom in range(self.n_atoms):
200 1 : replacements[f"$SPECIES_{atom}"] = as_string(self.get_species(atom))
201 1 : replacements[f"$QUANTUM_NUMBERS_{atom}"] = dict_to_repl(self.get_quantum_numbers(atom))
202 1 : replacements[f"$QUANTUM_NUMBERS_RESTRICTIONS_{atom}"] = dict_to_repl(
203 : self.get_quantum_number_restrictions(atom)
204 : )
205 :
206 1 : replacements["$STEPS"] = str(self.steps)
207 1 : for key, values in self.ranges.items():
208 1 : replacements[f"${key.upper()}_MIN"] = str(values[0])
209 1 : replacements[f"${key.upper()}_MAX"] = str(values[-1])
210 1 : if values[0] == values[-1]:
211 1 : replacements[f"${key.upper()}_VALUE"] = str(values[0])
212 :
213 1 : replacements["$DIAGONALIZE_KWARGS"] = dict_to_repl(self.diagonalize_kwargs)
214 :
215 1 : if self.diagonalize_relative_energy_range is None:
216 0 : replacements["$DIAGONALIZE_ENERGY_RANGE_KWARGS"] = ""
217 1 : elif self.n_atoms == 1:
218 1 : r_energy = self.diagonalize_relative_energy_range
219 1 : replacements["$DIAGONALIZE_ENERGY_RANGE_KWARGS"] = (
220 : f', energy_range=(ket_energy + {r_energy[0]}, ket_energy - {-r_energy[1]}), energy_range_unit="GHz"'
221 : )
222 1 : elif self.n_atoms == 2:
223 1 : r_energy = self.diagonalize_relative_energy_range
224 1 : replacements["$DIAGONALIZE_ENERGY_RANGE_KWARGS"] = (
225 : f', energy_range=(pair_energy + {r_energy[0]}, pair_energy - {-r_energy[1]}), energy_range_unit="GHz"'
226 : )
227 : else:
228 0 : raise RuntimeError("Energy range kwargs not implemented for more than two atoms")
229 :
230 1 : return replacements
231 :
232 :
233 1 : @dataclass
234 1 : class Results(ABC):
235 1 : energies: list[NDArray]
236 1 : energy_offset: float
237 1 : ket_overlaps: list[NDArray]
238 1 : systems: list[SystemBase[Any]]
239 :
240 1 : @classmethod
241 1 : def from_calculate(
242 : cls,
243 : parameters: Parameters[Any],
244 : system_list: Sequence[SystemBase[Any]],
245 : ket: pi.KetAtom | tuple[pi.KetAtom, ...],
246 : energy_offset: float,
247 : ) -> Self:
248 : """Create Results object from ket, basis, and diagonalized systems."""
249 1 : energies = [system.get_eigenenergies("GHz") - energy_offset for system in system_list]
250 :
251 1 : ket_overlaps = [system.get_eigenbasis().get_overlaps(ket) for system in system_list]
252 :
253 1 : return cls(energies, energy_offset, ket_overlaps, list(system_list))
254 :
255 :
256 1 : def as_string(value: str, *, raw_string: bool = False) -> str:
257 1 : string = '"' + value + '"'
258 1 : if raw_string:
259 1 : string = "r" + string
260 1 : return string
261 :
262 :
263 1 : def dict_to_repl(d: Mapping[str, Any]) -> str:
264 : """Convert a dictionary to a string for replacement."""
265 1 : if not d:
266 0 : return ""
267 1 : repl = ""
268 1 : for k, v in d.items():
269 1 : if isinstance(v, str):
270 1 : repl += f", {k}={as_string(v)}"
271 : else:
272 1 : repl += f", {k}={v}"
273 1 : return repl
|