Skip to content

Commit

Permalink
CUDA Resize-18 implementation (#19595)
Browse files Browse the repository at this point in the history
### Description
Implement Resize-18 on CUDA.

### Motivation and Context
Performance
  • Loading branch information
yuslepukhin authored Feb 29, 2024
1 parent d5606cd commit 5ee62a6
Show file tree
Hide file tree
Showing 20 changed files with 2,090 additions and 395 deletions.
3 changes: 2 additions & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,8 @@ Do not modify directly.*
|||13|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **shape** = tensor(int64)|
|||[5, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **shape** = tensor(int64)|
|||[1, 4]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Resize|*in* X:**T**<br> *in* scales:**tensor(float)**<br> *out* Y:**T**<br><br>or<br><br>*in* X:**T1**<br> *in* roi:**T2**<br> *in* scales:**tensor(float)**<br> *in* sizes:**tensor(int64)**<br> *out* Y:**T1**|13+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
|Resize|*in* X:**T**<br> *in* scales:**tensor(float)**<br> *out* Y:**T**<br><br>or<br><br>*in* X:**T1**<br> *in* roi:**T2**<br> *in* scales:**tensor(float)**<br> *in* sizes:**tensor(int64)**<br> *out* Y:**T1**|18+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
|||[13, 17]|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
|||[11, 12]|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
|||10|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
|ReverseSequence|*in* input:**T**<br> *in* sequence_lens:**tensor(int64)**<br> *out* Y:**T**|10+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/core/providers/cpu/cpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2008,8 +2008,10 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
Greater)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Less)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Less)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, Less)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int64_t, Less)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13,
int32_t, Less)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13,
int64_t, Less)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 13,
float, Add)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 13,
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/core/providers/cpu/cpu_provider_shared.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "core/providers/cpu/tensor/tile.h"
#include "core/providers/cpu/tensor/gather_elements.h"
#include "core/providers/cpu/tensor/unsqueeze.h"
#include "core/providers/cpu/tensor/upsamplebase.h"

#ifndef DISABLE_CONTRIB_OPS
#include "contrib_ops/cpu/bert/attention_base.h"
Expand Down Expand Up @@ -62,6 +63,7 @@
#endif

#include "cpu_provider_shared.h"
#include <limits>

