Skip to content

Commit

Permalink
Unthunk each element in ∇eachslice
Browse files Browse the repository at this point in the history
  • Loading branch information
pxl-th committed Dec 26, 2024
1 parent e055009 commit 6e23e61
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/rulesets/Base/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -262,20 +262,22 @@ end
# Using Val(dim) here is worth a factor of 2 in this, on Julia 1.8-
# @btime rrule(eachcol, $([1 2; 3 4]))[2]($([[10, 20], [30, 40]]))
function ∇eachslice(dys_raw, x::AbstractArray, vd::Val{dim}) where {dim}
dys = unthunk(dys_raw)
dys = unthunk.(unthunk(dys_raw))
i1 = findfirst(dy -> dy isa AbstractArray, dys)
if i1 === nothing # all slices are Zero!
return _zero_fill!(similar(x, float(eltype(x)), axes(x)))
end

T = Base.promote_eltype(dys...)
# The whole point of this gradient is that we can allocate one `dx` array:
dx = similar(x, T, axes(x))
for i in axes(x, dim)
slice = selectdim(dx, dim, i)
if dys[i] isa AbstractZero
dy = dys[i]
if dy isa AbstractZero
_zero_fill!(slice) # Avoids this: copyto!([1,2,3], ZeroTangent()) == [0,2,3]
else
copyto!(slice, dys[i])
copyto!(slice, dy)
end
end
return ProjectTo(x)(dx)
Expand Down

0 comments on commit 6e23e61

Please sign in to comment.