Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use broadcasting rules from ChainRules #89

Merged
merged 6 commits into from
Sep 17, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"

[compat]
ChainRules = "1.5"
ChainRulesCore = "1.2"
ChainRules = "1.44.6"
ChainRulesCore = "1.15.3"
Combinatorics = "1"
StaticArrays = "1"
StatsBase = "0.33"
Expand Down
39 changes: 19 additions & 20 deletions src/extra_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ function ChainRulesCore.rrule(::DiffractorRuleConfig, g::∇getindex, Δ)
g(Δ), Δ′′->(nothing, Δ′′[1][g.i...])
end

function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(getindex), xs::Array, i...)
function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(getindex), xs::Array{<:Number}, i...)
xs[i...], ∇getindex(xs, i)
end

Expand Down Expand Up @@ -150,12 +150,6 @@ end

ChainRulesCore.canonicalize(::ChainRulesCore.ZeroTangent) = ChainRulesCore.ZeroTangent()

# Skip AD'ing through the axis computation
function ChainRules.rrule(::DiffractorRuleConfig, ::typeof(Base.Broadcast.instantiate), bc::Base.Broadcast.Broadcasted)
return Base.Broadcast.instantiate(bc), Δ->begin
Core.tuple(NoTangent(), Δ)
end
end


using StaticArrays
Expand Down Expand Up @@ -187,9 +181,10 @@ end

@ChainRulesCore.non_differentiable StaticArrays.promote_tuple_eltype(T)

function ChainRules.frule((_, ∂A), ::typeof(getindex), A::AbstractArray, args...)
getindex(A, args...), getindex(∂A, args...)
end
# function ChainRules.frule((_, ∂A), ::typeof(getindex), A::AbstractArray, args...)
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
# getindex(A, args...), getindex(∂A, args...)
# end
# WARNING: Method definition frule(Any, typeof(Base.getindex), AbstractArray{T, N} where N where T, Any...) in module ChainRules at /Users/me/.julia/packages/ChainRules/KVV0e/src/rulesets/Base/indexing.jl:59 overwritten in module Diffractor at /Users/me/.julia/dev/Diffractor/src/extra_rules.jl:184

function ChainRules.rrule(::DiffractorRuleConfig, ::typeof(map), ::typeof(+), A::AbstractArray, B::AbstractArray)
map(+, A, B), Δ->(NoTangent(), NoTangent(), Δ, Δ)
Expand Down Expand Up @@ -226,26 +221,31 @@ struct BackMap{T}
end
(f::BackMap{N})(args...) where {N} = ∂⃖¹(getfield(f, :f), args...)
back_apply(x, y) = x(y)
back_apply_zero(x) = x(Zero())
back_apply_zero(x) = x(Zero()) # Zero is not defined

function ChainRules.rrule(::DiffractorRuleConfig, ::typeof(map), f, args::Tuple)
a, b = unzip_tuple(map(BackMap(f), args))
function back(Δ)
function map_back(Δ)
(fs, xs) = unzip_tuple(map(back_apply, b, Δ))
(NoTangent(), sum(fs), xs)
end
function back(Δ::ZeroTangent)
(fs, xs) = unzip_tuple(map(back_apply_zero, b))
(NoTangent(), sum(fs), xs)
end
a, back
map_back(Δ::AbstractZero) = (NoTangent(), NoTangent(), NoTangent())
# function back(Δ::ZeroTangent)
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
# (fs, xs) = unzip_tuple(map(back_apply_zero, b))
# (NoTangent(), sum(fs), xs)
# end
a, map_back
end

ChainRules.rrule(::DiffractorRuleConfig, ::typeof(map), f, args::Tuple{}) = (), _ -> (NoTangent(), NoTangent(), NoTangent())

function ChainRules.rrule(::DiffractorRuleConfig, ::typeof(Base.ntuple), f, n)
a, b = unzip_tuple(ntuple(BackMap(f), n))
a, function (Δ)
function ntuple_back(Δ)
(NoTangent(), sum(map(back_apply, b, Δ)), NoTangent())
end
ntuple_back(::AbstractZero) = (NoTangent(), NoTangent(), NoTangent())
a, ntuple_back
end

