Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FixSplitBregman #68

Merged
merged 5 commits into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/src/API/solvers.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ RegularizedLeastSquares.SplitBregman

## Miscellaneous Functions
```@docs
RegularizedLeastSquares.StoreSolutionCallback
RegularizedLeastSquares.StoreConvergenceCallback
RegularizedLeastSquares.CompareSolutionCallback
RegularizedLeastSquares.linearSolverList
RegularizedLeastSquares.createLinearSolver
RegularizedLeastSquares.applicableSolverList
Expand Down
85 changes: 39 additions & 46 deletions src/ADMM.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
export ADMM

mutable struct ADMM{rT,matT,opT,R,ropT,P,vecT,rvecT,preconT} <: AbstractPrimalDualSolver where {vecT <: AbstractVector{Union{rT, Complex{rT}}}, rvecT <: AbstractVector{rT}}
mutable struct ADMM{matT,opT,R,ropT,P,vecT,rvecT,preconT,rT} <: AbstractPrimalDualSolver where {vecT <: AbstractVector{Union{rT, Complex{rT}}}, rvecT <: AbstractVector{rT}}
# operators and regularization
A::matT
reg::Vector{R}
Expand All @@ -19,7 +19,7 @@ mutable struct ADMM{rT,matT,opT,R,ropT,P,vecT,rvecT,preconT} <: AbstractPrimalDu
uᵒˡᵈ::Vector{vecT}
# other parameters
precon::preconT
ρ::rvecT # TODO: Switch all these vectors to Tuple
ρ::rvecT
iterations::Int64
iterationsCG::Int64
# state variables for CG
Expand All @@ -40,10 +40,10 @@ mutable struct ADMM{rT,matT,opT,R,ropT,P,vecT,rvecT,preconT} <: AbstractPrimalDu
end

"""
ADMM(A; AHA = A'*A, precon = Identity(), reg = L1Regularization(zero(eltype(AHA))), normalizeReg = NoNormalization(), rho = 1e-1, vary_rho = :none, iterations = 50, iterationsCG = 10, absTol = eps(real(eltype(AHA))), relTol = eps(real(eltype(AHA))), tolInner = 1e-5, verbose = false)
ADMM( ; AHA = , precon = Identity(), reg = L1Regularization(zero(eltype(AHA))), normalizeReg = NoNormalization(), rho = 1e-1, vary_rho = :none, iterations = 50, iterationsCG = 10, absTol = eps(real(eltype(AHA))), relTol = eps(real(eltype(AHA))), tolInner = 1e-5, verbose = false)
ADMM(A; AHA = A'*A, precon = Identity(), reg = L1Regularization(zero(eltype(AHA))), normalizeReg = NoNormalization(), rho = 1e-1, vary_rho = :none, iterations = 10, iterationsCG = 10, absTol = eps(real(eltype(AHA))), relTol = eps(real(eltype(AHA))), tolInner = 1e-5, verbose = false)
ADMM( ; AHA = , precon = Identity(), reg = L1Regularization(zero(eltype(AHA))), normalizeReg = NoNormalization(), rho = 1e-1, vary_rho = :none, iterations = 10, iterationsCG = 10, absTol = eps(real(eltype(AHA))), relTol = eps(real(eltype(AHA))), tolInner = 1e-5, verbose = false)

creates an `ADMM` object for the forward operator `A` or normal operator `AHA`.
Creates an `ADMM` object for the forward operator `A` or normal operator `AHA`.

# Required Arguments
* `A` - forward operator
Expand All @@ -58,10 +58,10 @@ creates an `ADMM` object for the forward operator `A` or normal operator `AHA`.
* `rho::Real` - penalty of the augmented Lagrangian
* `vary_rho::Symbol` - vary rho to balance primal and dual feasibility; options `:none`, `:balance`, `:PnP`
* `iterations::Int` - maximum number of (outer) ADMM iterations
* `iterationsCG::Int` - max number of (inner) CG iterations
* `absTol::Real` - abs tolerance for stopping criterion
* `relTol::Real` - tolerance for stopping criterion
* `tolInner::Real` - rel tolerance for CG stopping criterion
* `iterationsCG::Int` - maximum number of (inner) CG iterations
* `absTol::Real` - absolute tolerance for stopping criterion
* `relTol::Real` - relative tolerance for stopping criterion
* `tolInner::Real` - relative tolerance for CG stopping criterion
* `verbose::Bool` - print residual in each iteration

See also [`createLinearSolver`](@ref), [`solve!`](@ref).
Expand All @@ -75,19 +75,18 @@ function ADMM(A
, normalizeReg::AbstractRegularizationNormalization = NoNormalization()
, rho = 1e-1
, vary_rho::Symbol = :none
, iterations::Int = 50
, iterations::Int = 10
, iterationsCG::Int = 10
, absTol::Real = eps(real(eltype(AHA)))
, relTol::Real = eps(real(eltype(AHA)))
, tolInner::Real = 1e-5
, verbose = false
)
# TODO: The constructor is not type stable

T = eltype(AHA)
rT = real(T)

reg = vec(reg) # using a custom method of vec(.)
reg = vec(reg)

regTrafo = []
indices = findsinks(AbstractProjectionRegularization, reg)
Expand All @@ -98,9 +97,9 @@ function ADMM(A
for r in reg
trafoReg = findfirst(ConstraintTransformedRegularization, r)
if isnothing(trafoReg)
push!(regTrafo, opEye(eltype(AHA),size(AHA,2)))
push!(regTrafo, opEye(T,size(AHA,2)))
else
push!(regTrafo, trafoReg)
push!(regTrafo, trafoReg.trafo)
end
end
regTrafo = identity.(regTrafo)
Expand All @@ -111,17 +110,16 @@ function ADMM(A
rho = rT.(rho)
end

x = Vector{T}(undef,size(AHA,2))
x = Vector{T}(undef, size(AHA,2))
xᵒˡᵈ = similar(x)
β = similar(x)
β_y = similar(x)

# fields for primal & dual variables
z = [similar(x, size(AHA,2)) for i=1:length(reg)]
zᵒˡᵈ = [similar(z[i]) for i=1:length(reg)]
u = [similar(z[i]) for i=1:length(reg)]
uᵒˡᵈ = [similar(u[i]) for i=1:length(reg)]

z = [similar(x, size(regTrafo[i],1)) for i ∈ eachindex(vec(reg))]
zᵒˡᵈ = [similar(z[i]) for i ∈ eachindex(vec(reg))]
u = [similar(z[i]) for i ∈ eachindex(vec(reg))]
uᵒˡᵈ = [similar(u[i]) for i ∈ eachindex(vec(reg))]

# statevariables for CG
# we store them here to prevent CG from allocating new fields at each call
Expand All @@ -138,16 +136,15 @@ function ADMM(A
reg = normalize(ADMM, normalizeReg, reg, A, nothing)

return ADMM(A,reg,regTrafo,proj,AHA,β,β_y,x,xᵒˡᵈ,z,zᵒˡᵈ,u,uᵒˡᵈ,precon,rho,iterations
,iterationsCG,cgStateVars, rᵏ,sᵏ,ɛᵖʳⁱ,ɛᵈᵘᵃ,zero(rT),Δ,rT(absTol),rT(relTol),rT(tolInner)
,normalizeReg, vary_rho, verbose)
,iterationsCG,cgStateVars,rᵏ,sᵏ,ɛᵖʳⁱ,ɛᵈᵘᵃ,rT(0),Δ,rT(absTol),rT(relTol),rT(tolInner),normalizeReg,vary_rho,verbose)
end

"""
init!(solver::ADMM, b; x=similar(b,0))
init!(solver::ADMM, b; x0 = 0)

(re-) initializes the ADMM iterator
"""
function init!(solver::ADMM, b::AbstractVector{T}; x0=0) where T
function init!(solver::ADMM, b; x0 = 0)
solver.x .= x0

# right hand side for the x-update
Expand All @@ -158,7 +155,7 @@ function init!(solver::ADMM, b::AbstractVector{T}; x0=0) where T
end

# primal and dual variables
for i=1:length(solver.reg)
for i ∈ eachindex(solver.reg)
solver.z[i] .= solver.regTrafo[i]*solver.x
solver.u[i] .= 0
end
Expand All @@ -168,22 +165,22 @@ function init!(solver::ADMM, b::AbstractVector{T}; x0=0) where T
solver.sᵏ .= Inf
solver.ɛᵖʳⁱ .= 0
solver.ɛᵈᵘᵃ .= 0
solver.σᵃᵇˢ = sqrt(length(b))*solver.absTol
solver.σᵃᵇˢ = sqrt(length(b)) * solver.absTol
solver.Δ .= Inf

# normalization of regularization parameters
solver.reg = normalize(solver, solver.normalizeReg, solver.reg, solver.A, b)
end

solverconvergence(solver::ADMM) = (; :primal => solver.rᵏ, :dual => norm(solver.sᵏ))
solverconvergence(solver::ADMM) = (; :primal => solver.rᵏ, :dual => solver.sᵏ)


"""
iterate(it::ADMM, iteration::Int=0)

performs one ADMM iteration.
"""
function iterate(solver::ADMM, iteration=0)
function iterate(solver::ADMM, iteration=1)
if done(solver, iteration) return nothing end
solver.verbose && println("Outer ADMM Iteration #$iteration")

Expand All @@ -194,18 +191,19 @@ function iterate(solver::ADMM, iteration=0)
for i ∈ eachindex(solver.reg)
mul!(solver.β, adjoint(solver.regTrafo[i]), solver.z[i], solver.ρ[i], 1)
mul!(solver.β, adjoint(solver.regTrafo[i]), solver.u[i], -solver.ρ[i], 1)
AHA += solver.ρ[i] * adjoint(solver.regTrafo[i]) * solver.regTrafo[i]
AHA += solver.ρ[i] * adjoint(solver.regTrafo[i]) * solver.regTrafo[i]
end
solver.verbose && println("conjugated gradients: ")
solver.xᵒˡᵈ .= solver.x
cg!(solver.x, AHA, solver.β, Pl=solver.precon, maxiter=solver.iterationsCG, reltol=solver.tolInner, statevars=solver.cgStateVars, verbose = solver.verbose)
cg!(solver.x, AHA, solver.β, Pl = solver.precon, maxiter = solver.iterationsCG, reltol = solver.tolInner, statevars = solver.cgStateVars, verbose = solver.verbose)

for proj in solver.proj
prox!(proj, solver.x)
end

# proximal map for regularization terms
for i ∈ eachindex(solver.reg)
# swap v and vᵒˡᵈ w/o copying data
# swap z and zᵒˡᵈ w/o copying data
tmp = solver.zᵒˡᵈ[i]
solver.zᵒˡᵈ[i] = solver.z[i]
solver.z[i] = tmp
Expand All @@ -214,19 +212,19 @@ function iterate(solver::ADMM, iteration=0)
mul!(solver.z[i], solver.regTrafo[i], solver.x)
solver.z[i] .+= solver.u[i]
if solver.ρ[i] != 0
prox!(solver.reg[i], solver.z[i], λ(solver.reg[i])/solver.ρ[i])
prox!(solver.reg[i], solver.z[i], λ(solver.reg[i])/2solver.ρ[i]) # λ is divided by 2 to match the ISTA-type algorithms
end

# 3. update u
solver.uᵒˡᵈ[i] .= solver.u[i]
mul!(solver.u[i], solver.regTrafo[i], solver.x, 1, 1)
solver.u[i] .-= solver.z[i]

# update convergence measures (one for each constraint)
solver.rᵏ[i] = norm(solver.regTrafo[i]*solver.x-solver.z[i]) # primal residual (x-z)
# update convergence criteria (one for each constraint)
solver.rᵏ[i] = norm(solver.regTrafo[i] * solver.x - solver.z[i]) # primal residual (x-z)
solver.sᵏ[i] = norm(solver.ρ[i] * adjoint(solver.regTrafo[i]) * (solver.z[i] .- solver.zᵒˡᵈ[i])) # dual residual (concerning f(x))

solver.ɛᵖʳⁱ[i] = max(norm(solver.regTrafo[i]*solver.x), norm(solver.z[i]))
solver.ɛᵖʳⁱ[i] = max(norm(solver.regTrafo[i] * solver.x), norm(solver.z[i]))
solver.ɛᵈᵘᵃ[i] = norm(solver.ρ[i] * adjoint(solver.regTrafo[i]) * solver.u[i])

Δᵒˡᵈ = solver.Δ[i]
Expand All @@ -244,28 +242,23 @@ function iterate(solver::ADMM, iteration=0)
end

if solver.verbose
println("rᵏ[$i] = $(solver.rᵏ[i])")
println("sᵏ[$i] = $(solver.sᵏ[i])")
println("ɛᵖʳⁱ[$i] = $(solver.ɛᵖʳⁱ[i])")
println("ɛᵈᵘᵃ[$i] = $(solver.ɛᵈᵘᵃ[i])")
println("Δᵒˡᵈ = $(Δᵒˡᵈ)")
println("Δ[$i] = $(solver.Δ[i])")
println("Δ/Δᵒˡᵈ = $(solver.Δ[i]/Δᵒˡᵈ)")
println("current ρ[$i] = $(solver.ρ[i])")
println("rᵏ[$i]/ɛᵖʳⁱ[$i] = $(solver.rᵏ[i]/solver.ɛᵖʳⁱ[i])")
println("sᵏ[$i]/ɛᵈᵘᵃ[$i] = $(solver.sᵏ[i]/solver.ɛᵈᵘᵃ[i])")
println("Δ[$i]/Δᵒˡᵈ[$i] = $(solver.Δ[i]/Δᵒˡᵈ)")
println("new ρ[$i] = $(solver.ρ[i])")
flush(stdout)
end
end

# return the primal feasibility measure as item and iteration number as state
return solver.rᵏ, iteration+1
end

function converged(solver::ADMM)
for i=1:length(solver.reg)
for i ∈ eachindex(solver.reg)
(solver.rᵏ[i] >= solver.σᵃᵇˢ + solver.relTol * solver.ɛᵖʳⁱ[i]) && return false
(solver.sᵏ[i] >= solver.σᵃᵇˢ + solver.relTol * solver.ɛᵈᵘᵃ[i]) && return false
end
return true
end

@inline done(solver::ADMM,iteration::Int) = converged(solver) || iteration>=solver.iterations
@inline done(solver::ADMM,iteration::Int) = converged(solver) || iteration >= solver.iterations
Loading