Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Define functions for Cholesky #168

Merged
merged 12 commits into from
Oct 13, 2023
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,17 @@ While in theory all of them can be defined, at present only the following subset

PRs to implement more generic fallbacks are welcome.

### Fallbacks for `LinearAlgebra.Cholesky`

For Cholesky decompositions of type `Cholesky` the following functions are defined as well:

- `dim`
- `whiten`, `whiten!`
- `unwhiten`, `unwhiten!`
- `quad`, `quad!`
- `invquad`, `invquad!`
- `X_A_Xt`, `Xt_A_X`, `X_invA_Xt`, `Xt_invA_X`

## Define Customized Subtypes

In some situation, it is useful to define a customized subtype of `AbstractPDMat` to capture positive definite matrices with special structures. For this purpose, one has to define a subset of methods (as listed below), and other methods will be automatically provided.
Expand Down
64 changes: 64 additions & 0 deletions src/chol.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,67 @@ if HAVE_CHOLMOD
chol_lower(cf::CholTypeSparse) = cf.PtL
chol_upper(cf::CholTypeSparse) = cf.UP
end

# Interface for `Cholesky`

dim(A::Cholesky) = LinearAlgebra.checksquare(A)

# whiten
whiten(A::Cholesky, x::AbstractVecOrMat) = chol_lower(A) \ x
whiten!(A::Cholesky, x::AbstractVecOrMat) = ldiv!(chol_lower(A), x)

# unwhiten
unwhiten(A::Cholesky, x::AbstractVecOrMat) = chol_lower(A) * x
unwhiten!(A::Cholesky, x::AbstractVecOrMat) = lmul!(chol_lower(A), x)

# 3-argument whiten/unwhiten
for T in (:AbstractVector, :AbstractMatrix)
@eval begin
whiten!(r::$T, A::Cholesky, x::$T) = whiten!(A, copyto!(r, x))
unwhiten!(r::$T, A::Cholesky, x::$T) = unwhiten!(A, copyto!(r, x))
end
end

# quad
quad(A::Cholesky, x::AbstractVector) = sum(abs2, chol_upper(A) * x)
function quad(A::Cholesky, X::AbstractMatrix)
Z = chol_upper(A) * X
return vec(sum(abs2, Z; dims=1))
end
function quad!(r::AbstractArray, A::Cholesky, X::AbstractMatrix)
Z = chol_upper(A) * X
return map!(Base.Fix1(sum, abs2), r, eachcol(Z))
end

# invquad
invquad(A::Cholesky, x::AbstractVector) = sum(abs2, chol_lower(A) \ x)
function invquad(A::Cholesky, X::AbstractMatrix)
Z = chol_lower(A) \ X
return vec(sum(abs2, Z; dims=1))
end
function invquad!(r::AbstractArray, A::Cholesky, X::AbstractMatrix)
Z = chol_lower(A) * X
return map!(Base.Fix1(sum, abs2), r, eachcol(Z))
end

# tri products

function X_A_Xt(A::Cholesky, X::AbstractMatrix)
Z = X * chol_lower(A)
return Z * transpose(Z)
end

function Xt_A_X(A::Cholesky, X::AbstractMatrix)
Z = chol_upper(A) * X
return transpose(Z) * Z
end

function X_invA_Xt(A::Cholesky, X::AbstractMatrix)
Z = X / chol_upper(A)
return Z * transpose(Z)
end

function Xt_invA_X(A::Cholesky, X::AbstractMatrix)
Z = chol_lower(A) \ X
return transpose(Z) * Z
end
13 changes: 11 additions & 2 deletions test/chol.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,29 @@ using PDMats: chol_lower, chol_upper

@testset "chol_lower and chol_upper" begin
@testset "allocations" begin
A = rand(100, 100)
d = 100
A = rand(d, d)
C = A'A
invC = inv(C)
size_of_one_copy = sizeof(C)
@assert size_of_one_copy > 100 # ensure the matrix is large enough that few-byte allocations don't matter
@assert size_of_one_copy > d # ensure the matrix is large enough that few-byte allocations don't matter

@test chol_lower(C) ≈ chol_upper(C)'
@test (@allocated chol_lower(C)) < 1.05 * size_of_one_copy # allow 5% overhead
@test (@allocated chol_upper(C)) < 1.05 * size_of_one_copy

X = randn(d, 10)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this one?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

d should be 100 and used above as well. It's a remaining (final?) issue from the merge commit afaf699 that broke tests a bit.

for uplo in (:L, :U)
ch = cholesky(Symmetric(C, uplo))
@test chol_lower(ch) ≈ chol_upper(ch)'
@test (@allocated chol_lower(ch)) < 33 # allow small overhead for wrapper types
@test (@allocated chol_upper(ch)) < 33 # allow small overhead for wrapper types

# Only test dim, `quad`/`invquad`, `whiten`/`unwhiten`, and tri products
@test dim(ch) == size(C, 1)
pdtest_quad(ch, C, invC, X, 0)
pdtest_triprod(ch, C, invC, X, 0)
pdtest_whiten(ch, C, 0)
end
end

Expand Down
15 changes: 9 additions & 6 deletions test/specialarrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using StaticArrays
@testset "Special matrix types" begin
@testset "StaticArrays" begin
# Full matrix
S = (x -> x * x')(@SMatrix(randn(4, 7)))
S = (x -> x * x' + I)(@SMatrix(randn(4, 7)))
PDS = PDMat(S)
@test PDS isa PDMat{Float64, <:SMatrix{4, 4, Float64}}
@test isbits(PDS)
Expand All @@ -27,12 +27,15 @@ using StaticArrays
X = @SMatrix rand(10, 4)
Y = @SMatrix rand(4, 10)

for A in (PDS, D, E)
@test A * x isa SVector{4, Float64}
@test A * x ≈ Matrix(A) * Vector(x)
for A in (PDS, D, E, C)
if !(A isa Cholesky)
# `*(::Cholesky, ::SArray)` is not defined
@test A * x isa SVector{4, Float64}
@test A * x ≈ Matrix(A) * Vector(x)

@test A * Y isa SMatrix{4, 10, Float64}
@test A * Y ≈ Matrix(A) * Matrix(Y)
@test A * Y isa SMatrix{4, 10, Float64}
@test A * Y ≈ Matrix(A) * Matrix(Y)
end

@test X / A isa SMatrix{10, 4, Float64}
@test X / A ≈ Matrix(X) / Matrix(A)
Expand Down
2 changes: 1 addition & 1 deletion test/testutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ function test_pdmat(C, Cmat::Matrix;
t_cholesky::Bool=true, # whether to test cholesky method
t_scale::Bool=true, # whether to test scaling
t_add::Bool=true, # whether to test pdadd
t_det::Bool=true, # whether to test det method
t_det::Bool=true, # whether to test det method
t_logdet::Bool=true, # whether to test logdet method
t_eig::Bool=true, # whether to test eigmax and eigmin
t_mul::Bool=true, # whether to test multiplication
Expand Down