Skip to content

Commit

Permalink
Merge branch 'master' into tidy_scoregradelbo
Browse files Browse the repository at this point in the history
  • Loading branch information
yebai authored Oct 22, 2024
2 parents feb3200 + 616b581 commit 45b5afd
Show file tree
Hide file tree
Showing 25 changed files with 147 additions and 99 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/Benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: '1'
version: '1.10'
arch: x64
- uses: actions/cache@v4
env:
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/CompatHelper.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
- name: Pkg.add("CompatHelper")
run: julia -e 'using Pkg; Pkg.add("CompatHelper")'
- name: CompatHelper.main()
run: julia -e 'using CompatHelper; CompatHelper.main(; subdirs=["", "bench", "test", "docs"])'
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
COMPATHELPER_PRIV: ${{ secrets.COMPATHELPER_PRIV }}
run: julia -e 'using CompatHelper; CompatHelper.main()'
COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }}
3 changes: 2 additions & 1 deletion .github/workflows/Documentation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
with:
version: '1'
version: '1.10'
- name: Configure doc environment
shell: julia --project=docs --color=yes {0}
run: |
Expand All @@ -29,6 +29,7 @@ jobs:
- uses: julia-actions/julia-docdeploy@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }}
- name: Run doctests
shell: julia --project=docs --color=yes {0}
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/TagBot.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ jobs:
- uses: JuliaRegistries/TagBot@v1
with:
token: ${{ secrets.GITHUB_TOKEN }}

ssh: ${{ secrets.DOCUMENTER_KEY }}
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
*.swp
.vscode/
docs/build/
.DS_Store
Manifest.toml
22 changes: 22 additions & 0 deletions bench/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,35 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ADTypes = "1"
BenchmarkTools = "1"
Bijectors = "0.13"
Distributions = "0.25.111"
DistributionsAD = "0.6"
Enzyme = "0.13.7"
FillArrays = "1"
ForwardDiff = "0.10"
InteractiveUtils = "1"
LogDensityProblems = "2"
Mooncake = "0.4.5"
Optimisers = "0.3"
Random = "1"
ReverseDiff = "1"
SimpleUnPack = "1"
StableRNGs = "1"
Zygote = "0.6"
julia = "1.10"
10 changes: 10 additions & 0 deletions bench/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,13 @@
This subdirectory contains code for continuous benchmarking of the performance of `AdvancedVI.jl`.
The initial version was heavily inspired by the setup of [Lux.jl](https://github.com/LuxDL/Lux.jl/tree/main).
The Github action and pages integration is provided by https://github.com/benchmark-action/github-action-benchmark/ and [BenchmarkTools.jl](https://github.com/JuliaCI/BenchmarkTools.jl).

To run the benchmarks locally, follow the following steps:

```julia
using Pkg
Pkg.activate(".")
Pkg.instantiate()
Pkg.develop("AdvancedVI")
include("benchmarks.jl")
```
87 changes: 59 additions & 28 deletions bench/benchmarks.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@

using ADTypes, ForwardDiff, ReverseDiff, Zygote
using ADTypes
using AdvancedVI
using BenchmarkTools
using Bijectors
using Distributions
using DistributionsAD
using Enzyme, ForwardDiff, ReverseDiff, Zygote, Mooncake
using FillArrays
using InteractiveUtils
using LinearAlgebra
Expand All @@ -17,37 +18,67 @@ BLAS.set_num_threads(min(4, Threads.nthreads()))
@info sprint(versioninfo)
@info "BLAS threads: $(BLAS.get_num_threads())"

include("utils.jl")
include("normallognormal.jl")
include("unconstrdist.jl")

const SUITES = BenchmarkGroup()

# Comment until https://github.com/TuringLang/Bijectors.jl/pull/315 is merged
# SUITES["normal + bijector"]["meanfield"]["Zygote"] =
# @benchmarkable normallognormal(
# ;
# fptype = Float64,
# adtype = AutoZygote(),
# family = :meanfield,
# objective = :RepGradELBO,
# n_montecarlo = 4,
# )

SUITES["normal + bijector"]["meanfield"]["ReverseDiff"] = @benchmarkable normallognormal(;
fptype=Float64,
adtype=AutoReverseDiff(),
family=:meanfield,
objective=:RepGradELBO,
n_montecarlo=4,
)

SUITES["normal + bijector"]["meanfield"]["ForwardDiff"] = @benchmarkable normallognormal(;
fptype=Float64,
adtype=AutoForwardDiff(),
family=:meanfield,
objective=:RepGradELBO,
n_montecarlo=4,
)
function variational_standard_mvnormal(type::Type, n_dims::Int, family::Symbol)
if family == :meanfield
MeanFieldGaussian(zeros(type, n_dims), Diagonal(ones(type, n_dims)))
else
FullRankGaussian(zeros(type, n_dims), Matrix(type, I, n_dims, n_dims))
end
end

begin
T = Float64

for (probname, prob) in [
("normal + bijector", normallognormal(; n_dims=10, realtype=T))
("normal", normal(; n_dims=10, realtype=T))
]
max_iter = 10^4
d = LogDensityProblems.dimension(prob)
optimizer = Optimisers.Adam(T(1e-3))

for (objname, obj) in [
("RepGradELBO", RepGradELBO(10)),
("RepGradELBO + STL", RepGradELBO(10; entropy=StickingTheLandingEntropy())),
],
(adname, adtype) in [
("Zygote", AutoZygote()),
("ForwardDiff", AutoForwardDiff()),
("ReverseDiff", AutoReverseDiff()),
#("Mooncake", AutoMooncake(; config=Mooncake.Config())),
#("Enzyme", AutoEnzyme()),
],
(familyname, family) in [
("meanfield", MeanFieldGaussian(zeros(T, d), Diagonal(ones(T, d)))),
(
"fullrank",
FullRankGaussian(zeros(T, d), LowerTriangular(Matrix{T}(I, d, d))),
),
]

b = Bijectors.bijector(prob)
binv = inverse(b)
q = Bijectors.TransformedDistribution(family, binv)

SUITES[probname][objname][familyname][adname] = begin
@benchmarkable AdvancedVI.optimize(
$prob,
$obj,
$q,
$max_iter;
adtype=$adtype,
optimizer=$optimizer,
show_progress=false,
)
end
end
end
end

BenchmarkTools.tune!(SUITES; verbose=true)
results = BenchmarkTools.run(SUITES; verbose=true)
Expand Down
32 changes: 6 additions & 26 deletions bench/normallognormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,30 +27,10 @@ function Bijectors.bijector(model::NormalLogNormal)
)
end

