From ab9b7a5af3f01e63b59434e4bf5f7306c1f53732 Mon Sep 17 00:00:00 2001 From: Maxence Gollier Date: Thu, 19 Dec 2024 16:48:50 +0100 Subject: [PATCH] add transpose keyword argument for qrm_min_norm_semi_normal --- src/QRMumps.jl | 1 + src/utils.jl | 61 ++++++++++++++++++++++++++++-------------------- test/test_qrm.jl | 17 ++++++++++++++ 3 files changed, 54 insertions(+), 25 deletions(-) diff --git a/src/QRMumps.jl b/src/QRMumps.jl index 942fae6..18e4905 100644 --- a/src/QRMumps.jl +++ b/src/QRMumps.jl @@ -451,6 +451,7 @@ Remark that the Q-factor is not used in this sequence but rather A and R. * `x`: the solution vector(s). * `Δx`: an auxiliary vector (or matrix if x and b are matrices) used to compute the solution, the size of this vector (resp. matrix) is the same as x. * `y`: an auxiliary vector (or matrix if x and b are matrices) used to compute the solution, the size of this vector (resp. matrix) is the same as b. +* `transp`: whether to use A, Aᵀ or Aᴴ. Can be either `'t'`, `'c'` or `'n'`. """ function qrm_min_norm_semi_normal! end diff --git a/src/utils.jl b/src/utils.jl index fe48682..40b9d86 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -58,54 +58,65 @@ function qrm_refine!(spmat :: qrm_spmat{T}, spfct :: qrm_spfct{T}, x :: Abstract end function qrm_min_norm_semi_normal(spmat :: qrm_spmat{T}, b :: AbstractVecOrMat{T}; transp :: Char = 'n') where T + + n = transp == 'n' ? spmat.mat.n : spmat.mat.m + spfct = qrm_spfct_init(spmat) qrm_set(spfct, "qrm_keeph", 0) if typeof(b) <: AbstractVector{T} - x = similar(b, spmat.mat.n) + x = similar(b, n) else - x = similar(b, (spmat.mat.n, size(b, 2))) + x = similar(b, (n, size(b, 2))) end Δx = similar(x) y = similar(b) - qrm_min_norm_semi_normal!(spmat, spfct, b, x, Δx, y) + qrm_min_norm_semi_normal!(spmat, spfct, b, x, Δx, y, transp = transp) return x end function qrm_min_norm_semi_normal!(spmat :: qrm_spmat{T}, spfct :: qrm_spfct{T}, b :: AbstractVector{T}, x :: AbstractVector{T}, Δx :: AbstractVector{T}, y :: AbstractVector{T}; transp :: Char = 'n') where T - @assert length(x) == spmat.mat.n - @assert length(b) == spmat.mat.m - @assert length(Δx) == spmat.mat.n - @assert length(y) == spmat.mat.m - - transp = T <: Real ? 't' : 'c' + + n = transp == 'n' ? spmat.mat.n : spmat.mat.m + m = transp == 'n' ? spmat.mat.m : spmat.mat.n + t = T <: Real ? 't' : 'c' + ntransp = transp == 't' || transp == 'c' ? 'n' : t + + @assert length(x) == n + @assert length(b) == m + @assert length(Δx) == n + @assert length(y) == m - qrm_analyse!(spmat, spfct, transp = transp) - qrm_factorize!(spmat, spfct, transp = transp) - qrm_solve!(spfct, b, Δx, transp = transp) + qrm_analyse!(spmat, spfct, transp = ntransp) + qrm_factorize!(spmat, spfct, transp = ntransp) + qrm_solve!(spfct, b, Δx, transp = t) qrm_solve!(spfct, Δx, y, transp = 'n') #x = A^T y - qrm_spmat_mv!(spmat, T(1), y, T(0), x, transp = transp) + qrm_spmat_mv!(spmat, T(1), y, T(0), x, transp = ntransp) end function qrm_min_norm_semi_normal!(spmat :: qrm_spmat{T}, spfct :: qrm_spfct{T}, b :: AbstractMatrix{T}, x :: AbstractMatrix{T}, Δx :: AbstractMatrix{T}, y :: AbstractMatrix{T}; transp :: Char = 'n') where T - @assert size(x, 1) == spmat.mat.n - @assert size(b, 1) == spmat.mat.m - @assert size(Δx, 1) == spmat.mat.n - @assert size(y, 1) == spmat.mat.m + + n = transp == 'n' ? spmat.mat.n : spmat.mat.m + m = transp == 'n' ? spmat.mat.m : spmat.mat.n + t = T <: Real ? 't' : 'c' + ntransp = transp == 't' || transp == 'c' ? 'n' : t + + @assert size(x, 1) == n + @assert size(b, 1) == m + @assert size(Δx, 1) == n + @assert size(y, 1) == m @assert size(x, 2) == size(b, 2) @assert size(Δx, 2) == size(b, 2) @assert size(y, 2) == size(b, 2) - transp = T <: Real ? 't' : 'c' - - qrm_analyse!(spmat, spfct, transp = transp) - qrm_factorize!(spmat, spfct, transp = transp) - qrm_solve!(spfct, b, Δx, transp = transp) + qrm_analyse!(spmat, spfct, transp = ntransp) + qrm_factorize!(spmat, spfct, transp = ntransp) + qrm_solve!(spfct, b, Δx, transp = t) qrm_solve!(spfct, Δx, y, transp = 'n') #x = A^T y - qrm_spmat_mv!(spmat, T(1), y, T(0), x, transp = transp) + qrm_spmat_mv!(spmat, T(1), y, T(0), x, transp = ntransp) end function qrm_least_squares_semi_normal(spmat :: qrm_spmat{T}, b :: AbstractVecOrMat{T}; transp :: Char = 'n') where T @@ -148,7 +159,7 @@ function qrm_least_squares_semi_normal!(spmat :: qrm_spmat{T}, spfct :: qrm_spfc qrm_solve!(spfct, z, y, transp = t) qrm_solve!(spfct, y, x, transp = 'n') - qrm_refine!(spmat, spfct, x, z, Δx, y, transp = transp) + qrm_refine!(spmat, spfct, x, z, Δx, y) end function qrm_least_squares_semi_normal!(spmat :: qrm_spmat{T}, spfct :: qrm_spfct{T}, b :: AbstractMatrix{T}, x :: AbstractMatrix{T}, z :: AbstractMatrix{T}, Δx :: AbstractMatrix{T}, y :: AbstractMatrix{T}; transp :: Char = 'n') where T @@ -176,5 +187,5 @@ function qrm_least_squares_semi_normal!(spmat :: qrm_spmat{T}, spfct :: qrm_spfc qrm_solve!(spfct, z, y, transp = t) qrm_solve!(spfct, y, x, transp = 'n') - qrm_refine!(spmat, spfct, x, z, Δx, y, transp = transp) + qrm_refine!(spmat, spfct, x, z, Δx, y) end diff --git a/test/test_qrm.jl b/test/test_qrm.jl index a1dc620..e9f74e8 100644 --- a/test/test_qrm.jl +++ b/test/test_qrm.jl @@ -198,10 +198,19 @@ end for I in (Int32 , Int64) A = sprand(T, n, m, 0.3) A = convert(SparseMatrixCSC{T,I}, A) + + A_transp = sprand(T, m, n, 0.3) + A_transp = convert(SparseMatrixCSC{T,I}, A_transp) + b = rand(T, n) B = rand(T, n, p) + spmat = qrm_spmat_init(A) spfct = qrm_analyse(spmat, transp=transp) + + spmat_transp = qrm_spmat_init(T) + qrm_spmat_init!(spmat_transp, A_transp) + qrm_factorize!(spmat, spfct, transp=transp) spfct2 = (T <: Real) ? Transpose(spfct) : Adjoint(spfct) @@ -268,6 +277,10 @@ end r = b - A * x @test norm(r) ≤ tol + x = qrm_min_norm_semi_normal(spmat_transp, b, transp = transp) + r = b - A_transp' * x + @test norm(r) ≤ tol + X = qrm_min_norm(spmat, B) R = B - A * X @test norm(R) ≤ tol @@ -276,6 +289,10 @@ end R = B - A * X @test norm(R) ≤ tol + X = qrm_min_norm_semi_normal(spmat_transp, B, transp = transp) + R = B - A_transp' * X + @test norm(R) ≤ tol + qrm_min_norm!(spmat, b, x) r = b - A * x @test norm(r) ≤ tol