Skip to content

Commit

Permalink
Merge pull request #17 from oschulz/unthunking-adjoint
Browse files Browse the repository at this point in the history
Make @adjoint unthunk pullback inputs
  • Loading branch information
DhairyaLGandhi authored Sep 22, 2021
2 parents 5e48881 + 05cd6e1 commit 37dc97a
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions src/adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,14 @@ function adjoint end
function _pullback end
function pullback end

function gradm(ex, mut = false)

function unthunk_tangent end
@inline unthunk_tangent(x) = x
@inline unthunk_tangent(x::Tuple) = map(unthunk_tangent, x)
@inline unthunk_tangent(x::NamedTuple) = map(unthunk_tangent, x)


function gradm(ex, mut = false, keepthunks = false)
@capture(shortdef(ex), (name_(args__) = body_) |
(name_(args__) where {Ts__} = body_)) || error("Need a function definition")
kw = length(args) > 1 && isexpr(args[1], :parameters) ? esc(popfirst!(args)) : nothing
Expand All @@ -51,28 +58,29 @@ function gradm(ex, mut = false)
gradtuple = isclosure ? gradtuple0 : gradtuple1
gradtuplekw = isclosure ? gradtuple2 : gradtuple3
adj = @q @inline ZygoteRules.adjoint($(fargs...)) where $(Ts...) = $(esc(body))
maybe_unthunked_Δ = keepthunks ? : :(unthunk_tangent(Δ))
quote
$adj
@inline function ZygoteRules._pullback($cx, $f::$T, $(args...)) where $(Ts...)
y, _back = adjoint(__context__, $f, $(argnames...))
$(mut ? nothing : :(back(::Nothing) = nothing))
back(Δ) = $gradtuple(_back(Δ))
back(Δ) = $gradtuple(_back($maybe_unthunked_Δ))
return y, back
end
@inline function ZygoteRules._pullback($cx, ::$kT, kw, $f::$T, $(args...)) where $(Ts...)
y, _back = adjoint(__context__, $f, $(argnames...); kw...)
$(mut ? nothing : :(back(::Nothing) = nothing))
back(Δ) = $gradtuplekw(_back(Δ))
back(Δ) = $gradtuplekw(_back($maybe_unthunked_Δ))
return y, back
end
nothing
end
end

macro adjoint(ex)
gradm(ex)
gradm(ex, false, false)
end

macro adjoint!(ex)
gradm(ex, true)
gradm(ex, true, false)
end

0 comments on commit 37dc97a

Please sign in to comment.