function normallognormal(; fptype, adtype, family, objective, max_iter=10^3, kwargs...)
n_dims = 10
μ_x = fptype(5.0)
σ_x = fptype(0.3)
μ_y = Fill(fptype(5.0), n_dims)
σ_y = Fill(fptype(0.3), n_dims)
model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y .^ 2))

obj = variational_objective(objective; kwargs...)

d = LogDensityProblems.dimension(model)
q = variational_standard_mvnormal(fptype, d, family)

b = Bijectors.bijector(model)
binv = inverse(b)
q_transformed = Bijectors.TransformedDistribution(q, binv)

return AdvancedVI.optimize(
model,
obj,
q_transformed,
max_iter;
adtype,
optimizer=Optimisers.Adam(fptype(1e-3)),
show_progress=false,
)
function normallognormal(; n_dims=10, realtype=Float64)
μ_x = realtype(5.0)
σ_x = realtype(0.3)
μ_y = Fill(realtype(5.0), n_dims)
σ_y = Fill(realtype(0.3), n_dims)
return model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y .^ 2))
end
26 changes: 26 additions & 0 deletions bench/unconstrdist.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@

struct UnconstrDist{D<:ContinuousMultivariateDistribution}
dist::D
end

function LogDensityProblems.logdensity(model::UnconstrDist, x)
return logpdf(model.dist, x)
end

function LogDensityProblems.dimension(model::UnconstrDist)
return length(model.dist)
end

function LogDensityProblems.capabilities(::Type{<:UnconstrDist})
return LogDensityProblems.LogDensityOrder{0}()
end

function Bijectors.bijector(model::UnconstrDist)
return identity
end

function normal(; n_dims=10, realtype=Float64)
μ = fill(realtype(5), n_dims)
Σ = Diagonal(ones(realtype, n_dims))
return UnconstrDist(MvNormal(μ, Σ))
end
20 changes: 0 additions & 20 deletions bench/utils.jl

This file was deleted.

6 changes: 4 additions & 2 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
QuasiMonteCarlo = "8a4e6c94-4038-4cdc-81c3-7e6ffdb2a71b"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

