diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl index fa34022a..210b49ca 100644 --- a/src/objectives/elbo/entropy.jl +++ b/src/objectives/elbo/entropy.jl @@ -37,10 +37,3 @@ function estimate_entropy( -logpdf(q, mc_sample) end end - -function estimate_entropy_maybe_stl( - entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop -) - q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop) - return estimate_entropy(entropy_estimator, samples, q_maybe_stop) -end diff --git a/src/objectives/elbo/repgradelbo.jl b/src/objectives/elbo/repgradelbo.jl index 5b2fd828..d8079c2b 100644 --- a/src/objectives/elbo/repgradelbo.jl +++ b/src/objectives/elbo/repgradelbo.jl @@ -45,6 +45,13 @@ function Base.show(io::IO, obj::RepGradELBO) return print(io, ")") end +function estimate_entropy_maybe_stl( + entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop +) + q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop) + return estimate_entropy(entropy_estimator, samples, q_maybe_stop) +end + function estimate_energy_with_samples(prob, samples) return mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples)) end @@ -85,9 +92,27 @@ function estimate_objective(obj::RepGradELBO, q, prob; n_samples::Int=obj.n_samp return estimate_objective(Random.default_rng(), obj, q, prob; n_samples) end -function estimate_repgradelbo_ad_forward(params′, aux) +""" + estimate_repgradelbo_ad_forward(params, aux) + +AD-guaranteed forward path of the reparameterization gradient objective. + +# Arguments +- `params`: Variational parameters. +- `aux`: Auxiliary information excluded from the AD path. + +# Auxiliary Information +`aux` should containt the following entries: +- `rng`: Random number generator. +- `obj`: The `RepGradELBO` objective. +- `problem`: The target `LogDensityProblem`. +- `adtype`: The `ADType` used for differentiating the forward path. +- `restructure`: Callable for restructuring the varitional distribution from `params`. +- `q_stop`: A copy of `restructure(params)` with its gradient "stopped" (excluded from the AD path). +""" +function estimate_repgradelbo_ad_forward(params, aux) (; rng, obj, problem, adtype, restructure, q_stop) = aux - q = restructure_ad_forward(adtype, restructure, params′) + q = restructure_ad_forward(adtype, restructure, params) samples, entropy = reparam_with_entropy(rng, q, q_stop, obj.n_samples, obj.entropy) energy = estimate_energy_with_samples(problem, samples) elbo = energy + entropy diff --git a/src/objectives/elbo/scoregradelbo.jl b/src/objectives/elbo/scoregradelbo.jl index 8b9a91c6..8b96fa1c 100644 --- a/src/objectives/elbo/scoregradelbo.jl +++ b/src/objectives/elbo/scoregradelbo.jl @@ -1,113 +1,63 @@ + """ ScoreGradELBO(n_samples; kwargs...) -Evidence lower-bound objective computed with score function gradients. -```math -\\begin{aligned} -\\nabla_{\\lambda} \\mathrm{ELBO}\\left(\\lambda\\right) -&\\= -\\mathbb{E}_{z \\sim q_{\\lambda}}\\left[ - \\log \\pi\\left(z\\right) \\nabla_{\\lambda} \\log q_{\\lambda}(z) -\\right] -+ \\mathbb{H}\\left(q_{\\lambda}\\right), -\\end{aligned} -``` - -To reduce the variance of the gradient estimator, we use a baseline computed from a running average of the previous ELBO values and subtract it from the objective. - -```math -\\mathbb{E}_{z \\sim q_{\\lambda}}\\left[ - \\nabla_{\\lambda} \\log q_{\\lambda}(z) \\left(\\pi\\left(z\\right) - \\beta\\right) -\\right] -``` +Evidence lower-bound objective computed with score function gradient with the VarGrad objective, also known as the leave-one-out control variate. # Arguments -- `n_samples::Int`: Number of Monte Carlo samples used to estimate the ELBO. - -# Keyword Arguments -- `entropy`: The estimator for the entropy term. (Type `<: AbstractEntropyEstimator`; Default: `ClosedFormEntropy()`) -- `baseline_window_size::Int`: The window size to use to compute the baseline. (Default: `10`) -- `baseline_history::Vector{Float64}`: The history of the baseline. (Default: `Float64[]`) +- `n_samples::Int`: Number of Monte Carlo samples used to estimate the VarGrad objective. # Requirements - The variational approximation ``q_{\\lambda}`` implements `rand` and `logpdf`. - `logpdf(q, x)` must be differentiable with respect to `q` by the selected AD backend. - The target distribution and the variational approximation have the same support. - -Depending on the options, additional requirements on ``q_{\\lambda}`` may apply. """ -struct ScoreGradELBO{EntropyEst<:AbstractEntropyEstimator} <: - AdvancedVI.AbstractVariationalObjective - entropy::EntropyEst +struct ScoreGradELBO <: AbstractVariationalObjective n_samples::Int - baseline_window_size::Int - baseline_history::Vector{Float64} -end - -function ScoreGradELBO( - n_samples::Int; - entropy::AbstractEntropyEstimator=ClosedFormEntropy(), - baseline_window_size::Int=10, - baseline_history::Vector{Float64}=Float64[], -) - return ScoreGradELBO(entropy, n_samples, baseline_window_size, baseline_history) end function Base.show(io::IO, obj::ScoreGradELBO) - print(io, "ScoreGradELBO(entropy=") - print(io, obj.entropy) - print(io, ", n_samples=") + print(io, "ScoreGradELBO(n_samples=") print(io, obj.n_samples) - print(io, ", baseline_window_size=") - print(io, obj.baseline_window_size) return print(io, ")") end -function compute_control_variate_baseline(history, window_size) - if length(history) == 0 - return 1.0 - end - min_index = max(1, length(history) - window_size) - return mean(history[min_index:end]) -end - -function estimate_energy_with_samples( - prob, samples_stop, samples_logprob, samples_logprob_stop, baseline -) - fv = Base.Fix1(LogDensityProblems.logdensity, prob).(eachsample(samples_stop)) - fv_mean = mean(fv) - score_grad = mean(@. samples_logprob * (fv - baseline)) - score_grad_stop = mean(@. samples_logprob_stop * (fv - baseline)) - return fv_mean + (score_grad - score_grad_stop) -end - function estimate_objective( rng::Random.AbstractRNG, obj::ScoreGradELBO, q, prob; n_samples::Int=obj.n_samples ) - samples, entropy = reparam_with_entropy(rng, q, q, obj.n_samples, obj.entropy) - energy = map(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples)) - return mean(energy) + entropy + samples = rand(rng, q, n_samples) + ℓπ = map(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples)) + ℓq = logpdf.(Ref(q), AdvancedVI.eachsample(samples)) + return mean(ℓπ - ℓq) end function estimate_objective(obj::ScoreGradELBO, q, prob; n_samples::Int=obj.n_samples) return estimate_objective(Random.default_rng(), obj, q, prob; n_samples) end -function estimate_scoregradelbo_ad_forward(params′, aux) - (; rng, obj, problem, adtype, restructure, q_stop) = aux - baseline = compute_control_variate_baseline( - obj.baseline_history, obj.baseline_window_size - ) - q = restructure_ad_forward(adtype, restructure, params′) - samples_stop = rand(rng, q_stop, obj.n_samples) - entropy = estimate_entropy_maybe_stl(obj.entropy, samples_stop, q, q_stop) - samples_logprob = logpdf.(Ref(q), AdvancedVI.eachsample(samples_stop)) - samples_logprob_stop = logpdf.(Ref(q_stop), AdvancedVI.eachsample(samples_stop)) - energy = estimate_energy_with_samples( - problem, samples_stop, samples_logprob, samples_logprob_stop, baseline - ) - elbo = energy + entropy - return -elbo +""" + estimate_scoregradelbo_ad_forward(params, aux) + +AD-guaranteed forward path of the score gradient objective. + +# Arguments +- `params`: Variational parameters. +- `aux`: Auxiliary information excluded from the AD path. + +# Auxiliary Information +`aux` should containt the following entries: +- `samples_stop`: Samples drawn from `q = restructure(params)` but with their gradients stopped (excluded from the AD path). +- `logprob_stop`: Log-densities of the target `LogDensityProblem` evaluated over `samples_stop`. +- `adtype`: The `ADType` used for differentiating the forward path. +- `restructure`: Callable for restructuring the varitional distribution from `params`. +""" +function estimate_scoregradelbo_ad_forward(params, aux) + (; samples_stop, logprob_stop, adtype, restructure) = aux + q = restructure_ad_forward(adtype, restructure, params) + ℓπ = logprob_stop + ℓq = logpdf.(Ref(q), AdvancedVI.eachsample(samples_stop)) + f = ℓq - ℓπ + return (mean(abs2, f) - mean(f)^2) / 2 end function AdvancedVI.estimate_gradient!( @@ -120,20 +70,15 @@ function AdvancedVI.estimate_gradient!( restructure, state, ) - q_stop = restructure(params) - aux = ( - rng=rng, - adtype=adtype, - obj=obj, - problem=prob, - restructure=restructure, - q_stop=q_stop, - ) + q = restructure(params) + samples = rand(rng, q, obj.n_samples) + ℓπ = map(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples)) + aux = (adtype=adtype, logprob_stop=ℓπ, samples_stop=samples, restructure=restructure) AdvancedVI.value_and_gradient!( adtype, estimate_scoregradelbo_ad_forward, params, aux, out ) - nelbo = DiffResults.value(out) - stat = (elbo=-nelbo,) - push!(obj.baseline_history, -nelbo) + ℓq = logpdf.(Ref(q), AdvancedVI.eachsample(samples)) + elbo = mean(ℓπ - ℓq) + stat = (elbo=elbo,) return out, nothing, stat end diff --git a/test/inference/repgradelbo_distributionsad.jl b/test/inference/repgradelbo_distributionsad.jl index 42eaeeed..286011ad 100644 --- a/test/inference/repgradelbo_distributionsad.jl +++ b/test/inference/repgradelbo_distributionsad.jl @@ -14,11 +14,10 @@ end @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in [Float64, Float32], (modelname, modelconstr) in Dict(:Normal => normal_meanfield), - n_montecarlo in [1, 10], (objname, objective) in Dict( - :RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo), + :RepGradELBOClosedFormEntropy => RepGradELBO(10), :RepGradELBOStickingTheLanding => - RepGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()), + RepGradELBO(10; entropy=StickingTheLandingEntropy()), ), (adbackname, adtype) in AD_repgradelbo_distributionsad diff --git a/test/inference/repgradelbo_locationscale.jl b/test/inference/repgradelbo_locationscale.jl index 4a84526b..2e294c90 100644 --- a/test/inference/repgradelbo_locationscale.jl +++ b/test/inference/repgradelbo_locationscale.jl @@ -10,16 +10,15 @@ else ) end -@testset "inference ScoreGradELBO VILocationScale" begin +@testset "inference RepGradELBO VILocationScale" begin @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in [Float64, Float32], (modelname, modelconstr) in Dict(:Normal => normal_meanfield, :Normal => normal_fullrank), - n_montecarlo in [1, 10], (objname, objective) in Dict( - :RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo), + :RepGradELBOClosedFormEntropy => RepGradELBO(10), :RepGradELBOStickingTheLanding => - RepGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()), + RepGradELBO(10; entropy=StickingTheLandingEntropy()), ), (adbackname, adtype) in AD_repgradelbo_locationscale diff --git a/test/inference/repgradelbo_locationscale_bijectors.jl b/test/inference/repgradelbo_locationscale_bijectors.jl index 5b197cc4..39aa0d10 100644 --- a/test/inference/repgradelbo_locationscale_bijectors.jl +++ b/test/inference/repgradelbo_locationscale_bijectors.jl @@ -15,11 +15,10 @@ end [Float64, Float32], (modelname, modelconstr) in Dict(:NormalLogNormalMeanField => normallognormal_meanfield), - n_montecarlo in [1, 10], (objname, objective) in Dict( - :RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo), + :RepGradELBOClosedFormEntropy => RepGradELBO(10), :RepGradELBOStickingTheLanding => - RepGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()), + RepGradELBO(10; entropy=StickingTheLandingEntropy()), ), (adbackname, adtype) in AD_repgradelbo_locationscale_bijectors diff --git a/test/inference/scoregradelbo_distributionsad.jl b/test/inference/scoregradelbo_distributionsad.jl index 962b9e03..c7aa9a44 100644 --- a/test/inference/scoregradelbo_distributionsad.jl +++ b/test/inference/scoregradelbo_distributionsad.jl @@ -14,12 +14,7 @@ end @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in [Float64, Float32], (modelname, modelconstr) in Dict(:Normal => normal_meanfield), - n_montecarlo in [1, 10], - (objname, objective) in Dict( - :ScoreGradELBOClosedFormEntropy => ScoreGradELBO(n_montecarlo), - :ScoreGradELBOStickingTheLanding => - ScoreGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()), - ), + (objname, objective) in Dict(:ScoreGradELBO => ScoreGradELBO(10)), (adbackname, adtype) in AD_scoregradelbo_distributionsad seed = (0x38bef07cf9cc549d) @@ -29,7 +24,7 @@ end (; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats T = 1000 - η = 1e-5 + η = 1e-4 opt = Optimisers.Descent(realtype(η)) # For small enough η, the error of SGD, Δλ, is bounded as diff --git a/test/inference/scoregradelbo_locationscale.jl b/test/inference/scoregradelbo_locationscale.jl index d882be71..4ba97d45 100644 --- a/test/inference/scoregradelbo_locationscale.jl +++ b/test/inference/scoregradelbo_locationscale.jl @@ -15,12 +15,7 @@ end [Float64, Float32], (modelname, modelconstr) in Dict(:Normal => normal_meanfield, :Normal => normal_fullrank), - n_montecarlo in [1, 10], - (objname, objective) in Dict( - :ScoreGradELBOClosedFormEntropy => ScoreGradELBO(n_montecarlo), - :ScoreGradELBOStickingTheLanding => - ScoreGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()), - ), + (objname, objective) in Dict(:ScoreGradELBO => ScoreGradELBO(10)), (adbackname, adtype) in AD_scoregradelbo_locationscale seed = (0x38bef07cf9cc549d) @@ -30,7 +25,7 @@ end (; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats T = 1000 - η = 1e-5 + η = 1e-4 opt = Optimisers.Descent(realtype(η)) # For small enough η, the error of SGD, Δλ, is bounded as diff --git a/test/inference/scoregradelbo_locationscale_bijectors.jl b/test/inference/scoregradelbo_locationscale_bijectors.jl index 22be98ba..a996d0a7 100644 --- a/test/inference/scoregradelbo_locationscale_bijectors.jl +++ b/test/inference/scoregradelbo_locationscale_bijectors.jl @@ -15,12 +15,7 @@ end [Float64, Float32], (modelname, modelconstr) in Dict(:NormalLogNormalMeanField => normallognormal_meanfield), - n_montecarlo in [1, 10], - (objname, objective) in Dict( - #:ScoreGradELBOClosedFormEntropy => ScoreGradELBO(n_montecarlo), # not supported yet. - :ScoreGradELBOStickingTheLanding => - ScoreGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()), - ), + (objname, objective) in Dict(:ScoreGradELBO => ScoreGradELBO(10)), (adbackname, adtype) in AD_scoregradelbo_locationscale_bijectors seed = (0x38bef07cf9cc549d) @@ -30,7 +25,7 @@ end (; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats T = 1000 - η = 1e-5 + η = 1e-4 opt = Optimisers.Descent(realtype(η)) b = Bijectors.bijector(model) diff --git a/test/interface/repgradelbo.jl b/test/interface/repgradelbo.jl index b34dc659..da3a59ac 100644 --- a/test/interface/repgradelbo.jl +++ b/test/interface/repgradelbo.jl @@ -1,5 +1,14 @@ -using Test +AD_repgradelbo_interface = if TEST_GROUP == "Enzyme" + [AutoEnzyme()] +else + [ + AutoForwardDiff(), + AutoReverseDiff(), + AutoZygote(), + AutoMooncake(; config=Mooncake.Config()), + ] +end @testset "interface RepGradELBO" begin seed = (0x38bef07cf9cc549d) @@ -9,7 +18,24 @@ using Test (; model, μ_true, L_true, n_dims, is_meanfield) = modelstats - q0 = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims)) + q0 = MeanFieldGaussian(zeros(n_dims), Diagonal(ones(n_dims))) + + @testset "basic" begin + @testset for adtype in AD_repgradelbo_interface, n_montecarlo in [1, 10] + obj = RepGradELBO(n_montecarlo) + _, _, stats, _ = optimize( + rng, + model, + obj, + q0, + 10; + optimizer=Descent(1e-5), + show_progress=false, + adtype=adtype, + ) + @assert isfinite(last(stats).elbo) + end + end obj = RepGradELBO(10) rng = StableRNG(seed) @@ -27,17 +53,6 @@ using Test end end -AD_repgradelbo_stl = if TEST_GROUP == "Enzyme" - [AutoEnzyme()] -else - [ - AutoForwardDiff(), - AutoReverseDiff(), - AutoZygote(), - AutoMooncake(; config=Mooncake.Config()), - ] -end - @testset "interface RepGradELBO STL variance reduction" begin seed = (0x38bef07cf9cc549d) rng = StableRNG(seed) @@ -45,12 +60,12 @@ end modelstats = normal_meanfield(rng, Float64) (; model, μ_true, L_true, n_dims, is_meanfield) = modelstats - @testset for adtype in AD_repgradelbo_stl + @testset for adtype in AD_repgradelbo_interface, n_montecarlo in [1, 10] q_true = MeanFieldGaussian( Vector{eltype(μ_true)}(μ_true), Diagonal(Vector{eltype(L_true)}(diag(L_true))) ) params, re = Optimisers.destructure(q_true) - obj = RepGradELBO(10; entropy=StickingTheLandingEntropy()) + obj = RepGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()) out = DiffResults.DiffResult(zero(eltype(params)), similar(params)) aux = ( diff --git a/test/interface/scoregradelbo.jl b/test/interface/scoregradelbo.jl index 8b0a3428..f368626e 100644 --- a/test/interface/scoregradelbo.jl +++ b/test/interface/scoregradelbo.jl @@ -1,5 +1,14 @@ -using Test +AD_scoregradelbo_interface = if TEST_GROUP == "Enzyme" + [AutoEnzyme()] +else + [ + AutoForwardDiff(), + AutoReverseDiff(), + AutoZygote(), + AutoMooncake(; config=Mooncake.Config()), + ] +end @testset "interface ScoreGradELBO" begin seed = (0x38bef07cf9cc549d) @@ -9,7 +18,24 @@ using Test (; model, μ_true, L_true, n_dims, is_meanfield) = modelstats - q0 = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims)) + q0 = MeanFieldGaussian(zeros(n_dims), Diagonal(ones(n_dims))) + + @testset "basic" begin + @testset for adtype in AD_scoregradelbo_interface, n_montecarlo in [1, 10] + obj = ScoreGradELBO(n_montecarlo) + _, _, stats, _ = optimize( + rng, + model, + obj, + q0, + 10; + optimizer=Descent(1e-5), + show_progress=false, + adtype=adtype, + ) + @assert isfinite(last(stats).elbo) + end + end obj = ScoreGradELBO(10) rng = StableRNG(seed)