Skip to content

Commit

Permalink
Merge pull request #68 from JuliaImageRecon/FixSplitBregman
Browse files Browse the repository at this point in the history
FixSplitBregman
  • Loading branch information
tknopp authored Jan 2, 2024
2 parents 483de3e + 2c6c906 commit 4692422
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 109 deletions.
3 changes: 3 additions & 0 deletions docs/src/API/solvers.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ RegularizedLeastSquares.SplitBregman

## Miscellaneous Functions
```@docs
RegularizedLeastSquares.StoreSolutionCallback
RegularizedLeastSquares.StoreConvergenceCallback
RegularizedLeastSquares.CompareSolutionCallback
RegularizedLeastSquares.linearSolverList
RegularizedLeastSquares.createLinearSolver
RegularizedLeastSquares.applicableSolverList
Expand Down
85 changes: 39 additions & 46 deletions src/ADMM.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
export ADMM

mutable struct ADMM{rT,matT,opT,R,ropT,P,vecT,rvecT,preconT} <: AbstractPrimalDualSolver where {vecT <: AbstractVector{Union{rT, Complex{rT}}}, rvecT <: AbstractVector{rT}}
mutable struct ADMM{matT,opT,R,ropT,P,vecT,rvecT,preconT,rT} <: AbstractPrimalDualSolver where {vecT <: AbstractVector{Union{rT, Complex{rT}}}, rvecT <: AbstractVector{rT}}
# operators and regularization
A::matT
reg::Vector{R}
Expand All @@ -19,7 +19,7 @@ mutable struct ADMM{rT,matT,opT,R,ropT,P,vecT,rvecT,preconT} <: AbstractPrimalDu
uᵒˡᵈ::Vector{vecT}
# other parameters
precon::preconT
ρ::rvecT # TODO: Switch all these vectors to Tuple
ρ::rvecT
iterations::Int64
iterationsCG::Int64
# state variables for CG
Expand All @@ -40,10 +40,10 @@ mutable struct ADMM{rT,matT,opT,R,ropT,P,vecT,rvecT,preconT} <: AbstractPrimalDu
end

"""
ADMM(A; AHA = A'*A, precon = Identity(), reg = L1Regularization(zero(eltype(AHA))), normalizeReg = NoNormalization(), rho = 1e-1, vary_rho = :none, iterations = 50, iterationsCG = 10, absTol = eps(real(eltype(AHA))), relTol = eps(real(eltype(AHA))), tolInner = 1e-5, verbose = false)
ADMM( ; AHA = , precon = Identity(), reg = L1Regularization(zero(eltype(AHA))), normalizeReg = NoNormalization(), rho = 1e-1, vary_rho = :none, iterations = 50, iterationsCG = 10, absTol = eps(real(eltype(AHA))), relTol = eps(real(eltype(AHA))), tolInner = 1e-5, verbose = false)
ADMM(A; AHA = A'*A, precon = Identity(), reg = L1Regularization(zero(eltype(AHA))), normalizeReg = NoNormalization(), rho = 1e-1, vary_rho = :none, iterations = 10, iterationsCG = 10, absTol = eps(real(eltype(AHA))), relTol = eps(real(eltype(AHA))), tolInner = 1e-5, verbose = false)
ADMM( ; AHA = , precon = Identity(), reg = L1Regularization(zero(eltype(AHA))), normalizeReg = NoNormalization(), rho = 1e-1, vary_rho = :none, iterations = 10, iterationsCG = 10, absTol = eps(real(eltype(AHA))), relTol = eps(real(eltype(AHA))), tolInner = 1e-5, verbose = false)
creates an `ADMM` object for the forward operator `A` or normal operator `AHA`.
Creates an `ADMM` object for the forward operator `A` or normal operator `AHA`.
# Required Arguments
* `A` - forward operator
Expand All @@ -58,10 +58,10 @@ creates an `ADMM` object for the forward operator `A` or normal operator `AHA`.
* `rho::Real` - penalty of the augmented Lagrangian
* `vary_rho::Symbol` - vary rho to balance primal and dual feasibility; options `:none`, `:balance`, `:PnP`
* `iterations::Int` - maximum number of (outer) ADMM iterations
* `iterationsCG::Int` - max number of (inner) CG iterations
* `absTol::Real` - abs tolerance for stopping criterion
* `relTol::Real` - tolerance for stopping criterion
* `tolInner::Real` - rel tolerance for CG stopping criterion
* `iterationsCG::Int` - maximum number of (inner) CG iterations
* `absTol::Real` - absolute tolerance for stopping criterion
* `relTol::Real` - relative tolerance for stopping criterion
* `tolInner::Real` - relative tolerance for CG stopping criterion
* `verbose::Bool` - print residual in each iteration
See also [`createLinearSolver`](@ref), [`solve!`](@ref).
Expand All @@ -75,19 +75,18 @@ function ADMM(A
, normalizeReg::AbstractRegularizationNormalization = NoNormalization()
, rho = 1e-1
, vary_rho::Symbol = :none
, iterations::Int = 50
, iterations::Int = 10
, iterationsCG::Int = 10
, absTol::Real = eps(real(eltype(AHA)))
, relTol::Real = eps(real(eltype(AHA)))
, tolInner::Real = 1e-5
, verbose = false
)
# TODO: The constructor is not type stable

T = eltype(AHA)
rT = real(T)

reg = vec(reg) # using a custom method of vec(.)
reg = vec(reg)

regTrafo = []
indices = findsinks(AbstractProjectionRegularization, reg)
Expand All @@ -98,9 +97,9 @@ function ADMM(A
for r in reg
trafoReg = findfirst(ConstraintTransformedRegularization, r)
if isnothing(trafoReg)
push!(regTrafo, opEye(eltype(AHA),size(AHA,2)))
push!(regTrafo, opEye(T,size(AHA,2)))
else
push!(regTrafo, trafoReg)
push!(regTrafo, trafoReg.trafo)
end
end
regTrafo = identity.(regTrafo)
Expand All @@ -111,17 +110,16 @@ function ADMM(A
rho = rT.(rho)
end

x = Vector{T}(undef,size(AHA,2))
x = Vector{T}(undef, size(AHA,2))
xᵒˡᵈ = similar(x)
β = similar(x)
β_y = similar(x)

# fields for primal & dual variables
z = [similar(x, size(AHA,2)) for i=1:length(reg)]
zᵒˡᵈ = [similar(z[i]) for i=1:length(reg)]
u = [similar(z[i]) for i=1:length(reg)]
uᵒˡᵈ = [similar(u[i]) for i=1:length(reg)]

z = [similar(x, size(regTrafo[i],1)) for i eachindex(vec(reg))]
zᵒˡᵈ = [similar(z[i]) for i eachindex(vec(reg))]
u = [similar(z[i]) for i eachindex(vec(reg))]
uᵒˡᵈ = [similar(u[i]) for i eachindex(vec(reg))]

# statevariables for CG
# we store them here to prevent CG from allocating new fields at each call
Expand All @@ -138,16 +136,15 @@ function ADMM(A
reg = normalize(ADMM, normalizeReg, reg, A, nothing)

return ADMM(A,reg,regTrafo,proj,AHA,β,β_y,x,xᵒˡᵈ,z,zᵒˡᵈ,u,uᵒˡᵈ,precon,rho,iterations
,iterationsCG,cgStateVars, rᵏ,sᵏ,ɛᵖʳⁱ,ɛᵈᵘᵃ,zero(rT),Δ,rT(absTol),rT(relTol),rT(tolInner)
,normalizeReg, vary_rho, verbose)
,iterationsCG,cgStateVars,rᵏ,sᵏ,ɛᵖʳⁱ,ɛᵈᵘᵃ,rT(0),Δ,rT(absTol),rT(relTol),rT(tolInner),normalizeReg,vary_rho,verbose)
end

"""
init!(solver::ADMM, b; x=similar(b,0))
init!(solver::ADMM, b; x0 = 0)
(re-) initializes the ADMM iterator
"""
function init!(solver::ADMM, b::AbstractVector{T}; x0=0) where T
function init!(solver::ADMM, b; x0 = 0)
solver.x .= x0

# right hand side for the x-update
Expand All @@ -158,7 +155,7 @@ function init!(solver::ADMM, b::AbstractVector{T}; x0=0) where T
end

# primal and dual variables
for i=1:length(solver.reg)
for i eachindex(solver.reg)
solver.z[i] .= solver.regTrafo[i]*solver.x
solver.u[i] .= 0
end
Expand All @@ -168,22 +165,22 @@ function init!(solver::ADMM, b::AbstractVector{T}; x0=0) where T
solver.sᵏ .= Inf
solver.ɛᵖʳⁱ .= 0
solver.ɛᵈᵘᵃ .= 0
solver.σᵃᵇˢ = sqrt(length(b))*solver.absTol
solver.σᵃᵇˢ = sqrt(length(b)) * solver.absTol
solver.Δ .= Inf

# normalization of regularization parameters
solver.reg = normalize(solver, solver.normalizeReg, solver.reg, solver.A, b)
end

solverconvergence(solver::ADMM) = (; :primal => solver.rᵏ, :dual => norm(solver.sᵏ))
solverconvergence(solver::ADMM) = (; :primal => solver.rᵏ, :dual => solver.sᵏ)


"""
iterate(it::ADMM, iteration::Int=0)
performs one ADMM iteration.
"""
function iterate(solver::ADMM, iteration=0)
function iterate(solver::ADMM, iteration=1)
if done(solver, iteration) return nothing end
solver.verbose && println("Outer ADMM Iteration #$iteration")

Expand All @@ -194,18 +191,19 @@ function iterate(solver::ADMM, iteration=0)
for i eachindex(solver.reg)
mul!(solver.β, adjoint(solver.regTrafo[i]), solver.z[i], solver.ρ[i], 1)
mul!(solver.β, adjoint(solver.regTrafo[i]), solver.u[i], -solver.ρ[i], 1)
AHA += solver.ρ[i] * adjoint(solver.regTrafo[i]) * solver.regTrafo[i]
AHA += solver.ρ[i] * adjoint(solver.regTrafo[i]) * solver.regTrafo[i]
end
solver.verbose && println("conjugated gradients: ")
solver.xᵒˡᵈ .= solver.x
cg!(solver.x, AHA, solver.β, Pl=solver.precon, maxiter=solver.iterationsCG, reltol=solver.tolInner, statevars=solver.cgStateVars, verbose = solver.verbose)
cg!(solver.x, AHA, solver.β, Pl = solver.precon, maxiter = solver.iterationsCG, reltol = solver.tolInner, statevars = solver.cgStateVars, verbose = solver.verbose)

for proj in solver.proj
prox!(proj, solver.x)
end

# proximal map for regularization terms
for i eachindex(solver.reg)
# swap v and vᵒˡᵈ w/o copying data
# swap z and zᵒˡᵈ w/o copying data
tmp = solver.zᵒˡᵈ[i]
solver.zᵒˡᵈ[i] = solver.z[i]
solver.z[i] = tmp
Expand All @@ -214,19 +212,19 @@ function iterate(solver::ADMM, iteration=0)
mul!(solver.z[i], solver.regTrafo[i], solver.x)
solver.z[i] .+= solver.u[i]
if solver.ρ[i] != 0
prox!(solver.reg[i], solver.z[i], λ(solver.reg[i])/solver.ρ[i])
prox!(solver.reg[i], solver.z[i], λ(solver.reg[i])/2solver.ρ[i]) # λ is divided by 2 to match the ISTA-type algorithms
end

# 3. update u
solver.uᵒˡᵈ[i] .= solver.u[i]
mul!(solver.u[i], solver.regTrafo[i], solver.x, 1, 1)
solver.u[i] .-= solver.z[i]

# update convergence measures (one for each constraint)
solver.rᵏ[i] = norm(solver.regTrafo[i]*solver.x-solver.z[i]) # primal residual (x-z)
# update convergence criteria (one for each constraint)
solver.rᵏ[i] = norm(solver.regTrafo[i] * solver.x - solver.z[i]) # primal residual (x-z)
solver.sᵏ[i] = norm(solver.ρ[i] * adjoint(solver.regTrafo[i]) * (solver.z[i] .- solver.zᵒˡᵈ[i])) # dual residual (concerning f(x))

solver.ɛᵖʳⁱ[i] = max(norm(solver.regTrafo[i]*solver.x), norm(solver.z[i]))
solver.ɛᵖʳⁱ[i] = max(norm(solver.regTrafo[i] * solver.x), norm(solver.z[i]))
solver.ɛᵈᵘᵃ[i] = norm(solver.ρ[i] * adjoint(solver.regTrafo[i]) * solver.u[i])

Δᵒˡᵈ = solver.Δ[i]
Expand All @@ -244,28 +242,23 @@ function iterate(solver::ADMM, iteration=0)
end

if solver.verbose
println("rᵏ[$i] = $(solver.rᵏ[i])")
println("sᵏ[$i] = $(solver.sᵏ[i])")
println("ɛᵖʳⁱ[$i] = $(solver.ɛᵖʳⁱ[i])")
println("ɛᵈᵘᵃ[$i] = $(solver.ɛᵈᵘᵃ[i])")
println("Δᵒˡᵈ = $(Δᵒˡᵈ)")
println("Δ[$i] = $(solver.Δ[i])")
println("Δ/Δᵒˡᵈ = $(solver.Δ[i]/Δᵒˡᵈ)")
println("current ρ[$i] = $(solver.ρ[i])")
println("rᵏ[$i]/ɛᵖʳⁱ[$i] = $(solver.rᵏ[i]/solver.ɛᵖʳⁱ[i])")
println("sᵏ[$i]/ɛᵈᵘᵃ[$i] = $(solver.sᵏ[i]/solver.ɛᵈᵘᵃ[i])")
println("Δ[$i]/Δᵒˡᵈ[$i] = $(solver.Δ[i]/Δᵒˡᵈ)")
println("new ρ[$i] = $(solver.ρ[i])")
flush(stdout)
end
end

# return the primal feasibility measure as item and iteration number as state
return solver.rᵏ, iteration+1
end

function converged(solver::ADMM)
for i=1:length(solver.reg)
for i eachindex(solver.reg)
(solver.rᵏ[i] >= solver.σᵃᵇˢ + solver.relTol * solver.ɛᵖʳⁱ[i]) && return false
(solver.sᵏ[i] >= solver.σᵃᵇˢ + solver.relTol * solver.ɛᵈᵘᵃ[i]) && return false
end
return true
end

@inline done(solver::ADMM,iteration::Int) = converged(solver) || iteration>=solver.iterations
@inline done(solver::ADMM,iteration::Int) = converged(solver) || iteration >= solver.iterations
Loading

0 comments on commit 4692422

Please sign in to comment.