Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add transformer class for review #11491

Closed
wants to merge 78 commits into from

Conversation

paarthneekhara
Copy link
Collaborator

Added the transformer stack currently being used in T5TTS - Identify unused code paths, clean up the code and see what modules can be reused.

Signed-off-by: Paarth Neekhara <[email protected]>
@github-actions github-actions bot added the TTS label Dec 5, 2024
nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved
@blisc blisc requested review from XuesongYang and rlangman December 9, 2024 18:33
nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved

self.d_model = d_model
self.non_linearity = nn.GELU(approximate="tanh")
self.proj = ConvNorm(d_model, d_model * 4, bias=bias, kernel_size=kernel_size, is_causal=is_causal)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the FFN size be a configuration instead of hardcoded to 4 * d_model?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Paarth changed this

nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved
Comment on lines 268 to 277
q = self.q_net(query).reshape(Bq, Tq, self.n_heads, self.d_head)
kv = self.kv_net(memory).reshape(Bkv, Tkv, 2, self.n_heads, self.d_head)
if self.pos_emb_name == 'rope':
q, kv = self.rope(q, kv)
elif self.pos_emb_name == 'alibi':
alibi_slopes = self.m[:, 0, 0]
q = q[~query_mask].reshape(-1, self.n_heads, self.d_head)
kv = kv[~memory_mask].reshape(-1, 2, self.n_heads, self.d_head)
lengths_q = (~query_mask).sum(1)
lengths_k = (~memory_mask).sum(1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code like this will be a lot easier to read if we replace .reshape () with einops rearrange(), and add comments with the output shapes for operations that are not reshape/rearrange.

nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved
if self.has_xattn:
self.cross_attention.reset_cache(use_cache)

def forward(self, x, x_mask, cond, cond_mask, dump_attention=False, attn_prior=None, idx=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If cond and cond_mask are optional we should default them to None.

Should we throw an error if cond is provided, but self.has_xattn is False? Or if cond is not provided, but self.has_xattn is True?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can default them to None, but I wouldnt raise an error if has_xattn is True and cond is None. I use that feature to pretrain the decoder with context as None, but still having the same architecture and parameters when using it as the pretrained T5 decoder for TTS,

nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved
p_dropout=p_dropout,
is_causal=False,
is_self_attention=False,
d_memory=params['d_heads'],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should rename d_heads in params to d_memory here. d_memory is supposed to be the dimension of the context information for cross attention. d_heads refers to the size of each attention head, but which this code hardcodes to be d_memory // n_heads.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We no longer use a params dict so this should no longer happen.

paarthneekhara and others added 2 commits December 13, 2024 19:16
…d suggest careful review of changes

Signed-off-by: Paarth Neekhara <[email protected]>
Signed-off-by: Jason <[email protected]>
@blisc blisc self-requested a review December 17, 2024 19:04
use_flash_self_attention=True,
use_flash_x_attention=True,
deterministic=False,
pos_emb={"name": "learnable"},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make the pos_emb argument more structured? Either a dataclass, or similar to xattn flatten into parameters like pos_emb_name, pos_emb_base, pos_emb_kwargs, etc.

Nitpick: Mutable objects like dictionaries should not be used as default arguments.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vote for @dataclass to group the configs.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We got rid of this parameter, and changed it to a bool

Comment on lines 376 to 380
has_xattn,
xa_d_memory=None,
xa_n_heads=None,
xa_pos_emb=None,
xa_max_length_causal_mask=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we group these into a CrossAttentionConfig dataclass? To make it clear which arguments are related/optional. Then we can check if the config is None rather than the has_xattn flag.

paarthneekhara and others added 2 commits December 19, 2024 02:43
…n pos emb and x-attn causal mask args

Signed-off-by: Paarth Neekhara <[email protected]>
XuesongYang and others added 6 commits January 18, 2025 00:53
Signed-off-by: Jason <[email protected]>
It requires that `xa_d_memory` and `xa_n_heads` are specified when `has_xattn` is True

Signed-off-by: Xuesong Yang <[email protected]>
Signed-off-by: Xuesong Yang <[email protected]>
Copy link

@github-advanced-security github-advanced-security bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CodeQL found more than 20 potential problems in the proposed changes. Check the Files changed tab for more details.

@github-actions github-actions bot removed the ASR label Jan 18, 2025
@XuesongYang XuesongYang removed the audio label Jan 18, 2025
tests/collections/tts/modules/test_tts_new_transformer.py Dismissed Show dismissed Hide dismissed
tests/collections/tts/modules/test_tts_new_transformer.py Dismissed Show dismissed Hide dismissed
tests/collections/tts/modules/test_tts_new_transformer.py Dismissed Show dismissed Hide dismissed
tests/collections/tts/modules/test_tts_new_transformer.py Dismissed Show dismissed Hide dismissed
tests/collections/tts/modules/test_tts_new_transformer.py Dismissed Show dismissed Hide dismissed
tests/collections/tts/modules/test_tts_new_transformer.py Dismissed Show dismissed Hide dismissed
tests/collections/tts/modules/test_tts_new_transformer.py Dismissed Show dismissed Hide dismissed
tests/collections/tts/modules/test_tts_new_transformer.py Dismissed Show dismissed Hide dismissed
tests/collections/tts/modules/test_tts_new_transformer.py Dismissed Show dismissed Hide dismissed
@XuesongYang
Copy link
Collaborator

XuesongYang commented Jan 18, 2025

added a unit test. This is the necessary test to ensure the forward pass of Transformer class succeed.

The multiple conditions from difference encoders failed the tests (test_forward_causal_self_attn_and_has_xattn). It seems a list of tensors are not supported. @paarthneekhara could you pls verify?

pls run pytest -s -vvv tests/collections/tts/modules/test_tts_new_transformer.py locally to test the code.

@paarthneekhara
Copy link
Collaborator Author

@XuesongYang We need to pass multi_encoder_mapping to the forward function as well for multi-encoder case. I have updated the test case with some comments (that still need to be incorporated) and fixes. Also corrected the x = (x + x_) * x_mask.unsqueeze(-1) bug which I believe was inserted when the mask was flipped.

FYI, I tested this new transformer code training and inference with t5tts locally, and it seems to be working fine. Also for a fixed set of weights, the transformer implementation in experimentalt5tts and this branch, give the same output, so I think we should be good.

blisc added 2 commits January 21, 2025 08:41
Signed-off-by: Jason <[email protected]>
Signed-off-by: Jason <[email protected]>
Copy link
Contributor

beep boop 🤖: 🚨 The following files must be fixed before merge!


Your code was analyzed with PyLint. The following annotations have been identified:

************* Module nemo.collections.tts.modules.transformer_2412
nemo/collections/tts/modules/transformer_2412.py:26:0: C0115: Missing class docstring (missing-class-docstring)
nemo/collections/tts/modules/transformer_2412.py:85:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/tts/modules/transformer_2412.py:94:0: C0115: Missing class docstring (missing-class-docstring)
nemo/collections/tts/modules/transformer_2412.py:138:0: C0115: Missing class docstring (missing-class-docstring)
nemo/collections/tts/modules/transformer_2412.py:171:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/tts/modules/transformer_2412.py:191:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/tts/modules/transformer_2412.py:195:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/tts/modules/transformer_2412.py:280:0: C0115: Missing class docstring (missing-class-docstring)
nemo/collections/tts/modules/transformer_2412.py:340:0: C0115: Missing class docstring (missing-class-docstring)
nemo/collections/tts/modules/transformer_2412.py:398:0: C0115: Missing class docstring (missing-class-docstring)
nemo/collections/tts/modules/transformer_2412.py:471:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/tts/modules/transformer_2412.py:536:0: C0115: Missing class docstring (missing-class-docstring)
nemo/collections/tts/modules/transformer_2412.py:627:4: C0116: Missing function or method docstring (missing-function-docstring)

-----------------------------------
Your code has been rated at 9.47/10

Mitigation guide:

  • Add sensible and useful docstrings to functions and methods
  • For trivial methods like getter/setters, consider adding # pylint: disable=C0116 inside the function itself
  • To disable multiple functions/methods at once, put a # pylint: disable=C0116 before the first and a # pylint: enable=C0116 after the last.

By applying these rules, we reduce the occurance of this message in future.

Thank you for improving NeMo's documentation!

@blisc
Copy link
Collaborator

blisc commented Jan 21, 2025

Closed in favour of #11911

@blisc blisc closed this Jan 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants