diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 31cca232fde34..9d9b266355335 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -682,7 +682,8 @@ Do not modify directly.* |PRelu|*in* X:**T**
*in* slope:**T**
*out* Y:**T**|16+|**T** = tensor(double), tensor(float), tensor(float16)| |||[9, 15]|**T** = tensor(double), tensor(float), tensor(float16)| |||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)| -|Pad|*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*in* axes:**Tind**
*out* output:**T**

or

*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*out* output:**T**

or

*in* data:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)| +|Pad|*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*in* axes:**Tind**
*out* output:**T**

or

*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*out* output:**T**

or

*in* data:**T**
*out* output:**T**|18+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)| +|||[13, 17]|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)| |||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |||[2, 10]|**T** = tensor(double), tensor(float), tensor(float16)| |ParametricSoftplus|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc index 9c55d37f550f4..bf73c59fb78ca 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc @@ -87,7 +87,13 @@ struct ProviderHostCPUImpl : ProviderHostCPU { const TensorShape& indice_shape, const TensorShape& update_shape) override { return ScatterND::ValidateShapes(input_shape, indice_shape, update_shape); } // From cpu/tensor/padbase.h (direct) - Status PadBase__HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, TensorShape& output_shape) override { return PadBase::HandleDimValueZero(mode, input_shape, output_shape); } + Status PadBase__HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, const TensorShape& output_shape) override { return PadBase::HandleDimValueZero(mode, input_shape, output_shape); } + + void PadBase__ComputePads(OpKernelContext& ctx, size_t data_rank, gsl::span pads_data, + PadsVector& pads) override { + PadBase::ComputePads(ctx, data_rank, pads_data, pads); + } + // From cpu/tensor/split.h (direct) Status SplitBase__PrepareForCompute(const SplitBase* p, const TensorShape& input_shape, int num_outputs, int64_t& axis, int& before_dims, int& after_dims_including_split_axis, int& after_dims_excluding_split, diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.h b/onnxruntime/core/providers/cpu/cpu_provider_shared.h index 8dee1cd620282..f33eec4b93e98 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.h +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.h @@ -25,6 +25,8 @@ class UnsqueezeBase__Prepare; // Directly maps to UnsqueezeBase::Pr class contrib__AdamWOptimizerBase__Prepare; class contrib__SGDOptimizerV2Base__Prepare; +using PadsVector = InlinedVector; + struct ProviderHostCPU { // From cpu/tensor/gatherbase.h virtual Status GatherBase__PrepareForCompute(const GatherBase* p, OpKernelContext* context, GatherBase__Prepare& prepare) = 0; @@ -44,7 +46,11 @@ struct ProviderHostCPU { const TensorShape& indice_shape, const TensorShape& update_shape) = 0; // From cpu/tensor/padbase.h - virtual Status PadBase__HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, TensorShape& output_shape) = 0; + virtual Status PadBase__HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, const TensorShape& output_shape) = 0; + + virtual void PadBase__ComputePads(OpKernelContext& ctx, size_t data_rank, gsl::span pads_data, + PadsVector& pads) = 0; + // From cpu/tensor/split.h virtual Status SplitBase__PrepareForCompute(const SplitBase* p, const TensorShape& input_shape, int num_outputs, int64_t& axis, int& before_dims, int& after_dims_including_split_axis, int& after_dims_excluding_split, diff --git a/onnxruntime/core/providers/cpu/tensor/pad.cc b/onnxruntime/core/providers/cpu/tensor/pad.cc index fe5267f20712b..912280687e229 100644 --- a/onnxruntime/core/providers/cpu/tensor/pad.cc +++ b/onnxruntime/core/providers/cpu/tensor/pad.cc @@ -9,6 +9,8 @@ #include "core/providers/op_kernel_type_control.h" #include "core/util/math.h" +#include + // there's no way to use a raw pointer as the copy destination with std::copy_n // (which gsl::copy uses with span::data() which returns a raw pointer) with the 14.11 toolset // without generating a 4996 warning. going through an iterator is way too much overhead so turn off the warning. @@ -167,47 +169,7 @@ ONNX_CPU_OPERATOR_KERNEL( using PadsVector = PadBase::PadsVector; -// This is the general padding method to n-dimensionally do edge or reflection padding (based on the inputDelta values) -template -static void PadAxis(T* output, T* input, ptrdiff_t input_delta, ptrdiff_t input_pitch, - size_t block_size, size_t block_count) { - for (size_t block_index = 0; block_index < block_count; block_index++) { - for (size_t i = 0; i < block_size; i++) { - *output++ = *input; - input += input_delta; - } - input += input_pitch; - } -} - -// These are optimizations of PadAxis. The inner loop is removed since the innermost axis has a blockSize of 1, -// and inputPitch and inputDelta are just a single value added each iteration. -template -static void PadInnermostAxis(T* output, T* input, ptrdiff_t input_delta, size_t block_count) { - for (size_t block_index = 0; block_index < block_count; block_index++) { - *output++ = *input; - input += input_delta; - } -} - -// For constant padding, there is no input, just a size to write the constant to -template -static void PadAxisConstant(T* output, T constant, size_t size) { - if (size == 1) { - *output = constant; - } else if (size == 2) { - *output = constant; - *(output + 1) = constant; - } else { - // This would be faster with SSE instructions. - // That would mean to have an implementation for each type (uint8, uint32, uint64). - T* end = output + size; - for (; output != end;) - *output++ = constant; - } -} - -Status PadBase::HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, TensorShape& output_shape) { +Status PadBase::HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, const TensorShape& output_shape) { switch (mode) { case Mode::Constant: { // default behavior is fine @@ -242,34 +204,66 @@ Status PadBase::HandleDimValueZero(const Mode& mode, const TensorShape& input_sh return Status::OK(); } -// special handling for edge case where the input has one or more dims with value of 0 -template -static Status PadInputWithDimValueOfZero(OpKernelContext* ctx, - const Mode& mode, - const TensorShape& input_shape, - TensorShapeVector& output_dims, - T value) { - TensorShape output_shape(output_dims); - ORT_RETURN_IF_ERROR(PadBase::HandleDimValueZero(mode, input_shape, output_shape)); - - auto& output_tensor = *ctx->Output(0, output_shape); - - // we need to add pads if mode is constant, otherwise the output has one or more dim values of 0 so is empty - if (mode == Mode::Constant) { - // we add pads with the default value to all dims including those with a value of 0 - auto* output = reinterpret_cast(output_tensor.MutableDataRaw()); - std::fill_n(output, output_shape.Size(), value); +static void ComputePadWithAxes( + gsl::span pads_tensor_raw_data, + std::function get_axis, + size_t axes_size, + size_t data_rank, + PadsVector& pads) { + for (size_t i = 0; i < axes_size; ++i) { + const size_t axis = onnxruntime::narrow(HandleNegativeAxis(get_axis(i), data_rank)); + pads[axis] = pads_tensor_raw_data[i]; // xi_begin + pads[data_rank + axis] = pads_tensor_raw_data[axes_size + i]; // xi_end } +} - return Status::OK(); +void PadBase::ComputePads(OpKernelContext& ctx, size_t data_rank, gsl::span pads_data, + PadsVector& pads) { + pads.reserve(2 * data_rank); + const Tensor* axes_tensor = ctx.Input(3); + if (axes_tensor) { + const size_t num_axes_dims = axes_tensor->Shape().NumDimensions(); + ORT_ENFORCE(num_axes_dims == 1, "Axes tensor should be a 1D tensor "); + + const int64_t num_axes = axes_tensor->Shape().Size(); + ORT_ENFORCE(pads_data.size() == narrow(2 * num_axes), + "Pads tensor size should be equal to twice the number of explicitly provided axes."); + + pads.resize(2 * data_rank, 0); + if (axes_tensor->IsDataType()) { + auto axes_data = axes_tensor->DataAsSpan(); + ComputePadWithAxes( + pads_data, + [axes_data](size_t idx) -> int64_t { + return axes_data[idx]; + }, + axes_data.size(), + data_rank, + pads); + } else if (axes_tensor->IsDataType()) { + auto axes_data = axes_tensor->DataAsSpan(); + ComputePadWithAxes( + pads_data, + [axes_data](size_t idx) { + return axes_data[idx]; + }, + axes_data.size(), + data_rank, + pads); + } + } else { + ORT_ENFORCE(pads_data.size() == 2 * data_rank, + "Pads tensor size should be equal to twice the input dimension count "); + pads.assign(pads_data.begin(), pads_data.end()); + } } // Flatten no padding inner most Axis, so one memcpy cover multiple Axis. // For example, for a shape of [1,224,224,3] with padding [0,3,3,0,0,3,3,0], can be flatten as // [1,224,224*3] with padding [0,3,3*3,0,3,3*3]. -static void FlattenInnerShape(const TensorShapeVector& input_dims, const PadsVector& pads, - const PadsVector& slices, TensorShapeVector& reshaped_dims) { - size_t dims_count = input_dims.size(); +void PadBase::FlattenInnerShape(gsl::span input_dims, gsl::span pads, + gsl::span slices, TensorShapeVector& reshaped_dims) { + const size_t dims_count = input_dims.size(); size_t inner_axis = dims_count - 1; size_t inner_size = 1; @@ -288,14 +282,14 @@ static void FlattenInnerShape(const TensorShapeVector& input_dims, const PadsVec } while (inner_axis-- > 0); reshaped_dims.reserve(inner_axis + 1); - std::copy(input_dims.cbegin(), input_dims.cbegin() + inner_axis + 1, std::back_inserter(reshaped_dims)); + std::copy(input_dims.begin(), input_dims.begin() + inner_axis + 1, std::back_inserter(reshaped_dims)); // Flatten inner axis. reshaped_dims[inner_axis] = inner_size; } -static void ReshapePads(const PadsVector& src_pad, size_t src_dim_count, size_t new_dim_count, - size_t inner_no_pad_size, PadsVector& reshaped_pad) { +void PadBase::ReshapePads(gsl::span src_pad, size_t src_dim_count, size_t new_dim_count, + size_t inner_no_pad_size, PadsVector& reshaped_pad) { size_t inner_axis = new_dim_count - 1; std::copy(src_pad.begin(), src_pad.begin() + inner_axis, reshaped_pad.begin()); std::copy(src_pad.begin() + src_dim_count, src_pad.begin() + src_dim_count + inner_axis, @@ -306,6 +300,68 @@ static void ReshapePads(const PadsVector& src_pad, size_t src_dim_count, size_t reshaped_pad[inner_axis + new_dim_count] = src_pad[inner_axis + src_dim_count] * inner_no_pad_size; } +// special handling for edge case where the input has one or more dims with value of 0 +template +static Status PadInputWithDimValueOfZero(OpKernelContext* ctx, + const Mode& mode, + const TensorShape& input_shape, + TensorShapeVector& output_dims, + T value) { + TensorShape output_shape(output_dims); + ORT_RETURN_IF_ERROR(PadBase::HandleDimValueZero(mode, input_shape, output_shape)); + + auto& output_tensor = *ctx->Output(0, output_shape); + + // we need to add pads if mode is constant, otherwise the output has one or more dim values of 0 so is empty + if (mode == Mode::Constant) { + // we add pads with the default value to all dims including those with a value of 0 + auto* output = reinterpret_cast(output_tensor.MutableDataRaw()); + std::fill_n(output, output_shape.Size(), value); + } + + return Status::OK(); +} + +// This is the general padding method to n-dimensionally do edge or reflection padding (based on the inputDelta values) +template +static void PadAxis(T* output, T* input, ptrdiff_t input_delta, ptrdiff_t input_pitch, + size_t block_size, size_t block_count) { + for (size_t block_index = 0; block_index < block_count; block_index++) { + for (size_t i = 0; i < block_size; i++) { + *output++ = *input; + input += input_delta; + } + input += input_pitch; + } +} + +// These are optimizations of PadAxis. The inner loop is removed since the innermost axis has a blockSize of 1, +// and inputPitch and inputDelta are just a single value added each iteration. +template +static void PadInnermostAxis(T* output, T* input, ptrdiff_t input_delta, size_t block_count) { + for (size_t block_index = 0; block_index < block_count; block_index++) { + *output++ = *input; + input += input_delta; + } +} + +// For constant padding, there is no input, just a size to write the constant to +template +static void PadAxisConstant(T* output, T constant, size_t size) { + if (size == 1) { + *output = constant; + } else if (size == 2) { + *output = constant; + *(output + 1) = constant; + } else { + // This would be faster with SSE instructions. + // That would mean to have an implementation for each type (uint8, uint32, uint64). + T* end = output + size; + for (; output != end;) + *output++ = constant; + } +} + template static Status PadImpl(OpKernelContext* ctx, const PadsVector& pads, @@ -327,7 +383,7 @@ static Status PadImpl(OpKernelContext* ctx, // Reshape input dims TensorShapeVector reshaped_input_dims; - FlattenInnerShape(output_dims, pads, slices, reshaped_input_dims); + PadBase::FlattenInnerShape(output_dims, pads, slices, reshaped_input_dims); // Reshape padding size_t new_dims_count = reshaped_input_dims.size(); @@ -336,8 +392,8 @@ static Status PadImpl(OpKernelContext* ctx, ? reshaped_input_dims[inner_axis] / output_dims[inner_axis] : 0); PadsVector reshaped_pad(2 * new_dims_count), reshaped_slice(2 * new_dims_count); - ReshapePads(pads, data_rank, new_dims_count, inner_no_pad_size, reshaped_pad); - ReshapePads(slices, data_rank, new_dims_count, inner_no_pad_size, reshaped_slice); + PadBase::ReshapePads(pads, data_rank, new_dims_count, inner_no_pad_size, reshaped_pad); + PadBase::ReshapePads(slices, data_rank, new_dims_count, inner_no_pad_size, reshaped_slice); TensorShapeVector reshaped_output_dims = reshaped_input_dims; TensorShapeVector input_starts; @@ -575,20 +631,6 @@ static PadValue PadValueFromFloat(float value, MLDataType data_type) { return result; } -template -void ComputePadWithAxes( - gsl::span pads_tensor_raw_data, - gsl::span axes_tensor_raw_data, - size_t data_rank, - PadsVector& pads) { - size_t axes_size = axes_tensor_raw_data.size(); - for (size_t i = 0; i < axes_size; ++i) { - int64_t axis = HandleNegativeAxis(onnxruntime::narrow(axes_tensor_raw_data[i]), data_rank); - pads[onnxruntime::narrow(axis)] = pads_tensor_raw_data[i]; // xi_begin - pads[data_rank + onnxruntime::narrow(axis)] = pads_tensor_raw_data[axes_size + i]; // xi_end - } -} - Status Pad::Compute(OpKernelContext* ctx) const { const Tensor& input_tensor = *ctx->Input(0); MLDataType data_type = input_tensor.DataType(); @@ -608,48 +650,14 @@ Status Pad::Compute(OpKernelContext* ctx) const { ORT_ENFORCE(pads_tensor_dims.size() == 1 || (pads_tensor_dims.size() == 2 && pads_tensor_dims[0] == 1), "Pads tensor should be a 1D tensor of shape [2 * num_axes] " "or a 2D tensor of shape [1, 2 * num_axes]"); - const int64_t* pads_tensor_raw_data = pads_tensor.Data(); - size_t pads_size = static_cast(pads_tensor.Shape().Size()); - pads.reserve(2 * data_rank); - - const Tensor* axes_tensor = ctx->Input(3); - if (axes_tensor) { - const auto& axes_tensor_dims = axes_tensor->Shape().GetDims(); - ORT_ENFORCE(axes_tensor_dims.size() == 1, "Axes tensor should be a 1D tensor "); - int64_t axes_size = axes_tensor_dims[0]; - - pads.resize(2 * data_rank, 0); - if (axes_tensor->IsDataType()) { - const int32_t* axes_tensor_raw_data = axes_tensor->Data(); - ComputePadWithAxes( - {pads_tensor_raw_data, onnxruntime::narrow(2 * axes_size)}, - {axes_tensor_raw_data, onnxruntime::narrow(axes_size)}, - data_rank, - pads); - } else if (axes_tensor->IsDataType()) { - const int64_t* axes_tensor_raw_data = axes_tensor->Data(); - ComputePadWithAxes( - {pads_tensor_raw_data, onnxruntime::narrow(2 * axes_size)}, - {axes_tensor_raw_data, onnxruntime::narrow(axes_size)}, - data_rank, - pads); - } - } else { - ORT_ENFORCE(pads_size == 2 * data_rank, - "Pads tensor size should be equal to twice the input dimension count "); - for (size_t i = 0; i < pads_size; ++i) { - pads.push_back(pads_tensor_raw_data[i]); - } - } + + const auto pads_data = pads_tensor.DataAsSpan(); + + // Compute Pads by applying axes if specified otherwise copy the supplied pads. + PadBase::ComputePads(*ctx, data_rank, pads_data, pads); // Separate out any negative pads into the slices array - slices.assign(pads.size(), 0); - for (size_t index = 0; index < pads.size(); index++) { - if (pads[index] < 0) { - slices[index] = pads[index]; - pads[index] = 0; - } - } + PadBase::SeparateNegativeToSlices(pads, slices); value.u64 = 0U; const Tensor* value_tensor = ctx->Input(2); diff --git a/onnxruntime/core/providers/cpu/tensor/padbase.h b/onnxruntime/core/providers/cpu/tensor/padbase.h index d869ed1a6dda2..43f9cbfc9f9a4 100644 --- a/onnxruntime/core/providers/cpu/tensor/padbase.h +++ b/onnxruntime/core/providers/cpu/tensor/padbase.h @@ -19,9 +19,80 @@ class PadBase { // Pads and slices are usually about twice the shapes involved using PadsVector = InlinedVector; - // Update the output_shape to make it consistent with numpy handling where there are one or more dimensions - // in the input_shape with a value of zero. - static Status HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, TensorShape& output_shape); + // The following several functions are shared among the providers + + /// + /// Handle the case when the input shape has zero dim values. + /// Depending on the mode, the input dim with zero value must match the output dim value. + /// + /// + /// Padding mode enum value + /// actual input shape + /// output_shape + /// Error if current mode padding can not be achieved with zero dim values + static Status HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, const TensorShape& output_shape); + + /// + /// Compute Pads by applying axes if specified otherwise copy the supplied pads. + /// + /// The function queries optional axes input (since version 18) and if present, + /// applies it as a mask to the pads. If axes is not present, the pads are copied as is. + /// If axes are present, they are used as a mask over pads, so only those axes are being padded. + /// + /// kernel context to query axes input + /// input rank + /// pads data from pads input + /// resulting pads + static void ComputePads(OpKernelContext& ctx, size_t data_rank, gsl::span pads_data, + PadsVector& pads); + + /// + /// Separates negative pad values to slices and zeros them out in original pads. + /// Leaving the rest of slices values as zero. + /// + /// This function is used inline in the Pad CUDA implementation and is not exposed via a provider + /// interfaces. + /// + /// pad values + /// slices output + static void SeparateNegativeToSlices(gsl::span pads, PadsVector& slices) { + slices.assign(pads.size(), 0); + for (size_t index = 0, lim = pads.size(); index < lim; index++) { + if (pads[index] < 0) { + slices[index] = pads[index]; + pads[index] = 0; + } + } + } + + // End provider shared + + /// + /// Flatten no padding inner most Axis, so one memcpy cover multiple Axis. + /// For example, for a shape of [1,224,224,3] with padding [0,3,3,0,0,3,3,0], can be flatten as + /// [1,224,224*3] with padding [0,3,3*3,0,3,3*3]. + /// + /// This is a helper function pads are expected to be twice the rank + /// + /// original input dims + /// pad values + /// slices + /// result dims + static void FlattenInnerShape(gsl::span input_dims, gsl::span pads, + gsl::span slices, TensorShapeVector& reshaped_dims); + + /// + /// Used after the inner shape is flattened, so we can apply this function to pads and slices + /// to reshape them as well. + /// + /// pads + /// original dim count + /// expected flattended dim count + /// is the left most dimension that was flattened. + /// In the example above, that would be 224, reverse computed from 224*3 + /// resulting reshaped pads or slices + static void ReshapePads(gsl::span src_pad, size_t src_dim_count, size_t new_dim_count, + size_t inner_no_pad_size, PadsVector& reshaped_pad); protected: PadBase(const OpKernelInfo& info) : value_(info.GetAttrOrDefault("value", 0.f)) { diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 644bcaaa24cd4..3fc4ed355a12b 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1121,10 +1121,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, LRN); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, Identity); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, ScatterND); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Pad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Pad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Pad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, bool, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, float, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, double, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, MLFloat16, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, bool, Pad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, SpaceToDepth); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, DepthToSpace); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int8_t, Sign); @@ -1269,6 +1269,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMax); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMax); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, Pad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Pad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, bool, Pad); // Opset 19 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, float, Cast); @@ -2008,10 +2012,10 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2091,13 +2095,6 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2150,11 +2147,22 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { // Opset 18 BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 19 BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/tensor/pad.cc b/onnxruntime/core/providers/cuda/tensor/pad.cc index 4584e5fd8272c..bdd6567d2ef34 100644 --- a/onnxruntime/core/providers/cuda/tensor/pad.cc +++ b/onnxruntime/core/providers/cuda/tensor/pad.cc @@ -29,15 +29,27 @@ namespace cuda { .InputMemoryType(OrtMemTypeCPUInput, 2) \ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ Pad); \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + Pad, \ + kOnnxDomain, \ + 13, 17, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .InputMemoryType(OrtMemTypeCPUInput, 1) \ + .InputMemoryType(OrtMemTypeCPUInput, 2) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Pad); \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ Pad, \ kOnnxDomain, \ - 13, \ + 18, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()) \ .InputMemoryType(OrtMemTypeCPUInput, 1) \ .InputMemoryType(OrtMemTypeCPUInput, 2) \ + .InputMemoryType(OrtMemTypeCPUInput, 3) \ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ Pad); @@ -94,28 +106,15 @@ Status Pad::ComputeInternal(OpKernelContext* ctx) const { if (is_dynamic_) { const Tensor& pads_tensor = *ctx->Input(1); const auto pads_tensor_dims = pads_tensor.Shape().GetDims(); - ORT_ENFORCE(utils::IsPrimitiveDataType(pads_tensor.DataType()), - "Pads tensor should be an INT64 tensor"); ORT_ENFORCE(pads_tensor_dims.size() == 1 || (pads_tensor_dims.size() == 2 && pads_tensor_dims[0] == 1), - "Pads tensor should be a 1D tensor of shape [2 * input_rank] or a 2D tensor of shape [1, 2 * input_rank]"); + "Pads tensor should be a 1D tensor of shape [2 * num_axes] or a 2D tensor of shape [1, 2 * num_axes]"); - const int64_t* pads_tensor_raw_data = pads_tensor.Data(); - size_t pads_size = static_cast(pads_tensor.Shape().Size()); - ORT_ENFORCE(pads_size == 2 * static_cast(dimension_count), - "Pads tensor size should be equal to twice the input dimension count "); + const auto pads_data = pads_tensor.DataAsSpan(); + + PadBase::ComputePads(*ctx, input_shape.NumDimensions(), pads_data, pads); - pads.reserve(2LL * dimension_count); - for (size_t i = 0; i < pads_size; ++i) { - pads.push_back(pads_tensor_raw_data[i]); - } // Separate out any negative pads into the slices array - slices.resize(pads.size(), 0); - for (size_t index = 0; index < pads.size(); index++) { - if (pads[index] < 0) { - slices[index] = pads[index]; - pads[index] = 0; - } - } + PadBase::SeparateNegativeToSlices(pads, slices); T raw_value{}; const Tensor* value_tensor = ctx->Input(2); diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index d7bec337a6be4..fff3d14b763d5 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -1158,10 +1158,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, LRN); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, Identity); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, ScatterND); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Pad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Pad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Pad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, bool, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, float, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, double, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, MLFloat16, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, bool, Pad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, SpaceToDepth); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, DepthToSpace); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int8_t, Sign); @@ -1298,6 +1298,11 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 17, MLFloat16, LayerNormalization); // Opset 18 +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, Pad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, Pad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, bool, Pad); + class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, Split); // Opset 19 @@ -2088,10 +2093,10 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2228,6 +2233,11 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // Opset 18 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 19 diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index a3155fe6b86cf..e1d0e310425c5 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -547,7 +547,14 @@ Status ScatterND::ValidateShapes(const TensorShape& input_shape, const TensorShape& indice_shape, const TensorShape& update_shape) { return g_host_cpu.ScatterNDBase__ValidateShapes(input_shape, indice_shape, update_shape); } -Status PadBase::HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, TensorShape& output_shape) { return g_host_cpu.PadBase__HandleDimValueZero(mode, input_shape, output_shape); } +Status PadBase::HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, const TensorShape& output_shape) { + return g_host_cpu.PadBase__HandleDimValueZero(mode, input_shape, output_shape); +} + +void PadBase::ComputePads(OpKernelContext& ctx, size_t data_rank, gsl::span pads_data, + PadsVector& pads) { + g_host_cpu.PadBase__ComputePads(ctx, data_rank, pads_data, pads); +} Status ConcatBase::PrepareForCompute(OpKernelContext* ctx, const ConcatBase::InlinedTensorsVector& input_tensors, Prepare& p) const {