-
Notifications
You must be signed in to change notification settings - Fork 287
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Changes in preparation for jax 0.4.21. (#235)
- Loading branch information
Showing
9 changed files
with
74 additions
and
88 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,15 +5,15 @@ | |
import os | ||
import sys | ||
|
||
tpu_type = os.environ.get("TPU_TYPE", "none") | ||
instance_type = os.environ.get("TPU_TYPE", "none") | ||
|
||
# Set LIBTPU_INIT_ARGS before importing jax! | ||
libtpu_init_args = [ | ||
"--xla_tpu_spmd_rng_bit_generator_unsafe=1", # SPMD partition-aware RngBitGenerator. | ||
"--xla_tpu_enable_latency_hiding_scheduler=true", # Try to schedule ops efficiently. | ||
"--xla_tpu_perform_spmd_cse_prevention=false", # b/229655601: prevent OOM on gpt2-small-repeat. | ||
] | ||
if tpu_type.startswith("v4-"): | ||
if instance_type.startswith("v4-"): | ||
libtpu_init_args += [ | ||
# Per [email protected], the following flags are not supported by V3. | ||
"--xla_enable_async_all_gather=true", # Allow async all-gather. | ||
|
@@ -44,11 +44,9 @@ | |
# tpu_library_init_fns.inc:98] TpuEmbeddingEngine_ExecutePartitioner not available in this library. | ||
import jax # jax must be imported before tensorflow! | ||
|
||
# NOTE: calling JAX distributed APIs (e.g. jax.default_backend(), jax.process_index() or | ||
# jax.process_count()) on GPU causes JAX to only view one process' GPUs. | ||
print(f"jax version={jax.__version__}", file=sys.stderr) | ||
if tpu_type != "none": | ||
print(f"instance_type={tpu_type} num_slices={num_tpu_slices}", file=sys.stderr) | ||
if instance_type != "none": | ||
print(f"instance_type={instance_type} num_slices={num_tpu_slices}", file=sys.stderr) | ||
|
||
import logging as pylogging | ||
|
||
|
@@ -76,48 +74,29 @@ | |
"If 'FAKE', uses fake inputs.", | ||
) | ||
flags.DEFINE_integer("jax_profiler_port", None, "If not None, the profiler port.") | ||
flags.DEFINE_string( | ||
"jax_backend", None, "If not None, ensures that trainer runs on the specified XLA backend." | ||
) | ||
flags.DEFINE_string("jax_backend", None, "Specifies the XLA backend to use.", required=True) | ||
flags.DEFINE_string( | ||
"distributed_coordinator", | ||
None, | ||
"Set this None for tpu backend but it is required for multi-gpu environment", | ||
"Distributed coordinator IP address. Must be None on tpu, otherwise required.", | ||
) | ||
flags.DEFINE_integer( | ||
"num_processes", None, "Total number of hosts (nodes). Must be None on tpu, otherwise required." | ||
) | ||
flags.DEFINE_integer( | ||
"num_processes", None, "Total number of hosts (nodes). Set this None for tpu backend." | ||
"process_id", None, "Rank of the current process. Must be None on tpu, otherwise required." | ||
) | ||
flags.DEFINE_integer("process_id", None, "Host process id. Set this None for tpu backend.") | ||
flags.DEFINE_string( | ||
"mesh_selector", | ||
None, | ||
"The mesh selector string. See `SpmdTrainer.Config.mesh_rules` for details.", | ||
) | ||
# TODO(markblee): Remove this flag. | ||
flags.DEFINE_boolean( | ||
"filter_info_logs", | ||
None, | ||
"If None (default), info log only on process 0 on TPUs, and on all processes on GPUs. " | ||
"If True, info log only on process 0. " | ||
"If False, info log on all processes.", | ||
) | ||
|
||
|
||
FLAGS = flags.FLAGS | ||
|
||
|
||
def setup(): | ||
# Decide whether to filter logs. | ||
if FLAGS.filter_info_logs is not None: | ||
filter_info_logs = FLAGS.filter_info_logs | ||
else: | ||
# Infer from platform. For multi-node multi-gpu environment, filtering makes it so that only | ||
# one process' devices are visible, so we disable it by default. | ||
filter_info_logs = FLAGS.jax_backend is None or FLAGS.jax_backend != "gpu" | ||
|
||
if filter_info_logs: | ||
logging.get_absl_handler().addFilter(InfoLogOnlyOnMaster()) | ||
|
||
setup_spmd( | ||
distributed_coordinator=FLAGS.distributed_coordinator, | ||
num_processes=FLAGS.num_processes, | ||
|
@@ -133,9 +112,8 @@ def setup(): | |
logging.info("Devices: %s", devices) | ||
local_devices = jax.local_devices() | ||
logging.info("Local Devices: %s", local_devices) | ||
if FLAGS.jax_backend is not None: | ||
if not devices or not all(device.platform == FLAGS.jax_backend for device in devices): | ||
raise RuntimeError(f"Expected backend {FLAGS.jax_backend}. Got {devices}.") | ||
if not devices or not all(device.platform == FLAGS.jax_backend for device in devices): | ||
raise RuntimeError(f"Expected backend {FLAGS.jax_backend}. Got {devices}.") | ||
if FLAGS.data_dir: | ||
# TODO(ruoming): Get rid of --data_dir and use only env var DATA_DIR. | ||
os.environ["DATA_DIR"] = FLAGS.data_dir | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters