From fdf93e3931fba4616b658688210f114d90740bb1 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Wed, 27 Nov 2024 11:53:27 +0000 Subject: [PATCH 01/15] Use JAX function for generating precomputes --- benchmarks/spherical.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/benchmarks/spherical.py b/benchmarks/spherical.py index c0f529e0..f25b29bc 100644 --- a/benchmarks/spherical.py +++ b/benchmarks/spherical.py @@ -5,7 +5,7 @@ from benchmarking import benchmark, parse_args_collect_and_run_benchmarks, skip import s2fft -from s2fft.recursions.price_mcewen import generate_precomputes +from s2fft.recursions.price_mcewen import generate_precomputes_jax from s2fft.sampling import s2_samples as samples L_VALUES = [8, 16, 32, 64, 128, 256] @@ -17,6 +17,10 @@ SPMD_VALUES = [False] +def _jax_arrays_to_numpy(precomps): + return [np.asarray(p) for p in precomps] + + def setup_forward(method, L, L_lower, sampling, spin, reality, spmd): if reality and spin != 0: skip("Reality only valid for scalar fields (spin=0).") @@ -31,7 +35,11 @@ def setup_forward(method, L, L_lower, sampling, spin, reality, spmd): Spin=spin, Reality=reality, ) - precomps = generate_precomputes(L, spin, sampling, forward=True, L_lower=L_lower) + precomps = generate_precomputes_jax( + L, spin, sampling, forward=True, L_lower=L_lower + ) + if method == "numpy": + precomps = _jax_arrays_to_numpy(precomps) return {"f": f, "precomps": precomps} @@ -71,7 +79,11 @@ def setup_inverse(method, L, L_lower, sampling, spin, reality, spmd): skip("GPU distribution only valid for JAX.") rng = np.random.default_rng() flm = s2fft.utils.signal_generator.generate_flm(rng, L, spin=spin, reality=reality) - precomps = generate_precomputes(L, spin, sampling, forward=False, L_lower=L_lower) + precomps = generate_precomputes_jax( + L, spin, sampling, forward=False, L_lower=L_lower + ) + if method == "numpy": + precomps = _jax_arrays_to_numpy(precomps) return {"flm": flm, "precomps": precomps} From 268e6218119daf66fb0d044cf9e6c68b7b8d841c Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Wed, 27 Nov 2024 14:34:24 +0000 Subject: [PATCH 02/15] Add benchmarks for precompute versions of spherical transforms --- benchmarks/precompute_spherical.py | 102 +++++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 benchmarks/precompute_spherical.py diff --git a/benchmarks/precompute_spherical.py b/benchmarks/precompute_spherical.py new file mode 100644 index 00000000..3654dc02 --- /dev/null +++ b/benchmarks/precompute_spherical.py @@ -0,0 +1,102 @@ +"""Benchmarks for spherical transforms.""" + +import numpy as np +import pyssht +from benchmarking import benchmark, parse_args_collect_and_run_benchmarks, skip + +import s2fft +import s2fft.precompute_transforms +from s2fft.sampling import s2_samples as samples + +L_VALUES = [8, 16, 32, 64, 128, 256] +SPIN_VALUES = [0] +SAMPLING_VALUES = ["mw"] +METHOD_VALUES = ["numpy", "jax"] +REALITY_VALUES = [True] + + +def setup_forward(method, L, sampling, spin, reality): + if reality and spin != 0: + skip("Reality only valid for scalar fields (spin=0).") + rng = np.random.default_rng() + flm = s2fft.utils.signal_generator.generate_flm(rng, L, spin=spin, reality=reality) + f = pyssht.inverse( + samples.flm_2d_to_1d(flm, L), + L, + Method=sampling.upper(), + Spin=spin, + Reality=reality, + ) + kernel_function = ( + s2fft.precompute_transforms.construct.spin_spherical_kernel_jax + if method == "jax" + else s2fft.precompute_transforms.construct.spin_spherical_kernel + ) + kernel = kernel_function( + L=L, spin=spin, reality=reality, sampling=sampling, forward=True + ) + return {"f": f, "kernel": kernel} + + +@benchmark( + setup_forward, + method=METHOD_VALUES, + L=L_VALUES, + sampling=SAMPLING_VALUES, + spin=SPIN_VALUES, + reality=REALITY_VALUES, +) +def forward(f, kernel, method, L, sampling, spin, reality): + flm = s2fft.precompute_transforms.spherical.forward( + f=f, + L=L, + spin=spin, + kernel=kernel, + sampling=sampling, + reality=reality, + method=method, + ) + if method == "jax": + flm.block_until_ready() + + +def setup_inverse(method, L, sampling, spin, reality): + if reality and spin != 0: + skip("Reality only valid for scalar fields (spin=0).") + rng = np.random.default_rng() + flm = s2fft.utils.signal_generator.generate_flm(rng, L, spin=spin, reality=reality) + kernel_function = ( + s2fft.precompute_transforms.construct.spin_spherical_kernel_jax + if method == "jax" + else s2fft.precompute_transforms.construct.spin_spherical_kernel + ) + kernel = kernel_function( + L=L, spin=spin, reality=reality, sampling=sampling, forward=False + ) + return {"flm": flm, "kernel": kernel} + + +@benchmark( + setup_inverse, + method=METHOD_VALUES, + L=L_VALUES, + sampling=SAMPLING_VALUES, + spin=SPIN_VALUES, + reality=REALITY_VALUES, +) +def inverse(flm, kernel, method, L, sampling, spin, reality): + f = s2fft.precompute_transforms.spherical.inverse( + flm=flm, + L=L, + spin=spin, + kernel=kernel, + sampling=sampling, + reality=reality, + method=method, + ) + if method == "jax": + f.block_until_ready() + + +if __name__ == "__main__": + results = parse_args_collect_and_run_benchmarks() From 1f5182a60a5916054537128bd57ae9e250ec2676 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Wed, 27 Nov 2024 14:35:33 +0000 Subject: [PATCH 03/15] Show defaults in benchmark help text --- benchmarks/benchmarking.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/benchmarks/benchmarking.py b/benchmarks/benchmarking.py index 5a9496fb..2cdaa8ec 100644 --- a/benchmarks/benchmarking.py +++ b/benchmarks/benchmarking.py @@ -147,7 +147,9 @@ def _parse_parameter_overrides(parameter_overrides): def _parse_cli_arguments(): """Parse command line arguments passed for controlling benchmark runs""" - parser = argparse.ArgumentParser("Run benchmarks") + parser = argparse.ArgumentParser( + "Run benchmarks", formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) parser.add_argument( "-number-runs", type=int, From 24d0441d91b3273f1953fbaba59d9edccbc028df Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Wed, 27 Nov 2024 14:35:51 +0000 Subject: [PATCH 04/15] Add flag for running once and discarding to eliminate JIT overhead --- benchmarks/benchmarking.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/benchmarks/benchmarking.py b/benchmarks/benchmarking.py index 2cdaa8ec..47db52bc 100644 --- a/benchmarks/benchmarking.py +++ b/benchmarks/benchmarking.py @@ -176,6 +176,15 @@ def _parse_cli_arguments(): parser.add_argument( "-output-file", type=Path, help="File path to write JSON formatted results to." ) + parser.add_argument( + "--run-once-and-discard", + action="store_true", + help=( + "Run benchmark function once first without recording time to " + "ignore the effect of any initial one-off costs such as just-in-time " + "compilation." + ), + ) return parser.parse_args() @@ -206,6 +215,7 @@ def run_benchmarks( number_repeats, print_results=True, parameter_overrides=None, + run_once_and_discard=False, ): """Run a set of benchmarks. @@ -219,6 +229,9 @@ def run_benchmarks( print_results: Whether to print benchmark results to stdout. parameter_overrides: Dictionary specifying any overrides for parameter values set in `benchmark` decorator. + run_once_and_discard: Whether to run benchmark function once first without + recording time to ignore the effect of any initial one-off costs such as + just-in-time compilation. Returns: Dictionary containing timing (and potentially memory usage) results for each @@ -236,6 +249,8 @@ def run_benchmarks( try: precomputes = benchmark.setup(**parameter_set) benchmark_function = partial(benchmark, **precomputes, **parameter_set) + if run_once_and_discard: + benchmark_function() run_times = [ time / number_runs for time in timeit.repeat( @@ -300,6 +315,7 @@ def parse_args_collect_and_run_benchmarks(module=None): number_runs=args.number_runs, number_repeats=args.repeats, parameter_overrides=parameter_overrides, + run_once_and_discard=args.run_once_and_discard, ) if args.output_file is not None: with open(args.output_file, "w") as f: From 3adac862d325dc3d8f9cb739d0e8633e569e77dc Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Wed, 27 Nov 2024 14:41:58 +0000 Subject: [PATCH 05/15] Use module docstring for benchmark help description --- benchmarks/benchmarking.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/benchmarks/benchmarking.py b/benchmarks/benchmarking.py index 47db52bc..5b24f4b4 100644 --- a/benchmarks/benchmarking.py +++ b/benchmarks/benchmarking.py @@ -145,10 +145,10 @@ def _parse_parameter_overrides(parameter_overrides): ) -def _parse_cli_arguments(): +def _parse_cli_arguments(description): """Parse command line arguments passed for controlling benchmark runs""" parser = argparse.ArgumentParser( - "Run benchmarks", formatter_class=argparse.ArgumentDefaultsHelpFormatter + description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument( "-number-runs", @@ -305,11 +305,11 @@ def parse_args_collect_and_run_benchmarks(module=None): Dictionary containing timing (and potentially memory usage) results for each parameters set of each benchmark function. """ - args = _parse_cli_arguments() - parameter_overrides = _parse_parameter_overrides(args.parameter_overrides) if module is None: frame = inspect.stack()[1] module = inspect.getmodule(frame[0]) + args = _parse_cli_arguments(module.__doc__) + parameter_overrides = _parse_parameter_overrides(args.parameter_overrides) results = run_benchmarks( benchmarks=collect_benchmarks(module), number_runs=args.number_runs, From 0c87cd64b997ab0664a3357a2c73a8e2a58a2461 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Wed, 27 Nov 2024 14:42:57 +0000 Subject: [PATCH 06/15] Make benchmark module docstring consistent --- benchmarks/precompute_spherical.py | 2 +- benchmarks/spherical.py | 2 +- benchmarks/wigner.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/benchmarks/precompute_spherical.py b/benchmarks/precompute_spherical.py index 3654dc02..1c2ee0c2 100644 --- a/benchmarks/precompute_spherical.py +++ b/benchmarks/precompute_spherical.py @@ -1,4 +1,4 @@ -"""Benchmarks for spherical transforms.""" +"""Benchmarks for precompute spherical transforms.""" import numpy as np import pyssht diff --git a/benchmarks/spherical.py b/benchmarks/spherical.py index f25b29bc..78231a43 100644 --- a/benchmarks/spherical.py +++ b/benchmarks/spherical.py @@ -1,4 +1,4 @@ -"""Benchmarks for spherical transforms.""" +"""Benchmarks for on-the-fly spherical transforms.""" import numpy as np import pyssht diff --git a/benchmarks/wigner.py b/benchmarks/wigner.py index 717501e3..a1f7d1ef 100644 --- a/benchmarks/wigner.py +++ b/benchmarks/wigner.py @@ -1,4 +1,4 @@ -"""Benchmarks for Wigner transforms.""" +"""Benchmarks for on-the-fly Wigner-d transforms.""" import numpy as np from benchmarking import benchmark, parse_args_collect_and_run_benchmarks, skip From b3e328398ee8d7aa97069bf159acfd7f7d557605 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Wed, 27 Nov 2024 15:15:52 +0000 Subject: [PATCH 07/15] Add precompute Wigner benchmarks --- benchmarks/precompute_wigner.py | 101 ++++++++++++++++++++++++++++++++ 1 file changed, 101 insertions(+) create mode 100644 benchmarks/precompute_wigner.py diff --git a/benchmarks/precompute_wigner.py b/benchmarks/precompute_wigner.py new file mode 100644 index 00000000..a3d1a988 --- /dev/null +++ b/benchmarks/precompute_wigner.py @@ -0,0 +1,101 @@ +"""Benchmarks for precompute Wigner-d transforms.""" + +import numpy as np +from benchmarking import benchmark, parse_args_collect_and_run_benchmarks, skip + +import s2fft +import s2fft.precompute_transforms +from s2fft.base_transforms import wigner as base_wigner +from s2fft.sampling import s2_samples as samples + +L_VALUES = [16, 32, 64, 128, 256] +N_VALUES = [2] +L_LOWER_VALUES = [0] +SAMPLING_VALUES = ["mw"] +METHOD_VALUES = ["numpy", "jax"] +REALITY_VALUES = [True] + +def setup_forward(method, L, N, L_lower, sampling, reality): + rng = np.random.default_rng() + flmn = s2fft.utils.signal_generator.generate_flmn(rng, L, N, reality=reality) + f = base_wigner.inverse( + flmn, + L, + N, + L_lower=L_lower, + sampling=sampling, + reality=reality, + ) + kernel_function = ( + s2fft.precompute_transforms.construct.wigner_kernel_jax + if method == "jax" + else s2fft.precompute_transforms.construct.wigner_kernel + ) + kernel = kernel_function( + L=L, N=N, reality=reality, sampling=sampling, forward=True + ) + return {"f": f, "kernel": kernel} + + +@benchmark( + setup_forward, + method=METHOD_VALUES, + L=L_VALUES, + N=N_VALUES, + L_lower=L_LOWER_VALUES, + sampling=SAMPLING_VALUES, + reality=REALITY_VALUES, +) +def forward(f, kernel, method, L, N, L_lower, sampling, reality): + flmn = s2fft.precompute_transforms.wigner.forward( + f=f, + L=L, + N=N, + kernel=kernel, + sampling=sampling, + reality=reality, + method=method, + ) + if method == "jax": + flmn.block_until_ready() + + +def setup_inverse(method, L, N, L_lower, sampling, reality): + rng = np.random.default_rng() + flmn = s2fft.utils.signal_generator.generate_flmn(rng, L, N, reality=reality) + kernel_function = ( + s2fft.precompute_transforms.construct.wigner_kernel_jax + if method == "jax" + else s2fft.precompute_transforms.construct.wigner_kernel + ) + kernel = kernel_function( + L=L, N=N, reality=reality, sampling=sampling, forward=False + ) + return {"flmn": flmn, "kernel": kernel} + + +@benchmark( + setup_inverse, + method=METHOD_VALUES, + L=L_VALUES, + N=N_VALUES, + L_lower=L_LOWER_VALUES, + sampling=SAMPLING_VALUES, + reality=REALITY_VALUES, +) +def inverse(flmn, kernel, method, L, N, L_lower, sampling, reality): + f = s2fft.precompute_transforms.wigner.inverse( + flmn=flmn, + L=L, + N=N, + kernel=kernel, + sampling=sampling, + reality=reality, + method=method, + ) + if method == "jax": + f.block_until_ready() + + +if __name__ == "__main__": + results = parse_args_collect_and_run_benchmarks() From dfaa56971e33bc9d66f8b21971c82524e5fc5646 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Wed, 27 Nov 2024 15:40:51 +0000 Subject: [PATCH 08/15] Record all benchmark results to JSON to not just last param set --- benchmarks/benchmarking.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/benchmarks/benchmarking.py b/benchmarks/benchmarking.py index 5b24f4b4..45d1be0a 100644 --- a/benchmarks/benchmarking.py +++ b/benchmarks/benchmarking.py @@ -239,7 +239,7 @@ def run_benchmarks( """ results = {} for benchmark in benchmarks: - results[benchmark.__name__] = {} + results[benchmark.__name__] = [] if print_results: print(benchmark.__name__) parameters = benchmark.parameters.copy() @@ -257,7 +257,7 @@ def run_benchmarks( benchmark_function, number=number_runs, repeat=number_repeats ) ] - results[benchmark.__name__] = {**parameter_set, "times / s": run_times} + results_entry = {**parameter_set, "times / s": run_times} if MEMORY_PROFILER_AVAILABLE: baseline_memory = memory_profiler.memory_usage(max_usage=True) peak_memory = ( @@ -270,7 +270,8 @@ def run_benchmarks( ) - baseline_memory ) - results[benchmark.__name__]["peak_memory / MiB"] = peak_memory + results_entry["peak_memory / MiB"] = peak_memory + results[benchmark.__name__].append(results_entry) if print_results: print( ( @@ -279,9 +280,9 @@ def run_benchmarks( else " " ) + f"min(time): {min(run_times):>#7.2g}s, " - + f"max(time): {max(run_times):>#7.2g}s, " + + f"max(time): {max(run_times):>#7.2g}s" + ( - f"peak mem.: {peak_memory:>#7.2g}MiB" + f", peak mem.: {peak_memory:>#7.2g}MiB" if MEMORY_PROFILER_AVAILABLE else "" ) From 54db6d218acc98681dc39bf095a7cc6c37f5d68b Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Wed, 27 Nov 2024 15:42:07 +0000 Subject: [PATCH 09/15] Record system info and date time in JSON output --- benchmarks/benchmarking.py | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/benchmarks/benchmarking.py b/benchmarks/benchmarking.py index 45d1be0a..dfd9fd78 100644 --- a/benchmarks/benchmarking.py +++ b/benchmarks/benchmarking.py @@ -60,11 +60,14 @@ def mean(x, n): """ import argparse +import datetime import inspect import json +import platform import timeit from ast import literal_eval from functools import partial +from importlib.metadata import PackageNotFoundError, version from itertools import product from pathlib import Path @@ -80,6 +83,14 @@ class SkipBenchmarkException(Exception): """Exception to be raised to skip benchmark for some parameter set.""" +def _get_version_or_none(package_name): + """Get installed version of package or `None` if package not found.""" + try: + return version(package_name) + except PackageNotFoundError: + return None + + def skip(message): """Skip benchmark for a particular parameter set with explanatory message. @@ -319,6 +330,25 @@ def parse_args_collect_and_run_benchmarks(module=None): run_once_and_discard=args.run_once_and_discard, ) if args.output_file is not None: + package_versions = { + f"{package}_version": _get_version_or_none(package) + for package in ("s2fft", "jax", "numpy") + } + system_info = { + "architecture": platform.architecture(), + "machine": platform.machine(), + "node": platform.node(), + "processor": platform.processor(), + "python_version": platform.python_version(), + "release": platform.release(), + "system": platform.system(), + **package_versions, + } with open(args.output_file, "w") as f: - json.dump(results, f) + output = { + "date_time": datetime.datetime.now().isoformat(), + "system_info": system_info, + "results": results, + } + json.dump(output, f, indent=True) return results From f4e9458cea0b2467323d40d15e2c8025f8a0af68 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Wed, 27 Nov 2024 15:52:26 +0000 Subject: [PATCH 10/15] Expose recursion and mode parameters for precompute benchmarks --- benchmarks/precompute_spherical.py | 25 +++++++++++++++++++------ benchmarks/precompute_wigner.py | 16 ++++++++++------ 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/benchmarks/precompute_spherical.py b/benchmarks/precompute_spherical.py index 1c2ee0c2..1af2a2aa 100644 --- a/benchmarks/precompute_spherical.py +++ b/benchmarks/precompute_spherical.py @@ -13,9 +13,10 @@ SAMPLING_VALUES = ["mw"] METHOD_VALUES = ["numpy", "jax"] REALITY_VALUES = [True] +RECURSION_VALUES = ["auto"] -def setup_forward(method, L, sampling, spin, reality): +def setup_forward(method, L, sampling, spin, reality, recursion): if reality and spin != 0: skip("Reality only valid for scalar fields (spin=0).") rng = np.random.default_rng() @@ -33,7 +34,12 @@ def setup_forward(method, L, sampling, spin, reality): else s2fft.precompute_transforms.construct.spin_spherical_kernel ) kernel = kernel_function( - L=L, spin=spin, reality=reality, sampling=sampling, forward=True + L=L, + spin=spin, + reality=reality, + sampling=sampling, + forward=True, + recursion=recursion, ) return {"f": f, "kernel": kernel} @@ -45,8 +51,9 @@ def setup_forward(method, L, sampling, spin, reality): sampling=SAMPLING_VALUES, spin=SPIN_VALUES, reality=REALITY_VALUES, + recursion=RECURSION_VALUES, ) -def forward(f, kernel, method, L, sampling, spin, reality): +def forward(f, kernel, method, L, sampling, spin, reality, recursion): flm = s2fft.precompute_transforms.spherical.forward( f=f, L=L, @@ -60,7 +67,7 @@ def forward(f, kernel, method, L, sampling, spin, reality): flm.block_until_ready() -def setup_inverse(method, L, sampling, spin, reality): +def setup_inverse(method, L, sampling, spin, reality, recursion): if reality and spin != 0: skip("Reality only valid for scalar fields (spin=0).") rng = np.random.default_rng() @@ -71,7 +78,12 @@ def setup_inverse(method, L, sampling, spin, reality): else s2fft.precompute_transforms.construct.spin_spherical_kernel ) kernel = kernel_function( - L=L, spin=spin, reality=reality, sampling=sampling, forward=False + L=L, + spin=spin, + reality=reality, + sampling=sampling, + forward=False, + recursion=recursion, ) return {"flm": flm, "kernel": kernel} @@ -83,8 +95,9 @@ def setup_inverse(method, L, sampling, spin, reality): sampling=SAMPLING_VALUES, spin=SPIN_VALUES, reality=REALITY_VALUES, + recursion=RECURSION_VALUES, ) -def inverse(flm, kernel, method, L, sampling, spin, reality): +def inverse(flm, kernel, method, L, sampling, spin, reality, recursion): f = s2fft.precompute_transforms.spherical.inverse( flm=flm, L=L, diff --git a/benchmarks/precompute_wigner.py b/benchmarks/precompute_wigner.py index a3d1a988..aa64ae69 100644 --- a/benchmarks/precompute_wigner.py +++ b/benchmarks/precompute_wigner.py @@ -14,8 +14,10 @@ SAMPLING_VALUES = ["mw"] METHOD_VALUES = ["numpy", "jax"] REALITY_VALUES = [True] +MODE_VALUES = ["auto"] -def setup_forward(method, L, N, L_lower, sampling, reality): + +def setup_forward(method, L, N, L_lower, sampling, reality, mode): rng = np.random.default_rng() flmn = s2fft.utils.signal_generator.generate_flmn(rng, L, N, reality=reality) f = base_wigner.inverse( @@ -32,7 +34,7 @@ def setup_forward(method, L, N, L_lower, sampling, reality): else s2fft.precompute_transforms.construct.wigner_kernel ) kernel = kernel_function( - L=L, N=N, reality=reality, sampling=sampling, forward=True + L=L, N=N, reality=reality, sampling=sampling, forward=True, mode=mode ) return {"f": f, "kernel": kernel} @@ -45,8 +47,9 @@ def setup_forward(method, L, N, L_lower, sampling, reality): L_lower=L_LOWER_VALUES, sampling=SAMPLING_VALUES, reality=REALITY_VALUES, + mode=MODE_VALUES, ) -def forward(f, kernel, method, L, N, L_lower, sampling, reality): +def forward(f, kernel, method, L, N, L_lower, sampling, reality, mode): flmn = s2fft.precompute_transforms.wigner.forward( f=f, L=L, @@ -60,7 +63,7 @@ def forward(f, kernel, method, L, N, L_lower, sampling, reality): flmn.block_until_ready() -def setup_inverse(method, L, N, L_lower, sampling, reality): +def setup_inverse(method, L, N, L_lower, sampling, reality, mode): rng = np.random.default_rng() flmn = s2fft.utils.signal_generator.generate_flmn(rng, L, N, reality=reality) kernel_function = ( @@ -69,7 +72,7 @@ def setup_inverse(method, L, N, L_lower, sampling, reality): else s2fft.precompute_transforms.construct.wigner_kernel ) kernel = kernel_function( - L=L, N=N, reality=reality, sampling=sampling, forward=False + L=L, N=N, reality=reality, sampling=sampling, forward=False, mode=mode ) return {"flmn": flmn, "kernel": kernel} @@ -82,8 +85,9 @@ def setup_inverse(method, L, N, L_lower, sampling, reality): L_lower=L_LOWER_VALUES, sampling=SAMPLING_VALUES, reality=REALITY_VALUES, + mode=MODE_VALUES, ) -def inverse(flmn, kernel, method, L, N, L_lower, sampling, reality): +def inverse(flmn, kernel, method, L, N, L_lower, sampling, reality, mode): f = s2fft.precompute_transforms.wigner.inverse( flmn=flmn, L=L, From e917f050fc3b0cc25b2e3d994fe9313e3894c7a4 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Wed, 27 Nov 2024 16:13:11 +0000 Subject: [PATCH 11/15] Record CPU and GPU info in benchmarks output if available --- benchmarks/benchmarking.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/benchmarks/benchmarking.py b/benchmarks/benchmarking.py index dfd9fd78..e8a2325e 100644 --- a/benchmarks/benchmarking.py +++ b/benchmarks/benchmarking.py @@ -91,6 +91,26 @@ def _get_version_or_none(package_name): return None +def _get_cpu_info(): + """Get details of CPU from cpuinfo if available or None if not.""" + try: + import cpuinfo + + return cpuinfo.get_cpu_info() + except ImportError: + return None + + +def _get_gpu_info(): + """Get details of GPU devices available from JAX or None if JAX not available.""" + try: + import jax + + return [d.device_kind for d in jax.devices() if d.platform == "gpu"] + except ImportError: + return None + + def skip(message): """Skip benchmark for a particular parameter set with explanatory message. @@ -342,11 +362,14 @@ def parse_args_collect_and_run_benchmarks(module=None): "python_version": platform.python_version(), "release": platform.release(), "system": platform.system(), + "cpu_info": _get_cpu_info(), + "gpu_info": _get_gpu_info(), **package_versions, } with open(args.output_file, "w") as f: output = { "date_time": datetime.datetime.now().isoformat(), + "benchmark_module": module.__name__, "system_info": system_info, "results": results, } From 742fdadc797060bb4d6108fc33c8f4b01108f1be Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Wed, 27 Nov 2024 17:15:20 +0000 Subject: [PATCH 12/15] Record GPU memory + CUDA info in results --- benchmarks/benchmarking.py | 38 +++++++++++++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/benchmarks/benchmarking.py b/benchmarks/benchmarking.py index e8a2325e..faea7c1b 100644 --- a/benchmarks/benchmarking.py +++ b/benchmarks/benchmarking.py @@ -101,12 +101,47 @@ def _get_cpu_info(): return None +def _get_gpu_memory_mebibytes(device): + """Try to get GPU memory available in mebibytes (MiB).""" + memory_stats = device.memory_stats() + if memory_stats is None: + return None + bytes_limit = memory_stats.get("bytes_limit") + return bytes_limit // 2**20 if bytes_limit is not None else None + + def _get_gpu_info(): """Get details of GPU devices available from JAX or None if JAX not available.""" try: import jax - return [d.device_kind for d in jax.devices() if d.platform == "gpu"] + return [ + { + "kind": d.device_kind, + "memory_available / MiB": _get_gpu_memory_mebibytes(d), + } + for d in jax.devices() + if d.platform == "gpu" + ] + except ImportError: + return None + + +def _get_cuda_info(): + """Try to get information on versions of CUDA libraries.""" + try: + from jax._src.lib import cuda_versions + + if cuda_versions is None: + return None + return { + "cuda_runtime_version": cuda_versions.cuda_runtime_get_version(), + "cuda_runtime_build_version": cuda_versions.cuda_runtime_build_version(), + "cudnn_version": cuda_versions.cudnn_get_version(), + "cudnn_build_version": cuda_versions.cudnn_build_version(), + "cufft_version": cuda_versions.cufft_get_version(), + "cufft_build_version": cuda_versions.cufft_build_version(), + } except ImportError: return None @@ -364,6 +399,7 @@ def parse_args_collect_and_run_benchmarks(module=None): "system": platform.system(), "cpu_info": _get_cpu_info(), "gpu_info": _get_gpu_info(), + "cuda_info": _get_cuda_info(), **package_versions, } with open(args.output_file, "w") as f: From 942cf9ef3801b515c142a7111d69c8b4186c972b Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Wed, 27 Nov 2024 17:55:01 +0000 Subject: [PATCH 13/15] Fix Wigner benchmarks --- benchmarks/wigner.py | 31 +++++++++++-------------------- 1 file changed, 11 insertions(+), 20 deletions(-) diff --git a/benchmarks/wigner.py b/benchmarks/wigner.py index a1f7d1ef..412c404f 100644 --- a/benchmarks/wigner.py +++ b/benchmarks/wigner.py @@ -16,12 +16,9 @@ SAMPLING_VALUES = ["mw"] METHOD_VALUES = ["numpy", "jax"] REALITY_VALUES = [True] -SPMD_VALUES = [False] -def setup_forward(method, L, L_lower, N, sampling, reality, spmd): - if spmd and method != "jax": - skip("GPU distribution only valid for JAX.") +def setup_forward(method, L, L_lower, N, sampling, reality): rng = np.random.default_rng() flmn = s2fft.utils.signal_generator.generate_flmn(rng, L, N, reality=reality) f = base_wigner.inverse( @@ -51,27 +48,23 @@ def setup_forward(method, L, L_lower, N, sampling, reality, spmd): N=N_VALUES, sampling=SAMPLING_VALUES, reality=REALITY_VALUES, - spmd=SPMD_VALUES, ) -def forward(f, precomps, method, L, L_lower, N, sampling, reality, spmd): +def forward(f, precomps, method, L, L_lower, N, sampling, reality): flmn = s2fft.transforms.wigner.forward( f=f, L=L, - L_lower=L_lower, N=N, - precomps=precomps, sampling=sampling, - reality=reality, method=method, - spmd=spmd, + reality=reality, + precomps=precomps, + L_lower=L_lower, ) if method == "jax": flmn.block_until_ready() -def setup_inverse(method, L, L_lower, N, sampling, reality, spmd): - if spmd and method != "jax": - skip("GPU distribution only valid for JAX.") +def setup_inverse(method, L, L_lower, N, sampling, reality): rng = np.random.default_rng() flmn = s2fft.utils.signal_generator.generate_flmn(rng, L, N, reality=reality) generate_precomputes = ( @@ -93,19 +86,17 @@ def setup_inverse(method, L, L_lower, N, sampling, reality, spmd): N=N_VALUES, sampling=SAMPLING_VALUES, reality=REALITY_VALUES, - spmd=SPMD_VALUES, ) -def inverse(flmn, precomps, method, L, L_lower, N, sampling, reality, spmd): - f = s2fft.transforms.spherical.inverse( - flm=flmn, +def inverse(flmn, precomps, method, L, L_lower, N, sampling, reality): + f = s2fft.transforms.wigner.inverse( + flmn=flmn, L=L, - L_lower=L_lower, N=N, - precomps=precomps, sampling=sampling, reality=reality, method=method, - spmd=spmd, + precomps=precomps, + L_lower=L_lower, ) if method == "jax": f.block_until_ready() From 64a2c719991b8f3238fe8c2f779feb991d44b409 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Fri, 29 Nov 2024 11:02:55 +0000 Subject: [PATCH 14/15] Removing unused imports --- benchmarks/precompute_wigner.py | 3 +-- benchmarks/wigner.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/benchmarks/precompute_wigner.py b/benchmarks/precompute_wigner.py index aa64ae69..01918ca7 100644 --- a/benchmarks/precompute_wigner.py +++ b/benchmarks/precompute_wigner.py @@ -1,12 +1,11 @@ """Benchmarks for precompute Wigner-d transforms.""" import numpy as np -from benchmarking import benchmark, parse_args_collect_and_run_benchmarks, skip +from benchmarking import benchmark, parse_args_collect_and_run_benchmarks import s2fft import s2fft.precompute_transforms from s2fft.base_transforms import wigner as base_wigner -from s2fft.sampling import s2_samples as samples L_VALUES = [16, 32, 64, 128, 256] N_VALUES = [2] diff --git a/benchmarks/wigner.py b/benchmarks/wigner.py index 412c404f..d6180961 100644 --- a/benchmarks/wigner.py +++ b/benchmarks/wigner.py @@ -1,7 +1,7 @@ """Benchmarks for on-the-fly Wigner-d transforms.""" import numpy as np -from benchmarking import benchmark, parse_args_collect_and_run_benchmarks, skip +from benchmarking import benchmark, parse_args_collect_and_run_benchmarks import s2fft from s2fft.base_transforms import wigner as base_wigner From 77fa3a3b146d21fb1e3328904b52fe7dfea29fa9 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Fri, 29 Nov 2024 11:18:31 +0000 Subject: [PATCH 15/15] Update benchmarking README to reflect new modules and functionality --- benchmarks/README.md | 45 ++++++++++++++++++++++++++++---------------- 1 file changed, 29 insertions(+), 16 deletions(-) diff --git a/benchmarks/README.md b/benchmarks/README.md index bbcd39f3..ab449949 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -8,15 +8,24 @@ and/or systems. If the [`memory_profiler` package](https://github.com/pythonprofilers/memory_profiler) is installed an estimate of the peak (main) memory usage of the benchmarked functions will also be recorded. +If the [`py-cpuinfo` package](https://pypi.org/project/py-cpuinfo/) +is installed additional information about CPU of system benchmarks are run on will be +recorded in JSON output. ## Description The benchmark scripts are as follows: - * `wigner.py` contains benchmarks for Wigner transforms (forward and inverse) - * `spherical.py` contains benchmarks for spherical transforms (forward and inverse) - + * `spherical.py` contains benchmarks for on-the-fly implementations of spherical + transforms (forward and inverse). + * `precompute_spherical.py` contains benchmarks for precompute implementations of + spherical transforms (forward and inverse). + * `wigner.py` contains benchmarks for on-the-fly implementations of Wigner-d + transforms (forward and inverse). + * `precompute_wigner.py` contains benchmarks for precompute implementations of + Wigner-d transforms (forward and inverse). + The `benchmarking.py` module contains shared utility functions for defining and running the benchmarks. @@ -29,22 +38,26 @@ the JSON formatted benchmark results to. Pass a `--help` argument to the script display the usage message: ``` -usage: Run benchmarks [-h] [-number-runs NUMBER_RUNS] [-repeats REPEATS] - [-parameter-overrides [PARAMETER_OVERRIDES [PARAMETER_OVERRIDES ...]]] - [-output-file OUTPUT_FILE] +usage: spherical.py [-h] [-number-runs NUMBER_RUNS] [-repeats REPEATS] + [-parameter-overrides [PARAMETER_OVERRIDES ...]] [-output-file OUTPUT_FILE] + [--run-once-and-discard] + +Benchmarks for on-the-fly spherical transforms. -optional arguments: +options: -h, --help show this help message and exit -number-runs NUMBER_RUNS - Number of times to run the benchmark in succession in each - timing run. - -repeats REPEATS Number of times to repeat the benchmark runs. - -parameter-overrides [PARAMETER_OVERRIDES [PARAMETER_OVERRIDES ...]] - Override for values to use for benchmark parameter. A parameter - name followed by space separated list of values to use. May be - specified multiple times to override multiple parameters. + Number of times to run the benchmark in succession in each timing run. (default: 10) + -repeats REPEATS Number of times to repeat the benchmark runs. (default: 3) + -parameter-overrides [PARAMETER_OVERRIDES ...] + Override for values to use for benchmark parameter. A parameter name followed by space + separated list of values to use. May be specified multiple times to override multiple + parameters. (default: None) -output-file OUTPUT_FILE - File path to write JSON formatted results to. + File path to write JSON formatted results to. (default: None) + --run-once-and-discard + Run benchmark function once first without recording time to ignore the effect of any initial + one-off costs such as just-in-time compilation. (default: False) ``` For example to run the spherical transform benchmarks using only the JAX implementations, @@ -52,7 +65,7 @@ running on a CPU (in double-precision) for `L` values 64, 128, 256, 512 and 1024 would run from the root of the repository: ```sh -JAX_PLATFORM_NAME=cpu JAX_ENABLE_X64=1 python benchmarks/spherical.py -p L 64 128 256 512 1024 -p method jax +JAX_PLATFORM_NAME=cpu JAX_ENABLE_X64=1 python benchmarks/spherical.py --run-once-and-discard -p L 64 128 256 512 1024 -p method jax ``` Note the usage of environment variables `JAX_PLATFORM_NAME` and `JAX_ENABLE_X64` to