Skip to content

Commit

Permalink
Improve performance of forward multi patch on GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
nHackel committed Nov 7, 2024
1 parent 1f805e1 commit 6f3f3f0
Showing 1 changed file with 31 additions and 8 deletions.
39 changes: 31 additions & 8 deletions ext/MPIRecoKernelAbstractionsExt/MultiPatch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,39 @@ function LinearAlgebra.mul!(b::AbstractVector{T}, op::DenseMultiPatchOperator{T,
shared[localIdx] = zero(eltype(b))

# First we iterate over the sparse indices
@unroll for i = localIdx:grid_stride:N
shared[localIdx] = shared[localIdx] + sign * S[patch_row, xss[i, patch], smIdx] * x[xcc[i, patch]]
# The following code does essentially this:
#tmp = zero(eltype(b))
#@unroll for i = localIdx:grid_stride:N
# tmp += sign * S[patch_row, xss[i, patch], smIdx] * x[xcc[i, patch]]
#end
#shared[localIdx] = tmp
# We first sum in a temp variable, hoping that it is accumulated in a register, since registers are faster than shared memory

# In this variant we further try use multiple registers to do independent sums to have more instruction level parallelism
tmp = @private eltype(b) 8
@unroll for j = 1:8
tmp[j] = zero(eltype(b))
end
@unroll for i = localIdx:grid_stride*8:N
@unroll for j = 1:8
index = i + (j - 1) * grid_stride
if index <= N
tmp[j] = tmp[j] + sign * S[patch_row, xss[index , patch], smIdx] * x[xcc[index , patch]]
end
end
end
@unroll for j = 1:8
shared[localIdx] += tmp[j]
end
@synchronize

# Now we need to reduce the shared memory to get the final result
full_reduction = grid_stride < N
if full_reduction

# For a full reduction we know s = 1024 and can (manually) unroll our loop
localIdx <= 512 && (@inbounds shared[localIdx] = shared[localIdx] + shared[localIdx + 512])
@synchronize
# For a full reduction we know s = 512 and can (manually) unroll our loop
#localIdx <= 512 && (@inbounds shared[localIdx] = shared[localIdx] + shared[localIdx + 512])
#@synchronize
localIdx <= 256 && (@inbounds shared[localIdx] = shared[localIdx] + shared[localIdx + 256])
@synchronize
localIdx <= 128 && (@inbounds shared[localIdx] = shared[localIdx] + shared[localIdx + 128])
Expand Down Expand Up @@ -90,8 +111,8 @@ function LinearAlgebra.mul!(b::AbstractVector{T}, op::DenseMultiPatchOperator{T,
end
end

kernel = dense_mul!(backend, 1024)
kernel(b, x, op.S, op.xcc, op.xss, op.sign, Int32(div(op.M, op.nPatches)), op.RowToPatch, op.patchToSMIdx; ndrange = (1024, size(op, 1)))
kernel = dense_mul!(backend, 512)
kernel(b, x, op.S, op.xcc, op.xss, op.sign, Int32(div(op.M, op.nPatches)), op.RowToPatch, op.patchToSMIdx; ndrange = (512, size(op, 1)))
synchronize(backend)
return b
end
Expand Down Expand Up @@ -210,9 +231,11 @@ end
shared = @localmem eltype(energy) grid_stride
shared[localIdx] = zero(eltype(energy))

tmp = zero(eltype(energy))
@unroll for i = localIdx:grid_stride:N
shared[localIdx] = shared[localIdx] + abs2(sign * S[patch_row, xss[i, patch], smIdx])
tmp += abs2(sign * S[patch_row, xss[i, patch], smIdx])
end
shared[localIdx] = tmp
@synchronize

@private s = div(min(grid_stride, N), Int32(2))
Expand Down

0 comments on commit 6f3f3f0

Please sign in to comment.