Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor Touches for ScoreGradELBO #99

Merged
merged 29 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
00d43d0
fix make `ScoreGradELBO` immutable
Red-Portal Sep 30, 2024
af7a5a6
fix error in `ScoreGradELBO`
Red-Portal Sep 30, 2024
d73ed54
Merge branch 'master' of github.com:TuringLang/AdvancedVI.jl into tid…
Red-Portal Oct 1, 2024
7370248
fix type instability, use OnlineStats for baseline window
Red-Portal Oct 1, 2024
1ab02c5
fix default options of `ScoreGradientELBO`, enable more tests
Red-Portal Oct 2, 2024
d110ec7
refactor `ScoreGradELBO`
Red-Portal Oct 4, 2024
42c1034
run formatter
Red-Portal Oct 4, 2024
feb3200
fix default value for baseline control variate
Red-Portal Oct 4, 2024
45b5afd
Merge branch 'master' into tidy_scoregradelbo
yebai Oct 22, 2024
45b37c1
Update CI.yml
yebai Oct 22, 2024
86eccf7
fix move log density computation out of the AD path
Red-Portal Nov 5, 2024
35ea7ec
Merge branch 'tidy_scoregradelbo' of github.com:TuringLang/AdvancedVI…
Red-Portal Nov 5, 2024
dc23a02
update change the `ScoreGradELBO` objective to be VarGrad underneath
Red-Portal Nov 5, 2024
f030d14
fix remove unnecessary import
Red-Portal Nov 5, 2024
2577dce
fix ScoreGradELBO outdated docs and removed unused parametric type
Red-Portal Nov 5, 2024
f0bbc1b
update docs for `ScoreGradELBO`
Red-Portal Nov 5, 2024
43e8581
update docs for `ScoreGradELBO`
Red-Portal Nov 5, 2024
3e6ef6f
Merge branch 'master' of github.com:TuringLang/AdvancedVI.jl into tid…
Red-Portal Dec 3, 2024
9ab4d89
run formatter
Red-Portal Dec 3, 2024
b3d72e8
remove outdated test
Red-Portal Dec 3, 2024
c69a5ed
fix error in `ScoreGradELBO` for `n_montecarlo=1`
Red-Portal Dec 3, 2024
bfa9de0
add basic tests for interface tests of variational objectives
Red-Portal Dec 3, 2024
eeb2d34
run formatter
Red-Portal Dec 4, 2024
b2fd6e2
tweak stepsize for inference test of ScoreGradELBO
Red-Portal Dec 4, 2024
5ff79e3
tweak `n_montecarlo` in inference test
Red-Portal Dec 4, 2024
eda4ea0
add docstrings to elbo objective forward ad paths
Red-Portal Dec 4, 2024
b6083ed
remove `n_montecarlo` option in the inference tests and just fix it
Red-Portal Dec 4, 2024
09c7276
fix bug in `ScoreGradELBO`
Red-Portal Dec 5, 2024
801a76b
fix wrong usage of `n_montecarlo` in inference tests
Red-Portal Dec 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions src/objectives/elbo/entropy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,3 @@ function estimate_entropy(
-logpdf(q, mc_sample)
end
end

function estimate_entropy_maybe_stl(
entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop
)
q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop)
return estimate_entropy(entropy_estimator, samples, q_maybe_stop)
end
29 changes: 27 additions & 2 deletions src/objectives/elbo/repgradelbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ function Base.show(io::IO, obj::RepGradELBO)
return print(io, ")")
end

function estimate_entropy_maybe_stl(
entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop
)
q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop)
return estimate_entropy(entropy_estimator, samples, q_maybe_stop)
end

