-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add transformer class for review #11491
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Paarth Neekhara <[email protected]>
Signed-off-by: paarthneekhara <[email protected]>
|
||
self.d_model = d_model | ||
self.non_linearity = nn.GELU(approximate="tanh") | ||
self.proj = ConvNorm(d_model, d_model * 4, bias=bias, kernel_size=kernel_size, is_causal=is_causal) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should the FFN size be a configuration instead of hardcoded to 4 * d_model
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Paarth changed this
q = self.q_net(query).reshape(Bq, Tq, self.n_heads, self.d_head) | ||
kv = self.kv_net(memory).reshape(Bkv, Tkv, 2, self.n_heads, self.d_head) | ||
if self.pos_emb_name == 'rope': | ||
q, kv = self.rope(q, kv) | ||
elif self.pos_emb_name == 'alibi': | ||
alibi_slopes = self.m[:, 0, 0] | ||
q = q[~query_mask].reshape(-1, self.n_heads, self.d_head) | ||
kv = kv[~memory_mask].reshape(-1, 2, self.n_heads, self.d_head) | ||
lengths_q = (~query_mask).sum(1) | ||
lengths_k = (~memory_mask).sum(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code like this will be a lot easier to read if we replace .reshape () with einops rearrange(), and add comments with the output shapes for operations that are not reshape/rearrange.
if self.has_xattn: | ||
self.cross_attention.reset_cache(use_cache) | ||
|
||
def forward(self, x, x_mask, cond, cond_mask, dump_attention=False, attn_prior=None, idx=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If cond and cond_mask are optional we should default them to None.
Should we throw an error if cond
is provided, but self.has_xattn
is False? Or if cond
is not provided, but self.has_xattn
is True?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can default them to None, but I wouldnt raise an error if has_xattn is True and cond is None. I use that feature to pretrain the decoder with context as None, but still having the same architecture and parameters when using it as the pretrained T5 decoder for TTS,
p_dropout=p_dropout, | ||
is_causal=False, | ||
is_self_attention=False, | ||
d_memory=params['d_heads'], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should rename d_heads
in params to d_memory
here. d_memory
is supposed to be the dimension of the context information for cross attention. d_heads
refers to the size of each attention head, but which this code hardcodes to be d_memory // n_heads
.
…d suggest careful review of changes Signed-off-by: Paarth Neekhara <[email protected]>
Signed-off-by: paarthneekhara <[email protected]>
Signed-off-by: Jason <[email protected]>
…ove remove_self_attention param Signed-off-by: Jason <[email protected]>
…TransformerLayer init Signed-off-by: Jason <[email protected]>
Signed-off-by: blisc <[email protected]>
Signed-off-by: Jason <[email protected]>
Signed-off-by: Jason <[email protected]>
Signed-off-by: blisc <[email protected]>
Signed-off-by: Jason <[email protected]>
…into tts_new_transformer
Signed-off-by: Jason <[email protected]>
use_flash_self_attention=True, | ||
use_flash_x_attention=True, | ||
deterministic=False, | ||
pos_emb={"name": "learnable"}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make the pos_emb argument more structured? Either a dataclass, or similar to xattn flatten into parameters like pos_emb_name
, pos_emb_base
, pos_emb_kwargs
, etc.
Nitpick: Mutable objects like dictionaries should not be used as default arguments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
vote for @dataclass
to group the configs.
has_xattn, | ||
xa_d_memory=None, | ||
xa_n_heads=None, | ||
xa_pos_emb=None, | ||
xa_max_length_causal_mask=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we group these into a CrossAttentionConfig
dataclass? To make it clear which arguments are related/optional. Then we can check if the config is None rather than the has_xattn
flag.
…n pos emb and x-attn causal mask args Signed-off-by: Paarth Neekhara <[email protected]>
Signed-off-by: paarthneekhara <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
went over the full implementations, and leave notes about bugfixes, and recommendations on refactoring for readability.
Co-authored-by: Xuesong Yang <[email protected]> Signed-off-by: Paarth Neekhara <[email protected]>
Signed-off-by: Paarth Neekhara <[email protected]>
Signed-off-by: paarthneekhara <[email protected]>
Signed-off-by: Xuesong Yang <[email protected]>
Signed-off-by: Xuesong Yang <[email protected]>
Signed-off-by: Jason <[email protected]>
Signed-off-by: blisc <[email protected]>
beep boop 🤖: 🚨 The following files must be fixed before merge! Your code was analyzed with PyLint. The following annotations have been identified:
Mitigation guide:
By applying these rules, we reduce the occurance of this message in future. Thank you for improving NeMo's documentation! |
Added the transformer stack currently being used in T5TTS - Identify unused code paths, clean up the code and see what modules can be reused.