Skip to content

Commit

Permalink
Merge pull request #38 from JakobAsslaender/OverlappingLLR
Browse files Browse the repository at this point in the history
Overlapping LLR + ADMM update scheme
  • Loading branch information
tknopp authored May 17, 2022
2 parents 278fa22 + 77ab4d6 commit 3545262
Show file tree
Hide file tree
Showing 6 changed files with 395 additions and 178 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.8.4"
[deps]
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand Down
57 changes: 38 additions & 19 deletions src/ADMM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@ mutable struct ADMM{rT,matT,opT,ropT,vecT,rvecT,preconT} <: AbstractLinearSolver
β_y::vecT
# fields for primal & dual variables
x::vecT
xᵒˡᵈ::vecT
z::Vector{vecT}
zᵒˡᵈ::Vector{vecT}
u::Vector{vecT}
uᵒˡᵈ::Vector{vecT}
# other parameters
precon::preconT
ρ::rvecT # TODO: Switch all these vectors to Tuple
Expand All @@ -27,12 +29,13 @@ mutable struct ADMM{rT,matT,opT,ropT,vecT,rvecT,preconT} <: AbstractLinearSolver
ɛᵖʳⁱ::rvecT
ɛᵈᵘᵃ::rvecT
σᵃᵇˢ::rT
Δ::rvecT
absTol::rT
relTol::rT
tolInner::rT
normalizeReg::Bool
regFac::rT
vary_ρ::Bool
vary_ρ::Symbol
verbose::Bool
end

Expand Down Expand Up @@ -70,10 +73,10 @@ function ADMM(A::matT, x::Vector{T}=zeros(eltype(A),size(A,2)); reg=nothing, reg
, relTol::Real=eps(real(T))
, tolInner::Real=1e-5
, normalizeReg::Bool=false
, vary_ρ::Bool=false
, vary_ρ::Symbol=:none
, verbose::Bool=false
, kargs...) where {T,matT,opT}
# TODO: The constructor is not type stable
# TODO: The constructor is not type stable

