Skip to content

Commit

Permalink
add callback to all three algorithms
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Jun 30, 2022
1 parent c43c82b commit 1a9a33a
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 22 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SlimOptim"
uuid = "e4c7bc62-5b23-4522-a1b9-71c2be45f1df"
authors = ["Mathias Louboutin <[email protected]>"]
version = "0.1.8"
version = "0.1.9"

This comment has been minimized.

Copy link
@mloubout

mloubout Jun 30, 2022

Author Member

[deps]
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
Expand Down
49 changes: 49 additions & 0 deletions examples/callback.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Author: Mathias Louboutin, [email protected]
# Date: December 2020

using LinearAlgebra, SlimOptim

N = 10

A = Diagonal(1:N)

function proj(x)
xp = deepcopy(x)
xp[xp .< 0] .= 0
return xp
end

x0 = 10 .+ 10 .* rand(N)
b = A*x0

function obj(x)
fun = .5*norm(A*x - b)^2
grad = A'*(A*x - b)
return fun, grad
end

function mycallback(sol::result)
# Print some info. ϕ_trace contains initial value so iteration is lenght-1
println("Bonjour at iteration $(length(sol.ϕ_trace)-1) with misfit value of $(sol.ϕ)")
println("Norm of solution is $(norm(sol.x))")
nothing
end

function mycallback(sol::BregmanIterations)
# Print some info. ϕ_trace contains initial value so iteration is lenght-1
println("Bonjour at iteration $(length(sol.ϕ_trace)-1) with misfit value of $(sol.ϕ)")
println("Norm of solution are $(norm(sol.x)), $(norm(sol.z))")
nothing
end

# PQN
sol = pqn(obj, randn(N), proj)
sol = pqn(obj, randn(N), proj, callback=mycallback)

# SPG
sol = spg(obj, randn(N), proj)
sol = spg(obj, randn(N), proj, callback=mycallback)

