Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve QR-algorithm stability #136

Merged
merged 7 commits into from
Feb 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 120 additions & 65 deletions src/schurfact.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,13 @@ function double_shift_schur!(
H::AbstractMatrix{Tv},
from::Int,
to::Int,
μ::Complex,
trace::Tv,
determinant::Tv,
Q = NotWanted(),
) where {Tv<:Real}
m, n = size(H)

# Compute the nonzero entries of p = (H - μI)(H - μ'I)e₁.
# Compute the nonzero entries of p = (H - μ₋I)(H - μI)e₁.
# Because of the Hessenberg structure we only need H[min:min+2,min:min+1] to form p.
@inbounds H₁₁ = H[from+0, from+0]
@inbounds H₂₁ = H[from+1, from+0]
Expand All @@ -165,8 +166,9 @@ function double_shift_schur!(
@inbounds H₂₂ = H[from+1, from+1]
@inbounds H₃₂ = H[from+2, from+1]

p₁ = abs2(μ) - 2real(μ) * H₁₁ + H₁₁ * H₁₁ + H₁₂ * H₂₁
p₂ = -2real(μ) * H₂₁ + H₂₁ * H₁₁ + H₂₂ * H₂₁
# Todo: avoid under/overflow.
p₁ = H₁₁ * H₁₁ + H₁₂ * H₂₁ - trace * H₁₁ + determinant
p₂ = H₂₁ * (H₁₁ + H₂₂ - trace)
p₃ = H₃₂ * H₂₁

# Map that column to a mulitiple of e₁ via two Given's rotations
Expand Down Expand Up @@ -317,6 +319,74 @@ function single_shift_schur!(
H
end

"""
Returns a tuple (is_real, c, s) where is_real is true iff the matrix H = [H₁₁ H₁₂; H₂₁ H₂₂] has
real eigenvalues. The (c, s) values form the most stable Given's rotation that makes G * H * G'
upper triangular (i.e. zeros out the bottom left entry).
"""
function upper_triangular_2x2(H₁₁::T, H₁₂::T, H₂₁::T, H₂₂::T) where {T<:Real}
# Early exit in trivial cases.
(iszero(H₂₁) || (iszero(H₁₁ - H₂₂) && sign(H₁₂) != sign(H₂₁))) &&
return false, one(T), zero(T)
iszero(H₁₂) && return true, zero(T), one(T)

# The characteristic polynomial is `λ² - tr(H)λ + det(H) = 0`
# => λ = (tr(H) ± √(tr(H)² - 4det(H))) / 2.
# Conjugate pair iff tr(H)² - 4det(H) < 0
# => (H₁₁ + H₂₂)² - 4(H₁₁H₂₂ - H₁₂H₂₁) < 0
# => ((H₁₁ - H₂₂) / 2)² + H₁₂H₂₁ < 0.
p = (H₁₁ - H₂₂) / 2
bcmax = max(abs(H₁₂), abs(H₂₁))
bcmis = min(abs(H₁₂), abs(H₂₁)) * sign(H₁₂) * sign(H₂₁)
scale = max(abs(p), bcmax)
z = (p / scale) * p + (bcmax / scale) * bcmis

# Return when complex. LAPACK also deals with 0 < z < 4eps(T), but I don't find it worth it.
# Note that the < is important, cause H = [1 -1/4; 1 2] for example is not upper triangular
# and has zero discriminant.
z < 0 && return false, one(T), zero(T)

# The rotation is basically just a "perfect" Wilkinson shift, so compute from either
# (H₁₁ - λ₁, H₁₂) or (H₁₁ - λ₂, H₁₂). We pick the option that would avoid catastrophic
# cancellation by choosing equal signs. A small rewrite:
# H₁₁ - λ with λ = (tr(H)/2 ± √((tr(H)/2)² - det(H))) means
# H₁₁ - λ = (H₁₁ - H₂₂)/2 ± √(((H₁₁ - H₂₂)/2)² - H₂₁H₁₂)
H₁₁_min_λ = p + copysign(sqrt(scale) * sqrt(z), p)
nrm = hypot(H₂₁, H₁₁_min_λ) # could use givensAlgorithm, but hypot is likely better regardless.
return true, H₁₁_min_λ / nrm, H₂₁ / nrm
end

"""
Returns a tuple (is_single, λ) where is_single is true iff the matrix H = [H₁₁ H₁₂; H₂₁ H₂₂] has
real eigenvalues. In that case λ is a Wilkinson shift: the eigenvalue closest to H₂₂.
"""
function use_single_shift(H₁₁::T, H₁₂::T, H₂₁::T, H₂₂::T) where {T}
# TODO: Merge the scaling tricks with the above.
# Scaling to avoid losing precision in the case where we have nearly
# repeated eigenvalues.
scale = abs(H₁₁) + abs(H₁₂) + abs(H₂₁) + abs(H₂₂)
H₁₁ /= scale
H₁₂ /= scale
H₂₁ /= scale
H₂₂ /= scale

# Trace and discriminant of small eigenvalue problem. Again,
# λ = tr(H) / 2 ± √(tr(H)² / 4 - det(H))
# = (H₁₁ + H₂₂) / 2 ± √((H₁₁ - H₂₂)/2)² - H₁₂H₂₁) written in a funny way:
t = (H₁₁ + H₂₂) / 2
d = (H₁₁ - t) * (H₂₂ - t) - H₁₂ * H₂₁

# Conjugate pair: need to do a double shift.
d > zero(T) && return false, zero(T)

# The shift is picked as the closest eigenvalue of the 2x2 block near H[to,to]
sqrt_discr = sqrt(abs(d))
λ₁ = t + sqrt_discr
λ₂ = t - sqrt_discr
λ = abs(H₂₂ - λ₁) < abs(H₂₂ - λ₂) ? λ₁ : λ₂
return true, λ * scale
end

###
### Real arithmetic
###
Expand All @@ -331,12 +401,9 @@ function local_schurfact!(
# iteration count
iter = 0

@inbounds while true
@inbounds while to > start
iter += 1

if iter > maxiter
throw("QR algorithm did not converge")
end
iter > maxiter && throw("QR algorithm did not converge")

# Indexing
# `to` points to the column where the off-diagonal value was last zero.
Expand All @@ -361,71 +428,59 @@ function local_schurfact!(
# We keep `from` one column past the zero off-diagonal value, so we check whether
# the `from - 1` column has a small off-diagonal value.
from = to
while from > start && !is_offdiagonal_small(H, from - 1, tol)
while from > start
if is_offdiagonal_small(H, from - 1, tol)
H[from, from-1] = zero(T)
break
end
from -= 1
end

if from == to
# This just means H[to, to-1] == 0, so one eigenvalue converged at the end
H[from, from-1] = zero(T)
# A single eigenvalue has converged
to -= 1
else
# Now we are sure we can work with a 2×2 block H[to-1:to,to-1:to]
# We check if this block has a conjugate eigenpair, which might mean we have
# converged w.r.t. this block if from + 1 == to.
# Otherwise, if from + 1 < to, we do either a single or double shift, based on
# whether the H[to-1:to,to-1:to] part has real eigenvalues or a conjugate pair.

H₁₁, H₁₂ = H[to-1, to-1], H[to-1, to]
H₂₁, H₂₂ = H[to, to-1], H[to, to]
continue
end

# Scaling to avoid losing precision in the case where we have nearly
# repeated eigenvalues.
scale = abs(H₁₁) + abs(H₁₂) + abs(H₂₁) + abs(H₂₂)
H₁₁ /= scale
H₁₂ /= scale
H₂₁ /= scale
H₂₂ /= scale

# Trace and discriminant of small eigenvalue problem.
t = (H₁₁ + H₂₂) / 2
d = (H₁₁ - t) * (H₂₂ - t) - H₁₂ * H₂₁
sqrt_discr = sqrt(abs(d))

# Very important to have a strict comparison here!
if d < zero(T)
# Real eigenvalues.
# Note that if from + 1 == to in this case, then just one additional
# iteration is necessary, since the Wilkinson shift will do an exact shift.

# Determine the Wilkinson shift -- the closest eigenvalue of the 2x2 block
# near H[to,to]

λ₁ = t + sqrt_discr
λ₂ = t - sqrt_discr
λ = abs(H₂₂ - λ₁) < abs(H₂₂ - λ₂) ? λ₁ : λ₂
λ *= scale

# Run a bulge chase
single_shift_schur!(H, from, to, λ, Q)
else
# Conjugate pair
if from + 1 == to
# A conjugate pair has converged apparently!
if from != 1
H[from, from-1] = zero(T)
end
to -= 2
else
# Otherwise we do a double shift!
complex_shift = scale * (t + sqrt_discr * im)
double_shift_schur!(H, from, to, complex_shift, Q)
end
# We can safely work with the bottom 2×2 block C := H[to-1:to,to-1:to] now.
C₁₁, C₁₂ = H[to-1, to-1], H[to-1, to]
C₂₁, C₂₂ = H[to, to-1], H[to, to]

# A 2x2 block is always considered converged. Complex conjugates are left as a 2x2 block.
# Real eigenvalues are "manually" upper triangularized.
if from + 1 == to
# In case of real eigenvalues, it should in principle be enough to do a single
# Wilkinson shift: that would be a perfect shift, and upper triangularizes the 2x2
# block. But it can also completely destroy accuracy. So, we do this single shift with
# more accurate arithmetic with the rotation computed above.
is_real, cs, sn = upper_triangular_2x2(C₁₁, C₁₂, C₂₁, C₂₂)

if is_real
G = Rotation2(cs, sn, from)
lmul!(G, H, from, size(H, 2))
rmul!(H, G, 1, to)
rmul!(Q, G)
H[to, to-1] = zero(T)
end

to -= 2
continue
end

# Converged!
to ≤ start && break
# Real eigenvalues: single wilkinson shift. Conjugate pair: Francis double shift.
is_single, μ = use_single_shift(C₁₁, C₁₂, C₂₁, C₂₂)

if is_single
single_shift_schur!(H, from, to, μ, Q)
else
# A double shift is done by computing the first column
# of (H - μ₊I)(H - μ₋I) where μ₊ and μ₋ are the eigenvalues of C. That's identical to
# (H² - (μ₊ + μ₋)H + μ₊μ₋), and since μ₊₋ = (tr(C) ± √(tr(C)² - 4det(C))) / 2.
# So, identical to H² - tr(C)H + det(C)I.
trace = C₁₁ + C₂₂
determinant = C₁₁ * C₂₂ - C₁₂ * C₂₁
double_shift_schur!(H, from, to, trace, determinant, Q)
end
end

return true
Expand Down
47 changes: 46 additions & 1 deletion test/schurfact.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@
# Here we look at some edge cases.

using Test, LinearAlgebra
using ArnoldiMethod: eigenvalues, local_schurfact!, is_offdiagonal_small, NotWanted
using ArnoldiMethod:
eigenvalues,
local_schurfact!,
is_offdiagonal_small,
NotWanted,
use_single_shift,
upper_triangular_2x2

include("utils.jl")

Expand Down Expand Up @@ -127,3 +133,42 @@ end
@test local_schurfact!(mat(eps(T)))
end
end

@testset "Convergence issue encountered in the wild" begin
# This 4x4 matrix with almost identical eigenvalues previously caused tens of thousands
# of iterations of the QR algorithm to converge, likely due to unstable computation of
# shifts and (H - μ₁I)(H - μ₂I)e₁ column.
H1 = [
-9.000000046596169 9.363971416904122e-6 0.6216202324428521 0.783119615978767
-3.1249216068055166e-10 -9.000000125049475 -0.005030734831215954 0.026538692060151765
0.0 2.5838932886290116e-12 -8.999999884550379 -4.118678562647915e-7
0.0 0.0 5.499735555858365e-9 -8.99999994380397
]
@test local_schurfact!(H1)

# Similarly this 3x3 matrix did not converge due to catastrophic cancellation when computing
# the first column of (H - μ₁I)(H - μ₂I)e₁.
H2 = [
-9.99999999890572 -5.359512176950441e-5 0.5057150345932383
6.673511665530937e-11 -9.999999865827567 -0.0009029114103036593
0.0 1.432733142195386e-11 -10.000000096783797
]
@test local_schurfact!(H2)

end

@testset "Exactly repeated eigenvalues in 2x2 block" begin
A = Float64[1 -1/4; 1 2]

# Test for upper triangularizing a 2x2 block
is_real, c, s = upper_triangular_2x2(A'...)
@test is_real
G = [c s; -s c]
@test G * A * G' ≈ Float64[1.5 -1.25; 0 1.5]
@test G' * G ≈ I

# Test for determining what type of shift to use
is_real, λ = use_single_shift(A'...)
@test is_real
@test λ ≈ 1.5
end
Loading