From 7368654d4e6722f4fe23a19f73ca1903a85a8864 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 25 Oct 2024 18:21:52 +0000 Subject: [PATCH 1/3] consolidate gpu data transfer --- .../providers/cuda/cuda_execution_provider.cc | 1 - .../core/providers/cuda/gpu_data_transfer.cc | 14 +++---- .../core/providers/cuda/gpu_data_transfer.h | 4 +- .../providers/migraphx/gpu_data_transfer.cc | 38 ++++++++++++------- .../migraphx/migraphx_execution_provider.cc | 2 + .../core/providers/rocm/gpu_data_transfer.cc | 19 ++++------ .../core/providers/rocm/gpu_data_transfer.h | 4 +- .../providers/rocm/rocm_execution_provider.cc | 5 +-- 8 files changed, 45 insertions(+), 42 deletions(-) diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index d3f01c1f7adc1..497d0014795ec 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -50,7 +50,6 @@ class Memcpy final : public OpKernel { ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor is nullptr."); Tensor* Y = ctx->Output(0, X->Shape()); ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output tensor."); - // do we support async copy? // The cudaMemCpyAsync will handle the pinned memory and non-pinned memory, // so we don't need the check here. auto* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(X->Location().device, Y->Location().device); diff --git a/onnxruntime/core/providers/cuda/gpu_data_transfer.cc b/onnxruntime/core/providers/cuda/gpu_data_transfer.cc index 71610634577ca..fb53da5101b56 100644 --- a/onnxruntime/core/providers/cuda/gpu_data_transfer.cc +++ b/onnxruntime/core/providers/cuda/gpu_data_transfer.cc @@ -7,10 +7,6 @@ #include "cuda_common.h" namespace onnxruntime { -GPUDataTransfer::GPUDataTransfer() {} - -GPUDataTransfer::~GPUDataTransfer() {} - bool GPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { return src_device.Type() == OrtDevice::GPU || src_device.MemType() == OrtDevice::MemType::CUDA_PINNED || dst_device.Type() == OrtDevice::GPU || dst_device.MemType() == OrtDevice::MemType::CUDA_PINNED; @@ -30,19 +26,17 @@ common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const // Copy only if the two addresses are different. if (dst_data != src_data) { CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyDeviceToDevice)); - CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(nullptr)); } } else { // copy from other CPU memory to GPU, this is blocking CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyHostToDevice)); - CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(nullptr)); } } else if (src_device.Type() == OrtDevice::GPU) { // copying from GPU to CPU memory, this is blocking CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyDeviceToHost)); - CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(nullptr)); } else { // copying between cpu memory + ORT_ENFORCE(dst_data != src_data); memcpy(dst_data, src_data, bytes); } @@ -59,7 +53,7 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, if (dst_device.Type() == OrtDevice::GPU) { if (src_device.Type() == OrtDevice::CPU) { - // copy from pinned memory to GPU, this is non-blocking + // copy from pinned or non-pinned CPU memory to GPU CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyHostToDevice, static_cast(stream.GetHandle()))); } else if (src_device.Type() == OrtDevice::GPU) { // copying between GPU, this is non-blocking @@ -69,7 +63,7 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, } } else if (src_device.Type() == OrtDevice::GPU) { if (dst_device.Type() == OrtDevice::CPU) { - // copying from GPU to pinned memory, this is non-blocking + // copy from GPU to pinned or non-pinned CPU memory. CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToHost, static_cast(stream.GetHandle()))); } } else { @@ -77,6 +71,8 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, // sync the stream first to make sure the data arrived CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(static_cast(stream.GetHandle()))); } + + ORT_ENFORCE(dst_data != src_data); memcpy(dst_data, src_data, bytes); } diff --git a/onnxruntime/core/providers/cuda/gpu_data_transfer.h b/onnxruntime/core/providers/cuda/gpu_data_transfer.h index 68846e68079f3..11e21e91936fc 100644 --- a/onnxruntime/core/providers/cuda/gpu_data_transfer.h +++ b/onnxruntime/core/providers/cuda/gpu_data_transfer.h @@ -10,8 +10,8 @@ namespace onnxruntime { class GPUDataTransfer : public IDataTransfer { public: - GPUDataTransfer(); - ~GPUDataTransfer(); + GPUDataTransfer() = default; + ~GPUDataTransfer() = default; bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override; diff --git a/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc b/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc index 51625b83b8f61..7d381e95ed5e1 100644 --- a/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc +++ b/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc @@ -2,12 +2,16 @@ // Licensed under the MIT License. #include "core/providers/shared_library/provider_api.h" -#include "gpu_data_transfer.h" -#include "migraphx_call.h" +#include "core/providers/migraphx/gpu_data_transfer.h" +#include "core/providers/migraphx/migraphx_call.h" + +// If you make change below, please also update onnxruntime/core/providers/rocm/gpu_data_transfer.cc namespace onnxruntime { + bool GPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { - return src_device.Type() == OrtDevice::GPU || src_device.MemType() == OrtDevice::MemType::HIP_PINNED || dst_device.Type() == OrtDevice::GPU || dst_device.MemType() == OrtDevice::MemType::HIP_PINNED; + return src_device.Type() == OrtDevice::GPU || src_device.MemType() == OrtDevice::MemType::HIP_PINNED || + dst_device.Type() == OrtDevice::GPU || dst_device.MemType() == OrtDevice::MemType::HIP_PINNED; } common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { @@ -23,17 +27,18 @@ common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const if (src_device.Type() == OrtDevice::GPU) { // Copy only if the two addresses are different. if (dst_data != src_data) { - HIP_CALL_THROW(hipMemcpy(dst_data, src_data, bytes, hipMemcpyDeviceToDevice)); + HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyDeviceToDevice)); } } else { // copy from other CPU memory to GPU, this is blocking - HIP_CALL_THROW(hipMemcpy(dst_data, src_data, bytes, hipMemcpyHostToDevice)); + HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyHostToDevice)); } } else if (src_device.Type() == OrtDevice::GPU) { // copying from GPU to CPU memory, this is blocking - HIP_CALL_THROW(hipMemcpy(dst_data, src_data, bytes, hipMemcpyDeviceToHost)); + HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyDeviceToHost)); } else { // copying between cpu memory + ORT_ENFORCE(dst_data != src_data); memcpy(dst_data, src_data, bytes); } @@ -49,23 +54,28 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, auto& dst_device = dst.Location().device; if (dst_device.Type() == OrtDevice::GPU) { - if (src_device.Type() == OrtDevice::CPU && src_device.MemType() == OrtDevice::MemType::HIP_PINNED) { - // copy from pinned memory to GPU, this is non-blocking - HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyHostToDevice, static_cast(stream.GetHandle()))); + if (src_device.Type() == OrtDevice::CPU) { + // If source are not pinned, the memory copy will be performed synchronously. + // For best performance, use hipHostMalloc to allocate host memory that is transferred asynchronously. + HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyHostToDevice, static_cast(stream.GetHandle()))); } else if (src_device.Type() == OrtDevice::GPU) { // copying between GPU, this is non-blocking HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToDevice, static_cast(stream.GetHandle()))); - } else { - // copy from other CPU memory to GPU, this is blocking - HIP_CALL_THROW(hipMemcpyWithStream(dst_data, src_data, bytes, hipMemcpyHostToDevice, static_cast(stream.GetHandle()))); } } else if (src_device.Type() == OrtDevice::GPU) { - HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, static_cast(stream.GetHandle()))); + // If dest are not pinned, the memory copy will be performed synchronously. + // For best performance, use hipHostMalloc to allocate host memory that is transferred asynchronously. + HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, static_cast(stream.GetHandle()))); } else { - // copying between cpu memory + if (src_device.MemType() == OrtDevice::MemType::CUDA_PINNED) { + // sync the stream first to make sure the data arrived + HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast(stream.GetHandle()))); + } + ORT_ENFORCE(dst_data != src_data); memcpy(dst_data, src_data, bytes); } return Status::OK(); } + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index dca38480434fe..fd76e72d373cd 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -49,6 +49,8 @@ class Memcpy final : public OpKernel { const IDataTransfer* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(X->Location().device, Y->Location().device); if (!gpu_data_transfer) return Status(common::ONNXRUNTIME, common::EP_FAIL, "gpu data transfer is missing in Migraphx EP."); + // CopyTensorAsync could handle both pinned memory and non-pinned CPU memory. + // For non-pinned CPU memory, the copy is synchronous. return gpu_data_transfer->CopyTensorAsync(*X, *Y, *(ctx->GetComputeStream())); } }; diff --git a/onnxruntime/core/providers/rocm/gpu_data_transfer.cc b/onnxruntime/core/providers/rocm/gpu_data_transfer.cc index 635a25480b646..3df1c007e2e52 100644 --- a/onnxruntime/core/providers/rocm/gpu_data_transfer.cc +++ b/onnxruntime/core/providers/rocm/gpu_data_transfer.cc @@ -6,10 +6,8 @@ #include "core/providers/rocm/gpu_data_transfer.h" #include "core/providers/rocm/rocm_common.h" +// If you make change below, please also update onnxruntime/core/providers/migraphx/gpu_data_transfer.cc namespace onnxruntime { -GPUDataTransfer::GPUDataTransfer() {} - -GPUDataTransfer::~GPUDataTransfer() {} bool GPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { return src_device.Type() == OrtDevice::GPU || src_device.MemType() == OrtDevice::MemType::HIP_PINNED || @@ -30,19 +28,17 @@ common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const // Copy only if the two addresses are different. if (dst_data != src_data) { HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyDeviceToDevice)); - HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); } } else { // copy from other CPU memory to GPU, this is blocking HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyHostToDevice)); - HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); } } else if (src_device.Type() == OrtDevice::GPU) { // copying from GPU to CPU memory, this is blocking HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyDeviceToHost)); - HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); } else { // copying between cpu memory + ORT_ENFORCE(dst_data != src_data); memcpy(dst_data, src_data, bytes); } @@ -59,7 +55,8 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, if (dst_device.Type() == OrtDevice::GPU) { if (src_device.Type() == OrtDevice::CPU) { - // copy from pinned memory to GPU, this is non-blocking + // If source are not pinned, the memory copy will be performed synchronously. + // For best performance, use hipHostMalloc to allocate host memory that is transferred asynchronously. HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyHostToDevice, static_cast(stream.GetHandle()))); } else if (src_device.Type() == OrtDevice::GPU) { // copying between GPU, this is non-blocking @@ -68,15 +65,15 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, } } } else if (src_device.Type() == OrtDevice::GPU) { - if (dst_device.Type() == OrtDevice::CPU) { - // copying from GPU to pinned memory, this is non-blocking - HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, static_cast(stream.GetHandle()))); - } + // If dest are not pinned, the memory copy will be performed synchronously. + // For best performance, use hipHostMalloc to allocate host memory that is transferred asynchronously. + HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, static_cast(stream.GetHandle()))); } else { if (src_device.MemType() == OrtDevice::MemType::CUDA_PINNED) { // sync the stream first to make sure the data arrived HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast(stream.GetHandle()))); } + ORT_ENFORCE(dst_data != src_data); memcpy(dst_data, src_data, bytes); } diff --git a/onnxruntime/core/providers/rocm/gpu_data_transfer.h b/onnxruntime/core/providers/rocm/gpu_data_transfer.h index 3d297bdce4a93..3d35ed52fff5c 100644 --- a/onnxruntime/core/providers/rocm/gpu_data_transfer.h +++ b/onnxruntime/core/providers/rocm/gpu_data_transfer.h @@ -10,8 +10,8 @@ namespace onnxruntime { class GPUDataTransfer : public IDataTransfer { public: - GPUDataTransfer(); - ~GPUDataTransfer(); + GPUDataTransfer() = default; + ~GPUDataTransfer() = default; bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override; diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index 02a21c033e988..2bd803f596acc 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -41,10 +41,9 @@ class Memcpy final : public OpKernel { ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor is nullptr."); Tensor* Y = ctx->Output(0, X->Shape()); ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output tensor."); - // do we support async copy? - // The rocmMemCpyAsync will handle the pinned memory and non-pinned memory, - // so we don't need the check here. auto* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(X->Location().device, Y->Location().device); + // CopyTensorAsync could handle both pinned memory and non-pinned CPU memory. + // For non-pinned CPU memory, the copy is synchronous. ORT_RETURN_IF_ERROR(gpu_data_transfer->CopyTensorAsync(*X, *Y, *ctx->GetComputeStream())); return Status::OK(); } else { From 255d289bfcb6fae0a070c32e080ea80bb1f3c71c Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 25 Oct 2024 19:47:36 +0000 Subject: [PATCH 2/3] Add HIP_RETURN_IF_ERROR macro --- onnxruntime/core/providers/migraphx/migraphx_call.h | 2 ++ onnxruntime/core/providers/migraphx/migraphx_stream_handle.h | 2 -- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_call.h b/onnxruntime/core/providers/migraphx/migraphx_call.h index f6a95cebf34b5..6d514e01aea96 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_call.h +++ b/onnxruntime/core/providers/migraphx/migraphx_call.h @@ -3,6 +3,7 @@ #pragma once #include "migraphx_inc.h" +#include "core/common/common.h" namespace onnxruntime { @@ -16,5 +17,6 @@ std::conditional_t RocmCall( #define HIP_CALL(expr) (RocmCall((expr), #expr, "HIP", hipSuccess, "", __FILE__, __LINE__)) #define HIP_CALL_THROW(expr) (RocmCall((expr), #expr, "HIP", hipSuccess, "", __FILE__, __LINE__)) +#define HIP_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(HIP_CALL(expr)) } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h index 03a7c1607e3ad..85b0aff87a436 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h +++ b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h @@ -6,8 +6,6 @@ #include "migraphx_inc.h" #include "migraphx_call.h" -#define HIP_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(HIP_CALL(expr)) - namespace onnxruntime { void WaitMIGraphXNotificationOnDevice(Stream& stream, synchronize::Notification& notification); From ba02da44a779277488867342a1a9dbbba284a5f0 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 25 Oct 2024 21:18:31 +0000 Subject: [PATCH 3/3] add back some sync --- onnxruntime/core/providers/cuda/gpu_data_transfer.cc | 8 ++++++++ onnxruntime/core/providers/migraphx/gpu_data_transfer.cc | 6 ++++++ onnxruntime/core/providers/rocm/gpu_data_transfer.cc | 6 ++++++ 3 files changed, 20 insertions(+) diff --git a/onnxruntime/core/providers/cuda/gpu_data_transfer.cc b/onnxruntime/core/providers/cuda/gpu_data_transfer.cc index fb53da5101b56..4dafbda409cd3 100644 --- a/onnxruntime/core/providers/cuda/gpu_data_transfer.cc +++ b/onnxruntime/core/providers/cuda/gpu_data_transfer.cc @@ -26,10 +26,18 @@ common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const // Copy only if the two addresses are different. if (dst_data != src_data) { CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyDeviceToDevice)); + // For device memory to device memory copy, no host-side synchronization is performed by cudaMemcpy. + // see https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(nullptr)); } } else { // copy from other CPU memory to GPU, this is blocking CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyHostToDevice)); + if (src_device.MemType() != OrtDevice::MemType::CUDA_PINNED) { + // For cudaMemcpy from pageable host memory to device memory, DMA to final destination may not have completed. + // see https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(nullptr)); + } } } else if (src_device.Type() == OrtDevice::GPU) { // copying from GPU to CPU memory, this is blocking diff --git a/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc b/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc index 7d381e95ed5e1..77c5e18a5878e 100644 --- a/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc +++ b/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc @@ -28,10 +28,16 @@ common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const // Copy only if the two addresses are different. if (dst_data != src_data) { HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyDeviceToDevice)); + // Follow core/providers/cuda/gpu_data_transfer.cc to synchronize the default stream here. + HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); } } else { // copy from other CPU memory to GPU, this is blocking HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyHostToDevice)); + if (src_device.MemType() != OrtDevice::MemType::HIP_PINNED) { + // Follow core/providers/cuda/gpu_data_transfer.cc to synchronize the default stream here. + HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); + } } } else if (src_device.Type() == OrtDevice::GPU) { // copying from GPU to CPU memory, this is blocking diff --git a/onnxruntime/core/providers/rocm/gpu_data_transfer.cc b/onnxruntime/core/providers/rocm/gpu_data_transfer.cc index 3df1c007e2e52..281a6f35a2808 100644 --- a/onnxruntime/core/providers/rocm/gpu_data_transfer.cc +++ b/onnxruntime/core/providers/rocm/gpu_data_transfer.cc @@ -28,10 +28,16 @@ common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const // Copy only if the two addresses are different. if (dst_data != src_data) { HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyDeviceToDevice)); + // Follow core/providers/cuda/gpu_data_transfer.cc to synchronize the default stream here. + HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); } } else { // copy from other CPU memory to GPU, this is blocking HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyHostToDevice)); + if (src_device.MemType() != OrtDevice::MemType::HIP_PINNED) { + // Follow core/providers/cuda/gpu_data_transfer.cc to synchronize the default stream here. + HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); + } } } else if (src_device.Type() == OrtDevice::GPU) { // copying from GPU to CPU memory, this is blocking