Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DO-NOT-MERGE] PR encompassing all changes needed to support neuron on Axlearn #919

Open
wants to merge 34 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
783b5ab
Support fine grained activation sharding. (#21)
patrick-toulme Nov 22, 2024
2bd3956
boilerplate code for neuron
apoorvtintin Nov 23, 2024
608b63d
Intermediate commit for others
apoorvtintin Nov 23, 2024
0e39d18
fuji 70B v2 runs
apoorvtintin Dec 3, 2024
1445b4d
add flash attention and remat
apoorvtintin Dec 4, 2024
6c70c6a
Maybe shard bug fixes
patrick-toulme Dec 4, 2024
c8e79df
Support fine grained embedding table sharding.
patrick-toulme Dec 5, 2024
eb4edf6
rotary replaced
apoorvtintin Dec 5, 2024
01a5ff5
use custom remat function
apoorvtintin Dec 6, 2024
b83d68c
clean flash attention, enable GQA since it works
apoorvtintin Dec 6, 2024
8062384
fix evaler
apoorvtintin Dec 7, 2024
ce67d13
fixed non shared lm head
apoorvtintin Dec 7, 2024
4cd7d09
fix remat
apoorvtintin Dec 8, 2024
d6bae35
Fix Pytest skipping.
lipovsek-aws Dec 10, 2024
8ee79cc
fix neuron imports to make unit tests run (#27)
lipovsek-aws Dec 10, 2024
8de3b9c
flash attention & input data sharding test
apoorvtintin Dec 10, 2024
3c03930
clean up flash attention
apoorvtintin Dec 11, 2024
28b97a3
Logits in FP32
patrick-toulme Dec 11, 2024
a80b674
fix regressions and messages (#29)
lipovsek-aws Dec 11, 2024
482b7fa
make full model default
apoorvtintin Dec 11, 2024
e9e3279
fix host array test
apoorvtintin Dec 11, 2024
d1d7da8
duplicate the kernel if we cannot shard on num_heads (#30)
aws-zhehongb Dec 11, 2024
dfa013e
enable more neuron attention test
apoorvtintin Dec 11, 2024
adadc7c
set num_layers to original value for 7B
apoorvtintin Dec 11, 2024
01ebc3f
add fsdp tp meshes for all fuji models
apoorvtintin Dec 11, 2024
041a54e
revert train_batch size change
apoorvtintin Dec 11, 2024
c855cd6
put sharding change for embedding under a flag
apoorvtintin Dec 11, 2024
4a78272
fixif input sharding is none for evaler
apoorvtintin Dec 11, 2024
fa52bae
fix host_array_test.py and avoid disapearing tests (#34)
lipovsek-aws Dec 12, 2024
326444b
added back skipping to test_pipeline_summary_writer (#33)
lipovsek-aws Dec 12, 2024
68c6ee9
Name linear1_i for remat (#35)
indhub Dec 12, 2024
f4a68f9
import neuron_attention earlier (#38)
apoorvtintin Dec 19, 2024
28d1bbb
import jax_neuronx to enable buffer donation
apoorvtintin Dec 27, 2024
fbe3d5f
Use different remat strategy on 8 nodes
apoorvtintin Jan 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions axlearn/cloud/gcp/tpu_health_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from contextlib import contextmanager
from datetime import datetime
from typing import Literal, Optional, Union
import pytest

import tensorflow as tf
from absl import flags, logging
Expand Down Expand Up @@ -61,10 +62,9 @@ def _parse_spec_and_check_if_should_skip(
timeout = float(check_split[1])
break
else:
logging.info(
"Skipping %s slice health check because check spec is %s.", check_type, check_spec
pytest.skip(
reason=f"Skipping {check_type} slice health check because check spec is {check_spec}.",
)
return None

# These environment variables are set by GKE.
if "MEGASCALE_NUM_SLICES" not in os.environ or "NODE_NAME" not in os.environ:
Expand All @@ -74,12 +74,9 @@ def _parse_spec_and_check_if_should_skip(

total_slices = int(os.environ["MEGASCALE_NUM_SLICES"])
if total_slices < num_slices_lower_bound:
logging.info(
"Skipping %s slice health check since num_slices < %d.",
check_type,
num_slices_lower_bound,
pytest.skip(
reason=f"Skipping {check_type} slice health check since num_slices < {num_slices_lower_bound}.",
)
return None
return timeout


Expand Down
128 changes: 111 additions & 17 deletions axlearn/common/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@
get_or_none,
shapes,
split_prng_key,
maybe_shard,
)

NEG_INF = -1e15
Expand Down Expand Up @@ -1260,24 +1261,46 @@ def apply_rotary_position_embeddings(
"""
# sin [batch_size, num_heads, sequence_length, embed_size_per_head//2]
# cos [batch_size, num_heads, sequence_length, embed_size_per_head//2]
# sin, cos = jnp.split(sinusoidal_pos, 2, axis=-1)
# # sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
# sin_pos = jnp.reshape(jnp.stack([sin, sin], axis=-1), sinusoidal_pos.shape)
# # cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
# cos_pos = jnp.reshape(jnp.stack([cos, cos], axis=-1), sinusoidal_pos.shape)
# # rotate_half_query_layer [-q1,q0,-q3,q2......,-qd-1,qd-2]
# rotate_half_query = jnp.reshape(
# jnp.stack([-query[..., 1::2], query[..., ::2]], axis=-1), query.shape
# )
# query = query * cos_pos + rotate_half_query * sin_pos
# # rotate_half_key_layer [-k1,k0,-k3,k2......,-kd-1,kd-2]
# rotate_half_key = jnp.reshape(jnp.stack([-key[..., 1::2], key[..., ::2]], axis=-1), key.shape)
# key = key * cos_pos + rotate_half_key * sin_pos
# if rotary_value:
# # rotate_half_value_layer [-v1,v0,-v3,v2......,-vd-1,vd-2]
# rotate_half_value = jnp.reshape(
# jnp.stack([-value[..., 1::2], value[..., ::2]], axis=-1), value.shape
# )
# value = value * cos_pos + rotate_half_value * sin_pos
# return query, key, value

def _rotate_half(x: jnp.ndarray) -> jnp.ndarray:
halves = jnp.split(x, 2, axis=-1)
return jnp.concatenate((-halves[1], halves[0]), axis=-1)

sin, cos = jnp.split(sinusoidal_pos, 2, axis=-1)
# sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
sin_pos = jnp.reshape(jnp.stack([sin, sin], axis=-1), sinusoidal_pos.shape)
# cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
cos_pos = jnp.reshape(jnp.stack([cos, cos], axis=-1), sinusoidal_pos.shape)
# rotate_half_query_layer [-q1,q0,-q3,q2......,-qd-1,qd-2]
rotate_half_query = jnp.reshape(
jnp.stack([-query[..., 1::2], query[..., ::2]], axis=-1), query.shape
)
rotate_half_query = _rotate_half(query)

query = query * cos_pos + rotate_half_query * sin_pos
# rotate_half_key_layer [-k1,k0,-k3,k2......,-kd-1,kd-2]
rotate_half_key = jnp.reshape(jnp.stack([-key[..., 1::2], key[..., ::2]], axis=-1), key.shape)
rotate_half_key = _rotate_half(key)
key = key * cos_pos + rotate_half_key * sin_pos
if rotary_value:
# rotate_half_value_layer [-v1,v0,-v3,v2......,-vd-1,vd-2]
rotate_half_value = jnp.reshape(
jnp.stack([-value[..., 1::2], value[..., ::2]], axis=-1), value.shape
)
rotate_half_value = _rotate_half(value)
value = value * cos_pos + rotate_half_value * sin_pos
return query, key, value

Expand Down Expand Up @@ -2525,6 +2548,12 @@ class Config(BaseLayer.Config):
# Ref: https://github.com/google/praxis/blob/main/praxis/layers/transformers.py#L1129
# TODO (bwzhang@) Adding a unittest for the hybridnorm.
structure: str = "prenorm"
# If not None, how to partition pre norm activation values.
prenorm_partition_spec: Optional[tuple[Optional[str]]] = None
# If not None, how to partition pre attention activation values.
preattention_partition_spec: Optional[tuple[Optional[str]]] = None
# If not None, how to partition post attention activation values.
postattention_partition_spec: Optional[tuple[Optional[str]]] = None

def __init__(self, cfg: Config, *, parent: Module):
super().__init__(cfg, parent=parent)
Expand Down Expand Up @@ -2648,10 +2677,14 @@ def attention_thunk(target: Tensor) -> tuple[Optional[NestedTensor], Tensor]:
return dict(attention=atten_state), atten_output

if cfg.structure == "prenorm":
target = maybe_shard(target, cfg.prenorm_partition_spec)
skip_input = target # pre-norm: where normalization happens within the residual part.
norm_target = self.norm(target)
norm_target = maybe_shard(norm_target, cfg.preattention_partition_spec)
atten_state, atten_output = attention_thunk(norm_target)
atten_output = maybe_shard(atten_output, cfg.postattention_partition_spec)
data = skip_input + self.stochastic_depth(self.dropout(atten_output.data))
data = self._remat_name(data, 'residual_add')
elif cfg.structure == "postnorm":
# This is the structure used by the original Transformer, BERT, and RoBERTa.
atten_state, atten_output = attention_thunk(target)
Expand Down Expand Up @@ -2878,6 +2911,13 @@ class Config(BaseLayer.Config):
# TODO(tlei3): deprecate this feature since we use TensorStats.
add_value_rms_norm_summary: Sequence[str] = []

# If not None, how to partition pre norm activation values.
prenorm_partition_spec: Optional[tuple[Optional[str]]] = None
# If not None, how to partition pre MLP activation values.
premlp_partition_spec: Optional[tuple[Optional[str]]] = None
# If not None, how to partition post MLP activation values.
postmlp_partition_spec: Optional[tuple[Optional[str]]] = None

def __init__(self, cfg: Config, *, parent: Module):
super().__init__(cfg, parent=parent)
cfg: TransformerFeedForwardLayer.Config = self.config
Expand Down Expand Up @@ -2942,17 +2982,21 @@ def _linear2(x):
remat_pt1 = "activation"
remat_pt2 = "linear2"
if cfg.structure == "prenorm":
inputs = maybe_shard(inputs, cfg.prenorm_partition_spec)
x = self.norm(inputs)
x = maybe_shard(x, cfg.premlp_partition_spec)
x = self._linear1_activation(x)
x = self._remat_name(x, remat_pt1)
x = self.dropout1(x)
x = _linear2(x)
x = self._remat_name(x, remat_pt2)
x = maybe_shard(x, cfg.postmlp_partition_spec)
x = self.dropout2(x)
x = self.stochastic_depth(x)
if cfg.residual_weight != 1:
x *= cfg.residual_weight
x += inputs
x=self._remat_name(x, 'mlp_residual')
elif cfg.structure == "postnorm":
x = self._linear1_activation(inputs)
x = self._remat_name(x, remat_pt1)
Expand Down Expand Up @@ -2998,7 +3042,7 @@ def _linear1_activation(self, x: Tensor) -> Tensor:
if isinstance(cfg.activation, tuple):
activations = [
self._get_activation(
self.children[f"linear1_{i}"](x), activation_fn_name=activation
self._remat_name(self.children[f"linear1_{i}"](x), f"linear1_{i}"), activation_fn_name=activation
)
for i, activation in enumerate(cfg.activation)
]
Expand Down Expand Up @@ -3477,7 +3521,8 @@ def set_ffn_partition_specs(ff_layer: TransformerFeedForwardLayer.Config):
ff_layer.linear2.param_partition_spec = (tp_axis_names, fsdp_axis_names)
# Encourage the right activation sharding.
ff_layer.linear1.output_partition_spec = (batch_axis_names, seq_axis_names, tp_axis_names)
ff_layer.linear2.output_partition_spec = (batch_axis_names, seq_axis_names, tp_axis_names)
# ff_layer.linear2.output_partition_spec = (batch_axis_names, seq_axis_names, tp_axis_names)
ff_layer.linear2.output_partition_spec = (batch_axis_names, seq_axis_names, None)

if not isinstance(cfg, Sequence):
cfg = [cfg]
Expand All @@ -3488,6 +3533,21 @@ def set_ffn_partition_specs(ff_layer: TransformerFeedForwardLayer.Config):
set_attn_partition_specs(layer_cfg.cross_attention.attention)
if isinstance(layer_cfg.feed_forward, TransformerFeedForwardLayer.Config):
set_ffn_partition_specs(layer_cfg.feed_forward)

# Neuron backend needs fine grained activation sharding.
if jax.default_backend() == 'neuron':
prenorm_partition_spec = (fsdp_axis_names, tp_axis_names, None)
preattention_partition_spec = (fsdp_axis_names, None, None)
postattention_partition_spec = (fsdp_axis_names, tp_axis_names, None)

layer_cfg.self_attention.set(
prenorm_partition_spec=prenorm_partition_spec,
preattention_partition_spec=preattention_partition_spec,
postattention_partition_spec=postattention_partition_spec)
layer_cfg.feed_forward.set(
prenorm_partition_spec=prenorm_partition_spec,
premlp_partition_spec=preattention_partition_spec,
postmlp_partition_spec=postattention_partition_spec)
# pytype: enable=attribute-error


Expand Down Expand Up @@ -4071,6 +4131,20 @@ def forward(

# TODO(sneha): extend_step

def save_only_these(*names_to_save):
# Save all values, including unnamed ones, excluding the specified names.
names_to_save = frozenset(names_to_save)
def policy(prim, *_, **params):
if 'name' in params and params['name'] in names_to_save:
print(f"[WIP] Saving {params['name']}")
return True
elif 'name' in params:
print(f"[WIP] Not saving tensor: {params['name']}")
return False
else:
print("[WIP] Not saving unnamed tensor")
return False
return policy

def build_remat_spec(
stack_cfg: Union[
Expand Down Expand Up @@ -4106,20 +4180,40 @@ def build_remat_spec(
if stack_cfg.klass is PipelinedTransformerLayer:
return None

backend = jax.default_backend()
checkpoints = []
if self_attention:
attention_name = stack_cfg.layer.self_attention.attention.klass.__name__
checkpoints.extend(
[f"{attention_name}.{el}" for el in ["q_proj", "k_proj", "v_proj", "context", "o_proj"]]
)

if backend != "neuron":
checkpoints.extend(
[f"{attention_name}.{el}" for el in ["q_proj", "k_proj", "v_proj", "context", "o_proj"]]
)
elif jax.device_count() > (64 * 8):
checkpoints.extend(
[f"{attention_name}.{el}" for el in ['q_proj', 'k_proj', 'v_proj']] + ["TransformerAttentionLayer.residual_add", "TransformerFeedForwardLayer.mlp_residual"]
)
else:
checkpoints.extend(
[f"{attention_name}.{el}" for el in ['q_proj', 'k_proj', 'v_proj']]
)
if feed_forward and hasattr(stack_cfg.layer, "feed_forward"):
ffn_name = stack_cfg.layer.feed_forward.klass.__name__
checkpoints.extend([f"{ffn_name}.{el}" for el in ["activation", "linear2"]])
if backend != "neuron":
checkpoints.extend([f"{ffn_name}.{el}" for el in ["activation", "linear2"]])
elif jax.device_count() > (64 * 8):
checkpoints.extend([f"{ffn_name}.{el}" for el in ["linear1_0", "linear1_1"]])
else:
checkpoints.extend([f"{ffn_name}.{el}" for el in ["linear1_0"]])

if backend != "neuron":
policy = config_for_function(jax_remat_policies.save_only_these_names).set(
names_which_can_be_saved=checkpoints
)
else:
policy = config_for_function(save_only_these).set(
names_to_save=checkpoints
)

policy = config_for_function(jax_remat_policies.save_only_these_names).set(
names_which_can_be_saved=checkpoints
)
if offload_dst:
policy = config_for_function(jax_remat_policies.save_and_offload_only_these_names).set(
names_which_can_be_saved=[],
Expand Down
64 changes: 64 additions & 0 deletions axlearn/common/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5041,6 +5041,70 @@ def test_set_double_shard_weights_config_for_list_of_configs(
(fsdp_axis_names, tp_axis_names, None),
)

@parameterized.product(
self_attention_input_linear_cfg=(
QKVLinear.default_config(),
FusedQKVLinear.default_config(),
),
cross_attention_cfg=(None, TransformerAttentionLayer.default_config()),
batch_axis_names=("data", ("replica", "data", "fsdp")),
fsdp_axis_names=("fsdp",),
tp_axis_names=("model",),
seq_axis_names=("seq",),
)
def test_set_activation_shardings_config_for_list_of_configs(
self,
self_attention_input_linear_cfg,
cross_attention_cfg,
batch_axis_names,
fsdp_axis_names,
tp_axis_names,
seq_axis_names,
):
cfg_layer: TransformerLayer.Config = TransformerLayer.default_config().set(
cross_attention=cross_attention_cfg
)
cfg_layer.self_attention.structure = "prenorm"
cfg_layer.feed_forward.structure = "prenorm"
cfg_layer.self_attention.attention.input_linear = self_attention_input_linear_cfg
cfg_layers = [cfg_layer, cfg_layer]

cfg_layer.self_attention.prenorm_partition_spec = (fsdp_axis_names, tp_axis_names, None,)
cfg_layer.self_attention.preattention_partition_spec = (fsdp_axis_names, None, None)
cfg_layer.self_attention.postattention_partition_spec = (fsdp_axis_names, tp_axis_names, None)

cfg_layer.feed_forward.prenorm_partition_spec = (fsdp_axis_names, tp_axis_names, None,)
cfg_layer.feed_forward.premlp_partition_spec = (fsdp_axis_names, None, None)
cfg_layer.feed_forward.postmlp_partition_spec = (fsdp_axis_names, tp_axis_names, None)

for cfg in cfg_layers:
self_atten = cfg.self_attention
feed_forward = cfg.feed_forward
self.assertSequenceEqual(
self_atten.prenorm_partition_spec,
(fsdp_axis_names, tp_axis_names, None),
)
self.assertSequenceEqual(
self_atten.preattention_partition_spec,
(fsdp_axis_names, None, None),
)
self.assertSequenceEqual(
self_atten.postattention_partition_spec,
(fsdp_axis_names, tp_axis_names, None),
)

self.assertSequenceEqual(
feed_forward.prenorm_partition_spec,
(fsdp_axis_names, tp_axis_names, None),
)
self.assertSequenceEqual(
feed_forward.premlp_partition_spec,
(fsdp_axis_names, None, None),
)
self.assertSequenceEqual(
feed_forward.postmlp_partition_spec,
(fsdp_axis_names, tp_axis_names, None),
)

class PositionalEmbeddingTest(TestCase):
"""Tests PositionalEmbedding."""
Expand Down
2 changes: 2 additions & 0 deletions axlearn/common/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,8 @@ def _metrics(
target_labels: Tensor,
target_num_bytes: Optional[Tensor],
) -> dict[str, Tensor]:
if logits.dtype in (jnp.bfloat16, jnp.float16):
logits = logits.astype(jnp.float32)
live_targets = (target_labels != self.decoder.config.pad_token_id) & (target_labels >= 0)
num_targets = live_targets.sum()
accuracy = (
Expand Down
Loading