Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ScoreELBO objective #72

Merged
merged 15 commits into from
Sep 30, 2024
43 changes: 24 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
[![Coverage](https://codecov.io/gh/TuringLang/AdvancedVI.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/TuringLang/AdvancedVI.jl)

# AdvancedVI.jl

[AdvancedVI](https://github.com/TuringLang/AdvancedVI.jl) provides implementations of variational inference (VI) algorithms, which is a family of algorithms aiming for scalable approximate Bayesian inference by leveraging optimization.
`AdvancedVI` is part of the [Turing](https://turinglang.org/stable/) probabilistic programming ecosystem.
The purpose of this package is to provide a common accessible interface for various VI algorithms and utilities so that other packages, e.g. `Turing`, only need to write a light wrapper for integration.
The purpose of this package is to provide a common accessible interface for various VI algorithms and utilities so that other packages, e.g. `Turing`, only need to write a light wrapper for integration.
For example, integrating `Turing` with `AdvancedVI.ADVI` only involves converting a `Turing.Model` into a [`LogDensityProblem`](https://github.com/tpapp/LogDensityProblems.jl) and extracting a corresponding `Bijectors.bijector`.

## Examples
Expand All @@ -21,7 +22,8 @@ y &\sim \mathcal{N}\left(\mu_y, \sigma_y^2\right),
\end{aligned}
$$

a `LogDensityProblem` can be implemented as
a `LogDensityProblem` can be implemented as

```julia
using LogDensityProblems
using SimpleUnPack
Expand All @@ -35,54 +37,58 @@ end

function LogDensityProblems.logdensity(model::NormalLogNormal, θ)
(; μ_x, σ_x, μ_y, Σ_y) = model
logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end])
return logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end])
end

function LogDensityProblems.dimension(model::NormalLogNormal)
length(model.μ_y) + 1
return length(model.μ_y) + 1
end

function LogDensityProblems.capabilities(::Type{<:NormalLogNormal})
LogDensityProblems.LogDensityOrder{0}()
return LogDensityProblems.LogDensityOrder{0}()
end
```

Since the support of `x` is constrained to be positive and VI is best done in the unconstrained Euclidean space, we need to use a *bijector* to transform `x` into unconstrained Euclidean space. We will use the [`Bijectors.jl`](https://github.com/TuringLang/Bijectors.jl) package for this purpose.
Since the support of `x` is constrained to be positive and VI is best done in the unconstrained Euclidean space, we need to use a *bijector* to transform `x` into unconstrained Euclidean space. We will use the [`Bijectors.jl`](https://github.com/TuringLang/Bijectors.jl) package for this purpose.
This corresponds to the automatic differentiation variational inference (ADVI) formulation[^KTRGB2017].

```julia
using Bijectors

function Bijectors.bijector(model::NormalLogNormal)
(; μ_x, σ_x, μ_y, Σ_y) = model
Bijectors.Stacked(
return Bijectors.Stacked(
Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]),
[1:1, 2:1+length(μ_y)])
[1:1, 2:(1 + length(μ_y))],
)
end
```

A simpler approach is to use `Turing`, where a `Turing.Model` can be automatically be converted into a `LogDensityProblem` and a corresponding `bijector` is automatically generated.

Let us instantiate a random normal-log-normal model.

```julia
using LinearAlgebra

n_dims = 10
μ_x = randn()
σ_x = exp.(randn())
μ_y = randn(n_dims)
σ_y = exp.(randn(n_dims))
model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2))
μ_x = randn()
σ_x = exp.(randn())
μ_y = randn(n_dims)
σ_y = exp.(randn(n_dims))
model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y .^ 2))
```

We can perform VI with stochastic gradient descent (SGD) using reparameterization gradient estimates of the ELBO[^TL2014][^RMW2014][^KW2014] as follows:

```julia
using Optimisers
using ADTypes, ForwardDiff
using AdvancedVI

# ELBO objective with the reparameterization gradient
n_montecarlo = 10
elbo = AdvancedVI.RepGradELBO(n_montecarlo)
elbo = AdvancedVI.RepGradELBO(n_montecarlo)

# Mean-field Gaussian variational family
d = LogDensityProblems.dimension(model)
Expand All @@ -91,20 +97,19 @@ L = Diagonal(ones(d))
q = AdvancedVI.MeanFieldGaussian(μ, L)

# Match support by applying the `model`'s inverse bijector
b = Bijectors.bijector(model)
binv = inverse(b)
b = Bijectors.bijector(model)
binv = inverse(b)
q_transformed = Bijectors.TransformedDistribution(q, binv)


# Run inference
max_iter = 10^3
q_avg, _, stats, _ = AdvancedVI.optimize(
model,
elbo,
q_transformed,
max_iter;
adtype = ADTypes.AutoForwardDiff(),
optimizer = Optimisers.Adam(1e-3)
adtype=ADTypes.AutoForwardDiff(),
optimizer=Optimisers.Adam(1e-3),
)

# Evaluate final ELBO with 10^3 Monte Carlo samples
Expand Down
2 changes: 0 additions & 2 deletions bench/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@

# AdvancedVI.jl Continuous Benchmarking

This subdirectory contains code for continuous benchmarking of the performance of `AdvancedVI.jl`.
The initial version was heavily inspired by the setup of [Lux.jl](https://github.com/LuxDL/Lux.jl/tree/main).
The Github action and pages integration is provided by https://github.com/benchmark-action/github-action-benchmark/ and [BenchmarkTools.jl](https://github.com/JuliaCI/BenchmarkTools.jl).

34 changes: 15 additions & 19 deletions bench/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,25 +33,21 @@ const SUITES = BenchmarkGroup()
# n_montecarlo = 4,
# )

SUITES["normal + bijector"]["meanfield"]["ReverseDiff"] =
@benchmarkable normallognormal(
;
fptype = Float64,
adtype = AutoReverseDiff(),
family = :meanfield,
objective = :RepGradELBO,
n_montecarlo = 4,
)

SUITES["normal + bijector"]["meanfield"]["ForwardDiff"] =
@benchmarkable normallognormal(
;
fptype = Float64,
adtype = AutoForwardDiff(),
family = :meanfield,
objective = :RepGradELBO,
n_montecarlo = 4,
)
SUITES["normal + bijector"]["meanfield"]["ReverseDiff"] = @benchmarkable normallognormal(;
fptype=Float64,
adtype=AutoReverseDiff(),
family=:meanfield,
objective=:RepGradELBO,
n_montecarlo=4,
)

SUITES["normal + bijector"]["meanfield"]["ForwardDiff"] = @benchmarkable normallognormal(;
fptype=Float64,
adtype=AutoForwardDiff(),
family=:meanfield,
objective=:RepGradELBO,
n_montecarlo=4,
)

BenchmarkTools.tune!(SUITES; verbose=true)
results = BenchmarkTools.run(SUITES; verbose=true)
Expand Down
36 changes: 18 additions & 18 deletions bench/normallognormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,49 +8,49 @@ end

function LogDensityProblems.logdensity(model::NormalLogNormal, θ)
(; μ_x, σ_x, μ_y, Σ_y) = model
logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end])
return logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end])
end

function LogDensityProblems.dimension(model::NormalLogNormal)
length(model.μ_y) + 1
return length(model.μ_y) + 1
end

function LogDensityProblems.capabilities(::Type{<:NormalLogNormal})
LogDensityProblems.LogDensityOrder{0}()
return LogDensityProblems.LogDensityOrder{0}()
end

function Bijectors.bijector(model::NormalLogNormal)
(; μ_x, σ_x, μ_y, Σ_y) = model
Bijectors.Stacked(
return Bijectors.Stacked(
Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]),
[1:1, 2:1+length(μ_y)])
[1:1, 2:(1 + length(μ_y))],
)
end

function normallognormal(; fptype, adtype, family, objective, kwargs...)
function normallognormal(; fptype, adtype, family, objective, max_iter=10^3, kwargs...)
n_dims = 10
μ_x = fptype(5.0)
σ_x = fptype(0.3)
μ_y = Fill(fptype(5.0), n_dims)
σ_y = Fill(fptype(0.3), n_dims)
model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2))
μ_x = fptype(5.0)
σ_x = fptype(0.3)
μ_y = Fill(fptype(5.0), n_dims)
σ_y = Fill(fptype(0.3), n_dims)
model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y .^ 2))

obj = variational_objective(objective; kwargs...)

d = LogDensityProblems.dimension(model)
q = variational_standard_mvnormal(fptype, d, family)
q = variational_standard_mvnormal(fptype, d, family)

b = Bijectors.bijector(model)
binv = inverse(b)
b = Bijectors.bijector(model)
binv = inverse(b)
q_transformed = Bijectors.TransformedDistribution(q, binv)

max_iter = 10^3
AdvancedVI.optimize(
return AdvancedVI.optimize(
model,
obj,
q_transformed,
max_iter;
adtype,
optimizer = Optimisers.Adam(fptype(1e-3)),
show_progress = false,
optimizer=Optimisers.Adam(fptype(1e-3)),
show_progress=false,
)
end
14 changes: 7 additions & 7 deletions bench/utils.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@

function variational_standard_mvnormal(type::Type, n_dims::Int, family::Symbol)
if family == :meanfield
AdvancedVI.MeanFieldGaussian(
zeros(type, n_dims), Diagonal(ones(type, n_dims))
)
AdvancedVI.MeanFieldGaussian(zeros(type, n_dims), Diagonal(ones(type, n_dims)))
else
AdvancedVI.FullRankGaussian(
zeros(type, n_dims), Matrix(type, I, n_dims, n_dims)
)
AdvancedVI.FullRankGaussian(zeros(type, n_dims), Matrix(type, I, n_dims, n_dims))
end
end

function variational_objective(objective::Symbol; kwargs...)
if objective == :RepGradELBO
AdvancedVI.RepGradELBO(kwargs[:n_montecarlo])
elseif objective == :RepGradELBOSTL
AdvancedVI.RepGradELBO(kwargs[:n_montecarlo], entropy=StickingTheLandingEntropy())
AdvancedVI.RepGradELBO(kwargs[:n_montecarlo]; entropy=StickingTheLandingEntropy())
elseif objective == :ScoreGradELBO
throw("ScoreGradELBO not supported yet. Please use ScoreGradELBOSTL instead.")
elseif objective == :ScoreGradELBOSTL
AdvancedVI.ScoreGradELBO(kwargs[:n_montecarlo]; entropy=StickingTheLandingEntropy())
end
end
12 changes: 9 additions & 3 deletions docs/src/elbo/overview.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

# [Evidence Lower Bound Maximization](@id elbomax)

## Introduction

Evidence lower bound (ELBO) maximization[^JGJS1999] is a general family of algorithms that minimize the exclusive (or reverse) Kullback-Leibler (KL) divergence between the target distribution ``\pi`` and a variational approximation ``q_{\lambda}``.
Expand All @@ -8,29 +8,35 @@ More generally, they aim to solve the following problem:
```math
\mathrm{minimize}_{q \in \mathcal{Q}}\quad \mathrm{KL}\left(q, \pi\right),
```

where $$\mathcal{Q}$$ is some family of distributions, often called the variational family.
Since the target distribution ``\pi`` is intractable in general, the KL divergence is also intractable.
Instead, the ELBO maximization strategy maximizes a surrogate objective, the *ELBO*:

```math
\mathrm{ELBO}\left(q\right) \triangleq \mathbb{E}_{\theta \sim q} \log \pi\left(\theta\right) + \mathbb{H}\left(q\right),
```

which serves as a lower bound to the KL.
The ELBO and its gradient can be readily estimated through various strategies.
Overall, ELBO maximization algorithms aim to solve the problem:

```math
\mathrm{maximize}_{q \in \mathcal{Q}}\quad \mathrm{ELBO}\left(q\right).
```

Multiple ways to solve this problem exist, each leading to a different variational inference algorithm.

## Algorithms

Currently, `AdvancedVI` only provides the approach known as black-box variational inference (also known as Monte Carlo VI, Stochastic Gradient VI).
(Introduced independently by two groups [^RGB2014][^TL2014] in 2014.)
In particular, `AdvancedVI` focuses on the reparameterization gradient estimator[^TL2014][^RMW2014][^KW2014], which is generally superior compared to alternative strategies[^XQKS2019], discussed in the following section:
* [RepGradELBO](@ref repgradelbo)

- [RepGradELBO](@ref repgradelbo)

[^JGJS1999]: Jordan, M. I., Ghahramani, Z., Jaakkola, T. S., & Saul, L. K. (1999). An introduction to variational methods for graphical models. Machine learning, 37, 183-233.
[^TL2014]: Titsias, M., & Lázaro-Gredilla, M. (2014). Doubly stochastic variational Bayes for non-conjugate inference. In *International Conference on Machine Learning*.
[^TL2014]: Titsias, M., & Lázaro-Gredilla, M. (2014). Doubly stochastic variational Bayes for non-conjugate inference. In *International Conference on Machine Learning*.
[^RMW2014]: Rezende, D. J., Mohamed, S., & Wierstra, D. (2014). Stochastic backpropagation and approximate inference in deep generative models. In *International Conference on Machine Learning*.
[^KW2014]: Kingma, D. P., & Welling, M. (2014). Auto-encoding variational bayes. In *International Conference on Learning Representations*.
[^XQKS2019]: Xu, M., Quiroz, M., Kohn, R., & Sisson, S. A. (2019). Variance reduction properties of the reparameterization trick. In *The International Conference on Artificial Intelligence and Statistics.
Expand Down
20 changes: 14 additions & 6 deletions docs/src/general.md
Original file line number Diff line number Diff line change
@@ -1,42 +1,50 @@

# [General Usage](@id general)

Each VI algorithm provides the followings:
1. Variational families supported by each VI algorithm.
2. A variational objective corresponding to the VI algorithm.
Note that each variational family is subject to its own constraints.
Thus, please refer to the documentation of the variational inference algorithm of interest.

1. Variational families supported by each VI algorithm.
2. A variational objective corresponding to the VI algorithm.
Note that each variational family is subject to its own constraints.
Thus, please refer to the documentation of the variational inference algorithm of interest.

## Optimizing a Variational Objective

After constructing a *variational objective* `objective` and initializing a *variational approximation*, one can optimize `objective` by calling `optimize`:

```@docs
optimize
```

## Estimating the Objective

In some cases, it is useful to directly estimate the objective value.
This can be done by the following funciton:

```@docs
estimate_objective
```

!!! info
Note that `estimate_objective` is not expected to be differentiated through, and may not result in optimal statistical performance.

Note that `estimate_objective` is not expected to be differentiated through, and may not result in optimal statistical performance.

## Advanced Usage

Each variational objective is a subtype of the following abstract type:

```@docs
AdvancedVI.AbstractVariationalObjective
```

Furthermore, `AdvancedVI` only interacts with each variational objective by querying gradient estimates.
Therefore, to create a new custom objective to be optimized through `AdvancedVI`, it suffices to implement the following function:

```@docs
AdvancedVI.estimate_gradient!
```

If an objective needs to be stateful, one can implement the following function to inialize the state.

```@docs
AdvancedVI.init
```
5 changes: 4 additions & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@ CurrentModule = AdvancedVI
# AdvancedVI

## Introduction

[AdvancedVI](https://github.com/TuringLang/AdvancedVI.jl) provides implementations of variational Bayesian inference (VI) algorithms.
VI algorithms perform scalable and computationally efficient Bayesian inference at the cost of asymptotic exactness.
`AdvancedVI` is part of the [Turing](https://turinglang.org/stable/) probabilistic programming ecosystem.

## Provided Algorithms

`AdvancedVI` currently provides the following algorithm for evidence lower bound maximization:
- [Evidence Lower-Bound Maximization](@ref elbomax)

- [Evidence Lower-Bound Maximization](@ref elbomax)
4 changes: 3 additions & 1 deletion src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,12 @@ Estimate the entropy of `q`.
"""
function estimate_entropy end

export RepGradELBO, ClosedFormEntropy, StickingTheLandingEntropy, MonteCarloEntropy
export RepGradELBO,
ScoreGradELBO, ClosedFormEntropy, StickingTheLandingEntropy, MonteCarloEntropy

include("objectives/elbo/entropy.jl")
include("objectives/elbo/repgradelbo.jl")
include("objectives/elbo/scoregradelbo.jl")

# Variational Families
export MvLocationScale, MeanFieldGaussian, FullRankGaussian
Expand Down
Loading
Loading