Skip to content

Commit

Permalink
fix avoid re-defining the differentiation objective to support AD pre…
Browse files Browse the repository at this point in the history
…-compilation (#66)

* update interface for objective initialization
* improve `RepGradELBO` to not redefine AD forward path
* add auxiliary argument to `value_and_gradient!`
  • Loading branch information
Red-Portal authored Jun 15, 2024
1 parent c93b5d7 commit cb3b838
Show file tree
Hide file tree
Showing 10 changed files with 146 additions and 44 deletions.
6 changes: 3 additions & 3 deletions ext/AdvancedVIBijectorsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ function AdvancedVI.reparam_with_entropy(
n_samples::Int,
ent_est ::AdvancedVI.AbstractEntropyEstimator
)
transform = q.transform
q_unconst = q.dist
q_unconst_stop = q_stop.dist
transform = q.transform
q_unconst = q.dist
q_unconst_stop = q_stop.dist

# Draw samples and compute entropy of the uncontrained distribution
unconstr_samples, unconst_entropy = AdvancedVI.reparam_with_entropy(
Expand Down
23 changes: 18 additions & 5 deletions ext/AdvancedVIForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,29 @@ end
getchunksize(::ADTypes.AutoForwardDiff{chunksize}) where {chunksize} = chunksize

function AdvancedVI.value_and_gradient!(
ad::ADTypes.AutoForwardDiff, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult
) where {T<:Real}
ad ::ADTypes.AutoForwardDiff,
f,
x ::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult
)
chunk_size = getchunksize(ad)
config = if isnothing(chunk_size)
ForwardDiff.GradientConfig(f, θ)
ForwardDiff.GradientConfig(f, x)
else
ForwardDiff.GradientConfig(f, θ, ForwardDiff.Chunk(length(θ), chunk_size))
ForwardDiff.GradientConfig(f, x, ForwardDiff.Chunk(length(x), chunk_size))
end
ForwardDiff.gradient!(out, f, θ, config)
ForwardDiff.gradient!(out, f, x, config)
return out
end

function AdvancedVI.value_and_gradient!(
ad ::ADTypes.AutoForwardDiff,
f,
x ::AbstractVector,
aux,
out::DiffResults.MutableDiffResult
)
AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out)
end

end
19 changes: 16 additions & 3 deletions ext/AdvancedVIReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,24 @@ end

# ReverseDiff without compiled tape
function AdvancedVI.value_and_gradient!(
ad::ADTypes.AutoReverseDiff, f, θ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult
ad::ADTypes.AutoReverseDiff,
f,
x::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult
)
tp = ReverseDiff.GradientTape(f, θ)
ReverseDiff.gradient!(out, tp, θ)
tp = ReverseDiff.GradientTape(f, x)
ReverseDiff.gradient!(out, tp, x)
return out
end

function AdvancedVI.value_and_gradient!(
ad::ADTypes.AutoReverseDiff,
f,
x::AbstractVector{<:Real},
aux,
out::DiffResults.MutableDiffResult
)
AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out)
end

end
23 changes: 19 additions & 4 deletions ext/AdvancedVIZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,36 @@ module AdvancedVIZygoteExt
if isdefined(Base, :get_extension)
using AdvancedVI
using AdvancedVI: ADTypes, DiffResults
using ChainRulesCore
using Zygote
else
using ..AdvancedVI
using ..AdvancedVI: ADTypes, DiffResults
using ..ChainRulesCore
using ..Zygote
end

function AdvancedVI.value_and_gradient!(
ad::ADTypes.AutoZygote, f, θ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult
::ADTypes.AutoZygote,
f,
x::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult
)
y, back = Zygote.pullback(f, θ)
θ = back(one(y))
y, back = Zygote.pullback(f, x)
x = back(one(y))
DiffResults.value!(out, y)
DiffResults.gradient!(out, only(∇θ))
DiffResults.gradient!(out, only(∇x))
return out
end

function AdvancedVI.value_and_gradient!(
ad::ADTypes.AutoZygote,
f,
x::AbstractVector{<:Real},
aux,
out::DiffResults.MutableDiffResult
)
AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out)
end

