From 045085a284515aa536891763dfbe2db0177f851b Mon Sep 17 00:00:00 2001 From: Andrew Mao Date: Wed, 31 May 2023 11:54:03 -0400 Subject: [PATCH 1/5] implement POGM/FISTA with adaptive restart --- src/FISTA.jl | 13 +- src/POGM.jl | 239 +++++++++++++++++++++++++++++++++ src/RegularizedLeastSquares.jl | 4 + 3 files changed, 255 insertions(+), 1 deletion(-) create mode 100644 src/POGM.jl diff --git a/src/FISTA.jl b/src/FISTA.jl index f58b4c26..1a5d5d9a 100644 --- a/src/FISTA.jl +++ b/src/FISTA.jl @@ -18,6 +18,7 @@ mutable struct FISTA{rT <: Real, vecT <: Union{AbstractVector{rT}, AbstractVecto norm_x₀::rT rel_res_norm::rT verbose::Bool + restart::Bool end """ @@ -38,6 +39,7 @@ creates a `FISTA` object for the system matrix `A`. * (`t=1.0`) - parameter for predictor-corrector step * (`relTol::Float64=1.e-5`) - tolerance for stopping criterion * (`iterations::Int64=50`) - maximum number of iterations +* (`restart=true`) - toggle whether to use adaptive GR scheme """ function FISTA(A, x::AbstractVector{T}=Vector{eltype(A)}(undef,size(A,2)); reg=nothing, regName=["L1"] , AᴴA=A'*A @@ -48,6 +50,7 @@ function FISTA(A, x::AbstractVector{T}=Vector{eltype(A)}(undef,size(A,2)); reg=n , relTol=eps(real(T)) , iterations=50 , normalizeReg=false + , restart=false , verbose = false , kargs...) where {T} @@ -65,7 +68,7 @@ function FISTA(A, x::AbstractVector{T}=Vector{eltype(A)}(undef,size(A,2)); reg=n ρ /= abs(power_iterations(AᴴA)) end - return FISTA(A, AᴴA, vec(reg)[1], x, x₀, xᵒˡᵈ, res, rT(ρ),rT(t),rT(t),iterations,rT(relTol),normalizeReg,one(rT),one(rT),rT(Inf),verbose) + return FISTA(A, AᴴA, vec(reg)[1], x, x₀, xᵒˡᵈ, res, rT(ρ),rT(t),rT(t),iterations,rT(relTol),normalizeReg,one(rT),one(rT),rT(Inf),verbose,restart) end """ @@ -162,6 +165,14 @@ function iterate(solver::FISTA, iteration::Int=0) # proximal map solver.reg.prox!(solver.x, solver.regFac*solver.ρ*solver.reg.λ; solver.reg.params...) + # gradient restart conditions + if solver.restart + if real(solver.res ⋅ (solver.x .- solver.xᵒˡᵈ) ) > 0 #if momentum is at an obtuse angle to the negative gradient + solver.verbose && println("Gradient restart at iter $iteration") + solver.t = 1 + end + end + # predictor-corrector update solver.tᵒˡᵈ = solver.t solver.t = (1 + sqrt(1 + 4 * solver.tᵒˡᵈ^2)) / 2 diff --git a/src/POGM.jl b/src/POGM.jl new file mode 100644 index 00000000..b19871da --- /dev/null +++ b/src/POGM.jl @@ -0,0 +1,239 @@ +export pogm, POGM + +mutable struct POGM{rT <: Real, vecT <: Union{AbstractVector{rT}, AbstractVector{Complex{rT}}}, matA, matAHA} <: AbstractLinearSolver + A::matA + AᴴA::matAHA + reg::Regularization + x::vecT + x₀::vecT + xᵒˡᵈ::vecT + y::vecT + z::vecT + w::vecT + res::vecT + ρ::rT + t::rT + tᵒˡᵈ::rT + α::rT + β::rT + γ::rT + γᵒˡᵈ::rT + σ::rT + σ_fac::rT + iterations::Int64 + relTol::rT + normalizeReg::Bool + regFac::rT + norm_x₀::rT + rel_res_norm::rT + verbose::Bool + restart::Symbol +end + +""" + POGM(A, x::vecT=zeros(eltype(A),size(A,2)) + ; reg=nothing, regName=["L1"], λ=[0.0], kargs...) + +creates a `POGM` object for the system matrix `A`. +POGM has a 2x better worst-case bound than FISTA, but actual performance varies by application. +It stores 3 extra intermediate variables the size of the image compared to FISTA +Only gradient restart scheme is implemented for now + +References: +- A.B. Taylor, J.M. Hendrickx, F. Glineur, + "Exact worst-case performance of first-order algorithms + for composite convex optimization," Arxiv:1512.07516, 2015, + SIAM J. Opt. 2017 + [http://doi.org/10.1137/16m108104x] +- Kim, D., & Fessler, J. A. (2018). + Adaptive Restart of the Optimized Gradient Method for Convex Optimization. + Journal of Optimization Theory and Applications, 178(1), 240–263. + [https://doi.org/10.1007/s10957-018-1287-4] + +# Arguments +* `A` - system matrix +* `x::vecT` - array with the same type and size as the solution +* (`reg=nothing`) - regularization object +* (`regName=["L1"]`) - name of the Regularization to use (if reg==nothing) +* (`AᴴA=A'*A`) - specialized normal operator, default is `A'*A` +* (`λ=0`) - regularization parameter +* (`ρ=0.95`) - step size for gradient step +* (`normalize_ρ=false`) - normalize step size by the maximum eigenvalue of `AᴴA` +* (`t=1.0`) - parameter for predictor-corrector step +* (`σ_fac=1.0`) - parameter for decreasing γ-momentum ∈ [0,1] +* (`relTol::Float64=1.e-5`) - tolerance for stopping criterion +* (`iterations::Int64=50`) - maximum number of iterations +* (`restart::Symbol=:none`) - :none, :gradient options for restarting +""" +function POGM(A, x::AbstractVector{T}=Vector{eltype(A)}(undef,size(A,2)); reg=nothing, regName=["L1"] + , AᴴA=A'*A + , λ=0 + , ρ=0.95 + , normalize_ρ=true + , t=1 + , σ_fac=1.0 + , relTol=eps(real(T)) + , iterations=50 + , normalizeReg=false + , restart = :none + , verbose = false + , kargs...) where {T} + + rT = real(T) + if reg == nothing + reg = Regularization(regName, λ, kargs...) + end + + x₀ = similar(x) + xᵒˡᵈ = similar(x) + y = similar(x) + z = similar(x) + w = similar(x) + res = similar(x) + res[1] = Inf # avoid spurious convergence in first iterations + + if normalize_ρ + ρ /= abs(power_iterations(AᴴA)) + end + + return POGM(A, AᴴA, vec(reg)[1], x, x₀, xᵒˡᵈ, y, z, w, res, rT(ρ),rT(t),rT(t),rT(0),rT(1),rT(1),rT(1),rT(1),rT(σ_fac), + iterations,rT(relTol),normalizeReg,one(rT),one(rT),rT(Inf),verbose,restart) +end + +""" + init!(it::POGM, b::vecT + ; A=solver.A + , x::vecT=similar(b,0) + , t::Float64=1.0) where T + +(re-) initializes the POGM iterator +""" +function init!(solver::POGM{rT,vecT,matA,matAHA}, b::vecT + ; x::vecT=similar(b,0) + , t=1 + ) where {rT,vecT,matA,matAHA} + + solver.x₀ .= adjoint(solver.A) * b + solver.norm_x₀ = norm(solver.x₀) + + if isempty(x) + solver.x .= 0 + else + solver.x .= x + end + solver.xᵒˡᵈ .= 0 # makes no difference in 1st iteration what this is set to + solver.y .= 0 + solver.z .= 0 + if solver.restart != :none #save memory if not using restart + solver.w .= 0 + end + + solver.t = t + solver.tᵒˡᵈ = t + solver.γ = 1 #doesn't matter + solver.γᵒˡᵈ = 1 #doesn't matter + solver.σ = 1 + # normalization of regularization parameters + if solver.normalizeReg + solver.regFac = norm(solver.x₀,1)/length(solver.x₀) + else + solver.regFac = 1 + end +end + +""" + solve(solver::POGM, b::Vector) + +solves an inverse problem using POGM. + +# Arguments +* `solver::POGM` - the solver containing both system matrix and regularizer +* `b::vecT` - data vector +* `A=solver.A` - operator for the data-term of the problem +* (`startVector::vecT=similar(b,0)`) - initial guess for the solution +* (`solverInfo=nothing`) - solverInfo object + +when a `SolverInfo` objects is passed, the residuals are stored in `solverInfo.convMeas`. +""" +function solve(solver::POGM, b; A=solver.A, startVector=similar(b,0), solverInfo=nothing, kargs...) + # initialize solver parameters + init!(solver, b; x=startVector) + + # log solver information + solverInfo != nothing && storeInfo(solverInfo,solver.x,norm(solver.res)) + + # perform POGM iterations + for (iteration, item) = enumerate(solver) + solverInfo != nothing && storeInfo(solverInfo,solver.x,norm(solver.res)) + end + + return solver.x +end + +""" + iterate(it::POGM, iteration::Int=0) + +performs one POGM iteration. +""" +function iterate(solver::POGM, iteration::Int=0) + if done(solver, iteration) return nothing end + + # calculate residuum and do gradient step + # solver.x .-= solver.ρ .* (solver.AᴴA * solver.x .- solver.x₀) + solver.xᵒˡᵈ .= solver.x #save this for inertia step later + mul!(solver.res, solver.AᴴA, solver.x) + solver.res .-= solver.x₀ + solver.x .-= solver.ρ .* solver.res + + solver.rel_res_norm = norm(solver.res) / solver.norm_x₀ + solver.verbose && println("Iteration $iteration; rel. residual = $(solver.rel_res_norm)") + + # inertial parameters + solver.tᵒˡᵈ = solver.t + if iteration == solver.iterations - 1 && solver.restart != :none #the convergence rate depends on choice of # iterations! + solver.t = (1 + sqrt(1 + 8 * solver.tᵒˡᵈ^2)) / 2 + else + solver.t = (1 + sqrt(1 + 4 * solver.tᵒˡᵈ^2)) / 2 + end + solver.α = (solver.tᵒˡᵈ - 1) / solver.t + solver.β = solver.σ * solver.tᵒˡᵈ / solver.t + solver.γᵒˡᵈ = solver.γ + if solver.restart == :gradient + solver.γ = solver.ρ * (1 + solver.α + solver.β) + else + solver.γ = solver.ρ * (2solver.tᵒˡᵈ + solver.t - 1) / solver.t + end + + # inertia steps + # x + α * (x - y) + β * (x - xᵒˡᵈ) + ρα/γᵒˡᵈ * (z - xᵒˡᵈ) + tmp = solver.y + solver.y = solver.x + solver.x = tmp # swap x and y + solver.x .*= -solver.α # here we calculate -α * y, where y is now stored in x + solver.x .+= (1 + solver.α + solver.β) .* solver.y + solver.x .-= (solver.β + solver.ρ * solver.α / solver.γᵒˡᵈ) .* solver.xᵒˡᵈ + solver.x .+= solver.ρ * solver.α / solver.γᵒˡᵈ .* solver.z + solver.z .= solver.x #store this for next iteration and GR + + # proximal map + solver.reg.prox!(solver.x, solver.regFac*solver.reg.λ*solver.γ; solver.reg.params...) + + # gradient restart conditions + if solver.restart == :gradient + if real((solver.y + solver.ρ / solver.γ .* (solver.x .- solver.z) .- solver.w) ⋅ ((solver.x .- solver.z) ./ solver.γ .- solver.res)) < 0 + solver.verbose && println("Gradient restart at iter $iteration") + solver.σ = 1 + solver.t = 1 + else #decreasing γ + solver.σ *= solver.σ_fac + end + solver.w = solver.y + solver.ρ / solver.γ .* (solver.x .- solver.z) #this computation is doubled to avoid having to store wᵒˡᵈ + end + + # return the residual-norm as item and iteration number as state + return solver, iteration+1 +end + +@inline converged(solver::POGM) = (solver.rel_res_norm < solver.relTol) + +@inline done(solver::POGM,iteration) = converged(solver) || iteration>=solver.iterations diff --git a/src/RegularizedLeastSquares.jl b/src/RegularizedLeastSquares.jl index 76bc4ebf..80e3b06a 100644 --- a/src/RegularizedLeastSquares.jl +++ b/src/RegularizedLeastSquares.jl @@ -49,6 +49,7 @@ include("CGNR.jl") include("Direct.jl") include("FusedLasso.jl") include("FISTA.jl") +include("POGM.jl") include("ADMM.jl") include("SplitBregman.jl") include("PrimalDualSolver.jl") @@ -83,6 +84,7 @@ Function returns choosen solver. * `"pseudoinverse"` - approximates a solution using the More-Penrose pseudo inverse * `"fusedlasso"` - solver for the Fused-Lasso problem * `"fista"` - Fast Iterative Shrinkage Thresholding Algorithm +* `"pogm"` - Proximal Optimal Gradient Method * `"admm"` - Alternating Direcion of Multipliers Method * `"splitBregman"` - Split Bregman method for constrained & regularized inverse problems * `"primaldualsolver"`- First order primal dual method @@ -110,6 +112,8 @@ function createLinearSolver(solver::AbstractString, A, x=zeros(eltype(A),size(A, return FusedLasso(A; kargs...) elseif solver == "fista" return FISTA(A, x; kargs...) + elseif solver == "pogm" + return POGM(A, x; kargs...) elseif solver == "admm" return ADMM(A, x; kargs...) elseif solver == "splitBregman" From edfa225abad2f20dc9f08ccfe07c540e2f84283a Mon Sep 17 00:00:00 2001 From: Andrew Mao Date: Wed, 31 May 2023 12:13:34 -0400 Subject: [PATCH 2/5] change restart options to symbol & add tests --- src/FISTA.jl | 8 ++++---- test/testSolvers.jl | 21 +++++++++++++++++++-- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/src/FISTA.jl b/src/FISTA.jl index 1a5d5d9a..87638a9c 100644 --- a/src/FISTA.jl +++ b/src/FISTA.jl @@ -18,7 +18,7 @@ mutable struct FISTA{rT <: Real, vecT <: Union{AbstractVector{rT}, AbstractVecto norm_x₀::rT rel_res_norm::rT verbose::Bool - restart::Bool + restart::Symbol end """ @@ -39,7 +39,7 @@ creates a `FISTA` object for the system matrix `A`. * (`t=1.0`) - parameter for predictor-corrector step * (`relTol::Float64=1.e-5`) - tolerance for stopping criterion * (`iterations::Int64=50`) - maximum number of iterations -* (`restart=true`) - toggle whether to use adaptive GR scheme +* (`restart::Symbol=:none`) - :none, :gradient options for restarting """ function FISTA(A, x::AbstractVector{T}=Vector{eltype(A)}(undef,size(A,2)); reg=nothing, regName=["L1"] , AᴴA=A'*A @@ -50,7 +50,7 @@ function FISTA(A, x::AbstractVector{T}=Vector{eltype(A)}(undef,size(A,2)); reg=n , relTol=eps(real(T)) , iterations=50 , normalizeReg=false - , restart=false + , restart = :none , verbose = false , kargs...) where {T} @@ -166,7 +166,7 @@ function iterate(solver::FISTA, iteration::Int=0) solver.reg.prox!(solver.x, solver.regFac*solver.ρ*solver.reg.λ; solver.reg.params...) # gradient restart conditions - if solver.restart + if solver.restart == :gradient if real(solver.res ⋅ (solver.x .- solver.xᵒˡᵈ) ) > 0 #if momentum is at an obtuse angle to the negative gradient solver.verbose && println("Gradient restart at iter $iteration") solver.t = 1 diff --git a/test/testSolvers.jl b/test/testSolvers.jl index 20c9e226..93736009 100644 --- a/test/testSolvers.jl +++ b/test/testSolvers.jl @@ -54,7 +54,7 @@ end b = b[idx] F = F[idx, :] - for solver in ["fista", "admm"] + for solver in ["pogm", "fista", "admm"] reg = Regularization("L1", 1e-3) solverInfo = SolverInfo(ComplexF64) S = createLinearSolver( @@ -69,8 +69,25 @@ end @info "Testing solver $solver ...: relative error = $(norm(x - x_approx) / norm(x))" @test x ≈ x_approx rtol = 0.1 + #additionally test the gradient restarting scheme + if solver == "pogm" || solver == "fista" + S = createLinearSolver( + solver, + F; + reg = reg, + iterations = 200, + solverInfo = solverInfo, + normalizeReg = false, + restart = :gradient, + ) + x_approx = solve(S, b) + @info "Testing solver $solver ...: relative error = $(norm(x - x_approx) / norm(x))" + @test x ≈ x_approx rtol = 0.1 + end + + # test invariance to the maximum eigenvalue reg.λ *= length(b) / norm(b, 1) - scale_F = 1e3 # test invariance to the maximum eigenvalue + scale_F = 1e3 S = createLinearSolver( solver, F .* scale_F; From 0db04c47b52b9ccea334e6ca4fda2b5a33490519 Mon Sep 17 00:00:00 2001 From: Jakob Asslaender Date: Wed, 31 May 2023 17:15:25 -0400 Subject: [PATCH 3/5] remove some allocations --- src/POGM.jl | 13 +++++++------ test/testSolvers.jl | 6 +++--- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/POGM.jl b/src/POGM.jl index b19871da..401238c9 100644 --- a/src/POGM.jl +++ b/src/POGM.jl @@ -45,9 +45,9 @@ References: for composite convex optimization," Arxiv:1512.07516, 2015, SIAM J. Opt. 2017 [http://doi.org/10.1137/16m108104x] -- Kim, D., & Fessler, J. A. (2018). +- Kim, D., & Fessler, J. A. (2018). Adaptive Restart of the Optimized Gradient Method for Convex Optimization. - Journal of Optimization Theory and Applications, 178(1), 240–263. + Journal of Optimization Theory and Applications, 178(1), 240–263. [https://doi.org/10.1007/s10957-018-1287-4] # Arguments @@ -124,7 +124,7 @@ function init!(solver::POGM{rT,vecT,matA,matAHA}, b::vecT solver.xᵒˡᵈ .= 0 # makes no difference in 1st iteration what this is set to solver.y .= 0 solver.z .= 0 - if solver.restart != :none #save memory if not using restart + if solver.restart != :none #save time if not using restart solver.w .= 0 end @@ -220,14 +220,15 @@ function iterate(solver::POGM, iteration::Int=0) # gradient restart conditions if solver.restart == :gradient - if real((solver.y + solver.ρ / solver.γ .* (solver.x .- solver.z) .- solver.w) ⋅ ((solver.x .- solver.z) ./ solver.γ .- solver.res)) < 0 + solver.w .+= solver.y .+ solver.ρ ./ solver.γ .* (solver.x .- solver.z) + if real((solver.w ⋅ solver.x - solver.w ⋅ solver.z) / solver.γ - solver.w ⋅ solver.res) < 0 solver.verbose && println("Gradient restart at iter $iteration") solver.σ = 1 solver.t = 1 - else #decreasing γ + else # decreasing γ solver.σ *= solver.σ_fac end - solver.w = solver.y + solver.ρ / solver.γ .* (solver.x .- solver.z) #this computation is doubled to avoid having to store wᵒˡᵈ + solver.w .= solver.ρ / solver.γ .* (solver.z .- solver.x) .- solver.y end # return the residual-norm as item and iteration number as state diff --git a/test/testSolvers.jl b/test/testSolvers.jl index 93736009..abdaa686 100644 --- a/test/testSolvers.jl +++ b/test/testSolvers.jl @@ -66,7 +66,7 @@ end normalizeReg = false, ) x_approx = solve(S, b) - @info "Testing solver $solver ...: relative error = $(norm(x - x_approx) / norm(x))" + @info "Testing solver $solver w/o restart: relative error = $(norm(x - x_approx) / norm(x))" @test x ≈ x_approx rtol = 0.1 #additionally test the gradient restarting scheme @@ -81,7 +81,7 @@ end restart = :gradient, ) x_approx = solve(S, b) - @info "Testing solver $solver ...: relative error = $(norm(x - x_approx) / norm(x))" + @info "Testing solver $solver w/ gradient restart: relative error = $(norm(x - x_approx) / norm(x))" @test x ≈ x_approx rtol = 0.1 end @@ -98,7 +98,7 @@ end ) x_approx = solve(S, b) x_approx .*= scale_F - @info "Testing solver $solver ...: relative error = $(norm(x - x_approx) / norm(x))" + @info "Testing solver $solver w/o restart and after re-scaling: relative error = $(norm(x - x_approx) / norm(x))" @test x ≈ x_approx rtol = 0.1 end From a2641056b35499fdc18d9ff1652a3416a464e6ce Mon Sep 17 00:00:00 2001 From: Andrew Mao Date: Thu, 1 Jun 2023 20:27:50 -0400 Subject: [PATCH 4/5] implement optista --- src/OptISTA.jl | 211 +++++++++++++++++++++++++++++++++ src/RegularizedLeastSquares.jl | 3 + 2 files changed, 214 insertions(+) create mode 100644 src/OptISTA.jl diff --git a/src/OptISTA.jl b/src/OptISTA.jl new file mode 100644 index 00000000..0784af7e --- /dev/null +++ b/src/OptISTA.jl @@ -0,0 +1,211 @@ +export OptISTA + +mutable struct OptISTA{rT <: Real, vecT <: Union{AbstractVector{rT}, AbstractVector{Complex{rT}}}, matA, matAHA} <: AbstractLinearSolver + A::matA + AᴴA::matAHA + reg::Regularization + x::vecT + x₀::vecT + y::vecT + z::vecT + zᵒˡᵈ::vecT + res::vecT + ρ::rT + θ::rT + θᵒˡᵈ::rT + θn::rT + α::rT + β::rT + γ::rT + iterations::Int64 + relTol::rT + normalizeReg::Bool + regFac::rT + norm_x₀::rT + rel_res_norm::rT + verbose::Bool +end + +""" + OptISTA(A, x::vecT=zeros(eltype(A),size(A,2)) + ; reg=nothing, regName=["L1"], λ=[0.0], kargs...) + +creates a `OptISTA` object for the system matrix `A`. +OptISTA has a 2x better worst-case bound than FISTA, but actual performance varies by application. +It stores 3 extra intermediate variables the size of the image compared to FISTA + +Reference: +- Uijeong Jang, Shuvomoy Das Gupta, Ernest K. Ryu, + "Computer-Assisted Design of Accelerated Composite + Optimization Methods: OptISTA," arXiv:2305.15704, 2023, + [https://arxiv.org/abs/2305.15704] + +# Arguments +* `A` - system matrix +* `x::vecT` - array with the same type and size as the solution +* (`reg=nothing`) - regularization object +* (`regName=["L1"]`) - name of the Regularization to use (if reg==nothing) +* (`AᴴA=A'*A`) - specialized normal operator, default is `A'*A` +* (`λ=0`) - regularization parameter +* (`ρ=0.95`) - step size for gradient step +* (`normalize_ρ=false`) - normalize step size by the maximum eigenvalue of `AᴴA` +* (`θ=1.0`) - parameter for predictor-corrector step +* (`relTol::Float64=1.e-5`) - tolerance for stopping criterion +* (`iterations::Int64=50`) - maximum number of iterations +""" +function OptISTA(A, x::AbstractVector{T}=Vector{eltype(A)}(undef,size(A,2)); reg=nothing, regName=["L1"] + , AᴴA=A'*A + , λ=0 + , ρ=0.95 + , normalize_ρ=true + , θ=1 + , relTol=eps(real(T)) + , iterations=50 + , normalizeReg=false + , verbose = false + , kargs...) where {T} + + rT = real(T) + if reg == nothing + reg = Regularization(regName, λ, kargs...) + end + + x₀ = similar(x) + y = similar(x) + z = similar(x) + zᵒˡᵈ = similar(x) + res = similar(x) + res[1] = Inf # avoid spurious convergence in first iterations + + if normalize_ρ + ρ /= abs(power_iterations(AᴴA)) + end + θn = 1 + for i = 1:(iterations-1) + θn = (1 + sqrt(1 + 4 * θn^2)) / 2 + end + θn = (1 + sqrt(1 + 8 * θn^2)) / 2 + + return OptISTA(A, AᴴA, vec(reg)[1], x, x₀, y, z, zᵒˡᵈ, res, rT(ρ),rT(θ),rT(θ),rT(θn),rT(0),rT(1),rT(1), + iterations,rT(relTol),normalizeReg,one(rT),one(rT),rT(Inf),verbose) +end + +""" + init!(it::OptISTA, b::vecT + ; A=solver.A + , x::vecT=similar(b,0) + , t::Float64=1.0) where T + +(re-) initializes the OptISTA iterator +""" +function init!(solver::OptISTA{rT,vecT,matA,matAHA}, b::vecT + ; x::vecT=similar(b,0) + , θ=1 + ) where {rT,vecT,matA,matAHA} + + solver.x₀ .= adjoint(solver.A) * b + solver.norm_x₀ = norm(solver.x₀) + + if isempty(x) + solver.x .= 0 + else + solver.x .= x + end + solver.y .= 0 + solver.z .= 0 + solver.zᵒˡᵈ .= 0 + + solver.θ = θ + solver.θᵒˡᵈ = θ + solver.θn = θ + for i = 1:(solver.iterations-1) + solver.θn = (1 + sqrt(1 + 4 * solver.θn^2)) / 2 + end + solver.θn = (1 + sqrt(1 + 8 * solver.θn^2)) / 2 + + # normalization of regularization parameters + if solver.normalizeReg + solver.regFac = norm(solver.x₀,1)/length(solver.x₀) + else + solver.regFac = 1 + end +end + +""" + solve(solver::OptISTA, b::Vector) + +solves an inverse problem using OptISTA. + +# Arguments +* `solver::OptISTA` - the solver containing both system matrix and regularizer +* `b::vecT` - data vector +* `A=solver.A` - operator for the data-term of the problem +* (`startVector::vecT=similar(b,0)`) - initial guess for the solution +* (`solverInfo=nothing`) - solverInfo object + +when a `SolverInfo` objects is passed, the residuals are stored in `solverInfo.convMeas`. +""" +function solve(solver::OptISTA, b; A=solver.A, startVector=similar(b,0), solverInfo=nothing, kargs...) + # initialize solver parameters + init!(solver, b; x=startVector) + + # log solver information + solverInfo != nothing && storeInfo(solverInfo,solver.x,norm(solver.res)) + + # perform OptISTA iterations + for (iteration, item) = enumerate(solver) + solverInfo != nothing && storeInfo(solverInfo,solver.x,norm(solver.res)) + end + + return solver.x +end + +""" + iterate(it::OptISTA, iteration::Int=0) + +performs one OptISTA iteration. +""" +function iterate(solver::OptISTA, iteration::Int=0) + if done(solver, iteration) return nothing end + + # inertial parameters + solver.θᵒˡᵈ = solver.θ + if iteration == solver.iterations - 1 #the convergence rate depends on choice of # iterations! + solver.θ = (1 + sqrt(1 + 8 * solver.θᵒˡᵈ^2)) / 2 + else + solver.θ = (1 + sqrt(1 + 4 * solver.θᵒˡᵈ^2)) / 2 + end + solver.α = (solver.θᵒˡᵈ - 1) / solver.θ + solver.β = solver.θᵒˡᵈ / solver.θ + solver.γ = 2solver.θ / solver.θn^2 * (solver.θn^2 - 2solver.θ^2 + solver.θ) + + # calculate residuum and do gradient step + # solver.y .-= solver.ρ * solver.γ .* (solver.AᴴA * solver.x .- solver.x₀) + solver.zᵒˡᵈ .= solver.z #store this for inertia step + solver.z .= solver.y #save yᵒˡᵈ in the variable z + mul!(solver.res, solver.AᴴA, solver.x) + solver.res .-= solver.x₀ + solver.y .-= solver.ρ * solver.γ .* solver.res + + solver.rel_res_norm = norm(solver.res) / solver.norm_x₀ + solver.verbose && println("Iteration $iteration; rel. residual = $(solver.rel_res_norm)") + + # proximal map + solver.reg.prox!(solver.y, solver.regFac*solver.reg.λ*solver.ρ*solver.γ; solver.reg.params...) + + # inertia steps + # z = x + (y - yᵒˡᵈ) / γ + # x = z + α * (z - zᵒˡᵈ) + β * (z - x) + solver.z ./= -solver.γ #yᵒˡᵈ is already stored in z + solver.z .+= solver.x .+ solver.y ./ solver.γ + solver.x .*= -solver.β + solver.x .+= (1 + solver.α + solver.β) .* solver.z + solver.x .-= solver.α .* solver.zᵒˡᵈ + + # return the residual-norm as item and iteration number as state + return solver, iteration+1 +end + +@inline converged(solver::OptISTA) = (solver.rel_res_norm < solver.relTol) + +@inline done(solver::OptISTA,iteration) = converged(solver) || iteration>=solver.iterations diff --git a/src/RegularizedLeastSquares.jl b/src/RegularizedLeastSquares.jl index 80e3b06a..c9b3c162 100644 --- a/src/RegularizedLeastSquares.jl +++ b/src/RegularizedLeastSquares.jl @@ -49,6 +49,7 @@ include("CGNR.jl") include("Direct.jl") include("FusedLasso.jl") include("FISTA.jl") +include("OptISTA.jl") include("POGM.jl") include("ADMM.jl") include("SplitBregman.jl") @@ -112,6 +113,8 @@ function createLinearSolver(solver::AbstractString, A, x=zeros(eltype(A),size(A, return FusedLasso(A; kargs...) elseif solver == "fista" return FISTA(A, x; kargs...) + elseif solver == "optfista" + return OptFISTA(A, x; kargs...) elseif solver == "pogm" return POGM(A, x; kargs...) elseif solver == "admm" From 73e740b519ba26fa2101460264e7e40ede95113c Mon Sep 17 00:00:00 2001 From: Andrew Mao Date: Fri, 2 Jun 2023 11:40:58 -0400 Subject: [PATCH 5/5] pass tests & cleanup --- src/FISTA.jl | 6 +++--- src/OptISTA.jl | 22 +++++++++++----------- src/POGM.jl | 8 +++----- src/RegularizedLeastSquares.jl | 5 +++-- test/testSolvers.jl | 2 +- 5 files changed, 21 insertions(+), 22 deletions(-) diff --git a/src/FISTA.jl b/src/FISTA.jl index 87638a9c..19db113e 100644 --- a/src/FISTA.jl +++ b/src/FISTA.jl @@ -55,7 +55,7 @@ function FISTA(A, x::AbstractVector{T}=Vector{eltype(A)}(undef,size(A,2)); reg=n , kargs...) where {T} rT = real(T) - if reg == nothing + if reg === nothing reg = Regularization(regName, λ, kargs...) end @@ -123,11 +123,11 @@ function solve(solver::FISTA, b; A=solver.A, startVector=similar(b,0), solverInf init!(solver, b; x=startVector) # log solver information - solverInfo != nothing && storeInfo(solverInfo,solver.x,norm(solver.res)) + solverInfo !== nothing && storeInfo(solverInfo,solver.x,norm(solver.res)) # perform FISTA iterations for (iteration, item) = enumerate(solver) - solverInfo != nothing && storeInfo(solverInfo,solver.x,norm(solver.res)) + solverInfo !== nothing && storeInfo(solverInfo,solver.x,norm(solver.res)) end return solver.x diff --git a/src/OptISTA.jl b/src/OptISTA.jl index 0784af7e..c81e3bf9 100644 --- a/src/OptISTA.jl +++ b/src/OptISTA.jl @@ -1,4 +1,4 @@ -export OptISTA +export optista, OptISTA mutable struct OptISTA{rT <: Real, vecT <: Union{AbstractVector{rT}, AbstractVector{Complex{rT}}}, matA, matAHA} <: AbstractLinearSolver A::matA @@ -32,7 +32,7 @@ end creates a `OptISTA` object for the system matrix `A`. OptISTA has a 2x better worst-case bound than FISTA, but actual performance varies by application. -It stores 3 extra intermediate variables the size of the image compared to FISTA +It stores 2 extra intermediate variables the size of the image compared to FISTA Reference: - Uijeong Jang, Shuvomoy Das Gupta, Ernest K. Ryu, @@ -66,7 +66,7 @@ function OptISTA(A, x::AbstractVector{T}=Vector{eltype(A)}(undef,size(A,2)); reg , kargs...) where {T} rT = real(T) - if reg == nothing + if reg === nothing reg = Regularization(regName, λ, kargs...) end @@ -81,7 +81,7 @@ function OptISTA(A, x::AbstractVector{T}=Vector{eltype(A)}(undef,size(A,2)); reg ρ /= abs(power_iterations(AᴴA)) end θn = 1 - for i = 1:(iterations-1) + for _ = 1:(iterations-1) θn = (1 + sqrt(1 + 4 * θn^2)) / 2 end θn = (1 + sqrt(1 + 8 * θn^2)) / 2 @@ -111,14 +111,14 @@ function init!(solver::OptISTA{rT,vecT,matA,matAHA}, b::vecT else solver.x .= x end - solver.y .= 0 - solver.z .= 0 - solver.zᵒˡᵈ .= 0 + solver.y .= solver.x + solver.z .= solver.x + solver.zᵒˡᵈ .= solver.x solver.θ = θ solver.θᵒˡᵈ = θ solver.θn = θ - for i = 1:(solver.iterations-1) + for _ = 1:(solver.iterations-1) solver.θn = (1 + sqrt(1 + 4 * solver.θn^2)) / 2 end solver.θn = (1 + sqrt(1 + 8 * solver.θn^2)) / 2 @@ -150,11 +150,11 @@ function solve(solver::OptISTA, b; A=solver.A, startVector=similar(b,0), solverI init!(solver, b; x=startVector) # log solver information - solverInfo != nothing && storeInfo(solverInfo,solver.x,norm(solver.res)) + solverInfo !== nothing && storeInfo(solverInfo,solver.x,norm(solver.res)) # perform OptISTA iterations for (iteration, item) = enumerate(solver) - solverInfo != nothing && storeInfo(solverInfo,solver.x,norm(solver.res)) + solverInfo !== nothing && storeInfo(solverInfo,solver.x,norm(solver.res)) end return solver.x @@ -169,6 +169,7 @@ function iterate(solver::OptISTA, iteration::Int=0) if done(solver, iteration) return nothing end # inertial parameters + solver.γ = 2solver.θ / solver.θn^2 * (solver.θn^2 - 2solver.θ^2 + solver.θ) solver.θᵒˡᵈ = solver.θ if iteration == solver.iterations - 1 #the convergence rate depends on choice of # iterations! solver.θ = (1 + sqrt(1 + 8 * solver.θᵒˡᵈ^2)) / 2 @@ -177,7 +178,6 @@ function iterate(solver::OptISTA, iteration::Int=0) end solver.α = (solver.θᵒˡᵈ - 1) / solver.θ solver.β = solver.θᵒˡᵈ / solver.θ - solver.γ = 2solver.θ / solver.θn^2 * (solver.θn^2 - 2solver.θ^2 + solver.θ) # calculate residuum and do gradient step # solver.y .-= solver.ρ * solver.γ .* (solver.AᴴA * solver.x .- solver.x₀) diff --git a/src/POGM.jl b/src/POGM.jl index 401238c9..ff3aa837 100644 --- a/src/POGM.jl +++ b/src/POGM.jl @@ -80,7 +80,7 @@ function POGM(A, x::AbstractVector{T}=Vector{eltype(A)}(undef,size(A,2)); reg=no , kargs...) where {T} rT = real(T) - if reg == nothing + if reg === nothing reg = Regularization(regName, λ, kargs...) end @@ -130,8 +130,6 @@ function init!(solver::POGM{rT,vecT,matA,matAHA}, b::vecT solver.t = t solver.tᵒˡᵈ = t - solver.γ = 1 #doesn't matter - solver.γᵒˡᵈ = 1 #doesn't matter solver.σ = 1 # normalization of regularization parameters if solver.normalizeReg @@ -160,11 +158,11 @@ function solve(solver::POGM, b; A=solver.A, startVector=similar(b,0), solverInfo init!(solver, b; x=startVector) # log solver information - solverInfo != nothing && storeInfo(solverInfo,solver.x,norm(solver.res)) + solverInfo !== nothing && storeInfo(solverInfo,solver.x,norm(solver.res)) # perform POGM iterations for (iteration, item) = enumerate(solver) - solverInfo != nothing && storeInfo(solverInfo,solver.x,norm(solver.res)) + solverInfo !== nothing && storeInfo(solverInfo,solver.x,norm(solver.res)) end return solver.x diff --git a/src/RegularizedLeastSquares.jl b/src/RegularizedLeastSquares.jl index c9b3c162..9556e8d8 100644 --- a/src/RegularizedLeastSquares.jl +++ b/src/RegularizedLeastSquares.jl @@ -85,6 +85,7 @@ Function returns choosen solver. * `"pseudoinverse"` - approximates a solution using the More-Penrose pseudo inverse * `"fusedlasso"` - solver for the Fused-Lasso problem * `"fista"` - Fast Iterative Shrinkage Thresholding Algorithm +* `"optista"` - "Optimal" ISTA * `"pogm"` - Proximal Optimal Gradient Method * `"admm"` - Alternating Direcion of Multipliers Method * `"splitBregman"` - Split Bregman method for constrained & regularized inverse problems @@ -113,8 +114,8 @@ function createLinearSolver(solver::AbstractString, A, x=zeros(eltype(A),size(A, return FusedLasso(A; kargs...) elseif solver == "fista" return FISTA(A, x; kargs...) - elseif solver == "optfista" - return OptFISTA(A, x; kargs...) + elseif solver == "optista" + return OptISTA(A, x; kargs...) elseif solver == "pogm" return POGM(A, x; kargs...) elseif solver == "admm" diff --git a/test/testSolvers.jl b/test/testSolvers.jl index abdaa686..3f6ec85e 100644 --- a/test/testSolvers.jl +++ b/test/testSolvers.jl @@ -54,7 +54,7 @@ end b = b[idx] F = F[idx, :] - for solver in ["pogm", "fista", "admm"] + for solver in ["pogm", "optista", "fista", "admm"] reg = Regularization("L1", 1e-3) solverInfo = SolverInfo(ComplexF64) S = createLinearSolver(