diff --git a/lib/cublas/wrappers.jl b/lib/cublas/wrappers.jl index 11e56d1530..eb63760618 100644 --- a/lib/cublas/wrappers.jl +++ b/lib/cublas/wrappers.jl @@ -1213,10 +1213,27 @@ end # create a batch of pointers in device memory from a strided device array @inline function unsafe_strided_batch(strided::DenseCuArray{T}) where {T} - batchsize = last(size(strided)) - stride = prod(size(strided)[1:end-1]) - ptrs = [pointer(strided, (i-1)*stride + 1) for i in 1:batchsize] - return CuArray(ptrs) + batch_size = last(size(strided)) + batch_stride = prod(size(strided)[1:end-1]) + #ptrs = [pointer(strided, (i-1)*batch_stride + 1) for i in 1:batch_size] + # fill the array on the GPU to avoid synchronous copies and support larger batch sizes + ptrs = CuArray{CuPtr{T}}(undef, batch_size) + function compute_pointers() + i = (blockIdx().x - 1i32) * blockDim().x + threadIdx().x + grid_stride = gridDim().x * blockDim().x + while i <= length(ptrs) + @inbounds ptrs[i] = + reinterpret(CuPtr{T}, pointer(strided, (i - 1i32) * batch_stride + 1i32)) + i += grid_stride + end + return + end + kernel = @cuda launch = false compute_pointers() + config = launch_configuration(kernel.fun) + threads = min(config.threads, batch_size) + blocks = min(config.blocks, cld(batch_size, threads)) + @cuda threads blocks compute_pointers() + return ptrs end ## (GE) general matrix-matrix multiplication grouped batched