Skip to content

Commit

Permalink
Merge pull request #35 from MagneticParticleImaging/nh/multiGPU
Browse files Browse the repository at this point in the history
GPU support for sparse single patch and for multi patch reconstructions
  • Loading branch information
nHackel authored Jul 23, 2024
2 parents e964253 + 62a6813 commit 7b80497
Show file tree
Hide file tree
Showing 18 changed files with 454 additions and 126 deletions.
36 changes: 18 additions & 18 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,21 @@ steps:
include("test/gpu/cuda.jl")'
timeout_in_minutes: 30

- label: "AMD GPUs -- MPIReco.jl"
plugins:
- JuliaCI/julia#v1:
version: "1.10"
agents:
queue: "juliagpu"
rocm: "*"
rocmgpu: "*"
command: |
julia --color=yes --project -e '
using Pkg
Pkg.add("TestEnv")
using TestEnv
TestEnv.activate();
Pkg.add("AMDGPU")
Pkg.instantiate()
include("test/gpu/rocm.jl")'
timeout_in_minutes: 30
# - label: "AMD GPUs -- MPIReco.jl"
# plugins:
# - JuliaCI/julia#v1:
# version: "1.10"
# agents:
# queue: "juliagpu"
# rocm: "*"
# rocmgpu: "*"
# command: |
# julia --color=yes --project -e '
# using Pkg
# Pkg.add("TestEnv")
# using TestEnv
# TestEnv.activate();
# Pkg.add("AMDGPU")
# Pkg.instantiate()
# include("test/gpu/rocm.jl")'
# timeout_in_minutes: 30
13 changes: 13 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Tobias Knopp <[email protected]>"]
version = "0.6.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
DSP = "717857b8-e6f2-59f4-9121-6e50c889abd2"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
DistributedArrays = "aaf54ef3-cdf8-58ed-94cc-d582ad619b94"
Expand All @@ -25,13 +26,17 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[compat]
AbstractImageReconstruction = "0.3"
Adapt = "3, 4"
Atomix = "0.1"
DSP = "0.6, 0.7"
Distributed = "1"
DistributedArrays = "0.6"
FFTW = "1.3"
GPUArrays = "8, 9, 10"
ImageUtils = "0.2"
IniFile = "0.5"
JLArrays = "0.1"
KernelAbstractions = "0.8, 0.9"
LinearAlgebra = "1"
LinearOperators = "2.3"
LinearOperatorCollection = "2"
Expand All @@ -56,5 +61,13 @@ ImageMagick = "6218d12a-5da1-5696-b52f-db25d2ecc6d1"
ImageQualityIndexes = "2996bd0c-7a13-11e9-2da2-2f5ce47296a9"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[weakdeps]
Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"

[targets]
test = ["Test", "HTTP", "FileIO", "LazyArtifacts", "Scratch", "ImageMagick", "ImageQualityIndexes", "Unitful", "JLArrays"]

[extensions]
MPIRecoKernelAbstractionsExt = ["Atomix","KernelAbstractions", "GPUArrays"]
2 changes: 1 addition & 1 deletion config/MultiPatch.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ _module = "MPIReco"
_module = "MPIReco"

[parameter.reco.solverParams]
_type = "RecoPlan{SimpleSolverParameters}"
_type = "RecoPlan{ElaborateSolverParameters}"
_module = "MPIReco"

[parameter.pre]
Expand Down
10 changes: 10 additions & 0 deletions ext/MPIRecoKernelAbstractionsExt/MPIRecoKernelAbstractionsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module MPIRecoKernelAbstractionsExt

using MPIReco, MPIReco.Adapt, MPIReco.LinearAlgebra, MPIReco.RegularizedLeastSquares
using KernelAbstractions, GPUArrays
using KernelAbstractions.Extras: @unroll
using Atomix

include("MultiPatch.jl")

end
143 changes: 143 additions & 0 deletions ext/MPIRecoKernelAbstractionsExt/MultiPatch.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
function Adapt.adapt_structure(::Type{arrT}, op::MultiPatchOperator) where {arrT <: AbstractGPUArray}
validSMs = all(x -> size(x) == size(op.S[1]), op.S)
validXCC = all(x -> length(x) == length(op.xcc[1]), op.xcc)
validXSS = all(x -> length(x) == length(op.xss[1]), op.xss)

# Ideally we create a DenseMultiPatchOperator on the GPU
if validSMs && validXCC && validXSS
S = adapt(arrT, stack(op.S))
# We want to use Int32 for better GPU performance
xcc = Int32.(adapt(arrT, stack(op.xcc)))
xss = Int32.(adapt(arrT, stack(op.xss)))
sign = Int32.(adapt(arrT, op.sign))
RowToPatch = Int32.(adapt(arrT, op.RowToPatch))
patchToSMIdx = Int32.(adapt(arrT, op.patchToSMIdx))
return DenseMultiPatchOperator(S, op.grid, op.N, op.M, RowToPatch, xcc, xss, sign, Int32(op.nPatches), patchToSMIdx)
else
throw(ArgumentError("Cannot adapt MultiPatchOperator to $arrT, since it cannot be represented as a DenseMultiPatchOperator"))
end
end

@kernel cpu = false inbounds = true function dense_mul!(b, @Const(x), @Const(S), @Const(xcc), @Const(xss), @Const(signs), @Const(M), @Const(RowToPatch), @Const(patchToSMIdx))
# Each group/block handles a single row of the operator
operator_row = @index(Group, Linear) # k
patch = RowToPatch[operator_row] # p
patch_row = mod1(operator_row, M) # j
smIdx = patchToSMIdx[patch]
sign = eltype(b)(signs[patch_row, smIdx])
grid_stride = prod(@groupsize())
N = Int32(size(xss, 1))

# We want to use a grid-stride loop to perform the sparse matrix-vector product.
# Each thread performs a single element-wise multiplication and reduction in its shared spot.
# Afterwards we reduce over the shared memory.
localIdx = @index(Local, Linear)
shared = @localmem eltype(b) grid_stride
shared[localIdx] = zero(eltype(b))

# First we iterate over the sparse indices
@unroll for i = localIdx:grid_stride:N
shared[localIdx] = shared[localIdx] + sign * S[patch_row, xss[i, patch], smIdx] * x[xcc[i, patch]]
end
@synchronize

# Now we need to reduce the shared memory to get the final result
@private s = div(min(grid_stride, N), Int32(2))
while s > Int32(0)
if localIdx <= s
shared[localIdx] = shared[localIdx] + shared[localIdx + s]
end
s >>= 1
@synchronize
end

# Write the result out to b
if localIdx == 1
b[operator_row] = shared[localIdx]
end
end

function LinearAlgebra.mul!(b::AbstractVector{T}, op::DenseMultiPatchOperator{T, V}, x::AbstractVector{T}) where {T, V <: AbstractGPUArray}
backend = get_backend(b)
kernel = dense_mul!(backend, 256)
kernel(b, x, op.S, op.xcc, op.xss, op.sign, Int32(div(op.M, op.nPatches)), op.RowToPatch, op.patchToSMIdx; ndrange = (256, size(op, 1)))
synchronize(backend)
return b
end

@kernel inbounds = true function dense_mul_adj!(res, @Const(t), @Const(S), @Const(xcc), @Const(xss), @Const(signs), @Const(M), @Const(RowToPatch), @Const(patchToSMIdx))
# Each group/block handles a single column of the adjoint(operator)
# i.e. a row of the operator
localIdx = @index(Local, Linear)
groupIdx = @index(Group, Linear) # k
patch = RowToPatch[groupIdx] # p
patch_row = mod1(groupIdx, M) # j
smIdx = patchToSMIdx[patch]
sign = eltype(res)(signs[patch_row, smIdx])
grid_stride = prod(@groupsize())
N = Int32(size(xss, 1))


# Each thread within the block will add the same value of t
val = t[groupIdx]

# Since we go along the columns during a matrix-vector product,
# we have a race condition with other threads writing to the same result.
for i = localIdx:grid_stride:N
tmp = sign * conj(S[patch_row, xss[i, patch], smIdx]) * val
# @atomic is not supported for ComplexF32 numbers
Atomix.@atomic res[1, xcc[i, patch]] += tmp.re
Atomix.@atomic res[2, xcc[i, patch]] += tmp.im
end
end

function LinearAlgebra.mul!(res::AbstractVector{T}, adj::Adjoint{T, OP}, t::AbstractVector{T}) where {T <: Complex, V <: AbstractGPUArray, OP <: DenseMultiPatchOperator{T, V}}
backend = get_backend(res)
op = adj.parent
res .= zero(T) # We need to zero the result, because we are using += in the kernel
kernel = dense_mul_adj!(backend, 256)
# We have to reinterpret the result as a real array, because atomic operations on Complex numbers are not supported
kernel(reinterpret(reshape, real(eltype(res)), res), t, op.S, op.xcc, op.xss, op.sign, Int32(div(op.M, op.nPatches)), op.RowToPatch, op.patchToSMIdx; ndrange = (256, size(op, 1)))
synchronize(backend)
return res
end

