From 6f3f3f0c6c1c6880be06d7f95836a4baf46da1e6 Mon Sep 17 00:00:00 2001 From: nHackel Date: Thu, 7 Nov 2024 10:23:27 +0100 Subject: [PATCH] Improve performance of forward multi patch on GPU --- .../MultiPatch.jl | 39 +++++++++++++++---- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/ext/MPIRecoKernelAbstractionsExt/MultiPatch.jl b/ext/MPIRecoKernelAbstractionsExt/MultiPatch.jl index 0053a43..765e23b 100644 --- a/ext/MPIRecoKernelAbstractionsExt/MultiPatch.jl +++ b/ext/MPIRecoKernelAbstractionsExt/MultiPatch.jl @@ -42,8 +42,29 @@ function LinearAlgebra.mul!(b::AbstractVector{T}, op::DenseMultiPatchOperator{T, shared[localIdx] = zero(eltype(b)) # First we iterate over the sparse indices - @unroll for i = localIdx:grid_stride:N - shared[localIdx] = shared[localIdx] + sign * S[patch_row, xss[i, patch], smIdx] * x[xcc[i, patch]] + # The following code does essentially this: + #tmp = zero(eltype(b)) + #@unroll for i = localIdx:grid_stride:N + # tmp += sign * S[patch_row, xss[i, patch], smIdx] * x[xcc[i, patch]] + #end + #shared[localIdx] = tmp + # We first sum in a temp variable, hoping that it is accumulated in a register, since registers are faster than shared memory + + # In this variant we further try use multiple registers to do independent sums to have more instruction level parallelism + tmp = @private eltype(b) 8 + @unroll for j = 1:8 + tmp[j] = zero(eltype(b)) + end + @unroll for i = localIdx:grid_stride*8:N + @unroll for j = 1:8 + index = i + (j - 1) * grid_stride + if index <= N + tmp[j] = tmp[j] + sign * S[patch_row, xss[index , patch], smIdx] * x[xcc[index , patch]] + end + end + end + @unroll for j = 1:8 + shared[localIdx] += tmp[j] end @synchronize @@ -51,9 +72,9 @@ function LinearAlgebra.mul!(b::AbstractVector{T}, op::DenseMultiPatchOperator{T, full_reduction = grid_stride < N if full_reduction - # For a full reduction we know s = 1024 and can (manually) unroll our loop - localIdx <= 512 && (@inbounds shared[localIdx] = shared[localIdx] + shared[localIdx + 512]) - @synchronize + # For a full reduction we know s = 512 and can (manually) unroll our loop + #localIdx <= 512 && (@inbounds shared[localIdx] = shared[localIdx] + shared[localIdx + 512]) + #@synchronize localIdx <= 256 && (@inbounds shared[localIdx] = shared[localIdx] + shared[localIdx + 256]) @synchronize localIdx <= 128 && (@inbounds shared[localIdx] = shared[localIdx] + shared[localIdx + 128]) @@ -90,8 +111,8 @@ function LinearAlgebra.mul!(b::AbstractVector{T}, op::DenseMultiPatchOperator{T, end end - kernel = dense_mul!(backend, 1024) - kernel(b, x, op.S, op.xcc, op.xss, op.sign, Int32(div(op.M, op.nPatches)), op.RowToPatch, op.patchToSMIdx; ndrange = (1024, size(op, 1))) + kernel = dense_mul!(backend, 512) + kernel(b, x, op.S, op.xcc, op.xss, op.sign, Int32(div(op.M, op.nPatches)), op.RowToPatch, op.patchToSMIdx; ndrange = (512, size(op, 1))) synchronize(backend) return b end @@ -210,9 +231,11 @@ end shared = @localmem eltype(energy) grid_stride shared[localIdx] = zero(eltype(energy)) + tmp = zero(eltype(energy)) @unroll for i = localIdx:grid_stride:N - shared[localIdx] = shared[localIdx] + abs2(sign * S[patch_row, xss[i, patch], smIdx]) + tmp += abs2(sign * S[patch_row, xss[i, patch], smIdx]) end + shared[localIdx] = tmp @synchronize @private s = div(min(grid_stride, N), Int32(2))