diff --git a/onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu b/onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu index 972d043ea98ae..32370930a5d90 100644 --- a/onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu @@ -9,7 +9,7 @@ #ifdef CPU_TESTING #undef __global__ #define __global__ -using IdType = int64_t; +using IdType = int; #define FUNC_DEF __host__ #else using IdType = int; @@ -42,7 +42,8 @@ static std::tuple ComputeBilinearScaleBufferSize(int64_t outpu std::cout << "\toutput_height: " << output_height << " output_width: " << output_width << " height_rscale: " << height_rscale << " width_rscale: " << width_rscale << " support_value: " << support_value - << " scaled_support_height: " << scaled_support_height << " scaled_support_width: \n" << scaled_support_width; + << " scaled_support_height: " << scaled_support_height << " scaled_support_width: \n" + << scaled_support_width; std::cout << "\tHeight Buffer size: " << static_cast(height_buffer_size) << " Width Buffer size: " << static_cast(width_buffer_size); @@ -132,7 +133,7 @@ struct AccumTypeCaster { }; template -FUNC_DEF void ComputeInterpolationAtLevel1(int id, +FUNC_DEF void ComputeInterpolationAtLevel1(IdType id, int64_t input_height, int64_t input_width, int64_t output_height, int64_t output_width, const fast_divmod& div_output_width, const fast_divmod& div_output_image, @@ -149,7 +150,7 @@ FUNC_DEF void ComputeInterpolationAtLevel1(int id, int bxc, output_image_index; div_output_image.divmod(id, bxc, output_image_index); - CUDA_LONG input_index = bxc * input_height * input_width; + int64_t input_index = bxc * input_height * input_width; int output_y, output_x; div_output_width.divmod(output_image_index, output_y, output_x); @@ -168,7 +169,8 @@ FUNC_DEF void ComputeInterpolationAtLevel1(int id, const auto* Xdata_offset = Xdata + input_index + xmin; for (; xmin < xmax; ++xmin) { - output += (*Xdata_offset++) * (*weight_coeff++); + // This cast is needed when we deal with half + output += static_cast((*Xdata_offset++)) * (*weight_coeff++); } if constexpr (onnxruntime::is_8bit_v) { @@ -176,12 +178,12 @@ FUNC_DEF void ComputeInterpolationAtLevel1(int id, } else if constexpr (std::is_same::value) { *Ydata_offset = static_cast(std::round(output)); } else { - *Ydata_offset = output; + *Ydata_offset = static_cast(output); } } template -FUNC_DEF void ComputeInterpolationAtLevel2(int id, +FUNC_DEF void ComputeInterpolationAtLevel2(IdType id, int64_t input_height, int64_t input_width, int64_t output_height, int64_t output_width, const fast_divmod& div_output_width, const fast_divmod& div_output_image, @@ -192,14 +194,14 @@ FUNC_DEF void ComputeInterpolationAtLevel2(int id, const InputType* Xdata, InputType* Ydata) { // No need to do scale if (output_height == input_height) { - Ydata[i] = Xdata[i]; + Ydata[id] = Xdata[id]; return; } // input image index and output image are the same. int bxc, output_image_index; div_output_image.divmod(id, bxc, output_image_index); - CUDA_LONG input_index = bxc * input_height * input_width; + int64_t input_index = bxc * input_height * input_width; int output_y, output_x; div_output_width.divmod(output_image_index, output_y, output_x); @@ -216,7 +218,7 @@ FUNC_DEF void ComputeInterpolationAtLevel2(int id, const auto* Xdata_offset = Xdata + input_index + ymin * output_width + output_x; for (; ymin < ymax; ++ymin) { - output += (*Xdata_offset) * (*weight_coeff); + output += static_cast((*Xdata_offset)) * (*weight_coeff); Xdata_offset += input_width; weight_coeff++; } @@ -230,8 +232,9 @@ FUNC_DEF void ComputeInterpolationAtLevel2(int id, } } -template -FUNC_DEF void HandleExtrapolation(int32_t id, int64_t output_depth, int64_t output_height, int64_t output_width, +template +FUNC_DEF void HandleExtrapolation(IdType id, int64_t input_height, int64_t input_width, + int64_t output_depth, int64_t output_height, int64_t output_width, const float extrapolation_value, InputType* Ydata, const fast_divmod& div_output_height, const fast_divmod& div_output_width, const fast_divmod& div_output_image, @@ -239,7 +242,7 @@ FUNC_DEF void HandleExtrapolation(int32_t id, int64_t output_depth, int64_t outp const int64_t* w_outof_bounds) { int bxc, output_image_index; div_output_image.divmod(id, bxc, output_image_index); - CUDA_LONG input_index = bxc * input_height * input_width; + // CUDA_LONG input_index = bxc * input_height * input_width; InputType* Ydata_base = Ydata + output_image_index * (output_depth * output_height * output_width); @@ -266,7 +269,7 @@ FUNC_DEF void HandleExtrapolation(int32_t id, int64_t output_depth, int64_t outp } // Extrapolate along the y dimension - if (z_outof_bounds[static_cast(output_z)] != -1) { + if (z_outof_bounds != nullptr && z_outof_bounds[static_cast(output_z)] != -1) { #ifdef CPU_TESTING assert(z_outof_bounds[static_cast(output_z)] == output_z); #endif @@ -280,25 +283,23 @@ __global__ void _UpsampleBilinearAntiAlias( #ifdef CPU_TESTING IdType id, #endif - const int64_t batch_size, - const int64_t num_channels, const int64_t input_height, const int64_t input_width, const int64_t output_depth, const int64_t output_height, const int64_t output_width, fast_divmod div_output_height, fast_divmod div_output_width, fast_divmod div_output_image, - int64_t window_size, + std::tuple window_sizes, // h, w const bool use_extrapolation, const float extrapolation_value, - std::tuple bounds_buffers, // z, h, w - std::tuple outof_bounds_buffers, // z, h, w - std::tuple weight_buffers, // z, h, w + std::tuple bounds_buffers, // h, w + std::tuple outof_bounds_buffers, // z, h, w + std::tuple weight_buffers, // h, w const uint8_t* clip8_lookups, T* image_temp_buffer, // We expect this to be input_height * output_width * num_channels const T* Xdata, T* Ydata, - const int N) { + const size_t N) { #ifndef CPU_TESTING CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); @@ -307,20 +308,21 @@ __global__ void _UpsampleBilinearAntiAlias( // horizon interpolate // This computes only the width direction.Thus height keeps unchanged. ComputeInterpolationAtLevel1(id, input_height, input_width, input_height, output_width, - div_output_width, div_output_image, window_size, - clip8_lookups, std::get<2>(bounds_buffers), - std::get<2>(weight_buffers), - Xdata, Ydata); + div_output_width, div_output_image, std::get<1>(window_sizes), + clip8_lookups, std::get<1>(bounds_buffers), + std::get<1>(weight_buffers), + Xdata, image_temp_buffer); // vertical interpolate ComputeInterpolationAtLevel2(id, input_height, output_width, output_height, output_width, div_output_width, div_output_image, - window_size, clip8_lookups, std::get<1>(bounds_buffers), - std::get<1>(weight_buffers), - Xdata, Ydata); + std::get<0>(window_sizes), + clip8_lookups, std::get<0>(bounds_buffers), + std::get<0>(weight_buffers), + image_temp_buffer, Ydata); if (use_extrapolation) { - HandleExtrapolation(id, output_depth, output_height, output_width, + HandleExtrapolation(id, input_height, input_width, output_depth, output_height, output_width, extrapolation_value, Ydata, div_output_height, div_output_width, div_output_image, std::get<0>(outof_bounds_buffers), @@ -382,7 +384,7 @@ __global__ void _UpsampleTrilinearAntiAlias( Xdata, Ydata); if (use_extrapolation) { - HandleExtrapolation(id, output_depth, output_height, output_width, + HandleExtrapolation(id, input_height, input_width, output_depth, output_height, output_width, extrapolation_value, Ydata, div_output_height, div_output_width, div_output_image, std::get<0>(outof_bounds_buffers), @@ -586,12 +588,12 @@ __global__ void _SetupBilinearUpsampleFilterAntiAlias( } #endif - id = id - y_output_size; + auto i = id - y_output_size; bounds += (y_output_size * 2); out_of_bounds += y_output_size; SetupUpsampleFilterAnitAliasImpl( - id, + i, input_size, output_size, inv_scale, roi_start, roi_end, @@ -696,12 +698,12 @@ __global__ void _SetupTrilinerarUpsampleFilterAntiAlias( } #endif - id = id - d_output_size; + auto i = id - d_output_size; bounds += d_output_size * 2; out_of_bounds += d_output_size; SetupUpsampleFilterAnitAliasImpl( - id, + i, input_size, output_size, inv_scale, roi_start, roi_end, @@ -734,12 +736,12 @@ __global__ void _SetupTrilinerarUpsampleFilterAntiAlias( } #endif - id = id - d_y_output_size; + auto i = id - d_y_output_size; bounds += (d_y_output_size * 2); out_of_bounds += d_y_output_size; SetupUpsampleFilterAnitAliasImpl( - id, + i, input_size, output_size, inv_scale, roi_start, roi_end, @@ -793,21 +795,25 @@ void ResizeAntiAliasImpl( ResizeCoordinateTransformationMode coordinate_transform_mode, gsl::span input_shape, gsl::span output_shape, + int64_t batch_size, int64_t num_channels, std::tuple inferred_input_dims, std::tuple inferred_output_dims, std::tuple inferred_dim_rscales, // const TArray& input_strides, const TArray& output_div_pitches, gsl::span roi_vals, - const std::optional& extrapolation_value, + const std::optional& extrapolation, bool exclude_outside, TempSpaceAllocateFunc allocate_temp_space, + const uint8_t* clip8_lookups, const T* input_data, T* output_data, const size_t N) { - using AccumType = typename onnxruntime::AccumulateType::type; + const bool use_extrapolation = extrapolation.has_value(); + const float extrapolation_value = use_extrapolation ? *extrapolation : 0.f; + // We support a special case of bilinear or bicubic if the input data is 4D with the outer 2 scales being 1.0 // We would have validated the outer scale values by the time execution reaches this const bool is_2D = (rank == 2 || rank == 4); @@ -865,17 +871,23 @@ void ResizeAntiAliasImpl( weighted_w_size; auto weighted_buffer = AllocateTyped(allocate_temp_space, weighted_buffer_size); + int64_t* y_bounds_buffer = GetTyped(weighted_buffer); + int64_t* w_bounds_buffer = y_bounds_buffer + output_height * 2; + + int64_t* y_outof_bounds_buffer = GetTyped(out_of_bounds_buffer_ptr); + int64_t* w_outof_bounds_buffer = y_outof_bounds_buffer + output_height; + AccumType* y_weighted_buffer = GetTyped(weighted_buffer); AccumType* w_weighted_buffer = y_weighted_buffer + weighted_y_size; #ifdef CPU_TESTING - for (int64_t id = 0, lim = output_height + output_width; id < lim; ++id) { + for (IdType id = 0, lim = narrow(output_height + output_width); id < lim; ++id) { DISPATCH_ANTIALIAS_FILTER_SETUP(coordinate_transform_mode, [&]() { _SetupBilinearUpsampleFilterAntiAlias( - narrow(id), + id, std::make_tuple(input_height, input_width), std::make_tuple(output_height, output_width), std::make_tuple(h_scale, w_scale), @@ -895,6 +907,24 @@ void ResizeAntiAliasImpl( auto weighted_buffer_span = gsl::make_span(GetTyped(weighted_buffer), weighted_buffer_size); PrintAntiAliasBuffers(std::cout, bounds_buffer_span, out_of_bounds_buffer_span, weighted_buffer_span); + auto image_temp_buffer = AllocateTyped(allocate_temp_space, + narrow(input_height * output_width * num_channels)); + + for (IdType id = 0, lim = narrow(batch_size * input_height * num_channels * input_width); id < lim; ++id) { + _UpsampleBilinearAntiAlias( + id, + input_height, input_width, output_depth, output_height, output_width, + output_div_pitches[rank - 2], output_div_pitches[rank - 1], div_output_image, + std::make_tuple(h_window_size, w_window_size), + use_extrapolation, extrapolation_value, + std::make_tuple(y_bounds_buffer, w_bounds_buffer), + std::make_tuple(static_cast(nullptr) , y_outof_bounds_buffer, w_outof_bounds_buffer), + std::make_tuple(y_weighted_buffer, w_weighted_buffer), + clip8_lookups, + GetTyped(image_temp_buffer), + input_data, output_data, N); + } + #else DISPATCH_ANTIALIAS_FILTER_SETUP(coordinate_transform_mode, [&]() { // Data is d, h, w in tuples @@ -946,12 +976,12 @@ void ResizeAntiAliasImpl( AccumType* w_weighted_buffer = y_weighted_buffer + y_buffer_size; #ifdef CPU_TESTING - for (int64_t id = 0, lim = output_depth + output_height + output_width; id < lim; ++id) { + for (IdType id = 0, lim = narrow(output_depth + output_height + output_width); id < lim; ++id) { DISPATCH_ANTIALIAS_FILTER_SETUP(coordinate_transform_mode, [&]() { _SetupTrilinerarUpsampleFilterAntiAlias( - narrow(id), + id, inferred_input_dims, inferred_output_dims, inferred_dim_rscales, @@ -1026,12 +1056,12 @@ void ResizeAntiAliasImpl( #ifdef CPU_TESTING - for (int64_t id = 0, lim = output_height + output_width; id < lim; ++id) { + for (IdType id = 0, lim = narrow(output_height + output_width); id < lim; ++id) { DISPATCH_ANTIALIAS_FILTER_SETUP(coordinate_transform_mode, [&]() { _SetupBilinearUpsampleFilterAntiAlias( - narrow(id), + id, std::make_tuple(input_height, input_width), std::make_tuple(output_height, output_width), std::make_tuple(h_scale, w_scale), @@ -1086,14 +1116,16 @@ void ResizeAntiAliasImpl( ResizeCoordinateTransformationMode coordinate_transform_mode, \ gsl::span input_shape, \ gsl::span output_shape, \ + int64_t batch_size, int64_t num_channels, \ std::tuple inferred_input_dims, \ std::tuple inferred_output_dims, \ std::tuple inferred_dim_rscales, /* const TArray& input_strides, */ \ const TArray& output_div_pitches, \ gsl::span roi_vals, \ - const std::optional& extrapolation_value, \ + const std::optional& extrapolation_value, \ bool exclude_outside, \ TempSpaceAllocateFunc allocate_temp_space, \ + const uint8_t* clip8_lookups, \ const T* input_data, \ T* output_data, \ const size_t N); diff --git a/onnxruntime/core/providers/cuda/tensor/resize_impl.h b/onnxruntime/core/providers/cuda/tensor/resize_impl.h index b1bd156e954e1..e55261964c9f0 100644 --- a/onnxruntime/core/providers/cuda/tensor/resize_impl.h +++ b/onnxruntime/core/providers/cuda/tensor/resize_impl.h @@ -93,15 +93,17 @@ void ResizeAntiAliasImpl( ResizeCoordinateTransformationMode coordinate_transform_mode, gsl::span input_shape, gsl::span output_shape, + int64_t batch_size, int64_t num_channels, std::tuple inferred_input_dims, std::tuple inferred_output_dims, std::tuple inferred_dim_rscales, // const TArray& input_strides, const TArray& output_div_pitches, gsl::span roi_vals, // CPU - const std::optional& extrapolation_value, + const std::optional& extrapolation_value, bool exclude_outside, TempSpaceAllocateFunc allocate_temp_space, + const uint8_t* clip8_lookups, const T* input_data, T* output_data, const size_t N); diff --git a/onnxruntime/core/providers/cuda/tensor/upsample.cc b/onnxruntime/core/providers/cuda/tensor/upsample.cc index 7fc6f10517790..308a627380240 100644 --- a/onnxruntime/core/providers/cuda/tensor/upsample.cc +++ b/onnxruntime/core/providers/cuda/tensor/upsample.cc @@ -110,9 +110,9 @@ Status Upsample::BaseCompute(OpKernelContext* context, return IAllocator::MakeUniquePtr(allocator_ptr, bytes_size); }; - std::optional extrapolation_value; + std::optional extrapolation_value; if (use_extrapolation_) - extrapolation_value.emplace(ToCudaType::FromFloat(extrapolation_value_)); + extrapolation_value.emplace(extrapolation_value_); switch (mode_) { case UpsampleMode::LINEAR: { @@ -176,6 +176,7 @@ Status Upsample::BaseCompute(OpKernelContext* context, mode_, coordinate_transform_mode_, X_dims, output_dims, + batch_size, num_channels, std::make_tuple(0, input_height, input_width), std::make_tuple(0, output_height, output_width), std::make_tuple(0.f, height_scale, width_scale), @@ -184,6 +185,7 @@ Status Upsample::BaseCompute(OpKernelContext* context, extrapolation_value, exclude_outside_, allocate_temp_space, + shared_lookup_table_ondevice_.get(), reinterpret_cast(X->Data()), reinterpret_cast(Y->MutableData()), output_count); @@ -191,8 +193,8 @@ Status Upsample::BaseCompute(OpKernelContext* context, } else if (X_dims.size() == 3 || X_dims.size() == 5) { const bool is_3D = X_dims.size() == 3; - // const int64_t batch_size = is_3D ? 1 : X_dims[0]; - // const int64_t num_channels = is_3D ? 1 : X_dims[1]; + const int64_t batch_size = is_3D ? 1 : X_dims[0]; + const int64_t num_channels = is_3D ? 1 : X_dims[1]; const int64_t input_depth = is_3D ? X_dims[0] : X_dims[2]; const int64_t input_height = is_3D ? X_dims[1] : X_dims[3]; const int64_t input_width = is_3D ? X_dims[2] : X_dims[4]; @@ -210,6 +212,7 @@ Status Upsample::BaseCompute(OpKernelContext* context, mode_, coordinate_transform_mode_, X_dims, output_dims, + batch_size, num_channels, std::make_tuple(input_depth, input_height, input_width), std::make_tuple(output_depth, output_height, output_width), std::make_tuple(depth_scale, height_scale, width_scale), @@ -218,6 +221,7 @@ Status Upsample::BaseCompute(OpKernelContext* context, extrapolation_value, exclude_outside_, allocate_temp_space, + shared_lookup_table_ondevice_.get(), reinterpret_cast(X->Data()), reinterpret_cast(Y->MutableData()), output_count); @@ -239,8 +243,8 @@ Status Upsample::BaseCompute(OpKernelContext* context, const bool is_nchw = is_2D ? true : (scales[1] == 1.0f); assert(is_nchw); // We are not implementing it yet. - // const int64_t batch_size = is_2D ? 1 : X_dims[0]; - // const int64_t num_channels = is_2D ? 1 : (is_nchw ? X_dims[1] : X_dims[3]); + const int64_t batch_size = is_2D ? 1 : X_dims[0]; + const int64_t num_channels = is_2D ? 1 : (is_nchw ? X_dims[1] : X_dims[3]); const int64_t input_height = is_2D ? X_dims[0] : (is_nchw ? X_dims[2] : X_dims[1]); const int64_t input_width = is_2D ? X_dims[1] : (is_nchw ? X_dims[3] : X_dims[2]); const int64_t output_height = is_2D ? output_dims[0] : (is_nchw ? output_dims[2] : output_dims[1]); @@ -250,6 +254,7 @@ Status Upsample::BaseCompute(OpKernelContext* context, ResizeAntiAliasImpl(Stream(context), rank, mode_, coordinate_transform_mode_, X_dims, output_dims, + batch_size, num_channels, std::make_tuple(0, input_height, input_width), std::make_tuple(0, output_height, output_width), std::make_tuple(0.f, height_scale, width_scale), @@ -258,6 +263,7 @@ Status Upsample::BaseCompute(OpKernelContext* context, extrapolation_value, exclude_outside_, allocate_temp_space, + shared_lookup_table_ondevice_.get(), reinterpret_cast(X->Data()), reinterpret_cast(Y->MutableData()), output_count);