Line data Source code
1 : # SPDX-FileCopyrightText: 2025 Pairinteraction Developers
2 : # SPDX-License-Identifier: LGPL-3.0-or-later
3 :
4 0 : import logging
5 0 : from abc import ABC
6 0 : from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union
7 :
8 0 : from attr import dataclass
9 :
10 0 : from pairinteraction import (
11 : _wrapped,
12 : complex as pi_complex,
13 : real as pi_real,
14 : )
15 0 : from pairinteraction_gui.config.system_config import RangesKeys
16 :
17 : if TYPE_CHECKING:
18 : from typing_extensions import Self
19 :
20 : from pairinteraction.units import NDArray
21 : from pairinteraction_gui.page import OneAtomPage, TwoAtomsPage
22 :
23 0 : logger = logging.getLogger(__name__)
24 :
25 : # FIXME: having all kwargs dictionaries being Any is a hacky solution, it would be nice to use TypedDict in the future
26 :
27 0 : UnitFromRangeKey: dict[RangesKeys, str] = {
28 : "Ex": "V/cm",
29 : "Ey": "V/cm",
30 : "Ez": "V/cm",
31 : "Bx": "Gauss",
32 : "By": "Gauss",
33 : "Bz": "Gauss",
34 : "Distance": r"$\mu$m",
35 : "Angle": r"$^\circ$",
36 : }
37 :
38 0 : VariableNameFromRangeKey: dict[RangesKeys, str] = {
39 : "Ex": "efield_x",
40 : "Ey": "efield_y",
41 : "Ez": "efield_z",
42 : "Bx": "bfield_x",
43 : "By": "bfield_y",
44 : "Bz": "bfield_z",
45 : "Distance": "distance",
46 : "Angle": "angle",
47 : }
48 :
49 0 : PageType = TypeVar("PageType", "OneAtomPage", "TwoAtomsPage")
50 :
51 :
52 0 : @dataclass
53 0 : class Parameters(ABC, Generic[PageType]):
54 0 : species: tuple[str, ...]
55 0 : quantum_numbers: tuple[dict[str, float], ...]
56 0 : quantum_number_deltas: tuple[dict[str, float], ...]
57 0 : ranges: dict[RangesKeys, list[float]]
58 0 : diagonalize_kwargs: dict[str, str]
59 0 : diagonalize_relative_energy_range: Union[tuple[float, float], None]
60 :
61 0 : def __post_init__(self) -> None:
62 : """Post-initialization processing."""
63 : # Check if all ranges have the same number of steps
64 0 : if not all(len(v) == self.steps for v in self.ranges.values()):
65 0 : raise ValueError("All ranges must have the same number of steps")
66 :
67 : # Check if all tuples have the same length
68 0 : if not all(
69 : len(tup) == self.n_atoms for tup in [self.species, self.quantum_numbers, self.quantum_number_deltas]
70 : ):
71 0 : raise ValueError("All tuples must have the same length as the number of atoms")
72 :
73 0 : @classmethod
74 0 : def from_page(cls, page: PageType) -> "Self":
75 : """Create Parameters object from page."""
76 0 : n_atoms = page.ket_config.n_atoms
77 :
78 0 : species = tuple(page.ket_config.get_species(atom) for atom in range(n_atoms))
79 0 : quantum_numbers = tuple(page.ket_config.get_quantum_numbers(atom) for atom in range(n_atoms))
80 :
81 0 : quantum_number_deltas = tuple(page.basis_config.get_quantum_number_deltas(atom) for atom in range(n_atoms))
82 :
83 0 : ranges = page.system_config.get_ranges_dict()
84 :
85 0 : diagonalize_kwargs = {}
86 0 : if page.plotwidget.fast_mode.isChecked():
87 0 : diagonalize_kwargs["diagonalizer"] = "lapacke_evr"
88 0 : diagonalize_kwargs["float_type"] = "float32"
89 :
90 0 : diagonalize_relative_energy_range = None
91 0 : if page.plotwidget.energy_range.isChecked():
92 0 : diagonalize_relative_energy_range = page.plotwidget.energy_range.values()
93 :
94 0 : return cls(
95 : species,
96 : quantum_numbers,
97 : quantum_number_deltas,
98 : ranges,
99 : diagonalize_kwargs,
100 : diagonalize_relative_energy_range,
101 : )
102 :
103 0 : @property
104 0 : def is_real(self) -> bool:
105 : """Check if the parameters are real."""
106 0 : return all(e == 0 for e in self.ranges.get("Ey", [0])) and all(b == 0 for b in self.ranges.get("By", [0]))
107 :
108 0 : @property
109 0 : def steps(self) -> int:
110 : """Return the number of steps."""
111 0 : return len(next(iter(self.ranges.values())))
112 :
113 0 : @property
114 0 : def n_atoms(self) -> int:
115 : """Return the number of atoms."""
116 0 : return len(self.species)
117 :
118 0 : def get_efield(self, step: int) -> list[float]:
119 : """Return the electric field for the given step."""
120 0 : efield_keys: list[RangesKeys] = ["Ex", "Ey", "Ez"]
121 0 : return [self.ranges[key][step] if key in self.ranges else 0 for key in efield_keys]
122 :
123 0 : def get_bfield(self, step: int) -> list[float]:
124 : """Return the magnetic field for the given step."""
125 0 : bfield_keys: list[RangesKeys] = ["Bx", "By", "Bz"]
126 0 : return [self.ranges[key][step] if key in self.ranges else 0 for key in bfield_keys]
127 :
128 0 : def get_species(self, atom: Optional[int] = None) -> str:
129 : """Return the species for the given ket."""
130 0 : return self.species[self._check_atom(atom)]
131 :
132 0 : def get_quantum_numbers(self, atom: Optional[int] = None) -> dict[str, Any]:
133 : """Return the quantum numbers for the given ket."""
134 0 : return self.quantum_numbers[self._check_atom(atom)]
135 :
136 0 : def get_quantum_number_restrictions(self, atom: Optional[int] = None) -> dict[str, Any]:
137 : """Return the quantum number restrictions for the given ket."""
138 0 : atom = self._check_atom(atom)
139 0 : qn_restrictions: dict[str, tuple[float, float]] = {}
140 0 : for key, delta in self.quantum_number_deltas[atom].items():
141 0 : if key in self.quantum_numbers[atom]:
142 0 : qn_restrictions[key] = (
143 : self.quantum_numbers[atom][key] - delta,
144 : self.quantum_numbers[atom][key] + delta,
145 : )
146 : else:
147 0 : raise ValueError(f"Quantum number delta {key} not found in quantum numbers.")
148 0 : return qn_restrictions
149 :
150 0 : def _check_atom(self, atom: Optional[int] = None) -> int:
151 : """Check if the atom is valid."""
152 0 : if atom is not None:
153 0 : return atom
154 0 : if self.n_atoms == 1:
155 0 : return 0
156 0 : raise ValueError("Atom index is required for multiple atoms")
157 :
158 0 : def get_diagonalize_energy_range(self, energy_of_interest: float) -> dict[str, Any]:
159 : """Return the kwargs for the diagonalization energy range."""
160 0 : if self.diagonalize_relative_energy_range is None:
161 0 : return {}
162 0 : kwargs: dict[str, Any] = {"energy_unit": "GHz"}
163 0 : kwargs["energy_range"] = (
164 : energy_of_interest + self.diagonalize_relative_energy_range[0],
165 : energy_of_interest + self.diagonalize_relative_energy_range[1],
166 : )
167 0 : return kwargs
168 :
169 0 : def get_x_values(self) -> list[float]:
170 : """Return the x values for the plot."""
171 0 : max_key = self._get_ranges_max_diff_key()
172 0 : return self.ranges[max_key]
173 :
174 0 : def get_x_label(self) -> str:
175 : """Return the x values for the plot."""
176 0 : max_key = self._get_ranges_max_diff_key()
177 0 : x_label = f"{max_key} [{UnitFromRangeKey[max_key]}]"
178 :
179 0 : non_constant_keys = [key for key, values in self.ranges.items() if key != max_key and values[0] != values[-1]]
180 0 : if non_constant_keys:
181 0 : x_label += f" ({', '.join(non_constant_keys)} did also change)"
182 :
183 0 : return x_label
184 :
185 0 : def _get_ranges_max_diff_key(self) -> RangesKeys:
186 : """Return the key with the maximum difference in the ranges."""
187 0 : range_diffs: dict[RangesKeys, float] = {key: abs(r[-1] - r[0]) for key, r in self.ranges.items()}
188 0 : return max(range_diffs, key=lambda x: range_diffs.get(x, -1))
189 :
190 0 : def to_replacement_dict(self) -> dict[str, str]:
191 : """Return a dictionary with the parameters for replacement."""
192 0 : max_key = self._get_ranges_max_diff_key()
193 0 : replacements: dict[str, str] = {
194 : "$PI_DTYPE": "real" if self.is_real else "complex",
195 : "$X_VARIABLE_NAME": VariableNameFromRangeKey[max_key],
196 : "$X_LABEL": as_string(self.get_x_label(), raw_string=True),
197 : }
198 :
199 0 : for atom in range(self.n_atoms):
200 0 : replacements[f"$SPECIES_{atom}"] = as_string(self.get_species(atom))
201 0 : replacements[f"$QUANTUM_NUMBERS_{atom}"] = dict_to_repl(self.get_quantum_numbers(atom))
202 0 : replacements[f"$QUANTUM_NUMBERS_RESTRICTIONS_{atom}"] = dict_to_repl(
203 : self.get_quantum_number_restrictions(atom)
204 : )
205 :
206 0 : replacements["$STEPS"] = str(self.steps)
207 0 : for key, values in self.ranges.items():
208 0 : replacements[f"${key.upper()}_MIN"] = str(values[0])
209 0 : replacements[f"${key.upper()}_MAX"] = str(values[-1])
210 0 : if values[0] == values[-1]:
211 0 : replacements[f"${key.upper()}_VALUE"] = str(values[0])
212 :
213 0 : replacements["$DIAGONALIZE_KWARGS"] = dict_to_repl(self.diagonalize_kwargs)
214 :
215 0 : if self.diagonalize_relative_energy_range is not None:
216 0 : r_energy = self.diagonalize_relative_energy_range
217 0 : replacements["$DIAGONALIZE_ENERGY_RANGE_KWARGS"] = (
218 : f', energy_range=(ket_energy + {r_energy[0]}, ket_energy - {-r_energy[1]}), energy_unit="GHz"'
219 : )
220 : else:
221 0 : replacements["$DIAGONALIZE_ENERGY_RANGE_KWARGS"] = ""
222 :
223 0 : return replacements
224 :
225 :
226 0 : @dataclass
227 0 : class Results(ABC):
228 0 : energies: list["NDArray"]
229 0 : energy_offset: float
230 0 : ket_overlaps: list["NDArray"]
231 0 : state_labels_0: list[str]
232 :
233 0 : @classmethod
234 0 : def from_calculate(
235 : cls,
236 : system_list: Union[
237 : list[pi_real.SystemPair], list[pi_complex.SystemPair], list[pi_real.SystemAtom], list[pi_complex.SystemAtom]
238 : ],
239 : ket: Union[_wrapped.KetAtom, tuple[_wrapped.KetAtom, ...]],
240 : energy_offset: float,
241 : ) -> "Self":
242 : """Create Results object from ket, basis, and diagonalized systems."""
243 0 : energies = [system.get_eigenenergies("GHz") - energy_offset for system in system_list]
244 0 : ket_overlaps = [system.get_eigenbasis().get_overlaps(ket) for system in system_list] # type: ignore [arg-type]
245 0 : basis_0 = system_list[-1].get_eigenbasis()
246 0 : state_0 = [basis_0.kets[i] for i in range(basis_0.number_of_states)]
247 0 : state_labels_0 = [s.get_label("ket") for s in state_0]
248 :
249 0 : return cls(energies, energy_offset, ket_overlaps, state_labels_0)
250 :
251 :
252 0 : def as_string(value: str, *, raw_string: bool = False) -> str:
253 0 : string = '"' + value + '"'
254 0 : if raw_string:
255 0 : string = "r" + string
256 0 : return string
257 :
258 :
259 0 : def dict_to_repl(d: dict[str, Any]) -> str:
260 : """Convert a dictionary to a string for replacement."""
261 0 : if not d:
262 0 : return ""
263 0 : repl = ""
264 0 : for k, v in d.items():
265 0 : if isinstance(v, str):
266 0 : repl += f", {k}={as_string(v)}"
267 : else:
268 0 : repl += f", {k}={v}"
269 0 : return repl
|