Skip to content

Commit

Permalink
Implement Pad-18 on Cuda.
Browse files Browse the repository at this point in the history
  • Loading branch information
yuslepukhin committed Jan 19, 2024
1 parent edb568a commit 1417af3
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 32 deletions.
8 changes: 7 additions & 1 deletion onnxruntime/core/providers/cpu/cpu_provider_shared.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,13 @@ struct ProviderHostCPUImpl : ProviderHostCPU {
const TensorShape& indice_shape,
const TensorShape& update_shape) override { return ScatterND::ValidateShapes(input_shape, indice_shape, update_shape); }
// From cpu/tensor/padbase.h (direct)
Status PadBase__HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, TensorShape& output_shape) override { return PadBase::HandleDimValueZero(mode, input_shape, output_shape); }
Status PadBase__HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, const TensorShape& output_shape) override { return PadBase::HandleDimValueZero(mode, input_shape, output_shape); }

Check warning on line 90 in onnxruntime/core/providers/cpu/cpu_provider_shared.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/cpu/cpu_provider_shared.cc#L90

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/cpu/cpu_provider_shared.cc:90:  Lines should be <= 120 characters long  [whitespace/line_length] [2]

void PadBase__ComputePads(OpKernelContext* ctx, size_t data_rank, gsl::span<const int64_t> pads_data,
PadsVector& pads) override {
PadBase::ComputePads(ctx, data_rank, pads_data, pads);
}

// From cpu/tensor/split.h (direct)
Status SplitBase__PrepareForCompute(const SplitBase* p, const TensorShape& input_shape, int num_outputs, int64_t& axis, int& before_dims,
int& after_dims_including_split_axis, int& after_dims_excluding_split,
Expand Down
8 changes: 7 additions & 1 deletion onnxruntime/core/providers/cpu/cpu_provider_shared.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class UnsqueezeBase__Prepare; // Directly maps to UnsqueezeBase::Pr
class contrib__AdamWOptimizerBase__Prepare;
class contrib__SGDOptimizerV2Base__Prepare;

using PadsVector = InlinedVector<int64_t, kTensorShapeSmallBufferElementsSize * 2>;

struct ProviderHostCPU {
// From cpu/tensor/gatherbase.h
virtual Status GatherBase__PrepareForCompute(const GatherBase* p, OpKernelContext* context, GatherBase__Prepare& prepare) = 0;
Expand All @@ -44,7 +46,11 @@ struct ProviderHostCPU {
const TensorShape& indice_shape,
const TensorShape& update_shape) = 0;
// From cpu/tensor/padbase.h
virtual Status PadBase__HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, TensorShape& output_shape) = 0;
virtual Status PadBase__HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, const TensorShape& output_shape) = 0;

Check warning on line 49 in onnxruntime/core/providers/cpu/cpu_provider_shared.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/cpu/cpu_provider_shared.h#L49

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/cpu/cpu_provider_shared.h:49:  Lines should be <= 120 characters long  [whitespace/line_length] [2]

virtual void PadBase__ComputePads(OpKernelContext* ctx, size_t data_rank, gsl::span<const int64_t> pads_data,
PadsVector& pads) = 0;

