Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor Touches for ScoreGradELBO #99

Merged
merged 29 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
00d43d0
fix make `ScoreGradELBO` immutable
Red-Portal Sep 30, 2024
af7a5a6
fix error in `ScoreGradELBO`
Red-Portal Sep 30, 2024
d73ed54
Merge branch 'master' of github.com:TuringLang/AdvancedVI.jl into tid…
Red-Portal Oct 1, 2024
7370248
fix type instability, use OnlineStats for baseline window
Red-Portal Oct 1, 2024
1ab02c5
fix default options of `ScoreGradientELBO`, enable more tests
Red-Portal Oct 2, 2024
d110ec7
refactor `ScoreGradELBO`
Red-Portal Oct 4, 2024
42c1034
run formatter
Red-Portal Oct 4, 2024
feb3200
fix default value for baseline control variate
Red-Portal Oct 4, 2024
45b5afd
Merge branch 'master' into tidy_scoregradelbo
yebai Oct 22, 2024
45b37c1
Update CI.yml
yebai Oct 22, 2024
86eccf7
fix move log density computation out of the AD path
Red-Portal Nov 5, 2024
35ea7ec
Merge branch 'tidy_scoregradelbo' of github.com:TuringLang/AdvancedVI…
Red-Portal Nov 5, 2024
dc23a02
update change the `ScoreGradELBO` objective to be VarGrad underneath
Red-Portal Nov 5, 2024
f030d14
fix remove unnecessary import
Red-Portal Nov 5, 2024
2577dce
fix ScoreGradELBO outdated docs and removed unused parametric type
Red-Portal Nov 5, 2024
f0bbc1b
update docs for `ScoreGradELBO`
Red-Portal Nov 5, 2024
43e8581
update docs for `ScoreGradELBO`
Red-Portal Nov 5, 2024
3e6ef6f
Merge branch 'master' of github.com:TuringLang/AdvancedVI.jl into tid…
Red-Portal Dec 3, 2024
9ab4d89
run formatter
Red-Portal Dec 3, 2024
b3d72e8
remove outdated test
Red-Portal Dec 3, 2024
c69a5ed
fix error in `ScoreGradELBO` for `n_montecarlo=1`
Red-Portal Dec 3, 2024
bfa9de0
add basic tests for interface tests of variational objectives
Red-Portal Dec 3, 2024
eeb2d34
run formatter
Red-Portal Dec 4, 2024
b2fd6e2
tweak stepsize for inference test of ScoreGradELBO
Red-Portal Dec 4, 2024
5ff79e3
tweak `n_montecarlo` in inference test
Red-Portal Dec 4, 2024
eda4ea0
add docstrings to elbo objective forward ad paths
Red-Portal Dec 4, 2024
b6083ed
remove `n_montecarlo` option in the inference tests and just fix it
Red-Portal Dec 4, 2024
09c7276
fix bug in `ScoreGradELBO`
Red-Portal Dec 5, 2024
801a76b
fix wrong usage of `n_montecarlo` in inference tests
Red-Portal Dec 5, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ jobs:
fail-fast: false
matrix:
version:
- '1.7'
- '1.10'
- '1'
os:
- ubuntu-latest
- macOS-latest
Expand Down
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down Expand Up @@ -49,6 +50,7 @@ Functors = "0.4"
LinearAlgebra = "1"
LogDensityProblems = "2"
Mooncake = "0.4"
OnlineStats = "1"
Optimisers = "0.2.16, 0.3"
ProgressMeter = "1.6"
Random = "1"
Expand Down
1 change: 1 addition & 0 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using Accessors

using Random
using Distributions
using OnlineStats

using Functors
using Optimisers
Expand Down
7 changes: 0 additions & 7 deletions src/objectives/elbo/entropy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,3 @@ function estimate_entropy(
-logpdf(q, mc_sample)
end
end

function estimate_entropy_maybe_stl(
entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop
)
q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop)
return estimate_entropy(entropy_estimator, samples, q_maybe_stop)
end
7 changes: 7 additions & 0 deletions src/objectives/elbo/repgradelbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ function Base.show(io::IO, obj::RepGradELBO)
return print(io, ")")
end

function estimate_entropy_maybe_stl(
entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop
)
q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop)
return estimate_entropy(entropy_estimator, samples, q_maybe_stop)
end

function estimate_energy_with_samples(prob, samples)
return mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples))
end
Expand Down
79 changes: 39 additions & 40 deletions src/objectives/elbo/scoregradelbo.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

