diff --git a/src/adjoint.jl b/src/adjoint.jl index a5f3a5d..47f628e 100644 --- a/src/adjoint.jl +++ b/src/adjoint.jl @@ -30,7 +30,14 @@ function adjoint end function _pullback end function pullback end -function gradm(ex, mut = false) + +function unthunk_tangent end +@inline unthunk_tangent(x) = x +@inline unthunk_tangent(x::Tuple) = map(unthunk_tangent, x) +@inline unthunk_tangent(x::NamedTuple) = map(unthunk_tangent, x) + + +function gradm(ex, mut = false, keepthunks = false) @capture(shortdef(ex), (name_(args__) = body_) | (name_(args__) where {Ts__} = body_)) || error("Need a function definition") kw = length(args) > 1 && isexpr(args[1], :parameters) ? esc(popfirst!(args)) : nothing @@ -51,18 +58,19 @@ function gradm(ex, mut = false) gradtuple = isclosure ? gradtuple0 : gradtuple1 gradtuplekw = isclosure ? gradtuple2 : gradtuple3 adj = @q @inline ZygoteRules.adjoint($(fargs...)) where $(Ts...) = $(esc(body)) + maybe_unthunked_Δ = keepthunks ? :Δ : :(unthunk_tangent(Δ)) quote $adj @inline function ZygoteRules._pullback($cx, $f::$T, $(args...)) where $(Ts...) y, _back = adjoint(__context__, $f, $(argnames...)) $(mut ? nothing : :(back(::Nothing) = nothing)) - back(Δ) = $gradtuple(_back(Δ)) + back(Δ) = $gradtuple(_back($maybe_unthunked_Δ)) return y, back end @inline function ZygoteRules._pullback($cx, ::$kT, kw, $f::$T, $(args...)) where $(Ts...) y, _back = adjoint(__context__, $f, $(argnames...); kw...) $(mut ? nothing : :(back(::Nothing) = nothing)) - back(Δ) = $gradtuplekw(_back(Δ)) + back(Δ) = $gradtuplekw(_back($maybe_unthunked_Δ)) return y, back end nothing @@ -70,9 +78,9 @@ function gradm(ex, mut = false) end macro adjoint(ex) - gradm(ex) + gradm(ex, false, false) end macro adjoint!(ex) - gradm(ex, true) + gradm(ex, true, false) end