From a7c4cf1e1268005d959035bacf4e15bb0a171f03 Mon Sep 17 00:00:00 2001 From: nHackel Date: Tue, 13 Aug 2024 09:58:14 +0200 Subject: [PATCH] Use type hierachy for matrix iteration --- src/MultiThreading.jl | 6 +++--- test/testMultiThreading.jl | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/MultiThreading.jl b/src/MultiThreading.jl index 383f8b39..27a77149 100644 --- a/src/MultiThreading.jl +++ b/src/MultiThreading.jl @@ -1,4 +1,4 @@ -export SequentialState, MultiThreadingState +export SequentialState, MultiThreadingState, prepareMultiStates abstract type AbstractMatrixSolverState{S} <: AbstractSolverState{S} end mutable struct SequentialState{S, ST <: AbstractSolverState{S}} <: AbstractMatrixSolverState{S} states::Vector{ST} @@ -30,14 +30,14 @@ function prepareMultiStates(solver::AbstractLinearSolver, state::AbstractSolverS end prepareMultiStates(solver::AbstractLinearSolver, state::Union{SequentialState, MultiThreadingState}, b::AbstractMatrix) = prepareMultiStates(solver, first(state.states), b) -function init!(solver::AbstractLinearSolver, state::Union{SequentialState, MultiThreadingState}, b::AbstractMatrix; kwargs...) +function init!(solver::AbstractLinearSolver, state::AbstractMatrixSolverState, b::AbstractMatrix; kwargs...) for (i, s) in enumerate(state.states) init!(solver, s, b[:, i]; kwargs...) end state.active .= true end -function iterate(solver::S, state::Union{SequentialState, MultiThreadingState}) where {S <: AbstractLinearSolver} +function iterate(solver::S, state::AbstractMatrixSolverState) where {S <: AbstractLinearSolver} activeIdx = findall(state.active) if isempty(activeIdx) return nothing diff --git a/test/testMultiThreading.jl b/test/testMultiThreading.jl index 48419576..465bd9d9 100644 --- a/test/testMultiThreading.jl +++ b/test/testMultiThreading.jl @@ -3,7 +3,7 @@ function testMultiThreadingSolver(; arrayType = Array, scheduler = MultiDataStat x = rand(ComplexF32, 2, 4) b = A * x - solvers = [CGNR] # linearSolverList() + solvers = linearSolverList() @testset "$(solvers[i])" for i = 1:length(solvers) S = createLinearSolver(solvers[i], arrayType(A), iterations = 100)