Line data Source code
1 : # SPDX-FileCopyrightText: 2025 PairInteraction Developers
2 : # SPDX-License-Identifier: LGPL-3.0-or-later
3 1 : from __future__ import annotations
4 :
5 1 : import contextlib
6 1 : import copy
7 1 : import logging
8 1 : from functools import cached_property, lru_cache
9 1 : from typing import TYPE_CHECKING, Literal, overload
10 :
11 1 : import numpy as np
12 1 : from scipy import sparse
13 :
14 1 : from pairinteraction.basis import BasisAtom, BasisAtomReal, BasisPair, BasisPairReal
15 1 : from pairinteraction.diagonalization import diagonalize
16 1 : from pairinteraction.perturbative.perturbation_theory import calculate_perturbative_hamiltonian
17 1 : from pairinteraction.system import SystemAtom, SystemAtomReal, SystemPair, SystemPairReal
18 1 : from pairinteraction.units import QuantityArray, QuantityScalar
19 :
20 : if TYPE_CHECKING:
21 : from collections.abc import Sequence
22 :
23 : from scipy.sparse import csr_matrix
24 : from typing_extensions import Self
25 :
26 : from pairinteraction.ket import KetAtomTuple
27 : from pairinteraction.units import ArrayLike, NDArray, PintArray, PintFloat
28 :
29 :
30 1 : logger = logging.getLogger(__name__)
31 :
32 1 : BasisSystemLiteral = Literal["basis_atoms", "system_atoms", "basis_pair", "system_pair"]
33 :
34 :
35 1 : class EffectiveSystemPair:
36 : """Class for creating an effective SystemPair object and calculating the effective Hamiltonian.
37 :
38 : Given a subspace spanned by tuples of `KetAtom` objects (ket_tuples),
39 : this class automatically generates appropriate `BasisAtom`, `SystemAtom` objects as well as a `BasisPair` and
40 : `SystemPair` object to calculate the effective Hamiltonian in the subspace via perturbation theory.
41 :
42 : This class also allows to set magnetic and electric fields similar to the `SystemAtom` class,
43 : as well as the angle and distance between the two atoms like in the `SystemPair` class.
44 :
45 : Examples:
46 : >>> import pairinteraction as pi
47 : >>> ket_atoms = {
48 : ... "+": pi.KetAtom("Rb", n=59, l=0, j=0.5, m=0.5),
49 : ... "0": pi.KetAtom("Rb", n=58, l=1, j=1.5, m=1.5),
50 : ... "-": pi.KetAtom("Rb", n=58, l=0, j=0.5, m=0.5),
51 : ... }
52 : >>> ket_tuples = [
53 : ... (ket_atoms["+"], ket_atoms["-"]),
54 : ... (ket_atoms["0"], ket_atoms["0"]),
55 : ... (ket_atoms["-"], ket_atoms["+"]),
56 : ... ]
57 : >>> eff_system = pi.EffectiveSystemPair(ket_tuples)
58 : >>> eff_system = eff_system.set_distance(10, angle_degree=45, unit="micrometer")
59 : >>> eff_h = eff_system.get_effective_hamiltonian(unit="MHz")
60 : >>> eff_h -= np.eye(3) * eff_system.get_pair_energies("MHz")[1]
61 : >>> print(np.round(eff_h, 0), "MHz")
62 : [[292. 3. 0.]
63 : [ 3. 0. 3.]
64 : [ 0. 3. 292.]] MHz
65 :
66 : """
67 :
68 1 : _basis_atom_class = BasisAtom
69 1 : _basis_pair_class = BasisPair
70 1 : _system_atom_class = SystemAtom
71 1 : _system_pair_class = SystemPair
72 :
73 1 : def __init__(self, ket_tuples: Sequence[KetAtomTuple]) -> None:
74 1 : if not all(len(ket_tuple) == 2 for ket_tuple in ket_tuples):
75 0 : raise ValueError("All ket tuples must contain exactly two kets")
76 1 : for i in range(2):
77 1 : if not all(ket_tuple[i].species == ket_tuples[0][i].species for ket_tuple in ket_tuples):
78 0 : raise ValueError(f"All kets for atom={i} must have the same species")
79 :
80 : # Perturbation attributes
81 1 : self._ket_tuples = [tuple(kets) for kets in ket_tuples]
82 1 : self._perturbation_order = 2
83 :
84 : # BasisAtom and SystemAtom attributes
85 1 : self._delta_n: int | None = None
86 1 : self._delta_l: int | None = None
87 1 : self._delta_m: int | None = None
88 1 : self._electric_field: PintArray | None = None
89 1 : self._magnetic_field: PintArray | None = None
90 1 : self._diamagnetism_enabled: bool | None = None
91 :
92 : # BasisPair and SystemPair attributes
93 1 : self._minimum_number_of_ket_pairs: int | None = None
94 1 : self._maximum_number_of_ket_pairs: int | None = None
95 1 : self._interaction_order: int | None = None
96 1 : self._distance_vector: PintArray | None = None
97 :
98 : # misc
99 1 : self._eff_h_dict_au: dict[int, NDArray] | None = None
100 1 : self._eff_vecs: csr_matrix | None = None
101 :
102 : # misc user set stuff
103 1 : self._user_set_parts: set[BasisSystemLiteral] = set()
104 :
105 1 : def copy(self: Self) -> Self:
106 : """Create a copy of the EffectiveSystemPair object (before it has been created)."""
107 0 : if self._is_created("basis_atoms"):
108 0 : raise RuntimeError(
109 : "Cannot copy the EffectiveSystemPair object after it has been created. "
110 : "Please create a new object instead."
111 : )
112 0 : return copy.copy(self)
113 :
114 1 : def _is_created(self: Self, what: BasisSystemLiteral = "basis_atoms") -> bool:
115 : """Check if some part of the effective Hamiltonian has already been created."""
116 1 : return hasattr(self, "_" + what)
117 :
118 1 : def _ensure_not_created(self: Self, what: BasisSystemLiteral = "basis_atoms") -> None:
119 : """Ensure that some part of the effective Hamiltonian has not been created yet."""
120 1 : if self._is_created(what):
121 0 : raise RuntimeError(
122 : f"Cannot change parameters for {what} after it has already been created. "
123 : f"Please set all parameters before {what} before accessing it (or creating the effective Hamiltonian)."
124 : )
125 :
126 1 : def _delete_created(self: Self, what: BasisSystemLiteral = "basis_atoms") -> None:
127 : """Delete the created part of the effective Hamiltonian.
128 :
129 : Args:
130 : what: The part of the effective Hamiltonian to delete.
131 : Default is "basis_atoms", which means delete all parts that have been created.
132 :
133 : """
134 1 : self._eff_h_dict_au = None
135 1 : self._eff_vecs = None
136 1 : self._eff_basis = None
137 1 : with contextlib.suppress(AttributeError):
138 1 : del self.model_inds
139 :
140 1 : parts_order: list[BasisSystemLiteral] = ["system_pair", "basis_pair", "system_atoms", "basis_atoms"]
141 1 : for part in parts_order:
142 1 : if part in self._user_set_parts:
143 0 : raise RuntimeError(
144 : f"Cannot delete {part} because it has been set by the user. "
145 : "Please create a new EffectiveSystemPair object instead."
146 : )
147 1 : with contextlib.suppress(AttributeError):
148 1 : delattr(self, "_" + part)
149 1 : if part == what:
150 1 : break
151 :
152 : # # # Perturbation methods and attributes # # #
153 1 : @property
154 1 : def ket_tuples(self) -> list[KetAtomTuple]:
155 : """The tuples of kets, which form the model space for the effective Hamiltonian."""
156 1 : return self._ket_tuples # type: ignore [return-value]
157 :
158 1 : @property
159 1 : def perturbation_order(self) -> int:
160 : """The perturbation order for the effective Hamiltonian."""
161 1 : return self._perturbation_order
162 :
163 1 : def set_perturbation_order(self: Self, order: int) -> Self:
164 : """Set the perturbation order for the effective Hamiltonian."""
165 1 : self._delete_created()
166 1 : self._perturbation_order = order
167 1 : return self
168 :
169 : # # # BasisAtom methods and attributes # # #
170 1 : @property
171 1 : def basis_atoms(self) -> tuple[BasisAtom, BasisAtom]:
172 : """The basis objects for the single-atom systems."""
173 1 : if not self._is_created("basis_atoms"):
174 1 : self._create_basis_atoms()
175 1 : return self._basis_atoms # type: ignore [return-value]
176 :
177 1 : @basis_atoms.setter
178 1 : def basis_atoms(self, basis_atoms: tuple[BasisAtom, BasisAtom]) -> None:
179 1 : self._ensure_not_created()
180 1 : if self._delta_n is not None or self._delta_l is not None or self._delta_m is not None:
181 0 : logger.warning("Setting basis_atoms will overwrite parameters defined for basis_atoms.")
182 1 : self._user_set_parts.add("basis_atoms")
183 1 : self._basis_atoms = tuple(basis_atoms)
184 :
185 1 : def set_delta_n(self: Self, delta_n: int) -> Self:
186 : """Set the delta_n value for single-atom basis."""
187 0 : self._delete_created()
188 0 : self._delta_n = delta_n
189 0 : return self
190 :
191 1 : def set_delta_l(self: Self, delta_l: int) -> Self:
192 : """Set the delta_l value for single-atom basis."""
193 0 : self._delete_created()
194 0 : self._delta_l = delta_l
195 0 : return self
196 :
197 1 : def set_delta_m(self: Self, delta_m: int) -> Self:
198 : """Set the delta_m value for single-atom basis."""
199 0 : self._delete_created()
200 0 : self._delta_m = delta_m
201 0 : return self
202 :
203 1 : def _create_basis_atoms(self) -> None:
204 1 : delta_n = self._delta_n if self._delta_n is not None else 7
205 1 : delta_l = self._delta_l
206 1 : if delta_l is None:
207 1 : delta_l = self.perturbation_order * (self.interaction_order - 2)
208 1 : delta_m = self._delta_m
209 1 : if delta_m is None and self._delta_l is None and self._are_fields_along_z:
210 1 : delta_m = self.perturbation_order * (self.interaction_order - 2)
211 :
212 1 : basis_atoms: list[BasisAtom] = []
213 1 : use_real = isinstance(self, EffectiveSystemPairReal)
214 1 : for i in range(2):
215 1 : kets = [ket_tuple[i] for ket_tuple in self.ket_tuples]
216 1 : nlfm = np.transpose([[ket.n, ket.l, ket.f, ket.m] for ket in kets])
217 1 : n_range = (int(np.min(nlfm[0])) - delta_n, int(np.max(nlfm[0])) + delta_n)
218 1 : l_range = (np.min(nlfm[1]) - delta_l, np.max(nlfm[1]) + delta_l)
219 1 : if any(ket.is_calculated_with_mqdt for ket in kets) and self._delta_l is None:
220 : # for mqdt we increase the default delta_l by 1 to take into account the variance ...
221 0 : l_range = (np.min(nlfm[1]) - delta_l - 1, np.max(nlfm[1]) + delta_l + 1)
222 1 : m_range = (np.min(nlfm[3]) - delta_m, np.max(nlfm[3]) + delta_m) if delta_m is not None else None
223 1 : basis = get_basis_atom_with_cache(kets[0].species, n_range, l_range, m_range, use_real=use_real)
224 1 : basis_atoms.append(basis)
225 :
226 1 : self._basis_atoms = tuple(basis_atoms)
227 :
228 : # # # SystemAtom methods and attributes # # #
229 1 : @property
230 1 : def system_atoms(self) -> tuple[SystemAtom, SystemAtom]:
231 : """The system objects for the single-atom systems."""
232 1 : if not self._is_created("system_atoms"):
233 1 : self._create_system_atoms()
234 1 : return self._system_atoms
235 :
236 1 : @system_atoms.setter
237 1 : def system_atoms(self, system_atoms: tuple[SystemAtom, SystemAtom]) -> None:
238 1 : self._ensure_not_created()
239 1 : if (
240 : self._electric_field is not None
241 : or self._magnetic_field is not None
242 : or self._diamagnetism_enabled is not None
243 : ):
244 0 : logger.warning("Setting system_atoms will overwrite parameters defined for system_atoms.")
245 1 : self._user_set_parts.add("system_atoms")
246 1 : self._system_atoms: tuple[SystemAtom, SystemAtom] = tuple(system_atoms) # type: ignore [assignment]
247 1 : self.basis_atoms = tuple(system.basis for system in system_atoms) # type: ignore [assignment]
248 :
249 1 : @property
250 1 : def electric_field(self) -> PintArray:
251 : """The electric field for the single-atom systems."""
252 1 : if self._electric_field is None:
253 1 : self.set_electric_field([0, 0, 0], "V/cm")
254 1 : assert self._electric_field is not None
255 1 : return self._electric_field
256 :
257 1 : def set_electric_field(
258 : self: Self,
259 : electric_field: PintArray | ArrayLike,
260 : unit: str | None = None,
261 : ) -> Self:
262 : """Set the electric field for the single-atom systems.
263 :
264 : Args:
265 : electric_field: The electric field to set for the systems.
266 : unit: The unit of the electric field, e.g. "V/cm".
267 : Default None expects a `pint.Quantity`.
268 :
269 : """
270 1 : self._delete_created()
271 1 : self._electric_field = QuantityArray.convert_user_to_pint(electric_field, unit, "electric_field")
272 1 : return self
273 :
274 1 : @property
275 1 : def magnetic_field(self) -> PintArray:
276 : """The magnetic field for the single-atom systems."""
277 1 : if self._magnetic_field is None:
278 1 : self.set_magnetic_field([0, 0, 0], "gauss")
279 1 : assert self._magnetic_field is not None
280 1 : return self._magnetic_field
281 :
282 1 : def set_magnetic_field(
283 : self: Self,
284 : magnetic_field: PintArray | ArrayLike,
285 : unit: str | None = None,
286 : ) -> Self:
287 : """Set the magnetic field for the single-atom systems.
288 :
289 : Args:
290 : magnetic_field: The magnetic field to set for the systems.
291 : unit: The unit of the magnetic field, e.g. "gauss".
292 : Default None expects a `pint.Quantity`.
293 :
294 : """
295 1 : self._delete_created()
296 1 : self._magnetic_field = QuantityArray.convert_user_to_pint(magnetic_field, unit, "magnetic_field")
297 1 : return self
298 :
299 1 : @property
300 1 : def _are_fields_along_z(self) -> bool:
301 1 : return all(x == 0 for x in [*self.magnetic_field[:2], *self.electric_field[:2]]) # type: ignore [index]
302 :
303 1 : @property
304 1 : def diamagnetism_enabled(self) -> bool:
305 : """Whether diamagnetism is enabled for the single-atom systems."""
306 1 : if self._diamagnetism_enabled is None:
307 1 : self.set_diamagnetism_enabled(False)
308 1 : assert self._diamagnetism_enabled is not None
309 1 : return self._diamagnetism_enabled
310 :
311 1 : def set_diamagnetism_enabled(self: Self, enable: bool = True) -> Self:
312 : """Enable or disable diamagnetism for the system.
313 :
314 : Args:
315 : enable: Whether to enable or disable diamagnetism.
316 :
317 : """
318 1 : self._delete_created("system_atoms")
319 1 : self._diamagnetism_enabled = enable
320 1 : return self
321 :
322 1 : def _create_system_atoms(self) -> None:
323 1 : system_atoms: list[SystemAtom] = []
324 1 : for basis_atom in self.basis_atoms:
325 1 : system = self._system_atom_class(basis_atom)
326 1 : system.set_diamagnetism_enabled(self.diamagnetism_enabled)
327 1 : system.set_electric_field(self.electric_field)
328 1 : system.set_magnetic_field(self.magnetic_field)
329 1 : system_atoms.append(system)
330 1 : diagonalize(system_atoms)
331 :
332 1 : self._system_atoms = tuple(system_atoms) # type: ignore [assignment]
333 :
334 : @overload
335 : def get_pair_energies(self, unit: None = None) -> list[PintFloat]: ...
336 :
337 : @overload
338 : def get_pair_energies(self, unit: str) -> list[float]: ...
339 :
340 1 : def get_pair_energies(self, unit: str | None = None) -> list[float] | list[PintFloat]:
341 : """Get the pair energies of the ket tuples for infinite distance (i.e. no interaction).
342 :
343 : Args:
344 : unit: The unit to which to convert the energies to.
345 : Default None will return a list of `pint.Quantity`.
346 :
347 : Returns:
348 : The energies as list of float if a unit was given, otherwise as list of `pint.Quantity`.
349 :
350 : """
351 1 : return [ # type: ignore [return-value]
352 : sum(
353 : system.get_corresponding_energy(ket, unit=unit)
354 : for system, ket in zip(self.system_atoms, ket_tuple, strict=True)
355 : )
356 : for ket_tuple in self.ket_tuples
357 : ]
358 :
359 : # # # BasisPair methods and attributes # # #
360 1 : @property
361 1 : def basis_pair(self) -> BasisPair:
362 : """The basis pair object for the pair system."""
363 1 : if not self._is_created("basis_pair"):
364 1 : self._create_basis_pair()
365 1 : return self._basis_pair
366 :
367 1 : @basis_pair.setter
368 1 : def basis_pair(self, basis_pair: BasisPair) -> None:
369 1 : self._ensure_not_created()
370 1 : if self._minimum_number_of_ket_pairs is not None or self._maximum_number_of_ket_pairs is not None:
371 0 : logger.warning("Setting basis_pair will overwrite parameters defined for basis_pair.")
372 1 : self._user_set_parts.add("basis_pair")
373 1 : self._basis_pair = basis_pair
374 1 : self.system_atoms = basis_pair.system_atoms
375 :
376 1 : def set_minimum_number_of_ket_pairs(self: Self, number_of_kets: int) -> Self:
377 : """Set the minimum number of ket pairs in the basis pair.
378 :
379 : Args:
380 : number_of_kets: The minimum number of ket pairs to set in the basis pair, by default we use 2000.
381 :
382 : """
383 1 : self._delete_created("basis_pair")
384 1 : if self._maximum_number_of_ket_pairs is not None and number_of_kets > self._maximum_number_of_ket_pairs:
385 0 : raise ValueError("The minimum number of ket pairs cannot be larger than the maximum number of ket pairs.")
386 1 : self._minimum_number_of_ket_pairs = number_of_kets
387 1 : return self
388 :
389 1 : def set_maximum_number_of_ket_pairs(self: Self, number_of_kets: int) -> Self:
390 : """Set the maximum number of ket pairs in the basis pair.
391 :
392 : Args:
393 : number_of_kets: The maximum number of ket pairs to set in the basis pair.
394 :
395 : """
396 0 : self._delete_created("basis_pair")
397 0 : if self._minimum_number_of_ket_pairs is not None and number_of_kets < self._minimum_number_of_ket_pairs:
398 0 : raise ValueError("The maximum number of ket pairs cannot be smaller than the minimum number of ket pairs.")
399 0 : self._maximum_number_of_ket_pairs = number_of_kets
400 0 : return self
401 :
402 1 : def _create_basis_pair(self) -> None:
403 1 : max_number_of_kets: float | None = self._maximum_number_of_ket_pairs
404 1 : if max_number_of_kets is None:
405 1 : max_number_of_kets = np.inf
406 1 : min_number_of_kets: float | None = self._minimum_number_of_ket_pairs
407 1 : if min_number_of_kets is None:
408 1 : min_number_of_kets = min(2_000, max_number_of_kets)
409 :
410 1 : pair_energies_au = self.get_pair_energies(unit="hartree")
411 1 : min_energy_au = min(pair_energies_au)
412 1 : max_energy_au = max(pair_energies_au)
413 :
414 1 : mhz_au = QuantityScalar.convert_user_to_au(1, "MHz", "energy")
415 1 : delta_energy_au = 100 * mhz_au
416 1 : min_delta, max_delta = 0.1 * mhz_au, 1.0
417 :
418 : # make a bisect search to get a sensible basis size between:
419 : # min_number_of_kets and max_number_of_kets
420 1 : while True:
421 1 : basis_pair = self._basis_pair_class(
422 : self.system_atoms,
423 : energy=(min_energy_au - delta_energy_au, max_energy_au + delta_energy_au),
424 : energy_unit="hartree",
425 : )
426 :
427 1 : if max_delta - min_delta < mhz_au:
428 0 : break # stop condition if delta_energy_au does not change anymore
429 :
430 1 : if basis_pair.number_of_kets < min_number_of_kets:
431 1 : min_delta = delta_energy_au
432 1 : delta_energy_au = min(2 * delta_energy_au, (delta_energy_au + max_delta) / 2)
433 1 : elif basis_pair.number_of_kets > max_number_of_kets:
434 0 : max_delta = delta_energy_au
435 0 : min_delta = max(0.5 * delta_energy_au, (delta_energy_au + min_delta) / 2)
436 : else:
437 1 : break
438 :
439 1 : self._basis_pair = basis_pair
440 1 : logger.debug("The pair basis for the perturbative calculations consists of %d kets.", basis_pair.number_of_kets)
441 :
442 : # # # SystemPair methods and attributes # # #
443 1 : @property
444 1 : def system_pair(self) -> SystemPair:
445 : """The system pair object for the pair system."""
446 1 : if not self._is_created("system_pair"):
447 1 : self._create_system_pair()
448 1 : return self._system_pair
449 :
450 1 : @system_pair.setter
451 1 : def system_pair(self, system_pair: SystemPair) -> None:
452 1 : self._ensure_not_created()
453 1 : if self._interaction_order is not None or self._distance_vector is not None:
454 0 : logger.warning("Setting system_pair will overwrite parameters defined for system_pair.")
455 1 : self._user_set_parts.add("system_pair")
456 1 : self._system_pair = system_pair
457 1 : self.basis_pair = system_pair.basis
458 :
459 1 : @property
460 1 : def interaction_order(self) -> int:
461 : """The interaction order for the pair system."""
462 1 : if self._interaction_order is None:
463 1 : self.set_interaction_order(3)
464 1 : return self._interaction_order # type: ignore [return-value]
465 :
466 1 : def set_interaction_order(self: Self, order: int) -> Self:
467 : """Set the interaction order of the pair system.
468 :
469 : Args:
470 : order: The interaction order to set for the pair system.
471 : The order must be 3, 4, or 5.
472 :
473 : """
474 1 : self._delete_created()
475 1 : self._interaction_order = order
476 1 : return self
477 :
478 1 : @property
479 1 : def distance_vector(self) -> PintArray:
480 : """The distance vector between the atoms in the pair system."""
481 1 : if self._distance_vector is None:
482 0 : self.set_distance_vector([0, 0, np.inf], "micrometer")
483 1 : return self._distance_vector # type: ignore [return-value]
484 :
485 1 : def set_distance(
486 : self: Self,
487 : distance: float | PintFloat,
488 : angle_degree: float = 0,
489 : unit: str | None = None,
490 : ) -> Self:
491 : """Set the distance between the atoms using the specified distance and angle.
492 :
493 : Args:
494 : distance: The distance to set between the atoms in the given unit.
495 : angle_degree: The angle between the distance vector and the z-axis in degrees.
496 : 90 degrees corresponds to the x-axis.
497 : Defaults to 0, which corresponds to the z-axis.
498 : unit: The unit of the distance, e.g. "micrometer".
499 : Default None expects a `pint.Quantity`.
500 :
501 : """
502 1 : distance_vector = [np.sin(np.deg2rad(angle_degree)) * distance, 0, np.cos(np.deg2rad(angle_degree)) * distance]
503 1 : return self.set_distance_vector(distance_vector, unit)
504 :
505 1 : def set_distance_vector(
506 : self: Self,
507 : distance: ArrayLike | PintArray,
508 : unit: str | None = None,
509 : ) -> Self:
510 : """Set the distance vector between the atoms.
511 :
512 : Args:
513 : distance: The distance vector to set between the atoms in the given unit.
514 : unit: The unit of the distance, e.g. "micrometer".
515 : Default None expects a `pint.Quantity`.
516 :
517 : """
518 1 : self._delete_created("system_pair")
519 1 : self._distance_vector = QuantityArray.convert_user_to_pint(distance, unit, "distance")
520 1 : return self
521 :
522 1 : def set_angle(
523 : self: Self,
524 : angle: float = 0,
525 : unit: Literal["degree", "radian"] = "degree",
526 : ) -> Self:
527 : """Set the angle between the atoms in degrees.
528 :
529 : Args:
530 : angle: The angle between the distance vector and the z-axis (by default in degrees).
531 : 90 degrees corresponds to the x-axis.
532 : Defaults to 0, which corresponds to the z-axis.
533 : unit: The unit of the angle, either "degree" or "radian", by default "degree".
534 :
535 : """
536 0 : assert unit in ("radian", "degree"), f"Unit {unit} is not supported for angle."
537 0 : if unit == "radian":
538 0 : angle = np.rad2deg(angle)
539 0 : distance_mum: float = np.linalg.norm(self.distance_vector.to("micrometer").magnitude) # type: ignore [assignment]
540 0 : return self.set_distance(distance_mum, angle, "micrometer")
541 :
542 1 : def _create_system_pair(self) -> None:
543 1 : system_pair = self._system_pair_class(self.basis_pair)
544 1 : system_pair.set_distance_vector(self.distance_vector)
545 1 : system_pair.set_interaction_order(self.interaction_order)
546 1 : self._system_pair = system_pair
547 :
548 : # # # Effective Hamiltonian methods and attributes # # #
549 : @overload
550 : def get_effective_hamiltonian(self, return_order: int | None = None, unit: None = None) -> PintArray: ...
551 :
552 : @overload
553 : def get_effective_hamiltonian(self, return_order: int | None = None, *, unit: str) -> NDArray: ...
554 :
555 1 : def get_effective_hamiltonian(
556 : self, return_order: int | None = None, unit: str | None = None
557 : ) -> NDArray | PintArray:
558 : """Get the effective Hamiltonian of the pair system.
559 :
560 : Args:
561 : return_order: The order of the perturbation to return.
562 : Default None, returns the sum up to the perturbation order set in the class.
563 : unit: The unit in which to return the effective Hamiltonian.
564 : If None, returns a pint array.
565 :
566 : Returns:
567 : The effective Hamiltonian of the pair system in the given unit.
568 : If unit is None, returns a pint array, otherwise returns a numpy array.
569 :
570 : """
571 1 : if self._eff_h_dict_au is None:
572 1 : self._create_effective_hamiltonian()
573 1 : assert self._eff_h_dict_au is not None
574 1 : if return_order is None:
575 1 : h_eff_au: NDArray = sum(self._eff_h_dict_au.values()) # type: ignore [assignment]
576 1 : elif return_order in self._eff_h_dict_au:
577 1 : h_eff_au = self._eff_h_dict_au[return_order]
578 : else:
579 0 : raise ValueError(
580 : f"The perturbation order {return_order} is not available in the effective Hamiltonian "
581 : f"with the specified perturbation_order {self.perturbation_order}."
582 : )
583 1 : return QuantityArray.convert_au_to_user(np.real_if_close(h_eff_au), "energy", unit)
584 :
585 1 : def get_effective_basisvectors(self) -> csr_matrix:
586 : """Get the eigenvectors of the perturbative Hamiltonian."""
587 0 : if len(self.model_inds) > 1 and self.perturbation_order > 2:
588 0 : logger.warning("For more than one state and perturbation_order > 2 the effective basis might be wrong.")
589 0 : if self._eff_vecs is None:
590 0 : self._create_effective_hamiltonian()
591 0 : assert self._eff_vecs is not None
592 0 : return self._eff_vecs
593 :
594 1 : def get_effective_basis(self) -> BasisPair:
595 : """Get the effective basis of the pair system."""
596 0 : raise NotImplementedError("The get effective basis method is not implemented yet.")
597 :
598 1 : def _create_effective_hamiltonian(self) -> None:
599 : """Calculate the perturbative Hamiltonian up to the given perturbation order."""
600 1 : hamiltonian_au = self.system_pair.get_hamiltonian(unit="hartree")
601 1 : eff_h_dict_au, eff_vecs = calculate_perturbative_hamiltonian(
602 : hamiltonian_au, self.model_inds, self.perturbation_order
603 : )
604 1 : self._eff_h_dict_au = eff_h_dict_au
605 1 : self._eff_vecs = eff_vecs
606 :
607 1 : self.check_for_resonances()
608 :
609 : # # # Other stuff # # #
610 1 : @cached_property
611 1 : def model_inds(self) -> list[int]:
612 : """The indices of the corresponding KetPairs of the given ket_tuples in the basis of the pair system."""
613 1 : model_inds = []
614 1 : for kets in self.ket_tuples:
615 1 : overlap = self.basis_pair.get_overlaps(kets)
616 1 : inds = np.argsort(overlap)[::-1]
617 1 : model_inds.append(int(inds[0]))
618 :
619 1 : if overlap[inds[0]] > 0.6:
620 1 : continue
621 0 : if overlap[inds[0]] == 0:
622 0 : raise ValueError(f"The pairstate {kets} is not part of the basis of the pair system.")
623 0 : logger.critical(
624 : "The ket_pair %s only has an overlap of %.3f with its corresponding pair_state."
625 : " The most perturbing states are:",
626 : kets,
627 : overlap[inds[0]],
628 : )
629 0 : for i in inds[1:5]:
630 0 : logger.error(" - %s with overlap %.3e", self.system_pair.basis.kets[i], overlap[i])
631 :
632 1 : return model_inds
633 :
634 1 : def check_for_resonances(self, required_overlap: float = 0.95) -> None:
635 : r"""Check if states of the model space have strong resonances with states outside the model space."""
636 : # Get the effective eigenvectors without potential warning
637 1 : if self._eff_vecs is None:
638 0 : self._create_effective_hamiltonian()
639 0 : assert self._eff_vecs is not None
640 1 : eff_vecs = self._eff_vecs
641 :
642 1 : overlaps = (eff_vecs.multiply(eff_vecs.conj())).real # elementwise multiplication
643 :
644 1 : model_inds = self.model_inds
645 1 : for i, m_ind in enumerate(model_inds):
646 1 : overlaps_i = overlaps[i, :]
647 :
648 1 : inf_data_inds = np.isinf(overlaps_i.data)
649 1 : if inf_data_inds.any():
650 1 : indices = overlaps_i.indices[np.argwhere(inf_data_inds).flatten()]
651 1 : logger.critical(
652 : "Detected 'inf' entries in the effective eigenvectors.\n"
653 : " This might happen, if you forgot to include a degenerate state in the model space.\n"
654 : " Consider adding the following states to the model space:"
655 : )
656 1 : for index in indices:
657 1 : logger.critical(" - %s has infinite admixture", self.system_pair.basis.kets[index])
658 1 : continue
659 :
660 1 : overlaps_i /= np.sum(overlaps_i.data) # normalize the overlaps to 1
661 1 : if overlaps_i[0, m_ind] >= required_overlap:
662 1 : continue
663 1 : logger.error(
664 : "The ket %s has only %.3f overlap with its corresponding effective eigenvector.\n"
665 : " Thus, the calculation might lead to unexpected or wrong results.\n"
666 : " Consider adding the most perturbing states to the model space.\n"
667 : " The most perturbing states are:",
668 : self.system_pair.basis.kets[m_ind],
669 : overlaps_i[0, m_ind],
670 : )
671 1 : print_above_admixture = (1 - overlaps_i[0, m_ind]) * 0.05
672 1 : indices = list(sparse.find(overlaps_i >= print_above_admixture)[1])
673 1 : indices = sorted(indices, key=lambda index, ov=overlaps_i: ov[0, index], reverse=True) # type: ignore [misc]
674 1 : for index in indices:
675 1 : if index != m_ind:
676 1 : admixture = overlaps_i[0, index]
677 1 : logger.error(" - %s with overlap %.3e", self.system_pair.basis.kets[index], admixture)
678 :
679 :
680 1 : class EffectiveSystemPairReal(EffectiveSystemPair):
681 1 : _basis_atom_class = BasisAtomReal
682 1 : _basis_pair_class = BasisPairReal
683 1 : _system_atom_class = SystemAtomReal
684 1 : _system_pair_class = SystemPairReal
685 :
686 :
687 1 : @lru_cache(maxsize=20)
688 1 : def get_basis_atom_with_cache(
689 : species: str, n: tuple[int, int], l: tuple[int, int], m: tuple[int, int], *, use_real: bool
690 : ) -> BasisAtom:
691 : """Get a BasisAtom object potentially by using a cache to avoid recomputing it."""
692 1 : if use_real:
693 1 : return BasisAtomReal(species, n=n, l=l, m=m)
694 1 : return BasisAtom(species, n=n, l=l, m=m)
|