Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add transformer class for review #11491

Draft
wants to merge 23 commits into
base: main
Choose a base branch
from

Conversation

paarthneekhara
Copy link
Collaborator

Added the transformer stack currently being used in T5TTS - Identify unused code paths, clean up the code and see what modules can be reused.

Signed-off-by: Paarth Neekhara <[email protected]>
@github-actions github-actions bot added the TTS label Dec 5, 2024
nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved
@blisc blisc requested review from XuesongYang and rlangman December 9, 2024 18:33
nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved

self.d_model = d_model
self.non_linearity = nn.GELU(approximate="tanh")
self.proj = ConvNorm(d_model, d_model * 4, bias=bias, kernel_size=kernel_size, is_causal=is_causal)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the FFN size be a configuration instead of hardcoded to 4 * d_model?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Paarth changed this

nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved
Comment on lines 268 to 277
q = self.q_net(query).reshape(Bq, Tq, self.n_heads, self.d_head)
kv = self.kv_net(memory).reshape(Bkv, Tkv, 2, self.n_heads, self.d_head)
if self.pos_emb_name == 'rope':
q, kv = self.rope(q, kv)
elif self.pos_emb_name == 'alibi':
alibi_slopes = self.m[:, 0, 0]
q = q[~query_mask].reshape(-1, self.n_heads, self.d_head)
kv = kv[~memory_mask].reshape(-1, 2, self.n_heads, self.d_head)
lengths_q = (~query_mask).sum(1)
lengths_k = (~memory_mask).sum(1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code like this will be a lot easier to read if we replace .reshape () with einops rearrange(), and add comments with the output shapes for operations that are not reshape/rearrange.

nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved
if self.has_xattn:
self.cross_attention.reset_cache(use_cache)

def forward(self, x, x_mask, cond, cond_mask, dump_attention=False, attn_prior=None, idx=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If cond and cond_mask are optional we should default them to None.

Should we throw an error if cond is provided, but self.has_xattn is False? Or if cond is not provided, but self.has_xattn is True?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can default them to None, but I wouldnt raise an error if has_xattn is True and cond is None. I use that feature to pretrain the decoder with context as None, but still having the same architecture and parameters when using it as the pretrained T5 decoder for TTS,

nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved
p_dropout=p_dropout,
is_causal=False,
is_self_attention=False,
d_memory=params['d_heads'],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should rename d_heads in params to d_memory here. d_memory is supposed to be the dimension of the context information for cross attention. d_heads refers to the size of each attention head, but which this code hardcodes to be d_memory // n_heads.

paarthneekhara and others added 2 commits December 13, 2024 19:16
…d suggest careful review of changes

Signed-off-by: Paarth Neekhara <[email protected]>
Signed-off-by: Jason <[email protected]>
@blisc blisc self-requested a review December 17, 2024 19:04
use_flash_self_attention=True,
use_flash_x_attention=True,
deterministic=False,
pos_emb={"name": "learnable"},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make the pos_emb argument more structured? Either a dataclass, or similar to xattn flatten into parameters like pos_emb_name, pos_emb_base, pos_emb_kwargs, etc.

Nitpick: Mutable objects like dictionaries should not be used as default arguments.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vote for @dataclass to group the configs.

Comment on lines 376 to 380
has_xattn,
xa_d_memory=None,
xa_n_heads=None,
xa_pos_emb=None,
xa_max_length_causal_mask=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we group these into a CrossAttentionConfig dataclass? To make it clear which arguments are related/optional. Then we can check if the config is None rather than the has_xattn flag.

paarthneekhara and others added 2 commits December 19, 2024 02:43
…n pos emb and x-attn causal mask args

Signed-off-by: Paarth Neekhara <[email protected]>
Copy link
Collaborator

@XuesongYang XuesongYang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

went over the full implementations, and leave notes about bugfixes, and recommendations on refactoring for readability.

nemo/collections/tts/modules/transformer_2412.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_2412.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_2412.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_2412.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_2412.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_2412.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_2412.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_2412.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_2412.py Outdated Show resolved Hide resolved
Co-authored-by: Xuesong Yang <[email protected]>
Signed-off-by: Paarth Neekhara <[email protected]>
Copy link
Contributor

github-actions bot commented Jan 3, 2025

beep boop 🤖: 🚨 The following files must be fixed before merge!


Your code was analyzed with PyLint. The following annotations have been identified:

************* Module nemo.collections.tts.modules.transformer_2412
nemo/collections/tts/modules/transformer_2412.py:22:0: C0115: Missing class docstring (missing-class-docstring)
nemo/collections/tts/modules/transformer_2412.py:71:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/tts/modules/transformer_2412.py:81:0: C0115: Missing class docstring (missing-class-docstring)
nemo/collections/tts/modules/transformer_2412.py:125:0: C0115: Missing class docstring (missing-class-docstring)
nemo/collections/tts/modules/transformer_2412.py:156:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/tts/modules/transformer_2412.py:176:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/tts/modules/transformer_2412.py:180:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/tts/modules/transformer_2412.py:264:0: C0115: Missing class docstring (missing-class-docstring)
nemo/collections/tts/modules/transformer_2412.py:317:0: C0115: Missing class docstring (missing-class-docstring)
nemo/collections/tts/modules/transformer_2412.py:366:0: C0115: Missing class docstring (missing-class-docstring)
nemo/collections/tts/modules/transformer_2412.py:441:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/tts/modules/transformer_2412.py:532:0: C0115: Missing class docstring (missing-class-docstring)
nemo/collections/tts/modules/transformer_2412.py:623:4: C0116: Missing function or method docstring (missing-function-docstring)

-----------------------------------
Your code has been rated at 9.51/10

Mitigation guide:

  • Add sensible and useful docstrings to functions and methods
  • For trivial methods like getter/setters, consider adding # pylint: disable=C0116 inside the function itself
  • To disable multiple functions/methods at once, put a # pylint: disable=C0116 before the first and a # pylint: enable=C0116 after the last.

By applying these rules, we reduce the occurance of this message in future.

Thank you for improving NeMo's documentation!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants