-
Notifications
You must be signed in to change notification settings - Fork 282
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds XLA SDC check configuration to the trainer. (#765)
- Loading branch information
Showing
7 changed files
with
405 additions
and
172 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,10 @@ | ||
# Copyright © 2024 Apple Inc. | ||
"""Runtime and compiler options for JAX/XLA.""" | ||
|
||
# This module must not depend on any jax/axlearn modules so that | ||
# importing this module does not result in initializing jax. | ||
import re | ||
from typing import Union | ||
from typing import Any, Union | ||
|
||
|
||
def default_xla_options( | ||
|
@@ -104,4 +105,63 @@ def infer_tpu_version(tpu_type: str) -> str: | |
return tpu_version | ||
|
||
|
||
def infer_xsc_compiler_options( | ||
*, | ||
halt_on_detection: bool = True, | ||
repeat_count: int = 1, | ||
replicate_llo: bool = False, | ||
) -> dict[str, Union[str, Any]]: | ||
"""Infers compiler options for running compiled function with XLA SDC check enabled. | ||
Defaults are as advised by: <[email protected]>. | ||
To see additional XSC logging, enable the following environment variables at start time: | ||
```bash | ||
export TPU_MIN_LOG_LEVEL=0 | ||
export TPU_VMODULE=tpu_configuration_ops_impl=3 | ||
export TF_CPP_MIN_LOG_LEVEL=0 | ||
``` | ||
TODO(tom_gunter): Update with link to documentation once public. | ||
Args: | ||
halt_on_detection: Whether to halt the program and raise a Python exception on detection. | ||
repeat_count: Number of times to repeatedly call the program and validate outputs. | ||
replicate_llo: LLO sequence duplication, useful for single-core chips (e.g. v5e, v6e). | ||
Returns: | ||
A dictionary of compiler options that enable SDC checks. | ||
""" | ||
options = dict( | ||
# XLA SDC Checker flags: | ||
# Enable the SDC checker. | ||
xla_tpu_enable_sdc_checker=True, | ||
# Number of times to repeat the function call. | ||
xla_tpu_sdc_check_repeat_count=repeat_count, | ||
# Raise Python exception on error. | ||
xla_tpu_sdc_check_halt_on_detection=halt_on_detection, | ||
# Duplicate LLO sequences. | ||
xla_tpu_sdc_replicate_llo=replicate_llo, | ||
# Alternate primary/secondary core for each re-run for platforms with 2 cores per device. | ||
xla_tpu_sdc_checker_alternate_megacore_cores=True, | ||
# XLA ICI SDC Checker flags: | ||
# N.B. ICI checker only runs once after first program compilation. | ||
# Enable the interconnect checker on first program call. | ||
xla_tpu_ici_sdc_test_run_on_program_start=True, | ||
# Max distance between send/recv neighbours. | ||
xla_tpu_ici_sdc_test_max_distance=1, | ||
# Number of repeated send/recv before checking for equivalence. | ||
xla_tpu_ici_sdc_test_pipeline_depth=4, | ||
# Size of the random+checksum buffer to send/recv in 4KiB chunks. | ||
xla_tpu_ici_sdc_test_buffer_size_chunks=32, | ||
# Number of packets to split buffer into. | ||
xla_tpu_ici_sdc_test_packet_size_chunks=4, | ||
# Number of times to repeat the create-buffer/send/recv/verify loop. | ||
xla_tpu_ici_sdc_test_iterations=10, | ||
# Enable LLO log recording which will print performance (bandwith/latency) stats. | ||
xla_tpu_enable_log_recorder=False, | ||
) | ||
return options | ||
|
||
|
||
_TPU_VERSIONS = ("v3", "v4", "v5litepod", "v5p") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,159 +1,114 @@ | ||
"""Utilities for debugging training.""" | ||
|
||
import functools | ||
import gc | ||
from typing import Any, Callable, Protocol | ||
from typing import Any, Callable, Optional, Protocol, Union | ||
|
||
import chex | ||
import jax | ||
from absl import logging | ||
from jax._src.pjit import pjit | ||
from jax.experimental import checkify | ||
|
||
|
||
class JitFn(Protocol): | ||
"""A "pjit-like" function suitable for using in place of JAX's implementation of `pjit`.""" | ||
class _CheckifyCompiledFnWrapper: | ||
"""Performs error handling on a "checkified" compiled function during the call.""" | ||
|
||
def __call__(self, fun: Callable, *args, **kwargs) -> Callable: | ||
"""Return a wrapped version of fun using pjit or similar. | ||
def __init__(self, compiled: jax.stages.Compiled): | ||
self._compiled = compiled | ||
|
||
The arguments accepted are the same as JAX's `pjit`. | ||
""" | ||
|
||
|
||
# pylint: disable=redefined-outer-name, | ||
# pylint: disable-next=protected-access | ||
JaxException = jax._src.checkify.JaxException # type: ignore[module-attr] | ||
def __call__(self, *args, **kwargs) -> Any: | ||
"""Calls the compiled function and raises on detected checkify error.""" | ||
err, result = self._compiled(*args, **kwargs) | ||
checkify.check_error(err) | ||
return result | ||
|
||
|
||
def checkify_pjit(errors: frozenset[JaxException], pjit: JitFn = pjit) -> JitFn: | ||
"""Produce a checkified version of pjit. | ||
class _CheckifyLoweredFnWrapper: | ||
"""Wraps a lowered checkified function.""" | ||
|
||
See the docstring of `checkify_and_rerun_on_nonfinite` for a usage example. | ||
def __init__( | ||
self, | ||
lowered: jax.stages.Lowered, | ||
): | ||
self._lowered = lowered | ||
|
||
Args: | ||
errors: The checkify checks to run. | ||
pjit: The pjit function to wrap. | ||
def compile( | ||
self, compiler_options: Optional[dict[str, Union[str, bool]]] = None | ||
) -> _CheckifyCompiledFnWrapper: | ||
"""Compile the function with provided options.""" | ||
compiled = self._lowered.compile(compiler_options=compiler_options) | ||
return _CheckifyCompiledFnWrapper(compiled) | ||
|
||
Returns: | ||
A checkified version of pjit. | ||
""" | ||
|
||
def pjit_checkify_decorator(fun, *args, **kwargs): | ||
checkified_fun = checkify.checkify(pjit(fun, *args, **kwargs), errors=errors) | ||
class CheckifyJitFnWrapper: | ||
"""A checkify wrapper for use in place of JAX's implementation of `pjit` in some cases.""" | ||
|
||
def pjit_checkify_wrapper(*args, **kwargs): | ||
err, result = checkified_fun(*args, **kwargs) | ||
checkify.check_error(err) | ||
return result | ||
def __init__( | ||
self, | ||
checkified_jit_handle: jax.stages.Wrapped, | ||
): | ||
self._checkified_jit_handle = checkified_jit_handle | ||
|
||
return pjit_checkify_wrapper | ||
def __call__( | ||
self, *args, compiler_options: Optional[dict[str, Union[str, bool]]] = None, **kwargs | ||
) -> Any: | ||
"""Lowers, compiles, and runs the function with provided arguments and keyword arguments.""" | ||
lowered = self.lower(*args, **kwargs) | ||
compiled = lowered.compile(compiler_options) | ||
return compiled(*args, **kwargs) | ||
|
||
return pjit_checkify_decorator | ||
def lower(self, *args, **kwargs) -> _CheckifyLoweredFnWrapper: | ||
"""Traces and lowers the function using the provided arguments.""" | ||
lowered = self._checkified_jit_handle.lower(*args, **kwargs) | ||
return _CheckifyLoweredFnWrapper(lowered) | ||
|
||
|
||
def checkify_and_rerun_for_float_errors(pjit: JitFn = pjit) -> JitFn: | ||
"""Produce a pjit-like transformation that runs jax checkify float checks. | ||
class CheckifyJitFn(Protocol): | ||
"""Mirrors the call signature of JAX's `pjit` definition.""" | ||
|
||
Args: | ||
pjit: The pjit function to wrap. | ||
def __call__(self, fun: Callable, *args, **kwargs) -> CheckifyJitFnWrapper: | ||
"""Return a jit-fn wrapped version of `fun`. | ||
Returns: | ||
A checkified version of pjit. | ||
""" | ||
return checkify_and_rerun_on_nonfinite(checkify.float_checks, pjit=pjit) | ||
The arguments accepted are the same as JAX's `pjit`. | ||
""" | ||
|
||
|
||
def checkify_and_rerun_on_nonfinite( | ||
errors: frozenset[JaxException], *, pjit: JitFn = pjit | ||
) -> JitFn: | ||
"""Produce a pjit-like transformation that detects if the output contains nonfinite values | ||
and if found, rerurns with additional error instrumentation. | ||
# pylint: disable=redefined-outer-name, | ||
# pylint: disable-next=protected-access | ||
JaxException = jax._src.checkify.JaxException # type: ignore[module-attr] | ||
|
||
This is similar to `jax_debug_nans` but it works properly with jit and pjit. | ||
Despite claims on the jax documentation,there are cases where `jax_debug_nans` failed to locate | ||
the nans when using jit. | ||
|
||
Note: Unlike ordinary pjit, this prevents donating the arguments. | ||
def checkify_pjit(errors: frozenset[JaxException]) -> CheckifyJitFn: | ||
"""Produce a checkified version of pjit. | ||
Example: | ||
``` | ||
pjit_with_rerun = checkify_and_rerun_on_nonfinite(errors=checkify.float_checks) | ||
```py | ||
pjit_with_nan_check = checkify_pjit(errors=checkify.nan_checks) | ||
@pjit_with_rerun | ||
def fn(x,y): | ||
return x / y | ||
@pjit_with_nan_check | ||
def fn(x): | ||
return jnp.log(x) | ||
assert fn(8,2) == 4 | ||
assert fn(jnp.exp(1)) == 1.0 | ||
# Raises a JaxRuntimeError with the source of the division by 0. | ||
fn(3,0) | ||
``` | ||
# Raises a JaxRuntimeError with the source of the undefined logarithm call. | ||
fn(-1) | ||
``` | ||
Args: | ||
errors: The checkify error checks to enable when rerunning. | ||
errors: The checkify checks to run. | ||
Returns: | ||
A function suitable for `SpmdTrainer.Config.dynamic_rerun`. | ||
Raises: | ||
JaxRuntimeException: If a nonfinite value is found in the original run and the checkify | ||
checks fail in the rerun. | ||
A checkified version of pjit. | ||
""" | ||
|
||
def pjit_and_rerun_decorator(fun, *args, donate_argnums: Any = None, **kwargs): | ||
# donate_argnums cannot be used because we need to be able to rerun on the original inputs. | ||
if donate_argnums: | ||
logging.warning( | ||
"Ignoring donate_argnums=%s because it is incompatible with rerunning.", | ||
donate_argnums, | ||
) | ||
jit_fun = pjit(fun, *args, **kwargs) | ||
|
||
@functools.wraps(fun) | ||
def pjit_and_rerun_wrapper(*args, **kwargs): | ||
# Ensure leftover arrays from previous invocations are collected. | ||
gc.collect() | ||
# Run function first time. | ||
result = jit_fun(*args, **kwargs) | ||
try: | ||
chex.assert_tree_all_finite(result) | ||
except AssertionError as e: | ||
logging.error("Got nonfinite results from pjit function %s", e) | ||
checkified_fun = checkify.checkify(jit_fun, errors=errors) | ||
# Run function second time. | ||
err, result = checkified_fun(*args, **kwargs) | ||
checkify.check_error(err) | ||
# If no error raised from above line: | ||
logging.warning("Bad call failed to reproduce the issue when rerun. Continuing...") | ||
return result | ||
|
||
return pjit_and_rerun_wrapper | ||
|
||
return pjit_and_rerun_decorator | ||
def pjit_checkify_decorator( | ||
fun, *args, out_shardings: Any = None, **kwargs | ||
) -> CheckifyJitFnWrapper: | ||
# We need to update out_shardings to handle the checkify result. | ||
out_shardings = (None, out_shardings) | ||
checkified_fun = CheckifyJitFnWrapper( | ||
pjit( | ||
checkify.checkify(fun, errors=errors), *args, out_shardings=out_shardings, **kwargs | ||
), | ||
) | ||
return checkified_fun | ||
|
||
|
||
def noop_pjit() -> JitFn: | ||
"""Produces a noop function that does not jit.""" | ||
|
||
def no_pjit_decorator(fun, *args, **kwargs): | ||
del args, kwargs | ||
return fun | ||
|
||
return no_pjit_decorator | ||
|
||
|
||
def checking_leaks_pjit(*, pjit: JitFn = pjit) -> JitFn: | ||
"""Prdouces a pjit-like transformation with jax tracer leak detection.""" | ||
|
||
def checking_leaks_pjit_decorator(fun, *args, **kwargs): | ||
jit_fun = pjit(fun, *args, **kwargs) | ||
|
||
@functools.wraps(jit_fun) | ||
def checking_leaks_pjit_wrapper(*args, **kwargs): | ||
with jax.checking_leaks(): | ||
return jit_fun(*args, **kwargs) | ||
|
||
return checking_leaks_pjit_wrapper | ||
|
||
return checking_leaks_pjit_decorator | ||
return pjit_checkify_decorator |
Oops, something went wrong.