Skip to content

Commit

Permalink
Merge branch 'master' of github.com:TuringLang/AdvancedVI.jl into pro…
Browse files Browse the repository at this point in the history
…jected_proximal_location_scale
  • Loading branch information
Red-Portal committed Dec 9, 2024
2 parents 074218a + 1dbf2ac commit a11e5ce
Show file tree
Hide file tree
Showing 32 changed files with 326 additions and 343 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ jobs:
fail-fast: false
matrix:
version:
- '1.7'
- '1.10'
- 'lts'
- '1'
os:
- ubuntu-latest
- macOS-latest
Expand All @@ -29,7 +29,7 @@ jobs:
- x64
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
- uses: julia-actions/setup-julia@v2
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
Expand Down
40 changes: 40 additions & 0 deletions .github/workflows/Enzyme.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
name: Enzyme
on:
push:
branches:
- master
tags: ['*']
pull_request:
workflow_dispatch:
concurrency:
# Skip intermediate builds: always.
# Cancel intermediate builds: only if it is a pull request build.
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}
jobs:
test:
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }}
env:
TEST_GROUP: Enzyme
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
version:
- 'lts'
- '1'
os:
- ubuntu-latest
- macOS-latest
- windows-latest
arch:
- x64
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- uses: julia-actions/cache@v1
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
14 changes: 4 additions & 10 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[weakdeps]
Expand All @@ -36,7 +35,7 @@ AdvancedVIEnzymeExt = "Enzyme"
[compat]
ADTypes = "1"
Accessors = "0.1"
Bijectors = "0.13"
Bijectors = "0.13, 0.14, 0.15"
ChainRulesCore = "1.16"
DiffResults = "1"
DifferentiationInterface = "0.6"
Expand All @@ -45,29 +44,24 @@ DocStringExtensions = "0.8, 0.9"
Enzyme = "0.13"
FillArrays = "1.3"
ForwardDiff = "0.10"
Functors = "0.4"
Functors = "0.4, 0.5"
LinearAlgebra = "1"
LogDensityProblems = "2"
Mooncake = "0.4"
Optimisers = "0.2.16, 0.3"
Optimisers = "0.2.16, 0.3, 0.4"
ProgressMeter = "1.6"
Random = "1"
Requires = "1.0"
ReverseDiff = "1"
SimpleUnPack = "1.1.0"
StatsBase = "0.32, 0.33, 0.34"
Zygote = "0.6"
julia = "1.7"
julia = "1.10, 1.11.2"

[extras]
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Pkg", "Test"]
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ a `LogDensityProblem` can be implemented as