# Bregman
sol = bregman(A, zeros(Float32, N), A*randn(Float32, N))
sol = bregman(A, zeros(Float32, N), A*randn(Float32, N); callback=mycallback)
30 changes: 23 additions & 7 deletions src/PQNSlim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ function pqn_options(;verbose=1, optTol=1f-5, progTol=1f-7,
end

"""
pqn(objective, x, projection, options)
pqn(objective, x, projection, options; ls=nothing, callback=nothing)
Function for using a limited-memory projected quasi-Newton to solve problems of the form
min objective(x) s.t. x in C
Expand All @@ -73,10 +73,14 @@ gradient algorithm
- `x`: Initial guess
- `options`: pqn_options structure
# Optional Arguments
- `ls` `: User provided linesearch function
- `callback` : Callback function. Must take as input a `result` callback(x::result)
# Notes:
Adapted fromt he matlab implementation of minConf_PQN
"""
function pqn(funObj, x::AbstractArray{T}, funProj::Function, options::PQN_params, ls=nothing) where {T}
function pqn(funObj, x::AbstractArray{T}, funProj::Function, options::PQN_params=pqn_options(); ls=nothing, callback=noop_callback) where {T}
# Result structure
sol = result(x)
G = similar(x)
Expand All @@ -88,11 +92,13 @@ function pqn(funObj, x::AbstractArray{T}, funProj::Function, options::PQN_params
obj(x) = objgrad!(G, x)

# Solve optimization
return _pqn(obj, grad!, objgrad!, projection, x, G, sol, ls, options)
return _pqn(obj, grad!, objgrad!, projection, x, G, sol, ls, options; callback=callback)
end

pqn(funObj, x, funProj, options, ls) = pqn(funObj, x, funProj, options;ls=ls)

"""
pqn(f, g!, fg!, x, projection,options)
pqn(f, g!, fg!, x, projection, options; ls=nothing, callback=nothing)
Function for using a limited-memory projected quasi-Newton to solve problems of the form
min objective(x) s.t. x in C
Expand All @@ -109,11 +115,16 @@ gradient algorithm.
- `x`: Initial guess
- `options`: pqn_options structure
# Optional Arguments
- `ls` `: User provided linesearch function
- `callback` : Callback function. Must take as input a `result` callback(x::result)
# Notes:
Adapted fromt he matlab implementation of minConf_PQN
"""
function pqn(f::Function, g!::Function, fg!::Function, x::AbstractArray{T},
funProj::Function, options::PQN_params, ls=nothing) where {T}
funProj::Function, options::PQN_params=pqn_options();
ls=nothing, callback=noop_callback) where {T}
# Result structure
sol = result(x)
G = similar(x)
Expand All @@ -125,14 +136,17 @@ function pqn(f::Function, g!::Function, fg!::Function, x::AbstractArray{T},
objgrad!(g, x) = (sol.n_ϕeval +=1;sol.n_geval +=1 ; return fg!(g, x))

# Solve optimization
return _pqn(obj, grad!, objgrad!, projection, x, G, sol, ls, options)
return _pqn(obj, grad!, objgrad!, projection, x, G, sol, ls, options; callback=callback)
end

pqn(f, g, fg!, x, funProj, options, ls) = pqn(f, g, fg!, x, funProj, options; ls=ls)

"""
Low level PQN solver
"""
function _pqn(obj::Function, grad!::Function, objgrad!::Function, projection::Function,
x::AbstractArray{T}, g::AbstractArray{T}, sol::result, ls, options::PQN_params) where {T}
x::AbstractArray{T}, g::AbstractArray{T}, sol::result, ls, options::PQN_params;
callback=noop_callback) where {T}
nVars = length(x)
old_ϕvals = -T(Inf)*ones(T, options.memory)
spg_opt = spg_options(optTol=options.SPGoptTol,progTol=options.SPGprogTol, maxIter=options.SPGiters,
Expand Down Expand Up @@ -254,6 +268,8 @@ function _pqn(obj::Function, grad!::Function, objgrad!::Function, projection::Fu
@printf("%10d %10d %10d %10d %15.5e %15.5e %15.5e\n",i,sol.n_ϕeval, sol.n_geval, sol.n_project, t, ϕ, optCond)
end

# Optional callback
callback(sol)
end
isLegal(x) && update!(sol; iter=options.maxIter+1, ϕ=ϕ, x=x, g=g, store_trace=options.store_trace)
return return sol
Expand Down
31 changes: 24 additions & 7 deletions src/SPGSlim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ end


"""
spg(funObj, x, funProj, options)
spg(funObj, x, funProj, options; ls=nothing, callback=nothing)
Function for using Spectral Projected Gradient to solve problems of the form
min funObj(x) s.t. x in C
Expand All @@ -75,6 +75,10 @@ Function for using Spectral Projected Gradient to solve problems of the form
- `x`: Initial guess
- `options`: spg_options structure
# Optional Arguments
- `ls` `: User provided linesearch function
- `callback` : Callback function. Must take as input a `result` callback(x::result)
# Notes:
- if the projection is expensive to compute, you can reduce the
Expand All @@ -83,7 +87,7 @@ Function for using Spectral Projected Gradient to solve problems of the form
- Adapted fromt he matlab implementation of minConf_SPG
"""
function spg(funObj::Function, x::AbstractArray{T}, funProj::Function,
options::SPG_params, ls=nothing) where {T}
options::SPG_params=spg_options(); ls=nothing, callback=noop_callback) where {T}
# Result structure
sol = result(x)
# Initialize array for gradient
Expand All @@ -96,11 +100,13 @@ function spg(funObj::Function, x::AbstractArray{T}, funProj::Function,
obj(x) = objgrad!(G, x)

# Solve optimization
return _spg(obj, grad!, objgrad!, projection, x, G, sol, ls, options)
return _spg(obj, grad!, objgrad!, projection, x, G, sol, ls, options; callback=callback)
end

spg(funObj, x, funProj, options, ls) = spg(funObj, x, funProj, options;ls=ls)

"""
spg(f, g!, fg!, x, funProj, options)
spg(f, g!, fg!, x, funProj, options; ls=nothing, callback=nothing)
Function for using Spectral Projected Gradient to solve problems of the form
min funObj(x) s.t. x in C
Expand All @@ -113,14 +119,19 @@ min funObj(x) s.t. x in C
- `x`: Initial guess
- `options`: spg_options structure
# Optional Arguments
- `ls` `: User provided linesearch function
- `callback` : Callback function. Must take as input a `result` callback(x::result)
# Notes:
- if the projection is expensive to compute, you can reduce the
number of projections by setting testOpt to 0 in the options
- Adapted fromt he matlab implementation of minConf_SPG
"""
function spg(f::Function, g!::Function, fg!::Function, x::AbstractArray{T},
funProj::Function, options::SPG_params, ls=nothing) where {T}
funProj::Function, options::SPG_params=spg_options();
ls=nothing, callback=noop_callback) where {T}
# Result structure
sol = result(x)
# Initialize array for gradient
Expand All @@ -133,14 +144,17 @@ function spg(f::Function, g!::Function, fg!::Function, x::AbstractArray{T},
objgrad!(g, x) = (sol.n_ϕeval +=1;sol.n_geval +=1 ; return fg!(g, x))

# Solve optimization
return _spg(obj, grad!, objgrad!, projection, x, G, sol, ls, options)
return _spg(obj, grad!, objgrad!, projection, x, G, sol, ls, options; callback=callback)
end

spg(f, g!, fg!, x, funProj, options, ls) = spg(f, g!, fg!, x, funProj, options; ls=ls)

"""
Low level SPG solver
"""
function _spg(obj::Function, grad!::Function, objgrad!::Function, projection::Function,
x::AbstractArray{T}, g::AbstractArray{T}, sol::result, ls, options::SPG_params) where {T}
x::AbstractArray{T}, g::AbstractArray{T}, sol::result, ls, options::SPG_params;
callback=noop_callback) where {T}
# Initialize local variables
nVars = length(x)
old_ϕvals = -T(Inf)*ones(T, options.memory)
Expand Down Expand Up @@ -216,6 +230,9 @@ function _spg(obj::Function, grad!::Function, objgrad!::Function, projection::Fu

# Output Log
iter_log(i, sol, t, alpha, ϕ, optCond, options)

# Potential callback
callback(sol)
end

# Restore best iteration
Expand Down
2 changes: 1 addition & 1 deletion src/SlimOptim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ export BackTracking, HagerZhang, Static, MoreThuente, StrongWolfe

export pqn, pqn_options
export spg, spg_options
export bregman, bregman_options
export bregman, bregman_options, BregmanIterations
#############################################################################
# Optimization algorithms
include("utils.jl") # common functions
Expand Down
24 changes: 18 additions & 6 deletions src/bregman.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,25 +60,28 @@ For example, for sparsity promoting denoising (i.e LSRTM)
- `x`: Initial guess
- `b`: observed data
# Optional Arguments
- `callback` : Callback function. Must take as input a `result` callback(x::result)
# Non-required arguments
- `options`: bregman options, default is bregman_options(); options.TD provides the sparsifying transform (e.g. curvelet)
"""
function bregman(A, x::AbstractVector{T1}, b::AbstractVector{T2}, options::BregmanParams=bregman_options()) where {T1<:Number, T2<:Number}
function bregman(A, x::AbstractVector{T1}, b::AbstractVector{T2}, options::BregmanParams=bregman_options(); callback=noop_callback) where {T1<:Number, T2<:Number}
# residual function wrapper
function obj(x)
d = A*x
fun = .5*norm(d - b)^2
grad = A'*(d - b)
return fun, grad
end
return bregman(obj, x, options)
return bregman(obj, x, options; callback=callback)
end

function bregman(A, TD, x::AbstractVector{T1}, b::AbstractVector{T2}, options::BregmanParams=bregman_options()) where {T1<:Number, T2<:Number}
function bregman(A, TD, x::AbstractVector{T1}, b::AbstractVector{T2}, options::BregmanParams=bregman_options(); callback=noop_callback) where {T1<:Number, T2<:Number}
@warn "deprecation warning: please put TD in options (BregmanParams) for version > 0.1.7; now overwritting TD in BregmanParams"
options.TD = TD
return bregman(A, x, b, options)
return bregman(A, x, b, options; callback=callback)
end

"""
Expand All @@ -96,8 +99,12 @@ Linearized bregman iteration for the system
# Non-required arguments
- `options`: bregman options, default is bregman_options(); options.TD provides the sparsifying transform (e.g. curvelet)
# Optional Arguments
- `callback` : Callback function. Must take as input a `result` callback(x::result)
"""
function bregman(funobj::Function, x::AbstractVector{T}, options::BregmanParams=bregman_options()) where {T}
function bregman(funobj::Function, x::AbstractVector{T}, options::BregmanParams=bregman_options(); callback=noop_callback) where {T}
# Output Parameter Settings
if options.verbose > 0
@printf("Running linearized bregman...\n");
Expand Down Expand Up @@ -154,6 +161,9 @@ function bregman(funobj::Function, x::AbstractVector{T}, options::BregmanParams=
(options.verbose > 0) && (@printf("%10d %15.5e %15.5e %15.5e %15.5e \n",i, t, obj_fun, f, maximum(sol.λ)))
norm(x - sol.x) < options.progTol && (@printf("Step size below progTol\n"); break;)
update!(sol; iter=i, ϕ=obj_fun, residual=f, x=x, z=z, g=g, store_trace=options.store_trace)

# Optional callback
callback(sol)
end
return sol
end
Expand Down Expand Up @@ -194,4 +204,6 @@ end

function breglog(init_x, init_z; lambda0=0, f0=0, obj0=0)
return BregmanIterations(1*init_x, 1*init_z, 0*init_z, f0, lambda0, obj0, Vector{}(), Vector{}(), Vector{}(), Vector{}())
end
end

noop_callback(::BregmanIterations) = nothing
2 changes: 2 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ function result(init_x::AbstractArray{T}; ϕ0=0, ϕeval=0, δϕeval=0) where T
return result(deepcopy(init_x), T(0)*init_x, T(ϕ0), Vector{T}(), Vector{AbstractArray{T}}(), 0, ϕeval, δϕeval)
end

noop_callback(::result) = nothing

function isLegal(v::AbstractArray{T}) where T
nv = norm(v)
return !isnan(nv) && !isinf(nv)
Expand Down

1 comment on commit 1a9a33a

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/63419

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.9 -m "<description of version>" 1a9a33a7470e54012a023f69d704a6ecc7d2a0af
git push origin v0.1.9

Please sign in to comment.