end
32 changes: 25 additions & 7 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,35 @@ using StatsBase

# derivatives
"""
value_and_gradient!(ad, f, θ, out)
value_and_gradient!(ad, f, x, out)
value_and_gradient!(ad, f, x, aux, out)
Evaluate the value and gradient of a function `f` at `θ` using the automatic differentiation backend `ad` and store the result in `out`.
Evaluate the value and gradient of a function `f` at `x` using the automatic differentiation backend `ad` and store the result in `out`.
`f` may receive auxiliary input as `f(x,aux)`.
# Arguments
- `ad::ADTypes.AbstractADType`: Automatic differentiation backend.
- `f`: Function subject to differentiation.
- `θ`: The point to evaluate the gradient.
- `x`: The point to evaluate the gradient.
- `aux`: Auxiliary input passed to `f`.
- `out::DiffResults.MutableDiffResult`: Buffer to contain the output gradient and function value.
"""
function value_and_gradient! end

"""
stop_gradient(x)
Stop the gradient from propagating to `x` if the selected ad backend supports it.
Otherwise, it is equivalent to `identity`.
# Arguments
- `x`: Input
# Returns
- `x`: Same value as the input.
"""
function stop_gradient end

# Update for gradient descent step
"""
update_variational_params!(family_type, opt_st, params, restructure, grad)
Expand Down Expand Up @@ -78,22 +95,23 @@ If the estimator is stateful, it can implement `init` to initialize the state.
abstract type AbstractVariationalObjective end

"""
init(rng, obj, λ, restructure)
init(rng, obj, prob, params, restructure)
Initialize a state of the variational objective `obj` given the initial variational parameters `λ`.
This function needs to be implemented only if `obj` is stateful.
# Arguments
- `rng::Random.AbstractRNG`: Random number generator.
- `obj::AbstractVariationalObjective`: Variational objective.
- `λ`: Initial variational parameters.
- `params`: Initial variational parameters.
- `restructure`: Function that reconstructs the variational approximation from `λ`.
"""
init(
::Random.AbstractRNG,
::AbstractVariationalObjective,
::AbstractVector,
::Any
::Any,
::Any,
::Any,
) = nothing

"""
Expand Down
36 changes: 21 additions & 15 deletions src/objectives/elbo/repgradelbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,13 @@ function estimate_energy_with_samples(prob, samples)
end

"""
reparam_with_entropy(rng, q, q_stop, n_samples, ent_est)
reparam_with_entropy(rng, q, n_samples, ent_est)
Draw `n_samples` from `q` and compute its entropy.
# Arguments
- `rng::Random.AbstractRNG`: Random number generator.
- `q`: Variational approximation.
- `q_stop`: `q` but with its gradient stopped.
- `n_samples::Int`: Number of Monte Carlo samples
- `ent_est`: The entropy estimation strategy. (See `estimate_entropy`.)
Expand All @@ -72,7 +71,11 @@ Draw `n_samples` from `q` and compute its entropy.
- `entropy`: An estimate (or exact value) of the differential entropy of `q`.
"""
function reparam_with_entropy(
rng::Random.AbstractRNG, q, q_stop, n_samples::Int, ent_est::AbstractEntropyEstimator
rng ::Random.AbstractRNG,
q,
q_stop,
n_samples::Int,
ent_est ::AbstractEntropyEstimator
)
samples = rand(rng, q, n_samples)
entropy = estimate_entropy_maybe_stl(ent_est, samples, q, q_stop)
Expand All @@ -94,28 +97,31 @@ end
estimate_objective(obj::RepGradELBO, q, prob; n_samples::Int = obj.n_samples) =
estimate_objective(Random.default_rng(), obj, q, prob; n_samples)

function estimate_repgradelbo_ad_forward(params′, aux)
@unpack rng, obj, problem, restructure, q_stop = aux
q = restructure(params′)
samples, entropy = reparam_with_entropy(rng, q, q_stop, obj.n_samples, obj.entropy)
energy = estimate_energy_with_samples(problem, samples)
elbo = energy + entropy
-elbo
end

