diff --git a/examples/models/llama2/custom_ops/op_sdpa.cpp b/examples/models/llama2/custom_ops/op_sdpa.cpp index 40b80bf5d0..80f01ffd7d 100644 --- a/examples/models/llama2/custom_ops/op_sdpa.cpp +++ b/examples/models/llama2/custom_ops/op_sdpa.cpp @@ -526,22 +526,22 @@ bool validate_flash_attention_args( "Attention mask must be a 2D tensor"); ET_LOG_MSG_AND_RETURN_IF_FALSE( - is_default_dim_order(query.dim_order().data(), query.dim()), - "key cache must be in default dim order"); + is_contiguous_dim_order(query.dim_order().data(), query.dim()), + "key cache must be in contiguous dim order"); ET_LOG_MSG_AND_RETURN_IF_FALSE( - is_default_dim_order(key.dim_order().data(), key.dim()), - "value cache must be in default dim order"); + is_contiguous_dim_order(key.dim_order().data(), key.dim()), + "value cache must be in contiguous dim order"); ET_LOG_MSG_AND_RETURN_IF_FALSE( - is_default_dim_order(value.dim_order().data(), value.dim()), - "value cache must be in default dim order"); + is_contiguous_dim_order(value.dim_order().data(), value.dim()), + "value cache must be in contiguous dim order"); if (attn_mask.has_value()) { ET_LOG_MSG_AND_RETURN_IF_FALSE( - is_default_dim_order( + is_contiguous_dim_order( attn_mask.value().dim_order().data(), attn_mask.value().dim()), - "value cache must be in default dim order"); + "value cache must be in contiguous dim order"); } return true; @@ -593,14 +593,14 @@ bool validate_cache_params( seq_length, v_cache.size(2)); - // Make sure they are in default dim order + // Make sure they are in contiguous dim order ET_LOG_MSG_AND_RETURN_IF_FALSE( - is_default_dim_order(k_cache.dim_order().data(), k_cache.dim()), - "key cache must be in default dim order"); + is_contiguous_dim_order(k_cache.dim_order().data(), k_cache.dim()), + "key cache must be in contiguous dim order"); ET_LOG_MSG_AND_RETURN_IF_FALSE( - is_default_dim_order(v_cache.dim_order().data(), v_cache.dim()), - "value cache must be in default dim order"); + is_contiguous_dim_order(v_cache.dim_order().data(), v_cache.dim()), + "value cache must be in contiguous dim order"); return true; } @@ -618,9 +618,9 @@ void update_cache( "projected_value must have batch size of 1"); ET_CHECK_MSG(cache.size(1) == 1, "cache must have batch size of 1"); ET_CHECK_MSG( - is_default_dim_order( + is_contiguous_dim_order( projected_value.dim_order().data(), projected_value.dim()), - "projected value must be in default dim order"); + "projected value must be in contiguous dim order"); const void* projected_value_data = projected_value.const_data_ptr(); void* cache_data = cache.mutable_data_ptr(); diff --git a/exir/dim_order_utils.py b/exir/dim_order_utils.py index 0aae6e9230..a0551c6f4d 100644 --- a/exir/dim_order_utils.py +++ b/exir/dim_order_utils.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import List +from typing import List, Optional import torch @@ -27,11 +27,13 @@ def _get_channels_last_dim_order(ndim: int) -> List[int]: raise AssertionError(f"Unsupported rank: {ndim}") -def get_memory_format(dim_order: List[int]) -> torch.memory_format: +def get_memory_format(dim_order: Optional[List[int]]) -> torch.memory_format: """ Given a dim_order try to map it to torch.memory_format """ - if dim_order == _get_contiguous_dim_order(len(dim_order)): + if dim_order is None: + return torch.preserve_format + elif dim_order == _get_contiguous_dim_order(len(dim_order)): return torch.contiguous_format elif len(dim_order) == 4 and dim_order == _get_channels_last_dim_order( len(dim_order) @@ -43,11 +45,15 @@ def get_memory_format(dim_order: List[int]) -> torch.memory_format: ) -def get_dim_order(memory_format: torch.memory_format, ndim: int) -> List[int]: +def get_dim_order( + memory_format: Optional[torch.memory_format], ndim: int +) -> Optional[List[int]]: """ Given a memory_format and a tensor rank, generate a dim_order """ - if memory_format == torch.contiguous_format: + if memory_format in [None, torch.preserve_format]: + return None + elif memory_format == torch.contiguous_format: return _get_contiguous_dim_order(ndim) elif memory_format == torch.channels_last: return _get_channels_last_dim_order(ndim) @@ -55,3 +61,21 @@ def get_dim_order(memory_format: torch.memory_format, ndim: int) -> List[int]: raise AssertionError( f"Failed to generate dim_order for a given memory format: {memory_format}" ) + + +def is_channel_last_dim_order(tensor: torch.Tensor) -> bool: + """ + Check if a tensor has channels last dim order + """ + if tensor.dim() != 4: + # Only support 4D tensors for channel list memory format. + return False + + return tensor.dim_order() == tuple(_get_channels_last_dim_order(tensor.dim())) + + +def is_contiguous_dim_order(tensor: torch.Tensor) -> bool: + """ + Check if a tensor has contiguous dim order + """ + return tensor.dim_order() == tuple(_get_contiguous_dim_order(tensor.dim())) diff --git a/exir/tests/test_memory_format_ops_pass.py b/exir/tests/test_memory_format_ops_pass.py index 2f251ec8bf..15e73dd413 100644 --- a/exir/tests/test_memory_format_ops_pass.py +++ b/exir/tests/test_memory_format_ops_pass.py @@ -10,6 +10,11 @@ import torch from executorch.exir import EdgeCompileConfig, to_edge + +from executorch.exir.dim_order_utils import ( + is_channel_last_dim_order, + is_contiguous_dim_order, +) from torch.export import export from torch.testing import FileCheck @@ -22,15 +27,6 @@ class MemoryFormatTestSet: class TestMemoryFormatOpsPass(unittest.TestCase): - def is_channel_last(self, x: torch.Tensor): - # This is a heuristic to determine if the input tensor is in NHWC (channel last) - # due to we do not have a good way to infer the dimension order or the memory format - # of the input tensor. Please not this function is specific for contiguous tensors - # whose dim(1) is channel one only, other types of tensors may not work well - # due to different channel configuration and memory arrangement. - - return x.stride(1) == 1 - def memory_format_test_runner(self, test_set: MemoryFormatTestSet): aten_op_str = "torch.ops.aten._to_copy.default" edge_op_str = "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default" @@ -60,13 +56,13 @@ def memory_format_test_runner(self, test_set: MemoryFormatTestSet): actual = epm.exported_program().module()(*test_set.sample_input) self.assertTrue(torch.allclose(actual, expected)) self.assertEqual( - self.is_channel_last(actual), - self.is_channel_last(expected), + is_channel_last_dim_order(actual), + is_channel_last_dim_order(expected), ) if test_set.target_memory_format == torch.channels_last: - self.assertTrue(self.is_channel_last(actual)) + self.assertTrue(is_channel_last_dim_order(actual)) elif test_set.target_memory_format == torch.contiguous_format: - self.assertFalse(self.is_channel_last(actual)) + self.assertTrue(is_contiguous_dim_order(actual)) else: raise RuntimeError("Unknown memory format") diff --git a/kernels/portable/cpu/op_native_batch_norm.cpp b/kernels/portable/cpu/op_native_batch_norm.cpp index 26eb5d90a7..2e613c0a63 100644 --- a/kernels/portable/cpu/op_native_batch_norm.cpp +++ b/kernels/portable/cpu/op_native_batch_norm.cpp @@ -66,10 +66,10 @@ std::tuple _native_batch_norm_legit_no_training_out( InvalidArgument, ret_val); - // For now, only support the default dim order + // For now, only support the contiguous dim order ET_KERNEL_CHECK( ctx, - is_default_dim_order(in.dim_order().data(), in.dim_order().size()), + is_contiguous_dim_order(in.dim_order().data(), in.dim_order().size()), InvalidArgument, ret_val); diff --git a/kernels/portable/test/targets.bzl b/kernels/portable/test/targets.bzl index ae0dbaef40..261ec50d76 100644 --- a/kernels/portable/test/targets.bzl +++ b/kernels/portable/test/targets.bzl @@ -8,7 +8,7 @@ def define_common_targets(): """ define_supported_features_lib() - op_test(name = "op_allclose_test", aten_compatible = False) + op_test(name = "op_allclose_test") op_test(name = "op_div_test") op_test(name = "op_gelu_test") op_test(name = "op_mul_test") diff --git a/kernels/test/targets.bzl b/kernels/test/targets.bzl index 789179c4ca..941e42ba1d 100644 --- a/kernels/test/targets.bzl +++ b/kernels/test/targets.bzl @@ -1,14 +1,14 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") load("@fbsource//xplat/executorch/kernels/test:util.bzl", "codegen_function_header_wrapper", "generated_op_test", "op_test") -def _common_op_test(name, kernels, aten_compatible = True): +def _common_op_test(name, kernels): """ Defines test targets in format of _op__test For ATen kernel testing, let's use portable functions.yaml for tested ops. """ for kernel in kernels: deps = [":function_header_wrapper_{}".format(kernel)] - op_test(name, aten_compatible = aten_compatible, kernel_name = kernel, use_kernel_prefix = True, deps = deps) + op_test(name, kernel_name = kernel, use_kernel_prefix = True, deps = deps) def make_example_generated_op_test_target(): """ diff --git a/kernels/test/util.bzl b/kernels/test/util.bzl index 0efeb49774..c2158bfab5 100644 --- a/kernels/test/util.bzl +++ b/kernels/test/util.bzl @@ -1,7 +1,7 @@ load("@fbsource//tools/build_defs:fbsource_utils.bzl", "is_xplat") load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") -def op_test(name, deps = [], aten_compatible = True, kernel_name = "portable", use_kernel_prefix = False): +def op_test(name, deps = [], kernel_name = "portable", use_kernel_prefix = False): """Defines a cxx_test() for an "op_*_test.cpp" file. Args: @@ -11,8 +11,6 @@ def op_test(name, deps = [], aten_compatible = True, kernel_name = "portable", u under //kernels//...; e.g., "op_add_test" will depend on "//kernels/portable/cpu:op_add". deps: Optional extra deps to add to the cxx_test(). - aten_compatible: If True, the operator under test is ATen-compatible - (i.e., appears in `functions.yaml`). kernel_name: The name string as in //executorch/kernels/. use_kernel_prefix: If True, the target name is _op__test. Used by common kernel testing. diff --git a/runtime/core/exec_aten/testing_util/tensor_factory.h b/runtime/core/exec_aten/testing_util/tensor_factory.h index 7ec4d5dc73..993bd8f6bd 100644 --- a/runtime/core/exec_aten/testing_util/tensor_factory.h +++ b/runtime/core/exec_aten/testing_util/tensor_factory.h @@ -292,7 +292,7 @@ class TensorFactory { * size of this vector must be equal to the product of the elements of * `sizes`. * @param[in] dim_order The dim order describing how tensor memory is laid - * out. If empty or not specificed, the function will use a default dim order + * out. If empty or not specificed, the function will use a contiguous dim order * of {0, 1, 2, 3, ...} * * @return A new Tensor with the specified shape and data. @@ -706,7 +706,7 @@ class TensorFactory { * size of this vector must be equal to the product of the elements of * `sizes`. * @param[in] dim_order The dim order describing how tensor memory is laid - * out. If empty or not specificed, the function will use a default dim order + * out. If empty or not specificed, the function will use a contiguous dim order * of {0, 1, 2, 3, ...} * * @return A new Tensor with the specified shape and data. diff --git a/runtime/core/exec_aten/util/dim_order_util.h b/runtime/core/exec_aten/util/dim_order_util.h index 31175d1e6c..33aa4f86a8 100644 --- a/runtime/core/exec_aten/util/dim_order_util.h +++ b/runtime/core/exec_aten/util/dim_order_util.h @@ -29,14 +29,14 @@ bool validate_dim_order(const DimOrderType* dim_order, const size_t dims) { } // namespace /** - * Check if a given dim_order array is equivalent to the default dim order of + * Check if a given dim_order array is equivalent to the contiguous dim order of * {0, 1, 2, 3, ...} * * @param[in] dim_order pointer to dim_order array * @param[in] dims length of the dim_order array */ template -inline bool is_default_dim_order( +inline bool is_contiguous_dim_order( const DimOrderType* dim_order, const size_t dims) { for (int i = 0; i < dims; ++i) { diff --git a/runtime/core/exec_aten/util/tensor_util.h b/runtime/core/exec_aten/util/tensor_util.h index f7a4a8d2a9..89593233d6 100644 --- a/runtime/core/exec_aten/util/tensor_util.h +++ b/runtime/core/exec_aten/util/tensor_util.h @@ -315,7 +315,7 @@ #define ET_CHECK_DEFAULT_OR_CHANNELSLAST_DIMORDER(t__) \ ({ \ ET_CHECK_MSG( \ - is_default_dim_order( \ + is_contiguous_dim_order( \ t__.dim_order().data(), t__.dim_order().size()) || \ is_channels_last_dim_order( \ t__.dim_order().data(), t__.dim_order().size()), \ diff --git a/runtime/core/exec_aten/util/tensor_util_aten.cpp b/runtime/core/exec_aten/util/tensor_util_aten.cpp index f08189cb8b..d3d5417f96 100644 --- a/runtime/core/exec_aten/util/tensor_util_aten.cpp +++ b/runtime/core/exec_aten/util/tensor_util_aten.cpp @@ -59,7 +59,7 @@ inline bool tensor_is_default_or_channels_last_dim_order(at::Tensor t) { get_dim_order(t, dim_order, t.dim()) == Error::Ok, "Failed to retrieve dim order from tensor!"); - bool ret_val = is_default_dim_order(dim_order, t.dim()) || + bool ret_val = is_contiguous_dim_order(dim_order, t.dim()) || is_channels_last_dim_order(dim_order, t.dim()); if (!ret_val) { diff --git a/runtime/core/exec_aten/util/tensor_util_portable.cpp b/runtime/core/exec_aten/util/tensor_util_portable.cpp index 8795833c37..ad7c93f0a3 100644 --- a/runtime/core/exec_aten/util/tensor_util_portable.cpp +++ b/runtime/core/exec_aten/util/tensor_util_portable.cpp @@ -55,7 +55,7 @@ bool tensor_has_valid_dim_order(torch::executor::Tensor t) { bool tensor_is_default_or_channels_last_dim_order(torch::executor::Tensor t) { bool ret_val = - is_default_dim_order(t.dim_order().data(), t.dim_order().size()) || + is_contiguous_dim_order(t.dim_order().data(), t.dim_order().size()) || is_channels_last_dim_order(t.dim_order().data(), t.dim_order().size()); if (!ret_val) { diff --git a/runtime/core/exec_aten/util/test/dim_order_util_test.cpp b/runtime/core/exec_aten/util/test/dim_order_util_test.cpp index f1e9309710..28e768be65 100644 --- a/runtime/core/exec_aten/util/test/dim_order_util_test.cpp +++ b/runtime/core/exec_aten/util/test/dim_order_util_test.cpp @@ -236,7 +236,7 @@ TEST(TensorUtilTest, IsDefaultDimOrderTest) { std::vector dim_order(i); std::iota(dim_order.begin(), dim_order.end(), 0); - EXPECT_TRUE(torch::executor::is_default_dim_order( + EXPECT_TRUE(torch::executor::is_contiguous_dim_order( dim_order.data(), dim_order.size())); // As a bonus, check that is_channels_last returns false @@ -252,7 +252,7 @@ TEST(TensorUtilTest, IsDefaultDimOrderFailCasesTest) { std::iota(dim_order.begin(), dim_order.end(), 0); std::swap(dim_order[0], dim_order[1]); - EXPECT_FALSE(torch::executor::is_default_dim_order( + EXPECT_FALSE(torch::executor::is_contiguous_dim_order( dim_order.data(), dim_order.size())); } @@ -263,7 +263,7 @@ TEST(TensorUtilTest, IsDefaultDimOrderFailCasesTest) { dim_order[d] = (d + 1) % i; } - EXPECT_FALSE(torch::executor::is_default_dim_order( + EXPECT_FALSE(torch::executor::is_contiguous_dim_order( dim_order.data(), dim_order.size())); } } @@ -276,8 +276,8 @@ TEST(TensorUtilTest, IsChannelsLastDimOrderTest) { EXPECT_TRUE(torch::executor::is_channels_last_dim_order(dim_order_5d, 5)); // As a bonus, check that is_default returns false - EXPECT_FALSE(torch::executor::is_default_dim_order(dim_order_4d, 4)); - EXPECT_FALSE(torch::executor::is_default_dim_order(dim_order_5d, 5)); + EXPECT_FALSE(torch::executor::is_contiguous_dim_order(dim_order_4d, 4)); + EXPECT_FALSE(torch::executor::is_contiguous_dim_order(dim_order_5d, 5)); } TEST(TensorUtilTest, IsChannelsLastDimOrderFailCasesTest) {