function ChainRules.frule(::DiffractorRuleConfig, _, ::Type{Vector{T}}, undef::UndefInitializer, dims::Int...) where {T}
Expand All @@ -267,5 +267,4 @@ function ChainRulesCore.rrule(::DiffractorRuleConfig, ::Type{InplaceableThunk},
val, Δ->(NoTangent(), NoTangent(), Δ)
end

Base.real(z::ZeroTangent) = z # TODO should be in CRC
Base.real(z::NoTangent) = z
Base.real(z::NoTangent) = z # TODO should be in CRC, https://github.com/JuliaDiff/ChainRulesCore.jl/pull/581
1 change: 1 addition & 0 deletions src/runtime.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ accum(x::Tangent{T}, y::Tangent) where T = _tangent(T, accum(backing(x), backing

_tangent(::Type{T}, z) where T = Tangent{T,typeof(z)}(z)
_tangent(::Type, ::NamedTuple{()}) = NoTangent()
_tangent(::Type, ::NamedTuple{<:Any, <:Tuple{Vararg{AbstractZero}}}) = NoTangent()
43 changes: 0 additions & 43 deletions src/stage1/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,46 +28,3 @@ function (∂ₙ::∂☆{N})(zc::ZeroBundle{N, typeof(copy)},
end
return r
end

# Broadcast over one element is just map
function (∂⃖ₙ::∂⃖{N})(::typeof(broadcasted), f, a::Array) where {N}
∂⃖ₙ(map, f, a)
end

# The below is from Zygote: TODO: DO we want to do something better here?

accum_sum(xs::Nothing; dims = :) = NoTangent()
accum_sum(xs::AbstractArray{Nothing}; dims = :) = NoTangent()
accum_sum(xs::AbstractArray{<:Number}; dims = :) = sum(xs, dims = dims)
accum_sum(xs::AbstractArray{<:AbstractArray{<:Number}}; dims = :) = sum(xs, dims = dims)
accum_sum(xs::Number; dims = :) = xs

# https://github.com/FluxML/Zygote.jl/issues/594
function Base.reducedim_init(::typeof(identity), ::typeof(accum), A::AbstractArray, region)
Base.reducedim_initarray(A, region, NoTangent(), Union{Nothing,eltype(A)})
end

trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))

unbroadcast(x::AbstractArray, x̄) =
size(x) == size(x̄) ? x̄ :
length(x) == length(x̄) ? trim(x, x̄) :
trim(x, accum_sum(x̄, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(x̄)+1, Val(ndims(x̄)))))

unbroadcast(x::Number, x̄) = accum_sum(x̄)
unbroadcast(x::Tuple{<:Any}, x̄) = (accum_sum(x̄),)
unbroadcast(x::Base.RefValue, x̄) = (x=accum_sum(x̄),)

unbroadcast(x::AbstractArray, x̄::Nothing) = NoTangent()

const Numeric = Union{Number, AbstractArray{<:Number, N} where N}

function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(broadcasted), ::typeof(+), xs::Numeric...)
broadcast(+, xs...), ȳ -> (NoTangent(), NoTangent(), map(x -> unbroadcast(x, unthunk(ȳ)), xs)...)
end

ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(broadcasted), ::typeof(-), x::Numeric, y::Numeric) = x .- y,
Δ -> let Δ=unthunk(Δ); (NoTangent(), NoTangent(), unbroadcast(x, Δ), -unbroadcast(y, Δ)); end

ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(broadcasted), ::typeof(*), x::Numeric, y::Numeric) = x.*y,
z̄ -> let z̄=unthunk(z̄); (NoTangent(), NoTangent(), unbroadcast(x, z̄ .* conj.(y)), unbroadcast(y, z̄ .* conj.(x))); end
5 changes: 3 additions & 2 deletions src/stage1/generated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -315,13 +315,13 @@ function (::∂⃖{N})(::typeof(Core.getfield), s, field::Symbol) where {N}
end

