Skip to content

Commit

Permalink
progress bar callback
Browse files Browse the repository at this point in the history
  • Loading branch information
atsanda committed Mar 21, 2024
1 parent 94d2bf6 commit 1a24382
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 19 deletions.
36 changes: 19 additions & 17 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,34 @@ authors = ["Tobias Knopp <[email protected]>"]
version = "0.13.1-DEV"

[deps]
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
LinearOperatorCollection = "a4a2c56f-fead-462a-a3ab-85921a5f2575"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125"
Coverage = "a2441757-f6aa-5fb2-8edb-039e3f45d037"
FLoops = "cc61a311-1640-44b5-9fba-1b764f453329"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearOperatorCollection = "a4a2c56f-fead-462a-a3ab-85921a5f2575"
LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"

[compat]
IterativeSolvers = "0.9"
julia = "1.9"
StatsBase = "0.33, 0.34"
FFTW = "1.0"
FLoops = "0.2"
VectorizationBase = "0.19, 0.21"
IterativeSolvers = "0.9"
LinearOperatorCollection = "1.2"
LinearOperators = "2.3.3"
FFTW = "1.0"
StatsBase = "0.33, 0.34"
VectorizationBase = "0.19, 0.21"
julia = "1.9"

[extras]
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "Random", "FFTW"]
34 changes: 33 additions & 1 deletion src/Callbacks.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
using ProgressMeter


export CompareSolutionCallback
mutable struct CompareSolutionCallback{T, F}
ref::Vector{T}
Expand Down Expand Up @@ -49,4 +52,33 @@ function (cb::StoreConvergenceCallback)(solver::AbstractLinearSolver, _)
push!(values, meas[key])
cb.convMeas[key] = values
end
end
end


export ProgressBarCallback
"""
ProgressBarCallback()
Callback that displays a progress bar for a solver.
"""
Base.@kwdef mutable struct ProgressBarCallback
meter::Union{Progress,Nothing} = nothing
end
ProgressBarCallback(solver::AbstractLinearSolver) = ProgressBarCallback(Progress(solver.iterations))
ProgressBarCallback(iterations::Int) = ProgressBarCallback(Progress(iterations))

"""
(self::ProgressBarCallback)(solver::AbstractLinearSolver, iter_n::Int)
Initializes the callback when `iter_n` is zero, then updates the progress bar.
"""
function (self::ProgressBarCallback)(solver::AbstractLinearSolver, iter_n::Int)
if iter_n != 0
next!(self.meter)
end

# lazy init for iter_n = 0
if iter_n == 0 && isnothing(self.meter)
self.meter = Progress(solver.iterations)
end
end
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ using FFTW
include("testKaczmarz.jl")
include("testProxMaps.jl")
include("testSolvers.jl")
include("testRegularization.jl")
include("testRegularization.jl")
include("testCallbacks.jl")
15 changes: 15 additions & 0 deletions test/testCallbacks.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
@testset "ProgressBarCallback" begin
A = [
0.831658 0.96717
0.383056 0.39043
0.820692 0.08118
]
x = [0.593; 0.269]
b = A * x

solver = ADMM(A; iterations=50)

_ = solve!(solver, b, callbacks=ProgressBarCallback())
_ = solve!(solver, b, callbacks=ProgressBarCallback(solver))
_ = solve!(solver, b, callbacks=ProgressBarCallback(50))
end

0 comments on commit 1a24382

Please sign in to comment.