Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
yuslepukhin committed Feb 28, 2024
1 parent 250f54f commit b54e7df
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 86 deletions.
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/cpu/tensor/upsample.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ REGISTER_VERSIONED_TYPED_KERNEL(uint8_t, 9, 9);

void UpsampleBase::AdjustOutputSizeAsPolicy(TensorShapeVector& output_dims, gsl::span<const int64_t> input_dims,
InlinedVector<float>& scales) const {
InlinedHashSet<int64_t> axes_set(axes_.begin(), axes_.end());

// AspectRatioPolicy::STRETCH is default policy when opset < 18
if (keep_aspect_ratio_policy_ == AspectRatioPolicy::STRETCH) {
return;
}

InlinedHashSet<int64_t> axes_set(axes_.begin(), axes_.end());

float scale_in_policy = 0.0f;
if (keep_aspect_ratio_policy_ == AspectRatioPolicy ::NOT_LARGER) {
scale_in_policy = std::numeric_limits<float>::max();
Expand Down
14 changes: 7 additions & 7 deletions onnxruntime/core/providers/cpu/tensor/upsample_antialias.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ struct FilterParamsBaseAntiAlias {

template <typename T>
struct FilterParamsAntiAlias {
float support_size = kSupportSize;
float cubic_coeff_a = kCubicCoeffA;
float support_size = antialias_constants::kSupportSize;
float cubic_coeff_a = antialias_constants::kCubicCoeffA;

FilterParamsBaseAntiAlias<T> dim_x;
FilterParamsBaseAntiAlias<T> dim_y;
Expand Down Expand Up @@ -63,7 +63,7 @@ struct BilinearParamsAntiAlias : FilterParamsAntiAlias<T> {
template <typename T>
struct BiCubicParamsAntiAlias : FilterParamsAntiAlias<T> {
BiCubicParamsAntiAlias() {
this->support_size = kBiCubicSupportSize;
this->support_size = antialias_constants::kBiCubicSupportSize;
}

// taken from
Expand Down Expand Up @@ -109,9 +109,9 @@ struct TriLinearParamsAntiAlias : FilterParamsAntiAlias<T> {
// - [N, H, W, C] and the scales are [1.0, height_scale, width_scale, 1.0]
template <class T>
void SetupUpsampleFilterAntiAlias(FilterParamsAntiAlias<T>& p,
gsl::span<int64_t> input_h_w_c,
gsl::span<int64_t> output_h_w_c,
gsl::span<float> scale_h_w_c,
gsl::span<const int64_t> input_h_w_c,
gsl::span<const int64_t> output_h_w_c,
gsl::span<const float> scale_h_w_c,
gsl::span<const float> roi,
AllocatorPtr& alloc,
const GetOriginalCoordinateFunc& get_original_coordinate,
Expand Down Expand Up @@ -199,7 +199,7 @@ void SetupUpsampleFilterAntiAlias(FilterParamsAntiAlias<T>& p,

// normalize the scale to 1 << 22 for int8/uint8
if constexpr (std::is_same<T, int32_t>::value) {
scale_buffer_int[x] = static_cast<int32_t>(std::round(scale_buffer[x] * ConstValue::mag_factor * 2.f));
scale_buffer_int[x] = static_cast<int32_t>(std::round(scale_buffer[x] * ConstValue::mag_factor_x_2));
}
}
/*for (; x < window_size; x++) {
Expand Down
11 changes: 7 additions & 4 deletions onnxruntime/core/providers/cpu/tensor/upsamplebase.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,17 @@ struct AccumulateType<double> {
using type = double;
};

// Antialiasing constants
namespace antialias_constants {
constexpr float kCubicCoeffA = -0.75f;
constexpr float kSupportSize = 2.0f;
constexpr float kBiCubicSupportSize = 4.0f;
} // namespace antialias_constants

namespace ConstValue {
constexpr int32_t mag_factor = 1 << (22 - 1);
}
// We use to multiply by 2, let's make a constant which is twice as big
constexpr int32_t mag_factor_x_2 = 1 << 22;
} // namespace ConstValue

template <class T>
inline constexpr bool is_8bit_v = std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
Expand Down Expand Up @@ -157,7 +160,7 @@ class UpsampleBase {
keep_aspect_ratio_policy_ = StringToKeepAspectRatioPolicy(keep_aspect_ratio_policy);

// guard against unit tests that can add an attribute
auto axes = info.GetAttrsOrDefault("axes");
auto axes = info.GetAttrsOrDefault<int64_t>("axes");
axes_.assign(axes.cbegin(), axes.cend());
}

Expand Down Expand Up @@ -186,7 +189,7 @@ class UpsampleBase {
nearest_mode_ = StringToNearestMode(nearest_mode_name);
get_nearest_pixel_ = GetNearestPixelFromOriginal(nearest_mode_);

cubic_coeff_a_ = info.GetAttrOrDefault<float>("cubic_coeff_a", kCubicCoeffA);
cubic_coeff_a_ = info.GetAttrOrDefault<float>("cubic_coeff_a", antialias_constants::kCubicCoeffA);
exclude_outside_ = info.GetAttrOrDefault<int64_t>("exclude_outside", 0) == 0 ? false : true;

if ((exclude_outside_ == 1 && mode_ != CUBIC) && (antialias_ == false || mode_ != LINEAR)) {
Expand Down
88 changes: 41 additions & 47 deletions onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ FUNC_DEF void SetupUpsampleFilterAnitAliasImpl(
for (int64_t x = 0; x < max_cut - min_cut; x++) {
scale_buffer[x] *= total_weight_inv;
// normalize the scale to 1 << 22 for int8/uint8
scale_buffer_int[x] = static_cast<int32_t>(_Round(scale_buffer[x] * ConstValue::mag_factor * 2.f));
scale_buffer_int[x] = static_cast<int32_t>(_Round(scale_buffer[x] * ConstValue::mag_factor_x_2));
}
} else {
for (int64_t x = 0; x < max_cut - min_cut; x++) {
Expand Down Expand Up @@ -571,7 +571,7 @@ __global__ void _SetupTrilinerarUpsampleFilterAntiAlias(
roi_start, roi_end,
scaled_support, window_size,
exclude_outisde,
onnxruntime::kCubicCoeffA, // Default value for trilinear
onnxruntime::antialias_constants::kCubicCoeffA, // Default value for trilinear
bounds,
out_of_bounds,
std::get<0>(weighted_coefficients));
Expand Down Expand Up @@ -600,7 +600,7 @@ __global__ void _SetupTrilinerarUpsampleFilterAntiAlias(
roi_start, roi_end,
scaled_support, window_size,
exclude_outisde,
onnxruntime::kCubicCoeffA, // Default value for trilinear
onnxruntime::antialias_constants::kCubicCoeffA, // Default value for trilinear
bounds,
out_of_bounds,
std::get<1>(weighted_coefficients));
Expand All @@ -627,7 +627,7 @@ __global__ void _SetupTrilinerarUpsampleFilterAntiAlias(
roi_start, roi_end,
scaled_support, window_size,
exclude_outisde,
onnxruntime::kCubicCoeffA, // Default value for trilinear
onnxruntime::antialias_constants::kCubicCoeffA, // Default value for trilinear
bounds,
out_of_bounds,
std::get<2>(weighted_coefficients));
Expand Down Expand Up @@ -665,7 +665,7 @@ __global__ void _SetupTrilinerarUpsampleFilterAntiAlias(
namespace {
template <typename T>
IAllocatorUniquePtr<uint8_t> AllocateTyped(
const std::function<onnxruntime::IAllocatorUniquePtr<uint8_t>(size_t)>& alloc,
const TempSpaceAllocateFunc& alloc,
size_t elements) {
return alloc(elements * sizeof(T));
}
Expand Down Expand Up @@ -713,7 +713,7 @@ void ResizeTrilinearUpsample(

int blocksPerGrid = static_cast<int>(ceil(static_cast<float>(N) / GridDim::maxThreadsPerBlock));

constexpr float support_value = kSupportSize;
constexpr float support_value = antialias_constants::kSupportSize;
float z_scale, h_scale, w_scale;
std::tie(z_scale, h_scale, w_scale) = inferred_dim_rscales;

Expand Down Expand Up @@ -861,19 +861,19 @@ void ResizeBiLinearUpsample(cudaStream_t stream,
std::tie(output_depth, output_height, output_width) = inferred_output_dims;

int blocksPerDimsMappingGrid =
static_cast<int>(ceil((output_depth + output_height + output_width) / 32.0));
narrow<int>(CeilDiv((output_depth + output_height + output_width), 32));

// rank 2 or 4
const fast_divmod div_output_image = (rank > 2) ? output_div_pitches[rank - 4]
: fast_divmod(gsl::narrow_cast<int>(N));
const fast_divmod& div_output_width = output_div_pitches[rank - 2];

constexpr float support_value = kSupportSize;
constexpr float support_value = antialias_constants::kSupportSize;

float h_scale, w_scale;
std::tie(std::ignore, h_scale, w_scale) = inferred_dim_rscales;

int blocksPerGrid = static_cast<int>(ceil(static_cast<float>(N) / GridDim::maxThreadsPerBlock));
int blocksPerGrid = narrow<int>(CeilDiv(N, GridDim::maxThreadsPerBlock));

SafeInt<int64_t> bounds_buffer_size = (SafeInt<int64_t>(output_height) + output_width) * 2;
SafeInt<int64_t> out_of_bounds_buffer_size = (SafeInt<int64_t>(output_height) + output_width);
Expand All @@ -894,16 +894,14 @@ void ResizeBiLinearUpsample(cudaStream_t stream,
int64_t* y_outof_bounds_buffer = GetTyped<int64_t>(out_of_bounds_buffer_ptr);
int64_t* w_outof_bounds_buffer = y_outof_bounds_buffer + output_height;

const int64_t weighted_buffer_size = SafeInt<int64_t>(weighted_y_size) +
weighted_w_size;
auto weighted_buffer_ptr = AllocateTyped<AccumType>(allocate_temp_space, weighted_buffer_size);
const int64_t weighted_buffer_size = SafeInt<int64_t>(weighted_y_size) + weighted_w_size;
auto weighted_buffer_ptr = AllocateTyped<AccumType>(allocate_temp_space, narrow<size_t>(weighted_buffer_size));

AccumType* y_weighted_buffer = GetTyped<AccumType>(weighted_buffer_ptr);
AccumType* w_weighted_buffer = y_weighted_buffer + weighted_y_size;

const auto temp_buf_size = num_channels * input_height * output_width;
auto image_temp_buffer = AllocateTyped<T>(allocate_temp_space,
narrow<size_t>(temp_buf_size));
auto image_temp_buffer = AllocateTyped<T>(allocate_temp_space, narrow<size_t>(temp_buf_size));

// clang-format off
DISPATCH_ANTIALIAS_FILTER_SETUP(coordinate_transform_mode, [&]() {
Expand All @@ -919,7 +917,7 @@ void ResizeBiLinearUpsample(cudaStream_t stream,
std::make_tuple(roi_vals[rank - 2 + rank], roi_vals[rank - 1 + rank]), // roi ends h, w
std::make_tuple(h_scaled_support, w_scaled_support),
std::make_tuple(h_window_size, w_window_size),
onnxruntime::kCubicCoeffA, exclude_outside,
onnxruntime::antialias_constants::kCubicCoeffA, exclude_outside,
GetTyped<int64_t>(bounds_buffer_ptr),
GetTyped<int64_t>(out_of_bounds_buffer_ptr),
std::make_tuple(y_weighted_buffer, w_weighted_buffer));
Expand Down Expand Up @@ -984,12 +982,12 @@ void ResizeBicubicUpsample(cudaStream_t stream,
const bool use_extrapolation = extrapolation.has_value();
const float extrapolation_value = use_extrapolation ? *extrapolation : 0.f;

int blocksPerGrid = static_cast<int>(ceil(static_cast<float>(N) / GridDim::maxThreadsPerBlock));
int blocksPerGrid = narrow<int>(CeilDiv(N, GridDim::maxThreadsPerBlock));
const fast_divmod div_output_image = (rank > 2) ? output_div_pitches[rank - 4]
: fast_divmod(gsl::narrow_cast<int>(N));
const fast_divmod& div_output_width = output_div_pitches[rank - 2];

constexpr float support_value = kBiCubicSupportSize;
constexpr float support_value = antialias_constants::kBiCubicSupportSize;

int64_t input_depth, input_height, input_width;
std::tie(input_depth, input_height, input_width) = inferred_input_dims;
Expand All @@ -998,7 +996,7 @@ void ResizeBicubicUpsample(cudaStream_t stream,
std::tie(output_depth, output_height, output_width) = inferred_output_dims;

int blocksPerDimsMappingGrid =
static_cast<int>(ceil((output_depth + output_height + output_width) / 32.0));
narrow<int>(CeilDiv((output_depth + output_height + output_width), 32));

float h_scale, w_scale;
std::tie(std::ignore, h_scale, w_scale) = inferred_dim_rscales;
Expand All @@ -1013,8 +1011,6 @@ void ResizeBicubicUpsample(cudaStream_t stream,
h_scale, w_scale, support_value,
h_scaled_support, w_scaled_support, h_window_size, w_window_size);

std::cout << std::endl;

auto bounds_buffer_ptr = AllocateTyped<int64_t>(allocate_temp_space, bounds_buffer_size);
auto out_of_bounds_buffer_ptr = AllocateTyped<int64_t>(allocate_temp_space, out_of_bounds_buffer_size);

Expand All @@ -1032,8 +1028,7 @@ void ResizeBicubicUpsample(cudaStream_t stream,
AccumType* w_weighted_buffer = y_weighted_buffer + weighted_y_size;

const auto temp_buf_size = SafeInt<int64_t>(batch_size) * num_channels * input_height * output_width;
auto image_temp_buffer = AllocateTyped<T>(allocate_temp_space,
narrow<size_t>(temp_buf_size));
auto image_temp_buffer = AllocateTyped<T>(allocate_temp_space, narrow<size_t>(temp_buf_size));

// clang-format off
DISPATCH_ANTIALIAS_FILTER_SETUP(coordinate_transform_mode, [&]() {
Expand All @@ -1047,7 +1042,7 @@ void ResizeBicubicUpsample(cudaStream_t stream,
std::make_tuple(roi_vals[rank - 2 + rank], roi_vals[rank - 1 + rank]), // roi ends h, w
std::make_tuple(h_scaled_support, w_scaled_support),
std::make_tuple(h_window_size, w_window_size),
onnxruntime::kCubicCoeffA, exclude_outside,
onnxruntime::antialias_constants::kCubicCoeffA, exclude_outside,
GetTyped<int64_t>(bounds_buffer_ptr),
GetTyped<int64_t>(out_of_bounds_buffer_ptr),
std::make_tuple(y_weighted_buffer, w_weighted_buffer));
Expand Down Expand Up @@ -1096,7 +1091,6 @@ void ResizeAntiAliasImpl(
std::tuple<int64_t, int64_t, int64_t> inferred_input_dims,
std::tuple<int64_t, int64_t, int64_t> inferred_output_dims,
std::tuple<float, float, float> inferred_dim_rscales,
// const TArray<int64_t>& input_strides,
const TArray<fast_divmod>& output_div_pitches,
gsl::span<const float> roi_vals,
const std::optional<float>& extrapolation,
Expand Down Expand Up @@ -1133,7 +1127,7 @@ void ResizeAntiAliasImpl(
output_div_pitches, roi_vals, extrapolation, exclude_outside,
allocate_temp_space, clip8_lookups, input_data, output_data, N);
} else {
ORT_THROW("Resize supports only 2-D or 3-D in LINEAR mode.");
ORT_NOT_IMPLEMENTED("Resize supports only 2-D or 3-D in LINEAR mode.");
}
} break;
case CUBIC: {
Expand All @@ -1144,35 +1138,35 @@ void ResizeAntiAliasImpl(
output_div_pitches, roi_vals, extrapolation, exclude_outside,
allocate_temp_space, clip8_lookups, input_data, output_data, N);
} else {
ORT_THROW("Resize supports only 2-D in CUBIC mode.");
ORT_NOT_IMPLEMENTED("Resize supports only 2-D in CUBIC mode.");
}
} break;
default:
ORT_THROW("Only bilinear/trilinear and bicubic modes are supported in Resize anti-alias mode");
ORT_NOT_IMPLEMENTED("Only bilinear/trilinear and bicubic modes are supported in Resize anti-alias mode");
break;
}
}

#define SPECIALIZED_ANTIALIAS_IMPL(T) \
template void ResizeAntiAliasImpl<T>( \
cudaStream_t stream, \
int rank, \
const UpsampleMode upsample_mode, \
ResizeCoordinateTransformationMode coordinate_transform_mode, \
gsl::span<const int64_t> input_shape, \
gsl::span<const int64_t> output_shape, \
int64_t batch_size, int64_t num_channels, \
std::tuple<int64_t, int64_t, int64_t> inferred_input_dims, \
std::tuple<int64_t, int64_t, int64_t> inferred_output_dims, \
std::tuple<float, float, float> inferred_dim_rscales, /* const TArray<int64_t>& input_strides, */ \
const TArray<fast_divmod>& output_div_pitches, \
gsl::span<const float> roi_vals, \
const std::optional<float>& extrapolation_value, \
bool exclude_outside, \
TempSpaceAllocateFunc allocate_temp_space, \
const uint8_t* clip8_lookups, \
const T* input_data, \
T* output_data, \
#define SPECIALIZED_ANTIALIAS_IMPL(T) \
template void ResizeAntiAliasImpl<T>( \
cudaStream_t stream, \
int rank, \
const UpsampleMode upsample_mode, \
ResizeCoordinateTransformationMode coordinate_transform_mode, \
gsl::span<const int64_t> input_shape, \
gsl::span<const int64_t> output_shape, \
int64_t batch_size, int64_t num_channels, \
std::tuple<int64_t, int64_t, int64_t> inferred_input_dims, \
std::tuple<int64_t, int64_t, int64_t> inferred_output_dims, \
std::tuple<float, float, float> inferred_dim_rscales, \
const TArray<fast_divmod>& output_div_pitches, \
gsl::span<const float> roi_vals, \
const std::optional<float>& extrapolation_value, \
bool exclude_outside, \
TempSpaceAllocateFunc allocate_temp_space, \
const uint8_t* clip8_lookups, \
const T* input_data, \
T* output_data, \
const size_t N);

SPECIALIZED_ANTIALIAS_IMPL(float)
Expand Down
2 changes: 0 additions & 2 deletions onnxruntime/core/providers/cuda/tensor/resize_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ void ResizeAntiAliasImpl(
std::tuple<int64_t, int64_t, int64_t> inferred_input_dims,
std::tuple<int64_t, int64_t, int64_t> inferred_output_dims,
std::tuple<float, float, float> inferred_dim_rscales,
// const TArray<int64_t>& input_strides,
const TArray<fast_divmod>& output_div_pitches,
gsl::span<const float> roi_vals, // CPU
const std::optional<float>& extrapolation_value,
Expand Down Expand Up @@ -140,7 +139,6 @@ inline int32_t ComputeWindowSize(float scaled_support) {
/// <summary>
/// Computes scale buffer size in number of elements for allocation purposes.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="output_size"></param>
/// <param name="window_size"></param>
/// <returns>Number of elements to fit in the buffer</returns>
Expand Down
Loading

0 comments on commit b54e7df

Please sign in to comment.