From c43c82b4758f8ebe95d65086513a5db9062ae74e Mon Sep 17 00:00:00 2001 From: "Ziyi (Francis) Yin" <54320031+ziyiyin97@users.noreply.github.com> Date: Fri, 6 May 2022 18:41:29 -0400 Subject: [PATCH] add a user define option for threshold in bregman (#3) * add a user define option for threshold in bregman * have a custom thresholding function for first iter * update options for lambda function * sorry about slightly changed api * dispatch, make kwargs into options * doesnt change API * pre-process at bregmanparams * clean up tests and documentation * TD is at the end for funobj bregman * do defaults * fixed all * don't need to be in same type --- Project.toml | 5 +- examples/denoising.jl | 10 +-- examples/lsrtm_marmousi.jl | 4 +- src/SlimOptim.jl | 2 +- src/bregman.jl | 123 +++++++++++++++++++------------------ test/test_bregman.jl | 66 ++++++++++---------- 6 files changed, 109 insertions(+), 101 deletions(-) diff --git a/Project.toml b/Project.toml index 526f76b..02f35a4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,14 +1,15 @@ name = "SlimOptim" uuid = "e4c7bc62-5b23-4522-a1b9-71c2be45f1df" authors = ["Mathias Louboutin "] -version = "0.1.7" +version = "0.1.8" [deps] LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] -julia = "1" LineSearches = "7.1.1" +julia = "1" diff --git a/examples/denoising.jl b/examples/denoising.jl index d1ee7bf..b4bd268 100644 --- a/examples/denoising.jl +++ b/examples/denoising.jl @@ -11,7 +11,7 @@ n = 256 k = 4 # Sparse in wavelet domain W = joDWT(n, n; DDT=Float32, RDT=Float32) -# Or with curvelet ifi nstalled +# Or with curvelet if installed # W = joCurvelet2D(128, 128; DDT=Float32, RDT=Float32) A = vcat([joRomberg(n, n; DDT=Float32, RDT=Float32) for i=1:k]...) @@ -20,11 +20,11 @@ imgn= img .+ .01f0*randn(Float32, size(img)) b = A*vec(imgn) # setup bregamn -opt = bregman_options(maxIter=200, verbose=2, quantile=.5, alpha=1, antichatter=true) -opt2 = bregman_options(maxIter=200, verbose=2, quantile=.5, alpha=1, antichatter=true, spg=true) +opt = bregman_options(maxIter=200, verbose=2, quantile=.5, alpha=1, antichatter=true, TD=W) +opt2 = bregman_options(maxIter=200, verbose=2, quantile=.5, alpha=1, antichatter=true, spg=true, TD=W) -sol = bregman(A, W, zeros(Float32, n*n), b, opt) -sol2 = bregman(A, W, zeros(Float32, n*n), b, opt2) +sol = bregman(A, zeros(Float32, n*n), b, opt) +sol2 = bregman(A, zeros(Float32, n*n), b, opt2) figure() subplot(121) diff --git a/examples/lsrtm_marmousi.jl b/examples/lsrtm_marmousi.jl index 8fd68e9..59d7883 100644 --- a/examples/lsrtm_marmousi.jl +++ b/examples/lsrtm_marmousi.jl @@ -52,5 +52,5 @@ function breg_obj(x) return .5f0*norm(r)^2, g[1:end] end -opt = bregman_options(maxIter=5, verbose=2, quantile=.9, alpha=1, antichatter=true)#, spg=true) -sol = bregman(breg_obj, 0f0.*vec(m0), opt, C) \ No newline at end of file +opt = bregman_options(maxIter=5, verbose=2, quantile=.9, alpha=1, antichatter=true, TD=C)#, spg=true) +sol = bregman(breg_obj, 0f0.*vec(m0), opt) \ No newline at end of file diff --git a/src/SlimOptim.jl b/src/SlimOptim.jl index 9d35f4c..feb2d02 100644 --- a/src/SlimOptim.jl +++ b/src/SlimOptim.jl @@ -3,7 +3,7 @@ module SlimOptim -using Printf, LinearAlgebra, LineSearches +using Printf, LinearAlgebra, LineSearches, Statistics import LineSearches: BackTracking, HagerZhang, Static, MoreThuente, StrongWolfe export BackTracking, HagerZhang, Static, MoreThuente, StrongWolfe diff --git a/src/bregman.jl b/src/bregman.jl index 399f7ea..87c8aea 100644 --- a/src/bregman.jl +++ b/src/bregman.jl @@ -7,14 +7,15 @@ mutable struct BregmanParams maxIter store_trace antichatter - quantile alpha spg + TD + λfunc end """ - bregman_options(;verbose=1, optTol=1e-6, progTol=1e-8, maxIter=20 - store_trace=false, linesearch=false, alpha=.25, spg=false) + bregman_options(;verbose=1, optTol=1e-6, progTol=1e-8, maxIter=20, + store_trace=false, quantile=.5, alpha=.25, spg=false) Options structure for the bregman iteration algorithm @@ -25,15 +26,27 @@ Options structure for the bregman iteration algorithm - `maxIter`: maximum number of iterations (default: 20) - `store_trace`: Whether to store the trace/history of x (default: false) - `antichatter`: Whether to use anti-chatter step length correction -- `quantile`: Thresholding level as quantile value, (default=.95 i.e thresholds 95% of the vector) - `alpha`: Strong convexity modulus. (step length is ``α \\frac{||r||_2^2}{||g||_2^2}``) +- `spg`: whether to use spg, default is false +- `TD`: sparsifying transform (e.g. curvelet), default is identity (LinearAlgebra.I) +- `λfunc`: a function to calculate threshold value, default is nothing +- `λ`: a pre-set threshold, will only be used if `λfunc` is not defined, default is nothing +- `quantile`: a percentage to calculate the threshold by quantile of the dual variable in 1st iteration, will only be used if neither `λfunc` nor `λ` are defined, default is .95 i.e thresholds 95% of the vector """ -bregman_options(;verbose=1, progTol=1e-8, maxIter=20, store_trace=false, antichatter=true, quantile=.95, alpha=.5, spg=false) = - BregmanParams(verbose, progTol, maxIter, store_trace, antichatter, quantile, alpha, spg) +function bregman_options(;verbose=1, progTol=1e-8, maxIter=20, store_trace=false, antichatter=true, alpha=.5, spg=false, TD=LinearAlgebra.I, quantile=.95, λ=nothing, λfunc=nothing) + if isnothing(λfunc) + if ~isnothing(λ) + λfunc = z->λ + else + λfunc = z->Statistics.quantile(abs.(z), quantile) + end + end + return BregmanParams(verbose, progTol, maxIter, store_trace, antichatter, alpha, spg, TD, λfunc) +end """ - bregman(A, TD, x, b, options) + bregman(A, x, b, options) Linearized bregman iteration for the system @@ -41,14 +54,17 @@ Linearized bregman iteration for the system For example, for sparsity promoting denoising (i.e LSRTM) -# Arguments +# Required arguments -- `TD`: curvelet transform -- `A`: Forward operator (J or preconditioned J for LSRTM) -- `b`: observed data +- `A`: Forward operator (e.g. J or preconditioned J for LSRTM) - `x`: Initial guess +- `b`: observed data + +# Non-required arguments + +- `options`: bregman options, default is bregman_options(); options.TD provides the sparsifying transform (e.g. curvelet) """ -function bregman(A, TD, x::Array{T}, b, options) where {T} +function bregman(A, x::AbstractVector{T1}, b::AbstractVector{T2}, options::BregmanParams=bregman_options()) where {T1<:Number, T2<:Number} # residual function wrapper function obj(x) d = A*x @@ -56,28 +72,32 @@ function bregman(A, TD, x::Array{T}, b, options) where {T} grad = A'*(d - b) return fun, grad end - - return bregman(obj, x, options, TD) + return bregman(obj, x, options) +end + +function bregman(A, TD, x::AbstractVector{T1}, b::AbstractVector{T2}, options::BregmanParams=bregman_options()) where {T1<:Number, T2<:Number} + @warn "deprecation warning: please put TD in options (BregmanParams) for version > 0.1.7; now overwritting TD in BregmanParams" + options.TD = TD + return bregman(A, x, b, options) end """ - bregman(fun, TD, x, b, options) + bregman(funobj, x, options) Linearized bregman iteration for the system ``\\frac{1}{2} ||TD \\ x||_2^2 + λ ||TD \\ x||_1 \\ \\ \\ s.t Ax = b`` -For example, for sparsity promoting denoising (i.e LSRTM) +# Required arguments -# Arguments - -- `TD`: curvelet transform -- `fun`: residual function, return the tuple (``f = \\frac{1}{2}||Ax - b||_2``, ``g = A^T(Ax - b)``) -- `b`: observed data +- `funobj`: a function that calculates the objective value (`0.5 * norm(Ax-b)^2`) and the gradient (`A'(Ax-b)`) - `x`: Initial guess +# Non-required arguments + +- `options`: bregman options, default is bregman_options(); options.TD provides the sparsifying transform (e.g. curvelet) """ -function bregman(funobj::Function, x::AbstractArray{T}, options::BregmanParams, TD=nothing) where {T} +function bregman(funobj::Function, x::AbstractVector{T}, options::BregmanParams=bregman_options()) where {T} # Output Parameter Settings if options.verbose > 0 @printf("Running linearized bregman...\n"); @@ -85,9 +105,8 @@ function bregman(funobj::Function, x::AbstractArray{T}, options::BregmanParams, @printf("Maximum number of iterations: %d\n",options.maxIter) @printf("Anti-chatter correction: %d\n",options.antichatter) end - isnothing(TD) && (TD = LinearAlgebra.I) - # Intitalize variables - z = TD*x + # Initialize variables + z = options.TD*x d = similar(z) options.spg && (gold = similar(x); xold=similar(x)) if options.antichatter @@ -96,8 +115,6 @@ function bregman(funobj::Function, x::AbstractArray{T}, options::BregmanParams, # Result structure sol = breglog(x, z) - # Initialize λ - λ = abs(T(0)) # Output Log if options.verbose > 0 @@ -108,7 +125,7 @@ function bregman(funobj::Function, x::AbstractArray{T}, options::BregmanParams, for i=1:options.maxIter f, g = funobj(x) # Preconditionned ipdate direction - d .= -TD*g + d .= -options.TD*g # Step length t = (options.spg && i> 1) ? T(dot(x-xold, x-xold)/dot(x-xold, g-gold)) : T(options.alpha*f/norm(d)^2) t = abs(t) @@ -116,52 +133,37 @@ function bregman(funobj::Function, x::AbstractArray{T}, options::BregmanParams, # Anti-chatter if options.antichatter + @assert isreal(z) "we currently do not support anti-chatter for complex numbers" @. tk = tk - sign(d) - # Chatter correction - inds_z = findall(abs.(z) .> λ) - @views d[inds_z] .*= abs.(tk[inds_z])/i + # Chatter correction after 1st iteration + if i > 1 + inds_z = findall(abs.(z) .> sol.λ) + @views d[inds_z] .*= abs.(tk[inds_z])/i + end end # Update z variable @. z = z + d # Get λ at first iteration - i == 1 && (λ = abs(T(quantile(abs.(z), options.quantile)))) + (i == 1) && (sol.λ = abs.(T.(options.λfunc(z)))) # Save curent state options.spg && (gold .= g; xold .= x) # Update x - x = TD'*soft_thresholding(z, λ) + x = options.TD'*soft_thresholding(z, sol.λ) - obj_fun = λ * norm(z, 1) + .5 * norm(z, 2)^2 - if options.verbose > 0 - @printf("%10d %15.5e %15.5e %15.5e %15.5e \n",i, t, obj_fun, f, λ) - end + obj_fun = norm(sol.λ .* z, 1) + .5 * norm(z, 2)^2 + (options.verbose > 0) && (@printf("%10d %15.5e %15.5e %15.5e %15.5e \n",i, t, obj_fun, f, maximum(sol.λ))) norm(x - sol.x) < options.progTol && (@printf("Step size below progTol\n"); break;) update!(sol; iter=i, ϕ=obj_fun, residual=f, x=x, z=z, g=g, store_trace=options.store_trace) end return sol end -# Utility functions -""" -Simplified Quantile from Statistics.jl since we only need simplified version of it. -""" -function quantile(u::AbstractVector, p::Real) - 0 <= p <= 1 || throw(ArgumentError("input probability out of [0,1] range")) - n = length(u) - v = sort(u; alg=Base.QuickSort) - - m = 1 - p - aleph = n*p + oftype(p, m) - j = clamp(trunc(Int, aleph), 1, n-1) - γ = clamp(aleph - j, 0, 1) - - n == 1 ? a = v[1] : a = v[j] - n == 1 ? b = v[1] : b = v[j+1] - - (isfinite(a) && isfinite(b)) ? q = a + γ*(b-a) : q = (1-γ)*a + γ*b - return q +function bregman(funobj::Function, x::AbstractVector{T}, options::BregmanParams, TD) where {T} + @warn "deprecation warning: please put TD in options (BregmanParams) for version > 0.1.7; now overwritting TD in BregmanParams" + options.TD = TD + return bregman(funobj, x, options) end - """ Bregman result structure """ @@ -170,6 +172,7 @@ mutable struct BregmanIterations z g ϕ + λ residual ϕ_trace r_trace @@ -189,6 +192,6 @@ function update!(r::BregmanIterations; x=nothing, z=nothing, ϕ=nothing, residua (~isnothing(residual) && length(r.r_trace) == iter-1) && (push!(r.r_trace, residual)) end -function breglog(init_x, init_z; f0=0, obj0=0) - return BregmanIterations(1*init_x, 1*init_z, 0*init_z, f0, obj0, Vector{}(), Vector{}(), Vector{}(), Vector{}()) -end +function breglog(init_x, init_z; lambda0=0, f0=0, obj0=0) + return BregmanIterations(1*init_x, 1*init_z, 0*init_z, f0, lambda0, obj0, Vector{}(), Vector{}(), Vector{}(), Vector{}()) +end \ No newline at end of file diff --git a/test/test_bregman.jl b/test/test_bregman.jl index d0d5f9c..91abdf1 100644 --- a/test/test_bregman.jl +++ b/test/test_bregman.jl @@ -5,34 +5,38 @@ using LinearAlgebra N1 = 100 N2 = div(N1, 2) + 5 -A = randn(N1, N2) - -x0 = 10 .* randn(N2) -x0[abs.(x0) .< 1f-6] .= 1.0 -inds = rand(1:N2, div(N2, 4)) -ninds = [i for i=1:N2 if i ∉ inds] -x0[inds] .= 0 -b = A*x0 - -function obj(x) - fun = .5*norm(A*x - b)^2 - grad = A'*(A*x - b) - return fun, grad -end - -opt = bregman_options(maxIter=200, progTol=0, verbose=2) -sol = bregman(obj, 1 .+ randn(N2), opt) - -@show sol.x[inds] -@show x0[inds] -@show sol.x[ninds] -@show x0[ninds] - -part_n = i -> norm(sol.x[i] - x0[i])/(norm(x0[i]) + norm(sol.x[i]) + eps(Float64)) -part_nz = i -> norm(sol.x[i], 1)/N2 -@show part_nz(inds) -@show part_n(ninds) - -@test part_nz(inds) < 1f-1 -@test part_n(ninds) < 1f-1 -@test sol.residual/sol.r_trace[1] < 1f-1 + +@testset "Bregman test for type $(T)" for T = [Float32, ComplexF32] + + A = randn(T, N1, N2) + x0 = 10 .* randn(T, N2) + x0[abs.(x0) .< 1f-6] .= 1.0 + inds = rand(1:N2, div(N2, 4)) + ninds = [i for i=1:N2 if i ∉ inds] + x0[inds] .= 0 + b = A*x0 + + function obj(x) + fun = .5*norm(A*x - b)^2 + grad = A'*(A*x - b) + return fun, grad + end + + opt = bregman_options(maxIter=200, progTol=0, verbose=2, antichatter=T==Float32) + sol = bregman(obj, 1 .+ randn(T, N2), opt) + + @show sol.x[inds] + @show x0[inds] + @show sol.x[ninds] + @show x0[ninds] + + part_n = i -> norm(sol.x[i] - x0[i])/(norm(x0[i]) + norm(sol.x[i]) + eps(Float32)) + part_nz = i -> norm(sol.x[i], 1)/N2 + @show part_nz(inds) + @show part_n(ninds) + + @test part_nz(inds) < 1f-1 + @test part_n(ninds) < 1f-1 + @test sol.residual/sol.r_trace[1] < 1f-1 + +end \ No newline at end of file