Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zygote AD & logpdf for transformed multivariate #217

Open
tpgillam opened this issue Apr 13, 2022 · 0 comments
Open

Zygote AD & logpdf for transformed multivariate #217

tpgillam opened this issue Apr 13, 2022 · 0 comments

Comments

@tpgillam
Copy link

I've found that Zygote fails to compute gradients when using the method of logpdf defined here

Here's a MWE:

using Bijectors
using DistributionsAD
using Flux
using Zygote

d = MvNormal(zeros(2), ones(2))
b = PlanarLayer(2)
flow = transformed(d, b)

x = [0.42 0.24; 0.42 0.24]

"""Use the optimised `logpdf` call."""
loss_(flow, x) = -sum(logpdf(flow, x))

"""Rearrange to use default `logpdf` in `Distributions`."""
function loss2_(flow, x)
    things = map(eachcol(x)) do obs
        logpdf(flow, obs)
    end
    return -sum(things)
end

@show loss_(flow, x)
@show loss2_(flow, x)

println()

gs = gradient(() -> loss_(flow, x), Flux.params(b))
@show gs.grads[Flux.params(b)[1]]

gs = gradient(() -> loss2_(flow, x), Flux.params(b))
@show gs.grads[Flux.params(b)[1]];

With output:

loss_(flow, x) = 3.089176357252711
loss2_(flow, x) = 3.089176357252711

gs.grads[(Flux.params(b))[1]] = nothing
gs.grads[(Flux.params(b))[1]] = [-2.603210756288831, -4.3264084139896095]

tested on Bijectors v0.10.0.

I'm not sure, but maybe the optimised dispatch for logpdf (or some of the methods called within) need additional chainrules support?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant