Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow external positions to be inputed in RoPE embedding layer #926

Merged
merged 14 commits into from
Jan 27, 2025
38 changes: 33 additions & 5 deletions axlearn/common/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1217,18 +1217,45 @@ class Config(BaseLayer.Config):
dim: Required[int] = REQUIRED # The dimensionality of the positional embedding.
theta: float = 10000.0 # The scale of base frequency.

def forward(self, positions: Tensor) -> Tensor:
def default_query_positions(self, max_seq_len: int) -> Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def default_query_positions(self, max_seq_len: int) -> Tensor:
def _default_query_positions(self, max_seq_len: int) -> Tensor:

Users should pass max_seq_len rather than calling this method publicly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There might be situations that we want to access the default query positions from the outside, such as getting the default positions and computing the positions based on it before passing it to forward. Therefore we want to keep this public.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we expect subclasses to override this method?

If not, it will be more readable for callers to call jnp.arange directly given the implementation is only one line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we do expect subclasses to override this method. One example is that we might want to specify the rotation start index and rotation end index for this embedding class. And the default embedding positions will be different then.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation. Given this, it's reasonable to consolidate the position computation logic in this class.

"""Compute default `positions` value to be inputed into forward when `positions` is
not provided to the corresponding QKVLinear class such as `RoFormerQKVLinear`
"""
return jnp.arange(max_seq_len)[None] # [batch_size=1, max_seq_len].

def forward(
self, *, positions: Optional[Tensor] = None, max_seq_len: Optional[int] = None
) -> Tensor:
"""
TODO(bwzhang): 1. verify the performance under float32.

Args:
positions: A tensor representing the token position IDs.
The shape is [batch_size, seq_len].
max_seq_len: Max length of sequence, required if positions is not provided,
ignored if positions is provided.

Returns:
Rotary Positional Embedding. Shape is [seq_len, dim].

Raises:
ValueError: If positions is None and max_seq_len is None, or they both exist
but do not match.
"""
cfg = self.config
if positions is not None and max_seq_len is not None:
if max_seq_len != positions.shape[-1]:
raise ValueError(
"Both `positions` and `max_seq_len` are provided and they "
"do not match. You only need to provide one of them."
)
if positions is None:
if max_seq_len is None:
ruomingp marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
"Must provide `max_seq_len` for computing default query positions if "
"`positions` is None."
)
positions = self.default_query_positions(max_seq_len)
return _rotary_sinusoidal_positional_embeddings(
positions=positions, dim=cfg.dim, theta=cfg.theta
)
Expand Down Expand Up @@ -1301,7 +1328,7 @@ class RoFormerQKVLinear(BaseQKVLinear):
class Config(BaseQKVLinear.Config):
"""Configures RoFormerQKVLinear."""

rope_pos_emb_layer: InstantiableConfig = (
rope_pos_emb_layer: RoFormerSinusoidalPositionalEmbedding.Config = (
RoFormerSinusoidalPositionalEmbedding.default_config()
)
input_linear: BaseQKVLinear.Config = QKVLinear.default_config()
Expand Down Expand Up @@ -1343,9 +1370,10 @@ def forward(
cfg = self.config
# Query should have shape of [batch_size, seq_len, num_heads, per_head_dim].
query, key, value = self.i_proj(query, key=key, value=value, kv_state=kv_state)
if query_positions is None:
query_positions = jnp.arange(query.shape[1])[None]
sinusoidal_pos_emb = self.rope_pos_emb_layer.forward(query_positions).astype(query.dtype)
seq_len = query.shape[1]
sinusoidal_pos_emb = self.rope_pos_emb_layer.forward(
positions=query_positions, max_seq_len=seq_len
).astype(query.dtype)
# sinusoidal_pos_emb shape should be [batch_size, seq_len, 1, dim]
sinusoidal_pos_emb = jnp.expand_dims(sinusoidal_pos_emb, 2)

Expand Down
121 changes: 115 additions & 6 deletions axlearn/common/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,9 +812,73 @@ def test_rope_emb(self, batch_size, max_len, dim):
.set(name="test_rope_emb", dim=dim)
.instantiate(parent=None)
)
test_output = test_layer.forward(positions)
test_output = test_layer.forward(positions=positions)
np.testing.assert_allclose(np.expand_dims(ref_output, 0), test_output, atol=5e-7)

@parameterized.parameters(
(None, True),
(10, False),
)
def test_rope_emb_no_pos(self, max_len, should_raise):
test_layer = (
attention.RoFormerSinusoidalPositionalEmbedding.default_config()
.set(name="test_rope_emb", dim=10)
.instantiate(parent=None)
)
if should_raise:
with self.assertRaises(ValueError):
test_layer.forward(max_seq_len=max_len)
else:
test_layer.forward(max_seq_len=max_len)

@parameterized.parameters(
(2, 10, 32, 4),
)
def test_default_rope_emb(self, batch_size, max_len, dim, num_heads):
rng = np.random.default_rng(seed=123)
query = jnp.asarray(rng.random([batch_size, max_len, dim]))
key = jnp.asarray(rng.random([batch_size, max_len, dim]))
value = jnp.asarray(rng.random([batch_size, max_len, dim]))
per_head_dim = dim // num_heads

emb_layer_cfg = attention.RoFormerSinusoidalPositionalEmbedding.default_config().set(
dim=per_head_dim,
)
linear_layer_cfg = attention.RoFormerQKVLinear.default_config().set(
query_dim=dim,
key_dim=dim,
value_dim=dim,
num_heads=num_heads,
per_head_dim=per_head_dim,
rope_pos_emb_layer=emb_layer_cfg,
rotary_value=False,
name="test_rope_linear",
)
rope_linear_layer = linear_layer_cfg.instantiate(parent=None)
state = rope_linear_layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0))

