Skip to content

Commit

Permalink
LinearAlgebra: diagzero for non-OneTo axes (#55252)
Browse files Browse the repository at this point in the history
Currently, the off-diagonal zeros for a block-`Diagonal` matrix is
computed using `diagzero`, which calls `zeros` for the sizes of the
elements. This returns an `Array`, unless one specializes `diagzero` for
the custom `Diagonal` matrix type.

This PR defines a `zeroslike` function that dispatches on the axes of
the elements, which lets packages specialize on the axes to return
custom `AbstractArray`s. Choosing to specialize on the `eltype` avoids
the need to specialize on the container, and allows packages to return
appropriate types for custom axis types.

With this,
```julia
julia> LinearAlgebra.zeroslike(::Type{S}, ax::Tuple{SOneTo, Vararg{SOneTo}}) where {S<:SMatrix} = SMatrix{map(length, ax)...}(ntuple(_->zero(eltype(S)), prod(length, ax)))

julia> D = Diagonal(fill(SMatrix{2,3}(1:6), 2))
2×2 Diagonal{SMatrix{2, 3, Int64, 6}, Vector{SMatrix{2, 3, Int64, 6}}}:
 [1 3 5; 2 4 6]        ⋅       
       ⋅         [1 3 5; 2 4 6]

julia> D[1,2] # now an SMatrix
2×3 SMatrix{2, 3, Int64, 6} with indices SOneTo(2)×SOneTo(3):
 0  0  0
 0  0  0

julia> LinearAlgebra.zeroslike(::Type{S}, ax::Tuple{SOneTo, Vararg{SOneTo}}) where {S<:MMatrix} = MMatrix{map(length, ax)...}(ntuple(_->zero(eltype(S)), prod(length, ax)))

julia> D = Diagonal(fill(MMatrix{2,3}(1:6), 2))
2×2 Diagonal{MMatrix{2, 3, Int64, 6}, Vector{MMatrix{2, 3, Int64, 6}}}:
 [1 3 5; 2 4 6]        ⋅       
       ⋅         [1 3 5; 2 4 6]

julia> D[1,2] # now an MMatrix
2×3 MMatrix{2, 3, Int64, 6} with indices SOneTo(2)×SOneTo(3):
 0  0  0
 0  0  0
```
The reason this can't be the default behavior is that we are not
guaranteed that there exists a `similar` method that accepts the
combination of axes. This is why we have to fall back to using the
sizes, unless a specialized method is provided by a package.

One positive outcome of this is that indexing into such a block-diagonal
matrix will now usually be type-stable, which mitigates
https://github.com/JuliaLang/julia/issues/45535 to some extent (although
it doesn't resolve the issue).

I've also updated the `getindex` for `Bidiagonal` to use `diagzero`,
instead of the similarly defined `bidiagzero` function that it was
using. Structured block matrices may now use `diagzero` uniformly to
generate the zero elements.
  • Loading branch information
jishnub authored Oct 9, 2024
1 parent 9c55783 commit 91da4bf
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 10 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ Standard library changes
(callable via `cholesky[!](A, RowMaximum())`) ([#54619]).
* The number of default BLAS threads now respects process affinity, instead of
using total number of logical threads available on the system ([#55574]).
* A new function `zeroslike` is added that is used to generate the zero elements for matrix-valued banded matrices.
Custom array types may specialize this function to return an appropriate result. ([#55252])

#### Logging

Expand Down
1 change: 1 addition & 0 deletions stdlib/LinearAlgebra/src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ public AbstractTriangular,
peakflops,
symmetric,
symmetric_type,
zeroslike,
matprod_dest

const BlasFloat = Union{Float64,Float32,ComplexF64,ComplexF32}
Expand Down
15 changes: 7 additions & 8 deletions stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,14 @@ Bidiagonal(A::Bidiagonal) = A
Bidiagonal{T}(A::Bidiagonal{T}) where {T} = A
Bidiagonal{T}(A::Bidiagonal) where {T} = Bidiagonal{T}(A.dv, A.ev, A.uplo)

bidiagzero(::Bidiagonal{T}, i, j) where {T} = zero(T)
function bidiagzero(A::Bidiagonal{<:AbstractMatrix}, i, j)
Tel = eltype(eltype(A.dv))
function diagzero(A::Bidiagonal{<:AbstractMatrix}, i, j)
Tel = eltype(A)
if i < j && A.uplo == 'U' #= top right zeros =#
return zeros(Tel, size(A.ev[i], 1), size(A.ev[j-1], 2))
return zeroslike(Tel, axes(A.ev[i], 1), axes(A.ev[j-1], 2))
elseif j < i && A.uplo == 'L' #= bottom left zeros =#
return zeros(Tel, size(A.ev[i-1], 1), size(A.ev[j], 2))
return zeroslike(Tel, axes(A.ev[i-1], 1), axes(A.ev[j], 2))
else
return zeros(Tel, size(A.dv[i], 1), size(A.dv[j], 2))
return zeroslike(Tel, axes(A.dv[i], 1), axes(A.dv[j], 2))
end
end

Expand Down Expand Up @@ -161,7 +160,7 @@ end
elseif i == j - _offdiagind(A.uplo)
return @inbounds A.ev[A.uplo == 'U' ? i : j]
else
return bidiagzero(A, i, j)
return diagzero(A, i, j)
end
end

Expand All @@ -173,7 +172,7 @@ end
# we explicitly compare the possible bands as b.band may be constant-propagated
return @inbounds A.ev[b.index]
else
return bidiagzero(A, Tuple(_cartinds(b))...)
return diagzero(A, Tuple(_cartinds(b))...)
end
end

Expand Down
23 changes: 21 additions & 2 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,27 @@ end
end
r
end
diagzero(::Diagonal{T}, i, j) where {T} = zero(T)
diagzero(D::Diagonal{<:AbstractMatrix{T}}, i, j) where {T} = zeros(T, size(D.diag[i], 1), size(D.diag[j], 2))
"""
diagzero(A::AbstractMatrix, i, j)
Return the appropriate zero element `A[i, j]` corresponding to a banded matrix `A`.
"""
diagzero(A::AbstractMatrix, i, j) = zero(eltype(A))
diagzero(D::Diagonal{M}, i, j) where {M<:AbstractMatrix} =
zeroslike(M, axes(D.diag[i], 1), axes(D.diag[j], 2))
# dispatching on the axes permits specializing on the axis types to return something other than an Array
zeroslike(M::Type, ax::Vararg{Union{AbstractUnitRange, Integer}}) = zeroslike(M, ax)
"""
zeroslike(::Type{M}, ax::Tuple{AbstractUnitRange, Vararg{AbstractUnitRange}}) where {M<:AbstractMatrix}
zeroslike(::Type{M}, sz::Tuple{Integer, Vararg{Integer}}) where {M<:AbstractMatrix}
Return an appropriate zero-ed array similar to `M`, with either the axes `ax` or the size `sz`.
This will be used as a structural zero element of a matrix-valued banded matrix.
By default, `zeroslike` falls back to using the size along each axis to construct the array.
"""
zeroslike(M::Type, ax::Tuple{AbstractUnitRange, Vararg{AbstractUnitRange}}) = zeroslike(M, map(length, ax))
zeroslike(M::Type, sz::Tuple{Integer, Vararg{Integer}}) = zeros(M, sz)
zeroslike(::Type{M}, sz::Tuple{Integer, Vararg{Integer}}) where {M<:AbstractMatrix} = zeros(eltype(M), sz)

@inline function getindex(D::Diagonal, b::BandIndex)
@boundscheck checkbounds(D, b)
Expand Down
10 changes: 10 additions & 0 deletions stdlib/LinearAlgebra/test/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,16 @@ end
B = Bidiagonal(dv, ev, :U)
@test B == Matrix{eltype(B)}(B)
end

@testset "non-standard axes" begin
LinearAlgebra.diagzero(T::Type, ax::Tuple{SizedArrays.SOneTo, Vararg{SizedArrays.SOneTo}}) =
zeros(T, ax)

s = SizedArrays.SizedArray{(2,2)}([1 2; 3 4])
B = Bidiagonal(fill(s,4), fill(s,3), :U)
@test @inferred(B[2,1]) isa typeof(s)
@test all(iszero, B[2,1])
end
end

@testset "copyto!" begin
Expand Down
7 changes: 7 additions & 0 deletions stdlib/LinearAlgebra/test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,13 @@ end
D = Diagonal(fill(S,3))
@test D * fill(S,2,3)' == fill(S * S', 3, 2)
@test fill(S,3,2)' * D == fill(S' * S, 2, 3)

@testset "indexing with non-standard-axes" begin
s = SizedArrays.SizedArray{(2,2)}([1 2; 3 4])
D = Diagonal(fill(s,3))
@test @inferred(D[1,2]) isa typeof(s)
@test all(iszero, D[1,2])
end
end

@testset "Eigensystem for block diagonal (issue #30681)" begin
Expand Down
3 changes: 3 additions & 0 deletions test/testhelpers/SizedArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,7 @@ mul!(dest::AbstractMatrix, S1::SizedMatrix, S2::SizedMatrix, α::Number, β::Num
mul!(dest::AbstractVector, M::AbstractMatrix, v::SizedVector, α::Number, β::Number) =
mul!(dest, M, _data(v), α, β)

LinearAlgebra.zeroslike(::Type{S}, ax::Tuple{SizedArrays.SOneTo, Vararg{SizedArrays.SOneTo}}) where {S<:SizedArray} =
zeros(eltype(S), ax)

end

0 comments on commit 91da4bf

Please sign in to comment.