Skip to content

Commit

Permalink
Working on Bilinear Upsample
Browse files Browse the repository at this point in the history
  • Loading branch information
yuslepukhin committed Feb 13, 2024
1 parent 435cb94 commit 2cfe280
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 52 deletions.
122 changes: 77 additions & 45 deletions onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#ifdef CPU_TESTING
#undef __global__
#define __global__
using IdType = int64_t;
using IdType = int;
#define FUNC_DEF __host__
#else
using IdType = int;
Expand Down Expand Up @@ -42,7 +42,8 @@ static std::tuple<int64_t, int64_t> ComputeBilinearScaleBufferSize(int64_t outpu
std::cout << "\toutput_height: " << output_height << " output_width: " << output_width
<< " height_rscale: " << height_rscale << " width_rscale: " << width_rscale
<< " support_value: " << support_value
<< " scaled_support_height: " << scaled_support_height << " scaled_support_width: \n" << scaled_support_width;
<< " scaled_support_height: " << scaled_support_height << " scaled_support_width: \n"
<< scaled_support_width;
std::cout << "\tHeight Buffer size: " << static_cast<size_t>(height_buffer_size)
<< " Width Buffer size: " << static_cast<size_t>(width_buffer_size);

Expand Down Expand Up @@ -132,7 +133,7 @@ struct AccumTypeCaster<int32_t> {
};

template <typename InputType, typename AccumulateType>
FUNC_DEF void ComputeInterpolationAtLevel1(int id,
FUNC_DEF void ComputeInterpolationAtLevel1(IdType id,
int64_t input_height, int64_t input_width,
int64_t output_height, int64_t output_width,
const fast_divmod& div_output_width, const fast_divmod& div_output_image,
Expand All @@ -149,7 +150,7 @@ FUNC_DEF void ComputeInterpolationAtLevel1(int id,

int bxc, output_image_index;
div_output_image.divmod(id, bxc, output_image_index);
CUDA_LONG input_index = bxc * input_height * input_width;
int64_t input_index = bxc * input_height * input_width;

int output_y, output_x;
div_output_width.divmod(output_image_index, output_y, output_x);
Expand All @@ -168,20 +169,21 @@ FUNC_DEF void ComputeInterpolationAtLevel1(int id,
const auto* Xdata_offset = Xdata + input_index + xmin;

for (; xmin < xmax; ++xmin) {
output += (*Xdata_offset++) * (*weight_coeff++);
// This cast is needed when we deal with half
output += static_cast<AccumulateType>((*Xdata_offset++)) * (*weight_coeff++);
}

if constexpr (onnxruntime::is_8bit_v<InputType>) {
*Ydata_offset = static_cast<InputType>(clip8_lookups[output >> 22]);
} else if constexpr (std::is_same<InputType, int32_t>::value) {
*Ydata_offset = static_cast<int32_t>(std::round(output));
} else {
*Ydata_offset = output;
*Ydata_offset = static_cast<InputType>(output);
}
}

template <typename InputType, typename AccumulateType>
FUNC_DEF void ComputeInterpolationAtLevel2(int id,
FUNC_DEF void ComputeInterpolationAtLevel2(IdType id,
int64_t input_height, int64_t input_width,
int64_t output_height, int64_t output_width,
const fast_divmod& div_output_width, const fast_divmod& div_output_image,
Expand All @@ -192,14 +194,14 @@ FUNC_DEF void ComputeInterpolationAtLevel2(int id,
const InputType* Xdata, InputType* Ydata) {
// No need to do scale
if (output_height == input_height) {
Ydata[i] = Xdata[i];
Ydata[id] = Xdata[id];
return;
}

// input image index and output image are the same.
int bxc, output_image_index;
div_output_image.divmod(id, bxc, output_image_index);
CUDA_LONG input_index = bxc * input_height * input_width;
int64_t input_index = bxc * input_height * input_width;

int output_y, output_x;
div_output_width.divmod(output_image_index, output_y, output_x);
Expand All @@ -216,7 +218,7 @@ FUNC_DEF void ComputeInterpolationAtLevel2(int id,
const auto* Xdata_offset = Xdata + input_index + ymin * output_width + output_x;

for (; ymin < ymax; ++ymin) {
output += (*Xdata_offset) * (*weight_coeff);
output += static_cast<AccumulateType>((*Xdata_offset)) * (*weight_coeff);
Xdata_offset += input_width;
weight_coeff++;
}
Expand All @@ -230,16 +232,17 @@ FUNC_DEF void ComputeInterpolationAtLevel2(int id,
}
}

template <typename InputType, typename AccumulateType>
FUNC_DEF void HandleExtrapolation(int32_t id, int64_t output_depth, int64_t output_height, int64_t output_width,
template <typename InputType>
FUNC_DEF void HandleExtrapolation(IdType id, int64_t input_height, int64_t input_width,
int64_t output_depth, int64_t output_height, int64_t output_width,
const float extrapolation_value, InputType* Ydata,
const fast_divmod& div_output_height, const fast_divmod& div_output_width,
const fast_divmod& div_output_image,
const int64_t* z_outof_bounds, const int64_t* y_outof_bounds,
const int64_t* w_outof_bounds) {
int bxc, output_image_index;
div_output_image.divmod(id, bxc, output_image_index);
CUDA_LONG input_index = bxc * input_height * input_width;
// CUDA_LONG input_index = bxc * input_height * input_width;

InputType* Ydata_base = Ydata + output_image_index * (output_depth * output_height * output_width);

Expand All @@ -266,7 +269,7 @@ FUNC_DEF void HandleExtrapolation(int32_t id, int64_t output_depth, int64_t outp
}

// Extrapolate along the y dimension
if (z_outof_bounds[static_cast<ptrdiff_t>(output_z)] != -1) {
if (z_outof_bounds != nullptr && z_outof_bounds[static_cast<ptrdiff_t>(output_z)] != -1) {
#ifdef CPU_TESTING
assert(z_outof_bounds[static_cast<ptrdiff_t>(output_z)] == output_z);
#endif
Expand All @@ -280,25 +283,23 @@ __global__ void _UpsampleBilinearAntiAlias(
#ifdef CPU_TESTING
IdType id,
#endif
const int64_t batch_size,
const int64_t num_channels,
const int64_t input_height,
const int64_t input_width,
const int64_t output_depth,
const int64_t output_height,
const int64_t output_width,
fast_divmod div_output_height, fast_divmod div_output_width, fast_divmod div_output_image,
int64_t window_size,
std::tuple<int32_t, int32_t> window_sizes, // h, w
const bool use_extrapolation,
const float extrapolation_value,
std::tuple<const int64_t*, const int64_t*, const int64_t*> bounds_buffers, // z, h, w
std::tuple<const int64_t*, const int64_t*, const int64_t*> outof_bounds_buffers, // z, h, w
std::tuple<const AccumType*, const AccumType*, const AccumType*> weight_buffers, // z, h, w
std::tuple<int64_t*, int64_t*> bounds_buffers, // h, w
std::tuple<int64_t*, int64_t*, int64_t*> outof_bounds_buffers, // z, h, w
std::tuple<AccumType*, AccumType*> weight_buffers, // h, w
const uint8_t* clip8_lookups,
T* image_temp_buffer, // We expect this to be input_height * output_width * num_channels
const T* Xdata,
T* Ydata,
const int N) {
const size_t N) {

#ifndef CPU_TESTING
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
Expand All @@ -307,20 +308,21 @@ __global__ void _UpsampleBilinearAntiAlias(
// horizon interpolate
// This computes only the width direction.Thus height keeps unchanged.
ComputeInterpolationAtLevel1(id, input_height, input_width, input_height, output_width,
div_output_width, div_output_image, window_size,
clip8_lookups, std::get<2>(bounds_buffers),
std::get<2>(weight_buffers),
Xdata, Ydata);
div_output_width, div_output_image, std::get<1>(window_sizes),
clip8_lookups, std::get<1>(bounds_buffers),
std::get<1>(weight_buffers),
Xdata, image_temp_buffer);

// vertical interpolate
ComputeInterpolationAtLevel2(id, input_height, output_width, output_height, output_width,
div_output_width, div_output_image,
window_size, clip8_lookups, std::get<1>(bounds_buffers),
std::get<1>(weight_buffers),
Xdata, Ydata);
std::get<0>(window_sizes),
clip8_lookups, std::get<0>(bounds_buffers),
std::get<0>(weight_buffers),
image_temp_buffer, Ydata);

if (use_extrapolation) {
HandleExtrapolation(id, output_depth, output_height, output_width,
HandleExtrapolation(id, input_height, input_width, output_depth, output_height, output_width,
extrapolation_value, Ydata,
div_output_height, div_output_width, div_output_image,
std::get<0>(outof_bounds_buffers),
Expand Down Expand Up @@ -382,7 +384,7 @@ __global__ void _UpsampleTrilinearAntiAlias(
Xdata, Ydata);

if (use_extrapolation) {
HandleExtrapolation(id, output_depth, output_height, output_width,
HandleExtrapolation(id, input_height, input_width, output_depth, output_height, output_width,
extrapolation_value, Ydata,
div_output_height, div_output_width, div_output_image,
std::get<0>(outof_bounds_buffers),
Expand Down Expand Up @@ -586,12 +588,12 @@ __global__ void _SetupBilinearUpsampleFilterAntiAlias(
}
#endif

id = id - y_output_size;
auto i = id - y_output_size;
bounds += (y_output_size * 2);
out_of_bounds += y_output_size;

SetupUpsampleFilterAnitAliasImpl<AccumType, Filter, CudaFunctionOriginalCoordinate>(
id,
i,
input_size, output_size,
inv_scale,
roi_start, roi_end,
Expand Down Expand Up @@ -696,12 +698,12 @@ __global__ void _SetupTrilinerarUpsampleFilterAntiAlias(
}
#endif

id = id - d_output_size;
auto i = id - d_output_size;
bounds += d_output_size * 2;
out_of_bounds += d_output_size;

SetupUpsampleFilterAnitAliasImpl<AccumType, Filter, CudaFunctionOriginalCoordinate>(
id,
i,
input_size, output_size,
inv_scale,
roi_start, roi_end,
Expand Down Expand Up @@ -734,12 +736,12 @@ __global__ void _SetupTrilinerarUpsampleFilterAntiAlias(
}
#endif

id = id - d_y_output_size;
auto i = id - d_y_output_size;
bounds += (d_y_output_size * 2);
out_of_bounds += d_y_output_size;

SetupUpsampleFilterAnitAliasImpl<AccumType, Filter, CudaFunctionOriginalCoordinate>(
id,
i,
input_size, output_size,
inv_scale,
roi_start, roi_end,
Expand Down Expand Up @@ -793,21 +795,25 @@ void ResizeAntiAliasImpl(
ResizeCoordinateTransformationMode coordinate_transform_mode,
gsl::span<const int64_t> input_shape,
gsl::span<const int64_t> output_shape,
int64_t batch_size, int64_t num_channels,
std::tuple<int64_t, int64_t, int64_t> inferred_input_dims,
std::tuple<int64_t, int64_t, int64_t> inferred_output_dims,
std::tuple<float, float, float> inferred_dim_rscales,
// const TArray<int64_t>& input_strides,
const TArray<fast_divmod>& output_div_pitches,
gsl::span<const float> roi_vals,
const std::optional<T>& extrapolation_value,
const std::optional<float>& extrapolation,
bool exclude_outside,
TempSpaceAllocateFunc allocate_temp_space,
const uint8_t* clip8_lookups,
const T* input_data,
T* output_data,
const size_t N) {

using AccumType = typename onnxruntime::AccumulateType<T>::type;

const bool use_extrapolation = extrapolation.has_value();
const float extrapolation_value = use_extrapolation ? *extrapolation : 0.f;

// We support a special case of bilinear or bicubic if the input data is 4D with the outer 2 scales being 1.0
// We would have validated the outer scale values by the time execution reaches this
const bool is_2D = (rank == 2 || rank == 4);
Expand Down Expand Up @@ -865,17 +871,23 @@ void ResizeAntiAliasImpl(
weighted_w_size;
auto weighted_buffer = AllocateTyped<AccumType>(allocate_temp_space, weighted_buffer_size);

int64_t* y_bounds_buffer = GetTyped<int64_t>(weighted_buffer);
int64_t* w_bounds_buffer = y_bounds_buffer + output_height * 2;

int64_t* y_outof_bounds_buffer = GetTyped<int64_t>(out_of_bounds_buffer_ptr);
int64_t* w_outof_bounds_buffer = y_outof_bounds_buffer + output_height;

AccumType* y_weighted_buffer = GetTyped<AccumType>(weighted_buffer);
AccumType* w_weighted_buffer = y_weighted_buffer + weighted_y_size;

#ifdef CPU_TESTING

for (int64_t id = 0, lim = output_height + output_width; id < lim; ++id) {
for (IdType id = 0, lim = narrow<IdType>(output_height + output_width); id < lim; ++id) {
DISPATCH_ANTIALIAS_FILTER_SETUP(coordinate_transform_mode, [&]() {
_SetupBilinearUpsampleFilterAntiAlias<AccumType,
BilinearFilter,
coord_t>(
narrow<int32_t>(id),
id,
std::make_tuple(input_height, input_width),
std::make_tuple(output_height, output_width),
std::make_tuple(h_scale, w_scale),
Expand All @@ -895,6 +907,24 @@ void ResizeAntiAliasImpl(
auto weighted_buffer_span = gsl::make_span(GetTyped<AccumType>(weighted_buffer), weighted_buffer_size);
PrintAntiAliasBuffers(std::cout, bounds_buffer_span, out_of_bounds_buffer_span, weighted_buffer_span);

auto image_temp_buffer = AllocateTyped<T>(allocate_temp_space,
narrow<size_t>(input_height * output_width * num_channels));

for (IdType id = 0, lim = narrow<IdType>(batch_size * input_height * num_channels * input_width); id < lim; ++id) {
_UpsampleBilinearAntiAlias<T, AccumType>(
id,
input_height, input_width, output_depth, output_height, output_width,
output_div_pitches[rank - 2], output_div_pitches[rank - 1], div_output_image,
std::make_tuple(h_window_size, w_window_size),
use_extrapolation, extrapolation_value,
std::make_tuple(y_bounds_buffer, w_bounds_buffer),
std::make_tuple(static_cast<int64_t*>(nullptr) , y_outof_bounds_buffer, w_outof_bounds_buffer),
std::make_tuple(y_weighted_buffer, w_weighted_buffer),
clip8_lookups,
GetTyped<T>(image_temp_buffer),
input_data, output_data, N);
}

#else
DISPATCH_ANTIALIAS_FILTER_SETUP(coordinate_transform_mode, [&]() {
// Data is d, h, w in tuples
Expand Down Expand Up @@ -946,12 +976,12 @@ void ResizeAntiAliasImpl(
AccumType* w_weighted_buffer = y_weighted_buffer + y_buffer_size;
#ifdef CPU_TESTING
for (int64_t id = 0, lim = output_depth + output_height + output_width; id < lim; ++id) {
for (IdType id = 0, lim = narrow<IdType>(output_depth + output_height + output_width); id < lim; ++id) {
DISPATCH_ANTIALIAS_FILTER_SETUP(coordinate_transform_mode, [&]() {
_SetupTrilinerarUpsampleFilterAntiAlias<AccumType,
TriLinearFilter,
coord_t>(
narrow<int32_t>(id),
id,
inferred_input_dims,
inferred_output_dims,
inferred_dim_rscales,
Expand Down Expand Up @@ -1026,12 +1056,12 @@ void ResizeAntiAliasImpl(
#ifdef CPU_TESTING
for (int64_t id = 0, lim = output_height + output_width; id < lim; ++id) {
for (IdType id = 0, lim = narrow<IdType>(output_height + output_width); id < lim; ++id) {
DISPATCH_ANTIALIAS_FILTER_SETUP(coordinate_transform_mode, [&]() {
_SetupBilinearUpsampleFilterAntiAlias<AccumType,
BiCubicFilter,
coord_t>(
narrow<int32_t>(id),
id,
std::make_tuple(input_height, input_width),
std::make_tuple(output_height, output_width),
std::make_tuple(h_scale, w_scale),
Expand Down Expand Up @@ -1086,14 +1116,16 @@ void ResizeAntiAliasImpl(
ResizeCoordinateTransformationMode coordinate_transform_mode, \
gsl::span<const int64_t> input_shape, \
gsl::span<const int64_t> output_shape, \
int64_t batch_size, int64_t num_channels, \
std::tuple<int64_t, int64_t, int64_t> inferred_input_dims, \
std::tuple<int64_t, int64_t, int64_t> inferred_output_dims, \
std::tuple<float, float, float> inferred_dim_rscales, /* const TArray<int64_t>& input_strides, */ \
const TArray<fast_divmod>& output_div_pitches, \
gsl::span<const float> roi_vals, \
const std::optional<T>& extrapolation_value, \
const std::optional<float>& extrapolation_value, \
bool exclude_outside, \
TempSpaceAllocateFunc allocate_temp_space, \
const uint8_t* clip8_lookups, \
const T* input_data, \
T* output_data, \
const size_t N);
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/core/providers/cuda/tensor/resize_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,17 @@ void ResizeAntiAliasImpl(
ResizeCoordinateTransformationMode coordinate_transform_mode,
gsl::span<const int64_t> input_shape,
gsl::span<const int64_t> output_shape,
int64_t batch_size, int64_t num_channels,
std::tuple<int64_t, int64_t, int64_t> inferred_input_dims,
std::tuple<int64_t, int64_t, int64_t> inferred_output_dims,
std::tuple<float, float, float> inferred_dim_rscales,
// const TArray<int64_t>& input_strides,
const TArray<fast_divmod>& output_div_pitches,
gsl::span<const float> roi_vals, // CPU
const std::optional<T>& extrapolation_value,
const std::optional<float>& extrapolation_value,
bool exclude_outside,
TempSpaceAllocateFunc allocate_temp_space,
const uint8_t* clip8_lookups,
const T* input_data,
T* output_data,
const size_t N);
Expand Down
Loading

0 comments on commit 2cfe280

Please sign in to comment.