```julia
using LogDensityProblems
using SimpleUnPack

struct NormalLogNormal{MX,SX,MY,SY}
μ_x::MX
Expand Down
6 changes: 3 additions & 3 deletions bench/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
[compat]
ADTypes = "1"
BenchmarkTools = "1"
Bijectors = "0.13"
Bijectors = "0.13, 0.14, 0.15"
Distributions = "0.25.111"
DistributionsAD = "0.6"
Enzyme = "0.13.7"
Expand All @@ -30,10 +30,10 @@ ForwardDiff = "0.10"
InteractiveUtils = "1"
LogDensityProblems = "2"
Mooncake = "0.4.5"
Optimisers = "0.3"
Optimisers = "0.3, 0.4"
Random = "1"
ReverseDiff = "1"
SimpleUnPack = "1"
StableRNGs = "1"
Zygote = "0.6"
julia = "1.10"
julia = "1.10, 1.11.2"
10 changes: 4 additions & 6 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,20 @@ Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
QuasiMonteCarlo = "8a4e6c94-4038-4cdc-81c3-7e6ffdb2a71b"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

[compat]
ADTypes = "1"
AdvancedVI = "0.3"
Bijectors = "0.13.6"
Bijectors = "0.13.6, 0.14, 0.15"
Distributions = "0.25"
Documenter = "0.26, 0.27"
Documenter = "1"
FillArrays = "1"
ForwardDiff = "0.10"
LogDensityProblems = "2.1.1"
Optimisers = "0.3"
Optimisers = "0.3, 0.4"
Plots = "1"
QuasiMonteCarlo = "0.3"
ReverseDiff = "1"
SimpleUnPack = "1"
StatsFuns = "1"
julia = "1.10"
julia = "1.10, 1.11.2"
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ makedocs(;
"Variational Families" => "families.md",
"Optimization" => "optimization.md",
],
warnonly=[:missing_docs],
)

deploydocs(; repo="github.com/TuringLang/AdvancedVI.jl", push_preview=true)
9 changes: 4 additions & 5 deletions docs/src/elbo/repgradelbo.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ using LinearAlgebra
using LogDensityProblems
using Plots
using Random
using SimpleUnPack
using Optimisers
using ADTypes, ForwardDiff
Expand All @@ -143,7 +142,7 @@ struct NormalLogNormal{MX,SX,MY,SY}
end
function LogDensityProblems.logdensity(model::NormalLogNormal, θ)
@unpack μ_x, σ_x, μ_y, Σ_y = model
(; μ_x, σ_x, μ_y, Σ_y) = model
logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end])
end
Expand All @@ -168,7 +167,7 @@ L = Diagonal(ones(d));
q0 = AdvancedVI.MeanFieldGaussian(μ, L)
function Bijectors.bijector(model::NormalLogNormal)
@unpack μ_x, σ_x, μ_y, Σ_y = model
(; μ_x, σ_x, μ_y, Σ_y) = model
Bijectors.Stacked(
Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]),
[1:1, 2:1+length(μ_y)])
Expand Down Expand Up @@ -295,7 +294,7 @@ qmcrng = SobolSample(; R=OwenScramble(; base=2, pad=32))
function Distributions.rand(
rng::AbstractRNG, q::MvLocationScale{<:Diagonal,D,L}, num_samples::Int
) where {L,D}
@unpack location, scale, dist = q
(; location, scale, dist) = q
n_dims = length(location)
scale_diag = diag(scale)
unif_samples = QuasiMonteCarlo.sample(num_samples, length(q), qmcrng)
Expand Down Expand Up @@ -337,7 +336,7 @@ savefig("advi_qmc_dist.svg")
function Distributions.rand(
rng::AbstractRNG, q::MvLocationScale{<:Diagonal, D, L}, num_samples::Int
) where {L, D}
@unpack location, scale, dist = q
(; location, scale, dist) = q
n_dims = length(location)
scale_diag = diag(scale)
scale_diag.*rand(rng, dist, n_dims, num_samples) .+ location
Expand Down
5 changes: 2 additions & 3 deletions docs/src/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ Using the `LogDensityProblems` interface, we the model can be defined as follows

```@example elboexample
using LogDensityProblems
using SimpleUnPack
struct NormalLogNormal{MX,SX,MY,SY}
μ_x::MX
Expand All @@ -25,7 +24,7 @@ struct NormalLogNormal{MX,SX,MY,SY}
end
function LogDensityProblems.logdensity(model::NormalLogNormal, θ)
@unpack μ_x, σ_x, μ_y, Σ_y = model
(; μ_x, σ_x, μ_y, Σ_y) = model
return logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end])
end
Expand Down Expand Up @@ -59,7 +58,7 @@ Thus, we will use [Bijectors](https://github.com/TuringLang/Bijectors.jl) to mat
using Bijectors
function Bijectors.bijector(model::NormalLogNormal)
@unpack μ_x, σ_x, μ_y, Σ_y = model
(; μ_x, σ_x, μ_y, Σ_y) = model
return Bijectors.Stacked(
Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]),
[1:1, 2:(1 + length(μ_y))],
Expand Down
2 changes: 2 additions & 0 deletions docs/src/optimization.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,5 @@ PolynomialAveraging
```

[^DCAMHV2020]: Dhaka, A. K., Catalina, A., Andersen, M. R., Magnusson, M., Huggins, J., & Vehtari, A. (2020). Robust, accurate stochastic optimization for variational inference. Advances in Neural Information Processing Systems, 33, 10961-10973.
[^KMJ2024]: Khaled, A., Mishchenko, K., & Jin, C. (2023). Dowg unleashed: An efficient universal parameter-free gradient descent method. Advances in Neural Information Processing Systems, 36, 6748-6769.
[^IHC2023]: Ivgi, M., Hinder, O., & Carmon, Y. (2023). Dog is sgd's best friend: A parameter-free dynamic step size schedule. In International Conference on Machine Learning (pp. 14465-14499). PMLR.
1 change: 0 additions & 1 deletion src/AdvancedVI.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@

module AdvancedVI

using SimpleUnPack: @unpack, @pack!
using Accessors

using Random
Expand Down
18 changes: 9 additions & 9 deletions src/families/location_scale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ function (re::RestructureMeanField)(flat::AbstractVector)
return MvLocationScale(location, scale, re.model.dist)
end

function Optimisers.destructure(q::MvLocationScale{<:Diagonal,D,L}) where {D,L}
@unpack location, scale, dist = q
function Optimisers.destructure(q::MvLocationScale{<:Diagonal,D,L,E}) where {D,L,E}
(; location, scale, dist) = q
flat = vcat(location, diag(scale))
return flat, RestructureMeanField(q)
end
Expand All @@ -51,27 +51,27 @@ Base.size(q::MvLocationScale) = size(q.location)
Base.eltype(::Type{<:MvLocationScale{S,D,L}}) where {S,D,L} = eltype(D)

function StatsBase.entropy(q::MvLocationScale)
@unpack location, scale, dist = q
(; location, scale, dist) = q
n_dims = length(location)
# `convert` is necessary because `entropy` is not type stable upstream
return n_dims * convert(eltype(location), entropy(dist)) + logdet(scale)
end

function Distributions.logpdf(q::MvLocationScale, z::AbstractVector{<:Real})
@unpack location, scale, dist = q
(; location, scale, dist) = q
return sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - logdet(scale)
end

function Distributions.rand(q::MvLocationScale)
@unpack location, scale, dist = q
(; location, scale, dist) = q
n_dims = length(location)
return scale * rand(dist, n_dims) + location
end

function Distributions.rand(
rng::AbstractRNG, q::MvLocationScale{S,D,L}, num_samples::Int
) where {S,D,L}
@unpack location, scale, dist = q
(; location, scale, dist) = q
n_dims = length(location)
return scale * rand(rng, dist, n_dims, num_samples) .+ location
end
Expand All @@ -80,7 +80,7 @@ end
function Distributions.rand(
rng::AbstractRNG, q::MvLocationScale{<:Diagonal,D,L}, num_samples::Int
) where {L,D}
@unpack location, scale, dist = q
(; location, scale, dist) = q
n_dims = length(location)
scale_diag = diag(scale)
return scale_diag .* rand(rng, dist, n_dims, num_samples) .+ location
Expand All @@ -89,14 +89,14 @@ end
function Distributions._rand!(
rng::AbstractRNG, q::MvLocationScale, x::AbstractVecOrMat{<:Real}
)
@unpack location, scale, dist = q
(; location, scale, dist) = q
rand!(rng, dist, x)
x[:] = scale * x
return x .+= location
end

function Distributions.mean(q::MvLocationScale)
@unpack location, scale = q
(; location, scale) = q
return location + scale * Fill(mean(q.dist), length(location))
end

Expand Down
16 changes: 8 additions & 8 deletions src/families/location_scale_low_rank.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ Base.size(q::MvLocationScaleLowRank) = size(q.location)
Base.eltype(::Type{<:MvLocationScaleLowRank{D,L,SD,SF}}) where {D,L,SD,SF} = eltype(L)

function StatsBase.entropy(q::MvLocationScaleLowRank)
@unpack location, scale_diag, scale_factors, dist = q
(; location, scale_diag, scale_factors, dist) = q
n_dims = length(location)
scale_diag2 = scale_diag .* scale_diag
UtDinvU = Hermitian(scale_factors' * (scale_factors ./ scale_diag2))
Expand All @@ -44,7 +44,7 @@ end
function Distributions.logpdf(
q::MvLocationScaleLowRank, z::AbstractVector{<:Real}; non_differntiable::Bool=false
)
@unpack location, scale_diag, scale_factors, dist = q
(; location, scale_diag, scale_factors, dist) = q
μ_base = mean(dist)
n_dims = length(location)

Expand All @@ -67,7 +67,7 @@ function Distributions.logpdf(
end

function Distributions.rand(q::MvLocationScaleLowRank)
@unpack location, scale_diag, scale_factors, dist = q
(; location, scale_diag, scale_factors, dist) = q
n_dims = length(location)
n_factors = size(scale_factors, 2)
u_diag = rand(dist, n_dims)
Expand All @@ -78,7 +78,7 @@ end
function Distributions.rand(
rng::AbstractRNG, q::MvLocationScaleLowRank, num_samples::Int
)
@unpack location, scale_diag, scale_factors, dist = q
(; location, scale_diag, scale_factors, dist) = q
n_dims = length(location)
n_factors = size(scale_factors, 2)
u_diag = rand(rng, dist, n_dims, num_samples)
Expand All @@ -89,7 +89,7 @@ end
function Distributions._rand!(
rng::AbstractRNG, q::MvLocationScaleLowRank, x::AbstractVecOrMat{<:Real}
)
@unpack location, scale_diag, scale_factors, dist = q
(; location, scale_diag, scale_factors, dist) = q

rand!(rng, dist, x)
x[:] = scale_diag .* x
Expand All @@ -101,22 +101,22 @@ function Distributions._rand!(
end

function Distributions.mean(q::MvLocationScaleLowRank)
@unpack location, scale_diag, scale_factors = q
(; location, scale_diag, scale_factors) = q
μ = mean(q.dist)
return location +
scale_diag .* Fill(μ, length(scale_diag)) +
scale_factors * Fill(μ, size(scale_factors, 2))
end

function Distributions.var(q::MvLocationScaleLowRank)
@unpack scale_diag, scale_factors = q
(; scale_diag, scale_factors) = q
σ2 = var(q.dist)
return σ2 *
(scale_diag .* scale_diag + sum(scale_factors .* scale_factors; dims=2)[:, 1])
end

function Distributions.cov(q::MvLocationScaleLowRank)
@unpack scale_diag, scale_factors = q
(; scale_diag, scale_factors) = q
σ2 = var(q.dist)
return σ2 * (Diagonal(scale_diag .* scale_diag) + scale_factors * scale_factors')
end
Expand Down
Loading

0 comments on commit a11e5ce

Please sign in to comment.