Skip to content

Commit

Permalink
Improve QR-algorithm stability (#136)
Browse files Browse the repository at this point in the history
Previously a perfect single shift was used in the 2x2 case of real eigenvalues, which can be disastrous for ill-conditioned eigenvalues.

The new implementation considers 2x2 blocks converged whether complex conjugate or real, and upper triangularizes it in case of real eigenvalues in a much more stable way. The most stable rotation is picked (there's a choice in case of two distinct eigenvalues), plus some formula manipulations and scaling tricks are applied.
  • Loading branch information
haampie committed Feb 17, 2024
1 parent 57b7fef commit 8929e0b
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 66 deletions.
185 changes: 120 additions & 65 deletions src/schurfact.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,13 @@ function double_shift_schur!(
H::AbstractMatrix{Tv},
from::Int,
to::Int,
μ::Complex,
trace::Tv,
determinant::Tv,
Q = NotWanted(),
) where {Tv<:Real}
m, n = size(H)

# Compute the nonzero entries of p = (H - μI)(H - μ'I)e₁.
# Compute the nonzero entries of p = (H - μ₋I)(H - μI)e₁.
# Because of the Hessenberg structure we only need H[min:min+2,min:min+1] to form p.
@inbounds H₁₁ = H[from+0, from+0]
@inbounds H₂₁ = H[from+1, from+0]
Expand All @@ -165,8 +166,9 @@ function double_shift_schur!(
@inbounds H₂₂ = H[from+1, from+1]
@inbounds H₃₂ = H[from+2, from+1]

p₁ = abs2(μ) - 2real(μ) * H₁₁ + H₁₁ * H₁₁ + H₁₂ * H₂₁
p₂ = -2real(μ) * H₂₁ + H₂₁ * H₁₁ + H₂₂ * H₂₁
# Todo: avoid under/overflow.
p₁ = H₁₁ * H₁₁ + H₁₂ * H₂₁ - trace * H₁₁ + determinant
p₂ = H₂₁ * (H₁₁ + H₂₂ - trace)
p₃ = H₃₂ * H₂₁

# Map that column to a mulitiple of e₁ via two Given's rotations
Expand Down Expand Up @@ -317,6 +319,74 @@ function single_shift_schur!(
H
end

"""
Returns a tuple (is_real, c, s) where is_real is true iff the matrix H = [H₁₁ H₁₂; H₂₁ H₂₂] has
real eigenvalues. The (c, s) values form the most stable Given's rotation that makes G * H * G'
upper triangular (i.e. zeros out the bottom left entry).
"""
function upper_triangular_2x2(H₁₁::T, H₁₂::T, H₂₁::T, H₂₂::T) where {T<:Real}
# Early exit in trivial cases.
(iszero(H₂₁) || (iszero(H₁₁ - H₂₂) && sign(H₁₂) != sign(H₂₁))) &&
return false, one(T), zero(T)
iszero(H₁₂) && return true, zero(T), one(T)

# The characteristic polynomial is `λ² - tr(H)λ + det(H) = 0`
# => λ = (tr(H) ± √(tr(H)² - 4det(H))) / 2.
# Conjugate pair iff tr(H)² - 4det(H) < 0
# => (H₁₁ + H₂₂)² - 4(H₁₁H₂₂ - H₁₂H₂₁) < 0
# => ((H₁₁ - H₂₂) / 2)² + H₁₂H₂₁ < 0.
p = (H₁₁ - H₂₂) / 2
bcmax = max(abs(H₁₂), abs(H₂₁))
bcmis = min(abs(H₁₂), abs(H₂₁)) * sign(H₁₂) * sign(H₂₁)
scale = max(abs(p), bcmax)
z = (p / scale) * p + (bcmax / scale) * bcmis

# Return when complex. LAPACK also deals with 0 < z < 4eps(T), but I don't find it worth it.
# Note that the < is important, cause H = [1 -1/4; 1 2] for example is not upper triangular
# and has zero discriminant.
z < 0 && return false, one(T), zero(T)

# The rotation is basically just a "perfect" Wilkinson shift, so compute from either
# (H₁₁ - λ₁, H₁₂) or (H₁₁ - λ₂, H₁₂). We pick the option that would avoid catastrophic
# cancellation by choosing equal signs. A small rewrite:
# H₁₁ - λ with λ = (tr(H)/2 ± √((tr(H)/2)² - det(H))) means
# H₁₁ - λ = (H₁₁ - H₂₂)/2 ± √(((H₁₁ - H₂₂)/2)² - H₂₁H₁₂)
H₁₁_min_λ = p + copysign(sqrt(scale) * sqrt(z), p)
nrm = hypot(H₂₁, H₁₁_min_λ) # could use givensAlgorithm, but hypot is likely better regardless.
return true, H₁₁_min_λ / nrm, H₂₁ / nrm
end

"""
Returns a tuple (is_single, λ) where is_single is true iff the matrix H = [H₁₁ H₁₂; H₂₁ H₂₂] has
real eigenvalues. In that case λ is a Wilkinson shift: the eigenvalue closest to H₂₂.
"""
function use_single_shift(H₁₁::T, H₁₂::T, H₂₁::T, H₂₂::T) where {T}
# TODO: Merge the scaling tricks with the above.
# Scaling to avoid losing precision in the case where we have nearly
# repeated eigenvalues.
scale = abs(H₁₁) + abs(H₁₂) + abs(H₂₁) + abs(H₂₂)
H₁₁ /= scale
H₁₂ /= scale
H₂₁ /= scale
H₂₂ /= scale

# Trace and discriminant of small eigenvalue problem. Again,
# λ = tr(H) / 2 ± √(tr(H)² / 4 - det(H))
# = (H₁₁ + H₂₂) / 2 ± √((H₁₁ - H₂₂)/2)² - H₁₂H₂₁) written in a funny way:
t = (H₁₁ + H₂₂) / 2
d = (H₁₁ - t) * (H₂₂ - t) - H₁₂ * H₂₁

# Conjugate pair: need to do a double shift.
d > zero(T) && return false, zero(T)

# The shift is picked as the closest eigenvalue of the 2x2 block near H[to,to]
sqrt_discr = sqrt(abs(d))
λ₁ = t + sqrt_discr
λ₂ = t - sqrt_discr
λ = abs(H₂₂ - λ₁) < abs(H₂₂ - λ₂) ? λ₁ : λ₂
return true, λ * scale
end

###
### Real arithmetic
###
Expand All @@ -331,12 +401,9 @@ function local_schurfact!(
# iteration count
iter = 0

@inbounds while true
@inbounds while to > start
iter += 1

if iter > maxiter
throw("QR algorithm did not converge")
end
iter > maxiter && throw("QR algorithm did not converge")

# Indexing
# `to` points to the column where the off-diagonal value was last zero.
Expand All @@ -361,71 +428,59 @@ function local_schurfact!(
# We keep `from` one column past the zero off-diagonal value, so we check whether
# the `from - 1` column has a small off-diagonal value.
from = to
while from > start && !is_offdiagonal_small(H, from - 1, tol)
while from > start
if is_offdiagonal_small(H, from - 1, tol)
H[from, from-1] = zero(T)
break
end
from -= 1
end

if from == to
# This just means H[to, to-1] == 0, so one eigenvalue converged at the end
H[from, from-1] = zero(T)
# A single eigenvalue has converged
to -= 1
else
# Now we are sure we can work with a 2×2 block H[to-1:to,to-1:to]
# We check if this block has a conjugate eigenpair, which might mean we have
# converged w.r.t. this block if from + 1 == to.
# Otherwise, if from + 1 < to, we do either a single or double shift, based on
# whether the H[to-1:to,to-1:to] part has real eigenvalues or a conjugate pair.

H₁₁, H₁₂ = H[to-1, to-1], H[to-1, to]
H₂₁, H₂₂ = H[to, to-1], H[to, to]
continue
end

# Scaling to avoid losing precision in the case where we have nearly
# repeated eigenvalues.
scale = abs(H₁₁) + abs(H₁₂) + abs(H₂₁) + abs(H₂₂)
H₁₁ /= scale
H₁₂ /= scale
H₂₁ /= scale
H₂₂ /= scale

# Trace and discriminant of small eigenvalue problem.
t = (H₁₁ + H₂₂) / 2
d = (H₁₁ - t) * (H₂₂ - t) - H₁₂ * H₂₁
sqrt_discr = sqrt(abs(d))

# Very important to have a strict comparison here!
if d < zero(T)
# Real eigenvalues.
# Note that if from + 1 == to in this case, then just one additional
# iteration is necessary, since the Wilkinson shift will do an exact shift.

# Determine the Wilkinson shift -- the closest eigenvalue of the 2x2 block
# near H[to,to]

λ₁ = t + sqrt_discr
λ₂ = t - sqrt_discr
λ = abs(H₂₂ - λ₁) < abs(H₂₂ - λ₂) ? λ₁ : λ₂
λ *= scale

# Run a bulge chase
single_shift_schur!(H, from, to, λ, Q)
else
# Conjugate pair
if from + 1 == to
# A conjugate pair has converged apparently!
if from != 1
H[from, from-1] = zero(T)
end
to -= 2
else
# Otherwise we do a double shift!
complex_shift = scale * (t + sqrt_discr * im)
double_shift_schur!(H, from, to, complex_shift, Q)
end
# We can safely work with the bottom 2×2 block C := H[to-1:to,to-1:to] now.
C₁₁, C₁₂ = H[to-1, to-1], H[to-1, to]
C₂₁, C₂₂ = H[to, to-1], H[to, to]

# A 2x2 block is always considered converged. Complex conjugates are left as a 2x2 block.
# Real eigenvalues are "manually" upper triangularized.
if from + 1 == to
# In case of real eigenvalues, it should in principle be enough to do a single
# Wilkinson shift: that would be a perfect shift, and upper triangularizes the 2x2
# block. But it can also completely destroy accuracy. So, we do this single shift with
# more accurate arithmetic with the rotation computed above.
is_real, cs, sn = upper_triangular_2x2(C₁₁, C₁₂, C₂₁, C₂₂)

if is_real
G = Rotation2(cs, sn, from)
lmul!(G, H, from, size(H, 2))
rmul!(H, G, 1, to)
rmul!(Q, G)
H[to, to-1] = zero(T)
end

to -= 2
continue
end

# Converged!
to start && break
# Real eigenvalues: single wilkinson shift. Conjugate pair: Francis double shift.
is_single, μ = use_single_shift(C₁₁, C₁₂, C₂₁, C₂₂)

if is_single
single_shift_schur!(H, from, to, μ, Q)
else
# A double shift is done by computing the first column
# of (H - μ₊I)(H - μ₋I) where μ₊ and μ₋ are the eigenvalues of C. That's identical to
# (H² - (μ₊ + μ₋)H + μ₊μ₋), and since μ₊₋ = (tr(C) ± √(tr(C)² - 4det(C))) / 2.
# So, identical to H² - tr(C)H + det(C)I.
trace = C₁₁ + C₂₂
determinant = C₁₁ * C₂₂ - C₁₂ * C₂₁
double_shift_schur!(H, from, to, trace, determinant, Q)
end
end

return true
Expand Down
47 changes: 46 additions & 1 deletion test/schurfact.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@
# Here we look at some edge cases.

using Test, LinearAlgebra
using ArnoldiMethod: eigenvalues, local_schurfact!, is_offdiagonal_small, NotWanted
using ArnoldiMethod:
eigenvalues,
local_schurfact!,
is_offdiagonal_small,
NotWanted,
use_single_shift,
upper_triangular_2x2

include("utils.jl")

Expand Down Expand Up @@ -127,3 +133,42 @@ end
@test local_schurfact!(mat(eps(T)))
end
end

@testset "Convergence issue encountered in the wild" begin
# This 4x4 matrix with almost identical eigenvalues previously caused tens of thousands
# of iterations of the QR algorithm to converge, likely due to unstable computation of
# shifts and (H - μ₁I)(H - μ₂I)e₁ column.
H1 = [
-9.000000046596169 9.363971416904122e-6 0.6216202324428521 0.783119615978767
-3.1249216068055166e-10 -9.000000125049475 -0.005030734831215954 0.026538692060151765
0.0 2.5838932886290116e-12 -8.999999884550379 -4.118678562647915e-7
0.0 0.0 5.499735555858365e-9 -8.99999994380397
]
@test local_schurfact!(H1)

# Similarly this 3x3 matrix did not converge due to catastrophic cancellation when computing
# the first column of (H - μ₁I)(H - μ₂I)e₁.
H2 = [
-9.99999999890572 -5.359512176950441e-5 0.5057150345932383
6.673511665530937e-11 -9.999999865827567 -0.0009029114103036593
0.0 1.432733142195386e-11 -10.000000096783797
]
@test local_schurfact!(H2)

end

@testset "Exactly repeated eigenvalues in 2x2 block" begin
A = Float64[1 -1/4; 1 2]

# Test for upper triangularizing a 2x2 block
is_real, c, s = upper_triangular_2x2(A'...)
@test is_real
G = [c s; -s c]
@test G * A * G' Float64[1.5 -1.25; 0 1.5]
@test G' * G I

# Test for determining what type of shift to use
is_real, λ = use_single_shift(A'...)
@test is_real
@test λ 1.5
end

0 comments on commit 8929e0b

Please sign in to comment.