diff --git a/backends/arm/TARGETS b/backends/arm/TARGETS index bd42710d7b..05f6095c37 100644 --- a/backends/arm/TARGETS +++ b/backends/arm/TARGETS @@ -110,3 +110,14 @@ python_library( "//executorch/backends/arm/operators:node_visitor", ], ) + +python_library( + name = "arm_model_evaluator", + src = [ + "util/arm_model_evaluator.py", + ], + typing = True, + deps = [ + "//caffe2:torch", + ] +) diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index a72cdfd1a0..1e2b26ef64 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -29,8 +29,8 @@ DecomposeSoftmaxesPass, ) from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass -from executorch.backends.arm._passes.insert_squeeze_after_sum_pass import ( - InsertSqueezeAfterSumPass, +from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import ( + KeepDimsFalseToSqueezePass, ) from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass from executorch.backends.arm._passes.meandim_to_averagepool_pass import ( @@ -71,7 +71,7 @@ def transform_to_backend_pipeline( self.add_pass(DecomposeMeanDimPass()) self.add_pass(MatchArgRanksPass(exported_program)) self.add_pass(DecomposeDivPass()) - self.add_pass(InsertSqueezeAfterSumPass()) + self.add_pass(KeepDimsFalseToSqueezePass()) self.add_pass(ConvertSplitToSlicePass()) self.add_pass(Conv1dUnsqueezePass(exported_program)) self.add_pass(DecomposeSoftmaxesPass()) diff --git a/backends/arm/_passes/arm_pass_utils.py b/backends/arm/_passes/arm_pass_utils.py index 3fcf724e5b..78ee6e265c 100644 --- a/backends/arm/_passes/arm_pass_utils.py +++ b/backends/arm/_passes/arm_pass_utils.py @@ -7,6 +7,7 @@ # pyre-unsafe +from inspect import isclass from typing import Optional import torch @@ -133,3 +134,60 @@ def get_first_fake_tensor(node: torch.fx.Node) -> FakeTensor: fake_tensor, FakeTensor ), f'Found {fake_tensor} in meta["val"] of {node}, expected to find FakeTensor.' return fake_tensor + + +def get_node_arg(args: list | dict, key: int | str | type, default_value=None): + """ + Help-function for getting a value from node.args/ kwargs, three cases: + 1. By position in node.args - Returns arg at given position or default_value if index is one out of bounds + 2. By key in node.kwargs - Returns kwarg with given key or default_value if it deos not exist + 3. By type in node.args - Returns first arg of args of given type. Useful for cases where arg postions may differ but types are unique. + """ + if isinstance(key, int): + if 0 <= key < len(args): + return args[key] + elif key == len(args): + if default_value is not None: + return default_value + else: + raise RuntimeError(f"No defult value given for index {key}") + else: + raise RuntimeError( + f"Out of bounds index {key} for getting value in args (of size {len(args)})" + ) + elif isinstance(key, str): + return args.get(key, default_value) + elif isclass(key): + for arg in args: + if isinstance(arg, key): + return arg + if default_value is not None: + return default_value + else: + raise RuntimeError(f"No arg of type {key}") + else: + raise RuntimeError("Invalid type") + + +def set_node_arg(node: torch.fx.Node, i: int | str, value): + """ + Help-function for setting a value in node.args/ kwargs. If the index is one larger than the list size, the value is instead appended to the list. + """ + if isinstance(i, int): + if 0 <= i < len(node.args): + args = list(node.args) + args[i] = value + node.args = tuple(args) + return + elif i == len(node.args): + node.args = node.args + (value,) + else: + raise RuntimeError( + f"Out of bounds index {i} for setting value in {node} args (of size {len(node.args)})" + ) + elif isinstance(i, str): + kwargs = dict(node.kwargs) + kwargs[i] = value + node.kwargs = kwargs + else: + raise RuntimeError("Invalid type") diff --git a/backends/arm/_passes/decompose_meandim_pass.py b/backends/arm/_passes/decompose_meandim_pass.py index d927fd613c..abf5c8f363 100644 --- a/backends/arm/_passes/decompose_meandim_pass.py +++ b/backends/arm/_passes/decompose_meandim_pass.py @@ -7,6 +7,7 @@ # pyre-unsafe import torch +from executorch.backends.arm._passes.arm_pass_utils import get_node_arg from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -42,16 +43,16 @@ def call_operator(self, op, args, kwargs, meta): if op not in (exir_ops.edge.aten.mean.dim, torch.ops.aten.mean.dim): return super().call_operator(op, args, kwargs, meta) - x = args[0] - dim = args[1] - keepdim = args[2] if len(args) > 2 else False - if not keepdim: - return super().call_operator(op, args, kwargs, meta) - # if keepdim == True and dim == [-1, -2], mean.dim can be + x = get_node_arg(args, 0) + dim = get_node_arg(args, 1) + keepdim = get_node_arg(args, 2, False) + + # if dim == [-1, -2], mean.dim can be # decomposed to avg_pool2d. This is handled by ConvertMeanDimToAveragePool. if dim == [-1, -2]: # Simply return the mean.dim operator for future decomposition. return super().call_operator(op, args, kwargs, meta) + shape = meta["val"].size() dtype = meta["val"].dtype input_shape = x.data.size() diff --git a/backends/arm/_passes/decompose_var_pass.py b/backends/arm/_passes/decompose_var_pass.py index cc8f0eb6da..283760e423 100644 --- a/backends/arm/_passes/decompose_var_pass.py +++ b/backends/arm/_passes/decompose_var_pass.py @@ -8,6 +8,7 @@ import torch +from executorch.backends.arm._passes.arm_pass_utils import get_node_arg from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -53,26 +54,30 @@ def call_operator(self, op, args, kwargs, meta): torch.ops.aten.var.dim, ): return super().call_operator(op, args, kwargs, meta) - shape = meta["val"].size() + + x = args[0] + input_shape = x.data.size() + shape = list(meta["val"].size()) + if shape == []: + shape = [1 for _ in input_shape] + dtype = meta["val"].dtype - dim = args[1] if len(args) > 1 else list(range(len(shape))) + # Get dim from args based on argument type + dim = get_node_arg(args, key=list, default_value=list(range(len(shape)))) + if op == torch.ops.aten.var.dim: - correction = args[-2] - keepdim = args[-1] + keepdim = get_node_arg(args, bool, False) + correction = get_node_arg(args, int, 1) else: - correction = kwargs["correction"] - keepdim = kwargs.get("keepdim", False) - if not keepdim: - return super().call_operator(op, args, kwargs, meta) + correction = get_node_arg(kwargs, "correction", 1) + keepdim = get_node_arg(kwargs, "keepdim", False) - x = args[0] - input_shape = x.data.size() N = 1 for d in dim: N *= input_shape[d] mean_op, diff_op, mul_op, sum_op, full_op = get_var_decomposition(op) - mean = super().call_operator(mean_op, (x, dim, keepdim), {}, meta) + mean = super().call_operator(mean_op, (x, dim, True), {}, meta) diff = super().call_operator(diff_op, (x, mean), {}, meta) squared_diff = super().call_operator(mul_op, (diff, diff), {}, meta) sum = super().call_operator(sum_op, (squared_diff, dim, keepdim), {}, meta) diff --git a/backends/arm/_passes/insert_squeeze_after_sum_pass.py b/backends/arm/_passes/keep_dims_false_to_squeeze_pass.py similarity index 58% rename from backends/arm/_passes/insert_squeeze_after_sum_pass.py rename to backends/arm/_passes/keep_dims_false_to_squeeze_pass.py index e088c2e35a..736c627d91 100644 --- a/backends/arm/_passes/insert_squeeze_after_sum_pass.py +++ b/backends/arm/_passes/keep_dims_false_to_squeeze_pass.py @@ -10,14 +10,18 @@ import torch import torch.fx -from executorch.backends.arm._passes.arm_pass_utils import create_node +from executorch.backends.arm._passes.arm_pass_utils import ( + create_node, + get_node_arg, + set_node_arg, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult -class InsertSqueezeAfterSumPass(ExportPass): +class KeepDimsFalseToSqueezePass(ExportPass): """ - In Pytorch, the default behaviour of Tensor.sum is to squeeze + In Pytorch, the default behaviour of for example Tensor.sum is to squeeze the dimension that is summed (keep_dim = False). However, in TOSA, REDUCE_SUM always preserves the rank of the input (keep_dim = True). @@ -31,21 +35,44 @@ class InsertSqueezeAfterSumPass(ExportPass): squeeze(dim = dims) """ + # CURRENTLY NOT HANDLED OPS + # exir_ops.edge.aten.amax, + # exir_ops.edge.aten.amin, + # exir_ops.edge.aten.any.dim, + # exir_ops.edge.aten.any.dims, + # exir_ops.edge.aten.argmax, + # exir_ops.edge.aten.argmin, + # exir_ops.edge.aten.max.dim, + # exir_ops.edge.aten.min.dim, + # exir_ops.edge.aten.prod.dim_int, + + # HANDLED OPS + # exir_ops.edge.aten.sum.dim_IntList + # exir_ops.edge.aten.var.correction (decomposed in decompose_var_pass) + # exir_ops.edge.aten.var.dim (decomposed in decompose_var_pass) + # exir_ops.edge.aten.mean.dim (decomposed in decompose_meandim_pass) + def call(self, graph_module: torch.fx.GraphModule): for node in graph_module.graph.nodes: + keep_dim_index = None + if node.op != "call_function": continue - if node.target != exir_ops.edge.aten.sum.dim_IntList: + if node.target == exir_ops.edge.aten.sum.dim_IntList: + keep_dim_index = 2 + else: continue + sum_node = cast(torch.fx.Node, node) - keep_dim = cast(bool, sum_node.args[2] if len(sum_node.args) > 2 else False) + keep_dim = get_node_arg(sum_node.args, keep_dim_index, False) + if keep_dim: continue - dim_list = cast(list[int], sum_node.args[1]) + dim_list = get_node_arg(sum_node.args, 1, [0]) # Add keep_dim = True arg to sum node. - sum_node.args = sum_node.args[0:2] + (True,) + set_node_arg(sum_node, 2, True) with graph_module.graph.inserting_after(sum_node): squeeze_node = create_node( @@ -53,6 +80,7 @@ def call(self, graph_module: torch.fx.GraphModule): ) sum_node.replace_all_uses_with(squeeze_node) squeeze_node.args = (sum_node, dim_list) + graph_module.graph.eliminate_dead_code() graph_module.recompile() graph_module = super().call(graph_module).graph_module diff --git a/backends/arm/operator_support/__init__.py b/backends/arm/operator_support/__init__.py index c133ce8003..08f58b1e43 100644 --- a/backends/arm/operator_support/__init__.py +++ b/backends/arm/operator_support/__init__.py @@ -5,9 +5,4 @@ # pyre-unsafe -from . import ( # noqa - mean_dim_support, - right_shift_support, - tosa_supported_operators, - var_correction_support, -) +from . import right_shift_support, to_copy_support, tosa_supported_operators # noqa diff --git a/backends/arm/operator_support/mean_dim_support.py b/backends/arm/operator_support/mean_dim_support.py deleted file mode 100644 index 67a7c20406..0000000000 --- a/backends/arm/operator_support/mean_dim_support.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright 2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -from typing import cast - -import torch.fx as fx - -from executorch.backends.arm.operator_support.tosa_supported_operators import ( - register_tosa_support_check, - SupportedTOSAOperatorCheck, -) -from executorch.backends.arm.tosa_specification import TosaSpecification -from executorch.exir.dialects._ops import ops as exir_ops - - -@register_tosa_support_check -class MeanDimSupported(SupportedTOSAOperatorCheck): - targets = [exir_ops.edge.aten.mean.dim] - - tosa_specs = [ - TosaSpecification.create_from_string("TOSA-0.80.0+BI"), - TosaSpecification.create_from_string("TOSA-0.80.0+MI"), - ] - - def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool: - assert node.target in self.targets - - keep_dim = node.args[2] if len(node.args) > 2 else False - return cast(bool, keep_dim) diff --git a/backends/arm/operator_support/to_copy_support.py b/backends/arm/operator_support/to_copy_support.py new file mode 100644 index 0000000000..9bba274804 --- /dev/null +++ b/backends/arm/operator_support/to_copy_support.py @@ -0,0 +1,120 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +import logging + +import torch + +import torch.fx as fx + +from executorch.backends.arm.operator_support.tosa_supported_operators import ( + register_tosa_support_check, + SupportedTOSAOperatorCheck, +) +from executorch.backends.arm.tosa_specification import TosaSpecification +from executorch.exir.dialects._ops import ops as exir_ops + +logger = logging.getLogger(__name__) + + +@register_tosa_support_check +class ToCopySupported(SupportedTOSAOperatorCheck): + targets = [exir_ops.edge.aten._to_copy.default] + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80.0+BI"), + TosaSpecification.create_from_string("TOSA-0.80.0+MI"), + ] + + SupportedTypeDict = dict[torch.dtype, list[torch.dtype]] + + @staticmethod + def _merge_supported_types( + dtypes1: SupportedTypeDict, dtypes2: SupportedTypeDict + ) -> SupportedTypeDict: + merged_dtypes = dtypes1 + for k, v in dtypes2.items(): + merged_dtypes[k] = merged_dtypes.get(k, []) + v + return merged_dtypes + + SUPPORTED_INT_TYPES: SupportedTypeDict = { + torch.bool: [torch.int8, torch.int16, torch.int32], + torch.int8: [torch.bool, torch.int16, torch.int32], + torch.int16: [torch.bool, torch.int8, torch.int32], + torch.int32: [torch.bool, torch.int8, torch.int16], + } + SUPPORTED_FLOAT_TYPES: SupportedTypeDict = { + torch.int8: [torch.float16, torch.bfloat16, torch.float32], + torch.int16: [torch.float16, torch.bfloat16, torch.float32], + torch.int32: [torch.float16, torch.bfloat16, torch.float32], + torch.bfloat16: [torch.int8, torch.int16, torch.int32, torch.float32], + torch.float16: [torch.int8, torch.int16, torch.int32, torch.float32], + torch.float32: [ + torch.int8, + torch.int16, + torch.int32, + torch.bfloat16, + torch.float16, + ], + } + ALL_SUPPORTED_TYPES = _merge_supported_types( + SUPPORTED_INT_TYPES, SUPPORTED_FLOAT_TYPES + ) + POSSIBLE_TYPE_CONVERSIONS = {torch.int64: torch.int32} + + def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool: + assert node.target in self.targets + + if tosa_spec not in self.tosa_specs: + return False + + assert tosa_spec.support_integer() + supported_dtypes = ( + self.ALL_SUPPORTED_TYPES + if tosa_spec.support_float() + else self.SUPPORTED_INT_TYPES + ) + # Take into account possible type conversions + supported_dtypes.update( + (k, supported_dtypes[v]) + for k, v in self.POSSIBLE_TYPE_CONVERSIONS.items() + if v in supported_dtypes + ) + + # Check input type + assert len(node.all_input_nodes) == 1 + input_val = node.all_input_nodes[0].meta["val"] + assert isinstance(input_val, torch._subclasses.FakeTensor) + input_dtype = input_val.dtype + if input_dtype not in supported_dtypes: + logger.info( + f"Input dtype {input_val.dtype} is not supported in " + f"{node.target.name()}." + ) + return False + + # Check output type + output_val = node.meta["val"] + assert isinstance(output_val, torch._subclasses.FakeTensor) + if output_val.dtype not in supported_dtypes[input_dtype]: + logger.info( + f"Output dtype {output_val.dtype} is not supported in " + f"{node.target.name()} for input dtype {input_dtype}. " + f"Supported output types: " + f"{''.join(str(t) for t in supported_dtypes[input_dtype])}" + ) + return False + + # Check memory format + if "memory_format" in node.kwargs: + if node.kwargs["memory_format"] in (torch.preserve_format,): + logger.info( + f"Argument 'memory_format' is not supported for " + f"{node.target.name()} right now." + ) + return False + + return True diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 3563ee9c51..7072ba6a82 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -92,6 +92,7 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool: exir_ops.edge.aten.avg_pool2d.default, exir_ops.edge.aten.max_pool2d_with_indices.default, exir_ops.edge.aten.sigmoid.default, + exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mm.default, exir_ops.edge.aten.repeat.default, exir_ops.edge.aten.reciprocal.default, @@ -105,6 +106,8 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool: exir_ops.edge.aten.sum.dim_IntList, exir_ops.edge.aten.tanh.default, exir_ops.edge.aten.upsample_nearest2d.vec, + exir_ops.edge.aten.var.correction, + exir_ops.edge.aten.var.dim, exir_ops.edge.aten.view_copy.default, exir_ops.edge.aten.clone.default, exir_ops.edge.aten.unsqueeze_copy.default, diff --git a/backends/arm/operator_support/var_correction_support.py b/backends/arm/operator_support/var_correction_support.py deleted file mode 100644 index 4aa2ae5e97..0000000000 --- a/backends/arm/operator_support/var_correction_support.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright 2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -from typing import cast - -import torch.fx as fx - -from executorch.backends.arm.operator_support.tosa_supported_operators import ( - register_tosa_support_check, - SupportedTOSAOperatorCheck, -) -from executorch.backends.arm.tosa_specification import TosaSpecification -from executorch.exir.dialects._ops import ops as exir_ops - - -@register_tosa_support_check -class VarCorrectionSupported(SupportedTOSAOperatorCheck): - targets = [exir_ops.edge.aten.var.correction] - - tosa_specs = [ - TosaSpecification.create_from_string("TOSA-0.80.0+BI"), - TosaSpecification.create_from_string("TOSA-0.80.0+MI"), - ] - - def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool: - assert node.target in self.targets - - keep_dim = node.kwargs.get("keepdim", False) - return cast(bool, keep_dim) diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index a5c2dd8dc5..8c4aa85e57 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -36,6 +36,7 @@ op_sub, op_sum, op_tanh, + op_to_copy, op_transpose, op_unsqueeze, op_upsample_nearest2d, diff --git a/backends/arm/operators/op_to_copy.py b/backends/arm/operators/op_to_copy.py new file mode 100644 index 0000000000..15077d6df7 --- /dev/null +++ b/backends/arm/operators/op_to_copy.py @@ -0,0 +1,43 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +from typing import List + +import serializer.tosa_serializer as ts +import torch +import tosa.Op as TosaOp + +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.tosa_mapping import TosaArg + + +@register_node_visitor +class ToCopyVisitor(NodeVisitor): + """ + Implement the type cast functionality of _to_copy. + + Other features like setting of the memory_format or moving a tensor to a + different device are not supported. + + Also note that the node should not be quantized. + """ + + target = "aten._to_copy.default" + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + assert not is_quant_node, "Casting of quantized values is not supported." + assert inputs + tosa_graph.addOperator(TosaOp.Op().CAST, [inputs[0].name], [output.name]) diff --git a/backends/arm/runtime/ArmBackendEthosU.cpp b/backends/arm/runtime/ArmBackendEthosU.cpp index 99ce0a9df2..a14c42140e 100644 --- a/backends/arm/runtime/ArmBackendEthosU.cpp +++ b/backends/arm/runtime/ArmBackendEthosU.cpp @@ -138,6 +138,7 @@ class ArmBackend final : public ::executorch::runtime::BackendInterface { // TODO(MLETORCH-123): Optimise into direct write from Vela into the SRAM // or DRAM output for compatible data layouts. for (int i = 0; i < handles.inputs->count; i++) { + auto tensor_count = 1, io_count = 1; auto tensor_in = args[i]->toTensor(); char* scratch_addr = handles.scratch_data + handles.inputs->io[i].offset; @@ -202,6 +203,19 @@ class ArmBackend final : public ::executorch::runtime::BackendInterface { ET_LOG(Error, "No matching input copy routine"); return Error::InvalidProgram; } + if (!permuted_input_shape) { + calculate_dimensions( + tensor_in, &handles.inputs->io[i], &tensor_count, &io_count); + if (tensor_count != io_count) { + ET_LOG(Error, "Input tensor sizes do not match"); + ET_LOG( + Error, + "Program expects %d elements but got %d", + io_count, + tensor_count); + return Error::InvalidProgram; + } + } } // Allocate driver handle and synchronously invoke driver @@ -236,14 +250,24 @@ class ArmBackend final : public ::executorch::runtime::BackendInterface { result); return Error::InvalidProgram; } - + int tensor_dim = 0, io_dim = 0; // Write outputs from scratch into EValue pointers for (int i = 0; i < handles.outputs->count; i++) { + int tensor_count = 1, io_count = 1; const char* output_addr = handles.scratch_data + handles.outputs->io[i].offset; // Process input EValue into scratch // Outputs are in the index immediately after inputs auto tensor_out = args[handles.inputs->count + i]->toTensor(); + + calculate_dimensions( + tensor_out, &handles.outputs->io[i], &tensor_count, &io_count); + + // At times the topological order of the outputs may change. + // Lets instead ensure that the sum of dimensions match. + tensor_dim = tensor_dim + tensor_count; + io_dim = io_dim + io_count; + bool permuted_output_shape; ET_CHECK_OK_OR_RETURN_ERROR(check_requires_permute( i, @@ -272,6 +296,12 @@ class ArmBackend final : public ::executorch::runtime::BackendInterface { } } } + if (tensor_dim != io_dim) { + ET_LOG(Error, "Total output tensor sizes do not match"); + ET_LOG( + Error, "Program expects size of %d but got %d", tensor_dim, io_dim); + return Error::InvalidProgram; + } return Error::Ok; } @@ -280,6 +310,21 @@ class ArmBackend final : public ::executorch::runtime::BackendInterface { } private: + void calculate_dimensions( + const executorch::aten::Tensor tensor, + VelaIO* io, + int* tensor_count, + int* io_count) const { + for (int i = 0; i < tensor.dim(); i++) { + *tensor_count = *tensor_count * tensor.size(i); + } + + // The VelaIO type has a shape of fixed size 4 + for (int i = 0; i < 4; i++) { + *io_count = *io_count * io->shape[i]; + } + } + Error check_requires_permute( int index, const executorch::aten::Tensor tensor, @@ -287,6 +332,7 @@ class ArmBackend final : public ::executorch::runtime::BackendInterface { bool permuted_io_flag, bool* is_permuted) const { bool permuted_shape = false; + if (tensor.dim() == 4) { // special case for NHWC workaround in AOT; as the compilation has // permuted to channel last in an undetectable way, we assume here @@ -304,30 +350,6 @@ class ArmBackend final : public ::executorch::runtime::BackendInterface { return Error::InvalidProgram; } } - if (!permuted_shape) { - // Check the number of elements in each tensor match - int tensor_count = 1; - int io_count = 1; - - for (int i = 0; i < tensor.dim(); i++) { - tensor_count = tensor_count * tensor.size(i); - } - - // The VelaIO type has a shape of fixed size 4 - for (int i = 0; i < 4; i++) { - io_count = io_count * io->shape[i]; - } - - if (tensor_count != io_count) { - ET_LOG(Error, "Input tensor sizes do not match"); - ET_LOG( - Error, - "Program expects %d elements but got %d", - io_count, - tensor_count); - return Error::InvalidProgram; - } - } *is_permuted = permuted_shape; return Error::Ok; } diff --git a/backends/arm/test/TARGETS b/backends/arm/test/TARGETS new file mode 100644 index 0000000000..ef092c5503 --- /dev/null +++ b/backends/arm/test/TARGETS @@ -0,0 +1,23 @@ +load("@fbcode_macros//build_defs:python_library.bzl", "python_library") + +python_library( + name = "common", + srcs = ["common.py"], + deps = [ + "//executorch/backends/xnnpack/test/tester:tester", + "//executorch/backends/arm:arm_backend", + "//executorch/exir:lib", + "//executorch/exir/backend:compile_spec_schema", + ] +) + +python_library( + name = "runner_utils", + srcs = ["runner_utils.py"], + deps = [ + "//executorch/backends/xnnpack/test/tester:tester", + "//executorch/backends/arm:arm_backend", + "//executorch/exir:lib", + "//executorch/exir/backend:compile_spec_schema", + ] +) diff --git a/backends/arm/test/common.py b/backends/arm/test/common.py index 17353cab31..48214a48a7 100644 --- a/backends/arm/test/common.py +++ b/backends/arm/test/common.py @@ -4,156 +4,33 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import logging + import os -import platform -import shutil -import subprocess -import sys + import tempfile from datetime import datetime -from enum import auto, Enum from pathlib import Path -from typing import Any - -import pytest -import torch +from conftest import is_option_enabled from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder from executorch.exir.backend.compile_spec_schema import CompileSpec -class arm_test_options(Enum): - quantize_io = auto() - corstone300 = auto() - dump_path = auto() - date_format = auto() - fast_fvp = auto() - - -_test_options: dict[arm_test_options, Any] = {} - -# ==== Pytest hooks ==== - - -def pytest_addoption(parser): - parser.addoption("--arm_quantize_io", action="store_true") - parser.addoption("--arm_run_corstone300", action="store_true") - parser.addoption("--default_dump_path", default=None) - parser.addoption("--date_format", default="%d-%b-%H:%M:%S") - parser.addoption("--fast_fvp", action="store_true") - - -def pytest_configure(config): - if config.option.arm_quantize_io: - load_libquantized_ops_aot_lib() - _test_options[arm_test_options.quantize_io] = True - if config.option.arm_run_corstone300: - corstone300_exists = shutil.which("FVP_Corstone_SSE-300_Ethos-U55") - if not corstone300_exists: - raise RuntimeError( - "Tests are run with --arm_run_corstone300 but corstone300 FVP is not installed." - ) - _test_options[arm_test_options.corstone300] = True - if config.option.default_dump_path: - dump_path = Path(config.option.default_dump_path).expanduser() - if dump_path.exists() and os.path.isdir(dump_path): - _test_options[arm_test_options.dump_path] = dump_path - else: - raise RuntimeError( - f"Supplied argument 'default_dump_path={dump_path}' that does not exist or is not a directory." - ) - _test_options[arm_test_options.date_format] = config.option.date_format - _test_options[arm_test_options.fast_fvp] = config.option.fast_fvp - logging.basicConfig(level=logging.INFO, stream=sys.stdout) - - -def pytest_collection_modifyitems(config, items): - if not config.option.arm_quantize_io: - skip_if_aot_lib_not_loaded = pytest.mark.skip( - "u55 tests can only run with quantize_io=True." - ) - - for item in items: - if "u55" in item.name: - item.add_marker(skip_if_aot_lib_not_loaded) - - -def pytest_sessionstart(session): - pass - - -def pytest_sessionfinish(session, exitstatus): - if get_option(arm_test_options.dump_path): - _clean_dir( - get_option(arm_test_options.dump_path), - f"ArmTester_{get_option(arm_test_options.date_format)}.log", - ) - - -# ==== End of Pytest hooks ===== - -# ==== Custom Pytest decorators ===== - - -def expectedFailureOnFVP(test_item): - if is_option_enabled("corstone300"): - test_item.__unittest_expecting_failure__ = True - return test_item - - -# ==== End of Custom Pytest decorators ===== - - -def load_libquantized_ops_aot_lib(): - so_ext = { - "Darwin": "dylib", - "Linux": "so", - "Windows": "dll", - }.get(platform.system(), None) - - find_lib_cmd = [ - "find", - "cmake-out-aot-lib", - "-name", - f"libquantized_ops_aot_lib.{so_ext}", - ] - res = subprocess.run(find_lib_cmd, capture_output=True) - if res.returncode == 0: - library_path = res.stdout.decode().strip() - torch.ops.load_library(library_path) - - -def is_option_enabled( - option: str | arm_test_options, fail_if_not_enabled: bool = False -) -> bool: - """ - Returns whether an option is successfully enabled, i.e. if the flag was - given to pytest and the necessary requirements are available. - Implemented options are: - - corstone300. - - quantize_io. - - The optional parameter 'fail_if_not_enabled' makes the function raise - a RuntimeError instead of returning False. +def get_time_formatted_path(path: str, log_prefix: str) -> str: """ - if isinstance(option, str): - option = arm_test_options[option.lower()] - - if option in _test_options and _test_options[option]: - return True - else: - if fail_if_not_enabled: - raise RuntimeError(f"Required option '{option}' for test is not enabled") - else: - return False + Returns the log path with the current time appended to it. Used for debugging. + Args: + path: The path to the folder where the log file will be stored. + log_prefix: The name of the test. -def get_option(option: arm_test_options) -> Any | None: - if option in _test_options: - return _test_options[option] - return None + Example output: + './my_log_folder/test_BI_artifact_28-Nov-14:14:38.log' + """ + return str( + Path(path) / f"{log_prefix}_{datetime.now().strftime('%d-%b-%H:%M:%S')}.log" + ) def maybe_get_tosa_collate_path() -> str | None: @@ -303,35 +180,6 @@ def get_u85_compile_spec_unbuilt( return compile_spec -def current_time_formated() -> str: - """Return current time as a formated string""" - return datetime.now().strftime(get_option(arm_test_options.date_format)) - - -def _clean_dir(dir: Path, filter: str, num_save=10): - sorted_files: list[tuple[datetime, Path]] = [] - for file in dir.iterdir(): - try: - creation_time = datetime.strptime(file.name, filter) - insert_index = -1 - for i, to_compare in enumerate(sorted_files): - compare_time = to_compare[0] - if creation_time < compare_time: - insert_index = i - break - if insert_index == -1 and len(sorted_files) < num_save: - sorted_files.append((creation_time, file)) - else: - sorted_files.insert(insert_index, (creation_time, file)) - except ValueError: - continue - - if len(sorted_files) > num_save: - for remove in sorted_files[0 : len(sorted_files) - num_save]: - file = remove[1] - file.unlink() - - def get_target_board(compile_spec: list[CompileSpec]) -> str | None: for spec in compile_spec: if spec.key == "compile_flags": diff --git a/backends/arm/test/conftest.py b/backends/arm/test/conftest.py new file mode 100644 index 0000000000..a94adb9a89 --- /dev/null +++ b/backends/arm/test/conftest.py @@ -0,0 +1,196 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import platform +import random +import re +import shutil +import subprocess +import sys +from enum import auto, Enum +from typing import Any + +import pytest +import torch + +""" +This file contains the pytest hooks, fixtures etc. for the Arm test suite. +""" + + +class arm_test_options(Enum): + quantize_io = auto() + corstone_fvp = auto() + fast_fvp = auto() + + +_test_options: dict[arm_test_options, Any] = {} + +# ==== Pytest hooks ==== + + +def pytest_configure(config): + if config.option.arm_quantize_io: + _load_libquantized_ops_aot_lib() + _test_options[arm_test_options.quantize_io] = True + if config.option.arm_run_corstoneFVP: + corstone300_exists = shutil.which("FVP_Corstone_SSE-300_Ethos-U55") + corstone320_exists = shutil.which("FVP_Corstone_SSE-320") + if not (corstone300_exists and corstone320_exists): + raise RuntimeError( + "Tests are run with --arm_run_corstoneFVP but corstone FVP is not installed." + ) + _test_options[arm_test_options.corstone_fvp] = True + _test_options[arm_test_options.fast_fvp] = config.option.fast_fvp + logging.basicConfig(level=logging.INFO, stream=sys.stdout) + + +def pytest_collection_modifyitems(config, items): + """ + Skip all tests that require run on Ethos-U if the option arm_quantize_io is + not set. + """ + if not config.option.arm_quantize_io: + skip_if_aot_lib_not_loaded = pytest.mark.skip( + "Ethos-U tests can only run on FVP with quantize_io=True." + ) + + for item in items: + if re.search(r"u55|u65|u85", item.name, re.IGNORECASE): + item.add_marker(skip_if_aot_lib_not_loaded) + + +def pytest_addoption(parser): + parser.addoption("--arm_quantize_io", action="store_true") + parser.addoption("--arm_run_corstoneFVP", action="store_true") + parser.addoption("--fast_fvp", action="store_true") + + +def pytest_sessionstart(session): + pass + + +def pytest_sessionfinish(session, exitstatus): + pass + + +# ==== End of Pytest hooks ===== + + +# ==== Pytest fixtures ===== + + +@pytest.fixture(autouse=True) +def set_random_seed(): + """ + Control random numbers in Arm test suite. Default behavior is random seed, + which is set before each test. Use the env variable ARM_TEST_SEED to set the + seed you want to use to overrride the default behavior. Or set it to RANDOM + if you want to be explicit. + + Examples: + As default use random seed for each test + ARM_TEST_SEED=RANDOM pytest --config-file=/dev/null --verbose -s --color=yes backends/arm/test/ops/test_avg_pool.py -k + Rerun with a specific seed found under a random seed test + ARM_TEST_SEED=3478246 pytest --config-file=/dev/null --verbose -s --color=yes backends/arm/test/ops/test_avg_pool.py -k + """ + if os.environ.get("ARM_TEST_SEED", "RANDOM") == "RANDOM": + random.seed() # reset seed, in case any other test has fiddled with it + seed = random.randint(0, 2**32 - 1) + torch.manual_seed(seed) + else: + seed_str = os.environ.get("ARM_TEST_SEED", "0") + if str.isdigit(seed_str): + seed = int(seed_str) + random.seed(seed) + torch.manual_seed(seed) + else: + raise TypeError( + "ARM_TEST_SEED env variable must be integers or the string RANDOM" + ) + + print(f" ARM_TEST_SEED={seed} ", end=" ") + + +# ==== End of Pytest fixtures ===== + + +# ==== Custom Pytest decorators ===== + + +def expectedFailureOnFVP(test_item): + if is_option_enabled("corstone_fvp"): + test_item.__unittest_expecting_failure__ = True + return test_item + + +# ==== End of Custom Pytest decorators ===== + + +def is_option_enabled( + option: str | arm_test_options, fail_if_not_enabled: bool = False +) -> bool: + """ + Returns whether an option is successfully enabled, i.e. if the flag was + given to pytest and the necessary requirements are available. + Implemented options are: + - corstone_fvp. + - quantize_io. + + The optional parameter 'fail_if_not_enabled' makes the function raise + a RuntimeError instead of returning False. + """ + if isinstance(option, str): + option = arm_test_options[option.lower()] + + if option in _test_options and _test_options[option]: + return True + else: + if fail_if_not_enabled: + raise RuntimeError(f"Required option '{option}' for test is not enabled") + else: + return False + + +def get_option(option: arm_test_options) -> Any | None: + """ + Returns the value of an pytest option if it is set, otherwise None. + + Args: + option (arm_test_options): The option to check for. + """ + if option in _test_options: + return _test_options[option] + return None + + +def _load_libquantized_ops_aot_lib(): + """ + Load the libquantized_ops_aot_lib shared library. It's required when + arm_quantize_io is set. + """ + so_ext = { + "Darwin": "dylib", + "Linux": "so", + "Windows": "dll", + }.get(platform.system(), None) + + find_lib_cmd = [ + "find", + "cmake-out-aot-lib", + "-name", + f"libquantized_ops_aot_lib.{so_ext}", + ] + + res = subprocess.run(find_lib_cmd, capture_output=True) + if res.returncode == 0: + library_path = res.stdout.decode().strip() + torch.ops.load_library(library_path) + else: + raise RuntimeError( + f"Failed to load libquantized_ops_aot_lib.{so_ext}. Did you build it?" + ) diff --git a/backends/arm/test/misc/test_debug_feats.py b/backends/arm/test/misc/test_debug_feats.py index 4cac39af70..3343ae748c 100644 --- a/backends/arm/test/misc/test_debug_feats.py +++ b/backends/arm/test/misc/test_debug_feats.py @@ -80,7 +80,9 @@ def _is_tosa_marker_in_file(self, tmp_file): def test_MI_artifact(self): model = Linear(20, 30) - tmp_file = os.path.join(tempfile.mkdtemp(), "tosa_dump_MI.txt") + tmp_file = common.get_time_formatted_path( + tempfile.mkdtemp(), self._testMethodName + ) self._tosa_MI_pipeline(model, dump_file=tmp_file) assert os.path.exists(tmp_file), f"File {tmp_file} was not created" if self._is_tosa_marker_in_file(tmp_file): @@ -89,7 +91,9 @@ def test_MI_artifact(self): def test_BI_artifact(self): model = Linear(20, 30) - tmp_file = os.path.join(tempfile.mkdtemp(), "tosa_dump_BI.txt") + tmp_file = common.get_time_formatted_path( + tempfile.mkdtemp(), self._testMethodName + ) self._tosa_BI_pipeline(model, dump_file=tmp_file) assert os.path.exists(tmp_file), f"File {tmp_file} was not created" if self._is_tosa_marker_in_file(tmp_file): diff --git a/backends/arm/test/models/test_mobilenet_v2_arm.py b/backends/arm/test/models/test_mobilenet_v2_arm.py index 19b4254575..24af9cf41a 100644 --- a/backends/arm/test/models/test_mobilenet_v2_arm.py +++ b/backends/arm/test/models/test_mobilenet_v2_arm.py @@ -9,7 +9,7 @@ import unittest import torch -from executorch.backends.arm.test import common +from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.arm_tester import ArmTester from executorch.exir import EdgeCompileConfig @@ -96,7 +96,7 @@ def test_mv2_u55_BI(self): .to_executorch() .serialize() ) - if common.is_option_enabled("corstone300"): + if conftest.is_option_enabled("corstone_fvp"): tester.run_method_and_compare_outputs( atol=1.0, qtol=1, inputs=self.model_inputs, target_board="corstone-300" ) @@ -114,7 +114,7 @@ def test_mv2_u85_BI(self): .to_executorch() .serialize() ) - if common.is_option_enabled("corstone300"): + if conftest.is_option_enabled("corstone_fvp"): tester.run_method_and_compare_outputs( atol=1.0, qtol=1, inputs=self.model_inputs, target_board="corstone-320" ) diff --git a/backends/arm/test/ops/test_add.py b/backends/arm/test/ops/test_add.py index 6676a38add..f40037f62f 100644 --- a/backends/arm/test/ops/test_add.py +++ b/backends/arm/test/ops/test_add.py @@ -10,7 +10,7 @@ from typing import Tuple import torch -from executorch.backends.arm.test import common +from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.arm_tester import ArmTester from executorch.exir import EdgeCompileConfig from executorch.exir.backend.compile_spec_schema import CompileSpec @@ -115,7 +115,7 @@ def _test_add_ethos_BI_pipeline( .to_executorch() .serialize() ) - if common.is_option_enabled("corstone300"): + if conftest.is_option_enabled("corstone_fvp"): tester.run_method_and_compare_outputs(qtol=1, inputs=test_data) return tester diff --git a/backends/arm/test/ops/test_avg_pool.py b/backends/arm/test/ops/test_avg_pool.py index ad3ddf8c0a..4801849949 100644 --- a/backends/arm/test/ops/test_avg_pool.py +++ b/backends/arm/test/ops/test_avg_pool.py @@ -14,7 +14,7 @@ ArmQuantizer, get_symmetric_quantization_config, ) -from executorch.backends.arm.test import common +from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.arm_tester import ArmTester from executorch.backends.xnnpack.test.tester.tester import Quantize from executorch.exir.backend.backend_details import CompileSpec @@ -118,7 +118,7 @@ def _test_avgpool2d_tosa_ethos_BI_pipeline( .to_executorch() .serialize() ) - if common.is_option_enabled("corstone300"): + if conftest.is_option_enabled("corstone_fvp"): tester.run_method_and_compare_outputs(qtol=1, inputs=test_data) @parameterized.expand(test_data_suite) diff --git a/backends/arm/test/ops/test_bmm.py b/backends/arm/test/ops/test_bmm.py index 824ec46372..0952d2595f 100644 --- a/backends/arm/test/ops/test_bmm.py +++ b/backends/arm/test/ops/test_bmm.py @@ -9,7 +9,7 @@ from typing import Tuple import torch -from executorch.backends.arm.test import common +from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.arm_tester import ArmTester from executorch.exir.backend.compile_spec_schema import CompileSpec from parameterized import parameterized @@ -22,8 +22,8 @@ class TestBMM(unittest.TestCase): class BMM(torch.nn.Module): test_parameters = [ - (torch.rand(5, 3, 5), torch.rand(5, 5, 2)), (torch.rand(2, 1, 1), torch.rand(2, 1, 1)), + (torch.rand(5, 3, 5), torch.rand(5, 5, 2)), (torch.ones(1, 55, 3), torch.ones(1, 3, 44)), (10000 * torch.randn(10, 1, 10), torch.randn(10, 10, 5)), (-10 * torch.randn(2, 32, 64), 5 + 5 * torch.randn(2, 64, 32)), @@ -112,7 +112,7 @@ def _test_bmm_ethosu_BI_pipeline( .to_executorch() .serialize() ) - if common.is_option_enabled("corstone300"): + if conftest.is_option_enabled("corstone_fvp"): tester.run_method_and_compare_outputs(inputs=test_data, qtol=1) @parameterized.expand(BMM.test_parameters) @@ -147,32 +147,37 @@ def test_bmm_single_input_tosa_BI(self, operand1: torch.Tensor): @parameterized.expand(BMM.test_parameters) @unittest.expectedFailure - def test_bmm_u55_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): + def test_bmm_u55_BI_xfails(self, operand1: torch.Tensor, operand2: torch.Tensor): test_data = (operand1, operand2) self._test_bmm_ethosu_BI_pipeline( self.BMM(), common.get_u55_compile_spec(), test_data ) - @parameterized.expand(BMM.test_parameters) - @common.expectedFailureOnFVP + @parameterized.expand(BMM.test_parameters[:1]) def test_bmm_u85_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): test_data = (operand1, operand2) self._test_bmm_ethosu_BI_pipeline( self.BMM(), common.get_u85_compile_spec(), test_data ) + @parameterized.expand(BMM.test_parameters[1:]) + @conftest.expectedFailureOnFVP + def test_bmm_u85_BI_xfails(self, operand1: torch.Tensor, operand2: torch.Tensor): + test_data = (operand1, operand2) + self._test_bmm_ethosu_BI_pipeline( + self.BMM(), common.get_u85_compile_spec(), test_data + ) + # Expected to fail with error: Warning, unsupported fusing of TOSA Rescale previous operator is of type: Memcpy @parameterized.expand(BMMSingleInput.test_parameters) @unittest.expectedFailure - def test_bmm_single_input_u55_BI(self, operand1: torch.Tensor): + def test_bmm_single_input_u55_BI_xfails(self, operand1: torch.Tensor): test_data = (operand1,) self._test_bmm_ethosu_BI_pipeline( self.BMMSingleInput(), common.get_u55_compile_spec(), test_data ) - # Numerical issues on FVP, MLETORCH 534 @parameterized.expand(BMMSingleInput.test_parameters) - @common.expectedFailureOnFVP def test_bmm_single_input_u85_BI(self, operand1: torch.Tensor): test_data = (operand1,) self._test_bmm_ethosu_BI_pipeline( diff --git a/backends/arm/test/ops/test_cat.py b/backends/arm/test/ops/test_cat.py index 88846369d0..bf436a8c18 100644 --- a/backends/arm/test/ops/test_cat.py +++ b/backends/arm/test/ops/test_cat.py @@ -10,8 +10,7 @@ from typing import Tuple import torch -from executorch.backends.arm.test import common - +from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.arm_tester import ArmTester from executorch.exir.backend.compile_spec_schema import CompileSpec from parameterized import parameterized @@ -114,7 +113,7 @@ def _test_cat_ethosu_BI_pipeline( .to_executorch() .serialize() ) - if common.is_option_enabled("corstone300"): + if conftest.is_option_enabled("corstone_fvp"): tester.run_method_and_compare_outputs(inputs=test_data) @parameterized.expand(Cat.test_parameters) @@ -135,7 +134,7 @@ def test_cat_tosa_BI(self, operands: tuple[torch.Tensor, ...], dim: int): # Mismatch in provided number of inputs and model signature, MLETORCH 519 @parameterized.expand(Cat.test_parameters) - @common.expectedFailureOnFVP + @conftest.expectedFailureOnFVP def test_cat_u55_BI(self, operands: tuple[torch.Tensor, ...], dim: int): test_data = (operands, dim) self._test_cat_ethosu_BI_pipeline( @@ -144,7 +143,7 @@ def test_cat_u55_BI(self, operands: tuple[torch.Tensor, ...], dim: int): # Mismatch in provided number of inputs and model signature, MLETORCH 519 @parameterized.expand(Cat.test_parameters) - @common.expectedFailureOnFVP + @conftest.expectedFailureOnFVP def test_cat_u85_BI(self, operands: tuple[torch.Tensor, ...], dim: int): test_data = (operands, dim) self._test_cat_ethosu_BI_pipeline( diff --git a/backends/arm/test/ops/test_clone.py b/backends/arm/test/ops/test_clone.py index 6b5216a8e1..2e7726a0bc 100644 --- a/backends/arm/test/ops/test_clone.py +++ b/backends/arm/test/ops/test_clone.py @@ -17,7 +17,7 @@ ArmQuantizer, get_symmetric_quantization_config, ) -from executorch.backends.arm.test import common +from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.arm_tester import ArmTester from executorch.backends.xnnpack.test.tester.tester import Quantize @@ -96,7 +96,7 @@ def _test_clone_tosa_ethos_pipeline( .to_executorch() .serialize() ) - if common.is_option_enabled("corstone300"): + if conftest.is_option_enabled("corstone_fvp"): tester.run_method_and_compare_outputs(qtol=1, inputs=test_data) def _test_clone_tosa_u55_pipeline( diff --git a/backends/arm/test/ops/test_conv1d.py b/backends/arm/test/ops/test_conv1d.py index f00c7984a1..e6e027ed6e 100644 --- a/backends/arm/test/ops/test_conv1d.py +++ b/backends/arm/test/ops/test_conv1d.py @@ -9,8 +9,7 @@ from typing import List, Optional, Tuple, Union import torch -from executorch.backends.arm.test import common - +from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.arm_tester import ArmTester from executorch.exir.backend.backend_details import CompileSpec from parameterized import parameterized @@ -279,7 +278,7 @@ def _test_conv1d_ethosu_BI_pipeline( .to_executorch() .serialize() ) - if common.is_option_enabled("corstone300"): + if conftest.is_option_enabled("corstone_fvp"): tester.run_method_and_compare_outputs(qtol=1, inputs=test_data) @parameterized.expand(testsuite) diff --git a/backends/arm/test/ops/test_conv2d.py b/backends/arm/test/ops/test_conv2d.py index 21df4bf0d5..222945cd16 100644 --- a/backends/arm/test/ops/test_conv2d.py +++ b/backends/arm/test/ops/test_conv2d.py @@ -9,8 +9,7 @@ from typing import List, Optional, Tuple, Union import torch -from executorch.backends.arm.test import common - +from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.arm_tester import ArmTester from executorch.exir.backend.compile_spec_schema import CompileSpec from parameterized import parameterized @@ -310,7 +309,7 @@ def _test_conv2d_ethosu_BI_pipeline( .to_executorch() .serialize() ) - if common.is_option_enabled("corstone300"): + if conftest.is_option_enabled("corstone_fvp"): tester.run_method_and_compare_outputs(qtol=1, inputs=test_data) @parameterized.expand(testsuite) diff --git a/backends/arm/test/ops/test_conv_combos.py b/backends/arm/test/ops/test_conv_combos.py index 7555fff720..86bf9cb632 100644 --- a/backends/arm/test/ops/test_conv_combos.py +++ b/backends/arm/test/ops/test_conv_combos.py @@ -12,7 +12,7 @@ import pytest import torch -from executorch.backends.arm.test import common +from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.arm_tester import ArmTester from executorch.exir.backend.backend_details import CompileSpec from parameterized import parameterized @@ -253,7 +253,7 @@ def _test_conv_combo_ethos_BI_pipeline( .to_executorch() .serialize() ) - if common.is_option_enabled("corstone300"): + if conftest.is_option_enabled("corstone_fvp"): tester.run_method_and_compare_outputs(qtol=1, inputs=test_data) #################### @@ -275,8 +275,6 @@ def test_conv_meandim_u55_BI(self): model.get_inputs(), ) - # Numerical Issues on FVP, MLETORCH-520 - @common.expectedFailureOnFVP def test_conv_meandim_u85_BI(self): model = ComboConv2dMeandim() self._test_conv_combo_ethos_BI_pipeline( diff --git a/backends/arm/test/ops/test_depthwise_conv.py b/backends/arm/test/ops/test_depthwise_conv.py index 28cb9ac844..083e9aaf68 100644 --- a/backends/arm/test/ops/test_depthwise_conv.py +++ b/backends/arm/test/ops/test_depthwise_conv.py @@ -9,7 +9,7 @@ from typing import Tuple import torch -from executorch.backends.arm.test import common +from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.ops.test_conv1d import Conv1d from executorch.backends.arm.test.ops.test_conv2d import Conv2d @@ -156,6 +156,19 @@ ("two_dw_conv2d", two_dw_conv2d), ] +testsuite_conv2d_u85 = [ + ("2x2_1x6x4x4_gp6_st1", dw_conv2d_2x2_1x6x4x4_gp6_st1), + ("3x3_1x3x256x256_gp3_st1", dw_conv2d_3x3_1x3x256x256_gp3_st1), + ("3x3_1x4x256x256_gp4_st1", dw_conv2d_3x3_1x4x256x256_gp4_st1), + ("3x3_1x4x256x256_gp4_nobias", dw_conv2d_3x3_1x4x256x256_gp4_nobias), +] + +testsuite_conv2d_u85_xfails = [ + ("3x3_2x8x198x198_gp8_st3", dw_conv2d_3x3_2x8x198x198_gp8_st3), + ("two_dw_conv2d", two_dw_conv2d), +] + + testsuite_conv1d = [ ("2_1x6x4_gp6_st1", dw_conv1d_2_1x6x4_gp6_st1), ("two_dw_conv1d", two_dw_conv1d), @@ -230,7 +243,7 @@ def _test_dw_conv_ethos_BI_pipeline( .to_executorch() .serialize() ) - if common.is_option_enabled("corstone300"): + if conftest.is_option_enabled("corstone_fvp"): tester.run_method_and_compare_outputs(qtol=1, inputs=test_data) @parameterized.expand(testsuite_conv1d + testsuite_conv2d) @@ -247,7 +260,7 @@ def test_dw_conv_tosa_BI(self, test_name: str, model: torch.nn.Module): ) # Works @parameterized.expand(testsuite_conv2d, skip_on_empty=True) - @common.expectedFailureOnFVP + @unittest.expectedFailure def test_dw_conv2d_u55_BI( self, test_name: str, model: torch.nn.Module, set_quantize_io: bool = False ): @@ -274,10 +287,8 @@ def test_dw_conv1d_u55_BI( model.get_inputs(), ) - # All test cases except 3x3_1x3x256x256_gp3_st1 have numerical issues on FVP. MLETORCH-520 - @parameterized.expand(testsuite_conv1d[:-2] + testsuite_conv2d) - @common.expectedFailureOnFVP - def test_dw_conv_u85_BI_xfails( + @parameterized.expand(testsuite_conv1d + testsuite_conv2d_u85) + def test_dw_conv_u85_BI( self, test_name: str, model: torch.nn.Module, set_quantize_io: bool = False ): self._test_dw_conv_ethos_BI_pipeline( @@ -288,8 +299,10 @@ def test_dw_conv_u85_BI_xfails( model.get_inputs(), ) - @parameterized.expand(testsuite_conv1d[-2:]) - def test_dw_conv_u85_BI( + # All test cases except 3x3_1x3x256x256_gp3_st1 have numerical issues on FVP. MLETORCH-520 + @parameterized.expand(testsuite_conv2d_u85_xfails) + @conftest.expectedFailureOnFVP + def test_dw_conv_u85_BI_xfails( self, test_name: str, model: torch.nn.Module, set_quantize_io: bool = False ): self._test_dw_conv_ethos_BI_pipeline( diff --git a/backends/arm/test/ops/test_div.py b/backends/arm/test/ops/test_div.py index b3815f3e7c..eaf6a21023 100644 --- a/backends/arm/test/ops/test_div.py +++ b/backends/arm/test/ops/test_div.py @@ -11,7 +11,7 @@ from typing import Optional, Tuple, Union import torch -from executorch.backends.arm.test import common +from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.arm_tester import ArmTester from parameterized import parameterized @@ -26,18 +26,18 @@ torch.ones(5), None, ), - ( - "op_div_rank1_rand", - torch.rand(5) * 5, - torch.rand(5) * 5, - None, - ), ( "op_div_rank1_negative_ones", torch.ones(5) * (-1), torch.ones(5) * (-1), None, ), + ( + "op_div_rank1_rand", + torch.rand(5) * 5, + torch.rand(5) * 5, + None, + ), ( "op_div_rank4_ones", torch.ones(5, 10, 25, 20), @@ -157,7 +157,7 @@ def _test_div_ethos_BI_pipeline( .to_executorch() .serialize() ) - if common.is_option_enabled("corstone300"): + if conftest.is_option_enabled("corstone_fvp"): tester.run_method_and_compare_outputs(qtol=1, inputs=test_data) @parameterized.expand(test_data_suite) @@ -183,9 +183,7 @@ def test_div_tosa_BI( test_data = (input_, other_) self._test_div_tosa_BI_pipeline(self.Div(), test_data) - # Numerical issues on FVP likely due to mul op, MLETORCH-521 - @parameterized.expand(test_data_suite) - @common.expectedFailureOnFVP + @parameterized.expand(test_data_suite[:2]) def test_div_u55_BI( self, test_name: str, @@ -199,8 +197,21 @@ def test_div_u55_BI( ) # Numerical issues on FVP likely due to mul op, MLETORCH-521 - @parameterized.expand(test_data_suite) - @common.expectedFailureOnFVP + @parameterized.expand(test_data_suite[2:]) + @conftest.expectedFailureOnFVP + def test_div_u55_BI_xfails( + self, + test_name: str, + input_: Union[torch.Tensor, torch.types.Number], + other_: Union[torch.Tensor, torch.types.Number], + rounding_mode: Optional[str] = None, + ): + test_data = (input_, other_) + self._test_div_ethos_BI_pipeline( + self.Div(), common.get_u55_compile_spec(), test_data + ) + + @parameterized.expand(test_data_suite[:2]) def test_div_u85_BI( self, test_name: str, @@ -212,3 +223,18 @@ def test_div_u85_BI( self._test_div_ethos_BI_pipeline( self.Div(), common.get_u85_compile_spec(), test_data ) + + # Numerical issues on FVP likely due to mul op, MLETORCH-521 + @parameterized.expand(test_data_suite[2:]) + @conftest.expectedFailureOnFVP + def test_div_u85_BI_xfails( + self, + test_name: str, + input_: Union[torch.Tensor, torch.types.Number], + other_: Union[torch.Tensor, torch.types.Number], + rounding_mode: Optional[str] = None, + ): + test_data = (input_, other_) + self._test_div_ethos_BI_pipeline( + self.Div(), common.get_u85_compile_spec(), test_data + ) diff --git a/backends/arm/test/ops/test_exp.py b/backends/arm/test/ops/test_exp.py index f33e0a9058..57cd23bb14 100644 --- a/backends/arm/test/ops/test_exp.py +++ b/backends/arm/test/ops/test_exp.py @@ -10,7 +10,7 @@ from typing import Tuple import torch -from executorch.backends.arm.test import common +from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.arm_tester import ArmTester from executorch.exir.backend.backend_details import CompileSpec from parameterized import parameterized @@ -95,7 +95,7 @@ def _test_exp_ethosu_BI_pipeline( .to_executorch() .serialize() ) - if common.is_option_enabled("corstone300"): + if conftest.is_option_enabled("corstone_fvp"): tester.run_method_and_compare_outputs(qtol=1, inputs=test_data) @parameterized.expand(test_data_suite) diff --git a/backends/arm/test/ops/test_expand.py b/backends/arm/test/ops/test_expand.py index 27f311b546..05f72aa379 100644 --- a/backends/arm/test/ops/test_expand.py +++ b/backends/arm/test/ops/test_expand.py @@ -17,7 +17,7 @@ ArmQuantizer, get_symmetric_quantization_config, ) -from executorch.backends.arm.test import common +from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.arm_tester import ArmTester from executorch.backends.xnnpack.test.tester.tester import Quantize @@ -97,7 +97,7 @@ def _test_expand_ethosu_BI_pipeline( .to_executorch() .serialize() ) - if common.is_option_enabled("corstone300"): + if conftest.is_option_enabled("corstone_fvp"): tester.run_method_and_compare_outputs(qtol=1, inputs=test_data) @parameterized.expand(Expand.test_parameters) @@ -110,7 +110,7 @@ def test_expand_tosa_BI(self, test_input, multiples): # Mismatch in provided number of inputs and model signature, MLETORCH 519 @parameterized.expand(Expand.test_parameters) - @common.expectedFailureOnFVP + @conftest.expectedFailureOnFVP def test_expand_u55_BI(self, test_input, multiples): self._test_expand_ethosu_BI_pipeline( common.get_u55_compile_spec(), self.Expand(), (test_input, multiples) @@ -118,7 +118,7 @@ def test_expand_u55_BI(self, test_input, multiples): # Mismatch in provided number of inputs and model signature, MLETORCH 519 @parameterized.expand(Expand.test_parameters) - @common.expectedFailureOnFVP + @conftest.expectedFailureOnFVP def test_expand_u85_BI(self, test_input, multiples): self._test_expand_ethosu_BI_pipeline( common.get_u85_compile_spec(), self.Expand(), (test_input, multiples) diff --git a/backends/arm/test/ops/test_full.py b/backends/arm/test/ops/test_full.py index 9857a7b87b..2ee41f8bc1 100644 --- a/backends/arm/test/ops/test_full.py +++ b/backends/arm/test/ops/test_full.py @@ -13,7 +13,7 @@ from typing import Tuple import torch -from executorch.backends.arm.test import common +from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.arm_tester import ArmTester from executorch.exir.backend.compile_spec_schema import CompileSpec from parameterized import parameterized @@ -109,7 +109,7 @@ def _test_full_tosa_ethos_pipeline( .to_executorch() .serialize() ) - if common.is_option_enabled("corstone300"): + if conftest.is_option_enabled("corstone_fvp"): tester.run_method_and_compare_outputs(qtol=1, inputs=test_data) def _test_full_tosa_u55_pipeline(self, module: torch.nn.Module, test_data: Tuple): @@ -145,7 +145,7 @@ def test_full_tosa_BI(self, test_tensor: Tuple): # Mismatch in provided number of inputs and model signature, MLETORCH 519 @parameterized.expand(AddVariableFull.test_parameters) - @common.expectedFailureOnFVP + @conftest.expectedFailureOnFVP def test_full_u55_BI(self, test_tensor: Tuple): self._test_full_tosa_u55_pipeline( self.AddVariableFull(), @@ -154,7 +154,7 @@ def test_full_u55_BI(self, test_tensor: Tuple): # Mismatch in provided number of inputs and model signature, MLETORCH 519 @parameterized.expand(AddVariableFull.test_parameters) - @common.expectedFailureOnFVP + @conftest.expectedFailureOnFVP def test_full_u85_BI(self, test_tensor: Tuple): self._test_full_tosa_u85_pipeline( self.AddVariableFull(), diff --git a/backends/arm/test/ops/test_hardtanh.py b/backends/arm/test/ops/test_hardtanh.py index 10073c5095..1c763e8167 100644 --- a/backends/arm/test/ops/test_hardtanh.py +++ b/backends/arm/test/ops/test_hardtanh.py @@ -15,7 +15,7 @@ get_symmetric_quantization_config, ) -from executorch.backends.arm.test import common +from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.arm_tester import ArmTester from executorch.backends.xnnpack.test.tester.tester import Quantize from parameterized import parameterized @@ -108,7 +108,7 @@ def _test_hardtanh_tosa_ethosu_BI_pipeline( .to_executorch() .serialize() ) - if common.is_option_enabled("corstone300"): + if conftest.is_option_enabled("corstone_fvp"): tester.run_method_and_compare_outputs(qtol=1, inputs=test_data) @parameterized.expand(test_data_suite) diff --git a/backends/arm/test/ops/test_layer_norm.py b/backends/arm/test/ops/test_layer_norm.py index 0b06044a59..a4d3bc5adf 100644 --- a/backends/arm/test/ops/test_layer_norm.py +++ b/backends/arm/test/ops/test_layer_norm.py @@ -8,7 +8,7 @@ from typing import List, Tuple, Union import torch -from executorch.backends.arm.test import common +from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.arm_tester import ArmTester from executorch.exir.backend.backend_details import CompileSpec from parameterized import parameterized @@ -130,7 +130,7 @@ def _test_layernorm_ethosu_BI_pipeline( .to_executorch() .serialize() ) - if common.is_option_enabled("corstone300"): + if conftest.is_option_enabled("corstone_fvp"): tester.run_method_and_compare_outputs(qtol=1, inputs=test_data) @parameterized.expand(test_data_suite) @@ -158,7 +158,7 @@ def test_layer_norm_tosa_BI( # Numerical issues on FVP likely due to mul op, MLETORCH-521 # Skip tests that require transposes. @parameterized.expand(test_data_suite[:-2]) - @common.expectedFailureOnFVP + @unittest.expectedFailure def test_layer_norm_u55_BI( self, test_name: str, @@ -170,9 +170,8 @@ def test_layer_norm_u55_BI( ) # Numerical issues on FVP likely due to mul op, MLETORCH-521 - @parameterized.expand(test_data_suite[:-1]) - @common.expectedFailureOnFVP - def test_layer_norm_u85_BI_fvp_xfails( + @parameterized.expand(test_data_suite[:-2]) + def test_layer_norm_u85_BI_fvp( self, test_name: str, test_data: torch.Tensor, @@ -182,7 +181,7 @@ def test_layer_norm_u85_BI_fvp_xfails( self.LayerNorm(*model_params), common.get_u85_compile_spec(), (test_data,) ) - @parameterized.expand(test_data_suite[-1:]) + @parameterized.expand(test_data_suite[-2:]) @unittest.skip # Flaky def test_layer_norm_u85_BI( self, diff --git a/backends/arm/test/ops/test_linear.py b/backends/arm/test/ops/test_linear.py index 30d4b2890a..8aabd365af 100644 --- a/backends/arm/test/ops/test_linear.py +++ b/backends/arm/test/ops/test_linear.py @@ -11,7 +11,7 @@ from typing import Tuple import torch -from executorch.backends.arm.test import common +from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.arm_tester import ArmTester from executorch.exir import EdgeCompileConfig @@ -247,7 +247,7 @@ def test_linear_tosa_u55_BI( test_data, ) - if common.is_option_enabled("corstone300"): + if conftest.is_option_enabled("corstone_fvp"): tester.run_method_and_compare_outputs(qtol=1, inputs=test_data) @parameterized.expand(test_data_suite_rank1 + test_data_suite_rank4) diff --git a/backends/arm/test/ops/test_log.py b/backends/arm/test/ops/test_log.py index 10175d27fb..4dd1fc97c7 100644 --- a/backends/arm/test/ops/test_log.py +++ b/backends/arm/test/ops/test_log.py @@ -10,7 +10,7 @@ from typing import Tuple import torch -from executorch.backends.arm.test import common +from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.arm_tester import ArmTester from executorch.exir.backend.backend_details import CompileSpec from parameterized import parameterized @@ -95,7 +95,7 @@ def _test_log_ethosu_BI_pipeline( .to_executorch() .serialize() ) - if common.is_option_enabled("corstone300"): + if conftest.is_option_enabled("corstone_fvp"): tester.run_method_and_compare_outputs(qtol=1, inputs=test_data) @parameterized.expand(test_data_suite) diff --git a/backends/arm/test/ops/test_logsoftmax.py b/backends/arm/test/ops/test_logsoftmax.py index 5d84fa127f..910384e0a0 100644 --- a/backends/arm/test/ops/test_logsoftmax.py +++ b/backends/arm/test/ops/test_logsoftmax.py @@ -17,14 +17,29 @@ test_data_suite = [ # (test_name, test_data, dim) - ("zeros", torch.zeros(10, 10, 10, 10), 0), - ("zeros_neg_dim", torch.zeros(10, 10, 10, 10), -4), + ("zeros", torch.zeros(10, 8, 5, 2), 0), + ("zeros_neg_dim", torch.zeros(10, 7, 8, 9), -4), ("ones", torch.ones(10, 10), 1), - ("rand_neg_dim", torch.rand(10, 10, 10), -1), - ("rand", torch.rand(10, 10, 10, 10), 2), - ("rand_neg_dim", torch.rand(10, 10, 2, 3), -2), - ("randn", torch.randn(10, 10, 5, 10), 3), - ("randn_neg_dim", torch.randn(1, 10, 10, 10), -3), + ("ones_neg_dim", torch.ones(10, 3, 4), -1), + ("rand", torch.rand(1, 2, 5, 8), 2), + ("rand_neg_dim", torch.rand(2, 10, 8, 10), -2), + ("randn", torch.randn(10, 10, 10, 10), 3), + ("randn_neg_dim", torch.randn(10, 5, 8, 7), -3), +] +test_data_suite_u55 = [ + # (test_name, test_data, dim) + ("ones", torch.ones(10, 10), 1), + ("ones_neg_dim", torch.ones(10, 3, 4), -1), + ("randn_neg_dim", torch.randn(10, 5, 8, 7), -3), +] + +test_data_suite_u55_xfails = [ + # (test_name, test_data, dim) + ("zeros", torch.zeros(10, 8, 5, 2), 0), + ("zeros_neg_dim", torch.zeros(10, 7, 8, 9), -4), + ("rand", torch.rand(1, 2, 5, 8), 2), + ("rand_neg_dim", torch.rand(2, 10, 8, 10), -2), + ("randn", torch.randn(10, 10, 10, 10), 3), ] @@ -135,7 +150,7 @@ def test_logsoftmax_tosa_BI( ): self._test_logsoftmax_tosa_BI_pipeline(self.LogSoftmax(dim=dim), (test_data,)) - @parameterized.expand(test_data_suite) + @parameterized.expand(test_data_suite_u55) def test_logsoftmax_tosa_u55_BI( self, test_name: str, @@ -146,6 +161,19 @@ def test_logsoftmax_tosa_u55_BI( self.LogSoftmax(dim=dim), (test_data,) ) + # Expected to fail as this is not supported on u55. + @parameterized.expand(test_data_suite_u55_xfails) + @unittest.expectedFailure + def test_logsoftmax_tosa_u55_BI_xfails( + self, + test_name: str, + test_data: torch.Tensor, + dim: int, + ): + self._test_logsoftmax_tosa_u55_BI_pipeline( + self.LogSoftmax(dim=dim), (test_data,) + ) + @parameterized.expand(test_data_suite) def test_logsoftmax_tosa_u85_BI( self, @@ -153,6 +181,6 @@ def test_logsoftmax_tosa_u85_BI( test_data: torch.Tensor, dim: int, ): - self._test_logsoftmax_tosa_u55_BI_pipeline( + self._test_logsoftmax_tosa_u85_BI_pipeline( self.LogSoftmax(dim=dim), (test_data,) ) diff --git a/backends/arm/test/ops/test_max_pool.py b/backends/arm/test/ops/test_max_pool.py index 41526b1c77..3a12616df6 100644 --- a/backends/arm/test/ops/test_max_pool.py +++ b/backends/arm/test/ops/test_max_pool.py @@ -15,7 +15,7 @@ ArmQuantizer, get_symmetric_quantization_config, ) -from executorch.backends.arm.test import common +from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.arm_tester import ArmTester from executorch.backends.xnnpack.test.tester.tester import Quantize @@ -171,7 +171,7 @@ def test_maxpool2d_tosa_u55_BI( common.get_u55_compile_spec(permute_memory_to_nhwc=True), (test_data,), ) - if common.is_option_enabled("corstone300"): + if conftest.is_option_enabled("corstone_fvp"): tester.run_method_and_compare_outputs( qtol=1, inputs=(test_data,), target_board="corstone-300" ) @@ -188,7 +188,7 @@ def test_maxpool2d_tosa_u85_BI( common.get_u85_compile_spec(permute_memory_to_nhwc=True), (test_data,), ) - if common.is_option_enabled("corstone300"): + if conftest.is_option_enabled("corstone_fvp"): tester.run_method_and_compare_outputs( qtol=1, inputs=(test_data,), target_board="corstone-320" ) @@ -216,7 +216,7 @@ def test_maxpool2d_tosa_BI_mult_batches( ) @parameterized.expand(test_data_suite_mult_batches) - @common.expectedFailureOnFVP # TODO: MLETORCH-433 + @conftest.expectedFailureOnFVP # TODO: MLETORCH-433 def test_maxpool2d_tosa_u55_BI_mult_batches( self, test_name: str, @@ -228,13 +228,13 @@ def test_maxpool2d_tosa_u55_BI_mult_batches( common.get_u55_compile_spec(permute_memory_to_nhwc=True), (test_data,), ) - if common.is_option_enabled("corstone300"): + if conftest.is_option_enabled("corstone_fvp"): tester.run_method_and_compare_outputs( qtol=1, inputs=(test_data,), target_board="corstone-300" ) @parameterized.expand(test_data_suite_mult_batches) - @common.expectedFailureOnFVP # TODO: MLETORCH-433 + @conftest.expectedFailureOnFVP # TODO: MLETORCH-433 def test_maxpool2d_tosa_u85_BI_mult_batches( self, test_name: str, @@ -246,7 +246,7 @@ def test_maxpool2d_tosa_u85_BI_mult_batches( common.get_u85_compile_spec(permute_memory_to_nhwc=True), (test_data,), ) - if common.is_option_enabled("corstone300"): + if conftest.is_option_enabled("corstone_fvp"): tester.run_method_and_compare_outputs( qtol=1, inputs=(test_data,), target_board="corstone-320" ) diff --git a/backends/arm/test/ops/test_mean_dim.py b/backends/arm/test/ops/test_mean_dim.py index e8320cf1df..e725eb1ef4 100644 --- a/backends/arm/test/ops/test_mean_dim.py +++ b/backends/arm/test/ops/test_mean_dim.py @@ -51,7 +51,7 @@ class MeanDim(torch.nn.Module): test_data_suite = [ # (test_name, test_data) ("zeros", torch.zeros(1, 1280, 7, 7), -1, True), - ("ones", torch.ones(1, 1280, 7, 7), (-1, 2), True), + ("ones", torch.ones(1, 1280, 7, 7), (-1, 2), False), ( "rand", torch.rand(1, 1280, 7, 7), @@ -62,7 +62,7 @@ class MeanDim(torch.nn.Module): "randn", torch.randn(1, 1280, 7, 7), (-1, -2, -3), - True, + False, ), ] @@ -269,8 +269,10 @@ def test_meandim_tosa_BI( ): self._test_meandim_tosa_BI_pipeline(self.MeanDim(dim, keepdim), (test_data,)) + # Expected to fail as this is not supported on u55. @parameterized.expand(MeanDim.test_data_suite) - def test_meandim_tosa_u55_BI( + @unittest.expectedFailure + def test_meandim_tosa_u55_BI_xfails( self, test_name: str, test_data: torch.Tensor, diff --git a/backends/arm/test/ops/test_mul.py b/backends/arm/test/ops/test_mul.py index 8f0321ea5f..ced71b0072 100644 --- a/backends/arm/test/ops/test_mul.py +++ b/backends/arm/test/ops/test_mul.py @@ -8,7 +8,7 @@ import unittest import torch -from executorch.backends.arm.test import common +from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.arm_tester import ArmTester from executorch.exir.backend.backend_details import CompileSpec from parameterized import parameterized @@ -128,7 +128,7 @@ def _test_mul_ethosu_BI_pipeline( .to_executorch() .serialize() ) - if common.is_option_enabled("corstone300"): + if conftest.is_option_enabled("corstone_fvp"): tester.run_method_and_compare_outputs(qtol=1, inputs=test_data) @parameterized.expand(test_data_sute) @@ -152,9 +152,7 @@ def test_mul_tosa_BI( test_data = (input_, other_) self._test_mul_tosa_BI_pipeline(self.Mul(), test_data) - # Numerical issues on FVP, MLETORCH-521 @parameterized.expand(test_data_sute) - @common.expectedFailureOnFVP def test_mul_u55_BI( self, test_name: str, @@ -166,10 +164,7 @@ def test_mul_u55_BI( common.get_u55_compile_spec(), self.Mul(), test_data ) - # Numerical issues on FVP, MLETORCH-521 - # test_data_sute[0] works on U85 - @parameterized.expand(test_data_sute[1:]) - @common.expectedFailureOnFVP + @parameterized.expand(test_data_sute) def test_mul_u85_BI( self, test_name: str, diff --git a/backends/arm/test/ops/test_permute.py b/backends/arm/test/ops/test_permute.py index 92400215b7..581cd3cfbc 100644 --- a/backends/arm/test/ops/test_permute.py +++ b/backends/arm/test/ops/test_permute.py @@ -15,7 +15,7 @@ get_symmetric_quantization_config, ) -from executorch.backends.arm.test import common +from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.arm_tester import ArmTester from executorch.backends.xnnpack.test.tester.tester import Quantize from executorch.exir.backend.compile_spec_schema import CompileSpec @@ -117,7 +117,7 @@ def _test_permute_ethos_BI_pipeline( .to_executorch() .serialize() ) - if common.is_option_enabled("corstone300"): + if conftest.is_option_enabled("corstone_fvp"): tester.run_method_and_compare_outputs(qtol=1, inputs=test_data) @parameterized.expand(test_data_suite) @@ -155,7 +155,7 @@ def test_permute_u85_BI( # Fails since on FVP since N > 1 is not supported. MLETORCH-517 @parameterized.expand(test_data_suite[-2:]) - @common.expectedFailureOnFVP + @conftest.expectedFailureOnFVP def test_permute_u85_BI_xfails( self, test_name: str, test_data: torch.Tensor, dims: list[int] ): diff --git a/backends/arm/test/ops/test_reciprocal.py b/backends/arm/test/ops/test_reciprocal.py index 876f063c76..a71396caf3 100644 --- a/backends/arm/test/ops/test_reciprocal.py +++ b/backends/arm/test/ops/test_reciprocal.py @@ -7,7 +7,7 @@ import unittest import torch -from executorch.backends.arm.test import common +from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.arm_tester import ArmTester from parameterized import parameterized @@ -97,7 +97,7 @@ def _test_reciprocal_u55_BI_pipeline( .to_executorch() .serialize() ) - if common.is_option_enabled("corstone300"): + if conftest.is_option_enabled("corstone_fvp"): tester.run_method_and_compare_outputs(qtol=1, inputs=test_data) @parameterized.expand(test_data_suite) diff --git a/backends/arm/test/ops/test_scalars.py b/backends/arm/test/ops/test_scalars.py index cd3dd72f60..455b484b94 100644 --- a/backends/arm/test/ops/test_scalars.py +++ b/backends/arm/test/ops/test_scalars.py @@ -153,9 +153,21 @@ def _test_add_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: tuple): .run_method_and_compare_outputs(inputs=test_data) ) - # Most MI tests fail, just show one working for now. - @parameterized.expand((tensor_scalar_tests[6],)) + @parameterized.expand(tensor_scalar_tests) def test_MI(self, test_name: str, op: torch.nn.Module, x, y): + expected_exception = None + if any(token in test_name for token in ("Sub_int", "Sub__int")): + expected_exception = RuntimeError + elif test_name.endswith("_st"): + expected_exception = AttributeError + + if expected_exception: + with self.assertRaises( + expected_exception, msg=f"Test {test_name} is expected to fail." + ): + self._test_add_tosa_MI_pipeline(op, (x, y)) + return + self._test_add_tosa_MI_pipeline(op, (x, y)) # op(Scalar float, tensor) works if the scalar is constant. diff --git a/backends/arm/test/ops/test_softmax.py b/backends/arm/test/ops/test_softmax.py index f883d6b8de..30215b47f3 100644 --- a/backends/arm/test/ops/test_softmax.py +++ b/backends/arm/test/ops/test_softmax.py @@ -28,6 +28,22 @@ ("randn_neg_dim", torch.randn(10, 5, 8, 7), -3), ] +test_data_suite_u55 = [ + # (test_name, test_data, dim) + ("ones", torch.ones(10, 10), 1), + ("ones_neg_dim", torch.ones(10, 3, 4), -1), + ("randn_neg_dim", torch.randn(10, 5, 8, 7), -3), +] + +test_data_suite_u55_xfails = [ + # (test_name, test_data, dim) + ("zeros", torch.zeros(10, 8, 5, 2), 0), + ("zeros_neg_dim", torch.zeros(10, 7, 8, 9), -4), + ("rand", torch.rand(1, 2, 5, 8), 2), + ("rand_neg_dim", torch.rand(2, 10, 8, 10), -2), + ("randn", torch.randn(10, 10, 10, 10), 3), +] + class TestSoftmax(unittest.TestCase): """Tests softmax.""" @@ -136,7 +152,7 @@ def test_softmax_tosa_BI( ): self._test_softmax_tosa_BI_pipeline(self.Softmax(dim=dim), (test_data,)) - @parameterized.expand(test_data_suite) + @parameterized.expand(test_data_suite_u55) def test_softmax_tosa_u55_BI( self, test_name: str, @@ -145,6 +161,17 @@ def test_softmax_tosa_u55_BI( ): self._test_softmax_tosa_u55_BI_pipeline(self.Softmax(dim=dim), (test_data,)) + # Expected to fail as this is not supported on u55. + @parameterized.expand(test_data_suite_u55_xfails) + @unittest.expectedFailure + def test_softmax_tosa_u55_BI_xfails( + self, + test_name: str, + test_data: torch.Tensor, + dim: int, + ): + self._test_softmax_tosa_u55_BI_pipeline(self.Softmax(dim=dim), (test_data,)) + @parameterized.expand(test_data_suite) def test_softmax_tosa_u85_BI( self, diff --git a/backends/arm/test/ops/test_sub.py b/backends/arm/test/ops/test_sub.py index 327a8de994..0592141028 100644 --- a/backends/arm/test/ops/test_sub.py +++ b/backends/arm/test/ops/test_sub.py @@ -10,8 +10,7 @@ from typing import Tuple import torch -from executorch.backends.arm.test import common - +from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.arm_tester import ArmTester from executorch.exir.backend.compile_spec_schema import CompileSpec from parameterized import parameterized @@ -98,7 +97,7 @@ def _test_sub_ethosu_BI_pipeline( .to_executorch() .serialize() ) - if common.is_option_enabled("corstone300"): + if conftest.is_option_enabled("corstone_fvp"): tester.run_method_and_compare_outputs(qtol=1, inputs=test_data) @parameterized.expand(Sub.test_parameters) diff --git a/backends/arm/test/ops/test_sum.py b/backends/arm/test/ops/test_sum.py index 9cd63b0a22..111517afbb 100644 --- a/backends/arm/test/ops/test_sum.py +++ b/backends/arm/test/ops/test_sum.py @@ -35,6 +35,18 @@ class Sum(torch.nn.Module): ((torch.rand(1, 2, 8, 8), [2, 3, 0], True),), ] + test_parameters_u55: list[Tuple[exampledata_t]] = [ + ((torch.rand(10), 0, True),), + ((torch.rand(10, 10), 1, False),), + ((torch.rand(1, 2, 3, 4), 3, True),), + ] + + test_parameters_u55_xfails: list[Tuple[exampledata_t]] = [ + ((torch.rand(10, 10, 10), [-3, 1], True),), + ((torch.rand(2, 1, 5, 8), 1, False),), + ((torch.rand(1, 2, 8, 8), [2, 3, 0], True),), + ] + def forward(self, x: torch.Tensor, dim: int, keepdim: bool): return x.sum(dim=dim, keepdim=keepdim) @@ -112,7 +124,7 @@ def test_sum_tosa_MI(self, test_data: tuple[exampledata_t]): def test_sum_tosa_BI(self, test_data: tuple[exampledata_t]): self._test_sum_tosa_BI_pipeline(self.Sum(), test_data) - @parameterized.expand(Sum.test_parameters) + @parameterized.expand(Sum.test_parameters_u55) def test_sum_u55_BI(self, test_data: tuple[exampledata_t]): self._test_sum_ethosu_BI_pipeline( self.Sum(), @@ -120,6 +132,16 @@ def test_sum_u55_BI(self, test_data: tuple[exampledata_t]): common.get_u55_compile_spec(permute_memory_to_nhwc=False), ) + # Expected to fail as this is not supported on u55. + @parameterized.expand(Sum.test_parameters_u55_xfails) + @unittest.expectedFailure + def test_sum_u55_BI_xfails(self, test_data: tuple[exampledata_t]): + self._test_sum_ethosu_BI_pipeline( + self.Sum(), + test_data, + common.get_u55_compile_spec(permute_memory_to_nhwc=False), + ) + @parameterized.expand(Sum.test_parameters) def test_sum_u85_BI(self, test_data: tuple[exampledata_t]): self._test_sum_ethosu_BI_pipeline( diff --git a/backends/arm/test/ops/test_to_copy.py b/backends/arm/test/ops/test_to_copy.py new file mode 100644 index 0000000000..8499512e10 --- /dev/null +++ b/backends/arm/test/ops/test_to_copy.py @@ -0,0 +1,70 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# +# Tests the _to_copy op which is interpreted as a cast for our purposes. +# + +import unittest + +import torch + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester + +from parameterized import parameterized + + +class Cast(torch.nn.Module): + def __init__(self, target_dtype): + super().__init__() + self.target_dtype = target_dtype + + def forward(self, x: torch.Tensor): + return x.to(dtype=self.target_dtype) + + +class TestToCopy(unittest.TestCase): + """ + Tests the _to_copy operation. + + Only test unquantized graphs as explicit casting of dtypes messes with the + quantization. + + Note: This is also covered by test_scalars.py. + """ + + _TO_COPY_TEST_DATA = ( + (torch.rand((1, 2, 3, 4), dtype=torch.float16), torch.float32), + (torch.rand((1, 2, 3, 4), dtype=torch.float32), torch.float16), + (torch.randint(-127, 128, (1, 2, 3, 4), dtype=torch.int8), torch.float32), + (torch.randint(-127, 128, (1, 2, 3, 4), dtype=torch.int8), torch.int32), + (torch.randint(-127, 128, (1, 2, 3, 4), dtype=torch.int32), torch.int8), + ) + + def _test_to_copy_tosa_MI_pipeline( + self, module: torch.nn.Module, test_data: torch.Tensor + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + ) + .export() + .dump_artifact() + .check_count({"torch.ops.aten._to_copy.default": 1}) + .to_edge() + .dump_artifact() + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + @parameterized.expand(_TO_COPY_TEST_DATA) + def test_view_tosa_MI(self, test_tensor: torch.Tensor, new_dtype): + self._test_to_copy_tosa_MI_pipeline(Cast(new_dtype), (test_tensor,)) diff --git a/backends/arm/test/ops/test_var.py b/backends/arm/test/ops/test_var.py index 3a1285e6da..727cd05393 100644 --- a/backends/arm/test/ops/test_var.py +++ b/backends/arm/test/ops/test_var.py @@ -29,9 +29,9 @@ class TestVar(unittest.TestCase): class Var(torch.nn.Module): test_parameters = [ (torch.randn(1, 50, 10, 20), True, 0), - (torch.rand(1, 50, 10), True, 0), + (torch.rand(1, 50, 10), False, 0), (torch.randn(1, 30, 15, 20), True, 1), - (torch.rand(1, 50, 10, 20), True, 0.5), + (torch.rand(1, 50, 10, 20), False, 0.5), ] def forward( @@ -45,8 +45,18 @@ def forward( class VarDim(torch.nn.Module): test_parameters = [ (torch.randn(1, 50, 10, 20), 1, True, False), - (torch.rand(1, 50, 10), -2, True, False), + (torch.rand(1, 50, 10), -2, False, False), + (torch.randn(1, 30, 15, 20), -3, True, True), + (torch.rand(1, 50, 10, 20), -1, False, True), + ] + + test_parameters_u55 = [ + (torch.randn(1, 50, 10, 20), 1, True, False), (torch.randn(1, 30, 15, 20), -3, True, True), + ] + + test_parameters_u55_xfails = [ + (torch.rand(1, 50, 10), -2, True, False), (torch.rand(1, 50, 10, 20), -1, True, True), ] @@ -148,8 +158,10 @@ def test_var_tosa_MI(self, test_tensor: torch.Tensor, keepdim, correction): def test_var_tosa_BI(self, test_tensor: torch.Tensor, keepdim, correction): self._test_var_tosa_BI_pipeline(self.Var(), (test_tensor, keepdim, correction)) + # Expected to fail as this is not supported on u55. @parameterized.expand(Var.test_parameters) - def test_var_u55_BI(self, test_tensor: torch.Tensor, keepdim, correction): + @unittest.expectedFailure + def test_var_u55_BI_xfails(self, test_tensor: torch.Tensor, keepdim, correction): self._test_var_ethosu_BI_pipeline( self.Var(), common.get_u55_compile_spec(), @@ -176,7 +188,7 @@ def test_var_dim_tosa_BI(self, test_tensor: torch.Tensor, dim, keepdim, correcti self.VarDim(), (test_tensor, dim, keepdim, correction) ) - @parameterized.expand(VarDim.test_parameters) + @parameterized.expand(VarDim.test_parameters_u55) def test_var_dim_u55_BI(self, test_tensor: torch.Tensor, dim, keepdim, correction): self._test_var_ethosu_BI_pipeline( self.VarDim(), @@ -184,6 +196,18 @@ def test_var_dim_u55_BI(self, test_tensor: torch.Tensor, dim, keepdim, correctio (test_tensor, dim, keepdim, correction), ) + # Expected to fail as this is not supported on u55. + @parameterized.expand(VarDim.test_parameters_u55_xfails) + @unittest.expectedFailure + def test_var_dim_u55_BI_xfails( + self, test_tensor: torch.Tensor, dim, keepdim, correction + ): + self._test_var_ethosu_BI_pipeline( + self.VarDim(), + common.get_u55_compile_spec(), + (test_tensor, dim, keepdim, correction), + ) + @parameterized.expand(VarDim.test_parameters) def test_var_dim_u85_BI(self, test_tensor: torch.Tensor, dim, keepdim, correction): self._test_var_ethosu_BI_pipeline( @@ -208,8 +232,10 @@ def test_var_correction_tosa_BI( self.VarCorrection(), (test_tensor, dim, keepdim, correction) ) + # Expected to fail as this is not supported on u55. @parameterized.expand(VarCorrection.test_parameters) - def test_var_correction_u55_BI( + @unittest.expectedFailure + def test_var_correction_u55_BI_xfails( self, test_tensor: torch.Tensor, dim, keepdim, correction ): self._test_var_ethosu_BI_pipeline( diff --git a/backends/arm/test/passes/test_meandim_to_averagepool2d.py b/backends/arm/test/passes/test_meandim_to_averagepool2d.py index 615187fb65..978a4c6fe5 100644 --- a/backends/arm/test/passes/test_meandim_to_averagepool2d.py +++ b/backends/arm/test/passes/test_meandim_to_averagepool2d.py @@ -68,8 +68,12 @@ def test_tosa_BI_meandim_no_modification(self): .quantize() .export() .to_edge() - .check(["executorch_exir_dialects_edge__ops_aten_mean_dim"]) + .check(["aten_sum_dim_int_list"]) + .check(["aten_full_default"]) + .check(["aten_mul_tensor"]) .run_passes(test_pass_stage) - .check(["executorch_exir_dialects_edge__ops_aten_mean_dim"]) + .check(["aten_sum_dim_int_list"]) + .check(["aten_full_default"]) + .check(["aten_mul_tensor"]) .check_not(["executorch_exir_dialects_edge__ops_aten_avg_pool2d_default"]) ) diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index b61c1b465f..a8a113cf93 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -17,7 +17,7 @@ import numpy as np import torch -from executorch.backends.arm.test.common import arm_test_options, is_option_enabled +from executorch.backends.arm.test.conftest import arm_test_options, is_option_enabled from torch.export import ExportedProgram from torch.fx.node import Node @@ -218,7 +218,7 @@ def run_corstone( assert ( self._has_init_run - ), "RunnerUtil needs to be initialized using init_run() before running Corstone300." + ), "RunnerUtil needs to be initialized using init_run() before running Corstone FVP." if self.target_board not in ["corstone-300", "corstone-320"]: raise RuntimeError(f"Unknown target board: {self.target_board}") diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 3564a3325a..6784605bb4 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -22,12 +22,7 @@ ArmQuantizer, get_symmetric_quantization_config, ) -from executorch.backends.arm.test.common import ( - arm_test_options, - current_time_formated, - get_option, - get_target_board, -) +from executorch.backends.arm.test.common import get_target_board from executorch.backends.arm.test.runner_utils import ( _get_input_quantization_params, @@ -626,9 +621,6 @@ def _get_tosa_operator_distribution( def _dump_str(to_print: str, path_to_dump: Optional[str] = None): - default_dump_path = get_option(arm_test_options.dump_path) - if not path_to_dump and default_dump_path: - path_to_dump = default_dump_path / f"ArmTester_{current_time_formated()}.log" if path_to_dump: with open(path_to_dump, "a") as fp: fp.write(to_print) diff --git a/examples/arm/README.md b/examples/arm/README.md index 717a96c13e..bb68ef537b 100644 --- a/examples/arm/README.md +++ b/examples/arm/README.md @@ -24,7 +24,7 @@ To run these scripts. On a Linux system, in a terminal, with a working internet $ ./setup.sh --i-agree-to-the-contained-eula [optional-scratch-dir] # Step [2] - build + run ExecuTorch and executor_runner baremetal application -# suited for Corstone300 to run a simple PyTorch model. +# suited for Corstone FVP's to run a simple PyTorch model. $ ./run.sh [--scratch-dir=same-optional-scratch-dir-as-before] ``` ### Online Tutorial diff --git a/examples/arm/aot_arm_compiler.py b/examples/arm/aot_arm_compiler.py index ddd5fd6b0b..6d899c2146 100644 --- a/examples/arm/aot_arm_compiler.py +++ b/examples/arm/aot_arm_compiler.py @@ -172,11 +172,21 @@ def forward(self, x): can_delegate = False +class MultipleOutputsModule(torch.nn.Module): + + def forward(self, x: torch.Tensor, y: torch.Tensor): + return (x * y, x.sum(dim=-1, keepdim=True)) + + example_input = (torch.randn(10, 4, 5), torch.randn(10, 4, 5)) + can_delegate = True + + models = { "add": AddModule, "add2": AddModule2, "add3": AddModule3, "softmax": SoftmaxModule, + "MultipleOutputsModule": MultipleOutputsModule, } calibration_data = { @@ -263,7 +273,7 @@ def get_compile_spec( target, system_config="Ethos_U55_High_End_Embedded", memory_mode="Shared_Sram", - extra_flags="--debug-force-regor --output-format=raw", + extra_flags="--debug-force-regor --output-format=raw --verbose-operators --verbose-cycle-estimate", ) .set_permute_memory_format(True) .set_quantize_io(True) @@ -276,7 +286,7 @@ def get_compile_spec( target, system_config="Ethos_U85_SYS_DRAM_Mid", memory_mode="Shared_Sram", - extra_flags="--output-format=raw", + extra_flags="--output-format=raw --verbose-operators --verbose-cycle-estimate", ) .set_permute_memory_format(True) .set_quantize_io(True) diff --git a/examples/arm/run.sh b/examples/arm/run.sh index 0e5fa9db34..cbc96c4b11 100755 --- a/examples/arm/run.sh +++ b/examples/arm/run.sh @@ -213,9 +213,9 @@ function build_executorch_runner() { cmake --build ${executor_runner_path}/cmake-out --parallel -- arm_executor_runner echo "[${FUNCNAME[0]}] Generated baremetal elf file:" find ${executor_runner_path}/cmake-out -name "arm_executor_runner" - echo "executable_text: $(find ${executor_runner_path}/cmake-out -name arm_executor_runner -exec size {} \; | grep -v filename | awk '{print $1}') bytes" - echo "executable_data: $(find ${executor_runner_path}/cmake-out -name arm_executor_runner -exec size {} \; | grep -v filename | awk '{print $2}') bytes" - echo "executable_bss: $(find ${executor_runner_path}/cmake-out -name arm_executor_runner -exec size {} \; | grep -v filename | awk '{print $3}') bytes" + echo "executable_text: $(find ${executor_runner_path}/cmake-out -name arm_executor_runner -exec arm-none-eabi-size {} \; | grep -v filename | awk '{print $1}') bytes" + echo "executable_data: $(find ${executor_runner_path}/cmake-out -name arm_executor_runner -exec arm-none-eabi-size {} \; | grep -v filename | awk '{print $2}') bytes" + echo "executable_bss: $(find ${executor_runner_path}/cmake-out -name arm_executor_runner -exec arm-none-eabi-size {} \; | grep -v filename | awk '{print $3}') bytes" } # Execute the executor_runner on FVP Simulator diff --git a/examples/arm/setup.sh b/examples/arm/setup.sh index 583237729d..6f619ef058 100755 --- a/examples/arm/setup.sh +++ b/examples/arm/setup.sh @@ -55,9 +55,9 @@ if [[ "${ARCH}" == "x86_64" ]]; then corstone320_md5_checksum="3deb3c68f9b2d145833f15374203514d" # toochain - toolchain_url="https://armkeil.blob.core.windows.net/developer/Files/downloads/gnu/12.3.rel1/binrel/arm-gnu-toolchain-12.3.rel1-x86_64-arm-none-eabi.tar.xz" - toolchain_dir="arm-gnu-toolchain-12.3.rel1-x86_64-arm-none-eabi" - toolchain_md5_checksum="00ebb1b70b1f88906c61206457eacb61" + toolchain_url="https://armkeil.blob.core.windows.net/developer/Files/downloads/gnu/13.3.rel1/binrel/arm-gnu-toolchain-13.3.rel1-x86_64-arm-none-eabi.tar.xz" + toolchain_dir="arm-gnu-toolchain-13.3.rel1-x86_64-arm-none-eabi" + toolchain_md5_checksum="0601a9588bc5b9c99ad2b56133b7f118" elif [[ "${ARCH}" == "aarch64" ]] || [[ "${ARCH}" == "arm64" ]]; then # FVPs corstone300_url="https://developer.arm.com/-/media/Arm%20Developer%20Community/Downloads/OSS/FVP/Corstone-300/FVP_Corstone_SSE-300_11.22_20_Linux64_armv8l.tgz?rev=9cc6e9a32bb947ca9b21fa162144cb01&hash=7657A4CF27D42E892E3F08D452AAB073" @@ -70,13 +70,13 @@ elif [[ "${ARCH}" == "aarch64" ]] || [[ "${ARCH}" == "arm64" ]]; then # toochain if [[ "${OS}" == "Darwin" ]]; then - toolchain_url="https://armkeil.blob.core.windows.net/developer/Files/downloads/gnu/12.3.rel1/binrel/arm-gnu-toolchain-12.3.rel1-darwin-arm64-arm-none-eabi.tar.xz" - toolchain_dir="arm-gnu-toolchain-12.3.rel1-darwin-arm64-arm-none-eabi" - toolchain_md5_checksum="53d034e9423e7f470acc5ed2a066758e" + toolchain_url="https://armkeil.blob.core.windows.net/developer/Files/downloads/gnu/13.3.rel1/binrel/arm-gnu-toolchain-13.3.rel1-darwin-arm64-arm-none-eabi.tar.xz" + toolchain_dir="arm-gnu-toolchain-13.3.rel1-darwin-arm64-arm-none-eabi" + toolchain_md5_checksum="f1c18320bb3121fa89dca11399273f4e" elif [[ "${OS}" == "Linux" ]]; then - toolchain_url="https://armkeil.blob.core.windows.net/developer/Files/downloads/gnu/12.3.rel1/binrel/arm-gnu-toolchain-12.3.rel1-aarch64-arm-none-eabi.tar.xz" - toolchain_dir="arm-gnu-toolchain-12.3.rel1-aarch64-arm-none-eabi" - toolchain_md5_checksum="02c9b0d3bb1110575877d8eee1f223f2" + toolchain_url="https://armkeil.blob.core.windows.net/developer/Files/downloads/gnu/13.3.rel1/binrel/arm-gnu-toolchain-13.3.rel1-aarch64-arm-none-eabi.tar.xz" + toolchain_dir="arm-gnu-toolchain-13.3.rel1-aarch64-arm-none-eabi" + toolchain_md5_checksum="303102d97b877ebbeb36b3158994b218" fi else echo "[main] Error: only x86-64 & aarch64/arm64 architecture is supported for now!"; exit 1; @@ -89,7 +89,11 @@ ethos_u_base_rev="24.08" # tosa reference model tosa_reference_model_url="https://review.mlplatform.org/tosa/reference_model" tosa_reference_model_rev="f9ea4ab7da19318fe36b1c34d68a3e40fd6e56c5" - + +# vela +vela_repo_url="https://review.mlplatform.org/ml/ethos-u/ethos-u-vela" +vela_rev="a08fc18780827b5fefc814dd0162ee6317ce0ae7" + ######## ### Mandatory user args ######## @@ -174,15 +178,15 @@ function setup_fvp() { function setup_toolchain() { # Download and install the arm-none-eabi toolchain cd "${root_dir}" - if [[ ! -e gcc.tar.xz ]]; then + if [[ ! -e "${toolchain_dir}.tar.xz" ]]; then echo "[${FUNCNAME[0]}] Downloading toolchain ..." - curl --output gcc.tar.xz "${toolchain_url}" - verify_md5 ${toolchain_md5_checksum} gcc.tar.xz + curl --output "${toolchain_dir}.tar.xz" "${toolchain_url}" + verify_md5 ${toolchain_md5_checksum} "${toolchain_dir}.tar.xz" fi echo "[${FUNCNAME[0]}] Installing toolchain ..." rm -rf "${toolchain_dir}" - tar xf gcc.tar.xz + tar xf "${toolchain_dir}.tar.xz" toolchain_bin_path="$(cd ${toolchain_dir}/bin && pwd)" export PATH=${PATH}:${toolchain_bin_path} hash arm-none-eabi-gcc @@ -198,6 +202,7 @@ function setup_ethos_u() { cd ethos-u git reset --hard ${ethos_u_base_rev} python3 ./fetch_externals.py -c ${ethos_u_base_rev}.json fetch + pip install pyelftools echo "[${FUNCNAME[0]}] Done @ $(git describe --all --long 3> /dev/null) in ${root_dir}/ethos-u dir." } @@ -259,9 +264,9 @@ function setup_vela() { # cd "${root_dir}" if [[ ! -e ethos-u-vela ]]; then - git clone https://review.mlplatform.org/ml/ethos-u/ethos-u-vela + git clone ${vela_repo_url} repo_dir="${root_dir}/ethos-u-vela" - base_rev=57ce18c89ccc6f6309333dccb24ed30dc68b571f + base_rev=${vela_rev} patch_repo fi cd "${root_dir}/ethos-u-vela" diff --git a/examples/models/llama/source_transformation/attention_sink.py b/examples/models/llama/source_transformation/attention_sink.py index 8f4fd1ebd2..8450600d2b 100644 --- a/examples/models/llama/source_transformation/attention_sink.py +++ b/examples/models/llama/source_transformation/attention_sink.py @@ -11,7 +11,7 @@ import torch -from executorch.examples.models.llama.llama_transformer import ModelArgs, Rope +from executorch.examples.models.llama.llama_transformer import KVCache, ModelArgs, Rope from executorch.examples.models.llama.rope import ( apply_rotary_emb_to_k, hf_apply_rotary_emb_to_k, @@ -87,3 +87,122 @@ def rerotate_k( ) return self.apply_rotary_emb_to_k(k, rerotation_cos, rerotation_sin) + + +class KVCacheWithAttentionSink(KVCache): + """ + KV cache that supports attention sink. It keeps the initial few tokens as attention sink. + For other tokens, it uses a sliding window to keep the most recent tokens. + + Parameters: + window_size: the size of the sliding window + sink_size: the number of initial tokens to keep as attention sink + eviction_batch_size: the number of tokens to evict in batch when there is not enough space in the KV cache + """ + + def __init__( + self, + n_heads: int, + head_dim: int, + transpose_cache: bool, + enable_dynamic_shape: bool, + rope: RopeWithAttentionSink, + window_size: int, + sink_size: int, + eviction_batch_size: int, + max_batch_size: int = 1, + dtype=torch.float32, + ): + super().__init__( + max_batch_size=max_batch_size, + max_seq_length=window_size + sink_size, + n_heads=n_heads, + head_dim=head_dim, + transpose_cache=transpose_cache, + enable_dynamic_shape=enable_dynamic_shape, + dtype=dtype, + ) + self.rope = rope + self.window_size = window_size + self.sink_size = sink_size + self.eviction_batch_size = eviction_batch_size + self.position_shift = 0 + + def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int: + """ + Evict old tokens from the cache to make rooms for new tokens. + + Parameters: + input_pos: the start position of the incoming token in the actual sequence + seq_len: the length of the incoming sequence + rope: the rope object to use for rerotating k + + Returns: + the number of tokens to evict from the cache which is also the number of + positions to shift for incoming tokens + """ + input_pos_item = input_pos.item() + torch._check_is_size(input_pos_item) + if input_pos_item + self.position_shift + seq_len > self.max_seq_length: + # There are not enough spaces in the cache to store the new tokens. + # We need to evict some old tokens and shift some recent tokens. + num_to_evict = max( + input_pos_item + self.position_shift - self.max_seq_length + seq_len, + self.eviction_batch_size, + ) + num_to_keep = ( + input_pos_item + self.position_shift - self.sink_size - num_to_evict + ) + num_empty_space = self.window_size - num_to_keep + dim_to_slice = 2 if self.transpose_cache else 1 + k_to_keep = self.k_cache.narrow( + dim_to_slice, + self.sink_size + num_to_evict, # pyre-ignore [6] + num_to_keep, # pyre-ignore [6] + ) + if self.transpose_cache: + k_to_keep = self.rope.rerotate_k( + k=k_to_keep.transpose(1, 2), + original_position=( # pyre-ignore [6] + self.sink_size + num_to_evict + ), + new_position=self.sink_size, + ).transpose(1, 2) + else: + k_to_keep = self.rope.rerotate_k( + k=k_to_keep, + original_position=( # pyre-ignore [6] + self.sink_size + num_to_evict + ), + new_position=self.sink_size, + ) + self.k_cache = torch.cat( + [ + self.k_cache.narrow(dim_to_slice, 0, self.sink_size), + k_to_keep, + torch.zeros_like( + self.k_cache.narrow( + dim_to_slice, 0, num_empty_space # pyre-ignore [6] + ) + ), + ], + dim=dim_to_slice, + ) + self.v_cache = torch.cat( + [ + self.v_cache.narrow(dim_to_slice, 0, self.sink_size), + self.v_cache.narrow( + dim_to_slice, + self.sink_size + num_to_evict, # pyre-ignore [6] + num_to_keep, # pyre-ignore [6] + ), + torch.zeros_like( + self.v_cache.narrow( + dim_to_slice, 0, num_empty_space # pyre-ignore [6] + ) + ), + ], + dim=dim_to_slice, + ) + self.position_shift -= num_to_evict # pyre-ignore [8] + return self.position_shift diff --git a/examples/models/llama/source_transformation/test_attention_sink.py b/examples/models/llama/source_transformation/test_attention_sink.py index 8eaa992dc3..4ffecf1e9c 100644 --- a/examples/models/llama/source_transformation/test_attention_sink.py +++ b/examples/models/llama/source_transformation/test_attention_sink.py @@ -10,6 +10,7 @@ from executorch.examples.models.llama.llama_transformer import ModelArgs from executorch.examples.models.llama.source_transformation.attention_sink import ( + KVCacheWithAttentionSink, RopeWithAttentionSink, ) from parameterized import parameterized @@ -79,14 +80,10 @@ def test_get_freqs( def test_rotate(self, original_position, new_position): seq_len = 32 - q = torch.rand( - 1, seq_len, self.params.n_heads, self.params.head_dim, dtype=torch.float32 - ) + size = (1, seq_len, self.params.n_heads, self.params.head_dim) + q = torch.rand(*size, dtype=torch.float32) k = torch.rand( - 1, - seq_len, - self.params.n_heads, - self.params.head_dim, + *size, dtype=torch.float32, ) freqs_cos, freqs_sin = self.rope_with_attention_sink.get_freqs( @@ -118,3 +115,465 @@ def test_rotate(self, original_position, new_position): ) torch.testing.assert_close(rerotated_k, expected_k) + + +class KVCacheWithAttentionSinkTest(unittest.TestCase): + + _single_evict_test_cases = [ + [False, 4, 1], + [True, 4, 1], + ] + + _batch_evict_test_cases = [ + [False, 4, 8], + [True, 4, 8], + ] + + _sliding_window_test_cases = [ + [False, 0, 1], + [True, 0, 1], + ] + + def _init_cache(self, transpose_cache, sink_size, eviction_batch_size): + self.params = ModelArgs( + use_kv_cache=True, + enable_dynamic_shape=True, + max_seq_len=self.window_size + sink_size, + ) + self.rope_with_attention_sink = RopeWithAttentionSink( + params=self.params, + window_size=self.window_size, + sink_size=sink_size, + eviction_batch_size=eviction_batch_size, + ) + self.kv_cache = KVCacheWithAttentionSink( + n_heads=self.params.n_heads, + head_dim=self.params.head_dim, + transpose_cache=transpose_cache, + enable_dynamic_shape=self.params.enable_dynamic_shape, + rope=self.rope_with_attention_sink, + max_batch_size=self.max_batch_size, + window_size=self.window_size, + sink_size=sink_size, + eviction_batch_size=eviction_batch_size, + dtype=self.dtype, + ) + + def _rand_kv_with_length(self, transpose_cache, seq_len): + size = ( + ( + self.max_batch_size, + seq_len, + self.params.n_heads, + self.params.head_dim, + ) + if not transpose_cache + else ( + self.max_batch_size, + self.params.n_heads, + seq_len, + self.params.head_dim, + ) + ) + if not transpose_cache: + k = torch.rand( + *size, + dtype=self.dtype, + ) + v = torch.rand( + *size, + dtype=self.dtype, + ) + else: + k = torch.rand( + *size, + dtype=self.dtype, + ) + v = torch.rand( + *size, + dtype=self.dtype, + ) + return k, v + + def _zero_kv_with_length(self, transpose_cache, seq_len): + size = ( + ( + self.max_batch_size, + seq_len, + self.params.n_heads, + self.params.head_dim, + ) + if not transpose_cache + else ( + self.max_batch_size, + self.params.n_heads, + seq_len, + self.params.head_dim, + ) + ) + if not transpose_cache: + k = torch.zeros( + *size, + dtype=self.dtype, + ) + v = torch.zeros( + *size, + dtype=self.dtype, + ) + else: + k = torch.zeros( + *size, + dtype=self.dtype, + ) + v = torch.zeros( + *size, + dtype=self.dtype, + ) + return k, v + + def _get_dim_to_slice(self, transpose_cache): + return 2 if transpose_cache else 1 + + def _get_expected_rotated_k( + self, transpose_cache, k, original_position, new_position + ): + if transpose_cache: + return self.rope_with_attention_sink.rerotate_k( + k=k.transpose(1, 2), + original_position=original_position, + new_position=new_position, + ).transpose(1, 2) + else: + return self.rope_with_attention_sink.rerotate_k( + k=k, original_position=original_position, new_position=new_position + ) + + def setUp(self): + torch.manual_seed(42) + self.max_batch_size = 1 + self.window_size = 28 + self.dtype = torch.float32 + + @parameterized.expand( + _single_evict_test_cases + _batch_evict_test_cases + _sliding_window_test_cases + ) + def test_evict_empty_cache(self, transpose_cache, sink_size, eviction_batch_size): + self._init_cache(transpose_cache, sink_size, eviction_batch_size) + + # KV cache is empty, evict does nothing + input_pos = torch.tensor([0], dtype=torch.int32) + assert self.kv_cache.evict_tokens(input_pos, 1) == 0 + + expected_k, expected_v = self._zero_kv_with_length( + transpose_cache, self.window_size + sink_size + ) + + torch.testing.assert_close(self.kv_cache.k_cache, expected_k) + torch.testing.assert_close(self.kv_cache.v_cache, expected_v) + + @parameterized.expand( + _single_evict_test_cases + _batch_evict_test_cases + _sliding_window_test_cases + ) + def test_evict_without_shift(self, transpose_cache, sink_size, eviction_batch_size): + dimension_to_slice = self._get_dim_to_slice(transpose_cache) + + self._init_cache(transpose_cache, sink_size, eviction_batch_size) + + # KV cache has enough spaces for new tokens, no shift + input_pos = torch.tensor([0], dtype=torch.int32) + k, v = self._rand_kv_with_length(transpose_cache, 10) + + self.kv_cache.update(input_pos, k, v) + + input_pos = torch.tensor([10], dtype=torch.int32) + assert self.kv_cache.evict_tokens(input_pos, 1) == 0 + + zero_k, zero_v = self._zero_kv_with_length( + transpose_cache, self.window_size + sink_size - 10 + ) + + expected_k = torch.cat( + [ + k, + zero_k, + ], + dim=dimension_to_slice, + ) + expected_v = torch.cat( + [ + v, + zero_v, + ], + dim=dimension_to_slice, + ) + + torch.testing.assert_close(self.kv_cache.k_cache, expected_k) + torch.testing.assert_close(self.kv_cache.v_cache, expected_v) + + @parameterized.expand(_single_evict_test_cases) + def test_evict_with_some_shift( + self, transpose_cache, sink_size, eviction_batch_size + ): + dimension_to_slice = self._get_dim_to_slice(transpose_cache) + + self._init_cache(transpose_cache, sink_size, eviction_batch_size) + + # KV cache has some spaces for new tokens but not all, shift some tokens + input_pos = torch.tensor([0], dtype=torch.int32) + k, v = self._rand_kv_with_length(transpose_cache, 5) + + self.kv_cache.update(input_pos, k, v) + + input_pos = torch.tensor([5], dtype=torch.int32) + k1, v1 = self._rand_kv_with_length(transpose_cache, 5) + + self.kv_cache.update(input_pos, k1, v1) + + input_pos = torch.tensor([10], dtype=torch.int32) + assert self.kv_cache.evict_tokens(input_pos, 24) == -2 + + zero_k, zero_v = self._zero_kv_with_length(transpose_cache, 24) + expected_k = torch.cat( + [ + k.narrow(dimension_to_slice, 0, sink_size), + self._get_expected_rotated_k( + transpose_cache, k1.narrow(dimension_to_slice, 1, 4), 6, 4 + ), + zero_k, + ], + dim=dimension_to_slice, + ) + expected_v = torch.cat( + [ + v.narrow(dimension_to_slice, 0, sink_size), + v1.narrow(dimension_to_slice, 1, 4), + zero_v, + ], + dim=dimension_to_slice, + ) + + torch.testing.assert_close(self.kv_cache.k_cache, expected_k) + torch.testing.assert_close(self.kv_cache.v_cache, expected_v) + + @parameterized.expand(_single_evict_test_cases) + def test_evict_with_all_shift( + self, transpose_cache, sink_size, eviction_batch_size + ): + dimension_to_slice = self._get_dim_to_slice(transpose_cache) + + self._init_cache(transpose_cache, sink_size, eviction_batch_size) + + # KV cache has no spaces for new tokens, shift all tokens + input_pos = torch.tensor([0], dtype=torch.int32) + k, v = self._rand_kv_with_length(transpose_cache, 5) + + self.kv_cache.update(input_pos, k, v) + + input_pos = torch.tensor([5], dtype=torch.int32) + k1, v1 = self._rand_kv_with_length(transpose_cache, 27) + + self.kv_cache.update(input_pos, k1, v1) + + input_pos = torch.tensor([32], dtype=torch.int32) + assert self.kv_cache.evict_tokens(input_pos, 6) == -6 + + zero_k, zero_v = self._zero_kv_with_length(transpose_cache, 6) + expected_k = torch.cat( + [ + k.narrow(dimension_to_slice, 0, sink_size), + self._get_expected_rotated_k( + transpose_cache, k1.narrow(dimension_to_slice, 5, 22), 10, 4 + ), + zero_k, + ], + dim=dimension_to_slice, + ) + expected_v = torch.cat( + [ + v.narrow(dimension_to_slice, 0, sink_size), + v1.narrow(dimension_to_slice, 5, 22), + zero_v, + ], + dim=dimension_to_slice, + ) + + torch.testing.assert_close(self.kv_cache.k_cache, expected_k) + torch.testing.assert_close(self.kv_cache.v_cache, expected_v) + + @parameterized.expand(_sliding_window_test_cases) + def test_evict_with_some_shift_for_sliding_window( + self, transpose_cache, sink_size, eviction_batch_size + ): + dimension_to_slice = self._get_dim_to_slice(transpose_cache) + + self._init_cache(transpose_cache, sink_size, eviction_batch_size) + + # KV cache has some spaces for new tokens but not all, shift some tokens + input_pos = torch.tensor([0], dtype=torch.int32) + k, v = self._rand_kv_with_length(transpose_cache, 5) + + self.kv_cache.update(input_pos, k, v) + + input_pos = torch.tensor([5], dtype=torch.int32) + k1, v1 = self._rand_kv_with_length(transpose_cache, 5) + + self.kv_cache.update(input_pos, k1, v1) + + input_pos = torch.tensor([10], dtype=torch.int32) + assert self.kv_cache.evict_tokens(input_pos, 20) == -2 + + zero_k, zero_v = self._zero_kv_with_length(transpose_cache, 20) + expected_k = torch.cat( + [ + self._get_expected_rotated_k( + transpose_cache, k.narrow(dimension_to_slice, 2, 3), 2, 0 + ), + self._get_expected_rotated_k(transpose_cache, k1, 5, 3), + zero_k, + ], + dim=dimension_to_slice, + ) + expected_v = torch.cat( + [ + v.narrow(dimension_to_slice, 2, 3), + v1, + zero_v, + ], + dim=dimension_to_slice, + ) + + torch.testing.assert_close(self.kv_cache.k_cache, expected_k) + torch.testing.assert_close(self.kv_cache.v_cache, expected_v) + + @parameterized.expand(_sliding_window_test_cases) + def test_evict_with_all_shift_for_sliding_window( + self, transpose_cache, sink_size, eviction_batch_size + ): + dimension_to_slice = self._get_dim_to_slice(transpose_cache) + + self._init_cache(transpose_cache, sink_size, eviction_batch_size) + + # KV cache has no spaces for new tokens, shift all tokens + input_pos = torch.tensor([0], dtype=torch.int32) + k, v = self._rand_kv_with_length(transpose_cache, 5) + + self.kv_cache.update(input_pos, k, v) + + input_pos = torch.tensor([5], dtype=torch.int32) + k1, v1 = self._rand_kv_with_length(transpose_cache, 23) + + self.kv_cache.update(input_pos, k1, v1) + + input_pos = torch.tensor([28], dtype=torch.int32) + assert self.kv_cache.evict_tokens(input_pos, 6) == -6 + + zero_k, zero_v = self._zero_kv_with_length(transpose_cache, 6) + expected_k = torch.cat( + [ + self._get_expected_rotated_k( + transpose_cache, k1.narrow(dimension_to_slice, 1, 22), 6, 0 + ), + zero_k, + ], + dim=dimension_to_slice, + ) + expected_v = torch.cat( + [ + v1.narrow(dimension_to_slice, 1, 22), + zero_v, + ], + dim=dimension_to_slice, + ) + + torch.testing.assert_close(self.kv_cache.k_cache, expected_k) + torch.testing.assert_close(self.kv_cache.v_cache, expected_v) + + @parameterized.expand(_batch_evict_test_cases) + def test_batch_evict_with_seq_len( + self, transpose_cache, sink_size, eviction_batch_size + ): + dimension_to_slice = self._get_dim_to_slice(transpose_cache) + + self._init_cache(transpose_cache, sink_size, eviction_batch_size) + + # KV cache has some spaces for new tokens but not all, shift some tokens + input_pos = torch.tensor([0], dtype=torch.int32) + k, v = self._rand_kv_with_length(transpose_cache, 5) + + self.kv_cache.update(input_pos, k, v) + + input_pos = torch.tensor([5], dtype=torch.int32) + k1, v1 = self._rand_kv_with_length(transpose_cache, 25) + + self.kv_cache.update(input_pos, k1, v1) + + input_pos = torch.tensor([30], dtype=torch.int32) + assert self.kv_cache.evict_tokens(input_pos, 12) == -10 + + zero_k, zero_v = self._zero_kv_with_length(transpose_cache, 12) + expected_k = torch.cat( + [ + k.narrow(dimension_to_slice, 0, sink_size), + self._get_expected_rotated_k( + transpose_cache, k1.narrow(dimension_to_slice, 9, 16), 14, 4 + ), + zero_k, + ], + dim=dimension_to_slice, + ) + expected_v = torch.cat( + [ + v.narrow(dimension_to_slice, 0, sink_size), + v1.narrow(dimension_to_slice, 9, 16), + zero_v, + ], + dim=dimension_to_slice, + ) + + torch.testing.assert_close(self.kv_cache.k_cache, expected_k) + torch.testing.assert_close(self.kv_cache.v_cache, expected_v) + + @parameterized.expand(_batch_evict_test_cases) + def test_batch_evict_with_batch_size( + self, transpose_cache, sink_size, eviction_batch_size + ): + dimension_to_slice = self._get_dim_to_slice(transpose_cache) + + self._init_cache(transpose_cache, sink_size, eviction_batch_size) + + # KV cache has no spaces for new tokens, shift all tokens + input_pos = torch.tensor([0], dtype=torch.int32) + k, v = self._rand_kv_with_length(transpose_cache, 5) + + self.kv_cache.update(input_pos, k, v) + + input_pos = torch.tensor([5], dtype=torch.int32) + k1, v1 = self._rand_kv_with_length(transpose_cache, 25) + + self.kv_cache.update(input_pos, k1, v1) + + input_pos = torch.tensor([30], dtype=torch.int32) + assert self.kv_cache.evict_tokens(input_pos, 6) == -8 + + zero_k, zero_v = self._zero_kv_with_length(transpose_cache, 10) + expected_k = torch.cat( + [ + k.narrow(dimension_to_slice, 0, sink_size), + self._get_expected_rotated_k( + transpose_cache, k1.narrow(dimension_to_slice, 7, 18), 12, 4 + ), + zero_k, + ], + dim=dimension_to_slice, + ) + expected_v = torch.cat( + [ + v.narrow(dimension_to_slice, 0, sink_size), + v1.narrow(dimension_to_slice, 7, 18), + zero_v, + ], + dim=dimension_to_slice, + ) + + torch.testing.assert_close(self.kv_cache.k_cache, expected_k) + torch.testing.assert_close(self.kv_cache.v_cache, expected_v) diff --git a/kernels/quantized/cpu/op_dequantize.cpp b/kernels/quantized/cpu/op_dequantize.cpp index 847f764b0e..f07592fbfb 100644 --- a/kernels/quantized/cpu/op_dequantize.cpp +++ b/kernels/quantized/cpu/op_dequantize.cpp @@ -11,6 +11,9 @@ #include #include #include +#if defined(__aarch64__) || defined(__ARM_NEON) +#include +#endif /** * For an input tensor, use the scale and zero_point arguments to quantize it. @@ -22,6 +25,8 @@ namespace native { using Tensor = exec_aten::Tensor; using Scalar = exec_aten::Scalar; using ScalarType = exec_aten::ScalarType; +using StridesType = exec_aten::StridesType; +using SizesType = exec_aten::SizesType; namespace { @@ -63,6 +68,183 @@ void check_dequantize_per_tensor_args( quant_max); } +/** + * Useful to reduce a tensor `in` over a given dimension `dim` using the + * reduce function `fn`, which should have the following signature: + * void fn(const size_t size, const size_t stride, const size_t base_ix) + * where `size` and `stride` are the size and stride of the dimension being + * reduced and `base_ix` is the index of the first element of the reduction. + */ +template +void apply_over_unpacked_dim( + const Fn& fn, + const exec_aten::Tensor& in, + const int64_t& dim) { + if (in.numel() == 0) { + return; + } + + ET_CHECK_MSG(in.dim() > 0, "Input tensor must have at least one dimension"); + ET_CHECK_VALID_DIM(dim, in.dim()); + + const size_t d = ET_NORMALIZE_IX(dim, in.dim()); + const size_t dim_size = in.size(d); + const size_t outer_size = getLeadingDims(in, d); + const size_t inner_size = getTrailingDims(in, d); + // Loop through all outer dimensions + for (size_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) { + // Loop through dim + for (size_t unpacked_dim_idx = 0; unpacked_dim_idx < dim_size; + ++unpacked_dim_idx) { + fn(inner_size, outer_idx, unpacked_dim_idx); + } + } +} + +void dequantize_optimized( + const int8_t* in, + const double scale, + const int64_t zero_point, + float* out, + int64_t quant_min, + int64_t quant_max, + size_t numel) { + ET_CHECK_MSG( + zero_point >= quant_min, + "zero_point must be %" PRId64 " <= quant_min %" PRId64, + zero_point, + quant_min); + ET_CHECK_MSG( + zero_point <= quant_max, + "zero_point must be %" PRId64 " >= quant_max %" PRId64, + zero_point, + quant_max); + size_t i = 0; +#if defined(__aarch64__) || defined(__ARM_NEON) + int8x8_t zero_point_vec = vdup_n_s8(zero_point); + float32x4_t scales = vdupq_n_f32(static_cast(scale)); + constexpr int32_t kVecSize = 16; + const size_t num_vecs = numel / kVecSize; + const int8_t* in_copy = in; + float* out_copy = out; + for (; i < num_vecs; i++) { + int8x16_t in_vec = vld1q_s8(in_copy); + int16x8_t sub_vec_0_7 = vsubl_s8(vget_low_s8(in_vec), zero_point_vec); + int32x4_t sub_vec_0_3 = vmovl_s16(vget_low_s16(sub_vec_0_7)); + int32x4_t sub_vec_4_7 = vmovl_s16(vget_high_s16(sub_vec_0_7)); + float32x4_t out_vec_0_3 = vmulq_f32(vcvtq_f32_s32(sub_vec_0_3), scales); + float32x4_t out_vec_4_7 = vmulq_f32(vcvtq_f32_s32(sub_vec_4_7), scales); + + int16x8_t sub_vec_8_15 = vsubl_s8(vget_high_s8(in_vec), zero_point_vec); + int32x4_t sub_vec_8_11 = vmovl_s16(vget_low_s16(sub_vec_8_15)); + int32x4_t sub_vec_12_15 = vmovl_s16(vget_high_s16(sub_vec_8_15)); + float32x4_t out_vec_8_11 = vmulq_f32(vcvtq_f32_s32(sub_vec_8_11), scales); + float32x4_t out_vec_12_15 = vmulq_f32(vcvtq_f32_s32(sub_vec_12_15), scales); + vst1q_f32(out_copy + 0, out_vec_0_3); + vst1q_f32(out_copy + 4, out_vec_4_7); + vst1q_f32(out_copy + 8, out_vec_8_11); + vst1q_f32(out_copy + 12, out_vec_12_15); + in_copy += kVecSize; + out_copy += kVecSize; + } + i = i * kVecSize; +#endif + for (; i < numel; i++) { + out[i] = (in[i] - zero_point) * scale; + } +} + +float get_scale(const Tensor& scale, size_t channel_ix) { + ET_CHECK_MSG( + (scale.scalar_type() == ScalarType::Double) || + (scale.scalar_type() == ScalarType::Float), + "scale.scalar_type() %" PRId8 " is not double or float type", + static_cast(scale.scalar_type())); + if (scale.scalar_type() == ScalarType::Double) { + return static_cast(scale.const_data_ptr()[channel_ix]); + } else { + return scale.const_data_ptr()[channel_ix]; + } +} + +bool can_use_optimized_dequantize_per_channel( + const Tensor& in, + const ScalarType in_dtype, + exec_aten::optional& out_dtype) { + bool is_contiguous = false; +#ifdef USE_ATEN_LIB + is_contiguous = in.is_contiguous(); +#else + is_contiguous = executorch::runtime::is_contiguous_dim_order( + in.dim_order().data(), in.dim()); +#endif + if (!is_contiguous || (in_dtype != ScalarType::Char) || + (out_dtype.has_value() && out_dtype.value() != ScalarType::Float)) { + return false; + } + return true; +} + +void dequantize_per_channel_optimized( + const Tensor& in, + const Tensor& scales, + const optional& opt_zero_points, + Tensor& out, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + ScalarType in_dtype, + exec_aten::optional& out_dtype) { + check_dequantize_per_tensor_args( + in, quant_min, quant_max, in_dtype, out_dtype, out); + ET_CHECK_MSG( + in_dtype == ScalarType::Char, + "in.scalar_type() %" PRId8 " is not supported:", + static_cast(in.scalar_type())); + if (out_dtype.has_value()) { + ET_CHECK_MSG( + out_dtype.value() == ScalarType::Float, + "Only float output is supported"); + } + const int8_t* in_data = in.const_data_ptr(); + float* out_data = out.mutable_data_ptr(); + const int64_t* zero_points_data = nullptr; + if (opt_zero_points.has_value()) { + zero_points_data = opt_zero_points.value().const_data_ptr(); + } + const StridesType axis_stride = in.strides()[axis]; + const StridesType outer_stride = in.size(axis) * axis_stride; + apply_over_unpacked_dim( + [in_data, + out_data, + &scales, + zero_points_data, + axis_stride, + outer_stride, + quant_min, + quant_max]( + SizesType numel, SizesType outer_idx, SizesType unpacked_dim_idx) { + const int8_t* in_data_local = + in_data + outer_idx * outer_stride + unpacked_dim_idx * axis_stride; + const double scale = get_scale(scales, unpacked_dim_idx); + const int64_t zero_point = zero_points_data != nullptr + ? zero_points_data[unpacked_dim_idx] + : 0; + float* out_data_local = out_data + outer_idx * outer_stride + + unpacked_dim_idx * axis_stride; + dequantize_optimized( + in_data_local, + scale, + zero_point, + out_data_local, + quant_min, + quant_max, + numel); + }, + in, + axis); +} + } // namespace /** @@ -172,19 +354,6 @@ Tensor& dequantize_per_tensor_tensor_args_out( return out; } -float get_scale(const Tensor& scale, size_t channel_ix) { - ET_CHECK_MSG( - (scale.scalar_type() == ScalarType::Double) || - (scale.scalar_type() == ScalarType::Float), - "scale.scalar_type() %" PRId8 " is not double or float type", - static_cast(scale.scalar_type())); - if (scale.scalar_type() == ScalarType::Double) { - return static_cast(scale.const_data_ptr()[channel_ix]); - } else { - return scale.const_data_ptr()[channel_ix]; - } -} - Tensor& dequantize_per_channel_out( const Tensor& input, const Tensor& scale, @@ -229,6 +398,20 @@ Tensor& dequantize_per_channel_out( check_dequantize_per_tensor_args( input, quant_min, quant_max, dtype, out_dtype, out); + if (can_use_optimized_dequantize_per_channel(input, dtype, out_dtype)) { + dequantize_per_channel_optimized( + input, + scale, + opt_zero_points, + out, + axis, + quant_min, + quant_max, + dtype, + out_dtype); + return out; + } + // a list contains all dimensions except axis int64_t dims[kTensorDimensionLimit]; for (int64_t i = 0; i < input.dim() - 1; i++) { diff --git a/kernels/quantized/targets.bzl b/kernels/quantized/targets.bzl index 13ef166ece..5440400612 100644 --- a/kernels/quantized/targets.bzl +++ b/kernels/quantized/targets.bzl @@ -69,6 +69,8 @@ def define_common_targets(): "quantized_decomposed::dequantize_per_tensor.Tensor_out", "quantized_decomposed::quantize_per_tensor.out", "quantized_decomposed::quantize_per_tensor.Tensor_out", + "quantized_decomposed::dequantize_per_channel.out", + "quantized_decomposed::quantize_per_channel.out", ], ) diff --git a/kernels/quantized/test/op_dequantize_test.cpp b/kernels/quantized/test/op_dequantize_test.cpp index 8d23e74e41..676aa32690 100644 --- a/kernels/quantized/test/op_dequantize_test.cpp +++ b/kernels/quantized/test/op_dequantize_test.cpp @@ -123,13 +123,13 @@ TEST(OpDequantizeOutTest, TensorArgOverload) { EXPECT_TENSOR_EQ(out, expected); } -TEST(OpDequantizeOutTest, DequantizePerChannel) { - et_pal_init(); - TensorFactory tf_byte; +template +void test_per_channel_dtype() { + TensorFactory tf; TensorFactory tf_double; TensorFactory tf_long; - Tensor input = tf_byte.full({3, 2}, 100); + Tensor input = tf.full({3, 2}, 100); Tensor scale = tf_double.make({2}, {0.5, 1}); Tensor zero_point = tf_long.make({2}, {30, 60}); int64_t quant_min = 0; @@ -147,7 +147,7 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) { /*axis=*/1, quant_min, quant_max, - ScalarType::Byte, + DTYPE, optional(), out); @@ -168,7 +168,7 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) { /*axis=*/0, quant_min, quant_max, - ScalarType::Byte, + DTYPE, optional(), out); @@ -176,7 +176,7 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) { // Test with a different axis out = tfo.zeros({3}); - input = tf_byte.make({3}, {100, 100, 100}); + input = tf.make({3}, {100, 100, 100}); scale = tf_double.make({3}, {0.5, 0.75, 1}); zero_point = tf_long.make({3}, {30, 50, 60}); // (100 - 30) * 0.5 @@ -190,8 +190,42 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) { /*axis=*/0, quant_min, quant_max, - ScalarType::Byte, + DTYPE, + optional(), + out); + EXPECT_TENSOR_EQ(out, expected); + + // Test with a different axis + input = tf.full({3, 19}, 100); + out = tfo.zeros({3, 19}); + scale = tf_double.make({3}, {0.5, 0.75, 1}); + zero_point = tf_long.make({3}, {30, 50, 60}); + // (100 - 30) * 0.5 + // (100 - 50) * 0.75 + // (100 - 60) * 1 + expected = tfo.make( + {3, 19}, + {35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, + 35, 35, 35, 35, 35, 35, 35, 37.5, 37.5, 37.5, 37.5, 37.5, + 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, + 37.5, 37.5, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, + 40, 40, 40, 40, 40, 40, 40, 40, 40}); + dequantize_per_channel_out( + input, + scale, + zero_point, + /*axis=*/0, + quant_min, + quant_max, + DTYPE, optional(), out); + EXPECT_TENSOR_EQ(out, expected); } + +TEST(OpDequantizeOutTest, DequantizePerChannel) { + et_pal_init(); + test_per_channel_dtype(); + test_per_channel_dtype(); +}