Skip to content

Commit

Permalink
Add portable upsample_bilinear2d kernel (#6923)
Browse files Browse the repository at this point in the history
Summary:
Add a upsample_bilinear2d kernel to the portable kernel library. This implementation re-uses some of the inner logic from the ATen implementation (see Upsample.h and UpsampleKernel.cpp), however I have not ported the outer kernel structure as it relies on TensorIterator and runtime allocation.

It may be worth re-visiting this in the future, either by looking at pulling in more of the ATen implementation or adding an optimized variant.


Test Plan:
Added comprehensive operator-level test coverage for upsample_bilinear2d.
```
buck test //executorch/kernels/test:portable_op_upsample_bilinear2d_test
buck test //executorch/kernels/test:aten_op_upsample_bilinear2d_test
```

Differential Revision: D65756150

Pulled By: GregoryComer
  • Loading branch information
GregoryComer authored and facebook-github-bot committed Dec 3, 2024
1 parent 1cf9482 commit 21b4bca
Show file tree
Hide file tree
Showing 9 changed files with 939 additions and 0 deletions.
2 changes: 2 additions & 0 deletions kernels/aten/functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,8 @@

- op: unsqueeze_copy.out

- op: upsample_bilinear2d.vec_out

- op: upsample_nearest2d.out

- op: upsample_nearest2d.vec_out
Expand Down
135 changes: 135 additions & 0 deletions kernels/portable/cpu/op_upsample_bilinear2d.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/kernels/portable/cpu/util/upsample_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>

namespace torch {
namespace executor {
namespace native {

using exec_aten::ArrayRef;
using exec_aten::optional;
using exec_aten::SizesType;

namespace {
template <typename CTYPE>
void upsample_bilinear2d_kernel_impl(
const Tensor& in,
bool align_corners,
const float scale_h,
const float scale_w,
Tensor& out) {
const auto in_data = in.const_data_ptr<CTYPE>();
auto out_data = out.mutable_data_ptr<CTYPE>();

auto in_plane = in_data;
for (auto n = 0; n < out.size(0); n++) {
for (auto c = 0; c < out.size(1); c++) {
for (auto h = 0; h < out.size(2); h++) {
for (auto w = 0; w < out.size(3); w++) {
// Compute source index.
// See area_pixel_compute_source_index in
// pytorch/aten/src/ATen/native/UpSample.h
int64_t in_h1, in_h2, in_w1, in_w2;
float weight_h, inv_weight_h, weight_w, inv_weight_w;

compute_source_index_and_lambda(
in_h1,
in_h2,
weight_h,
inv_weight_h,
scale_h,
h,
in.sizes()[2],
out.sizes()[2],
align_corners);

compute_source_index_and_lambda(
in_w1,
in_w2,
weight_w,
inv_weight_w,
scale_w,
w,
in.sizes()[3],
out.sizes()[3],
align_corners);

const auto top_left =
in_plane[in_h1 * in.strides()[2] + in_w1 * in.strides()[3]];
const auto top_right =
in_plane[in_h1 * in.strides()[2] + in_w2 * in.strides()[3]];
const auto bottom_left =
in_plane[in_h2 * in.strides()[2] + in_w1 * in.strides()[3]];
const auto bottom_right =
in_plane[in_h2 * in.strides()[2] + in_w2 * in.strides()[3]];

const auto top = top_left * weight_w + top_right * inv_weight_w;
const auto bottom =
bottom_left * weight_w + bottom_right * inv_weight_w;
const auto val = top * weight_h + bottom * inv_weight_h;

*out_data = val;
out_data++;
}
}

in_plane += in.strides()[1];
}
}
}
} // namespace

// Signatures are auto-generated, so disable pass-by-value lint.
// NOLINTBEGIN(facebook-hte-ConstantArgumentPassByValue, facebook-hte-ParameterMightThrowOnCopy)
Tensor& upsample_bilinear2d_vec_out(
KernelRuntimeContext& ctx,
const Tensor& in,
const exec_aten::OptionalArrayRef<int64_t> output_size,
bool align_corners,
const exec_aten::OptionalArrayRef<double> scale_factors,
Tensor& out) {
// Preconditions (checked in check_..._args):
// In and out tensors have same dtype.
// In and out tensors are rank 4 and have same dim[0] and dim[1].
// In and out tensors are default dim order (NCHW).
ET_KERNEL_CHECK(
ctx,
check_upsample_bilinear2d_args(
in, output_size, align_corners, scale_factors, out),
InvalidArgument,
out);

double scale_h, scale_w;

ET_KERNEL_CHECK_MSG(
ctx,
resize_upsample_2d(in, output_size, scale_factors, scale_h, scale_w, out) == Error::Ok,
InvalidArgument,
out,
"Failed to resize output tensor");

const auto kernel_scale_h = area_pixel_compute_scale<double>(
in.sizes()[2], out.sizes()[2], align_corners, scale_h);
const auto kernel_scale_w = area_pixel_compute_scale<double>(
in.sizes()[3], out.sizes()[3], align_corners, scale_w);

ET_SWITCH_REAL_TYPES(
in.scalar_type(), ctx, "upsample_bilinear2d.out", CTYPE, [&]() {
upsample_bilinear2d_kernel_impl<CTYPE>(
in, align_corners, kernel_scale_h, kernel_scale_w, out);
});

return out;
}
// NOLINTEND(facebook-hte-ConstantArgumentPassByValue, facebook-hte-ParameterMightThrowOnCopy)

} // namespace native
} // namespace executor
} // namespace torch
11 changes: 11 additions & 0 deletions kernels/portable/cpu/util/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def define_common_targets():
"//executorch/kernels/portable/cpu/util:advanced_index_util",
"//executorch/kernels/portable/cpu/util:slice_util",
"//executorch/kernels/portable/cpu/util:elementwise_util",
"//executorch/kernels/portable/cpu/util:upsample_util",
],
visibility = ["//executorch/...", "@EXECUTORCH_CLIENTS"],
)
Expand Down Expand Up @@ -266,6 +267,16 @@ def define_common_targets():
visibility = ["//executorch/kernels/portable/cpu/..."],
)

