You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am trying to align the HF transformers implementation of T5 with MTF and T5X, such that it can be relied upon for pretraining.
Could you possibly confirm whether what I've found here regarding HF's pretraining weight initialization, is a difference compared to the official T5X or MTF implementations? huggingface/transformers#26441
…but they are not applying this compensation in the untied case.
Consequently, I see that the lm_head initially outputs huge activations with variance ~= hidden_dim, and results in initial cross-entropy loss of ~110.
I see that nanoT5 copied the same convention, and empirically exhibits the same behaviour: huge initial loss. PiotrNawrot/nanoT5#25
I think this will have had consequences for the paper they wrote around pretraining on HF code.
though there is a competing theory (based on MTF) that lm_head ought to be initialized to std=0.05: huggingface/transformers#26441 (comment)
I do note though that 0.05 is of a very similar magnitude to hidden_dim**-.5, so perhaps it could work similarly well:
512**-.5
0.044
768**-.5
0.036
Does this sound about right? Should an untied lm_head be initialized with std=hidden_dim**-.5 normally-distributed noise?
Thanks @Birch-san for opening this issue. I’m an interested observer who is anxious to hear the response. This would mean a lot to those of us who want to pre-train T5 the “correct” way.
Hi t5x community,
I am trying to align the HF transformers implementation of T5 with MTF and T5X, such that it can be relied upon for pretraining.
Could you possibly confirm whether what I've found here regarding HF's pretraining weight initialization, is a difference compared to the official T5X or MTF implementations?
huggingface/transformers#26441
HF initialize their (untied)
lm_head
withstd=1
:https://github.com/huggingface/transformers/blob/3b7675b2b844b02d4821b827871a21ad16dd446c/src/transformers/models/t5/modeling_t5.py#L831
I suspect this is a mistake. I think they are approaching its initialization as though the lm_head were to be tied to the embedding layer (though in this code path, it is not).
Perhaps they could compensate for its increased variance by scaling the logits down before giving them to the lm_head:
https://github.com/huggingface/transformers/blob/3b7675b2b844b02d4821b827871a21ad16dd446c/src/transformers/models/t5/modeling_t5.py#L1769
…but they are not applying this compensation in the untied case.
Consequently, I see that the lm_head initially outputs huge activations with
variance ~= hidden_dim
, and results in initial cross-entropy loss of ~110.I see that nanoT5 copied the same convention, and empirically exhibits the same behaviour: huge initial loss.
PiotrNawrot/nanoT5#25
I think this will have had consequences for the paper they wrote around pretraining on HF code.
As for the fix…
I think I found evidence in t5x that the lm_head initializiation needs to be changed to
std=hidden_dim**-.5
:huggingface/transformers#26441 (comment)
though there is a competing theory (based on MTF) that lm_head ought to be initialized to
std=0.05
:huggingface/transformers#26441 (comment)
I do note though that 0.05 is of a very similar magnitude to
hidden_dim**-.5
, so perhaps it could work similarly well:Does this sound about right? Should an untied lm_head be initialized with
std=hidden_dim**-.5
normally-distributed noise?Bonus question: are any of HF's other layers initialized incorrectly?
https://github.com/huggingface/transformers/blob/3b7675b2b844b02d4821b827871a21ad16dd446c/src/transformers/models/t5/modeling_t5.py#L818
Thanks!
The text was updated successfully, but these errors were encountered: