Skip to content

Commit

Permalink
add a user define option for threshold in bregman (#3)
Browse files Browse the repository at this point in the history
* add a user define option for threshold in bregman

* have a custom thresholding function for first iter

* update options for lambda function

* sorry about slightly changed api

* dispatch, make kwargs into options

* doesnt change API

* pre-process at bregmanparams

* clean up tests and documentation

* TD is at the end for funobj bregman

* do defaults

* fixed all

* don't need to be in same type
  • Loading branch information
ziyiyin97 authored May 6, 2022
1 parent 3f2822e commit c43c82b
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 101 deletions.
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
name = "SlimOptim"
uuid = "e4c7bc62-5b23-4522-a1b9-71c2be45f1df"
authors = ["Mathias Louboutin <[email protected]>"]
version = "0.1.7"
version = "0.1.8"

This comment has been minimized.

Copy link
@ziyiyin97

ziyiyin97 May 8, 2022

Author Member

[deps]
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
julia = "1"
LineSearches = "7.1.1"
julia = "1"
10 changes: 5 additions & 5 deletions examples/denoising.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ n = 256
k = 4
# Sparse in wavelet domain
W = joDWT(n, n; DDT=Float32, RDT=Float32)
# Or with curvelet ifi nstalled
# Or with curvelet if installed
# W = joCurvelet2D(128, 128; DDT=Float32, RDT=Float32)
A = vcat([joRomberg(n, n; DDT=Float32, RDT=Float32) for i=1:k]...)

Expand All @@ -20,11 +20,11 @@ imgn= img .+ .01f0*randn(Float32, size(img))
b = A*vec(imgn)

# setup bregamn
opt = bregman_options(maxIter=200, verbose=2, quantile=.5, alpha=1, antichatter=true)
opt2 = bregman_options(maxIter=200, verbose=2, quantile=.5, alpha=1, antichatter=true, spg=true)
opt = bregman_options(maxIter=200, verbose=2, quantile=.5, alpha=1, antichatter=true, TD=W)
opt2 = bregman_options(maxIter=200, verbose=2, quantile=.5, alpha=1, antichatter=true, spg=true, TD=W)

sol = bregman(A, W, zeros(Float32, n*n), b, opt)
sol2 = bregman(A, W, zeros(Float32, n*n), b, opt2)
sol = bregman(A, zeros(Float32, n*n), b, opt)
sol2 = bregman(A, zeros(Float32, n*n), b, opt2)

figure()
subplot(121)
Expand Down
4 changes: 2 additions & 2 deletions examples/lsrtm_marmousi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,5 @@ function breg_obj(x)
return .5f0*norm(r)^2, g[1:end]
end

opt = bregman_options(maxIter=5, verbose=2, quantile=.9, alpha=1, antichatter=true)#, spg=true)
sol = bregman(breg_obj, 0f0.*vec(m0), opt, C)
opt = bregman_options(maxIter=5, verbose=2, quantile=.9, alpha=1, antichatter=true, TD=C)#, spg=true)
sol = bregman(breg_obj, 0f0.*vec(m0), opt)
2 changes: 1 addition & 1 deletion src/SlimOptim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

module SlimOptim

using Printf, LinearAlgebra, LineSearches
using Printf, LinearAlgebra, LineSearches, Statistics

import LineSearches: BackTracking, HagerZhang, Static, MoreThuente, StrongWolfe
export BackTracking, HagerZhang, Static, MoreThuente, StrongWolfe
Expand Down
123 changes: 63 additions & 60 deletions src/bregman.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@ mutable struct BregmanParams
maxIter
store_trace
antichatter
quantile
alpha
spg
TD
λfunc
end

"""
bregman_options(;verbose=1, optTol=1e-6, progTol=1e-8, maxIter=20
store_trace=false, linesearch=false, alpha=.25, spg=false)
bregman_options(;verbose=1, optTol=1e-6, progTol=1e-8, maxIter=20,
store_trace=false, quantile=.5, alpha=.25, spg=false)
Options structure for the bregman iteration algorithm
Expand All @@ -25,69 +26,87 @@ Options structure for the bregman iteration algorithm
- `maxIter`: maximum number of iterations (default: 20)
- `store_trace`: Whether to store the trace/history of x (default: false)
- `antichatter`: Whether to use anti-chatter step length correction
- `quantile`: Thresholding level as quantile value, (default=.95 i.e thresholds 95% of the vector)
- `alpha`: Strong convexity modulus. (step length is ``α \\frac{||r||_2^2}{||g||_2^2}``)
- `spg`: whether to use spg, default is false
- `TD`: sparsifying transform (e.g. curvelet), default is identity (LinearAlgebra.I)
- `λfunc`: a function to calculate threshold value, default is nothing
- `λ`: a pre-set threshold, will only be used if `λfunc` is not defined, default is nothing
- `quantile`: a percentage to calculate the threshold by quantile of the dual variable in 1st iteration, will only be used if neither `λfunc` nor `λ` are defined, default is .95 i.e thresholds 95% of the vector
"""
bregman_options(;verbose=1, progTol=1e-8, maxIter=20, store_trace=false, antichatter=true, quantile=.95, alpha=.5, spg=false) =
BregmanParams(verbose, progTol, maxIter, store_trace, antichatter, quantile, alpha, spg)
function bregman_options(;verbose=1, progTol=1e-8, maxIter=20, store_trace=false, antichatter=true, alpha=.5, spg=false, TD=LinearAlgebra.I, quantile=.95, λ=nothing, λfunc=nothing)
if isnothing(λfunc)
if ~isnothing(λ)
λfunc = z->λ
else
λfunc = z->Statistics.quantile(abs.(z), quantile)
end
end
return BregmanParams(verbose, progTol, maxIter, store_trace, antichatter, alpha, spg, TD, λfunc)
end

"""
bregman(A, TD, x, b, options)
bregman(A, x, b, options)
Linearized bregman iteration for the system
``\\frac{1}{2} ||TD \\ x||_2^2 + λ ||TD \\ x||_1 \\ \\ \\ s.t Ax = b``
For example, for sparsity promoting denoising (i.e LSRTM)
# Arguments
# Required arguments
- `TD`: curvelet transform
- `A`: Forward operator (J or preconditioned J for LSRTM)
- `b`: observed data
- `A`: Forward operator (e.g. J or preconditioned J for LSRTM)
- `x`: Initial guess
- `b`: observed data
# Non-required arguments
- `options`: bregman options, default is bregman_options(); options.TD provides the sparsifying transform (e.g. curvelet)
"""
function bregman(A, TD, x::Array{T}, b, options) where {T}
function bregman(A, x::AbstractVector{T1}, b::AbstractVector{T2}, options::BregmanParams=bregman_options()) where {T1<:Number, T2<:Number}
# residual function wrapper
function obj(x)
d = A*x
fun = .5*norm(d - b)^2
grad = A'*(d - b)
return fun, grad
end

return bregman(obj, x, options, TD)
return bregman(obj, x, options)
end

function bregman(A, TD, x::AbstractVector{T1}, b::AbstractVector{T2}, options::BregmanParams=bregman_options()) where {T1<:Number, T2<:Number}
@warn "deprecation warning: please put TD in options (BregmanParams) for version > 0.1.7; now overwritting TD in BregmanParams"
options.TD = TD
return bregman(A, x, b, options)
end

"""
bregman(fun, TD, x, b, options)
bregman(funobj, x, options)
Linearized bregman iteration for the system
``\\frac{1}{2} ||TD \\ x||_2^2 + λ ||TD \\ x||_1 \\ \\ \\ s.t Ax = b``
For example, for sparsity promoting denoising (i.e LSRTM)
# Required arguments
# Arguments
- `TD`: curvelet transform
- `fun`: residual function, return the tuple (``f = \\frac{1}{2}||Ax - b||_2``, ``g = A^T(Ax - b)``)
- `b`: observed data
- `funobj`: a function that calculates the objective value (`0.5 * norm(Ax-b)^2`) and the gradient (`A'(Ax-b)`)
- `x`: Initial guess
# Non-required arguments
- `options`: bregman options, default is bregman_options(); options.TD provides the sparsifying transform (e.g. curvelet)
"""
function bregman(funobj::Function, x::AbstractArray{T}, options::BregmanParams, TD=nothing) where {T}
function bregman(funobj::Function, x::AbstractVector{T}, options::BregmanParams=bregman_options()) where {T}
# Output Parameter Settings
if options.verbose > 0
@printf("Running linearized bregman...\n");
@printf("Progress tolerance: %.2e\n",options.progTol)
@printf("Maximum number of iterations: %d\n",options.maxIter)
@printf("Anti-chatter correction: %d\n",options.antichatter)
end
isnothing(TD) && (TD = LinearAlgebra.I)
# Intitalize variables
z = TD*x
# Initialize variables
z = options.TD*x
d = similar(z)
options.spg && (gold = similar(x); xold=similar(x))
if options.antichatter
Expand All @@ -96,8 +115,6 @@ function bregman(funobj::Function, x::AbstractArray{T}, options::BregmanParams,

# Result structure
sol = breglog(x, z)
# Initialize λ
λ = abs(T(0))

# Output Log
if options.verbose > 0
Expand All @@ -108,60 +125,45 @@ function bregman(funobj::Function, x::AbstractArray{T}, options::BregmanParams,
for i=1:options.maxIter
f, g = funobj(x)
# Preconditionned ipdate direction
d .= -TD*g
d .= -options.TD*g
# Step length
t = (options.spg && i> 1) ? T(dot(x-xold, x-xold)/dot(x-xold, g-gold)) : T(options.alpha*f/norm(d)^2)
t = abs(t)
mul!(d, d, t)

# Anti-chatter
if options.antichatter
@assert isreal(z) "we currently do not support anti-chatter for complex numbers"
@. tk = tk - sign(d)
# Chatter correction
inds_z = findall(abs.(z) .> λ)
@views d[inds_z] .*= abs.(tk[inds_z])/i
# Chatter correction after 1st iteration
if i > 1
inds_z = findall(abs.(z) .> sol.λ)
@views d[inds_z] .*= abs.(tk[inds_z])/i
end
end
# Update z variable
@. z = z + d
# Get λ at first iteration
i == 1 &&= abs(T(quantile(abs.(z), options.quantile))))
(i == 1) && (sol.λ = abs.(T.(options.λfunc(z))))
# Save curent state
options.spg && (gold .= g; xold .= x)
# Update x
x = TD'*soft_thresholding(z, λ)
x = options.TD'*soft_thresholding(z, sol.λ)

obj_fun = λ * norm(z, 1) + .5 * norm(z, 2)^2
if options.verbose > 0
@printf("%10d %15.5e %15.5e %15.5e %15.5e \n",i, t, obj_fun, f, λ)
end
obj_fun = norm(sol.λ .* z, 1) + .5 * norm(z, 2)^2
(options.verbose > 0) && (@printf("%10d %15.5e %15.5e %15.5e %15.5e \n",i, t, obj_fun, f, maximum(sol.λ)))
norm(x - sol.x) < options.progTol && (@printf("Step size below progTol\n"); break;)
update!(sol; iter=i, ϕ=obj_fun, residual=f, x=x, z=z, g=g, store_trace=options.store_trace)
end
return sol
end

# Utility functions
"""
Simplified Quantile from Statistics.jl since we only need simplified version of it.
"""
function quantile(u::AbstractVector, p::Real)
0 <= p <= 1 || throw(ArgumentError("input probability out of [0,1] range"))
n = length(u)
v = sort(u; alg=Base.QuickSort)

m = 1 - p
aleph = n*p + oftype(p, m)
j = clamp(trunc(Int, aleph), 1, n-1)
γ = clamp(aleph - j, 0, 1)

n == 1 ? a = v[1] : a = v[j]
n == 1 ? b = v[1] : b = v[j+1]

(isfinite(a) && isfinite(b)) ? q = a + γ*(b-a) : q = (1-γ)*a + γ*b
return q
function bregman(funobj::Function, x::AbstractVector{T}, options::BregmanParams, TD) where {T}
@warn "deprecation warning: please put TD in options (BregmanParams) for version > 0.1.7; now overwritting TD in BregmanParams"
options.TD = TD
return bregman(funobj, x, options)
end


"""
Bregman result structure
"""
Expand All @@ -170,6 +172,7 @@ mutable struct BregmanIterations
z
g
ϕ
λ
residual
ϕ_trace
r_trace
Expand All @@ -189,6 +192,6 @@ function update!(r::BregmanIterations; x=nothing, z=nothing, ϕ=nothing, residua
(~isnothing(residual) && length(r.r_trace) == iter-1) && (push!(r.r_trace, residual))
end

function breglog(init_x, init_z; f0=0, obj0=0)
return BregmanIterations(1*init_x, 1*init_z, 0*init_z, f0, obj0, Vector{}(), Vector{}(), Vector{}(), Vector{}())
end
function breglog(init_x, init_z; lambda0=0, f0=0, obj0=0)
return BregmanIterations(1*init_x, 1*init_z, 0*init_z, f0, lambda0, obj0, Vector{}(), Vector{}(), Vector{}(), Vector{}())
end
66 changes: 35 additions & 31 deletions test/test_bregman.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,38 @@ using LinearAlgebra

N1 = 100
N2 = div(N1, 2) + 5
A = randn(N1, N2)

x0 = 10 .* randn(N2)
x0[abs.(x0) .< 1f-6] .= 1.0
inds = rand(1:N2, div(N2, 4))
ninds = [i for i=1:N2 if i inds]
x0[inds] .= 0
b = A*x0

function obj(x)
fun = .5*norm(A*x - b)^2
grad = A'*(A*x - b)
return fun, grad
end

opt = bregman_options(maxIter=200, progTol=0, verbose=2)
sol = bregman(obj, 1 .+ randn(N2), opt)

@show sol.x[inds]
@show x0[inds]
@show sol.x[ninds]
@show x0[ninds]

part_n = i -> norm(sol.x[i] - x0[i])/(norm(x0[i]) + norm(sol.x[i]) + eps(Float64))
part_nz = i -> norm(sol.x[i], 1)/N2
@show part_nz(inds)
@show part_n(ninds)

@test part_nz(inds) < 1f-1
@test part_n(ninds) < 1f-1
@test sol.residual/sol.r_trace[1] < 1f-1

@testset "Bregman test for type $(T)" for T = [Float32, ComplexF32]

A = randn(T, N1, N2)
x0 = 10 .* randn(T, N2)
x0[abs.(x0) .< 1f-6] .= 1.0
inds = rand(1:N2, div(N2, 4))
ninds = [i for i=1:N2 if i inds]
x0[inds] .= 0
b = A*x0

function obj(x)
fun = .5*norm(A*x - b)^2
grad = A'*(A*x - b)
return fun, grad
end

opt = bregman_options(maxIter=200, progTol=0, verbose=2, antichatter=T==Float32)
sol = bregman(obj, 1 .+ randn(T, N2), opt)

@show sol.x[inds]
@show x0[inds]
@show sol.x[ninds]
@show x0[ninds]

part_n = i -> norm(sol.x[i] - x0[i])/(norm(x0[i]) + norm(sol.x[i]) + eps(Float32))
part_nz = i -> norm(sol.x[i], 1)/N2
@show part_nz(inds)
@show part_n(ninds)

@test part_nz(inds) < 1f-1
@test part_n(ninds) < 1f-1
@test sol.residual/sol.r_trace[1] < 1f-1

end

1 comment on commit c43c82b

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/59916

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.8 -m "<description of version>" c43c82b4758f8ebe95d65086513a5db9062ae74e
git push origin v0.1.8

Please sign in to comment.