diff --git a/kernels/aten/functions.yaml b/kernels/aten/functions.yaml index f2f18f51c8..449a9769b4 100644 --- a/kernels/aten/functions.yaml +++ b/kernels/aten/functions.yaml @@ -403,6 +403,8 @@ - op: unsqueeze_copy.out +- op: upsample_bilinear2d.vec_out + - op: upsample_nearest2d.out - op: upsample_nearest2d.vec_out diff --git a/kernels/portable/cpu/op_upsample_bilinear2d.cpp b/kernels/portable/cpu/op_upsample_bilinear2d.cpp new file mode 100644 index 0000000000..a6bb666cff --- /dev/null +++ b/kernels/portable/cpu/op_upsample_bilinear2d.cpp @@ -0,0 +1,135 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace torch { +namespace executor { +namespace native { + +using exec_aten::ArrayRef; +using exec_aten::optional; +using exec_aten::SizesType; + +namespace { +template +void upsample_bilinear2d_kernel_impl( + const Tensor& in, + bool align_corners, + const float scale_h, + const float scale_w, + Tensor& out) { + const auto in_data = in.const_data_ptr(); + auto out_data = out.mutable_data_ptr(); + + auto in_plane = in_data; + for (auto n = 0; n < out.size(0); n++) { + for (auto c = 0; c < out.size(1); c++) { + for (auto h = 0; h < out.size(2); h++) { + for (auto w = 0; w < out.size(3); w++) { + // Compute source index. + // See area_pixel_compute_source_index in + // pytorch/aten/src/ATen/native/UpSample.h + int64_t in_h1, in_h2, in_w1, in_w2; + float weight_h, inv_weight_h, weight_w, inv_weight_w; + + compute_source_index_and_lambda( + in_h1, + in_h2, + weight_h, + inv_weight_h, + scale_h, + h, + in.sizes()[2], + out.sizes()[2], + align_corners); + + compute_source_index_and_lambda( + in_w1, + in_w2, + weight_w, + inv_weight_w, + scale_w, + w, + in.sizes()[3], + out.sizes()[3], + align_corners); + + const auto top_left = + in_plane[in_h1 * in.strides()[2] + in_w1 * in.strides()[3]]; + const auto top_right = + in_plane[in_h1 * in.strides()[2] + in_w2 * in.strides()[3]]; + const auto bottom_left = + in_plane[in_h2 * in.strides()[2] + in_w1 * in.strides()[3]]; + const auto bottom_right = + in_plane[in_h2 * in.strides()[2] + in_w2 * in.strides()[3]]; + + const auto top = top_left * weight_w + top_right * inv_weight_w; + const auto bottom = + bottom_left * weight_w + bottom_right * inv_weight_w; + const auto val = top * weight_h + bottom * inv_weight_h; + + *out_data = val; + out_data++; + } + } + + in_plane += in.strides()[1]; + } + } +} +} // namespace + +// Signatures are auto-generated, so disable pass-by-value lint. +// NOLINTBEGIN(facebook-hte-ConstantArgumentPassByValue, facebook-hte-ParameterMightThrowOnCopy) +Tensor& upsample_bilinear2d_vec_out( + KernelRuntimeContext& ctx, + const Tensor& in, + const exec_aten::OptionalArrayRef output_size, + bool align_corners, + const exec_aten::OptionalArrayRef scale_factors, + Tensor& out) { + // Preconditions (checked in check_..._args): + // In and out tensors have same dtype. + // In and out tensors are rank 4 and have same dim[0] and dim[1]. + // In and out tensors are default dim order (NCHW). + ET_KERNEL_CHECK( + ctx, + check_upsample_bilinear2d_args( + in, output_size, align_corners, scale_factors, out), + InvalidArgument, + out); + + double scale_h, scale_w; + + ET_KERNEL_CHECK_MSG( + ctx, + resize_upsample_2d(in, output_size, scale_factors, scale_h, scale_w, out) == Error::Ok, + InvalidArgument, + out, + "Failed to resize output tensor"); + + const auto kernel_scale_h = area_pixel_compute_scale( + in.sizes()[2], out.sizes()[2], align_corners, scale_h); + const auto kernel_scale_w = area_pixel_compute_scale( + in.sizes()[3], out.sizes()[3], align_corners, scale_w); + + ET_SWITCH_REAL_TYPES( + in.scalar_type(), ctx, "upsample_bilinear2d.out", CTYPE, [&]() { + upsample_bilinear2d_kernel_impl( + in, align_corners, kernel_scale_h, kernel_scale_w, out); + }); + + return out; +} +// NOLINTEND(facebook-hte-ConstantArgumentPassByValue, facebook-hte-ParameterMightThrowOnCopy) + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/portable/cpu/util/targets.bzl b/kernels/portable/cpu/util/targets.bzl index 1dc36afce2..3e8072fef6 100644 --- a/kernels/portable/cpu/util/targets.bzl +++ b/kernels/portable/cpu/util/targets.bzl @@ -31,6 +31,7 @@ def define_common_targets(): "//executorch/kernels/portable/cpu/util:advanced_index_util", "//executorch/kernels/portable/cpu/util:slice_util", "//executorch/kernels/portable/cpu/util:elementwise_util", + "//executorch/kernels/portable/cpu/util:upsample_util", ], visibility = ["//executorch/...", "@EXECUTORCH_CLIENTS"], ) @@ -266,6 +267,16 @@ def define_common_targets(): visibility = ["//executorch/kernels/portable/cpu/..."], ) + runtime.cxx_library( + name = "upsample_util", + srcs = ["upsample_util.cpp"], + exported_headers = ["upsample_util.h"], + deps = [ + "//executorch/runtime/kernel:kernel_includes", + ], + visibility = ["//executorch/kernels/portable/cpu/..."], + ) + # Utility functions that can be used by operators that perform reduction for aten_mode in [True, False]: suffix = "_aten" if aten_mode else "" diff --git a/kernels/portable/cpu/util/upsample_util.cpp b/kernels/portable/cpu/util/upsample_util.cpp new file mode 100644 index 0000000000..11074fd31f --- /dev/null +++ b/kernels/portable/cpu/util/upsample_util.cpp @@ -0,0 +1,94 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace torch { +namespace executor { + +bool check_upsample_2d_common_args( + const Tensor& in, + const exec_aten::OptionalArrayRef& output_size, + const exec_aten::OptionalArrayRef& scale_factors, + Tensor& out) { + ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out)); + ET_LOG_AND_RETURN_IF_FALSE(in.dim() == 4); + ET_LOG_AND_RETURN_IF_FALSE(out.dim() == 4); + ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_dim_order(in)); + ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_dim_order(out)); + ET_LOG_AND_RETURN_IF_FALSE(output_size.has_value() ^ scale_factors.has_value()); + if (scale_factors.has_value()) { + ET_LOG_AND_RETURN_IF_FALSE(scale_factors.value().size() == 2); + ET_LOG_AND_RETURN_IF_FALSE(scale_factors.value()[0] > 0); + ET_LOG_AND_RETURN_IF_FALSE(scale_factors.value()[1] > 0); + } + else if (output_size.has_value()) { + ET_LOG_AND_RETURN_IF_FALSE(output_size.value().size() == 2); + ET_LOG_AND_RETURN_IF_FALSE(output_size.value()[0] > 0); + ET_LOG_AND_RETURN_IF_FALSE(output_size.value()[1] > 0); + } + + return true; +} + +bool check_upsample_bilinear2d_args( + const Tensor& in, + const exec_aten::OptionalArrayRef& output_size, + ET_UNUSED const bool align_corners, + const exec_aten::OptionalArrayRef& scale_factors, + Tensor& out) { + return check_upsample_2d_common_args(in, output_size, scale_factors, out); +} + +Error resize_upsample_2d( + const Tensor& in, + const exec_aten::OptionalArrayRef& output_size, + const exec_aten::OptionalArrayRef& scale_factors, + double& scale_h_out, + double& scale_w_out, + Tensor& out) { + // Either output_size or scale_factors are provided, not both. This + // is checked in check_..._args. + // Scales are transformed according to align_corners. + std::array target_size; + + const auto dim = in.dim(); + std::copy(in.sizes().cbegin(), in.sizes().cend(), target_size.begin()); + + if (scale_factors.has_value()) { + scale_h_out = scale_factors.value()[0]; + scale_w_out = scale_factors.value()[1]; + + target_size[dim - 2] = + static_cast(in.sizes()[dim - 2] * scale_h_out); + target_size[dim - 1] = + static_cast(in.sizes()[dim - 1] * scale_w_out); + } else if (output_size.has_value()) { + scale_h_out = static_cast(output_size.value()[0]) / in.sizes()[dim - 2]; + scale_w_out = static_cast(output_size.value()[1]) / in.sizes()[dim - 1]; + + target_size[dim - 2] = output_size.value()[0]; + target_size[dim - 1] = output_size.value()[1]; + } else { + ET_LOG(Error, "Invalid output_size or scale_factors"); + return Error::InvalidArgument; + } + + ET_CHECK_OR_RETURN_ERROR( + target_size[dim - 2] > 0 && target_size[dim - 1] > 0, + InvalidArgument, + "Upsampled output size must be non-empty, but was %ld x %ld.", + static_cast(target_size[dim - 2]), + static_cast(target_size[dim - 1])); + + return resize_tensor(out, {target_size.data(), static_cast(dim)}); +} + +} // namespace executor +} // namespace torch diff --git a/kernels/portable/cpu/util/upsample_util.h b/kernels/portable/cpu/util/upsample_util.h new file mode 100644 index 0000000000..86ae8e77ec --- /dev/null +++ b/kernels/portable/cpu/util/upsample_util.h @@ -0,0 +1,131 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +namespace torch { +namespace executor { + +bool check_upsample_2d_common_args( + const Tensor& in, + const exec_aten::OptionalArrayRef& output_size, + const exec_aten::OptionalArrayRef& scale_factors, + Tensor& out); + +bool check_upsample_bilinear2d_args( + const Tensor& in, + const exec_aten::OptionalArrayRef& output_size, + const bool align_corners, + const exec_aten::OptionalArrayRef& scale_factors, + Tensor& out); + +Error resize_upsample_2d( + const Tensor& in, + const exec_aten::OptionalArrayRef& output_size, + const exec_aten::OptionalArrayRef& scale_factors, + double& scale_h_out, + double& scale_w_out, + Tensor& out); + +// Ported from aten/src/ATen/native/UpSample.h +template +inline scalar_t compute_scales_value( + const exec_aten::optional& scale, + int64_t input_size, + int64_t output_size) { + return scale.has_value() ? static_cast(1.0 / scale.value()) + : (static_cast(input_size) / output_size); +} + +// Ported from aten/src/ATen/native/UpSample.h +template +inline scalar_t area_pixel_compute_scale( + int64_t input_size, + int64_t output_size, + bool align_corners, + const exec_aten::optional& scale) { + // see Note [area_pixel_compute_scale] + if (align_corners) { + if (output_size > 1) { + return static_cast(input_size - 1) / (output_size - 1); + } else { + return static_cast(0); + } + } else { + return compute_scales_value(scale, input_size, output_size); + } +} + +// Ported from aten/src/ATen/native/UpSample.h +template +inline scalar_t area_pixel_compute_source_index( + scalar_t scale, + int64_t dst_index, + bool align_corners, + bool cubic) { + if (align_corners) { + return scale * dst_index; + } else { + scalar_t src_idx = scale * (dst_index + static_cast(0.5)) - + static_cast(0.5); + return (!cubic && src_idx < static_cast(0)) ? scalar_t(0) + : src_idx; + } +} + +// Ported from aten/src/ATen/native/UpSample.h +// when `real_input_index` becomes larger than the range the floating point +// type can accurately represent, the type casting to `int64_t` might exceed +// `input_size`, causing overflow. So we guard it with `std::min` below. +template +inline void guard_index_and_lambda( + const opmath_t& real_input_index, + const int64_t& input_size, + int64_t& input_index, + scalar_t& lambda) { + input_index = + std::min(static_cast(floorf(real_input_index)), input_size - 1); + lambda = std::min( + std::max(real_input_index - input_index, static_cast(0)), + static_cast(1)); +} + +// Ported from aten/src/ATen/native/UpSample.h +template +inline void compute_source_index_and_lambda( + int64_t& input_index0, + int64_t& input_index1, + scalar_t& lambda0, + scalar_t& lambda1, + opmath_t ratio, + int64_t output_index, + int64_t input_size, + int64_t output_size, + bool align_corners) { + if (output_size == input_size) { + // scale_factor = 1, simply copy + input_index0 = output_index; + input_index1 = output_index; + lambda0 = static_cast(1); + lambda1 = static_cast(0); + } else { + const auto real_input_index = area_pixel_compute_source_index( + ratio, output_index, align_corners, /*cubic=*/false); + guard_index_and_lambda(real_input_index, input_size, input_index0, lambda1); + int64_t offset = (input_index0 < input_size - 1) ? 1 : 0; + input_index1 = input_index0 + offset; + lambda0 = static_cast(1.) - lambda1; + } +} + +} // namespace executor +} // namespace torch diff --git a/kernels/portable/functions.yaml b/kernels/portable/functions.yaml index a5d60eb59e..2e0a6a3218 100644 --- a/kernels/portable/functions.yaml +++ b/kernels/portable/functions.yaml @@ -912,6 +912,11 @@ - arg_meta: null kernel_name: torch::executor::unsqueeze_copy_out +- op: upsample_bilinear2d.vec_out + kernels: + - arg_meta: null + kernel_name: torch::executor::upsample_bilinear2d_vec_out + - op: var.correction_out kernels: - arg_meta: null diff --git a/kernels/test/op_upsample_bilinear2d_test.cpp b/kernels/test/op_upsample_bilinear2d_test.cpp new file mode 100644 index 0000000000..7b3fa96f62 --- /dev/null +++ b/kernels/test/op_upsample_bilinear2d_test.cpp @@ -0,0 +1,554 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include // Declares the operator +#include +#include +#include +#include +#include + +#include + +using exec_aten::optional; +using exec_aten::ScalarType; +using exec_aten::Tensor; +using torch::executor::testing::SupportedFeatures; +using torch::executor::testing::TensorFactory; + + +// TODO Fix this +#ifdef USE_ATEN_LIB +template +using OptionalArrayRef = std::optional>; +#else +using exec_aten::OptionalArrayRef; +#endif + + +class OpUpsampleBilinear2dTest : public OperatorTest { + protected: + Tensor& op_upsample_bilinear2d_vec_out( + const Tensor& in, + const OptionalArrayRef& output_size, + bool align_corners, + const OptionalArrayRef& scale_factors, + Tensor& out) { + return torch::executor::aten::upsample_bilinear2d_outf( + context_, in, output_size, align_corners, scale_factors, out); + } + + template + void test_upsample_bilinear2d_dtype() { + TensorFactory tf; + + const auto input = tf.make({1, 1, 1, 2}, {1, 4}); + std::array output_size = {1, 4}; + auto out = tf.zeros({1, 1, 1, 4}); + + op_upsample_bilinear2d_vec_out( + input, + OptionalArrayRef({output_size.data(), output_size.size()}), + true, + {}, + out); + + const auto expected = tf.make({1, 1, 1, 4}, {1, 2, 3, 4}); + + EXPECT_TENSOR_CLOSE(out, expected); + } +}; + +TEST_F(OpUpsampleBilinear2dTest, Simple1x2To1x4) { + TensorFactory tf; + + const auto input = tf.make({1, 1, 1, 2}, {1.0, 4.0}); + std::array output_size = {1, 4}; + auto out = tf.zeros({1, 1, 1, 4}); + + op_upsample_bilinear2d_vec_out( + input, + OptionalArrayRef({output_size.data(), output_size.size()}), + false, + {}, + out); + + const auto expected = tf.make({1, 1, 1, 4}, {1.0, 1.75, 3.25, 4.0}); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST_F(OpUpsampleBilinear2dTest, Simple1x2To1x4AlignCorners) { + TensorFactory tf; + + const auto input = tf.make({1, 1, 2, 1}, {1.0, 4.0}); + std::array output_size = {4, 1}; + auto out = tf.zeros({1, 1, 4, 1}); + + op_upsample_bilinear2d_vec_out( + input, + OptionalArrayRef({output_size.data(), output_size.size()}), + true, + {}, + out); + + const auto expected = tf.make({1, 1, 4, 1}, {1.0, 2.0, 3.0, 4.0}); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST_F(OpUpsampleBilinear2dTest, Simple2x1To4x1) { + TensorFactory tf; + + const auto input = tf.make({1, 1, 2, 1}, {1.0, 4.0}); + std::array output_size = {4, 1}; + auto out = tf.zeros({1, 1, 4, 1}); + + op_upsample_bilinear2d_vec_out( + input, + OptionalArrayRef({output_size.data(), output_size.size()}), + false, + {}, + out); + + const auto expected = tf.make({1, 1, 4, 1}, {1.0, 1.75, 3.25, 4.0}); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST_F(OpUpsampleBilinear2dTest, Simple2x1To4x1AlignCorners) { + TensorFactory tf; + + const auto input = tf.make({1, 1, 2, 1}, {1.0, 4.0}); + std::array output_size = {4, 1}; + auto out = tf.zeros({1, 1, 4, 1}); + + op_upsample_bilinear2d_vec_out( + input, + OptionalArrayRef({output_size.data(), output_size.size()}), + true, + {}, + out); + + const auto expected = tf.make({1, 1, 4, 1}, {1.0, 2.0, 3.0, 4.0}); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST_F(OpUpsampleBilinear2dTest, SmokeTest) { + TensorFactory tf; + + const auto input = tf.make( + {1, 1, 2, 3}, + { + 1.0, + 2.0, + 3.0, + 4.0, + 5.0, + 6.0, + }); + std::array output_size = {3, 4}; + auto out = tf.zeros({1, 1, 3, 4}); + + op_upsample_bilinear2d_vec_out( + input, + OptionalArrayRef({output_size.data(), output_size.size()}), + false, + {}, + out); + + const auto expected = tf.make( + {1, 1, 3, 4}, + {1.0000, + 1.6250, + 2.3750, + 3.0000, + 2.5000, + 3.1250, + 3.8750, + 4.5000, + 4.0000, + 4.6250, + 5.3750, + 6.0000}); + + EXPECT_TENSOR_CLOSE(out, expected); +} + +TEST_F(OpUpsampleBilinear2dTest, SmokeTestAlignCorners) { + TensorFactory tf; + + const auto input = tf.make( + {1, 1, 2, 3}, + { + 1.0, + 2.0, + 3.0, + 4.0, + 5.0, + 6.0, + }); + std::array output_size = {3, 4}; + auto out = tf.zeros({1, 1, 3, 4}); + + op_upsample_bilinear2d_vec_out( + input, + OptionalArrayRef({output_size.data(), output_size.size()}), + true, + {}, + out); + + const auto expected = tf.make( + {1, 1, 3, 4}, + {1.0000, + 1.6667, + 2.3333, + 3.0000, + 2.5000, + 3.1667, + 3.8333, + 4.5000, + 4.0000, + 4.6667, + 5.3333, + 6.0000}); + + EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, 0, 0.0001); +} + +TEST_F(OpUpsampleBilinear2dTest, SmokeTestScales) { + TensorFactory tf; + + const auto input = tf.make( + {1, 1, 2, 3}, + { + 1.0, + 2.0, + 3.0, + 4.0, + 5.0, + 6.0, + }); + auto out = tf.zeros({1, 1, 3, 4}); + + const std::array scale_factors = {3.0 / 2, 4.0 / 3}; + op_upsample_bilinear2d_vec_out(input, {}, false, OptionalArrayRef({ scale_factors.data(), scale_factors.size() }), out); + + const auto expected = tf.make( + {1, 1, 3, 4}, + {1.0000, + 1.6250, + 2.3750, + 3.0000, + 2.5000, + 3.1250, + 3.8750, + 4.5000, + 4.0000, + 4.6250, + 5.3750, + 6.0000}); + + EXPECT_TENSOR_CLOSE(out, expected); +} + +TEST_F(OpUpsampleBilinear2dTest, SmokeTestAlignCornersScales) { + TensorFactory tf; + + const auto input = tf.make( + {1, 1, 2, 3}, + { + 1.0, + 2.0, + 3.0, + 4.0, + 5.0, + 6.0, + }); + auto out = tf.zeros({1, 1, 3, 4}); + + const std::array scale_factors = {3.0 / 2, 4.0 / 3}; + op_upsample_bilinear2d_vec_out(input, {}, true, OptionalArrayRef({ scale_factors.data(), scale_factors.size() }), out); + + const auto expected = tf.make( + {1, 1, 3, 4}, + {1.0000, + 1.6667, + 2.3333, + 3.0000, + 2.5000, + 3.1667, + 3.8333, + 4.5000, + 4.0000, + 4.6667, + 5.3333, + 6.0000}); + + EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, 0, 0.0001); +} + +TEST_F(OpUpsampleBilinear2dTest, DType) { +#define TEST_ENTRY(ctype, dtype) \ + test_upsample_bilinear2d_dtype(); \ + ET_FORALL_REAL_TYPES(TEST_ENTRY); +#undef TEST_ENTRY +} + +TEST_F(OpUpsampleBilinear2dTest, MismatchedOutputSizeDies) { + if (SupportedFeatures::get()->output_resize) { + GTEST_SKIP() + << "The current kernel supports implicitly resizing output tensor"; + } + TensorFactory tf; + + const auto input = tf.ones({1, 1, 1, 2}); + std::array output_size = {1, 4}; + auto out = tf.zeros({1, 1, 1, 5}); + + ET_EXPECT_KERNEL_FAILURE( + context_, + op_upsample_bilinear2d_vec_out( + input, + OptionalArrayRef({output_size.data(), output_size.size()}), + false, + {}, + out)); +} + +TEST_F(OpUpsampleBilinear2dTest, InvalidInputRankDies) { + TensorFactory tf; + + const auto input = tf.ones({1, 1, 2}); + std::array output_size = {1, 4}; + auto out = tf.zeros({1, 1, 1, 4}); + + ET_EXPECT_KERNEL_FAILURE( + context_, + op_upsample_bilinear2d_vec_out( + input, + OptionalArrayRef({output_size.data(), output_size.size()}), + false, + {}, + out)); +} + +TEST_F(OpUpsampleBilinear2dTest, InvalidOutputRankDies) { + TensorFactory tf; + + const auto input = tf.ones({1, 1, 2}); + std::array output_size = {1, 4}; + auto out = tf.zeros({1, 1, 4}); + + ET_EXPECT_KERNEL_FAILURE( + context_, + op_upsample_bilinear2d_vec_out( + input, + OptionalArrayRef({output_size.data(), output_size.size()}), + false, + {}, + out)); +} + +TEST_F(OpUpsampleBilinear2dTest, MissingOutputSizeOrScaleDies) { + TensorFactory tf; + + const auto input = tf.ones({1, 1, 2}); + auto out = tf.zeros({1, 1, 4}); + + ET_EXPECT_KERNEL_FAILURE( + context_, op_upsample_bilinear2d_vec_out(input, {}, false, {}, out)); +} + +TEST_F(OpUpsampleBilinear2dTest, BothOutputSizeAndScaleDies) { + TensorFactory tf; + + const auto input = tf.ones({1, 1, 2}); + const std::array output_size = {1, 4}; + const std::array scale_factors = {2, 1}; + auto out = tf.zeros({1, 1, 4}); + + ET_EXPECT_KERNEL_FAILURE( + context_, + op_upsample_bilinear2d_vec_out( + input, + OptionalArrayRef({output_size.data(), output_size.size()}), + false, + OptionalArrayRef({scale_factors.data(), scale_factors.size()}), + out)); +} + +TEST_F(OpUpsampleBilinear2dTest, MismatchedDTypeDies) { + TensorFactory tf; + TensorFactory tf2; + + const auto input = tf.ones({1, 1, 1, 2}); + std::array output_size = {1, 4}; + auto out = tf2.zeros({1, 1, 1, 4}); + + ET_EXPECT_KERNEL_FAILURE( + context_, + op_upsample_bilinear2d_vec_out( + input, + OptionalArrayRef({output_size.data(), output_size.size()}), + false, + {}, + out)); +} + +TEST_F(OpUpsampleBilinear2dTest, ComputedOutputSizeMatchesExpected) { + // Computed output sizes (from input size * scales) must match PyTorch + // eager-mode - multiplied as double and cast (truncated) to an integral type. + // See + // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/UpSample.cpp + TensorFactory tf; + + // Test case format: { in_h, in_w, scale_h, scale_w, out_h, out_w } + std::vector> + test_cases = { + {10, 10, 9.99999, 9.55, 99, 95}, + {10, 10, 9.99999999, 0.1, 99, 1}, + }; + + for (const auto& test_case : test_cases) { + const auto [in_h, in_w, scale_h, scale_w, out_h, out_w] = test_case; + + const auto input = tf.ones({1, 1, in_h, in_w}); + auto out = tf.zeros({1, 1, out_h, out_w}); + std::array scale_factors = { scale_h, scale_w }; + + op_upsample_bilinear2d_vec_out(input, {}, false, OptionalArrayRef({scale_factors.data(), scale_factors.size()}), out); + + const auto expected = tf.ones({1, 1, out_h, out_w}); + + + EXPECT_TENSOR_CLOSE(out, expected); + } +} + +TEST_F(OpUpsampleBilinear2dTest, ZeroComputedOutputSizeDies) { + TensorFactory tf; + + const auto input = tf.ones({1, 1, 1, 2}); + auto out = tf.zeros({1, 1, 1, 4}); + std::array scale_factors = { 1, 0.25 }; + + ET_EXPECT_KERNEL_FAILURE( + context_, op_upsample_bilinear2d_vec_out(input, {}, false, OptionalArrayRef({scale_factors.data(), scale_factors.size()}), out)); +} + +TEST_F(OpUpsampleBilinear2dTest, NumericsCheck) { + TensorFactory tf; + + const auto input = tf.ones({3, 7, 47, 99}); + auto out = tf.zeros({3, 7, 291, 512}); + std::array output_size = {291, 512}; + + auto input_ptr = static_cast(input.mutable_data_ptr()); + for (auto i = 0ul; i < input.numel(); i++) { + input_ptr[i] = static_cast(i); + } + + op_upsample_bilinear2d_vec_out( + input, + OptionalArrayRef({output_size.data(), output_size.size()}), + false, + {}, + out); + + // Indices and expected values to evaluate. + std::vector> test_values = { + {0, 2, 60, 200, 10262.14453125}, + {1, 6, 5, 503, 60624.30078125}, + {2, 0, 111, 300, 66932.953125}, + }; + + const auto output_data = static_cast(out.const_data_ptr()); + for (const auto& test_case : test_values) { + const auto [n, c, h, w, expected] = test_case; + const auto actual = output_data + [n * out.strides()[0] + c * out.strides()[1] + h * out.strides()[2] + + w * out.strides()[3]]; + EXPECT_FLOAT_EQ(expected, actual); + } +} + +TEST_F(OpUpsampleBilinear2dTest, NumericsCheckAlignCorners) { + TensorFactory tf; + + const auto input = tf.ones({3, 7, 47, 99}); + auto out = tf.zeros({3, 7, 291, 512}); + std::array output_size = {291, 512}; + + auto input_ptr = static_cast(input.mutable_data_ptr()); + for (auto i = 0ul; i < input.numel(); i++) { + input_ptr[i] = static_cast(i); + } + + op_upsample_bilinear2d_vec_out( + input, + OptionalArrayRef({output_size.data(), output_size.size()}), + true, + {}, + out); + + // Indices and expected values to evaluate. + std::vector> test_values = { + {0, 2, 60, 200, 10286.5634765625}, + {1, 6, 5, 503, 60663.98046875}, + {2, 0, 111, 300, 66942.625}, + }; + + const auto output_data = static_cast(out.const_data_ptr()); + for (const auto& test_case : test_values) { + const auto [n, c, h, w, expected] = test_case; + const auto actual = output_data + [n * out.strides()[0] + c * out.strides()[1] + h * out.strides()[2] + + w * out.strides()[3]]; + EXPECT_FLOAT_EQ(expected, actual); + } +} + +TEST_F(OpUpsampleBilinear2dTest, Simple5x1To4x1) { + TensorFactory tf; + + const auto input = tf.make({1, 1, 5, 1}, {1.0, 2.0, 3.0, 4.0, 5.0}); + std::array output_size = {4, 1}; + auto out = tf.zeros({1, 1, 4, 1}); + + op_upsample_bilinear2d_vec_out( + input, + OptionalArrayRef({output_size.data(), output_size.size()}), + false, + {}, + out); + + const auto expected = tf.make({1, 1, 4, 1}, {1.1250, 2.3750, 3.6250, 4.8750}); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST_F(OpUpsampleBilinear2dTest, Simple5x1To4x1AlignCorners) { + TensorFactory tf; + + const auto input = tf.make({1, 1, 5, 1}, {1.0, 2.0, 3.0, 4.0, 5.0}); + std::array output_size = {4, 1}; + auto out = tf.zeros({1, 1, 4, 1}); + + op_upsample_bilinear2d_vec_out( + input, + OptionalArrayRef({output_size.data(), output_size.size()}), + true, + {}, + out); + + const auto expected = tf.make({1, 1, 4, 1}, {1.0, 2.333333, 3.666667, 5.0}); + + EXPECT_TENSOR_CLOSE(out, expected); +} diff --git a/kernels/test/targets.bzl b/kernels/test/targets.bzl index 77b18a4814..9c104cc0a4 100644 --- a/kernels/test/targets.bzl +++ b/kernels/test/targets.bzl @@ -321,6 +321,7 @@ def define_common_targets(): _common_op_test("op_trunc_test", ["aten", "portable"]) _common_op_test("op_unbind_copy_test", ["aten", "portable"]) _common_op_test("op_unsqueeze_copy_test", ["aten", "portable"]) + _common_op_test("op_upsample_bilinear2d_test", ["aten", "portable"]) _common_op_test("op_var_test", ["aten", "portable"]) _common_op_test("op_view_copy_test", ["aten", "portable"]) _common_op_test("op_where_test", ["aten", "portable"]) diff --git a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl index 53698e7f21..521e2b2078 100644 --- a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -1226,6 +1226,12 @@ ATEN_OPS = ( "//executorch/kernels/portable/cpu/util:copy_ops_util", ], ), + op_target( + name = "op_upsample_bilinear2d", + deps = [ + "//executorch/kernels/portable/cpu/util:upsample_util", + ], + ), op_target( name = "op_var", deps = [