# Kaczmarz specific functions
function RegularizedLeastSquares.dot_with_matrix_row(op::DenseMultiPatchOperator{T, V}, x::AbstractArray{T}, k::Int) where {T, V <: AbstractGPUArray}
patch = @allowscalar op.RowToPatch[k]
patch_row = mod1(k, div(op.M,op.nPatches))
smIdx = @allowscalar op.patchToSMIdx[patch]
sign = @allowscalar op.sign[patch_row, smIdx]
S = op.S
# Inplace reduce-broadcast: https://github.com/JuliaLang/julia/pull/31020
return sum(Broadcast.instantiate(Base.broadcasted(view(op.xss, :, patch), view(op.xcc, :, patch)) do xs, xc
@inbounds sign * S[patch_row, xs, smIdx] * x[xc]
end))
end

function RegularizedLeastSquares.rownorm²(op::DenseMultiPatchOperator{T, V}, row::Int64) where {T, V <: AbstractGPUArray}
patch = @allowscalar op.RowToPatch[row]
patch_row = mod1(row, div(op.M,op.nPatches))
smIdx = @allowscalar op.patchToSMIdx[patch]
sign = @allowscalar op.sign[patch_row, smIdx]
S = op.S
return mapreduce(xs -> abs2(sign * S[patch_row, xs, smIdx]), +, view(op.xss, :, patch))
end

@kernel cpu = false function kaczmarz_update_kernel!(x, @Const(S), @Const(row), @Const(beta), @Const(xcc), @Const(xss), @Const(signs), @Const(M), @Const(RowToPatch), @Const(patchToSMIdx))
# Each thread handles one element of the kaczmarz update
idx = @index(Global, Linear)
patch = RowToPatch[row]
patch_row = mod1(row, M)
smIdx = patchToSMIdx[patch]
sign = eltype(x)(signs[patch_row, smIdx])
x[xcc[idx, patch]] += beta * conj(sign * S[patch_row, xss[idx, patch], smIdx])
end

function RegularizedLeastSquares.kaczmarz_update!(op::DenseMultiPatchOperator{T, V}, x::vecT, row, beta) where {T, vecT <: AbstractGPUVector{T}, V <: AbstractGPUArray{T}}
backend = get_backend(x)
kernel = kaczmarz_update_kernel!(backend, 256)
kernel(x, op.S, row, beta, op.xcc, op.xss, op.sign, Int32(div(op.M, op.nPatches)), op.RowToPatch, op.patchToSMIdx; ndrange = size(op.xss, 1))
synchronize(backend)
return x
end
26 changes: 17 additions & 9 deletions src/Algorithms/MultiPatchAlgorithms/MultiPatchAlgorithm.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
export MultiPatchReconstructionAlgorithm, MultiPatchReconstructionParameter
Base.@kwdef struct MultiPatchReconstructionParameter{F<:AbstractFrequencyFilterParameter,O<:AbstractMultiPatchOperatorParameter, S<:AbstractSolverParameters, FF<:AbstractFocusFieldPositions, FFSF<:AbstractFocusFieldPositions} <: AbstractMultiPatchReconstructionParameters
Base.@kwdef struct MultiPatchReconstructionParameter{matT <: AbstractArray,F<:AbstractFrequencyFilterParameter,O<:AbstractMultiPatchOperatorParameter, S<:AbstractSolverParameters, FF<:AbstractFocusFieldPositions, FFSF<:AbstractFocusFieldPositions, R <: AbstractRegularization} <: AbstractMultiPatchReconstructionParameters
arrayType::Type{matT} = Array
# File
sf::MultiMPIFile
freqFilter::F
opParams::O
ffPos::FF = DefaultFocusFieldPositions()
ffPosSF::FFSF = DefaultFocusFieldPositions()
solverParams::S
λ::Float32
reg::Vector{R} = AbstractRegularization[]
# weightingType::WeightingType = WeightingType.None
end

