Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPU support for sparse single patch and for multi patch reconstructions #35

Merged
merged 18 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,21 @@ steps:
include("test/gpu/cuda.jl")'
timeout_in_minutes: 30

- label: "AMD GPUs -- MPIReco.jl"
plugins:
- JuliaCI/julia#v1:
version: "1.10"
agents:
queue: "juliagpu"
rocm: "*"
rocmgpu: "*"
command: |
julia --color=yes --project -e '
using Pkg
Pkg.add("TestEnv")
using TestEnv
TestEnv.activate();
Pkg.add("AMDGPU")
Pkg.instantiate()
include("test/gpu/rocm.jl")'
timeout_in_minutes: 30
# - label: "AMD GPUs -- MPIReco.jl"
# plugins:
# - JuliaCI/julia#v1:
# version: "1.10"
# agents:
# queue: "juliagpu"
# rocm: "*"
# rocmgpu: "*"
# command: |
# julia --color=yes --project -e '
# using Pkg
# Pkg.add("TestEnv")
# using TestEnv
# TestEnv.activate();
# Pkg.add("AMDGPU")
# Pkg.instantiate()
# include("test/gpu/rocm.jl")'
# timeout_in_minutes: 30
13 changes: 13 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Tobias Knopp <[email protected]>"]
version = "0.6.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
DSP = "717857b8-e6f2-59f4-9121-6e50c889abd2"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
DistributedArrays = "aaf54ef3-cdf8-58ed-94cc-d582ad619b94"
Expand All @@ -25,13 +26,17 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[compat]
AbstractImageReconstruction = "0.3"
Adapt = "3, 4"
Atomix = "0.1"
DSP = "0.6, 0.7"
Distributed = "1"
DistributedArrays = "0.6"
FFTW = "1.3"
GPUArrays = "8, 9, 10"
ImageUtils = "0.2"
IniFile = "0.5"
JLArrays = "0.1"
KernelAbstractions = "0.8, 0.9"
LinearAlgebra = "1"
LinearOperators = "2.3"
LinearOperatorCollection = "2"
Expand All @@ -56,5 +61,13 @@ ImageMagick = "6218d12a-5da1-5696-b52f-db25d2ecc6d1"
ImageQualityIndexes = "2996bd0c-7a13-11e9-2da2-2f5ce47296a9"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[weakdeps]
Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"

[targets]
test = ["Test", "HTTP", "FileIO", "LazyArtifacts", "Scratch", "ImageMagick", "ImageQualityIndexes", "Unitful", "JLArrays"]

[extensions]
MPIRecoKernelAbstractionsExt = ["Atomix","KernelAbstractions", "GPUArrays"]
2 changes: 1 addition & 1 deletion config/MultiPatch.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ _module = "MPIReco"
_module = "MPIReco"

[parameter.reco.solverParams]
_type = "RecoPlan{SimpleSolverParameters}"
_type = "RecoPlan{ElaborateSolverParameters}"
_module = "MPIReco"

[parameter.pre]
Expand Down
10 changes: 10 additions & 0 deletions ext/MPIRecoKernelAbstractionsExt/MPIRecoKernelAbstractionsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module MPIRecoKernelAbstractionsExt

using MPIReco, MPIReco.Adapt, MPIReco.LinearAlgebra, MPIReco.RegularizedLeastSquares
using KernelAbstractions, GPUArrays
using KernelAbstractions.Extras: @unroll
using Atomix

include("MultiPatch.jl")

end
143 changes: 143 additions & 0 deletions ext/MPIRecoKernelAbstractionsExt/MultiPatch.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
function Adapt.adapt_structure(::Type{arrT}, op::MultiPatchOperator) where {arrT <: AbstractGPUArray}
validSMs = all(x -> size(x) == size(op.S[1]), op.S)
validXCC = all(x -> length(x) == length(op.xcc[1]), op.xcc)
validXSS = all(x -> length(x) == length(op.xss[1]), op.xss)

