diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index b0ed68d595c42..1eaf0fb6dad76 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -734,7 +734,8 @@ Do not modify directly.* |||13|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)| |||[5, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)| |||[1, 4]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Resize|*in* X:**T**
*in* scales:**tensor(float)**
*out* Y:**T**

or

*in* X:**T1**
*in* roi:**T2**
*in* scales:**tensor(float)**
*in* sizes:**tensor(int64)**
*out* Y:**T1**|13+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| +|Resize|*in* X:**T**
*in* scales:**tensor(float)**
*out* Y:**T**

or

*in* X:**T1**
*in* roi:**T2**
*in* scales:**tensor(float)**
*in* sizes:**tensor(int64)**
*out* Y:**T1**|18+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| +|||[13, 17]|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| |||[11, 12]|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| |||10|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| |ReverseSequence|*in* input:**T**
*in* sequence_lens:**tensor(int64)**
*out* Y:**T**|10+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 48e4617b33b4d..37e7e42150413 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -2008,8 +2008,10 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { Greater)>, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo namespace onnxruntime { // The suppressed warning is: "The type with a virtual function needs either public virtual or protected nonvirtual destructor." @@ -292,6 +294,12 @@ struct ProviderHostCPUImpl : ProviderHostCPU { Status Sampling__Compute(const contrib::transformers::Sampling* p, OpKernelContext* ctx) override { return p->contrib::transformers::Sampling::Compute(ctx); } Status Sampling__SetupSubgraphExecutionInfo(contrib::transformers::Sampling* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) override { return p->contrib::transformers::Sampling::SetupSubgraphExecutionInfo(session_state, attribute_name, subgraph_session_state); } + void UpsampleBase__AdjustOutputSizeAsPolicy(const UpsampleBase* p, TensorShapeVector& output_dims, + gsl::span input_dims, + InlinedVector& scales) const override { + p->AdjustOutputSizeAsPolicy(output_dims, input_dims, scales); + } + #ifdef ENABLE_ATEN Status ATen__Compute(const contrib::ATen* p, OpKernelContext* p_ctx) override { return p->ATen::Compute(p_ctx); } #endif diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.h b/onnxruntime/core/providers/cpu/cpu_provider_shared.h index f33eec4b93e98..c0e674827e4d1 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.h +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.h @@ -24,6 +24,7 @@ class SliceOp__PrepareForComputeMetadata; // Directly maps to SliceOp::PrepareF class UnsqueezeBase__Prepare; // Directly maps to UnsqueezeBase::Prepare class contrib__AdamWOptimizerBase__Prepare; class contrib__SGDOptimizerV2Base__Prepare; +class UpsampleBase; using PadsVector = InlinedVector; @@ -202,6 +203,10 @@ struct ProviderHostCPU { virtual Status Sampling__Compute(const contrib::transformers::Sampling* p, OpKernelContext* ctx) = 0; virtual Status Sampling__SetupSubgraphExecutionInfo(contrib::transformers::Sampling* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) = 0; + virtual void UpsampleBase__AdjustOutputSizeAsPolicy(const UpsampleBase* p, TensorShapeVector& output_dims, + gsl::span input_dims, + InlinedVector& scales) const = 0; + #ifdef ENABLE_ATEN virtual Status ATen__Compute(const contrib::ATen* p, OpKernelContext* p_ctx) = 0; #endif diff --git a/onnxruntime/core/providers/cpu/tensor/upsample.cc b/onnxruntime/core/providers/cpu/tensor/upsample.cc index fa69e144be554..babbac0b7be17 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsample.cc +++ b/onnxruntime/core/providers/cpu/tensor/upsample.cc @@ -1,10 +1,15 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/providers/cpu/tensor/upsample.h" + +#include + +#include "core/common/inlined_containers.h" #include "core/common/safeint.h" #include "core/platform/threadpool.h" -#include "core/providers/cpu/tensor/upsample.h" #include "core/providers/cpu/tensor/upsample_antialias.h" + using namespace onnxruntime::common; using namespace std; using onnxruntime::narrow; @@ -30,6 +35,46 @@ REGISTER_VERSIONED_TYPED_KERNEL(int32_t, 9, 9); REGISTER_VERSIONED_TYPED_KERNEL(int8_t, 9, 9); REGISTER_VERSIONED_TYPED_KERNEL(uint8_t, 9, 9); +void UpsampleBase::AdjustOutputSizeAsPolicy(TensorShapeVector& output_dims, gsl::span input_dims, + InlinedVector& scales) const { + // AspectRatioPolicy::STRETCH is default policy when opset < 18 + if (keep_aspect_ratio_policy_ == AspectRatioPolicy::STRETCH) { + return; + } + + InlinedHashSet axes_set(axes_.begin(), axes_.end()); + + float scale_in_policy = 0.0f; + if (keep_aspect_ratio_policy_ == AspectRatioPolicy ::NOT_LARGER) { + scale_in_policy = std::numeric_limits::max(); + + for (size_t i = 0; i < scales.size(); i++) { + if (axes_set.empty() || axes_set.count(i) > 0) { + scale_in_policy = std::min(scale_in_policy, scales[i]); + } + } + } else if (keep_aspect_ratio_policy_ == AspectRatioPolicy ::NOT_SMALLER) { + scale_in_policy = std::numeric_limits::min(); + + for (size_t i = 0; i < scales.size(); i++) { + if (axes_set.empty() || axes_set.count(i) > 0) { + scale_in_policy = std::max(scale_in_policy, scales[i]); + } + } + } + + for (size_t i = 0; i < scales.size(); i++) { + // if axes is not specified (AKA axes_set.empty()), we apply the policy to all axes + if (axes_set.empty() || axes_set.count(i) > 0) { + scales[i] = scale_in_policy; + output_dims[i] = static_cast(std::round(scales[i] * input_dims[i])); + } else { + scales[i] = 1.0f; + output_dims[i] = input_dims[i]; + } + } +} + template void UpsampleNearest2x(int64_t batch_size, int64_t num_channels, @@ -94,8 +139,8 @@ UpsampleNearestSetupInputMappings(int64_t n_dim, const TensorShape& input_shape, const TensorShape& output_shape, const std::vector& input_dim_factor, - const vector& scales, - const vector& roi, + gsl::span scales, + gsl::span roi, bool extrapolation_enabled, const GetOriginalCoordinateFunc& get_original_coordinate, const GetNearestPixelFunc& get_nearest_pixel) { @@ -141,8 +186,8 @@ static Status UpsampleNearestImpl(const T* input, T* output, const TensorShape& input_shape, const TensorShape& output_shape, - const vector& scales, - const vector& roi, + gsl::span scales, + gsl::span roi, bool extrapolation_enabled, const T extrapolation_value, const GetOriginalCoordinateFunc& get_original_coordinate, @@ -285,8 +330,8 @@ static Status UpsampleNearest(const T* input, T* output, const TensorShape& input_shape, const TensorShape& output_shape, - const vector& scales, - const vector& roi, + gsl::span scales, + gsl::span roi, bool is_resize, bool extrapolation_enabled, T extrapolation_value, @@ -412,7 +457,7 @@ BilinearParams SetupUpsampleBilinear(const int32_t input_height, const int32_t output_width, const float height_scale, const float width_scale, - const std::vector& roi, + gsl::span roi, AllocatorPtr& alloc, const GetOriginalCoordinateFunc& get_original_coordinate, const bool is_nchw) { @@ -518,7 +563,7 @@ BilinearParamsInteger SetupUpsampleBilinearInteger(const int32_t input_height, const int32_t output_width, const float height_scale, const float width_scale, - const std::vector& roi, + gsl::span roi, AllocatorPtr& alloc, const GetOriginalCoordinateFunc& get_original_coordinate, const bool is_nchw) { @@ -650,7 +695,7 @@ static TrilinearParams SetupUpsampleTrilinear(int64_t input_depth, float depth_scale, float height_scale, float width_scale, - const std::vector& roi, + gsl::span roi, AllocatorPtr& alloc, const GetOriginalCoordinateFunc& get_original_coordinate) { TrilinearParams p; @@ -796,7 +841,7 @@ void UpsampleTrilinear(int64_t batch_size, float depth_scale, float height_scale, float width_scale, - const std::vector& roi, + gsl::span roi, bool use_extrapolation, float extrapolation_value, const T* XdataBase, @@ -929,7 +974,7 @@ void ResizeBiCubic(int64_t batch_size, bool use_extrapolation, float extrapolation_value, bool exclude_outside, - const std::vector& roi, + gsl::span roi, const T* Xdata, T* Ydata, const GetOriginalCoordinateFunc& get_original_coordinate) { @@ -1067,9 +1112,9 @@ void ResizeBiCubic(int64_t batch_size, template Status Upsample::BaseCompute(OpKernelContext* context, - const std::vector& roi, - const std::vector& scales, - const gsl::span& output_dims) const { + gsl::span roi, + gsl::span scales, + gsl::span output_dims) const { const auto* X = context->Input(0); auto dims = X->Shape().GetDims(); ORT_RETURN_IF_NOT(output_dims.size() == dims.size(), "Rank of input and output tensor should be same."); @@ -1327,7 +1372,7 @@ Status Upsample::Compute(OpKernelContext* context) const { // Initialize the roi array to all zeros as this will be the most common case // Roi data is needed only when coordinate transformation mode is set to tf_crop_and_resize // for all other cases we need a 0 initialized roi array - std::vector roi_array(roi_); + InlinedVector roi_array(roi_); if (!roi_cached_) { bool use_default_roi = true; @@ -1353,7 +1398,7 @@ Status Upsample::Compute(OpKernelContext* context) const { ComputeROIWithAxes(roi_array, input_dims.size()); // Get scales data - std::vector scales_array(input_dims.size()); + InlinedVector scales_array(input_dims.size()); if (OpKernel::Node().InputDefs().size() == 1) { // Compute output shape from scales and input dims diff --git a/onnxruntime/core/providers/cpu/tensor/upsample.h b/onnxruntime/core/providers/cpu/tensor/upsample.h index 3046ee4b8260d..8ff04781f6ad0 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsample.h +++ b/onnxruntime/core/providers/cpu/tensor/upsample.h @@ -66,8 +66,8 @@ class Upsample : public UpsampleBase, public OpKernel { Status Compute(OpKernelContext* context) const override; - Status BaseCompute(OpKernelContext* context, const std::vector& roi, const std::vector& scales, - const gsl::span& output_dims) const; + Status BaseCompute(OpKernelContext* context, gsl::span roi, gsl::span scales, + gsl::span output_dims) const; }; BilinearParams SetupUpsampleBilinear(const int32_t input_height, @@ -76,7 +76,7 @@ BilinearParams SetupUpsampleBilinear(const int32_t input_height, const int32_t output_width, const float height_scale, const float width_scale, - const std::vector& roi, + gsl::span roi, AllocatorPtr& alloc, const GetOriginalCoordinateFunc& get_original_coordinate, const bool is_nchw); @@ -90,7 +90,7 @@ void UpsampleBilinear(const int32_t batch_size, const int32_t output_width, const float height_scale, const float width_scale, - const std::vector& roi, + gsl::span roi, const bool use_extrapolation, const float extrapolation_value, const T* const XdataBase, @@ -144,7 +144,7 @@ void NhwcUpsampleBilinear(const int32_t batch_size, const int32_t output_width, const float height_scale, const float width_scale, - const std::vector& roi, + gsl::span roi, const float extrapolation_value, const T* const XdataBase, T* const YdataBase, @@ -227,7 +227,7 @@ BilinearParamsInteger SetupUpsampleBilinearInteger(const int32_t input_height, const int32_t output_width, const float height_scale, const float width_scale, - const std::vector& roi, + gsl::span roi, AllocatorPtr& alloc, const GetOriginalCoordinateFunc& get_original_coordinate, const bool is_nchw); @@ -241,7 +241,7 @@ void NhwcUpsampleBilinearInteger(const int32_t batch_size, const int32_t output_width, const float height_scale, const float width_scale, - const std::vector& roi, + gsl::span roi, const float extrapolation_value, const T* const XdataBase, T* const YdataBase, diff --git a/onnxruntime/core/providers/cpu/tensor/upsample_antialias.h b/onnxruntime/core/providers/cpu/tensor/upsample_antialias.h index e1dcaf500a325..1e32b7e874b1a 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsample_antialias.h +++ b/onnxruntime/core/providers/cpu/tensor/upsample_antialias.h @@ -21,32 +21,6 @@ namespace onnxruntime { -namespace ConstValue { -constexpr int32_t mag_factor = 1 << (22 - 1); -} - -namespace { -const uint8_t* GetLookupTableShared() { - // initialized once - static const auto* lookup_table = []() { - // if we have already initialized the lookup table, just return - // ideally we could have a global lookup table, but that account for too much space. - /* Handles values form -640 to 639. */ - static uint8_t table[1280] = {0}; - - // taken from https://github.com/python-pillow/Pillow/blob/66add095a50d76c35c7f58643461f2edf78a3f05/src/libImaging/Resample.c#L94 - // we need to handle negative values - // it's equivalent to :x = np.clip(x, 0, 255) where x \in [-640, 639] - // we will accept a negative x for (&table[640])[x] means table +640 -x - for (int i = 0; i < 1280; ++i) { - table[i] = static_cast(std::min(std::max(i - 640, 0), 255)); - } - return table; - }(); - return lookup_table; -} -} // namespace - template struct FilterParamsBaseAntiAlias { std::vector bound; @@ -57,15 +31,15 @@ struct FilterParamsBaseAntiAlias { template struct FilterParamsAntiAlias { - float support_size = 2.0f; - float cubic_coeff_a = -0.75f; + float support_size = antialias_constants::kSupportSize; + float cubic_coeff_a = antialias_constants::kCubicCoeffA; FilterParamsBaseAntiAlias dim_x; FilterParamsBaseAntiAlias dim_y; FilterParamsBaseAntiAlias dim_z; const uint8_t* GetClip8LookupTable() const { - return GetLookupTableShared(); + return UpsampleBase::GetLookupTableShared(); } virtual ~FilterParamsAntiAlias() = default; virtual float Filter(float x) const = 0; @@ -89,7 +63,7 @@ struct BilinearParamsAntiAlias : FilterParamsAntiAlias { template struct BiCubicParamsAntiAlias : FilterParamsAntiAlias { BiCubicParamsAntiAlias() { - this->support_size = 4.0f; + this->support_size = antialias_constants::kBiCubicSupportSize; } // taken from @@ -124,27 +98,6 @@ struct TriLinearParamsAntiAlias : FilterParamsAntiAlias { } }; -template -struct AccumulateType { - using type = int32_t; - using Dtype = T; -}; - -template <> -struct AccumulateType { - using type = float; -}; - -template <> -struct AccumulateType { - using type = float; -}; - -template <> -struct AccumulateType { - using type = double; -}; - // The following method supports a 3/4/5-D input in 'Linear mode, cubic mode' // that amounts to 'Bilinear,TriLinear, Bicubic/Tricubic' Upsampling/Resizing in the sense that it assumes // A N-D tensor has @@ -156,19 +109,20 @@ struct AccumulateType { // - [N, H, W, C] and the scales are [1.0, height_scale, width_scale, 1.0] template void SetupUpsampleFilterAntiAlias(FilterParamsAntiAlias& p, - const gsl::span input_h_w_c, - const gsl::span output_h_w_c, - const gsl::span scale_h_w_c, - const std::vector& roi, + gsl::span input_h_w_c, + gsl::span output_h_w_c, + gsl::span scale_h_w_c, + gsl::span roi, AllocatorPtr& alloc, const GetOriginalCoordinateFunc& get_original_coordinate, bool exclude_outside, const bool is_nchw) { - auto compute_weight_coefficients = [&alloc, &roi, &get_original_coordinate, exclude_outside](const FilterParamsAntiAlias& p, - const int64_t input_size, - const int64_t output_size, - size_t rindex, - FilterParamsBaseAntiAlias& param_base, - const float rscale) -> int64_t { + auto compute_weight_coefficients = [&alloc, roi, &get_original_coordinate, exclude_outside]( + const FilterParamsAntiAlias& p, + const int64_t input_size, + const int64_t output_size, + size_t rindex, + FilterParamsBaseAntiAlias& param_base, + const float rscale) -> int64_t { param_base.bound.reserve(static_cast(output_size) * 2); param_base.out_of_bound_idx.reserve(static_cast(output_size)); @@ -245,13 +199,14 @@ void SetupUpsampleFilterAntiAlias(FilterParamsAntiAlias& p, // normalize the scale to 1 << 22 for int8/uint8 if constexpr (std::is_same::value) { - scale_buffer_int[x] = static_cast(std::round(scale_buffer[x] * ConstValue::mag_factor * 2.f)); + scale_buffer_int[x] = static_cast(std::round(scale_buffer[x] * ConstValue::mag_factor_x_2)); } } /*for (; x < window_size; x++) { scale_buffer[x] = 0; }*/ } + return window_size; }; @@ -269,9 +224,6 @@ void SetupUpsampleFilterAntiAlias(FilterParamsAntiAlias& p, } } -template -inline constexpr bool is_8bit_v = std::is_same::value || std::is_same::value; - /** * @brief To compute interpolation along with the last axis. * For brief,we assume the input tensor has 3 dimensions and we all it CHW for each character represent a dim. @@ -398,6 +350,7 @@ void ComputeInterpolationAtLevel2(int64_t num_channels, int64_t input_height, in output += *Xdata_offset * (*weight_coeff_start++); Xdata_offset += output_width; } + if constexpr (is_8bit_v) { *Ydata_offset++ = static_cast(clip8_lookups[output >> 22]); } else if constexpr (std::is_same::value) { @@ -444,6 +397,7 @@ void ComputeInterpolationAtLevel2(int64_t num_channels, int64_t input_height, in output += *Xdata_offset * (*weight_coeff_start++); Xdata_offset += output_width; } + if constexpr (is_8bit_v) { *Ydata_offset++ = static_cast(clip8_lookups[output >> 22]); } else if constexpr (std::is_same::value) { @@ -515,6 +469,7 @@ void UpsampleBaseAntiAlias(FilterParamsAntiAlias& p, narrow(input_height * num_channels * input_width)); auto ydata_span = gsl::make_span(image_temp_buffer.get(), narrow(input_height * num_channels * output_width)); + // This computes only the width direction.Thus height keeps unchanged. ComputeInterpolationAtLevel1(num_channels, input_height, input_width, input_height, output_width, xdata_span, ydata_span, p, p.dim_x, tp); } @@ -546,7 +501,7 @@ void UpsampleBilinearAntiAlias(const int64_t batch_size, const int64_t output_width, const float height_scale, const float width_scale, - const std::vector& roi, + gsl::span roi, const bool use_extrapolation, const float extrapolation_value, bool exclude_outside, @@ -575,7 +530,7 @@ void NhwcUpsampleBilinearAntiAlias(const int64_t batch_size, const int64_t output_width, const float height_scale, const float width_scale, - const std::vector& roi, + gsl::span roi, const bool use_extrapolation, const float extrapolation_value, bool exclude_outside, @@ -608,7 +563,7 @@ void NhwcResizeBiCubicAntiAlias(const int64_t batch_size, bool use_extrapolation, float extrapolation_value, bool exclude_outside, - const std::vector& roi, + gsl::span roi, const Tensor* X, T* Ydata_base, AllocatorPtr& alloc, @@ -688,7 +643,7 @@ void ResizeBiCubicAntiAlias(int64_t batch_size, bool use_extrapolation, float extrapolation_value, bool exclude_outside, - const std::vector& roi, + gsl::span roi, const Tensor* X, T* Ydata_base, AllocatorPtr& alloc, @@ -719,7 +674,7 @@ void UpsampleTrilinearAntiAlias(int64_t batch_size, float depth_scale, float height_scale, float width_scale, - const std::vector& roi, + gsl::span roi, bool use_extrapolation, float extrapolation_value, bool exclude_outside, diff --git a/onnxruntime/core/providers/cpu/tensor/upsamplebase.h b/onnxruntime/core/providers/cpu/tensor/upsamplebase.h index a0e7ca1084fef..b768fedd8513a 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsamplebase.h +++ b/onnxruntime/core/providers/cpu/tensor/upsamplebase.h @@ -3,11 +3,13 @@ #pragma once +#include #include #include #include #include -#include + +#include #include "core/common/status.h" #include #include @@ -58,7 +60,73 @@ enum class AspectRatioPolicy { NOT_SMALLER, }; +// Antialias types +template +struct AccumulateType { + using type = int32_t; + using Dtype = T; +}; + +template <> +struct AccumulateType { + using type = float; +}; + +template <> +struct AccumulateType { + using type = float; +}; + +template <> +struct AccumulateType { + using type = float; +}; + +template <> +struct AccumulateType { + using type = double; +}; + +namespace antialias_constants { +constexpr float kCubicCoeffA = -0.75f; +constexpr float kSupportSize = 2.0f; +constexpr float kBiCubicSupportSize = 4.0f; +} // namespace antialias_constants + +namespace ConstValue { +constexpr int32_t mag_factor = 1 << (22 - 1); +// We use to multiply by 2, let's make a constant which is twice as big +constexpr int32_t mag_factor_x_2 = 1 << 22; +} // namespace ConstValue + +template +inline constexpr bool is_8bit_v = std::is_same::value || std::is_same::value; + +template +void PrintAntiAliasBuffers(std::ostream& os, gsl::span bounds, gsl::span out_of_bounds, + gsl::span weight_coefficients) { + os << "#### Bounds: "; + std::copy(bounds.begin(), bounds.end(), std::ostream_iterator(os, " ")); + os << std::endl; + + os << "#### Out of Bounds: "; + std::copy(out_of_bounds.begin(), out_of_bounds.end(), + std::ostream_iterator(os, " ")); + os << std::endl; + + os << "#### Scale Buffer: "; + std::copy(weight_coefficients.begin(), weight_coefficients.end(), + std::ostream_iterator(os, " ")); + os << std::endl; +} + class UpsampleBase { + public: + // Make this available in other EP via provider bridge + // it works iff output_shape is specified + void AdjustOutputSizeAsPolicy(TensorShapeVector& output_dims, gsl::span input_dims, + InlinedVector& scales) const; + protected: explicit UpsampleBase(const OpKernelInfo& info) : scales_cached_(false), roi_cached_(false), use_extrapolation_(false) { @@ -69,23 +137,32 @@ class UpsampleBase { std::string mode; ORT_ENFORCE(info.GetAttr("mode", &mode).IsOK()); mode_ = StringToUpsampleMode(mode); - antialias_ = info.GetAttrOrDefault("antialias", 0) == 0 ? false : true; - if (antialias_) { - ORT_ENFORCE((UpsampleMode::LINEAR == mode_ || UpsampleMode::CUBIC == mode_), - "when anti-aliasing is set, Resize only supports mode `LINEAR` and `CUBIC`."); - } auto input_count = info.GetInputCount(); if (input_count == 1) { // opset < 10 - ORT_THROW_IF_ERROR(info.GetAttrs("scales", scales_)); - ORT_THROW_IF_ERROR(ScalesValidation(scales_, mode_)); + std::vector scales; + ORT_THROW_IF_ERROR(info.GetAttrs("scales", scales)); + ORT_THROW_IF_ERROR(ScalesValidation(scales, mode_)); + scales_.assign(scales.cbegin(), scales.cend()); scales_cached_ = true; } - std::string keep_aspect_ratio_policy = info.GetAttrOrDefault("keep_aspect_ratio_policy", "stretch"); - keep_aspect_ratio_policy_ = StringToKeepAspectRatioPolicy(keep_aspect_ratio_policy); + if (opset >= 18) { + antialias_ = info.GetAttrOrDefault("antialias", 0) == 0 ? false : true; + + if (antialias_) { + ORT_ENFORCE((UpsampleMode::LINEAR == mode_ || UpsampleMode::CUBIC == mode_), + "when anti-aliasing is set, Resize only supports mode `LINEAR` and `CUBIC`."); + } - axes_ = info.GetAttrsOrDefault("axes"); + // The attribute is absent in opset < 18, but the default value as if stretch. + std::string keep_aspect_ratio_policy = info.GetAttrOrDefault("keep_aspect_ratio_policy", "stretch"); + keep_aspect_ratio_policy_ = StringToKeepAspectRatioPolicy(keep_aspect_ratio_policy); + + // guard against unit tests that can add an attribute + auto axes = info.GetAttrsOrDefault("axes"); + axes_.assign(axes.cbegin(), axes.cend()); + } extrapolation_value_ = info.GetAttrOrDefault("extrapolation_value", 0.0f); @@ -112,7 +189,7 @@ class UpsampleBase { nearest_mode_ = StringToNearestMode(nearest_mode_name); get_nearest_pixel_ = GetNearestPixelFromOriginal(nearest_mode_); - cubic_coeff_a_ = info.GetAttrOrDefault("cubic_coeff_a", -0.75f); + cubic_coeff_a_ = info.GetAttrOrDefault("cubic_coeff_a", antialias_constants::kCubicCoeffA); exclude_outside_ = info.GetAttrOrDefault("exclude_outside", 0) == 0 ? false : true; if ((exclude_outside_ == 1 && mode_ != CUBIC) && (antialias_ == false || mode_ != LINEAR)) { @@ -166,7 +243,7 @@ class UpsampleBase { ResizeCoordinateTransformationMode coordinate_transform_mode_; GetOriginalCoordinateFunc get_original_coordinate_; ResizeNearestMode nearest_mode_; - AspectRatioPolicy keep_aspect_ratio_policy_; + AspectRatioPolicy keep_aspect_ratio_policy_{AspectRatioPolicy::STRETCH}; GetNearestPixelFunc get_nearest_pixel_; float cubic_coeff_a_; bool exclude_outside_; @@ -174,9 +251,9 @@ class UpsampleBase { float extrapolation_value_; bool use_nearest2x_optimization_ = false; - std::vector scales_; - std::vector roi_; - std::vector axes_; + InlinedVector scales_; + InlinedVector roi_; + TensorShapeVector axes_; bool scales_cached_; bool roi_cached_; @@ -335,7 +412,7 @@ class UpsampleBase { } } - [[nodiscard]] Status ScalesValidation(const std::vector& scales, const UpsampleMode mode) const { + [[nodiscard]] Status ScalesValidation(gsl::span scales, const UpsampleMode mode) const { if (!is_resize_) { for (auto& scale : scales) { ORT_RETURN_IF_NOT(scale >= 1, "Scale value should be greater than or equal to 1."); @@ -372,7 +449,7 @@ class UpsampleBase { } [[nodiscard]] Status - ParseScalesData(const Tensor* scale, std::vector& scales, int64_t rank) const { + ParseScalesData(const Tensor* scale, InlinedVector& scales, int64_t rank) const { const auto* scale_data = scale->Data(); int64_t scales_size = scale->Shape().Size(); ORT_RETURN_IF_NOT(scales_size > 0, "scales size should be greater than 0."); @@ -387,19 +464,19 @@ class UpsampleBase { // in which case the other axes is ignored and use default scale of 1 // scales_size == axes_.size() should be guaranteed if axes is not empty if (rank > 0 && (scales_size != rank || axes_.size())) { - std::vector new_scales(size_t(rank), 1.0f); + InlinedVector new_scales(size_t(rank), 1.0f); ORT_RETURN_IF_NOT(*std::max_element(axes_.begin(), axes_.end()) < rank && (int64_t(axes_.size()) == scales_size), "all values in axes should be less than rank of the data"); for (size_t i = 0; i < axes_.size(); i++) { new_scales[static_cast(axes_[i])] = scales[i]; } - scales = new_scales; + scales.swap(new_scales); } return ScalesValidation(scales, mode_); } - void ParseRoiData(const Tensor* roi, std::vector& roi_array) const { + void ParseRoiData(const Tensor* roi, InlinedVector& roi_array) const { int64_t roi_size = roi->Shape().Size(); if (roi_size > 0) { roi_array.resize(onnxruntime::narrow(roi_size)); @@ -429,52 +506,11 @@ class UpsampleBase { return Status::OK(); } - // it works iff output_shape is specified - void AdjustOutputSizeAsPolicy(TensorShapeVector& output_dims, gsl::span input_dims, - std::vector& scales) const { - std::unordered_set axes_set(axes_.begin(), axes_.end()); - - // AspectRatioPolicy::STRETCH is default policy when opset < 18 - if (keep_aspect_ratio_policy_ == AspectRatioPolicy ::STRETCH) { - return; - } - - float scale_in_policy = 0.0f; - if (keep_aspect_ratio_policy_ == AspectRatioPolicy ::NOT_LARGER) { - scale_in_policy = std::numeric_limits::max(); - - for (size_t i = 0; i < scales.size(); i++) { - if (axes_set.empty() || axes_set.count(i) > 0) { - scale_in_policy = std::min(scale_in_policy, scales[i]); - } - } - } else if (keep_aspect_ratio_policy_ == AspectRatioPolicy ::NOT_SMALLER) { - scale_in_policy = std::numeric_limits::min(); - - for (size_t i = 0; i < scales.size(); i++) { - if (axes_set.empty() || axes_set.count(i) > 0) { - scale_in_policy = std::max(scale_in_policy, scales[i]); - } - } - } - - for (size_t i = 0; i < scales.size(); i++) { - // if axes is not specified (AKA axes_set.empty()), we apply the policy to all axes - if (axes_set.empty() || axes_set.count(i) > 0) { - scales[i] = scale_in_policy; - output_dims[i] = static_cast(std::round(scales[i] * input_dims[i])); - } else { - scales[i] = 1.0f; - output_dims[i] = input_dims[i]; - } - } - } - // It's different in Opset 18 and before. // we will modify output_shape by sorts of policy even if it's specified [[nodiscard]] Status ParseScalesDataAndAdjustOutputSize(TensorShapeVector& output_dims, gsl::span input_dims, - std::vector& scales) const { + InlinedVector& scales) const { for (size_t i = 0, end = input_dims.size(); i < end; ++i) { // Handle corner case to avoid dividing by zero in the next step if (input_dims[i] == 0) { @@ -507,9 +543,9 @@ class UpsampleBase { // Roi is redefined in Opset-18, we have a concept of axes. // So we need to update it accordingly. - void ComputeROIWithAxes(std::vector& roi_array, size_t rank) const { + void ComputeROIWithAxes(InlinedVector& roi_array, size_t rank) const { if (axes_.size()) { - std::vector roi_tmp(rank * 2, 0); + InlinedVector roi_tmp(rank * 2, 0); for (size_t i = rank; i < rank * 2; ++i) { roi_tmp[i] = 1; } @@ -518,9 +554,32 @@ class UpsampleBase { roi_tmp[v_in_axes] = (roi_array[i]); roi_tmp[rank + v_in_axes] = (roi_array[axes_.size() + i]); } - roi_array = roi_tmp; + roi_array.swap(roi_tmp); } } + + public: + static constexpr size_t kLookupTableSize = 1280; + + static const uint8_t* GetLookupTableShared() { + // initialized once + static const auto* lookup_table = []() { + // if we have already initialized the lookup table, just return + // ideally we could have a global lookup table, but that account for too much space. + /* Handles values form -640 to 639. */ + static uint8_t table[kLookupTableSize] = {0}; + + // taken from https://github.com/python-pillow/Pillow/blob/66add095a50d76c35c7f58643461f2edf78a3f05/src/libImaging/Resample.c#L94 + // we need to handle negative values + // it's equivalent to :x = np.clip(x, 0, 255) where x \in [-640, 639] + // we will accept a negative x for (&table[640])[x] means table +640 -x + for (int i = 0; i < static_cast(kLookupTableSize); ++i) { + table[i] = static_cast(std::min(std::max(i - 640, 0), 255)); + } + return table; + }(); + return lookup_table; + } }; // UpsampleBase } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cu_inc/common.cuh b/onnxruntime/core/providers/cuda/cu_inc/common.cuh index 0d9928baa86e0..66794f88d8670 100644 --- a/onnxruntime/core/providers/cuda/cu_inc/common.cuh +++ b/onnxruntime/core/providers/cuda/cu_inc/common.cuh @@ -194,13 +194,13 @@ template <> __device__ __inline__ half _Ceil(half a) { return half(ceilf((float)a)); } template -__device__ __inline__ T _Floor(T a); +__device__ __host__ __inline__ T _Floor(T a); template <> -__device__ __inline__ float _Floor(float a) { return floorf(a); } +__device__ __host__ __inline__ float _Floor(float a) { return floorf(a); } template <> -__device__ __inline__ double _Floor(double a) { return floor(a); } +__device__ __host__ __inline__ double _Floor(double a) { return floor(a); } template <> __device__ __inline__ half _Floor(half a) { return half(floorf((float)a)); } @@ -230,13 +230,13 @@ template <> __device__ __inline__ half _Erf(half a) { return half(erff((float)a)); } template -__device__ __inline__ T _Round(T a); +__device__ __host__ __inline__ T _Round(T a); template <> -__device__ __inline__ float _Round(float a) { return rintf(a); } +__device__ __host__ __inline__ float _Round(float a) { return rintf(a); } template <> -__device__ __inline__ double _Round(double a) { return rint(a); } +__device__ __host__ __inline__ double _Round(double a) { return rint(a); } template <> __device__ __inline__ half _Round(half a) { diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 00783bcbc2665..1ce089fd93044 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1109,11 +1109,11 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceSumSquare); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, GatherND); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Dropout); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Resize); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Resize); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Resize); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, Resize); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint8_t, Resize); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, float, Resize); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, double, Resize); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, MLFloat16, Resize); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, int32_t, Resize); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, uint8_t, Resize); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, If); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, Loop); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Flatten); @@ -1277,6 +1277,11 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Pad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, bool, Pad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, Resize); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Resize); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Resize); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, Resize); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, uint8_t, Resize); // Opset 19 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, float, Cast); @@ -2009,11 +2014,11 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2176,6 +2181,11 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 19 BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/tensor/resize.cc b/onnxruntime/core/providers/cuda/tensor/resize.cc index 764172a8d1fac..97d4eb71e970a 100644 --- a/onnxruntime/core/providers/cuda/tensor/resize.cc +++ b/onnxruntime/core/providers/cuda/tensor/resize.cc @@ -28,10 +28,22 @@ namespace cuda { .InputMemoryType(OrtMemTypeCPUInput, 3) \ .TypeConstraint("T1", DataTypeImpl::GetTensorType()), \ Resize); \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + Resize, \ + kOnnxDomain, \ + 13, 17, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .InputMemoryType(OrtMemTypeCPUInput, 1) \ + .InputMemoryType(OrtMemTypeCPUInput, 2) \ + .InputMemoryType(OrtMemTypeCPUInput, 3) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), \ + Resize); \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ Resize, \ kOnnxDomain, \ - 13, \ + 18, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()) \ diff --git a/onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu b/onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu new file mode 100644 index 0000000000000..56b7c3f499303 --- /dev/null +++ b/onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu @@ -0,0 +1,1179 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/tensor/resize_impl.h" + +#define FUNC_DEF __device__ + +namespace onnxruntime { +namespace cuda { + +using onnxruntime::ResizeCoordinateTransformationMode; +using onnxruntime::UpsampleMode; + +/// +/// Compute a buffer for bilinear data for CUDA antialias resizing. +/// +static std::tuple ComputeBilinearScaleBufferSize( + int64_t output_height, int64_t output_width, + float height_rscale, float width_rscale, + float support_value, + float& scaled_support_height, float& scaled_support_width, + int32_t& window_size_height, int32_t& window_size_width) { + scaled_support_height = ComputeScaledSupportValue(support_value, height_rscale); + scaled_support_width = ComputeScaledSupportValue(support_value, width_rscale); + window_size_height = ComputeWindowSize(scaled_support_height); + window_size_width = ComputeWindowSize(scaled_support_width); + + auto height_buffer_size = ComputeWeightedCoeffBufferSize(output_height, window_size_height); + auto width_buffer_size = ComputeWeightedCoeffBufferSize(output_width, window_size_width); + + return std::make_tuple(height_buffer_size, width_buffer_size); +} + +/// +/// Compute a buffer for btrilinear data for CUDA antialias resizing. +/// +static std::tuple ComputeTrilinearScaleBufferSize( + int64_t output_depth, int64_t output_height, int64_t output_width, + float depth_rscale, float height_rscale, float width_rscale, + float support_value, + float& scaled_support_depth, float& scaled_support_height, + float& scaled_support_width, int32_t& window_size_depth, + int32_t& window_size_height, int32_t& window_size_width) { + scaled_support_depth = ComputeScaledSupportValue(support_value, depth_rscale); + window_size_depth = ComputeWindowSize(scaled_support_depth); + auto depth_buffer_size = ComputeWeightedCoeffBufferSize(output_depth, window_size_depth); + + const auto [y_buffer_size, w_buffer_size] = ComputeBilinearScaleBufferSize(output_height, + output_width, height_rscale, + width_rscale, support_value, + scaled_support_height, + scaled_support_width, + window_size_height, window_size_width); + return std::make_tuple(depth_buffer_size, y_buffer_size, w_buffer_size); +} + +// Antialiasing filters +struct BilinearFilter { + __device__ __host__ float operator()(float x, float /* cubic_coeff_a */) const { + if (x < 0.0f) { + x = -x; + } + if (x < 1.0f) { + return 1.0f - x; + } + return 0.0f; + } +}; + +struct BiCubicFilter { + __device__ __host__ float operator()(float x, float cubic_coeff_a) const { + /* https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm + */ + if (x < 0.0f) { + x = -x; + } + if (x < 1.0f) { + return ((cubic_coeff_a + 2.0f) * x - (cubic_coeff_a + 3.0f)) * x * x + 1; + } + if (x < 2.0f) { + return (((x - 5.0f) * x + 8.f) * x - 4.f) * cubic_coeff_a; + } + return 0.0f; + } +}; + +struct TriLinearFilter { + __device__ __host__ float operator()(float x, float /* cubic_coeff_a */) const { + if (x < 0.0f) { + x = -x; + } + if (x < 1.0f) { + return 1.0f - x; + } + return 0.0f; + } +}; + +template +struct AccumTypeCaster { + static __device__ __host__ AccumType* cast(AccumType* p) { + return p; + } +}; + +template <> +struct AccumTypeCaster { + static __device__ __host__ float* cast(int32_t* p) { + return reinterpret_cast(p); + } +}; + +template +__global__ void _ComputeInterpolationAtLevel1( + int64_t num_channels, + int64_t input_height, int64_t input_width, + int64_t output_height, int64_t output_width, + const fast_divmod div_output_width, + const fast_divmod div_output_image, + int32_t window_size, + const uint8_t* clip8_table, + const int64_t* bound_data, + std::tuple outof_bounds_buffers, + const AccumType* weight_coefficients, + const T* Xdata, T* Ydata, + const int N) { + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); + + // No need to do scale + if (output_width == input_width) { + Ydata[id] = Xdata[id]; + return; + } + + int bxc, output_image_index; + div_output_image.divmod(id, bxc, output_image_index); + + int output_y, output_x; + div_output_width.divmod(output_image_index, output_y, output_x); + + CUDA_LONG input_index = static_cast(bxc * num_channels * input_height * input_width); + CUDA_LONG output_index = static_cast(bxc * num_channels * output_height * output_width); + + auto* Ydata_offset = Ydata + output_index + output_width * output_y + output_x; + const auto* bound = bound_data; + + AccumType output = onnxruntime::is_8bit_v ? ConstValue::mag_factor : 0; + + const auto* weight_coeff = weight_coefficients + window_size * output_x; + int64_t xmin = bound[static_cast(output_x) * 2]; + int64_t xmax = bound[static_cast(output_x) * 2 + 1]; + + // Input window + const auto* Xdata_offset = Xdata + input_index + input_width * output_y + xmin; + + for (; xmin < xmax; ++xmin) { + if constexpr (std::is_same::value) { + // This cast is needed when we deal with half + output += static_cast((*Xdata_offset++)) * (*weight_coeff++); + } else { + output += (*Xdata_offset++) * (*weight_coeff++); + } + } + + if constexpr (onnxruntime::is_8bit_v) { + const uint8_t* clip8_lookups = &clip8_table[640]; + *Ydata_offset = static_cast(clip8_lookups[output >> 22]); + } else if constexpr (std::is_same::value) { + *Ydata_offset = static_cast(std::round(output)); + } else { + *Ydata_offset = static_cast(output); + } +} + +template +__global__ void _ComputeInterpolationAtLevel2( + int64_t num_channels, + int64_t input_height, int64_t input_width, + int64_t output_height, int64_t output_width, + const fast_divmod div_output_height, + const fast_divmod div_output_width, + const fast_divmod div_output_image, + int32_t window_size, + bool use_extrapolation, float extrapolation_value, + const uint8_t* clip8_table, + const int64_t* bound_data, + std::tuple outof_bounds_buffers, + const AccumType* weight_coefficients, + const T* Xdata, T* Ydata, int N) { + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); + + // No need to do scale + if (output_height == input_height) { + Ydata[id] = Xdata[id]; + return; + } + + int bxc, output_image_index; + div_output_image.divmod(id, bxc, output_image_index); + + int output_z, output_y, output_x, temp; + div_output_height.divmod(output_image_index, output_z, temp); + div_output_width.divmod(temp, output_y, output_x); + + CUDA_LONG input_index = static_cast(bxc * num_channels * input_height * input_width + + output_z * input_height * input_width); + CUDA_LONG output_index = static_cast(bxc * num_channels * output_height * output_width + + output_z * output_height * output_width); + + auto* Ydata_offset = Ydata + output_index + output_width * output_y + output_x; + + if (use_extrapolation) { + const auto* w_outof_bounds = std::get<1>(outof_bounds_buffers); + // Extrapolate along the w dimension + if (w_outof_bounds[static_cast(output_x)] != -1) { + *Ydata_offset = static_cast(extrapolation_value); + return; + } + + // Extrapolate along the y dimension + const auto* y_outof_bounds = std::get<0>(outof_bounds_buffers); + if (y_outof_bounds[static_cast(output_y)] != -1) { + *Ydata_offset = static_cast(extrapolation_value); + return; + } + } + + const auto* bound = bound_data; + + AccumType output = onnxruntime::is_8bit_v ? ConstValue::mag_factor : 0; + + const auto* weight_coeff = weight_coefficients + window_size * output_y; + int64_t ymin = bound[static_cast(output_y) * 2]; + int64_t ymax = bound[static_cast(output_y) * 2 + 1]; + + const auto* Xdata_offset = Xdata + input_index + ymin * output_width + output_x; + + for (; ymin < ymax; ++ymin) { + if constexpr (std::is_same::value) { + // We cast to AccumType to resolve ambiguous call to operator* for half in CUDA + output += static_cast((*Xdata_offset)) * (*weight_coeff++); + } else { + output += (*Xdata_offset) * (*weight_coeff++); + } + Xdata_offset += input_width; + } + + if constexpr (onnxruntime::is_8bit_v) { + const uint8_t* clip8_lookups = &clip8_table[640]; + *Ydata_offset = static_cast(clip8_lookups[output >> 22]); + } else if constexpr (std::is_same::value) { + *Ydata_offset = static_cast(std::round(output)); + } else { + *Ydata_offset = output; + } +} + +template +__global__ void _ComputeInterpolationAtLevel3( + int64_t input_depth, + int64_t input_height, int64_t input_width, + int64_t output_depth, + int64_t output_height, int64_t output_width, + const fast_divmod div_output_height, + const fast_divmod div_output_width, + const fast_divmod div_output_image, + int32_t window_size, + bool use_extrapolation, float extrapolation_value, + const uint8_t* clip8_table, + const int64_t* bound_data, + std::tuple outof_bounds_buffers, + const AccumType* weight_coefficients, + const T* Xdata, T* Ydata, int N) { + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); + + // No need to do scale + if (input_depth == output_depth) { + Ydata[id] = Xdata[id]; + return; + } + + int bxc, output_image_index; + div_output_image.divmod(id, bxc, output_image_index); + + int output_z, output_y, output_x, temp; + div_output_height.divmod(output_image_index, output_z, temp); + div_output_width.divmod(temp, output_y, output_x); + + CUDA_LONG input_index = static_cast(bxc * input_depth * input_height * input_width); + + auto* Ydata_offset = Ydata + id; + + if (use_extrapolation) { + const auto* w_outof_bounds = std::get<2>(outof_bounds_buffers); + // Extrapolate along the w dimension + if (w_outof_bounds[static_cast(output_x)] != -1) { + *Ydata_offset = static_cast(extrapolation_value); + return; + } + + // Extrapolate along the y dimension + const auto* y_outof_bounds = std::get<1>(outof_bounds_buffers); + if (y_outof_bounds[static_cast(output_y)] != -1) { + *Ydata_offset = static_cast(extrapolation_value); + return; + } + + // Extrapolate along the z dimension + const int64_t* z_outof_bounds = std::get<0>(outof_bounds_buffers); + if (z_outof_bounds != nullptr && z_outof_bounds[static_cast(output_z)] != -1) { + *Ydata_offset = static_cast(extrapolation_value); + return; + } + } + + const auto* bound = bound_data; + + AccumType output = onnxruntime::is_8bit_v ? ConstValue::mag_factor : 0; + + const auto* weight_coeff = weight_coefficients + window_size * output_z; + int64_t zmin = bound[static_cast(output_z) * 2]; + int64_t zmax = bound[static_cast(output_z) * 2 + 1]; + + const auto z_step = input_height * input_width; + const auto* Xdata_offset = Xdata + input_index + zmin * z_step + output_y * output_width + output_x; + + for (; zmin < zmax; ++zmin) { + if constexpr (std::is_same::value) { + // We cast to AccumType to resolve ambiguous call to operator* for half in CUDA + output += static_cast((*Xdata_offset)) * (*weight_coeff++); + } else { + output += (*Xdata_offset) * (*weight_coeff++); + } + Xdata_offset += z_step; + } + + if constexpr (onnxruntime::is_8bit_v) { + const uint8_t* clip8_lookups = &clip8_table[640]; + *Ydata_offset = static_cast(clip8_lookups[output >> 22]); + } else if constexpr (std::is_same::value) { + *Ydata_offset = static_cast(std::round(output)); + } else { + *Ydata_offset = output; + } +} + +/// +/// This function expects the following buffers to be pre-allocated on device +/// 1. bounds: int64_t[output_size * 2] +/// 2. out_of_bounds: int64_t[output_size] +/// 3. scale_data: T[output_size * window_size] +/// +/// Template parameter AccumType +/// +template +FUNC_DEF void SetupUpsampleFilterAnitAliasImpl( + int64_t i, + int64_t input_size, int64_t output_size, + float rscale, + float roi_start, float roi_end, + float scaled_support, int32_t window_size, bool exclude_outside, + float cubic_coeff_a, + int64_t* bounds, + int64_t* out_of_bounds, + AccumType* scale_data) { + Filter filter{}; + CudaFunctionOriginalCoordinate get_original_coordinate{}; + + const auto scale = 1.f / rscale; + const float inv_scale = (scale >= 1.0f) ? 1.0f / scale : 1.0f; + + const float id = static_cast(i); + float center = 0.5f; + if (scale == 1.0f) { + center += id; + } else { + center += get_original_coordinate(id, rscale, + static_cast(output_size), + static_cast(input_size), + roi_start, roi_end); + } + + if (center - 0.5f < 0 || center - 0.5f > static_cast(input_size - 1)) { + out_of_bounds[i] = i; + } else { + out_of_bounds[i] = -1; + } + + float total_weight{0}; + + auto fmin = _Floor(center - scaled_support + 0.5f); + auto fmax = _Floor(center + scaled_support + 0.5f); + + int64_t min_real = static_cast(fmin); + int64_t max_real = static_cast(fmax); + int64_t min_cut = std::max(min_real, 0); + int64_t max_cut = std::min(max_real, input_size); + + int64_t min_val = exclude_outside ? min_cut : min_real; + int64_t max_val = exclude_outside ? max_cut : max_real; + bounds[i * 2] = min_cut; + bounds[i * 2 + 1] = max_cut; + + // This is done for int32_t case, when the final result is in int32_t, but + // we perform calculations in float. All other types as is. + auto* scale_buffer = AccumTypeCaster::cast(&scale_data[i * window_size]); + + max_val -= min_val; + for (int64_t x = 0; x < max_val; x++) { + const float arg = (x + min_val - center + 0.5f) * inv_scale; + const auto w = filter(arg, cubic_coeff_a); + scale_buffer[x] = w; + total_weight += w; + } + + if (!exclude_outside) { + int64_t neg_xsize = min_val < 0 ? -min_val : 0; + for (int64_t x = 0; x < neg_xsize; x++) { + scale_buffer[neg_xsize] += scale_buffer[x]; + } + + int64_t bound_size = + max_val + min_val > input_size ? max_val + min_val - input_size : 0; + for (int64_t x = max_val - bound_size; x < max_val; x++) { + scale_buffer[max_val - bound_size - 1] += + scale_buffer[x]; + } + + for (int64_t x = 0; (neg_xsize | bound_size) > 0 && x < max_cut - min_cut; x++) { + scale_buffer[x] = scale_buffer[x + neg_xsize]; + } + } + + const float total_weight_inv = (total_weight == 0) ? 1.f : (1.f / total_weight); + if constexpr (std::is_same::value) { + auto* scale_buffer_int = reinterpret_cast(scale_buffer); + for (int64_t x = 0; x < max_cut - min_cut; x++) { + scale_buffer[x] *= total_weight_inv; + // normalize the scale to 1 << 22 for int8/uint8 + scale_buffer_int[x] = static_cast(_Round(scale_buffer[x] * ConstValue::mag_factor_x_2)); + } + } else { + for (int64_t x = 0; x < max_cut - min_cut; x++) { + scale_buffer[x] *= total_weight_inv; + } + } +} + +/// This kernel computes antialias filter for bilinear or bicubic upsampling. +/// The function expects the following buffers to be pre-allocated on device +/// 1. bounds: int64_t[output_size * 2] for each of the two dimensions +/// 2. out_of_bounds: int64_t[output_size] for each of the two dimensions +/// 3. scale_data: AccumType[output_size * window_size] for each of the two dimensions +/// Buffers layout [h_data, w_data] +template +__global__ void _SetupBilinearUpsampleFilterAntiAlias( + std::tuple input_dims, // h, w + std::tuple output_dims, // h, w + std::tuple inv_scale_vals, // h, w + std::tuple roi_start_vals, // h, w + std::tuple roi_end_vals, // h, w + std::tuple dim_scaled_support, // Pre-computed scaled support values h, w + std::tuple dim_window_size, // Pre-computed windows sizes h, w + float cubic_coeff_a, + bool exclude_outside, + int64_t* bounds, + int64_t* out_of_bounds, + std::tuple weighted_coefficients // y, h buffers +) { + const auto N = std::get<0>(output_dims) + std::get<1>(output_dims); + + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); + + if (id < std::get<0>(output_dims)) { + // Setup for y + int64_t input_size = std::get<0>(input_dims); + int64_t output_size = std::get<0>(output_dims); + float inv_scale = std::get<0>(inv_scale_vals); + float roi_start = std::get<0>(roi_start_vals); + float roi_end = std::get<0>(roi_end_vals); + float scaled_support = std::get<0>(dim_scaled_support); + int32_t window_size = std::get<0>(dim_window_size); + + SetupUpsampleFilterAnitAliasImpl( + id, + input_size, output_size, + inv_scale, + roi_start, roi_end, + scaled_support, window_size, + exclude_outside, + cubic_coeff_a, + bounds, + out_of_bounds, + std::get<0>(weighted_coefficients)); + + } else { + // Setup for w + // w = id - output_height + + int64_t input_size = std::get<1>(input_dims); + int64_t output_size = std::get<1>(output_dims); + float inv_scale = std::get<1>(inv_scale_vals); + float roi_start = std::get<1>(roi_start_vals); + float roi_end = std::get<1>(roi_end_vals); + + float scaled_support = std::get<1>(dim_scaled_support); + int32_t window_size = std::get<1>(dim_window_size); + + // Adjust buffer positions + const auto y_output_size = std::get<0>(output_dims); + + auto i = id - y_output_size; + bounds += (y_output_size * 2); + out_of_bounds += y_output_size; + + SetupUpsampleFilterAnitAliasImpl( + i, + input_size, output_size, + inv_scale, + roi_start, roi_end, + scaled_support, window_size, + exclude_outside, + cubic_coeff_a, + bounds, + out_of_bounds, + std::get<1>(weighted_coefficients)); + } +} + +/// +/// Compute AntiAlias filter for trilinear upsampling, all in one go +/// The function expects the following buffers to be pre-allocated on device +/// 1. bounds: int64_t[output_size * 2] for each of the three dimensions +/// 2. out_of_bounds: int64_t[output_size] for each of the three dimensions +/// 3. scale_data: AccumType[output_size * window_size] for each of the three dimensions +/// Each kind of buffer contains data for all 3 dims. +/// Buffers layout [d_data, h_data, w_data] +/// +template +__global__ void _SetupTrilinerarUpsampleFilterAntiAlias( + std::tuple input_dims, // d, h, w + std::tuple output_dims, // d, h, w + std::tuple inv_scale_vals, // d, h, w + std::tuple roi_start_vals, // d, h, w + std::tuple roi_end_vals, // d, h, w + std::tuple dim_scaled_support, // Pre-computed scaled support values d, h, w + std::tuple dim_window_size, // Pre-computed windows sizes d, h, w + bool exclude_outisde, + int64_t* bounds, + int64_t* out_of_bounds, + std::tuple weighted_coefficients) { + const auto N = std::get<0>(output_dims) + std::get<1>(output_dims) + std::get<2>(output_dims); + + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); + + if (id < std::get<0>(output_dims)) { + // Setup for d by default (id < output_depth) + int64_t input_size = std::get<0>(input_dims); + int64_t output_size = std::get<0>(output_dims); + float inv_scale = std::get<0>(inv_scale_vals); + float roi_start = std::get<0>(roi_start_vals); + float roi_end = std::get<0>(roi_end_vals); + float scaled_support = std::get<0>(dim_scaled_support); + int32_t window_size = std::get<0>(dim_window_size); + + SetupUpsampleFilterAnitAliasImpl( + id, + input_size, output_size, + inv_scale, + roi_start, roi_end, + scaled_support, window_size, + exclude_outisde, + onnxruntime::antialias_constants::kCubicCoeffA, // Default value for trilinear + bounds, + out_of_bounds, + std::get<0>(weighted_coefficients)); + + } else if (id >= std::get<0>(output_dims) && id < (std::get<0>(output_dims) + std::get<1>(output_dims))) { + int64_t input_size = std::get<1>(input_dims); + int64_t output_size = std::get<1>(output_dims); + float inv_scale = std::get<1>(inv_scale_vals); + float roi_start = std::get<1>(roi_start_vals); + float roi_end = std::get<1>(roi_end_vals); + + float scaled_support = std::get<1>(dim_scaled_support); + int32_t window_size = std::get<1>(dim_window_size); + + // Adjust buffer positions + const auto d_output_size = std::get<0>(output_dims); + + auto i = id - d_output_size; + bounds += d_output_size * 2; + out_of_bounds += d_output_size; + + SetupUpsampleFilterAnitAliasImpl( + i, + input_size, output_size, + inv_scale, + roi_start, roi_end, + scaled_support, window_size, + exclude_outisde, + onnxruntime::antialias_constants::kCubicCoeffA, // Default value for trilinear + bounds, + out_of_bounds, + std::get<1>(weighted_coefficients)); + } else { + int64_t input_size = std::get<2>(input_dims); + int64_t output_size = std::get<2>(output_dims); + float inv_scale = std::get<2>(inv_scale_vals); + float roi_start = std::get<2>(roi_start_vals); + float roi_end = std::get<2>(roi_end_vals); + float scaled_support = std::get<2>(dim_scaled_support); + int32_t window_size = std::get<2>(dim_window_size); + + // Adjust buffer positions + const auto d_y_output_size = std::get<0>(output_dims) + std::get<1>(output_dims); + + auto i = id - d_y_output_size; + bounds += (d_y_output_size * 2); + out_of_bounds += d_y_output_size; + + SetupUpsampleFilterAnitAliasImpl( + i, + input_size, output_size, + inv_scale, + roi_start, roi_end, + scaled_support, window_size, + exclude_outisde, + onnxruntime::antialias_constants::kCubicCoeffA, // Default value for trilinear + bounds, + out_of_bounds, + std::get<2>(weighted_coefficients)); + } +} + +#define CASEA_COORD_ANTIALIAS(coordinate_mode, TransformCoordType, ...) \ + case coordinate_mode: { \ + using coord_t = TransformCoordType; \ + return __VA_ARGS__(); \ + break; \ + } + +#define DISPATCH_ANTIALIAS_FILTER_SETUP(coord_enum, ...) \ + [&] { \ + const auto the_type = coord_enum; \ + switch (the_type) { \ + CASEA_COORD_ANTIALIAS(ResizeCoordinateTransformationMode::HALF_PIXEL, \ + TransformCoordinate_HALF_PIXEL, __VA_ARGS__) \ + CASEA_COORD_ANTIALIAS(ResizeCoordinateTransformationMode::ASYMMETRIC, \ + TransformCoordinate_ASYMMETRIC, __VA_ARGS__) \ + CASEA_COORD_ANTIALIAS(ResizeCoordinateTransformationMode::PYTORCH_HALF_PIXEL, \ + TransformCoordinate_PYTORCH_HALF_PIXEL, __VA_ARGS__) \ + CASEA_COORD_ANTIALIAS(ResizeCoordinateTransformationMode::ALIGN_CORNERS, \ + TransformCoordinate_ALIGN_CORNERS, __VA_ARGS__) \ + CASEA_COORD_ANTIALIAS(ResizeCoordinateTransformationMode::TF_HALF_PIXEL_FOR_NN, \ + TransformCoordinate_TF_HALF_PIXEL_FOR_NN, __VA_ARGS__) \ + CASEA_COORD_ANTIALIAS(ResizeCoordinateTransformationMode::TF_CROP_AND_RESIZE, \ + TransformCoordinate_TF_CROP_AND_RESIZE, __VA_ARGS__) \ + default: \ + ORT_THROW("unknown ResizeCoordinateTransformationMode"); \ + } \ + }() + +namespace { +template +IAllocatorUniquePtr AllocateTyped( + const TempSpaceAllocateFunc& alloc, + size_t elements) { + return alloc(elements * sizeof(T)); +} + +template +T* GetTyped(IAllocatorUniquePtr& bytes) { + return reinterpret_cast(bytes.get()); +} +} // namespace + +template +void ResizeTrilinearUpsample( + cudaStream_t stream, + int rank, + const UpsampleMode upsample_mode, + ResizeCoordinateTransformationMode coordinate_transform_mode, + gsl::span input_shape, + gsl::span output_shape, + int64_t batch_size, int64_t num_channels, + std::tuple inferred_input_dims, + std::tuple inferred_output_dims, + std::tuple inferred_dim_rscales, + const TArray& output_div_pitches, + gsl::span roi_vals, + const std::optional& extrapolation, + bool exclude_outside, + const TempSpaceAllocateFunc& allocate_temp_space, + const uint8_t* clip8_lookups, + const T* input_data, + T* output_data, + const size_t N) { + using AccumType = typename onnxruntime::AccumulateType::type; + + const bool use_extrapolation = extrapolation.has_value(); + const float extrapolation_value = use_extrapolation ? *extrapolation : 0.f; + + int64_t input_depth, input_height, input_width; + std::tie(input_depth, input_height, input_width) = inferred_input_dims; + + int64_t output_depth, output_height, output_width; + std::tie(output_depth, output_height, output_width) = inferred_output_dims; + + int blocksPerDimsMappingGrid = + static_cast(ceil((output_depth + output_height + output_width) / 32.0)); + + int blocksPerGrid = static_cast(ceil(static_cast(N) / GridDim::maxThreadsPerBlock)); + + constexpr float support_value = antialias_constants::kSupportSize; + float z_scale, h_scale, w_scale; + std::tie(z_scale, h_scale, w_scale) = inferred_dim_rscales; + + const auto& div_output_width = output_div_pitches[rank - 2]; + + SafeInt bounds_buffer_size = (SafeInt(output_depth) + output_height + output_width) * 2; + SafeInt out_of_bounds_buffer_size = (SafeInt(output_depth) + output_height + output_width); + + auto bounds_buffer_ptr = AllocateTyped(allocate_temp_space, bounds_buffer_size); + auto out_of_bounds_buffer_ptr = AllocateTyped(allocate_temp_space, out_of_bounds_buffer_size); + + int64_t* z_bounds_buffer = GetTyped(bounds_buffer_ptr); + int64_t* y_bounds_buffer = z_bounds_buffer + output_depth * 2; + int64_t* w_bounds_buffer = y_bounds_buffer + output_height * 2; + + int64_t* z_outof_bounds_buffer = GetTyped(out_of_bounds_buffer_ptr); + int64_t* y_outof_bounds_buffer = z_outof_bounds_buffer + output_depth; + int64_t* w_outof_bounds_buffer = y_outof_bounds_buffer + output_height; + + float z_scaled_support, h_scaled_support, w_scaled_support; + int32_t z_window_size, h_window_size, w_window_size; + const auto [z_buffer_size, y_buffer_size, w_buffer_size] = ComputeTrilinearScaleBufferSize( + output_depth, output_height, output_width, + z_scale, h_scale, w_scale, support_value, + z_scaled_support, h_scaled_support, w_scaled_support, + z_window_size, h_window_size, w_window_size); + + const int64_t weighted_buffer_size = SafeInt(z_buffer_size) + y_buffer_size + w_buffer_size; + + auto weighted_buffer_ptr = AllocateTyped(allocate_temp_space, weighted_buffer_size); + AccumType* z_weighted_buffer = GetTyped(weighted_buffer_ptr); + AccumType* y_weighted_buffer = z_weighted_buffer + z_buffer_size; + AccumType* w_weighted_buffer = y_weighted_buffer + y_buffer_size; + + const auto h_w_interpolate_temp_buf_size = SafeInt(batch_size) * num_channels * + input_depth * input_height * output_width; + auto h_w_interpolate_temp_buffer_ptr = AllocateTyped(allocate_temp_space, + narrow(h_w_interpolate_temp_buf_size)); + + const auto h_w_interpolate_result_buffer_size = SafeInt(batch_size) * num_channels * + input_depth * output_height * output_width; + auto h_w_interpolate_result_buffer_ptr = AllocateTyped(allocate_temp_space, h_w_interpolate_result_buffer_size); + + // clang-format off + DISPATCH_ANTIALIAS_FILTER_SETUP(coordinate_transform_mode, [&]() { + _SetupTrilinerarUpsampleFilterAntiAlias<<>>( + inferred_input_dims, + inferred_output_dims, + inferred_dim_rscales, + std::make_tuple(roi_vals[rank - 3], roi_vals[rank - 2], roi_vals[rank - 1]), // roi starts d, h, w + std::make_tuple(roi_vals[rank - 3 + rank], roi_vals[rank - 2 + rank], // roi ends d, h, w + roi_vals[rank - 1 + rank]), + std::make_tuple(z_scaled_support, h_scaled_support, w_scaled_support), + std::make_tuple(z_window_size, h_window_size, w_window_size), + exclude_outside, + GetTyped(bounds_buffer_ptr), + GetTyped(out_of_bounds_buffer_ptr), + std::make_tuple(z_weighted_buffer, y_weighted_buffer, w_weighted_buffer)); + }); + + // clang-format on + const fast_divmod div_w_image(narrow(num_channels * input_depth * input_height * output_width)); + // clang-format off + _ComputeInterpolationAtLevel1<<>>( + num_channels * input_depth, input_height, input_width, input_height, output_width, + div_output_width, + div_w_image, + w_window_size, + clip8_lookups, + w_bounds_buffer, + std::make_tuple(y_outof_bounds_buffer, w_outof_bounds_buffer), + w_weighted_buffer, input_data, + GetTyped(h_w_interpolate_temp_buffer_ptr), + narrow(h_w_interpolate_temp_buf_size)); + + // clang-format on + const fast_divmod div_output_height{narrow(output_height * output_width)}; + const fast_divmod div_h_w_image(narrow(num_channels * input_depth * output_height * output_width)); + // clang-format off + _ComputeInterpolationAtLevel2<<>>( + num_channels * input_depth, input_height, output_width, output_height, output_width, + div_output_height, + div_output_width, + div_h_w_image, + h_window_size, + false, 0.f, // No extrapolation + clip8_lookups, + y_bounds_buffer, + std::make_tuple(y_outof_bounds_buffer, w_outof_bounds_buffer), + y_weighted_buffer, GetTyped(h_w_interpolate_temp_buffer_ptr), + GetTyped(h_w_interpolate_result_buffer_ptr), + narrow(h_w_interpolate_result_buffer_size)); + + // clang-format on + const fast_divmod div_z_h_w_image(narrow(input_depth * output_height * output_width)); + // clang-format off + _ComputeInterpolationAtLevel3<<>>( + input_depth, output_height, output_width, + output_depth, output_height, output_width, + div_output_height, + div_output_width, + div_z_h_w_image, + z_window_size, + use_extrapolation, extrapolation_value, + clip8_lookups, + z_bounds_buffer, + std::make_tuple(z_outof_bounds_buffer, y_outof_bounds_buffer, w_outof_bounds_buffer), + z_weighted_buffer, GetTyped(h_w_interpolate_result_buffer_ptr), + output_data, + narrow(N)); + // clang-format on +} + +template +void ResizeBiLinearUpsample(cudaStream_t stream, + int rank, + const UpsampleMode upsample_mode, + ResizeCoordinateTransformationMode coordinate_transform_mode, + gsl::span input_shape, + gsl::span output_shape, + int64_t batch_size, int64_t num_channels, + std::tuple inferred_input_dims, + std::tuple inferred_output_dims, + std::tuple inferred_dim_rscales, + const TArray& output_div_pitches, + gsl::span roi_vals, + const std::optional& extrapolation, + bool exclude_outside, + const TempSpaceAllocateFunc& allocate_temp_space, + const uint8_t* clip8_lookups, + const T* input_data, + T* output_data, + const size_t N) { + using AccumType = typename onnxruntime::AccumulateType::type; + + const bool use_extrapolation = extrapolation.has_value(); + const float extrapolation_value = use_extrapolation ? *extrapolation : 0.f; + + int64_t input_depth, input_height, input_width; + std::tie(input_depth, input_height, input_width) = inferred_input_dims; + + int64_t output_depth, output_height, output_width; + std::tie(output_depth, output_height, output_width) = inferred_output_dims; + + int blocksPerDimsMappingGrid = + narrow(CeilDiv((output_depth + output_height + output_width), 32)); + + // rank 2 or 4 + const fast_divmod div_output_image = (rank > 2) ? output_div_pitches[rank - 4] + : fast_divmod(gsl::narrow_cast(N)); + const fast_divmod& div_output_width = output_div_pitches[rank - 2]; + + constexpr float support_value = antialias_constants::kSupportSize; + + float h_scale, w_scale; + std::tie(std::ignore, h_scale, w_scale) = inferred_dim_rscales; + + int blocksPerGrid = narrow(CeilDiv(N, GridDim::maxThreadsPerBlock)); + + SafeInt bounds_buffer_size = (SafeInt(output_height) + output_width) * 2; + SafeInt out_of_bounds_buffer_size = (SafeInt(output_height) + output_width); + + float h_scaled_support, w_scaled_support; + int32_t h_window_size, w_window_size; + const auto [weighted_y_size, weighted_w_size] = + ComputeBilinearScaleBufferSize(output_height, output_width, + h_scale, w_scale, support_value, + h_scaled_support, w_scaled_support, h_window_size, w_window_size); + + auto bounds_buffer_ptr = AllocateTyped(allocate_temp_space, bounds_buffer_size); + auto out_of_bounds_buffer_ptr = AllocateTyped(allocate_temp_space, out_of_bounds_buffer_size); + + int64_t* y_bounds_buffer = GetTyped(bounds_buffer_ptr); + int64_t* w_bounds_buffer = y_bounds_buffer + output_height * 2; + + int64_t* y_outof_bounds_buffer = GetTyped(out_of_bounds_buffer_ptr); + int64_t* w_outof_bounds_buffer = y_outof_bounds_buffer + output_height; + + const int64_t weighted_buffer_size = SafeInt(weighted_y_size) + weighted_w_size; + auto weighted_buffer_ptr = AllocateTyped(allocate_temp_space, narrow(weighted_buffer_size)); + + AccumType* y_weighted_buffer = GetTyped(weighted_buffer_ptr); + AccumType* w_weighted_buffer = y_weighted_buffer + weighted_y_size; + + const auto temp_buf_size = num_channels * input_height * output_width; + auto image_temp_buffer = AllocateTyped(allocate_temp_space, narrow(temp_buf_size)); + + // clang-format off + DISPATCH_ANTIALIAS_FILTER_SETUP(coordinate_transform_mode, [&]() { + // Data is d, h, w in tuples + + _SetupBilinearUpsampleFilterAntiAlias<<>>( + std::make_tuple(input_height, input_width), + std::make_tuple(output_height, output_width), + std::make_tuple(h_scale, w_scale), + std::make_tuple(roi_vals[rank - 2], roi_vals[rank - 1]), // roi starts h, w + std::make_tuple(roi_vals[rank - 2 + rank], roi_vals[rank - 1 + rank]), // roi ends h, w + std::make_tuple(h_scaled_support, w_scaled_support), + std::make_tuple(h_window_size, w_window_size), + onnxruntime::antialias_constants::kCubicCoeffA, exclude_outside, + GetTyped(bounds_buffer_ptr), + GetTyped(out_of_bounds_buffer_ptr), + std::make_tuple(y_weighted_buffer, w_weighted_buffer)); + }); + + // clang-format on + const fast_divmod div_step_image{narrow(num_channels * input_height * output_width)}; + // clang-format off + _ComputeInterpolationAtLevel1<<>>( + num_channels, input_height, input_width, input_height, output_width, + div_output_width, + div_step_image, + w_window_size, + clip8_lookups, + w_bounds_buffer, + std::make_tuple(y_outof_bounds_buffer, w_outof_bounds_buffer), + w_weighted_buffer, input_data, GetTyped(image_temp_buffer), + narrow(temp_buf_size)); + + // clang-format on + const fast_divmod div_output_height{narrow(output_height * output_width)}; + // clang-format off + _ComputeInterpolationAtLevel2<<>>( + num_channels, input_height, output_width, output_height, output_width, + div_output_height, + div_output_width, + div_output_image, + h_window_size, + use_extrapolation, extrapolation_value, + clip8_lookups, + y_bounds_buffer, + std::make_tuple(y_outof_bounds_buffer, w_outof_bounds_buffer), + y_weighted_buffer, GetTyped(image_temp_buffer), output_data, + narrow(N)); + + // clang-format on +} + +template +void ResizeBicubicUpsample(cudaStream_t stream, + int rank, + const UpsampleMode upsample_mode, + ResizeCoordinateTransformationMode coordinate_transform_mode, + gsl::span input_shape, + gsl::span output_shape, + int64_t batch_size, int64_t num_channels, + std::tuple inferred_input_dims, + std::tuple inferred_output_dims, + std::tuple inferred_dim_rscales, + // const TArray& input_strides, + const TArray& output_div_pitches, + gsl::span roi_vals, + const std::optional& extrapolation, + bool exclude_outside, + const TempSpaceAllocateFunc& allocate_temp_space, + const uint8_t* clip8_lookups, + const T* input_data, + T* output_data, + const size_t N) { + using AccumType = typename onnxruntime::AccumulateType::type; + + const bool use_extrapolation = extrapolation.has_value(); + const float extrapolation_value = use_extrapolation ? *extrapolation : 0.f; + + int blocksPerGrid = narrow(CeilDiv(N, GridDim::maxThreadsPerBlock)); + const fast_divmod div_output_image = (rank > 2) ? output_div_pitches[rank - 4] + : fast_divmod(gsl::narrow_cast(N)); + const fast_divmod& div_output_width = output_div_pitches[rank - 2]; + + constexpr float support_value = antialias_constants::kBiCubicSupportSize; + + int64_t input_depth, input_height, input_width; + std::tie(input_depth, input_height, input_width) = inferred_input_dims; + + int64_t output_depth, output_height, output_width; + std::tie(output_depth, output_height, output_width) = inferred_output_dims; + + int blocksPerDimsMappingGrid = + narrow(CeilDiv((output_depth + output_height + output_width), 32)); + + float h_scale, w_scale; + std::tie(std::ignore, h_scale, w_scale) = inferred_dim_rscales; + + SafeInt bounds_buffer_size = (SafeInt(output_height) + output_width) * 2; + SafeInt out_of_bounds_buffer_size = (SafeInt(output_height) + output_width); + + float h_scaled_support, w_scaled_support; + int32_t h_window_size, w_window_size; + const auto [weighted_y_size, weighted_w_size] = + ComputeBilinearScaleBufferSize(output_height, output_width, + h_scale, w_scale, support_value, + h_scaled_support, w_scaled_support, h_window_size, w_window_size); + + auto bounds_buffer_ptr = AllocateTyped(allocate_temp_space, bounds_buffer_size); + auto out_of_bounds_buffer_ptr = AllocateTyped(allocate_temp_space, out_of_bounds_buffer_size); + + int64_t* y_bounds_buffer = GetTyped(bounds_buffer_ptr); + int64_t* w_bounds_buffer = y_bounds_buffer + output_height * 2; + + int64_t* y_outof_bounds_buffer = GetTyped(out_of_bounds_buffer_ptr); + int64_t* w_outof_bounds_buffer = y_outof_bounds_buffer + output_height; + + const int64_t weighted_buffer_size = SafeInt(weighted_y_size) + + weighted_w_size; + auto weighted_buffer_ptr = AllocateTyped(allocate_temp_space, weighted_buffer_size); + + AccumType* y_weighted_buffer = GetTyped(weighted_buffer_ptr); + AccumType* w_weighted_buffer = y_weighted_buffer + weighted_y_size; + + const auto temp_buf_size = SafeInt(batch_size) * num_channels * input_height * output_width; + auto image_temp_buffer = AllocateTyped(allocate_temp_space, narrow(temp_buf_size)); + + // clang-format off + DISPATCH_ANTIALIAS_FILTER_SETUP(coordinate_transform_mode, [&]() { + _SetupBilinearUpsampleFilterAntiAlias<<>>( + std::make_tuple(input_height, input_width), + std::make_tuple(output_height, output_width), + std::make_tuple(h_scale, w_scale), + std::make_tuple(roi_vals[rank - 2], roi_vals[rank - 1]), // roi starts h, w + std::make_tuple(roi_vals[rank - 2 + rank], roi_vals[rank - 1 + rank]), // roi ends h, w + std::make_tuple(h_scaled_support, w_scaled_support), + std::make_tuple(h_window_size, w_window_size), + onnxruntime::antialias_constants::kCubicCoeffA, exclude_outside, + GetTyped(bounds_buffer_ptr), + GetTyped(out_of_bounds_buffer_ptr), + std::make_tuple(y_weighted_buffer, w_weighted_buffer)); + }); + // clang-format on + const fast_divmod div_step_image(narrow(num_channels * input_height * output_width)); + // clang-format off + _ComputeInterpolationAtLevel1<<>>( + num_channels, input_height, input_width, input_height, output_width, + div_output_width, + div_step_image, + w_window_size, + clip8_lookups, + w_bounds_buffer, + std::make_tuple(y_outof_bounds_buffer, w_outof_bounds_buffer), + w_weighted_buffer, input_data, GetTyped(image_temp_buffer), + narrow(temp_buf_size)); + // clang-format on + + const fast_divmod div_output_height{narrow(output_height * output_width)}; + // clang-format off + _ComputeInterpolationAtLevel2<<>>( + num_channels, input_height, output_width, output_height, output_width, + div_output_height, + div_output_width, + div_output_image, + h_window_size, + use_extrapolation, extrapolation_value, + clip8_lookups, + y_bounds_buffer, + std::make_tuple(y_outof_bounds_buffer, w_outof_bounds_buffer), + y_weighted_buffer, GetTyped(image_temp_buffer), output_data, + narrow(N)); + // clang-format on +} + +template +void ResizeAntiAliasImpl( + cudaStream_t stream, + int rank, + const UpsampleMode upsample_mode, + ResizeCoordinateTransformationMode coordinate_transform_mode, + gsl::span input_shape, + gsl::span output_shape, + int64_t batch_size, int64_t num_channels, + std::tuple inferred_input_dims, + std::tuple inferred_output_dims, + std::tuple inferred_dim_rscales, + const TArray& output_div_pitches, + gsl::span roi_vals, + const std::optional& extrapolation, + bool exclude_outside, + TempSpaceAllocateFunc allocate_temp_space, + const uint8_t* clip8_lookups, + const T* input_data, + T* output_data, + const size_t N) { + // We support a special case of bilinear or bicubic if the input data is 4D with the outer 2 scales being 1.0 + // We would have validated the outer scale values by the time execution reaches this + const bool is_2D = (rank == 2 || rank == 4); + + // We support a special case of trilinear or tricubic if the input data is 5D with the outer 2 scales being 1.0 + // We would have validated the outer scale values by the time execution reaches this + const bool is_3D = (rank == 3 || rank == 5); + + // Should not hit this as we have already validated input rank/scales and we provide verbose error messages + // to the user. + ORT_ENFORCE(is_2D || is_3D, "Only bilinear/trilinear and bicubic modes are supported in Resize anti-alias mode"); + + switch (upsample_mode) { + case UpsampleMode::LINEAR: { + if (is_2D) { + ResizeBiLinearUpsample(stream, rank, upsample_mode, coordinate_transform_mode, + input_shape, output_shape, batch_size, num_channels, + inferred_input_dims, inferred_output_dims, inferred_dim_rscales, + output_div_pitches, roi_vals, extrapolation, exclude_outside, + allocate_temp_space, clip8_lookups, input_data, output_data, N); + } else if (is_3D) { + ResizeTrilinearUpsample(stream, rank, upsample_mode, coordinate_transform_mode, + input_shape, output_shape, batch_size, num_channels, + inferred_input_dims, inferred_output_dims, inferred_dim_rscales, + output_div_pitches, roi_vals, extrapolation, exclude_outside, + allocate_temp_space, clip8_lookups, input_data, output_data, N); + } else { + ORT_NOT_IMPLEMENTED("Resize supports only 2-D or 3-D in LINEAR mode."); + } + } break; + case CUBIC: { + if (is_2D) { + ResizeBicubicUpsample(stream, rank, upsample_mode, coordinate_transform_mode, + input_shape, output_shape, batch_size, num_channels, + inferred_input_dims, inferred_output_dims, inferred_dim_rscales, + output_div_pitches, roi_vals, extrapolation, exclude_outside, + allocate_temp_space, clip8_lookups, input_data, output_data, N); + } else { + ORT_NOT_IMPLEMENTED("Resize supports only 2-D in CUBIC mode."); + } + } break; + default: + ORT_NOT_IMPLEMENTED("Only bilinear/trilinear and bicubic modes are supported in Resize anti-alias mode"); + break; + } +} + +#define SPECIALIZED_ANTIALIAS_IMPL(T) \ + template void ResizeAntiAliasImpl( \ + cudaStream_t stream, \ + int rank, \ + const UpsampleMode upsample_mode, \ + ResizeCoordinateTransformationMode coordinate_transform_mode, \ + gsl::span input_shape, \ + gsl::span output_shape, \ + int64_t batch_size, int64_t num_channels, \ + std::tuple inferred_input_dims, \ + std::tuple inferred_output_dims, \ + std::tuple inferred_dim_rscales, \ + const TArray& output_div_pitches, \ + gsl::span roi_vals, \ + const std::optional& extrapolation_value, \ + bool exclude_outside, \ + TempSpaceAllocateFunc allocate_temp_space, \ + const uint8_t* clip8_lookups, \ + const T* input_data, \ + T* output_data, \ + const size_t N); + +SPECIALIZED_ANTIALIAS_IMPL(float) +SPECIALIZED_ANTIALIAS_IMPL(double) +SPECIALIZED_ANTIALIAS_IMPL(half) +SPECIALIZED_ANTIALIAS_IMPL(int32_t) +SPECIALIZED_ANTIALIAS_IMPL(uint8_t) + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/resize_impl.cu b/onnxruntime/core/providers/cuda/tensor/resize_impl.cu index 1a94c7705e913..0cde0ed8e8681 100644 --- a/onnxruntime/core/providers/cuda/tensor/resize_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/resize_impl.cu @@ -12,7 +12,7 @@ using onnxruntime::ResizeNearestMode; using onnxruntime::UpsampleMode; struct NearestPixel_SIMPLE { - __device__ __forceinline__ int operator() (float x_original, bool is_down_sampling) const { + __device__ __forceinline__ int operator()(float x_original, bool is_down_sampling) const { if (is_down_sampling) { return static_cast(_Ceil(x_original)); } @@ -21,7 +21,7 @@ struct NearestPixel_SIMPLE { }; struct NearestPixel_ROUND_PREFER_FLOOR { - __device__ __forceinline__ int operator() (float x_original, bool) const { + __device__ __forceinline__ int operator()(float x_original, bool) const { if (x_original == static_cast(x_original) + 0.5f) { return static_cast(_Floor(x_original)); } @@ -30,62 +30,23 @@ struct NearestPixel_ROUND_PREFER_FLOOR { }; struct NearestPixel_ROUND_PREFER_CEIL { - __device__ __forceinline__ int operator() (float x_original, bool) const { + __device__ __forceinline__ int operator()(float x_original, bool) const { return static_cast(roundf(x_original)); } }; struct NearestPixel_FLOOR { - __device__ __forceinline__ int operator() (float x_original, bool) const { + __device__ __forceinline__ int operator()(float x_original, bool) const { return static_cast(_Floor(x_original)); } }; struct NearestPixel_CEIL { - __device__ __forceinline__ int operator() (float x_original, bool) const { + __device__ __forceinline__ int operator()(float x_original, bool) const { return static_cast(_Ceil(x_original)); } }; -struct TransformCoordinate_ASYMMETRIC { - __device__ __forceinline__ float operator() (float x_resized, float x_scale, float, float, float, float) const { - return x_resized / x_scale; - } -}; - -struct TransformCoordinate_HALF_PIXEL { - __device__ __forceinline__ float operator() (float x_resized, float x_scale, float, float, float, float) const { - return ((x_resized + 0.5f) / x_scale) - 0.5f; - } -}; - -struct TransformCoordinate_PYTORCH_HALF_PIXEL { - __device__ __forceinline__ float operator() (float x_resized, float x_scale, float length_resized, float, float, float) const { - return length_resized > 1 ? (x_resized + 0.5f) / x_scale - 0.5f : 0.0f; - } -}; - -struct TransformCoordinate_TF_HALF_PIXEL_FOR_NN { - __device__ __forceinline__ float operator() (float x_resized, float x_scale, float, float, float, float) const { - return (x_resized + 0.5f) / x_scale; - } -}; - -struct TransformCoordinate_ALIGN_CORNERS { - __device__ __forceinline__ float operator() (float x_resized, float, float length_resized, float length_original, float, float) const { - return length_resized == 1 ? 0 : x_resized * (length_original - 1) / (length_resized - 1); - } -}; - -struct TransformCoordinate_TF_CROP_AND_RESIZE { - __device__ __forceinline__ float operator() (float x_resized, float, float length_resized, float length_original, float roi_start, float roi_end) const { - auto orig = length_resized > 1 - ? roi_start * (length_original - 1) + (x_resized * (roi_end - roi_start) * (length_original - 1)) / (length_resized - 1) - : 0.5 * (roi_start + roi_end) * (length_original - 1); - return static_cast(orig); - } -}; - #define CASE_TYPE_USING_HINT(enum_type, type, HINT, ...) \ case enum_type: { \ using HINT = type; \ @@ -95,20 +56,24 @@ struct TransformCoordinate_TF_CROP_AND_RESIZE { #define CASE_TYPE_COORD(enum_type, type, ...) \ CASE_TYPE_USING_HINT(enum_type, type, coord_t, __VA_ARGS__) -#define DISPATCH_RESIZE_COORDINATE_TRANSFORMATION_MODE(TYPE, ...) \ - [&] { \ - const auto& the_type = TYPE; \ - /* don't use TYPE again in case it is an expensive or side-effect op */ \ - switch (the_type) { \ - CASE_TYPE_COORD(ResizeCoordinateTransformationMode::HALF_PIXEL, TransformCoordinate_HALF_PIXEL, __VA_ARGS__) \ - CASE_TYPE_COORD(ResizeCoordinateTransformationMode::ASYMMETRIC, TransformCoordinate_ASYMMETRIC, __VA_ARGS__) \ - CASE_TYPE_COORD(ResizeCoordinateTransformationMode::PYTORCH_HALF_PIXEL, TransformCoordinate_PYTORCH_HALF_PIXEL, __VA_ARGS__) \ - CASE_TYPE_COORD(ResizeCoordinateTransformationMode::ALIGN_CORNERS, TransformCoordinate_ALIGN_CORNERS, __VA_ARGS__) \ - CASE_TYPE_COORD(ResizeCoordinateTransformationMode::TF_HALF_PIXEL_FOR_NN, TransformCoordinate_TF_HALF_PIXEL_FOR_NN, __VA_ARGS__) \ - CASE_TYPE_COORD(ResizeCoordinateTransformationMode::TF_CROP_AND_RESIZE, TransformCoordinate_TF_CROP_AND_RESIZE, __VA_ARGS__) \ - default: \ - ORT_THROW("unknown ResizeCoordinateTransformationMode"); \ - } \ +#define DISPATCH_RESIZE_COORDINATE_TRANSFORMATION_MODE(TYPE, ...) \ + [&] { \ + const auto& the_type = TYPE; \ + /* don't use TYPE again in case it is an expensive or side-effect op */ \ + switch (the_type) { \ + CASE_TYPE_COORD(ResizeCoordinateTransformationMode::HALF_PIXEL, TransformCoordinate_HALF_PIXEL, __VA_ARGS__) \ + CASE_TYPE_COORD(ResizeCoordinateTransformationMode::ASYMMETRIC, TransformCoordinate_ASYMMETRIC, __VA_ARGS__) \ + CASE_TYPE_COORD(ResizeCoordinateTransformationMode::PYTORCH_HALF_PIXEL, \ + TransformCoordinate_PYTORCH_HALF_PIXEL, __VA_ARGS__) \ + CASE_TYPE_COORD(ResizeCoordinateTransformationMode::ALIGN_CORNERS, \ + TransformCoordinate_ALIGN_CORNERS, __VA_ARGS__) \ + CASE_TYPE_COORD(ResizeCoordinateTransformationMode::TF_HALF_PIXEL_FOR_NN, \ + TransformCoordinate_TF_HALF_PIXEL_FOR_NN, __VA_ARGS__) \ + CASE_TYPE_COORD(ResizeCoordinateTransformationMode::TF_CROP_AND_RESIZE, \ + TransformCoordinate_TF_CROP_AND_RESIZE, __VA_ARGS__) \ + default: \ + ORT_THROW("unknown ResizeCoordinateTransformationMode"); \ + } \ }() #define CASE_TYPE_NEAREST(enum_type, type, ...) \ @@ -119,11 +84,11 @@ struct TransformCoordinate_TF_CROP_AND_RESIZE { const auto& the_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ switch (the_type) { \ - CASE_TYPE_NEAREST(ResizeNearestMode::SIMPLE, NearestPixel_SIMPLE, __VA_ARGS__) \ + CASE_TYPE_NEAREST(ResizeNearestMode::SIMPLE, NearestPixel_SIMPLE, __VA_ARGS__) \ CASE_TYPE_NEAREST(ResizeNearestMode::ROUND_PREFER_FLOOR, NearestPixel_ROUND_PREFER_FLOOR, __VA_ARGS__) \ - CASE_TYPE_NEAREST(ResizeNearestMode::ROUND_PREFER_CEIL, NearestPixel_ROUND_PREFER_CEIL, __VA_ARGS__) \ - CASE_TYPE_NEAREST(ResizeNearestMode::FLOOR, NearestPixel_FLOOR, __VA_ARGS__) \ - CASE_TYPE_NEAREST(ResizeNearestMode::CEIL, NearestPixel_CEIL, __VA_ARGS__) \ + CASE_TYPE_NEAREST(ResizeNearestMode::ROUND_PREFER_CEIL, NearestPixel_ROUND_PREFER_CEIL, __VA_ARGS__) \ + CASE_TYPE_NEAREST(ResizeNearestMode::FLOOR, NearestPixel_FLOOR, __VA_ARGS__) \ + CASE_TYPE_NEAREST(ResizeNearestMode::CEIL, NearestPixel_CEIL, __VA_ARGS__) \ default: \ ORT_THROW("unknown ResizeNearestMode"); \ } \ @@ -151,10 +116,12 @@ __global__ void _ResizeNearestMappingKernel2D( // only apply co-ordinate transformation if scale != 1.0 if (scales_height == 1.0f) { - dims_mapping[id].extrapolate_ = 0; + dims_mapping[id].extrapolate_ = 0; } else { - float orig_coord = transform_coordinate(static_cast(dim), scales_height, static_cast(output_height), - static_cast(input_height), roi_start_height, roi_end_height); + float orig_coord = transform_coordinate(static_cast(dim), scales_height, + static_cast(output_height), + static_cast(input_height), + roi_start_height, roi_end_height); dims_mapping[id].extrapolate_ = static_cast( extrapolation_enabled && (orig_coord < 0.f || orig_coord > static_cast(input_height - 1))); dim = calc_nearest_pixel(orig_coord, scales_height < 1); @@ -210,9 +177,12 @@ __global__ void _ResizeNearestMappingKernel( if (scales[axis] == 1.0f) { dims_mapping[id].extrapolate_ = 0; } else { - float orig_coord = transform_coordinate(static_cast(dim), scales[axis], static_cast(output_shape[axis]), + float orig_coord = transform_coordinate(static_cast(dim), scales[axis], + static_cast(output_shape[axis]), static_cast(input_shape[axis]), roi[axis], roi[axis + rank]); - dims_mapping[id].extrapolate_ = static_cast(extrapolation_enabled && (orig_coord < 0.f || orig_coord > static_cast(input_shape[axis] - 1))); + dims_mapping[id].extrapolate_ = static_cast(extrapolation_enabled && + (orig_coord < 0.f || + orig_coord > static_cast(input_shape[axis] - 1))); dim = calc_nearest_pixel(orig_coord, scales[axis] < 1); if (dim >= input_shape[axis]) dim = input_shape[axis] - 1; if (dim < 0) dim = 0; @@ -293,21 +263,27 @@ __global__ void _ResizeBilinearCoordinateMapping( LinearMappingInfo* dims_mapping) { CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, SumHW); if (id < output_height) { // y = id - float input_y = scale_height == 1 ? static_cast(id) : - transform_coordinate(static_cast(id), scale_height, - static_cast(output_height), static_cast(input_height), - roi_height_start, roi_height_end); - dims_mapping[id].extrapolate_ = (int)(extrapolation_enabled && (input_y < 0 || input_y > static_cast(input_height - 1))); + float input_y = scale_height == 1 ? static_cast(id) + : transform_coordinate(static_cast(id), scale_height, + static_cast(output_height), + static_cast(input_height), + roi_height_start, roi_height_end); + dims_mapping[id].extrapolate_ = static_cast((extrapolation_enabled && + (input_y < 0 || + input_y > static_cast(input_height - 1)))); input_y = max(0.0f, min(input_y, static_cast(input_height - 1))); int y_int = static_cast(input_y); dims_mapping[id].origin_ = y_int; dims_mapping[id].weight_ = (y_int >= input_height - 1) ? 0.5f : input_y - y_int; - } else { //x = id - output_height - float input_x = scale_width == 1 ? static_cast(id - output_height) : - transform_coordinate(static_cast(id - output_height), scale_width, - static_cast(output_width), static_cast(input_width), - roi_width_start, roi_width_end); - dims_mapping[id].extrapolate_ = (int)(extrapolation_enabled && (input_x < 0 || input_x > static_cast(input_width - 1))); + } else { // x = id - output_height + float input_x = scale_width == 1 ? static_cast(id - output_height) + : transform_coordinate(static_cast(id - output_height), + scale_width, static_cast(output_width), + static_cast(input_width), roi_width_start, + roi_width_end); + dims_mapping[id].extrapolate_ = static_cast((extrapolation_enabled && + (input_x < 0 || + input_x > static_cast(input_width - 1)))); input_x = max(0.0f, min(input_x, static_cast(input_width - 1))); int x_int = static_cast(input_x); dims_mapping[id].origin_ = x_int; @@ -371,32 +347,40 @@ __global__ void _ResizeTrilinearCoordinateMapping( LinearMappingInfo* dims_mapping) { CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, SumDHW); if (id < output_depth) { // z = id - float input_z = scale_depth == 1 ? static_cast(id) : - transform_coordinate(static_cast(id), scale_depth, - static_cast(output_depth), static_cast(input_depth), - roi_depth_start, roi_depth_end); - dims_mapping[id].extrapolate_ = (int)(extrapolation_enabled && (input_z < 0 || input_z > static_cast(input_depth - 1))); + float input_z = scale_depth == 1 ? static_cast(id) + : transform_coordinate(static_cast(id), scale_depth, + static_cast(output_depth), + static_cast(input_depth), + roi_depth_start, roi_depth_end); + dims_mapping[id].extrapolate_ = static_cast((extrapolation_enabled && + (input_z < 0 || + input_z > static_cast(input_depth - 1)))); input_z = max(0.0f, min(input_z, static_cast(input_depth - 1))); int z_int = static_cast(input_z); dims_mapping[id].origin_ = z_int; dims_mapping[id].weight_ = (z_int >= input_depth - 1) ? 0.5f : input_z - z_int; } else if (id >= output_depth && id < (output_depth + output_height)) { // y = id - output_depth - float input_y = scale_height == 1 ? static_cast(id - output_depth) : - transform_coordinate(static_cast(id - output_depth), scale_height, - static_cast(output_height), static_cast(input_height), - roi_height_start, roi_height_end); - - dims_mapping[id].extrapolate_ = (int)(extrapolation_enabled && (input_y < 0 || input_y > static_cast(input_height - 1))); + float input_y = scale_height == 1 ? static_cast(id - output_depth) + : transform_coordinate(static_cast(id - output_depth), + scale_height, static_cast(output_height), + static_cast(input_height), + roi_height_start, roi_height_end); + + dims_mapping[id].extrapolate_ = static_cast((extrapolation_enabled && + (input_y < 0 || + input_y > static_cast(input_height - 1)))); input_y = max(0.0f, min(input_y, static_cast(input_height - 1))); int y_int = static_cast(input_y); dims_mapping[id].origin_ = y_int; dims_mapping[id].weight_ = (y_int >= input_height - 1) ? 0.5f : input_y - y_int; - } else { //x = id - output_depth - output_height - float input_x = scale_width == 1 ? static_cast(id - output_depth - output_height) : - transform_coordinate(static_cast(id - output_depth - output_height), scale_width, - static_cast(output_width), static_cast(input_width), - roi_width_start, roi_width_end); - dims_mapping[id].extrapolate_ = (int)(extrapolation_enabled && (input_x < 0 || input_x > static_cast(input_width - 1))); + } else { // x = id - output_depth - output_height + float input_x = scale_width == 1 ? static_cast(id - output_depth - output_height) + : transform_coordinate(static_cast(id - output_depth - output_height), + scale_width, static_cast(output_width), + static_cast(input_width), + roi_width_start, roi_width_end); + dims_mapping[id].extrapolate_ = (int)(extrapolation_enabled && (input_x < 0 || + input_x > static_cast(input_width - 1))); input_x = max(0.0f, min(input_x, static_cast(input_width - 1))); int x_int = static_cast(input_x); dims_mapping[id].origin_ = x_int; @@ -513,21 +497,33 @@ __global__ void _ResizeCubicCoordinateMapping( int max_input_coord = static_cast(is_y_axis ? input_height : input_width); float scale = is_y_axis ? scale_height : scale_width; - float input_coordinat = scale == 1 ? (is_y_axis ? id : id - output_height) : - transform_coordinate( - static_cast(is_y_axis ? id : id - output_height), - scale, - static_cast(is_y_axis ? output_height : output_width), - static_cast(max_input_coord), - (is_y_axis ? roi_height_start : roi_width_start), - (is_y_axis ? roi_height_end : roi_width_end)); + float input_coordinat = scale == 1 ? (is_y_axis ? id : id - output_height) + : transform_coordinate( + static_cast(is_y_axis ? id : id - output_height), + scale, + static_cast(is_y_axis ? output_height : output_width), + static_cast(max_input_coord), + (is_y_axis ? roi_height_start : roi_width_start), + (is_y_axis ? roi_height_end : roi_width_end)); int coord_int = static_cast(_Floor(input_coordinat)); float s_coord = abs(input_coordinat - coord_int); float coeff_sum = 1.0f; - float coeff_0 = static_cast(((cubic_coeff_a * (s_coord + 1) - 5 * cubic_coeff_a) * (s_coord + 1) + 8 * cubic_coeff_a) * (s_coord + 1) - 4 * cubic_coeff_a); - float coeff_1 = static_cast(((cubic_coeff_a + 2) * s_coord - (cubic_coeff_a + 3)) * s_coord * s_coord + 1); - float coeff_2 = static_cast(((cubic_coeff_a + 2) * (1 - s_coord) - (cubic_coeff_a + 3)) * (1 - s_coord) * (1 - s_coord) + 1); - float coeff_3 = static_cast(((cubic_coeff_a * (2 - s_coord) - 5 * cubic_coeff_a) * (2 - s_coord) + 8 * cubic_coeff_a) * (2 - s_coord) - 4 * cubic_coeff_a); + float coeff_0 = static_cast(((cubic_coeff_a * (s_coord + 1) - 5 * cubic_coeff_a) * + (s_coord + 1) + + 8 * cubic_coeff_a) * + (s_coord + 1) - + 4 * cubic_coeff_a); + float coeff_1 = static_cast(((cubic_coeff_a + 2) * s_coord - (cubic_coeff_a + 3)) * + s_coord * s_coord + + 1); + float coeff_2 = static_cast(((cubic_coeff_a + 2) * (1 - s_coord) - (cubic_coeff_a + 3)) * + (1 - s_coord) * (1 - s_coord) + + 1); + float coeff_3 = static_cast(((cubic_coeff_a * (2 - s_coord) - 5 * cubic_coeff_a) * + (2 - s_coord) + + 8 * cubic_coeff_a) * + (2 - s_coord) - + 4 * cubic_coeff_a); if (exclude_outside) { coeff_0 = (coord_int - 1 < 0 || coord_int - 1 >= max_input_coord) ? 0.0 : coeff_0; coeff_1 = (coord_int + 0 < 0 || coord_int + 0 >= max_input_coord) ? 0.0 : coeff_1; @@ -540,7 +536,8 @@ __global__ void _ResizeCubicCoordinateMapping( dm.coeff1_ = coeff_1 / coeff_sum; dm.coeff2_ = coeff_2 / coeff_sum; dm.coeff3_ = coeff_3 / coeff_sum; - dm.extrapolate_ = (int)(extrapolation_enabled && (input_coordinat < 0 || input_coordinat > static_cast(max_input_coord - 1))); + dm.extrapolate_ = (int)(extrapolation_enabled && (input_coordinat < 0 || + input_coordinat > static_cast(max_input_coord - 1))); } template @@ -569,21 +566,30 @@ __global__ void _ResizeBiCubicKernel( int x_int = x_info.origin_; int y_int = y_info.origin_; const T* image = input_data + input_index; - output_data[id] = y_info.coeff0_ * CubicInterpolationRowwise(image, x_int, y_int - 1, input_height, input_width, w0, w1, w2, w3) + - y_info.coeff1_ * CubicInterpolationRowwise(image, x_int, y_int, input_height, input_width, w0, w1, w2, w3) + - y_info.coeff2_ * CubicInterpolationRowwise(image, x_int, y_int + 1, input_height, input_width, w0, w1, w2, w3) + - y_info.coeff3_ * CubicInterpolationRowwise(image, x_int, y_int + 2, input_height, input_width, w0, w1, w2, w3); + output_data[id] = y_info.coeff0_ * + CubicInterpolationRowwise(image, x_int, y_int - 1, input_height, input_width, w0, w1, w2, w3) + + y_info.coeff1_ * + CubicInterpolationRowwise(image, x_int, y_int, input_height, input_width, w0, w1, w2, w3) + + y_info.coeff2_ * + CubicInterpolationRowwise(image, x_int, y_int + 1, input_height, input_width, w0, w1, w2, w3) + + y_info.coeff3_ * + CubicInterpolationRowwise(image, x_int, y_int + 2, input_height, input_width, w0, w1, w2, w3); } size_t CalcResizeBufferSize(const onnxruntime::UpsampleMode upsample_mode, const gsl::span& output_dims) { switch (upsample_mode) { case UpsampleMode::NN: - return sizeof(int64_t) * output_dims.size() + sizeof(NearestMappingInfo) * static_cast(std::accumulate(output_dims.begin(), output_dims.end(), (int64_t)0)); + return sizeof(int64_t) * output_dims.size() + + sizeof(NearestMappingInfo) * + static_cast(std::accumulate(output_dims.begin(), + output_dims.end(), (int64_t)0)); case UpsampleMode::LINEAR: - return sizeof(LinearMappingInfo) * static_cast(std::accumulate(output_dims.rbegin(), output_dims.rbegin() + 2, (int64_t)0)); + return sizeof(LinearMappingInfo) * + static_cast(std::accumulate(output_dims.rbegin(), output_dims.rbegin() + 2, (int64_t)0)); case UpsampleMode::CUBIC: - return sizeof(CubicMappingInfo) * static_cast(std::accumulate(output_dims.rbegin(), output_dims.rbegin() + 2, (int64_t)0)); + return sizeof(CubicMappingInfo) * + static_cast(std::accumulate(output_dims.rbegin(), output_dims.rbegin() + 2, (int64_t)0)); } return 0; } @@ -616,7 +622,8 @@ void ResizeNearestImpl( if (could2d) { int64_t output_height = output_shape[rank - 2]; int64_t output_width = output_shape[rank - 1]; - fast_divmod div_output_image = (rank > 2) ? output_div_pitches[rank - 3] : fast_divmod(static_cast(output_height * output_width)); + fast_divmod div_output_image = (rank > 2) ? output_div_pitches[rank - 3] + : fast_divmod(static_cast(output_height * output_width)); int blocksPerDimsMappingGrid = static_cast(ceil((output_height + output_width) / 32.0)); DISPATCH_RESIZE_COORDINATE_TRANSFORMATION_MODE(transform_coordinate, [&]() { @@ -694,13 +701,6 @@ void ResizeImpl( ResizeCoordinateTransformationMode coordinate_transform_mode, ResizeNearestMode nearest_mode, void* dims_mapping) { - bool isSame = std::all_of(scales_vals.Data(), scales_vals.Data() + rank, [](float v) { return v == 1.0f; }) && - (coordinate_transform_mode != ResizeCoordinateTransformationMode::TF_CROP_AND_RESIZE); - if (isSame) { - CUDA_CALL_THROW(cudaMemcpyAsync(output_data, input_data, N * sizeof(T), cudaMemcpyDeviceToDevice, stream)); - return; - } - if (upsample_mode == UpsampleMode::NN) { ResizeNearestImpl( stream, rank, input_shape, output_shape, input_strides, output_div_pitches, @@ -761,7 +761,7 @@ void ResizeImpl( } else if (is_3D) { DISPATCH_RESIZE_COORDINATE_TRANSFORMATION_MODE(coordinate_transform_mode, [&]() { _ResizeTrilinearCoordinateMapping<<>>( - input_shape[rank - 3] , input_shape[rank - 2], input_shape[rank - 1], + input_shape[rank - 3], input_shape[rank - 2], input_shape[rank - 1], output_depth, output_height, output_width, scales_vals[rank - 3], scales_vals[rank - 2], scales_vals[rank - 1], roi_vals[rank - 3], roi_vals[rank - 3 + rank], @@ -778,7 +778,7 @@ void ResizeImpl( reinterpret_cast(dims_mapping)); return; } - ORT_THROW("Only bilinear/trilinear and bicubic modes are supported in Resize"); + ORT_THROW("Resize support 2-D and 3-D dimensions in LINEAR mode."); break; case UpsampleMode::CUBIC: if (is_2D) { @@ -801,7 +801,7 @@ void ResizeImpl( reinterpret_cast(dims_mapping)); return; } - ORT_THROW("Only bilinear/trilinear and bicubic modes are supported in Resize"); + ORT_THROW("Resize supports only 2-D in CUBIC mode."); case UpsampleMode::NN: ORT_THROW("Only bilinear/trilinear and bicubic modes are supported in Resize"); } @@ -809,7 +809,7 @@ void ResizeImpl( #define SPECIALIZED_IMPL(T) \ template void ResizeImpl( \ - cudaStream_t stream, \ + cudaStream_t stream, \ const UpsampleMode upsample_mode, \ const int rank, \ TArray& input_shape, \ diff --git a/onnxruntime/core/providers/cuda/tensor/resize_impl.h b/onnxruntime/core/providers/cuda/tensor/resize_impl.h index d459dbff18d3e..ad06eebb9efb1 100644 --- a/onnxruntime/core/providers/cuda/tensor/resize_impl.h +++ b/onnxruntime/core/providers/cuda/tensor/resize_impl.h @@ -2,15 +2,69 @@ // Licensed under the MIT License. #pragma once + #include + +#include + #include "core/providers/cuda/shared_inc/cuda_utils.h" #include "core/common/common.h" #include "core/providers/cpu/tensor/upsamplebase.h" #include "core/providers/cuda/cuda_common.h" namespace onnxruntime { +template <> +struct AccumulateType { + using type = float; +}; namespace cuda { +struct TransformCoordinate_ASYMMETRIC { + __device__ __host__ __forceinline__ float operator()(float x_resized, float x_scale, + float, float, float, float) const { + return x_resized / x_scale; + } +}; + +struct TransformCoordinate_HALF_PIXEL { + __device__ __host__ __forceinline__ float operator()(float x_resized, float x_scale, + float, float, float, float) const { + return ((x_resized + 0.5f) / x_scale) - 0.5f; + } +}; + +struct TransformCoordinate_PYTORCH_HALF_PIXEL { + __device__ __host__ __forceinline__ float operator()(float x_resized, float x_scale, float length_resized, float, + float, float) const { + return length_resized > 1 ? (x_resized + 0.5f) / x_scale - 0.5f : 0.0f; + } +}; + +struct TransformCoordinate_TF_HALF_PIXEL_FOR_NN { + __device__ __host__ __forceinline__ float operator()(float x_resized, float x_scale, + float, float, float, float) const { + return (x_resized + 0.5f) / x_scale; + } +}; + +struct TransformCoordinate_ALIGN_CORNERS { + __device__ __host__ __forceinline__ float operator()(float x_resized, float, float length_resized, + float length_original, float, float) const { + return length_resized == 1 ? 0 : x_resized * (length_original - 1) / (length_resized - 1); + } +}; + +struct TransformCoordinate_TF_CROP_AND_RESIZE { + __device__ __host__ __forceinline__ float operator()(float x_resized, float, float length_resized, + float length_original, float roi_start, float roi_end) const { + auto orig = length_resized > 1 + ? roi_start * (length_original - 1) + + (x_resized * (roi_end - roi_start) * (length_original - 1)) / (length_resized - 1) + : 0.5 * (roi_start + roi_end) * (length_original - 1); + return static_cast(orig); + } +}; + size_t CalcResizeBufferSize(const onnxruntime::UpsampleMode upsample_mode, const gsl::span& output_dims); @@ -36,5 +90,62 @@ void ResizeImpl( onnxruntime::ResizeNearestMode nearest_mode, void* dims_mapping); +using TempSpaceAllocateFunc = std::function(size_t buffer_size)>; + +template +void ResizeAntiAliasImpl( + cudaStream_t stream, + int rank, + const UpsampleMode upsample_mode, + ResizeCoordinateTransformationMode coordinate_transform_mode, + gsl::span input_shape, + gsl::span output_shape, + int64_t batch_size, int64_t num_channels, + std::tuple inferred_input_dims, + std::tuple inferred_output_dims, + std::tuple inferred_dim_rscales, + const TArray& output_div_pitches, + gsl::span roi_vals, // CPU + const std::optional& extrapolation_value, + bool exclude_outside, + TempSpaceAllocateFunc allocate_temp_space, + const uint8_t* clip8_lookups, + const T* input_data, + T* output_data, + const size_t N); + +/// +/// Compute scaled support value for a given dimension inverse scale +/// +/// Support value from parameters +/// inverse scale value comes from input/attr for +/// +inline float ComputeScaledSupportValue(float support_value, float rscale) { + const float scale = 1.0f / rscale; + float scaled_support = (scale >= 1.0f) ? (support_value * 0.5f) * scale : support_value * 0.5f; + return scaled_support; +} + +/// +/// Compute window size for a given dimension scaled support value. +/// +/// +/// +inline int32_t ComputeWindowSize(float scaled_support) { + SafeInt window_size(ceilf(scaled_support)); + return window_size * 2 + 1; +} + +/// +/// Computes scale buffer size in number of elements for allocation purposes. +/// +/// +/// +/// Number of elements to fit in the buffer +inline SafeInt ComputeWeightedCoeffBufferSize(int64_t output_size, int32_t window_size) { + SafeInt buffer_size(output_size); + return buffer_size * window_size; +} + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/upsample.cc b/onnxruntime/core/providers/cuda/tensor/upsample.cc index ae12ca328bc7c..17533eb3d9a72 100644 --- a/onnxruntime/core/providers/cuda/tensor/upsample.cc +++ b/onnxruntime/core/providers/cuda/tensor/upsample.cc @@ -2,6 +2,9 @@ // Licensed under the MIT License. #include "upsample.h" + +#include + #include "upsample_impl.h" #include "core/providers/cuda/tensor/resize_impl.h" #include "core/providers/cpu/tensor/utils.h" @@ -37,11 +40,23 @@ REGISTER_VERSIONED_TYPED_KERNEL(MLFloat16, 9, 9); REGISTER_VERSIONED_TYPED_KERNEL(int32_t, 9, 9); REGISTER_VERSIONED_TYPED_KERNEL(uint8_t, 9, 9); +template +Upsample::Upsample(const OpKernelInfo& info) : UpsampleBase(info), CudaKernel(info) { + if (UpsampleBase::antialias_) { + // Copy the table on DEVICE + const uint8_t* lookup_table = GetLookupTableShared(); + auto alloc = info.GetAllocator(OrtMemTypeDefault); + shared_lookup_table_ondevice_ = IAllocator::MakeUniquePtr(std::move(alloc), kLookupTableSize); + CUDA_CALL_THROW(cudaMemcpyAsync(shared_lookup_table_ondevice_.get(), lookup_table, kLookupTableSize, + cudaMemcpyHostToDevice, nullptr)); + } +} + template Status Upsample::BaseCompute(OpKernelContext* context, - const std::vector& roi, - const std::vector& scales, - const gsl::span& output_dims) const { + gsl::span roi, + gsl::span scales, + gsl::span output_dims) const { const Tensor* X = context->Input(0); auto X_dims = X->Shape().GetDims(); int32_t rank = static_cast(X_dims.size()); @@ -52,7 +67,8 @@ Status Upsample::BaseCompute(OpKernelContext* context, is_resize_ ? "Resize: input tensor cannot be scalar." : "Upsample: input tensor cannot be scalar."); if (rank != static_cast(scales.size())) return Status(ONNXRUNTIME, INVALID_ARGUMENT, - is_resize_ ? "Resize: input tensor's dimension does not match the scales." : "Upsample: input tensor's dimension does not match the scales."); + is_resize_ ? "Resize: input tensor's dimension does not match the scales." + : "Upsample: input tensor's dimension does not match the scales."); if (roi.size() != 2 * X_dims.size()) return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Resize: size of roi array should be 2 * N where N is the rank of input tensor X."); @@ -79,22 +95,194 @@ Status Upsample::BaseCompute(OpKernelContext* context, size_t output_count = Y->Shape().Size(); if (is_resize_) { - TArray input_shape(X_dims); - TArray output_shape(output_dims); - TArray roi_vals(roi); - TArray scales_vals(scales); - - size_t temp_buffer_size = CalcResizeBufferSize(mode_, output_dims); - auto dims_mapping_buffer = GetScratchBuffer(temp_buffer_size, context->GetComputeStream()); - void* dims_mapping = reinterpret_cast(dims_mapping_buffer.get()); - ResizeImpl(Stream(context), mode_, (int)rank, input_shape, output_shape, - input_strides, output_div_pitches, scales_vals, roi_vals, - reinterpret_cast(X->Data()), - reinterpret_cast(Y->MutableData()), - output_count, use_extrapolation_, ToCudaType::FromFloat(extrapolation_value_), - cubic_coeff_a_, exclude_outside_, - coordinate_transform_mode_, nearest_mode_, - dims_mapping); + const bool is_same = std::all_of(scales.begin(), scales.end(), [](float v) { return v == 1.0f; }) && + (coordinate_transform_mode_ != ResizeCoordinateTransformationMode::TF_CROP_AND_RESIZE); + if (is_same) { + CUDA_CALL_THROW(cudaMemcpyAsync(Y->MutableData(), X->Data(), + output_count * sizeof(T), cudaMemcpyDeviceToDevice, Stream(context))); + return Status::OK(); + } + + if (antialias_) { + TempSpaceAllocateFunc allocate_temp_space = [&](size_t bytes_size) { + return GetScratchBuffer(bytes_size, context->GetComputeStream()); + }; + + std::optional extrapolation_value; + if (use_extrapolation_) + extrapolation_value.emplace(extrapolation_value_); + + switch (mode_) { + case UpsampleMode::LINEAR: { + if (X_dims.size() == 2 || X_dims.size() == 4) { + const bool is_2D = X_dims.size() == 2; + + int64_t batch_size = 1; + int64_t num_channels = 1; + + int64_t input_height; + int64_t input_width; + + int64_t output_height; + int64_t output_width; + + float height_scale; + float width_scale; + + if (is_2D) { + input_height = X_dims[0]; + input_width = X_dims[1]; + + output_height = output_dims[0]; + output_width = output_dims[1]; + + height_scale = scales[0]; + width_scale = scales[1]; + } else { + if (scales[0] == 1.0f && scales[1] == 1.0f) { + batch_size = X_dims[Channels::N]; + num_channels = X_dims[Channels::C]; + input_height = X_dims[Channels::H]; + input_width = X_dims[Channels::W]; + + output_height = output_dims[Channels::H]; + output_width = output_dims[Channels::W]; + + height_scale = scales[2]; + width_scale = scales[3]; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Resize", ": NHWC is not supported yet"); + } + } + + ResizeAntiAliasImpl(Stream(context), + rank, + mode_, + coordinate_transform_mode_, + X_dims, output_dims, + batch_size, num_channels, + std::make_tuple(0, input_height, input_width), + std::make_tuple(0, output_height, output_width), + std::make_tuple(0.f, height_scale, width_scale), + output_div_pitches, + roi, + extrapolation_value, + exclude_outside_, + allocate_temp_space, + shared_lookup_table_ondevice_.get(), + reinterpret_cast(X->Data()), + reinterpret_cast(Y->MutableData()), + output_count); + + } else if (X_dims.size() == 3 || X_dims.size() == 5) { + const bool is_3D = X_dims.size() == 3; + + if (!is_3D) { + if (!(scales[0] == 1.0f && scales[1] == 1.0f)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Resize", ": NDHWC is not supported yet"); + } + } + + const int64_t batch_size = is_3D ? 1 : X_dims[0]; + const int64_t num_channels = is_3D ? 1 : X_dims[1]; + const int64_t input_depth = is_3D ? X_dims[0] : X_dims[2]; + const int64_t input_height = is_3D ? X_dims[1] : X_dims[3]; + const int64_t input_width = is_3D ? X_dims[2] : X_dims[4]; + + const int64_t output_depth = is_3D ? output_dims[0] : output_dims[2]; + const int64_t output_height = is_3D ? output_dims[1] : output_dims[3]; + const int64_t output_width = is_3D ? output_dims[2] : output_dims[4]; + + const float depth_scale = is_3D ? scales[0] : scales[2]; + const float height_scale = is_3D ? scales[1] : scales[3]; + const float width_scale = is_3D ? scales[2] : scales[4]; + + ResizeAntiAliasImpl(Stream(context), + rank, + mode_, + coordinate_transform_mode_, + X_dims, output_dims, + batch_size, num_channels, + std::make_tuple(input_depth, input_height, input_width), + std::make_tuple(output_depth, output_height, output_width), + std::make_tuple(depth_scale, height_scale, width_scale), + output_div_pitches, + roi, + extrapolation_value, + exclude_outside_, + allocate_temp_space, + shared_lookup_table_ondevice_.get(), + reinterpret_cast(X->Data()), + reinterpret_cast(Y->MutableData()), + output_count); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Resize", + ": 'Linear' mode only support 2-D inputs or 3-D inputs ('Bilinear', 'Trilinear') " + "or 4-D inputs or 5-D inputs with the corresponding outermost 2 scale values " + "being 1."); + } + } break; + case UpsampleMode::CUBIC: { + if (X_dims.size() != 2 && X_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Resize", + ": 'Cubic' mode only support 2-D inputs ('Bicubic') or 4-D inputs " + "with the corresponding outermost 2 scale values being 1."); + } + + const bool is_2D = X_dims.size() == 2; + const bool is_nchw = is_2D ? true : (scales[1] == 1.0f && scales[1] == 1.0f); + + ORT_RETURN_IF_NOT(is_nchw, + "Resize 'Cubic' mode only supports NCWH layout " + " with 2-D or 4-D with leading dims equal to 1"); + + const int64_t batch_size = is_2D ? 1 : X_dims[Channels::N]; + const int64_t num_channels = is_2D ? 1 : X_dims[Channels::C]; + const int64_t input_height = is_2D ? X_dims[0] : X_dims[Channels::H]; + const int64_t input_width = is_2D ? X_dims[1] : X_dims[Channels::W]; + + const int64_t output_height = is_2D ? output_dims[0] : output_dims[Channels::H]; + const int64_t output_width = is_2D ? output_dims[1] : output_dims[Channels::W]; + const float height_scale = is_2D ? scales[0] : scales[2]; + const float width_scale = is_2D ? scales[1] : scales[3]; + + ResizeAntiAliasImpl(Stream(context), rank, mode_, coordinate_transform_mode_, + X_dims, output_dims, + batch_size, num_channels, + std::make_tuple(0, input_height, input_width), + std::make_tuple(0, output_height, output_width), + std::make_tuple(0.f, height_scale, width_scale), + output_div_pitches, + roi, + extrapolation_value, + exclude_outside_, + allocate_temp_space, + shared_lookup_table_ondevice_.get(), + reinterpret_cast(X->Data()), + reinterpret_cast(Y->MutableData()), + output_count); + } break; + default: + return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Resize: unexpected mode"); + } + } else { + TArray input_shape(X_dims); + TArray output_shape(output_dims); + TArray roi_vals(roi); + TArray scales_vals(scales); + + size_t temp_buffer_size = CalcResizeBufferSize(mode_, output_dims); + auto dims_mapping_buffer = GetScratchBuffer(temp_buffer_size, context->GetComputeStream()); + void* dims_mapping = reinterpret_cast(dims_mapping_buffer.get()); + ResizeImpl(Stream(context), mode_, rank, input_shape, output_shape, + input_strides, output_div_pitches, scales_vals, roi_vals, + reinterpret_cast(X->Data()), + reinterpret_cast(Y->MutableData()), + output_count, use_extrapolation_, ToCudaType::FromFloat(extrapolation_value_), + cubic_coeff_a_, exclude_outside_, + coordinate_transform_mode_, nearest_mode_, + dims_mapping); + } } else { TArray scales_div(rank); @@ -124,7 +312,7 @@ Status Upsample::ComputeInternal(OpKernelContext* context) const { auto input_dims = X->Shape().GetDims(); TensorShapeVector output_dims(input_dims.size()); - std::vector roi_array(input_dims.size() * 2, 0.0f); + InlinedVector roi_array(input_dims.size() * 2, 0.0f); if (!roi_cached_) { bool use_default_roi = true; if (need_roi_input_) { @@ -147,29 +335,37 @@ Status Upsample::ComputeInternal(OpKernelContext* context) const { } } - const std::vector& roi = roi_cached_ ? roi_ : roi_array; - std::vector scales_array = scales_; + ComputeROIWithAxes(roi_array, input_dims.size()); + InlinedVector scales_array(input_dims.size()); + // opset < 10 if (OpKernel::Node().InputDefs().size() == 1) { - // Compute output shape from scales and input dims + // Compute output shape from scales attributes and input dims + scales_array = scales_; + ComputeOutputShape(scales_array, input_dims, output_dims); - return BaseCompute(context, roi, scales_, output_dims); + return BaseCompute(context, roi_array, scales_, output_dims); } const Tensor* scales = context->Input(scales_input_idx_); const Tensor* sizes = context->Input(sizes_input_idx_); + // This is when scales are obtained and cached from a constant initializer if (scales_cached_) { - ORT_ENFORCE(sizes == nullptr, "Only one of scales or sizes must be provided as input."); + ORT_RETURN_IF_NOT(sizes == nullptr, "Only one of scales or sizes must be provided as input."); + scales_array = scales_; + // Compute output shape from scales and input dims ComputeOutputShape(scales_array, input_dims, output_dims); - return BaseCompute(context, roi, scales_, output_dims); + return BaseCompute(context, roi_array, scales_array, output_dims); } - scales_array.resize((input_dims.size())); + // Scales and sizes are input to the node if (scales != nullptr && scales->Shape().Size() != 0) { // use scales input data ORT_ENFORCE(sizes == nullptr, "Only one of scales or sizes must be provided as input."); ORT_RETURN_IF_ERROR(ParseScalesData(scales, scales_array, input_dims.size())); + + // Compute output shape from scales and input dims ComputeOutputShape(scales_array, input_dims, output_dims); } else { // When sizes input is available directly populate it into the output_dims array. @@ -179,7 +375,7 @@ Status Upsample::ComputeInternal(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(ParseScalesDataAndAdjustOutputSize(output_dims, input_dims, scales_array)); } - return BaseCompute(context, roi, scales_array, output_dims); + return BaseCompute(context, roi_array, scales_array, output_dims); } } // namespace cuda diff --git a/onnxruntime/core/providers/cuda/tensor/upsample.h b/onnxruntime/core/providers/cuda/tensor/upsample.h index 7bf2a23ede399..50597e0fba1b9 100644 --- a/onnxruntime/core/providers/cuda/tensor/upsample.h +++ b/onnxruntime/core/providers/cuda/tensor/upsample.h @@ -13,12 +13,14 @@ namespace cuda { template class Upsample : public UpsampleBase, public CudaKernel { public: - Upsample(const OpKernelInfo& info) : UpsampleBase(info), CudaKernel(info) { - } + explicit Upsample(const OpKernelInfo& info); Status ComputeInternal(OpKernelContext* context) const override; - Status BaseCompute(OpKernelContext* context, const std::vector& roi, const std::vector& scales, - const gsl::span& output_dims) const; + Status BaseCompute(OpKernelContext* context, gsl::span roi, gsl::span scales, + gsl::span output_dims) const; + + private: + IAllocatorUniquePtr shared_lookup_table_ondevice_; }; } // namespace cuda diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index 3fd5423681b81..0265c06b9a938 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -1145,11 +1145,11 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceSumSquare); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, GatherND); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Dropout); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Resize); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Resize); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Resize); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, Resize); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint8_t, Resize); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, float, Resize); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, double, Resize); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, MLFloat16, Resize); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, int32_t, Resize); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, uint8_t, Resize); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, If); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, Loop); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Flatten); @@ -1304,6 +1304,11 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, bool, Pad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, ScatterElements); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, Resize); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, Resize); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, Resize); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, int32_t, Resize); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, uint8_t, Resize); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, Split); // Opset 19 @@ -2081,11 +2086,16 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2240,6 +2250,16 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, // Opset 19 diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index da17135878fe5..7b73ab36b3742 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -24,6 +24,7 @@ #include "core/providers/cpu/tensor/size.h" #include "core/providers/cpu/tensor/scatter_nd.h" #include "core/providers/cpu/tensor/unsqueeze.h" +#include "core/providers/cpu/tensor/upsamplebase.h" #include "core/providers/cpu/tensor/tile.h" #ifndef DISABLE_CONTRIB_OPS @@ -572,6 +573,11 @@ std::unique_ptr> EinsumTypedComputeProcessor template <> std::unique_ptr> EinsumTypedComputeProcessor::Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) { return g_host_cpu.EinsumTypedComputeProcessor_MLFloat16__Create(context, allocator, tp, einsum_compute_preprocessor, einsum_cuda_assets); } +void UpsampleBase::AdjustOutputSizeAsPolicy(TensorShapeVector& output_dims, gsl::span input_dims, + InlinedVector& scales) const { + g_host_cpu.UpsampleBase__AdjustOutputSizeAsPolicy(this, output_dims, input_dims, scales); +} + #ifndef DISABLE_CONTRIB_OPS namespace contrib { Status embed_layer_norm::CheckInputs(const OpKernelContext* context, bool quantizedVersion) { @@ -648,7 +654,6 @@ Status Sampling::SetupSubgraphExecutionInfo(const SessionState& session_state, c const SessionState& subgraph_session_state) { return g_host_cpu.Sampling__SetupSubgraphExecutionInfo(this, session_state, attribute_name, subgraph_session_state); } - } // namespace transformers #ifdef ENABLE_ATEN diff --git a/onnxruntime/core/providers/xnnpack/tensor/resize.cc b/onnxruntime/core/providers/xnnpack/tensor/resize.cc index 0c9e2e9fc17a2..09666c8039402 100644 --- a/onnxruntime/core/providers/xnnpack/tensor/resize.cc +++ b/onnxruntime/core/providers/xnnpack/tensor/resize.cc @@ -288,7 +288,7 @@ Status Resize::Compute(OpKernelContext* ctx) const { // Get scales data const auto* scales = ctx->Input(scales_input_idx_); - std::vector scales_array(X->Shape().GetDims().size()); + InlinedVector scales_array(X->Shape().GetDims().size()); if (scales != nullptr && scales->Shape().Size() != 0) { ORT_RETURN_IF_ERROR(ParseScalesData(scales, scales_array, output_shape.size())); diff --git a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc index 10f02349a24d5..1d31f3fdb4eb4 100644 --- a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc @@ -11,7 +11,8 @@ namespace test { TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_tf_crop_and_resize) { // TODO: Unskip when fixed #41968513 if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 0.20000028610229492, which exceeds threshold"; + GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] " + << "is 0.20000028610229492, which exceeds threshold"; } OpTester test("Resize", 13); @@ -32,7 +33,8 @@ TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_tf_crop_and_resize) { test.AddInput("X", {H, W}, X); test.AddInput("roi", {4}, roi); - test.AddInput("", {0}, scales); // opset13 requires either 'sizes' or 'scales' must be provided, but not both of them + // opset13 requires either 'sizes' or 'scales' must be provided, but not both of them + test.AddInput("", {0}, scales); test.AddInput("sizes", {2}, sizes); std::vector Y = {7.600004f, 7.9f, 8.2f, @@ -188,7 +190,9 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_without_e // CUDA: result mismatch due to not implementing NHWC support // ROCm: results mismatch // DML: results mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider, kDmlExecutionProvider}); + test.Run( + OpTester::ExpectResult::kExpectSuccess, "", + {kCudaExecutionProvider, kRocmExecutionProvider, kDmlExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_without_extrapolation_int8) { @@ -317,7 +321,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_int8) { // The output size is [1,1,2,4].*[1,1,0.6,0.6]=[1,1,1,2] // NNAPI will recaluclate the scales as the output size divided by input size // scales = [1,1,1,2]./[1,1,2,4] = [1,1,0.5,0.5] -// See, https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/internal/reference/reference_ops.h +// See:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/internal/reference/reference_ops.h // So the result of the above example will be different than CPU EP // Add the following 2 tests to test with scales valid to NNAPI TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear1) { @@ -475,7 +479,8 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_align_corners_int TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_2DBilinear_pytorch_half_pixel) { // TODO: Unskip when fixed #41968513 if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 1.5000001192092896, which exceeds threshold"; + GTEST_SKIP() << "Skipping because of the following error: " + << " The difference between expected[i] and output[i] is 1.5000001192092896, which exceeds threshold"; } OpTester test("Resize", 13); @@ -533,7 +538,8 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_pytorch_half_pixe // CUDA: result mismatch due to not implementing NHWC support // ROCm: results mismatch // DML: results mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider, kDmlExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaExecutionProvider, kRocmExecutionProvider, kDmlExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_pytorch_half_pixel_int8) { @@ -721,7 +727,8 @@ TEST(ResizeOpTest, ResizeOpLinearUpSampleTest_2DBilinear_align_corners) { TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_3DTrilinear_pytorch_half_pixel) { // TODO: Unskip when fixed #41968513 if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 1.5000001192092896, which exceeds threshold"; + GTEST_SKIP() << "Skipping because of the following error: " + << "The difference between expected[i] and output[i] is 1.5000001192092896, which exceeds threshold"; } OpTester test("Resize", 13); @@ -1088,7 +1095,8 @@ TEST(ResizeOpTest, ResizeOpNearestUpSample_Floor_Align_Corners) { TEST(ResizeOpTest, ResizeOpNearest_OneToOneMappingBetweenInputAndOutputDataDims) { // TODO: Unskip when fixed #41968513 if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 3, which exceeds threshold"; + GTEST_SKIP() << "Skipping because of the following error: " + << "The difference between expected[i] and output[i] is 3, which exceeds threshold"; } OpTester test("Resize", 12); // tf_half_pixel_for_nn is deprecated since opset 13 @@ -1480,7 +1488,8 @@ TEST(ResizeOpTest, ResizeOpCubicUpSampleTest_tf_half_pixel_for_nn) { TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear_Ver10) { // TODO: Unskip when fixed #41968513 if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 1.6666665077209473, which exceeds threshold"; + GTEST_SKIP() << "Skipping because of the following error: " + << "The difference between expected[i] and output[i] is 1.6666665077209473, which exceeds threshold"; } OpTester test("Resize", 10); @@ -1505,7 +1514,8 @@ TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear_Ver10) { TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_2DBilinear_Ver10) { // TODO: Unskip when fixed #41968513 if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 1.6666665077209473, which exceeds threshold"; + GTEST_SKIP() << "Skipping because of the following error: " + << "The difference between expected[i] and output[i] is 1.6666665077209473, which exceeds threshold "; } OpTester test("Resize", 10); @@ -1530,7 +1540,8 @@ TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_2DBilinear_Ver10) { TEST(ResizeOpTest, ResizeOpLinearUpSampleTest_4DBilinear_Ver10) { // TODO: Unskip when fixed #41968513 if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 0.5, which exceeds threshold"; + GTEST_SKIP() << "Skipping because of the following error: " + << "The difference between expected[i] and output[i] is 0.5, which exceeds threshold"; } OpTester test("Resize", 10); @@ -1565,7 +1576,8 @@ TEST(ResizeOpTest, ResizeOpLinearUpSampleTest_4DBilinear_Ver10) { TEST(ResizeOpTest, ResizeOpLinearUpSampleTest_2DBilinear_Ver10) { // TODO: Unskip when fixed #41968513 if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 0.5, which exceeds threshold"; + GTEST_SKIP() << "Skipping because of the following error: " + << "The difference between expected[i] and output[i] is 0.5, which exceeds threshold"; } OpTester test("Resize", 10); @@ -1676,7 +1688,8 @@ TEST(UpsampleOpTest, ResizeOpNearestNoScaleTest_Ver10) { TEST(ResizeOpTest, ResizeOp_MissingRoiAndMissingScalesOptionalInputs) { // TODO: Unskip when fixed #41968513 if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1876): The parameter is incorrect."; + GTEST_SKIP() << "Skipping because of the following error: " + << "MLOperatorAuthorImpl.cpp(1876): The parameter is incorrect."; } OpTester test("Resize", 13); @@ -1827,7 +1840,8 @@ template void TestAntialiasing(std::map attributes, std::vector input_shape, std::vector input_data, - std::vector output_shape_or_scale, std::vector output_data) { + std::vector output_shape_or_scale, std::vector output_data, + gsl::span excluded_ep = {}) { auto parse_attr = [](const std::string& str, auto typed_v) { using Tdata = decltype(typed_v); std::vector vect; @@ -1891,13 +1905,22 @@ void TestAntialiasing(std::map attributes, } test.AddOutput("Y", output_shape, output_data); - // TensorRT 8.5 supports operators up to Opset 17. Temporarily exclude TensorRT EP due to accurarcy issue. - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + + std::unordered_set excluded_eps; + std::transform(excluded_ep.begin(), excluded_ep.end(), + std::inserter(excluded_eps, excluded_eps.end()), [](std::string_view ep) { + return std::string(ep); + }); + // TensorRT 8.5 supports operators up to Opset 17. Temporarily exclude TensorRT EP due to accuracy issue. + excluded_eps.insert(kTensorrtExecutionProvider); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded_eps); } TEST(ResizeOpTest, Antialias_Bilinear_No_ExcludeOutside) { if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because dml implementation of antialias is slightly different and doesn't match in all cases."; + GTEST_SKIP() << "Skipping because dml implementation of antialias " + << "is slightly different and doesn't match in all cases."; } std::vector X(16); std::iota(X.begin(), X.end(), 1.f); @@ -1939,7 +1962,8 @@ TEST(ResizeOpTest, Antialias_Bilinear_dtype) { std::vector Y = {1, 3, 4, 6, 8, 9, 11, 13, 14}; - TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 1, 4, 4}, X, {1, 1, 3, 3}, Y); + InlinedVector excluded_eps = {kCudaExecutionProvider}; + TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 1, 4, 4}, X, {1, 1, 3, 3}, Y, excluded_eps); } { std::vector X(16); @@ -1982,17 +2006,21 @@ TEST(ResizeOpTest, Antialias_NhwcBilinear) { 33.5f, 73.5f, 113.5f, 35.074074f, 75.07407f, 115.07407f, 36.590908f, 76.59091f, 116.59091f}; - TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 5, 8, 3}, X, {1, 4, 5, 3}, Y); + + // Nchw is not supported by CUDA Resize implementation + InlinedVector excluded_eps = {kCudaExecutionProvider, kRocmExecutionProvider}; + TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 5, 8, 3}, X, {1, 4, 5, 3}, Y, excluded_eps); } TEST(ResizeOpTest, Antialias_NhwcBilinear_dtype) { + InlinedVector excluded_eps = {kCudaExecutionProvider, kRocmExecutionProvider}; { std::vector X(16); std::iota(X.begin(), X.end(), uint8_t(0)); std::vector Y = {1, 3, 4, 6, 8, 9, 11, 13, 14}; - TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 4, 4, 1}, X, {1, 3, 3, 1}, Y); + TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 4, 4, 1}, X, {1, 3, 3, 1}, Y, excluded_eps); } { std::vector X(16); @@ -2000,7 +2028,7 @@ TEST(ResizeOpTest, Antialias_NhwcBilinear_dtype) { std::vector Y = {1, 3, 4, 6, 8, 9, 11, 13, 14}; - TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 4, 4, 1}, X, {1, 3, 3, 1}, Y); + TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 4, 4, 1}, X, {1, 3, 3, 1}, Y, excluded_eps); } { std::vector X(16); @@ -2008,13 +2036,14 @@ TEST(ResizeOpTest, Antialias_NhwcBilinear_dtype) { std::vector Y = {1, 3, 4, 6, 8, 9, 11, 13, 14}; - TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 4, 4, 1}, X, {1, 3, 3, 1}, Y); + TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 4, 4, 1}, X, {1, 3, 3, 1}, Y, excluded_eps); } } TEST(ResizeOpTest, Antialias_Trilinear_No_ExcludeOutside) { if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because dml implementation of antialias is slightly different and doesn't match in all cases."; + GTEST_SKIP() << "Skipping because dml implementation of " + << "antialias is slightly different and doesn't match in all cases."; } std::vector X(16 * 4); std::iota(X.begin(), X.end(), 0.f); @@ -2038,13 +2067,17 @@ TEST(ResizeOpTest, Antialias_Trilinear_ExcludeOutside) { TEST(ResizeOpTest, Antialias_Trilinear_Scale_Is_11s_and_1s1) { if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because dml implementation of antialias is slightly different and doesn't match in all cases."; + GTEST_SKIP() << "Skipping because dml implementation of antialias" + << " is slightly different and doesn't match in all cases."; } + + InlinedVector excluded_eps = {kCudaExecutionProvider}; std::vector X(16 * 4 * 4); std::iota(X.begin(), X.end(), 0.f); { std::vector Y = X; - TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {4, 1, 4, 4, 4}, X, {4, 1, 4, 4, 4}, Y); + TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {4, 1, 4, 4, 4}, X, {4, 1, 4, 4, 4}, Y, + excluded_eps); } { std::vector Y = {0.625f, 2.375f, 4.625f, 6.375f, 8.625f, 10.375f, 12.625f, @@ -2066,7 +2099,8 @@ TEST(ResizeOpTest, Antialias_Trilinear_Scale_Is_11s_and_1s1) { 224.625f, 226.375f, 228.625f, 230.375f, 232.625f, 234.375f, 236.625f, 238.375f, 240.625f, 242.375f, 244.625f, 246.375f, 248.625f, 250.375f, 252.625f, 254.375f}; - TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "0"}}, {4, 1, 4, 4, 4}, X, {4, 1, 4, 4, 2}, Y); + TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "0"}}, {4, 1, 4, 4, 4}, X, {4, 1, 4, 4, 2}, Y, + excluded_eps); } { std::vector Y = {2.5f, 3.5f, 4.5f, 5.5f, 9.5f, 10.5f, 11.5f, 12.5f, 18.5f, @@ -2084,7 +2118,8 @@ TEST(ResizeOpTest, Antialias_Trilinear_Scale_Is_11s_and_1s1) { 217.5f, 218.5f, 219.5f, 220.5f, 226.5f, 227.5f, 228.5f, 229.5f, 233.5f, 234.5f, 235.5f, 236.5f, 242.5f, 243.5f, 244.5f, 245.5f, 249.5f, 250.5f, 251.5f, 252.5f}; - TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "0"}}, {4, 1, 4, 4, 4}, X, {4, 1, 4, 2, 4}, Y); + TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "0"}}, {4, 1, 4, 4, 4}, X, {4, 1, 4, 2, 4}, Y, + excluded_eps); } } @@ -2124,12 +2159,15 @@ TEST(ResizeOpTest, Antialias_NHWCBicubic_ExcludeOutside) { 19.576872f, 43.57687f, 21.126253f, 45.126255f, 22.606192f, 46.606194f, 19.878183f, 43.87818f, 21.358122f, 45.35812f, 22.907503f, 46.907505f, 24.387442f, 48.387444f}; - TestAntialiasing({{"mode", "cubic"}, {"exclude_outside", "0"}}, {1, 4, 6, 2}, X, {1, 8, 4, 2}, Y); + + InlinedVector excluded_eps = {kCudaExecutionProvider, kRocmExecutionProvider}; + TestAntialiasing({{"mode", "cubic"}, {"exclude_outside", "0"}}, {1, 4, 6, 2}, X, {1, 8, 4, 2}, Y, excluded_eps); } TEST(ResizeOpTest, Antialias_Linear_AlignCorners) { if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because dml implementation of antialias is slightly different and doesn't match in all cases."; + GTEST_SKIP() << "Skipping because dml implementation of antialias" + << "is slightly different and doesn't match in all cases."; } std::vector X(256); std::iota(X.begin(), X.end(), 0.0f); @@ -2145,9 +2183,40 @@ TEST(ResizeOpTest, Antialias_Linear_AlignCorners) { 187.08333f, 195.91667f, 198.41667f, 205.91667f, 208.41667f, 217.25f, 219.75f, 227.25f, 229.75f, 238.58333f, 241.08333f, 248.58333f, 251.08333f}; + InlinedVector excluded_eps = {kCudaExecutionProvider, kRocmExecutionProvider}; TestAntialiasing( {{"mode", "linear"}, {"exclude_outside", "0"}, {"coordinate_transformation_mode", "align_corners"}}, - {4, 1, 4, 4, 4}, X, {4, 1, 3, 2, 2}, Y); + {4, 1, 4, 4, 4}, X, {4, 1, 3, 2, 2}, Y, excluded_eps); +} + +TEST(ResizeOpTest, Antialias_Linear_AlignCorners_3D) { + if (DefaultDmlExecutionProvider().get() != nullptr) { + GTEST_SKIP() << "Skipping because dml implementation of antialias is slightly " + << "different and doesn't match in all cases."; + } + std::vector X(256); + std::iota(X.begin(), X.end(), 0.0f); + std::vector Y{ + 1.25f, 3.75f, 11.25f, 13.75f, + 17.25f, 19.75f, 27.25f, 29.75f, + 33.25f, 35.75f, 43.25f, 45.75f, + 49.25f, 51.75f, 59.25f, 61.75f, + 65.25f, 67.75f, 75.25f, 77.75f, + 81.25f, 83.75f, 91.25f, 93.75f, + 97.25f, 99.75f, 107.25f, 109.75f, + 113.25f, 115.75f, 123.25f, 125.75f, + 129.25f, 131.75f, 139.25f, 141.75f, + 145.25f, 147.75f, 155.25f, 157.75f, + 161.25f, 163.75f, 171.25f, 173.75f, + 177.25f, 179.75f, 187.25f, 189.75f, + 193.25f, 195.75f, 203.25f, 205.75f, + 209.25f, 211.75f, 219.25f, 221.75f, + 225.25f, 227.75f, 235.25f, 237.75f, + 241.25f, 243.75f, 251.25f, 253.75f}; + + TestAntialiasing( + {{"mode", "linear"}, {"exclude_outside", "0"}, {"coordinate_transformation_mode", "align_corners"}}, + {16, 4, 4}, X, {16, 2, 2}, Y); } TEST(ResizeOpTest, Antialias_Bicubic_ExcludeOutside) { @@ -2166,19 +2235,23 @@ TEST(ResizeOpTest, Antialias_Bicubic_Dtype) { std::vector X(36); std::iota(X.begin(), X.end(), uint8_t(0)); std::vector Y = {4, 6, 7, 16, 18, 19, 28, 30, 31}; - TestAntialiasing({{"mode", "cubic"}, {"cubic_coeff_a", "-0.5f"}, {"exclude_outside", "1"}}, {1, 1, 6, 6}, X, {1, 1, 3, 3}, Y); + TestAntialiasing({{"mode", "cubic"}, {"cubic_coeff_a", "-0.5f"}, {"exclude_outside", "1"}}, {1, 1, 6, 6}, + X, {1, 1, 3, 3}, Y); } { std::vector X(36); std::iota(X.begin(), X.end(), int8_t(0)); std::vector Y = {4, 6, 7, 16, 18, 19, 28, 30, 31}; - TestAntialiasing({{"mode", "cubic"}, {"cubic_coeff_a", "-0.5f"}, {"exclude_outside", "1"}}, {1, 1, 6, 6}, X, {1, 1, 3, 3}, Y); + InlinedVector excluded_eps = {kCudaExecutionProvider}; + TestAntialiasing({{"mode", "cubic"}, {"cubic_coeff_a", "-0.5f"}, {"exclude_outside", "1"}}, {1, 1, 6, 6}, + X, {1, 1, 3, 3}, Y, excluded_eps); } { std::vector X(36); std::iota(X.begin(), X.end(), 0); std::vector Y = {4, 6, 7, 16, 18, 19, 28, 30, 31}; - TestAntialiasing({{"mode", "cubic"}, {"cubic_coeff_a", "-0.5f"}, {"exclude_outside", "1"}}, {1, 1, 6, 6}, X, {1, 1, 3, 3}, Y); + TestAntialiasing({{"mode", "cubic"}, {"cubic_coeff_a", "-0.5f"}, {"exclude_outside", "1"}}, {1, 1, 6, 6}, + X, {1, 1, 3, 3}, Y); } } @@ -2189,8 +2262,10 @@ TEST(ResizeOpTest, Antialias_Axes_and_Scale) { std::vector Y = {6.3f, 7.5f, 8.7f, 11.1f, 12.3f, 13.5f, 15.9f, 17.1f, 18.3f, 25.5f, 26.7f, 27.9f, 30.3f, 31.5f, 32.7f, 35.1f, 36.3f, 37.5f, 44.7f, 45.9f, 47.1f, 49.5f, 50.7f, 51.9f, 54.3f, 55.5f, 56.7f}; - TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}, {"axes", "{2,3,4}"}, {"output_shape", "{1,1,3,3,3}"}}, {1, 1, 4, 4, 4}, X, - std::vector{3 / 4.0f, 3 / 4.0f, 3 / 4.0f}, Y); + TestAntialiasing( + {{"mode", "linear"}, {"exclude_outside", "1"}, {"axes", "{2,3,4}"}, {"output_shape", "{1,1,3,3,3}"}}, + {1, 1, 4, 4, 4}, X, + std::vector{3 / 4.0f, 3 / 4.0f, 3 / 4.0f}, Y); } TEST(ResizeOpTest, Antialias_Axes_and_Size) { @@ -2199,8 +2274,10 @@ TEST(ResizeOpTest, Antialias_Axes_and_Size) { std::vector Y = {6.3f, 7.5f, 8.7f, 11.1f, 12.3f, 13.5f, 15.9f, 17.1f, 18.3f, 25.5f, 26.7f, 27.9f, 30.3f, 31.5f, 32.7f, 35.1f, 36.3f, 37.5f, 44.7f, 45.9f, 47.1f, 49.5f, 50.7f, 51.9f, 54.3f, 55.5f, 56.7f}; - TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}, {"axes", "{2,3,4}"}, {"output_shape", "{1,1,3,3,3}"}}, {1, 1, 4, 4, 4}, X, - {3, 3, 3}, Y); + TestAntialiasing( + {{"mode", "linear"}, {"exclude_outside", "1"}, {"axes", "{2,3,4}"}, {"output_shape", "{1,1,3,3,3}"}}, + {1, 1, 4, 4, 4}, X, + {3, 3, 3}, Y); } TEST(ResizeOpTest, Antialias_Axes_and_PolicyNoLarger) { @@ -2209,9 +2286,13 @@ TEST(ResizeOpTest, Antialias_Axes_and_PolicyNoLarger) { std::vector Y = {6.3f, 7.5f, 8.7f, 11.1f, 12.3f, 13.5f, 15.9f, 17.1f, 18.3f, 25.5f, 26.7f, 27.9f, 30.3f, 31.5f, 32.7f, 35.1f, 36.3f, 37.5f, 44.7f, 45.9f, 47.1f, 49.5f, 50.7f, 51.9f, 54.3f, 55.5f, 56.7f}; - TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}, {"axes", "{2,3,4}"}, {"output_shape", "{1,1,3,3,3}"}, {"policy", "not_larger"}}, - {1, 1, 4, 4, 4}, X, - {3, 4, 5}, Y); + // clang-format off + TestAntialiasing( + {{"mode", "linear"}, {"exclude_outside", "1"}, {"axes", "{2,3,4}"}, {"output_shape", "{1,1,3,3,3}"}, + {"policy", "not_larger"}}, + {1, 1, 4, 4, 4}, X, + {3, 4, 5}, Y); + // clang-format on } TEST(ResizeOpTest, Antialias_Axes_and_PolicyNoSmaller) { @@ -2220,9 +2301,13 @@ TEST(ResizeOpTest, Antialias_Axes_and_PolicyNoSmaller) { std::vector Y = {6.3f, 7.5f, 8.7f, 11.1f, 12.3f, 13.5f, 15.9f, 17.1f, 18.3f, 25.5f, 26.7f, 27.9f, 30.3f, 31.5f, 32.7f, 35.1f, 36.3f, 37.5f, 44.7f, 45.9f, 47.1f, 49.5f, 50.7f, 51.9f, 54.3f, 55.5f, 56.7f}; - TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}, {"axes", "{2,3,4}"}, {"output_shape", "{1,1,3,3,3}"}, {"policy", "not_smaller"}}, - {1, 1, 4, 4, 4}, X, - {1, 2, 3}, Y); + // clang-format off + TestAntialiasing( + {{"mode", "linear"}, {"exclude_outside", "1"}, {"axes", "{2,3,4}"}, {"output_shape", "{1,1,3,3,3}"}, + {"policy", "not_smaller"}}, + {1, 1, 4, 4, 4}, X, + {1, 2, 3}, Y); + // clang-format on } TEST(ResizeOpTest, Antialias_Use_Extrapolation) {