Base.@kwdef mutable struct MultiPatchReconstructionAlgorithm{P} <: AbstractMultiPatchReconstructionAlgorithm where {P<:AbstractMultiPatchAlgorithmParameters}
Base.@kwdef mutable struct MultiPatchReconstructionAlgorithm{P, matT <: AbstractArray} <: AbstractMultiPatchReconstructionAlgorithm where {P<:AbstractMultiPatchAlgorithmParameters}
params::P
# Could also do reconstruction progress meter here
opParams::Union{AbstractMultiPatchOperatorParameter, Nothing} = nothing
sf::MultiMPIFile
ffOp::Union{Nothing, MultiPatchOperator}
arrayType::Type{matT}
ffOp::Union{Nothing, AbstractMultiPatchOperator}
ffPos::Union{Nothing,AbstractArray}
ffPosSF::Union{Nothing,AbstractArray}
freqs::Vector{CartesianIndex{2}}
Expand All @@ -38,7 +40,7 @@ function MultiPatchReconstructionAlgorithm(params::MultiPatchParameters{<:Abstra
ffPosSF = [vec(ffPos(SF))[l] for l=1:L, SF in reco.sf]
end

return MultiPatchReconstructionAlgorithm(params, reco.opParams, reco.sf, nothing, ffPos_, ffPosSF, freqs, Channel{Any}(Inf))
return MultiPatchReconstructionAlgorithm(params, reco.opParams, reco.sf, reco.arrayType, nothing, ffPos_, ffPosSF, freqs, Channel{Any}(Inf))
end
recoAlgorithmTypes(::Type{MultiPatchReconstruction}) = SystemMatrixBasedAlgorithm()
AbstractImageReconstruction.parameter(algo::MultiPatchReconstructionAlgorithm) = algo.origParam
Expand Down Expand Up @@ -76,8 +78,14 @@ function process(algo::MultiPatchReconstructionAlgorithm, params::AbstractMultiP

ffPosSF = algo.ffPosSF

return MultiPatchOperator(algo.sf, frequencies; toKwargs(params)...,
gradient = gradient, FFPos = ffPos_, FFPosSF = ffPosSF)
operator = MultiPatchOperator(algo.sf, frequencies; toKwargs(params)..., gradient = gradient, FFPos = ffPos_, FFPosSF = ffPosSF)
return adapt(algo.arrayType, operator)
end

function process(algo::MultiPatchReconstructionAlgorithm, params::Union{A, ProcessResultCache{<:A}}, f::MPIFile, args...) where A <: AbstractMPIPreProcessingParameters
result = process(typeof(algo), params, f, args...)
result = adapt(algo.arrayType, result)
return result
end

function process(t::Type{<:MultiPatchReconstructionAlgorithm}, params::CommonPreProcessingParameters{NoBackgroundCorrectionParameters}, f::MPIFile, frequencies::Union{Vector{CartesianIndex{2}}, Nothing} = nothing)
Expand Down Expand Up @@ -114,8 +122,8 @@ function process(::Type{<:MultiPatchReconstructionAlgorithm}, params::ExternalPr
return data
end

function process(algo::MultiPatchReconstructionAlgorithm, params::MultiPatchReconstructionParameter, u::Array)
solver = LeastSquaresParameters(S = algo.ffOp, reg = [L2Regularization(params.λ)], solverParams = params.solverParams)
function process(algo::MultiPatchReconstructionAlgorithm, params::MultiPatchReconstructionParameter, u::AbstractArray)
solver = LeastSquaresParameters(S = algo.ffOp, reg = params.reg, solverParams = params.solverParams)

result = process(algo, solver, u)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ end
function MultiPatchReconstructionAlgorithm(params::MultiPatchParameters{<:PeriodicMotionPreProcessing,<:PeriodicMotionReconstructionParameter,<:AbstractMPIPostProcessingParameters})
reco = params.reco
freqs = process(MultiPatchReconstructionAlgorithm, reco.freqFilter, reco.sf)
return MultiPatchReconstructionAlgorithm(params, nothing, reco.sf, nothing, nothing, nothing, freqs, Channel{Any}(Inf))
return MultiPatchReconstructionAlgorithm(params, nothing, reco.sf, Array, nothing, nothing, nothing, freqs, Channel{Any}(Inf))
end

function AbstractImageReconstruction.put!(algo::MultiPatchReconstructionAlgorithm{MultiPatchParameters{PT, R, T}}, data::MPIFile) where {R, T, PT <: PeriodicMotionPreProcessing}
Expand Down
20 changes: 11 additions & 9 deletions src/Algorithms/SinglePatchAlgorithms/SinglePatchAlgorithm.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
Base.@kwdef struct SinglePatchReconstructionParameter{L<:AbstractSystemMatrixLoadingParameter, SL<:AbstractLinearSolver,
SP<:AbstractSolverParameters{SL}, R<:AbstractRegularization, W<:AbstractWeightingParameters} <: AbstractSinglePatchReconstructionParameters
matT <: AbstractArray, SP<:AbstractSolverParameters{SL}, R<:AbstractRegularization, W<:AbstractWeightingParameters} <: AbstractSinglePatchReconstructionParameters
# File
sf::MPIFile
sfLoad::Union{L, ProcessResultCache{L}}
arrayType::Type{matT} = Array
# Solver
solverParams::SP
reg::Vector{R} = AbstractRegularization[]
weightingParams::W = NoWeightingParameters()
end

Base.@kwdef mutable struct SinglePatchReconstructionAlgorithm{P} <: AbstractSinglePatchReconstructionAlgorithm where {P<:AbstractSinglePatchAlgorithmParameters}
Base.@kwdef mutable struct SinglePatchReconstructionAlgorithm{P, matT <: AbstractArray} <: AbstractSinglePatchReconstructionAlgorithm where {P<:AbstractSinglePatchAlgorithmParameters}
params::P
# Could also do reconstruction progress meter here
sf::Union{MPIFile, Vector{MPIFile}}
S::AbstractArray
arrayType::Type{matT}
grid::RegularGridPositions
freqs::Vector{CartesianIndex{2}}
output::Channel{Any}
Expand All @@ -23,16 +25,16 @@ function SinglePatchReconstruction(params::SinglePatchParameters{<:AbstractMPIPr
return SinglePatchReconstructionAlgorithm(params)
end
function SinglePatchReconstructionAlgorithm(params::SinglePatchParameters{<:AbstractMPIPreProcessingParameters, R, PT}) where {R<:AbstractSinglePatchReconstructionParameters, PT <:AbstractMPIPostProcessingParameters}
freqs, S, grid = prepareSystemMatrix(params.reco)
return SinglePatchReconstructionAlgorithm(params, params.reco.sf, S, grid, freqs, Channel{Any}(Inf))
freqs, S, grid, arrayType = prepareSystemMatrix(params.reco)
return SinglePatchReconstructionAlgorithm(params, params.reco.sf, S, arrayType, grid, freqs, Channel{Any}(Inf))
end
recoAlgorithmTypes(::Type{SinglePatchReconstruction}) = SystemMatrixBasedAlgorithm()
AbstractImageReconstruction.parameter(algo::SinglePatchReconstructionAlgorithm) = algo.params

function prepareSystemMatrix(reco::SinglePatchReconstructionParameter{L,S}) where {L<:AbstractSystemMatrixLoadingParameter, S<:AbstractLinearSolver}
freqs, sf, grid = process(AbstractMPIRecoAlgorithm, reco.sfLoad, reco.sf)
sf, grid = prepareSF(S, sf, grid)
return freqs, sf, grid
sf, grid = process(AbstractMPIRecoAlgorithm, reco.sfLoad, S, sf, grid, reco.arrayType)
return freqs, sf, grid, reco.arrayType
end


Expand All @@ -44,13 +46,13 @@ function process(algo::SinglePatchReconstructionAlgorithm, params::Union{A, Proc
@warn "System matrix and measurement have different element data type. Mapping measurment data to system matrix element type."
result = map(eltype(algo.S),result)
end
result = copyto!(similar(algo.S, size(result)...), result)
result = adapt(algo.arrayType, result)
return result
end


function process(algo::SinglePatchReconstructionAlgorithm, params::SinglePatchReconstructionParameter, u)
weights = process(algo, params.weightingParams, u)
weights = adapt(algo.arrayType, process(algo, params.weightingParams, u))

B = getLinearOperator(algo, params)

Expand All @@ -68,5 +70,5 @@ function getLinearOperator(algo::SinglePatchReconstructionAlgorithm, params::Sin
end

function getLinearOperator(algo::SinglePatchReconstructionAlgorithm, params::SinglePatchReconstructionParameter{<:SparseSystemMatrixLoadingParameter, S}) where {S}
return process(algo, params.sfLoad, eltype(algo.S), tuple(shape(algo.grid)...))
return process(algo, params.sfLoad, eltype(algo.S), algo.arrayType, tuple(shape(algo.grid)...))
end
Loading

0 comments on commit 7b80497

Please sign in to comment.