-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use ProjectTo
in broadcasting & gradient
#1044
Conversation
src/compiler/interface.jl
Outdated
@@ -73,7 +73,8 @@ julia> gradient([7, 11], 0, 1) do x, y, d | |||
""" | |||
function gradient(f, args...) | |||
y, back = pullback(f, args...) | |||
return back(sensitivity(y)) | |||
grad = back(sensitivity(y)) | |||
map(_project, args, grad) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we want this at the gradient
or the pullback
level?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My thinking was to start small! Applying it to gradient
applies it to the user-facing calls, once. Applying it to pullback
or _pullback
inserts it into many more places internally... maybe it'll make sin'''(1.0)
unhappy.
One side-effect of this is that it makes this wrong answer into an error: julia> gradient((x,y) -> sum(map(+,x,y)), [1,2], [3,4,5,6]) # before
([1, 1], [1, 1])
julia> gradient((x,y) -> sum(map(+,x,y)), [1,2], [3,4,5,6]) # after
ERROR: DimensionMismatch("variable with size(x) == (4,) cannot have a gradient with size(dx) == (2,)")
Stacktrace:
[1] (::ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}})(dx::Vector{Int64})
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/ySyqy/src/projection.jl:197 |
That error seems awkward to me. Previously, the Julia behaviour of the function was the reason behind this gradient. Presumably, the resultant gradient should be sized appropriately, not error. |
src/compiler/interface.jl
Outdated
@@ -95,11 +97,32 @@ true | |||
""" | |||
function withgradient(f, args...) | |||
y, back = pullback(f, args...) | |||
(val = y, grad = back(sensitivity(y))) | |||
grad = back(sensitivity(y)) | |||
isnothing(grad) && return (val=y, grad=nothing) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this check necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can't map
over nothing
.
@@ -45,18 +45,20 @@ function Base.reducedim_init(::typeof(identity), ::typeof(accum), A::AbstractArr | |||
Base.reducedim_initarray(A, region, nothing, Union{Nothing,eltype(A)}) | |||
end | |||
|
|||
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x)))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think doing this makes unbroadcast less generic, we don't need to define projections here afaict. Let's retain the current definition.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What case exactly is not handled, if this is less generic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It restricts it to what can be handled by _project
as opposed to simple sizes and lengths of arrays.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
those are broadly the same now, as of recent changes. _project
will never method error now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that before CRC changes, _project
had extra methods to handle other cases.
@@ -45,18 +45,20 @@ function Base.reducedim_init(::typeof(identity), ::typeof(accum), A::AbstractArr | |||
Base.reducedim_initarray(A, region, nothing, Union{Nothing,eltype(A)}) | |||
end | |||
|
|||
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x)))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It restricts it to what can be handled by _project
as opposed to simple sizes and lengths of arrays.
2c1252b
to
09a0ed6
Compare
src/compiler/chainrules.jl
Outdated
@inline function _project(x::Union{Numeric, Ref{<:Numeric}}, dx) | ||
wrap_chainrules_output(ProjectTo(x)(wrap_chainrules_input(dx))) | ||
end | ||
_project(x::AbstractArray, dx) = dx isa AbstractArray ? reshape(dx, axes(x)) : dx |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be broken down into a different method
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you write exactly what method you prefer? There are obviously always other ways to write things.
``` | ||
""" | ||
function gradient(f, args...) | ||
y, back = pullback(f, args...) | ||
return back(sensitivity(y)) | ||
grad = back(sensitivity(y)) | ||
isnothing(grad) ? nothing : map(_project, args, grad) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can add a method to _project
and avoid this change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can add a method to _project and avoid this change
Can you write exactly what method that would be?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something like _project(x, ::Nothing) = nothing
maybe
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is easy to try:
julia> _project(x, ::Nothing) = nothing
_project (generic function with 1 method)
julia> map(_project, (1,2,3), nothing)
ERROR: MethodError: no method matching length(::Nothing)
ProjectTo
in broadcasting, etc.ProjectTo
in broadcasting & gradient
Is there a reason not to pull all of broadcasting down into ChainRules.jl? |
Why don't we give Zygote.Forward more love? It's better for neural networks. |
One reason not to is that Zygote's un-fused broadcast might not be the last word here. Maybe you can write a fused forward broadcast in Diffractor which would be hopelessly slow here. I think there's a lot of exploring left to be done. Unlike the basic rules in ChainRules, where we can write a pretty close to optimal rule once & let everything use it. Anyway this PR has much more modest goals. In the linked Flux issues it comes pretty close to entirely removing the penalty for mixing up your eltypes. And it fixes a lot of Zygote issues about real/complex. |
Mixing eltypes is going to get really important with low precision work picking up the pace. We shouldn't have to write custom passes for every operation related to 16 bit floats. Besides, its good not to be opinionated and guide users to be type stable. Wouldn't we expect complex numbers to have gradients with complex types? Changing that seems like a bug. |
Yes there have been rumours of mixed-precision training for ages. I don't see any obvious problem though. It does not involve randomly mixing types and hoping that Julia's promotion will figure it out. Complex/real has been discussed at great length. This PR really isn't the place to argue it; if you think it's wrong you should open an issue on ChainRulesCore and make your case.
Err, they do? There would be a lot of broken tests if that were altered. I think you may have misunderstood what problem this projection solves. The first message has examples, and links to issues closed. |
Co-authored-by: Lyndon White <[email protected]>
I did respond on the slack where I'd mentioned wanting to take a look at it today. |
This starts building ChainRulesCore's type projection story into how Zygote handles broadcasting, and into its user-facing functions. This will already be called in some rules handled by ChainRules, but this applies it a bit more broadly.
After:
Before:
Replaces #965, or most of it.
Many tests will fail, including most of the FFT tests I think, since those tend to return a complex gradient for a real input.FFT tests are unchanged.Closes #342, closes #402. Fixes #917, fixes #431.
Closes FluxML/Flux.jl#886