LCOV - code coverage report
Current view: top level - src/pairinteraction - cli.py (source / functions) Hit Total Coverage
Test: coverage.info Lines: 18 165 10.9 %
Date: 2026-04-30 10:43:26 Functions: 0 11 0.0 %

          Line data    Source code
       1             : # SPDX-FileCopyrightText: 2025 PairInteraction Developers
       2             : # SPDX-License-Identifier: LGPL-3.0-or-later
       3             : 
       4           1 : import argparse
       5           1 : import sys
       6           1 : from pathlib import Path
       7           1 : from typing import TYPE_CHECKING, cast
       8             : 
       9           1 : from colorama import Fore, Style
      10             : 
      11           1 : from pairinteraction import __version__, configure_logging
      12             : 
      13             : if TYPE_CHECKING:
      14             :     from collections.abc import Callable
      15             : 
      16             : 
      17           1 : class HelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescriptionHelpFormatter):
      18             :     """Show default arguments and keep manual line breaks in the epilog."""
      19             : 
      20             : 
      21           1 : def main() -> int:
      22             :     """Entry point for the PairInteraction CLI."""
      23           0 :     parser = argparse.ArgumentParser(
      24             :         description=("PairInteraction CLI\n\nRun 'pairinteraction' without a command to launch the GUI."),
      25             :         formatter_class=HelpFormatter,
      26             :         epilog=(
      27             :             "Examples:\n"
      28             :             "  pairinteraction\n"
      29             :             "  pairinteraction --log-level INFO\n"
      30             :             "  pairinteraction --log-level INFO test\n"
      31             :             "  pairinteraction database list\n"
      32             :             "  pairinteraction database download Rb Cs\n"
      33             :             "  pairinteraction database download https://github.com/pairinteraction/database-sqdt/releases/download/v1.2/Rb_v1.2.zip\n"
      34             :             "  pairinteraction database remove\n"
      35             :             "  pairinteraction config reset-gui\n"
      36             :             "  pairinteraction config paths\n"
      37             :             "\n"
      38             :             "Command-specific help:\n"
      39             :             "  pairinteraction test --help\n"
      40             :             "  pairinteraction database --help\n"
      41             :             "  pairinteraction config --help"
      42             :         ),
      43             :     )
      44           0 :     parser.add_argument("--version", action="version", version=f"PairInteraction v{__version__}")
      45           0 :     parser.add_argument(
      46             :         "--reload",
      47             :         action="store_true",
      48             :         help="launch the GUI with automatic theme reload during development",
      49             :     )
      50           0 :     parser.add_argument(
      51             :         "--log-level",
      52             :         default="WARNING",
      53             :         choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
      54             :         help="set the logging level",
      55             :     )
      56             : 
      57             :     # Launch GUI (default action)
      58           0 :     parser.set_defaults(func=lambda args: start_gui(reload=args.reload))
      59             : 
      60           0 :     subparsers = parser.add_subparsers(dest="command", title="commands", metavar="{test,database,config}")
      61             : 
      62             :     # Removed launch command
      63           0 :     gui_parser = subparsers.add_parser(
      64             :         "gui",
      65             :         add_help=False,
      66             :     )
      67           0 :     gui_parser.set_defaults(
      68             :         func=lambda _args: parser.error(
      69             :             "The 'gui' subcommand no longer exists. To launch the GUI, run 'pairinteraction' without a command."
      70             :         )
      71             :     )
      72             : 
      73             :     # Test command
      74           0 :     test_parser = subparsers.add_parser(
      75             :         "test",
      76             :         formatter_class=HelpFormatter,
      77             :         help="run tests",
      78             :     )
      79           0 :     test_parser.set_defaults(func=lambda _args: run_unit_tests())
      80             : 
      81             :     # Database command group
      82           0 :     database_parser = subparsers.add_parser(
      83             :         "database",
      84             :         formatter_class=HelpFormatter,
      85             :         help="manage and inspect the database",
      86             :     )
      87           0 :     database_subparsers = database_parser.add_subparsers(dest="database_command", title="database commands")
      88             : 
      89             :     # Database list command
      90           0 :     db_list_parser = database_subparsers.add_parser(
      91             :         "list",
      92             :         formatter_class=HelpFormatter,
      93             :         help="list local and remote database table versions",
      94             :     )
      95           0 :     db_list_parser.set_defaults(func=lambda _args: list_databases())
      96             : 
      97             :     # Database download command
      98           0 :     db_download_parser = database_subparsers.add_parser(
      99             :         "download",
     100             :         formatter_class=HelpFormatter,
     101             :         help="download database tables for one or more species",
     102             :     )
     103           0 :     db_download_parser.add_argument("species", nargs="+", help="list of species to download data for / list of URLs")
     104           0 :     db_download_parser.set_defaults(func=lambda args: download_databases(args.species))
     105             : 
     106             :     # Database remove command
     107           0 :     db_remove_parser = database_subparsers.add_parser(
     108             :         "remove",
     109             :         formatter_class=HelpFormatter,
     110             :         help="delete the cached database directory",
     111             :     )
     112           0 :     db_remove_parser.set_defaults(func=lambda _args: remove_database_cache())
     113             : 
     114           0 :     database_parser.set_defaults(func=lambda _args: print_help(database_parser))
     115             : 
     116             :     # Config command group
     117           0 :     config_parser = subparsers.add_parser(
     118             :         "config",
     119             :         formatter_class=HelpFormatter,
     120             :         help="manage GUI settings and inspect paths",
     121             :     )
     122           0 :     config_subparsers = config_parser.add_subparsers(dest="config_command", title="config commands")
     123             : 
     124             :     # Config reset gui command
     125           0 :     config_reset_gui_parser = config_subparsers.add_parser(
     126             :         "reset-gui",
     127             :         formatter_class=HelpFormatter,
     128             :         help="delete GUI settings file to restore defaults",
     129             :     )
     130           0 :     config_reset_gui_parser.set_defaults(func=lambda _args: reset_gui_settings())
     131             : 
     132             :     # Config list paths command
     133           0 :     config_list_paths_parser = config_subparsers.add_parser(
     134             :         "paths",
     135             :         formatter_class=HelpFormatter,
     136             :         help="show config and cache directories",
     137             :     )
     138           0 :     config_list_paths_parser.set_defaults(func=lambda _args: show_paths())
     139             : 
     140           0 :     config_parser.set_defaults(func=lambda _args: print_help(config_parser))
     141             : 
     142           0 :     args = parser.parse_args()
     143             : 
     144           0 :     if args.command is not None and args.reload:
     145           0 :         parser.error("--reload can only be used when launching the GUI")
     146             : 
     147           0 :     configure_logging(args.log_level)
     148             : 
     149           0 :     return cast("Callable[[argparse.Namespace], int]", args.func)(args)
     150             : 
     151             : 
     152           1 : def print_help(parser: argparse.ArgumentParser) -> int:
     153             :     """Print help."""
     154           0 :     parser.print_help()
     155           0 :     return 0
     156             : 
     157             : 
     158           1 : def start_gui(*, reload: bool = False) -> int:
     159             :     """Launch the GUI."""
     160           0 :     from pairinteraction_gui import main as gui_main
     161             : 
     162           0 :     print("Launching the GUI...")
     163           0 :     gui_main(enable_theme_hot_reload=reload)
     164           0 :     return 0
     165             : 
     166             : 
     167           1 : def reset_gui_settings() -> int:
     168             :     """Delete the GUI settings file to restore default values."""
     169           0 :     from pairinteraction._backend import get_cache_directory
     170             : 
     171           0 :     settings_file = get_cache_directory() / "gui_settings.ini"
     172             : 
     173           0 :     if not settings_file.exists():
     174           0 :         print("No GUI settings file found. Nothing to delete.")
     175           0 :         return 0
     176             : 
     177           0 :     confirmation = input(f"Are you sure you want to delete the GUI settings file {settings_file}? (y/N): ")
     178           0 :     if confirmation.lower() not in ["y", "yes"]:
     179           0 :         print(Fore.YELLOW + "Aborted deletion of GUI settings." + Style.RESET_ALL)
     180           0 :         return 0
     181             : 
     182           0 :     try:
     183           0 :         settings_file.unlink()
     184           0 :     except Exception as e:
     185           0 :         print(Fore.RED + f"Error while deleting GUI settings file: {e}" + Style.RESET_ALL)
     186           0 :         return 1
     187             : 
     188           0 :     print(Fore.GREEN + "GUI settings deleted. Default values will be used on next launch." + Style.RESET_ALL)
     189           0 :     return 0
     190             : 
     191             : 
     192           1 : def run_unit_tests() -> int:
     193             :     """Run the C++ module unit tests."""
     194           0 :     from pairinteraction import run_unit_tests
     195             : 
     196           0 :     print("Running the C++ module unit tests...")
     197           0 :     exit_code = run_unit_tests(download_missing=True)
     198           0 :     if exit_code:
     199           0 :         print(Fore.RED + "Tests failed." + Style.RESET_ALL)
     200             :     else:
     201           0 :         print(Fore.GREEN + "Tests passed." + Style.RESET_ALL)
     202           0 :     return exit_code
     203             : 
     204             : 
     205           1 : def _download_database_from_url(url: str, tables_dir: Path) -> int:
     206           0 :     import shutil
     207           0 :     import tempfile
     208           0 :     from urllib.request import urlretrieve
     209           0 :     from zipfile import ZipFile
     210             : 
     211           0 :     from packaging.version import Version
     212             : 
     213           0 :     try:
     214           0 :         with tempfile.TemporaryDirectory() as td:
     215           0 :             tmp = Path(td) / "tables.zip"
     216             : 
     217           0 :             try:
     218           0 :                 msg = f"Downloading {url}..."
     219           0 :                 print(msg, end="", flush=True)
     220             : 
     221           0 :                 def _hook(blocks: int, block_size: int, total_size: int, _msg: str = msg) -> None:
     222           0 :                     if total_size > 0:
     223           0 :                         pct = min(100, int(blocks * block_size * 100 / total_size))
     224           0 :                         print(f"\r{_msg} {pct:3}%", end="", flush=True)
     225             : 
     226           0 :                 urlretrieve(url, tmp, reporthook=_hook)  # noqa: S310
     227             :             finally:
     228           0 :                 print()
     229             : 
     230           0 :             with ZipFile(tmp) as z:
     231           0 :                 roots = {Path(n).parts[0] for n in z.namelist() if Path(n).parts}
     232           0 :                 root = roots.pop() if len(roots) == 1 else None
     233             : 
     234           0 :                 if not root or f"{root}/" not in z.namelist():
     235           0 :                     raise ValueError("The ZIP archive must contain exactly one top-level folder.")  # noqa: TRY301
     236           0 :                 species, version_str = root.rsplit("_v", 1)
     237           0 :                 Version(version_str)  # validate version
     238             : 
     239           0 :                 to_delete = list(tables_dir.glob(f"{species}*"))
     240           0 :                 confirmation = input(
     241             :                     f"Do you want delete the tables in {', '.join(p.name for p in to_delete)} and "
     242             :                     "replace them with the downloaded tables? (y/N): "
     243             :                 )
     244           0 :                 if confirmation.lower() not in ["y", "yes"]:
     245           0 :                     print(Fore.YELLOW + "Aborted replacing tables." + Style.RESET_ALL)
     246           0 :                     return 1
     247           0 :                 for p in to_delete:
     248           0 :                     shutil.rmtree(p)
     249             : 
     250           0 :             shutil.unpack_archive(tmp, tables_dir, format="zip")
     251             : 
     252           0 :     except Exception as e:
     253           0 :         print(Fore.RED + f"Failed: {e}" + Style.RESET_ALL)
     254           0 :         return 1
     255             : 
     256             :     else:
     257           0 :         print(Fore.GREEN + "Successful." + Style.RESET_ALL)
     258           0 :         return 0
     259             : 
     260             : 
     261           1 : def download_databases(species_list: list[str]) -> int:
     262             :     """Download the required data files for the specified species."""
     263           0 :     from urllib.parse import urlparse
     264             : 
     265           0 :     import pairinteraction as pi
     266           0 :     from pairinteraction._backend import get_cache_directory
     267             : 
     268           0 :     database_dir = get_cache_directory() / "database"
     269           0 :     tables_dir = database_dir / "tables"
     270           0 :     tables_dir.mkdir(parents=True, exist_ok=True)
     271           0 :     database = pi.Database(download_missing=True, use_cache=False, database_dir=database_dir)
     272             : 
     273           0 :     is_wigner_downloaded = False
     274           0 :     exit_code = 0
     275             : 
     276           0 :     for species in species_list:
     277             :         # If species is a URL, download and unzip to database/tables
     278           0 :         if urlparse(species).scheme in {"http", "https"}:
     279           0 :             print("Check for tables...")
     280           0 :             exit_code |= _download_database_from_url(species, tables_dir)
     281           0 :             continue
     282             : 
     283           0 :         try:
     284           0 :             print(f"Check for tables for {species}...")
     285             : 
     286             :             # We make use of the fact that all tables of a species get downloaded
     287             :             # automatically when we create a BasisAtom object.
     288           0 :             basis = pi.BasisAtom(species, n=(50, 51), l=(0, 2), database=database)
     289             : 
     290             :             # We calculate matrix elements to ensure that the Wigner table is
     291             :             # downloaded as well.
     292           0 :             if not is_wigner_downloaded:
     293           0 :                 basis.get_matrix_elements(basis, "electric_dipole", 0)
     294           0 :                 is_wigner_downloaded = True
     295             : 
     296           0 :             print(Fore.GREEN + "Successful." + Style.RESET_ALL)
     297           0 :         except Exception as e:
     298           0 :             exit_code = 1
     299           0 :             print(Fore.RED + f"Failed: {e}" + Style.RESET_ALL)
     300             : 
     301           0 :     return exit_code
     302             : 
     303             : 
     304           1 : def list_databases() -> int:
     305             :     """Print a table of local and remote database table versions."""
     306           0 :     from pairinteraction.database import print_database_info
     307             : 
     308           0 :     print_database_info()
     309           0 :     return 0
     310             : 
     311             : 
     312           1 : def show_paths() -> int:
     313             :     """Show config and cache directories."""
     314           0 :     from pairinteraction._backend import get_cache_directory, get_config_directory
     315             : 
     316           0 :     print("Config directory:", get_config_directory())
     317           0 :     print("Cache directory:", get_cache_directory())
     318           0 :     print("Database directory:", get_cache_directory() / "database/tables")
     319           0 :     return 0
     320             : 
     321             : 
     322           1 : def remove_database_cache() -> int:
     323             :     """Delete the cached database directory."""
     324           0 :     import shutil
     325             : 
     326           0 :     from pairinteraction._backend import get_cache_directory
     327             : 
     328           0 :     database_dir = get_cache_directory() / "database"
     329             : 
     330           0 :     confirmation = input(f"Are you sure you want to delete all downloaded database tables in {database_dir}? (y/N): ")
     331           0 :     if confirmation.lower() not in ["y", "yes"]:
     332           0 :         print(Fore.YELLOW + "Aborted deletion of database directory." + Style.RESET_ALL)
     333           0 :         return 0
     334             : 
     335           0 :     print(f"Deleting cached database directory {database_dir}...")
     336           0 :     try:
     337           0 :         shutil.rmtree(database_dir)
     338           0 :     except Exception as e:
     339           0 :         print(Fore.RED + f"Error while deleting database directory: {e}" + Style.RESET_ALL)
     340           0 :         return 1
     341             : 
     342           0 :     print(Fore.GREEN + "Database directory deleted." + Style.RESET_ALL)
     343           0 :     return 0
     344             : 
     345             : 
     346           1 : if __name__ == "__main__":
     347           0 :     sys.exit(main())

Generated by: LCOV version 1.16