Skip to content

Commit

Permalink
add basic tests for interface tests of variational objectives
Browse files Browse the repository at this point in the history
  • Loading branch information
Red-Portal committed Dec 3, 2024
1 parent c69a5ed commit bfa9de0
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 18 deletions.
2 changes: 1 addition & 1 deletion test/inference/scoregradelbo_locationscale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ end
Dict(:Normal => normal_meanfield, :Normal => normal_fullrank),
n_montecarlo in [1, 10],
(objname, objective) in Dict(:ScoreGradELBO => ScoreGradELBO(n_montecarlo)),
(adbackname, adtype) in AD_locationscale
(adbackname, adtype) in AD_scoregradelbo_locationscale

seed = (0x38bef07cf9cc549d)
rng = StableRNG(seed)
Expand Down
45 changes: 30 additions & 15 deletions test/interface/repgradelbo.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@

using Test
AD_repgradelbo_interface = if TEST_GROUP == "Enzyme"
[AutoEnzyme()]
else
[
AutoForwardDiff(),
AutoReverseDiff(),
AutoZygote(),
AutoMooncake(; config=Mooncake.Config()),
]
end

@testset "interface RepGradELBO" begin
seed = (0x38bef07cf9cc549d)
Expand All @@ -9,7 +18,24 @@ using Test

(; model, μ_true, L_true, n_dims, is_meanfield) = modelstats

q0 = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims))
q0 = MeanFieldGaussian(zeros(n_dims), Diagonal(ones(n_dims)))

@testset "basic" begin
@testset for adtype in AD_repgradelbo_interface, n_montecarlo in [1, 10]
obj = RepGradELBO(n_montecarlo)
_, _, stats, _ = optimize(
rng,
model,
obj,
q0,
10;
optimizer=Descent(1e-5),
show_progress=false,
adtype=adtype,
)
@assert isfinite(last(stats).elbo)
end
end

obj = RepGradELBO(10)
rng = StableRNG(seed)
Expand All @@ -27,30 +53,19 @@ using Test
end
end

AD_repgradelbo_stl = if TEST_GROUP == "Enzyme"
[AutoEnzyme()]
else
[
AutoForwardDiff(),
AutoReverseDiff(),
AutoZygote(),
AutoMooncake(; config=Mooncake.Config()),
]
end

@testset "interface RepGradELBO STL variance reduction" begin
seed = (0x38bef07cf9cc549d)
rng = StableRNG(seed)

modelstats = normal_meanfield(rng, Float64)
(; model, μ_true, L_true, n_dims, is_meanfield) = modelstats

@testset for adtype in AD_repgradelbo_stl
@testset for adtype in AD_repgradelbo_interface, n_montecarlo in [1, 10]
q_true = MeanFieldGaussian(
Vector{eltype(μ_true)}(μ_true), Diagonal(Vector{eltype(L_true)}(diag(L_true)))
)
params, re = Optimisers.destructure(q_true)
obj = RepGradELBO(10; entropy=StickingTheLandingEntropy())
obj = RepGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy())
out = DiffResults.DiffResult(zero(eltype(params)), similar(params))

aux = (
Expand Down
30 changes: 28 additions & 2 deletions test/interface/scoregradelbo.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@

using Test
AD_scoregradelbo_interface = if TEST_GROUP == "Enzyme"
[AutoEnzyme()]
else
[
AutoForwardDiff(),
AutoReverseDiff(),
AutoZygote(),
AutoMooncake(; config=Mooncake.Config()),
]
end

@testset "interface ScoreGradELBO" begin
seed = (0x38bef07cf9cc549d)
Expand All @@ -9,7 +18,24 @@ using Test

(; model, μ_true, L_true, n_dims, is_meanfield) = modelstats

q0 = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims))
q0 = MeanFieldGaussian(zeros(n_dims), Diagonal(ones(n_dims)))

@testset "basic" begin
@testset for adtype in AD_scoregradelbo_interface, n_montecarlo in [1, 10]
obj = ScoreGradELBO(n_montecarlo)
_, _, stats, _ = optimize(
rng,
model,
obj,
q0,
10;
optimizer=Descent(1e-5),
show_progress=false,
adtype=adtype,
)
@assert isfinite(last(stats).elbo)
end
end

obj = ScoreGradELBO(10)
rng = StableRNG(seed)
Expand Down

0 comments on commit bfa9de0

Please sign in to comment.