Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merging all stages into one #4

Open
janvainer opened this issue Jan 22, 2021 · 6 comments
Open

Merging all stages into one #4

janvainer opened this issue Jan 22, 2021 · 6 comments

Comments

@janvainer
Copy link

Hi, thanks have you tried to train the model completely end to end? (all stages merged into one) If not, what would be the rationale in not doing so? is the model expected not to converge in such case?

@CookiePPP
Copy link

CookiePPP commented Jan 22, 2021

1
I have not used this repo in it's original state, but I can confirm that all the modules appear to work. The MDN loss needs to be modified slightly but is otherwise perfectly usable.

2/3.
https://github.com/Deepest-Project/AlignTTS/blob/master/modules/loss.py#L22

MDN loss is very memory expensive to compute, so you might want to avoid calculating for as long as possible.

Otherwise, I think you would be fine. My copy seems to converge just fine and I run the entire network at once.

@janvainer
Copy link
Author

Makes sense, thanks. I gess the stages are there for efficiency purposes and not necessarily to ensure convergence.

@janvainer
Copy link
Author

Could you tell me what modifications had to be done?

@CookiePPP
Copy link

CookiePPP commented Jan 23, 2021

(1)

mu = torch.sigmoid(mu_sigma[:, :, :hp.n_mel_channels].unsqueeze(2)) # B, L, 1, F

mu should be able to vary between all possible values of the spectrogram elements. -11.52 to 4.5 if unnormalized. I don't know why but torch.sigmoid is applied to this tensor in the stock repo. I recommend removing it or multiplying the outputs by 11.52.

(2)

AlignTTS/modules/loss.py

Lines 22 to 23 in ed9c29d

exponential = -0.5*torch.sum((x-mu)*(x-mu)/log_sigma.exp()**2, dim=-1) # B, L, T
log_prob_matrix = exponential - (hp.n_mel_channels/2)*torch.log(torch.tensor(2*math.pi)) - 0.5 * log_sigma.sum(dim=-1)

This computation involves 3 calculations that can be pruned. The -0.5 * x only needs to be done once. x-mu only needs to be done once. The -(hp.n_mel_channels/2)*torch.log(torch.tensor(2*math.pi)) is important to be technically accurate, but it doesn't change the gradient flow (I believe) and I've not found it to be important to calculate the alignment.

Here's with the modifications. I also use .pow_(2) which will do the squared operation inplace and save VRAM during the forward pass.

exponential = -0.5*torch.sum((x-mu).pow_(2)/log_sigma.exp().pow_(2), dim=-1) # B, L, T 
log_prob_matrix = exponential -0.5 * log_sigma.sum(dim=-1)# -(hp.n_mel_channels/2)*torch.log(torch.tensor(2*math.pi))

(3)
This function is calculating the sum of all the NLL's. I'm not sure if this is intentional but it makes the gradients explode and makes it very hard to keep track of model performance.

Example graph of what my MDN loss looks like when I calculate the mean;
image

To calculate the mean loss, just replace the torch.sum with torch.mean and divide the final loss by the number of non-padded frames.


exponential = -0.5*torch.sum((x-mu).pow_(2)/log_sigma.exp().pow_(2), dim=-1) # B, L, T 
log_prob_matrix = exponential -0.5 * log_sigma.sum(dim=-1)# -(hp.n_mel_channels/2)*torch.log(torch.tensor(2*math.pi))

becomes

exponential = -0.5*torch.mean((x-mu).pow_(2)/log_sigma.exp().pow_(2), dim=-1) # B, L, T 
log_prob_matrix = exponential -0.5 * log_sigma.mean(dim=-1)# -(hp.n_mel_channels/2)*torch.log(torch.tensor(2*math.pi))

and

mdn_loss = -alpha_last.mean()

becomes

mdn_loss = -alpha_last.sum()/mel_lengths.sum()

@CookiePPP
Copy link

CookiePPP commented Jan 23, 2021

Oh wow. I learned something today.

exponential = -0.5*torch.mean(torch.nn.functional.mse_loss(x, mu, reduction='none')/log_sigma.exp().pow_(2), dim=-1) # B, L, T 

uses 25% less VRAM than

exponential = -0.5*torch.mean((x-mu).pow_(2)/log_sigma.exp().pow_(2), dim=-1) # B, L, T 

But for some reason it gives out an annoying warning.

C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\functional.py:2652: UserWarning: Using a target size ([4, 256, 1, 160]) that is different to the input size ([4, 1, 1024, 160]). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
  warnings.warn("Using a target size ({}) that is different to the input size ({}). "

The outputs and gradients are still 100% correct, so feel free to use that too. Just you will have to supress the warning or ignore it.

@janvainer
Copy link
Author

Great! Thanks for your advice. It makes good sense. I am thinking about reimplementing the alignment loss with numba to speed up the backward pass. Hopefully it will be possible to merge all the stages into one.🤔

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants