From f524c0acf99b96da23eea2c049498ce34b01dccb Mon Sep 17 00:00:00 2001 From: Isaack Karanja Date: Thu, 3 Oct 2024 17:37:06 +0000 Subject: [PATCH] Add --megascale_abort_on_hangs flag for multi-slice TPU jobs * Introduce flag to terminate jobs on MegaScale Runtime Errors * Enable auto-restart of jax process when errors occur * Prevent silent hangs in multi-slice TPU configurations * Reduce time to recovery for failed jobs * ref: https://github.com/apple/axlearn/pull/716 * co-authored by Nick Stogner --- axlearn/common/compiler_options.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/axlearn/common/compiler_options.py b/axlearn/common/compiler_options.py index 3ba82b36..5228ee32 100644 --- a/axlearn/common/compiler_options.py +++ b/axlearn/common/compiler_options.py @@ -52,6 +52,11 @@ def default_xla_options( # concurrently with gradient computation for the following layer. xla_tpu_enable_data_parallel_all_reduce_opt="true", xla_tpu_data_parallel_opt_different_sized_ops="true", + # If MegaScale Runtime Error is encountered when running multi-slice jobs, + # enabling this flag will allow for termination of the job, triggering + # the process to exit. This is set to true to prevent the job from + # silently hanging and to reduce time to recovery. + megascale_abort_on_hangs="true", ) # Validate options. Will never fail if this function is implemented correctly.