function estimate_energy_with_samples(prob, samples)
return mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples))
end
Expand Down Expand Up @@ -85,9 +92,27 @@ function estimate_objective(obj::RepGradELBO, q, prob; n_samples::Int=obj.n_samp
return estimate_objective(Random.default_rng(), obj, q, prob; n_samples)
end

function estimate_repgradelbo_ad_forward(params′, aux)
"""
estimate_repgradelbo_ad_forward(params, aux)

AD-guaranteed forward path of the reparameterization gradient objective.

# Arguments
- `params`: Variational parameters.
- `aux`: Auxiliary information excluded from the AD path.

# Auxiliary Information
`aux` should containt the following entries:
- `rng`: Random number generator.
- `obj`: The `RepGradELBO` objective.
- `problem`: The target `LogDensityProblem`.
- `adtype`: The `ADType` used for differentiating the forward path.
- `restructure`: Callable for restructuring the varitional distribution from `params`.
- `q_stop`: A copy of `restructure(params)` with its gradient "stopped" (excluded from the AD path).
"""
function estimate_repgradelbo_ad_forward(params, aux)
(; rng, obj, problem, adtype, restructure, q_stop) = aux
q = restructure_ad_forward(adtype, restructure, params)
q = restructure_ad_forward(adtype, restructure, params)
samples, entropy = reparam_with_entropy(rng, q, q_stop, obj.n_samples, obj.entropy)
energy = estimate_energy_with_samples(problem, samples)
elbo = energy + entropy
Expand Down
133 changes: 39 additions & 94 deletions src/objectives/elbo/scoregradelbo.jl
Original file line number Diff line number Diff line change
@@ -1,113 +1,63 @@

"""
ScoreGradELBO(n_samples; kwargs...)

Evidence lower-bound objective computed with score function gradients.
```math
\\begin{aligned}
\\nabla_{\\lambda} \\mathrm{ELBO}\\left(\\lambda\\right)
&\\=
\\mathbb{E}_{z \\sim q_{\\lambda}}\\left[
\\log \\pi\\left(z\\right) \\nabla_{\\lambda} \\log q_{\\lambda}(z)
\\right]
+ \\mathbb{H}\\left(q_{\\lambda}\\right),
\\end{aligned}
```

To reduce the variance of the gradient estimator, we use a baseline computed from a running average of the previous ELBO values and subtract it from the objective.

```math
\\mathbb{E}_{z \\sim q_{\\lambda}}\\left[
\\nabla_{\\lambda} \\log q_{\\lambda}(z) \\left(\\pi\\left(z\\right) - \\beta\\right)
\\right]
```
Evidence lower-bound objective computed with score function gradient with the VarGrad objective, also known as the leave-one-out control variate.

# Arguments
- `n_samples::Int`: Number of Monte Carlo samples used to estimate the ELBO.

# Keyword Arguments
- `entropy`: The estimator for the entropy term. (Type `<: AbstractEntropyEstimator`; Default: `ClosedFormEntropy()`)
- `baseline_window_size::Int`: The window size to use to compute the baseline. (Default: `10`)
- `baseline_history::Vector{Float64}`: The history of the baseline. (Default: `Float64[]`)
- `n_samples::Int`: Number of Monte Carlo samples used to estimate the VarGrad objective.

# Requirements
- The variational approximation ``q_{\\lambda}`` implements `rand` and `logpdf`.
- `logpdf(q, x)` must be differentiable with respect to `q` by the selected AD backend.
- The target distribution and the variational approximation have the same support.

Depending on the options, additional requirements on ``q_{\\lambda}`` may apply.
"""
struct ScoreGradELBO{EntropyEst<:AbstractEntropyEstimator} <:
AdvancedVI.AbstractVariationalObjective
entropy::EntropyEst
struct ScoreGradELBO <: AbstractVariationalObjective
n_samples::Int
baseline_window_size::Int
baseline_history::Vector{Float64}
end

function ScoreGradELBO(
n_samples::Int;
entropy::AbstractEntropyEstimator=ClosedFormEntropy(),
baseline_window_size::Int=10,
baseline_history::Vector{Float64}=Float64[],
)
return ScoreGradELBO(entropy, n_samples, baseline_window_size, baseline_history)
end

function Base.show(io::IO, obj::ScoreGradELBO)
print(io, "ScoreGradELBO(entropy=")
print(io, obj.entropy)
print(io, ", n_samples=")
print(io, "ScoreGradELBO(n_samples=")
print(io, obj.n_samples)
print(io, ", baseline_window_size=")
print(io, obj.baseline_window_size)
return print(io, ")")
end

function compute_control_variate_baseline(history, window_size)
if length(history) == 0
return 1.0
end
min_index = max(1, length(history) - window_size)
return mean(history[min_index:end])
end

function estimate_energy_with_samples(
prob, samples_stop, samples_logprob, samples_logprob_stop, baseline
)
fv = Base.Fix1(LogDensityProblems.logdensity, prob).(eachsample(samples_stop))
fv_mean = mean(fv)
score_grad = mean(@. samples_logprob * (fv - baseline))
score_grad_stop = mean(@. samples_logprob_stop * (fv - baseline))
return fv_mean + (score_grad - score_grad_stop)
end

function estimate_objective(
mhauru marked this conversation as resolved.
Show resolved Hide resolved
rng::Random.AbstractRNG, obj::ScoreGradELBO, q, prob; n_samples::Int=obj.n_samples
)
samples, entropy = reparam_with_entropy(rng, q, q, obj.n_samples, obj.entropy)
energy = map(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples))
return mean(energy) + entropy
samples = rand(rng, q, n_samples)
ℓπ = map(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples))
ℓq = logpdf.(Ref(q), AdvancedVI.eachsample(samples))
mhauru marked this conversation as resolved.
Show resolved Hide resolved
return mean(ℓπ - ℓq)
end

