Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: Refactor JobSet for Pathways #918

Draft
wants to merge 64 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 56 commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
582329a
Adding support for Pathways proxy
jesus-orozco Sep 9, 2024
8d3c643
Update pathways-utils dependency and fix formatting
jesus-orozco Oct 1, 2024
0e61b76
Move pathways package to its own dependency tree and pin it to a spec…
jesus-orozco Oct 7, 2024
4260c38
Relocate pathwaysutils import
jesus-orozco Oct 9, 2024
ae80a36
Create custom jobset for pathways
jesus-orozco Oct 9, 2024
9960027
Updates to pathways jobset creation
jesus-orozco Oct 16, 2024
f02d345
Merge branch 'main' into feature/jax_pathways
jesus-orozco Oct 17, 2024
6c3c083
Merge branch 'apple:main' into feature/jax_pathways
jesus-orozco Oct 25, 2024
ac6bcd2
Update pathwaysutils source to pypi
jesus-orozco Oct 25, 2024
43ac5e4
Merge branch 'main' into feature/pathways_workload
jesus-orozco Oct 28, 2024
05f311d
Merge branch 'apple:main' into feature/jax_pathways
jesus-orozco Oct 28, 2024
8972933
trillium testing baseline
jesus-orozco Nov 6, 2024
a952c6f
revert dockerfile for upstream merge
jesus-orozco Nov 6, 2024
34571aa
Merge branch 'apple:main' into trillium_testing
jesus-orozco Nov 6, 2024
0e8ae86
fixed pdbs 3 for fuji 70b
jesus-orozco Nov 6, 2024
d9d458a
testing pdbs 3 with 2 v6e-256 slices
jesus-orozco Nov 6, 2024
73da333
use maxtext xla flags only
jesus-orozco Nov 7, 2024
01bcf8f
new baseline for pdbs=3 without ffn_dim
jesus-orozco Nov 7, 2024
b1ba4fe
try xla sc offload flags
jesus-orozco Nov 7, 2024
6b24cd2
revert AR + SC offload flags
jesus-orozco Nov 7, 2024
aa06778
output jobset to yaml file
jesus-orozco Nov 7, 2024
efabbfd
calculate batch size based on flags
jesus-orozco Nov 7, 2024
cd556c7
enable ffn 3.5 and test pdbs 3 with 4 slices
jesus-orozco Nov 7, 2024
90b0d26
retry 4 slices with pdbs 3
jesus-orozco Nov 7, 2024
4974458
test xla_tpu_enable_sparse_core_collective_offload_all_reduce
jesus-orozco Nov 8, 2024
caff395
remove xla_tpu_enable_sparse_core_collective_offload_all_reduce
jesus-orozco Nov 8, 2024
8277980
dynamic global batch size based on pdbs and slices
jesus-orozco Nov 8, 2024
e23a4d0
enable xla_enable_async_all_reduce
jesus-orozco Nov 11, 2024
1797c52
Merge branch 'main' into feature/jax_pathways
jesus-orozco Nov 11, 2024
4868815
Refactor pathways config flag
jesus-orozco Nov 11, 2024
fe62afd
Merge branch 'apple:main' into feature/pathways_workload
jesus-orozco Nov 11, 2024
27ceea0
install libtpu nightly
jesus-orozco Nov 11, 2024
5c80fc8
Merge branch 'apple:main' into trillium_testing
jesus-orozco Nov 11, 2024
200ac48
custom remat policy for fuji-70b
jesus-orozco Nov 15, 2024
3bfda15
sparscore xla flags and nothing_saveable remat policy
jesus-orozco Nov 15, 2024
51e2e90
calculate batch size with jax devices
jesus-orozco Nov 18, 2024
b894304
update remat policy offloading
jesus-orozco Nov 18, 2024
ce5f1a2
Merge branch 'main' into feature/pathways_workload
jesus-orozco Nov 18, 2024
b51e67e
Update axlearn/cloud/gcp/job.py
jesus-orozco Nov 18, 2024
6537274
Update job.py with dynamic module imports
jesus-orozco Nov 18, 2024
7dbb1b9
Update job.py - remove pathways from dynamic import error message
jesus-orozco Nov 18, 2024
af7c746
Merge remote-tracking branch 'origin/feature/jax_pathways' into pathw…
jesus-orozco Nov 22, 2024
5704ce5
pathways jobset updates
jesus-orozco Nov 22, 2024
bf82554
merge trillium changes
jesus-orozco Nov 22, 2024
4f58291
Merge branch 'apple:main' into pathways_trillium
jesus-orozco Nov 25, 2024
d206110
Install pathwaysutils
jesus-orozco Nov 25, 2024
4152db3
Disable force eval
jesus-orozco Nov 25, 2024
0c33403
Launch pathways on trainer_main
jesus-orozco Nov 25, 2024
315fa6b
add v6e mesh rules
jesus-orozco Dec 5, 2024
765f64b
pin libtpu version
jesus-orozco Dec 5, 2024
f387807
update pathways jobset definition
jesus-orozco Dec 5, 2024
735ff37
Merge branch 'apple:main' into pathways_trillium
jesus-orozco Dec 6, 2024
2e0dd4b
dump xla flags to gcs
jesus-orozco Dec 9, 2024
e9d8341
Refactor jobset to align with new pathways structure
jesus-orozco Dec 12, 2024
357a322
Apply formatting
jesus-orozco Dec 12, 2024
89ac1ea
refactor pathways jobset to new spec
jesus-orozco Jan 10, 2025
d63dd6a
Rebase to axlearn main
jiya-zhang Jan 14, 2025
4393453
revert changes to gitignore and remove dockerignore
jesus-orozco Jan 14, 2025
9ba599b
revert changes to dependencies
jesus-orozco Jan 14, 2025
ca5c883
revert changes to trainer and model configs
jesus-orozco Jan 14, 2025
c7bc5df
remove unnecessary updates to gke tpu job for pathways workloads
jesus-orozco Jan 14, 2025
a0bf9df
Revert updates to fuji model config
jesus-orozco Jan 14, 2025
9f19159
Update pathways container specs for gke tpu job
jesus-orozco Jan 15, 2025
8f71508
Update gke tpu job to bypass jobset coordinator
jesus-orozco Jan 16, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
logs
jobsets
.venv
.circleci
.vscode
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# IGNORE
jobsets

