Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor interface for projections/proximal operators #147

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ q_avg, _, stats, _ = AdvancedVI.optimize(
q_transformed,
max_iter;
adtype=ADTypes.AutoForwardDiff(),
optimizer=Optimisers.Adam(1e-3),
optimizer=ProjectScale(Optimisers.Adam(1e-3)),
)

# Evaluate final ELBO with 10^3 Monte Carlo samples
Expand Down
2 changes: 1 addition & 1 deletion bench/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ begin
]
max_iter = 10^4
d = LogDensityProblems.dimension(prob)
optimizer = Optimisers.Adam(T(1e-3))
optimizer = ProjectScale(Optimisers.Adam(T(1e-3)))

for (objname, obj) in [
("RepGradELBO", RepGradELBO(10)),
Expand Down
6 changes: 3 additions & 3 deletions docs/src/elbo/repgradelbo.md
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ _, _, stats_cfe, _ = AdvancedVI.optimize(
max_iter;
show_progress = false,
adtype = AutoForwardDiff(),
optimizer = Optimisers.Adam(3e-3),
optimizer = ProjectScale(Optimisers.Adam(3e-3)),
callback = callback,
);

Expand All @@ -229,7 +229,7 @@ _, _, stats_stl, _ = AdvancedVI.optimize(
max_iter;
show_progress = false,
adtype = AutoForwardDiff(),
optimizer = Optimisers.Adam(3e-3),
optimizer = ProjectScale(Optimisers.Adam(3e-3)),
callback = callback,
);

Expand Down Expand Up @@ -316,7 +316,7 @@ _, _, stats_qmc, _ = AdvancedVI.optimize(
max_iter;
show_progress = false,
adtype = AutoForwardDiff(),
optimizer = Optimisers.Adam(3e-3),
optimizer = ProjectScale(Optimisers.Adam(3e-3)),
callback = callback,
);

Expand Down
5 changes: 4 additions & 1 deletion docs/src/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,14 @@ q_avg_trans, q_trans, stats, _ = AdvancedVI.optimize(
n_max_iter;
show_progress=false,
adtype=AutoForwardDiff(),
optimizer=Optimisers.Adam(1e-3),
optimizer=ProjectScale(Optimisers.Adam(1e-3)),
);
nothing
```

`ProjectScale` is a wrapper around an optimization rule such that the variational approximation stays within a stable region of the variational family.
For more information see [this section](@ref projectscale).

`q_avg_trans` is the final output of the optimization procedure.
If a parameter averaging strategy is used through the keyword argument `averager`, `q_avg_trans` is be the output of the averaging strategy, while `q_trans` is the last iterate.

Expand Down
10 changes: 10 additions & 0 deletions docs/src/families.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,16 @@ FullRankGaussian
MeanFieldGaussian
```

### [Scale Projection Operator](@id projectscale)

For the location scale, it is often the case that optimization is stable only when the smallest eigenvalue of the scale matrix is strictly positive[^D2020].
To ensure this, we provide the following wrapper around optimization rule:

```@docs
ProjectScale
```

[^D2020]: Domke, J. (2020). Provable smoothness guarantees for black-box variational inference. In *International Conference on Machine Learning*.
### Gaussian Variational Families

```julia
Expand Down
23 changes: 21 additions & 2 deletions ext/AdvancedVIBijectorsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ else
end

function AdvancedVI.update_variational_params!(
proj::ProjectScale,
::Type{<:Bijectors.TransformedDistribution{<:AdvancedVI.MvLocationScale}},
opt_st,
params,
Expand All @@ -24,9 +25,8 @@ function AdvancedVI.update_variational_params!(
)
opt_st, params = Optimisers.update!(opt_st, params, grad)
q = restructure(params)
ϵ = q.dist.scale_eps
ϵ = proj.scale_eps

# Project the scale matrix to the set of positive definite triangular matrices
diag_idx = diagind(q.dist.scale)
@. q.dist.scale[diag_idx] = max(q.dist.scale[diag_idx], ϵ)

Expand All @@ -35,6 +35,25 @@ function AdvancedVI.update_variational_params!(
return opt_st, params
end

function AdvancedVI.update_variational_params!(
proj::ProjectScale,
::Type{<:Bijectors.TransformedDistribution{<:AdvancedVI.MvLocationScaleLowRank}},
opt_st,
params,
restructure,
grad,
)
opt_st, params = Optimisers.update!(opt_st, params, grad)
q = restructure(params)
ϵ = proj.scale_eps

@. q.dist.scale_diag = max(q.dist.scale_diag, ϵ)

params, _ = Optimisers.destructure(q)

return opt_st, params
end

function AdvancedVI.reparam_with_entropy(
rng::Random.AbstractRNG,
q::Bijectors.TransformedDistribution,
Expand Down
13 changes: 7 additions & 6 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,17 @@ restructure_ad_forward(::ADTypes.AbstractADType, restructure, params) = restruct

# Update for gradient descent step
"""
update_variational_params!(family_type, opt_st, params, restructure, grad)
update_variational_params!(rule, family_type, opt_st, params, restructure, grad)

Update variational distribution according to the update rule in the optimizer state `opt_st` and the variational family `family_type`.
Update variational distribution according to the update rule in the optimizer state `opt_st`, the optimizer given by `rule`, and the variational family type `family_type`.

This is a wrapper around `Optimisers.update!` to provide some indirection.
For example, depending on the optimizer and the variational family, this may do additional things such as applying projection or proximal mappings.
Same as the default behavior of `Optimisers.update!`, `params` and `opt_st` may be updated by the routine and are no longer valid after calling this functino.
Instead, the return values should be used.

# Arguments
- `rule`: Optimization rule.
- `family_type::Type`: Type of the variational family `typeof(restructure(params))`.
- `opt_st`: Optimizer state returned by `Optimisers.setup`.
- `params`: Current set of parameters to be updated.
Expand All @@ -82,9 +83,9 @@ Instead, the return values should be used.
- `opt_st`: Updated optimizer state.
- `params`: Updated parameters.
"""
function update_variational_params! end

function update_variational_params!(::Type, opt_st, params, restructure, grad)
function update_variational_params!(
::Optimisers.AbstractRule, family_type, opt_st, params, restructure, grad
)
return Optimisers.update!(opt_st, params, grad)
end

Expand Down Expand Up @@ -185,7 +186,7 @@ include("objectives/elbo/repgradelbo.jl")
include("objectives/elbo/scoregradelbo.jl")

# Variational Families
export MvLocationScale, MeanFieldGaussian, FullRankGaussian
export MvLocationScale, MeanFieldGaussian, FullRankGaussian, ProjectScale

include("families/location_scale.jl")

Expand Down
84 changes: 39 additions & 45 deletions src/families/location_scale.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,6 @@

struct MvLocationScale{S,D<:ContinuousDistribution,L,E<:Real} <:
ContinuousMultivariateDistribution
location::L
scale::S
dist::D
scale_eps::E
end

"""
MvLocationScale(location, scale, dist; scale_eps)
MvLocationScale(location, scale, dist)

The location scale variational family broadly represents various variational
families using `location` and `scale` variational parameters.
Expand All @@ -20,21 +12,11 @@ represented as follows:
u = rand(dist, d)
z = scale*u + location
```

`scale_eps` sets a constraint on the smallest value of `scale` to be enforced during optimization.
This is necessary to guarantee stable convergence.

# Keyword Arguments
- `scale_eps`: Lower bound constraint for the diagonal of the scale. (default: `1e-4`).
"""
function MvLocationScale(
location::AbstractVector{T},
scale::AbstractMatrix{T},
dist::ContinuousUnivariateDistribution;
scale_eps::T=T(1e-4),
) where {T<:Real}
@assert minimum(diag(scale)) ≥ scale_eps "Initial scale is too small (smallest diagonal value is $(minimum(diag(scale)))). This might result in unstable optimization behavior."
return MvLocationScale(location, scale, dist, scale_eps)
struct MvLocationScale{S,D<:ContinuousDistribution,L} <: ContinuousMultivariateDistribution
location::L
scale::S
dist::D
end

Functors.@functor MvLocationScale (location, scale)
Expand All @@ -44,18 +26,18 @@ Functors.@functor MvLocationScale (location, scale)
# `scale <: Diagonal`, which is not the default behavior. Otherwise, forward-mode AD
# is very inefficient.
# begin
struct RestructureMeanField{S<:Diagonal,D,L,E}
model::MvLocationScale{S,D,L,E}
struct RestructureMeanField{S<:Diagonal,D,L}
model::MvLocationScale{S,D,L}
end

function (re::RestructureMeanField)(flat::AbstractVector)
n_dims = div(length(flat), 2)
location = first(flat, n_dims)
scale = Diagonal(last(flat, n_dims))
return MvLocationScale(location, scale, re.model.dist, re.model.scale_eps)
return MvLocationScale(location, scale, re.model.dist)
end

function Optimisers.destructure(q::MvLocationScale{<:Diagonal,D,L,E}) where {D,L,E}
function Optimisers.destructure(q::MvLocationScale{<:Diagonal,D,L}) where {D,L}
(; location, scale, dist) = q
flat = vcat(location, diag(scale))
return flat, RestructureMeanField(q)
Expand All @@ -66,7 +48,7 @@ Base.length(q::MvLocationScale) = length(q.location)

Base.size(q::MvLocationScale) = size(q.location)

Base.eltype(::Type{<:MvLocationScale{S,D,L,E}}) where {S,D,L,E} = eltype(D)
Base.eltype(::Type{<:MvLocationScale{S,D,L}}) where {S,D,L} = eltype(D)

function StatsBase.entropy(q::MvLocationScale)
(; location, scale, dist) = q
Expand Down Expand Up @@ -131,49 +113,61 @@ function Distributions.cov(q::MvLocationScale)
end

"""
FullRankGaussian(μ, L; scale_eps)
FullRankGaussian(μ, L)

Construct a Gaussian variational approximation with a dense covariance matrix.

# Arguments
- `μ::AbstractVector{T}`: Mean of the Gaussian.
- `L::LinearAlgebra.AbstractTriangular{T}`: Cholesky factor of the covariance of the Gaussian.

# Keyword Arguments
- `scale_eps`: Smallest value allowed for the diagonal of the scale. (default: `1e-4`).
"""
function FullRankGaussian(
μ::AbstractVector{T}, L::LinearAlgebra.AbstractTriangular{T}; scale_eps::T=T(1e-4)
μ::AbstractVector{T}, L::LinearAlgebra.AbstractTriangular{T}
) where {T<:Real}
q_base = Normal{T}(zero(T), one(T))
return MvLocationScale(μ, L, q_base, scale_eps)
return MvLocationScale(μ, L, Normal{T}(zero(T), one(T)))
end

"""
MeanFieldGaussian(μ, L; scale_eps)
MeanFieldGaussian(μ, L)

Construct a Gaussian variational approximation with a diagonal covariance matrix.

# Arguments
- `μ::AbstractVector{T}`: Mean of the Gaussian.
- `L::Diagonal{T}`: Diagonal Cholesky factor of the covariance of the Gaussian.
"""
function MeanFieldGaussian(μ::AbstractVector{T}, L::Diagonal{T}) where {T<:Real}
return MvLocationScale(μ, L, Normal{T}(zero(T), one(T)))
end

# Keyword Arguments
- `scale_eps`: Smallest value allowed for the diagonal of the scale. (default: `1e-4`).
"""
function MeanFieldGaussian(
μ::AbstractVector{T}, L::Diagonal{T}; scale_eps::T=T(1e-4)
) where {T<:Real}
q_base = Normal{T}(zero(T), one(T))
return MvLocationScale(μ, L, q_base, scale_eps)
ProjectScale(rule, scale_eps)

Compose an optimization `rule` with a projection, where the projection ensures that a `LocationScale` or `LocationScaleLowRank` has a scale with eigenvalues larger than `scale_eps`.

# Arguments
- `rule::Optimisers.AbstractRule`: Optimization rule to compose with the projection.
- `scale_eps::Real`: Lower bound on the eigenvalues of the scale matrix of the projection.
"""
struct ProjectScale{Rule<:Optimisers.AbstractRule,F<:Real} <: Optimisers.AbstractRule
rule::Rule
scale_eps::F
end

function ProjectScale(rule, scale_eps::Real=1e-5)
return ProjectScale{typeof(rule),typeof(scale_eps)}(rule, scale_eps)
end

Optimisers.setup(proj::ProjectScale, x) = Optimisers.setup(proj.rule, x)

Optimisers.init(proj::ProjectScale, x) = Optimisers.init(proj.rule, x)

function update_variational_params!(
::Type{<:MvLocationScale}, opt_st, params, restructure, grad
proj::ProjectScale, ::Type{<:MvLocationScale}, opt_st, params, restructure, grad
)
opt_st, params = Optimisers.update!(opt_st, params, grad)
q = restructure(params)
ϵ = q.scale_eps
ϵ = convert(eltype(params), proj.scale_eps)

# Project the scale matrix to the set of positive definite triangular matrices
diag_idx = diagind(q.scale)
Expand Down
Loading
Loading