Skip to content

Commit

Permalink
LinearAlgebra: Don't assume AbstractVector supports deleteat! (#46672)
Browse files Browse the repository at this point in the history
  • Loading branch information
eschnett authored Sep 12, 2022
1 parent 81eb6ef commit 9e8fb63
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 12 deletions.
4 changes: 4 additions & 0 deletions stdlib/LinearAlgebra/src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,10 @@ _evview(S::SymTridiagonal) = @view S.ev[begin:begin + length(S.dv) - 2]
_zeros(::Type{T}, b::AbstractVector, n::Integer) where {T} = zeros(T, max(length(b), n))
_zeros(::Type{T}, B::AbstractMatrix, n::Integer) where {T} = zeros(T, max(size(B, 1), n), size(B, 2))

# convert to Vector, if necessary
_makevector(x::Vector) = x
_makevector(x::AbstractVector) = Vector(x)

# append a zero element / drop the last element
_pushzero(A) = (B = similar(A, length(A)+1); @inbounds B[begin:end-1] .= A; @inbounds B[end] = zero(eltype(B)); B)
_droplast!(A) = deleteat!(A, lastindex(A))
Expand Down
5 changes: 3 additions & 2 deletions stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,9 @@ similar(B::Bidiagonal, ::Type{T}) where {T} = Bidiagonal(similar(B.dv, T), simil
similar(B::Bidiagonal, ::Type{T}, dims::Union{Dims{1},Dims{2}}) where {T} = zeros(T, dims...)

function kron(A::Diagonal, B::Bidiagonal)
kdv = kron(diag(A), B.dv)
kev = _droplast!(kron(diag(A), _pushzero(B.ev)))
# `_droplast!` is only guaranteed to work with `Vector`
kdv = _makevector(kron(diag(A), B.dv))
kev = _droplast!(_makevector(kron(diag(A), _pushzero(B.ev))))
Bidiagonal(kdv, kev, B.uplo)
end

Expand Down
7 changes: 4 additions & 3 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -605,9 +605,10 @@ function kron(A::Diagonal, B::SymTridiagonal)
SymTridiagonal(kdv, kev)
end
function kron(A::Diagonal, B::Tridiagonal)
kd = kron(diag(A), B.d)
kdl = _droplast!(kron(diag(A), _pushzero(B.dl)))
kdu = _droplast!(kron(diag(A), _pushzero(B.du)))
# `_droplast!` is only guaranteed to work with `Vector`
kd = _makevector(kron(diag(A), B.d))
kdl = _droplast!(_makevector(kron(diag(A), _pushzero(B.dl))))
kdu = _droplast!(_makevector(kron(diag(A), _pushzero(B.du))))
Tridiagonal(kdl, kd, kdu)
end

Expand Down
37 changes: 30 additions & 7 deletions stdlib/LinearAlgebra/test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -462,13 +462,36 @@ end
@test kron(Ad, Ad).diag == kron([1, 2, 3], [1, 2, 3])
end

@testset "kron (issue #46456)" begin
A = Diagonal(randn(10))
BL = Bidiagonal(randn(10), randn(9), :L)
BU = Bidiagonal(randn(10), randn(9), :U)
C = SymTridiagonal(randn(10), randn(9))
Cl = SymTridiagonal(randn(10), randn(10))
D = Tridiagonal(randn(9), randn(10), randn(9))
# Define a vector type that does not support `deleteat!`, to ensure that `kron` handles this
struct SimpleVector{T} <: AbstractVector{T}
vec::Vector{T}
end
SimpleVector(x::SimpleVector) = SimpleVector(Vector(x.vec))
SimpleVector{T}(::UndefInitializer, n::Integer) where {T} = SimpleVector(Vector{T}(undef, n))
Base.:(==)(x::SimpleVector, y::SimpleVector) = x == y
Base.axes(x::SimpleVector) = axes(x.vec)
Base.convert(::Type{Vector{T}}, x::SimpleVector) where {T} = convert(Vector{T}, x.vec)
Base.convert(::Type{Vector}, x::SimpleVector{T}) where {T} = convert(Vector{T}, x)
Base.convert(::Type{Array{T}}, x::SimpleVector) where {T} = convert(Vector{T}, x)
Base.convert(::Type{Array}, x::SimpleVector) = convert(Vector, x)
Base.copyto!(x::SimpleVector, y::SimpleVector) = (copyto!(x.vec, y.vec); x)
Base.eltype(::Type{SimpleVector{T}}) where {T} = T
Base.getindex(x::SimpleVector, ind...) = getindex(x.vec, ind...)
Base.kron(x::SimpleVector, y::SimpleVector) = SimpleVector(kron(x.vec, y.vec))
Base.promote_rule(::Type{<:AbstractVector{T}}, ::Type{SimpleVector{U}}) where {T,U} = Vector{promote_type(T, U)}
Base.promote_rule(::Type{SimpleVector{T}}, ::Type{SimpleVector{U}}) where {T,U} = SimpleVector{promote_type(T, U)}
Base.setindex!(x::SimpleVector, val, ind...) = (setindex!(x.vec, val, ind...), x)
Base.similar(x::SimpleVector, ::Type{T}) where {T} = SimpleVector(similar(x.vec, T))
Base.similar(x::SimpleVector, ::Type{T}, dims::Dims{1}) where {T} = SimpleVector(similar(x.vec, T, dims))
Base.size(x::SimpleVector) = size(x.vec)

@testset "kron (issue #46456)" for repr in Any[identity, SimpleVector]
A = Diagonal(repr(randn(10)))
BL = Bidiagonal(repr(randn(10)), repr(randn(9)), :L)
BU = Bidiagonal(repr(randn(10)), repr(randn(9)), :U)
C = SymTridiagonal(repr(randn(10)), repr(randn(9)))
Cl = SymTridiagonal(repr(randn(10)), repr(randn(10)))
D = Tridiagonal(repr(randn(9)), repr(randn(10)), repr(randn(9)))
@test kron(A, BL)::Bidiagonal == kron(Array(A), Array(BL))
@test kron(A, BU)::Bidiagonal == kron(Array(A), Array(BU))
@test kron(A, C)::SymTridiagonal == kron(Array(A), Array(C))
Expand Down

0 comments on commit 9e8fb63

Please sign in to comment.