Skip to content

Commit

Permalink
fix conditional testing on Enzyme and Mooncake
Browse files Browse the repository at this point in the history
  • Loading branch information
Red-Portal committed Oct 8, 2024
1 parent 82246ad commit 0ee74f4
Show file tree
Hide file tree
Showing 10 changed files with 25 additions and 84 deletions.
5 changes: 5 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Expand All @@ -29,13 +31,16 @@ DiffResults = "1"
DifferentiationInterface = "0.6"
Distributions = "0.25.111"
DistributionsAD = "0.6.45"
Enzyme = "0.13"
FillArrays = "1.6.1"
ForwardDiff = "0.10.36"
Functors = "0.4.5"
LinearAlgebra = "1"
LogDensityProblems = "2.1.1"
Mooncake = "0.4"
Optimisers = "0.2.16, 0.3"
PDMats = "0.11.7"
Pkg = "1"
Random = "1"
ReverseDiff = "1.15.1"
StableRNGs = "1.0.0"
Expand Down
12 changes: 2 additions & 10 deletions test/inference/repgradelbo_distributionsad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,10 @@ AD_distributionsad = Dict(
:ForwarDiff => AutoForwardDiff(),
#:ReverseDiff => AutoReverseDiff(), # DistributionsAD doesn't support ReverseDiff at the moment
:Zygote => AutoZygote(),
:Mooncake => AutoMooncake(; config=Mooncake.Config()),
:Enzyme => AutoEnzyme(),
)

if @isdefined(Mooncake)
AD_distributionsad[:Mooncake] = AutoMooncake(; config=Mooncake.Config())
end

if @isdefined(Enzyme)
AD_distributionsad[:Enzyme] = AutoEnzyme(;
mode=set_runtime_activity(ReverseWithPrimal), function_annotation=Const
)
end

@testset "inference RepGradELBO DistributionsAD" begin
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in
[Float64, Float32],
Expand Down
12 changes: 2 additions & 10 deletions test/inference/repgradelbo_locationscale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,10 @@ AD_locationscale = Dict(
:ForwarDiff => AutoForwardDiff(),
:ReverseDiff => AutoReverseDiff(),
:Zygote => AutoZygote(),
:Mooncake => AutoMooncake(; config=Mooncake.Config()),
:Enzyme => AutoEnzyme(),
)

if @isdefined(Mooncake)
AD_locationscale[:Mooncake] = AutoMooncake(; config=Mooncake.Config())
end

if @isdefined(Enzyme)
AD_locationscale[:Enzyme] = AutoEnzyme(;
mode=set_runtime_activity(ReverseWithPrimal), function_annotation=Const
)
end

@testset "inference ScoreGradELBO VILocationScale" begin
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in
[Float64, Float32],
Expand Down
12 changes: 2 additions & 10 deletions test/inference/repgradelbo_locationscale_bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,10 @@ AD_locationscale_bijectors = Dict(
:ForwarDiff => AutoForwardDiff(),
:ReverseDiff => AutoReverseDiff(),
:Zygote => AutoZygote(),
:Mooncake => AutoMooncake(; config=Mooncake.Config()),
:Enzyme => AutoEnzyme(),
)

if @isdefined(Mooncake)
AD_locationscale_bijectors[:Mooncake] = AutoMooncake(; config=Mooncake.Config())
end

if @isdefined(Enzyme)
AD_locationscale_bijectors[:Enzyme] = AutoEnzyme(;
mode=set_runtime_activity(ReverseWithPrimal), function_annotation=Const
)
end

@testset "inference RepGradELBO VILocationScale Bijectors" begin
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in
[Float64, Float32],
Expand Down
10 changes: 2 additions & 8 deletions test/inference/scoregradelbo_distributionsad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,10 @@ AD_scoregradelbo_distributionsad = Dict(
:ForwarDiff => AutoForwardDiff(),
#:ReverseDiff => AutoReverseDiff(), # DistributionsAD doesn't support ReverseDiff at the moment
:Zygote => AutoZygote(),
#:Mooncake => AutoMooncake(; config=Mooncake.Config()),
:Enzyme => AutoEnzyme(),
)

if @isdefined(Mooncake)
AD_scoregradelbo_distributionsad[:Mooncake] = AutoMooncake(; config=Mooncake.Config())
end

