Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP #940

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft

WIP #940

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions src/block_gmres.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ kwargs_block_gmres = (:M, :N, :ldiv, :restart, :reorthogonalization, :atol, :rto
ΔX, X, W, V, Z = solver.ΔX, solver.X, solver.W, solver.V, solver.Z
C, D, R, H, τ, stats = solver.C, solver.D, solver.R, solver.H, solver.τ, solver.stats
Ψtmp = C
buffer = solver.buffer
warm_start = solver.warm_start
RNorms = stats.residuals
reset!(stats)
Expand Down Expand Up @@ -198,9 +199,9 @@ kwargs_block_gmres = (:M, :N, :ldiv, :restart, :reorthogonalization, :atol, :rto
# Initial Γ and V₁
copyto!(V[1], R₀)
if C isa Matrix
householder!(V[1], Z[1], τ[1])
householder!(V[1], Z[1], τ[1], buffer)
else
householder!(V[1], Z[1], τ[1], solver.tmp)
householder!(V[1], Z[1], τ[1], buffer, solver.tmp)
end

npass = npass + 1
Expand Down Expand Up @@ -242,17 +243,17 @@ kwargs_block_gmres = (:M, :N, :ldiv, :restart, :reorthogonalization, :atol, :rto

# Vₖ₊₁ and Ψₖ₊₁.ₖ are stored in Q and C.
if C isa Matrix
householder!(Q, C, τ[inner_iter])
householder!(Q, C, τ[inner_iter], buffer)
else
householder!(Q, C, τ[inner_iter], solver.tmp)
householder!(Q, C, τ[inner_iter], buffer, solver.tmp)
end

# Update the QR factorization of Hₖ₊₁.ₖ.
# Apply previous Householder reflections Ωᵢ.
for i = 1 : inner_iter-1
D1 .= R[nr+i]
D2 .= R[nr+i+1]
kormqr!('L', trans, H[i], τ[i], D)
kormqr!('L', trans, H[i], τ[i], D, buffer)
R[nr+i] .= D1
R[nr+i+1] .= D2
end
Expand All @@ -261,15 +262,15 @@ kwargs_block_gmres = (:M, :N, :ldiv, :restart, :reorthogonalization, :atol, :rto
H[inner_iter][1:p,:] .= R[nr+inner_iter]
H[inner_iter][p+1:2p,:] .= C
if C isa Matrix
householder!(H[inner_iter], R[nr+inner_iter], τ[inner_iter], compact=true)
householder!(H[inner_iter], R[nr+inner_iter], τ[inner_iter], buffer, compact=true)
else
householder!(H[inner_iter], R[nr+inner_iter], τ[inner_iter], solver.tmp, compact=true)
householder!(H[inner_iter], R[nr+inner_iter], τ[inner_iter], buffer, solver.tmp, compact=true)
end

# Update Zₖ = (Qₖ)ᴴΓE₁ = (Λ₁, ..., Λₖ, Λbarₖ₊₁)
D1 .= Z[inner_iter]
D2 .= zero(FC)
kormqr!('L', trans, H[inner_iter], τ[inner_iter], D)
kormqr!('L', trans, H[inner_iter], τ[inner_iter], D, buffer)
Z[inner_iter] .= D1

# Update residual norm estimate.
Expand Down
7 changes: 6 additions & 1 deletion src/block_krylov_solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ mutable struct BlockGmresSolver{T,FC,SV,SM} <: BlockKrylovSolver{T,FC,SV,SM}
H :: Vector{SM}
τ :: Vector{SV}
tmp :: SM
size_buffer:: Vector{Int}
buffer :: SV
warm_start :: Bool
stats :: SimpleStats{T}
end
Expand All @@ -125,8 +127,11 @@ function BlockGmresSolver(m, n, p, memory, SV, SM)
H = SM[SM(undef, 2p, p) for i = 1 : memory]
τ = SV[SV(undef, p) for i = 1 : memory]
tmp = C isa Matrix ? SM(undef, 0, 0) : SM(undef, p, p)
trans = FC <: AbstractFloat ? 'T' : 'C'
size_buffer = Int[kgeqrf_buffer!(V[1], τ[1]), korgqr_buffer!(V[1], τ[1]), kgeqrf_buffer!(H[1], τ[1]), korgqr_buffer!(H[1], τ[1]), kormqr_buffer!('L', trans, H[1], τ[1], D)]
buffer = C isa Matrix ? SV(undef, 0) : SV(undef, size_buffer |> maximum)
stats = SimpleStats(0, false, false, T[], T[], T[], 0.0, "unknown")
solver = BlockGmresSolver{T,FC,SV,SM}(m, n, p, ΔX, X, W, P, Q, C, D, V, Z, R, H, τ, tmp, false, stats)
solver = BlockGmresSolver{T,FC,SV,SM}(m, n, p, ΔX, X, W, P, Q, C, D, V, Z, R, H, τ, tmp, size_buffer, buffer, false, stats)
return solver
end

Expand Down
125 changes: 118 additions & 7 deletions src/block_krylov_utils.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import LinearAlgebra.BLAS.BlasInt
import LinearAlgebra.BLAS.@blasfunc
import LinearAlgebra.LAPACK.liblapack

# """
# Q, R = gs(A)
#
Expand Down Expand Up @@ -187,29 +191,136 @@ end
# Output :
# Q an n-by-k orthonormal matrix: QᴴQ = Iₖ
# R an k-by-k upper triangular matrix: QR = A
function householder(A::AbstractMatrix{FC}; compact::Bool=false) where FC <: FloatOrComplex
function householder(A::Matrix{FC}; compact::Bool=false) where FC <: FloatOrComplex
n, k = size(A)
Q = copy(A)
τ = zeros(FC, k)
R = zeros(FC, k, k)
householder!(Q, R, τ; compact)
end

function householder!(Q::AbstractMatrix{FC}, R::AbstractMatrix{FC}, τ::AbstractVector{FC}, tmp::AbstractMatrix{FC}; compact::Bool=false) where FC <: FloatOrComplex
function householder!(Q::AbstractMatrix{FC}, R::AbstractMatrix{FC}, τ::AbstractVector{FC}; compact::Bool=false) where FC <: FloatOrComplex
n, k = size(Q)
kfill!(R, zero(FC))
kgeqrf!(Q, τ)
copyto!(tmp, view(Q, 1:k, 1:k))
copy_triangle(tmp, R, k)
copy_triangle(Q, R, k)
!compact && korgqr!(Q, τ)
return Q, R
end

function householder!(Q::AbstractMatrix{FC}, R::AbstractMatrix{FC}, τ::AbstractVector{FC}; compact::Bool=false) where FC <: FloatOrComplex
function householder!(Q::AbstractMatrix{FC}, R::AbstractMatrix{FC}, τ::AbstractVector{FC}, buffer::AbstractVector{FC}; compact::Bool=false) where FC <: FloatOrComplex
n, k = size(Q)
kfill!(R, zero(FC))
kgeqrf!(Q, τ)
kgeqrf!(Q, τ, buffer)
copy_triangle(Q, R, k)
!compact && korgqr!(Q, τ)
!compact && korgqr!(Q, τ, buffer)
return Q, R
end

function householder!(Q::AbstractMatrix{FC}, R::AbstractMatrix{FC}, τ::AbstractVector{FC}, buffer::AbstractVector{FC}, tmp::AbstractMatrix{FC}; compact::Bool=false) where FC <: FloatOrComplex
n, k = size(Q)
kfill!(R, zero(FC))
kgeqrf!(Q, τ, buffer)
copyto!(tmp, view(Q, 1:k, 1:k))
copy_triangle(tmp, R, k)
!compact && korgqr!(Q, τ, buffer)
return Q, R
end

for (Xgeqrf, Xorgqr, Xormqr, T) in ((:sgeqrf_, :sorgqr_, :sormqr_, :Float32),
(:dgeqrf_, :dorgqr_, :dormqr_, :Float64),
(:cgeqrf_, :cungqr_, :cunmqr_, :ComplexF32),
(:zgeqrf_, :zungqr_, :zunmqr_, :ComplexF64))
@eval begin
function kgeqrf_buffer!(A::Matrix{$T}, tau::Vector{$T})
symb = @blasfunc($Xgeqrf)
m, n = size(A)
work = Ref{$T}(0)
lwork = Ref{BlasInt}(-1)
info = Ref{BlasInt}()
lda = max(1, stride(A,2))
@ccall liblapack.dgeqrf_64_(m::Ref{BlasInt}, n::Ref{BlasInt}, A::Ptr{$T},
lda::Ref{BlasInt}, tau::Ptr{$T}, work::Ptr{$T},
lwork::Ref{BlasInt}, info::Ptr{BlasInt})::Cvoid
return work[] |> Int
end

function kgeqrf!(A::Matrix{$T}, tau::Vector{$T}, work::Vector{$T})
symb = @blasfunc($Xgeqrf)
m, n = size(A)
lwork = Ref{BlasInt}(length(work))
info = Ref{BlasInt}()
lda = max(1, stride(A,2))
@ccall liblapack.dgeqrf_64_(m::Ref{BlasInt}, n::Ref{BlasInt}, A::Ptr{$T},
lda::Ref{BlasInt}, tau::Ptr{$T}, work::Ptr{$T},
lwork::Ref{BlasInt}, info::Ptr{BlasInt})::Cvoid
return nothing
end

function korgqr_buffer!(A::Matrix{$T}, tau::Vector{$T})
symb = @blasfunc($Xorgqr)
m, n = size(A)
k = length(tau)
work = Ref{$T}(0)
lwork = Ref{BlasInt}(-1)
info = Ref{BlasInt}()
lda = max(1, stride(A,2))
info = Ref{BlasInt}()
@ccall liblapack.dorgqr_64_(m::Ref{BlasInt}, n::Ref{BlasInt}, k::Ref{BlasInt},
A::Ptr{$T}, lda::Ref{BlasInt}, tau::Ptr{$T}, work::Ptr{$T},
lwork::Ref{BlasInt}, info::Ptr{BlasInt})::Cvoid
return work[] |> Int
end

function korgqr!(A::Matrix{$T}, tau::Vector{$T}, work::Vector{$T})
symb = @blasfunc($Xorgqr)
m, n = size(A)
k = length(tau)
lwork = Ref{BlasInt}(length(work))
info = Ref{BlasInt}()
lda = max(1, stride(A,2))
info = Ref{BlasInt}()
@ccall liblapack.dorgqr_64_(m::Ref{BlasInt}, n::Ref{BlasInt}, k::Ref{BlasInt},
A::Ptr{$T}, lda::Ref{BlasInt}, tau::Ptr{$T}, work::Ptr{$T},
lwork::Ref{BlasInt}, info::Ptr{BlasInt})::Cvoid
return nothing
end

function kormqr_buffer!(side::Char, trans::Char, A::Matrix{$T}, tau::Vector{$T}, C::Matrix{$T})
symb = @blasfunc($Xormqr)
m, n = size(A)
k = length(tau)
work = Ref{$T}(0)
lwork = Ref{BlasInt}(-1)
info = Ref{BlasInt}()
lda = max(1, stride(A,2))
ldc = max(1, stride(C,2))
@ccall liblapack.dormqr_64_(side::Ref{UInt8}, trans::Ref{UInt8}, m::Ref{BlasInt},
n::Ref{BlasInt}, k::Ref{BlasInt}, A::Ptr{$T},
lda::Ref{BlasInt}, tau::Ptr{$T}, C::Ptr{$T},
ldc::Ref{BlasInt}, work::Ptr{$T}, lwork::Ref{BlasInt},
info::Ref{BlasInt}, 1::Clong, 1::Clong)::Cvoid
return work[] |> BlasInt
end

function kormqr!(side::Char, trans::Char, A::Matrix{$T}, tau::Vector{$T}, C::Matrix{$T}, work::Vector{$T})
symb = @blasfunc($Xormqr)
m, n = size(A)
k = length(tau)
lwork = Ref{BlasInt}(length(work))
info = Ref{BlasInt}()
lda = max(1, stride(A,2))
ldc = max(1, stride(C,2))
@ccall liblapack.dormqr_64_(side::Ref{UInt8}, trans::Ref{UInt8}, m::Ref{BlasInt},
n::Ref{BlasInt}, k::Ref{BlasInt}, A::Ptr{$T},
lda::Ref{BlasInt}, tau::Ptr{$T}, C::Ptr{$T},
ldc::Ref{BlasInt}, work::Ptr{$T}, lwork::Ref{BlasInt},
info::Ref{BlasInt}, 1::Clong, 1::Clong)::Cvoid
return nothing
end
end
end

kgeqrf!(A :: AbstractMatrix{T}, tau :: AbstractVector{T}, buffer:: AbstractVector{T}) where T <: BLAS.BlasFloat = LAPACK.geqrf!(A, tau)
korgqr!(A :: AbstractMatrix{T}, tau :: AbstractVector{T}, buffer:: AbstractVector{T}) where T <: BLAS.BlasFloat = LAPACK.orgqr!(A, tau)
kormqr!(side :: Char, trans :: Char, A :: AbstractMatrix{T}, tau :: AbstractVector{T}, C :: AbstractMatrix{T}, buffer:: AbstractVector{T}) where T <: BLAS.BlasFloat = LAPACK.ormqr!(side, trans, A, tau, C)
54 changes: 54 additions & 0 deletions test/test_allocations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@

A = FC.(get_div_grad(18, 18, 18)) # Dimension m x n
m,n = size(A)
p = 5
k = div(n, 2)
Au = A[1:k,:] # Dimension k x n
Ao = A[:,1:k] # Dimension m x k
b = Ao * ones(FC, k) # Dimension m
c = Au * ones(FC, n) # Dimension k
B = A * Matrix{FC}(I, m, p) # Dimension m × p
mem = 200

T = real(FC)
Expand Down Expand Up @@ -636,6 +638,58 @@
inplace_gpmr_bytes = @allocated gpmr!(solver, Ao, Au, b, c)
@test inplace_gpmr_bytes == 0
end

@testset "BLOCK-GMRES" begin
# BLOCK-GMRES needs:
# - 2 (n*p)-matrices: X, W
# - 1 (p*p)-matrix: C
# - 1 (2p*p)-matrix: D
# - mem p-vectors: τ
# - mem (n*p)-matrices: V
# - mem (p*p)-matrices: Z
# - mem*(mem+1)/2 (p*p)-matrices: R
# - mem (2p*p)-matrices: H
function storage_block_gmres_bytes(mem, n, p)
res = (2*n*p + p*p + 2p*p + mem*p + mem*n*p + mem*p*p + mem*(mem+1)*p*p/2 + mem*2p*p)
return nbits_FC * res
end

expected_block_gmres_bytes = storage_block_gmres_bytes(mem, n, p)
block_gmres(A, B, memory=mem, itmax=mem) # warmup
actual_block_gmres_bytes = @allocated block_gmres(A, B, memory=mem, itmax=mem)
@test expected_block_gmres_bytes ≤ actual_block_gmres_bytes ≤ 1.02 * expected_block_gmres_bytes

solver = BlockGmresSolver(A, B, mem)
block_gmres!(solver, A, B) # warmup
inplace_block_gmres_bytes = @allocated block_gmres!(solver, A, B)
@test inplace_block_gmres_bytes == 0
end

# @testset "BLOCK-MINRES" begin
# # BLOCK-MINRES needs:
# # - 2 (n*p)-matrices: X, W
# # - 1 (p*p)-matrix: C
# # - 1 (2p*p)-matrix: D
# # - mem p-vectors: τ
# # - mem (n*p)-matrices: V
# # - mem (p*p)-matrices: Z
# # - mem*(mem+1)/2 (p*p)-matrices: R
# # - mem (2p*p)-matrices: H
# function storage_block_minres_bytes(mem, n, p)
# res = (2*n*p + p*p + 2p*p + mem*p + mem*n*p + mem*p*p + mem*(mem+1)*p*p/2 + mem*2p*p)
# return nbits_FC * res
# end

# expected_block_minres_bytes = storage_block_minres_bytes(mem, n, p)
# block_minres(A, B) # warmup
# actual_block_minres_bytes = @allocated block_minres(A, B)
# @test expected_block_minres_bytes ≤ actual_block_minres_bytes ≤ 1.02 * expected_block_minres_bytes

# solver = BlockGmresSolver(A, B, mem)
# block_minres!(solver, A, B) # warmup
# inplace_block_minres_bytes = @allocated block_minres!(solver, A, B)
# @test inplace_block_minres_bytes == 0
# end
end
end
end
2 changes: 2 additions & 0 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using Krylov

include("get_div_grad.jl")
include("gen_lsq.jl")
include("check_min_norm.jl")
Expand Down
Loading
Loading