rope_emb_layer = emb_layer_cfg.set(name="test_rope_emb").instantiate(parent=None)
default_positions = rope_emb_layer.default_query_positions(max_len)

input_dict = dict(query=query, key=key, value=value)

layer_outputs_no_position, _ = F(
rope_linear_layer,
inputs=input_dict,
state=state,
is_training=True,
prng_key=jax.random.PRNGKey(0),
)
layer_outputs, _ = F(
rope_linear_layer,
inputs=dict(**input_dict, query_positions=default_positions),
state=state,
is_training=True,
prng_key=jax.random.PRNGKey(0),
)
# test RoFormerQKVLinear uses default positions in RoFormerSinusoidalPositionalEmbedding
np.testing.assert_allclose(layer_outputs_no_position, layer_outputs, atol=1e-5)

def _compare_against_roformer_attention(
self,
ref,
Expand Down Expand Up @@ -887,7 +951,7 @@ def test_rope_self_attention(self, rotary_value: bool, override_positions: bool)
if override_positions
else jnp.expand_dims(jnp.arange(max_sequence_length), 0)
)
ref_rope_emb = as_torch_tensor(rope_emb_layer.forward(positions)).unsqueeze(1)
ref_rope_emb = as_torch_tensor(rope_emb_layer.forward(positions=positions)).unsqueeze(1)
layer = attention.TransformerAttentionLayer.default_config().set(
source_dim=model_dim,
target_dim=model_dim,
Expand Down Expand Up @@ -1075,9 +1139,13 @@ def test_against_llama_for_precompute_freqs_cis(self, theta: float):
attention.QKVLinear.default_config(),
attention.GroupedQKVLinear.default_config(),
),
has_query_positions=(True, False),
)
def test_roformer_qkv_linear(
self, dtype: jnp.dtype, input_linear: attention.BaseQKVLinear.Config
self,
dtype: jnp.dtype,
input_linear: attention.BaseQKVLinear.Config,
has_query_positions: bool,
):
seq_len = 6
batch_size = 2
Expand Down Expand Up @@ -1116,6 +1184,14 @@ def test_roformer_qkv_linear(
jax.random.PRNGKey(0)
)
input_batch = dict(query=query, key=key, value=value)
if has_query_positions:
input_batch["query_positions"] = jax.random.permutation(
jax.random.PRNGKey(1),
jnp.arange(seq_len)[None, :].repeat(batch_size, axis=0),
axis=1,
independent=True,
)

