-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Excise getindex adjoint #1328
Excise getindex adjoint #1328
Conversation
Aside from some cascading failures from Flux defining methods on a now-nonexistent Lines 163 to 164 in 81380f0
wrap_chainrules_input([nothing]) == [ZeroTangent()] , ∇getindex ends up calling _setindex_zero(x::Vector{Float64}, dy::ZeroTangent, inds::Vector{Int64}) which calls fill!(dest::Vector{ZeroTangent}, x::Bool) . I don't understand why CR is using eltype(y) for eltype(dx) , so summoning @mcabbott for help.
|
Maybe there should be a |
Do you mean something like JuliaDiff/ChainRules.jl#683? |
Thanks for pointing me to this @ToucheSir. I just rebased this on master (totally clean) and gave it a try and works great for the issue I was running into. I also note that _, back = Zygote._pullback(x->x[1]*im, randn(2))
@test back(1.0)[2] == real([-im, 0]) == [0, 0] now passes with (ChainRules v1.46.0) with no other modifications other than the rebase. Anything I can do to help this along? Currently I'm fairly motivated to have this working. |
We need to figure out how to make CI and downstream tests pass. This may or may not entail changes on the ChainRules side. |
Try locally with this perhaps: JuliaDiff/ChainRules.jl@7bb7d98 |
With JuliaDiff/ChainRules.jl#687 the getindex failures are:
|
src/lib/array.jl
Outdated
end | ||
Base.size(A::OneElement) = map(length, A.axes) | ||
Base.axes(A::OneElement) = A.axes | ||
Base.getindex(A::OneElement{T,N}, i::Vararg{Int,N}) where {T,N} = ifelse(i==A.ind, A.val, zero(T)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Deleting this struct makes Flux fail.
It could just be left in deprecations. Or it could be hooked up more narrowly to CR's rule, something like this shouldn't upset 2nd derivative definitions:
ChainRules.∇getindex(x::Array{<:Number,N}, dy, inds::Vararg{Integer,N}) where N =
OneElement(dy, inds, axes(x))
But that's piracy...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't bother reverting the deletion into a deprecation because #1328 (comment) was still unresolved, but that seems like the way to go.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. To be clear I think returning one nothing
is an upgrade even though it makes those tests fail.
If someone approves JuliaDiff/ChainRules.jl#687 then perhaps CI can be made happy here.
d73f176
to
3e1847d
Compare
Dusting this one off since remaining failures in #1389 have remained quite stubborn. Downstream doesn't appear to have a problem with collapsing zeros, so I don't think it should be breaking? |
We have a better rule in Chainrules now
84b3079
to
6edfce9
Compare
@ToucheSir what is the status of this right now? |
I think I was waiting for an opinion on #1328 (comment). This PR means we'll be collapsing zeros more aggressively than Zygote currently does. I'm not sure if that counts as a breaking change or not. I do anticipate it happening more and more if/when we migrate more rules over to ChainRules and use CR zero types more internally. |
I think collapsing zeros more aggressively is fine |
Zygote's own tests pass now and the Flux failure is due to a broken hessian test passing. That just leaves Molly, which is failing because it overloads |
Thanks for keeping me in the loop here. I tried this PR and Molly with the The offending line in Molly is https://github.com/JuliaMolSim/Molly.jl/blob/master/src/interactions/implicit_solvent.jl#L642 and the stack trace is below. I can look into it more and open an issue but am posting it here in case anyone can quickly see the problem or the fix. Otherwise this PR seems good, I can remove the
|
Thanks for looking into this. It is possible to retroactively upper-bound the Zygote dep for older versions of Molly in the registry, but that would be more work and reduce future version compatibility quite a bit more than keeping the stub. I also don't know what the error you're running into could be, but we could always hold off on tagging until you've figured it out. |
I fixed the above issue so I think this PR should be merged and I can lower bound on the next Zygote release with the appropriate changes. |
Any chance this could be merged and tagged? |
Thanks! |
This finally kills a very long-standing fork I've been having to keep around for a project that needed this, thank you @ToucheSir! |
I'm not perfectly sure but this change might have caused a regression I reported here: JuliaML/MLUtils.jl#170. |
We have a better rule in Chainrules now.
Draft for now because I suspect downstream CI will have something to say about this. If it works though, would fix #820 and #1327.
PR Checklist
getindex
#944?)