-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fail to train in mini-Imagenet #8
Comments
I think you can modify the parameters: def edl_loss(self, func, y, alpha, annealing_step, device="cuda"):
y = self.one_hot_embedding(y)
y = y.to(device)
alpha = alpha.to(device)
S = torch.sum(alpha, dim=1, keepdim=True)
A = torch.sum(y * (func(S) - func(alpha)), dim=1, keepdim=True)
# annealing_coef = torch.min(
# torch.tensor(1.0, dtype=torch.float32),
# torch.tensor(self.epoch / annealing_step, dtype=torch.float32),
# )
annealing_coef = 0.1
kl_alpha = (alpha - 1) * (1 - y) + 1
kl_div = annealing_coef * self.kl_divergence(kl_alpha, device=device)
return A + kl_div set |
Hi, thank you for your answer. I try to set 'annealing_coef' as 0.1 and 0.05 respectively, but it still not works. Do you let it works successfully? |
I haven't tried this repo, but I have tried to train an 8631 ID face recognition network, we use resnet-100 is ok, can get the same accuracy on the in-distribution dataset as the softmax training method, resnet-50 can't get to convergence. Another thing is that we all find the KL loss will damage the accuracy, try to decrease the coefficients or just remove it. |
Thank you so much. It still don't works. I think I may need to fine-tune other hyperparameters. |
maybe you can refer this https://github.com/RuoyuChen10/FaceTechnologyTool/blob/master/FaceRecognition/evidential_learning.py, I have tried this on Face Recognition. I'm also failed before. I conclude it's mainly about:
Maybe the learning rate and depth of the network has few influences on softmax and Cross-Entropy Loss training method. |
I use the edl loss to train in mini-imagenet dataset with 64 classes, but the loss can't converge and the accuracy is very low.
The text was updated successfully, but these errors were encountered: