Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

State-based multi-threading #90

Merged
merged 6 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/ADMM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ function ADMM(A
return ADMM(A, reg, regTrafo, proj, AHA, precon, normalizeReg, vary_rho, verbose, iterations, iterationsCG, state)
end

function init!(solver::ADMM, state::ADMMState{rT, rvecT, vecT}, b::otherT; kwargs...) where {rT, rvecT, vecT, otherT}
function init!(solver::ADMM, state::ADMMState{rT, rvecT, vecT}, b::otherT; kwargs...) where {rT, rvecT, vecT, otherT <: AbstractVector}
x = similar(b, size(state.x)...)
xᵒˡᵈ = similar(b, size(state.xᵒˡᵈ)...)
β = similar(b, size(state.β)...)
Expand All @@ -165,7 +165,7 @@ end

(re-) initializes the ADMM iterator
"""
function init!(solver::ADMM, state::ADMMState{rT, rvecT, vecT}, b::vecT; x0 = 0) where {rT, rvecT, vecT}
function init!(solver::ADMM, state::ADMMState{rT, rvecT, vecT}, b::vecT; x0 = 0) where {rT, rvecT, vecT <: AbstractVector}
state.x .= x0

# right hand side for the x-update
Expand Down Expand Up @@ -202,7 +202,7 @@ solverconvergence(state::ADMMState) = (; :primal => state.rᵏ, :dual => state.s

performs one ADMM iteration.
"""
function iterate(solver::ADMM, state::S = solver.state) where S <: AbstractSolverState{<:ADMM}
function iterate(solver::ADMM, state::ADMMState)
done(solver, state) && return nothing
solver.verbose && println("Outer ADMM Iteration #$iteration")

Expand Down
8 changes: 4 additions & 4 deletions src/CGNR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ function CGNR(A
return CGNR(A, AHA, L2, other, normalizeReg, iterations, state)
end

function init!(solver::CGNR, state::CGNRState{T, Tc, vecTc}, b::otherTc; kwargs...) where {T, Tc, vecTc, otherTc}
function init!(solver::CGNR, state::CGNRState{T, Tc, vecTc}, b::otherTc; kwargs...) where {T, Tc, vecTc, otherTc <: AbstractVector{Tc}}
x = similar(b, size(state.x)...)
x₀ = similar(b, size(state.x₀)...)
pl = similar(b, size(state.pl)...)
Expand All @@ -104,7 +104,7 @@ end

(re-) initializes the CGNR iterator
"""
function init!(solver::CGNR, state::CGNRState{T, Tc, vecTc}, b::vecTc; x0 = 0) where {T, Tc <: Union{T, Complex{T}}, vecTc<:AbstractArray{Tc}}
function init!(solver::CGNR, state::CGNRState{T, Tc, vecTc}, b::vecTc; x0 = 0) where {T, Tc <: Union{T, Complex{T}}, vecTc<:AbstractVector{Tc}}
state.pl .= 0 #temporary vector
state.vl .= 0 #temporary vector
state.αl = 0 #temporary scalar
Expand All @@ -131,7 +131,7 @@ end

initCGNR(x₀, A, b) = mul!(x₀, adjoint(A), b)
#initCGNR(x₀, prod::ProdOp{T, <:WeightingOp, matT}, b) where {T, matT} = mul!(x₀, adjoint(prod.B), b)
initCGNR(x₀, ::Nothing, b) = x₀ .= one(eltype(x₀))
initCGNR(x₀, ::Nothing, b) = x₀ .= b

solverconvergence(state::CGNRState) = (; :residual => norm(state.x₀))

Expand All @@ -140,7 +140,7 @@ solverconvergence(state::CGNRState) = (; :residual => norm(state.x₀))

performs one CGNR iteration.
"""
function iterate(solver::CGNR, state=solver.state)
function iterate(solver::CGNR, state::CGNRState)
if done(solver, state)
for r in solver.constr
prox!(r, state.x)
Expand Down
12 changes: 6 additions & 6 deletions src/Direct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,19 @@ function DirectSolver(A; reg::Vector{<:AbstractRegularization} = [L2Regularizati
return DirectSolver(A, L2, normalizeReg, other, DirectSolverState(x, b))
end

function init!(solver::DirectSolver, state::DirectSolverState{vecT}, b::otherT; kwargs...) where {vecT, otherT}
function init!(solver::DirectSolver, state::DirectSolverState{vecT}, b::otherT; kwargs...) where {vecT, otherT <: AbstractVector}
x = similar(b, size(state.x)...)
bvecT = similar(b, size(state.b)...)
solver.state = DirectSolverState(x, bvecT)
init!(solver, solver.state, b; kwargs...)
end
function init!(solver::DirectSolver, state::DirectSolverState{vecT}, b::vecT; x0=0) where vecT
function init!(solver::DirectSolver, state::DirectSolverState{vecT}, b::vecT; x0=0) where vecT <: AbstractVector
solver.l2 = normalize(solver, solver.normalizeReg, solver.l2, solver.A, b)
state.b .= b
state.x .= x0
end

function iterate(solver::DirectSolver, state = solver.state)
function iterate(solver::DirectSolver, state::DirectSolverState)
A = solver.A
λ_ = λ(solver.l2)
lufact = lu(A'*A .+ λ_)
Expand Down Expand Up @@ -138,18 +138,18 @@ function PseudoInverse(A::AbstractMatrix, x, b, l2, norm, proj)
return PseudoInverse(temp, l2, norm, proj, DirectSolverState(x, b))
end

function init!(solver::PseudoInverse, state::DirectSolverState{vecT}, b::otherT; kwargs...) where {vecT, otherT}
function init!(solver::PseudoInverse, state::DirectSolverState{vecT}, b::otherT; kwargs...) where {vecT, otherT <: AbstractVector}
x = similar(b, size(state.x)...)
bvecT = similar(b, size(state.b)...)
solver.state = DirectSolverState(x, bvecT)
init!(solver, solver.state, b; kwargs...)
end
function init!(solver::PseudoInverse, state::DirectSolverState{vecT}, b::vecT; x0=0) where vecT
function init!(solver::PseudoInverse, state::DirectSolverState{vecT}, b::vecT; x0=0) where vecT <: AbstractVector
solver.l2 = normalize(solver, solver.normalizeReg, solver.l2, solver.svd, b)
state.b .= b
end

function iterate(solver::PseudoInverse, state = solver.state)
function iterate(solver::PseudoInverse, state::DirectSolverState)
# Inversion by using the pseudoinverse of the SVD
svd = solver.svd

Expand Down
6 changes: 3 additions & 3 deletions src/FISTA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ function FISTA(A
return FISTA(A, AHA, reg[1], other, normalizeReg, verbose, restart, iterations, state)
end

function init!(solver::FISTA, state::FISTAState{rT, vecT}, b::otherT; kwargs...) where {rT, vecT, otherT}
function init!(solver::FISTA, state::FISTAState{rT, vecT}, b::otherT; kwargs...) where {rT, vecT, otherT <: AbstractVector}
x = similar(b, size(state.x)...)
x₀ = similar(b, size(state.x₀)...)
xᵒˡᵈ = similar(b, size(state.xᵒˡᵈ)...)
Expand All @@ -107,7 +107,7 @@ end

(re-) initializes the FISTA iterator
"""
function init!(solver::FISTA, state::FISTAState{rT, vecT}, b::vecT; x0 = 0, theta=1) where {rT, vecT}
function init!(solver::FISTA, state::FISTAState{rT, vecT}, b::vecT; x0 = 0, theta=1) where {rT, vecT <: AbstractVector}
if solver.A === nothing
state.x₀ .= b
else
Expand Down Expand Up @@ -136,7 +136,7 @@ solverconvergence(state::FISTAState) = (; :residual => norm(state.res))

performs one fista iteration.
"""
function iterate(solver::FISTA, state = solver.state)
function iterate(solver::FISTA, state::FISTAState)
if done(solver, state) return nothing end

# momentum / Nesterov step
Expand Down
10 changes: 5 additions & 5 deletions src/Kaczmarz.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ function Kaczmarz(A
Int64(seed), normalizeReg, iterations, state)
end

function init!(solver::Kaczmarz, state::KaczmarzState{T, vecT}, b::otherT; kwargs...) where {T, vecT, otherT}
function init!(solver::Kaczmarz, state::KaczmarzState{T, vecT}, b::otherT; kwargs...) where {T, vecT, otherT <: AbstractVector}
u = similar(b, size(state.u)...)
x = similar(b, size(state.x)...)
vl = similar(b, size(state.vl)...)
Expand All @@ -125,7 +125,7 @@ end

(re-) initializes the Kacmarz iterator
"""
function init!(solver::Kaczmarz, state::KaczmarzState{T, vecT}, b::vecT; x0 = 0) where {T, vecT}
function init!(solver::Kaczmarz, state::KaczmarzState{T, vecT}, b::vecT; x0 = 0) where {T, vecT <: AbstractVector}
λ_prev = λ(solver.L2)
solver.L2 = normalize(solver, solver.normalizeReg, solver.L2, solver.A, b)
solver.reg = normalize(solver, solver.normalizeReg, solver.reg, solver.A, b)
Expand Down Expand Up @@ -164,12 +164,12 @@ end


function solversolution(solver::Kaczmarz{matT, RN}) where {matT, R<:L2Regularization{<:AbstractVector}, RN <: Union{R, AbstractNestedRegularization{<:R}}}
return solver.state.x .* (1 ./ sqrt.(λ(solver.L2)))
return solversolution(solver.state) .* (1 ./ sqrt.(λ(solver.L2)))
end
solversolution(solver::Kaczmarz) = solver.state.x
solversolution(solver::Kaczmarz) = solversolution(solver.state)
solverconvergence(state::KaczmarzState) = (; :residual => norm(state.vl))

function iterate(solver::Kaczmarz, state = solver.state)
function iterate(solver::Kaczmarz, state::KaczmarzState)
if done(solver,state) return nothing end

if solver.randomized
Expand Down
68 changes: 68 additions & 0 deletions src/MultiThreading.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
export SequentialState, MultiThreadingState, prepareMultiStates
abstract type AbstractMatrixSolverState{S} <: AbstractSolverState{S} end
mutable struct SequentialState{S, ST <: AbstractSolverState{S}} <: AbstractMatrixSolverState{S}
states::Vector{ST}
active::Vector{Bool}
SequentialState(states::Vector{ST}) where {S, ST <: AbstractSolverState{S}} = new{S, ST}(states, fill(true, length(states)))
end

mutable struct MultiThreadingState{S, ST <: AbstractSolverState{S}} <: AbstractMatrixSolverState{S}
states::Vector{ST}
active::Vector{Bool}
MultiThreadingState(states::Vector{ST}) where {S, ST <: AbstractSolverState{S}} = new{S, ST}(states, fill(true, length(states)))
end

function init!(solver::AbstractLinearSolver, state::AbstractSolverState, b::AbstractMatrix; scheduler = SequentialState, kwargs...)
states = prepareMultiStates(solver, state, b)
multiState = scheduler(states)
solver.state = multiState
init!(solver, multiState, b; kwargs...)
end
function init!(solver::AbstractLinearSolver, state::AbstractMatrixSolverState, b::AbstractVector; kwargs...)
singleState = first(state.states)
solver.state = singleState
init!(solver, singleState, b; kwargs...)
end

function prepareMultiStates(solver::AbstractLinearSolver, state::AbstractSolverState, b::AbstractMatrix)
states = [deepcopy(state) for _ in 1:size(b, 2)]
return states
end
prepareMultiStates(solver::AbstractLinearSolver, state::Union{SequentialState, MultiThreadingState}, b::AbstractMatrix) = prepareMultiStates(solver, first(state.states), b)

function init!(solver::AbstractLinearSolver, state::AbstractMatrixSolverState, b::AbstractMatrix; kwargs...)
for (i, s) in enumerate(state.states)
init!(solver, s, b[:, i]; kwargs...)
end
state.active .= true
end

function iterate(solver::S, state::AbstractMatrixSolverState) where {S <: AbstractLinearSolver}
activeIdx = findall(state.active)
if isempty(activeIdx)
return nothing
end
return iterate(solver, state, activeIdx)
end

function iterate(solver::AbstractLinearSolver, state::SequentialState, activeIdx)
for i in activeIdx
res = iterate(solver, state.states[i])
if isnothing(res)
state.active[i] = false
end
end
return state.active, state
end

function iterate(solver::AbstractLinearSolver, state::MultiThreadingState, activeIdx)
Threads.@threads for i in activeIdx
res = iterate(solver, state.states[i])
if isnothing(res)
state.active[i] = false
end
end
return state.active, state
end

solversolution(state::AbstractMatrixSolverState) = mapreduce(solversolution, hcat, state.states)
6 changes: 3 additions & 3 deletions src/OptISTA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ function OptISTA(A
return OptISTA(A, AHA, reg[1], other, normalizeReg, verbose, iterations, state)
end

function init!(solver::OptISTA, state::OptISTAState{rT, vecT}, b::otherT; kwargs...) where {rT, vecT, otherT}
function init!(solver::OptISTA, state::OptISTAState{rT, vecT}, b::otherT; kwargs...) where {rT, vecT, otherT <: AbstractVector}
x = similar(b, size(state.x)...)
x₀ = similar(b, size(state.x₀)...)
y = similar(b, size(state.y)...)
Expand All @@ -125,7 +125,7 @@ end

(re-) initializes the OptISTA iterator
"""
function init!(solver::OptISTA, state::OptISTAState{rT, vecT}, b::vecT; x0 = 0, θ=1) where {rT, vecT}
function init!(solver::OptISTA, state::OptISTAState{rT, vecT}, b::vecT; x0 = 0, θ=1) where {rT, vecT <: AbstractVector}
if solver.A === nothing
state.x₀ .= b
else
Expand Down Expand Up @@ -161,7 +161,7 @@ solverconvergence(state::OptISTAState) = (; :residual => norm(state.res))

performs one OptISTA iteration.
"""
function iterate(solver::OptISTA, state::OptISTAState = solver.state)
function iterate(solver::OptISTA, state::OptISTAState)
if done(solver, state) return nothing end

# inertial parameters
Expand Down
6 changes: 3 additions & 3 deletions src/POGM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ function POGM(A
return POGM(A, AHA, reg[1], other, normalizeReg, verbose, restart, iterations, state)
end

function init!(solver::POGM, state::POGMState{rT, vecT}, b::otherT; kwargs...) where {rT, vecT, otherT}
function init!(solver::POGM, state::POGMState{rT, vecT}, b::otherT; kwargs...) where {rT, vecT, otherT <: AbstractVector}
x = similar(b, size(state.x)...)
x₀ = similar(b, size(state.x₀)...)
xᵒˡᵈ = similar(b, size(state.xᵒˡᵈ)...)
Expand All @@ -135,7 +135,7 @@ end

(re-) initializes the POGM iterator
"""
function init!(solver::POGM, state::POGMState{rT, vecT}, b::vecT; x0 = 0, theta=1) where {rT, vecT}
function init!(solver::POGM, state::POGMState{rT, vecT}, b::vecT; x0 = 0, theta=1) where {rT, vecT <: AbstractVector}
if solver.A === nothing
state.x₀ .= b
else
Expand Down Expand Up @@ -170,7 +170,7 @@ solverconvergence(state::POGMState) = (; :residual => norm(state.res))

performs one POGM iteration.
"""
function iterate(solver::POGM, state = solver.state)
function iterate(solver::POGM, state::POGMState)
if done(solver, state)
return nothing
end
Expand Down
15 changes: 11 additions & 4 deletions src/RegularizedLeastSquares.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,13 @@ The keyword `callbacks` allows you to pass a (vector of) callable objects that t

See also [`StoreSolutionCallback`](@ref), [`StoreConvergenceCallback`](@ref), [`CompareSolutionCallback`](@ref) for a number of provided callback options.
"""
function solve!(solver::AbstractLinearSolver, b; x0 = 0, callbacks = (_, _) -> nothing)
function solve!(solver::AbstractLinearSolver, b; callbacks = (_, _) -> nothing, kwargs...)
if !(callbacks isa Vector)
callbacks = [callbacks]
end


init!(solver, b; x0)
init!(solver, b; kwargs...)
foreach(cb -> cb(solver, 0), callbacks)

for (iteration, _) = enumerate(solver)
Expand All @@ -129,7 +129,7 @@ end
"""
solve!(cb, solver::AbstractLinearSolver, b; kwargs...) = solve!(solver, b; kwargs..., callbacks = cb)


include("MultiThreading.jl")

export AbstractRowActionSolver
abstract type AbstractRowActionSolver <: AbstractLinearSolver end
Expand Down Expand Up @@ -159,7 +159,13 @@ export solversolution, solverconvergence, solverstate

Return the current solution of the solver
"""
solversolution(solver::AbstractLinearSolver) = solverstate(solver).x
solversolution(solver::AbstractLinearSolver) = solversolution(solverstate(solver))
"""
solversolution(state::AbstractSolverState)

Return the current solution of the solver's state
"""
solversolution(state::AbstractSolverState) = state.x
"""
solverconvergence(solver::AbstractLinearSolver)

Expand All @@ -171,6 +177,7 @@ solverstate(solver::AbstractLinearSolver) = solver.state
solverconvergence(solver::AbstractLinearSolver) = solverconvergence(solverstate(solver))

init!(solver::AbstractLinearSolver, b; kwargs...) = init!(solver, solverstate(solver), b; kwargs...)
iterate(solver::AbstractLinearSolver) = iterate(solver, solverstate(solver))

include("Utils.jl")
include("Kaczmarz.jl")
Expand Down
6 changes: 3 additions & 3 deletions src/SplitBregman.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ function SplitBregman(A
return SplitBregman(A,reg,regTrafo,proj,AHA,precon,normalizeReg,verbose,iterations,iterationsInner,iterationsCG,state)
end

function init!(solver::SplitBregman, state::SplitBregmanState{rT, rvecT, vecT}, b::otherT; kwargs...) where {rT, rvecT, vecT, otherT}
function init!(solver::SplitBregman, state::SplitBregmanState{rT, rvecT, vecT}, b::otherT; kwargs...) where {rT, rvecT, vecT, otherT <: AbstractVector}
y = similar(b, size(state.y)...)
x = similar(b, size(state.x)...)
β = similar(b, size(state.β)...)
Expand All @@ -167,7 +167,7 @@ end

(re-) initializes the SplitBregman iterator
"""
function init!(solver::SplitBregman, state::SplitBregmanState{rT, rvecT, vecT}, b::vecT; x0 = 0) where {rT, rvecT, vecT}
function init!(solver::SplitBregman, state::SplitBregmanState{rT, rvecT, vecT}, b::vecT; x0 = 0) where {rT, rvecT, vecT <: AbstractVector}
state.x .= x0

# right hand side for the x-update
Expand Down Expand Up @@ -201,7 +201,7 @@ end

solverconvergence(state::SplitBregmanState) = (; :primal => state.rᵏ, :dual => state.sᵏ)

function iterate(solver::SplitBregman, state=solver.state)
function iterate(solver::SplitBregman, state::SplitBregmanState)
if done(solver, state) return nothing end
solver.verbose && println("SplitBregman Iteration #$(state.iteration) – Outer iteration $(state.iter_cnt)")

Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ arrayTypes = areTypesDefined ? arrayTypes : [Array, JLArray]
include("testProxMaps.jl")
include("testSolvers.jl")
include("testRegularization.jl")
include("testMultiThreading.jl")
end
Loading
Loading