Skip to content

Commit

Permalink
Increased block size for multi patch gpu operator
Browse files Browse the repository at this point in the history
  • Loading branch information
nHackel committed Oct 10, 2024
1 parent a8a2d6a commit 1f805e1
Showing 1 changed file with 74 additions and 43 deletions.
117 changes: 74 additions & 43 deletions ext/MPIRecoKernelAbstractionsExt/MultiPatch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,49 +18,80 @@ function Adapt.adapt_structure(::Type{arrT}, op::MultiPatchOperator) where {arrT
end
end

@kernel cpu = false inbounds = true function dense_mul!(b, @Const(x), @Const(S), @Const(xcc), @Const(xss), @Const(signs), @Const(M), @Const(RowToPatch), @Const(patchToSMIdx))
# Each group/block handles a single row of the operator
operator_row = @index(Group, Linear) # k
patch = RowToPatch[operator_row] # p
patch_row = mod1(operator_row, M) # j
smIdx = patchToSMIdx[patch]
sign = eltype(b)(signs[patch_row, smIdx])
grid_stride = prod(@groupsize())
N = Int32(size(xss, 1))

# We want to use a grid-stride loop to perform the sparse matrix-vector product.
# Each thread performs a single element-wise multiplication and reduction in its shared spot.
# Afterwards we reduce over the shared memory.
localIdx = @index(Local, Linear)
shared = @localmem eltype(b) grid_stride
shared[localIdx] = zero(eltype(b))

# First we iterate over the sparse indices
@unroll for i = localIdx:grid_stride:N
shared[localIdx] = shared[localIdx] + sign * S[patch_row, xss[i, patch], smIdx] * x[xcc[i, patch]]
end
@synchronize
function LinearAlgebra.mul!(b::AbstractVector{T}, op::DenseMultiPatchOperator{T, V}, x::AbstractVector{T}) where {T, V <: AbstractGPUArray}
backend = get_backend(b)

# Now we need to reduce the shared memory to get the final result
@private s = div(min(grid_stride, N), Int32(2))
while s > Int32(0)
if localIdx <= s
shared[localIdx] = shared[localIdx] + shared[localIdx + s]

# Define kernel within the mul! function to make sure we know the workgroup size for manual unrolling
@kernel cpu = false inbounds = true function dense_mul!(b, @Const(x), @Const(S), @Const(xcc), @Const(xss), @Const(signs), @Const(M), @Const(RowToPatch), @Const(patchToSMIdx))
# Each group/block handles a single row of the operator
operator_row = @index(Group, Linear) # k
patch = RowToPatch[operator_row] # p
patch_row = mod1(operator_row, M) # j
smIdx = patchToSMIdx[patch]
sign = eltype(b)(signs[patch_row, smIdx])
grid_stride = prod(@groupsize())
N = Int32(size(xss, 1))

# We want to use a grid-stride loop to perform the sparse matrix-vector product.
# Each thread performs a single element-wise multiplication and reduction in its shared spot.
# Afterwards we reduce over the shared memory.
localIdx = @index(Local, Linear)
shared = @localmem eltype(b) grid_stride
shared[localIdx] = zero(eltype(b))

# First we iterate over the sparse indices
@unroll for i = localIdx:grid_stride:N
shared[localIdx] = shared[localIdx] + sign * S[patch_row, xss[i, patch], smIdx] * x[xcc[i, patch]]
end
s >>= 1
@synchronize

# Now we need to reduce the shared memory to get the final result
full_reduction = grid_stride < N
if full_reduction

# For a full reduction we know s = 1024 and can (manually) unroll our loop
localIdx <= 512 && (@inbounds shared[localIdx] = shared[localIdx] + shared[localIdx + 512])
@synchronize
localIdx <= 256 && (@inbounds shared[localIdx] = shared[localIdx] + shared[localIdx + 256])
@synchronize
localIdx <= 128 && (@inbounds shared[localIdx] = shared[localIdx] + shared[localIdx + 128])
@synchronize
localIdx <= 64 && (@inbounds shared[localIdx] = shared[localIdx] + shared[localIdx + 64])
@synchronize
localIdx <= 32 && (@inbounds shared[localIdx] = shared[localIdx] + shared[localIdx + 32])
@synchronize
localIdx <= 16 && (@inbounds shared[localIdx] = shared[localIdx] + shared[localIdx + 16])
@synchronize
localIdx <= 8 && (@inbounds shared[localIdx] = shared[localIdx] + shared[localIdx + 8])
@synchronize
localIdx <= 4 && (@inbounds shared[localIdx] = shared[localIdx] + shared[localIdx + 4])
@synchronize
localIdx <= 2 && (@inbounds shared[localIdx] = shared[localIdx] + shared[localIdx + 2])
@synchronize
localIdx == 1 && (@inbounds shared[localIdx] = shared[localIdx] + shared[localIdx + 1])


else
@private s = div(min(grid_stride, N), Int32(2))
while s > Int32(0)
if localIdx <= s
shared[localIdx] = shared[localIdx] + shared[localIdx + s]
end
s >>= 1
@synchronize
end
end

# Write the result out to b
if localIdx == 1
b[operator_row] = shared[localIdx]
end
end

# Write the result out to b
if localIdx == 1
b[operator_row] = shared[localIdx]
end
end

function LinearAlgebra.mul!(b::AbstractVector{T}, op::DenseMultiPatchOperator{T, V}, x::AbstractVector{T}) where {T, V <: AbstractGPUArray}
backend = get_backend(b)
kernel = dense_mul!(backend, 256)
kernel(b, x, op.S, op.xcc, op.xss, op.sign, Int32(div(op.M, op.nPatches)), op.RowToPatch, op.patchToSMIdx; ndrange = (256, size(op, 1)))
kernel = dense_mul!(backend, 1024)
kernel(b, x, op.S, op.xcc, op.xss, op.sign, Int32(div(op.M, op.nPatches)), op.RowToPatch, op.patchToSMIdx; ndrange = (1024, size(op, 1)))
synchronize(backend)
return b
end
Expand All @@ -83,7 +114,7 @@ end

# Since we go along the columns during a matrix-vector product,
# we have a race condition with other threads writing to the same result.
for i = localIdx:grid_stride:N
@unroll for i = localIdx:grid_stride:N
tmp = sign * conj(S[patch_row, xss[i, patch], smIdx]) * val
# @atomic is not supported for ComplexF32 numbers
Atomix.@atomic res[1, xcc[i, patch]] += tmp.re
Expand All @@ -95,9 +126,9 @@ function LinearAlgebra.mul!(res::AbstractVector{T}, adj::Adjoint{T, OP}, t::Abst
backend = get_backend(res)
op = adj.parent
res .= zero(T) # We need to zero the result, because we are using += in the kernel
kernel = dense_mul_adj!(backend, 256)
kernel = dense_mul_adj!(backend, 1024)
# We have to reinterpret the result as a real array, because atomic operations on Complex numbers are not supported
kernel(reinterpret(reshape, real(eltype(res)), res), t, op.S, op.xcc, op.xss, op.sign, Int32(div(op.M, op.nPatches)), op.RowToPatch, op.patchToSMIdx; ndrange = (256, size(op, 1)))
kernel(reinterpret(reshape, real(eltype(res)), res), t, op.S, op.xcc, op.xss, op.sign, Int32(div(op.M, op.nPatches)), op.RowToPatch, op.patchToSMIdx; ndrange = (1024, size(op, 1)))
synchronize(backend)
return res
end
Expand Down Expand Up @@ -136,7 +167,7 @@ end

function RegularizedLeastSquares.kaczmarz_update!(op::DenseMultiPatchOperator{T, V}, x::vecT, row, beta) where {T, vecT <: AbstractGPUVector{T}, V <: AbstractGPUArray{T}}
backend = get_backend(x)
kernel = kaczmarz_update_kernel!(backend, 256)
kernel = kaczmarz_update_kernel!(backend, 1024)
kernel(x, op.S, row, beta, op.xcc, op.xss, op.sign, Int32(div(op.M, op.nPatches)), op.RowToPatch, op.patchToSMIdx; ndrange = size(op.xss, 1))
synchronize(backend)
return x
Expand All @@ -157,9 +188,9 @@ end

function normalize_dense_op(op::DenseMultiPatchOperator{T, V}, weights) where {T, V <: AbstractGPUArray{T}}
backend = get_backend(op.S)
kernel = normalize_kernel!(backend, 256)
kernel = normalize_kernel!(backend, 1024)
energy = KernelAbstractions.zeros(backend, real(eltype(op)), size(op, 1))
kernel(energy, weights, op.S, op.xss, op.sign, Int32(div(op.M, op.nPatches)), op.RowToPatch, op.patchToSMIdx; ndrange = (256, size(op, 1)))
kernel(energy, weights, op.S, op.xss, op.sign, Int32(div(op.M, op.nPatches)), op.RowToPatch, op.patchToSMIdx; ndrange = (1024, size(op, 1)))
synchronize(backend)
return energy
end
Expand Down

0 comments on commit 1f805e1

Please sign in to comment.