diff --git a/Project.toml b/Project.toml index 895f00e..8d407d6 100644 --- a/Project.toml +++ b/Project.toml @@ -31,9 +31,11 @@ DSP = "0.6, 0.7" Distributed = "1" DistributedArrays = "0.6" FFTW = "1.3" +GPUArrays = "8, 9, 10" ImageUtils = "0.2" IniFile = "0.5" JLArrays = "0.1" +KernelAbstractions = "0.8, 0.9" LinearAlgebra = "1" LinearOperators = "2.3" LinearOperatorCollection = "2" @@ -58,5 +60,12 @@ ImageMagick = "6218d12a-5da1-5696-b52f-db25d2ecc6d1" ImageQualityIndexes = "2996bd0c-7a13-11e9-2da2-2f5ce47296a9" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" +[weakdeps] +KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" +GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" + [targets] test = ["Test", "HTTP", "FileIO", "LazyArtifacts", "Scratch", "ImageMagick", "ImageQualityIndexes", "Unitful", "JLArrays"] + +[extensions] +MPIRecoKernelAbstractionsExt = ["KernelAbstractions", "GPUArrays"] \ No newline at end of file diff --git a/ext/MPIRecoKernelAbstractionsExt/MPIRecoKernelAbstractionsExt.jl b/ext/MPIRecoKernelAbstractionsExt/MPIRecoKernelAbstractionsExt.jl new file mode 100644 index 0000000..cf0f4fb --- /dev/null +++ b/ext/MPIRecoKernelAbstractionsExt/MPIRecoKernelAbstractionsExt.jl @@ -0,0 +1,8 @@ +module MPIRecoKernelAbstractionsExt + +using MPIReco, MPIReco.Adapt, MPIReco.LinearAlgebra +using KernelAbstractions, GPUArrays + +include("MultiPatch.jl") + +end \ No newline at end of file diff --git a/ext/MPIRecoKernelAbstractionsExt/MultiPatch.jl b/ext/MPIRecoKernelAbstractionsExt/MultiPatch.jl new file mode 100644 index 0000000..95f3ff4 --- /dev/null +++ b/ext/MPIRecoKernelAbstractionsExt/MultiPatch.jl @@ -0,0 +1,69 @@ +function Adapt.adapt_structure(::Type{arrT}, op::MultiPatchOperator) where {arrT <: AbstractGPUArray} + validSMs = all(x -> size(x) == size(op.S[1]), op.S) + validXCC = all(x -> length(x) == length(op.xcc[1]), op.xcc) + validXSS = all(x -> length(x) == length(op.xss[1]), op.xss) + + # Ideally we create a DenseMultiPatchOperator on the GPU + if validSMs && validXCC && validXSS + S = adapt(arrT, stack(op.S)) + # We want to use Int32 for better GPU performance + xcc = Int32.(adapt(arrT, stack(op.xcc))) + xss = Int32.(adapt(arrT, stack(op.xss))) + sign = Int32.(adapt(arrT, op.sign)) + RowToPatch = Int32.(adapt(arrT, op.RowToPatch)) + patchToSMIdx = Int32.(adapt(arrT, op.patchToSMIdx)) + return DenseMultiPatchOperator(S, op.grid, Int32(op.N), Int32(op.M), RowToPatch, xcc, xss, sign, Int32(op.nPatches), patchToSMIdx) + else + throw(ArgumentError("Cannot adapt MultiPatchOperator to $arrT, since it cannot be represented as a DenseMultiPatchOperator")) + end +end + +@kernel function dense_mul!(b, @Const(x), @Const(S), @Const(xcc), @Const(xss), @Const(signs), @Const(M), @Const(RowToPatch), @Const(patchToSMIdx)) + # Each group/block handles a single row of the operator + operator_row = @index(Group, Linear) # k + patch = RowToPatch[operator_row] # p + patch_row = mod1(operator_row, M) # j + smIdx = patchToSMIdx[patch] + sign = signs[patch_row, smIdx] + grid_stride = prod(@groupsize()) + N = size(xss, 1) + + # We want to use a grid-stride loop to perform the sparse matrix-vector product. + # Each thread performs a single element-wise multiplication and reduction in its shared spot. + # Afterwards we reduce over the shared memory. + localIdx = @index(Local, Linear) + shared = @localmem eltype(b) prod(@groupsize()) + shared[localIdx] = zero(eltype(b)) + + # First we iterate over the sparse indices + i = localIdx + while i <= N + shared[localIdx] = shared[localIdx] + sign * S[patch_row, xss[i, patch], smIdx] * x[xcc[i, patch]] + i += grid_stride + end + @synchronize + + # Now we need to reduce the shared memory to get the final result + @private s = div(min(grid_stride, N), Int32(2)) + while s > Int32(0) + if localIdx <= s + shared[localIdx] = shared[localIdx] + shared[localIdx + s] + end + s >>= 1 + @synchronize + end + + # Write the result out to b + if localIdx == 1 + b[operator_row] = shared[localIdx] + end +end + +function LinearAlgebra.mul!(b::AbstractVector{T}, op::DenseMultiPatchOperator{T, V}, x::AbstractVector{T}) where {T, V} + b[:] .= zero(T) + backend = get_backend(b) + kernel = dense_mul!(backend, 256) + kernel(b, x, op.S, op.xcc, op.xss, op.sign, div(op.M, op.nPatches), op.RowToPatch, op.patchToSMIdx; ndrange = (256, size(op, 1))) + synchronize(backend) + return b +end \ No newline at end of file