diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 31cca232fde34..9d9b266355335 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -682,7 +682,8 @@ Do not modify directly.*
|PRelu|*in* X:**T**
*in* slope:**T**
*out* Y:**T**|16+|**T** = tensor(double), tensor(float), tensor(float16)|
|||[9, 15]|**T** = tensor(double), tensor(float), tensor(float16)|
|||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)|
-|Pad|*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*in* axes:**Tind**
*out* output:**T**
or
*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*out* output:**T**
or
*in* data:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)|
+|Pad|*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*in* axes:**Tind**
*out* output:**T**
or
*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*out* output:**T**
or
*in* data:**T**
*out* output:**T**|18+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)|
+|||[13, 17]|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)|
|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
|||[2, 10]|**T** = tensor(double), tensor(float), tensor(float16)|
|ParametricSoftplus|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc
index 9c55d37f550f4..bf73c59fb78ca 100644
--- a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc
+++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc
@@ -87,7 +87,13 @@ struct ProviderHostCPUImpl : ProviderHostCPU {
const TensorShape& indice_shape,
const TensorShape& update_shape) override { return ScatterND::ValidateShapes(input_shape, indice_shape, update_shape); }
// From cpu/tensor/padbase.h (direct)
- Status PadBase__HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, TensorShape& output_shape) override { return PadBase::HandleDimValueZero(mode, input_shape, output_shape); }
+ Status PadBase__HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, const TensorShape& output_shape) override { return PadBase::HandleDimValueZero(mode, input_shape, output_shape); }
+
+ void PadBase__ComputePads(OpKernelContext& ctx, size_t data_rank, gsl::span pads_data,
+ PadsVector& pads) override {
+ PadBase::ComputePads(ctx, data_rank, pads_data, pads);
+ }
+
// From cpu/tensor/split.h (direct)
Status SplitBase__PrepareForCompute(const SplitBase* p, const TensorShape& input_shape, int num_outputs, int64_t& axis, int& before_dims,
int& after_dims_including_split_axis, int& after_dims_excluding_split,
diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.h b/onnxruntime/core/providers/cpu/cpu_provider_shared.h
index 8dee1cd620282..f33eec4b93e98 100644
--- a/onnxruntime/core/providers/cpu/cpu_provider_shared.h
+++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.h
@@ -25,6 +25,8 @@ class UnsqueezeBase__Prepare; // Directly maps to UnsqueezeBase::Pr
class contrib__AdamWOptimizerBase__Prepare;
class contrib__SGDOptimizerV2Base__Prepare;
+using PadsVector = InlinedVector;
+
struct ProviderHostCPU {
// From cpu/tensor/gatherbase.h
virtual Status GatherBase__PrepareForCompute(const GatherBase* p, OpKernelContext* context, GatherBase__Prepare& prepare) = 0;
@@ -44,7 +46,11 @@ struct ProviderHostCPU {
const TensorShape& indice_shape,
const TensorShape& update_shape) = 0;
// From cpu/tensor/padbase.h
- virtual Status PadBase__HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, TensorShape& output_shape) = 0;
+ virtual Status PadBase__HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, const TensorShape& output_shape) = 0;
+
+ virtual void PadBase__ComputePads(OpKernelContext& ctx, size_t data_rank, gsl::span pads_data,
+ PadsVector& pads) = 0;
+
// From cpu/tensor/split.h
virtual Status SplitBase__PrepareForCompute(const SplitBase* p, const TensorShape& input_shape, int num_outputs, int64_t& axis, int& before_dims,
int& after_dims_including_split_axis, int& after_dims_excluding_split,
diff --git a/onnxruntime/core/providers/cpu/tensor/pad.cc b/onnxruntime/core/providers/cpu/tensor/pad.cc
index fe5267f20712b..912280687e229 100644
--- a/onnxruntime/core/providers/cpu/tensor/pad.cc
+++ b/onnxruntime/core/providers/cpu/tensor/pad.cc
@@ -9,6 +9,8 @@
#include "core/providers/op_kernel_type_control.h"
#include "core/util/math.h"
+#include
+
// there's no way to use a raw pointer as the copy destination with std::copy_n
// (which gsl::copy uses with span::data() which returns a raw pointer) with the 14.11 toolset
// without generating a 4996 warning. going through an iterator is way too much overhead so turn off the warning.
@@ -167,47 +169,7 @@ ONNX_CPU_OPERATOR_KERNEL(
using PadsVector = PadBase::PadsVector;
-// This is the general padding method to n-dimensionally do edge or reflection padding (based on the inputDelta values)
-template
-static void PadAxis(T* output, T* input, ptrdiff_t input_delta, ptrdiff_t input_pitch,
- size_t block_size, size_t block_count) {
- for (size_t block_index = 0; block_index < block_count; block_index++) {
- for (size_t i = 0; i < block_size; i++) {
- *output++ = *input;
- input += input_delta;
- }
- input += input_pitch;
- }
-}
-
-// These are optimizations of PadAxis. The inner loop is removed since the innermost axis has a blockSize of 1,
-// and inputPitch and inputDelta are just a single value added each iteration.
-template
-static void PadInnermostAxis(T* output, T* input, ptrdiff_t input_delta, size_t block_count) {
- for (size_t block_index = 0; block_index < block_count; block_index++) {
- *output++ = *input;
- input += input_delta;
- }
-}
-
-// For constant padding, there is no input, just a size to write the constant to
-template
-static void PadAxisConstant(T* output, T constant, size_t size) {
- if (size == 1) {
- *output = constant;
- } else if (size == 2) {
- *output = constant;
- *(output + 1) = constant;
- } else {
- // This would be faster with SSE instructions.
- // That would mean to have an implementation for each type (uint8, uint32, uint64).
- T* end = output + size;
- for (; output != end;)
- *output++ = constant;
- }
-}
-
-Status PadBase::HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, TensorShape& output_shape) {
+Status PadBase::HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, const TensorShape& output_shape) {
switch (mode) {
case Mode::Constant: {
// default behavior is fine
@@ -242,34 +204,66 @@ Status PadBase::HandleDimValueZero(const Mode& mode, const TensorShape& input_sh
return Status::OK();
}
-// special handling for edge case where the input has one or more dims with value of 0
-template
-static Status PadInputWithDimValueOfZero(OpKernelContext* ctx,
- const Mode& mode,
- const TensorShape& input_shape,
- TensorShapeVector& output_dims,
- T value) {
- TensorShape output_shape(output_dims);
- ORT_RETURN_IF_ERROR(PadBase::HandleDimValueZero(mode, input_shape, output_shape));
-
- auto& output_tensor = *ctx->Output(0, output_shape);
-
- // we need to add pads if mode is constant, otherwise the output has one or more dim values of 0 so is empty
- if (mode == Mode::Constant) {
- // we add pads with the default value to all dims including those with a value of 0
- auto* output = reinterpret_cast(output_tensor.MutableDataRaw());
- std::fill_n(output, output_shape.Size(), value);
+static void ComputePadWithAxes(
+ gsl::span pads_tensor_raw_data,
+ std::function get_axis,
+ size_t axes_size,
+ size_t data_rank,
+ PadsVector& pads) {
+ for (size_t i = 0; i < axes_size; ++i) {
+ const size_t axis = onnxruntime::narrow(HandleNegativeAxis(get_axis(i), data_rank));
+ pads[axis] = pads_tensor_raw_data[i]; // xi_begin
+ pads[data_rank + axis] = pads_tensor_raw_data[axes_size + i]; // xi_end
}
+}
- return Status::OK();
+void PadBase::ComputePads(OpKernelContext& ctx, size_t data_rank, gsl::span pads_data,
+ PadsVector& pads) {
+ pads.reserve(2 * data_rank);
+ const Tensor* axes_tensor = ctx.Input(3);
+ if (axes_tensor) {
+ const size_t num_axes_dims = axes_tensor->Shape().NumDimensions();
+ ORT_ENFORCE(num_axes_dims == 1, "Axes tensor should be a 1D tensor ");
+
+ const int64_t num_axes = axes_tensor->Shape().Size();
+ ORT_ENFORCE(pads_data.size() == narrow(2 * num_axes),
+ "Pads tensor size should be equal to twice the number of explicitly provided axes.");
+
+ pads.resize(2 * data_rank, 0);
+ if (axes_tensor->IsDataType()) {
+ auto axes_data = axes_tensor->DataAsSpan();
+ ComputePadWithAxes(
+ pads_data,
+ [axes_data](size_t idx) -> int64_t {
+ return axes_data[idx];
+ },
+ axes_data.size(),
+ data_rank,
+ pads);
+ } else if (axes_tensor->IsDataType()) {
+ auto axes_data = axes_tensor->DataAsSpan();
+ ComputePadWithAxes(
+ pads_data,
+ [axes_data](size_t idx) {
+ return axes_data[idx];
+ },
+ axes_data.size(),
+ data_rank,
+ pads);
+ }
+ } else {
+ ORT_ENFORCE(pads_data.size() == 2 * data_rank,
+ "Pads tensor size should be equal to twice the input dimension count ");
+ pads.assign(pads_data.begin(), pads_data.end());
+ }
}
// Flatten no padding inner most Axis, so one memcpy cover multiple Axis.
// For example, for a shape of [1,224,224,3] with padding [0,3,3,0,0,3,3,0], can be flatten as
// [1,224,224*3] with padding [0,3,3*3,0,3,3*3].
-static void FlattenInnerShape(const TensorShapeVector& input_dims, const PadsVector& pads,
- const PadsVector& slices, TensorShapeVector& reshaped_dims) {
- size_t dims_count = input_dims.size();
+void PadBase::FlattenInnerShape(gsl::span input_dims, gsl::span pads,
+ gsl::span slices, TensorShapeVector& reshaped_dims) {
+ const size_t dims_count = input_dims.size();
size_t inner_axis = dims_count - 1;
size_t inner_size = 1;
@@ -288,14 +282,14 @@ static void FlattenInnerShape(const TensorShapeVector& input_dims, const PadsVec
} while (inner_axis-- > 0);
reshaped_dims.reserve(inner_axis + 1);
- std::copy(input_dims.cbegin(), input_dims.cbegin() + inner_axis + 1, std::back_inserter(reshaped_dims));
+ std::copy(input_dims.begin(), input_dims.begin() + inner_axis + 1, std::back_inserter(reshaped_dims));
// Flatten inner axis.
reshaped_dims[inner_axis] = inner_size;
}
-static void ReshapePads(const PadsVector& src_pad, size_t src_dim_count, size_t new_dim_count,
- size_t inner_no_pad_size, PadsVector& reshaped_pad) {
+void PadBase::ReshapePads(gsl::span src_pad, size_t src_dim_count, size_t new_dim_count,
+ size_t inner_no_pad_size, PadsVector& reshaped_pad) {
size_t inner_axis = new_dim_count - 1;
std::copy(src_pad.begin(), src_pad.begin() + inner_axis, reshaped_pad.begin());
std::copy(src_pad.begin() + src_dim_count, src_pad.begin() + src_dim_count + inner_axis,
@@ -306,6 +300,68 @@ static void ReshapePads(const PadsVector& src_pad, size_t src_dim_count, size_t
reshaped_pad[inner_axis + new_dim_count] = src_pad[inner_axis + src_dim_count] * inner_no_pad_size;
}
+// special handling for edge case where the input has one or more dims with value of 0
+template
+static Status PadInputWithDimValueOfZero(OpKernelContext* ctx,
+ const Mode& mode,
+ const TensorShape& input_shape,
+ TensorShapeVector& output_dims,
+ T value) {
+ TensorShape output_shape(output_dims);
+ ORT_RETURN_IF_ERROR(PadBase::HandleDimValueZero(mode, input_shape, output_shape));
+
+ auto& output_tensor = *ctx->Output(0, output_shape);
+
+ // we need to add pads if mode is constant, otherwise the output has one or more dim values of 0 so is empty
+ if (mode == Mode::Constant) {
+ // we add pads with the default value to all dims including those with a value of 0
+ auto* output = reinterpret_cast(output_tensor.MutableDataRaw());
+ std::fill_n(output, output_shape.Size(), value);
+ }
+
+ return Status::OK();
+}
+
+// This is the general padding method to n-dimensionally do edge or reflection padding (based on the inputDelta values)
+template
+static void PadAxis(T* output, T* input, ptrdiff_t input_delta, ptrdiff_t input_pitch,
+ size_t block_size, size_t block_count) {
+ for (size_t block_index = 0; block_index < block_count; block_index++) {
+ for (size_t i = 0; i < block_size; i++) {
+ *output++ = *input;
+ input += input_delta;
+ }
+ input += input_pitch;
+ }
+}
+
+// These are optimizations of PadAxis. The inner loop is removed since the innermost axis has a blockSize of 1,
+// and inputPitch and inputDelta are just a single value added each iteration.
+template
+static void PadInnermostAxis(T* output, T* input, ptrdiff_t input_delta, size_t block_count) {
+ for (size_t block_index = 0; block_index < block_count; block_index++) {
+ *output++ = *input;
+ input += input_delta;
+ }
+}
+
+// For constant padding, there is no input, just a size to write the constant to
+template
+static void PadAxisConstant(T* output, T constant, size_t size) {
+ if (size == 1) {
+ *output = constant;
+ } else if (size == 2) {
+ *output = constant;
+ *(output + 1) = constant;
+ } else {
+ // This would be faster with SSE instructions.
+ // That would mean to have an implementation for each type (uint8, uint32, uint64).
+ T* end = output + size;
+ for (; output != end;)
+ *output++ = constant;
+ }
+}
+
template
static Status PadImpl(OpKernelContext* ctx,
const PadsVector& pads,
@@ -327,7 +383,7 @@ static Status PadImpl(OpKernelContext* ctx,
// Reshape input dims
TensorShapeVector reshaped_input_dims;
- FlattenInnerShape(output_dims, pads, slices, reshaped_input_dims);
+ PadBase::FlattenInnerShape(output_dims, pads, slices, reshaped_input_dims);
// Reshape padding
size_t new_dims_count = reshaped_input_dims.size();
@@ -336,8 +392,8 @@ static Status PadImpl(OpKernelContext* ctx,
? reshaped_input_dims[inner_axis] / output_dims[inner_axis]
: 0);
PadsVector reshaped_pad(2 * new_dims_count), reshaped_slice(2 * new_dims_count);
- ReshapePads(pads, data_rank, new_dims_count, inner_no_pad_size, reshaped_pad);
- ReshapePads(slices, data_rank, new_dims_count, inner_no_pad_size, reshaped_slice);
+ PadBase::ReshapePads(pads, data_rank, new_dims_count, inner_no_pad_size, reshaped_pad);
+ PadBase::ReshapePads(slices, data_rank, new_dims_count, inner_no_pad_size, reshaped_slice);
TensorShapeVector reshaped_output_dims = reshaped_input_dims;
TensorShapeVector input_starts;
@@ -575,20 +631,6 @@ static PadValue PadValueFromFloat(float value, MLDataType data_type) {
return result;
}
-template
-void ComputePadWithAxes(
- gsl::span pads_tensor_raw_data,
- gsl::span axes_tensor_raw_data,
- size_t data_rank,
- PadsVector& pads) {
- size_t axes_size = axes_tensor_raw_data.size();
- for (size_t i = 0; i < axes_size; ++i) {
- int64_t axis = HandleNegativeAxis(onnxruntime::narrow(axes_tensor_raw_data[i]), data_rank);
- pads[onnxruntime::narrow(axis)] = pads_tensor_raw_data[i]; // xi_begin
- pads[data_rank + onnxruntime::narrow(axis)] = pads_tensor_raw_data[axes_size + i]; // xi_end
- }
-}
-
Status Pad::Compute(OpKernelContext* ctx) const {
const Tensor& input_tensor = *ctx->Input(0);
MLDataType data_type = input_tensor.DataType();
@@ -608,48 +650,14 @@ Status Pad::Compute(OpKernelContext* ctx) const {
ORT_ENFORCE(pads_tensor_dims.size() == 1 || (pads_tensor_dims.size() == 2 && pads_tensor_dims[0] == 1),
"Pads tensor should be a 1D tensor of shape [2 * num_axes] "
"or a 2D tensor of shape [1, 2 * num_axes]");
- const int64_t* pads_tensor_raw_data = pads_tensor.Data();
- size_t pads_size = static_cast(pads_tensor.Shape().Size());
- pads.reserve(2 * data_rank);
-
- const Tensor* axes_tensor = ctx->Input(3);
- if (axes_tensor) {
- const auto& axes_tensor_dims = axes_tensor->Shape().GetDims();
- ORT_ENFORCE(axes_tensor_dims.size() == 1, "Axes tensor should be a 1D tensor ");
- int64_t axes_size = axes_tensor_dims[0];
-
- pads.resize(2 * data_rank, 0);
- if (axes_tensor->IsDataType()) {
- const int32_t* axes_tensor_raw_data = axes_tensor->Data();
- ComputePadWithAxes(
- {pads_tensor_raw_data, onnxruntime::narrow(2 * axes_size)},
- {axes_tensor_raw_data, onnxruntime::narrow(axes_size)},
- data_rank,
- pads);
- } else if (axes_tensor->IsDataType()) {
- const int64_t* axes_tensor_raw_data = axes_tensor->Data();
- ComputePadWithAxes(
- {pads_tensor_raw_data, onnxruntime::narrow(2 * axes_size)},
- {axes_tensor_raw_data, onnxruntime::narrow(axes_size)},
- data_rank,
- pads);
- }
- } else {
- ORT_ENFORCE(pads_size == 2 * data_rank,
- "Pads tensor size should be equal to twice the input dimension count ");
- for (size_t i = 0; i < pads_size; ++i) {
- pads.push_back(pads_tensor_raw_data[i]);
- }
- }
+
+ const auto pads_data = pads_tensor.DataAsSpan();
+
+ // Compute Pads by applying axes if specified otherwise copy the supplied pads.
+ PadBase::ComputePads(*ctx, data_rank, pads_data, pads);
// Separate out any negative pads into the slices array
- slices.assign(pads.size(), 0);
- for (size_t index = 0; index < pads.size(); index++) {
- if (pads[index] < 0) {
- slices[index] = pads[index];
- pads[index] = 0;
- }
- }
+ PadBase::SeparateNegativeToSlices(pads, slices);
value.u64 = 0U;
const Tensor* value_tensor = ctx->Input(2);
diff --git a/onnxruntime/core/providers/cpu/tensor/padbase.h b/onnxruntime/core/providers/cpu/tensor/padbase.h
index d869ed1a6dda2..43f9cbfc9f9a4 100644
--- a/onnxruntime/core/providers/cpu/tensor/padbase.h
+++ b/onnxruntime/core/providers/cpu/tensor/padbase.h
@@ -19,9 +19,80 @@ class PadBase {
// Pads and slices are usually about twice the shapes involved
using PadsVector = InlinedVector;
- // Update the output_shape to make it consistent with numpy handling where there are one or more dimensions
- // in the input_shape with a value of zero.
- static Status HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, TensorShape& output_shape);
+ // The following several functions are shared among the providers
+
+ ///
+ /// Handle the case when the input shape has zero dim values.
+ /// Depending on the mode, the input dim with zero value must match the output dim value.
+ ///
+ ///
+ /// Padding mode enum value
+ /// actual input shape
+ /// output_shape
+ /// Error if current mode padding can not be achieved with zero dim values
+ static Status HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, const TensorShape& output_shape);
+
+ ///
+ /// Compute Pads by applying axes if specified otherwise copy the supplied pads.
+ ///
+ /// The function queries optional axes input (since version 18) and if present,
+ /// applies it as a mask to the pads. If axes is not present, the pads are copied as is.
+ /// If axes are present, they are used as a mask over pads, so only those axes are being padded.
+ ///
+ /// kernel context to query axes input
+ /// input rank
+ /// pads data from pads input
+ /// resulting pads
+ static void ComputePads(OpKernelContext& ctx, size_t data_rank, gsl::span pads_data,
+ PadsVector& pads);
+
+ ///
+ /// Separates negative pad values to slices and zeros them out in original pads.
+ /// Leaving the rest of slices values as zero.
+ ///
+ /// This function is used inline in the Pad CUDA implementation and is not exposed via a provider
+ /// interfaces.
+ ///
+ /// pad values
+ /// slices output
+ static void SeparateNegativeToSlices(gsl::span pads, PadsVector& slices) {
+ slices.assign(pads.size(), 0);
+ for (size_t index = 0, lim = pads.size(); index < lim; index++) {
+ if (pads[index] < 0) {
+ slices[index] = pads[index];
+ pads[index] = 0;
+ }
+ }
+ }
+
+ // End provider shared
+
+ ///
+ /// Flatten no padding inner most Axis, so one memcpy cover multiple Axis.
+ /// For example, for a shape of [1,224,224,3] with padding [0,3,3,0,0,3,3,0], can be flatten as
+ /// [1,224,224*3] with padding [0,3,3*3,0,3,3*3].
+ ///
+ /// This is a helper function pads are expected to be twice the rank
+ ///
+ /// original input dims
+ /// pad values
+ /// slices
+ /// result dims
+ static void FlattenInnerShape(gsl::span input_dims, gsl::span pads,
+ gsl::span slices, TensorShapeVector& reshaped_dims);
+
+ ///
+ /// Used after the inner shape is flattened, so we can apply this function to pads and slices
+ /// to reshape them as well.
+ ///
+ /// pads
+ /// original dim count
+ /// expected flattended dim count
+ /// is the left most dimension that was flattened.
+ /// In the example above, that would be 224, reverse computed from 224*3
+ /// resulting reshaped pads or slices
+ static void ReshapePads(gsl::span src_pad, size_t src_dim_count, size_t new_dim_count,
+ size_t inner_no_pad_size, PadsVector& reshaped_pad);
protected:
PadBase(const OpKernelInfo& info) : value_(info.GetAttrOrDefault("value", 0.f)) {
diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
index 644bcaaa24cd4..3fc4ed355a12b 100644
--- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
+++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
@@ -1121,10 +1121,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, LRN);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, Identity);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, ScatterND);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Pad);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Pad);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Pad);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, bool, Pad);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, float, Pad);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, double, Pad);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, MLFloat16, Pad);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, bool, Pad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, SpaceToDepth);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, DepthToSpace);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int8_t, Sign);
@@ -1269,6 +1269,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMax);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, Pad);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Pad);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, bool, Pad);
// Opset 19
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, float, Cast);
@@ -2008,10 +2012,10 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
@@ -2091,13 +2095,6 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
@@ -2150,11 +2147,22 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
// Opset 18
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
// Opset 19
BuildKernelCreateInfo,
diff --git a/onnxruntime/core/providers/cuda/tensor/pad.cc b/onnxruntime/core/providers/cuda/tensor/pad.cc
index 4584e5fd8272c..bdd6567d2ef34 100644
--- a/onnxruntime/core/providers/cuda/tensor/pad.cc
+++ b/onnxruntime/core/providers/cuda/tensor/pad.cc
@@ -29,15 +29,27 @@ namespace cuda {
.InputMemoryType(OrtMemTypeCPUInput, 2) \
.TypeConstraint("T", DataTypeImpl::GetTensorType()), \
Pad); \
+ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
+ Pad, \
+ kOnnxDomain, \
+ 13, 17, \
+ T, \
+ kCudaExecutionProvider, \
+ (*KernelDefBuilder::Create()) \
+ .InputMemoryType(OrtMemTypeCPUInput, 1) \
+ .InputMemoryType(OrtMemTypeCPUInput, 2) \
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \
+ Pad); \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
Pad, \
kOnnxDomain, \
- 13, \
+ 18, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.InputMemoryType(OrtMemTypeCPUInput, 1) \
.InputMemoryType(OrtMemTypeCPUInput, 2) \
+ .InputMemoryType(OrtMemTypeCPUInput, 3) \
.TypeConstraint("T", DataTypeImpl::GetTensorType()), \
Pad);
@@ -94,28 +106,15 @@ Status Pad::ComputeInternal(OpKernelContext* ctx) const {
if (is_dynamic_) {
const Tensor& pads_tensor = *ctx->Input(1);
const auto pads_tensor_dims = pads_tensor.Shape().GetDims();
- ORT_ENFORCE(utils::IsPrimitiveDataType(pads_tensor.DataType()),
- "Pads tensor should be an INT64 tensor");
ORT_ENFORCE(pads_tensor_dims.size() == 1 || (pads_tensor_dims.size() == 2 && pads_tensor_dims[0] == 1),
- "Pads tensor should be a 1D tensor of shape [2 * input_rank] or a 2D tensor of shape [1, 2 * input_rank]");
+ "Pads tensor should be a 1D tensor of shape [2 * num_axes] or a 2D tensor of shape [1, 2 * num_axes]");
- const int64_t* pads_tensor_raw_data = pads_tensor.Data();
- size_t pads_size = static_cast(pads_tensor.Shape().Size());
- ORT_ENFORCE(pads_size == 2 * static_cast(dimension_count),
- "Pads tensor size should be equal to twice the input dimension count ");
+ const auto pads_data = pads_tensor.DataAsSpan();
+
+ PadBase::ComputePads(*ctx, input_shape.NumDimensions(), pads_data, pads);
- pads.reserve(2LL * dimension_count);
- for (size_t i = 0; i < pads_size; ++i) {
- pads.push_back(pads_tensor_raw_data[i]);
- }
// Separate out any negative pads into the slices array
- slices.resize(pads.size(), 0);
- for (size_t index = 0; index < pads.size(); index++) {
- if (pads[index] < 0) {
- slices[index] = pads[index];
- pads[index] = 0;
- }
- }
+ PadBase::SeparateNegativeToSlices(pads, slices);
T raw_value{};
const Tensor* value_tensor = ctx->Input(2);
diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc
index d7bec337a6be4..fff3d14b763d5 100644
--- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc
+++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc
@@ -1158,10 +1158,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, LRN);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, Identity);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, ScatterND);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Pad);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Pad);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Pad);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, bool, Pad);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, float, Pad);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, double, Pad);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, MLFloat16, Pad);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, bool, Pad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, SpaceToDepth);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, DepthToSpace);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int8_t, Sign);
@@ -1298,6 +1298,11 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 17, MLFloat16, LayerNormalization);
// Opset 18
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, Pad);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, Pad);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, bool, Pad);
+
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, Split);
// Opset 19
@@ -2088,10 +2093,10 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
@@ -2228,6 +2233,11 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
// Opset 18
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+
BuildKernelCreateInfo,
// Opset 19
diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc
index a3155fe6b86cf..e1d0e310425c5 100644
--- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc
+++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc
@@ -547,7 +547,14 @@ Status ScatterND::ValidateShapes(const TensorShape& input_shape,
const TensorShape& indice_shape,
const TensorShape& update_shape) { return g_host_cpu.ScatterNDBase__ValidateShapes(input_shape, indice_shape, update_shape); }
-Status PadBase::HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, TensorShape& output_shape) { return g_host_cpu.PadBase__HandleDimValueZero(mode, input_shape, output_shape); }
+Status PadBase::HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, const TensorShape& output_shape) {
+ return g_host_cpu.PadBase__HandleDimValueZero(mode, input_shape, output_shape);
+}
+
+void PadBase::ComputePads(OpKernelContext& ctx, size_t data_rank, gsl::span pads_data,
+ PadsVector& pads) {
+ g_host_cpu.PadBase__ComputePads(ctx, data_rank, pads_data, pads);
+}
Status ConcatBase::PrepareForCompute(OpKernelContext* ctx, const ConcatBase::InlinedTensorsVector& input_tensors,
Prepare& p) const {