Line data Source code
1 : # SPDX-FileCopyrightText: 2024 PairInteraction Developers
2 : # SPDX-License-Identifier: LGPL-3.0-or-later
3 :
4 : # ruff: noqa: E402
5 1 : from __future__ import annotations
6 :
7 1 : import os
8 1 : from typing import TYPE_CHECKING
9 :
10 : if TYPE_CHECKING:
11 : from collections.abc import Callable
12 :
13 :
14 1 : def _setup_dynamic_libraries() -> None: # noqa: C901, PLR0915
15 1 : import platform
16 1 : import sys
17 1 : from importlib.metadata import PackageNotFoundError, files, version
18 1 : from pathlib import Path
19 1 : from warnings import warn
20 :
21 1 : from pairinteraction._info import Info
22 :
23 : # ---------------------------------------------------------------------------------------
24 : # Helper functions
25 : # ---------------------------------------------------------------------------------------
26 1 : def is_package_installed(package: str) -> bool:
27 : """Check whether a package is installed."""
28 1 : try:
29 1 : version(package)
30 0 : except PackageNotFoundError:
31 0 : return False
32 : else:
33 1 : return True
34 :
35 1 : def is_running_under_pyinstaller() -> bool:
36 0 : return getattr(sys, "frozen", False) and hasattr(sys, "_MEIPASS")
37 :
38 1 : def find_library_file_in_package(substring: str, package: str) -> Path | None:
39 : """Find the library file if it got installed together with the package."""
40 1 : if (package_files := files(package)) is not None:
41 1 : for p in package_files:
42 1 : if substring in p.stem:
43 1 : return Path(p.locate()).resolve()
44 0 : return None
45 :
46 1 : def find_library_file_in_directory(substring: str, directory: Path = Path(__file__).parent) -> Path | None:
47 : """Find the library file if it is in the specified directory."""
48 0 : for p in directory.glob("*"):
49 0 : if substring in p.stem:
50 0 : return p.resolve()
51 0 : return None
52 :
53 1 : def load_candidate(candidate: Path, loader: Callable[[str], object]) -> None:
54 1 : try:
55 1 : loader(str(candidate))
56 0 : except Exception as e:
57 0 : warn(f"Unable to load {candidate.name}: {e}", RuntimeWarning, stacklevel=2)
58 :
59 1 : add_dll_directory: Callable[[str], object] | None = getattr(os, "add_dll_directory", None)
60 :
61 : # ---------------------------------------------------------------------------------------
62 : # Load shared libraries
63 : # ---------------------------------------------------------------------------------------
64 :
65 1 : def fix_ssl() -> None:
66 : """Fix SSL library loading issues under Windows."""
67 0 : assert add_dll_directory is not None
68 :
69 : # If PairInteraction was installed, the SSL library might have been installed together with it
70 0 : if (
71 : is_package_installed("pairinteraction")
72 : and find_library_file_in_package("ssl", "pairinteraction") is not None
73 : ):
74 0 : return
75 :
76 : # If PairInteraction is running under PyInstaller, the SSL library might have been bundled with the executable
77 0 : if is_running_under_pyinstaller() and find_library_file_in_directory("ssl", Path(__file__).parent) is not None:
78 0 : return
79 :
80 : # Else, PairInteraction is probably running in development mode and we add a bunch of directories to
81 : # the DLL search path to avoid loading issues
82 0 : possible_dirs = [
83 : Path.cwd(),
84 : # look in cwd.parents
85 : Path.cwd().parent,
86 : Path.cwd().parent.parent,
87 : Path.cwd().parent.parent.parent,
88 : # and __file__.parents[2] (for editable installs)
89 : Path(__file__).resolve().parent.parent.parent,
90 : Path(__file__).resolve().parent.parent.parent.parent,
91 : ]
92 0 : possible_paths = [d / "vcpkg_installed" / "x64-windows" / "bin" for d in possible_dirs]
93 0 : for path in possible_paths:
94 0 : if path.is_dir():
95 0 : add_dll_directory(str(path))
96 :
97 1 : def load_mkl(system: str) -> None:
98 1 : import ctypes
99 1 : from functools import partial
100 :
101 1 : if not is_package_installed("mkl"):
102 0 : raise RuntimeError("The 'mkl' library is not installed.")
103 :
104 1 : path = find_library_file_in_package("mkl_core", "mkl")
105 1 : if path is None:
106 0 : raise RuntimeError("The 'mkl_core' library could not be found.")
107 :
108 1 : mkl_lib_dir = path.parent
109 :
110 1 : mkl_lib_file_names = [
111 : "mkl_core", # must be loaded first
112 : "mkl_tbb_thread", # must be loaded second
113 : "mkl_avx2",
114 : "mkl_avx512",
115 : "mkl_def",
116 : "mkl_intel_lp64",
117 : "mkl_mc3",
118 : "mkl_rt",
119 : "mkl_vml_avx2",
120 : "mkl_vml_avx512",
121 : "mkl_vml_cmpt",
122 : "mkl_vml_def",
123 : "mkl_vml_mc3",
124 : "mkl_sequential", # needed for pytest with cpp coverage
125 : ]
126 :
127 1 : if system == "Linux":
128 : # Under linux, the libraries must always be loaded manually in the address space
129 1 : tbb_lib_file: Path | None = None
130 1 : if is_package_installed("tbb"):
131 1 : tbb_lib_file = find_library_file_in_package("tbb", "tbb")
132 1 : if tbb_lib_file is None and is_package_installed("pairinteraction"):
133 0 : tbb_lib_file = find_library_file_in_package("tbb", "pairinteraction")
134 1 : if tbb_lib_file is None:
135 0 : raise RuntimeError("The 'tbb' library could not be found.")
136 1 : load_candidate(tbb_lib_file, partial(ctypes.CDLL, mode=os.RTLD_LAZY | os.RTLD_GLOBAL))
137 :
138 1 : for lib in mkl_lib_file_names:
139 1 : candidate = mkl_lib_dir / f"lib{lib}.so.2"
140 1 : load_candidate(candidate, partial(ctypes.CDLL, mode=os.RTLD_LAZY | os.RTLD_GLOBAL))
141 :
142 0 : elif system == "Windows":
143 0 : assert add_dll_directory is not None
144 :
145 : # Modify the dll search path
146 0 : add_dll_directory(str(mkl_lib_dir))
147 :
148 : else:
149 0 : warn(f"Cannot load MKL libraries on unsupported system {system}.", RuntimeWarning, stacklevel=2)
150 :
151 1 : system = platform.system()
152 1 : if system == "Windows":
153 0 : fix_ssl()
154 1 : if Info.with_mkl:
155 1 : load_mkl(system)
156 :
157 :
158 1 : _setup_dynamic_libraries()
159 :
160 :
161 1 : def _setup_ca_bundle() -> None:
162 1 : import certifi
163 :
164 1 : from pairinteraction._backend import set_ca_bundle_path
165 :
166 1 : set_ca_bundle_path(certifi.where())
167 :
168 :
169 1 : _setup_ca_bundle()
170 :
171 :
172 : # ---------------------------------------------------------------------------------------
173 : # Configure PairInteraction for running tests with a local database if requested
174 : # ---------------------------------------------------------------------------------------
175 1 : def _setup_test_mode(download_missing: bool = False, database_dir: str | None = None) -> None:
176 1 : from pathlib import Path
177 :
178 1 : from pairinteraction.database import Database
179 :
180 1 : if database_dir is None:
181 1 : possible_dirs = [
182 : Path.cwd(),
183 : # look in cwd.parents
184 : Path.cwd().parent,
185 : Path.cwd().parent.parent,
186 : Path.cwd().parent.parent.parent,
187 : # and __file__.parents[2] (for editable installs)
188 : Path(__file__).resolve().parent.parent.parent,
189 : ]
190 1 : possible_paths = [d / "data" / "database" for d in possible_dirs]
191 :
192 1 : for path in possible_paths:
193 1 : if any(path.rglob("wigner.parquet")):
194 1 : database_dir = str(path)
195 1 : break
196 :
197 1 : if database_dir is None:
198 0 : raise FileNotFoundError("Could not find database directory")
199 :
200 1 : Database.initialize_global_database(download_missing, True, database_dir)
201 :
202 :
203 1 : if os.getenv("PAIRINTERACTION_TEST_MODE", "0") == "1":
204 0 : download_missing = bool(int(os.getenv("PAIRINTERACTION_TEST_DOWNLOAD_MISSING", "0")))
205 0 : database_dir = os.getenv("PAIRINTERACTION_TEST_DATABASE_DIR", None)
206 0 : _setup_test_mode(download_missing, database_dir)
207 :
208 :
209 : # ---------------------------------------------------------------------------------------
210 : # Decorate all functions in _backend with a decorator that flushes pending logs
211 : # ---------------------------------------------------------------------------------------
212 1 : def _setup_logging() -> None:
213 1 : from pairinteraction import _backend
214 1 : from pairinteraction.custom_logging import decorate_module_with_flush_logs
215 :
216 1 : decorate_module_with_flush_logs(_backend)
217 :
218 :
219 1 : _setup_logging()
220 1 : del _setup_logging
221 1 : del _setup_ca_bundle
222 :
223 : # ---------------------------------------------------------------------------------------
224 : # Import pairinteraction
225 : # ---------------------------------------------------------------------------------------
226 1 : from pairinteraction import (
227 : green_tensor,
228 : perturbative,
229 : real,
230 : visualization,
231 : )
232 1 : from pairinteraction._backend import (
233 : VERSION_MAJOR as _VERSION_MAJOR,
234 : VERSION_MINOR as _VERSION_MINOR,
235 : VERSION_PATCH as _VERSION_PATCH,
236 : run_unit_tests,
237 : )
238 1 : from pairinteraction.basis import BasisAtom, BasisPair
239 1 : from pairinteraction.custom_logging import configure_logging
240 1 : from pairinteraction.database import Database, print_database_info
241 1 : from pairinteraction.diagonalization import diagonalize
242 1 : from pairinteraction.ket import KetAtom, KetPair
243 1 : from pairinteraction.perturbative import C3, C6, EffectiveSystemPair
244 1 : from pairinteraction.state import StateAtom, StatePair
245 1 : from pairinteraction.system import SystemAtom, SystemPair
246 1 : from pairinteraction.units import ureg
247 :
248 1 : __all__ = [
249 : "C3",
250 : "C6",
251 : "BasisAtom",
252 : "BasisPair",
253 : "Database",
254 : "EffectiveSystemPair",
255 : "KetAtom",
256 : "KetPair",
257 : "StateAtom",
258 : "StatePair",
259 : "SystemAtom",
260 : "SystemPair",
261 : "configure_logging",
262 : "diagonalize",
263 : "green_tensor",
264 : "perturbative",
265 : "print_database_info",
266 : "real",
267 : "run_unit_tests",
268 : "ureg",
269 : "visualization",
270 : ]
271 :
272 1 : __version__ = f"{_VERSION_MAJOR}.{_VERSION_MINOR}.{_VERSION_PATCH}"
273 :
274 :
275 : # ---------------------------------------------------------------------------------------
276 : # Clean up namespace
277 : # ---------------------------------------------------------------------------------------
278 1 : del _VERSION_MAJOR, _VERSION_MINOR, _VERSION_PATCH
279 1 : del _setup_dynamic_libraries # don't delete _setup_test_mode, since it is used in tests/conftest.py
|