runtime.cxx_library(
name = "upsample_util",
srcs = ["upsample_util.cpp"],
exported_headers = ["upsample_util.h"],
deps = [
"//executorch/runtime/kernel:kernel_includes",
],
visibility = ["//executorch/kernels/portable/cpu/..."],
)

# Utility functions that can be used by operators that perform reduction
for aten_mode in [True, False]:
suffix = "_aten" if aten_mode else ""
Expand Down
94 changes: 94 additions & 0 deletions kernels/portable/cpu/util/upsample_util.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/kernels/portable/cpu/util/upsample_util.h>
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>

namespace torch {
namespace executor {

bool check_upsample_2d_common_args(
const Tensor& in,
const exec_aten::OptionalArrayRef<int64_t>& output_size,
const exec_aten::OptionalArrayRef<double>& scale_factors,
Tensor& out) {
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
ET_LOG_AND_RETURN_IF_FALSE(in.dim() == 4);
ET_LOG_AND_RETURN_IF_FALSE(out.dim() == 4);
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_dim_order(in));
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_dim_order(out));
ET_LOG_AND_RETURN_IF_FALSE(output_size.has_value() ^ scale_factors.has_value());
if (scale_factors.has_value()) {
ET_LOG_AND_RETURN_IF_FALSE(scale_factors.value().size() == 2);
ET_LOG_AND_RETURN_IF_FALSE(scale_factors.value()[0] > 0);
ET_LOG_AND_RETURN_IF_FALSE(scale_factors.value()[1] > 0);
}
else if (output_size.has_value()) {
ET_LOG_AND_RETURN_IF_FALSE(output_size.value().size() == 2);
ET_LOG_AND_RETURN_IF_FALSE(output_size.value()[0] > 0);
ET_LOG_AND_RETURN_IF_FALSE(output_size.value()[1] > 0);
}

return true;
}

bool check_upsample_bilinear2d_args(
const Tensor& in,
const exec_aten::OptionalArrayRef<int64_t>& output_size,
ET_UNUSED const bool align_corners,
const exec_aten::OptionalArrayRef<double>& scale_factors,
Tensor& out) {
return check_upsample_2d_common_args(in, output_size, scale_factors, out);
}

Error resize_upsample_2d(
const Tensor& in,
const exec_aten::OptionalArrayRef<int64_t>& output_size,
const exec_aten::OptionalArrayRef<double>& scale_factors,
double& scale_h_out,
double& scale_w_out,
Tensor& out) {
// Either output_size or scale_factors are provided, not both. This
// is checked in check_..._args.
// Scales are transformed according to align_corners.
std::array<Tensor::SizesType, kTensorDimensionLimit> target_size;

const auto dim = in.dim();
std::copy(in.sizes().cbegin(), in.sizes().cend(), target_size.begin());

if (scale_factors.has_value()) {
scale_h_out = scale_factors.value()[0];
scale_w_out = scale_factors.value()[1];

target_size[dim - 2] =
static_cast<Tensor::SizesType>(in.sizes()[dim - 2] * scale_h_out);
target_size[dim - 1] =
static_cast<Tensor::SizesType>(in.sizes()[dim - 1] * scale_w_out);
} else if (output_size.has_value()) {
scale_h_out = static_cast<double>(output_size.value()[0]) / in.sizes()[dim - 2];
scale_w_out = static_cast<double>(output_size.value()[1]) / in.sizes()[dim - 1];

target_size[dim - 2] = output_size.value()[0];
target_size[dim - 1] = output_size.value()[1];
} else {
ET_LOG(Error, "Invalid output_size or scale_factors");
return Error::InvalidArgument;
}

ET_CHECK_OR_RETURN_ERROR(
target_size[dim - 2] > 0 && target_size[dim - 1] > 0,
InvalidArgument,
"Upsampled output size must be non-empty, but was %ld x %ld.",
static_cast<long>(target_size[dim - 2]),
static_cast<long>(target_size[dim - 1]));

return resize_tensor(out, {target_size.data(), static_cast<size_t>(dim)});
}

} // namespace executor
} // namespace torch
131 changes: 131 additions & 0 deletions kernels/portable/cpu/util/upsample_util.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>

namespace torch {
namespace executor {

bool check_upsample_2d_common_args(
const Tensor& in,
const exec_aten::OptionalArrayRef<int64_t>& output_size,
const exec_aten::OptionalArrayRef<double>& scale_factors,
Tensor& out);

bool check_upsample_bilinear2d_args(
const Tensor& in,
const exec_aten::OptionalArrayRef<int64_t>& output_size,
const bool align_corners,
const exec_aten::OptionalArrayRef<double>& scale_factors,
Tensor& out);

Error resize_upsample_2d(
const Tensor& in,
const exec_aten::OptionalArrayRef<int64_t>& output_size,
const exec_aten::OptionalArrayRef<double>& scale_factors,
double& scale_h_out,
double& scale_w_out,
Tensor& out);

// Ported from aten/src/ATen/native/UpSample.h
template <typename scalar_t>
inline scalar_t compute_scales_value(
const exec_aten::optional<double>& scale,
int64_t input_size,
int64_t output_size) {
return scale.has_value() ? static_cast<scalar_t>(1.0 / scale.value())
: (static_cast<scalar_t>(input_size) / output_size);
}

// Ported from aten/src/ATen/native/UpSample.h
template <typename scalar_t>
inline scalar_t area_pixel_compute_scale(
int64_t input_size,
int64_t output_size,
bool align_corners,
const exec_aten::optional<double>& scale) {
// see Note [area_pixel_compute_scale]
if (align_corners) {
if (output_size > 1) {
return static_cast<scalar_t>(input_size - 1) / (output_size - 1);
} else {
return static_cast<scalar_t>(0);
}
} else {
return compute_scales_value<scalar_t>(scale, input_size, output_size);
}
}

// Ported from aten/src/ATen/native/UpSample.h
template <typename scalar_t>
inline scalar_t area_pixel_compute_source_index(
scalar_t scale,
int64_t dst_index,
bool align_corners,
bool cubic) {
if (align_corners) {
return scale * dst_index;
} else {
scalar_t src_idx = scale * (dst_index + static_cast<scalar_t>(0.5)) -
static_cast<scalar_t>(0.5);
return (!cubic && src_idx < static_cast<scalar_t>(0)) ? scalar_t(0)
: src_idx;
}
}

// Ported from aten/src/ATen/native/UpSample.h
// when `real_input_index` becomes larger than the range the floating point
// type can accurately represent, the type casting to `int64_t` might exceed
// `input_size`, causing overflow. So we guard it with `std::min` below.
template <typename scalar_t, typename opmath_t>
inline void guard_index_and_lambda(
const opmath_t& real_input_index,
const int64_t& input_size,
int64_t& input_index,
scalar_t& lambda) {
input_index =
std::min(static_cast<int64_t>(floorf(real_input_index)), input_size - 1);
lambda = std::min(
std::max(real_input_index - input_index, static_cast<opmath_t>(0)),
static_cast<opmath_t>(1));
}

// Ported from aten/src/ATen/native/UpSample.h
template <typename scalar_t, typename opmath_t>
inline void compute_source_index_and_lambda(
int64_t& input_index0,
int64_t& input_index1,
scalar_t& lambda0,
scalar_t& lambda1,
opmath_t ratio,
int64_t output_index,
int64_t input_size,
int64_t output_size,
bool align_corners) {
if (output_size == input_size) {
// scale_factor = 1, simply copy
input_index0 = output_index;
input_index1 = output_index;
lambda0 = static_cast<scalar_t>(1);
lambda1 = static_cast<scalar_t>(0);
} else {
const auto real_input_index = area_pixel_compute_source_index<opmath_t>(
ratio, output_index, align_corners, /*cubic=*/false);
guard_index_and_lambda(real_input_index, input_size, input_index0, lambda1);
int64_t offset = (input_index0 < input_size - 1) ? 1 : 0;
input_index1 = input_index0 + offset;
lambda0 = static_cast<scalar_t>(1.) - lambda1;
}
}

} // namespace executor
} // namespace torch
Loading

0 comments on commit 21b4bca

Please sign in to comment.