Skip to content

Commit

Permalink
Fix verbose printing in ADMM, FISTA, OptISTA, and POGM solvers
Browse files Browse the repository at this point in the history
  • Loading branch information
nHackel committed Aug 23, 2024
1 parent 4f551dc commit 59a3049
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 10 deletions.
10 changes: 5 additions & 5 deletions src/ADMM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ performs one ADMM iteration.
"""
function iterate(solver::ADMM, state::ADMMState)
done(solver, state) && return nothing
solver.verbose && println("Outer ADMM Iteration #$iteration")
solver.verbose && println("Outer ADMM Iteration #$(state.iteration)")

# 1. solve arg min_x 1/2|| Ax-b ||² + ρ/2 Σ_i||Φi*x+ui-zi||²
# <=> (A'A+ρ Σ_i Φi'Φi)*x = A'b+ρΣ_i Φi'(zi-ui)
Expand Down Expand Up @@ -264,10 +264,10 @@ function iterate(solver::ADMM, state::ADMMState)
end

if solver.verbose
println("rᵏ[$i]/ɛᵖʳⁱ[$i] = $(solver.rᵏ[i]/solver.ɛᵖʳⁱ[i])")
println("sᵏ[$i]/ɛᵈᵘᵃ[$i] = $(solver.sᵏ[i]/solver.ɛᵈᵘᵃ[i])")
println("Δ[$i]/Δᵒˡᵈ[$i] = $(solver.Δ[i]/Δᵒˡᵈ)")
println("new ρ[$i] = $(solver.ρ[i])")
println("rᵏ[$i]/ɛᵖʳⁱ[$i] = $(state.rᵏ[i]/state.ɛᵖʳⁱ[i])")
println("sᵏ[$i]/ɛᵈᵘᵃ[$i] = $(state.sᵏ[i]/state.ɛᵈᵘᵃ[i])")
println("Δ[$i]/Δᵒˡᵈ[$i] = $(state.Δ[i]/Δᵒˡᵈ)")
println("new ρ[$i] = $(state.ρ[i])")
flush(stdout)
end
end
Expand Down
4 changes: 2 additions & 2 deletions src/FISTA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ function iterate(solver::FISTA, state::FISTAState)
state.x .-= state.ρ .* state.res

state.rel_res_norm = norm(state.res) / state.norm_x₀
solver.verbose && println("Iteration $iteration; rel. residual = $(state.rel_res_norm)")
solver.verbose && println("Iteration $(state.iteration); rel. residual = $(state.rel_res_norm)")

# the two lines below are equivalent to the ones above and non-allocating, but require the 5-argument mul! function to implemented for AHA, i.e. if AHA is LinearOperator, it requires LinearOperators.jl v2
# mul!(solver.x, solver.AHA, solver.xᵒˡᵈ, -solver.ρ, 1)
Expand All @@ -170,7 +170,7 @@ function iterate(solver::FISTA, state::FISTAState)
# gradient restart conditions
if solver.restart == :gradient
if real(state.res (state.x .- state.xᵒˡᵈ) ) > 0 #if momentum is at an obtuse angle to the negative gradient
solver.verbose && println("Gradient restart at iter $iteration")
solver.verbose && println("Gradient restart at iter $(state.iteration)")
state.theta = 1
end
end
Expand Down
2 changes: 1 addition & 1 deletion src/OptISTA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ function iterate(solver::OptISTA, state::OptISTAState)
state.y .-= state.ρ * state.γ .* state.res

state.rel_res_norm = norm(state.res) / state.norm_x₀
solver.verbose && println("Iteration $iteration; rel. residual = $(state.rel_res_norm)")
solver.verbose && println("Iteration $(state.iteration); rel. residual = $(state.rel_res_norm)")

# proximal map
prox!(solver.reg, state.y, state.ρ * state.γ * λ(solver.reg))
Expand Down
4 changes: 2 additions & 2 deletions src/POGM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ function iterate(solver::POGM, state::POGMState)
state.x .-= state.ρ .* state.res

state.rel_res_norm = norm(state.res) / state.norm_x₀
solver.verbose && println("Iteration $iteration; rel. residual = $(state.rel_res_norm)")
solver.verbose && println("Iteration $(state.iteration); rel. residual = $(state.rel_res_norm)")

# inertial parameters
state.thetaᵒˡᵈ = state.theta
Expand Down Expand Up @@ -222,7 +222,7 @@ function iterate(solver::POGM, state::POGMState)
if solver.restart == :gradient
state.w .+= state.y .+ state.ρ ./ state.γ .* (state.x .- state.z)
if real((state.w state.x - state.w state.z) / state.γ - state.w state.res) < 0
solver.verbose && println("Gradient restart at iter $iteration")
solver.verbose && println("Gradient restart at iter $(state.iteration)")
state.σ = 1
state.theta = 1
else # decreasing γ
Expand Down
20 changes: 20 additions & 0 deletions test/testSolvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,25 @@ function testConvexLinearSolver(; arrayType = Array, elType = Float32)
=#
end

function testVerboseSolvers(; arrayType = Array, elType = Float32)
A = rand(elType, 3, 2)
x = rand(elType, 2)
b = A * x

solvers = [ADMM, FISTA, POGM, OptISTA, SplitBregman]

for solver in solvers
@test try
S = createLinearSolver(solver, arrayType(A), iterations = 3, verbose = true)
solve!(S, arrayType(b))
true
catch e
@error e
false
end
end
end


@testset "Test Solvers" begin
for arrayType in arrayTypes
Expand All @@ -240,4 +259,5 @@ end
end
end
end
testVerboseSolvers()
end

0 comments on commit 59a3049

Please sign in to comment.