Skip to content

Commit

Permalink
Prevent type=inferability escaping for rrule of sortslices
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed Jan 1, 2025
1 parent e055009 commit 1696cee
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
7 changes: 5 additions & 2 deletions src/rulesets/Base/sort.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,13 @@ end

function rrule(::typeof(sortslices), x::AbstractArray; dims::Integer, kw...)
p = sortperm(collect(eachslice(x; dims=dims)); kw...)
inds = ntuple(d -> d == dims ? p : (:), ndims(x))
function sortslices_pullback(dy)
return (NoTangent(), ∇getindex(x, unthunk(dy), inds...))
# avoid closing over `inds` as it doesn't fully infer and that makes it worse
# recomputing is cheap
inds_inner = ntuple(d -> d == dims ? p : (:), ndims(x))
return (NoTangent(), ∇getindex(x, unthunk(dy), inds_inner...))
end
inds = ntuple(d -> d == dims ? p : (:), ndims(x))
return x[inds...], sortslices_pullback
end

Expand Down
2 changes: 1 addition & 1 deletion test/rulesets/Base/sort.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

test_rrule(sortslices, rand(3,4); fkwargs=(; dims=2))
test_rrule(sortslices, rand(5,4); fkwargs=(; dims=1, rev=true, by=last))
test_rrule(sortslices, rand(3,4,5); fkwargs=(; dims=3, by=sum), check_inferred=false)
test_rrule(sortslices, rand(3,4,5); fkwargs=(; dims=3, by=sum))

@test_throws Exception sortslices(Diagonal(1:3), dims=1)
end
Expand Down

0 comments on commit 1696cee

Please sign in to comment.