# unify Floating types
if typeof(ρ) <: Number
Expand All @@ -96,10 +99,13 @@ function ADMM(A::matT, x::Vector{T}=zeros(eltype(A),size(A,2)); reg=nothing, reg
regTrafo = [opEye(eltype(x),size(A,2)) for _=1:length(reg)]
end

xᵒˡᵈ = similar(x)

# fields for primal & dual variables
z = [similar(x, size(regTrafo[i],1)) for i=1:length(reg)]
zᵒˡᵈ = [similar(z[i]) for i=1:length(reg)]
u = [similar(z[i]) for i=1:length(reg)]
uᵒˡᵈ = [similar(u[i]) for i=1:length(reg)]

# operator and fields for the update of x
if AᴴA == nothing
Expand All @@ -117,10 +123,10 @@ function ADMM(A::matT, x::Vector{T}=zeros(eltype(A),size(A,2)); reg=nothing, reg
sᵏ = similar(rᵏ)
ɛᵖʳⁱ = similar(rᵏ)
ɛᵈᵘᵃ = similar(rᵏ)
Δ = similar(rᵏ)


return ADMM(A,reg,regTrafo,AᴴA,β,β_y,x,z,zᵒˡᵈ,u,precon,ρ_vec,iterations
,iterationsInner,statevars, rᵏ,sᵏ,ɛᵖʳⁱ,ɛᵈᵘᵃ,zero(real(T)),absTol,relTol,tolInner
return ADMM(A,reg,regTrafo,AᴴA,β,β_y,x,xᵒˡᵈ,z,zᵒˡᵈ,u,uᵒˡᵈ,precon,ρ_vec,iterations
,iterationsInner,statevars, rᵏ,sᵏ,ɛᵖʳⁱ,ɛᵈᵘᵃ,zero(real(T)),Δ,absTol,relTol,tolInner
,normalizeReg,one(real(T)), vary_ρ, verbose)
end

Expand Down Expand Up @@ -154,8 +160,7 @@ function init!(solver::ADMM{rT,matT,opT,ropT,vecT,rvecT,preconT}, b::vecT

# primal and dual variables
for i=1:length(solver.reg)
solver.z[i][:] .= solver.regTrafo[i]*solver.x
solver.zᵒˡᵈ[i][:] .= 0
solver.z[i] .= solver.regTrafo[i]*solver.x
solver.u[i] .= 0
end

Expand All @@ -168,6 +173,7 @@ function init!(solver::ADMM{rT,matT,opT,ropT,vecT,rvecT,preconT}, b::vecT
solver.ɛᵖʳⁱ .= 0
solver.ɛᵈᵘᵃ .= 0
solver.σᵃᵇˢ = sqrt(length(b))*solver.absTol
solver.Δ .= Inf

# normalization of regularization parameters
if solver.normalizeReg
Expand Down Expand Up @@ -218,28 +224,31 @@ performs one ADMM iteration.
"""
function iterate(solver::ADMM, iteration::Integer=0)
if done(solver, iteration) return nothing end
solver.verbose && println("Outer ADMM Iteration #$iteration")

# 1. solve arg min_x 1/2|| Ax-b ||² + ρ/2 Σ_i||Φi*x+ui-zi||²
# <=> (A'A+ρ Σ_i Φi'Φi)*x = A'b+ρΣ_i Φi'(zi-ui)
copyto!(solver.β, solver.β_y)
solver.β .= solver.β_y
AᴴA = solver.AᴴA
for i=1:length(solver.reg)
solver.β[:] .+= solver.ρ[i]*adjoint(solver.regTrafo[i])*(solver.z[i].-solver.u[i])
AᴴA += solver.ρ[i] * adjoint(solver.regTrafo[i]) * solver.regTrafo[i]
end
solver.verbose && println("conjugated gardients: ")
solver.xᵒˡᵈ .= solver.x
cg!(solver.x, AᴴA, solver.β, Pl=solver.precon
, maxiter=solver.iterationsInner, reltol=solver.tolInner, statevars=solver.cgStateVars, verbose = solver.verbose)

for i=1:length(solver.reg)
# 2. update z using the proximal map of 1/ρ*g(x)
copyto!(solver.zᵒˡᵈ[i], solver.z[i])
solver.zᵒˡᵈ[i] .= solver.z[i]
solver.z[i] .= solver.regTrafo[i]*solver.x .+ solver.u[i]
if solver.ρ[i] != 0
solver.reg[i].prox!(solver.z[i], solver.regFac*solver.reg[i].λ/solver.ρ[i]; solver.reg[i].params...)
end

# 3. update u
solver.uᵒˡᵈ[i] .= solver.u[i]
solver.u[i] .+= solver.regTrafo[i]*solver.x .- solver.z[i]

# update convergence measures (one for each constraint)
Expand All @@ -249,20 +258,30 @@ function iterate(solver::ADMM, iteration::Integer=0)
solver.ɛᵖʳⁱ[i] = max(norm(solver.regTrafo[i]*solver.x), norm(solver.z[i]))
solver.ɛᵈᵘᵃ[i] = norm(solver.ρ[i] * adjoint(solver.regTrafo[i]) * solver.u[i])

if solver.verbose
println("rᵏ[$i] = $(solver.rᵏ[i])")
println("sᵏ[$i] = $(solver.sᵏ[i])")
end
Δᵒˡᵈ = solver.Δ[i]
solver.Δ[i] = norm(solver.x .- solver.xᵒˡᵈ ) +
norm(solver.z[i] .- solver.zᵒˡᵈ[i]) +
norm(solver.u[i] .- solver.uᵒˡᵈ[i])

# adapt ρ according to Boyd et al.
if solver.vary_ρ && solver.rᵏ[i] > 10solver.sᵏ[i]
if (solver.vary_ρ == :balance && solver.rᵏ[i]/solver.ɛᵖʳⁱ[i] > 10solver.sᵏ[i]/solver.ɛᵈᵘᵃ[i]) || # adapt ρ according to Boyd et al.
(solver.vary_ρ == :PnP && solver.Δ[i]/Δᵒˡᵈ > 0.9) # adapt ρ according to Chang et al.
solver.ρ[i] *= 2
solver.u[i] ./= 2
solver.verbose && println("updated ρ[$i] = $(solver.ρ[i])")
elseif solver.vary_ρ && solver.sᵏ[i] > 10solver.rᵏ[i]
elseif solver.vary_ρ == :balance && solver.sᵏ[i]/solver.ɛᵈᵘᵃ[i] > 10solver.rᵏ[i]/solver.ɛᵖʳⁱ[i]
solver.ρ[i] /= 2
solver.u[i] .*= 2
solver.verbose && println("updated ρ[$i] = $(solver.ρ[i])")
end

if solver.verbose
println("rᵏ[$i] = $(solver.rᵏ[i])")
println("sᵏ[$i] = $(solver.sᵏ[i])")
println("ɛᵖʳⁱ[$i] = $(solver.ɛᵖʳⁱ[i])")
println("ɛᵈᵘᵃ[$i] = $(solver.ɛᵈᵘᵃ[i])")
println("Δᵒˡᵈ = $(Δᵒˡᵈ)")
println("Δ[$i] = $(solver.Δ[i])")
println("Δ/Δᵒˡᵈ = $(solver.Δ[i]/Δᵒˡᵈ)")
println("current ρ[$i] = $(solver.ρ[i])")
flush(stdout)
end
end

Expand Down
5 changes: 3 additions & 2 deletions src/RegularizedLeastSquares.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using IterativeSolvers
using Random
using VectorizationBase
using VectorizationBase: shufflevector, zstridedpointer
using Polyester
#@reexport using SparsityOperators
using SparsityOperators: normalOperator, opEye
using ProgressMeter
Expand Down Expand Up @@ -81,7 +82,7 @@ Function returns choosen solver.
* `"fista"` - Fast Iterative Shrinkage Thresholding Algorithm
* `"admm"` - Alternating Direcion of Multipliers Method
* `"splitBregman"` - Split Bregman method for constrained & regularized inverse problems
* `"primaldualsolver"`- First order primal dual method
* `"primaldualsolver"`- First order primal dual method
"""
function createLinearSolver(solver::AbstractString, A, x=zeros(eltype(A),size(A,2));
log::Bool=false, kargs...)
Expand All @@ -90,7 +91,7 @@ function createLinearSolver(solver::AbstractString, A, x=zeros(eltype(A),size(A,

if solver == "kaczmarz"
return Kaczmarz(A; kargs...)
elseif solver == "kaczmarzUpdated"
elseif solver == "kaczmarzUpdated"
return KaczmarzUpdated(A; kargs...)
elseif solver == "cgnr"
return CGNR(A, x; kargs...)
Expand Down
Loading

0 comments on commit 3545262

Please sign in to comment.