Line data Source code
1 : # SPDX-FileCopyrightText: 2025 Pairinteraction Developers
2 : # SPDX-License-Identifier: LGPL-3.0-or-later
3 :
4 1 : import logging
5 1 : from abc import ABC
6 1 : from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union
7 :
8 1 : import numpy as np
9 1 : from attr import dataclass
10 :
11 1 : from pairinteraction import (
12 : _wrapped,
13 : complex as pi_complex,
14 : real as pi_real,
15 : )
16 1 : from pairinteraction_gui.config.system_config import RangesKeys
17 :
18 : if TYPE_CHECKING:
19 : from typing_extensions import Self
20 :
21 : from pairinteraction.units import NDArray
22 : from pairinteraction_gui.page import OneAtomPage, TwoAtomsPage
23 :
24 1 : logger = logging.getLogger(__name__)
25 :
26 : # FIXME: having all kwargs dictionaries being Any is a hacky solution, it would be nice to use TypedDict in the future
27 :
28 1 : UnitFromRangeKey: dict[RangesKeys, str] = {
29 : "Ex": "V/cm",
30 : "Ey": "V/cm",
31 : "Ez": "V/cm",
32 : "Bx": "Gauss",
33 : "By": "Gauss",
34 : "Bz": "Gauss",
35 : "Distance": r"$\mu$m",
36 : "Angle": r"$^\circ$",
37 : }
38 :
39 1 : VariableNameFromRangeKey: dict[RangesKeys, str] = {
40 : "Ex": "efield_x",
41 : "Ey": "efield_y",
42 : "Ez": "efield_z",
43 : "Bx": "bfield_x",
44 : "By": "bfield_y",
45 : "Bz": "bfield_z",
46 : "Distance": "distance",
47 : "Angle": "angle",
48 : }
49 :
50 1 : PageType = TypeVar("PageType", "OneAtomPage", "TwoAtomsPage")
51 :
52 :
53 1 : @dataclass
54 1 : class Parameters(ABC, Generic[PageType]):
55 1 : species: tuple[str, ...]
56 1 : quantum_numbers: tuple[dict[str, float], ...]
57 1 : quantum_number_deltas: tuple[dict[str, float], ...]
58 1 : ranges: dict[RangesKeys, list[float]]
59 1 : diagonalize_kwargs: dict[str, str]
60 1 : diagonalize_relative_energy_range: Union[tuple[float, float], None]
61 1 : number_state_labels: int
62 :
63 1 : def __post_init__(self) -> None:
64 : """Post-initialization processing."""
65 : # Check if all ranges have the same number of steps
66 0 : if not all(len(v) == self.steps for v in self.ranges.values()):
67 0 : raise ValueError("All ranges must have the same number of steps")
68 :
69 : # Check if all tuples have the same length
70 0 : if not all(
71 : len(tup) == self.n_atoms for tup in [self.species, self.quantum_numbers, self.quantum_number_deltas]
72 : ):
73 0 : raise ValueError("All tuples must have the same length as the number of atoms")
74 :
75 1 : @classmethod
76 1 : def from_page(cls, page: PageType) -> "Self":
77 : """Create Parameters object from page."""
78 1 : n_atoms = page.ket_config.n_atoms
79 :
80 1 : species = tuple(page.ket_config.get_species(atom) for atom in range(n_atoms))
81 1 : quantum_numbers = tuple(page.ket_config.get_quantum_numbers(atom) for atom in range(n_atoms))
82 :
83 1 : quantum_number_deltas = tuple(page.basis_config.get_quantum_number_deltas(atom) for atom in range(n_atoms))
84 :
85 1 : ranges = page.system_config.get_ranges_dict()
86 :
87 1 : diagonalize_kwargs = {}
88 1 : if page.calculation_config.fast_mode.isChecked():
89 1 : diagonalize_kwargs["diagonalizer"] = "lapacke_evr"
90 1 : diagonalize_kwargs["float_type"] = "float32"
91 :
92 1 : diagonalize_relative_energy_range = None
93 1 : if page.calculation_config.energy_range.isChecked():
94 0 : diagonalize_relative_energy_range = page.calculation_config.energy_range.values()
95 :
96 1 : return cls(
97 : species,
98 : quantum_numbers,
99 : quantum_number_deltas,
100 : ranges,
101 : diagonalize_kwargs,
102 : diagonalize_relative_energy_range,
103 : page.calculation_config.number_state_labels.value(default=0),
104 : )
105 :
106 1 : @property
107 1 : def is_real(self) -> bool:
108 : """Check if the parameters are real."""
109 1 : return all(e == 0 for e in self.ranges.get("Ey", [0])) and all(b == 0 for b in self.ranges.get("By", [0]))
110 :
111 1 : @property
112 1 : def steps(self) -> int:
113 : """Return the number of steps."""
114 1 : return len(next(iter(self.ranges.values())))
115 :
116 1 : @property
117 1 : def n_atoms(self) -> int:
118 : """Return the number of atoms."""
119 0 : return len(self.species)
120 :
121 1 : def get_efield(self, step: int) -> list[float]:
122 : """Return the electric field for the given step."""
123 1 : efield_keys: list[RangesKeys] = ["Ex", "Ey", "Ez"]
124 1 : return [self.ranges[key][step] if key in self.ranges else 0 for key in efield_keys]
125 :
126 1 : def get_bfield(self, step: int) -> list[float]:
127 : """Return the magnetic field for the given step."""
128 1 : bfield_keys: list[RangesKeys] = ["Bx", "By", "Bz"]
129 1 : return [self.ranges[key][step] if key in self.ranges else 0 for key in bfield_keys]
130 :
131 1 : def get_species(self, atom: Optional[int] = None) -> str:
132 : """Return the species for the given ket."""
133 1 : return self.species[self._check_atom(atom)]
134 :
135 1 : def get_quantum_numbers(self, atom: Optional[int] = None) -> dict[str, Any]:
136 : """Return the quantum numbers for the given ket."""
137 1 : return self.quantum_numbers[self._check_atom(atom)]
138 :
139 1 : def get_quantum_number_restrictions(self, atom: Optional[int] = None) -> dict[str, Any]:
140 : """Return the quantum number restrictions for the given ket."""
141 1 : atom = self._check_atom(atom)
142 1 : qn_restrictions: dict[str, tuple[float, float]] = {}
143 1 : for key, delta in self.quantum_number_deltas[atom].items():
144 1 : if key in self.quantum_numbers[atom]:
145 1 : qn_restrictions[key] = (
146 : self.quantum_numbers[atom][key] - delta,
147 : self.quantum_numbers[atom][key] + delta,
148 : )
149 : else:
150 0 : raise ValueError(f"Quantum number delta {key} not found in quantum numbers.")
151 1 : return qn_restrictions
152 :
153 1 : def _check_atom(self, atom: Optional[int] = None) -> int:
154 : """Check if the atom is valid."""
155 1 : if atom is not None:
156 1 : return atom
157 0 : if self.n_atoms == 1:
158 0 : return 0
159 0 : raise ValueError("Atom index is required for multiple atoms")
160 :
161 1 : def get_diagonalize_energy_range(self, energy_of_interest: float) -> dict[str, Any]:
162 : """Return the kwargs for the diagonalization energy range."""
163 1 : if self.diagonalize_relative_energy_range is None:
164 1 : return {}
165 0 : kwargs: dict[str, Any] = {"energy_unit": "GHz"}
166 0 : kwargs["energy_range"] = (
167 : energy_of_interest + self.diagonalize_relative_energy_range[0],
168 : energy_of_interest + self.diagonalize_relative_energy_range[1],
169 : )
170 0 : return kwargs
171 :
172 1 : def get_x_values(self) -> list[float]:
173 : """Return the x values for the plot."""
174 0 : max_key = self._get_ranges_max_diff_key()
175 0 : return self.ranges[max_key]
176 :
177 1 : def get_x_label(self) -> str:
178 : """Return the x values for the plot."""
179 0 : max_key = self._get_ranges_max_diff_key()
180 0 : x_label = f"{max_key} [{UnitFromRangeKey[max_key]}]"
181 :
182 0 : non_constant_keys = [key for key, values in self.ranges.items() if key != max_key and values[0] != values[-1]]
183 0 : if non_constant_keys:
184 0 : x_label += f" ({', '.join(non_constant_keys)} did also change)"
185 :
186 0 : return x_label
187 :
188 1 : def _get_ranges_max_diff_key(self) -> RangesKeys:
189 : """Return the key with the maximum difference in the ranges."""
190 0 : range_diffs: dict[RangesKeys, float] = {key: abs(r[-1] - r[0]) for key, r in self.ranges.items()}
191 0 : return max(range_diffs, key=lambda x: range_diffs.get(x, -1))
192 :
193 1 : def to_replacement_dict(self) -> dict[str, str]:
194 : """Return a dictionary with the parameters for replacement."""
195 0 : max_key = self._get_ranges_max_diff_key()
196 0 : replacements: dict[str, str] = {
197 : "$PI_DTYPE": "real" if self.is_real else "complex",
198 : "$X_VARIABLE_NAME": VariableNameFromRangeKey[max_key],
199 : "$X_LABEL": as_string(self.get_x_label(), raw_string=True),
200 : }
201 :
202 0 : for atom in range(self.n_atoms):
203 0 : replacements[f"$SPECIES_{atom}"] = as_string(self.get_species(atom))
204 0 : replacements[f"$QUANTUM_NUMBERS_{atom}"] = dict_to_repl(self.get_quantum_numbers(atom))
205 0 : replacements[f"$QUANTUM_NUMBERS_RESTRICTIONS_{atom}"] = dict_to_repl(
206 : self.get_quantum_number_restrictions(atom)
207 : )
208 :
209 0 : replacements["$STEPS"] = str(self.steps)
210 0 : for key, values in self.ranges.items():
211 0 : replacements[f"${key.upper()}_MIN"] = str(values[0])
212 0 : replacements[f"${key.upper()}_MAX"] = str(values[-1])
213 0 : if values[0] == values[-1]:
214 0 : replacements[f"${key.upper()}_VALUE"] = str(values[0])
215 :
216 0 : replacements["$DIAGONALIZE_KWARGS"] = dict_to_repl(self.diagonalize_kwargs)
217 :
218 0 : if self.diagonalize_relative_energy_range is not None:
219 0 : r_energy = self.diagonalize_relative_energy_range
220 0 : replacements["$DIAGONALIZE_ENERGY_RANGE_KWARGS"] = (
221 : f', energy_range=(ket_energy + {r_energy[0]}, ket_energy - {-r_energy[1]}), energy_unit="GHz"'
222 : )
223 : else:
224 0 : replacements["$DIAGONALIZE_ENERGY_RANGE_KWARGS"] = ""
225 :
226 0 : return replacements
227 :
228 :
229 1 : @dataclass
230 1 : class Results(ABC):
231 1 : energies: list["NDArray"]
232 1 : energy_offset: float
233 1 : ket_overlaps: list["NDArray"]
234 1 : state_labels: dict[int, list[str]]
235 :
236 1 : @classmethod
237 1 : def from_calculate(
238 : cls,
239 : parameters: Parameters[Any],
240 : system_list: Union[
241 : list[pi_real.SystemPair], list[pi_complex.SystemPair], list[pi_real.SystemAtom], list[pi_complex.SystemAtom]
242 : ],
243 : ket: Union[_wrapped.KetAtom, tuple[_wrapped.KetAtom, ...]],
244 : energy_offset: float,
245 : ) -> "Self":
246 : """Create Results object from ket, basis, and diagonalized systems."""
247 1 : energies = [system.get_eigenenergies("GHz") - energy_offset for system in system_list]
248 1 : ket_overlaps = [system.get_eigenbasis().get_overlaps(ket) for system in system_list] # type: ignore [arg-type]
249 :
250 1 : steps_with_labels = [int(i) for i in np.linspace(0, parameters.steps - 1, parameters.number_state_labels)]
251 1 : states_dict = {i: system_list[i].get_eigenbasis().states for i in steps_with_labels}
252 1 : state_labels = {i: [s.get_label() for s in states] for i, states in states_dict.items()}
253 :
254 1 : return cls(energies, energy_offset, ket_overlaps, state_labels)
255 :
256 :
257 1 : def as_string(value: str, *, raw_string: bool = False) -> str:
258 0 : string = '"' + value + '"'
259 0 : if raw_string:
260 0 : string = "r" + string
261 0 : return string
262 :
263 :
264 1 : def dict_to_repl(d: dict[str, Any]) -> str:
265 : """Convert a dictionary to a string for replacement."""
266 0 : if not d:
267 0 : return ""
268 0 : repl = ""
269 0 : for k, v in d.items():
270 0 : if isinstance(v, str):
271 0 : repl += f", {k}={as_string(v)}"
272 : else:
273 0 : repl += f", {k}={v}"
274 0 : return repl
|