-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merging all stages into one #4
Comments
1 2/3. MDN loss is very memory expensive to compute, so you might want to avoid calculating for as long as possible. Otherwise, I think you would be fine. My copy seems to converge just fine and I run the entire network at once. |
Makes sense, thanks. I gess the stages are there for efficiency purposes and not necessarily to ensure convergence. |
Could you tell me what modifications had to be done? |
(1) Line 19 in ed9c29d
(2) Lines 22 to 23 in ed9c29d
This computation involves 3 calculations that can be pruned. The Here's with the modifications. I also use exponential = -0.5*torch.sum((x-mu).pow_(2)/log_sigma.exp().pow_(2), dim=-1) # B, L, T
log_prob_matrix = exponential -0.5 * log_sigma.sum(dim=-1)# -(hp.n_mel_channels/2)*torch.log(torch.tensor(2*math.pi)) (3) Example graph of what my MDN loss looks like when I calculate the mean; To calculate the mean loss, just replace the exponential = -0.5*torch.sum((x-mu).pow_(2)/log_sigma.exp().pow_(2), dim=-1) # B, L, T
log_prob_matrix = exponential -0.5 * log_sigma.sum(dim=-1)# -(hp.n_mel_channels/2)*torch.log(torch.tensor(2*math.pi)) becomes exponential = -0.5*torch.mean((x-mu).pow_(2)/log_sigma.exp().pow_(2), dim=-1) # B, L, T
log_prob_matrix = exponential -0.5 * log_sigma.mean(dim=-1)# -(hp.n_mel_channels/2)*torch.log(torch.tensor(2*math.pi)) and Line 32 in ed9c29d
becomes mdn_loss = -alpha_last.sum()/mel_lengths.sum() |
Oh wow. I learned something today. exponential = -0.5*torch.mean(torch.nn.functional.mse_loss(x, mu, reduction='none')/log_sigma.exp().pow_(2), dim=-1) # B, L, T uses 25% less VRAM than exponential = -0.5*torch.mean((x-mu).pow_(2)/log_sigma.exp().pow_(2), dim=-1) # B, L, T But for some reason it gives out an annoying warning.
The outputs and gradients are still 100% correct, so feel free to use that too. Just you will have to supress the warning or ignore it. |
Great! Thanks for your advice. It makes good sense. I am thinking about reimplementing the alignment loss with numba to speed up the backward pass. Hopefully it will be possible to merge all the stages into one.🤔 |
Hi, thanks have you tried to train the model completely end to end? (all stages merged into one) If not, what would be the rationale in not doing so? is the model expected not to converge in such case?
The text was updated successfully, but these errors were encountered: