Skip to content

Commit

Permalink
Address Lint And build issues
Browse files Browse the repository at this point in the history
  • Loading branch information
yuslepukhin committed Feb 22, 2024
1 parent f6db0ba commit 9f766d2
Show file tree
Hide file tree
Showing 12 changed files with 81 additions and 47 deletions.
6 changes: 4 additions & 2 deletions onnxruntime/core/providers/cpu/cpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2007,8 +2007,10 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
Greater)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Less)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Less)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, Less)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int64_t, Less)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13,
int32_t, Less)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13,
int64_t, Less)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 13,
float, Add)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 13,
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/core/providers/cpu/cpu_provider_shared.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
#endif

#include "cpu_provider_shared.h"
#include <limits>

namespace onnxruntime {
// The suppressed warning is: "The type with a virtual function needs either public virtual or protected nonvirtual destructor."
Expand Down Expand Up @@ -293,7 +294,8 @@ struct ProviderHostCPUImpl : ProviderHostCPU {
Status Sampling__Compute(const contrib::transformers::Sampling* p, OpKernelContext* ctx) override { return p->contrib::transformers::Sampling::Compute(ctx); }
Status Sampling__SetupSubgraphExecutionInfo(contrib::transformers::Sampling* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) override { return p->contrib::transformers::Sampling::SetupSubgraphExecutionInfo(session_state, attribute_name, subgraph_session_state); }

void UpsampleBase__AdjustOutputSizeAsPolicy(const UpsampleBase* p, TensorShapeVector& output_dims, gsl::span<const int64_t> input_dims,
void UpsampleBase__AdjustOutputSizeAsPolicy(const UpsampleBase* p, TensorShapeVector& output_dims,
gsl::span<const int64_t> input_dims,
InlinedVector<float>& scales) const override {
p->AdjustOutputSizeAsPolicy(output_dims, input_dims, scales);
}
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/providers/cpu/tensor/upsample.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
#include "core/common/safeint.h"
#include "core/platform/threadpool.h"
#include "core/providers/cpu/tensor/upsample_antialias.h"

#include <limits>

Check warning on line 11 in onnxruntime/core/providers/cpu/tensor/upsample.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Found C++ system header after other header. Should be: upsample.h, c system, c++ system, other. [build/include_order] [4] Raw Output: onnxruntime/core/providers/cpu/tensor/upsample.cc:11: Found C++ system header after other header. Should be: upsample.h, c system, c++ system, other. [build/include_order] [4]

using namespace onnxruntime::common;
using namespace std;
using onnxruntime::narrow;
Expand Down
13 changes: 7 additions & 6 deletions onnxruntime/core/providers/cpu/tensor/upsample_antialias.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,13 @@ void SetupUpsampleFilterAntiAlias(FilterParamsAntiAlias<T>& p,
AllocatorPtr& alloc,
const GetOriginalCoordinateFunc& get_original_coordinate,
bool exclude_outside, const bool is_nchw) {
auto compute_weight_coefficients = [&alloc, roi, &get_original_coordinate, exclude_outside](const FilterParamsAntiAlias<T>& p,
const int64_t input_size,
const int64_t output_size,
size_t rindex,
FilterParamsBaseAntiAlias<T>& param_base,
const float rscale) -> int64_t {
auto compute_weight_coefficients = [&alloc, roi, &get_original_coordinate, exclude_outside](
const FilterParamsAntiAlias<T>& p,
const int64_t input_size,
const int64_t output_size,
size_t rindex,
FilterParamsBaseAntiAlias<T>& param_base,
const float rscale) -> int64_t {
param_base.bound.reserve(static_cast<size_t>(output_size) * 2);
param_base.out_of_bound_idx.reserve(static_cast<size_t>(output_size));

Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/cpu/tensor/upsamplebase.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#pragma once

#include <algorithm>
#include <string>
#include <string_view>
#include <unordered_map>
Expand Down
49 changes: 29 additions & 20 deletions onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@ using onnxruntime::UpsampleMode;
/// <summary>
/// Compute a buffer for bilinear data for CUDA antialias resizing.
/// </summary>
static std::tuple<int64_t, int64_t> ComputeBilinearScaleBufferSize(int64_t output_height, int64_t output_width,
float height_rscale, float width_rscale,
float support_value,
float& scaled_support_height, float& scaled_support_width,
int32_t& window_size_height, int32_t& window_size_width) {
static std::tuple<int64_t, int64_t> ComputeBilinearScaleBufferSize(
int64_t output_height, int64_t output_width,
float height_rscale, float width_rscale,
float support_value,
float& scaled_support_height, float& scaled_support_width,
int32_t& window_size_height, int32_t& window_size_width) {
scaled_support_height = ComputeScaledSupportValue(support_value, height_rscale);
scaled_support_width = ComputeScaledSupportValue(support_value, width_rscale);
window_size_height = ComputeWindowSize(scaled_support_height);
Expand All @@ -45,8 +46,11 @@ static std::tuple<int64_t, int64_t, int64_t> ComputeTrilinearScaleBufferSize(
window_size_depth = ComputeWindowSize(scaled_support_depth);
auto depth_buffer_size = ComputeWeightedCoeffBufferSize(output_depth, window_size_depth);

const auto [y_buffer_size, w_buffer_size] = ComputeBilinearScaleBufferSize(output_height, output_width, height_rscale,
width_rscale, support_value, scaled_support_height, scaled_support_width,
const auto [y_buffer_size, w_buffer_size] = ComputeBilinearScaleBufferSize(output_height,
output_width, height_rscale,
width_rscale, support_value,
scaled_support_height,
scaled_support_width,
window_size_height, window_size_width);
return std::make_tuple(depth_buffer_size, y_buffer_size, w_buffer_size);
}
Expand Down Expand Up @@ -121,7 +125,6 @@ __global__ void _ComputeInterpolationAtLevel1(
const AccumType* weight_coefficients,
const T* Xdata, T* Ydata,
const int N) {

CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);

// No need to do scale
Expand Down Expand Up @@ -185,7 +188,6 @@ __global__ void _ComputeInterpolationAtLevel2(
std::tuple<int64_t*, int64_t*> outof_bounds_buffers,
const AccumType* weight_coefficients,
const T* Xdata, T* Ydata, int N) {

CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);

// No need to do scale
Expand All @@ -201,8 +203,10 @@ __global__ void _ComputeInterpolationAtLevel2(
div_output_height.divmod(output_image_index, output_z, temp);
div_output_width.divmod(temp, output_y, output_x);

CUDA_LONG input_index = static_cast<CUDA_LONG>(bxc * num_channels * input_height * input_width + output_z * input_height * input_width);
CUDA_LONG output_index = static_cast<CUDA_LONG>(bxc * num_channels * output_height * output_width + output_z * output_height * output_width);
CUDA_LONG input_index = static_cast<CUDA_LONG>(bxc * num_channels * input_height * input_width +
output_z * input_height * input_width);
CUDA_LONG output_index = static_cast<CUDA_LONG>(bxc * num_channels * output_height * output_width +
output_z * output_height * output_width);

auto* Ydata_offset = Ydata + output_index + output_width * output_y + output_x;

Expand Down Expand Up @@ -268,7 +272,6 @@ __global__ void _ComputeInterpolationAtLevel3(
std::tuple<int64_t*, int64_t*, int64_t*> outof_bounds_buffers,
const AccumType* weight_coefficients,
const T* Xdata, T* Ydata, int N) {

CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);

// No need to do scale
Expand Down Expand Up @@ -391,7 +394,7 @@ FUNC_DEF void SetupUpsampleFilterAnitAliasImpl(

int64_t min_real = static_cast<int64_t>(fmin);
int64_t max_real = static_cast<int64_t>(fmax);
int64_t min_cut = std::max(min_real, 0LL);
int64_t min_cut = std::max<int64_t>(min_real, 0);
int64_t max_cut = std::min(max_real, input_size);

int64_t min_val = exclude_outside ? min_cut : min_real;
Expand Down Expand Up @@ -547,7 +550,6 @@ __global__ void _SetupTrilinerarUpsampleFilterAntiAlias(
int64_t* bounds,
int64_t* out_of_bounds,
std::tuple<AccumType*, AccumType*, AccumType*> weighted_coefficients) {

const auto N = std::get<0>(output_dims) + std::get<1>(output_dims) + std::get<2>(output_dims);

CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
Expand Down Expand Up @@ -639,6 +641,7 @@ __global__ void _SetupTrilinerarUpsampleFilterAntiAlias(
break; \
}

// cpplint off
#define DISPATCH_ANTIALIAS_FILTER_SETUP(coord_enum, ...) \
[&] { \
const auto the_type = coord_enum; \
Expand All @@ -653,6 +656,7 @@ __global__ void _SetupTrilinerarUpsampleFilterAntiAlias(
ORT_THROW("unknown ResizeCoordinateTransformationMode"); \
} \
}()
// cpplint on

namespace {
template <typename T>
Expand Down Expand Up @@ -739,10 +743,13 @@ void ResizeTrilinearUpsample(
AccumType* y_weighted_buffer = z_weighted_buffer + z_buffer_size;
AccumType* w_weighted_buffer = y_weighted_buffer + y_buffer_size;

const auto h_w_interpolate_temp_buf_size = SafeInt<int64_t>(batch_size) * num_channels * input_depth * input_height * output_width;
auto h_w_interpolate_temp_buffer_ptr = AllocateTyped<T>(allocate_temp_space, narrow<size_t>(h_w_interpolate_temp_buf_size));
const auto h_w_interpolate_temp_buf_size = SafeInt<int64_t>(batch_size) * num_channels *
input_depth * input_height * output_width;
auto h_w_interpolate_temp_buffer_ptr = AllocateTyped<T>(allocate_temp_space,
narrow<size_t>(h_w_interpolate_temp_buf_size));

const auto h_w_interpolate_result_buffer_size = SafeInt<int64_t>(batch_size) * num_channels * input_depth * output_height * output_width;
const auto h_w_interpolate_result_buffer_size = SafeInt<int64_t>(batch_size) * num_channels *
input_depth * output_height * output_width;
auto h_w_interpolate_result_buffer_ptr = AllocateTyped<T>(allocate_temp_space, h_w_interpolate_result_buffer_size);

// clang-format off
Expand Down Expand Up @@ -852,7 +859,8 @@ void ResizeBiLinearUpsample(cudaStream_t stream,
static_cast<int>(ceil((output_depth + output_height + output_width) / 32.0));

// rank 2 or 4
const fast_divmod div_output_image = (rank > 2) ? output_div_pitches[rank - 4] : fast_divmod(gsl::narrow_cast<int>(N));
const fast_divmod div_output_image = (rank > 2) ? output_div_pitches[rank - 4]
: fast_divmod(gsl::narrow_cast<int>(N));
const fast_divmod& div_output_width = output_div_pitches[rank - 2];

constexpr float support_value = kSupportSize;
Expand Down Expand Up @@ -972,7 +980,8 @@ void ResizeBicubicUpsample(cudaStream_t stream,
const float extrapolation_value = use_extrapolation ? *extrapolation : 0.f;

int blocksPerGrid = static_cast<int>(ceil(static_cast<float>(N) / GridDim::maxThreadsPerBlock));
const fast_divmod div_output_image = (rank > 2) ? output_div_pitches[rank - 4] : fast_divmod(gsl::narrow_cast<int>(N));
const fast_divmod div_output_image = (rank > 2) ? output_div_pitches[rank - 4]
: fast_divmod(gsl::narrow_cast<int>(N));
const fast_divmod& div_output_width = output_div_pitches[rank - 2];

constexpr float support_value = kBiCubicSupportSize;
Expand Down Expand Up @@ -1170,4 +1179,4 @@ SPECIALIZED_ANTIALIAS_IMPL(int32_t)
SPECIALIZED_ANTIALIAS_IMPL(uint8_t)

} // namespace cuda
} // namespace onnxruntime
} // namespace onnxruntime
9 changes: 7 additions & 2 deletions onnxruntime/core/providers/cuda/tensor/resize_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ struct NearestPixel_CEIL {
#define CASE_TYPE_COORD(enum_type, type, ...) \
CASE_TYPE_USING_HINT(enum_type, type, coord_t, __VA_ARGS__)

// cpplint off
#define DISPATCH_RESIZE_COORDINATE_TRANSFORMATION_MODE(TYPE, ...) \
[&] { \
const auto& the_type = TYPE; \
Expand All @@ -71,6 +72,7 @@ struct NearestPixel_CEIL {
ORT_THROW("unknown ResizeCoordinateTransformationMode"); \
} \
}()
// cpplint on

#define CASE_TYPE_NEAREST(enum_type, type, ...) \
CASE_TYPE_USING_HINT(enum_type, type, nearest_t, __VA_ARGS__)
Expand Down Expand Up @@ -171,9 +173,12 @@ __global__ void _ResizeNearestMappingKernel(
if (scales[axis] == 1.0f) {
dims_mapping[id].extrapolate_ = 0;
} else {
float orig_coord = transform_coordinate(static_cast<float>(dim), scales[axis], static_cast<float>(output_shape[axis]),
float orig_coord = transform_coordinate(static_cast<float>(dim), scales[axis],
static_cast<float>(output_shape[axis]),
static_cast<float>(input_shape[axis]), roi[axis], roi[axis + rank]);
dims_mapping[id].extrapolate_ = static_cast<int>(extrapolation_enabled && (orig_coord < 0.f || orig_coord > static_cast<float>(input_shape[axis] - 1)));
dims_mapping[id].extrapolate_ = static_cast<int>(extrapolation_enabled &&
(orig_coord < 0.f ||
orig_coord > static_cast<float>(input_shape[axis] - 1)));
dim = calc_nearest_pixel(orig_coord, scales[axis] < 1);
if (dim >= input_shape[axis]) dim = input_shape[axis] - 1;
if (dim < 0) dim = 0;
Expand Down
17 changes: 10 additions & 7 deletions onnxruntime/core/providers/cuda/tensor/resize_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@ struct AccumulateType<half> {
namespace cuda {

struct TransformCoordinate_ASYMMETRIC {
__device__ __host__ __forceinline__ float operator()(float x_resized, float x_scale, float, float, float, float) const {
__device__ __host__ __forceinline__ float operator()(float x_resized, float x_scale,
float, float, float, float) const {
return x_resized / x_scale;
}
};

struct TransformCoordinate_HALF_PIXEL {
__device__ __host__ __forceinline__ float operator()(float x_resized, float x_scale, float, float, float, float) const {
__device__ __host__ __forceinline__ float operator()(float x_resized, float x_scale,
float, float, float, float) const {
return ((x_resized + 0.5f) / x_scale) - 0.5f;
}
};
Expand All @@ -35,21 +37,22 @@ struct TransformCoordinate_PYTORCH_HALF_PIXEL {
};

struct TransformCoordinate_TF_HALF_PIXEL_FOR_NN {
__device__ __host__ __forceinline__ float operator()(float x_resized, float x_scale, float, float, float, float) const {
__device__ __host__ __forceinline__ float operator()(float x_resized, float x_scale,
float, float, float, float) const {
return (x_resized + 0.5f) / x_scale;
}
};

struct TransformCoordinate_ALIGN_CORNERS {
__device__ __host__ __forceinline__ float operator()(float x_resized, float, float length_resized, float length_original,
float, float) const {
__device__ __host__ __forceinline__ float operator()(float x_resized, float, float length_resized,
float length_original, float, float) const {
return length_resized == 1 ? 0 : x_resized * (length_original - 1) / (length_resized - 1);
}
};

struct TransformCoordinate_TF_CROP_AND_RESIZE {
__device__ __host__ __forceinline__ float operator()(float x_resized, float, float length_resized, float length_original,
float roi_start, float roi_end) const {
__device__ __host__ __forceinline__ float operator()(float x_resized, float, float length_resized,
float length_original, float roi_start, float roi_end) const {
auto orig = length_resized > 1
? roi_start * (length_original - 1) +
(x_resized * (roi_end - roi_start) * (length_original - 1)) / (length_resized - 1)
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/core/providers/cuda/tensor/upsample.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
// Licensed under the MIT License.

#include "upsample.h"

#include <utility>

#include "upsample_impl.h"
#include "core/providers/cuda/tensor/resize_impl.h"
#include "core/providers/cpu/tensor/utils.h"
Expand Down Expand Up @@ -247,7 +250,6 @@ Status Upsample<T>::BaseCompute(OpKernelContext* context,
reinterpret_cast<const CudaT*>(X->Data<T>()),
reinterpret_cast<CudaT*>(Y->MutableData<T>()),
output_count);

} break;
default:
return Status(ONNXRUNTIME, FAIL, "Resize: unexpected mode");
Expand All @@ -261,7 +263,7 @@ Status Upsample<T>::BaseCompute(OpKernelContext* context,
size_t temp_buffer_size = CalcResizeBufferSize(mode_, output_dims);
auto dims_mapping_buffer = GetScratchBuffer<unsigned char>(temp_buffer_size, context->GetComputeStream());
void* dims_mapping = reinterpret_cast<void*>(dims_mapping_buffer.get());
ResizeImpl(Stream(context), mode_, (int)rank, input_shape, output_shape,
ResizeImpl(Stream(context), mode_, rank, input_shape, output_shape,
input_strides, output_div_pitches, scales_vals, roi_vals,
reinterpret_cast<const CudaT*>(X->Data<T>()),
reinterpret_cast<CudaT*>(Y->MutableData<T>()),
Expand Down
Loading

0 comments on commit 9f766d2

Please sign in to comment.