# test results
test-results

Expand Down
4 changes: 3 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,10 @@ RUN apt-get install -y google-perftools
ENV PIP_FIND_LINKS=https://storage.googleapis.com/jax-releases/libtpu_releases.html
# Ensure we install the TPU version, even if building locally.
# Jax will fallback to CPU when run on a machine without TPU.
RUN pip install .[core,tpu]
RUN pip install .[core,tpu,pathways]
RUN if [ -n "$EXTRAS" ]; then pip install .[$EXTRAS]; fi
RUN pip install -U --pre libtpu-nightly==0.1.dev20241203+nightly requests \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
COPY . .

################################################################################
Expand Down
379 changes: 319 additions & 60 deletions axlearn/cloud/gcp/job.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion axlearn/cloud/gcp/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,7 @@ def infer_tpu_workers(tpu_type: str) -> int:
tpu_version, tpu_cores = match.groups()
if tpu_version in {"v3", "v4", "v5p"}:
return int(tpu_cores) // 8
if tpu_version in {"v5litepod"}:
if tpu_version in {"v5litepod", "v6e"}:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you rebase main? v6e is already supported

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

return int(tpu_cores) // 4
except Exception as e: # pylint: disable=broad-except
logging.error("Failed to parse tpu_type %s: %s", tpu_type, e)
Expand Down
2 changes: 1 addition & 1 deletion axlearn/cloud/gcp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def running_from_vm() -> bool:
capture_output=True,
text=True,
)
return (out.returncode == 0) and "Metadata-Flavor: Google" in out.stdout
return False # (out.returncode == 0) and "Metadata-Flavor: Google" in out.stdout
changlan marked this conversation as resolved.
Show resolved Hide resolved