# Ideally we create a DenseMultiPatchOperator on the GPU
if validSMs && validXCC && validXSS
S = adapt(arrT, stack(op.S))
# We want to use Int32 for better GPU performance
xcc = Int32.(adapt(arrT, stack(op.xcc)))
xss = Int32.(adapt(arrT, stack(op.xss)))
sign = Int32.(adapt(arrT, op.sign))
RowToPatch = Int32.(adapt(arrT, op.RowToPatch))
patchToSMIdx = Int32.(adapt(arrT, op.patchToSMIdx))
return DenseMultiPatchOperator(S, op.grid, op.N, op.M, RowToPatch, xcc, xss, sign, Int32(op.nPatches), patchToSMIdx)
else
throw(ArgumentError("Cannot adapt MultiPatchOperator to $arrT, since it cannot be represented as a DenseMultiPatchOperator"))
end
end

@kernel cpu = false inbounds = true function dense_mul!(b, @Const(x), @Const(S), @Const(xcc), @Const(xss), @Const(signs), @Const(M), @Const(RowToPatch), @Const(patchToSMIdx))
# Each group/block handles a single row of the operator
operator_row = @index(Group, Linear) # k
patch = RowToPatch[operator_row] # p
patch_row = mod1(operator_row, M) # j
smIdx = patchToSMIdx[patch]
sign = eltype(b)(signs[patch_row, smIdx])
grid_stride = prod(@groupsize())
N = Int32(size(xss, 1))

# We want to use a grid-stride loop to perform the sparse matrix-vector product.
# Each thread performs a single element-wise multiplication and reduction in its shared spot.
# Afterwards we reduce over the shared memory.
localIdx = @index(Local, Linear)
shared = @localmem eltype(b) grid_stride
shared[localIdx] = zero(eltype(b))

# First we iterate over the sparse indices
@unroll for i = localIdx:grid_stride:N
shared[localIdx] = shared[localIdx] + sign * S[patch_row, xss[i, patch], smIdx] * x[xcc[i, patch]]
end
@synchronize

# Now we need to reduce the shared memory to get the final result
@private s = div(min(grid_stride, N), Int32(2))
while s > Int32(0)
if localIdx <= s
shared[localIdx] = shared[localIdx] + shared[localIdx + s]
end
s >>= 1
@synchronize
end

# Write the result out to b
if localIdx == 1
b[operator_row] = shared[localIdx]
end
end

function LinearAlgebra.mul!(b::AbstractVector{T}, op::DenseMultiPatchOperator{T, V}, x::AbstractVector{T}) where {T, V <: AbstractGPUArray}
backend = get_backend(b)
kernel = dense_mul!(backend, 256)
kernel(b, x, op.S, op.xcc, op.xss, op.sign, Int32(div(op.M, op.nPatches)), op.RowToPatch, op.patchToSMIdx; ndrange = (256, size(op, 1)))
synchronize(backend)
return b
end

@kernel inbounds = true function dense_mul_adj!(res, @Const(t), @Const(S), @Const(xcc), @Const(xss), @Const(signs), @Const(M), @Const(RowToPatch), @Const(patchToSMIdx))
# Each group/block handles a single column of the adjoint(operator)
# i.e. a row of the operator
localIdx = @index(Local, Linear)
groupIdx = @index(Group, Linear) # k
patch = RowToPatch[groupIdx] # p
patch_row = mod1(groupIdx, M) # j
smIdx = patchToSMIdx[patch]
sign = eltype(res)(signs[patch_row, smIdx])
grid_stride = prod(@groupsize())
N = Int32(size(xss, 1))


# Each thread within the block will add the same value of t
val = t[groupIdx]

# Since we go along the columns during a matrix-vector product,
# we have a race condition with other threads writing to the same result.
for i = localIdx:grid_stride:N
tmp = sign * conj(S[patch_row, xss[i, patch], smIdx]) * val
# @atomic is not supported for ComplexF32 numbers
Atomix.@atomic res[1, xcc[i, patch]] += tmp.re
Atomix.@atomic res[2, xcc[i, patch]] += tmp.im
end
end

function LinearAlgebra.mul!(res::AbstractVector{T}, adj::Adjoint{T, OP}, t::AbstractVector{T}) where {T <: Complex, V <: AbstractGPUArray, OP <: DenseMultiPatchOperator{T, V}}
backend = get_backend(res)
op = adj.parent
res .= zero(T) # We need to zero the result, because we are using += in the kernel
kernel = dense_mul_adj!(backend, 256)
# We have to reinterpret the result as a real array, because atomic operations on Complex numbers are not supported
kernel(reinterpret(reshape, real(eltype(res)), res), t, op.S, op.xcc, op.xss, op.sign, Int32(div(op.M, op.nPatches)), op.RowToPatch, op.patchToSMIdx; ndrange = (256, size(op, 1)))
synchronize(backend)
return res
end

