Skip to content

Commit

Permalink
Use cuda memset async (#21216)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
pengwa authored Jul 5, 2024
1 parent 0bbd061 commit 3f6b743
Showing 1 changed file with 3 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ struct PadAndUnflattenFunctor {
typedef typename ToCudaType<T>::MappedType CudaT;
const CudaT* input_data = reinterpret_cast<const CudaT*>(input_tensor.Data<T>());

CUDA_CALL_THROW(cudaMemset(output_tensor.MutableDataRaw(), 0, output_tensor.Shape().Size() * sizeof(CudaT)));
CUDA_CALL_THROW(cudaMemsetAsync(output_tensor.MutableDataRaw(), 0, output_tensor.Shape().Size() * sizeof(CudaT),
stream));
PadAndUnflattenImpl<CudaT>(stream, input_element_count, output_element_stride_fdm, index_value_upper_bound,
input_data, indices_tensor.Data<int64_t>(),
reinterpret_cast<CudaT*>(output_tensor.MutableData<T>()));
Expand All @@ -48,6 +49,7 @@ Status PadAndUnflatten::ComputeInternal(OpKernelContext* context) const {
const Tensor* input_tensor = context->Input<Tensor>(0);
const Tensor* indices_tensor = context->Input<Tensor>(1);
const Tensor* unflatten_dims_tensor = context->Input<Tensor>(2); // Parse the 1-D shape tensor.

ORT_ENFORCE(unflatten_dims_tensor->Shape().NumDimensions() == 1,
"unflatten_dims_tensor tensor must be 1-D.", unflatten_dims_tensor->Shape().NumDimensions());
ORT_ENFORCE(unflatten_dims_tensor->Shape().Size() == 2,
Expand Down

0 comments on commit 3f6b743

Please sign in to comment.