Skip to content

Commit

Permalink
Readded regtrafo field for ADMM and SplitBregman
Browse files Browse the repository at this point in the history
nHackel committed Sep 21, 2023
1 parent aab0431 commit 309cb0b
Showing 2 changed files with 54 additions and 59 deletions.
52 changes: 25 additions & 27 deletions src/ADMM.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
export admm, ADMM

mutable struct ADMM{rT,matT,opT,R,vecT,rvecT,preconT} <: AbstractPrimalDualSolver where {vecT <: AbstractVector{Union{rT, Complex{rT}}}, rvecT <: AbstractVector{rT}}
mutable struct ADMM{rT,matT,opT,R,ropT,vecT,rvecT,preconT} <: AbstractPrimalDualSolver where {vecT <: AbstractVector{Union{rT, Complex{rT}}}, rvecT <: AbstractVector{rT}}
# operators and regularization
A::matT
reg::Vector{R}
regTrafo::Vector{ropT}
# fields and operators for x update
AᴴA::opT
β::vecT
@@ -86,17 +87,17 @@ function ADMM(A::matT, x::Vector{T}=zeros(eltype(A),size(A,2));

reg = vec(reg) # using a custom method of vec(.)

regs = AbstractRegularization[]
for (i, r) in enumerate(reg)
regTrafo = []
# Retrieve constraint trafos
for r in reg
trafoReg = findfirst(ConstraintTransformedRegularization, r)
if isnothing(trafoReg)
regTrafo = opEye(eltype(x),size(A,2))
push!(regs, ConstraintTransformedRegularization(r, regTrafo))
push!(regTrafo, opEye(eltype(x),size(A,2)))
else
push!(regs, r)
push!(regTrafo, trafoReg)

Check warning on line 97 in src/ADMM.jl

Codecov / codecov/patch

src/ADMM.jl#L97

Added line #L97 was not covered by tests
end
end
reg = identity.(regs)
regTrafo = identity.(regTrafo)

xᵒˡᵈ = similar(x)

@@ -127,7 +128,7 @@ function ADMM(A::matT, x::Vector{T}=zeros(eltype(A),size(A,2));
# normalization parameters
reg = normalize(ADMM, normalizeReg, reg, A, nothing)

return ADMM(A,reg,AᴴA,β,β_y,x,xᵒˡᵈ,z,zᵒˡᵈ,u,uᵒˡᵈ,precon,ρ_vec,iterations
return ADMM(A,reg, regTrafo,AᴴA,β,β_y,x,xᵒˡᵈ,z,zᵒˡᵈ,u,uᵒˡᵈ,precon,ρ_vec,iterations
,iterationsInner,statevars, rᵏ,sᵏ,ɛᵖʳⁱ,ɛᵈᵘᵃ,zero(real(T)),Δ,absTol,relTol,tolInner
,normalizeReg, vary_ρ, verbose)
end
@@ -141,11 +142,11 @@ end
(re-) initializes the ADMM iterator
"""
function init!(solver::ADMM{rT,matT,opT,R,vecT,rvecT,preconT}, b::vecT
function init!(solver::ADMM{rT,matT,opT,R, ropT,vecT,rvecT,preconT}, b::vecT
; A::matT=solver.A
, AᴴA::opT=solver.AᴴA
, x::vecT=similar(b,0)
, kargs...) where {rT,matT,opT,R,vecT,rvecT,preconT}
, kargs...) where {rT,matT,opT,R,ropT,vecT,rvecT,preconT}

# operators
if A != solver.A
@@ -161,9 +162,8 @@ function init!(solver::ADMM{rT,matT,opT,R,vecT,rvecT,preconT}, b::vecT
end

# primal and dual variables
for (i, reg) in enumerate(solver.reg)
regTrafo = transform(findfirst(ConstraintTransformedRegularization, reg))
solver.z[i] .= regTrafo*solver.x
for i=1:length(solver.reg)
solver.z[i] .= solver.regTrafo[i]*solver.x
solver.u[i] .= 0
end

@@ -202,7 +202,7 @@ solves an inverse problem using ADMM.
when a `SolverInfo` objects is passed, the primal residuals `solver.rᵏ`
and the dual residual `norm(solver.sᵏ)` are stored in `solverInfo.convMeas`.
"""
function solve(solver::ADMM{rT,matT,opT,ropT,vecT,rvecT,preconT}, b::vecT; A=solver.A, AᴴA=solver.AᴴA, startVector::vecT=similar(b,0), solverInfo=nothing, kargs...) where {rT,matT,opT,ropT,vecT,rvecT,preconT}
function solve(solver::ADMM{rT,matT,opT,R,ropT,vecT,rvecT,preconT}, b::vecT; A=solver.A, AᴴA=solver.AᴴA, startVector::vecT=similar(b,0), solverInfo=nothing, kargs...) where {rT,matT,opT,R,ropT,vecT,rvecT,preconT}
# initialize solver parameters
init!(solver, b; A=A, AᴴA=AᴴA, x=startVector)

@@ -226,39 +226,37 @@ function iterate(solver::ADMM, iteration::Integer=0)
if done(solver, iteration) return nothing end
solver.verbose && println("Outer ADMM Iteration #$iteration")

regTrafos = map(reg -> transform(findfirst(ConstraintTransformedRegularization, reg)), solver.reg)

# 1. solve arg min_x 1/2|| Ax-b ||² + ρ/2 Σ_i||Φi*x+ui-zi||²
# <=> (A'A+ρ Σ_i Φi'Φi)*x = A'b+ρΣ_i Φi'(zi-ui)
solver.β .= solver.β_y
AᴴA = solver.AᴴA
for (i, reg) in enumerate(solver.reg)
solver.β[:] .+= solver.ρ[i]*adjoint(i)*(solver.z[i].-solver.u[i])
AᴴA += solver.ρ[i] * adjoint(regTrafos[i]) * regTrafos[i]
for i=1:length(solver.reg)
solver.β[:] .+= solver.ρ[i]*adjoint(solver.regTrafo[i])*(solver.z[i].-solver.u[i])
AᴴA += solver.ρ[i] * adjoint(solver.regTrafo[i]) * solver.regTrafo[i]
end
solver.verbose && println("conjugated gradients: ")
solver.xᵒˡᵈ .= solver.x
cg!(solver.x, AᴴA, solver.β, Pl=solver.precon
, maxiter=solver.iterationsInner, reltol=solver.tolInner, statevars=solver.cgStateVars, verbose = solver.verbose)

for (i, reg) in enumerate(solver.reg)
for i=1:length(solver.reg)
# 2. update z using the proximal map of 1/ρ*g(x)
solver.zᵒˡᵈ[i] .= solver.z[i]
solver.z[i] .= regTrafos[i]*solver.x .+ solver.u[i]
solver.z[i] .= solver.regTrafo[i]*solver.x .+ solver.u[i]
if solver.ρ[i] != 0
prox!(reg, solver.z[i], λ(reg)/solver.ρ[i])
prox!(solver.reg[i], solver.z[i], λ(solver.reg[i])/solver.ρ[i])
end

# 3. update u
solver.uᵒˡᵈ[i] .= solver.u[i]
solver.u[i] .+= regTrafos[i]*solver.x .- solver.z[i]
solver.u[i] .+= solver.regTrafo[i]*solver.x .- solver.z[i]

# update convergence measures (one for each constraint)
solver.rᵏ[i] = norm(regTrafos[i]*solver.x-solver.z[i]) # primal residual (x-z)
solver.sᵏ[i] = norm(solver.ρ[i] * adjoint(regTrafos[i]) * (solver.z[i] .- solver.zᵒˡᵈ[i])) # dual residual (concerning f(x))
solver.rᵏ[i] = norm(solver.regTrafo[i]*solver.x-solver.z[i]) # primal residual (x-z)
solver.sᵏ[i] = norm(solver.ρ[i] * adjoint(solver.regTrafo[i]) * (solver.z[i] .- solver.zᵒˡᵈ[i])) # dual residual (concerning f(x))

solver.ɛᵖʳⁱ[i] = max(norm(regTrafos[i]*solver.x), norm(solver.z[i]))
solver.ɛᵈᵘᵃ[i] = norm(solver.ρ[i] * adjoint(regTrafos[i]) * solver.u[i])
solver.ɛᵖʳⁱ[i] = max(norm(solver.regTrafo[i]*solver.x), norm(solver.z[i]))
solver.ɛᵈᵘᵃ[i] = norm(solver.ρ[i] * adjoint(solver.regTrafo[i]) * solver.u[i])

Δᵒˡᵈ = solver.Δ[i]
solver.Δ[i] = norm(solver.x .- solver.xᵒˡᵈ ) +
61 changes: 29 additions & 32 deletions src/SplitBregman.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
export SplitBregman

mutable struct SplitBregman{matT,vecT,opT,R,rvecT,preconT,rT} <: AbstractPrimalDualSolver
mutable struct SplitBregman{matT,vecT,opT,R,ropT,rvecT,preconT,rT} <: AbstractPrimalDualSolver
# oerators and regularization
A::matT
reg::Vector{R}
regTrafo::Vector{ropT}
y::vecT
# fields and operators for x update
op::opT
@@ -72,17 +73,18 @@ function SplitBregman(A::matT, x::vecT=zeros(eltype(A),size(A,2)), b=nothing;
, normalizeReg::AbstractRegularizationNormalization = NoNormalization()
, kargs...) where {matT, vecT<:AbstractVector}

regs = AbstractRegularization[]
for (i, r) in enumerate(reg)
regTrafo = []
# Retrieve constraint trafos
for r in reg
trafoReg = findfirst(ConstraintTransformedRegularization, r)
if isnothing(trafoReg)
regTrafo = opEye(eltype(x),size(A,2))
push!(regs, ConstraintTransformedRegularization(r, regTrafo))
push!(regTrafo, opEye(eltype(x),size(A,2)))
else
push!(regs, r)
push!(regTrafo, trafoReg)

Check warning on line 83 in src/SplitBregman.jl

Codecov / codecov/patch

src/SplitBregman.jl#L83

Added line #L83 was not covered by tests
end
end
reg = identity.(regs)
regTrafo = identity.(regTrafo)



if b==nothing
@@ -93,9 +95,8 @@ function SplitBregman(A::matT, x::vecT=zeros(eltype(A),size(A,2)), b=nothing;

# operator and fields for the update of x
op = A'*A
for (i, reg) in enumerate(reg)
regTrafo = transform(findfirst(ConstraintTransformedRegularization, reg))
op += ρ[i]*adjoint(regTrafo)*regTrafo
for i=1:length(vec(reg))
op += ρ[i]*adjoint(regTrafo[i])*regTrafo[i]
end
β = similar(x)
β_yj = similar(x)
@@ -130,7 +131,7 @@ function SplitBregman(A::matT, x::vecT=zeros(eltype(A),size(A,2)), b=nothing;
# normalization parameters
reg = normalize(SplitBregman, normalizeReg, vec(reg), A, nothing)

return SplitBregman(A,reg,y,op,β,β_yj,y_j,u,v,vᵒˡᵈ,b,precon,ρ_vec
return SplitBregman(A,reg,regTrafo,y,op,β,β_yj,y_j,u,v,vᵒˡᵈ,b,precon,ρ_vec
,iterations,iterationsInner,iterationsCG,statevars, rk,sk
,eps_pri,eps_dt,0.0,absTol,relTol,tolInner,iter_cnt,normalizeReg)
end
@@ -143,18 +144,17 @@ end
(re-) initializes the SplitBregman iterator
"""
function init!(solver::SplitBregman{matT,vecT,opT,ropT,rvecT,preconT}, b::vecT
function init!(solver::SplitBregman{matT,vecT,opT,R,ropT,rvecT,preconT}, b::vecT
; A::matT=solver.A
, u::vecT=similar(b,0)
, kargs...) where {matT,vecT,opT,ropT,rvecT,preconT}
, kargs...) where {matT,vecT,opT,R,ropT,rvecT,preconT}

# operators
if A != solver.A
solver.A = A
solver.op = A'*A
for (i, reg) in enumerate(solver.reg)
regTrafo = transform(findfirst(ConstraintTransformedRegularization, reg))
solver.op += ρ[i]*adjoint(regTrafo)*regTrafo
for i=1:length(vec(reg))
solver.op += ρ[i]*adjoint(regTrafo[i])*regTrafo[i]
end
end
solver.y = b
@@ -171,9 +171,8 @@ function init!(solver::SplitBregman{matT,vecT,opT,ropT,rvecT,preconT}, b::vecT
end

# primal and dual variables
for (i, reg) in enumerate(solver.reg)
regTrafo = transform(findfirst(ConstraintTransformedRegularization, reg))
solver.v[i][:] .= regTrafo*solver.u
for i=1:length(solver.reg)
solver.v[i][:] .= solver.regTrafo[i]*solver.u
solver.vᵒˡᵈ[i][:] .= 0
solver.b[i][:] .= 0
end
@@ -255,53 +254,51 @@ The Split Bregman Method for l1 Regularized Problems
* (`tolInner::Float64=1.e-3`) - relative tolerance for CG
* (`solverInfo = nothing`) - `solverInfo` object used to store convergence metrics
"""
function iterate(solver::SplitBregman{matT,vecT,opT,rvecT,preconT}, iteration::Int=1) where {matT,vecT,opT,rvecT,preconT}
function iterate(solver::SplitBregman{matT, vecT, opT, R, ropT, rvecT, preconT, rT}, iteration::Int=1) where {matT, vecT, opT, R, ropT, rvecT, preconT, rT}
if done(solver, iteration) return nothing end

regTrafos = map(reg -> transform(findfirst(ConstraintTransformedRegularization, reg)), solver.reg)

# update u
solver.β[:] .= solver.β_yj
for i=1:length(solver.reg)
solver.β[:] .+= solver.ρ[i]*adjoint(regTrafos[i])*(solver.v[i].-solver.b[i])
solver.β[:] .+= solver.ρ[i]*adjoint(solver.regTrafo[i])*(solver.v[i].-solver.b[i])
end
cg!(solver.u,solver.op,solver.β,Pl=solver.precon,maxiter=solver.iterationsCG,reltol=solver.tolInner)

# proximal map for regularization terms
for (i, reg) in enumerate(solver.reg)
for i=1:length(solver.reg)
copyto!(solver.vᵒˡᵈ[i], solver.v[i])
solver.v[i][:] .= regTrafos[i]*solver.u .+ solver.b[i]
solver.v[i][:] .= solver.regTrafo[i]*solver.u .+ solver.b[i]
if solver.ρ[i] != 0
prox!(reg, solver.v[i], λ(reg)/solver.ρ[i])
prox!(solver.reg[i], solver.v[i], λ(solver.reg[i])/solver.ρ[i])
end
end

# update b
for i=1:length(solver.reg)
solver.b[i] .+= regTrafos[i]*solver.u .- solver.v[i]
solver.b[i] .+= solver.regTrafo[i]*solver.u .- solver.v[i]
end

# update convergence criteria
# primal residuals norms (one for each constraint)
for i=1:length(solver.reg)
solver.rk[i] = norm(regTrafos[i]*solver.u-solver.v[i])
solver.eps_pri[i] = solver.σᵃᵇˢ + solver.relTol*max( norm(regTrafos[i]*solver.u), norm(solver.v[i]) )
solver.rk[i] = norm(solver.regTrafo[i]*solver.u-solver.v[i])
solver.eps_pri[i] = solver.σᵃᵇˢ + solver.relTol*max( norm(solver.regTrafo[i]*solver.u), norm(solver.v[i]) )
end
# accumulated dual residual
# effectively this corresponds to combining all constraints into one larger constraint.
solver.sk[:] .= 0.0
solver.eps_dt[:] .= 0.0
for i=1:length(solver.reg)
solver.sk[:] .+= solver.ρ[i]*adjoint(regTrafos[i])*(solver.v[i].-solver.vᵒˡᵈ[i])
solver.eps_dt[:] .+= solver.ρ[i]*adjoint(regTrafos[i])*solver.b[i]
solver.sk[:] .+= solver.ρ[i]*adjoint(solver.regTrafo[i])*(solver.v[i].-solver.vᵒˡᵈ[i])
solver.eps_dt[:] .+= solver.ρ[i]*adjoint(solver.regTrafo[i])*solver.b[i]
end

if update_y(solver,iteration)
solver.y_j[:] .+= solver.y .- solver.A*solver.u
solver.β_yj[:] .= adjoint(solver.A) * solver.y_j
# reset v and b
for i=1:length(solver.reg)
solver.v[i][:] .= regTrafos[i]*solver.u
solver.v[i][:] .= solver.regTrafo[i]*solver.u
solver.b[i] .= 0
end
solver.iter_cnt += 1

0 comments on commit 309cb0b

Please sign in to comment.