# Kaczmarz specific functions
function RegularizedLeastSquares.dot_with_matrix_row(op::DenseMultiPatchOperator{T, V}, x::AbstractArray{T}, k::Int) where {T, V <: AbstractGPUArray}
patch = @allowscalar op.RowToPatch[k]
patch_row = mod1(k, div(op.M,op.nPatches))
smIdx = @allowscalar op.patchToSMIdx[patch]
sign = @allowscalar op.sign[patch_row, smIdx]
S = op.S
# Inplace reduce-broadcast: https://github.com/JuliaLang/julia/pull/31020
return sum(Broadcast.instantiate(Base.broadcasted(view(op.xss, :, patch), view(op.xcc, :, patch)) do xs, xc
@inbounds sign * S[patch_row, xs, smIdx] * x[xc]
end))
end

function RegularizedLeastSquares.rownorm²(op::DenseMultiPatchOperator{T, V}, row::Int64) where {T, V <: AbstractGPUArray}
patch = @allowscalar op.RowToPatch[row]
patch_row = mod1(row, div(op.M,op.nPatches))
smIdx = @allowscalar op.patchToSMIdx[patch]
sign = @allowscalar op.sign[patch_row, smIdx]
S = op.S
return mapreduce(xs -> abs2(sign * S[patch_row, xs, smIdx]), +, view(op.xss, :, patch))
end

@kernel cpu = false function kaczmarz_update_kernel!(x, @Const(S), @Const(row), @Const(beta), @Const(xcc), @Const(xss), @Const(signs), @Const(M), @Const(RowToPatch), @Const(patchToSMIdx))
# Each thread handles one element of the kaczmarz update
idx = @index(Global, Linear)
patch = RowToPatch[row]
patch_row = mod1(row, M)
smIdx = patchToSMIdx[patch]
sign = eltype(x)(signs[patch_row, smIdx])
x[xcc[idx, patch]] += beta * conj(sign * S[patch_row, xss[idx, patch], smIdx])
end

function RegularizedLeastSquares.kaczmarz_update!(op::DenseMultiPatchOperator{T, V}, x::vecT, row, beta) where {T, vecT <: AbstractGPUVector{T}, V <: AbstractGPUArray{T}}
backend = get_backend(x)
kernel = kaczmarz_update_kernel!(backend, 256)
kernel(x, op.S, row, beta, op.xcc, op.xss, op.sign, Int32(div(op.M, op.nPatches)), op.RowToPatch, op.patchToSMIdx; ndrange = size(op.xss, 1))
synchronize(backend)
return x
end
26 changes: 17 additions & 9 deletions src/Algorithms/MultiPatchAlgorithms/MultiPatchAlgorithm.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
export MultiPatchReconstructionAlgorithm, MultiPatchReconstructionParameter
Base.@kwdef struct MultiPatchReconstructionParameter{F<:AbstractFrequencyFilterParameter,O<:AbstractMultiPatchOperatorParameter, S<:AbstractSolverParameters, FF<:AbstractFocusFieldPositions, FFSF<:AbstractFocusFieldPositions} <: AbstractMultiPatchReconstructionParameters
Base.@kwdef struct MultiPatchReconstructionParameter{matT <: AbstractArray,F<:AbstractFrequencyFilterParameter,O<:AbstractMultiPatchOperatorParameter, S<:AbstractSolverParameters, FF<:AbstractFocusFieldPositions, FFSF<:AbstractFocusFieldPositions, R <: AbstractRegularization} <: AbstractMultiPatchReconstructionParameters
arrayType::Type{matT} = Array
# File
sf::MultiMPIFile
freqFilter::F
opParams::O
ffPos::FF = DefaultFocusFieldPositions()
ffPosSF::FFSF = DefaultFocusFieldPositions()
solverParams::S
λ::Float32
reg::Vector{R} = AbstractRegularization[]
# weightingType::WeightingType = WeightingType.None
end

