-
Notifications
You must be signed in to change notification settings - Fork 483
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
213f386
commit 27fd3a1
Showing
13 changed files
with
711 additions
and
2 deletions.
There are no files selected for viewing
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import pytest | ||
import torch | ||
|
||
from torchtune.models.t5._component_builders import t5_encoder | ||
from torchtune.training.seed import set_seed | ||
|
||
VOCAB_SIZE = 512 | ||
MAX_SEQ_LEN = 8 | ||
BSZ = 2 | ||
EMBED_DIM = 2 | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def random(): | ||
set_seed(0) | ||
|
||
|
||
class TestT5Encoder: | ||
@pytest.fixture | ||
def model(self): | ||
model = t5_encoder( | ||
embed_dim=EMBED_DIM, | ||
mlp_dim=4, | ||
num_heads=2, | ||
head_dim=EMBED_DIM // 2, | ||
num_layers=2, | ||
rel_pos_num_buckets=4, | ||
rel_pos_max_dist=4, | ||
vocab_size=VOCAB_SIZE, | ||
norm_eps=1e-6, | ||
max_seq_len=MAX_SEQ_LEN, | ||
) | ||
|
||
for param in model.parameters(): | ||
param.data.uniform_(0, 1) | ||
|
||
return model | ||
|
||
@pytest.fixture | ||
def inputs(self): | ||
return torch.randint(0, VOCAB_SIZE, (BSZ, MAX_SEQ_LEN)) | ||
|
||
def test_forward(self, model, inputs): | ||
actual = model(inputs) | ||
expected = torch.tensor( | ||
[ | ||
[ | ||
[0.3670, 0.2938], | ||
[0.3692, 0.2921], | ||
[0.3611, 0.2984], | ||
[0.4207, 0.2437], | ||
[0.3447, 0.3106], | ||
[0.3383, 0.3150], | ||
[0.3727, 0.2892], | ||
[0.3996, 0.2653], | ||
], | ||
[ | ||
[0.3855, 0.2783], | ||
[0.2627, 0.3581], | ||
[0.3601, 0.2992], | ||
[0.3473, 0.3087], | ||
[0.3549, 0.3032], | ||
[0.2871, 0.3459], | ||
[0.2753, 0.3520], | ||
[0.2285, 0.3728], | ||
], | ||
] | ||
) | ||
assert actual.shape == (BSZ, MAX_SEQ_LEN, EMBED_DIM) | ||
torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4) | ||
|
||
def test_backward(self, model, inputs): | ||
y = model(inputs) | ||
loss = y.mean() | ||
loss.backward() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
import pytest | ||
|
||
from tests.common import ASSETS | ||
from torchtune.models.t5._model_builders import t5_tokenizer | ||
|
||
|
||
class TestT5Tokenizer: | ||
@pytest.fixture | ||
def tokenizer(self): | ||
return t5_tokenizer(str(ASSETS / "sentencepiece.model")) | ||
|
||
def test_encoding(self, tokenizer): | ||
texts = [ | ||
"a cow jumping over the moon", | ||
"a helpful AI assistant", | ||
] | ||
correct_tokens = [ | ||
[3, 9, 9321, 15539, 147, 8, 8114, 1], | ||
[3, 9, 2690, 7833, 6165, 1], | ||
] | ||
for text, correct in zip(texts, correct_tokens): | ||
tokens = tokenizer.encode(text) | ||
print(tokens) | ||
assert tokens == correct | ||
|
||
def test_decoding(self, tokenizer): | ||
text = "this is torchtune" | ||
assert text == tokenizer.decode(tokenizer.encode(text)) | ||
|
||
def test_call(self, tokenizer): | ||
sample = {"text": "hello world"} | ||
sample = tokenizer(sample) | ||
assert "text" not in sample | ||
assert "tokens" in sample |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from ._component_builders import t5_encoder | ||
from ._model_builders import t5_tokenizer, t5_v1_1_xxl_encoder | ||
|
||
__all__ = [ | ||
"t5_encoder", | ||
"t5_tokenizer", | ||
"t5_v1_1_xxl_encoder", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
from torch import nn | ||
|
||
from torchtune.models.t5._encoder import ( | ||
T5Encoder, | ||
T5EncoderLayer, | ||
T5EncoderSelfAttention, | ||
) | ||
from torchtune.modules.feed_forward import FeedForward | ||
from torchtune.modules.rms_norm import RMSNorm | ||
|
||
|
||
def t5_encoder( | ||
embed_dim: int, | ||
mlp_dim: int, | ||
num_heads: int, | ||
head_dim: int, | ||
num_layers: int, | ||
rel_pos_num_buckets: int, | ||
rel_pos_max_dist: int, | ||
vocab_size: int, | ||
norm_eps: float, | ||
max_seq_len: int, | ||
): | ||
""" | ||
Builder for the T5 encoder. | ||
T5 paper: https://arxiv.org/abs/1910.10683 | ||
Args: | ||
embed_dim (int): The model dimension. | ||
mlp_dim (int): The inner dimension of the feed forward layers. | ||
num_heads (int): The number of attention heads. | ||
head_dim (int): The dimension of the attention heads (should equal `embed_dim // num_heads`) | ||
num_layers (int): Number of encoder layers. | ||
rel_pos_num_buckets (int): Number of discrete buckets to divide the relative positions into. | ||
See: :class:`~torchtune.models.t5._encoder.T5EncoderRelativePositionBias` | ||
rel_pos_max_dist (int): Maximum distance for relative positions. | ||
Distances beyond this are grouped into the last bucket. | ||
See: :class:`~torchtune.models.t5._encoder.T5EncoderRelativePositionBias` | ||
vocab_size (int): Vocab size of the model's tokenizer. | ||
norm_eps (float): Small value added to denominator for numerical stability. | ||
max_seq_len (int): The maximum sequence length (context length) of the model. | ||
Returns: | ||
T5Encoder | ||
""" | ||
token_embedding = nn.Embedding(vocab_size, embed_dim) | ||
|
||
attn = T5EncoderSelfAttention( | ||
embed_dim=embed_dim, | ||
num_heads=num_heads, | ||
head_dim=head_dim, | ||
q_proj=nn.Linear(embed_dim, embed_dim, bias=False), | ||
k_proj=nn.Linear(embed_dim, embed_dim, bias=False), | ||
v_proj=nn.Linear(embed_dim, embed_dim, bias=False), | ||
output_proj=nn.Linear(embed_dim, embed_dim, bias=False), | ||
) | ||
|
||
mlp = FeedForward( | ||
gate_proj=nn.Linear(embed_dim, mlp_dim, bias=False), | ||
down_proj=nn.Linear(mlp_dim, embed_dim, bias=False), | ||
up_proj=nn.Linear(embed_dim, mlp_dim, bias=False), | ||
activation=nn.GELU(), | ||
) | ||
|
||
layer = T5EncoderLayer( | ||
attn=attn, | ||
mlp=mlp, | ||
sa_norm=RMSNorm(embed_dim, eps=norm_eps), | ||
mlp_norm=RMSNorm(embed_dim, eps=norm_eps), | ||
) | ||
|
||
final_norm = RMSNorm(embed_dim, eps=norm_eps) | ||
|
||
return T5Encoder( | ||
token_embedding=token_embedding, | ||
layer=layer, | ||
final_norm=final_norm, | ||
num_layers=num_layers, | ||
num_heads=num_heads, | ||
rel_pos_num_buckets=rel_pos_num_buckets, | ||
rel_pos_max_dist=rel_pos_max_dist, | ||
max_seq_len=max_seq_len, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from torchtune.models.convert_weights import get_mapped_key | ||
|
||
# state dict key mappings from HF's format to torchtune's format | ||
_FROM_HF = { | ||
# emb | ||
"encoder.embed_tokens.weight": "token_embedding.weight", | ||
"encoder.block.{}.layer._0.SelfAttention.relative_attention_bias.weight": "relative_position_bias.embedding.weight", | ||
# attn | ||
"encoder.block.{}.layer._0.SelfAttention.q.weight": "layers.{}.attn.q_proj.weight", | ||
"encoder.block.{}.layer._0.SelfAttention.k.weight": "layers.{}.attn.k_proj.weight", | ||
"encoder.block.{}.layer._0.SelfAttention.v.weight": "layers.{}.attn.v_proj.weight", | ||
"encoder.block.{}.layer._0.SelfAttention.o.weight": "layers.{}.attn.output_proj.weight", | ||
# ff | ||
"encoder.block.{}.layer._1.DenseReluDense.wi_0.weight": "layers.{}.mlp.w1.weight", | ||
"encoder.block.{}.layer._1.DenseReluDense.wo.weight": "layers.{}.mlp.w2.weight", | ||
"encoder.block.{}.layer._1.DenseReluDense.wi_1.weight": "layers.{}.mlp.w3.weight", | ||
# norm | ||
"encoder.block.{}.layer._0.layer_norm.weight": "layers.{}.sa_norm.scale", | ||
"encoder.block.{}.layer._1.layer_norm.weight": "layers.{}.mlp_norm.scale", | ||
"encoder.final_layer_norm.weight": "final_norm.scale", | ||
} | ||
|
||
_IGNORE = { | ||
"shared.weight", | ||
"lm_head.weight", | ||
} | ||
|
||
|
||
def t5_encoder_hf_to_tune(state_dict): | ||
converted_state_dict = {} | ||
for key, value in state_dict.items(): | ||
if key.startswith("decoder.") or key in _IGNORE: | ||
continue | ||
|
||
# NOTE: HF's T5 has ".<integer>." parts that we do NOT want to be dynamically mapped | ||
# to corresponding ".<integer>." parts in our converted state dict. | ||
# This breaks the `get_mapped_key` implementation, so as a temporary hack, | ||
# we add leading underscores to these parts here and in the `_FROM_HF` map above. | ||
key = key.replace("layer.0.", "layer._0.").replace("layer.1.", "layer._1.") | ||
|
||
new_key = get_mapped_key(key, _FROM_HF) | ||
converted_state_dict[new_key] = value | ||
return converted_state_dict |
Oops, something went wrong.