Skip to content

Commit

Permalink
Move strided batch pointer conversion to GPU (#2608)
Browse files Browse the repository at this point in the history
Co-authored-by: Tim Besard <[email protected]>
  • Loading branch information
THargreaves and maleadt authored Jan 8, 2025
1 parent 14ae82d commit 74b8eff
Showing 1 changed file with 21 additions and 4 deletions.
25 changes: 21 additions & 4 deletions lib/cublas/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1213,10 +1213,27 @@ end

# create a batch of pointers in device memory from a strided device array
@inline function unsafe_strided_batch(strided::DenseCuArray{T}) where {T}
batchsize = last(size(strided))
stride = prod(size(strided)[1:end-1])
ptrs = [pointer(strided, (i-1)*stride + 1) for i in 1:batchsize]
return CuArray(ptrs)
batch_size = last(size(strided))
batch_stride = prod(size(strided)[1:end-1])
#ptrs = [pointer(strided, (i-1)*batch_stride + 1) for i in 1:batch_size]
# fill the array on the GPU to avoid synchronous copies and support larger batch sizes
ptrs = CuArray{CuPtr{T}}(undef, batch_size)
function compute_pointers()
i = (blockIdx().x - 1i32) * blockDim().x + threadIdx().x
grid_stride = gridDim().x * blockDim().x
while i <= length(ptrs)
@inbounds ptrs[i] =
reinterpret(CuPtr{T}, pointer(strided, (i - 1i32) * batch_stride + 1i32))
i += grid_stride
end
return
end
kernel = @cuda launch = false compute_pointers()
config = launch_configuration(kernel.fun)
threads = min(config.threads, batch_size)
blocks = min(config.blocks, cld(batch_size, threads))
@cuda threads blocks compute_pointers()
return ptrs
end

## (GE) general matrix-matrix multiplication grouped batched
Expand Down

0 comments on commit 74b8eff

Please sign in to comment.