Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[breaking] Remove ConstraintTransformRegularization #70

Merged
merged 2 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/src/API/regularization.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ RegularizedLeastSquares.FixedParameterRegularization
```@docs
RegularizedLeastSquares.MaskedRegularization
RegularizedLeastSquares.TransformedRegularization
RegularizedLeastSquares.ConstraintTransformedRegularization
RegularizedLeastSquares.PlugAndPlayRegularization
```

Expand Down
43 changes: 19 additions & 24 deletions src/ADMM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ mutable struct ADMM{matT,opT,R,ropT,P,vecT,rvecT,preconT,rT} <: AbstractPrimalDu
end

"""
ADMM(A; AHA = A'*A, precon = Identity(), reg = L1Regularization(zero(real(eltype(AHA)))), normalizeReg = NoNormalization(), rho = 1e-1, vary_rho = :none, iterations = 10, iterationsCG = 10, absTol = eps(real(eltype(AHA))), relTol = eps(real(eltype(AHA))), tolInner = 1e-5, verbose = false)
ADMM( ; AHA = , precon = Identity(), reg = L1Regularization(zero(real(eltype(AHA)))), normalizeReg = NoNormalization(), rho = 1e-1, vary_rho = :none, iterations = 10, iterationsCG = 10, absTol = eps(real(eltype(AHA))), relTol = eps(real(eltype(AHA))), tolInner = 1e-5, verbose = false)
ADMM(A; AHA = A'*A, precon = Identity(), reg = L1Regularization(zero(real(eltype(AHA)))), regTrafo = opEye(eltype(AHA), size(AHA,1)), normalizeReg = NoNormalization(), rho = 1e-1, vary_rho = :none, iterations = 10, iterationsCG = 10, absTol = eps(real(eltype(AHA))), relTol = eps(real(eltype(AHA))), tolInner = 1e-5, verbose = false)
ADMM( ; AHA = , precon = Identity(), reg = L1Regularization(zero(real(eltype(AHA)))), regTrafo = opEye(eltype(AHA), size(AHA,1)), normalizeReg = NoNormalization(), rho = 1e-1, vary_rho = :none, iterations = 10, iterationsCG = 10, absTol = eps(real(eltype(AHA))), relTol = eps(real(eltype(AHA))), tolInner = 1e-5, verbose = false)

Creates an `ADMM` object for the forward operator `A` or normal operator `AHA`.

Expand All @@ -54,6 +54,7 @@ Creates an `ADMM` object for the forward operator `A` or normal operator `AHA`.
* `AHA` - normal operator is optional if `A` is supplied
* `precon` - preconditionner for the internal CG algorithm
* `reg::AbstractParameterizedRegularization` - regularization term; can also be a vector of regularization terms
* `regTrafo` - transformation to a space in which `reg` is applied; if `reg` is a vector, `regTrafo` has to be a vector of the same length. Use `opEye(eltype(AHA), size(AHA,1))` if no transformation is desired.
* `normalizeReg::AbstractRegularizationNormalization` - regularization normalization scheme; options are `NoNormalization()`, `MeasurementBasedNormalization()`, `SystemMatrixBasedNormalization()`
* `rho::Real` - penalty of the augmented Lagrangian
* `vary_rho::Symbol` - vary rho to balance primal and dual feasibility; options `:none`, `:balance`, `:PnP`
Expand All @@ -64,6 +65,8 @@ Creates an `ADMM` object for the forward operator `A` or normal operator `AHA`.
* `tolInner::Real` - relative tolerance for CG stopping criterion
* `verbose::Bool` - print residual in each iteration

ADMM differs from ISTA-type algorithms in the sense that the proximal operation is applied separately from the transformation to the space in which the penalty is applied. This is reflected by the interface which has `reg` and `regTrafo` as separate arguments. E.g., for a TV penalty, you should NOT set `reg=TVRegularization`, but instead use `reg=L1Regularization(λ), regTrafo=RegularizedLeastSquares.GradientOp(Float64; shape=(Nx,Ny,Nz))`.

See also [`createLinearSolver`](@ref), [`solve!`](@ref).
"""
ADMM(; AHA, kwargs...) = ADMM(nothing; kwargs..., AHA = AHA)
Expand All @@ -72,6 +75,7 @@ function ADMM(A
; AHA = A'*A
, precon = Identity()
, reg = L1Regularization(zero(real(eltype(AHA))))
, regTrafo = opEye(eltype(AHA), size(AHA,1))
, normalizeReg::AbstractRegularizationNormalization = NoNormalization()
, rho = 1e-1
, vary_rho::Symbol = :none
Expand All @@ -86,23 +90,15 @@ function ADMM(A
T = eltype(AHA)
rT = real(T)

reg = vec(reg)
reg = isa(reg, AbstractVector) ? reg : [reg]
regTrafo = isa(regTrafo, AbstractVector) ? regTrafo : [regTrafo]
@assert length(reg) == length(regTrafo) "reg and regTrafo must have the same length"

regTrafo = []
indices = findsinks(AbstractProjectionRegularization, reg)
proj = [reg[i] for i in indices]
proj = identity.(proj)
deleteat!(reg, indices)
# Retrieve constraint trafos
for r in reg
trafoReg = findfirst(ConstraintTransformedRegularization, r)
if isnothing(trafoReg)
push!(regTrafo, opEye(T,size(AHA,2)))
else
push!(regTrafo, trafoReg.trafo)
end
end
regTrafo = identity.(regTrafo)
deleteat!(regTrafo, indices)

if typeof(rho) <: Number
rho = [rT.(rho) for _ ∈ eachindex(reg)]
Expand All @@ -116,10 +112,10 @@ function ADMM(A
β_y = similar(x)

# fields for primal & dual variables
z = [similar(x, size(regTrafo[i],1)) for i ∈ eachindex(vec(reg))]
zᵒˡᵈ = [similar(z[i]) for i ∈ eachindex(vec(reg))]
u = [similar(z[i]) for i ∈ eachindex(vec(reg))]
uᵒˡᵈ = [similar(u[i]) for i ∈ eachindex(vec(reg))]
z = [similar(x, size(regTrafo[i],1)) for i ∈ eachindex(reg)]
zᵒˡᵈ = [similar(z[i]) for i ∈ eachindex(reg)]
u = [similar(z[i]) for i ∈ eachindex(reg)]
uᵒˡᵈ = [similar(u[i]) for i ∈ eachindex(reg)]

# statevariables for CG
# we store them here to prevent CG from allocating new fields at each call
Expand All @@ -135,16 +131,15 @@ function ADMM(A
# normalization parameters
reg = normalize(ADMM, normalizeReg, reg, A, nothing)

return ADMM(A,reg,regTrafo,proj,AHA,β,β_y,x,xᵒˡᵈ,z,zᵒˡᵈ,u,uᵒˡᵈ,precon,rho,iterations
,iterationsCG,cgStateVars,rᵏ,sᵏ,ɛᵖʳⁱ,ɛᵈᵘᵃ,rT(0),Δ,rT(absTol),rT(relTol),rT(tolInner),normalizeReg,vary_rho,verbose)
return ADMM(A, reg, regTrafo, proj, AHA, β, β_y, x, xᵒˡᵈ, z, zᵒˡᵈ, u, uᵒˡᵈ, precon, rho, iterations, iterationsCG, cgStateVars, rᵏ, sᵏ, ɛᵖʳⁱ, ɛᵈᵘᵃ, rT(0), Δ, rT(absTol), rT(relTol), rT(tolInner), normalizeReg, vary_rho, verbose)
end

"""
init!(solver::ADMM, b; x0 = 0)

(re-) initializes the ADMM iterator
"""
function init!(solver::ADMM, b; x0 = 0)
function init!(solver::ADMM, b; x0=0)
solver.x .= x0

# right hand side for the x-update
Expand All @@ -156,7 +151,7 @@ function init!(solver::ADMM, b; x0 = 0)

# primal and dual variables
for i ∈ eachindex(solver.reg)
solver.z[i] .= solver.regTrafo[i]*solver.x
solver.z[i] .= solver.regTrafo[i] * solver.x
solver.u[i] .= 0
end

Expand All @@ -181,7 +176,7 @@ solverconvergence(solver::ADMM) = (; :primal => solver.rᵏ, :dual => solver.s
performs one ADMM iteration.
"""
function iterate(solver::ADMM, iteration=1)
if done(solver, iteration) return nothing end
done(solver, iteration) && return nothing
solver.verbose && println("Outer ADMM Iteration #$iteration")

# 1. solve arg min_x 1/2|| Ax-b ||² + ρ/2 Σ_i||Φi*x+ui-zi||²
Expand All @@ -195,7 +190,7 @@ function iterate(solver::ADMM, iteration=1)
end
solver.verbose && println("conjugated gradients: ")
solver.xᵒˡᵈ .= solver.x
cg!(solver.x, AHA, solver.β, Pl = solver.precon, maxiter = solver.iterationsCG, reltol = solver.tolInner, statevars = solver.cgStateVars, verbose = solver.verbose)
cg!(solver.x, AHA, solver.β, Pl=solver.precon, maxiter=solver.iterationsCG, reltol=solver.tolInner, statevars=solver.cgStateVars, verbose=solver.verbose)

for proj in solver.proj
prox!(proj, solver.x)
Expand Down
2 changes: 1 addition & 1 deletion src/CGNR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ function CGNR(A
ζl = zero(T) #temporary scalar

# Prepare regularization terms
reg = vec(reg)
reg = isa(reg, AbstractVector) ? reg : [reg]
reg = normalize(CGNR, normalizeReg, reg, A, nothing)
idx = findsink(L2Regularization, reg)
if isnothing(idx)
Expand Down
2 changes: 1 addition & 1 deletion src/FISTA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ function FISTA(A
end

# Prepare regularization terms
reg = vec(reg)
reg = isa(reg, AbstractVector) ? reg : [reg]
indices = findsinks(AbstractProjectionRegularization, reg)
other = [reg[i] for i in indices]
deleteat!(reg, indices)
Expand Down
2 changes: 1 addition & 1 deletion src/Kaczmarz.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ function Kaczmarz(A
end

# Prepare regularization terms
reg = vec(reg)
reg = isa(reg, AbstractVector) ? reg : [reg]
reg = normalize(Kaczmarz, normalizeReg, reg, A, nothing)
idx = findsink(L2Regularization, reg)
if isnothing(idx)
Expand Down
2 changes: 1 addition & 1 deletion src/OptISTA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ function OptISTA(A
θn = (1 + sqrt(1 + 8 * θn^2)) / 2

# Prepare regularization terms
reg = vec(reg)
reg = isa(reg, AbstractVector) ? reg : [reg]
indices = findsinks(AbstractProjectionRegularization, reg)
other = [reg[i] for i in indices]
deleteat!(reg, indices)
Expand Down
2 changes: 1 addition & 1 deletion src/POGM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ function POGM(A
rho /= abs(power_iterations(AHA))
end

reg = vec(reg)
reg = isa(reg, AbstractVector) ? reg : [reg]
indices = findsinks(AbstractProjectionRegularization, reg)
other = [reg[i] for i in indices]
deleteat!(reg, indices)
Expand Down
10 changes: 6 additions & 4 deletions src/PrimalDualSolver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,11 @@
) where T
M,N = size(A)

if (reg isa Vector && reg[1] isa L1Regularization) || reg isa L1Regularization
reg = isa(reg, AbstractVector) ? reg : [reg]

if reg[1] isa L1Regularization
gradientOp = opEye(T,N) #UniformScaling(one(T))
elseif (reg isa Vector && reg[1] isa TVRegularization) || reg isa TVRegularization
elseif reg[1] isa TVRegularization

Check warning on line 57 in src/PrimalDualSolver.jl

View check run for this annotation

Codecov / codecov/patch

src/PrimalDualSolver.jl#L57

Added line #L57 was not covered by tests
gradientOp = gradientOperator(T,shape)
end

Expand All @@ -63,9 +65,9 @@
y2 = zeros(T,size(gradientOp*x,1))

# normalization parameters
reg = normalize(PrimalDualSolver, normalizeReg, vec(reg), A, nothing)
reg = normalize(PrimalDualSolver, normalizeReg, reg, A, nothing)

return PrimalDualSolver(A,vec(reg),gradientOp,u,x,cO,y1,y2,T(σ),T(τ),T(ϵ),T(PrimalDualGap),enforceReal,enforcePositive,iterations,shape,
return PrimalDualSolver(A,reg,gradientOp,u,x,cO,y1,y2,T(σ),T(τ),T(ϵ),T(PrimalDualGap),enforceReal,enforcePositive,iterations,shape,
normalizeReg)
end

Expand Down
24 changes: 0 additions & 24 deletions src/Regularization/ConstraintTransformedRegularization.jl

This file was deleted.

4 changes: 0 additions & 4 deletions src/Regularization/Regularization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ include("ScaledRegularization.jl")
include("NormalizedRegularization.jl")
include("TransformedRegularization.jl")
include("MaskedRegularization.jl")
include("ConstraintTransformedRegularization.jl")
include("PlugAndPlayRegularization.jl")


Expand All @@ -88,9 +87,6 @@ end
findsinks(::Type{S}, reg::Vector{<:AbstractRegularization}) where S <: AbstractRegularization = findall(x -> sinktype(x) <: S, reg)


Base.vec(reg::AbstractRegularization) = AbstractRegularization[reg]
Base.vec(reg::AbstractVector{AbstractRegularization}) = reg

"""
RegularizationList()

Expand Down
8 changes: 4 additions & 4 deletions src/RegularizedLeastSquares.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ end
Pass `cb` as the callback to `solve!`

# Examples
```julia
```julia
julia> x_approx = solve!(solver, b) do solver, iteration
println(iteration)
end
Expand Down Expand Up @@ -181,6 +181,8 @@ include("PrimalDualSolver.jl")

include("Callbacks.jl")

include("deprecated.jl")

"""
Return a list of all available linear solvers
"""
Expand Down Expand Up @@ -259,6 +261,4 @@ function createLinearSolver(solver::Type{T}; AHA, kargs...) where {T<:AbstractLi
return solver(; [key=>kargs[key] for key in filtered]..., AHA = AHA)
end

@deprecate createLinearSolver(solver, A, x; kargs...) createLinearSolver(solver, A; kargs...)

end
end
Loading