Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate to DifferentiationInterface #98

Merged
merged 17 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
steps:
- label: "CUDA with julia {{matrix.julia}}"
plugins:
- JuliaCI/julia#v1:
version: "{{matrix.julia}}"
- JuliaCI/julia-test#v1:
test_args: "--quickfail"
agents:
queue: "juliagpu"
cuda: "*"
timeout_in_minutes: 60
env:
GROUP: "GPU"
ADVANCEDVI_TEST_CUDA: "true"
matrix:
setup:
julia:
- "1.10"
20 changes: 9 additions & 11 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Expand All @@ -24,50 +25,47 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
AdvancedVIBijectorsExt = "Bijectors"
AdvancedVIEnzymeExt = "Enzyme"
AdvancedVIForwardDiffExt = "ForwardDiff"
AdvancedVIReverseDiffExt = "ReverseDiff"
AdvancedVITapirExt = "Tapir"
AdvancedVIZygoteExt = "Zygote"

[compat]
ADTypes = "0.1, 0.2, 1"
ADTypes = "1"
Accessors = "0.1"
Bijectors = "0.13"
ChainRulesCore = "1.16"
DiffResults = "1"
DifferentiationInterface = "0.6"
Distributions = "0.25.111"
DocStringExtensions = "0.8, 0.9"
Enzyme = "0.13"
FillArrays = "1.3"
ForwardDiff = "0.10.36"
ForwardDiff = "0.10"
Functors = "0.4"
LinearAlgebra = "1"
LogDensityProblems = "2"
Mooncake = "0.4"
Optimisers = "0.2.16, 0.3"
ProgressMeter = "1.6"
Random = "1"
Requires = "1.0"
ReverseDiff = "1.15.1"
ReverseDiff = "1"
SimpleUnPack = "1.1.0"
StatsBase = "0.32, 0.33, 0.34"
Tapir = "0.2"
Zygote = "0.6.63"
Zygote = "0.6"
julia = "1.7"

[extras]
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

[compat]
ADTypes = "0.1.6"
ADTypes = "1"
AdvancedVI = "0.3"
Bijectors = "0.13.6"
Distributions = "0.25"
Expand Down
16 changes: 0 additions & 16 deletions ext/AdvancedVIEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

module AdvancedVIEnzymeExt

if isdefined(Base, :get_extension)
Expand All @@ -15,21 +14,6 @@ function AdvancedVI.restructure_ad_forward(::ADTypes.AutoEnzyme, restructure, pa
return restructure(params)::typeof(restructure.model)
end

function AdvancedVI.value_and_gradient!(
::ADTypes.AutoEnzyme, f, x::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult
)
∇x = DiffResults.gradient(out)
fill!(∇x, zero(eltype(∇x)))
_, y = Enzyme.autodiff(
Enzyme.set_runtime_activity(Enzyme.ReverseWithPrimal, true),
Enzyme.Const(f),
Enzyme.Active,
Enzyme.Duplicated(x, ∇x),
)
DiffResults.value!(out, y)
return out
end

function AdvancedVI.value_and_gradient!(
::ADTypes.AutoEnzyme,
f,
Expand Down
42 changes: 0 additions & 42 deletions ext/AdvancedVIForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,42 +0,0 @@

module AdvancedVIForwardDiffExt

if isdefined(Base, :get_extension)
using ForwardDiff
using AdvancedVI
using AdvancedVI: ADTypes, DiffResults
else
using ..ForwardDiff
using ..AdvancedVI
using ..AdvancedVI: ADTypes, DiffResults
end

getchunksize(::ADTypes.AutoForwardDiff{chunksize}) where {chunksize} = chunksize

function AdvancedVI.value_and_gradient!(
ad::ADTypes.AutoForwardDiff,
f,
x::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult,
)
chunk_size = getchunksize(ad)
config = if isnothing(chunk_size)
ForwardDiff.GradientConfig(f, x)
else
ForwardDiff.GradientConfig(f, x, ForwardDiff.Chunk(length(x), chunk_size))
end
ForwardDiff.gradient!(out, f, x, config)
return out
end

function AdvancedVI.value_and_gradient!(
ad::ADTypes.AutoForwardDiff,
f,
x::AbstractVector,
aux,
out::DiffResults.MutableDiffResult,
)
return AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out)
end

end
Empty file added ext/AdvancedVIMooncakeExt.jl
Empty file.
36 changes: 0 additions & 36 deletions ext/AdvancedVIReverseDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,36 +0,0 @@

module AdvancedVIReverseDiffExt

if isdefined(Base, :get_extension)
using AdvancedVI
using AdvancedVI: ADTypes, DiffResults
using ReverseDiff
else
using ..AdvancedVI
using ..AdvancedVI: ADTypes, DiffResults
using ..ReverseDiff
end

# ReverseDiff without compiled tape
function AdvancedVI.value_and_gradient!(
ad::ADTypes.AutoReverseDiff,
f,
x::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult,
)
tp = ReverseDiff.GradientTape(f, x)
ReverseDiff.gradient!(out, tp, x)
return out
end

function AdvancedVI.value_and_gradient!(
ad::ADTypes.AutoReverseDiff,
f,
x::AbstractVector{<:Real},
aux,
out::DiffResults.MutableDiffResult,
)
return AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out)
end

