Line data Source code
1 : # SPDX-FileCopyrightText: 2025 PairInteraction Developers
2 : # SPDX-License-Identifier: LGPL-3.0-or-later
3 1 : from __future__ import annotations
4 :
5 1 : import logging
6 1 : from abc import ABC
7 1 : from typing import TYPE_CHECKING, Any, Generic, TypeVar
8 :
9 1 : import numpy as np
10 1 : from attr import dataclass
11 :
12 1 : import pairinteraction as pi
13 :
14 : if TYPE_CHECKING:
15 : from collections.abc import Mapping
16 :
17 : from typing_extensions import Self
18 :
19 : from pairinteraction.system import SystemAtom, SystemAtomReal, SystemPair, SystemPairReal
20 : from pairinteraction.units import NDArray
21 : from pairinteraction_gui.config.basis_config import QuantumNumberRestrictions
22 : from pairinteraction_gui.config.ket_config import QuantumNumbers
23 : from pairinteraction_gui.config.system_config import RangesKeys
24 : from pairinteraction_gui.page import OneAtomPage, TwoAtomsPage
25 :
26 1 : logger = logging.getLogger(__name__)
27 :
28 1 : UnitFromRangeKey: dict[RangesKeys, str] = {
29 : "Ex": "V/cm",
30 : "Ey": "V/cm",
31 : "Ez": "V/cm",
32 : "Bx": "Gauss",
33 : "By": "Gauss",
34 : "Bz": "Gauss",
35 : "Distance": r"$\mu$m",
36 : "Angle": r"$^\circ$",
37 : }
38 :
39 1 : VariableNameFromRangeKey: dict[RangesKeys, str] = {
40 : "Ex": "efield_x",
41 : "Ey": "efield_y",
42 : "Ez": "efield_z",
43 : "Bx": "bfield_x",
44 : "By": "bfield_y",
45 : "Bz": "bfield_z",
46 : "Distance": "distance",
47 : "Angle": "angle",
48 : }
49 :
50 1 : PageType = TypeVar("PageType", "OneAtomPage", "TwoAtomsPage")
51 :
52 :
53 1 : @dataclass
54 1 : class Parameters(ABC, Generic[PageType]):
55 1 : species: tuple[str, ...]
56 1 : quantum_numbers: tuple[QuantumNumbers, ...]
57 1 : quantum_number_restrictions: tuple[QuantumNumberRestrictions, ...]
58 1 : ranges: dict[RangesKeys, list[float]]
59 1 : diamagnetism_enabled: bool
60 1 : diagonalize_kwargs: dict[str, str]
61 1 : diagonalize_relative_energy_range: tuple[float, float] | None
62 1 : number_state_labels: int
63 :
64 1 : def __post_init__(self) -> None:
65 : """Post-initialization processing."""
66 : # Check if all ranges have the same number of steps
67 0 : if not all(len(v) == self.steps for v in self.ranges.values()):
68 0 : raise ValueError("All ranges must have the same number of steps")
69 :
70 : # Check if all tuples have the same length
71 0 : if not all(
72 : len(tup) == self.n_atoms for tup in [self.species, self.quantum_numbers, self.quantum_number_restrictions]
73 : ):
74 0 : raise ValueError("All tuples must have the same length as the number of atoms")
75 :
76 1 : @classmethod
77 1 : def from_page(cls, page: PageType) -> Self:
78 : """Create Parameters object from page."""
79 1 : n_atoms = page.ket_config.n_atoms
80 :
81 1 : species = tuple(page.ket_config.get_species(atom) for atom in range(n_atoms))
82 1 : quantum_numbers = tuple(page.ket_config.get_quantum_numbers(atom) for atom in range(n_atoms))
83 :
84 1 : quantum_number_restrictions = tuple(
85 : page.basis_config.get_quantum_number_restrictions(atom) for atom in range(n_atoms)
86 : )
87 :
88 1 : ranges = page.system_config.get_ranges_dict()
89 1 : diamagnetism_enabled = page.system_config.diamagnetism.isChecked()
90 :
91 1 : diagonalize_kwargs = {}
92 1 : if page.calculation_config.fast_mode.isChecked():
93 1 : diagonalize_kwargs["diagonalizer"] = "lapacke_evr"
94 1 : diagonalize_kwargs["float_type"] = "float32"
95 :
96 1 : diagonalize_relative_energy_range = None
97 1 : if page.calculation_config.energy_range.isChecked():
98 0 : diagonalize_relative_energy_range = page.calculation_config.energy_range.values()
99 :
100 1 : return cls(
101 : species,
102 : quantum_numbers,
103 : quantum_number_restrictions,
104 : ranges,
105 : diamagnetism_enabled,
106 : diagonalize_kwargs,
107 : diagonalize_relative_energy_range,
108 : page.calculation_config.number_state_labels.value(default=0),
109 : )
110 :
111 1 : @property
112 1 : def is_real(self) -> bool:
113 : """Check if the parameters are real."""
114 1 : return all(e == 0 for e in self.ranges.get("Ey", [0])) and all(b == 0 for b in self.ranges.get("By", [0]))
115 :
116 1 : @property
117 1 : def steps(self) -> int:
118 : """Return the number of steps."""
119 1 : return len(next(iter(self.ranges.values())))
120 :
121 1 : @property
122 1 : def n_atoms(self) -> int:
123 : """Return the number of atoms."""
124 1 : return len(self.species)
125 :
126 1 : def get_efield(self, step: int) -> list[float]:
127 : """Return the electric field for the given step."""
128 0 : efield_keys: list[RangesKeys] = ["Ex", "Ey", "Ez"]
129 0 : return [self.ranges[key][step] if key in self.ranges else 0 for key in efield_keys]
130 :
131 1 : def get_bfield(self, step: int) -> list[float]:
132 : """Return the magnetic field for the given step."""
133 0 : bfield_keys: list[RangesKeys] = ["Bx", "By", "Bz"]
134 0 : return [self.ranges[key][step] if key in self.ranges else 0 for key in bfield_keys]
135 :
136 1 : def get_species(self, atom: int | None = None) -> str:
137 : """Return the species for the given ket."""
138 1 : return self.species[self._check_atom(atom)]
139 :
140 1 : def get_quantum_numbers(self, atom: int | None = None) -> QuantumNumbers:
141 : """Return the quantum numbers for the given ket."""
142 1 : return self.quantum_numbers[self._check_atom(atom)]
143 :
144 1 : def get_ket_atom(self, atom: int | None = None) -> pi.KetAtom:
145 : """Return the ket atom for the given atom index."""
146 0 : return pi.KetAtom(self.get_species(atom), **self.get_quantum_numbers(atom))
147 :
148 1 : def get_quantum_number_restrictions(self, atom: int | None = None) -> QuantumNumberRestrictions:
149 : """Return the quantum number restrictions."""
150 1 : return self.quantum_number_restrictions[self._check_atom(atom)]
151 :
152 1 : def _check_atom(self, atom: int | None = None) -> int:
153 : """Check if the atom is valid."""
154 1 : if atom is not None:
155 1 : return atom
156 0 : if self.n_atoms == 1:
157 0 : return 0
158 0 : raise ValueError("Atom index is required for multiple atoms")
159 :
160 1 : def get_diagonalize_energy_range_kwargs(self, energy_of_interest: float) -> dict[str, Any]:
161 : """Return the kwargs for the diagonalization energy range."""
162 0 : if self.diagonalize_relative_energy_range is None:
163 0 : return {}
164 0 : kwargs: dict[str, Any] = {"energy_range_unit": "GHz"}
165 0 : kwargs["energy_range"] = (
166 : energy_of_interest + self.diagonalize_relative_energy_range[0],
167 : energy_of_interest + self.diagonalize_relative_energy_range[1],
168 : )
169 0 : return kwargs
170 :
171 1 : def get_x_values(self) -> list[float]:
172 : """Return the x values for the plot."""
173 0 : max_key = self._get_ranges_max_diff_key()
174 0 : return self.ranges[max_key]
175 :
176 1 : def get_x_label(self) -> str:
177 : """Return the x values for the plot."""
178 1 : max_key = self._get_ranges_max_diff_key()
179 1 : x_label = f"{max_key} [{UnitFromRangeKey[max_key]}]"
180 :
181 1 : non_constant_keys = [key for key, values in self.ranges.items() if key != max_key and values[0] != values[-1]]
182 1 : if non_constant_keys:
183 0 : x_label += f" ({', '.join(non_constant_keys)} did also change)"
184 :
185 1 : return x_label
186 :
187 1 : def _get_ranges_max_diff_key(self) -> RangesKeys:
188 : """Return the key with the maximum difference in the ranges."""
189 1 : range_diffs: dict[RangesKeys, float] = {key: abs(r[-1] - r[0]) for key, r in self.ranges.items()}
190 1 : return max(range_diffs, key=lambda x: range_diffs.get(x, -1))
191 :
192 1 : def to_replacement_dict(self) -> dict[str, str]:
193 : """Return a dictionary with the parameters for replacement."""
194 1 : max_key = self._get_ranges_max_diff_key()
195 1 : replacements: dict[str, str] = {
196 : "$PI_DTYPE": "real" if self.is_real else "complex",
197 : "$X_VARIABLE_NAME": VariableNameFromRangeKey[max_key],
198 : "$X_LABEL": as_string(self.get_x_label(), raw_string=True),
199 : "$DIAMAGNETISM_ENABLED": str(self.diamagnetism_enabled),
200 : }
201 :
202 1 : for atom in range(self.n_atoms):
203 1 : replacements[f"$SPECIES_{atom}"] = as_string(self.get_species(atom))
204 1 : replacements[f"$QUANTUM_NUMBERS_{atom}"] = dict_to_repl(self.get_quantum_numbers(atom))
205 1 : replacements[f"$QUANTUM_NUMBERS_RESTRICTIONS_{atom}"] = dict_to_repl(
206 : self.get_quantum_number_restrictions(atom)
207 : )
208 :
209 1 : replacements["$STEPS"] = str(self.steps)
210 1 : for key, values in self.ranges.items():
211 1 : replacements[f"${key.upper()}_MIN"] = str(values[0])
212 1 : replacements[f"${key.upper()}_MAX"] = str(values[-1])
213 1 : if values[0] == values[-1]:
214 1 : replacements[f"${key.upper()}_VALUE"] = str(values[0])
215 :
216 1 : replacements["$DIAGONALIZE_KWARGS"] = dict_to_repl(self.diagonalize_kwargs)
217 :
218 1 : if self.diagonalize_relative_energy_range is not None:
219 0 : r_energy = self.diagonalize_relative_energy_range
220 0 : replacements["$DIAGONALIZE_ENERGY_RANGE_KWARGS"] = (
221 : f', energy_range=(ket_energy + {r_energy[0]}, ket_energy - {-r_energy[1]}), energy_range_unit="GHz"'
222 : )
223 : else:
224 1 : replacements["$DIAGONALIZE_ENERGY_RANGE_KWARGS"] = ""
225 :
226 1 : return replacements
227 :
228 :
229 1 : @dataclass
230 1 : class Results(ABC):
231 1 : energies: list[NDArray]
232 1 : energy_offset: float
233 1 : ket_overlaps: list[NDArray]
234 1 : state_labels: dict[int, list[str]]
235 :
236 1 : @classmethod
237 1 : def from_calculate(
238 : cls,
239 : parameters: Parameters[Any],
240 : system_list: list[SystemPairReal] | list[SystemPair] | list[SystemAtomReal] | list[SystemAtom],
241 : ket: pi.KetAtom | tuple[pi.KetAtom, ...],
242 : energy_offset: float,
243 : ) -> Self:
244 : """Create Results object from ket, basis, and diagonalized systems."""
245 0 : energies = [system.get_eigenenergies("GHz") - energy_offset for system in system_list]
246 0 : ket_overlaps = [system.get_eigenbasis().get_overlaps(ket) for system in system_list] # type: ignore [arg-type]
247 :
248 0 : steps_with_labels = [int(i) for i in np.linspace(0, parameters.steps - 1, parameters.number_state_labels)]
249 0 : states_dict = {i: system_list[i].get_eigenbasis().states for i in steps_with_labels}
250 0 : state_labels = {i: [s.get_label() for s in states] for i, states in states_dict.items()}
251 :
252 0 : return cls(energies, energy_offset, ket_overlaps, state_labels)
253 :
254 :
255 1 : def as_string(value: str, *, raw_string: bool = False) -> str:
256 1 : string = '"' + value + '"'
257 1 : if raw_string:
258 1 : string = "r" + string
259 1 : return string
260 :
261 :
262 1 : def dict_to_repl(d: Mapping[str, Any]) -> str:
263 : """Convert a dictionary to a string for replacement."""
264 1 : if not d:
265 0 : return ""
266 1 : repl = ""
267 1 : for k, v in d.items():
268 1 : if isinstance(v, str):
269 1 : repl += f", {k}={as_string(v)}"
270 : else:
271 1 : repl += f", {k}={v}"
272 1 : return repl
|