Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gradient of Flux.normalise return NaN when std is zero #2096

Closed
chengchingwen opened this issue Nov 1, 2022 · 8 comments
Closed

gradient of Flux.normalise return NaN when std is zero #2096

chengchingwen opened this issue Nov 1, 2022 · 8 comments

Comments

@chengchingwen
Copy link
Member

$\epsilon$ argument of Flux.normalise only prevent the forward value from division by zero, but there is also an division by $\sigma$ in pullback of std. We might need a custom rrule for Flux.normalise.

julia> Zygote.gradient(x->sum(sin.(Flux.normalise(x;dims=1))), ones(3,3))
([NaN NaN NaN; NaN NaN NaN; NaN NaN NaN],)
@mcabbott
Copy link
Member

mcabbott commented Nov 1, 2022

Xref JuliaML/MLUtils.jl#123 (about moving & renaming) and #1992 (about NaN from batch of 1).

@ToucheSir
Copy link
Member

Other frameworks have implemented this using sqrt + var + the eps instead of using std directly.

@chengchingwen
Copy link
Member Author

and #1992 (about NaN from batch of 1).

Not sure if #1992 is related. Batchnorm don't use normalise and this issue is not caused by batch size.

Other frameworks have implemented this using sqrt + var + the eps instead of using std directly.

FWIW, PyTorch layernorm add the eps to var and then store the var with eps and use it directly in the pullback.
One quick and dirty solution can be inplace updating the std value with eps (with AD ignoring it).

This also brings up an issue about the error between real value and the value with eps. Since we are dividing by $(\sqrt{var} + \epsilon)$ and they are dividing by $(\sqrt{var + \epsilon})$, the resulting value could have $\epsilon$ times difference. I'm not sure which is better (IMPO it should be handled by a branch)

@ToucheSir
Copy link
Member

I think we'd have to go with $\sqrt{var + \epsilon}$ because the rule for sqrt(x) divides by 2*sqrt(x).

@chengchingwen
Copy link
Member Author

You can replace 2sqrt(x) with 2sqrt(x) + ϵ if you wrap everything in a single rrule

@ToucheSir
Copy link
Member

For sure, but I'm loath to create a rrule just for this. I actually have a WIP PR bringing the norm functions to NNlib, so @chengchingwen if you want to continue this design discussion I can publish it.

@chengchingwen
Copy link
Member Author

I would be interested. I actually have a function for computing the gradient of a layer norm directly in NAlib. This is the best (in terms of both performance and memory efficient) I can get without writing cuda kernel. The gradient of normalisecan be easily split out from it. So let's see if we can get even more performance from the new design.

RomeoV added a commit to RomeoV/DisentanglingVAE.jl that referenced this issue Mar 16, 2023
There is a problem in normalise that if `std(x) \approx 0`, then the
chain rule evaluates to NaN. See e.g. here: [FluxML/Flux.jl#2096].

We tried to fix this here by adding some noise to x, although that might
not be the best solution. We also fix in a later commit that all images
actually have some noise in the background.
@ToucheSir
Copy link
Member

#2421 has been merged now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants