-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Migrate to DifferentiationInterface #98
Changes from all commits
ee4bd54
c7b300f
c4e2db4
c2593c2
a8d2ee9
32a2fe0
9f6fbac
c102f80
b765d48
6f1e989
90a865c
61c4999
169b368
acc558d
36c70e7
dce99d7
79e5c62
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
steps: | ||
- label: "CUDA with julia {{matrix.julia}}" | ||
plugins: | ||
- JuliaCI/julia#v1: | ||
version: "{{matrix.julia}}" | ||
- JuliaCI/julia-test#v1: | ||
test_args: "--quickfail" | ||
agents: | ||
queue: "juliagpu" | ||
cuda: "*" | ||
timeout_in_minutes: 60 | ||
env: | ||
GROUP: "GPU" | ||
ADVANCEDVI_TEST_CUDA: "true" | ||
matrix: | ||
setup: | ||
julia: | ||
- "1.10" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,42 +0,0 @@ | ||
|
||
module AdvancedVIForwardDiffExt | ||
|
||
if isdefined(Base, :get_extension) | ||
using ForwardDiff | ||
using AdvancedVI | ||
using AdvancedVI: ADTypes, DiffResults | ||
else | ||
using ..ForwardDiff | ||
using ..AdvancedVI | ||
using ..AdvancedVI: ADTypes, DiffResults | ||
end | ||
|
||
getchunksize(::ADTypes.AutoForwardDiff{chunksize}) where {chunksize} = chunksize | ||
|
||
function AdvancedVI.value_and_gradient!( | ||
ad::ADTypes.AutoForwardDiff, | ||
f, | ||
x::AbstractVector{<:Real}, | ||
out::DiffResults.MutableDiffResult, | ||
) | ||
chunk_size = getchunksize(ad) | ||
config = if isnothing(chunk_size) | ||
ForwardDiff.GradientConfig(f, x) | ||
else | ||
ForwardDiff.GradientConfig(f, x, ForwardDiff.Chunk(length(x), chunk_size)) | ||
end | ||
ForwardDiff.gradient!(out, f, x, config) | ||
return out | ||
end | ||
|
||
function AdvancedVI.value_and_gradient!( | ||
ad::ADTypes.AutoForwardDiff, | ||
f, | ||
x::AbstractVector, | ||
aux, | ||
out::DiffResults.MutableDiffResult, | ||
) | ||
return AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out) | ||
end | ||
|
||
end | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,36 +0,0 @@ | ||
|
||
module AdvancedVIReverseDiffExt | ||
|
||
if isdefined(Base, :get_extension) | ||
using AdvancedVI | ||
using AdvancedVI: ADTypes, DiffResults | ||
using ReverseDiff | ||
else | ||
using ..AdvancedVI | ||
using ..AdvancedVI: ADTypes, DiffResults | ||
using ..ReverseDiff | ||
end | ||
|
||
# ReverseDiff without compiled tape | ||
function AdvancedVI.value_and_gradient!( | ||
ad::ADTypes.AutoReverseDiff, | ||
f, | ||
x::AbstractVector{<:Real}, | ||
out::DiffResults.MutableDiffResult, | ||
) | ||
tp = ReverseDiff.GradientTape(f, x) | ||
ReverseDiff.gradient!(out, tp, x) | ||
return out | ||
end | ||
|
||
function AdvancedVI.value_and_gradient!( | ||
ad::ADTypes.AutoReverseDiff, | ||
f, | ||
x::AbstractVector{<:Real}, | ||
aux, | ||
out::DiffResults.MutableDiffResult, | ||
) | ||
return AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out) | ||
end | ||
|
||
end | ||
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,36 +0,0 @@ | ||
|
||
module AdvancedVIZygoteExt | ||
|
||
if isdefined(Base, :get_extension) | ||
using AdvancedVI | ||
using AdvancedVI: ADTypes, DiffResults | ||
using ChainRulesCore | ||
using Zygote | ||
else | ||
using ..AdvancedVI | ||
using ..AdvancedVI: ADTypes, DiffResults | ||
using ..ChainRulesCore | ||
using ..Zygote | ||
end | ||
|
||
function AdvancedVI.value_and_gradient!( | ||
::ADTypes.AutoZygote, f, x::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult | ||
) | ||
y, back = Zygote.pullback(f, x) | ||
∇x = back(one(y)) | ||
DiffResults.value!(out, y) | ||
DiffResults.gradient!(out, only(∇x)) | ||
return out | ||
end | ||
|
||
function AdvancedVI.value_and_gradient!( | ||
ad::ADTypes.AutoZygote, | ||
f, | ||
x::AbstractVector{<:Real}, | ||
aux, | ||
out::DiffResults.MutableDiffResult, | ||
) | ||
return AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out) | ||
end | ||
|
||
end | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,16 +16,17 @@ using LinearAlgebra | |
|
||
using LogDensityProblems | ||
|
||
using ADTypes, DiffResults | ||
using ADTypes | ||
using DiffResults | ||
using DifferentiationInterface | ||
using ChainRulesCore | ||
|
||
using FillArrays | ||
|
||
using StatsBase | ||
|
||
# derivatives | ||
# Derivatives | ||
""" | ||
value_and_gradient!(ad, f, x, out) | ||
value_and_gradient!(ad, f, x, aux, out) | ||
|
||
Evaluate the value and gradient of a function `f` at `x` using the automatic differentiation backend `ad` and store the result in `out`. | ||
|
@@ -38,7 +39,14 @@ Evaluate the value and gradient of a function `f` at `x` using the automatic dif | |
- `aux`: Auxiliary input passed to `f`. | ||
- `out::DiffResults.MutableDiffResult`: Buffer to contain the output gradient and function value. | ||
""" | ||
function value_and_gradient! end | ||
function value_and_gradient!( | ||
ad::ADTypes.AbstractADType, f, x, aux, out::DiffResults.MutableDiffResult | ||
) | ||
grad_buf = DiffResults.gradient(out) | ||
y, _ = DifferentiationInterface.value_and_gradient!(f, grad_buf, ad, x, Constant(aux)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You would benefit from DI's preparation mechanism, especially for backends like ForwardDiff and ReverseDiff There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also for Moooncake. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @gdalle @willtebbutt In our case, features to come in the future (like subsampling) will result in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mooncake.jl ought to be fine provided that the type doesn't change. Am I correct in assuming that in the case of subsampling, you're just changing the indices of the data which are subsampled at each iteration? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But @gdalle is that true in general? For example I imagine reversediff or the upcoming reactant support would result in an invalid prep object, no? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the size of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For most backends, what preparation does is allocate caches that correspond to the various differentiated inputs and contexts. These caches have a specific size which determines whether preparation can be reused, but then inside your function you can create more objects with arbitrary sizes, that's none of my business. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Haha fair enough.
I think compiled tapes actually need everything to be the same except There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh yeah I had forgotten constants. As soon as you use constants then DI recompiles the tape every time anyway There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah I see. Okay, then I guess there's nothing preventing us from using |
||
DiffResults.value!(out, y) | ||
return out | ||
end | ||
|
||
""" | ||
restructure_ad_forward(adtype, restructure, params) | ||
|
@@ -131,7 +139,7 @@ function estimate_objective end | |
export estimate_objective | ||
|
||
""" | ||
estimate_gradient!(rng, obj, adtype, out, prob, λ, restructure, obj_state) | ||
estimate_gradient!(rng, obj, adtype, out, prob, params, restructure, obj_state) | ||
|
||
Estimate (possibly stochastic) gradients of the variational objective `obj` targeting `prob` with respect to the variational parameters `λ` | ||
|
||
|
@@ -141,7 +149,7 @@ Estimate (possibly stochastic) gradients of the variational objective `obj` targ | |
- `adtype::ADTypes.AbstractADType`: Automatic differentiation backend. | ||
- `out::DiffResults.MutableDiffResult`: Buffer containing the objective value and gradient estimates. | ||
- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. | ||
- `λ`: Variational parameters to evaluate the gradient on. | ||
- `params`: Variational parameters to evaluate the gradient on. | ||
- `restructure`: Function that reconstructs the variational approximation from `λ`. | ||
- `obj_state`: Previous state of the objective. | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might be better as an extension x/ref JuliaDiff/DifferentiationInterface.jl#509