# TODO: Temporary - make better
function (::∂⃖{N})(::typeof(Base.getindex), a::Array, inds...) where {N}
function (::∂⃖{N})(::typeof(Base.getindex), a::Array{<:Number}, inds...) where {N}
getindex(a, inds...), let
EvenOddOdd{1, c_order(N)}(
(@Base.constprop :aggressive Δ->begin
Δ isa AbstractZero && return (NoTangent(), Δ, map(Returns(Δ), inds)...)
BB = zero(a)
BB[inds...] = Δ
BB[inds...] = unthunk(Δ)
(NoTangent(), BB, map(x->NoTangent(), inds)...)
end),
(@Base.constprop :aggressive (_, Δ, _)->begin
Expand All @@ -334,6 +334,7 @@ struct tuple_back{M}; end
(::tuple_back)(Δ::Tuple) = Core.tuple(NoTangent(), Δ...)
(::tuple_back{N})(Δ::AbstractZero) where {N} = Core.tuple(NoTangent(), ntuple(i->Δ, N)...)
(::tuple_back{N})(Δ::Tangent) where {N} = Core.tuple(NoTangent(), ntuple(i->lifted_getfield(Δ, i), N)...)
(t::tuple_back)(Δ::AbstractThunk) = t(unthunk(Δ))

function (::∂⃖{N})(::typeof(Core.tuple), args::Vararg{Any, M}) where {N, M}
Core.tuple(args...),
Expand Down
68 changes: 65 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ let var"'" = Diffractor.PrimeDerivativeBack
@test @inferred(sin'(1.0)) == cos(1.0)
@test @inferred(sin''(1.0)) == -sin(1.0)
@test sin'''(1.0) == -cos(1.0)
@test sin''''(1.0) == sin(1.0) broken = VERSION >= v"1.8"
@test sin'''''(1.0) == cos(1.0) broken = VERSION >= v"1.8"
@test sin''''''(1.0) == -sin(1.0) broken = VERSION >= v"1.8"
@test sin''''(1.0) == sin(1.0) # broken = VERSION >= v"1.8"
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
@test sin'''''(1.0) == cos(1.0) # broken = VERSION >= v"1.8"
@test sin''''''(1.0) == -sin(1.0) # broken = VERSION >= v"1.8"

f_getfield(x) = getfield((x,), 1)
@test f_getfield'(1) == 1
Expand Down Expand Up @@ -219,6 +219,68 @@ z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2)
@test z45 ≈ 2.0
@test delta45 ≈ 1.0

# PR #82 - getindex on non-numeric arrays
@test gradient(ls -> ls[1](1.), [Base.Fix1(*, 1.)])[1][1] isa Tangent{<:Base.Fix1}

@testset "broadcast" begin
@test gradient(x -> sum(x ./ x), [1,2,3]) == ([0,0,0],) # derivatives_given_output
@test gradient(x -> sum(sqrt.(atan.(x, transpose(x)))), [1,2,3])[1] ≈ [0.2338, -0.0177, -0.0661] atol=1e-3
@test gradient(x -> sum(exp.(log.(x))), [1,2,3]) == ([1,1,1],)

@test gradient(x -> sum((exp∘log).(x)), [1,2,3]) == ([1,1,1],) # frule_via_ad
exp_log(x) = exp(log(x))
@test gradient(x -> sum(exp_log.(x)), [1,2,3]) == ([1,1,1],)
@test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], [1,2]) == ([1 1; 0.5 0.5], [-3, -1.75])
@test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], 5) == ([0.2 0.2; 0.2 0.2], -0.4)
@test gradient(x -> sum((y -> y/x).([1,2,3])), 4) == (-0.375,) # closure

@test gradient(x -> sum(sum, (x,) ./ x), [1,2,3])[1] ≈ [-4.1666, 0.3333, 1.1666] atol=1e-3 # array of arrays
@test gradient(x -> sum(sum, Ref(x) ./ x), [1,2,3])[1] ≈ [-4.1666, 0.3333, 1.1666] atol=1e-3
@test gradient(x -> sum(sum, (x,) ./ x), [1,2,3])[1] ≈ [-4.1666, 0.3333, 1.1666] atol=1e-3
@test gradient(x -> sum(sum, (x,) .* transpose(x)), [1,2,3])[1] ≈ [12, 12, 12] # must not take the * fast path

@test gradient(x -> sum(x ./ 4), [1,2,3]) == ([0.25, 0.25, 0.25],)
@test gradient(x -> sum([1,2,3] ./ x), 4) == (-0.375,) # x/y rule
@test gradient(x -> sum(x.^2), [1,2,3]) == ([2.0, 4.0, 6.0],) # x.^2 rule
@test gradient(x -> sum([1,2,3] ./ x.^2), 4) == (-0.1875,) # scalar^2 rule

@test gradient(x -> sum((1,2,3) .- x), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(-1.0, -1.0, -1.0),)
@test gradient(x -> sum(transpose([1,2,3]) .- x), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(-3.0, -3.0, -3.0),)
@test gradient(x -> sum([1 2 3] .+ x .^ 2), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(6.0, 12.0, 18.0),)

@test gradient(x -> sum(x .> 2), [1,2,3]) |> only |> iszero # Bool output
@test gradient(x -> sum(1 .+ iseven.(x)), [1,2,3]) |> only |> iszero
@test gradient((x,y) -> sum(x .== y), [1,2,3], [1 2 3]) == (NoTangent(), NoTangent())
@test gradient(x -> sum(x .+ [1,2,3]), true) |> only |> iszero # Bool input
@test gradient(x -> sum(x ./ [1,2,3]), [true false]) |> only |> iszero
@test gradient(x -> sum(x .* transpose([1,2,3])), (true, false)) |> only |> iszero

tup_adj = gradient((x,y) -> sum(2 .* x .+ log.(y)), (1,2), transpose([3,4,5]))
@test tup_adj[1] == Tangent{Tuple{Int64, Int64}}(6.0, 6.0)
@test tup_adj[2] ≈ [0.6666666666666666 0.5 0.4]
@test tup_adj[2] isa Transpose
@test gradient(x -> sum(atan.(x, (1,2,3))), Diagonal([4,5,6]))[1] isa Diagonal

@test gradient(x -> sum((y -> (x*y)).([1,2,3])), 4.0) == (6.0,) # closure
end

@testset "broadcast, 2nd order" begin
@test gradient(x -> gradient(y -> sum(y .* y), x)[1] |> sum, [1,2,3.0])[1] == [2,2,2] # calls "split broadcasting generic" with f = unthunk
@test gradient(x -> gradient(y -> sum(y .* x), x)[1].^3 |> sum, [1,2,3.0])[1] == [3,12,27]
@test_broken gradient(x -> gradient(y -> sum(y .* 2 .* y'), x)[1] |> sum, [1,2,3.0])[1] == [12, 12, 12] # Control flow support not fully implemented yet for higher-order

@test_broken gradient(x -> sum(gradient(x -> sum(x .^ 2 .+ x'), x)[1]), [1,2,3.0])[1] == [6,6,6] # BoundsError: attempt to access 18-element Vector{Core.Compiler.BasicBlock} at index [0]
@test_broken gradient(x -> sum(gradient(x -> sum((x .+ 1) .* x .- x), x)[1]), [1,2,3.0])[1] == [2,2,2]
@test_broken gradient(x -> sum(gradient(x -> sum(x .* x ./ 2), x)[1]), [1,2,3.0])[1] == [1,1,1]

@test_broken gradient(x -> sum(gradient(x -> sum(exp.(x)), x)[1]), [1,2,3])[1] ≈ exp.(1:3) # MethodError: no method matching copy(::Nothing)
@test_broken gradient(x -> sum(gradient(x -> sum(atan.(x, x')), x)[1]), [1,2,3.0])[1] ≈ [0,0,0]
@test_broken gradient(x -> sum(gradient(x -> sum(transpose(x) .* x), x)[1]), [1,2,3]) == ([6,6,6],) # accum(a::Transpose{Float64, Vector{Float64}}, b::ChainRulesCore.Tangent{Transpose{Int64, Vector{Int64}}, NamedTuple{(:parent,), Tuple{ChainRulesCore.NoTangent}}})
@test_broken gradient(x -> sum(gradient(x -> sum(transpose(x) ./ x.^2), x)[1]), [1,2,3])[1] ≈ [27.675925925925927, -0.824074074074074, -2.1018518518518516]

@test_broken gradient(z -> gradient(x -> sum((y -> (x^2*y)).([1,2,3])), z)[1], 5.0) == (12.0,)
end

# Higher order control flow not yet supported (https://github.com/JuliaDiff/Diffractor.jl/issues/24)
#include("pinn.jl")

Expand Down