function estimate_gradient!(
rng ::Random.AbstractRNG,
obj ::RepGradELBO,
adtype::ADTypes.AbstractADType,
out ::DiffResults.MutableDiffResult,
prob,
λ,
params,
restructure,
state,
)
q_stop = restructure(λ)
function f(λ′)
q = restructure(λ′)
samples, entropy = reparam_with_entropy(rng, q, q_stop, obj.n_samples, obj.entropy)
energy = estimate_energy_with_samples(prob, samples)
elbo = energy + entropy
-elbo
end
value_and_gradient!(adtype, f, λ, out)

q_stop = restructure(params)
aux = (rng=rng, obj=obj, problem=prob, restructure=restructure, q_stop=q_stop)
value_and_gradient!(
adtype, estimate_repgradelbo_ad_forward, params, aux, out
)
nelbo = DiffResults.value(out)
stat = (elbo=-nelbo,)

out, nothing, stat
end
2 changes: 1 addition & 1 deletion src/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ function optimize(
)
params, restructure = Optimisers.destructure(deepcopy(q_init))
opt_st = maybe_init_optimizer(state_init, optimizer, params)
obj_st = maybe_init_objective(state_init, rng, objective, params, restructure)
obj_st = maybe_init_objective(state_init, rng, objective, problem, params, restructure)
grad_buf = DiffResults.DiffResult(zero(eltype(params)), similar(params))
stats = NamedTuple[]

Expand Down
17 changes: 13 additions & 4 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,28 @@ end
function maybe_init_optimizer(
state_init::NamedTuple,
optimizer ::Optimisers.AbstractRule,
params ::AbstractVector
params
)
haskey(state_init, :optimizer) ? state_init.optimizer : Optimisers.setup(optimizer, params)
if haskey(state_init, :optimizer)
state_init.optimizer
else
Optimisers.setup(optimizer, params)
end
end

function maybe_init_objective(
state_init::NamedTuple,
rng ::Random.AbstractRNG,
objective ::AbstractVariationalObjective,
params ::AbstractVector,
problem,
params,
restructure
)
haskey(state_init, :objective) ? state_init.objective : init(rng, objective, params, restructure)
if haskey(state_init, :objective)
state_init.objective
else
init(rng, objective, problem, params, restructure)
end
end

eachsample(samples::AbstractMatrix) = eachcol(samples)
Expand Down
3 changes: 1 addition & 2 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Expand All @@ -26,7 +26,6 @@ ADTypes = "0.2.1, 1"
Bijectors = "0.13"
Distributions = "0.25.100"
DistributionsAD = "0.6.45"
Enzyme = "0.12"
FillArrays = "1.6.1"
ForwardDiff = "0.10.36"
Functors = "0.4.5"
Expand Down
29 changes: 29 additions & 0 deletions test/interface/repgradelbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,32 @@ using Test
@test elbo elbo_ref rtol=0.1
end
end

@testset "interface RepGradELBO STL variance reduction" begin
seed = (0x38bef07cf9cc549d)
rng = StableRNG(seed)

modelstats = normal_meanfield(rng, Float64)
@unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats

@testset for ad in [
ADTypes.AutoForwardDiff(),
ADTypes.AutoReverseDiff(),
ADTypes.AutoZygote()
]
q_true = MeanFieldGaussian(
Vector{eltype(μ_true)}(μ_true),
Diagonal(Vector{eltype(L_true)}(diag(L_true)))
)
params, re = Optimisers.destructure(q_true)
obj = RepGradELBO(10; entropy=StickingTheLandingEntropy())
out = DiffResults.DiffResult(zero(eltype(params)), similar(params))

aux = (rng=rng, obj=obj, problem=model, restructure=re, q_stop=q_true)
AdvancedVI.value_and_gradient!(
ad, AdvancedVI.estimate_repgradelbo_ad_forward, params, aux, out
)
grad = DiffResults.gradient(out)
@test norm(grad) 0 atol=1e-5
end
end

1 comment on commit cb3b838

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: cb3b838 Previous: c93b5d7 Ratio
normal + bijector/meanfield/ForwardDiff 468762711 ns 529221525.5 ns 0.89
normal + bijector/meanfield/ReverseDiff 185732693 ns 187259407 ns 0.99

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.