Skip to content

Commit

Permalink
Accept more general Integer sizes in reshape (#55521)
Browse files Browse the repository at this point in the history
This PR generalizes the `reshape` methods to accept `Integer`s instead
of `Int`s, and adds a `_reshape_uncolon` method for `Integer` arguments.
The current `_reshape_uncolon` method that accepts `Int`s is left
unchanged to ensure that the inferred types are not impacted. I've also
tried to ensure that most `Integer` subtypes in `Base` that may be
safely converted to `Int`s pass through that method.

The call sequence would now go like this:
```julia
reshape(A, ::Tuple{Vararg{Union{Integer, Colon}}}) -> reshape(A, ::Tuple{Vararg{Integer}}) -> reshape(A, ::Tuple{Vararg{Int}}) (fallback)
```
This lets packages define `reshape(A::CustomArray, ::Tuple{Integer,
Vararg{Integer}})` without having to implement `_reshape_uncolon` by
themselves (or having to call internal `Base` functions, as in
JuliaArrays/FillArrays.jl#373). `reshape`
calls involving a `Colon` would convert this to an `Integer` in `Base`,
and then pass the `Integer` sizes to the custom method defined in the
package.

This PR does not resolve issues like
#40076 because this still
converts `Integer`s to `Int`s in the actual reshaping step. However,
`BigInt` sizes that may be converted to `Int`s will work now:
```julia
julia> reshape(1:4, big(2), big(2))
2×2 reshape(::UnitRange{Int64}, 2, 2) with eltype Int64:
 1  3
 2  4

julia> reshape(1:4, big(1), :)
1×4 reshape(::UnitRange{Int64}, 1, 4) with eltype Int64:
 1  2  3  4
```

Note that the reshape method with `Integer` sizes explicitly converts
these to `Int`s to avoid self-recursion (as opposed to calling
`to_shape` to carry out the conversion implicitly). In the future, we
may want to decide what to do with types or values that can't be
converted to an `Int`.

---------

Co-authored-by: Neven Sajko <[email protected]>
  • Loading branch information
jishnub and nsajko authored Dec 5, 2024
1 parent e572d23 commit 5835c3b
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 20 deletions.
57 changes: 38 additions & 19 deletions base/reshapedarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,37 +121,56 @@ reshape

reshape(parent::AbstractArray, dims::IntOrInd...) = reshape(parent, dims)
reshape(parent::AbstractArray, shp::Tuple{Union{Integer,OneTo}, Vararg{Union{Integer,OneTo}}}) = reshape(parent, to_shape(shp))
reshape(parent::AbstractArray, dims::Tuple{Integer, Vararg{Integer}}) = reshape(parent, map(Int, dims))
reshape(parent::AbstractArray, dims::Dims) = _reshape(parent, dims)

# Allow missing dimensions with Colon():
reshape(parent::AbstractVector, ::Colon) = parent
reshape(parent::AbstractVector, ::Tuple{Colon}) = parent
reshape(parent::AbstractArray, dims::Int...) = reshape(parent, dims)
reshape(parent::AbstractArray, dims::Union{Int,Colon}...) = reshape(parent, dims)
reshape(parent::AbstractArray, dims::Tuple{Vararg{Union{Int,Colon}}}) = reshape(parent, _reshape_uncolon(parent, dims))
@inline function _reshape_uncolon(A, dims)
@noinline throw1(dims) = throw(DimensionMismatch(string("new dimensions $(dims) ",
"may have at most one omitted dimension specified by `Colon()`")))
@noinline throw2(A, dims) = throw(DimensionMismatch(string("array size $(length(A)) ",
"must be divisible by the product of the new dimensions $dims")))
pre = _before_colon(dims...)::Tuple{Vararg{Int}}
reshape(parent::AbstractArray, dims::Integer...) = reshape(parent, dims)
reshape(parent::AbstractArray, dims::Union{Integer,Colon}...) = reshape(parent, dims)
reshape(parent::AbstractArray, dims::Tuple{Vararg{Union{Integer,Colon}}}) = reshape(parent, _reshape_uncolon(parent, dims))

@noinline throw1(dims) = throw(DimensionMismatch(LazyString("new dimensions ", dims,
" may have at most one omitted dimension specified by `Colon()`")))
@noinline throw2(lenA, dims) = throw(DimensionMismatch(string("array size ", lenA,
" must be divisible by the product of the new dimensions ", dims)))

@inline function _reshape_uncolon(A, _dims::Tuple{Vararg{Union{Integer, Colon}}})
# promote the dims to `Int` at least
dims = map(x -> x isa Colon ? x : promote_type(typeof(x), Int)(x), _dims)
pre = _before_colon(dims...)
post = _after_colon(dims...)
_any_colon(post...) && throw1(dims)
post::Tuple{Vararg{Int}}
len = length(A)
sz, is_exact = if iszero(len)
(0, true)
_reshape_uncolon_computesize(len, dims, pre, post)
end
@inline function _reshape_uncolon_computesize(len::Int, dims, pre::Tuple{Vararg{Int}}, post::Tuple{Vararg{Int}})
sz = if iszero(len)
0
else
let pr = Core.checked_dims(pre..., post...) # safe product
if iszero(pr)
throw2(A, dims)
end
(quo, rem) = divrem(len, pr)
(Int(quo), iszero(rem))
quo = _reshape_uncolon_computesize_nonempty(len, dims, pr)
convert(Int, quo)
end
end::Tuple{Int,Bool}
is_exact || throw2(A, dims)
(pre..., sz, post...)::Tuple{Int,Vararg{Int}}
end
(pre..., sz, post...)
end
@inline function _reshape_uncolon_computesize(len, dims, pre, post)
pr = prod((pre..., post...))
sz = if iszero(len)
promote(len, pr)[1] # zero of the correct type
else
_reshape_uncolon_computesize_nonempty(len, dims, pr)
end
(pre..., sz, post...)
end
@inline function _reshape_uncolon_computesize_nonempty(len, dims, pr)
iszero(pr) && throw2(len, dims)
(quo, rem) = divrem(len, pr)
iszero(rem) || throw2(len, dims)
quo
end
@inline _any_colon() = false
@inline _any_colon(dim::Colon, tail...) = true
Expand Down
20 changes: 20 additions & 0 deletions test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2186,3 +2186,23 @@ end
copyto!(A, 1, x, 1)
@test A == axes(A,1)
end

@testset "reshape with Integer sizes" begin
@test reshape(1:4, big(2), big(2)) == reshape(1:4, 2, 2)
a = [1 2 3; 4 5 6]
reshaped_arrays = (
reshape(a, 3, 2),
reshape(a, (3, 2)),
reshape(a, big(3), big(2)),
reshape(a, (big(3), big(2))),
reshape(a, :, big(2)),
reshape(a, (:, big(2))),
reshape(a, big(3), :),
reshape(a, (big(3), :)),
)
@test allequal(reshaped_arrays)
for b reshaped_arrays
@test b isa Matrix{Int}
@test b.ref === a.ref
end
end
11 changes: 11 additions & 0 deletions test/offsetarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -914,3 +914,14 @@ end
b = sum(a, dims=1)
@test b[begin] == sum(r)
end

@testset "reshape" begin
A0 = [1 3; 2 4]
A = reshape(A0, 2:3, 4:5)
@test axes(A) == Base.IdentityUnitRange.((2:3, 4:5))

B = reshape(A0, -10:-9, 9:10)
@test isa(B, OffsetArray{Int,2})
@test parent(B) == A0
@test axes(B) == Base.IdentityUnitRange.((-10:-9, 9:10))
end
4 changes: 3 additions & 1 deletion test/testhelpers/OffsetArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ _similar_axes_or_length(AT, ax::I, ::I) where {I} = similar(AT, map(_indexlength

# reshape accepts a single colon
Base.reshape(A::AbstractArray, inds::OffsetAxis...) = reshape(A, inds)
function Base.reshape(A::AbstractArray, inds::Tuple{OffsetAxis,Vararg{OffsetAxis}})
function Base.reshape(A::AbstractArray, inds::Tuple{Vararg{OffsetAxis}})
AR = reshape(no_offset_view(A), map(_indexlength, inds))
O = OffsetArray(AR, map(_offset, axes(AR), inds))
return _popreshape(O, axes(AR), _filterreshapeinds(inds))
Expand Down Expand Up @@ -557,6 +557,8 @@ Base.reshape(A::OffsetArray, inds::Tuple{OffsetAxis,Vararg{OffsetAxis}}) =
OffsetArray(_reshape(parent(A), inds), map(_toaxis, inds))
# And for non-offset axes, we can just return a reshape of the parent directly
Base.reshape(A::OffsetArray, inds::Tuple{Union{Integer,Base.OneTo},Vararg{Union{Integer,Base.OneTo}}}) = _reshape_nov(A, inds)
Base.reshape(A::OffsetArray, inds::Tuple{Integer,Vararg{Integer}}) = _reshape_nov(A, inds)
Base.reshape(A::OffsetArray, inds::Tuple{Union{Colon, Integer}, Vararg{Union{Colon, Integer}}}) = _reshape_nov(A, inds)
Base.reshape(A::OffsetArray, inds::Dims) = _reshape_nov(A, inds)
Base.reshape(A::OffsetVector, ::Colon) = A
Base.reshape(A::OffsetVector, ::Tuple{Colon}) = A
Expand Down

0 comments on commit 5835c3b

Please sign in to comment.