diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index b0ed68d595c42..1eaf0fb6dad76 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -734,7 +734,8 @@ Do not modify directly.*
|||13|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)|
|||[5, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)|
|||[1, 4]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
-|Resize|*in* X:**T**
*in* scales:**tensor(float)**
*out* Y:**T**
or
*in* X:**T1**
*in* roi:**T2**
*in* scales:**tensor(float)**
*in* sizes:**tensor(int64)**
*out* Y:**T1**|13+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
+|Resize|*in* X:**T**
*in* scales:**tensor(float)**
*out* Y:**T**
or
*in* X:**T1**
*in* roi:**T2**
*in* scales:**tensor(float)**
*in* sizes:**tensor(int64)**
*out* Y:**T1**|18+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
+|||[13, 17]|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
|||[11, 12]|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
|||10|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
|ReverseSequence|*in* input:**T**
*in* sequence_lens:**tensor(int64)**
*out* Y:**T**|10+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
index 48e4617b33b4d..37e7e42150413 100644
--- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
+++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
@@ -2008,8 +2008,10 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
Greater)>,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo
namespace onnxruntime {
// The suppressed warning is: "The type with a virtual function needs either public virtual or protected nonvirtual destructor."
@@ -292,6 +294,12 @@ struct ProviderHostCPUImpl : ProviderHostCPU {
Status Sampling__Compute(const contrib::transformers::Sampling* p, OpKernelContext* ctx) override { return p->contrib::transformers::Sampling::Compute(ctx); }
Status Sampling__SetupSubgraphExecutionInfo(contrib::transformers::Sampling* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) override { return p->contrib::transformers::Sampling::SetupSubgraphExecutionInfo(session_state, attribute_name, subgraph_session_state); }
+ void UpsampleBase__AdjustOutputSizeAsPolicy(const UpsampleBase* p, TensorShapeVector& output_dims,
+ gsl::span input_dims,
+ InlinedVector& scales) const override {
+ p->AdjustOutputSizeAsPolicy(output_dims, input_dims, scales);
+ }
+
#ifdef ENABLE_ATEN
Status ATen__Compute(const contrib::ATen* p, OpKernelContext* p_ctx) override { return p->ATen::Compute(p_ctx); }
#endif
diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.h b/onnxruntime/core/providers/cpu/cpu_provider_shared.h
index f33eec4b93e98..c0e674827e4d1 100644
--- a/onnxruntime/core/providers/cpu/cpu_provider_shared.h
+++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.h
@@ -24,6 +24,7 @@ class SliceOp__PrepareForComputeMetadata; // Directly maps to SliceOp::PrepareF
class UnsqueezeBase__Prepare; // Directly maps to UnsqueezeBase::Prepare
class contrib__AdamWOptimizerBase__Prepare;
class contrib__SGDOptimizerV2Base__Prepare;
+class UpsampleBase;
using PadsVector = InlinedVector;
@@ -202,6 +203,10 @@ struct ProviderHostCPU {
virtual Status Sampling__Compute(const contrib::transformers::Sampling* p, OpKernelContext* ctx) = 0;
virtual Status Sampling__SetupSubgraphExecutionInfo(contrib::transformers::Sampling* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) = 0;
+ virtual void UpsampleBase__AdjustOutputSizeAsPolicy(const UpsampleBase* p, TensorShapeVector& output_dims,
+ gsl::span input_dims,
+ InlinedVector& scales) const = 0;
+
#ifdef ENABLE_ATEN
virtual Status ATen__Compute(const contrib::ATen* p, OpKernelContext* p_ctx) = 0;
#endif
diff --git a/onnxruntime/core/providers/cpu/tensor/upsample.cc b/onnxruntime/core/providers/cpu/tensor/upsample.cc
index fa69e144be554..babbac0b7be17 100644
--- a/onnxruntime/core/providers/cpu/tensor/upsample.cc
+++ b/onnxruntime/core/providers/cpu/tensor/upsample.cc
@@ -1,10 +1,15 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
+#include "core/providers/cpu/tensor/upsample.h"
+
+#include
+
+#include "core/common/inlined_containers.h"
#include "core/common/safeint.h"
#include "core/platform/threadpool.h"
-#include "core/providers/cpu/tensor/upsample.h"
#include "core/providers/cpu/tensor/upsample_antialias.h"
+
using namespace onnxruntime::common;
using namespace std;
using onnxruntime::narrow;
@@ -30,6 +35,46 @@ REGISTER_VERSIONED_TYPED_KERNEL(int32_t, 9, 9);
REGISTER_VERSIONED_TYPED_KERNEL(int8_t, 9, 9);
REGISTER_VERSIONED_TYPED_KERNEL(uint8_t, 9, 9);
+void UpsampleBase::AdjustOutputSizeAsPolicy(TensorShapeVector& output_dims, gsl::span input_dims,
+ InlinedVector& scales) const {
+ // AspectRatioPolicy::STRETCH is default policy when opset < 18
+ if (keep_aspect_ratio_policy_ == AspectRatioPolicy::STRETCH) {
+ return;
+ }
+
+ InlinedHashSet axes_set(axes_.begin(), axes_.end());
+
+ float scale_in_policy = 0.0f;
+ if (keep_aspect_ratio_policy_ == AspectRatioPolicy ::NOT_LARGER) {
+ scale_in_policy = std::numeric_limits::max();
+
+ for (size_t i = 0; i < scales.size(); i++) {
+ if (axes_set.empty() || axes_set.count(i) > 0) {
+ scale_in_policy = std::min(scale_in_policy, scales[i]);
+ }
+ }
+ } else if (keep_aspect_ratio_policy_ == AspectRatioPolicy ::NOT_SMALLER) {
+ scale_in_policy = std::numeric_limits::min();
+
+ for (size_t i = 0; i < scales.size(); i++) {
+ if (axes_set.empty() || axes_set.count(i) > 0) {
+ scale_in_policy = std::max(scale_in_policy, scales[i]);
+ }
+ }
+ }
+
+ for (size_t i = 0; i < scales.size(); i++) {
+ // if axes is not specified (AKA axes_set.empty()), we apply the policy to all axes
+ if (axes_set.empty() || axes_set.count(i) > 0) {
+ scales[i] = scale_in_policy;
+ output_dims[i] = static_cast(std::round(scales[i] * input_dims[i]));
+ } else {
+ scales[i] = 1.0f;
+ output_dims[i] = input_dims[i];
+ }
+ }
+}
+
template
void UpsampleNearest2x(int64_t batch_size,
int64_t num_channels,
@@ -94,8 +139,8 @@ UpsampleNearestSetupInputMappings(int64_t n_dim,
const TensorShape& input_shape,
const TensorShape& output_shape,
const std::vector& input_dim_factor,
- const vector& scales,
- const vector& roi,
+ gsl::span scales,
+ gsl::span roi,
bool extrapolation_enabled,
const GetOriginalCoordinateFunc& get_original_coordinate,
const GetNearestPixelFunc& get_nearest_pixel) {
@@ -141,8 +186,8 @@ static Status UpsampleNearestImpl(const T* input,
T* output,
const TensorShape& input_shape,
const TensorShape& output_shape,
- const vector& scales,
- const vector& roi,
+ gsl::span scales,
+ gsl::span roi,
bool extrapolation_enabled,
const T extrapolation_value,
const GetOriginalCoordinateFunc& get_original_coordinate,
@@ -285,8 +330,8 @@ static Status UpsampleNearest(const T* input,
T* output,
const TensorShape& input_shape,
const TensorShape& output_shape,
- const vector& scales,
- const vector& roi,
+ gsl::span scales,
+ gsl::span roi,
bool is_resize,
bool extrapolation_enabled,
T extrapolation_value,
@@ -412,7 +457,7 @@ BilinearParams SetupUpsampleBilinear(const int32_t input_height,
const int32_t output_width,
const float height_scale,
const float width_scale,
- const std::vector& roi,
+ gsl::span roi,
AllocatorPtr& alloc,
const GetOriginalCoordinateFunc& get_original_coordinate,
const bool is_nchw) {
@@ -518,7 +563,7 @@ BilinearParamsInteger SetupUpsampleBilinearInteger(const int32_t input_height,
const int32_t output_width,
const float height_scale,
const float width_scale,
- const std::vector& roi,
+ gsl::span roi,
AllocatorPtr& alloc,
const GetOriginalCoordinateFunc& get_original_coordinate,
const bool is_nchw) {
@@ -650,7 +695,7 @@ static TrilinearParams SetupUpsampleTrilinear(int64_t input_depth,
float depth_scale,
float height_scale,
float width_scale,
- const std::vector& roi,
+ gsl::span roi,
AllocatorPtr& alloc,
const GetOriginalCoordinateFunc& get_original_coordinate) {
TrilinearParams p;
@@ -796,7 +841,7 @@ void UpsampleTrilinear(int64_t batch_size,
float depth_scale,
float height_scale,
float width_scale,
- const std::vector& roi,
+ gsl::span roi,
bool use_extrapolation,
float extrapolation_value,
const T* XdataBase,
@@ -929,7 +974,7 @@ void ResizeBiCubic(int64_t batch_size,
bool use_extrapolation,
float extrapolation_value,
bool exclude_outside,
- const std::vector& roi,
+ gsl::span roi,
const T* Xdata,
T* Ydata,
const GetOriginalCoordinateFunc& get_original_coordinate) {
@@ -1067,9 +1112,9 @@ void ResizeBiCubic(int64_t batch_size,
template
Status Upsample::BaseCompute(OpKernelContext* context,
- const std::vector& roi,
- const std::vector& scales,
- const gsl::span& output_dims) const {
+ gsl::span roi,
+ gsl::span scales,
+ gsl::span output_dims) const {
const auto* X = context->Input(0);
auto dims = X->Shape().GetDims();
ORT_RETURN_IF_NOT(output_dims.size() == dims.size(), "Rank of input and output tensor should be same.");
@@ -1327,7 +1372,7 @@ Status Upsample::Compute(OpKernelContext* context) const {
// Initialize the roi array to all zeros as this will be the most common case
// Roi data is needed only when coordinate transformation mode is set to tf_crop_and_resize
// for all other cases we need a 0 initialized roi array
- std::vector roi_array(roi_);
+ InlinedVector roi_array(roi_);
if (!roi_cached_) {
bool use_default_roi = true;
@@ -1353,7 +1398,7 @@ Status Upsample::Compute(OpKernelContext* context) const {
ComputeROIWithAxes(roi_array, input_dims.size());
// Get scales data
- std::vector scales_array(input_dims.size());
+ InlinedVector scales_array(input_dims.size());
if (OpKernel::Node().InputDefs().size() == 1) {
// Compute output shape from scales and input dims
diff --git a/onnxruntime/core/providers/cpu/tensor/upsample.h b/onnxruntime/core/providers/cpu/tensor/upsample.h
index 3046ee4b8260d..8ff04781f6ad0 100644
--- a/onnxruntime/core/providers/cpu/tensor/upsample.h
+++ b/onnxruntime/core/providers/cpu/tensor/upsample.h
@@ -66,8 +66,8 @@ class Upsample : public UpsampleBase, public OpKernel {
Status Compute(OpKernelContext* context) const override;
- Status BaseCompute(OpKernelContext* context, const std::vector& roi, const std::vector& scales,
- const gsl::span& output_dims) const;
+ Status BaseCompute(OpKernelContext* context, gsl::span roi, gsl::span scales,
+ gsl::span output_dims) const;
};
BilinearParams SetupUpsampleBilinear(const int32_t input_height,
@@ -76,7 +76,7 @@ BilinearParams SetupUpsampleBilinear(const int32_t input_height,
const int32_t output_width,
const float height_scale,
const float width_scale,
- const std::vector& roi,
+ gsl::span roi,
AllocatorPtr& alloc,
const GetOriginalCoordinateFunc& get_original_coordinate,
const bool is_nchw);
@@ -90,7 +90,7 @@ void UpsampleBilinear(const int32_t batch_size,
const int32_t output_width,
const float height_scale,
const float width_scale,
- const std::vector& roi,
+ gsl::span roi,
const bool use_extrapolation,
const float extrapolation_value,
const T* const XdataBase,
@@ -144,7 +144,7 @@ void NhwcUpsampleBilinear(const int32_t batch_size,
const int32_t output_width,
const float height_scale,
const float width_scale,
- const std::vector& roi,
+ gsl::span roi,
const float extrapolation_value,
const T* const XdataBase,
T* const YdataBase,
@@ -227,7 +227,7 @@ BilinearParamsInteger SetupUpsampleBilinearInteger(const int32_t input_height,
const int32_t output_width,
const float height_scale,
const float width_scale,
- const std::vector& roi,
+ gsl::span roi,
AllocatorPtr& alloc,
const GetOriginalCoordinateFunc& get_original_coordinate,
const bool is_nchw);
@@ -241,7 +241,7 @@ void NhwcUpsampleBilinearInteger(const int32_t batch_size,
const int32_t output_width,
const float height_scale,
const float width_scale,
- const std::vector& roi,
+ gsl::span roi,
const float extrapolation_value,
const T* const XdataBase,
T* const YdataBase,
diff --git a/onnxruntime/core/providers/cpu/tensor/upsample_antialias.h b/onnxruntime/core/providers/cpu/tensor/upsample_antialias.h
index e1dcaf500a325..1e32b7e874b1a 100644
--- a/onnxruntime/core/providers/cpu/tensor/upsample_antialias.h
+++ b/onnxruntime/core/providers/cpu/tensor/upsample_antialias.h
@@ -21,32 +21,6 @@
namespace onnxruntime {
-namespace ConstValue {
-constexpr int32_t mag_factor = 1 << (22 - 1);
-}
-
-namespace {
-const uint8_t* GetLookupTableShared() {
- // initialized once
- static const auto* lookup_table = []() {
- // if we have already initialized the lookup table, just return
- // ideally we could have a global lookup table, but that account for too much space.
- /* Handles values form -640 to 639. */
- static uint8_t table[1280] = {0};
-
- // taken from https://github.com/python-pillow/Pillow/blob/66add095a50d76c35c7f58643461f2edf78a3f05/src/libImaging/Resample.c#L94
- // we need to handle negative values
- // it's equivalent to :x = np.clip(x, 0, 255) where x \in [-640, 639]
- // we will accept a negative x for (&table[640])[x] means table +640 -x
- for (int i = 0; i < 1280; ++i) {
- table[i] = static_cast(std::min(std::max(i - 640, 0), 255));
- }
- return table;
- }();
- return lookup_table;
-}
-} // namespace
-
template
struct FilterParamsBaseAntiAlias {
std::vector bound;
@@ -57,15 +31,15 @@ struct FilterParamsBaseAntiAlias {
template
struct FilterParamsAntiAlias {
- float support_size = 2.0f;
- float cubic_coeff_a = -0.75f;
+ float support_size = antialias_constants::kSupportSize;
+ float cubic_coeff_a = antialias_constants::kCubicCoeffA;
FilterParamsBaseAntiAlias dim_x;
FilterParamsBaseAntiAlias dim_y;
FilterParamsBaseAntiAlias dim_z;
const uint8_t* GetClip8LookupTable() const {
- return GetLookupTableShared();
+ return UpsampleBase::GetLookupTableShared();
}
virtual ~FilterParamsAntiAlias() = default;
virtual float Filter(float x) const = 0;
@@ -89,7 +63,7 @@ struct BilinearParamsAntiAlias : FilterParamsAntiAlias {
template
struct BiCubicParamsAntiAlias : FilterParamsAntiAlias {
BiCubicParamsAntiAlias() {
- this->support_size = 4.0f;
+ this->support_size = antialias_constants::kBiCubicSupportSize;
}
// taken from
@@ -124,27 +98,6 @@ struct TriLinearParamsAntiAlias : FilterParamsAntiAlias {
}
};
-template
-struct AccumulateType {
- using type = int32_t;
- using Dtype = T;
-};
-
-template <>
-struct AccumulateType {
- using type = float;
-};
-
-template <>
-struct AccumulateType {
- using type = float;
-};
-
-template <>
-struct AccumulateType {
- using type = double;
-};
-
// The following method supports a 3/4/5-D input in 'Linear mode, cubic mode'
// that amounts to 'Bilinear,TriLinear, Bicubic/Tricubic' Upsampling/Resizing in the sense that it assumes
// A N-D tensor has
@@ -156,19 +109,20 @@ struct AccumulateType {
// - [N, H, W, C] and the scales are [1.0, height_scale, width_scale, 1.0]
template
void SetupUpsampleFilterAntiAlias(FilterParamsAntiAlias& p,
- const gsl::span input_h_w_c,
- const gsl::span output_h_w_c,
- const gsl::span scale_h_w_c,
- const std::vector& roi,
+ gsl::span input_h_w_c,
+ gsl::span output_h_w_c,
+ gsl::span scale_h_w_c,
+ gsl::span roi,
AllocatorPtr& alloc,
const GetOriginalCoordinateFunc& get_original_coordinate,
bool exclude_outside, const bool is_nchw) {
- auto compute_weight_coefficients = [&alloc, &roi, &get_original_coordinate, exclude_outside](const FilterParamsAntiAlias& p,
- const int64_t input_size,
- const int64_t output_size,
- size_t rindex,
- FilterParamsBaseAntiAlias& param_base,
- const float rscale) -> int64_t {
+ auto compute_weight_coefficients = [&alloc, roi, &get_original_coordinate, exclude_outside](
+ const FilterParamsAntiAlias& p,
+ const int64_t input_size,
+ const int64_t output_size,
+ size_t rindex,
+ FilterParamsBaseAntiAlias& param_base,
+ const float rscale) -> int64_t {
param_base.bound.reserve(static_cast(output_size) * 2);
param_base.out_of_bound_idx.reserve(static_cast(output_size));
@@ -245,13 +199,14 @@ void SetupUpsampleFilterAntiAlias(FilterParamsAntiAlias& p,
// normalize the scale to 1 << 22 for int8/uint8
if constexpr (std::is_same::value) {
- scale_buffer_int[x] = static_cast(std::round(scale_buffer[x] * ConstValue::mag_factor * 2.f));
+ scale_buffer_int[x] = static_cast(std::round(scale_buffer[x] * ConstValue::mag_factor_x_2));
}
}
/*for (; x < window_size; x++) {
scale_buffer[x] = 0;
}*/
}
+
return window_size;
};
@@ -269,9 +224,6 @@ void SetupUpsampleFilterAntiAlias(FilterParamsAntiAlias& p,
}
}
-template
-inline constexpr bool is_8bit_v = std::is_same::value || std::is_same::value;
-
/**
* @brief To compute interpolation along with the last axis.
* For brief,we assume the input tensor has 3 dimensions and we all it CHW for each character represent a dim.
@@ -398,6 +350,7 @@ void ComputeInterpolationAtLevel2(int64_t num_channels, int64_t input_height, in
output += *Xdata_offset * (*weight_coeff_start++);
Xdata_offset += output_width;
}
+
if constexpr (is_8bit_v) {
*Ydata_offset++ = static_cast(clip8_lookups[output >> 22]);
} else if constexpr (std::is_same::value) {
@@ -444,6 +397,7 @@ void ComputeInterpolationAtLevel2(int64_t num_channels, int64_t input_height, in
output += *Xdata_offset * (*weight_coeff_start++);
Xdata_offset += output_width;
}
+
if constexpr (is_8bit_v) {
*Ydata_offset++ = static_cast(clip8_lookups[output >> 22]);
} else if constexpr (std::is_same::value) {
@@ -515,6 +469,7 @@ void UpsampleBaseAntiAlias(FilterParamsAntiAlias& p,
narrow(input_height * num_channels * input_width));
auto ydata_span = gsl::make_span(image_temp_buffer.get(), narrow(input_height * num_channels * output_width));
+ // This computes only the width direction.Thus height keeps unchanged.
ComputeInterpolationAtLevel1(num_channels, input_height, input_width, input_height, output_width,
xdata_span, ydata_span, p, p.dim_x, tp);
}
@@ -546,7 +501,7 @@ void UpsampleBilinearAntiAlias(const int64_t batch_size,
const int64_t output_width,
const float height_scale,
const float width_scale,
- const std::vector& roi,
+ gsl::span roi,
const bool use_extrapolation,
const float extrapolation_value,
bool exclude_outside,
@@ -575,7 +530,7 @@ void NhwcUpsampleBilinearAntiAlias(const int64_t batch_size,
const int64_t output_width,
const float height_scale,
const float width_scale,
- const std::vector& roi,
+ gsl::span roi,
const bool use_extrapolation,
const float extrapolation_value,
bool exclude_outside,
@@ -608,7 +563,7 @@ void NhwcResizeBiCubicAntiAlias(const int64_t batch_size,
bool use_extrapolation,
float extrapolation_value,
bool exclude_outside,
- const std::vector& roi,
+ gsl::span roi,
const Tensor* X,
T* Ydata_base,
AllocatorPtr& alloc,
@@ -688,7 +643,7 @@ void ResizeBiCubicAntiAlias(int64_t batch_size,
bool use_extrapolation,
float extrapolation_value,
bool exclude_outside,
- const std::vector& roi,
+ gsl::span roi,
const Tensor* X,
T* Ydata_base,
AllocatorPtr& alloc,
@@ -719,7 +674,7 @@ void UpsampleTrilinearAntiAlias(int64_t batch_size,
float depth_scale,
float height_scale,
float width_scale,
- const std::vector& roi,
+ gsl::span roi,
bool use_extrapolation,
float extrapolation_value,
bool exclude_outside,
diff --git a/onnxruntime/core/providers/cpu/tensor/upsamplebase.h b/onnxruntime/core/providers/cpu/tensor/upsamplebase.h
index a0e7ca1084fef..b768fedd8513a 100644
--- a/onnxruntime/core/providers/cpu/tensor/upsamplebase.h
+++ b/onnxruntime/core/providers/cpu/tensor/upsamplebase.h
@@ -3,11 +3,13 @@
#pragma once
+#include
#include
#include
#include
#include
-#include
+
+#include
#include "core/common/status.h"
#include
#include
@@ -58,7 +60,73 @@ enum class AspectRatioPolicy {
NOT_SMALLER,
};
+// Antialias types
+template
+struct AccumulateType {
+ using type = int32_t;
+ using Dtype = T;
+};
+
+template <>
+struct AccumulateType {
+ using type = float;
+};
+
+template <>
+struct AccumulateType {
+ using type = float;
+};
+
+template <>
+struct AccumulateType {
+ using type = float;
+};
+
+template <>
+struct AccumulateType {
+ using type = double;
+};
+
+namespace antialias_constants {
+constexpr float kCubicCoeffA = -0.75f;
+constexpr float kSupportSize = 2.0f;
+constexpr float kBiCubicSupportSize = 4.0f;
+} // namespace antialias_constants
+
+namespace ConstValue {
+constexpr int32_t mag_factor = 1 << (22 - 1);
+// We use to multiply by 2, let's make a constant which is twice as big
+constexpr int32_t mag_factor_x_2 = 1 << 22;
+} // namespace ConstValue
+
+template
+inline constexpr bool is_8bit_v = std::is_same::value || std::is_same::value;
+
+template
+void PrintAntiAliasBuffers(std::ostream& os, gsl::span bounds, gsl::span out_of_bounds,
+ gsl::span weight_coefficients) {
+ os << "#### Bounds: ";
+ std::copy(bounds.begin(), bounds.end(), std::ostream_iterator(os, " "));
+ os << std::endl;
+
+ os << "#### Out of Bounds: ";
+ std::copy(out_of_bounds.begin(), out_of_bounds.end(),
+ std::ostream_iterator(os, " "));
+ os << std::endl;
+
+ os << "#### Scale Buffer: ";
+ std::copy(weight_coefficients.begin(), weight_coefficients.end(),
+ std::ostream_iterator(os, " "));
+ os << std::endl;
+}
+
class UpsampleBase {
+ public:
+ // Make this available in other EP via provider bridge
+ // it works iff output_shape is specified
+ void AdjustOutputSizeAsPolicy(TensorShapeVector& output_dims, gsl::span input_dims,
+ InlinedVector& scales) const;
+
protected:
explicit UpsampleBase(const OpKernelInfo& info)
: scales_cached_(false), roi_cached_(false), use_extrapolation_(false) {
@@ -69,23 +137,32 @@ class UpsampleBase {
std::string mode;
ORT_ENFORCE(info.GetAttr("mode", &mode).IsOK());
mode_ = StringToUpsampleMode(mode);
- antialias_ = info.GetAttrOrDefault("antialias", 0) == 0 ? false : true;
- if (antialias_) {
- ORT_ENFORCE((UpsampleMode::LINEAR == mode_ || UpsampleMode::CUBIC == mode_),
- "when anti-aliasing is set, Resize only supports mode `LINEAR` and `CUBIC`.");
- }
auto input_count = info.GetInputCount();
if (input_count == 1) { // opset < 10
- ORT_THROW_IF_ERROR(info.GetAttrs("scales", scales_));
- ORT_THROW_IF_ERROR(ScalesValidation(scales_, mode_));
+ std::vector scales;
+ ORT_THROW_IF_ERROR(info.GetAttrs("scales", scales));
+ ORT_THROW_IF_ERROR(ScalesValidation(scales, mode_));
+ scales_.assign(scales.cbegin(), scales.cend());
scales_cached_ = true;
}
- std::string keep_aspect_ratio_policy = info.GetAttrOrDefault("keep_aspect_ratio_policy", "stretch");
- keep_aspect_ratio_policy_ = StringToKeepAspectRatioPolicy(keep_aspect_ratio_policy);
+ if (opset >= 18) {
+ antialias_ = info.GetAttrOrDefault("antialias", 0) == 0 ? false : true;
+
+ if (antialias_) {
+ ORT_ENFORCE((UpsampleMode::LINEAR == mode_ || UpsampleMode::CUBIC == mode_),
+ "when anti-aliasing is set, Resize only supports mode `LINEAR` and `CUBIC`.");
+ }
- axes_ = info.GetAttrsOrDefault("axes");
+ // The attribute is absent in opset < 18, but the default value as if stretch.
+ std::string keep_aspect_ratio_policy = info.GetAttrOrDefault("keep_aspect_ratio_policy", "stretch");
+ keep_aspect_ratio_policy_ = StringToKeepAspectRatioPolicy(keep_aspect_ratio_policy);
+
+ // guard against unit tests that can add an attribute
+ auto axes = info.GetAttrsOrDefault("axes");
+ axes_.assign(axes.cbegin(), axes.cend());
+ }
extrapolation_value_ = info.GetAttrOrDefault("extrapolation_value", 0.0f);
@@ -112,7 +189,7 @@ class UpsampleBase {
nearest_mode_ = StringToNearestMode(nearest_mode_name);
get_nearest_pixel_ = GetNearestPixelFromOriginal(nearest_mode_);
- cubic_coeff_a_ = info.GetAttrOrDefault("cubic_coeff_a", -0.75f);
+ cubic_coeff_a_ = info.GetAttrOrDefault("cubic_coeff_a", antialias_constants::kCubicCoeffA);
exclude_outside_ = info.GetAttrOrDefault("exclude_outside", 0) == 0 ? false : true;
if ((exclude_outside_ == 1 && mode_ != CUBIC) && (antialias_ == false || mode_ != LINEAR)) {
@@ -166,7 +243,7 @@ class UpsampleBase {
ResizeCoordinateTransformationMode coordinate_transform_mode_;
GetOriginalCoordinateFunc get_original_coordinate_;
ResizeNearestMode nearest_mode_;
- AspectRatioPolicy keep_aspect_ratio_policy_;
+ AspectRatioPolicy keep_aspect_ratio_policy_{AspectRatioPolicy::STRETCH};
GetNearestPixelFunc get_nearest_pixel_;
float cubic_coeff_a_;
bool exclude_outside_;
@@ -174,9 +251,9 @@ class UpsampleBase {
float extrapolation_value_;
bool use_nearest2x_optimization_ = false;
- std::vector scales_;
- std::vector roi_;
- std::vector axes_;
+ InlinedVector scales_;
+ InlinedVector roi_;
+ TensorShapeVector axes_;
bool scales_cached_;
bool roi_cached_;
@@ -335,7 +412,7 @@ class UpsampleBase {
}
}
- [[nodiscard]] Status ScalesValidation(const std::vector& scales, const UpsampleMode mode) const {
+ [[nodiscard]] Status ScalesValidation(gsl::span scales, const UpsampleMode mode) const {
if (!is_resize_) {
for (auto& scale : scales) {
ORT_RETURN_IF_NOT(scale >= 1, "Scale value should be greater than or equal to 1.");
@@ -372,7 +449,7 @@ class UpsampleBase {
}
[[nodiscard]] Status
- ParseScalesData(const Tensor* scale, std::vector& scales, int64_t rank) const {
+ ParseScalesData(const Tensor* scale, InlinedVector& scales, int64_t rank) const {
const auto* scale_data = scale->Data();
int64_t scales_size = scale->Shape().Size();
ORT_RETURN_IF_NOT(scales_size > 0, "scales size should be greater than 0.");
@@ -387,19 +464,19 @@ class UpsampleBase {
// in which case the other axes is ignored and use default scale of 1
// scales_size == axes_.size() should be guaranteed if axes is not empty
if (rank > 0 && (scales_size != rank || axes_.size())) {
- std::vector new_scales(size_t(rank), 1.0f);
+ InlinedVector new_scales(size_t(rank), 1.0f);
ORT_RETURN_IF_NOT(*std::max_element(axes_.begin(), axes_.end()) < rank && (int64_t(axes_.size()) == scales_size),
"all values in axes should be less than rank of the data");
for (size_t i = 0; i < axes_.size(); i++) {
new_scales[static_cast(axes_[i])] = scales[i];
}
- scales = new_scales;
+ scales.swap(new_scales);
}
return ScalesValidation(scales, mode_);
}
- void ParseRoiData(const Tensor* roi, std::vector& roi_array) const {
+ void ParseRoiData(const Tensor* roi, InlinedVector& roi_array) const {
int64_t roi_size = roi->Shape().Size();
if (roi_size > 0) {
roi_array.resize(onnxruntime::narrow(roi_size));
@@ -429,52 +506,11 @@ class UpsampleBase {
return Status::OK();
}
- // it works iff output_shape is specified
- void AdjustOutputSizeAsPolicy(TensorShapeVector& output_dims, gsl::span input_dims,
- std::vector& scales) const {
- std::unordered_set axes_set(axes_.begin(), axes_.end());
-
- // AspectRatioPolicy::STRETCH is default policy when opset < 18
- if (keep_aspect_ratio_policy_ == AspectRatioPolicy ::STRETCH) {
- return;
- }
-
- float scale_in_policy = 0.0f;
- if (keep_aspect_ratio_policy_ == AspectRatioPolicy ::NOT_LARGER) {
- scale_in_policy = std::numeric_limits::max();
-
- for (size_t i = 0; i < scales.size(); i++) {
- if (axes_set.empty() || axes_set.count(i) > 0) {
- scale_in_policy = std::min(scale_in_policy, scales[i]);
- }
- }
- } else if (keep_aspect_ratio_policy_ == AspectRatioPolicy ::NOT_SMALLER) {
- scale_in_policy = std::numeric_limits::min();
-
- for (size_t i = 0; i < scales.size(); i++) {
- if (axes_set.empty() || axes_set.count(i) > 0) {
- scale_in_policy = std::max(scale_in_policy, scales[i]);
- }
- }
- }
-
- for (size_t i = 0; i < scales.size(); i++) {
- // if axes is not specified (AKA axes_set.empty()), we apply the policy to all axes
- if (axes_set.empty() || axes_set.count(i) > 0) {
- scales[i] = scale_in_policy;
- output_dims[i] = static_cast(std::round(scales[i] * input_dims[i]));
- } else {
- scales[i] = 1.0f;
- output_dims[i] = input_dims[i];
- }
- }
- }
-
// It's different in Opset 18 and before.
// we will modify output_shape by sorts of policy even if it's specified
[[nodiscard]] Status ParseScalesDataAndAdjustOutputSize(TensorShapeVector& output_dims,
gsl::span input_dims,
- std::vector& scales) const {
+ InlinedVector& scales) const {
for (size_t i = 0, end = input_dims.size(); i < end; ++i) {
// Handle corner case to avoid dividing by zero in the next step
if (input_dims[i] == 0) {
@@ -507,9 +543,9 @@ class UpsampleBase {
// Roi is redefined in Opset-18, we have a concept of axes.
// So we need to update it accordingly.
- void ComputeROIWithAxes(std::vector& roi_array, size_t rank) const {
+ void ComputeROIWithAxes(InlinedVector& roi_array, size_t rank) const {
if (axes_.size()) {
- std::vector roi_tmp(rank * 2, 0);
+ InlinedVector roi_tmp(rank * 2, 0);
for (size_t i = rank; i < rank * 2; ++i) {
roi_tmp[i] = 1;
}
@@ -518,9 +554,32 @@ class UpsampleBase {
roi_tmp[v_in_axes] = (roi_array[i]);
roi_tmp[rank + v_in_axes] = (roi_array[axes_.size() + i]);
}
- roi_array = roi_tmp;
+ roi_array.swap(roi_tmp);
}
}
+
+ public:
+ static constexpr size_t kLookupTableSize = 1280;
+
+ static const uint8_t* GetLookupTableShared() {
+ // initialized once
+ static const auto* lookup_table = []() {
+ // if we have already initialized the lookup table, just return
+ // ideally we could have a global lookup table, but that account for too much space.
+ /* Handles values form -640 to 639. */
+ static uint8_t table[kLookupTableSize] = {0};
+
+ // taken from https://github.com/python-pillow/Pillow/blob/66add095a50d76c35c7f58643461f2edf78a3f05/src/libImaging/Resample.c#L94
+ // we need to handle negative values
+ // it's equivalent to :x = np.clip(x, 0, 255) where x \in [-640, 639]
+ // we will accept a negative x for (&table[640])[x] means table +640 -x
+ for (int i = 0; i < static_cast(kLookupTableSize); ++i) {
+ table[i] = static_cast(std::min(std::max(i - 640, 0), 255));
+ }
+ return table;
+ }();
+ return lookup_table;
+ }
}; // UpsampleBase
} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/cu_inc/common.cuh b/onnxruntime/core/providers/cuda/cu_inc/common.cuh
index 0d9928baa86e0..66794f88d8670 100644
--- a/onnxruntime/core/providers/cuda/cu_inc/common.cuh
+++ b/onnxruntime/core/providers/cuda/cu_inc/common.cuh
@@ -194,13 +194,13 @@ template <>
__device__ __inline__ half _Ceil(half a) { return half(ceilf((float)a)); }
template
-__device__ __inline__ T _Floor(T a);
+__device__ __host__ __inline__ T _Floor(T a);
template <>
-__device__ __inline__ float _Floor(float a) { return floorf(a); }
+__device__ __host__ __inline__ float _Floor(float a) { return floorf(a); }
template <>
-__device__ __inline__ double _Floor(double a) { return floor(a); }
+__device__ __host__ __inline__ double _Floor(double a) { return floor(a); }
template <>
__device__ __inline__ half _Floor(half a) { return half(floorf((float)a)); }
@@ -230,13 +230,13 @@ template <>
__device__ __inline__ half _Erf(half a) { return half(erff((float)a)); }
template
-__device__ __inline__ T _Round(T a);
+__device__ __host__ __inline__ T _Round(T a);
template <>
-__device__ __inline__ float _Round(float a) { return rintf(a); }
+__device__ __host__ __inline__ float _Round(float a) { return rintf(a); }
template <>
-__device__ __inline__ double _Round(double a) { return rint(a); }
+__device__ __host__ __inline__ double _Round(double a) { return rint(a); }
template <>
__device__ __inline__ half _Round(half a) {
diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
index 00783bcbc2665..1ce089fd93044 100644
--- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
+++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
@@ -1109,11 +1109,11 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceSumSquare);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, GatherND);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Dropout);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Resize);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Resize);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Resize);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, Resize);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint8_t, Resize);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, float, Resize);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, double, Resize);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, MLFloat16, Resize);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, int32_t, Resize);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, uint8_t, Resize);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, If);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, Loop);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Flatten);
@@ -1277,6 +1277,11 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Pad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, bool, Pad);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, Resize);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Resize);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Resize);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, Resize);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, uint8_t, Resize);
// Opset 19
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, float, Cast);
@@ -2009,11 +2014,11 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
@@ -2176,6 +2181,11 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
// Opset 19
BuildKernelCreateInfo,
diff --git a/onnxruntime/core/providers/cuda/tensor/resize.cc b/onnxruntime/core/providers/cuda/tensor/resize.cc
index 764172a8d1fac..97d4eb71e970a 100644
--- a/onnxruntime/core/providers/cuda/tensor/resize.cc
+++ b/onnxruntime/core/providers/cuda/tensor/resize.cc
@@ -28,10 +28,22 @@ namespace cuda {
.InputMemoryType(OrtMemTypeCPUInput, 3) \
.TypeConstraint("T1", DataTypeImpl::GetTensorType()), \
Resize); \
+ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
+ Resize, \
+ kOnnxDomain, \
+ 13, 17, \
+ T, \
+ kCudaExecutionProvider, \
+ (*KernelDefBuilder::Create()) \
+ .InputMemoryType(OrtMemTypeCPUInput, 1) \
+ .InputMemoryType(OrtMemTypeCPUInput, 2) \
+ .InputMemoryType(OrtMemTypeCPUInput, 3) \
+ .TypeConstraint("T1", DataTypeImpl::GetTensorType()), \
+ Resize); \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
Resize, \
kOnnxDomain, \
- 13, \
+ 18, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
diff --git a/onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu b/onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu
new file mode 100644
index 0000000000000..56b7c3f499303
--- /dev/null
+++ b/onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu
@@ -0,0 +1,1179 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/providers/cuda/cu_inc/common.cuh"
+#include "core/providers/cuda/tensor/resize_impl.h"
+
+#define FUNC_DEF __device__
+
+namespace onnxruntime {
+namespace cuda {
+
+using onnxruntime::ResizeCoordinateTransformationMode;
+using onnxruntime::UpsampleMode;
+
+///
+/// Compute a buffer for bilinear data for CUDA antialias resizing.
+///
+static std::tuple ComputeBilinearScaleBufferSize(
+ int64_t output_height, int64_t output_width,
+ float height_rscale, float width_rscale,
+ float support_value,
+ float& scaled_support_height, float& scaled_support_width,
+ int32_t& window_size_height, int32_t& window_size_width) {
+ scaled_support_height = ComputeScaledSupportValue(support_value, height_rscale);
+ scaled_support_width = ComputeScaledSupportValue(support_value, width_rscale);
+ window_size_height = ComputeWindowSize(scaled_support_height);
+ window_size_width = ComputeWindowSize(scaled_support_width);
+
+ auto height_buffer_size = ComputeWeightedCoeffBufferSize(output_height, window_size_height);
+ auto width_buffer_size = ComputeWeightedCoeffBufferSize(output_width, window_size_width);
+
+ return std::make_tuple(height_buffer_size, width_buffer_size);
+}
+
+///
+/// Compute a buffer for btrilinear data for CUDA antialias resizing.
+///
+static std::tuple ComputeTrilinearScaleBufferSize(
+ int64_t output_depth, int64_t output_height, int64_t output_width,
+ float depth_rscale, float height_rscale, float width_rscale,
+ float support_value,
+ float& scaled_support_depth, float& scaled_support_height,
+ float& scaled_support_width, int32_t& window_size_depth,
+ int32_t& window_size_height, int32_t& window_size_width) {
+ scaled_support_depth = ComputeScaledSupportValue(support_value, depth_rscale);
+ window_size_depth = ComputeWindowSize(scaled_support_depth);
+ auto depth_buffer_size = ComputeWeightedCoeffBufferSize(output_depth, window_size_depth);
+
+ const auto [y_buffer_size, w_buffer_size] = ComputeBilinearScaleBufferSize(output_height,
+ output_width, height_rscale,
+ width_rscale, support_value,
+ scaled_support_height,
+ scaled_support_width,
+ window_size_height, window_size_width);
+ return std::make_tuple(depth_buffer_size, y_buffer_size, w_buffer_size);
+}
+
+// Antialiasing filters
+struct BilinearFilter {
+ __device__ __host__ float operator()(float x, float /* cubic_coeff_a */) const {
+ if (x < 0.0f) {
+ x = -x;
+ }
+ if (x < 1.0f) {
+ return 1.0f - x;
+ }
+ return 0.0f;
+ }
+};
+
+struct BiCubicFilter {
+ __device__ __host__ float operator()(float x, float cubic_coeff_a) const {
+ /* https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
+ */
+ if (x < 0.0f) {
+ x = -x;
+ }
+ if (x < 1.0f) {
+ return ((cubic_coeff_a + 2.0f) * x - (cubic_coeff_a + 3.0f)) * x * x + 1;
+ }
+ if (x < 2.0f) {
+ return (((x - 5.0f) * x + 8.f) * x - 4.f) * cubic_coeff_a;
+ }
+ return 0.0f;
+ }
+};
+
+struct TriLinearFilter {
+ __device__ __host__ float operator()(float x, float /* cubic_coeff_a */) const {
+ if (x < 0.0f) {
+ x = -x;
+ }
+ if (x < 1.0f) {
+ return 1.0f - x;
+ }
+ return 0.0f;
+ }
+};
+
+template
+struct AccumTypeCaster {
+ static __device__ __host__ AccumType* cast(AccumType* p) {
+ return p;
+ }
+};
+
+template <>
+struct AccumTypeCaster {
+ static __device__ __host__ float* cast(int32_t* p) {
+ return reinterpret_cast(p);
+ }
+};
+
+template
+__global__ void _ComputeInterpolationAtLevel1(
+ int64_t num_channels,
+ int64_t input_height, int64_t input_width,
+ int64_t output_height, int64_t output_width,
+ const fast_divmod div_output_width,
+ const fast_divmod div_output_image,
+ int32_t window_size,
+ const uint8_t* clip8_table,
+ const int64_t* bound_data,
+ std::tuple outof_bounds_buffers,
+ const AccumType* weight_coefficients,
+ const T* Xdata, T* Ydata,
+ const int N) {
+ CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
+
+ // No need to do scale
+ if (output_width == input_width) {
+ Ydata[id] = Xdata[id];
+ return;
+ }
+
+ int bxc, output_image_index;
+ div_output_image.divmod(id, bxc, output_image_index);
+
+ int output_y, output_x;
+ div_output_width.divmod(output_image_index, output_y, output_x);
+
+ CUDA_LONG input_index = static_cast(bxc * num_channels * input_height * input_width);
+ CUDA_LONG output_index = static_cast(bxc * num_channels * output_height * output_width);
+
+ auto* Ydata_offset = Ydata + output_index + output_width * output_y + output_x;
+ const auto* bound = bound_data;
+
+ AccumType output = onnxruntime::is_8bit_v ? ConstValue::mag_factor : 0;
+
+ const auto* weight_coeff = weight_coefficients + window_size * output_x;
+ int64_t xmin = bound[static_cast(output_x) * 2];
+ int64_t xmax = bound[static_cast(output_x) * 2 + 1];
+
+ // Input window
+ const auto* Xdata_offset = Xdata + input_index + input_width * output_y + xmin;
+
+ for (; xmin < xmax; ++xmin) {
+ if constexpr (std::is_same::value) {
+ // This cast is needed when we deal with half
+ output += static_cast((*Xdata_offset++)) * (*weight_coeff++);
+ } else {
+ output += (*Xdata_offset++) * (*weight_coeff++);
+ }
+ }
+
+ if constexpr (onnxruntime::is_8bit_v) {
+ const uint8_t* clip8_lookups = &clip8_table[640];
+ *Ydata_offset = static_cast(clip8_lookups[output >> 22]);
+ } else if constexpr (std::is_same::value) {
+ *Ydata_offset = static_cast(std::round(output));
+ } else {
+ *Ydata_offset = static_cast