Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rule for Dict iteration #1285

Merged
merged 1 commit into from
Sep 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions src/lib/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,45 @@ end
end
end

# This rule behaves much like the getindex adjoint,
# just with an (internal) ordinal index instead of a key.
function _pullback(cx::AContext, ::typeof(iterate), d::Dict, i)
iter = iterate(d, i)
function dict_iterate_pullback(Δ)
(iter === nothing || Δ === nothing) && return
k, v = iter[1]
_, dv = Δ[1]
accum_param(cx, v, dv) === nothing && return
grad = grad_mut(cx, d)
grad[k] = accum(get(grad, k, nothing), dv)
return (nothing, grad, nothing)
end
return iter, dict_iterate_pullback
end

# ...while this one is to avoid duplicating code or differentiating skip_deleted.
# The alternative would be to write a rule for the private _iterate(::Dict, i).
function _pullback(cx::AContext, ::typeof(iterate), d::Dict)
# Calculation of i is the same used in iterate(::Dict)
return _pullback(cx, iterate, d, Base.skip_deleted(d, d.idxfloor))
end

function _pullback(cx::AContext, ::typeof(iterate), vi::Base.ValueIterator{<:Dict}, i::Int)
iter = iterate(vi, i)
function values_iterate_pullback(Δ)
(iter === nothing || Δ === nothing) && return
v, dv = iter[1], Δ[1]
accum_param(cx, v, dv) === nothing && return
# Same as vi.dict.keys[i], but without reaching into Dict internals.
# Iterating the dict instead of keys() is to hit the rules above in nested AD.
k = iterate(vi.dict, i)[1][1]
grad = grad_mut(cx, vi.dict)
grad[k] = accum(get(grad, k, nothing), dv)
return (nothing, (; dict = grad), nothing)
end
return iter, values_iterate_pullback
end

# Channels

grad_mut(ch::Channel) = Channel(ch.sz_max)
Expand Down
32 changes: 32 additions & 0 deletions test/lib/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,36 @@

@test result1 == result2
end

@testset "Dict iteration" begin
# https://github.com/FluxML/Zygote.jl/issues/1065
function sumkv(d)
s = zero(d["c"])
for (k, v) in d
s += v
k == :b && (s += v)
end
return sum(s)
end

function sumvals(d)
s = zero(d["c"])
for v in values(d)
s += v
end
return sum(s)
end

d_num = Dict(:a => 3, :b => 4, "c" => 5)
d_arr = Dict(:a => [3], :b => [4], "c" => [5])
ps = d_arr |> values |> collect |> Params

@test gradient(sumkv, d_num)[1] == Dict(:a => 1, :b => 2, "c" => 1)
grads = gradient(() -> sumkv(d_arr), ps)
@test (grads[d_arr[:a]], grads[d_arr[:b]], grads[d_arr["c"]]) == ([1], [2], [1])

@test gradient(sumvals, d_num)[1] == Dict(:a => 1, :b => 1, "c" => 1)
grads = gradient(() -> sumvals(d_arr), ps)
@test (grads[d_arr[:a]], grads[d_arr[:b]], grads[d_arr["c"]]) == ([1], [1], [1])
end
end