From 9393db34b02c358f36e8556dd71de5bde209edb1 Mon Sep 17 00:00:00 2001 From: Firenze11 Date: Tue, 14 Jan 2025 17:40:42 -0800 Subject: [PATCH 01/12] Allow external positions to be inputed in RoPE embedding layer Use case: In RoPE embedding, position embeddings are applied to Q, K, V values after `i_proj`. Unlike the implementation of current `RoFormerQKVLinear`, in MaskedDiT we need to customize positions to indicate masked versus non-masked positions in the position embedding. When we convert this masked roformer attention module to flash attention, we need its signature to be supported by `MultiheadAttention`. --- axlearn/common/attention.py | 79 +++++++++++++++++++++++-------------- 1 file changed, 50 insertions(+), 29 deletions(-) diff --git a/axlearn/common/attention.py b/axlearn/common/attention.py index b53ef8b67..7aaa2b9cf 100644 --- a/axlearn/common/attention.py +++ b/axlearn/common/attention.py @@ -74,6 +74,7 @@ import enum import functools import math +import re from collections.abc import Sequence from enum import Enum, unique from typing import Any, Callable, NamedTuple, Optional, Protocol, Union @@ -81,6 +82,9 @@ import einops import jax from jax import numpy as jnp +from jax._src.ad_checkpoint import name_p +from jax._src.interpreters import partial_eval as pe +from jax.core import Primitive from axlearn.common import ops, param_init from axlearn.common.attention_bias import ( @@ -134,16 +138,13 @@ from axlearn.common.utils import ( Nested, NestedTensor, - OffloadPolicy, PartitionSpec, - SavePattern, Tensor, TensorSpec, VDict, check_numerics, flatten_items, get_or_none, - save_and_offload_only_these_names_regex, shapes, split_prng_key, ) @@ -1216,18 +1217,37 @@ class Config(BaseLayer.Config): dim: Required[int] = REQUIRED # The dimensionality of the positional embedding. theta: float = 10000.0 # The scale of base frequency. - def forward(self, positions: Tensor) -> Tensor: + def default_query_positions(self, max_seq_len: int) -> Tensor: + """Compute default `positions` value to be inputed into forward when `positions` is + not provided to the corresponding QKVLinear class such as `RoFormerQKVLinear` + """ + return jnp.arange(max_seq_len)[None] # [batch_size=1, max_seq_len]. + + def forward( + self, positions: Optional[Tensor] = None, max_seq_len: Optional[int] = None + ) -> Tensor: """ TODO(bwzhang): 1. verify the performance under float32. Args: positions: A tensor representing the token position IDs. The shape is [batch_size, seq_len]. + max_seq_len: Max length of sequence, required if positions is not provided Returns: Rotary Positional Embedding. Shape is [seq_len, dim]. + + Raises: + ValueError: If positions is None and max_seq_len is None. """ cfg = self.config + if positions is None: + if max_seq_len is None: + raise ValueError( + "Must provide `max_seq_len` for computing default query positions if " + "`positions` is None." + ) + positions = self.default_query_positions(max_seq_len) return _rotary_sinusoidal_positional_embeddings( positions=positions, dim=cfg.dim, theta=cfg.theta ) @@ -1300,7 +1320,7 @@ class RoFormerQKVLinear(BaseQKVLinear): class Config(BaseQKVLinear.Config): """Configures RoFormerQKVLinear.""" - rope_pos_emb_layer: InstantiableConfig = ( + rope_pos_emb_layer: RoFormerSinusoidalPositionalEmbedding.Config = ( RoFormerSinusoidalPositionalEmbedding.default_config() ) input_linear: BaseQKVLinear.Config = QKVLinear.default_config() @@ -1342,9 +1362,10 @@ def forward( cfg = self.config # Query should have shape of [batch_size, seq_len, num_heads, per_head_dim]. query, key, value = self.i_proj(query, key=key, value=value, kv_state=kv_state) - if query_positions is None: - query_positions = jnp.arange(query.shape[1])[None] - sinusoidal_pos_emb = self.rope_pos_emb_layer.forward(query_positions).astype(query.dtype) + seq_len = query.shape[1] + sinusoidal_pos_emb = self.rope_pos_emb_layer.forward( + positions=query_positions, max_seq_len=seq_len + ).astype(query.dtype) # sinusoidal_pos_emb shape should be [batch_size, seq_len, 1, dim] sinusoidal_pos_emb = jnp.expand_dims(sinusoidal_pos_emb, 2) @@ -3982,42 +4003,42 @@ def forward( # TODO(sneha): extend_step +OffloadPolicy = Callable[[Primitive, list[Any], dict[str, Any]], Union[bool, Any]] +_SavePattern = Union[str, re.Pattern, None] + + # Adapted from jax source code to support regex. Reference: # https://github.com/jax-ml/jax/blob/0d36b0b433a93c707f86dac89b0c05d40302775a/jax/_src/ad_checkpoint.py#L120 -# TODO(kelvin-zou): deprecated, keep it here to minimize distruption to the golden configs. -# Please use axlearn.common.utils.extended_checkpoint_policies instead. def _save_and_offload_only_these_names_regex( *, - names_which_can_be_saved: SavePattern, - names_which_can_be_offloaded: SavePattern, + names_which_can_be_saved: _SavePattern, + names_which_can_be_offloaded: _SavePattern, offload_src: str, offload_dst: str, ) -> OffloadPolicy: - return save_and_offload_only_these_names_regex( - names_which_can_be_saved=names_which_can_be_saved, - names_which_can_be_offloaded=names_which_can_be_offloaded, - offload_src=offload_src, - offload_dst=offload_dst, - ) + def policy(prim, *_, **params): + if prim is name_p: + if names_which_can_be_saved and re.fullmatch(names_which_can_be_saved, params["name"]): + return pe.Saveable + if names_which_can_be_offloaded and re.fullmatch( + names_which_can_be_offloaded, params["name"] + ): + return pe.Offloadable(src=offload_src, dst=offload_dst) + return pe.Recompute # not saveable unless it's in the allow-list + + return policy -# Regex patterns for matching remat names -class RematRegexSavePatterns(enum.Enum): - QKV_PROJ = r".*[kqv]_proj" - O_PROJ = r".*o_proj" - CONTEXT = r".*context" - LINEAR1_X = r".*linear1_[01]" - LINEAR2_X = r".*linear2_[01]" - SELF_ATTENTION = ".*([qkvo]_proj|context)" - FEED_FORWARD = "|".join([LINEAR1_X, LINEAR2_X]) +SELF_ATTENTION_SAVE_PATTERN = ".*([qkvo]_proj|context)" +FEED_FORWARD_SAVE_PATTERN = ".*linear[12]_.*" def build_remat_spec( stack_cfg: Union[ BaseStackedTransformerLayer.Config, "RepeatedConformerLayer.Config" # type: ignore ], - save_pattern: SavePattern = RematRegexSavePatterns.SELF_ATTENTION.value, - offload_pattern: SavePattern = None, + save_pattern: _SavePattern = SELF_ATTENTION_SAVE_PATTERN, + offload_pattern: _SavePattern = None, offload_dst: str = "pinned_host", ) -> Optional[RematSpec]: """Configures how the Transformer or Conformer stack will save the linearization points. From 3b7c847b88d3544d2584da58a5f23256dab33685 Mon Sep 17 00:00:00 2001 From: Firenze11 Date: Tue, 14 Jan 2025 17:41:43 -0800 Subject: [PATCH 02/12] Update attention_test.py --- axlearn/common/attention_test.py | 192 +++++++++++++++++++------------ 1 file changed, 119 insertions(+), 73 deletions(-) diff --git a/axlearn/common/attention_test.py b/axlearn/common/attention_test.py index 9ed01ca94..8528fe95e 100644 --- a/axlearn/common/attention_test.py +++ b/axlearn/common/attention_test.py @@ -41,6 +41,7 @@ from axlearn.common import attention, attention_bias, test_utils, utils from axlearn.common.attention import ( + FEED_FORWARD_SAVE_PATTERN, BaseStackedTransformerLayer, BaseTransformerLayer, BottleNeckAdapterTransformerLayer, @@ -57,7 +58,6 @@ PipelinedTransformerLayer, QKVLinear, QLinear, - RematRegexSavePatterns, RepeatedTransformerLayer, RoFormerQKVLinear, StackedTransformerLayer, @@ -65,6 +65,7 @@ TransformerFeedForwardLayer, TransformerLayer, _next_power_of_two, + _save_and_offload_only_these_names_regex, apply_attention_logit_biases, apply_rotary_position_embeddings, build_remat_spec, @@ -123,7 +124,6 @@ VDict, as_tensor, flatten_items, - save_and_offload_only_these_names_regex, shapes, ) @@ -812,9 +812,73 @@ def test_rope_emb(self, batch_size, max_len, dim): .set(name="test_rope_emb", dim=dim) .instantiate(parent=None) ) - test_output = test_layer.forward(positions) + test_output = test_layer.forward(positions=positions) np.testing.assert_allclose(np.expand_dims(ref_output, 0), test_output, atol=5e-7) + @parameterized.parameters( + (None, True), + (10, False), + ) + def test_rope_emb_no_pos(self, max_len, should_raise): + test_layer = ( + attention.RoFormerSinusoidalPositionalEmbedding.default_config() + .set(name="test_rope_emb", dim=10) + .instantiate(parent=None) + ) + if should_raise: + with self.assertRaises(ValueError): + test_layer.forward(max_seq_len=max_len) + else: + test_layer.forward(max_seq_len=max_len) + + @parameterized.parameters( + (2, 10, 32, 4), + ) + def test_default_rope_emb(self, batch_size, max_len, dim, num_heads): + rng = np.random.default_rng(seed=123) + query = jnp.asarray(rng.random([batch_size, max_len, dim])) + key = jnp.asarray(rng.random([batch_size, max_len, dim])) + value = jnp.asarray(rng.random([batch_size, max_len, dim])) + per_head_dim = dim // num_heads + + emb_layer_cfg = attention.RoFormerSinusoidalPositionalEmbedding.default_config().set( + dim=per_head_dim, + ) + linear_layer_cfg = attention.RoFormerQKVLinear.default_config().set( + query_dim=dim, + key_dim=dim, + value_dim=dim, + num_heads=num_heads, + per_head_dim=per_head_dim, + rope_pos_emb_layer=emb_layer_cfg, + rotary_value=False, + name="test_rope_linear", + ) + rope_linear_layer = linear_layer_cfg.instantiate(parent=None) + state = rope_linear_layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) + + rope_emb_layer = emb_layer_cfg.set(name="test_rope_emb").instantiate(parent=None) + default_positions = rope_emb_layer.default_query_positions(max_len) + + input_dict = dict(query=query, key=key, value=value) + + layer_outputs_no_position, _ = F( + rope_linear_layer, + inputs=input_dict, + state=state, + is_training=True, + prng_key=jax.random.PRNGKey(0), + ) + layer_outputs, _ = F( + rope_linear_layer, + inputs=dict(**input_dict, query_positions=default_positions), + state=state, + is_training=True, + prng_key=jax.random.PRNGKey(0), + ) + # test RoFormerQKVLinear uses default positions in RoFormerSinusoidalPositionalEmbedding + np.testing.assert_allclose(layer_outputs_no_position, layer_outputs, atol=1e-5) + def _compare_against_roformer_attention( self, ref, @@ -887,7 +951,7 @@ def test_rope_self_attention(self, rotary_value: bool, override_positions: bool) if override_positions else jnp.expand_dims(jnp.arange(max_sequence_length), 0) ) - ref_rope_emb = as_torch_tensor(rope_emb_layer.forward(positions)).unsqueeze(1) + ref_rope_emb = as_torch_tensor(rope_emb_layer.forward(positions=positions)).unsqueeze(1) layer = attention.TransformerAttentionLayer.default_config().set( source_dim=model_dim, target_dim=model_dim, @@ -1075,9 +1139,13 @@ def test_against_llama_for_precompute_freqs_cis(self, theta: float): attention.QKVLinear.default_config(), attention.GroupedQKVLinear.default_config(), ), + has_query_positions=(True, False), ) def test_roformer_qkv_linear( - self, dtype: jnp.dtype, input_linear: attention.BaseQKVLinear.Config + self, + dtype: jnp.dtype, + input_linear: attention.BaseQKVLinear.Config, + has_query_positions: bool, ): seq_len = 6 batch_size = 2 @@ -1116,6 +1184,14 @@ def test_roformer_qkv_linear( jax.random.PRNGKey(0) ) input_batch = dict(query=query, key=key, value=value) + if has_query_positions: + input_batch["query_positions"] = jax.random.permutation( + jax.random.PRNGKey(1), + jnp.arange(seq_len)[None, :].repeat(batch_size, axis=0), + axis=1, + independent=True, + ) + layer_outputs, _ = F( roformer_qkv_linear, inputs=utils.cast_floats(input_batch, to_dtype=dtype), @@ -2168,17 +2244,24 @@ def test_data_types(self, dtype: jnp.dtype, per_dim_scale: Optional[PerDimScale. lambda query_len, kv_len: _random_mask(jax.random.PRNGKey(1), query_len, kv_len), ), kv_length_multiplier=(0.5, 1, 2), + has_query_positions=(False, True), ) def test_causal( self, base_cfg: attention.MultiheadAttention.Config, attention_logit_biases_fn: Callable[[int, int], Tensor], kv_length_multiplier: float, + has_query_positions: bool, ): """Tests that base_cfg(causal=True) is equivalent to applying a causal mask.""" - if kv_length_multiplier != 1 and isinstance( - base_cfg.input_linear, - (FusedGroupedQKVLinear.Config, RoFormerQKVLinear.Config, FusedQKVLinear.Config), + if ( + has_query_positions + and not isinstance(base_cfg.input_linear, RoFormerQKVLinear.Config) + or kv_length_multiplier != 1 + and isinstance( + base_cfg.input_linear, + (FusedGroupedQKVLinear.Config, RoFormerQKVLinear.Config, FusedQKVLinear.Config), + ) ): pytest.skip(reason="Incompatible test setting that does not need testing.") @@ -2202,6 +2285,14 @@ def test_causal( query = jnp.zeros([batch_size, query_len, model_dim], dtype=jnp.float32) outputs = [] + if has_query_positions: + query_positions = jax.random.permutation( + jax.random.PRNGKey(1), + jnp.arange(query_len)[None, :].repeat(batch_size, axis=0), + axis=1, + independent=True, + ) + for layer in (ref_layer, test_layer): inputs = dict(query=query) kv_len = int(kv_length_multiplier * query_len) @@ -2223,6 +2314,8 @@ def test_causal( attention_logit_biases, causal_biases ) inputs["attention_logit_biases"] = attention_logit_biases + if has_query_positions: + inputs["query_positions"] = query_positions layer_outputs, _ = F( layer, @@ -2261,16 +2354,21 @@ def test_causal( lambda seq_len: None, lambda seq_len: _random_mask(jax.random.PRNGKey(1), seq_len, seq_len), ), + has_query_positions=(False, True), ) def test_sliding_window( self, base_cfg: attention.MultiheadAttention.Config, attention_logit_biases_fn: Callable[[int], Tensor], + has_query_positions: bool, ): """ Tests that base_cfg with sliding window causal mask fns is equivalent to applying a causal sliding window mask. """ + if has_query_positions and not isinstance(base_cfg.input_linear, RoFormerQKVLinear.Config): + return + model_dim = 16 num_heads = 4 ref_cfg = base_cfg.clone( @@ -2296,6 +2394,15 @@ def test_sliding_window( batch_size, seq_len = 2, 4 query = jnp.zeros([batch_size, seq_len, model_dim], dtype=jnp.float32) outputs = [] + + if has_query_positions: + query_positions = jax.random.permutation( + jax.random.PRNGKey(1), + jnp.arange(seq_len)[None, :].repeat(batch_size, axis=0), + axis=1, + independent=True, + ) + for layer in (ref_layer, test_layer): attention_logit_biases = attention_logit_biases_fn(seq_len) if layer is ref_layer: @@ -2305,6 +2412,8 @@ def test_sliding_window( attention_logit_biases, ) inputs = dict(query=query, attention_logit_biases=attention_logit_biases) + if has_query_positions: + inputs["query_positions"] = query_positions layer_outputs, _ = F( layer, state=layer_params, @@ -3445,8 +3554,8 @@ def f(x, layer_params): _, save_name_backward = jax.linearize( jax.remat( f, - policy=save_and_offload_only_these_names_regex( - names_which_can_be_saved=RematRegexSavePatterns.FEED_FORWARD.value, + policy=_save_and_offload_only_these_names_regex( + names_which_can_be_saved=FEED_FORWARD_SAVE_PATTERN, names_which_can_be_offloaded=None, offload_src="device", offload_dst="pinned_host", @@ -3901,69 +4010,6 @@ def f(x, layer_params): 5, ) - def test_build_remat_spec_neuron(self): - model_dim, num_heads = 6, 2 - cfg: TransformerLayer.Config = TransformerLayer.default_config().set(input_dim=model_dim) - cfg.self_attention.attention.set(num_heads=num_heads, causal=True) - cfg.feed_forward.hidden_dim = model_dim * 4 - cfg.vlog = 5 - - layer: BaseTransformerLayer = cfg.clone(name="layer").instantiate(parent=None) - layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) - - batch_size, tgt_len = 2, 5 - rng = np.random.default_rng(seed=123) - target = rng.random([batch_size, tgt_len, cfg.input_dim], dtype=np.float32) - - def f(x, layer_params): - forward_outputs, _ = F( - layer, - inputs=dict( - data=x, - ), - state=layer_params, - is_training=True, - prng_key=jax.random.PRNGKey(0), - ) - return forward_outputs - - # Ignore type errors. - spec: Any = build_remat_spec(mock.MagicMock()) - - policy = ( - config_for_function(save_and_offload_only_these_names_regex) - .set( - names_which_can_be_saved="|".join( - [ - RematRegexSavePatterns.QKV_PROJ.value, - RematRegexSavePatterns.LINEAR1_X.value, - ] - ), - names_which_can_be_offloaded=None, - offload_src=None, - offload_dst=None, - ) - .instantiate() - ) - - _, default_policy_backward = jax.linearize( - jax.remat(f, policy=policy, prevent_cse=spec.prevent_cse), - jnp.asarray(target), - layer_params, - ) - _, full_remat_backward = jax.linearize( - jax.remat(f), - jnp.asarray(target), - layer_params, - ) - - # Eliminated the remat of qkv_proj and linear1_0 = 4 dots. - self.assertEqual( - str(full_remat_backward).count(" dot_general") - - str(default_policy_backward).count(" dot_general"), - 4, - ) - class TestStackModel(BaseLayer): """A dummy transformer stack.""" From 3664b6bace364a8303425f8a83dd8119c4f7b958 Mon Sep 17 00:00:00 2001 From: Firenze11 Date: Tue, 14 Jan 2025 17:42:37 -0800 Subject: [PATCH 03/12] Update dit.py --- axlearn/common/dit.py | 76 +++++++++++++++++-------------------------- 1 file changed, 29 insertions(+), 47 deletions(-) diff --git a/axlearn/common/dit.py b/axlearn/common/dit.py index 24fcc4e21..711f0457c 100644 --- a/axlearn/common/dit.py +++ b/axlearn/common/dit.py @@ -13,8 +13,6 @@ from typing import Optional, Union -import chex -import einops import jax import jax.numpy as jnp @@ -33,21 +31,7 @@ def modulate(*, x, shift, scale): - """Modulates the input x tensor. - - Note: shift and scale must have the same shape. - - Args: - x: input tensor with shape [batch_size, num_length, input_dim]. - shift: shifting the norm tensor with shape [batch_size, 1|num_length, input_dim]. - scale: scaling the norm tensor with shape [batch_size, 1|num_length, input_dim]. - - Returns: - A tensor with shape [batch_size, num_length, input_dim]. - """ - chex.assert_equal_shape((shift, scale)) - chex.assert_equal_rank((x, shift, scale)) - return x * (1 + scale) + shift + return x * (1 + jnp.expand_dims(scale, 1)) + jnp.expand_dims(shift, 1) class TimeStepEmbedding(BaseLayer): @@ -227,18 +211,15 @@ def forward(self, input: Tensor) -> Tensor: """Generate the parameters for modulation. Args: - input: A tensor with shape [batch_size, dim] or [batch_size, num_length, dim]. + input: A tensor with shape [batch_size, ..., dim]. Returns: A list of tensors with length num_outputs. - Each tensor has shape [batch_size, 1|num_length, dim]. + Each tensor has shape [batch_size, ..., dim]. """ cfg = self.config x = get_activation_fn(cfg.activation)(input) output = self.linear(x) - assert output.ndim in (2, 3) - if output.ndim == 2: - output = einops.rearrange(output, "b d -> b 1 d") output = jnp.split(output, cfg.num_outputs, axis=-1) return output @@ -311,16 +292,14 @@ def forward(self, *, input: Tensor, shift: Tensor, scale: Tensor, gate: Tensor) Args: input: input tensor with shape [batch_size, num_length, input_dim]. - shift: shifting the norm tensor with shape [batch_size, 1|num_length, input_dim]. - scale: scaling the norm tensor with shape [batch_size, 1|num_length, input_dim]. + shift: shifting the norm tensor with shape [batch_size, input_dim]. + scale: scaling the norm tensor with shape [batch_size, input_dim]. gate: applying before the residual addition with shape - [batch_size, 1|num_length, input_dim]. + [batch_size, input_dim]. Returns: A tensor with shape [batch_size, num_length, input_dim]. """ - chex.assert_equal_shape((shift, scale, gate)) - chex.assert_equal_rank((input, shift)) cfg = self.config remat_pt1 = "linear1_0" remat_pt2 = "linear2" @@ -346,7 +325,7 @@ def forward(self, *, input: Tensor, shift: Tensor, scale: Tensor, gate: Tensor) x = self.postnorm(x) x = self.dropout2(x) - x = x * gate + x = x * jnp.expand_dims(gate, 1) x += input return x @@ -404,18 +383,19 @@ def forward( shift: Optional[Tensor] = None, scale: Optional[Tensor] = None, gate: Optional[Tensor] = None, + query_positions: Optional[Tensor] = None, attention_logit_biases: Optional[Tensor] = None, ) -> Tensor: """The forward function of DiTAttentionLayer. Args: input: input tensor with shape [batch_size, num_length, target_dim]. - shift: If provided, shifting the norm tensor with shape [batch_size, 1|num_length, - target_dim] and scale should be provided. - scale: If provided, scaling the norm tensor with shape [batch_size, 1|num_length, - target_dim] and shift should be provided. + shift: If provided, shifting the norm tensor with shape [batch_size, target_dim] and + scale should be provided. + scale: If provided, scaling the norm tensor with shape [batch_size, target_dim] and + shift should be provided. gate: If provided, applying before the residual addition with shape - [batch_size, 1|num_length, target_dim]. + [batch_size, target_dim]. attention_logit_biases: Optional Tensor representing the self attention biases. Returns: @@ -439,7 +419,9 @@ def forward( if shift is not None and scale is not None: x = modulate(x=x, shift=shift, scale=scale) - x = self.attention(query=x, attention_logit_biases=attention_logit_biases).data + x = self.attention( + query=x, query_positions=query_positions, attention_logit_biases=attention_logit_biases + ).data if cfg.structure == "postnorm": x = self.norm(x) @@ -447,7 +429,7 @@ def forward( x = self.postnorm(x) if gate is not None: - x = x * gate + x = x * jnp.expand_dims(gate, 1) output = input + x return output @@ -484,12 +466,12 @@ def extend_step( results of previous attentions, and index used for fast decoding. Contains "attention" cached states. target: target tensor with shape [batch_size, step_length, target_dim]. - shift: If provided, shifting the norm tensor with shape [batch_size, 1|num_length, - target_dim] and scale should be provided. - scale: If provided, scaling the norm tensor with shape [batch_size, 1|num_length, - target_dim] and shift should be provided. + shift: If provided, shifting the norm tensor with shape [batch_size, target_dim] and + scale should be provided. + scale: If provided, scaling the norm tensor with shape [batch_size, target_dim] and + shift should be provided. gate: If provided, applying before the residual addition with shape - [batch_size, 1|num_length, target_dim]. + [batch_size, target_dim]. Returns: A tuple (cached_states, output): @@ -525,7 +507,7 @@ def extend_step( x = self.postnorm(x) if gate is not None: - x = x * gate + x = x * jnp.expand_dims(gate, 1) output = target + x return dict(attention=attn_states), output @@ -563,8 +545,8 @@ def forward(self, *, input: Tensor, condition: Tensor) -> Tensor: Args: input: input tensor with shape [batch_size, num_length, input_dim]. - condition: tensor with shape [batch_size, input_dim] or [batch_size, num_length, - input_dim] for generating layer norm shift, scale, and gate. + condition: tensor with shape [batch_size, input_dim] for generating + layer norm shift, scale, and gate. Returns: A tensor with shape [batch_size, num_length, input_dim]. @@ -605,8 +587,8 @@ def extend_step( results of previous attentions, and index used for fast decoding. Contains "attention" cached states. target: target tensor with shape [batch_size, step_length, input_dim]. - condition: tensor with shape [batch_size, input_dim] or [batch_size, step_length, - input_dim] for generating layer norm shift, scale, and gate. + condition: tensor with shape [batch_size, input_dim] for generating + layer norm shift, scale, and gate. Returns: A tuple (cached_states, output): @@ -660,8 +642,8 @@ def forward(self, *, input: Tensor, condition: Tensor) -> Tensor: Args: input: input tensor with shape [batch_size, num_length, input_dim]. - condition: tensor with shape [batch_size, input_dim] or [batch_size, num_length, - input_dim] for generating layer norm shift and scale. + condition: tensor with shape [batch_size, input_dim] for generating + layer norm shift and scale. Returns: A tensor with shape [batch_size, num_length, output_dim]. From c967c076d3048ffab91d8f0a3333e54b9f28e9ef Mon Sep 17 00:00:00 2001 From: Firenze11 Date: Tue, 14 Jan 2025 17:46:58 -0800 Subject: [PATCH 04/12] Update attention.py --- axlearn/common/attention.py | 8846 ++++++++++++++++++++--------------- 1 file changed, 5005 insertions(+), 3841 deletions(-) diff --git a/axlearn/common/attention.py b/axlearn/common/attention.py index 7aaa2b9cf..bca07178f 100644 --- a/axlearn/common/attention.py +++ b/axlearn/common/attention.py @@ -6,3633 +6,4162 @@ # Copyright 2018 The TensorFlow Authors. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # -# google-research/t5x: -# Copyright 2022 The T5X Authors. All Rights Reserved. -# Licensed under the Apache License, Version 2.0 (the "License"). -# -# huggingface/transformers: -# Copyright 2021 The HuggingFace Inc. team. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"). -# -# facebookresearch/deit: -# Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"). -# -# tensorflow/models: -# Copyright 2023 The TensorFlow Authors. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"). -# -# google/praxis: -# Copyright 2022 The Pax Authors. -# Licensed under the Apache License, Version 2.0 (the "License"). -# # ofirpress/attention_with_linear_biases: # Copyright (c) Facebook, Inc. and its affiliates. -# Licensed under the MIT license. - -"""Attention layers with pjit partition specs. - -On `attention_logit_biases`: -* For methods that take a tensor, a biases Tensor can have one of the following shapes: - * [target_length, source_length] - * [batch, target_length, source_length] - * [batch, num_heads, target_length, source_length]. -* Each value represents a bias to be added to the attention logits - (therefore a -inf represents a disconnected position pair). -* biases=None represents an all-zero tensor, i.e., all position pairs are connected. -* For methods that take a BaseAttentionBias, the value() will always be None or a 4d Tensor with - the above semantics. - -TODO(apghml) Convert everything to take an instance of BaseAttentionBias rather than a Tensor. - -On `segment_ids`: -* A tensor of shape [batch, target_length] with values in [0, num_segments]. -* Tokens are only allowed to attend to other tokens within the same segment. -* segment_ids == 0 represents paddings. -* None represents an all-one tensor, i.e. all positions are in the same segment. - -On `positions`: -* A tensor of shape [batch, target_length]. Note that this is conceptually different from - `time_step`. To disambiguate: - * `positions`: A [batch, target_length] tensor indicating the position ids of each input token - during training (i.e. in `forward`). - * `time_step`: a [batch] tensor indicating the current decode position of each sample during - decoding (i.e. in `init_states` and `extend_step`). -* In most typical cases, the values of `positions` are integers in [0, target_length - 1]. - However, this should not be assumed by the implementation in order to support other positional - encoding schemes, e.g. RandPos (https://arxiv.org/pdf/2305.16843), where positions are - non-consecutive integers that can be larger than target_length - 1. -* None represents jnp.arange(target_length). -* When the accompanying argument is `query`, the `positions` argument is named as - `query_position`. Similarly, when the argument `target`, it is named as `target_positions`. - -TODO(changlan): Merge the use of `positions` and `time_step` to reduce cognitive complexity. - -""" - -# pylint: disable=abstract-method,too-many-lines -import enum -import functools +# +# facebookresearch/llama: +# Copyright (c) Facebook, Inc. and its affiliates. + +"""Tests attention layers.""" + +import contextlib +import copy +import itertools + +# pylint: disable=too-many-lines,duplicate-code,no-self-use import math -import re from collections.abc import Sequence -from enum import Enum, unique -from typing import Any, Callable, NamedTuple, Optional, Protocol, Union +from itertools import combinations +from typing import Any, Callable, Optional, Union +from unittest import mock -import einops import jax +import numpy as np +import optax +import pytest +import torch +from absl import logging +from absl.testing import absltest, parameterized +from jax import nn from jax import numpy as jnp -from jax._src.ad_checkpoint import name_p -from jax._src.interpreters import partial_eval as pe -from jax.core import Primitive - -from axlearn.common import ops, param_init +from jax.ad_checkpoint import checkpoint_policies as jax_remat_policies +from transformers.models.roberta import modeling_roberta as hf_roberta +from transformers.models.roformer import modeling_roformer as hf_roformer +from transformers.models.xlnet import modeling_xlnet as hf_xlnet + +from axlearn.common import attention, attention_bias, test_utils, utils +from axlearn.common.attention import ( + BaseStackedTransformerLayer, + BaseTransformerLayer, + BottleNeckAdapterTransformerLayer, + FusedGroupedQKVLinear, + FusedQKVLinear, + KVState, + LearnedPositionalEmbedding, + MultiheadAttentionXL, + MultiheadInputLinear, + MultiheadOutputLinear, + MultiheadRelativePositionLinear, + ParallelTransformerLayer, + PerDimScale, + PipelinedTransformerLayer, + QKVLinear, + QLinear, + RematRegexSavePatterns, + RepeatedTransformerLayer, + RoFormerQKVLinear, + StackedTransformerLayer, + TransformerAttentionLayer, + TransformerFeedForwardLayer, + TransformerLayer, + _next_power_of_two, + apply_attention_logit_biases, + apply_rotary_position_embeddings, + build_remat_spec, + compute_padding_biases, + rel_pos_to_abs_pos, + scaled_hidden_dim, + set_double_shard_weights_config, + sinusoidal_positional_embeddings, + update_data_with_skip_connection, + xl_attention_logits, +) from axlearn.common.attention_bias import ( NEG_INF, - BaseAttentionBias, - CausalAttentionBias, - MaskFn, - MaskFnAttentionBias, - SegmentIdAttentionBias, - as_attention_bias, + bool_to_bias, causal_mask, - make_segment_mask, + make_causal_biases, + make_sliding_window_causal_biases, + sliding_window_causal_mask, ) from axlearn.common.base_layer import ( BaseLayer, + DefaultTensorStats, FactorizationSpec, - NestedParameterSpec, ParameterSpec, RematSpec, ) from axlearn.common.config import ( - REQUIRED, - ConfigOr, - FunctionConfigBase, InstantiableConfig, - Required, + UnknownFieldError, config_class, config_for_function, - maybe_instantiate, -) -from axlearn.common.layers import ( - Dropout, - LayerNorm, - Linear, - StochasticDepth, - get_activation_fn, - get_stochastic_depth_linear_rate, + maybe_set_config, ) -from axlearn.common.module import Module, child_context +from axlearn.common.decoder import Decoder, TransformerTextEmbeddings +from axlearn.common.layers import RMSNorm, set_bias_recursively +from axlearn.common.module import InvocationContext, Module +from axlearn.common.module import functional as F +from axlearn.common.module import new_output_collection, set_current_context +from axlearn.common.optimizer_base import OptParam +from axlearn.common.optimizers import adafactor_optimizer +from axlearn.common.param_converter import as_torch_tensor from axlearn.common.param_init import ( PARAM_REGEXP_WEIGHT, - ConstantInitializer, DefaultInitializer, FanAxes, WeightInitializer, - constant_initializer, ) -from axlearn.common.pipeline import Pipeline -from axlearn.common.quantized_dot_general.layers import DenseGeneralBaseLayer -from axlearn.common.repeat import Repeat +from axlearn.common.pipeline import BaseSchedule, GPipeSchedule, StreamSchedule +from axlearn.common.test_utils import TestCase, assert_allclose, dummy_segments_positions +from axlearn.common.torch_utils import parameters_from_torch_layer from axlearn.common.utils import ( Nested, - NestedTensor, PartitionSpec, Tensor, TensorSpec, VDict, - check_numerics, + as_tensor, flatten_items, - get_or_none, + save_and_offload_only_these_names_regex, shapes, - split_prng_key, ) -class ForwardMode(enum.Enum): - """ForwardMode describes the type of computation to be done in a forward pass through a layer. - - FORWARD: Used for a standard forward pass. - INIT_STATES: Used for initializing the decoding cache. Typically means that the method signature - matches EXTEND_STEP, possibly without an input cache state, and returning a prefilled cache - along with the layer outputs. - EXTEND_STEP: Used for incremental decoding. Typically means that the method signature consumes - cache state and emits cache state along with layer outputs. - """ - - FORWARD = 0 - INIT_STATES = 1 - EXTEND_STEP = 2 +def all_subsets(given_set): + "Generate all subsets of a list `given_set`." + s = list(given_set) + return list( + itertools.chain.from_iterable(itertools.combinations(s, r) for r in range(len(s) + 1)) + ) -class KVState(NamedTuple): - """Represents key/value projections, of shape [batch, source_length, num_kv_heads, head_dim].""" +def make_index_position_biases(query_len: int, kv_len: int) -> Tensor: + """Generates attention logit biases where query indices cannot attend to larger key indices. - k_proj: Tensor - v_proj: Tensor + Args: + query_len: The sequence length. + kv_len: The key's length. + Returns: + A float tensor of shape [query_len, kv_len] where the value at + [i, j] = -inf if i < j, 0 otherwise. + """ -class BaseTransformerLayer(BaseLayer): - """An abstract class to define the common interface of all *TransformerLayer classes, including: + return bool_to_bias( + causal_mask( + jnp.arange(query_len)[:, None], + jnp.arange(kv_len)[None, :], + ) + ) - * All subclasses must have `input_dim` in its Config; - * The common Output structure; - * The common method signature for `forward()`, `init_states()`, and `extend_step()`. - """ - @config_class - class Config(BaseLayer.Config): - """Configures BaseTransformerLayer.""" +def _random_mask(prng_key, tgt_len, src_len): + """Returns a float mask of shape [tgt_len, src_len].""" + key1, key2 = jax.random.split(prng_key) + mask = jnp.logical_not( + jax.random.randint(key1, minval=0, maxval=2, shape=[tgt_len, src_len]) + + + # Ensure that every tgt position attends to at least one src position, otherwise + # torch_modules.MultiheadAttention will generate NaN. + nn.one_hot(jax.random.randint(key2, minval=0, maxval=src_len, shape=[tgt_len]), src_len) + ) + return mask.astype(jnp.float32) * NEG_INF + + +class MaskTest(absltest.TestCase): + """Tests mask implementations.""" + + def test_causal_mask(self): + expected = jnp.array([[0.0, NEG_INF, NEG_INF], [0.0, 0.0, NEG_INF], [0.0, 0.0, 0.0]]) + actual = attention_bias.make_causal_biases(3) + self.assertTrue(jnp.all(actual <= expected)) + + def test_segment_mask(self): + expected = jnp.array( + [ # batch + [ # num_heads + [ + [NEG_INF, NEG_INF, NEG_INF, 0.0], + [NEG_INF, NEG_INF, NEG_INF, 0.0], + [0.0, 0.0, NEG_INF, NEG_INF], + [NEG_INF, NEG_INF, 0.0, NEG_INF], + ] + ] + ] + ) + actual = attention_bias.make_segment_mask( + target_segments=jnp.asarray([[1, 1, 2, 0]]), + source_segments=jnp.asarray([[2, 2, 0, 1]]), + ) + self.assertTrue(jnp.all(actual <= expected)) + + def test_apply_attention_logit_biases(self): + batch_size = 10 + num_heads = 12 + dim = 32 + logits = jnp.asarray(np.random.random(size=[batch_size, num_heads, dim])) + + # Testing for biases = None + masked_logit = apply_attention_logit_biases(logits, attention_logit_biases=None) + self.assertEqual(masked_logit.dtype, logits.dtype) + np.testing.assert_array_equal(logits, masked_logit) + + # Testing for biases = random_float_biases + for logit_float_type in [jnp.bfloat16, jnp.float32, jnp.float16]: + for mask_float_type in [jnp.bfloat16, jnp.float32, jnp.float16]: + logits = jnp.asarray(np.random.random(size=[batch_size, num_heads, dim])).astype( + logit_float_type + ) + random_float_biases = jnp.asarray( + np.random.random(size=[batch_size, num_heads, dim]) + ).astype(mask_float_type) + masked_logit = apply_attention_logit_biases( + logits, attention_logit_biases=random_float_biases + ) + self.assertEqual(masked_logit.dtype, logits.dtype) + np.testing.assert_array_equal( + masked_logit, logits + random_float_biases.astype(logits.dtype) + ) - input_dim: Required[int] = REQUIRED # Input feature dim. - class Output(NamedTuple): - """BaseTransformerLayer output. +class CausalAttentionLogitBiasLayerTest(TestCase): + """Tests CausalAttentionLogitBiasLayer.""" - Fields: - data: [batch, target_length, input_dim]. The layer output. Always present. + @parameterized.parameters( + # Test the mask with all padding tokens. + dict( + token_ids=[[0, 0, 0], [0, 0, 0]], + expected=[ + [ + [0, NEG_INF, NEG_INF], + [0, 0, NEG_INF], + [0, 0, 0], + ], + ] + * 2, + ), + # Test the mask with all valid tokens. + dict( + token_ids=[[1, 2, 3], [4, 5, 6]], + expected=[ + [ + [0, NEG_INF, NEG_INF], + [0, 0, NEG_INF], + [0, 0, 0], + ], + ] + * 2, + ), + # Test the mask with some valid tokens and some padding tokens. + dict( + token_ids=[[10, 0, 0], [12, 33, 0]], + expected=[ + [ + [0, NEG_INF, NEG_INF], + [0, 0, NEG_INF], + [0, 0, 0], + ], + ] + * 2, + ), + # Test the mask with additional padding biases. + dict( + token_ids=[[10, 0, 0], [12, 33, 0]], + apply_padding_mask=True, + expected=[ + [ + [0, NEG_INF, NEG_INF], + [NEG_INF, NEG_INF, NEG_INF], + [NEG_INF, NEG_INF, NEG_INF], + ], + [ + [0, NEG_INF, NEG_INF], + [0, 0, NEG_INF], + [NEG_INF, NEG_INF, NEG_INF], + ], + ], + ), + # Test the mask with valid tokens AND paddings in between. + dict( + token_ids=[[10, 0, 11], [12, 33, 0]], + apply_padding_mask=True, + expected=[ + [ + [0, NEG_INF, NEG_INF], + [NEG_INF, NEG_INF, NEG_INF], + [0, NEG_INF, 0], + ], + [ + [0, NEG_INF, NEG_INF], + [0, 0, NEG_INF], + [NEG_INF, NEG_INF, NEG_INF], + ], + ], + ), + # Test a basic case with positions. + dict( + token_ids=[[10, 11, 12], [13, 14, 15]], + segment_ids=[[1, 1, 2], [1, 2, 2]], + positions=[[0, 1, 0], [0, 0, 1]], + expected=[ + [ + [0, NEG_INF, NEG_INF], + [0, 0, NEG_INF], + [NEG_INF, NEG_INF, 0], + ], + [ + [0, NEG_INF, NEG_INF], + [NEG_INF, 0, NEG_INF], + [NEG_INF, 0, 0], + ], + ], + ), + # Test a case where some segments are empty. + dict( + token_ids=[[10, 11, 12], [13, 14, 15]], + segment_ids=[[1, 2, 2], [2, 2, 2]], + positions=[[0, 0, 1], [0, 1, 2]], + expected=[ + [ + [0, NEG_INF, NEG_INF], + [NEG_INF, 0, NEG_INF], + [NEG_INF, 0, 0], + ], + [ + [0, NEG_INF, NEG_INF], + [0, 0, NEG_INF], + [0, 0, 0], + ], + ], + ), + # Test with positions and padding. + # Note: we deliberately allow the last token to be 0, to test that without + # apply_padding_mask, a 0-token is not necessarily padding if its segment_id != 0. + dict( + token_ids=[[10, 11, 0], [13, 14, 0]], + segment_ids=[[1, 1, 0], [1, 2, 2]], + positions=[[0, 1, 0], [0, 0, 1]], + expected=[ + [ + [0, NEG_INF, NEG_INF], + [0, 0, NEG_INF], + [NEG_INF, NEG_INF, 0], + ], + [ + [0, NEG_INF, NEG_INF], + [NEG_INF, 0, NEG_INF], + [NEG_INF, 0, 0], + ], + ], + ), + # Test with segment IDs but not positions. + # This should have the same result as the previous test. + dict( + token_ids=[[10, 11, 0], [13, 14, 0]], + segment_ids=[[1, 1, 0], [1, 2, 2]], + expected=[ + [ + [0, NEG_INF, NEG_INF], + [0, 0, NEG_INF], + [NEG_INF, NEG_INF, 0], + ], + [ + [0, NEG_INF, NEG_INF], + [NEG_INF, 0, NEG_INF], + [NEG_INF, 0, 0], + ], + ], + ), + # Test with positions and padding, and apply the padding mask. + dict( + token_ids=[[10, 11, 0], [13, 14, 0]], + segment_ids=[[1, 1, 0], [1, 2, 0]], + positions=[[0, 1, 0], [0, 0, 1]], + expected=[ + [ + [0, NEG_INF, NEG_INF], + [0, 0, NEG_INF], + [NEG_INF, NEG_INF, NEG_INF], + ], + [ + [0, NEG_INF, NEG_INF], + [NEG_INF, 0, NEG_INF], + [NEG_INF, NEG_INF, NEG_INF], + ], + ], + apply_padding_mask=True, + ), + ) + def test_causal_attention_mask_layer( + self, + token_ids: list, + expected: list, + segment_ids: Optional[Tensor] = None, + positions: Optional[Tensor] = None, + apply_padding_mask: Optional[bool] = False, + ): + causal_attention_mask_layer = ( + attention.CausalAttentionLogitBiasLayer.default_config() + .set(name="test_causal_attention_mask") + .instantiate(parent=None) + ) + if token_ids is not None: + token_ids = np.asarray(token_ids) + if positions is None: + positions = np.arange(token_ids.shape[-1])[None, :] + else: + positions = np.asarray(positions) + if segment_ids is None: + segment_ids = np.ones_like(token_ids) + else: + segment_ids = np.asarray(segment_ids) + actual = causal_attention_mask_layer.forward(segment_ids=segment_ids, positions=positions) + if apply_padding_mask: + actual += compute_padding_biases(token_ids, pad_token_id=0) + assert_allclose(jnp.exp(actual.squeeze(1)), jnp.exp(np.asarray(expected))) + + +class FullAttentionLogitBiasLayerTest(TestCase): + """Tests FullAttentionLogitBiasLayer.""" + + @parameterized.parameters( + # Test the mask with all padding tokens. + dict( + token_ids=[[0, 0, 0], [0, 0, 0]], + expected=[ + [ + [NEG_INF, NEG_INF, NEG_INF], + [NEG_INF, NEG_INF, NEG_INF], + [NEG_INF, NEG_INF, NEG_INF], + ], + [ + [NEG_INF, NEG_INF, NEG_INF], + [NEG_INF, NEG_INF, NEG_INF], + [NEG_INF, NEG_INF, NEG_INF], + ], + ], + ), + # Test the mask with all valid tokens. + dict( + token_ids=[[1, 2, 3], [4, 5, 6]], + expected=[ + [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + ], + ), + # Test the mask with some valid tokens and some padding tokens. + dict( + token_ids=[[10, 0, 0], [12, 33, 0]], + expected=[ + [ + [0, NEG_INF, NEG_INF], + [NEG_INF, NEG_INF, NEG_INF], + [NEG_INF, NEG_INF, NEG_INF], + ], + [ + [0, 0, NEG_INF], + [0, 0, NEG_INF], + [NEG_INF, NEG_INF, NEG_INF], + ], + ], + ), + # Test a basic case with segment IDs. + dict( + token_ids=[[10, 11, 12], [13, 14, 15]], + segment_ids=[[1, 1, 2], [1, 2, 2]], + positions=[[0, 1, 0], [0, 0, 1]], + expected=[ + [ + [0, 0, NEG_INF], + [0, 0, NEG_INF], + [NEG_INF, NEG_INF, 0], + ], + [ + [0, NEG_INF, NEG_INF], + [NEG_INF, 0, 0], + [NEG_INF, 0, 0], + ], + ], + ), + # Test a case where some segments are empty. + dict( + token_ids=[[10, 11, 12], [13, 14, 15]], + segment_ids=[[1, 1, 2], [2, 2, 2]], + positions=[[0, 1, 0], [0, 1, 2]], + expected=[ + [ + [0, 0, NEG_INF], + [0, 0, NEG_INF], + [NEG_INF, NEG_INF, 0], + ], + [ + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + ], + ], + ), + # Test with segment IDs and padding. + dict( + token_ids=[[10, 11, 0], [13, 14, 0]], + segment_ids=[[1, 1, 0], [1, 2, 0]], + positions=[[0, 1, 0], [0, 0, 1]], + expected=[ + [ + [0, 0, NEG_INF], + [0, 0, NEG_INF], + [NEG_INF, NEG_INF, NEG_INF], + ], + [ + [0, NEG_INF, NEG_INF], + [NEG_INF, 0, NEG_INF], + [NEG_INF, NEG_INF, NEG_INF], + ], + ], + ), + ) + def test_full_attention_mask_layer( + self, + token_ids: list, + expected: list, + segment_ids: Optional[Tensor] = None, + positions: Optional[Tensor] = None, + ): + full_attention_mask_layer = ( + attention.FullAttentionLogitBiasLayer.default_config() + .set(name="test_full_attention_mask") + .instantiate(parent=None) + ) + if token_ids is not None: + token_ids = np.asarray(token_ids) + if positions is None: + positions = np.arange(token_ids.shape[-1])[None, :] + else: + positions = np.asarray(positions) + if segment_ids is None: + segment_ids = token_ids != 0 + else: + segment_ids = np.asarray(segment_ids) + actual = full_attention_mask_layer.forward(segment_ids=segment_ids, positions=positions) + actual += compute_padding_biases(token_ids, pad_token_id=0) + assert_allclose(jnp.exp(np.asarray(expected)), jnp.exp(actual.squeeze(1))) + + +class ALiBiAttentionLogitBiasLayerTest(TestCase): + """Tests ALiBiAttentionLogitBiasLayer.""" + + def ref_alibi_implementation(self, batch_size, num_heads, max_len): + # Slopes is in jax DeviceArray. Switch it to torch tensor as the ref code. + slopes = torch.Tensor(attention.alibi_get_slopes(num_heads)) + alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(max_len).unsqueeze(0).unsqueeze( + 0 + ).expand(num_heads, -1, -1) + alibi = alibi.view(num_heads, 1, max_len) + + # Post processing to translate alibi matrix into the jax format. + # Alibi matrix shape [batch_size, num_heads, max_len, max_len]. + alibi = alibi.unsqueeze(0).expand(batch_size, -1, max_len, -1) + # Translate from pytorch to jax. + alibi = as_tensor(alibi) + return alibi + + def test_alibi_attention_mask(self): + num_heads = 12 + batch_size = 2 + max_len = 3 + + # Test alibi implementation. + alibi_attention_mask_layer = ( + attention.ALiBiAttentionLogitBiasLayer.default_config() + .set(name="test_alibi_attention_mask", num_heads=num_heads) + .instantiate(parent=None) + ) - self_attention_probs: The attention probabilities returned by the self-attention layer. - Shape: [..., target_length, target_length]. + # Casual attention mask which will be applied to ref alibi mask. + ref_causal_attention_mask_layer = ( + attention.CausalAttentionLogitBiasLayer.default_config() + .set(name="ref_causal_attention_mask") + .instantiate(parent=None) + ) - self_attention_probs[..., i, j] represents self-attention probability on - input data[..., j, :] when computing output data[..., i, :]. - self_attention_probs.sum(axis=-1) equals to all 1's. + token_ids = as_tensor(np.random.randint(low=1, high=20, size=[batch_size, max_len])) + segment_ids = jnp.ones_like(token_ids) + positions = jnp.arange(max_len)[None, :] - Present if "self_attention_probs" is in `return_aux`. + ref_alibi_mask = self.ref_alibi_implementation(batch_size, num_heads, max_len) + # Reshape causal_mask to [batch_size, num_heads, max_len, max_len]. + ref_causal_mask = ref_causal_attention_mask_layer.forward( + segment_ids=segment_ids, positions=positions + ) + ref_causal_mask = jnp.repeat(ref_causal_mask, num_heads, axis=1) - self_attention_kv_state: The KV state used in self-attention. - Present if "self_attention_kv_state" is in `return_aux`. + # Prepare the ref and the test alibi mask. + ref_alibi_mask = attention.apply_attention_logit_biases(ref_alibi_mask, ref_causal_mask) + test_alibi_mask = alibi_attention_mask_layer.forward( + segment_ids=segment_ids, positions=positions + ) - cross_attention_probs: The attention probabilities returned by the cross-attention - layer. Shape: [..., target_length, source_length]. + # Ref and test alibi mask should be the same after applying it into a QK attention matrix. + # e.g. softmax(QK + ref_alibi_mask) == softmax(QK + test_alibi_mask). + random_qk_matrix = jnp.asarray( + np.random.random(size=[batch_size, num_heads, max_len, max_len]) + ) - If not None, cross_attention_probs[..., i, j] represents attention probability on - cross_attention_data[..., j, :] when computing output data[..., i, :]. - cross_attention_probs.sum(axis=-1) equals to all 1's. + ref_alibi_softmax = jax.nn.softmax(random_qk_matrix + ref_alibi_mask, axis=-1) + test_alibi_softmax = jax.nn.softmax(random_qk_matrix + test_alibi_mask, axis=-1) + + # The ref alibi implementation relies on the softmax property of invariance to translation. + # e.g. ref_alibi = [[0, -inf, -inf], [0, 1, -inf], [0, 1, 2]] + # test_alibi = [[0, -inf, -inf], [-1, 0, -inf], [-2, -1, 0]] + # softmax(qk + test_alibi) = softmax (qk + [[0, -inf, -inf], [-1, 0, -inf], [-2, -1, 0]]) + # = softmax (qk + [[0, -inf, -inf], [0, 1, -inf+1], [0, 1, 2]]) + # As the numerical -inf is not perfect -inf defined in math. + # Therefore, a very limit difference between those two after softmax, due to (-inf + x). + # The rtol is set to 5e-7 to tolerate this difference. + np.testing.assert_allclose(ref_alibi_softmax, test_alibi_softmax, rtol=5e-07) + + @parameterized.product( + [ + dict(num_segments=1, max_len=3), + dict(num_segments=3, max_len=3), + dict(num_segments=3, max_len=8), + ], + ) + def test_packing(self, max_len: int, num_segments: int): + # With packed inputs of shape [batch, seq_len], we form a block-diagonal matrix of shape + # [batch, num_heads, seq_len, seq_len], where each (unpacked) input has blocks of shape + # [batch, num_heads, segment_len, segment_len] (segment_len <= seq_len). + # We test this by comparing each block against a freshly computed alibi mask of the same + # shape, ensuring that packing is equivalent to treating each unpacked input separately. + num_heads = 12 + batch_size = 2 + + # Test alibi implementation. + alibi_attention_mask_layer = ( + attention.ALiBiAttentionLogitBiasLayer.default_config() + .set(name="test_alibi_attention_mask", num_heads=num_heads) + .instantiate(parent=None) + ) - Present if "cross_attention_probs" is in `return_aux`. - """ + # Construct inputs of shape [batch_size, max_len]. + input_segment_ids, positions = dummy_segments_positions( + batch_size, max_len, num_segments=num_segments + ) - data: Tensor - self_attention_probs: Optional[Tensor] = None - self_attention_kv_state: Optional[KVState] = None - cross_attention_probs: Optional[Tensor] = None + # Compute the test alibi mask of shape [batch, num_heads, seq_len, seq_len]. + test_alibi_batch = alibi_attention_mask_layer.forward( + segment_ids=input_segment_ids, positions=positions + ) + # Apply segment mask and softmax (see notes above). + test_alibi_batch = jax.nn.softmax(test_alibi_batch, axis=-1) + + for batch in range(batch_size): + test_alibi = test_alibi_batch[batch] + input_segments = input_segment_ids[batch] + + # Compute the reference alibi mask(s) for each segment separately. + for segment in range(num_segments): + # [seq_len]. + segment_mask = input_segments == segment + segment_len = int(jnp.sum(segment_mask, dtype=jnp.int32)) + + # Skip the segment if empty. + if segment_len == 0: + continue + + # Select the submatrix in test_alibi corresponding to the current segment. + # [seq_len, seq_len]. + segment_mask = jnp.logical_and(segment_mask[:, None], segment_mask[None, :]) + # [num_heads, seq_len, seq_len]. + segment_mask = jnp.repeat(segment_mask[None, ...], num_heads, 0) + # [num_heads, segment_len, segment_len]. + test_alibi_segment = test_alibi[segment_mask.astype(jnp.bool_)].reshape( + (num_heads, segment_len, segment_len) + ) - def forward( - self, - data: Tensor, - *, - self_attention_kv_state: Optional[KVState] = None, - self_attention_logit_biases: Optional[Tensor] = None, - cross_attention_data: Optional[Tensor] = None, - cross_attention_logit_biases: Optional[Tensor] = None, - target_segment_ids: Optional[Tensor] = None, - target_positions: Optional[Tensor] = None, - return_aux: Optional[set[str]] = None, - ) -> Output: - """Computes transformer layer outputs given full-sequence inputs. - - For incremental computation, use init_states() and extend_step(). - - See comments at the beginning of this file for semantics of *_attention_logit_biases. - - Args: - data: A Tensor of shape [batch, target_length, input_dim]. - self_attention_kv_state: An optional KVState used for self-attention. - self_attention_logit_biases: An optional Tensor representing the self-attention biases. - cross_attention_data: An optional Tensor representing cross-attention data of shape - [source_batch, source_length, source_dim]. - cross_attention_logit_biases: An optional Tensor representing the cross-attention - biases. - target_segment_ids: See ``segment_ids`` in the file comments. - target_positions: See ``positions`` in the file comments. - return_aux: A set of auxiliary output fields to return. Each element must be an - optional field of `Output`, e.g., - `return_aux = {"self_attention_probs", "self_attention_kv_state"}` means that - `Output.{self_attention_probs, self_attention_kv_state}` will be populated. - - Returns: - BaseTransformerLayer.Output. - """ - raise NotImplementedError(type(self)) + # Construct the ref_alibi for the current segment. + # [num_heads, segment_len]. + ref_alibi = self.ref_alibi_implementation(1, num_heads, segment_len).squeeze(0) + ref_causal_mask = jnp.repeat( + make_causal_biases(segment_len)[None, ...], num_heads, 0 + ) + ref_alibi = attention.apply_attention_logit_biases(ref_alibi, ref_causal_mask) + ref_alibi = jax.nn.softmax(ref_alibi, axis=-1) - def init_states( - self, - *, - time_step: Optional[Tensor], - data: Union[Tensor, TensorSpec], - self_attention_kv_state: Optional[KVState] = None, - self_attention_logit_biases: Optional[Tensor] = None, - cross_attention_data: Optional[Tensor] = None, - cross_attention_logit_biases: Optional[Tensor] = None, - ) -> tuple[Nested[Tensor], Optional[Output]]: - """Initializes cached states for incremental computation. - - The method supports initializing an empty cache as well as prefilling: - * To initialize an empty cache, specify `time_step=None`. - In this case, `data` is allowed to be a TensorSpec. - * To prefill, provide `time_step` and `data` as Tensors. - - Args: - time_step: A Tensor of shape [batch]. Each value is an index into the length dimension - indicating where decoding will start from. - data: A Tensor of shape [batch, target_length, input_dim]. For batch index `i`, only - `data[i, :time_step[i], ...]` will affect subsequent decoding. - self_attention_kv_state: An optional KVState used for self-attention. - self_attention_logit_biases: An optional Tensor representing the self-attention biases. - cross_attention_data: An optional Tensor representing cross-attention data of shape - [batch, source_length, source_dim]. - cross_attention_logit_biases: An optional Tensor representing the cross-attention - biases. - - Returns: - A tuple (init_states, output): - * init_states: A nested tree of Tensors, which can be used as `cached_states` for the - initial call of `extend_step()`. - * output: In the prefill case, a BaseTransformerLayer.Output instance, where: - .data is of the same shape as `data`; - .self_attention_probs is of shape [batch, num_heads, target_length, target_length]; - .cross_attention_probs is of shape [batch, num_heads, target_length, source_length]. - Otherwise, if initializing cache from scratch, output will be None. - """ - raise NotImplementedError(type(self)) + np.testing.assert_allclose(ref_alibi, test_alibi_segment, rtol=5e-07) - def extend_step( - self, - cached_states: NestedTensor, - data: Tensor, - *, - self_attention_kv_state: Optional[KVState] = None, - self_attention_logit_biases: Optional[Tensor] = None, - cross_attention_data: Optional[Tensor] = None, - cross_attention_logit_biases: Optional[Tensor] = None, - ) -> tuple[NestedTensor, Output]: - """Computes incremental outputs. - - Args: - cached_states: A NestedTensor returned by `init_states()` or a previous invocation of - `extend_step()`. - data: A Tensor of shape [target_batch_size, target_step_length, input_dim], where - `target_step_length` is usually 1. For self-attention, `data` will be used as the - `query` sequence and will be appended to key and value sequences. - self_attention_kv_state: An optional KVState used for self-attention. - self_attention_logit_biases: An optional Tensor of shape - [..., target_step_length, target_max_len], where `target_step_length` must match - the shape of `data` and `target_max_len` must match the value given for - `init_states()`. - cross_attention_data: An optional Tensor of shape [..., source_length, source_dim]. - cross_attention_logit_biases: An optional Tensor of shape - [..., target_step_length, source_length], where `target_step_length` must match - the shape of `data`. - - Returns: - (updated_cached_states, output), where: - `updated_cached_states` represents the new cached states incorporating `data`; - `output` represents the output for the given input data. `output.data` will have the - same shape as the input data. - """ - raise NotImplementedError(type(self)) +class SymmetricALiBiAttentionLogitBiasLayerTest(TestCase): + """Tests SymmetricALiBiAttentionLogitBiasLayer.""" -class LearnedPositionalEmbedding(BaseLayer): - """TODO(ruoming): Remove LearnedPositionalEmbedding. We can just use the Embedding layer.""" + def test_alibi_attention_mask(self): + num_heads = 8 + batch_size = 2 + max_len = 3 - @config_class - class Config(BaseLayer.Config): - """Configures LearnedPositionalEmbedding.""" - - dim: Required[int] = REQUIRED # Input feature dim. - shape: Required[Sequence[int]] = REQUIRED # The sequence shape. - - # Similar initialization code for Embedding. - # pylint: disable=duplicate-code - @classmethod - def default_config(cls): - cfg = super().default_config() - cfg.param_partition_spec = (None, None, "model") - # By default, initialize to Gaussian with std=1/sqrt(dim), e.g., 0.036 when dim=768. - # - # This is the same as: - # https://github.com/pytorch/fairseq/blob/master/fairseq/modules/positional_embedding.py#L26 - # - # BERT uses std=0.02 regardless of dim: - # https://github.com/google-research/bert/blob/eedf5716ce1268e56f0a50264a88cafad334ac61/modeling.py#L492-L495 - cfg.param_init = DefaultInitializer.default_config().set( - init_by_param_name={ - PARAM_REGEXP_WEIGHT: WeightInitializer.default_config().set( - fan="fan_out", distribution="normal" - ) - } + # Test alibi implementation. + alibi_attention_mask_layer = ( + attention.SymmetricALiBiAttentionLogitBiasLayer.default_config() + .set(name="test_symmetric_alibi_attention_mask", num_heads=num_heads) + .instantiate(parent=None) ) - return cfg - - # pylint: enable=duplicate-code - def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: - cfg = self.config - return dict( - weight=ParameterSpec( - shape=[1] + list(cfg.shape) + [cfg.dim], - mesh_axes=cfg.param_partition_spec, - ) + # [num_heads] + slopes = jnp.array(attention.alibi_get_slopes(num_heads)) + + # [max_len, max_len] + base_alibi_mask = jnp.array( + [ + [0, -1, -2], + [-1, 0, -1], + [-2, -1, 0], + ], + dtype=jnp.float32, ) - def _compute_fan_axes(self, name: str, parameter_spec: ParameterSpec) -> Optional[FanAxes]: - if not name.endswith("weight"): - return None - if len(parameter_spec.shape) != 3: - raise NotImplementedError( - "_compute_fan_axes requires weight parameters to have exactly 3 axes " - f"shape({name}) = {parameter_spec.shape}" - ) - return FanAxes(batch_axis=0, in_axis=1, out_axis=2) + # [heads, max_len, max_len] + expected_logits_bias = slopes[:, jnp.newaxis, jnp.newaxis] * base_alibi_mask + # [batch, heads, max_len, max_len] + expected_logits_bias = expected_logits_bias[jnp.newaxis, ...].repeat(batch_size, axis=0) - def embeddings(self) -> Tensor: - """Returns weights of shape cfg.shape + [dim].""" - return self.parameters["weight"].squeeze(0) + segment_ids = jnp.ones((batch_size, max_len)) + positions = jnp.arange(max_len)[None, :] + actual_logits_bias = alibi_attention_mask_layer( + segment_ids=segment_ids, positions=positions + ) - def forward(self, positions: Tensor) -> Tensor: - """ - Args: - positions: An integer tensor with arbitrary shape [...]. + assert_allclose(actual_logits_bias, expected_logits_bias) - Returns: - Embeddings with shape [..., *cfg.dim[1:], dim]. - """ - embeddings = self.embeddings() - return embeddings[positions] +class RoFormerSinusoidalPositionalEmbeddingTest(TestCase): + """Tests RoFormerSinusoidalPositionalEmbedding.""" -def sinusoidal_positional_embeddings( - positions: Tensor, *, dim: int, min_timescale: float = 1, max_timescale: float = 10000 -) -> Tensor: - """Sinusoidal positional embeddings. + @parameterized.product( + tensor_dimensions=( + (2, 3, 10, 32), + (2, 3, 8, 32), + (2, 4, 6, 32), + (2, 4, 8, 16), + (2, 5, 8, 48), + (2, 5, 8, 64), + ), + rotary_key=(True, False), + rotary_value=(True, False), + ) + def test_apply_rotary_position_embeddings( + self, tensor_dimensions: tuple[int, int, int, int], rotary_key: bool, rotary_value: bool + ): + # Unittest against the apply_rotary_position_embeddings in HF. + batch_size, num_heads, max_len, dim = tensor_dimensions + + token_ids = np.random.randint(low=1, high=20, size=[batch_size, max_len]) + sinusoidal_pos_layer = hf_roformer.RoFormerSinusoidalPositionalEmbedding(max_len, dim) + sinusoidal_pos = sinusoidal_pos_layer(as_torch_tensor(token_ids).shape)[None, None, :, :] + query = np.random.random([batch_size, num_heads, max_len, dim]) + key = np.random.random([batch_size, num_heads, max_len, dim]) + value = np.random.random([batch_size, num_heads, max_len, dim]) + ref_layer = hf_roformer.RoFormerSelfAttention.apply_rotary_position_embeddings + test_layer = apply_rotary_position_embeddings + if rotary_value: + ref_q_proj, ref_k_proj, ref_v_proj = ref_layer( + sinusoidal_pos, + as_torch_tensor(query), + as_torch_tensor(key), + as_torch_tensor(value), + ) + else: + # If rotary_value is set to False, value keeps unchanged. + # pylint: disable-next=unbalanced-tuple-unpacking + ref_q_proj, ref_k_proj = ref_layer( + sinusoidal_pos, as_torch_tensor(query), as_torch_tensor(key) + ) + ref_v_proj = as_torch_tensor(value) + if not rotary_key: + ref_k_proj = as_torch_tensor(key) - Proposed in the original Transformer paper: https://arxiv.org/abs/1706.03762. + test_q_proj, test_k_proj, test_v_proj = test_layer( + sinusoidal_pos=as_tensor(sinusoidal_pos), + query=query, + key=key, + value=value, + rotary_key=rotary_key, + rotary_value=rotary_value, + ) + np.testing.assert_allclose(test_q_proj, ref_q_proj, atol=5e-7) + np.testing.assert_allclose(test_k_proj, ref_k_proj, atol=5e-7) + np.testing.assert_allclose(test_v_proj, ref_v_proj, atol=5e-7) + + @parameterized.parameters( + (2, 10, 32), + (2, 8, 32), + (2, 6, 32), + (2, 8, 16), + (2, 8, 48), + (2, 8, 64), + ) + def test_rope_emb(self, batch_size, max_len, dim): + # Token id is in the np format for easier transition. + token_ids = np.random.randint(low=1, high=20, size=[batch_size, max_len]) + positions = jnp.expand_dims(jnp.arange(token_ids.shape[-1], dtype=jnp.int32), 0) + ref_layer = hf_roformer.RoFormerSinusoidalPositionalEmbedding(max_len, dim) + ref_output = ref_layer(as_torch_tensor(token_ids).shape) + # Set up the RoPE AXLearn configs. + test_layer = ( + attention.RoFormerSinusoidalPositionalEmbedding.default_config() + .set(name="test_rope_emb", dim=dim) + .instantiate(parent=None) + ) + test_output = test_layer.forward(positions=positions) + np.testing.assert_allclose(np.expand_dims(ref_output, 0), test_output, atol=5e-7) - Reference: - https://github.com/tensorflow/lingvo/blob/d2f1e1b3cccdac8f73ae20f86afb03560b1c176d/lingvo/core/layers.py#L2775-L2923 + @parameterized.parameters( + (None, True), + (10, False), + ) + def test_rope_emb_no_pos(self, max_len, should_raise): + test_layer = ( + attention.RoFormerSinusoidalPositionalEmbedding.default_config() + .set(name="test_rope_emb", dim=10) + .instantiate(parent=None) + ) + if should_raise: + with self.assertRaises(ValueError): + test_layer.forward(max_seq_len=max_len) + else: + test_layer.forward(max_seq_len=max_len) - The inputs to the sinusoid functions will be positions / timescale(k) - for dimension 0 <= k < num_timescales = dim // 2, where: - timescale(k) = geometric interpolation between min_timescale and max_timescale, i.e., - log(timescale(k) / min_timescale) / log(max_timescale / min_timescale) = - k / num_timescales. - Specifically: timescale(0) = min_timescale and timescale(num_timescales) = max_timescale. + @parameterized.parameters( + (2, 10, 32, 4), + ) + def test_default_rope_emb(self, batch_size, max_len, dim, num_heads): + rng = np.random.default_rng(seed=123) + query = jnp.asarray(rng.random([batch_size, max_len, dim])) + key = jnp.asarray(rng.random([batch_size, max_len, dim])) + value = jnp.asarray(rng.random([batch_size, max_len, dim])) + per_head_dim = dim // num_heads + + emb_layer_cfg = attention.RoFormerSinusoidalPositionalEmbedding.default_config().set( + dim=per_head_dim, + ) + linear_layer_cfg = attention.RoFormerQKVLinear.default_config().set( + query_dim=dim, + key_dim=dim, + value_dim=dim, + num_heads=num_heads, + per_head_dim=per_head_dim, + rope_pos_emb_layer=emb_layer_cfg, + rotary_value=False, + name="test_rope_linear", + ) + rope_linear_layer = linear_layer_cfg.instantiate(parent=None) + state = rope_linear_layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) - Args: - positions: An integer tensor of any shape [...]. Each value represents an - absolute or relative position. - dim: the embedding dimension. Must be divisible by 2. - min_timescale: The minimum timescale (used for channel 0 and dim // 2). - max_timescale: The maximum timescale (used for channel dim // 2 - 1 and dim - 1). + rope_emb_layer = emb_layer_cfg.set(name="test_rope_emb").instantiate(parent=None) + default_positions = rope_emb_layer.default_query_positions(max_len) - Returns: - Embeddings of shape [..., dim]. + input_dict = dict(query=query, key=key, value=value) - Raises: - NotImplementedError: If dim is not divisible by 2. - """ - if dim % 2 != 0: - raise NotImplementedError(f"dim ({dim}) must be divisible by 2") - num_timescales = dim // 2 + layer_outputs_no_position, _ = F( + rope_linear_layer, + inputs=input_dict, + state=state, + is_training=True, + prng_key=jax.random.PRNGKey(0), + ) + layer_outputs, _ = F( + rope_linear_layer, + inputs=dict(**input_dict, query_positions=default_positions), + state=state, + is_training=True, + prng_key=jax.random.PRNGKey(0), + ) + # test RoFormerQKVLinear uses default positions in RoFormerSinusoidalPositionalEmbedding + np.testing.assert_allclose(layer_outputs_no_position, layer_outputs, atol=1e-5) - # To ensure results match other libraries, it is important to calculate - # log_timescale_increment using float64 calculations. This has no - # runtime cost. - log_timescale_increment = math.log(max_timescale / min_timescale) / max(1, num_timescales - 1) + def _compare_against_roformer_attention( + self, + ref, + layer, + tgt_len, + batch_size, + ref_rope_emb, + positions, + ): + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) + layer_param_shapes = jax.tree.map(lambda x: x.shape, layer_params) + print(f"layer state={layer_param_shapes}") + layer_params = parameters_from_torch_layer(ref) + model_dim, num_heads = layer.config.target_dim, layer.config.attention.num_heads + rng = np.random.default_rng(seed=123) + target = rng.random([batch_size, tgt_len, model_dim], dtype=np.float32) + null_mask = jnp.zeros([tgt_len, tgt_len]) + rand_mask = _random_mask(jax.random.PRNGKey(123), tgt_len, tgt_len) + + for mask in (None, null_mask, rand_mask): + if mask is not None: + mask = jnp.tile(mask[None, None, :, :], (batch_size, num_heads, 1, 1)) + layer_outputs, _ = F( + layer, + inputs=dict( + target=jnp.asarray(target), + attention_logit_biases=mask, + target_positions=positions, + ), + state=layer_params, + is_training=True, + prng_key=jax.random.PRNGKey(0), + ) + attn_mask = None if mask is None else as_torch_tensor(mask) + print("ref_rope_emb", ref_rope_emb.shape) + print("target", target.shape) + (ref_outputs,) = ref.forward( + torch.as_tensor(target, dtype=torch.float32), + attention_mask=attn_mask, + sinusoidal_pos=ref_rope_emb, + output_attentions=False, + ) + assert_allclose(layer_outputs.data, as_tensor(ref_outputs)) + + @parameterized.product(rotary_value=[True, False], override_positions=[True, False]) + def test_rope_self_attention(self, rotary_value: bool, override_positions: bool): + model_dim = 32 + num_heads = 4 + max_sequence_length = 12 + batch_size = 2 + rope_mha_cfg = attention.MultiheadAttention.default_config().set( + query_dim=model_dim, + key_dim=model_dim, + value_dim=model_dim, + num_heads=num_heads, + input_linear=RoFormerQKVLinear.default_config().set(rotary_value=rotary_value), + ) + rope_emb_layer = ( + attention.RoFormerSinusoidalPositionalEmbedding.default_config() + .set(name="test_rope_emb", dim=model_dim // num_heads) + .instantiate(parent=None) + ) + positions = ( + jax.random.randint( + jax.random.PRNGKey(0), + shape=(batch_size, max_sequence_length), + minval=0, + maxval=max_sequence_length, + ) + if override_positions + else jnp.expand_dims(jnp.arange(max_sequence_length), 0) + ) + ref_rope_emb = as_torch_tensor(rope_emb_layer.forward(positions=positions)).unsqueeze(1) + layer = attention.TransformerAttentionLayer.default_config().set( + source_dim=model_dim, + target_dim=model_dim, + name="rope_trans_attn", + attention=rope_mha_cfg, + structure="postnorm", + ) + layer = layer.instantiate(parent=None) + roformer_config = hf_roformer.RoFormerConfig( + hidden_size=model_dim, + num_attention_heads=num_heads, + attention_probs_dropout_prob=0, + hidden_dropout_prob=0, + rotary_value=rotary_value, + ) + print(f"roformer_config={roformer_config}") + ref = hf_roformer.RoFormerAttention(roformer_config) + self._compare_against_roformer_attention( + ref, + layer, + max_sequence_length, + batch_size, + ref_rope_emb, + positions if override_positions else None, + ) - # [num_timescales]. - inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales) * -log_timescale_increment) - # [..., num_timescales]. - scaled_time = jnp.expand_dims(positions, -1) * inv_timescales +def llama_reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + """LLaMA reshape for broadcast function. - # [..., dim]. - signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=-1) - return signal + Ref: + https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L55-L60 + """ + ndim = x.ndim + assert 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [ + d if i == 1 or i == ndim - 1 else 1 # pylint: disable=consider-using-in + for i, d in enumerate(x.shape) + ] + return freqs_cis.view(*shape) -class SinusoidalPositionalEmbedding(BaseLayer): - """Sinusoidal positional embeddings. +def llama_apply_rotary_emb( + *, + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """LLaMA apply rotary embeddings to input tensors using the given frequency tensor. - See sinusoidal_positional_embeddings()'s comments. + Ref: + https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L63-L73 """ + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = llama_reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) - @config_class - class Config(BaseLayer.Config): - """Configures SinusoidalPositionalEmbedding.""" - - dim: Required[int] = REQUIRED - min_timescale: float = 1 - max_timescale: float = 10000 - - def forward(self, positions: Tensor) -> Tensor: - """Looks up positional embeddings by positions.""" - cfg: SinusoidalPositionalEmbedding.Config = self.config - return sinusoidal_positional_embeddings( - positions, dim=cfg.dim, min_timescale=cfg.min_timescale, max_timescale=cfg.max_timescale - ) +class RefLLaMAAttention(torch.nn.Module): + """Reference Implementation of LLaMA-1. -class BaseMultiheadLinear(DenseGeneralBaseLayer): - """The linear layer used for multi-head attention. + Ref: + https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L76 - It uses einsum for efficient computation on TPU to avoid reshaping. + The modifications are removing the dependency of ColumnParallelLinear and RowParallelLinear. """ - @config_class - class Config(DenseGeneralBaseLayer.Config): - """Configures BaseMultiheadLinear.""" - - model_dim: Required[int] = REQUIRED # Feature dim. - num_heads: Required[int] = REQUIRED # Number of attention heads. - per_head_dim: Required[int] = REQUIRED # Dimension per head. - bias: bool = True # Whether the linear modules have biases. - - @classmethod - def default_config(cls) -> Config: - cfg = super().default_config() - # Shard the 'num_heads' axis by the 'model' dim of the mesh. - cfg.param_partition_spec = (None, "model", None) - return cfg + def __init__(self, n_heads: int, dim: int, max_batch_size: int, max_seq_len: int): + super().__init__() - def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: - cfg = self.config - params = dict( - weight=ParameterSpec( - shape=(cfg.model_dim, cfg.num_heads, cfg.per_head_dim), - mesh_axes=cfg.param_partition_spec, - factorization=FactorizationSpec(axes=("row", None, "col")), - ) + self.n_local_heads = n_heads + self.head_dim = dim // n_heads + + self.wq = torch.nn.Linear( + dim, + n_heads * self.head_dim, + bias=False, + ) + self.wk = torch.nn.Linear( + dim, + n_heads * self.head_dim, + bias=False, + ) + self.wv = torch.nn.Linear( + dim, + n_heads * self.head_dim, + bias=False, + ) + self.wo = torch.nn.Linear( + n_heads * self.head_dim, + dim, + bias=False, ) - if cfg.bias: - params["bias"] = self._bias_spec - return params - @property - def _einsum_expr(self): - raise NotImplementedError(type(self)) + self.cache_k = torch.zeros((max_batch_size, max_seq_len, self.n_local_heads, self.head_dim)) + self.cache_v = torch.zeros((max_batch_size, max_seq_len, self.n_local_heads, self.head_dim)) - def forward(self, inputs: Tensor) -> Tensor: - params = self.parameters - outputs = self.einsum_maybe_quantized( - self._einsum_expr, activation=inputs, kernel=params["weight"] + def forward( + self, + x: torch.Tensor, + start_pos: int, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor], + ) -> torch.Tensor: + bsz, seqlen, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + + xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim) + + xq, xk = llama_apply_rotary_emb(xq=xq, xk=xk, freqs_cis=freqs_cis) + + self.cache_k = self.cache_k.to(xq) + self.cache_v = self.cache_v.to(xq) + + self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk + self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv + + keys = self.cache_k[:bsz, : start_pos + seqlen] + values = self.cache_v[:bsz, : start_pos + seqlen] + + xq = xq.transpose(1, 2) + keys = keys.transpose(1, 2) + values = values.transpose(1, 2) + scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) + if mask is not None: + scores = scores + mask # (bs, n_local_heads, slen, cache_len + slen) + scores = torch.nn.functional.softmax(scores.float(), dim=-1).type_as(xq) + output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim) + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + + return self.wo(output) + + +class RoFormerSinusoidalPositionalEmbeddingAgainstLLaMATest(TestCase): + def llama_ref_precompute_freqs_cis( + self, *, dim: int, end: int, theta: float = 10000.0 + ) -> torch.Tensor: + """Reference LLaMA-1 implementation. + + Ref: + https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47-L52 + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) # type: ignore + freqs = torch.outer(t, freqs).float() # type: ignore + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + @parameterized.parameters([10000.0, 1000000.0]) + def test_against_llama_for_precompute_freqs_cis(self, theta: float): + max_len = 100 + dim = 32 + positions = jnp.expand_dims(jnp.arange(max_len), 0) + axlearn_rope_cfg = attention.RoFormerSinusoidalPositionalEmbedding.default_config().set( + dim=dim, + theta=theta, + ) + axlearn_rope_layer = axlearn_rope_cfg.set(name="rope").instantiate(parent=None) + axlearn_rope, _ = F( + axlearn_rope_layer, + inputs=dict(positions=positions), + state=axlearn_rope_layer.initialize_parameters_recursively( + prng_key=jax.random.PRNGKey(0) + ), + is_training=True, + prng_key=jax.random.PRNGKey(0), + ) + llama_rope = self.llama_ref_precompute_freqs_cis(dim=dim, end=max_len, theta=theta) + axlearn_imag, axlearn_real = jnp.split(axlearn_rope, 2, axis=-1) + llama_real, llama_imag = llama_rope.real, llama_rope.imag + # [0] is added, as axlearn_real and axlearn_imag has a batch_size=1 dimension. + assert_allclose(llama_real, as_tensor(axlearn_real)[0]) + assert_allclose(llama_imag, as_tensor(axlearn_imag)[0]) + + @parameterized.product( + dtype=(jnp.float32, jnp.bfloat16), + input_linear=( + None, + attention.QKVLinear.default_config(), + attention.GroupedQKVLinear.default_config(), + ), + has_query_positions=(True, False), + ) + def test_roformer_qkv_linear( + self, + dtype: jnp.dtype, + input_linear: attention.BaseQKVLinear.Config, + has_query_positions: bool, + ): + seq_len = 6 + batch_size = 2 + model_dim = 16 + num_heads = 4 + per_head_dim = model_dim // num_heads + roformer_qkv_linear_kwargs = { + "name": "roformer_qkv_linear", + "query_dim": model_dim, + "key_dim": model_dim, + "value_dim": model_dim, + "num_heads": num_heads, + "per_head_dim": per_head_dim, + "rotary_value": False, + } + num_kv_heads = num_heads + if input_linear is not None: + if isinstance(input_linear, attention.GroupedQKVLinear.Config): + num_kv_heads = num_heads // 2 + input_linear = input_linear.set(num_kv_heads=num_kv_heads) + roformer_qkv_linear_kwargs["input_linear"] = input_linear + + roformer_qkv_linear = ( + RoFormerQKVLinear.default_config() + .set(**roformer_qkv_linear_kwargs) + .instantiate(parent=None) ) - return outputs + params.get("bias", 0) - - def _compute_fan_axes(self, name: str, parameter_spec: ParameterSpec) -> Optional[FanAxes]: - raise NotImplementedError(type(self)) + # Check that we see the num kv heads is propagated from child input linear. + self.assertEqual(roformer_qkv_linear.num_kv_heads, num_kv_heads) -class MultiheadInputLinear(BaseMultiheadLinear): - """Multi-head input linear layer.""" + query = jax.random.uniform(jax.random.PRNGKey(1), shape=(batch_size, seq_len, model_dim)) + key = jax.random.uniform(jax.random.PRNGKey(2), shape=(batch_size, seq_len, model_dim)) + value = jax.random.uniform(jax.random.PRNGKey(3), shape=(batch_size, seq_len, model_dim)) + roformer_qkv_linear_state = roformer_qkv_linear.initialize_parameters_recursively( + jax.random.PRNGKey(0) + ) + input_batch = dict(query=query, key=key, value=value) + if has_query_positions: + input_batch["query_positions"] = jax.random.permutation( + jax.random.PRNGKey(1), + jnp.arange(seq_len)[None, :].repeat(batch_size, axis=0), + axis=1, + independent=True, + ) - @property - def _einsum_expr(self): - return "btd,dnh->btnh" + layer_outputs, _ = F( + roformer_qkv_linear, + inputs=utils.cast_floats(input_batch, to_dtype=dtype), + state=utils.cast_floats(roformer_qkv_linear_state, to_dtype=dtype), + is_training=True, + prng_key=jax.random.PRNGKey(0), + ) + self.assertEqual(layer_outputs.query.dtype, dtype) + self.assertEqual(layer_outputs.key.dtype, dtype) + self.assertEqual(layer_outputs.value.dtype, dtype) + + def test_against_llama_for_apply_rotary_emb(self): + max_len = 100 + dim = 32 + batch_size = 4 + positions = jnp.expand_dims(jnp.arange(max_len), 0) + axlearn_rope_cfg = attention.RoFormerSinusoidalPositionalEmbedding.default_config().set( + dim=dim + ) + axlearn_rope_layer = axlearn_rope_cfg.set(name="rope").instantiate(parent=None) + axlearn_rope, _ = F( + axlearn_rope_layer, + inputs=dict(positions=positions), + state=axlearn_rope_layer.initialize_parameters_recursively( + prng_key=jax.random.PRNGKey(0) + ), + is_training=True, + prng_key=jax.random.PRNGKey(0), + ) + llama_rope = self.llama_ref_precompute_freqs_cis(dim=dim, end=max_len) + rng = np.random.default_rng(seed=123) + query = rng.random([batch_size, max_len, dim]) + key = rng.random([batch_size, max_len, dim]) + value = rng.random([batch_size, max_len, dim]) + llama_q, llama_k = llama_apply_rotary_emb( + xq=torch.Tensor(query), xk=torch.Tensor(key), freqs_cis=llama_rope + ) + axlearn_q, axlearn_k, _ = attention.apply_rotary_position_embeddings( + query=jnp.asarray(query), + key=jnp.asarray(key), + value=jnp.asarray(value), + sinusoidal_pos=axlearn_rope, + rotary_key=True, + rotary_value=False, + ) - @property - def _bias_spec(self): - cfg = self.config - return ParameterSpec( - shape=(cfg.num_heads, cfg.per_head_dim), - mesh_axes=cfg.param_partition_spec[-2:], + assert_allclose(as_tensor(llama_q.reshape(batch_size, max_len, -1)), axlearn_q, atol=5e-6) + assert_allclose(as_tensor(llama_k.reshape(batch_size, max_len, -1)), axlearn_k, atol=5e-6) + + def test_against_llama_for_attention(self): + max_len = 100 + dim = 32 + batch_size = 4 + n_heads = 4 + rng = np.random.default_rng(seed=123) + x = rng.random([batch_size, max_len, dim]) + ref_llama = RefLLaMAAttention( + n_heads=n_heads, dim=dim, max_batch_size=batch_size, max_seq_len=max_len + ) + llama_rope = self.llama_ref_precompute_freqs_cis(dim=dim // n_heads, end=max_len) + llama_output = ref_llama.forward(torch.Tensor(x), 0, llama_rope, mask=None) + + rope_mha_cfg = attention.MultiheadAttention.default_config().set( + query_dim=dim, + key_dim=dim, + value_dim=dim, + num_heads=n_heads, + input_linear=RoFormerQKVLinear.default_config().set( + rotary_value=False, + ), ) - # pylint: disable-next=no-self-use - def _compute_fan_axes(self, name: str, parameter_spec: ParameterSpec) -> Optional[FanAxes]: - if name == "weight": - return FanAxes(in_axis=0, out_axis=(1, 2)) - else: - return None + rope_mha = rope_mha_cfg.set(name="rope").instantiate(parent=None) + state = rope_mha.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) + state["i_proj"]["i_proj"]["q_proj"]["weight"] = jnp.asarray( + ref_llama.wq.weight.transpose(0, 1) + .reshape(dim, n_heads, dim // n_heads) + .detach() + .numpy() + ) + state["i_proj"]["i_proj"]["k_proj"]["weight"] = jnp.asarray( + ref_llama.wk.weight.transpose(0, 1) + .reshape(dim, n_heads, dim // n_heads) + .detach() + .numpy() + ) + state["i_proj"]["i_proj"]["v_proj"]["weight"] = jnp.asarray( + ref_llama.wv.weight.transpose(0, 1) + .reshape(dim, n_heads, dim // n_heads) + .detach() + .numpy() + ) + state["o_proj"]["weight"] = jnp.asarray( + ref_llama.wo.weight.reshape(dim, n_heads, dim // n_heads).detach().numpy() + ) -class MultiheadOutputLinear(BaseMultiheadLinear): - """Multi-head output linear layer.""" + axlearn_output, _ = F( + rope_mha, + inputs=dict(query=jnp.asarray(x)), + state=state, + is_training=False, + prng_key=jax.random.PRNGKey(0), + ) + assert_allclose( + as_tensor(llama_output.reshape(batch_size, max_len, -1)), axlearn_output.data + ) - @property - def _einsum_expr(self): - return "btnh,dnh->btd" - @property - def _bias_spec(self): - cfg = self.config - return ParameterSpec( - shape=(cfg.model_dim,), - mesh_axes=cfg.param_partition_spec[:1], +class MultiheadLinearInitTest(TestCase): + """Tests MultiheadLinear initialization.""" + + @parameterized.parameters( + ( + MultiheadInputLinear, + FanAxes(in_axis=0, out_axis=(1, 2)), + { + "fan_in": 4, + "fan_out": 8 * 6, + "fan_avg": (4 + 8 * 6) / 2, + }, + ), + ( + MultiheadOutputLinear, + FanAxes(in_axis=(1, 2), out_axis=0), + { + "fan_in": 8 * 6, + "fan_out": 4, + "fan_avg": (8 * 6 + 4) / 2, + }, + ), + ( + MultiheadRelativePositionLinear, + FanAxes(in_axis=0, out_axis=(1, 2)), + { + "fan_in": 4, + "fan_out": 8 * 6, + "fan_avg": (4 + 8 * 6) / 2, + }, + ), + ) + def test_compute_fan_axes(self, cls, fan_axes, fans): + for dist in ("uniform", "normal", "truncated_normal"): + for scale in (1.0, 2.0): + for fan_type in ("fan_in", "fan_out", "fan_avg"): + cfg = cls.default_config().set( + name="test", model_dim=4, num_heads=8, per_head_dim=6 + ) + cfg.param_init = DefaultInitializer.default_config().set( + init_by_param_name={ + PARAM_REGEXP_WEIGHT: WeightInitializer.default_config().set( + fan=fan_type, scale=scale, distribution=dist + ) + } + ) + layer: BaseLayer = cfg.instantiate(parent=None) + # pylint: disable-next=protected-access + param_spec_map = layer._create_layer_parameter_specs() + self.assertEqual( + # pylint: disable-next=protected-access + layer._compute_fan_axes("weight", param_spec_map["weight"]), + fan_axes, + ) + layer_params = layer.initialize_parameters_recursively(jax.random.PRNGKey(1)) + weight = layer_params["weight"] + self.assertEqual(weight.dtype, jnp.float32) + fan = fans[fan_type] + expected_std = scale / math.sqrt(fan) + actual_std = np.std(weight) + self.assertBetween(actual_std, expected_std / 1.5, expected_std * 1.5) + + +class QKVLinearTest(TestCase): + """Tests QKVLinear, FusedQKVLinear, and associated layers.""" + + @parameterized.product( + test_cls=[ + attention.FusedQKVLinear, + attention.GroupedQKVLinear, + attention.FusedGroupedQKVLinear, + ], + with_positions=[True, False], + ) + def test_qkv_equality(self, test_cls: type[attention.BaseQKVLinear], with_positions: bool): + """Tests that the QKVLinear variants are equivalent when num_kv_heads=num_heads.""" + with utils.numeric_checks(True): + model_dim = 12 + num_heads = 4 + per_head_dim = model_dim // num_heads + layer_kwargs = dict( + query_dim=model_dim, + key_dim=model_dim, + value_dim=model_dim, + num_heads=num_heads, + per_head_dim=per_head_dim, + ) + base_cfg = QKVLinear.default_config().set(**layer_kwargs) + test_cfg = test_cls.default_config().set(**layer_kwargs) + maybe_set_config(test_cfg, num_kv_heads=num_heads) + base_layer = base_cfg.set(name="base").instantiate(parent=None) + test_layer = test_cfg.set(name="test").instantiate(parent=None) + + # Construct base layer state. + base_state = base_layer.initialize_parameters_recursively(jax.random.PRNGKey(0)) + + # Map state to fused version. + if test_cls == attention.FusedQKVLinear: + weight = jnp.array( + [base_state[el]["weight"] for el in ("q_proj", "k_proj", "v_proj")] + ) + bias = jnp.array([base_state[el]["bias"] for el in ("q_proj", "k_proj", "v_proj")]) + test_state = {"qkv_proj": dict(weight=weight, bias=bias)} + elif test_cls == attention.FusedGroupedQKVLinear: + # Concatenate along the num_heads dim. + weight = jnp.concatenate( + [base_state[el]["weight"] for el in ("q_proj", "k_proj", "v_proj")], axis=1 + ) + bias = jnp.concatenate( + [base_state[el]["bias"] for el in ("q_proj", "k_proj", "v_proj")], axis=0 + ) + test_state = {"qkv_proj": dict(weight=weight, bias=bias)} + else: + test_state = base_state + + # Construct test inputs. + batch_size, src_len, tgt_len = 2, 6, 6 + query = jax.random.uniform(jax.random.PRNGKey(0), [batch_size, tgt_len, model_dim]) + key = jax.random.uniform(jax.random.PRNGKey(1), [batch_size, src_len, model_dim]) + value = jax.random.uniform(jax.random.PRNGKey(2), [batch_size, src_len, model_dim]) + + # In the fused GQA case, we assume query=key=value. + if test_cls == attention.FusedGroupedQKVLinear: + key = value = None + + positions = jnp.ones((1, tgt_len)) if with_positions else None + inputs = dict(query=query, key=key, value=value, query_positions=positions) + outputs = {} + layer_names = ("base", "test") + for name, layer, state in zip( + layer_names, (base_layer, test_layer), (base_state, test_state) + ): + outputs[name], _ = F( + layer, + state=state, + is_training=True, + prng_key=jax.random.PRNGKey(456), + inputs=inputs, + ) + for layer_a, layer_b in combinations(layer_names, 2): + # Check that the outputs are close for all pairs. + self.assertNestedAllClose(outputs[layer_a], outputs[layer_b]) + + @parameterized.parameters( + dict(layer_cls=attention.QKVLinear, expected=4), + dict(layer_cls=attention.FusedQKVLinear, expected=4), + dict( + layer_cls=attention.QKVLinear, + num_kv_heads=2, + expected=UnknownFieldError("num_kv_heads"), + ), + dict( + layer_cls=attention.FusedQKVLinear, + num_kv_heads=2, + expected=UnknownFieldError("num_kv_heads"), + ), + dict( + layer_cls=attention.GroupedQKVLinear, + num_kv_heads=3, + expected=ValueError("should divide"), + ), + dict( + layer_cls=attention.FusedGroupedQKVLinear, + num_kv_heads=3, + expected=ValueError("should divide"), + ), + dict(layer_cls=attention.GroupedQKVLinear, num_kv_heads=2, expected=2), + dict(layer_cls=attention.FusedGroupedQKVLinear, num_kv_heads=2, expected=2), + ) + def test_num_kv_heads( + self, + layer_cls: type[attention.BaseQKVLinear], + expected: Union[int, Exception], + num_kv_heads: Optional[int] = None, + ): + model_dim = 12 + num_heads = 4 + per_head_dim = model_dim // num_heads + common_kwargs = dict( + query_dim=model_dim, key_dim=model_dim, value_dim=model_dim, per_head_dim=per_head_dim ) + cfg = layer_cls.default_config().set(name="test", num_heads=num_heads, **common_kwargs) - # pylint: disable-next=no-self-use - def _compute_fan_axes(self, name: str, parameter_spec: ParameterSpec) -> Optional[FanAxes]: - if name == "weight": - return FanAxes(in_axis=(1, 2), out_axis=0) + if isinstance(expected, Exception): + ctx = self.assertRaisesRegex(type(expected), str(expected)) else: - return None - - -def apply_attention_logit_biases( - logits: Tensor, attention_logit_biases: Optional[Tensor] = None -) -> Tensor: - """Applies `attention_logit_biases` on `logits`. + ctx = contextlib.nullcontext() + + with ctx: + if num_kv_heads is not None: + cfg.set(num_kv_heads=num_kv_heads) + layer = cfg.instantiate(parent=None) + self.assertEqual(expected, layer.num_kv_heads) + + @parameterized.parameters( + (QKVLinear.default_config(), QLinear.default_config()), + ( + RoFormerQKVLinear.default_config().set( + input_linear=QKVLinear.default_config(), rotary_value=False + ), + RoFormerQKVLinear.default_config().set( + input_linear=QLinear.default_config(), rotary_value=False + ), + ), + ) + def test_qlinear(self, base_cfg, test_cfg): + """Tests that QLinear is equivalent to QKVLinear with the same kv_state.""" + with utils.numeric_checks(True): + model_dim = 12 + num_heads = 3 + per_head_dim = model_dim // num_heads + layer_kwargs = dict( + query_dim=model_dim, + key_dim=model_dim, + value_dim=model_dim, + num_heads=num_heads, + per_head_dim=per_head_dim, + ) + base_cfg = base_cfg.set(**layer_kwargs) + test_cfg = test_cfg.set(**layer_kwargs) + maybe_set_config(test_cfg, num_kv_heads=num_heads) + base_layer = base_cfg.set(name="base").instantiate(parent=None) + test_layer = test_cfg.set(name="test").instantiate(parent=None) + + # Construct base layer state. + base_state = base_layer.initialize_parameters_recursively(jax.random.PRNGKey(0)) + # Map state to QLinear. + if "q_proj" in base_state: + test_state = {"q_proj": base_state["q_proj"]} + elif "i_proj" in base_state: + test_state = {"i_proj": {"q_proj": base_state["i_proj"]["q_proj"]}} + else: + raise ValueError("Cannot find expected q_proj state.") + + # Construct test inputs. + batch_size, src_len, tgt_len = 2, 6, 6 + query = jax.random.uniform(jax.random.PRNGKey(0), [batch_size, tgt_len, model_dim]) + key = jax.random.uniform(jax.random.PRNGKey(1), [batch_size, src_len, model_dim]) + value = jax.random.uniform(jax.random.PRNGKey(2), [batch_size, src_len, model_dim]) + + outputs = {} + layer_names = ("base", "test") + kv_kwargs = {"key": key, "value": value} + for name, layer, state in zip( + layer_names, (base_layer, test_layer), (base_state, test_state) + ): + outputs[name], _ = F( + layer, + state=state, + is_training=True, + prng_key=jax.random.PRNGKey(456), + inputs=dict(query=query, **kv_kwargs), + ) + if name == "base": + kv_kwargs = { + "kv_state": KVState(k_proj=outputs[name].key, v_proj=outputs[name].value) + } + for layer_a, layer_b in combinations(layer_names, 2): + # Check that the outputs are close for all pairs. + self.assertNestedAllClose(outputs[layer_a], outputs[layer_b]) + + @parameterized.parameters( + (attention.QKVLinear, 1), + (attention.FusedQKVLinear, 1), + (attention.GroupedQKVLinear, 1), + (attention.FusedGroupedQKVLinear, 1), + (attention.RoFormerQKVLinear, 1), + (attention.QKVLinear, 2), + (attention.FusedQKVLinear, 3), + (attention.GroupedQKVLinear, 4), + (attention.FusedGroupedQKVLinear, 3), + (attention.RoFormerQKVLinear, 2), + ) + def test_repeated_extend_step(self, layer_cls: type[attention.BaseQKVLinear], extend_step_len): + """Tests that calling QKVLinear.extend_step() multiple times with the + same time_step results in the same output.""" + model_dim = 8 + num_heads = 2 + per_head_dim = model_dim // num_heads + layer_kwargs = dict( + query_dim=model_dim, + key_dim=model_dim, + value_dim=model_dim, + num_heads=num_heads, + per_head_dim=per_head_dim, + ) + cfg = layer_cls.default_config().set(**layer_kwargs) + maybe_set_config(cfg, num_kv_heads=num_heads, rotary_value=False) + layer = cfg.set(name="test").instantiate(parent=None) + + # Construct base layer state. + layer_state = layer.initialize_parameters_recursively(jax.random.PRNGKey(0)) + + # Construct test inputs. + batch_size, tgt_len = 2, 4 + query = jax.random.uniform(jax.random.PRNGKey(0), [batch_size, tgt_len, model_dim]) + + fwd_output, _ = F( + layer, + state=layer_state, + is_training=False, + prng_key=jax.random.PRNGKey(456), + inputs=dict(query=query), + ) - Args: - logits: A float Tensor. - attention_logit_biases: A float Tensor. If None, assume all zeros. + cache_state, init_output = layer.init_states( + time_step=None, query=TensorSpec([batch_size, tgt_len]) + ) + self.assertIsNone(init_output) + step_querys = [] + step_keys = step_values = None + for t in range(0, tgt_len, extend_step_len): + (cache_state, step_output), _ = F( + layer, + state=layer_state, + is_training=False, + prng_key=jax.random.PRNGKey(456), + inputs=dict(cached_states=cache_state, query=query[:, t : t + extend_step_len]), + method="extend_step", + ) + step_querys.append(step_output.query) + step_keys = step_output.key + step_values = step_output.value - Returns: - logits + attention_logit_biases, in logits.dtype. - """ - if attention_logit_biases is None: - return logits - return logits + attention_logit_biases.astype(logits.dtype) + self.assertNestedAllClose(fwd_output.query, jnp.concat(step_querys, axis=1)) + self.assertNestedAllClose(fwd_output.key, step_keys) + self.assertNestedAllClose(fwd_output.value, step_values) + @parameterized.parameters(jnp.float32, jnp.float16, jnp.bfloat16) + def test_dtypes_inherited_from_parent(self, dtype: jnp.dtype): + """Test that the dtype is inherited from the parent. -def softmax_with_biases(logits: Tensor, attention_logit_biases: Optional[Tensor] = None) -> Tensor: - """Computes softmax with optional masking. + When neither `Config.cache_dtype` nor `BaseLayer.Config.dtype` are set the dtype should + be inherited from the parent, and the dtype should be preserved in values in the + cached states and outputs. + """ - Args: - logits: A Tensor of any shape. - attention_logit_biases: A Tensor that is broadcastable with logits. - See ``On attention logit biases`` in the file comments. + target_batch_size = 3 + target_max_len = 6 + model_dim = 12 + num_heads = 4 + per_head_dim = model_dim // num_heads + layer_kwargs = dict( + query_dim=model_dim, + key_dim=model_dim, + value_dim=model_dim, + num_heads=num_heads, + per_head_dim=per_head_dim, + ) - Returns: - A Tensor of same shape and dtype as logits. - """ - check_numerics(logits) - logits = apply_attention_logit_biases(logits, attention_logit_biases) - logits_dtype = logits.dtype - if logits_dtype in (jnp.bfloat16, jnp.float16): - # Avoid computing softmax in 16-bit floats. - logits = logits.astype(jnp.float32) - probs = jax.nn.softmax(logits, axis=-1) - if probs.dtype != logits_dtype: - probs = probs.astype(logits_dtype) - check_numerics(probs) - return probs - - -def sigmoid_with_biases( - logits: Tensor, - attention_logit_biases: Optional[Tensor] = None, -) -> Tensor: - """Computes sigmoid with optional masking. + class Parent(BaseLayer): + @config_class + class Config(BaseLayer.Config): + qkv_linear: InstantiableConfig = QKVLinear.default_config().set(**layer_kwargs) + + def __init__(self, cfg: Config, *, parent: Module): + super().__init__(cfg, parent=parent) + cfg = self.config + self._add_child("qkv_linear", cfg.qkv_linear) + + parent_cfg = Parent.default_config().set(name="parent", dtype=dtype) + # Test assumes that dtype is not set in test_cfg. + self.assertIs(parent_cfg.qkv_linear.dtype, None) + parent = parent_cfg.instantiate(parent=None) + qkv_linear = parent.qkv_linear + state = qkv_linear.initialize_parameters_recursively(jax.random.PRNGKey(0)) + + # Check dtypes from init_states. + (cache, init_output), _ = F( + qkv_linear, + prng_key=jax.random.PRNGKey(0), + state=state, + inputs=dict( + time_step=None, + query=TensorSpec([target_batch_size, target_max_len]), + ), + method="init_states", + is_training=False, + ) + self.assertIsNone(init_output) + self.assertEqual(cache["key"].dtype, dtype) + self.assertEqual(cache["value"].dtype, dtype) + + query = jax.random.uniform( + jax.random.PRNGKey(0), + shape=(target_batch_size, target_max_len, model_dim), + dtype=dtype, + ) + # Time step in the middle, so that some of the init_state is masked. + time_step = jnp.full( + shape=target_batch_size, + fill_value=target_max_len // 2, + dtype=jnp.int32, + ) + (init_state, output), _ = F( + qkv_linear, + prng_key=jax.random.PRNGKey(0), + state=state, + inputs=dict(time_step=time_step, query=query), + method="init_states", + is_training=False, + ) + self.assertEqual(init_state["key"].dtype, dtype) + self.assertEqual(init_state["value"].dtype, dtype) + self.assertEqual(output.query.dtype, dtype) + self.assertEqual(output.key.dtype, dtype) + self.assertEqual(output.value.dtype, dtype) + + +class PerDimScaleTest(TestCase): + """Tests PerDimScale.""" + + @parameterized.parameters(jnp.float32, jnp.float16, jnp.bfloat16) + def test_per_dim_scale(self, dtype: jnp.dtype): + batch_size, tgt_len, num_head, model_dim = 3, 5, 2, 8 + per_head_dim = model_dim // num_head + layer: PerDimScale = ( + PerDimScale.default_config() + .set( + name="test", + dim=per_head_dim, + ) # We do not set layer dtype. + .instantiate(parent=None) + ) + state = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) + query = jax.random.normal( + jax.random.PRNGKey(456), [batch_size, tgt_len, num_head, per_head_dim], dtype=dtype + ) + self.assertEqual(dict(param=(per_head_dim,)), shapes(state)) + outputs, _ = F( + layer, + state=state, + is_training=True, + prng_key=jax.random.PRNGKey(456), + inputs=(query,), + ) + expected_outputs = query + assert_allclose(outputs, expected_outputs) + self.assertEqual(outputs.dtype, query.dtype) - Args: - logits: A Tensor of any shape. - attention_logit_biases: A Tensor that is broadcastable with logits. - See ``On attention logit biases`` in the file comments. - Returns: - A Tensor of same shape and dtype as logits. - """ - check_numerics(logits) - logits = apply_attention_logit_biases(logits, attention_logit_biases) - # Avoid computing sigmoid in 16-bit floats. - logits_dtype = logits.dtype - if logits_dtype in (jnp.bfloat16, jnp.float16): - logits = logits.astype(jnp.float32) - probs = jax.nn.sigmoid(logits) - check_numerics(probs) - return probs - - -class BaseQKVLinear(BaseLayer): - """A layer that encapsulates mapping input queries, keys, and values to - multi-headed output queries, keys, and values. - """ +class ScaleQueryTest(TestCase): + """Tests ScaleQuery.""" - @config_class - class Config(BaseLayer.Config): - """Configures BaseQKVLinear.""" - - # Input query feature dim. - query_dim: Required[int] = REQUIRED - # Input key feature dim. - key_dim: Required[int] = REQUIRED - # Input value feature dim. - value_dim: Required[int] = REQUIRED - # Number of attention heads. - num_heads: Required[int] = REQUIRED - # Dimension of each attention head. - per_head_dim: Required[int] = REQUIRED - # Autoregressive cache dtype. Should match the step dtype. - # Needs to match the forward dtype for Repeated layers. If None, infer as BaseLayer.dtype(). - cache_dtype: Optional[jnp.dtype] = None - - class Output(NamedTuple): - # [batch, target_length, num_heads, per_head_dim]. - query: Tensor - # [batch, source_length, num_heads, per_head_dim]. - key: Tensor - # [batch, source_length, num_heads, per_head_dim]. - value: Tensor - - @property - def num_kv_heads(self): - return self.config.num_heads - - def init_states( + @parameterized.product( + scale_factor=[None, 7], + norm=[None, RMSNorm.default_config()], + per_dim_scale=[ + None, + PerDimScale.default_config(), + ], + ) + def test_scale_query( self, *, - time_step: Optional[Tensor], - query: Union[Tensor, TensorSpec], - key: Optional[Tensor] = None, - value: Optional[Tensor] = None, - kv_state: Optional[KVState] = None, - ) -> tuple[Nested[Tensor], Optional[Output]]: - """Initializes cache for autoregressive cached decoding. - - The method supports initializing an empty cache as well as prefilling: - * To initialize an empty cache, specify `time_step=None`. - In this case, `query` is allowed to be a TensorSpec. - * To prefill, provide `time_step` and `query` as Tensors. - - Args: - time_step: An optional Tensor of shape [batch]. Each value is an index into the length - dimension indicating where decoding will start from. - query: A Tensor or TensorSpec of shape [batch, target_length, target_dim] corresponding - to query vector at `time_step` indices. - For batch index `i`, only `query[i, :time_step[i], ...]` will affect subsequent - decoding. - key: An optional Tensor of shape [batch, source_length, source_dim]. - If None, will use `query`. - value: An optional Tensor of shape [batch, source_length, source_dim]. - If None, will use `query`. - kv_state: An optional KVState. If not None, both key and value must be None. - - Returns: - A tuple (init_states, output): - * init_states: A Nested Tensor state of `key`, `value` of shape - [batch, num_heads, per_head_dim, source_length], and `time_step` of shape [batch]. - * output: In the prefill case, an Output instance, where query is of size - [batch, target_length, num_heads, per_head_dim] and each of key, value are of dim - [batch, source_length, num_heads, per_head_dim]. - Otherwise, if initializing cache from scratch, output will be None. - - Raises: - ValueError: If key/value and kv_state are an invalid combination. - ValueError: If query and time_step are an invalid combination. - """ - cfg: BaseQKVLinear.Config = self.config - # Default to base layer dtype for initialization if cache_dtype is None. - dtype = cfg.cache_dtype or self.dtype() - assert dtype is not None - - if kv_state is not None and (key is not None or value is not None): - raise ValueError("kv_state should not be specified together with key/value.") - if time_step is not None and isinstance(query, TensorSpec): - raise ValueError("query must be a Tensor if time_step is provided.") - - output = None - # Always initialize to all 0's; if `time_step` is provided, we invoke `extend_step` below - # which updates the cache with the new `time_step`. - init_state = dict(time_step=jnp.zeros(query.shape[0], dtype=jnp.int32)) - - # If `kv_state` is provided externally, we do not have to maintain key/value in cache. - # Otherwise, initialize the cache from provided query, key, value. - if kv_state is None: + scale_factor: Optional[float], + norm: Optional[RMSNorm.Config], + per_dim_scale: Optional[PerDimScale.Config], + ): + kwargs = self._scale_kwargs( + scale_factor=scale_factor, norm=norm, per_dim_scale=per_dim_scale + ) + forward_outputs, _ = F(**kwargs) + + self.assertEqual(forward_outputs.shape, kwargs["inputs"]["proj"].shape) + q_proj_scaled = kwargs["inputs"]["proj"] + if norm is not None: + assert isinstance(norm, RMSNorm.Config) + moment2 = (q_proj_scaled * q_proj_scaled).mean(axis=-1, keepdims=True) + q_proj_scaled = q_proj_scaled * jax.lax.rsqrt(moment2 + norm.eps) + if per_dim_scale is not None: + assert isinstance(per_dim_scale, PerDimScale.Config) + # We overrode the initializer for PerDimScale so we can measure the effect. + q_proj_scaled = q_proj_scaled * jax.nn.softplus(1.0) * 1.442695041 + + if scale_factor is None: + scale_factor = kwargs["module"].config.per_head_dim ** -0.5 + scale_factor = float(scale_factor) + q_proj_scaled = q_proj_scaled * scale_factor + + self.assertNestedAllClose(forward_outputs, q_proj_scaled) + + def _scale_kwargs( + self, + *, + scale_factor: Union[None, int, float, InstantiableConfig[attention.ScaleFn]], + norm: Optional[InstantiableConfig], + per_dim_scale: Optional[PerDimScale.Config], + ): + model_dim = 16 + if isinstance(scale_factor, (int, float)): + scale_factor = config_for_function(attention.constant_scale_fn).set(value=scale_factor) + + num_heads = 2 + per_head_dim = model_dim // num_heads + if per_dim_scale is not None: + per_dim_scale = per_dim_scale.set(dim=per_head_dim) + + cfg = attention.ScaleQuery.default_config().set( + name="test", + per_head_dim=per_head_dim, + norm=norm, + scale_factor=scale_factor, + per_dim_scale=per_dim_scale, + ) + layer = cfg.instantiate(parent=None) - def maybe_initialize(kv: Optional[Tensor]): - # [batch, source/target_len, num_kv_heads, per_head_dim]. - if kv is None: - kv = jnp.zeros( - (*query.shape[:2], self.num_kv_heads, cfg.per_head_dim), dtype=dtype - ) - else: - kv = jnp.reshape(kv, (*kv.shape[:2], self.num_kv_heads, cfg.per_head_dim)) - return kv + param_specs = layer.create_parameter_specs_recursively() + layer_params = jax.tree.map( + lambda spec: jnp.ones(spec.shape, dtype=spec.dtype), param_specs + ) - init_state.update(key=maybe_initialize(key), value=maybe_initialize(value)) + batch_size = 3 + tgt_len = 10 + q_proj = jnp.concatenate( + ( + jnp.ones([batch_size, tgt_len // 2, num_heads, per_head_dim]), + jnp.zeros([batch_size, tgt_len // 2, num_heads, per_head_dim]), + ), + axis=1, + ) + kwargs = dict( + module=layer, + state=layer_params, + is_training=False, + prng_key=jax.random.PRNGKey(456), + inputs=dict(proj=q_proj), + ) + return kwargs - # If time_step is not provided, initialize an empty cache (i.e., all 0's). - # Otherwise, treat as prefill case and invoke `extend_step`. - if time_step is not None: - init_state, output = self.extend_step( - init_state, query, key=key, value=value, kv_state=kv_state - ) - # The time_step from `extend_step` includes full query length. - init_state["time_step"] = time_step - return init_state, output +class ScaleKeyTest(TestCase): + """Tests ScaleKey.""" - def forward( + @parameterized.product( + scale_factor=[None, 7], + norm=[None, RMSNorm.default_config()], + ) + def test_scale_key( self, - query: Tensor, *, - key: Optional[Tensor] = None, - value: Optional[Tensor] = None, - kv_state: Optional[KVState] = None, - query_positions: Optional[Tensor] = None, - ) -> Output: - """Computes per-head query, key, and value for the input query, key, value. - - Args: - query: A Tensor of shape [batch, target_length, target_dim]. - key: an optional Tensor of shape [batch, source_length, source_dim]. - If None, will use `query`. - value: An optional Tensor of shape [batch, source_length, source_dim]. - If None, will use `query`. - kv_state: An optional KVState. If not None, both key and value must be None. - query_positions: An optional Tensor of shape [batch, target_length]. - - Returns: - An Output instance, where query is of size - [batch, target_length, num_heads, per_head_dim] and each of key, value are of dim - [batch, source_length, num_heads, per_head_dim]. - """ - raise NotImplementedError(type(self)) - - def extend_step( + scale_factor: Optional[float], + norm: Optional[RMSNorm.Config], + ): + kwargs = self._scale_kwargs(scale_factor=scale_factor, norm=norm) + forward_outputs, _ = F(**kwargs) + + self.assertEqual(forward_outputs.shape, kwargs["inputs"]["proj"].shape) + q_proj_scaled = kwargs["inputs"]["proj"] + if norm is not None: + assert isinstance(norm, RMSNorm.Config) + moment2 = (q_proj_scaled * q_proj_scaled).mean(axis=-1, keepdims=True) + q_proj_scaled = q_proj_scaled * jax.lax.rsqrt(moment2 + norm.eps) + + if scale_factor is None: + scale_factor = 1.0 + scale_factor = float(scale_factor) + q_proj_scaled = q_proj_scaled * scale_factor + self.assertNestedAllClose(forward_outputs, q_proj_scaled) + + def _scale_kwargs( self, - cached_states: NestedTensor, - query: Tensor, *, - key: Optional[Tensor] = None, - value: Optional[Tensor] = None, - kv_state: Optional[KVState] = None, - ) -> tuple[NestedTensor, Output]: - """Computes the value vector given the query of the current step. - This function is used by autoregressive decoding. - - Based on: - https://github.com/tensorflow/lingvo/blob/5754b2f840ebf0f8c52d87e5d4d76f22e372513e/lingvo/jax/layers/attentions.py#L1249 - - Args: - cached_states: A `NestedTensor` object containing tensors which are the results of - previous attentions, and index used for fast decoding. Contains "key" and "value" of - shape [batch, num_heads, per_head_dim, target_length], and a Tensor "time_step" of - shape [batch]. - query: Tensor of shape [batch, steps, target_dim] corresponding to query vector starting - at "time_step" indices. - key: An optional Tensor of shape [batch, source_length, source_dim]. If None, will use - `query`. - value: An optional Tensor of shape [batch, source_length, source_dim]. If None, will - use `query`. - kv_state: An optional KVState. If not None, both key and value must be None. - - Returns: - A `NestedTensor` state of key and value pair along with index updated at `time_step`. - An Output instance, where query is of size - [batch, target_length, num_heads, per_head_dim] and each of key, value are of dim - [batch, source_length, num_heads, per_head_dim]. - """ - time_step = cached_states["time_step"] - assert time_step.ndim == 1 - - if kv_state is not None: - if key is not None or value is not None: - raise ValueError("kv_state should not be specified together with key/value") - kv_kwargs = dict(kv_state=kv_state) - else: - kv_kwargs = dict(key=key, value=value) + scale_factor: Union[None, int, float, InstantiableConfig[attention.ScaleFn]], + norm: Optional[InstantiableConfig], + ): + model_dim = 16 + if isinstance(scale_factor, (int, float)): + scale_factor = config_for_function(attention.constant_scale_fn).set(value=scale_factor) + + num_heads = 2 + per_head_dim = model_dim // num_heads + + cfg = attention.ScaleKey.default_config().set( + name="test", + per_head_dim=per_head_dim, + norm=norm, + scale_factor=scale_factor, + ) + layer = cfg.instantiate(parent=None) - num_query_steps = query.shape[1] - query_positions = jnp.arange(num_query_steps)[None] - query_positions += time_step[:, None] + param_specs = layer.create_parameter_specs_recursively() + layer_params = jax.tree.map( + lambda spec: jnp.ones(spec.shape, dtype=spec.dtype), param_specs + ) - # Project inputs to key, value and query. Each has shape [B, steps, N, H]. - q_proj, k_proj, v_proj = self.forward(query, **kv_kwargs, query_positions=query_positions) - updated_state = dict(time_step=time_step + num_query_steps) - if kv_state is None: - # Update the cache via dynamic slice. [B, S, N, H]. - cached_key = cached_states["key"] - cached_value = cached_states["value"] + batch_size = 4 + tgt_len = 12 + k_proj = jnp.concatenate( + ( + jnp.ones([batch_size, tgt_len // 2, num_heads, per_head_dim]), + jnp.zeros([batch_size, tgt_len // 2, num_heads, per_head_dim]), + ), + axis=1, + ) + kwargs = dict( + module=layer, + state=layer_params, + is_training=False, + prng_key=jax.random.PRNGKey(123), + inputs=dict(proj=k_proj), + ) + return kwargs - # Ensure that we accumulate using the original dtype. - k_proj = k_proj.astype(cached_key.dtype) - v_proj = v_proj.astype(cached_value.dtype) - # TODO(dhwang2): jax.lax.dynamic_update_slice_in_dim is generally faster than advanced - # indexing, but an unusual slowdown was observed, with RLHF sampling taking up to - # 3 hours per run. Investigate and fix it. - # Note: All X_idx are small, so generating them on-demand is not costly. - b, _, n, h = cached_key.shape - b_idx = jnp.arange(b)[:, None, None, None] - t_idx = (jnp.arange(k_proj.shape[1])[None] + time_step[:, None])[:, :, None, None] - n_idx = jnp.arange(n)[None, None, :, None] - h_idx = jnp.arange(h)[None, None, None, :] - k_proj = cached_key.at[b_idx, t_idx, n_idx, h_idx].set(k_proj) - v_proj = cached_value.at[b_idx, t_idx, n_idx, h_idx].set(v_proj) +def _convert_to_qkv_linear( + base_state: Nested[Tensor], *, input_linear_layer_class: type +) -> Nested[Tensor]: + """Converts the params of a MultiheadAttention layer - updated_state.update(key=k_proj, value=v_proj) - return updated_state, self.Output(query=q_proj, key=k_proj, value=v_proj) + ... to params of a MultiheadAttention layer with input_linear of the given type.""" + test_state = copy.deepcopy(base_state) + if issubclass( + input_linear_layer_class, (attention.FusedQKVLinear, attention.FusedGroupedQKVLinear) + ): -class QKVLinear(BaseQKVLinear): - """Maps input query, key, and value to multi-headed output query, key, and value.""" + def combine_qkv(param_name: str) -> Tensor: + qkv_params = [ + utils.get_recursively(base_state, f"i_proj/{proj}/{param_name}") + for proj in ("q_proj", "k_proj", "v_proj") + ] + if issubclass(input_linear_layer_class, attention.FusedQKVLinear): + return jnp.stack(qkv_params) + else: + return jnp.concatenate(qkv_params, axis=-2) - @config_class - class Config(BaseQKVLinear.Config): - """Configures QKVLinear.""" + qkv_proj = {"weight": combine_qkv("weight")} + if "bias" in base_state["i_proj"]["q_proj"]: + qkv_proj["bias"] = combine_qkv("bias") + test_state["i_proj"] = VDict({"qkv_proj": qkv_proj}) - # The layer used to project. - layer: MultiheadInputLinear.Config = MultiheadInputLinear.default_config() + return test_state - def __init__(self, cfg: Config, *, parent: Module): - super().__init__(cfg, parent=parent) - cfg = self.config - for name, dim, num_heads in ( - ("q", cfg.query_dim, cfg.num_heads), - ("k", cfg.key_dim, self.num_kv_heads), - ("v", cfg.value_dim, self.num_kv_heads), - ): - proj_cfg = cfg.layer - proj_cfg.model_dim = dim - proj_cfg.num_heads = num_heads - proj_cfg.per_head_dim = cfg.per_head_dim - self._add_child(f"{name}_proj", proj_cfg) - def forward( - self, - query: Tensor, - *, - key: Optional[Tensor] = None, - value: Optional[Tensor] = None, - kv_state: Optional[Tensor] = None, - query_positions: Optional[Tensor] = None, - ) -> BaseQKVLinear.Output: - """Computes attention for the given query, key, value. +class MultiheadAttentionTest(TestCase): + """Tests MultiheadAttention, GroupedQueryAttention, and associated layers.""" - If `key` or `value` are None, will use `query` in place. - - See parent class for full docstring. - """ - if kv_state is not None: - raise ValueError( - "QKVLinear computes key and value projections " - "and does not expect external `kv_state`." + def test_add_tensor_stats(self): + model_dim = 12 + num_heads = 4 + cfg = attention.MultiheadAttention.default_config().set( + name="attn", + query_dim=12, + key_dim=model_dim, + value_dim=model_dim, + num_heads=num_heads, + tensor_stats=DefaultTensorStats.default_config(), + ) + layer = cfg.instantiate(parent=None) + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) + + batch_size, src_len, tgt_len = 2, 6, 6 + rng = np.random.default_rng(seed=123) + query = jnp.asarray(rng.random([batch_size, tgt_len, model_dim])) + key = jnp.asarray(rng.random([batch_size, src_len, model_dim])) + value = jnp.asarray(rng.random([batch_size, src_len, model_dim])) + attention_logit_biases = jnp.ones([batch_size, tgt_len, src_len]) * NEG_INF + x = dict(query=query, key=key, value=value, attention_logit_biases=attention_logit_biases) + _, output_collection = F( + layer, + inputs=x, + state=layer_params, + is_training=True, + prng_key=jax.random.PRNGKey(0), + ) + if "tensor_stats" in output_collection.summaries: + output_stats = output_collection.summaries["tensor_stats"] + else: + output_stats = {} + expected_stats = ["o_proj_outputs"] + for k in expected_stats: + assert k in output_stats + + def test_invalid_key_value_combinations_raise(self): + model_dim = 12 + num_heads = 4 + layer_kwargs = dict( + query_dim=model_dim, + key_dim=model_dim, + value_dim=model_dim, + num_heads=num_heads, + ) + multihead_attention = ( + attention.MultiheadAttention.default_config() + .set(name="test_multihead_attention", **layer_kwargs) + .instantiate(parent=None) + ) + fused_multihead_attention = ( + attention.MultiheadAttention.default_config() + .set( + name="test_fused_multihead_attention", + input_linear=attention.FusedQKVLinear.default_config(), + **layer_kwargs, ) - del query_positions - - key = query if key is None else key - value = query if value is None else value - q_proj = self.q_proj(query) - k_proj = self.k_proj(key) - v_proj = self.v_proj(value) - return self.Output(query=q_proj, key=k_proj, value=v_proj) - - -class GroupedQKVLinear(QKVLinear): - """A variant of QKVLinear that supports configuring a different number of key, value - projections. - - Note that the number of key, value projections must evenly divide the number of query heads. - """ - - @config_class - class Config(QKVLinear.Config): - """Configures GroupedQKVLinear.""" - - # Number of heads for key, value projections. - # It is required that num_heads % num_kv_heads == 0. - num_kv_heads: Required[int] = REQUIRED - - def __init__(self, cfg: Config, *, parent: Module): - super().__init__(cfg, parent=parent) - cfg = self.config - if cfg.num_heads % cfg.num_kv_heads != 0: - raise ValueError( - f"The number of query subgroups ({cfg.num_kv_heads}) should divide " - f"the number of query heads ({cfg.num_heads})." + .instantiate(parent=None) + ) + rng = np.random.default_rng(seed=123) + inputs = jnp.asarray(rng.random([2, 6, model_dim])) + for layer in (multihead_attention, fused_multihead_attention): + for query, key, value in [(inputs, None, inputs), (inputs, inputs, None)]: + with self.assertRaisesRegex( + ValueError, "key and value must be both None or both set" + ): + layer.forward(query, key=key, value=value) + + @parameterized.parameters(None, PerDimScale.default_config()) + def test_input_linear_variants(self, per_dim_scale): + with utils.numeric_checks(True): + model_dim = 12 + num_heads = 4 + layer_kwargs = dict( + query_dim=model_dim, + key_dim=model_dim, + value_dim=model_dim, + num_heads=num_heads, + query_scale=attention.ScaleQuery.default_config().set(per_dim_scale=per_dim_scale), ) - - @property - def num_kv_heads(self): - return self.config.num_kv_heads - - -class QLinear(BaseQKVLinear): - """Maps input query to multi-headed output query. Assumes external KVState.""" - - @config_class - class Config(BaseQKVLinear.Config): - """Configures QLinear.""" - - # The layer used to project. - layer: MultiheadInputLinear.Config = MultiheadInputLinear.default_config() - - def __init__(self, cfg: Config, *, parent: Module): - super().__init__(cfg, parent=parent) - cfg = self.config - proj_cfg = cfg.layer - proj_cfg.model_dim = cfg.query_dim - proj_cfg.num_heads = cfg.num_heads - proj_cfg.per_head_dim = cfg.per_head_dim - self._add_child("q_proj", proj_cfg) - - def forward( - self, - query: Tensor, - *, - kv_state: KVState, - key: Optional[Tensor] = None, - value: Optional[Tensor] = None, - query_positions: Optional[Tensor] = None, - ) -> BaseQKVLinear.Output: - """Computes projects for the given query. Uses {k,v}_proj from `kv_state`. - - See parent class for full docstring. - """ - if kv_state is None or key is not None or value is not None: - raise ValueError( - f"Only kv_state is expected: key={key}, value={value}, kv_state={kv_state}" + multihead_attention = ( + attention.MultiheadAttention.default_config() + .set(name="test_multihead_attention", **layer_kwargs) + .instantiate(parent=None) + ) + multihead_attention_state = multihead_attention.initialize_parameters_recursively( + jax.random.PRNGKey(0) + ) + fused_multihead_attention = ( + attention.MultiheadAttention.default_config() + .set( + name="test_fused_multihead_attention", + input_linear=attention.FusedQKVLinear.default_config(), + **layer_kwargs, + ) + .instantiate(parent=None) ) - q_proj = self.q_proj(query) - return self.Output(query=q_proj, key=kv_state.k_proj, value=kv_state.v_proj) - - -class FusedQKVLinear(BaseQKVLinear): - """Maps input query, key, and value to multi-headed query, key, and value using a fused weight. - - N.B. Only supports cases where query, key, and value all have the same shape. - """ - - @config_class - class Config(BaseQKVLinear.Config): - """Configures FusedQKVLinear.""" - # The layer used to project. - layer: MultiheadInputLinear.Config = MultiheadInputLinear.default_config() + def fused_state_from(state): + output_state = {} + for k, v in state.items(): + if k == "i_proj": + weight = jnp.array( + [v[el]["weight"] for el in ("q_proj", "k_proj", "v_proj")] + ) + bias = jnp.array([v[el]["bias"] for el in ("q_proj", "k_proj", "v_proj")]) + output_state[k] = {"qkv_proj": dict(weight=weight, bias=bias)} + else: + output_state[k] = v + return output_state + + # Map state to fused version. + fused_multihead_attention_state = fused_state_from(multihead_attention_state) + + batch_size, src_len, tgt_len = 2, 6, 6 + rng = np.random.default_rng(seed=123) + query = jnp.asarray(rng.random([batch_size, tgt_len, model_dim])) + key = jnp.asarray(rng.random([batch_size, src_len, model_dim])) + value = jnp.asarray(rng.random([batch_size, src_len, model_dim])) + attention_logit_biases = jnp.ones([batch_size, tgt_len, src_len]) * NEG_INF + inputs = dict( + query=query, key=key, value=value, attention_logit_biases=attention_logit_biases + ) - def __init__(self, cfg: Config, *, parent: Module): - super().__init__(cfg, parent=parent) - cfg = self.config - if not cfg.query_dim == cfg.key_dim == cfg.value_dim: - raise ValueError( - f"All projection dims must be equal for {type(self)}, saw: " - f"query:{cfg.query_dim}, key:{cfg.key_dim}, value:{cfg.value_dim}" + outputs = {} + layer_names = ("multihead_attention", "fused_multihead_attention") + for name, layer, state in zip( + layer_names, + (multihead_attention, fused_multihead_attention), + (multihead_attention_state, fused_multihead_attention_state), + ): + outputs[name], _ = F( + layer, + state=state, + is_training=True, + prng_key=jax.random.PRNGKey(456), + inputs=inputs, + ) + layer_output_data = outputs[name].data + # No NaN. + self.assertTrue(jnp.all(jnp.isfinite(layer_output_data)), layer_output_data) + for layer_a, layer_b in combinations(layer_names, 2): + # Check that the outputs are close for all pairs. + self.assertNestedAllClose(outputs[layer_a], outputs[layer_b]) + + @parameterized.parameters(None, PerDimScale.default_config()) + def test_all_mask(self, per_dim_scale): + with utils.numeric_checks(True): + model_dim = 12 + num_heads = 4 + per_head_dim = model_dim // num_heads + cfg = attention.MultiheadAttention.default_config().set( + name="test", + query_dim=model_dim, + key_dim=model_dim, + value_dim=model_dim, + num_heads=num_heads, + query_scale=attention.ScaleQuery.default_config().set(per_dim_scale=per_dim_scale), ) - proj_cfg = cfg.layer - proj_cfg.model_dim = cfg.query_dim - proj_cfg.num_heads = cfg.num_heads - proj_cfg.per_head_dim = cfg.per_head_dim - self._add_child("qkv_proj", proj_cfg) - - def create_parameter_specs_recursively(self) -> NestedParameterSpec: - specs = VDict(**super().create_parameter_specs_recursively()) - - def transform_factorization_spec( - spec: Optional[FactorizationSpec], - ) -> Optional[FactorizationSpec]: - if spec is None: - return None - return FactorizationSpec(axes=[None] + list(spec.axes)) - - return jax.tree.map( - lambda spec: ParameterSpec( - dtype=spec.dtype, - shape=(3, *spec.shape), - mesh_axes=PartitionSpec(None, *spec.mesh_axes), - factorization=transform_factorization_spec(spec.factorization), - fan_axes=param_init.maybe_prepend_axis( - spec.fan_axes, axis_type=param_init.FanAxes.AxisType.BATCH_AXIS + layer: attention.MultiheadAttention = cfg.instantiate(parent=None) + self.assertContainsSubset( + dict( + dropout={}, + i_proj={ + **{ + proj: { + "weight": ParameterSpec( + dtype=layer.dtype(), + shape=(model_dim, num_heads, per_head_dim), + mesh_axes=PartitionSpec(None, "model", None), + factorization=FactorizationSpec(axes=("row", None, "col")), + ), + "bias": ParameterSpec( + dtype=layer.dtype(), + shape=(num_heads, per_head_dim), + mesh_axes=PartitionSpec("model", None), + factorization=None, + ), + } + for proj in ("q_proj", "k_proj", "v_proj") + }, + }, + o_proj={ + "bias": ParameterSpec( + dtype=layer.dtype(), + shape=(model_dim,), + mesh_axes=PartitionSpec( + None, + ), + factorization=None, + ), + "weight": ParameterSpec( + dtype=layer.dtype(), + shape=(model_dim, num_heads, per_head_dim), + mesh_axes=PartitionSpec(None, "model", None), + factorization=FactorizationSpec(axes=("row", None, "col")), + ), + }, ), - ), - specs, - ) - - def initialize_parameters_recursively( - self, prng_key: Tensor, *, prebuilt: Optional[Nested[Optional[ParameterSpec]]] = None - ) -> NestedTensor: - if self._use_prebuilt_params(prebuilt): - return jax.tree.map(lambda _: None, prebuilt) - - def init(prng_key_i): - return VDict(qkv_proj=self.qkv_proj.initialize_parameters_recursively(prng_key_i)) - - return jax.vmap(init)(split_prng_key(prng_key, 3).keys) - - def forward( - self, - query: Tensor, - *, - key: Optional[Tensor] = None, - value: Optional[Tensor] = None, - kv_state: Optional[KVState] = None, - query_positions: Optional[Tensor] = None, - ) -> BaseQKVLinear.Output: - """Computes multi-head query, key, and value for the input query, key, value - using a fused weight. - - N.B. Only supports cases where query, key, and value all have the same shape if set. - - See parent class for full docstring. - - Raises: - ValueError: If key and value are not both set or both None; or if kv_state is not None. - """ - if kv_state is not None: - raise ValueError( - "FusedQKVLinear computes key and value projections " - "and does not expect external `kv_state`." + layer.create_parameter_specs_recursively(), ) - del query_positions - - with child_context("qkv_proj"): - params = self.qkv_proj.parameters - if key is None and value is None: - # Computing self attention. - # N.B. this branch (with just the query inputs) is required in - # order to get the best step time on TPU for self-attention. - inputs = query # [batch, target_length, target_dim]. - proj = self.qkv_proj.einsum_maybe_quantized( - "btd,pdnh->pbtnh", activation=inputs, kernel=params["weight"] - ) - elif key is not None and value is not None: - # Compute cross attention but with same target/source shapes. - assert ( - query.shape == key.shape == value.shape # pytype: disable=attribute-error - ), f"Not supported for {type(self)}." - inputs = jnp.stack( - [query, key, value], axis=0 - ) # [q/k/v, batch, target, model_dim]. - proj = self.qkv_proj.einsum_maybe_quantized( - "pbtd,pdnh->pbtnh", activation=inputs, kernel=params["weight"] - ) - else: - raise ValueError("Key and value should be either both None or both set.") - if self.qkv_proj.config.bias: - bias = jnp.expand_dims( - params.get("bias", jnp.array([0], dtype=query.dtype)), - (1, 2), - ) - proj = proj + bias - q_proj, k_proj, v_proj = proj - return self.Output(query=q_proj, key=k_proj, value=v_proj) - -class FusedGroupedQKVLinear(BaseQKVLinear): - """Maps input query, key, and value to multi-headed query, key, and value using a fused weight. - - The main difference from FusedQKVLinear is supporting a different number of key, value heads - than query heads. All of the projection weights are concatenated/fused along the `num_heads` - axis and then split after projection. - """ - - @config_class - class Config(BaseQKVLinear.Config): - """Configures FusedGroupedQKVLinear.""" - - # Number of heads for key, value projections. - # It is required that num_heads % num_kv_heads == 0. - num_kv_heads: Required[int] = REQUIRED - # The layer used to project. - layer: MultiheadInputLinear.Config = MultiheadInputLinear.default_config() - - def __init__(self, cfg: Config, *, parent: Module): - super().__init__(cfg, parent=parent) - cfg = self.config - if not cfg.query_dim == cfg.key_dim == cfg.value_dim: - raise ValueError( - f"All projection dims must be equal for {type(self)}, saw: " - f"query:{cfg.query_dim}, key:{cfg.key_dim}, value:{cfg.value_dim}" + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) + qkv_shapes = dict( + weight=(model_dim, num_heads, per_head_dim), bias=(num_heads, per_head_dim) ) - if cfg.num_heads % cfg.num_kv_heads != 0: - raise ValueError( - f"The number of query subgroups {cfg.num_kv_heads} should divide " - f"the number of query heads {cfg.num_heads}." + expected_scale_query_params = {} + if per_dim_scale: + expected_scale_query_params["per_dim_scale"] = dict(param=(per_head_dim,)) + expected_params = { + "i_proj": {f"{x}_proj": qkv_shapes for x in ("q", "k", "v")}, + "o_proj": dict(weight=(model_dim, num_heads, per_head_dim), bias=(model_dim,)), + "dropout": {}, + "scale_key": {}, + "scale_query": expected_scale_query_params, + } + self.assertEqual( + expected_params, + shapes(layer_params), ) - proj_cfg = cfg.layer - proj_cfg.model_dim = cfg.query_dim - proj_cfg.num_heads = cfg.num_heads + 2 * cfg.num_kv_heads - proj_cfg.per_head_dim = cfg.per_head_dim - self._add_child("qkv_proj", proj_cfg) - - @property - def num_kv_heads(self): - return self.config.num_kv_heads - def forward( - self, - query: Tensor, - *, - key: Optional[Tensor] = None, - value: Optional[Tensor] = None, - kv_state: Optional[Tensor] = None, - query_positions: Optional[Tensor] = None, - ) -> FusedQKVLinear.Output: - """See FusedQKVLinear for full docstring. - - N.B. Only supports cases where key and value are both None. - """ - if kv_state is not None: - raise ValueError( - "FusedGroupedQKVLinear computes key and value projections " - "and does not expect external `kv_state`." + batch_size, src_len, tgt_len = 2, 4, 6 + rng = np.random.default_rng(seed=123) + query = jnp.asarray(rng.random([batch_size, tgt_len, model_dim])) + key = jnp.asarray(rng.random([batch_size, src_len, model_dim])) + value = jnp.asarray(rng.random([batch_size, src_len, model_dim])) + attention_logit_biases = jnp.ones([batch_size, tgt_len, src_len]) * NEG_INF + inputs = dict( + query=query, key=key, value=value, attention_logit_biases=attention_logit_biases ) - if key is not None or value is not None: - raise ValueError("Key and value should be both None.") - del query_positions - cfg = self.config - proj = self.qkv_proj(query) - q_proj, k_proj, v_proj = jnp.split( - proj, [cfg.num_heads, cfg.num_heads + cfg.num_kv_heads], axis=-2 - ) - return self.Output(query=q_proj, key=k_proj, value=v_proj) - - -def _rotary_sinusoidal_positional_embeddings( - *, positions: Tensor, dim: int, theta: float = 10000.0 -) -> Tensor: - """Generate the sin/cos positional embedding. - - Ref: - https://github.com/huggingface/transformers/blob/main/src/transformers/models/roformer/modeling_roformer.py#L76-L90 - - Args: - positions: A tensor representing the token position IDs with shape [batch_size, seq_len]. - dim: The dimensionality of the positional embedding. - theta: A parameter to scale the frequencies. - - Returns: - Rotary Positional Embedding with shape [batch_size, seq_len, dim]. - """ - if dim % 2 != 0: - raise ValueError(f"dim: {dim} should be a multiplier of 2.") - exponents = jnp.arange(dim).astype(jnp.float32) - pos_array = positions.astype(jnp.float32) - exponents = jnp.power(theta, 2 * (exponents // 2) / dim) - position_enc = jnp.expand_dims(pos_array, 2) / jnp.expand_dims(exponents, [0, 1]) - - rope_part_1 = jnp.sin(position_enc[:, :, 0::2]) - rope_part_2 = jnp.cos(position_enc[:, :, 1::2]) - rope = jnp.concatenate((rope_part_1, rope_part_2), axis=-1) - return rope - - -class RoFormerSinusoidalPositionalEmbedding(BaseLayer): - """Implementation of Rotary Position Embedding (RoPE). - - Ref: - https://github.com/huggingface/transformers/blob/62ceb4/src/transformers/models/roformer/modeling_roformer.py - """ - - @config_class - class Config(BaseLayer.Config): - """Configures RoFormerSinusoidalPositionalEmbedding.""" - - dim: Required[int] = REQUIRED # The dimensionality of the positional embedding. - theta: float = 10000.0 # The scale of base frequency. - - def default_query_positions(self, max_seq_len: int) -> Tensor: - """Compute default `positions` value to be inputed into forward when `positions` is - not provided to the corresponding QKVLinear class such as `RoFormerQKVLinear` - """ - return jnp.arange(max_seq_len)[None] # [batch_size=1, max_seq_len]. - - def forward( - self, positions: Optional[Tensor] = None, max_seq_len: Optional[int] = None - ) -> Tensor: - """ - TODO(bwzhang): 1. verify the performance under float32. - - Args: - positions: A tensor representing the token position IDs. - The shape is [batch_size, seq_len]. - max_seq_len: Max length of sequence, required if positions is not provided - - Returns: - Rotary Positional Embedding. Shape is [seq_len, dim]. - - Raises: - ValueError: If positions is None and max_seq_len is None. - """ - cfg = self.config - if positions is None: - if max_seq_len is None: - raise ValueError( - "Must provide `max_seq_len` for computing default query positions if " - "`positions` is None." - ) - positions = self.default_query_positions(max_seq_len) - return _rotary_sinusoidal_positional_embeddings( - positions=positions, dim=cfg.dim, theta=cfg.theta - ) - - -def apply_rotary_position_embeddings( - *, - query: Tensor, - key: Tensor, - value: Tensor, - sinusoidal_pos: Tensor, - rotary_key: bool, - rotary_value: bool, -) -> tuple[Tensor, Tensor, Tensor]: - """This is a jax implementation (a copy) of the RoPE apply_rotary_position_embeddings. - - Ref: - https://github.com/huggingface/transformers/blob/v4.21.2/src/transformers/models/roformer/modeling_roformer.py#L322-L346 - - Args: - query: Query embeddings with shape [batch_size, seq_len, num_heads, dim]. - key: Key embeddings with shape [batch_size, seq_len, num_heads, dim]. - value: Value embeddings with shape [batch_size, seq_len, num_heads, dim]. - sinusoidal_pos: Rotary position embeddings with shape [batch_size, seq_len, 1, dim]. - rotary_key: Whether to apply rotary position embeddings on key. - rotary_value: Whether to apply rotary position embeddings on value. + layer_outputs, _ = F( + layer, + state=layer_params, + is_training=True, + prng_key=jax.random.PRNGKey(456), + inputs=inputs, + ) + layer_output_data = layer_outputs.data + # No NaN. + self.assertTrue(jnp.all(jnp.isfinite(layer_output_data)), layer_output_data) - Returns: - A tuple of: - Rotary position affined query embeddings with shape [batch_size, seq_len, num_heads, dim] - Rotary position affined key embeddings with shape [batch_size, seq_len, num_heads, dim] - Rotary position affined value embeddings with shape [batch_size, seq_len, num_heads, dim] - if rotary_value == True, else original value embeddings - """ - # sin [batch_size, num_heads, sequence_length, embed_size_per_head//2] - # cos [batch_size, num_heads, sequence_length, embed_size_per_head//2] - sin, cos = jnp.split(sinusoidal_pos, 2, axis=-1) - # sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1] - sin_pos = jnp.reshape(jnp.stack([sin, sin], axis=-1), sinusoidal_pos.shape) - # cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1] - cos_pos = jnp.reshape(jnp.stack([cos, cos], axis=-1), sinusoidal_pos.shape) - # rotate_half_query_layer [-q1,q0,-q3,q2......,-qd-1,qd-2] - rotate_half_query = jnp.reshape( - jnp.stack([-query[..., 1::2], query[..., ::2]], axis=-1), query.shape + @parameterized.product( + dtype=(jnp.float32, jnp.float16, jnp.bfloat16), + per_dim_scale=(None, PerDimScale.default_config()), ) - query = query * cos_pos + rotate_half_query * sin_pos - - if rotary_key: - # rotate_half_key_layer [-k1,k0,-k3,k2......,-kd-1,kd-2] - rotate_half_key = jnp.reshape( - jnp.stack([-key[..., 1::2], key[..., ::2]], axis=-1), key.shape - ) - key = key * cos_pos + rotate_half_key * sin_pos - if rotary_value: - # rotate_half_value_layer [-v1,v0,-v3,v2......,-vd-1,vd-2] - rotate_half_value = jnp.reshape( - jnp.stack([-value[..., 1::2], value[..., ::2]], axis=-1), value.shape + def test_data_types(self, dtype: jnp.dtype, per_dim_scale: Optional[PerDimScale.Config]): + model_dim = 16 + num_heads = 4 + cfg = attention.MultiheadAttention.default_config().set( + name="test", + query_dim=model_dim, + key_dim=model_dim, + value_dim=model_dim, + num_heads=num_heads, + dtype=dtype, + query_scale=attention.ScaleQuery.default_config().set(per_dim_scale=per_dim_scale), ) - value = value * cos_pos + rotate_half_value * sin_pos - return query, key, value - - -class RoFormerQKVLinear(BaseQKVLinear): - """RoFormerQKVLinear class - - This class maps the query, key, and value using the RoPE embeddings. - """ + layer = cfg.instantiate(parent=None) - @config_class - class Config(BaseQKVLinear.Config): - """Configures RoFormerQKVLinear.""" + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) - rope_pos_emb_layer: RoFormerSinusoidalPositionalEmbedding.Config = ( - RoFormerSinusoidalPositionalEmbedding.default_config() + batch_size, src_len, tgt_len = 2, 4, 6 + query = jnp.zeros([batch_size, tgt_len, model_dim], dtype=dtype) + key = jnp.zeros([batch_size, src_len, model_dim], dtype=dtype) + value = jnp.zeros([batch_size, src_len, model_dim], dtype=dtype) + attention_logit_biases = jnp.ones([batch_size, tgt_len, src_len]) * NEG_INF + inputs = dict( + query=query, key=key, value=value, attention_logit_biases=attention_logit_biases ) - input_linear: BaseQKVLinear.Config = QKVLinear.default_config() - # Whether to apply RoPE rotations to the value embeddings. - rotary_value: Required[bool] = REQUIRED - - def __init__(self, cfg: QKVLinear.Config, *, parent: Module): - super().__init__(cfg, parent=parent) - cfg = self.config - self._add_child( - "rope_pos_emb_layer", - cfg.rope_pos_emb_layer.set(dim=cfg.per_head_dim), - ) - self._add_child( - "i_proj", - cfg.input_linear.set( - query_dim=cfg.query_dim, - value_dim=cfg.value_dim, - key_dim=cfg.key_dim, - num_heads=cfg.num_heads, - per_head_dim=cfg.per_head_dim, - ), + layer_outputs, _ = F( + layer, + state=layer_params, + is_training=True, + prng_key=jax.random.PRNGKey(456), + inputs=inputs, ) + self.assertEqual(layer_outputs.data.dtype, dtype) - @property - def num_kv_heads(self): - """Propagate num KV heads from input linear.""" - return self.i_proj.num_kv_heads - - def forward( + @parameterized.product( + base_cfg=( + attention.MultiheadAttention.default_config(), + attention.GroupedQueryAttention.default_config().set( + input_linear=attention.GroupedQKVLinear.default_config().set(num_kv_heads=2) + ), + attention.GroupedQueryAttention.default_config().set( + input_linear=attention.FusedGroupedQKVLinear.default_config().set(num_kv_heads=2) + ), + attention.GroupedQueryAttention.default_config().set( + input_linear=attention.RoFormerQKVLinear.default_config().set(rotary_value=False) + ), + attention.SigmoidAttention.default_config().set( + input_linear=attention.RoFormerQKVLinear.default_config().set(rotary_value=False), + seq_len=4, + ), + attention.SigmoidAttention.default_config().set( + # Used in ALiBi position encoding. + input_linear=FusedQKVLinear.default_config(), + seq_len=4, + ), + ), + attention_logit_biases_fn=( + lambda query_len, kv_len: None, + lambda query_len, kv_len: _random_mask(jax.random.PRNGKey(1), query_len, kv_len), + ), + kv_length_multiplier=(0.5, 1, 2), + has_query_positions=(False, True), + ) + def test_causal( self, - query: Tensor, - *, - key: Optional[Tensor] = None, - value: Optional[Tensor] = None, - kv_state: Optional[KVState] = None, - query_positions: Optional[Tensor] = None, - ) -> BaseQKVLinear.Output: - cfg = self.config - # Query should have shape of [batch_size, seq_len, num_heads, per_head_dim]. - query, key, value = self.i_proj(query, key=key, value=value, kv_state=kv_state) - seq_len = query.shape[1] - sinusoidal_pos_emb = self.rope_pos_emb_layer.forward( - positions=query_positions, max_seq_len=seq_len - ).astype(query.dtype) - # sinusoidal_pos_emb shape should be [batch_size, seq_len, 1, dim] - sinusoidal_pos_emb = jnp.expand_dims(sinusoidal_pos_emb, 2) - - i_proj_computes_kv = kv_state is None - query, key, value = apply_rotary_position_embeddings( - sinusoidal_pos=sinusoidal_pos_emb, - query=query, - key=key, - value=value, - rotary_key=i_proj_computes_kv, - rotary_value=i_proj_computes_kv and cfg.rotary_value, + base_cfg: attention.MultiheadAttention.Config, + attention_logit_biases_fn: Callable[[int, int], Tensor], + kv_length_multiplier: float, + has_query_positions: bool, + ): + """Tests that base_cfg(causal=True) is equivalent to applying a causal mask.""" + if ( + has_query_positions + and not isinstance(base_cfg.input_linear, RoFormerQKVLinear.Config) + or kv_length_multiplier != 1 + and isinstance( + base_cfg.input_linear, + (FusedGroupedQKVLinear.Config, RoFormerQKVLinear.Config, FusedQKVLinear.Config), + ) + ): + pytest.skip(reason="Incompatible test setting that does not need testing.") + + model_dim = 16 + num_heads = 4 + ref_cfg = base_cfg.clone( + name="test", + query_dim=model_dim, + key_dim=model_dim, + value_dim=model_dim, + num_heads=num_heads, ) + self.assertFalse(ref_cfg.causal) + ref_layer = ref_cfg.instantiate(parent=None) + layer_params = ref_layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) + + test_cfg = ref_cfg.clone(causal=True) + test_layer = test_cfg.instantiate(parent=None) + + batch_size, query_len = 2, 4 + query = jnp.zeros([batch_size, query_len, model_dim], dtype=jnp.float32) + outputs = [] + + if has_query_positions: + query_positions = jax.random.permutation( + jax.random.PRNGKey(1), + jnp.arange(query_len)[None, :].repeat(batch_size, axis=0), + axis=1, + independent=True, + ) - return self.Output(query, key, value) - - -class PerDimScale(BaseLayer): - """A layer to scale individual dimensions of the input.""" - - @config_class - class Config(BaseLayer.Config): - """Configures PerDimScale.""" - - dim: Required[int] = REQUIRED - - @classmethod - def default_config(cls) -> Config: - cfg: PerDimScale.Config = super().default_config() - cfg.param_init = ConstantInitializer.default_config().set(value=0.0) - return cfg - - def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: - cfg = self.config - return { - "param": ParameterSpec(shape=(cfg.dim,), mesh_axes=(None,)), - } - - def forward(self, x: Tensor) -> Tensor: - """Returns x * per_dim_scale.""" - cfg = self.config - assert x.shape[-1] == cfg.dim - # https://github.com/tensorflow/lingvo/blob/3d16483b749a1181330ae9ce318688e7518d63c9/lingvo/jax/layers/attentions.py#L232-L234 - # 1.0/jax.nn.softplus(0.0) = 1.442695041. Hard code this number to avoid unnecessary - # XLA op fusion. - r_softplus_0 = 1.442695041 - scale = jax.nn.softplus(self.parameters["param"]) * r_softplus_0 - return (x * scale).astype(x.dtype) - - -ScaleFn = Callable[[int], float] # A function mapping per_head_dim to a scale. - - -def constant_scale_fn(value: float) -> ScaleFn: - """A constant scale function for `MultiheadAttention`. - - Example: - `key_scale = config_for_function(constant_scale_fn).set(value=0.01)` - - Args: - value: The value to scale by. - - Returns: - A `ScaleFn` that always returns `value`. - """ - - def constant_function(per_head_dim: int) -> float: - del per_head_dim - return value - - return constant_function - - -def pow_scale_fn(exp: float) -> ScaleFn: - """A scale function for `MultiheadAttention` that computes `per_head_dim ** exp`. - - Example: - `query_scale = config_for_function(pow_scale_fn).set(exp=-0.5)` - - Args: - exp: The exponent. - - Returns: - A `ScaleFn` that computes `per_head_dim ** exp`. - """ - - return functools.partial(pow, exp=exp) - - -class BaseScaleQK(BaseLayer): - """Defines the common interface for scaling projected attention queries or keys. - - * All subclasses must have `per_head_dim` in their config. - """ - - @config_class - class Config(BaseLayer.Config): - """Configures BaseScaleQK.""" - - # The per-head dimension. - per_head_dim: Required[int] = REQUIRED - - def forward(self, proj: Tensor) -> Tensor: - """Scales the projected queries or keys. - - Args: - proj: The projected queries/keys. - Shape: [batch, seq_length, num_heads, per_head_dim]. - - Returns: - A tensor with the same shape as the input. - """ - raise NotImplementedError(type(self)) - - -class ScaleQuery(BaseScaleQK): - """Default implementation for scaling projected queries.""" - - @config_class - class Config(BaseScaleQK.Config): - """Configures ScaleQuery.""" - - # The config for a normalization layer applied along the per-head dim. - # If None, no normalization is applied. - norm: Optional[InstantiableConfig] = None - # The config for a function to compute a query scale muliplier factor. - # If None, then self.default_scale_fn_config. - scale_factor: Optional[InstantiableConfig[ScaleFn]] = None - # A vector to apply per dimension scale to the query projection. - per_dim_scale: Optional[PerDimScale.Config] = None - - def __init__(self, cfg: Config, *, parent: Module): - super().__init__(cfg, parent=parent) - cfg = self.config - self._scale_factor = self.default_scale_factor_config() - if cfg.scale_factor is not None: - self._scale_factor = cfg.scale_factor - self._scale_factor = self._scale_factor.instantiate() - if cfg.norm is not None: - self._add_child("norm", cfg.norm.set(input_dim=cfg.per_head_dim)) - if cfg.per_dim_scale: - self._add_child("per_dim_scale", cfg.per_dim_scale.set(dim=cfg.per_head_dim)) - - def apply_norm(self, proj: Tensor) -> Tensor: - """Applies the norm to projected queries if configured.""" - if "norm" in self.children: - proj = self.norm(proj) - return proj - - def apply_per_dim_scale(self, proj: Tensor) -> Tensor: - """Applies the per-dim scale to projected queries if configured.""" - if "per_dim_scale" in self.children: - # The Lingvo MultiheadAttention applies a per_dim_scale: - # https://github.com/tensorflow/lingvo/blob/41212226eac7a26491790c2bd476b78493f93ff6/lingvo/core/batch_major_attention.py#L790 - proj = self.per_dim_scale(proj) - return proj - - def apply_scale_factor(self, proj: Tensor) -> Tensor: - """Applies the scale-factor to projected queries.""" - scale = self._scale_factor(self.config.per_head_dim) - return proj * scale - - def forward(self, proj: Tensor) -> Tensor: - """Scales the projected queries.""" - proj = self.apply_norm(proj) - proj = self.apply_per_dim_scale(proj) - proj = self.apply_scale_factor(proj) - # Stop scale constant from being folded with others. - # May increase numerical stability. - return ops.forward_optimization_barrier(proj) - - @staticmethod - def default_scale_factor_config() -> InstantiableConfig[ScaleFn]: - """The config for the default function used to compute the query scale.""" - - return config_for_function(pow_scale_fn).set(exp=-0.5) - - -class ScaleKey(BaseScaleQK): - """Default implementation for scaling projected keys.""" - - @config_class - class Config(BaseScaleQK.Config): - """Configures ScaleKey.""" - - # The config for a normalization layer applied along the per-head dim. - # If None, no normalization is applied. - norm: Optional[InstantiableConfig] = None - # The config for a function to compute a key scale muliplier factor. - # If None, then self.default_scale_factor_config. - scale_factor: Optional[InstantiableConfig[ScaleFn]] = None - - def __init__(self, cfg: Config, *, parent: Module): - super().__init__(cfg, parent=parent) - cfg = self.config - self._scale_factor = self.default_scale_factor_config() - if cfg.scale_factor is not None: - self._scale_factor = cfg.scale_factor - self._scale_factor = self._scale_factor.instantiate() - if cfg.norm is not None: - self._add_child("norm", cfg.norm.set(input_dim=cfg.per_head_dim)) - - def forward(self, proj: Tensor) -> Tensor: - """Scales the projected keys.""" - cfg = self.config - if cfg.norm is not None: - proj = self.norm(proj) - scale = self._scale_factor(cfg.per_head_dim) - proj = proj * scale - # Stop scale constant from being folded with others. - # May increase numerical stability. - return ops.forward_optimization_barrier(proj) - - @staticmethod - def default_scale_factor_config() -> InstantiableConfig[ScaleFn]: - """The config for the default function used to compute the key scale.""" - - return config_for_function(constant_scale_fn).set(value=1) - - -class MultiheadAttention(BaseLayer): - """A basic multi-head attention layer. - - Differences from torch.nn.MultiheadAttention: - - Use of einsum for efficient computation on TPU to avoid reshaping; - - Separate weights for {q,k,v}_proj for proper weight initialization that depends - on fan-out and efficient TPU execution (where split is not free). - """ - - @config_class - class Config(BaseLayer.Config): - """Configures MultiheadAttention.""" - - query_dim: Required[int] = REQUIRED # Input query feature dim. - key_dim: Required[int] = REQUIRED # Input key feature dim. - value_dim: Required[int] = REQUIRED # Input value feature dim. - output_dim: Optional[int] = None # Output feature dim. If None, use query_dim. - hidden_dim: Optional[int] = None # Hidden feature dim. If None, use query_dim. - # Number of attention heads. Must divide hidden_dim evenly. - num_heads: Required[int] = REQUIRED - # Config used to produce Q,K,V projections. - input_linear: BaseQKVLinear.Config = QKVLinear.default_config() - # Config used for the output projection. - output_linear: MultiheadOutputLinear.Config = MultiheadOutputLinear.default_config() - # The dropout layer. - dropout: Dropout.Config = Dropout.default_config() - # Config used to scale projected queries prior to computing logits. - query_scale: BaseScaleQK.Config = ScaleQuery.default_config() - # Config used to scale projected keys prior to computing logits. - key_scale: BaseScaleQK.Config = ScaleKey.default_config() - # Cap the absolute values of logits by tanh. Enabled by setting a positive value. - atten_logit_cap: Optional[float] = None - # A function to compute the boolean mask to apply when computing the attention - # where True means "attend" and False means "do not attend". - # Set to `causal_mask` for causal masking. - # When used with certain flash attention implementations, more efficient - # code paths may be used. (See the FlashAttention docstring for more details.) - # This field may not be specified if `causal` (deprecated) is specified. - # If `attention_logit_biases` argument is also specified, both masks are combined with AND. - mask: ConfigOr[Optional[MaskFn]] = None - # Deprecated. Use `mask=causal_mask` instead. - # If True, applies causal masking. `key` and `value` must be None. - # May not be specified if `mask` is already specified. - # If `attention_logit_biases` argument is also specified, both masks are combined with AND. - # TODO (apghml) Eliminate this field in favor of `mask`. - causal: Optional[bool] = None - - def __init__(self, cfg: Config, *, parent: Module): - super().__init__(cfg, parent=parent) - cfg = self.config - if cfg.causal and cfg.mask is not None: - raise NotImplementedError("Cannot specify `causal` when using `mask`.") - if cfg.causal: - self._mask_fn = causal_mask - else: - self._mask_fn = maybe_instantiate(cfg.mask) - # Configure inputs to multi-headed QKV projection. - i_proj_cfg = cfg.input_linear - i_proj_cfg.query_dim = cfg.query_dim - i_proj_cfg.key_dim = cfg.key_dim - i_proj_cfg.value_dim = cfg.value_dim - i_proj_cfg.num_heads = cfg.num_heads - i_proj_cfg.per_head_dim = self.per_head_dim() - self._add_child("i_proj", i_proj_cfg) - # Configure output projection. - o_proj_cfg = cfg.output_linear - o_proj_cfg.model_dim = self.output_dim() - o_proj_cfg.num_heads = cfg.num_heads - o_proj_cfg.per_head_dim = self.per_head_dim() - self._add_child("o_proj", o_proj_cfg) - # Add dropout layer. - self._add_child("dropout", cfg.dropout) - # Add query scaling layer. - self._add_child("scale_query", cfg.query_scale.set(per_head_dim=self.per_head_dim())) - # Add key scaling layer. - self._add_child("scale_key", cfg.key_scale.set(per_head_dim=self.per_head_dim())) - - def output_dim(self): - cfg = self.config - return cfg.output_dim or cfg.query_dim - - def hidden_dim(self): - cfg = self.config - return cfg.hidden_dim or cfg.query_dim - - def per_head_dim(self): - cfg = self.config - hidden_dim = self.hidden_dim() - if hidden_dim % cfg.num_heads != 0: - raise ValueError(f"num_heads ({cfg.num_heads}) must divide hidden_dim ({hidden_dim})") - return hidden_dim // cfg.num_heads - - class Output(NamedTuple): - """Outputs of MultiheadAttention. - - Fields: - data: [batch, target_length, output_dim]. The attention output. Always present. - probs: [batch, num_heads, target_length, source_length]. The attention probabilities. - Populated if "probs" is in `return_aux`. - kv_state: The KV state used for computing the attention outputs. - Populated if "kv_state" is in `return_aux`. - """ - - data: Tensor - probs: Optional[Tensor] = None - kv_state: Optional[KVState] = None - - def _forward_for_mode( + for layer in (ref_layer, test_layer): + inputs = dict(query=query) + kv_len = int(kv_length_multiplier * query_len) + if kv_length_multiplier < 1: + inputs["key"] = query[:, :kv_len] + inputs["value"] = query[:, :kv_len] + elif kv_length_multiplier > 1: + inputs["key"] = jnp.tile(query, [1, int(kv_length_multiplier), 1]) + inputs["value"] = jnp.tile(query, [1, int(kv_length_multiplier), 1]) + + attention_logit_biases = attention_logit_biases_fn(inputs["query"].shape[1], kv_len) + if layer is ref_layer: + # Apply causal mask on top of the logit biases for `ref_layer`. + causal_biases = make_index_position_biases(inputs["query"].shape[1], kv_len=kv_len) + if attention_logit_biases is None: + attention_logit_biases = causal_biases + else: + attention_logit_biases = apply_attention_logit_biases( + attention_logit_biases, causal_biases + ) + inputs["attention_logit_biases"] = attention_logit_biases + if has_query_positions: + inputs["query_positions"] = query_positions + + layer_outputs, _ = F( + layer, + state=layer_params, + is_training=True, + prng_key=jax.random.PRNGKey(456), + inputs=inputs, + ) + outputs.append(layer_outputs) + # The outputs are equivalent. + self.assertNestedAllClose(outputs[0], outputs[1]) + + @parameterized.product( + base_cfg=( + attention.MultiheadAttention.default_config(), + attention.GroupedQueryAttention.default_config().set( + input_linear=attention.GroupedQKVLinear.default_config().set(num_kv_heads=2) + ), + attention.GroupedQueryAttention.default_config().set( + input_linear=attention.FusedGroupedQKVLinear.default_config().set(num_kv_heads=2) + ), + attention.GroupedQueryAttention.default_config().set( + input_linear=attention.RoFormerQKVLinear.default_config().set(rotary_value=False) + ), + attention.SigmoidAttention.default_config().set( + input_linear=attention.RoFormerQKVLinear.default_config().set(rotary_value=False), + seq_len=4, + ), + attention.SigmoidAttention.default_config().set( + # Used in ALiBi position encoding. + input_linear=FusedQKVLinear.default_config(), + seq_len=4, + ), + ), + attention_logit_biases_fn=( + lambda seq_len: None, + lambda seq_len: _random_mask(jax.random.PRNGKey(1), seq_len, seq_len), + ), + has_query_positions=(False, True), + ) + def test_sliding_window( self, - *, - mode: ForwardMode, - query: Union[Tensor, TensorSpec], - key: Optional[Tensor] = None, - value: Optional[Tensor] = None, - kv_state: Optional[KVState] = None, - attention_logit_biases: Union[None, Tensor, BaseAttentionBias] = None, - segment_ids: Optional[Tensor] = None, - query_positions: Optional[Tensor] = None, - cached_states: Optional[NestedTensor] = None, - return_aux: Optional[set[str]] = None, - ) -> tuple[Nested[Tensor], Optional[Output]]: - """Computes attention for the given query, key, value, and attention logit biases. - - If key and value are both None, computes self-attention using query. - - Args: - mode: Configures whether `cached_states` are consumed or emitted. See `ForwardMode` for - details. - query: A Tensor or TensorSpec of shape [batch, target_length, target_dim]. - key: An optional Tensor of shape [batch, source_length, source_dim]. - value: An optional Tensor of shape [batch, source_length, source_dim]. - kv_state: An optional KVState. If specified, both `key` and `value` should be None. - attention_logit_biases: See ``On attention logit biases`` in the file comments. - segment_ids: See ``On segment_ids`` in the file comments. - query_positions: See ``On positions`` in the file comments. - cached_states: Optional NestedTensor as produced by `init_states`. - return_aux: See comments on `Output`. - - Returns: - A tuple (cached_states, output): - * cached_states: An optional NestedTensor of cache states, depending on `mode`. - * output: An optional Output instance, where .data is of the same shape as query and - .probs is of shape [batch, num_heads, target_length, source_length]. - If initializing cache from scratch, output will be None. - - Raises: - ValueError: If key & value are an invalid combination. - ValueError: If `mode` is unsupported. + base_cfg: attention.MultiheadAttention.Config, + attention_logit_biases_fn: Callable[[int], Tensor], + has_query_positions: bool, + ): + """ + Tests that base_cfg with sliding window causal mask fns is equivalent to applying a + causal sliding window mask. """ - # Validate key & value combination. - if (key is None) != (value is None): - raise ValueError( - "key and value must be both None or both set, " - f"key:{type(key)}, value:{type(value)}" + if has_query_positions and not isinstance(base_cfg.input_linear, RoFormerQKVLinear.Config): + return + + model_dim = 16 + num_heads = 4 + ref_cfg = base_cfg.clone( + name="test", + query_dim=model_dim, + key_dim=model_dim, + value_dim=model_dim, + num_heads=num_heads, + ) + self.assertFalse(ref_cfg.causal) + ref_layer = ref_cfg.instantiate(parent=None) + layer_params = ref_layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) + + sliding_window_size = 2 + test_cfg = ref_cfg.clone( + causal=False, + mask=config_for_function(sliding_window_causal_mask).set( + sliding_window_size=sliding_window_size + ), + ) + test_layer = test_cfg.instantiate(parent=None) + + batch_size, seq_len = 2, 4 + query = jnp.zeros([batch_size, seq_len, model_dim], dtype=jnp.float32) + outputs = [] + + if has_query_positions: + query_positions = jax.random.permutation( + jax.random.PRNGKey(1), + jnp.arange(seq_len)[None, :].repeat(batch_size, axis=0), + axis=1, + independent=True, ) - if kv_state is not None: - if key is not None or value is not None: - raise ValueError("kv_state should not be specified together with key/value") - kv_kwargs = dict(kv_state=kv_state) - else: - kv_kwargs = dict(key=key, value=value) - if mode == ForwardMode.FORWARD: - i_proj_state, i_proj_output = ( - None, - self.i_proj(query, query_positions=query_positions, **kv_kwargs), - ) - elif mode == ForwardMode.INIT_STATES: - assert cached_states is not None - assert query_positions is None - i_proj_state, i_proj_output = self.i_proj.init_states( - time_step=cached_states["i_proj"], query=query, **kv_kwargs - ) - elif mode == ForwardMode.EXTEND_STEP: - assert cached_states is not None - assert query_positions is None - i_proj_state, i_proj_output = self.i_proj.extend_step( - cached_states["i_proj"], query, **kv_kwargs - ) - else: - raise ValueError(f"Unrecognized mode {mode}.") - - if i_proj_output is None: - assert mode == ForwardMode.INIT_STATES - return dict(i_proj=i_proj_state), None - - q_proj, k_proj, v_proj = i_proj_output - kv_state = KVState(k_proj=k_proj, v_proj=v_proj) - q_proj = self._remat_name(q_proj, "q_proj") - k_proj = self._remat_name(k_proj, "k_proj") - v_proj = self._remat_name(v_proj, "v_proj") - self.vlog(3, "atten.q_proj=%s", q_proj.sum()) - self.vlog(3, "atten.k_proj=%s", k_proj.sum()) - self.vlog(3, "atten.v_proj=%s", v_proj.sum()) - attention_logit_biases = as_attention_bias(attention_logit_biases) - if self._mask_fn is not None: - target_positions = None - if mode == ForwardMode.EXTEND_STEP: - target_positions = cached_states["i_proj"]["time_step"] - if self._mask_fn is causal_mask: - # Needed for legacy flash attention implementations that don't have - # sparse mask support. - # E.g., the legacy tpu flash attention, all current gpu flash attention - # implementations. - attention_logit_biases += CausalAttentionBias( - shape=(q_proj.shape[1], k_proj.shape[1]), - target_positions=target_positions, - dtype=q_proj.dtype, - ) - else: - attention_logit_biases += MaskFnAttentionBias( - self._mask_fn, - shape=(q_proj.shape[1], k_proj.shape[1]), - target_positions=target_positions, - dtype=q_proj.dtype, + for layer in (ref_layer, test_layer): + attention_logit_biases = attention_logit_biases_fn(seq_len) + if layer is ref_layer: + # Apply causal and sliding window mask on top of the logit biases for `ref_layer`. + attention_logit_biases = apply_attention_logit_biases( + make_sliding_window_causal_biases(seq_len, sliding_window_size), + attention_logit_biases, ) - if segment_ids is not None: - attention_logit_biases += SegmentIdAttentionBias(segment_ids) - context, probs = self._compute_attention( - q_proj=q_proj, - k_proj=k_proj, - v_proj=v_proj, - attention_logit_biases=attention_logit_biases, + inputs = dict(query=query, attention_logit_biases=attention_logit_biases) + if has_query_positions: + inputs["query_positions"] = query_positions + layer_outputs, _ = F( + layer, + state=layer_params, + is_training=True, + prng_key=jax.random.PRNGKey(456), + inputs=inputs, + ) + outputs.append(layer_outputs) + # The outputs are equivalent. + self.assertNestedAllClose(outputs[0], outputs[1]) + + @parameterized.product( + dtype=(jnp.float32, jnp.float16, jnp.bfloat16), + per_dim_scale=(None, PerDimScale.default_config()), + atten_logit_cap=(0.0, 20.0), + input_linear=( + None, # Use the default linear. + attention.QKVLinear.default_config(), + attention.FusedQKVLinear.default_config(), + attention.GroupedQKVLinear.default_config().set(num_kv_heads=4), + attention.FusedGroupedQKVLinear.default_config().set(num_kv_heads=4), + ), + bias=(True, False), + ) + def test_gqa_forward( + self, + dtype: jnp.dtype, + per_dim_scale: Optional[PerDimScale.Config], + atten_logit_cap: float, + input_linear: attention.BaseQKVLinear.Config, + bias: bool, + ): + """When num_kv_heads=num_heads, GQA should be equivalent to MHA.""" + model_dim = 16 + num_heads = 4 + layer_kwargs = dict( + query_dim=model_dim, + key_dim=model_dim, + value_dim=model_dim, + num_heads=num_heads, + query_scale=attention.ScaleQuery.default_config().set(per_dim_scale=per_dim_scale), + atten_logit_cap=atten_logit_cap, + dtype=dtype, ) - self.vlog(3, "atten.prob=%s", probs[0, 0, 0, :]) - self.vlog(3, "atten.context=%s", context.sum()) - - # [batch, target_length, output_dim]. - o_proj = self.o_proj(context) - outputs = self._remat_name(o_proj, "o_proj") - self._add_tensor_stats("o_proj_outputs", outputs) - return_aux = return_aux or set() - output = self.Output( - data=outputs, - probs=probs if "probs" in return_aux else None, - kv_state=kv_state if "kv_state" in return_aux else None, + init_key = jax.random.PRNGKey(123) + # Initialize MultiheadAttention. + base_cfg = attention.MultiheadAttention.default_config().set(**layer_kwargs) + set_bias_recursively(base_cfg, bias=bias) + base_layer = base_cfg.set(name="base").instantiate(parent=None) + base_state = base_layer.initialize_parameters_recursively(prng_key=init_key) + # Initialize GroupedQueryAttenion. + cfg = attention.GroupedQueryAttention.default_config().set(**layer_kwargs) + if input_linear is not None: + cfg.set(input_linear=input_linear) + set_bias_recursively(cfg, bias=bias) + test_layer = cfg.set(name="test").instantiate(parent=None) + logging.info("base_state=%s", shapes(base_state)) + # We convert 'base_state' to 'test_state' because JAX does not ensure that RNG behavior + # remains the same with vs. without vmap. So test_layer initialization may behave + # differently even with the same seed. + test_state = _convert_to_qkv_linear( + base_state, input_linear_layer_class=cfg.input_linear.klass ) - return dict(i_proj=i_proj_state), output - - def _compute_attention( - self, - *, - q_proj: Tensor, - k_proj: Tensor, - v_proj: Tensor, - attention_logit_biases: BaseAttentionBias, - ) -> tuple[Tensor, Tensor]: - """Computes attention context and probs. - - Args: - q_proj: [batch_size, target_length, num_heads, per_head_dim]. - k_proj: [batch_size, source_length, num_heads, per_head_dim]. - v_proj: [batch_size, source_length, num_heads, per_head_dim]. - attention_logit_biases: See ``On attention logit biases`` in the file comments. - - Returns: - The context of shape [batch_size, target_length, num_heads, per_head_dim], - and probs of shape [batch, num_heads, target_length, source_length]. - """ - logits = self._compute_logits(q_proj, k_proj) - logits = self._cap_logits(logits) - self.vlog(3, "atten.logits=%s", logits[0, 0, 0, :]) - probs = softmax_with_biases(logits, attention_logit_biases=attention_logit_biases.value()) - probs = self.dropout(probs) - context = self._compute_context(probs, v_proj) - context = self._remat_name(context, "context") - return context, probs + logging.info("transformed_test_state=%s", shapes(test_state)) + + # Dummy inputs. + batch_size, tgt_len = 2, 6 + inputs = dict( + query=jax.random.normal( + jax.random.PRNGKey(124), + [batch_size, tgt_len, model_dim], + dtype=dtype, + ), + key=None, + value=None, + attention_logit_biases=attention_bias.make_causal_biases(tgt_len), + ) + # Get outputs. + forward_key = jax.random.PRNGKey(456) + base_outputs, _ = F( + base_layer, + state=base_state, + is_training=False, + prng_key=forward_key, + inputs=inputs, + ) + test_outputs, _ = F( + test_layer, + state=test_state, + is_training=False, + prng_key=forward_key, + inputs=inputs, + ) + self.assertNestedAllClose(base_outputs, test_outputs) - def forward( + def _test_extend_step( self, - query: Tensor, + attention_cfg: attention.MultiheadAttention.Config, *, - key: Optional[Tensor] = None, - value: Optional[Tensor] = None, - kv_state: Optional[KVState] = None, - attention_logit_biases: Optional[Tensor] = None, - segment_ids: Optional[Tensor] = None, - query_positions: Optional[Tensor] = None, - return_aux: Optional[set[str]] = None, - ) -> Output: - """Computes attention for the given query, key, value, and attention logit biases. - - If key and value are both None, computes self-attention using query. - - Args: - query: A Tensor of shape [batch, target_length, target_dim]. - key: An optional Tensor of shape [batch, source_length, source_dim]. - value: An optional Tensor of shape [batch, source_length, source_dim]. - kv_state: An optional KVState. If not None, both key and value must be None. - attention_logit_biases: See ``On attention logit biases`` in the file comments. - segment_ids: See `On segment_ids` in the file comments. - query_positions: See ``On positions`` in the file comments. - return_aux: See comments on `Output`. - - Returns: - An Output instance, where .data is of the same shape as query and .probs is of shape - [batch, num_heads, target_length, source_length]. - - Raises: - ValueError: If key & value are an invalid combination. - """ - _, output = self._forward_for_mode( - mode=ForwardMode.FORWARD, - query=query, - key=key, - value=value, - kv_state=kv_state, - attention_logit_biases=attention_logit_biases, - segment_ids=segment_ids, - query_positions=query_positions, - return_aux=return_aux, + model_dim: int, + num_heads: int, + dtype: jnp.dtype, + bias: bool, + extend_step_len: int, + ): + cfg = attention_cfg.set( + query_dim=model_dim, + key_dim=model_dim, + value_dim=model_dim, + num_heads=num_heads, ) - return output + cfg.input_linear.set(dtype=dtype, cache_dtype=None) + set_bias_recursively(cfg, bias=bias) + layer: attention.MultiheadAttention = cfg.set(name="test").instantiate(parent=None) - def _cap_logits(self, logits: Tensor) -> Tensor: - """Caps the logits with tanh.""" - cfg = self.config - if not cfg.atten_logit_cap or cfg.atten_logit_cap <= 0.0: - return logits - cap = jnp.array(cfg.atten_logit_cap, dtype=logits.dtype) - return cap * jnp.tanh(logits / cap) + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) - def _compute_logits(self, q_proj: Tensor, k_proj: Tensor) -> Tensor: - """Compute attention logits. - - Args: - q_proj: query tensor, [batch, target_length, num_heads, per_head_dim]. - k_proj: key tensor, [batch, source_length, num_heads, per_head_dim]. - - Returns: - logits: [batch, num_heads, target_length, source_length]. - """ - q_proj = self.scale_query(q_proj) - k_proj = self.scale_key(k_proj) - return jnp.einsum("btnh,bsnh->bnts", q_proj, k_proj) - - def _compute_context(self, probs: Tensor, v_proj: Tensor) -> Tensor: - """Compute attention context. - - Args: - probs: probs tensor, [batch, num_heads, target_length, source_length]. - v_proj: value tensor, [batch, source_length, num_heads, per_head_dim]. - - Returns: - context: [batch, target_length, num_heads, per_head_dim]. - """ - return jnp.einsum("bnts,bsnh->btnh", probs, v_proj).astype(v_proj.dtype) - - def init_states( - self, - *, - time_step: Optional[Tensor], - query: Union[Tensor, TensorSpec], - key: Optional[Tensor] = None, - value: Optional[Tensor] = None, - kv_state: Optional[KVState] = None, - attention_logit_biases: Optional[Tensor], - return_aux: Optional[set[str]] = None, - ) -> tuple[Nested[Tensor], Optional[Output]]: - """Initializes cache for autoregressive cached decoding. - - The method supports initializing an empty cache as well as prefilling: - * To initialize an empty cache, specify `time_step=None`. - In this case, `query` is allowed to be a TensorSpec. - * To prefill, provide `time_step` and `query` as Tensors. - - Args: - time_step: A Tensor of shape [batch]. Each value is an index into the length dimension - indicating where decoding will start from. - query: A Tensor or TensorSpec of shape [batch, target_length, target_dim] corresponding - to query projection input vector up to `time_step`. For batch index `i`, only - `query[i, :time_step[i], ...]` will affect subsequent decoding. - key: Same description as `query`, but for the key projection input vector. - Key and value have to both be tensors or both be None. - If they are tensors, key and value are used as the unique input to the - input projection. Otherwise, query is used as the key and value input. - value: Same description as `query`, but for the value projection input vector. - See the above comment for `key` for additional constraints. - kv_state: An optional KVState. - attention_logit_biases: See ``On attention logit biases`` in the file comments. - return_aux: See comments on `Output`. - - Returns: - A tuple (init_states, output): - * init_states: A Nested Tensor state of key and value pair along with index updated at - `time_step`. - * output: In the prefill case, an Output instance, where .data is of the same shape as - query and .probs is of shape [batch, num_heads, target_length, source_length]. - Otherwise, if initializing cache from scratch, output will be None. - """ - return self._forward_for_mode( - mode=ForwardMode.INIT_STATES, - query=query, - key=key, - value=value, - cached_states=dict(i_proj=time_step), - kv_state=kv_state, - attention_logit_biases=attention_logit_biases, - return_aux=return_aux, + batch_size, tgt_len = 2, 6 + head_dim = model_dim // num_heads + query = jax.random.normal( + jax.random.PRNGKey(123), [batch_size, tgt_len, model_dim], dtype=dtype ) - - def extend_step( - self, - cached_states: NestedTensor, - query: Tensor, - *, - key: Optional[Tensor] = None, - value: Optional[Tensor] = None, - kv_state: Optional[KVState] = None, - attention_logit_biases: Optional[Tensor], - return_aux: Optional[set[str]] = None, - ) -> tuple[NestedTensor, Output]: - """Computes the value vector given the query of the current step. - This function is used by autoregressive decoding. - - Based on: - https://github.com/tensorflow/lingvo/blob/5754b2f840ebf0f8c52d87e5d4d76f22e372513e/lingvo/jax/layers/attentions.py#L1249 - - Args: - cached_states: A `NestedTensor` object containing tensors which are the results of - previous attentions, and index used for fast decoding. Contains "key" and "value" of - shape [B, N, H, T], and a Tensor "time_step" of shape [B]. - query: Tensor of shape [B, 1, D] corresponding to query projection input vector - at "time_step" indices. - key: Tensor of shape [B, 1, D] corresponding to key projection input vector at - "time_step" indices. Key and value have to both be tensors or both be None. - If they are tensors, key and value are used as the unique input to the - input projection. Otherwise, query is used as the key and value input. - value: Tensor of shape [B, 1, D] corresponding to value projection input vector - at "time_step" indices. See the above comment for `key` for additional - constraints. - kv_state: An optional KVState. - attention_logit_biases: See ``On attention logit biases`` in the file comments. - Additionally, target_length is expected to be 1 since this is per time step. - The biases should already include causal masking for decoding, plus other biases - if necessary. - return_aux: See comments on `Output`. - - Returns: - A `NestedTensor` state of key and value pair along with index updated at `time_step`. - An Output instance, where .data is of the same shape as query, .probs is of shape - [batch, num_heads, 1, source_length]. - """ - return self._forward_for_mode( - mode=ForwardMode.EXTEND_STEP, + key = value = kv_state = None + if attention_cfg.klass == attention.GroupedQueryAttention: + pass + elif attention_cfg.input_linear.klass == QLinear: + kv_state = KVState( + k_proj=jax.random.normal( + jax.random.PRNGKey(124), [batch_size, tgt_len, num_heads, head_dim], dtype=dtype + ), + v_proj=jax.random.normal( + jax.random.PRNGKey(125), [batch_size, tgt_len, num_heads, head_dim], dtype=dtype + ), + ) + else: + # Make key and value distinct from query. Otherwise, it is equivalent + # to the query only case. + key = value = query + 0.1 + attention_logit_biases = attention_bias.make_causal_biases(tgt_len) + return_aux = {"probs"} + inputs = dict( query=query, key=key, value=value, - cached_states=cached_states, kv_state=kv_state, attention_logit_biases=attention_logit_biases, return_aux=return_aux, ) - - @staticmethod - def default_query_scale_config() -> InstantiableConfig[ScaleFn]: - """The config for the default function used to compute the query scale.""" - - return config_for_function(pow_scale_fn).set(exp=-0.5) - - @staticmethod - def default_key_scale_config() -> InstantiableConfig[ScaleFn]: - """The config for the default function used to compute the key scale.""" - - return config_for_function(constant_scale_fn).set(value=1) - - -class GroupedQueryAttention(MultiheadAttention): - """A Grouped-Query Attention (GQA) layer. - - Query projections are divided into K groups along the `num_heads` dimension. Projections in the - same query subgroup share one common key/value head. This reduces the size of the KV-cache by a - factor of `num_heads/num_kv_heads`. - - When `input_linear` is a `GroupedQKVLinear` layer with `num_kv_heads=1`, GQA reduces to - multi-query attention (MQA). - When `input_linear` is a `QKVLinear` layer (i.e. `num_kv_heads=num_heads`), GQA is equivalent to - multi-head attention (MHA). - - Note that in some cases fused variants `FusedQKVLinear` or `FusedGroupedQKVLinear` can be used - as drop-in replacements for `QKVLinear` or `GroupedQKVLinear` respectively (see corresponding - layer docstrings for details). - - Reference: https://arxiv.org/abs/2305.13245 - """ - - @property - def num_kv_heads(self): - return self.i_proj.num_kv_heads - - def _compute_logits(self, q_proj: Tensor, k_proj: Tensor) -> Tensor: - """Compute attention logits. - - Args: - q_proj: query tensor, [batch, target_length, num_heads, per_head_dim]. - k_proj: key tensor, [batch, source_length, num_kv_heads, per_head_dim]. - - Returns: - logits: [batch, num_heads, target_length, source_length]. - """ - kv_heads = k_proj.shape[-2] - num_head_group = self.config.num_heads // kv_heads - if num_head_group == 1: - return super()._compute_logits(q_proj=q_proj, k_proj=k_proj) - - q_proj = self.scale_query(q_proj) - k_proj = self.scale_key(k_proj) - q_proj = einops.rearrange(q_proj, "b t (k g) h -> b t k g h", k=kv_heads, g=num_head_group) - k_proj = einops.rearrange(k_proj, "b s k h -> b s k 1 h") - logits = jnp.einsum("btkgh,bsk1h->bkgts", q_proj, k_proj) - return einops.rearrange(logits, "b k g t s -> b (k g) t s") - - def _compute_context(self, probs: Tensor, v_proj: Tensor) -> Tensor: - """Compute attention context. - - Args: - probs: probs tensor, [batch, num_heads, target_length, source_length]. - v_proj: value tensor, [batch, source_length, num_kv_heads, per_head_dim]. - - Returns: - context: [batch, target_length, num_heads, per_head_dim]. - """ - kv_heads = v_proj.shape[-2] - num_head_group = self.config.num_heads // kv_heads - if num_head_group == 1: - return super()._compute_context(probs=probs, v_proj=v_proj) - - probs = einops.rearrange(probs, "b (k g) t s -> b k g t s", k=kv_heads, g=num_head_group) - v_proj = einops.rearrange(v_proj, "b s k h -> b s k 1 h") - context = jnp.einsum("bkgts,bsk1h->btkgh", probs, v_proj) - return einops.rearrange(context, "b t k g h -> b t (k g) h") - - -class SigmoidAttention(MultiheadAttention): - """A multi-head sigmoid-based attention layer, instead of softmax. - - TODO(floris_weers): Add paper reference. - """ - - @config_class - class Config(MultiheadAttention.Config): - """Configures SigmoidAttention.""" - - seq_len: Required[int] = REQUIRED # Maximum sequence length used. - - def _compute_attention( - self, - *, - q_proj: Tensor, - k_proj: Tensor, - v_proj: Tensor, - attention_logit_biases: BaseAttentionBias, - ) -> tuple[Tensor, Tensor]: - """See `MultiheadAttention._compute_attention` for details.""" - cfg = self.config - logits = self._compute_logits(q_proj, k_proj) - logits = self._cap_logits(logits) - self.vlog(3, "atten.logits=%s", logits[0, 0, 0, :]) - - attention_logit_biases = attention_logit_biases.value() - if attention_logit_biases is None: - attention_logit_biases = 0 - # To approximate softmax, we subtract a bias dependent on sequence length. - attention_logit_biases = attention_logit_biases - jnp.log(cfg.seq_len) - probs = sigmoid_with_biases( - logits, - attention_logit_biases=attention_logit_biases, + forward_outputs, _ = F( + layer, + state=layer_params, + is_training=False, + prng_key=jax.random.PRNGKey(456), + inputs=inputs, ) - probs = self.dropout(probs) - - context = self._compute_context(probs, v_proj) - context = self._remat_name(context, "context") - return context, probs - - -def rel_pos_to_abs_pos(x: Tensor) -> Tensor: - """Converts a (T, relative_pos_offset) Tensor to a (T, abs_position) tensor. - - For example, t = 3: - ..abc abc - .def. => def - ghi.. ghi - - Input shape: [t, 2t - 1]: - ..abc - .def. - ghi.. - - 1. Reshape to [t * (2t - 1)] - ..abc.def.ghi.. - - 2. Trim by [t-1:-1], producing shape [t * (2t - 2)]. - abc.def.ghi. - - 3. Reshape to [t, 2t - 2]: - abc. - def. - ghi. - 4. Trim by [:, :-(t-2)] - abc - def - ghi - - Args: - x: a Tensor of shape [T, 2*T - 1], where x[i, j] represents the bias between query[i] and - absolute position k = i + j - (T - 1), if 0 <= k < T, otherwise the value is not used. - T is expected to be >= 1. - - Returns: - y: a Tensor of shape [T, T], s.t. y[i, k] = x[i, j] where k = i + j - (T - 1), - if 0 <= k < T. - """ - t, offset_length = x.shape - assert offset_length == 2 * t - 1 - if t <= 1: - return x - # [t * (2t - 1)]. - x = x.reshape([-1]) - # [t * (2t - 2)]. - x = x[t - 1 : -1] - # [t, 2t - 2]. - x = x.reshape([t, -1]) - # [t, t]. When t = 2, do not trim. - if t > 2: - x = x[:, : -(t - 2)] - return x - - -class MultiheadRelativePositionLinear(BaseMultiheadLinear): - """Multi-head relative position linear layer.""" - - @property - def _einsum_expr(self): - return "ld,dnh->lnh" - - @property - def _bias_spec(self): - cfg = self.config - return ParameterSpec( - shape=(cfg.num_heads, cfg.per_head_dim), - mesh_axes=cfg.param_partition_spec[-2:], + initial_state, initial_output = layer.init_states( + time_step=None, + query=TensorSpec([batch_size, tgt_len]), + kv_state=kv_state, + # This is unused for initializing state from scratch. + attention_logit_biases=None, ) - - # pylint: disable-next=no-self-use - def _compute_fan_axes(self, name: str, parameter_spec: ParameterSpec) -> Optional[FanAxes]: - if name == "weight": - return FanAxes(in_axis=0, out_axis=(1, 2)) + self.assertIsNone(initial_output) + if kv_state is None: + for k in ["key", "value"]: + # Check that the cache dtype is inferred as the layer dtype. + self.assertEqual(initial_state["i_proj"][k].dtype, dtype) else: - return None - - -def xl_attention_logits( - q_proj: Tensor, k_proj: Tensor, relative_pos_emb: Tensor, u: Tensor, v: Tensor -): - """Computes Transformer XL self-attention logits. - - Note that this implementation follows XLNet implementation and is different from the lingvo - implementation in that here the relative_pos_emb index is computed from key_i - query_i, - while lingvo computes from query_i - key_i. - - Args: - q_proj: A Tensor of shape [batch, target_length, num_heads, per_head_dim], representing - projected queries. - k_proj: A Tensor of shape [batch, target_length, num_heads, per_head_dim], representing - projected keys. - relative_pos_emb: A Tensor of shape [num_embeddings, num_heads, per_head_dim], representing - projected relative positional embeddings, where num_embeddings = 2 * target_length - 1. - relative_pos_emb[key_i - query_i + target_length - 1] represents positional - embeddings between query[:, query_i] and key[:, key_i] and is usually computed from - sinusoidal_positional_embeddings(query_i - key_i), i.e., - relative_pos_emb[0] represents query_i = target_length - 1 and key_i = 0. - relative_pos_emb[-1] represents query_i = 0 and key_i = target_length - 1. - u: A Tensor of shape [num_heads, per_head_dim]. - The trainable `u` in https://arxiv.org/pdf/1901.02860.pdf 3.3 used for term 'ac'. - v: A Tensor of shape [num_heads, per_head_dim]. - The trainable `v` in https://arxiv.org/pdf/1901.02860.pdf 3.3 used for term 'bd'. - - Returns: - A tensor of shape [batch, num_heads, target_length, target_length] representing attention - logits. logit[:, :, i, j] represents the logit for query[i] and key[j]. - """ - term_ac = jnp.einsum("btnh,bsnh->bnts", q_proj + u, k_proj) - term_bd = jnp.einsum("btnh,lnh->bntl", q_proj + v, relative_pos_emb) - # Apply vmap twice to map over both `batch` and `num_heads`. - term_bd = jax.vmap(jax.vmap(rel_pos_to_abs_pos))(term_bd) - return term_ac + term_bd - - -class MultiheadAttentionXL(MultiheadAttention): - """Multi-head self-attention with relative positional embeddings. - - The default config matches XL-Net implementation with `per_dim_scale=None` and - `scale_position=LOGIT`. - To match with Lingvo implementation, enable `per_dim_scale` - and set `scale_position=QUERY`. Note the positional embeddings are in descending - order, which is different from Lingvo's implementation. - - Reference: - https://github.com/zihangdai/xlnet/blob/bbaa3a6fa0b3a2ee694e8cf66167434f9eca9660/modeling.py - https://github.com/huggingface/transformers/blob/224bde91caff4ccfd12277ab5e9bf97c61e22ee9/src/transformers/models/xlnet/modeling_xlnet.py#L204 - https://github.com/tensorflow/lingvo/blob/a1326a09641a6ec7d775a51148012551756d888d/lingvo/core/batch_major_attention.py#L1345 - https://github.com/tensorflow/lingvo/blob/f02fed838836bcc513d51c95d482247b119543fb/lingvo/core/attention_util.py#L464-L513 - """ - - @unique - class ScalePosition(Enum): - # Applies query scale-factor to the logits. - LOGIT = 0 - # Applies query scale-factor to the queries. - QUERY = 1 - - @config_class - class Config(MultiheadAttention.Config): - """Configures MultiheadAttentionXL.""" - - pos_emb_dim: Optional[int] = None # Positional embedding dim. If None, use query_dim. - # Config for computing relative position embeddings for range [-seq_len + 1, seq_len - 1]. - relative_pos_emb: SinusoidalPositionalEmbedding.Config = ( - SinusoidalPositionalEmbedding.default_config() - ) - # Config used for the R projection. - relative_pos_linear: MultiheadRelativePositionLinear.Config = ( - MultiheadRelativePositionLinear.default_config().set(bias=False) - ) - scale_position: Required["MultiheadAttentionXL.ScalePosition"] = REQUIRED - - @classmethod - def default_config(cls) -> Config: - cfg: MultiheadAttentionXL.Config = super().default_config() - cfg.scale_position = MultiheadAttentionXL.ScalePosition.LOGIT - # pylint: disable=no-member - cfg.input_linear = FusedQKVLinear.default_config() - cfg.input_linear.layer.bias = False - # pylint: enable=no-member - return cfg - - def __init__(self, cfg: Config, *, parent: Module): - super().__init__(cfg, parent=parent) - cfg: MultiheadAttentionXL.Config = self.config - if not cfg.query_dim == cfg.key_dim == cfg.value_dim: - raise ValueError( - f"MultiheadAttentionXL requires query_dim ({cfg.query_dim}) == " - f"key_dim ({cfg.key_dim}) == value_dim ({cfg.value_dim})" + self.assertNotIn("key", initial_state["i_proj"]) + self.assertNotIn("value", initial_state["i_proj"]) + inputs = dict(cached_states=initial_state, kv_state=kv_state, return_aux=return_aux) + decoder_output = [] + decoder_probs = [] + for t in range(0, tgt_len, extend_step_len): + inputs["query"] = query[:, t : t + extend_step_len, :] + if key is not None: + inputs["key"] = key[:, t : t + extend_step_len, :] + if value is not None: + inputs["value"] = value[:, t : t + extend_step_len, :] + inputs["attention_logit_biases"] = attention_logit_biases[t : t + extend_step_len, :] + (cached_states, extend_step_outputs), _ = F( + layer, + state=layer_params, + is_training=False, + prng_key=jax.random.PRNGKey(456), + inputs=inputs, + method="extend_step", ) - pos_emb_dim = cfg.pos_emb_dim or cfg.query_dim - self._add_child("relative_pos_emb", cfg.relative_pos_emb.set(dim=pos_emb_dim)) - self._add_child( - "r_proj", - cfg.relative_pos_linear.clone( - model_dim=pos_emb_dim, num_heads=cfg.num_heads, per_head_dim=self.per_head_dim() - ), + inputs["cached_states"] = cached_states + decoder_output.append(extend_step_outputs.data) + decoder_probs.append(extend_step_outputs.probs) + decoder_output = jnp.concatenate(decoder_output, axis=1) + decoder_probs = jnp.concatenate(decoder_probs, axis=2) + assert_allclose(decoder_output, forward_outputs.data, atol=1e-6) + assert_allclose(decoder_probs, forward_outputs.probs, atol=1e-6) + + @parameterized.product( + dtype=(jnp.float32, jnp.float16, jnp.bfloat16), + per_dim_scale=(None, PerDimScale.default_config()), + atten_logit_cap=(0.0, 20.0), + bias=(True, False), + input_linear=(QKVLinear, RoFormerQKVLinear, QLinear), + extend_step_len=(1, 4), + ) + def test_extend_step( + self, + dtype: jnp.dtype, + per_dim_scale: Optional[PerDimScale.Config], + atten_logit_cap: float, + input_linear: attention.BaseQKVLinear, + bias: bool, + extend_step_len: int, + ): + model_dim = 16 + num_heads = 4 + if input_linear == attention.RoFormerQKVLinear: + input_linear = input_linear.default_config().set(rotary_value=False) + else: + input_linear = input_linear.default_config() + cfg = attention.MultiheadAttention.default_config().set( + query_scale=attention.ScaleQuery.default_config().set(per_dim_scale=per_dim_scale), + atten_logit_cap=atten_logit_cap, + input_linear=input_linear, ) - - def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: - cfg = self.config - params = super()._create_layer_parameter_specs() - params["u_bias"] = params["v_bias"] = ParameterSpec( - shape=(cfg.num_heads, self.per_head_dim()), - initializer=constant_initializer(0), - mesh_axes=cfg.relative_pos_linear.param_partition_spec[-2:], + self._test_extend_step( + cfg, + model_dim=model_dim, + num_heads=num_heads, + dtype=dtype, + bias=bias, + extend_step_len=extend_step_len, ) - return params - def forward( - self, - query: Tensor, - *, - key: Optional[Tensor] = None, - value: Optional[Tensor] = None, - **kwargs, - ) -> MultiheadAttention.Output: - if key is not None or value is not None: - raise ValueError("Both key and value must be None for MultiheadAttentionXL") - return super().forward(query, **kwargs) - - def _compute_logits(self, q_proj: Tensor, k_proj: Tensor) -> Tensor: - cfg = self.config - with child_context("apply_query_norm", module=self): - # We apply the query norm (if configured) to the projection (not the logits). - q_proj = self.scale_query.apply_norm(q_proj) - - with child_context("apply_per_dim_scale", module=self): - q_proj = self.scale_query.apply_per_dim_scale(q_proj) - - if cfg.scale_position == MultiheadAttentionXL.ScalePosition.QUERY: - with child_context("apply_scale_factor_queries", module=self): - q_proj = self.scale_query.apply_scale_factor(q_proj) - - seq_len = q_proj.shape[1] - # [2*seq_len - 1, pos_emb_dim]. - # - # Following the XLNet implementation - # https://github.com/zihangdai/xlnet/blob/bbaa3a6fa0b3a2ee694e8cf66167434f9eca9660/modeling.py#L215 - # https://github.com/huggingface/transformers/blob/28d0048218ad7bce69510b16024510afba0daed2/src/transformers/models/xlnet/modeling_xlnet.py#L1030 - # the positions are from descending from seq_len - 1 to -seq_len + 1. - pos_emb = self.relative_pos_emb(jnp.arange(seq_len - 1, -seq_len, -1, dtype=jnp.int32)) - # [2*seq_len - 1, num_heads, per_head_dim]. - r_proj = self.r_proj(pos_emb) - - # Apply key scaling. - k_proj = self.scale_key(k_proj) - - logits = xl_attention_logits( - q_proj=q_proj, - k_proj=k_proj, - relative_pos_emb=r_proj, - u=self.parameters["u_bias"], - v=self.parameters["v_bias"], - ) - if cfg.scale_position == MultiheadAttentionXL.ScalePosition.LOGIT: - # In the original XL-Net code, it applies scale on AC + BD: - # - # https://github.com/zihangdai/xlnet/blob/bbaa3a6fa0b3a2ee694e8cf66167434f9eca9660/modeling.py#L148 - with child_context("apply_scale_factor_logits", module=self): - logits = self.scale_query.apply_scale_factor(logits) - return logits - - def extend_step( + @parameterized.product( + dtype=(jnp.float32, jnp.float16, jnp.bfloat16), + per_dim_scale=(None, PerDimScale.default_config()), + atten_logit_cap=(0.0, 20.0), + num_kv_heads=(1, 2, 4), + input_linear=(attention.GroupedQKVLinear, attention.FusedGroupedQKVLinear), + bias=(True, False), + extend_step_len=(1, 4), + ) + def test_gqa_extend_step( self, - cached_states: NestedTensor, - query: Tensor, - **kwargs, - ) -> tuple[NestedTensor, MultiheadAttention.Output]: - raise NotImplementedError(type(self)) - - -class TransformerAttentionLayer(BaseLayer): - """A Transformer attention layer with normalization and a skip connection. + dtype: jnp.dtype, + per_dim_scale: Optional[PerDimScale.Config], + atten_logit_cap: float, + num_kv_heads: int, + input_linear: type[attention.BaseQKVLinear], + bias: bool, + extend_step_len: int, + ): + model_dim = 16 + num_heads = 4 + cfg = attention.GroupedQueryAttention.default_config().set( + query_scale=attention.ScaleQuery.default_config().set(per_dim_scale=per_dim_scale), + atten_logit_cap=atten_logit_cap, + input_linear=input_linear.default_config().set(num_kv_heads=num_kv_heads), + ) + self._test_extend_step( + cfg, + model_dim=model_dim, + num_heads=num_heads, + dtype=dtype, + bias=bias, + extend_step_len=extend_step_len, + ) - Can be used for either self-attention or cross-attention. - """ + def _test_prefill_states( + self, + attention_cfg: attention.MultiheadAttention.Config, + *, + model_dim: int, + num_heads: int, + dtype: jnp.dtype, + bias: bool, + num_kv_heads: Optional[int] = None, + ): + cfg = attention_cfg.set( + query_dim=model_dim, + key_dim=model_dim, + value_dim=model_dim, + num_heads=num_heads, + ) + cfg.input_linear.set(dtype=dtype, cache_dtype=None) + set_bias_recursively(cfg, bias=bias) + layer: attention.MultiheadAttention = cfg.set(name="test").instantiate(parent=None) - @config_class - class Config(BaseLayer.Config): - """Configures TransformerAttentionLayer.""" - - target_dim: Required[int] = REQUIRED # Input target feature dim. - source_dim: Required[int] = REQUIRED # Input source feature dim. - norm: InstantiableConfig = LayerNorm.default_config() # The normalization layer config. - attention: InstantiableConfig = ( - MultiheadAttention.default_config() - ) # The attention layer config. - dropout: InstantiableConfig = Dropout.default_config() # The dropout layer config. - # The stochastic depth layer config. - # Pytorch reference: - # https://github.com/facebookresearch/deit/blob/main/models_v2.py#L58 - # Tensorflow reference: - # https://github.com/tensorflow/models/blob/master/official/projects/vit/modeling/nn_blocks.py#L86-L92 - stochastic_depth: InstantiableConfig = StochasticDepth.default_config() - # The inner structure of the layer: prenorm or postnorm. See - # https://arxiv.org/abs/2002.04745 for background. - # The structure also support hybridnorm, which uses two norms in the residual branch. - # hybridnorm: TransformerAttentionLayer(x) = x + layernorm_2(attention(layernorm_1(x))) - # Ref: https://github.com/google/praxis/blob/main/praxis/layers/transformers.py#L1129 - # TODO (bwzhang@) Adding a unittest for the hybridnorm. - structure: str = "prenorm" + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) - def __init__(self, cfg: Config, *, parent: Module): - super().__init__(cfg, parent=parent) - cfg = self.config - if cfg.structure in ["prenorm", "postnorm"]: - self._add_child("norm", cfg.norm.set(input_dim=cfg.target_dim)) - elif cfg.structure == "hybridnorm": - self._add_child("prenorm", cfg.norm.set(input_dim=cfg.target_dim)) - self._add_child("postnorm", cfg.norm.set(input_dim=cfg.target_dim)) + batch_size, tgt_len = 3, 6 + query = jax.random.normal( + jax.random.PRNGKey(123), [batch_size, tgt_len, model_dim], dtype=dtype + ) + if attention_cfg.klass == attention.GroupedQueryAttention: + key = value = None else: - raise NotImplementedError(cfg.structure) - self._add_child( - "attention", - cfg.attention.set( - query_dim=cfg.target_dim, - key_dim=cfg.source_dim, - value_dim=cfg.source_dim, - output_dim=cfg.target_dim, + # Make key and value distinct from query. Otherwise, it is equivalent + # to the query only case. + key = value = query + 0.1 + attention_logit_biases = attention_bias.make_causal_biases(tgt_len) + return_aux = {"probs"} + + forward_outputs, _ = F( + layer, + state=layer_params, + is_training=False, + prng_key=jax.random.PRNGKey(456), + inputs=dict( + query=query, + key=key, + value=value, + attention_logit_biases=attention_logit_biases, + return_aux=return_aux, ), ) - self._add_child("dropout", cfg.dropout) - self._add_child("stochastic_depth", cfg.stochastic_depth) - - class Output(NamedTuple): - """Outputs of TransformerAttentionLayer. - Fields: - data: [batch, target_length, output_dim]. The attention output. Always present. - probs: The attention probabilities returned by the attention layer. - Populated if "probs" is in return_aux. - kv_state: The KV state used to compute output. - Populated if "kv_state" is in return_aux. - """ - - data: Tensor - probs: Optional[Tensor] = None - kv_state: Optional[KVState] = None + time_step = jnp.arange(batch_size) + (initial_states, initial_output), _ = F( + layer, + state=layer_params, + is_training=False, + prng_key=jax.random.PRNGKey(456), + inputs=dict( + time_step=time_step, + query=query, + key=key, + value=value, + attention_logit_biases=attention_logit_biases, + return_aux=return_aux, + ), + method="init_states", + ) - def _forward_for_mode( - self, - *, - mode: ForwardMode, - target: Union[Tensor, TensorSpec], - source: Optional[Union[Tensor, KVState]] = None, - attention_logit_biases: Optional[Tensor] = None, - segment_ids: Optional[Tensor] = None, - target_positions: Optional[Tensor] = None, - cached_states: Optional[NestedTensor] = None, - return_aux: Optional[set[str]] = None, - ) -> tuple[Optional[Nested[Tensor]], Optional[Output]]: - """Computes either self-attention or cross-attention for the given target and source. - - Args: - mode: Configures whether `cached_states` are consumed or emitted. See `ForwardMode` for - details. - target: A Tensor or TensorSpec of shape [batch, target_length, target_dim]. - source: An optional KVState or Tensor of shape [batch, source_length, source_dim]. - If None, uses norm(target) as source (self-attention). - attention_logit_biases: See ``On attention logit biases`` in the file comments. - segment_ids: See ``On segment_ids`` in the file comments. - target_positions: See ``On positions`` in the file comments. - cached_states: Optional NestedTensor as produced by `init_states`. - return_aux: See comments on `Output`. - - Returns: - A tuple (cached_states, output): - * cached_states: An optional Nested Tensor of cache states, depending on `mode`. - * output: An optional Output instance, where .data is of the same shape as query and - .probs is of shape [batch, num_heads, target_length, source_length]. - If initializing cache from scratch, output will be None. - - Raises: - ValueError: If `mode` is unsupported. - NotImplementedError: If `cfg.structure` is not supported. - """ - cfg = self.config + # Check time_step and shapes of state. + self.assertEqual(["i_proj"], list(initial_states.keys())) + self.assertTrue(jnp.all(time_step == initial_states["i_proj"]["time_step"])) + for proj in ["key", "value"]: + self.assertEqual( + (batch_size, tgt_len, num_kv_heads or num_heads, model_dim // num_heads), + initial_states["i_proj"][proj].shape, + ) + self.assertEqual( + dtype, + initial_states["i_proj"][proj].dtype, + ) - if source is None: - kv_kwargs = {} - elif isinstance(source, KVState): - kv_kwargs = {"kv_state": source} - elif isinstance(source, Tensor): - kv_kwargs = {"key": source, "value": source} + # Zero-out outputs starting from initial time_step, and test that we can recover the full + # outputs by calling extend_step starting from time_step. + # [batch, tgt_len]. + time_step_mask = jnp.arange(tgt_len) < time_step[:, None] + # [batch, tgt_len, model_dim]. + decoder_output = initial_output.data * time_step_mask[..., None] + # [batch, tgt_len, model_dim] --> [batch, model_dim, tgt_len]. + decoder_output = jnp.moveaxis(decoder_output, -2, -1) + + # [batch, num_heads, tgt_len, src_len]. + if initial_output.probs is None: + decoder_probs = None else: - raise NotImplementedError(source) - kv_kwargs["return_aux"] = return_aux - - def attention_thunk(target: Tensor) -> tuple[Optional[NestedTensor], Tensor]: - if mode == ForwardMode.FORWARD: - atten_state, atten_output = ( - None, - self.attention( - query=target, - **kv_kwargs, - attention_logit_biases=attention_logit_biases, - segment_ids=segment_ids, - query_positions=target_positions, - ), - ) - elif mode == ForwardMode.INIT_STATES: - assert cached_states is not None - assert segment_ids is None - assert target_positions is None - atten_state, atten_output = self.attention.init_states( - time_step=cached_states["attention"], - query=target, - **kv_kwargs, - attention_logit_biases=attention_logit_biases, + decoder_probs = initial_output.probs * time_step_mask[:, None, :, None] + # [batch, num_heads, tgt_len, src_len] --> [batch, num_heads, src_len, tgt_len]. + decoder_probs = jnp.moveaxis(decoder_probs, -2, -1) + + # Call extend_step from time_step, ensuring that outputs match. + inputs = dict(cached_states=initial_states, return_aux=return_aux) + while jnp.any(time_step < tgt_len): + # [batch, tgt_len=1, model_dim]. + inputs["query"] = jnp.take_along_axis( + query, time_step[:, None, None], axis=1, mode="clip" + ) + if key is not None: + inputs["key"] = jnp.take_along_axis( + key, time_step[:, None, None], axis=1, mode="clip" ) - elif mode == ForwardMode.EXTEND_STEP: - assert cached_states is not None - assert segment_ids is None - assert target_positions is None - atten_state, atten_output = self.attention.extend_step( - cached_states["attention"], - target, - **kv_kwargs, - attention_logit_biases=attention_logit_biases, + inputs["value"] = jnp.take_along_axis( + value, time_step[:, None, None], axis=1, mode="clip" ) - else: - raise ValueError(f"Unrecognized mode {mode}.") - return atten_state, atten_output - - if mode == ForwardMode.INIT_STATES: - assert cached_states is not None - if cached_states["attention"] is None: - atten_state, atten_output = attention_thunk(TensorSpec(target.shape, target.dtype)) - return dict(attention=atten_state), atten_output - - if cfg.structure == "prenorm": - skip_input = target # pre-norm: where normalization happens within the residual part. - norm_target = self.norm(target) - atten_state, atten_output = attention_thunk(norm_target) - data = skip_input + self.stochastic_depth(self.dropout(atten_output.data)) - elif cfg.structure == "postnorm": - # This is the structure used by the original Transformer, BERT, and RoBERTa. - atten_state, atten_output = attention_thunk(target) - # Post-norm: norm applied on the sum of input and attention output. - data = self.norm(target + self.stochastic_depth(self.dropout(atten_output.data))) - elif cfg.structure == "hybridnorm": - skip_input = target # pre-norm: where normalization happens within the residual part. - norm_target = self.prenorm(target) - atten_state, atten_output = attention_thunk(norm_target) - data = skip_input + self.stochastic_depth( - self.dropout(self.postnorm(atten_output.data)) + # [batch=1, tgt_len=1, tgt_len]. + inputs["attention_logit_biases"] = jnp.take_along_axis( + attention_logit_biases[None, :, :], time_step[:, None, None], axis=1, mode="clip" + ) + (updated_state, outputs), _ = F( + layer, + state=layer_params, + is_training=False, + prng_key=jax.random.PRNGKey(456), + inputs=inputs, + method="extend_step", ) + inputs["cached_states"] = updated_state + + # [batch, model_dim, tgt_len=1] + curr_outputs = jnp.moveaxis(outputs.data, -2, -1) + # [batch, num_heads, src_len, tgt_len=1] + curr_probs = jnp.moveaxis(outputs.probs, -2, -1) + + # [batch, 1, tgt_len]. + oh_indices = jax.nn.one_hot(time_step, tgt_len)[:, None, :] + decoder_output = decoder_output + curr_outputs * oh_indices + # [batch, 1, 1, tgt_len]. + oh_indices = oh_indices[..., None, :] + decoder_probs = decoder_probs + curr_probs * oh_indices + time_step = time_step + 1 + + # [batch, model_dim, tgt_len] --> [batch, tgt_len, model_dim]. + decoder_output = jnp.moveaxis(decoder_output, -1, -2) + # [batch, num_heads, src_len, tgt_len] --> [batch, num_heads, tgt_len, src_len]. + decoder_probs = jnp.moveaxis(decoder_probs, -1, -2) + + assert_allclose(decoder_output, forward_outputs.data) + assert_allclose(decoder_probs, forward_outputs.probs) + + @parameterized.product( + dtype=(jnp.float32, jnp.float16, jnp.bfloat16), + per_dim_scale=(None, PerDimScale.default_config()), + atten_logit_cap=(0.0, 20.0), + bias=(True, False), + input_linear=(attention.QKVLinear, attention.RoFormerQKVLinear), + ) + def test_prefill_states( + self, + dtype: jnp.dtype, + per_dim_scale: Optional[PerDimScale.Config], + atten_logit_cap: float, + bias: bool, + input_linear: attention.BaseQKVLinear, + ): + model_dim = 16 + num_heads = 4 + if input_linear == attention.RoFormerQKVLinear: + input_linear = input_linear.default_config().set(rotary_value=False) else: - raise NotImplementedError(cfg.structure) - return dict(attention=atten_state), self.Output( - data=data, probs=atten_output.probs, kv_state=atten_output.kv_state + input_linear = input_linear.default_config() + cfg = attention.MultiheadAttention.default_config().set( + query_scale=attention.ScaleQuery.default_config().set(per_dim_scale=per_dim_scale), + atten_logit_cap=atten_logit_cap, + input_linear=input_linear, ) - - def forward( - self, - *, - target: Tensor, - source: Optional[Union[Tensor, KVState]] = None, - attention_logit_biases: Optional[Tensor] = None, - segment_ids: Optional[Tensor] = None, - target_positions: Optional[Tensor] = None, - return_aux: Optional[set[str]] = None, - ) -> Output: - """Computes attention with target as query and source as key and value. - - Args: - target: A Tensor of shape [batch, target_length, target_dim]. - source: An optional KVState or Tensor of shape [batch, source_length, source_dim]. - If None, uses norm(target) as source (self-attention) - attention_logit_biases: See ``On attention logit biases`` in the file comments. - segment_ids: See ``segment_ids`` in the file comments. - target_positions: See ``positions`` in the file comments. - return_aux: See comments on `Output`. - - Returns: - An Output instance, where .data is of the same shape as target and .probs is of shape - [batch, num_heads, target_length, source_length]. - - Raises: - NotImplementedError: If cfg.structure is unsupported. - """ - _, output = self._forward_for_mode( - mode=ForwardMode.FORWARD, - target=target, - source=source, - attention_logit_biases=attention_logit_biases, - segment_ids=segment_ids, - target_positions=target_positions, - cached_states=None, - return_aux=return_aux, + self._test_prefill_states( + cfg, model_dim=model_dim, num_heads=num_heads, dtype=dtype, bias=bias ) - return output - def init_states( + @parameterized.product( + dtype=(jnp.float32, jnp.float16, jnp.bfloat16), + per_dim_scale=(None, PerDimScale.default_config()), + atten_logit_cap=(0.0, 20.0), + num_kv_heads=(1, 2, 4), + input_linear=(attention.GroupedQKVLinear, attention.FusedGroupedQKVLinear), + bias=(True, False), + ) + def test_gqa_prefill_states( self, - *, - time_step: Optional[Tensor], - target: Union[Tensor, TensorSpec], - source: Optional[Union[Tensor, KVState]] = None, - attention_logit_biases: Optional[Tensor] = None, - return_aux: Optional[set[str]] = None, - ) -> tuple[Nested[Tensor], Optional[Output]]: - """Initializes cache for autoregressive cached decoding. - - The method supports initializing an empty cache as well as prefilling: - * To initialize an empty cache, specify `time_step=None`. - In this case, `target` is allowed to be a TensorSpec. - * To prefill, provide `time_step` and `target` as Tensors. - - Args: - time_step: A Tensor of shape [batch]. Each value is an index into the length dimension - indicating where decoding will start from. - target: Tensor of shape [batch, target_length, target_dim] corresponding to query vector - at `time_step` indices. For batch index `i`, only `target[i, :time_step[i], ...]` - will affect subsequent decoding. - source: An optional KVState or Tensor of shape [batch, source_length, source_dim]. - If None, uses norm(target) as source (self-attention) - attention_logit_biases: See ``On attention logit biases`` in the file comments. - return_aux: See comments on `Output`. - - Returns: - A tuple (init_states, output): - * init_states: A Nested Tensor state depending on the `attention` layer implementation. - * output: In the prefill case, an Output instance, where .data is of the same shape as - query, .probs is of shape [batch, num_heads, target_length, source_length]. - Otherwise, if initializing cache from scratch, output will be None. - """ - return self._forward_for_mode( - mode=ForwardMode.INIT_STATES, - target=target, - source=source, - cached_states=dict(attention=time_step), - attention_logit_biases=attention_logit_biases, - return_aux=return_aux, + dtype: jnp.dtype, + per_dim_scale: Optional[PerDimScale.Config], + atten_logit_cap: float, + num_kv_heads: int, + input_linear: type[attention.BaseQKVLinear], + bias: bool, + ): + model_dim = 16 + num_heads = 4 + cfg = attention.GroupedQueryAttention.default_config().set( + query_scale=attention.ScaleQuery.default_config().set(per_dim_scale=per_dim_scale), + atten_logit_cap=atten_logit_cap, + input_linear=input_linear.default_config().set(num_kv_heads=num_kv_heads), ) - - def extend_step( - self, - cached_states: NestedTensor, - target: Tensor, - *, - source: Optional[Union[Tensor, KVState]] = None, - attention_logit_biases: Optional[Tensor] = None, - return_aux: Optional[set[str]] = None, - ) -> tuple[Nested[Tensor], Output]: - """Computes the value vector given the query of the current step. - This function is used by autoregressive decoding. - - Args: - cached_states: A `NestedTensor` object containing tensors which are the - results of previous attentions, and index used for fast decoding. Contains - "attention" cached states. - target: Tensor of shape [B, 1, D] corresponding to query vector at index - time_step. - source: An optional KVState or Tensor of shape [batch, source_length, source_dim]. - If None, uses norm(target) as source (self-attention) - attention_logit_biases: See ``On attention logit biases`` in the file comments. - Additionally, target_length is expected to be 1 since this is per time step. - attention_logit_biases should have already taken care of causal masking for - decoding, plus other maskings necessary. - return_aux: See comments on `Output`. - - Returns: - A `NestedTensor` state of key and value pair along with index updated at `time_step`. - An Output instance, where .data is of the same shape as query, .probs is of shape - [batch, num_heads, 1, source_length]. - - Raises: - NotImplementedError: If cfg.structure is not supported. - """ - return self._forward_for_mode( # pytype: disable=bad-return-type - mode=ForwardMode.EXTEND_STEP, - target=target, - source=source, - cached_states=cached_states, - attention_logit_biases=attention_logit_biases, - return_aux=return_aux, + self._test_prefill_states( + cfg, + model_dim=model_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + dtype=dtype, + bias=bias, ) + def test_gqa_against_mha(self): + model_dim = 16 + num_heads = 4 + num_kv_heads = 2 + ref_cfg = attention.MultiheadAttention.default_config().set( + name="mha", + query_dim=model_dim, + key_dim=model_dim, + value_dim=model_dim, + num_heads=num_heads, + ) + ref_layer = ref_cfg.instantiate(parent=None) + + test_cfg = attention.GroupedQueryAttention.default_config().set( + name="gqa", + query_dim=model_dim, + key_dim=model_dim, + value_dim=model_dim, + num_heads=num_heads, + input_linear=attention.GroupedQKVLinear.default_config().set(num_kv_heads=num_kv_heads), + ) + test_layer = test_cfg.instantiate(parent=None) + + # Initialize layer parameters. + prng_key = jax.random.PRNGKey(123) + prng_key, init_key, data_key = jax.random.split(prng_key, num=3) + state = ref_layer.initialize_parameters_recursively(init_key) + + batch, seq_len = 2, 10 + per_head_dim = ref_layer.per_head_dim() + q = jax.random.uniform(data_key, (batch, seq_len, num_heads, per_head_dim)) + k = jax.random.uniform(data_key, (batch, seq_len, num_kv_heads, per_head_dim)) + v = jax.random.uniform(data_key, (batch, seq_len, num_kv_heads, per_head_dim)) + attention_logit_biases = attention_logit_biases = attention_bias.ZeroAttentionBias() + + (test_context, ref_probs), _ = F( + test_layer, + method="_compute_attention", + state=state, + is_training=False, + prng_key=prng_key, + inputs=dict( + q_proj=q, k_proj=k, v_proj=v, attention_logit_biases=attention_logit_biases + ), + ) -def scaled_hidden_dim(scale: float = 4) -> FunctionConfigBase: - def scale_fn(input_dim: int, *, scale: float) -> int: - return round(input_dim * scale) + k = jnp.repeat(k, num_heads // num_kv_heads, axis=2) + v = jnp.repeat(v, num_heads // num_kv_heads, axis=2) + + (ref_context, ref_probs), _ = F( + ref_layer, + method="_compute_attention", + state=state, + is_training=False, + prng_key=prng_key, + inputs=dict( + q_proj=q, k_proj=k, v_proj=v, attention_logit_biases=attention_logit_biases + ), + ) - return config_for_function(scale_fn).set(scale=scale) + assert_allclose(ref_context, test_context) + assert_allclose(ref_probs, ref_probs) + def _scale_query_kwargs( + self, + *, + query_scale_factor: Union[None, int, float, InstantiableConfig[attention.ScaleFn]], + key_scale_factor: Union[None, int, float, InstantiableConfig[attention.ScaleFn]], + ): + model_dim = 16 + if isinstance(query_scale_factor, (int, float)): + query_scale_factor = config_for_function(attention.constant_scale_fn).set( + value=query_scale_factor + ) + if isinstance(key_scale_factor, (int, float)): + key_scale_factor = config_for_function(attention.constant_scale_fn).set( + value=key_scale_factor + ) -class TransformerFeedForwardLayer(BaseLayer): - """A Transformer feed-forward layer.""" + cfg = attention.MultiheadAttention.default_config().set( + name="test", + query_dim=model_dim, + key_dim=model_dim, + value_dim=model_dim, + num_heads=2, + query_scale=attention.ScaleQuery.default_config().set(scale_factor=query_scale_factor), + key_scale=attention.ScaleKey.default_config().set(scale_factor=key_scale_factor), + ) + cfg.input_linear.layer.bias = False + cfg.output_linear.bias = False + layer = cfg.instantiate(parent=None) - @config_class - class Config(BaseLayer.Config): - """Configures TransformerFeedForwardLayer.""" - - input_dim: Required[int] = REQUIRED # Input feature dim. - # The hidden dim. - # It should be given either as an integer or a function config that instantiates - # a dim-to-dim function, e.g., scaled_hidden_dim(4). - hidden_dim: Required[Union[int, FunctionConfigBase]] = REQUIRED - # Config for the first linear layer. - linear1: InstantiableConfig = Linear.default_config().set( - param_partition_spec=[None, "model"] - ) - # Config for the second linear layer. - linear2: InstantiableConfig = Linear.default_config().set( - param_partition_spec=["model", None] - ) - norm: InstantiableConfig = LayerNorm.default_config() # The normalization layer config. - - # The activation function(s). - # - # If a single string, the activation applied on the output of linear1. - # - # If a tuple of two strings, this layer will contain separate child Linear layers, one for - # each activation function, according to cfg.linear1 with `hidden_dim` as the output dim. - # The activation outputs will be multiplied element-wise to produce the inputs for linear2. - # See the implementation in _linear1_activation(). - # This supports the gated linear activations proposed by Shazeer in - # https://arxiv.org/abs/2002.05202. - activation: Union[str, tuple[str, str]] = "nn.relu" - - # The dropout layer config. - dropout: InstantiableConfig = Dropout.default_config() - - # The stochastic depth layer config. - # Pytorch reference: - # https://github.com/facebookresearch/deit/blob/main/models_v2.py#L59 - # Tensorflow reference: - # https://github.com/tensorflow/models/blob/master/official/projects/vit/modeling/nn_blocks.py#L103-L119 - stochastic_depth: InstantiableConfig = StochasticDepth.default_config() - - # The inner structure of the layer: "prenorm", "postnorm", "hybridnorm", "nonorm". - # * prenorm: y = x + feedforward(norm(x)) - # * postnorm: y = norm(x + feedforward(x)) - # * hybridnorm: y = postnorm(x + feedforward(prenorm(x))) - # * nonorm: y = feedforward(x) # no residual, which is usually applied externally. - # - # References: - # prenorm/postnorm: https://arxiv.org/abs/2002.04745. - # hybridnorm: https://github.com/google/praxis/blob/main/praxis/layers/transformers.py#L273 - # nonorm: see ParallelTransformerLayer. - structure: str = "prenorm" - - # outputs = inputs + residual_weight * x. - residual_weight: float = 1.0 - - # Auxiliary stats. - - # If True, add "dead_neurons/{activation}" stats for activation functions that have - # zones of near-zero gradients, e.g., x < 0 for ReLU. - # - # A "neuron" `i` is considered dead if all of x[..., i] (across batch/seq) fall within the - # dead zone. - # - # Only supported for a subset of activation functions, including relu, gelu, and silu. - add_dead_neuron_summary: Optional[bool] = None - - # Adds summary of RMS norms of the specified values. Supported value are: - # - "inputs": inputs of the layer. - # - "linear1_outputs": outputs of linear1. - # - "linear2_outputs": outputs of linear2. - # TODO(tlei3): deprecate this feature since we use TensorStats. - add_value_rms_norm_summary: Sequence[str] = [] + param_specs = layer.create_parameter_specs_recursively() + layer_params = jax.tree.map( + lambda spec: jnp.ones(spec.shape, dtype=spec.dtype), param_specs + ) - def __init__(self, cfg: Config, *, parent: Module): - super().__init__(cfg, parent=parent) - cfg: TransformerFeedForwardLayer.Config = self.config - if cfg.structure in ["prenorm", "postnorm"]: - self._add_child("norm", cfg.norm.set(input_dim=cfg.input_dim)) - elif cfg.structure == "hybridnorm": - self._add_child("prenorm", cfg.norm.set(input_dim=cfg.input_dim)) - self._add_child("postnorm", cfg.norm.set(input_dim=cfg.input_dim)) - elif cfg.structure == "nonorm": - pass - else: - raise NotImplementedError(cfg.structure) + batch_size = 3 + tgt_len = 10 # Must be even. + query = jnp.concatenate( + ( + jnp.ones([batch_size, tgt_len // 2, model_dim]), + jnp.zeros([batch_size, tgt_len // 2, model_dim]), + ), + axis=1, + ) + kwargs = dict( + module=layer, + state=layer_params, + is_training=False, + prng_key=jax.random.PRNGKey(456), + inputs=dict(query=query), + ) + return kwargs - if isinstance(cfg.hidden_dim, int): - hidden_dim = cfg.hidden_dim - else: - hidden_dim = cfg.hidden_dim.set(input_dim=cfg.input_dim).instantiate() - if isinstance(cfg.activation, tuple): - assert len(cfg.activation) == 2, cfg.activation - # Create a linear1 projection for each activation. - for i in range(len(cfg.activation)): - self._add_child( - f"linear1_{i}", - cfg.linear1.set(input_dim=cfg.input_dim, output_dim=hidden_dim), - ) - else: - assert isinstance(cfg.activation, str), cfg.activation - self._add_child( - "linear1", - cfg.linear1.set(input_dim=cfg.input_dim, output_dim=hidden_dim), + @parameterized.product(query_scale_factor=[None, 7], key_scale_factor=[None, 11]) + def test_scale_query_key( + self, *, query_scale_factor: Optional[float], key_scale_factor: Optional[float] + ): + kwargs = self._scale_query_kwargs( + query_scale_factor=query_scale_factor, key_scale_factor=key_scale_factor + ) + kwargs["inputs"]["return_aux"] = {"probs"} + forward_outputs, _ = F(**kwargs) + if query_scale_factor is None: + query_scale_factor = kwargs["module"].per_head_dim() ** -0.5 + if key_scale_factor is None: + key_scale_factor = 1 + query_scale_factor = float(query_scale_factor) + key_scale_factor = float(key_scale_factor) + self.assertNestedAllClose( + forward_outputs.probs[0, 0, 0, 0], + # All ones matrix times all ones vector has l2 norm dim ** 1.5. + # Half of input tokens are all ones, half are all zeros. + jax.nn.sigmoid( + kwargs["inputs"]["query"].shape[-1] ** 3 * query_scale_factor * key_scale_factor, ) - self._add_child( - "linear2", - cfg.linear2.set(input_dim=hidden_dim, output_dim=cfg.input_dim), - ) - if cfg.structure in ["prenorm", "hybridnorm", "nonorm"]: - self._add_child("dropout1", cfg.dropout) - self._add_child("dropout2", cfg.dropout) - elif cfg.structure in ["postnorm"]: - self._add_child("dropout", cfg.dropout) - else: - raise NotImplementedError(cfg.structure) + / (kwargs["inputs"]["query"].shape[1] // 2), + ) - self._add_child("stochastic_depth", cfg.stochastic_depth) - # TODO(tlei3): deprecate this check since we will use TensorStats to handle what - # tensors are logged. - for value in cfg.add_value_rms_norm_summary: - if value not in ["inputs", "linear1_outputs", "linear2_outputs"]: - raise NotImplementedError(f"add_value_rms_norm_summary: {value}") + def test_scale_query_key_dim_dependence(self): + query_scale_factor = config_for_function(attention.pow_scale_fn).set(exp=1) + key_scale_factor = config_for_function(attention.pow_scale_fn).set(exp=-1) + kwargs = self._scale_query_kwargs( + query_scale_factor=query_scale_factor, key_scale_factor=key_scale_factor + ) + kwargs["inputs"]["return_aux"] = {"probs"} + forward_outputs, _ = F(**kwargs) + self.assertNestedAllClose( + forward_outputs.probs[0, 0, 0, 0], + # All ones matrix times all ones vector has l2 norm dim ** 1.5. + # Half of input tokens are all ones, half are all zeros. + jax.nn.sigmoid(float(kwargs["inputs"]["query"].shape[-1] ** 3)) + / (kwargs["inputs"]["query"].shape[1] // 2), + ) - def forward(self, inputs: Tensor) -> Tensor: - cfg = self.config + def test_scale_query_key_barrier(self): + """Tests that the scale factors are not combined. - def _linear2(x): - """Applies linear2, optionally logging RMS norm of the output.""" - x = self.linear2(x) - self._add_tensor_stats("linear2_outputs", x) - return x - - self._add_tensor_stats("inputs", inputs) - - remat_pt2 = "linear2" - if cfg.structure == "prenorm": - x = self.norm(inputs) - x = self._linear1_activation(x) - x = self.dropout1(x) - x = _linear2(x) - x = self._remat_name(x, remat_pt2) - x = self.dropout2(x) - x = self.stochastic_depth(x) - if cfg.residual_weight != 1: - x *= cfg.residual_weight - x += inputs - elif cfg.structure == "postnorm": - x = self._linear1_activation(inputs) - x = _linear2(x) - x = self._remat_name(x, remat_pt2) - x = self.dropout(x) - x = self.stochastic_depth(x) - if cfg.residual_weight != 1: - x *= cfg.residual_weight - x = self.norm(x + inputs) - elif cfg.structure == "hybridnorm": - x = self.prenorm(inputs) - x = self._linear1_activation(x) - x = self.dropout1(x) - x = _linear2(x) - x = self._remat_name(x, remat_pt2) - x = self.postnorm(x) - x = self.dropout2(x) - x = self.stochastic_depth(x) - if cfg.residual_weight != 1: - x *= cfg.residual_weight - x += inputs - elif cfg.structure == "nonorm": - x = inputs - x = self._linear1_activation(x) - x = self.dropout1(x) - x = _linear2(x) - x = self._remat_name(x, remat_pt2) - x = self.dropout2(x) - x = self.stochastic_depth(x) - # We still apply `residual_weight`, since there is usually a residual link outside of - # this layer, e.g., in ParallelTransformerLayer. - if cfg.residual_weight != 1: - x *= cfg.residual_weight - else: - raise NotImplementedError(cfg.structure) - return x + Note that even without the barrier, it's not clear that they would be combined. + (They aren't on CPU even without the barrier.) + """ + query_scale_factor = 7 + key_scale_factor = 11 + kwargs = self._scale_query_kwargs( + query_scale_factor=query_scale_factor, key_scale_factor=key_scale_factor + ) - def _linear1_activation(self, x: Tensor) -> Tensor: - cfg = self.config - if isinstance(cfg.activation, tuple): - activations = [ - self._get_activation( - self._remat_name(self.children[f"linear1_{i}"](x), f"linear1_{i}"), - activation_fn_name=activation, - ) - for i, activation in enumerate(cfg.activation) - ] - assert len(activations) == 2, cfg.activation - outputs = activations[0] * activations[1] - self._add_tensor_stats("linear1_0_outputs", activations[0]) - self._add_tensor_stats("linear1_1_outputs", activations[1]) - self._add_tensor_stats("linear1_outputs", outputs) - return outputs - else: - x = self.linear1(x) - x = self._remat_name(x, "linear1_0") - x = self._get_activation(x, activation_fn_name=cfg.activation) - self._add_tensor_stats("linear1_outputs", x) - return x + # Check optimized HLO scales by query_scale_factor and key_scale_factor as separate + # multiplications. This only checks the default backend, so it doesn't check + # what happens on gpu/tpu unless jax is configured to use them. + f = jax.jit(F, static_argnames=("module", "is_training")) + compile_options = dict( + xla_cpu_enable_fast_math=True, + xla_cpu_fast_math_honor_nans=False, + xla_cpu_fast_math_honor_infs=False, + xla_cpu_fast_math_honor_functions=False, + xla_cpu_fast_math_honor_division=False, + ) + hlo = f.lower(**kwargs).compile(compile_options).as_text() + hlo = test_utils.clean_hlo(hlo) + self.assertIn(str(query_scale_factor), hlo) + self.assertIn(str(key_scale_factor), hlo) + self.assertNotIn(str(query_scale_factor * key_scale_factor), hlo) + + @parameterized.parameters( + [ + ( + 1.0, + jax.nn.sigmoid((1.0 * 1.0) * 2 - jnp.log(6)), + 6, + ), + ( + 1.0, + jax.nn.sigmoid((1.0 * 1.0) * 2 - jnp.log(4)), + 4, + ), + ( + 2.0, + jax.nn.sigmoid((2.0 * 2.0) * 2 - jnp.log(6)), + 6, + ), + ] + ) + def test_sigmoid_compute_attention(self, qkv_value: float, expected_value: float, seq_len: int): + model_dim = 16 + num_heads = 4 + batch_size = 2 + init_key = jax.random.PRNGKey(123) + + cfg = attention.SigmoidAttention.default_config().set( + seq_len=seq_len, + query_dim=model_dim, + key_dim=model_dim, + value_dim=model_dim, + num_heads=num_heads, + query_scale=attention.ScaleQuery.default_config(), + atten_logit_cap=0.0, + dtype=jnp.float32, + ) + sigmoid_attention = cfg.set(name="sigmoid_attention").instantiate(parent=None) + state = sigmoid_attention.initialize_parameters_recursively(prng_key=init_key) + + qkv_shape = [batch_size, seq_len, num_heads, num_heads] + inputs = dict( + q_proj=jnp.full(qkv_shape, fill_value=qkv_value), + k_proj=jnp.full(qkv_shape, fill_value=qkv_value), + v_proj=jnp.full(qkv_shape, fill_value=qkv_value), + attention_logit_biases=attention_bias.CausalAttentionBias(shape=(seq_len, seq_len)), + ) - def _get_activation(self, x: Tensor, activation_fn_name: str) -> Tensor: - """Applies activation function on 'x' and optionally counts the number of dead neurons. + # Get outputs. + forward_key = jax.random.PRNGKey(456) - Args: - x: A tensor of shape [B, S, H]. - activation_fn_name: The name of the activation fn. + (_, probs), _ = F( + sigmoid_attention, + method="_compute_attention", + state=state, + is_training=False, + prng_key=forward_key, + inputs=inputs, + ) - Returns: - activation_fn(x). - """ - cfg = self.config - if cfg.add_dead_neuron_summary: - if activation_fn_name in ["quick_gelu", "exact_gelu"]: - # To make GELU be sufficiently small. - threshold = -4.0 - elif activation_fn_name in ["nn.silu", "nn.sigmoid"]: - # nn.silu(jnp.array(-10.)) = -0.00045398 - # nn.sigmoid(jnp.array(-10.)) = 4.5397872e-05 - threshold = -10.0 - elif activation_fn_name in ["nn.relu", "squared_relu"]: - threshold = 0 - else: - threshold = None - if threshold is not None: - max_hidden_units = jnp.max(x, axis=(0, 1)) - num_dead_units = jnp.count_nonzero( - jnp.less(max_hidden_units, threshold).astype(jnp.int32) - ) - self.add_summary( - f"dead_neurons/{activation_fn_name}", - num_dead_units, - ) - return get_activation_fn(activation_fn_name)(x) + output_shape = [batch_size, num_heads, seq_len, seq_len] + indexes = jnp.arange(seq_len) + # Zeros outside of the causal triangle. + causal_biases = jax.lax.ge(indexes[:, None], indexes[None, :]) + expected_output = jnp.full(output_shape, fill_value=expected_value) * causal_biases + self.assertNestedAllClose(probs, expected_output) -class TransformerLayer(BaseTransformerLayer): - """A Transformer layer. - Unlike torch.nn.TransformerLayer, this allows components to be customized, e.g., replacing - vanilla attention with relative positional attention from TransformerXL/DeBERTa or replacing - feed-forward with a mixture-of-expert feed-forward layer. - """ +def oracle_xl_attention_logits( + query: np.ndarray, + key: np.ndarray, + relative_pos_emb: np.ndarray, + content_bias: np.ndarray, + positional_bias: np.ndarray, +) -> np.ndarray: + """Computes expected attention logits using non-vectorized approach. - @config_class - class Config(BaseTransformerLayer.Config): - """Configures TransformerLayer.""" + Reference: + https://github.com/tensorflow/lingvo/blob/41212226eac7a26491790c2bd476b78493f93ff6/lingvo/core/attention_util_test.py#L48-L73. - self_attention: InstantiableConfig = TransformerAttentionLayer.default_config() - # If not None, the cross-attention layer config. - cross_attention: Optional[InstantiableConfig] = None - feed_forward: InstantiableConfig = TransformerFeedForwardLayer.default_config() + Note that this implementation follows XLNet implementation and is different from the lingvo + implementation in that here the relative_pos_emb index is computed from key_i - query_i, + while lingvo computes from query_i - key_i. - def __init__(self, cfg: Config, *, parent: Module): - super().__init__(cfg, parent=parent) - cfg: TransformerLayer.Config = self.config - self._add_child( - "self_attention", - cfg.self_attention.set(target_dim=cfg.input_dim, source_dim=cfg.input_dim), + See comments on xl_attention_logits(). + """ + batch, seqlen, num_heads, _ = query.shape + tgtlen, srclen = seqlen, seqlen + + logits = np.zeros((batch, num_heads, tgtlen, srclen)) + + for b in range(batch): + for n in range(num_heads): + for i in range(tgtlen): + for j in range(srclen): + offset = seqlen - 1 + pos_emb = relative_pos_emb[j - i + offset] + logits[b][n][i][j] = np.dot(query[b][i][n], key[b][j][n]) + logits[b][n][i][j] += np.dot(query[b][i][n], pos_emb[n]) + logits[b][n][i][j] += np.dot(content_bias[n], key[b][j][n]) + logits[b][n][i][j] += np.dot(positional_bias[n], pos_emb[n]) + return logits + + +class TransformerXLTest(TestCase): + """Tests TransformerXL.""" + + @parameterized.parameters(5, 2, 1) + def test_rel_pos_to_abs_pos(self, seq_len): + # rel_offset[:, i] = i - (seq_len - 1), i.e., in range [-seq_len + 1, seq_len - 1]. + rel_offset = jnp.tile(jnp.arange(-seq_len + 1, seq_len)[None, :], [seq_len, 1]) + # abs_pos[i, j] = j - i. + abs_pos = rel_pos_to_abs_pos(rel_offset) + expected = jnp.arange(seq_len)[None, :] - jnp.arange(seq_len)[:, None] + assert_allclose(abs_pos, expected) + + def test_xl_attention_logits(self): + num_heads, per_head_dim = 4, 3 + batch_size, tgt_len = 2, 5 + q = jax.random.normal( + jax.random.PRNGKey(100), + [batch_size, tgt_len, num_heads, per_head_dim], + dtype=jnp.float32, + ) + k = jax.random.normal( + jax.random.PRNGKey(101), + [batch_size, tgt_len, num_heads, per_head_dim], + dtype=jnp.float32, + ) + relative_pos_emb = jax.random.normal( + jax.random.PRNGKey(102), [2 * tgt_len - 1, num_heads, per_head_dim], dtype=jnp.float32 + ) + u = jax.random.normal(jax.random.PRNGKey(103), [num_heads, per_head_dim], dtype=jnp.float32) + v = jax.random.normal(jax.random.PRNGKey(104), [num_heads, per_head_dim], dtype=jnp.float32) + expected = oracle_xl_attention_logits( + query=q, key=k, relative_pos_emb=relative_pos_emb, content_bias=u, positional_bias=v + ) + actual = xl_attention_logits( + q_proj=q, k_proj=k, relative_pos_emb=relative_pos_emb, u=u, v=v + ) + assert_allclose(actual, expected) + + @parameterized.product( + per_dim_scale=(None, PerDimScale.default_config()), + scale_position=( + MultiheadAttentionXL.ScalePosition.LOGIT, + MultiheadAttentionXL.ScalePosition.QUERY, + ), + ) + def test_per_dim_scale(self, per_dim_scale, scale_position): + model_dim = 6 + num_heads = 2 + cfg = attention.TransformerAttentionLayer.default_config().set( + name="test", + target_dim=model_dim, + source_dim=model_dim, + structure="postnorm", + attention=MultiheadAttentionXL.default_config().set( + num_heads=num_heads, + query_scale=attention.ScaleQuery.default_config().set(per_dim_scale=per_dim_scale), + scale_position=scale_position, + ), + ) + cfg.attention.output_linear.bias = False + cfg.attention.vlog = 5 + + layer: attention.TransformerAttentionLayer = cfg.instantiate(parent=None) + layer_params = layer.initialize_parameters_recursively(jax.random.PRNGKey(123)) + batch_size, tgt_len = 2, 5 + target = jax.random.normal( + jax.random.PRNGKey(100), [batch_size, tgt_len, model_dim], dtype=jnp.float32 ) - self._add_child("feed_forward", cfg.feed_forward.set(input_dim=cfg.input_dim)) - if cfg.cross_attention is not None: - self._add_child("cross_attention", cfg.cross_attention.set(target_dim=cfg.input_dim)) - def _forward_for_mode( - self, - *, - mode: ForwardMode, - data: Union[Tensor, TensorSpec], - self_attention_kv_state: Optional[KVState] = None, - self_attention_logit_biases: Optional[Tensor] = None, - cross_attention_data: Optional[Tensor] = None, - cross_attention_logit_biases: Optional[Tensor] = None, - target_segment_ids: Optional[Tensor] = None, - target_positions: Optional[Tensor] = None, - cached_states: Optional[NestedTensor] = None, - return_aux: Optional[set[str]] = None, - ) -> tuple[Optional[NestedTensor], Optional[BaseTransformerLayer.Output]]: - """Computes transformer layer outputs and self/cross-attention probabilities. - - Args: - mode: Configures whether `cached_states` are consumed or emitted. See `ForwardMode` for - details. - data: A Tensor or TensorSpec of shape [batch, target_length, target_dim]. - self_attention_kv_state: An optional KVState used for self-attention. - self_attention_logit_biases: An optional Tensor representing the self-attention biases. - cross_attention_data: An optional Tensor of shape [batch, source_length, source_dim]. - cross_attention_logit_biases: An optional Tensor representing the cross-attention - biases. - target_segment_ids: See ``segment_ids`` in the file comments. - target_positions: See ``positions`` in the file comments. - cached_states: Optional NestedTensor as produced by `init_states`. - return_aux: See comments on BaseTransformerLayer.forward. - - Returns: - A tuple (cached_states, output): - * cached_states: An optional Nested Tensor of cache states, depending on `mode`. - * output: An optional Output instance, where .data is of the same shape as `data`, - .self_attention_probs is of shape [batch, num_heads, target_length, target_length]; - .cross_attention_probs is of shape [batch, num_heads, target_length, source_length]. - If initializing cache from scratch, output will be None. - - Raises: - ValueError: If `mode` is unsupported. - """ - if isinstance(data, Tensor): - self.vlog(3, "transformer.input=%s", data.sum()) # pytype: disable=attribute-error - self_attention_return_aux = set() - cross_attention_return_aux = set() - if return_aux: - if "self_attention_probs" in return_aux: - self_attention_return_aux.add("probs") - if "self_attention_kv_state" in return_aux: - self_attention_return_aux.add("kv_state") - if "cross_attention_probs" in return_aux: - cross_attention_return_aux.add("probs") - if mode == ForwardMode.FORWARD: - self_atten_state, self_atten_outputs = ( - None, - self.self_attention( - target=data, - segment_ids=target_segment_ids, - target_positions=target_positions, - source=self_attention_kv_state, - attention_logit_biases=self_attention_logit_biases, - return_aux=self_attention_return_aux, - ), - ) - elif mode == ForwardMode.INIT_STATES: - assert cached_states is not None - if target_segment_ids is not None: - raise NotImplementedError("target_segment_ids is not supported in INIT_STATES.") - if target_positions is not None: - raise NotImplementedError("target_positions is not supported in INIT_STATES.") - self_atten_state, self_atten_outputs = self.self_attention.init_states( - time_step=cached_states["self_attention"], - target=data, - source=self_attention_kv_state, - attention_logit_biases=self_attention_logit_biases, - return_aux=self_attention_return_aux, - ) - elif mode == ForwardMode.EXTEND_STEP: - assert cached_states is not None - if target_segment_ids is not None: - raise NotImplementedError("target_segment_ids is not supported in EXTEND_STEP.") - if target_positions is not None: - raise NotImplementedError("target_positions is not supported in EXTEND_STEP.") - self_atten_state, self_atten_outputs = self.self_attention.extend_step( - cached_states=cached_states["self_attention"], - target=data, - source=self_attention_kv_state, - attention_logit_biases=self_attention_logit_biases, - return_aux=self_attention_return_aux, - ) - else: - raise ValueError(f"Unrecognized mode {mode}.") - - if self_atten_outputs is None: - assert mode == ForwardMode.INIT_STATES - return dict(self_attention=self_atten_state), self_atten_outputs - - data = self_atten_outputs.data - self.vlog(3, "self_attention.output=%s", data.sum()) - if cross_attention_data is not None: - cross_atten_outputs = self.cross_attention( - target=data, - source=cross_attention_data, - attention_logit_biases=cross_attention_logit_biases, - return_aux=cross_attention_return_aux, + layer_params["attention"]["u_bias"] = jax.random.normal( + jax.random.PRNGKey(0), [num_heads, model_dim // num_heads] + ) + layer_params["attention"]["v_bias"] = jax.random.normal( + jax.random.PRNGKey(1), [num_heads, model_dim // num_heads] + ) + if per_dim_scale: + layer_params["attention"]["scale_query"]["per_dim_scale"]["param"] = jax.random.normal( + jax.random.PRNGKey(2), [model_dim // num_heads] ) - data = cross_atten_outputs.data - cross_attention_probs = cross_atten_outputs.probs - else: - cross_attention_probs = None - data = self.feed_forward(data) - self.vlog(3, "transformer.output=%s", data.sum()) - # TODO(markblee): Support module outputs in decoding. - if mode == ForwardMode.FORWARD: - self.add_module_output("output", data) - return dict(self_attention=self_atten_state), BaseTransformerLayer.Output( - data=data, - self_attention_probs=self_atten_outputs.probs, - self_attention_kv_state=self_atten_outputs.kv_state, - cross_attention_probs=cross_attention_probs, + layer_outputs, _ = F( + layer, + state=layer_params, + is_training=True, + prng_key=jax.random.PRNGKey(456), + inputs=dict(target=target), ) - - def forward( - self, - data: Tensor, - **kwargs, - ) -> BaseTransformerLayer.Output: - _, output = self._forward_for_mode( - mode=ForwardMode.FORWARD, data=data, cached_states=None, **kwargs + expected_vals = { + str(None): { + MultiheadAttentionXL.ScalePosition.LOGIT.value: 48.683887, + MultiheadAttentionXL.ScalePosition.QUERY.value: 48.598305, + }, + str(PerDimScale.default_config()): { + MultiheadAttentionXL.ScalePosition.LOGIT.value: 48.790010, + MultiheadAttentionXL.ScalePosition.QUERY.value: 48.858986, + }, + } + assert_allclose( + expected_vals[str(per_dim_scale)][scale_position.value], + jnp.abs(layer_outputs.data).sum(), ) - return output - def init_states( - self, - time_step: Optional[Tensor], - data: Union[Tensor, TensorSpec], - **kwargs, - ) -> tuple[Nested[Tensor], Optional[BaseTransformerLayer.Output]]: - return self._forward_for_mode( - mode=ForwardMode.INIT_STATES, - cached_states=dict(self_attention=time_step), - data=data, - **kwargs, - ) - - def extend_step( - self, - cached_states: NestedTensor, - data: Tensor, - **kwargs, - ) -> tuple[NestedTensor, BaseTransformerLayer.Output]: - return self._forward_for_mode( # pytype:disable=bad-return-type - mode=ForwardMode.EXTEND_STEP, - cached_states=cached_states, - data=data, - **kwargs, + def test_multihead_attention_xl(self): + model_dim = 6 + num_heads = 2 + per_head_dim = model_dim // num_heads + cfg = attention.TransformerAttentionLayer.default_config().set( + name="test", + target_dim=model_dim, + source_dim=model_dim, + structure="postnorm", + attention=MultiheadAttentionXL.default_config().set(num_heads=num_heads), + ) + cfg.attention.output_linear.bias = False + cfg.attention.vlog = 5 + layer: attention.TransformerAttentionLayer = cfg.instantiate(parent=None) + layer.initialize_parameters_recursively(jax.random.PRNGKey(123)) + ref_cfg = hf_xlnet.XLNetConfig( + n_head=num_heads, + d_model=model_dim, + d_head=model_dim // num_heads, + dropout=0, + layer_norm_eps=cfg.norm.eps, + ) + ref = hf_xlnet.XLNetRelativeAttention(ref_cfg) + # XLNetRelativeAttention is not properly initialized. + with torch.no_grad(): + for var in ("q", "k", "v", "o", "r"): + getattr(ref, var).copy_( + torch.normal(0, np.sqrt(model_dim), [model_dim, num_heads, per_head_dim]) + ) + for var in ("r_w_bias", "r_r_bias"): + getattr(ref, var).copy_( + torch.normal(0, np.sqrt(model_dim), [num_heads, model_dim // num_heads]) + ) + batch_size, tgt_len = 2, 5 + target = jax.random.normal( + jax.random.PRNGKey(100), [batch_size, tgt_len, model_dim], dtype=jnp.float32 + ) + num_tokens = jax.random.randint( + jax.random.PRNGKey(101), + minval=2, + maxval=tgt_len + 1, + shape=[batch_size], + ) + # [batch_size, tgt_len]. + is_valid_token = jnp.arange(tgt_len)[None, :] < num_tokens[:, None] + # [batch_size, 1, tgt_len, tgt_len]. + attention_logit_biases = jnp.expand_dims( + NEG_INF * (1 - jnp.einsum("bt,bs->bts", is_valid_token, is_valid_token)), 1 + ) + # [2 * tgt_len, model_dim]. + rel_pos_emb = sinusoidal_positional_embeddings( + jnp.arange(tgt_len, -tgt_len, -1), dim=model_dim + ) + ref_inputs = dict( + g=None, + h=target.transpose([1, 0, 2]), # [qlen, bsz, d_model]. + r=rel_pos_emb[:, None, :], # [rlen, 1, d_model]. + attn_mask_g=None, + # [qlen, klen, bsz, n_head]. + attn_mask_h=attention_logit_biases.transpose([2, 3, 0, 1]) < 0, + seg_mat=None, + ) + logging.info("ref_inputs=%s", ref_inputs) + + test_outputs, ref_outputs = self._compute_layer_outputs( + test_layer=layer, + ref_layer=ref, + test_inputs=dict(target=target, attention_logit_biases=attention_logit_biases), + ref_inputs=as_torch_tensor(ref_inputs), + parameters_from_ref_layer=parameters_from_torch_layer, + require_same_num_params=False, + ) + logging.info("test_outputs=%s", test_outputs) + logging.info("ref_outputs=%s", ref_outputs) + self.assertNestedAllClose( + test_outputs.data, as_tensor(ref_outputs[0]).transpose([1, 0, 2]), atol=6e-6 ) -class ParallelTransformerLayer(BaseTransformerLayer): - """A Transformer layer with parallel self-attention and feed-forward layers: - - x = norm(inputs) - outputs = inputs + self_atten(x) + ffn(x) +class TransformerAttentionLayerTest(TestCase): + @parameterized.parameters([False, True]) + def test_forward_vs_extend_step(self, with_source: bool): + init_prng, target_prng, source_prng = jax.random.split(jax.random.PRNGKey(0), 3) - TODO(rpang): experiment to understand whether we should use separate normalization layers - for self_atten and ffn as in PaLM. + model_dim = 8 + layer_kwargs = dict(target_dim=model_dim, source_dim=model_dim) + cfg: TransformerAttentionLayer.Config = TransformerAttentionLayer.default_config().set( + **layer_kwargs + ) + cfg.attention.set(num_heads=2, mask=causal_mask) + layer: TransformerAttentionLayer = cfg.set(name="test").instantiate(parent=None) + layer_params = layer.initialize_parameters_recursively(prng_key=init_prng) - References: - https://github.com/kingoflolz/mesh-transformer-jax - PaLM: https://arxiv.org/abs/2204.02311 - """ + batch, decode_len = 2, 6 + target = jax.random.uniform(target_prng, shape=[batch, decode_len, model_dim]) + input_kwargs = {} - @config_class - class Config(BaseTransformerLayer.Config): - norm: InstantiableConfig = LayerNorm.default_config() # The normalization layer config. - self_attention: MultiheadAttention.Config = MultiheadAttention.default_config() - feed_forward: TransformerFeedForwardLayer.Config = ( - TransformerFeedForwardLayer.default_config().set(structure="nonorm") - ) + if with_source: + input_kwargs.update( + source=jax.random.uniform(source_prng, shape=[batch, decode_len, model_dim]) + ) - def __init__(self, cfg: Config, *, parent: Module): - super().__init__(cfg, parent=parent) - cfg: TransformerLayer.Config = self.config - self._add_child("norm", cfg.norm.set(input_dim=cfg.input_dim)) - self._add_child( - "self_attention", - cfg.self_attention.set( - query_dim=cfg.input_dim, - key_dim=cfg.input_dim, - value_dim=cfg.input_dim, - output_dim=cfg.input_dim, - ), + forward_outputs, _ = F( + layer, + inputs=dict(target=jnp.asarray(target), **input_kwargs), + state=layer_params, + is_training=True, + prng_key=jax.random.PRNGKey(0), ) - self._add_child("feed_forward", cfg.feed_forward.set(input_dim=cfg.input_dim)) - def forward( - self, - *, - data: Tensor, - self_attention_logit_biases: Optional[Tensor] = None, - target_segment_ids: Optional[Tensor] = None, - ) -> BaseTransformerLayer.Output: - """Computes transformer layer outputs and self/cross-attention probabilities. + for start_time_step in (-1, 0, 2, decode_len): + if start_time_step < 0: + (cached_states, init_outputs), _ = F( + layer, + inputs=dict( + time_step=None, + target=TensorSpec(target.shape, target.dtype), + **input_kwargs, + ), + state=layer_params, + is_training=True, + prng_key=jax.random.PRNGKey(0), + method="init_states", + ) + self.assertIsNone(init_outputs) + data = jnp.zeros([batch, decode_len, model_dim]) + start_time_step = 0 + else: + (cached_states, prefill_outputs), _ = F( + layer, + inputs=dict( + time_step=jnp.array([start_time_step] * batch, dtype=jnp.int32), + target=target, + **input_kwargs, + ), + state=layer_params, + is_training=True, + prng_key=jax.random.PRNGKey(0), + method="init_states", + ) + data = prefill_outputs.data - Args: - data: A Tensor of shape [batch, target_length, target_dim]. - self_attention_logit_biases: An optional Tensor representing the self-attention biases. - target_segment_ids: See ``segment_ids`` in the file comments. + data = jnp.einsum("btd->tbd", data) - Returns: - An Output instance, where .data is of the same shape as `data`, .self_attention_probs is - of shape [batch, num_heads, target_length, target_length]. + for time_step in range(start_time_step, decode_len): + extend_kwargs = {} + for k, v in input_kwargs.items(): + extend_kwargs[k] = jnp.asarray(v[:, time_step : time_step + 1, :]) - Raises: - ValueError: If `mode` is unsupported. - """ - inputs = data - data = self.norm(data) - self_atten_outputs = self.self_attention( - query=data, - key=data, - value=data, - attention_logit_biases=self_attention_logit_biases, - segment_ids=target_segment_ids, - ) - feed_forward_outputs = self.feed_forward(data) - outputs = inputs + self_atten_outputs.data + feed_forward_outputs - return BaseTransformerLayer.Output( - data=outputs, - self_attention_probs=self_atten_outputs.probs, - self_attention_kv_state=self_atten_outputs.kv_state, - cross_attention_probs=None, - ) + (cached_states, extend_outputs), _ = F( + layer, + inputs=dict( + target=jnp.asarray(target[:, time_step : time_step + 1, :]), + cached_states=cached_states, + **extend_kwargs, + ), + state=layer_params, + is_training=True, + prng_key=jax.random.PRNGKey(0), + method="extend_step", + ) + data = data.at[time_step].set(jnp.squeeze(extend_outputs.data, axis=1)) + data = jnp.einsum("tbd->btd", data) -def _next_power_of_two(n: float) -> int: - if n <= 1: - return 2 - return 1 << int(math.log2(n - 1)) + 1 + # Prefill + extend_step == forward. + assert_allclose(forward_outputs.data, data) -class BottleNeckAdapterTransformerLayer(BaseTransformerLayer): - """TransformerLayer with bottleneck adaptor for fine-tuning. - Figure 3(a) in https://arxiv.org/pdf/2110.04366.pdf - """ +class TransformerFeedForwardLayerTest(TestCase): + @parameterized.parameters( + dict(rms_norm_summary=[]), + dict(rms_norm_summary=["linear2_outputs"]), + dict(rms_norm_summary=["final_outputs"], expected_raise_regex="add_value_rms_norm_summary"), + ) + def test_add_value_rms_norm_summary( + self, rms_norm_summary: list[str], *, expected_raise_regex=None + ): + batch, seq_len, dim = 2, 3, 4 + cfg = TransformerFeedForwardLayer.default_config().set( + name="ffn", + input_dim=dim, + hidden_dim=dim * 4, + add_value_rms_norm_summary=rms_norm_summary, + tensor_stats=DefaultTensorStats.default_config(), + ) + if expected_raise_regex is not None: + with self.assertRaisesRegex(NotImplementedError, expected_raise_regex): + layer = cfg.instantiate(parent=None) + return + layer = cfg.instantiate(parent=None) + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) + x = jax.random.normal(jax.random.PRNGKey(1), shape=[batch, seq_len, dim]) + y, output_collection = F( + layer, + inputs=dict(inputs=x), + state=layer_params, + is_training=True, + prng_key=jax.random.PRNGKey(0), + ) + self.assertSequenceEqual(x.shape, y.shape) + self.assertNestedAllClose(2.663487, jnp.sum(y)) + if "tensor_stats" in output_collection.summaries: + output_stats = output_collection.summaries["tensor_stats"] + else: + output_stats = {} + for k in rms_norm_summary: + assert k in output_stats + + @parameterized.parameters( + dict(activation_fn="nn.relu"), + dict(activation_fn=("nn.relu", "linear")), + dict(activation_fn=("linear", "quick_gelu")), + dict(activation_fn=("linear", "exact_gelu")), + dict(activation_fn=("linear", "nn.silu")), + ) + def test_add_dead_neuron_summary(self, activation_fn: Union[str, list[str]]): + batch, seq_len, dim = 2, 3, 4 + cfg = TransformerFeedForwardLayer.default_config().set( + name="ffn", + input_dim=dim, + hidden_dim=dim * 4, + activation=activation_fn, + add_dead_neuron_summary=True, + ) + layer = cfg.instantiate(parent=None) + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) + x = jax.random.normal(jax.random.PRNGKey(1), shape=[batch, seq_len, dim]) + y, output_collection = F( + layer, + inputs=dict(inputs=x), + state=layer_params, + is_training=True, + prng_key=jax.random.PRNGKey(0), + ) + self.assertSequenceEqual(x.shape, y.shape) + if isinstance(activation_fn, str): + activation_fn = [activation_fn] + self.assertSetEqual( + {k for k in output_collection.summaries.keys() if k.startswith("dead_neurons/")}, + { + f"dead_neurons/{k}" + for k in activation_fn + if k in ("nn.relu", "quick_gelu", "exact_gelu", "nn.silu") + }, + ) - @config_class - class Config(BaseTransformerLayer.Config): - """Configures BottleNeckAdapterTransformerLayer.""" + def test_linear_remat(self): + batch, seq_len, dim = 2, 3, 4 + cfg = TransformerFeedForwardLayer.default_config().set( + name="ffn", + input_dim=dim, + hidden_dim=dim * 4, + add_value_rms_norm_summary=[], + tensor_stats=DefaultTensorStats.default_config(), + activation=("nn.relu", "nn.relu"), + ) + layer = cfg.instantiate(parent=None) + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) + x = jax.random.normal(jax.random.PRNGKey(1), shape=[batch, seq_len, dim]) + + def f(x, layer_params): + y, _ = F( + layer, + inputs=dict(inputs=x), + state=layer_params, + is_training=True, + prng_key=jax.random.PRNGKey(0), + ) + return y + + _, save_name_backward = jax.linearize( + jax.remat( + f, + policy=save_and_offload_only_these_names_regex( + names_which_can_be_saved=RematRegexSavePatterns.FEED_FORWARD.value, + names_which_can_be_offloaded=None, + offload_src="device", + offload_dst="pinned_host", + ), + ), + x, + layer_params, + ) + _, save_dots_backward = jax.linearize( + jax.remat(f, policy=jax_remat_policies.dots_saveable), x, layer_params + ) - # The transformer layer to which an adapter will be added. - layer: BaseTransformerLayer.Config = TransformerLayer.default_config() + self.assertEqual(str(save_name_backward).count(" dot_general"), 6) + self.assertEqual( + str(save_name_backward).count(" dot_general"), + str(save_dots_backward).count(" dot_general"), + ) - # The adapter, which in this case is a bottleneck layer composed of - # a downward and an upward projection. - adapter: TransformerFeedForwardLayer.Config = TransformerFeedForwardLayer.default_config() - # The ratio by which the input dimension will be - # reduced in the downward projection in the adapter. - bottleneck_ratio: float = 0.5 +class BaseTransformerTest(TestCase): + def _test_decoder_with_transformer(self, transformer_cfg: BaseTransformerLayer.Config): + prefix_length = jnp.asarray([0, 2]) + batch_size, num_decodes, seq_len, vocab_size = prefix_length.shape[0], 3, 7, 6 + bos_id = eos_id = 1 + pad_token_id = 0 - def __init__(self, cfg: Config, *, parent: Module): - super().__init__(cfg, parent=parent) - cfg = self.config - self._add_child("layer", cfg.layer) - self._add_child( - "adapter", - cfg.adapter.set( - input_dim=cfg.layer.input_dim, - hidden_dim=_next_power_of_two(cfg.layer.input_dim * cfg.bottleneck_ratio), - structure="postnorm", + cfg = Decoder.default_config().set( + transformer=transformer_cfg.clone(name="transformer"), + dim=transformer_cfg.input_dim, + vocab_size=vocab_size, + emb=TransformerTextEmbeddings.default_config().set( + pos_emb=LearnedPositionalEmbedding.default_config().set(shape=(seq_len,)) ), + # output_norm=LayerNorm.default_config().set(eps=layer_norm_epsilon), + # dropout_rate=dropout_rate, + pad_token_id=pad_token_id, + eos_token_id=eos_id, ) - def _forward_for_mode( - self, - *, - mode: ForwardMode, - data: Union[Tensor, TensorSpec], - cached_states: Optional[NestedTensor] = None, - **kwargs, - ) -> tuple[Optional[Nested[Tensor]], Optional[Tensor]]: - """Computes transformer layer outputs and self/cross-attention probabilities. - - Args: - mode: Configures whether `cached_states` are consumed or emitted. See `ForwardMode` for - details. - data: A Tensor of shape [batch, target_length, target_dim]. - cached_states: Optional NestedTensor as produced by `init_states`. - - Returns: - A tuple (cached_states, output): - * cached_states: An optional NestedTensor of cache states, depending on `mode`. - * output: An Output instance, where .data is of the same shape as `data`; - .self_attention_probs is of shape [batch, num_heads, target_length, target_length]; - .cross_attention_probs is of shape [batch, num_heads, target_length, source_length]. - If initializing cache from scratch, output will be None. - - Raises: - ValueError: If `mode` is unsupported. - """ - if isinstance(data, Tensor): - self.vlog(3, "transformer.input=%s", data.sum()) # pytype: disable=attribute-error - if mode == ForwardMode.FORWARD: - output = self.layer.forward(data=data, **kwargs) - elif mode == ForwardMode.INIT_STATES: - assert cached_states is not None - cached_states, output = self.layer.init_states( - time_step=cached_states["layer"], - data=data, - **kwargs, - ) - elif mode == ForwardMode.EXTEND_STEP: - assert cached_states is not None - cached_states, output = self.layer.extend_step( - cached_states=cached_states, - data=data, - **kwargs, - ) - else: - raise ValueError(f"Unrecognized mode {mode}.") - - if output is None: - assert mode == ForwardMode.INIT_STATES and cached_states["layer"] is None - return cached_states, output + decoder: Decoder = cfg.set(name="decoder").instantiate(parent=None) + decoder_state = decoder.initialize_parameters_recursively(jax.random.PRNGKey(0)) - skip_input = output.data - data = self.adapter(output.data) - data += skip_input - self.vlog(3, "adapted_transformer.output=%s", data.sum()) - return cached_states, output._replace(data=data) - - def forward( - self, - data: Tensor, - **kwargs, - ) -> BaseTransformerLayer.Output: - _, output = self._forward_for_mode( - mode=ForwardMode.FORWARD, - data=data, - cached_states=None, - **kwargs, + prefix = jax.random.randint( + jax.random.PRNGKey(124), + shape=[batch_size, seq_len], + # Prefix can consist of any tokens, including pad and eos. + minval=0, + maxval=vocab_size, ) - return output + # Explicitly fill positions >= prefix_length with pad_token_id. + # Note that each batch example may have a different prefix length. + # [batch_size, seq_len]. + prefix_mask = jnp.arange(seq_len) < prefix_length[:, None] + prefix = prefix * prefix_mask + pad_token_id * (1 - prefix_mask) + # Set last token to a non-pad token, to fix the prefix length. + oh_indices = jax.nn.one_hot(prefix_length - 1, seq_len, dtype=prefix.dtype) + prefix = prefix * (1 - oh_indices) + bos_id * oh_indices + inputs = dict( + input_batch=dict(prefix=prefix), + max_sequence_length=seq_len, + # cross_attention_data=None, + # cross_attention_logit_biases=None, + num_decodes=num_decodes, + ) + outputs, _ = F( + decoder, + inputs=inputs, + state=decoder_state, + is_training=False, + prng_key=jax.random.PRNGKey(2), + method="sample_decode", + ) + sequences = outputs.sequences + self.assertEqual(sequences.shape, (batch_size, num_decodes, seq_len)) - def init_states( + def _test_forward_vs_extend_step( self, + cfg: BaseTransformerLayer.Config, *, - time_step: Optional[Tensor], - data: Union[Tensor, TensorSpec], - **kwargs, - ) -> tuple[Nested[Tensor], Optional[BaseTransformerLayer.Output]]: - return self._forward_for_mode( - mode=ForwardMode.INIT_STATES, - cached_states=dict(layer=time_step), - data=data, - **kwargs, - ) - - def extend_step( - self, - cached_states: NestedTensor, - data: Tensor, - **kwargs, - ) -> tuple[NestedTensor, BaseTransformerLayer.Output]: - return self._forward_for_mode( # pytype: disable=bad-return-type - mode=ForwardMode.EXTEND_STEP, - cached_states=cached_states, - data=data, - **kwargs, + input_kwargs: Optional[dict[str, Any]] = None, + ): + """Tests that {init,prefill}_states + extend_step is equivalent to forward for `cfg`.""" + if input_kwargs is None: + input_kwargs = {} + layer: BaseTransformerLayer = cfg.clone(name="layer").instantiate(parent=None) + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) + + batch_size, tgt_len = 2, 5 + rng = np.random.default_rng(seed=123) + target = rng.random([batch_size, tgt_len, cfg.input_dim], dtype=np.float32) + + forward_outputs, _ = F( + layer, + inputs=dict( + data=jnp.asarray(target), + **input_kwargs, + ), + state=layer_params, + is_training=True, + prng_key=jax.random.PRNGKey(0), ) + for start_time_step in (-1, 0, 2, tgt_len): + if start_time_step > tgt_len: + continue + print(f"start_time_step={start_time_step} layer={type(layer)}") + if start_time_step < 0: + (cached_states, init_outputs), _ = F( + layer, + inputs=dict( + time_step=None, + data=TensorSpec([batch_size, tgt_len]), + **input_kwargs, + ), + state=layer_params, + is_training=True, + prng_key=jax.random.PRNGKey(0), + method="init_states", + ) + self.assertIsNone(init_outputs) + decoder_output = jnp.zeros_like(target) + start_time_step = 0 + else: + (cached_states, prefill_outputs), _ = F( + layer, + inputs=dict( + time_step=jnp.array([start_time_step] * batch_size, dtype=jnp.int32), + data=jnp.asarray(target), + **input_kwargs, + ), + state=layer_params, + is_training=True, + prng_key=jax.random.PRNGKey(0), + method="init_states", + ) + decoder_output = prefill_outputs.data + # Transpose to [tgt_len, batch_size, model_dim]. + decoder_output = jnp.einsum("bsd->sbd", decoder_output) + for time_step in range(start_time_step, tgt_len): + (cached_states, extend_step_outputs), _ = F( + layer, + inputs=dict( + data=jnp.asarray(target[:, time_step : time_step + 1, :]), + cached_states=cached_states, + **input_kwargs, + ), + state=layer_params, + is_training=True, + prng_key=jax.random.PRNGKey(0), + method="extend_step", + ) + decoder_output = decoder_output.at[time_step].set( + jnp.squeeze(extend_step_outputs.data, axis=1) + ) + # Transpose to [batch_size, tgt_len, model_dim]. + decoder_output = jnp.einsum("sbd->bsd", decoder_output) + # Prefill + extend_step == forward. + assert_allclose(forward_outputs.data, decoder_output) -def set_double_shard_weights_config( - cfg: Union[TransformerLayer.Config, Sequence[TransformerLayer.Config]], - *, - batch_axis_names: Union[str, Sequence[str]] = ("data", "fsdp"), - fsdp_axis_names: Union[str, Sequence[str]] = "fsdp", - tp_axis_names: Union[str, Sequence[str]] = "model", - seq_axis_names: Union[str, Sequence[str]] = "seq", -): - """Sets `cfg` to shard FFN and attention weights over both fsdp and tp axes. - - Args: - cfg: (A sequence of) Transformer layer config to apply sharding spec to. - batch_axis_names: Axis name(s) over which we shard the batch dimension of output tensors. - fsdp_axis_names: Axis name(s) over which we shard fully-sharded-data-parallel tensors. - tp_axis_names: Axis name(s) over which we shard tensor-parallel tensors. - seq_axis_names: Axis name(s) over which we shard sequence-parallel tensors. - """ - - # pytype: disable=attribute-error - def set_attn_partition_specs(attn_layer: MultiheadAttention.Config): - # Shard weights. - input_linear_cfg = attn_layer.input_linear - if hasattr(input_linear_cfg, "input_linear"): - input_linear_cfg = input_linear_cfg.input_linear - input_linear_cfg.layer.param_partition_spec = (fsdp_axis_names, tp_axis_names, None) - attn_layer.output_linear.param_partition_spec = (fsdp_axis_names, tp_axis_names, None) - - def set_ffn_partition_specs(ff_layer: TransformerFeedForwardLayer.Config): - # Shard weights. - ff_layer.linear1.param_partition_spec = (fsdp_axis_names, tp_axis_names) - ff_layer.linear2.param_partition_spec = (tp_axis_names, fsdp_axis_names) - # Encourage the right activation sharding. - ff_layer.linear1.output_partition_spec = (batch_axis_names, seq_axis_names, tp_axis_names) - ff_layer.linear2.output_partition_spec = (batch_axis_names, seq_axis_names, tp_axis_names) - - if not isinstance(cfg, Sequence): - cfg = [cfg] - - for layer_cfg in cfg: - set_attn_partition_specs(layer_cfg.self_attention.attention) - if layer_cfg.cross_attention is not None: - set_attn_partition_specs(layer_cfg.cross_attention.attention) - if isinstance(layer_cfg.feed_forward, TransformerFeedForwardLayer.Config): - set_ffn_partition_specs(layer_cfg.feed_forward) - # pytype: enable=attribute-error - - -class BaseStackedTransformerLayer(BaseTransformerLayer): - """The common interface of all stacked transformer layer classes. - - Note that BaseStackedTransformerLayer is a subclass of BaseTransformerLayer and therefore - can be used where a BaseTransformerLayer is expected. - - The Output returned by BaseStackedTransformerLayer has the following fields: - * .data is of the same shape as query, from the output of the final layer; - * .self_attention_kv_state is of shape [batch, target_length, num_heads, head_dim], - from the self-attention KV state of the final layer; - * .probs is of shape [num_layers, batch, num_heads, target_length, source_length], - from all layers of the stack; - """ - - @config_class - class Config(BaseTransformerLayer.Config): - """Configures BaseStackedTransformerLayer.""" - # The number of layers in the stack. - num_layers: Required[int] = REQUIRED - # Config for each layer in the stack. - # The layer must be a subclass of BaseTransformerLayer. - layer: BaseTransformerLayer.Config = TransformerLayer.default_config() - peak_stochastic_depth_rate: Optional[float] = None +class TransformerTest(BaseTransformerTest): + """Tests TransformerLayer.""" + def _compare_against_roberta_attention( + self, ref: hf_roberta.RobertaAttention, layer: TransformerAttentionLayer + ): + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) + layer_param_shapes = jax.tree.map(lambda x: x.shape, layer_params) + print(f"layer state={layer_param_shapes}") + layer_params = parameters_from_torch_layer(ref) + batch_size, tgt_len = 2, 6 + model_dim, num_heads = layer.config.target_dim, layer.config.attention.num_heads + rng = np.random.default_rng(seed=123) + target = rng.random([batch_size, tgt_len, model_dim], dtype=np.float32) + null_mask = jnp.zeros([tgt_len, tgt_len]) + rand_mask = _random_mask(jax.random.PRNGKey(123), tgt_len, tgt_len) + for mask in (None, null_mask, rand_mask): + if mask is not None: + mask = jnp.tile(mask[None, None, :, :], (batch_size, num_heads, 1, 1)) + layer_outputs, _ = F( + layer, + inputs=dict(target=jnp.asarray(target), attention_logit_biases=mask), + state=layer_params, + is_training=True, + prng_key=jax.random.PRNGKey(0), + ) + attn_mask = None if mask is None else as_torch_tensor(mask) + (ref_outputs,) = ref.forward( + torch.as_tensor(target, dtype=torch.float32), + attention_mask=attn_mask, + output_attentions=False, + ) + assert_allclose(layer_outputs.data, as_tensor(ref_outputs)) + + def test_against_roberta_attention(self): + model_dim = 16 + num_heads = 4 + cfg = attention.TransformerAttentionLayer.default_config().set( + name="test", + target_dim=model_dim, + source_dim=model_dim, + structure="postnorm", + ) + cfg.attention.set(num_heads=num_heads) + layer = cfg.instantiate(parent=None) + roberta_config = hf_roberta.RobertaConfig( + hidden_size=model_dim, + num_attention_heads=num_heads, + attention_probs_dropout_prob=0, + hidden_dropout_prob=0, + classifier_dropout=0, + ) + print(f"roberta_config={roberta_config}") + ref = hf_roberta.RobertaAttention(roberta_config) + self._compare_against_roberta_attention(ref, layer) + + def _compare_against_roberta_layer(self, ref: hf_roberta.RobertaLayer, layer: TransformerLayer): + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) + layer_params = parameters_from_torch_layer(ref) + batch_size, tgt_len = 2, 6 + model_dim, num_heads = ( + layer.config.input_dim, + layer.config.self_attention.attention.num_heads, + ) + rng = np.random.default_rng(seed=123) + target = rng.random([batch_size, tgt_len, model_dim], dtype=np.float32) + null_mask = jnp.zeros([tgt_len, tgt_len]) + rand_mask = _random_mask(jax.random.PRNGKey(123), tgt_len, tgt_len) + for mask in (None, null_mask, rand_mask): + if mask is not None: + mask = jnp.tile(mask[None, None, :, :], (batch_size, num_heads, 1, 1)) + layer_outputs, output_collection = F( + layer, + inputs=dict(data=jnp.asarray(target), self_attention_logit_biases=mask), + state=layer_params, + is_training=True, + prng_key=jax.random.PRNGKey(0), + drop_output_collections=(), + ) + if layer_outputs.self_attention_probs is not None: + self.assertEqual( + (batch_size, num_heads, tgt_len, tgt_len), + layer_outputs.self_attention_probs.shape, + ) + attn_mask = None if mask is None else as_torch_tensor(mask) + (ref_outputs,) = ref.forward( + torch.as_tensor(target, dtype=torch.float32), + attention_mask=attn_mask, + output_attentions=False, + ) + assert_allclose(layer_outputs.data, as_tensor(ref_outputs)) + self.assertNestedEqual(layer_outputs.data, output_collection.module_outputs["output"]) + + def test_against_roberta_layer(self): + model_dim = 16 + num_heads = 4 + cfg = TransformerLayer.default_config().set(name="test", input_dim=model_dim) + cfg.self_attention.set(structure="postnorm") + cfg.feed_forward.set( + structure="postnorm", activation="nn.silu", hidden_dim=scaled_hidden_dim(4) + ) + cfg.feed_forward.linear1.set(bias=True) + cfg.feed_forward.linear2.set(bias=True) + cfg.self_attention.attention.set(num_heads=num_heads) + cfg.self_attention.attention.input_linear.layer.set(bias=True) + cfg.self_attention.attention.output_linear.set(bias=True) + layer: TransformerLayer = cfg.instantiate(parent=None) + roberta_config = hf_roberta.RobertaConfig( + hidden_size=model_dim, + num_attention_heads=num_heads, + attention_probs_dropout_prob=0, + hidden_dropout_prob=0, + classifier_dropout=0, + # Jax's gelu uses an approximation by default and is slightly different from + # torch.nn.gelu. + hidden_act="silu", + ) + ref = hf_roberta.RobertaLayer(roberta_config) + self._compare_against_roberta_layer(ref, layer) + + def test_decoding(self): + model_dim, num_heads = 6, 2 + cfg = TransformerLayer.default_config().set(input_dim=model_dim) + cfg.self_attention.attention.set(num_heads=num_heads, causal=True) + cfg.feed_forward.hidden_dim = model_dim * 4 + cfg.vlog = 5 + self._test_forward_vs_extend_step(cfg) + + def test_self_attention_kv_state(self): + """Tests TransformerLayer with explicit self_attention_kv_state. + + Creates a base TransformerLayer and a test TransformerLayer with QLinear. Uses the kv_state + of the base layer as the explicit kv_state for the test layer. Checks that the outputs are + identical. + """ + model_dim = 16 + num_heads = 4 + base_cfg = TransformerLayer.default_config().set(name="test", input_dim=model_dim) + base_cfg.feed_forward.set(hidden_dim=scaled_hidden_dim(4)) + base_cfg.self_attention.attention.set(num_heads=num_heads, causal=True) + base_layer: TransformerLayer = base_cfg.instantiate(parent=None) + base_layer_params = base_layer.initialize_parameters_recursively( + prng_key=jax.random.PRNGKey(0) + ) -class UpdateDataFn(Protocol): - """A function for updating the constituent layers' input in a StackTransformerLayer.""" + test_cfg = base_cfg.clone() + test_cfg.self_attention.attention.input_linear = QLinear.default_config() + test_layer: TransformerLayer = test_cfg.instantiate(parent=None) + # Let test_layer_params to be identical to base_layer_params except removing {k,v}_proj. + test_layer_params = copy.deepcopy(base_layer_params) + for k in ("k_proj", "v_proj"): + test_layer_params["self_attention"]["attention"]["i_proj"].pop(k) + self.assertEqual( + shapes(test_layer_params), + shapes(test_layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0))), + ) - def __call__( - self, data: Tensor, all_layer_outputs: list[BaseTransformerLayer.Output] - ) -> Tensor: - """Returns a new Tensor with the same shape as `data`, reflecting some desired updates. + batch_size, tgt_len = 2, 5 + rng = np.random.default_rng(seed=123) + target = rng.random([batch_size, tgt_len, model_dim], dtype=np.float32) + base_layer_outputs, _ = F( + base_layer, + inputs=dict(data=jnp.asarray(target), return_aux={"self_attention_kv_state"}), + state=base_layer_params, + is_training=True, + prng_key=jax.random.PRNGKey(0), + ) + test_layer_outputs, _ = F( + test_layer, + # Explicitly pass `self_attention_kv_state` from `base_layer_outputs` as inputs to + # test_layer. + inputs=dict( + data=jnp.asarray(target), + self_attention_kv_state=base_layer_outputs.self_attention_kv_state, + return_aux={"self_attention_kv_state"}, + ), + state=test_layer_params, + is_training=True, + prng_key=jax.random.PRNGKey(0), + ) + assert_allclose(base_layer_outputs.data, test_layer_outputs.data) + + # Tests prefill_state and extend_step. + self._test_forward_vs_extend_step( + test_cfg, + input_kwargs=dict( + # Explicitly pass `self_attention_kv_state`. + self_attention_kv_state=base_layer_outputs.self_attention_kv_state, + ), + ) - Args: - data: A Tensor denoting the input data to the upcoming layer. - all_layer_outputs: A list of BaseTransformerLayer.Output that is appended with - the output of each constituent layer in the stack. - Returns: - A new Tensor with the same shape as `data`. - """ +class ParallelTransformerTest(TestCase): + """Tests ParallelTransformerLayer.""" + + def test_with_golden_value(self): + """A test of ParallelTransformerLayer by comparing results to a golden value.""" + model_dim = 16 + num_heads = 4 + cfg = ParallelTransformerLayer.default_config().set(name="test", input_dim=model_dim) + cfg.feed_forward.set(hidden_dim=scaled_hidden_dim(4)) + cfg.self_attention.set(num_heads=num_heads) + cfg.norm = RMSNorm.default_config() + set_bias_recursively(cfg, bias=False) + layer: TransformerLayer = cfg.instantiate(parent=None) + + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) + self.assertEqual( + { + "feed_forward": { + "dropout1": {}, + "dropout2": {}, + "linear1": {"weight": (16, 64)}, + "linear2": {"weight": (64, 16)}, + "stochastic_depth": {}, + }, + "norm": {"scale": (16,)}, + "self_attention": { + "dropout": {}, + "i_proj": { + "k_proj": {"weight": (16, 4, 4)}, + "q_proj": {"weight": (16, 4, 4)}, + "v_proj": {"weight": (16, 4, 4)}, + }, + "o_proj": {"weight": (16, 4, 4)}, + "scale_key": {}, + "scale_query": {}, + }, + }, + utils.shapes(layer_params), + ) + batch_size, tgt_len = 2, 6 + rng = np.random.default_rng(seed=123) + target = rng.random([batch_size, tgt_len, model_dim], dtype=np.float32) + mask = attention_bias.make_causal_biases(tgt_len) + mask = jnp.tile(mask[None, None, :, :], (batch_size, num_heads, 1, 1)) + layer_outputs, _ = F( + layer, + inputs=dict(data=jnp.asarray(target), self_attention_logit_biases=mask), + state=layer_params, + is_training=True, + prng_key=jax.random.PRNGKey(0), + ) + self.assertEqual(target.shape, layer_outputs.data.shape) + self.assertNestedAllClose(0.609666, np.mean(layer_outputs.data)) + + def test_build_remat_spec(self): + model_dim, num_heads = 6, 2 + cfg: TransformerLayer.Config = TransformerLayer.default_config().set(input_dim=model_dim) + cfg.self_attention.attention.set(num_heads=num_heads, causal=True) + cfg.feed_forward.hidden_dim = model_dim * 4 + cfg.vlog = 5 + + layer: BaseTransformerLayer = cfg.clone(name="layer").instantiate(parent=None) + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) + + batch_size, tgt_len = 2, 5 + rng = np.random.default_rng(seed=123) + target = rng.random([batch_size, tgt_len, cfg.input_dim], dtype=np.float32) + + def f(x, layer_params): + forward_outputs, _ = F( + layer, + inputs=dict( + data=x, + ), + state=layer_params, + is_training=True, + prng_key=jax.random.PRNGKey(0), + ) + return forward_outputs -def update_data_with_skip_connection(skip_connections: dict[int, int]) -> UpdateDataFn: - """Creates a function that adds skip connection to the input data tensor. + # Ignore type errors. + spec: Any = build_remat_spec(mock.MagicMock()) - Args: - skip_connections: A dictionary where keys and values represent 0-indexed layer indices. - For a (k, v) pair, the output of the v-th layer will be added to the input - of the k-th layer. + _, default_policy_backward = jax.linearize( + jax.remat(f, policy=spec.policy.instantiate(), prevent_cse=spec.prevent_cse), + jnp.asarray(target), + layer_params, + ) + _, full_remat_backward = jax.linearize( + jax.remat(f), + jnp.asarray(target), + layer_params, + ) + # Eliminated the remat of qkv_proj, context and o_proj = 5 dots. This assumes + # FlashAttention is not enabled. + self.assertEqual( + str(full_remat_backward).count(" dot_general") + - str(default_policy_backward).count(" dot_general"), + 5, + ) - Returns: - A function that implements skip connections, following the UpdateDataFn protocol, . - """ + def test_build_remat_spec_neuron(self): + model_dim, num_heads = 6, 2 + cfg: TransformerLayer.Config = TransformerLayer.default_config().set(input_dim=model_dim) + cfg.self_attention.attention.set(num_heads=num_heads, causal=True) + cfg.feed_forward.hidden_dim = model_dim * 4 + cfg.vlog = 5 + + layer: BaseTransformerLayer = cfg.clone(name="layer").instantiate(parent=None) + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) + + batch_size, tgt_len = 2, 5 + rng = np.random.default_rng(seed=123) + target = rng.random([batch_size, tgt_len, cfg.input_dim], dtype=np.float32) + + def f(x, layer_params): + forward_outputs, _ = F( + layer, + inputs=dict( + data=x, + ), + state=layer_params, + is_training=True, + prng_key=jax.random.PRNGKey(0), + ) + return forward_outputs + + # Ignore type errors. + spec: Any = build_remat_spec(mock.MagicMock()) + + policy = ( + config_for_function(save_and_offload_only_these_names_regex) + .set( + names_which_can_be_saved="|".join( + [ + RematRegexSavePatterns.QKV_PROJ.value, + RematRegexSavePatterns.LINEAR1_X.value, + ] + ), + names_which_can_be_offloaded=None, + offload_src=None, + offload_dst=None, + ) + .instantiate() + ) - def update_data(data: Tensor, all_layer_outputs: list[BaseTransformerLayer.Output]) -> Tensor: - layer_index = len(all_layer_outputs) - if layer_index in skip_connections: - data += all_layer_outputs[skip_connections[layer_index]].data - return data + _, default_policy_backward = jax.linearize( + jax.remat(f, policy=policy, prevent_cse=spec.prevent_cse), + jnp.asarray(target), + layer_params, + ) + _, full_remat_backward = jax.linearize( + jax.remat(f), + jnp.asarray(target), + layer_params, + ) - return update_data + # Eliminated the remat of qkv_proj and linear1_0 = 4 dots. + self.assertEqual( + str(full_remat_backward).count(" dot_general") + - str(default_policy_backward).count(" dot_general"), + 4, + ) -class StackedTransformerLayer(BaseStackedTransformerLayer): - """A simple implementation of BaseStackedTransformerLayer.""" +class TestStackModel(BaseLayer): + """A dummy transformer stack.""" @config_class - class Config(BaseStackedTransformerLayer.Config): - """Configures StackedTransformerLayer.""" - - # If `layer` is a Config, it will be stacked cfg.num_layers times. If `layer` is a - # sequence of Configs, the sequence length should match cfg.num_layers. - layer: Union[ - BaseTransformerLayer.Config, Sequence[BaseTransformerLayer.Config] - ] = TransformerLayer.default_config() - # If set, implements the UpdateDataFn protocol to update individual layers' input - # data in some specified way. This operation is applied before calling every layer. - data_merger: Optional[InstantiableConfig[UpdateDataFn]] = None - - def __init__(self, cfg: Config, *, parent: Optional[Module]): + class Config(BaseLayer.Config): + stack: Optional[BaseStackedTransformerLayer.Config] = None # The transformer stack. + output_self_attention_kv_state: bool = False + + def __init__(self, cfg: Config, *, parent: Module): super().__init__(cfg, parent=parent) cfg = self.config - self._update_data = maybe_instantiate(cfg.data_merger) - - if isinstance(cfg.layer, Sequence): - layer_cfgs = cfg.layer - if len(layer_cfgs) != cfg.num_layers: - raise ValueError( - f"Number of layer configs ({len(layer_cfgs)}) must match " - f"cfg.num_layers ({cfg.num_layers})." - ) - else: - layer_cfgs = [cfg.layer] * cfg.num_layers - self._layers = [] - for i, layer_cfg in enumerate(layer_cfgs): - if layer_cfg.input_dim is not REQUIRED: - raise ValueError( - f"Do not set Config.layer.input_dim. Set Config.input_dim instead: {layer_cfg}" - ) - layer_cfg = layer_cfg.clone(input_dim=cfg.input_dim) - if cfg.peak_stochastic_depth_rate: - layer_rate = get_stochastic_depth_linear_rate( - cfg.peak_stochastic_depth_rate, - stage_order=i + 1, - num_stages=cfg.num_layers, - ) - layer_cfg.self_attention.stochastic_depth.rate = layer_rate - layer_cfg.feed_forward.stochastic_depth.rate = layer_rate - self._layers.append(self._add_child(f"layer{i}", layer_cfg)) - - def initialize_parameters_recursively( - self, prng_key: Tensor, *, prebuilt: Optional[Nested[Optional[ParameterSpec]]] = None - ) -> NestedTensor: - cfg = self.config # type: StackedTransformerLayer.Config - prng_key = split_prng_key(prng_key, cfg.num_layers) - state = {} - for i in range(cfg.num_layers): - layer = self._layers[i] - key = jax.tree.map(lambda x, index=i: x[index], prng_key.keys) - state[layer.name] = layer.initialize_parameters_recursively( - key, prebuilt=get_or_none(prebuilt, layer.name) - ) - return state + self._add_child("stack", cfg.stack) - def _forward_for_mode( - self, - *, - mode: ForwardMode, - data: Union[Tensor, TensorSpec], - cached_states: Optional[Nested[Tensor]] = None, - **layer_kwargs, - ) -> tuple[list[Optional[Nested[Tensor]]], Optional[TransformerLayer.Output]]: - """Computes transformer stack outputs. - - Args: - mode: Configures whether `cached_states` are consumed or emitted. See `ForwardMode` for - details. - data: A Tensor or TensorSpec of shape [batch, target_length, target_dim]. - cached_states: Optional Nested Tensor as produced by `init_states`. - - Returns: - A tuple (updated_cache_states, outputs): - * updated_cached_states: An optional NestedTensor of cache states, depending on `mode`; - * outputs: An optional instance of Output (see comments on BaseStackedTransformerLayer). - - Raises: - ValueError: If `mode` is unsupported. - """ - all_layer_outputs = [] - all_layer_states = [] - - # True iff we are initializing an empty cache (i.e., not prefilling). - cache_init = mode == ForwardMode.INIT_STATES and cached_states is None - - for i, layer in enumerate(self._layers): - # Prepare inputs to the current layer. - if self._update_data is not None: - data = self._update_data(data, all_layer_outputs) - # TODO(markblee): Consider folding into _update_data. - self._update_layer_kwargs(layer_kwargs, all_layer_outputs=all_layer_outputs) - - if mode == ForwardMode.FORWARD: - layer_states, layer_outputs = None, layer(data, **layer_kwargs) - elif mode == ForwardMode.INIT_STATES: - # cached_states is allowed to be None in the case where we initialize from scratch. - layer_states, layer_outputs = layer.init_states( - time_step=cached_states, - data=data, - **layer_kwargs, - ) - elif mode == ForwardMode.EXTEND_STEP: - assert cached_states is not None - layer_states, layer_outputs = layer.extend_step( - cached_states=cached_states[i], - data=data, - **layer_kwargs, - ) - else: - raise ValueError(f"Unrecognized mode {mode}.") + def forward(self, data, **layer_kwargs): + cfg = self.config - all_layer_states.append(layer_states) + # [batch, length, dim]. + output = self.stack(data, **layer_kwargs) + x = output.data + x_mean = jnp.mean(x, axis=1, keepdims=True) + # [batch, length]. + x_var = jnp.sum((x - x_mean) ** 2, axis=-1) + loss = jnp.mean(x_var) + if cfg.output_self_attention_kv_state: + return loss, {"mean": x_mean, "self_attention_kv_state": output.self_attention_kv_state} + return loss, {"mean": x_mean} + + +def _recursive_stack(inputs: Nested[Tensor], axis=0): + def stack(*xs): + return jnp.stack(xs, axis=axis) + + return {"layer": utils.vectorized_tree_map(stack, *inputs.values())} + + +def _convert_from_stacked_params( + layer_params: Nested[Tensor], *, target_stack_cfg: BaseStackedTransformerLayer.Config +) -> Nested[Tensor]: + """Converts params of a StackedTransformerLayer to params for `target_stack_cfg`.""" + # First stack to params of a RepeatedTransformerLayer. + layer_params = {"stack": {"repeat": VDict(_recursive_stack(layer_params["stack"]))}} + if target_stack_cfg.klass == RepeatedTransformerLayer: + return layer_params + elif target_stack_cfg.klass == PipelinedTransformerLayer: + pipeline_stage_cfg = target_stack_cfg.stage + num_layers_per_stage = target_stack_cfg.num_layers // target_stack_cfg.num_stages + + def reshape(x): + """Reshapes x from [num_layers, ...] to [num_stages, num_layers_per_stage, ...].""" + x_shape = list(x.shape) + return jnp.reshape(x, [target_stack_cfg.num_stages, num_layers_per_stage] + x_shape[1:]) + + pipeline_params = jax.tree.map(reshape, layer_params["stack"].pop("repeat")) + + if pipeline_stage_cfg.klass == RepeatedTransformerLayer: + layer_params["stack"]["pipeline"] = VDict({"layer": {"repeat": pipeline_params}}) + elif pipeline_stage_cfg.klass == StackedTransformerLayer: + layer_params["stack"]["pipeline"] = VDict( + { + "layer": { + f"layer{i}": jax.tree.map(lambda x, i=i: x[:, i], pipeline_params["layer"]) + for i in range(num_layers_per_stage) + } + } + ) + else: + raise NotImplementedError(target_stack_cfg) + return layer_params + else: + raise NotImplementedError(target_stack_cfg) - # If initializing the cache from scratch, layer_outputs will be None. Further, `data` - # can be effectively treated as a TensorSpec, and thus does not need to be carried - # across layers. - if layer_outputs is None: - assert cache_init - continue - all_layer_outputs.append(layer_outputs) - data = layer_outputs.data +class NonUniformStack(StackedTransformerLayer): + def _aggregate_layer_outputs( + self, layer_outputs: Sequence[BaseTransformerLayer.Output] + ) -> BaseTransformerLayer.Output: + return BaseTransformerLayer.Output( + # Use data and self_attention_kv_state from the final layer outputs. + data=layer_outputs[-1].data, + self_attention_kv_state=layer_outputs[-1].self_attention_kv_state, + # Do not aggregate *_attention_probs. + self_attention_probs=None, + cross_attention_probs=None, + ) - outputs = None if cache_init else self._aggregate_layer_outputs(all_layer_outputs) - return all_layer_states, outputs - def init_states( - self, - *, - time_step: Optional[Tensor], - data: Union[Tensor, TensorSpec], - **layer_kwargs, - ) -> tuple[list[Nested[Tensor]], Optional[TransformerLayer.Output]]: - """See `BaseTransformerLayer.init_states` for details.""" - return self._forward_for_mode( - mode=ForwardMode.INIT_STATES, - cached_states=time_step, - data=data, - **layer_kwargs, - ) +class TestStackedTransformerLayerWithKVState(NonUniformStack): + """A class with a simple override of _update_layer_kwargs for unit testing.""" def _update_layer_kwargs( self, @@ -3640,609 +4169,1244 @@ def _update_layer_kwargs( *, all_layer_outputs: list[BaseTransformerLayer.Output], ): - """Updates `layer_kwargs` using other args. + layer_index = len(all_layer_outputs) + if layer_index == 1: + layer_kwargs["self_attention_kv_state"] = all_layer_outputs[-1].self_attention_kv_state + elif layer_index == 2: + layer_kwargs["self_attention_kv_state"] = None - This method is called before we invoke each layer in `self._layers`. - The updated `layer_kwargs` will be passed to the layer invocation. - Args: - layer_kwargs: a dictionary of arguments that can be used by individual layers. - all_layer_outputs: a list of BaseTransformerLayer.Output that is appended with - the output of each constituent layer in the stack. - """ - pass # Do nothing by default. +class TestStackedTransformerLayerWithSkipConnection(StackedTransformerLayer): + """A class that outputs all layers' output for unit testing.""" def _aggregate_layer_outputs( self, layer_outputs: Sequence[BaseTransformerLayer.Output], - ) -> BaseTransformerLayer.Output: - """Aggregates outputs from the stack.""" - data = layer_outputs[-1].data - self_attention_kv_state = layer_outputs[-1].self_attention_kv_state - aux_outputs = [ - output._replace(data=None, self_attention_kv_state=None) for output in layer_outputs - ] - # Stack auxiliary outputs along axis 0. - outputs = jax.tree.map(lambda *xs: jnp.stack(xs, axis=0), *aux_outputs) - return outputs._replace(data=data, self_attention_kv_state=self_attention_kv_state) - - def forward( - self, - data: Tensor, - **layer_kwargs, - ) -> TransformerLayer.Output: - _, output = self._forward_for_mode( - mode=ForwardMode.FORWARD, - data=data, - cached_states=None, - **layer_kwargs, - ) - return output - - def extend_step( - self, - cached_states: list[NestedTensor], - data: Tensor, - **layer_kwargs, - ) -> tuple[list[Nested[Tensor]], TransformerLayer.Output]: - return self._forward_for_mode( # pytype: disable=bad-return-type - mode=ForwardMode.EXTEND_STEP, - cached_states=cached_states, - data=data, - **layer_kwargs, - ) + ) -> Sequence[BaseTransformerLayer.Output]: + return layer_outputs -class _TransformerRepeat(Repeat): - """A Repeat layer with layer=TransformerLayer.""" +class StackedTransformerTest(BaseTransformerTest): + """Tests StackedTransformerLayer.""" - @config_class - class Config(Repeat.Config): - """Configures _TransformerRepeat.""" - - # The additional fields of BaseTransformerLayer.Output that should propagate as input to - # the next layer. - # - # For example, carry=("data", "self_attention_kv_state") means that both `data` and - # `self_attention_kv_state` will propagate between layers. - # - # If None, only "data" is propagated. - carry: Optional[Sequence[str]] = None - - def _forward_for_mode( + def _stack_config( self, + stack_cfg, *, - mode: ForwardMode, - data: Union[Tensor, TensorSpec], - cached_states: Optional[Nested[Tensor]] = None, - **layer_kwargs, - ) -> tuple[Optional[Nested[Tensor]], Optional[TransformerLayer.Output]]: - """Computes transformer stack outputs. - - Args: - mode: Configures whether `cached_states` are consumed or emitted. See `ForwardMode` for - details. - data: A Tensor of shape [batch, target_length, target_dim]. - cached_states: Optional Nested Tensor as produced by `init_states`. - layer_kwargs: Additional kwargs to each layer. - - Returns: - A tuple (updated_cache_states, outputs): - * updated_cached_states: An optional NestedTensor of cache states, depending on `mode`; - * outputs: An optional instance of Output (see comments on BaseStackedTransformerLayer). - - Raises: - ValueError: If `mode` is unsupported. - """ - cfg: _TransformerRepeat.Config = self.config - - # True iff we are initializing an empty cache (i.e., not prefilling). - cache_init = mode == ForwardMode.INIT_STATES and cached_states is None - - if cached_states is not None: - for path, value in flatten_items(cached_states): - assert value.shape[0] == cfg.num_layers, f"{path}={shapes(value)}" - - def layer_fn(carry, x_i): - if mode == ForwardMode.FORWARD: - layer_states, layer_outputs = None, self.layer(**carry, **layer_kwargs) - elif mode == ForwardMode.INIT_STATES: - # Note that x_i can be None if initializing an empty cache. This corresponds to the - # case where `cached_states=None`. - layer_states, layer_outputs = self.layer.init_states( - time_step=x_i, **carry, **layer_kwargs - ) - elif mode == ForwardMode.EXTEND_STEP: - assert x_i is not None - layer_states, layer_outputs = self.layer.extend_step( - cached_states=x_i, **carry, **layer_kwargs - ) - else: - raise ValueError(f"Unrecognized mode {mode}.") - - ys = {} - if layer_states is not None: - ys["cached_states"] = layer_states - - # If initializing the cache from scratch, layer_outputs will be None. - if layer_outputs is None: - assert cache_init - return carry, ys + num_layers, + model_dim, + num_heads, + dtype, + remat_spec, + output_self_attention_kv_state=False, + ) -> TestStackModel.Config: + if isinstance(stack_cfg, type): + stack_cfg = stack_cfg.default_config() + if callable(remat_spec): + remat_spec = remat_spec(stack_cfg) + cfg = TestStackModel.default_config().set( + name="test", + stack=stack_cfg.set( + input_dim=model_dim, + num_layers=num_layers, + vlog=5, + dtype=dtype, + layer=TransformerLayer.default_config().set(remat_spec=remat_spec), + ), + output_self_attention_kv_state=output_self_attention_kv_state, + ) + layer_cfg = cfg.stack.layer + layer_cfg.self_attention.attention.set(num_heads=num_heads) + layer_cfg.feed_forward.hidden_dim = model_dim * 4 + layer_cfg.vlog = 5 + return cfg - ys.update({k: v for k, v in layer_outputs._asdict().items() if k not in carry}) - return {k: getattr(layer_outputs, k) for k in carry}, ys + @parameterized.product( + transformer_type=[StackedTransformerLayer, RepeatedTransformerLayer], + # Also tests stack-of-stacks and repeat-of-stacks. + layer_type=[TransformerLayer, StackedTransformerLayer], + ) + def test_transformer_extend_step(self, transformer_type, layer_type): + batch_size, src_len, tgt_len = 10, 4, 6 + num_dec_layers, model_dim, num_heads = 3, 16, 4 + + cfg: BaseStackedTransformerLayer.Config = transformer_type.default_config().set( + name="test", + input_dim=model_dim, + num_layers=num_dec_layers, + ) + cross_atten_cfg = TransformerAttentionLayer.default_config().set( + source_dim=model_dim * 2, + structure="postnorm", + ) + cross_atten_cfg.attention.set(num_heads=num_heads) - if cfg.carry is None: - carry = {"data": data} + # Prepare layer config. + if layer_type == StackedTransformerLayer: + cfg.layer = layer_type.default_config().set(num_layers=2) + layer_cfg = cfg.layer.layer else: - layer_kwargs["data"] = data - carry = {k: layer_kwargs.pop(k) for k in cfg.carry} - - repeat_outputs: Repeat.Output = self._run(layer_fn, carry=carry, xs=cached_states) - carry = repeat_outputs.carry - ys = repeat_outputs.ys - updated_states = ys.pop("cached_states", None) - - if cache_init: - assert ys == {} - return updated_states, None - - for k in ("data", "self_attention_kv_state"): - if k in carry: - continue - v = ys.pop(k, None) - if v is not None: - # Take the output from the last layer. - if isinstance(v, KVState): - v = KVState(k_proj=v.k_proj[-1], v_proj=v.v_proj[-1]) - else: - v = v[-1] - carry[k] = v - return updated_states, TransformerLayer.Output(**carry, **ys) - - def forward( - self, - data: Tensor, - **layer_kwargs, - ) -> TransformerLayer.Output: - _, output = self._forward_for_mode( - mode=ForwardMode.FORWARD, - data=data, - cached_states=None, - **layer_kwargs, - ) - return output - - def init_states( - self, - *, - time_step: Optional[Tensor], - data: Union[Tensor, TensorSpec], - **layer_kwargs, - ) -> tuple[Nested[Tensor], Optional[TransformerLayer.Output]]: - cfg: _TransformerRepeat.Config = self.config - # time_step is allowed to be None if initializing an empty cache. - if time_step is not None: - time_step = jnp.tile(time_step, [cfg.num_layers, 1]) - - # In the repeat case, scan requires a Tensor rather than ShapeDtypeStruct. - # Use vmap rather than materializing the Tensor. - if isinstance(data, TensorSpec): + layer_cfg = cfg.layer + layer_cfg.self_attention.attention.set(num_heads=num_heads) + layer_cfg.cross_attention = cross_atten_cfg + layer_cfg.feed_forward.hidden_dim = model_dim * 4 - def layer_fn(_): - return self.layer.init_states(time_step=time_step, data=data, **layer_kwargs) + # Instantiate transformer stack. + layer: BaseStackedTransformerLayer = cfg.instantiate(parent=None) + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) - return jax.vmap(layer_fn)(jnp.empty(cfg.num_layers)) + target = jax.random.normal(jax.random.PRNGKey(123), [batch_size, tgt_len, model_dim]) + source = jax.random.normal(jax.random.PRNGKey(456), [batch_size, src_len, model_dim * 2]) - return self._forward_for_mode( - mode=ForwardMode.INIT_STATES, - data=data, - cached_states=time_step, - **layer_kwargs, + self_attention_logit_biases = attention_bias.make_causal_biases(tgt_len) + cross_attention_logit_biases = ( + jnp.array(np.random.randint(0, 2, [tgt_len, src_len])) * NEG_INF ) - - def extend_step( - self, - cached_states: NestedTensor, - data: Tensor, - **layer_kwargs, - ) -> tuple[NestedTensor, TransformerLayer.Output]: - return self._forward_for_mode( # pytype: disable=bad-return-type - mode=ForwardMode.EXTEND_STEP, - data=data, - cached_states=cached_states, - **layer_kwargs, + return_aux = {"self_attention_probs", "cross_attention_probs"} + + forward_outputs, _ = F( + layer, + inputs=dict( + data=target, + self_attention_logit_biases=self_attention_logit_biases, + cross_attention_data=source, + cross_attention_logit_biases=cross_attention_logit_biases, + return_aux=return_aux, + ), + state=layer_params, + is_training=False, + prng_key=jax.random.PRNGKey(0), + ) + initial_state, initial_output = layer.init_states( + time_step=None, + data=TensorSpec([batch_size, tgt_len]), + ) + self.assertIsNone(initial_output) + inputs = dict( + cached_states=initial_state, cross_attention_data=source, return_aux=return_aux + ) + decoder_output = jnp.zeros(shape=[tgt_len, batch_size, model_dim]) + + # [num_dec_layers, [num_stacked_layers,] batch_size, num_heads, tgt_len, tgt_len] --> + # [tgt_len, num_dec_layers, [num_stacked_layers,] batch_size, num_heads, tgt_len]. + # The layer being stacked can itself be a stack, in which case we have an extra dim. + decoder_self_attention_probs = jnp.moveaxis( + jnp.zeros_like(forward_outputs.self_attention_probs), + -2, + 0, + ) + # [tgt_len, num_dec_layers, [num_stacked_layers,] batch_size, num_heads, src_len]. + decoder_cross_attention_probs = jnp.moveaxis( + jnp.zeros_like(forward_outputs.cross_attention_probs), + -2, + 0, + ) + for t in range(tgt_len): + inputs["data"] = jnp.expand_dims(target[:, t, :], axis=1) + inputs["self_attention_logit_biases"] = self_attention_logit_biases[ + jnp.newaxis, jnp.newaxis, t, : + ] + inputs["cross_attention_logit_biases"] = cross_attention_logit_biases[ + jnp.newaxis, jnp.newaxis, t, : + ] + (updated_states, layer_outputs), _ = F( + layer, + state=layer_params, + is_training=False, + prng_key=jax.random.PRNGKey(456), + inputs=inputs, + method="extend_step", + ) + # Check that updated_states are VDicts for the Repeated layer. + if transformer_type is RepeatedTransformerLayer: + jax.tree.map( + lambda v: self.assertIsInstance(v, utils.VDict), + updated_states, + is_leaf=lambda v: isinstance(v, dict), + ) + inputs["cached_states"] = updated_states + decoder_output = decoder_output.at[t].set(jnp.squeeze(layer_outputs.data, axis=1)) + decoder_self_attention_probs = decoder_self_attention_probs.at[t].set( + jnp.squeeze(layer_outputs.self_attention_probs, axis=-2) + ) + decoder_cross_attention_probs = decoder_cross_attention_probs.at[t].set( + jnp.squeeze(layer_outputs.cross_attention_probs, axis=-2) + ) + decoder_out_transposed = jnp.transpose(decoder_output, [1, 0, 2]) + decoder_self_attention_probs_transposed = jnp.moveaxis(decoder_self_attention_probs, 0, -2) + decoder_cross_attention_probs_transposed = jnp.moveaxis( + decoder_cross_attention_probs, 0, -2 ) + assert_allclose(decoder_out_transposed, forward_outputs.data, atol=1e-6) + assert_allclose( + decoder_self_attention_probs_transposed, forward_outputs.self_attention_probs, atol=1e-6 + ) + assert_allclose( + decoder_cross_attention_probs_transposed, + forward_outputs.cross_attention_probs, + atol=1e-6, + ) -class RepeatedTransformerLayer(BaseStackedTransformerLayer): - """An implementation of BaseStackedTransformerLayer with a scan loop. + @parameterized.product( + transformer_type=[StackedTransformerLayer, RepeatedTransformerLayer], + # Also tests stack-of-stacks and repeat-of-stacks. + layer_type=[TransformerLayer, StackedTransformerLayer], + ) + # pylint: disable-next=too-many-statements + def test_transformer_prefill_states(self, transformer_type, layer_type): + batch_size, src_len, tgt_len = 10, 4, 6 + num_dec_layers, model_dim, num_heads = 3, 16, 4 + + cfg = transformer_type.default_config().set( + name="test", + input_dim=model_dim, + num_layers=num_dec_layers, + ) + cross_atten_cfg = TransformerAttentionLayer.default_config().set( + source_dim=model_dim * 2, + structure="postnorm", + ) + cross_atten_cfg.attention.set(num_heads=num_heads) - Compared with StackedTransformerLayer, the size of the XLA program for RepeatedTransformerLayer - does not grow proportional to the number of layers. In practice, this significantly reduces - XLA compilation overhead of large models with many layers. - """ + # Prepare layer config. + if layer_type == StackedTransformerLayer: + cfg.layer = layer_type.default_config().set(num_layers=2) + layer_cfg = cfg.layer.layer + else: + layer_cfg = cfg.layer + layer_cfg.self_attention.attention.set(num_heads=num_heads) + layer_cfg.cross_attention = cross_atten_cfg + layer_cfg.feed_forward.hidden_dim = model_dim * 4 - @config_class - class Config(BaseStackedTransformerLayer.Config): - """Configures RepeatedTransformerLayer.""" + # Instantiate transformer stack. + layer = cfg.instantiate(parent=None) + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) - repeat: Repeat.Config = _TransformerRepeat.default_config() + target = jax.random.normal(jax.random.PRNGKey(123), [batch_size, tgt_len, model_dim]) + source = jax.random.normal(jax.random.PRNGKey(456), [batch_size, src_len, model_dim * 2]) - def __init__(self, cfg: Config, *, parent: Optional[Module]): - super().__init__(cfg, parent=parent) - cfg = self.config # type: RepeatedTransformerLayer.Config - repeat_cfg = cfg.repeat.set( - layer=cfg.layer.set(input_dim=cfg.input_dim), - num_layers=cfg.num_layers, - ) - self._add_child("repeat", repeat_cfg) - - def initialize_parameters_recursively( - self, prng_key: Tensor, *, prebuilt: Optional[NestedTensor] = None - ) -> NestedTensor: - # We need to call self.repeat.initialize_parameters_recursively() with the same prng_key - # to ensure initialization parity with StackedTransformerLayer. - return dict( - repeat=self.repeat.initialize_parameters_recursively( - prng_key, prebuilt=get_or_none(prebuilt, "repeat") - ) + self_attention_logit_biases = attention_bias.make_causal_biases(tgt_len) + cross_attention_logit_biases = ( + jnp.array(np.random.randint(0, 2, [tgt_len, src_len])) * NEG_INF + ) + return_aux = {"self_attention_probs", "cross_attention_probs"} + + forward_outputs, _ = F( + layer, + inputs=dict( + data=target, + self_attention_logit_biases=self_attention_logit_biases, + cross_attention_data=source, + cross_attention_logit_biases=cross_attention_logit_biases, + return_aux=return_aux, + ), + state=layer_params, + is_training=False, + prng_key=jax.random.PRNGKey(0), + ) + # Initialize state. + time_step = jnp.arange(batch_size) + (initial_states, initial_output), _ = F( + layer, + state=layer_params, + is_training=False, + prng_key=jax.random.PRNGKey(456), + inputs=dict( + time_step=time_step, + data=target, + self_attention_logit_biases=self_attention_logit_biases, + cross_attention_data=source, + cross_attention_logit_biases=cross_attention_logit_biases, + return_aux=return_aux, + ), + method="init_states", ) - def forward( - self, - data: Tensor, - **layer_kwargs, - ) -> TransformerLayer.Output: - return self.repeat(data, **layer_kwargs) + # Zero-out outputs starting from initial time_step, and test that we can recover the full + # outputs by calling extend_step starting from time_step. + time_step_mask = jnp.arange(tgt_len) < time_step[:, None] + # [batch, tgt_len, model_dim]. + decoder_output = initial_output.data * time_step_mask[..., None] + # [num_layers, batch, num_heads, tgt_len, tgt_len]. + decoder_self_attention_probs = ( + initial_output.self_attention_probs * time_step_mask[None, :, None, :, None] + ) + # [num_layers, batch, num_heads, tgt_len, src_len]. + decoder_cross_attention_probs = ( + initial_output.cross_attention_probs * time_step_mask[None, :, None, :, None] + ) - def init_states(self, *args, **kwargs): - cached_states, output = self.repeat.init_states(*args, **kwargs) - return VDict(repeat=cached_states), output + # Transpose for simpler updates during extend_step. + # [batch, tgt_len, model_dim] --> [batch, model_dim, tgt_len]. + decoder_output = jnp.moveaxis(decoder_output, -2, -1) + # [..., tgt_len, src_len] --> [..., src_len, tgt_len]. + decoder_self_attention_probs = jnp.moveaxis(decoder_self_attention_probs, -2, -1) + decoder_cross_attention_probs = jnp.moveaxis(decoder_cross_attention_probs, -2, -1) - def extend_step( - self, - cached_states: NestedTensor, - data: Tensor, - **layer_kwargs, - ) -> tuple[list[NestedTensor], TransformerLayer.Output]: - repeat_cached_states, output = self.repeat.extend_step( - cached_states=cached_states["repeat"], - data=data, - **layer_kwargs, + # Call extend_step from time_step, ensuring that outputs match. + inputs = dict( + cached_states=initial_states, cross_attention_data=source, return_aux=return_aux + ) + while jnp.any(time_step < tgt_len): + # [batch, tgt_len=1, model_dim]. + inputs["data"] = jnp.take_along_axis( + target, time_step[:, None, None], axis=1, mode="clip" + ) + # [batch=1, tgt_len=1, tgt_len]. + inputs["self_attention_logit_biases"] = jnp.take_along_axis( + self_attention_logit_biases[None, :, :], + time_step[:, None, None], + axis=1, + mode="clip", + ) + # [batch=1, tgt_len=1, src_len]. + inputs["cross_attention_logit_biases"] = jnp.take_along_axis( + cross_attention_logit_biases[None, :, :], + time_step[:, None, None], + axis=1, + mode="clip", + ) + (updated_states, layer_outputs), _ = F( + layer, + state=layer_params, + is_training=False, + prng_key=jax.random.PRNGKey(456), + inputs=inputs, + method="extend_step", + ) + # Check that updated_states are VDicts for the Repeated layer. + if transformer_type is RepeatedTransformerLayer: + jax.tree.map( + lambda v: self.assertIsInstance(v, utils.VDict), + updated_states, + is_leaf=lambda v: isinstance(v, dict), + ) + inputs["cached_states"] = updated_states + + # [batch, model_dim, tgt_len=1] + curr_outputs = jnp.moveaxis(layer_outputs.data, -2, -1) + # [..., tgt_len, tgt_len=1] + curr_self_attention_probs = jnp.moveaxis(layer_outputs.self_attention_probs, -2, -1) + # [..., src_len, tgt_len=1] + curr_cross_attention_probs = jnp.moveaxis(layer_outputs.cross_attention_probs, -2, -1) + + # [batch, 1, tgt_len]. + oh_indices = jax.nn.one_hot(time_step, tgt_len)[:, None, :] + decoder_output = decoder_output + curr_outputs * oh_indices + # [num_layers=1, batch, num_heads=1, tgt_len=1, tgt_len]. + oh_indices = oh_indices[None, :, None, :, :] + decoder_self_attention_probs = ( + decoder_self_attention_probs + curr_self_attention_probs * oh_indices + ) + decoder_cross_attention_probs = ( + decoder_cross_attention_probs + curr_cross_attention_probs * oh_indices + ) + time_step = time_step + 1 + + # [batch, model_dim, tgt_len] --> [batch, tgt_len, model_dim]. + decoder_output = jnp.moveaxis(decoder_output, -1, -2) + # [..., src_len, tgt_len] --> [..., tgt_len, src_len]. + decoder_self_attention_probs = jnp.moveaxis(decoder_self_attention_probs, -1, -2) + decoder_cross_attention_probs = jnp.moveaxis(decoder_cross_attention_probs, -1, -2) + + assert_allclose(decoder_output, forward_outputs.data) + assert_allclose(decoder_self_attention_probs, forward_outputs.self_attention_probs) + assert_allclose(decoder_cross_attention_probs, forward_outputs.cross_attention_probs) + + def test_skip_connection(self): + batch_size = 2 + seq_len = 6 + num_heads = 2 + input_dim = 4 + hidden_dim = 8 + num_layers = 5 + layer_with_skip_input = 3 + + cfg = TestStackedTransformerLayerWithSkipConnection.default_config().set( + name="test", input_dim=input_dim, num_layers=num_layers ) - return VDict(repeat=repeat_cached_states), output + transformer_cfg = TransformerLayer.default_config() + transformer_cfg.self_attention.attention.num_heads = num_heads + transformer_cfg.feed_forward.hidden_dim = hidden_dim + cfg.layer = transformer_cfg -class _TransformerPipeline(Pipeline): - """Transformer pipeline layer.""" + test_cfg = cfg.clone().set( + data_merger=config_for_function(update_data_with_skip_connection).set( + skip_connections={layer_with_skip_input: 1} + ) + ) - def forward( - self, - data: Tensor, - *, - return_aux: Optional[set[str]] = None, - **kwargs, - ) -> TransformerLayer.Output: - carry_in = dict(data=data) - return_aux = return_aux or set() - - # Even though attention logit biases do not change across layers, we - # include them in the carry so that they are aligned with the microbatches. - carry_in.update(kwargs) - carry_in = self._to_microbatches(carry_in) - self.vlog(3, "carry_in=%s", shapes(carry_in)) - - def layer_fn(carry, _): - layer_outputs: TransformerLayer.Output = self.layer(**carry) - carry.pop("data") - return dict(**carry, data=layer_outputs.data), { - k: v if k in return_aux else None - for k, v in layer_outputs._asdict().items() - if k != "data" - } + base_layer = cfg.instantiate(parent=None) + test_layer = test_cfg.instantiate(parent=None) - pipeline_outputs: Pipeline.Output = self._run(layer_fn, carry_in) - carry_out = self._from_microbatches(pipeline_outputs.carry["data"]) + random_inputs = jax.random.uniform( + jax.random.PRNGKey(1), shape=(batch_size, seq_len, input_dim) + ) + state = base_layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) + base_output, _ = F( + base_layer, + is_training=True, + prng_key=jax.random.PRNGKey(123), + state=state, + inputs=dict(data=random_inputs), + ) + test_output, _ = F( + test_layer, + is_training=True, + prng_key=jax.random.PRNGKey(123), + state=state, + inputs=dict(data=random_inputs), + ) - ys = pipeline_outputs.ys - self.vlog(3, "ys=%s", shapes(ys)) - return TransformerLayer.Output(data=carry_out, **ys) + for i in range(layer_with_skip_input): + self.assertNestedAllClose( + base_output[i].data, + test_output[i].data, + ) + for i in range(layer_with_skip_input, num_layers): + self.assertNotAlmostEqual( + jnp.min(jnp.abs(base_output[i].data - test_output[i].data)), + 0.0, + ) + def test_update_layer_kwargs(self): + batch_size = 2 + seq_len = 6 + num_heads = 2 + input_dim = 4 + per_head_dim = input_dim // num_heads + hidden_dim = 8 + num_layers = 3 + + # Create a StackedTransformerLayer by specifying a sequence of non-uniform layer configs. + cfg = TestStackedTransformerLayerWithKVState.default_config().set(name="test") + cfg.input_dim = input_dim + cfg.num_layers = num_layers + cfg.layer = [] + for i in range(num_layers): + transformer_cfg = TransformerLayer.default_config() + transformer_cfg.self_attention.attention.num_heads = num_heads + transformer_cfg.feed_forward.hidden_dim = hidden_dim + + if i == 1: + transformer_cfg.self_attention.attention.input_linear = QLinear.default_config() + + cfg.layer.append(transformer_cfg) + + layer: StackedTransformerLayer = cfg.instantiate(parent=None) + inputs = jax.random.uniform(jax.random.PRNGKey(1), shape=(batch_size, seq_len, input_dim)) + state = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) + outputs, _ = F( + layer, + is_training=True, + prng_key=jax.random.PRNGKey(123), + state=state, + inputs=dict(data=inputs, return_aux={"self_attention_kv_state"}), + ) + self.assertEqual( + BaseTransformerLayer.Output( + data=(batch_size, seq_len, input_dim), + self_attention_probs=None, + self_attention_kv_state=KVState( + k_proj=(batch_size, seq_len, num_heads, per_head_dim), + v_proj=(batch_size, seq_len, num_heads, per_head_dim), + ), + cross_attention_probs=None, + ), + shapes(outputs), + ) -class PipelinedTransformerLayer(BaseStackedTransformerLayer): - """An implementation of BaseStackedTransformerLayer with pipeline model parallelism.""" + def test_stack_vs_repeat(self): + self._compare_layers(StackedTransformerLayer, RepeatedTransformerLayer) + + def test_stack_vs_repeat_bfloat16(self): + # FIXME(rpang): fix the following test, which is caused by different behaviors of bfloat16 + # to float32 casting. + # self._compare_layers(StackedTransformerLayer, RepeatedTransformerLayer, + # dtype=jnp.bfloat16) + pass + + def test_stack_vs_repeat_remat_everything_saveable(self): + self._compare_layers( + StackedTransformerLayer, + RepeatedTransformerLayer, + remat_spec=RematSpec(policy=jax_remat_policies.everything_saveable), + ) - @config_class - class Config(BaseStackedTransformerLayer.Config): - """Configures PipelinedTransformerLayer.""" - - # The number of pipeline stages. Must evenly divide `num_layers`. - num_stages: Required[int] = REQUIRED - # The number of pipeline microbatches. Must evenly divide batch size. - num_microbatches: Required[int] = REQUIRED - # Config for each stage in the pipeline. - stage: BaseLayer.Config = StackedTransformerLayer.default_config().set(layer=None) - # Config for the pipeline implementation, such as pipeline schedule. - pipeline: _TransformerPipeline.Config = _TransformerPipeline.default_config() - - def __init__(self, cfg: Config, *, parent: Optional[Module]): - super().__init__(cfg, parent=parent) - cfg = self.config # type: PipelinedTransformerLayer.Config - if cfg.num_layers % cfg.num_stages != 0: - raise ValueError(f"num_stages {cfg.num_stages} must divide num_layers {cfg.num_layers}") - num_layers_per_stage = cfg.num_layers // cfg.num_stages - stage_cfg = cfg.stage.set( - input_dim=cfg.input_dim, layer=cfg.layer, num_layers=num_layers_per_stage - ) - pipeline_cfg = cfg.pipeline.set( - layer=stage_cfg, num_layers=cfg.num_stages, num_microbatches=cfg.num_microbatches - ) - self._add_child("pipeline", pipeline_cfg) - - def initialize_parameters_recursively( - self, prng_key: Tensor, *, prebuilt: Optional[Nested[Optional[ParameterSpec]]] = None - ) -> NestedTensor: - cfg = self.config # type: PipelinedTransformerLayer.Config - # We pre-split all num_layers keys to ensure initialization parity with - # StackedTransformerLayer. - prng_key = split_prng_key(prng_key, (cfg.num_stages, cfg.num_layers // cfg.num_stages)) - return dict( - pipeline=self.pipeline.initialize_parameters_recursively( - prng_key, prebuilt=get_or_none(prebuilt, "pipeline") - ) + def test_stack_vs_repeat_with_build_remat_spec(self): + self._compare_layers( + StackedTransformerLayer, + RepeatedTransformerLayer, + remat_spec=build_remat_spec, ) - def forward( + @parameterized.product( + stage_cls=[StackedTransformerLayer, RepeatedTransformerLayer], + schedule_cls=[GPipeSchedule, StreamSchedule], + remat_spec=[None, RematSpec(policy=jax_remat_policies.everything_saveable)], + ) + def test_stack_vs_pipeline( self, - data: Tensor, - **kwargs, - ) -> TransformerLayer.Output: - return self.pipeline(data, **kwargs) - - # TODO(sneha): extend_step - - -OffloadPolicy = Callable[[Primitive, list[Any], dict[str, Any]], Union[bool, Any]] -_SavePattern = Union[str, re.Pattern, None] - - -# Adapted from jax source code to support regex. Reference: -# https://github.com/jax-ml/jax/blob/0d36b0b433a93c707f86dac89b0c05d40302775a/jax/_src/ad_checkpoint.py#L120 -def _save_and_offload_only_these_names_regex( - *, - names_which_can_be_saved: _SavePattern, - names_which_can_be_offloaded: _SavePattern, - offload_src: str, - offload_dst: str, -) -> OffloadPolicy: - def policy(prim, *_, **params): - if prim is name_p: - if names_which_can_be_saved and re.fullmatch(names_which_can_be_saved, params["name"]): - return pe.Saveable - if names_which_can_be_offloaded and re.fullmatch( - names_which_can_be_offloaded, params["name"] - ): - return pe.Offloadable(src=offload_src, dst=offload_dst) - return pe.Recompute # not saveable unless it's in the allow-list + stage_cls: type[BaseTransformerLayer], + schedule_cls: type[BaseSchedule], + remat_spec: Optional[RematSpec], + ): + pipelined_cfg: PipelinedTransformerLayer.Config = PipelinedTransformerLayer.default_config() + pipelined_cfg.stage = stage_cls.default_config().set(layer=None) + pipelined_cfg.pipeline.schedule = schedule_cls.default_config() + + # If using StreamSchedule, we expect `num_microbatches` to be divisible by `num_stages`. + if schedule_cls is StreamSchedule: + # num_microbatches = 6, num_stages = 3, microbatch_size = 2 + batch_size, num_layers = 12, 6 + else: + # num_microbatches = 5, num_stages = 3, microbatch_size = 2 + batch_size, num_layers = 10, 6 + + pipelined_cfg.num_microbatches = batch_size // 2 + pipelined_cfg.num_stages = num_layers // 2 + self._compare_layers( + StackedTransformerLayer, + pipelined_cfg, + remat_spec=remat_spec, + batch_size=batch_size, + num_layers=num_layers, + ) - return policy + # pylint: disable-next=too-many-statements,too-many-branches + def _compare_layers( + self, + *stack_configs, + dtype=jnp.float32, + remat_spec=None, + batch_size: int = 10, + num_layers: int = 6, + ): + assert stack_configs[0] == StackedTransformerLayer, stack_configs[0] + with utils.numeric_checks(False): + tgt_len, model_dim, num_heads = 5, 8, 4 + target = jax.random.normal( + jax.random.PRNGKey(123), [batch_size, tgt_len, model_dim], dtype=dtype + ) + rand_mask = _random_mask(jax.random.PRNGKey(123), tgt_len, tgt_len) + rand_mask = jnp.tile(rand_mask[None, None, :, :], (batch_size, num_heads, 1, 1)) + + all_params = [] + all_outputs = [] + all_gradients = [] + all_updates = [] + stacked_layer_params = None + for stack_cfg in stack_configs: + cfg = self._stack_config( + stack_cfg, + num_layers=num_layers, + model_dim=model_dim, + num_heads=num_heads, + dtype=dtype, + remat_spec=remat_spec, + ) + cls = cfg.stack.klass + layer: TestStackModel = cfg.instantiate(parent=None) + + param_specs = layer.create_parameter_specs_recursively() + logging.info( + "%s.factorization_specs=%s", + cls, + jax.tree.map(lambda x: x.factorization, param_specs), + ) + layer_params = layer.initialize_parameters_recursively( + prng_key=jax.random.PRNGKey(123) + ) + logging.info( + "%s.params=%s", + cls, + [ + f"{path}={value.dtype}({value.shape})" + for path, value in flatten_items(layer_params) + ], + ) + if cls == StackedTransformerLayer: + stacked_layer_params = copy.deepcopy(layer_params) + else: + layer_params = _convert_from_stacked_params( + stacked_layer_params, target_stack_cfg=cfg.stack + ) + logging.info( + "Converted: %s.params=%s", + cls, + [ + f"{path}={value.dtype}({value.shape})" + for path, value in flatten_items(layer_params) + ], + ) -SELF_ATTENTION_SAVE_PATTERN = ".*([qkvo]_proj|context)" -FEED_FORWARD_SAVE_PATTERN = ".*linear[12]_.*" + def _loss(layer_params, data, mask, layer=layer): + layer_outputs, layer_output_collection = F( + layer, + inputs=dict( + data=data, self_attention_logit_biases=mask, target_segment_ids=None + ), + state=layer_params, + is_training=True, + prng_key=jax.random.PRNGKey(0), + ) + loss, aux = layer_outputs + return loss, (aux, layer_output_collection) + value, grads = jax.value_and_grad(_loss, has_aux=True)( + layer_params, jnp.asarray(target), rand_mask + ) + loss, (aux, layer_output_collection) = value + layer_outputs = (loss, aux) + + # Note that we do not compare summaries across stack layer types because: + # (1) attention layers do not emit summaries yet; + # (2) pipelines emit per-microbatch summaries which have a different structure + # than summaries from other stack layers. + summaries = layer_output_collection.summaries + logging.info( + "layer_outputs=%s summaries=%s", + shapes(flatten_items(layer_outputs)), + shapes(flatten_items(summaries)), + ) + logging.info( + "global_grad_norm=%s, grads=%s", + optax.global_norm(grads), + shapes(flatten_items(grads)), + ) -def build_remat_spec( - stack_cfg: Union[ - BaseStackedTransformerLayer.Config, "RepeatedConformerLayer.Config" # type: ignore - ], - save_pattern: _SavePattern = SELF_ATTENTION_SAVE_PATTERN, - offload_pattern: _SavePattern = None, - offload_dst: str = "pinned_host", -) -> Optional[RematSpec]: - """Configures how the Transformer or Conformer stack will save the linearization points. + optimizer = adafactor_optimizer( + learning_rate=0.1, + b1=0.9, + b2=0.98, + multiply_by_parameter_scale=False, + clipping_threshold=1.0, + eps=1e-2, + ) + opt_params = jax.tree.map( + lambda spec, p: OptParam( + value=p, + factorization_spec=spec.factorization, + weight_decay_scale=spec.weight_decay_scale, + ), + param_specs, + layer_params, + ) + opt_state = optimizer.init(opt_params) + logging.info("opt_state=%s", shapes(opt_state)) + updates, opt_state = optimizer.update(grads, opt_state, opt_params) - We try to save activations from the forward pass that are inefficient to recompute on the - backward pass. We choose the linearization points in the MultiHeadAttention layer, as that - demonstrated (empirically) the best throughput, allowing us to train with a batch size of 16 on - gpt2-10b with adamw and full sharding across 4 TPU v4 chips and a RepeatedTransformerLayer, - with 1.8x the step time of a stacked layer with a batch size of 8 and the same sharding config. + def rms_norm(x): + return jnp.sqrt(jnp.mean(x**2)) - For conformer model, we start from the same remat policy as language models. - TODO(zhiyunlu): investigate Conformer model's memory/step-time tradeoffs. Possibly we - need to save points in the LConv module. + if cls == StackedTransformerLayer: + update_norms = jax.tree.map(rms_norm, updates) + else: + update_norms = jax.vmap(lambda x, norm=rms_norm: jax.tree.map(norm, x))(updates) + logging.info( + "global_update_norm=%s update_norms=%s", + optax.global_norm(updates), + dict(utils.flatten_items(update_norms)), + ) - Args: - stack_cfg: A transformer config. - save_pattern: Activation regex pattern to save in HBM. - offload_pattern: Activation regex pattern to offload to `offload_dst`. - offload_dst: Destination of remat checkptoing offloading. Relevant Maxtext example: - https://github.com/google/maxtext/blob/ebd39aa64d670fa13a313b6f776e01ad9e450321/MaxText/layers/models.py#L230. + if cls == StackedTransformerLayer: + for x in (layer_params, grads, updates): + x["stack"] = _recursive_stack(x["stack"]) + + if cls == RepeatedTransformerLayer: + for x in (layer_params, grads, updates): + x["stack"] = x["stack"]["repeat"] + + if cls == PipelinedTransformerLayer: + for x in (layer_params, grads, updates): + logging.info("x=%s", shapes(x)) + if cfg.stack.stage.klass == StackedTransformerLayer: + # First stack within each stage. + x["stack"]["pipeline"]["layer"] = _recursive_stack( + x["stack"]["pipeline"]["layer"], axis=1 + ) + logging.info("x=%s", shapes(x)) + elif cfg.stack.stage.klass == RepeatedTransformerLayer: + x["stack"]["pipeline"]["layer"] = x["stack"]["pipeline"]["layer"][ + "repeat" + ] + else: + raise NotImplementedError(cfg.stack.stage.klass) + + # Then reshape across stages. + x["stack"] = jax.tree.map( + lambda x: x.reshape([num_layers] + list(x.shape[2:])), + x["stack"]["pipeline"]["layer"], + ) + + all_params.append(layer_params) + all_outputs.append(layer_outputs) + all_gradients.append(grads) + all_updates.append(updates) + + if cls == StackedTransformerLayer: + one_layer = layer.stack.layer0 + elif cls == RepeatedTransformerLayer: + one_layer = layer.stack.repeat.layer + else: + one_layer = None - Returns: - None (if no rematerialization is needed) or a RematSpec. - """ - # TODO(markblee): Switch to using isinstance everywhere. - if stack_cfg.klass is PipelinedTransformerLayer: - return None - - policy = config_for_function(_save_and_offload_only_these_names_regex).set( - names_which_can_be_saved=save_pattern, - names_which_can_be_offloaded=offload_pattern, - offload_src="device", - offload_dst=offload_dst, + # pylint: disable=protected-access + if one_layer is not None: + logging.info( + "%s._remat_methods = %s", one_layer.path(), one_layer._remat_methods + ) + if remat_spec is not None: + self.assertSequenceEqual( + one_layer._remat_methods, ["forward"], msg=one_layer.path() + ) + else: + self.assertEmpty(one_layer._remat_methods, msg=one_layer.path()) + # pylint: enable=protected-access + + self.assertNestedAllClose(all_params[0], all_params[1]) + self.assertNestedAllClose(all_outputs[0], all_outputs[1]) + self.assertNestedAllClose(all_gradients[0], all_gradients[1]) + self.assertNestedAllClose(all_updates[0], all_updates[1]) + + @parameterized.parameters(StackedTransformerLayer, RepeatedTransformerLayer) + def test_stacked_decoding(self, stack_cls): + model_dim, num_heads = 6, 2 + cfg = stack_cls.default_config().set(num_layers=5, input_dim=model_dim) + layer_cfg = cfg.layer + layer_cfg.self_attention.attention.set(num_heads=num_heads, causal=True) + layer_cfg.feed_forward.hidden_dim = model_dim * 4 + self._test_forward_vs_extend_step(cfg) + self._test_decoder_with_transformer(cfg) + + @parameterized.product( + outer_stack_cls=(StackedTransformerLayer, RepeatedTransformerLayer), + inner_stack_cls=(StackedTransformerLayer, RepeatedTransformerLayer), ) + def test_nested_stacked_decoding(self, outer_stack_cls, inner_stack_cls): + model_dim, num_heads = 6, 2 + cfg = outer_stack_cls.default_config().set(num_layers=2, input_dim=model_dim) + cfg.layer = inner_stack_cls.default_config().set(num_layers=3) + layer_cfg = cfg.layer.layer + layer_cfg.self_attention.attention.set(num_heads=num_heads, causal=True) + layer_cfg.feed_forward.hidden_dim = model_dim * 4 + self._test_forward_vs_extend_step(cfg) + self._test_decoder_with_transformer(cfg) + + @parameterized.parameters(None, 0.0, 0.2, 1.0) + def test_stochastic_depth(self, rate): + batch_size, tgt_len = 10, 6 + num_dec_layers, model_dim, num_heads = 3, 16, 4 + model_dim = 16 + num_heads = 4 + cfg = StackedTransformerLayer.default_config().set( + name="test", + input_dim=model_dim, + num_layers=num_dec_layers, + peak_stochastic_depth_rate=rate, + ) + layer_cfg = cfg.layer + layer_cfg.self_attention.attention.set(num_heads=num_heads) + layer_cfg.feed_forward.hidden_dim = model_dim * 4 + + if rate is None or 0 <= rate < 1: + layer = cfg.instantiate(parent=None) + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) + target = jax.random.normal(jax.random.PRNGKey(123), [batch_size, tgt_len, model_dim]) + F( + layer, + inputs=dict(data=target), + state=layer_params, + is_training=True, + prng_key=jax.random.PRNGKey(0), + ) + else: + with self.assertRaises(ValueError): + cfg.instantiate(parent=None) + + @parameterized.product(is_training=(True, False)) + def test_stacked_transformer_with_seq_layer_cfgs(self, is_training): + batch_size = 2 + seq_len = 16 + input_dim = 4 + hidden_dim = 16 + num_layers = 4 + num_heads = 4 + + # Create a StackedTransformerLayer by specifying a sequence of layer configs. + cfg = StackedTransformerLayer.default_config().set(name="test") + cfg.input_dim = input_dim + cfg.num_layers = num_layers + transformer_cfg = TransformerLayer.default_config() + transformer_cfg.self_attention.attention.num_heads = num_heads + transformer_cfg.feed_forward.hidden_dim = hidden_dim + cfg.layer = (transformer_cfg,) * num_layers + layer: StackedTransformerLayer = cfg.instantiate(parent=None) + inputs = jax.random.uniform(jax.random.PRNGKey(1), shape=(batch_size, seq_len, input_dim)) + state = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) + outputs, _ = F( + layer, + is_training=is_training, + prng_key=jax.random.PRNGKey(123), + state=state, + inputs=dict(data=inputs), + ) + # Create a ref StackedTransformerLayer with repeating the default layer cfg. + ref_cfg = StackedTransformerLayer.default_config().set(name="test") + ref_cfg.input_dim = input_dim + ref_cfg.num_layers = num_layers + ref_cfg.layer.self_attention.attention.num_heads = num_heads + ref_cfg.layer.feed_forward.hidden_dim = hidden_dim + ref_layer: StackedTransformerLayer = ref_cfg.instantiate(parent=None) + ref_outputs, _ = F( + ref_layer, + is_training=is_training, + prng_key=jax.random.PRNGKey(123), + state=state, + inputs=dict(data=inputs), + ) + assert_allclose(outputs.data, ref_outputs.data) + assert_allclose(outputs.self_attention_probs, ref_outputs.self_attention_probs) + + @parameterized.product(is_training=(True, False)) + def test_stacked_transformer_with_non_uniform_layers(self, is_training): + """Tests that a custom StackedTransformerLayer can support non-uniform layers.""" + batch_size = 2 + seq_len = 16 + input_dim = 4 + hidden_dim = 16 + num_layers = 2 + + # Create a StackedTransformerLayer by specifying a sequence of non-uniform layer configs. + cfg = NonUniformStack.default_config().set(name="test") + cfg.input_dim = input_dim + cfg.num_layers = num_layers + cfg.layer = [] + for i in range(num_layers): + transformer_cfg = TransformerLayer.default_config() + # Different numbers of heads between the layers. + transformer_cfg.self_attention.attention.num_heads = 2 if i == 0 else 1 + transformer_cfg.feed_forward.hidden_dim = hidden_dim + cfg.layer.append(transformer_cfg) + layer: StackedTransformerLayer = cfg.instantiate(parent=None) + inputs = jax.random.uniform(jax.random.PRNGKey(1), shape=(batch_size, seq_len, input_dim)) + state = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) + outputs, _ = F( + layer, + is_training=is_training, + prng_key=jax.random.PRNGKey(123), + state=state, + inputs=dict(data=inputs, return_aux={"self_attention_kv_state"}), + ) + self.assertEqual( + BaseTransformerLayer.Output( + data=(2, 16, 4), + self_attention_probs=None, + self_attention_kv_state=KVState(k_proj=(2, 16, 1, 4), v_proj=(2, 16, 1, 4)), + cross_attention_probs=None, + ), + shapes(outputs), + ) - return RematSpec( - prevent_cse=stack_cfg.klass is StackedTransformerLayer, - # If we are running inside a jax.lax.scan (Repeated/Pipelined transformers - # or Repeated Conformers) we can enable common subexpression elimination optimizations. - policy=policy, + @parameterized.parameters( + [None, False], + [("data",), False], + [("data",), True], + [("data", "self_attention_kv_state"), True], ) + def test_repeated_layer_with_custom_carry(self, repeat_carry, precomputed_kv_state): + """Tests RepeatedTransformerLayer with customized `carry`.""" + batch_size = 1 + seq_len = 16 + input_dim = 4 + num_heads = 2 + head_dim = input_dim // num_heads + num_layers = 3 + + cfg = self._stack_config( + RepeatedTransformerLayer, + num_layers=num_layers, + model_dim=input_dim, + num_heads=num_heads, + dtype=jnp.float32, + remat_spec=None, + output_self_attention_kv_state=True, + ) + cfg.stack.repeat.carry = repeat_carry + cfg.stack.layer.remat_spec = build_remat_spec(cfg.stack) + if precomputed_kv_state: + kv_shape = (batch_size, seq_len, num_heads, head_dim) + kv_state = KVState( + k_proj=jax.random.normal(key=jax.random.PRNGKey(1), shape=kv_shape), + v_proj=jax.random.normal(key=jax.random.PRNGKey(2), shape=kv_shape), + ) + cfg.stack.layer.self_attention.attention.input_linear = QLinear.default_config() + expected_output = 1.8719857 + else: + kv_state = None + # carry=None and carry=("data",) are equivalent. + expected_output = 5.3901253 + + layer = cfg.instantiate(parent=None) + state = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) + inputs = jax.random.uniform(jax.random.PRNGKey(1), shape=(batch_size, seq_len, input_dim)) + outputs, _ = F( + layer, + is_training=True, + prng_key=jax.random.PRNGKey(123), + state=state, + inputs=dict( + data=inputs, + self_attention_kv_state=kv_state, + return_aux={"self_attention_kv_state"}, + ), + ) + self.assertNestedAllClose(expected_output, outputs[0]) + if precomputed_kv_state: + self.assertNestedAllClose(kv_state, outputs[1]["self_attention_kv_state"]) + else: + self.assertIsInstance(outputs[1]["self_attention_kv_state"], KVState) + def test_pipeline_return_aux(self): + batch_size, num_heads, seq_len, dim = 2, 3, 4, 6 -class AttentionLogitBiasLayer(BaseLayer): - """Base attention logit bias layer. + class DummyTransformerLayer(TransformerLayer): + def forward(self, data, **kwargs): + return TransformerLayer.Output( + data=data, + self_attention_probs=jnp.empty([batch_size, num_heads, seq_len, seq_len]), + self_attention_kv_state=KVState( + k_proj=jnp.empty([batch_size, seq_len, num_heads, dim]), + v_proj=jnp.empty([batch_size, seq_len, num_heads, dim]), + ), + ) - The attention logit bias layer should have input_ids as input. - """ + cfg: PipelinedTransformerLayer.Config = PipelinedTransformerLayer.default_config().set( + num_stages=2, + num_microbatches=2, + num_layers=2, + input_dim=dim, + layer=DummyTransformerLayer.default_config(), + ) + cfg.layer.self_attention.attention.set(num_heads=num_heads) + cfg.layer.feed_forward.hidden_dim = scaled_hidden_dim(4) + + with test_utils.bind_layer(cfg) as layer: + data = jax.random.uniform(layer.prng_key, shape=[2, 3, 4]) + out = layer(data, return_aux={"self_attention_kv_state"}) + self.assertNestedAllClose(data, out.data) + self.assertIsNone(out.self_attention_probs) + self.assertIsNotNone(out.self_attention_kv_state) + + @parameterized.parameters( + ([],), + (["self_attention"],), + (["feed_forward"],), + (["self_attention", "feed_forward"],), + ) + def test_initialize_parameters_recursively(self, prebuilt_layers: list[str]): + """Tests initialize_parameters_recursively with various prebuilt layers.""" + input_dim = 4 + num_heads = 2 + num_layers = 3 + + cfg = self._stack_config( + RepeatedTransformerLayer, + num_layers=num_layers, + model_dim=input_dim, + num_heads=num_heads, + dtype=jnp.float32, + remat_spec=None, + output_self_attention_kv_state=True, + ) + cfg.stack.layer.remat_spec = build_remat_spec(cfg.stack) + layer = cfg.instantiate(parent=None) + param_specs = layer.create_parameter_specs_recursively() + initialized_from_scratch = layer.initialize_parameters_recursively( + prng_key=jax.random.PRNGKey(123) + ) + jax.tree_util.tree_map_with_path( + lambda path, spec, param: self.assertEqual(param.shape, spec.shape, path), + param_specs, + initialized_from_scratch, + ) - def forward(self, *, segment_ids: Tensor, positions: Tensor) -> Tensor: - """Produces attention logit biases. - - Args: - segment_ids: An integer Tensor of shape [batch_size, seq_len] with values in - [0, num_segments). Tokens are only allowed to attend to other tokens within the same - segment. segment_ids == 0 represents paddings. - positions: An Tensor of broadcastable shape to `input_ids` with values in [0, seq_len). - This can be used to produce biases for packed inputs. - - Returns: - A float attention logit biases of shape [batch_size, 1, seq_len, seq_len] or - [batch_size, num_heads, seq_len, seq_len]. - Output[b,i,j] is -inf iff attention is disabled with query=input[b, i] and - key=input[b, j]. - """ - raise NotImplementedError(type(self)) + def has_prebuilt_layers(path): + for prebuilt_layer in prebuilt_layers: + for part in path: + if prebuilt_layer == part.key: + return True + return False + # ParameterSpec for a prebuilt param, None otherwise. + prebuilt_specs = jax.tree_util.tree_map_with_path( + lambda path, spec: spec if has_prebuilt_layers(path) else None, param_specs + ) + if prebuilt_layers: + self.assertNotEmpty(jax.tree_util.tree_leaves(prebuilt_specs)) + initialized_state = layer.initialize_parameters_recursively( + prng_key=jax.random.PRNGKey(123), prebuilt=prebuilt_specs + ) -def compute_padding_biases(input_ids: Tensor, *, pad_token_id: Optional[int]) -> Tensor: - """Compute the logits bias to disable attention to/from paddings. + def validate_initialized(path, spec, initialized, prebuilt): + if prebuilt is None: + self.assertEqual(spec.shape, initialized.shape, path) + else: + self.assertIsNone(initialized) - Args: - input_ids: A Tensor of shape [batch_size, seq_len]. - pad_token_id: An int representing the padded token ID or None. + jax.tree_util.tree_map_with_path( + validate_initialized, param_specs, initialized_state, prebuilt_specs + ) - Returns: - A float logit biases of shape [batch_size, 1, seq_len, seq_len]. - """ - if pad_token_id is None: - batch_size, seq_len = input_ids.shape - return jnp.zeros([batch_size, 1, seq_len, seq_len]) - padding_bias = (input_ids == pad_token_id) * NEG_INF - return padding_bias[:, None, None, :] + padding_bias[:, None, :, None] +class ConfigHelperTest(TestCase): + """Tests config utils.""" + + @parameterized.product( + self_attention_input_linear_cfg=( + QKVLinear.default_config(), + FusedQKVLinear.default_config(), + RoFormerQKVLinear.default_config().set(input_linear=FusedQKVLinear.default_config()), + ), + cross_attention_cfg=(None, TransformerAttentionLayer.default_config()), + batch_axis_names=("data", ("replica", "data", "fsdp")), + fsdp_axis_names=("fsdp",), + tp_axis_names=("model",), + seq_axis_names=("seq",), + ) + def test_set_double_shard_weights_config( + self, + self_attention_input_linear_cfg, + cross_attention_cfg, + batch_axis_names, + fsdp_axis_names, + tp_axis_names, + seq_axis_names, + ): + cfg: TransformerLayer.Config = TransformerLayer.default_config().set( + cross_attention=cross_attention_cfg + ) + cfg.self_attention.attention.input_linear = self_attention_input_linear_cfg + set_double_shard_weights_config( + cfg, + batch_axis_names=batch_axis_names, + fsdp_axis_names=fsdp_axis_names, + tp_axis_names=tp_axis_names, + seq_axis_names=seq_axis_names, + ) -class CausalAttentionLogitBiasLayer(AttentionLogitBiasLayer): - """Causal attention logit bias layer.""" + ff_layer = cfg.feed_forward + self.assertSequenceEqual( + ff_layer.linear1.param_partition_spec, (fsdp_axis_names, tp_axis_names) + ) + self.assertSequenceEqual( + ff_layer.linear2.param_partition_spec, (tp_axis_names, fsdp_axis_names) + ) + self.assertSequenceEqual( + ff_layer.linear1.output_partition_spec, + (batch_axis_names, seq_axis_names, tp_axis_names), + ) + self.assertSequenceEqual( + ff_layer.linear2.output_partition_spec, + (batch_axis_names, seq_axis_names, tp_axis_names), + ) - def forward(self, *, segment_ids: Tensor, positions: Tensor) -> Tensor: - """Refer to AttentionLogitBiasLayer.forward for docstring.""" - # Note: padding tokens are not explicitly masked. - causal_bias = (positions[:, None, :, None] < positions[:, None, None, :]) * NEG_INF - return apply_attention_logit_biases( - causal_bias, make_segment_mask(source_segments=segment_ids, target_segments=segment_ids) + self_atten = cfg.self_attention.attention + input_linear = self_atten.input_linear + if isinstance(self_attention_input_linear_cfg, RoFormerQKVLinear.Config): + input_linear = input_linear.input_linear + # Shard weights. + self.assertSequenceEqual( + input_linear.layer.param_partition_spec, + (fsdp_axis_names, tp_axis_names, None), + ) + self.assertSequenceEqual( + self_atten.output_linear.param_partition_spec, (fsdp_axis_names, tp_axis_names, None) ) + if cross_attention_cfg is None: + self.assertIsNone(cfg.cross_attention) + else: + cross_atten = cfg.cross_attention.attention + # Shard weights. + self.assertSequenceEqual( + cross_atten.input_linear.layer.param_partition_spec, + (fsdp_axis_names, tp_axis_names, None), + ) + self.assertSequenceEqual( + cross_atten.output_linear.param_partition_spec, + (fsdp_axis_names, tp_axis_names, None), + ) -class FullAttentionLogitBiasLayer(AttentionLogitBiasLayer): - """Full attention logit bias layer.""" + @parameterized.product( + self_attention_input_linear_cfg=( + QKVLinear.default_config(), + FusedQKVLinear.default_config(), + ), + cross_attention_cfg=(None, TransformerAttentionLayer.default_config()), + batch_axis_names=("data", ("replica", "data", "fsdp")), + fsdp_axis_names=("fsdp",), + tp_axis_names=("model",), + seq_axis_names=("seq",), + ) + def test_set_double_shard_weights_config_for_list_of_configs( + self, + self_attention_input_linear_cfg, + cross_attention_cfg, + batch_axis_names, + fsdp_axis_names, + tp_axis_names, + seq_axis_names, + ): + cfg_layer: TransformerLayer.Config = TransformerLayer.default_config().set( + cross_attention=cross_attention_cfg + ) + cfg_layer.self_attention.attention.input_linear = self_attention_input_linear_cfg + cfg_layers = [cfg_layer, cfg_layer] + set_double_shard_weights_config( + cfg_layers, + batch_axis_names=batch_axis_names, + fsdp_axis_names=fsdp_axis_names, + tp_axis_names=tp_axis_names, + seq_axis_names=seq_axis_names, + ) - def forward(self, *, segment_ids: Tensor, positions: Tensor) -> Tensor: - """Refer to AttentionLogitBiasLayer.forward for docstring.""" - del positions - return make_segment_mask(source_segments=segment_ids, target_segments=segment_ids) + for cfg in cfg_layers: + ff_layer = cfg.feed_forward + self.assertSequenceEqual( + ff_layer.linear1.param_partition_spec, (fsdp_axis_names, tp_axis_names) + ) + self.assertSequenceEqual( + ff_layer.linear2.param_partition_spec, (tp_axis_names, fsdp_axis_names) + ) + self.assertSequenceEqual( + ff_layer.linear1.output_partition_spec, + (batch_axis_names, seq_axis_names, tp_axis_names), + ) + self.assertSequenceEqual( + ff_layer.linear2.output_partition_spec, + (batch_axis_names, seq_axis_names, tp_axis_names), + ) + self_atten = cfg.self_attention.attention + # Shard weights. + self.assertSequenceEqual( + self_atten.input_linear.layer.param_partition_spec, + (fsdp_axis_names, tp_axis_names, None), + ) + self.assertSequenceEqual( + self_atten.output_linear.param_partition_spec, + (fsdp_axis_names, tp_axis_names, None), + ) -def alibi_get_slopes(num_heads: int) -> list: - """Get the slopes for different attention heads defined in ALiBi paper. + if cross_attention_cfg is None: + self.assertIsNone(cfg.cross_attention) + else: + cross_atten = cfg.self_attention.attention + # Shard weights. + self.assertSequenceEqual( + cross_atten.input_linear.layer.param_partition_spec, + (fsdp_axis_names, tp_axis_names, None), + ) + self.assertSequenceEqual( + cross_atten.output_linear.param_partition_spec, + (fsdp_axis_names, tp_axis_names, None), + ) - This is a direct copy from ALiBi codebase. - Ref: - https://github.com/ofirpress/attention_with_linear_biases/tree/3b7c2eca/fairseq/models/transformer.py#L742-L752 - Args: - num_heads: An integer for the number of attention heads. +class PositionalEmbeddingTest(TestCase): + """Tests PositionalEmbedding.""" - Returns: - A tensor of slopes with shape of [num_heads]. Each value represents - a slope for one attention head. - """ + def test_learned_positional_embedding_1d(self): + """ + Simple test that LearnedPositionalEmbedding returns expected outputs for a 1d sequence. + """ + positions = np.arange(10) + dim = 8 + pos_emb_cfg = LearnedPositionalEmbedding.default_config().set( + name="test", + dim=dim, + shape=(len(positions),), + ) + pos_emb = pos_emb_cfg.instantiate(parent=None) - def get_slopes_power_of_2(n: int) -> list: - start = 2 ** (-(2 ** -(math.log2(n) - 3))) - ratio = start - return [start * ratio**i for i in range(n)] + state = pos_emb.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) - if math.log2(num_heads).is_integer(): - return get_slopes_power_of_2(num_heads) - else: - closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) - return ( - get_slopes_power_of_2(closest_power_of_2) - + alibi_get_slopes(2 * closest_power_of_2)[0::2][: num_heads - closest_power_of_2] + outputs, _ = F( + pos_emb, + is_training=True, + prng_key=jax.random.PRNGKey(1), + state=state, + inputs={"positions": positions}, ) + context = InvocationContext( + name="root", + parent=None, + module=pos_emb, + state=state, + output_collection=new_output_collection(), + is_training=True, + prng_key=jax.random.PRNGKey(2), + ) + with set_current_context(context): + embeddings_tensor = pos_emb.embeddings() + assert embeddings_tensor.shape == (len(positions), dim) -class ALiBiAttentionLogitBiasLayer(CausalAttentionLogitBiasLayer): - """attention logit bias layer in ALiBi. + for position in positions: + assert_allclose(outputs[position], embeddings_tensor[position]) - Ref: https://github.com/ofirpress/attention_with_linear_biases/tree/3b7c2eca - """ - @config_class - class Config(CausalAttentionLogitBiasLayer.Config): - """Configures ALiBiAttentionLogitBiasLayer.""" +@pytest.mark.parametrize("x, output", [(300, 512), (127.1, 128), (128, 128), (0.1, 2)]) +def test_next_power_of_two(x, output): + assert _next_power_of_two(x) == output - num_heads: Required[int] = REQUIRED - def forward(self, *, segment_ids: Tensor, positions: Tensor) -> Tensor: - """Produces an attention logit biases of shape [batch_size, num_heads, seq_len, seq_len]. +class BottleNeckAdapterTransformerLayerTest(TestCase): + """Tests BottleNeckAdapterTransformerLayer.""" - The ALiBi bias is defined as below: - 1. Create a lower triangle matrix with the value of: - bias = [-(i-1), ..., -2, -1, 0] * slopes - 2. Apply the casual biases. - bias = apply_apply_attention_logit_biases(bias, causal_bias) + @parameterized.parameters( + {"bottleneck_ratio": 0.1}, + {"bottleneck_ratio": 0.5}, + {"bottleneck_ratio": 1.0}, + ) + def test_forward(self, bottleneck_ratio): + batch_size, tgt_len, model_dim, num_heads = 2, 3, 32, 1 - Refer to AttentionLogitBiasLayer.forward for docstring. - """ - cfg = self.config - slopes = jnp.asarray(alibi_get_slopes(cfg.num_heads)) - # Create the lower triangle matrix w/ value [-(i-1), ..., -2, -1, 0] for each segment. - alibi_bias = jnp.expand_dims(positions, [1]) - jnp.expand_dims(positions, [2]) - # Add head dim. - alibi_bias = jnp.expand_dims(alibi_bias, [1]) - # Multiply w/ the slopes. - alibi_bias = alibi_bias * jnp.expand_dims(slopes, [0, 2, 3]) - bias = super().forward(segment_ids=segment_ids, positions=positions) - # Combine the biases. - return apply_attention_logit_biases(alibi_bias, bias) - - -class SymmetricALiBiAttentionLogitBiasLayer(FullAttentionLogitBiasLayer): - """Symmetric full attention version of ALiBiAttentionLogitBiasLayer. - - Main implementation differences between this one and `ALiBiAttentionLogitBiasLayer` (above): - 1. Muliplies alibi slopes by -1. - 2. Computes absolute value of relative positions. - 3. Multiplies results of steps 1 and 2 to get symmetric bias matrix. - - Originally proposed here by an author of the ALiBi paper: - https://github.com/ofirpress/attention_with_linear_biases/issues/5 - """ + layer_cfg = TransformerLayer.default_config().set(name="layer", input_dim=model_dim) + layer_cfg.self_attention.attention.set(num_heads=num_heads) + layer_cfg.feed_forward.hidden_dim = model_dim - @config_class - class Config(FullAttentionLogitBiasLayer.Config): - """Configures SymmetricALiBiAttentionLogitBiasLayer.""" + adapter_cfg = BottleNeckAdapterTransformerLayer.default_config().set( + input_dim=model_dim, name="adapter", bottleneck_ratio=bottleneck_ratio + ) + adapter_cfg.layer = layer_cfg - num_heads: Required[int] = REQUIRED + adapter = adapter_cfg.instantiate(parent=None) - def forward(self, *, segment_ids: Tensor, positions: Tensor) -> Tensor: - cfg = self.config + state = adapter.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) - slopes = -1 * jnp.asarray(alibi_get_slopes(cfg.num_heads)) + data = jax.random.normal(jax.random.PRNGKey(1), [batch_size, tgt_len, model_dim]) + self_attention_logit_biases = attention_bias.make_causal_biases(tgt_len) - # Create the lower triangle matrix w/ value [-(i-1), ..., -2, -1, 0] for each segment. - alibi_bias = jnp.abs(positions[:, jnp.newaxis, :] - positions[:, :, jnp.newaxis]) + outputs, _ = F( + adapter, + is_training=True, + prng_key=jax.random.PRNGKey(2), + state=state, + inputs=dict( + data=data, + self_attention_logit_biases=self_attention_logit_biases, + ), + ) - # Add head dim. - alibi_bias = alibi_bias[:, jnp.newaxis, :, :] + # Output shape is left unchanged. + assert outputs.data.shape == (2, 3, 32) - # Multiply w/ the slopes. - alibi_bias = alibi_bias * jnp.expand_dims(slopes, [0, 2, 3]) - bias = super().forward(segment_ids=segment_ids, positions=positions) - # Combine the biases. - return apply_attention_logit_biases(alibi_bias, bias) +if __name__ == "__main__": + with utils.numeric_checks(True): + absltest.main() From 60b7d32dcec64b9f87d44ba61cee0910f7e6674e Mon Sep 17 00:00:00 2001 From: Firenze11 Date: Tue, 14 Jan 2025 18:53:35 -0800 Subject: [PATCH 05/12] Update attention_test.py --- axlearn/common/attention_test.py | 71 ++++++++++++++++++++++++++++++-- 1 file changed, 67 insertions(+), 4 deletions(-) diff --git a/axlearn/common/attention_test.py b/axlearn/common/attention_test.py index 8528fe95e..bca07178f 100644 --- a/axlearn/common/attention_test.py +++ b/axlearn/common/attention_test.py @@ -41,7 +41,6 @@ from axlearn.common import attention, attention_bias, test_utils, utils from axlearn.common.attention import ( - FEED_FORWARD_SAVE_PATTERN, BaseStackedTransformerLayer, BaseTransformerLayer, BottleNeckAdapterTransformerLayer, @@ -58,6 +57,7 @@ PipelinedTransformerLayer, QKVLinear, QLinear, + RematRegexSavePatterns, RepeatedTransformerLayer, RoFormerQKVLinear, StackedTransformerLayer, @@ -65,7 +65,6 @@ TransformerFeedForwardLayer, TransformerLayer, _next_power_of_two, - _save_and_offload_only_these_names_regex, apply_attention_logit_biases, apply_rotary_position_embeddings, build_remat_spec, @@ -124,6 +123,7 @@ VDict, as_tensor, flatten_items, + save_and_offload_only_these_names_regex, shapes, ) @@ -3554,8 +3554,8 @@ def f(x, layer_params): _, save_name_backward = jax.linearize( jax.remat( f, - policy=_save_and_offload_only_these_names_regex( - names_which_can_be_saved=FEED_FORWARD_SAVE_PATTERN, + policy=save_and_offload_only_these_names_regex( + names_which_can_be_saved=RematRegexSavePatterns.FEED_FORWARD.value, names_which_can_be_offloaded=None, offload_src="device", offload_dst="pinned_host", @@ -4010,6 +4010,69 @@ def f(x, layer_params): 5, ) + def test_build_remat_spec_neuron(self): + model_dim, num_heads = 6, 2 + cfg: TransformerLayer.Config = TransformerLayer.default_config().set(input_dim=model_dim) + cfg.self_attention.attention.set(num_heads=num_heads, causal=True) + cfg.feed_forward.hidden_dim = model_dim * 4 + cfg.vlog = 5 + + layer: BaseTransformerLayer = cfg.clone(name="layer").instantiate(parent=None) + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) + + batch_size, tgt_len = 2, 5 + rng = np.random.default_rng(seed=123) + target = rng.random([batch_size, tgt_len, cfg.input_dim], dtype=np.float32) + + def f(x, layer_params): + forward_outputs, _ = F( + layer, + inputs=dict( + data=x, + ), + state=layer_params, + is_training=True, + prng_key=jax.random.PRNGKey(0), + ) + return forward_outputs + + # Ignore type errors. + spec: Any = build_remat_spec(mock.MagicMock()) + + policy = ( + config_for_function(save_and_offload_only_these_names_regex) + .set( + names_which_can_be_saved="|".join( + [ + RematRegexSavePatterns.QKV_PROJ.value, + RematRegexSavePatterns.LINEAR1_X.value, + ] + ), + names_which_can_be_offloaded=None, + offload_src=None, + offload_dst=None, + ) + .instantiate() + ) + + _, default_policy_backward = jax.linearize( + jax.remat(f, policy=policy, prevent_cse=spec.prevent_cse), + jnp.asarray(target), + layer_params, + ) + _, full_remat_backward = jax.linearize( + jax.remat(f), + jnp.asarray(target), + layer_params, + ) + + # Eliminated the remat of qkv_proj and linear1_0 = 4 dots. + self.assertEqual( + str(full_remat_backward).count(" dot_general") + - str(default_policy_backward).count(" dot_general"), + 4, + ) + class TestStackModel(BaseLayer): """A dummy transformer stack.""" From 8bdbac1608ff4c1eb35e5a26b440fcb5062a3c95 Mon Sep 17 00:00:00 2001 From: Firenze11 Date: Tue, 14 Jan 2025 18:54:06 -0800 Subject: [PATCH 06/12] Update attention.py --- axlearn/common/attention.py | 8843 +++++++++++++++-------------------- 1 file changed, 3839 insertions(+), 5004 deletions(-) diff --git a/axlearn/common/attention.py b/axlearn/common/attention.py index bca07178f..b2cc40f61 100644 --- a/axlearn/common/attention.py +++ b/axlearn/common/attention.py @@ -6,4162 +6,3632 @@ # Copyright 2018 The TensorFlow Authors. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # -# ofirpress/attention_with_linear_biases: -# Copyright (c) Facebook, Inc. and its affiliates. +# google-research/t5x: +# Copyright 2022 The T5X Authors. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (the "License"). +# +# huggingface/transformers: +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"). +# +# facebookresearch/deit: +# Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"). +# +# tensorflow/models: +# Copyright 2023 The TensorFlow Authors. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"). +# +# google/praxis: +# Copyright 2022 The Pax Authors. +# Licensed under the Apache License, Version 2.0 (the "License"). # -# facebookresearch/llama: +# ofirpress/attention_with_linear_biases: # Copyright (c) Facebook, Inc. and its affiliates. - -"""Tests attention layers.""" - -import contextlib -import copy -import itertools - -# pylint: disable=too-many-lines,duplicate-code,no-self-use +# Licensed under the MIT license. + +"""Attention layers with pjit partition specs. + +On `attention_logit_biases`: +* For methods that take a tensor, a biases Tensor can have one of the following shapes: + * [target_length, source_length] + * [batch, target_length, source_length] + * [batch, num_heads, target_length, source_length]. +* Each value represents a bias to be added to the attention logits + (therefore a -inf represents a disconnected position pair). +* biases=None represents an all-zero tensor, i.e., all position pairs are connected. +* For methods that take a BaseAttentionBias, the value() will always be None or a 4d Tensor with + the above semantics. + +TODO(apghml) Convert everything to take an instance of BaseAttentionBias rather than a Tensor. + +On `segment_ids`: +* A tensor of shape [batch, target_length] with values in [0, num_segments]. +* Tokens are only allowed to attend to other tokens within the same segment. +* segment_ids == 0 represents paddings. +* None represents an all-one tensor, i.e. all positions are in the same segment. + +On `positions`: +* A tensor of shape [batch, target_length]. Note that this is conceptually different from + `time_step`. To disambiguate: + * `positions`: A [batch, target_length] tensor indicating the position ids of each input token + during training (i.e. in `forward`). + * `time_step`: a [batch] tensor indicating the current decode position of each sample during + decoding (i.e. in `init_states` and `extend_step`). +* In most typical cases, the values of `positions` are integers in [0, target_length - 1]. + However, this should not be assumed by the implementation in order to support other positional + encoding schemes, e.g. RandPos (https://arxiv.org/pdf/2305.16843), where positions are + non-consecutive integers that can be larger than target_length - 1. +* None represents jnp.arange(target_length). +* When the accompanying argument is `query`, the `positions` argument is named as + `query_position`. Similarly, when the argument `target`, it is named as `target_positions`. + +TODO(changlan): Merge the use of `positions` and `time_step` to reduce cognitive complexity. + +""" + +# pylint: disable=abstract-method,too-many-lines +import enum +import functools import math from collections.abc import Sequence -from itertools import combinations -from typing import Any, Callable, Optional, Union -from unittest import mock +from enum import Enum, unique +from typing import Any, Callable, NamedTuple, Optional, Protocol, Union +import einops import jax -import numpy as np -import optax -import pytest -import torch -from absl import logging -from absl.testing import absltest, parameterized -from jax import nn from jax import numpy as jnp -from jax.ad_checkpoint import checkpoint_policies as jax_remat_policies -from transformers.models.roberta import modeling_roberta as hf_roberta -from transformers.models.roformer import modeling_roformer as hf_roformer -from transformers.models.xlnet import modeling_xlnet as hf_xlnet - -from axlearn.common import attention, attention_bias, test_utils, utils -from axlearn.common.attention import ( - BaseStackedTransformerLayer, - BaseTransformerLayer, - BottleNeckAdapterTransformerLayer, - FusedGroupedQKVLinear, - FusedQKVLinear, - KVState, - LearnedPositionalEmbedding, - MultiheadAttentionXL, - MultiheadInputLinear, - MultiheadOutputLinear, - MultiheadRelativePositionLinear, - ParallelTransformerLayer, - PerDimScale, - PipelinedTransformerLayer, - QKVLinear, - QLinear, - RematRegexSavePatterns, - RepeatedTransformerLayer, - RoFormerQKVLinear, - StackedTransformerLayer, - TransformerAttentionLayer, - TransformerFeedForwardLayer, - TransformerLayer, - _next_power_of_two, - apply_attention_logit_biases, - apply_rotary_position_embeddings, - build_remat_spec, - compute_padding_biases, - rel_pos_to_abs_pos, - scaled_hidden_dim, - set_double_shard_weights_config, - sinusoidal_positional_embeddings, - update_data_with_skip_connection, - xl_attention_logits, -) + +from axlearn.common import ops, param_init from axlearn.common.attention_bias import ( NEG_INF, - bool_to_bias, + BaseAttentionBias, + CausalAttentionBias, + MaskFn, + MaskFnAttentionBias, + SegmentIdAttentionBias, + as_attention_bias, causal_mask, - make_causal_biases, - make_sliding_window_causal_biases, - sliding_window_causal_mask, + make_segment_mask, ) from axlearn.common.base_layer import ( BaseLayer, - DefaultTensorStats, FactorizationSpec, + NestedParameterSpec, ParameterSpec, RematSpec, ) from axlearn.common.config import ( + REQUIRED, + ConfigOr, + FunctionConfigBase, InstantiableConfig, - UnknownFieldError, + Required, config_class, config_for_function, - maybe_set_config, + maybe_instantiate, +) +from axlearn.common.layers import ( + Dropout, + LayerNorm, + Linear, + StochasticDepth, + get_activation_fn, + get_stochastic_depth_linear_rate, ) -from axlearn.common.decoder import Decoder, TransformerTextEmbeddings -from axlearn.common.layers import RMSNorm, set_bias_recursively -from axlearn.common.module import InvocationContext, Module -from axlearn.common.module import functional as F -from axlearn.common.module import new_output_collection, set_current_context -from axlearn.common.optimizer_base import OptParam -from axlearn.common.optimizers import adafactor_optimizer -from axlearn.common.param_converter import as_torch_tensor +from axlearn.common.module import Module, child_context from axlearn.common.param_init import ( PARAM_REGEXP_WEIGHT, + ConstantInitializer, DefaultInitializer, FanAxes, WeightInitializer, + constant_initializer, ) -from axlearn.common.pipeline import BaseSchedule, GPipeSchedule, StreamSchedule -from axlearn.common.test_utils import TestCase, assert_allclose, dummy_segments_positions -from axlearn.common.torch_utils import parameters_from_torch_layer +from axlearn.common.pipeline import Pipeline +from axlearn.common.quantized_dot_general.layers import DenseGeneralBaseLayer +from axlearn.common.repeat import Repeat from axlearn.common.utils import ( Nested, + NestedTensor, + OffloadPolicy, PartitionSpec, + SavePattern, Tensor, TensorSpec, VDict, - as_tensor, + check_numerics, flatten_items, + get_or_none, save_and_offload_only_these_names_regex, shapes, + split_prng_key, ) -def all_subsets(given_set): - "Generate all subsets of a list `given_set`." - s = list(given_set) - return list( - itertools.chain.from_iterable(itertools.combinations(s, r) for r in range(len(s) + 1)) - ) - - -def make_index_position_biases(query_len: int, kv_len: int) -> Tensor: - """Generates attention logit biases where query indices cannot attend to larger key indices. - - Args: - query_len: The sequence length. - kv_len: The key's length. +class ForwardMode(enum.Enum): + """ForwardMode describes the type of computation to be done in a forward pass through a layer. - Returns: - A float tensor of shape [query_len, kv_len] where the value at - [i, j] = -inf if i < j, 0 otherwise. + FORWARD: Used for a standard forward pass. + INIT_STATES: Used for initializing the decoding cache. Typically means that the method signature + matches EXTEND_STEP, possibly without an input cache state, and returning a prefilled cache + along with the layer outputs. + EXTEND_STEP: Used for incremental decoding. Typically means that the method signature consumes + cache state and emits cache state along with layer outputs. """ - return bool_to_bias( - causal_mask( - jnp.arange(query_len)[:, None], - jnp.arange(kv_len)[None, :], - ) - ) + FORWARD = 0 + INIT_STATES = 1 + EXTEND_STEP = 2 -def _random_mask(prng_key, tgt_len, src_len): - """Returns a float mask of shape [tgt_len, src_len].""" - key1, key2 = jax.random.split(prng_key) - mask = jnp.logical_not( - jax.random.randint(key1, minval=0, maxval=2, shape=[tgt_len, src_len]) - + - # Ensure that every tgt position attends to at least one src position, otherwise - # torch_modules.MultiheadAttention will generate NaN. - nn.one_hot(jax.random.randint(key2, minval=0, maxval=src_len, shape=[tgt_len]), src_len) - ) - return mask.astype(jnp.float32) * NEG_INF - - -class MaskTest(absltest.TestCase): - """Tests mask implementations.""" - - def test_causal_mask(self): - expected = jnp.array([[0.0, NEG_INF, NEG_INF], [0.0, 0.0, NEG_INF], [0.0, 0.0, 0.0]]) - actual = attention_bias.make_causal_biases(3) - self.assertTrue(jnp.all(actual <= expected)) - - def test_segment_mask(self): - expected = jnp.array( - [ # batch - [ # num_heads - [ - [NEG_INF, NEG_INF, NEG_INF, 0.0], - [NEG_INF, NEG_INF, NEG_INF, 0.0], - [0.0, 0.0, NEG_INF, NEG_INF], - [NEG_INF, NEG_INF, 0.0, NEG_INF], - ] - ] - ] - ) - actual = attention_bias.make_segment_mask( - target_segments=jnp.asarray([[1, 1, 2, 0]]), - source_segments=jnp.asarray([[2, 2, 0, 1]]), - ) - self.assertTrue(jnp.all(actual <= expected)) - - def test_apply_attention_logit_biases(self): - batch_size = 10 - num_heads = 12 - dim = 32 - logits = jnp.asarray(np.random.random(size=[batch_size, num_heads, dim])) - - # Testing for biases = None - masked_logit = apply_attention_logit_biases(logits, attention_logit_biases=None) - self.assertEqual(masked_logit.dtype, logits.dtype) - np.testing.assert_array_equal(logits, masked_logit) - - # Testing for biases = random_float_biases - for logit_float_type in [jnp.bfloat16, jnp.float32, jnp.float16]: - for mask_float_type in [jnp.bfloat16, jnp.float32, jnp.float16]: - logits = jnp.asarray(np.random.random(size=[batch_size, num_heads, dim])).astype( - logit_float_type - ) - random_float_biases = jnp.asarray( - np.random.random(size=[batch_size, num_heads, dim]) - ).astype(mask_float_type) - masked_logit = apply_attention_logit_biases( - logits, attention_logit_biases=random_float_biases - ) - self.assertEqual(masked_logit.dtype, logits.dtype) - np.testing.assert_array_equal( - masked_logit, logits + random_float_biases.astype(logits.dtype) - ) +class KVState(NamedTuple): + """Represents key/value projections, of shape [batch, source_length, num_kv_heads, head_dim].""" + k_proj: Tensor + v_proj: Tensor -class CausalAttentionLogitBiasLayerTest(TestCase): - """Tests CausalAttentionLogitBiasLayer.""" - @parameterized.parameters( - # Test the mask with all padding tokens. - dict( - token_ids=[[0, 0, 0], [0, 0, 0]], - expected=[ - [ - [0, NEG_INF, NEG_INF], - [0, 0, NEG_INF], - [0, 0, 0], - ], - ] - * 2, - ), - # Test the mask with all valid tokens. - dict( - token_ids=[[1, 2, 3], [4, 5, 6]], - expected=[ - [ - [0, NEG_INF, NEG_INF], - [0, 0, NEG_INF], - [0, 0, 0], - ], - ] - * 2, - ), - # Test the mask with some valid tokens and some padding tokens. - dict( - token_ids=[[10, 0, 0], [12, 33, 0]], - expected=[ - [ - [0, NEG_INF, NEG_INF], - [0, 0, NEG_INF], - [0, 0, 0], - ], - ] - * 2, - ), - # Test the mask with additional padding biases. - dict( - token_ids=[[10, 0, 0], [12, 33, 0]], - apply_padding_mask=True, - expected=[ - [ - [0, NEG_INF, NEG_INF], - [NEG_INF, NEG_INF, NEG_INF], - [NEG_INF, NEG_INF, NEG_INF], - ], - [ - [0, NEG_INF, NEG_INF], - [0, 0, NEG_INF], - [NEG_INF, NEG_INF, NEG_INF], - ], - ], - ), - # Test the mask with valid tokens AND paddings in between. - dict( - token_ids=[[10, 0, 11], [12, 33, 0]], - apply_padding_mask=True, - expected=[ - [ - [0, NEG_INF, NEG_INF], - [NEG_INF, NEG_INF, NEG_INF], - [0, NEG_INF, 0], - ], - [ - [0, NEG_INF, NEG_INF], - [0, 0, NEG_INF], - [NEG_INF, NEG_INF, NEG_INF], - ], - ], - ), - # Test a basic case with positions. - dict( - token_ids=[[10, 11, 12], [13, 14, 15]], - segment_ids=[[1, 1, 2], [1, 2, 2]], - positions=[[0, 1, 0], [0, 0, 1]], - expected=[ - [ - [0, NEG_INF, NEG_INF], - [0, 0, NEG_INF], - [NEG_INF, NEG_INF, 0], - ], - [ - [0, NEG_INF, NEG_INF], - [NEG_INF, 0, NEG_INF], - [NEG_INF, 0, 0], - ], - ], - ), - # Test a case where some segments are empty. - dict( - token_ids=[[10, 11, 12], [13, 14, 15]], - segment_ids=[[1, 2, 2], [2, 2, 2]], - positions=[[0, 0, 1], [0, 1, 2]], - expected=[ - [ - [0, NEG_INF, NEG_INF], - [NEG_INF, 0, NEG_INF], - [NEG_INF, 0, 0], - ], - [ - [0, NEG_INF, NEG_INF], - [0, 0, NEG_INF], - [0, 0, 0], - ], - ], - ), - # Test with positions and padding. - # Note: we deliberately allow the last token to be 0, to test that without - # apply_padding_mask, a 0-token is not necessarily padding if its segment_id != 0. - dict( - token_ids=[[10, 11, 0], [13, 14, 0]], - segment_ids=[[1, 1, 0], [1, 2, 2]], - positions=[[0, 1, 0], [0, 0, 1]], - expected=[ - [ - [0, NEG_INF, NEG_INF], - [0, 0, NEG_INF], - [NEG_INF, NEG_INF, 0], - ], - [ - [0, NEG_INF, NEG_INF], - [NEG_INF, 0, NEG_INF], - [NEG_INF, 0, 0], - ], - ], - ), - # Test with segment IDs but not positions. - # This should have the same result as the previous test. - dict( - token_ids=[[10, 11, 0], [13, 14, 0]], - segment_ids=[[1, 1, 0], [1, 2, 2]], - expected=[ - [ - [0, NEG_INF, NEG_INF], - [0, 0, NEG_INF], - [NEG_INF, NEG_INF, 0], - ], - [ - [0, NEG_INF, NEG_INF], - [NEG_INF, 0, NEG_INF], - [NEG_INF, 0, 0], - ], - ], - ), - # Test with positions and padding, and apply the padding mask. - dict( - token_ids=[[10, 11, 0], [13, 14, 0]], - segment_ids=[[1, 1, 0], [1, 2, 0]], - positions=[[0, 1, 0], [0, 0, 1]], - expected=[ - [ - [0, NEG_INF, NEG_INF], - [0, 0, NEG_INF], - [NEG_INF, NEG_INF, NEG_INF], - ], - [ - [0, NEG_INF, NEG_INF], - [NEG_INF, 0, NEG_INF], - [NEG_INF, NEG_INF, NEG_INF], - ], - ], - apply_padding_mask=True, - ), - ) - def test_causal_attention_mask_layer( - self, - token_ids: list, - expected: list, - segment_ids: Optional[Tensor] = None, - positions: Optional[Tensor] = None, - apply_padding_mask: Optional[bool] = False, - ): - causal_attention_mask_layer = ( - attention.CausalAttentionLogitBiasLayer.default_config() - .set(name="test_causal_attention_mask") - .instantiate(parent=None) - ) - if token_ids is not None: - token_ids = np.asarray(token_ids) - if positions is None: - positions = np.arange(token_ids.shape[-1])[None, :] - else: - positions = np.asarray(positions) - if segment_ids is None: - segment_ids = np.ones_like(token_ids) - else: - segment_ids = np.asarray(segment_ids) - actual = causal_attention_mask_layer.forward(segment_ids=segment_ids, positions=positions) - if apply_padding_mask: - actual += compute_padding_biases(token_ids, pad_token_id=0) - assert_allclose(jnp.exp(actual.squeeze(1)), jnp.exp(np.asarray(expected))) - - -class FullAttentionLogitBiasLayerTest(TestCase): - """Tests FullAttentionLogitBiasLayer.""" - - @parameterized.parameters( - # Test the mask with all padding tokens. - dict( - token_ids=[[0, 0, 0], [0, 0, 0]], - expected=[ - [ - [NEG_INF, NEG_INF, NEG_INF], - [NEG_INF, NEG_INF, NEG_INF], - [NEG_INF, NEG_INF, NEG_INF], - ], - [ - [NEG_INF, NEG_INF, NEG_INF], - [NEG_INF, NEG_INF, NEG_INF], - [NEG_INF, NEG_INF, NEG_INF], - ], - ], - ), - # Test the mask with all valid tokens. - dict( - token_ids=[[1, 2, 3], [4, 5, 6]], - expected=[ - [[0, 0, 0], [0, 0, 0], [0, 0, 0]], - [[0, 0, 0], [0, 0, 0], [0, 0, 0]], - ], - ), - # Test the mask with some valid tokens and some padding tokens. - dict( - token_ids=[[10, 0, 0], [12, 33, 0]], - expected=[ - [ - [0, NEG_INF, NEG_INF], - [NEG_INF, NEG_INF, NEG_INF], - [NEG_INF, NEG_INF, NEG_INF], - ], - [ - [0, 0, NEG_INF], - [0, 0, NEG_INF], - [NEG_INF, NEG_INF, NEG_INF], - ], - ], - ), - # Test a basic case with segment IDs. - dict( - token_ids=[[10, 11, 12], [13, 14, 15]], - segment_ids=[[1, 1, 2], [1, 2, 2]], - positions=[[0, 1, 0], [0, 0, 1]], - expected=[ - [ - [0, 0, NEG_INF], - [0, 0, NEG_INF], - [NEG_INF, NEG_INF, 0], - ], - [ - [0, NEG_INF, NEG_INF], - [NEG_INF, 0, 0], - [NEG_INF, 0, 0], - ], - ], - ), - # Test a case where some segments are empty. - dict( - token_ids=[[10, 11, 12], [13, 14, 15]], - segment_ids=[[1, 1, 2], [2, 2, 2]], - positions=[[0, 1, 0], [0, 1, 2]], - expected=[ - [ - [0, 0, NEG_INF], - [0, 0, NEG_INF], - [NEG_INF, NEG_INF, 0], - ], - [ - [0, 0, 0], - [0, 0, 0], - [0, 0, 0], - ], - ], - ), - # Test with segment IDs and padding. - dict( - token_ids=[[10, 11, 0], [13, 14, 0]], - segment_ids=[[1, 1, 0], [1, 2, 0]], - positions=[[0, 1, 0], [0, 0, 1]], - expected=[ - [ - [0, 0, NEG_INF], - [0, 0, NEG_INF], - [NEG_INF, NEG_INF, NEG_INF], - ], - [ - [0, NEG_INF, NEG_INF], - [NEG_INF, 0, NEG_INF], - [NEG_INF, NEG_INF, NEG_INF], - ], - ], - ), - ) - def test_full_attention_mask_layer( - self, - token_ids: list, - expected: list, - segment_ids: Optional[Tensor] = None, - positions: Optional[Tensor] = None, - ): - full_attention_mask_layer = ( - attention.FullAttentionLogitBiasLayer.default_config() - .set(name="test_full_attention_mask") - .instantiate(parent=None) - ) - if token_ids is not None: - token_ids = np.asarray(token_ids) - if positions is None: - positions = np.arange(token_ids.shape[-1])[None, :] - else: - positions = np.asarray(positions) - if segment_ids is None: - segment_ids = token_ids != 0 - else: - segment_ids = np.asarray(segment_ids) - actual = full_attention_mask_layer.forward(segment_ids=segment_ids, positions=positions) - actual += compute_padding_biases(token_ids, pad_token_id=0) - assert_allclose(jnp.exp(np.asarray(expected)), jnp.exp(actual.squeeze(1))) - - -class ALiBiAttentionLogitBiasLayerTest(TestCase): - """Tests ALiBiAttentionLogitBiasLayer.""" - - def ref_alibi_implementation(self, batch_size, num_heads, max_len): - # Slopes is in jax DeviceArray. Switch it to torch tensor as the ref code. - slopes = torch.Tensor(attention.alibi_get_slopes(num_heads)) - alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(max_len).unsqueeze(0).unsqueeze( - 0 - ).expand(num_heads, -1, -1) - alibi = alibi.view(num_heads, 1, max_len) - - # Post processing to translate alibi matrix into the jax format. - # Alibi matrix shape [batch_size, num_heads, max_len, max_len]. - alibi = alibi.unsqueeze(0).expand(batch_size, -1, max_len, -1) - # Translate from pytorch to jax. - alibi = as_tensor(alibi) - return alibi - - def test_alibi_attention_mask(self): - num_heads = 12 - batch_size = 2 - max_len = 3 - - # Test alibi implementation. - alibi_attention_mask_layer = ( - attention.ALiBiAttentionLogitBiasLayer.default_config() - .set(name="test_alibi_attention_mask", num_heads=num_heads) - .instantiate(parent=None) - ) +class BaseTransformerLayer(BaseLayer): + """An abstract class to define the common interface of all *TransformerLayer classes, including: - # Casual attention mask which will be applied to ref alibi mask. - ref_causal_attention_mask_layer = ( - attention.CausalAttentionLogitBiasLayer.default_config() - .set(name="ref_causal_attention_mask") - .instantiate(parent=None) - ) + * All subclasses must have `input_dim` in its Config; + * The common Output structure; + * The common method signature for `forward()`, `init_states()`, and `extend_step()`. + """ - token_ids = as_tensor(np.random.randint(low=1, high=20, size=[batch_size, max_len])) - segment_ids = jnp.ones_like(token_ids) - positions = jnp.arange(max_len)[None, :] + @config_class + class Config(BaseLayer.Config): + """Configures BaseTransformerLayer.""" - ref_alibi_mask = self.ref_alibi_implementation(batch_size, num_heads, max_len) - # Reshape causal_mask to [batch_size, num_heads, max_len, max_len]. - ref_causal_mask = ref_causal_attention_mask_layer.forward( - segment_ids=segment_ids, positions=positions - ) - ref_causal_mask = jnp.repeat(ref_causal_mask, num_heads, axis=1) + input_dim: Required[int] = REQUIRED # Input feature dim. - # Prepare the ref and the test alibi mask. - ref_alibi_mask = attention.apply_attention_logit_biases(ref_alibi_mask, ref_causal_mask) - test_alibi_mask = alibi_attention_mask_layer.forward( - segment_ids=segment_ids, positions=positions - ) + class Output(NamedTuple): + """BaseTransformerLayer output. - # Ref and test alibi mask should be the same after applying it into a QK attention matrix. - # e.g. softmax(QK + ref_alibi_mask) == softmax(QK + test_alibi_mask). - random_qk_matrix = jnp.asarray( - np.random.random(size=[batch_size, num_heads, max_len, max_len]) - ) + Fields: + data: [batch, target_length, input_dim]. The layer output. Always present. - ref_alibi_softmax = jax.nn.softmax(random_qk_matrix + ref_alibi_mask, axis=-1) - test_alibi_softmax = jax.nn.softmax(random_qk_matrix + test_alibi_mask, axis=-1) - - # The ref alibi implementation relies on the softmax property of invariance to translation. - # e.g. ref_alibi = [[0, -inf, -inf], [0, 1, -inf], [0, 1, 2]] - # test_alibi = [[0, -inf, -inf], [-1, 0, -inf], [-2, -1, 0]] - # softmax(qk + test_alibi) = softmax (qk + [[0, -inf, -inf], [-1, 0, -inf], [-2, -1, 0]]) - # = softmax (qk + [[0, -inf, -inf], [0, 1, -inf+1], [0, 1, 2]]) - # As the numerical -inf is not perfect -inf defined in math. - # Therefore, a very limit difference between those two after softmax, due to (-inf + x). - # The rtol is set to 5e-7 to tolerate this difference. - np.testing.assert_allclose(ref_alibi_softmax, test_alibi_softmax, rtol=5e-07) - - @parameterized.product( - [ - dict(num_segments=1, max_len=3), - dict(num_segments=3, max_len=3), - dict(num_segments=3, max_len=8), - ], - ) - def test_packing(self, max_len: int, num_segments: int): - # With packed inputs of shape [batch, seq_len], we form a block-diagonal matrix of shape - # [batch, num_heads, seq_len, seq_len], where each (unpacked) input has blocks of shape - # [batch, num_heads, segment_len, segment_len] (segment_len <= seq_len). - # We test this by comparing each block against a freshly computed alibi mask of the same - # shape, ensuring that packing is equivalent to treating each unpacked input separately. - num_heads = 12 - batch_size = 2 - - # Test alibi implementation. - alibi_attention_mask_layer = ( - attention.ALiBiAttentionLogitBiasLayer.default_config() - .set(name="test_alibi_attention_mask", num_heads=num_heads) - .instantiate(parent=None) - ) + self_attention_probs: The attention probabilities returned by the self-attention layer. + Shape: [..., target_length, target_length]. - # Construct inputs of shape [batch_size, max_len]. - input_segment_ids, positions = dummy_segments_positions( - batch_size, max_len, num_segments=num_segments - ) + self_attention_probs[..., i, j] represents self-attention probability on + input data[..., j, :] when computing output data[..., i, :]. + self_attention_probs.sum(axis=-1) equals to all 1's. - # Compute the test alibi mask of shape [batch, num_heads, seq_len, seq_len]. - test_alibi_batch = alibi_attention_mask_layer.forward( - segment_ids=input_segment_ids, positions=positions - ) - # Apply segment mask and softmax (see notes above). - test_alibi_batch = jax.nn.softmax(test_alibi_batch, axis=-1) - - for batch in range(batch_size): - test_alibi = test_alibi_batch[batch] - input_segments = input_segment_ids[batch] - - # Compute the reference alibi mask(s) for each segment separately. - for segment in range(num_segments): - # [seq_len]. - segment_mask = input_segments == segment - segment_len = int(jnp.sum(segment_mask, dtype=jnp.int32)) - - # Skip the segment if empty. - if segment_len == 0: - continue - - # Select the submatrix in test_alibi corresponding to the current segment. - # [seq_len, seq_len]. - segment_mask = jnp.logical_and(segment_mask[:, None], segment_mask[None, :]) - # [num_heads, seq_len, seq_len]. - segment_mask = jnp.repeat(segment_mask[None, ...], num_heads, 0) - # [num_heads, segment_len, segment_len]. - test_alibi_segment = test_alibi[segment_mask.astype(jnp.bool_)].reshape( - (num_heads, segment_len, segment_len) - ) + Present if "self_attention_probs" is in `return_aux`. - # Construct the ref_alibi for the current segment. - # [num_heads, segment_len]. - ref_alibi = self.ref_alibi_implementation(1, num_heads, segment_len).squeeze(0) - ref_causal_mask = jnp.repeat( - make_causal_biases(segment_len)[None, ...], num_heads, 0 - ) - ref_alibi = attention.apply_attention_logit_biases(ref_alibi, ref_causal_mask) - ref_alibi = jax.nn.softmax(ref_alibi, axis=-1) + self_attention_kv_state: The KV state used in self-attention. + Present if "self_attention_kv_state" is in `return_aux`. - np.testing.assert_allclose(ref_alibi, test_alibi_segment, rtol=5e-07) + cross_attention_probs: The attention probabilities returned by the cross-attention + layer. Shape: [..., target_length, source_length]. + If not None, cross_attention_probs[..., i, j] represents attention probability on + cross_attention_data[..., j, :] when computing output data[..., i, :]. + cross_attention_probs.sum(axis=-1) equals to all 1's. -class SymmetricALiBiAttentionLogitBiasLayerTest(TestCase): - """Tests SymmetricALiBiAttentionLogitBiasLayer.""" + Present if "cross_attention_probs" is in `return_aux`. + """ - def test_alibi_attention_mask(self): - num_heads = 8 - batch_size = 2 - max_len = 3 + data: Tensor + self_attention_probs: Optional[Tensor] = None + self_attention_kv_state: Optional[KVState] = None + cross_attention_probs: Optional[Tensor] = None - # Test alibi implementation. - alibi_attention_mask_layer = ( - attention.SymmetricALiBiAttentionLogitBiasLayer.default_config() - .set(name="test_symmetric_alibi_attention_mask", num_heads=num_heads) - .instantiate(parent=None) - ) + def forward( + self, + data: Tensor, + *, + self_attention_kv_state: Optional[KVState] = None, + self_attention_logit_biases: Optional[Tensor] = None, + cross_attention_data: Optional[Tensor] = None, + cross_attention_logit_biases: Optional[Tensor] = None, + target_segment_ids: Optional[Tensor] = None, + target_positions: Optional[Tensor] = None, + return_aux: Optional[set[str]] = None, + ) -> Output: + """Computes transformer layer outputs given full-sequence inputs. + + For incremental computation, use init_states() and extend_step(). + + See comments at the beginning of this file for semantics of *_attention_logit_biases. + + Args: + data: A Tensor of shape [batch, target_length, input_dim]. + self_attention_kv_state: An optional KVState used for self-attention. + self_attention_logit_biases: An optional Tensor representing the self-attention biases. + cross_attention_data: An optional Tensor representing cross-attention data of shape + [source_batch, source_length, source_dim]. + cross_attention_logit_biases: An optional Tensor representing the cross-attention + biases. + target_segment_ids: See ``segment_ids`` in the file comments. + target_positions: See ``positions`` in the file comments. + return_aux: A set of auxiliary output fields to return. Each element must be an + optional field of `Output`, e.g., + `return_aux = {"self_attention_probs", "self_attention_kv_state"}` means that + `Output.{self_attention_probs, self_attention_kv_state}` will be populated. + + Returns: + BaseTransformerLayer.Output. + """ + raise NotImplementedError(type(self)) - # [num_heads] - slopes = jnp.array(attention.alibi_get_slopes(num_heads)) - - # [max_len, max_len] - base_alibi_mask = jnp.array( - [ - [0, -1, -2], - [-1, 0, -1], - [-2, -1, 0], - ], - dtype=jnp.float32, - ) + def init_states( + self, + *, + time_step: Optional[Tensor], + data: Union[Tensor, TensorSpec], + self_attention_kv_state: Optional[KVState] = None, + self_attention_logit_biases: Optional[Tensor] = None, + cross_attention_data: Optional[Tensor] = None, + cross_attention_logit_biases: Optional[Tensor] = None, + ) -> tuple[Nested[Tensor], Optional[Output]]: + """Initializes cached states for incremental computation. + + The method supports initializing an empty cache as well as prefilling: + * To initialize an empty cache, specify `time_step=None`. + In this case, `data` is allowed to be a TensorSpec. + * To prefill, provide `time_step` and `data` as Tensors. + + Args: + time_step: A Tensor of shape [batch]. Each value is an index into the length dimension + indicating where decoding will start from. + data: A Tensor of shape [batch, target_length, input_dim]. For batch index `i`, only + `data[i, :time_step[i], ...]` will affect subsequent decoding. + self_attention_kv_state: An optional KVState used for self-attention. + self_attention_logit_biases: An optional Tensor representing the self-attention biases. + cross_attention_data: An optional Tensor representing cross-attention data of shape + [batch, source_length, source_dim]. + cross_attention_logit_biases: An optional Tensor representing the cross-attention + biases. + + Returns: + A tuple (init_states, output): + * init_states: A nested tree of Tensors, which can be used as `cached_states` for the + initial call of `extend_step()`. + * output: In the prefill case, a BaseTransformerLayer.Output instance, where: + .data is of the same shape as `data`; + .self_attention_probs is of shape [batch, num_heads, target_length, target_length]; + .cross_attention_probs is of shape [batch, num_heads, target_length, source_length]. + Otherwise, if initializing cache from scratch, output will be None. + """ + raise NotImplementedError(type(self)) - # [heads, max_len, max_len] - expected_logits_bias = slopes[:, jnp.newaxis, jnp.newaxis] * base_alibi_mask - # [batch, heads, max_len, max_len] - expected_logits_bias = expected_logits_bias[jnp.newaxis, ...].repeat(batch_size, axis=0) + def extend_step( + self, + cached_states: NestedTensor, + data: Tensor, + *, + self_attention_kv_state: Optional[KVState] = None, + self_attention_logit_biases: Optional[Tensor] = None, + cross_attention_data: Optional[Tensor] = None, + cross_attention_logit_biases: Optional[Tensor] = None, + ) -> tuple[NestedTensor, Output]: + """Computes incremental outputs. + + Args: + cached_states: A NestedTensor returned by `init_states()` or a previous invocation of + `extend_step()`. + data: A Tensor of shape [target_batch_size, target_step_length, input_dim], where + `target_step_length` is usually 1. For self-attention, `data` will be used as the + `query` sequence and will be appended to key and value sequences. + self_attention_kv_state: An optional KVState used for self-attention. + self_attention_logit_biases: An optional Tensor of shape + [..., target_step_length, target_max_len], where `target_step_length` must match + the shape of `data` and `target_max_len` must match the value given for + `init_states()`. + cross_attention_data: An optional Tensor of shape [..., source_length, source_dim]. + cross_attention_logit_biases: An optional Tensor of shape + [..., target_step_length, source_length], where `target_step_length` must match + the shape of `data`. + + Returns: + (updated_cached_states, output), where: + `updated_cached_states` represents the new cached states incorporating `data`; + `output` represents the output for the given input data. `output.data` will have the + same shape as the input data. + """ + raise NotImplementedError(type(self)) - segment_ids = jnp.ones((batch_size, max_len)) - positions = jnp.arange(max_len)[None, :] - actual_logits_bias = alibi_attention_mask_layer( - segment_ids=segment_ids, positions=positions - ) - assert_allclose(actual_logits_bias, expected_logits_bias) +class LearnedPositionalEmbedding(BaseLayer): + """TODO(ruoming): Remove LearnedPositionalEmbedding. We can just use the Embedding layer.""" + @config_class + class Config(BaseLayer.Config): + """Configures LearnedPositionalEmbedding.""" + + dim: Required[int] = REQUIRED # Input feature dim. + shape: Required[Sequence[int]] = REQUIRED # The sequence shape. + + # Similar initialization code for Embedding. + # pylint: disable=duplicate-code + @classmethod + def default_config(cls): + cfg = super().default_config() + cfg.param_partition_spec = (None, None, "model") + # By default, initialize to Gaussian with std=1/sqrt(dim), e.g., 0.036 when dim=768. + # + # This is the same as: + # https://github.com/pytorch/fairseq/blob/master/fairseq/modules/positional_embedding.py#L26 + # + # BERT uses std=0.02 regardless of dim: + # https://github.com/google-research/bert/blob/eedf5716ce1268e56f0a50264a88cafad334ac61/modeling.py#L492-L495 + cfg.param_init = DefaultInitializer.default_config().set( + init_by_param_name={ + PARAM_REGEXP_WEIGHT: WeightInitializer.default_config().set( + fan="fan_out", distribution="normal" + ) + } + ) + return cfg -class RoFormerSinusoidalPositionalEmbeddingTest(TestCase): - """Tests RoFormerSinusoidalPositionalEmbedding.""" + # pylint: enable=duplicate-code - @parameterized.product( - tensor_dimensions=( - (2, 3, 10, 32), - (2, 3, 8, 32), - (2, 4, 6, 32), - (2, 4, 8, 16), - (2, 5, 8, 48), - (2, 5, 8, 64), - ), - rotary_key=(True, False), - rotary_value=(True, False), - ) - def test_apply_rotary_position_embeddings( - self, tensor_dimensions: tuple[int, int, int, int], rotary_key: bool, rotary_value: bool - ): - # Unittest against the apply_rotary_position_embeddings in HF. - batch_size, num_heads, max_len, dim = tensor_dimensions - - token_ids = np.random.randint(low=1, high=20, size=[batch_size, max_len]) - sinusoidal_pos_layer = hf_roformer.RoFormerSinusoidalPositionalEmbedding(max_len, dim) - sinusoidal_pos = sinusoidal_pos_layer(as_torch_tensor(token_ids).shape)[None, None, :, :] - query = np.random.random([batch_size, num_heads, max_len, dim]) - key = np.random.random([batch_size, num_heads, max_len, dim]) - value = np.random.random([batch_size, num_heads, max_len, dim]) - ref_layer = hf_roformer.RoFormerSelfAttention.apply_rotary_position_embeddings - test_layer = apply_rotary_position_embeddings - if rotary_value: - ref_q_proj, ref_k_proj, ref_v_proj = ref_layer( - sinusoidal_pos, - as_torch_tensor(query), - as_torch_tensor(key), - as_torch_tensor(value), + def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: + cfg = self.config + return dict( + weight=ParameterSpec( + shape=[1] + list(cfg.shape) + [cfg.dim], + mesh_axes=cfg.param_partition_spec, ) - else: - # If rotary_value is set to False, value keeps unchanged. - # pylint: disable-next=unbalanced-tuple-unpacking - ref_q_proj, ref_k_proj = ref_layer( - sinusoidal_pos, as_torch_tensor(query), as_torch_tensor(key) + ) + + def _compute_fan_axes(self, name: str, parameter_spec: ParameterSpec) -> Optional[FanAxes]: + if not name.endswith("weight"): + return None + if len(parameter_spec.shape) != 3: + raise NotImplementedError( + "_compute_fan_axes requires weight parameters to have exactly 3 axes " + f"shape({name}) = {parameter_spec.shape}" ) - ref_v_proj = as_torch_tensor(value) - if not rotary_key: - ref_k_proj = as_torch_tensor(key) + return FanAxes(batch_axis=0, in_axis=1, out_axis=2) - test_q_proj, test_k_proj, test_v_proj = test_layer( - sinusoidal_pos=as_tensor(sinusoidal_pos), - query=query, - key=key, - value=value, - rotary_key=rotary_key, - rotary_value=rotary_value, - ) - np.testing.assert_allclose(test_q_proj, ref_q_proj, atol=5e-7) - np.testing.assert_allclose(test_k_proj, ref_k_proj, atol=5e-7) - np.testing.assert_allclose(test_v_proj, ref_v_proj, atol=5e-7) - - @parameterized.parameters( - (2, 10, 32), - (2, 8, 32), - (2, 6, 32), - (2, 8, 16), - (2, 8, 48), - (2, 8, 64), - ) - def test_rope_emb(self, batch_size, max_len, dim): - # Token id is in the np format for easier transition. - token_ids = np.random.randint(low=1, high=20, size=[batch_size, max_len]) - positions = jnp.expand_dims(jnp.arange(token_ids.shape[-1], dtype=jnp.int32), 0) - ref_layer = hf_roformer.RoFormerSinusoidalPositionalEmbedding(max_len, dim) - ref_output = ref_layer(as_torch_tensor(token_ids).shape) - # Set up the RoPE AXLearn configs. - test_layer = ( - attention.RoFormerSinusoidalPositionalEmbedding.default_config() - .set(name="test_rope_emb", dim=dim) - .instantiate(parent=None) - ) - test_output = test_layer.forward(positions=positions) - np.testing.assert_allclose(np.expand_dims(ref_output, 0), test_output, atol=5e-7) + def embeddings(self) -> Tensor: + """Returns weights of shape cfg.shape + [dim].""" + return self.parameters["weight"].squeeze(0) - @parameterized.parameters( - (None, True), - (10, False), - ) - def test_rope_emb_no_pos(self, max_len, should_raise): - test_layer = ( - attention.RoFormerSinusoidalPositionalEmbedding.default_config() - .set(name="test_rope_emb", dim=10) - .instantiate(parent=None) - ) - if should_raise: - with self.assertRaises(ValueError): - test_layer.forward(max_seq_len=max_len) - else: - test_layer.forward(max_seq_len=max_len) + def forward(self, positions: Tensor) -> Tensor: + """ + Args: + positions: An integer tensor with arbitrary shape [...]. - @parameterized.parameters( - (2, 10, 32, 4), - ) - def test_default_rope_emb(self, batch_size, max_len, dim, num_heads): - rng = np.random.default_rng(seed=123) - query = jnp.asarray(rng.random([batch_size, max_len, dim])) - key = jnp.asarray(rng.random([batch_size, max_len, dim])) - value = jnp.asarray(rng.random([batch_size, max_len, dim])) - per_head_dim = dim // num_heads - - emb_layer_cfg = attention.RoFormerSinusoidalPositionalEmbedding.default_config().set( - dim=per_head_dim, - ) - linear_layer_cfg = attention.RoFormerQKVLinear.default_config().set( - query_dim=dim, - key_dim=dim, - value_dim=dim, - num_heads=num_heads, - per_head_dim=per_head_dim, - rope_pos_emb_layer=emb_layer_cfg, - rotary_value=False, - name="test_rope_linear", - ) - rope_linear_layer = linear_layer_cfg.instantiate(parent=None) - state = rope_linear_layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) + Returns: + Embeddings with shape [..., *cfg.dim[1:], dim]. + """ + embeddings = self.embeddings() + return embeddings[positions] - rope_emb_layer = emb_layer_cfg.set(name="test_rope_emb").instantiate(parent=None) - default_positions = rope_emb_layer.default_query_positions(max_len) - input_dict = dict(query=query, key=key, value=value) +def sinusoidal_positional_embeddings( + positions: Tensor, *, dim: int, min_timescale: float = 1, max_timescale: float = 10000 +) -> Tensor: + """Sinusoidal positional embeddings. - layer_outputs_no_position, _ = F( - rope_linear_layer, - inputs=input_dict, - state=state, - is_training=True, - prng_key=jax.random.PRNGKey(0), - ) - layer_outputs, _ = F( - rope_linear_layer, - inputs=dict(**input_dict, query_positions=default_positions), - state=state, - is_training=True, - prng_key=jax.random.PRNGKey(0), - ) - # test RoFormerQKVLinear uses default positions in RoFormerSinusoidalPositionalEmbedding - np.testing.assert_allclose(layer_outputs_no_position, layer_outputs, atol=1e-5) + Proposed in the original Transformer paper: https://arxiv.org/abs/1706.03762. - def _compare_against_roformer_attention( - self, - ref, - layer, - tgt_len, - batch_size, - ref_rope_emb, - positions, - ): - layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) - layer_param_shapes = jax.tree.map(lambda x: x.shape, layer_params) - print(f"layer state={layer_param_shapes}") - layer_params = parameters_from_torch_layer(ref) - model_dim, num_heads = layer.config.target_dim, layer.config.attention.num_heads - rng = np.random.default_rng(seed=123) - target = rng.random([batch_size, tgt_len, model_dim], dtype=np.float32) - null_mask = jnp.zeros([tgt_len, tgt_len]) - rand_mask = _random_mask(jax.random.PRNGKey(123), tgt_len, tgt_len) - - for mask in (None, null_mask, rand_mask): - if mask is not None: - mask = jnp.tile(mask[None, None, :, :], (batch_size, num_heads, 1, 1)) - layer_outputs, _ = F( - layer, - inputs=dict( - target=jnp.asarray(target), - attention_logit_biases=mask, - target_positions=positions, - ), - state=layer_params, - is_training=True, - prng_key=jax.random.PRNGKey(0), - ) - attn_mask = None if mask is None else as_torch_tensor(mask) - print("ref_rope_emb", ref_rope_emb.shape) - print("target", target.shape) - (ref_outputs,) = ref.forward( - torch.as_tensor(target, dtype=torch.float32), - attention_mask=attn_mask, - sinusoidal_pos=ref_rope_emb, - output_attentions=False, - ) - assert_allclose(layer_outputs.data, as_tensor(ref_outputs)) - - @parameterized.product(rotary_value=[True, False], override_positions=[True, False]) - def test_rope_self_attention(self, rotary_value: bool, override_positions: bool): - model_dim = 32 - num_heads = 4 - max_sequence_length = 12 - batch_size = 2 - rope_mha_cfg = attention.MultiheadAttention.default_config().set( - query_dim=model_dim, - key_dim=model_dim, - value_dim=model_dim, - num_heads=num_heads, - input_linear=RoFormerQKVLinear.default_config().set(rotary_value=rotary_value), - ) - rope_emb_layer = ( - attention.RoFormerSinusoidalPositionalEmbedding.default_config() - .set(name="test_rope_emb", dim=model_dim // num_heads) - .instantiate(parent=None) - ) - positions = ( - jax.random.randint( - jax.random.PRNGKey(0), - shape=(batch_size, max_sequence_length), - minval=0, - maxval=max_sequence_length, - ) - if override_positions - else jnp.expand_dims(jnp.arange(max_sequence_length), 0) - ) - ref_rope_emb = as_torch_tensor(rope_emb_layer.forward(positions=positions)).unsqueeze(1) - layer = attention.TransformerAttentionLayer.default_config().set( - source_dim=model_dim, - target_dim=model_dim, - name="rope_trans_attn", - attention=rope_mha_cfg, - structure="postnorm", - ) - layer = layer.instantiate(parent=None) - roformer_config = hf_roformer.RoFormerConfig( - hidden_size=model_dim, - num_attention_heads=num_heads, - attention_probs_dropout_prob=0, - hidden_dropout_prob=0, - rotary_value=rotary_value, - ) - print(f"roformer_config={roformer_config}") - ref = hf_roformer.RoFormerAttention(roformer_config) - self._compare_against_roformer_attention( - ref, - layer, - max_sequence_length, - batch_size, - ref_rope_emb, - positions if override_positions else None, - ) + Reference: + https://github.com/tensorflow/lingvo/blob/d2f1e1b3cccdac8f73ae20f86afb03560b1c176d/lingvo/core/layers.py#L2775-L2923 + The inputs to the sinusoid functions will be positions / timescale(k) + for dimension 0 <= k < num_timescales = dim // 2, where: + timescale(k) = geometric interpolation between min_timescale and max_timescale, i.e., + log(timescale(k) / min_timescale) / log(max_timescale / min_timescale) = + k / num_timescales. + Specifically: timescale(0) = min_timescale and timescale(num_timescales) = max_timescale. -def llama_reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: - """LLaMA reshape for broadcast function. + Args: + positions: An integer tensor of any shape [...]. Each value represents an + absolute or relative position. + dim: the embedding dimension. Must be divisible by 2. + min_timescale: The minimum timescale (used for channel 0 and dim // 2). + max_timescale: The maximum timescale (used for channel dim // 2 - 1 and dim - 1). - Ref: - https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L55-L60 + Returns: + Embeddings of shape [..., dim]. + + Raises: + NotImplementedError: If dim is not divisible by 2. """ - ndim = x.ndim - assert 1 < ndim - assert freqs_cis.shape == (x.shape[1], x.shape[-1]) - shape = [ - d if i == 1 or i == ndim - 1 else 1 # pylint: disable=consider-using-in - for i, d in enumerate(x.shape) - ] - return freqs_cis.view(*shape) + if dim % 2 != 0: + raise NotImplementedError(f"dim ({dim}) must be divisible by 2") + num_timescales = dim // 2 + # To ensure results match other libraries, it is important to calculate + # log_timescale_increment using float64 calculations. This has no + # runtime cost. + log_timescale_increment = math.log(max_timescale / min_timescale) / max(1, num_timescales - 1) -def llama_apply_rotary_emb( - *, - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor]: - """LLaMA apply rotary embeddings to input tensors using the given frequency tensor. + # [num_timescales]. + inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales) * -log_timescale_increment) - Ref: - https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L63-L73 - """ - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - freqs_cis = llama_reshape_for_broadcast(freqs_cis, xq_) - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) + # [..., num_timescales]. + scaled_time = jnp.expand_dims(positions, -1) * inv_timescales + # [..., dim]. + signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=-1) + return signal -class RefLLaMAAttention(torch.nn.Module): - """Reference Implementation of LLaMA-1. - Ref: - https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L76 +class SinusoidalPositionalEmbedding(BaseLayer): + """Sinusoidal positional embeddings. - The modifications are removing the dependency of ColumnParallelLinear and RowParallelLinear. + See sinusoidal_positional_embeddings()'s comments. """ - def __init__(self, n_heads: int, dim: int, max_batch_size: int, max_seq_len: int): - super().__init__() + @config_class + class Config(BaseLayer.Config): + """Configures SinusoidalPositionalEmbedding.""" - self.n_local_heads = n_heads - self.head_dim = dim // n_heads + dim: Required[int] = REQUIRED + min_timescale: float = 1 + max_timescale: float = 10000 - self.wq = torch.nn.Linear( - dim, - n_heads * self.head_dim, - bias=False, - ) - self.wk = torch.nn.Linear( - dim, - n_heads * self.head_dim, - bias=False, - ) - self.wv = torch.nn.Linear( - dim, - n_heads * self.head_dim, - bias=False, - ) - self.wo = torch.nn.Linear( - n_heads * self.head_dim, - dim, - bias=False, + def forward(self, positions: Tensor) -> Tensor: + """Looks up positional embeddings by positions.""" + cfg: SinusoidalPositionalEmbedding.Config = self.config + return sinusoidal_positional_embeddings( + positions, dim=cfg.dim, min_timescale=cfg.min_timescale, max_timescale=cfg.max_timescale ) - self.cache_k = torch.zeros((max_batch_size, max_seq_len, self.n_local_heads, self.head_dim)) - self.cache_v = torch.zeros((max_batch_size, max_seq_len, self.n_local_heads, self.head_dim)) - def forward( - self, - x: torch.Tensor, - start_pos: int, - freqs_cis: torch.Tensor, - mask: Optional[torch.Tensor], - ) -> torch.Tensor: - bsz, seqlen, _ = x.shape - xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) - - xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) - xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim) - xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim) - - xq, xk = llama_apply_rotary_emb(xq=xq, xk=xk, freqs_cis=freqs_cis) - - self.cache_k = self.cache_k.to(xq) - self.cache_v = self.cache_v.to(xq) - - self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk - self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv - - keys = self.cache_k[:bsz, : start_pos + seqlen] - values = self.cache_v[:bsz, : start_pos + seqlen] - - xq = xq.transpose(1, 2) - keys = keys.transpose(1, 2) - values = values.transpose(1, 2) - scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) - if mask is not None: - scores = scores + mask # (bs, n_local_heads, slen, cache_len + slen) - scores = torch.nn.functional.softmax(scores.float(), dim=-1).type_as(xq) - output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim) - output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) - - return self.wo(output) - - -class RoFormerSinusoidalPositionalEmbeddingAgainstLLaMATest(TestCase): - def llama_ref_precompute_freqs_cis( - self, *, dim: int, end: int, theta: float = 10000.0 - ) -> torch.Tensor: - """Reference LLaMA-1 implementation. - - Ref: - https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47-L52 - """ - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) - t = torch.arange(end, device=freqs.device) # type: ignore - freqs = torch.outer(t, freqs).float() # type: ignore - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 - return freqs_cis - - @parameterized.parameters([10000.0, 1000000.0]) - def test_against_llama_for_precompute_freqs_cis(self, theta: float): - max_len = 100 - dim = 32 - positions = jnp.expand_dims(jnp.arange(max_len), 0) - axlearn_rope_cfg = attention.RoFormerSinusoidalPositionalEmbedding.default_config().set( - dim=dim, - theta=theta, - ) - axlearn_rope_layer = axlearn_rope_cfg.set(name="rope").instantiate(parent=None) - axlearn_rope, _ = F( - axlearn_rope_layer, - inputs=dict(positions=positions), - state=axlearn_rope_layer.initialize_parameters_recursively( - prng_key=jax.random.PRNGKey(0) - ), - is_training=True, - prng_key=jax.random.PRNGKey(0), - ) - llama_rope = self.llama_ref_precompute_freqs_cis(dim=dim, end=max_len, theta=theta) - axlearn_imag, axlearn_real = jnp.split(axlearn_rope, 2, axis=-1) - llama_real, llama_imag = llama_rope.real, llama_rope.imag - # [0] is added, as axlearn_real and axlearn_imag has a batch_size=1 dimension. - assert_allclose(llama_real, as_tensor(axlearn_real)[0]) - assert_allclose(llama_imag, as_tensor(axlearn_imag)[0]) - - @parameterized.product( - dtype=(jnp.float32, jnp.bfloat16), - input_linear=( - None, - attention.QKVLinear.default_config(), - attention.GroupedQKVLinear.default_config(), - ), - has_query_positions=(True, False), - ) - def test_roformer_qkv_linear( - self, - dtype: jnp.dtype, - input_linear: attention.BaseQKVLinear.Config, - has_query_positions: bool, - ): - seq_len = 6 - batch_size = 2 - model_dim = 16 - num_heads = 4 - per_head_dim = model_dim // num_heads - roformer_qkv_linear_kwargs = { - "name": "roformer_qkv_linear", - "query_dim": model_dim, - "key_dim": model_dim, - "value_dim": model_dim, - "num_heads": num_heads, - "per_head_dim": per_head_dim, - "rotary_value": False, - } - num_kv_heads = num_heads - if input_linear is not None: - if isinstance(input_linear, attention.GroupedQKVLinear.Config): - num_kv_heads = num_heads // 2 - input_linear = input_linear.set(num_kv_heads=num_kv_heads) - roformer_qkv_linear_kwargs["input_linear"] = input_linear - - roformer_qkv_linear = ( - RoFormerQKVLinear.default_config() - .set(**roformer_qkv_linear_kwargs) - .instantiate(parent=None) - ) +class BaseMultiheadLinear(DenseGeneralBaseLayer): + """The linear layer used for multi-head attention. - # Check that we see the num kv heads is propagated from child input linear. - self.assertEqual(roformer_qkv_linear.num_kv_heads, num_kv_heads) + It uses einsum for efficient computation on TPU to avoid reshaping. + """ - query = jax.random.uniform(jax.random.PRNGKey(1), shape=(batch_size, seq_len, model_dim)) - key = jax.random.uniform(jax.random.PRNGKey(2), shape=(batch_size, seq_len, model_dim)) - value = jax.random.uniform(jax.random.PRNGKey(3), shape=(batch_size, seq_len, model_dim)) - roformer_qkv_linear_state = roformer_qkv_linear.initialize_parameters_recursively( - jax.random.PRNGKey(0) - ) - input_batch = dict(query=query, key=key, value=value) - if has_query_positions: - input_batch["query_positions"] = jax.random.permutation( - jax.random.PRNGKey(1), - jnp.arange(seq_len)[None, :].repeat(batch_size, axis=0), - axis=1, - independent=True, - ) + @config_class + class Config(DenseGeneralBaseLayer.Config): + """Configures BaseMultiheadLinear.""" + + model_dim: Required[int] = REQUIRED # Feature dim. + num_heads: Required[int] = REQUIRED # Number of attention heads. + per_head_dim: Required[int] = REQUIRED # Dimension per head. + bias: bool = True # Whether the linear modules have biases. + + @classmethod + def default_config(cls) -> Config: + cfg = super().default_config() + # Shard the 'num_heads' axis by the 'model' dim of the mesh. + cfg.param_partition_spec = (None, "model", None) + return cfg - layer_outputs, _ = F( - roformer_qkv_linear, - inputs=utils.cast_floats(input_batch, to_dtype=dtype), - state=utils.cast_floats(roformer_qkv_linear_state, to_dtype=dtype), - is_training=True, - prng_key=jax.random.PRNGKey(0), - ) - self.assertEqual(layer_outputs.query.dtype, dtype) - self.assertEqual(layer_outputs.key.dtype, dtype) - self.assertEqual(layer_outputs.value.dtype, dtype) - - def test_against_llama_for_apply_rotary_emb(self): - max_len = 100 - dim = 32 - batch_size = 4 - positions = jnp.expand_dims(jnp.arange(max_len), 0) - axlearn_rope_cfg = attention.RoFormerSinusoidalPositionalEmbedding.default_config().set( - dim=dim - ) - axlearn_rope_layer = axlearn_rope_cfg.set(name="rope").instantiate(parent=None) - axlearn_rope, _ = F( - axlearn_rope_layer, - inputs=dict(positions=positions), - state=axlearn_rope_layer.initialize_parameters_recursively( - prng_key=jax.random.PRNGKey(0) - ), - is_training=True, - prng_key=jax.random.PRNGKey(0), - ) - llama_rope = self.llama_ref_precompute_freqs_cis(dim=dim, end=max_len) - rng = np.random.default_rng(seed=123) - query = rng.random([batch_size, max_len, dim]) - key = rng.random([batch_size, max_len, dim]) - value = rng.random([batch_size, max_len, dim]) - llama_q, llama_k = llama_apply_rotary_emb( - xq=torch.Tensor(query), xk=torch.Tensor(key), freqs_cis=llama_rope - ) - axlearn_q, axlearn_k, _ = attention.apply_rotary_position_embeddings( - query=jnp.asarray(query), - key=jnp.asarray(key), - value=jnp.asarray(value), - sinusoidal_pos=axlearn_rope, - rotary_key=True, - rotary_value=False, + def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: + cfg = self.config + params = dict( + weight=ParameterSpec( + shape=(cfg.model_dim, cfg.num_heads, cfg.per_head_dim), + mesh_axes=cfg.param_partition_spec, + factorization=FactorizationSpec(axes=("row", None, "col")), + ) ) + if cfg.bias: + params["bias"] = self._bias_spec + return params - assert_allclose(as_tensor(llama_q.reshape(batch_size, max_len, -1)), axlearn_q, atol=5e-6) - assert_allclose(as_tensor(llama_k.reshape(batch_size, max_len, -1)), axlearn_k, atol=5e-6) - - def test_against_llama_for_attention(self): - max_len = 100 - dim = 32 - batch_size = 4 - n_heads = 4 - rng = np.random.default_rng(seed=123) - x = rng.random([batch_size, max_len, dim]) - ref_llama = RefLLaMAAttention( - n_heads=n_heads, dim=dim, max_batch_size=batch_size, max_seq_len=max_len - ) - llama_rope = self.llama_ref_precompute_freqs_cis(dim=dim // n_heads, end=max_len) - llama_output = ref_llama.forward(torch.Tensor(x), 0, llama_rope, mask=None) - - rope_mha_cfg = attention.MultiheadAttention.default_config().set( - query_dim=dim, - key_dim=dim, - value_dim=dim, - num_heads=n_heads, - input_linear=RoFormerQKVLinear.default_config().set( - rotary_value=False, - ), + @property + def _einsum_expr(self): + raise NotImplementedError(type(self)) + + def forward(self, inputs: Tensor) -> Tensor: + params = self.parameters + outputs = self.einsum_maybe_quantized( + self._einsum_expr, activation=inputs, kernel=params["weight"] ) + return outputs + params.get("bias", 0) - rope_mha = rope_mha_cfg.set(name="rope").instantiate(parent=None) + def _compute_fan_axes(self, name: str, parameter_spec: ParameterSpec) -> Optional[FanAxes]: + raise NotImplementedError(type(self)) - state = rope_mha.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) - state["i_proj"]["i_proj"]["q_proj"]["weight"] = jnp.asarray( - ref_llama.wq.weight.transpose(0, 1) - .reshape(dim, n_heads, dim // n_heads) - .detach() - .numpy() - ) - state["i_proj"]["i_proj"]["k_proj"]["weight"] = jnp.asarray( - ref_llama.wk.weight.transpose(0, 1) - .reshape(dim, n_heads, dim // n_heads) - .detach() - .numpy() - ) - state["i_proj"]["i_proj"]["v_proj"]["weight"] = jnp.asarray( - ref_llama.wv.weight.transpose(0, 1) - .reshape(dim, n_heads, dim // n_heads) - .detach() - .numpy() - ) - state["o_proj"]["weight"] = jnp.asarray( - ref_llama.wo.weight.reshape(dim, n_heads, dim // n_heads).detach().numpy() - ) - axlearn_output, _ = F( - rope_mha, - inputs=dict(query=jnp.asarray(x)), - state=state, - is_training=False, - prng_key=jax.random.PRNGKey(0), - ) - assert_allclose( - as_tensor(llama_output.reshape(batch_size, max_len, -1)), axlearn_output.data - ) +class MultiheadInputLinear(BaseMultiheadLinear): + """Multi-head input linear layer.""" + @property + def _einsum_expr(self): + return "btd,dnh->btnh" -class MultiheadLinearInitTest(TestCase): - """Tests MultiheadLinear initialization.""" - - @parameterized.parameters( - ( - MultiheadInputLinear, - FanAxes(in_axis=0, out_axis=(1, 2)), - { - "fan_in": 4, - "fan_out": 8 * 6, - "fan_avg": (4 + 8 * 6) / 2, - }, - ), - ( - MultiheadOutputLinear, - FanAxes(in_axis=(1, 2), out_axis=0), - { - "fan_in": 8 * 6, - "fan_out": 4, - "fan_avg": (8 * 6 + 4) / 2, - }, - ), - ( - MultiheadRelativePositionLinear, - FanAxes(in_axis=0, out_axis=(1, 2)), - { - "fan_in": 4, - "fan_out": 8 * 6, - "fan_avg": (4 + 8 * 6) / 2, - }, - ), - ) - def test_compute_fan_axes(self, cls, fan_axes, fans): - for dist in ("uniform", "normal", "truncated_normal"): - for scale in (1.0, 2.0): - for fan_type in ("fan_in", "fan_out", "fan_avg"): - cfg = cls.default_config().set( - name="test", model_dim=4, num_heads=8, per_head_dim=6 - ) - cfg.param_init = DefaultInitializer.default_config().set( - init_by_param_name={ - PARAM_REGEXP_WEIGHT: WeightInitializer.default_config().set( - fan=fan_type, scale=scale, distribution=dist - ) - } - ) - layer: BaseLayer = cfg.instantiate(parent=None) - # pylint: disable-next=protected-access - param_spec_map = layer._create_layer_parameter_specs() - self.assertEqual( - # pylint: disable-next=protected-access - layer._compute_fan_axes("weight", param_spec_map["weight"]), - fan_axes, - ) - layer_params = layer.initialize_parameters_recursively(jax.random.PRNGKey(1)) - weight = layer_params["weight"] - self.assertEqual(weight.dtype, jnp.float32) - fan = fans[fan_type] - expected_std = scale / math.sqrt(fan) - actual_std = np.std(weight) - self.assertBetween(actual_std, expected_std / 1.5, expected_std * 1.5) - - -class QKVLinearTest(TestCase): - """Tests QKVLinear, FusedQKVLinear, and associated layers.""" - - @parameterized.product( - test_cls=[ - attention.FusedQKVLinear, - attention.GroupedQKVLinear, - attention.FusedGroupedQKVLinear, - ], - with_positions=[True, False], - ) - def test_qkv_equality(self, test_cls: type[attention.BaseQKVLinear], with_positions: bool): - """Tests that the QKVLinear variants are equivalent when num_kv_heads=num_heads.""" - with utils.numeric_checks(True): - model_dim = 12 - num_heads = 4 - per_head_dim = model_dim // num_heads - layer_kwargs = dict( - query_dim=model_dim, - key_dim=model_dim, - value_dim=model_dim, - num_heads=num_heads, - per_head_dim=per_head_dim, - ) - base_cfg = QKVLinear.default_config().set(**layer_kwargs) - test_cfg = test_cls.default_config().set(**layer_kwargs) - maybe_set_config(test_cfg, num_kv_heads=num_heads) - base_layer = base_cfg.set(name="base").instantiate(parent=None) - test_layer = test_cfg.set(name="test").instantiate(parent=None) - - # Construct base layer state. - base_state = base_layer.initialize_parameters_recursively(jax.random.PRNGKey(0)) - - # Map state to fused version. - if test_cls == attention.FusedQKVLinear: - weight = jnp.array( - [base_state[el]["weight"] for el in ("q_proj", "k_proj", "v_proj")] - ) - bias = jnp.array([base_state[el]["bias"] for el in ("q_proj", "k_proj", "v_proj")]) - test_state = {"qkv_proj": dict(weight=weight, bias=bias)} - elif test_cls == attention.FusedGroupedQKVLinear: - # Concatenate along the num_heads dim. - weight = jnp.concatenate( - [base_state[el]["weight"] for el in ("q_proj", "k_proj", "v_proj")], axis=1 - ) - bias = jnp.concatenate( - [base_state[el]["bias"] for el in ("q_proj", "k_proj", "v_proj")], axis=0 - ) - test_state = {"qkv_proj": dict(weight=weight, bias=bias)} - else: - test_state = base_state - - # Construct test inputs. - batch_size, src_len, tgt_len = 2, 6, 6 - query = jax.random.uniform(jax.random.PRNGKey(0), [batch_size, tgt_len, model_dim]) - key = jax.random.uniform(jax.random.PRNGKey(1), [batch_size, src_len, model_dim]) - value = jax.random.uniform(jax.random.PRNGKey(2), [batch_size, src_len, model_dim]) - - # In the fused GQA case, we assume query=key=value. - if test_cls == attention.FusedGroupedQKVLinear: - key = value = None - - positions = jnp.ones((1, tgt_len)) if with_positions else None - inputs = dict(query=query, key=key, value=value, query_positions=positions) - outputs = {} - layer_names = ("base", "test") - for name, layer, state in zip( - layer_names, (base_layer, test_layer), (base_state, test_state) - ): - outputs[name], _ = F( - layer, - state=state, - is_training=True, - prng_key=jax.random.PRNGKey(456), - inputs=inputs, - ) - for layer_a, layer_b in combinations(layer_names, 2): - # Check that the outputs are close for all pairs. - self.assertNestedAllClose(outputs[layer_a], outputs[layer_b]) - - @parameterized.parameters( - dict(layer_cls=attention.QKVLinear, expected=4), - dict(layer_cls=attention.FusedQKVLinear, expected=4), - dict( - layer_cls=attention.QKVLinear, - num_kv_heads=2, - expected=UnknownFieldError("num_kv_heads"), - ), - dict( - layer_cls=attention.FusedQKVLinear, - num_kv_heads=2, - expected=UnknownFieldError("num_kv_heads"), - ), - dict( - layer_cls=attention.GroupedQKVLinear, - num_kv_heads=3, - expected=ValueError("should divide"), - ), - dict( - layer_cls=attention.FusedGroupedQKVLinear, - num_kv_heads=3, - expected=ValueError("should divide"), - ), - dict(layer_cls=attention.GroupedQKVLinear, num_kv_heads=2, expected=2), - dict(layer_cls=attention.FusedGroupedQKVLinear, num_kv_heads=2, expected=2), - ) - def test_num_kv_heads( - self, - layer_cls: type[attention.BaseQKVLinear], - expected: Union[int, Exception], - num_kv_heads: Optional[int] = None, - ): - model_dim = 12 - num_heads = 4 - per_head_dim = model_dim // num_heads - common_kwargs = dict( - query_dim=model_dim, key_dim=model_dim, value_dim=model_dim, per_head_dim=per_head_dim + @property + def _bias_spec(self): + cfg = self.config + return ParameterSpec( + shape=(cfg.num_heads, cfg.per_head_dim), + mesh_axes=cfg.param_partition_spec[-2:], ) - cfg = layer_cls.default_config().set(name="test", num_heads=num_heads, **common_kwargs) - if isinstance(expected, Exception): - ctx = self.assertRaisesRegex(type(expected), str(expected)) + # pylint: disable-next=no-self-use + def _compute_fan_axes(self, name: str, parameter_spec: ParameterSpec) -> Optional[FanAxes]: + if name == "weight": + return FanAxes(in_axis=0, out_axis=(1, 2)) else: - ctx = contextlib.nullcontext() - - with ctx: - if num_kv_heads is not None: - cfg.set(num_kv_heads=num_kv_heads) - layer = cfg.instantiate(parent=None) - self.assertEqual(expected, layer.num_kv_heads) - - @parameterized.parameters( - (QKVLinear.default_config(), QLinear.default_config()), - ( - RoFormerQKVLinear.default_config().set( - input_linear=QKVLinear.default_config(), rotary_value=False - ), - RoFormerQKVLinear.default_config().set( - input_linear=QLinear.default_config(), rotary_value=False - ), - ), - ) - def test_qlinear(self, base_cfg, test_cfg): - """Tests that QLinear is equivalent to QKVLinear with the same kv_state.""" - with utils.numeric_checks(True): - model_dim = 12 - num_heads = 3 - per_head_dim = model_dim // num_heads - layer_kwargs = dict( - query_dim=model_dim, - key_dim=model_dim, - value_dim=model_dim, - num_heads=num_heads, - per_head_dim=per_head_dim, - ) - base_cfg = base_cfg.set(**layer_kwargs) - test_cfg = test_cfg.set(**layer_kwargs) - maybe_set_config(test_cfg, num_kv_heads=num_heads) - base_layer = base_cfg.set(name="base").instantiate(parent=None) - test_layer = test_cfg.set(name="test").instantiate(parent=None) - - # Construct base layer state. - base_state = base_layer.initialize_parameters_recursively(jax.random.PRNGKey(0)) - # Map state to QLinear. - if "q_proj" in base_state: - test_state = {"q_proj": base_state["q_proj"]} - elif "i_proj" in base_state: - test_state = {"i_proj": {"q_proj": base_state["i_proj"]["q_proj"]}} - else: - raise ValueError("Cannot find expected q_proj state.") - - # Construct test inputs. - batch_size, src_len, tgt_len = 2, 6, 6 - query = jax.random.uniform(jax.random.PRNGKey(0), [batch_size, tgt_len, model_dim]) - key = jax.random.uniform(jax.random.PRNGKey(1), [batch_size, src_len, model_dim]) - value = jax.random.uniform(jax.random.PRNGKey(2), [batch_size, src_len, model_dim]) - - outputs = {} - layer_names = ("base", "test") - kv_kwargs = {"key": key, "value": value} - for name, layer, state in zip( - layer_names, (base_layer, test_layer), (base_state, test_state) - ): - outputs[name], _ = F( - layer, - state=state, - is_training=True, - prng_key=jax.random.PRNGKey(456), - inputs=dict(query=query, **kv_kwargs), - ) - if name == "base": - kv_kwargs = { - "kv_state": KVState(k_proj=outputs[name].key, v_proj=outputs[name].value) - } - for layer_a, layer_b in combinations(layer_names, 2): - # Check that the outputs are close for all pairs. - self.assertNestedAllClose(outputs[layer_a], outputs[layer_b]) - - @parameterized.parameters( - (attention.QKVLinear, 1), - (attention.FusedQKVLinear, 1), - (attention.GroupedQKVLinear, 1), - (attention.FusedGroupedQKVLinear, 1), - (attention.RoFormerQKVLinear, 1), - (attention.QKVLinear, 2), - (attention.FusedQKVLinear, 3), - (attention.GroupedQKVLinear, 4), - (attention.FusedGroupedQKVLinear, 3), - (attention.RoFormerQKVLinear, 2), - ) - def test_repeated_extend_step(self, layer_cls: type[attention.BaseQKVLinear], extend_step_len): - """Tests that calling QKVLinear.extend_step() multiple times with the - same time_step results in the same output.""" - model_dim = 8 - num_heads = 2 - per_head_dim = model_dim // num_heads - layer_kwargs = dict( - query_dim=model_dim, - key_dim=model_dim, - value_dim=model_dim, - num_heads=num_heads, - per_head_dim=per_head_dim, - ) - cfg = layer_cls.default_config().set(**layer_kwargs) - maybe_set_config(cfg, num_kv_heads=num_heads, rotary_value=False) - layer = cfg.set(name="test").instantiate(parent=None) - - # Construct base layer state. - layer_state = layer.initialize_parameters_recursively(jax.random.PRNGKey(0)) - - # Construct test inputs. - batch_size, tgt_len = 2, 4 - query = jax.random.uniform(jax.random.PRNGKey(0), [batch_size, tgt_len, model_dim]) - - fwd_output, _ = F( - layer, - state=layer_state, - is_training=False, - prng_key=jax.random.PRNGKey(456), - inputs=dict(query=query), - ) + return None + + +class MultiheadOutputLinear(BaseMultiheadLinear): + """Multi-head output linear layer.""" - cache_state, init_output = layer.init_states( - time_step=None, query=TensorSpec([batch_size, tgt_len]) + @property + def _einsum_expr(self): + return "btnh,dnh->btd" + + @property + def _bias_spec(self): + cfg = self.config + return ParameterSpec( + shape=(cfg.model_dim,), + mesh_axes=cfg.param_partition_spec[:1], ) - self.assertIsNone(init_output) - step_querys = [] - step_keys = step_values = None - for t in range(0, tgt_len, extend_step_len): - (cache_state, step_output), _ = F( - layer, - state=layer_state, - is_training=False, - prng_key=jax.random.PRNGKey(456), - inputs=dict(cached_states=cache_state, query=query[:, t : t + extend_step_len]), - method="extend_step", - ) - step_querys.append(step_output.query) - step_keys = step_output.key - step_values = step_output.value - self.assertNestedAllClose(fwd_output.query, jnp.concat(step_querys, axis=1)) - self.assertNestedAllClose(fwd_output.key, step_keys) - self.assertNestedAllClose(fwd_output.value, step_values) + # pylint: disable-next=no-self-use + def _compute_fan_axes(self, name: str, parameter_spec: ParameterSpec) -> Optional[FanAxes]: + if name == "weight": + return FanAxes(in_axis=(1, 2), out_axis=0) + else: + return None - @parameterized.parameters(jnp.float32, jnp.float16, jnp.bfloat16) - def test_dtypes_inherited_from_parent(self, dtype: jnp.dtype): - """Test that the dtype is inherited from the parent. - When neither `Config.cache_dtype` nor `BaseLayer.Config.dtype` are set the dtype should - be inherited from the parent, and the dtype should be preserved in values in the - cached states and outputs. - """ +def apply_attention_logit_biases( + logits: Tensor, attention_logit_biases: Optional[Tensor] = None +) -> Tensor: + """Applies `attention_logit_biases` on `logits`. - target_batch_size = 3 - target_max_len = 6 - model_dim = 12 - num_heads = 4 - per_head_dim = model_dim // num_heads - layer_kwargs = dict( - query_dim=model_dim, - key_dim=model_dim, - value_dim=model_dim, - num_heads=num_heads, - per_head_dim=per_head_dim, - ) + Args: + logits: A float Tensor. + attention_logit_biases: A float Tensor. If None, assume all zeros. - class Parent(BaseLayer): - @config_class - class Config(BaseLayer.Config): - qkv_linear: InstantiableConfig = QKVLinear.default_config().set(**layer_kwargs) - - def __init__(self, cfg: Config, *, parent: Module): - super().__init__(cfg, parent=parent) - cfg = self.config - self._add_child("qkv_linear", cfg.qkv_linear) - - parent_cfg = Parent.default_config().set(name="parent", dtype=dtype) - # Test assumes that dtype is not set in test_cfg. - self.assertIs(parent_cfg.qkv_linear.dtype, None) - parent = parent_cfg.instantiate(parent=None) - qkv_linear = parent.qkv_linear - state = qkv_linear.initialize_parameters_recursively(jax.random.PRNGKey(0)) - - # Check dtypes from init_states. - (cache, init_output), _ = F( - qkv_linear, - prng_key=jax.random.PRNGKey(0), - state=state, - inputs=dict( - time_step=None, - query=TensorSpec([target_batch_size, target_max_len]), - ), - method="init_states", - is_training=False, - ) - self.assertIsNone(init_output) - self.assertEqual(cache["key"].dtype, dtype) - self.assertEqual(cache["value"].dtype, dtype) - - query = jax.random.uniform( - jax.random.PRNGKey(0), - shape=(target_batch_size, target_max_len, model_dim), - dtype=dtype, - ) - # Time step in the middle, so that some of the init_state is masked. - time_step = jnp.full( - shape=target_batch_size, - fill_value=target_max_len // 2, - dtype=jnp.int32, - ) - (init_state, output), _ = F( - qkv_linear, - prng_key=jax.random.PRNGKey(0), - state=state, - inputs=dict(time_step=time_step, query=query), - method="init_states", - is_training=False, - ) - self.assertEqual(init_state["key"].dtype, dtype) - self.assertEqual(init_state["value"].dtype, dtype) - self.assertEqual(output.query.dtype, dtype) - self.assertEqual(output.key.dtype, dtype) - self.assertEqual(output.value.dtype, dtype) - - -class PerDimScaleTest(TestCase): - """Tests PerDimScale.""" - - @parameterized.parameters(jnp.float32, jnp.float16, jnp.bfloat16) - def test_per_dim_scale(self, dtype: jnp.dtype): - batch_size, tgt_len, num_head, model_dim = 3, 5, 2, 8 - per_head_dim = model_dim // num_head - layer: PerDimScale = ( - PerDimScale.default_config() - .set( - name="test", - dim=per_head_dim, - ) # We do not set layer dtype. - .instantiate(parent=None) - ) - state = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) - query = jax.random.normal( - jax.random.PRNGKey(456), [batch_size, tgt_len, num_head, per_head_dim], dtype=dtype - ) - self.assertEqual(dict(param=(per_head_dim,)), shapes(state)) - outputs, _ = F( - layer, - state=state, - is_training=True, - prng_key=jax.random.PRNGKey(456), - inputs=(query,), - ) - expected_outputs = query - assert_allclose(outputs, expected_outputs) - self.assertEqual(outputs.dtype, query.dtype) + Returns: + logits + attention_logit_biases, in logits.dtype. + """ + if attention_logit_biases is None: + return logits + return logits + attention_logit_biases.astype(logits.dtype) -class ScaleQueryTest(TestCase): - """Tests ScaleQuery.""" +def softmax_with_biases(logits: Tensor, attention_logit_biases: Optional[Tensor] = None) -> Tensor: + """Computes softmax with optional masking. - @parameterized.product( - scale_factor=[None, 7], - norm=[None, RMSNorm.default_config()], - per_dim_scale=[ - None, - PerDimScale.default_config(), - ], - ) - def test_scale_query( - self, - *, - scale_factor: Optional[float], - norm: Optional[RMSNorm.Config], - per_dim_scale: Optional[PerDimScale.Config], - ): - kwargs = self._scale_kwargs( - scale_factor=scale_factor, norm=norm, per_dim_scale=per_dim_scale - ) - forward_outputs, _ = F(**kwargs) - - self.assertEqual(forward_outputs.shape, kwargs["inputs"]["proj"].shape) - q_proj_scaled = kwargs["inputs"]["proj"] - if norm is not None: - assert isinstance(norm, RMSNorm.Config) - moment2 = (q_proj_scaled * q_proj_scaled).mean(axis=-1, keepdims=True) - q_proj_scaled = q_proj_scaled * jax.lax.rsqrt(moment2 + norm.eps) - if per_dim_scale is not None: - assert isinstance(per_dim_scale, PerDimScale.Config) - # We overrode the initializer for PerDimScale so we can measure the effect. - q_proj_scaled = q_proj_scaled * jax.nn.softplus(1.0) * 1.442695041 - - if scale_factor is None: - scale_factor = kwargs["module"].config.per_head_dim ** -0.5 - scale_factor = float(scale_factor) - q_proj_scaled = q_proj_scaled * scale_factor - - self.assertNestedAllClose(forward_outputs, q_proj_scaled) - - def _scale_kwargs( + Args: + logits: A Tensor of any shape. + attention_logit_biases: A Tensor that is broadcastable with logits. + See ``On attention logit biases`` in the file comments. + + Returns: + A Tensor of same shape and dtype as logits. + """ + check_numerics(logits) + logits = apply_attention_logit_biases(logits, attention_logit_biases) + logits_dtype = logits.dtype + if logits_dtype in (jnp.bfloat16, jnp.float16): + # Avoid computing softmax in 16-bit floats. + logits = logits.astype(jnp.float32) + probs = jax.nn.softmax(logits, axis=-1) + if probs.dtype != logits_dtype: + probs = probs.astype(logits_dtype) + check_numerics(probs) + return probs + + +def sigmoid_with_biases( + logits: Tensor, + attention_logit_biases: Optional[Tensor] = None, +) -> Tensor: + """Computes sigmoid with optional masking. + + Args: + logits: A Tensor of any shape. + attention_logit_biases: A Tensor that is broadcastable with logits. + See ``On attention logit biases`` in the file comments. + + Returns: + A Tensor of same shape and dtype as logits. + """ + check_numerics(logits) + logits = apply_attention_logit_biases(logits, attention_logit_biases) + # Avoid computing sigmoid in 16-bit floats. + logits_dtype = logits.dtype + if logits_dtype in (jnp.bfloat16, jnp.float16): + logits = logits.astype(jnp.float32) + probs = jax.nn.sigmoid(logits) + check_numerics(probs) + return probs + + +class BaseQKVLinear(BaseLayer): + """A layer that encapsulates mapping input queries, keys, and values to + multi-headed output queries, keys, and values. + """ + + @config_class + class Config(BaseLayer.Config): + """Configures BaseQKVLinear.""" + + # Input query feature dim. + query_dim: Required[int] = REQUIRED + # Input key feature dim. + key_dim: Required[int] = REQUIRED + # Input value feature dim. + value_dim: Required[int] = REQUIRED + # Number of attention heads. + num_heads: Required[int] = REQUIRED + # Dimension of each attention head. + per_head_dim: Required[int] = REQUIRED + # Autoregressive cache dtype. Should match the step dtype. + # Needs to match the forward dtype for Repeated layers. If None, infer as BaseLayer.dtype(). + cache_dtype: Optional[jnp.dtype] = None + + class Output(NamedTuple): + # [batch, target_length, num_heads, per_head_dim]. + query: Tensor + # [batch, source_length, num_heads, per_head_dim]. + key: Tensor + # [batch, source_length, num_heads, per_head_dim]. + value: Tensor + + @property + def num_kv_heads(self): + return self.config.num_heads + + def init_states( self, *, - scale_factor: Union[None, int, float, InstantiableConfig[attention.ScaleFn]], - norm: Optional[InstantiableConfig], - per_dim_scale: Optional[PerDimScale.Config], - ): - model_dim = 16 - if isinstance(scale_factor, (int, float)): - scale_factor = config_for_function(attention.constant_scale_fn).set(value=scale_factor) - - num_heads = 2 - per_head_dim = model_dim // num_heads - if per_dim_scale is not None: - per_dim_scale = per_dim_scale.set(dim=per_head_dim) - - cfg = attention.ScaleQuery.default_config().set( - name="test", - per_head_dim=per_head_dim, - norm=norm, - scale_factor=scale_factor, - per_dim_scale=per_dim_scale, - ) - layer = cfg.instantiate(parent=None) + time_step: Optional[Tensor], + query: Union[Tensor, TensorSpec], + key: Optional[Tensor] = None, + value: Optional[Tensor] = None, + kv_state: Optional[KVState] = None, + ) -> tuple[Nested[Tensor], Optional[Output]]: + """Initializes cache for autoregressive cached decoding. + + The method supports initializing an empty cache as well as prefilling: + * To initialize an empty cache, specify `time_step=None`. + In this case, `query` is allowed to be a TensorSpec. + * To prefill, provide `time_step` and `query` as Tensors. + + Args: + time_step: An optional Tensor of shape [batch]. Each value is an index into the length + dimension indicating where decoding will start from. + query: A Tensor or TensorSpec of shape [batch, target_length, target_dim] corresponding + to query vector at `time_step` indices. + For batch index `i`, only `query[i, :time_step[i], ...]` will affect subsequent + decoding. + key: An optional Tensor of shape [batch, source_length, source_dim]. + If None, will use `query`. + value: An optional Tensor of shape [batch, source_length, source_dim]. + If None, will use `query`. + kv_state: An optional KVState. If not None, both key and value must be None. + + Returns: + A tuple (init_states, output): + * init_states: A Nested Tensor state of `key`, `value` of shape + [batch, num_heads, per_head_dim, source_length], and `time_step` of shape [batch]. + * output: In the prefill case, an Output instance, where query is of size + [batch, target_length, num_heads, per_head_dim] and each of key, value are of dim + [batch, source_length, num_heads, per_head_dim]. + Otherwise, if initializing cache from scratch, output will be None. + + Raises: + ValueError: If key/value and kv_state are an invalid combination. + ValueError: If query and time_step are an invalid combination. + """ + cfg: BaseQKVLinear.Config = self.config + # Default to base layer dtype for initialization if cache_dtype is None. + dtype = cfg.cache_dtype or self.dtype() + assert dtype is not None + + if kv_state is not None and (key is not None or value is not None): + raise ValueError("kv_state should not be specified together with key/value.") + if time_step is not None and isinstance(query, TensorSpec): + raise ValueError("query must be a Tensor if time_step is provided.") + + output = None + # Always initialize to all 0's; if `time_step` is provided, we invoke `extend_step` below + # which updates the cache with the new `time_step`. + init_state = dict(time_step=jnp.zeros(query.shape[0], dtype=jnp.int32)) + + # If `kv_state` is provided externally, we do not have to maintain key/value in cache. + # Otherwise, initialize the cache from provided query, key, value. + if kv_state is None: - param_specs = layer.create_parameter_specs_recursively() - layer_params = jax.tree.map( - lambda spec: jnp.ones(spec.shape, dtype=spec.dtype), param_specs - ) + def maybe_initialize(kv: Optional[Tensor]): + # [batch, source/target_len, num_kv_heads, per_head_dim]. + if kv is None: + kv = jnp.zeros( + (*query.shape[:2], self.num_kv_heads, cfg.per_head_dim), dtype=dtype + ) + else: + kv = jnp.reshape(kv, (*kv.shape[:2], self.num_kv_heads, cfg.per_head_dim)) + return kv - batch_size = 3 - tgt_len = 10 - q_proj = jnp.concatenate( - ( - jnp.ones([batch_size, tgt_len // 2, num_heads, per_head_dim]), - jnp.zeros([batch_size, tgt_len // 2, num_heads, per_head_dim]), - ), - axis=1, - ) - kwargs = dict( - module=layer, - state=layer_params, - is_training=False, - prng_key=jax.random.PRNGKey(456), - inputs=dict(proj=q_proj), - ) - return kwargs + init_state.update(key=maybe_initialize(key), value=maybe_initialize(value)) + # If time_step is not provided, initialize an empty cache (i.e., all 0's). + # Otherwise, treat as prefill case and invoke `extend_step`. + if time_step is not None: + init_state, output = self.extend_step( + init_state, query, key=key, value=value, kv_state=kv_state + ) + # The time_step from `extend_step` includes full query length. + init_state["time_step"] = time_step -class ScaleKeyTest(TestCase): - """Tests ScaleKey.""" + return init_state, output - @parameterized.product( - scale_factor=[None, 7], - norm=[None, RMSNorm.default_config()], - ) - def test_scale_key( + def forward( self, + query: Tensor, *, - scale_factor: Optional[float], - norm: Optional[RMSNorm.Config], - ): - kwargs = self._scale_kwargs(scale_factor=scale_factor, norm=norm) - forward_outputs, _ = F(**kwargs) - - self.assertEqual(forward_outputs.shape, kwargs["inputs"]["proj"].shape) - q_proj_scaled = kwargs["inputs"]["proj"] - if norm is not None: - assert isinstance(norm, RMSNorm.Config) - moment2 = (q_proj_scaled * q_proj_scaled).mean(axis=-1, keepdims=True) - q_proj_scaled = q_proj_scaled * jax.lax.rsqrt(moment2 + norm.eps) - - if scale_factor is None: - scale_factor = 1.0 - scale_factor = float(scale_factor) - q_proj_scaled = q_proj_scaled * scale_factor - self.assertNestedAllClose(forward_outputs, q_proj_scaled) - - def _scale_kwargs( + key: Optional[Tensor] = None, + value: Optional[Tensor] = None, + kv_state: Optional[KVState] = None, + query_positions: Optional[Tensor] = None, + ) -> Output: + """Computes per-head query, key, and value for the input query, key, value. + + Args: + query: A Tensor of shape [batch, target_length, target_dim]. + key: an optional Tensor of shape [batch, source_length, source_dim]. + If None, will use `query`. + value: An optional Tensor of shape [batch, source_length, source_dim]. + If None, will use `query`. + kv_state: An optional KVState. If not None, both key and value must be None. + query_positions: An optional Tensor of shape [batch, target_length]. + + Returns: + An Output instance, where query is of size + [batch, target_length, num_heads, per_head_dim] and each of key, value are of dim + [batch, source_length, num_heads, per_head_dim]. + """ + raise NotImplementedError(type(self)) + + def extend_step( self, + cached_states: NestedTensor, + query: Tensor, *, - scale_factor: Union[None, int, float, InstantiableConfig[attention.ScaleFn]], - norm: Optional[InstantiableConfig], - ): - model_dim = 16 - if isinstance(scale_factor, (int, float)): - scale_factor = config_for_function(attention.constant_scale_fn).set(value=scale_factor) - - num_heads = 2 - per_head_dim = model_dim // num_heads - - cfg = attention.ScaleKey.default_config().set( - name="test", - per_head_dim=per_head_dim, - norm=norm, - scale_factor=scale_factor, - ) - layer = cfg.instantiate(parent=None) + key: Optional[Tensor] = None, + value: Optional[Tensor] = None, + kv_state: Optional[KVState] = None, + ) -> tuple[NestedTensor, Output]: + """Computes the value vector given the query of the current step. + This function is used by autoregressive decoding. + + Based on: + https://github.com/tensorflow/lingvo/blob/5754b2f840ebf0f8c52d87e5d4d76f22e372513e/lingvo/jax/layers/attentions.py#L1249 + + Args: + cached_states: A `NestedTensor` object containing tensors which are the results of + previous attentions, and index used for fast decoding. Contains "key" and "value" of + shape [batch, num_heads, per_head_dim, target_length], and a Tensor "time_step" of + shape [batch]. + query: Tensor of shape [batch, steps, target_dim] corresponding to query vector starting + at "time_step" indices. + key: An optional Tensor of shape [batch, source_length, source_dim]. If None, will use + `query`. + value: An optional Tensor of shape [batch, source_length, source_dim]. If None, will + use `query`. + kv_state: An optional KVState. If not None, both key and value must be None. + + Returns: + A `NestedTensor` state of key and value pair along with index updated at `time_step`. + An Output instance, where query is of size + [batch, target_length, num_heads, per_head_dim] and each of key, value are of dim + [batch, source_length, num_heads, per_head_dim]. + """ + time_step = cached_states["time_step"] + assert time_step.ndim == 1 - param_specs = layer.create_parameter_specs_recursively() - layer_params = jax.tree.map( - lambda spec: jnp.ones(spec.shape, dtype=spec.dtype), param_specs - ) + if kv_state is not None: + if key is not None or value is not None: + raise ValueError("kv_state should not be specified together with key/value") + kv_kwargs = dict(kv_state=kv_state) + else: + kv_kwargs = dict(key=key, value=value) - batch_size = 4 - tgt_len = 12 - k_proj = jnp.concatenate( - ( - jnp.ones([batch_size, tgt_len // 2, num_heads, per_head_dim]), - jnp.zeros([batch_size, tgt_len // 2, num_heads, per_head_dim]), - ), - axis=1, - ) - kwargs = dict( - module=layer, - state=layer_params, - is_training=False, - prng_key=jax.random.PRNGKey(123), - inputs=dict(proj=k_proj), - ) - return kwargs + num_query_steps = query.shape[1] + query_positions = jnp.arange(num_query_steps)[None] + query_positions += time_step[:, None] + # Project inputs to key, value and query. Each has shape [B, steps, N, H]. + q_proj, k_proj, v_proj = self.forward(query, **kv_kwargs, query_positions=query_positions) + updated_state = dict(time_step=time_step + num_query_steps) + if kv_state is None: + # Update the cache via dynamic slice. [B, S, N, H]. + cached_key = cached_states["key"] + cached_value = cached_states["value"] -def _convert_to_qkv_linear( - base_state: Nested[Tensor], *, input_linear_layer_class: type -) -> Nested[Tensor]: - """Converts the params of a MultiheadAttention layer + # Ensure that we accumulate using the original dtype. + k_proj = k_proj.astype(cached_key.dtype) + v_proj = v_proj.astype(cached_value.dtype) - ... to params of a MultiheadAttention layer with input_linear of the given type.""" - test_state = copy.deepcopy(base_state) + # TODO(dhwang2): jax.lax.dynamic_update_slice_in_dim is generally faster than advanced + # indexing, but an unusual slowdown was observed, with RLHF sampling taking up to + # 3 hours per run. Investigate and fix it. + # Note: All X_idx are small, so generating them on-demand is not costly. + b, _, n, h = cached_key.shape + b_idx = jnp.arange(b)[:, None, None, None] + t_idx = (jnp.arange(k_proj.shape[1])[None] + time_step[:, None])[:, :, None, None] + n_idx = jnp.arange(n)[None, None, :, None] + h_idx = jnp.arange(h)[None, None, None, :] + k_proj = cached_key.at[b_idx, t_idx, n_idx, h_idx].set(k_proj) + v_proj = cached_value.at[b_idx, t_idx, n_idx, h_idx].set(v_proj) - if issubclass( - input_linear_layer_class, (attention.FusedQKVLinear, attention.FusedGroupedQKVLinear) - ): + updated_state.update(key=k_proj, value=v_proj) + return updated_state, self.Output(query=q_proj, key=k_proj, value=v_proj) - def combine_qkv(param_name: str) -> Tensor: - qkv_params = [ - utils.get_recursively(base_state, f"i_proj/{proj}/{param_name}") - for proj in ("q_proj", "k_proj", "v_proj") - ] - if issubclass(input_linear_layer_class, attention.FusedQKVLinear): - return jnp.stack(qkv_params) - else: - return jnp.concatenate(qkv_params, axis=-2) - qkv_proj = {"weight": combine_qkv("weight")} - if "bias" in base_state["i_proj"]["q_proj"]: - qkv_proj["bias"] = combine_qkv("bias") - test_state["i_proj"] = VDict({"qkv_proj": qkv_proj}) +class QKVLinear(BaseQKVLinear): + """Maps input query, key, and value to multi-headed output query, key, and value.""" - return test_state + @config_class + class Config(BaseQKVLinear.Config): + """Configures QKVLinear.""" + # The layer used to project. + layer: MultiheadInputLinear.Config = MultiheadInputLinear.default_config() -class MultiheadAttentionTest(TestCase): - """Tests MultiheadAttention, GroupedQueryAttention, and associated layers.""" + def __init__(self, cfg: Config, *, parent: Module): + super().__init__(cfg, parent=parent) + cfg = self.config + for name, dim, num_heads in ( + ("q", cfg.query_dim, cfg.num_heads), + ("k", cfg.key_dim, self.num_kv_heads), + ("v", cfg.value_dim, self.num_kv_heads), + ): + proj_cfg = cfg.layer + proj_cfg.model_dim = dim + proj_cfg.num_heads = num_heads + proj_cfg.per_head_dim = cfg.per_head_dim + self._add_child(f"{name}_proj", proj_cfg) - def test_add_tensor_stats(self): - model_dim = 12 - num_heads = 4 - cfg = attention.MultiheadAttention.default_config().set( - name="attn", - query_dim=12, - key_dim=model_dim, - value_dim=model_dim, - num_heads=num_heads, - tensor_stats=DefaultTensorStats.default_config(), - ) - layer = cfg.instantiate(parent=None) - layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) - - batch_size, src_len, tgt_len = 2, 6, 6 - rng = np.random.default_rng(seed=123) - query = jnp.asarray(rng.random([batch_size, tgt_len, model_dim])) - key = jnp.asarray(rng.random([batch_size, src_len, model_dim])) - value = jnp.asarray(rng.random([batch_size, src_len, model_dim])) - attention_logit_biases = jnp.ones([batch_size, tgt_len, src_len]) * NEG_INF - x = dict(query=query, key=key, value=value, attention_logit_biases=attention_logit_biases) - _, output_collection = F( - layer, - inputs=x, - state=layer_params, - is_training=True, - prng_key=jax.random.PRNGKey(0), - ) - if "tensor_stats" in output_collection.summaries: - output_stats = output_collection.summaries["tensor_stats"] - else: - output_stats = {} - expected_stats = ["o_proj_outputs"] - for k in expected_stats: - assert k in output_stats - - def test_invalid_key_value_combinations_raise(self): - model_dim = 12 - num_heads = 4 - layer_kwargs = dict( - query_dim=model_dim, - key_dim=model_dim, - value_dim=model_dim, - num_heads=num_heads, - ) - multihead_attention = ( - attention.MultiheadAttention.default_config() - .set(name="test_multihead_attention", **layer_kwargs) - .instantiate(parent=None) - ) - fused_multihead_attention = ( - attention.MultiheadAttention.default_config() - .set( - name="test_fused_multihead_attention", - input_linear=attention.FusedQKVLinear.default_config(), - **layer_kwargs, - ) - .instantiate(parent=None) - ) - rng = np.random.default_rng(seed=123) - inputs = jnp.asarray(rng.random([2, 6, model_dim])) - for layer in (multihead_attention, fused_multihead_attention): - for query, key, value in [(inputs, None, inputs), (inputs, inputs, None)]: - with self.assertRaisesRegex( - ValueError, "key and value must be both None or both set" - ): - layer.forward(query, key=key, value=value) - - @parameterized.parameters(None, PerDimScale.default_config()) - def test_input_linear_variants(self, per_dim_scale): - with utils.numeric_checks(True): - model_dim = 12 - num_heads = 4 - layer_kwargs = dict( - query_dim=model_dim, - key_dim=model_dim, - value_dim=model_dim, - num_heads=num_heads, - query_scale=attention.ScaleQuery.default_config().set(per_dim_scale=per_dim_scale), - ) - multihead_attention = ( - attention.MultiheadAttention.default_config() - .set(name="test_multihead_attention", **layer_kwargs) - .instantiate(parent=None) - ) - multihead_attention_state = multihead_attention.initialize_parameters_recursively( - jax.random.PRNGKey(0) + def forward( + self, + query: Tensor, + *, + key: Optional[Tensor] = None, + value: Optional[Tensor] = None, + kv_state: Optional[Tensor] = None, + query_positions: Optional[Tensor] = None, + ) -> BaseQKVLinear.Output: + """Computes attention for the given query, key, value. + + If `key` or `value` are None, will use `query` in place. + + See parent class for full docstring. + """ + if kv_state is not None: + raise ValueError( + "QKVLinear computes key and value projections " + "and does not expect external `kv_state`." ) - fused_multihead_attention = ( - attention.MultiheadAttention.default_config() - .set( - name="test_fused_multihead_attention", - input_linear=attention.FusedQKVLinear.default_config(), - **layer_kwargs, - ) - .instantiate(parent=None) + del query_positions + + key = query if key is None else key + value = query if value is None else value + q_proj = self.q_proj(query) + k_proj = self.k_proj(key) + v_proj = self.v_proj(value) + return self.Output(query=q_proj, key=k_proj, value=v_proj) + + +class GroupedQKVLinear(QKVLinear): + """A variant of QKVLinear that supports configuring a different number of key, value + projections. + + Note that the number of key, value projections must evenly divide the number of query heads. + """ + + @config_class + class Config(QKVLinear.Config): + """Configures GroupedQKVLinear.""" + + # Number of heads for key, value projections. + # It is required that num_heads % num_kv_heads == 0. + num_kv_heads: Required[int] = REQUIRED + + def __init__(self, cfg: Config, *, parent: Module): + super().__init__(cfg, parent=parent) + cfg = self.config + if cfg.num_heads % cfg.num_kv_heads != 0: + raise ValueError( + f"The number of query subgroups ({cfg.num_kv_heads}) should divide " + f"the number of query heads ({cfg.num_heads})." ) - def fused_state_from(state): - output_state = {} - for k, v in state.items(): - if k == "i_proj": - weight = jnp.array( - [v[el]["weight"] for el in ("q_proj", "k_proj", "v_proj")] - ) - bias = jnp.array([v[el]["bias"] for el in ("q_proj", "k_proj", "v_proj")]) - output_state[k] = {"qkv_proj": dict(weight=weight, bias=bias)} - else: - output_state[k] = v - return output_state - - # Map state to fused version. - fused_multihead_attention_state = fused_state_from(multihead_attention_state) - - batch_size, src_len, tgt_len = 2, 6, 6 - rng = np.random.default_rng(seed=123) - query = jnp.asarray(rng.random([batch_size, tgt_len, model_dim])) - key = jnp.asarray(rng.random([batch_size, src_len, model_dim])) - value = jnp.asarray(rng.random([batch_size, src_len, model_dim])) - attention_logit_biases = jnp.ones([batch_size, tgt_len, src_len]) * NEG_INF - inputs = dict( - query=query, key=key, value=value, attention_logit_biases=attention_logit_biases + @property + def num_kv_heads(self): + return self.config.num_kv_heads + + +class QLinear(BaseQKVLinear): + """Maps input query to multi-headed output query. Assumes external KVState.""" + + @config_class + class Config(BaseQKVLinear.Config): + """Configures QLinear.""" + + # The layer used to project. + layer: MultiheadInputLinear.Config = MultiheadInputLinear.default_config() + + def __init__(self, cfg: Config, *, parent: Module): + super().__init__(cfg, parent=parent) + cfg = self.config + proj_cfg = cfg.layer + proj_cfg.model_dim = cfg.query_dim + proj_cfg.num_heads = cfg.num_heads + proj_cfg.per_head_dim = cfg.per_head_dim + self._add_child("q_proj", proj_cfg) + + def forward( + self, + query: Tensor, + *, + kv_state: KVState, + key: Optional[Tensor] = None, + value: Optional[Tensor] = None, + query_positions: Optional[Tensor] = None, + ) -> BaseQKVLinear.Output: + """Computes projects for the given query. Uses {k,v}_proj from `kv_state`. + + See parent class for full docstring. + """ + if kv_state is None or key is not None or value is not None: + raise ValueError( + f"Only kv_state is expected: key={key}, value={value}, kv_state={kv_state}" ) + q_proj = self.q_proj(query) + return self.Output(query=q_proj, key=kv_state.k_proj, value=kv_state.v_proj) - outputs = {} - layer_names = ("multihead_attention", "fused_multihead_attention") - for name, layer, state in zip( - layer_names, - (multihead_attention, fused_multihead_attention), - (multihead_attention_state, fused_multihead_attention_state), - ): - outputs[name], _ = F( - layer, - state=state, - is_training=True, - prng_key=jax.random.PRNGKey(456), - inputs=inputs, - ) - layer_output_data = outputs[name].data - # No NaN. - self.assertTrue(jnp.all(jnp.isfinite(layer_output_data)), layer_output_data) - for layer_a, layer_b in combinations(layer_names, 2): - # Check that the outputs are close for all pairs. - self.assertNestedAllClose(outputs[layer_a], outputs[layer_b]) - - @parameterized.parameters(None, PerDimScale.default_config()) - def test_all_mask(self, per_dim_scale): - with utils.numeric_checks(True): - model_dim = 12 - num_heads = 4 - per_head_dim = model_dim // num_heads - cfg = attention.MultiheadAttention.default_config().set( - name="test", - query_dim=model_dim, - key_dim=model_dim, - value_dim=model_dim, - num_heads=num_heads, - query_scale=attention.ScaleQuery.default_config().set(per_dim_scale=per_dim_scale), + +class FusedQKVLinear(BaseQKVLinear): + """Maps input query, key, and value to multi-headed query, key, and value using a fused weight. + + N.B. Only supports cases where query, key, and value all have the same shape. + """ + + @config_class + class Config(BaseQKVLinear.Config): + """Configures FusedQKVLinear.""" + + # The layer used to project. + layer: MultiheadInputLinear.Config = MultiheadInputLinear.default_config() + + def __init__(self, cfg: Config, *, parent: Module): + super().__init__(cfg, parent=parent) + cfg = self.config + if not cfg.query_dim == cfg.key_dim == cfg.value_dim: + raise ValueError( + f"All projection dims must be equal for {type(self)}, saw: " + f"query:{cfg.query_dim}, key:{cfg.key_dim}, value:{cfg.value_dim}" ) - layer: attention.MultiheadAttention = cfg.instantiate(parent=None) - self.assertContainsSubset( - dict( - dropout={}, - i_proj={ - **{ - proj: { - "weight": ParameterSpec( - dtype=layer.dtype(), - shape=(model_dim, num_heads, per_head_dim), - mesh_axes=PartitionSpec(None, "model", None), - factorization=FactorizationSpec(axes=("row", None, "col")), - ), - "bias": ParameterSpec( - dtype=layer.dtype(), - shape=(num_heads, per_head_dim), - mesh_axes=PartitionSpec("model", None), - factorization=None, - ), - } - for proj in ("q_proj", "k_proj", "v_proj") - }, - }, - o_proj={ - "bias": ParameterSpec( - dtype=layer.dtype(), - shape=(model_dim,), - mesh_axes=PartitionSpec( - None, - ), - factorization=None, - ), - "weight": ParameterSpec( - dtype=layer.dtype(), - shape=(model_dim, num_heads, per_head_dim), - mesh_axes=PartitionSpec(None, "model", None), - factorization=FactorizationSpec(axes=("row", None, "col")), - ), - }, + proj_cfg = cfg.layer + proj_cfg.model_dim = cfg.query_dim + proj_cfg.num_heads = cfg.num_heads + proj_cfg.per_head_dim = cfg.per_head_dim + self._add_child("qkv_proj", proj_cfg) + + def create_parameter_specs_recursively(self) -> NestedParameterSpec: + specs = VDict(**super().create_parameter_specs_recursively()) + + def transform_factorization_spec( + spec: Optional[FactorizationSpec], + ) -> Optional[FactorizationSpec]: + if spec is None: + return None + return FactorizationSpec(axes=[None] + list(spec.axes)) + + return jax.tree.map( + lambda spec: ParameterSpec( + dtype=spec.dtype, + shape=(3, *spec.shape), + mesh_axes=PartitionSpec(None, *spec.mesh_axes), + factorization=transform_factorization_spec(spec.factorization), + fan_axes=param_init.maybe_prepend_axis( + spec.fan_axes, axis_type=param_init.FanAxes.AxisType.BATCH_AXIS ), - layer.create_parameter_specs_recursively(), + ), + specs, + ) + + def initialize_parameters_recursively( + self, prng_key: Tensor, *, prebuilt: Optional[Nested[Optional[ParameterSpec]]] = None + ) -> NestedTensor: + if self._use_prebuilt_params(prebuilt): + return jax.tree.map(lambda _: None, prebuilt) + + def init(prng_key_i): + return VDict(qkv_proj=self.qkv_proj.initialize_parameters_recursively(prng_key_i)) + + return jax.vmap(init)(split_prng_key(prng_key, 3).keys) + + def forward( + self, + query: Tensor, + *, + key: Optional[Tensor] = None, + value: Optional[Tensor] = None, + kv_state: Optional[KVState] = None, + query_positions: Optional[Tensor] = None, + ) -> BaseQKVLinear.Output: + """Computes multi-head query, key, and value for the input query, key, value + using a fused weight. + + N.B. Only supports cases where query, key, and value all have the same shape if set. + + See parent class for full docstring. + + Raises: + ValueError: If key and value are not both set or both None; or if kv_state is not None. + """ + if kv_state is not None: + raise ValueError( + "FusedQKVLinear computes key and value projections " + "and does not expect external `kv_state`." ) + del query_positions + + with child_context("qkv_proj"): + params = self.qkv_proj.parameters + if key is None and value is None: + # Computing self attention. + # N.B. this branch (with just the query inputs) is required in + # order to get the best step time on TPU for self-attention. + inputs = query # [batch, target_length, target_dim]. + proj = self.qkv_proj.einsum_maybe_quantized( + "btd,pdnh->pbtnh", activation=inputs, kernel=params["weight"] + ) + elif key is not None and value is not None: + # Compute cross attention but with same target/source shapes. + assert ( + query.shape == key.shape == value.shape # pytype: disable=attribute-error + ), f"Not supported for {type(self)}." + inputs = jnp.stack( + [query, key, value], axis=0 + ) # [q/k/v, batch, target, model_dim]. + proj = self.qkv_proj.einsum_maybe_quantized( + "pbtd,pdnh->pbtnh", activation=inputs, kernel=params["weight"] + ) + else: + raise ValueError("Key and value should be either both None or both set.") + if self.qkv_proj.config.bias: + bias = jnp.expand_dims( + params.get("bias", jnp.array([0], dtype=query.dtype)), + (1, 2), + ) + proj = proj + bias + q_proj, k_proj, v_proj = proj + return self.Output(query=q_proj, key=k_proj, value=v_proj) + - layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) - qkv_shapes = dict( - weight=(model_dim, num_heads, per_head_dim), bias=(num_heads, per_head_dim) +class FusedGroupedQKVLinear(BaseQKVLinear): + """Maps input query, key, and value to multi-headed query, key, and value using a fused weight. + + The main difference from FusedQKVLinear is supporting a different number of key, value heads + than query heads. All of the projection weights are concatenated/fused along the `num_heads` + axis and then split after projection. + """ + + @config_class + class Config(BaseQKVLinear.Config): + """Configures FusedGroupedQKVLinear.""" + + # Number of heads for key, value projections. + # It is required that num_heads % num_kv_heads == 0. + num_kv_heads: Required[int] = REQUIRED + # The layer used to project. + layer: MultiheadInputLinear.Config = MultiheadInputLinear.default_config() + + def __init__(self, cfg: Config, *, parent: Module): + super().__init__(cfg, parent=parent) + cfg = self.config + if not cfg.query_dim == cfg.key_dim == cfg.value_dim: + raise ValueError( + f"All projection dims must be equal for {type(self)}, saw: " + f"query:{cfg.query_dim}, key:{cfg.key_dim}, value:{cfg.value_dim}" ) - expected_scale_query_params = {} - if per_dim_scale: - expected_scale_query_params["per_dim_scale"] = dict(param=(per_head_dim,)) - expected_params = { - "i_proj": {f"{x}_proj": qkv_shapes for x in ("q", "k", "v")}, - "o_proj": dict(weight=(model_dim, num_heads, per_head_dim), bias=(model_dim,)), - "dropout": {}, - "scale_key": {}, - "scale_query": expected_scale_query_params, - } - self.assertEqual( - expected_params, - shapes(layer_params), + if cfg.num_heads % cfg.num_kv_heads != 0: + raise ValueError( + f"The number of query subgroups {cfg.num_kv_heads} should divide " + f"the number of query heads {cfg.num_heads}." ) + proj_cfg = cfg.layer + proj_cfg.model_dim = cfg.query_dim + proj_cfg.num_heads = cfg.num_heads + 2 * cfg.num_kv_heads + proj_cfg.per_head_dim = cfg.per_head_dim + self._add_child("qkv_proj", proj_cfg) - batch_size, src_len, tgt_len = 2, 4, 6 - rng = np.random.default_rng(seed=123) - query = jnp.asarray(rng.random([batch_size, tgt_len, model_dim])) - key = jnp.asarray(rng.random([batch_size, src_len, model_dim])) - value = jnp.asarray(rng.random([batch_size, src_len, model_dim])) - attention_logit_biases = jnp.ones([batch_size, tgt_len, src_len]) * NEG_INF - inputs = dict( - query=query, key=key, value=value, attention_logit_biases=attention_logit_biases - ) - layer_outputs, _ = F( - layer, - state=layer_params, - is_training=True, - prng_key=jax.random.PRNGKey(456), - inputs=inputs, + @property + def num_kv_heads(self): + return self.config.num_kv_heads + + def forward( + self, + query: Tensor, + *, + key: Optional[Tensor] = None, + value: Optional[Tensor] = None, + kv_state: Optional[Tensor] = None, + query_positions: Optional[Tensor] = None, + ) -> FusedQKVLinear.Output: + """See FusedQKVLinear for full docstring. + + N.B. Only supports cases where key and value are both None. + """ + if kv_state is not None: + raise ValueError( + "FusedGroupedQKVLinear computes key and value projections " + "and does not expect external `kv_state`." ) - layer_output_data = layer_outputs.data - # No NaN. - self.assertTrue(jnp.all(jnp.isfinite(layer_output_data)), layer_output_data) + if key is not None or value is not None: + raise ValueError("Key and value should be both None.") + del query_positions + cfg = self.config + proj = self.qkv_proj(query) + q_proj, k_proj, v_proj = jnp.split( + proj, [cfg.num_heads, cfg.num_heads + cfg.num_kv_heads], axis=-2 + ) + return self.Output(query=q_proj, key=k_proj, value=v_proj) - @parameterized.product( - dtype=(jnp.float32, jnp.float16, jnp.bfloat16), - per_dim_scale=(None, PerDimScale.default_config()), - ) - def test_data_types(self, dtype: jnp.dtype, per_dim_scale: Optional[PerDimScale.Config]): - model_dim = 16 - num_heads = 4 - cfg = attention.MultiheadAttention.default_config().set( - name="test", - query_dim=model_dim, - key_dim=model_dim, - value_dim=model_dim, - num_heads=num_heads, - dtype=dtype, - query_scale=attention.ScaleQuery.default_config().set(per_dim_scale=per_dim_scale), + +def _rotary_sinusoidal_positional_embeddings( + *, positions: Tensor, dim: int, theta: float = 10000.0 +) -> Tensor: + """Generate the sin/cos positional embedding. + + Ref: + https://github.com/huggingface/transformers/blob/main/src/transformers/models/roformer/modeling_roformer.py#L76-L90 + + Args: + positions: A tensor representing the token position IDs with shape [batch_size, seq_len]. + dim: The dimensionality of the positional embedding. + theta: A parameter to scale the frequencies. + + Returns: + Rotary Positional Embedding with shape [batch_size, seq_len, dim]. + """ + if dim % 2 != 0: + raise ValueError(f"dim: {dim} should be a multiplier of 2.") + exponents = jnp.arange(dim).astype(jnp.float32) + pos_array = positions.astype(jnp.float32) + exponents = jnp.power(theta, 2 * (exponents // 2) / dim) + position_enc = jnp.expand_dims(pos_array, 2) / jnp.expand_dims(exponents, [0, 1]) + + rope_part_1 = jnp.sin(position_enc[:, :, 0::2]) + rope_part_2 = jnp.cos(position_enc[:, :, 1::2]) + rope = jnp.concatenate((rope_part_1, rope_part_2), axis=-1) + return rope + + +class RoFormerSinusoidalPositionalEmbedding(BaseLayer): + """Implementation of Rotary Position Embedding (RoPE). + + Ref: + https://github.com/huggingface/transformers/blob/62ceb4/src/transformers/models/roformer/modeling_roformer.py + """ + + @config_class + class Config(BaseLayer.Config): + """Configures RoFormerSinusoidalPositionalEmbedding.""" + + dim: Required[int] = REQUIRED # The dimensionality of the positional embedding. + theta: float = 10000.0 # The scale of base frequency. + + def default_query_positions(self, max_seq_len: int) -> Tensor: + """Compute default `positions` value to be inputed into forward when `positions` is + not provided to the corresponding QKVLinear class such as `RoFormerQKVLinear` + """ + return jnp.arange(max_seq_len)[None] # [batch_size=1, max_seq_len]. + + def forward( + self, positions: Optional[Tensor] = None, max_seq_len: Optional[int] = None + ) -> Tensor: + """ + TODO(bwzhang): 1. verify the performance under float32. + + Args: + positions: A tensor representing the token position IDs. + The shape is [batch_size, seq_len]. + max_seq_len: Max length of sequence, required if positions is not provided + + Returns: + Rotary Positional Embedding. Shape is [seq_len, dim]. + + Raises: + ValueError: If positions is None and max_seq_len is None. + """ + cfg = self.config + if positions is None: + if max_seq_len is None: + raise ValueError( + "Must provide `max_seq_len` for computing default query positions if " + "`positions` is None." + ) + positions = self.default_query_positions(max_seq_len) + return _rotary_sinusoidal_positional_embeddings( + positions=positions, dim=cfg.dim, theta=cfg.theta ) - layer = cfg.instantiate(parent=None) - layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) - batch_size, src_len, tgt_len = 2, 4, 6 - query = jnp.zeros([batch_size, tgt_len, model_dim], dtype=dtype) - key = jnp.zeros([batch_size, src_len, model_dim], dtype=dtype) - value = jnp.zeros([batch_size, src_len, model_dim], dtype=dtype) - attention_logit_biases = jnp.ones([batch_size, tgt_len, src_len]) * NEG_INF - inputs = dict( - query=query, key=key, value=value, attention_logit_biases=attention_logit_biases +def apply_rotary_position_embeddings( + *, + query: Tensor, + key: Tensor, + value: Tensor, + sinusoidal_pos: Tensor, + rotary_key: bool, + rotary_value: bool, +) -> tuple[Tensor, Tensor, Tensor]: + """This is a jax implementation (a copy) of the RoPE apply_rotary_position_embeddings. + + Ref: + https://github.com/huggingface/transformers/blob/v4.21.2/src/transformers/models/roformer/modeling_roformer.py#L322-L346 + + Args: + query: Query embeddings with shape [batch_size, seq_len, num_heads, dim]. + key: Key embeddings with shape [batch_size, seq_len, num_heads, dim]. + value: Value embeddings with shape [batch_size, seq_len, num_heads, dim]. + sinusoidal_pos: Rotary position embeddings with shape [batch_size, seq_len, 1, dim]. + rotary_key: Whether to apply rotary position embeddings on key. + rotary_value: Whether to apply rotary position embeddings on value. + + Returns: + A tuple of: + Rotary position affined query embeddings with shape [batch_size, seq_len, num_heads, dim] + Rotary position affined key embeddings with shape [batch_size, seq_len, num_heads, dim] + Rotary position affined value embeddings with shape [batch_size, seq_len, num_heads, dim] + if rotary_value == True, else original value embeddings + """ + # sin [batch_size, num_heads, sequence_length, embed_size_per_head//2] + # cos [batch_size, num_heads, sequence_length, embed_size_per_head//2] + sin, cos = jnp.split(sinusoidal_pos, 2, axis=-1) + # sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1] + sin_pos = jnp.reshape(jnp.stack([sin, sin], axis=-1), sinusoidal_pos.shape) + # cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1] + cos_pos = jnp.reshape(jnp.stack([cos, cos], axis=-1), sinusoidal_pos.shape) + # rotate_half_query_layer [-q1,q0,-q3,q2......,-qd-1,qd-2] + rotate_half_query = jnp.reshape( + jnp.stack([-query[..., 1::2], query[..., ::2]], axis=-1), query.shape + ) + query = query * cos_pos + rotate_half_query * sin_pos + + if rotary_key: + # rotate_half_key_layer [-k1,k0,-k3,k2......,-kd-1,kd-2] + rotate_half_key = jnp.reshape( + jnp.stack([-key[..., 1::2], key[..., ::2]], axis=-1), key.shape ) - layer_outputs, _ = F( - layer, - state=layer_params, - is_training=True, - prng_key=jax.random.PRNGKey(456), - inputs=inputs, + key = key * cos_pos + rotate_half_key * sin_pos + if rotary_value: + # rotate_half_value_layer [-v1,v0,-v3,v2......,-vd-1,vd-2] + rotate_half_value = jnp.reshape( + jnp.stack([-value[..., 1::2], value[..., ::2]], axis=-1), value.shape ) - self.assertEqual(layer_outputs.data.dtype, dtype) + value = value * cos_pos + rotate_half_value * sin_pos + return query, key, value - @parameterized.product( - base_cfg=( - attention.MultiheadAttention.default_config(), - attention.GroupedQueryAttention.default_config().set( - input_linear=attention.GroupedQKVLinear.default_config().set(num_kv_heads=2) - ), - attention.GroupedQueryAttention.default_config().set( - input_linear=attention.FusedGroupedQKVLinear.default_config().set(num_kv_heads=2) - ), - attention.GroupedQueryAttention.default_config().set( - input_linear=attention.RoFormerQKVLinear.default_config().set(rotary_value=False) - ), - attention.SigmoidAttention.default_config().set( - input_linear=attention.RoFormerQKVLinear.default_config().set(rotary_value=False), - seq_len=4, - ), - attention.SigmoidAttention.default_config().set( - # Used in ALiBi position encoding. - input_linear=FusedQKVLinear.default_config(), - seq_len=4, - ), - ), - attention_logit_biases_fn=( - lambda query_len, kv_len: None, - lambda query_len, kv_len: _random_mask(jax.random.PRNGKey(1), query_len, kv_len), - ), - kv_length_multiplier=(0.5, 1, 2), - has_query_positions=(False, True), - ) - def test_causal( - self, - base_cfg: attention.MultiheadAttention.Config, - attention_logit_biases_fn: Callable[[int, int], Tensor], - kv_length_multiplier: float, - has_query_positions: bool, - ): - """Tests that base_cfg(causal=True) is equivalent to applying a causal mask.""" - if ( - has_query_positions - and not isinstance(base_cfg.input_linear, RoFormerQKVLinear.Config) - or kv_length_multiplier != 1 - and isinstance( - base_cfg.input_linear, - (FusedGroupedQKVLinear.Config, RoFormerQKVLinear.Config, FusedQKVLinear.Config), - ) - ): - pytest.skip(reason="Incompatible test setting that does not need testing.") - - model_dim = 16 - num_heads = 4 - ref_cfg = base_cfg.clone( - name="test", - query_dim=model_dim, - key_dim=model_dim, - value_dim=model_dim, - num_heads=num_heads, + +class RoFormerQKVLinear(BaseQKVLinear): + """RoFormerQKVLinear class + + This class maps the query, key, and value using the RoPE embeddings. + """ + + @config_class + class Config(BaseQKVLinear.Config): + """Configures RoFormerQKVLinear.""" + + rope_pos_emb_layer: RoFormerSinusoidalPositionalEmbedding.Config = ( + RoFormerSinusoidalPositionalEmbedding.default_config() ) - self.assertFalse(ref_cfg.causal) - ref_layer = ref_cfg.instantiate(parent=None) - layer_params = ref_layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) - - test_cfg = ref_cfg.clone(causal=True) - test_layer = test_cfg.instantiate(parent=None) - - batch_size, query_len = 2, 4 - query = jnp.zeros([batch_size, query_len, model_dim], dtype=jnp.float32) - outputs = [] - - if has_query_positions: - query_positions = jax.random.permutation( - jax.random.PRNGKey(1), - jnp.arange(query_len)[None, :].repeat(batch_size, axis=0), - axis=1, - independent=True, - ) + input_linear: BaseQKVLinear.Config = QKVLinear.default_config() + # Whether to apply RoPE rotations to the value embeddings. + rotary_value: Required[bool] = REQUIRED - for layer in (ref_layer, test_layer): - inputs = dict(query=query) - kv_len = int(kv_length_multiplier * query_len) - if kv_length_multiplier < 1: - inputs["key"] = query[:, :kv_len] - inputs["value"] = query[:, :kv_len] - elif kv_length_multiplier > 1: - inputs["key"] = jnp.tile(query, [1, int(kv_length_multiplier), 1]) - inputs["value"] = jnp.tile(query, [1, int(kv_length_multiplier), 1]) - - attention_logit_biases = attention_logit_biases_fn(inputs["query"].shape[1], kv_len) - if layer is ref_layer: - # Apply causal mask on top of the logit biases for `ref_layer`. - causal_biases = make_index_position_biases(inputs["query"].shape[1], kv_len=kv_len) - if attention_logit_biases is None: - attention_logit_biases = causal_biases - else: - attention_logit_biases = apply_attention_logit_biases( - attention_logit_biases, causal_biases - ) - inputs["attention_logit_biases"] = attention_logit_biases - if has_query_positions: - inputs["query_positions"] = query_positions - - layer_outputs, _ = F( - layer, - state=layer_params, - is_training=True, - prng_key=jax.random.PRNGKey(456), - inputs=inputs, - ) - outputs.append(layer_outputs) - # The outputs are equivalent. - self.assertNestedAllClose(outputs[0], outputs[1]) - - @parameterized.product( - base_cfg=( - attention.MultiheadAttention.default_config(), - attention.GroupedQueryAttention.default_config().set( - input_linear=attention.GroupedQKVLinear.default_config().set(num_kv_heads=2) - ), - attention.GroupedQueryAttention.default_config().set( - input_linear=attention.FusedGroupedQKVLinear.default_config().set(num_kv_heads=2) - ), - attention.GroupedQueryAttention.default_config().set( - input_linear=attention.RoFormerQKVLinear.default_config().set(rotary_value=False) - ), - attention.SigmoidAttention.default_config().set( - input_linear=attention.RoFormerQKVLinear.default_config().set(rotary_value=False), - seq_len=4, - ), - attention.SigmoidAttention.default_config().set( - # Used in ALiBi position encoding. - input_linear=FusedQKVLinear.default_config(), - seq_len=4, + def __init__(self, cfg: QKVLinear.Config, *, parent: Module): + super().__init__(cfg, parent=parent) + cfg = self.config + self._add_child( + "rope_pos_emb_layer", + cfg.rope_pos_emb_layer.set(dim=cfg.per_head_dim), + ) + self._add_child( + "i_proj", + cfg.input_linear.set( + query_dim=cfg.query_dim, + value_dim=cfg.value_dim, + key_dim=cfg.key_dim, + num_heads=cfg.num_heads, + per_head_dim=cfg.per_head_dim, ), - ), - attention_logit_biases_fn=( - lambda seq_len: None, - lambda seq_len: _random_mask(jax.random.PRNGKey(1), seq_len, seq_len), - ), - has_query_positions=(False, True), - ) - def test_sliding_window( + ) + + @property + def num_kv_heads(self): + """Propagate num KV heads from input linear.""" + return self.i_proj.num_kv_heads + + def forward( self, - base_cfg: attention.MultiheadAttention.Config, - attention_logit_biases_fn: Callable[[int], Tensor], - has_query_positions: bool, - ): + query: Tensor, + *, + key: Optional[Tensor] = None, + value: Optional[Tensor] = None, + kv_state: Optional[KVState] = None, + query_positions: Optional[Tensor] = None, + ) -> BaseQKVLinear.Output: + cfg = self.config + # Query should have shape of [batch_size, seq_len, num_heads, per_head_dim]. + query, key, value = self.i_proj(query, key=key, value=value, kv_state=kv_state) + seq_len = query.shape[1] + sinusoidal_pos_emb = self.rope_pos_emb_layer.forward( + positions=query_positions, max_seq_len=seq_len + ).astype(query.dtype) + # sinusoidal_pos_emb shape should be [batch_size, seq_len, 1, dim] + sinusoidal_pos_emb = jnp.expand_dims(sinusoidal_pos_emb, 2) + + i_proj_computes_kv = kv_state is None + query, key, value = apply_rotary_position_embeddings( + sinusoidal_pos=sinusoidal_pos_emb, + query=query, + key=key, + value=value, + rotary_key=i_proj_computes_kv, + rotary_value=i_proj_computes_kv and cfg.rotary_value, + ) + + return self.Output(query, key, value) + + +class PerDimScale(BaseLayer): + """A layer to scale individual dimensions of the input.""" + + @config_class + class Config(BaseLayer.Config): + """Configures PerDimScale.""" + + dim: Required[int] = REQUIRED + + @classmethod + def default_config(cls) -> Config: + cfg: PerDimScale.Config = super().default_config() + cfg.param_init = ConstantInitializer.default_config().set(value=0.0) + return cfg + + def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: + cfg = self.config + return { + "param": ParameterSpec(shape=(cfg.dim,), mesh_axes=(None,)), + } + + def forward(self, x: Tensor) -> Tensor: + """Returns x * per_dim_scale.""" + cfg = self.config + assert x.shape[-1] == cfg.dim + # https://github.com/tensorflow/lingvo/blob/3d16483b749a1181330ae9ce318688e7518d63c9/lingvo/jax/layers/attentions.py#L232-L234 + # 1.0/jax.nn.softplus(0.0) = 1.442695041. Hard code this number to avoid unnecessary + # XLA op fusion. + r_softplus_0 = 1.442695041 + scale = jax.nn.softplus(self.parameters["param"]) * r_softplus_0 + return (x * scale).astype(x.dtype) + + +ScaleFn = Callable[[int], float] # A function mapping per_head_dim to a scale. + + +def constant_scale_fn(value: float) -> ScaleFn: + """A constant scale function for `MultiheadAttention`. + + Example: + `key_scale = config_for_function(constant_scale_fn).set(value=0.01)` + + Args: + value: The value to scale by. + + Returns: + A `ScaleFn` that always returns `value`. + """ + + def constant_function(per_head_dim: int) -> float: + del per_head_dim + return value + + return constant_function + + +def pow_scale_fn(exp: float) -> ScaleFn: + """A scale function for `MultiheadAttention` that computes `per_head_dim ** exp`. + + Example: + `query_scale = config_for_function(pow_scale_fn).set(exp=-0.5)` + + Args: + exp: The exponent. + + Returns: + A `ScaleFn` that computes `per_head_dim ** exp`. + """ + + return functools.partial(pow, exp=exp) + + +class BaseScaleQK(BaseLayer): + """Defines the common interface for scaling projected attention queries or keys. + + * All subclasses must have `per_head_dim` in their config. + """ + + @config_class + class Config(BaseLayer.Config): + """Configures BaseScaleQK.""" + + # The per-head dimension. + per_head_dim: Required[int] = REQUIRED + + def forward(self, proj: Tensor) -> Tensor: + """Scales the projected queries or keys. + + Args: + proj: The projected queries/keys. + Shape: [batch, seq_length, num_heads, per_head_dim]. + + Returns: + A tensor with the same shape as the input. """ - Tests that base_cfg with sliding window causal mask fns is equivalent to applying a - causal sliding window mask. + raise NotImplementedError(type(self)) + + +class ScaleQuery(BaseScaleQK): + """Default implementation for scaling projected queries.""" + + @config_class + class Config(BaseScaleQK.Config): + """Configures ScaleQuery.""" + + # The config for a normalization layer applied along the per-head dim. + # If None, no normalization is applied. + norm: Optional[InstantiableConfig] = None + # The config for a function to compute a query scale muliplier factor. + # If None, then self.default_scale_fn_config. + scale_factor: Optional[InstantiableConfig[ScaleFn]] = None + # A vector to apply per dimension scale to the query projection. + per_dim_scale: Optional[PerDimScale.Config] = None + + def __init__(self, cfg: Config, *, parent: Module): + super().__init__(cfg, parent=parent) + cfg = self.config + self._scale_factor = self.default_scale_factor_config() + if cfg.scale_factor is not None: + self._scale_factor = cfg.scale_factor + self._scale_factor = self._scale_factor.instantiate() + if cfg.norm is not None: + self._add_child("norm", cfg.norm.set(input_dim=cfg.per_head_dim)) + if cfg.per_dim_scale: + self._add_child("per_dim_scale", cfg.per_dim_scale.set(dim=cfg.per_head_dim)) + + def apply_norm(self, proj: Tensor) -> Tensor: + """Applies the norm to projected queries if configured.""" + if "norm" in self.children: + proj = self.norm(proj) + return proj + + def apply_per_dim_scale(self, proj: Tensor) -> Tensor: + """Applies the per-dim scale to projected queries if configured.""" + if "per_dim_scale" in self.children: + # The Lingvo MultiheadAttention applies a per_dim_scale: + # https://github.com/tensorflow/lingvo/blob/41212226eac7a26491790c2bd476b78493f93ff6/lingvo/core/batch_major_attention.py#L790 + proj = self.per_dim_scale(proj) + return proj + + def apply_scale_factor(self, proj: Tensor) -> Tensor: + """Applies the scale-factor to projected queries.""" + scale = self._scale_factor(self.config.per_head_dim) + return proj * scale + + def forward(self, proj: Tensor) -> Tensor: + """Scales the projected queries.""" + proj = self.apply_norm(proj) + proj = self.apply_per_dim_scale(proj) + proj = self.apply_scale_factor(proj) + # Stop scale constant from being folded with others. + # May increase numerical stability. + return ops.forward_optimization_barrier(proj) + + @staticmethod + def default_scale_factor_config() -> InstantiableConfig[ScaleFn]: + """The config for the default function used to compute the query scale.""" + + return config_for_function(pow_scale_fn).set(exp=-0.5) + + +class ScaleKey(BaseScaleQK): + """Default implementation for scaling projected keys.""" + + @config_class + class Config(BaseScaleQK.Config): + """Configures ScaleKey.""" + + # The config for a normalization layer applied along the per-head dim. + # If None, no normalization is applied. + norm: Optional[InstantiableConfig] = None + # The config for a function to compute a key scale muliplier factor. + # If None, then self.default_scale_factor_config. + scale_factor: Optional[InstantiableConfig[ScaleFn]] = None + + def __init__(self, cfg: Config, *, parent: Module): + super().__init__(cfg, parent=parent) + cfg = self.config + self._scale_factor = self.default_scale_factor_config() + if cfg.scale_factor is not None: + self._scale_factor = cfg.scale_factor + self._scale_factor = self._scale_factor.instantiate() + if cfg.norm is not None: + self._add_child("norm", cfg.norm.set(input_dim=cfg.per_head_dim)) + + def forward(self, proj: Tensor) -> Tensor: + """Scales the projected keys.""" + cfg = self.config + if cfg.norm is not None: + proj = self.norm(proj) + scale = self._scale_factor(cfg.per_head_dim) + proj = proj * scale + # Stop scale constant from being folded with others. + # May increase numerical stability. + return ops.forward_optimization_barrier(proj) + + @staticmethod + def default_scale_factor_config() -> InstantiableConfig[ScaleFn]: + """The config for the default function used to compute the key scale.""" + + return config_for_function(constant_scale_fn).set(value=1) + + +class MultiheadAttention(BaseLayer): + """A basic multi-head attention layer. + + Differences from torch.nn.MultiheadAttention: + - Use of einsum for efficient computation on TPU to avoid reshaping; + - Separate weights for {q,k,v}_proj for proper weight initialization that depends + on fan-out and efficient TPU execution (where split is not free). + """ + + @config_class + class Config(BaseLayer.Config): + """Configures MultiheadAttention.""" + + query_dim: Required[int] = REQUIRED # Input query feature dim. + key_dim: Required[int] = REQUIRED # Input key feature dim. + value_dim: Required[int] = REQUIRED # Input value feature dim. + output_dim: Optional[int] = None # Output feature dim. If None, use query_dim. + hidden_dim: Optional[int] = None # Hidden feature dim. If None, use query_dim. + # Number of attention heads. Must divide hidden_dim evenly. + num_heads: Required[int] = REQUIRED + # Config used to produce Q,K,V projections. + input_linear: BaseQKVLinear.Config = QKVLinear.default_config() + # Config used for the output projection. + output_linear: MultiheadOutputLinear.Config = MultiheadOutputLinear.default_config() + # The dropout layer. + dropout: Dropout.Config = Dropout.default_config() + # Config used to scale projected queries prior to computing logits. + query_scale: BaseScaleQK.Config = ScaleQuery.default_config() + # Config used to scale projected keys prior to computing logits. + key_scale: BaseScaleQK.Config = ScaleKey.default_config() + # Cap the absolute values of logits by tanh. Enabled by setting a positive value. + atten_logit_cap: Optional[float] = None + # A function to compute the boolean mask to apply when computing the attention + # where True means "attend" and False means "do not attend". + # Set to `causal_mask` for causal masking. + # When used with certain flash attention implementations, more efficient + # code paths may be used. (See the FlashAttention docstring for more details.) + # This field may not be specified if `causal` (deprecated) is specified. + # If `attention_logit_biases` argument is also specified, both masks are combined with AND. + mask: ConfigOr[Optional[MaskFn]] = None + # Deprecated. Use `mask=causal_mask` instead. + # If True, applies causal masking. `key` and `value` must be None. + # May not be specified if `mask` is already specified. + # If `attention_logit_biases` argument is also specified, both masks are combined with AND. + # TODO (apghml) Eliminate this field in favor of `mask`. + causal: Optional[bool] = None + + def __init__(self, cfg: Config, *, parent: Module): + super().__init__(cfg, parent=parent) + cfg = self.config + if cfg.causal and cfg.mask is not None: + raise NotImplementedError("Cannot specify `causal` when using `mask`.") + if cfg.causal: + self._mask_fn = causal_mask + else: + self._mask_fn = maybe_instantiate(cfg.mask) + # Configure inputs to multi-headed QKV projection. + i_proj_cfg = cfg.input_linear + i_proj_cfg.query_dim = cfg.query_dim + i_proj_cfg.key_dim = cfg.key_dim + i_proj_cfg.value_dim = cfg.value_dim + i_proj_cfg.num_heads = cfg.num_heads + i_proj_cfg.per_head_dim = self.per_head_dim() + self._add_child("i_proj", i_proj_cfg) + # Configure output projection. + o_proj_cfg = cfg.output_linear + o_proj_cfg.model_dim = self.output_dim() + o_proj_cfg.num_heads = cfg.num_heads + o_proj_cfg.per_head_dim = self.per_head_dim() + self._add_child("o_proj", o_proj_cfg) + # Add dropout layer. + self._add_child("dropout", cfg.dropout) + # Add query scaling layer. + self._add_child("scale_query", cfg.query_scale.set(per_head_dim=self.per_head_dim())) + # Add key scaling layer. + self._add_child("scale_key", cfg.key_scale.set(per_head_dim=self.per_head_dim())) + + def output_dim(self): + cfg = self.config + return cfg.output_dim or cfg.query_dim + + def hidden_dim(self): + cfg = self.config + return cfg.hidden_dim or cfg.query_dim + + def per_head_dim(self): + cfg = self.config + hidden_dim = self.hidden_dim() + if hidden_dim % cfg.num_heads != 0: + raise ValueError(f"num_heads ({cfg.num_heads}) must divide hidden_dim ({hidden_dim})") + return hidden_dim // cfg.num_heads + + class Output(NamedTuple): + """Outputs of MultiheadAttention. + + Fields: + data: [batch, target_length, output_dim]. The attention output. Always present. + probs: [batch, num_heads, target_length, source_length]. The attention probabilities. + Populated if "probs" is in `return_aux`. + kv_state: The KV state used for computing the attention outputs. + Populated if "kv_state" is in `return_aux`. """ - if has_query_positions and not isinstance(base_cfg.input_linear, RoFormerQKVLinear.Config): - return - - model_dim = 16 - num_heads = 4 - ref_cfg = base_cfg.clone( - name="test", - query_dim=model_dim, - key_dim=model_dim, - value_dim=model_dim, - num_heads=num_heads, - ) - self.assertFalse(ref_cfg.causal) - ref_layer = ref_cfg.instantiate(parent=None) - layer_params = ref_layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) - - sliding_window_size = 2 - test_cfg = ref_cfg.clone( - causal=False, - mask=config_for_function(sliding_window_causal_mask).set( - sliding_window_size=sliding_window_size - ), - ) - test_layer = test_cfg.instantiate(parent=None) - - batch_size, seq_len = 2, 4 - query = jnp.zeros([batch_size, seq_len, model_dim], dtype=jnp.float32) - outputs = [] - - if has_query_positions: - query_positions = jax.random.permutation( - jax.random.PRNGKey(1), - jnp.arange(seq_len)[None, :].repeat(batch_size, axis=0), - axis=1, - independent=True, + + data: Tensor + probs: Optional[Tensor] = None + kv_state: Optional[KVState] = None + + def _forward_for_mode( + self, + *, + mode: ForwardMode, + query: Union[Tensor, TensorSpec], + key: Optional[Tensor] = None, + value: Optional[Tensor] = None, + kv_state: Optional[KVState] = None, + attention_logit_biases: Union[None, Tensor, BaseAttentionBias] = None, + segment_ids: Optional[Tensor] = None, + query_positions: Optional[Tensor] = None, + cached_states: Optional[NestedTensor] = None, + return_aux: Optional[set[str]] = None, + ) -> tuple[Nested[Tensor], Optional[Output]]: + """Computes attention for the given query, key, value, and attention logit biases. + + If key and value are both None, computes self-attention using query. + + Args: + mode: Configures whether `cached_states` are consumed or emitted. See `ForwardMode` for + details. + query: A Tensor or TensorSpec of shape [batch, target_length, target_dim]. + key: An optional Tensor of shape [batch, source_length, source_dim]. + value: An optional Tensor of shape [batch, source_length, source_dim]. + kv_state: An optional KVState. If specified, both `key` and `value` should be None. + attention_logit_biases: See ``On attention logit biases`` in the file comments. + segment_ids: See ``On segment_ids`` in the file comments. + query_positions: See ``On positions`` in the file comments. + cached_states: Optional NestedTensor as produced by `init_states`. + return_aux: See comments on `Output`. + + Returns: + A tuple (cached_states, output): + * cached_states: An optional NestedTensor of cache states, depending on `mode`. + * output: An optional Output instance, where .data is of the same shape as query and + .probs is of shape [batch, num_heads, target_length, source_length]. + If initializing cache from scratch, output will be None. + + Raises: + ValueError: If key & value are an invalid combination. + ValueError: If `mode` is unsupported. + """ + # Validate key & value combination. + if (key is None) != (value is None): + raise ValueError( + "key and value must be both None or both set, " + f"key:{type(key)}, value:{type(value)}" ) + if kv_state is not None: + if key is not None or value is not None: + raise ValueError("kv_state should not be specified together with key/value") + kv_kwargs = dict(kv_state=kv_state) + else: + kv_kwargs = dict(key=key, value=value) - for layer in (ref_layer, test_layer): - attention_logit_biases = attention_logit_biases_fn(seq_len) - if layer is ref_layer: - # Apply causal and sliding window mask on top of the logit biases for `ref_layer`. - attention_logit_biases = apply_attention_logit_biases( - make_sliding_window_causal_biases(seq_len, sliding_window_size), - attention_logit_biases, - ) - inputs = dict(query=query, attention_logit_biases=attention_logit_biases) - if has_query_positions: - inputs["query_positions"] = query_positions - layer_outputs, _ = F( - layer, - state=layer_params, - is_training=True, - prng_key=jax.random.PRNGKey(456), - inputs=inputs, + if mode == ForwardMode.FORWARD: + i_proj_state, i_proj_output = ( + None, + self.i_proj(query, query_positions=query_positions, **kv_kwargs), ) - outputs.append(layer_outputs) - # The outputs are equivalent. - self.assertNestedAllClose(outputs[0], outputs[1]) - - @parameterized.product( - dtype=(jnp.float32, jnp.float16, jnp.bfloat16), - per_dim_scale=(None, PerDimScale.default_config()), - atten_logit_cap=(0.0, 20.0), - input_linear=( - None, # Use the default linear. - attention.QKVLinear.default_config(), - attention.FusedQKVLinear.default_config(), - attention.GroupedQKVLinear.default_config().set(num_kv_heads=4), - attention.FusedGroupedQKVLinear.default_config().set(num_kv_heads=4), - ), - bias=(True, False), - ) - def test_gqa_forward( - self, - dtype: jnp.dtype, - per_dim_scale: Optional[PerDimScale.Config], - atten_logit_cap: float, - input_linear: attention.BaseQKVLinear.Config, - bias: bool, - ): - """When num_kv_heads=num_heads, GQA should be equivalent to MHA.""" - model_dim = 16 - num_heads = 4 - layer_kwargs = dict( - query_dim=model_dim, - key_dim=model_dim, - value_dim=model_dim, - num_heads=num_heads, - query_scale=attention.ScaleQuery.default_config().set(per_dim_scale=per_dim_scale), - atten_logit_cap=atten_logit_cap, - dtype=dtype, - ) - init_key = jax.random.PRNGKey(123) - # Initialize MultiheadAttention. - base_cfg = attention.MultiheadAttention.default_config().set(**layer_kwargs) - set_bias_recursively(base_cfg, bias=bias) - base_layer = base_cfg.set(name="base").instantiate(parent=None) - base_state = base_layer.initialize_parameters_recursively(prng_key=init_key) - # Initialize GroupedQueryAttenion. - cfg = attention.GroupedQueryAttention.default_config().set(**layer_kwargs) - if input_linear is not None: - cfg.set(input_linear=input_linear) - set_bias_recursively(cfg, bias=bias) - test_layer = cfg.set(name="test").instantiate(parent=None) - logging.info("base_state=%s", shapes(base_state)) - # We convert 'base_state' to 'test_state' because JAX does not ensure that RNG behavior - # remains the same with vs. without vmap. So test_layer initialization may behave - # differently even with the same seed. - test_state = _convert_to_qkv_linear( - base_state, input_linear_layer_class=cfg.input_linear.klass - ) - logging.info("transformed_test_state=%s", shapes(test_state)) - - # Dummy inputs. - batch_size, tgt_len = 2, 6 - inputs = dict( - query=jax.random.normal( - jax.random.PRNGKey(124), - [batch_size, tgt_len, model_dim], - dtype=dtype, - ), - key=None, - value=None, - attention_logit_biases=attention_bias.make_causal_biases(tgt_len), - ) - # Get outputs. - forward_key = jax.random.PRNGKey(456) - base_outputs, _ = F( - base_layer, - state=base_state, - is_training=False, - prng_key=forward_key, - inputs=inputs, + elif mode == ForwardMode.INIT_STATES: + assert cached_states is not None + assert query_positions is None + i_proj_state, i_proj_output = self.i_proj.init_states( + time_step=cached_states["i_proj"], query=query, **kv_kwargs + ) + elif mode == ForwardMode.EXTEND_STEP: + assert cached_states is not None + assert query_positions is None + i_proj_state, i_proj_output = self.i_proj.extend_step( + cached_states["i_proj"], query, **kv_kwargs + ) + else: + raise ValueError(f"Unrecognized mode {mode}.") + + if i_proj_output is None: + assert mode == ForwardMode.INIT_STATES + return dict(i_proj=i_proj_state), None + + q_proj, k_proj, v_proj = i_proj_output + kv_state = KVState(k_proj=k_proj, v_proj=v_proj) + q_proj = self._remat_name(q_proj, "q_proj") + k_proj = self._remat_name(k_proj, "k_proj") + v_proj = self._remat_name(v_proj, "v_proj") + self.vlog(3, "atten.q_proj=%s", q_proj.sum()) + self.vlog(3, "atten.k_proj=%s", k_proj.sum()) + self.vlog(3, "atten.v_proj=%s", v_proj.sum()) + attention_logit_biases = as_attention_bias(attention_logit_biases) + if self._mask_fn is not None: + target_positions = None + if mode == ForwardMode.EXTEND_STEP: + target_positions = cached_states["i_proj"]["time_step"] + if self._mask_fn is causal_mask: + # Needed for legacy flash attention implementations that don't have + # sparse mask support. + # E.g., the legacy tpu flash attention, all current gpu flash attention + # implementations. + attention_logit_biases += CausalAttentionBias( + shape=(q_proj.shape[1], k_proj.shape[1]), + target_positions=target_positions, + dtype=q_proj.dtype, + ) + else: + attention_logit_biases += MaskFnAttentionBias( + self._mask_fn, + shape=(q_proj.shape[1], k_proj.shape[1]), + target_positions=target_positions, + dtype=q_proj.dtype, + ) + if segment_ids is not None: + attention_logit_biases += SegmentIdAttentionBias(segment_ids) + context, probs = self._compute_attention( + q_proj=q_proj, + k_proj=k_proj, + v_proj=v_proj, + attention_logit_biases=attention_logit_biases, ) - test_outputs, _ = F( - test_layer, - state=test_state, - is_training=False, - prng_key=forward_key, - inputs=inputs, + self.vlog(3, "atten.prob=%s", probs[0, 0, 0, :]) + self.vlog(3, "atten.context=%s", context.sum()) + + # [batch, target_length, output_dim]. + o_proj = self.o_proj(context) + outputs = self._remat_name(o_proj, "o_proj") + self._add_tensor_stats("o_proj_outputs", outputs) + return_aux = return_aux or set() + output = self.Output( + data=outputs, + probs=probs if "probs" in return_aux else None, + kv_state=kv_state if "kv_state" in return_aux else None, ) - self.assertNestedAllClose(base_outputs, test_outputs) + return dict(i_proj=i_proj_state), output - def _test_extend_step( + def _compute_attention( self, - attention_cfg: attention.MultiheadAttention.Config, *, - model_dim: int, - num_heads: int, - dtype: jnp.dtype, - bias: bool, - extend_step_len: int, - ): - cfg = attention_cfg.set( - query_dim=model_dim, - key_dim=model_dim, - value_dim=model_dim, - num_heads=num_heads, + q_proj: Tensor, + k_proj: Tensor, + v_proj: Tensor, + attention_logit_biases: BaseAttentionBias, + ) -> tuple[Tensor, Tensor]: + """Computes attention context and probs. + + Args: + q_proj: [batch_size, target_length, num_heads, per_head_dim]. + k_proj: [batch_size, source_length, num_heads, per_head_dim]. + v_proj: [batch_size, source_length, num_heads, per_head_dim]. + attention_logit_biases: See ``On attention logit biases`` in the file comments. + + Returns: + The context of shape [batch_size, target_length, num_heads, per_head_dim], + and probs of shape [batch, num_heads, target_length, source_length]. + """ + logits = self._compute_logits(q_proj, k_proj) + logits = self._cap_logits(logits) + self.vlog(3, "atten.logits=%s", logits[0, 0, 0, :]) + probs = softmax_with_biases(logits, attention_logit_biases=attention_logit_biases.value()) + probs = self.dropout(probs) + context = self._compute_context(probs, v_proj) + context = self._remat_name(context, "context") + return context, probs + + def forward( + self, + query: Tensor, + *, + key: Optional[Tensor] = None, + value: Optional[Tensor] = None, + kv_state: Optional[KVState] = None, + attention_logit_biases: Optional[Tensor] = None, + segment_ids: Optional[Tensor] = None, + query_positions: Optional[Tensor] = None, + return_aux: Optional[set[str]] = None, + ) -> Output: + """Computes attention for the given query, key, value, and attention logit biases. + + If key and value are both None, computes self-attention using query. + + Args: + query: A Tensor of shape [batch, target_length, target_dim]. + key: An optional Tensor of shape [batch, source_length, source_dim]. + value: An optional Tensor of shape [batch, source_length, source_dim]. + kv_state: An optional KVState. If not None, both key and value must be None. + attention_logit_biases: See ``On attention logit biases`` in the file comments. + segment_ids: See `On segment_ids` in the file comments. + query_positions: See ``On positions`` in the file comments. + return_aux: See comments on `Output`. + + Returns: + An Output instance, where .data is of the same shape as query and .probs is of shape + [batch, num_heads, target_length, source_length]. + + Raises: + ValueError: If key & value are an invalid combination. + """ + _, output = self._forward_for_mode( + mode=ForwardMode.FORWARD, + query=query, + key=key, + value=value, + kv_state=kv_state, + attention_logit_biases=attention_logit_biases, + segment_ids=segment_ids, + query_positions=query_positions, + return_aux=return_aux, ) - cfg.input_linear.set(dtype=dtype, cache_dtype=None) - set_bias_recursively(cfg, bias=bias) - layer: attention.MultiheadAttention = cfg.set(name="test").instantiate(parent=None) + return output - layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) + def _cap_logits(self, logits: Tensor) -> Tensor: + """Caps the logits with tanh.""" + cfg = self.config + if not cfg.atten_logit_cap or cfg.atten_logit_cap <= 0.0: + return logits + cap = jnp.array(cfg.atten_logit_cap, dtype=logits.dtype) + return cap * jnp.tanh(logits / cap) - batch_size, tgt_len = 2, 6 - head_dim = model_dim // num_heads - query = jax.random.normal( - jax.random.PRNGKey(123), [batch_size, tgt_len, model_dim], dtype=dtype - ) - key = value = kv_state = None - if attention_cfg.klass == attention.GroupedQueryAttention: - pass - elif attention_cfg.input_linear.klass == QLinear: - kv_state = KVState( - k_proj=jax.random.normal( - jax.random.PRNGKey(124), [batch_size, tgt_len, num_heads, head_dim], dtype=dtype - ), - v_proj=jax.random.normal( - jax.random.PRNGKey(125), [batch_size, tgt_len, num_heads, head_dim], dtype=dtype - ), - ) - else: - # Make key and value distinct from query. Otherwise, it is equivalent - # to the query only case. - key = value = query + 0.1 - attention_logit_biases = attention_bias.make_causal_biases(tgt_len) - return_aux = {"probs"} - inputs = dict( + def _compute_logits(self, q_proj: Tensor, k_proj: Tensor) -> Tensor: + """Compute attention logits. + + Args: + q_proj: query tensor, [batch, target_length, num_heads, per_head_dim]. + k_proj: key tensor, [batch, source_length, num_heads, per_head_dim]. + + Returns: + logits: [batch, num_heads, target_length, source_length]. + """ + q_proj = self.scale_query(q_proj) + k_proj = self.scale_key(k_proj) + return jnp.einsum("btnh,bsnh->bnts", q_proj, k_proj) + + def _compute_context(self, probs: Tensor, v_proj: Tensor) -> Tensor: + """Compute attention context. + + Args: + probs: probs tensor, [batch, num_heads, target_length, source_length]. + v_proj: value tensor, [batch, source_length, num_heads, per_head_dim]. + + Returns: + context: [batch, target_length, num_heads, per_head_dim]. + """ + return jnp.einsum("bnts,bsnh->btnh", probs, v_proj).astype(v_proj.dtype) + + def init_states( + self, + *, + time_step: Optional[Tensor], + query: Union[Tensor, TensorSpec], + key: Optional[Tensor] = None, + value: Optional[Tensor] = None, + kv_state: Optional[KVState] = None, + attention_logit_biases: Optional[Tensor], + return_aux: Optional[set[str]] = None, + ) -> tuple[Nested[Tensor], Optional[Output]]: + """Initializes cache for autoregressive cached decoding. + + The method supports initializing an empty cache as well as prefilling: + * To initialize an empty cache, specify `time_step=None`. + In this case, `query` is allowed to be a TensorSpec. + * To prefill, provide `time_step` and `query` as Tensors. + + Args: + time_step: A Tensor of shape [batch]. Each value is an index into the length dimension + indicating where decoding will start from. + query: A Tensor or TensorSpec of shape [batch, target_length, target_dim] corresponding + to query projection input vector up to `time_step`. For batch index `i`, only + `query[i, :time_step[i], ...]` will affect subsequent decoding. + key: Same description as `query`, but for the key projection input vector. + Key and value have to both be tensors or both be None. + If they are tensors, key and value are used as the unique input to the + input projection. Otherwise, query is used as the key and value input. + value: Same description as `query`, but for the value projection input vector. + See the above comment for `key` for additional constraints. + kv_state: An optional KVState. + attention_logit_biases: See ``On attention logit biases`` in the file comments. + return_aux: See comments on `Output`. + + Returns: + A tuple (init_states, output): + * init_states: A Nested Tensor state of key and value pair along with index updated at + `time_step`. + * output: In the prefill case, an Output instance, where .data is of the same shape as + query and .probs is of shape [batch, num_heads, target_length, source_length]. + Otherwise, if initializing cache from scratch, output will be None. + """ + return self._forward_for_mode( + mode=ForwardMode.INIT_STATES, query=query, key=key, value=value, + cached_states=dict(i_proj=time_step), kv_state=kv_state, attention_logit_biases=attention_logit_biases, return_aux=return_aux, ) - forward_outputs, _ = F( - layer, - state=layer_params, - is_training=False, - prng_key=jax.random.PRNGKey(456), - inputs=inputs, - ) - initial_state, initial_output = layer.init_states( - time_step=None, - query=TensorSpec([batch_size, tgt_len]), + def extend_step( + self, + cached_states: NestedTensor, + query: Tensor, + *, + key: Optional[Tensor] = None, + value: Optional[Tensor] = None, + kv_state: Optional[KVState] = None, + attention_logit_biases: Optional[Tensor], + return_aux: Optional[set[str]] = None, + ) -> tuple[NestedTensor, Output]: + """Computes the value vector given the query of the current step. + This function is used by autoregressive decoding. + + Based on: + https://github.com/tensorflow/lingvo/blob/5754b2f840ebf0f8c52d87e5d4d76f22e372513e/lingvo/jax/layers/attentions.py#L1249 + + Args: + cached_states: A `NestedTensor` object containing tensors which are the results of + previous attentions, and index used for fast decoding. Contains "key" and "value" of + shape [B, N, H, T], and a Tensor "time_step" of shape [B]. + query: Tensor of shape [B, 1, D] corresponding to query projection input vector + at "time_step" indices. + key: Tensor of shape [B, 1, D] corresponding to key projection input vector at + "time_step" indices. Key and value have to both be tensors or both be None. + If they are tensors, key and value are used as the unique input to the + input projection. Otherwise, query is used as the key and value input. + value: Tensor of shape [B, 1, D] corresponding to value projection input vector + at "time_step" indices. See the above comment for `key` for additional + constraints. + kv_state: An optional KVState. + attention_logit_biases: See ``On attention logit biases`` in the file comments. + Additionally, target_length is expected to be 1 since this is per time step. + The biases should already include causal masking for decoding, plus other biases + if necessary. + return_aux: See comments on `Output`. + + Returns: + A `NestedTensor` state of key and value pair along with index updated at `time_step`. + An Output instance, where .data is of the same shape as query, .probs is of shape + [batch, num_heads, 1, source_length]. + """ + return self._forward_for_mode( + mode=ForwardMode.EXTEND_STEP, + query=query, + key=key, + value=value, + cached_states=cached_states, kv_state=kv_state, - # This is unused for initializing state from scratch. - attention_logit_biases=None, + attention_logit_biases=attention_logit_biases, + return_aux=return_aux, ) - self.assertIsNone(initial_output) - if kv_state is None: - for k in ["key", "value"]: - # Check that the cache dtype is inferred as the layer dtype. - self.assertEqual(initial_state["i_proj"][k].dtype, dtype) - else: - self.assertNotIn("key", initial_state["i_proj"]) - self.assertNotIn("value", initial_state["i_proj"]) - inputs = dict(cached_states=initial_state, kv_state=kv_state, return_aux=return_aux) - decoder_output = [] - decoder_probs = [] - for t in range(0, tgt_len, extend_step_len): - inputs["query"] = query[:, t : t + extend_step_len, :] - if key is not None: - inputs["key"] = key[:, t : t + extend_step_len, :] - if value is not None: - inputs["value"] = value[:, t : t + extend_step_len, :] - inputs["attention_logit_biases"] = attention_logit_biases[t : t + extend_step_len, :] - (cached_states, extend_step_outputs), _ = F( - layer, - state=layer_params, - is_training=False, - prng_key=jax.random.PRNGKey(456), - inputs=inputs, - method="extend_step", - ) - inputs["cached_states"] = cached_states - decoder_output.append(extend_step_outputs.data) - decoder_probs.append(extend_step_outputs.probs) - decoder_output = jnp.concatenate(decoder_output, axis=1) - decoder_probs = jnp.concatenate(decoder_probs, axis=2) - assert_allclose(decoder_output, forward_outputs.data, atol=1e-6) - assert_allclose(decoder_probs, forward_outputs.probs, atol=1e-6) - - @parameterized.product( - dtype=(jnp.float32, jnp.float16, jnp.bfloat16), - per_dim_scale=(None, PerDimScale.default_config()), - atten_logit_cap=(0.0, 20.0), - bias=(True, False), - input_linear=(QKVLinear, RoFormerQKVLinear, QLinear), - extend_step_len=(1, 4), - ) - def test_extend_step( + + @staticmethod + def default_query_scale_config() -> InstantiableConfig[ScaleFn]: + """The config for the default function used to compute the query scale.""" + + return config_for_function(pow_scale_fn).set(exp=-0.5) + + @staticmethod + def default_key_scale_config() -> InstantiableConfig[ScaleFn]: + """The config for the default function used to compute the key scale.""" + + return config_for_function(constant_scale_fn).set(value=1) + + +class GroupedQueryAttention(MultiheadAttention): + """A Grouped-Query Attention (GQA) layer. + + Query projections are divided into K groups along the `num_heads` dimension. Projections in the + same query subgroup share one common key/value head. This reduces the size of the KV-cache by a + factor of `num_heads/num_kv_heads`. + + When `input_linear` is a `GroupedQKVLinear` layer with `num_kv_heads=1`, GQA reduces to + multi-query attention (MQA). + When `input_linear` is a `QKVLinear` layer (i.e. `num_kv_heads=num_heads`), GQA is equivalent to + multi-head attention (MHA). + + Note that in some cases fused variants `FusedQKVLinear` or `FusedGroupedQKVLinear` can be used + as drop-in replacements for `QKVLinear` or `GroupedQKVLinear` respectively (see corresponding + layer docstrings for details). + + Reference: https://arxiv.org/abs/2305.13245 + """ + + @property + def num_kv_heads(self): + return self.i_proj.num_kv_heads + + def _compute_logits(self, q_proj: Tensor, k_proj: Tensor) -> Tensor: + """Compute attention logits. + + Args: + q_proj: query tensor, [batch, target_length, num_heads, per_head_dim]. + k_proj: key tensor, [batch, source_length, num_kv_heads, per_head_dim]. + + Returns: + logits: [batch, num_heads, target_length, source_length]. + """ + kv_heads = k_proj.shape[-2] + num_head_group = self.config.num_heads // kv_heads + if num_head_group == 1: + return super()._compute_logits(q_proj=q_proj, k_proj=k_proj) + + q_proj = self.scale_query(q_proj) + k_proj = self.scale_key(k_proj) + q_proj = einops.rearrange(q_proj, "b t (k g) h -> b t k g h", k=kv_heads, g=num_head_group) + k_proj = einops.rearrange(k_proj, "b s k h -> b s k 1 h") + logits = jnp.einsum("btkgh,bsk1h->bkgts", q_proj, k_proj) + return einops.rearrange(logits, "b k g t s -> b (k g) t s") + + def _compute_context(self, probs: Tensor, v_proj: Tensor) -> Tensor: + """Compute attention context. + + Args: + probs: probs tensor, [batch, num_heads, target_length, source_length]. + v_proj: value tensor, [batch, source_length, num_kv_heads, per_head_dim]. + + Returns: + context: [batch, target_length, num_heads, per_head_dim]. + """ + kv_heads = v_proj.shape[-2] + num_head_group = self.config.num_heads // kv_heads + if num_head_group == 1: + return super()._compute_context(probs=probs, v_proj=v_proj) + + probs = einops.rearrange(probs, "b (k g) t s -> b k g t s", k=kv_heads, g=num_head_group) + v_proj = einops.rearrange(v_proj, "b s k h -> b s k 1 h") + context = jnp.einsum("bkgts,bsk1h->btkgh", probs, v_proj) + return einops.rearrange(context, "b t k g h -> b t (k g) h") + + +class SigmoidAttention(MultiheadAttention): + """A multi-head sigmoid-based attention layer, instead of softmax. + + TODO(floris_weers): Add paper reference. + """ + + @config_class + class Config(MultiheadAttention.Config): + """Configures SigmoidAttention.""" + + seq_len: Required[int] = REQUIRED # Maximum sequence length used. + + def _compute_attention( self, - dtype: jnp.dtype, - per_dim_scale: Optional[PerDimScale.Config], - atten_logit_cap: float, - input_linear: attention.BaseQKVLinear, - bias: bool, - extend_step_len: int, - ): - model_dim = 16 - num_heads = 4 - if input_linear == attention.RoFormerQKVLinear: - input_linear = input_linear.default_config().set(rotary_value=False) - else: - input_linear = input_linear.default_config() - cfg = attention.MultiheadAttention.default_config().set( - query_scale=attention.ScaleQuery.default_config().set(per_dim_scale=per_dim_scale), - atten_logit_cap=atten_logit_cap, - input_linear=input_linear, + *, + q_proj: Tensor, + k_proj: Tensor, + v_proj: Tensor, + attention_logit_biases: BaseAttentionBias, + ) -> tuple[Tensor, Tensor]: + """See `MultiheadAttention._compute_attention` for details.""" + cfg = self.config + logits = self._compute_logits(q_proj, k_proj) + logits = self._cap_logits(logits) + self.vlog(3, "atten.logits=%s", logits[0, 0, 0, :]) + + attention_logit_biases = attention_logit_biases.value() + if attention_logit_biases is None: + attention_logit_biases = 0 + # To approximate softmax, we subtract a bias dependent on sequence length. + attention_logit_biases = attention_logit_biases - jnp.log(cfg.seq_len) + probs = sigmoid_with_biases( + logits, + attention_logit_biases=attention_logit_biases, ) - self._test_extend_step( - cfg, - model_dim=model_dim, - num_heads=num_heads, - dtype=dtype, - bias=bias, - extend_step_len=extend_step_len, + probs = self.dropout(probs) + + context = self._compute_context(probs, v_proj) + context = self._remat_name(context, "context") + return context, probs + + +def rel_pos_to_abs_pos(x: Tensor) -> Tensor: + """Converts a (T, relative_pos_offset) Tensor to a (T, abs_position) tensor. + + For example, t = 3: + ..abc abc + .def. => def + ghi.. ghi + + Input shape: [t, 2t - 1]: + ..abc + .def. + ghi.. + + 1. Reshape to [t * (2t - 1)] + ..abc.def.ghi.. + + 2. Trim by [t-1:-1], producing shape [t * (2t - 2)]. + abc.def.ghi. + + 3. Reshape to [t, 2t - 2]: + abc. + def. + ghi. + + 4. Trim by [:, :-(t-2)] + abc + def + ghi + + Args: + x: a Tensor of shape [T, 2*T - 1], where x[i, j] represents the bias between query[i] and + absolute position k = i + j - (T - 1), if 0 <= k < T, otherwise the value is not used. + T is expected to be >= 1. + + Returns: + y: a Tensor of shape [T, T], s.t. y[i, k] = x[i, j] where k = i + j - (T - 1), + if 0 <= k < T. + """ + t, offset_length = x.shape + assert offset_length == 2 * t - 1 + if t <= 1: + return x + # [t * (2t - 1)]. + x = x.reshape([-1]) + # [t * (2t - 2)]. + x = x[t - 1 : -1] + # [t, 2t - 2]. + x = x.reshape([t, -1]) + # [t, t]. When t = 2, do not trim. + if t > 2: + x = x[:, : -(t - 2)] + return x + + +class MultiheadRelativePositionLinear(BaseMultiheadLinear): + """Multi-head relative position linear layer.""" + + @property + def _einsum_expr(self): + return "ld,dnh->lnh" + + @property + def _bias_spec(self): + cfg = self.config + return ParameterSpec( + shape=(cfg.num_heads, cfg.per_head_dim), + mesh_axes=cfg.param_partition_spec[-2:], ) - @parameterized.product( - dtype=(jnp.float32, jnp.float16, jnp.bfloat16), - per_dim_scale=(None, PerDimScale.default_config()), - atten_logit_cap=(0.0, 20.0), - num_kv_heads=(1, 2, 4), - input_linear=(attention.GroupedQKVLinear, attention.FusedGroupedQKVLinear), - bias=(True, False), - extend_step_len=(1, 4), - ) - def test_gqa_extend_step( - self, - dtype: jnp.dtype, - per_dim_scale: Optional[PerDimScale.Config], - atten_logit_cap: float, - num_kv_heads: int, - input_linear: type[attention.BaseQKVLinear], - bias: bool, - extend_step_len: int, - ): - model_dim = 16 - num_heads = 4 - cfg = attention.GroupedQueryAttention.default_config().set( - query_scale=attention.ScaleQuery.default_config().set(per_dim_scale=per_dim_scale), - atten_logit_cap=atten_logit_cap, - input_linear=input_linear.default_config().set(num_kv_heads=num_kv_heads), + # pylint: disable-next=no-self-use + def _compute_fan_axes(self, name: str, parameter_spec: ParameterSpec) -> Optional[FanAxes]: + if name == "weight": + return FanAxes(in_axis=0, out_axis=(1, 2)) + else: + return None + + +def xl_attention_logits( + q_proj: Tensor, k_proj: Tensor, relative_pos_emb: Tensor, u: Tensor, v: Tensor +): + """Computes Transformer XL self-attention logits. + + Note that this implementation follows XLNet implementation and is different from the lingvo + implementation in that here the relative_pos_emb index is computed from key_i - query_i, + while lingvo computes from query_i - key_i. + + Args: + q_proj: A Tensor of shape [batch, target_length, num_heads, per_head_dim], representing + projected queries. + k_proj: A Tensor of shape [batch, target_length, num_heads, per_head_dim], representing + projected keys. + relative_pos_emb: A Tensor of shape [num_embeddings, num_heads, per_head_dim], representing + projected relative positional embeddings, where num_embeddings = 2 * target_length - 1. + relative_pos_emb[key_i - query_i + target_length - 1] represents positional + embeddings between query[:, query_i] and key[:, key_i] and is usually computed from + sinusoidal_positional_embeddings(query_i - key_i), i.e., + relative_pos_emb[0] represents query_i = target_length - 1 and key_i = 0. + relative_pos_emb[-1] represents query_i = 0 and key_i = target_length - 1. + u: A Tensor of shape [num_heads, per_head_dim]. + The trainable `u` in https://arxiv.org/pdf/1901.02860.pdf 3.3 used for term 'ac'. + v: A Tensor of shape [num_heads, per_head_dim]. + The trainable `v` in https://arxiv.org/pdf/1901.02860.pdf 3.3 used for term 'bd'. + + Returns: + A tensor of shape [batch, num_heads, target_length, target_length] representing attention + logits. logit[:, :, i, j] represents the logit for query[i] and key[j]. + """ + term_ac = jnp.einsum("btnh,bsnh->bnts", q_proj + u, k_proj) + term_bd = jnp.einsum("btnh,lnh->bntl", q_proj + v, relative_pos_emb) + # Apply vmap twice to map over both `batch` and `num_heads`. + term_bd = jax.vmap(jax.vmap(rel_pos_to_abs_pos))(term_bd) + return term_ac + term_bd + + +class MultiheadAttentionXL(MultiheadAttention): + """Multi-head self-attention with relative positional embeddings. + + The default config matches XL-Net implementation with `per_dim_scale=None` and + `scale_position=LOGIT`. + To match with Lingvo implementation, enable `per_dim_scale` + and set `scale_position=QUERY`. Note the positional embeddings are in descending + order, which is different from Lingvo's implementation. + + Reference: + https://github.com/zihangdai/xlnet/blob/bbaa3a6fa0b3a2ee694e8cf66167434f9eca9660/modeling.py + https://github.com/huggingface/transformers/blob/224bde91caff4ccfd12277ab5e9bf97c61e22ee9/src/transformers/models/xlnet/modeling_xlnet.py#L204 + https://github.com/tensorflow/lingvo/blob/a1326a09641a6ec7d775a51148012551756d888d/lingvo/core/batch_major_attention.py#L1345 + https://github.com/tensorflow/lingvo/blob/f02fed838836bcc513d51c95d482247b119543fb/lingvo/core/attention_util.py#L464-L513 + """ + + @unique + class ScalePosition(Enum): + # Applies query scale-factor to the logits. + LOGIT = 0 + # Applies query scale-factor to the queries. + QUERY = 1 + + @config_class + class Config(MultiheadAttention.Config): + """Configures MultiheadAttentionXL.""" + + pos_emb_dim: Optional[int] = None # Positional embedding dim. If None, use query_dim. + # Config for computing relative position embeddings for range [-seq_len + 1, seq_len - 1]. + relative_pos_emb: SinusoidalPositionalEmbedding.Config = ( + SinusoidalPositionalEmbedding.default_config() + ) + # Config used for the R projection. + relative_pos_linear: MultiheadRelativePositionLinear.Config = ( + MultiheadRelativePositionLinear.default_config().set(bias=False) + ) + scale_position: Required["MultiheadAttentionXL.ScalePosition"] = REQUIRED + + @classmethod + def default_config(cls) -> Config: + cfg: MultiheadAttentionXL.Config = super().default_config() + cfg.scale_position = MultiheadAttentionXL.ScalePosition.LOGIT + # pylint: disable=no-member + cfg.input_linear = FusedQKVLinear.default_config() + cfg.input_linear.layer.bias = False + # pylint: enable=no-member + return cfg + + def __init__(self, cfg: Config, *, parent: Module): + super().__init__(cfg, parent=parent) + cfg: MultiheadAttentionXL.Config = self.config + if not cfg.query_dim == cfg.key_dim == cfg.value_dim: + raise ValueError( + f"MultiheadAttentionXL requires query_dim ({cfg.query_dim}) == " + f"key_dim ({cfg.key_dim}) == value_dim ({cfg.value_dim})" + ) + pos_emb_dim = cfg.pos_emb_dim or cfg.query_dim + self._add_child("relative_pos_emb", cfg.relative_pos_emb.set(dim=pos_emb_dim)) + self._add_child( + "r_proj", + cfg.relative_pos_linear.clone( + model_dim=pos_emb_dim, num_heads=cfg.num_heads, per_head_dim=self.per_head_dim() + ), ) - self._test_extend_step( - cfg, - model_dim=model_dim, - num_heads=num_heads, - dtype=dtype, - bias=bias, - extend_step_len=extend_step_len, + + def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: + cfg = self.config + params = super()._create_layer_parameter_specs() + params["u_bias"] = params["v_bias"] = ParameterSpec( + shape=(cfg.num_heads, self.per_head_dim()), + initializer=constant_initializer(0), + mesh_axes=cfg.relative_pos_linear.param_partition_spec[-2:], ) + return params - def _test_prefill_states( + def forward( self, - attention_cfg: attention.MultiheadAttention.Config, + query: Tensor, *, - model_dim: int, - num_heads: int, - dtype: jnp.dtype, - bias: bool, - num_kv_heads: Optional[int] = None, - ): - cfg = attention_cfg.set( - query_dim=model_dim, - key_dim=model_dim, - value_dim=model_dim, - num_heads=num_heads, - ) - cfg.input_linear.set(dtype=dtype, cache_dtype=None) - set_bias_recursively(cfg, bias=bias) - layer: attention.MultiheadAttention = cfg.set(name="test").instantiate(parent=None) + key: Optional[Tensor] = None, + value: Optional[Tensor] = None, + **kwargs, + ) -> MultiheadAttention.Output: + if key is not None or value is not None: + raise ValueError("Both key and value must be None for MultiheadAttentionXL") + return super().forward(query, **kwargs) + + def _compute_logits(self, q_proj: Tensor, k_proj: Tensor) -> Tensor: + cfg = self.config + with child_context("apply_query_norm", module=self): + # We apply the query norm (if configured) to the projection (not the logits). + q_proj = self.scale_query.apply_norm(q_proj) + + with child_context("apply_per_dim_scale", module=self): + q_proj = self.scale_query.apply_per_dim_scale(q_proj) + + if cfg.scale_position == MultiheadAttentionXL.ScalePosition.QUERY: + with child_context("apply_scale_factor_queries", module=self): + q_proj = self.scale_query.apply_scale_factor(q_proj) + + seq_len = q_proj.shape[1] + # [2*seq_len - 1, pos_emb_dim]. + # + # Following the XLNet implementation + # https://github.com/zihangdai/xlnet/blob/bbaa3a6fa0b3a2ee694e8cf66167434f9eca9660/modeling.py#L215 + # https://github.com/huggingface/transformers/blob/28d0048218ad7bce69510b16024510afba0daed2/src/transformers/models/xlnet/modeling_xlnet.py#L1030 + # the positions are from descending from seq_len - 1 to -seq_len + 1. + pos_emb = self.relative_pos_emb(jnp.arange(seq_len - 1, -seq_len, -1, dtype=jnp.int32)) + # [2*seq_len - 1, num_heads, per_head_dim]. + r_proj = self.r_proj(pos_emb) + + # Apply key scaling. + k_proj = self.scale_key(k_proj) + + logits = xl_attention_logits( + q_proj=q_proj, + k_proj=k_proj, + relative_pos_emb=r_proj, + u=self.parameters["u_bias"], + v=self.parameters["v_bias"], + ) + if cfg.scale_position == MultiheadAttentionXL.ScalePosition.LOGIT: + # In the original XL-Net code, it applies scale on AC + BD: + # + # https://github.com/zihangdai/xlnet/blob/bbaa3a6fa0b3a2ee694e8cf66167434f9eca9660/modeling.py#L148 + with child_context("apply_scale_factor_logits", module=self): + logits = self.scale_query.apply_scale_factor(logits) + return logits + + def extend_step( + self, + cached_states: NestedTensor, + query: Tensor, + **kwargs, + ) -> tuple[NestedTensor, MultiheadAttention.Output]: + raise NotImplementedError(type(self)) - layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) - batch_size, tgt_len = 3, 6 - query = jax.random.normal( - jax.random.PRNGKey(123), [batch_size, tgt_len, model_dim], dtype=dtype - ) - if attention_cfg.klass == attention.GroupedQueryAttention: - key = value = None +class TransformerAttentionLayer(BaseLayer): + """A Transformer attention layer with normalization and a skip connection. + + Can be used for either self-attention or cross-attention. + """ + + @config_class + class Config(BaseLayer.Config): + """Configures TransformerAttentionLayer.""" + + target_dim: Required[int] = REQUIRED # Input target feature dim. + source_dim: Required[int] = REQUIRED # Input source feature dim. + norm: InstantiableConfig = LayerNorm.default_config() # The normalization layer config. + attention: InstantiableConfig = ( + MultiheadAttention.default_config() + ) # The attention layer config. + dropout: InstantiableConfig = Dropout.default_config() # The dropout layer config. + # The stochastic depth layer config. + # Pytorch reference: + # https://github.com/facebookresearch/deit/blob/main/models_v2.py#L58 + # Tensorflow reference: + # https://github.com/tensorflow/models/blob/master/official/projects/vit/modeling/nn_blocks.py#L86-L92 + stochastic_depth: InstantiableConfig = StochasticDepth.default_config() + # The inner structure of the layer: prenorm or postnorm. See + # https://arxiv.org/abs/2002.04745 for background. + # The structure also support hybridnorm, which uses two norms in the residual branch. + # hybridnorm: TransformerAttentionLayer(x) = x + layernorm_2(attention(layernorm_1(x))) + # Ref: https://github.com/google/praxis/blob/main/praxis/layers/transformers.py#L1129 + # TODO (bwzhang@) Adding a unittest for the hybridnorm. + structure: str = "prenorm" + + def __init__(self, cfg: Config, *, parent: Module): + super().__init__(cfg, parent=parent) + cfg = self.config + if cfg.structure in ["prenorm", "postnorm"]: + self._add_child("norm", cfg.norm.set(input_dim=cfg.target_dim)) + elif cfg.structure == "hybridnorm": + self._add_child("prenorm", cfg.norm.set(input_dim=cfg.target_dim)) + self._add_child("postnorm", cfg.norm.set(input_dim=cfg.target_dim)) else: - # Make key and value distinct from query. Otherwise, it is equivalent - # to the query only case. - key = value = query + 0.1 - attention_logit_biases = attention_bias.make_causal_biases(tgt_len) - return_aux = {"probs"} - - forward_outputs, _ = F( - layer, - state=layer_params, - is_training=False, - prng_key=jax.random.PRNGKey(456), - inputs=dict( - query=query, - key=key, - value=value, - attention_logit_biases=attention_logit_biases, - return_aux=return_aux, + raise NotImplementedError(cfg.structure) + self._add_child( + "attention", + cfg.attention.set( + query_dim=cfg.target_dim, + key_dim=cfg.source_dim, + value_dim=cfg.source_dim, + output_dim=cfg.target_dim, ), ) + self._add_child("dropout", cfg.dropout) + self._add_child("stochastic_depth", cfg.stochastic_depth) - time_step = jnp.arange(batch_size) - (initial_states, initial_output), _ = F( - layer, - state=layer_params, - is_training=False, - prng_key=jax.random.PRNGKey(456), - inputs=dict( - time_step=time_step, - query=query, - key=key, - value=value, - attention_logit_biases=attention_logit_biases, - return_aux=return_aux, - ), - method="init_states", - ) + class Output(NamedTuple): + """Outputs of TransformerAttentionLayer. - # Check time_step and shapes of state. - self.assertEqual(["i_proj"], list(initial_states.keys())) - self.assertTrue(jnp.all(time_step == initial_states["i_proj"]["time_step"])) - for proj in ["key", "value"]: - self.assertEqual( - (batch_size, tgt_len, num_kv_heads or num_heads, model_dim // num_heads), - initial_states["i_proj"][proj].shape, - ) - self.assertEqual( - dtype, - initial_states["i_proj"][proj].dtype, - ) + Fields: + data: [batch, target_length, output_dim]. The attention output. Always present. + probs: The attention probabilities returned by the attention layer. + Populated if "probs" is in return_aux. + kv_state: The KV state used to compute output. + Populated if "kv_state" is in return_aux. + """ + + data: Tensor + probs: Optional[Tensor] = None + kv_state: Optional[KVState] = None + + def _forward_for_mode( + self, + *, + mode: ForwardMode, + target: Union[Tensor, TensorSpec], + source: Optional[Union[Tensor, KVState]] = None, + attention_logit_biases: Optional[Tensor] = None, + segment_ids: Optional[Tensor] = None, + target_positions: Optional[Tensor] = None, + cached_states: Optional[NestedTensor] = None, + return_aux: Optional[set[str]] = None, + ) -> tuple[Optional[Nested[Tensor]], Optional[Output]]: + """Computes either self-attention or cross-attention for the given target and source. + + Args: + mode: Configures whether `cached_states` are consumed or emitted. See `ForwardMode` for + details. + target: A Tensor or TensorSpec of shape [batch, target_length, target_dim]. + source: An optional KVState or Tensor of shape [batch, source_length, source_dim]. + If None, uses norm(target) as source (self-attention). + attention_logit_biases: See ``On attention logit biases`` in the file comments. + segment_ids: See ``On segment_ids`` in the file comments. + target_positions: See ``On positions`` in the file comments. + cached_states: Optional NestedTensor as produced by `init_states`. + return_aux: See comments on `Output`. + + Returns: + A tuple (cached_states, output): + * cached_states: An optional Nested Tensor of cache states, depending on `mode`. + * output: An optional Output instance, where .data is of the same shape as query and + .probs is of shape [batch, num_heads, target_length, source_length]. + If initializing cache from scratch, output will be None. + + Raises: + ValueError: If `mode` is unsupported. + NotImplementedError: If `cfg.structure` is not supported. + """ + cfg = self.config - # Zero-out outputs starting from initial time_step, and test that we can recover the full - # outputs by calling extend_step starting from time_step. - # [batch, tgt_len]. - time_step_mask = jnp.arange(tgt_len) < time_step[:, None] - # [batch, tgt_len, model_dim]. - decoder_output = initial_output.data * time_step_mask[..., None] - # [batch, tgt_len, model_dim] --> [batch, model_dim, tgt_len]. - decoder_output = jnp.moveaxis(decoder_output, -2, -1) - - # [batch, num_heads, tgt_len, src_len]. - if initial_output.probs is None: - decoder_probs = None + if source is None: + kv_kwargs = {} + elif isinstance(source, KVState): + kv_kwargs = {"kv_state": source} + elif isinstance(source, Tensor): + kv_kwargs = {"key": source, "value": source} else: - decoder_probs = initial_output.probs * time_step_mask[:, None, :, None] - # [batch, num_heads, tgt_len, src_len] --> [batch, num_heads, src_len, tgt_len]. - decoder_probs = jnp.moveaxis(decoder_probs, -2, -1) - - # Call extend_step from time_step, ensuring that outputs match. - inputs = dict(cached_states=initial_states, return_aux=return_aux) - while jnp.any(time_step < tgt_len): - # [batch, tgt_len=1, model_dim]. - inputs["query"] = jnp.take_along_axis( - query, time_step[:, None, None], axis=1, mode="clip" - ) - if key is not None: - inputs["key"] = jnp.take_along_axis( - key, time_step[:, None, None], axis=1, mode="clip" + raise NotImplementedError(source) + kv_kwargs["return_aux"] = return_aux + + def attention_thunk(target: Tensor) -> tuple[Optional[NestedTensor], Tensor]: + if mode == ForwardMode.FORWARD: + atten_state, atten_output = ( + None, + self.attention( + query=target, + **kv_kwargs, + attention_logit_biases=attention_logit_biases, + segment_ids=segment_ids, + query_positions=target_positions, + ), ) - inputs["value"] = jnp.take_along_axis( - value, time_step[:, None, None], axis=1, mode="clip" + elif mode == ForwardMode.INIT_STATES: + assert cached_states is not None + assert segment_ids is None + assert target_positions is None + atten_state, atten_output = self.attention.init_states( + time_step=cached_states["attention"], + query=target, + **kv_kwargs, + attention_logit_biases=attention_logit_biases, ) - # [batch=1, tgt_len=1, tgt_len]. - inputs["attention_logit_biases"] = jnp.take_along_axis( - attention_logit_biases[None, :, :], time_step[:, None, None], axis=1, mode="clip" - ) - (updated_state, outputs), _ = F( - layer, - state=layer_params, - is_training=False, - prng_key=jax.random.PRNGKey(456), - inputs=inputs, - method="extend_step", + elif mode == ForwardMode.EXTEND_STEP: + assert cached_states is not None + assert segment_ids is None + assert target_positions is None + atten_state, atten_output = self.attention.extend_step( + cached_states["attention"], + target, + **kv_kwargs, + attention_logit_biases=attention_logit_biases, + ) + else: + raise ValueError(f"Unrecognized mode {mode}.") + return atten_state, atten_output + + if mode == ForwardMode.INIT_STATES: + assert cached_states is not None + if cached_states["attention"] is None: + atten_state, atten_output = attention_thunk(TensorSpec(target.shape, target.dtype)) + return dict(attention=atten_state), atten_output + + if cfg.structure == "prenorm": + skip_input = target # pre-norm: where normalization happens within the residual part. + norm_target = self.norm(target) + atten_state, atten_output = attention_thunk(norm_target) + data = skip_input + self.stochastic_depth(self.dropout(atten_output.data)) + elif cfg.structure == "postnorm": + # This is the structure used by the original Transformer, BERT, and RoBERTa. + atten_state, atten_output = attention_thunk(target) + # Post-norm: norm applied on the sum of input and attention output. + data = self.norm(target + self.stochastic_depth(self.dropout(atten_output.data))) + elif cfg.structure == "hybridnorm": + skip_input = target # pre-norm: where normalization happens within the residual part. + norm_target = self.prenorm(target) + atten_state, atten_output = attention_thunk(norm_target) + data = skip_input + self.stochastic_depth( + self.dropout(self.postnorm(atten_output.data)) ) - inputs["cached_states"] = updated_state - - # [batch, model_dim, tgt_len=1] - curr_outputs = jnp.moveaxis(outputs.data, -2, -1) - # [batch, num_heads, src_len, tgt_len=1] - curr_probs = jnp.moveaxis(outputs.probs, -2, -1) - - # [batch, 1, tgt_len]. - oh_indices = jax.nn.one_hot(time_step, tgt_len)[:, None, :] - decoder_output = decoder_output + curr_outputs * oh_indices - # [batch, 1, 1, tgt_len]. - oh_indices = oh_indices[..., None, :] - decoder_probs = decoder_probs + curr_probs * oh_indices - time_step = time_step + 1 - - # [batch, model_dim, tgt_len] --> [batch, tgt_len, model_dim]. - decoder_output = jnp.moveaxis(decoder_output, -1, -2) - # [batch, num_heads, src_len, tgt_len] --> [batch, num_heads, tgt_len, src_len]. - decoder_probs = jnp.moveaxis(decoder_probs, -1, -2) - - assert_allclose(decoder_output, forward_outputs.data) - assert_allclose(decoder_probs, forward_outputs.probs) - - @parameterized.product( - dtype=(jnp.float32, jnp.float16, jnp.bfloat16), - per_dim_scale=(None, PerDimScale.default_config()), - atten_logit_cap=(0.0, 20.0), - bias=(True, False), - input_linear=(attention.QKVLinear, attention.RoFormerQKVLinear), - ) - def test_prefill_states( - self, - dtype: jnp.dtype, - per_dim_scale: Optional[PerDimScale.Config], - atten_logit_cap: float, - bias: bool, - input_linear: attention.BaseQKVLinear, - ): - model_dim = 16 - num_heads = 4 - if input_linear == attention.RoFormerQKVLinear: - input_linear = input_linear.default_config().set(rotary_value=False) else: - input_linear = input_linear.default_config() - cfg = attention.MultiheadAttention.default_config().set( - query_scale=attention.ScaleQuery.default_config().set(per_dim_scale=per_dim_scale), - atten_logit_cap=atten_logit_cap, - input_linear=input_linear, - ) - self._test_prefill_states( - cfg, model_dim=model_dim, num_heads=num_heads, dtype=dtype, bias=bias + raise NotImplementedError(cfg.structure) + return dict(attention=atten_state), self.Output( + data=data, probs=atten_output.probs, kv_state=atten_output.kv_state ) - @parameterized.product( - dtype=(jnp.float32, jnp.float16, jnp.bfloat16), - per_dim_scale=(None, PerDimScale.default_config()), - atten_logit_cap=(0.0, 20.0), - num_kv_heads=(1, 2, 4), - input_linear=(attention.GroupedQKVLinear, attention.FusedGroupedQKVLinear), - bias=(True, False), - ) - def test_gqa_prefill_states( + def forward( self, - dtype: jnp.dtype, - per_dim_scale: Optional[PerDimScale.Config], - atten_logit_cap: float, - num_kv_heads: int, - input_linear: type[attention.BaseQKVLinear], - bias: bool, - ): - model_dim = 16 - num_heads = 4 - cfg = attention.GroupedQueryAttention.default_config().set( - query_scale=attention.ScaleQuery.default_config().set(per_dim_scale=per_dim_scale), - atten_logit_cap=atten_logit_cap, - input_linear=input_linear.default_config().set(num_kv_heads=num_kv_heads), - ) - self._test_prefill_states( - cfg, - model_dim=model_dim, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - dtype=dtype, - bias=bias, + *, + target: Tensor, + source: Optional[Union[Tensor, KVState]] = None, + attention_logit_biases: Optional[Tensor] = None, + segment_ids: Optional[Tensor] = None, + target_positions: Optional[Tensor] = None, + return_aux: Optional[set[str]] = None, + ) -> Output: + """Computes attention with target as query and source as key and value. + + Args: + target: A Tensor of shape [batch, target_length, target_dim]. + source: An optional KVState or Tensor of shape [batch, source_length, source_dim]. + If None, uses norm(target) as source (self-attention) + attention_logit_biases: See ``On attention logit biases`` in the file comments. + segment_ids: See ``segment_ids`` in the file comments. + target_positions: See ``positions`` in the file comments. + return_aux: See comments on `Output`. + + Returns: + An Output instance, where .data is of the same shape as target and .probs is of shape + [batch, num_heads, target_length, source_length]. + + Raises: + NotImplementedError: If cfg.structure is unsupported. + """ + _, output = self._forward_for_mode( + mode=ForwardMode.FORWARD, + target=target, + source=source, + attention_logit_biases=attention_logit_biases, + segment_ids=segment_ids, + target_positions=target_positions, + cached_states=None, + return_aux=return_aux, ) + return output - def test_gqa_against_mha(self): - model_dim = 16 - num_heads = 4 - num_kv_heads = 2 - ref_cfg = attention.MultiheadAttention.default_config().set( - name="mha", - query_dim=model_dim, - key_dim=model_dim, - value_dim=model_dim, - num_heads=num_heads, - ) - ref_layer = ref_cfg.instantiate(parent=None) - - test_cfg = attention.GroupedQueryAttention.default_config().set( - name="gqa", - query_dim=model_dim, - key_dim=model_dim, - value_dim=model_dim, - num_heads=num_heads, - input_linear=attention.GroupedQKVLinear.default_config().set(num_kv_heads=num_kv_heads), - ) - test_layer = test_cfg.instantiate(parent=None) - - # Initialize layer parameters. - prng_key = jax.random.PRNGKey(123) - prng_key, init_key, data_key = jax.random.split(prng_key, num=3) - state = ref_layer.initialize_parameters_recursively(init_key) - - batch, seq_len = 2, 10 - per_head_dim = ref_layer.per_head_dim() - q = jax.random.uniform(data_key, (batch, seq_len, num_heads, per_head_dim)) - k = jax.random.uniform(data_key, (batch, seq_len, num_kv_heads, per_head_dim)) - v = jax.random.uniform(data_key, (batch, seq_len, num_kv_heads, per_head_dim)) - attention_logit_biases = attention_logit_biases = attention_bias.ZeroAttentionBias() - - (test_context, ref_probs), _ = F( - test_layer, - method="_compute_attention", - state=state, - is_training=False, - prng_key=prng_key, - inputs=dict( - q_proj=q, k_proj=k, v_proj=v, attention_logit_biases=attention_logit_biases - ), + def init_states( + self, + *, + time_step: Optional[Tensor], + target: Union[Tensor, TensorSpec], + source: Optional[Union[Tensor, KVState]] = None, + attention_logit_biases: Optional[Tensor] = None, + return_aux: Optional[set[str]] = None, + ) -> tuple[Nested[Tensor], Optional[Output]]: + """Initializes cache for autoregressive cached decoding. + + The method supports initializing an empty cache as well as prefilling: + * To initialize an empty cache, specify `time_step=None`. + In this case, `target` is allowed to be a TensorSpec. + * To prefill, provide `time_step` and `target` as Tensors. + + Args: + time_step: A Tensor of shape [batch]. Each value is an index into the length dimension + indicating where decoding will start from. + target: Tensor of shape [batch, target_length, target_dim] corresponding to query vector + at `time_step` indices. For batch index `i`, only `target[i, :time_step[i], ...]` + will affect subsequent decoding. + source: An optional KVState or Tensor of shape [batch, source_length, source_dim]. + If None, uses norm(target) as source (self-attention) + attention_logit_biases: See ``On attention logit biases`` in the file comments. + return_aux: See comments on `Output`. + + Returns: + A tuple (init_states, output): + * init_states: A Nested Tensor state depending on the `attention` layer implementation. + * output: In the prefill case, an Output instance, where .data is of the same shape as + query, .probs is of shape [batch, num_heads, target_length, source_length]. + Otherwise, if initializing cache from scratch, output will be None. + """ + return self._forward_for_mode( + mode=ForwardMode.INIT_STATES, + target=target, + source=source, + cached_states=dict(attention=time_step), + attention_logit_biases=attention_logit_biases, + return_aux=return_aux, ) - k = jnp.repeat(k, num_heads // num_kv_heads, axis=2) - v = jnp.repeat(v, num_heads // num_kv_heads, axis=2) - - (ref_context, ref_probs), _ = F( - ref_layer, - method="_compute_attention", - state=state, - is_training=False, - prng_key=prng_key, - inputs=dict( - q_proj=q, k_proj=k, v_proj=v, attention_logit_biases=attention_logit_biases - ), + def extend_step( + self, + cached_states: NestedTensor, + target: Tensor, + *, + source: Optional[Union[Tensor, KVState]] = None, + attention_logit_biases: Optional[Tensor] = None, + return_aux: Optional[set[str]] = None, + ) -> tuple[Nested[Tensor], Output]: + """Computes the value vector given the query of the current step. + This function is used by autoregressive decoding. + + Args: + cached_states: A `NestedTensor` object containing tensors which are the + results of previous attentions, and index used for fast decoding. Contains + "attention" cached states. + target: Tensor of shape [B, 1, D] corresponding to query vector at index + time_step. + source: An optional KVState or Tensor of shape [batch, source_length, source_dim]. + If None, uses norm(target) as source (self-attention) + attention_logit_biases: See ``On attention logit biases`` in the file comments. + Additionally, target_length is expected to be 1 since this is per time step. + attention_logit_biases should have already taken care of causal masking for + decoding, plus other maskings necessary. + return_aux: See comments on `Output`. + + Returns: + A `NestedTensor` state of key and value pair along with index updated at `time_step`. + An Output instance, where .data is of the same shape as query, .probs is of shape + [batch, num_heads, 1, source_length]. + + Raises: + NotImplementedError: If cfg.structure is not supported. + """ + return self._forward_for_mode( # pytype: disable=bad-return-type + mode=ForwardMode.EXTEND_STEP, + target=target, + source=source, + cached_states=cached_states, + attention_logit_biases=attention_logit_biases, + return_aux=return_aux, ) - assert_allclose(ref_context, test_context) - assert_allclose(ref_probs, ref_probs) - def _scale_query_kwargs( - self, - *, - query_scale_factor: Union[None, int, float, InstantiableConfig[attention.ScaleFn]], - key_scale_factor: Union[None, int, float, InstantiableConfig[attention.ScaleFn]], - ): - model_dim = 16 - if isinstance(query_scale_factor, (int, float)): - query_scale_factor = config_for_function(attention.constant_scale_fn).set( - value=query_scale_factor - ) - if isinstance(key_scale_factor, (int, float)): - key_scale_factor = config_for_function(attention.constant_scale_fn).set( - value=key_scale_factor - ) +def scaled_hidden_dim(scale: float = 4) -> FunctionConfigBase: + def scale_fn(input_dim: int, *, scale: float) -> int: + return round(input_dim * scale) - cfg = attention.MultiheadAttention.default_config().set( - name="test", - query_dim=model_dim, - key_dim=model_dim, - value_dim=model_dim, - num_heads=2, - query_scale=attention.ScaleQuery.default_config().set(scale_factor=query_scale_factor), - key_scale=attention.ScaleKey.default_config().set(scale_factor=key_scale_factor), - ) - cfg.input_linear.layer.bias = False - cfg.output_linear.bias = False - layer = cfg.instantiate(parent=None) + return config_for_function(scale_fn).set(scale=scale) - param_specs = layer.create_parameter_specs_recursively() - layer_params = jax.tree.map( - lambda spec: jnp.ones(spec.shape, dtype=spec.dtype), param_specs - ) - batch_size = 3 - tgt_len = 10 # Must be even. - query = jnp.concatenate( - ( - jnp.ones([batch_size, tgt_len // 2, model_dim]), - jnp.zeros([batch_size, tgt_len // 2, model_dim]), - ), - axis=1, - ) - kwargs = dict( - module=layer, - state=layer_params, - is_training=False, - prng_key=jax.random.PRNGKey(456), - inputs=dict(query=query), - ) - return kwargs +class TransformerFeedForwardLayer(BaseLayer): + """A Transformer feed-forward layer.""" - @parameterized.product(query_scale_factor=[None, 7], key_scale_factor=[None, 11]) - def test_scale_query_key( - self, *, query_scale_factor: Optional[float], key_scale_factor: Optional[float] - ): - kwargs = self._scale_query_kwargs( - query_scale_factor=query_scale_factor, key_scale_factor=key_scale_factor - ) - kwargs["inputs"]["return_aux"] = {"probs"} - forward_outputs, _ = F(**kwargs) - if query_scale_factor is None: - query_scale_factor = kwargs["module"].per_head_dim() ** -0.5 - if key_scale_factor is None: - key_scale_factor = 1 - query_scale_factor = float(query_scale_factor) - key_scale_factor = float(key_scale_factor) - self.assertNestedAllClose( - forward_outputs.probs[0, 0, 0, 0], - # All ones matrix times all ones vector has l2 norm dim ** 1.5. - # Half of input tokens are all ones, half are all zeros. - jax.nn.sigmoid( - kwargs["inputs"]["query"].shape[-1] ** 3 * query_scale_factor * key_scale_factor, + @config_class + class Config(BaseLayer.Config): + """Configures TransformerFeedForwardLayer.""" + + input_dim: Required[int] = REQUIRED # Input feature dim. + # The hidden dim. + # It should be given either as an integer or a function config that instantiates + # a dim-to-dim function, e.g., scaled_hidden_dim(4). + hidden_dim: Required[Union[int, FunctionConfigBase]] = REQUIRED + # Config for the first linear layer. + linear1: InstantiableConfig = Linear.default_config().set( + param_partition_spec=[None, "model"] + ) + # Config for the second linear layer. + linear2: InstantiableConfig = Linear.default_config().set( + param_partition_spec=["model", None] + ) + norm: InstantiableConfig = LayerNorm.default_config() # The normalization layer config. + + # The activation function(s). + # + # If a single string, the activation applied on the output of linear1. + # + # If a tuple of two strings, this layer will contain separate child Linear layers, one for + # each activation function, according to cfg.linear1 with `hidden_dim` as the output dim. + # The activation outputs will be multiplied element-wise to produce the inputs for linear2. + # See the implementation in _linear1_activation(). + # This supports the gated linear activations proposed by Shazeer in + # https://arxiv.org/abs/2002.05202. + activation: Union[str, tuple[str, str]] = "nn.relu" + + # The dropout layer config. + dropout: InstantiableConfig = Dropout.default_config() + + # The stochastic depth layer config. + # Pytorch reference: + # https://github.com/facebookresearch/deit/blob/main/models_v2.py#L59 + # Tensorflow reference: + # https://github.com/tensorflow/models/blob/master/official/projects/vit/modeling/nn_blocks.py#L103-L119 + stochastic_depth: InstantiableConfig = StochasticDepth.default_config() + + # The inner structure of the layer: "prenorm", "postnorm", "hybridnorm", "nonorm". + # * prenorm: y = x + feedforward(norm(x)) + # * postnorm: y = norm(x + feedforward(x)) + # * hybridnorm: y = postnorm(x + feedforward(prenorm(x))) + # * nonorm: y = feedforward(x) # no residual, which is usually applied externally. + # + # References: + # prenorm/postnorm: https://arxiv.org/abs/2002.04745. + # hybridnorm: https://github.com/google/praxis/blob/main/praxis/layers/transformers.py#L273 + # nonorm: see ParallelTransformerLayer. + structure: str = "prenorm" + + # outputs = inputs + residual_weight * x. + residual_weight: float = 1.0 + + # Auxiliary stats. + + # If True, add "dead_neurons/{activation}" stats for activation functions that have + # zones of near-zero gradients, e.g., x < 0 for ReLU. + # + # A "neuron" `i` is considered dead if all of x[..., i] (across batch/seq) fall within the + # dead zone. + # + # Only supported for a subset of activation functions, including relu, gelu, and silu. + add_dead_neuron_summary: Optional[bool] = None + + # Adds summary of RMS norms of the specified values. Supported value are: + # - "inputs": inputs of the layer. + # - "linear1_outputs": outputs of linear1. + # - "linear2_outputs": outputs of linear2. + # TODO(tlei3): deprecate this feature since we use TensorStats. + add_value_rms_norm_summary: Sequence[str] = [] + + def __init__(self, cfg: Config, *, parent: Module): + super().__init__(cfg, parent=parent) + cfg: TransformerFeedForwardLayer.Config = self.config + if cfg.structure in ["prenorm", "postnorm"]: + self._add_child("norm", cfg.norm.set(input_dim=cfg.input_dim)) + elif cfg.structure == "hybridnorm": + self._add_child("prenorm", cfg.norm.set(input_dim=cfg.input_dim)) + self._add_child("postnorm", cfg.norm.set(input_dim=cfg.input_dim)) + elif cfg.structure == "nonorm": + pass + else: + raise NotImplementedError(cfg.structure) + + if isinstance(cfg.hidden_dim, int): + hidden_dim = cfg.hidden_dim + else: + hidden_dim = cfg.hidden_dim.set(input_dim=cfg.input_dim).instantiate() + if isinstance(cfg.activation, tuple): + assert len(cfg.activation) == 2, cfg.activation + # Create a linear1 projection for each activation. + for i in range(len(cfg.activation)): + self._add_child( + f"linear1_{i}", + cfg.linear1.set(input_dim=cfg.input_dim, output_dim=hidden_dim), + ) + else: + assert isinstance(cfg.activation, str), cfg.activation + self._add_child( + "linear1", + cfg.linear1.set(input_dim=cfg.input_dim, output_dim=hidden_dim), ) - / (kwargs["inputs"]["query"].shape[1] // 2), - ) + self._add_child( + "linear2", + cfg.linear2.set(input_dim=hidden_dim, output_dim=cfg.input_dim), + ) + if cfg.structure in ["prenorm", "hybridnorm", "nonorm"]: + self._add_child("dropout1", cfg.dropout) + self._add_child("dropout2", cfg.dropout) + elif cfg.structure in ["postnorm"]: + self._add_child("dropout", cfg.dropout) + else: + raise NotImplementedError(cfg.structure) - def test_scale_query_key_dim_dependence(self): - query_scale_factor = config_for_function(attention.pow_scale_fn).set(exp=1) - key_scale_factor = config_for_function(attention.pow_scale_fn).set(exp=-1) - kwargs = self._scale_query_kwargs( - query_scale_factor=query_scale_factor, key_scale_factor=key_scale_factor - ) - kwargs["inputs"]["return_aux"] = {"probs"} - forward_outputs, _ = F(**kwargs) - self.assertNestedAllClose( - forward_outputs.probs[0, 0, 0, 0], - # All ones matrix times all ones vector has l2 norm dim ** 1.5. - # Half of input tokens are all ones, half are all zeros. - jax.nn.sigmoid(float(kwargs["inputs"]["query"].shape[-1] ** 3)) - / (kwargs["inputs"]["query"].shape[1] // 2), - ) + self._add_child("stochastic_depth", cfg.stochastic_depth) + # TODO(tlei3): deprecate this check since we will use TensorStats to handle what + # tensors are logged. + for value in cfg.add_value_rms_norm_summary: + if value not in ["inputs", "linear1_outputs", "linear2_outputs"]: + raise NotImplementedError(f"add_value_rms_norm_summary: {value}") - def test_scale_query_key_barrier(self): - """Tests that the scale factors are not combined. + def forward(self, inputs: Tensor) -> Tensor: + cfg = self.config - Note that even without the barrier, it's not clear that they would be combined. - (They aren't on CPU even without the barrier.) - """ - query_scale_factor = 7 - key_scale_factor = 11 - kwargs = self._scale_query_kwargs( - query_scale_factor=query_scale_factor, key_scale_factor=key_scale_factor - ) + def _linear2(x): + """Applies linear2, optionally logging RMS norm of the output.""" + x = self.linear2(x) + self._add_tensor_stats("linear2_outputs", x) + return x + + self._add_tensor_stats("inputs", inputs) + + remat_pt2 = "linear2" + if cfg.structure == "prenorm": + x = self.norm(inputs) + x = self._linear1_activation(x) + x = self.dropout1(x) + x = _linear2(x) + x = self._remat_name(x, remat_pt2) + x = self.dropout2(x) + x = self.stochastic_depth(x) + if cfg.residual_weight != 1: + x *= cfg.residual_weight + x += inputs + elif cfg.structure == "postnorm": + x = self._linear1_activation(inputs) + x = _linear2(x) + x = self._remat_name(x, remat_pt2) + x = self.dropout(x) + x = self.stochastic_depth(x) + if cfg.residual_weight != 1: + x *= cfg.residual_weight + x = self.norm(x + inputs) + elif cfg.structure == "hybridnorm": + x = self.prenorm(inputs) + x = self._linear1_activation(x) + x = self.dropout1(x) + x = _linear2(x) + x = self._remat_name(x, remat_pt2) + x = self.postnorm(x) + x = self.dropout2(x) + x = self.stochastic_depth(x) + if cfg.residual_weight != 1: + x *= cfg.residual_weight + x += inputs + elif cfg.structure == "nonorm": + x = inputs + x = self._linear1_activation(x) + x = self.dropout1(x) + x = _linear2(x) + x = self._remat_name(x, remat_pt2) + x = self.dropout2(x) + x = self.stochastic_depth(x) + # We still apply `residual_weight`, since there is usually a residual link outside of + # this layer, e.g., in ParallelTransformerLayer. + if cfg.residual_weight != 1: + x *= cfg.residual_weight + else: + raise NotImplementedError(cfg.structure) + return x - # Check optimized HLO scales by query_scale_factor and key_scale_factor as separate - # multiplications. This only checks the default backend, so it doesn't check - # what happens on gpu/tpu unless jax is configured to use them. - f = jax.jit(F, static_argnames=("module", "is_training")) - compile_options = dict( - xla_cpu_enable_fast_math=True, - xla_cpu_fast_math_honor_nans=False, - xla_cpu_fast_math_honor_infs=False, - xla_cpu_fast_math_honor_functions=False, - xla_cpu_fast_math_honor_division=False, - ) - hlo = f.lower(**kwargs).compile(compile_options).as_text() - hlo = test_utils.clean_hlo(hlo) - self.assertIn(str(query_scale_factor), hlo) - self.assertIn(str(key_scale_factor), hlo) - self.assertNotIn(str(query_scale_factor * key_scale_factor), hlo) - - @parameterized.parameters( - [ - ( - 1.0, - jax.nn.sigmoid((1.0 * 1.0) * 2 - jnp.log(6)), - 6, - ), - ( - 1.0, - jax.nn.sigmoid((1.0 * 1.0) * 2 - jnp.log(4)), - 4, - ), - ( - 2.0, - jax.nn.sigmoid((2.0 * 2.0) * 2 - jnp.log(6)), - 6, - ), - ] - ) - def test_sigmoid_compute_attention(self, qkv_value: float, expected_value: float, seq_len: int): - model_dim = 16 - num_heads = 4 - batch_size = 2 - init_key = jax.random.PRNGKey(123) - - cfg = attention.SigmoidAttention.default_config().set( - seq_len=seq_len, - query_dim=model_dim, - key_dim=model_dim, - value_dim=model_dim, - num_heads=num_heads, - query_scale=attention.ScaleQuery.default_config(), - atten_logit_cap=0.0, - dtype=jnp.float32, - ) - sigmoid_attention = cfg.set(name="sigmoid_attention").instantiate(parent=None) - state = sigmoid_attention.initialize_parameters_recursively(prng_key=init_key) - - qkv_shape = [batch_size, seq_len, num_heads, num_heads] - inputs = dict( - q_proj=jnp.full(qkv_shape, fill_value=qkv_value), - k_proj=jnp.full(qkv_shape, fill_value=qkv_value), - v_proj=jnp.full(qkv_shape, fill_value=qkv_value), - attention_logit_biases=attention_bias.CausalAttentionBias(shape=(seq_len, seq_len)), - ) + def _linear1_activation(self, x: Tensor) -> Tensor: + cfg = self.config + if isinstance(cfg.activation, tuple): + activations = [ + self._get_activation( + self._remat_name(self.children[f"linear1_{i}"](x), f"linear1_{i}"), + activation_fn_name=activation, + ) + for i, activation in enumerate(cfg.activation) + ] + assert len(activations) == 2, cfg.activation + outputs = activations[0] * activations[1] + self._add_tensor_stats("linear1_0_outputs", activations[0]) + self._add_tensor_stats("linear1_1_outputs", activations[1]) + self._add_tensor_stats("linear1_outputs", outputs) + return outputs + else: + x = self.linear1(x) + x = self._remat_name(x, "linear1_0") + x = self._get_activation(x, activation_fn_name=cfg.activation) + self._add_tensor_stats("linear1_outputs", x) + return x - # Get outputs. - forward_key = jax.random.PRNGKey(456) + def _get_activation(self, x: Tensor, activation_fn_name: str) -> Tensor: + """Applies activation function on 'x' and optionally counts the number of dead neurons. - (_, probs), _ = F( - sigmoid_attention, - method="_compute_attention", - state=state, - is_training=False, - prng_key=forward_key, - inputs=inputs, - ) + Args: + x: A tensor of shape [B, S, H]. + activation_fn_name: The name of the activation fn. - output_shape = [batch_size, num_heads, seq_len, seq_len] - indexes = jnp.arange(seq_len) - # Zeros outside of the causal triangle. - causal_biases = jax.lax.ge(indexes[:, None], indexes[None, :]) - expected_output = jnp.full(output_shape, fill_value=expected_value) * causal_biases + Returns: + activation_fn(x). + """ + cfg = self.config + if cfg.add_dead_neuron_summary: + if activation_fn_name in ["quick_gelu", "exact_gelu"]: + # To make GELU be sufficiently small. + threshold = -4.0 + elif activation_fn_name in ["nn.silu", "nn.sigmoid"]: + # nn.silu(jnp.array(-10.)) = -0.00045398 + # nn.sigmoid(jnp.array(-10.)) = 4.5397872e-05 + threshold = -10.0 + elif activation_fn_name in ["nn.relu", "squared_relu"]: + threshold = 0 + else: + threshold = None + if threshold is not None: + max_hidden_units = jnp.max(x, axis=(0, 1)) + num_dead_units = jnp.count_nonzero( + jnp.less(max_hidden_units, threshold).astype(jnp.int32) + ) + self.add_summary( + f"dead_neurons/{activation_fn_name}", + num_dead_units, + ) + return get_activation_fn(activation_fn_name)(x) - self.assertNestedAllClose(probs, expected_output) +class TransformerLayer(BaseTransformerLayer): + """A Transformer layer. -def oracle_xl_attention_logits( - query: np.ndarray, - key: np.ndarray, - relative_pos_emb: np.ndarray, - content_bias: np.ndarray, - positional_bias: np.ndarray, -) -> np.ndarray: - """Computes expected attention logits using non-vectorized approach. + Unlike torch.nn.TransformerLayer, this allows components to be customized, e.g., replacing + vanilla attention with relative positional attention from TransformerXL/DeBERTa or replacing + feed-forward with a mixture-of-expert feed-forward layer. + """ - Reference: - https://github.com/tensorflow/lingvo/blob/41212226eac7a26491790c2bd476b78493f93ff6/lingvo/core/attention_util_test.py#L48-L73. + @config_class + class Config(BaseTransformerLayer.Config): + """Configures TransformerLayer.""" - Note that this implementation follows XLNet implementation and is different from the lingvo - implementation in that here the relative_pos_emb index is computed from key_i - query_i, - while lingvo computes from query_i - key_i. + self_attention: InstantiableConfig = TransformerAttentionLayer.default_config() + # If not None, the cross-attention layer config. + cross_attention: Optional[InstantiableConfig] = None + feed_forward: InstantiableConfig = TransformerFeedForwardLayer.default_config() - See comments on xl_attention_logits(). - """ - batch, seqlen, num_heads, _ = query.shape - tgtlen, srclen = seqlen, seqlen - - logits = np.zeros((batch, num_heads, tgtlen, srclen)) - - for b in range(batch): - for n in range(num_heads): - for i in range(tgtlen): - for j in range(srclen): - offset = seqlen - 1 - pos_emb = relative_pos_emb[j - i + offset] - logits[b][n][i][j] = np.dot(query[b][i][n], key[b][j][n]) - logits[b][n][i][j] += np.dot(query[b][i][n], pos_emb[n]) - logits[b][n][i][j] += np.dot(content_bias[n], key[b][j][n]) - logits[b][n][i][j] += np.dot(positional_bias[n], pos_emb[n]) - return logits - - -class TransformerXLTest(TestCase): - """Tests TransformerXL.""" - - @parameterized.parameters(5, 2, 1) - def test_rel_pos_to_abs_pos(self, seq_len): - # rel_offset[:, i] = i - (seq_len - 1), i.e., in range [-seq_len + 1, seq_len - 1]. - rel_offset = jnp.tile(jnp.arange(-seq_len + 1, seq_len)[None, :], [seq_len, 1]) - # abs_pos[i, j] = j - i. - abs_pos = rel_pos_to_abs_pos(rel_offset) - expected = jnp.arange(seq_len)[None, :] - jnp.arange(seq_len)[:, None] - assert_allclose(abs_pos, expected) - - def test_xl_attention_logits(self): - num_heads, per_head_dim = 4, 3 - batch_size, tgt_len = 2, 5 - q = jax.random.normal( - jax.random.PRNGKey(100), - [batch_size, tgt_len, num_heads, per_head_dim], - dtype=jnp.float32, - ) - k = jax.random.normal( - jax.random.PRNGKey(101), - [batch_size, tgt_len, num_heads, per_head_dim], - dtype=jnp.float32, - ) - relative_pos_emb = jax.random.normal( - jax.random.PRNGKey(102), [2 * tgt_len - 1, num_heads, per_head_dim], dtype=jnp.float32 - ) - u = jax.random.normal(jax.random.PRNGKey(103), [num_heads, per_head_dim], dtype=jnp.float32) - v = jax.random.normal(jax.random.PRNGKey(104), [num_heads, per_head_dim], dtype=jnp.float32) - expected = oracle_xl_attention_logits( - query=q, key=k, relative_pos_emb=relative_pos_emb, content_bias=u, positional_bias=v - ) - actual = xl_attention_logits( - q_proj=q, k_proj=k, relative_pos_emb=relative_pos_emb, u=u, v=v - ) - assert_allclose(actual, expected) - - @parameterized.product( - per_dim_scale=(None, PerDimScale.default_config()), - scale_position=( - MultiheadAttentionXL.ScalePosition.LOGIT, - MultiheadAttentionXL.ScalePosition.QUERY, - ), - ) - def test_per_dim_scale(self, per_dim_scale, scale_position): - model_dim = 6 - num_heads = 2 - cfg = attention.TransformerAttentionLayer.default_config().set( - name="test", - target_dim=model_dim, - source_dim=model_dim, - structure="postnorm", - attention=MultiheadAttentionXL.default_config().set( - num_heads=num_heads, - query_scale=attention.ScaleQuery.default_config().set(per_dim_scale=per_dim_scale), - scale_position=scale_position, - ), - ) - cfg.attention.output_linear.bias = False - cfg.attention.vlog = 5 - - layer: attention.TransformerAttentionLayer = cfg.instantiate(parent=None) - layer_params = layer.initialize_parameters_recursively(jax.random.PRNGKey(123)) - batch_size, tgt_len = 2, 5 - target = jax.random.normal( - jax.random.PRNGKey(100), [batch_size, tgt_len, model_dim], dtype=jnp.float32 + def __init__(self, cfg: Config, *, parent: Module): + super().__init__(cfg, parent=parent) + cfg: TransformerLayer.Config = self.config + self._add_child( + "self_attention", + cfg.self_attention.set(target_dim=cfg.input_dim, source_dim=cfg.input_dim), ) + self._add_child("feed_forward", cfg.feed_forward.set(input_dim=cfg.input_dim)) + if cfg.cross_attention is not None: + self._add_child("cross_attention", cfg.cross_attention.set(target_dim=cfg.input_dim)) - layer_params["attention"]["u_bias"] = jax.random.normal( - jax.random.PRNGKey(0), [num_heads, model_dim // num_heads] - ) - layer_params["attention"]["v_bias"] = jax.random.normal( - jax.random.PRNGKey(1), [num_heads, model_dim // num_heads] - ) - if per_dim_scale: - layer_params["attention"]["scale_query"]["per_dim_scale"]["param"] = jax.random.normal( - jax.random.PRNGKey(2), [model_dim // num_heads] + def _forward_for_mode( + self, + *, + mode: ForwardMode, + data: Union[Tensor, TensorSpec], + self_attention_kv_state: Optional[KVState] = None, + self_attention_logit_biases: Optional[Tensor] = None, + cross_attention_data: Optional[Tensor] = None, + cross_attention_logit_biases: Optional[Tensor] = None, + target_segment_ids: Optional[Tensor] = None, + target_positions: Optional[Tensor] = None, + cached_states: Optional[NestedTensor] = None, + return_aux: Optional[set[str]] = None, + ) -> tuple[Optional[NestedTensor], Optional[BaseTransformerLayer.Output]]: + """Computes transformer layer outputs and self/cross-attention probabilities. + + Args: + mode: Configures whether `cached_states` are consumed or emitted. See `ForwardMode` for + details. + data: A Tensor or TensorSpec of shape [batch, target_length, target_dim]. + self_attention_kv_state: An optional KVState used for self-attention. + self_attention_logit_biases: An optional Tensor representing the self-attention biases. + cross_attention_data: An optional Tensor of shape [batch, source_length, source_dim]. + cross_attention_logit_biases: An optional Tensor representing the cross-attention + biases. + target_segment_ids: See ``segment_ids`` in the file comments. + target_positions: See ``positions`` in the file comments. + cached_states: Optional NestedTensor as produced by `init_states`. + return_aux: See comments on BaseTransformerLayer.forward. + + Returns: + A tuple (cached_states, output): + * cached_states: An optional Nested Tensor of cache states, depending on `mode`. + * output: An optional Output instance, where .data is of the same shape as `data`, + .self_attention_probs is of shape [batch, num_heads, target_length, target_length]; + .cross_attention_probs is of shape [batch, num_heads, target_length, source_length]. + If initializing cache from scratch, output will be None. + + Raises: + ValueError: If `mode` is unsupported. + """ + if isinstance(data, Tensor): + self.vlog(3, "transformer.input=%s", data.sum()) # pytype: disable=attribute-error + self_attention_return_aux = set() + cross_attention_return_aux = set() + if return_aux: + if "self_attention_probs" in return_aux: + self_attention_return_aux.add("probs") + if "self_attention_kv_state" in return_aux: + self_attention_return_aux.add("kv_state") + if "cross_attention_probs" in return_aux: + cross_attention_return_aux.add("probs") + if mode == ForwardMode.FORWARD: + self_atten_state, self_atten_outputs = ( + None, + self.self_attention( + target=data, + segment_ids=target_segment_ids, + target_positions=target_positions, + source=self_attention_kv_state, + attention_logit_biases=self_attention_logit_biases, + return_aux=self_attention_return_aux, + ), ) - layer_outputs, _ = F( - layer, - state=layer_params, - is_training=True, - prng_key=jax.random.PRNGKey(456), - inputs=dict(target=target), - ) - expected_vals = { - str(None): { - MultiheadAttentionXL.ScalePosition.LOGIT.value: 48.683887, - MultiheadAttentionXL.ScalePosition.QUERY.value: 48.598305, - }, - str(PerDimScale.default_config()): { - MultiheadAttentionXL.ScalePosition.LOGIT.value: 48.790010, - MultiheadAttentionXL.ScalePosition.QUERY.value: 48.858986, - }, - } - assert_allclose( - expected_vals[str(per_dim_scale)][scale_position.value], - jnp.abs(layer_outputs.data).sum(), + elif mode == ForwardMode.INIT_STATES: + assert cached_states is not None + if target_segment_ids is not None: + raise NotImplementedError("target_segment_ids is not supported in INIT_STATES.") + if target_positions is not None: + raise NotImplementedError("target_positions is not supported in INIT_STATES.") + self_atten_state, self_atten_outputs = self.self_attention.init_states( + time_step=cached_states["self_attention"], + target=data, + source=self_attention_kv_state, + attention_logit_biases=self_attention_logit_biases, + return_aux=self_attention_return_aux, + ) + elif mode == ForwardMode.EXTEND_STEP: + assert cached_states is not None + if target_segment_ids is not None: + raise NotImplementedError("target_segment_ids is not supported in EXTEND_STEP.") + if target_positions is not None: + raise NotImplementedError("target_positions is not supported in EXTEND_STEP.") + self_atten_state, self_atten_outputs = self.self_attention.extend_step( + cached_states=cached_states["self_attention"], + target=data, + source=self_attention_kv_state, + attention_logit_biases=self_attention_logit_biases, + return_aux=self_attention_return_aux, + ) + else: + raise ValueError(f"Unrecognized mode {mode}.") + + if self_atten_outputs is None: + assert mode == ForwardMode.INIT_STATES + return dict(self_attention=self_atten_state), self_atten_outputs + + data = self_atten_outputs.data + self.vlog(3, "self_attention.output=%s", data.sum()) + if cross_attention_data is not None: + cross_atten_outputs = self.cross_attention( + target=data, + source=cross_attention_data, + attention_logit_biases=cross_attention_logit_biases, + return_aux=cross_attention_return_aux, + ) + data = cross_atten_outputs.data + cross_attention_probs = cross_atten_outputs.probs + else: + cross_attention_probs = None + data = self.feed_forward(data) + self.vlog(3, "transformer.output=%s", data.sum()) + # TODO(markblee): Support module outputs in decoding. + if mode == ForwardMode.FORWARD: + self.add_module_output("output", data) + return dict(self_attention=self_atten_state), BaseTransformerLayer.Output( + data=data, + self_attention_probs=self_atten_outputs.probs, + self_attention_kv_state=self_atten_outputs.kv_state, + cross_attention_probs=cross_attention_probs, ) - def test_multihead_attention_xl(self): - model_dim = 6 - num_heads = 2 - per_head_dim = model_dim // num_heads - cfg = attention.TransformerAttentionLayer.default_config().set( - name="test", - target_dim=model_dim, - source_dim=model_dim, - structure="postnorm", - attention=MultiheadAttentionXL.default_config().set(num_heads=num_heads), - ) - cfg.attention.output_linear.bias = False - cfg.attention.vlog = 5 - layer: attention.TransformerAttentionLayer = cfg.instantiate(parent=None) - layer.initialize_parameters_recursively(jax.random.PRNGKey(123)) - ref_cfg = hf_xlnet.XLNetConfig( - n_head=num_heads, - d_model=model_dim, - d_head=model_dim // num_heads, - dropout=0, - layer_norm_eps=cfg.norm.eps, - ) - ref = hf_xlnet.XLNetRelativeAttention(ref_cfg) - # XLNetRelativeAttention is not properly initialized. - with torch.no_grad(): - for var in ("q", "k", "v", "o", "r"): - getattr(ref, var).copy_( - torch.normal(0, np.sqrt(model_dim), [model_dim, num_heads, per_head_dim]) - ) - for var in ("r_w_bias", "r_r_bias"): - getattr(ref, var).copy_( - torch.normal(0, np.sqrt(model_dim), [num_heads, model_dim // num_heads]) - ) - batch_size, tgt_len = 2, 5 - target = jax.random.normal( - jax.random.PRNGKey(100), [batch_size, tgt_len, model_dim], dtype=jnp.float32 - ) - num_tokens = jax.random.randint( - jax.random.PRNGKey(101), - minval=2, - maxval=tgt_len + 1, - shape=[batch_size], - ) - # [batch_size, tgt_len]. - is_valid_token = jnp.arange(tgt_len)[None, :] < num_tokens[:, None] - # [batch_size, 1, tgt_len, tgt_len]. - attention_logit_biases = jnp.expand_dims( - NEG_INF * (1 - jnp.einsum("bt,bs->bts", is_valid_token, is_valid_token)), 1 - ) - # [2 * tgt_len, model_dim]. - rel_pos_emb = sinusoidal_positional_embeddings( - jnp.arange(tgt_len, -tgt_len, -1), dim=model_dim - ) - ref_inputs = dict( - g=None, - h=target.transpose([1, 0, 2]), # [qlen, bsz, d_model]. - r=rel_pos_emb[:, None, :], # [rlen, 1, d_model]. - attn_mask_g=None, - # [qlen, klen, bsz, n_head]. - attn_mask_h=attention_logit_biases.transpose([2, 3, 0, 1]) < 0, - seg_mat=None, - ) - logging.info("ref_inputs=%s", ref_inputs) - - test_outputs, ref_outputs = self._compute_layer_outputs( - test_layer=layer, - ref_layer=ref, - test_inputs=dict(target=target, attention_logit_biases=attention_logit_biases), - ref_inputs=as_torch_tensor(ref_inputs), - parameters_from_ref_layer=parameters_from_torch_layer, - require_same_num_params=False, + def forward( + self, + data: Tensor, + **kwargs, + ) -> BaseTransformerLayer.Output: + _, output = self._forward_for_mode( + mode=ForwardMode.FORWARD, data=data, cached_states=None, **kwargs ) - logging.info("test_outputs=%s", test_outputs) - logging.info("ref_outputs=%s", ref_outputs) - self.assertNestedAllClose( - test_outputs.data, as_tensor(ref_outputs[0]).transpose([1, 0, 2]), atol=6e-6 + return output + + def init_states( + self, + time_step: Optional[Tensor], + data: Union[Tensor, TensorSpec], + **kwargs, + ) -> tuple[Nested[Tensor], Optional[BaseTransformerLayer.Output]]: + return self._forward_for_mode( + mode=ForwardMode.INIT_STATES, + cached_states=dict(self_attention=time_step), + data=data, + **kwargs, + ) + + def extend_step( + self, + cached_states: NestedTensor, + data: Tensor, + **kwargs, + ) -> tuple[NestedTensor, BaseTransformerLayer.Output]: + return self._forward_for_mode( # pytype:disable=bad-return-type + mode=ForwardMode.EXTEND_STEP, + cached_states=cached_states, + data=data, + **kwargs, ) -class TransformerAttentionLayerTest(TestCase): - @parameterized.parameters([False, True]) - def test_forward_vs_extend_step(self, with_source: bool): - init_prng, target_prng, source_prng = jax.random.split(jax.random.PRNGKey(0), 3) +class ParallelTransformerLayer(BaseTransformerLayer): + """A Transformer layer with parallel self-attention and feed-forward layers: - model_dim = 8 - layer_kwargs = dict(target_dim=model_dim, source_dim=model_dim) - cfg: TransformerAttentionLayer.Config = TransformerAttentionLayer.default_config().set( - **layer_kwargs - ) - cfg.attention.set(num_heads=2, mask=causal_mask) - layer: TransformerAttentionLayer = cfg.set(name="test").instantiate(parent=None) - layer_params = layer.initialize_parameters_recursively(prng_key=init_prng) + x = norm(inputs) + outputs = inputs + self_atten(x) + ffn(x) - batch, decode_len = 2, 6 - target = jax.random.uniform(target_prng, shape=[batch, decode_len, model_dim]) - input_kwargs = {} + TODO(rpang): experiment to understand whether we should use separate normalization layers + for self_atten and ffn as in PaLM. - if with_source: - input_kwargs.update( - source=jax.random.uniform(source_prng, shape=[batch, decode_len, model_dim]) - ) + References: + https://github.com/kingoflolz/mesh-transformer-jax + PaLM: https://arxiv.org/abs/2204.02311 + """ - forward_outputs, _ = F( - layer, - inputs=dict(target=jnp.asarray(target), **input_kwargs), - state=layer_params, - is_training=True, - prng_key=jax.random.PRNGKey(0), + @config_class + class Config(BaseTransformerLayer.Config): + norm: InstantiableConfig = LayerNorm.default_config() # The normalization layer config. + self_attention: MultiheadAttention.Config = MultiheadAttention.default_config() + feed_forward: TransformerFeedForwardLayer.Config = ( + TransformerFeedForwardLayer.default_config().set(structure="nonorm") ) - for start_time_step in (-1, 0, 2, decode_len): - if start_time_step < 0: - (cached_states, init_outputs), _ = F( - layer, - inputs=dict( - time_step=None, - target=TensorSpec(target.shape, target.dtype), - **input_kwargs, - ), - state=layer_params, - is_training=True, - prng_key=jax.random.PRNGKey(0), - method="init_states", - ) - self.assertIsNone(init_outputs) - data = jnp.zeros([batch, decode_len, model_dim]) - start_time_step = 0 - else: - (cached_states, prefill_outputs), _ = F( - layer, - inputs=dict( - time_step=jnp.array([start_time_step] * batch, dtype=jnp.int32), - target=target, - **input_kwargs, - ), - state=layer_params, - is_training=True, - prng_key=jax.random.PRNGKey(0), - method="init_states", - ) - data = prefill_outputs.data + def __init__(self, cfg: Config, *, parent: Module): + super().__init__(cfg, parent=parent) + cfg: TransformerLayer.Config = self.config + self._add_child("norm", cfg.norm.set(input_dim=cfg.input_dim)) + self._add_child( + "self_attention", + cfg.self_attention.set( + query_dim=cfg.input_dim, + key_dim=cfg.input_dim, + value_dim=cfg.input_dim, + output_dim=cfg.input_dim, + ), + ) + self._add_child("feed_forward", cfg.feed_forward.set(input_dim=cfg.input_dim)) - data = jnp.einsum("btd->tbd", data) + def forward( + self, + *, + data: Tensor, + self_attention_logit_biases: Optional[Tensor] = None, + target_segment_ids: Optional[Tensor] = None, + ) -> BaseTransformerLayer.Output: + """Computes transformer layer outputs and self/cross-attention probabilities. - for time_step in range(start_time_step, decode_len): - extend_kwargs = {} - for k, v in input_kwargs.items(): - extend_kwargs[k] = jnp.asarray(v[:, time_step : time_step + 1, :]) + Args: + data: A Tensor of shape [batch, target_length, target_dim]. + self_attention_logit_biases: An optional Tensor representing the self-attention biases. + target_segment_ids: See ``segment_ids`` in the file comments. - (cached_states, extend_outputs), _ = F( - layer, - inputs=dict( - target=jnp.asarray(target[:, time_step : time_step + 1, :]), - cached_states=cached_states, - **extend_kwargs, - ), - state=layer_params, - is_training=True, - prng_key=jax.random.PRNGKey(0), - method="extend_step", - ) - data = data.at[time_step].set(jnp.squeeze(extend_outputs.data, axis=1)) + Returns: + An Output instance, where .data is of the same shape as `data`, .self_attention_probs is + of shape [batch, num_heads, target_length, target_length]. - data = jnp.einsum("tbd->btd", data) + Raises: + ValueError: If `mode` is unsupported. + """ + inputs = data + data = self.norm(data) + self_atten_outputs = self.self_attention( + query=data, + key=data, + value=data, + attention_logit_biases=self_attention_logit_biases, + segment_ids=target_segment_ids, + ) + feed_forward_outputs = self.feed_forward(data) + outputs = inputs + self_atten_outputs.data + feed_forward_outputs + return BaseTransformerLayer.Output( + data=outputs, + self_attention_probs=self_atten_outputs.probs, + self_attention_kv_state=self_atten_outputs.kv_state, + cross_attention_probs=None, + ) - # Prefill + extend_step == forward. - assert_allclose(forward_outputs.data, data) +def _next_power_of_two(n: float) -> int: + if n <= 1: + return 2 + return 1 << int(math.log2(n - 1)) + 1 -class TransformerFeedForwardLayerTest(TestCase): - @parameterized.parameters( - dict(rms_norm_summary=[]), - dict(rms_norm_summary=["linear2_outputs"]), - dict(rms_norm_summary=["final_outputs"], expected_raise_regex="add_value_rms_norm_summary"), - ) - def test_add_value_rms_norm_summary( - self, rms_norm_summary: list[str], *, expected_raise_regex=None - ): - batch, seq_len, dim = 2, 3, 4 - cfg = TransformerFeedForwardLayer.default_config().set( - name="ffn", - input_dim=dim, - hidden_dim=dim * 4, - add_value_rms_norm_summary=rms_norm_summary, - tensor_stats=DefaultTensorStats.default_config(), - ) - if expected_raise_regex is not None: - with self.assertRaisesRegex(NotImplementedError, expected_raise_regex): - layer = cfg.instantiate(parent=None) - return - layer = cfg.instantiate(parent=None) - layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) - x = jax.random.normal(jax.random.PRNGKey(1), shape=[batch, seq_len, dim]) - y, output_collection = F( - layer, - inputs=dict(inputs=x), - state=layer_params, - is_training=True, - prng_key=jax.random.PRNGKey(0), - ) - self.assertSequenceEqual(x.shape, y.shape) - self.assertNestedAllClose(2.663487, jnp.sum(y)) - if "tensor_stats" in output_collection.summaries: - output_stats = output_collection.summaries["tensor_stats"] - else: - output_stats = {} - for k in rms_norm_summary: - assert k in output_stats - - @parameterized.parameters( - dict(activation_fn="nn.relu"), - dict(activation_fn=("nn.relu", "linear")), - dict(activation_fn=("linear", "quick_gelu")), - dict(activation_fn=("linear", "exact_gelu")), - dict(activation_fn=("linear", "nn.silu")), - ) - def test_add_dead_neuron_summary(self, activation_fn: Union[str, list[str]]): - batch, seq_len, dim = 2, 3, 4 - cfg = TransformerFeedForwardLayer.default_config().set( - name="ffn", - input_dim=dim, - hidden_dim=dim * 4, - activation=activation_fn, - add_dead_neuron_summary=True, - ) - layer = cfg.instantiate(parent=None) - layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) - x = jax.random.normal(jax.random.PRNGKey(1), shape=[batch, seq_len, dim]) - y, output_collection = F( - layer, - inputs=dict(inputs=x), - state=layer_params, - is_training=True, - prng_key=jax.random.PRNGKey(0), - ) - self.assertSequenceEqual(x.shape, y.shape) - if isinstance(activation_fn, str): - activation_fn = [activation_fn] - self.assertSetEqual( - {k for k in output_collection.summaries.keys() if k.startswith("dead_neurons/")}, - { - f"dead_neurons/{k}" - for k in activation_fn - if k in ("nn.relu", "quick_gelu", "exact_gelu", "nn.silu") - }, - ) - def test_linear_remat(self): - batch, seq_len, dim = 2, 3, 4 - cfg = TransformerFeedForwardLayer.default_config().set( - name="ffn", - input_dim=dim, - hidden_dim=dim * 4, - add_value_rms_norm_summary=[], - tensor_stats=DefaultTensorStats.default_config(), - activation=("nn.relu", "nn.relu"), - ) - layer = cfg.instantiate(parent=None) - layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) - x = jax.random.normal(jax.random.PRNGKey(1), shape=[batch, seq_len, dim]) - - def f(x, layer_params): - y, _ = F( - layer, - inputs=dict(inputs=x), - state=layer_params, - is_training=True, - prng_key=jax.random.PRNGKey(0), - ) - return y - - _, save_name_backward = jax.linearize( - jax.remat( - f, - policy=save_and_offload_only_these_names_regex( - names_which_can_be_saved=RematRegexSavePatterns.FEED_FORWARD.value, - names_which_can_be_offloaded=None, - offload_src="device", - offload_dst="pinned_host", - ), - ), - x, - layer_params, - ) - _, save_dots_backward = jax.linearize( - jax.remat(f, policy=jax_remat_policies.dots_saveable), x, layer_params - ) +class BottleNeckAdapterTransformerLayer(BaseTransformerLayer): + """TransformerLayer with bottleneck adaptor for fine-tuning. + Figure 3(a) in https://arxiv.org/pdf/2110.04366.pdf + """ - self.assertEqual(str(save_name_backward).count(" dot_general"), 6) - self.assertEqual( - str(save_name_backward).count(" dot_general"), - str(save_dots_backward).count(" dot_general"), - ) + @config_class + class Config(BaseTransformerLayer.Config): + """Configures BottleNeckAdapterTransformerLayer.""" + # The transformer layer to which an adapter will be added. + layer: BaseTransformerLayer.Config = TransformerLayer.default_config() -class BaseTransformerTest(TestCase): - def _test_decoder_with_transformer(self, transformer_cfg: BaseTransformerLayer.Config): - prefix_length = jnp.asarray([0, 2]) - batch_size, num_decodes, seq_len, vocab_size = prefix_length.shape[0], 3, 7, 6 - bos_id = eos_id = 1 - pad_token_id = 0 + # The adapter, which in this case is a bottleneck layer composed of + # a downward and an upward projection. + adapter: TransformerFeedForwardLayer.Config = TransformerFeedForwardLayer.default_config() - cfg = Decoder.default_config().set( - transformer=transformer_cfg.clone(name="transformer"), - dim=transformer_cfg.input_dim, - vocab_size=vocab_size, - emb=TransformerTextEmbeddings.default_config().set( - pos_emb=LearnedPositionalEmbedding.default_config().set(shape=(seq_len,)) + # The ratio by which the input dimension will be + # reduced in the downward projection in the adapter. + bottleneck_ratio: float = 0.5 + + def __init__(self, cfg: Config, *, parent: Module): + super().__init__(cfg, parent=parent) + cfg = self.config + self._add_child("layer", cfg.layer) + self._add_child( + "adapter", + cfg.adapter.set( + input_dim=cfg.layer.input_dim, + hidden_dim=_next_power_of_two(cfg.layer.input_dim * cfg.bottleneck_ratio), + structure="postnorm", ), - # output_norm=LayerNorm.default_config().set(eps=layer_norm_epsilon), - # dropout_rate=dropout_rate, - pad_token_id=pad_token_id, - eos_token_id=eos_id, ) - decoder: Decoder = cfg.set(name="decoder").instantiate(parent=None) - decoder_state = decoder.initialize_parameters_recursively(jax.random.PRNGKey(0)) + def _forward_for_mode( + self, + *, + mode: ForwardMode, + data: Union[Tensor, TensorSpec], + cached_states: Optional[NestedTensor] = None, + **kwargs, + ) -> tuple[Optional[Nested[Tensor]], Optional[Tensor]]: + """Computes transformer layer outputs and self/cross-attention probabilities. + + Args: + mode: Configures whether `cached_states` are consumed or emitted. See `ForwardMode` for + details. + data: A Tensor of shape [batch, target_length, target_dim]. + cached_states: Optional NestedTensor as produced by `init_states`. + + Returns: + A tuple (cached_states, output): + * cached_states: An optional NestedTensor of cache states, depending on `mode`. + * output: An Output instance, where .data is of the same shape as `data`; + .self_attention_probs is of shape [batch, num_heads, target_length, target_length]; + .cross_attention_probs is of shape [batch, num_heads, target_length, source_length]. + If initializing cache from scratch, output will be None. + + Raises: + ValueError: If `mode` is unsupported. + """ + if isinstance(data, Tensor): + self.vlog(3, "transformer.input=%s", data.sum()) # pytype: disable=attribute-error + if mode == ForwardMode.FORWARD: + output = self.layer.forward(data=data, **kwargs) + elif mode == ForwardMode.INIT_STATES: + assert cached_states is not None + cached_states, output = self.layer.init_states( + time_step=cached_states["layer"], + data=data, + **kwargs, + ) + elif mode == ForwardMode.EXTEND_STEP: + assert cached_states is not None + cached_states, output = self.layer.extend_step( + cached_states=cached_states, + data=data, + **kwargs, + ) + else: + raise ValueError(f"Unrecognized mode {mode}.") + + if output is None: + assert mode == ForwardMode.INIT_STATES and cached_states["layer"] is None + return cached_states, output - prefix = jax.random.randint( - jax.random.PRNGKey(124), - shape=[batch_size, seq_len], - # Prefix can consist of any tokens, including pad and eos. - minval=0, - maxval=vocab_size, - ) - # Explicitly fill positions >= prefix_length with pad_token_id. - # Note that each batch example may have a different prefix length. - # [batch_size, seq_len]. - prefix_mask = jnp.arange(seq_len) < prefix_length[:, None] - prefix = prefix * prefix_mask + pad_token_id * (1 - prefix_mask) - # Set last token to a non-pad token, to fix the prefix length. - oh_indices = jax.nn.one_hot(prefix_length - 1, seq_len, dtype=prefix.dtype) - prefix = prefix * (1 - oh_indices) + bos_id * oh_indices - inputs = dict( - input_batch=dict(prefix=prefix), - max_sequence_length=seq_len, - # cross_attention_data=None, - # cross_attention_logit_biases=None, - num_decodes=num_decodes, - ) - outputs, _ = F( - decoder, - inputs=inputs, - state=decoder_state, - is_training=False, - prng_key=jax.random.PRNGKey(2), - method="sample_decode", + skip_input = output.data + data = self.adapter(output.data) + data += skip_input + self.vlog(3, "adapted_transformer.output=%s", data.sum()) + return cached_states, output._replace(data=data) + + def forward( + self, + data: Tensor, + **kwargs, + ) -> BaseTransformerLayer.Output: + _, output = self._forward_for_mode( + mode=ForwardMode.FORWARD, + data=data, + cached_states=None, + **kwargs, ) - sequences = outputs.sequences - self.assertEqual(sequences.shape, (batch_size, num_decodes, seq_len)) + return output - def _test_forward_vs_extend_step( + def init_states( self, - cfg: BaseTransformerLayer.Config, *, - input_kwargs: Optional[dict[str, Any]] = None, - ): - """Tests that {init,prefill}_states + extend_step is equivalent to forward for `cfg`.""" - if input_kwargs is None: - input_kwargs = {} - layer: BaseTransformerLayer = cfg.clone(name="layer").instantiate(parent=None) - layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) - - batch_size, tgt_len = 2, 5 - rng = np.random.default_rng(seed=123) - target = rng.random([batch_size, tgt_len, cfg.input_dim], dtype=np.float32) - - forward_outputs, _ = F( - layer, - inputs=dict( - data=jnp.asarray(target), - **input_kwargs, - ), - state=layer_params, - is_training=True, - prng_key=jax.random.PRNGKey(0), + time_step: Optional[Tensor], + data: Union[Tensor, TensorSpec], + **kwargs, + ) -> tuple[Nested[Tensor], Optional[BaseTransformerLayer.Output]]: + return self._forward_for_mode( + mode=ForwardMode.INIT_STATES, + cached_states=dict(layer=time_step), + data=data, + **kwargs, + ) + + def extend_step( + self, + cached_states: NestedTensor, + data: Tensor, + **kwargs, + ) -> tuple[NestedTensor, BaseTransformerLayer.Output]: + return self._forward_for_mode( # pytype: disable=bad-return-type + mode=ForwardMode.EXTEND_STEP, + cached_states=cached_states, + data=data, + **kwargs, ) - for start_time_step in (-1, 0, 2, tgt_len): - if start_time_step > tgt_len: - continue - print(f"start_time_step={start_time_step} layer={type(layer)}") - if start_time_step < 0: - (cached_states, init_outputs), _ = F( - layer, - inputs=dict( - time_step=None, - data=TensorSpec([batch_size, tgt_len]), - **input_kwargs, - ), - state=layer_params, - is_training=True, - prng_key=jax.random.PRNGKey(0), - method="init_states", - ) - self.assertIsNone(init_outputs) - decoder_output = jnp.zeros_like(target) - start_time_step = 0 - else: - (cached_states, prefill_outputs), _ = F( - layer, - inputs=dict( - time_step=jnp.array([start_time_step] * batch_size, dtype=jnp.int32), - data=jnp.asarray(target), - **input_kwargs, - ), - state=layer_params, - is_training=True, - prng_key=jax.random.PRNGKey(0), - method="init_states", - ) - decoder_output = prefill_outputs.data - # Transpose to [tgt_len, batch_size, model_dim]. - decoder_output = jnp.einsum("bsd->sbd", decoder_output) - for time_step in range(start_time_step, tgt_len): - (cached_states, extend_step_outputs), _ = F( - layer, - inputs=dict( - data=jnp.asarray(target[:, time_step : time_step + 1, :]), - cached_states=cached_states, - **input_kwargs, - ), - state=layer_params, - is_training=True, - prng_key=jax.random.PRNGKey(0), - method="extend_step", - ) - decoder_output = decoder_output.at[time_step].set( - jnp.squeeze(extend_step_outputs.data, axis=1) - ) - # Transpose to [batch_size, tgt_len, model_dim]. - decoder_output = jnp.einsum("sbd->bsd", decoder_output) - # Prefill + extend_step == forward. - assert_allclose(forward_outputs.data, decoder_output) +def set_double_shard_weights_config( + cfg: Union[TransformerLayer.Config, Sequence[TransformerLayer.Config]], + *, + batch_axis_names: Union[str, Sequence[str]] = ("data", "fsdp"), + fsdp_axis_names: Union[str, Sequence[str]] = "fsdp", + tp_axis_names: Union[str, Sequence[str]] = "model", + seq_axis_names: Union[str, Sequence[str]] = "seq", +): + """Sets `cfg` to shard FFN and attention weights over both fsdp and tp axes. -class TransformerTest(BaseTransformerTest): - """Tests TransformerLayer.""" + Args: + cfg: (A sequence of) Transformer layer config to apply sharding spec to. + batch_axis_names: Axis name(s) over which we shard the batch dimension of output tensors. + fsdp_axis_names: Axis name(s) over which we shard fully-sharded-data-parallel tensors. + tp_axis_names: Axis name(s) over which we shard tensor-parallel tensors. + seq_axis_names: Axis name(s) over which we shard sequence-parallel tensors. + """ - def _compare_against_roberta_attention( - self, ref: hf_roberta.RobertaAttention, layer: TransformerAttentionLayer - ): - layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) - layer_param_shapes = jax.tree.map(lambda x: x.shape, layer_params) - print(f"layer state={layer_param_shapes}") - layer_params = parameters_from_torch_layer(ref) - batch_size, tgt_len = 2, 6 - model_dim, num_heads = layer.config.target_dim, layer.config.attention.num_heads - rng = np.random.default_rng(seed=123) - target = rng.random([batch_size, tgt_len, model_dim], dtype=np.float32) - null_mask = jnp.zeros([tgt_len, tgt_len]) - rand_mask = _random_mask(jax.random.PRNGKey(123), tgt_len, tgt_len) - for mask in (None, null_mask, rand_mask): - if mask is not None: - mask = jnp.tile(mask[None, None, :, :], (batch_size, num_heads, 1, 1)) - layer_outputs, _ = F( - layer, - inputs=dict(target=jnp.asarray(target), attention_logit_biases=mask), - state=layer_params, - is_training=True, - prng_key=jax.random.PRNGKey(0), - ) - attn_mask = None if mask is None else as_torch_tensor(mask) - (ref_outputs,) = ref.forward( - torch.as_tensor(target, dtype=torch.float32), - attention_mask=attn_mask, - output_attentions=False, - ) - assert_allclose(layer_outputs.data, as_tensor(ref_outputs)) - - def test_against_roberta_attention(self): - model_dim = 16 - num_heads = 4 - cfg = attention.TransformerAttentionLayer.default_config().set( - name="test", - target_dim=model_dim, - source_dim=model_dim, - structure="postnorm", - ) - cfg.attention.set(num_heads=num_heads) - layer = cfg.instantiate(parent=None) - roberta_config = hf_roberta.RobertaConfig( - hidden_size=model_dim, - num_attention_heads=num_heads, - attention_probs_dropout_prob=0, - hidden_dropout_prob=0, - classifier_dropout=0, - ) - print(f"roberta_config={roberta_config}") - ref = hf_roberta.RobertaAttention(roberta_config) - self._compare_against_roberta_attention(ref, layer) - - def _compare_against_roberta_layer(self, ref: hf_roberta.RobertaLayer, layer: TransformerLayer): - layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) - layer_params = parameters_from_torch_layer(ref) - batch_size, tgt_len = 2, 6 - model_dim, num_heads = ( - layer.config.input_dim, - layer.config.self_attention.attention.num_heads, - ) - rng = np.random.default_rng(seed=123) - target = rng.random([batch_size, tgt_len, model_dim], dtype=np.float32) - null_mask = jnp.zeros([tgt_len, tgt_len]) - rand_mask = _random_mask(jax.random.PRNGKey(123), tgt_len, tgt_len) - for mask in (None, null_mask, rand_mask): - if mask is not None: - mask = jnp.tile(mask[None, None, :, :], (batch_size, num_heads, 1, 1)) - layer_outputs, output_collection = F( - layer, - inputs=dict(data=jnp.asarray(target), self_attention_logit_biases=mask), - state=layer_params, - is_training=True, - prng_key=jax.random.PRNGKey(0), - drop_output_collections=(), - ) - if layer_outputs.self_attention_probs is not None: - self.assertEqual( - (batch_size, num_heads, tgt_len, tgt_len), - layer_outputs.self_attention_probs.shape, - ) - attn_mask = None if mask is None else as_torch_tensor(mask) - (ref_outputs,) = ref.forward( - torch.as_tensor(target, dtype=torch.float32), - attention_mask=attn_mask, - output_attentions=False, - ) - assert_allclose(layer_outputs.data, as_tensor(ref_outputs)) - self.assertNestedEqual(layer_outputs.data, output_collection.module_outputs["output"]) - - def test_against_roberta_layer(self): - model_dim = 16 - num_heads = 4 - cfg = TransformerLayer.default_config().set(name="test", input_dim=model_dim) - cfg.self_attention.set(structure="postnorm") - cfg.feed_forward.set( - structure="postnorm", activation="nn.silu", hidden_dim=scaled_hidden_dim(4) - ) - cfg.feed_forward.linear1.set(bias=True) - cfg.feed_forward.linear2.set(bias=True) - cfg.self_attention.attention.set(num_heads=num_heads) - cfg.self_attention.attention.input_linear.layer.set(bias=True) - cfg.self_attention.attention.output_linear.set(bias=True) - layer: TransformerLayer = cfg.instantiate(parent=None) - roberta_config = hf_roberta.RobertaConfig( - hidden_size=model_dim, - num_attention_heads=num_heads, - attention_probs_dropout_prob=0, - hidden_dropout_prob=0, - classifier_dropout=0, - # Jax's gelu uses an approximation by default and is slightly different from - # torch.nn.gelu. - hidden_act="silu", - ) - ref = hf_roberta.RobertaLayer(roberta_config) - self._compare_against_roberta_layer(ref, layer) - - def test_decoding(self): - model_dim, num_heads = 6, 2 - cfg = TransformerLayer.default_config().set(input_dim=model_dim) - cfg.self_attention.attention.set(num_heads=num_heads, causal=True) - cfg.feed_forward.hidden_dim = model_dim * 4 - cfg.vlog = 5 - self._test_forward_vs_extend_step(cfg) - - def test_self_attention_kv_state(self): - """Tests TransformerLayer with explicit self_attention_kv_state. - - Creates a base TransformerLayer and a test TransformerLayer with QLinear. Uses the kv_state - of the base layer as the explicit kv_state for the test layer. Checks that the outputs are - identical. - """ - model_dim = 16 - num_heads = 4 - base_cfg = TransformerLayer.default_config().set(name="test", input_dim=model_dim) - base_cfg.feed_forward.set(hidden_dim=scaled_hidden_dim(4)) - base_cfg.self_attention.attention.set(num_heads=num_heads, causal=True) - base_layer: TransformerLayer = base_cfg.instantiate(parent=None) - base_layer_params = base_layer.initialize_parameters_recursively( - prng_key=jax.random.PRNGKey(0) - ) + # pytype: disable=attribute-error + def set_attn_partition_specs(attn_layer: MultiheadAttention.Config): + # Shard weights. + input_linear_cfg = attn_layer.input_linear + if hasattr(input_linear_cfg, "input_linear"): + input_linear_cfg = input_linear_cfg.input_linear + input_linear_cfg.layer.param_partition_spec = (fsdp_axis_names, tp_axis_names, None) + attn_layer.output_linear.param_partition_spec = (fsdp_axis_names, tp_axis_names, None) - test_cfg = base_cfg.clone() - test_cfg.self_attention.attention.input_linear = QLinear.default_config() - test_layer: TransformerLayer = test_cfg.instantiate(parent=None) - # Let test_layer_params to be identical to base_layer_params except removing {k,v}_proj. - test_layer_params = copy.deepcopy(base_layer_params) - for k in ("k_proj", "v_proj"): - test_layer_params["self_attention"]["attention"]["i_proj"].pop(k) - self.assertEqual( - shapes(test_layer_params), - shapes(test_layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0))), - ) + def set_ffn_partition_specs(ff_layer: TransformerFeedForwardLayer.Config): + # Shard weights. + ff_layer.linear1.param_partition_spec = (fsdp_axis_names, tp_axis_names) + ff_layer.linear2.param_partition_spec = (tp_axis_names, fsdp_axis_names) + # Encourage the right activation sharding. + ff_layer.linear1.output_partition_spec = (batch_axis_names, seq_axis_names, tp_axis_names) + ff_layer.linear2.output_partition_spec = (batch_axis_names, seq_axis_names, tp_axis_names) + + if not isinstance(cfg, Sequence): + cfg = [cfg] + + for layer_cfg in cfg: + set_attn_partition_specs(layer_cfg.self_attention.attention) + if layer_cfg.cross_attention is not None: + set_attn_partition_specs(layer_cfg.cross_attention.attention) + if isinstance(layer_cfg.feed_forward, TransformerFeedForwardLayer.Config): + set_ffn_partition_specs(layer_cfg.feed_forward) + # pytype: enable=attribute-error + + +class BaseStackedTransformerLayer(BaseTransformerLayer): + """The common interface of all stacked transformer layer classes. + + Note that BaseStackedTransformerLayer is a subclass of BaseTransformerLayer and therefore + can be used where a BaseTransformerLayer is expected. + + The Output returned by BaseStackedTransformerLayer has the following fields: + * .data is of the same shape as query, from the output of the final layer; + * .self_attention_kv_state is of shape [batch, target_length, num_heads, head_dim], + from the self-attention KV state of the final layer; + * .probs is of shape [num_layers, batch, num_heads, target_length, source_length], + from all layers of the stack; + """ - batch_size, tgt_len = 2, 5 - rng = np.random.default_rng(seed=123) - target = rng.random([batch_size, tgt_len, model_dim], dtype=np.float32) - base_layer_outputs, _ = F( - base_layer, - inputs=dict(data=jnp.asarray(target), return_aux={"self_attention_kv_state"}), - state=base_layer_params, - is_training=True, - prng_key=jax.random.PRNGKey(0), - ) - test_layer_outputs, _ = F( - test_layer, - # Explicitly pass `self_attention_kv_state` from `base_layer_outputs` as inputs to - # test_layer. - inputs=dict( - data=jnp.asarray(target), - self_attention_kv_state=base_layer_outputs.self_attention_kv_state, - return_aux={"self_attention_kv_state"}, - ), - state=test_layer_params, - is_training=True, - prng_key=jax.random.PRNGKey(0), - ) - assert_allclose(base_layer_outputs.data, test_layer_outputs.data) - - # Tests prefill_state and extend_step. - self._test_forward_vs_extend_step( - test_cfg, - input_kwargs=dict( - # Explicitly pass `self_attention_kv_state`. - self_attention_kv_state=base_layer_outputs.self_attention_kv_state, - ), - ) + @config_class + class Config(BaseTransformerLayer.Config): + """Configures BaseStackedTransformerLayer.""" + + # The number of layers in the stack. + num_layers: Required[int] = REQUIRED + # Config for each layer in the stack. + # The layer must be a subclass of BaseTransformerLayer. + layer: BaseTransformerLayer.Config = TransformerLayer.default_config() + peak_stochastic_depth_rate: Optional[float] = None + + +class UpdateDataFn(Protocol): + """A function for updating the constituent layers' input in a StackTransformerLayer.""" + def __call__( + self, data: Tensor, all_layer_outputs: list[BaseTransformerLayer.Output] + ) -> Tensor: + """Returns a new Tensor with the same shape as `data`, reflecting some desired updates. -class ParallelTransformerTest(TestCase): - """Tests ParallelTransformerLayer.""" - - def test_with_golden_value(self): - """A test of ParallelTransformerLayer by comparing results to a golden value.""" - model_dim = 16 - num_heads = 4 - cfg = ParallelTransformerLayer.default_config().set(name="test", input_dim=model_dim) - cfg.feed_forward.set(hidden_dim=scaled_hidden_dim(4)) - cfg.self_attention.set(num_heads=num_heads) - cfg.norm = RMSNorm.default_config() - set_bias_recursively(cfg, bias=False) - layer: TransformerLayer = cfg.instantiate(parent=None) - - layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) - self.assertEqual( - { - "feed_forward": { - "dropout1": {}, - "dropout2": {}, - "linear1": {"weight": (16, 64)}, - "linear2": {"weight": (64, 16)}, - "stochastic_depth": {}, - }, - "norm": {"scale": (16,)}, - "self_attention": { - "dropout": {}, - "i_proj": { - "k_proj": {"weight": (16, 4, 4)}, - "q_proj": {"weight": (16, 4, 4)}, - "v_proj": {"weight": (16, 4, 4)}, - }, - "o_proj": {"weight": (16, 4, 4)}, - "scale_key": {}, - "scale_query": {}, - }, - }, - utils.shapes(layer_params), - ) + Args: + data: A Tensor denoting the input data to the upcoming layer. + all_layer_outputs: A list of BaseTransformerLayer.Output that is appended with + the output of each constituent layer in the stack. - batch_size, tgt_len = 2, 6 - rng = np.random.default_rng(seed=123) - target = rng.random([batch_size, tgt_len, model_dim], dtype=np.float32) - mask = attention_bias.make_causal_biases(tgt_len) - mask = jnp.tile(mask[None, None, :, :], (batch_size, num_heads, 1, 1)) - layer_outputs, _ = F( - layer, - inputs=dict(data=jnp.asarray(target), self_attention_logit_biases=mask), - state=layer_params, - is_training=True, - prng_key=jax.random.PRNGKey(0), - ) - self.assertEqual(target.shape, layer_outputs.data.shape) - self.assertNestedAllClose(0.609666, np.mean(layer_outputs.data)) - - def test_build_remat_spec(self): - model_dim, num_heads = 6, 2 - cfg: TransformerLayer.Config = TransformerLayer.default_config().set(input_dim=model_dim) - cfg.self_attention.attention.set(num_heads=num_heads, causal=True) - cfg.feed_forward.hidden_dim = model_dim * 4 - cfg.vlog = 5 - - layer: BaseTransformerLayer = cfg.clone(name="layer").instantiate(parent=None) - layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) - - batch_size, tgt_len = 2, 5 - rng = np.random.default_rng(seed=123) - target = rng.random([batch_size, tgt_len, cfg.input_dim], dtype=np.float32) - - def f(x, layer_params): - forward_outputs, _ = F( - layer, - inputs=dict( - data=x, - ), - state=layer_params, - is_training=True, - prng_key=jax.random.PRNGKey(0), - ) - return forward_outputs + Returns: + A new Tensor with the same shape as `data`. + """ - # Ignore type errors. - spec: Any = build_remat_spec(mock.MagicMock()) - _, default_policy_backward = jax.linearize( - jax.remat(f, policy=spec.policy.instantiate(), prevent_cse=spec.prevent_cse), - jnp.asarray(target), - layer_params, - ) - _, full_remat_backward = jax.linearize( - jax.remat(f), - jnp.asarray(target), - layer_params, - ) - # Eliminated the remat of qkv_proj, context and o_proj = 5 dots. This assumes - # FlashAttention is not enabled. - self.assertEqual( - str(full_remat_backward).count(" dot_general") - - str(default_policy_backward).count(" dot_general"), - 5, - ) +def update_data_with_skip_connection(skip_connections: dict[int, int]) -> UpdateDataFn: + """Creates a function that adds skip connection to the input data tensor. - def test_build_remat_spec_neuron(self): - model_dim, num_heads = 6, 2 - cfg: TransformerLayer.Config = TransformerLayer.default_config().set(input_dim=model_dim) - cfg.self_attention.attention.set(num_heads=num_heads, causal=True) - cfg.feed_forward.hidden_dim = model_dim * 4 - cfg.vlog = 5 - - layer: BaseTransformerLayer = cfg.clone(name="layer").instantiate(parent=None) - layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) - - batch_size, tgt_len = 2, 5 - rng = np.random.default_rng(seed=123) - target = rng.random([batch_size, tgt_len, cfg.input_dim], dtype=np.float32) - - def f(x, layer_params): - forward_outputs, _ = F( - layer, - inputs=dict( - data=x, - ), - state=layer_params, - is_training=True, - prng_key=jax.random.PRNGKey(0), - ) - return forward_outputs - - # Ignore type errors. - spec: Any = build_remat_spec(mock.MagicMock()) - - policy = ( - config_for_function(save_and_offload_only_these_names_regex) - .set( - names_which_can_be_saved="|".join( - [ - RematRegexSavePatterns.QKV_PROJ.value, - RematRegexSavePatterns.LINEAR1_X.value, - ] - ), - names_which_can_be_offloaded=None, - offload_src=None, - offload_dst=None, - ) - .instantiate() - ) + Args: + skip_connections: A dictionary where keys and values represent 0-indexed layer indices. + For a (k, v) pair, the output of the v-th layer will be added to the input + of the k-th layer. - _, default_policy_backward = jax.linearize( - jax.remat(f, policy=policy, prevent_cse=spec.prevent_cse), - jnp.asarray(target), - layer_params, - ) - _, full_remat_backward = jax.linearize( - jax.remat(f), - jnp.asarray(target), - layer_params, - ) + Returns: + A function that implements skip connections, following the UpdateDataFn protocol, . + """ - # Eliminated the remat of qkv_proj and linear1_0 = 4 dots. - self.assertEqual( - str(full_remat_backward).count(" dot_general") - - str(default_policy_backward).count(" dot_general"), - 4, - ) + def update_data(data: Tensor, all_layer_outputs: list[BaseTransformerLayer.Output]) -> Tensor: + layer_index = len(all_layer_outputs) + if layer_index in skip_connections: + data += all_layer_outputs[skip_connections[layer_index]].data + return data + return update_data -class TestStackModel(BaseLayer): - """A dummy transformer stack.""" - @config_class - class Config(BaseLayer.Config): - stack: Optional[BaseStackedTransformerLayer.Config] = None # The transformer stack. - output_self_attention_kv_state: bool = False +class StackedTransformerLayer(BaseStackedTransformerLayer): + """A simple implementation of BaseStackedTransformerLayer.""" - def __init__(self, cfg: Config, *, parent: Module): + @config_class + class Config(BaseStackedTransformerLayer.Config): + """Configures StackedTransformerLayer.""" + + # If `layer` is a Config, it will be stacked cfg.num_layers times. If `layer` is a + # sequence of Configs, the sequence length should match cfg.num_layers. + layer: Union[ + BaseTransformerLayer.Config, Sequence[BaseTransformerLayer.Config] + ] = TransformerLayer.default_config() + # If set, implements the UpdateDataFn protocol to update individual layers' input + # data in some specified way. This operation is applied before calling every layer. + data_merger: Optional[InstantiableConfig[UpdateDataFn]] = None + + def __init__(self, cfg: Config, *, parent: Optional[Module]): super().__init__(cfg, parent=parent) cfg = self.config - self._add_child("stack", cfg.stack) + self._update_data = maybe_instantiate(cfg.data_merger) + + if isinstance(cfg.layer, Sequence): + layer_cfgs = cfg.layer + if len(layer_cfgs) != cfg.num_layers: + raise ValueError( + f"Number of layer configs ({len(layer_cfgs)}) must match " + f"cfg.num_layers ({cfg.num_layers})." + ) + else: + layer_cfgs = [cfg.layer] * cfg.num_layers + self._layers = [] + for i, layer_cfg in enumerate(layer_cfgs): + if layer_cfg.input_dim is not REQUIRED: + raise ValueError( + f"Do not set Config.layer.input_dim. Set Config.input_dim instead: {layer_cfg}" + ) + layer_cfg = layer_cfg.clone(input_dim=cfg.input_dim) + if cfg.peak_stochastic_depth_rate: + layer_rate = get_stochastic_depth_linear_rate( + cfg.peak_stochastic_depth_rate, + stage_order=i + 1, + num_stages=cfg.num_layers, + ) + layer_cfg.self_attention.stochastic_depth.rate = layer_rate + layer_cfg.feed_forward.stochastic_depth.rate = layer_rate + self._layers.append(self._add_child(f"layer{i}", layer_cfg)) + + def initialize_parameters_recursively( + self, prng_key: Tensor, *, prebuilt: Optional[Nested[Optional[ParameterSpec]]] = None + ) -> NestedTensor: + cfg = self.config # type: StackedTransformerLayer.Config + prng_key = split_prng_key(prng_key, cfg.num_layers) + state = {} + for i in range(cfg.num_layers): + layer = self._layers[i] + key = jax.tree.map(lambda x, index=i: x[index], prng_key.keys) + state[layer.name] = layer.initialize_parameters_recursively( + key, prebuilt=get_or_none(prebuilt, layer.name) + ) + return state - def forward(self, data, **layer_kwargs): - cfg = self.config + def _forward_for_mode( + self, + *, + mode: ForwardMode, + data: Union[Tensor, TensorSpec], + cached_states: Optional[Nested[Tensor]] = None, + **layer_kwargs, + ) -> tuple[list[Optional[Nested[Tensor]]], Optional[TransformerLayer.Output]]: + """Computes transformer stack outputs. + + Args: + mode: Configures whether `cached_states` are consumed or emitted. See `ForwardMode` for + details. + data: A Tensor or TensorSpec of shape [batch, target_length, target_dim]. + cached_states: Optional Nested Tensor as produced by `init_states`. + + Returns: + A tuple (updated_cache_states, outputs): + * updated_cached_states: An optional NestedTensor of cache states, depending on `mode`; + * outputs: An optional instance of Output (see comments on BaseStackedTransformerLayer). + + Raises: + ValueError: If `mode` is unsupported. + """ + all_layer_outputs = [] + all_layer_states = [] + + # True iff we are initializing an empty cache (i.e., not prefilling). + cache_init = mode == ForwardMode.INIT_STATES and cached_states is None + + for i, layer in enumerate(self._layers): + # Prepare inputs to the current layer. + if self._update_data is not None: + data = self._update_data(data, all_layer_outputs) + # TODO(markblee): Consider folding into _update_data. + self._update_layer_kwargs(layer_kwargs, all_layer_outputs=all_layer_outputs) + + if mode == ForwardMode.FORWARD: + layer_states, layer_outputs = None, layer(data, **layer_kwargs) + elif mode == ForwardMode.INIT_STATES: + # cached_states is allowed to be None in the case where we initialize from scratch. + layer_states, layer_outputs = layer.init_states( + time_step=cached_states, + data=data, + **layer_kwargs, + ) + elif mode == ForwardMode.EXTEND_STEP: + assert cached_states is not None + layer_states, layer_outputs = layer.extend_step( + cached_states=cached_states[i], + data=data, + **layer_kwargs, + ) + else: + raise ValueError(f"Unrecognized mode {mode}.") - # [batch, length, dim]. - output = self.stack(data, **layer_kwargs) - x = output.data - x_mean = jnp.mean(x, axis=1, keepdims=True) - # [batch, length]. - x_var = jnp.sum((x - x_mean) ** 2, axis=-1) - loss = jnp.mean(x_var) - if cfg.output_self_attention_kv_state: - return loss, {"mean": x_mean, "self_attention_kv_state": output.self_attention_kv_state} - return loss, {"mean": x_mean} - - -def _recursive_stack(inputs: Nested[Tensor], axis=0): - def stack(*xs): - return jnp.stack(xs, axis=axis) - - return {"layer": utils.vectorized_tree_map(stack, *inputs.values())} - - -def _convert_from_stacked_params( - layer_params: Nested[Tensor], *, target_stack_cfg: BaseStackedTransformerLayer.Config -) -> Nested[Tensor]: - """Converts params of a StackedTransformerLayer to params for `target_stack_cfg`.""" - # First stack to params of a RepeatedTransformerLayer. - layer_params = {"stack": {"repeat": VDict(_recursive_stack(layer_params["stack"]))}} - if target_stack_cfg.klass == RepeatedTransformerLayer: - return layer_params - elif target_stack_cfg.klass == PipelinedTransformerLayer: - pipeline_stage_cfg = target_stack_cfg.stage - num_layers_per_stage = target_stack_cfg.num_layers // target_stack_cfg.num_stages - - def reshape(x): - """Reshapes x from [num_layers, ...] to [num_stages, num_layers_per_stage, ...].""" - x_shape = list(x.shape) - return jnp.reshape(x, [target_stack_cfg.num_stages, num_layers_per_stage] + x_shape[1:]) - - pipeline_params = jax.tree.map(reshape, layer_params["stack"].pop("repeat")) - - if pipeline_stage_cfg.klass == RepeatedTransformerLayer: - layer_params["stack"]["pipeline"] = VDict({"layer": {"repeat": pipeline_params}}) - elif pipeline_stage_cfg.klass == StackedTransformerLayer: - layer_params["stack"]["pipeline"] = VDict( - { - "layer": { - f"layer{i}": jax.tree.map(lambda x, i=i: x[:, i], pipeline_params["layer"]) - for i in range(num_layers_per_stage) - } - } - ) - else: - raise NotImplementedError(target_stack_cfg) - return layer_params - else: - raise NotImplementedError(target_stack_cfg) + all_layer_states.append(layer_states) + # If initializing the cache from scratch, layer_outputs will be None. Further, `data` + # can be effectively treated as a TensorSpec, and thus does not need to be carried + # across layers. + if layer_outputs is None: + assert cache_init + continue -class NonUniformStack(StackedTransformerLayer): - def _aggregate_layer_outputs( - self, layer_outputs: Sequence[BaseTransformerLayer.Output] - ) -> BaseTransformerLayer.Output: - return BaseTransformerLayer.Output( - # Use data and self_attention_kv_state from the final layer outputs. - data=layer_outputs[-1].data, - self_attention_kv_state=layer_outputs[-1].self_attention_kv_state, - # Do not aggregate *_attention_probs. - self_attention_probs=None, - cross_attention_probs=None, - ) + all_layer_outputs.append(layer_outputs) + data = layer_outputs.data + outputs = None if cache_init else self._aggregate_layer_outputs(all_layer_outputs) + return all_layer_states, outputs -class TestStackedTransformerLayerWithKVState(NonUniformStack): - """A class with a simple override of _update_layer_kwargs for unit testing.""" + def init_states( + self, + *, + time_step: Optional[Tensor], + data: Union[Tensor, TensorSpec], + **layer_kwargs, + ) -> tuple[list[Nested[Tensor]], Optional[TransformerLayer.Output]]: + """See `BaseTransformerLayer.init_states` for details.""" + return self._forward_for_mode( + mode=ForwardMode.INIT_STATES, + cached_states=time_step, + data=data, + **layer_kwargs, + ) def _update_layer_kwargs( self, @@ -4169,1244 +3639,609 @@ def _update_layer_kwargs( *, all_layer_outputs: list[BaseTransformerLayer.Output], ): - layer_index = len(all_layer_outputs) - if layer_index == 1: - layer_kwargs["self_attention_kv_state"] = all_layer_outputs[-1].self_attention_kv_state - elif layer_index == 2: - layer_kwargs["self_attention_kv_state"] = None + """Updates `layer_kwargs` using other args. + This method is called before we invoke each layer in `self._layers`. + The updated `layer_kwargs` will be passed to the layer invocation. -class TestStackedTransformerLayerWithSkipConnection(StackedTransformerLayer): - """A class that outputs all layers' output for unit testing.""" + Args: + layer_kwargs: a dictionary of arguments that can be used by individual layers. + all_layer_outputs: a list of BaseTransformerLayer.Output that is appended with + the output of each constituent layer in the stack. + """ + pass # Do nothing by default. def _aggregate_layer_outputs( self, layer_outputs: Sequence[BaseTransformerLayer.Output], - ) -> Sequence[BaseTransformerLayer.Output]: - return layer_outputs + ) -> BaseTransformerLayer.Output: + """Aggregates outputs from the stack.""" + data = layer_outputs[-1].data + self_attention_kv_state = layer_outputs[-1].self_attention_kv_state + aux_outputs = [ + output._replace(data=None, self_attention_kv_state=None) for output in layer_outputs + ] + # Stack auxiliary outputs along axis 0. + outputs = jax.tree.map(lambda *xs: jnp.stack(xs, axis=0), *aux_outputs) + return outputs._replace(data=data, self_attention_kv_state=self_attention_kv_state) + def forward( + self, + data: Tensor, + **layer_kwargs, + ) -> TransformerLayer.Output: + _, output = self._forward_for_mode( + mode=ForwardMode.FORWARD, + data=data, + cached_states=None, + **layer_kwargs, + ) + return output + + def extend_step( + self, + cached_states: list[NestedTensor], + data: Tensor, + **layer_kwargs, + ) -> tuple[list[Nested[Tensor]], TransformerLayer.Output]: + return self._forward_for_mode( # pytype: disable=bad-return-type + mode=ForwardMode.EXTEND_STEP, + cached_states=cached_states, + data=data, + **layer_kwargs, + ) -class StackedTransformerTest(BaseTransformerTest): - """Tests StackedTransformerLayer.""" - def _stack_config( +class _TransformerRepeat(Repeat): + """A Repeat layer with layer=TransformerLayer.""" + + @config_class + class Config(Repeat.Config): + """Configures _TransformerRepeat.""" + + # The additional fields of BaseTransformerLayer.Output that should propagate as input to + # the next layer. + # + # For example, carry=("data", "self_attention_kv_state") means that both `data` and + # `self_attention_kv_state` will propagate between layers. + # + # If None, only "data" is propagated. + carry: Optional[Sequence[str]] = None + + def _forward_for_mode( self, - stack_cfg, *, - num_layers, - model_dim, - num_heads, - dtype, - remat_spec, - output_self_attention_kv_state=False, - ) -> TestStackModel.Config: - if isinstance(stack_cfg, type): - stack_cfg = stack_cfg.default_config() - if callable(remat_spec): - remat_spec = remat_spec(stack_cfg) - cfg = TestStackModel.default_config().set( - name="test", - stack=stack_cfg.set( - input_dim=model_dim, - num_layers=num_layers, - vlog=5, - dtype=dtype, - layer=TransformerLayer.default_config().set(remat_spec=remat_spec), - ), - output_self_attention_kv_state=output_self_attention_kv_state, - ) - layer_cfg = cfg.stack.layer - layer_cfg.self_attention.attention.set(num_heads=num_heads) - layer_cfg.feed_forward.hidden_dim = model_dim * 4 - layer_cfg.vlog = 5 - return cfg + mode: ForwardMode, + data: Union[Tensor, TensorSpec], + cached_states: Optional[Nested[Tensor]] = None, + **layer_kwargs, + ) -> tuple[Optional[Nested[Tensor]], Optional[TransformerLayer.Output]]: + """Computes transformer stack outputs. + + Args: + mode: Configures whether `cached_states` are consumed or emitted. See `ForwardMode` for + details. + data: A Tensor of shape [batch, target_length, target_dim]. + cached_states: Optional Nested Tensor as produced by `init_states`. + layer_kwargs: Additional kwargs to each layer. + + Returns: + A tuple (updated_cache_states, outputs): + * updated_cached_states: An optional NestedTensor of cache states, depending on `mode`; + * outputs: An optional instance of Output (see comments on BaseStackedTransformerLayer). + + Raises: + ValueError: If `mode` is unsupported. + """ + cfg: _TransformerRepeat.Config = self.config + + # True iff we are initializing an empty cache (i.e., not prefilling). + cache_init = mode == ForwardMode.INIT_STATES and cached_states is None + + if cached_states is not None: + for path, value in flatten_items(cached_states): + assert value.shape[0] == cfg.num_layers, f"{path}={shapes(value)}" + + def layer_fn(carry, x_i): + if mode == ForwardMode.FORWARD: + layer_states, layer_outputs = None, self.layer(**carry, **layer_kwargs) + elif mode == ForwardMode.INIT_STATES: + # Note that x_i can be None if initializing an empty cache. This corresponds to the + # case where `cached_states=None`. + layer_states, layer_outputs = self.layer.init_states( + time_step=x_i, **carry, **layer_kwargs + ) + elif mode == ForwardMode.EXTEND_STEP: + assert x_i is not None + layer_states, layer_outputs = self.layer.extend_step( + cached_states=x_i, **carry, **layer_kwargs + ) + else: + raise ValueError(f"Unrecognized mode {mode}.") - @parameterized.product( - transformer_type=[StackedTransformerLayer, RepeatedTransformerLayer], - # Also tests stack-of-stacks and repeat-of-stacks. - layer_type=[TransformerLayer, StackedTransformerLayer], - ) - def test_transformer_extend_step(self, transformer_type, layer_type): - batch_size, src_len, tgt_len = 10, 4, 6 - num_dec_layers, model_dim, num_heads = 3, 16, 4 - - cfg: BaseStackedTransformerLayer.Config = transformer_type.default_config().set( - name="test", - input_dim=model_dim, - num_layers=num_dec_layers, - ) - cross_atten_cfg = TransformerAttentionLayer.default_config().set( - source_dim=model_dim * 2, - structure="postnorm", - ) - cross_atten_cfg.attention.set(num_heads=num_heads) + ys = {} + if layer_states is not None: + ys["cached_states"] = layer_states - # Prepare layer config. - if layer_type == StackedTransformerLayer: - cfg.layer = layer_type.default_config().set(num_layers=2) - layer_cfg = cfg.layer.layer - else: - layer_cfg = cfg.layer - layer_cfg.self_attention.attention.set(num_heads=num_heads) - layer_cfg.cross_attention = cross_atten_cfg - layer_cfg.feed_forward.hidden_dim = model_dim * 4 + # If initializing the cache from scratch, layer_outputs will be None. + if layer_outputs is None: + assert cache_init + return carry, ys - # Instantiate transformer stack. - layer: BaseStackedTransformerLayer = cfg.instantiate(parent=None) - layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) + ys.update({k: v for k, v in layer_outputs._asdict().items() if k not in carry}) + return {k: getattr(layer_outputs, k) for k in carry}, ys - target = jax.random.normal(jax.random.PRNGKey(123), [batch_size, tgt_len, model_dim]) - source = jax.random.normal(jax.random.PRNGKey(456), [batch_size, src_len, model_dim * 2]) + if cfg.carry is None: + carry = {"data": data} + else: + layer_kwargs["data"] = data + carry = {k: layer_kwargs.pop(k) for k in cfg.carry} - self_attention_logit_biases = attention_bias.make_causal_biases(tgt_len) - cross_attention_logit_biases = ( - jnp.array(np.random.randint(0, 2, [tgt_len, src_len])) * NEG_INF - ) - return_aux = {"self_attention_probs", "cross_attention_probs"} - - forward_outputs, _ = F( - layer, - inputs=dict( - data=target, - self_attention_logit_biases=self_attention_logit_biases, - cross_attention_data=source, - cross_attention_logit_biases=cross_attention_logit_biases, - return_aux=return_aux, - ), - state=layer_params, - is_training=False, - prng_key=jax.random.PRNGKey(0), - ) - initial_state, initial_output = layer.init_states( - time_step=None, - data=TensorSpec([batch_size, tgt_len]), - ) - self.assertIsNone(initial_output) - inputs = dict( - cached_states=initial_state, cross_attention_data=source, return_aux=return_aux - ) - decoder_output = jnp.zeros(shape=[tgt_len, batch_size, model_dim]) - - # [num_dec_layers, [num_stacked_layers,] batch_size, num_heads, tgt_len, tgt_len] --> - # [tgt_len, num_dec_layers, [num_stacked_layers,] batch_size, num_heads, tgt_len]. - # The layer being stacked can itself be a stack, in which case we have an extra dim. - decoder_self_attention_probs = jnp.moveaxis( - jnp.zeros_like(forward_outputs.self_attention_probs), - -2, - 0, - ) - # [tgt_len, num_dec_layers, [num_stacked_layers,] batch_size, num_heads, src_len]. - decoder_cross_attention_probs = jnp.moveaxis( - jnp.zeros_like(forward_outputs.cross_attention_probs), - -2, - 0, - ) - for t in range(tgt_len): - inputs["data"] = jnp.expand_dims(target[:, t, :], axis=1) - inputs["self_attention_logit_biases"] = self_attention_logit_biases[ - jnp.newaxis, jnp.newaxis, t, : - ] - inputs["cross_attention_logit_biases"] = cross_attention_logit_biases[ - jnp.newaxis, jnp.newaxis, t, : - ] - (updated_states, layer_outputs), _ = F( - layer, - state=layer_params, - is_training=False, - prng_key=jax.random.PRNGKey(456), - inputs=inputs, - method="extend_step", - ) - # Check that updated_states are VDicts for the Repeated layer. - if transformer_type is RepeatedTransformerLayer: - jax.tree.map( - lambda v: self.assertIsInstance(v, utils.VDict), - updated_states, - is_leaf=lambda v: isinstance(v, dict), - ) - inputs["cached_states"] = updated_states - decoder_output = decoder_output.at[t].set(jnp.squeeze(layer_outputs.data, axis=1)) - decoder_self_attention_probs = decoder_self_attention_probs.at[t].set( - jnp.squeeze(layer_outputs.self_attention_probs, axis=-2) - ) - decoder_cross_attention_probs = decoder_cross_attention_probs.at[t].set( - jnp.squeeze(layer_outputs.cross_attention_probs, axis=-2) - ) - decoder_out_transposed = jnp.transpose(decoder_output, [1, 0, 2]) - decoder_self_attention_probs_transposed = jnp.moveaxis(decoder_self_attention_probs, 0, -2) - decoder_cross_attention_probs_transposed = jnp.moveaxis( - decoder_cross_attention_probs, 0, -2 - ) + repeat_outputs: Repeat.Output = self._run(layer_fn, carry=carry, xs=cached_states) + carry = repeat_outputs.carry + ys = repeat_outputs.ys + updated_states = ys.pop("cached_states", None) - assert_allclose(decoder_out_transposed, forward_outputs.data, atol=1e-6) - assert_allclose( - decoder_self_attention_probs_transposed, forward_outputs.self_attention_probs, atol=1e-6 - ) - assert_allclose( - decoder_cross_attention_probs_transposed, - forward_outputs.cross_attention_probs, - atol=1e-6, - ) + if cache_init: + assert ys == {} + return updated_states, None - @parameterized.product( - transformer_type=[StackedTransformerLayer, RepeatedTransformerLayer], - # Also tests stack-of-stacks and repeat-of-stacks. - layer_type=[TransformerLayer, StackedTransformerLayer], - ) - # pylint: disable-next=too-many-statements - def test_transformer_prefill_states(self, transformer_type, layer_type): - batch_size, src_len, tgt_len = 10, 4, 6 - num_dec_layers, model_dim, num_heads = 3, 16, 4 - - cfg = transformer_type.default_config().set( - name="test", - input_dim=model_dim, - num_layers=num_dec_layers, - ) - cross_atten_cfg = TransformerAttentionLayer.default_config().set( - source_dim=model_dim * 2, - structure="postnorm", - ) - cross_atten_cfg.attention.set(num_heads=num_heads) + for k in ("data", "self_attention_kv_state"): + if k in carry: + continue + v = ys.pop(k, None) + if v is not None: + # Take the output from the last layer. + if isinstance(v, KVState): + v = KVState(k_proj=v.k_proj[-1], v_proj=v.v_proj[-1]) + else: + v = v[-1] + carry[k] = v + return updated_states, TransformerLayer.Output(**carry, **ys) - # Prepare layer config. - if layer_type == StackedTransformerLayer: - cfg.layer = layer_type.default_config().set(num_layers=2) - layer_cfg = cfg.layer.layer - else: - layer_cfg = cfg.layer - layer_cfg.self_attention.attention.set(num_heads=num_heads) - layer_cfg.cross_attention = cross_atten_cfg - layer_cfg.feed_forward.hidden_dim = model_dim * 4 + def forward( + self, + data: Tensor, + **layer_kwargs, + ) -> TransformerLayer.Output: + _, output = self._forward_for_mode( + mode=ForwardMode.FORWARD, + data=data, + cached_states=None, + **layer_kwargs, + ) + return output + + def init_states( + self, + *, + time_step: Optional[Tensor], + data: Union[Tensor, TensorSpec], + **layer_kwargs, + ) -> tuple[Nested[Tensor], Optional[TransformerLayer.Output]]: + cfg: _TransformerRepeat.Config = self.config + # time_step is allowed to be None if initializing an empty cache. + if time_step is not None: + time_step = jnp.tile(time_step, [cfg.num_layers, 1]) - # Instantiate transformer stack. - layer = cfg.instantiate(parent=None) - layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) + # In the repeat case, scan requires a Tensor rather than ShapeDtypeStruct. + # Use vmap rather than materializing the Tensor. + if isinstance(data, TensorSpec): - target = jax.random.normal(jax.random.PRNGKey(123), [batch_size, tgt_len, model_dim]) - source = jax.random.normal(jax.random.PRNGKey(456), [batch_size, src_len, model_dim * 2]) + def layer_fn(_): + return self.layer.init_states(time_step=time_step, data=data, **layer_kwargs) - self_attention_logit_biases = attention_bias.make_causal_biases(tgt_len) - cross_attention_logit_biases = ( - jnp.array(np.random.randint(0, 2, [tgt_len, src_len])) * NEG_INF - ) - return_aux = {"self_attention_probs", "cross_attention_probs"} - - forward_outputs, _ = F( - layer, - inputs=dict( - data=target, - self_attention_logit_biases=self_attention_logit_biases, - cross_attention_data=source, - cross_attention_logit_biases=cross_attention_logit_biases, - return_aux=return_aux, - ), - state=layer_params, - is_training=False, - prng_key=jax.random.PRNGKey(0), - ) - # Initialize state. - time_step = jnp.arange(batch_size) - (initial_states, initial_output), _ = F( - layer, - state=layer_params, - is_training=False, - prng_key=jax.random.PRNGKey(456), - inputs=dict( - time_step=time_step, - data=target, - self_attention_logit_biases=self_attention_logit_biases, - cross_attention_data=source, - cross_attention_logit_biases=cross_attention_logit_biases, - return_aux=return_aux, - ), - method="init_states", - ) + return jax.vmap(layer_fn)(jnp.empty(cfg.num_layers)) - # Zero-out outputs starting from initial time_step, and test that we can recover the full - # outputs by calling extend_step starting from time_step. - time_step_mask = jnp.arange(tgt_len) < time_step[:, None] - # [batch, tgt_len, model_dim]. - decoder_output = initial_output.data * time_step_mask[..., None] - # [num_layers, batch, num_heads, tgt_len, tgt_len]. - decoder_self_attention_probs = ( - initial_output.self_attention_probs * time_step_mask[None, :, None, :, None] - ) - # [num_layers, batch, num_heads, tgt_len, src_len]. - decoder_cross_attention_probs = ( - initial_output.cross_attention_probs * time_step_mask[None, :, None, :, None] + return self._forward_for_mode( + mode=ForwardMode.INIT_STATES, + data=data, + cached_states=time_step, + **layer_kwargs, ) - # Transpose for simpler updates during extend_step. - # [batch, tgt_len, model_dim] --> [batch, model_dim, tgt_len]. - decoder_output = jnp.moveaxis(decoder_output, -2, -1) - # [..., tgt_len, src_len] --> [..., src_len, tgt_len]. - decoder_self_attention_probs = jnp.moveaxis(decoder_self_attention_probs, -2, -1) - decoder_cross_attention_probs = jnp.moveaxis(decoder_cross_attention_probs, -2, -1) - - # Call extend_step from time_step, ensuring that outputs match. - inputs = dict( - cached_states=initial_states, cross_attention_data=source, return_aux=return_aux - ) - while jnp.any(time_step < tgt_len): - # [batch, tgt_len=1, model_dim]. - inputs["data"] = jnp.take_along_axis( - target, time_step[:, None, None], axis=1, mode="clip" - ) - # [batch=1, tgt_len=1, tgt_len]. - inputs["self_attention_logit_biases"] = jnp.take_along_axis( - self_attention_logit_biases[None, :, :], - time_step[:, None, None], - axis=1, - mode="clip", - ) - # [batch=1, tgt_len=1, src_len]. - inputs["cross_attention_logit_biases"] = jnp.take_along_axis( - cross_attention_logit_biases[None, :, :], - time_step[:, None, None], - axis=1, - mode="clip", - ) - (updated_states, layer_outputs), _ = F( - layer, - state=layer_params, - is_training=False, - prng_key=jax.random.PRNGKey(456), - inputs=inputs, - method="extend_step", - ) - # Check that updated_states are VDicts for the Repeated layer. - if transformer_type is RepeatedTransformerLayer: - jax.tree.map( - lambda v: self.assertIsInstance(v, utils.VDict), - updated_states, - is_leaf=lambda v: isinstance(v, dict), - ) - inputs["cached_states"] = updated_states - - # [batch, model_dim, tgt_len=1] - curr_outputs = jnp.moveaxis(layer_outputs.data, -2, -1) - # [..., tgt_len, tgt_len=1] - curr_self_attention_probs = jnp.moveaxis(layer_outputs.self_attention_probs, -2, -1) - # [..., src_len, tgt_len=1] - curr_cross_attention_probs = jnp.moveaxis(layer_outputs.cross_attention_probs, -2, -1) - - # [batch, 1, tgt_len]. - oh_indices = jax.nn.one_hot(time_step, tgt_len)[:, None, :] - decoder_output = decoder_output + curr_outputs * oh_indices - # [num_layers=1, batch, num_heads=1, tgt_len=1, tgt_len]. - oh_indices = oh_indices[None, :, None, :, :] - decoder_self_attention_probs = ( - decoder_self_attention_probs + curr_self_attention_probs * oh_indices - ) - decoder_cross_attention_probs = ( - decoder_cross_attention_probs + curr_cross_attention_probs * oh_indices - ) - time_step = time_step + 1 - - # [batch, model_dim, tgt_len] --> [batch, tgt_len, model_dim]. - decoder_output = jnp.moveaxis(decoder_output, -1, -2) - # [..., src_len, tgt_len] --> [..., tgt_len, src_len]. - decoder_self_attention_probs = jnp.moveaxis(decoder_self_attention_probs, -1, -2) - decoder_cross_attention_probs = jnp.moveaxis(decoder_cross_attention_probs, -1, -2) - - assert_allclose(decoder_output, forward_outputs.data) - assert_allclose(decoder_self_attention_probs, forward_outputs.self_attention_probs) - assert_allclose(decoder_cross_attention_probs, forward_outputs.cross_attention_probs) - - def test_skip_connection(self): - batch_size = 2 - seq_len = 6 - num_heads = 2 - input_dim = 4 - hidden_dim = 8 - num_layers = 5 - layer_with_skip_input = 3 - - cfg = TestStackedTransformerLayerWithSkipConnection.default_config().set( - name="test", input_dim=input_dim, num_layers=num_layers + def extend_step( + self, + cached_states: NestedTensor, + data: Tensor, + **layer_kwargs, + ) -> tuple[NestedTensor, TransformerLayer.Output]: + return self._forward_for_mode( # pytype: disable=bad-return-type + mode=ForwardMode.EXTEND_STEP, + data=data, + cached_states=cached_states, + **layer_kwargs, ) - transformer_cfg = TransformerLayer.default_config() - transformer_cfg.self_attention.attention.num_heads = num_heads - transformer_cfg.feed_forward.hidden_dim = hidden_dim - cfg.layer = transformer_cfg - test_cfg = cfg.clone().set( - data_merger=config_for_function(update_data_with_skip_connection).set( - skip_connections={layer_with_skip_input: 1} - ) - ) +class RepeatedTransformerLayer(BaseStackedTransformerLayer): + """An implementation of BaseStackedTransformerLayer with a scan loop. - base_layer = cfg.instantiate(parent=None) - test_layer = test_cfg.instantiate(parent=None) + Compared with StackedTransformerLayer, the size of the XLA program for RepeatedTransformerLayer + does not grow proportional to the number of layers. In practice, this significantly reduces + XLA compilation overhead of large models with many layers. + """ - random_inputs = jax.random.uniform( - jax.random.PRNGKey(1), shape=(batch_size, seq_len, input_dim) - ) - state = base_layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) - base_output, _ = F( - base_layer, - is_training=True, - prng_key=jax.random.PRNGKey(123), - state=state, - inputs=dict(data=random_inputs), - ) - test_output, _ = F( - test_layer, - is_training=True, - prng_key=jax.random.PRNGKey(123), - state=state, - inputs=dict(data=random_inputs), - ) + @config_class + class Config(BaseStackedTransformerLayer.Config): + """Configures RepeatedTransformerLayer.""" - for i in range(layer_with_skip_input): - self.assertNestedAllClose( - base_output[i].data, - test_output[i].data, - ) - for i in range(layer_with_skip_input, num_layers): - self.assertNotAlmostEqual( - jnp.min(jnp.abs(base_output[i].data - test_output[i].data)), - 0.0, - ) + repeat: Repeat.Config = _TransformerRepeat.default_config() - def test_update_layer_kwargs(self): - batch_size = 2 - seq_len = 6 - num_heads = 2 - input_dim = 4 - per_head_dim = input_dim // num_heads - hidden_dim = 8 - num_layers = 3 - - # Create a StackedTransformerLayer by specifying a sequence of non-uniform layer configs. - cfg = TestStackedTransformerLayerWithKVState.default_config().set(name="test") - cfg.input_dim = input_dim - cfg.num_layers = num_layers - cfg.layer = [] - for i in range(num_layers): - transformer_cfg = TransformerLayer.default_config() - transformer_cfg.self_attention.attention.num_heads = num_heads - transformer_cfg.feed_forward.hidden_dim = hidden_dim - - if i == 1: - transformer_cfg.self_attention.attention.input_linear = QLinear.default_config() - - cfg.layer.append(transformer_cfg) - - layer: StackedTransformerLayer = cfg.instantiate(parent=None) - inputs = jax.random.uniform(jax.random.PRNGKey(1), shape=(batch_size, seq_len, input_dim)) - state = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) - outputs, _ = F( - layer, - is_training=True, - prng_key=jax.random.PRNGKey(123), - state=state, - inputs=dict(data=inputs, return_aux={"self_attention_kv_state"}), - ) - self.assertEqual( - BaseTransformerLayer.Output( - data=(batch_size, seq_len, input_dim), - self_attention_probs=None, - self_attention_kv_state=KVState( - k_proj=(batch_size, seq_len, num_heads, per_head_dim), - v_proj=(batch_size, seq_len, num_heads, per_head_dim), - ), - cross_attention_probs=None, - ), - shapes(outputs), + def __init__(self, cfg: Config, *, parent: Optional[Module]): + super().__init__(cfg, parent=parent) + cfg = self.config # type: RepeatedTransformerLayer.Config + repeat_cfg = cfg.repeat.set( + layer=cfg.layer.set(input_dim=cfg.input_dim), + num_layers=cfg.num_layers, + ) + self._add_child("repeat", repeat_cfg) + + def initialize_parameters_recursively( + self, prng_key: Tensor, *, prebuilt: Optional[NestedTensor] = None + ) -> NestedTensor: + # We need to call self.repeat.initialize_parameters_recursively() with the same prng_key + # to ensure initialization parity with StackedTransformerLayer. + return dict( + repeat=self.repeat.initialize_parameters_recursively( + prng_key, prebuilt=get_or_none(prebuilt, "repeat") + ) ) - def test_stack_vs_repeat(self): - self._compare_layers(StackedTransformerLayer, RepeatedTransformerLayer) - - def test_stack_vs_repeat_bfloat16(self): - # FIXME(rpang): fix the following test, which is caused by different behaviors of bfloat16 - # to float32 casting. - # self._compare_layers(StackedTransformerLayer, RepeatedTransformerLayer, - # dtype=jnp.bfloat16) - pass - - def test_stack_vs_repeat_remat_everything_saveable(self): - self._compare_layers( - StackedTransformerLayer, - RepeatedTransformerLayer, - remat_spec=RematSpec(policy=jax_remat_policies.everything_saveable), - ) + def forward( + self, + data: Tensor, + **layer_kwargs, + ) -> TransformerLayer.Output: + return self.repeat(data, **layer_kwargs) - def test_stack_vs_repeat_with_build_remat_spec(self): - self._compare_layers( - StackedTransformerLayer, - RepeatedTransformerLayer, - remat_spec=build_remat_spec, - ) + def init_states(self, *args, **kwargs): + cached_states, output = self.repeat.init_states(*args, **kwargs) + return VDict(repeat=cached_states), output - @parameterized.product( - stage_cls=[StackedTransformerLayer, RepeatedTransformerLayer], - schedule_cls=[GPipeSchedule, StreamSchedule], - remat_spec=[None, RematSpec(policy=jax_remat_policies.everything_saveable)], - ) - def test_stack_vs_pipeline( + def extend_step( self, - stage_cls: type[BaseTransformerLayer], - schedule_cls: type[BaseSchedule], - remat_spec: Optional[RematSpec], - ): - pipelined_cfg: PipelinedTransformerLayer.Config = PipelinedTransformerLayer.default_config() - pipelined_cfg.stage = stage_cls.default_config().set(layer=None) - pipelined_cfg.pipeline.schedule = schedule_cls.default_config() - - # If using StreamSchedule, we expect `num_microbatches` to be divisible by `num_stages`. - if schedule_cls is StreamSchedule: - # num_microbatches = 6, num_stages = 3, microbatch_size = 2 - batch_size, num_layers = 12, 6 - else: - # num_microbatches = 5, num_stages = 3, microbatch_size = 2 - batch_size, num_layers = 10, 6 - - pipelined_cfg.num_microbatches = batch_size // 2 - pipelined_cfg.num_stages = num_layers // 2 - self._compare_layers( - StackedTransformerLayer, - pipelined_cfg, - remat_spec=remat_spec, - batch_size=batch_size, - num_layers=num_layers, + cached_states: NestedTensor, + data: Tensor, + **layer_kwargs, + ) -> tuple[list[NestedTensor], TransformerLayer.Output]: + repeat_cached_states, output = self.repeat.extend_step( + cached_states=cached_states["repeat"], + data=data, + **layer_kwargs, ) + return VDict(repeat=repeat_cached_states), output + - # pylint: disable-next=too-many-statements,too-many-branches - def _compare_layers( +class _TransformerPipeline(Pipeline): + """Transformer pipeline layer.""" + + def forward( self, - *stack_configs, - dtype=jnp.float32, - remat_spec=None, - batch_size: int = 10, - num_layers: int = 6, - ): - assert stack_configs[0] == StackedTransformerLayer, stack_configs[0] - with utils.numeric_checks(False): - tgt_len, model_dim, num_heads = 5, 8, 4 + data: Tensor, + *, + return_aux: Optional[set[str]] = None, + **kwargs, + ) -> TransformerLayer.Output: + carry_in = dict(data=data) + return_aux = return_aux or set() + + # Even though attention logit biases do not change across layers, we + # include them in the carry so that they are aligned with the microbatches. + carry_in.update(kwargs) + carry_in = self._to_microbatches(carry_in) + self.vlog(3, "carry_in=%s", shapes(carry_in)) + + def layer_fn(carry, _): + layer_outputs: TransformerLayer.Output = self.layer(**carry) + carry.pop("data") + return dict(**carry, data=layer_outputs.data), { + k: v if k in return_aux else None + for k, v in layer_outputs._asdict().items() + if k != "data" + } + + pipeline_outputs: Pipeline.Output = self._run(layer_fn, carry_in) + carry_out = self._from_microbatches(pipeline_outputs.carry["data"]) - target = jax.random.normal( - jax.random.PRNGKey(123), [batch_size, tgt_len, model_dim], dtype=dtype + ys = pipeline_outputs.ys + self.vlog(3, "ys=%s", shapes(ys)) + return TransformerLayer.Output(data=carry_out, **ys) + + +class PipelinedTransformerLayer(BaseStackedTransformerLayer): + """An implementation of BaseStackedTransformerLayer with pipeline model parallelism.""" + + @config_class + class Config(BaseStackedTransformerLayer.Config): + """Configures PipelinedTransformerLayer.""" + + # The number of pipeline stages. Must evenly divide `num_layers`. + num_stages: Required[int] = REQUIRED + # The number of pipeline microbatches. Must evenly divide batch size. + num_microbatches: Required[int] = REQUIRED + # Config for each stage in the pipeline. + stage: BaseLayer.Config = StackedTransformerLayer.default_config().set(layer=None) + # Config for the pipeline implementation, such as pipeline schedule. + pipeline: _TransformerPipeline.Config = _TransformerPipeline.default_config() + + def __init__(self, cfg: Config, *, parent: Optional[Module]): + super().__init__(cfg, parent=parent) + cfg = self.config # type: PipelinedTransformerLayer.Config + if cfg.num_layers % cfg.num_stages != 0: + raise ValueError(f"num_stages {cfg.num_stages} must divide num_layers {cfg.num_layers}") + num_layers_per_stage = cfg.num_layers // cfg.num_stages + stage_cfg = cfg.stage.set( + input_dim=cfg.input_dim, layer=cfg.layer, num_layers=num_layers_per_stage + ) + pipeline_cfg = cfg.pipeline.set( + layer=stage_cfg, num_layers=cfg.num_stages, num_microbatches=cfg.num_microbatches + ) + self._add_child("pipeline", pipeline_cfg) + + def initialize_parameters_recursively( + self, prng_key: Tensor, *, prebuilt: Optional[Nested[Optional[ParameterSpec]]] = None + ) -> NestedTensor: + cfg = self.config # type: PipelinedTransformerLayer.Config + # We pre-split all num_layers keys to ensure initialization parity with + # StackedTransformerLayer. + prng_key = split_prng_key(prng_key, (cfg.num_stages, cfg.num_layers // cfg.num_stages)) + return dict( + pipeline=self.pipeline.initialize_parameters_recursively( + prng_key, prebuilt=get_or_none(prebuilt, "pipeline") ) - rand_mask = _random_mask(jax.random.PRNGKey(123), tgt_len, tgt_len) - rand_mask = jnp.tile(rand_mask[None, None, :, :], (batch_size, num_heads, 1, 1)) - - all_params = [] - all_outputs = [] - all_gradients = [] - all_updates = [] - stacked_layer_params = None - for stack_cfg in stack_configs: - cfg = self._stack_config( - stack_cfg, - num_layers=num_layers, - model_dim=model_dim, - num_heads=num_heads, - dtype=dtype, - remat_spec=remat_spec, - ) - cls = cfg.stack.klass - layer: TestStackModel = cfg.instantiate(parent=None) - - param_specs = layer.create_parameter_specs_recursively() - logging.info( - "%s.factorization_specs=%s", - cls, - jax.tree.map(lambda x: x.factorization, param_specs), - ) - layer_params = layer.initialize_parameters_recursively( - prng_key=jax.random.PRNGKey(123) - ) - logging.info( - "%s.params=%s", - cls, - [ - f"{path}={value.dtype}({value.shape})" - for path, value in flatten_items(layer_params) - ], - ) - if cls == StackedTransformerLayer: - stacked_layer_params = copy.deepcopy(layer_params) - else: - layer_params = _convert_from_stacked_params( - stacked_layer_params, target_stack_cfg=cfg.stack - ) - logging.info( - "Converted: %s.params=%s", - cls, - [ - f"{path}={value.dtype}({value.shape})" - for path, value in flatten_items(layer_params) - ], - ) + ) - def _loss(layer_params, data, mask, layer=layer): - layer_outputs, layer_output_collection = F( - layer, - inputs=dict( - data=data, self_attention_logit_biases=mask, target_segment_ids=None - ), - state=layer_params, - is_training=True, - prng_key=jax.random.PRNGKey(0), - ) - loss, aux = layer_outputs - return loss, (aux, layer_output_collection) + def forward( + self, + data: Tensor, + **kwargs, + ) -> TransformerLayer.Output: + return self.pipeline(data, **kwargs) - value, grads = jax.value_and_grad(_loss, has_aux=True)( - layer_params, jnp.asarray(target), rand_mask - ) - loss, (aux, layer_output_collection) = value - layer_outputs = (loss, aux) - - # Note that we do not compare summaries across stack layer types because: - # (1) attention layers do not emit summaries yet; - # (2) pipelines emit per-microbatch summaries which have a different structure - # than summaries from other stack layers. - summaries = layer_output_collection.summaries - logging.info( - "layer_outputs=%s summaries=%s", - shapes(flatten_items(layer_outputs)), - shapes(flatten_items(summaries)), - ) - logging.info( - "global_grad_norm=%s, grads=%s", - optax.global_norm(grads), - shapes(flatten_items(grads)), - ) + # TODO(sneha): extend_step - optimizer = adafactor_optimizer( - learning_rate=0.1, - b1=0.9, - b2=0.98, - multiply_by_parameter_scale=False, - clipping_threshold=1.0, - eps=1e-2, - ) - opt_params = jax.tree.map( - lambda spec, p: OptParam( - value=p, - factorization_spec=spec.factorization, - weight_decay_scale=spec.weight_decay_scale, - ), - param_specs, - layer_params, - ) - opt_state = optimizer.init(opt_params) - logging.info("opt_state=%s", shapes(opt_state)) - updates, opt_state = optimizer.update(grads, opt_state, opt_params) - def rms_norm(x): - return jnp.sqrt(jnp.mean(x**2)) +# Adapted from jax source code to support regex. Reference: +# https://github.com/jax-ml/jax/blob/0d36b0b433a93c707f86dac89b0c05d40302775a/jax/_src/ad_checkpoint.py#L120 +# TODO(kelvin-zou): deprecated, keep it here to minimize distruption to the golden configs. +# Please use axlearn.common.utils.extended_checkpoint_policies instead. +def _save_and_offload_only_these_names_regex( + *, + names_which_can_be_saved: SavePattern, + names_which_can_be_offloaded: SavePattern, + offload_src: str, + offload_dst: str, +) -> OffloadPolicy: + return save_and_offload_only_these_names_regex( + names_which_can_be_saved=names_which_can_be_saved, + names_which_can_be_offloaded=names_which_can_be_offloaded, + offload_src=offload_src, + offload_dst=offload_dst, + ) + - if cls == StackedTransformerLayer: - update_norms = jax.tree.map(rms_norm, updates) - else: - update_norms = jax.vmap(lambda x, norm=rms_norm: jax.tree.map(norm, x))(updates) - logging.info( - "global_update_norm=%s update_norms=%s", - optax.global_norm(updates), - dict(utils.flatten_items(update_norms)), - ) +# Regex patterns for matching remat names +class RematRegexSavePatterns(enum.Enum): + QKV_PROJ = r".*[kqv]_proj" + O_PROJ = r".*o_proj" + CONTEXT = r".*context" + LINEAR1_X = r".*linear1_[01]" + LINEAR2_X = r".*linear2_[01]" + SELF_ATTENTION = ".*([qkvo]_proj|context)" + FEED_FORWARD = "|".join([LINEAR1_X, LINEAR2_X]) + + +def build_remat_spec( + stack_cfg: Union[ + BaseStackedTransformerLayer.Config, "RepeatedConformerLayer.Config" # type: ignore + ], + save_pattern: SavePattern = RematRegexSavePatterns.SELF_ATTENTION.value, + offload_pattern: SavePattern = None, + offload_dst: str = "pinned_host", +) -> Optional[RematSpec]: + """Configures how the Transformer or Conformer stack will save the linearization points. + + We try to save activations from the forward pass that are inefficient to recompute on the + backward pass. We choose the linearization points in the MultiHeadAttention layer, as that + demonstrated (empirically) the best throughput, allowing us to train with a batch size of 16 on + gpt2-10b with adamw and full sharding across 4 TPU v4 chips and a RepeatedTransformerLayer, + with 1.8x the step time of a stacked layer with a batch size of 8 and the same sharding config. + + For conformer model, we start from the same remat policy as language models. + TODO(zhiyunlu): investigate Conformer model's memory/step-time tradeoffs. Possibly we + need to save points in the LConv module. - if cls == StackedTransformerLayer: - for x in (layer_params, grads, updates): - x["stack"] = _recursive_stack(x["stack"]) - - if cls == RepeatedTransformerLayer: - for x in (layer_params, grads, updates): - x["stack"] = x["stack"]["repeat"] - - if cls == PipelinedTransformerLayer: - for x in (layer_params, grads, updates): - logging.info("x=%s", shapes(x)) - if cfg.stack.stage.klass == StackedTransformerLayer: - # First stack within each stage. - x["stack"]["pipeline"]["layer"] = _recursive_stack( - x["stack"]["pipeline"]["layer"], axis=1 - ) - logging.info("x=%s", shapes(x)) - elif cfg.stack.stage.klass == RepeatedTransformerLayer: - x["stack"]["pipeline"]["layer"] = x["stack"]["pipeline"]["layer"][ - "repeat" - ] - else: - raise NotImplementedError(cfg.stack.stage.klass) - - # Then reshape across stages. - x["stack"] = jax.tree.map( - lambda x: x.reshape([num_layers] + list(x.shape[2:])), - x["stack"]["pipeline"]["layer"], - ) - - all_params.append(layer_params) - all_outputs.append(layer_outputs) - all_gradients.append(grads) - all_updates.append(updates) - - if cls == StackedTransformerLayer: - one_layer = layer.stack.layer0 - elif cls == RepeatedTransformerLayer: - one_layer = layer.stack.repeat.layer - else: - one_layer = None + Args: + stack_cfg: A transformer config. + save_pattern: Activation regex pattern to save in HBM. + offload_pattern: Activation regex pattern to offload to `offload_dst`. + offload_dst: Destination of remat checkptoing offloading. Relevant Maxtext example: + https://github.com/google/maxtext/blob/ebd39aa64d670fa13a313b6f776e01ad9e450321/MaxText/layers/models.py#L230. - # pylint: disable=protected-access - if one_layer is not None: - logging.info( - "%s._remat_methods = %s", one_layer.path(), one_layer._remat_methods - ) - if remat_spec is not None: - self.assertSequenceEqual( - one_layer._remat_methods, ["forward"], msg=one_layer.path() - ) - else: - self.assertEmpty(one_layer._remat_methods, msg=one_layer.path()) - # pylint: enable=protected-access - - self.assertNestedAllClose(all_params[0], all_params[1]) - self.assertNestedAllClose(all_outputs[0], all_outputs[1]) - self.assertNestedAllClose(all_gradients[0], all_gradients[1]) - self.assertNestedAllClose(all_updates[0], all_updates[1]) - - @parameterized.parameters(StackedTransformerLayer, RepeatedTransformerLayer) - def test_stacked_decoding(self, stack_cls): - model_dim, num_heads = 6, 2 - cfg = stack_cls.default_config().set(num_layers=5, input_dim=model_dim) - layer_cfg = cfg.layer - layer_cfg.self_attention.attention.set(num_heads=num_heads, causal=True) - layer_cfg.feed_forward.hidden_dim = model_dim * 4 - self._test_forward_vs_extend_step(cfg) - self._test_decoder_with_transformer(cfg) - - @parameterized.product( - outer_stack_cls=(StackedTransformerLayer, RepeatedTransformerLayer), - inner_stack_cls=(StackedTransformerLayer, RepeatedTransformerLayer), + Returns: + None (if no rematerialization is needed) or a RematSpec. + """ + # TODO(markblee): Switch to using isinstance everywhere. + if stack_cfg.klass is PipelinedTransformerLayer: + return None + + policy = config_for_function(_save_and_offload_only_these_names_regex).set( + names_which_can_be_saved=save_pattern, + names_which_can_be_offloaded=offload_pattern, + offload_src="device", + offload_dst=offload_dst, ) - def test_nested_stacked_decoding(self, outer_stack_cls, inner_stack_cls): - model_dim, num_heads = 6, 2 - cfg = outer_stack_cls.default_config().set(num_layers=2, input_dim=model_dim) - cfg.layer = inner_stack_cls.default_config().set(num_layers=3) - layer_cfg = cfg.layer.layer - layer_cfg.self_attention.attention.set(num_heads=num_heads, causal=True) - layer_cfg.feed_forward.hidden_dim = model_dim * 4 - self._test_forward_vs_extend_step(cfg) - self._test_decoder_with_transformer(cfg) - - @parameterized.parameters(None, 0.0, 0.2, 1.0) - def test_stochastic_depth(self, rate): - batch_size, tgt_len = 10, 6 - num_dec_layers, model_dim, num_heads = 3, 16, 4 - model_dim = 16 - num_heads = 4 - cfg = StackedTransformerLayer.default_config().set( - name="test", - input_dim=model_dim, - num_layers=num_dec_layers, - peak_stochastic_depth_rate=rate, - ) - layer_cfg = cfg.layer - layer_cfg.self_attention.attention.set(num_heads=num_heads) - layer_cfg.feed_forward.hidden_dim = model_dim * 4 - - if rate is None or 0 <= rate < 1: - layer = cfg.instantiate(parent=None) - layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) - target = jax.random.normal(jax.random.PRNGKey(123), [batch_size, tgt_len, model_dim]) - F( - layer, - inputs=dict(data=target), - state=layer_params, - is_training=True, - prng_key=jax.random.PRNGKey(0), - ) - else: - with self.assertRaises(ValueError): - cfg.instantiate(parent=None) - - @parameterized.product(is_training=(True, False)) - def test_stacked_transformer_with_seq_layer_cfgs(self, is_training): - batch_size = 2 - seq_len = 16 - input_dim = 4 - hidden_dim = 16 - num_layers = 4 - num_heads = 4 - - # Create a StackedTransformerLayer by specifying a sequence of layer configs. - cfg = StackedTransformerLayer.default_config().set(name="test") - cfg.input_dim = input_dim - cfg.num_layers = num_layers - transformer_cfg = TransformerLayer.default_config() - transformer_cfg.self_attention.attention.num_heads = num_heads - transformer_cfg.feed_forward.hidden_dim = hidden_dim - cfg.layer = (transformer_cfg,) * num_layers - layer: StackedTransformerLayer = cfg.instantiate(parent=None) - inputs = jax.random.uniform(jax.random.PRNGKey(1), shape=(batch_size, seq_len, input_dim)) - state = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) - outputs, _ = F( - layer, - is_training=is_training, - prng_key=jax.random.PRNGKey(123), - state=state, - inputs=dict(data=inputs), - ) - # Create a ref StackedTransformerLayer with repeating the default layer cfg. - ref_cfg = StackedTransformerLayer.default_config().set(name="test") - ref_cfg.input_dim = input_dim - ref_cfg.num_layers = num_layers - ref_cfg.layer.self_attention.attention.num_heads = num_heads - ref_cfg.layer.feed_forward.hidden_dim = hidden_dim - ref_layer: StackedTransformerLayer = ref_cfg.instantiate(parent=None) - ref_outputs, _ = F( - ref_layer, - is_training=is_training, - prng_key=jax.random.PRNGKey(123), - state=state, - inputs=dict(data=inputs), - ) - assert_allclose(outputs.data, ref_outputs.data) - assert_allclose(outputs.self_attention_probs, ref_outputs.self_attention_probs) - - @parameterized.product(is_training=(True, False)) - def test_stacked_transformer_with_non_uniform_layers(self, is_training): - """Tests that a custom StackedTransformerLayer can support non-uniform layers.""" - batch_size = 2 - seq_len = 16 - input_dim = 4 - hidden_dim = 16 - num_layers = 2 - - # Create a StackedTransformerLayer by specifying a sequence of non-uniform layer configs. - cfg = NonUniformStack.default_config().set(name="test") - cfg.input_dim = input_dim - cfg.num_layers = num_layers - cfg.layer = [] - for i in range(num_layers): - transformer_cfg = TransformerLayer.default_config() - # Different numbers of heads between the layers. - transformer_cfg.self_attention.attention.num_heads = 2 if i == 0 else 1 - transformer_cfg.feed_forward.hidden_dim = hidden_dim - cfg.layer.append(transformer_cfg) - layer: StackedTransformerLayer = cfg.instantiate(parent=None) - inputs = jax.random.uniform(jax.random.PRNGKey(1), shape=(batch_size, seq_len, input_dim)) - state = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) - outputs, _ = F( - layer, - is_training=is_training, - prng_key=jax.random.PRNGKey(123), - state=state, - inputs=dict(data=inputs, return_aux={"self_attention_kv_state"}), - ) - self.assertEqual( - BaseTransformerLayer.Output( - data=(2, 16, 4), - self_attention_probs=None, - self_attention_kv_state=KVState(k_proj=(2, 16, 1, 4), v_proj=(2, 16, 1, 4)), - cross_attention_probs=None, - ), - shapes(outputs), - ) - @parameterized.parameters( - [None, False], - [("data",), False], - [("data",), True], - [("data", "self_attention_kv_state"), True], + return RematSpec( + prevent_cse=stack_cfg.klass is StackedTransformerLayer, + # If we are running inside a jax.lax.scan (Repeated/Pipelined transformers + # or Repeated Conformers) we can enable common subexpression elimination optimizations. + policy=policy, ) - def test_repeated_layer_with_custom_carry(self, repeat_carry, precomputed_kv_state): - """Tests RepeatedTransformerLayer with customized `carry`.""" - batch_size = 1 - seq_len = 16 - input_dim = 4 - num_heads = 2 - head_dim = input_dim // num_heads - num_layers = 3 - - cfg = self._stack_config( - RepeatedTransformerLayer, - num_layers=num_layers, - model_dim=input_dim, - num_heads=num_heads, - dtype=jnp.float32, - remat_spec=None, - output_self_attention_kv_state=True, - ) - cfg.stack.repeat.carry = repeat_carry - cfg.stack.layer.remat_spec = build_remat_spec(cfg.stack) - if precomputed_kv_state: - kv_shape = (batch_size, seq_len, num_heads, head_dim) - kv_state = KVState( - k_proj=jax.random.normal(key=jax.random.PRNGKey(1), shape=kv_shape), - v_proj=jax.random.normal(key=jax.random.PRNGKey(2), shape=kv_shape), - ) - cfg.stack.layer.self_attention.attention.input_linear = QLinear.default_config() - expected_output = 1.8719857 - else: - kv_state = None - # carry=None and carry=("data",) are equivalent. - expected_output = 5.3901253 - - layer = cfg.instantiate(parent=None) - state = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) - inputs = jax.random.uniform(jax.random.PRNGKey(1), shape=(batch_size, seq_len, input_dim)) - outputs, _ = F( - layer, - is_training=True, - prng_key=jax.random.PRNGKey(123), - state=state, - inputs=dict( - data=inputs, - self_attention_kv_state=kv_state, - return_aux={"self_attention_kv_state"}, - ), - ) - self.assertNestedAllClose(expected_output, outputs[0]) - if precomputed_kv_state: - self.assertNestedAllClose(kv_state, outputs[1]["self_attention_kv_state"]) - else: - self.assertIsInstance(outputs[1]["self_attention_kv_state"], KVState) - def test_pipeline_return_aux(self): - batch_size, num_heads, seq_len, dim = 2, 3, 4, 6 - class DummyTransformerLayer(TransformerLayer): - def forward(self, data, **kwargs): - return TransformerLayer.Output( - data=data, - self_attention_probs=jnp.empty([batch_size, num_heads, seq_len, seq_len]), - self_attention_kv_state=KVState( - k_proj=jnp.empty([batch_size, seq_len, num_heads, dim]), - v_proj=jnp.empty([batch_size, seq_len, num_heads, dim]), - ), - ) +class AttentionLogitBiasLayer(BaseLayer): + """Base attention logit bias layer. - cfg: PipelinedTransformerLayer.Config = PipelinedTransformerLayer.default_config().set( - num_stages=2, - num_microbatches=2, - num_layers=2, - input_dim=dim, - layer=DummyTransformerLayer.default_config(), - ) - cfg.layer.self_attention.attention.set(num_heads=num_heads) - cfg.layer.feed_forward.hidden_dim = scaled_hidden_dim(4) - - with test_utils.bind_layer(cfg) as layer: - data = jax.random.uniform(layer.prng_key, shape=[2, 3, 4]) - out = layer(data, return_aux={"self_attention_kv_state"}) - self.assertNestedAllClose(data, out.data) - self.assertIsNone(out.self_attention_probs) - self.assertIsNotNone(out.self_attention_kv_state) - - @parameterized.parameters( - ([],), - (["self_attention"],), - (["feed_forward"],), - (["self_attention", "feed_forward"],), - ) - def test_initialize_parameters_recursively(self, prebuilt_layers: list[str]): - """Tests initialize_parameters_recursively with various prebuilt layers.""" - input_dim = 4 - num_heads = 2 - num_layers = 3 - - cfg = self._stack_config( - RepeatedTransformerLayer, - num_layers=num_layers, - model_dim=input_dim, - num_heads=num_heads, - dtype=jnp.float32, - remat_spec=None, - output_self_attention_kv_state=True, - ) - cfg.stack.layer.remat_spec = build_remat_spec(cfg.stack) - layer = cfg.instantiate(parent=None) - param_specs = layer.create_parameter_specs_recursively() - initialized_from_scratch = layer.initialize_parameters_recursively( - prng_key=jax.random.PRNGKey(123) - ) - jax.tree_util.tree_map_with_path( - lambda path, spec, param: self.assertEqual(param.shape, spec.shape, path), - param_specs, - initialized_from_scratch, - ) + The attention logit bias layer should have input_ids as input. + """ - def has_prebuilt_layers(path): - for prebuilt_layer in prebuilt_layers: - for part in path: - if prebuilt_layer == part.key: - return True - return False + def forward(self, *, segment_ids: Tensor, positions: Tensor) -> Tensor: + """Produces attention logit biases. + + Args: + segment_ids: An integer Tensor of shape [batch_size, seq_len] with values in + [0, num_segments). Tokens are only allowed to attend to other tokens within the same + segment. segment_ids == 0 represents paddings. + positions: An Tensor of broadcastable shape to `input_ids` with values in [0, seq_len). + This can be used to produce biases for packed inputs. + + Returns: + A float attention logit biases of shape [batch_size, 1, seq_len, seq_len] or + [batch_size, num_heads, seq_len, seq_len]. + Output[b,i,j] is -inf iff attention is disabled with query=input[b, i] and + key=input[b, j]. + """ + raise NotImplementedError(type(self)) - # ParameterSpec for a prebuilt param, None otherwise. - prebuilt_specs = jax.tree_util.tree_map_with_path( - lambda path, spec: spec if has_prebuilt_layers(path) else None, param_specs - ) - if prebuilt_layers: - self.assertNotEmpty(jax.tree_util.tree_leaves(prebuilt_specs)) - initialized_state = layer.initialize_parameters_recursively( - prng_key=jax.random.PRNGKey(123), prebuilt=prebuilt_specs - ) - def validate_initialized(path, spec, initialized, prebuilt): - if prebuilt is None: - self.assertEqual(spec.shape, initialized.shape, path) - else: - self.assertIsNone(initialized) +def compute_padding_biases(input_ids: Tensor, *, pad_token_id: Optional[int]) -> Tensor: + """Compute the logits bias to disable attention to/from paddings. - jax.tree_util.tree_map_with_path( - validate_initialized, param_specs, initialized_state, prebuilt_specs - ) + Args: + input_ids: A Tensor of shape [batch_size, seq_len]. + pad_token_id: An int representing the padded token ID or None. + Returns: + A float logit biases of shape [batch_size, 1, seq_len, seq_len]. + """ + if pad_token_id is None: + batch_size, seq_len = input_ids.shape + return jnp.zeros([batch_size, 1, seq_len, seq_len]) + padding_bias = (input_ids == pad_token_id) * NEG_INF + return padding_bias[:, None, None, :] + padding_bias[:, None, :, None] -class ConfigHelperTest(TestCase): - """Tests config utils.""" - - @parameterized.product( - self_attention_input_linear_cfg=( - QKVLinear.default_config(), - FusedQKVLinear.default_config(), - RoFormerQKVLinear.default_config().set(input_linear=FusedQKVLinear.default_config()), - ), - cross_attention_cfg=(None, TransformerAttentionLayer.default_config()), - batch_axis_names=("data", ("replica", "data", "fsdp")), - fsdp_axis_names=("fsdp",), - tp_axis_names=("model",), - seq_axis_names=("seq",), - ) - def test_set_double_shard_weights_config( - self, - self_attention_input_linear_cfg, - cross_attention_cfg, - batch_axis_names, - fsdp_axis_names, - tp_axis_names, - seq_axis_names, - ): - cfg: TransformerLayer.Config = TransformerLayer.default_config().set( - cross_attention=cross_attention_cfg - ) - cfg.self_attention.attention.input_linear = self_attention_input_linear_cfg - set_double_shard_weights_config( - cfg, - batch_axis_names=batch_axis_names, - fsdp_axis_names=fsdp_axis_names, - tp_axis_names=tp_axis_names, - seq_axis_names=seq_axis_names, - ) - ff_layer = cfg.feed_forward - self.assertSequenceEqual( - ff_layer.linear1.param_partition_spec, (fsdp_axis_names, tp_axis_names) - ) - self.assertSequenceEqual( - ff_layer.linear2.param_partition_spec, (tp_axis_names, fsdp_axis_names) - ) - self.assertSequenceEqual( - ff_layer.linear1.output_partition_spec, - (batch_axis_names, seq_axis_names, tp_axis_names), - ) - self.assertSequenceEqual( - ff_layer.linear2.output_partition_spec, - (batch_axis_names, seq_axis_names, tp_axis_names), - ) +class CausalAttentionLogitBiasLayer(AttentionLogitBiasLayer): + """Causal attention logit bias layer.""" - self_atten = cfg.self_attention.attention - input_linear = self_atten.input_linear - if isinstance(self_attention_input_linear_cfg, RoFormerQKVLinear.Config): - input_linear = input_linear.input_linear - # Shard weights. - self.assertSequenceEqual( - input_linear.layer.param_partition_spec, - (fsdp_axis_names, tp_axis_names, None), - ) - self.assertSequenceEqual( - self_atten.output_linear.param_partition_spec, (fsdp_axis_names, tp_axis_names, None) + def forward(self, *, segment_ids: Tensor, positions: Tensor) -> Tensor: + """Refer to AttentionLogitBiasLayer.forward for docstring.""" + # Note: padding tokens are not explicitly masked. + causal_bias = (positions[:, None, :, None] < positions[:, None, None, :]) * NEG_INF + return apply_attention_logit_biases( + causal_bias, make_segment_mask(source_segments=segment_ids, target_segments=segment_ids) ) - if cross_attention_cfg is None: - self.assertIsNone(cfg.cross_attention) - else: - cross_atten = cfg.cross_attention.attention - # Shard weights. - self.assertSequenceEqual( - cross_atten.input_linear.layer.param_partition_spec, - (fsdp_axis_names, tp_axis_names, None), - ) - self.assertSequenceEqual( - cross_atten.output_linear.param_partition_spec, - (fsdp_axis_names, tp_axis_names, None), - ) - @parameterized.product( - self_attention_input_linear_cfg=( - QKVLinear.default_config(), - FusedQKVLinear.default_config(), - ), - cross_attention_cfg=(None, TransformerAttentionLayer.default_config()), - batch_axis_names=("data", ("replica", "data", "fsdp")), - fsdp_axis_names=("fsdp",), - tp_axis_names=("model",), - seq_axis_names=("seq",), - ) - def test_set_double_shard_weights_config_for_list_of_configs( - self, - self_attention_input_linear_cfg, - cross_attention_cfg, - batch_axis_names, - fsdp_axis_names, - tp_axis_names, - seq_axis_names, - ): - cfg_layer: TransformerLayer.Config = TransformerLayer.default_config().set( - cross_attention=cross_attention_cfg - ) - cfg_layer.self_attention.attention.input_linear = self_attention_input_linear_cfg - cfg_layers = [cfg_layer, cfg_layer] - set_double_shard_weights_config( - cfg_layers, - batch_axis_names=batch_axis_names, - fsdp_axis_names=fsdp_axis_names, - tp_axis_names=tp_axis_names, - seq_axis_names=seq_axis_names, - ) +class FullAttentionLogitBiasLayer(AttentionLogitBiasLayer): + """Full attention logit bias layer.""" - for cfg in cfg_layers: - ff_layer = cfg.feed_forward - self.assertSequenceEqual( - ff_layer.linear1.param_partition_spec, (fsdp_axis_names, tp_axis_names) - ) - self.assertSequenceEqual( - ff_layer.linear2.param_partition_spec, (tp_axis_names, fsdp_axis_names) - ) - self.assertSequenceEqual( - ff_layer.linear1.output_partition_spec, - (batch_axis_names, seq_axis_names, tp_axis_names), - ) - self.assertSequenceEqual( - ff_layer.linear2.output_partition_spec, - (batch_axis_names, seq_axis_names, tp_axis_names), - ) + def forward(self, *, segment_ids: Tensor, positions: Tensor) -> Tensor: + """Refer to AttentionLogitBiasLayer.forward for docstring.""" + del positions + return make_segment_mask(source_segments=segment_ids, target_segments=segment_ids) - self_atten = cfg.self_attention.attention - # Shard weights. - self.assertSequenceEqual( - self_atten.input_linear.layer.param_partition_spec, - (fsdp_axis_names, tp_axis_names, None), - ) - self.assertSequenceEqual( - self_atten.output_linear.param_partition_spec, - (fsdp_axis_names, tp_axis_names, None), - ) - if cross_attention_cfg is None: - self.assertIsNone(cfg.cross_attention) - else: - cross_atten = cfg.self_attention.attention - # Shard weights. - self.assertSequenceEqual( - cross_atten.input_linear.layer.param_partition_spec, - (fsdp_axis_names, tp_axis_names, None), - ) - self.assertSequenceEqual( - cross_atten.output_linear.param_partition_spec, - (fsdp_axis_names, tp_axis_names, None), - ) +def alibi_get_slopes(num_heads: int) -> list: + """Get the slopes for different attention heads defined in ALiBi paper. + This is a direct copy from ALiBi codebase. + Ref: + https://github.com/ofirpress/attention_with_linear_biases/tree/3b7c2eca/fairseq/models/transformer.py#L742-L752 -class PositionalEmbeddingTest(TestCase): - """Tests PositionalEmbedding.""" + Args: + num_heads: An integer for the number of attention heads. - def test_learned_positional_embedding_1d(self): - """ - Simple test that LearnedPositionalEmbedding returns expected outputs for a 1d sequence. - """ - positions = np.arange(10) - dim = 8 - pos_emb_cfg = LearnedPositionalEmbedding.default_config().set( - name="test", - dim=dim, - shape=(len(positions),), - ) - pos_emb = pos_emb_cfg.instantiate(parent=None) + Returns: + A tensor of slopes with shape of [num_heads]. Each value represents + a slope for one attention head. + """ - state = pos_emb.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) + def get_slopes_power_of_2(n: int) -> list: + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] - outputs, _ = F( - pos_emb, - is_training=True, - prng_key=jax.random.PRNGKey(1), - state=state, - inputs={"positions": positions}, + if math.log2(num_heads).is_integer(): + return get_slopes_power_of_2(num_heads) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) + return ( + get_slopes_power_of_2(closest_power_of_2) + + alibi_get_slopes(2 * closest_power_of_2)[0::2][: num_heads - closest_power_of_2] ) - context = InvocationContext( - name="root", - parent=None, - module=pos_emb, - state=state, - output_collection=new_output_collection(), - is_training=True, - prng_key=jax.random.PRNGKey(2), - ) - with set_current_context(context): - embeddings_tensor = pos_emb.embeddings() - assert embeddings_tensor.shape == (len(positions), dim) - for position in positions: - assert_allclose(outputs[position], embeddings_tensor[position]) +class ALiBiAttentionLogitBiasLayer(CausalAttentionLogitBiasLayer): + """attention logit bias layer in ALiBi. + Ref: https://github.com/ofirpress/attention_with_linear_biases/tree/3b7c2eca + """ -@pytest.mark.parametrize("x, output", [(300, 512), (127.1, 128), (128, 128), (0.1, 2)]) -def test_next_power_of_two(x, output): - assert _next_power_of_two(x) == output + @config_class + class Config(CausalAttentionLogitBiasLayer.Config): + """Configures ALiBiAttentionLogitBiasLayer.""" + num_heads: Required[int] = REQUIRED -class BottleNeckAdapterTransformerLayerTest(TestCase): - """Tests BottleNeckAdapterTransformerLayer.""" + def forward(self, *, segment_ids: Tensor, positions: Tensor) -> Tensor: + """Produces an attention logit biases of shape [batch_size, num_heads, seq_len, seq_len]. - @parameterized.parameters( - {"bottleneck_ratio": 0.1}, - {"bottleneck_ratio": 0.5}, - {"bottleneck_ratio": 1.0}, - ) - def test_forward(self, bottleneck_ratio): - batch_size, tgt_len, model_dim, num_heads = 2, 3, 32, 1 + The ALiBi bias is defined as below: + 1. Create a lower triangle matrix with the value of: + bias = [-(i-1), ..., -2, -1, 0] * slopes + 2. Apply the casual biases. + bias = apply_apply_attention_logit_biases(bias, causal_bias) - layer_cfg = TransformerLayer.default_config().set(name="layer", input_dim=model_dim) - layer_cfg.self_attention.attention.set(num_heads=num_heads) - layer_cfg.feed_forward.hidden_dim = model_dim + Refer to AttentionLogitBiasLayer.forward for docstring. + """ + cfg = self.config + slopes = jnp.asarray(alibi_get_slopes(cfg.num_heads)) + # Create the lower triangle matrix w/ value [-(i-1), ..., -2, -1, 0] for each segment. + alibi_bias = jnp.expand_dims(positions, [1]) - jnp.expand_dims(positions, [2]) + # Add head dim. + alibi_bias = jnp.expand_dims(alibi_bias, [1]) + # Multiply w/ the slopes. + alibi_bias = alibi_bias * jnp.expand_dims(slopes, [0, 2, 3]) + bias = super().forward(segment_ids=segment_ids, positions=positions) + # Combine the biases. + return apply_attention_logit_biases(alibi_bias, bias) + + +class SymmetricALiBiAttentionLogitBiasLayer(FullAttentionLogitBiasLayer): + """Symmetric full attention version of ALiBiAttentionLogitBiasLayer. + + Main implementation differences between this one and `ALiBiAttentionLogitBiasLayer` (above): + 1. Muliplies alibi slopes by -1. + 2. Computes absolute value of relative positions. + 3. Multiplies results of steps 1 and 2 to get symmetric bias matrix. + + Originally proposed here by an author of the ALiBi paper: + https://github.com/ofirpress/attention_with_linear_biases/issues/5 + """ - adapter_cfg = BottleNeckAdapterTransformerLayer.default_config().set( - input_dim=model_dim, name="adapter", bottleneck_ratio=bottleneck_ratio - ) - adapter_cfg.layer = layer_cfg + @config_class + class Config(FullAttentionLogitBiasLayer.Config): + """Configures SymmetricALiBiAttentionLogitBiasLayer.""" - adapter = adapter_cfg.instantiate(parent=None) + num_heads: Required[int] = REQUIRED - state = adapter.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) + def forward(self, *, segment_ids: Tensor, positions: Tensor) -> Tensor: + cfg = self.config - data = jax.random.normal(jax.random.PRNGKey(1), [batch_size, tgt_len, model_dim]) - self_attention_logit_biases = attention_bias.make_causal_biases(tgt_len) + slopes = -1 * jnp.asarray(alibi_get_slopes(cfg.num_heads)) - outputs, _ = F( - adapter, - is_training=True, - prng_key=jax.random.PRNGKey(2), - state=state, - inputs=dict( - data=data, - self_attention_logit_biases=self_attention_logit_biases, - ), - ) + # Create the lower triangle matrix w/ value [-(i-1), ..., -2, -1, 0] for each segment. + alibi_bias = jnp.abs(positions[:, jnp.newaxis, :] - positions[:, :, jnp.newaxis]) - # Output shape is left unchanged. - assert outputs.data.shape == (2, 3, 32) + # Add head dim. + alibi_bias = alibi_bias[:, jnp.newaxis, :, :] + # Multiply w/ the slopes. + alibi_bias = alibi_bias * jnp.expand_dims(slopes, [0, 2, 3]) -if __name__ == "__main__": - with utils.numeric_checks(True): - absltest.main() + bias = super().forward(segment_ids=segment_ids, positions=positions) + # Combine the biases. + return apply_attention_logit_biases(alibi_bias, bias) From 7aa6dd5dda3e9e977e46d5d2eab7872dea1cf2d4 Mon Sep 17 00:00:00 2001 From: Firenze11 Date: Tue, 14 Jan 2025 18:54:34 -0800 Subject: [PATCH 07/12] Update dit.py --- axlearn/common/dit.py | 71 ++++++++++++++++++++++++++++--------------- 1 file changed, 46 insertions(+), 25 deletions(-) diff --git a/axlearn/common/dit.py b/axlearn/common/dit.py index 711f0457c..a388d1920 100644 --- a/axlearn/common/dit.py +++ b/axlearn/common/dit.py @@ -13,6 +13,8 @@ from typing import Optional, Union +import chex +import einops import jax import jax.numpy as jnp @@ -31,7 +33,21 @@ def modulate(*, x, shift, scale): - return x * (1 + jnp.expand_dims(scale, 1)) + jnp.expand_dims(shift, 1) + """Modulates the input x tensor. + + Note: shift and scale must have the same shape. + + Args: + x: input tensor with shape [batch_size, num_length, input_dim]. + shift: shifting the norm tensor with shape [batch_size, 1|num_length, input_dim]. + scale: scaling the norm tensor with shape [batch_size, 1|num_length, input_dim]. + + Returns: + A tensor with shape [batch_size, num_length, input_dim]. + """ + chex.assert_equal_shape((shift, scale)) + chex.assert_equal_rank((x, shift, scale)) + return x * (1 + scale) + shift class TimeStepEmbedding(BaseLayer): @@ -211,15 +227,18 @@ def forward(self, input: Tensor) -> Tensor: """Generate the parameters for modulation. Args: - input: A tensor with shape [batch_size, ..., dim]. + input: A tensor with shape [batch_size, dim] or [batch_size, num_length, dim]. Returns: A list of tensors with length num_outputs. - Each tensor has shape [batch_size, ..., dim]. + Each tensor has shape [batch_size, 1|num_length, dim]. """ cfg = self.config x = get_activation_fn(cfg.activation)(input) output = self.linear(x) + assert output.ndim in (2, 3) + if output.ndim == 2: + output = einops.rearrange(output, "b d -> b 1 d") output = jnp.split(output, cfg.num_outputs, axis=-1) return output @@ -292,14 +311,16 @@ def forward(self, *, input: Tensor, shift: Tensor, scale: Tensor, gate: Tensor) Args: input: input tensor with shape [batch_size, num_length, input_dim]. - shift: shifting the norm tensor with shape [batch_size, input_dim]. - scale: scaling the norm tensor with shape [batch_size, input_dim]. + shift: shifting the norm tensor with shape [batch_size, 1|num_length, input_dim]. + scale: scaling the norm tensor with shape [batch_size, 1|num_length, input_dim]. gate: applying before the residual addition with shape - [batch_size, input_dim]. + [batch_size, 1|num_length, input_dim]. Returns: A tensor with shape [batch_size, num_length, input_dim]. """ + chex.assert_equal_shape((shift, scale, gate)) + chex.assert_equal_rank((input, shift)) cfg = self.config remat_pt1 = "linear1_0" remat_pt2 = "linear2" @@ -325,7 +346,7 @@ def forward(self, *, input: Tensor, shift: Tensor, scale: Tensor, gate: Tensor) x = self.postnorm(x) x = self.dropout2(x) - x = x * jnp.expand_dims(gate, 1) + x = x * gate x += input return x @@ -390,12 +411,12 @@ def forward( Args: input: input tensor with shape [batch_size, num_length, target_dim]. - shift: If provided, shifting the norm tensor with shape [batch_size, target_dim] and - scale should be provided. - scale: If provided, scaling the norm tensor with shape [batch_size, target_dim] and - shift should be provided. + shift: If provided, shifting the norm tensor with shape [batch_size, 1|num_length, + target_dim] and scale should be provided. + scale: If provided, scaling the norm tensor with shape [batch_size, 1|num_length, + target_dim] and shift should be provided. gate: If provided, applying before the residual addition with shape - [batch_size, target_dim]. + [batch_size, 1|num_length, target_dim]. attention_logit_biases: Optional Tensor representing the self attention biases. Returns: @@ -429,7 +450,7 @@ def forward( x = self.postnorm(x) if gate is not None: - x = x * jnp.expand_dims(gate, 1) + x = x * gate output = input + x return output @@ -466,12 +487,12 @@ def extend_step( results of previous attentions, and index used for fast decoding. Contains "attention" cached states. target: target tensor with shape [batch_size, step_length, target_dim]. - shift: If provided, shifting the norm tensor with shape [batch_size, target_dim] and - scale should be provided. - scale: If provided, scaling the norm tensor with shape [batch_size, target_dim] and - shift should be provided. + shift: If provided, shifting the norm tensor with shape [batch_size, 1|num_length, + target_dim] and scale should be provided. + scale: If provided, scaling the norm tensor with shape [batch_size, 1|num_length, + target_dim] and shift should be provided. gate: If provided, applying before the residual addition with shape - [batch_size, target_dim]. + [batch_size, 1|num_length, target_dim]. Returns: A tuple (cached_states, output): @@ -507,7 +528,7 @@ def extend_step( x = self.postnorm(x) if gate is not None: - x = x * jnp.expand_dims(gate, 1) + x = x * gate output = target + x return dict(attention=attn_states), output @@ -545,8 +566,8 @@ def forward(self, *, input: Tensor, condition: Tensor) -> Tensor: Args: input: input tensor with shape [batch_size, num_length, input_dim]. - condition: tensor with shape [batch_size, input_dim] for generating - layer norm shift, scale, and gate. + condition: tensor with shape [batch_size, input_dim] or [batch_size, num_length, + input_dim] for generating layer norm shift, scale, and gate. Returns: A tensor with shape [batch_size, num_length, input_dim]. @@ -587,8 +608,8 @@ def extend_step( results of previous attentions, and index used for fast decoding. Contains "attention" cached states. target: target tensor with shape [batch_size, step_length, input_dim]. - condition: tensor with shape [batch_size, input_dim] for generating - layer norm shift, scale, and gate. + condition: tensor with shape [batch_size, input_dim] or [batch_size, step_length, + input_dim] for generating layer norm shift, scale, and gate. Returns: A tuple (cached_states, output): @@ -642,8 +663,8 @@ def forward(self, *, input: Tensor, condition: Tensor) -> Tensor: Args: input: input tensor with shape [batch_size, num_length, input_dim]. - condition: tensor with shape [batch_size, input_dim] for generating - layer norm shift and scale. + condition: tensor with shape [batch_size, input_dim] or [batch_size, num_length, + input_dim] for generating layer norm shift and scale. Returns: A tensor with shape [batch_size, num_length, output_dim]. From 0be3a830603153ca925a46bf38afb6a75b3be4c6 Mon Sep 17 00:00:00 2001 From: Firenze11 Date: Tue, 14 Jan 2025 19:16:18 -0800 Subject: [PATCH 08/12] Update axlearn/common/attention.py Co-authored-by: Mark Lee --- axlearn/common/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/axlearn/common/attention.py b/axlearn/common/attention.py index b2cc40f61..71f2dcb3f 100644 --- a/axlearn/common/attention.py +++ b/axlearn/common/attention.py @@ -1231,7 +1231,7 @@ def forward( Args: positions: A tensor representing the token position IDs. The shape is [batch_size, seq_len]. - max_seq_len: Max length of sequence, required if positions is not provided + max_seq_len: Max length of sequence, required if positions is not provided. Returns: Rotary Positional Embedding. Shape is [seq_len, dim]. From 82a29f743381550bc04cb972426387e8a8dbb1eb Mon Sep 17 00:00:00 2001 From: Firenze11 Date: Fri, 17 Jan 2025 09:47:45 -0800 Subject: [PATCH 09/12] respond to comments. Co-authored-by: Ruoming Pang --- axlearn/common/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/axlearn/common/attention.py b/axlearn/common/attention.py index 71f2dcb3f..e77033a28 100644 --- a/axlearn/common/attention.py +++ b/axlearn/common/attention.py @@ -1223,7 +1223,7 @@ def default_query_positions(self, max_seq_len: int) -> Tensor: return jnp.arange(max_seq_len)[None] # [batch_size=1, max_seq_len]. def forward( - self, positions: Optional[Tensor] = None, max_seq_len: Optional[int] = None + self, *, positions: Optional[Tensor] = None, max_seq_len: Optional[int] = None ) -> Tensor: """ TODO(bwzhang): 1. verify the performance under float32. From 2bb2a2bb265e6be953ebc5cece6785efd6ba41ca Mon Sep 17 00:00:00 2001 From: Firenze11 Date: Fri, 17 Jan 2025 09:53:32 -0800 Subject: [PATCH 10/12] Update attention.py --- axlearn/common/attention.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/axlearn/common/attention.py b/axlearn/common/attention.py index e77033a28..33b8ec575 100644 --- a/axlearn/common/attention.py +++ b/axlearn/common/attention.py @@ -1231,7 +1231,8 @@ def forward( Args: positions: A tensor representing the token position IDs. The shape is [batch_size, seq_len]. - max_seq_len: Max length of sequence, required if positions is not provided. + max_seq_len: Max length of sequence, required if positions is not provided, + ignored if positions is provided. Returns: Rotary Positional Embedding. Shape is [seq_len, dim]. From 42552344e69bc575fcd9463ebb39fb7dbdc8de57 Mon Sep 17 00:00:00 2001 From: Firenze11 Date: Fri, 24 Jan 2025 15:30:04 -0800 Subject: [PATCH 11/12] Update attention.py --- axlearn/common/attention.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/axlearn/common/attention.py b/axlearn/common/attention.py index 8c4a8b7b7..d199c0984 100644 --- a/axlearn/common/attention.py +++ b/axlearn/common/attention.py @@ -1238,9 +1238,14 @@ def forward( Rotary Positional Embedding. Shape is [seq_len, dim]. Raises: - ValueError: If positions is None and max_seq_len is None. + ValueError: If positions is None and max_seq_len is None, or they both exist + but do not match. """ cfg = self.config + if positions is not None and max_seq_len is not None: + if max_seq_len != positions.shape[-1]: + raise ValueError("Both `positions` and `max_seq_len` are provided and they " + "do not match. You only need to provide one of them.") if positions is None: if max_seq_len is None: raise ValueError( From 20a1b4c3d35a5d8831ef8f92ac654506a481c1cc Mon Sep 17 00:00:00 2001 From: Firenze11 Date: Fri, 24 Jan 2025 15:56:30 -0800 Subject: [PATCH 12/12] Update attention.py --- axlearn/common/attention.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/axlearn/common/attention.py b/axlearn/common/attention.py index 41ae3a5ba..e00e58de1 100644 --- a/axlearn/common/attention.py +++ b/axlearn/common/attention.py @@ -1245,8 +1245,10 @@ def forward( cfg = self.config if positions is not None and max_seq_len is not None: if max_seq_len != positions.shape[-1]: - raise ValueError("Both `positions` and `max_seq_len` are provided and they " - "do not match. You only need to provide one of them.") + raise ValueError( + "Both `positions` and `max_seq_len` are provided and they " + "do not match. You only need to provide one of them." + ) if positions is None: if max_seq_len is None: raise ValueError(