function estimate_objective(obj::ScoreGradELBO, q, prob; n_samples::Int=obj.n_samples)
return estimate_objective(Random.default_rng(), obj, q, prob; n_samples)
end

function estimate_scoregradelbo_ad_forward(params′, aux)
(; rng, obj, problem, adtype, restructure, q_stop) = aux
baseline = compute_control_variate_baseline(
obj.baseline_history, obj.baseline_window_size
)
q = restructure_ad_forward(adtype, restructure, params′)
samples_stop = rand(rng, q_stop, obj.n_samples)
entropy = estimate_entropy_maybe_stl(obj.entropy, samples_stop, q, q_stop)
samples_logprob = logpdf.(Ref(q), AdvancedVI.eachsample(samples_stop))
samples_logprob_stop = logpdf.(Ref(q_stop), AdvancedVI.eachsample(samples_stop))
energy = estimate_energy_with_samples(
problem, samples_stop, samples_logprob, samples_logprob_stop, baseline
)
elbo = energy + entropy
return -elbo
"""
estimate_scoregradelbo_ad_forward(params, aux)

AD-guaranteed forward path of the score gradient objective.

# Arguments
- `params`: Variational parameters.
- `aux`: Auxiliary information excluded from the AD path.

# Auxiliary Information
`aux` should containt the following entries:
- `samples_stop`: Samples drawn from `q = restructure(params)` but with their gradients stopped (excluded from the AD path).
- `logprob_stop`: Log-densities of the target `LogDensityProblem` evaluated over `samples_stop`.
- `adtype`: The `ADType` used for differentiating the forward path.
- `restructure`: Callable for restructuring the varitional distribution from `params`.
"""
function estimate_scoregradelbo_ad_forward(params, aux)
(; samples_stop, logprob_stop, adtype, restructure) = aux
q = restructure_ad_forward(adtype, restructure, params)
ℓπ = logprob_stop
ℓq = logpdf.(Ref(q), AdvancedVI.eachsample(samples_stop))
f = ℓq - ℓπ
return (mean(abs2, f) - mean(f)^2) / 2
end

function AdvancedVI.estimate_gradient!(
Expand All @@ -120,20 +70,15 @@ function AdvancedVI.estimate_gradient!(
restructure,
state,
)
q_stop = restructure(params)
aux = (
rng=rng,
adtype=adtype,
obj=obj,
problem=prob,
restructure=restructure,
q_stop=q_stop,
)
q = restructure(params)
samples = rand(rng, q, obj.n_samples)
ℓπ = map(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples))
aux = (adtype=adtype, logprob_stop=ℓπ, samples_stop=samples, restructure=restructure)
AdvancedVI.value_and_gradient!(
adtype, estimate_scoregradelbo_ad_forward, params, aux, out
)
nelbo = DiffResults.value(out)
stat = (elbo=-nelbo,)
push!(obj.baseline_history, -nelbo)
ℓq = logpdf.(Ref(q), AdvancedVI.eachsample(samples))
elbo = mean(ℓπ - ℓq)
stat = (elbo=elbo,)
return out, nothing, stat
end
5 changes: 2 additions & 3 deletions test/inference/repgradelbo_distributionsad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@ end
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in
[Float64, Float32],
(modelname, modelconstr) in Dict(:Normal => normal_meanfield),
n_montecarlo in [1, 10],
(objname, objective) in Dict(
:RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo),
:RepGradELBOClosedFormEntropy => RepGradELBO(10),
:RepGradELBOStickingTheLanding =>
RepGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()),
RepGradELBO(10; entropy=StickingTheLandingEntropy()),
),
(adbackname, adtype) in AD_repgradelbo_distributionsad

