How to reduce time spent on cutlass::DeviceAllocation::copy_from_host
?
#998
Replies: 3 comments
-
Hi, @wangruohui .
If your application has multiple kernels running one after each other (e.g., a preceding layer feeding into grouped GEMM), you could consider having the preceding kernel populate the device pointers so as to avoid the memcpys. Otherwise, you will indeed need to copy from host to device.
One way to reduce these overheads is by allocating a single buffer of memory on host/device which will be used to populate all of the grouped GEMM args (e.g., all values for ptrA, then all values for ptrB, etc.). The pointers passed in to the grouped GEMM arguments are then just offsets within this large buffer (e.g., the ptrB argument would be equivalent to We have an example of doing this in the CUTLASS Python interface here: cutlass/python/cutlass/emit/pytorch.py Lines 237 to 357 in f679663 In fact, if you're interested in using this to build a PyTorch CUDA extension that uses grouped GEMM, you may be interested in the Python interface's PyTorch emitter. See this example for details: https://github.com/NVIDIA/cutlass/blob/main/examples/python/02_pytorch_extension_grouped_gemm.ipynb I hope this helps! |
Beta Was this translation helpful? Give feedback.
-
Thank you very much! I have checked the code and understood the single buffer trick. My ultimate goal is to port example/41 (grouped attention) into pytorch, so maybe I still need to deal with C++. For the first trick, did you mean something like this? # For the first time, memcpy ptr_A/B/C to device and keep it alive on device,
# e.g., by wrapping device allocation as a python object and returning it.
pointers = cutlass_module.foo(matA, matB, out=matC)
# Same data same pointer, but another operation
# In cpp, construct argument using pointers directly and launch kernel
cutlass_module.bar(matA, matB, out=matC, pointers=pointers) |
Beta Was this translation helpful? Give feedback.
-
Forgot to answer this question. You can use simple CUDA mallocs and memcpys (including the async version) if you'd prefer. Regarding your code snippet above: Are you suggesting allocating one space of device memory for each operand A in the group, and reusing these allocations across invocations (e.g., the previous layer writes its output to the same buffer each time, which is then reused in the call to the grouped kernel)? I think that could work if your problem sizes won't change and if you can make sure that the buffer isn't overwritten while the grouped GEMM is still using it. |
Beta Was this translation helpful? Give feedback.
-
Hello,
I am studying
examples/24_gemm_grouped
and making it available for pytorch. But I found coping ptrA/B/C/D and lda/b/c/d from host to device takes much more time than kernel launch. In detail, there are totally 8 memcpy in the example https://github.com/NVIDIA/cutlass/blob/main/examples/24_gemm_grouped/gemm_grouped.cu#L654-L685.Based on my understanding, this is unavoidable even though actual data is already placed on GPUs, because the kernel is defined on device, pointers to data must be accessible on device. (Am I right?)
So I am wondering if there are methods to reduce time spent on moving data around. For example, making
copy_from_host
async and overlapped? But it seems only synced memcpy is implemented forDeviceAllocation
.Thank you very much!
Beta Was this translation helpful? Give feedback.
All reactions