namespace onnxruntime {
// The suppressed warning is: "The type with a virtual function needs either public virtual or protected nonvirtual destructor."
Expand Down Expand Up @@ -292,6 +294,12 @@ struct ProviderHostCPUImpl : ProviderHostCPU {
Status Sampling__Compute(const contrib::transformers::Sampling* p, OpKernelContext* ctx) override { return p->contrib::transformers::Sampling::Compute(ctx); }
Status Sampling__SetupSubgraphExecutionInfo(contrib::transformers::Sampling* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) override { return p->contrib::transformers::Sampling::SetupSubgraphExecutionInfo(session_state, attribute_name, subgraph_session_state); }

void UpsampleBase__AdjustOutputSizeAsPolicy(const UpsampleBase* p, TensorShapeVector& output_dims,
gsl::span<const int64_t> input_dims,
InlinedVector<float>& scales) const override {
p->AdjustOutputSizeAsPolicy(output_dims, input_dims, scales);
}

#ifdef ENABLE_ATEN
Status ATen__Compute(const contrib::ATen* p, OpKernelContext* p_ctx) override { return p->ATen::Compute(p_ctx); }
#endif
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/providers/cpu/cpu_provider_shared.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class SliceOp__PrepareForComputeMetadata; // Directly maps to SliceOp::PrepareF
class UnsqueezeBase__Prepare; // Directly maps to UnsqueezeBase::Prepare
class contrib__AdamWOptimizerBase__Prepare;
class contrib__SGDOptimizerV2Base__Prepare;
class UpsampleBase;

using PadsVector = InlinedVector<int64_t, kTensorShapeSmallBufferElementsSize * 2>;

Expand Down Expand Up @@ -202,6 +203,10 @@ struct ProviderHostCPU {
virtual Status Sampling__Compute(const contrib::transformers::Sampling* p, OpKernelContext* ctx) = 0;
virtual Status Sampling__SetupSubgraphExecutionInfo(contrib::transformers::Sampling* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) = 0;

virtual void UpsampleBase__AdjustOutputSizeAsPolicy(const UpsampleBase* p, TensorShapeVector& output_dims,
gsl::span<const int64_t> input_dims,
InlinedVector<float>& scales) const = 0;

#ifdef ENABLE_ATEN
virtual Status ATen__Compute(const contrib::ATen* p, OpKernelContext* p_ctx) = 0;
#endif
Expand Down
79 changes: 62 additions & 17 deletions onnxruntime/core/providers/cpu/tensor/upsample.cc
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/cpu/tensor/upsample.h"

#include <limits>

#include "core/common/inlined_containers.h"
#include "core/common/safeint.h"
#include "core/platform/threadpool.h"
#include "core/providers/cpu/tensor/upsample.h"
#include "core/providers/cpu/tensor/upsample_antialias.h"

using namespace onnxruntime::common;
using namespace std;
using onnxruntime::narrow;
Expand All @@ -30,6 +35,46 @@ REGISTER_VERSIONED_TYPED_KERNEL(int32_t, 9, 9);
REGISTER_VERSIONED_TYPED_KERNEL(int8_t, 9, 9);
REGISTER_VERSIONED_TYPED_KERNEL(uint8_t, 9, 9);

void UpsampleBase::AdjustOutputSizeAsPolicy(TensorShapeVector& output_dims, gsl::span<const int64_t> input_dims,
InlinedVector<float>& scales) const {
// AspectRatioPolicy::STRETCH is default policy when opset < 18
if (keep_aspect_ratio_policy_ == AspectRatioPolicy::STRETCH) {
return;
}

InlinedHashSet<int64_t> axes_set(axes_.begin(), axes_.end());

float scale_in_policy = 0.0f;
if (keep_aspect_ratio_policy_ == AspectRatioPolicy ::NOT_LARGER) {
scale_in_policy = std::numeric_limits<float>::max();

for (size_t i = 0; i < scales.size(); i++) {
if (axes_set.empty() || axes_set.count(i) > 0) {
scale_in_policy = std::min(scale_in_policy, scales[i]);
}
}
} else if (keep_aspect_ratio_policy_ == AspectRatioPolicy ::NOT_SMALLER) {
scale_in_policy = std::numeric_limits<float>::min();

for (size_t i = 0; i < scales.size(); i++) {
if (axes_set.empty() || axes_set.count(i) > 0) {
scale_in_policy = std::max(scale_in_policy, scales[i]);
}
}
}

for (size_t i = 0; i < scales.size(); i++) {
// if axes is not specified (AKA axes_set.empty()), we apply the policy to all axes
if (axes_set.empty() || axes_set.count(i) > 0) {
scales[i] = scale_in_policy;
output_dims[i] = static_cast<int64_t>(std::round(scales[i] * input_dims[i]));
} else {
scales[i] = 1.0f;
output_dims[i] = input_dims[i];
}
}
}

template <typename T>
void UpsampleNearest2x(int64_t batch_size,
int64_t num_channels,
Expand Down Expand Up @@ -94,8 +139,8 @@ UpsampleNearestSetupInputMappings(int64_t n_dim,
const TensorShape& input_shape,
const TensorShape& output_shape,
const std::vector<int64_t>& input_dim_factor,
const vector<float>& scales,
const vector<float>& roi,
gsl::span<const float> scales,
gsl::span<const float> roi,
bool extrapolation_enabled,
const GetOriginalCoordinateFunc& get_original_coordinate,
const GetNearestPixelFunc& get_nearest_pixel) {
Expand Down Expand Up @@ -141,8 +186,8 @@ static Status UpsampleNearestImpl(const T* input,
T* output,
const TensorShape& input_shape,
const TensorShape& output_shape,
const vector<float>& scales,
const vector<float>& roi,
gsl::span<const float> scales,
gsl::span<const float> roi,
bool extrapolation_enabled,
const T extrapolation_value,
const GetOriginalCoordinateFunc& get_original_coordinate,
Expand Down Expand Up @@ -285,8 +330,8 @@ static Status UpsampleNearest(const T* input,
T* output,
const TensorShape& input_shape,
const TensorShape& output_shape,
const vector<float>& scales,
const vector<float>& roi,
gsl::span<const float> scales,
gsl::span<const float> roi,
bool is_resize,
bool extrapolation_enabled,
T extrapolation_value,
Expand Down Expand Up @@ -412,7 +457,7 @@ BilinearParams SetupUpsampleBilinear(const int32_t input_height,
const int32_t output_width,
const float height_scale,
const float width_scale,
const std::vector<float>& roi,
gsl::span<const float> roi,
AllocatorPtr& alloc,
const GetOriginalCoordinateFunc& get_original_coordinate,
const bool is_nchw) {
Expand Down Expand Up @@ -518,7 +563,7 @@ BilinearParamsInteger SetupUpsampleBilinearInteger(const int32_t input_height,
const int32_t output_width,
const float height_scale,
const float width_scale,
const std::vector<float>& roi,
gsl::span<const float> roi,
AllocatorPtr& alloc,
const GetOriginalCoordinateFunc& get_original_coordinate,
const bool is_nchw) {
Expand Down Expand Up @@ -650,7 +695,7 @@ static TrilinearParams SetupUpsampleTrilinear(int64_t input_depth,
float depth_scale,
float height_scale,
float width_scale,
const std::vector<float>& roi,
gsl::span<const float> roi,
AllocatorPtr& alloc,
const GetOriginalCoordinateFunc& get_original_coordinate) {
TrilinearParams p;
Expand Down Expand Up @@ -796,7 +841,7 @@ void UpsampleTrilinear(int64_t batch_size,
float depth_scale,
float height_scale,
float width_scale,
const std::vector<float>& roi,
gsl::span<const float> roi,
bool use_extrapolation,
float extrapolation_value,
const T* XdataBase,
Expand Down Expand Up @@ -929,7 +974,7 @@ void ResizeBiCubic(int64_t batch_size,
bool use_extrapolation,
float extrapolation_value,
bool exclude_outside,
const std::vector<float>& roi,
gsl::span<const float> roi,
const T* Xdata,
T* Ydata,
const GetOriginalCoordinateFunc& get_original_coordinate) {
Expand Down Expand Up @@ -1067,9 +1112,9 @@ void ResizeBiCubic(int64_t batch_size,

template <typename T>
Status Upsample<T>::BaseCompute(OpKernelContext* context,
const std::vector<float>& roi,
const std::vector<float>& scales,
const gsl::span<const int64_t>& output_dims) const {
gsl::span<const float> roi,
gsl::span<const float> scales,
gsl::span<const int64_t> output_dims) const {
const auto* X = context->Input<Tensor>(0);
auto dims = X->Shape().GetDims();
ORT_RETURN_IF_NOT(output_dims.size() == dims.size(), "Rank of input and output tensor should be same.");
Expand Down Expand Up @@ -1327,7 +1372,7 @@ Status Upsample<T>::Compute(OpKernelContext* context) const {
// Initialize the roi array to all zeros as this will be the most common case
// Roi data is needed only when coordinate transformation mode is set to tf_crop_and_resize
// for all other cases we need a 0 initialized roi array
std::vector<float> roi_array(roi_);
InlinedVector<float> roi_array(roi_);

if (!roi_cached_) {
bool use_default_roi = true;
Expand All @@ -1353,7 +1398,7 @@ Status Upsample<T>::Compute(OpKernelContext* context) const {

ComputeROIWithAxes(roi_array, input_dims.size());
// Get scales data
std::vector<float> scales_array(input_dims.size());
InlinedVector<float> scales_array(input_dims.size());

if (OpKernel::Node().InputDefs().size() == 1) {
// Compute output shape from scales and input dims
Expand Down
14 changes: 7 additions & 7 deletions onnxruntime/core/providers/cpu/tensor/upsample.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ class Upsample : public UpsampleBase, public OpKernel {

Status Compute(OpKernelContext* context) const override;

Status BaseCompute(OpKernelContext* context, const std::vector<float>& roi, const std::vector<float>& scales,
const gsl::span<const int64_t>& output_dims) const;
Status BaseCompute(OpKernelContext* context, gsl::span<const float> roi, gsl::span<const float> scales,
gsl::span<const int64_t> output_dims) const;
};

BilinearParams SetupUpsampleBilinear(const int32_t input_height,
Expand All @@ -76,7 +76,7 @@ BilinearParams SetupUpsampleBilinear(const int32_t input_height,
const int32_t output_width,
const float height_scale,
const float width_scale,
const std::vector<float>& roi,
gsl::span<const float> roi,
AllocatorPtr& alloc,
const GetOriginalCoordinateFunc& get_original_coordinate,
const bool is_nchw);
Expand All @@ -90,7 +90,7 @@ void UpsampleBilinear(const int32_t batch_size,
const int32_t output_width,
const float height_scale,
const float width_scale,
const std::vector<float>& roi,
gsl::span<const float> roi,
const bool use_extrapolation,
const float extrapolation_value,
const T* const XdataBase,
Expand Down Expand Up @@ -144,7 +144,7 @@ void NhwcUpsampleBilinear(const int32_t batch_size,
const int32_t output_width,
const float height_scale,
const float width_scale,
const std::vector<float>& roi,
gsl::span<const float> roi,
const float extrapolation_value,
const T* const XdataBase,
T* const YdataBase,
Expand Down Expand Up @@ -227,7 +227,7 @@ BilinearParamsInteger SetupUpsampleBilinearInteger(const int32_t input_height,
const int32_t output_width,
const float height_scale,
const float width_scale,
const std::vector<float>& roi,
gsl::span<const float> roi,
AllocatorPtr& alloc,
const GetOriginalCoordinateFunc& get_original_coordinate,
const bool is_nchw);
Expand All @@ -241,7 +241,7 @@ void NhwcUpsampleBilinearInteger(const int32_t batch_size,
const int32_t output_width,
const float height_scale,
const float width_scale,
const std::vector<float>& roi,
gsl::span<const float> roi,
const float extrapolation_value,
const T* const XdataBase,
T* const YdataBase,
Expand Down
Loading

0 comments on commit 5ee62a6

Please sign in to comment.