-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
gradient
broken for (*)(::Diagonal{Real}, ::Matrix{Complex}, ::Diagonal{Real})
when updating Julia 1.8 -> 1.9
#1483
Comments
Do you mind testing the same function but with ForwardDiff instead of Zygote on 1.8/1.9? Zygote's broadcasting code isn't doing anything different between versions, so I wonder if changes in the stdlib or ForwardDiff are leading to a different path being taken depending on the version. |
using ForwardDiff works. FWIW I quickly checked the code in LinearAlgebra/src/diagonal.jl and it seems it changed from 1.8 -> 1.9. |
The blame points to JuliaLang/julia#46400. I imagine this was hitting the 2-arg |
I have the same thought process. Personally, I'd think we should write a new Edit - Please let me know if there is anything I can do to speed up fixing this (the way you guys wish, I'm going to try writing a |
Agree that JuliaLang/julia#46400 must be what's new, previously this would have gone on to pairwise julia> gradient(x -> sum(abs2, x .* Ac .* x), [0.1, 0.2, 0.3]) # fine?
([0.008354213288778839, 0.04584256955208588, 0.27367634356707043],)
julia> gradient(x -> sum(abs2, broadcast(*, x, Ac, x)), [0.1, 0.2, 0.3]) # same error
ERROR: MethodError: no method matching _mul_partials(::ForwardDiff.Partials{3, Float64}, ::ForwardDiff.Partials{6, Float64}, ::Float64, ::Float64)
julia> using ForwardDiff
julia> ForwardDiff.gradient(x -> sum(abs2, broadcast(*, x, Ac, x)), [0.1, 0.2, 0.3]) # fine, same result
3-element Vector{Float64}:
0.008354213288778839
0.04584256955208588
0.27367634356707043
julia> gradient(x -> sum(abs2, broadcast(*, x, Ar, x)), [0.1, 0.2, 0.3]) # all real is fine, as above
([0.006549471072183613, 0.017877747316265458, 0.05009079087702957],)
julia> gradient(x -> sum(abs2, broadcast(*, x*im, Ac, x*im)), [0.1, 0.2, 0.3]) # all complex also fine?
([0.008354213288778839, 0.04584256955208588, 0.27367634356707043],) I do think this probably points to Zygote's treatment of broadcasting with complex numbers. There are special rules for broadcasting When there is no special broadcasting rule, the generic one here tries to use Dual numbers before giving up and eventually saving all the Zygote scalar pullbacks. The upgrade to use Complex{Dual} was I think #1324 , and it's possible that this mismatch of 6 + 3 partials comes from some bug in that? The code is re-worked in #1441, not merged yet, but it might be worth trying that to see if anything changed. Attempting to trigger such a broadcasting bug, without julia> Zygote.Numeric
Union{AbstractArray{<:T}, T} where T<:Number
julia> gradient(x -> sum(abs2, broadcast(+, x, Ac, x)), [0.1, 0.2, 0.3]) # 3-arg + is ok
([6.217191396508677, 11.013318686350626, 11.801770970288118],)
julia> gradient(x -> sum(abs2, broadcast((a,b,c) -> (a/b+c), x, Ac, x)), [0.1, 0.2, 0.3])
ERROR: Cannot determine ordering of Dual tags Nothing and Nothing
Stacktrace:
[1] ≺(a::Type, b::Type)
@ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/dual.jl:54
[2] promote_rule(::Type{Dual{Nothing, Float64, 6}}, ::Type{Dual{Nothing, Float64, 3}})
@ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/dual.jl:407
julia> gradient(x -> sum(abs2, broadcast((a,b,c) -> (a/b+c), x, Ar, x)), [0.1, 0.2, 0.3]) # all real
([25.238201775657366, 49.06406347939432, 39.42294666177911],)
julia> gradient(x -> sum(abs2, broadcast((a,b,c) -> (a/b+c), x*im, Ac, x*im)), [0.1, 0.2, 0.3]) # all complex
([2.799189365814652, 15.391840124749505, 6.102148153348823],) Aside from looking for bugs in Zygote's broadcasting, it would not be crazy to have a rule for this 3-arg |
FWIW, this temporary workaround fixed my issue. I'm sure it is not ideal (I'm a beginner with writing _3arg_mul(A::Diagonal, B::AbstractMatrix{<:Complex}, C::Diagonal) = A * B * C
function ChainRulesCore.rrule(
::typeof(_3arg_mul), A::Diagonal, B::AbstractMatrix{<:Complex}, C::Diagonal
)
project_A = ProjectTo(A)
project_B = ProjectTo(B)
project_C = ProjectTo(C)
function _3arg_mul_pullback(ȳ)
dA = ȳ * (B * C)'
dB = A' * ȳ * C'
dC = (A * B)' * ȳ
return NoTangent(), project_A(dA), project_B(dB), project_C(dC)
end
return A * B * C, _3arg_mul_pullback
end |
gradient
breaks when triple multiplying aDiagonal{<:Real}
,Matrix{<:Complex}
, andDiagonal{Real}
. This breaks going from Julia 1.8 -> 1.9.MWE:
Error message:
Versions:
Package versions:
The text was updated successfully, but these errors were encountered: