diff --git a/test/inference/repgradelbo_distributionsad.jl b/test/inference/repgradelbo_distributionsad.jl index 7ae72de1..38dbc6e3 100644 --- a/test/inference/repgradelbo_distributionsad.jl +++ b/test/inference/repgradelbo_distributionsad.jl @@ -21,9 +21,16 @@ rng = StableRNG(seed) modelstats = modelconstr(rng, realtype) - @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats + @unpack model, μ_true, L_true, n_dims, strong_convexity, is_meanfield = modelstats - T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3) + T = 1000 + η = 1e-3 + opt = Optimisers.Descent(realtype(η)) + + # For small enough η, the error of SGD, Δλ, is bounded as + # Δλ ≤ ρ^T Δλ0 + O(η), + # where ρ = 1 - ημ, μ is the strong convexity constant. + contraction_rate = 1 - η*strong_convexity μ0 = Zeros(realtype, n_dims) L0 = Diagonal(Ones(realtype, n_dims)) @@ -33,7 +40,7 @@ Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true) q, stats, _ = optimize( rng, model, objective, q0, T; - optimizer = Optimisers.Adam(realtype(η)), + optimizer = opt, show_progress = PROGRESS, adtype = adtype, ) @@ -42,7 +49,7 @@ L = sqrt(cov(q)) Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) - @test Δλ ≤ Δλ0/T^(1/4) + @test Δλ ≤ contraction_rate^(T/2)*Δλ0 @test eltype(μ) == eltype(μ_true) @test eltype(L) == eltype(L_true) end @@ -51,7 +58,7 @@ rng = StableRNG(seed) q, stats, _ = optimize( rng, model, objective, q0, T; - optimizer = Optimisers.Adam(realtype(η)), + optimizer = opt, show_progress = PROGRESS, adtype = adtype, ) @@ -61,7 +68,7 @@ rng_repl = StableRNG(seed) q, stats, _ = optimize( rng_repl, model, objective, q0, T; - optimizer = Optimisers.Adam(realtype(η)), + optimizer = opt, show_progress = PROGRESS, adtype = adtype, ) diff --git a/test/inference/repgradelbo_locationscale.jl b/test/inference/repgradelbo_locationscale.jl index e98a72f2..73846675 100644 --- a/test/inference/repgradelbo_locationscale.jl +++ b/test/inference/repgradelbo_locationscale.jl @@ -22,9 +22,16 @@ rng = StableRNG(seed) modelstats = modelconstr(rng, realtype) - @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats + @unpack model, μ_true, L_true, n_dims, strong_convexity, is_meanfield = modelstats - T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3) + T = 1000 + η = 1e-3 + opt = Optimisers.Descent(realtype(η)) + + # For small enough η, the error of SGD, Δλ, is bounded as + # Δλ ≤ ρ^T Δλ0 + O(η), + # where ρ = 1 - ημ, μ is the strong convexity constant. + contraction_rate = 1 - η*strong_convexity q0 = if is_meanfield MeanFieldGaussian(zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims))) @@ -37,7 +44,7 @@ Δλ0 = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true) q, stats, _ = optimize( rng, model, objective, q0, T; - optimizer = Optimisers.Adam(realtype(η)), + optimizer = opt, show_progress = PROGRESS, adtype = adtype, ) @@ -46,7 +53,7 @@ L = q.scale Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) - @test Δλ ≤ Δλ0/T^(1/4) + @test Δλ ≤ contraction_rate^(T/2)*Δλ0 @test eltype(μ) == eltype(μ_true) @test eltype(L) == eltype(L_true) end @@ -55,7 +62,7 @@ rng = StableRNG(seed) q, stats, _ = optimize( rng, model, objective, q0, T; - optimizer = Optimisers.Adam(realtype(η)), + optimizer = opt, show_progress = PROGRESS, adtype = adtype, ) @@ -65,7 +72,7 @@ rng_repl = StableRNG(seed) q, stats, _ = optimize( rng_repl, model, objective, q0, T; - optimizer = Optimisers.Adam(realtype(η)), + optimizer = opt, show_progress = PROGRESS, adtype = adtype, ) diff --git a/test/inference/repgradelbo_locationscale_bijectors.jl b/test/inference/repgradelbo_locationscale_bijectors.jl index 6921688d..416a4169 100644 --- a/test/inference/repgradelbo_locationscale_bijectors.jl +++ b/test/inference/repgradelbo_locationscale_bijectors.jl @@ -21,9 +21,11 @@ rng = StableRNG(seed) modelstats = modelconstr(rng, realtype) - @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats + @unpack model, μ_true, L_true, n_dims, strong_convexity, is_meanfield = modelstats - T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3) + T = 1000 + η = 1e-3 + opt = Optimisers.Descent(realtype(η)) b = Bijectors.bijector(model) b⁻¹ = inverse(b) @@ -38,11 +40,16 @@ end q0_z = Bijectors.transformed(q0_η, b⁻¹) + # For small enough η, the error of SGD, Δλ, is bounded as + # Δλ ≤ ρ^T Δλ0 + O(η), + # where ρ = 1 - ημ, μ is the strong convexity constant. + contraction_rate = 1 - η*strong_convexity + @testset "convergence" begin Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true) q, stats, _ = optimize( rng, model, objective, q0_z, T; - optimizer = Optimisers.Adam(realtype(η)), + optimizer = opt, show_progress = PROGRESS, adtype = adtype, ) @@ -51,7 +58,7 @@ L = q.dist.scale Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) - @test Δλ ≤ Δλ0/T^(1/4) + @test Δλ ≤ contraction_rate^(T/2)*Δλ0 @test eltype(μ) == eltype(μ_true) @test eltype(L) == eltype(L_true) end @@ -60,7 +67,7 @@ rng = StableRNG(seed) q, stats, _ = optimize( rng, model, objective, q0_z, T; - optimizer = Optimisers.Adam(realtype(η)), + optimizer = opt, show_progress = PROGRESS, adtype = adtype, ) @@ -70,7 +77,7 @@ rng_repl = StableRNG(seed) q, stats, _ = optimize( rng_repl, model, objective, q0_z, T; - optimizer = Optimisers.Adam(realtype(η)), + optimizer = opt show_progress = PROGRESS, adtype = adtype, ) diff --git a/test/models/normal.jl b/test/models/normal.jl index 3f305e1a..5c3e22e8 100644 --- a/test/models/normal.jl +++ b/test/models/normal.jl @@ -20,24 +20,28 @@ end function normal_fullrank(rng::Random.AbstractRNG, realtype::Type) n_dims = 5 - μ = randn(rng, realtype, n_dims) - L = tril(I + ones(realtype, n_dims, n_dims))/2 - Σ = L*L' |> Hermitian + σ0 = realtype(0.3) + μ = Fill(realtype(5), n_dims) + L = Matrix(σ0*I, n_dims, n_dims) + Σ = L*L' |> Hermitian model = TestNormal(μ, PDMat(Σ, Cholesky(L, 'L', 0))) - TestModel(model, μ, L, n_dims, false) + TestModel(model, μ, LowerTriangular(L), n_dims, 1/σ0^2, false) end function normal_meanfield(rng::Random.AbstractRNG, realtype::Type) n_dims = 5 - μ = randn(rng, realtype, n_dims) - σ = log.(exp.(randn(rng, realtype, n_dims)) .+ 1) + σ0 = realtype(0.3) + μ = Fill(realtype(5), n_dims) + #randn(rng, realtype, n_dims) + σ = Fill(σ0, n_dims) + #log.(exp.(randn(rng, realtype, n_dims)) .+ 1) model = TestNormal(μ, Diagonal(σ.^2)) L = σ |> Diagonal - TestModel(model, μ, L, n_dims, true) + TestModel(model, μ, L, n_dims, 1/σ0^2, true) end diff --git a/test/models/normallognormal.jl b/test/models/normallognormal.jl index 6615084b..54adcd48 100644 --- a/test/models/normallognormal.jl +++ b/test/models/normallognormal.jl @@ -26,40 +26,29 @@ function Bijectors.bijector(model::NormalLogNormal) [1:1, 2:1+length(μ_y)]) end -function normallognormal_fullrank(rng::Random.AbstractRNG, realtype::Type) - n_dims = 5 - - μ_x = randn(rng, realtype) - σ_x = ℯ - μ_y = randn(rng, realtype, n_dims) - L_y = tril(I + ones(realtype, n_dims, n_dims))/2 - Σ_y = L_y*L_y' |> Hermitian - - model = NormalLogNormal(μ_x, σ_x, μ_y, PDMat(Σ_y, Cholesky(L_y, 'L', 0))) - - Σ = Matrix{realtype}(undef, n_dims+1, n_dims+1) - Σ[1,1] = σ_x^2 - Σ[2:end,2:end] = Σ_y - Σ = Σ |> Hermitian - - μ = vcat(μ_x, μ_y) - L = cholesky(Σ).L - - TestModel(model, μ, L, n_dims+1, false) +function normallognormal_fullrank(::Random.AbstractRNG, realtype::Type) + n_y_dims = 5 + + σ0 = realtype(0.3) + μ = Fill(realtype(5.0), n_y_dims+1) + L = Matrix(σ0*I, n_y_dims+1, n_y_dims+1) + Σ = L*L' |> Hermitian + + model = NormalLogNormal( + μ[1], L[1,1], μ[2:end], PDMat(Σ[2:end,2:end], Cholesky(L[2:end,2:end], 'L', 0)) + ) + TestModel(model, μ, LowerTriangular(L), n_y_dims+1, 1/σ0^2, false) end -function normallognormal_meanfield(rng::Random.AbstractRNG, realtype::Type) - n_dims = 5 - - μ_x = randn(rng, realtype) - σ_x = ℯ - μ_y = randn(rng, realtype, n_dims) - σ_y = log.(exp.(randn(rng, realtype, n_dims)) .+ 1) +function normallognormal_meanfield(::Random.AbstractRNG, realtype::Type) + n_y_dims = 5 - model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2)) + σ0 = realtype(0.3) + μ = Fill(realtype(5), n_y_dims + 1) + σ = Fill(σ0, n_y_dims + 1) + L = Diagonal(σ) - μ = vcat(μ_x, μ_y) - L = vcat(σ_x, σ_y) |> Diagonal + model = NormalLogNormal(μ[1], σ[1], μ[2:end], Diagonal(σ[2:end].^2)) - TestModel(model, μ, L, n_dims+1, true) + TestModel(model, μ, L, n_y_dims+1, 1/σ0^2, true) end diff --git a/test/runtests.jl b/test/runtests.jl index 8b8b15d0..3bd13144 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -25,11 +25,12 @@ using AdvancedVI const GROUP = get(ENV, "GROUP", "All") # Models for Inference Tests -struct TestModel{M,L,S} +struct TestModel{M,L,S,SC} model::M μ_true::L L_true::S n_dims::Int + strong_convexity::SC is_meanfield::Bool end include("models/normal.jl")