Skip to content

Commit

Permalink
Merge pull request #90 from JuliaImageRecon/nh/multiThreading
Browse files Browse the repository at this point in the history
State-based multi-threading
  • Loading branch information
nHackel authored Aug 13, 2024
2 parents 7c20a71 + 6b17ba9 commit 450cac0
Show file tree
Hide file tree
Showing 12 changed files with 142 additions and 34 deletions.
6 changes: 3 additions & 3 deletions src/ADMM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ function ADMM(A
return ADMM(A, reg, regTrafo, proj, AHA, precon, normalizeReg, vary_rho, verbose, iterations, iterationsCG, state)
end

function init!(solver::ADMM, state::ADMMState{rT, rvecT, vecT}, b::otherT; kwargs...) where {rT, rvecT, vecT, otherT}
function init!(solver::ADMM, state::ADMMState{rT, rvecT, vecT}, b::otherT; kwargs...) where {rT, rvecT, vecT, otherT <: AbstractVector}
x = similar(b, size(state.x)...)
xᵒˡᵈ = similar(b, size(state.xᵒˡᵈ)...)
β = similar(b, size(state.β)...)
Expand All @@ -165,7 +165,7 @@ end
(re-) initializes the ADMM iterator
"""
function init!(solver::ADMM, state::ADMMState{rT, rvecT, vecT}, b::vecT; x0 = 0) where {rT, rvecT, vecT}
function init!(solver::ADMM, state::ADMMState{rT, rvecT, vecT}, b::vecT; x0 = 0) where {rT, rvecT, vecT <: AbstractVector}
state.x .= x0

# right hand side for the x-update
Expand Down Expand Up @@ -202,7 +202,7 @@ solverconvergence(state::ADMMState) = (; :primal => state.rᵏ, :dual => state.s
performs one ADMM iteration.
"""
function iterate(solver::ADMM, state::S = solver.state) where S <: AbstractSolverState{<:ADMM}
function iterate(solver::ADMM, state::ADMMState)
done(solver, state) && return nothing
solver.verbose && println("Outer ADMM Iteration #$iteration")

Expand Down
8 changes: 4 additions & 4 deletions src/CGNR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ function CGNR(A
return CGNR(A, AHA, L2, other, normalizeReg, iterations, state)
end

function init!(solver::CGNR, state::CGNRState{T, Tc, vecTc}, b::otherTc; kwargs...) where {T, Tc, vecTc, otherTc}
function init!(solver::CGNR, state::CGNRState{T, Tc, vecTc}, b::otherTc; kwargs...) where {T, Tc, vecTc, otherTc <: AbstractVector{Tc}}
x = similar(b, size(state.x)...)
x₀ = similar(b, size(state.x₀)...)
pl = similar(b, size(state.pl)...)
Expand All @@ -104,7 +104,7 @@ end
(re-) initializes the CGNR iterator
"""
function init!(solver::CGNR, state::CGNRState{T, Tc, vecTc}, b::vecTc; x0 = 0) where {T, Tc <: Union{T, Complex{T}}, vecTc<:AbstractArray{Tc}}
function init!(solver::CGNR, state::CGNRState{T, Tc, vecTc}, b::vecTc; x0 = 0) where {T, Tc <: Union{T, Complex{T}}, vecTc<:AbstractVector{Tc}}
state.pl .= 0 #temporary vector
state.vl .= 0 #temporary vector
state.αl = 0 #temporary scalar
Expand All @@ -131,7 +131,7 @@ end

initCGNR(x₀, A, b) = mul!(x₀, adjoint(A), b)
#initCGNR(x₀, prod::ProdOp{T, <:WeightingOp, matT}, b) where {T, matT} = mul!(x₀, adjoint(prod.B), b)
initCGNR(x₀, ::Nothing, b) = x₀ .= one(eltype(x₀))
initCGNR(x₀, ::Nothing, b) = x₀ .= b

solverconvergence(state::CGNRState) = (; :residual => norm(state.x₀))

Expand All @@ -140,7 +140,7 @@ solverconvergence(state::CGNRState) = (; :residual => norm(state.x₀))
performs one CGNR iteration.
"""
function iterate(solver::CGNR, state=solver.state)
function iterate(solver::CGNR, state::CGNRState)
if done(solver, state)
for r in solver.constr
prox!(r, state.x)
Expand Down
12 changes: 6 additions & 6 deletions src/Direct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,19 @@ function DirectSolver(A; reg::Vector{<:AbstractRegularization} = [L2Regularizati
return DirectSolver(A, L2, normalizeReg, other, DirectSolverState(x, b))
end

function init!(solver::DirectSolver, state::DirectSolverState{vecT}, b::otherT; kwargs...) where {vecT, otherT}
function init!(solver::DirectSolver, state::DirectSolverState{vecT}, b::otherT; kwargs...) where {vecT, otherT <: AbstractVector}
x = similar(b, size(state.x)...)
bvecT = similar(b, size(state.b)...)
solver.state = DirectSolverState(x, bvecT)
init!(solver, solver.state, b; kwargs...)
end
function init!(solver::DirectSolver, state::DirectSolverState{vecT}, b::vecT; x0=0) where vecT
function init!(solver::DirectSolver, state::DirectSolverState{vecT}, b::vecT; x0=0) where vecT <: AbstractVector
solver.l2 = normalize(solver, solver.normalizeReg, solver.l2, solver.A, b)
state.b .= b
state.x .= x0
end

function iterate(solver::DirectSolver, state = solver.state)
function iterate(solver::DirectSolver, state::DirectSolverState)
A = solver.A
λ_ = λ(solver.l2)
lufact = lu(A'*A .+ λ_)
Expand Down Expand Up @@ -138,18 +138,18 @@ function PseudoInverse(A::AbstractMatrix, x, b, l2, norm, proj)
return PseudoInverse(temp, l2, norm, proj, DirectSolverState(x, b))
end

function init!(solver::PseudoInverse, state::DirectSolverState{vecT}, b::otherT; kwargs...) where {vecT, otherT}
function init!(solver::PseudoInverse, state::DirectSolverState{vecT}, b::otherT; kwargs...) where {vecT, otherT <: AbstractVector}
x = similar(b, size(state.x)...)
bvecT = similar(b, size(state.b)...)
solver.state = DirectSolverState(x, bvecT)
init!(solver, solver.state, b; kwargs...)
end
function init!(solver::PseudoInverse, state::DirectSolverState{vecT}, b::vecT; x0=0) where vecT
function init!(solver::PseudoInverse, state::DirectSolverState{vecT}, b::vecT; x0=0) where vecT <: AbstractVector
solver.l2 = normalize(solver, solver.normalizeReg, solver.l2, solver.svd, b)
state.b .= b
end

function iterate(solver::PseudoInverse, state = solver.state)
function iterate(solver::PseudoInverse, state::DirectSolverState)
# Inversion by using the pseudoinverse of the SVD
svd = solver.svd

Expand Down
6 changes: 3 additions & 3 deletions src/FISTA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ function FISTA(A
return FISTA(A, AHA, reg[1], other, normalizeReg, verbose, restart, iterations, state)
end

function init!(solver::FISTA, state::FISTAState{rT, vecT}, b::otherT; kwargs...) where {rT, vecT, otherT}
function init!(solver::FISTA, state::FISTAState{rT, vecT}, b::otherT; kwargs...) where {rT, vecT, otherT <: AbstractVector}
x = similar(b, size(state.x)...)
x₀ = similar(b, size(state.x₀)...)
xᵒˡᵈ = similar(b, size(state.xᵒˡᵈ)...)
Expand All @@ -107,7 +107,7 @@ end
(re-) initializes the FISTA iterator
"""
function init!(solver::FISTA, state::FISTAState{rT, vecT}, b::vecT; x0 = 0, theta=1) where {rT, vecT}
function init!(solver::FISTA, state::FISTAState{rT, vecT}, b::vecT; x0 = 0, theta=1) where {rT, vecT <: AbstractVector}
if solver.A === nothing
state.x₀ .= b
else
Expand Down Expand Up @@ -136,7 +136,7 @@ solverconvergence(state::FISTAState) = (; :residual => norm(state.res))
performs one fista iteration.
"""
function iterate(solver::FISTA, state = solver.state)
function iterate(solver::FISTA, state::FISTAState)
if done(solver, state) return nothing end

# momentum / Nesterov step
Expand Down
10 changes: 5 additions & 5 deletions src/Kaczmarz.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ function Kaczmarz(A
Int64(seed), normalizeReg, iterations, state)
end

function init!(solver::Kaczmarz, state::KaczmarzState{T, vecT}, b::otherT; kwargs...) where {T, vecT, otherT}
function init!(solver::Kaczmarz, state::KaczmarzState{T, vecT}, b::otherT; kwargs...) where {T, vecT, otherT <: AbstractVector}
u = similar(b, size(state.u)...)
x = similar(b, size(state.x)...)
vl = similar(b, size(state.vl)...)
Expand All @@ -125,7 +125,7 @@ end
(re-) initializes the Kacmarz iterator
"""
function init!(solver::Kaczmarz, state::KaczmarzState{T, vecT}, b::vecT; x0 = 0) where {T, vecT}
function init!(solver::Kaczmarz, state::KaczmarzState{T, vecT}, b::vecT; x0 = 0) where {T, vecT <: AbstractVector}
λ_prev = λ(solver.L2)
solver.L2 = normalize(solver, solver.normalizeReg, solver.L2, solver.A, b)
solver.reg = normalize(solver, solver.normalizeReg, solver.reg, solver.A, b)
Expand Down Expand Up @@ -164,12 +164,12 @@ end


function solversolution(solver::Kaczmarz{matT, RN}) where {matT, R<:L2Regularization{<:AbstractVector}, RN <: Union{R, AbstractNestedRegularization{<:R}}}
return solver.state.x .* (1 ./ sqrt.(λ(solver.L2)))
return solversolution(solver.state) .* (1 ./ sqrt.(λ(solver.L2)))
end
solversolution(solver::Kaczmarz) = solver.state.x
solversolution(solver::Kaczmarz) = solversolution(solver.state)
solverconvergence(state::KaczmarzState) = (; :residual => norm(state.vl))

function iterate(solver::Kaczmarz, state = solver.state)
function iterate(solver::Kaczmarz, state::KaczmarzState)
if done(solver,state) return nothing end

if solver.randomized
Expand Down
68 changes: 68 additions & 0 deletions src/MultiThreading.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
export SequentialState, MultiThreadingState, prepareMultiStates
abstract type AbstractMatrixSolverState{S} <: AbstractSolverState{S} end
mutable struct SequentialState{S, ST <: AbstractSolverState{S}} <: AbstractMatrixSolverState{S}
states::Vector{ST}
active::Vector{Bool}
SequentialState(states::Vector{ST}) where {S, ST <: AbstractSolverState{S}} = new{S, ST}(states, fill(true, length(states)))
end

mutable struct MultiThreadingState{S, ST <: AbstractSolverState{S}} <: AbstractMatrixSolverState{S}
states::Vector{ST}
active::Vector{Bool}
MultiThreadingState(states::Vector{ST}) where {S, ST <: AbstractSolverState{S}} = new{S, ST}(states, fill(true, length(states)))
end

function init!(solver::AbstractLinearSolver, state::AbstractSolverState, b::AbstractMatrix; scheduler = SequentialState, kwargs...)
states = prepareMultiStates(solver, state, b)
multiState = scheduler(states)
solver.state = multiState
init!(solver, multiState, b; kwargs...)
end
function init!(solver::AbstractLinearSolver, state::AbstractMatrixSolverState, b::AbstractVector; kwargs...)
singleState = first(state.states)
solver.state = singleState
init!(solver, singleState, b; kwargs...)
end

function prepareMultiStates(solver::AbstractLinearSolver, state::AbstractSolverState, b::AbstractMatrix)
states = [deepcopy(state) for _ in 1:size(b, 2)]
return states
end
prepareMultiStates(solver::AbstractLinearSolver, state::Union{SequentialState, MultiThreadingState}, b::AbstractMatrix) = prepareMultiStates(solver, first(state.states), b)

function init!(solver::AbstractLinearSolver, state::AbstractMatrixSolverState, b::AbstractMatrix; kwargs...)
for (i, s) in enumerate(state.states)
init!(solver, s, b[:, i]; kwargs...)
end
state.active .= true
end

function iterate(solver::S, state::AbstractMatrixSolverState) where {S <: AbstractLinearSolver}
activeIdx = findall(state.active)
if isempty(activeIdx)
return nothing
end
return iterate(solver, state, activeIdx)
end

function iterate(solver::AbstractLinearSolver, state::SequentialState, activeIdx)
for i in activeIdx
res = iterate(solver, state.states[i])
if isnothing(res)
state.active[i] = false
end
end
return state.active, state
end

function iterate(solver::AbstractLinearSolver, state::MultiThreadingState, activeIdx)
Threads.@threads for i in activeIdx
res = iterate(solver, state.states[i])
if isnothing(res)
state.active[i] = false
end
end
return state.active, state
end

solversolution(state::AbstractMatrixSolverState) = mapreduce(solversolution, hcat, state.states)
6 changes: 3 additions & 3 deletions src/OptISTA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ function OptISTA(A
return OptISTA(A, AHA, reg[1], other, normalizeReg, verbose, iterations, state)
end

function init!(solver::OptISTA, state::OptISTAState{rT, vecT}, b::otherT; kwargs...) where {rT, vecT, otherT}
function init!(solver::OptISTA, state::OptISTAState{rT, vecT}, b::otherT; kwargs...) where {rT, vecT, otherT <: AbstractVector}
x = similar(b, size(state.x)...)
x₀ = similar(b, size(state.x₀)...)
y = similar(b, size(state.y)...)
Expand All @@ -125,7 +125,7 @@ end
(re-) initializes the OptISTA iterator
"""
function init!(solver::OptISTA, state::OptISTAState{rT, vecT}, b::vecT; x0 = 0, θ=1) where {rT, vecT}
function init!(solver::OptISTA, state::OptISTAState{rT, vecT}, b::vecT; x0 = 0, θ=1) where {rT, vecT <: AbstractVector}
if solver.A === nothing
state.x₀ .= b
else
Expand Down Expand Up @@ -161,7 +161,7 @@ solverconvergence(state::OptISTAState) = (; :residual => norm(state.res))
performs one OptISTA iteration.
"""
function iterate(solver::OptISTA, state::OptISTAState = solver.state)
function iterate(solver::OptISTA, state::OptISTAState)
if done(solver, state) return nothing end

# inertial parameters
Expand Down
6 changes: 3 additions & 3 deletions src/POGM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ function POGM(A
return POGM(A, AHA, reg[1], other, normalizeReg, verbose, restart, iterations, state)
end

function init!(solver::POGM, state::POGMState{rT, vecT}, b::otherT; kwargs...) where {rT, vecT, otherT}
function init!(solver::POGM, state::POGMState{rT, vecT}, b::otherT; kwargs...) where {rT, vecT, otherT <: AbstractVector}
x = similar(b, size(state.x)...)
x₀ = similar(b, size(state.x₀)...)
xᵒˡᵈ = similar(b, size(state.xᵒˡᵈ)...)
Expand All @@ -135,7 +135,7 @@ end
(re-) initializes the POGM iterator
"""
function init!(solver::POGM, state::POGMState{rT, vecT}, b::vecT; x0 = 0, theta=1) where {rT, vecT}
function init!(solver::POGM, state::POGMState{rT, vecT}, b::vecT; x0 = 0, theta=1) where {rT, vecT <: AbstractVector}
if solver.A === nothing
state.x₀ .= b
else
Expand Down Expand Up @@ -170,7 +170,7 @@ solverconvergence(state::POGMState) = (; :residual => norm(state.res))
performs one POGM iteration.
"""
function iterate(solver::POGM, state = solver.state)
function iterate(solver::POGM, state::POGMState)
if done(solver, state)
return nothing
end
Expand Down
15 changes: 11 additions & 4 deletions src/RegularizedLeastSquares.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,13 @@ The keyword `callbacks` allows you to pass a (vector of) callable objects that t
See also [`StoreSolutionCallback`](@ref), [`StoreConvergenceCallback`](@ref), [`CompareSolutionCallback`](@ref) for a number of provided callback options.
"""
function solve!(solver::AbstractLinearSolver, b; x0 = 0, callbacks = (_, _) -> nothing)
function solve!(solver::AbstractLinearSolver, b; callbacks = (_, _) -> nothing, kwargs...)
if !(callbacks isa Vector)
callbacks = [callbacks]
end


init!(solver, b; x0)
init!(solver, b; kwargs...)
foreach(cb -> cb(solver, 0), callbacks)

for (iteration, _) = enumerate(solver)
Expand All @@ -129,7 +129,7 @@ end
"""
solve!(cb, solver::AbstractLinearSolver, b; kwargs...) = solve!(solver, b; kwargs..., callbacks = cb)


include("MultiThreading.jl")

export AbstractRowActionSolver
abstract type AbstractRowActionSolver <: AbstractLinearSolver end
Expand Down Expand Up @@ -159,7 +159,13 @@ export solversolution, solverconvergence, solverstate
Return the current solution of the solver
"""
solversolution(solver::AbstractLinearSolver) = solverstate(solver).x
solversolution(solver::AbstractLinearSolver) = solversolution(solverstate(solver))
"""
solversolution(state::AbstractSolverState)
Return the current solution of the solver's state
"""
solversolution(state::AbstractSolverState) = state.x
"""
solverconvergence(solver::AbstractLinearSolver)
Expand All @@ -171,6 +177,7 @@ solverstate(solver::AbstractLinearSolver) = solver.state
solverconvergence(solver::AbstractLinearSolver) = solverconvergence(solverstate(solver))

init!(solver::AbstractLinearSolver, b; kwargs...) = init!(solver, solverstate(solver), b; kwargs...)
iterate(solver::AbstractLinearSolver) = iterate(solver, solverstate(solver))

include("Utils.jl")
include("Kaczmarz.jl")
Expand Down
6 changes: 3 additions & 3 deletions src/SplitBregman.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ function SplitBregman(A
return SplitBregman(A,reg,regTrafo,proj,AHA,precon,normalizeReg,verbose,iterations,iterationsInner,iterationsCG,state)
end

function init!(solver::SplitBregman, state::SplitBregmanState{rT, rvecT, vecT}, b::otherT; kwargs...) where {rT, rvecT, vecT, otherT}
function init!(solver::SplitBregman, state::SplitBregmanState{rT, rvecT, vecT}, b::otherT; kwargs...) where {rT, rvecT, vecT, otherT <: AbstractVector}
y = similar(b, size(state.y)...)
x = similar(b, size(state.x)...)
β = similar(b, size(state.β)...)
Expand All @@ -167,7 +167,7 @@ end
(re-) initializes the SplitBregman iterator
"""
function init!(solver::SplitBregman, state::SplitBregmanState{rT, rvecT, vecT}, b::vecT; x0 = 0) where {rT, rvecT, vecT}
function init!(solver::SplitBregman, state::SplitBregmanState{rT, rvecT, vecT}, b::vecT; x0 = 0) where {rT, rvecT, vecT <: AbstractVector}
state.x .= x0

# right hand side for the x-update
Expand Down Expand Up @@ -201,7 +201,7 @@ end

solverconvergence(state::SplitBregmanState) = (; :primal => state.rᵏ, :dual => state.sᵏ)

function iterate(solver::SplitBregman, state=solver.state)
function iterate(solver::SplitBregman, state::SplitBregmanState)
if done(solver, state) return nothing end
solver.verbose && println("SplitBregman Iteration #$(state.iteration) – Outer iteration $(state.iter_cnt)")

Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ arrayTypes = areTypesDefined ? arrayTypes : [Array, JLArray]
include("testProxMaps.jl")
include("testSolvers.jl")
include("testRegularization.jl")
include("testMultiThreading.jl")
end
Loading

0 comments on commit 450cac0

Please sign in to comment.