#if @isdefined(Enzyme)
# AD_scoregradelbo_distributionsad[:Enzyme] = AutoEnzyme()
#end

@testset "inference ScoreGradELBO DistributionsAD" begin
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in
[Float64, Float32],
Expand Down
12 changes: 2 additions & 10 deletions test/inference/scoregradelbo_locationscale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,10 @@ AD_scoregradelbo_locationscale = Dict(
:ForwarDiff => AutoForwardDiff(),
:ReverseDiff => AutoReverseDiff(),
:Zygote => AutoZygote(),
:Mooncake => AutoMooncake(; config=Mooncake.Config()),
:Enzyme => AutoEnzyme(),
)

if @isdefined(Mooncake)
AD_scoregradelbo_locationscale[:Mooncake] = AutoMooncake(; config=Mooncake.Config())
end

if @isdefined(Enzyme)
AD_scoregradelbo_locationscale[:Enzyme] = AutoEnzyme(;
mode=set_runtime_activity(ReverseWithPrimal), function_annotation=Const
)
end

@testset "inference ScoreGradELBO VILocationScale" begin
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in
[Float64, Float32],
Expand Down
12 changes: 3 additions & 9 deletions test/inference/scoregradelbo_locationscale_bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,11 @@
AD_scoregradelbo_locationscale_bijectors = Dict(
:ForwarDiff => AutoForwardDiff(),
:ReverseDiff => AutoReverseDiff(),
#:Zygote => AutoZygote(),
:Zygote => AutoZygote(),
#:Mooncake => AutoMooncake(; safe_mode=false)
:Enzyme => AutoEnzyme()
)

#if @isdefined(Tapir)
# AD_scoregradelbo_locationscale_bijectors[:Tapir] = AutoTapir(; safe_mode=false)
#end

if @isdefined(Enzyme)
AD_scoregradelbo_locationscale_bijectors[:Enzyme] = AutoEnzyme()
end

@testset "inference ScoreGradELBO VILocationScale Bijectors" begin
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in
[Float64, Float32],
Expand Down
10 changes: 2 additions & 8 deletions test/interface/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,10 @@ const interface_ad_backends = Dict(
:ForwardDiff => AutoForwardDiff(),
:ReverseDiff => AutoReverseDiff(),
:Zygote => AutoZygote(),
:Mooncake => AutoMooncake(; config=Mooncake.Config()),
:Enzyme => AutoEnzyme(),
)

if @isdefined(Mooncake)
interface_ad_backends[:Mooncake] = AutoMooncake(; config=Mooncake.Config())
end

if @isdefined(Enzyme)
interface_ad_backends[:Enzyme] = AutoEnzyme()
end

@testset "ad" begin
@testset "$(adname)" for (adname, adtype) in interface_ad_backends
D = 10
Expand Down
17 changes: 5 additions & 12 deletions test/interface/repgradelbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,12 @@ end
(; model, μ_true, L_true, n_dims, is_meanfield) = modelstats

ad_backends = [
ADTypes.AutoForwardDiff(), ADTypes.AutoReverseDiff(), ADTypes.AutoZygote()
ADTypes.AutoForwardDiff(),
ADTypes.AutoReverseDiff(),
ADTypes.AutoZygote(),
AutoMooncake(; config=Mooncake.Config()),
AutoEnzyme(),
]
if @isdefined(Mooncake)
push!(ad_backends, AutoMooncake(; config=Mooncake.Config()))
end
if @isdefined(Enzyme)
push!(
ad_backends,
AutoEnzyme(;
mode=set_runtime_activity(ReverseWithPrimal), function_annotation=Const
),
)
end

@testset for adtype in ad_backends
q_true = MeanFieldGaussian(
Expand Down
7 changes: 0 additions & 7 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,6 @@ using DistributionsAD
using ADTypes
using ForwardDiff, ReverseDiff, Zygote

if VERSION >= v"1.10"
Pkg.add("Mooncake")
Pkg.add("Enzyme")
using Mooncake
using Enzyme
end

using AdvancedVI

const GROUP = get(ENV, "GROUP", "All")
Expand Down

0 comments on commit 0ee74f4

Please sign in to comment.