end
37 changes: 0 additions & 37 deletions ext/AdvancedVITapirExt.jl

This file was deleted.

36 changes: 0 additions & 36 deletions ext/AdvancedVIZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -1,36 +0,0 @@

module AdvancedVIZygoteExt

if isdefined(Base, :get_extension)
using AdvancedVI
using AdvancedVI: ADTypes, DiffResults
using ChainRulesCore
using Zygote
else
using ..AdvancedVI
using ..AdvancedVI: ADTypes, DiffResults
using ..ChainRulesCore
using ..Zygote
end

function AdvancedVI.value_and_gradient!(
::ADTypes.AutoZygote, f, x::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult
)
y, back = Zygote.pullback(f, x)
∇x = back(one(y))
DiffResults.value!(out, y)
DiffResults.gradient!(out, only(∇x))
return out
end

function AdvancedVI.value_and_gradient!(
ad::ADTypes.AutoZygote,
f,
x::AbstractVector{<:Real},
aux,
out::DiffResults.MutableDiffResult,
)
return AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out)
end

end
20 changes: 14 additions & 6 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,17 @@ using LinearAlgebra

using LogDensityProblems

using ADTypes, DiffResults
using ADTypes
using DiffResults
using DifferentiationInterface
using ChainRulesCore

using FillArrays

using StatsBase

# derivatives
# Derivatives
"""
value_and_gradient!(ad, f, x, out)
value_and_gradient!(ad, f, x, aux, out)

Evaluate the value and gradient of a function `f` at `x` using the automatic differentiation backend `ad` and store the result in `out`.
Expand All @@ -38,7 +39,14 @@ Evaluate the value and gradient of a function `f` at `x` using the automatic dif
- `aux`: Auxiliary input passed to `f`.
- `out::DiffResults.MutableDiffResult`: Buffer to contain the output gradient and function value.
"""
function value_and_gradient! end
function value_and_gradient!(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be better as an extension x/ref JuliaDiff/DifferentiationInterface.jl#509

ad::ADTypes.AbstractADType, f, x, aux, out::DiffResults.MutableDiffResult
)
grad_buf = DiffResults.gradient(out)
y, _ = DifferentiationInterface.value_and_gradient!(f, grad_buf, ad, x, Constant(aux))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You would benefit from DI's preparation mechanism, especially for backends like ForwardDiff and ReverseDiff

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also for Moooncake.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gdalle @willtebbutt In our case, features to come in the future (like subsampling) will result in aux.prob change at every iteration. Is is still valid to re-use the same prep object? Or is the idea that I should call DI.prepare_gradient every time before calling DI.value_and_gradient?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mooncake.jl ought to be fine provided that the type doesn't change. Am I correct in assuming that in the case of subsampling, you're just changing the indices of the data which are subsampled at each iteration?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But @gdalle is that true in general? For example I imagine reversediff or the upcoming reactant support would result in an invalid prep object, no?

