You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The ShiftCrossEntropy currently utilizes nn.CrossEntropyLoss as its backend, which expects the input to be unnormalized logits. It appears that ShiftCrossEntropy passes input probabilities and target probabilities to the backend instead. This might lead to a deviation from the expected behavior described in equation (7) of the paper.
In my opinion, KL divergence should have the same effect as cross entropy loss, since in the code, the target is detached, and these two losses differ only by the entropy of the target. However, replacing the cross entropy loss with KL divergence make the model fail to converge.
The reason might be numerical issues of pytorch, or as is mentioned, the misuse of nn.CrossEntropyLoss, or other factors...
The
ShiftCrossEntropy
currently utilizesnn.CrossEntropyLoss
as its backend, which expects the input to be unnormalized logits. It appears thatShiftCrossEntropy
passes input probabilities and target probabilities to the backend instead. This might lead to a deviation from the expected behavior described in equation (7) of the paper.pesto-full/src/losses/entropy.py
Line 49 in 229f78b
The text was updated successfully, but these errors were encountered: