Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA Resize-18 implementation #19595

Merged
merged 14 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,8 @@ Do not modify directly.*
|||13|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **shape** = tensor(int64)|
|||[5, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **shape** = tensor(int64)|
|||[1, 4]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Resize|*in* X:**T**<br> *in* scales:**tensor(float)**<br> *out* Y:**T**<br><br>or<br><br>*in* X:**T1**<br> *in* roi:**T2**<br> *in* scales:**tensor(float)**<br> *in* sizes:**tensor(int64)**<br> *out* Y:**T1**|13+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
|Resize|*in* X:**T**<br> *in* scales:**tensor(float)**<br> *out* Y:**T**<br><br>or<br><br>*in* X:**T1**<br> *in* roi:**T2**<br> *in* scales:**tensor(float)**<br> *in* sizes:**tensor(int64)**<br> *out* Y:**T1**|18+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
|||[13, 17]|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
|||[11, 12]|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
|||10|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
|ReverseSequence|*in* input:**T**<br> *in* sequence_lens:**tensor(int64)**<br> *out* Y:**T**|10+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/core/providers/cpu/cpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2008,8 +2008,10 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
Greater)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Less)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Less)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, Less)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int64_t, Less)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13,
int32_t, Less)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13,
int64_t, Less)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 13,
float, Add)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 13,
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/core/providers/cpu/cpu_provider_shared.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "core/providers/cpu/tensor/tile.h"
#include "core/providers/cpu/tensor/gather_elements.h"
#include "core/providers/cpu/tensor/unsqueeze.h"
#include "core/providers/cpu/tensor/upsamplebase.h"

#ifndef DISABLE_CONTRIB_OPS
#include "contrib_ops/cpu/bert/attention_base.h"
Expand Down Expand Up @@ -62,6 +63,7 @@
#endif

#include "cpu_provider_shared.h"
#include <limits>

namespace onnxruntime {
// The suppressed warning is: "The type with a virtual function needs either public virtual or protected nonvirtual destructor."
Expand Down Expand Up @@ -292,6 +294,12 @@ struct ProviderHostCPUImpl : ProviderHostCPU {
Status Sampling__Compute(const contrib::transformers::Sampling* p, OpKernelContext* ctx) override { return p->contrib::transformers::Sampling::Compute(ctx); }
Status Sampling__SetupSubgraphExecutionInfo(contrib::transformers::Sampling* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) override { return p->contrib::transformers::Sampling::SetupSubgraphExecutionInfo(session_state, attribute_name, subgraph_session_state); }

void UpsampleBase__AdjustOutputSizeAsPolicy(const UpsampleBase* p, TensorShapeVector& output_dims,
gsl::span<const int64_t> input_dims,
InlinedVector<float>& scales) const override {
p->AdjustOutputSizeAsPolicy(output_dims, input_dims, scales);
}

#ifdef ENABLE_ATEN
Status ATen__Compute(const contrib::ATen* p, OpKernelContext* p_ctx) override { return p->ATen::Compute(p_ctx); }
#endif
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/providers/cpu/cpu_provider_shared.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class SliceOp__PrepareForComputeMetadata; // Directly maps to SliceOp::PrepareF
class UnsqueezeBase__Prepare; // Directly maps to UnsqueezeBase::Prepare
class contrib__AdamWOptimizerBase__Prepare;
class contrib__SGDOptimizerV2Base__Prepare;
class UpsampleBase;

using PadsVector = InlinedVector<int64_t, kTensorShapeSmallBufferElementsSize * 2>;

Expand Down Expand Up @@ -202,6 +203,10 @@ struct ProviderHostCPU {
virtual Status Sampling__Compute(const contrib::transformers::Sampling* p, OpKernelContext* ctx) = 0;
virtual Status Sampling__SetupSubgraphExecutionInfo(contrib::transformers::Sampling* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) = 0;

virtual void UpsampleBase__AdjustOutputSizeAsPolicy(const UpsampleBase* p, TensorShapeVector& output_dims,
gsl::span<const int64_t> input_dims,
InlinedVector<float>& scales) const = 0;

#ifdef ENABLE_ATEN
virtual Status ATen__Compute(const contrib::ATen* p, OpKernelContext* p_ctx) = 0;
#endif
Expand Down
79 changes: 62 additions & 17 deletions onnxruntime/core/providers/cpu/tensor/upsample.cc
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/cpu/tensor/upsample.h"

#include <limits>

#include "core/common/inlined_containers.h"
#include "core/common/safeint.h"
#include "core/platform/threadpool.h"
#include "core/providers/cpu/tensor/upsample.h"
#include "core/providers/cpu/tensor/upsample_antialias.h"

using namespace onnxruntime::common;
using namespace std;
using onnxruntime::narrow;
Expand All @@ -30,6 +35,46 @@ REGISTER_VERSIONED_TYPED_KERNEL(int32_t, 9, 9);
REGISTER_VERSIONED_TYPED_KERNEL(int8_t, 9, 9);
REGISTER_VERSIONED_TYPED_KERNEL(uint8_t, 9, 9);

void UpsampleBase::AdjustOutputSizeAsPolicy(TensorShapeVector& output_dims, gsl::span<const int64_t> input_dims,
InlinedVector<float>& scales) const {
yuslepukhin marked this conversation as resolved.
Show resolved Hide resolved
// AspectRatioPolicy::STRETCH is default policy when opset < 18
if (keep_aspect_ratio_policy_ == AspectRatioPolicy::STRETCH) {
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
return;
}

InlinedHashSet<int64_t> axes_set(axes_.begin(), axes_.end());

float scale_in_policy = 0.0f;
if (keep_aspect_ratio_policy_ == AspectRatioPolicy ::NOT_LARGER) {
scale_in_policy = std::numeric_limits<float>::max();

for (size_t i = 0; i < scales.size(); i++) {
if (axes_set.empty() || axes_set.count(i) > 0) {
scale_in_policy = std::min(scale_in_policy, scales[i]);
}
}
} else if (keep_aspect_ratio_policy_ == AspectRatioPolicy ::NOT_SMALLER) {
scale_in_policy = std::numeric_limits<float>::min();

for (size_t i = 0; i < scales.size(); i++) {
if (axes_set.empty() || axes_set.count(i) > 0) {
scale_in_policy = std::max(scale_in_policy, scales[i]);
}
}
}

for (size_t i = 0; i < scales.size(); i++) {
// if axes is not specified (AKA axes_set.empty()), we apply the policy to all axes
if (axes_set.empty() || axes_set.count(i) > 0) {
scales[i] = scale_in_policy;
output_dims[i] = static_cast<int64_t>(std::round(scales[i] * input_dims[i]));
} else {
scales[i] = 1.0f;
output_dims[i] = input_dims[i];
}
}
}

template <typename T>
void UpsampleNearest2x(int64_t batch_size,
int64_t num_channels,
Expand Down Expand Up @@ -94,8 +139,8 @@ UpsampleNearestSetupInputMappings(int64_t n_dim,
const TensorShape& input_shape,
const TensorShape& output_shape,
const std::vector<int64_t>& input_dim_factor,
const vector<float>& scales,
const vector<float>& roi,
gsl::span<const float> scales,
gsl::span<const float> roi,
bool extrapolation_enabled,
const GetOriginalCoordinateFunc& get_original_coordinate,
const GetNearestPixelFunc& get_nearest_pixel) {
Expand Down Expand Up @@ -141,8 +186,8 @@ static Status UpsampleNearestImpl(const T* input,
T* output,
const TensorShape& input_shape,
const TensorShape& output_shape,
const vector<float>& scales,
const vector<float>& roi,
gsl::span<const float> scales,
gsl::span<const float> roi,
bool extrapolation_enabled,
const T extrapolation_value,
const GetOriginalCoordinateFunc& get_original_coordinate,
Expand Down Expand Up @@ -285,8 +330,8 @@ static Status UpsampleNearest(const T* input,
T* output,
const TensorShape& input_shape,
const TensorShape& output_shape,
const vector<float>& scales,
const vector<float>& roi,
gsl::span<const float> scales,
gsl::span<const float> roi,
bool is_resize,
bool extrapolation_enabled,
T extrapolation_value,
Expand Down Expand Up @@ -412,7 +457,7 @@ BilinearParams SetupUpsampleBilinear(const int32_t input_height,
const int32_t output_width,
const float height_scale,
const float width_scale,
const std::vector<float>& roi,
gsl::span<const float> roi,
AllocatorPtr& alloc,
const GetOriginalCoordinateFunc& get_original_coordinate,
const bool is_nchw) {
Expand Down Expand Up @@ -518,7 +563,7 @@ BilinearParamsInteger SetupUpsampleBilinearInteger(const int32_t input_height,
const int32_t output_width,
const float height_scale,
const float width_scale,
const std::vector<float>& roi,
gsl::span<const float> roi,
AllocatorPtr& alloc,
const GetOriginalCoordinateFunc& get_original_coordinate,
const bool is_nchw) {
Expand Down Expand Up @@ -650,7 +695,7 @@ static TrilinearParams SetupUpsampleTrilinear(int64_t input_depth,
float depth_scale,
float height_scale,
float width_scale,
const std::vector<float>& roi,
gsl::span<const float> roi,
AllocatorPtr& alloc,
const GetOriginalCoordinateFunc& get_original_coordinate) {
TrilinearParams p;
Expand Down Expand Up @@ -796,7 +841,7 @@ void UpsampleTrilinear(int64_t batch_size,
float depth_scale,
float height_scale,
float width_scale,
const std::vector<float>& roi,
gsl::span<const float> roi,
bool use_extrapolation,
float extrapolation_value,
const T* XdataBase,
Expand Down Expand Up @@ -929,7 +974,7 @@ void ResizeBiCubic(int64_t batch_size,
bool use_extrapolation,
float extrapolation_value,
bool exclude_outside,
const std::vector<float>& roi,
gsl::span<const float> roi,
const T* Xdata,
T* Ydata,
const GetOriginalCoordinateFunc& get_original_coordinate) {
Expand Down Expand Up @@ -1067,9 +1112,9 @@ void ResizeBiCubic(int64_t batch_size,

template <typename T>
Status Upsample<T>::BaseCompute(OpKernelContext* context,
const std::vector<float>& roi,
const std::vector<float>& scales,
const gsl::span<const int64_t>& output_dims) const {
gsl::span<const float> roi,
gsl::span<const float> scales,
gsl::span<const int64_t> output_dims) const {
const auto* X = context->Input<Tensor>(0);
auto dims = X->Shape().GetDims();
ORT_RETURN_IF_NOT(output_dims.size() == dims.size(), "Rank of input and output tensor should be same.");
Expand Down Expand Up @@ -1327,7 +1372,7 @@ Status Upsample<T>::Compute(OpKernelContext* context) const {
// Initialize the roi array to all zeros as this will be the most common case
// Roi data is needed only when coordinate transformation mode is set to tf_crop_and_resize
// for all other cases we need a 0 initialized roi array
std::vector<float> roi_array(roi_);
InlinedVector<float> roi_array(roi_);

if (!roi_cached_) {
bool use_default_roi = true;
Expand All @@ -1353,7 +1398,7 @@ Status Upsample<T>::Compute(OpKernelContext* context) const {

ComputeROIWithAxes(roi_array, input_dims.size());
// Get scales data
std::vector<float> scales_array(input_dims.size());
InlinedVector<float> scales_array(input_dims.size());

if (OpKernel::Node().InputDefs().size() == 1) {
// Compute output shape from scales and input dims
Expand Down
14 changes: 7 additions & 7 deletions onnxruntime/core/providers/cpu/tensor/upsample.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ class Upsample : public UpsampleBase, public OpKernel {

Status Compute(OpKernelContext* context) const override;

Status BaseCompute(OpKernelContext* context, const std::vector<float>& roi, const std::vector<float>& scales,
const gsl::span<const int64_t>& output_dims) const;
Status BaseCompute(OpKernelContext* context, gsl::span<const float> roi, gsl::span<const float> scales,
gsl::span<const int64_t> output_dims) const;
};

BilinearParams SetupUpsampleBilinear(const int32_t input_height,
Expand All @@ -76,7 +76,7 @@ BilinearParams SetupUpsampleBilinear(const int32_t input_height,
const int32_t output_width,
const float height_scale,
const float width_scale,
const std::vector<float>& roi,
gsl::span<const float> roi,
AllocatorPtr& alloc,
const GetOriginalCoordinateFunc& get_original_coordinate,
const bool is_nchw);
Expand All @@ -90,7 +90,7 @@ void UpsampleBilinear(const int32_t batch_size,
const int32_t output_width,
const float height_scale,
const float width_scale,
const std::vector<float>& roi,
gsl::span<const float> roi,
const bool use_extrapolation,
const float extrapolation_value,
const T* const XdataBase,
Expand Down Expand Up @@ -144,7 +144,7 @@ void NhwcUpsampleBilinear(const int32_t batch_size,
const int32_t output_width,
const float height_scale,
const float width_scale,
const std::vector<float>& roi,
gsl::span<const float> roi,
const float extrapolation_value,
const T* const XdataBase,
T* const YdataBase,
Expand Down Expand Up @@ -227,7 +227,7 @@ BilinearParamsInteger SetupUpsampleBilinearInteger(const int32_t input_height,
const int32_t output_width,
const float height_scale,
const float width_scale,
const std::vector<float>& roi,
gsl::span<const float> roi,
AllocatorPtr& alloc,
const GetOriginalCoordinateFunc& get_original_coordinate,
const bool is_nchw);
Expand All @@ -241,7 +241,7 @@ void NhwcUpsampleBilinearInteger(const int32_t batch_size,
const int32_t output_width,
const float height_scale,
const float width_scale,
const std::vector<float>& roi,
gsl::span<const float> roi,
const float extrapolation_value,
const T* const XdataBase,
T* const YdataBase,
Expand Down
Loading
Loading