def running_from_k8s() -> bool:
Expand Down
68 changes: 65 additions & 3 deletions axlearn/common/compiler_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,68 @@ def default_xla_options(
xla_enable_async_all_gather="true", # Allow async all-gather.
xla_enable_async_collective_permute="true", # Allow async collective permute.
)
if version == "v6e":
options.update(
# improved performance for v6e
xla_tpu_scoped_vmem_limit_kib="98304",
# maxtext xla flags
# xla_enable_async_all_reduce="true",
# CF_FOR_ALL_GATHER
xla_tpu_enable_async_collective_fusion="true",
xla_tpu_enable_async_collective_fusion_fuse_all_gather="true",
xla_tpu_enable_async_collective_fusion_multiple_steps="true",
xla_tpu_overlap_compute_collective_tc="true",
xla_enable_async_all_gather="true",
# sparsecore offloading AR
xla_sc_disable_megacore_partitioning="true",
# xla_tpu_enable_async_collective_fusion_fuse_all_gather="false",
# xla_tpu_enable_all_gather_offload_tracing="true",
xla_tpu_use_tc_device_shape_on_sc="true",
# xla_tpu_enable_sparse_core_collective_offload_all_gather="true",
xla_sc_enable_instruction_fusion="false",
xla_sc_disjoint_spmem="false",
tpu_use_continuations="true",
xla_jf_crs_combiner_threshold_count="10",
xla_tpu_enable_sparse_core_collective_offload_all_reduce="true",
# Flag to enable some advanced scheduling features.
xla_tpu_enable_all_experimental_scheduler_features="true",
# Flag to enable memory tracking scheduling. The default AUTO only enables
# it in some situations. Not needed if
# xla_tpu_enable_all_experimental_scheduler_features is set to true already.
xla_tpu_enable_scheduler_memory_pressure_tracking="ENABLED",
# Flag controlling the maximum number of overlapping host offloadings.
xla_tpu_host_transfer_overlap_limit=24,
# Flag to enable the aggressive removal of opt-barriers.
xla_tpu_aggressive_opt_barrier_removal="ENABLED",
# Flag to enable more aggressive scheduling for async ops, such as pushing
# the async start to the beginning of the loop body.
xla_lhs_prioritize_async_depth_over_stall="ENABLED",
# For multi-slice configurations,
# Flag to enable pipelining of cross-DCN all-gathers.
xla_tpu_enable_ag_backward_pipelining="true",
xla_should_allow_loop_variant_parameter_in_chain="ENABLED",
xla_should_add_loop_invariant_op_in_chain="ENABLED",
# Flag controlling the maximum number of overlapping cross-DCN send/recv.
xla_max_concurrent_host_send_recv=100,
# If you are seeing OOM (out-of-memory) error, or bad performance when HBM memory
# usage is close to HBM capacity, tuning these two flags might help:
# Flag controlling the HBM memory limit as a percentage of the total HBM size.
# Default value is 95. Can tune up or down to give more or less memory for the
# scheduler. The scheduler favors more on less memory usage when it's under
# memory pressure, instead of hiding latency by overlapping more computations
# and communications.
# xla_tpu_scheduler_percent_shared_memory_limit=xx,
# Flag controlling the number of times the scheduler is run if the scheduled
# peak memory usage exceeds the initial memory limit, by setting memory limit
# to 90% of the previous memory limit each time. Default value is 1. Sometimes
# when the scheduler thinks it goes out memory, it may not actually happen due
# to other factors controlled by other compiler passes, or the initial memory
# limit is already set too low. Cutting the memory limit to 90% of previous one
# though, may make the scheduler weighting too much on the memory usage instead
# of latency side.
xla_latency_hiding_scheduler_rerun=0,
)
options["2a886c8_chip_config_name"] = "megachip_tccontrol"
if num_slices > 1:
# Support multiple TPU slices connected over a data center network.
options.update(
Expand All @@ -58,8 +120,8 @@ def default_xla_options(
)

# Validate options. Will never fail if this function is implemented correctly.
for k, v in options.items():
assert v in [True, False, "true", "false"], (k, v)
# for k, v in options.items():
# assert v in [True, False, "true", "false"], (k, v)

return options

Expand Down Expand Up @@ -166,4 +228,4 @@ def infer_xsc_compiler_options(
return options


_TPU_VERSIONS = ("v3", "v4", "v5litepod", "v5p")
_TPU_VERSIONS = ("v3", "v4", "v5litepod", "v5p", "v6e")
5 changes: 4 additions & 1 deletion axlearn/common/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
instance_type=instance_type, num_slices=num_tpu_slices, backend="tpu"
)
os.environ["LIBTPU_INIT_ARGS"] = compiler_options.xla_flags_from_options(libtpu_init_options)
print("LIBTPU_INIT_ARGS: ", os.environ["LIBTPU_INIT_ARGS"], file=sys.stderr)
changlan marked this conversation as resolved.
Show resolved Hide resolved
except compiler_options.NotTpuError as e:
# Log this when setup() is called.
tpu_flags_exc = e
Expand Down Expand Up @@ -132,7 +133,9 @@ def setup():
logging.info("Devices: %s", devices)
local_devices = jax.local_devices()
logging.info("Local Devices: %s", local_devices)
if not devices or not all(device.platform == FLAGS.jax_backend for device in devices):
if FLAGS.jax_backend != "proxy" and (
not devices or not all(device.platform == FLAGS.jax_backend for device in devices)
):
raise RuntimeError(f"Expected backend {FLAGS.jax_backend}. Got {devices}.")
if FLAGS.data_dir:
# TODO(ruoming): Get rid of --data_dir and use only env var DATA_DIR.
Expand Down
10 changes: 10 additions & 0 deletions axlearn/common/launch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,16 @@
None,
"The mesh selector string. See `SpmdTrainer.Config.mesh_rules` for details.",
)
flags.DEFINE_string(
changlan marked this conversation as resolved.
Show resolved Hide resolved
"pdbs",
None,
"Per device batch size (Overrides global batch size).",
)
flags.DEFINE_integer(
"slices",
1,
"Number of slices for the TPU job.",
)

