Skip to content

Commit

Permalink
Apply a few fixes and improvements from PDMats.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
David Widmann committed Oct 16, 2023
1 parent 1b4bf27 commit e088639
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 13 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
ChainRulesCore = "0.9.17, 0.10, 1"
Distributions = "0.23, 0.24"
FiniteDifferences = "0.11, 0.12"
PDMats = "0.9, 0.10, 0.11"
PDMats = "0.11.19"
Zygote = "0.5.5, 0.6"
julia = "1"

Expand Down
2 changes: 2 additions & 0 deletions src/PDMatsExtras.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@ include("psd_mat.jl")
include("woodbury_pd_mat.jl")
include("utils.jl")

Base.@deprecate PSDMat{T,S}(d::Int, m::AbstractMatrix{T}, c::CholType{T,S}) where {T,S} PSDMat{T,S}(m, c)

end
50 changes: 41 additions & 9 deletions src/psd_mat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,40 @@ References for discussion on supporting degenerate mvnormal distributions:
Positive semi-definite matrix together with a CholeskyPivoted factorization object.
"""
struct PSDMat{T<:Real,S<:AbstractMatrix} <: AbstractPDMat{T}
dim::Int
mat::S
chol::CholType{T, S}

PSDMat{T,S}(d::Int,m::AbstractMatrix{T},c::CholType{T,S}) where {T,S} = new{T, S}(d, m, c)
PSDMat{T,S}(m::AbstractMatrix{T},c::CholType{T,S}) where {T,S} = new{T, S}(m, c)
end

function PSDMat(mat::AbstractMatrix, chol::CholType)
d = size(mat, 1)
d = LinearAlgebra.checksquare(mat)
size(chol, 1) == d ||
throw(DimensionMismatch("Dimensions of mat and chol are inconsistent."))
PSDMat{eltype(mat),typeof(mat)}(d, mat, chol)
PSDMat{eltype(mat),typeof(mat)}(mat, chol)
end

PSDMat(mat::Matrix) = PSDMat(mat, cholesky(mat, Val(true); check=false))
PSDMat(mat::Matrix) = PSDMat(mat, cholesky(mat, VERSION >= v"1.8.0-rc1" ? RowMaximum() : Val(true); check=false))
PSDMat(mat::Symmetric) = PSDMat(Matrix(mat))
PSDMat(fac::CholType) = PSDMat(Matrix(fac), fac)

function Base.getproperty(a::PSDMat, s::Symbol)
if s === :dim
return size(getfield(a, :mat), 1)
end
return getfield(a, s)
end
Base.propertynames(::PSDMat) = (:mat, :chol, :dim)

Check warning on line 43 in src/psd_mat.jl

View check run for this annotation

Codecov / codecov/patch

src/psd_mat.jl#L43

Added line #L43 was not covered by tests

### Conversion
Base.convert(::Type{PSDMat{T}}, a::PSDMat) where {T<:Real} = PSDMat(convert(AbstractArray{T}, a.mat))
Base.convert(::Type{AbstractArray{T}}, a::PSDMat) where {T<:Real} = convert(PSDMat{T}, a)
Base.convert(::Type{PSDMat{T}}, a::PSDMat{T}) where {T<:Real} = a
function Base.convert(::Type{PSDMat{T}}, a::PSDMat) where {T<:Real}
chol = convert(CholType{T}, a.chol)
S = typeof(chol.factors)
mat = convert(S, a.mat)
PSDMat{T,S}(mat, chol)

Check warning on line 51 in src/psd_mat.jl

View check run for this annotation

Codecov / codecov/patch

src/psd_mat.jl#L46-L51

Added lines #L46 - L51 were not covered by tests
end
Base.convert(::Type{AbstractPDMat{T}}, a::PSDMat) where {T<:Real} = convert(PSDMat{T}, a)

Check warning on line 53 in src/psd_mat.jl

View check run for this annotation

Codecov / codecov/patch

src/psd_mat.jl#L53

Added line #L53 was not covered by tests

### Basics

Expand All @@ -63,6 +76,7 @@ end
### Algebra

Base.inv(a::PSDMat) = PSDMat(inv(a.chol))
LinearAlgebra.cholesky(a::PSDMat) = a.chol

Check warning on line 79 in src/psd_mat.jl

View check run for this annotation

Codecov / codecov/patch

src/psd_mat.jl#L79

Added line #L79 was not covered by tests
LinearAlgebra.logdet(a::PSDMat) = logdet(a.chol)
LinearAlgebra.eigmax(a::PSDMat) = eigmax(a.mat)
LinearAlgebra.eigmin(a::PSDMat) = eigmin(a.mat)
Expand Down Expand Up @@ -92,8 +106,26 @@ end

### quadratic forms

PDMats.quad(a::PSDMat, x::StridedVector) = dot(x, a * x)
PDMats.invquad(a::PSDMat, x::StridedVector) = dot(x, a \ x)
function PDMats.quad(a::PSDMat, x::AbstractVecOrMat)
if a.dim != size(x, 1)
throw(DimensionMismatch("Inconsistent argument dimensions."))

Check warning on line 111 in src/psd_mat.jl

View check run for this annotation

Codecov / codecov/patch

src/psd_mat.jl#L111

Added line #L111 was not covered by tests
end
# https://github.com/JuliaLang/julia/commit/2425ae760fb5151c5c7dd0554e87c5fc9e24de73
if VERSION < v"1.4.0-DEV.92"
z = a.mat * x
return x isa AbstractVector ? dot(x, z) : map(dot, eachcol(x), eachcol(z))

Check warning on line 116 in src/psd_mat.jl

View check run for this annotation

Codecov / codecov/patch

src/psd_mat.jl#L115-L116

Added lines #L115 - L116 were not covered by tests
else
return x isa AbstractVector ? dot(x, a.mat, x) : map(Base.Fix1(quad, a), eachcol(x))
end
end

function PDMats.invquad(a::PSDMat, x::AbstractVecOrMat)
if a.dim != size(x, 1)
throw(DimensionMismatch("Inconsistent argument dimensions."))

Check warning on line 124 in src/psd_mat.jl

View check run for this annotation

Codecov / codecov/patch

src/psd_mat.jl#L124

Added line #L124 was not covered by tests
end
z = a.chol \ x
return x isa AbstractVector ? dot(x, z) : map(dot, eachcol(x), eachcol(z))
end

"""
quad!(r::AbstractArray, a::AbstractPDMat, x::StridedMatrix)
Expand Down
5 changes: 5 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,8 @@ function submat(A::ScalMat, inds)
checkbounds(Bool, A, inds) || throw(BoundsError(A, inds))
return ScalMat(length(inds), A.value)
end

