Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements to benchmarking system #248

Merged
merged 15 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 29 additions & 16 deletions benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,24 @@ and/or systems.
If the [`memory_profiler` package](https://github.com/pythonprofilers/memory_profiler)
is installed an estimate of the peak (main) memory usage of the benchmarked functions
will also be recorded.
If the [`py-cpuinfo` package](https://pypi.org/project/py-cpuinfo/)
is installed additional information about CPU of system benchmarks are run on will be
recorded in JSON output.


## Description

The benchmark scripts are as follows:

* `wigner.py` contains benchmarks for Wigner transforms (forward and inverse)
* `spherical.py` contains benchmarks for spherical transforms (forward and inverse)

* `spherical.py` contains benchmarks for on-the-fly implementations of spherical
transforms (forward and inverse).
* `precompute_spherical.py` contains benchmarks for precompute implementations of
spherical transforms (forward and inverse).
* `wigner.py` contains benchmarks for on-the-fly implementations of Wigner-d
transforms (forward and inverse).
* `precompute_wigner.py` contains benchmarks for precompute implementations of
Wigner-d transforms (forward and inverse).

The `benchmarking.py` module contains shared utility functions for defining and running
the benchmarks.

Expand All @@ -29,30 +38,34 @@ the JSON formatted benchmark results to. Pass a `--help` argument to the script
display the usage message:

```
usage: Run benchmarks [-h] [-number-runs NUMBER_RUNS] [-repeats REPEATS]
[-parameter-overrides [PARAMETER_OVERRIDES [PARAMETER_OVERRIDES ...]]]
[-output-file OUTPUT_FILE]
usage: spherical.py [-h] [-number-runs NUMBER_RUNS] [-repeats REPEATS]
[-parameter-overrides [PARAMETER_OVERRIDES ...]] [-output-file OUTPUT_FILE]
[--run-once-and-discard]

Benchmarks for on-the-fly spherical transforms.

optional arguments:
options:
-h, --help show this help message and exit
-number-runs NUMBER_RUNS
Number of times to run the benchmark in succession in each
timing run.
-repeats REPEATS Number of times to repeat the benchmark runs.
-parameter-overrides [PARAMETER_OVERRIDES [PARAMETER_OVERRIDES ...]]
Override for values to use for benchmark parameter. A parameter
name followed by space separated list of values to use. May be
specified multiple times to override multiple parameters.
Number of times to run the benchmark in succession in each timing run. (default: 10)
-repeats REPEATS Number of times to repeat the benchmark runs. (default: 3)
-parameter-overrides [PARAMETER_OVERRIDES ...]
Override for values to use for benchmark parameter. A parameter name followed by space
separated list of values to use. May be specified multiple times to override multiple
parameters. (default: None)
-output-file OUTPUT_FILE
File path to write JSON formatted results to.
File path to write JSON formatted results to. (default: None)
--run-once-and-discard
Run benchmark function once first without recording time to ignore the effect of any initial
one-off costs such as just-in-time compilation. (default: False)
```

For example to run the spherical transform benchmarks using only the JAX implementations,
running on a CPU (in double-precision) for `L` values 64, 128, 256, 512 and 1024 we
would run from the root of the repository:

```sh
JAX_PLATFORM_NAME=cpu JAX_ENABLE_X64=1 python benchmarks/spherical.py -p L 64 128 256 512 1024 -p method jax
JAX_PLATFORM_NAME=cpu JAX_ENABLE_X64=1 python benchmarks/spherical.py --run-once-and-discard -p L 64 128 256 512 1024 -p method jax
```

Note the usage of environment variables `JAX_PLATFORM_NAME` and `JAX_ENABLE_X64` to
Expand Down
128 changes: 118 additions & 10 deletions benchmarks/benchmarking.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,14 @@ def mean(x, n):
"""

import argparse
import datetime
import inspect
import json
import platform
import timeit
from ast import literal_eval
from functools import partial
from importlib.metadata import PackageNotFoundError, version
from itertools import product
from pathlib import Path

Expand All @@ -80,6 +83,69 @@ class SkipBenchmarkException(Exception):
"""Exception to be raised to skip benchmark for some parameter set."""


def _get_version_or_none(package_name):
"""Get installed version of package or `None` if package not found."""
try:
return version(package_name)
except PackageNotFoundError:
return None


def _get_cpu_info():
"""Get details of CPU from cpuinfo if available or None if not."""
try:
import cpuinfo

return cpuinfo.get_cpu_info()
except ImportError:
return None


def _get_gpu_memory_mebibytes(device):
"""Try to get GPU memory available in mebibytes (MiB)."""
memory_stats = device.memory_stats()
if memory_stats is None:
return None
bytes_limit = memory_stats.get("bytes_limit")
return bytes_limit // 2**20 if bytes_limit is not None else None


def _get_gpu_info():
"""Get details of GPU devices available from JAX or None if JAX not available."""
try:
import jax

return [
{
"kind": d.device_kind,
"memory_available / MiB": _get_gpu_memory_mebibytes(d),
}
for d in jax.devices()
if d.platform == "gpu"
]
except ImportError:
return None


def _get_cuda_info():
"""Try to get information on versions of CUDA libraries."""
try:
from jax._src.lib import cuda_versions

if cuda_versions is None:
return None
return {
"cuda_runtime_version": cuda_versions.cuda_runtime_get_version(),
"cuda_runtime_build_version": cuda_versions.cuda_runtime_build_version(),
"cudnn_version": cuda_versions.cudnn_get_version(),
"cudnn_build_version": cuda_versions.cudnn_build_version(),
"cufft_version": cuda_versions.cufft_get_version(),
"cufft_build_version": cuda_versions.cufft_build_version(),
}
except ImportError:
return None


def skip(message):
"""Skip benchmark for a particular parameter set with explanatory message.

Expand Down Expand Up @@ -145,9 +211,11 @@ def _parse_parameter_overrides(parameter_overrides):
)


def _parse_cli_arguments():
def _parse_cli_arguments(description):
"""Parse command line arguments passed for controlling benchmark runs"""
parser = argparse.ArgumentParser("Run benchmarks")
parser = argparse.ArgumentParser(
description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"-number-runs",
type=int,
Expand All @@ -174,6 +242,15 @@ def _parse_cli_arguments():
parser.add_argument(
"-output-file", type=Path, help="File path to write JSON formatted results to."
)
parser.add_argument(
"--run-once-and-discard",
action="store_true",
help=(
"Run benchmark function once first without recording time to "
"ignore the effect of any initial one-off costs such as just-in-time "
"compilation."
),
)
return parser.parse_args()


Expand Down Expand Up @@ -204,6 +281,7 @@ def run_benchmarks(
number_repeats,
print_results=True,
parameter_overrides=None,
run_once_and_discard=False,
):
"""Run a set of benchmarks.

Expand All @@ -217,14 +295,17 @@ def run_benchmarks(
print_results: Whether to print benchmark results to stdout.
parameter_overrides: Dictionary specifying any overrides for parameter values
set in `benchmark` decorator.
run_once_and_discard: Whether to run benchmark function once first without
recording time to ignore the effect of any initial one-off costs such as
just-in-time compilation.

Returns:
Dictionary containing timing (and potentially memory usage) results for each
parameters set of each benchmark function.
"""
results = {}
for benchmark in benchmarks:
results[benchmark.__name__] = {}
results[benchmark.__name__] = []
if print_results:
print(benchmark.__name__)
parameters = benchmark.parameters.copy()
Expand All @@ -234,13 +315,15 @@ def run_benchmarks(
try:
precomputes = benchmark.setup(**parameter_set)
benchmark_function = partial(benchmark, **precomputes, **parameter_set)
if run_once_and_discard:
benchmark_function()
run_times = [
time / number_runs
for time in timeit.repeat(
benchmark_function, number=number_runs, repeat=number_repeats
)
]
results[benchmark.__name__] = {**parameter_set, "times / s": run_times}
results_entry = {**parameter_set, "times / s": run_times}
if MEMORY_PROFILER_AVAILABLE:
baseline_memory = memory_profiler.memory_usage(max_usage=True)
peak_memory = (
Expand All @@ -253,7 +336,8 @@ def run_benchmarks(
)
- baseline_memory
)
results[benchmark.__name__]["peak_memory / MiB"] = peak_memory
results_entry["peak_memory / MiB"] = peak_memory
results[benchmark.__name__].append(results_entry)
if print_results:
print(
(
Expand All @@ -262,9 +346,9 @@ def run_benchmarks(
else " "
)
+ f"min(time): {min(run_times):>#7.2g}s, "
+ f"max(time): {max(run_times):>#7.2g}s, "
+ f"max(time): {max(run_times):>#7.2g}s"
+ (
f"peak mem.: {peak_memory:>#7.2g}MiB"
f", peak mem.: {peak_memory:>#7.2g}MiB"
if MEMORY_PROFILER_AVAILABLE
else ""
)
Expand All @@ -288,18 +372,42 @@ def parse_args_collect_and_run_benchmarks(module=None):
Dictionary containing timing (and potentially memory usage) results for each
parameters set of each benchmark function.
"""
args = _parse_cli_arguments()
parameter_overrides = _parse_parameter_overrides(args.parameter_overrides)
if module is None:
frame = inspect.stack()[1]
module = inspect.getmodule(frame[0])
args = _parse_cli_arguments(module.__doc__)
parameter_overrides = _parse_parameter_overrides(args.parameter_overrides)
results = run_benchmarks(
benchmarks=collect_benchmarks(module),
number_runs=args.number_runs,
number_repeats=args.repeats,
parameter_overrides=parameter_overrides,
run_once_and_discard=args.run_once_and_discard,
)
if args.output_file is not None:
package_versions = {
f"{package}_version": _get_version_or_none(package)
for package in ("s2fft", "jax", "numpy")
}
system_info = {
"architecture": platform.architecture(),
"machine": platform.machine(),
"node": platform.node(),
"processor": platform.processor(),
"python_version": platform.python_version(),
"release": platform.release(),
"system": platform.system(),
"cpu_info": _get_cpu_info(),
"gpu_info": _get_gpu_info(),
"cuda_info": _get_cuda_info(),
**package_versions,
}
with open(args.output_file, "w") as f:
json.dump(results, f)
output = {
"date_time": datetime.datetime.now().isoformat(),
"benchmark_module": module.__name__,
"system_info": system_info,
"results": results,
}
json.dump(output, f, indent=True)
return results
Loading