FLAGS = flags.FLAGS

Expand Down
2 changes: 2 additions & 0 deletions axlearn/common/launch_trainer_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

"""Main function for launching the trainer."""

# Temp hack to bypass invalid backend error
import pathwaysutils
from absl import app, flags

from axlearn.common import launch, launch_trainer, measurement
Expand Down
21 changes: 11 additions & 10 deletions axlearn/common/trainer.py
changlan marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,9 @@ class Config(Module.Config):
# The provided config should instantiate to a thunk that returns the context manager.
context_manager: Optional[ConfigOr[Callable[[], ContextManager]]] = None

# The Global Batch Size
train_batch_size: Optional[int] = None

def __init__(
self,
cfg: Config,
Expand Down Expand Up @@ -569,13 +572,11 @@ def run(
self.vlog(3, "Start step %s", self.step)
output = self._run_step(
utils.host_to_global_device_array(input_batch),
force_run_evals=(
force_run_eval_sets_at_max_step if self.step >= cfg.max_step else None
),
force_run_evals=None,
)
self.vlog(3, "Done step %s", self.step)
num_steps += 1
if num_steps % 100 == 0:
if num_steps % 5 == 0:
now = time.perf_counter()
average_step_time = (now - start_time) / num_steps
self._step_log("Average step time: %s seconds", average_step_time)
Expand Down Expand Up @@ -1020,12 +1021,12 @@ def _run_step(
# Run the compiled function.
self._trainer_state, outputs = compiled_train_step_fn(self.trainer_state, input_batch)

if self.step % 100 == 0 or 0 <= self.step <= 5:
self._step_log(
"loss=%s aux=%s",
outputs["loss"],
jax.tree.map(lambda x: x.item() if x.ndim == 0 else f"T{x.shape}", outputs["aux"]),
)
# if self.step % 100 == 0 or 0 <= self.step <= 5:
self._step_log(
"loss=%s aux=%s",
outputs["loss"],
jax.tree.map(lambda x: x.item() if x.ndim == 0 else f"T{x.shape}", outputs["aux"]),
)

self.summary_writer(self.step, {"loss": outputs["loss"], **outputs["summaries"]})
# Aggregate summaries across evalers.
Expand Down
7 changes: 5 additions & 2 deletions axlearn/common/utils_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def setup(
if initialization_timeout is not None:
init_kwargs["initialization_timeout"] = initialization_timeout

if jax_backend == "tpu":
# TPU resources orchestrated by Pathways use 'proxy' as the JAX backend
if jax_backend in ("tpu", "proxy"):
if not (
distributed_coordinator is None and num_processes is None and process_id is None
):
Expand Down Expand Up @@ -92,5 +93,7 @@ def setup(
# local_device_ids arg allows us to maintain expected behavior
init_kwargs["local_device_ids"] = list(range(8))

jax.distributed.initialize(**init_kwargs)
# When using Pathways proxy for TPU backend, jax distributed init is not needed
if jax_backend != "proxy":
jax.distributed.initialize(**init_kwargs)
_jax_distributed_initialized = True
84 changes: 82 additions & 2 deletions axlearn/experiments/text/gpt/fuji.py
changlan marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
import itertools
from typing import Any, Optional, Union

from absl import flags
from jax.ad_checkpoint import checkpoint_policies as jax_remat_policies

from axlearn.cloud.gcp.system_characteristics import USER_FACING_NAME_TO_SYSTEM_CHARACTERISTICS
from axlearn.common import causal_lm, config
from axlearn.common.attention import (
BaseStackedTransformerLayer,
Expand Down Expand Up @@ -54,6 +56,8 @@
from axlearn.experiments.text.gpt.common import scaled_hidden_dim
from axlearn.experiments.trainer_config_utils import TrainerConfigFn

FLAGS = flags.FLAGS

MODEL_SIZES = ("test", "1B", "3B", "7B", "8B", "70B")


Expand Down Expand Up @@ -122,6 +126,10 @@ def get_trainer_kwargs(
max_step = TOTAL_TOKENS[version][model_size] // tokens_per_batch
max_sequence_length = MAX_SEQUENCE_LENGTH[version]
train_batch_size = tokens_per_batch // max_sequence_length
if FLAGS.pdbs:
import jax

train_batch_size = len(jax.devices()) * int(FLAGS.pdbs)

# Whether to use grouped query attention.
num_kv_heads = None
Expand Down Expand Up @@ -287,6 +295,25 @@ def get_trainer_kwargs(
"gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)",
mesh_shape_from_axes(data=-1, fsdp=8),
),
# tpu-v6e.
(
"tpu-v6e-.*",
ChainConfigModifier.default_config().set(
config_modifiers=[
MeshShapeModifier.default_config().set(
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=256)
),
RematSpecModifier.default_config().set(
remat_policies={
"model.decoder.transformer.layer": RematSpec(
prevent_cse=True,
policy=jax_remat_policies.nothing_saveable,
),
}
),
],
),
),
),
)
elif model_size == "8B":
Expand Down Expand Up @@ -367,9 +394,40 @@ def get_trainer_kwargs(
"gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)",
mesh_shape_from_axes(data=-1, fsdp=8),
),
# tpu-v6e.
(
"tpu-v6e-.*",
ChainConfigModifier.default_config().set(
config_modifiers=[
MeshShapeModifier.default_config().set(
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=256)
),
RematSpecModifier.default_config().set(
remat_policies={
"model.decoder.transformer.layer": RematSpec(
prevent_cse=True,
policy=jax_remat_policies.nothing_saveable,
),
}
),
],
),
),
),
)
elif model_size == "70B":
remat_policy_70b = config_for_function(
jax_remat_policies.save_and_offload_only_these_names
).set(
names_which_can_be_saved=[],
names_which_can_be_offloaded=[
"FlashAttention.q_proj",
"FlashAttention.k_proj",
"FlashAttention.v_proj",
],
offload_src="device",
offload_dst="pinned_host",
)
trainer_kwargs = dict(
model_kwargs=dict(
num_layers=80,
Expand All @@ -387,6 +445,8 @@ def get_trainer_kwargs(
max_sequence_length=max_sequence_length,
train_batch_size=train_batch_size,
max_step=max_step,
# eval_every_n_steps=500,
save_every_n_steps=100,
mesh_shape=mesh_shape_from_axes(fsdp=-1),
mesh_rules=(
# TPU V5e maximum per device batch is 1.
Expand All @@ -398,13 +458,13 @@ def get_trainer_kwargs(
ChainConfigModifier.default_config().set(
config_modifiers=[
MeshShapeModifier.default_config().set(
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=256)
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=128)
),
RematSpecModifier.default_config().set(
remat_policies={
"model.decoder.transformer.layer": RematSpec(
prevent_cse=True,
policy=offload_dots_saveable_policy,
policy=jax_remat_policies.dots_saveable,
),
}
),
Expand All @@ -417,6 +477,26 @@ def get_trainer_kwargs(
"gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)",
mesh_shape_from_axes(data=-1, fsdp=128),
),
# tpu-v6e.
(
"tpu-v6e-.*",
ChainConfigModifier.default_config().set(
config_modifiers=[
MeshShapeModifier.default_config().set(
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=256)
),
RematSpecModifier.default_config().set(
remat_policies={
"model.decoder.transformer.layer": RematSpec(
prevent_cse=True,
policy=remat_policy_70b,
# policy=jax_remat_policies.nothing_saveable,
),
}
),
],
),
),
),
)
else:
Expand Down
Loading