diff --git a/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten.cc b/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten.cc index 7bd759e8976c1..f3feef4391bb5 100644 --- a/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten.cc +++ b/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten.cc @@ -35,7 +35,8 @@ struct PadAndUnflattenFunctor { typedef typename ToCudaType::MappedType CudaT; const CudaT* input_data = reinterpret_cast(input_tensor.Data()); - CUDA_CALL_THROW(cudaMemset(output_tensor.MutableDataRaw(), 0, output_tensor.Shape().Size() * sizeof(CudaT))); + CUDA_CALL_THROW(cudaMemsetAsync(output_tensor.MutableDataRaw(), 0, output_tensor.Shape().Size() * sizeof(CudaT), + stream)); PadAndUnflattenImpl(stream, input_element_count, output_element_stride_fdm, index_value_upper_bound, input_data, indices_tensor.Data(), reinterpret_cast(output_tensor.MutableData())); @@ -48,6 +49,7 @@ Status PadAndUnflatten::ComputeInternal(OpKernelContext* context) const { const Tensor* input_tensor = context->Input(0); const Tensor* indices_tensor = context->Input(1); const Tensor* unflatten_dims_tensor = context->Input(2); // Parse the 1-D shape tensor. + ORT_ENFORCE(unflatten_dims_tensor->Shape().NumDimensions() == 1, "unflatten_dims_tensor tensor must be 1-D.", unflatten_dims_tensor->Shape().NumDimensions()); ORT_ENFORCE(unflatten_dims_tensor->Shape().Size() == 2,