Line data Source code
1 : # SPDX-FileCopyrightText: 2025 PairInteraction Developers
2 : # SPDX-License-Identifier: LGPL-3.0-or-later
3 :
4 1 : import logging
5 1 : from abc import ABC
6 1 : from collections.abc import Mapping
7 1 : from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union
8 :
9 1 : import numpy as np
10 1 : from attr import dataclass
11 :
12 1 : from pairinteraction import (
13 : _wrapped,
14 : complex as pi_complex,
15 : real as pi_real,
16 : )
17 1 : from pairinteraction_gui.config.system_config import RangesKeys
18 :
19 : if TYPE_CHECKING:
20 : from typing_extensions import Self
21 :
22 : from pairinteraction.units import NDArray
23 : from pairinteraction_gui.config.basis_config import QuantumNumberRestrictions
24 : from pairinteraction_gui.config.ket_config import QuantumNumbers
25 : from pairinteraction_gui.page import OneAtomPage, TwoAtomsPage
26 :
27 1 : logger = logging.getLogger(__name__)
28 :
29 1 : UnitFromRangeKey: dict[RangesKeys, str] = {
30 : "Ex": "V/cm",
31 : "Ey": "V/cm",
32 : "Ez": "V/cm",
33 : "Bx": "Gauss",
34 : "By": "Gauss",
35 : "Bz": "Gauss",
36 : "Distance": r"$\mu$m",
37 : "Angle": r"$^\circ$",
38 : }
39 :
40 1 : VariableNameFromRangeKey: dict[RangesKeys, str] = {
41 : "Ex": "efield_x",
42 : "Ey": "efield_y",
43 : "Ez": "efield_z",
44 : "Bx": "bfield_x",
45 : "By": "bfield_y",
46 : "Bz": "bfield_z",
47 : "Distance": "distance",
48 : "Angle": "angle",
49 : }
50 :
51 1 : PageType = TypeVar("PageType", "OneAtomPage", "TwoAtomsPage")
52 :
53 :
54 1 : @dataclass
55 1 : class Parameters(ABC, Generic[PageType]):
56 1 : species: tuple[str, ...]
57 1 : quantum_numbers: tuple["QuantumNumbers", ...]
58 1 : quantum_number_restrictions: tuple["QuantumNumberRestrictions", ...]
59 1 : ranges: dict[RangesKeys, list[float]]
60 1 : diamagnetism_enabled: bool
61 1 : diagonalize_kwargs: dict[str, str]
62 1 : diagonalize_relative_energy_range: Union[tuple[float, float], None]
63 1 : number_state_labels: int
64 :
65 1 : def __post_init__(self) -> None:
66 : """Post-initialization processing."""
67 : # Check if all ranges have the same number of steps
68 0 : if not all(len(v) == self.steps for v in self.ranges.values()):
69 0 : raise ValueError("All ranges must have the same number of steps")
70 :
71 : # Check if all tuples have the same length
72 0 : if not all(
73 : len(tup) == self.n_atoms for tup in [self.species, self.quantum_numbers, self.quantum_number_restrictions]
74 : ):
75 0 : raise ValueError("All tuples must have the same length as the number of atoms")
76 :
77 1 : @classmethod
78 1 : def from_page(cls, page: PageType) -> "Self":
79 : """Create Parameters object from page."""
80 1 : n_atoms = page.ket_config.n_atoms
81 :
82 1 : species = tuple(page.ket_config.get_species(atom) for atom in range(n_atoms))
83 1 : quantum_numbers = tuple(page.ket_config.get_quantum_numbers(atom) for atom in range(n_atoms))
84 :
85 1 : quantum_number_restrictions = tuple(
86 : page.basis_config.get_quantum_number_restrictions(atom) for atom in range(n_atoms)
87 : )
88 :
89 1 : ranges = page.system_config.get_ranges_dict()
90 1 : diamagnetism_enabled = page.system_config.diamagnetism.isChecked()
91 :
92 1 : diagonalize_kwargs = {}
93 1 : if page.calculation_config.fast_mode.isChecked():
94 1 : diagonalize_kwargs["diagonalizer"] = "lapacke_evr"
95 1 : diagonalize_kwargs["float_type"] = "float32"
96 :
97 1 : diagonalize_relative_energy_range = None
98 1 : if page.calculation_config.energy_range.isChecked():
99 0 : diagonalize_relative_energy_range = page.calculation_config.energy_range.values()
100 :
101 1 : return cls(
102 : species,
103 : quantum_numbers,
104 : quantum_number_restrictions,
105 : ranges,
106 : diamagnetism_enabled,
107 : diagonalize_kwargs,
108 : diagonalize_relative_energy_range,
109 : page.calculation_config.number_state_labels.value(default=0),
110 : )
111 :
112 1 : @property
113 1 : def is_real(self) -> bool:
114 : """Check if the parameters are real."""
115 1 : return all(e == 0 for e in self.ranges.get("Ey", [0])) and all(b == 0 for b in self.ranges.get("By", [0]))
116 :
117 1 : @property
118 1 : def steps(self) -> int:
119 : """Return the number of steps."""
120 1 : return len(next(iter(self.ranges.values())))
121 :
122 1 : @property
123 1 : def n_atoms(self) -> int:
124 : """Return the number of atoms."""
125 1 : return len(self.species)
126 :
127 1 : def get_efield(self, step: int) -> list[float]:
128 : """Return the electric field for the given step."""
129 0 : efield_keys: list[RangesKeys] = ["Ex", "Ey", "Ez"]
130 0 : return [self.ranges[key][step] if key in self.ranges else 0 for key in efield_keys]
131 :
132 1 : def get_bfield(self, step: int) -> list[float]:
133 : """Return the magnetic field for the given step."""
134 0 : bfield_keys: list[RangesKeys] = ["Bx", "By", "Bz"]
135 0 : return [self.ranges[key][step] if key in self.ranges else 0 for key in bfield_keys]
136 :
137 1 : def get_species(self, atom: Optional[int] = None) -> str:
138 : """Return the species for the given ket."""
139 1 : return self.species[self._check_atom(atom)]
140 :
141 1 : def get_quantum_numbers(self, atom: Optional[int] = None) -> "QuantumNumbers":
142 : """Return the quantum numbers for the given ket."""
143 1 : return self.quantum_numbers[self._check_atom(atom)]
144 :
145 1 : def get_ket_atom(self, atom: Optional[int] = None) -> _wrapped.KetAtom:
146 : """Return the ket atom for the given atom index."""
147 0 : return _wrapped.KetAtom(self.get_species(atom), **self.get_quantum_numbers(atom))
148 :
149 1 : def get_quantum_number_restrictions(self, atom: Optional[int] = None) -> "QuantumNumberRestrictions":
150 : """Return the quantum number restrictions."""
151 1 : return self.quantum_number_restrictions[self._check_atom(atom)]
152 :
153 1 : def _check_atom(self, atom: Optional[int] = None) -> int:
154 : """Check if the atom is valid."""
155 1 : if atom is not None:
156 1 : return atom
157 0 : if self.n_atoms == 1:
158 0 : return 0
159 0 : raise ValueError("Atom index is required for multiple atoms")
160 :
161 1 : def get_diagonalize_energy_range_kwargs(self, energy_of_interest: float) -> dict[str, Any]:
162 : """Return the kwargs for the diagonalization energy range."""
163 0 : if self.diagonalize_relative_energy_range is None:
164 0 : return {}
165 0 : kwargs: dict[str, Any] = {"energy_range_unit": "GHz"}
166 0 : kwargs["energy_range"] = (
167 : energy_of_interest + self.diagonalize_relative_energy_range[0],
168 : energy_of_interest + self.diagonalize_relative_energy_range[1],
169 : )
170 0 : return kwargs
171 :
172 1 : def get_x_values(self) -> list[float]:
173 : """Return the x values for the plot."""
174 0 : max_key = self._get_ranges_max_diff_key()
175 0 : return self.ranges[max_key]
176 :
177 1 : def get_x_label(self) -> str:
178 : """Return the x values for the plot."""
179 1 : max_key = self._get_ranges_max_diff_key()
180 1 : x_label = f"{max_key} [{UnitFromRangeKey[max_key]}]"
181 :
182 1 : non_constant_keys = [key for key, values in self.ranges.items() if key != max_key and values[0] != values[-1]]
183 1 : if non_constant_keys:
184 0 : x_label += f" ({', '.join(non_constant_keys)} did also change)"
185 :
186 1 : return x_label
187 :
188 1 : def _get_ranges_max_diff_key(self) -> RangesKeys:
189 : """Return the key with the maximum difference in the ranges."""
190 1 : range_diffs: dict[RangesKeys, float] = {key: abs(r[-1] - r[0]) for key, r in self.ranges.items()}
191 1 : return max(range_diffs, key=lambda x: range_diffs.get(x, -1))
192 :
193 1 : def to_replacement_dict(self) -> dict[str, str]:
194 : """Return a dictionary with the parameters for replacement."""
195 1 : max_key = self._get_ranges_max_diff_key()
196 1 : replacements: dict[str, str] = {
197 : "$PI_DTYPE": "real" if self.is_real else "complex",
198 : "$X_VARIABLE_NAME": VariableNameFromRangeKey[max_key],
199 : "$X_LABEL": as_string(self.get_x_label(), raw_string=True),
200 : "$DIAMAGNETISM_ENABLED": str(self.diamagnetism_enabled),
201 : }
202 :
203 1 : for atom in range(self.n_atoms):
204 1 : replacements[f"$SPECIES_{atom}"] = as_string(self.get_species(atom))
205 1 : replacements[f"$QUANTUM_NUMBERS_{atom}"] = dict_to_repl(self.get_quantum_numbers(atom))
206 1 : replacements[f"$QUANTUM_NUMBERS_RESTRICTIONS_{atom}"] = dict_to_repl(
207 : self.get_quantum_number_restrictions(atom)
208 : )
209 :
210 1 : replacements["$STEPS"] = str(self.steps)
211 1 : for key, values in self.ranges.items():
212 1 : replacements[f"${key.upper()}_MIN"] = str(values[0])
213 1 : replacements[f"${key.upper()}_MAX"] = str(values[-1])
214 1 : if values[0] == values[-1]:
215 1 : replacements[f"${key.upper()}_VALUE"] = str(values[0])
216 :
217 1 : replacements["$DIAGONALIZE_KWARGS"] = dict_to_repl(self.diagonalize_kwargs)
218 :
219 1 : if self.diagonalize_relative_energy_range is not None:
220 0 : r_energy = self.diagonalize_relative_energy_range
221 0 : replacements["$DIAGONALIZE_ENERGY_RANGE_KWARGS"] = (
222 : f', energy_range=(ket_energy + {r_energy[0]}, ket_energy - {-r_energy[1]}), energy_range_unit="GHz"'
223 : )
224 : else:
225 1 : replacements["$DIAGONALIZE_ENERGY_RANGE_KWARGS"] = ""
226 :
227 1 : return replacements
228 :
229 :
230 1 : @dataclass
231 1 : class Results(ABC):
232 1 : energies: list["NDArray"]
233 1 : energy_offset: float
234 1 : ket_overlaps: list["NDArray"]
235 1 : state_labels: dict[int, list[str]]
236 :
237 1 : @classmethod
238 1 : def from_calculate(
239 : cls,
240 : parameters: Parameters[Any],
241 : system_list: Union[
242 : list[pi_real.SystemPair], list[pi_complex.SystemPair], list[pi_real.SystemAtom], list[pi_complex.SystemAtom]
243 : ],
244 : ket: Union[_wrapped.KetAtom, tuple[_wrapped.KetAtom, ...]],
245 : energy_offset: float,
246 : ) -> "Self":
247 : """Create Results object from ket, basis, and diagonalized systems."""
248 0 : energies = [system.get_eigenenergies("GHz") - energy_offset for system in system_list]
249 0 : ket_overlaps = [system.get_eigenbasis().get_overlaps(ket) for system in system_list] # type: ignore [arg-type]
250 :
251 0 : steps_with_labels = [int(i) for i in np.linspace(0, parameters.steps - 1, parameters.number_state_labels)]
252 0 : states_dict = {i: system_list[i].get_eigenbasis().states for i in steps_with_labels}
253 0 : state_labels = {i: [s.get_label() for s in states] for i, states in states_dict.items()}
254 :
255 0 : return cls(energies, energy_offset, ket_overlaps, state_labels)
256 :
257 :
258 1 : def as_string(value: str, *, raw_string: bool = False) -> str:
259 1 : string = '"' + value + '"'
260 1 : if raw_string:
261 1 : string = "r" + string
262 1 : return string
263 :
264 :
265 1 : def dict_to_repl(d: Mapping[str, Any]) -> str:
266 : """Convert a dictionary to a string for replacement."""
267 1 : if not d:
268 0 : return ""
269 1 : repl = ""
270 1 : for k, v in d.items():
271 1 : if isinstance(v, str):
272 1 : repl += f", {k}={as_string(v)}"
273 : else:
274 1 : repl += f", {k}={v}"
275 1 : return repl
|