Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate to DifferentiationInterface #98

Merged
merged 17 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
steps:
- label: "CUDA with julia {{matrix.julia}}"
plugins:
- JuliaCI/julia#v1:
version: "{{matrix.julia}}"
- JuliaCI/julia-test#v1:
test_args: "--quickfail"
agents:
queue: "juliagpu"
cuda: "*"
timeout_in_minutes: 60
env:
GROUP: "GPU"
ADVANCEDVI_TEST_CUDA: "true"
matrix:
setup:
julia:
- "1.10"
26 changes: 11 additions & 15 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ version = "0.3.0"
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Expand All @@ -24,52 +24,48 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
AdvancedVIBijectorsExt = "Bijectors"
AdvancedVIEnzymeExt = "Enzyme"
AdvancedVIForwardDiffExt = "ForwardDiff"
AdvancedVIReverseDiffExt = "ReverseDiff"
AdvancedVITapirExt = "Tapir"
AdvancedVIZygoteExt = "Zygote"

[compat]
ADTypes = "0.1, 0.2, 1"
ADTypes = "1"
Accessors = "0.1"
Bijectors = "0.13"
ChainRulesCore = "1.16"
DiffResults = "1"
DifferentiationInterface = "0.6"
Distributions = "0.25.111"
DocStringExtensions = "0.8, 0.9"
Enzyme = "0.13"
FillArrays = "1.3"
ForwardDiff = "0.10.36"
ForwardDiff = "0.10"
Functors = "0.4"
LinearAlgebra = "1"
LogDensityProblems = "2"
Mooncake = "0.4"
Optimisers = "0.2.16, 0.3"
ProgressMeter = "1.6"
Random = "1"
Requires = "1.0"
ReverseDiff = "1.15.1"
ReverseDiff = "1"
SimpleUnPack = "1.1.0"
StatsBase = "0.32, 0.33, 0.34"
Tapir = "0.2"
Zygote = "0.6.63"
Zygote = "0.6"
julia = "1.7"

[extras]
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Pkg", "Test"]
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

[compat]
ADTypes = "0.1.6"
ADTypes = "1"
AdvancedVI = "0.3"
Bijectors = "0.13.6"
Distributions = "0.25"
Expand Down
34 changes: 5 additions & 29 deletions ext/AdvancedVIEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -1,53 +1,29 @@

module AdvancedVIEnzymeExt

if isdefined(Base, :get_extension)
using Enzyme
using AdvancedVI
using AdvancedVI: ADTypes, DiffResults
using AdvancedVI: ADTypes
else
using ..Enzyme
using ..AdvancedVI
using ..AdvancedVI: ADTypes, DiffResults
using ..AdvancedVI: ADTypes
end

function AdvancedVI.restructure_ad_forward(::ADTypes.AutoEnzyme, restructure, params)
return restructure(params)::typeof(restructure.model)
end

function AdvancedVI.value_and_gradient!(
::ADTypes.AutoEnzyme, f, x::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult
)
∇x = DiffResults.gradient(out)
fill!(∇x, zero(eltype(∇x)))
_, y = Enzyme.autodiff(
Enzyme.set_runtime_activity(Enzyme.ReverseWithPrimal, true),
Enzyme.Const(f),
Enzyme.Active,
Enzyme.Duplicated(x, ∇x),
)
DiffResults.value!(out, y)
return out
end

function AdvancedVI.value_and_gradient!(
::ADTypes.AutoEnzyme,
f,
x::AbstractVector{<:Real},
aux,
out::DiffResults.MutableDiffResult,
)
∇x = DiffResults.gradient(out)
fill!(∇x, zero(eltype(∇x)))
function AdvancedVI.value_and_gradient(::ADTypes.AutoEnzyme, f, x::AbstractVector{<:Real}, aux)
∇x = zero(x)
_, y = Enzyme.autodiff(
Enzyme.set_runtime_activity(Enzyme.ReverseWithPrimal, true),
Enzyme.Const(f),
Enzyme.Active,
Enzyme.Duplicated(x, ∇x),
Enzyme.Const(aux),
)
DiffResults.value!(out, y)
return out
return y, ∇x
end

end
42 changes: 0 additions & 42 deletions ext/AdvancedVIForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,42 +0,0 @@

module AdvancedVIForwardDiffExt

if isdefined(Base, :get_extension)
using ForwardDiff
using AdvancedVI
using AdvancedVI: ADTypes, DiffResults
else
using ..ForwardDiff
using ..AdvancedVI
using ..AdvancedVI: ADTypes, DiffResults
end

getchunksize(::ADTypes.AutoForwardDiff{chunksize}) where {chunksize} = chunksize

function AdvancedVI.value_and_gradient!(
ad::ADTypes.AutoForwardDiff,
f,
x::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult,
)
chunk_size = getchunksize(ad)
config = if isnothing(chunk_size)
ForwardDiff.GradientConfig(f, x)
else
ForwardDiff.GradientConfig(f, x, ForwardDiff.Chunk(length(x), chunk_size))
end
ForwardDiff.gradient!(out, f, x, config)
return out
end

function AdvancedVI.value_and_gradient!(
ad::ADTypes.AutoForwardDiff,
f,
x::AbstractVector,
aux,
out::DiffResults.MutableDiffResult,
)
return AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out)
end

