Skip to content

Commit

Permalink
Adds XLA SDC check configuration to the trainer. (#765)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgunter authored Oct 21, 2024
1 parent 20a87a6 commit b3f1d3a
Show file tree
Hide file tree
Showing 7 changed files with 405 additions and 172 deletions.
62 changes: 61 additions & 1 deletion axlearn/common/compiler_options.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright © 2024 Apple Inc.
"""Runtime and compiler options for JAX/XLA."""

# This module must not depend on any jax/axlearn modules so that
# importing this module does not result in initializing jax.
import re
from typing import Union
from typing import Any, Union


def default_xla_options(
Expand Down Expand Up @@ -104,4 +105,63 @@ def infer_tpu_version(tpu_type: str) -> str:
return tpu_version


def infer_xsc_compiler_options(
*,
halt_on_detection: bool = True,
repeat_count: int = 1,
replicate_llo: bool = False,
) -> dict[str, Union[str, Any]]:
"""Infers compiler options for running compiled function with XLA SDC check enabled.
Defaults are as advised by: <[email protected]>.
To see additional XSC logging, enable the following environment variables at start time:
```bash
export TPU_MIN_LOG_LEVEL=0
export TPU_VMODULE=tpu_configuration_ops_impl=3
export TF_CPP_MIN_LOG_LEVEL=0
```
TODO(tom_gunter): Update with link to documentation once public.
Args:
halt_on_detection: Whether to halt the program and raise a Python exception on detection.
repeat_count: Number of times to repeatedly call the program and validate outputs.
replicate_llo: LLO sequence duplication, useful for single-core chips (e.g. v5e, v6e).
Returns:
A dictionary of compiler options that enable SDC checks.
"""
options = dict(
# XLA SDC Checker flags:
# Enable the SDC checker.
xla_tpu_enable_sdc_checker=True,
# Number of times to repeat the function call.
xla_tpu_sdc_check_repeat_count=repeat_count,
# Raise Python exception on error.
xla_tpu_sdc_check_halt_on_detection=halt_on_detection,
# Duplicate LLO sequences.
xla_tpu_sdc_replicate_llo=replicate_llo,
# Alternate primary/secondary core for each re-run for platforms with 2 cores per device.
xla_tpu_sdc_checker_alternate_megacore_cores=True,
# XLA ICI SDC Checker flags:
# N.B. ICI checker only runs once after first program compilation.
# Enable the interconnect checker on first program call.
xla_tpu_ici_sdc_test_run_on_program_start=True,
# Max distance between send/recv neighbours.
xla_tpu_ici_sdc_test_max_distance=1,
# Number of repeated send/recv before checking for equivalence.
xla_tpu_ici_sdc_test_pipeline_depth=4,
# Size of the random+checksum buffer to send/recv in 4KiB chunks.
xla_tpu_ici_sdc_test_buffer_size_chunks=32,
# Number of packets to split buffer into.
xla_tpu_ici_sdc_test_packet_size_chunks=4,
# Number of times to repeat the create-buffer/send/recv/verify loop.
xla_tpu_ici_sdc_test_iterations=10,
# Enable LLO log recording which will print performance (bandwith/latency) stats.
xla_tpu_enable_log_recorder=False,
)
return options


_TPU_VERSIONS = ("v3", "v4", "v5litepod", "v5p")
22 changes: 22 additions & 0 deletions axlearn/common/compiler_options_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright © 2024 Apple Inc.
"""Tests for compiler_options.py."""

import jax
import jax.numpy as jnp
import pytest
Expand All @@ -26,3 +27,24 @@ def atest_xla_flags_from_options(self):
options = dict(a="true", b="false", c=True, d=False, long_option_name=True)
result = compiler_options.xla_flags_from_options(options)
self.assertEqual(result, "--a=true --b=false --c=1 --d=0 --long_option_name=1")

def test_xsc_compiler_options(self):
options = compiler_options.infer_xsc_compiler_options(
halt_on_detection=False, repeat_count=2, replicate_llo=True
)
expected_options = dict(
xla_tpu_enable_sdc_checker=True,
xla_tpu_sdc_check_repeat_count=2,
xla_tpu_sdc_check_halt_on_detection=False,
xla_tpu_sdc_replicate_llo=True,
xla_tpu_sdc_checker_alternate_megacore_cores=True,
xla_tpu_ici_sdc_test_run_on_program_start=True,
xla_tpu_ici_sdc_test_max_distance=1,
xla_tpu_ici_sdc_test_pipeline_depth=4,
xla_tpu_ici_sdc_test_buffer_size_chunks=32,
xla_tpu_ici_sdc_test_packet_size_chunks=4,
xla_tpu_ici_sdc_test_iterations=10,
xla_tpu_enable_log_recorder=False,
)
for name, option in options.items():
self.assertEqual(option, expected_options[name])
195 changes: 75 additions & 120 deletions axlearn/common/debug_utils.py
Original file line number Diff line number Diff line change
@@ -1,159 +1,114 @@
"""Utilities for debugging training."""

import functools
import gc
from typing import Any, Callable, Protocol
from typing import Any, Callable, Optional, Protocol, Union

import chex
import jax
from absl import logging
from jax._src.pjit import pjit
from jax.experimental import checkify


class JitFn(Protocol):
"""A "pjit-like" function suitable for using in place of JAX's implementation of `pjit`."""
class _CheckifyCompiledFnWrapper:
"""Performs error handling on a "checkified" compiled function during the call."""

def __call__(self, fun: Callable, *args, **kwargs) -> Callable:
"""Return a wrapped version of fun using pjit or similar.
def __init__(self, compiled: jax.stages.Compiled):
self._compiled = compiled

The arguments accepted are the same as JAX's `pjit`.
"""


# pylint: disable=redefined-outer-name,
# pylint: disable-next=protected-access
JaxException = jax._src.checkify.JaxException # type: ignore[module-attr]
def __call__(self, *args, **kwargs) -> Any:
"""Calls the compiled function and raises on detected checkify error."""
err, result = self._compiled(*args, **kwargs)
checkify.check_error(err)
return result


def checkify_pjit(errors: frozenset[JaxException], pjit: JitFn = pjit) -> JitFn:
"""Produce a checkified version of pjit.
class _CheckifyLoweredFnWrapper:
"""Wraps a lowered checkified function."""

See the docstring of `checkify_and_rerun_on_nonfinite` for a usage example.
def __init__(
self,
lowered: jax.stages.Lowered,
):
self._lowered = lowered

Args:
errors: The checkify checks to run.
pjit: The pjit function to wrap.
def compile(
self, compiler_options: Optional[dict[str, Union[str, bool]]] = None
) -> _CheckifyCompiledFnWrapper:
"""Compile the function with provided options."""
compiled = self._lowered.compile(compiler_options=compiler_options)
return _CheckifyCompiledFnWrapper(compiled)

Returns:
A checkified version of pjit.
"""

def pjit_checkify_decorator(fun, *args, **kwargs):
checkified_fun = checkify.checkify(pjit(fun, *args, **kwargs), errors=errors)
class CheckifyJitFnWrapper:
"""A checkify wrapper for use in place of JAX's implementation of `pjit` in some cases."""

def pjit_checkify_wrapper(*args, **kwargs):
err, result = checkified_fun(*args, **kwargs)
checkify.check_error(err)
return result
def __init__(
self,
checkified_jit_handle: jax.stages.Wrapped,
):
self._checkified_jit_handle = checkified_jit_handle

return pjit_checkify_wrapper
def __call__(
self, *args, compiler_options: Optional[dict[str, Union[str, bool]]] = None, **kwargs
) -> Any:
"""Lowers, compiles, and runs the function with provided arguments and keyword arguments."""
lowered = self.lower(*args, **kwargs)
compiled = lowered.compile(compiler_options)
return compiled(*args, **kwargs)

return pjit_checkify_decorator
def lower(self, *args, **kwargs) -> _CheckifyLoweredFnWrapper:
"""Traces and lowers the function using the provided arguments."""
lowered = self._checkified_jit_handle.lower(*args, **kwargs)
return _CheckifyLoweredFnWrapper(lowered)


def checkify_and_rerun_for_float_errors(pjit: JitFn = pjit) -> JitFn:
"""Produce a pjit-like transformation that runs jax checkify float checks.
class CheckifyJitFn(Protocol):
"""Mirrors the call signature of JAX's `pjit` definition."""

Args:
pjit: The pjit function to wrap.
def __call__(self, fun: Callable, *args, **kwargs) -> CheckifyJitFnWrapper:
"""Return a jit-fn wrapped version of `fun`.
Returns:
A checkified version of pjit.
"""
return checkify_and_rerun_on_nonfinite(checkify.float_checks, pjit=pjit)
The arguments accepted are the same as JAX's `pjit`.
"""


def checkify_and_rerun_on_nonfinite(
errors: frozenset[JaxException], *, pjit: JitFn = pjit
) -> JitFn:
"""Produce a pjit-like transformation that detects if the output contains nonfinite values
and if found, rerurns with additional error instrumentation.
# pylint: disable=redefined-outer-name,
# pylint: disable-next=protected-access
JaxException = jax._src.checkify.JaxException # type: ignore[module-attr]

This is similar to `jax_debug_nans` but it works properly with jit and pjit.
Despite claims on the jax documentation,there are cases where `jax_debug_nans` failed to locate
the nans when using jit.

Note: Unlike ordinary pjit, this prevents donating the arguments.
def checkify_pjit(errors: frozenset[JaxException]) -> CheckifyJitFn:
"""Produce a checkified version of pjit.
Example:
```
pjit_with_rerun = checkify_and_rerun_on_nonfinite(errors=checkify.float_checks)
```py
pjit_with_nan_check = checkify_pjit(errors=checkify.nan_checks)
@pjit_with_rerun
def fn(x,y):
return x / y
@pjit_with_nan_check
def fn(x):
return jnp.log(x)
assert fn(8,2) == 4
assert fn(jnp.exp(1)) == 1.0
# Raises a JaxRuntimeError with the source of the division by 0.
fn(3,0)
```
# Raises a JaxRuntimeError with the source of the undefined logarithm call.
fn(-1)
```
Args:
errors: The checkify error checks to enable when rerunning.
errors: The checkify checks to run.
Returns:
A function suitable for `SpmdTrainer.Config.dynamic_rerun`.
Raises:
JaxRuntimeException: If a nonfinite value is found in the original run and the checkify
checks fail in the rerun.
A checkified version of pjit.
"""

def pjit_and_rerun_decorator(fun, *args, donate_argnums: Any = None, **kwargs):
# donate_argnums cannot be used because we need to be able to rerun on the original inputs.
if donate_argnums:
logging.warning(
"Ignoring donate_argnums=%s because it is incompatible with rerunning.",
donate_argnums,
)
jit_fun = pjit(fun, *args, **kwargs)

@functools.wraps(fun)
def pjit_and_rerun_wrapper(*args, **kwargs):
# Ensure leftover arrays from previous invocations are collected.
gc.collect()
# Run function first time.
result = jit_fun(*args, **kwargs)
try:
chex.assert_tree_all_finite(result)
except AssertionError as e:
logging.error("Got nonfinite results from pjit function %s", e)
checkified_fun = checkify.checkify(jit_fun, errors=errors)
# Run function second time.
err, result = checkified_fun(*args, **kwargs)
checkify.check_error(err)
# If no error raised from above line:
logging.warning("Bad call failed to reproduce the issue when rerun. Continuing...")
return result

return pjit_and_rerun_wrapper

return pjit_and_rerun_decorator
def pjit_checkify_decorator(
fun, *args, out_shardings: Any = None, **kwargs
) -> CheckifyJitFnWrapper:
# We need to update out_shardings to handle the checkify result.
out_shardings = (None, out_shardings)
checkified_fun = CheckifyJitFnWrapper(
pjit(
checkify.checkify(fun, errors=errors), *args, out_shardings=out_shardings, **kwargs
),
)
return checkified_fun


def noop_pjit() -> JitFn:
"""Produces a noop function that does not jit."""

def no_pjit_decorator(fun, *args, **kwargs):
del args, kwargs
return fun

return no_pjit_decorator


def checking_leaks_pjit(*, pjit: JitFn = pjit) -> JitFn:
"""Prdouces a pjit-like transformation with jax tracer leak detection."""

def checking_leaks_pjit_decorator(fun, *args, **kwargs):
jit_fun = pjit(fun, *args, **kwargs)

@functools.wraps(jit_fun)
def checking_leaks_pjit_wrapper(*args, **kwargs):
with jax.checking_leaks():
return jit_fun(*args, **kwargs)

return checking_leaks_pjit_wrapper

return checking_leaks_pjit_decorator
return pjit_checkify_decorator
Loading

0 comments on commit b3f1d3a

Please sign in to comment.