Line data Source code
1 : # SPDX-FileCopyrightText: 2025 PairInteraction Developers
2 : # SPDX-License-Identifier: LGPL-3.0-or-later
3 :
4 1 : import argparse
5 1 : import sys
6 1 : from pathlib import Path
7 1 : from typing import TYPE_CHECKING, cast
8 :
9 1 : from colorama import Fore, Style
10 :
11 1 : from pairinteraction import __version__, configure_logging
12 :
13 : if TYPE_CHECKING:
14 : from collections.abc import Callable
15 :
16 :
17 1 : class HelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescriptionHelpFormatter):
18 : """Show default arguments and keep manual line breaks in the epilog."""
19 :
20 :
21 1 : def main() -> int:
22 : """Entry point for the PairInteraction CLI."""
23 0 : parser = argparse.ArgumentParser(
24 : description=("PairInteraction CLI\n\nRun 'pairinteraction' without a command to launch the GUI."),
25 : formatter_class=HelpFormatter,
26 : epilog=(
27 : "Examples:\n"
28 : " pairinteraction\n"
29 : " pairinteraction --log-level INFO\n"
30 : " pairinteraction --log-level INFO test\n"
31 : " pairinteraction database list\n"
32 : " pairinteraction database download Rb Cs\n"
33 : " pairinteraction database download https://github.com/pairinteraction/database-sqdt/releases/download/v1.2/Rb_v1.2.zip\n"
34 : " pairinteraction database remove\n"
35 : " pairinteraction config reset-gui\n"
36 : " pairinteraction config paths\n"
37 : "\n"
38 : "Command-specific help:\n"
39 : " pairinteraction test --help\n"
40 : " pairinteraction database --help\n"
41 : " pairinteraction config --help"
42 : ),
43 : )
44 0 : parser.add_argument("--version", action="version", version=f"PairInteraction v{__version__}")
45 0 : parser.add_argument(
46 : "--reload",
47 : action="store_true",
48 : help="launch the GUI with automatic theme reload during development",
49 : )
50 0 : parser.add_argument(
51 : "--log-level",
52 : default="WARNING",
53 : choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
54 : help="set the logging level",
55 : )
56 :
57 : # Launch GUI (default action)
58 0 : parser.set_defaults(func=lambda args: start_gui(reload=args.reload))
59 :
60 0 : subparsers = parser.add_subparsers(dest="command", title="commands", metavar="{test,database,config}")
61 :
62 : # Removed launch command
63 0 : gui_parser = subparsers.add_parser(
64 : "gui",
65 : add_help=False,
66 : )
67 0 : gui_parser.set_defaults(
68 : func=lambda _args: parser.error(
69 : "The 'gui' subcommand no longer exists. To launch the GUI, run 'pairinteraction' without a command."
70 : )
71 : )
72 :
73 : # Test command
74 0 : test_parser = subparsers.add_parser(
75 : "test",
76 : formatter_class=HelpFormatter,
77 : help="run tests",
78 : )
79 0 : test_parser.set_defaults(func=lambda _args: run_unit_tests())
80 :
81 : # Database command group
82 0 : database_parser = subparsers.add_parser(
83 : "database",
84 : formatter_class=HelpFormatter,
85 : help="manage and inspect the database",
86 : )
87 0 : database_subparsers = database_parser.add_subparsers(dest="database_command", title="database commands")
88 :
89 : # Database list command
90 0 : db_list_parser = database_subparsers.add_parser(
91 : "list",
92 : formatter_class=HelpFormatter,
93 : help="list local and remote database table versions",
94 : )
95 0 : db_list_parser.set_defaults(func=lambda _args: list_databases())
96 :
97 : # Database download command
98 0 : db_download_parser = database_subparsers.add_parser(
99 : "download",
100 : formatter_class=HelpFormatter,
101 : help="download database tables for one or more species",
102 : )
103 0 : db_download_parser.add_argument("species", nargs="+", help="list of species to download data for / list of URLs")
104 0 : db_download_parser.set_defaults(func=lambda args: download_databases(args.species))
105 :
106 : # Database remove command
107 0 : db_remove_parser = database_subparsers.add_parser(
108 : "remove",
109 : formatter_class=HelpFormatter,
110 : help="delete the cached database directory",
111 : )
112 0 : db_remove_parser.set_defaults(func=lambda _args: remove_database_cache())
113 :
114 0 : database_parser.set_defaults(func=lambda _args: print_help(database_parser))
115 :
116 : # Config command group
117 0 : config_parser = subparsers.add_parser(
118 : "config",
119 : formatter_class=HelpFormatter,
120 : help="manage GUI settings and inspect paths",
121 : )
122 0 : config_subparsers = config_parser.add_subparsers(dest="config_command", title="config commands")
123 :
124 : # Config reset gui command
125 0 : config_reset_gui_parser = config_subparsers.add_parser(
126 : "reset-gui",
127 : formatter_class=HelpFormatter,
128 : help="delete GUI settings file to restore defaults",
129 : )
130 0 : config_reset_gui_parser.set_defaults(func=lambda _args: reset_gui_settings())
131 :
132 : # Config list paths command
133 0 : config_list_paths_parser = config_subparsers.add_parser(
134 : "paths",
135 : formatter_class=HelpFormatter,
136 : help="show config and cache directories",
137 : )
138 0 : config_list_paths_parser.set_defaults(func=lambda _args: show_paths())
139 :
140 0 : config_parser.set_defaults(func=lambda _args: print_help(config_parser))
141 :
142 0 : args = parser.parse_args()
143 :
144 0 : if args.command is not None and args.reload:
145 0 : parser.error("--reload can only be used when launching the GUI")
146 :
147 0 : configure_logging(args.log_level)
148 :
149 0 : return cast("Callable[[argparse.Namespace], int]", args.func)(args)
150 :
151 :
152 1 : def print_help(parser: argparse.ArgumentParser) -> int:
153 : """Print help."""
154 0 : parser.print_help()
155 0 : return 0
156 :
157 :
158 1 : def start_gui(*, reload: bool = False) -> int:
159 : """Launch the GUI."""
160 0 : from pairinteraction_gui import main as gui_main
161 :
162 0 : print("Launching the GUI...")
163 0 : gui_main(enable_theme_hot_reload=reload)
164 0 : return 0
165 :
166 :
167 1 : def reset_gui_settings() -> int:
168 : """Delete the GUI settings file to restore default values."""
169 0 : from pairinteraction._backend import get_cache_directory
170 :
171 0 : settings_file = get_cache_directory() / "gui_settings.ini"
172 :
173 0 : if not settings_file.exists():
174 0 : print("No GUI settings file found. Nothing to delete.")
175 0 : return 0
176 :
177 0 : confirmation = input(f"Are you sure you want to delete the GUI settings file {settings_file}? (y/N): ")
178 0 : if confirmation.lower() not in ["y", "yes"]:
179 0 : print(Fore.YELLOW + "Aborted deletion of GUI settings." + Style.RESET_ALL)
180 0 : return 0
181 :
182 0 : try:
183 0 : settings_file.unlink()
184 0 : except Exception as e:
185 0 : print(Fore.RED + f"Error while deleting GUI settings file: {e}" + Style.RESET_ALL)
186 0 : return 1
187 :
188 0 : print(Fore.GREEN + "GUI settings deleted. Default values will be used on next launch." + Style.RESET_ALL)
189 0 : return 0
190 :
191 :
192 1 : def run_unit_tests() -> int:
193 : """Run the C++ module unit tests."""
194 0 : from pairinteraction import run_unit_tests
195 :
196 0 : print("Running the C++ module unit tests...")
197 0 : exit_code = run_unit_tests(download_missing=True)
198 0 : if exit_code:
199 0 : print(Fore.RED + "Tests failed." + Style.RESET_ALL)
200 : else:
201 0 : print(Fore.GREEN + "Tests passed." + Style.RESET_ALL)
202 0 : return exit_code
203 :
204 :
205 1 : def _download_database_from_url(url: str, tables_dir: Path) -> int:
206 0 : import shutil
207 0 : import tempfile
208 0 : from urllib.request import urlretrieve
209 0 : from zipfile import ZipFile
210 :
211 0 : from packaging.version import Version
212 :
213 0 : try:
214 0 : with tempfile.TemporaryDirectory() as td:
215 0 : tmp = Path(td) / "tables.zip"
216 :
217 0 : try:
218 0 : msg = f"Downloading {url}..."
219 0 : print(msg, end="", flush=True)
220 :
221 0 : def _hook(blocks: int, block_size: int, total_size: int, _msg: str = msg) -> None:
222 0 : if total_size > 0:
223 0 : pct = min(100, int(blocks * block_size * 100 / total_size))
224 0 : print(f"\r{_msg} {pct:3}%", end="", flush=True)
225 :
226 0 : urlretrieve(url, tmp, reporthook=_hook) # noqa: S310
227 : finally:
228 0 : print()
229 :
230 0 : with ZipFile(tmp) as z:
231 0 : roots = {Path(n).parts[0] for n in z.namelist() if Path(n).parts}
232 0 : root = roots.pop() if len(roots) == 1 else None
233 :
234 0 : if not root or f"{root}/" not in z.namelist():
235 0 : raise ValueError("The ZIP archive must contain exactly one top-level folder.") # noqa: TRY301
236 0 : species, version_str = root.rsplit("_v", 1)
237 0 : Version(version_str) # validate version
238 :
239 0 : to_delete = list(tables_dir.glob(f"{species}*"))
240 0 : confirmation = input(
241 : f"Do you want delete the tables in {', '.join(p.name for p in to_delete)} and "
242 : "replace them with the downloaded tables? (y/N): "
243 : )
244 0 : if confirmation.lower() not in ["y", "yes"]:
245 0 : print(Fore.YELLOW + "Aborted replacing tables." + Style.RESET_ALL)
246 0 : return 1
247 0 : for p in to_delete:
248 0 : shutil.rmtree(p)
249 :
250 0 : shutil.unpack_archive(tmp, tables_dir, format="zip")
251 :
252 0 : except Exception as e:
253 0 : print(Fore.RED + f"Failed: {e}" + Style.RESET_ALL)
254 0 : return 1
255 :
256 : else:
257 0 : print(Fore.GREEN + "Successful." + Style.RESET_ALL)
258 0 : return 0
259 :
260 :
261 1 : def download_databases(species_list: list[str]) -> int:
262 : """Download the required data files for the specified species."""
263 0 : from urllib.parse import urlparse
264 :
265 0 : import pairinteraction as pi
266 0 : from pairinteraction._backend import get_cache_directory
267 :
268 0 : database_dir = get_cache_directory() / "database"
269 0 : tables_dir = database_dir / "tables"
270 0 : tables_dir.mkdir(parents=True, exist_ok=True)
271 0 : database = pi.Database(download_missing=True, use_cache=False, database_dir=database_dir)
272 :
273 0 : is_wigner_downloaded = False
274 0 : exit_code = 0
275 :
276 0 : for species in species_list:
277 : # If species is a URL, download and unzip to database/tables
278 0 : if urlparse(species).scheme in {"http", "https"}:
279 0 : print("Check for tables...")
280 0 : exit_code |= _download_database_from_url(species, tables_dir)
281 0 : continue
282 :
283 0 : try:
284 0 : print(f"Check for tables for {species}...")
285 :
286 : # We make use of the fact that all tables of a species get downloaded
287 : # automatically when we create a BasisAtom object.
288 0 : basis = pi.BasisAtom(species, n=(50, 51), l=(0, 2), database=database)
289 :
290 : # We calculate matrix elements to ensure that the Wigner table is
291 : # downloaded as well.
292 0 : if not is_wigner_downloaded:
293 0 : basis.get_matrix_elements(basis, "electric_dipole", 0)
294 0 : is_wigner_downloaded = True
295 :
296 0 : print(Fore.GREEN + "Successful." + Style.RESET_ALL)
297 0 : except Exception as e:
298 0 : exit_code = 1
299 0 : print(Fore.RED + f"Failed: {e}" + Style.RESET_ALL)
300 :
301 0 : return exit_code
302 :
303 :
304 1 : def list_databases() -> int:
305 : """Print a table of local and remote database table versions."""
306 0 : from pairinteraction.database import print_database_info
307 :
308 0 : print_database_info()
309 0 : return 0
310 :
311 :
312 1 : def show_paths() -> int:
313 : """Show config and cache directories."""
314 0 : from pairinteraction._backend import get_cache_directory, get_config_directory
315 :
316 0 : print("Config directory:", get_config_directory())
317 0 : print("Cache directory:", get_cache_directory())
318 0 : print("Database directory:", get_cache_directory() / "database/tables")
319 0 : return 0
320 :
321 :
322 1 : def remove_database_cache() -> int:
323 : """Delete the cached database directory."""
324 0 : import shutil
325 :
326 0 : from pairinteraction._backend import get_cache_directory
327 :
328 0 : database_dir = get_cache_directory() / "database"
329 :
330 0 : confirmation = input(f"Are you sure you want to delete all downloaded database tables in {database_dir}? (y/N): ")
331 0 : if confirmation.lower() not in ["y", "yes"]:
332 0 : print(Fore.YELLOW + "Aborted deletion of database directory." + Style.RESET_ALL)
333 0 : return 0
334 :
335 0 : print(f"Deleting cached database directory {database_dir}...")
336 0 : try:
337 0 : shutil.rmtree(database_dir)
338 0 : except Exception as e:
339 0 : print(Fore.RED + f"Error while deleting database directory: {e}" + Style.RESET_ALL)
340 0 : return 1
341 :
342 0 : print(Fore.GREEN + "Database directory deleted." + Style.RESET_ALL)
343 0 : return 0
344 :
345 :
346 1 : if __name__ == "__main__":
347 0 : sys.exit(main())
|