Expand Down
7 changes: 3 additions & 4 deletions test/inference/repgradelbo_locationscale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,15 @@ else
)
end

@testset "inference ScoreGradELBO VILocationScale" begin
@testset "inference RepGradELBO VILocationScale" begin
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in
[Float64, Float32],
(modelname, modelconstr) in
Dict(:Normal => normal_meanfield, :Normal => normal_fullrank),
n_montecarlo in [1, 10],
(objname, objective) in Dict(
:RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo),
:RepGradELBOClosedFormEntropy => RepGradELBO(10),
:RepGradELBOStickingTheLanding =>
RepGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()),
RepGradELBO(10; entropy=StickingTheLandingEntropy()),
),
(adbackname, adtype) in AD_repgradelbo_locationscale

Expand Down
5 changes: 2 additions & 3 deletions test/inference/repgradelbo_locationscale_bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@ end
[Float64, Float32],
(modelname, modelconstr) in
Dict(:NormalLogNormalMeanField => normallognormal_meanfield),
n_montecarlo in [1, 10],
(objname, objective) in Dict(
:RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo),
:RepGradELBOClosedFormEntropy => RepGradELBO(10),
:RepGradELBOStickingTheLanding =>
RepGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()),
RepGradELBO(10; entropy=StickingTheLandingEntropy()),
),
(adbackname, adtype) in AD_repgradelbo_locationscale_bijectors

Expand Down
9 changes: 2 additions & 7 deletions test/inference/scoregradelbo_distributionsad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,7 @@ end
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in
[Float64, Float32],
(modelname, modelconstr) in Dict(:Normal => normal_meanfield),
n_montecarlo in [1, 10],
(objname, objective) in Dict(
:ScoreGradELBOClosedFormEntropy => ScoreGradELBO(n_montecarlo),
:ScoreGradELBOStickingTheLanding =>
ScoreGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()),
),
(objname, objective) in Dict(:ScoreGradELBO => ScoreGradELBO(10)),
(adbackname, adtype) in AD_scoregradelbo_distributionsad

seed = (0x38bef07cf9cc549d)
Expand All @@ -29,7 +24,7 @@ end
(; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats

T = 1000
η = 1e-5
η = 1e-4
opt = Optimisers.Descent(realtype(η))

# For small enough η, the error of SGD, Δλ, is bounded as
Expand Down
9 changes: 2 additions & 7 deletions test/inference/scoregradelbo_locationscale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,7 @@ end
[Float64, Float32],
(modelname, modelconstr) in
Dict(:Normal => normal_meanfield, :Normal => normal_fullrank),
n_montecarlo in [1, 10],
(objname, objective) in Dict(
:ScoreGradELBOClosedFormEntropy => ScoreGradELBO(n_montecarlo),
:ScoreGradELBOStickingTheLanding =>
ScoreGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()),
),
(objname, objective) in Dict(:ScoreGradELBO => ScoreGradELBO(10)),
(adbackname, adtype) in AD_scoregradelbo_locationscale

seed = (0x38bef07cf9cc549d)
Expand All @@ -30,7 +25,7 @@ end
(; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats

T = 1000
η = 1e-5
η = 1e-4
opt = Optimisers.Descent(realtype(η))

# For small enough η, the error of SGD, Δλ, is bounded as
Expand Down
9 changes: 2 additions & 7 deletions test/inference/scoregradelbo_locationscale_bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,7 @@ end
[Float64, Float32],
(modelname, modelconstr) in
Dict(:NormalLogNormalMeanField => normallognormal_meanfield),
n_montecarlo in [1, 10],
(objname, objective) in Dict(
#:ScoreGradELBOClosedFormEntropy => ScoreGradELBO(n_montecarlo), # not supported yet.
:ScoreGradELBOStickingTheLanding =>
ScoreGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()),
),
(objname, objective) in Dict(:ScoreGradELBO => ScoreGradELBO(10)),
(adbackname, adtype) in AD_scoregradelbo_locationscale_bijectors

seed = (0x38bef07cf9cc549d)
Expand All @@ -30,7 +25,7 @@ end
(; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats

T = 1000
η = 1e-5
η = 1e-4
opt = Optimisers.Descent(realtype(η))

b = Bijectors.bijector(model)
Expand Down
Loading
Loading