Expand All @@ -18,13 +19,14 @@ ADTypes = "1"
AdvancedVI = "0.3"
Bijectors = "0.13.6"
Distributions = "0.25"
Documenter = "0.26, 0.27"
Documenter = "1"
FillArrays = "1"
ForwardDiff = "0.10"
LogDensityProblems = "2.1.1"
Optimisers = "0.3"
Plots = "1"
QuasiMonteCarlo = "0.3"
ReverseDiff = "1"
SimpleUnPack = "1"
StatsFuns = "1"
julia = "1.6"
julia = "1.10"
Empty file removed ext/AdvancedVIForwardDiffExt.jl
Empty file.
Empty file removed ext/AdvancedVIMooncakeExt.jl
Empty file.
Empty file removed ext/AdvancedVIReverseDiffExt.jl
Empty file.
Empty file removed ext/AdvancedVIZygoteExt.jl
Empty file.
9 changes: 0 additions & 9 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -263,15 +263,6 @@ end
@require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin
include("../ext/AdvancedVIEnzymeExt.jl")
end
@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" begin
include("../ext/AdvancedVIForwardDiffExt.jl")
end
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
include("../ext/AdvancedVIReverseDiffExt.jl")
end
@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
include("../ext/AdvancedVIZygoteExt.jl")
end
end
end

Expand Down
2 changes: 1 addition & 1 deletion test/families/location_scale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
@test cov(z_samples; dims=2) cov(q_true) rtol = realtype(1e-2)

samples_ref = rand(StableRNG(1), q, n_montecarlo)
@test samples_ref == rand(StableRNG(1), q, n_montecarlo)
@test samples_ref rand(StableRNG(1), q, n_montecarlo)
end

@testset "rand! AbstractVector" begin
Expand Down
2 changes: 1 addition & 1 deletion test/inference/repgradelbo_distributionsad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ AD_distributionsad = Dict(
)

if @isdefined(Mooncake)
AD_distributionsad[:Mooncake] = AutoMooncake(; config=nothing)
AD_distributionsad[:Mooncake] = AutoMooncake(; config=Mooncake.Config())
end

if @isdefined(Enzyme)
Expand Down
2 changes: 1 addition & 1 deletion test/inference/repgradelbo_locationscale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ AD_locationscale = Dict(
)

if @isdefined(Mooncake)
AD_locationscale[:Mooncake] = AutoMooncake(; config=nothing)
AD_locationscale[:Mooncake] = AutoMooncake(; config=Mooncake.Config())
end

if @isdefined(Enzyme)
Expand Down
2 changes: 1 addition & 1 deletion test/inference/repgradelbo_locationscale_bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ AD_locationscale_bijectors = Dict(
)

if @isdefined(Mooncake)
AD_locationscale_bijectors[:Mooncake] = AutoMooncake(; config=nothing)
AD_locationscale_bijectors[:Mooncake] = AutoMooncake(; config=Mooncake.Config())
end

if @isdefined(Enzyme)
Expand Down
2 changes: 1 addition & 1 deletion test/inference/scoregradelbo_distributionsad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ AD_scoregradelbo_distributionsad = Dict(
)

if @isdefined(Mooncake)
AD_scoregradelbo_distributionsad[:Moonscake] = AutoMooncake(; config=nothing)
AD_scoregradelbo_distributionsad[:Mooncake] = AutoMooncake(; config=Mooncake.Config())
end

if @isdefined(Enzyme)
Expand Down
2 changes: 1 addition & 1 deletion test/inference/scoregradelbo_locationscale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ AD_scoregradelbo_locationscale = Dict(
)

if @isdefined(Mooncake)
AD_scoregradelbo_locationscale[:Mooncake] = AutoMooncake(; config=nothing)
AD_scoregradelbo_locationscale[:Mooncake] = AutoMooncake(; config=Mooncake.Config())
end

if @isdefined(Enzyme)
Expand Down
4 changes: 2 additions & 2 deletions test/interface/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ const interface_ad_backends = Dict(
:Zygote => AutoZygote(),
)

if @isdefined(Tapir)
interface_ad_backends[:Tapir] = AutoTapir(; safe_mode=false)
if @isdefined(Mooncake)
interface_ad_backends[:Mooncake] = AutoMooncake(; config=Mooncake.Config())
end

if @isdefined(Enzyme)
Expand Down
Loading

0 comments on commit 45b5afd

Please sign in to comment.