Skip to content

Commit

Permalink
Rewriting rev_svd to (hopefully) be faster
Browse files Browse the repository at this point in the history
This uses fewer matrix multiplications.

The code no longer uses the helper function _mulsubtrans!!
so it has been removed.
  • Loading branch information
perrutquist committed Dec 4, 2023
1 parent ea25c11 commit f288c7a
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 44 deletions.
31 changes: 16 additions & 15 deletions src/rulesets/LinearAlgebra/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ end
#####

function _svd_pullback::Tangent, F)
∂X = svd_rev(F, Ȳ.U, Ȳ.S, Ȳ.Vt')
∂X = svd_rev(F, Ȳ.U, Ȳ.S, Ȳ.Vt)
return (NoTangent(), ∂X)
end
_svd_pullback(Ȳ::AbstractThunk, F) = _svd_pullback(unthunk(Ȳ), F)
Expand Down Expand Up @@ -244,34 +244,35 @@ function rrule(::typeof(getproperty), F::T, x::Symbol) where T <: SVD
end

# When not `ZeroTangent`s expect `Ū::AbstractMatrix, s̄::AbstractVector, V̄::AbstractMatrix`
function svd_rev(USV::SVD, Ū, s̄, )
function svd_rev(USV::SVD, Ū, s̄, V̄t)
# Note: assuming a thin factorization, i.e. svd(A, full=false), which is the default
U = USV.U
s = USV.S
V = USV.V
Vt = USV.Vt

k = length(s)
T = eltype(s)
F = T[i == j ? 1 : inv(@inbounds s[j]^2 - s[i]^2) for i = 1:k, j = 1:k]

# We do a lot of matrix operations here, so we'll try to be memory-friendly and do
# as many of the computations in-place as possible. Benchmarking shows that the in-
# place functions here are significantly faster than their out-of-place, naively
# implemented counterparts, and allocate no additional memory.
Ut = U'
FUᵀŪ = _mulsubtrans!!(Ut*Ū, F) # F .* (UᵀŪ - ŪᵀU)
FVᵀV̄ = _mulsubtrans!!(Vt*V̄, F) # F .* (VᵀV̄ - V̄ᵀV)
UtŪ = U'*Ū

Check warning on line 257 in src/rulesets/LinearAlgebra/factorization.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/rulesets/LinearAlgebra/factorization.jl:257:- UtŪ = U'*Ū src/rulesets/LinearAlgebra/factorization.jl:258:- V̄tV = V̄t*Vt' src/rulesets/LinearAlgebra/factorization.jl:265:+ UtŪ = U' * Ū src/rulesets/LinearAlgebra/factorization.jl:266:+ V̄tV = V̄t * Vt'
V̄tV = V̄t*Vt'

FUᵀŪS = F .* (UtŪ .- UtŪ') .* s'
SFVᵀV̄ = F .* (V̄tV' .- V̄tV) .* s

Check warning on line 262 in src/rulesets/LinearAlgebra/factorization.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/rulesets/LinearAlgebra/factorization.jl:262:- src/rulesets/LinearAlgebra/factorization.jl:270:+
S = Diagonal(s)
=isa AbstractZero ?: Diagonal(s̄)

Check warning on line 265 in src/rulesets/LinearAlgebra/factorization.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/rulesets/LinearAlgebra/factorization.jl:265:- src/rulesets/LinearAlgebra/factorization.jl:273:+
Ā = U * (FUᵀŪS ++ SFVᵀV̄) * Vt

# TODO: consider using MuladdMacro here
Ūs = Ū / S
V̄ts = S \'
Ā = add!!(U * FUᵀŪ * S, Ūs - U * (Ut * Ūs)) * Vt
Ā = add!!(Ā, U ** Vt)
Ā = add!!(Ā, U * add!!(S * FVᵀV̄ * Vt, V̄ts - (V̄ts * V) * Vt))
if size(U,1) > size(U,2)

Check warning on line 269 in src/rulesets/LinearAlgebra/factorization.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/rulesets/LinearAlgebra/factorization.jl:269:- if size(U,1) > size(U,2) src/rulesets/LinearAlgebra/factorization.jl:277:+ if size(U, 1) > size(U, 2)
Ā = add!!(Ā, ((Ū .- U * UtŪ) / S) * Vt)
end

Check warning on line 272 in src/rulesets/LinearAlgebra/factorization.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/rulesets/LinearAlgebra/factorization.jl:272:- src/rulesets/LinearAlgebra/factorization.jl:273:- if size(Vt,2) > size(Vt,1) src/rulesets/LinearAlgebra/factorization.jl:280:+ src/rulesets/LinearAlgebra/factorization.jl:281:+ if size(Vt, 2) > size(Vt, 1)
if size(Vt,2) > size(Vt,1)
Ā = add!!(Ā, U * (S \ (V̄t .- V̄tV * Vt)))
end

return Ā
end
Expand Down
18 changes: 0 additions & 18 deletions src/rulesets/LinearAlgebra/utils.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,6 @@
# Some utility functions for optimizing linear algebra operations that aren't specific
# to any particular rule definition

# F .* (X - X'), overwrites X if possible
function _mulsubtrans!!(X::AbstractMatrix{<:Real}, F::AbstractMatrix{<:Real})
T = promote_type(eltype(X), eltype(F))
Y = (T <: eltype(X)) ? X : similar(X, T)
k = size(X, 1)
@inbounds for j = 1:k, i = 1:j # Iterate the upper triangle
if i == j
Y[i,i] = zero(T)
else
Y[i,j], Y[j,i] = F[i,j] * (X[i,j] - X[j,i]), F[j,i] * (X[j,i] - X[i,j])
end
end
return Y
end
_mulsubtrans!!(X::AbstractZero, F::AbstractZero) = X
_mulsubtrans!!(X::AbstractZero, F::AbstractMatrix{<:Real}) = X
_mulsubtrans!!(X::AbstractMatrix{<:Real}, F::AbstractZero) = F

_extract_imag(x) = complex(0, imag(x))

"""
Expand Down
11 changes: 0 additions & 11 deletions test/rulesets/LinearAlgebra/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,17 +146,6 @@ end
@test dX_thunked == dX_unthunked
end
end

@testset "Helper functions" begin
X = randn(10, 10)
Y = randn(10, 10)
@test ChainRules._mulsubtrans!!(copy(X), Y) Y .* (X - X')

Z = randn(Float32, 10, 10)
result = ChainRules._mulsubtrans!!(copy(Z), Y)
@test result Y .* (Z - Z')
@test eltype(result) == Float64
end
end

@testset "eigendecomposition" begin
Expand Down

0 comments on commit f288c7a

Please sign in to comment.