diff --git a/CMakeLists.txt b/CMakeLists.txt index 156fb24e6b..6b76f27eb0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -721,10 +721,15 @@ if(EXECUTORCH_BUILD_PYBIND) -fPIC -frtti -fexceptions - # libtorch is built with the old ABI, so we need to do the same for any - # .cpp files that include torch, c10, or ATen targets. - -D_GLIBCXX_USE_CXX11_ABI=0 ) + if(EXECUTORCH_DO_NOT_USE_CXX11_ABI) + # libtorch is built with the old ABI, so we need to do the same for any + # .cpp files that include torch, c10, or ATen targets. Note that PyTorch + # nightly binary is built with _GLIBCXX_USE_CXX11_ABI set to 0 while its + # CI build sets this to 1 (default) + list(APPEND _pybind_compile_options -D_GLIBCXX_USE_CXX11_ABI=0) + endif() + # util lib add_library( util ${CMAKE_CURRENT_SOURCE_DIR}/extension/evalue_util/print_evalue.cpp diff --git a/README.md b/README.md index da2cb82ef9..aded66bf40 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,11 @@ We recommend using the latest release tag from the See [CONTRIBUTING.md](CONTRIBUTING.md) for details about issues, PRs, code style, CI jobs, and other development topics. +To connect with us and other community members, we invite you to join PyTorch Slack community by filling out this [form](https://docs.google.com/forms/d/e/1FAIpQLSeADnUNW36fjKjYzyHDOzEB_abKQE9b6gqqW9NXse6O0MWh0A/viewform). Once you've joined, you can: +* Head to the `#executorch-general` channel for general questions, discussion, and community support. +* Join the `#executorch-contributors` channel if you're interested in contributing directly to project development. + + ## Directory Structure ``` diff --git a/backends/arm/_passes/TARGETS b/backends/arm/_passes/TARGETS index ca20b03fcc..6ca59cfee2 100644 --- a/backends/arm/_passes/TARGETS +++ b/backends/arm/_passes/TARGETS @@ -7,6 +7,7 @@ python_library( deps = [ "//executorch/backends/arm:tosa_quant_utils", "//executorch/backends/arm:tosa_utils", + "//executorch/backends/xnnpack/_passes:xnnpack_passes", "//executorch/exir:lib", ], ) diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index a6c9cf1d06..a72cdfd1a0 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -23,6 +23,7 @@ from executorch.backends.arm._passes.decompose_layernorm_pass import ( DecomposeLayerNormPass, ) +from executorch.backends.arm._passes.decompose_linear_pass import DecomposeLinearPass from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass from executorch.backends.arm._passes.decompose_softmaxes_pass import ( DecomposeSoftmaxesPass, @@ -74,6 +75,7 @@ def transform_to_backend_pipeline( self.add_pass(ConvertSplitToSlicePass()) self.add_pass(Conv1dUnsqueezePass(exported_program)) self.add_pass(DecomposeSoftmaxesPass()) + self.add_pass(DecomposeLinearPass()) for spec in compile_spec: if spec.key == "permute_memory_format": memory_format = spec.value.decode() diff --git a/backends/arm/_passes/decompose_linear_pass.py b/backends/arm/_passes/decompose_linear_pass.py new file mode 100644 index 0000000000..30767b354d --- /dev/null +++ b/backends/arm/_passes/decompose_linear_pass.py @@ -0,0 +1,112 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +from executorch.backends.arm._passes.arm_pass_utils import ( + create_node, + get_first_fake_tensor, +) +from executorch.backends.arm.tosa_quant_utils import dq_op, q_op +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + + +class DecomposeLinearPass(ExportPass): + """ + This pass decomposes linear into a Conv2D with the required view operations. + linear(x, weights, bias) becomes: + x_reshaped = view(x) + weights_reshaped = view(weights) + conv2d = conv2d(x_reshaped, weights_reshaped, bias) + output = view(conv2d) + It also inserts q/dq pairs if the linear node was quantized. + """ + + def call(self, graph_module): + for node in graph_module.graph.nodes: + if node.op != "call_function": + continue + if node.target != exir_ops.edge.aten.linear.default: + continue + args = node.args + input = args[0] + weights = args[1] + bias = args[2] if len(args) > 2 else None + output_shape = get_first_fake_tensor(node).shape + input_shape = get_first_fake_tensor(input).shape + weights_shape = get_first_fake_tensor(weights).shape + batches = int(np.prod(input_shape[:-1])) if len(input_shape) > 1 else 1 + # input has shape (..., Ci) + input_reshaped_shape = [batches, input_shape[-1], 1, 1] + # weights have shape (Co, Ci) + weights_reshaped_shape = [weights_shape[0], weights_shape[1], 1, 1] + + with graph_module.graph.inserting_before(node): + quantize = input.op == "call_function" and input.target == dq_op + q_params = input.args[1:] if quantize else None + # Reshape input to 4D with shape (N, Ci, 1, 1) + input_reshaped = create_node( + graph=graph_module.graph, + op_target=exir_ops.edge.aten.view_copy.default, + args=(input, input_reshaped_shape), + kwargs={}, + quantize=quantize, + q_params=q_params, + ) + + quantize = weights.op == "call_function" and weights.target == dq_op + q_params = weights.args[1:] if quantize else None + # Reshape weights to 4D with shape (Co, Ci, 1, 1) + weights_reshaped = create_node( + graph=graph_module.graph, + op_target=exir_ops.edge.aten.view_copy.default, + args=(weights, weights_reshaped_shape), + kwargs={}, + quantize=quantize, + q_params=q_params, + ) + + consumer_node = list(node.users)[0] + quantize = ( + consumer_node.op == "call_function" and consumer_node.target == q_op + ) + q_params = consumer_node.args[1:] if quantize else None + conv = create_node( + graph=graph_module.graph, + op_target=exir_ops.edge.aten.convolution.default, + args=( + input_reshaped, + weights_reshaped, + bias, + [1, 1], # strides + [0, 0], # padding + [1, 1], # dilation + False, # transposed + [0, 0], # output padding + 1, # groups + ), + kwargs={}, + quantize=quantize, + q_params=q_params, + ) + + with graph_module.graph.inserting_after(conv): + # Reshape output to same rank as original input with shape (..., Co) + # No need to insert q/dq pair as Conv2D node above has inserted them if + # required. + output = create_node( + graph=graph_module.graph, + op_target=exir_ops.edge.aten.view_copy.default, + args=(conv, list(output_shape)), + kwargs={}, + ) + + node.replace_all_uses_with(output) + graph_module.graph.erase_node(node) + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + return PassResult(graph_module, True) diff --git a/backends/arm/arm_backend.py b/backends/arm/arm_backend.py index db3b368115..b55f237543 100644 --- a/backends/arm/arm_backend.py +++ b/backends/arm/arm_backend.py @@ -13,13 +13,15 @@ import logging import os -from typing import cast, final, List, Optional +from typing import final, List, Optional import serializer.tosa_serializer as ts from executorch.backends.arm.arm_vela import vela_compile from executorch.backends.arm.operators.node_visitor import get_node_visitors from executorch.backends.arm.operators.op_output import process_output from executorch.backends.arm.operators.op_placeholder import process_placeholder + +from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.backends.arm._passes.arm_pass_manager import ( ArmPassManager, ) # usort: skip @@ -31,7 +33,6 @@ from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult from executorch.exir.backend.compile_spec_schema import CompileSpec from torch.export.exported_program import ExportedProgram -from torch.fx import Node # TOSA backend debug functionality logger = logging.getLogger(__name__) @@ -87,9 +88,15 @@ def ethosu_compile_spec( if extra_flags is not None: self.compiler_flags.append(extra_flags) + base_tosa_version = "TOSA-0.80.0+BI" + if "U55" in config: + # Add the Ethos-U55 extension marker + base_tosa_version += "+u55" + self.tosa_version = TosaSpecification.create_from_string(base_tosa_version) + return self - def tosa_compile_spec(self) -> "ArmCompileSpecBuilder": + def tosa_compile_spec(self, tosa_version: str) -> "ArmCompileSpecBuilder": """ Generate compile spec for TOSA flatbuffer output """ @@ -97,6 +104,7 @@ def tosa_compile_spec(self) -> "ArmCompileSpecBuilder": self.output_format is None ), f"Output format already set: {self.output_format}" self.output_format = "tosa" + self.tosa_version = TosaSpecification.create_from_string(tosa_version) return self def dump_intermediate_artifacts_to( @@ -130,6 +138,13 @@ def build(self) -> List[CompileSpec]: """ Generate a list of compile spec objects from the builder """ + assert self.tosa_version + + # Always supply a TOSA version + self.compile_spec = [ + CompileSpec("tosa_version", str(self.tosa_version).encode()) + ] + if self.output_format == "vela": self.compile_spec += [ CompileSpec("output_format", "vela".encode()), @@ -211,11 +226,18 @@ def preprocess( # noqa: C901 if not output_format: raise RuntimeError("output format is required") + tosa_spec = TosaSpecification.create_from_compilespecs(compile_spec) + assert ( + tosa_spec is not None + ), "TOSA backend needs a TOSA version specified in the CompileSpec!" + if output_format == "vela" and len(compile_flags) == 0: # Not testing for compile_flags correctness here, just that they are # present. The compiler will give errors if they are not valid. raise RuntimeError("compile flags are required for vela output format") + logger.info(f"Converting ExportedProgram to TOSA: {tosa_spec}") + # Converted output for this subgraph, serializer needs path early as it emits # const data directly. Path created and data written only in debug builds. tosa_graph = ts.TosaSerializer(artifact_path) @@ -223,14 +245,13 @@ def preprocess( # noqa: C901 exported_program=edge_program, compile_spec=compile_spec ) - node_visitors = get_node_visitors(edge_program) + node_visitors = get_node_visitors(edge_program, tosa_spec) for node in graph_module.graph.nodes: - node = cast(Node, node) if node.op == "call_function": - process_call_function(node, tosa_graph, node_visitors) + process_call_function(node, tosa_graph, node_visitors, tosa_spec) elif node.op == "placeholder": - process_placeholder(node, tosa_graph, edge_program) + process_placeholder(node, tosa_graph, edge_program, tosa_spec) elif node.op == "output": process_output(node, tosa_graph) else: @@ -238,6 +259,9 @@ def preprocess( # noqa: C901 # any checking of compatibility. dbg_fail(node, tosa_graph, artifact_path) + # TODO: It would be awesome if this dump could somehow be done on top level and not here. + # Problem is that the desc.json has to be created on the tosa_graph object, which we can't + # access from top level. if artifact_path: tag = _get_first_delegation_tag(graph_module) dbg_tosa_dump( @@ -258,4 +282,6 @@ def preprocess( # noqa: C901 else: raise RuntimeError(f"Unknown format {output_format}") + # Continueing from above. Can I put tosa_graph into this function? + # debug_handle_map = ... return PreprocessResult(processed_bytes=binary) diff --git a/backends/arm/arm_partitioner.py b/backends/arm/arm_partitioner.py index bdd4b80f29..ef924fa434 100644 --- a/backends/arm/arm_partitioner.py +++ b/backends/arm/arm_partitioner.py @@ -8,7 +8,7 @@ import logging import operator import os -from typing import cast, final, List +from typing import Callable, cast, final, List, Optional, Tuple import torch from executorch.backends.arm.arm_backend import ArmBackend # usort: skip @@ -39,7 +39,6 @@ class TOSASupportedOperators(OperatorSupportBase): def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: supported = node.op == "call_function" and node.target in [ exir_ops.edge.aten.add.Tensor, - exir_ops.edge.aten.addmm.default, exir_ops.edge.aten.expand_copy.default, exir_ops.edge.aten.cat.default, exir_ops.edge.aten.bmm.default, @@ -49,6 +48,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: exir_ops.edge.aten.div.Tensor, exir_ops.edge.aten.exp.default, exir_ops.edge.aten.log.default, + exir_ops.edge.aten.linear.default, exir_ops.edge.aten.split_with_sizes_copy.default, exir_ops.edge.aten.full.default, exir_ops.edge.aten.mul.Tensor, @@ -137,3 +137,12 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: return PartitionResult( tagged_exported_program=exported_program, partition_tags=partition_tags ) + + def ops_to_not_decompose( + self, + ep: ExportedProgram, + ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: + ops_to_not_decompose = [ + torch.ops.aten.linear.default, + ] + return (ops_to_not_decompose, None) diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index 5e188aea77..6e51c2c141 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -8,7 +8,6 @@ from . import ( # noqa node_visitor, op_add, - op_addmm, op_avg_pool2d, op_batch_norm, op_bmm, diff --git a/backends/arm/operators/node_visitor.py b/backends/arm/operators/node_visitor.py index 99fd0388e4..9e98ebcab9 100644 --- a/backends/arm/operators/node_visitor.py +++ b/backends/arm/operators/node_visitor.py @@ -1,4 +1,4 @@ -# Copyright 2023 Arm Limited and/or its affiliates. +# Copyright 2023-2024 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -10,6 +10,7 @@ import serializer.tosa_serializer as ts import torch from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_specification import TosaSpecification from torch.export import ExportedProgram @@ -18,8 +19,19 @@ class NodeVisitor: Node Visitor pattern for lowering edge IR to TOSA """ - def __init__(self, exported_program: ExportedProgram): + # Add the currently supported node_visitor specs as default. + # This should be overriden in the NodeVisitor subclasses to target + # a specific TOSA version. + # When all node_visitors has been refactored to target a specific + # version, this list should be removed. + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80.0+BI"), + TosaSpecification.create_from_string("TOSA-0.80.0+MI"), + ] + + def __init__(self, exported_program: ExportedProgram, tosa_spec: TosaSpecification): self._exported_program = exported_program or None + self.tosa_spec = tosa_spec def define_node( self, @@ -33,16 +45,30 @@ def define_node( # container for all node visitors -_node_visitor_dict = {} +_node_visitor_dicts = { + TosaSpecification.create_from_string("TOSA-0.80.0+BI"): {}, + TosaSpecification.create_from_string("TOSA-0.80.0+MI"): {}, +} def register_node_visitor(visitor): - _node_visitor_dict[visitor.target] = visitor + for tosa_spec in visitor.tosa_specs: + _node_visitor_dicts[tosa_spec][visitor.target] = visitor + return visitor def get_node_visitors(*args) -> Dict[str, NodeVisitor]: node_visitors = {} - for target, visitor in _node_visitor_dict.items(): + tosa_spec = None + for arg in args: + if isinstance(arg, TosaSpecification): + tosa_spec = arg + break + + if tosa_spec is None: + raise RuntimeError("No TOSA specification supplied.") + + for target, visitor in _node_visitor_dicts[tosa_spec].items(): node_visitors[target] = visitor(*args) return node_visitors diff --git a/backends/arm/operators/op_add.py b/backends/arm/operators/op_add.py index ec2ade9e8a..7a71a0d2bd 100644 --- a/backends/arm/operators/op_add.py +++ b/backends/arm/operators/op_add.py @@ -11,19 +11,25 @@ import executorch.backends.arm.tosa_utils as tutils import serializer.tosa_serializer as ts +import torch from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_specification import TosaSpecification from serializer.tosa_serializer import TosaOp from torch.fx import Node @register_node_visitor -class AddVisitor(NodeVisitor): +class AddVisitor_080_BI(NodeVisitor): target = "aten.add.Tensor" + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80.0+BI"), + ] + def __init__(self, *args): super().__init__(*args) @@ -35,33 +41,72 @@ def define_node( output: TosaArg, is_quant_node: bool, ) -> None: - if is_quant_node: - input_nodes = tutils.get_two_inputs(node) + input_nodes = tutils.get_two_inputs(node) + + if not is_quant_node and not all( + tensor.meta["val"].dtype in (torch.int8, torch.int32) + for tensor in input_nodes + ): + raise RuntimeError( + f"Unexpected non quantized {AddVisitor_080_BI.target} node." + ) + needs_rescale = not ( + all(tensor.meta["val"].dtype == torch.int32 for tensor in input_nodes) + and node.meta["val"].dtype == torch.int32 + ) + + if needs_rescale: # Rescale inputs to 32 bit rescaled_inputs, scale = tqutils.rescale_nodes_to_int32( input_nodes, tosa_graph ) - # Preapre sub output tensor - broadcasted_shape = tutils.broadcast_shapes( - rescaled_inputs[0].shape, rescaled_inputs[0].shape - ) + # Prepare add output tensor + broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order) add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32) + else: + add_output = output + rescaled_inputs = inputs - # Do the INT32 Add - tosa_graph.addOperator( - TosaOp.Op().ADD, - [ - rescaled_inputs[0].name, - rescaled_inputs[1].name, - ], - [add_output.name], - None, - ) + # Do the INT32 Add + tosa_graph.addOperator( + TosaOp.Op().ADD, + [ + rescaled_inputs[0].name, + rescaled_inputs[1].name, + ], + [add_output.name], + None, + ) + if needs_rescale: # Scale output back to 8 bit tqutils.rescale_node_back_to_int8(node, add_output, scale, tosa_graph) + + +@register_node_visitor +class AddVisitor_080_MI(AddVisitor_080_BI): + # inheriting 'target' from BI class + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80.0+MI"), + ] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + if is_quant_node: + # Call the inherited define_node for handling integers + super().define_node(node, tosa_graph, inputs, output, is_quant_node) else: # FP32 Add lowering tosa_graph.addOperator( diff --git a/backends/arm/operators/op_addmm.py b/backends/arm/operators/op_addmm.py deleted file mode 100644 index b4f782db4a..0000000000 --- a/backends/arm/operators/op_addmm.py +++ /dev/null @@ -1,148 +0,0 @@ -# Copyright 2023-2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -from typing import List - -import serializer.tosa_serializer as ts -import torch -from executorch.backends.arm.operators.node_visitor import ( - NodeVisitor, - register_node_visitor, -) -from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import build_rescale, get_quant_node_args - -from executorch.backends.arm.tosa_utils import build_reshape -from executorch.exir.dialects._ops import ops as exir_ops -from serializer.tosa_serializer import TosaOp - - -@register_node_visitor -class AddmmVisitor(NodeVisitor): - target = "aten.addmm.default" - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, - inputs: List[TosaArg], - output: TosaArg, - is_quant_node: bool, - ) -> None: - bias, input, weight = inputs - - N = input.shape[0] - input_channels = input.shape[1] - output_channels = weight.shape[1] - - input_new_shape = (N, 1, 1, input_channels) - input_reshaped = tosa_graph.addIntermediate( - input_new_shape, - ts.DType.INT8 if is_quant_node else input.dtype, - ) - - build_reshape(tosa_graph, input.name, input_new_shape, input_reshaped.name) - - weight_new_shape = (output_channels, 1, 1, input_channels) - weight_reshaped = tosa_graph.addIntermediate( - weight_new_shape, - ts.DType.INT8 if is_quant_node else weight.dtype, - ) - - build_reshape(tosa_graph, weight.name, weight_new_shape, weight_reshaped.name) - - # Get the attributes of convolution. - attr = ts.TosaSerializerAttribute() - pad_attr = [0, 0, 0, 0] - stride_attr = [1, 1] - dilation_attr = [1, 1] - - input_zp = 0 - if is_quant_node: - input_node = node.all_input_nodes[1] - # rank > 2 linear layer - if input_node.target == exir_ops.edge.aten.view_copy.default: - quant_node = input_node.all_input_nodes[0] - else: - quant_node = input_node - input_zp = get_quant_node_args(quant_node).zp - attr.ConvAttribute( - pad=pad_attr, - stride=stride_attr, - dilation=dilation_attr, - input_zp=input_zp, - weight_zp=0, - local_bound=False, - ) - - conv2d_output_shape = (N, 1, 1, output_channels) - conv2d_res = tosa_graph.addIntermediate( - conv2d_output_shape, - ts.DType.INT32 if is_quant_node else output.dtype, - ) - - # U55 doesn't support tosa.matmul and tosa.fully_connected will be deprecated - # TOSA Conv2d input is NHWC and weights are in OHWI - tosa_graph.addOperator( - TosaOp.Op().CONV2D, - [ - input_reshaped.name, - weight_reshaped.name, - bias.name, - ], - [conv2d_res.name], - attr, - ) - - result_shape = (N, output_channels) - - if is_quant_node: - # Read inputs' parent nodes - _, input_node, weight_node = node.all_input_nodes - - # rank > 2 linear layer - if input_node.target == exir_ops.edge.aten.view_copy.default: - quant_node = input_node.all_input_nodes[0] - input_scale = get_quant_node_args(quant_node).scale - consumer_node = list(node.users)[0] - consumer_consumer_node = list(consumer_node.users)[0] - quant_args = get_quant_node_args(consumer_consumer_node) - consumer_node_scale = quant_args.scale - consumer_node_node_zp = quant_args.zp - else: - input_scale = get_quant_node_args(input_node).scale - consumer_node = list(node.users)[0] - quant_args = get_quant_node_args(consumer_node) - consumer_node_scale = quant_args.scale - consumer_node_node_zp = quant_args.zp - - weight_node_q_node = weight_node.all_input_nodes[0] - weight_scale = get_quant_node_args(weight_node_q_node).scale - - output_rescale_scale = (input_scale * weight_scale) / consumer_node_scale - - reshaped_res = tosa_graph.addIntermediate(result_shape, ts.DType.INT32) - build_reshape(tosa_graph, conv2d_res.name, result_shape, reshaped_res.name) - - build_rescale( - tosa_fb=tosa_graph, - scale=output_rescale_scale, - input_node=reshaped_res, - output_name=output.name, - output_type=ts.DType.INT8, - output_shape=reshaped_res.shape, - input_zp=0, - output_zp=consumer_node_node_zp, - is_double_round=False, - ) - - else: - # non-quantized case - build_reshape(tosa_graph, conv2d_res.name, result_shape, output.name) diff --git a/backends/arm/operators/op_max_pool2d.py b/backends/arm/operators/op_max_pool2d.py index 0752d8242f..a0b868f684 100644 --- a/backends/arm/operators/op_max_pool2d.py +++ b/backends/arm/operators/op_max_pool2d.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe -from typing import cast, List +from typing import List import serializer.tosa_serializer as ts import torch @@ -54,9 +54,7 @@ def define_node( output_zp = 0 if is_quant_node: - input_zp = get_quant_node_args( - cast(torch.fx.Node, node.all_input_nodes[0]) - ).zp + input_zp = get_quant_node_args(node.all_input_nodes[0]).zp output_zp = get_quant_node_args(list(node.users)[0]).zp attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_permute.py b/backends/arm/operators/op_permute.py index 69f6f6506c..8142d6d654 100644 --- a/backends/arm/operators/op_permute.py +++ b/backends/arm/operators/op_permute.py @@ -14,7 +14,6 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_utils import is_permute_node_before_addmm from serializer.tosa_serializer import TosaOp @@ -81,13 +80,6 @@ def define_node( output: TosaArg, is_quant_node: bool, ) -> None: - if is_permute_node_before_addmm(node): - ## Simply add an identityOp - tosa_graph.addOperator( - TosaOp.Op().IDENTITY, [inputs[0].name], [output.name] - ) - return - # The permutation vector describes a permutation P in default Pytorch dim_order. # For rank 4, the default dim_order NCHW. # E.g. (2,3,0,1) -> permute (n,c,h,w) to (w,c,n,h) diff --git a/backends/arm/operators/op_placeholder.py b/backends/arm/operators/op_placeholder.py index 2618c9e71d..950d4636d2 100644 --- a/backends/arm/operators/op_placeholder.py +++ b/backends/arm/operators/op_placeholder.py @@ -14,18 +14,18 @@ get_quant_node_args, is_quant_arg, ) +from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.backends.arm.tosa_utils import ( - is_bias_node_for_quantized_addmm, is_bias_node_for_quantized_conv, tosa_shape, ) -from executorch.exir.dialects._ops import ops as exir_ops from torch.export.exported_program import ExportedProgram def process_inputs( node: torch.fx.Node, tosa_graph: ts.TosaSerializer, + tosa_spec: TosaSpecification, ): """Serialize an input node""" # inputs need to be in default dim_order (contiguous memory format) @@ -55,25 +55,13 @@ def process_quantized_bias( ): """ Serialize bias node that needs to be quantized. - This can be either an addmm or conv bias node. """ consumer_node = list(node.users)[0] - if is_bias_node_for_quantized_addmm(node): - ( - _, - input_node, - weight_node_permuted, - ) = consumer_node.all_input_nodes - - weight_node = weight_node_permuted.all_input_nodes[0] - if input_node.target == exir_ops.edge.aten.view_copy.default: - input_node = input_node.all_input_nodes[0] - else: - ( - input_node, - weight_node, - _, - ) = consumer_node.all_input_nodes + ( + input_node, + weight_node, + _, + ) = consumer_node.all_input_nodes input_node_scale = get_quant_node_args(input_node).scale weight_node_scale = get_quant_node_args(weight_node).scale @@ -95,6 +83,7 @@ def process_inputs_to_parameters( node: torch.fx.Node, tosa_graph: ts.TosaSerializer, edge_program: ExportedProgram, + tosa_spec: TosaSpecification, ): """Serialize bias and non-quantized weights""" inputs = [TosaArg(node)] @@ -104,11 +93,15 @@ def process_inputs_to_parameters( assert isinstance(parameter_data, torch.Tensor), "Expect Attr to be tensor" parameter_values = parameter_data.detach().numpy() - if is_bias_node_for_quantized_addmm(node) or is_bias_node_for_quantized_conv(node): + if is_bias_node_for_quantized_conv(node): # BI bias + assert tosa_spec.support_integer(), f"{tosa_spec} doesnt't support integer" process_quantized_bias(node, tosa_graph, parameter_values) else: # MI weights or bias + if inputs[0].dtype == torch.float32: + assert tosa_spec.support_float(), f"{tosa_spec} doesn't support float" + parameter_values = np.transpose(parameter_values, inputs[0].dim_order) tosa_graph.addConst( @@ -158,15 +151,16 @@ def process_placeholder( node: torch.fx.Node, tosa_graph: ts.TosaSerializer, edge_program: ExportedProgram, + tosa_spec: TosaSpecification, ): """Wrapper for processing and serializing all types of placeholders""" assert node.name == node.target, "Expect placeholder name and target to match" assert 0 == len(node.args), "Can't handle default input values" if node.name in edge_program.graph_signature.user_inputs: - process_inputs(node, tosa_graph) + process_inputs(node, tosa_graph, tosa_spec) elif node.name in edge_program.graph_signature.inputs_to_parameters: - process_inputs_to_parameters(node, tosa_graph, edge_program) + process_inputs_to_parameters(node, tosa_graph, edge_program, tosa_spec) elif node.name in edge_program.graph_signature.inputs_to_buffers: process_inputs_to_buffers(node, tosa_graph, edge_program) elif node.name in edge_program.graph_signature.inputs_to_lifted_tensor_constants: diff --git a/backends/arm/operators/op_sub.py b/backends/arm/operators/op_sub.py index 2089b6e9e9..b86a5ea3ad 100644 --- a/backends/arm/operators/op_sub.py +++ b/backends/arm/operators/op_sub.py @@ -43,10 +43,8 @@ def define_node( input_nodes, tosa_graph ) - # Preapre sub output tensor - broadcasted_shape = tutils.broadcast_shapes( - rescaled_inputs[0].shape, rescaled_inputs[0].shape - ) + # Prepare sub output tensor + broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order) sub_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32) # Do the INT32 Sub diff --git a/backends/arm/test/common.py b/backends/arm/test/common.py index 1a155c0323..3a9818929b 100644 --- a/backends/arm/test/common.py +++ b/backends/arm/test/common.py @@ -177,30 +177,36 @@ def maybe_get_tosa_collate_path() -> str | None: def get_tosa_compile_spec( - permute_memory_to_nhwc=True, custom_path=None + tosa_version: str, permute_memory_to_nhwc=True, custom_path=None ) -> list[CompileSpec]: """ Default compile spec for TOSA tests. """ - return get_tosa_compile_spec_unbuilt(permute_memory_to_nhwc, custom_path).build() + return get_tosa_compile_spec_unbuilt( + tosa_version, permute_memory_to_nhwc, custom_path + ).build() def get_tosa_compile_spec_unbuilt( - permute_memory_to_nhwc=False, custom_path=None + tosa_version: str, permute_memory_to_nhwc=False, custom_path=None ) -> ArmCompileSpecBuilder: """Get the ArmCompileSpecBuilder for the default TOSA tests, to modify the compile spec before calling .build() to finalize it. """ if not custom_path: - custom_path = maybe_get_tosa_collate_path() + intermediate_path = maybe_get_tosa_collate_path() or tempfile.mkdtemp( + prefix="arm_tosa_" + ) + else: + intermediate_path = custom_path - if custom_path is not None and not os.path.exists(custom_path): - os.makedirs(custom_path, exist_ok=True) + if not os.path.exists(intermediate_path): + os.makedirs(intermediate_path, exist_ok=True) compile_spec_builder = ( ArmCompileSpecBuilder() - .tosa_compile_spec() + .tosa_compile_spec(tosa_version) .set_permute_memory_format(permute_memory_to_nhwc) - .dump_intermediate_artifacts_to(custom_path) + .dump_intermediate_artifacts_to(intermediate_path) ) return compile_spec_builder diff --git a/backends/arm/test/misc/test_debug_feats.py b/backends/arm/test/misc/test_debug_feats.py index 1aa3e82c76..4cac39af70 100644 --- a/backends/arm/test/misc/test_debug_feats.py +++ b/backends/arm/test/misc/test_debug_feats.py @@ -49,7 +49,7 @@ def _tosa_MI_pipeline(self, module: torch.nn.Module, dump_file=None): ArmTester( module, example_inputs=module.get_inputs(), - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .to_edge() @@ -63,12 +63,11 @@ def _tosa_BI_pipeline(self, module: torch.nn.Module, dump_file=None): ArmTester( module, example_inputs=module.get_inputs(), - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() - .to_edge() - .partition() + .to_edge_transform_and_lower() .dump_artifact(dump_file) .dump_artifact() ) @@ -108,13 +107,11 @@ def test_numerical_diff_prints(self): model, example_inputs=model.get_inputs(), compile_spec=common.get_tosa_compile_spec( - permute_memory_to_nhwc=True, - custom_path=tempfile.mkdtemp("diff_print_test"), + "TOSA-0.80.0+MI", permute_memory_to_nhwc=True ), ) .export() - .to_edge() - .partition() + .to_edge_transform_and_lower() .to_executorch() ) # We expect an assertion error here. Any other issues will cause the @@ -135,7 +132,7 @@ def test_dump_ops_and_dtypes(): ArmTester( model, example_inputs=model.get_inputs(), - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .dump_dtype_distribution() @@ -143,10 +140,7 @@ def test_dump_ops_and_dtypes(): .export() .dump_dtype_distribution() .dump_operator_distribution() - .to_edge() - .dump_dtype_distribution() - .dump_operator_distribution() - .partition() + .to_edge_transform_and_lower() .dump_dtype_distribution() .dump_operator_distribution() ) @@ -159,7 +153,7 @@ def test_dump_ops_and_dtypes_parseable(): ArmTester( model, example_inputs=model.get_inputs(), - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .dump_dtype_distribution(print_table=False) @@ -167,10 +161,7 @@ def test_dump_ops_and_dtypes_parseable(): .export() .dump_dtype_distribution(print_table=False) .dump_operator_distribution(print_table=False) - .to_edge() - .dump_dtype_distribution(print_table=False) - .dump_operator_distribution(print_table=False) - .partition() + .to_edge_transform_and_lower() .dump_dtype_distribution(print_table=False) .dump_operator_distribution(print_table=False) ) @@ -190,12 +181,11 @@ def test_collate_tosa_BI_tests(self): ArmTester( model, example_inputs=model.get_inputs(), - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() - .to_edge() - .partition() + .to_edge_transform_and_lower() .to_executorch() ) # test that the output directory is created and contains the expected files @@ -203,10 +193,10 @@ def test_collate_tosa_BI_tests(self): "test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests" ) assert os.path.exists( - "test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests/output_tag8.tosa" + "test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests/output_tag5.tosa" ) assert os.path.exists( - "test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests/desc_tag8.json" + "test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests/desc_tag5.json" ) os.environ.pop("TOSA_TESTCASES_BASE_PATH") @@ -220,12 +210,11 @@ def test_dump_tosa_ops(caplog): ArmTester( model, example_inputs=model.get_inputs(), - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() - .to_edge() - .partition() + .to_edge_transform_and_lower() .dump_operator_distribution() ) assert "TOSA operators:" in caplog.text @@ -244,8 +233,7 @@ def forward(self, x): ArmTester(model, example_inputs=(torch.ones(5),), compile_spec=compile_spec) .quantize() .export() - .to_edge() - .partition() + .to_edge_transform_and_lower() .dump_operator_distribution() ) assert "Can not get operator distribution for Vela command stream." in caplog.text diff --git a/backends/arm/test/misc/test_dim_order_guards.py b/backends/arm/test/misc/test_dim_order_guards.py index 8bad1493b1..d7406afe95 100644 --- a/backends/arm/test/misc/test_dim_order_guards.py +++ b/backends/arm/test/misc/test_dim_order_guards.py @@ -34,7 +34,7 @@ def test_tosa_MI_pipeline(self): ArmTester( module, example_inputs=module.get_inputs(), - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .to_edge() @@ -48,7 +48,7 @@ def test_tosa_BI_pipeline(self): ArmTester( module, example_inputs=module.get_inputs(), - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/misc/test_lifted_tensor.py b/backends/arm/test/misc/test_lifted_tensor.py index 29b2887431..12b8d0665b 100644 --- a/backends/arm/test/misc/test_lifted_tensor.py +++ b/backends/arm/test/misc/test_lifted_tensor.py @@ -60,7 +60,7 @@ def test_partition_lifted_tensor_tosa_MI(self, op, data): ArmTester( LiftedTensor(op), example_inputs=data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .to_edge() @@ -77,7 +77,7 @@ def test_partition_lifted_tensor_tosa_BI(self, op, data): ArmTester( LiftedTensor(op), example_inputs=data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() @@ -95,7 +95,7 @@ def test_partition_lifted_scalar_tensor_tosa_MI(self, op, data, arg1): ArmTester( LiftedScalarTensor(op, arg1), example_inputs=(data), - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .to_edge() @@ -110,7 +110,7 @@ def test_partition_lifted_scalar_tensor_tosa_BI(self, op, data, arg1): ArmTester( LiftedScalarTensor(op, arg1), example_inputs=(data), - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/misc/test_tosa_spec.py b/backends/arm/test/misc/test_tosa_spec.py new file mode 100644 index 0000000000..5cbad140b7 --- /dev/null +++ b/backends/arm/test/misc/test_tosa_spec.py @@ -0,0 +1,105 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from executorch.backends.arm.tosa_specification import ( + Tosa_0_80, + Tosa_1_00, + TosaSpecification, +) + +from executorch.exir.backend.compile_spec_schema import CompileSpec +from parameterized import parameterized + +test_valid_0_80_strings = [ + "TOSA-0.80.0+BI", + "TOSA-0.80.0+MI+8k", + "TOSA-0.80.0+BI+u55", +] +test_valid_1_00_strings = [ + "TOSA-1.00.0+INT+FP+fft", + "TOSA-1.00.0+FP+bf16+fft", + "TOSA-1.00.0+INT+int4+cf", + "TOSA-1.00.0+FP+cf+bf16+8k", + "TOSA-1.00.0+FP+INT+bf16+fft+int4+cf", + "TOSA-1.00.0+FP+INT+fft+int4+cf+8k", +] + +test_valid_1_00_extensions = { + "INT": ["int16", "int4", "var", "cf"], + "FP": ["bf16", "fp8e4m3", "fp8e5m2", "fft", "var", "cf"], +} + +test_invalid_strings = [ + "TOSA-0.80.0+bi", + "TOSA-0.80.0", + "TOSA-0.80.0+8k", + "TOSA-0.80.0+BI+MI", + "TOSA-0.80.0+BI+U55", + "TOSA-1.00.0+fft", + "TOSA-1.00.0+fp+bf16+fft", + "TOSA-1.00.0+INT+INT4+cf", + "TOSA-1.00.0+BI", + "TOSA-1.00.0+FP+FP+INT", + "TOSA-1.00.0+FP+CF+bf16", + "TOSA-1.00.0+BF16+fft+int4+cf+INT", +] + +test_compile_specs = [ + ([CompileSpec("tosa_version", "TOSA-0.80.0+BI".encode())],), + ([CompileSpec("tosa_version", "TOSA-0.80.0+BI+u55".encode())],), + ([CompileSpec("tosa_version", "TOSA-1.00.0+INT".encode())],), +] + +test_compile_specs_no_version = [ + ([CompileSpec("other_key", "TOSA-0.80.0+BI".encode())],), + ([CompileSpec("other_key", "some_value".encode())],), +] + + +class TestTosaSpecification(unittest.TestCase): + """Tests the TOSA specification class""" + + @parameterized.expand(test_valid_0_80_strings) + def test_version_string_0_80(self, version_string: str): + tosa_spec = TosaSpecification.create_from_string(version_string) + assert isinstance(tosa_spec, Tosa_0_80) + assert tosa_spec.profile in ["BI", "MI"] + + @parameterized.expand(test_valid_1_00_strings) + def test_version_string_1_00(self, version_string: str): + tosa_spec = TosaSpecification.create_from_string(version_string) + assert isinstance(tosa_spec, Tosa_1_00) + assert [profile in ["INT", "FP"] for profile in tosa_spec.profiles].count( + True + ) > 0 + + for profile in tosa_spec.profiles: + assert [ + e in test_valid_1_00_extensions[profile] for e in tosa_spec.extensions + ] + + @parameterized.expand(test_invalid_strings) + def test_invalid_version_strings(self, version_string: str): + tosa_spec = None + with self.assertRaises(ValueError): + tosa_spec = TosaSpecification.create_from_string(version_string) + + assert tosa_spec is None + + @parameterized.expand(test_compile_specs) + def test_create_from_compilespec(self, compile_specs: list[CompileSpec]): + tosa_spec = TosaSpecification.create_from_compilespecs(compile_specs) + assert isinstance(tosa_spec, TosaSpecification) + + @parameterized.expand(test_compile_specs_no_version) + def test_create_from_invalid_compilespec(self, compile_specs: list[CompileSpec]): + tosa_spec = None + with self.assertRaises(ValueError): + tosa_spec = TosaSpecification.create_from_compilespecs(compile_specs) + + assert tosa_spec is None diff --git a/backends/arm/test/models/test_mobilenet_v2_arm.py b/backends/arm/test/models/test_mobilenet_v2_arm.py index a50e2732f1..19b4254575 100644 --- a/backends/arm/test/models/test_mobilenet_v2_arm.py +++ b/backends/arm/test/models/test_mobilenet_v2_arm.py @@ -54,12 +54,12 @@ def test_mv2_tosa_MI(self): ArmTester( self.mv2, example_inputs=self.model_inputs, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+MI", permute_memory_to_nhwc=True + ), ) .export() - .to_edge(config=self._edge_compile_config) - .check(list(self.all_operators)) - .partition() + .to_edge_transform_and_lower(edge_compile_config=self._edge_compile_config) .to_executorch() .run_method_and_compare_outputs(inputs=self.model_inputs) ) @@ -69,13 +69,13 @@ def test_mv2_tosa_BI(self): ArmTester( self.mv2, example_inputs=self.model_inputs, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+BI", permute_memory_to_nhwc=True + ), ) .quantize() .export() - .to_edge(config=self._edge_compile_config) - .check(list(self.operators_after_quantization)) - .partition() + .to_edge_transform_and_lower(edge_compile_config=self._edge_compile_config) .to_executorch() # atol=1.0 is a defensive upper limit # TODO MLETROCH-72 @@ -92,9 +92,7 @@ def test_mv2_u55_BI(self): ) .quantize() .export() - .to_edge(config=self._edge_compile_config) - .check(list(self.operators_after_quantization)) - .partition() + .to_edge_transform_and_lower(edge_compile_config=self._edge_compile_config) .to_executorch() .serialize() ) @@ -112,9 +110,7 @@ def test_mv2_u85_BI(self): ) .quantize() .export() - .to_edge(config=self._edge_compile_config) - .check(list(self.operators_after_quantization)) - .partition() + .to_edge_transform_and_lower(edge_compile_config=self._edge_compile_config) .to_executorch() .serialize() ) diff --git a/backends/arm/test/ops/test_add.py b/backends/arm/test/ops/test_add.py index e3eeb187da..66e278ee0f 100644 --- a/backends/arm/test/ops/test_add.py +++ b/backends/arm/test/ops/test_add.py @@ -61,7 +61,7 @@ def _test_add_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.add.Tensor": 1}) @@ -80,7 +80,7 @@ def _test_add_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_avg_pool.py b/backends/arm/test/ops/test_avg_pool.py index 344a80a79b..afd079fb95 100644 --- a/backends/arm/test/ops/test_avg_pool.py +++ b/backends/arm/test/ops/test_avg_pool.py @@ -55,7 +55,9 @@ def _test_avgpool2d_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+MI", permute_memory_to_nhwc=True + ), ) .export() .check(["torch.ops.aten.avg_pool2d.default"]) @@ -76,7 +78,9 @@ def _test_avgpool2d_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+BI", permute_memory_to_nhwc=True + ), ) .quantize(Quantize(quantizer, get_symmetric_quantization_config())) .export() diff --git a/backends/arm/test/ops/test_batch_norm.py b/backends/arm/test/ops/test_batch_norm.py index bfe1146a90..297ac0af1c 100644 --- a/backends/arm/test/ops/test_batch_norm.py +++ b/backends/arm/test/ops/test_batch_norm.py @@ -533,7 +533,7 @@ def _test_batchnorm2d_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_not(["torch.ops.quantized_decomposed"]) @@ -561,7 +561,7 @@ def _test_batchnorm2d_no_stats_tosa_MI_pipeline( ArmTester( module, example_example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten._native_batch_norm_legit.no_stats": 1}) @@ -590,7 +590,7 @@ def _test_batchnorm2d_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_bmm.py b/backends/arm/test/ops/test_bmm.py index e4e6abb7bb..e5e9508e25 100644 --- a/backends/arm/test/ops/test_bmm.py +++ b/backends/arm/test/ops/test_bmm.py @@ -50,7 +50,7 @@ def _test_bmm_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.bmm.default": 1}) @@ -70,7 +70,7 @@ def _test_bmm_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_cat.py b/backends/arm/test/ops/test_cat.py index b0a38ce198..b380c44d52 100644 --- a/backends/arm/test/ops/test_cat.py +++ b/backends/arm/test/ops/test_cat.py @@ -56,7 +56,7 @@ def _test_cat_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.cat.default": 1}) @@ -76,7 +76,7 @@ def _test_cat_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() @@ -121,7 +121,7 @@ def test_cat_tosa_MI(self, operands: tuple[torch.Tensor, ...], dim: int): def test_cat_4d_tosa_MI(self): square = torch.ones((2, 2, 2, 2)) for dim in range(-3, 3): - test_data = ((square, square.clone()), dim) + test_data = ((square, square), dim) self._test_cat_tosa_MI_pipeline(self.Cat(), test_data) @parameterized.expand(Cat.test_parameters) diff --git a/backends/arm/test/ops/test_clone.py b/backends/arm/test/ops/test_clone.py index 9852c5c452..4721f257b0 100644 --- a/backends/arm/test/ops/test_clone.py +++ b/backends/arm/test/ops/test_clone.py @@ -47,7 +47,7 @@ def _test_clone_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.clone.default": 1}) @@ -66,7 +66,7 @@ def _test_clone_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize(Quantize(quantizer, get_symmetric_quantization_config())) .export() diff --git a/backends/arm/test/ops/test_conv1d.py b/backends/arm/test/ops/test_conv1d.py index 3b27554221..133148faef 100644 --- a/backends/arm/test/ops/test_conv1d.py +++ b/backends/arm/test/ops/test_conv1d.py @@ -226,7 +226,9 @@ def _test_conv1d_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+MI", permute_memory_to_nhwc=True + ), ) .export() .to_edge() @@ -246,7 +248,9 @@ def _test_conv1d_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+BI", permute_memory_to_nhwc=True + ), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_conv2d.py b/backends/arm/test/ops/test_conv2d.py index 46adfc8a01..43c3e85139 100644 --- a/backends/arm/test/ops/test_conv2d.py +++ b/backends/arm/test/ops/test_conv2d.py @@ -253,7 +253,9 @@ def _test_conv2d_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+MI", permute_memory_to_nhwc=True + ), ) .export() .to_edge() @@ -273,7 +275,9 @@ def _test_conv2d_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+BI", permute_memory_to_nhwc=True + ), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_conv_combos.py b/backends/arm/test/ops/test_conv_combos.py index 4b45b67126..3e9bdef958 100644 --- a/backends/arm/test/ops/test_conv_combos.py +++ b/backends/arm/test/ops/test_conv_combos.py @@ -192,7 +192,9 @@ def _test_conv_combo_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+MI", permute_memory_to_nhwc=True + ), ) .export() .to_edge() @@ -214,7 +216,9 @@ def _test_conv_combo_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+BI", permute_memory_to_nhwc=True + ), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_depthwise_conv.py b/backends/arm/test/ops/test_depthwise_conv.py index 01ffbc1054..4bfa863c49 100644 --- a/backends/arm/test/ops/test_depthwise_conv.py +++ b/backends/arm/test/ops/test_depthwise_conv.py @@ -177,7 +177,9 @@ def _test_dw_conv_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+MI", permute_memory_to_nhwc=True + ), ) .export() .to_edge() @@ -195,7 +197,9 @@ def _test_dw_conv_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+BI", permute_memory_to_nhwc=True + ), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_div.py b/backends/arm/test/ops/test_div.py index 84a8d53f9d..28cc686690 100644 --- a/backends/arm/test/ops/test_div.py +++ b/backends/arm/test/ops/test_div.py @@ -102,7 +102,7 @@ def _test_div_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.div.Tensor": 1}) @@ -121,7 +121,7 @@ def _test_div_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_exp.py b/backends/arm/test/ops/test_exp.py index 6e85d8fe49..c706b7b206 100644 --- a/backends/arm/test/ops/test_exp.py +++ b/backends/arm/test/ops/test_exp.py @@ -40,7 +40,7 @@ def _test_exp_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check(["torch.ops.aten.exp.default"]) @@ -58,7 +58,7 @@ def _test_exp_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tuple): ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_expand.py b/backends/arm/test/ops/test_expand.py index aa13a6475c..effa7ce713 100644 --- a/backends/arm/test/ops/test_expand.py +++ b/backends/arm/test/ops/test_expand.py @@ -46,7 +46,7 @@ def _test_expand_tosa_MI_pipeline(self, module: torch.nn.Module, test_data: Tupl ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.expand.default": 1}) @@ -64,7 +64,7 @@ def _test_expand_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tupl ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize(Quantize(quantizer, get_symmetric_quantization_config())) .export() diff --git a/backends/arm/test/ops/test_full.py b/backends/arm/test/ops/test_full.py index 2722edef32..d4cfc5c369 100644 --- a/backends/arm/test/ops/test_full.py +++ b/backends/arm/test/ops/test_full.py @@ -57,7 +57,7 @@ def _test_full_tosa_MI_pipeline( ArmTester( module, example_inputs=example_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.full.default": 1}) @@ -80,7 +80,7 @@ def _test_full_tosa_BI_pipeline( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec( - permute_memory_to_nhwc=permute_memory_to_nhwc + "TOSA-0.80.0+BI", permute_memory_to_nhwc=permute_memory_to_nhwc ), ) .quantize() diff --git a/backends/arm/test/ops/test_hardtanh.py b/backends/arm/test/ops/test_hardtanh.py index c7c3736e37..a9f12abdf0 100644 --- a/backends/arm/test/ops/test_hardtanh.py +++ b/backends/arm/test/ops/test_hardtanh.py @@ -52,7 +52,7 @@ def _test_hardtanh_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check(["torch.ops.aten.hardtanh.default"]) @@ -73,7 +73,7 @@ def _test_hardtanh_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize(Quantize(quantizer, get_symmetric_quantization_config())) .export() diff --git a/backends/arm/test/ops/test_layer_norm.py b/backends/arm/test/ops/test_layer_norm.py index 0150c20524..f059d71eba 100644 --- a/backends/arm/test/ops/test_layer_norm.py +++ b/backends/arm/test/ops/test_layer_norm.py @@ -74,7 +74,9 @@ def _test_layernorm_tosa_MI_pipeline( ArmTester( model=module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+MI", permute_memory_to_nhwc=True + ), ) .export() .check(["torch.ops.aten.layer_norm.default"]) @@ -93,7 +95,9 @@ def _test_layernorm_tosa_BI_pipeline( ArmTester( model=module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+BI", permute_memory_to_nhwc=True + ), ) .quantize() .check_not(["torch.ops.aten.layer_norm.default"]) diff --git a/backends/arm/test/ops/test_linear.py b/backends/arm/test/ops/test_linear.py index 3f68ab0251..c7a475035d 100644 --- a/backends/arm/test/ops/test_linear.py +++ b/backends/arm/test/ops/test_linear.py @@ -23,70 +23,82 @@ test_data_suite_rank1 = [ - # (test_name, test_data, out_features) + # (test_name, test_data, out_features, has_bias) ( "model_linear_rank1_zeros", torch.zeros(10), 15, + True, ), ( "model_linear_rank1_ones", torch.ones(10), 15, + False, ), ( "model_linear_rank1_negative_ones", torch.ones(10) * (-1), 20, + True, ), ( "model_linear_rank1_rand", torch.rand(10), 10, + True, ), ( "model_linear_rank1_negative_large_rand", torch.rand(10) * (-100), 30, + False, ), ( "model_linear_rank1_large_randn", torch.randn(15) * 100, 20, + True, ), ] test_data_suite_rank4 = [ - # (test_name, test_data, out_features) + # (test_name, test_data, out_features, has_bias) ( "model_linear_rank4_zeros", torch.zeros(5, 10, 25, 20), 30, + True, ), ( "model_linear_rank4_ones", torch.ones(5, 10, 25, 20), 30, + False, ), ( "model_linear_rank4_negative_ones", torch.ones(5, 10, 25, 20) * (-1), 30, + True, ), ( "model_linear_rank4_rand", torch.rand(5, 10, 25, 20), 30, + False, ), ( "model_linear_rank4_negative_large_rand", torch.rand(5, 10, 25, 20) * (-100), 30, + True, ), ( "model_linear_rank4_large_randn", torch.randn(5, 10, 25, 20) * 100, 30, + False, ), ] @@ -122,13 +134,14 @@ def _test_linear_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=False), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+MI", permute_memory_to_nhwc=True + ), ) .export() .check_count({"torch.ops.aten.linear.default": 1}) .check_not(["torch.ops.quantized_decomposed"]) - .to_edge(config=self._edge_compile_config) - .partition() + .to_edge_transform_and_lower(edge_compile_config=self._edge_compile_config) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() .run_method_and_compare_outputs(inputs=test_data) @@ -141,14 +154,15 @@ def _test_linear_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=False), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+BI", permute_memory_to_nhwc=True + ), ) .quantize() .export() .check_count({"torch.ops.aten.linear.default": 1}) .check(["torch.ops.quantized_decomposed"]) - .to_edge(config=self._edge_compile_config) - .partition() + .to_edge_transform_and_lower(edge_compile_config=self._edge_compile_config) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() .run_method_and_compare_outputs(inputs=test_data, qtol=True) @@ -170,8 +184,7 @@ def _test_linear_tosa_ethosu_BI_pipeline( .export() .check_count({"torch.ops.aten.linear.default": 1}) .check(["torch.ops.quantized_decomposed"]) - .to_edge(config=self._edge_compile_config) - .partition() + .to_edge_transform_and_lower(edge_compile_config=self._edge_compile_config) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() .serialize() @@ -184,6 +197,7 @@ def test_linear_tosa_MI( test_name: str, test_data: torch.Tensor, out_features: int, + has_bias: bool, ): in_features = test_data.shape[-1] test_data = (test_data,) @@ -191,6 +205,7 @@ def test_linear_tosa_MI( self.Linear( in_features=in_features, out_features=out_features, + bias=has_bias, ), test_data, ) @@ -201,11 +216,15 @@ def test_linear_tosa_BI( test_name: str, test_data: torch.Tensor, out_features: int, + has_bias: bool, ): in_features = test_data.shape[-1] test_data = (test_data,) self._test_linear_tosa_BI_pipeline( - self.Linear(in_features=in_features, out_features=out_features), test_data + self.Linear( + in_features=in_features, out_features=out_features, bias=has_bias + ), + test_data, ) @parameterized.expand(test_data_suite_rank1) @@ -214,6 +233,7 @@ def test_linear_tosa_u55_BI( test_name: str, test_data: torch.Tensor, out_features: int, + has_bias: bool, ): in_features = test_data.shape[-1] test_data = (test_data,) @@ -221,20 +241,22 @@ def test_linear_tosa_u55_BI( self.Linear( in_features=in_features, out_features=out_features, + bias=has_bias, ), - common.get_u55_compile_spec(permute_memory_to_nhwc=False), + common.get_u55_compile_spec(), test_data, ) if common.is_option_enabled("corstone300"): tester.run_method_and_compare_outputs(qtol=1, inputs=test_data) - @parameterized.expand(test_data_suite_rank1) + @parameterized.expand(test_data_suite_rank1 + test_data_suite_rank4) def test_linear_tosa_u85_BI( self, test_name: str, test_data: torch.Tensor, out_features: int, + has_bias: bool, ): in_features = test_data.shape[-1] test_data = (test_data,) @@ -242,7 +264,8 @@ def test_linear_tosa_u85_BI( self.Linear( in_features=in_features, out_features=out_features, + bias=has_bias, ), - common.get_u85_compile_spec(permute_memory_to_nhwc=False), + common.get_u85_compile_spec(), test_data, ) diff --git a/backends/arm/test/ops/test_log.py b/backends/arm/test/ops/test_log.py index 269b7be25f..847635ea36 100644 --- a/backends/arm/test/ops/test_log.py +++ b/backends/arm/test/ops/test_log.py @@ -40,7 +40,7 @@ def _test_log_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check(["torch.ops.aten.log.default"]) @@ -58,7 +58,7 @@ def _test_log_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tuple): ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_logsoftmax.py b/backends/arm/test/ops/test_logsoftmax.py index 2d51588bb3..5d84fa127f 100644 --- a/backends/arm/test/ops/test_logsoftmax.py +++ b/backends/arm/test/ops/test_logsoftmax.py @@ -46,7 +46,7 @@ def _test_logsoftmax_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check(["torch.ops.aten.log_softmax.int"]) @@ -66,7 +66,7 @@ def _test_logsoftmax_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_max_pool.py b/backends/arm/test/ops/test_max_pool.py index 5c48afa3ce..41526b1c77 100644 --- a/backends/arm/test/ops/test_max_pool.py +++ b/backends/arm/test/ops/test_max_pool.py @@ -62,7 +62,9 @@ def _test_maxpool2d_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+MI", permute_memory_to_nhwc=True + ), ) .export() .check(["torch.ops.aten.max_pool2d.default"]) @@ -87,7 +89,9 @@ def _test_maxpool2d_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+BI", permute_memory_to_nhwc=True + ), ) .quantize(Quantize(quantizer, get_symmetric_quantization_config())) .export() diff --git a/backends/arm/test/ops/test_mean_dim.py b/backends/arm/test/ops/test_mean_dim.py index 68307bcdf1..e8320cf1df 100644 --- a/backends/arm/test/ops/test_mean_dim.py +++ b/backends/arm/test/ops/test_mean_dim.py @@ -81,7 +81,7 @@ def _test_adaptive_avg_pool2d_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check(["torch.ops.aten.adaptive_avg_pool2d.default"]) @@ -101,7 +101,7 @@ def _test_adaptive_avg_pool2d_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() @@ -150,7 +150,7 @@ def _test_meandim_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_not(["torch.ops.quantized_decomposed"]) @@ -169,7 +169,7 @@ def _test_meandim_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_mm.py b/backends/arm/test/ops/test_mm.py index 4271496eaa..21b02bbd10 100644 --- a/backends/arm/test/ops/test_mm.py +++ b/backends/arm/test/ops/test_mm.py @@ -54,7 +54,7 @@ def _test_mm_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.mm.default": 1}) @@ -74,7 +74,7 @@ def _test_mm_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_mul.py b/backends/arm/test/ops/test_mul.py index a1c2dba5fe..7fa20c2566 100644 --- a/backends/arm/test/ops/test_mul.py +++ b/backends/arm/test/ops/test_mul.py @@ -70,7 +70,9 @@ def _test_mul_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+MI", permute_memory_to_nhwc=True + ), ) .export() .check_count({"torch.ops.aten.mul.Tensor": 1}) @@ -89,7 +91,9 @@ def _test_mul_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+BI", permute_memory_to_nhwc=True + ), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_permute.py b/backends/arm/test/ops/test_permute.py index 6346e847c9..62b6b823de 100644 --- a/backends/arm/test/ops/test_permute.py +++ b/backends/arm/test/ops/test_permute.py @@ -57,7 +57,7 @@ def _test_permute_tosa_MI_pipeline( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec( - permute_memory_to_nhwc=permute_memory_to_nhwc + "TOSA-0.80.0+MI", permute_memory_to_nhwc=permute_memory_to_nhwc ), ) .export() @@ -79,7 +79,7 @@ def _test_permute_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize(Quantize(quantizer, get_symmetric_quantization_config())) .export() diff --git a/backends/arm/test/ops/test_reciprocal.py b/backends/arm/test/ops/test_reciprocal.py index cb4971bf8c..7745a614e6 100644 --- a/backends/arm/test/ops/test_reciprocal.py +++ b/backends/arm/test/ops/test_reciprocal.py @@ -46,7 +46,7 @@ def _test_reciprocal_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.reciprocal.default": 1}) @@ -65,7 +65,7 @@ def _test_reciprocal_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_relu.py b/backends/arm/test/ops/test_relu.py index effbccc74d..595c907b32 100644 --- a/backends/arm/test/ops/test_relu.py +++ b/backends/arm/test/ops/test_relu.py @@ -48,7 +48,7 @@ def _test_relu_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check(["torch.ops.aten.relu.default"]) @@ -69,7 +69,7 @@ def _test_relu_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize(Quantize(quantizer, get_symmetric_quantization_config())) .export() diff --git a/backends/arm/test/ops/test_repeat.py b/backends/arm/test/ops/test_repeat.py index 1efac9f974..20c57ba749 100644 --- a/backends/arm/test/ops/test_repeat.py +++ b/backends/arm/test/ops/test_repeat.py @@ -47,7 +47,7 @@ def _test_repeat_tosa_MI_pipeline(self, module: torch.nn.Module, test_data: Tupl ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.repeat.default": 1}) @@ -65,7 +65,7 @@ def _test_repeat_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tupl ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize(Quantize(quantizer, get_symmetric_quantization_config())) .export() diff --git a/backends/arm/test/ops/test_rsqrt.py b/backends/arm/test/ops/test_rsqrt.py index 2ccb7ec991..2cddc8da26 100644 --- a/backends/arm/test/ops/test_rsqrt.py +++ b/backends/arm/test/ops/test_rsqrt.py @@ -35,7 +35,7 @@ def _test_rsqrt_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.rsqrt.default": 1}) @@ -53,7 +53,7 @@ def _test_rsqrt_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_scalars.py b/backends/arm/test/ops/test_scalars.py index 0305ef58c0..86433745a6 100644 --- a/backends/arm/test/ops/test_scalars.py +++ b/backends/arm/test/ops/test_scalars.py @@ -123,7 +123,7 @@ def _test_add_tosa_MI_pipeline(self, module: torch.nn.Module, test_data: tuple): ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .to_edge() @@ -137,7 +137,7 @@ def _test_add_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: tuple): ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_select.py b/backends/arm/test/ops/test_select.py index 6a47c2e66b..85bfc15d2d 100644 --- a/backends/arm/test/ops/test_select.py +++ b/backends/arm/test/ops/test_select.py @@ -58,7 +58,7 @@ def _test_select_tosa_MI_pipeline( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec( - permute_memory_to_nhwc=permute + "TOSA-0.80.0+MI", permute_memory_to_nhwc=permute ), ) .export() @@ -84,7 +84,7 @@ def _test_select_tosa_BI_pipeline( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec( - permute_memory_to_nhwc=permute + "TOSA-0.80.0+BI", permute_memory_to_nhwc=permute ), ) .quantize() @@ -93,6 +93,8 @@ def _test_select_tosa_BI_pipeline( .check(["torch.ops.quantized_decomposed"]) .to_edge() .partition() + .dump_artifact() + .dump_operator_distribution() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() .run_method_and_compare_outputs(inputs=test_data) @@ -160,14 +162,12 @@ def test_select_int_tosa_MI(self, test_data: test_data_t): ) @parameterized.expand(test_data_suite) - @unittest.skip def test_select_copy_tosa_BI(self, test_data: test_data_t): self._test_select_tosa_BI_pipeline( self.SelectCopy(), test_data, export_target="torch.ops.aten.select_copy.int" ) @parameterized.expand(test_data_suite) - @unittest.skip def test_select_int_tosa_BI(self, test_data: test_data_t): self._test_select_tosa_BI_pipeline( self.SelectInt(), test_data, export_target="torch.ops.aten.select.int" diff --git a/backends/arm/test/ops/test_sigmoid.py b/backends/arm/test/ops/test_sigmoid.py index 4d126b68e5..f12658c985 100644 --- a/backends/arm/test/ops/test_sigmoid.py +++ b/backends/arm/test/ops/test_sigmoid.py @@ -71,7 +71,7 @@ def _test_sigmoid_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check(["torch.ops.aten.sigmoid.default"]) @@ -89,7 +89,7 @@ def _test_sigmoid_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tup ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_slice.py b/backends/arm/test/ops/test_slice.py index 18db358fdf..0fc92b011a 100644 --- a/backends/arm/test/ops/test_slice.py +++ b/backends/arm/test/ops/test_slice.py @@ -39,7 +39,7 @@ def _test_slice_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check(["torch.ops.aten.slice.Tensor"]) @@ -60,7 +60,7 @@ def _test_slice_tosa_BI_pipeline( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec( - permute_memory_to_nhwc=permute + "TOSA-0.80.0+BI", permute_memory_to_nhwc=permute ), ) .quantize() diff --git a/backends/arm/test/ops/test_softmax.py b/backends/arm/test/ops/test_softmax.py index 954dd201a9..f883d6b8de 100644 --- a/backends/arm/test/ops/test_softmax.py +++ b/backends/arm/test/ops/test_softmax.py @@ -47,7 +47,7 @@ def _test_softmax_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check(["torch.ops.aten.softmax.int"]) @@ -67,7 +67,7 @@ def _test_softmax_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_split.py b/backends/arm/test/ops/test_split.py index 8ed0e723f1..42395c4c2d 100644 --- a/backends/arm/test/ops/test_split.py +++ b/backends/arm/test/ops/test_split.py @@ -56,7 +56,7 @@ def _test_split_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .to_edge() @@ -79,7 +79,7 @@ def _test_split_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_squeeze.py b/backends/arm/test/ops/test_squeeze.py index c3f1edf37b..7e915da645 100644 --- a/backends/arm/test/ops/test_squeeze.py +++ b/backends/arm/test/ops/test_squeeze.py @@ -61,7 +61,7 @@ def _test_squeeze_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({export_target: 1}) @@ -82,7 +82,7 @@ def _test_squeeze_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_sub.py b/backends/arm/test/ops/test_sub.py index e80c043698..5c67240e52 100644 --- a/backends/arm/test/ops/test_sub.py +++ b/backends/arm/test/ops/test_sub.py @@ -43,7 +43,7 @@ def _test_sub_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.sub.Tensor": 1}) @@ -63,7 +63,7 @@ def _test_sub_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_sum.py b/backends/arm/test/ops/test_sum.py index 73860dfa4a..9cd63b0a22 100644 --- a/backends/arm/test/ops/test_sum.py +++ b/backends/arm/test/ops/test_sum.py @@ -49,7 +49,7 @@ def _test_sum_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.sum.dim_IntList": 1}) @@ -68,7 +68,7 @@ def _test_sum_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_tanh.py b/backends/arm/test/ops/test_tanh.py index 6f5cf17cf3..5f3859eadd 100644 --- a/backends/arm/test/ops/test_tanh.py +++ b/backends/arm/test/ops/test_tanh.py @@ -44,7 +44,7 @@ def _test_tanh_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check(["torch.ops.aten.tanh.default"]) @@ -62,7 +62,7 @@ def _test_tanh_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tuple) ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_unsqueeze.py b/backends/arm/test/ops/test_unsqueeze.py index 36bb93b796..8936d55f8b 100644 --- a/backends/arm/test/ops/test_unsqueeze.py +++ b/backends/arm/test/ops/test_unsqueeze.py @@ -35,7 +35,7 @@ def _test_unsqueeze_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.unsqueeze.default": 1}) @@ -53,7 +53,7 @@ def _test_unsqueeze_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_var.py b/backends/arm/test/ops/test_var.py index 56b7c5fbb4..3a1285e6da 100644 --- a/backends/arm/test/ops/test_var.py +++ b/backends/arm/test/ops/test_var.py @@ -86,7 +86,7 @@ def _test_var_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .to_edge() @@ -107,7 +107,7 @@ def _test_var_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize(Quantize(quantizer, get_symmetric_quantization_config())) .export() diff --git a/backends/arm/test/ops/test_view.py b/backends/arm/test/ops/test_view.py index 54e80702e3..09a8f57bd3 100644 --- a/backends/arm/test/ops/test_view.py +++ b/backends/arm/test/ops/test_view.py @@ -55,7 +55,7 @@ def _test_view_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.view.default": 1}) @@ -73,7 +73,7 @@ def _test_view_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/passes/test_meandim_to_averagepool2d.py b/backends/arm/test/passes/test_meandim_to_averagepool2d.py index c8fa0f4b7a..615187fb65 100644 --- a/backends/arm/test/passes/test_meandim_to_averagepool2d.py +++ b/backends/arm/test/passes/test_meandim_to_averagepool2d.py @@ -46,7 +46,7 @@ def test_tosa_BI_meandim_to_averagepool(self): ArmTester( module, example_inputs=module.get_inputs(), - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() @@ -63,7 +63,7 @@ def test_tosa_BI_meandim_no_modification(self): ArmTester( module, example_inputs=module.get_inputs(), - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/quantizer/test_generic_annotater.py b/backends/arm/test/quantizer/test_generic_annotater.py index b859757df4..3d39463a42 100644 --- a/backends/arm/test/quantizer/test_generic_annotater.py +++ b/backends/arm/test/quantizer/test_generic_annotater.py @@ -30,7 +30,9 @@ def example_inputs(self): class TestGenericAnnotator(unittest.TestCase): def check_annotation(self, model): tester = ArmTester( - model, model.example_inputs(), common.get_tosa_compile_spec() + model, + model.example_inputs(), + common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) quant_model = tester.quantize().get_artifact() partitions = get_source_partitions(quant_model.graph, [model.op]) diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index f3c90eda83..d2ee113a5d 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -17,14 +17,11 @@ import numpy as np import torch -import tosa_reference_model - from torch.export import ExportedProgram from torch.fx.node import Node -from tosa import TosaGraph logger = logging.getLogger(__name__) -logger.setLevel(logging.CRITICAL) +logger.setLevel(logging.WARNING) class QuantizationParams: @@ -170,7 +167,7 @@ def __init__( ): self.intermediate_path = intermediate_path self.tosa_ref_model_path = tosa_ref_model_path or "tosa_reference_model" - assert self.intermediate_path is None or os.path.exists( + assert os.path.exists( self.intermediate_path ), f"TOSA artifact path don't exist! Path: {self.intermediate_path}" @@ -326,46 +323,7 @@ def run_corstone( tosa_ref_output = np.fromfile(out_path_with_suffix, dtype=np.float32) output_shape = self.output_node.args[0][0].meta["val"].shape tosa_ref_output = torch.from_numpy(tosa_ref_output).reshape(output_shape) - return tosa_ref_output - - def run_tosa_graph( - self, graph: TosaGraph, inputs: list[np.ndarray] | list[torch.Tensor] - ) -> torch.Tensor: - """Runs the TOSA reference model with inputs and returns the result.""" - data_np = [ - prep_data_for_save( - input, self.is_quantized, self.input_names[i], self.qp_input[i] - ) - for i, input in enumerate(inputs) - ] - # tosa_profile: 0 = Base Inference, 1 = Main Inference, 2 = Main Training. - tosa_profile = 0 if self.is_quantized else 1 - debug_mode = "ALL" if logger.level <= logging.DEBUG else None - outputs, status = tosa_reference_model.run( - graph, - data_np, - verbosity=_tosa_refmodel_loglevel(logger.level), - tosa_profile=tosa_profile, - initialize_variable_tensor_from_numpy=1, # True - debug_mode=debug_mode, - ) - - assert ( - status == tosa_reference_model.GraphStatus.TOSA_VALID - ), "Non-valid TOSA given to reference model." - - outputs_torch = [] - for output in outputs: - output = output.astype(np.float32) - if self.is_quantized: - # Need to dequant back to FP32 for comparison with torch output - quant_param = self.qp_output - assert ( - quant_param is not None - ), "There are no quantization parameters, check output parameters" - output = (output - quant_param.zp) * quant_param.scale - outputs_torch.append(torch.from_numpy(output)) - return tuple(outputs_torch) + return [tosa_ref_output] def run_tosa_ref_model( self, @@ -450,13 +408,21 @@ def run_tosa_ref_model( assert ( shutil.which(self.tosa_ref_model_path) is not None ), f"tosa_reference_model tool not found, did you run examples/arm/setup.sh? Path: {self.tosa_ref_model_path}" - + loglevel_map = { + logging.INFO: "INFO", + logging.CRITICAL: "LOW", + logging.ERROR: "LOW", + logging.WARNING: "MED", + logging.DEBUG: "HIGH", + logging.NOTSET: "MED", + } + clamped_logging_level = max(min(logger.level // 10 * 10, 50), 0) cmd_ref_model = [ self.tosa_ref_model_path, "--test_desc", desc_file_path, "-l", - _tosa_refmodel_loglevel(logger.level), + loglevel_map[clamped_logging_level], ] _run_cmd(cmd_ref_model) @@ -492,10 +458,7 @@ def run_tosa_ref_model( def prep_data_for_save( - data: torch.Tensor, - is_quantized: bool, - input_name: str, - quant_param: QuantizationParams, + data, is_quantized: bool, input_name: str, quant_param: QuantizationParams ): data_np = np.array(data.detach(), order="C").astype( f"{data.dtype}".replace("torch.", "") @@ -639,19 +602,3 @@ def dbg_tosa_fb_to_json(tosa_fb: bytes) -> Dict: pass return json_out - - -def _tosa_refmodel_loglevel(loglevel: int) -> str: - """Converts a logging loglevel to tosa_reference_model logginglevel, - returned as string. - """ - loglevel_map = { - logging.INFO: "INFO", - logging.CRITICAL: "LOW", - logging.ERROR: "LOW", - logging.WARNING: "MED", - logging.DEBUG: "HIGH", - logging.NOTSET: "MED", - } - clamped_logging_level = max(min(loglevel // 10 * 10, 50), 0) - return loglevel_map[clamped_logging_level] diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 834e177b7d..14a9d1df41 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -39,10 +39,11 @@ from executorch.backends.xnnpack.test.tester import Tester from executorch.devtools.backend_debug import get_delegation_info -from executorch.exir import EdgeCompileConfig, EdgeProgramManager +from executorch.exir import EdgeCompileConfig, ExecutorchProgramManager from executorch.exir.backend.compile_spec_schema import CompileSpec - +from executorch.exir.backend.partitioner import Partitioner from executorch.exir.lowered_backend_module import LoweredBackendModule + from tabulate import tabulate from torch.export.graph_signature import ExportGraphSignature, InputSpec, OutputSpec from torch.fx import Graph @@ -50,50 +51,61 @@ logger = logging.getLogger(__name__) +def _dump_lowered_modules_artifact( + path_to_dump: Optional[str], + artifact: ExecutorchProgramManager, + graph_module: torch.fx.GraphModule, +): + output = "Formated Graph Signature:\n" + output += _format_export_graph_signature( + artifact.exported_program().graph_signature + ) + + def get_output_format(lowered_module) -> str | None: + for spec in lowered_module.compile_specs: + if spec.key == "output_format": + return spec.value.decode() + return None + + for node in graph_module.graph.nodes: + if node.op == "get_attr" and node.name.startswith("lowered_module_"): + lowered_module = getattr(graph_module, node.name) + assert isinstance( + lowered_module, LoweredBackendModule + ), f"Attribute {node.name} must be of type LoweredBackendModule." + + output_format = get_output_format(lowered_module) + if output_format == "tosa": + tosa_fb = lowered_module.processed_bytes + to_print = dbg_tosa_fb_to_json(tosa_fb) + to_print = pformat(to_print, compact=True, indent=1) + output += f"\nTOSA deserialized {node.name}: \n{to_print}\n" + elif output_format == "vela": + vela_cmd_stream = lowered_module.processed_bytes + output += f"\nVela command stream {node.name}: \n{vela_cmd_stream}\n" + else: + logger.warning( + f"No TOSA nor Vela compile spec found in compile specs of {node.name}." + ) + continue + + if not output: + logger.warning("No output to print generated from artifact.") + return + + _dump_str(output, path_to_dump) + + class Partition(tester.Partition): def dump_artifact(self, path_to_dump: Optional[str]): super().dump_artifact(path_to_dump) + _dump_lowered_modules_artifact(path_to_dump, self.artifact, self.graph_module) - output = "Formated Graph Signature:\n" - output += _format_export_graph_signature( - self.artifact.exported_program().graph_signature - ) - - def get_output_format(lowered_module) -> str | None: - for spec in lowered_module.compile_specs: - if spec.key == "output_format": - return spec.value.decode() - return None - - for node in self.graph_module.graph.nodes: - if node.op == "get_attr" and node.name.startswith("lowered_module_"): - lowered_module = getattr(self.graph_module, node.name) - assert isinstance( - lowered_module, LoweredBackendModule - ), f"Attribute {node.name} must be of type LoweredBackendModule." - - output_format = get_output_format(lowered_module) - if output_format == "tosa": - tosa_fb = lowered_module.processed_bytes - to_print = dbg_tosa_fb_to_json(tosa_fb) - to_print = pformat(to_print, compact=True, indent=1) - output += f"\nTOSA deserialized {node.name}: \n{to_print}\n" - elif output_format == "vela": - vela_cmd_stream = lowered_module.processed_bytes - output += ( - f"\nVela command stream {node.name}: \n{vela_cmd_stream}\n" - ) - else: - logger.warning( - f"No TOSA nor Vela compile spec found in compile specs of {node.name}." - ) - continue - if not output: - logger.warning("No output to print generated from artifact.") - return - - _dump_str(output, path_to_dump) +class ToEdgeTransformAndLower(tester.ToEdgeTransformAndLower): + def dump_artifact(self, path_to_dump: Optional[str]): + super().dump_artifact(path_to_dump) + _dump_lowered_modules_artifact(path_to_dump, self.artifact, self.graph_module) class Serialize(tester.Serialize): @@ -120,15 +132,10 @@ def __init__( super().__init__(dynamic_shapes) self.tosa_test_util = tosa_test_util - def run(self, artifact: EdgeProgramManager, inputs=None): - self.executorch_program = artifact.to_executorch(self.config) - if module := getattr( - artifact.exported_program().graph_module, "lowered_module_0", None - ): - self.buffer = module.processed_bytes - def run_artifact(self, inputs): - tosa_output = self.tosa_test_util.run_tosa_graph(self.buffer, inputs) + tosa_output = self.tosa_test_util.run_tosa_ref_model( + inputs=inputs, + ) return tosa_output @@ -216,6 +223,26 @@ def partition(self, partition_stage: Optional[Partition] = None): partition_stage = Partition(arm_partitioner) return super().partition(partition_stage) + def to_edge_transform_and_lower( + self, + to_edge_and_lower_stage: Optional[ToEdgeTransformAndLower] = None, + partitioners: Optional[List[Partitioner]] = None, + edge_compile_config: Optional[EdgeCompileConfig] = None, + ): + if to_edge_and_lower_stage is None: + if partitioners is None: + partitioners = [ArmPartitioner(compile_spec=self.compile_spec)] + to_edge_and_lower_stage = ToEdgeTransformAndLower( + partitioners, edge_compile_config + ) + else: + if partitioners is not None: + to_edge_and_lower_stage.partitioners = partitioners + if edge_compile_config is not None: + to_edge_and_lower_stage.edge_compile_conf = edge_compile_config + to_edge_and_lower_stage.edge_compile_conf._skip_dim_order = True + return super().to_edge_transform_and_lower(to_edge_and_lower_stage) + def to_executorch(self, to_executorch_stage: Optional[ToExecutorch] | None = None): if to_executorch_stage is None: to_executorch_stage = ToExecutorch(self.runner_util) @@ -260,21 +287,23 @@ def run_method_and_compare_outputs( inputs (Optional[Tuple[torch.Tensor]]): Allows you to input custom input data. The default is random data. """ + + edge_stage = self.stages[self.stage_name(tester.ToEdge)] + if edge_stage is None: + edge_stage = self.stages[self.stage_name(tester.ToEdgeTransformAndLower)] assert ( self.runner_util is not None ), "self.tosa_test_util is not initialized, cannot use run_method()" assert ( - self.stages[self.stage_name(tester.ToEdge)] is not None - ), "To compare outputs, at least the ToEdge stage needs to be run." + edge_stage is not None + ), "To compare outputs, at least the ToEdge or ToEdgeTransformAndLower stage needs to be run." stage = stage or self.cur test_stage = self.stages[stage] is_quantized = self.stages[self.stage_name(tester.Quantize)] is not None exported_program = self.stages[self.stage_name(tester.Export)].artifact - edge_program = self.stages[ - self.stage_name(tester.ToEdge) - ].artifact.exported_program() + edge_program = edge_stage.artifact.exported_program() self.runner_util.init_run( exported_program, edge_program, @@ -321,7 +350,7 @@ def run_method_and_compare_outputs( logger.info(f"Run #{run_iteration}, input shapes: {input_shape_str}") reference_output = reference_stage.run_artifact(reference_input) - test_output = test_stage.run_artifact(test_input) + test_output = tuple(test_stage.run_artifact(test_input)) if ( is_nhwc and test_stage == self.stages[self.stage_name(tester.ToExecutorch)] @@ -338,8 +367,10 @@ def get_graph(self, stage: str | None = None) -> Graph: if stage is None: stage = self.cur artifact = self.get_artifact(stage) - if self.cur == self.stage_name(tester.ToEdge) or self.cur == self.stage_name( - Partition + if ( + self.cur == self.stage_name(tester.ToEdge) + or self.cur == self.stage_name(Partition) + or self.cur == self.stage_name(ToEdgeTransformAndLower) ): graph = artifact.exported_program().graph elif self.cur == self.stage_name(tester.Export) or self.cur == self.stage_name( @@ -367,7 +398,14 @@ def dump_operator_distribution( line = "#" * 10 to_print = f"{line} {self.cur.capitalize()} Operator Distribution {line}\n" - if self.cur == self.stage_name(tester.Partition) and print_table: + if ( + self.cur + in ( + self.stage_name(tester.Partition), + self.stage_name(ToEdgeTransformAndLower), + ) + and print_table + ): graph_module = self.get_artifact().exported_program().graph_module if print_table: delegation_info = get_delegation_info(graph_module) diff --git a/backends/arm/tosa_specification.py b/backends/arm/tosa_specification.py new file mode 100644 index 0000000000..716e8daee2 --- /dev/null +++ b/backends/arm/tosa_specification.py @@ -0,0 +1,226 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +# +# Main implementation of AoT flow to partition and preprocess for Arm target +# backends. Converts via TOSA as an intermediate form supported by AoT and +# JIT compiler flows. +# + +import re +from typing import List + +from executorch.exir.backend.compile_spec_schema import CompileSpec +from packaging.version import Version + + +class TosaSpecification: + """ + This class implements a representation of TOSA specification + (https://www.mlplatform.org/tosa/tosa_spec.html) with a version, a profile + (with extension) and a level (8k). + For 0.80 releases the profile is BI or MI, with u55 handled as an inofficial extension + For 1.00 releases the profile is INT or FP, and the extensions are for + INT: int16, int4, var, cf + FP: bf16, fp8e4m3, fp8e5m2, fft, var, cf + + The TOSA specification is encoded in the string represenatation + TOSA-major.minor.patch+profile[+level][+extensions] + + For 0.80 MI implies BI, while for 1.0 the profiles has to explicitely be specified. + + Profiles are uppercase letters and extensions and level is lowercase. + """ + + version: Version + + def support_integer(self) -> bool: + """ + Returns true if any integer operations are supported for the specification. + """ + raise NotImplementedError + + def support_float(self) -> bool: + """ + Returns true if any float operations are supported for the specification. + """ + raise NotImplementedError + + def __init__(self, version: Version): + self.version = version + + @staticmethod + def create_from_compilespecs( + compile_specs: List[CompileSpec], + ) -> "TosaSpecification": + """ + Search the CompileSpec list for 'tosa_version' and instantiate a + class from the found value or return None on failure. + """ + for spec in compile_specs: + if spec.key == "tosa_version": + return TosaSpecification.create_from_string(spec.value.decode()) + raise ValueError( + "No TOSA version key found in any of the supplied CompileSpecs" + ) + + @staticmethod + def create_from_string(repr: str) -> "TosaSpecification": + """ + Creates a TOSA specification class from a string representation: + TOSA-0.80.0+MI + TOSA-0.80.0+BI+8k + TOSA-0.80.0+BI+u55 # Ethos-U55 extension to handle TOSA subset + TOSA-0.90.0+MI + TOSA-1.00.0+INT+FP+int4+cf + """ + + pattern = r"^(TOSA)-([\d.]+)\+(.+)$" + match = re.match(pattern, repr) + if match: + name = match.group(1) + version = Version(match.group(2)) + extras = match.group(3).split("+") + if name != "TOSA": + raise ValueError(f"Malformed TOSA specification representation: {repr}") + match version: + case _ if version.major == 0 and version.minor == 80: + return Tosa_0_80(version, extras) + case _ if version.major == 1 and version.minor == 0: + return Tosa_1_00(version, extras) + case _: + raise ValueError(f"Wrong TOSA version: {version} from {repr}") + + raise ValueError(f"Failed to parse TOSA specification representation: {repr}") + + +class Tosa_0_80(TosaSpecification): + profile: str + level_8k: bool + is_U55_subset: bool + available_profiles = ["BI", "MI"] # MT is not defined + + def __init__(self, version: Version, extras: List[str]): + super().__init__(version) + assert version >= Version("0.80") and version < Version("0.90") + + # Check that we only have one profile in the extensions list + if [e in Tosa_0_80.available_profiles for e in extras].count(True) != 1: + raise ValueError( + f"Bad combination of extras: {extras}, more than one of {Tosa_0_80.available_profiles} found." + ) + + # The list contains one profile at most, so pick it + self.profile = [e for e in extras if e in Tosa_0_80.available_profiles][0] + extras.remove(self.profile) + + self.level_8k = "8k" in extras + if self.level_8k: + extras.remove("8k") + self.is_U55_subset = "u55" in extras + if self.is_U55_subset: + extras.remove("u55") + + if len(extras) > 0: + raise ValueError(f"Unhandled extras found: {extras}") + + def __repr__(self): + extensions = "" + if self.level_8k: + extensions += "+8K" + if self.is_U55_subset: + extensions += "+u55" + return f"TOSA-{str(self.version)}+{self.profile}{extensions}" + + def __hash__(self) -> int: + return hash(str(self.version) + self.profile) + + def __eq__(self, other: object) -> bool: + if isinstance(other, Tosa_0_80): + return (self.version == other.version) and (self.profile == other.profile) + return False + + def support_integer(self): + return True + + def support_float(self): + return self.profile == "MI" + + +class Tosa_1_00(TosaSpecification): + profiles: List[str] + level_8k: bool + extensions: List[str] + + available_profiles = ["INT", "FP"] + valid_extensions = { + "INT": ["int16", "int4", "var", "cf"], + "FP": ["bf16", "fp8e4m3", "fp8e5m2", "fft", "var", "cf"], + } + + def __init__(self, version: Version, extras: List[str]): + super().__init__(version) + + # Check that we have at least one profile in the extensions list + if [e in Tosa_1_00.available_profiles for e in extras].count(True) == 0: + raise ValueError( + f"No profile ({Tosa_1_00.available_profiles}) found in: {extras}." + ) + + # and not more than number of available profiles + if [e in Tosa_1_00.available_profiles for e in extras].count(True) > len( + Tosa_1_00.available_profiles + ): + raise ValueError( + f"Too many profiles ({Tosa_1_00.available_profiles}) found in: {extras}." + ) + + # The list contains one profile at least, so pick them + self.profiles = [e for e in extras if e in Tosa_1_00.available_profiles] + for p in self.profiles: + extras.remove(p) + + self.level_8k = "8k" in extras + if self.level_8k: + extras.remove("8k") + + combined_extensions = [] + for p in self.profiles: + combined_extensions += Tosa_1_00.valid_extensions[p] + + if not all(e in combined_extensions for e in extras): + raise ValueError( + f"Bad extensions for TOSA-{version}{self._get_profiles_string()}: {extras}" + ) + + # all the rest of the extras are handled extensions + self.extensions = extras + + def _get_profiles_string(self) -> str: + return "".join(["+" + p for p in self.profiles]) + + def _get_extensions_string(self) -> str: + return "".join(["+" + e for e in self.extensions]) + + def __repr__(self): + return f"TOSA-{self.version}{self._get_profiles_string()}{self._get_profiles_string()}" + + def __hash__(self) -> int: + return hash(str(self.version) + self._get_profiles_string()) + + def __eq__(self, other: object) -> bool: + if isinstance(other, Tosa_1_00): + return (self.version == other.version) and ( + self._get_profiles_string() == other._get_profiles_string() + ) + return False + + def support_integer(self): + return "INT" in self.profiles + + def support_float(self): + return "FP" in self.profiles diff --git a/backends/arm/tosa_utils.py b/backends/arm/tosa_utils.py index cfafac1676..c91d89b1b9 100644 --- a/backends/arm/tosa_utils.py +++ b/backends/arm/tosa_utils.py @@ -21,6 +21,7 @@ is_quant_node, q_op, ) +from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.exir.dialects._ops import ops as exir_ops from serializer.tosa_serializer import TosaOp from torch.fx import Node @@ -130,31 +131,6 @@ def get_output_node(node: Node) -> Node: return list(node.users)[0] -# Helper function to do broadcasting -# Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_broadcasting -def broadcast_shapes(shape1, shape2): - assert len(shape1) == len(shape2), "broadcast_shapes::shapes must have same ranks" - - need_broadcasting = False - for val1, val2 in zip(shape1, shape2): - if val1 != val2: - need_broadcasting = True - if not need_broadcasting: - return shape1 - - broadcasted_shape = list(shape1) - shape2 = list(shape2) - for idx, _ in enumerate(broadcasted_shape): - if broadcasted_shape[idx] == 1: - broadcasted_shape[idx] = shape2[idx] - else: - assert not ( - shape2[idx] != 1 and shape2[idx] != broadcasted_shape[idx] - ), "broadcast_shapes::broadcast shape mismatch" - - return broadcasted_shape - - """ TOSA reshape returns a tensor with the same type/values as the input. No data conversion happens during a reshape operation. """ @@ -165,36 +141,6 @@ def build_reshape(tosa_fb, input_name, new_shape, output_name): tosa_fb.addOperator(TosaOp.Op().RESHAPE, [input_name], [output_name], attr) -def is_permute_node_before_addmm(node): - return ( - node.target == exir_ops.edge.aten.permute_copy.default - and list(node.users)[0].target == exir_ops.edge.aten.addmm.default - ) - - -def is_bias_node_for_quantized_addmm(node): - consumer_node = list(node.users)[0] - # consumer node is addmm - is_rank2_linear_bias = ( - consumer_node.target == exir_ops.edge.aten.addmm.default - and list(consumer_node.users)[0].target == q_op - ) - - # rank>2 linear layers - # consumer_consumer node is view_copy - is_rank_greater_than_2_linear_bias = False - if ( - consumer_node.target == exir_ops.edge.aten.addmm.default - and list(consumer_node.users)[0].target == exir_ops.edge.aten.view_copy.default - ): - consumer_consumer_node = list(consumer_node.users)[0] - is_rank_greater_than_2_linear_bias = ( - list(consumer_consumer_node.users)[0].target == q_op - ) - - return is_rank2_linear_bias or is_rank_greater_than_2_linear_bias - - def is_bias_node_for_quantized_conv(node): consumer_node = list(node.users)[0] return ( @@ -290,6 +236,7 @@ def process_call_function( node: torch.fx.Node, tosa_graph: ts.TosaSerializer, node_visitors: Dict[str, NodeVisitor], + tosa_spec: TosaSpecification, ): # Unpack arguments and convert inputs = getNodeArgs(node) @@ -299,11 +246,7 @@ def process_call_function( tosa_graph.currRegion.currBasicBlock.addTensor( output.name, - ( - tosa_shape(inputs[0].shape, inputs[0].dim_order) - if is_permute_node_before_addmm(node) - else tosa_shape(output.shape, output.dim_order) - ), + (tosa_shape(output.shape, output.dim_order)), map_dtype(get_quant_node_dtype(node)) if is_quant_node(node) else output.dtype, ) @@ -319,7 +262,7 @@ def process_call_function( is_quant_node(node), ) else: - raise RuntimeError(f"Unknown operator {node.target}") + raise RuntimeError(f"Unknown operator {node.target} for TOSA : {tosa_spec}") def expand_dims( diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index 8456c50f6c..9876e59dbf 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -62,6 +62,21 @@ python_library( ], ) +python_library( + name = "pass_utils", + srcs = [ + "pass_utils.py", + ], + deps = [ + ":utils", + "//caffe2:torch", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + "//executorch/exir/passes:lib", + "//executorch/exir/passes:spec_prop_pass", + ], +) + python_library( name = "ops_registrations", srcs = [ diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index fce6ce5736..5e852b369d 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -50,7 +50,11 @@ "quantized_linear.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( - "cadence::quantized_linear.per_tensor_out(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)" + "quantized_linear.per_tensor_out(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.define( + "quantized_linear.per_tensor(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, " + "SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset) -> Tensor" ) lib.define( @@ -129,6 +133,28 @@ def quantized_linear_meta( return src.new_empty(out_size, dtype=src.dtype) +@register_fake("cadence::quantized_linear.per_tensor") +def quantized_linear_per_tensor_meta( + src: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + in_zero_point: torch.SymInt, + weight_zero_point: torch.SymInt, + out_multiplier: torch.SymInt, + out_shift: torch.SymInt, + out_zero_point: torch.SymInt, + offset: Optional[torch.Tensor], +) -> torch.Tensor: + # src comes in shape [leading_dims, in_dim] + # weight comes in shape [out_dim, in_dim] + # output comes in empty with shape [leading_dims, out_dim] + out_size = list(src.size()) + weight_size = list(weight.size()) + assert len(weight_size) == 2 + out_size[-1] = weight_size[0] + return src.new_empty(out_size, dtype=src.dtype) + + @register_fake("cadence::quantized_conv") def quantized_conv_meta( input: torch.Tensor, diff --git a/backends/cadence/aot/pass_utils.py b/backends/cadence/aot/pass_utils.py new file mode 100644 index 0000000000..3aa6f48a31 --- /dev/null +++ b/backends/cadence/aot/pass_utils.py @@ -0,0 +1,91 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-strict + +from dataclasses import dataclass +from typing import Callable, Optional, Set, Union + +import torch +from executorch.backends.cadence.aot.utils import get_edge_overload_packet + +from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket + +from executorch.exir.pass_base import ExportPass +from torch._ops import OpOverloadPacket + + +# Is an overlap in tensor lifetime and storage allowed at the current opt level? +# We allow overlap at opt level >= 2. +def allow_lifetime_and_storage_overlap(opt_level: int) -> bool: + return opt_level >= 2 + + +# A dataclass that stores the attributes of an ExportPass. +@dataclass +class CadencePassAttribute: + opt_level: Optional[int] = None + debug_pass: bool = False + + +# A dictionary that maps an ExportPass to its attributes. +_ALL_CADENCE_PASSES: dict[ExportPass, CadencePassAttribute] = {} + + +def get_cadence_pass_attribute(p: ExportPass) -> CadencePassAttribute: + return _ALL_CADENCE_PASSES[p] + + +# A decorator that registers a pass. +def register_cadence_pass( + pass_attribute: CadencePassAttribute, +) -> Callable[[ExportPass], ExportPass]: + def wrapper(cls: ExportPass) -> ExportPass: + _ALL_CADENCE_PASSES[cls] = pass_attribute + return cls + + return wrapper + + +def get_all_available_cadence_passes() -> Set[ExportPass]: + return set(_ALL_CADENCE_PASSES.keys()) + + +# Create a new filter to filter out relevant passes from all Jarvis passes. +def create_cadence_pass_filter( + opt_level: int, debug: bool = False +) -> Callable[[ExportPass], bool]: + def _filter(p: ExportPass) -> bool: + pass_attribute = get_cadence_pass_attribute(p) + return ( + pass_attribute.opt_level is not None + and pass_attribute.opt_level <= opt_level + and (not pass_attribute.debug_pass or debug) + ) + + return _filter + + +# Return the overload packet for the edge or torch op. +def get_overload_packet( + op: Union[Callable[..., str], str], +) -> Union[OpOverloadPacket, EdgeOpOverloadPacket, None]: + return ( + get_edge_overload_packet(op) + if isinstance(op, EdgeOpOverload) + else getattr(op, "overloadpacket", None) + ) + + +# Get the list of node names in a graph module (only for "call_function" ops and +# EdgeOpOverload targets). This should be used only after to_edge is called. +def get_node_names_list_from_gm( + graph_module: torch.fx.GraphModule, +) -> list[torch.fx.Node]: + graph_nodes = [] + for node in graph_module.graph.nodes: + if node.op != "call_function": + continue + if not isinstance(node.target, EdgeOpOverload): + continue + graph_nodes.append(node.name) + return graph_nodes diff --git a/backends/cadence/hifi/kernels/kernels.cpp b/backends/cadence/hifi/kernels/kernels.cpp index 10e5fb176e..1b335c846b 100644 --- a/backends/cadence/hifi/kernels/kernels.cpp +++ b/backends/cadence/hifi/kernels/kernels.cpp @@ -165,6 +165,7 @@ void requantize( typed_quantize_val(int8_t); typed_quantize_val(uint8_t); typed_quantize_val(int16_t); +typed_quantize_val(uint16_t); #undef typed_quantize_val #define typed_quantize_vec(dtype) \ @@ -177,6 +178,7 @@ typed_quantize_val(int16_t); typed_quantize_vec(int8_t); typed_quantize_vec(uint8_t); typed_quantize_vec(int16_t); +typed_quantize_vec(uint16_t); typed_quantize_vec(int32_t); #undef typed_quantize_vec @@ -186,6 +188,7 @@ typed_quantize_vec(int32_t); typed_dequantize_val(int8_t); typed_dequantize_val(uint8_t); typed_dequantize_val(int16_t); +typed_dequantize_val(uint16_t); #undef typed_dequantize_val #define typed_dequantize_vec(dtype) \ @@ -198,6 +201,7 @@ typed_dequantize_val(int16_t); typed_dequantize_vec(int8_t); typed_dequantize_vec(uint8_t); typed_dequantize_vec(int16_t); +typed_dequantize_vec(uint16_t); typed_dequantize_vec(int32_t); #undef typed_dequantize_vec diff --git a/backends/cadence/hifi/operators/dequantize_per_tensor.cpp b/backends/cadence/hifi/operators/dequantize_per_tensor.cpp index 18381a26e0..996d753c59 100644 --- a/backends/cadence/hifi/operators/dequantize_per_tensor.cpp +++ b/backends/cadence/hifi/operators/dequantize_per_tensor.cpp @@ -41,6 +41,9 @@ void dequantize_per_tensor_out( } else if (input.scalar_type() == ScalarType::Short) { const int16_t* input_data = input.const_data_ptr(); dequantize(out_data, input_data, scale, zero_point, numel); + } else if (input.scalar_type() == ScalarType::Bits16) { + const uint16_t* input_data = input.const_data_ptr(); + dequantize(out_data, input_data, scale, zero_point, numel); } else if (input.scalar_type() == ScalarType::Int) { const int32_t* input_data = input.const_data_ptr(); dequantize(out_data, input_data, scale, zero_point, numel); diff --git a/backends/cadence/hifi/operators/quantize_per_tensor.cpp b/backends/cadence/hifi/operators/quantize_per_tensor.cpp index c65d62968f..1078b5716c 100644 --- a/backends/cadence/hifi/operators/quantize_per_tensor.cpp +++ b/backends/cadence/hifi/operators/quantize_per_tensor.cpp @@ -44,6 +44,10 @@ void quantize_per_tensor_out( int16_t* out_data = out.mutable_data_ptr(); cadence::impl::HiFi::kernels::quantize( out_data, input_data, 1. / scale, zero_point, numel); + } else if (out.scalar_type() == ScalarType::Bits16) { + uint16_t* out_data = out.mutable_data_ptr(); + cadence::impl::HiFi::kernels::quantize( + out_data, input_data, 1. / scale, zero_point, numel); } else if (out.scalar_type() == ScalarType::Int) { int32_t* out_data = out.mutable_data_ptr(); cadence::impl::HiFi::kernels::quantize( diff --git a/backends/cadence/reference/kernels/kernels.cpp b/backends/cadence/reference/kernels/kernels.cpp index 4d4ff26c3f..faac3d7cb2 100644 --- a/backends/cadence/reference/kernels/kernels.cpp +++ b/backends/cadence/reference/kernels/kernels.cpp @@ -65,6 +65,7 @@ void dequantize( typed_quantize_val(int8_t); typed_quantize_val(uint8_t); typed_quantize_val(int16_t); +typed_quantize_val(uint16_t); typed_quantize_val(int32_t); #undef typed_quantize_val @@ -78,6 +79,7 @@ typed_quantize_val(int32_t); typed_quantize_vec(int8_t); typed_quantize_vec(uint8_t); typed_quantize_vec(int16_t); +typed_quantize_vec(uint16_t); typed_quantize_vec(int32_t); #undef typed_quantize_vec @@ -86,6 +88,7 @@ typed_quantize_vec(int32_t); typed_dequantize_val(int8_t); typed_dequantize_val(uint8_t); typed_dequantize_val(int16_t); +typed_dequantize_val(uint16_t); typed_dequantize_val(int32_t); #undef typed_dequantize_val @@ -99,6 +102,7 @@ typed_dequantize_val(int32_t); typed_dequantize_vec(int8_t); typed_dequantize_vec(uint8_t); typed_dequantize_vec(int16_t); +typed_dequantize_vec(uint16_t); typed_dequantize_vec(int32_t); #undef typed_dequantize_vec diff --git a/backends/cadence/reference/operators/dequantize_per_tensor.cpp b/backends/cadence/reference/operators/dequantize_per_tensor.cpp index aef730bfd1..b49c045b94 100644 --- a/backends/cadence/reference/operators/dequantize_per_tensor.cpp +++ b/backends/cadence/reference/operators/dequantize_per_tensor.cpp @@ -37,6 +37,14 @@ void dequantize_per_tensor_out( const int8_t* input_data = input.const_data_ptr(); impl::reference::kernels::dequantize( out_data, input_data, scale, zero_point, numel); + } else if (input.scalar_type() == ScalarType::Bits16) { + const uint16_t* input_data = input.const_data_ptr(); + impl::reference::kernels::dequantize( + out_data, input_data, scale, zero_point, numel); + } else if (input.scalar_type() == ScalarType::Short) { + const int16_t* input_data = input.const_data_ptr(); + impl::reference::kernels::dequantize( + out_data, input_data, scale, zero_point, numel); } else if (input.scalar_type() == ScalarType::Int) { const int32_t* input_data = input.const_data_ptr(); impl::reference::kernels::dequantize( diff --git a/backends/cadence/reference/operators/quantize_per_tensor.cpp b/backends/cadence/reference/operators/quantize_per_tensor.cpp index 0d7ff0bc7e..ad5fa791b5 100644 --- a/backends/cadence/reference/operators/quantize_per_tensor.cpp +++ b/backends/cadence/reference/operators/quantize_per_tensor.cpp @@ -39,6 +39,14 @@ void quantize_per_tensor_out( int8_t* out_data = out.mutable_data_ptr(); impl::reference::kernels::quantize( out_data, input_data, 1. / scale, zero_point, numel); + } else if (out.scalar_type() == ScalarType::Bits16) { + uint16_t* out_data = out.mutable_data_ptr(); + impl::reference::kernels::quantize( + out_data, input_data, 1. / scale, zero_point, numel); + } else if (out.scalar_type() == ScalarType::Short) { + int16_t* out_data = out.mutable_data_ptr(); + impl::reference::kernels::quantize( + out_data, input_data, 1. / scale, zero_point, numel); } else if (out.scalar_type() == ScalarType::Int) { int32_t* out_data = out.mutable_data_ptr(); impl::reference::kernels::quantize( diff --git a/examples/arm/aot_arm_compiler.py b/examples/arm/aot_arm_compiler.py index 3075d992d5..e718c52fdc 100644 --- a/examples/arm/aot_arm_compiler.py +++ b/examples/arm/aot_arm_compiler.py @@ -180,7 +180,9 @@ def get_compile_spec( spec_builder = None if target == "TOSA": spec_builder = ( - ArmCompileSpecBuilder().tosa_compile_spec().set_permute_memory_format(True) + ArmCompileSpecBuilder() + .tosa_compile_spec("TOSA-0.80.0+BI") + .set_permute_memory_format(True) ) elif "ethos-u55" in target: spec_builder = ( diff --git a/examples/arm/setup.sh b/examples/arm/setup.sh index 43f7d48b83..583237729d 100755 --- a/examples/arm/setup.sh +++ b/examples/arm/setup.sh @@ -88,7 +88,7 @@ ethos_u_base_rev="24.08" # tosa reference model tosa_reference_model_url="https://review.mlplatform.org/tosa/reference_model" -tosa_reference_model_rev="ef31e7222e99cb1c24b2aff9fc52b2d609612283" +tosa_reference_model_rev="f9ea4ab7da19318fe36b1c34d68a3e40fd6e56c5" ######## ### Mandatory user args @@ -227,13 +227,30 @@ function setup_tosa_reference_model() { cd reference_model git checkout ${tosa_reference_model_rev} git submodule update --init --recursive + cd .. + fi + cd reference_model + mkdir -p build + cd build + cmake .. + + # make use of half the cores for building + if [[ "${OS}" == "Linux" ]]; then + n=$(( $(nproc) / 2 )) + elif [[ "${OS}" == "Darwin" ]]; then + n=$(( $(sysctl -n hw.logicalcpu) / 2 )) + else + n=1 fi - echo "pip installing reference_model..." - repo_dir="${root_dir}/reference_model" - cd $repo_dir - pip install . + if [[ "$n" -lt 1 ]]; then + n=1 + fi + make -j"${n}" + cd reference_model + tosa_bin_path=`pwd` + echo "export PATH=\${PATH}:${tosa_bin_path}" >> "${setup_path_script}" } function setup_vela() { diff --git a/exir/pass_base.py b/exir/pass_base.py index db6bef8e3f..9c97921f51 100644 --- a/exir/pass_base.py +++ b/exir/pass_base.py @@ -318,7 +318,11 @@ def call_function( if target == operator.getitem: value, key = args return self.callback.call_getitem(value, key, meta) - elif getattr(target, "__module__", None) in {"_operator", "math"}: + elif getattr(target, "__module__", None) in { + "_operator", + "builtins", + "math", + }: assert callable(target) return self.callback.call_sym(target, args, meta) elif target in _TORCH_SYM_OPS: diff --git a/exir/passes/__init__.py b/exir/passes/__init__.py index 7a0623040f..fdb954010c 100644 --- a/exir/passes/__init__.py +++ b/exir/passes/__init__.py @@ -339,7 +339,7 @@ def get_submodule(node: torch.fx.Node) -> torch.fx.GraphModule: self.call(get_submodule(node.args[0])) self.call(get_submodule(node.args[1])) continue - elif getattr(target, "__module__", None) == "_operator": + elif getattr(target, "__module__", None) in ("builtins", "_operator"): continue elif target in to_out_var_skiplist: continue diff --git a/exir/passes/executorch_prim_ops_registry.py b/exir/passes/executorch_prim_ops_registry.py index 4af233aaa6..fa1c2e6913 100644 --- a/exir/passes/executorch_prim_ops_registry.py +++ b/exir/passes/executorch_prim_ops_registry.py @@ -4,9 +4,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import builtins import math import operator -from typing import Dict, Set, Union +from typing import Any, Dict, Set, Union # necessary to ensure the ops are registered import torch @@ -94,12 +95,24 @@ def neg(a: _SymScalar) -> _SymScalar: return -a # pyre-ignore +@bind_pattern_to_op(executorch_prims_lib, "ceil.Scalar(Scalar a) -> Scalar") +def ceil(a: _SymScalar) -> _SymScalar: + return math.ceil(a) # pyre-ignore + + +@bind_pattern_to_op(executorch_prims_lib, "round.Scalar(Scalar a) -> Scalar") +def builtin_round(a: _SymScalar) -> _SymScalar: + return round(a) # pyre-ignore + + @bind_pattern_to_op(executorch_prims_lib, "trunc.Scalar(Scalar a) -> Scalar") def trunc(a: _SymScalar) -> _SymScalar: return math.trunc(a) # pyre-ignore -_PYTHON_SYM_OPS_TO_EXECUTORCH_SYM_OPS: Dict[OpOverload, OpOverload] = { +_PYTHON_SYM_OPS_TO_EXECUTORCH_SYM_OPS: Dict[Any, OpOverload] = { + builtins.round: ops.backend.executorch_prim.round.Scalar, + math.ceil: ops.backend.executorch_prim.ceil.Scalar, math.trunc: ops.backend.executorch_prim.trunc.Scalar, operator.sub: ops.backend.executorch_prim.sub.Scalar, operator.mul: ops.backend.executorch_prim.mul.Scalar, diff --git a/extension/android/CMakeLists.txt b/extension/android/CMakeLists.txt index c96cfeb5d7..70f21f2751 100644 --- a/extension/android/CMakeLists.txt +++ b/extension/android/CMakeLists.txt @@ -64,7 +64,7 @@ set(executorch_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../lib/cmake/ExecuTorch) find_package(executorch CONFIG REQUIRED) target_link_options_shared_lib(executorch) -add_library(executorch_jni SHARED jni/jni_layer.cpp) +add_library(executorch_jni SHARED jni/jni_layer.cpp jni/log.cpp) set(link_libraries) list( @@ -146,7 +146,7 @@ if(EXECUTORCH_JNI_CUSTOM_LIBRARY) endif() if(EXECUTORCH_BUILD_LLAMA_JNI) - target_sources(executorch_jni PRIVATE jni/jni_layer_llama.cpp) + target_sources(executorch_jni PRIVATE jni/jni_layer_llama.cpp jni/log.cpp) list(APPEND link_libraries llama_runner llava_runner) target_compile_definitions(executorch_jni PUBLIC EXECUTORCH_BUILD_LLAMA_JNI=1) add_subdirectory( diff --git a/extension/android/jni/BUCK b/extension/android/jni/BUCK index 6f269739c0..e1bf26fef2 100644 --- a/extension/android/jni/BUCK +++ b/extension/android/jni/BUCK @@ -1,5 +1,6 @@ load("@fbsource//tools/build_defs/android:fb_android_cxx_library.bzl", "fb_android_cxx_library") load("@fbsource//xplat/executorch/backends/xnnpack/third-party:third_party_libs.bzl", "third_party_dep") +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") load("@fbsource//xplat/executorch/codegen:codegen.bzl", "executorch_generated_lib") oncall("executorch") @@ -25,7 +26,7 @@ executorch_generated_lib( fb_android_cxx_library( name = "executorch_jni", - srcs = ["jni_layer.cpp"], + srcs = ["jni_layer.cpp", "log.cpp"], headers = ["jni_layer_constants.h"], allow_jni_merging = False, compiler_flags = [ @@ -36,6 +37,7 @@ fb_android_cxx_library( soname = "libexecutorch.$(ext)", visibility = ["PUBLIC"], deps = [ + ":log_provider_static", "//fbandroid/libraries/fbjni:fbjni", "//fbandroid/native/fb:fb", "//third-party/glog:glog", @@ -49,7 +51,7 @@ fb_android_cxx_library( fb_android_cxx_library( name = "executorch_jni_full", - srcs = ["jni_layer.cpp"], + srcs = ["jni_layer.cpp", "log.cpp"], headers = ["jni_layer_constants.h"], allow_jni_merging = False, compiler_flags = [ @@ -60,6 +62,7 @@ fb_android_cxx_library( soname = "libexecutorch.$(ext)", visibility = ["PUBLIC"], deps = [ + ":log_provider_static", ":generated_op_lib_optimized_static", "//fbandroid/libraries/fbjni:fbjni", "//fbandroid/native/fb:fb", @@ -88,6 +91,7 @@ fb_android_cxx_library( soname = "libexecutorch.$(ext)", visibility = ["PUBLIC"], deps = [ + ":log_provider_static", "//fbandroid/libraries/fbjni:fbjni", "//fbandroid/native/fb:fb", "//third-party/glog:glog", @@ -101,3 +105,18 @@ fb_android_cxx_library( "//xplat/executorch/extension/threadpool:threadpool_static", ], ) + +runtime.cxx_library( + name = "log_provider", + srcs = ["log.cpp"], + exported_headers = ["log.h"], + compiler_flags = [ + "-frtti", + "-fexceptions", + "-Wno-unused-variable", + ], + deps = [ + "//executorch/runtime/core:core", + ], + visibility = ["@EXECUTORCH_CLIENTS"], +) diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index 479da28806..ddba8462b9 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -17,6 +17,7 @@ #include "jni_layer_constants.h" +#include #include #include #include @@ -36,76 +37,6 @@ using namespace executorch::extension; using namespace torch::executor; -#ifdef __ANDROID__ -#include -#include -#include - -// Number of entries to store in the in-memory log buffer. -const size_t log_buffer_length = 16; - -struct log_entry { - et_timestamp_t timestamp; - et_pal_log_level_t level; - std::string filename; - std::string function; - size_t line; - std::string message; - - log_entry( - et_timestamp_t timestamp, - et_pal_log_level_t level, - const char* filename, - const char* function, - size_t line, - const char* message, - size_t length) - : timestamp(timestamp), - level(level), - filename(filename), - function(function), - line(line), - message(message, length) {} -}; - -namespace { -std::vector log_buffer_; -std::mutex log_buffer_mutex_; -} // namespace - -// For Android, write to logcat -void et_pal_emit_log_message( - et_timestamp_t timestamp, - et_pal_log_level_t level, - const char* filename, - const char* function, - size_t line, - const char* message, - size_t length) { - std::lock_guard guard(log_buffer_mutex_); - - while (log_buffer_.size() >= log_buffer_length) { - log_buffer_.erase(log_buffer_.begin()); - } - - log_buffer_.emplace_back( - timestamp, level, filename, function, line, message, length); - - int android_log_level = ANDROID_LOG_UNKNOWN; - if (level == 'D') { - android_log_level = ANDROID_LOG_DEBUG; - } else if (level == 'I') { - android_log_level = ANDROID_LOG_INFO; - } else if (level == 'E') { - android_log_level = ANDROID_LOG_ERROR; - } else if (level == 'F') { - android_log_level = ANDROID_LOG_FATAL; - } - - __android_log_print(android_log_level, "ExecuTorch", "%s", message); -} -#endif - namespace executorch::extension { class TensorHybrid : public facebook::jni::HybridClass { public: @@ -437,24 +368,26 @@ class ExecuTorchJni : public facebook::jni::HybridClass { facebook::jni::local_ref> readLogBuffer() { #ifdef __ANDROID__ - std::lock_guard guard(log_buffer_mutex_); - - const auto size = log_buffer_.size(); - facebook::jni::local_ref> ret = - facebook::jni::JArrayClass::newArray(size); - - for (auto i = 0u; i < size; i++) { - const auto& entry = log_buffer_[i]; - // Format the log entry as "[TIMESTAMP FUNCTION FILE:LINE] LEVEL MESSAGE". - std::stringstream ss; - ss << "[" << entry.timestamp << " " << entry.function << " " - << entry.filename << ":" << entry.line << "] " - << static_cast(entry.level) << " " << entry.message; - - facebook::jni::local_ref jstr_message = - facebook::jni::make_jstring(ss.str().c_str()); - (*ret)[i] = jstr_message; - } + + facebook::jni::local_ref> ret; + + access_log_buffer([&](std::vector& buffer) { + const auto size = buffer.size(); + ret = facebook::jni::JArrayClass::newArray(size); + for (auto i = 0u; i < size; i++) { + const auto& entry = buffer[i]; + // Format the log entry as "[TIMESTAMP FUNCTION FILE:LINE] LEVEL + // MESSAGE". + std::stringstream ss; + ss << "[" << entry.timestamp << " " << entry.function << " " + << entry.filename << ":" << entry.line << "] " + << static_cast(entry.level) << " " << entry.message; + + facebook::jni::local_ref jstr_message = + facebook::jni::make_jstring(ss.str().c_str()); + (*ret)[i] = jstr_message; + } + }); return ret; #else @@ -468,10 +401,7 @@ class ExecuTorchJni : public facebook::jni::HybridClass { makeNativeMethod("forward", ExecuTorchJni::forward), makeNativeMethod("execute", ExecuTorchJni::execute), makeNativeMethod("loadMethod", ExecuTorchJni::load_method), - -#ifdef __ANDROID__ makeNativeMethod("readLogBuffer", ExecuTorchJni::readLogBuffer), -#endif }); } }; diff --git a/extension/android/jni/log.cpp b/extension/android/jni/log.cpp new file mode 100644 index 0000000000..663198e127 --- /dev/null +++ b/extension/android/jni/log.cpp @@ -0,0 +1,69 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "log.h" + +#ifdef __ANDROID__ + +#include +#include +#include +#include + +using executorch::extension::log_entry; + +// Number of entries to store in the in-memory log buffer. +const size_t log_buffer_length = 16; + +namespace { +std::vector log_buffer_; +std::mutex log_buffer_mutex_; +} // namespace + +// For Android, write to logcat +void et_pal_emit_log_message( + et_timestamp_t timestamp, + et_pal_log_level_t level, + const char* filename, + const char* function, + size_t line, + const char* message, + size_t length) { + std::lock_guard guard(log_buffer_mutex_); + + while (log_buffer_.size() >= log_buffer_length) { + log_buffer_.erase(log_buffer_.begin()); + } + + log_buffer_.emplace_back( + timestamp, level, filename, function, line, message, length); + + int android_log_level = ANDROID_LOG_UNKNOWN; + if (level == 'D') { + android_log_level = ANDROID_LOG_DEBUG; + } else if (level == 'I') { + android_log_level = ANDROID_LOG_INFO; + } else if (level == 'E') { + android_log_level = ANDROID_LOG_ERROR; + } else if (level == 'F') { + android_log_level = ANDROID_LOG_FATAL; + } + + __android_log_print(android_log_level, "ExecuTorch", "%s", message); +} + +namespace executorch::extension { + +void access_log_buffer(std::function&)> accessor) { + std::lock_guard guard(log_buffer_mutex_); + accessor(log_buffer_); +} + +} // namespace executorch::extension + +#endif diff --git a/extension/android/jni/log.h b/extension/android/jni/log.h new file mode 100644 index 0000000000..4389b1d61a --- /dev/null +++ b/extension/android/jni/log.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include +#include +#include + +namespace executorch::extension { +struct log_entry { + et_timestamp_t timestamp; + et_pal_log_level_t level; + std::string filename; + std::string function; + size_t line; + std::string message; + + log_entry( + et_timestamp_t timestamp, + et_pal_log_level_t level, + const char* filename, + const char* function, + size_t line, + const char* message, + size_t length) + : timestamp(timestamp), + level(level), + filename(filename), + function(function), + line(line), + message(message, length) {} +}; + +void access_log_buffer(std::function&)> accessor); +} // namespace executorch::extension diff --git a/kernels/prim_ops/register_prim_ops.cpp b/kernels/prim_ops/register_prim_ops.cpp index 5755ab8d66..38901bb840 100644 --- a/kernels/prim_ops/register_prim_ops.cpp +++ b/kernels/prim_ops/register_prim_ops.cpp @@ -303,6 +303,51 @@ static Kernel prim_ops[] = { } }), + // ceil.Scalar(Scalar a) -> Scalar + Kernel( + "executorch_prim::ceil.Scalar", + [](KernelRuntimeContext& context, EValue** stack) { + (void)context; + EValue& a = *stack[0]; + EValue& out = *stack[1]; + if (a.isDouble()) { + out = EValue(static_cast(ceil(a.toDouble()))); + } else { + ET_CHECK_MSG(false, "Unsupported DType %zu", (size_t)a.tag); + } + }), + + // round.Scalar(Scalar a) -> Scalar + Kernel( + "executorch_prim::round.Scalar", + [](KernelRuntimeContext& context, EValue** stack) { + (void)context; + EValue& a = *stack[0]; + EValue& out = *stack[1]; + if (a.isDouble()) { + // Round half to even to match Python round(). Need an explicit + // implementation as not all platforms support fenv rounding modes. + // See + // https://codeyarns.com/tech/2018-08-17-how-to-round-half-to-even.html + const auto val = a.toDouble(); + const auto r = round(val); + const auto d = r - val; + auto res = 0.0; + + if (std::abs(d) != 0.5) { + res = r; + } else if (fmod(r, 2.0) == 0.0) { + res = r; + } else { + res = val - d; + } + + out = EValue(static_cast(res)); + } else { + ET_CHECK_MSG(false, "Unsupported DType %zu", (size_t)a.tag); + } + }), + // trunc.Scalar(Scalar a) -> Scalar Kernel( "executorch_prim::trunc.Scalar", diff --git a/kernels/prim_ops/test/prim_ops_test.cpp b/kernels/prim_ops/test/prim_ops_test.cpp index 3581a470da..ab6bd28e6c 100644 --- a/kernels/prim_ops/test/prim_ops_test.cpp +++ b/kernels/prim_ops/test/prim_ops_test.cpp @@ -503,6 +503,47 @@ TEST_F(RegisterPrimOpsTest, TestETViewEmpty) { getOpsFn("executorch_prim::et_view.default")(context, bad_stack), ""); } +TEST_F(RegisterPrimOpsTest, TestCeil) { + std::array inputs = { + 0.0, 0.25, 0.5, 0.75, 1.0, 1.75, -0.5, -1.0, -1.5, 9.999999}; + std::array expected = {0, 1, 1, 1, 1, 2, 0, -1, -1, 10}; + + for (auto i = 0; i < inputs.size(); i++) { + EValue values[2]; + values[0] = EValue(inputs[i]); + values[1] = EValue(0.0); + + EValue* stack[2]; + for (size_t j = 0; j < 2; j++) { + stack[j] = &values[j]; + } + + getOpsFn("executorch_prim::ceil.Scalar")(context, stack); + EXPECT_EQ(stack[1]->toInt(), expected[i]); + } +} + +TEST_F(RegisterPrimOpsTest, TestRound) { + // Note that Python uses round-to-even for halfway values. + std::array inputs = { + 0.0, 0.25, 0.5, 0.75, 1.0, 1.5, -0.5, -1.0, -1.5, 9.999999}; + std::array expected = {0, 0, 0, 1, 1, 2, 0, -1, -2, 10}; + + for (auto i = 0; i < inputs.size(); i++) { + EValue values[2]; + values[0] = EValue(inputs[i]); + values[1] = EValue(0.0); + + EValue* stack[2]; + for (size_t j = 0; j < 2; j++) { + stack[j] = &values[j]; + } + + getOpsFn("executorch_prim::round.Scalar")(context, stack); + EXPECT_EQ(stack[1]->toInt(), expected[i]); + } +} + TEST_F(RegisterPrimOpsTest, TestTrunc) { std::array inputs = { 0.0, 0.25, 0.5, 0.75, 1.0, 1.75, -0.5, -1.0, -1.5, 9.999999};