# https://github.com/JuliaLang/julia/pull/29749
if VERSION < v"1.1.0-DEV.792"
eachcol(A::AbstractVecOrMat) = (view(A, :, i) for i in axes(A, 2))

Check warning on line 20 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L20

Added line #L20 was not covered by tests
end
4 changes: 2 additions & 2 deletions test/psd_mat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ end
verbose = 1
@testset "Positive definite" begin
M = TEST_MATRICES["Positive definite"]
pivoted = cholesky(M, Val(true))
pivoted = cholesky(M, VERSION >= v"1.8.0-rc1" ? RowMaximum() : Val(true))
C = PSDMat(M, pivoted)
test_pdmat(
C,
Expand All @@ -33,7 +33,7 @@ end
@testset "Positive semi-definite" begin
M = TEST_MATRICES["Positive semi-definite"]
@test !isposdef(M)
pivoted = cholesky(M, Val(true); check=false)
pivoted = cholesky(M, VERSION >= v"1.8.0-rc1" ? RowMaximum() : Val(true); check=false)
C = PSDMat(M, pivoted)
test_pdmat(
C,
Expand Down
2 changes: 1 addition & 1 deletion test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@

@testset "PSDMat" begin
SM = TEST_MATRICES["Positive semi-definite"]
pivoted = cholesky(SM, Val(true); check=false)
pivoted = cholesky(SM, VERSION >= v"1.8.0-rc1" ? RowMaximum() : Val(true); check=false)
M = PSDMat(SM, pivoted)
M_dense = Matrix(M)

Expand Down

0 comments on commit e088639

Please sign in to comment.