layer_outputs, _ = F(
roformer_qkv_linear,
inputs=utils.cast_floats(input_batch, to_dtype=dtype),
Expand Down Expand Up @@ -2168,17 +2244,24 @@ def test_data_types(self, dtype: jnp.dtype, per_dim_scale: Optional[PerDimScale.
lambda query_len, kv_len: _random_mask(jax.random.PRNGKey(1), query_len, kv_len),
),
kv_length_multiplier=(0.5, 1, 2),
has_query_positions=(False, True),
)
def test_causal(
self,
base_cfg: attention.MultiheadAttention.Config,
attention_logit_biases_fn: Callable[[int, int], Tensor],
kv_length_multiplier: float,
has_query_positions: bool,
):
"""Tests that base_cfg(causal=True) is equivalent to applying a causal mask."""
if kv_length_multiplier != 1 and isinstance(
base_cfg.input_linear,
(FusedGroupedQKVLinear.Config, RoFormerQKVLinear.Config, FusedQKVLinear.Config),
if (
has_query_positions
and not isinstance(base_cfg.input_linear, RoFormerQKVLinear.Config)
or kv_length_multiplier != 1
and isinstance(
base_cfg.input_linear,
(FusedGroupedQKVLinear.Config, RoFormerQKVLinear.Config, FusedQKVLinear.Config),
)
):
pytest.skip(reason="Incompatible test setting that does not need testing.")

Expand All @@ -2202,6 +2285,14 @@ def test_causal(
query = jnp.zeros([batch_size, query_len, model_dim], dtype=jnp.float32)
outputs = []

if has_query_positions:
query_positions = jax.random.permutation(
jax.random.PRNGKey(1),
jnp.arange(query_len)[None, :].repeat(batch_size, axis=0),
axis=1,
independent=True,
)

for layer in (ref_layer, test_layer):
inputs = dict(query=query)
kv_len = int(kv_length_multiplier * query_len)
Expand All @@ -2223,6 +2314,8 @@ def test_causal(
attention_logit_biases, causal_biases
)
inputs["attention_logit_biases"] = attention_logit_biases
if has_query_positions:
inputs["query_positions"] = query_positions

layer_outputs, _ = F(
layer,
Expand Down Expand Up @@ -2261,16 +2354,21 @@ def test_causal(
lambda seq_len: None,
lambda seq_len: _random_mask(jax.random.PRNGKey(1), seq_len, seq_len),
),
has_query_positions=(False, True),
)
def test_sliding_window(
self,
base_cfg: attention.MultiheadAttention.Config,
attention_logit_biases_fn: Callable[[int], Tensor],
has_query_positions: bool,
):
"""
Tests that base_cfg with sliding window causal mask fns is equivalent to applying a
causal sliding window mask.
"""
if has_query_positions and not isinstance(base_cfg.input_linear, RoFormerQKVLinear.Config):
return

model_dim = 16
num_heads = 4
ref_cfg = base_cfg.clone(
Expand All @@ -2296,6 +2394,15 @@ def test_sliding_window(
batch_size, seq_len = 2, 4
query = jnp.zeros([batch_size, seq_len, model_dim], dtype=jnp.float32)
outputs = []

if has_query_positions:
query_positions = jax.random.permutation(
jax.random.PRNGKey(1),
jnp.arange(seq_len)[None, :].repeat(batch_size, axis=0),
axis=1,
independent=True,
)

for layer in (ref_layer, test_layer):
attention_logit_biases = attention_logit_biases_fn(seq_len)
if layer is ref_layer:
Expand All @@ -2305,6 +2412,8 @@ def test_sliding_window(
attention_logit_biases,
)
inputs = dict(query=query, attention_logit_biases=attention_logit_biases)
if has_query_positions:
inputs["query_positions"] = query_positions
layer_outputs, _ = F(
layer,
state=layer_params,
Expand Down
5 changes: 4 additions & 1 deletion axlearn/common/dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ def forward(
shift: Optional[Tensor] = None,
scale: Optional[Tensor] = None,
gate: Optional[Tensor] = None,
query_positions: Optional[Tensor] = None,
attention_logit_biases: Optional[Tensor] = None,
) -> Tensor:
"""The forward function of DiTAttentionLayer.
Expand Down Expand Up @@ -440,7 +441,9 @@ def forward(
if shift is not None and scale is not None:
x = modulate(x=x, shift=shift, scale=scale)

x = self.attention(query=x, attention_logit_biases=attention_logit_biases).data
x = self.attention(
query=x, query_positions=query_positions, attention_logit_biases=attention_logit_biases
).data

if cfg.structure == "postnorm":
x = self.norm(x)
Expand Down
Loading