From 1b60209938fb62cd067d1e7105bf1eef08b083a6 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 31 Oct 2024 09:52:50 -0700 Subject: [PATCH] [CUDA/ROCm/Migraphx] consolidate gpu data transfer (#22609) ### Description Consolidate the gpu data transfer in CUDA, ROCm and Migraphx EP. (1) Remove some redundant stream synchronize on default stream according to spec of cudaMemcpy (2) consolidate CUDA, ROCm and MigrphaX to try use same logic. ### Motivation This is a follow up on reviewing https://github.com/microsoft/onnxruntime/pull/22589. ### Context https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html#api-sync-behavior ##### cudaMemcpy() * For transfers from pageable host memory to device memory, a stream sync is performed before the copy is initiated. The function will return once the pageable buffer has been copied to the staging memory for DMA transfer to device memory, **but the DMA to final destination may not have completed**. * For transfers from pinned host memory to device memory, the function is synchronous with respect to the host. * For transfers from device to either pageable or pinned host memory, the function returns only once the copy has completed. * For transfers from device memory to device memory, **no host-side synchronization is performed**. * For transfers from any host memory to any host memory, the function is fully synchronous with respect to the host. #### cudaMemcpyAsync * For transfers between device memory and pageable host memory, the function might be synchronous with respect to host. * For transfers from any host memory to any host memory, the function is fully synchronous with respect to the host. * If pageable memory must first be staged to pinned memory, the driver may synchronize with the stream and stage the copy into pinned memory. * For all other transfers, the function should be fully asynchronous. https://rocm.docs.amd.com/projects/HIP/en/latest/doxygen/html/group___memory.html ##### hipMemcpyAsync() If host or dest are not pinned, the memory copy will be performed synchronously. For best performance, use hipHostMalloc to allocate host memory that is transferred asynchronously. on HCC hipMemcpyAsync does not support overlapped H2D and D2H copies. For hipMemcpy, the copy is always performed by the device associated with the specified stream. ##### hipMemcpy() For hipMemcpy, the copy is always performed by the current device (set by hipSetDevice). https://github.com/ROCm/ROCm/blob/roc-5.7.x/tools/autotag/templates/rocm_changes/5.6.1.md ROCm 5.6.1 release note: hipMemcpy device-to-device (intra device) is now asynchronous with respect to the host --- .../providers/cuda/cuda_execution_provider.cc | 1 - .../core/providers/cuda/gpu_data_transfer.cc | 20 +++++---- .../core/providers/cuda/gpu_data_transfer.h | 4 +- .../providers/migraphx/gpu_data_transfer.cc | 44 +++++++++++++------ .../core/providers/migraphx/migraphx_call.h | 2 + .../migraphx/migraphx_execution_provider.cc | 2 + .../migraphx/migraphx_stream_handle.h | 2 - .../core/providers/rocm/gpu_data_transfer.cc | 23 +++++----- .../core/providers/rocm/gpu_data_transfer.h | 4 +- .../providers/rocm/rocm_execution_provider.cc | 5 +-- 10 files changed, 65 insertions(+), 42 deletions(-) diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index d3f01c1f7adc1..497d0014795ec 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -50,7 +50,6 @@ class Memcpy final : public OpKernel { ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor is nullptr."); Tensor* Y = ctx->Output(0, X->Shape()); ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output tensor."); - // do we support async copy? // The cudaMemCpyAsync will handle the pinned memory and non-pinned memory, // so we don't need the check here. auto* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(X->Location().device, Y->Location().device); diff --git a/onnxruntime/core/providers/cuda/gpu_data_transfer.cc b/onnxruntime/core/providers/cuda/gpu_data_transfer.cc index 71610634577ca..4dafbda409cd3 100644 --- a/onnxruntime/core/providers/cuda/gpu_data_transfer.cc +++ b/onnxruntime/core/providers/cuda/gpu_data_transfer.cc @@ -7,10 +7,6 @@ #include "cuda_common.h" namespace onnxruntime { -GPUDataTransfer::GPUDataTransfer() {} - -GPUDataTransfer::~GPUDataTransfer() {} - bool GPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { return src_device.Type() == OrtDevice::GPU || src_device.MemType() == OrtDevice::MemType::CUDA_PINNED || dst_device.Type() == OrtDevice::GPU || dst_device.MemType() == OrtDevice::MemType::CUDA_PINNED; @@ -30,19 +26,25 @@ common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const // Copy only if the two addresses are different. if (dst_data != src_data) { CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyDeviceToDevice)); + // For device memory to device memory copy, no host-side synchronization is performed by cudaMemcpy. + // see https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(nullptr)); } } else { // copy from other CPU memory to GPU, this is blocking CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyHostToDevice)); - CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(nullptr)); + if (src_device.MemType() != OrtDevice::MemType::CUDA_PINNED) { + // For cudaMemcpy from pageable host memory to device memory, DMA to final destination may not have completed. + // see https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(nullptr)); + } } } else if (src_device.Type() == OrtDevice::GPU) { // copying from GPU to CPU memory, this is blocking CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyDeviceToHost)); - CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(nullptr)); } else { // copying between cpu memory + ORT_ENFORCE(dst_data != src_data); memcpy(dst_data, src_data, bytes); } @@ -59,7 +61,7 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, if (dst_device.Type() == OrtDevice::GPU) { if (src_device.Type() == OrtDevice::CPU) { - // copy from pinned memory to GPU, this is non-blocking + // copy from pinned or non-pinned CPU memory to GPU CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyHostToDevice, static_cast(stream.GetHandle()))); } else if (src_device.Type() == OrtDevice::GPU) { // copying between GPU, this is non-blocking @@ -69,7 +71,7 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, } } else if (src_device.Type() == OrtDevice::GPU) { if (dst_device.Type() == OrtDevice::CPU) { - // copying from GPU to pinned memory, this is non-blocking + // copy from GPU to pinned or non-pinned CPU memory. CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToHost, static_cast(stream.GetHandle()))); } } else { @@ -77,6 +79,8 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, // sync the stream first to make sure the data arrived CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(static_cast(stream.GetHandle()))); } + + ORT_ENFORCE(dst_data != src_data); memcpy(dst_data, src_data, bytes); } diff --git a/onnxruntime/core/providers/cuda/gpu_data_transfer.h b/onnxruntime/core/providers/cuda/gpu_data_transfer.h index 68846e68079f3..11e21e91936fc 100644 --- a/onnxruntime/core/providers/cuda/gpu_data_transfer.h +++ b/onnxruntime/core/providers/cuda/gpu_data_transfer.h @@ -10,8 +10,8 @@ namespace onnxruntime { class GPUDataTransfer : public IDataTransfer { public: - GPUDataTransfer(); - ~GPUDataTransfer(); + GPUDataTransfer() = default; + ~GPUDataTransfer() = default; bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override; diff --git a/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc b/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc index 51625b83b8f61..77c5e18a5878e 100644 --- a/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc +++ b/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc @@ -2,12 +2,16 @@ // Licensed under the MIT License. #include "core/providers/shared_library/provider_api.h" -#include "gpu_data_transfer.h" -#include "migraphx_call.h" +#include "core/providers/migraphx/gpu_data_transfer.h" +#include "core/providers/migraphx/migraphx_call.h" + +// If you make change below, please also update onnxruntime/core/providers/rocm/gpu_data_transfer.cc namespace onnxruntime { + bool GPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { - return src_device.Type() == OrtDevice::GPU || src_device.MemType() == OrtDevice::MemType::HIP_PINNED || dst_device.Type() == OrtDevice::GPU || dst_device.MemType() == OrtDevice::MemType::HIP_PINNED; + return src_device.Type() == OrtDevice::GPU || src_device.MemType() == OrtDevice::MemType::HIP_PINNED || + dst_device.Type() == OrtDevice::GPU || dst_device.MemType() == OrtDevice::MemType::HIP_PINNED; } common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { @@ -23,17 +27,24 @@ common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const if (src_device.Type() == OrtDevice::GPU) { // Copy only if the two addresses are different. if (dst_data != src_data) { - HIP_CALL_THROW(hipMemcpy(dst_data, src_data, bytes, hipMemcpyDeviceToDevice)); + HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyDeviceToDevice)); + // Follow core/providers/cuda/gpu_data_transfer.cc to synchronize the default stream here. + HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); } } else { // copy from other CPU memory to GPU, this is blocking - HIP_CALL_THROW(hipMemcpy(dst_data, src_data, bytes, hipMemcpyHostToDevice)); + HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyHostToDevice)); + if (src_device.MemType() != OrtDevice::MemType::HIP_PINNED) { + // Follow core/providers/cuda/gpu_data_transfer.cc to synchronize the default stream here. + HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); + } } } else if (src_device.Type() == OrtDevice::GPU) { // copying from GPU to CPU memory, this is blocking - HIP_CALL_THROW(hipMemcpy(dst_data, src_data, bytes, hipMemcpyDeviceToHost)); + HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyDeviceToHost)); } else { // copying between cpu memory + ORT_ENFORCE(dst_data != src_data); memcpy(dst_data, src_data, bytes); } @@ -49,23 +60,28 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, auto& dst_device = dst.Location().device; if (dst_device.Type() == OrtDevice::GPU) { - if (src_device.Type() == OrtDevice::CPU && src_device.MemType() == OrtDevice::MemType::HIP_PINNED) { - // copy from pinned memory to GPU, this is non-blocking - HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyHostToDevice, static_cast(stream.GetHandle()))); + if (src_device.Type() == OrtDevice::CPU) { + // If source are not pinned, the memory copy will be performed synchronously. + // For best performance, use hipHostMalloc to allocate host memory that is transferred asynchronously. + HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyHostToDevice, static_cast(stream.GetHandle()))); } else if (src_device.Type() == OrtDevice::GPU) { // copying between GPU, this is non-blocking HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToDevice, static_cast(stream.GetHandle()))); - } else { - // copy from other CPU memory to GPU, this is blocking - HIP_CALL_THROW(hipMemcpyWithStream(dst_data, src_data, bytes, hipMemcpyHostToDevice, static_cast(stream.GetHandle()))); } } else if (src_device.Type() == OrtDevice::GPU) { - HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, static_cast(stream.GetHandle()))); + // If dest are not pinned, the memory copy will be performed synchronously. + // For best performance, use hipHostMalloc to allocate host memory that is transferred asynchronously. + HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, static_cast(stream.GetHandle()))); } else { - // copying between cpu memory + if (src_device.MemType() == OrtDevice::MemType::CUDA_PINNED) { + // sync the stream first to make sure the data arrived + HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast(stream.GetHandle()))); + } + ORT_ENFORCE(dst_data != src_data); memcpy(dst_data, src_data, bytes); } return Status::OK(); } + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_call.h b/onnxruntime/core/providers/migraphx/migraphx_call.h index f6a95cebf34b5..6d514e01aea96 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_call.h +++ b/onnxruntime/core/providers/migraphx/migraphx_call.h @@ -3,6 +3,7 @@ #pragma once #include "migraphx_inc.h" +#include "core/common/common.h" namespace onnxruntime { @@ -16,5 +17,6 @@ std::conditional_t RocmCall( #define HIP_CALL(expr) (RocmCall((expr), #expr, "HIP", hipSuccess, "", __FILE__, __LINE__)) #define HIP_CALL_THROW(expr) (RocmCall((expr), #expr, "HIP", hipSuccess, "", __FILE__, __LINE__)) +#define HIP_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(HIP_CALL(expr)) } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index ed0dd7e6ae364..3134e80f3021a 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -49,6 +49,8 @@ class Memcpy final : public OpKernel { const IDataTransfer* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(X->Location().device, Y->Location().device); if (!gpu_data_transfer) return Status(common::ONNXRUNTIME, common::EP_FAIL, "gpu data transfer is missing in Migraphx EP."); + // CopyTensorAsync could handle both pinned memory and non-pinned CPU memory. + // For non-pinned CPU memory, the copy is synchronous. return gpu_data_transfer->CopyTensorAsync(*X, *Y, *(ctx->GetComputeStream())); } }; diff --git a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h index 03a7c1607e3ad..85b0aff87a436 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h +++ b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h @@ -6,8 +6,6 @@ #include "migraphx_inc.h" #include "migraphx_call.h" -#define HIP_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(HIP_CALL(expr)) - namespace onnxruntime { void WaitMIGraphXNotificationOnDevice(Stream& stream, synchronize::Notification& notification); diff --git a/onnxruntime/core/providers/rocm/gpu_data_transfer.cc b/onnxruntime/core/providers/rocm/gpu_data_transfer.cc index 635a25480b646..281a6f35a2808 100644 --- a/onnxruntime/core/providers/rocm/gpu_data_transfer.cc +++ b/onnxruntime/core/providers/rocm/gpu_data_transfer.cc @@ -6,10 +6,8 @@ #include "core/providers/rocm/gpu_data_transfer.h" #include "core/providers/rocm/rocm_common.h" +// If you make change below, please also update onnxruntime/core/providers/migraphx/gpu_data_transfer.cc namespace onnxruntime { -GPUDataTransfer::GPUDataTransfer() {} - -GPUDataTransfer::~GPUDataTransfer() {} bool GPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { return src_device.Type() == OrtDevice::GPU || src_device.MemType() == OrtDevice::MemType::HIP_PINNED || @@ -30,19 +28,23 @@ common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const // Copy only if the two addresses are different. if (dst_data != src_data) { HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyDeviceToDevice)); + // Follow core/providers/cuda/gpu_data_transfer.cc to synchronize the default stream here. HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); } } else { // copy from other CPU memory to GPU, this is blocking HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyHostToDevice)); - HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); + if (src_device.MemType() != OrtDevice::MemType::HIP_PINNED) { + // Follow core/providers/cuda/gpu_data_transfer.cc to synchronize the default stream here. + HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); + } } } else if (src_device.Type() == OrtDevice::GPU) { // copying from GPU to CPU memory, this is blocking HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyDeviceToHost)); - HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); } else { // copying between cpu memory + ORT_ENFORCE(dst_data != src_data); memcpy(dst_data, src_data, bytes); } @@ -59,7 +61,8 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, if (dst_device.Type() == OrtDevice::GPU) { if (src_device.Type() == OrtDevice::CPU) { - // copy from pinned memory to GPU, this is non-blocking + // If source are not pinned, the memory copy will be performed synchronously. + // For best performance, use hipHostMalloc to allocate host memory that is transferred asynchronously. HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyHostToDevice, static_cast(stream.GetHandle()))); } else if (src_device.Type() == OrtDevice::GPU) { // copying between GPU, this is non-blocking @@ -68,15 +71,15 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, } } } else if (src_device.Type() == OrtDevice::GPU) { - if (dst_device.Type() == OrtDevice::CPU) { - // copying from GPU to pinned memory, this is non-blocking - HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, static_cast(stream.GetHandle()))); - } + // If dest are not pinned, the memory copy will be performed synchronously. + // For best performance, use hipHostMalloc to allocate host memory that is transferred asynchronously. + HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, static_cast(stream.GetHandle()))); } else { if (src_device.MemType() == OrtDevice::MemType::CUDA_PINNED) { // sync the stream first to make sure the data arrived HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast(stream.GetHandle()))); } + ORT_ENFORCE(dst_data != src_data); memcpy(dst_data, src_data, bytes); } diff --git a/onnxruntime/core/providers/rocm/gpu_data_transfer.h b/onnxruntime/core/providers/rocm/gpu_data_transfer.h index 3d297bdce4a93..3d35ed52fff5c 100644 --- a/onnxruntime/core/providers/rocm/gpu_data_transfer.h +++ b/onnxruntime/core/providers/rocm/gpu_data_transfer.h @@ -10,8 +10,8 @@ namespace onnxruntime { class GPUDataTransfer : public IDataTransfer { public: - GPUDataTransfer(); - ~GPUDataTransfer(); + GPUDataTransfer() = default; + ~GPUDataTransfer() = default; bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override; diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index 02a21c033e988..2bd803f596acc 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -41,10 +41,9 @@ class Memcpy final : public OpKernel { ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor is nullptr."); Tensor* Y = ctx->Output(0, X->Shape()); ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output tensor."); - // do we support async copy? - // The rocmMemCpyAsync will handle the pinned memory and non-pinned memory, - // so we don't need the check here. auto* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(X->Location().device, Y->Location().device); + // CopyTensorAsync could handle both pinned memory and non-pinned CPU memory. + // For non-pinned CPU memory, the copy is synchronous. ORT_RETURN_IF_ERROR(gpu_data_transfer->CopyTensorAsync(*X, *Y, *ctx->GetComputeStream())); return Status::OK(); } else {