Base.@kwdef mutable struct MultiPatchReconstructionAlgorithm{P} <: AbstractMultiPatchReconstructionAlgorithm where {P<:AbstractMultiPatchAlgorithmParameters}
Base.@kwdef mutable struct MultiPatchReconstructionAlgorithm{P, matT <: AbstractArray} <: AbstractMultiPatchReconstructionAlgorithm where {P<:AbstractMultiPatchAlgorithmParameters}
params::P
# Could also do reconstruction progress meter here
opParams::Union{AbstractMultiPatchOperatorParameter, Nothing} = nothing
sf::MultiMPIFile
ffOp::Union{Nothing, MultiPatchOperator}
arrayType::Type{matT}
ffOp::Union{Nothing, AbstractMultiPatchOperator}
ffPos::Union{Nothing,AbstractArray}
ffPosSF::Union{Nothing,AbstractArray}
freqs::Vector{CartesianIndex{2}}
Expand All @@ -38,7 +40,7 @@ function MultiPatchReconstructionAlgorithm(params::MultiPatchParameters{<:Abstra
ffPosSF = [vec(ffPos(SF))[l] for l=1:L, SF in reco.sf]
end

return MultiPatchReconstructionAlgorithm(params, reco.opParams, reco.sf, nothing, ffPos_, ffPosSF, freqs, Channel{Any}(Inf))
return MultiPatchReconstructionAlgorithm(params, reco.opParams, reco.sf, reco.arrayType, nothing, ffPos_, ffPosSF, freqs, Channel{Any}(Inf))
end
recoAlgorithmTypes(::Type{MultiPatchReconstruction}) = SystemMatrixBasedAlgorithm()
AbstractImageReconstruction.parameter(algo::MultiPatchReconstructionAlgorithm) = algo.origParam
Expand Down Expand Up @@ -76,8 +78,14 @@ function process(algo::MultiPatchReconstructionAlgorithm, params::AbstractMultiP

ffPosSF = algo.ffPosSF

return MultiPatchOperator(algo.sf, frequencies; toKwargs(params)...,
gradient = gradient, FFPos = ffPos_, FFPosSF = ffPosSF)
operator = MultiPatchOperator(algo.sf, frequencies; toKwargs(params)..., gradient = gradient, FFPos = ffPos_, FFPosSF = ffPosSF)
return adapt(algo.arrayType, operator)
end

function process(algo::MultiPatchReconstructionAlgorithm, params::Union{A, ProcessResultCache{<:A}}, f::MPIFile, args...) where A <: AbstractMPIPreProcessingParameters
result = process(typeof(algo), params, f, args...)
result = adapt(algo.arrayType, result)
return result
end

function process(t::Type{<:MultiPatchReconstructionAlgorithm}, params::CommonPreProcessingParameters{NoBackgroundCorrectionParameters}, f::MPIFile, frequencies::Union{Vector{CartesianIndex{2}}, Nothing} = nothing)
Expand Down Expand Up @@ -114,8 +122,8 @@ function process(::Type{<:MultiPatchReconstructionAlgorithm}, params::ExternalPr
return data
end

function process(algo::MultiPatchReconstructionAlgorithm, params::MultiPatchReconstructionParameter, u::Array)
solver = LeastSquaresParameters(S = algo.ffOp, reg = [L2Regularization(params.λ)], solverParams = params.solverParams)
function process(algo::MultiPatchReconstructionAlgorithm, params::MultiPatchReconstructionParameter, u::AbstractArray)
solver = LeastSquaresParameters(S = algo.ffOp, reg = params.reg, solverParams = params.solverParams)

result = process(algo, solver, u)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ end
function MultiPatchReconstructionAlgorithm(params::MultiPatchParameters{<:PeriodicMotionPreProcessing,<:PeriodicMotionReconstructionParameter,<:AbstractMPIPostProcessingParameters})
reco = params.reco
freqs = process(MultiPatchReconstructionAlgorithm, reco.freqFilter, reco.sf)
return MultiPatchReconstructionAlgorithm(params, nothing, reco.sf, nothing, nothing, nothing, freqs, Channel{Any}(Inf))
return MultiPatchReconstructionAlgorithm(params, nothing, reco.sf, Array, nothing, nothing, nothing, freqs, Channel{Any}(Inf))
end

function AbstractImageReconstruction.put!(algo::MultiPatchReconstructionAlgorithm{MultiPatchParameters{PT, R, T}}, data::MPIFile) where {R, T, PT <: PeriodicMotionPreProcessing}
Expand Down
20 changes: 11 additions & 9 deletions src/Algorithms/SinglePatchAlgorithms/SinglePatchAlgorithm.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
Base.@kwdef struct SinglePatchReconstructionParameter{L<:AbstractSystemMatrixLoadingParameter, SL<:AbstractLinearSolver,
SP<:AbstractSolverParameters{SL}, R<:AbstractRegularization, W<:AbstractWeightingParameters} <: AbstractSinglePatchReconstructionParameters
matT <: AbstractArray, SP<:AbstractSolverParameters{SL}, R<:AbstractRegularization, W<:AbstractWeightingParameters} <: AbstractSinglePatchReconstructionParameters
# File
sf::MPIFile
sfLoad::Union{L, ProcessResultCache{L}}
arrayType::Type{matT} = Array
# Solver
solverParams::SP
reg::Vector{R} = AbstractRegularization[]
weightingParams::W = NoWeightingParameters()
end

Base.@kwdef mutable struct SinglePatchReconstructionAlgorithm{P} <: AbstractSinglePatchReconstructionAlgorithm where {P<:AbstractSinglePatchAlgorithmParameters}
Base.@kwdef mutable struct SinglePatchReconstructionAlgorithm{P, matT <: AbstractArray} <: AbstractSinglePatchReconstructionAlgorithm where {P<:AbstractSinglePatchAlgorithmParameters}
params::P
# Could also do reconstruction progress meter here
sf::Union{MPIFile, Vector{MPIFile}}
S::AbstractArray
arrayType::Type{matT}
grid::RegularGridPositions
freqs::Vector{CartesianIndex{2}}
output::Channel{Any}
Expand All @@ -23,16 +25,16 @@ function SinglePatchReconstruction(params::SinglePatchParameters{<:AbstractMPIPr
return SinglePatchReconstructionAlgorithm(params)
end
function SinglePatchReconstructionAlgorithm(params::SinglePatchParameters{<:AbstractMPIPreProcessingParameters, R, PT}) where {R<:AbstractSinglePatchReconstructionParameters, PT <:AbstractMPIPostProcessingParameters}
freqs, S, grid = prepareSystemMatrix(params.reco)
return SinglePatchReconstructionAlgorithm(params, params.reco.sf, S, grid, freqs, Channel{Any}(Inf))
freqs, S, grid, arrayType = prepareSystemMatrix(params.reco)
return SinglePatchReconstructionAlgorithm(params, params.reco.sf, S, arrayType, grid, freqs, Channel{Any}(Inf))
end
recoAlgorithmTypes(::Type{SinglePatchReconstruction}) = SystemMatrixBasedAlgorithm()
AbstractImageReconstruction.parameter(algo::SinglePatchReconstructionAlgorithm) = algo.params

function prepareSystemMatrix(reco::SinglePatchReconstructionParameter{L,S}) where {L<:AbstractSystemMatrixLoadingParameter, S<:AbstractLinearSolver}
freqs, sf, grid = process(AbstractMPIRecoAlgorithm, reco.sfLoad, reco.sf)
sf, grid = prepareSF(S, sf, grid)
return freqs, sf, grid
sf, grid = process(AbstractMPIRecoAlgorithm, reco.sfLoad, S, sf, grid, reco.arrayType)
return freqs, sf, grid, reco.arrayType
end


Expand All @@ -44,13 +46,13 @@ function process(algo::SinglePatchReconstructionAlgorithm, params::Union{A, Proc
@warn "System matrix and measurement have different element data type. Mapping measurment data to system matrix element type."
result = map(eltype(algo.S),result)
end
result = copyto!(similar(algo.S, size(result)...), result)
result = adapt(algo.arrayType, result)
return result
end


function process(algo::SinglePatchReconstructionAlgorithm, params::SinglePatchReconstructionParameter, u)
weights = process(algo, params.weightingParams, u)
weights = adapt(algo.arrayType, process(algo, params.weightingParams, u))

B = getLinearOperator(algo, params)

Expand All @@ -68,5 +70,5 @@ function getLinearOperator(algo::SinglePatchReconstructionAlgorithm, params::Sin
end

function getLinearOperator(algo::SinglePatchReconstructionAlgorithm, params::SinglePatchReconstructionParameter{<:SparseSystemMatrixLoadingParameter, S}) where {S}
return process(algo, params.sfLoad, eltype(algo.S), tuple(shape(algo.grid)...))
return process(algo, params.sfLoad, eltype(algo.S), algo.arrayType, tuple(shape(algo.grid)...))
end
Loading
Loading