From 48c2df781ab740ed99d66db6083381eaa9cebc5c Mon Sep 17 00:00:00 2001 From: nHackel Date: Mon, 12 Aug 2024 16:22:41 +0200 Subject: [PATCH 1/6] Restrict solvers iterate and init function to vectors and make iterate overwriteable by deleting default arg --- src/ADMM.jl | 6 +++--- src/CGNR.jl | 6 +++--- src/FISTA.jl | 6 +++--- src/Kaczmarz.jl | 6 +++--- src/OptISTA.jl | 6 +++--- src/POGM.jl | 6 +++--- src/RegularizedLeastSquares.jl | 1 + src/SplitBregman.jl | 6 +++--- 8 files changed, 22 insertions(+), 21 deletions(-) diff --git a/src/ADMM.jl b/src/ADMM.jl index 48f5d355..447753e9 100644 --- a/src/ADMM.jl +++ b/src/ADMM.jl @@ -140,7 +140,7 @@ function ADMM(A return ADMM(A, reg, regTrafo, proj, AHA, precon, normalizeReg, vary_rho, verbose, iterations, iterationsCG, state) end -function init!(solver::ADMM, state::ADMMState{rT, rvecT, vecT}, b::otherT; kwargs...) where {rT, rvecT, vecT, otherT} +function init!(solver::ADMM, state::ADMMState{rT, rvecT, vecT}, b::otherT; kwargs...) where {rT, rvecT, vecT, otherT <: AbstractVector} x = similar(b, size(state.x)...) xᵒˡᵈ = similar(b, size(state.xᵒˡᵈ)...) β = similar(b, size(state.β)...) @@ -165,7 +165,7 @@ end (re-) initializes the ADMM iterator """ -function init!(solver::ADMM, state::ADMMState{rT, rvecT, vecT}, b::vecT; x0 = 0) where {rT, rvecT, vecT} +function init!(solver::ADMM, state::ADMMState{rT, rvecT, vecT}, b::vecT; x0 = 0) where {rT, rvecT, vecT <: AbstractVector} state.x .= x0 # right hand side for the x-update @@ -202,7 +202,7 @@ solverconvergence(state::ADMMState) = (; :primal => state.rᵏ, :dual => state.s performs one ADMM iteration. """ -function iterate(solver::ADMM, state::S = solver.state) where S <: AbstractSolverState{<:ADMM} +function iterate(solver::ADMM, state::ADMMState) done(solver, state) && return nothing solver.verbose && println("Outer ADMM Iteration #$iteration") diff --git a/src/CGNR.jl b/src/CGNR.jl index 83112b18..939e85b8 100644 --- a/src/CGNR.jl +++ b/src/CGNR.jl @@ -88,7 +88,7 @@ function CGNR(A return CGNR(A, AHA, L2, other, normalizeReg, iterations, state) end -function init!(solver::CGNR, state::CGNRState{T, Tc, vecTc}, b::otherTc; kwargs...) where {T, Tc, vecTc, otherTc} +function init!(solver::CGNR, state::CGNRState{T, Tc, vecTc}, b::otherTc; kwargs...) where {T, Tc, vecTc, otherTc <: AbstractVector{Tc}} x = similar(b, size(state.x)...) x₀ = similar(b, size(state.x₀)...) pl = similar(b, size(state.pl)...) @@ -104,7 +104,7 @@ end (re-) initializes the CGNR iterator """ -function init!(solver::CGNR, state::CGNRState{T, Tc, vecTc}, b::vecTc; x0 = 0) where {T, Tc <: Union{T, Complex{T}}, vecTc<:AbstractArray{Tc}} +function init!(solver::CGNR, state::CGNRState{T, Tc, vecTc}, b::vecTc; x0 = 0) where {T, Tc <: Union{T, Complex{T}}, vecTc<:AbstractVector{Tc}} state.pl .= 0 #temporary vector state.vl .= 0 #temporary vector state.αl = 0 #temporary scalar @@ -140,7 +140,7 @@ solverconvergence(state::CGNRState) = (; :residual => norm(state.x₀)) performs one CGNR iteration. """ -function iterate(solver::CGNR, state=solver.state) +function iterate(solver::CGNR, state::CGNRState) if done(solver, state) for r in solver.constr prox!(r, state.x) diff --git a/src/FISTA.jl b/src/FISTA.jl index c0835b19..db4d4cf4 100644 --- a/src/FISTA.jl +++ b/src/FISTA.jl @@ -91,7 +91,7 @@ function FISTA(A return FISTA(A, AHA, reg[1], other, normalizeReg, verbose, restart, iterations, state) end -function init!(solver::FISTA, state::FISTAState{rT, vecT}, b::otherT; kwargs...) where {rT, vecT, otherT} +function init!(solver::FISTA, state::FISTAState{rT, vecT}, b::otherT; kwargs...) where {rT, vecT, otherT <: AbstractVector} x = similar(b, size(state.x)...) x₀ = similar(b, size(state.x₀)...) xᵒˡᵈ = similar(b, size(state.xᵒˡᵈ)...) @@ -107,7 +107,7 @@ end (re-) initializes the FISTA iterator """ -function init!(solver::FISTA, state::FISTAState{rT, vecT}, b::vecT; x0 = 0, theta=1) where {rT, vecT} +function init!(solver::FISTA, state::FISTAState{rT, vecT}, b::vecT; x0 = 0, theta=1) where {rT, vecT <: AbstractVector} if solver.A === nothing state.x₀ .= b else @@ -136,7 +136,7 @@ solverconvergence(state::FISTAState) = (; :residual => norm(state.res)) performs one fista iteration. """ -function iterate(solver::FISTA, state = solver.state) +function iterate(solver::FISTA, state::FISTAState) if done(solver, state) return nothing end # momentum / Nesterov step diff --git a/src/Kaczmarz.jl b/src/Kaczmarz.jl index 614508c9..3ddcd979 100644 --- a/src/Kaczmarz.jl +++ b/src/Kaczmarz.jl @@ -110,7 +110,7 @@ function Kaczmarz(A Int64(seed), normalizeReg, iterations, state) end -function init!(solver::Kaczmarz, state::KaczmarzState{T, vecT}, b::otherT; kwargs...) where {T, vecT, otherT} +function init!(solver::Kaczmarz, state::KaczmarzState{T, vecT}, b::otherT; kwargs...) where {T, vecT, otherT <: AbstractVector} u = similar(b, size(state.u)...) x = similar(b, size(state.x)...) vl = similar(b, size(state.vl)...) @@ -125,7 +125,7 @@ end (re-) initializes the Kacmarz iterator """ -function init!(solver::Kaczmarz, state::KaczmarzState{T, vecT}, b::vecT; x0 = 0) where {T, vecT} +function init!(solver::Kaczmarz, state::KaczmarzState{T, vecT}, b::vecT; x0 = 0) where {T, vecT <: AbstractVector} λ_prev = λ(solver.L2) solver.L2 = normalize(solver, solver.normalizeReg, solver.L2, solver.A, b) solver.reg = normalize(solver, solver.normalizeReg, solver.reg, solver.A, b) @@ -169,7 +169,7 @@ end solversolution(solver::Kaczmarz) = solver.state.x solverconvergence(state::KaczmarzState) = (; :residual => norm(state.vl)) -function iterate(solver::Kaczmarz, state = solver.state) +function iterate(solver::Kaczmarz, state::KaczmarzState) if done(solver,state) return nothing end if solver.randomized diff --git a/src/OptISTA.jl b/src/OptISTA.jl index 509fde2e..d835f0e3 100644 --- a/src/OptISTA.jl +++ b/src/OptISTA.jl @@ -104,7 +104,7 @@ function OptISTA(A return OptISTA(A, AHA, reg[1], other, normalizeReg, verbose, iterations, state) end -function init!(solver::OptISTA, state::OptISTAState{rT, vecT}, b::otherT; kwargs...) where {rT, vecT, otherT} +function init!(solver::OptISTA, state::OptISTAState{rT, vecT}, b::otherT; kwargs...) where {rT, vecT, otherT <: AbstractVector} x = similar(b, size(state.x)...) x₀ = similar(b, size(state.x₀)...) y = similar(b, size(state.y)...) @@ -125,7 +125,7 @@ end (re-) initializes the OptISTA iterator """ -function init!(solver::OptISTA, state::OptISTAState{rT, vecT}, b::vecT; x0 = 0, θ=1) where {rT, vecT} +function init!(solver::OptISTA, state::OptISTAState{rT, vecT}, b::vecT; x0 = 0, θ=1) where {rT, vecT <: AbstractVector} if solver.A === nothing state.x₀ .= b else @@ -161,7 +161,7 @@ solverconvergence(state::OptISTAState) = (; :residual => norm(state.res)) performs one OptISTA iteration. """ -function iterate(solver::OptISTA, state::OptISTAState = solver.state) +function iterate(solver::OptISTA, state::OptISTAState) if done(solver, state) return nothing end # inertial parameters diff --git a/src/POGM.jl b/src/POGM.jl index a99efb56..c5141816 100644 --- a/src/POGM.jl +++ b/src/POGM.jl @@ -113,7 +113,7 @@ function POGM(A return POGM(A, AHA, reg[1], other, normalizeReg, verbose, restart, iterations, state) end -function init!(solver::POGM, state::POGMState{rT, vecT}, b::otherT; kwargs...) where {rT, vecT, otherT} +function init!(solver::POGM, state::POGMState{rT, vecT}, b::otherT; kwargs...) where {rT, vecT, otherT <: AbstractVector} x = similar(b, size(state.x)...) x₀ = similar(b, size(state.x₀)...) xᵒˡᵈ = similar(b, size(state.xᵒˡᵈ)...) @@ -135,7 +135,7 @@ end (re-) initializes the POGM iterator """ -function init!(solver::POGM, state::POGMState{rT, vecT}, b::vecT; x0 = 0, theta=1) where {rT, vecT} +function init!(solver::POGM, state::POGMState{rT, vecT}, b::vecT; x0 = 0, theta=1) where {rT, vecT <: AbstractVector} if solver.A === nothing state.x₀ .= b else @@ -170,7 +170,7 @@ solverconvergence(state::POGMState) = (; :residual => norm(state.res)) performs one POGM iteration. """ -function iterate(solver::POGM, state = solver.state) +function iterate(solver::POGM, state::POGMState) if done(solver, state) return nothing end diff --git a/src/RegularizedLeastSquares.jl b/src/RegularizedLeastSquares.jl index 5cb7d4be..b3c43aa5 100644 --- a/src/RegularizedLeastSquares.jl +++ b/src/RegularizedLeastSquares.jl @@ -171,6 +171,7 @@ solverstate(solver::AbstractLinearSolver) = solver.state solverconvergence(solver::AbstractLinearSolver) = solverconvergence(solverstate(solver)) init!(solver::AbstractLinearSolver, b; kwargs...) = init!(solver, solverstate(solver), b; kwargs...) +iterate(solver::AbstractLinearSolver) = iterate(solver, solverstate(solver)) include("Utils.jl") include("Kaczmarz.jl") diff --git a/src/SplitBregman.jl b/src/SplitBregman.jl index 8b4602fd..7a078061 100644 --- a/src/SplitBregman.jl +++ b/src/SplitBregman.jl @@ -143,7 +143,7 @@ function SplitBregman(A return SplitBregman(A,reg,regTrafo,proj,AHA,precon,normalizeReg,verbose,iterations,iterationsInner,iterationsCG,state) end -function init!(solver::SplitBregman, state::SplitBregmanState{rT, rvecT, vecT}, b::otherT; kwargs...) where {rT, rvecT, vecT, otherT} +function init!(solver::SplitBregman, state::SplitBregmanState{rT, rvecT, vecT}, b::otherT; kwargs...) where {rT, rvecT, vecT, otherT <: AbstractVector} y = similar(b, size(state.y)...) x = similar(b, size(state.x)...) β = similar(b, size(state.β)...) @@ -167,7 +167,7 @@ end (re-) initializes the SplitBregman iterator """ -function init!(solver::SplitBregman, state::SplitBregmanState{rT, rvecT, vecT}, b::vecT; x0 = 0) where {rT, rvecT, vecT} +function init!(solver::SplitBregman, state::SplitBregmanState{rT, rvecT, vecT}, b::vecT; x0 = 0) where {rT, rvecT, vecT <: AbstractVector} state.x .= x0 # right hand side for the x-update @@ -201,7 +201,7 @@ end solverconvergence(state::SplitBregmanState) = (; :primal => state.rᵏ, :dual => state.sᵏ) -function iterate(solver::SplitBregman, state=solver.state) +function iterate(solver::SplitBregman, state::SplitBregmanState) if done(solver, state) return nothing end solver.verbose && println("SplitBregman Iteration #$(state.iteration) – Outer iteration $(state.iter_cnt)") From 99299af02bc7ea1a34d64bc61b1bf86efd568fe6 Mon Sep 17 00:00:00 2001 From: nHackel Date: Mon, 12 Aug 2024 16:43:23 +0200 Subject: [PATCH 2/6] Add MultiThreading states for parallel processing support --- src/MultiThreading.jl | 68 ++++++++++++++++++++++++++++++++++ src/RegularizedLeastSquares.jl | 14 +++++-- test/runtests.jl | 1 + test/testMultiThreading.jl | 32 ++++++++++++++++ 4 files changed, 111 insertions(+), 4 deletions(-) create mode 100644 src/MultiThreading.jl create mode 100644 test/testMultiThreading.jl diff --git a/src/MultiThreading.jl b/src/MultiThreading.jl new file mode 100644 index 00000000..383f8b39 --- /dev/null +++ b/src/MultiThreading.jl @@ -0,0 +1,68 @@ +export SequentialState, MultiThreadingState +abstract type AbstractMatrixSolverState{S} <: AbstractSolverState{S} end +mutable struct SequentialState{S, ST <: AbstractSolverState{S}} <: AbstractMatrixSolverState{S} + states::Vector{ST} + active::Vector{Bool} + SequentialState(states::Vector{ST}) where {S, ST <: AbstractSolverState{S}} = new{S, ST}(states, fill(true, length(states))) +end + +mutable struct MultiThreadingState{S, ST <: AbstractSolverState{S}} <: AbstractMatrixSolverState{S} + states::Vector{ST} + active::Vector{Bool} + MultiThreadingState(states::Vector{ST}) where {S, ST <: AbstractSolverState{S}} = new{S, ST}(states, fill(true, length(states))) +end + +function init!(solver::AbstractLinearSolver, state::AbstractSolverState, b::AbstractMatrix; scheduler = SequentialState, kwargs...) + states = prepareMultiStates(solver, state, b) + multiState = scheduler(states) + solver.state = multiState + init!(solver, multiState, b; kwargs...) +end +function init!(solver::AbstractLinearSolver, state::AbstractMatrixSolverState, b::AbstractVector; kwargs...) + singleState = first(state.states) + solver.state = singleState + init!(solver, singleState, b; kwargs...) +end + +function prepareMultiStates(solver::AbstractLinearSolver, state::AbstractSolverState, b::AbstractMatrix) + states = [deepcopy(state) for _ in 1:size(b, 2)] + return states +end +prepareMultiStates(solver::AbstractLinearSolver, state::Union{SequentialState, MultiThreadingState}, b::AbstractMatrix) = prepareMultiStates(solver, first(state.states), b) + +function init!(solver::AbstractLinearSolver, state::Union{SequentialState, MultiThreadingState}, b::AbstractMatrix; kwargs...) + for (i, s) in enumerate(state.states) + init!(solver, s, b[:, i]; kwargs...) + end + state.active .= true +end + +function iterate(solver::S, state::Union{SequentialState, MultiThreadingState}) where {S <: AbstractLinearSolver} + activeIdx = findall(state.active) + if isempty(activeIdx) + return nothing + end + return iterate(solver, state, activeIdx) +end + +function iterate(solver::AbstractLinearSolver, state::SequentialState, activeIdx) + for i in activeIdx + res = iterate(solver, state.states[i]) + if isnothing(res) + state.active[i] = false + end + end + return state.active, state +end + +function iterate(solver::AbstractLinearSolver, state::MultiThreadingState, activeIdx) + Threads.@threads for i in activeIdx + res = iterate(solver, state.states[i]) + if isnothing(res) + state.active[i] = false + end + end + return state.active, state +end + +solversolution(state::Union{SequentialState, MultiThreadingState}) = mapreduce(solversolution, hcat, state.states) \ No newline at end of file diff --git a/src/RegularizedLeastSquares.jl b/src/RegularizedLeastSquares.jl index b3c43aa5..39431f94 100644 --- a/src/RegularizedLeastSquares.jl +++ b/src/RegularizedLeastSquares.jl @@ -99,13 +99,13 @@ The keyword `callbacks` allows you to pass a (vector of) callable objects that t See also [`StoreSolutionCallback`](@ref), [`StoreConvergenceCallback`](@ref), [`CompareSolutionCallback`](@ref) for a number of provided callback options. """ -function solve!(solver::AbstractLinearSolver, b; x0 = 0, callbacks = (_, _) -> nothing) +function solve!(solver::AbstractLinearSolver, b; callbacks = (_, _) -> nothing, kwargs...) if !(callbacks isa Vector) callbacks = [callbacks] end - init!(solver, b; x0) + init!(solver, b; kwargs...) foreach(cb -> cb(solver, 0), callbacks) for (iteration, _) = enumerate(solver) @@ -129,7 +129,7 @@ end """ solve!(cb, solver::AbstractLinearSolver, b; kwargs...) = solve!(solver, b; kwargs..., callbacks = cb) - +include("MultiThreading.jl") export AbstractRowActionSolver abstract type AbstractRowActionSolver <: AbstractLinearSolver end @@ -159,7 +159,13 @@ export solversolution, solverconvergence, solverstate Return the current solution of the solver """ -solversolution(solver::AbstractLinearSolver) = solverstate(solver).x +solversolution(solver::AbstractLinearSolver) = solversolution(solverstate(solver)) +""" + solversolution(state::AbstractSolverState) + +Return the current solution of the solver's state +""" +solversolution(state::AbstractSolverState) = state.x """ solverconvergence(solver::AbstractLinearSolver) diff --git a/test/runtests.jl b/test/runtests.jl index ddb6ad34..14b6028a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,4 +13,5 @@ arrayTypes = areTypesDefined ? arrayTypes : [Array, JLArray] include("testProxMaps.jl") include("testSolvers.jl") include("testRegularization.jl") + include("testMultiThreading.jl") end \ No newline at end of file diff --git a/test/testMultiThreading.jl b/test/testMultiThreading.jl new file mode 100644 index 00000000..48419576 --- /dev/null +++ b/test/testMultiThreading.jl @@ -0,0 +1,32 @@ +function testMultiThreadingSolver(; arrayType = Array, scheduler = MultiDataState) + A = rand(ComplexF32, 3, 2) + x = rand(ComplexF32, 2, 4) + b = A * x + + solvers = [CGNR] # linearSolverList() + @testset "$(solvers[i])" for i = 1:length(solvers) + S = createLinearSolver(solvers[i], arrayType(A), iterations = 100) + + x_sequential = hcat([Array(solve!(S, arrayType(b[:, j]))) for j = 1:size(b, 2)]...) + @test x_sequential ≈ x rtol = 0.1 + + x_approx = Array(solve!(S, arrayType(b), scheduler=scheduler)) + @test x_approx ≈ x rtol = 0.1 + + # Does sequential/normal reco still works after multi-threading + x_vec = Array(solve!(S, arrayType(b[:, 1]))) + @test x_vec ≈ x[:, 1] rtol = 0.1 + end +end + +@testset "Test MultiThreading Support" begin + for arrayType in arrayTypes + @testset "$arrayType" begin + for scheduler in [SequentialState, MultiThreadingState] + @testset "$scheduler" begin + testMultiThreadingSolver(; arrayType, scheduler) + end + end + end + end +end \ No newline at end of file From a7c4cf1e1268005d959035bacf4e15bb0a171f03 Mon Sep 17 00:00:00 2001 From: nHackel Date: Tue, 13 Aug 2024 09:58:14 +0200 Subject: [PATCH 3/6] Use type hierachy for matrix iteration --- src/MultiThreading.jl | 6 +++--- test/testMultiThreading.jl | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/MultiThreading.jl b/src/MultiThreading.jl index 383f8b39..27a77149 100644 --- a/src/MultiThreading.jl +++ b/src/MultiThreading.jl @@ -1,4 +1,4 @@ -export SequentialState, MultiThreadingState +export SequentialState, MultiThreadingState, prepareMultiStates abstract type AbstractMatrixSolverState{S} <: AbstractSolverState{S} end mutable struct SequentialState{S, ST <: AbstractSolverState{S}} <: AbstractMatrixSolverState{S} states::Vector{ST} @@ -30,14 +30,14 @@ function prepareMultiStates(solver::AbstractLinearSolver, state::AbstractSolverS end prepareMultiStates(solver::AbstractLinearSolver, state::Union{SequentialState, MultiThreadingState}, b::AbstractMatrix) = prepareMultiStates(solver, first(state.states), b) -function init!(solver::AbstractLinearSolver, state::Union{SequentialState, MultiThreadingState}, b::AbstractMatrix; kwargs...) +function init!(solver::AbstractLinearSolver, state::AbstractMatrixSolverState, b::AbstractMatrix; kwargs...) for (i, s) in enumerate(state.states) init!(solver, s, b[:, i]; kwargs...) end state.active .= true end -function iterate(solver::S, state::Union{SequentialState, MultiThreadingState}) where {S <: AbstractLinearSolver} +function iterate(solver::S, state::AbstractMatrixSolverState) where {S <: AbstractLinearSolver} activeIdx = findall(state.active) if isempty(activeIdx) return nothing diff --git a/test/testMultiThreading.jl b/test/testMultiThreading.jl index 48419576..465bd9d9 100644 --- a/test/testMultiThreading.jl +++ b/test/testMultiThreading.jl @@ -3,7 +3,7 @@ function testMultiThreadingSolver(; arrayType = Array, scheduler = MultiDataStat x = rand(ComplexF32, 2, 4) b = A * x - solvers = [CGNR] # linearSolverList() + solvers = linearSolverList() @testset "$(solvers[i])" for i = 1:length(solvers) S = createLinearSolver(solvers[i], arrayType(A), iterations = 100) From 9b9e47f027dd60643b7d5397a47436654c0da4f1 Mon Sep 17 00:00:00 2001 From: nHackel Date: Tue, 13 Aug 2024 10:02:16 +0200 Subject: [PATCH 4/6] Refactor solversolution function to accept AbstractMatrixSolverState --- src/MultiThreading.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/MultiThreading.jl b/src/MultiThreading.jl index 27a77149..762f6838 100644 --- a/src/MultiThreading.jl +++ b/src/MultiThreading.jl @@ -65,4 +65,4 @@ function iterate(solver::AbstractLinearSolver, state::MultiThreadingState, activ return state.active, state end -solversolution(state::Union{SequentialState, MultiThreadingState}) = mapreduce(solversolution, hcat, state.states) \ No newline at end of file +solversolution(state::AbstractMatrixSolverState) = mapreduce(solversolution, hcat, state.states) \ No newline at end of file From ef0c58f0fd38ed127420a039b7b47dc233e8be25 Mon Sep 17 00:00:00 2001 From: nHackel Date: Tue, 13 Aug 2024 10:30:29 +0200 Subject: [PATCH 5/6] Fix multi threading state bugs in Kaczmarz and direct solvers --- src/Direct.jl | 12 ++++++------ src/Kaczmarz.jl | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/Direct.jl b/src/Direct.jl index e420a3ec..45221a7f 100644 --- a/src/Direct.jl +++ b/src/Direct.jl @@ -41,19 +41,19 @@ function DirectSolver(A; reg::Vector{<:AbstractRegularization} = [L2Regularizati return DirectSolver(A, L2, normalizeReg, other, DirectSolverState(x, b)) end -function init!(solver::DirectSolver, state::DirectSolverState{vecT}, b::otherT; kwargs...) where {vecT, otherT} +function init!(solver::DirectSolver, state::DirectSolverState{vecT}, b::otherT; kwargs...) where {vecT, otherT <: AbstractVector} x = similar(b, size(state.x)...) bvecT = similar(b, size(state.b)...) solver.state = DirectSolverState(x, bvecT) init!(solver, solver.state, b; kwargs...) end -function init!(solver::DirectSolver, state::DirectSolverState{vecT}, b::vecT; x0=0) where vecT +function init!(solver::DirectSolver, state::DirectSolverState{vecT}, b::vecT; x0=0) where vecT <: AbstractVector solver.l2 = normalize(solver, solver.normalizeReg, solver.l2, solver.A, b) state.b .= b state.x .= x0 end -function iterate(solver::DirectSolver, state = solver.state) +function iterate(solver::DirectSolver, state::DirectSolverState) A = solver.A λ_ = λ(solver.l2) lufact = lu(A'*A .+ λ_) @@ -138,18 +138,18 @@ function PseudoInverse(A::AbstractMatrix, x, b, l2, norm, proj) return PseudoInverse(temp, l2, norm, proj, DirectSolverState(x, b)) end -function init!(solver::PseudoInverse, state::DirectSolverState{vecT}, b::otherT; kwargs...) where {vecT, otherT} +function init!(solver::PseudoInverse, state::DirectSolverState{vecT}, b::otherT; kwargs...) where {vecT, otherT <: AbstractVector} x = similar(b, size(state.x)...) bvecT = similar(b, size(state.b)...) solver.state = DirectSolverState(x, bvecT) init!(solver, solver.state, b; kwargs...) end -function init!(solver::PseudoInverse, state::DirectSolverState{vecT}, b::vecT; x0=0) where vecT +function init!(solver::PseudoInverse, state::DirectSolverState{vecT}, b::vecT; x0=0) where vecT <: AbstractVector solver.l2 = normalize(solver, solver.normalizeReg, solver.l2, solver.svd, b) state.b .= b end -function iterate(solver::PseudoInverse, state = solver.state) +function iterate(solver::PseudoInverse, state::DirectSolverState) # Inversion by using the pseudoinverse of the SVD svd = solver.svd diff --git a/src/Kaczmarz.jl b/src/Kaczmarz.jl index 3ddcd979..3584c85c 100644 --- a/src/Kaczmarz.jl +++ b/src/Kaczmarz.jl @@ -164,9 +164,9 @@ end function solversolution(solver::Kaczmarz{matT, RN}) where {matT, R<:L2Regularization{<:AbstractVector}, RN <: Union{R, AbstractNestedRegularization{<:R}}} - return solver.state.x .* (1 ./ sqrt.(λ(solver.L2))) + return solversolution(solver.state) .* (1 ./ sqrt.(λ(solver.L2))) end -solversolution(solver::Kaczmarz) = solver.state.x +solversolution(solver::Kaczmarz) = solversolution(solver.state) solverconvergence(state::KaczmarzState) = (; :residual => norm(state.vl)) function iterate(solver::Kaczmarz, state::KaczmarzState) From 6b17ba93cbb7a4d4ef381deeef693c80f864a27b Mon Sep 17 00:00:00 2001 From: nHackel Date: Tue, 13 Aug 2024 11:37:19 +0200 Subject: [PATCH 6/6] Fix bug in CGNR with only AHA, b is now used to init x0, similar to FISTA and co --- src/CGNR.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/CGNR.jl b/src/CGNR.jl index 939e85b8..18343fb2 100644 --- a/src/CGNR.jl +++ b/src/CGNR.jl @@ -131,7 +131,7 @@ end initCGNR(x₀, A, b) = mul!(x₀, adjoint(A), b) #initCGNR(x₀, prod::ProdOp{T, <:WeightingOp, matT}, b) where {T, matT} = mul!(x₀, adjoint(prod.B), b) -initCGNR(x₀, ::Nothing, b) = x₀ .= one(eltype(x₀)) +initCGNR(x₀, ::Nothing, b) = x₀ .= b solverconvergence(state::CGNRState) = (; :residual => norm(state.x₀))