-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
448cdda
commit 314eacf
Showing
6 changed files
with
215 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
name: Benchmarks | ||
on: | ||
push: | ||
branches: | ||
- master | ||
pull_request: | ||
branches: | ||
- master | ||
|
||
concurrency: | ||
# Skip intermediate builds: always. | ||
# Cancel intermediate builds: only if it is a pull request build. | ||
group: ${{ github.workflow }}-${{ github.ref }} | ||
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} | ||
|
||
jobs: | ||
benchmark: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v4 | ||
- uses: julia-actions/setup-julia@v2 | ||
with: | ||
version: '1' | ||
arch: x64 | ||
- uses: actions/cache@v4 | ||
env: | ||
cache-name: cache-artifacts | ||
with: | ||
path: ~/.julia/artifacts | ||
key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} | ||
restore-keys: | | ||
${{ runner.os }}-test-${{ env.cache-name }}- | ||
${{ runner.os }}-test- | ||
${{ runner.os }}- | ||
- name: Run benchmark | ||
run: | | ||
cd bench | ||
julia --project --threads=2 --color=yes -e ' | ||
using Pkg; | ||
Pkg.develop(PackageSpec(path=joinpath(pwd(), ".."))); | ||
Pkg.instantiate(); | ||
include("benchmarks.jl")' | ||
- name: Parse & Upload Benchmark Results | ||
uses: benchmark-action/github-action-benchmark@v1 | ||
with: | ||
name: Benchmark Results | ||
tool: 'julia' | ||
output-file-path: bench/benchmark_results.json | ||
summary-always: true | ||
github-token: ${{ secrets.GITHUB_TOKEN }} | ||
comment-always: true | ||
alert-threshold: "200%" | ||
fail-on-alert: true | ||
benchmark-data-dir-path: benchmarks | ||
auto-push: ${{ github.event_name != 'pull_request' }} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
[deps] | ||
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" | ||
AdvancedVI = "b5ca4192-6429-45e5-a2d9-87aec30a685c" | ||
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" | ||
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" | ||
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" | ||
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" | ||
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" | ||
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" | ||
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" | ||
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" | ||
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" | ||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" | ||
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" | ||
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" | ||
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
|
||
# AdvancedVI.jl Continuous Benchmarking | ||
|
||
This subdirectory contains code for continuous benchmarking of the performance of `AdvancedVI.jl`. | ||
The initial version was heavily inspired by the setup of [Lux.jl](https://github.com/LuxDL/Lux.jl/tree/main). | ||
The Github action and pages integration is provided by https://github.com/benchmark-action/github-action-benchmark/ and [BenchmarkTools.jl](https://github.com/JuliaCI/BenchmarkTools.jl). | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
|
||
using ADTypes, ForwardDiff, ReverseDiff, Zygote | ||
using AdvancedVI | ||
using BenchmarkTools | ||
using Bijectors | ||
using Distributions | ||
using DistributionsAD | ||
using FillArrays | ||
using InteractiveUtils | ||
using LinearAlgebra | ||
using LogDensityProblems | ||
using Optimisers | ||
using Random | ||
|
||
BLAS.set_num_threads(min(4, Threads.nthreads())) | ||
|
||
@info sprint(versioninfo) | ||
@info "BLAS threads: $(BLAS.get_num_threads())" | ||
|
||
include("utils.jl") | ||
include("normallognormal.jl") | ||
|
||
const SUITES = BenchmarkGroup() | ||
|
||
# Comment until https://github.com/TuringLang/Bijectors.jl/pull/315 is merged | ||
# SUITES["normal + bijector"]["meanfield"]["Zygote"] = | ||
# @benchmarkable normallognormal( | ||
# ; | ||
# fptype = Float64, | ||
# adtype = AutoZygote(), | ||
# family = :meanfield, | ||
# objective = :RepGradELBO, | ||
# n_montecarlo = 4, | ||
# ) | ||
|
||
SUITES["normal + bijector"]["meanfield"]["ReverseDiff"] = | ||
@benchmarkable normallognormal( | ||
; | ||
fptype = Float64, | ||
adtype = AutoReverseDiff(), | ||
family = :meanfield, | ||
objective = :RepGradELBO, | ||
n_montecarlo = 4, | ||
) | ||
|
||
SUITES["normal + bijector"]["meanfield"]["ForwardDiff"] = | ||
@benchmarkable normallognormal( | ||
; | ||
fptype = Float64, | ||
adtype = AutoForwardDiff(), | ||
family = :meanfield, | ||
objective = :RepGradELBO, | ||
n_montecarlo = 4, | ||
) | ||
|
||
BenchmarkTools.tune!(SUITES; verbose=true) | ||
results = BenchmarkTools.run(SUITES; verbose=true) | ||
display(median(results)) | ||
|
||
BenchmarkTools.save(joinpath(@__DIR__, "benchmark_results.json"), median(results)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
|
||
struct NormalLogNormal{MX,SX,MY,SY} | ||
μ_x::MX | ||
σ_x::SX | ||
μ_y::MY | ||
Σ_y::SY | ||
end | ||
|
||
function LogDensityProblems.logdensity(model::NormalLogNormal, θ) | ||
(; μ_x, σ_x, μ_y, Σ_y) = model | ||
logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) | ||
end | ||
|
||
function LogDensityProblems.dimension(model::NormalLogNormal) | ||
length(model.μ_y) + 1 | ||
end | ||
|
||
function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) | ||
LogDensityProblems.LogDensityOrder{0}() | ||
end | ||
|
||
function Bijectors.bijector(model::NormalLogNormal) | ||
(; μ_x, σ_x, μ_y, Σ_y) = model | ||
Bijectors.Stacked( | ||
Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), | ||
[1:1, 2:1+length(μ_y)]) | ||
end | ||
|
||
function normallognormal(; fptype, adtype, family, objective, kwargs...) | ||
n_dims = 10 | ||
μ_x = fptype(5.0) | ||
σ_x = fptype(0.3) | ||
μ_y = Fill(fptype(5.0), n_dims) | ||
σ_y = Fill(fptype(0.3), n_dims) | ||
model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2)) | ||
|
||
obj = variational_objective(objective; kwargs...) | ||
|
||
d = LogDensityProblems.dimension(model) | ||
q = variational_standard_mvnormal(fptype, d, family) | ||
|
||
b = Bijectors.bijector(model) | ||
binv = inverse(b) | ||
q_transformed = Bijectors.TransformedDistribution(q, binv) | ||
|
||
max_iter = 10^3 | ||
AdvancedVI.optimize( | ||
model, | ||
obj, | ||
q_transformed, | ||
max_iter; | ||
adtype, | ||
optimizer = Optimisers.Adam(fptype(1e-3)), | ||
show_progress = false, | ||
) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
|
||
function variational_standard_mvnormal(type::Type, n_dims::Int, family::Symbol) | ||
if family == :meanfield | ||
AdvancedVI.MeanFieldGaussian( | ||
zeros(type, n_dims), Diagonal(ones(type, n_dims)) | ||
) | ||
else | ||
AdvancedVI.FullRankGaussian( | ||
zeros(type, n_dims), Matrix(type, I, n_dims, n_dims) | ||
) | ||
end | ||
end | ||
|
||
function variational_objective(objective::Symbol; kwargs...) | ||
if objective == :RepGradELBO | ||
AdvancedVI.RepGradELBO(kwargs[:n_montecarlo]) | ||
elseif objective == :RepGradELBOSTL | ||
AdvancedVI.RepGradELBO(kwargs[:n_montecarlo], entropy=StickingTheLandingEntropy()) | ||
end | ||
end |