end
Empty file added ext/AdvancedVIMooncakeExt.jl
Empty file.
36 changes: 0 additions & 36 deletions ext/AdvancedVIReverseDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,36 +0,0 @@

module AdvancedVIReverseDiffExt

if isdefined(Base, :get_extension)
using AdvancedVI
using AdvancedVI: ADTypes, DiffResults
using ReverseDiff
else
using ..AdvancedVI
using ..AdvancedVI: ADTypes, DiffResults
using ..ReverseDiff
end

# ReverseDiff without compiled tape
function AdvancedVI.value_and_gradient!(
ad::ADTypes.AutoReverseDiff,
f,
x::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult,
)
tp = ReverseDiff.GradientTape(f, x)
ReverseDiff.gradient!(out, tp, x)
return out
end

function AdvancedVI.value_and_gradient!(
ad::ADTypes.AutoReverseDiff,
f,
x::AbstractVector{<:Real},
aux,
out::DiffResults.MutableDiffResult,
)
return AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out)
end

end
37 changes: 0 additions & 37 deletions ext/AdvancedVITapirExt.jl

This file was deleted.

36 changes: 0 additions & 36 deletions ext/AdvancedVIZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -1,36 +0,0 @@

module AdvancedVIZygoteExt

if isdefined(Base, :get_extension)
using AdvancedVI
using AdvancedVI: ADTypes, DiffResults
using ChainRulesCore
using Zygote
else
using ..AdvancedVI
using ..AdvancedVI: ADTypes, DiffResults
using ..ChainRulesCore
using ..Zygote
end

function AdvancedVI.value_and_gradient!(
::ADTypes.AutoZygote, f, x::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult
)
y, back = Zygote.pullback(f, x)
∇x = back(one(y))
DiffResults.value!(out, y)
DiffResults.gradient!(out, only(∇x))
return out
end

function AdvancedVI.value_and_gradient!(
ad::ADTypes.AutoZygote,
f,
x::AbstractVector{<:Real},
aux,
out::DiffResults.MutableDiffResult,
)
return AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out)
end

end
19 changes: 11 additions & 8 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@ using LinearAlgebra

using LogDensityProblems

using ADTypes, DiffResults
using ADTypes
using DifferentiationInterface
using ChainRulesCore

using FillArrays

using StatsBase

# derivatives
# Derivatives
"""
value_and_gradient!(ad, f, x, out)
value_and_gradient!(ad, f, x, aux, out)
value_and_gradient(ad, f, x, aux, out)

Evaluate the value and gradient of a function `f` at `x` using the automatic differentiation backend `ad` and store the result in `out`.
`f` may receive auxiliary input as `f(x,aux)`.
Expand All @@ -36,9 +36,13 @@ Evaluate the value and gradient of a function `f` at `x` using the automatic dif
- `f`: Function subject to differentiation.
- `x`: The point to evaluate the gradient.
- `aux`: Auxiliary input passed to `f`.
- `out::DiffResults.MutableDiffResult`: Buffer to contain the output gradient and function value.

# Returns
- `value`: `f` evaluated at `x`.
- `grad`: Gradient of `f` evaluated at `x`.
"""
function value_and_gradient! end
value_and_gradient(ad::ADTypes.AbstractADType, f, x, aux) =
DifferentiationInterface.value_and_gradient(f, ad, x, Constant(aux))

"""
restructure_ad_forward(adtype, restructure, params)
Expand Down Expand Up @@ -131,15 +135,14 @@ function estimate_objective end
export estimate_objective

"""
estimate_gradient!(rng, obj, adtype, out, prob, λ, restructure, obj_state)
estimate_gradient(rng, obj, adtype, prob, λ, restructure, obj_state)

Estimate (possibly stochastic) gradients of the variational objective `obj` targeting `prob` with respect to the variational parameters `λ`

# Arguments
- `rng::Random.AbstractRNG`: Random number generator.
- `obj::AbstractVariationalObjective`: Variational objective.
- `adtype::ADTypes.AbstractADType`: Automatic differentiation backend.
- `out::DiffResults.MutableDiffResult`: Buffer containing the objective value and gradient estimates.
- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface.
- `λ`: Variational parameters to evaluate the gradient on.
- `restructure`: Function that reconstructs the variational approximation from `λ`.
Expand Down
8 changes: 3 additions & 5 deletions src/objectives/elbo/repgradelbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,10 @@ function estimate_repgradelbo_ad_forward(params′, aux)
return -elbo
end

function estimate_gradient!(
function estimate_gradient(
rng::Random.AbstractRNG,
obj::RepGradELBO,
adtype::ADTypes.AbstractADType,
out::DiffResults.MutableDiffResult,
prob,
params,
restructure,
Expand All @@ -120,8 +119,7 @@ function estimate_gradient!(
restructure=restructure,
q_stop=q_stop,
)
value_and_gradient!(adtype, estimate_repgradelbo_ad_forward, params, aux, out)
nelbo = DiffResults.value(out)
nelbo, g = value_and_gradient(adtype, estimate_repgradelbo_ad_forward, params, aux)
stat = (elbo=-nelbo,)
return out, nothing, stat
return g, nothing, stat
end
Loading