From 309cb0bd96d999a34887e9b2f51add36d0f34cd1 Mon Sep 17 00:00:00 2001 From: nHackel Date: Thu, 21 Sep 2023 10:49:21 +0200 Subject: [PATCH] Readded regtrafo field for ADMM and SplitBregman --- src/ADMM.jl | 52 +++++++++++++++++++------------------- src/SplitBregman.jl | 61 +++++++++++++++++++++------------------------ 2 files changed, 54 insertions(+), 59 deletions(-) diff --git a/src/ADMM.jl b/src/ADMM.jl index 1fecd16d..3f0dd457 100644 --- a/src/ADMM.jl +++ b/src/ADMM.jl @@ -1,9 +1,10 @@ export admm, ADMM -mutable struct ADMM{rT,matT,opT,R,vecT,rvecT,preconT} <: AbstractPrimalDualSolver where {vecT <: AbstractVector{Union{rT, Complex{rT}}}, rvecT <: AbstractVector{rT}} +mutable struct ADMM{rT,matT,opT,R,ropT,vecT,rvecT,preconT} <: AbstractPrimalDualSolver where {vecT <: AbstractVector{Union{rT, Complex{rT}}}, rvecT <: AbstractVector{rT}} # operators and regularization A::matT reg::Vector{R} + regTrafo::Vector{ropT} # fields and operators for x update AᴴA::opT β::vecT @@ -86,17 +87,17 @@ function ADMM(A::matT, x::Vector{T}=zeros(eltype(A),size(A,2)); reg = vec(reg) # using a custom method of vec(.) - regs = AbstractRegularization[] - for (i, r) in enumerate(reg) + regTrafo = [] + # Retrieve constraint trafos + for r in reg trafoReg = findfirst(ConstraintTransformedRegularization, r) if isnothing(trafoReg) - regTrafo = opEye(eltype(x),size(A,2)) - push!(regs, ConstraintTransformedRegularization(r, regTrafo)) + push!(regTrafo, opEye(eltype(x),size(A,2))) else - push!(regs, r) + push!(regTrafo, trafoReg) end end - reg = identity.(regs) + regTrafo = identity.(regTrafo) xᵒˡᵈ = similar(x) @@ -127,7 +128,7 @@ function ADMM(A::matT, x::Vector{T}=zeros(eltype(A),size(A,2)); # normalization parameters reg = normalize(ADMM, normalizeReg, reg, A, nothing) - return ADMM(A,reg,AᴴA,β,β_y,x,xᵒˡᵈ,z,zᵒˡᵈ,u,uᵒˡᵈ,precon,ρ_vec,iterations + return ADMM(A,reg, regTrafo,AᴴA,β,β_y,x,xᵒˡᵈ,z,zᵒˡᵈ,u,uᵒˡᵈ,precon,ρ_vec,iterations ,iterationsInner,statevars, rᵏ,sᵏ,ɛᵖʳⁱ,ɛᵈᵘᵃ,zero(real(T)),Δ,absTol,relTol,tolInner ,normalizeReg, vary_ρ, verbose) end @@ -141,11 +142,11 @@ end (re-) initializes the ADMM iterator """ -function init!(solver::ADMM{rT,matT,opT,R,vecT,rvecT,preconT}, b::vecT +function init!(solver::ADMM{rT,matT,opT,R, ropT,vecT,rvecT,preconT}, b::vecT ; A::matT=solver.A , AᴴA::opT=solver.AᴴA , x::vecT=similar(b,0) - , kargs...) where {rT,matT,opT,R,vecT,rvecT,preconT} + , kargs...) where {rT,matT,opT,R,ropT,vecT,rvecT,preconT} # operators if A != solver.A @@ -161,9 +162,8 @@ function init!(solver::ADMM{rT,matT,opT,R,vecT,rvecT,preconT}, b::vecT end # primal and dual variables - for (i, reg) in enumerate(solver.reg) - regTrafo = transform(findfirst(ConstraintTransformedRegularization, reg)) - solver.z[i] .= regTrafo*solver.x + for i=1:length(solver.reg) + solver.z[i] .= solver.regTrafo[i]*solver.x solver.u[i] .= 0 end @@ -202,7 +202,7 @@ solves an inverse problem using ADMM. when a `SolverInfo` objects is passed, the primal residuals `solver.rᵏ` and the dual residual `norm(solver.sᵏ)` are stored in `solverInfo.convMeas`. """ -function solve(solver::ADMM{rT,matT,opT,ropT,vecT,rvecT,preconT}, b::vecT; A=solver.A, AᴴA=solver.AᴴA, startVector::vecT=similar(b,0), solverInfo=nothing, kargs...) where {rT,matT,opT,ropT,vecT,rvecT,preconT} +function solve(solver::ADMM{rT,matT,opT,R,ropT,vecT,rvecT,preconT}, b::vecT; A=solver.A, AᴴA=solver.AᴴA, startVector::vecT=similar(b,0), solverInfo=nothing, kargs...) where {rT,matT,opT,R,ropT,vecT,rvecT,preconT} # initialize solver parameters init!(solver, b; A=A, AᴴA=AᴴA, x=startVector) @@ -226,39 +226,37 @@ function iterate(solver::ADMM, iteration::Integer=0) if done(solver, iteration) return nothing end solver.verbose && println("Outer ADMM Iteration #$iteration") - regTrafos = map(reg -> transform(findfirst(ConstraintTransformedRegularization, reg)), solver.reg) - # 1. solve arg min_x 1/2|| Ax-b ||² + ρ/2 Σ_i||Φi*x+ui-zi||² # <=> (A'A+ρ Σ_i Φi'Φi)*x = A'b+ρΣ_i Φi'(zi-ui) solver.β .= solver.β_y AᴴA = solver.AᴴA - for (i, reg) in enumerate(solver.reg) - solver.β[:] .+= solver.ρ[i]*adjoint(i)*(solver.z[i].-solver.u[i]) - AᴴA += solver.ρ[i] * adjoint(regTrafos[i]) * regTrafos[i] + for i=1:length(solver.reg) + solver.β[:] .+= solver.ρ[i]*adjoint(solver.regTrafo[i])*(solver.z[i].-solver.u[i]) + AᴴA += solver.ρ[i] * adjoint(solver.regTrafo[i]) * solver.regTrafo[i] end solver.verbose && println("conjugated gradients: ") solver.xᵒˡᵈ .= solver.x cg!(solver.x, AᴴA, solver.β, Pl=solver.precon , maxiter=solver.iterationsInner, reltol=solver.tolInner, statevars=solver.cgStateVars, verbose = solver.verbose) - for (i, reg) in enumerate(solver.reg) + for i=1:length(solver.reg) # 2. update z using the proximal map of 1/ρ*g(x) solver.zᵒˡᵈ[i] .= solver.z[i] - solver.z[i] .= regTrafos[i]*solver.x .+ solver.u[i] + solver.z[i] .= solver.regTrafo[i]*solver.x .+ solver.u[i] if solver.ρ[i] != 0 - prox!(reg, solver.z[i], λ(reg)/solver.ρ[i]) + prox!(solver.reg[i], solver.z[i], λ(solver.reg[i])/solver.ρ[i]) end # 3. update u solver.uᵒˡᵈ[i] .= solver.u[i] - solver.u[i] .+= regTrafos[i]*solver.x .- solver.z[i] + solver.u[i] .+= solver.regTrafo[i]*solver.x .- solver.z[i] # update convergence measures (one for each constraint) - solver.rᵏ[i] = norm(regTrafos[i]*solver.x-solver.z[i]) # primal residual (x-z) - solver.sᵏ[i] = norm(solver.ρ[i] * adjoint(regTrafos[i]) * (solver.z[i] .- solver.zᵒˡᵈ[i])) # dual residual (concerning f(x)) + solver.rᵏ[i] = norm(solver.regTrafo[i]*solver.x-solver.z[i]) # primal residual (x-z) + solver.sᵏ[i] = norm(solver.ρ[i] * adjoint(solver.regTrafo[i]) * (solver.z[i] .- solver.zᵒˡᵈ[i])) # dual residual (concerning f(x)) - solver.ɛᵖʳⁱ[i] = max(norm(regTrafos[i]*solver.x), norm(solver.z[i])) - solver.ɛᵈᵘᵃ[i] = norm(solver.ρ[i] * adjoint(regTrafos[i]) * solver.u[i]) + solver.ɛᵖʳⁱ[i] = max(norm(solver.regTrafo[i]*solver.x), norm(solver.z[i])) + solver.ɛᵈᵘᵃ[i] = norm(solver.ρ[i] * adjoint(solver.regTrafo[i]) * solver.u[i]) Δᵒˡᵈ = solver.Δ[i] solver.Δ[i] = norm(solver.x .- solver.xᵒˡᵈ ) + diff --git a/src/SplitBregman.jl b/src/SplitBregman.jl index 01f04ae8..7908a699 100644 --- a/src/SplitBregman.jl +++ b/src/SplitBregman.jl @@ -1,9 +1,10 @@ export SplitBregman -mutable struct SplitBregman{matT,vecT,opT,R,rvecT,preconT,rT} <: AbstractPrimalDualSolver +mutable struct SplitBregman{matT,vecT,opT,R,ropT,rvecT,preconT,rT} <: AbstractPrimalDualSolver # oerators and regularization A::matT reg::Vector{R} + regTrafo::Vector{ropT} y::vecT # fields and operators for x update op::opT @@ -72,17 +73,18 @@ function SplitBregman(A::matT, x::vecT=zeros(eltype(A),size(A,2)), b=nothing; , normalizeReg::AbstractRegularizationNormalization = NoNormalization() , kargs...) where {matT, vecT<:AbstractVector} - regs = AbstractRegularization[] - for (i, r) in enumerate(reg) + regTrafo = [] + # Retrieve constraint trafos + for r in reg trafoReg = findfirst(ConstraintTransformedRegularization, r) if isnothing(trafoReg) - regTrafo = opEye(eltype(x),size(A,2)) - push!(regs, ConstraintTransformedRegularization(r, regTrafo)) + push!(regTrafo, opEye(eltype(x),size(A,2))) else - push!(regs, r) + push!(regTrafo, trafoReg) end end - reg = identity.(regs) + regTrafo = identity.(regTrafo) + if b==nothing @@ -93,9 +95,8 @@ function SplitBregman(A::matT, x::vecT=zeros(eltype(A),size(A,2)), b=nothing; # operator and fields for the update of x op = A'*A - for (i, reg) in enumerate(reg) - regTrafo = transform(findfirst(ConstraintTransformedRegularization, reg)) - op += ρ[i]*adjoint(regTrafo)*regTrafo + for i=1:length(vec(reg)) + op += ρ[i]*adjoint(regTrafo[i])*regTrafo[i] end β = similar(x) β_yj = similar(x) @@ -130,7 +131,7 @@ function SplitBregman(A::matT, x::vecT=zeros(eltype(A),size(A,2)), b=nothing; # normalization parameters reg = normalize(SplitBregman, normalizeReg, vec(reg), A, nothing) - return SplitBregman(A,reg,y,op,β,β_yj,y_j,u,v,vᵒˡᵈ,b,precon,ρ_vec + return SplitBregman(A,reg,regTrafo,y,op,β,β_yj,y_j,u,v,vᵒˡᵈ,b,precon,ρ_vec ,iterations,iterationsInner,iterationsCG,statevars, rk,sk ,eps_pri,eps_dt,0.0,absTol,relTol,tolInner,iter_cnt,normalizeReg) end @@ -143,18 +144,17 @@ end (re-) initializes the SplitBregman iterator """ -function init!(solver::SplitBregman{matT,vecT,opT,ropT,rvecT,preconT}, b::vecT +function init!(solver::SplitBregman{matT,vecT,opT,R,ropT,rvecT,preconT}, b::vecT ; A::matT=solver.A , u::vecT=similar(b,0) - , kargs...) where {matT,vecT,opT,ropT,rvecT,preconT} + , kargs...) where {matT,vecT,opT,R,ropT,rvecT,preconT} # operators if A != solver.A solver.A = A solver.op = A'*A - for (i, reg) in enumerate(solver.reg) - regTrafo = transform(findfirst(ConstraintTransformedRegularization, reg)) - solver.op += ρ[i]*adjoint(regTrafo)*regTrafo + for i=1:length(vec(reg)) + solver.op += ρ[i]*adjoint(regTrafo[i])*regTrafo[i] end end solver.y = b @@ -171,9 +171,8 @@ function init!(solver::SplitBregman{matT,vecT,opT,ropT,rvecT,preconT}, b::vecT end # primal and dual variables - for (i, reg) in enumerate(solver.reg) - regTrafo = transform(findfirst(ConstraintTransformedRegularization, reg)) - solver.v[i][:] .= regTrafo*solver.u + for i=1:length(solver.reg) + solver.v[i][:] .= solver.regTrafo[i]*solver.u solver.vᵒˡᵈ[i][:] .= 0 solver.b[i][:] .= 0 end @@ -255,45 +254,43 @@ The Split Bregman Method for l1 Regularized Problems * (`tolInner::Float64=1.e-3`) - relative tolerance for CG * (`solverInfo = nothing`) - `solverInfo` object used to store convergence metrics """ -function iterate(solver::SplitBregman{matT,vecT,opT,rvecT,preconT}, iteration::Int=1) where {matT,vecT,opT,rvecT,preconT} +function iterate(solver::SplitBregman{matT, vecT, opT, R, ropT, rvecT, preconT, rT}, iteration::Int=1) where {matT, vecT, opT, R, ropT, rvecT, preconT, rT} if done(solver, iteration) return nothing end - regTrafos = map(reg -> transform(findfirst(ConstraintTransformedRegularization, reg)), solver.reg) - # update u solver.β[:] .= solver.β_yj for i=1:length(solver.reg) - solver.β[:] .+= solver.ρ[i]*adjoint(regTrafos[i])*(solver.v[i].-solver.b[i]) + solver.β[:] .+= solver.ρ[i]*adjoint(solver.regTrafo[i])*(solver.v[i].-solver.b[i]) end cg!(solver.u,solver.op,solver.β,Pl=solver.precon,maxiter=solver.iterationsCG,reltol=solver.tolInner) # proximal map for regularization terms - for (i, reg) in enumerate(solver.reg) + for i=1:length(solver.reg) copyto!(solver.vᵒˡᵈ[i], solver.v[i]) - solver.v[i][:] .= regTrafos[i]*solver.u .+ solver.b[i] + solver.v[i][:] .= solver.regTrafo[i]*solver.u .+ solver.b[i] if solver.ρ[i] != 0 - prox!(reg, solver.v[i], λ(reg)/solver.ρ[i]) + prox!(solver.reg[i], solver.v[i], λ(solver.reg[i])/solver.ρ[i]) end end # update b for i=1:length(solver.reg) - solver.b[i] .+= regTrafos[i]*solver.u .- solver.v[i] + solver.b[i] .+= solver.regTrafo[i]*solver.u .- solver.v[i] end # update convergence criteria # primal residuals norms (one for each constraint) for i=1:length(solver.reg) - solver.rk[i] = norm(regTrafos[i]*solver.u-solver.v[i]) - solver.eps_pri[i] = solver.σᵃᵇˢ + solver.relTol*max( norm(regTrafos[i]*solver.u), norm(solver.v[i]) ) + solver.rk[i] = norm(solver.regTrafo[i]*solver.u-solver.v[i]) + solver.eps_pri[i] = solver.σᵃᵇˢ + solver.relTol*max( norm(solver.regTrafo[i]*solver.u), norm(solver.v[i]) ) end # accumulated dual residual # effectively this corresponds to combining all constraints into one larger constraint. solver.sk[:] .= 0.0 solver.eps_dt[:] .= 0.0 for i=1:length(solver.reg) - solver.sk[:] .+= solver.ρ[i]*adjoint(regTrafos[i])*(solver.v[i].-solver.vᵒˡᵈ[i]) - solver.eps_dt[:] .+= solver.ρ[i]*adjoint(regTrafos[i])*solver.b[i] + solver.sk[:] .+= solver.ρ[i]*adjoint(solver.regTrafo[i])*(solver.v[i].-solver.vᵒˡᵈ[i]) + solver.eps_dt[:] .+= solver.ρ[i]*adjoint(solver.regTrafo[i])*solver.b[i] end if update_y(solver,iteration) @@ -301,7 +298,7 @@ function iterate(solver::SplitBregman{matT,vecT,opT,rvecT,preconT}, iteration::I solver.β_yj[:] .= adjoint(solver.A) * solver.y_j # reset v and b for i=1:length(solver.reg) - solver.v[i][:] .= regTrafos[i]*solver.u + solver.v[i][:] .= solver.regTrafo[i]*solver.u solver.b[i] .= 0 end solver.iter_cnt += 1