Skip to content

Commit

Permalink
Fix multi threading state bugs in Kaczmarz and direct solvers
Browse files Browse the repository at this point in the history
  • Loading branch information
nHackel committed Aug 13, 2024
1 parent 9b9e47f commit ef0c58f
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
12 changes: 6 additions & 6 deletions src/Direct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,19 @@ function DirectSolver(A; reg::Vector{<:AbstractRegularization} = [L2Regularizati
return DirectSolver(A, L2, normalizeReg, other, DirectSolverState(x, b))
end

function init!(solver::DirectSolver, state::DirectSolverState{vecT}, b::otherT; kwargs...) where {vecT, otherT}
function init!(solver::DirectSolver, state::DirectSolverState{vecT}, b::otherT; kwargs...) where {vecT, otherT <: AbstractVector}
x = similar(b, size(state.x)...)
bvecT = similar(b, size(state.b)...)
solver.state = DirectSolverState(x, bvecT)
init!(solver, solver.state, b; kwargs...)
end
function init!(solver::DirectSolver, state::DirectSolverState{vecT}, b::vecT; x0=0) where vecT
function init!(solver::DirectSolver, state::DirectSolverState{vecT}, b::vecT; x0=0) where vecT <: AbstractVector
solver.l2 = normalize(solver, solver.normalizeReg, solver.l2, solver.A, b)
state.b .= b
state.x .= x0
end

function iterate(solver::DirectSolver, state = solver.state)
function iterate(solver::DirectSolver, state::DirectSolverState)
A = solver.A
λ_ = λ(solver.l2)
lufact = lu(A'*A .+ λ_)
Expand Down Expand Up @@ -138,18 +138,18 @@ function PseudoInverse(A::AbstractMatrix, x, b, l2, norm, proj)
return PseudoInverse(temp, l2, norm, proj, DirectSolverState(x, b))
end

function init!(solver::PseudoInverse, state::DirectSolverState{vecT}, b::otherT; kwargs...) where {vecT, otherT}
function init!(solver::PseudoInverse, state::DirectSolverState{vecT}, b::otherT; kwargs...) where {vecT, otherT <: AbstractVector}
x = similar(b, size(state.x)...)
bvecT = similar(b, size(state.b)...)
solver.state = DirectSolverState(x, bvecT)
init!(solver, solver.state, b; kwargs...)
end
function init!(solver::PseudoInverse, state::DirectSolverState{vecT}, b::vecT; x0=0) where vecT
function init!(solver::PseudoInverse, state::DirectSolverState{vecT}, b::vecT; x0=0) where vecT <: AbstractVector
solver.l2 = normalize(solver, solver.normalizeReg, solver.l2, solver.svd, b)
state.b .= b
end

function iterate(solver::PseudoInverse, state = solver.state)
function iterate(solver::PseudoInverse, state::DirectSolverState)
# Inversion by using the pseudoinverse of the SVD
svd = solver.svd

Expand Down
4 changes: 2 additions & 2 deletions src/Kaczmarz.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,9 @@ end


function solversolution(solver::Kaczmarz{matT, RN}) where {matT, R<:L2Regularization{<:AbstractVector}, RN <: Union{R, AbstractNestedRegularization{<:R}}}
return solver.state.x .* (1 ./ sqrt.(λ(solver.L2)))
return solversolution(solver.state) .* (1 ./ sqrt.(λ(solver.L2)))
end
solversolution(solver::Kaczmarz) = solver.state.x
solversolution(solver::Kaczmarz) = solversolution(solver.state)
solverconvergence(state::KaczmarzState) = (; :residual => norm(state.vl))

function iterate(solver::Kaczmarz, state::KaczmarzState)
Expand Down

0 comments on commit ef0c58f

Please sign in to comment.