diff --git a/Project.toml b/Project.toml index 6f1e02dd..43aa719e 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.8.4" [deps] IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" diff --git a/src/ADMM.jl b/src/ADMM.jl index b63348f8..f4cbc5fe 100644 --- a/src/ADMM.jl +++ b/src/ADMM.jl @@ -11,9 +11,11 @@ mutable struct ADMM{rT,matT,opT,ropT,vecT,rvecT,preconT} <: AbstractLinearSolver β_y::vecT # fields for primal & dual variables x::vecT + xᵒˡᵈ::vecT z::Vector{vecT} zᵒˡᵈ::Vector{vecT} u::Vector{vecT} + uᵒˡᵈ::Vector{vecT} # other parameters precon::preconT ρ::rvecT # TODO: Switch all these vectors to Tuple @@ -27,12 +29,13 @@ mutable struct ADMM{rT,matT,opT,ropT,vecT,rvecT,preconT} <: AbstractLinearSolver ɛᵖʳⁱ::rvecT ɛᵈᵘᵃ::rvecT σᵃᵇˢ::rT + Δ::rvecT absTol::rT relTol::rT tolInner::rT normalizeReg::Bool regFac::rT - vary_ρ::Bool + vary_ρ::Symbol verbose::Bool end @@ -70,10 +73,10 @@ function ADMM(A::matT, x::Vector{T}=zeros(eltype(A),size(A,2)); reg=nothing, reg , relTol::Real=eps(real(T)) , tolInner::Real=1e-5 , normalizeReg::Bool=false - , vary_ρ::Bool=false + , vary_ρ::Symbol=:none , verbose::Bool=false , kargs...) where {T,matT,opT} -# TODO: The constructor is not type stable + # TODO: The constructor is not type stable # unify Floating types if typeof(ρ) <: Number @@ -96,10 +99,13 @@ function ADMM(A::matT, x::Vector{T}=zeros(eltype(A),size(A,2)); reg=nothing, reg regTrafo = [opEye(eltype(x),size(A,2)) for _=1:length(reg)] end + xᵒˡᵈ = similar(x) + # fields for primal & dual variables z = [similar(x, size(regTrafo[i],1)) for i=1:length(reg)] zᵒˡᵈ = [similar(z[i]) for i=1:length(reg)] u = [similar(z[i]) for i=1:length(reg)] + uᵒˡᵈ = [similar(u[i]) for i=1:length(reg)] # operator and fields for the update of x if AᴴA == nothing @@ -117,10 +123,10 @@ function ADMM(A::matT, x::Vector{T}=zeros(eltype(A),size(A,2)); reg=nothing, reg sᵏ = similar(rᵏ) ɛᵖʳⁱ = similar(rᵏ) ɛᵈᵘᵃ = similar(rᵏ) + Δ = similar(rᵏ) - - return ADMM(A,reg,regTrafo,AᴴA,β,β_y,x,z,zᵒˡᵈ,u,precon,ρ_vec,iterations - ,iterationsInner,statevars, rᵏ,sᵏ,ɛᵖʳⁱ,ɛᵈᵘᵃ,zero(real(T)),absTol,relTol,tolInner + return ADMM(A,reg,regTrafo,AᴴA,β,β_y,x,xᵒˡᵈ,z,zᵒˡᵈ,u,uᵒˡᵈ,precon,ρ_vec,iterations + ,iterationsInner,statevars, rᵏ,sᵏ,ɛᵖʳⁱ,ɛᵈᵘᵃ,zero(real(T)),Δ,absTol,relTol,tolInner ,normalizeReg,one(real(T)), vary_ρ, verbose) end @@ -154,8 +160,7 @@ function init!(solver::ADMM{rT,matT,opT,ropT,vecT,rvecT,preconT}, b::vecT # primal and dual variables for i=1:length(solver.reg) - solver.z[i][:] .= solver.regTrafo[i]*solver.x - solver.zᵒˡᵈ[i][:] .= 0 + solver.z[i] .= solver.regTrafo[i]*solver.x solver.u[i] .= 0 end @@ -168,6 +173,7 @@ function init!(solver::ADMM{rT,matT,opT,ropT,vecT,rvecT,preconT}, b::vecT solver.ɛᵖʳⁱ .= 0 solver.ɛᵈᵘᵃ .= 0 solver.σᵃᵇˢ = sqrt(length(b))*solver.absTol + solver.Δ .= Inf # normalization of regularization parameters if solver.normalizeReg @@ -218,28 +224,31 @@ performs one ADMM iteration. """ function iterate(solver::ADMM, iteration::Integer=0) if done(solver, iteration) return nothing end + solver.verbose && println("Outer ADMM Iteration #$iteration") # 1. solve arg min_x 1/2|| Ax-b ||² + ρ/2 Σ_i||Φi*x+ui-zi||² # <=> (A'A+ρ Σ_i Φi'Φi)*x = A'b+ρΣ_i Φi'(zi-ui) - copyto!(solver.β, solver.β_y) + solver.β .= solver.β_y AᴴA = solver.AᴴA for i=1:length(solver.reg) solver.β[:] .+= solver.ρ[i]*adjoint(solver.regTrafo[i])*(solver.z[i].-solver.u[i]) AᴴA += solver.ρ[i] * adjoint(solver.regTrafo[i]) * solver.regTrafo[i] end solver.verbose && println("conjugated gardients: ") + solver.xᵒˡᵈ .= solver.x cg!(solver.x, AᴴA, solver.β, Pl=solver.precon , maxiter=solver.iterationsInner, reltol=solver.tolInner, statevars=solver.cgStateVars, verbose = solver.verbose) for i=1:length(solver.reg) # 2. update z using the proximal map of 1/ρ*g(x) - copyto!(solver.zᵒˡᵈ[i], solver.z[i]) + solver.zᵒˡᵈ[i] .= solver.z[i] solver.z[i] .= solver.regTrafo[i]*solver.x .+ solver.u[i] if solver.ρ[i] != 0 solver.reg[i].prox!(solver.z[i], solver.regFac*solver.reg[i].λ/solver.ρ[i]; solver.reg[i].params...) end # 3. update u + solver.uᵒˡᵈ[i] .= solver.u[i] solver.u[i] .+= solver.regTrafo[i]*solver.x .- solver.z[i] # update convergence measures (one for each constraint) @@ -249,20 +258,30 @@ function iterate(solver::ADMM, iteration::Integer=0) solver.ɛᵖʳⁱ[i] = max(norm(solver.regTrafo[i]*solver.x), norm(solver.z[i])) solver.ɛᵈᵘᵃ[i] = norm(solver.ρ[i] * adjoint(solver.regTrafo[i]) * solver.u[i]) - if solver.verbose - println("rᵏ[$i] = $(solver.rᵏ[i])") - println("sᵏ[$i] = $(solver.sᵏ[i])") - end + Δᵒˡᵈ = solver.Δ[i] + solver.Δ[i] = norm(solver.x .- solver.xᵒˡᵈ ) + + norm(solver.z[i] .- solver.zᵒˡᵈ[i]) + + norm(solver.u[i] .- solver.uᵒˡᵈ[i]) - # adapt ρ according to Boyd et al. - if solver.vary_ρ && solver.rᵏ[i] > 10solver.sᵏ[i] + if (solver.vary_ρ == :balance && solver.rᵏ[i]/solver.ɛᵖʳⁱ[i] > 10solver.sᵏ[i]/solver.ɛᵈᵘᵃ[i]) || # adapt ρ according to Boyd et al. + (solver.vary_ρ == :PnP && solver.Δ[i]/Δᵒˡᵈ > 0.9) # adapt ρ according to Chang et al. solver.ρ[i] *= 2 solver.u[i] ./= 2 - solver.verbose && println("updated ρ[$i] = $(solver.ρ[i])") - elseif solver.vary_ρ && solver.sᵏ[i] > 10solver.rᵏ[i] + elseif solver.vary_ρ == :balance && solver.sᵏ[i]/solver.ɛᵈᵘᵃ[i] > 10solver.rᵏ[i]/solver.ɛᵖʳⁱ[i] solver.ρ[i] /= 2 solver.u[i] .*= 2 - solver.verbose && println("updated ρ[$i] = $(solver.ρ[i])") + end + + if solver.verbose + println("rᵏ[$i] = $(solver.rᵏ[i])") + println("sᵏ[$i] = $(solver.sᵏ[i])") + println("ɛᵖʳⁱ[$i] = $(solver.ɛᵖʳⁱ[i])") + println("ɛᵈᵘᵃ[$i] = $(solver.ɛᵈᵘᵃ[i])") + println("Δᵒˡᵈ = $(Δᵒˡᵈ)") + println("Δ[$i] = $(solver.Δ[i])") + println("Δ/Δᵒˡᵈ = $(solver.Δ[i]/Δᵒˡᵈ)") + println("current ρ[$i] = $(solver.ρ[i])") + flush(stdout) end end diff --git a/src/RegularizedLeastSquares.jl b/src/RegularizedLeastSquares.jl index 3135f9f8..1a8482b1 100644 --- a/src/RegularizedLeastSquares.jl +++ b/src/RegularizedLeastSquares.jl @@ -9,6 +9,7 @@ using IterativeSolvers using Random using VectorizationBase using VectorizationBase: shufflevector, zstridedpointer +using Polyester #@reexport using SparsityOperators using SparsityOperators: normalOperator, opEye using ProgressMeter @@ -81,7 +82,7 @@ Function returns choosen solver. * `"fista"` - Fast Iterative Shrinkage Thresholding Algorithm * `"admm"` - Alternating Direcion of Multipliers Method * `"splitBregman"` - Split Bregman method for constrained & regularized inverse problems -* `"primaldualsolver"`- First order primal dual method +* `"primaldualsolver"`- First order primal dual method """ function createLinearSolver(solver::AbstractString, A, x=zeros(eltype(A),size(A,2)); log::Bool=false, kargs...) @@ -90,7 +91,7 @@ function createLinearSolver(solver::AbstractString, A, x=zeros(eltype(A),size(A, if solver == "kaczmarz" return Kaczmarz(A; kargs...) - elseif solver == "kaczmarzUpdated" + elseif solver == "kaczmarzUpdated" return KaczmarzUpdated(A; kargs...) elseif solver == "cgnr" return CGNR(A, x; kargs...) diff --git a/src/proximalMaps/ProxLLR.jl b/src/proximalMaps/ProxLLR.jl index de3f2752..229eb634 100644 --- a/src/proximalMaps/ProxLLR.jl +++ b/src/proximalMaps/ProxLLR.jl @@ -1,61 +1,72 @@ -export proxLLR!, normLLR +export proxLLR!, normLLR, proxLLROverlapping! """ - proxLLR!(x::Vector{T}, λ::Float64=1e-6; kargs...) where T + proxLLR!(x::Vector{T}, λ=1e-6; kargs...) where T proximal map for LLR regularization using singular-value-thresholding # Arguments * `x::Vector{T}` - Vector to apply proximal map to -* `λ::Float64` - regularization parameter +* `λ` - regularization parameter * `shape::Tuple{Int}=[]` - dimensions of the image * `blockSize::Tuple{Int}=[2;2]` - size of patches to perform singluar value thresholding on * `randshift::Bool=true` - randomly shifts the patches to ensure translation invariance """ -function proxLLR!(x::Vector{T}, λ; shape::NTuple{N,TI}=error(), - blockSize::NTuple{N,TI}=ntuple(_-> 2, N), randshift::Bool=true) where {T, N,TI <: Integer} - - x = reshape(x, tuple(shape..., length(x) ÷ prod(shape))) - - block_idx = CartesianIndices(blockSize) - K = size(x)[end] - - if randshift - # Random.seed!(1234) - shift_idx = (Tuple(rand(block_idx))..., 0) - xs = circshift(x, shift_idx) - else - xs = x - end - - ext = mod.(shape,blockSize) - pad = mod.(blockSize .- ext, blockSize) - if any(pad .!= 0) - xp = zeros(T, (shape .+ pad)..., K) - xp[CartesianIndices(x)] .= xs - else - xp = xs - end - - xᴸᴸᴿ = Array{T}(undef, prod(blockSize), K) - for i ∈ CartesianIndices(StepRange.(0, blockSize, shape .- 1)) - @views xᴸᴸᴿ .= reshape(xp[i .+ block_idx,:], :, K) - # threshold singular values - SVDec = svd!(xᴸᴸᴿ) - proxL1!(SVDec.S,λ) - xp[i .+ block_idx,:] .= reshape(SVDec.U * Diagonal(SVDec.S) * SVDec.Vt, blockSize..., :) - end - - if any(pad .!= 0) - xs .= xp[CartesianIndices(xs)] - end - - if randshift - x .= circshift(xs, -1 .* shift_idx) - end - - x = vec(x) - return x +function proxLLR!( + x::Vector{T}, + λ; + shape::NTuple{N,TI}, + blockSize::NTuple{N,TI} = ntuple(_ -> 2, N), + randshift::Bool = true, +) where {T,N,TI<:Integer} + + x = reshape(x, tuple(shape..., length(x) ÷ prod(shape))) + + block_idx = CartesianIndices(blockSize) + K = size(x)[end] + + if randshift + # Random.seed!(1234) + shift_idx = (Tuple(rand(block_idx))..., 0) + xs = circshift(x, shift_idx) + else + xs = x + end + + ext = mod.(shape, blockSize) + pad = mod.(blockSize .- ext, blockSize) + if any(pad .!= 0) + xp = zeros(T, (shape .+ pad)..., K) + xp[CartesianIndices(x)] .= xs + else + xp = xs + end + + bthreads = BLAS.get_num_threads() + try + BLAS.set_num_threads(1) + xᴸᴸᴿ = [Array{T}(undef, prod(blockSize), K) for _ = 1:Threads.nthreads()] + @batch for i ∈ CartesianIndices(StepRange.(TI(0), blockSize, shape .- 1)) + @views xᴸᴸᴿ[Threads.threadid()] .= reshape(xp[i.+block_idx, :], :, K) + # threshold singular values + SVDec = svd!(xᴸᴸᴿ[Threads.threadid()]) + proxL1!(SVDec.S, λ) + xp[i.+block_idx, :] .= reshape(SVDec.U * Diagonal(SVDec.S) * SVDec.Vt, blockSize..., :) + end + finally + BLAS.set_num_threads(bthreads) + end + + if any(pad .!= 0) + xs .= xp[CartesianIndices(xs)] + end + + if randshift + x .= circshift(xs, -1 .* shift_idx) + end + + x = vec(x) + return x end """ @@ -64,48 +75,128 @@ end returns the value of the LLR-regularization term. Arguments are the same is in `proxLLR!` """ -function normLLR(x::Vector{T}, λ::Float64; shape::NTuple{N,TI}, L=1, blockSize::NTuple{N,TI}=ntuple(_-> 2, N), randshift::Bool=true, kargs...) where {N, T, TI <: Integer} - - Nvoxel = prod(shape) - K = floor(Int,length(x)/(Nvoxel*L)) - normᴸᴸᴿ = 0. - for i = 1:L - normᴸᴸᴿ += blockNuclearNorm(x[(i-1)*Nvoxel*K+1:i*Nvoxel*K], shape; blockSize=blockSize, randshift=randshift, kargs...) - end +function normLLR( + x::Vector{T}, + λ::Float64; + shape::NTuple{N,TI}, + L = 1, + blockSize::NTuple{N,TI} = ntuple(_ -> 2, N), + randshift::Bool = true, + kargs..., +) where {N,T,TI<:Integer} + + Nvoxel = prod(shape) + K = floor(Int, length(x) / (Nvoxel * L)) + normᴸᴸᴿ = 0.0 + for i = 1:L + normᴸᴸᴿ += blockNuclearNorm( + x[(i-1)*Nvoxel*K+1:i*Nvoxel*K], + shape; + blockSize = blockSize, + randshift = randshift, + kargs..., + ) + end - return λ*normᴸᴸᴿ + return λ * normᴸᴸᴿ end -function blockNuclearNorm(x::Vector{T}, shape::NTuple{N,TI}; blockSize::NTuple{N,TI}=ntuple(_-> 2, N), - randshift::Bool=true, kargs...) where {N, T, TI <: Integer} - x = reshape( x, tuple( shape...,floor(Int64, length(x)/prod(shape)) ) ) +function blockNuclearNorm( + x::Vector{T}, + shape::NTuple{N,TI}; + blockSize::NTuple{N,TI} = ntuple(_ -> 2, N), + randshift::Bool = true, + kargs..., +) where {N,T,TI<:Integer} + x = reshape(x, tuple(shape..., floor(Int64, length(x) / prod(shape)))) Wy = blockSize[1] Wz = blockSize[2] if randshift - srand(1234) - shift_idx = [rand(1:Wy) rand(1:Wz) 0] - x = circshift(x, shift_idx) + srand(1234) + shift_idx = [rand(1:Wy) rand(1:Wz) 0] + x = circshift(x, shift_idx) end ny, nz, K = size(x) # reshape into patches - L = floor(Int,ny*nz/Wy/Wz) # number of patches, assumes that image dimensions are divisble by the blocksizes + L = floor(Int, ny * nz / Wy / Wz) # number of patches, assumes that image dimensions are divisble by the blocksizes - xᴸᴸᴿ = zeros(T,Wy*Wz,L,K) - for i=1:K - xᴸᴸᴿ[:,:,i] = im2colDistinct(x[:,:,i], (Wy,Wz)) + xᴸᴸᴿ = zeros(T, Wy * Wz, L, K) + for i = 1:K + xᴸᴸᴿ[:, :, i] = im2colDistinct(x[:, :, i], (Wy, Wz)) end - xᴸᴸᴿ = permutedims(xᴸᴸᴿ,[1 3 2]) + xᴸᴸᴿ = permutedims(xᴸᴸᴿ, [1 3 2]) # L1-norm of singular values - normᴸᴸᴿ = 0. + normᴸᴸᴿ = 0.0 for i = 1:L - SVDec = svd(xᴸᴸᴿ[:,:,i]) - normᴸᴸᴿ += norm(SVDec.S,1) + SVDec = svd(xᴸᴸᴿ[:, :, i]) + normᴸᴸᴿ += norm(SVDec.S, 1) end return normᴸᴸᴿ end + + +""" +proxLLROverlapping!(x::Vector{T}, λ=1e-6; kargs...) where T + +proximal map for LLR regularization with fully overlapping blocks + +# Arguments +* `x::Vector{T}` - Vector to apply proximal map to +* `λ` - regularization parameter +* `shape::Tuple{Int}=[]` - dimensions of the image +* `blockSize::NTuple{Int}=ntuple(_ -> 2, N)` - size of patches to perform singluar value thresholding on +""" +function proxLLROverlapping!( + x::Vector{T}, + λ; + shape::NTuple{N,TI}, + blockSize::NTuple{N,TI} = ntuple(_ -> 2, N), + ) where {T,N,TI<:Integer} + + x = reshape(x, tuple(shape..., length(x) ÷ prod(shape))) + + block_idx = CartesianIndices(blockSize) + K = size(x)[end] + + ext = mod.(shape, blockSize) + pad = mod.(blockSize .- ext, blockSize) + if any(pad .!= 0) + xp = zeros(T, (shape .+ pad)..., K) + xp[CartesianIndices(x)] .= x + else + xp = copy(x) + end + + x .= 0 # from here on x is the output + + bthreads = BLAS.get_num_threads() + try + BLAS.set_num_threads(1) + xᴸᴸᴿ = [Array{T}(undef, prod(blockSize), K) for _ = 1:Threads.nthreads()] + for is ∈ block_idx + shift_idx = (Tuple(is)..., 0) + xs = circshift(xp, shift_idx) + + @batch for i ∈ CartesianIndices(StepRange.(TI(0), blockSize, shape .- 1)) + @views xᴸᴸᴿ[Threads.threadid()] .= reshape(xs[i.+block_idx, :], :, K) + + # threshold singular values + SVDec = svd!(xᴸᴸᴿ[Threads.threadid()]) + proxL1!(SVDec.S, λ) + xs[i.+block_idx, :] .= reshape(SVDec.U * Diagonal(SVDec.S) * SVDec.Vt, blockSize..., :) + end + x .+= circshift(xs, -1 .* shift_idx)[CartesianIndices(x)] + end + finally + BLAS.set_num_threads(bthreads) + end + + x ./= length(block_idx) + return vec(x) +end diff --git a/test/testProxMaps.jl b/test/testProxMaps.jl index 1ffea040..dde17dba 100644 --- a/test/testProxMaps.jl +++ b/test/testProxMaps.jl @@ -186,6 +186,34 @@ function testLLR(shape=(32,32,80),blockSize=(4,4);σ=0.05) @test 0.5*norm(xNoisy-x_llr)^2+normLLR(x_llr,10*σ,shape=shape[1:2],blockSize=blockSize,randshift=false) <= normLLR(xNoisy,10*σ,shape=shape[1:2],blockSize=blockSize,randshift=false) end +function testLLROverlapping(shape=(32,32,80),blockSize=(4,4);σ=0.05) + @info "test Overlapping LLR regularization" + Random.seed!(1234) + x = zeros(ComplexF64,shape); + for j=1:div(shape[2],blockSize[2]) + for i=1:div(shape[1],blockSize[1]) + ampl = rand() + r = rand() + for t=1:shape[3] + x[(i-1)*blockSize[1]+1:i*blockSize[1],(j-1)*blockSize[2]+1:j*blockSize[2],t] .= ampl*exp.(-r*t) + end + end + end + x = vec(x) + + xNoisy = copy(x) + σ = sum(abs.(x))/length(x)*σ + xNoisy[:] += σ/sqrt(2.0)*(randn(prod(shape))+1im*randn(prod(shape))) + + x_llr = copy(xNoisy) + proxLLROverlapping!(x_llr,10*σ,shape=shape[1:2],blockSize=blockSize) + @test norm(x - x_llr) <= norm(x - xNoisy) + @test norm(x - x_llr) / norm(x) ≈ 0 atol=0.05 + @info "rel. LLR error : $(norm(x - x_llr)/ norm(x)) vs $(norm(x - xNoisy)/ norm(x))" + # check decreas of objective function + @test 0.5*norm(xNoisy-x_llr)^2+normLLR(x_llr,10*σ,shape=shape[1:2],blockSize=blockSize,randshift=false) <= normLLR(xNoisy,10*σ,shape=shape[1:2],blockSize=blockSize,randshift=false) +end + function testLLR_3D(shape=(32,32,32,80),blockSize=(4,4,4);σ=0.05) @info "test LLR 3D regularization" Random.seed!(1234) @@ -225,5 +253,6 @@ end testProj() testNuclear() testLLR() + testLLROverlapping() testLLR_3D() end diff --git a/test/testSolvers.jl b/test/testSolvers.jl index cb89e7a3..20c9e226 100644 --- a/test/testSolvers.jl +++ b/test/testSolvers.jl @@ -1,112 +1,188 @@ Random.seed!(12345) @testset "Real Linear Solver" begin - A = rand(3,2); - x = rand(2); - b = A*x; - - solvers = linearSolverListReal() - - for solver in solvers - solverInfo = SolverInfo(Float64) - S = createLinearSolver(solver, A, iterations=100, solverInfo=solverInfo, shape=(2,1)) - x_approx = solve(S,b) - @info "Testing solver $solver ...: $x == $x_approx" - @test norm(x - x_approx) / norm(x) ≈ 0 atol=0.1 - end + A = rand(3, 2) + x = rand(2) + b = A * x + + solvers = linearSolverListReal() + + for solver in solvers + solverInfo = SolverInfo(Float64) + S = createLinearSolver( + solver, + A, + iterations = 100, + solverInfo = solverInfo, + shape = (2, 1), + ) + x_approx = solve(S, b) + @info "Testing solver $solver ...: $x == $x_approx" + @test norm(x - x_approx) / norm(x) ≈ 0 atol = 0.1 + end end @testset "Complex Linear Solver" begin - A = rand(3,2)+im*rand(3,2); - x = rand(2)+im*rand(2); - b = A*x; - - solvers = linearSolverList() - - for solver in solvers - solverInfo = SolverInfo(ComplexF64) - S = createLinearSolver(solver, A, iterations=100, solverInfo=solverInfo) - x_approx = solve(S,b) - @info "Testing solver $solver ...: $x == $x_approx" - @test norm(x - x_approx) / norm(x) ≈ 0 atol=0.1 - end + A = rand(3, 2) + im * rand(3, 2) + x = rand(2) + im * rand(2) + b = A * x + + solvers = linearSolverList() + + for solver in solvers + solverInfo = SolverInfo(ComplexF64) + S = createLinearSolver(solver, A, iterations = 100, solverInfo = solverInfo) + x_approx = solve(S, b) + @info "Testing solver $solver ...: $x == $x_approx" + @test norm(x - x_approx) / norm(x) ≈ 0 atol = 0.1 + end end @testset "General Convex Solver" begin - # fully sampled operator, image and data - N = 256 - numPeaks = 5 - F = [1 / sqrt(N)*exp(-2im * π *j*k/N) for j=0:N-1, k=0:N-1] - x = zeros(N) - for i = 1:3 - x[rand(1:N)] = rand() - end - b = 1 / sqrt(N)*fft(x) - - # random undersampling - idx = sort(unique(rand(1:N, div(N,2)))) - b = b[idx] - F = F[idx,:] - - for solver in ["fista","admm"] - reg = Regularization("L1",1e-3) + # fully sampled operator, image and data + N = 256 + numPeaks = 5 + F = [1 / sqrt(N) * exp(-2im * π * j * k / N) for j = 0:N-1, k = 0:N-1] + x = zeros(N) + for i = 1:3 + x[rand(1:N)] = rand() + end + b = 1 / sqrt(N) * fft(x) + + # random undersampling + idx = sort(unique(rand(1:N, div(N, 2)))) + b = b[idx] + F = F[idx, :] + + for solver in ["fista", "admm"] + reg = Regularization("L1", 1e-3) + solverInfo = SolverInfo(ComplexF64) + S = createLinearSolver( + solver, + F; + reg = reg, + iterations = 200, + solverInfo = solverInfo, + normalizeReg = false, + ) + x_approx = solve(S, b) + @info "Testing solver $solver ...: relative error = $(norm(x - x_approx) / norm(x))" + @test x ≈ x_approx rtol = 0.1 + + reg.λ *= length(b) / norm(b, 1) + scale_F = 1e3 # test invariance to the maximum eigenvalue + S = createLinearSolver( + solver, + F .* scale_F; + reg = reg, + iterations = 200, + solverInfo = solverInfo, + normalizeReg = true, + ) + x_approx = solve(S, b) + x_approx .*= scale_F + @info "Testing solver $solver ...: relative error = $(norm(x - x_approx) / norm(x))" + @test x ≈ x_approx rtol = 0.1 + end + + # test ADMM with option vary_ρ + solver = "admm" + reg = Regularization("L1", 1.e-3) solverInfo = SolverInfo(ComplexF64) - S = createLinearSolver(solver,F; reg=reg, iterations=200, solverInfo=solverInfo, normalizeReg=false) + S = createLinearSolver( + solver, + F; + reg = reg, + iterations = 200, + solverInfo = solverInfo, + normalizeReg = false, + ρ = 1e6, + vary_ρ = :balance, + verbose = false, + ) x_approx = solve(S, b) @info "Testing solver $solver ...: relative error = $(norm(x - x_approx) / norm(x))" - @test x ≈ x_approx rtol=0.1 - - reg.λ *= length(b)/norm(b,1) - scale_F = 1e3 # test invariance to the maximum eigenvalue - S = createLinearSolver(solver, F .* scale_F; reg=reg, iterations=200, solverInfo=solverInfo, normalizeReg=true) + @test x ≈ x_approx rtol = 0.1 + + S = createLinearSolver( + solver, + F; + reg = reg, + iterations = 200, + solverInfo = solverInfo, + normalizeReg = false, + ρ = 1e-6, + vary_ρ = :balance, + verbose = false, + ) x_approx = solve(S, b) - x_approx .*= scale_F @info "Testing solver $solver ...: relative error = $(norm(x - x_approx) / norm(x))" - @test x ≈ x_approx rtol=0.1 - end - - # test ADMM with option vary_ρ - solver = "admm" - reg = Regularization("L1",1.e-3) - solverInfo = SolverInfo(ComplexF64) - S = createLinearSolver(solver,F; reg=reg, iterations=200, solverInfo=solverInfo, normalizeReg=false, ρ=1e6, vary_ρ=true, verbose=false) - x_approx = solve(S, b) - @info "Testing solver $solver ...: relative error = $(norm(x - x_approx) / norm(x))" - @test x ≈ x_approx rtol=0.1 - - S = createLinearSolver(solver,F; reg=reg, iterations=200, solverInfo=solverInfo, normalizeReg=false, ρ=1e-6, vary_ρ=true, verbose=false) - x_approx = solve(S, b) - @info "Testing solver $solver ...: relative error = $(norm(x - x_approx) / norm(x))" - @test x ≈ x_approx rtol=0.1 - - + @test x ≈ x_approx rtol = 0.1 + + # the PnP scheme only increases ρ, hence we only test it with a small initial ρ + S = createLinearSolver( + solver, + F; + reg = reg, + iterations = 200, + solverInfo = solverInfo, + normalizeReg = false, + ρ = 1e-6, + vary_ρ = :PnP, + verbose = false, + ) + x_approx = solve(S, b) + @info "Testing solver $solver ...: relative error = $(norm(x - x_approx) / norm(x))" + @test x ≈ x_approx rtol = 0.1 - for solver in ["splitBregman"] - reg = Regularization("L1",1.e-3) + ## + solver = "splitBregman" + reg = Regularization("L1", 1.e-3) solverInfo = SolverInfo(ComplexF64) - S = createLinearSolver(solver,F; reg=reg,iterations=5,iterationsInner=40, - ρ=1.0,solverInfo=solverInfo, normalizeReg=false) + S = createLinearSolver( + solver, + F; + reg = reg, + iterations = 5, + iterationsInner = 40, + ρ = 1.0, + solverInfo = solverInfo, + normalizeReg = false, + ) x_approx = solve(S, b) @info "Testing solver $solver ...: relative error = $(norm(x - x_approx) / norm(x))" - @test x ≈ x_approx rtol=0.1 - - reg.λ *= length(b)/norm(b,1) - S = createLinearSolver(solver,F; reg=reg,iterations=5,iterationsInner=40, - ρ=1.0,solverInfo=solverInfo, normalizeReg=true) + @test x ≈ x_approx rtol = 0.1 + + reg.λ *= length(b) / norm(b, 1) + S = createLinearSolver( + solver, + F; + reg = reg, + iterations = 5, + iterationsInner = 40, + ρ = 1.0, + solverInfo = solverInfo, + normalizeReg = true, + ) x_approx = solve(S, b) @info "Testing solver $solver ...: relative error = $(norm(x - x_approx) / norm(x))" - @test x ≈ x_approx rtol=0.1 - end + @test x ≈ x_approx rtol = 0.1 - for solver in ["primaldualsolver"] - reg = [Regularization("L1",1.e-4), Regularization("TV",1.e-4)] + ## + solver = "primaldualsolver" + reg = [Regularization("L1", 1.e-4), Regularization("TV", 1.e-4)] solverInfo = SolverInfo(Float64) - FR = [real.(F./norm(F)); imag.(F./norm(F))] - bR = [real.(b./norm(F)); imag.(b./norm(F))] - S = createLinearSolver(solver,FR; reg=reg, regName=["L1","TV"], iterations=1000, solverInfo=solverInfo) + FR = [real.(F ./ norm(F)); imag.(F ./ norm(F))] + bR = [real.(b ./ norm(F)); imag.(b ./ norm(F))] + S = createLinearSolver( + solver, + FR; + reg = reg, + regName = ["L1", "TV"], + iterations = 1000, + solverInfo = solverInfo, + ) x_approx = solve(S, bR) @info "Testing solver $solver ...: relative error = $(norm(x - x_approx) / norm(x))" - @test x ≈ x_approx rtol=0.1 - end - + @test x ≈ x_approx rtol = 0.1 end