"""
ScoreGradELBO(n_samples; kwargs...)

Expand Down Expand Up @@ -25,9 +26,8 @@ To reduce the variance of the gradient estimator, we use a baseline computed fro
- `n_samples::Int`: Number of Monte Carlo samples used to estimate the ELBO.

# Keyword Arguments
- `entropy`: The estimator for the entropy term. (Type `<: AbstractEntropyEstimator`; Default: `ClosedFormEntropy()`)
- `entropy`: The estimator for the entropy term. (Type `<: AbstractEntropyEstimator`; Default: `FullMonteCarloEntropy()`)
- `baseline_window_size::Int`: The window size to use to compute the baseline. (Default: `10`)
- `baseline_history::Vector{Float64}`: The history of the baseline. (Default: `Float64[]`)

# Requirements
- The variational approximation ``q_{\\lambda}`` implements `rand` and `logpdf`.
Expand All @@ -36,21 +36,24 @@ To reduce the variance of the gradient estimator, we use a baseline computed fro

Depending on the options, additional requirements on ``q_{\\lambda}`` may apply.
"""
struct ScoreGradELBO{EntropyEst<:AbstractEntropyEstimator} <:
AdvancedVI.AbstractVariationalObjective
struct ScoreGradELBO{EntropyEst<:AbstractEntropyEstimator} <: AbstractVariationalObjective
entropy::EntropyEst
n_samples::Int
baseline_window_size::Int
baseline_history::Vector{Float64}
end

function ScoreGradELBO(
n_samples::Int;
entropy::AbstractEntropyEstimator=ClosedFormEntropy(),
entropy::AbstractEntropyEstimator=MonteCarloEntropy(),
baseline_window_size::Int=10,
baseline_history::Vector{Float64}=Float64[],
)
return ScoreGradELBO(entropy, n_samples, baseline_window_size, baseline_history)
return ScoreGradELBO(entropy, n_samples, baseline_window_size)
end

function init(
::Random.AbstractRNG, obj::ScoreGradELBO, prob, params::AbstractVector{T}, restructure
) where {T<:Real}
return MovingWindow(T, obj.baseline_window_size)
end

function Base.show(io::IO, obj::ScoreGradELBO)
Expand All @@ -63,28 +66,11 @@ function Base.show(io::IO, obj::ScoreGradELBO)
return print(io, ")")
end

function compute_control_variate_baseline(history, window_size)
if length(history) == 0
return 1.0
end
min_index = max(1, length(history) - window_size)
return mean(history[min_index:end])
end

function estimate_energy_with_samples(
prob, samples_stop, samples_logprob, samples_logprob_stop, baseline
)
fv = Base.Fix1(LogDensityProblems.logdensity, prob).(eachsample(samples_stop))
fv_mean = mean(fv)
score_grad = mean(@. samples_logprob * (fv - baseline))
score_grad_stop = mean(@. samples_logprob_stop * (fv - baseline))
return fv_mean + (score_grad - score_grad_stop)
end

function estimate_objective(
mhauru marked this conversation as resolved.
Show resolved Hide resolved
rng::Random.AbstractRNG, obj::ScoreGradELBO, q, prob; n_samples::Int=obj.n_samples
)
samples, entropy = reparam_with_entropy(rng, q, q, obj.n_samples, obj.entropy)
samples = rand(rng, q, n_samples)
entropy = estimate_entropy(obj.entropy, samples, q)
energy = map(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples))
return mean(energy) + entropy
end
Expand All @@ -94,18 +80,19 @@ function estimate_objective(obj::ScoreGradELBO, q, prob; n_samples::Int=obj.n_sa
end

function estimate_scoregradelbo_ad_forward(params′, aux)
mhauru marked this conversation as resolved.
Show resolved Hide resolved
@unpack rng, obj, problem, adtype, restructure, q_stop = aux
baseline = compute_control_variate_baseline(
obj.baseline_history, obj.baseline_window_size
)
@unpack rng, obj, problem, adtype, restructure, samples, q_stop, baseline = aux
q = restructure_ad_forward(adtype, restructure, params′)
samples_stop = rand(rng, q_stop, obj.n_samples)
entropy = estimate_entropy_maybe_stl(obj.entropy, samples_stop, q, q_stop)
samples_logprob = logpdf.(Ref(q), AdvancedVI.eachsample(samples_stop))
samples_logprob_stop = logpdf.(Ref(q_stop), AdvancedVI.eachsample(samples_stop))
energy = estimate_energy_with_samples(
problem, samples_stop, samples_logprob, samples_logprob_stop, baseline
)

ℓq = logpdf.(Ref(q), AdvancedVI.eachsample(samples))
ℓq_stop = logpdf.(Ref(q_stop), AdvancedVI.eachsample(samples))
ℓπ = map(Base.Fix1(LogDensityProblems.logdensity, problem), eachsample(samples))
ℓπ_mean = mean(ℓπ)
score_grad = mean(@. ℓq * (ℓπ - baseline))
score_grad_stop = mean(@. ℓq_stop * (ℓπ - baseline))

energy = ℓπ_mean + (score_grad - score_grad_stop)
entropy = estimate_entropy(obj.entropy, samples, q)

elbo = energy + entropy
return -elbo
end
Expand All @@ -120,20 +107,32 @@ function AdvancedVI.estimate_gradient!(
restructure,
state,
)
baseline_buf = state
baseline_history = OnlineStats.value(baseline_buf)
baseline = if isempty(baseline_history)
zero(eltype(params))
else
mean(baseline_history)
end
q_stop = restructure(params)
samples = rand(rng, q_stop, obj.n_samples)
aux = (
rng=rng,
adtype=adtype,
obj=obj,
problem=prob,
restructure=restructure,
baseline=baseline,
samples=samples,
q_stop=q_stop,
)
AdvancedVI.value_and_gradient!(
adtype, estimate_scoregradelbo_ad_forward, params, aux, out
)
nelbo = DiffResults.value(out)
stat = (elbo=-nelbo,)
push!(obj.baseline_history, -nelbo)
return out, nothing, stat
if obj.baseline_window_size > 0
fit!(baseline_buf, -nelbo)
end
return out, baseline_buf, stat
end
2 changes: 1 addition & 1 deletion test/inference/repgradelbo_locationscale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ if @isdefined(Enzyme)
)
end

