-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CUDA/ROCm/Migraphx] consolidate gpu data transfer #22609
Merged
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
tianleiwu
requested review from
souptc,
PeixuanZuo,
mindest,
weixingzhang and
cloudhan
October 26, 2024 02:30
PeixuanZuo
approved these changes
Oct 31, 2024
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks!
mindest
reviewed
Oct 31, 2024
ishwar-raut1
pushed a commit
to ishwar-raut1/onnxruntime
that referenced
this pull request
Nov 19, 2024
### Description Consolidate the gpu data transfer in CUDA, ROCm and Migraphx EP. (1) Remove some redundant stream synchronize on default stream according to spec of cudaMemcpy (2) consolidate CUDA, ROCm and MigrphaX to try use same logic. ### Motivation This is a follow up on reviewing microsoft#22589. ### Context https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html#api-sync-behavior ##### cudaMemcpy() * For transfers from pageable host memory to device memory, a stream sync is performed before the copy is initiated. The function will return once the pageable buffer has been copied to the staging memory for DMA transfer to device memory, **but the DMA to final destination may not have completed**. * For transfers from pinned host memory to device memory, the function is synchronous with respect to the host. * For transfers from device to either pageable or pinned host memory, the function returns only once the copy has completed. * For transfers from device memory to device memory, **no host-side synchronization is performed**. * For transfers from any host memory to any host memory, the function is fully synchronous with respect to the host. #### cudaMemcpyAsync * For transfers between device memory and pageable host memory, the function might be synchronous with respect to host. * For transfers from any host memory to any host memory, the function is fully synchronous with respect to the host. * If pageable memory must first be staged to pinned memory, the driver may synchronize with the stream and stage the copy into pinned memory. * For all other transfers, the function should be fully asynchronous. https://rocm.docs.amd.com/projects/HIP/en/latest/doxygen/html/group___memory.html ##### hipMemcpyAsync() If host or dest are not pinned, the memory copy will be performed synchronously. For best performance, use hipHostMalloc to allocate host memory that is transferred asynchronously. on HCC hipMemcpyAsync does not support overlapped H2D and D2H copies. For hipMemcpy, the copy is always performed by the device associated with the specified stream. ##### hipMemcpy() For hipMemcpy, the copy is always performed by the current device (set by hipSetDevice). https://github.com/ROCm/ROCm/blob/roc-5.7.x/tools/autotag/templates/rocm_changes/5.6.1.md ROCm 5.6.1 release note: hipMemcpy device-to-device (intra device) is now asynchronous with respect to the host
ankitm3k
pushed a commit
to intel/onnxruntime
that referenced
this pull request
Dec 11, 2024
### Description Consolidate the gpu data transfer in CUDA, ROCm and Migraphx EP. (1) Remove some redundant stream synchronize on default stream according to spec of cudaMemcpy (2) consolidate CUDA, ROCm and MigrphaX to try use same logic. ### Motivation This is a follow up on reviewing microsoft#22589. ### Context https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html#api-sync-behavior ##### cudaMemcpy() * For transfers from pageable host memory to device memory, a stream sync is performed before the copy is initiated. The function will return once the pageable buffer has been copied to the staging memory for DMA transfer to device memory, **but the DMA to final destination may not have completed**. * For transfers from pinned host memory to device memory, the function is synchronous with respect to the host. * For transfers from device to either pageable or pinned host memory, the function returns only once the copy has completed. * For transfers from device memory to device memory, **no host-side synchronization is performed**. * For transfers from any host memory to any host memory, the function is fully synchronous with respect to the host. #### cudaMemcpyAsync * For transfers between device memory and pageable host memory, the function might be synchronous with respect to host. * For transfers from any host memory to any host memory, the function is fully synchronous with respect to the host. * If pageable memory must first be staged to pinned memory, the driver may synchronize with the stream and stage the copy into pinned memory. * For all other transfers, the function should be fully asynchronous. https://rocm.docs.amd.com/projects/HIP/en/latest/doxygen/html/group___memory.html ##### hipMemcpyAsync() If host or dest are not pinned, the memory copy will be performed synchronously. For best performance, use hipHostMalloc to allocate host memory that is transferred asynchronously. on HCC hipMemcpyAsync does not support overlapped H2D and D2H copies. For hipMemcpy, the copy is always performed by the device associated with the specified stream. ##### hipMemcpy() For hipMemcpy, the copy is always performed by the current device (set by hipSetDevice). https://github.com/ROCm/ROCm/blob/roc-5.7.x/tools/autotag/templates/rocm_changes/5.6.1.md ROCm 5.6.1 release note: hipMemcpy device-to-device (intra device) is now asynchronous with respect to the host
ankitm3k
pushed a commit
to intel/onnxruntime
that referenced
this pull request
Dec 11, 2024
### Description Consolidate the gpu data transfer in CUDA, ROCm and Migraphx EP. (1) Remove some redundant stream synchronize on default stream according to spec of cudaMemcpy (2) consolidate CUDA, ROCm and MigrphaX to try use same logic. ### Motivation This is a follow up on reviewing microsoft#22589. ### Context https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html#api-sync-behavior ##### cudaMemcpy() * For transfers from pageable host memory to device memory, a stream sync is performed before the copy is initiated. The function will return once the pageable buffer has been copied to the staging memory for DMA transfer to device memory, **but the DMA to final destination may not have completed**. * For transfers from pinned host memory to device memory, the function is synchronous with respect to the host. * For transfers from device to either pageable or pinned host memory, the function returns only once the copy has completed. * For transfers from device memory to device memory, **no host-side synchronization is performed**. * For transfers from any host memory to any host memory, the function is fully synchronous with respect to the host. #### cudaMemcpyAsync * For transfers between device memory and pageable host memory, the function might be synchronous with respect to host. * For transfers from any host memory to any host memory, the function is fully synchronous with respect to the host. * If pageable memory must first be staged to pinned memory, the driver may synchronize with the stream and stage the copy into pinned memory. * For all other transfers, the function should be fully asynchronous. https://rocm.docs.amd.com/projects/HIP/en/latest/doxygen/html/group___memory.html ##### hipMemcpyAsync() If host or dest are not pinned, the memory copy will be performed synchronously. For best performance, use hipHostMalloc to allocate host memory that is transferred asynchronously. on HCC hipMemcpyAsync does not support overlapped H2D and D2H copies. For hipMemcpy, the copy is always performed by the device associated with the specified stream. ##### hipMemcpy() For hipMemcpy, the copy is always performed by the current device (set by hipSetDevice). https://github.com/ROCm/ROCm/blob/roc-5.7.x/tools/autotag/templates/rocm_changes/5.6.1.md ROCm 5.6.1 release note: hipMemcpy device-to-device (intra device) is now asynchronous with respect to the host
ankitm3k
pushed a commit
to intel/onnxruntime
that referenced
this pull request
Dec 11, 2024
### Description Consolidate the gpu data transfer in CUDA, ROCm and Migraphx EP. (1) Remove some redundant stream synchronize on default stream according to spec of cudaMemcpy (2) consolidate CUDA, ROCm and MigrphaX to try use same logic. ### Motivation This is a follow up on reviewing microsoft#22589. ### Context https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html#api-sync-behavior ##### cudaMemcpy() * For transfers from pageable host memory to device memory, a stream sync is performed before the copy is initiated. The function will return once the pageable buffer has been copied to the staging memory for DMA transfer to device memory, **but the DMA to final destination may not have completed**. * For transfers from pinned host memory to device memory, the function is synchronous with respect to the host. * For transfers from device to either pageable or pinned host memory, the function returns only once the copy has completed. * For transfers from device memory to device memory, **no host-side synchronization is performed**. * For transfers from any host memory to any host memory, the function is fully synchronous with respect to the host. #### cudaMemcpyAsync * For transfers between device memory and pageable host memory, the function might be synchronous with respect to host. * For transfers from any host memory to any host memory, the function is fully synchronous with respect to the host. * If pageable memory must first be staged to pinned memory, the driver may synchronize with the stream and stage the copy into pinned memory. * For all other transfers, the function should be fully asynchronous. https://rocm.docs.amd.com/projects/HIP/en/latest/doxygen/html/group___memory.html ##### hipMemcpyAsync() If host or dest are not pinned, the memory copy will be performed synchronously. For best performance, use hipHostMalloc to allocate host memory that is transferred asynchronously. on HCC hipMemcpyAsync does not support overlapped H2D and D2H copies. For hipMemcpy, the copy is always performed by the device associated with the specified stream. ##### hipMemcpy() For hipMemcpy, the copy is always performed by the current device (set by hipSetDevice). https://github.com/ROCm/ROCm/blob/roc-5.7.x/tools/autotag/templates/rocm_changes/5.6.1.md ROCm 5.6.1 release note: hipMemcpy device-to-device (intra device) is now asynchronous with respect to the host
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
Consolidate the gpu data transfer in CUDA, ROCm and Migraphx EP.
(1) Remove some redundant stream synchronize on default stream according to spec of cudaMemcpy
(2) consolidate CUDA, ROCm and MigrphaX to try use same logic.
Motivation
This is a follow up on reviewing #22589.
Context
https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html#api-sync-behavior
cudaMemcpy()
cudaMemcpyAsync
https://rocm.docs.amd.com/projects/HIP/en/latest/doxygen/html/group___memory.html
hipMemcpyAsync()
If host or dest are not pinned, the memory copy will be performed synchronously. For best performance, use hipHostMalloc to allocate host memory that is transferred asynchronously.
on HCC hipMemcpyAsync does not support overlapped H2D and D2H copies. For hipMemcpy, the copy is always performed by the device associated with the specified stream.
hipMemcpy()
For hipMemcpy, the copy is always performed by the current device (set by hipSetDevice).
https://github.com/ROCm/ROCm/blob/roc-5.7.x/tools/autotag/templates/rocm_changes/5.6.1.md
ROCm 5.6.1 release note: hipMemcpy device-to-device (intra device) is now asynchronous with respect to the host