// From cpu/tensor/split.h
virtual Status SplitBase__PrepareForCompute(const SplitBase* p, const TensorShape& input_shape, int num_outputs, int64_t& axis, int& before_dims,
int& after_dims_including_split_axis, int& after_dims_excluding_split,
Expand Down
9 changes: 7 additions & 2 deletions onnxruntime/core/providers/cpu/tensor/pad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ ONNX_CPU_OPERATOR_KERNEL(

using PadsVector = PadBase::PadsVector;

Status PadBase::HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, TensorShape& output_shape) {
Status PadBase::HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, const TensorShape& output_shape) {
switch (mode) {
case Mode::Constant: {
// default behavior is fine
Expand Down Expand Up @@ -202,7 +202,7 @@ Status PadBase::HandleDimValueZero(const Mode& mode, const TensorShape& input_sh
return Status::OK();
}

void PadBase::ComputePadWithAxes(
static void ComputePadWithAxes(
gsl::span<const int64_t> pads_tensor_raw_data,
std::function<int64_t(size_t)> get_axis,
size_t axes_size,
Expand All @@ -223,6 +223,10 @@ void PadBase::ComputePads(OpKernelContext* ctx, size_t data_rank, gsl::span<cons
const size_t num_axes_dims = axes_tensor->Shape().NumDimensions();
ORT_ENFORCE(num_axes_dims == 1, "Axes tensor should be a 1D tensor ");

const int64_t num_axes = axes_tensor->Shape().Size();
ORT_ENFORCE(pads_data.size() == 2 * num_axes,
"Pads tensor size should be equal to twice the number of explicitly provided axes.");

pads.resize(2 * data_rank, 0);
if (axes_tensor->IsDataType<int32_t>()) {
auto axes_data = axes_tensor->DataAsSpan<int32_t>();
Expand Down Expand Up @@ -644,6 +648,7 @@ Status Pad::Compute(OpKernelContext* ctx) const {
ORT_ENFORCE(pads_tensor_dims.size() == 1 || (pads_tensor_dims.size() == 2 && pads_tensor_dims[0] == 1),
"Pads tensor should be a 1D tensor of shape [2 * num_axes] "
"or a 2D tensor of shape [1, 2 * num_axes]");

const auto pads_data = pads_tensor.DataAsSpan<int64_t>();

// Compute Pads by applying axes if specified otherwise copy the supplied pads.
Expand Down
80 changes: 61 additions & 19 deletions onnxruntime/core/providers/cpu/tensor/padbase.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,41 @@ class PadBase {
// Pads and slices are usually about twice the shapes involved
using PadsVector = InlinedVector<int64_t, kTensorShapeSmallBufferElementsSize * 2>;

// Update the output_shape to make it consistent with numpy handling where there are one or more dimensions
// in the input_shape with a value of zero.
static Status HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, TensorShape& output_shape);

static void FlattenInnerShape(gsl::span<const int64_t> input_dims, gsl::span<const int64_t> pads,
gsl::span<const int64_t> slices, TensorShapeVector& reshaped_dims);

static void ReshapePads(gsl::span<const int64_t> src_pad, size_t src_dim_count, size_t new_dim_count,
size_t inner_no_pad_size, PadsVector& reshaped_pad);

// Compute Pads by applying axes if specified otherwise copy the supplied pads.
// The following several functions are shared among the providers

/// <summary>
/// Update the output_shape to make it consistent with numpy handling where there are one or more dimensions
/// in the input_shape with a value of zero.
/// </summary>
/// <param name="mode">Padding mode enum value</param>
/// <param name="input_shape">actual input shape</param>
/// <param name="output_shape">output_shape</param>
/// <returns>Error if current mode padding can not be achieved with zero dim values</returns>
static Status HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, const TensorShape& output_shape);

/// <summary>
/// Compute Pads by applying axes if specified otherwise copy the supplied pads.
///
/// The function queries optional axes input (since version 18) and if present,
/// applies it as a mask to the pads. If axes is not present, the pads are copied as is.
/// If axes are present, they are used as a mask over pads, so only those axes are being padded.
/// </summary>
/// <param name="ctx">kernel context to query axes input</param>
/// <param name="data_rank">input rank</param>
/// <param name="pads_data">pads data from pads input</param>
/// <param name="pads">resulting pads</param>
static void ComputePads(OpKernelContext* ctx, size_t data_rank, gsl::span<const int64_t> pads_data,
PadsVector& pads);

static void ComputePadWithAxes(
gsl::span<const int64_t> pads_tensor_raw_data,
std::function<int64_t(size_t)> get_axis,
size_t axes_size,
size_t data_rank,
PadsVector& pads);

// Separate out any negative pads into the slices array
/// <summary>
/// Separates negative pad values to slices and zeros them out in original pads.
/// Leaving the rest of slices values as zero.
///
/// This function is used inline in the Pad CUDA implementation and is not exposed via a provider
/// interfaces.
/// </summary>
/// <param name="pads">pad values</param>
/// <param name="slices">slices output</param>
static void SeparateNegativeToSlices(gsl::span<int64_t> pads, PadsVector& slices) {
slices.assign(pads.size(), 0);
for (size_t index = 0, lim = pads.size(); index < lim; index++) {
Expand All @@ -52,6 +65,35 @@ class PadBase {
}
}

// End provider shared

/// <summary>
/// Flatten no padding inner most Axis, so one memcpy cover multiple Axis.
/// For example, for a shape of [1,224,224,3] with padding [0,3,3,0,0,3,3,0], can be flatten as
/// [1,224,224*3] with padding [0,3,3*3,0,3,3*3].
///
/// This is a helper function pads are expected to be twice the rank
/// </summary>
/// <param name="input_dims">original input dims</param>
/// <param name="pads">pad values</param>
/// <param name="slices">slices</param>
/// <param name="reshaped_dims">result dims</param>
static void FlattenInnerShape(gsl::span<const int64_t> input_dims, gsl::span<const int64_t> pads,
gsl::span<const int64_t> slices, TensorShapeVector& reshaped_dims);

/// <summary>
/// Used after the inner shape is flattened, so we can apply this function to pads and slices
/// to reshape them as well.
/// </summary>
/// <param name="src_pad">pads</param>
/// <param name="src_dim_count">original dim count</param>
/// <param name="new_dim_count">expected flattended dim count</param>
/// <param name="inner_no_pad_size">is the left most dimension that was flattened.
/// In the example above, that would be 224, reverse computed from 224*3</param>
/// <param name="reshaped_pad">resulting reshaped pads or slices</param>
static void ReshapePads(gsl::span<const int64_t> src_pad, size_t src_dim_count, size_t new_dim_count,
size_t inner_no_pad_size, PadsVector& reshaped_pad);

protected:
PadBase(const OpKernelInfo& info) : value_(info.GetAttrOrDefault("value", 0.f)) {
std::string mode;
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1125,6 +1125,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Pad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Pad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, bool, Pad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, Pad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Pad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, bool, Pad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, SpaceToDepth);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, DepthToSpace);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int8_t, Sign);
Expand Down Expand Up @@ -2012,6 +2016,10 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, bool, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad)>,

Check warning on line 2021 in onnxruntime/core/providers/cuda/cuda_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/cuda/cuda_execution_provider.cc#L2021

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/cuda/cuda_execution_provider.cc:2021:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, bool, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, SpaceToDepth)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, DepthToSpace)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int8_t, Sign)>,
Expand Down
22 changes: 14 additions & 8 deletions onnxruntime/core/providers/cuda/tensor/pad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,18 @@ namespace cuda {
.InputMemoryType(OrtMemTypeCPUInput, 1) \
.InputMemoryType(OrtMemTypeCPUInput, 2) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
Pad<T>); \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
Pad, \
kOnnxDomain, \
18, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.InputMemoryType(OrtMemTypeCPUInput, 1) \
.InputMemoryType(OrtMemTypeCPUInput, 2) \
.InputMemoryType(OrtMemTypeCPUInput, 3) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
Pad<T>);

using PadsVector = PadBase::PadsVector;
Expand Down Expand Up @@ -94,18 +106,12 @@ Status Pad<T>::ComputeInternal(OpKernelContext* ctx) const {
if (is_dynamic_) {
const Tensor& pads_tensor = *ctx->Input<Tensor>(1);
const auto pads_tensor_dims = pads_tensor.Shape().GetDims();
ORT_ENFORCE(utils::IsPrimitiveDataType<int64_t>(pads_tensor.DataType()),
"Pads tensor should be an INT64 tensor");
ORT_ENFORCE(pads_tensor_dims.size() == 1 || (pads_tensor_dims.size() == 2 && pads_tensor_dims[0] == 1),
"Pads tensor should be a 1D tensor of shape [2 * input_rank] or a 2D tensor of shape [1, 2 * input_rank]");
"Pads tensor should be a 1D tensor of shape [2 * num_axes] or a 2D tensor of shape [1, 2 * num_axes]");

const auto pads_data = pads_tensor.DataAsSpan<int64_t>();
const size_t pads_size = static_cast<size_t>(pads_tensor.Shape().Size());
ORT_ENFORCE(pads_size == 2 * static_cast<size_t>(dimension_count),
"Pads tensor size should be equal to twice the input dimension count ");

pads.reserve(pads_size);
pads.assign(pads_data.begin(), pads_data.end());
PadBase::ComputePads(ctx, input_shape.NumDimensions(), pads_data, pads);

// Separate out any negative pads into the slices array
PadBase::SeparateNegativeToSlices(pads, slices);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,14 @@ Status ScatterND::ValidateShapes(const TensorShape& input_shape,
const TensorShape& indice_shape,
const TensorShape& update_shape) { return g_host_cpu.ScatterNDBase__ValidateShapes(input_shape, indice_shape, update_shape); }

Status PadBase::HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, TensorShape& output_shape) { return g_host_cpu.PadBase__HandleDimValueZero(mode, input_shape, output_shape); }
Status PadBase::HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, const TensorShape& output_shape) {
return g_host_cpu.PadBase__HandleDimValueZero(mode, input_shape, output_shape);
}

void PadBase::ComputePads(OpKernelContext* ctx, size_t data_rank, gsl::span<const int64_t> pads_data,
PadsVector& pads) {
g_host_cpu.PadBase__ComputePads(ctx, data_rank, pads_data, pads);
}

Status ConcatBase::PrepareForCompute(OpKernelContext* ctx, const ConcatBase::InlinedTensorsVector& input_tensors,
Prepare& p) const {
Expand Down

0 comments on commit 1417af3

Please sign in to comment.