@testset "inference ScoreGradELBO VILocationScale" begin
@testset "inference RepGradELBO VILocationScale" begin
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in
[Float64, Float32],
(modelname, modelconstr) in
Expand Down
12 changes: 4 additions & 8 deletions test/inference/scoregradelbo_distributionsad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,16 @@ if @isdefined(Mooncake)
AD_scoregradelbo_distributionsad[:Mooncake] = AutoMooncake(; config=Mooncake.Config())
end

#if @isdefined(Enzyme)
# AD_scoregradelbo_distributionsad[:Enzyme] = AutoEnzyme()
#end
if @isdefined(Enzyme)
AD_scoregradelbo_distributionsad[:Enzyme] = AutoEnzyme()
end

@testset "inference ScoreGradELBO DistributionsAD" begin
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in
[Float64, Float32],
(modelname, modelconstr) in Dict(:Normal => normal_meanfield),
n_montecarlo in [1, 10],
(objname, objective) in Dict(
:ScoreGradELBOClosedFormEntropy => ScoreGradELBO(n_montecarlo),
:ScoreGradELBOStickingTheLanding =>
ScoreGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()),
),
(objname, objective) in Dict(:ScoreGradELBO => ScoreGradELBO(n_montecarlo)),
(adbackname, adtype) in AD_scoregradelbo_distributionsad

seed = (0x38bef07cf9cc549d)
Expand Down
6 changes: 1 addition & 5 deletions test/inference/scoregradelbo_locationscale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,7 @@ end
(modelname, modelconstr) in
Dict(:Normal => normal_meanfield, :Normal => normal_fullrank),
n_montecarlo in [1, 10],
(objname, objective) in Dict(
:ScoreGradELBOClosedFormEntropy => ScoreGradELBO(n_montecarlo),
:ScoreGradELBOStickingTheLanding =>
ScoreGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()),
),
(objname, objective) in Dict(:ScoreGradELBO => ScoreGradELBO(n_montecarlo)),
(adbackname, adtype) in AD_locationscale

seed = (0x38bef07cf9cc549d)
Expand Down
12 changes: 4 additions & 8 deletions test/inference/scoregradelbo_locationscale_bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ AD_scoregradelbo_locationscale_bijectors = Dict(
#:Zygote => AutoZygote(),
)

#if @isdefined(Tapir)
# AD_scoregradelbo_locationscale_bijectors[:Tapir] = AutoTapir(; safe_mode=false)
#end
if @isdefined(Mooncake)
Red-Portal marked this conversation as resolved.
Show resolved Hide resolved
AD_scoregradelbo_locationscale_bijectors[:Mooncake] = AutoMooncake(; config=nothing)
end

if @isdefined(Enzyme)
AD_scoregradelbo_locationscale_bijectors[:Enzyme] = AutoEnzyme()
Expand All @@ -19,11 +19,7 @@ end
(modelname, modelconstr) in
Dict(:NormalLogNormalMeanField => normallognormal_meanfield),
n_montecarlo in [1, 10],
(objname, objective) in Dict(
#:ScoreGradELBOClosedFormEntropy => ScoreGradELBO(n_montecarlo), # not supported yet.
:ScoreGradELBOStickingTheLanding =>
ScoreGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()),
),
(objname, objective) in Dict(:ScoreGradELBO => ScoreGradELBO(n_montecarlo)),
(adbackname, adtype) in AD_scoregradelbo_locationscale_bijectors

seed = (0x38bef07cf9cc549d)
Expand Down
17 changes: 17 additions & 0 deletions test/interface/scoregradelbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,21 @@ using Test
elbo = estimate_objective(obj, q0, model; n_samples=10^4)
@test elbo ≈ elbo_ref rtol = 0.2
end

@testset "baseline_window" begin
T = 100
adtype = AutoForwardDiff()

obj = ScoreGradELBO(10)
_, _, stats, _ = optimize(rng, model, obj, q0, T; show_progress=false, adtype)
@test isfinite(last(stats).elbo)

obj = ScoreGradELBO(10; baseline_window_size=0)
_, _, stats, _ = optimize(rng, model, obj, q0, T; show_progress=false, adtype)
@test isfinite(last(stats).elbo)

obj = ScoreGradELBO(10; baseline_window_size=1)
_, _, stats, _ = optimize(rng, model, obj, q0, T; show_progress=false, adtype)
@test isfinite(last(stats).elbo)
end
end
Loading