diff --git a/Project.toml b/Project.toml index 1ebcf9c..b38af29 100644 --- a/Project.toml +++ b/Project.toml @@ -16,6 +16,7 @@ LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" ManualMemory = "d125e4d3-2237-4719-b19c-fa641b8a4667" Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SIMDTypes = "94e857df-77ce-4151-89e5-788b33177be4" SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa" Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" @@ -38,6 +39,7 @@ LayoutPointers = "0.1.3" LoopVectorization = "0.12.104" ManualMemory = "0.1.8" Polyester = "0.4, 0.5, 0.6, 0.7" +ReverseDiff = "1.14" SIMDTypes = "0.1" SLEEFPirates = "0.6" Static = "0.8.4" diff --git a/src/SimpleChains.jl b/src/SimpleChains.jl index 6c4f3a9..d3ab44d 100644 --- a/src/SimpleChains.jl +++ b/src/SimpleChains.jl @@ -43,6 +43,7 @@ import ForwardDiff import LoopVectorization import StaticArrays using Random: AbstractRNG +using ReverseDiff using LoopVectorization: matmul_params, @turbo # using LoopVectorization: matmul_params diff --git a/src/chain_rules.jl b/src/chain_rules.jl index b07e7ab..44959ef 100644 --- a/src/chain_rules.jl +++ b/src/chain_rules.jl @@ -1,4 +1,3 @@ - if isdefined(ChainRulesCore, :NoTangent) const NoTangent = ChainRulesCore.NoTangent else @@ -176,6 +175,60 @@ end _returns_scalar(::AbstractPenalty) = True() _returns_scalar(sc::SimpleChain) = has_loss_typed(sc) -function ChainRulesCore.rrule(sc::Chain, arg, params) - _rrule(sc, arg, params, _returns_scalar(sc)) +function ChainRulesCore.rrule(::typeof(call_chain), sc::Chain, arg, params) + v, pb = _rrule(sc, arg, params, _returns_scalar(sc)) + return v, Δ -> (NoTangent(), pb(collect(Δ))...) +end +function call_chain(sc::SimpleChain, arg::AbstractArray, params::ReverseDiff.TrackedVector) + return ReverseDiff.track(call_chain, sc, arg, params) +end +function call_chain(sc::SimpleChain, arg::AbstractArray, params::SubArray{<:ReverseDiff.TrackedReal, 1}) + return ReverseDiff.track(call_chain, sc, arg, params) +end +function call_chain(sc::SimpleChain, arg::ReverseDiff.TrackedArray, params::ReverseDiff.AbstractVector) + return ReverseDiff.track(call_chain, sc, arg, params) +end +function call_chain(sc::SimpleChain, arg::ReverseDiff.TrackedArray, params::ReverseDiff.TrackedVector) + return ReverseDiff.track(call_chain, sc, arg, params) +end +function call_chain(sc::SimpleChain, arg::ReverseDiff.TrackedArray, params::SubArray{<:ReverseDiff.TrackedReal, 1}) + return ReverseDiff.track(call_chain, sc, arg, params) +end +ReverseDiff.@grad function call_chain(sc::Chain, arg::AbstractArray, params::AbstractVector) + argv = ReverseDiff.value(arg) + paramsv = ReverseDiff.value(params) + v, pb = _rrule(sc, argv, paramsv, _returns_scalar(sc)) + return v, Δ -> begin + _Δ = Base.tail(pb(collect(Δ))) + _Δ = Base.tail(pb(collect(Δ))) + (nothing, _Δ...) + end +end + +function params_rrule(Δ) + n = sum(x -> sum(length, x), Δ) + T = eltype(first(first(Δ))) + v = zeros(T, n) + offset = 0 + for x in Δ + for y in x + if y isa Real + v[offset + 1] = y + offset += 1 + else + l = length(y) + v[offset + 1 : offset + l] = vec(y) + offset += l + end + end + end + return v +end +function ChainRulesCore.rrule(::typeof(params), sc::SimpleChain, p::AbstractVector) + return params(sc, p), Δ -> (NoTangent(), NoTangent(), params_rrule(Δ)) +end +params(sc::SimpleChain, p::ReverseDiff.TrackedArray) = ReverseDiff.track(params, sc, p) +params(sc::SimpleChain, p::SubArray{<:ReverseDiff.TrackedReal, 1}) = ReverseDiff.track(params, sc, p) +ReverseDiff.@grad function params(sc::SimpleChain, p::AbstractArray) + return params(sc, p), Δ -> (nothing, params_rrule(Δ)) end diff --git a/src/penalty.jl b/src/penalty.jl index b1d19c2..b48fb03 100644 --- a/src/penalty.jl +++ b/src/penalty.jl @@ -1,9 +1,6 @@ function (Λ::AbstractPenalty{<:SimpleChain})(arg, params) - Base.FastMath.add_fast( - getchain(Λ)(arg, params), - apply_penalty(Λ, params, static_size(arg)) - ) + return call_chain(Λ, arg, params) end function valgrad!(g, Λ::AbstractPenalty{<:SimpleChain}, arg, params) Base.FastMath.add_fast( @@ -12,6 +9,17 @@ function valgrad!(g, Λ::AbstractPenalty{<:SimpleChain}, arg, params) ) end +function call_chain( + Λ::AbstractPenalty{<:SimpleChain}, + arg, + params, +) + Base.FastMath.add_fast( + getchain(Λ)(arg, params), + apply_penalty(Λ, params, static_size(arg)) + ) +end + _penalty_applied_to_sc(_::IO, ::Nothing) = nothing function _penalty_applied_to_sc(io::IO, sc::SimpleChain) println(io, " applied to:") diff --git a/src/simple_chain.jl b/src/simple_chain.jl index e91a583..4d27d1e 100644 --- a/src/simple_chain.jl +++ b/src/simple_chain.jl @@ -151,7 +151,11 @@ end _maybe_sarray(fx, static_size(fx)) end -function (c::SimpleChain)( +function (c::SimpleChain)(arg::AbstractArray, params::AbstractVector) + return call_chain(c, arg, params) +end +function call_chain( + c::SimpleChain, arg::AbstractArray{T0}, params::AbstractVector{T1} ) where {T0,T1} @@ -267,6 +271,9 @@ end arg::StaticArrays.SArray, params::AbstractVector ) + call_chain(c, arg, params) +end +function call_chain(c::SimpleChain, arg::StaticArrays.SArray, params::AbstractVector) verify_arg(c, arg) @unpack layers = c marg = StaticArrays.MArray(arg)