You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I've found that Zygote fails to compute gradients when using the method of logpdf defined here
Here's a MWE:
using Bijectors
using DistributionsAD
using Flux
using Zygote
d =MvNormal(zeros(2), ones(2))
b =PlanarLayer(2)
flow =transformed(d, b)
x = [0.420.24; 0.420.24]
"""Use the optimised `logpdf` call."""loss_(flow, x) =-sum(logpdf(flow, x))
"""Rearrange to use default `logpdf` in `Distributions`."""functionloss2_(flow, x)
things =map(eachcol(x)) do obs
logpdf(flow, obs)
endreturn-sum(things)
end@showloss_(flow, x)
@showloss2_(flow, x)
println()
gs =gradient(() ->loss_(flow, x), Flux.params(b))
@show gs.grads[Flux.params(b)[1]]
gs =gradient(() ->loss2_(flow, x), Flux.params(b))
@show gs.grads[Flux.params(b)[1]];
I've found that Zygote fails to compute gradients when using the method of
logpdf
defined hereHere's a MWE:
With output:
tested on
Bijectors v0.10.0
.I'm not sure, but maybe the optimised dispatch for
logpdf
(or some of the methods called within) need additional chainrules support?The text was updated successfully, but these errors were encountered: