diff --git a/docs/src/API/solvers.md b/docs/src/API/solvers.md index fd7b6003..3f0e4fbf 100644 --- a/docs/src/API/solvers.md +++ b/docs/src/API/solvers.md @@ -44,6 +44,9 @@ RegularizedLeastSquares.SplitBregman ## Miscellaneous Functions ```@docs +RegularizedLeastSquares.StoreSolutionCallback +RegularizedLeastSquares.StoreConvergenceCallback +RegularizedLeastSquares.CompareSolutionCallback RegularizedLeastSquares.linearSolverList RegularizedLeastSquares.createLinearSolver RegularizedLeastSquares.applicableSolverList diff --git a/src/ADMM.jl b/src/ADMM.jl index ab57b36d..b185fd6c 100644 --- a/src/ADMM.jl +++ b/src/ADMM.jl @@ -1,6 +1,6 @@ export ADMM -mutable struct ADMM{rT,matT,opT,R,ropT,P,vecT,rvecT,preconT} <: AbstractPrimalDualSolver where {vecT <: AbstractVector{Union{rT, Complex{rT}}}, rvecT <: AbstractVector{rT}} +mutable struct ADMM{matT,opT,R,ropT,P,vecT,rvecT,preconT,rT} <: AbstractPrimalDualSolver where {vecT <: AbstractVector{Union{rT, Complex{rT}}}, rvecT <: AbstractVector{rT}} # operators and regularization A::matT reg::Vector{R} @@ -19,7 +19,7 @@ mutable struct ADMM{rT,matT,opT,R,ropT,P,vecT,rvecT,preconT} <: AbstractPrimalDu uᵒˡᵈ::Vector{vecT} # other parameters precon::preconT - ρ::rvecT # TODO: Switch all these vectors to Tuple + ρ::rvecT iterations::Int64 iterationsCG::Int64 # state variables for CG @@ -40,10 +40,10 @@ mutable struct ADMM{rT,matT,opT,R,ropT,P,vecT,rvecT,preconT} <: AbstractPrimalDu end """ - ADMM(A; AHA = A'*A, precon = Identity(), reg = L1Regularization(zero(eltype(AHA))), normalizeReg = NoNormalization(), rho = 1e-1, vary_rho = :none, iterations = 50, iterationsCG = 10, absTol = eps(real(eltype(AHA))), relTol = eps(real(eltype(AHA))), tolInner = 1e-5, verbose = false) - ADMM( ; AHA = , precon = Identity(), reg = L1Regularization(zero(eltype(AHA))), normalizeReg = NoNormalization(), rho = 1e-1, vary_rho = :none, iterations = 50, iterationsCG = 10, absTol = eps(real(eltype(AHA))), relTol = eps(real(eltype(AHA))), tolInner = 1e-5, verbose = false) + ADMM(A; AHA = A'*A, precon = Identity(), reg = L1Regularization(zero(eltype(AHA))), normalizeReg = NoNormalization(), rho = 1e-1, vary_rho = :none, iterations = 10, iterationsCG = 10, absTol = eps(real(eltype(AHA))), relTol = eps(real(eltype(AHA))), tolInner = 1e-5, verbose = false) + ADMM( ; AHA = , precon = Identity(), reg = L1Regularization(zero(eltype(AHA))), normalizeReg = NoNormalization(), rho = 1e-1, vary_rho = :none, iterations = 10, iterationsCG = 10, absTol = eps(real(eltype(AHA))), relTol = eps(real(eltype(AHA))), tolInner = 1e-5, verbose = false) -creates an `ADMM` object for the forward operator `A` or normal operator `AHA`. +Creates an `ADMM` object for the forward operator `A` or normal operator `AHA`. # Required Arguments * `A` - forward operator @@ -58,10 +58,10 @@ creates an `ADMM` object for the forward operator `A` or normal operator `AHA`. * `rho::Real` - penalty of the augmented Lagrangian * `vary_rho::Symbol` - vary rho to balance primal and dual feasibility; options `:none`, `:balance`, `:PnP` * `iterations::Int` - maximum number of (outer) ADMM iterations - * `iterationsCG::Int` - max number of (inner) CG iterations - * `absTol::Real` - abs tolerance for stopping criterion - * `relTol::Real` - tolerance for stopping criterion - * `tolInner::Real` - rel tolerance for CG stopping criterion + * `iterationsCG::Int` - maximum number of (inner) CG iterations + * `absTol::Real` - absolute tolerance for stopping criterion + * `relTol::Real` - relative tolerance for stopping criterion + * `tolInner::Real` - relative tolerance for CG stopping criterion * `verbose::Bool` - print residual in each iteration See also [`createLinearSolver`](@ref), [`solve!`](@ref). @@ -75,19 +75,18 @@ function ADMM(A , normalizeReg::AbstractRegularizationNormalization = NoNormalization() , rho = 1e-1 , vary_rho::Symbol = :none - , iterations::Int = 50 + , iterations::Int = 10 , iterationsCG::Int = 10 , absTol::Real = eps(real(eltype(AHA))) , relTol::Real = eps(real(eltype(AHA))) , tolInner::Real = 1e-5 , verbose = false ) - # TODO: The constructor is not type stable T = eltype(AHA) rT = real(T) - reg = vec(reg) # using a custom method of vec(.) + reg = vec(reg) regTrafo = [] indices = findsinks(AbstractProjectionRegularization, reg) @@ -98,9 +97,9 @@ function ADMM(A for r in reg trafoReg = findfirst(ConstraintTransformedRegularization, r) if isnothing(trafoReg) - push!(regTrafo, opEye(eltype(AHA),size(AHA,2))) + push!(regTrafo, opEye(T,size(AHA,2))) else - push!(regTrafo, trafoReg) + push!(regTrafo, trafoReg.trafo) end end regTrafo = identity.(regTrafo) @@ -111,17 +110,16 @@ function ADMM(A rho = rT.(rho) end - x = Vector{T}(undef,size(AHA,2)) + x = Vector{T}(undef, size(AHA,2)) xᵒˡᵈ = similar(x) β = similar(x) β_y = similar(x) # fields for primal & dual variables - z = [similar(x, size(AHA,2)) for i=1:length(reg)] - zᵒˡᵈ = [similar(z[i]) for i=1:length(reg)] - u = [similar(z[i]) for i=1:length(reg)] - uᵒˡᵈ = [similar(u[i]) for i=1:length(reg)] - + z = [similar(x, size(regTrafo[i],1)) for i ∈ eachindex(vec(reg))] + zᵒˡᵈ = [similar(z[i]) for i ∈ eachindex(vec(reg))] + u = [similar(z[i]) for i ∈ eachindex(vec(reg))] + uᵒˡᵈ = [similar(u[i]) for i ∈ eachindex(vec(reg))] # statevariables for CG # we store them here to prevent CG from allocating new fields at each call @@ -138,16 +136,15 @@ function ADMM(A reg = normalize(ADMM, normalizeReg, reg, A, nothing) return ADMM(A,reg,regTrafo,proj,AHA,β,β_y,x,xᵒˡᵈ,z,zᵒˡᵈ,u,uᵒˡᵈ,precon,rho,iterations - ,iterationsCG,cgStateVars, rᵏ,sᵏ,ɛᵖʳⁱ,ɛᵈᵘᵃ,zero(rT),Δ,rT(absTol),rT(relTol),rT(tolInner) - ,normalizeReg, vary_rho, verbose) + ,iterationsCG,cgStateVars,rᵏ,sᵏ,ɛᵖʳⁱ,ɛᵈᵘᵃ,rT(0),Δ,rT(absTol),rT(relTol),rT(tolInner),normalizeReg,vary_rho,verbose) end """ - init!(solver::ADMM, b; x=similar(b,0)) + init!(solver::ADMM, b; x0 = 0) (re-) initializes the ADMM iterator """ -function init!(solver::ADMM, b::AbstractVector{T}; x0=0) where T +function init!(solver::ADMM, b; x0 = 0) solver.x .= x0 # right hand side for the x-update @@ -158,7 +155,7 @@ function init!(solver::ADMM, b::AbstractVector{T}; x0=0) where T end # primal and dual variables - for i=1:length(solver.reg) + for i ∈ eachindex(solver.reg) solver.z[i] .= solver.regTrafo[i]*solver.x solver.u[i] .= 0 end @@ -168,14 +165,14 @@ function init!(solver::ADMM, b::AbstractVector{T}; x0=0) where T solver.sᵏ .= Inf solver.ɛᵖʳⁱ .= 0 solver.ɛᵈᵘᵃ .= 0 - solver.σᵃᵇˢ = sqrt(length(b))*solver.absTol + solver.σᵃᵇˢ = sqrt(length(b)) * solver.absTol solver.Δ .= Inf # normalization of regularization parameters solver.reg = normalize(solver, solver.normalizeReg, solver.reg, solver.A, b) end -solverconvergence(solver::ADMM) = (; :primal => solver.rᵏ, :dual => norm(solver.sᵏ)) +solverconvergence(solver::ADMM) = (; :primal => solver.rᵏ, :dual => solver.sᵏ) """ @@ -183,7 +180,7 @@ solverconvergence(solver::ADMM) = (; :primal => solver.rᵏ, :dual => norm(solve performs one ADMM iteration. """ -function iterate(solver::ADMM, iteration=0) +function iterate(solver::ADMM, iteration=1) if done(solver, iteration) return nothing end solver.verbose && println("Outer ADMM Iteration #$iteration") @@ -194,18 +191,19 @@ function iterate(solver::ADMM, iteration=0) for i ∈ eachindex(solver.reg) mul!(solver.β, adjoint(solver.regTrafo[i]), solver.z[i], solver.ρ[i], 1) mul!(solver.β, adjoint(solver.regTrafo[i]), solver.u[i], -solver.ρ[i], 1) - AHA += solver.ρ[i] * adjoint(solver.regTrafo[i]) * solver.regTrafo[i] + AHA += solver.ρ[i] * adjoint(solver.regTrafo[i]) * solver.regTrafo[i] end solver.verbose && println("conjugated gradients: ") solver.xᵒˡᵈ .= solver.x - cg!(solver.x, AHA, solver.β, Pl=solver.precon, maxiter=solver.iterationsCG, reltol=solver.tolInner, statevars=solver.cgStateVars, verbose = solver.verbose) + cg!(solver.x, AHA, solver.β, Pl = solver.precon, maxiter = solver.iterationsCG, reltol = solver.tolInner, statevars = solver.cgStateVars, verbose = solver.verbose) for proj in solver.proj prox!(proj, solver.x) end + # proximal map for regularization terms for i ∈ eachindex(solver.reg) - # swap v and vᵒˡᵈ w/o copying data + # swap z and zᵒˡᵈ w/o copying data tmp = solver.zᵒˡᵈ[i] solver.zᵒˡᵈ[i] = solver.z[i] solver.z[i] = tmp @@ -214,7 +212,7 @@ function iterate(solver::ADMM, iteration=0) mul!(solver.z[i], solver.regTrafo[i], solver.x) solver.z[i] .+= solver.u[i] if solver.ρ[i] != 0 - prox!(solver.reg[i], solver.z[i], λ(solver.reg[i])/solver.ρ[i]) + prox!(solver.reg[i], solver.z[i], λ(solver.reg[i])/2solver.ρ[i]) # λ is divided by 2 to match the ISTA-type algorithms end # 3. update u @@ -222,11 +220,11 @@ function iterate(solver::ADMM, iteration=0) mul!(solver.u[i], solver.regTrafo[i], solver.x, 1, 1) solver.u[i] .-= solver.z[i] - # update convergence measures (one for each constraint) - solver.rᵏ[i] = norm(solver.regTrafo[i]*solver.x-solver.z[i]) # primal residual (x-z) + # update convergence criteria (one for each constraint) + solver.rᵏ[i] = norm(solver.regTrafo[i] * solver.x - solver.z[i]) # primal residual (x-z) solver.sᵏ[i] = norm(solver.ρ[i] * adjoint(solver.regTrafo[i]) * (solver.z[i] .- solver.zᵒˡᵈ[i])) # dual residual (concerning f(x)) - solver.ɛᵖʳⁱ[i] = max(norm(solver.regTrafo[i]*solver.x), norm(solver.z[i])) + solver.ɛᵖʳⁱ[i] = max(norm(solver.regTrafo[i] * solver.x), norm(solver.z[i])) solver.ɛᵈᵘᵃ[i] = norm(solver.ρ[i] * adjoint(solver.regTrafo[i]) * solver.u[i]) Δᵒˡᵈ = solver.Δ[i] @@ -244,28 +242,23 @@ function iterate(solver::ADMM, iteration=0) end if solver.verbose - println("rᵏ[$i] = $(solver.rᵏ[i])") - println("sᵏ[$i] = $(solver.sᵏ[i])") - println("ɛᵖʳⁱ[$i] = $(solver.ɛᵖʳⁱ[i])") - println("ɛᵈᵘᵃ[$i] = $(solver.ɛᵈᵘᵃ[i])") - println("Δᵒˡᵈ = $(Δᵒˡᵈ)") - println("Δ[$i] = $(solver.Δ[i])") - println("Δ/Δᵒˡᵈ = $(solver.Δ[i]/Δᵒˡᵈ)") - println("current ρ[$i] = $(solver.ρ[i])") + println("rᵏ[$i]/ɛᵖʳⁱ[$i] = $(solver.rᵏ[i]/solver.ɛᵖʳⁱ[i])") + println("sᵏ[$i]/ɛᵈᵘᵃ[$i] = $(solver.sᵏ[i]/solver.ɛᵈᵘᵃ[i])") + println("Δ[$i]/Δᵒˡᵈ[$i] = $(solver.Δ[i]/Δᵒˡᵈ)") + println("new ρ[$i] = $(solver.ρ[i])") flush(stdout) end end - # return the primal feasibility measure as item and iteration number as state return solver.rᵏ, iteration+1 end function converged(solver::ADMM) - for i=1:length(solver.reg) + for i ∈ eachindex(solver.reg) (solver.rᵏ[i] >= solver.σᵃᵇˢ + solver.relTol * solver.ɛᵖʳⁱ[i]) && return false (solver.sᵏ[i] >= solver.σᵃᵇˢ + solver.relTol * solver.ɛᵈᵘᵃ[i]) && return false end return true end -@inline done(solver::ADMM,iteration::Int) = converged(solver) || iteration>=solver.iterations +@inline done(solver::ADMM,iteration::Int) = converged(solver) || iteration >= solver.iterations \ No newline at end of file diff --git a/src/SplitBregman.jl b/src/SplitBregman.jl index f86d8dfc..f9cb853a 100644 --- a/src/SplitBregman.jl +++ b/src/SplitBregman.jl @@ -1,7 +1,7 @@ export SplitBregman -mutable struct SplitBregman{matT,vecT,opT,R,ropT,P,rvecT,preconT,rT} <: AbstractPrimalDualSolver - # oerators and regularization +mutable struct SplitBregman{matT,opT,R,ropT,P,vecT,rvecT,preconT,rT} <: AbstractPrimalDualSolver + # operators and regularization A::matT reg::Vector{R} regTrafo::Vector{ropT} @@ -25,11 +25,10 @@ mutable struct SplitBregman{matT,vecT,opT,R,ropT,P,rvecT,preconT,rT} <: Abstract # state variables for CG cgStateVars::CGStateVariables # convergence parameters - rk::rvecT - sk::vecT - eps_pri::rvecT - eps_dt::vecT - # eps_dual::Float64 + rᵏ::rvecT + sᵏ::rvecT + ɛᵖʳⁱ::rvecT + ɛᵈᵘᵃ::rvecT σᵃᵇˢ::rT absTol::rT relTol::rT @@ -41,26 +40,28 @@ mutable struct SplitBregman{matT,vecT,opT,R,ropT,P,rvecT,preconT,rT} <: Abstract end """ - SplitBregman(A; AHA = A'*A, reg = L1Regularization(zero(eltype(AHA))), normalizeReg = NoNormalization(), precon = Identity(), rho = 1.e2absTol = eps(), relTol = eps(), tolInner = 1.e-6, iterations::Int = 10, iterationsInner::Int = 50, iterationsCG::Int = 10, verbose = false) - SplitBregman( ; AHA = A'*A, reg = L1Regularization(zero(eltype(AHA))), normalizeReg = NoNormalization(), precon = Identity(), rho = 1.e2absTol = eps(), relTol = eps(), tolInner = 1.e-6, iterations::Int = 10, iterationsInner::Int = 50, iterationsCG::Int = 10, verbose = false) + SplitBregman(A; AHA = A'*A, precon = Identity(), reg = L1Regularization(zero(eltype(AHA))), normalizeReg = NoNormalization(), rho = 1e-1, iterations = 1, iterationsInner = 10, iterationsCG = 10, absTol = eps(real(eltype(AHA))), relTol = eps(real(eltype(AHA))), tolInner = 1e-5, verbose = false) + SplitBregman( ; AHA = , precon = Identity(), reg = L1Regularization(zero(eltype(AHA))), normalizeReg = NoNormalization(), rho = 1e-1, iterations = 1, iterationsInner = 10, iterationsCG = 10, absTol = eps(real(eltype(AHA))), relTol = eps(real(eltype(AHA))), tolInner = 1e-5, verbose = false) -Creates a `SplitBregman` object for the forward operator `A`. +Creates a `SplitBregman` object for the forward operator `A` or normal operator `AHA`. # Required Arguments * `A` - forward operator + OR + * `AHA` - normal operator (as a keyword argument) # Optional Keyword Arguments * `AHA` - normal operator is optional if `A` is supplied - * `reg::AbstractParameterizedRegularization` - regularization term - * `normalizeReg::AbstractRegularizationNormalization` - regularization normalization scheme; options are `NoNormalization()`, `MeasurementBasedNormalization()`, `SystemMatrixBasedNormalization()` * `precon` - preconditionner for the internal CG algorithm + * `reg::AbstractParameterizedRegularization` - regularization term; can also be a vector of regularization terms + * `normalizeReg::AbstractRegularizationNormalization` - regularization normalization scheme; options are `NoNormalization()`, `MeasurementBasedNormalization()`, `SystemMatrixBasedNormalization()` * `rho::Real` - weights for condition on regularized variables; can also be a vector for multiple regularization terms - * `absTol::Float64` - absolute tolerance for stopping criterion - * `relTol::Float64` - relative tolerance for stopping criterion - * `tolInner::Float64` - tolerance for CG stopping criterion - * `iterations::Int` - maximum number of iterations + * `iterations::Int` - maximum number of outer iterations. Set to 1 for unconstraint split Bregman * `iterationsInner::Int` - maximum number of inner iterations - * `iterationsCG::Int` - maximum number of CG iterations + * `iterationsCG::Int` - maximum number of (inner) CG iterations + * `absTol::Real` - absolute tolerance for stopping criterion + * `relTol::Real` - relative tolerance for stopping criterion + * `tolInner::Real` - relative tolerance for CG stopping criterion * `verbose::Bool` - print residual in each iteration See also [`createLinearSolver`](@ref), [`solve!`](@ref). @@ -69,16 +70,16 @@ SplitBregman(; AHA, kwargs...) = SplitBregman(nothing; kwargs..., AHA = AHA) function SplitBregman(A ; AHA = A'*A + , precon = Identity() , reg = L1Regularization(zero(eltype(AHA))) , normalizeReg::AbstractRegularizationNormalization = NoNormalization() - , precon = Identity() - , rho = 1.e2 - , absTol = eps() - , relTol = eps() - , tolInner = 1.e-6 - , iterations::Int = 10 - , iterationsInner::Int = 50 + , rho = 1e-1 + , iterations::Int = 1 + , iterationsInner::Int = 10 , iterationsCG::Int = 10 + , absTol::Real = eps(real(eltype(AHA))) + , relTol::Real = eps(real(eltype(AHA))) + , tolInner::Real = 1e-5 , verbose = false ) @@ -98,7 +99,7 @@ function SplitBregman(A if isnothing(trafoReg) push!(regTrafo, opEye(T,size(AHA,2))) else - push!(regTrafo, trafoReg) + push!(regTrafo, trafoReg.trafo) end end regTrafo = identity.(regTrafo) @@ -109,33 +110,33 @@ function SplitBregman(A rho = rT.(rho) end - x = Vector{T}(undef,size(AHA,2)) + x = Vector{T}(undef, size(AHA,2)) y = similar(x) β = similar(x) β_y = similar(x) # fields for primal & dual variables - z = [similar(x, size(AHA,2)) for i ∈ eachindex(vec(reg))] - zᵒˡᵈ = [similar(z[i]) for i ∈ eachindex(vec(reg))] - u = [similar(z[i]) for i ∈ eachindex(vec(reg))] + z = [similar(x, size(regTrafo[i],1)) for i ∈ eachindex(vec(reg))] + zᵒˡᵈ = [similar(z[i]) for i ∈ eachindex(vec(reg))] + u = [similar(z[i]) for i ∈ eachindex(vec(reg))] # statevariables for CG # we store them here to prevent CG from allocating new fields at each call cgStateVars = CGStateVariables(zero(x),similar(x),similar(x)) # convergence parameters - rk = similar(x, rT, length(reg)) - sk = similar(x) - eps_pri = similar(x, rT, length(reg)) - eps_dt = similar(x) + rᵏ = Array{rT}(undef, length(reg)) + sᵏ = similar(rᵏ) + ɛᵖʳⁱ = similar(rᵏ) + ɛᵈᵘᵃ = similar(rᵏ) iter_cnt = 1 # normalization parameters - reg = normalize(SplitBregman, normalizeReg, vec(reg), A, nothing) + reg = normalize(SplitBregman, normalizeReg, reg, A, nothing) - return SplitBregman(A,reg,regTrafo,proj,y,AHA,β,β_y,x,z,zᵒˡᵈ,u,precon,rho,iterations,iterationsInner,iterationsCG,cgStateVars,rk,sk,eps_pri,eps_dt,rT(0),rT(absTol),rT(relTol),rT(tolInner),iter_cnt,normalizeReg,verbose) + return SplitBregman(A,reg,regTrafo,proj,y,AHA,β,β_y,x,z,zᵒˡᵈ,u,precon,rho,iterations,iterationsInner,iterationsCG,cgStateVars,rᵏ,sᵏ,ɛᵖʳⁱ,ɛᵈᵘᵃ,rT(0),rT(absTol),rT(relTol),rT(tolInner),iter_cnt,normalizeReg,verbose) end """ @@ -161,7 +162,11 @@ function init!(solver::SplitBregman, b; x0 = 0) end # convergence parameter - solver.σᵃᵇˢ = sqrt(length(b))*solver.absTol + solver.rᵏ .= Inf + solver.sᵏ .= Inf + solver.ɛᵖʳⁱ .= 0 + solver.ɛᵈᵘᵃ .= 0 + solver.σᵃᵇˢ = sqrt(length(b)) * solver.absTol # normalization of regularization parameters solver.reg = normalize(solver, solver.normalizeReg, solver.reg, solver.A, b) @@ -170,10 +175,11 @@ function init!(solver::SplitBregman, b; x0 = 0) solver.iter_cnt = 1 end -solverconvergence(solver::SplitBregman) = (; :primal => solver.rk, :dual => norm(solver.sk)) +solverconvergence(solver::SplitBregman) = (; :primal => solver.rᵏ, :dual => solver.sᵏ) function iterate(solver::SplitBregman, iteration=1) if done(solver, iteration) return nothing end + solver.verbose && println("SplitBregman Iteration #$iteration – Outer iteration $(solver.iter_cnt)") # update x solver.β .= solver.β_y @@ -184,7 +190,7 @@ function iterate(solver::SplitBregman, iteration=1) AHA += solver.ρ[i] * adjoint(solver.regTrafo[i]) * solver.regTrafo[i] end solver.verbose && println("conjugated gradients: ") - cg!(solver.x, AHA, solver.β, Pl = solver.precon, maxiter = solver.iterationsInner, reltol=solver.tolInner, statevars=solver.cgStateVars, verbose = solver.verbose) + cg!(solver.x, AHA, solver.β, Pl = solver.precon, maxiter = solver.iterationsCG, reltol = solver.tolInner, statevars = solver.cgStateVars, verbose = solver.verbose) for proj in solver.proj prox!(proj, solver.x) @@ -201,27 +207,25 @@ function iterate(solver::SplitBregman, iteration=1) mul!(solver.z[i], solver.regTrafo[i], solver.x) solver.z[i] .+= solver.u[i] if solver.ρ[i] != 0 - prox!(solver.reg[i], solver.z[i], λ(solver.reg[i])/solver.ρ[i]) + prox!(solver.reg[i], solver.z[i], λ(solver.reg[i])/2solver.ρ[i]) # λ is divided by 2 to match the ISTA-type algorithms end # 3. update u mul!(solver.u[i], solver.regTrafo[i], solver.x, 1, 1) solver.u[i] .-= solver.z[i] - # update convergence criteria - # primal residuals norms (one for each constraint) - solver.rk[i] = norm(solver.regTrafo[i] * solver.x - solver.z[i]) - solver.eps_pri[i] = solver.σᵃᵇˢ + solver.relTol * max(norm(solver.regTrafo[i]*solver.x), norm(solver.z[i])) - end + # update convergence criteria (one for each constraint) + solver.rᵏ[i] = norm(solver.regTrafo[i] * solver.x - solver.z[i]) # primal residual (x-z) + solver.sᵏ[i] = norm(solver.ρ[i] * adjoint(solver.regTrafo[i]) * (solver.z[i] .- solver.zᵒˡᵈ[i])) # dual residual (concerning f(x)) - # accumulated dual residual - # effectively this corresponds to combining all constraints into one larger constraint. - solver.sk .= 0 - solver.eps_dt .= 0 - for i ∈ eachindex(solver.reg) - mul!(solver.sk, adjoint(solver.regTrafo[i]), solver.z[i], solver.ρ[i], 1) - mul!(solver.sk, adjoint(solver.regTrafo[i]), solver.zᵒˡᵈ[i], -solver.ρ[i], 1) - mul!(solver.eps_dt, adjoint(solver.regTrafo[i]), solver.u[i], solver.ρ[i], 1) + solver.ɛᵖʳⁱ[i] = max(norm(solver.regTrafo[i] * solver.x), norm(solver.z[i])) + solver.ɛᵈᵘᵃ[i] = norm(solver.ρ[i] * adjoint(solver.regTrafo[i]) * solver.u[i]) + + if solver.verbose + println("rᵏ[$i]/ɛᵖʳⁱ[$i] = $(solver.rᵏ[i]/solver.ɛᵖʳⁱ[i])") + println("sᵏ[$i]/ɛᵈᵘᵃ[$i] = $(solver.sᵏ[i]/solver.ɛᵈᵘᵃ[i])") + flush(stdout) + end end @@ -237,19 +241,15 @@ function iterate(solver::SplitBregman, iteration=1) iteration = 0 end - return solver.rk[1], iteration+1 + return solver.rᵏ, iteration+1 end function converged(solver::SplitBregman) - if norm(solver.sk) >= solver.σᵃᵇˢ+solver.relTol*norm(solver.eps_dt) - return false - else - for i=1:length(solver.reg) - (solver.rk[i] >= solver.eps_pri[i]) && return false + for i ∈ eachindex(solver.reg) + (solver.rᵏ[i] >= solver.σᵃᵇˢ + solver.relTol * solver.ɛᵖʳⁱ[i]) && return false + (solver.sᵏ[i] >= solver.σᵃᵇˢ + solver.relTol * solver.ɛᵈᵘᵃ[i]) && return false end - end - return true end -@inline done(solver::SplitBregman,iteration::Int) = (iteration==1 && solver.iter_cnt>solver.iterations) \ No newline at end of file +@inline done(solver::SplitBregman,iteration::Int) = converged(solver) || (iteration == 1 && solver.iter_cnt > solver.iterations) \ No newline at end of file diff --git a/src/proximalMaps/ProxTV.jl b/src/proximalMaps/ProxTV.jl index fb78c274..48cf39f1 100644 --- a/src/proximalMaps/ProxTV.jl +++ b/src/proximalMaps/ProxTV.jl @@ -3,7 +3,7 @@ export TVRegularization """ TVRegularization -Regularization term implementing the proximal map for TV regularization. Calculated with the Condat algorithm if the TV is calculated only along one dimension and with the Fast Gradient Projection algorithm otherwise. +Regularization term implementing the proximal map for TV regularization. Calculated with the Condat algorithm if the TV is calculated only along one real-valued dimension and with the Fast Gradient Projection algorithm otherwise. Reference for the Condat algorithm: https://lcondat.github.io/publis/Condat-fast_TV-SPL-2013.pdf diff --git a/test/testSolvers.jl b/test/testSolvers.jl index 65645ea0..553ec496 100644 --- a/test/testSolvers.jl +++ b/test/testSolvers.jl @@ -154,7 +154,7 @@ end ## solver = SplitBregman - reg = L1Regularization(1.e-3) + reg = L1Regularization(2e-3) S = createLinearSolver( solver, F;