From 2d1ea3c59b1bb06a774ac566fa0ae9168b29ccb0 Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Thu, 17 Feb 2022 00:02:23 +0800 Subject: [PATCH] Fix `stride(A, i)` for 0-dim inputs (#44090) Fixes #44087 --- base/abstractarray.jl | 8 +++++++- base/reinterpretarray.jl | 2 ++ test/abstractarray.jl | 13 +++++++++++++ 3 files changed, 22 insertions(+), 1 deletion(-) diff --git a/base/abstractarray.jl b/base/abstractarray.jl index 0f900aeef5f45..50b83dff86e6b 100644 --- a/base/abstractarray.jl +++ b/base/abstractarray.jl @@ -546,7 +546,13 @@ julia> stride(A,3) function stride(A::AbstractArray, k::Integer) st = strides(A) k ≤ ndims(A) && return st[k] - return sum(st .* size(A)) + ndims(A) == 0 && return 1 + sz = size(A) + s = st[1] * sz[1] + for i in 2:ndims(A) + s += st[i] * sz[i] + end + return s end @inline size_to_strides(s, d, sz...) = (s, size_to_strides(s * d, sz...)...) diff --git a/base/reinterpretarray.jl b/base/reinterpretarray.jl index c33c2027839ea..3b54ed04089cd 100644 --- a/base/reinterpretarray.jl +++ b/base/reinterpretarray.jl @@ -149,6 +149,8 @@ StridedMatrix{T} = StridedArray{T,2} StridedVecOrMat{T} = Union{StridedVector{T}, StridedMatrix{T}} strides(a::Union{DenseArray,StridedReshapedArray,StridedReinterpretArray}) = size_to_strides(1, size(a)...) +stride(A::Union{DenseArray,StridedReshapedArray,StridedReinterpretArray}, k::Integer) = + k ≤ ndims(A) ? strides(A)[k] : length(A) function strides(a::ReshapedReinterpretArray) ap = parent(a) diff --git a/test/abstractarray.jl b/test/abstractarray.jl index 84d69200368e8..060f1ffa8b8cb 100644 --- a/test/abstractarray.jl +++ b/test/abstractarray.jl @@ -1584,6 +1584,19 @@ end end end +@testset "stride for 0 dims array #44087" begin + struct Fill44087 <: AbstractArray{Int,0} + a::Int + end + # `stride` shouldn't work if `strides` is not defined. + @test_throws MethodError stride(Fill44087(1), 1) + # It is intentionally to only check the return type. (The value is somehow arbitrary) + @test stride(fill(1), 1) isa Int + @test stride(reinterpret(Float64, fill(Int64(1))), 1) isa Int + @test stride(reinterpret(reshape, Float64, fill(Int64(1))), 1) isa Int + @test stride(Base.ReshapedArray(fill(1), (), ()), 1) isa Int +end + @testset "to_indices inference (issue #42001 #44059)" begin @test (@inferred to_indices([], ntuple(Returns(CartesianIndex(1)), 32))) == ntuple(Returns(1), 32) @test (@inferred to_indices([], ntuple(Returns(CartesianIndices(1:1)), 32))) == ntuple(Returns(Base.OneTo(1)), 32)