{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Benchmark of Numerov integration with and without using numba.njit" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import timeit\n", "\n", "import numpy as np\n", "\n", "from ryd_numerov.rydberg import RydbergState\n", "\n", "test_cases: list[tuple[str, int, int, bool]] = [\n", " # species, n, l, use_njit\n", " (\"H\", 100, 80, True),\n", " (\"H\", 100, 80, False),\n", "]" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def run_benchmark(number: int = 10) -> list[dict]:\n", " \"\"\"Run benchmark for different quantum states.\n", "\n", " Args:\n", " number: Number of times to run each test for averaging\n", "\n", " \"\"\"\n", " # run the integration once to compile the numba function\n", " species, n, l, use_njit = test_cases[0]\n", " atom = RydbergState(species, n, l, j=l + 0.5)\n", " atom.create_wavefunction(_use_njit=True)\n", "\n", " results = []\n", " for species, n, l, use_njit in test_cases:\n", " # Setup the test function\n", " stmt = (\n", " \"atom = RydbergState(species, n, l, j=l+0.5)\\n\"\n", " \"atom.create_grid(dz=1e-3)\\n\"\n", " \"atom.create_wavefunction(_use_njit=use_njit)\"\n", " )\n", "\n", " # Time the integration multiple times and take average/std\n", " globals_dict = {\"RydbergState\": RydbergState, \"species\": species, \"n\": n, \"l\": l, \"use_njit\": use_njit}\n", " times = timeit.repeat(stmt=stmt, number=1, repeat=number, globals=globals_dict)\n", " avg_time = np.mean(times)\n", " std_time = np.std(times)\n", "\n", " results.append({\"species\": species, \"n\": n, \"l\": l, \"use_njit\": use_njit, \"time\": avg_time, \"std\": std_time})\n", "\n", " return results" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using python implementation of Numerov integration, this is much slower!\n", "Using python implementation of Numerov integration, this is much slower!\n", "Using python implementation of Numerov integration, this is much slower!\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Benchmark Results:\n", "----------------------------------------------------------------------\n", " species n l use_njit time (ms) std (ms)\n", "----------------------------------------------------------------------\n", " H 100 80 True 11.88 0.73\n", " H 100 80 False 117.49 0.15\n" ] } ], "source": [ "results = run_benchmark(number=3)\n", "\n", "print(\"\\nBenchmark Results:\")\n", "print(\"-\" * 70)\n", "print(f\"{'species':>8} {'n':>3} {'l':>3} {'use_njit':>10} {'time (ms)':>10} {'std (ms)':>10}\")\n", "print(\"-\" * 70)\n", "for r in results:\n", " print(\n", " f\"{r['species']:>8} {r['n']:>3} {r['l']:>3} {r['use_njit']!s:>10} \"\n", " f\"{r['time'] * 1000:10.2f} {r['std'] * 1000:10.2f}\"\n", " )" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.1" } }, "nbformat": 4, "nbformat_minor": 2 }