From 722657d5705b0fe1aa24db7bfd887ab6ce968fd3 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Tue, 3 Sep 2024 13:55:10 +0530 Subject: [PATCH] Forward istriu/istril for triangular to parent (#55663) --- stdlib/LinearAlgebra/src/special.jl | 4 ++++ stdlib/LinearAlgebra/src/triangular.jl | 22 +++++++++++++++++-- stdlib/LinearAlgebra/test/triangular.jl | 28 +++++++++++++++++++++++++ 3 files changed, 52 insertions(+), 2 deletions(-) diff --git a/stdlib/LinearAlgebra/src/special.jl b/stdlib/LinearAlgebra/src/special.jl index 8263e632a86b8a..5a7c98cfdf32c1 100644 --- a/stdlib/LinearAlgebra/src/special.jl +++ b/stdlib/LinearAlgebra/src/special.jl @@ -586,3 +586,7 @@ function cholesky(S::RealHermSymComplexHerm{<:Real,<:SymTridiagonal}, ::NoPivot B = Bidiagonal{T}(diag(S, 0), diag(S, S.uplo == 'U' ? 1 : -1), sym_uplo(S.uplo)) cholesky!(Hermitian(B, sym_uplo(S.uplo)), NoPivot(); check = check) end + +# istriu/istril for triangular wrappers of structured matrices +_istril(A::LowerTriangular{<:Any, <:BandedMatrix}, k) = istril(parent(A), k) +_istriu(A::UpperTriangular{<:Any, <:BandedMatrix}, k) = istriu(parent(A), k) diff --git a/stdlib/LinearAlgebra/src/triangular.jl b/stdlib/LinearAlgebra/src/triangular.jl index 8d32dac824ce8f..71473e0dc11743 100644 --- a/stdlib/LinearAlgebra/src/triangular.jl +++ b/stdlib/LinearAlgebra/src/triangular.jl @@ -330,14 +330,32 @@ function Base.replace_in_print_matrix(A::Union{LowerTriangular,UnitLowerTriangul return i >= j ? s : Base.replace_with_centered_mark(s) end -Base.@constprop :aggressive function istril(A::Union{LowerTriangular,UnitLowerTriangular}, k::Integer=0) +istril(A::UnitLowerTriangular, k::Integer=0) = k >= 0 +istriu(A::UnitUpperTriangular, k::Integer=0) = k <= 0 +Base.@constprop :aggressive function istril(A::LowerTriangular, k::Integer=0) k >= 0 && return true return _istril(A, k) end -Base.@constprop :aggressive function istriu(A::Union{UpperTriangular,UnitUpperTriangular}, k::Integer=0) +@inline function _istril(A::LowerTriangular, k) + P = parent(A) + m = size(A, 1) + for j in max(1, k + 2):m + all(iszero, view(P, j:min(j - k - 1, m), j)) || return false + end + return true +end +Base.@constprop :aggressive function istriu(A::UpperTriangular, k::Integer=0) k <= 0 && return true return _istriu(A, k) end +@inline function _istriu(A::UpperTriangular, k) + P = parent(A) + m = size(A, 1) + for j in 1:min(m, m + k - 1) + all(iszero, view(P, max(1, j - k + 1):j, j)) || return false + end + return true +end istril(A::Adjoint, k::Integer=0) = istriu(A.parent, -k) istril(A::Transpose, k::Integer=0) = istriu(A.parent, -k) istriu(A::Adjoint, k::Integer=0) = istril(A.parent, -k) diff --git a/stdlib/LinearAlgebra/test/triangular.jl b/stdlib/LinearAlgebra/test/triangular.jl index b88be00b0209c3..8748d11bd1da4d 100644 --- a/stdlib/LinearAlgebra/test/triangular.jl +++ b/stdlib/LinearAlgebra/test/triangular.jl @@ -1226,4 +1226,32 @@ end end end +@testset "istriu/istril forwards to parent" begin + @testset "$(nameof(typeof(M)))" for M in [Tridiagonal(rand(n-1), rand(n), rand(n-1)), + Tridiagonal(zeros(n-1), zeros(n), zeros(n-1)), + Diagonal(randn(n)), + Diagonal(zeros(n)), + ] + @testset for TriT in (UpperTriangular, UnitUpperTriangular, LowerTriangular, UnitLowerTriangular) + U = TriT(M) + A = Array(U) + for k in -n:n + @test istriu(U, k) == istriu(A, k) + @test istril(U, k) == istril(A, k) + end + end + end + z = zeros(n,n) + @testset for TriT in (UpperTriangular, UnitUpperTriangular, LowerTriangular, UnitLowerTriangular) + P = Matrix{BigFloat}(undef, n, n) + copytrito!(P, z, TriT <: Union{UpperTriangular, UnitUpperTriangular} ? 'U' : 'L') + U = TriT(P) + A = Array(U) + @testset for k in -n:n + @test istriu(U, k) == istriu(A, k) + @test istril(U, k) == istril(A, k) + end + end +end + end # module TestTriangular