From ef0c58f0fd38ed127420a039b7b47dc233e8be25 Mon Sep 17 00:00:00 2001 From: nHackel Date: Tue, 13 Aug 2024 10:30:29 +0200 Subject: [PATCH] Fix multi threading state bugs in Kaczmarz and direct solvers --- src/Direct.jl | 12 ++++++------ src/Kaczmarz.jl | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/Direct.jl b/src/Direct.jl index e420a3ec..45221a7f 100644 --- a/src/Direct.jl +++ b/src/Direct.jl @@ -41,19 +41,19 @@ function DirectSolver(A; reg::Vector{<:AbstractRegularization} = [L2Regularizati return DirectSolver(A, L2, normalizeReg, other, DirectSolverState(x, b)) end -function init!(solver::DirectSolver, state::DirectSolverState{vecT}, b::otherT; kwargs...) where {vecT, otherT} +function init!(solver::DirectSolver, state::DirectSolverState{vecT}, b::otherT; kwargs...) where {vecT, otherT <: AbstractVector} x = similar(b, size(state.x)...) bvecT = similar(b, size(state.b)...) solver.state = DirectSolverState(x, bvecT) init!(solver, solver.state, b; kwargs...) end -function init!(solver::DirectSolver, state::DirectSolverState{vecT}, b::vecT; x0=0) where vecT +function init!(solver::DirectSolver, state::DirectSolverState{vecT}, b::vecT; x0=0) where vecT <: AbstractVector solver.l2 = normalize(solver, solver.normalizeReg, solver.l2, solver.A, b) state.b .= b state.x .= x0 end -function iterate(solver::DirectSolver, state = solver.state) +function iterate(solver::DirectSolver, state::DirectSolverState) A = solver.A λ_ = λ(solver.l2) lufact = lu(A'*A .+ λ_) @@ -138,18 +138,18 @@ function PseudoInverse(A::AbstractMatrix, x, b, l2, norm, proj) return PseudoInverse(temp, l2, norm, proj, DirectSolverState(x, b)) end -function init!(solver::PseudoInverse, state::DirectSolverState{vecT}, b::otherT; kwargs...) where {vecT, otherT} +function init!(solver::PseudoInverse, state::DirectSolverState{vecT}, b::otherT; kwargs...) where {vecT, otherT <: AbstractVector} x = similar(b, size(state.x)...) bvecT = similar(b, size(state.b)...) solver.state = DirectSolverState(x, bvecT) init!(solver, solver.state, b; kwargs...) end -function init!(solver::PseudoInverse, state::DirectSolverState{vecT}, b::vecT; x0=0) where vecT +function init!(solver::PseudoInverse, state::DirectSolverState{vecT}, b::vecT; x0=0) where vecT <: AbstractVector solver.l2 = normalize(solver, solver.normalizeReg, solver.l2, solver.svd, b) state.b .= b end -function iterate(solver::PseudoInverse, state = solver.state) +function iterate(solver::PseudoInverse, state::DirectSolverState) # Inversion by using the pseudoinverse of the SVD svd = solver.svd diff --git a/src/Kaczmarz.jl b/src/Kaczmarz.jl index 3ddcd979..3584c85c 100644 --- a/src/Kaczmarz.jl +++ b/src/Kaczmarz.jl @@ -164,9 +164,9 @@ end function solversolution(solver::Kaczmarz{matT, RN}) where {matT, R<:L2Regularization{<:AbstractVector}, RN <: Union{R, AbstractNestedRegularization{<:R}}} - return solver.state.x .* (1 ./ sqrt.(λ(solver.L2))) + return solversolution(solver.state) .* (1 ./ sqrt.(λ(solver.L2))) end -solversolution(solver::Kaczmarz) = solver.state.x +solversolution(solver::Kaczmarz) = solversolution(solver.state) solverconvergence(state::KaczmarzState) = (; :residual => norm(state.vl)) function iterate(solver::Kaczmarz, state::KaczmarzState)