Skip to content

Commit

Permalink
Generalize (inv)quad and (un)whiten methods
Browse files Browse the repository at this point in the history
Now necessary because of JuliaStats/PDMats.jl#183
  • Loading branch information
sethaxen committed Nov 30, 2023
1 parent bcb0f46 commit ef3c7ad
Showing 1 changed file with 18 additions and 5 deletions.
23 changes: 18 additions & 5 deletions src/woodbury.jl
Original file line number Diff line number Diff line change
Expand Up @@ -366,8 +366,13 @@ end

PDMats.dim(W::WoodburyPDMat) = size(W.A, 1)

function PDMats.invquad(W::WoodburyPDMat, x::AbstractVector{T}) where {T}
return sum(abs2, pdfactorize(W).L \ x)
function PDMats.invquad(W::WoodburyPDMat, x::AbstractVecOrMat{T}) where {T}
WL_inv_x = pdfactorize(W).L \ x
if x isa AbstractVector
return sum(abs2, WL_inv_x)
else
return vec(sum(abs2, WL_inv_x; dims=1))
end
end

function PDMats.invquad!(r::AbstractArray, W::WoodburyPDMat, x::AbstractMatrix{T}) where {T}
Expand All @@ -382,18 +387,26 @@ function PDMats.quad!(r::AbstractArray, W::WoodburyPDMat, x::AbstractMatrix{T})
return r
end

function PDMats.quad(W::WoodburyPDMat, x::AbstractVector{T}) where {T}
v = pdfactorize(W).R * x
return sum(abs2, v)
function PDMats.quad(W::WoodburyPDMat, x::AbstractVecOrMat{T}) where {T}
WR_inv_x = pdfactorize(W).R * x
if x isa AbstractVector
return sum(abs2, WR_inv_x)
else
return vec(sum(abs2, WR_inv_x; dims=1))
end
end

PDMats.unwhiten(W::WoodburyPDMat, x::AbstractVecOrMat) = pdfactorize(W).L * x

function PDMats.unwhiten!(
r::AbstractVecOrMat{T}, W::WoodburyPDMat, x::AbstractVecOrMat{T}
) where {T}
copyto!(r, x)
return lmul!(pdfactorize(W).L, r)
end

PDMats.whiten(W::WoodburyPDMat, x::AbstractVecOrMat) = pdfactorize(W).R \ x

function invunwhiten!(
r::AbstractVecOrMat{T}, W::WoodburyPDMat, x::AbstractVecOrMat{T}
) where {T}
Expand Down

0 comments on commit ef3c7ad

Please sign in to comment.