Copy link

@gdalle gdalle Sep 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the size of x and the size of the contents of aux.prob don't change, you should be good to go (except for ReverseDiff as explained above). That would require the minibatches to always have the same length I guess?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For most backends, what preparation does is allocate caches that correspond to the various differentiated inputs and contexts. These caches have a specific size which determines whether preparation can be reused, but then inside your function you can create more objects with arbitrary sizes, that's none of my business.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's none of my business.

Haha fair enough.

(except for ReverseDiff as explained above). That would require the minibatches to always have the same length I guess?

I think compiled tapes actually need everything to be the same except x no? But in any case, I think the best course of action would be to expose control on whether to use preparing or not. But I feel that would be a little too complicated for this PR alone.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah I had forgotten constants. As soon as you use constants then DI recompiles the tape every time anyway

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see. Okay, then I guess there's nothing preventing us from using prepare_gradient.

DiffResults.value!(out, y)
return out
end

"""
restructure_ad_forward(adtype, restructure, params)
Expand Down Expand Up @@ -131,7 +139,7 @@ function estimate_objective end
export estimate_objective

"""
estimate_gradient!(rng, obj, adtype, out, prob, λ, restructure, obj_state)
estimate_gradient!(rng, obj, adtype, out, prob, params, restructure, obj_state)

Estimate (possibly stochastic) gradients of the variational objective `obj` targeting `prob` with respect to the variational parameters `λ`

Expand All @@ -141,7 +149,7 @@ Estimate (possibly stochastic) gradients of the variational objective `obj` targ
- `adtype::ADTypes.AbstractADType`: Automatic differentiation backend.
- `out::DiffResults.MutableDiffResult`: Buffer containing the objective value and gradient estimates.
- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface.
- `λ`: Variational parameters to evaluate the gradient on.
- `params`: Variational parameters to evaluate the gradient on.
- `restructure`: Function that reconstructs the variational approximation from `λ`.
- `obj_state`: Previous state of the objective.

Expand Down
2 changes: 1 addition & 1 deletion src/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ The arguments are as follows:
- `restructure`: Function that restructures the variational approximation from the variational parameters. Calling `restructure(param)` reconstructs the variational approximation.
- `gradient`: The estimated (possibly stochastic) gradient.

`cb` can return a `NamedTuple` containing some additional information computed within `cb`.
`callback` can return a `NamedTuple` containing some additional information computed within `cb`.
This will be appended to the statistic of the current corresponding iteration.
Otherwise, just return `nothing`.

Expand Down
6 changes: 4 additions & 2 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Expand All @@ -26,7 +26,8 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
[compat]
ADTypes = "0.2.1, 1"
Bijectors = "0.13"
DiffResults = "1.0"
DiffResults = "1"
DifferentiationInterface = "0.6"
Distributions = "0.25.111"
DistributionsAD = "0.6.45"
FillArrays = "1.6.1"
Expand All @@ -41,6 +42,7 @@ ReverseDiff = "1.15.1"
SimpleUnPack = "1.1.0"
StableRNGs = "1.0.0"
Statistics = "1"
StatsBase = "0.34"
Test = "1"
Tracker = "0.2.20"
Zygote = "0.6.63"
Expand Down
8 changes: 5 additions & 3 deletions test/inference/repgradelbo_distributionsad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ AD_distributionsad = Dict(
:Zygote => AutoZygote(),
)

if @isdefined(Tapir)
AD_distributionsad[:Tapir] = AutoTapir(; safe_mode=false)
if @isdefined(Mooncake)
AD_distributionsad[:Mooncake] = AutoMooncake(; config=nothing)
end

if @isdefined(Enzyme)
AD_distributionsad[:Enzyme] = AutoEnzyme()
AD_distributionsad[:Enzyme] = AutoEnzyme(;
mode=set_runtime_activity(ReverseWithPrimal), function_annotation=Const
)
end

@testset "inference RepGradELBO DistributionsAD" begin
Expand Down
Loading
Loading