diff --git a/.ci/scripts/test_llama_runner_eager.sh b/.ci/scripts/test_llama_runner_eager.sh index 537d835ba1..0f2cb7b376 100644 --- a/.ci/scripts/test_llama_runner_eager.sh +++ b/.ci/scripts/test_llama_runner_eager.sh @@ -42,11 +42,12 @@ run_and_verify() { -d fp32 \ --max_seq_length 32 \ --temperature 0 \ + --show_tokens \ --prompt "Once upon a time," > result.txt # Verify result.txt RESULT=$(cat result.txt) - EXPECTED_RESULT="there was a little girl" + EXPECTED_RESULT="727, 471, 263, 2217, 7826, 4257, 365, 2354, 29889, 2296, 18012, 304, 1708, 5377, 297, 278, 6575, 845, 457, 29889, 3118, 2462, 29892, 1183, 4446, 263" if [[ "${RESULT}" == *"${EXPECTED_RESULT}"* ]]; then echo "Actual result: ${RESULT}" echo "Success" diff --git a/.github/workflows/ghstack_land.yml b/.github/workflows/ghstack_land.yml index 12782c66dd..e3b02d2a94 100644 --- a/.github/workflows/ghstack_land.yml +++ b/.github/workflows/ghstack_land.yml @@ -5,6 +5,7 @@ on: branches: - 'gh/cccclai/[0-9]+/base' - 'gh/dbort/[0-9]+/base' + - 'gh/dvorjackz/[0-9]+/base' - 'gh/guangy10/[0-9]+/base' - 'gh/helunwencser/[0-9]+/base' - 'gh/jorgep31415/[0-9]+/base' diff --git a/.gitmodules b/.gitmodules index 844cd91789..6844743d73 100644 --- a/.gitmodules +++ b/.gitmodules @@ -64,3 +64,6 @@ [submodule "third-party/pybind11"] path = third-party/pybind11 url = https://github.com/pybind/pybind11.git +[submodule "third-party/ao"] + path = third-party/ao + url = https://github.com/pytorch/ao.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 156fb24e6b..6b76f27eb0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -721,10 +721,15 @@ if(EXECUTORCH_BUILD_PYBIND) -fPIC -frtti -fexceptions - # libtorch is built with the old ABI, so we need to do the same for any - # .cpp files that include torch, c10, or ATen targets. - -D_GLIBCXX_USE_CXX11_ABI=0 ) + if(EXECUTORCH_DO_NOT_USE_CXX11_ABI) + # libtorch is built with the old ABI, so we need to do the same for any + # .cpp files that include torch, c10, or ATen targets. Note that PyTorch + # nightly binary is built with _GLIBCXX_USE_CXX11_ABI set to 0 while its + # CI build sets this to 1 (default) + list(APPEND _pybind_compile_options -D_GLIBCXX_USE_CXX11_ABI=0) + endif() + # util lib add_library( util ${CMAKE_CURRENT_SOURCE_DIR}/extension/evalue_util/print_evalue.cpp diff --git a/README.md b/README.md index da2cb82ef9..aded66bf40 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,11 @@ We recommend using the latest release tag from the See [CONTRIBUTING.md](CONTRIBUTING.md) for details about issues, PRs, code style, CI jobs, and other development topics. +To connect with us and other community members, we invite you to join PyTorch Slack community by filling out this [form](https://docs.google.com/forms/d/e/1FAIpQLSeADnUNW36fjKjYzyHDOzEB_abKQE9b6gqqW9NXse6O0MWh0A/viewform). Once you've joined, you can: +* Head to the `#executorch-general` channel for general questions, discussion, and community support. +* Join the `#executorch-contributors` channel if you're interested in contributing directly to project development. + + ## Directory Structure ``` diff --git a/backends/arm/TARGETS b/backends/arm/TARGETS index 39910f0150..a73973ad04 100644 --- a/backends/arm/TARGETS +++ b/backends/arm/TARGETS @@ -1,3 +1,4 @@ +# @noautodeps load("@fbcode_macros//build_defs:python_library.bzl", "python_library") python_library( @@ -69,6 +70,18 @@ python_library( ], ) +python_library( + name = "tosa_specification", + srcs = [ + "tosa_specification.py", + ], + typing = True, + deps = [ + "fbsource//third-party/pypi/packaging:packaging", + "//executorch/exir/backend:compile_spec_schema", + ], +) + python_library( name = "tosa_utils", srcs = [ diff --git a/backends/arm/_passes/TARGETS b/backends/arm/_passes/TARGETS index ca20b03fcc..6ca59cfee2 100644 --- a/backends/arm/_passes/TARGETS +++ b/backends/arm/_passes/TARGETS @@ -7,6 +7,7 @@ python_library( deps = [ "//executorch/backends/arm:tosa_quant_utils", "//executorch/backends/arm:tosa_utils", + "//executorch/backends/xnnpack/_passes:xnnpack_passes", "//executorch/exir:lib", ], ) diff --git a/backends/arm/_passes/annotate_channels_last_dim_order_pass.py b/backends/arm/_passes/annotate_channels_last_dim_order_pass.py index 77def9e7cd..786117e645 100644 --- a/backends/arm/_passes/annotate_channels_last_dim_order_pass.py +++ b/backends/arm/_passes/annotate_channels_last_dim_order_pass.py @@ -14,7 +14,7 @@ get_first_fake_tensor, insert_q_dq_pair, ) -from executorch.backends.arm.tosa_quant_utils import dq_op, q_op +from executorch.backends.arm.tosa_quant_utils import dq_op, q_op, register_passable_op from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -42,6 +42,9 @@ def _transpose_impl(*args, **kwargs): return args[0] +register_passable_op(torch.ops.passthrough_to_tosa._transpose) + + class AnnotateChannelsLastDimOrder(ExportPass): """ Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index b3ddecbc29..a72cdfd1a0 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -23,6 +23,7 @@ from executorch.backends.arm._passes.decompose_layernorm_pass import ( DecomposeLayerNormPass, ) +from executorch.backends.arm._passes.decompose_linear_pass import DecomposeLinearPass from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass from executorch.backends.arm._passes.decompose_softmaxes_pass import ( DecomposeSoftmaxesPass, @@ -43,6 +44,7 @@ from executorch.backends.arm._passes.unsqueeze_scalar_placeholders_pass import ( UnsqueezeScalarPlaceholdersPass, ) +from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass from executorch.exir import ExportedProgram from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.pass_manager import PassManager @@ -58,6 +60,7 @@ def transform_to_backend_pipeline( ): """Apply passes before transforming program to backend""" self.add_pass(CastInt64ToInt32Pass(exported_program)) + self.add_pass(RemoveGetItemPass()) self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program)) self.add_pass(SizeAdjustConv2DPass()) self.add_pass(RemoveClonePass()) @@ -72,6 +75,7 @@ def transform_to_backend_pipeline( self.add_pass(ConvertSplitToSlicePass()) self.add_pass(Conv1dUnsqueezePass(exported_program)) self.add_pass(DecomposeSoftmaxesPass()) + self.add_pass(DecomposeLinearPass()) for spec in compile_spec: if spec.key == "permute_memory_format": memory_format = spec.value.decode() diff --git a/backends/arm/_passes/decompose_linear_pass.py b/backends/arm/_passes/decompose_linear_pass.py new file mode 100644 index 0000000000..30767b354d --- /dev/null +++ b/backends/arm/_passes/decompose_linear_pass.py @@ -0,0 +1,112 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +from executorch.backends.arm._passes.arm_pass_utils import ( + create_node, + get_first_fake_tensor, +) +from executorch.backends.arm.tosa_quant_utils import dq_op, q_op +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + + +class DecomposeLinearPass(ExportPass): + """ + This pass decomposes linear into a Conv2D with the required view operations. + linear(x, weights, bias) becomes: + x_reshaped = view(x) + weights_reshaped = view(weights) + conv2d = conv2d(x_reshaped, weights_reshaped, bias) + output = view(conv2d) + It also inserts q/dq pairs if the linear node was quantized. + """ + + def call(self, graph_module): + for node in graph_module.graph.nodes: + if node.op != "call_function": + continue + if node.target != exir_ops.edge.aten.linear.default: + continue + args = node.args + input = args[0] + weights = args[1] + bias = args[2] if len(args) > 2 else None + output_shape = get_first_fake_tensor(node).shape + input_shape = get_first_fake_tensor(input).shape + weights_shape = get_first_fake_tensor(weights).shape + batches = int(np.prod(input_shape[:-1])) if len(input_shape) > 1 else 1 + # input has shape (..., Ci) + input_reshaped_shape = [batches, input_shape[-1], 1, 1] + # weights have shape (Co, Ci) + weights_reshaped_shape = [weights_shape[0], weights_shape[1], 1, 1] + + with graph_module.graph.inserting_before(node): + quantize = input.op == "call_function" and input.target == dq_op + q_params = input.args[1:] if quantize else None + # Reshape input to 4D with shape (N, Ci, 1, 1) + input_reshaped = create_node( + graph=graph_module.graph, + op_target=exir_ops.edge.aten.view_copy.default, + args=(input, input_reshaped_shape), + kwargs={}, + quantize=quantize, + q_params=q_params, + ) + + quantize = weights.op == "call_function" and weights.target == dq_op + q_params = weights.args[1:] if quantize else None + # Reshape weights to 4D with shape (Co, Ci, 1, 1) + weights_reshaped = create_node( + graph=graph_module.graph, + op_target=exir_ops.edge.aten.view_copy.default, + args=(weights, weights_reshaped_shape), + kwargs={}, + quantize=quantize, + q_params=q_params, + ) + + consumer_node = list(node.users)[0] + quantize = ( + consumer_node.op == "call_function" and consumer_node.target == q_op + ) + q_params = consumer_node.args[1:] if quantize else None + conv = create_node( + graph=graph_module.graph, + op_target=exir_ops.edge.aten.convolution.default, + args=( + input_reshaped, + weights_reshaped, + bias, + [1, 1], # strides + [0, 0], # padding + [1, 1], # dilation + False, # transposed + [0, 0], # output padding + 1, # groups + ), + kwargs={}, + quantize=quantize, + q_params=q_params, + ) + + with graph_module.graph.inserting_after(conv): + # Reshape output to same rank as original input with shape (..., Co) + # No need to insert q/dq pair as Conv2D node above has inserted them if + # required. + output = create_node( + graph=graph_module.graph, + op_target=exir_ops.edge.aten.view_copy.default, + args=(conv, list(output_shape)), + kwargs={}, + ) + + node.replace_all_uses_with(output) + graph_module.graph.erase_node(node) + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + return PassResult(graph_module, True) diff --git a/backends/arm/_passes/insert_squeeze_after_sum_pass.py b/backends/arm/_passes/insert_squeeze_after_sum_pass.py index 152d5c95f6..adf2b4f491 100644 --- a/backends/arm/_passes/insert_squeeze_after_sum_pass.py +++ b/backends/arm/_passes/insert_squeeze_after_sum_pass.py @@ -8,9 +8,7 @@ import torch import torch.fx -from executorch.backends.arm._passes.arm_pass_utils import create_node, insert_q_dq_pair - -from executorch.backends.arm.tosa_quant_utils import get_quant_node_args, is_quant_node +from executorch.backends.arm._passes.arm_pass_utils import create_node from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -28,8 +26,6 @@ class InsertSqueezeAfterSumPass(ExportPass): sum(dims, keep_dim = False) After pass: sum(dims, keep_dim = True) - (q) - (dq) squeeze(dim = dims) """ @@ -45,12 +41,6 @@ def call(self, graph_module: torch.fx.GraphModule): continue dim_list = cast(list[int], sum_node.args[1]) - quantized = is_quant_node(sum_node) - if quantized: - qparams = get_quant_node_args(sum_node.all_input_nodes[0]) - qparams = qparams + (torch.int8,) - else: - qparams = None # Add keep_dim = True arg to sum node. sum_node.args = sum_node.args[0:2] + (True,) @@ -61,8 +51,6 @@ def call(self, graph_module: torch.fx.GraphModule): ) sum_node.replace_all_uses_with(squeeze_node) squeeze_node.args = (sum_node, dim_list) - if quantized: - sum_node = insert_q_dq_pair(graph_module.graph, sum_node, qparams) graph_module.graph.eliminate_dead_code() graph_module.recompile() graph_module = super().call(graph_module).graph_module diff --git a/backends/arm/_passes/size_adjust_conv2d_pass.py b/backends/arm/_passes/size_adjust_conv2d_pass.py index 980ab09e59..c7bd27dcce 100644 --- a/backends/arm/_passes/size_adjust_conv2d_pass.py +++ b/backends/arm/_passes/size_adjust_conv2d_pass.py @@ -9,7 +9,7 @@ from typing import cast, Optional import torch.fx -from executorch.backends.arm.tosa_quant_utils import is_quant_node +from executorch.backends.arm.tosa_quant_utils import is_node_quantized from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult from torch._ops import OpOverload @@ -113,7 +113,7 @@ def call(self, graph_module: torch.fx.GraphModule): slice_node = graph.create_node( "call_function", self.slice_op, (last_node,) + args ) - if is_quant_node(last_node): + if is_node_quantized(last_node): q_params = last_node.args[1:] dq_node = insert_q_dq_pair( graph_module.graph, slice_node, q_params diff --git a/backends/arm/arm_backend.py b/backends/arm/arm_backend.py index 28af583106..47c3c2d5e5 100644 --- a/backends/arm/arm_backend.py +++ b/backends/arm/arm_backend.py @@ -20,6 +20,8 @@ from executorch.backends.arm.operators.node_visitor import get_node_visitors from executorch.backends.arm.operators.op_output import process_output from executorch.backends.arm.operators.op_placeholder import process_placeholder + +from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.backends.arm._passes.arm_pass_manager import ( ArmPassManager, ) # usort: skip @@ -50,6 +52,7 @@ def __init__(self): # TODO MLETORCH-265 Remove permute_nhwc flag self.permute_nhwc = False self.quantize_io = False + self.tosa_version = None def ethosu_compile_spec( self, @@ -86,9 +89,15 @@ def ethosu_compile_spec( if extra_flags is not None: self.compiler_flags.append(extra_flags) + base_tosa_version = "TOSA-0.80.0+BI" + if "U55" in config: + # Add the Ethos-U55 extension marker + base_tosa_version += "+u55" + self.tosa_version = TosaSpecification.create_from_string(base_tosa_version) + return self - def tosa_compile_spec(self) -> "ArmCompileSpecBuilder": + def tosa_compile_spec(self, tosa_version: str) -> "ArmCompileSpecBuilder": """ Generate compile spec for TOSA flatbuffer output """ @@ -96,6 +105,7 @@ def tosa_compile_spec(self) -> "ArmCompileSpecBuilder": self.output_format is None ), f"Output format already set: {self.output_format}" self.output_format = "tosa" + self.tosa_version = TosaSpecification.create_from_string(tosa_version) return self def dump_intermediate_artifacts_to( @@ -129,6 +139,13 @@ def build(self) -> List[CompileSpec]: """ Generate a list of compile spec objects from the builder """ + assert self.tosa_version + + # Always supply a TOSA version + self.compile_spec = [ + CompileSpec("tosa_version", str(self.tosa_version).encode()) + ] + if self.output_format == "vela": self.compile_spec += [ CompileSpec("output_format", "vela".encode()), @@ -210,11 +227,18 @@ def preprocess( # noqa: C901 if not output_format: raise RuntimeError("output format is required") + tosa_spec = TosaSpecification.create_from_compilespecs(compile_spec) + assert ( + tosa_spec is not None + ), "TOSA backend needs a TOSA version specified in the CompileSpec!" + if output_format == "vela" and len(compile_flags) == 0: # Not testing for compile_flags correctness here, just that they are # present. The compiler will give errors if they are not valid. raise RuntimeError("compile flags are required for vela output format") + logger.info(f"Converting ExportedProgram to TOSA: {tosa_spec}") + # Converted output for this subgraph, serializer needs path early as it emits # const data directly. Path created and data written only in debug builds. tosa_graph = ts.TosaSerializer(artifact_path) @@ -222,13 +246,13 @@ def preprocess( # noqa: C901 exported_program=edge_program, compile_spec=compile_spec ) - node_visitors = get_node_visitors(edge_program) + node_visitors = get_node_visitors(edge_program, tosa_spec) for node in graph_module.graph.nodes: if node.op == "call_function": - process_call_function(node, tosa_graph, node_visitors) + process_call_function(node, tosa_graph, node_visitors, tosa_spec) elif node.op == "placeholder": - process_placeholder(node, tosa_graph, edge_program) + process_placeholder(node, tosa_graph, edge_program, tosa_spec) elif node.op == "output": process_output(node, tosa_graph) else: diff --git a/backends/arm/arm_partitioner.py b/backends/arm/arm_partitioner.py index 7309287998..ef924fa434 100644 --- a/backends/arm/arm_partitioner.py +++ b/backends/arm/arm_partitioner.py @@ -8,7 +8,7 @@ import logging import operator import os -from typing import cast, final, List +from typing import Callable, cast, final, List, Optional, Tuple import torch from executorch.backends.arm.arm_backend import ArmBackend # usort: skip @@ -39,7 +39,6 @@ class TOSASupportedOperators(OperatorSupportBase): def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: supported = node.op == "call_function" and node.target in [ exir_ops.edge.aten.add.Tensor, - exir_ops.edge.aten.addmm.default, exir_ops.edge.aten.expand_copy.default, exir_ops.edge.aten.cat.default, exir_ops.edge.aten.bmm.default, @@ -49,12 +48,14 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: exir_ops.edge.aten.div.Tensor, exir_ops.edge.aten.exp.default, exir_ops.edge.aten.log.default, + exir_ops.edge.aten.linear.default, exir_ops.edge.aten.split_with_sizes_copy.default, exir_ops.edge.aten.full.default, exir_ops.edge.aten.mul.Tensor, exir_ops.edge.aten._native_batch_norm_legit_no_training.default, exir_ops.edge.aten.native_layer_norm.default, exir_ops.edge.aten.avg_pool2d.default, + exir_ops.edge.aten.max_pool2d_with_indices.default, exir_ops.edge.aten.sigmoid.default, exir_ops.edge.aten.mm.default, exir_ops.edge.aten.repeat.default, @@ -136,3 +137,12 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: return PartitionResult( tagged_exported_program=exported_program, partition_tags=partition_tags ) + + def ops_to_not_decompose( + self, + ep: ExportedProgram, + ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: + ops_to_not_decompose = [ + torch.ops.aten.linear.default, + ] + return (ops_to_not_decompose, None) diff --git a/backends/arm/operators/TARGETS b/backends/arm/operators/TARGETS index fd04d5fb84..d12cc7e4df 100644 --- a/backends/arm/operators/TARGETS +++ b/backends/arm/operators/TARGETS @@ -1,3 +1,4 @@ +# @noautodeps load("@fbcode_macros//build_defs:python_library.bzl", "python_library") python_library( @@ -6,6 +7,7 @@ python_library( typing = True, deps = [ "//executorch/backends/arm:tosa_mapping", + "//executorch/backends/arm:tosa_specification", ], ) diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index a8ddf1c8f0..6e51c2c141 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -8,7 +8,6 @@ from . import ( # noqa node_visitor, op_add, - op_addmm, op_avg_pool2d, op_batch_norm, op_bmm, @@ -20,6 +19,7 @@ op_get_item, op_hardtanh, op_log, + op_max_pool2d, op_mm, op_mul, op_permute, diff --git a/backends/arm/operators/node_visitor.py b/backends/arm/operators/node_visitor.py index 99fd0388e4..9e98ebcab9 100644 --- a/backends/arm/operators/node_visitor.py +++ b/backends/arm/operators/node_visitor.py @@ -1,4 +1,4 @@ -# Copyright 2023 Arm Limited and/or its affiliates. +# Copyright 2023-2024 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -10,6 +10,7 @@ import serializer.tosa_serializer as ts import torch from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_specification import TosaSpecification from torch.export import ExportedProgram @@ -18,8 +19,19 @@ class NodeVisitor: Node Visitor pattern for lowering edge IR to TOSA """ - def __init__(self, exported_program: ExportedProgram): + # Add the currently supported node_visitor specs as default. + # This should be overriden in the NodeVisitor subclasses to target + # a specific TOSA version. + # When all node_visitors has been refactored to target a specific + # version, this list should be removed. + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80.0+BI"), + TosaSpecification.create_from_string("TOSA-0.80.0+MI"), + ] + + def __init__(self, exported_program: ExportedProgram, tosa_spec: TosaSpecification): self._exported_program = exported_program or None + self.tosa_spec = tosa_spec def define_node( self, @@ -33,16 +45,30 @@ def define_node( # container for all node visitors -_node_visitor_dict = {} +_node_visitor_dicts = { + TosaSpecification.create_from_string("TOSA-0.80.0+BI"): {}, + TosaSpecification.create_from_string("TOSA-0.80.0+MI"): {}, +} def register_node_visitor(visitor): - _node_visitor_dict[visitor.target] = visitor + for tosa_spec in visitor.tosa_specs: + _node_visitor_dicts[tosa_spec][visitor.target] = visitor + return visitor def get_node_visitors(*args) -> Dict[str, NodeVisitor]: node_visitors = {} - for target, visitor in _node_visitor_dict.items(): + tosa_spec = None + for arg in args: + if isinstance(arg, TosaSpecification): + tosa_spec = arg + break + + if tosa_spec is None: + raise RuntimeError("No TOSA specification supplied.") + + for target, visitor in _node_visitor_dicts[tosa_spec].items(): node_visitors[target] = visitor(*args) return node_visitors diff --git a/backends/arm/operators/op_add.py b/backends/arm/operators/op_add.py index ec2ade9e8a..7a71a0d2bd 100644 --- a/backends/arm/operators/op_add.py +++ b/backends/arm/operators/op_add.py @@ -11,19 +11,25 @@ import executorch.backends.arm.tosa_utils as tutils import serializer.tosa_serializer as ts +import torch from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_specification import TosaSpecification from serializer.tosa_serializer import TosaOp from torch.fx import Node @register_node_visitor -class AddVisitor(NodeVisitor): +class AddVisitor_080_BI(NodeVisitor): target = "aten.add.Tensor" + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80.0+BI"), + ] + def __init__(self, *args): super().__init__(*args) @@ -35,33 +41,72 @@ def define_node( output: TosaArg, is_quant_node: bool, ) -> None: - if is_quant_node: - input_nodes = tutils.get_two_inputs(node) + input_nodes = tutils.get_two_inputs(node) + + if not is_quant_node and not all( + tensor.meta["val"].dtype in (torch.int8, torch.int32) + for tensor in input_nodes + ): + raise RuntimeError( + f"Unexpected non quantized {AddVisitor_080_BI.target} node." + ) + needs_rescale = not ( + all(tensor.meta["val"].dtype == torch.int32 for tensor in input_nodes) + and node.meta["val"].dtype == torch.int32 + ) + + if needs_rescale: # Rescale inputs to 32 bit rescaled_inputs, scale = tqutils.rescale_nodes_to_int32( input_nodes, tosa_graph ) - # Preapre sub output tensor - broadcasted_shape = tutils.broadcast_shapes( - rescaled_inputs[0].shape, rescaled_inputs[0].shape - ) + # Prepare add output tensor + broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order) add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32) + else: + add_output = output + rescaled_inputs = inputs - # Do the INT32 Add - tosa_graph.addOperator( - TosaOp.Op().ADD, - [ - rescaled_inputs[0].name, - rescaled_inputs[1].name, - ], - [add_output.name], - None, - ) + # Do the INT32 Add + tosa_graph.addOperator( + TosaOp.Op().ADD, + [ + rescaled_inputs[0].name, + rescaled_inputs[1].name, + ], + [add_output.name], + None, + ) + if needs_rescale: # Scale output back to 8 bit tqutils.rescale_node_back_to_int8(node, add_output, scale, tosa_graph) + + +@register_node_visitor +class AddVisitor_080_MI(AddVisitor_080_BI): + # inheriting 'target' from BI class + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80.0+MI"), + ] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + if is_quant_node: + # Call the inherited define_node for handling integers + super().define_node(node, tosa_graph, inputs, output, is_quant_node) else: # FP32 Add lowering tosa_graph.addOperator( diff --git a/backends/arm/operators/op_addmm.py b/backends/arm/operators/op_addmm.py deleted file mode 100644 index b4f782db4a..0000000000 --- a/backends/arm/operators/op_addmm.py +++ /dev/null @@ -1,148 +0,0 @@ -# Copyright 2023-2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -from typing import List - -import serializer.tosa_serializer as ts -import torch -from executorch.backends.arm.operators.node_visitor import ( - NodeVisitor, - register_node_visitor, -) -from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import build_rescale, get_quant_node_args - -from executorch.backends.arm.tosa_utils import build_reshape -from executorch.exir.dialects._ops import ops as exir_ops -from serializer.tosa_serializer import TosaOp - - -@register_node_visitor -class AddmmVisitor(NodeVisitor): - target = "aten.addmm.default" - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, - inputs: List[TosaArg], - output: TosaArg, - is_quant_node: bool, - ) -> None: - bias, input, weight = inputs - - N = input.shape[0] - input_channels = input.shape[1] - output_channels = weight.shape[1] - - input_new_shape = (N, 1, 1, input_channels) - input_reshaped = tosa_graph.addIntermediate( - input_new_shape, - ts.DType.INT8 if is_quant_node else input.dtype, - ) - - build_reshape(tosa_graph, input.name, input_new_shape, input_reshaped.name) - - weight_new_shape = (output_channels, 1, 1, input_channels) - weight_reshaped = tosa_graph.addIntermediate( - weight_new_shape, - ts.DType.INT8 if is_quant_node else weight.dtype, - ) - - build_reshape(tosa_graph, weight.name, weight_new_shape, weight_reshaped.name) - - # Get the attributes of convolution. - attr = ts.TosaSerializerAttribute() - pad_attr = [0, 0, 0, 0] - stride_attr = [1, 1] - dilation_attr = [1, 1] - - input_zp = 0 - if is_quant_node: - input_node = node.all_input_nodes[1] - # rank > 2 linear layer - if input_node.target == exir_ops.edge.aten.view_copy.default: - quant_node = input_node.all_input_nodes[0] - else: - quant_node = input_node - input_zp = get_quant_node_args(quant_node).zp - attr.ConvAttribute( - pad=pad_attr, - stride=stride_attr, - dilation=dilation_attr, - input_zp=input_zp, - weight_zp=0, - local_bound=False, - ) - - conv2d_output_shape = (N, 1, 1, output_channels) - conv2d_res = tosa_graph.addIntermediate( - conv2d_output_shape, - ts.DType.INT32 if is_quant_node else output.dtype, - ) - - # U55 doesn't support tosa.matmul and tosa.fully_connected will be deprecated - # TOSA Conv2d input is NHWC and weights are in OHWI - tosa_graph.addOperator( - TosaOp.Op().CONV2D, - [ - input_reshaped.name, - weight_reshaped.name, - bias.name, - ], - [conv2d_res.name], - attr, - ) - - result_shape = (N, output_channels) - - if is_quant_node: - # Read inputs' parent nodes - _, input_node, weight_node = node.all_input_nodes - - # rank > 2 linear layer - if input_node.target == exir_ops.edge.aten.view_copy.default: - quant_node = input_node.all_input_nodes[0] - input_scale = get_quant_node_args(quant_node).scale - consumer_node = list(node.users)[0] - consumer_consumer_node = list(consumer_node.users)[0] - quant_args = get_quant_node_args(consumer_consumer_node) - consumer_node_scale = quant_args.scale - consumer_node_node_zp = quant_args.zp - else: - input_scale = get_quant_node_args(input_node).scale - consumer_node = list(node.users)[0] - quant_args = get_quant_node_args(consumer_node) - consumer_node_scale = quant_args.scale - consumer_node_node_zp = quant_args.zp - - weight_node_q_node = weight_node.all_input_nodes[0] - weight_scale = get_quant_node_args(weight_node_q_node).scale - - output_rescale_scale = (input_scale * weight_scale) / consumer_node_scale - - reshaped_res = tosa_graph.addIntermediate(result_shape, ts.DType.INT32) - build_reshape(tosa_graph, conv2d_res.name, result_shape, reshaped_res.name) - - build_rescale( - tosa_fb=tosa_graph, - scale=output_rescale_scale, - input_node=reshaped_res, - output_name=output.name, - output_type=ts.DType.INT8, - output_shape=reshaped_res.shape, - input_zp=0, - output_zp=consumer_node_node_zp, - is_double_round=False, - ) - - else: - # non-quantized case - build_reshape(tosa_graph, conv2d_res.name, result_shape, output.name) diff --git a/backends/arm/operators/op_bmm.py b/backends/arm/operators/op_bmm.py index 161b5d2239..8c9bd7ac2a 100644 --- a/backends/arm/operators/op_bmm.py +++ b/backends/arm/operators/op_bmm.py @@ -14,7 +14,11 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import build_rescale, get_quant_node_args +from executorch.backends.arm.tosa_quant_utils import ( + build_rescale, + get_quant_arg_downstream, + get_quant_arg_upstream, +) from executorch.backends.arm.tosa_utils import get_two_inputs from serializer.tosa_serializer import TosaOp @@ -42,8 +46,10 @@ def define_node( # For INT8, we need to get the zero points and add an intermediate tensor # for a later rescale. if is_quant_node: - input0_zp = get_quant_node_args(input0).zp - input1_zp = get_quant_node_args(input1).zp + input0_q_params = get_quant_arg_upstream(input0) + input1_q_params = get_quant_arg_upstream(input1) + input0_zp = input0_q_params.zp + input1_zp = input1_q_params.zp bmm_result = tosa_graph.addIntermediate(output.shape, ts.DType.INT32) bmm_output_name = bmm_result.name else: @@ -63,9 +69,7 @@ def define_node( # As INT8 accumulates into INT32, we need to rescale it back to INT8 if is_quant_node: - input0_q_params = get_quant_node_args(input0) - input1_q_params = get_quant_node_args(input1) - output_q_params = get_quant_node_args(list(node.users)[0]) + output_q_params = get_quant_arg_downstream(list(node.users)[0]) final_output_scale = ( input0_q_params.scale * input1_q_params.scale diff --git a/backends/arm/operators/op_conv2d.py b/backends/arm/operators/op_conv2d.py index 64cde0724f..ffbeee7306 100644 --- a/backends/arm/operators/op_conv2d.py +++ b/backends/arm/operators/op_conv2d.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe -from typing import cast, List +from typing import List import serializer.tosa_serializer as ts import torch @@ -15,9 +15,10 @@ from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_quant_utils import ( build_rescale_conv_output, - get_quant_node_args, + get_quant_arg_downstream, + get_quant_arg_upstream, ) -from executorch.backends.arm.tosa_utils import build_reshape, getNodeArgs, tosa_shape +from executorch.backends.arm.tosa_utils import build_reshape, tosa_shape from serializer.tosa_serializer import TosaOp @@ -82,7 +83,7 @@ def define_node( ) input_zp = ( - get_quant_node_args(node.all_input_nodes[0]).zp if is_quant_node else 0 + get_quant_arg_upstream(node.all_input_nodes[0]).zp if is_quant_node else 0 ) attr.ConvAttribute( @@ -158,9 +159,10 @@ def define_node( # integer value domain of the next op. Otherwise return float32 output. if is_quant_node: # Get scale_factor from input, weight, and output. - _, input_scale, _, _, _, _ = getNodeArgs(cast(torch.fx.Node, node.args[0])) - _, weight_scale, _, _, _, _ = getNodeArgs(cast(torch.fx.Node, node.args[1])) - _, output_scale, output_zp, _, _, _ = getNodeArgs(list(node.users)[0]) + input_scale = get_quant_arg_upstream(node.all_input_nodes[0]).scale + weight_scale = get_quant_arg_upstream(node.all_input_nodes[1]).scale + output_qargs = get_quant_arg_downstream(list(node.users)[0]) + build_rescale_conv_output( tosa_graph, # pyre-fixme[61]: Uninitialized local [61]: Local variable `conv2d_res` is undefined, or not always defined. @@ -169,6 +171,6 @@ def define_node( actual_out_type, input_scale, weight_scale, - output_scale, - output_zp, + output_qargs.scale, + output_qargs.zp, ) diff --git a/backends/arm/operators/op_exp.py b/backends/arm/operators/op_exp.py index 0e0a75dcc4..7a0b4e104f 100644 --- a/backends/arm/operators/op_exp.py +++ b/backends/arm/operators/op_exp.py @@ -17,7 +17,8 @@ from executorch.backends.arm.tosa_quant_utils import ( dequantize_value, - get_quant_node_args, + get_quant_arg_downstream, + get_quant_arg_upstream, QuantArgs, quantize_value, ) @@ -48,9 +49,9 @@ def define_node( # Create attribute for 8 bit table lookup. input_node = node.all_input_nodes[0] - in_quantargs = get_quant_node_args(input_node) + in_quantargs = get_quant_arg_upstream(input_node) output_node = list(node.users)[0] - out_quantargs = get_quant_node_args(output_node) + out_quantargs = get_quant_arg_downstream(output_node) table = exp_table_8bit(in_quantargs, out_quantargs) table_attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_full.py b/backends/arm/operators/op_full.py index cf67975e0d..d2bc1377ce 100644 --- a/backends/arm/operators/op_full.py +++ b/backends/arm/operators/op_full.py @@ -14,7 +14,10 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import get_quant_node_args +from executorch.backends.arm.tosa_quant_utils import ( + get_quant_arg_downstream, + quantize_value, +) from executorch.backends.arm.tosa_utils import tosa_shape from torch.fx import Node @@ -39,10 +42,8 @@ def define_node( value = inputs[1].number if is_quant_node: - qargs = get_quant_node_args(list(node.users)[0]) - qvalue = np.clip( - np.round(value / qargs.scale) + qargs.zp, qargs.qmin, qargs.qmax - ) + qargs = get_quant_arg_downstream(list(node.users)[0]) + qvalue = quantize_value(value, qargs) dtype = ts.DType.INT8 data = np.full(shape, qvalue, dtype=np.int8) else: diff --git a/backends/arm/operators/op_hardtanh.py b/backends/arm/operators/op_hardtanh.py index 62c0a27f05..e726028206 100644 --- a/backends/arm/operators/op_hardtanh.py +++ b/backends/arm/operators/op_hardtanh.py @@ -14,7 +14,10 @@ ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import get_quant_node_args +from executorch.backends.arm.tosa_quant_utils import ( + get_quant_arg_upstream, + quantize_value, +) from serializer.tosa_serializer import TosaOp @@ -37,12 +40,10 @@ def define_node( if is_quant_node: # Get quant parameters - scale, zp, qmin, qmax = get_quant_node_args(node.all_input_nodes[0]) + qargs = get_quant_arg_upstream(node.all_input_nodes[0]) # Convert to quantized representation - clamp_min_qs = round((inputs[1].number / scale) + zp) - clamp_min_qs = max(clamp_min_qs, qmin) - clamp_max_qs = round((inputs[2].number / scale) + zp) - clamp_max_qs = min(clamp_max_qs, qmax) + clamp_min_qs = quantize_value(inputs[1].number, qargs) + clamp_max_qs = quantize_value(inputs[2].number, qargs) # Set fp values to 0.0 since they are not used clamp_min_fp = 0.0 clamp_max_fp = 0.0 diff --git a/backends/arm/operators/op_log.py b/backends/arm/operators/op_log.py index 5276173efa..76adc2325e 100644 --- a/backends/arm/operators/op_log.py +++ b/backends/arm/operators/op_log.py @@ -17,7 +17,8 @@ from executorch.backends.arm.tosa_quant_utils import ( dequantize_value, - get_quant_node_args, + get_quant_arg_downstream, + get_quant_arg_upstream, QuantArgs, quantize_value, ) @@ -49,9 +50,9 @@ def define_node( # Create attribute for 8 bit table lookup. input_node = node.all_input_nodes[0] - in_quantargs = get_quant_node_args(input_node) + in_quantargs = get_quant_arg_upstream(input_node) output_node = list(node.users)[0] - out_quantargs = get_quant_node_args(output_node) + out_quantargs = get_quant_arg_downstream(output_node) table = log_table_8bit(in_quantargs, out_quantargs) table_attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_max_pool2d.py b/backends/arm/operators/op_max_pool2d.py new file mode 100644 index 0000000000..74e33ddb02 --- /dev/null +++ b/backends/arm/operators/op_max_pool2d.py @@ -0,0 +1,78 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +from typing import List + +import serializer.tosa_serializer as ts +import torch +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_utils import ( + get_quant_arg_downstream, + get_quant_arg_upstream, +) + +from serializer.tosa_serializer import TosaOp + + +@register_node_visitor +class MaxPool2dVisitor(NodeVisitor): + target = "aten.max_pool2d.default" + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + + input_tensor = inputs[0] + kernel_size = inputs[1].special + stride = inputs[2].special + + try: + padding = [*inputs[3].special, *inputs[3].special] + except IndexError: + padding = [0, 0, 0, 0] + + accumulator_type = input_tensor.dtype + + if is_quant_node: + # Accumulator type always is int8 when input tensor is an integer type. + accumulator_type = ts.DType.INT8 + + # Initilize zero point to zero. + input_zp = 0 + output_zp = 0 + + if is_quant_node: + input_zp = get_quant_arg_upstream(node.all_input_nodes[0]).zp + output_zp = get_quant_arg_downstream(list(node.users)[0]).zp + + attr = ts.TosaSerializerAttribute() + attr.PoolAttribute( + kernel=kernel_size, + stride=stride, + pad=padding, + input_zp=input_zp, + output_zp=output_zp, + accum_dtype=accumulator_type, + ) + + tosa_graph.addOperator( + TosaOp.Op().MAX_POOL2D, + [input_tensor.name], + [output.name], + attr, + ) diff --git a/backends/arm/operators/op_mm.py b/backends/arm/operators/op_mm.py index ebddb3a40e..81334de16c 100644 --- a/backends/arm/operators/op_mm.py +++ b/backends/arm/operators/op_mm.py @@ -14,7 +14,11 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import build_rescale, get_quant_node_args +from executorch.backends.arm.tosa_quant_utils import ( + build_rescale, + get_quant_arg_downstream, + get_quant_arg_upstream, +) from executorch.backends.arm.tosa_utils import ( build_reshape, expand_dims, @@ -54,8 +58,8 @@ def define_node( # For INT8, we need to get the zero point, otherwise it is 0 input0_zp, input1_zp = 0, 0 if is_quant_node: - input0_zp = get_quant_node_args(input0).zp - input1_zp = get_quant_node_args(input1).zp + input0_zp = get_quant_arg_upstream(input0).zp + input1_zp = get_quant_arg_upstream(input1).zp mat_mul_result = tosa_graph.addIntermediate( output_new_shape, ts.DType.INT32 if is_quant_node else output.dtype @@ -86,9 +90,9 @@ def define_node( # As INT8 accumulates into INT32, we need to rescale it back to INT8 if is_quant_node: - input0_q_params = get_quant_node_args(input0) - input1_q_params = get_quant_node_args(input1) - output_q_params = get_quant_node_args(list(node.users)[0]) + input0_q_params = get_quant_arg_upstream(input0) + input1_q_params = get_quant_arg_upstream(input1) + output_q_params = get_quant_arg_downstream(list(node.users)[0]) final_output_scale = ( input0_q_params.scale * input1_q_params.scale diff --git a/backends/arm/operators/op_mul.py b/backends/arm/operators/op_mul.py index c152e8759e..ad578aa1f0 100644 --- a/backends/arm/operators/op_mul.py +++ b/backends/arm/operators/op_mul.py @@ -37,10 +37,10 @@ def define_node( if is_quant_node: input_A = inputs[0] input_B = inputs[1] - input_A_qargs = tqutils.get_quant_node_args( + input_A_qargs = tqutils.get_quant_arg_upstream( cast(torch.fx.Node, node.args[0]) ) - input_B_qargs = tqutils.get_quant_node_args( + input_B_qargs = tqutils.get_quant_arg_upstream( cast(torch.fx.Node, node.args[1]) ) diff --git a/backends/arm/operators/op_permute.py b/backends/arm/operators/op_permute.py index 69f6f6506c..8142d6d654 100644 --- a/backends/arm/operators/op_permute.py +++ b/backends/arm/operators/op_permute.py @@ -14,7 +14,6 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_utils import is_permute_node_before_addmm from serializer.tosa_serializer import TosaOp @@ -81,13 +80,6 @@ def define_node( output: TosaArg, is_quant_node: bool, ) -> None: - if is_permute_node_before_addmm(node): - ## Simply add an identityOp - tosa_graph.addOperator( - TosaOp.Op().IDENTITY, [inputs[0].name], [output.name] - ) - return - # The permutation vector describes a permutation P in default Pytorch dim_order. # For rank 4, the default dim_order NCHW. # E.g. (2,3,0,1) -> permute (n,c,h,w) to (w,c,n,h) diff --git a/backends/arm/operators/op_placeholder.py b/backends/arm/operators/op_placeholder.py index 2618c9e71d..d466a13e38 100644 --- a/backends/arm/operators/op_placeholder.py +++ b/backends/arm/operators/op_placeholder.py @@ -10,22 +10,23 @@ import torch.fx from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_quant_utils import ( - get_quant_arg_dtype, - get_quant_node_args, - is_quant_arg, + get_quant_arg_upstream, + get_quantized_node_output_dtype, + is_node_quantized, ) +from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.backends.arm.tosa_utils import ( - is_bias_node_for_quantized_addmm, is_bias_node_for_quantized_conv, + map_dtype, tosa_shape, ) -from executorch.exir.dialects._ops import ops as exir_ops from torch.export.exported_program import ExportedProgram def process_inputs( node: torch.fx.Node, tosa_graph: ts.TosaSerializer, + tosa_spec: TosaSpecification, ): """Serialize an input node""" # inputs need to be in default dim_order (contiguous memory format) @@ -41,7 +42,11 @@ def process_inputs( tensor = ts.TosaSerializerTensor( inputs[0].name, tosa_shape(input_shape, input_dim_order), - get_quant_arg_dtype(node) if is_quant_arg(node) else inputs[0].dtype, + ( + map_dtype(get_quantized_node_output_dtype(node)) + if is_node_quantized(node) + else inputs[0].dtype + ), data=None, placeholderFilename=inputs[0].name + ".npy", ) @@ -55,28 +60,16 @@ def process_quantized_bias( ): """ Serialize bias node that needs to be quantized. - This can be either an addmm or conv bias node. """ consumer_node = list(node.users)[0] - if is_bias_node_for_quantized_addmm(node): - ( - _, - input_node, - weight_node_permuted, - ) = consumer_node.all_input_nodes - - weight_node = weight_node_permuted.all_input_nodes[0] - if input_node.target == exir_ops.edge.aten.view_copy.default: - input_node = input_node.all_input_nodes[0] - else: - ( - input_node, - weight_node, - _, - ) = consumer_node.all_input_nodes - - input_node_scale = get_quant_node_args(input_node).scale - weight_node_scale = get_quant_node_args(weight_node).scale + ( + input_node, + weight_node, + _, + ) = consumer_node.all_input_nodes + + input_node_scale = get_quant_arg_upstream(input_node).scale + weight_node_scale = get_quant_arg_upstream(weight_node).scale bias_values_quantized = ( (parameter_values / (input_node_scale * weight_node_scale)) .round() @@ -95,6 +88,7 @@ def process_inputs_to_parameters( node: torch.fx.Node, tosa_graph: ts.TosaSerializer, edge_program: ExportedProgram, + tosa_spec: TosaSpecification, ): """Serialize bias and non-quantized weights""" inputs = [TosaArg(node)] @@ -104,11 +98,15 @@ def process_inputs_to_parameters( assert isinstance(parameter_data, torch.Tensor), "Expect Attr to be tensor" parameter_values = parameter_data.detach().numpy() - if is_bias_node_for_quantized_addmm(node) or is_bias_node_for_quantized_conv(node): + if is_bias_node_for_quantized_conv(node): # BI bias + assert tosa_spec.support_integer(), f"{tosa_spec} doesnt't support integer" process_quantized_bias(node, tosa_graph, parameter_values) else: # MI weights or bias + if inputs[0].dtype == torch.float32: + assert tosa_spec.support_float(), f"{tosa_spec} doesn't support float" + parameter_values = np.transpose(parameter_values, inputs[0].dim_order) tosa_graph.addConst( @@ -158,15 +156,16 @@ def process_placeholder( node: torch.fx.Node, tosa_graph: ts.TosaSerializer, edge_program: ExportedProgram, + tosa_spec: TosaSpecification, ): """Wrapper for processing and serializing all types of placeholders""" assert node.name == node.target, "Expect placeholder name and target to match" assert 0 == len(node.args), "Can't handle default input values" if node.name in edge_program.graph_signature.user_inputs: - process_inputs(node, tosa_graph) + process_inputs(node, tosa_graph, tosa_spec) elif node.name in edge_program.graph_signature.inputs_to_parameters: - process_inputs_to_parameters(node, tosa_graph, edge_program) + process_inputs_to_parameters(node, tosa_graph, edge_program, tosa_spec) elif node.name in edge_program.graph_signature.inputs_to_buffers: process_inputs_to_buffers(node, tosa_graph, edge_program) elif node.name in edge_program.graph_signature.inputs_to_lifted_tensor_constants: diff --git a/backends/arm/operators/op_reciprocal.py b/backends/arm/operators/op_reciprocal.py index 3d43fd8f7d..774c4d94b1 100644 --- a/backends/arm/operators/op_reciprocal.py +++ b/backends/arm/operators/op_reciprocal.py @@ -15,7 +15,8 @@ from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_quant_utils import ( dequantize_value, - get_quant_node_args, + get_quant_arg_downstream, + get_quant_arg_upstream, QuantArgs, quantize_value, ) @@ -41,8 +42,8 @@ def define_node( if is_quant_node: input = inputs[0] - input_qargs = get_quant_node_args(node.all_input_nodes[0]) - output_qargs = get_quant_node_args(list(node.users)[0]) + input_qargs = get_quant_arg_upstream(node.all_input_nodes[0]) + output_qargs = get_quant_arg_downstream(list(node.users)[0]) div_table = div_table_8bit(input_qargs, output_qargs) diff --git a/backends/arm/operators/op_relu.py b/backends/arm/operators/op_relu.py index 20bba3f654..a3a7c82ab8 100644 --- a/backends/arm/operators/op_relu.py +++ b/backends/arm/operators/op_relu.py @@ -38,7 +38,7 @@ def define_node( clamp_min_qs = 0 clamp_max_qs = 0 if is_quant_node: - out_qargs = tqutils.get_quant_node_args(list(node.users)[0]) + out_qargs = tqutils.get_quant_arg_downstream(list(node.users)[0]) clamp_min_qs = tqutils.quantize_value(0, out_qargs) clamp_max_qs = tqutils.quantize_value(float("inf"), out_qargs) diff --git a/backends/arm/operators/op_rsqrt.py b/backends/arm/operators/op_rsqrt.py index 9225c7d938..b503a323b1 100644 --- a/backends/arm/operators/op_rsqrt.py +++ b/backends/arm/operators/op_rsqrt.py @@ -16,7 +16,8 @@ from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_quant_utils import ( dequantize_value, - get_quant_node_args, + get_quant_arg_downstream, + get_quant_arg_upstream, QuantArgs, quantize_value, ) @@ -39,9 +40,9 @@ def define_node( # Assume quantized input is 8 bit. # Create attribute for 8 bit table lookup. input_node = node.all_input_nodes[0] - in_quantargs = get_quant_node_args(input_node) + in_quantargs = get_quant_arg_upstream(input_node) output_node = list(node.users)[0] - out_quantargs = get_quant_node_args(output_node) + out_quantargs = get_quant_arg_downstream(output_node) table = rsqrt_table_8bit(in_quantargs, out_quantargs) table_attr = ts.TosaSerializerAttribute() table_attr.TableAttribute(table) diff --git a/backends/arm/operators/op_sigmoid.py b/backends/arm/operators/op_sigmoid.py index 0087b1f7a8..e299e99b43 100644 --- a/backends/arm/operators/op_sigmoid.py +++ b/backends/arm/operators/op_sigmoid.py @@ -17,7 +17,8 @@ from executorch.backends.arm.tosa_quant_utils import ( dequantize_value, - get_quant_node_args, + get_quant_arg_downstream, + get_quant_arg_upstream, QuantArgs, quantize_value, ) @@ -49,9 +50,9 @@ def define_node( # Create attribute for 8 bit table lookup. input_node = node.all_input_nodes[0] - in_quantargs = get_quant_node_args(input_node) + in_quantargs = get_quant_arg_upstream(input_node) output_node = list(node.users)[0] - out_quantargs = get_quant_node_args(output_node) + out_quantargs = get_quant_arg_downstream(output_node) table = sigmoid_table_8bit(in_quantargs, out_quantargs) table_attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_sub.py b/backends/arm/operators/op_sub.py index 2089b6e9e9..b86a5ea3ad 100644 --- a/backends/arm/operators/op_sub.py +++ b/backends/arm/operators/op_sub.py @@ -43,10 +43,8 @@ def define_node( input_nodes, tosa_graph ) - # Preapre sub output tensor - broadcasted_shape = tutils.broadcast_shapes( - rescaled_inputs[0].shape, rescaled_inputs[0].shape - ) + # Prepare sub output tensor + broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order) sub_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32) # Do the INT32 Sub diff --git a/backends/arm/operators/op_tanh.py b/backends/arm/operators/op_tanh.py index 20f343a7f1..2c84580edc 100644 --- a/backends/arm/operators/op_tanh.py +++ b/backends/arm/operators/op_tanh.py @@ -17,7 +17,8 @@ from executorch.backends.arm.tosa_quant_utils import ( dequantize_value, - get_quant_node_args, + get_quant_arg_downstream, + get_quant_arg_upstream, QuantArgs, quantize_value, ) @@ -49,9 +50,9 @@ def define_node( # Create attribute for 8 bit table lookup. input_node = node.all_input_nodes[0] - in_quantargs = get_quant_node_args(input_node) + in_quantargs = get_quant_arg_upstream(input_node) output_node = list(node.users)[0] - out_quantargs = get_quant_node_args(output_node) + out_quantargs = get_quant_arg_downstream(output_node) table = tanh_table_8bit(in_quantargs, out_quantargs) table_attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index 6a68eb2eb9..511aeda1ac 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -75,7 +75,7 @@ def _supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPattern [torch.nn.AdaptiveAvgPool2d], [F.adaptive_avg_pool2d], ], - "mul": [torch.mul], + "mul": [[torch.mul]], "sub": [[torch.sub]], } return copy.deepcopy(supported_operators) @@ -268,7 +268,6 @@ class ArmQuantizer(Quantizer): "sub", "mul", "mm", - "cat", "one_to_one", "generic", "sum", diff --git a/backends/arm/quantizer/arm_quantizer_utils.py b/backends/arm/quantizer/arm_quantizer_utils.py index 4a910611bc..4d52b7ddf1 100644 --- a/backends/arm/quantizer/arm_quantizer_utils.py +++ b/backends/arm/quantizer/arm_quantizer_utils.py @@ -144,21 +144,11 @@ def is_share_obs_or_fq_op(op: Callable) -> bool: torch.ops.aten.mean.dim, torch.ops.aten.permute.default, torch.ops.aten.permute_copy.default, - torch.ops.aten.squeeze.dim, - torch.ops.aten.squeeze.dims, - torch.ops.aten.squeeze.default, - torch.ops.aten.squeeze_copy.dim, - torch.ops.aten.unsqueeze.default, - torch.ops.aten.unsqueeze_copy.default, # TODO: remove? torch.ops.aten.adaptive_avg_pool2d.default, torch.ops.aten.avg_pool2d.default, - torch.ops.aten.view_copy.default, - torch.ops.aten.view.default, + torch.ops.aten.max_pool2d.default, torch.ops.aten.full.default, - torch.ops.aten.slice.Tensor, - torch.ops.aten.split.Tensor, - torch.ops.aten.split_with_sizes.default, torch.ops.aten.flatten.using_ints, torch.ops.aten.dropout.default, operator.getitem, diff --git a/backends/arm/quantizer/quantization_annotation/__init__.py b/backends/arm/quantizer/quantization_annotation/__init__.py index bc3184298f..7eaa837c5b 100644 --- a/backends/arm/quantizer/quantization_annotation/__init__.py +++ b/backends/arm/quantizer/quantization_annotation/__init__.py @@ -51,7 +51,6 @@ def decorator(annotator: AnnotatorType): from . import ( # noqa adaptive_ang_pool2d_annotator, add_annotator, - cat_annotator, conv_annotator, generic_annotator, linear_annotator, diff --git a/backends/arm/quantizer/quantization_annotation/cat_annotator.py b/backends/arm/quantizer/quantization_annotation/cat_annotator.py deleted file mode 100644 index 6e138cd9de..0000000000 --- a/backends/arm/quantizer/quantization_annotation/cat_annotator.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -import itertools -from typing import Callable, cast, List, Optional - -import torch.fx -from executorch.backends.arm.quantizer import arm_quantizer_utils -from executorch.backends.arm.quantizer.quantization_annotation import register_annotator -from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig -from torch.ao.quantization.quantizer import ( - QuantizationAnnotation, - SharedQuantizationSpec, -) -from torch.fx import Node -from torch.fx.passes.utils.source_matcher_utils import get_source_partitions - - -@register_annotator("cat") -def _annotate_cat( - gm: torch.fx.GraphModule, - quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[Node], bool]] = None, -) -> Optional[List[List[Node]]]: - cat_partitions = get_source_partitions(gm.graph, [torch.cat], filter_fn) - cat_partitions = list(itertools.chain.from_iterable(cat_partitions.values())) - annotated_partitions = [] - for cat_partition in cat_partitions: - annotated_partitions.append(cat_partition.nodes) - cat_node = cat_partition.output_nodes[0] - if arm_quantizer_utils.is_annotated(cat_node): - continue - - input_acts = cast(list[torch.fx.Node], cat_node.args[0]) - input_act0 = input_acts[0] - - input_act_qspec = quantization_config.get_input_act_qspec() - shared_with_input0_qspec = SharedQuantizationSpec((input_act0, cat_node)) - - input_qspec_map = {} - - # First input is set to input qspec from the quantization config. - if isinstance(input_act0, Node): - if not arm_quantizer_utils.is_input_ok_for_quantization(input_act0, gm): - continue - input_qspec_map[input_act0] = input_act_qspec - - # For the rest of the inputs, share qspec with first. - # If we can't quantize any of the inputs, abort annotation. - for input_act in input_acts[1:]: - if isinstance(input_act, Node): - if not arm_quantizer_utils.is_input_ok_for_quantization(input_act, gm): - continue - if input_act is not input_act0: - input_qspec_map[input_act] = shared_with_input0_qspec - - if input_qspec_map is not None: - cat_node.meta["quantization_annotation"] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, - output_qspec=shared_with_input0_qspec, - _annotated=True, - ) - return annotated_partitions diff --git a/backends/arm/quantizer/quantization_annotation/generic_annotator.py b/backends/arm/quantizer/quantization_annotation/generic_annotator.py index f91df1398e..b093eec808 100644 --- a/backends/arm/quantizer/quantization_annotation/generic_annotator.py +++ b/backends/arm/quantizer/quantization_annotation/generic_annotator.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe - from typing import Callable, List, Optional import torch @@ -24,28 +23,36 @@ # DATA LAYOUT OPS torch.ops.aten.squeeze.default, torch.ops.aten.squeeze_copy.default, + torch.ops.aten.squeeze_copy.dim, + torch.ops.aten.squeeze.dim, + torch.ops.aten.squeeze.dims, torch.ops.aten.unsqueeze.default, torch.ops.aten.unsqueeze_copy.default, torch.ops.aten.reshape.default, + torch.ops.aten.repeat.default, + torch.ops.aten.expand_copy.default, + torch.ops.aten.expand.default, # Disabling these as there seems to be an issue with support for complex # datatypes in torch: # torch.ops.aten.view_as_complex.default, # torch.ops.aten.view_as_complex_copy.default, # torch.ops.aten.view_as_real.default, # torch.ops.aten.view_as_real_copy.default, + torch.ops.aten.view.default, torch.ops.aten.view_copy.default, torch.ops.aten.select.int, torch.ops.aten.select_copy.int, torch.ops.aten.slice.Tensor, torch.ops.aten.slice_copy.Tensor, - # 'concat' should be handled separately as it has a sequence of inputs and - # makes the implementation unnecessary complicated. - # torch.ops.aten.concat.default, + torch.ops.aten.split.Tensor, + torch.ops.aten.split_with_sizes.default, torch.ops.aten.transpose.Dimname, torch.ops.aten.transpose.int, torch.ops.aten.transpose_copy.int, torch.ops.aten.tile.default, torch.ops.aten.flip.default, + torch.ops.aten.cat.default, + torch.ops.aten.stack.default, ] @@ -66,15 +73,31 @@ def _annotate_generic( if arm_quantizer_utils.is_annotated(node): continue - input_node = node.args[0] + input_acts = node.args[0] + + # Check to see if there are multiple inputs. + # this allows for stack/cat ops to be annotated + # in a similar way. + has_multi_inputs = isinstance(input_acts, list) + + input_act0 = input_acts[0] if has_multi_inputs else input_acts # Using a non-shared quantization spec here as a SharedQuantizationSpec # can lead to a recursion. _annotate_input_qspec_map( - node, input_node, quantization_config.get_input_act_qspec() + node, input_act0, quantization_config.get_input_act_qspec() ) - _annotate_output_qspec(node, SharedQuantizationSpec((input_node, node))) + shared_with_input0_qspec = SharedQuantizationSpec((input_act0, node)) + + if has_multi_inputs: + # For the rest of the inputs, share qspec with first. + for input_act in input_acts[1:]: + if input_act is not input_act0: + node.meta["quantization_annotation"].input_qspec_map[ + input_act + ] = shared_with_input0_qspec + _annotate_output_qspec(node, shared_with_input0_qspec) arm_quantizer_utils.mark_nodes_as_annotated([node]) annotated_partitions.append([node]) diff --git a/backends/arm/quantizer/quantization_annotation/mm_annotator.py b/backends/arm/quantizer/quantization_annotation/mm_annotator.py index b48c6d5990..60d9adb1c3 100644 --- a/backends/arm/quantizer/quantization_annotation/mm_annotator.py +++ b/backends/arm/quantizer/quantization_annotation/mm_annotator.py @@ -24,7 +24,9 @@ def _annotate_mm( quantization_config: QuantizationConfig, filter_fn: Optional[Callable[[Node], bool]] = None, ) -> Optional[List[List[Node]]]: - mm_partitions = get_source_partitions(gm.graph, [torch.mm, torch.bmm], filter_fn) + mm_partitions = get_source_partitions( + gm.graph, [torch.mm, torch.bmm, torch.matmul], filter_fn + ) mm_partitions = list(itertools.chain.from_iterable(mm_partitions.values())) annotated_partitions = [] for mm_partition in mm_partitions: diff --git a/backends/arm/test/common.py b/backends/arm/test/common.py index 2ae86b1d1e..3a9818929b 100644 --- a/backends/arm/test/common.py +++ b/backends/arm/test/common.py @@ -11,6 +11,10 @@ import subprocess import sys import tempfile +from datetime import datetime +from enum import auto, Enum +from pathlib import Path +from typing import Any import pytest @@ -19,7 +23,15 @@ from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder from executorch.exir.backend.compile_spec_schema import CompileSpec -_enabled_options: list[str] = [] + +class arm_test_options(Enum): + quantize_io = auto() + corstone300 = auto() + dump_path = auto() + date_format = auto() + + +_test_options: dict[arm_test_options, Any] = {} # ==== Pytest hooks ==== @@ -27,19 +39,30 @@ def pytest_addoption(parser): parser.addoption("--arm_quantize_io", action="store_true") parser.addoption("--arm_run_corstone300", action="store_true") + parser.addoption("--default_dump_path", default=None) + parser.addoption("--date_format", default="%d-%b-%H:%M:%S") def pytest_configure(config): if config.option.arm_quantize_io: load_libquantized_ops_aot_lib() - _enabled_options.append("quantize_io") + _test_options[arm_test_options.quantize_io] = True if config.option.arm_run_corstone300: corstone300_exists = shutil.which("FVP_Corstone_SSE-300_Ethos-U55") if not corstone300_exists: raise RuntimeError( "Tests are run with --arm_run_corstone300 but corstone300 FVP is not installed." ) - _enabled_options.append("corstone300") + _test_options[arm_test_options.corstone300] = True + if config.option.default_dump_path: + dump_path = Path(config.option.default_dump_path).expanduser() + if dump_path.exists() and os.path.isdir(dump_path): + _test_options[arm_test_options.dump_path] = dump_path + else: + raise RuntimeError( + f"Supplied argument 'default_dump_path={dump_path}' that does not exist or is not a directory." + ) + _test_options[arm_test_options.date_format] = config.option.date_format logging.basicConfig(level=logging.INFO, stream=sys.stdout) @@ -54,8 +77,31 @@ def pytest_collection_modifyitems(config, items): item.add_marker(skip_if_aot_lib_not_loaded) +def pytest_sessionstart(session): + pass + + +def pytest_sessionfinish(session, exitstatus): + if get_option(arm_test_options.dump_path): + _clean_dir( + get_option(arm_test_options.dump_path), + f"ArmTester_{get_option(arm_test_options.date_format)}.log", + ) + + # ==== End of Pytest hooks ===== +# ==== Custom Pytest decorators ===== + + +def expectedFailureOnFVP(test_item): + if is_option_enabled("corstone300"): + test_item.__unittest_expecting_failure__ = True + return test_item + + +# ==== End of Custom Pytest decorators ===== + def load_libquantized_ops_aot_lib(): so_ext = { @@ -76,7 +122,9 @@ def load_libquantized_ops_aot_lib(): torch.ops.load_library(library_path) -def is_option_enabled(option: str, fail_if_not_enabled: bool = False) -> bool: +def is_option_enabled( + option: str | arm_test_options, fail_if_not_enabled: bool = False +) -> bool: """ Returns whether an option is successfully enabled, i.e. if the flag was given to pytest and the necessary requirements are available. @@ -87,7 +135,10 @@ def is_option_enabled(option: str, fail_if_not_enabled: bool = False) -> bool: The optional parameter 'fail_if_not_enabled' makes the function raise a RuntimeError instead of returning False. """ - if option.lower() in _enabled_options: + if isinstance(option, str): + option = arm_test_options[option.lower()] + + if option in _test_options and _test_options[option]: return True else: if fail_if_not_enabled: @@ -96,6 +147,12 @@ def is_option_enabled(option: str, fail_if_not_enabled: bool = False) -> bool: return False +def get_option(option: arm_test_options) -> Any | None: + if option in _test_options: + return _test_options[option] + return None + + def maybe_get_tosa_collate_path() -> str | None: """ Checks the environment variable TOSA_TESTCASES_BASE_PATH and returns the @@ -120,16 +177,18 @@ def maybe_get_tosa_collate_path() -> str | None: def get_tosa_compile_spec( - permute_memory_to_nhwc=True, custom_path=None + tosa_version: str, permute_memory_to_nhwc=True, custom_path=None ) -> list[CompileSpec]: """ Default compile spec for TOSA tests. """ - return get_tosa_compile_spec_unbuilt(permute_memory_to_nhwc, custom_path).build() + return get_tosa_compile_spec_unbuilt( + tosa_version, permute_memory_to_nhwc, custom_path + ).build() def get_tosa_compile_spec_unbuilt( - permute_memory_to_nhwc=False, custom_path=None + tosa_version: str, permute_memory_to_nhwc=False, custom_path=None ) -> ArmCompileSpecBuilder: """Get the ArmCompileSpecBuilder for the default TOSA tests, to modify the compile spec before calling .build() to finalize it. @@ -145,7 +204,7 @@ def get_tosa_compile_spec_unbuilt( os.makedirs(intermediate_path, exist_ok=True) compile_spec_builder = ( ArmCompileSpecBuilder() - .tosa_compile_spec() + .tosa_compile_spec(tosa_version) .set_permute_memory_format(permute_memory_to_nhwc) .dump_intermediate_artifacts_to(intermediate_path) ) @@ -219,3 +278,32 @@ def get_u85_compile_spec_unbuilt( .dump_intermediate_artifacts_to(artifact_path) ) return compile_spec + + +def current_time_formated() -> str: + """Return current time as a formated string""" + return datetime.now().strftime(get_option(arm_test_options.date_format)) + + +def _clean_dir(dir: Path, filter: str, num_save=10): + sorted_files: list[tuple[datetime, Path]] = [] + for file in dir.iterdir(): + try: + creation_time = datetime.strptime(file.name, filter) + insert_index = -1 + for i, to_compare in enumerate(sorted_files): + compare_time = to_compare[0] + if creation_time < compare_time: + insert_index = i + break + if insert_index == -1 and len(sorted_files) < num_save: + sorted_files.append((creation_time, file)) + else: + sorted_files.insert(insert_index, (creation_time, file)) + except ValueError: + continue + + if len(sorted_files) > num_save: + for remove in sorted_files[0 : len(sorted_files) - num_save]: + file = remove[1] + file.unlink() diff --git a/backends/arm/test/misc/test_debug_feats.py b/backends/arm/test/misc/test_debug_feats.py index 7d9a18a80e..4cac39af70 100644 --- a/backends/arm/test/misc/test_debug_feats.py +++ b/backends/arm/test/misc/test_debug_feats.py @@ -49,7 +49,7 @@ def _tosa_MI_pipeline(self, module: torch.nn.Module, dump_file=None): ArmTester( module, example_inputs=module.get_inputs(), - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .to_edge() @@ -63,12 +63,11 @@ def _tosa_BI_pipeline(self, module: torch.nn.Module, dump_file=None): ArmTester( module, example_inputs=module.get_inputs(), - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() - .to_edge() - .partition() + .to_edge_transform_and_lower() .dump_artifact(dump_file) .dump_artifact() ) @@ -107,11 +106,12 @@ def test_numerical_diff_prints(self): ArmTester( model, example_inputs=model.get_inputs(), - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=False), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+MI", permute_memory_to_nhwc=True + ), ) .export() - .to_edge() - .partition() + .to_edge_transform_and_lower() .to_executorch() ) # We expect an assertion error here. Any other issues will cause the @@ -132,7 +132,7 @@ def test_dump_ops_and_dtypes(): ArmTester( model, example_inputs=model.get_inputs(), - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .dump_dtype_distribution() @@ -140,10 +140,7 @@ def test_dump_ops_and_dtypes(): .export() .dump_dtype_distribution() .dump_operator_distribution() - .to_edge() - .dump_dtype_distribution() - .dump_operator_distribution() - .partition() + .to_edge_transform_and_lower() .dump_dtype_distribution() .dump_operator_distribution() ) @@ -156,7 +153,7 @@ def test_dump_ops_and_dtypes_parseable(): ArmTester( model, example_inputs=model.get_inputs(), - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .dump_dtype_distribution(print_table=False) @@ -164,10 +161,7 @@ def test_dump_ops_and_dtypes_parseable(): .export() .dump_dtype_distribution(print_table=False) .dump_operator_distribution(print_table=False) - .to_edge() - .dump_dtype_distribution(print_table=False) - .dump_operator_distribution(print_table=False) - .partition() + .to_edge_transform_and_lower() .dump_dtype_distribution(print_table=False) .dump_operator_distribution(print_table=False) ) @@ -187,12 +181,11 @@ def test_collate_tosa_BI_tests(self): ArmTester( model, example_inputs=model.get_inputs(), - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() - .to_edge() - .partition() + .to_edge_transform_and_lower() .to_executorch() ) # test that the output directory is created and contains the expected files @@ -200,10 +193,10 @@ def test_collate_tosa_BI_tests(self): "test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests" ) assert os.path.exists( - "test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests/output_tag8.tosa" + "test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests/output_tag5.tosa" ) assert os.path.exists( - "test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests/desc_tag8.json" + "test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests/desc_tag5.json" ) os.environ.pop("TOSA_TESTCASES_BASE_PATH") @@ -217,12 +210,11 @@ def test_dump_tosa_ops(caplog): ArmTester( model, example_inputs=model.get_inputs(), - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() - .to_edge() - .partition() + .to_edge_transform_and_lower() .dump_operator_distribution() ) assert "TOSA operators:" in caplog.text @@ -241,8 +233,7 @@ def forward(self, x): ArmTester(model, example_inputs=(torch.ones(5),), compile_spec=compile_spec) .quantize() .export() - .to_edge() - .partition() + .to_edge_transform_and_lower() .dump_operator_distribution() ) assert "Can not get operator distribution for Vela command stream." in caplog.text diff --git a/backends/arm/test/misc/test_dim_order_guards.py b/backends/arm/test/misc/test_dim_order_guards.py index 8bad1493b1..d7406afe95 100644 --- a/backends/arm/test/misc/test_dim_order_guards.py +++ b/backends/arm/test/misc/test_dim_order_guards.py @@ -34,7 +34,7 @@ def test_tosa_MI_pipeline(self): ArmTester( module, example_inputs=module.get_inputs(), - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .to_edge() @@ -48,7 +48,7 @@ def test_tosa_BI_pipeline(self): ArmTester( module, example_inputs=module.get_inputs(), - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/misc/test_lifted_tensor.py b/backends/arm/test/misc/test_lifted_tensor.py index 29b2887431..12b8d0665b 100644 --- a/backends/arm/test/misc/test_lifted_tensor.py +++ b/backends/arm/test/misc/test_lifted_tensor.py @@ -60,7 +60,7 @@ def test_partition_lifted_tensor_tosa_MI(self, op, data): ArmTester( LiftedTensor(op), example_inputs=data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .to_edge() @@ -77,7 +77,7 @@ def test_partition_lifted_tensor_tosa_BI(self, op, data): ArmTester( LiftedTensor(op), example_inputs=data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() @@ -95,7 +95,7 @@ def test_partition_lifted_scalar_tensor_tosa_MI(self, op, data, arg1): ArmTester( LiftedScalarTensor(op, arg1), example_inputs=(data), - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .to_edge() @@ -110,7 +110,7 @@ def test_partition_lifted_scalar_tensor_tosa_BI(self, op, data, arg1): ArmTester( LiftedScalarTensor(op, arg1), example_inputs=(data), - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/misc/test_tosa_spec.py b/backends/arm/test/misc/test_tosa_spec.py new file mode 100644 index 0000000000..5cbad140b7 --- /dev/null +++ b/backends/arm/test/misc/test_tosa_spec.py @@ -0,0 +1,105 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from executorch.backends.arm.tosa_specification import ( + Tosa_0_80, + Tosa_1_00, + TosaSpecification, +) + +from executorch.exir.backend.compile_spec_schema import CompileSpec +from parameterized import parameterized + +test_valid_0_80_strings = [ + "TOSA-0.80.0+BI", + "TOSA-0.80.0+MI+8k", + "TOSA-0.80.0+BI+u55", +] +test_valid_1_00_strings = [ + "TOSA-1.00.0+INT+FP+fft", + "TOSA-1.00.0+FP+bf16+fft", + "TOSA-1.00.0+INT+int4+cf", + "TOSA-1.00.0+FP+cf+bf16+8k", + "TOSA-1.00.0+FP+INT+bf16+fft+int4+cf", + "TOSA-1.00.0+FP+INT+fft+int4+cf+8k", +] + +test_valid_1_00_extensions = { + "INT": ["int16", "int4", "var", "cf"], + "FP": ["bf16", "fp8e4m3", "fp8e5m2", "fft", "var", "cf"], +} + +test_invalid_strings = [ + "TOSA-0.80.0+bi", + "TOSA-0.80.0", + "TOSA-0.80.0+8k", + "TOSA-0.80.0+BI+MI", + "TOSA-0.80.0+BI+U55", + "TOSA-1.00.0+fft", + "TOSA-1.00.0+fp+bf16+fft", + "TOSA-1.00.0+INT+INT4+cf", + "TOSA-1.00.0+BI", + "TOSA-1.00.0+FP+FP+INT", + "TOSA-1.00.0+FP+CF+bf16", + "TOSA-1.00.0+BF16+fft+int4+cf+INT", +] + +test_compile_specs = [ + ([CompileSpec("tosa_version", "TOSA-0.80.0+BI".encode())],), + ([CompileSpec("tosa_version", "TOSA-0.80.0+BI+u55".encode())],), + ([CompileSpec("tosa_version", "TOSA-1.00.0+INT".encode())],), +] + +test_compile_specs_no_version = [ + ([CompileSpec("other_key", "TOSA-0.80.0+BI".encode())],), + ([CompileSpec("other_key", "some_value".encode())],), +] + + +class TestTosaSpecification(unittest.TestCase): + """Tests the TOSA specification class""" + + @parameterized.expand(test_valid_0_80_strings) + def test_version_string_0_80(self, version_string: str): + tosa_spec = TosaSpecification.create_from_string(version_string) + assert isinstance(tosa_spec, Tosa_0_80) + assert tosa_spec.profile in ["BI", "MI"] + + @parameterized.expand(test_valid_1_00_strings) + def test_version_string_1_00(self, version_string: str): + tosa_spec = TosaSpecification.create_from_string(version_string) + assert isinstance(tosa_spec, Tosa_1_00) + assert [profile in ["INT", "FP"] for profile in tosa_spec.profiles].count( + True + ) > 0 + + for profile in tosa_spec.profiles: + assert [ + e in test_valid_1_00_extensions[profile] for e in tosa_spec.extensions + ] + + @parameterized.expand(test_invalid_strings) + def test_invalid_version_strings(self, version_string: str): + tosa_spec = None + with self.assertRaises(ValueError): + tosa_spec = TosaSpecification.create_from_string(version_string) + + assert tosa_spec is None + + @parameterized.expand(test_compile_specs) + def test_create_from_compilespec(self, compile_specs: list[CompileSpec]): + tosa_spec = TosaSpecification.create_from_compilespecs(compile_specs) + assert isinstance(tosa_spec, TosaSpecification) + + @parameterized.expand(test_compile_specs_no_version) + def test_create_from_invalid_compilespec(self, compile_specs: list[CompileSpec]): + tosa_spec = None + with self.assertRaises(ValueError): + tosa_spec = TosaSpecification.create_from_compilespecs(compile_specs) + + assert tosa_spec is None diff --git a/backends/arm/test/models/test_mobilenet_v2_arm.py b/backends/arm/test/models/test_mobilenet_v2_arm.py index a50e2732f1..19b4254575 100644 --- a/backends/arm/test/models/test_mobilenet_v2_arm.py +++ b/backends/arm/test/models/test_mobilenet_v2_arm.py @@ -54,12 +54,12 @@ def test_mv2_tosa_MI(self): ArmTester( self.mv2, example_inputs=self.model_inputs, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+MI", permute_memory_to_nhwc=True + ), ) .export() - .to_edge(config=self._edge_compile_config) - .check(list(self.all_operators)) - .partition() + .to_edge_transform_and_lower(edge_compile_config=self._edge_compile_config) .to_executorch() .run_method_and_compare_outputs(inputs=self.model_inputs) ) @@ -69,13 +69,13 @@ def test_mv2_tosa_BI(self): ArmTester( self.mv2, example_inputs=self.model_inputs, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+BI", permute_memory_to_nhwc=True + ), ) .quantize() .export() - .to_edge(config=self._edge_compile_config) - .check(list(self.operators_after_quantization)) - .partition() + .to_edge_transform_and_lower(edge_compile_config=self._edge_compile_config) .to_executorch() # atol=1.0 is a defensive upper limit # TODO MLETROCH-72 @@ -92,9 +92,7 @@ def test_mv2_u55_BI(self): ) .quantize() .export() - .to_edge(config=self._edge_compile_config) - .check(list(self.operators_after_quantization)) - .partition() + .to_edge_transform_and_lower(edge_compile_config=self._edge_compile_config) .to_executorch() .serialize() ) @@ -112,9 +110,7 @@ def test_mv2_u85_BI(self): ) .quantize() .export() - .to_edge(config=self._edge_compile_config) - .check(list(self.operators_after_quantization)) - .partition() + .to_edge_transform_and_lower(edge_compile_config=self._edge_compile_config) .to_executorch() .serialize() ) diff --git a/backends/arm/test/ops/test_add.py b/backends/arm/test/ops/test_add.py index e3eeb187da..66e278ee0f 100644 --- a/backends/arm/test/ops/test_add.py +++ b/backends/arm/test/ops/test_add.py @@ -61,7 +61,7 @@ def _test_add_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.add.Tensor": 1}) @@ -80,7 +80,7 @@ def _test_add_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_avg_pool.py b/backends/arm/test/ops/test_avg_pool.py index 344a80a79b..afd079fb95 100644 --- a/backends/arm/test/ops/test_avg_pool.py +++ b/backends/arm/test/ops/test_avg_pool.py @@ -55,7 +55,9 @@ def _test_avgpool2d_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+MI", permute_memory_to_nhwc=True + ), ) .export() .check(["torch.ops.aten.avg_pool2d.default"]) @@ -76,7 +78,9 @@ def _test_avgpool2d_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+BI", permute_memory_to_nhwc=True + ), ) .quantize(Quantize(quantizer, get_symmetric_quantization_config())) .export() diff --git a/backends/arm/test/ops/test_batch_norm.py b/backends/arm/test/ops/test_batch_norm.py index 4935e910d6..297ac0af1c 100644 --- a/backends/arm/test/ops/test_batch_norm.py +++ b/backends/arm/test/ops/test_batch_norm.py @@ -533,12 +533,9 @@ def _test_batchnorm2d_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() - .check_count( - {"torch.ops.aten._native_batch_norm_legit_no_training.default": 1} - ) .check_not(["torch.ops.quantized_decomposed"]) .to_edge() .check_count( @@ -564,7 +561,7 @@ def _test_batchnorm2d_no_stats_tosa_MI_pipeline( ArmTester( module, example_example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten._native_batch_norm_legit.no_stats": 1}) @@ -593,7 +590,7 @@ def _test_batchnorm2d_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_bmm.py b/backends/arm/test/ops/test_bmm.py index e4e6abb7bb..6246657120 100644 --- a/backends/arm/test/ops/test_bmm.py +++ b/backends/arm/test/ops/test_bmm.py @@ -32,6 +32,12 @@ class BMM(torch.nn.Module): def forward(self, x, y): return torch.bmm(x, y) + class MatMul(torch.nn.Module): + test_parameters = [(torch.rand(2, 3, 5), torch.rand(2, 5, 2))] + + def forward(self, x, y): + return torch.matmul(x, y) + class BMMSingleInput(torch.nn.Module): test_parameters = [ (torch.rand(20, 3, 3),), @@ -50,12 +56,12 @@ def _test_bmm_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() - .check_count({"torch.ops.aten.bmm.default": 1}) .check_not(["torch.ops.quantized_decomposed"]) .to_edge() + .check_count({"executorch_exir_dialects_edge__ops_aten_bmm_default": 1}) .partition() .check_not(["executorch_exir_dialects_edge__ops_aten_bmm_default"]) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) @@ -70,13 +76,13 @@ def _test_bmm_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() - .check_count({"torch.ops.aten.bmm.default": 1}) .check(["torch.ops.quantized_decomposed"]) .to_edge() + .check_count({"executorch_exir_dialects_edge__ops_aten_bmm_default": 1}) .partition() .check_not(["executorch_exir_dialects_edge__ops_aten_bmm_default"]) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) @@ -116,6 +122,16 @@ def test_bmm_single_input_tosa_MI(self, operand1: torch.Tensor): test_data = (operand1,) self._test_bmm_tosa_MI_pipeline(self.BMMSingleInput(), test_data) + @parameterized.expand(MatMul.test_parameters) + def test_matmul_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor): + test_data = (operand1, operand2) + self._test_bmm_tosa_MI_pipeline(self.MatMul(), test_data) + + @parameterized.expand(MatMul.test_parameters) + def test_matmul_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): + test_data = (operand1, operand2) + self._test_bmm_tosa_BI_pipeline(self.MatMul(), test_data) + @parameterized.expand(BMM.test_parameters) def test_bmm_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): test_data = (operand1, operand2) diff --git a/backends/arm/test/ops/test_cat.py b/backends/arm/test/ops/test_cat.py index 9723ba0f0c..b380c44d52 100644 --- a/backends/arm/test/ops/test_cat.py +++ b/backends/arm/test/ops/test_cat.py @@ -56,7 +56,7 @@ def _test_cat_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.cat.default": 1}) @@ -76,7 +76,7 @@ def _test_cat_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_clone.py b/backends/arm/test/ops/test_clone.py index 9852c5c452..4721f257b0 100644 --- a/backends/arm/test/ops/test_clone.py +++ b/backends/arm/test/ops/test_clone.py @@ -47,7 +47,7 @@ def _test_clone_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.clone.default": 1}) @@ -66,7 +66,7 @@ def _test_clone_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize(Quantize(quantizer, get_symmetric_quantization_config())) .export() diff --git a/backends/arm/test/ops/test_conv1d.py b/backends/arm/test/ops/test_conv1d.py index 3b27554221..133148faef 100644 --- a/backends/arm/test/ops/test_conv1d.py +++ b/backends/arm/test/ops/test_conv1d.py @@ -226,7 +226,9 @@ def _test_conv1d_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+MI", permute_memory_to_nhwc=True + ), ) .export() .to_edge() @@ -246,7 +248,9 @@ def _test_conv1d_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+BI", permute_memory_to_nhwc=True + ), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_conv2d.py b/backends/arm/test/ops/test_conv2d.py index 46adfc8a01..43c3e85139 100644 --- a/backends/arm/test/ops/test_conv2d.py +++ b/backends/arm/test/ops/test_conv2d.py @@ -253,7 +253,9 @@ def _test_conv2d_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+MI", permute_memory_to_nhwc=True + ), ) .export() .to_edge() @@ -273,7 +275,9 @@ def _test_conv2d_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+BI", permute_memory_to_nhwc=True + ), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_conv_combos.py b/backends/arm/test/ops/test_conv_combos.py index 4b45b67126..3e9bdef958 100644 --- a/backends/arm/test/ops/test_conv_combos.py +++ b/backends/arm/test/ops/test_conv_combos.py @@ -192,7 +192,9 @@ def _test_conv_combo_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+MI", permute_memory_to_nhwc=True + ), ) .export() .to_edge() @@ -214,7 +216,9 @@ def _test_conv_combo_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+BI", permute_memory_to_nhwc=True + ), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_depthwise_conv.py b/backends/arm/test/ops/test_depthwise_conv.py index 01ffbc1054..4bfa863c49 100644 --- a/backends/arm/test/ops/test_depthwise_conv.py +++ b/backends/arm/test/ops/test_depthwise_conv.py @@ -177,7 +177,9 @@ def _test_dw_conv_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+MI", permute_memory_to_nhwc=True + ), ) .export() .to_edge() @@ -195,7 +197,9 @@ def _test_dw_conv_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+BI", permute_memory_to_nhwc=True + ), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_div.py b/backends/arm/test/ops/test_div.py index 84a8d53f9d..28cc686690 100644 --- a/backends/arm/test/ops/test_div.py +++ b/backends/arm/test/ops/test_div.py @@ -102,7 +102,7 @@ def _test_div_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.div.Tensor": 1}) @@ -121,7 +121,7 @@ def _test_div_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_exp.py b/backends/arm/test/ops/test_exp.py index 6e85d8fe49..c706b7b206 100644 --- a/backends/arm/test/ops/test_exp.py +++ b/backends/arm/test/ops/test_exp.py @@ -40,7 +40,7 @@ def _test_exp_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check(["torch.ops.aten.exp.default"]) @@ -58,7 +58,7 @@ def _test_exp_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tuple): ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_expand.py b/backends/arm/test/ops/test_expand.py index aa13a6475c..effa7ce713 100644 --- a/backends/arm/test/ops/test_expand.py +++ b/backends/arm/test/ops/test_expand.py @@ -46,7 +46,7 @@ def _test_expand_tosa_MI_pipeline(self, module: torch.nn.Module, test_data: Tupl ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.expand.default": 1}) @@ -64,7 +64,7 @@ def _test_expand_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tupl ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize(Quantize(quantizer, get_symmetric_quantization_config())) .export() diff --git a/backends/arm/test/ops/test_full.py b/backends/arm/test/ops/test_full.py index 2722edef32..d4cfc5c369 100644 --- a/backends/arm/test/ops/test_full.py +++ b/backends/arm/test/ops/test_full.py @@ -57,7 +57,7 @@ def _test_full_tosa_MI_pipeline( ArmTester( module, example_inputs=example_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.full.default": 1}) @@ -80,7 +80,7 @@ def _test_full_tosa_BI_pipeline( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec( - permute_memory_to_nhwc=permute_memory_to_nhwc + "TOSA-0.80.0+BI", permute_memory_to_nhwc=permute_memory_to_nhwc ), ) .quantize() diff --git a/backends/arm/test/ops/test_hardtanh.py b/backends/arm/test/ops/test_hardtanh.py index c7c3736e37..a9f12abdf0 100644 --- a/backends/arm/test/ops/test_hardtanh.py +++ b/backends/arm/test/ops/test_hardtanh.py @@ -52,7 +52,7 @@ def _test_hardtanh_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check(["torch.ops.aten.hardtanh.default"]) @@ -73,7 +73,7 @@ def _test_hardtanh_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize(Quantize(quantizer, get_symmetric_quantization_config())) .export() diff --git a/backends/arm/test/ops/test_layer_norm.py b/backends/arm/test/ops/test_layer_norm.py index 0150c20524..f059d71eba 100644 --- a/backends/arm/test/ops/test_layer_norm.py +++ b/backends/arm/test/ops/test_layer_norm.py @@ -74,7 +74,9 @@ def _test_layernorm_tosa_MI_pipeline( ArmTester( model=module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+MI", permute_memory_to_nhwc=True + ), ) .export() .check(["torch.ops.aten.layer_norm.default"]) @@ -93,7 +95,9 @@ def _test_layernorm_tosa_BI_pipeline( ArmTester( model=module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+BI", permute_memory_to_nhwc=True + ), ) .quantize() .check_not(["torch.ops.aten.layer_norm.default"]) diff --git a/backends/arm/test/ops/test_linear.py b/backends/arm/test/ops/test_linear.py index 3f68ab0251..30d4b2890a 100644 --- a/backends/arm/test/ops/test_linear.py +++ b/backends/arm/test/ops/test_linear.py @@ -23,70 +23,82 @@ test_data_suite_rank1 = [ - # (test_name, test_data, out_features) + # (test_name, test_data, out_features, has_bias) ( "model_linear_rank1_zeros", torch.zeros(10), 15, + True, ), ( "model_linear_rank1_ones", torch.ones(10), 15, + False, ), ( "model_linear_rank1_negative_ones", torch.ones(10) * (-1), 20, + True, ), ( "model_linear_rank1_rand", torch.rand(10), 10, + True, ), ( "model_linear_rank1_negative_large_rand", torch.rand(10) * (-100), 30, + False, ), ( "model_linear_rank1_large_randn", torch.randn(15) * 100, 20, + True, ), ] test_data_suite_rank4 = [ - # (test_name, test_data, out_features) + # (test_name, test_data, out_features, has_bias) ( "model_linear_rank4_zeros", torch.zeros(5, 10, 25, 20), 30, + True, ), ( "model_linear_rank4_ones", torch.ones(5, 10, 25, 20), 30, + False, ), ( "model_linear_rank4_negative_ones", torch.ones(5, 10, 25, 20) * (-1), 30, + True, ), ( "model_linear_rank4_rand", torch.rand(5, 10, 25, 20), 30, + False, ), ( "model_linear_rank4_negative_large_rand", torch.rand(5, 10, 25, 20) * (-100), 30, + True, ), ( "model_linear_rank4_large_randn", torch.randn(5, 10, 25, 20) * 100, 30, + False, ), ] @@ -122,13 +134,14 @@ def _test_linear_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=False), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+MI", permute_memory_to_nhwc=True + ), ) .export() .check_count({"torch.ops.aten.linear.default": 1}) .check_not(["torch.ops.quantized_decomposed"]) - .to_edge(config=self._edge_compile_config) - .partition() + .to_edge_transform_and_lower(edge_compile_config=self._edge_compile_config) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() .run_method_and_compare_outputs(inputs=test_data) @@ -141,17 +154,18 @@ def _test_linear_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=False), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+BI", permute_memory_to_nhwc=True + ), ) .quantize() .export() .check_count({"torch.ops.aten.linear.default": 1}) .check(["torch.ops.quantized_decomposed"]) - .to_edge(config=self._edge_compile_config) - .partition() + .to_edge_transform_and_lower(edge_compile_config=self._edge_compile_config) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() - .run_method_and_compare_outputs(inputs=test_data, qtol=True) + .run_method_and_compare_outputs(inputs=test_data, qtol=1) ) def _test_linear_tosa_ethosu_BI_pipeline( @@ -170,8 +184,7 @@ def _test_linear_tosa_ethosu_BI_pipeline( .export() .check_count({"torch.ops.aten.linear.default": 1}) .check(["torch.ops.quantized_decomposed"]) - .to_edge(config=self._edge_compile_config) - .partition() + .to_edge_transform_and_lower(edge_compile_config=self._edge_compile_config) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() .serialize() @@ -184,6 +197,7 @@ def test_linear_tosa_MI( test_name: str, test_data: torch.Tensor, out_features: int, + has_bias: bool, ): in_features = test_data.shape[-1] test_data = (test_data,) @@ -191,6 +205,7 @@ def test_linear_tosa_MI( self.Linear( in_features=in_features, out_features=out_features, + bias=has_bias, ), test_data, ) @@ -201,11 +216,15 @@ def test_linear_tosa_BI( test_name: str, test_data: torch.Tensor, out_features: int, + has_bias: bool, ): in_features = test_data.shape[-1] test_data = (test_data,) self._test_linear_tosa_BI_pipeline( - self.Linear(in_features=in_features, out_features=out_features), test_data + self.Linear( + in_features=in_features, out_features=out_features, bias=has_bias + ), + test_data, ) @parameterized.expand(test_data_suite_rank1) @@ -214,6 +233,7 @@ def test_linear_tosa_u55_BI( test_name: str, test_data: torch.Tensor, out_features: int, + has_bias: bool, ): in_features = test_data.shape[-1] test_data = (test_data,) @@ -221,20 +241,22 @@ def test_linear_tosa_u55_BI( self.Linear( in_features=in_features, out_features=out_features, + bias=has_bias, ), - common.get_u55_compile_spec(permute_memory_to_nhwc=False), + common.get_u55_compile_spec(), test_data, ) if common.is_option_enabled("corstone300"): tester.run_method_and_compare_outputs(qtol=1, inputs=test_data) - @parameterized.expand(test_data_suite_rank1) + @parameterized.expand(test_data_suite_rank1 + test_data_suite_rank4) def test_linear_tosa_u85_BI( self, test_name: str, test_data: torch.Tensor, out_features: int, + has_bias: bool, ): in_features = test_data.shape[-1] test_data = (test_data,) @@ -242,7 +264,8 @@ def test_linear_tosa_u85_BI( self.Linear( in_features=in_features, out_features=out_features, + bias=has_bias, ), - common.get_u85_compile_spec(permute_memory_to_nhwc=False), + common.get_u85_compile_spec(), test_data, ) diff --git a/backends/arm/test/ops/test_log.py b/backends/arm/test/ops/test_log.py index 269b7be25f..847635ea36 100644 --- a/backends/arm/test/ops/test_log.py +++ b/backends/arm/test/ops/test_log.py @@ -40,7 +40,7 @@ def _test_log_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check(["torch.ops.aten.log.default"]) @@ -58,7 +58,7 @@ def _test_log_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tuple): ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_logsoftmax.py b/backends/arm/test/ops/test_logsoftmax.py index 2d51588bb3..5d84fa127f 100644 --- a/backends/arm/test/ops/test_logsoftmax.py +++ b/backends/arm/test/ops/test_logsoftmax.py @@ -46,7 +46,7 @@ def _test_logsoftmax_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check(["torch.ops.aten.log_softmax.int"]) @@ -66,7 +66,7 @@ def _test_logsoftmax_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_max_pool.py b/backends/arm/test/ops/test_max_pool.py new file mode 100644 index 0000000000..41526b1c77 --- /dev/null +++ b/backends/arm/test/ops/test_max_pool.py @@ -0,0 +1,252 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import unittest + +from typing import Tuple + +import torch +from executorch.backends.arm.quantizer.arm_quantizer import ( + ArmQuantizer, + get_symmetric_quantization_config, +) +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester + +from executorch.backends.xnnpack.test.tester.tester import Quantize +from executorch.exir.backend.backend_details import CompileSpec +from parameterized import parameterized + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +test_data_suite = [ + # (test_name, test_data, [kernel_size, stride, padding]) + ("zeros", torch.zeros(1, 1, 4, 8), [2, 2, 1]), + ("ones", torch.ones(1, 16, 50, 32), [4, 2, 0]), + ("rand", torch.rand(1, 16, 52, 16), [4, 3, 0]), +] + +test_data_suite_mult_batches = [ + ("randn", torch.randn(5, 16, 50, 32), [4, 2, 0]), +] + + +class TestMaxPool2d(unittest.TestCase): + """Tests MaxPool2d.""" + + class MaxPool2d(torch.nn.Module): + def __init__( + self, + kernel_size: int | Tuple[int, int], + stride: int | Tuple[int, int], + padding: int | Tuple[int, int], + ): + super().__init__() + self.max_pool_2d = torch.nn.MaxPool2d( + kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x): + return self.max_pool_2d(x) + + def _test_maxpool2d_tosa_MI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.tensor] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+MI", permute_memory_to_nhwc=True + ), + ) + .export() + .check(["torch.ops.aten.max_pool2d.default"]) + .check_not(["torch.ops.quantized_decomposed"]) + .to_edge() + .partition() + .check_not(["executorch_exir_dialects_edge__ops_aten_max_pool2d_default"]) + .check_not( + [ + "executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default" + ] + ) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + ) + + def _test_maxpool2d_tosa_BI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.tensor] + ): + quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+BI", permute_memory_to_nhwc=True + ), + ) + .quantize(Quantize(quantizer, get_symmetric_quantization_config())) + .export() + .check_count({"torch.ops.aten.max_pool2d.default": 1}) + .check(["torch.ops.quantized_decomposed"]) + .to_edge() + .partition() + .check_not(["executorch_exir_dialects_edge__ops_aten_max_pool2d_default"]) + .check_not( + [ + "executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default" + ] + ) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data, qtol=1) + ) + + def _test_maxpool2d_tosa_ethos_BI_pipeline( + self, + module: torch.nn.Module, + compile_spec: CompileSpec, + test_data: Tuple[torch.tensor], + ): + quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) + tester = ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=compile_spec, + ) + .quantize(Quantize(quantizer, get_symmetric_quantization_config())) + .export() + .check_count({"torch.ops.aten.max_pool2d.default": 1}) + .check(["torch.ops.quantized_decomposed"]) + .to_edge() + .partition() + .check_not(["executorch_exir_dialects_edge__ops_aten_max_pool2d_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .serialize() + ) + + return tester + + @parameterized.expand(test_data_suite) + def test_maxpool2d_tosa_MI( + self, + test_name: str, + test_data: torch.Tensor, + model_params: int | Tuple[int, int], + ): + self._test_maxpool2d_tosa_MI_pipeline( + self.MaxPool2d(*model_params), (test_data,) + ) + + @parameterized.expand(test_data_suite) + def test_maxpool2d_tosa_BI( + self, + test_name: str, + test_data: torch.Tensor, + model_params: int | Tuple[int, int], + ): + self._test_maxpool2d_tosa_BI_pipeline( + self.MaxPool2d(*model_params), (test_data,) + ) + + @parameterized.expand(test_data_suite) + def test_maxpool2d_tosa_u55_BI( + self, + test_name: str, + test_data: torch.Tensor, + model_params: int | Tuple[int, int], + ): + tester = self._test_maxpool2d_tosa_ethos_BI_pipeline( + self.MaxPool2d(*model_params), + common.get_u55_compile_spec(permute_memory_to_nhwc=True), + (test_data,), + ) + if common.is_option_enabled("corstone300"): + tester.run_method_and_compare_outputs( + qtol=1, inputs=(test_data,), target_board="corstone-300" + ) + + @parameterized.expand(test_data_suite) + def test_maxpool2d_tosa_u85_BI( + self, + test_name: str, + test_data: torch.Tensor, + model_params: int | Tuple[int, int], + ): + tester = self._test_maxpool2d_tosa_ethos_BI_pipeline( + self.MaxPool2d(*model_params), + common.get_u85_compile_spec(permute_memory_to_nhwc=True), + (test_data,), + ) + if common.is_option_enabled("corstone300"): + tester.run_method_and_compare_outputs( + qtol=1, inputs=(test_data,), target_board="corstone-320" + ) + + @parameterized.expand(test_data_suite_mult_batches) + def test_maxpool2d_tosa_MI_mult_batches( + self, + test_name: str, + test_data: torch.Tensor, + model_params: int | Tuple[int, int], + ): + self._test_maxpool2d_tosa_MI_pipeline( + self.MaxPool2d(*model_params), (test_data,) + ) + + @parameterized.expand(test_data_suite_mult_batches) + def test_maxpool2d_tosa_BI_mult_batches( + self, + test_name: str, + test_data: torch.Tensor, + model_params: int | Tuple[int, int], + ): + self._test_maxpool2d_tosa_BI_pipeline( + self.MaxPool2d(*model_params), (test_data,) + ) + + @parameterized.expand(test_data_suite_mult_batches) + @common.expectedFailureOnFVP # TODO: MLETORCH-433 + def test_maxpool2d_tosa_u55_BI_mult_batches( + self, + test_name: str, + test_data: torch.Tensor, + model_params: int | Tuple[int, int], + ): + tester = self._test_maxpool2d_tosa_ethos_BI_pipeline( + self.MaxPool2d(*model_params), + common.get_u55_compile_spec(permute_memory_to_nhwc=True), + (test_data,), + ) + if common.is_option_enabled("corstone300"): + tester.run_method_and_compare_outputs( + qtol=1, inputs=(test_data,), target_board="corstone-300" + ) + + @parameterized.expand(test_data_suite_mult_batches) + @common.expectedFailureOnFVP # TODO: MLETORCH-433 + def test_maxpool2d_tosa_u85_BI_mult_batches( + self, + test_name: str, + test_data: torch.Tensor, + model_params: int | Tuple[int, int], + ): + tester = self._test_maxpool2d_tosa_ethos_BI_pipeline( + self.MaxPool2d(*model_params), + common.get_u85_compile_spec(permute_memory_to_nhwc=True), + (test_data,), + ) + if common.is_option_enabled("corstone300"): + tester.run_method_and_compare_outputs( + qtol=1, inputs=(test_data,), target_board="corstone-320" + ) diff --git a/backends/arm/test/ops/test_mean_dim.py b/backends/arm/test/ops/test_mean_dim.py index 68307bcdf1..e8320cf1df 100644 --- a/backends/arm/test/ops/test_mean_dim.py +++ b/backends/arm/test/ops/test_mean_dim.py @@ -81,7 +81,7 @@ def _test_adaptive_avg_pool2d_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check(["torch.ops.aten.adaptive_avg_pool2d.default"]) @@ -101,7 +101,7 @@ def _test_adaptive_avg_pool2d_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() @@ -150,7 +150,7 @@ def _test_meandim_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_not(["torch.ops.quantized_decomposed"]) @@ -169,7 +169,7 @@ def _test_meandim_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_mm.py b/backends/arm/test/ops/test_mm.py index 4271496eaa..21b02bbd10 100644 --- a/backends/arm/test/ops/test_mm.py +++ b/backends/arm/test/ops/test_mm.py @@ -54,7 +54,7 @@ def _test_mm_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.mm.default": 1}) @@ -74,7 +74,7 @@ def _test_mm_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_mul.py b/backends/arm/test/ops/test_mul.py index a1c2dba5fe..7fa20c2566 100644 --- a/backends/arm/test/ops/test_mul.py +++ b/backends/arm/test/ops/test_mul.py @@ -70,7 +70,9 @@ def _test_mul_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+MI", permute_memory_to_nhwc=True + ), ) .export() .check_count({"torch.ops.aten.mul.Tensor": 1}) @@ -89,7 +91,9 @@ def _test_mul_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+BI", permute_memory_to_nhwc=True + ), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_permute.py b/backends/arm/test/ops/test_permute.py index 6346e847c9..62b6b823de 100644 --- a/backends/arm/test/ops/test_permute.py +++ b/backends/arm/test/ops/test_permute.py @@ -57,7 +57,7 @@ def _test_permute_tosa_MI_pipeline( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec( - permute_memory_to_nhwc=permute_memory_to_nhwc + "TOSA-0.80.0+MI", permute_memory_to_nhwc=permute_memory_to_nhwc ), ) .export() @@ -79,7 +79,7 @@ def _test_permute_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize(Quantize(quantizer, get_symmetric_quantization_config())) .export() diff --git a/backends/arm/test/ops/test_reciprocal.py b/backends/arm/test/ops/test_reciprocal.py index cb4971bf8c..7745a614e6 100644 --- a/backends/arm/test/ops/test_reciprocal.py +++ b/backends/arm/test/ops/test_reciprocal.py @@ -46,7 +46,7 @@ def _test_reciprocal_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.reciprocal.default": 1}) @@ -65,7 +65,7 @@ def _test_reciprocal_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_relu.py b/backends/arm/test/ops/test_relu.py index effbccc74d..595c907b32 100644 --- a/backends/arm/test/ops/test_relu.py +++ b/backends/arm/test/ops/test_relu.py @@ -48,7 +48,7 @@ def _test_relu_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check(["torch.ops.aten.relu.default"]) @@ -69,7 +69,7 @@ def _test_relu_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize(Quantize(quantizer, get_symmetric_quantization_config())) .export() diff --git a/backends/arm/test/ops/test_repeat.py b/backends/arm/test/ops/test_repeat.py index 1efac9f974..20c57ba749 100644 --- a/backends/arm/test/ops/test_repeat.py +++ b/backends/arm/test/ops/test_repeat.py @@ -47,7 +47,7 @@ def _test_repeat_tosa_MI_pipeline(self, module: torch.nn.Module, test_data: Tupl ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.repeat.default": 1}) @@ -65,7 +65,7 @@ def _test_repeat_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tupl ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize(Quantize(quantizer, get_symmetric_quantization_config())) .export() diff --git a/backends/arm/test/ops/test_rsqrt.py b/backends/arm/test/ops/test_rsqrt.py index 2ccb7ec991..2cddc8da26 100644 --- a/backends/arm/test/ops/test_rsqrt.py +++ b/backends/arm/test/ops/test_rsqrt.py @@ -35,7 +35,7 @@ def _test_rsqrt_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.rsqrt.default": 1}) @@ -53,7 +53,7 @@ def _test_rsqrt_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_scalars.py b/backends/arm/test/ops/test_scalars.py index 0305ef58c0..86433745a6 100644 --- a/backends/arm/test/ops/test_scalars.py +++ b/backends/arm/test/ops/test_scalars.py @@ -123,7 +123,7 @@ def _test_add_tosa_MI_pipeline(self, module: torch.nn.Module, test_data: tuple): ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .to_edge() @@ -137,7 +137,7 @@ def _test_add_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: tuple): ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_select.py b/backends/arm/test/ops/test_select.py index fdb2fa1463..85bfc15d2d 100644 --- a/backends/arm/test/ops/test_select.py +++ b/backends/arm/test/ops/test_select.py @@ -58,7 +58,7 @@ def _test_select_tosa_MI_pipeline( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec( - permute_memory_to_nhwc=permute + "TOSA-0.80.0+MI", permute_memory_to_nhwc=permute ), ) .export() @@ -84,7 +84,7 @@ def _test_select_tosa_BI_pipeline( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec( - permute_memory_to_nhwc=permute + "TOSA-0.80.0+BI", permute_memory_to_nhwc=permute ), ) .quantize() diff --git a/backends/arm/test/ops/test_sigmoid.py b/backends/arm/test/ops/test_sigmoid.py index 4d126b68e5..f12658c985 100644 --- a/backends/arm/test/ops/test_sigmoid.py +++ b/backends/arm/test/ops/test_sigmoid.py @@ -71,7 +71,7 @@ def _test_sigmoid_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check(["torch.ops.aten.sigmoid.default"]) @@ -89,7 +89,7 @@ def _test_sigmoid_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tup ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_slice.py b/backends/arm/test/ops/test_slice.py index 0bab21f907..0fc92b011a 100644 --- a/backends/arm/test/ops/test_slice.py +++ b/backends/arm/test/ops/test_slice.py @@ -8,13 +8,9 @@ from typing import Tuple import torch -from executorch.backends.arm.quantizer.arm_quantizer import ( - ArmQuantizer, - get_symmetric_quantization_config, -) + from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.arm_tester import ArmTester -from executorch.backends.xnnpack.test.tester.tester import Quantize from executorch.exir.backend.compile_spec_schema import CompileSpec from parameterized import parameterized @@ -43,7 +39,7 @@ def _test_slice_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check(["torch.ops.aten.slice.Tensor"]) @@ -59,16 +55,15 @@ def _test_slice_tosa_BI_pipeline( self, module: torch.nn.Module, test_data: Tuple[torch.Tensor], permute: bool ): - quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) ( ArmTester( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec( - permute_memory_to_nhwc=permute + "TOSA-0.80.0+BI", permute_memory_to_nhwc=permute ), ) - .quantize(Quantize(quantizer, get_symmetric_quantization_config())) + .quantize() .export() .check(["torch.ops.aten.slice.Tensor"]) .to_edge() @@ -84,14 +79,13 @@ def _test_slice_ethos_BI_pipeline( module: torch.nn.Module, test_data: Tuple[torch.Tensor], ): - quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) ( ArmTester( module, example_inputs=test_data, compile_spec=common.get_u55_compile_spec(), ) - .quantize(Quantize(quantizer, get_symmetric_quantization_config())) + .quantize() .export() .check(["torch.ops.aten.slice.Tensor"]) .to_edge() diff --git a/backends/arm/test/ops/test_softmax.py b/backends/arm/test/ops/test_softmax.py index 954dd201a9..f883d6b8de 100644 --- a/backends/arm/test/ops/test_softmax.py +++ b/backends/arm/test/ops/test_softmax.py @@ -47,7 +47,7 @@ def _test_softmax_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check(["torch.ops.aten.softmax.int"]) @@ -67,7 +67,7 @@ def _test_softmax_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_split.py b/backends/arm/test/ops/test_split.py index 3f6edc0c2b..42395c4c2d 100644 --- a/backends/arm/test/ops/test_split.py +++ b/backends/arm/test/ops/test_split.py @@ -7,13 +7,9 @@ import unittest import torch -from executorch.backends.arm.quantizer.arm_quantizer import ( - ArmQuantizer, - get_symmetric_quantization_config, -) + from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.arm_tester import ArmTester -from executorch.backends.xnnpack.test.tester.tester import Quantize from executorch.exir.backend.compile_spec_schema import CompileSpec from parameterized import parameterized @@ -60,7 +56,7 @@ def _test_split_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .to_edge() @@ -79,14 +75,13 @@ def _test_split_tosa_BI_pipeline( self, module: torch.nn.Module, test_data: test_data_t ): - quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) ( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) - .quantize(Quantize(quantizer, get_symmetric_quantization_config())) + .quantize() .export() .to_edge() .partition() @@ -98,14 +93,13 @@ def _test_split_tosa_BI_pipeline( def _test_split_ethosu_BI_pipeline( self, compile_spec: CompileSpec, module: torch.nn.Module, test_data: test_data_t ): - quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) ( ArmTester( module, example_inputs=test_data, compile_spec=compile_spec, ) - .quantize(Quantize(quantizer, get_symmetric_quantization_config())) + .quantize() .export() .check(["torch.ops.aten.split.Tensor"]) .to_edge() diff --git a/backends/arm/test/ops/test_squeeze.py b/backends/arm/test/ops/test_squeeze.py index c9d7d42195..7e915da645 100644 --- a/backends/arm/test/ops/test_squeeze.py +++ b/backends/arm/test/ops/test_squeeze.py @@ -13,14 +13,9 @@ import torch -from executorch.backends.arm.quantizer.arm_quantizer import ( - ArmQuantizer, - get_symmetric_quantization_config, -) from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.arm_tester import ArmTester -from executorch.backends.xnnpack.test.tester.tester import Quantize from executorch.exir.backend.compile_spec_schema import CompileSpec from parameterized import parameterized @@ -66,7 +61,7 @@ def _test_squeeze_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({export_target: 1}) @@ -83,14 +78,13 @@ def _test_squeeze_tosa_BI_pipeline( test_data: Tuple[torch.Tensor, Optional[tuple[int]]], export_target: str, ): - quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) ( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) - .quantize(Quantize(quantizer, get_symmetric_quantization_config())) + .quantize() .export() .check_count({export_target: 1}) .to_edge() @@ -107,10 +101,9 @@ def _test_squeeze_ethosu_BI_pipeline( test_data: Tuple[torch.Tensor, Optional[tuple[int]]], export_target: str, ): - quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) ( ArmTester(module, example_inputs=test_data, compile_spec=compile_spec) - .quantize(Quantize(quantizer, get_symmetric_quantization_config())) + .quantize() .export() .check_count({export_target: 1}) .to_edge() diff --git a/backends/arm/test/ops/test_sub.py b/backends/arm/test/ops/test_sub.py index e80c043698..5c67240e52 100644 --- a/backends/arm/test/ops/test_sub.py +++ b/backends/arm/test/ops/test_sub.py @@ -43,7 +43,7 @@ def _test_sub_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.sub.Tensor": 1}) @@ -63,7 +63,7 @@ def _test_sub_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_sum.py b/backends/arm/test/ops/test_sum.py index 73860dfa4a..9cd63b0a22 100644 --- a/backends/arm/test/ops/test_sum.py +++ b/backends/arm/test/ops/test_sum.py @@ -49,7 +49,7 @@ def _test_sum_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.sum.dim_IntList": 1}) @@ -68,7 +68,7 @@ def _test_sum_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_tanh.py b/backends/arm/test/ops/test_tanh.py index 6f5cf17cf3..5f3859eadd 100644 --- a/backends/arm/test/ops/test_tanh.py +++ b/backends/arm/test/ops/test_tanh.py @@ -44,7 +44,7 @@ def _test_tanh_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check(["torch.ops.aten.tanh.default"]) @@ -62,7 +62,7 @@ def _test_tanh_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tuple) ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_unsqueeze.py b/backends/arm/test/ops/test_unsqueeze.py index 1cc597c066..8936d55f8b 100644 --- a/backends/arm/test/ops/test_unsqueeze.py +++ b/backends/arm/test/ops/test_unsqueeze.py @@ -13,14 +13,9 @@ import torch -from executorch.backends.arm.quantizer.arm_quantizer import ( - ArmQuantizer, - get_symmetric_quantization_config, -) from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.arm_tester import ArmTester -from executorch.backends.xnnpack.test.tester.tester import Quantize from executorch.exir.backend.compile_spec_schema import CompileSpec from parameterized import parameterized @@ -40,7 +35,7 @@ def _test_unsqueeze_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.unsqueeze.default": 1}) @@ -54,14 +49,13 @@ def _test_unsqueeze_tosa_MI_pipeline( def _test_unsqueeze_tosa_BI_pipeline( self, module: torch.nn.Module, test_data: Tuple[torch.Tensor, int] ): - quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) ( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) - .quantize(Quantize(quantizer, get_symmetric_quantization_config())) + .quantize() .export() .check_count({"torch.ops.aten.unsqueeze.default": 1}) .to_edge() @@ -77,14 +71,13 @@ def _test_unsqueeze_ethosu_BI_pipeline( module: torch.nn.Module, test_data: Tuple[torch.Tensor, int], ): - quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) ( ArmTester( module, example_inputs=test_data, compile_spec=compile_spec, ) - .quantize(Quantize(quantizer, get_symmetric_quantization_config())) + .quantize() .export() .check_count({"torch.ops.aten.unsqueeze.default": 1}) .to_edge() diff --git a/backends/arm/test/ops/test_var.py b/backends/arm/test/ops/test_var.py index 56b7c5fbb4..3a1285e6da 100644 --- a/backends/arm/test/ops/test_var.py +++ b/backends/arm/test/ops/test_var.py @@ -86,7 +86,7 @@ def _test_var_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .to_edge() @@ -107,7 +107,7 @@ def _test_var_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize(Quantize(quantizer, get_symmetric_quantization_config())) .export() diff --git a/backends/arm/test/ops/test_view.py b/backends/arm/test/ops/test_view.py index fe1f2981da..09a8f57bd3 100644 --- a/backends/arm/test/ops/test_view.py +++ b/backends/arm/test/ops/test_view.py @@ -13,14 +13,9 @@ import torch -from executorch.backends.arm.quantizer.arm_quantizer import ( - ArmQuantizer, - get_symmetric_quantization_config, -) from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.arm_tester import ArmTester -from executorch.backends.xnnpack.test.tester.tester import Quantize from executorch.exir.backend.compile_spec_schema import CompileSpec from parameterized import parameterized @@ -60,7 +55,7 @@ def _test_view_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.view.default": 1}) @@ -74,14 +69,13 @@ def _test_view_tosa_MI_pipeline( def _test_view_tosa_BI_pipeline( self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] ): - quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) ( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) - .quantize(Quantize(quantizer, get_symmetric_quantization_config())) + .quantize() .export() .check_count({"torch.ops.aten.view.default": 1}) .to_edge() @@ -97,10 +91,13 @@ def _test_view_ethos_BI_pipeline( module: torch.nn.Module, test_data: Tuple[torch.Tensor], ): - quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) ( - ArmTester(module, example_inputs=test_data, compile_spec=compile_spec) - .quantize(Quantize(quantizer, get_symmetric_quantization_config())) + ArmTester( + module, + example_inputs=test_data, + compile_spec=compile_spec, + ) + .quantize() .export() .check_count({"torch.ops.aten.view.default": 1}) .to_edge() diff --git a/backends/arm/test/passes/test_meandim_to_averagepool2d.py b/backends/arm/test/passes/test_meandim_to_averagepool2d.py index c8fa0f4b7a..615187fb65 100644 --- a/backends/arm/test/passes/test_meandim_to_averagepool2d.py +++ b/backends/arm/test/passes/test_meandim_to_averagepool2d.py @@ -46,7 +46,7 @@ def test_tosa_BI_meandim_to_averagepool(self): ArmTester( module, example_inputs=module.get_inputs(), - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() @@ -63,7 +63,7 @@ def test_tosa_BI_meandim_no_modification(self): ArmTester( module, example_inputs=module.get_inputs(), - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/quantizer/test_generic_annotater.py b/backends/arm/test/quantizer/test_generic_annotater.py index b859757df4..3d39463a42 100644 --- a/backends/arm/test/quantizer/test_generic_annotater.py +++ b/backends/arm/test/quantizer/test_generic_annotater.py @@ -30,7 +30,9 @@ def example_inputs(self): class TestGenericAnnotator(unittest.TestCase): def check_annotation(self, model): tester = ArmTester( - model, model.example_inputs(), common.get_tosa_compile_spec() + model, + model.example_inputs(), + common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) quant_model = tester.quantize().get_artifact() partitions = get_source_partitions(quant_model.graph, [model.op]) diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index 3e9d3620cc..d2ee113a5d 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -448,8 +448,11 @@ def run_tosa_ref_model( ), "There are no quantization parameters, check output parameters" tosa_ref_output = (tosa_ref_output - quant_param.zp) * quant_param.scale + if tosa_ref_output.dtype == np.double: + tosa_ref_output = tosa_ref_output.astype("float32") + # tosa_output is a numpy array, convert to torch tensor for comparison - tosa_ref_outputs.append(torch.from_numpy(tosa_ref_output.astype("float32"))) + tosa_ref_outputs.append(torch.from_numpy(tosa_ref_output)) return tosa_ref_outputs @@ -457,7 +460,9 @@ def run_tosa_ref_model( def prep_data_for_save( data, is_quantized: bool, input_name: str, quant_param: QuantizationParams ): - data_np = np.array(data.detach(), order="C").astype(np.float32) + data_np = np.array(data.detach(), order="C").astype( + f"{data.dtype}".replace("torch.", "") + ) if is_quantized: assert quant_param.node_name in input_name, ( diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 59d326109d..14a9d1df41 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -22,6 +22,11 @@ ArmQuantizer, get_symmetric_quantization_config, ) +from executorch.backends.arm.test.common import ( + arm_test_options, + current_time_formated, + get_option, +) from executorch.backends.arm.test.runner_utils import ( _get_input_quantization_params, @@ -34,10 +39,11 @@ from executorch.backends.xnnpack.test.tester import Tester from executorch.devtools.backend_debug import get_delegation_info -from executorch.exir import EdgeCompileConfig +from executorch.exir import EdgeCompileConfig, ExecutorchProgramManager from executorch.exir.backend.compile_spec_schema import CompileSpec - +from executorch.exir.backend.partitioner import Partitioner from executorch.exir.lowered_backend_module import LoweredBackendModule + from tabulate import tabulate from torch.export.graph_signature import ExportGraphSignature, InputSpec, OutputSpec from torch.fx import Graph @@ -45,50 +51,61 @@ logger = logging.getLogger(__name__) +def _dump_lowered_modules_artifact( + path_to_dump: Optional[str], + artifact: ExecutorchProgramManager, + graph_module: torch.fx.GraphModule, +): + output = "Formated Graph Signature:\n" + output += _format_export_graph_signature( + artifact.exported_program().graph_signature + ) + + def get_output_format(lowered_module) -> str | None: + for spec in lowered_module.compile_specs: + if spec.key == "output_format": + return spec.value.decode() + return None + + for node in graph_module.graph.nodes: + if node.op == "get_attr" and node.name.startswith("lowered_module_"): + lowered_module = getattr(graph_module, node.name) + assert isinstance( + lowered_module, LoweredBackendModule + ), f"Attribute {node.name} must be of type LoweredBackendModule." + + output_format = get_output_format(lowered_module) + if output_format == "tosa": + tosa_fb = lowered_module.processed_bytes + to_print = dbg_tosa_fb_to_json(tosa_fb) + to_print = pformat(to_print, compact=True, indent=1) + output += f"\nTOSA deserialized {node.name}: \n{to_print}\n" + elif output_format == "vela": + vela_cmd_stream = lowered_module.processed_bytes + output += f"\nVela command stream {node.name}: \n{vela_cmd_stream}\n" + else: + logger.warning( + f"No TOSA nor Vela compile spec found in compile specs of {node.name}." + ) + continue + + if not output: + logger.warning("No output to print generated from artifact.") + return + + _dump_str(output, path_to_dump) + + class Partition(tester.Partition): def dump_artifact(self, path_to_dump: Optional[str]): super().dump_artifact(path_to_dump) + _dump_lowered_modules_artifact(path_to_dump, self.artifact, self.graph_module) - output = "Formated Graph Signature:\n" - output += _format_export_graph_signature( - self.artifact.exported_program().graph_signature - ) - - def get_output_format(lowered_module) -> str | None: - for spec in lowered_module.compile_specs: - if spec.key == "output_format": - return spec.value.decode() - return None - - for node in self.graph_module.graph.nodes: - if node.op == "get_attr" and node.name.startswith("lowered_module_"): - lowered_module = getattr(self.graph_module, node.name) - assert isinstance( - lowered_module, LoweredBackendModule - ), f"Attribute {node.name} must be of type LoweredBackendModule." - - output_format = get_output_format(lowered_module) - if output_format == "tosa": - tosa_fb = lowered_module.processed_bytes - to_print = dbg_tosa_fb_to_json(tosa_fb) - to_print = pformat(to_print, compact=True, indent=1) - output += f"\nTOSA deserialized {node.name}: \n{to_print}\n" - elif output_format == "vela": - vela_cmd_stream = lowered_module.processed_bytes - output += ( - f"\nVela command stream {node.name}: \n{vela_cmd_stream}\n" - ) - else: - logger.warning( - f"No TOSA nor Vela compile spec found in compile specs of {node.name}." - ) - continue - - if not output: - logger.warning("No output to print generated from artifact.") - return - _dump_str(output, path_to_dump) +class ToEdgeTransformAndLower(tester.ToEdgeTransformAndLower): + def dump_artifact(self, path_to_dump: Optional[str]): + super().dump_artifact(path_to_dump) + _dump_lowered_modules_artifact(path_to_dump, self.artifact, self.graph_module) class Serialize(tester.Serialize): @@ -206,6 +223,26 @@ def partition(self, partition_stage: Optional[Partition] = None): partition_stage = Partition(arm_partitioner) return super().partition(partition_stage) + def to_edge_transform_and_lower( + self, + to_edge_and_lower_stage: Optional[ToEdgeTransformAndLower] = None, + partitioners: Optional[List[Partitioner]] = None, + edge_compile_config: Optional[EdgeCompileConfig] = None, + ): + if to_edge_and_lower_stage is None: + if partitioners is None: + partitioners = [ArmPartitioner(compile_spec=self.compile_spec)] + to_edge_and_lower_stage = ToEdgeTransformAndLower( + partitioners, edge_compile_config + ) + else: + if partitioners is not None: + to_edge_and_lower_stage.partitioners = partitioners + if edge_compile_config is not None: + to_edge_and_lower_stage.edge_compile_conf = edge_compile_config + to_edge_and_lower_stage.edge_compile_conf._skip_dim_order = True + return super().to_edge_transform_and_lower(to_edge_and_lower_stage) + def to_executorch(self, to_executorch_stage: Optional[ToExecutorch] | None = None): if to_executorch_stage is None: to_executorch_stage = ToExecutorch(self.runner_util) @@ -250,21 +287,23 @@ def run_method_and_compare_outputs( inputs (Optional[Tuple[torch.Tensor]]): Allows you to input custom input data. The default is random data. """ + + edge_stage = self.stages[self.stage_name(tester.ToEdge)] + if edge_stage is None: + edge_stage = self.stages[self.stage_name(tester.ToEdgeTransformAndLower)] assert ( self.runner_util is not None ), "self.tosa_test_util is not initialized, cannot use run_method()" assert ( - self.stages[self.stage_name(tester.ToEdge)] is not None - ), "To compare outputs, at least the ToEdge stage needs to be run." + edge_stage is not None + ), "To compare outputs, at least the ToEdge or ToEdgeTransformAndLower stage needs to be run." stage = stage or self.cur test_stage = self.stages[stage] is_quantized = self.stages[self.stage_name(tester.Quantize)] is not None exported_program = self.stages[self.stage_name(tester.Export)].artifact - edge_program = self.stages[ - self.stage_name(tester.ToEdge) - ].artifact.exported_program() + edge_program = edge_stage.artifact.exported_program() self.runner_util.init_run( exported_program, edge_program, @@ -328,8 +367,10 @@ def get_graph(self, stage: str | None = None) -> Graph: if stage is None: stage = self.cur artifact = self.get_artifact(stage) - if self.cur == self.stage_name(tester.ToEdge) or self.cur == self.stage_name( - Partition + if ( + self.cur == self.stage_name(tester.ToEdge) + or self.cur == self.stage_name(Partition) + or self.cur == self.stage_name(ToEdgeTransformAndLower) ): graph = artifact.exported_program().graph elif self.cur == self.stage_name(tester.Export) or self.cur == self.stage_name( @@ -357,7 +398,14 @@ def dump_operator_distribution( line = "#" * 10 to_print = f"{line} {self.cur.capitalize()} Operator Distribution {line}\n" - if self.cur == self.stage_name(tester.Partition) and print_table: + if ( + self.cur + in ( + self.stage_name(tester.Partition), + self.stage_name(ToEdgeTransformAndLower), + ) + and print_table + ): graph_module = self.get_artifact().exported_program().graph_module if print_table: delegation_info = get_delegation_info(graph_module) @@ -575,6 +623,9 @@ def _get_tosa_operator_distribution( def _dump_str(to_print: str, path_to_dump: Optional[str] = None): + default_dump_path = get_option(arm_test_options.dump_path) + if not path_to_dump and default_dump_path: + path_to_dump = default_dump_path / f"ArmTester_{current_time_formated()}.log" if path_to_dump: with open(path_to_dump, "a") as fp: fp.write(to_print) diff --git a/backends/arm/third-party/ethos-u-core-driver b/backends/arm/third-party/ethos-u-core-driver index 90f9df900a..78df0006c5 160000 --- a/backends/arm/third-party/ethos-u-core-driver +++ b/backends/arm/third-party/ethos-u-core-driver @@ -1 +1 @@ -Subproject commit 90f9df900acdc0718ecd2dfdc53780664758dec5 +Subproject commit 78df0006c5fa667150d3ee35db7bde1d3f6f58c7 diff --git a/backends/arm/tosa_quant_utils.py b/backends/arm/tosa_quant_utils.py index fe408e41b3..19397fe6b2 100644 --- a/backends/arm/tosa_quant_utils.py +++ b/backends/arm/tosa_quant_utils.py @@ -8,21 +8,38 @@ # Utiliy functions for TOSA quantized lowerings import math -from typing import NamedTuple, Sequence +from typing import Callable, cast, NamedTuple, Sequence import numpy as np import serializer.tosa_serializer as ts import torch.fx import tosa.Op as TosaOp -from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg +from executorch.backends.arm.tosa_mapping import TosaArg from executorch.exir.dialects._ops import ops as exir_ops from serializer.tosa_serializer import TosaSerializerTensor from torch.fx import Node + q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default -dq_q_ops = [q_op, dq_op] +dq_q_ops = (q_op, dq_op) +passable_ops = [ + exir_ops.edge.aten.view_copy.default, + exir_ops.edge.aten.permute_copy.default, + exir_ops.edge.aten.squeeze_copy.dims, + exir_ops.edge.aten.unsqueeze_copy.default, + exir_ops.edge.aten.split_with_sizes_copy.default, + exir_ops.edge.aten.repeat.default, + exir_ops.edge.aten.clone.default, + exir_ops.edge.aten.slice_copy.Tensor, + exir_ops.edge.aten.cat.default, +] + + +def register_passable_op(op): + """We need to be able to add custom ops such as tosa_transpose to the passable_op list after they have been created""" + passable_ops.append(op) class QuantArgs(NamedTuple): @@ -30,6 +47,19 @@ class QuantArgs(NamedTuple): zp: int qmin: int qmax: int + dtype: torch.dtype + + def quantize_value(self, x): + if not isinstance(x, torch.Tensor): + x = torch.Tensor([x]) + return torch.clip( + torch.round(x / self.scale) + self.zp, + self.qmin, + self.qmax, + ).to(self.dtype) + + def dequantize_value(self, qx: int) -> float: + return (qx - self.zp) * self.scale def quantize_value(x, qargs: QuantArgs, dtype=np.int8): @@ -44,81 +74,159 @@ def dequantize_value(qx, qargs: QuantArgs): return (qx - qargs.zp) * qargs.scale -def is_quant_node(node: torch.fx.Node): +def qargs_from_qnode(node: torch.fx.Node): + assert node.target in dq_q_ops, f"Op {node} is not a quant node." - consumer_node_condition = False - if len(list(node.users)) > 0: - consumer_node = list(node.users)[0] + return QuantArgs( + scale=cast(float, node.args[1]), + zp=cast(int, node.args[2]), + qmin=cast(int, node.args[3]), + qmax=cast(int, node.args[4]), + dtype=cast(torch.dtype, node.args[5]), + ) - # For Rank > 2 Linear layers, the quant node is after the view_copy - if ( - node.target == exir_ops.edge.aten.addmm.default - and consumer_node.target == exir_ops.edge.aten.view_copy.default - ): - consumer_consumer_node = list(consumer_node.users)[0] - return True if consumer_consumer_node.target == q_op else False - consumer_node_condition = consumer_node.target == q_op - input_node_condition = False - if len(node.all_input_nodes) > 0: - input = node.all_input_nodes[0] - input_node_condition = input.target in dq_q_ops +def get_neighbour_quant_args( + node: torch.fx.Node, +) -> tuple[list[QuantArgs], list[QuantArgs]]: + user_q_args = [] - return node.target in dq_q_ops or consumer_node_condition or input_node_condition + for user in node.users: + q_args = search_quant_arg_downstream(user) + if q_args: + user_q_args.append(q_args) + input_q_nodes = [] + for input_node in node.all_input_nodes: + q_args = search_quant_arg_upstream(input_node) + if q_args: + input_q_nodes.append(q_args) + return user_q_args, input_q_nodes -def get_quant_node_dtype(node: torch.fx.Node): - # pyre-ignore[16]: Undefined attribute. - if "tosa" in node.target.__name__: - return node.meta["val"].dtype - if node.target in dq_q_ops: - return node.args[5] +def all_q_args_equal(q_arg_list: list[QuantArgs]) -> bool: + first_q_arg = q_arg_list[0] + for q_arg in q_arg_list: + if q_arg != first_q_arg: + return False + return True - # if not a tosa node, nor a q/dq op, walk the graph until we find a q op - consumer_node = list(node.users)[0] - while True: - if consumer_node.target in dq_q_ops: - return consumer_node.args[5] - # Try to move on to the next node - if len(consumer_node.users) == 0: - raise RuntimeError(f"No quantized node found in graph for node {node}") - consumer_node = list(consumer_node.users)[0] +def is_node_quantized(node: torch.fx.Node) -> bool: + if node.target in dq_q_ops: + return True + user_q_args, input_q_args = get_neighbour_quant_args(node) -def is_quant_arg(arg): - consumer_node = list(arg.users)[0] - return consumer_node.target == q_op + # If we did not find any neighbouring quant nodes, we are not quantized. + if len(input_q_args) == 0 and len(user_q_args) == 0: + return False + if node.target in passable_ops: + assert all_q_args_equal( + user_q_args + input_q_args + ), f"Node {node} needs same quantization parameters on all inputs and outputs." -def get_quant_arg_dtype(node: torch.fx.Node): - consumer_node = list(node.users)[0] + return True - # Get type of quant node, args differ from per_tensor and per_channel. - if consumer_node.target == q_op: - if is_quant_arg(node): - return map_dtype(consumer_node.args[5]) - else: - raise RuntimeError("Quantization argument not found") + +def search_quant_arg_downstream(node: torch.fx.Node) -> QuantArgs | None: + """ + Iterates downward in the graph passing through 'passable_ops' to find and return a quantization node, + starting with 'node'. + If a passable node with multiple consumers is encountered, + find QuantArgs for all consumers and assert that they are equal. + If a node not in passable_ops is encountered, return None. + If a node without consumers is encountered, return None. + """ + if node.target in dq_q_ops: + return qargs_from_qnode(node) + if node.target not in passable_ops: + return None + consumer_nodes = list(node.users) + if len(consumer_nodes) == 0: + return None + elif len(consumer_nodes) == 1: + return search_quant_arg_downstream(consumer_nodes[0]) + else: + consumer_qargs: list[QuantArgs] = [] + for input in consumer_nodes: + quant_args = search_quant_arg_downstream(input) + if quant_args: + consumer_qargs.append(quant_args) + if len(consumer_qargs) == 0: + return None + assert all_q_args_equal( + consumer_qargs + ), f"Encountered a op, {node}, in passable_ops with different QuantArgs for different consumers." + return consumer_qargs[0] + + +def get_quant_arg_downstream(node: torch.fx.Node) -> QuantArgs: + """Calls search_quant_arg_downstream and asserts that QuantArgs are found, + meaning return value can't be None. + """ + qargs = search_quant_arg_downstream(node) + assert qargs, f"Did not find QuantArgs downstream for node {node}" + return qargs -def get_quant_node_args(node: torch.fx.Node): +def search_quant_arg_upstream(node: torch.fx.Node) -> QuantArgs | None: + """ + Iterates upward in the graph passing through 'passable_ops' to find and return a quantization node, + starting with 'node'. + If a passable node with multiple inputs is encountered, + find QuantArgs for all inputs and assert that they are equal. + If a node not in passable_ops is encountered, return None. + If a node without inputs is encountered, return None. """ - Get the quantization parameters from a quant node. - Args: - node: The quant node. - Returns: - QuantArgs: scale, zp, qmin, qmax + if node.target in dq_q_ops: + return qargs_from_qnode(node) + if node.target not in passable_ops: + return None + input_nodes = list(node.all_input_nodes) + if len(input_nodes) == 0: + return None + elif len(input_nodes) == 1: + return search_quant_arg_upstream(input_nodes[0]) + else: + input_qargs: list[QuantArgs] = [] + for input in input_nodes: + quant_args = search_quant_arg_upstream(input) + if quant_args: + input_qargs.append(quant_args) + if len(input_qargs) == 0: + return None + assert all_q_args_equal( + input_qargs + ), f"Encountered a op, {node}, in passable_ops with different QuantArgs for different inputs." + return input_qargs[0] + + +def get_quant_arg_upstream(node: torch.fx.Node) -> QuantArgs: + """Calls search_quant_arg_upstream and asserts that QuantArgs are found, + meaning return value can't be None. """ - quant_args = [TosaArg(arg) for arg in node.args] - return QuantArgs( - quant_args[1].number, - quant_args[2].number, - quant_args[3].number, - quant_args[4].number, - ) + qargs = search_quant_arg_upstream(node) + assert qargs, f"Did not find QuantArgs upstream for node {node}" + return qargs + + +def get_quantized_node_output_dtype(node: torch.fx.Node) -> torch.dtype: + if isinstance(node.target, Callable) and "tosa" in node.target.__name__: + return node.meta["val"].dtype + if node.target in dq_q_ops: + return cast(torch.dtype, node.args[5]) + + # if not a tosa node, nor a q/dq op, walk the graph until we find a q op + user_q_args, input_q_args = get_neighbour_quant_args(node) + if len(user_q_args) > 0: + return user_q_args[0].dtype + elif node.target in passable_ops and len(input_q_args) > 0: + return input_q_args[0].dtype + else: + raise RuntimeError("No quantized node found in graph") # Check if scale32 mode is used for given output element type @@ -267,14 +375,14 @@ def rescale_nodes_to_int32( needed by rescale_node_back_to_int8. """ - tensors = [TosaArg(node.args[0]) for node in nodes] + tensors = [TosaArg(node) for node in nodes] # Reshape tensor according to tosa dim order for tensor in tensors: dim_order = tensor.dim_order tensor.shape = [tensor.shape[i] for i in dim_order] - qargs = [get_quant_node_args(node) for node in nodes] + qargs = [get_quant_arg_upstream(node) for node in nodes] # Scale the int8 quantized input to a common scale in the integer # domain @@ -307,7 +415,7 @@ def rescale_node_back_to_int8( scale: the scaling factor used to rescale to int32, from the function 'rescale_nodes_to_int32' tosa_graph: the tosa_graph to manipulate. """ - qargs_out = get_quant_node_args(list(node.users)[0]) + qargs_out = get_quant_arg_downstream(list(node.users)[0]) output_rescale_scale = scale / qargs_out.scale # Rescale Back to INT8 @@ -334,7 +442,7 @@ def build_rescale_conv_output( output_zp, ): # TODO add check to verify if this is a Per-channel quantization. - post_conv2d_scale = (input_scale.number * weight_scale.number) / output_scale.number + post_conv2d_scale = (input_scale * weight_scale) / output_scale # Since we assume the input tensor that is being rescaled is int32 date type, zero point must be 0. build_rescale( @@ -345,6 +453,6 @@ def build_rescale_conv_output( output_type, op.shape, 0, - output_zp.number, + output_zp, ) return diff --git a/backends/arm/tosa_specification.py b/backends/arm/tosa_specification.py new file mode 100644 index 0000000000..716e8daee2 --- /dev/null +++ b/backends/arm/tosa_specification.py @@ -0,0 +1,226 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +# +# Main implementation of AoT flow to partition and preprocess for Arm target +# backends. Converts via TOSA as an intermediate form supported by AoT and +# JIT compiler flows. +# + +import re +from typing import List + +from executorch.exir.backend.compile_spec_schema import CompileSpec +from packaging.version import Version + + +class TosaSpecification: + """ + This class implements a representation of TOSA specification + (https://www.mlplatform.org/tosa/tosa_spec.html) with a version, a profile + (with extension) and a level (8k). + For 0.80 releases the profile is BI or MI, with u55 handled as an inofficial extension + For 1.00 releases the profile is INT or FP, and the extensions are for + INT: int16, int4, var, cf + FP: bf16, fp8e4m3, fp8e5m2, fft, var, cf + + The TOSA specification is encoded in the string represenatation + TOSA-major.minor.patch+profile[+level][+extensions] + + For 0.80 MI implies BI, while for 1.0 the profiles has to explicitely be specified. + + Profiles are uppercase letters and extensions and level is lowercase. + """ + + version: Version + + def support_integer(self) -> bool: + """ + Returns true if any integer operations are supported for the specification. + """ + raise NotImplementedError + + def support_float(self) -> bool: + """ + Returns true if any float operations are supported for the specification. + """ + raise NotImplementedError + + def __init__(self, version: Version): + self.version = version + + @staticmethod + def create_from_compilespecs( + compile_specs: List[CompileSpec], + ) -> "TosaSpecification": + """ + Search the CompileSpec list for 'tosa_version' and instantiate a + class from the found value or return None on failure. + """ + for spec in compile_specs: + if spec.key == "tosa_version": + return TosaSpecification.create_from_string(spec.value.decode()) + raise ValueError( + "No TOSA version key found in any of the supplied CompileSpecs" + ) + + @staticmethod + def create_from_string(repr: str) -> "TosaSpecification": + """ + Creates a TOSA specification class from a string representation: + TOSA-0.80.0+MI + TOSA-0.80.0+BI+8k + TOSA-0.80.0+BI+u55 # Ethos-U55 extension to handle TOSA subset + TOSA-0.90.0+MI + TOSA-1.00.0+INT+FP+int4+cf + """ + + pattern = r"^(TOSA)-([\d.]+)\+(.+)$" + match = re.match(pattern, repr) + if match: + name = match.group(1) + version = Version(match.group(2)) + extras = match.group(3).split("+") + if name != "TOSA": + raise ValueError(f"Malformed TOSA specification representation: {repr}") + match version: + case _ if version.major == 0 and version.minor == 80: + return Tosa_0_80(version, extras) + case _ if version.major == 1 and version.minor == 0: + return Tosa_1_00(version, extras) + case _: + raise ValueError(f"Wrong TOSA version: {version} from {repr}") + + raise ValueError(f"Failed to parse TOSA specification representation: {repr}") + + +class Tosa_0_80(TosaSpecification): + profile: str + level_8k: bool + is_U55_subset: bool + available_profiles = ["BI", "MI"] # MT is not defined + + def __init__(self, version: Version, extras: List[str]): + super().__init__(version) + assert version >= Version("0.80") and version < Version("0.90") + + # Check that we only have one profile in the extensions list + if [e in Tosa_0_80.available_profiles for e in extras].count(True) != 1: + raise ValueError( + f"Bad combination of extras: {extras}, more than one of {Tosa_0_80.available_profiles} found." + ) + + # The list contains one profile at most, so pick it + self.profile = [e for e in extras if e in Tosa_0_80.available_profiles][0] + extras.remove(self.profile) + + self.level_8k = "8k" in extras + if self.level_8k: + extras.remove("8k") + self.is_U55_subset = "u55" in extras + if self.is_U55_subset: + extras.remove("u55") + + if len(extras) > 0: + raise ValueError(f"Unhandled extras found: {extras}") + + def __repr__(self): + extensions = "" + if self.level_8k: + extensions += "+8K" + if self.is_U55_subset: + extensions += "+u55" + return f"TOSA-{str(self.version)}+{self.profile}{extensions}" + + def __hash__(self) -> int: + return hash(str(self.version) + self.profile) + + def __eq__(self, other: object) -> bool: + if isinstance(other, Tosa_0_80): + return (self.version == other.version) and (self.profile == other.profile) + return False + + def support_integer(self): + return True + + def support_float(self): + return self.profile == "MI" + + +class Tosa_1_00(TosaSpecification): + profiles: List[str] + level_8k: bool + extensions: List[str] + + available_profiles = ["INT", "FP"] + valid_extensions = { + "INT": ["int16", "int4", "var", "cf"], + "FP": ["bf16", "fp8e4m3", "fp8e5m2", "fft", "var", "cf"], + } + + def __init__(self, version: Version, extras: List[str]): + super().__init__(version) + + # Check that we have at least one profile in the extensions list + if [e in Tosa_1_00.available_profiles for e in extras].count(True) == 0: + raise ValueError( + f"No profile ({Tosa_1_00.available_profiles}) found in: {extras}." + ) + + # and not more than number of available profiles + if [e in Tosa_1_00.available_profiles for e in extras].count(True) > len( + Tosa_1_00.available_profiles + ): + raise ValueError( + f"Too many profiles ({Tosa_1_00.available_profiles}) found in: {extras}." + ) + + # The list contains one profile at least, so pick them + self.profiles = [e for e in extras if e in Tosa_1_00.available_profiles] + for p in self.profiles: + extras.remove(p) + + self.level_8k = "8k" in extras + if self.level_8k: + extras.remove("8k") + + combined_extensions = [] + for p in self.profiles: + combined_extensions += Tosa_1_00.valid_extensions[p] + + if not all(e in combined_extensions for e in extras): + raise ValueError( + f"Bad extensions for TOSA-{version}{self._get_profiles_string()}: {extras}" + ) + + # all the rest of the extras are handled extensions + self.extensions = extras + + def _get_profiles_string(self) -> str: + return "".join(["+" + p for p in self.profiles]) + + def _get_extensions_string(self) -> str: + return "".join(["+" + e for e in self.extensions]) + + def __repr__(self): + return f"TOSA-{self.version}{self._get_profiles_string()}{self._get_profiles_string()}" + + def __hash__(self) -> int: + return hash(str(self.version) + self._get_profiles_string()) + + def __eq__(self, other: object) -> bool: + if isinstance(other, Tosa_1_00): + return (self.version == other.version) and ( + self._get_profiles_string() == other._get_profiles_string() + ) + return False + + def support_integer(self): + return "INT" in self.profiles + + def support_float(self): + return "FP" in self.profiles diff --git a/backends/arm/tosa_utils.py b/backends/arm/tosa_utils.py index cfafac1676..b61b27853a 100644 --- a/backends/arm/tosa_utils.py +++ b/backends/arm/tosa_utils.py @@ -16,11 +16,13 @@ from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg from executorch.backends.arm.tosa_quant_utils import ( - get_quant_node_args, - get_quant_node_dtype, - is_quant_node, + get_quant_arg_downstream, + get_quant_arg_upstream, + get_quantized_node_output_dtype, + is_node_quantized, q_op, ) +from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.exir.dialects._ops import ops as exir_ops from serializer.tosa_serializer import TosaOp from torch.fx import Node @@ -130,31 +132,6 @@ def get_output_node(node: Node) -> Node: return list(node.users)[0] -# Helper function to do broadcasting -# Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_broadcasting -def broadcast_shapes(shape1, shape2): - assert len(shape1) == len(shape2), "broadcast_shapes::shapes must have same ranks" - - need_broadcasting = False - for val1, val2 in zip(shape1, shape2): - if val1 != val2: - need_broadcasting = True - if not need_broadcasting: - return shape1 - - broadcasted_shape = list(shape1) - shape2 = list(shape2) - for idx, _ in enumerate(broadcasted_shape): - if broadcasted_shape[idx] == 1: - broadcasted_shape[idx] = shape2[idx] - else: - assert not ( - shape2[idx] != 1 and shape2[idx] != broadcasted_shape[idx] - ), "broadcast_shapes::broadcast shape mismatch" - - return broadcasted_shape - - """ TOSA reshape returns a tensor with the same type/values as the input. No data conversion happens during a reshape operation. """ @@ -165,36 +142,6 @@ def build_reshape(tosa_fb, input_name, new_shape, output_name): tosa_fb.addOperator(TosaOp.Op().RESHAPE, [input_name], [output_name], attr) -def is_permute_node_before_addmm(node): - return ( - node.target == exir_ops.edge.aten.permute_copy.default - and list(node.users)[0].target == exir_ops.edge.aten.addmm.default - ) - - -def is_bias_node_for_quantized_addmm(node): - consumer_node = list(node.users)[0] - # consumer node is addmm - is_rank2_linear_bias = ( - consumer_node.target == exir_ops.edge.aten.addmm.default - and list(consumer_node.users)[0].target == q_op - ) - - # rank>2 linear layers - # consumer_consumer node is view_copy - is_rank_greater_than_2_linear_bias = False - if ( - consumer_node.target == exir_ops.edge.aten.addmm.default - and list(consumer_node.users)[0].target == exir_ops.edge.aten.view_copy.default - ): - consumer_consumer_node = list(consumer_node.users)[0] - is_rank_greater_than_2_linear_bias = ( - list(consumer_consumer_node.users)[0].target == q_op - ) - - return is_rank2_linear_bias or is_rank_greater_than_2_linear_bias - - def is_bias_node_for_quantized_conv(node): consumer_node = list(node.users)[0] return ( @@ -237,8 +184,8 @@ def build_avg_pool_2d_common( output_zp = 0 if is_quant_node: - input_zp = get_quant_node_args(cast(torch.fx.Node, node.args[0])).zp - output_zp = get_quant_node_args(list(node.users)[0]).zp + input_zp = get_quant_arg_upstream(cast(torch.fx.Node, node.args[0])).zp + output_zp = get_quant_arg_downstream(list(node.users)[0]).zp attr = ts.TosaSerializerAttribute() attr.PoolAttribute( @@ -290,6 +237,7 @@ def process_call_function( node: torch.fx.Node, tosa_graph: ts.TosaSerializer, node_visitors: Dict[str, NodeVisitor], + tosa_spec: TosaSpecification, ): # Unpack arguments and convert inputs = getNodeArgs(node) @@ -297,14 +245,15 @@ def process_call_function( # Convert output (this node itself) output = TosaArg(node) + is_quant_node = is_node_quantized(node) + if is_quant_node: + output_dtype = map_dtype(get_quantized_node_output_dtype(node)) + else: + output_dtype = output.dtype tosa_graph.currRegion.currBasicBlock.addTensor( output.name, - ( - tosa_shape(inputs[0].shape, inputs[0].dim_order) - if is_permute_node_before_addmm(node) - else tosa_shape(output.shape, output.dim_order) - ), - map_dtype(get_quant_node_dtype(node)) if is_quant_node(node) else output.dtype, + (tosa_shape(output.shape, output.dim_order)), + output_dtype, ) # Visiting each Node @@ -316,10 +265,10 @@ def process_call_function( tosa_graph, inputs, output, - is_quant_node(node), + is_quant_node, ) else: - raise RuntimeError(f"Unknown operator {node.target}") + raise RuntimeError(f"Unknown operator {node.target} for TOSA : {tosa_spec}") def expand_dims( diff --git a/backends/arm/util/arm_model_evaluator.py b/backends/arm/util/arm_model_evaluator.py index 4ffb80c2f0..b348f10722 100644 --- a/backends/arm/util/arm_model_evaluator.py +++ b/backends/arm/util/arm_model_evaluator.py @@ -7,7 +7,7 @@ import os import tempfile import zipfile -from typing import Optional, Tuple, Union +from typing import Any, Optional, Tuple import torch @@ -32,7 +32,7 @@ def __init__( else: self.tosa_output_path = None - def get_model_error(self) -> Union[float, float, float, float]: + def get_model_error(self) -> tuple[float, float, float, float]: """ Returns the following metrics between the outputs of the FP32 and INT8 model: - Maximum error @@ -51,7 +51,12 @@ def get_model_error(self) -> Union[float, float, float, float]: max_percentage_error = torch.max(percentage_error).item() mean_absolute_error = torch.mean(torch.abs(difference).float()).item() - return max_error, max_absolute_error, max_percentage_error, mean_absolute_error + return ( + float(max_error), + float(max_absolute_error), + float(max_percentage_error), + float(mean_absolute_error), + ) def get_compression_ratio(self) -> float: """Compute the compression ratio of the outputted TOSA flatbuffer.""" @@ -67,7 +72,7 @@ def get_compression_ratio(self) -> float: return compression_ratio - def evaluate(self) -> dict[any]: + def evaluate(self) -> dict[str, Any]: max_error, max_absolute_error, max_percent_error, mean_absolute_error = ( self.get_model_error() ) @@ -82,6 +87,8 @@ def evaluate(self) -> dict[any]: } if self.tosa_output_path: + # We know output_metrics["metrics"] is list since we just defined it, safe to ignore. + # pyre-ignore[16] output_metrics["metrics"][ "compression_ratio" ] = self.get_compression_ratio() diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index 8456c50f6c..74deed0628 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -11,6 +11,7 @@ load( "CXX", ) load("@fbsource//xplat/executorch/codegen:codegen.bzl", "executorch_generated_lib") +load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest") oncall("odai_jarvis") @@ -62,6 +63,21 @@ python_library( ], ) +python_library( + name = "pass_utils", + srcs = [ + "pass_utils.py", + ], + deps = [ + ":utils", + "//caffe2:torch", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + "//executorch/exir/passes:lib", + "//executorch/exir/passes:spec_prop_pass", + ], +) + python_library( name = "ops_registrations", srcs = [ @@ -88,3 +104,15 @@ executorch_generated_lib( "//executorch/kernels/portable:operators", ], ) + +python_unittest( + name = "test_pass_filter", + srcs = [ + "tests/test_pass_filter.py", + ], + typing = True, + deps = [ + ":pass_utils", + "//executorch/exir:pass_base", + ], +) diff --git a/backends/cadence/aot/functions_hifi.yaml b/backends/cadence/aot/functions_hifi.yaml index 84c07be78c..52390e1918 100644 --- a/backends/cadence/aot/functions_hifi.yaml +++ b/backends/cadence/aot/functions_hifi.yaml @@ -62,6 +62,11 @@ - arg_meta: null kernel_name: torch::executor::full_out +- op: mean.out + kernels: + - arg_meta: null + kernel_name: cadence::impl::HiFi::mean_dim_out + - op: mul.out kernels: - arg_meta: null @@ -105,7 +110,7 @@ - op: where.self_out kernels: - arg_meta: null - kernel_name: torch::executor::where_out + kernel_name: cadence::impl::HiFi::where_out # custom ops - func: cadence::quantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index d47ea3f21a..5e852b369d 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -50,7 +50,11 @@ "quantized_linear.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( - "cadence::quantized_linear.per_tensor_out(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)" + "quantized_linear.per_tensor_out(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.define( + "quantized_linear.per_tensor(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, " + "SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset) -> Tensor" ) lib.define( @@ -66,6 +70,12 @@ lib.define( "quantized_conv.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)" ) +lib.define( + "quantized_conv.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, bool channel_last=False) -> (Tensor Z)" +) +lib.define( + "quantized_conv.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)" +) lib.define( "quantized_matmul(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed=False) -> (Tensor Z)" @@ -123,6 +133,28 @@ def quantized_linear_meta( return src.new_empty(out_size, dtype=src.dtype) +@register_fake("cadence::quantized_linear.per_tensor") +def quantized_linear_per_tensor_meta( + src: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + in_zero_point: torch.SymInt, + weight_zero_point: torch.SymInt, + out_multiplier: torch.SymInt, + out_shift: torch.SymInt, + out_zero_point: torch.SymInt, + offset: Optional[torch.Tensor], +) -> torch.Tensor: + # src comes in shape [leading_dims, in_dim] + # weight comes in shape [out_dim, in_dim] + # output comes in empty with shape [leading_dims, out_dim] + out_size = list(src.size()) + weight_size = list(weight.size()) + assert len(weight_size) == 2 + out_size[-1] = weight_size[0] + return src.new_empty(out_size, dtype=src.dtype) + + @register_fake("cadence::quantized_conv") def quantized_conv_meta( input: torch.Tensor, @@ -171,6 +203,54 @@ def quantized_conv_meta( return input.new_empty(output_size, dtype=input.dtype) +@register_fake("cadence::quantized_conv.per_tensor") +def quantized_conv_per_tensor_meta( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: Tuple[int], + padding: Tuple[int], + dilation: Tuple[int], + groups: int, + in_zero_point: int, + weight_zero_point: int, + bias_scale: float, + output_scale: float, + output_zero_point: int, + out_multiplier: int, + out_shift: int, + channel_last: bool = False, +) -> torch.Tensor: + if channel_last: + out_channels, *kernel_size, _ = weight.shape + else: + out_channels, _, *kernel_size = weight.shape + + in_size = input.shape + # Assert that the input tensor has at least 3 dimensions, and at most 6 + assert len(in_size) > 2 + assert len(in_size) < 6 + + # Compute the output tensor size + output_size = ( + get_conv1d_output_size( + in_size, + out_channels, + stride[1], + padding[1], + dilation[1], + kernel_size[0], + channel_last, + ) + if len(in_size) == 3 + else get_conv2d_output_size( + in_size, out_channels, stride, padding, dilation, kernel_size, channel_last + ) + ) + + return input.new_empty(output_size, dtype=input.dtype) + + @register_fake("cadence::quantized_layer_norm") def quantized_layer_norm_meta( input: torch.Tensor, diff --git a/backends/cadence/aot/pass_utils.py b/backends/cadence/aot/pass_utils.py new file mode 100644 index 0000000000..12a2f62238 --- /dev/null +++ b/backends/cadence/aot/pass_utils.py @@ -0,0 +1,91 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-strict + +from dataclasses import dataclass +from typing import Callable, Optional, Set, Union + +import torch +from executorch.backends.cadence.aot.utils import get_edge_overload_packet + +from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket + +from executorch.exir.pass_base import ExportPass +from torch._ops import OpOverloadPacket + + +# Is an overlap in tensor lifetime and storage allowed at the current opt level? +# We allow overlap at opt level >= 2. +def allow_lifetime_and_storage_overlap(opt_level: int) -> bool: + return opt_level >= 2 + + +# A dataclass that stores the attributes of an ExportPass. +@dataclass +class CadencePassAttribute: + opt_level: Optional[int] = None + debug_pass: bool = False + + +# A dictionary that maps an ExportPass to its attributes. +ALL_CADENCE_PASSES: dict[ExportPass, CadencePassAttribute] = {} + + +def get_cadence_pass_attribute(p: ExportPass) -> CadencePassAttribute: + return ALL_CADENCE_PASSES[p] + + +# A decorator that registers a pass. +def register_cadence_pass( + pass_attribute: CadencePassAttribute, +) -> Callable[[ExportPass], ExportPass]: + def wrapper(cls: ExportPass) -> ExportPass: + ALL_CADENCE_PASSES[cls] = pass_attribute + return cls + + return wrapper + + +def get_all_available_cadence_passes() -> Set[ExportPass]: + return set(ALL_CADENCE_PASSES.keys()) + + +# Create a new filter to filter out relevant passes from all Jarvis passes. +def create_cadence_pass_filter( + opt_level: int, debug: bool = False +) -> Callable[[ExportPass], bool]: + def _filter(p: ExportPass) -> bool: + pass_attribute = get_cadence_pass_attribute(p) + return ( + pass_attribute.opt_level is not None + and pass_attribute.opt_level <= opt_level + and (not pass_attribute.debug_pass or debug) + ) + + return _filter + + +# Return the overload packet for the edge or torch op. +def get_overload_packet( + op: Union[Callable[..., str], str], +) -> Union[OpOverloadPacket, EdgeOpOverloadPacket, None]: + return ( + get_edge_overload_packet(op) + if isinstance(op, EdgeOpOverload) + else getattr(op, "overloadpacket", None) + ) + + +# Get the list of node names in a graph module (only for "call_function" ops and +# EdgeOpOverload targets). This should be used only after to_edge is called. +def get_node_names_list_from_gm( + graph_module: torch.fx.GraphModule, +) -> list[torch.fx.Node]: + graph_nodes = [] + for node in graph_module.graph.nodes: + if node.op != "call_function": + continue + if not isinstance(node.target, EdgeOpOverload): + continue + graph_nodes.append(node.name) + return graph_nodes diff --git a/backends/cadence/aot/tests/test_pass_filter.py b/backends/cadence/aot/tests/test_pass_filter.py new file mode 100644 index 0000000000..7b49ef5c32 --- /dev/null +++ b/backends/cadence/aot/tests/test_pass_filter.py @@ -0,0 +1,160 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-unsafe + + +import unittest + +from copy import deepcopy + +from executorch.backends.cadence.aot import pass_utils +from executorch.backends.cadence.aot.pass_utils import ( + ALL_CADENCE_PASSES, + CadencePassAttribute, + create_cadence_pass_filter, + register_cadence_pass, +) + +from executorch.exir.pass_base import ExportPass + + +class TestBase(unittest.TestCase): + def setUp(self): + # Before running each test, create a copy of _all_passes to later restore it after test. + # This avoids messing up the original _all_passes when running tests. + self._all_passes_original = deepcopy(ALL_CADENCE_PASSES) + # Clear _all_passes to do a clean test. It'll be restored after each test in tearDown(). + pass_utils.ALL_CADENCE_PASSES.clear() + + def tearDown(self): + # Restore _all_passes to original state before test. + pass_utils.ALL_CADENCE_PASSES = self._all_passes_original + + def get_filtered_passes(self, filter_): + return {cls: attr for cls, attr in ALL_CADENCE_PASSES.items() if filter_(cls)} + + +# Test pass registration +class TestPassRegistration(TestBase): + def test_register_cadence_pass(self): + pass_attr_O0 = CadencePassAttribute(opt_level=0) + pass_attr_debug = CadencePassAttribute(opt_level=None, debug_pass=True) + pass_attr_O1_all_backends = CadencePassAttribute( + opt_level=1, + ) + + # Register 1st pass with opt_level=0 + @register_cadence_pass(pass_attr_O0) + class DummyPass_O0(ExportPass): + pass + + # Register 2nd pass with opt_level=1, all backends. + @register_cadence_pass(pass_attr_O1_all_backends) + class DummyPass_O1_All_Backends(ExportPass): + pass + + # Register 3rd pass with opt_level=None, debug=True + @register_cadence_pass(pass_attr_debug) + class DummyPass_Debug(ExportPass): + pass + + # Check if the three passes are indeed added into _all_passes + expected_all_passes = { + DummyPass_O0: pass_attr_O0, + DummyPass_Debug: pass_attr_debug, + DummyPass_O1_All_Backends: pass_attr_O1_all_backends, + } + self.assertEqual(pass_utils.ALL_CADENCE_PASSES, expected_all_passes) + + +# Test pass filtering +class TestPassFiltering(TestBase): + def test_filter_none(self): + pass_attr_O0 = CadencePassAttribute(opt_level=0) + pass_attr_O1_debug = CadencePassAttribute(opt_level=1, debug_pass=True) + pass_attr_O1_all_backends = CadencePassAttribute( + opt_level=1, + ) + + @register_cadence_pass(pass_attr_O0) + class DummyPass_O0(ExportPass): + pass + + @register_cadence_pass(pass_attr_O1_debug) + class DummyPass_O1_Debug(ExportPass): + pass + + @register_cadence_pass(pass_attr_O1_all_backends) + class DummyPass_O1_All_Backends(ExportPass): + pass + + O1_filter = create_cadence_pass_filter(opt_level=1, debug=True) + O1_filter_passes = self.get_filtered_passes(O1_filter) + + # Assert that no passes are filtered out. + expected_passes = { + DummyPass_O0: pass_attr_O0, + DummyPass_O1_Debug: pass_attr_O1_debug, + DummyPass_O1_All_Backends: pass_attr_O1_all_backends, + } + self.assertEqual(O1_filter_passes, expected_passes) + + def test_filter_debug(self): + pass_attr_O1_debug = CadencePassAttribute(opt_level=1, debug_pass=True) + pass_attr_O2 = CadencePassAttribute(opt_level=2) + + @register_cadence_pass(pass_attr_O1_debug) + class DummyPass_O1_Debug(ExportPass): + pass + + @register_cadence_pass(pass_attr_O2) + class DummyPass_O2(ExportPass): + pass + + debug_filter = create_cadence_pass_filter(opt_level=2, debug=False) + debug_filter_passes = self.get_filtered_passes(debug_filter) + + # Assert that debug passees are filtered out, since the filter explicitly + # chooses debug=False. + self.assertEqual(debug_filter_passes, {DummyPass_O2: pass_attr_O2}) + + def test_filter_all(self): + @register_cadence_pass(CadencePassAttribute(opt_level=1)) + class DummyPass_O1(ExportPass): + pass + + @register_cadence_pass(CadencePassAttribute(opt_level=2)) + class DummyPass_O2(ExportPass): + pass + + debug_filter = create_cadence_pass_filter(opt_level=0) + debug_filter_passes = self.get_filtered_passes(debug_filter) + + # Assert that all the passes are filtered out, since the filter only selects + # passes with opt_level <= 0 + self.assertEqual(debug_filter_passes, {}) + + def test_filter_opt_level_None(self): + pass_attr_O1 = CadencePassAttribute(opt_level=1) + pass_attr_O2_debug = CadencePassAttribute(opt_level=2, debug_pass=True) + + @register_cadence_pass(CadencePassAttribute(opt_level=None)) + class DummyPass_None(ExportPass): + pass + + @register_cadence_pass(pass_attr_O1) + class DummyPass_O1(ExportPass): + pass + + @register_cadence_pass(pass_attr_O2_debug) + class DummyPass_O2_Debug(ExportPass): + pass + + O2_filter = create_cadence_pass_filter(opt_level=2, debug=True) + filtered_passes = self.get_filtered_passes(O2_filter) + # Passes with opt_level=None should never be retained. + expected_passes = { + DummyPass_O1: pass_attr_O1, + DummyPass_O2_Debug: pass_attr_O2_debug, + } + self.assertEqual(filtered_passes, expected_passes) diff --git a/backends/cadence/hifi/kernels/CMakeLists.txt b/backends/cadence/hifi/kernels/CMakeLists.txt index 8fee7e8536..9321cc544e 100644 --- a/backends/cadence/hifi/kernels/CMakeLists.txt +++ b/backends/cadence/hifi/kernels/CMakeLists.txt @@ -13,6 +13,8 @@ add_library( ${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_div_f32_broadcast.c ${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_div_mode_f32_broadcast.c ${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_mul_f32_broadcast.c + ${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_where_f32xf32_f32.c + ${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_reduce_32_32.c ) # Let files say "include ". set(_common_include_directories ${EXECUTORCH_ROOT}/..) diff --git a/backends/cadence/hifi/kernels/kernels.cpp b/backends/cadence/hifi/kernels/kernels.cpp index 10e5fb176e..1b335c846b 100644 --- a/backends/cadence/hifi/kernels/kernels.cpp +++ b/backends/cadence/hifi/kernels/kernels.cpp @@ -165,6 +165,7 @@ void requantize( typed_quantize_val(int8_t); typed_quantize_val(uint8_t); typed_quantize_val(int16_t); +typed_quantize_val(uint16_t); #undef typed_quantize_val #define typed_quantize_vec(dtype) \ @@ -177,6 +178,7 @@ typed_quantize_val(int16_t); typed_quantize_vec(int8_t); typed_quantize_vec(uint8_t); typed_quantize_vec(int16_t); +typed_quantize_vec(uint16_t); typed_quantize_vec(int32_t); #undef typed_quantize_vec @@ -186,6 +188,7 @@ typed_quantize_vec(int32_t); typed_dequantize_val(int8_t); typed_dequantize_val(uint8_t); typed_dequantize_val(int16_t); +typed_dequantize_val(uint16_t); #undef typed_dequantize_val #define typed_dequantize_vec(dtype) \ @@ -198,6 +201,7 @@ typed_dequantize_val(int16_t); typed_dequantize_vec(int8_t); typed_dequantize_vec(uint8_t); typed_dequantize_vec(int16_t); +typed_dequantize_vec(uint16_t); typed_dequantize_vec(int32_t); #undef typed_dequantize_vec diff --git a/backends/cadence/hifi/kernels/kernels.h b/backends/cadence/hifi/kernels/kernels.h index 70d5e39fad..2c915661f8 100644 --- a/backends/cadence/hifi/kernels/kernels.h +++ b/backends/cadence/hifi/kernels/kernels.h @@ -55,6 +55,34 @@ extern "C" WORD32 xa_nn_elm_mul_broadcast_4D_f32xf32_f32( const FLOAT32* __restrict__ p_inp2, const WORD32* const p_inp2_shape); +extern "C" WORD32 xa_nn_elm_where_f32xf32_f32( + FLOAT32* __restrict__ p_out, + const FLOAT32* __restrict__ p_inp1, + const FLOAT32* __restrict__ p_inp2, + const unsigned char* __restrict__ p_condition, + WORD32 num_elm); + +extern "C" WORD32 xa_nn_elm_where_broadcast_4D_f32xf32_f32( + FLOAT32* __restrict__ p_out, + const WORD32* const p_out_shape, + const FLOAT32* __restrict__ p_inp1, + const WORD32* const p_inp1_shape, + const FLOAT32* __restrict__ p_inp2, + const WORD32* const p_inp2_shape, + const unsigned char* __restrict__ p_condition, + const WORD32* const p_condition_shape); + +extern "C" WORD32 xa_nn_reduce_mean_4D_f32_f32( + FLOAT32* __restrict__ p_out, + const WORD32* const p_out_shape, + const FLOAT32* __restrict__ p_inp, + const WORD32* const p_inp_shape, + const WORD32* __restrict__ p_axis, + WORD32 num_out_dims, + WORD32 num_inp_dims, + WORD32 num_axis_dims, + void* __restrict__ p_scratch_in); + namespace cadence { namespace impl { namespace HiFi { diff --git a/backends/cadence/hifi/operators/CMakeLists.txt b/backends/cadence/hifi/operators/CMakeLists.txt index cbbb279e5d..dbe5867550 100644 --- a/backends/cadence/hifi/operators/CMakeLists.txt +++ b/backends/cadence/hifi/operators/CMakeLists.txt @@ -22,19 +22,12 @@ endif() set(_aten_ops__srcs "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_add.cpp" "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_div.cpp" + "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_mean.cpp" "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_mul.cpp" "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_sigmoid.cpp" "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_sub.cpp" "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_tanh.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/activation_ops_util.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/copy_ops_util.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/broadcast_util.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/index_util.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/kernel_ops_util.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/matmul_ops_util.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/reduce_util.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/repeat_util.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/slice_util.cpp" + "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_where.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_bmm.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_cat.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_clone.cpp" @@ -57,6 +50,7 @@ set(_aten_ops__srcs "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/matmul_ops_util.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/reduce_util.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/repeat_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/slice_util.cpp" ) add_library(aten_ops_cadence ${_aten_ops__srcs}) target_link_libraries(aten_ops_cadence PUBLIC executorch) diff --git a/backends/cadence/hifi/operators/dequantize_per_tensor.cpp b/backends/cadence/hifi/operators/dequantize_per_tensor.cpp index 18381a26e0..996d753c59 100644 --- a/backends/cadence/hifi/operators/dequantize_per_tensor.cpp +++ b/backends/cadence/hifi/operators/dequantize_per_tensor.cpp @@ -41,6 +41,9 @@ void dequantize_per_tensor_out( } else if (input.scalar_type() == ScalarType::Short) { const int16_t* input_data = input.const_data_ptr(); dequantize(out_data, input_data, scale, zero_point, numel); + } else if (input.scalar_type() == ScalarType::Bits16) { + const uint16_t* input_data = input.const_data_ptr(); + dequantize(out_data, input_data, scale, zero_point, numel); } else if (input.scalar_type() == ScalarType::Int) { const int32_t* input_data = input.const_data_ptr(); dequantize(out_data, input_data, scale, zero_point, numel); diff --git a/backends/cadence/hifi/operators/op_mean.cpp b/backends/cadence/hifi/operators/op_mean.cpp new file mode 100644 index 0000000000..478e10da71 --- /dev/null +++ b/backends/cadence/hifi/operators/op_mean.cpp @@ -0,0 +1,170 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +#include + +using exec_aten::ScalarType; +using exec_aten::Tensor; +using executorch::aten::RuntimeContext; +using executorch::runtime::ArrayRef; +using torch::executor::Error; +using torch::executor::optional; + +namespace cadence { +namespace impl { +namespace HiFi { +namespace native { + +int prepare_data( + const Tensor& in, + Tensor& out, + optional> dim_list, + int* inp_shape, + int* out_shape, + int* p_axis, + int num_inp_dims, + int num_out_dims) { + for (int i = 0; i < num_inp_dims; i++) { + inp_shape[i] = in.size(i); + } + + for (int i = 0; i < num_out_dims; i++) { + out_shape[i] = out.size(i); + } + + int num_axis_dims = 0; + for (const auto& d : dim_list.value()) { + if (d < 0) { + p_axis[num_axis_dims] = num_inp_dims + d; + num_axis_dims++; + } else { + p_axis[num_axis_dims] = d; + num_axis_dims++; + } + } + + return num_axis_dims; +} + +Tensor& mean_dim_out( + RuntimeContext& ctx, + const Tensor& in, + optional> dim_list, + bool keepdim, + optional dtype, + Tensor& out) { + ET_KERNEL_CHECK( + ctx, + torch::executor::check_mean_dim_args(in, dim_list, keepdim, dtype, out), + InvalidArgument, + out); + + ET_KERNEL_CHECK( + ctx, + torch::executor::resize_reduction_out(in, dim_list, keepdim, out) == + Error::Ok, + InvalidArgument, + out); + + constexpr auto name = "mean.out"; + constexpr int kNnlibMaxDim = 4; + + bool optimized = 1; + + if (out.scalar_type() != ScalarType::Float) + optimized = 0; + + if (in.dim() > kNnlibMaxDim) + optimized = 0; + + if (optimized) { + float* __restrict__ p_out = out.mutable_data_ptr(); + const float* __restrict__ p_inp = + (const float* __restrict__)in.const_data_ptr(); + + int num_elm = in.numel(); + + int num_inp_dims = in.dim(); + int num_out_dims = out.dim(); + + int inp_shape[kNnlibMaxDim]; + int out_shape[kNnlibMaxDim]; + int p_axis[kNnlibMaxDim]; + + for (int i = 0; i < kNnlibMaxDim; i++) { + out_shape[i] = 1; + inp_shape[i] = 1; + p_axis[i] = 1; + } + + int num_axis_dims = prepare_data( + in, + out, + dim_list, + inp_shape, + out_shape, + p_axis, + num_inp_dims, + num_out_dims); + + if (num_axis_dims == num_inp_dims) { + num_out_dims = 1; + out_shape[0] = 1; + } + + int scratch_size = xa_nn_reduce_getsize_nhwc( + -3, inp_shape, num_inp_dims, p_axis, num_axis_dims, 1); + + void* __restrict__ p_scratch_in = (void* __restrict__)malloc(scratch_size); + + xa_nn_reduce_mean_4D_f32_f32( + p_out, + out_shape, + p_inp, + inp_shape, + p_axis, + num_out_dims, + num_inp_dims, + num_axis_dims, + p_scratch_in); + + return out; + } + + ET_SWITCH_REALHB_TYPES(in.scalar_type(), ctx, name, CTYPE_IN, [&] { + ET_SWITCH_FLOATH_TYPES(out.scalar_type(), ctx, name, CTYPE_OUT, [&] { + CTYPE_OUT* out_data = out.mutable_data_ptr(); + const size_t num = torch::executor::get_reduced_dim_product(in, dim_list); + + for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) { + CTYPE_OUT sum = 0; + if (in.numel() > 0) { + sum = torch::executor::map_reduce_over_dim_list( + [](CTYPE_IN v) { return static_cast(v); }, + [](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; }, + in, + dim_list, + out_ix); + } + out_data[out_ix] = sum / static_cast(num); + } + }); + }); + + return out; +} + +} // namespace native +} // namespace HiFi +} // namespace impl +} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_where.cpp b/backends/cadence/hifi/operators/op_where.cpp new file mode 100644 index 0000000000..06bd0bc3c9 --- /dev/null +++ b/backends/cadence/hifi/operators/op_where.cpp @@ -0,0 +1,176 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +using exec_aten::ScalarType; +using exec_aten::Tensor; +using executorch::aten::RuntimeContext; +using torch::executor::Error; + +namespace cadence { +namespace impl { +namespace HiFi { +namespace native { + +Tensor& where_out( + RuntimeContext& ctx, + const Tensor& cond, + const Tensor& a, + const Tensor& b, + Tensor& out) { + ScalarType cond_type = cond.scalar_type(); + ScalarType a_type = a.scalar_type(); + ScalarType b_type = b.scalar_type(); + ScalarType common_type = executorch::runtime::promoteTypes(a_type, b_type); + ScalarType out_type = out.scalar_type(); + + ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out); + + // Determine output size and resize for dynamic shapes + ET_KERNEL_CHECK( + ctx, + torch::executor::resize_to_broadcast_target_size(a, b, cond, out) == + Error::Ok, + InvalidArgument, + out); + + constexpr int kNnlibMaxDim = 4; /*fallback if broadcast and dim > 4 */ + constexpr auto name = "where.self_out"; + + ET_CHECK_MSG( + cond_type == ScalarType::Bool || cond_type == ScalarType::Byte, + "Unhandled dtype %s for where.self_out", + torch::executor::toString(cond_type)); + + int a_dim = a.dim(), b_dim = b.dim(), con_dim = cond.dim(), + out_dim = out.dim(); + bool optimized = 1; + /*find broadcast*/ + const bool a_is_broadcasted = !out.sizes().equals(a.sizes()); + const bool b_is_broadcasted = !out.sizes().equals(b.sizes()); + const bool cond_is_broadcasted = !out.sizes().equals(cond.sizes()); + const bool broadcast = + (a_is_broadcasted || b_is_broadcasted || cond_is_broadcasted); + + int max_dim = a.dim() > b.dim() ? a.dim() : b.dim(); + max_dim = cond.dim() > max_dim ? cond.dim() : max_dim; + max_dim = out.dim() > max_dim ? out.dim() : max_dim; + + if ((a_type != ScalarType::Float) || (b_type != ScalarType::Float)) + optimized = 0; + + if ((a_dim == 0) || (b_dim == 0) || (con_dim == 0)) + optimized = 0; + + if ((broadcast == 1) && (max_dim > kNnlibMaxDim)) + optimized = 0; + + if (optimized) { + const float* a_data = a.const_data_ptr(); + const float* b_data = b.const_data_ptr(); + float* out_data = out.mutable_data_ptr(); + const unsigned char* con = cond.const_data_ptr(); + + if (broadcast == 1) { + int out_shape[kNnlibMaxDim]; + int inp1_shape[kNnlibMaxDim]; + int inp2_shape[kNnlibMaxDim]; + int con_shape[kNnlibMaxDim]; + + for (int i = 0; i < kNnlibMaxDim; i++) { + con_shape[i] = 1; + out_shape[i] = 1; + inp1_shape[i] = 1; + inp2_shape[i] = 1; + } + + int off_o = kNnlibMaxDim - out.dim(); + int off_a = kNnlibMaxDim - a.dim(); + int off_b = kNnlibMaxDim - b.dim(); + int off_c = kNnlibMaxDim - cond.dim(); + + for (int i = 0; i < out.dim(); i++) + out_shape[i + off_o] = out.size(i); + for (int i = 0; i < a.dim(); i++) + inp1_shape[i + off_a] = a.size(i); + for (int i = 0; i < b.dim(); i++) + inp2_shape[i + off_b] = b.size(i); + for (int i = 0; i < cond.dim(); i++) + con_shape[i + off_c] = cond.size(i); + + if (con_shape[0] != out_shape[0] || con_shape[1] != out_shape[1] || + con_shape[2] != out_shape[2] || con_shape[3] != out_shape[3]) { + void* p_scratch = + malloc(out_shape[0] * out_shape[1] * out_shape[2] * out_shape[3]); + const unsigned char* p_brd_cond = (const unsigned char*)p_scratch; + xa_nn_broadcast_8_8( + (WORD8* __restrict__)p_brd_cond, + out_shape, + (const WORD8* __restrict__)con, + con_shape, + 4); + + for (int i = 0; i < 4; i++) { + con_shape[i] = out_shape[i]; + } + xa_nn_elm_where_broadcast_4D_f32xf32_f32( + out_data, + out_shape, + a_data, + inp1_shape, + b_data, + inp2_shape, + p_brd_cond, + con_shape); + free(p_scratch); + } else { + xa_nn_elm_where_broadcast_4D_f32xf32_f32( + out_data, + out_shape, + a_data, + inp1_shape, + b_data, + inp2_shape, + con, + con_shape); + } + } else { + xa_nn_elm_where_f32xf32_f32(out_data, a_data, b_data, con, out.numel()); + } + return out; + } + ET_SWITCH_REALHB_TYPES(a_type, ctx, name, CTYPE_A, [&]() { + ET_SWITCH_REALHB_TYPES(b_type, ctx, name, CTYPE_B, [&]() { + using CTYPE_OUT = + typename torch::executor::promote_types::type; + torch::executor:: + apply_ternary_elementwise_fn( + [](const CTYPE_A val_a, + const CTYPE_B val_b, + const uint8_t val_c) { + CTYPE_OUT a_casted = static_cast(val_a); + CTYPE_OUT b_casted = static_cast(val_b); + return val_c ? a_casted : b_casted; + }, + a, + b, + cond, + out); + }); + }); + return out; +} + +} // namespace native +} // namespace HiFi +} // namespace impl +} // namespace cadence diff --git a/backends/cadence/hifi/operators/quantize_per_tensor.cpp b/backends/cadence/hifi/operators/quantize_per_tensor.cpp index c65d62968f..1078b5716c 100644 --- a/backends/cadence/hifi/operators/quantize_per_tensor.cpp +++ b/backends/cadence/hifi/operators/quantize_per_tensor.cpp @@ -44,6 +44,10 @@ void quantize_per_tensor_out( int16_t* out_data = out.mutable_data_ptr(); cadence::impl::HiFi::kernels::quantize( out_data, input_data, 1. / scale, zero_point, numel); + } else if (out.scalar_type() == ScalarType::Bits16) { + uint16_t* out_data = out.mutable_data_ptr(); + cadence::impl::HiFi::kernels::quantize( + out_data, input_data, 1. / scale, zero_point, numel); } else if (out.scalar_type() == ScalarType::Int) { int32_t* out_data = out.mutable_data_ptr(); cadence::impl::HiFi::kernels::quantize( diff --git a/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_where_f32xf32_f32.c b/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_where_f32xf32_f32.c new file mode 100644 index 0000000000..6a7f6d0f77 --- /dev/null +++ b/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_where_f32xf32_f32.c @@ -0,0 +1,838 @@ +/******************************************************************************* +* Copyright (c) 2018-2024 Cadence Design Systems, Inc. +* +* Permission is hereby granted, free of charge, to any person obtaining +* a copy of this software and associated documentation files (the +* "Software"), to use this Software with Cadence processor cores only and +* not with any other processors and platforms, subject to +* the following conditions: +* +* The above copyright notice and this permission notice shall be included +* in all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +******************************************************************************/ +#include "xa_type_def.h" +#include "nnlib-hifi4/xa_nnlib/algo/common/include/xa_nnlib_common_fpu.h" +#include "nnlib-hifi4/xa_nnlib/algo/common/include/xa_nn_common.h" +#include "nnlib-hifi4/xa_nnlib/algo/common/include/xa_nnlib_err_chk.h" +#include "nnlib-hifi4/xa_nnlib/algo/kernels/basic/hifi4/xa_nn_basic_state.h" +#include "xa_nnlib_kernels_api.h" + + +#if !HAVE_VFPU +DISCARD_FUN_FOR_NONVOID_RETURN( + WORD32, xa_nn_elm_where_f32xf32_f32, + ( + FLOAT32 *p_out, + const FLOAT32 *p_inp1, + const FLOAT32 *p_inp2, + const unsigned char *__restrict__ condition, + WORD32 num_elm + ) + ) +#else +WORD32 xa_nn_elm_where_f32xf32_f32(FLOAT32 * __restrict__ p_out, + const FLOAT32 * __restrict__ p_inp1, + const FLOAT32 * __restrict__ p_inp2, + const unsigned char *__restrict__ p_condition, + WORD32 num_elm) +{ + + /* NULL pointer checks */ + XA_NNLIB_ARG_CHK_PTR(p_out, -1); + XA_NNLIB_ARG_CHK_PTR(p_inp1, -1); + XA_NNLIB_ARG_CHK_PTR(p_inp2, -1); + /* Pointer alignment checks */ + XA_NNLIB_ARG_CHK_ALIGN(p_out, sizeof(FLOAT32), -1); + XA_NNLIB_ARG_CHK_ALIGN(p_inp1, sizeof(FLOAT32), -1); + XA_NNLIB_ARG_CHK_ALIGN(p_inp2, sizeof(FLOAT32), -1); + /* Basic Parameter checks */ + XA_NNLIB_ARG_CHK_COND((num_elm <= 0), -1); + + int i; + xtfloatx2 *inp1 = (xtfloatx2 *)p_inp1; + xtfloatx2 *inp2 = (xtfloatx2 *)p_inp2; + xtfloatx2 *out = (xtfloatx2 *)p_out; + unsigned char *condition = p_condition; + xtfloatx2 x1, x2, y; + unsigned char con1, con2; + xtbool2 con = int32_rtor_xtbool2(0x00000003); + + if(((((unsigned)p_out)&7) == 0) && ((((unsigned)p_inp1)&7) == 0) && ((((unsigned)p_inp2)&7) == 0)) + { + for(i=0;i < num_elm>>1;i++) + { + XT_LSX2IP(x1, inp1, 2*sizeof(FLOAT32)); + XT_LSX2IP(x2, inp2, 2*sizeof(FLOAT32)); + con1 = XT_L8UI(condition, 0); + condition++; + con2 = XT_L8UI(condition, 0); + condition++; + con = AE_MOVBA1X2(con1, con2); + XT_MOVT_SX2 (y, x1, con); + XT_MOVF_SX2 (y, x2, con); + XT_SSX2IP( y, out, 2*sizeof(FLOAT32)); + } + } + else + { + ae_valign inp1_a, inp2_a, out_a; + + inp1_a = XT_LASX2PP(inp1); + inp2_a = XT_LASX2PP(inp2); + out_a = AE_ZALIGN64(); + /* Each iteration of loop is independent so safe to use concurrent pragma */ +#pragma concurrent + for(i=0;i < num_elm>>1;i++) + { + XT_LASX2IP(x1, inp1_a, inp1); + XT_LASX2IP(x2, inp2_a, inp2); + con1 = XT_L8UI(condition, 0); + condition++; + con2 = XT_L8UI(condition, 0); + condition++; + con = AE_MOVBA1X2(con1, con2); + XT_MOVT_SX2 (y, x1, con); + XT_MOVF_SX2 (y, x2, con); + XT_SASX2IP(y, out_a, out); + } + XT_SASX2POSFP(out_a, out); + } + // Remainder Loop + if (num_elm & 1) + { + xtfloat a1, a2, a; + con1 = XT_L8UI(condition, 0); + xtbool s = AE_MOVBA(con1); + XT_LSIP(a1, (xtfloat *)inp1, 0); + XT_LSIP(a2, (xtfloat *)inp2, 0); + XT_MOVT_S(a, a1, s); + XT_MOVF_S(a, a2, s); + XT_SSI(a, (xtfloat *)out, 0); + } +} + +static void internal_elm_where_broadcast_f32xf32_f32(FLOAT32 * __restrict__ p_out, + const FLOAT32 * __restrict__ p_inp1, + const FLOAT32 * __restrict__ p_inp2, + const unsigned char * __restrict__ p_condition, + WORD32 num_elm, + xtbool sign_flag) +{ + int i; + xtfloatx2 * __restrict__ p_a = (xtfloatx2 *)p_inp1; + xtfloatx2 * __restrict__ p_b = (xtfloatx2 *)p_inp2; + xtfloatx2 *__restrict__ p_c = (xtfloatx2 *)p_out; + unsigned char *condition = p_condition; + + const int num_simd2_ops = num_elm >> 1; + const int num_scalar_ops = num_elm & 1; + + xtfloat a0_7, out; + xtfloatx2 x1, x2, y; + x2 = XT_LSI((xtfloat *)p_b, 0); + + unsigned char con1, con2; + xtbool2 con = int32_rtor_xtbool2(0x00000003); + + /* For out = condition ? inp2 :inp1 */ + if(sign_flag){ + if(((((unsigned)p_a)&7) == 0) && ((((unsigned)p_c)&7) == 0)) + { + for(i=0; i> 1; + const int num_scalar_ops = num_elm & 1; + + xtfloat a0_7, out; + xtfloatx2 x1, x2, y; + x2 = XT_LSI((xtfloat *)p_b, 0); + x1 = XT_LSI((xtfloat *)p_a, 0); + + unsigned char con1, con2; + xtbool2 con = int32_rtor_xtbool2(0x00000003); + + if((((unsigned)p_c)&7) == 0) + { + for(i=0; i> 1; + num_scalar_ops = in_lc & 1; + } + else + { + num_simd2_ops = (in_lc >> 2) << 1; + num_scalar_ops = in_lc & 3; + } + + xtfloatx2 x1, x2, y; + xtfloat a0, b0, c0; + unsigned char con1, con2; + xtbool2 con = int32_rtor_xtbool2(0x00000003); + /* For out = condition ? inp2 :inp1 */ + if(sign_flag){ + for(i = 0; i < out_lc; i++) + { + p_a = (xtfloatx2 *)&p_inp1[i * in_lc]; + p_b = (xtfloatx2 *)p_inp2; + p_c = (xtfloatx2 *)&p_out[i * in_lc]; + condition = &p_condition[i * in_lc]; + if(((((unsigned)p_a)&7) == 0) && ((((unsigned)p_b)&7) == 0) && ((((unsigned)p_c)&7) == 0)) + { + for(j = 0; j < num_simd2_ops; j++) + { + XT_LSX2IP(x1, p_a, 2 * sizeof(FLOAT32)); + XT_LSX2IP(x2, p_b, 2 * sizeof(FLOAT32)); + con1 = XT_L8UI(condition, 0); + condition++; + con2 = XT_L8UI(condition, 0); + condition++; + con = AE_MOVBA1X2(con1, con2); + XT_MOVT_SX2 (y, x2, con); + XT_MOVF_SX2 (y, x1, con); + XT_SSX2IP(y, p_c, 2 * sizeof(FLOAT32)); + } + } + else + { + ae_valign vinp1, vinp2, out_a = AE_ZALIGN64(); + vinp1 = XT_LASX2PP(p_a); + vinp2 = XT_LASX2PP(p_b); + for(j = 0; j < num_simd2_ops; j++) + { + XT_LASX2IP(x1, vinp1, p_a); + XT_LASX2IP(x2, vinp2, p_b); + con1 = XT_L8UI(condition, 0); + condition++; + con2 = XT_L8UI(condition, 0); + condition++; + con = AE_MOVBA1X2(con1, con2); + XT_MOVT_SX2 (y, x2, con); + XT_MOVF_SX2 (y, x1, con); + XT_SASX2IP(y, out_a, p_c); + } + XT_SASX2POSFP(out_a, (xtfloatx2 *)p_c); + } + if(num_scalar_ops !=0) + { + XT_LSIP(a0, (xtfloat *)p_a, 0); + XT_LSIP(b0, (xtfloat *)p_b, 0); + con1 = XT_L8UI(condition, 0); + xtbool s = AE_MOVBA(con1); + XT_MOVT_S(c0, b0, s); + XT_MOVF_S(c0, a0, s); + XT_SSI(c0, (xtfloat *)p_c, 0); + } + } + } + /* For out = condition ? inp1 :inp2 */ + else + { + for(i = 0; i < out_lc; i++) + { + p_a = (xtfloatx2 *)&p_inp1[i * in_lc]; + p_b = (xtfloatx2 *)p_inp2; + p_c = (xtfloatx2 *)&p_out[i * in_lc]; + condition = &p_condition[i * in_lc]; + if(((((unsigned)p_a)&7) == 0) && ((((unsigned)p_b)&7) == 0) && ((((unsigned)p_c)&7) == 0)) + { + for(j = 0; j < num_simd2_ops; j++) + { + XT_LSX2IP(x1, p_a, 2 * sizeof(FLOAT32)); + XT_LSX2IP(x2, p_b, 2 * sizeof(FLOAT32)); + con1 = XT_L8UI(condition, 0); + condition++; + con2 = XT_L8UI(condition, 0); + condition++; + con = AE_MOVBA1X2(con1, con2); + XT_MOVT_SX2 (y, x1, con); + XT_MOVF_SX2 (y, x2, con); + XT_SSX2IP(y, p_c, 2 * sizeof(FLOAT32)); + } + } + else + { + ae_valign vinp1, vinp2, out_a = AE_ZALIGN64(); + vinp1 = XT_LASX2PP(p_a); + vinp2 = XT_LASX2PP(p_b); + + for(j = 0; j < num_simd2_ops; j++) + { + XT_LASX2IP(x1, vinp1, p_a); + XT_LASX2IP(x2, vinp2, p_b); + con1 = XT_L8UI(condition, 0); + condition++; + con2 = XT_L8UI(condition, 0); + condition++; + con = AE_MOVBA1X2(con1, con2); + XT_MOVT_SX2 (y, x1, con); + XT_MOVF_SX2 (y, x2, con); + XT_SASX2IP(y, out_a, p_c); + } + XT_SASX2POSFP(out_a, (xtfloatx2 *)p_c); + } + if(num_scalar_ops !=0) + { + XT_LSIP(a0, (xtfloat *)p_a, 0); + XT_LSIP(b0, (xtfloat *)p_b, 0); + con1 = XT_L8UI(condition, 0); + xtbool s = AE_MOVBA(con1); + XT_MOVT_S(c0, a0, s); + XT_MOVF_S(c0, b0, s); + XT_SSI(c0, (xtfloat *)p_c, 0); + } + } + } +} + +static void internal_elm_where_broadcast_both_2D_f32xf32_f32(FLOAT32 * __restrict__ p_out, + const FLOAT32 * __restrict__ p_inp1, + const FLOAT32 * __restrict__ p_inp2, + const unsigned char * __restrict__ p_condition, + WORD32 out_lc, + WORD32 in_lc) +{ + int i, j; + + xtfloatx2 * __restrict__ p_a = (xtfloatx2 *)p_inp1; + xtfloatx2 * __restrict__ p_b = (xtfloatx2 *)p_inp2; + xtfloatx2 *__restrict__ p_c = (xtfloatx2 *)p_out; + unsigned char *condition = p_condition; + + int num_simd2_ops; + int num_scalar_ops; + + if(out_lc) + { + num_simd2_ops = in_lc >> 1; + num_scalar_ops = in_lc & 1; + } + else + { + num_simd2_ops = (in_lc >> 2) << 1; + num_scalar_ops = in_lc & 3; + } + + xtfloatx2 x1, x2, y; + xtfloat a0, b0, c0; + unsigned char con1, con2; + xtbool2 con = int32_rtor_xtbool2(0x00000003); + + for(i = 0; i < out_lc; i++) + { + p_a = (xtfloatx2 *)p_inp1; + p_b = (xtfloatx2 *)p_inp2; + p_c = (xtfloatx2 *)&p_out[i * in_lc]; + condition = &p_condition[i * in_lc]; + if(((((unsigned)p_a)&7) == 0) && ((((unsigned)p_b)&7) == 0) && ((((unsigned)p_c)&7) == 0)) + { + for(j = 0; j < num_simd2_ops; j++) + { + XT_LSX2IP(x1, p_a, 2 * sizeof(FLOAT32)); + XT_LSX2IP(x2, p_b, 2 * sizeof(FLOAT32)); + con1 = XT_L8UI(condition, 0); + condition++; + con2 = XT_L8UI(condition, 0); + condition++; + con = AE_MOVBA1X2(con1, con2); + XT_MOVT_SX2 (y, x1, con); + XT_MOVF_SX2 (y, x2, con); + XT_SSX2IP(y, p_c, 2 * sizeof(FLOAT32)); + } + } + else + { + ae_valign vinp1, vinp2, out_a = AE_ZALIGN64(); + vinp1 = XT_LASX2PP(p_a); + vinp2 = XT_LASX2PP(p_b); + + for(j = 0; j < num_simd2_ops; j++) + { + XT_LASX2IP(x1, vinp1, p_a); + XT_LASX2IP(x2, vinp2, p_b); + con1 = XT_L8UI(condition, 0); + condition++; + con2 = XT_L8UI(condition, 0); + condition++; + con = AE_MOVBA1X2(con1, con2); + XT_MOVT_SX2 (y, x1, con); + XT_MOVF_SX2 (y, x2, con); + XT_SASX2IP(y, out_a, p_c); + } + XT_SASX2POSFP(out_a, (xtfloatx2 *)p_c); + } + if(num_scalar_ops !=0) + { + XT_LSIP(a0, (xtfloat *)p_a, 0); + XT_LSIP(b0, (xtfloat *)p_b, 0); + con1 = XT_L8UI(condition, 0); + xtbool s = AE_MOVBA(con1); + XT_MOVT_S(c0, a0, s); + XT_MOVF_S(c0, b0, s); + XT_SSI(c0, (xtfloat *)p_c, 0); + } + } +} + +WORD32 xa_nn_elm_where_broadcast_4D_f32xf32_f32(FLOAT32 * __restrict__ p_out, + const WORD32 *const p_out_shape, + const FLOAT32 * __restrict__ p_inp1, + const WORD32 *const p_inp1_shape, + const FLOAT32 * __restrict__ p_inp2, + const WORD32 *const p_inp2_shape, + const unsigned char *__restrict__ p_condition, + const WORD32 *const p_condition_shape + ) +{ + /* NULL pointer checks */ + XA_NNLIB_ARG_CHK_PTR(p_out, -1); + XA_NNLIB_ARG_CHK_PTR(p_inp1, -1); + XA_NNLIB_ARG_CHK_PTR(p_inp2, -1); + XA_NNLIB_ARG_CHK_PTR(p_condition, -1); + XA_NNLIB_ARG_CHK_PTR(p_out_shape, -1); + XA_NNLIB_ARG_CHK_PTR(p_inp1_shape, -1); + XA_NNLIB_ARG_CHK_PTR(p_inp2_shape, -1); + XA_NNLIB_ARG_CHK_PTR(p_condition_shape, -1); + /* Pointer alignment checks */ + XA_NNLIB_ARG_CHK_ALIGN(p_out, sizeof(FLOAT32), -1); + XA_NNLIB_ARG_CHK_ALIGN(p_inp1, sizeof(FLOAT32), -1); + XA_NNLIB_ARG_CHK_ALIGN(p_inp2, sizeof(FLOAT32), -1); + XA_NNLIB_ARG_CHK_ALIGN(p_condition, sizeof(FLOAT32), -1); + XA_NNLIB_ARG_CHK_ALIGN(p_out_shape, sizeof(WORD32), -1); + XA_NNLIB_ARG_CHK_ALIGN(p_inp1_shape, sizeof(WORD32), -1); + XA_NNLIB_ARG_CHK_ALIGN(p_inp2_shape, sizeof(WORD32), -1); + XA_NNLIB_ARG_CHK_ALIGN(p_condition_shape, sizeof(WORD32), -1); + + /* Check shapes */ + int i; + xtbool sign_flag; + for(i = 0; i < 4; i++) + { + if((p_inp1_shape[i] != p_inp2_shape[i]) && ((p_inp1_shape[i] != 1) && (p_inp2_shape[i] != 1))) + { + return -1; + } + } + WORD32 inp1_strides[4], inp2_strides[4]; + inp1_strides[3] = 1; + inp2_strides[3] = 1; + for(i = 2; i >= 0; i--) + { + ae_int32x2 d_str, d_shape; + d_str = AE_MOVDA32X2(inp1_strides[i + 1], inp2_strides[i + 1]); + d_shape = AE_MOVDA32X2(p_inp1_shape[i + 1], p_inp2_shape[i + 1]); + d_str = AE_MULP32X2(d_str, d_shape); + inp1_strides[i] = AE_MOVAD32_H(d_str); + inp2_strides[i] = AE_MOVAD32_L(d_str); + } + + int need_broadcast = 0; + int inp1_const = 1, inp2_const = 1; + for(i = 0; i < 4; i++) + { + if(p_inp1_shape[i] == 1) + { + inp1_strides[i] = 0; + need_broadcast = 1; + } + else + { + inp1_const &= 0; + } + if(p_inp2_shape[i] == 1) + { + inp2_strides[i] = 0; + need_broadcast = 1; + } + else + { + inp2_const &= 0; + } + } + + int itr0, itr1, itr2; + FLOAT32 *p_out_tmp = p_out; + const unsigned char *__restrict p_condition_temp = p_condition; + const FLOAT32 *__restrict__ p_inp1_tmp = p_inp1; + const FLOAT32 *__restrict__ p_inp2_tmp = p_inp2; + + if(need_broadcast == 0) + { + sign_flag = 0; + internal_elm_where_broadcast_2D_f32xf32_f32( + p_out, + p_inp1, + p_inp2, + p_condition, + 1, + p_out_shape[0] * inp1_strides[0], + sign_flag); + } + else if((inp1_strides[3] == 1)&& (inp2_strides[3] == 1)) + { + WORD32 in_lc, out_lc; + sign_flag = 0; + in_lc = p_out_shape[2] * p_out_shape[3]; + out_lc = 1; + if((inp1_strides[2] == 0) && (inp2_strides[2] == 0)) + { + in_lc = p_out_shape[3]; + out_lc = p_out_shape[2]; + for(itr0 = 0; itr0 < p_out_shape[0]; itr0++) + { + const FLOAT32 *__restrict__ p_inp1_tmp0 = p_inp1_tmp; + const FLOAT32 *__restrict__ p_inp2_tmp0 = p_inp2_tmp; + for(itr1 = 0; itr1 < p_out_shape[1]; itr1++) + { + internal_elm_where_broadcast_both_2D_f32xf32_f32( + p_out_tmp, + p_inp1_tmp0, + p_inp2_tmp0, + p_condition_temp, + out_lc, + in_lc); + p_out_tmp += in_lc * out_lc; + p_inp1_tmp0 += inp1_strides[1]; + p_inp2_tmp0 += inp2_strides[1]; + p_condition_temp += in_lc * out_lc; + } + p_inp1_tmp += inp1_strides[0]; + p_inp2_tmp += inp2_strides[0]; + } + } + else + { + if(inp1_strides[2] == 0) + { + const FLOAT32 *tmp; + tmp = p_inp1_tmp; p_inp1_tmp = p_inp2_tmp; p_inp2_tmp = tmp; + sign_flag = 1; + int tmp_strides[2]; + tmp_strides[0] = inp1_strides[0]; + tmp_strides[1] = inp1_strides[1]; + + inp1_strides[0] = inp2_strides[0]; + inp1_strides[1] = inp2_strides[1]; + + inp2_strides[0] = tmp_strides[0]; + inp2_strides[1] = tmp_strides[1]; + in_lc = p_out_shape[3]; + out_lc = p_out_shape[2]; + } + else if(inp2_strides[2] == 0) + { + in_lc = p_out_shape[3]; + out_lc = p_out_shape[2]; + } + + for(itr0 = 0; itr0 < p_out_shape[0]; itr0++) + { + const FLOAT32 *__restrict__ p_inp1_tmp0 = p_inp1_tmp; + const FLOAT32 *__restrict__ p_inp2_tmp0 = p_inp2_tmp; + for(itr1 = 0; itr1 < p_out_shape[1]; itr1++) + { + internal_elm_where_broadcast_2D_f32xf32_f32( + p_out_tmp, + p_inp1_tmp0, + p_inp2_tmp0, + p_condition_temp, + out_lc, + in_lc, + sign_flag); + p_out_tmp += in_lc * out_lc; + p_inp1_tmp0 += inp1_strides[1]; + p_inp2_tmp0 += inp2_strides[1]; + p_condition_temp += in_lc * out_lc; + } + + p_inp1_tmp += inp1_strides[0]; + p_inp2_tmp += inp2_strides[0]; + } + } + } + else if(inp1_const == 1 || inp2_const == 1) + { + if((inp1_const == 1)&&(inp2_const == 1)) + { + internal_elm_where_broadcast_both_f32xf32_f32( + p_out_tmp, + p_inp1_tmp, + p_inp2_tmp, + p_condition_temp, + p_out_shape[0] * p_out_shape[1] * p_out_shape[2] * p_out_shape[3]); + } + else + { + sign_flag = 0; + if(inp1_strides[3] == 0) + { + sign_flag = 1; + const FLOAT32 *tmp; + tmp = p_inp1_tmp; p_inp1_tmp = p_inp2_tmp; p_inp2_tmp = tmp; + } + internal_elm_where_broadcast_f32xf32_f32( + p_out_tmp, + p_inp1_tmp, + p_inp2_tmp, + p_condition_temp, + p_out_shape[0] * p_out_shape[1] * p_out_shape[2] * p_out_shape[3], + sign_flag); + } + } + else + { + sign_flag = 0; + if((inp1_strides[3] == 0) && (inp2_strides[3] == 0)) + { + for(itr0 = 0; itr0 < p_out_shape[0]; itr0++) + { + const FLOAT32 *__restrict__ p_inp1_tmp0 = p_inp1_tmp; + const FLOAT32 *__restrict__ p_inp2_tmp0 = p_inp2_tmp; + for(itr1 = 0; itr1 < p_out_shape[1]; itr1++) + { + const FLOAT32 *__restrict__ p_inp1_tmp1 = p_inp1_tmp0; + const FLOAT32 *__restrict__ p_inp2_tmp1 = p_inp2_tmp0; + for(itr2 = 0; itr2 < p_out_shape[2]; itr2++) + { + { + internal_elm_where_broadcast_both_f32xf32_f32( + p_out_tmp, + p_inp1_tmp1, + p_inp2_tmp1, + p_condition_temp, + p_out_shape[3]); + } + p_out_tmp += p_out_shape[3]; + p_inp1_tmp1 += inp1_strides[2]; + p_inp2_tmp1 += inp2_strides[2]; + p_condition_temp += p_out_shape[3]; + } + p_inp1_tmp0 += inp1_strides[1]; + p_inp2_tmp0 += inp2_strides[1]; + } + p_inp1_tmp += inp1_strides[0]; + p_inp2_tmp += inp2_strides[0]; + } + } + else + { + if(inp1_strides[3] == 0) + { + const FLOAT32 *tmp; + tmp = p_inp1_tmp; p_inp1_tmp = p_inp2_tmp; p_inp2_tmp = tmp; + sign_flag = 1; + int tmp_strides[3]; + tmp_strides[0] = inp1_strides[0]; + tmp_strides[1] = inp1_strides[1]; + tmp_strides[2] = inp1_strides[2]; + + inp1_strides[0] = inp2_strides[0]; + inp1_strides[1] = inp2_strides[1]; + inp1_strides[2] = inp2_strides[2]; + + inp2_strides[0] = tmp_strides[0]; + inp2_strides[1] = tmp_strides[1]; + inp2_strides[2] = tmp_strides[2]; + } + for(itr0 = 0; itr0 < p_out_shape[0]; itr0++) + { + const FLOAT32 *__restrict__ p_inp1_tmp0 = p_inp1_tmp; + const FLOAT32 *__restrict__ p_inp2_tmp0 = p_inp2_tmp; + for(itr1 = 0; itr1 < p_out_shape[1]; itr1++) + { + const FLOAT32 *__restrict__ p_inp1_tmp1 = p_inp1_tmp0; + const FLOAT32 *__restrict__ p_inp2_tmp1 = p_inp2_tmp0; + for(itr2 = 0; itr2 < p_out_shape[2]; itr2++) + { + { + internal_elm_where_broadcast_f32xf32_f32( + p_out_tmp, + p_inp1_tmp1, + p_inp2_tmp1, + p_condition_temp, + p_out_shape[3], + sign_flag); + } + p_out_tmp += p_out_shape[3]; + p_inp1_tmp1 += inp1_strides[2]; + p_inp2_tmp1 += inp2_strides[2]; + p_condition_temp += p_out_shape[3]; + } + p_inp1_tmp0 += inp1_strides[1]; + p_inp2_tmp0 += inp2_strides[1]; + } + p_inp1_tmp += inp1_strides[0]; + p_inp2_tmp += inp2_strides[0]; + } + } + } + return 0; +} + +#endif \ No newline at end of file diff --git a/backends/cadence/hifi/third-party/nnlib/xa_nn_reduce_32_32.c b/backends/cadence/hifi/third-party/nnlib/xa_nn_reduce_32_32.c new file mode 100644 index 0000000000..5978a92d26 --- /dev/null +++ b/backends/cadence/hifi/third-party/nnlib/xa_nn_reduce_32_32.c @@ -0,0 +1,647 @@ +#include "xa_nnlib_common.h" +#include +//#include "xa_nn_basic_state.h" +#include "xa_nnlib_common_macros.h" + +#define ALIGNMENT_8 8 + +#define ALIGN_PTR(x, bytes) ((((unsigned)(x))+(bytes-1))&(~(bytes-1))) + +static void vecmean16_inpx3(const xtfloatx2 *p_src1, const xtfloat* p_src2, const xtfloat* p_src3, xtfloatx2 *p_dst, int N){ + int i = 0; + ae_valign align_src1, align_dst; + ae_valign align_src2, align_src3; + align_src1 = AE_LA64_PP(p_src1); + align_src2 = AE_LA64_PP(p_src2); + align_src3 = AE_LA64_PP(p_src3); + align_dst = AE_ZALIGN64(); + + for(i=0; i < (N >> 2); i++) + { + xtfloatx2 j1_h, j1_l, j2_h, j2_l; + + xtfloatx2 wout1, wout2; + XT_LASX2IP(wout1, align_src1, p_src1); + XT_LASX2IP(wout2, align_src1, p_src1); + + XT_LASX2IP(j1_h, align_src2, (xtfloatx2 *)p_src2); + XT_LASX2IP(j1_l, align_src2, (xtfloatx2 *)p_src2); + XT_LASX2IP(j2_h, align_src3, (xtfloatx2 *)p_src3); + XT_LASX2IP(j2_l, align_src3, (xtfloatx2 *)p_src3); + + j1_h = XT_ADD_SX2(j1_h, j2_h); + j1_l = XT_ADD_SX2(j1_l, j2_l); + wout1 = XT_ADD_SX2(wout1, j1_h); + wout2 = XT_ADD_SX2(wout2, j1_l); + + XT_SASX2IP(wout1, align_dst, p_dst); + XT_SASX2IP(wout2, align_dst, p_dst); + } + AE_SA64POS_FP(align_dst, p_dst); // finalize the stream + + //Remainder Loop + for(i=0; i < (N & 3); i++) + { + xtfloat j1, j2; + xtfloat wout1; + XT_LSXP(wout1, (xtfloat *)p_src1, sizeof(xtfloat)); + j1 = (xtfloat) *(p_src2 + i); + j2 = (xtfloat) *(p_src3 + i); + + j1 = XT_ADD_S(j1, j2); + wout1 = XT_ADD_S(wout1, j1); + XT_SSXP(wout1, (xtfloat *)p_dst, sizeof(xtfloat)); + } +} + +static void vecmean16_inpx2(const xtfloatx2 *p_src1, const xtfloat* p_src2, xtfloatx2 *p_dst, int N){ + ae_valign align_src1, align_dst; + ae_valign align_src2; + align_src1 = AE_LA64_PP(p_src1); + align_src2 = AE_LA64_PP(p_src2); + align_dst = AE_ZALIGN64(); + + int i = 0; + for(i=0; i < (N >> 2); i++) + { + xtfloatx2 j1, j2; + xtfloatx2 wout1, wout2; + XT_LASX2IP(wout1, align_src1, p_src1); + XT_LASX2IP(wout2, align_src1, p_src1); + + XT_LASX2IP(j1, align_src2, (xtfloatx2 *)p_src2); + XT_LASX2IP(j2, align_src2, (xtfloatx2 *)p_src2); + + wout1 = XT_ADD_SX2(wout1, j1); + wout2 = XT_ADD_SX2(wout2, j2); + + XT_SASX2IP(wout1, align_dst, p_dst); + XT_SASX2IP(wout2, align_dst, p_dst); + } + AE_SA64POS_FP(align_dst, p_dst); // finalize the stream + + //Remainder Loop + for(i=0; i < (N & 3); i++) + { + xtfloat j1; + xtfloat wout1; + XT_LSXP(wout1, (xtfloat *)p_src1, sizeof(xtfloat)); + j1 = (xtfloat) *(p_src2 + i); + wout1 = XT_ADD_S(wout1, j1); + XT_SSXP(wout1, (xtfloat *)p_dst, sizeof(xtfloat)); + } +} + +static void vecmean32_inpx3(const xtfloatx2* p_src1, const xtfloatx2* p_wsrc2, const xtfloatx2* p_wsrc3, xtfloatx2 *p_dst, int N){ + ae_valign align_src1, align_src2, align_src3, align_dst; + align_src1 = AE_LA64_PP(p_src1); + align_src2 = AE_LA64_PP(p_wsrc2); + align_src3 = AE_LA64_PP(p_wsrc3); + align_dst = AE_ZALIGN64(); + + int i = 0; + for(i=0; i < (N >> 2); i++) + { + xtfloatx2 j1, j2, j3, j4; + xtfloatx2 wj1, wj2; + xtfloatx2 wout1, wout2; + XT_LASX2IP(wout1, align_src1, p_src1); + XT_LASX2IP(wout2, align_src1, p_src1); + XT_LASX2IP(j1, align_src2, p_wsrc2); + XT_LASX2IP(j2, align_src3, p_wsrc3); + XT_LASX2IP(j3, align_src2, p_wsrc2); + XT_LASX2IP(j4, align_src3, p_wsrc3); + + wj1 = XT_ADD_SX2(j1, j2); + wj2 = XT_ADD_SX2(j3, j4); + wout1 = XT_ADD_SX2(wout1, wj1); + wout2 = XT_ADD_SX2(wout2, wj2); + XT_SASX2IP(wout1, align_dst, p_dst); + XT_SASX2IP(wout2, align_dst, p_dst); + } + AE_SA64POS_FP(align_dst, p_dst); // finalize the stream + + //Remainder Loop + for(i=0; i < (N & 3); i++) + { + xtfloat j1, j2; + xtfloat wj1; + xtfloat wout1; + XT_LSXP(wout1, (xtfloat *)p_src1, 4); + XT_LSXP(j1, (xtfloat *)p_wsrc2, 4); + XT_LSXP(j2, (xtfloat *)p_wsrc3, 4); + wj1 = XT_ADD_S(j1, j2); + wout1 = XT_ADD_S(wout1, wj1); + XT_SSXP(wout1, (xtfloat *)p_dst, sizeof(xtfloat)); + } +} + +static void vecmean32_inpx2(const xtfloatx2* p_src1, const xtfloatx2* p_wsrc2, xtfloatx2 *p_dst, int N){ + ae_valign align_src1, align_src2, align_dst; + align_src1 = AE_LA64_PP(p_src1); + align_src2 = AE_LA64_PP(p_wsrc2); + align_dst = AE_ZALIGN64(); + + int i = 0; + for(i=0; i < (N >> 2); i++) + { + xtfloatx2 j1, j2; + xtfloatx2 wout1, wout2; + XT_LASX2IP(wout1, align_src1, p_src1); + XT_LASX2IP(wout2, align_src1, p_src1); + XT_LASX2IP(j1, align_src2, p_wsrc2); + XT_LASX2IP(j2, align_src2, p_wsrc2); + wout1 = XT_ADD_SX2(wout1, j1); + wout2 = XT_ADD_SX2(wout2, j2); + XT_SASX2IP(wout1, align_dst, p_dst); + XT_SASX2IP(wout2, align_dst, p_dst); + } + AE_SA64POS_FP(align_dst, p_dst); // finalize the stream + + //Remainder Loop + for(i=0; i < (N & 3); i++) + { + xtfloat j1; + xtfloat wout1; + XT_LSXP(wout1, (xtfloat *)p_src1, 4); + XT_LSXP(j1, (xtfloat *)p_wsrc2, 4); + wout1 = XT_ADD_S(wout1, j1); + XT_SSXP(wout1, (xtfloat *)p_dst, sizeof(WORD32)); + } +} + +static inline void xa_nn_reduce_sum_4D_f32_f32(const FLOAT32 * __restrict__ p_inp + ,const WORD32 *const p_4D_inp_shape + ,const WORD32 * __restrict__ p_axis_data + ,WORD32 num_inp_dims + ,WORD32 num_axis_dims + ,pVOID p_scratch_in) +{ + xtfloat *p_in = (xtfloat *)(p_inp); + xtfloat *p_scratch = (xtfloat *)(p_scratch_in); + + int temp_inp_n = p_4D_inp_shape[0]; + int temp_inp_h = p_4D_inp_shape[1]; + int temp_inp_w = p_4D_inp_shape[2]; + int temp_inp_c = p_4D_inp_shape[3]; + + int itr_axis = 0, itr_n = 0, itr_h = 0, itr_w = 0, itr_c = 0; + xtfloat *p_src2, *p_src3; + xtfloatx2 *p_src1; + xtfloatx2 * p_dst; + ae_valign align_src2; + + int axis_dims_count = num_axis_dims; + if(axis_dims_count) + { + switch(p_axis_data[itr_axis]) + { + case 0: { + int plane_size = temp_inp_h * temp_inp_w * temp_inp_c; + for(itr_n=0; itr_n < (temp_inp_n & ~(2 - 1)); itr_n += 2) + { + p_src1 = (xtfloatx2 *)p_scratch; + p_src2 = p_in + itr_n * plane_size; + p_src3 = p_in + (itr_n + 1) * plane_size; + p_dst = (xtfloatx2 *)p_scratch; + vecmean16_inpx3(p_src1, p_src2, p_src3, p_dst, plane_size); + } + + if(temp_inp_n & 1) + { + p_src1 = (xtfloatx2 *)p_scratch; + p_src2 = (p_in + itr_n * plane_size); + p_dst = (xtfloatx2 *)p_scratch; + vecmean16_inpx2(p_src1, p_src2, p_dst, plane_size); + } + temp_inp_n = 1; + }break; + case 1: { + int plane_size = temp_inp_h * temp_inp_w * temp_inp_c; + int wc_plane_size = temp_inp_w * temp_inp_c; + for(itr_n=0; itr_n < (temp_inp_n); itr_n++) + { + p_src1 = (xtfloatx2 *)(p_scratch + (itr_n * wc_plane_size)); + for(itr_h=0; itr_h < (temp_inp_h & ~(2 - 1)); itr_h += 2) + { + p_src2 = p_in + (itr_n * plane_size) + (itr_h * wc_plane_size); + p_src3 = p_in + (itr_n * plane_size) + ((itr_h + 1) * wc_plane_size); + p_dst = (xtfloatx2 *)(p_scratch + (itr_n * wc_plane_size)); + vecmean16_inpx3(p_src1, p_src2, p_src3, p_dst, wc_plane_size); + p_src1 = (xtfloatx2 *)(p_scratch + (itr_n * wc_plane_size)); + } + + if(temp_inp_h & 1) + { + p_src2 = p_in + (itr_n * plane_size) + (itr_h * wc_plane_size); + p_dst = (xtfloatx2 *)(p_scratch + (itr_n * wc_plane_size)); + vecmean16_inpx2(p_src1, p_src2, p_dst, wc_plane_size); + } + } + temp_inp_h = 1; + }break; + case 2:{ + int plane_size = temp_inp_h * temp_inp_w * temp_inp_c; + int wc_plane_size = temp_inp_w * temp_inp_c; + int hc_plane_size = temp_inp_h * temp_inp_c; + + for(itr_n=0; itr_n < (temp_inp_n); itr_n++) + { + for(itr_h=0; itr_h < (temp_inp_h); itr_h++) + { + p_src1 = (xtfloatx2 *)(p_scratch + (((itr_n * hc_plane_size) + itr_h * temp_inp_c))); + for(itr_w=0; itr_w < (temp_inp_w & ~(2 - 1)); itr_w += 2) + { + p_src2 = p_in + (itr_n * plane_size) + (itr_h * wc_plane_size) + (itr_w * temp_inp_c); + p_src3 = p_in + (itr_n * plane_size) + (itr_h * wc_plane_size) + ((itr_w + 1) * temp_inp_c); + p_dst = (xtfloatx2 *)(p_scratch + (itr_n * hc_plane_size) + itr_h * temp_inp_c); + vecmean16_inpx3(p_src1, p_src2, p_src3, p_dst, temp_inp_c); + p_src1 = (xtfloatx2 *)(p_scratch + (itr_n * hc_plane_size) + (itr_h * temp_inp_c)); + } + + if(temp_inp_w & 1) + { + p_src2 = p_in + (itr_n * plane_size) + (itr_h * wc_plane_size) + (itr_w * temp_inp_c); + p_dst = (xtfloatx2 *)(p_scratch + (itr_n * hc_plane_size) + itr_h * temp_inp_c); + vecmean16_inpx2(p_src1, p_src2, p_dst, temp_inp_c); + } + } + } + temp_inp_w = 1; + }break; + case 3: { + int plane_size = temp_inp_h * temp_inp_w * temp_inp_c; + int wc_plane_size = temp_inp_w * temp_inp_c; + int hw_plane_size = temp_inp_h * temp_inp_w; + int rem_c = (temp_inp_c & 7); + + for(itr_n=0; itr_n < (temp_inp_n); itr_n++) + { + for(itr_h=0; itr_h < (temp_inp_h); itr_h++) + { + for(itr_w=0; itr_w < (temp_inp_w); itr_w++) + { + p_src1 = (xtfloatx2 *)(p_scratch + (((itr_n * hw_plane_size) + (itr_h * temp_inp_w) + itr_w))); + p_src2 = p_in + (itr_n * plane_size) + (itr_h * wc_plane_size) + (itr_w * temp_inp_c); + p_dst = (xtfloatx2 *)(p_scratch + (itr_n * hw_plane_size) + (itr_h * temp_inp_w) + itr_w); + align_src2 = AE_LA64_PP(p_src2); + + for(itr_c=0; itr_c < (temp_inp_c >> 3); itr_c++) + { + xtfloatx2 j11, j12, j21, j22, i1; + i1 = XT_LSX((xtfloat *)p_src1, 0); + XT_LASX2IP(j11, align_src2, (xtfloatx2 *)p_src2); + XT_LASX2IP(j12, align_src2, (xtfloatx2 *)p_src2); + XT_LASX2IP(j21, align_src2, (xtfloatx2 *)p_src2); + XT_LASX2IP(j22, align_src2, (xtfloatx2 *)p_src2); + + j11 = XT_ADD_SX2(j11, j12); + j21 = XT_ADD_SX2(j21, j22); + + xtfloatx2 t1 = XT_SEL32_HH_SX2(j11, j11); + xtfloatx2 t2 = XT_SEL32_HH_SX2(j21, j21); + + j11 = XT_ADD_SX2(j11, t1); + j21 = XT_ADD_SX2(j21, t2); + + j11 = XT_ADD_SX2(j11, j21); + i1 = XT_ADD_SX2(i1, j11); + + XT_SSX(i1, (xtfloat *)p_dst, 0); + + p_src1 = p_dst; + } + //Remainder Loop + for(itr_c=0; itr_c < rem_c ; itr_c++) + { + xtfloat j1; + xtfloat i1; + i1 = XT_LSX((xtfloat *)p_src1, 0); + j1 = *p_src2++; + + i1 = XT_ADD_S(i1, j1); + XT_SSX(i1, (xtfloat *)p_dst, 0); + } + } + } + } + temp_inp_c = 1; + }break; + default: + break; + } + + axis_dims_count--; + itr_axis++; + } + + while(axis_dims_count) + { + ae_valign align_src; + xtfloat *p_scr_in = p_scratch; + xtfloatx2 *p_wsrc2, *p_wsrc3; + switch(p_axis_data[itr_axis]) + { + case 0: { + int plane_size = temp_inp_h * temp_inp_w * temp_inp_c; + for(itr_n=1; itr_n < ((temp_inp_n -1) & ~(2 - 1)); itr_n += 2) + { + p_src1 = (xtfloatx2 *)p_scratch; + p_wsrc2 = (xtfloatx2 *)(p_scr_in + itr_n * plane_size); + p_wsrc3 = (xtfloatx2 *)(p_scr_in + (itr_n + 1) * plane_size); + p_dst = (xtfloatx2 *)p_scratch; + vecmean32_inpx3(p_src1, p_wsrc2, p_wsrc3, p_dst, plane_size); + } + + if((temp_inp_n - 1) & 1) + { + p_src1 = (xtfloatx2 *)p_scratch; + p_wsrc2 = (xtfloatx2 *)(p_scr_in + itr_n * plane_size); + p_dst = (xtfloatx2 *)p_scratch; + vecmean32_inpx2(p_src1, p_wsrc2, p_dst, plane_size); + } + temp_inp_n = 1; + }break; + case 1: { + int plane_size = temp_inp_h * temp_inp_w * temp_inp_c; + int wc_plane_size = temp_inp_w * temp_inp_c; + for(itr_n=0; itr_n < (temp_inp_n); itr_n++) + { + p_src1 = (xtfloatx2 *)(p_scratch + + (itr_n * plane_size)); + for(itr_h = 1; itr_h < ((temp_inp_h - 1) & ~(2 - 1)); itr_h += 2) + { + p_wsrc2 = (xtfloatx2 *)(p_scr_in + (itr_n * plane_size) + (itr_h * wc_plane_size)); + p_wsrc3 = (xtfloatx2 *)(p_scr_in + (itr_n * plane_size) + ((itr_h + 1) * wc_plane_size)); + p_dst = (xtfloatx2 *)(p_scratch + (itr_n * wc_plane_size)); + vecmean32_inpx3(p_src1, p_wsrc2, p_wsrc3, p_dst, wc_plane_size); + p_src1 = (xtfloatx2 *)(p_scratch + (itr_n * wc_plane_size)); + } + + if((temp_inp_h - 1) & 1) + { + p_wsrc2 = (xtfloatx2 *)(p_scr_in + (itr_n * plane_size) + (itr_h * wc_plane_size)); + p_dst = (xtfloatx2 *)(p_scratch + (itr_n * wc_plane_size)); + vecmean32_inpx2(p_src1, p_wsrc2, p_dst, plane_size); + } + } + temp_inp_h = 1; + }break; + case 2:{ + int plane_size = temp_inp_h * temp_inp_w * temp_inp_c; + int wc_plane_size = temp_inp_w * temp_inp_c; + int hc_plane_size = temp_inp_h * temp_inp_c; + for(itr_n=0; itr_n < (temp_inp_n); itr_n++) + { + for(itr_h=0; itr_h < (temp_inp_h); itr_h++) + { + p_src1 = (xtfloatx2 *)(p_scratch + ((itr_n * plane_size) + (itr_h * wc_plane_size))); + for(itr_w = 1; itr_w < ((temp_inp_w - 1) & ~(2 - 1)); itr_w += 2) + { + p_wsrc2 = (xtfloatx2 *)(p_scr_in + (itr_n * plane_size) + (itr_h * wc_plane_size) + (itr_w * temp_inp_c)); + p_wsrc3 = (xtfloatx2 *)(p_scr_in + (itr_n * plane_size) + (itr_h * wc_plane_size) + ((itr_w + 1) * temp_inp_c)); + p_dst = (xtfloatx2 *)(p_scratch + (itr_n * hc_plane_size) + itr_h * temp_inp_c); + vecmean32_inpx3(p_src1, p_wsrc2, p_wsrc3, p_dst, temp_inp_c); + p_src1 = (xtfloatx2 *)(p_scratch + (itr_n * hc_plane_size) + (itr_h * temp_inp_c)); + } + + if((temp_inp_w - 1) & 1) + { + p_wsrc2 = (xtfloatx2 *)(p_scr_in + (itr_n * plane_size) + (itr_h * wc_plane_size) + (itr_w * temp_inp_c)); + p_dst = (xtfloatx2 *)(p_scratch + (itr_n * hc_plane_size) + itr_h * temp_inp_c); + vecmean32_inpx2(p_src1, p_wsrc2, p_dst, temp_inp_c); + } + } + } + temp_inp_w = 1; + }break; + case 3: { + int plane_size = temp_inp_h * temp_inp_w * temp_inp_c; + int wc_plane_size = temp_inp_w * temp_inp_c; + int hw_plane_size = temp_inp_h * temp_inp_w; + int rem_c = ((temp_inp_c) & 3); + for(itr_n=0; itr_n < (temp_inp_n); itr_n++) + { + for(itr_h=0; itr_h < (temp_inp_h); itr_h++) + { + for(itr_w=0; itr_w < (temp_inp_w); itr_w++) + { + p_wsrc2 = (xtfloatx2 *)(p_scr_in + (itr_n * plane_size) + (itr_h * wc_plane_size) + (itr_w * temp_inp_c)); + p_dst = (xtfloatx2 *)(p_scratch + (itr_n * hw_plane_size) + (itr_h * temp_inp_w) + itr_w); + align_src = AE_LA64_PP(p_wsrc2); + xtfloatx2 i1 = AE_MOVXTFLOATX2_FROMF32X2(AE_MOVDA32(0)); + for(itr_c = 0; itr_c < (temp_inp_c >> 2); itr_c++) + { + xtfloatx2 j1, j2; + XT_LASX2IP(j1, align_src, p_wsrc2); + XT_LASX2IP(j2, align_src, p_wsrc2); + + xtfloatx2 t1 = XT_SEL32_HH_SX2(j1, j1); + xtfloatx2 t2 = XT_SEL32_HH_SX2(j2, j2); + + j1 = XT_ADD_SX2(t1, j1); + j2 = XT_ADD_SX2(t2, j2); + + i1 = XT_ADD_SX2(i1, j1); + i1 = XT_ADD_SX2(i1, j2); + } + + //Remainder Loop + for(itr_c=0; itr_c < rem_c; itr_c++) + { + xtfloat j1; + XT_LSXP(j1, (xtfloat *)p_wsrc2, sizeof(xtfloat)); + i1 = XT_ADD_S(i1, j1); + } + XT_SSX(i1, (xtfloat *)p_dst, 0); + } + } + } + temp_inp_c = 1; + }break; + default: + break; + } + axis_dims_count--; + itr_axis++; + } +} + +WORD32 xa_nn_reduce_mean_4D_f32_f32( + FLOAT32 * __restrict__ p_out, + const WORD32 *const p_out_shape, + const FLOAT32 * __restrict__ p_inp, + const WORD32 *const p_inp_shape, + const WORD32 * __restrict__ p_axis, + WORD32 num_out_dims, + WORD32 num_inp_dims, + WORD32 num_axis_dims, + void * __restrict__ p_scratch_in) +{ + /* NULL pointer checks */ + XA_NNLIB_ARG_CHK_PTR(p_out, -1); + XA_NNLIB_ARG_CHK_PTR(p_inp, -1); + XA_NNLIB_ARG_CHK_PTR(p_axis, -1); + XA_NNLIB_ARG_CHK_PTR(p_out_shape, -1); + XA_NNLIB_ARG_CHK_PTR(p_inp_shape, -1); + + /* Invalid input checks */ + XA_NNLIB_ARG_CHK_COND(((num_inp_dims <= 0) || (num_inp_dims > 4)), -1); + XA_NNLIB_ARG_CHK_COND(((num_out_dims <= 0) || (num_out_dims > 4)), -1); + XA_NNLIB_ARG_CHK_COND(((num_axis_dims < 0) || (num_axis_dims > 4)), -1); + + int axis_itr = 0, inp_itr = 0, out_itr = 0; + int num_elm_in_axis = 1; + int current, past = -1; + for(axis_itr=0; axis_itr < num_axis_dims; axis_itr++) + { + current = p_axis[axis_itr]; + XA_NNLIB_ARG_CHK_COND(((current < 0) || (current > (num_inp_dims - 1))), -1); + XA_NNLIB_ARG_CHK_COND((p_inp_shape[current] > 1024), -1); + + /* Avoid calculation in case of repeated axis dims*/ + if(current != past) + { + num_elm_in_axis *= p_inp_shape[current]; + past = current; + } + } + + for(inp_itr=0; inp_itr < num_inp_dims; inp_itr++) + { + XA_NNLIB_ARG_CHK_COND((p_inp_shape[inp_itr] <= 0), -1); + } + + int out_length = 1; + for(out_itr=0; out_itr < num_out_dims; out_itr++) + { + XA_NNLIB_ARG_CHK_COND((p_out_shape[out_itr] <= 0), -1); + out_length *= p_out_shape[out_itr]; + } + + /* Pointer alignment checks */ + XA_NNLIB_ARG_CHK_ALIGN(p_out, sizeof(FLOAT32), -1); + XA_NNLIB_ARG_CHK_ALIGN(p_inp, sizeof(FLOAT32), -1); + XA_NNLIB_ARG_CHK_ALIGN(p_axis, sizeof(WORD32), -1); + XA_NNLIB_ARG_CHK_ALIGN(p_out_shape, sizeof(WORD32), -1); + XA_NNLIB_ARG_CHK_ALIGN(p_inp_shape, sizeof(WORD32), -1); + + FLOAT32 *p_in = (FLOAT32 *)(p_inp); + WORD32 *p_scratch = (WORD32 *)(ALIGN_PTR(p_scratch_in, ALIGNMENT_8)); + + // Changing order of axis data so that reduce max will be first computed + // across largest inp shape dim in axis. This is required to + // minimize the scratch usage. + int inp_length = 1, p_axis_data[4] = {0}, inp_shape_max; + if(num_axis_dims) + { + inp_shape_max = p_inp_shape[p_axis[0]]; + axis_itr = 1; + int max_axis_itr = 0; + int temp_p_axis_0 = p_axis[0]; + for(axis_itr = 0; axis_itr < num_axis_dims; axis_itr++) + { + p_axis_data[axis_itr] = p_axis[axis_itr]; + } + for(axis_itr = 1; axis_itr < num_axis_dims; axis_itr++) + { + if(p_inp_shape[p_axis[axis_itr]] > inp_shape_max) + { + inp_shape_max = p_inp_shape[p_axis[axis_itr]]; + max_axis_itr = axis_itr; + } + } + p_axis_data[0] = p_axis_data[max_axis_itr]; + p_axis_data[max_axis_itr] = temp_p_axis_0; + + inp_itr = 0; + for(inp_itr=0; inp_itr < num_inp_dims; inp_itr++) + { + inp_length *= p_inp_shape[inp_itr]; + } + + memset(p_scratch, 0, ((inp_length / inp_shape_max) * sizeof(WORD32))); //TODO: Alternate approach for memset? + } + + // Promoting lesser dim tensors to 4D tensors. Also modifying axis + // data accordingly. + int p_4D_inp_shape[4] = {1, 1, 1, 1}; + int itr = num_inp_dims - 1; + int count = 3; + while(itr >= 0) + { + p_4D_inp_shape[count] = p_inp_shape[itr]; + itr--; + count--; + } + for(itr = 0; itr < num_axis_dims; itr++) + { + p_axis_data[itr] = p_axis_data[itr] + (4 - num_inp_dims); + } + ae_valign align_out = AE_ZALIGN64(); + + if(num_axis_dims) + { + if(num_elm_in_axis > 1) + { + xa_nn_reduce_sum_4D_f32_f32(p_in, + p_4D_inp_shape, + p_axis_data, + num_inp_dims, + num_axis_dims, + p_scratch); + itr = 0; + xtfloatx2 *p_src1 = (xtfloatx2 *)(p_scratch); + + float div = 1; + + for(int i = 0; i < num_axis_dims; i++) + { + div = div * (float)p_4D_inp_shape[p_axis_data[i]]; + } + + float mul = 1 / div; + + xtfloatx2 multiplier = XT_LSX((xtfloat *)&mul, 0); + + for(itr = 0; itr < (out_length >> 3); itr++) + { + xtfloatx2 temp1, temp2, temp3, temp4; + + temp2 = XT_LSX2X(p_src1, 8); + temp3 = XT_LSX2X(p_src1, 16); + temp4 = XT_LSX2X(p_src1, 24); + XT_LSX2XP(temp1, p_src1, 32); + + temp1 = XT_MUL_SX2(temp1, multiplier); + temp2 = XT_MUL_SX2(temp2, multiplier); + temp3 = XT_MUL_SX2(temp3, multiplier); + temp4 = XT_MUL_SX2(temp4, multiplier); + + XT_SASX2IP(temp1, align_out, (xtfloatx2 *)p_out); + XT_SASX2IP(temp2, align_out, (xtfloatx2 *)p_out); + XT_SASX2IP(temp3, align_out, (xtfloatx2 *)p_out); + XT_SASX2IP(temp4, align_out, (xtfloatx2 *)p_out); + } + AE_SA64POS_FP(align_out, p_out); + + for(itr = 0; itr < (out_length & 7); itr++) + { + xtfloat temp1; + XT_LSXP(temp1, (xtfloat *)p_src1, 4); + temp1 = XT_MUL_S(temp1, multiplier); + XT_SSXP(temp1, (xtfloat *)p_out, 4); + } + } + else + { + + memcpy(p_out, p_inp, inp_length * sizeof(FLOAT32)); + } + } + else + { + memcpy(p_out, p_inp, inp_length * sizeof(FLOAT32)); + } + + return 0; +} diff --git a/backends/cadence/reference/kernels/kernels.cpp b/backends/cadence/reference/kernels/kernels.cpp index 4d4ff26c3f..faac3d7cb2 100644 --- a/backends/cadence/reference/kernels/kernels.cpp +++ b/backends/cadence/reference/kernels/kernels.cpp @@ -65,6 +65,7 @@ void dequantize( typed_quantize_val(int8_t); typed_quantize_val(uint8_t); typed_quantize_val(int16_t); +typed_quantize_val(uint16_t); typed_quantize_val(int32_t); #undef typed_quantize_val @@ -78,6 +79,7 @@ typed_quantize_val(int32_t); typed_quantize_vec(int8_t); typed_quantize_vec(uint8_t); typed_quantize_vec(int16_t); +typed_quantize_vec(uint16_t); typed_quantize_vec(int32_t); #undef typed_quantize_vec @@ -86,6 +88,7 @@ typed_quantize_vec(int32_t); typed_dequantize_val(int8_t); typed_dequantize_val(uint8_t); typed_dequantize_val(int16_t); +typed_dequantize_val(uint16_t); typed_dequantize_val(int32_t); #undef typed_dequantize_val @@ -99,6 +102,7 @@ typed_dequantize_val(int32_t); typed_dequantize_vec(int8_t); typed_dequantize_vec(uint8_t); typed_dequantize_vec(int16_t); +typed_dequantize_vec(uint16_t); typed_dequantize_vec(int32_t); #undef typed_dequantize_vec diff --git a/backends/cadence/reference/operators/dequantize_per_tensor.cpp b/backends/cadence/reference/operators/dequantize_per_tensor.cpp index aef730bfd1..b49c045b94 100644 --- a/backends/cadence/reference/operators/dequantize_per_tensor.cpp +++ b/backends/cadence/reference/operators/dequantize_per_tensor.cpp @@ -37,6 +37,14 @@ void dequantize_per_tensor_out( const int8_t* input_data = input.const_data_ptr(); impl::reference::kernels::dequantize( out_data, input_data, scale, zero_point, numel); + } else if (input.scalar_type() == ScalarType::Bits16) { + const uint16_t* input_data = input.const_data_ptr(); + impl::reference::kernels::dequantize( + out_data, input_data, scale, zero_point, numel); + } else if (input.scalar_type() == ScalarType::Short) { + const int16_t* input_data = input.const_data_ptr(); + impl::reference::kernels::dequantize( + out_data, input_data, scale, zero_point, numel); } else if (input.scalar_type() == ScalarType::Int) { const int32_t* input_data = input.const_data_ptr(); impl::reference::kernels::dequantize( diff --git a/backends/cadence/reference/operators/quantize_per_tensor.cpp b/backends/cadence/reference/operators/quantize_per_tensor.cpp index 0d7ff0bc7e..ad5fa791b5 100644 --- a/backends/cadence/reference/operators/quantize_per_tensor.cpp +++ b/backends/cadence/reference/operators/quantize_per_tensor.cpp @@ -39,6 +39,14 @@ void quantize_per_tensor_out( int8_t* out_data = out.mutable_data_ptr(); impl::reference::kernels::quantize( out_data, input_data, 1. / scale, zero_point, numel); + } else if (out.scalar_type() == ScalarType::Bits16) { + uint16_t* out_data = out.mutable_data_ptr(); + impl::reference::kernels::quantize( + out_data, input_data, 1. / scale, zero_point, numel); + } else if (out.scalar_type() == ScalarType::Short) { + int16_t* out_data = out.mutable_data_ptr(); + impl::reference::kernels::quantize( + out_data, input_data, 1. / scale, zero_point, numel); } else if (out.scalar_type() == ScalarType::Int) { int32_t* out_data = out.mutable_data_ptr(); impl::reference::kernels::quantize( diff --git a/backends/qualcomm/quantizer/utils.py b/backends/qualcomm/quantizer/utils.py index dc3d2a6841..223b068375 100644 --- a/backends/qualcomm/quantizer/utils.py +++ b/backends/qualcomm/quantizer/utils.py @@ -229,14 +229,29 @@ def get_default_8bit_qnn_ptq_config( ) -> QuantizationConfig: extra_args: Dict[str, Any] = {"eps": 2**-12} - act_quantization_spec = QuantizationSpec( - dtype=torch.uint8, - qscheme=( - torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine - ), - ch_axis=0, - observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), - ) + if act_symmetric: + # If zero_point is 128, htp can do optimizations. + # If we keep quant_min and quant_max none, observer will default use 128 as zero_point. + # If we provide uint8 quant_min/max, it will use 127 as zero_point, which is undesired. + act_quantization_spec = QuantizationSpec( + dtype=torch.uint8, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) + else: + # PyTorch will remove redundant observers based on attributes such as: + # dtype, quant_min, quant_max, ch_axis, etc. + # Providing values like quant_min and quant_max can help observers compare + # and further reduce the number of observers. + act_quantization_spec = QuantizationSpec( + dtype=torch.uint8, + quant_min=torch.iinfo(torch.uint8).min, + quant_max=torch.iinfo(torch.uint8).max, + qscheme=torch.per_tensor_affine, + ch_axis=0, + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) weight_quantization_spec = QuantizationSpec( dtype=torch.int8, @@ -409,6 +424,7 @@ def get_ptq_per_channel_quant_config( quant_min=torch.iinfo(act_dtype).min, quant_max=torch.iinfo(act_dtype).max, qscheme=torch.per_tensor_affine, + ch_axis=0, observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args(**extra_args), ) diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 4bfdedcd4b..875cfbf956 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -2918,6 +2918,44 @@ def test_ptq_mobilebert(self): for k, v in cpu.items(): self.assertLessEqual(abs(v[0] - htp[k][0]), 5) + def test_wav2letter(self): + if not self.required_envs([self.pretrained_weight]): + self.skipTest("missing required envs") + + cmds = [ + "python", + f"{self.executorch_root}/examples/qualcomm/scripts/wav2letter.py", + "--artifact", + self.artifact_dir, + "--build_folder", + self.build_folder, + "--device", + self.device, + "--model", + self.model, + "--pretrained_weight", + self.pretrained_weight, + "--ip", + self.ip, + "--port", + str(self.port), + ] + if self.host: + cmds.extend(["--host", self.host]) + if self.shared_buffer: + cmds.extend(["--shared_buffer"]) + + p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) + with Listener((self.ip, self.port)) as listener: + conn = listener.accept() + p.communicate() + msg = json.loads(conn.recv()) + if "Error" in msg: + self.fail(msg["Error"]) + else: + self.assertLessEqual(msg["wer"], 0.5) + self.assertLessEqual(msg["cer"], 0.25) + def test_export_example(self): if not self.required_envs([self.model_name]): self.skipTest("missing required envs") diff --git a/backends/vulkan/_passes/TARGETS b/backends/vulkan/_passes/TARGETS index cf50f170cf..ed3d847933 100644 --- a/backends/vulkan/_passes/TARGETS +++ b/backends/vulkan/_passes/TARGETS @@ -16,6 +16,20 @@ runtime.python_library( ], ) +runtime.python_library( + name = "int4_weight_only_quantizer", + srcs = [ + "int4_weight_only_quantizer.py", + ], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + "//executorch/backends/vulkan:custom_ops_lib", + "//pytorch/ao:torchao", + ] +) + runtime.python_library( name = "remove_local_scalar_dense", srcs = ["remove_local_scalar_dense_ops.py"], @@ -30,17 +44,18 @@ runtime.python_library( ) runtime.python_library( - name = "int4_weight_only_quantizer", - srcs = [ - "int4_weight_only_quantizer.py", - ], + name = "tag_memory_meta_pass", + srcs = ["tag_memory_meta_pass.py"], visibility = [ "//executorch/backends/...", ], deps = [ - "//executorch/backends/vulkan:custom_ops_lib", - "//pytorch/ao:torchao", - ] + "//caffe2:torch", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + "//executorch/backends/vulkan:utils_lib", + "//executorch/backends/vulkan/serialization:lib", + ], ) runtime.python_library( @@ -56,5 +71,6 @@ runtime.python_library( ":insert_prepack_nodes", ":int4_weight_only_quantizer", ":remove_local_scalar_dense", + ":tag_memory_meta_pass" ] ) diff --git a/backends/vulkan/_passes/__init__.py b/backends/vulkan/_passes/__init__.py index cfdb7c6eee..8823553ab1 100644 --- a/backends/vulkan/_passes/__init__.py +++ b/backends/vulkan/_passes/__init__.py @@ -5,9 +5,11 @@ from executorch.backends.vulkan._passes.remove_local_scalar_dense_ops import ( RemoveLocalScalarDenseOpsTransform, ) +from executorch.backends.vulkan._passes.tag_memory_meta_pass import TagMemoryMetaPass __all__ = [ "insert_prepack_nodes", "VkInt4WeightOnlyQuantizer", "RemoveLocalScalarDenseOpsTransform", + "TagMemoryMetaPass", ] diff --git a/backends/vulkan/_passes/insert_prepack_nodes.py b/backends/vulkan/_passes/insert_prepack_nodes.py index 37665a6da8..7876806d6d 100644 --- a/backends/vulkan/_passes/insert_prepack_nodes.py +++ b/backends/vulkan/_passes/insert_prepack_nodes.py @@ -35,6 +35,10 @@ def prepack_not_required(node: torch.fx.Node) -> bool: if not is_param_node(program, node): return True + # Annotate that this node is going to represented as a tensorref in the Vulkan + # compute graph. This will be useful for later graph passes. + node.meta["vkdg_tensorref"] = True + for user in node.users: if user.op == "call_function" and handles_own_prepacking( # pyre-ignore diff --git a/backends/vulkan/_passes/tag_memory_meta_pass.py b/backends/vulkan/_passes/tag_memory_meta_pass.py new file mode 100644 index 0000000000..0a6a2d42d4 --- /dev/null +++ b/backends/vulkan/_passes/tag_memory_meta_pass.py @@ -0,0 +1,267 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from copy import deepcopy +from typing import Set + +import executorch.backends.vulkan.utils as utils + +import torch + +from executorch.backends.vulkan.op_registry import get_op_features, has_impl + +from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( + VkMemoryLayout, + VkStorageType, +) + +from executorch.exir.dialects._ops import ops as exir_ops + +from executorch.exir.pass_base import ExportPass, PassResult + +from torch._subclasses.fake_tensor import FakeTensor + +from torch.fx.passes.tools_common import NodeList +from torch.fx.passes.utils.fuser_utils import topo_sort + +logger: logging.Logger = logging.getLogger("") +logger.setLevel(logging.INFO) + + +def set_memory_metadata( + node: torch.fx.Node, storage: VkStorageType, layout: VkMemoryLayout +) -> None: + utils.set_node_spec_attr(node, "vk_storage_type", storage) + utils.set_node_spec_attr(node, "vk_memory_layout", layout) + + +def insert_transition_node( + graph_module: torch.fx.GraphModule, + node: torch.fx.Node, + arg: torch.fx.Node, + storage: VkStorageType, + layout: VkMemoryLayout, +) -> None: + """ + Insert a clone node to copy the original tensor to a tensor with the desired storage + type and memory layout. + """ + with graph_module.graph.inserting_before(node): + clone_node = graph_module.graph.create_node( + "call_function", + exir_ops.edge.aten.clone.default, + (arg,), + ) + clone_node.meta["val"] = arg.meta["val"] + clone_node.meta["spec"] = deepcopy(arg.meta["spec"]) + clone_node.meta["spec"].const = False + set_memory_metadata(clone_node, storage, layout) + arg.replace_all_uses_with(clone_node, lambda x, y=node: x == y) + + +class TagMemoryMetaPass(ExportPass): + """ + There are a variety of ways that tensors can be represented in Vulkan. The two main + descriptors for how a tensor is laid out in memory is: + + 1. Storage Type (buffer or texture) + 2. Memory Layout (which dim is packed along a texel / has a stride of 1, etc.) + + Due to the differences between buffers and textures, and the differences between + different memory layouts, an implementation for an operator may only support a + specific set of (storage type, memory layout) combinations. + + Furthermore, if an operator implementation supports multiple (storage type, memory + layout) combinations, there may be a "preferred" setting which results in optimal + performance. + + This pass is responsible for ensuring that all tensors participating in an operator + call have a valid/optimal (storage type, memory layout) setting, and insert + transition operators to transfer input tensors to the correct memory settings when + necessary. + """ + + def __init__( + self, + texture_limits: utils.ImageExtents, + default_storage_type: VkStorageType = VkStorageType.TEXTURE_3D, + default_memory_layout: VkMemoryLayout = VkMemoryLayout.TENSOR_WIDTH_PACKED, + ): + super().__init__() + self.default_storage: VkStorageType = default_storage_type + self.default_layout: VkMemoryLayout = default_memory_layout + self.texture_limits = texture_limits + + def propose_node_storage( + self, + node: torch.fx.Node, + ) -> VkStorageType: + """ + Uses the operator registry to determine the storage type that should be used for + a given node. The storage type is determined with the following priorities: + 1. In some cases, a tensor involved in the computation may be too large to be + represented as a texture. If this is the case, the node is "opinionated" and + buffer representation must be used. + 1. If the operator called by the node indicates an optimal storage type, or only + supports a single storage type, use that storage type. If either is true, + then the node is considered to be opinionated as well. If multiple storage + and no preferred storage type is indicated, then the node is not opinionated; + go to the next step. + 2. If the node's arguments already have memory metadata annotations, then + preserve the settings of the first argument. Otherwise, proceed to the next + step. + 3. Recursively search the node's uses to see if any subsequent uses are + opinionated; inherit the settings of the first opinionated node. If no + opinionated user can be found, then proceed to the last step. + 4. Use the default storage type setting. + """ + # The node may have an input/output tensor that is too big to be stored in a + # texture. In this case, buffer storage must be used. Note that the partitioner + # has already checked for the fact that buffer storage is supported by the + # operator. + if len(utils.possible_node_memory_layouts(node, self.texture_limits)) == 0: + return VkStorageType.BUFFER + + valid_storage_types: Set[VkStorageType] = utils.all_storage_types + + # pyre-ignore + if has_impl(node.target): + # pyre-ignore + features = get_op_features(node.target) + valid_storage_types = features.supported_storage_types() + storage = features.propose_storage_type() + if storage is not None: + return storage + + for arg in node.args: + if isinstance(arg, torch.fx.Node) and isinstance( + arg.meta["val"], FakeTensor + ): + storage = utils.get_node_storage_type(arg) + if storage is not None and storage in valid_storage_types: + return storage + + # If no storage type has been resolved yet, assume the optimal storage type of + # the first opinionated user. This search is recursive. + for user in node.users: + optimal_storage = self.propose_node_storage(user) + if optimal_storage is not None: + return optimal_storage + + if self.default_storage in valid_storage_types: + return self.default_storage + else: + return next(iter(valid_storage_types)) + + def propose_node_layout( + self, + node: torch.fx.Node, + storage: VkStorageType, + ) -> VkMemoryLayout: + """ + Performs the same steps as propose_node_storage, but detects the memory layout + that should be used for the specific storage type. The same prioritization logic + is applied. + """ + valid_layouts: Set[VkMemoryLayout] = utils.all_memory_layouts + # pyre-ignore + if has_impl(node.target): + # pyre-ignore + features = get_op_features(node.target) + valid_layouts = features.supported_memory_layouts(storage) + layout = features.propose_memory_layout(storage) + if layout is not None: + return layout + + for arg in node.args: + if isinstance(arg, torch.fx.Node) and isinstance( + arg.meta["val"], FakeTensor + ): + layout = utils.get_node_memory_layout(arg) + if layout is not None and layout in valid_layouts: + return layout + + # If no storage type has been resolved yet, assume the optimal storage type of + # the first opinionated user. This search is recursive. + for user in node.users: + optimal_storage = self.propose_node_layout(user, storage) + if optimal_storage is not None: + return optimal_storage + + # As a last resort, return the default storage type that should be used. + if self.default_layout in valid_layouts: + return self.default_layout + else: + return next(iter(valid_layouts)) + + def should_annotate(self, node) -> bool: + if not isinstance(node, torch.fx.Node): + return False + + if not isinstance(node.meta["val"], FakeTensor): + return False + + # Storage type and memory layout for tensorref will be determined at runtime + # so there's no use in setting those attributes ahead of time. + if node.meta.get("vkdg_tensorref", False): + return False + + return True + + def should_delay_annotation(self, node: torch.fx.Node) -> bool: + # For prepack nodes, delay setting the storage type and memory layout as long as + # possible. This is to minimize the number of transitions, since it can be + # difficult to predict what storage type and memory layout should be used at the + # time the prepack node is observed. + return node.target == exir_ops.edge.et_vk.prepack.default + + # noqa + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + sorted_nodes: NodeList = topo_sort(list(graph_module.graph.nodes)) + + for node in sorted_nodes: + if not self.should_annotate(node) or self.should_delay_annotation(node): + continue + + storage = self.propose_node_storage(node) + layout = self.propose_node_layout(node, storage) + + set_memory_metadata(node, storage, layout) + + inserting_transitions_for_node = False + for i, arg in enumerate(node.args): + if not self.should_annotate(arg): + continue + + assert isinstance(arg, torch.fx.Node) + + arg_storage = utils.get_node_storage_type(arg) + arg_layout = utils.get_node_memory_layout(arg) + + if arg_storage is None: + utils.set_node_spec_attr(arg, "vk_storage_type", storage) + arg_storage = storage + if arg_layout is None: + utils.set_node_spec_attr(arg, "vk_memory_layout", layout) + arg_layout = layout + + if arg_storage == storage and arg_layout == layout: + continue + + if not inserting_transitions_for_node: + inserting_transitions_for_node = True + logger.info( + f"[Vulkan Delegate] Inserting transition(s) for {node.format_node()}:" + ) + + insert_transition_node(graph_module, node, arg, storage, layout) + + logger.info( + f" args {i} ({arg}): ({arg_storage}, {arg_layout}) -> ({storage}, {layout})" + ) + + return PassResult(graph_module, True) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index fe67fdb30c..eeec5ab37e 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -8,18 +8,31 @@ import operator -from typing import Callable, Dict, List, Optional, Union +from typing import Callable, Dict, Optional, Set, Union import executorch.backends.vulkan.custom_ops_lib # noqa import torch -from executorch.backends.vulkan.serialization.vulkan_graph_schema import VkMemoryLayout +from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( + VkMemoryLayout, + VkStorageType, +) + +from executorch.backends.vulkan.utils import ( + all_memory_layouts, + all_packed_dims, + PackedDim, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload from torch._subclasses.fake_tensor import FakeTensor +###################### +## OpFeatures class ## +###################### + def allow_node(node: torch.fx.Node) -> bool: return True @@ -27,25 +40,37 @@ def allow_node(node: torch.fx.Node) -> bool: class TextureImplFeatures: __slots__ = [ - # Indicates if the compute shader is agnostic to the packed dimension - "uses_packed_dim", - # Indicates if the compute shader is agnostic to the texture axis mapping + "valid_packed_dims", "uses_axis_map", - # Specifies a specific set of memory layouts that the shader supports. If it is - # and empty list, then the supported memory layouts can be inferred from the - # `uses_packed_dim` and `uses_axis_map` flags. - "supported_layouts", ] def __init__( self, - uses_packed_dim: bool = False, uses_axis_map: bool = False, - supported_layouts: Optional[List[VkMemoryLayout]] = None, + valid_packed_dims: Optional[Set[PackedDim]] = None, ): - self.uses_packed_dim: bool = uses_packed_dim self.uses_axis_map: bool = uses_axis_map - self.supported_layouts: Optional[List[VkMemoryLayout]] = supported_layouts + self.valid_packed_dims = set() + if valid_packed_dims is not None: + self.valid_packed_dims = valid_packed_dims + + def valid_memory_layouts(self) -> Set[VkMemoryLayout]: + """ + Derive the set of memory layouts supported by the texture implementation based + on the valid packed dimensions. + """ + layouts = set() + + if PackedDim.WIDTH in self.valid_packed_dims: + layouts.add(VkMemoryLayout.TENSOR_WIDTH_PACKED) + + if PackedDim.HEIGHT in self.valid_packed_dims: + layouts.add(VkMemoryLayout.TENSOR_HEIGHT_PACKED) + + if PackedDim.CHANNELS in self.valid_packed_dims: + layouts.add(VkMemoryLayout.TENSOR_CHANNELS_PACKED) + + return layouts class OpFeatures: @@ -58,10 +83,16 @@ class OpFeatures: # bool indicating if the operator has a resize function, which allows it to # support dynamic shape tensors. "resize_fn", + # Optimal + "optimal_storage", + "optimal_layout", # bool indicating if the operator handles its own prepacking. If this is True, # then the insert_prepack_nodes pass will not insert prepack nodes for the args # of the op. "handles_own_prepacking", + # Optional dictionary to specify a custom function to calculate the required + # image extents for a particular argument index. + "skip_limits_check", # Optional check function used during partitioning to determine if a node's # inputs are supported by the operator implementation. "check_node_fn", @@ -72,17 +103,96 @@ def __init__( texture_impl: Optional[TextureImplFeatures] = None, buffer_impl: bool = False, resize_fn: bool = False, + optimal_storage: Optional[VkStorageType] = None, + optimal_layout: Optional[VkMemoryLayout] = None, handles_own_prepacking: bool = False, + skip_limits_check: Optional[Set[int]] = None, check_node_fn: Optional[Callable] = None, ): self.texture_impl: Optional[TextureImplFeatures] = texture_impl self.buffer_impl: bool = buffer_impl self.resize_fn: bool = resize_fn + self.optimal_storage: Optional[VkStorageType] = optimal_storage + self.optimal_layout: Optional[VkMemoryLayout] = optimal_layout self.handles_own_prepacking: bool = handles_own_prepacking + + self.skip_limits_check: Set[int] = set() + if skip_limits_check is not None: + self.skip_limits_check = skip_limits_check + self.check_node_fn: Callable = allow_node if check_node_fn is not None: self.check_node_fn = check_node_fn + def propose_storage_type(self) -> Optional[VkStorageType]: + """ + Propose a storage type that should be used for this operator. A proposal can be + made if one of the following is true: + 1. The operator specifies an optimal storage type + 2. Only one storage type is supported. + + If both storage types are supported and no optimal storage type is specified, + then None is returned to indicate that there is no preference in storage type. + """ + if self.optimal_storage is not None: + return self.optimal_storage + + if self.texture_impl is not None and not self.buffer_impl: + return VkStorageType.TEXTURE_3D + elif self.buffer_impl and self.texture_impl is None: + return VkStorageType.BUFFER + + return None + + def supported_storage_types(self) -> Set[VkStorageType]: + """ + Return the set of storage types supported by this operator. + """ + storage_types = set() + if self.texture_impl is not None: + storage_types.add(VkStorageType.TEXTURE_3D) + if self.buffer_impl: + storage_types.add(VkStorageType.BUFFER) + + return storage_types + + def propose_memory_layout(self, storage: VkStorageType) -> Optional[VkMemoryLayout]: + """ + Given a storage type as a precondition, propose a memory layout that should be + used for this operator. A proposal can be made if one of the following is true: + 1. The operator specifies an optimal memory layout + 2. Only one memory layout is supported. + + If multiple memory layouts are supported and no optimal memory layout is + specified then return None to indicate that the "best" memory layout for the + operator is ambiguous. + """ + if self.optimal_layout is not None: + return self.optimal_layout + + if storage == VkStorageType.TEXTURE_3D: + assert self.texture_impl is not None + possible_layouts = self.texture_impl.valid_memory_layouts() + if len(possible_layouts) == 1: + return next(iter(possible_layouts)) + + return None + + def supported_memory_layouts(self, storage: VkStorageType) -> Set[VkMemoryLayout]: + """ + Return the set of memory layouts supported by this operator for a given storage + type. + """ + if storage == VkStorageType.TEXTURE_3D: + assert self.texture_impl is not None + return self.texture_impl.valid_memory_layouts() + else: + return all_memory_layouts + + +####################### +## Operator Registry ## +####################### OpKey = Union[str, torch._ops.OpOverload, EdgeOpOverload] @@ -122,8 +232,8 @@ def update_features_impl(op: OpKey): ) def register_ephemeral_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - uses_packed_dim=True, uses_axis_map=True, + valid_packed_dims=all_packed_dims, ) features.buffer_impl = True features.resize_fn = True @@ -143,8 +253,8 @@ def register_ephemeral_op(features: OpFeatures): ) def register_binary_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - uses_packed_dim=True, uses_axis_map=True, + valid_packed_dims=all_packed_dims, ) features.resize_fn = True return features @@ -170,8 +280,8 @@ def register_binary_op(features: OpFeatures): ) def register_unary_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - uses_packed_dim=True, uses_axis_map=True, + valid_packed_dims=all_packed_dims, ) features.buffer_impl = True features.resize_fn = True @@ -181,8 +291,8 @@ def register_unary_op(features: OpFeatures): @update_features(exir_ops.edge.aten._to_copy.default) def register_to_copy_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - uses_packed_dim=True, uses_axis_map=True, + valid_packed_dims=all_packed_dims, ) features.resize_fn = True @@ -220,15 +330,16 @@ def check_to_copy_node(node: torch.fx.Node) -> bool: ) def register_mm_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - uses_packed_dim=False, uses_axis_map=True, - supported_layouts=[ - VkMemoryLayout.TENSOR_WIDTH_PACKED, - VkMemoryLayout.TENSOR_CHANNELS_PACKED, - ], + valid_packed_dims={ + PackedDim.WIDTH, + PackedDim.CHANNELS, + }, ) features.buffer_impl = True features.resize_fn = True + features.optimal_storage = VkStorageType.TEXTURE_3D + features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED features.handles_own_prepacking = True return features @@ -236,12 +347,13 @@ def register_mm_op(features: OpFeatures): @update_features(exir_ops.edge.aten._weight_int8pack_mm.default) def register_int8_mm_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - uses_packed_dim=False, uses_axis_map=False, - supported_layouts=[VkMemoryLayout.TENSOR_WIDTH_PACKED], + valid_packed_dims={PackedDim.WIDTH}, ) features.buffer_impl = True features.resize_fn = True + features.optimal_storage = VkStorageType.TEXTURE_3D + features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED features.handles_own_prepacking = True return features @@ -249,11 +361,12 @@ def register_int8_mm_op(features: OpFeatures): @update_features(exir_ops.edge.et_vk.linear_weight_int4.default) def register_int4_mm_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - uses_packed_dim=False, uses_axis_map=False, - supported_layouts=[VkMemoryLayout.TENSOR_WIDTH_PACKED], + valid_packed_dims={PackedDim.WIDTH}, ) features.resize_fn = True + features.optimal_storage = VkStorageType.TEXTURE_3D + features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED features.handles_own_prepacking = True return features @@ -266,7 +379,7 @@ def register_int4_mm_op(features: OpFeatures): ) def register_softmax_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - uses_packed_dim=True, + valid_packed_dims=all_packed_dims, ) features.resize_fn = True return features @@ -282,7 +395,7 @@ def register_softmax_op(features: OpFeatures): ) def register_reduce_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - uses_packed_dim=True, + valid_packed_dims=all_packed_dims, ) features.resize_fn = True @@ -309,7 +422,7 @@ def check_reduce_node(node: torch.fx.Node) -> bool: ) def register_2d_pool_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - supported_layouts=[VkMemoryLayout.TENSOR_CHANNELS_PACKED], + valid_packed_dims={PackedDim.CHANNELS}, ) features.resize_fn = True return features @@ -323,19 +436,24 @@ def register_2d_pool_op(features: OpFeatures): ) def register_convolution_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - supported_layouts=[VkMemoryLayout.TENSOR_CHANNELS_PACKED], + valid_packed_dims={PackedDim.CHANNELS}, ) features.resize_fn = True + features.optimal_storage = VkStorageType.TEXTURE_3D + features.optimal_layout = VkMemoryLayout.TENSOR_CHANNELS_PACKED features.handles_own_prepacking = True + features.skip_limits_check = {1, 2} return features @update_features("llama::sdpa_with_kv_cache") def register_sdpa_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - supported_layouts=[VkMemoryLayout.TENSOR_WIDTH_PACKED], + valid_packed_dims={PackedDim.WIDTH}, ) features.resize_fn = True + features.optimal_storage = VkStorageType.TEXTURE_3D + features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED features.handles_own_prepacking = True return features @@ -343,7 +461,7 @@ def register_sdpa_op(features: OpFeatures): @update_features(exir_ops.edge.et_vk.apply_rotary_emb.default) def register_rotary_emb_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - supported_layouts=[VkMemoryLayout.TENSOR_WIDTH_PACKED], + valid_packed_dims={PackedDim.WIDTH}, ) features.resize_fn = True return features @@ -352,7 +470,7 @@ def register_rotary_emb_op(features: OpFeatures): @update_features(exir_ops.edge.aten.view_copy.default) def register_view_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - uses_packed_dim=True, + valid_packed_dims=all_packed_dims, ) features.resize_fn = True return features @@ -393,7 +511,7 @@ def register_view_op(features: OpFeatures): ) def register_ported_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - supported_layouts=[VkMemoryLayout.TENSOR_CHANNELS_PACKED], + valid_packed_dims={PackedDim.CHANNELS}, ) return features @@ -408,15 +526,24 @@ def register_ported_op(features: OpFeatures): ) def register_ported_ops_with_prepacking(features: OpFeatures): features.texture_impl = TextureImplFeatures( - supported_layouts=[VkMemoryLayout.TENSOR_CHANNELS_PACKED], + valid_packed_dims={PackedDim.CHANNELS}, ) features.handles_own_prepacking = True return features -## -## Utility Functions -## +####################### +## Utility functions ## +####################### + + +def has_impl(target: OpKey) -> bool: + if not isinstance(target, str): + if target not in vulkan_supported_ops: + return target.name() in vulkan_supported_ops + return target in vulkan_supported_ops + else: + return target in vulkan_supported_ops def get_op_features(target: OpKey) -> OpFeatures: diff --git a/backends/vulkan/partitioner/TARGETS b/backends/vulkan/partitioner/TARGETS index d68a82ade0..1d1d29f6fb 100644 --- a/backends/vulkan/partitioner/TARGETS +++ b/backends/vulkan/partitioner/TARGETS @@ -13,6 +13,7 @@ runtime.python_library( ], deps = [ "//executorch/backends/vulkan:op_registry", + "//executorch/backends/vulkan:utils_lib", "//executorch/backends/vulkan:vulkan_preprocess", "//executorch/exir:delegate", "//executorch/exir:lib", diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index 2e916fd581..64e672fd69 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -9,12 +9,23 @@ import logging from typing import Any, Callable, Dict, final, List, Mapping, Optional, Tuple -import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema +import executorch.backends.vulkan.utils as utils import torch -from executorch.backends.vulkan.op_registry import vulkan_supported_ops +from executorch.backends.vulkan.op_registry import ( + get_op_features, + has_impl, + OpFeatures, + vulkan_supported_ops, +) + +from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( + VkMemoryLayout, + VkStorageType, +) from executorch.backends.vulkan.vulkan_preprocess import VulkanBackend + from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.backend.partitioner import ( DelegationSpec, @@ -24,7 +35,6 @@ from executorch.exir.backend.utils import tag_constant_data from executorch.exir.dialects._ops import ops as exir_ops -from torch._subclasses.fake_tensor import FakeTensor from torch.export.exported_program import ExportedProgram from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner @@ -40,104 +50,147 @@ class VulkanSupportedOperators(OperatorSupportBase): - def __init__(self, require_dynamic_shape: bool = False) -> None: + def __init__( + self, texture_limits: utils.ImageExtents, require_dynamic_shape: bool = False + ) -> None: super().__init__() self.require_dynamic_shapes = require_dynamic_shape - # The tensor dim limit is to guard against tensors with one or more - # large dimensions, which cannot be represented by an image texture due - # to the texture axis limits. - self.tensor_dim_limit = 16384 - - # pyre-ignore - def node_val_is_compatible(self, node_val: Any) -> bool: - # Skip nodes that don't have a value - if node_val is None: - return True + self.texture_limits: utils.ImageExtents = texture_limits - # TODO(ssjia) support symbolic ints - if isinstance(node_val, torch.SymInt): - return False + def op_node_is_compatible( + self, node: torch.fx.Node, features: Optional[OpFeatures] = None + ) -> Tuple[bool, str]: + """ + Check if a given node is compatible with the Vulkan delegate's implementation + of the operator called by the node. Each tensor argument participating in the + operator call must be able to be represented with a (storage type, memory layout) + combination that is supported by the operator implementation. + """ + target = node.target + # Account for custom operators + if node.target == torch.ops.higher_order.auto_functionalized: + first_arg = node.args[0] + assert isinstance(first_arg, torch._ops.OpOverload) + target = first_arg.name() - if isinstance(node_val, FakeTensor): - # Vulkan currently only supports tensors of up to 4D - if len(node_val.shape) > 4: - return False + # Extract the features for the node's operator, if no override was provided + if features is None: + if not has_impl(target): + return False, "no operator implementation" + features = get_op_features(target) - # bool dtype not currently supported - if node_val.dtype == torch.bool: - return False + valid_texture_layouts = utils.possible_node_memory_layouts( + node, self.texture_limits + ) - for dim in node_val.shape: - if dim > self.tensor_dim_limit: - return False + for i, arg in enumerate(node.args): + if ( + isinstance(arg, torch.fx.Node) + and utils.is_tensor_node(arg) + and i not in features.skip_limits_check + ): + arg_texture_layouts = utils.possible_node_memory_layouts( + arg, self.texture_limits + ) + valid_texture_layouts = valid_texture_layouts.intersection( + arg_texture_layouts + ) + + # If there are no valid texture memory layouts, then buffer storage must be + # supported by the operator implementation. + if len(valid_texture_layouts) == 0: + compatible = VkStorageType.BUFFER in features.supported_storage_types() + reason = "op is compatible" + if not compatible: + reason = "op requires buffers which is not supported by op impl" + return compatible, reason + + op_available_layouts = features.supported_memory_layouts( + VkStorageType.TEXTURE_3D + ) - if isinstance(node_val, (list, tuple)): - for item in node_val: - if not self.node_val_is_compatible(item): - return False + is_compatible = any( + layout in op_available_layouts for layout in valid_texture_layouts + ) + if not is_compatible: + return False, "Required texutre memory layout not supported" - return True + return is_compatible, "Op is compatible" - def all_args_compatible(self, node: torch.fx.Node) -> bool: - node_val = node.meta.get("val", None) - if not self.node_val_is_compatible(node_val): - return False + def node_is_compatible( + self, node: torch.fx.Node, features: Optional[OpFeatures] = None + ) -> Tuple[bool, str]: + # TODO(ssjia) support symbolic ints + if utils.is_symint_node(node): + return False, "symint node not supported yet" + elif utils.is_tensor_node(node): + return self.op_node_is_compatible(node, features=features) - for arg in node.args: - if not isinstance(arg, torch.fx.Node): - continue + return False, f"Unsupported node type: {node.format_node()}" - arg_val = arg.meta.get("val", None) - if not self.node_val_is_compatible(arg_val): - return False + def is_linear_permute(self, node: torch.fx.Node) -> Tuple[bool, bool]: + """ + Detect if a node is a permute/transpose that precedes a call to a `mm` or + `addmm` operator. This node can be fused with the `mm` or `addmm` to produce a + `linear` operator. - return True + This function returns two bool values: + 1. The first indicates if this node can be fused into a linear node + 2. The second indicates if the overall linear op can be executed with Vulkan - def is_linear_permute(self, node: torch.fx.Node) -> bool: + The node will be partitioned only if both are true. + """ if node.target not in [ exir_ops.edge.aten.t_copy.default, exir_ops.edge.aten.permute_copy.default, ]: - return False + return False, False if len(node.users) != 1: - return False + return False, False first_user = list(node.users.keys())[0] if first_user.target in [ exir_ops.edge.aten.mm.default, exir_ops.edge.aten.addmm.default, ]: - # Only mark this node if the overall linear op is valid - if self.all_args_compatible(first_user): - return True + # Only mark this node if the target linear op is valid + if self.node_is_compatible(first_user)[0]: + return True, True + else: + return True, False - return False + return False, False - def is_in_local_scalar_dense_chain(self, node: torch.fx.Node) -> bool: + def is_in_local_scalar_dense_chain(self, node: torch.fx.Node) -> Tuple[bool, bool]: """ Scalar tensors are usually converted to scalar values in the graph via` scalar_tensor[0].item()` in Python, which translates to a chain of `local_scalar_dense(torch.select.int(scalar_tensor, 0, 0))` in the graph. This function marks the entire chain as supported by the Vulkan delegate. - Later, within vulkan_preprocess there will be a graph transform which - replaces the chain with passing in the scalar tensor directly. + Later, within vulkan_preprocess there will be a graph transform which replaces + the chain with passing in the scalar tensor directly. + + Similar to the `is_linear_permute` function, this function has 2 return values. """ if node.target == exir_ops.edge.aten.select_copy.int: if len(node.users) != 1: - return False + return False, False # pyre-ignore if node.args[0].meta["val"].numel() != 1: - return False + return False, False + + local_scalar_dense = list(node.users.keys())[0] + if local_scalar_dense.target != torch.ops.aten._local_scalar_dense.default: + return False, False - user = list(node.users.keys())[0] - return user.target == torch.ops.aten._local_scalar_dense.default + return self.is_in_local_scalar_dense_chain(local_scalar_dense) if node.target == torch.ops.aten._local_scalar_dense.default: - return True + return True, all(self.node_is_compatible(user)[0] for user in node.users) - return False + return False, False def log_skip(self, node: torch.fx.Node, reason: str) -> None: if node.op == "call_function": @@ -148,26 +201,35 @@ def log_skip(self, node: torch.fx.Node, reason: str) -> None: def is_node_supported( self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node ) -> bool: - r = self._is_node_supported(submodules, node) + r = self._is_node_supported(node) return r - def _is_node_supported( - self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node - ) -> bool: + def _is_node_supported(self, node: torch.fx.Node) -> bool: target = node.target if node.target == torch.ops.higher_order.auto_functionalized: first_arg = node.args[0] assert isinstance(first_arg, torch._ops.OpOverload) target = first_arg.name() - if self.is_linear_permute(node): + is_linear_permute, target_linear_is_compatible = self.is_linear_permute(node) + if is_linear_permute and target_linear_is_compatible: return True + elif is_linear_permute: + # Skip so that the permute can be fused into a linear by another backend + self.log_skip(node, "permute node of non compatible linear node") + return False - if self.is_in_local_scalar_dense_chain(node): + is_in_local_scalar_dense_chain, dst_node_is_compatible = ( + self.is_in_local_scalar_dense_chain(node) + ) + if is_in_local_scalar_dense_chain and dst_node_is_compatible: return True + elif is_in_local_scalar_dense_chain: + self.log_skip(node, "local scalar dense of incompatible op node") + return False if target not in vulkan_supported_ops: - self.log_skip(node, "not in vulkan_supported_ops") + self.log_skip(node, "no operator implementation") return False features = vulkan_supported_ops[target] @@ -180,19 +242,42 @@ def _is_node_supported( self.log_skip(node, "no dynamic shape support") return False - return self.all_args_compatible(node) + is_compatible, reason = self.node_is_compatible(node, features=features) + if not is_compatible: + self.log_skip(node, reason) + + return is_compatible def parse_compile_options(compile_options: Dict[str, Any]) -> List[CompileSpec]: compile_specs = [] for key, value in compile_options.items(): - if isinstance( - value, (vk_graph_schema.VkStorageType, vk_graph_schema.VkMemoryLayout) - ): + if isinstance(value, (VkStorageType, VkMemoryLayout)): value_bytes = int(value).to_bytes(4, byteorder="little") compile_specs.append(CompileSpec(key, value_bytes)) + if isinstance(value, bool): + value_bytes = value.to_bytes(1, byteorder="little") + compile_specs.append(CompileSpec(key, value_bytes)) + + if key == "texture_limits": + compile_specs.append( + CompileSpec( + "texture_limits_x", int(value[0]).to_bytes(4, byteorder="little") + ) + ) + compile_specs.append( + CompileSpec( + "texture_limits_y", int(value[1]).to_bytes(4, byteorder="little") + ) + ) + compile_specs.append( + CompileSpec( + "texture_limits_z", int(value[2]).to_bytes(4, byteorder="little") + ) + ) + # Unhandled options are ignored return compile_specs @@ -200,7 +285,10 @@ def parse_compile_options(compile_options: Dict[str, Any]) -> List[CompileSpec]: @final class VulkanPartitioner(Partitioner): - def __init__(self, compile_options: Optional[Dict[str, Any]] = None) -> None: + def __init__( + self, + compile_options: Optional[Dict[str, Any]] = None, + ) -> None: self.options: Dict[str, Any] = {} if compile_options is not None: self.options = compile_options @@ -218,9 +306,15 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: # subgraphs containing the nodes with the tags partition_tags = {} + texture_limits: utils.ImageExtents = self.options.get( + "texture_limits", utils.DEFAULT_TEXTURE_LIMITS + ) capability_partitioner = CapabilityBasedPartitioner( exported_program.graph_module, - VulkanSupportedOperators(self.options.get("require_dynamic_shapes", False)), + VulkanSupportedOperators( + texture_limits, + require_dynamic_shape=self.options.get("require_dynamic_shapes", False), + ), allows_single_node_partition=True, ) partition_list = capability_partitioner.propose_partitions() diff --git a/backends/vulkan/runtime/gen_vulkan_spv.py b/backends/vulkan/runtime/gen_vulkan_spv.py index 46db1e3a98..39d023e765 100644 --- a/backends/vulkan/runtime/gen_vulkan_spv.py +++ b/backends/vulkan/runtime/gen_vulkan_spv.py @@ -540,6 +540,7 @@ def __init__( env: Dict[Any, Any], glslc_path: Optional[str], glslc_flags: str = "", + replace_u16vecn: bool = False, ) -> None: if isinstance(src_dir_paths, str): self.src_dir_paths = [src_dir_paths] @@ -549,6 +550,7 @@ def __init__( self.env = env self.glslc_path = glslc_path self.glslc_flags = glslc_flags + self.replace_u16vecn = replace_u16vecn self.glsl_src_files: Dict[str, str] = {} self.template_yaml_files: List[str] = [] @@ -705,6 +707,22 @@ def constructOutputMap(self) -> None: self.create_shader_params(), ) + def maybe_replace_u16vecn(self, input_text: str) -> str: + """ + There is a latency benefit to using u16vecn variables to store texture position + variables instead of ivecn, likely due to reduced register pressure. However, + SwiftShader does not support 16 bit integer types in shaders, so this is a crude + way to fallback to using ivecn to store texture positions so that testing with + SwiftShader is still possible. + """ + if not self.replace_u16vecn: + return input_text + if "codegen-nosub" in input_text: + return input_text + + input_text = input_text.replace("u16vec", "ivec") + return input_text + def generateSPV(self, output_dir: str) -> Dict[str, str]: output_file_map = {} @@ -716,6 +734,7 @@ def process_shader(shader_paths_pair): with codecs.open(source_glsl, "r", encoding="utf-8") as input_file: input_text = input_file.read() + input_text = self.maybe_replace_u16vecn(input_text) output_text = preprocess(input_text, shader_params) glsl_out_path = os.path.join(output_dir, f"{shader_name}.glsl") @@ -1029,6 +1048,7 @@ def main(argv: List[str]) -> int: parser.add_argument("-c", "--glslc-path", required=True, help="") parser.add_argument("-t", "--tmp-dir-path", required=True, help="/tmp") parser.add_argument("-o", "--output-path", required=True, help="") + parser.add_argument("--replace-u16vecn", action="store_true", default=False) parser.add_argument("--optimize_size", action="store_true", help="") parser.add_argument("--optimize", action="store_true", help="") parser.add_argument( @@ -1056,7 +1076,11 @@ def main(argv: List[str]) -> int: glslc_flags += "-O" shader_generator = SPVGenerator( - options.glsl_paths, env, options.glslc_path, glslc_flags + options.glsl_paths, + env, + options.glslc_path, + glslc_flags=glslc_flags, + replace_u16vecn=options.replace_u16vecn, ) output_spv_files = shader_generator.generateSPV(options.tmp_dir_path) diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl index ecfb44d431..f679732ddb 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl @@ -6,6 +6,8 @@ * LICENSE file in the root directory of this source tree. */ +// codegen-nosub + #version 450 core #define PRECISION ${PRECISION} diff --git a/backends/vulkan/serialization/vulkan_graph_builder.py b/backends/vulkan/serialization/vulkan_graph_builder.py index bc77bc40cf..8144747212 100644 --- a/backends/vulkan/serialization/vulkan_graph_builder.py +++ b/backends/vulkan/serialization/vulkan_graph_builder.py @@ -12,6 +12,11 @@ import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema import torch + +from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( + VkMemoryLayout, + VkStorageType, +) from executorch.backends.vulkan.utils import ( is_constant, is_get_attr_node, @@ -169,6 +174,15 @@ def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int: if spec.mem_obj_id is not None: mem_obj_id = spec.mem_obj_id + storage_type = VkStorageType.DEFAULT_STORAGE + memory_layout = VkMemoryLayout.DEFAULT_LAYOUT + if hasattr(spec, "vk_storage_type"): + # pyre-ignore[16] + storage_type = spec.vk_storage_type + if hasattr(spec, "vk_memory_layout"): + # pyre-ignore[16] + memory_layout = spec.vk_memory_layout + new_id = len(self.values) self.values.append( vk_graph_schema.VkValue( @@ -177,6 +191,8 @@ def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int: dims=spec.shape, constant_id=constant_id, mem_obj_id=mem_obj_id, + storage_type=storage_type, + memory_layout=memory_layout, ) ) ) diff --git a/backends/vulkan/serialization/vulkan_graph_schema.py b/backends/vulkan/serialization/vulkan_graph_schema.py index 8197f705b5..35113bc623 100644 --- a/backends/vulkan/serialization/vulkan_graph_schema.py +++ b/backends/vulkan/serialization/vulkan_graph_schema.py @@ -37,6 +37,9 @@ class VkStorageType(IntEnum): TEXTURE_2D = 2 DEFAULT_STORAGE = 255 + def __str__(self) -> str: + return self.name + class VkMemoryLayout(IntEnum): TENSOR_WIDTH_PACKED = 0 @@ -44,6 +47,9 @@ class VkMemoryLayout(IntEnum): TENSOR_CHANNELS_PACKED = 2 DEFAULT_LAYOUT = 255 + def __str__(self) -> str: + return self.name + @dataclass class VkTensor: diff --git a/backends/vulkan/targets.bzl b/backends/vulkan/targets.bzl index 0d3b17cccc..c2b46774aa 100644 --- a/backends/vulkan/targets.bzl +++ b/backends/vulkan/targets.bzl @@ -27,6 +27,7 @@ def vulkan_spv_shader_lib(name, spv_filegroups, is_fbcode = False): select({ "DEFAULT": "", "ovr_config//os:android": "--optimize", + "ovr_config//os:linux": "--replace-u16vecn", }) ) @@ -117,7 +118,7 @@ def define_common_targets(is_fbcode = False): "fbsource//third-party/toolchains:android" ], "ovr_config//os:macos-arm64": [ - "//third-party/khronos:moltenVK" + "//third-party/khronos:moltenVK_static" ], }) VK_API_PREPROCESSOR_FLAGS += select({ @@ -223,6 +224,8 @@ def define_common_targets(is_fbcode = False): ], deps = [ "//caffe2:torch", + "//executorch/exir:tensor", + "//executorch/backends/vulkan/serialization:lib", ] ) @@ -253,6 +256,7 @@ def define_common_targets(is_fbcode = False): ], deps = [ ":custom_ops_lib", + ":utils_lib", "//caffe2:torch", "//executorch/exir/dialects:lib", "//executorch/backends/vulkan/serialization:lib", diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index ae0b8c6940..2e9fbba01c 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -4,11 +4,28 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from enum import IntEnum +from typing import Optional, Set, Tuple + import torch + +from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( + VkMemoryLayout, + VkStorageType, +) + +from executorch.exir.tensor import TensorSpec + from torch._export.utils import is_buffer, is_param +from torch._subclasses.fake_tensor import FakeTensor + from torch.export import ExportedProgram +## +## Node type determination +## + def is_get_attr_node(node: torch.fx.Node) -> bool: return isinstance(node, torch.fx.Node) and node.op == "get_attr" @@ -28,3 +45,171 @@ def is_param_node(program: ExportedProgram, node: torch.fx.Node) -> bool: or is_buffer(program, node) or is_constant(program, node) ) + + +def is_symint_node(node: torch.fx.Node) -> bool: + """ + Returns true if the given node produces a SymInt value + """ + if "val" not in node.meta: + return False + + if isinstance(node.meta["val"], torch.SymInt): + return True + + return False + + +def is_tensor_node(node: torch.fx.Node) -> bool: + """ + Returns true if the given node produces a tensor value, or a collection of tensor values + """ + # All nodes with tensor values are tagged by the SpecPropPass transform + if "spec" in node.meta: + return True + + if "val" not in node.meta: + return False + + if isinstance(node.meta["val"], FakeTensor): + return True + + if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple): + return all(isinstance(x, FakeTensor) for x in node.meta["val"]) + + return False + + +## +## Memory Layout, Storage Type Determination +## + +ImageExtents = Tuple[int, int, int] + +DEFAULT_TEXTURE_LIMITS = (16384, 16384, 2048) + + +class PackedDim(IntEnum): + WIDTH = 0 + HEIGHT = 1 + CHANNELS = 2 + + +all_packed_dims: Set[PackedDim] = { + PackedDim.WIDTH, + PackedDim.HEIGHT, + PackedDim.CHANNELS, +} + +all_storage_types: Set[VkStorageType] = { + VkStorageType.BUFFER, + VkStorageType.TEXTURE_3D, +} + +all_memory_layouts: Set[VkMemoryLayout] = { + VkMemoryLayout.TENSOR_WIDTH_PACKED, + VkMemoryLayout.TENSOR_HEIGHT_PACKED, + VkMemoryLayout.TENSOR_CHANNELS_PACKED, +} + + +def required_image_extents(sizes: torch.Size, layout: VkMemoryLayout) -> ImageExtents: + """ + Calculate the image extents that will be used to represent a tensor with the given sizes + and memory layout in the Vulkan Delegate. + """ + width = sizes[-1] if len(sizes) >= 1 else 1 + height = sizes[-2] if len(sizes) >= 2 else 1 + channels = sizes[-3] if len(sizes) >= 3 else 1 + batch = sizes[0] if len(sizes) >= 4 else 1 + + if layout == VkMemoryLayout.TENSOR_WIDTH_PACKED: + width = (width + 3) // 4 + elif layout == VkMemoryLayout.TENSOR_HEIGHT_PACKED: + height = (height + 3) // 4 + elif layout == VkMemoryLayout.TENSOR_CHANNELS_PACKED: + channels = (channels + 3) // 4 + else: + raise RuntimeError(f"Unsupported memory layout {layout}") + + return width, height, channels * batch + + +def extents_are_valid(extents: ImageExtents, limits: ImageExtents) -> bool: + return all(extents[i] <= limits[i] for i in range(len(extents))) + + +def valid_texture_memory_layouts( + tensor_sizes: torch.Size, texture_limits: ImageExtents +) -> Set[VkMemoryLayout]: + """ + Given tensor sizes, determine the set of memory layouts which will prodice a texture + that can fit within the specified device limits. + """ + valid_layouts = set() + for layout in list(all_memory_layouts): + extents = required_image_extents(tensor_sizes, layout) + if extents_are_valid(extents, texture_limits): + valid_layouts.add(layout) + + return valid_layouts + + +def possible_node_memory_layouts( + node: torch.fx.Node, texture_limits: ImageExtents +) -> Set[VkMemoryLayout]: + """ + Given a node, determine the set of memory layouts which can be used to represent all + tensors involved in the computation. + """ + assert is_tensor_node(node) + if isinstance(node.meta["val"], FakeTensor): + return valid_texture_memory_layouts(node.meta["val"].shape, texture_limits) + valid_layouts = set() + if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple): + for fake_tensor in node.meta["val"]: + valid_layouts = valid_layouts.union( + valid_texture_memory_layouts(fake_tensor.shape, texture_limits) + ) + + return valid_layouts + + +## +## TensorSpec Utils +## + + +def set_node_spec_attr(node: torch.fx.Node, attr: str, value): + assert "spec" in node.meta + spec = node.meta["spec"] + if isinstance(spec, TensorSpec): + setattr(spec, attr, value) + elif isinstance(spec, list) or isinstance(spec, tuple): + for s in spec: + assert isinstance(s, TensorSpec) + setattr(s, attr, value) + else: + raise RuntimeError(f"Cannot set attr for spec of type {type(spec)}") + + +def get_node_spec_attr(node: torch.fx.Node, attr: str, return_first: bool = True): + assert "spec" in node.meta + spec = node.meta["spec"] + if isinstance(spec, TensorSpec): + return getattr(spec, attr) if hasattr(spec, attr) else None + elif isinstance(spec, list) or isinstance(spec, tuple): + if return_first: + return getattr(spec[0], attr) if hasattr(spec, attr) else None + else: + return [getattr(s, attr) if hasattr(s, attr) else None for s in spec] + else: + raise RuntimeError(f"Cannot get attr for spec of type {type(spec)}") + + +def get_node_storage_type(node: torch.fx.Node) -> Optional[VkStorageType]: + return get_node_spec_attr(node, "vk_storage_type") + + +def get_node_memory_layout(node: torch.fx.Node) -> Optional[VkMemoryLayout]: + return get_node_spec_attr(node, "vk_memory_layout") diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 96eee198f4..c938f9ff42 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -6,7 +6,9 @@ # pyre-strict -from typing import final, List +from typing import Any, Dict, final, List + +import executorch.backends.vulkan.utils as utils from executorch.backends.transforms.addmm_mm_to_linear import AddmmToLinearTransform from executorch.backends.transforms.fuse_batch_norm_with_conv import ( @@ -20,9 +22,14 @@ from executorch.backends.vulkan._passes import ( insert_prepack_nodes, RemoveLocalScalarDenseOpsTransform, + TagMemoryMetaPass, ) from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder +from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( + VkMemoryLayout, + VkStorageType, +) from executorch.backends.vulkan.serialization.vulkan_graph_serialize import ( serialize_vulkan_graph, ) @@ -78,6 +85,28 @@ def apply_passes(program: ExportedProgram, passes) -> ExportedProgram: return program +def parse_compile_spec(compile_specs: List[CompileSpec]) -> Dict[str, Any]: + options = {} + for spec in compile_specs: + if spec.key == "storage_type_override": + options[spec.key] = VkStorageType( + int.from_bytes(spec.value, byteorder="little") + ) + if spec.key == "memory_layout_override": + options[spec.key] = VkMemoryLayout( + int.from_bytes(spec.value, byteorder="little") + ) + if spec.key in {"texture_limits_x", "texture_limits_y", "texture_limits_z"}: + options[spec.key] = int.from_bytes(spec.value, byteorder="little") + + if spec.key == "skip_tag_memory_metadata": + options[spec.key] = bool.from_bytes(spec.value, byteorder="little") + + # Unhandled options are ignored + + return options + + @final class VulkanBackend(BackendDetails): @classmethod @@ -87,6 +116,25 @@ def preprocess( # noqa: C901 program: ExportedProgram, module_compile_spec: List[CompileSpec], ) -> PreprocessResult: + compile_options = parse_compile_spec(module_compile_spec) + limits_x = compile_options.get( + "texture_limits_x", utils.DEFAULT_TEXTURE_LIMITS[0] + ) + limits_y = compile_options.get( + "texture_limits_y", utils.DEFAULT_TEXTURE_LIMITS[1] + ) + limits_z = compile_options.get( + "texture_limits_z", utils.DEFAULT_TEXTURE_LIMITS[2] + ) + texture_limits = (limits_x, limits_y, limits_z) + + default_storage_type = compile_options.get( + "storage_type_override", VkStorageType.TEXTURE_3D + ) + default_memory_layout = compile_options.get( + "memory_layout_override", VkMemoryLayout.TENSOR_WIDTH_PACKED + ) + program = unsafe_remove_auto_functionalized_pass(program) # First, apply passes that fuse/remove operators to consolidate the graph @@ -122,10 +170,31 @@ def preprocess( # noqa: C901 ], ) + # Optionally apply the memory metadata tagging pass, which will insert storage + # type and memory layout transition nodes to ensure that all tensor arguments + # to an operator is in a supported or optimal configuration. If this pass is not + # applied, there will be a risk that some operators recieve arguments with + # memory settings that are not supported by the implementation. + if not compile_options.get("skip_tag_memory_metadata", False): + program = apply_passes( + program, + [ + TagMemoryMetaPass( + texture_limits, + default_storage_type=default_storage_type, + default_memory_layout=default_memory_layout, + ), + ], + ) + # Finally, apply dynamic shape passes and memory planning pass. These passes # must be applied only when the graph structure is finalized. program = apply_passes( - program, [ConstraintBasedSymShapeEvalPass(), MemoryPlanningPass()] + program, + [ + ConstraintBasedSymShapeEvalPass(), + MemoryPlanningPass(), + ], ) graph_builder = VkGraphBuilder( diff --git a/devtools/bundled_program/schema/common.fbs b/devtools/bundled_program/schema/common.fbs index fc299ac691..0236416951 100644 --- a/devtools/bundled_program/schema/common.fbs +++ b/devtools/bundled_program/schema/common.fbs @@ -35,3 +35,19 @@ enum ScalarType : byte { // BITS4x2 = 20, // BITS8 = 21, } + +// Describes a contiguous piece of data that lives outside of the flatbuffer data, +// typically appended afterwards in the file. +// For .pte and .ptd files, the "extended header" in the file, when present, +// points to the segment base offset. +table DataSegment { + // Segment offsets are relative to the segment base offset provided in the + // extended file header. Segments will typically be aligned in a way to make + // it possible to use mmap() to load them. + offset: uint64; + + // The size in bytes of valid data starting at the offset. The segment + // data may be followed by padding before the segment that follows it, + // to make it easier to use mmap(). + size: uint64; +} diff --git a/devtools/bundled_program/schema/test/test_schema.py b/devtools/bundled_program/schema/test/test_schema.py index 6a2d244103..af79561a92 100644 --- a/devtools/bundled_program/schema/test/test_schema.py +++ b/devtools/bundled_program/schema/test/test_schema.py @@ -14,9 +14,7 @@ class TestSchema(unittest.TestCase): def test_schema_sync(self) -> None: # make the test work in both internal and oss. - prefix = ( - "executorch/" if os.path.exists("executorch/schema/common.fbs") else "" - ) + prefix = "executorch/" if os.path.exists("executorch/schema/common.fbs") else "" self.assertTrue( filecmp.cmp( diff --git a/devtools/etdump/common.fbs b/devtools/etdump/common.fbs index fc299ac691..0236416951 100644 --- a/devtools/etdump/common.fbs +++ b/devtools/etdump/common.fbs @@ -35,3 +35,19 @@ enum ScalarType : byte { // BITS4x2 = 20, // BITS8 = 21, } + +// Describes a contiguous piece of data that lives outside of the flatbuffer data, +// typically appended afterwards in the file. +// For .pte and .ptd files, the "extended header" in the file, when present, +// points to the segment base offset. +table DataSegment { + // Segment offsets are relative to the segment base offset provided in the + // extended file header. Segments will typically be aligned in a way to make + // it possible to use mmap() to load them. + offset: uint64; + + // The size in bytes of valid data starting at the offset. The segment + // data may be followed by padding before the segment that follows it, + // to make it easier to use mmap(). + size: uint64; +} diff --git a/devtools/inspector/_inspector_utils.py b/devtools/inspector/_inspector_utils.py index f712644303..83492f9963 100644 --- a/devtools/inspector/_inspector_utils.py +++ b/devtools/inspector/_inspector_utils.py @@ -112,6 +112,7 @@ def get_scalar_type_size(scalar_type: ScalarType) -> Tuple[torch.dtype, int]: ScalarType.BYTE: (torch.uint8, 1), ScalarType.CHAR: (torch.int8, 1), ScalarType.BOOL: (torch.bool, 1), + ScalarType.BITS16: (torch.uint16, 2), ScalarType.SHORT: (torch.int16, 2), ScalarType.HALF: (torch.float16, 2), ScalarType.INT: (torch.int, 4), @@ -217,6 +218,7 @@ def verify_debug_data_equivalence( if isinstance(output_a, torch.Tensor): assert bool( + # pyre-fixme[6]: For 1st argument expected `Tensor` but got `bool`. torch.all(output_a == output_b) ), "Tensors Debug Data is different. Expected to be equal." else: diff --git a/examples/arm/aot_arm_compiler.py b/examples/arm/aot_arm_compiler.py index 3075d992d5..e718c52fdc 100644 --- a/examples/arm/aot_arm_compiler.py +++ b/examples/arm/aot_arm_compiler.py @@ -180,7 +180,9 @@ def get_compile_spec( spec_builder = None if target == "TOSA": spec_builder = ( - ArmCompileSpecBuilder().tosa_compile_spec().set_permute_memory_format(True) + ArmCompileSpecBuilder() + .tosa_compile_spec("TOSA-0.80.0+BI") + .set_permute_memory_format(True) ) elif "ethos-u55" in target: spec_builder = ( diff --git a/examples/arm/executor_runner/arm_perf_monitor.cpp b/examples/arm/executor_runner/arm_perf_monitor.cpp index 323010bfd7..b75e510d9d 100644 --- a/examples/arm/executor_runner/arm_perf_monitor.cpp +++ b/examples/arm/executor_runner/arm_perf_monitor.cpp @@ -24,7 +24,14 @@ static std::vector ethosu_pmuEventCounts( ETHOSU_PMU_Get_NumEventCounters(), 0); +#if defined(ETHOSU55) || defined(ETHOSU65) static const uint32_t ethosu_pmuCountersUsed = 4; +#elif defined(ETHOSU85) +static const uint32_t ethosu_pmuCountersUsed = 5; +#else +#error No NPU target defined +#endif + // ethosu_pmuCountersUsed should match numbers of counters setup in // ethosu_inference_begin() and not be more then the HW supports static_assert(ETHOSU_PMU_NCOUNTERS >= ethosu_pmuCountersUsed); @@ -44,18 +51,26 @@ void ethosu_inference_begin(struct ethosu_driver* drv, void*) { ETHOSU_PMU_Set_EVTYPER(drv, 1, ETHOSU_PMU_AXI1_RD_DATA_BEAT_RECEIVED); ETHOSU_PMU_Set_EVTYPER(drv, 2, ETHOSU_PMU_AXI0_WR_DATA_BEAT_WRITTEN); ETHOSU_PMU_Set_EVTYPER(drv, 3, ETHOSU_PMU_NPU_IDLE); + // Enable the 4 counters + ETHOSU_PMU_CNTR_Enable( + drv, + ETHOSU_PMU_CNT1_Msk | ETHOSU_PMU_CNT2_Msk | ETHOSU_PMU_CNT3_Msk | + ETHOSU_PMU_CNT4_Msk); #elif defined(ETHOSU85) - ETHOSU_PMU_Set_EVTYPER(drv, 0, ETHOSU_PMU_EXT0_RD_DATA_BEAT_RECEIVED); - ETHOSU_PMU_Set_EVTYPER(drv, 1, ETHOSU_PMU_EXT1_RD_DATA_BEAT_RECEIVED); - ETHOSU_PMU_Set_EVTYPER(drv, 2, ETHOSU_PMU_EXT0_WR_DATA_BEAT_WRITTEN); - ETHOSU_PMU_Set_EVTYPER(drv, 3, ETHOSU_PMU_NPU_IDLE); + ETHOSU_PMU_Set_EVTYPER(drv, 0, ETHOSU_PMU_SRAM_RD_DATA_BEAT_RECEIVED); + ETHOSU_PMU_Set_EVTYPER(drv, 1, ETHOSU_PMU_SRAM_WR_DATA_BEAT_WRITTEN); + ETHOSU_PMU_Set_EVTYPER(drv, 2, ETHOSU_PMU_EXT_RD_DATA_BEAT_RECEIVED); + ETHOSU_PMU_Set_EVTYPER(drv, 3, ETHOSU_PMU_EXT_WR_DATA_BEAT_WRITTEN); + ETHOSU_PMU_Set_EVTYPER(drv, 4, ETHOSU_PMU_NPU_IDLE); + // Enable the 5 counters + ETHOSU_PMU_CNTR_Enable( + drv, + ETHOSU_PMU_CNT1_Msk | ETHOSU_PMU_CNT2_Msk | ETHOSU_PMU_CNT3_Msk | + ETHOSU_PMU_CNT4_Msk | ETHOSU_PMU_CNT5_Msk); #else #error No NPU target defined #endif - // Enable 4 counters - ETHOSU_PMU_CNTR_Enable(drv, 0xf); - ETHOSU_PMU_CNTR_Enable(drv, ETHOSU_PMU_CCNT_Msk); ETHOSU_PMU_CYCCNT_Reset(drv); @@ -177,7 +192,7 @@ void StopMeasurements() { #elif defined(ETHOSU85) ET_LOG( Info, - "Ethos-U PMU Events:[ETHOSU_PMU_EXT0_RD_DATA_BEAT_RECEIVED, ETHOSU_PMU_EXT1_RD_DATA_BEAT_RECEIVED, ETHOSU_PMU_EXT0_WR_DATA_BEAT_WRITTEN, ETHOSU_PMU_NPU_IDLE]"); + "Ethos-U PMU Events:[ETHOSU_PMU_SRAM_RD_DATA_BEAT_RECEIVED, ETHOSU_PMU_SRAM_WR_DATA_BEAT_WRITTEN, ETHOSU_PMU_EXT_RD_DATA_BEAT_RECEIVED, ETHOSU_PMU_EXT_WR_DATA_BEAT_WRITTEN, ETHOSU_PMU_NPU_IDLE]"); #else #error No NPU target defined #endif diff --git a/examples/models/llama/CMakeLists.txt b/examples/models/llama/CMakeLists.txt index b1401a0bca..6a4aee11d2 100644 --- a/examples/models/llama/CMakeLists.txt +++ b/examples/models/llama/CMakeLists.txt @@ -37,6 +37,8 @@ cmake_dependent_option( "NOT EXECUTORCH_BUILD_ARM_BAREMETAL" OFF ) +option(EXECUTORCH_BUILD_TORCHAO "Build the torchao kernels" OFF) + if(NOT PYTHON_EXECUTABLE) set(PYTHON_EXECUTABLE python3) endif() @@ -121,6 +123,13 @@ if(EXECUTORCH_BUILD_KERNELS_CUSTOM) list(APPEND link_libraries custom_ops) endif() +if(EXECUTORCH_BUILD_TORCHAO) + set(TORCHAO_BUILD_EXECUTORCH_OPS ON) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../../third-party/ao/torchao/experimental ${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/ao/torchao/experimental) + target_link_options_shared_lib(torchao_ops_executorch) + list(APPEND link_libraries torchao_ops_executorch) +endif() + set(XNNPACK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../backends/xnnpack) # Extra compile option and include dir for pthreadpool if(EXECUTORCH_BUILD_PTHREADPOOL) diff --git a/examples/models/llama/TARGETS b/examples/models/llama/TARGETS index d328adffbf..cf387bfab2 100644 --- a/examples/models/llama/TARGETS +++ b/examples/models/llama/TARGETS @@ -82,6 +82,7 @@ runtime.python_library( "export_llama_lib.py", "model.py", "source_transformation/apply_spin_quant_r1_r2.py", + "source_transformation/attention.py", "source_transformation/lora.py", "source_transformation/pre_quantization.py", "source_transformation/prune_vocab.py", diff --git a/examples/models/llama/export_llama.py b/examples/models/llama/export_llama.py index 3d0d1b7bcf..1899ccf4df 100644 --- a/examples/models/llama/export_llama.py +++ b/examples/models/llama/export_llama.py @@ -7,11 +7,14 @@ # Example script for exporting Llama2 to flatbuffer import logging +import sys import torch from .export_llama_lib import build_args_parser, export_llama +sys.setrecursionlimit(4096) + FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index f3822b6866..817f116c92 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -12,6 +12,7 @@ import copy import json import logging +import re import shlex from enum import Enum from json import JSONDecodeError @@ -19,7 +20,6 @@ from typing import Callable, List, Optional, Union import pkg_resources - import torch from executorch.devtools.etrecord import generate_etrecord @@ -50,6 +50,8 @@ fuse_layer_norms, get_model_with_r1_r2, ) + +from .source_transformation.attention import replace_attention_to_attention_sha from .source_transformation.quantize import ( get_quant_embedding_transform, get_quant_weight_transform, @@ -153,12 +155,12 @@ def build_args_parser() -> argparse.ArgumentParser: ], help="Use PT2E quantization. Comma separated options. e.g. xnnpack_dynamic (for per channel 8 bit weight), xnnpack_dynamic_qc4 (for per channel 4 bit weight), embedding.", ) + parser.add_argument( "-qmode", "--quantization_mode", - type=str, + type=_qmode_type, default=None, - choices=["int8", "8da4w", "8da4w-gptq", "vulkan_4w"], help="type of quantization", ) @@ -175,6 +177,12 @@ def build_args_parser() -> argparse.ArgumentParser: help="checkpoint directory. Use with a sharded checkpoint, not for the standard llama2 model. Note, checkpoint_dir takes precedence over checkpoint if both are set.", ) + parser.add_argument( + "--use_qnn_sha", + action="store_true", + help="Change multi head attention to multiple single head attention for qnn backend (Qualcomm)", + ) + parser.add_argument( "--calibration_tasks", nargs="+", @@ -568,6 +576,23 @@ def get_quantizer_and_quant_params(args): return pt2e_quant_params, quantizers, quant_dtype +def _qmode_type(value): + choices = ["int8", "8da4w", "8da4w-gptq", "vulkan_4w"] + patterns = [r"torchao:8da(\d+)w"] + + if value in choices: + return value + + for pattern in patterns: + matches = re.findall(pattern, value) + if len(matches) == 1: + return value + + raise argparse.ArgumentTypeError( + f"Got qmode {value}, but expected one of {choices}, or one of the regex patterns {patterns}." + ) + + def _validate_args(args): """ TODO: Combine all the backends under --backend args @@ -581,6 +606,19 @@ def _validate_args(args): if args.num_sharding > 0 and not args.qnn: raise ValueError("Model shard is only supported with qnn backend now.") + if ( + args.quantization_mode is not None + and args.quantization_mode.startswith("torchao:") + ) or ( + args.embedding_quantize is not None + and args.embedding_quantize.startswith("torchao:") + ): + if args.enable_dynamic_shape: + raise ValueError( + "Dynamic shape is not currently supported with torchao ops. Please use --disable_dynamic_shape." + "If you need this feature, please file an issue." + ) + def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 _validate_args(args) @@ -622,7 +660,7 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 partitioners.append( get_vulkan_partitioner( args.dtype_override, - args.quantization_mode, + args.enable_dynamic_shape, ) ) modelname = f"vulkan_{modelname}" @@ -670,15 +708,24 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 get_custom_quant_ios_dtype, ) + atten = builder_exported_to_edge.model.layers[0].attention + if args.use_qnn_sha: + cache_shape = torch.Size( + (atten.max_batch_size, atten.max_seq_len, atten.head_dim) + ) + else: + cache_shape = torch.Size( + ( + atten.max_batch_size, + atten.max_seq_len, + atten.n_kv_heads, + atten.head_dim, + ) + ) # pyre-ignore tag_quant_io( builder_exported_to_edge.edge_manager.exported_program().graph_module, - partial( - get_custom_quant_ios_dtype, # pyre-ignore - builder_exported_to_edge.model.layers[ - 0 - ].attention.kv_cache.past_k_caches.shape, - ), + partial(get_custom_quant_ios_dtype, cache_shape), # pyre-ignore ) logging.info("Lowering model using following partitioner(s): ") @@ -947,15 +994,27 @@ def _get_source_transforms( # noqa convert_linear_to_conv2d, ) - transforms.append(replace_kv_cache_with_simple_kv_cache) - transforms.append(replace_sdpa_with_flex_sdpa) - transforms.append(replace_causal_mask) - transforms.append(replace_rms_norm_with_native_rms_norm) - if args.optimized_rotation_path: - transforms.append(fuse_layer_norms) - transforms.append(get_model_with_r1_r2(args.optimized_rotation_path)) - # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`. - transforms.append(convert_linear_to_conv2d) + if args.use_qnn_sha: + if args.optimized_rotation_path: + transforms.append(fuse_layer_norms) + transforms.append( + get_model_with_r1_r2(args.optimized_rotation_path) + ) + transforms.append(replace_attention_to_attention_sha) + transforms.append(replace_causal_mask) + transforms.append(replace_rms_norm_with_native_rms_norm) + transforms.append(convert_linear_to_conv2d) + else: + transforms.append(replace_kv_cache_with_simple_kv_cache) + transforms.append(replace_sdpa_with_flex_sdpa) + transforms.append(replace_causal_mask) + transforms.append(replace_rms_norm_with_native_rms_norm) + if args.optimized_rotation_path: + transforms.append(fuse_layer_norms) + transforms.append( + get_model_with_r1_r2(args.optimized_rotation_path) + ) + transforms.append(convert_linear_to_conv2d) elif args.mps: # Currently mps doesn't support sdpa op, use the simpler decomposition diff --git a/examples/models/llama/install_requirements.sh b/examples/models/llama/install_requirements.sh index 3103daeb7d..f794b660bd 100755 --- a/examples/models/llama/install_requirements.sh +++ b/examples/models/llama/install_requirements.sh @@ -10,8 +10,7 @@ pip install snakeviz sentencepiece # Install torchao. -TORCHAO_VERSION=$(cat "$(dirname "$0")"/../../../.ci/docker/ci_commit_pins/torchao.txt) -pip install --no-use-pep517 "git+https://github.com/pytorch/ao.git@${TORCHAO_VERSION}" +pip install "$(dirname "$0")/../../../third-party/ao" # Install lm-eval for Model Evaluation with lm-evalution-harness # Install tiktoken for tokenizer diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index 76e8730328..20b8b1e30d 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -276,7 +276,6 @@ def __init__(self, args: ModelArgs, layer_id: int): self.max_batch_size = args.max_batch_size self.max_seq_len = args.max_seq_len self.dim = args.dim - # self.dim = 4096, self.n_heads = 32, self.head_dim = 4096 / 32 = 125 self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False) self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) diff --git a/examples/models/llama/runner/eager.py b/examples/models/llama/runner/eager.py index b8792151a0..9745fdd542 100644 --- a/examples/models/llama/runner/eager.py +++ b/examples/models/llama/runner/eager.py @@ -54,7 +54,7 @@ def build_args_parser() -> argparse.ArgumentParser: parser.add_argument( "--prompt", type=str, - default="Hello", + default=None, ) parser.add_argument( @@ -63,6 +63,20 @@ def build_args_parser() -> argparse.ArgumentParser: default=0, ) + parser.add_argument( + "--show_tokens", + action="store_true", + default=False, + help="Show the tokens that were generated", + ) + + parser.add_argument( + "--chat", + action="store_true", + default=False, + help="Have multi-turn chat with the model", + ) + return parser @@ -71,15 +85,16 @@ def main() -> None: args = parser.parse_args() runner = EagerLlamaRunner(args) - result = runner.text_completion( - prompt=args.prompt, - temperature=args.temperature, - ) - print( - "Response: \n{response}\n Tokens:\n {tokens}".format( - response=result["generation"], tokens=result["tokens"] + generated_tokens = ( + runner.chat_completion(temperature=args.temperature) + if args.chat + else runner.text_completion( + prompt=args.prompt, + temperature=args.temperature, ) ) + if args.show_tokens: + print(f"Tokens: {generated_tokens}") if __name__ == "__main__": diff --git a/examples/models/llama/runner/generation.py b/examples/models/llama/runner/generation.py index 867c41aabe..ed25d44b6f 100644 --- a/examples/models/llama/runner/generation.py +++ b/examples/models/llama/runner/generation.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from abc import ABC, abstractmethod -from typing import List, Optional, TypedDict +from typing import List, Optional import torch @@ -13,11 +13,6 @@ from executorch.extension.llm.tokenizer.utils import get_tokenizer -class CompletionPrediction(TypedDict, total=False): - generation: str - tokens: List[int] # not required - - def sample_top_p(probs, p): """ Perform top-p (nucleus) sampling on a probability distribution. @@ -72,18 +67,20 @@ def generate( # noqa: C901 temperature: float = 0.8, top_p: float = 0.9, echo: bool = False, + pos_base: int = 0, ) -> List[int]: # prefill logits = self.forward( tokens=torch.tensor([prompt_tokens], dtype=torch.long, device=self.device), input_pos=( - torch.tensor([0], dtype=torch.long, device=self.device) + torch.tensor([pos_base], dtype=torch.long, device=self.device) if self.params.use_kv_cache else None ), ) current_token = next_token(logits, temperature, top_p) + print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True) tokens = prompt_tokens + [current_token] while len(tokens) < self.params.max_seq_len: @@ -93,7 +90,9 @@ def generate( # noqa: C901 [[current_token]], dtype=torch.long, device=self.device ), input_pos=torch.tensor( - [len(tokens) - 1], dtype=torch.long, device=self.device + [pos_base + len(tokens) - 1], + dtype=torch.long, + device=self.device, ), ) else: @@ -101,12 +100,14 @@ def generate( # noqa: C901 tokens=torch.tensor([tokens], dtype=torch.long, device=self.device), ) current_token = next_token(logits, temperature, top_p) + tokens.append(current_token) if current_token == self.tokenizer.eos_id or ( hasattr(self.tokenizer, "stop_tokens") and current_token in self.tokenizer.stop_tokens ): break - tokens.append(current_token) + print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True) + print("\n") return tokens if echo else tokens[len(prompt_tokens) :] @@ -116,7 +117,7 @@ def text_completion( temperature: float = 0.6, top_p: float = 0.9, echo: bool = False, - ) -> CompletionPrediction: + ) -> List[int]: """ Perform text completion for a prompt using the language model. @@ -127,19 +128,60 @@ def text_completion( echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. Returns: - CompletionPrediction: Completion prediction, which contains the generated text completion. + Generated list of tokens. Note: This method generates text completion for the provided prompt, employing nucleus sampling to introduce controlled randomness. """ - prompt_tokens = self.tokenizer.encode(prompt, bos=True, eos=False) - generation_tokens = self.generate( - prompt_tokens=prompt_tokens, + return self.generate( + prompt_tokens=self.tokenizer.encode(prompt, bos=True, eos=False), temperature=temperature, top_p=top_p, echo=echo, ) - return { - "generation": self.tokenizer.decode(generation_tokens), - "tokens": generation_tokens, - } + + def chat_completion( + self, + temperature: float = 0.6, + top_p: float = 0.9, + ) -> List[int]: + """ + Perform multi-turn chat with the language model. + + Args: + prompt (str): Text prompt for completion. + temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. + top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9. + echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. + + Returns: + Generated list of tokens. + + Note: + This method generates text completion for the provided prompt, employing nucleus sampling to introduce controlled randomness. + """ + exit_prompt = "exit" + tokens = [] + prompt = input("Me: ") + while prompt and prompt != exit_prompt: + print("LLM: ", end="", flush=True) + new_tokens = self.generate( + prompt_tokens=self.tokenizer.encode( + self._format_prompt(prompt), bos=True, eos=False + ), + temperature=temperature, + top_p=top_p, + echo=True, + pos_base=len(tokens), + ) + tokens.extend(new_tokens) + prompt = input("Me: ") + return tokens + + def _format_prompt(self, prompt: str) -> str: + return f""" +<|begin_of_text|><|start_header_id|>system<|end_header_id|> + +You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|> + +{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>""" diff --git a/examples/models/llama/runner/native.py b/examples/models/llama/runner/native.py index 73005d9333..19e5791598 100644 --- a/examples/models/llama/runner/native.py +++ b/examples/models/llama/runner/native.py @@ -107,15 +107,11 @@ def main() -> None: parser = build_args_parser() args = parser.parse_args() runner = NativeLlamaRunner(args) - result = runner.text_completion( + generated_tokens = runner.text_completion( prompt=args.prompt, temperature=args.temperature, ) - print( - "Response: \n{response}\n Tokens:\n {tokens}".format( - response=result["generation"], tokens=result["tokens"] - ) - ) + print(f"Response: {generated_tokens}") if __name__ == "__main__": diff --git a/examples/models/llama/source_transformation/attention.py b/examples/models/llama/source_transformation/attention.py new file mode 100644 index 0000000000..c5a028d340 --- /dev/null +++ b/examples/models/llama/source_transformation/attention.py @@ -0,0 +1,219 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +# Example script for exporting Llama2 to flatbuffer + +import math +from typing import List, Optional, Tuple + +import torch +from executorch.examples.models.llama.llama_transformer import Attention +from torch import nn + + +def apply_rotary_emb_single( + x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor +) -> torch.Tensor: + x_r, x_i = x[..., ::2], x[..., 1::2] + + x_out_r = x_r * freqs_cos - x_i * freqs_sin + x_out_i = x_r * freqs_sin + x_i * freqs_cos + + x_out = torch.cat([x_out_r, x_out_i], dim=-1) + return x_out + + +class KVCacheSHA(torch.nn.Module): + def __init__( + self, + max_batch_size: int, + max_seq_length: int, + n_heads: int, + head_dim: int, + dtype=torch.float32, + ): + super().__init__() + + # a buffer per head + cache_shape = (max_batch_size, max_seq_length, head_dim) + for i in range(n_heads): + self.register_buffer( + f"past_k_caches_{i}", + torch.zeros(cache_shape, dtype=dtype, device="cpu"), + persistent=False, + ) + self.register_buffer( + f"past_v_caches_{i}", + torch.zeros(cache_shape, dtype=dtype, device="cpu"), + persistent=False, + ) + + def update( + self, + input_pos: torch.Tensor, + k_val: torch.Tensor, + v_val: torch.Tensor, + cache_idx: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + new_k = torch.ops.aten.index_put_( + getattr(self, f"past_k_caches_{cache_idx}"), [None, input_pos], k_val + ) + new_v = torch.ops.aten.index_put_( + getattr(self, f"past_v_caches_{cache_idx}"), [None, input_pos], v_val + ) + return new_k, new_v + + def get_cache(self, head_idx): + return getattr(self, f"past_k_caches_{head_idx}"), getattr( + self, f"past_v_caches_{head_idx}" + ) + + +class SDPASHA(torch.nn.Module): + + def __init__( + self, + max_batch_size: int, + max_seq_length: int, + n_heads: int, + n_rep: int, + head_dim: int, + dim: int, + ): + super().__init__() + self.head_dim = head_dim + self.n_rep = n_rep + self.dim = dim + self.kv_cache = KVCacheSHA( + max_batch_size, max_seq_length, n_heads // n_rep, head_dim + ) + self.scale_factor = math.sqrt(head_dim) + + def forward( + self, + input_pos: torch.Tensor, + qs: List[torch.Tensor], + ks: List[torch.Tensor], + vs: List[torch.Tensor], + mask, + ): + + transpose_ks = [] + for i in range(len(ks)): + new_k, _ = self.kv_cache.update(input_pos, ks[i], vs[i], i) + transpose_ks.append(new_k.transpose(-2, -1).contiguous()) + + output = [] + for i, q in enumerate(qs): + cache_idx = i // self.n_rep + _, v = self.kv_cache.get_cache(cache_idx) + + attn_mask = mask[input_pos] + + attn_weight = q @ transpose_ks[cache_idx] / self.scale_factor + attn_weight += attn_mask + attn_weight = torch.softmax(attn_weight, dim=-1) + output.append(attn_weight @ v.contiguous()) + + return torch.cat(output, dim=-1) + + +class AttentionSHA(nn.Module): + def __init__(self, attention_mha: nn.Module): + super().__init__() + if not attention_mha.use_kv_cache: + raise NotImplementedError("bert mode is not support") + + self.n_heads = attention_mha.n_heads + self.n_kv_heads = attention_mha.n_kv_heads + self.n_rep = self.n_heads // self.n_kv_heads + self.dim = attention_mha.dim + self.max_batch_size = attention_mha.max_batch_size + self.max_seq_len = attention_mha.max_seq_len + self.head_dim = attention_mha.dim // self.n_heads + self.SDPA = SDPASHA( + self.max_batch_size, + self.max_seq_len, + self.n_heads, + self.n_rep, + self.head_dim, + self.dim, + ) + self.wq = nn.ModuleList( + [ + nn.Linear(self.dim, self.head_dim, bias=False) + for _ in range(self.n_heads) + ] + ) + self.wk = nn.ModuleList( + [ + nn.Linear(self.dim, self.head_dim, bias=False) + for _ in range(self.n_kv_heads) + ] + ) + self.wv = nn.ModuleList( + [ + nn.Linear(self.dim, self.head_dim, bias=False) + for _ in range(self.n_kv_heads) + ] + ) + + for i in range(self.n_heads): + self.wq[i].weight.data.copy_( + attention_mha.wq.weight[i * self.head_dim : (i + 1) * self.head_dim] + ) + for i in range(self.n_kv_heads): + self.wk[i].weight.data.copy_( + attention_mha.wk.weight[i * self.head_dim : (i + 1) * self.head_dim] + ) + self.wv[i].weight.data.copy_( + attention_mha.wv.weight[i * self.head_dim : (i + 1) * self.head_dim] + ) + self.wo = attention_mha.wo + + causal_mask = torch.tril( + torch.ones( + self.max_seq_len, + self.max_seq_len, + dtype=torch.bool, + device="cpu", + ) + ) + self.register_buffer("mask", causal_mask, persistent=False) + + def forward( + self, + x: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + input_pos: Optional[torch.Tensor] = None, + ): + # QKV + q = [wq(x) for wq in self.wq] + k = [wk(x) for wk in self.wk] + v = [wv(x) for wv in self.wv] + for i in range(len(q)): + q[i] = apply_rotary_emb_single(q[i], freqs_cos, freqs_sin) + for i in range(len(k)): + k[i] = apply_rotary_emb_single(k[i], freqs_cos, freqs_sin) + + output = self.SDPA(input_pos, q, k, v, self.mask) + return self.wo(output) + + +def replace_attention_to_attention_sha(module: torch.nn.Module): + for name, child in module.named_children(): + if isinstance(child, Attention): + setattr( + module, + name, + AttentionSHA(child), + ) + else: + replace_attention_to_attention_sha(child) + return module diff --git a/examples/models/llama/source_transformation/quantize.py b/examples/models/llama/source_transformation/quantize.py index 162d41d659..d168b7efcd 100644 --- a/examples/models/llama/source_transformation/quantize.py +++ b/examples/models/llama/source_transformation/quantize.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import logging +import re from functools import partial from pathlib import Path from typing import Any, Dict, Optional @@ -70,6 +72,26 @@ def quantize( # noqa C901 if qmode == "int8": # Add quantization mode options here: group size, bit width, etc. return WeightOnlyInt8QuantHandler(model).quantized_model() + elif qmode.startswith("torchao:"): + pattern = r"torchao:8da(\d+)w" + matches = re.findall(pattern, qmode) + assert len(matches) == 1, f"Expected 1 match for pattern but got {len(matches)}" + bitwidth = int(matches[0][0]) + _load_torchao_ops_aten() + from torchao.experimental.quant_api import Int8DynActIntxWeightLinearQuantizer + + with torch.no_grad(): + model = Int8DynActIntxWeightLinearQuantizer( + device="cpu", + precision=torch.float32, + groupsize=group_size, + bitwidth=bitwidth, + has_weight_zeros=False, + ).quantize(model) + + if verbose: + print("quantized model:", model) + return model elif qmode == "8da4w": # Check for required args if group_size is None: @@ -79,6 +101,7 @@ def quantize( # noqa C901 model = Int8DynActInt4WeightQuantizer( precision=torch_dtype, groupsize=group_size ).quantize(model) + if verbose: print("quantized model:", model) return model @@ -692,6 +715,25 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor: def get_quant_embedding_transform(args): + if args.embedding_quantize.startswith("torchao:"): + bitwidth, group_size = args.embedding_quantize.split(":")[1].split(",") + group_size = int(group_size) + bitwidth = int(bitwidth) + _load_torchao_ops_aten() + from torchao.experimental.quant_api import IntxWeightEmbeddingQuantizer + + def _torchao_embedding_quantizer(model): + with torch.no_grad(): + model = IntxWeightEmbeddingQuantizer( + device="cpu", + precision=torch.float32, + bitwidth=bitwidth, + groupsize=group_size, + ).quantize(model) + return model + + return _torchao_embedding_quantizer + bitwidth, group_size = args.embedding_quantize.split(",") if group_size == "none" or group_size == "None" or group_size == "0": group_size = None @@ -733,4 +775,23 @@ def get_quant_weight_transform(args, dtype_override, verbose): ) +def _load_torchao_ops_aten(): + import glob + import os + + libs = glob.glob( + os.path.abspath( + os.path.join( + os.environ.get("CMAKE_INSTALL_PREFIX", ""), + "lib/libtorchao_ops_aten.*", + ) + ) + ) + assert ( + len(libs) == 1 + ), f"Expected 1 library but got {len(libs)}. If you installed the torchao ops in a non-standard location, please set CMAKE_INSTALL_PREFIX correctly." + logging.info(f"Loading custom ops library: {libs[0]}") + torch.ops.load_library(libs[0]) + + ############################ Source Transform End ####################### diff --git a/examples/models/llama/tokenizer/tiktoken.py b/examples/models/llama/tokenizer/tiktoken.py index 1d74e5e3aa..b48cb4dc89 100644 --- a/examples/models/llama/tokenizer/tiktoken.py +++ b/examples/models/llama/tokenizer/tiktoken.py @@ -185,6 +185,18 @@ def decode(self, t: Sequence[int]) -> str: # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. return self.model.decode(cast(List[int], t)) + def decode_token(self, t: int) -> str: + """ + Decodes a single token ID into a string. + + Args: + t (int): The token ID to be decoded. + + Returns: + str: The decoded string. + """ + return self.model.decode_single_token_bytes(t).decode("utf-8") + @staticmethod def _split_whitespaces_or_nonwhitespaces( s: str, max_consecutive_slice_len: int diff --git a/examples/models/llama3_2_vision/install_requirements.sh b/examples/models/llama3_2_vision/install_requirements.sh index 44cc399acb..49558952d8 100755 --- a/examples/models/llama3_2_vision/install_requirements.sh +++ b/examples/models/llama3_2_vision/install_requirements.sh @@ -9,5 +9,4 @@ pip install --pre torchtune --extra-index-url https://download.pytorch.org/whl/nightly/cpu --no-cache-dir # Install torchao. -TORCHAO_VERSION=$(cat "$(dirname "$0")"/../../../.ci/docker/ci_commit_pins/torchao.txt) -pip install --no-use-pep517 "git+https://github.com/pytorch/ao.git@${TORCHAO_VERSION}" +pip install "$(dirname "$0")/../../../third-party/ao" diff --git a/examples/models/phi-3-mini-lora/install_requirements.sh b/examples/models/phi-3-mini-lora/install_requirements.sh index ec6289a126..2cd74d0cd4 100755 --- a/examples/models/phi-3-mini-lora/install_requirements.sh +++ b/examples/models/phi-3-mini-lora/install_requirements.sh @@ -10,5 +10,4 @@ pip install torchtune pip install tiktoken # Install torchao. -TORCHAO_VERSION=$(cat "$(dirname "$0")"/../../../.ci/docker/ci_commit_pins/torchao.txt) -pip install --no-use-pep517 "git+https://github.com/pytorch/ao.git@${TORCHAO_VERSION}" +pip install "$(dirname "$0")/../../../third-party/ao" diff --git a/examples/portable/executor_runner/executor_runner.cpp b/examples/portable/executor_runner/executor_runner.cpp index f1a2d3b8f2..93c150c0b9 100644 --- a/examples/portable/executor_runner/executor_runner.cpp +++ b/examples/portable/executor_runner/executor_runner.cpp @@ -22,9 +22,6 @@ #include -#include -#include - #include #include #include @@ -39,10 +36,6 @@ DEFINE_string( model_path, "model.pte", "Model serialized in flatbuffer format."); -DEFINE_bool( - is_fd_uri, - false, - "True if the model_path passed is a file descriptor with the prefix \"fd:///\"."); using executorch::extension::FileDataLoader; using executorch::runtime::Error; @@ -73,12 +66,7 @@ int main(int argc, char** argv) { // DataLoaders that use mmap() or point to data that's already in memory, and // users can create their own DataLoaders to load from arbitrary sources. const char* model_path = FLAGS_model_path.c_str(); - const bool is_fd_uri = FLAGS_is_fd_uri; - - Result loader = is_fd_uri - ? FileDataLoader::fromFileDescriptorUri(model_path) - : FileDataLoader::from(model_path); - + Result loader = FileDataLoader::from(model_path); ET_CHECK_MSG( loader.ok(), "FileDataLoader::from() failed: 0x%" PRIx32, diff --git a/examples/qualcomm/scripts/install_requirement.sh b/examples/qualcomm/scripts/install_requirement.sh new file mode 100644 index 0000000000..c961467a8a --- /dev/null +++ b/examples/qualcomm/scripts/install_requirement.sh @@ -0,0 +1,2 @@ +pip install soundfile +pip install torchmetrics diff --git a/examples/qualcomm/scripts/wav2letter.py b/examples/qualcomm/scripts/wav2letter.py new file mode 100644 index 0000000000..e377c6d7e9 --- /dev/null +++ b/examples/qualcomm/scripts/wav2letter.py @@ -0,0 +1,226 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import json +import os +import sys +from multiprocessing.connection import Client + +import numpy as np + +import torch +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype +from executorch.examples.models.wav2letter import Wav2LetterModel +from executorch.examples.qualcomm.utils import ( + build_executorch_binary, + make_output_dir, + parse_skip_delegation_node, + setup_common_args_and_variables, + SimpleADB, +) + + +class Conv2D(torch.nn.Module): + def __init__(self, stride, padding, weight, bias=None): + super().__init__() + use_bias = bias is not None + self.conv = torch.nn.Conv2d( + in_channels=weight.shape[1], + out_channels=weight.shape[0], + kernel_size=[weight.shape[2], 1], + stride=[*stride, 1], + padding=[*padding, 0], + bias=use_bias, + ) + self.conv.weight = torch.nn.Parameter(weight.unsqueeze(-1)) + if use_bias: + self.conv.bias = torch.nn.Parameter(bias) + + def forward(self, x): + return self.conv(x) + + +def get_dataset(data_size, artifact_dir): + from torch.utils.data import DataLoader + from torchaudio.datasets import LIBRISPEECH + + def collate_fun(batch): + waves, labels = [], [] + + for wave, _, text, *_ in batch: + waves.append(wave.squeeze(0)) + labels.append(text) + # need padding here for static ouput shape + waves = torch.nn.utils.rnn.pad_sequence(waves, batch_first=True) + return waves, labels + + dataset = LIBRISPEECH(artifact_dir, url="test-clean", download=True) + data_loader = DataLoader( + dataset=dataset, + batch_size=data_size, + shuffle=True, + collate_fn=lambda x: collate_fun(x), + ) + # prepare input data + inputs, targets, input_list = [], [], "" + for wave, label in data_loader: + for index in range(data_size): + # reshape input tensor to NCHW + inputs.append((wave[index].reshape(1, 1, -1, 1),)) + targets.append(label[index]) + input_list += f"input_{index}_0.raw\n" + # here we only take first batch, i.e. 'data_size' tensors + break + + return inputs, targets, input_list + + +def eval_metric(pred, target_str): + from torchmetrics.text import CharErrorRate, WordErrorRate + + def parse(ids): + vocab = " abcdefghijklmnopqrstuvwxyz'*" + return ["".join([vocab[c] for c in id]).replace("*", "").upper() for id in ids] + + pred_str = parse( + [ + torch.unique_consecutive(pred[i, :, :].argmax(0)) + for i in range(pred.shape[0]) + ] + ) + wer, cer = WordErrorRate(), CharErrorRate() + return wer(pred_str, target_str), cer(pred_str, target_str) + + +def main(args): + skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + + # ensure the working directory exist + os.makedirs(args.artifact, exist_ok=True) + + if not args.compile_only and args.device is None: + raise RuntimeError( + "device serial is required if not compile only. " + "Please specify a device serial by -s/--device argument." + ) + + instance = Wav2LetterModel() + # target labels " abcdefghijklmnopqrstuvwxyz'*" + instance.vocab_size = 29 + model = instance.get_eager_model().eval() + model.load_state_dict(torch.load(args.pretrained_weight, weights_only=True)) + + # convert conv1d to conv2d in nn.Module level will only introduce 2 permute + # nodes around input & output, which is more quantization friendly. + for i in range(len(model.acoustic_model)): + for j in range(len(model.acoustic_model[i])): + module = model.acoustic_model[i][j] + if isinstance(module, torch.nn.Conv1d): + model.acoustic_model[i][j] = Conv2D( + stride=module.stride, + padding=module.padding, + weight=module.weight, + bias=module.bias, + ) + + # retrieve dataset, will take some time to download + data_num = 100 + inputs, targets, input_list = get_dataset( + data_size=data_num, artifact_dir=args.artifact + ) + pte_filename = "w2l_qnn" + build_executorch_binary( + model, + inputs[0], + args.model, + f"{args.artifact}/{pte_filename}", + inputs, + skip_node_id_set=skip_node_id_set, + skip_node_op_set=skip_node_op_set, + quant_dtype=QuantDtype.use_8a8w, + shared_buffer=args.shared_buffer, + ) + + if args.compile_only: + sys.exit(0) + + adb = SimpleADB( + qnn_sdk=os.getenv("QNN_SDK_ROOT"), + build_path=f"{args.build_folder}", + pte_path=f"{args.artifact}/{pte_filename}.pte", + workspace=f"/data/local/tmp/executorch/{pte_filename}", + device_id=args.device, + host_id=args.host, + soc_model=args.model, + shared_buffer=args.shared_buffer, + ) + adb.push(inputs=inputs, input_list=input_list) + adb.execute() + + # collect output data + output_data_folder = f"{args.artifact}/outputs" + make_output_dir(output_data_folder) + adb.pull(output_path=args.artifact) + + predictions = [] + for i in range(data_num): + predictions.append( + np.fromfile( + os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32 + ) + ) + + # evaluate metrics + wer, cer = 0, 0 + for i, pred in enumerate(predictions): + pred = torch.from_numpy(pred).reshape(1, instance.vocab_size, -1) + wer_eval, cer_eval = eval_metric(pred, targets[i]) + wer += wer_eval + cer += cer_eval + + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send( + json.dumps({"wer": wer.item() / data_num, "cer": cer.item() / data_num}) + ) + else: + print(f"wer: {wer / data_num}\ncer: {cer / data_num}") + + +if __name__ == "__main__": + parser = setup_common_args_and_variables() + + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. " + "Default ./wav2letter", + default="./wav2letter", + type=str, + ) + + parser.add_argument( + "-p", + "--pretrained_weight", + help=( + "Location of pretrained weight, please download via " + "https://github.com/nipponjo/wav2letter-ctc-pytorch/tree/main?tab=readme-ov-file#wav2letter-ctc-pytorch" + " for torchaudio.models.Wav2Letter version" + ), + default=None, + type=str, + required=True, + ) + + args = parser.parse_args() + try: + main(args) + except Exception as e: + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send(json.dumps({"Error": str(e)})) + else: + raise Exception(e) diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index 06225be2d1..ae5444023a 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -348,7 +348,9 @@ def histogram(golden, predict): return (pa, mpa, miou, cls_iou) -def get_imagenet_dataset(dataset_path, data_size, image_shape, crop_size=None): +def get_imagenet_dataset( + dataset_path, data_size, image_shape, crop_size=None, shuffle=True +): from torchvision import datasets, transforms def get_data_loader(): @@ -365,7 +367,7 @@ def get_data_loader(): imagenet_data = datasets.ImageFolder(dataset_path, transform=preprocess) return torch.utils.data.DataLoader( imagenet_data, - shuffle=True, + shuffle=shuffle, ) # prepare input data diff --git a/exir/pass_base.py b/exir/pass_base.py index db6bef8e3f..9c97921f51 100644 --- a/exir/pass_base.py +++ b/exir/pass_base.py @@ -318,7 +318,11 @@ def call_function( if target == operator.getitem: value, key = args return self.callback.call_getitem(value, key, meta) - elif getattr(target, "__module__", None) in {"_operator", "math"}: + elif getattr(target, "__module__", None) in { + "_operator", + "builtins", + "math", + }: assert callable(target) return self.callback.call_sym(target, args, meta) elif target in _TORCH_SYM_OPS: diff --git a/exir/passes/TARGETS b/exir/passes/TARGETS index eeb1e5265b..a3251589ac 100644 --- a/exir/passes/TARGETS +++ b/exir/passes/TARGETS @@ -16,6 +16,7 @@ python_library( ":normalize_transpose_pass", ":prim_ops_py_registry", ":quant_fusion_pass", + ":quantize_io_pass", ":remove_noop_pass", ":replace_aten_with_edge_pass", ":replace_broken_ops_with_function_ops_pass", @@ -143,6 +144,19 @@ python_library( ], ) +python_library( + name = "quantize_io_pass", + srcs = [ + "quantize_io_pass.py", + ], + deps = [ + "fbsource//third-party/pypi/numpy:numpy", + "//caffe2:torch", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + ], +) + python_library( name = "memory_planning_pass", srcs = [ diff --git a/exir/passes/__init__.py b/exir/passes/__init__.py index 7a0623040f..fdb954010c 100644 --- a/exir/passes/__init__.py +++ b/exir/passes/__init__.py @@ -339,7 +339,7 @@ def get_submodule(node: torch.fx.Node) -> torch.fx.GraphModule: self.call(get_submodule(node.args[0])) self.call(get_submodule(node.args[1])) continue - elif getattr(target, "__module__", None) == "_operator": + elif getattr(target, "__module__", None) in ("builtins", "_operator"): continue elif target in to_out_var_skiplist: continue diff --git a/exir/passes/executorch_prim_ops_registry.py b/exir/passes/executorch_prim_ops_registry.py index 6362a47112..fa1c2e6913 100644 --- a/exir/passes/executorch_prim_ops_registry.py +++ b/exir/passes/executorch_prim_ops_registry.py @@ -4,8 +4,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import builtins +import math import operator -from typing import Dict, Set, Union +from typing import Any, Dict, Set, Union # necessary to ensure the ops are registered import torch @@ -14,6 +16,8 @@ from torch._ops import OpOverload from torch.library import Library +# pyre-unsafe + executorch_prims_lib = Library("executorch_prim", "DEF") @@ -91,7 +95,25 @@ def neg(a: _SymScalar) -> _SymScalar: return -a # pyre-ignore -_PYTHON_SYM_OPS_TO_EXECUTORCH_SYM_OPS: Dict[OpOverload, OpOverload] = { +@bind_pattern_to_op(executorch_prims_lib, "ceil.Scalar(Scalar a) -> Scalar") +def ceil(a: _SymScalar) -> _SymScalar: + return math.ceil(a) # pyre-ignore + + +@bind_pattern_to_op(executorch_prims_lib, "round.Scalar(Scalar a) -> Scalar") +def builtin_round(a: _SymScalar) -> _SymScalar: + return round(a) # pyre-ignore + + +@bind_pattern_to_op(executorch_prims_lib, "trunc.Scalar(Scalar a) -> Scalar") +def trunc(a: _SymScalar) -> _SymScalar: + return math.trunc(a) # pyre-ignore + + +_PYTHON_SYM_OPS_TO_EXECUTORCH_SYM_OPS: Dict[Any, OpOverload] = { + builtins.round: ops.backend.executorch_prim.round.Scalar, + math.ceil: ops.backend.executorch_prim.ceil.Scalar, + math.trunc: ops.backend.executorch_prim.trunc.Scalar, operator.sub: ops.backend.executorch_prim.sub.Scalar, operator.mul: ops.backend.executorch_prim.mul.Scalar, operator.add: ops.backend.executorch_prim.add.Scalar, diff --git a/exir/passes/quantize_io_pass.py b/exir/passes/quantize_io_pass.py new file mode 100644 index 0000000000..21ac4c868a --- /dev/null +++ b/exir/passes/quantize_io_pass.py @@ -0,0 +1,259 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +import logging +from typing import Any, Dict, List, Optional, Union + +import numpy as np + +import torch + +from executorch.exir import EdgeProgramManager +from executorch.exir.dialects._ops import ops as exir_ops + +from executorch.exir.pass_base import ExportPass +from executorch.exir.tensor import scalar_type_enum +from torch.fx.passes.infra.pass_base import PassResult + +logger = logging.getLogger(__name__) + + +def quantize_input( + exported_program, input_index, qparams: Optional[Dict[str, Any]] = None +): + """ + Modify the program to expect quantized input at given index. The input is expected + to be quantizing this input as the first step. Must be called before + permute_input_layout. Returns the scale, zero point, qmin, qmax, and dtype of the + expected quantization. + """ + graph = exported_program.graph_module.graph + name = exported_program.graph_signature.user_inputs[input_index] + placeholders = [n for n in graph.nodes if n.op == "placeholder" and n.name == name] + assert placeholders + target_placeholder = placeholders[0] + + if len(target_placeholder.users) != 1: + raise ValueError(f"Input {input_index} has more than one users") + quantize = next(iter(target_placeholder.users)) + if ( + quantize.target + != exir_ops.edge.quantized_decomposed.quantize_per_tensor.default + ): + raise ValueError(f"Input {input_index} is not used by a quantize op") + + # If user specified qparams are different from args of quantize op, we do requantization instead of eliminating quantize op + need_requant = False + if qparams is not None: + assert all( + qparam in qparams for qparam in ["scale", "zp", "dtype"] + ), "dtype/scale/zp must be specified in qparam for input requantization" + if qparams["dtype"] != quantize.args[5]: + if any( + dtype + not in [torch.int8, torch.uint8, torch.bool, torch.int16, torch.uint16] + for dtype in [qparams["dtype"], quantize.args[5]] + ): + raise ValueError( + f"Only limited data types are supported for requantization, but got {qparams['dtype']} -> {quantize.args[5]}" + ) + + need_requant = True + elif ( + not np.isclose(qparams["scale"], quantize.args[1]) + or qparams["zp"] != quantize.args[2] + ): + need_requant = True + + if need_requant: + assert qparams is not None + dtype = qparams["dtype"] + qmin = torch.iinfo(dtype).min + qmax = torch.iinfo(dtype).max + scale = qparams["scale"] + zero_point = qparams["zp"] + quant_args = (scale, zero_point, qmin, qmax, dtype) + logger.info( + f"Modifying program to requantize quantized input at index {input_index}" + ) + logger.info(f"Quantization parameters: {quant_args}") + + with exported_program.graph_module.graph.inserting_before(quantize): + input_dequant = exported_program.graph_module.graph.call_function( + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + args=( + target_placeholder, + *quant_args, + ), + ) + input_dequant.meta["input_qparams"] = [ + { + "scale": scale, + "zero_point": zero_point, + "qmin": qmin, + "qmax": qmax, + "dtype": dtype, + } + ] + input_dequant.meta["val"] = quantize.meta["val"].to(torch.float32) + target_placeholder.meta["val"] = target_placeholder.meta["val"].to(dtype) + quantize.replace_input_with(target_placeholder, input_dequant) + else: + quant_args = quantize.args[1:] + logger.info(f"Modifying program to take quantized input at index {input_index}") + logger.info(f"Quantization parameters: {quant_args}") + + target_placeholder.meta["val"] = ( + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default( + target_placeholder.meta["val"], *quant_args + ) + ) + quantize.replace_all_uses_with(quantize.args[0]) + + exported_program.graph_module.graph.eliminate_dead_code() + return quant_args + + +def quantize_output(exported_program, output_index): + """ + Modify the program to produce quantized output at given index. The model is expected + to be dequantizing this output as the last step. Must be called before + permute_output_layout. Returns the scale, zero point, qmin, qmax, and dtype of the + output quantization. + """ + graph = exported_program.graph_module.graph + outputs = [n for n in graph.nodes if n.op == "output"] + if len(outputs) != 1: + raise NotImplementedError("Only 1 output node is supported") + + output_node = outputs[0] + output_list = list(output_node.args[0]) + if output_index >= len(output_list): + raise ValueError( + f"{len(output_list)} outputs available, " + + f"output index out of bounds: {output_index}" + ) + + target_output = output_list[output_index] + if ( + target_output.target + != exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default + ): + raise ValueError("Output {output_index} is not a dequantize op") + + dequant = target_output + output_list[output_index] = dequant.args[0] + output_node.args = (output_list,) + dequant_args = dequant.args[1:] + graph.eliminate_dead_code() + + logger.info( + f"Modifying program to produce quantized output at index {output_index}" + ) + logger.info(f"Dequantization parameters: {dequant_args}") + return dequant_args + + +def get_config_method_name( + prefix: Optional[str] = "forward", + arg_type: str = "input", + index: int = 0, + key: str = "scale", +): + if prefix is None: + prefix = "" + else: + prefix = prefix + "_" + assert arg_type in ["input", "output"], "arg_type must be either input or output" + assert index >= 0, "index must be non-negative" + assert key in [ + "scale", + "zp", + "quant_min", + "quant_max", + "dtype", + ], "key must be one of scale, zp, quant_min, quant_max, dtype" + return f"{prefix}{arg_type}{index}_{key}" + + +class QuantizeInputs(ExportPass): + def __init__( + self, + edge_program_manager: EdgeProgramManager, + quantized_inputs_idx: Union[Dict[int, Dict[str, Any]], List[int]], + method_name: Optional[str] = None, + ): + super().__init__() + self.edge_program_manager = edge_program_manager + + self.quantized_inputs_idx_dict = {} + if isinstance(quantized_inputs_idx, dict): + self.quantized_inputs_idx_dict = quantized_inputs_idx + else: + for idx in quantized_inputs_idx: + self.quantized_inputs_idx_dict[idx] = None + self.param_prefix_name = method_name + + def call(self, graph_module: torch.fx.GraphModule): + for i, qparams in self.quantized_inputs_idx_dict.items(): + quant_args = quantize_input( + self.edge_program_manager.exported_program(), i, qparams + ) + + if not self.edge_program_manager._config_methods: + self.edge_program_manager._config_methods = {} + + self.edge_program_manager._config_methods[ + get_config_method_name(self.param_prefix_name, "input", i, "scale") + ] = quant_args[0] + self.edge_program_manager._config_methods[ # pyre-ignore + get_config_method_name(self.param_prefix_name, "input", i, "zp") + ] = quant_args[1] + self.edge_program_manager._config_methods[ + get_config_method_name(self.param_prefix_name, "input", i, "quant_min") + ] = quant_args[2] + self.edge_program_manager._config_methods[ + get_config_method_name(self.param_prefix_name, "input", i, "quant_max") + ] = quant_args[3] + self.edge_program_manager._config_methods[ + get_config_method_name(self.param_prefix_name, "input", i, "dtype") + ] = scalar_type_enum(quant_args[4]) + return PassResult(graph_module, True) + + +class QuantizeOutputs(ExportPass): + def __init__( + self, + edge_program_manager: EdgeProgramManager, + quantized_outputs_idx_list: List[int], + method_name: Optional[str] = None, + ): + super().__init__() + self.edge_program_manager = edge_program_manager + self.quantized_outputs_idx_list = quantized_outputs_idx_list + self.param_prefix_name = method_name + + def call(self, graph_module: torch.fx.GraphModule): + for i in self.quantized_outputs_idx_list: + dequant_args = quantize_output( + self.edge_program_manager.exported_program(), i + ) # noqa F841 + + if not self.edge_program_manager._config_methods: + self.edge_program_manager._config_methods = {} + + self.edge_program_manager._config_methods[ + get_config_method_name(self.param_prefix_name, "output", i, "scale") + ] = dequant_args[0] + self.edge_program_manager._config_methods[ # pyre-ignore + get_config_method_name(self.param_prefix_name, "output", i, "zp") + ] = dequant_args[1] + self.edge_program_manager._config_methods[ + get_config_method_name(self.param_prefix_name, "output", i, "quant_min") + ] = dequant_args[2] + self.edge_program_manager._config_methods[ + get_config_method_name(self.param_prefix_name, "output", i, "quant_max") + ] = dequant_args[3] + self.edge_program_manager._config_methods[ + get_config_method_name(self.param_prefix_name, "output", i, "dtype") + ] = scalar_type_enum(dequant_args[4]) + + return PassResult(graph_module, True) diff --git a/exir/tests/TARGETS b/exir/tests/TARGETS index f8b4d905fb..1995589f80 100644 --- a/exir/tests/TARGETS +++ b/exir/tests/TARGETS @@ -448,3 +448,15 @@ python_unittest( "//executorch/exir:_warnings", ], ) + +python_unittest( + name = "quantize_io_pass", + srcs = [ + "test_quantize_io_pass.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:lib", + "//executorch/exir/passes:quantize_io_pass", + ], +) diff --git a/exir/tests/test_quantize_io_pass.py b/exir/tests/test_quantize_io_pass.py new file mode 100644 index 0000000000..b3899b008c --- /dev/null +++ b/exir/tests/test_quantize_io_pass.py @@ -0,0 +1,156 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import unittest + +import torch +from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower +from executorch.exir.passes.quantize_io_pass import ( + get_config_method_name, + QuantizeInputs, + QuantizeOutputs, +) +from executorch.exir.tensor import get_scalar_type +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e + +from torch.ao.quantization.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, +) +from torch.testing import FileCheck + +op_str = { + "q": "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default", + "dq": "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default", +} + + +class TestQuantIOPass(unittest.TestCase): + class Add(torch.nn.Module): + def forward(self, x, y): + return x + y + + def _quantize(self, mod, example_inputs): + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config() + quantizer.set_global(operator_config) + m = torch.export.export_for_training( + mod, copy.deepcopy(example_inputs) + ).module() + m = prepare_pt2e(m, quantizer) + _ = m(*example_inputs) + m = convert_pt2e(m) + exported_program = torch.export.export_for_training(m, example_inputs) + return exported_program + + def _check_count(self, op, count, epm): + code = epm.exported_program().graph_module.code + FileCheck().check_count(op, count, exactly=True).run(code) + + def _get_edge_prog_manager(self, mod, example_inputs): + exported_program = self._quantize(mod, example_inputs) + edge_program_manager = to_edge_transform_and_lower( + exported_program, + transform_passes=[], + partitioner=None, + compile_config=EdgeCompileConfig(_check_ir_validity=False), + ) + + self._check_count(op_str["dq"], 3, edge_program_manager) + self._check_count(op_str["q"], 3, edge_program_manager) + return edge_program_manager + + def test_add_drop_q_inputs(self) -> None: + example_inputs = (torch.randn(1, 5), torch.randn(1, 5)) + mod = self.Add().eval() + edge_program_manager = self._get_edge_prog_manager(mod, example_inputs) + reference_outputs = edge_program_manager.exported_program().module()( + *example_inputs + ) + + edge_program_manager_qin = edge_program_manager.transform( + [ + QuantizeInputs( + edge_program_manager=edge_program_manager, + quantized_inputs_idx=[0, 1], + method_name="forward", + ) + ] + ) + self._check_count(op_str["dq"], 3, edge_program_manager) + self._check_count(op_str["q"], 1, edge_program_manager) + + quantized_example_inputs = [] + for i in range(len(example_inputs)): + d = edge_program_manager_qin._config_methods + scale = d[get_config_method_name("forward", "input", i, "scale")] + zp = d[get_config_method_name("forward", "input", i, "zp")] + quant_min = d[get_config_method_name("forward", "input", i, "quant_min")] + quant_max = d[get_config_method_name("forward", "input", i, "quant_max")] + dtype = get_scalar_type( + d[get_config_method_name("forward", "input", i, "dtype")] + ) + + quantized_example_inputs.append( + torch.ops.quantized_decomposed.quantize_per_tensor.default( + example_inputs[i], scale, zp, quant_min, quant_max, dtype + ), + ) + quantized_example_inputs = tuple(quantized_example_inputs) + output = edge_program_manager_qin.exported_program().module()( + *quantized_example_inputs + ) + torch.testing.assert_close( + reference_outputs[0], + output[0], + ) + + def test_add_drop_dq_output(self) -> None: + example_inputs = (torch.randn(1, 5), torch.randn(1, 5)) + mod = self.Add().eval() + edge_program_manager = self._get_edge_prog_manager(mod, example_inputs) + reference_outputs = edge_program_manager.exported_program().module()( + *example_inputs + ) + + edge_program_manager_dqout = edge_program_manager.transform( + [ + QuantizeOutputs( + edge_program_manager=edge_program_manager, + quantized_outputs_idx_list=[0], + method_name="forward", + ) + ] + ) + self._check_count(op_str["dq"], 2, edge_program_manager) + self._check_count(op_str["q"], 3, edge_program_manager) + + quantized_outputs = edge_program_manager_dqout.exported_program().module()( + *example_inputs + ) + + dequantized_outputs = [] + for i in range(len(quantized_outputs)): + d = edge_program_manager_dqout._config_methods + scale = d[get_config_method_name("forward", "output", i, "scale")] + zp = d[get_config_method_name("forward", "output", i, "zp")] + q_min = d[get_config_method_name("forward", "output", i, "quant_min")] + q_max = d[get_config_method_name("forward", "output", i, "quant_max")] + dtype = get_scalar_type( + d[get_config_method_name("forward", "output", i, "dtype")] + ) + dequantized_outputs.append( + torch.ops.quantized_decomposed.dequantize_per_tensor.default( + quantized_outputs[i], scale, zp, q_min, q_max, dtype + ) + ) + dequantized_outputs = tuple(dequantized_outputs) + + torch.testing.assert_close( + reference_outputs[0], + dequantized_outputs[0], + ) diff --git a/extension/android/CMakeLists.txt b/extension/android/CMakeLists.txt index 31f24b3979..70f21f2751 100644 --- a/extension/android/CMakeLists.txt +++ b/extension/android/CMakeLists.txt @@ -64,7 +64,7 @@ set(executorch_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../lib/cmake/ExecuTorch) find_package(executorch CONFIG REQUIRED) target_link_options_shared_lib(executorch) -add_library(executorch_jni SHARED jni/jni_layer.cpp) +add_library(executorch_jni SHARED jni/jni_layer.cpp jni/log.cpp) set(link_libraries) list( @@ -146,7 +146,7 @@ if(EXECUTORCH_JNI_CUSTOM_LIBRARY) endif() if(EXECUTORCH_BUILD_LLAMA_JNI) - target_sources(executorch_jni PRIVATE jni/jni_layer_llama.cpp) + target_sources(executorch_jni PRIVATE jni/jni_layer_llama.cpp jni/log.cpp) list(APPEND link_libraries llama_runner llava_runner) target_compile_definitions(executorch_jni PUBLIC EXECUTORCH_BUILD_LLAMA_JNI=1) add_subdirectory( @@ -190,4 +190,4 @@ target_include_directories( target_compile_options(executorch_jni PUBLIC ${_common_compile_options}) -target_link_libraries(executorch_jni ${link_libraries}) +target_link_libraries(executorch_jni ${link_libraries} log) diff --git a/extension/android/build.gradle b/extension/android/build.gradle index de243154d6..b40f08e0c4 100644 --- a/extension/android/build.gradle +++ b/extension/android/build.gradle @@ -20,6 +20,5 @@ task makeJar(type: Jar) { dependencies { implementation 'com.facebook.fbjni:fbjni-java-only:0.2.2' implementation 'com.facebook.soloader:nativeloader:0.10.5' - testImplementation 'junit:junit:4.13.2' } } diff --git a/extension/android/jni/BUCK b/extension/android/jni/BUCK index 6f269739c0..e1bf26fef2 100644 --- a/extension/android/jni/BUCK +++ b/extension/android/jni/BUCK @@ -1,5 +1,6 @@ load("@fbsource//tools/build_defs/android:fb_android_cxx_library.bzl", "fb_android_cxx_library") load("@fbsource//xplat/executorch/backends/xnnpack/third-party:third_party_libs.bzl", "third_party_dep") +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") load("@fbsource//xplat/executorch/codegen:codegen.bzl", "executorch_generated_lib") oncall("executorch") @@ -25,7 +26,7 @@ executorch_generated_lib( fb_android_cxx_library( name = "executorch_jni", - srcs = ["jni_layer.cpp"], + srcs = ["jni_layer.cpp", "log.cpp"], headers = ["jni_layer_constants.h"], allow_jni_merging = False, compiler_flags = [ @@ -36,6 +37,7 @@ fb_android_cxx_library( soname = "libexecutorch.$(ext)", visibility = ["PUBLIC"], deps = [ + ":log_provider_static", "//fbandroid/libraries/fbjni:fbjni", "//fbandroid/native/fb:fb", "//third-party/glog:glog", @@ -49,7 +51,7 @@ fb_android_cxx_library( fb_android_cxx_library( name = "executorch_jni_full", - srcs = ["jni_layer.cpp"], + srcs = ["jni_layer.cpp", "log.cpp"], headers = ["jni_layer_constants.h"], allow_jni_merging = False, compiler_flags = [ @@ -60,6 +62,7 @@ fb_android_cxx_library( soname = "libexecutorch.$(ext)", visibility = ["PUBLIC"], deps = [ + ":log_provider_static", ":generated_op_lib_optimized_static", "//fbandroid/libraries/fbjni:fbjni", "//fbandroid/native/fb:fb", @@ -88,6 +91,7 @@ fb_android_cxx_library( soname = "libexecutorch.$(ext)", visibility = ["PUBLIC"], deps = [ + ":log_provider_static", "//fbandroid/libraries/fbjni:fbjni", "//fbandroid/native/fb:fb", "//third-party/glog:glog", @@ -101,3 +105,18 @@ fb_android_cxx_library( "//xplat/executorch/extension/threadpool:threadpool_static", ], ) + +runtime.cxx_library( + name = "log_provider", + srcs = ["log.cpp"], + exported_headers = ["log.h"], + compiler_flags = [ + "-frtti", + "-fexceptions", + "-Wno-unused-variable", + ], + deps = [ + "//executorch/runtime/core:core", + ], + visibility = ["@EXECUTORCH_CLIENTS"], +) diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index a6f0045725..ddba8462b9 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -17,6 +17,7 @@ #include "jni_layer_constants.h" +#include #include #include #include @@ -33,33 +34,6 @@ #include #include -#ifdef __ANDROID__ -#include - -// For Android, write to logcat -void et_pal_emit_log_message( - et_timestamp_t timestamp, - et_pal_log_level_t level, - const char* filename, - const char* function, - size_t line, - const char* message, - size_t length) { - int android_log_level = ANDROID_LOG_UNKNOWN; - if (level == 'D') { - android_log_level = ANDROID_LOG_DEBUG; - } else if (level == 'I') { - android_log_level = ANDROID_LOG_INFO; - } else if (level == 'E') { - android_log_level = ANDROID_LOG_ERROR; - } else if (level == 'F') { - android_log_level = ANDROID_LOG_FATAL; - } - - __android_log_print(android_log_level, "ExecuTorch", "%s", message); -} -#endif - using namespace executorch::extension; using namespace torch::executor; @@ -391,12 +365,43 @@ class ExecuTorchJni : public facebook::jni::HybridClass { return jresult; } + facebook::jni::local_ref> + readLogBuffer() { +#ifdef __ANDROID__ + + facebook::jni::local_ref> ret; + + access_log_buffer([&](std::vector& buffer) { + const auto size = buffer.size(); + ret = facebook::jni::JArrayClass::newArray(size); + for (auto i = 0u; i < size; i++) { + const auto& entry = buffer[i]; + // Format the log entry as "[TIMESTAMP FUNCTION FILE:LINE] LEVEL + // MESSAGE". + std::stringstream ss; + ss << "[" << entry.timestamp << " " << entry.function << " " + << entry.filename << ":" << entry.line << "] " + << static_cast(entry.level) << " " << entry.message; + + facebook::jni::local_ref jstr_message = + facebook::jni::make_jstring(ss.str().c_str()); + (*ret)[i] = jstr_message; + } + }); + + return ret; +#else + return facebook::jni::JArrayClass::newArray(0); +#endif + } + static void registerNatives() { registerHybrid({ makeNativeMethod("initHybrid", ExecuTorchJni::initHybrid), makeNativeMethod("forward", ExecuTorchJni::forward), makeNativeMethod("execute", ExecuTorchJni::execute), makeNativeMethod("loadMethod", ExecuTorchJni::load_method), + makeNativeMethod("readLogBuffer", ExecuTorchJni::readLogBuffer), }); } }; diff --git a/extension/android/jni/log.cpp b/extension/android/jni/log.cpp new file mode 100644 index 0000000000..663198e127 --- /dev/null +++ b/extension/android/jni/log.cpp @@ -0,0 +1,69 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "log.h" + +#ifdef __ANDROID__ + +#include +#include +#include +#include + +using executorch::extension::log_entry; + +// Number of entries to store in the in-memory log buffer. +const size_t log_buffer_length = 16; + +namespace { +std::vector log_buffer_; +std::mutex log_buffer_mutex_; +} // namespace + +// For Android, write to logcat +void et_pal_emit_log_message( + et_timestamp_t timestamp, + et_pal_log_level_t level, + const char* filename, + const char* function, + size_t line, + const char* message, + size_t length) { + std::lock_guard guard(log_buffer_mutex_); + + while (log_buffer_.size() >= log_buffer_length) { + log_buffer_.erase(log_buffer_.begin()); + } + + log_buffer_.emplace_back( + timestamp, level, filename, function, line, message, length); + + int android_log_level = ANDROID_LOG_UNKNOWN; + if (level == 'D') { + android_log_level = ANDROID_LOG_DEBUG; + } else if (level == 'I') { + android_log_level = ANDROID_LOG_INFO; + } else if (level == 'E') { + android_log_level = ANDROID_LOG_ERROR; + } else if (level == 'F') { + android_log_level = ANDROID_LOG_FATAL; + } + + __android_log_print(android_log_level, "ExecuTorch", "%s", message); +} + +namespace executorch::extension { + +void access_log_buffer(std::function&)> accessor) { + std::lock_guard guard(log_buffer_mutex_); + accessor(log_buffer_); +} + +} // namespace executorch::extension + +#endif diff --git a/extension/android/jni/log.h b/extension/android/jni/log.h new file mode 100644 index 0000000000..4389b1d61a --- /dev/null +++ b/extension/android/jni/log.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include +#include +#include + +namespace executorch::extension { +struct log_entry { + et_timestamp_t timestamp; + et_pal_log_level_t level; + std::string filename; + std::string function; + size_t line; + std::string message; + + log_entry( + et_timestamp_t timestamp, + et_pal_log_level_t level, + const char* filename, + const char* function, + size_t line, + const char* message, + size_t length) + : timestamp(timestamp), + level(level), + filename(filename), + function(function), + line(line), + message(message, length) {} +}; + +void access_log_buffer(std::function&)> accessor); +} // namespace executorch::extension diff --git a/extension/android/src/main/java/org/pytorch/executorch/Module.java b/extension/android/src/main/java/org/pytorch/executorch/Module.java index 608439548a..879b88c5f2 100644 --- a/extension/android/src/main/java/org/pytorch/executorch/Module.java +++ b/extension/android/src/main/java/org/pytorch/executorch/Module.java @@ -99,6 +99,11 @@ public int loadMethod(String methodName) { return mNativePeer.loadMethod(methodName); } + /** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */ + public String[] readLogBuffer() { + return mNativePeer.readLogBuffer(); + } + /** * Explicitly destroys the native torch::jit::Module. Calling this method is not required, as the * native object will be destroyed when this object is garbage-collected. However, the timing of diff --git a/extension/android/src/main/java/org/pytorch/executorch/NativePeer.java b/extension/android/src/main/java/org/pytorch/executorch/NativePeer.java index 2cf2ee53d7..a5487a4702 100644 --- a/extension/android/src/main/java/org/pytorch/executorch/NativePeer.java +++ b/extension/android/src/main/java/org/pytorch/executorch/NativePeer.java @@ -54,4 +54,8 @@ public void resetNative() { */ @DoNotStrip public native int loadMethod(String methodName); + + /** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */ + @DoNotStrip + public native String[] readLogBuffer(); } diff --git a/extension/android_test/.gitignore b/extension/android_test/.gitignore new file mode 100644 index 0000000000..a43b7e827a --- /dev/null +++ b/extension/android_test/.gitignore @@ -0,0 +1,6 @@ +local.properties +.gradle +.idea/* +.externalNativeBuild +src/libs/* +build diff --git a/extension/android_test/TARGETS b/extension/android_test/TARGETS new file mode 100644 index 0000000000..5c4f482b5e --- /dev/null +++ b/extension/android_test/TARGETS @@ -0,0 +1 @@ +# This file needs to exist to avoid build system breakage, see https://fburl.com/workplace/jtdlgdmd diff --git a/extension/android_test/add_model.py b/extension/android_test/add_model.py new file mode 100644 index 0000000000..5c7cf4770e --- /dev/null +++ b/extension/android_test/add_model.py @@ -0,0 +1,26 @@ +import torch +from executorch.exir import to_edge +from torch.export import export + + +# Start with a PyTorch model that adds two input tensors (matrices) +class Add(torch.nn.Module): + def __init__(self): + super(Add, self).__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor): + return x + y + + +# 1. torch.export: Defines the program with the ATen operator set. +aten_dialect = export(Add(), (torch.ones(1), torch.ones(1))) + +# 2. to_edge: Make optimizations for Edge devices +edge_program = to_edge(aten_dialect) + +# 3. to_executorch: Convert the graph to an ExecuTorch program +executorch_program = edge_program.to_executorch() + +# 4. Save the compiled .pte program +with open("add.pte", "wb") as file: + file.write(executorch_program.buffer) diff --git a/extension/android_test/build.gradle b/extension/android_test/build.gradle new file mode 100644 index 0000000000..5beb5455cb --- /dev/null +++ b/extension/android_test/build.gradle @@ -0,0 +1,65 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + buildscript { + repositories { + google() + mavenCentral() + gradlePluginPortal() + } + dependencies { + classpath 'com.android.tools.build:gradle:7.3.0' + } +} + + +apply plugin: 'com.android.library' + +group 'org.pytorch.executorch' + + +android { + namespace 'org.pytorch.executorch' + compileSdkVersion 31 + buildToolsVersion "29.0.0" + defaultConfig { + minSdkVersion 28 + targetSdkVersion 31 + versionCode 1 + versionName "1.0" + testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" + } + compileOptions { + sourceCompatibility JavaVersion.VERSION_1_8 + targetCompatibility JavaVersion.VERSION_1_8 + } + sourceSets { + androidTest { + resources.srcDirs += [ 'src/androidTest/resources' ] + } + } +} + +dependencies { + implementation 'com.facebook.soloader:nativeloader:0.10.5' + implementation("com.facebook.fbjni:fbjni:0.5.1") + implementation(files("src/libs/executorch.aar")) + testImplementation 'junit:junit:4.13.2' + androidTestImplementation 'androidx.test.ext:junit:1.1.5' + androidTestImplementation 'androidx.test:rules:1.2.0' + androidTestImplementation 'commons-io:commons-io:2.4' + androidTestImplementation 'androidx.test.espresso:espresso-core:3.5.1' + androidTestImplementation 'com.google.gms:google-services:4.3.3' +} + +task('setupNativeLibs', type: Exec){ + commandLine("sh", "setup.sh") +} + +gradle.projectsEvaluated { + preBuild.dependsOn setupNativeLibs +} diff --git a/extension/android_test/gradle.properties b/extension/android_test/gradle.properties new file mode 100644 index 0000000000..2cbd6d19d3 --- /dev/null +++ b/extension/android_test/gradle.properties @@ -0,0 +1,23 @@ +# Project-wide Gradle settings. +# IDE (e.g. Android Studio) users: +# Gradle settings configured through the IDE *will override* +# any settings specified in this file. +# For more details on how to configure your build environment visit +# http://www.gradle.org/docs/current/userguide/build_environment.html +# Specifies the JVM arguments used for the daemon process. +# The setting is particularly useful for tweaking memory settings. +org.gradle.jvmargs=-Xmx2048m -Dfile.encoding=UTF-8 +# When configured, Gradle will run in incubating parallel mode. +# This option should only be used with decoupled projects. More details, visit +# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects +# org.gradle.parallel=true +# AndroidX package structure to make it clearer which packages are bundled with the +# Android operating system, and which are packaged with your app's APK +# https://developer.android.com/topic/libraries/support-library/androidx-rn +android.useAndroidX=true +# Kotlin code style for this project: "official" or "obsolete": +kotlin.code.style=official +# Enables namespacing of each library's R class so that its R class includes only the +# resources declared in the library itself and none from the library's dependencies, +# thereby reducing the size of the R class for that library +android.nonTransitiveRClass=true diff --git a/extension/android_test/gradle/libs.versions.toml b/extension/android_test/gradle/libs.versions.toml new file mode 100644 index 0000000000..561988cb1f --- /dev/null +++ b/extension/android_test/gradle/libs.versions.toml @@ -0,0 +1,12 @@ +# This file was generated by the Gradle 'init' task. +# https://docs.gradle.org/current/userguide/platforms.html#sub::toml-dependencies-format + +[versions] +commons-math3 = "3.6.1" +guava = "32.1.3-jre" +junit = "4.13.2" + +[libraries] +commons-math3 = { module = "org.apache.commons:commons-math3", version.ref = "commons-math3" } +guava = { module = "com.google.guava:guava", version.ref = "guava" } +junit = { module = "junit:junit", version.ref = "junit" } diff --git a/extension/android_test/gradle/wrapper/gradle-wrapper.jar b/extension/android_test/gradle/wrapper/gradle-wrapper.jar new file mode 100644 index 0000000000..d64cd49177 Binary files /dev/null and b/extension/android_test/gradle/wrapper/gradle-wrapper.jar differ diff --git a/extension/android_test/gradle/wrapper/gradle-wrapper.properties b/extension/android_test/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 0000000000..a80b22ce5c --- /dev/null +++ b/extension/android_test/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,7 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-8.6-bin.zip +networkTimeout=10000 +validateDistributionUrl=true +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/extension/android_test/gradlew b/extension/android_test/gradlew new file mode 100755 index 0000000000..1aa94a4269 --- /dev/null +++ b/extension/android_test/gradlew @@ -0,0 +1,249 @@ +#!/bin/sh + +# +# Copyright © 2015-2021 the original authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +############################################################################## +# +# Gradle start up script for POSIX generated by Gradle. +# +# Important for running: +# +# (1) You need a POSIX-compliant shell to run this script. If your /bin/sh is +# noncompliant, but you have some other compliant shell such as ksh or +# bash, then to run this script, type that shell name before the whole +# command line, like: +# +# ksh Gradle +# +# Busybox and similar reduced shells will NOT work, because this script +# requires all of these POSIX shell features: +# * functions; +# * expansions «$var», «${var}», «${var:-default}», «${var+SET}», +# «${var#prefix}», «${var%suffix}», and «$( cmd )»; +# * compound commands having a testable exit status, especially «case»; +# * various built-in commands including «command», «set», and «ulimit». +# +# Important for patching: +# +# (2) This script targets any POSIX shell, so it avoids extensions provided +# by Bash, Ksh, etc; in particular arrays are avoided. +# +# The "traditional" practice of packing multiple parameters into a +# space-separated string is a well documented source of bugs and security +# problems, so this is (mostly) avoided, by progressively accumulating +# options in "$@", and eventually passing that to Java. +# +# Where the inherited environment variables (DEFAULT_JVM_OPTS, JAVA_OPTS, +# and GRADLE_OPTS) rely on word-splitting, this is performed explicitly; +# see the in-line comments for details. +# +# There are tweaks for specific operating systems such as AIX, CygWin, +# Darwin, MinGW, and NonStop. +# +# (3) This script is generated from the Groovy template +# https://github.com/gradle/gradle/blob/HEAD/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt +# within the Gradle project. +# +# You can find Gradle at https://github.com/gradle/gradle/. +# +############################################################################## + +# Attempt to set APP_HOME + +# Resolve links: $0 may be a link +app_path=$0 + +# Need this for daisy-chained symlinks. +while + APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path + [ -h "$app_path" ] +do + ls=$( ls -ld "$app_path" ) + link=${ls#*' -> '} + case $link in #( + /*) app_path=$link ;; #( + *) app_path=$APP_HOME$link ;; + esac +done + +# This is normally unused +# shellcheck disable=SC2034 +APP_BASE_NAME=${0##*/} +# Discard cd standard output in case $CDPATH is set (https://github.com/gradle/gradle/issues/25036) +APP_HOME=$( cd "${APP_HOME:-./}" > /dev/null && pwd -P ) || exit + +# Use the maximum available, or set MAX_FD != -1 to use that value. +MAX_FD=maximum + +warn () { + echo "$*" +} >&2 + +die () { + echo + echo "$*" + echo + exit 1 +} >&2 + +# OS specific support (must be 'true' or 'false'). +cygwin=false +msys=false +darwin=false +nonstop=false +case "$( uname )" in #( + CYGWIN* ) cygwin=true ;; #( + Darwin* ) darwin=true ;; #( + MSYS* | MINGW* ) msys=true ;; #( + NONSTOP* ) nonstop=true ;; +esac + +CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + + +# Determine the Java command to use to start the JVM. +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD=$JAVA_HOME/jre/sh/java + else + JAVACMD=$JAVA_HOME/bin/java + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD=java + if ! command -v java >/dev/null 2>&1 + then + die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +fi + +# Increase the maximum file descriptors if we can. +if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then + case $MAX_FD in #( + max*) + # In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC2039,SC3045 + MAX_FD=$( ulimit -H -n ) || + warn "Could not query maximum file descriptor limit" + esac + case $MAX_FD in #( + '' | soft) :;; #( + *) + # In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC2039,SC3045 + ulimit -n "$MAX_FD" || + warn "Could not set maximum file descriptor limit to $MAX_FD" + esac +fi + +# Collect all arguments for the java command, stacking in reverse order: +# * args from the command line +# * the main class name +# * -classpath +# * -D...appname settings +# * --module-path (only if needed) +# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables. + +# For Cygwin or MSYS, switch paths to Windows format before running java +if "$cygwin" || "$msys" ; then + APP_HOME=$( cygpath --path --mixed "$APP_HOME" ) + CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" ) + + JAVACMD=$( cygpath --unix "$JAVACMD" ) + + # Now convert the arguments - kludge to limit ourselves to /bin/sh + for arg do + if + case $arg in #( + -*) false ;; # don't mess with options #( + /?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath + [ -e "$t" ] ;; #( + *) false ;; + esac + then + arg=$( cygpath --path --ignore --mixed "$arg" ) + fi + # Roll the args list around exactly as many times as the number of + # args, so each arg winds up back in the position where it started, but + # possibly modified. + # + # NB: a `for` loop captures its iteration list before it begins, so + # changing the positional parameters here affects neither the number of + # iterations, nor the values presented in `arg`. + shift # remove old arg + set -- "$@" "$arg" # push replacement arg + done +fi + + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' + +# Collect all arguments for the java command: +# * DEFAULT_JVM_OPTS, JAVA_OPTS, JAVA_OPTS, and optsEnvironmentVar are not allowed to contain shell fragments, +# and any embedded shellness will be escaped. +# * For example: A user cannot expect ${Hostname} to be expanded, as it is an environment variable and will be +# treated as '${Hostname}' itself on the command line. + +set -- \ + "-Dorg.gradle.appname=$APP_BASE_NAME" \ + -classpath "$CLASSPATH" \ + org.gradle.wrapper.GradleWrapperMain \ + "$@" + +# Stop when "xargs" is not available. +if ! command -v xargs >/dev/null 2>&1 +then + die "xargs is not available" +fi + +# Use "xargs" to parse quoted args. +# +# With -n1 it outputs one arg per line, with the quotes and backslashes removed. +# +# In Bash we could simply go: +# +# readarray ARGS < <( xargs -n1 <<<"$var" ) && +# set -- "${ARGS[@]}" "$@" +# +# but POSIX shell has neither arrays nor command substitution, so instead we +# post-process each arg (as a line of input to sed) to backslash-escape any +# character that might be a shell metacharacter, then use eval to reverse +# that process (while maintaining the separation between arguments), and wrap +# the whole thing up as a single "set" statement. +# +# This will of course break if any of these variables contains a newline or +# an unmatched quote. +# + +eval "set -- $( + printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" | + xargs -n1 | + sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' | + tr '\n' ' ' + )" '"$@"' + +exec "$JAVACMD" "$@" diff --git a/extension/android_test/gradlew.bat b/extension/android_test/gradlew.bat new file mode 100644 index 0000000000..25da30dbde --- /dev/null +++ b/extension/android_test/gradlew.bat @@ -0,0 +1,92 @@ +@rem +@rem Copyright 2015 the original author or authors. +@rem +@rem Licensed under the Apache License, Version 2.0 (the "License"); +@rem you may not use this file except in compliance with the License. +@rem You may obtain a copy of the License at +@rem +@rem https://www.apache.org/licenses/LICENSE-2.0 +@rem +@rem Unless required by applicable law or agreed to in writing, software +@rem distributed under the License is distributed on an "AS IS" BASIS, +@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +@rem See the License for the specific language governing permissions and +@rem limitations under the License. +@rem + +@if "%DEBUG%"=="" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +set DIRNAME=%~dp0 +if "%DIRNAME%"=="" set DIRNAME=. +@rem This is normally unused +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Resolve any "." and ".." in APP_HOME to make it shorter. +for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if %ERRORLEVEL% equ 0 goto execute + +echo. 1>&2 +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto execute + +echo. 1>&2 +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 + +goto fail + +:execute +@rem Setup the command line + +set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar + + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %* + +:end +@rem End local scope for the variables with windows NT shell +if %ERRORLEVEL% equ 0 goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +set EXIT_CODE=%ERRORLEVEL% +if %EXIT_CODE% equ 0 set EXIT_CODE=1 +if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE% +exit /b %EXIT_CODE% + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/extension/android_test/settings.gradle b/extension/android_test/settings.gradle new file mode 100644 index 0000000000..6b1bd4f7f8 --- /dev/null +++ b/extension/android_test/settings.gradle @@ -0,0 +1,24 @@ +/* + * This file was generated by the Gradle 'init' task. + * + * The settings file is used to specify which projects to include in your build. + * For more detailed information on multi-project builds, please refer to https://docs.gradle.org/8.6/userguide/multi_project_builds.html in the Gradle documentation. + */ +pluginManagement { + repositories { + google() + mavenCentral() + gradlePluginPortal() + } +} + +dependencyResolutionManagement { + repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS) + repositories { + google() + mavenCentral() + } +} + +rootProject.name = 'executorch' +include('src') diff --git a/extension/android_test/setup.sh b/extension/android_test/setup.sh new file mode 100755 index 0000000000..d83aeeebb4 --- /dev/null +++ b/extension/android_test/setup.sh @@ -0,0 +1,61 @@ +#!/usr/bin/env bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -eu + +BUILD_AAR_DIR="$(mktemp -d)" +export BUILD_AAR_DIR + +BASEDIR=$(dirname "$0") +source "$BASEDIR"/../../build/build_android_llm_demo.sh + +build_native_library() { + ANDROID_ABI="$1" + CMAKE_OUT="cmake-out-android-${ANDROID_ABI}" + EXECUTORCH_CMAKE_BUILD_TYPE="${EXECUTORCH_CMAKE_BUILD_TYPE:-Release}" + cmake . -DCMAKE_INSTALL_PREFIX="${CMAKE_OUT}" \ + -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" \ + -DANDROID_ABI="${ANDROID_ABI}" \ + -DEXECUTORCH_BUILD_XNNPACK=ON \ + -DEXECUTORCH_XNNPACK_SHARED_WORKSPACE=ON \ + -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ + -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ + -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ + -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ + -DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \ + -DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \ + -B"${CMAKE_OUT}" + + cmake --build "${CMAKE_OUT}" -j16 --target install + + cmake extension/android \ + -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}"/build/cmake/android.toolchain.cmake \ + -DANDROID_ABI="${ANDROID_ABI}" \ + -DCMAKE_INSTALL_PREFIX=c"${CMAKE_OUT}" \ + -DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \ + -DEXECUTORCH_BUILD_LLAMA_JNI=ON \ + -B"${CMAKE_OUT}"/extension/android + + cmake --build "${CMAKE_OUT}"/extension/android -j16 + + # Copy artifacts to ABI specific directory + mkdir -p "${BUILD_AAR_DIR}/jni/${ANDROID_ABI}" + cp "${CMAKE_OUT}"/extension/android/*.so "${BUILD_AAR_DIR}/jni/${ANDROID_ABI}/" +} + +pushd "$BASEDIR"/../../ +build_jar +build_native_library "arm64-v8a" +build_native_library "x86_64" +build_aar +source ".ci/scripts/test_llama.sh" stories110M cmake fp16 portable ${BUILD_AAR_DIR} +popd +mkdir -p "$BASEDIR"/src/libs +cp "$BUILD_AAR_DIR/executorch.aar" "$BASEDIR"/src/libs/executorch.aar +python add_model.py +mv "add.pte" "$BASEDIR"/src/androidTest/resources/add.pte +unzip -o "$BUILD_AAR_DIR"/model.zip -d "$BASEDIR"/src/androidTest/resources diff --git a/extension/android_test/src/androidTest/java/org/pytorch/executorch/LlamaModuleInstrumentationTest.java b/extension/android_test/src/androidTest/java/org/pytorch/executorch/LlamaModuleInstrumentationTest.java new file mode 100644 index 0000000000..940e34d684 --- /dev/null +++ b/extension/android_test/src/androidTest/java/org/pytorch/executorch/LlamaModuleInstrumentationTest.java @@ -0,0 +1,119 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorch; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.fail; + +import android.os.Environment; +import androidx.test.rule.GrantPermissionRule; +import android.Manifest; +import android.content.Context; +import org.junit.Test; +import org.junit.Before; +import org.junit.Rule; +import org.junit.runner.RunWith; +import java.io.InputStream; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.List; +import java.util.ArrayList; +import java.io.IOException; +import java.io.File; +import java.io.FileOutputStream; +import org.junit.runners.JUnit4; +import org.apache.commons.io.FileUtils; +import androidx.test.ext.junit.runners.AndroidJUnit4; +import androidx.test.InstrumentationRegistry; +import org.pytorch.executorch.LlamaModule; +import org.pytorch.executorch.LlamaCallback; +import org.pytorch.executorch.Module; +import org.pytorch.executorch.EValue; +import org.pytorch.executorch.Tensor; + +/** Unit tests for {@link LlamaModule}. */ +@RunWith(AndroidJUnit4.class) +public class LlamaModuleInstrumentationTest implements LlamaCallback { + private static String TEST_FILE_NAME = "/tinyllama_portable_fp16_h.pte"; + private static String TOKENIZER_FILE_NAME = "/tokenizer.bin"; + private static String TEST_PROMPT = "Hello"; + private static int OK = 0x00; + private static int SEQ_LEN = 32; + + private final List results = new ArrayList<>(); + private final List tokensPerSecond = new ArrayList<>(); + private LlamaModule mModule; + + private static String getTestFilePath(String fileName) { + return InstrumentationRegistry.getInstrumentation().getTargetContext().getExternalCacheDir() + fileName; + } + + @Before + public void setUp() throws IOException { + // copy zipped test resources to local device + File addPteFile = new File(getTestFilePath(TEST_FILE_NAME)); + InputStream inputStream = getClass().getResourceAsStream(TEST_FILE_NAME); + FileUtils.copyInputStreamToFile(inputStream, addPteFile); + inputStream.close(); + + File tokenizerFile = new File(getTestFilePath(TOKENIZER_FILE_NAME)); + inputStream = getClass().getResourceAsStream(TOKENIZER_FILE_NAME); + FileUtils.copyInputStreamToFile(inputStream, tokenizerFile); + inputStream.close(); + + mModule = new LlamaModule(getTestFilePath(TEST_FILE_NAME), getTestFilePath(TOKENIZER_FILE_NAME), 0.0f); + } + + @Rule + public GrantPermissionRule mRuntimePermissionRule = GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE); + + @Test + public void testGenerate() throws IOException, URISyntaxException{ + int loadResult = mModule.load(); + // Check that the model can be load successfully + assertEquals(OK, loadResult); + + mModule.generate(TEST_PROMPT, SEQ_LEN, LlamaModuleInstrumentationTest.this); + assertEquals(results.size(), SEQ_LEN); + assertTrue(tokensPerSecond.get(tokensPerSecond.size() - 1) > 0); + } + + @Test + public void testGenerateAndStop() throws IOException, URISyntaxException{ + int seqLen = 32; + mModule.generate(TEST_PROMPT, SEQ_LEN, new LlamaCallback() { + @Override + public void onResult(String result) { + LlamaModuleInstrumentationTest.this.onResult(result); + mModule.stop(); + } + + @Override + public void onStats(float tps) { + LlamaModuleInstrumentationTest.this.onStats(tps); + } + }); + + int stoppedResultSize = results.size(); + assertTrue(stoppedResultSize < SEQ_LEN); + } + + @Override + public void onResult(String result) { + results.add(result); + } + + @Override + public void onStats(float tps) { + tokensPerSecond.add(tps); + } +} diff --git a/extension/android_test/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.java b/extension/android_test/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.java new file mode 100644 index 0000000000..e8259969ab --- /dev/null +++ b/extension/android_test/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.java @@ -0,0 +1,130 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorch; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.fail; + +import android.os.Environment; +import androidx.test.rule.GrantPermissionRule; +import android.Manifest; +import android.content.Context; +import org.junit.Test; +import org.junit.Before; +import org.junit.Rule; +import org.junit.runner.RunWith; +import java.io.InputStream; +import java.net.URI; +import java.net.URISyntaxException; +import java.io.IOException; +import java.io.File; +import java.io.FileOutputStream; +import org.junit.runners.JUnit4; +import org.apache.commons.io.FileUtils; +import androidx.test.ext.junit.runners.AndroidJUnit4; +import androidx.test.InstrumentationRegistry; +import org.pytorch.executorch.Module; +import org.pytorch.executorch.EValue; +import org.pytorch.executorch.Tensor; + +/** Unit tests for {@link Module}. */ +@RunWith(AndroidJUnit4.class) +public class ModuleInstrumentationTest { + private static String TEST_FILE_NAME = "/add.pte"; + private static String MISSING_FILE_NAME = "/missing.pte"; + private static String NON_PTE_FILE_NAME = "/test.txt"; + private static String FORWARD_METHOD = "forward"; + private static String NONE_METHOD = "none"; + private static int OK = 0x00; + private static int INVALID_ARGUMENT = 0x12; + private static int ACCESS_FAILED = 0x22; + + private static String getTestFilePath(String fileName) { + return InstrumentationRegistry.getInstrumentation().getTargetContext().getExternalCacheDir() + fileName; + } + + @Before + public void setUp() throws IOException { + // copy zipped test resources to local device + File addPteFile = new File(getTestFilePath(TEST_FILE_NAME)); + InputStream inputStream = getClass().getResourceAsStream(TEST_FILE_NAME); + FileUtils.copyInputStreamToFile(inputStream, addPteFile); + inputStream.close(); + + File nonPteFile = new File(getTestFilePath(NON_PTE_FILE_NAME)); + inputStream = getClass().getResourceAsStream(NON_PTE_FILE_NAME); + FileUtils.copyInputStreamToFile(inputStream, nonPteFile); + inputStream.close(); + } + + @Rule + public GrantPermissionRule mRuntimePermissionRule = GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE); + + @Test + public void testModuleLoadAndForward() throws IOException, URISyntaxException{ + Module module = Module.load(getTestFilePath(TEST_FILE_NAME)); + + EValue[] results = module.forward(); + assertTrue(results[0].isTensor()); + } + + @Test + public void testModuleLoadMethodAndForward() throws IOException{ + Module module = Module.load(getTestFilePath(TEST_FILE_NAME)); + + int loadMethod = module.loadMethod(FORWARD_METHOD); + assertEquals(loadMethod, OK); + + EValue[] results = module.forward(); + assertTrue(results[0].isTensor()); + } + + @Test + public void testModuleLoadForwardExplicit() throws IOException{ + Module module = Module.load(getTestFilePath(TEST_FILE_NAME)); + + EValue[] results = module.execute(FORWARD_METHOD); + assertTrue(results[0].isTensor()); + } + + @Test + public void testModuleLoadNonExistantFile() throws IOException{ + Module module = Module.load(getTestFilePath(MISSING_FILE_NAME)); + + EValue[] results = module.forward(); + assertEquals(null, results); + } + + @Test + public void testModuleLoadMethodNonExistantFile() throws IOException{ + Module module = Module.load(getTestFilePath(MISSING_FILE_NAME)); + + int loadMethod = module.loadMethod(FORWARD_METHOD); + assertEquals(loadMethod, ACCESS_FAILED); + } + + @Test + public void testModuleLoadMethodNonExistantMethod() throws IOException{ + Module module = Module.load(getTestFilePath(TEST_FILE_NAME)); + + int loadMethod = module.loadMethod(NONE_METHOD); + assertEquals(loadMethod, INVALID_ARGUMENT); + } + + @Test + public void testNonPteFile() throws IOException{ + Module module = Module.load(getTestFilePath(NON_PTE_FILE_NAME)); + + int loadMethod = module.loadMethod(FORWARD_METHOD); + assertEquals(loadMethod, INVALID_ARGUMENT); + } +} diff --git a/extension/android_test/src/androidTest/resources/test.txt b/extension/android_test/src/androidTest/resources/test.txt new file mode 100644 index 0000000000..039461e6a9 --- /dev/null +++ b/extension/android_test/src/androidTest/resources/test.txt @@ -0,0 +1 @@ +non pte file diff --git a/extension/android_test/src/main/AndroidManifest.xml b/extension/android_test/src/main/AndroidManifest.xml new file mode 100644 index 0000000000..b8ac862938 --- /dev/null +++ b/extension/android_test/src/main/AndroidManifest.xml @@ -0,0 +1,12 @@ + + + + + + + + + diff --git a/extension/android/src/test/java/org/pytorch/executorch/EValueTest.java b/extension/android_test/src/test/java/org/pytorch/executorch/EValueTest.java similarity index 99% rename from extension/android/src/test/java/org/pytorch/executorch/EValueTest.java rename to extension/android_test/src/test/java/org/pytorch/executorch/EValueTest.java index 35367883ef..29cabae75f 100644 --- a/extension/android/src/test/java/org/pytorch/executorch/EValueTest.java +++ b/extension/android_test/src/test/java/org/pytorch/executorch/EValueTest.java @@ -129,7 +129,7 @@ public void testOptionalTensorListValue() { Optional.of(Tensor.fromBlob(data[1], shape[1]))); assertTrue(evalue.isOptionalTensorList()); - assertTrue(evalue.toOptionalTensorList()[0].isEmpty()); + assertTrue(!evalue.toOptionalTensorList()[0].isPresent()); assertTrue(evalue.toOptionalTensorList()[1].isPresent()); assertTrue(Arrays.equals(evalue.toOptionalTensorList()[1].get().shape, shape[0])); diff --git a/extension/android_test/src/test/java/org/pytorch/executorch/TensorTest.java b/extension/android_test/src/test/java/org/pytorch/executorch/TensorTest.java new file mode 100644 index 0000000000..7933113412 --- /dev/null +++ b/extension/android_test/src/test/java/org/pytorch/executorch/TensorTest.java @@ -0,0 +1,270 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.fail; + +import java.nio.ByteBuffer; +import java.nio.DoubleBuffer; +import java.nio.FloatBuffer; +import java.nio.IntBuffer; +import java.nio.LongBuffer; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.pytorch.executorch.Tensor; + +/** Unit tests for {@link Tensor}. */ +@RunWith(JUnit4.class) +public class TensorTest { + + @Test + public void testFloatTensor() { + float data[] = {Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE}; + long shape[] = {2, 2}; + Tensor tensor = Tensor.fromBlob(data, shape); + assertEquals(tensor.dtype(), DType.FLOAT); + assertEquals(shape[0], tensor.shape()[0]); + assertEquals(shape[1], tensor.shape()[1]); + assertEquals(4, tensor.numel()); + assertEquals(data[0], tensor.getDataAsFloatArray()[0], 1e-5); + assertEquals(data[1], tensor.getDataAsFloatArray()[1], 1e-5); + assertEquals(data[2], tensor.getDataAsFloatArray()[2], 1e-5); + assertEquals(data[3], tensor.getDataAsFloatArray()[3], 1e-5); + + FloatBuffer floatBuffer = Tensor.allocateFloatBuffer(4); + floatBuffer.put(data); + tensor = Tensor.fromBlob(floatBuffer, shape); + assertEquals(tensor.dtype(), DType.FLOAT); + assertEquals(shape[0], tensor.shape()[0]); + assertEquals(shape[1], tensor.shape()[1]); + assertEquals(4, tensor.numel()); + assertEquals(data[0], tensor.getDataAsFloatArray()[0], 1e-5); + assertEquals(data[1], tensor.getDataAsFloatArray()[1], 1e-5); + assertEquals(data[2], tensor.getDataAsFloatArray()[2], 1e-5); + assertEquals(data[3], tensor.getDataAsFloatArray()[3], 1e-5); + } + + @Test + public void testIntTensor() { + int data[] = {Integer.MIN_VALUE, 0, 1, Integer.MAX_VALUE}; + long shape[] = {1, 4, 1}; + Tensor tensor = Tensor.fromBlob(data, shape); + assertEquals(tensor.dtype(), DType.INT32); + assertEquals(shape[0], tensor.shape()[0]); + assertEquals(shape[1], tensor.shape()[1]); + assertEquals(shape[2], tensor.shape()[2]); + assertEquals(4, tensor.numel()); + assertEquals(data[0], tensor.getDataAsIntArray()[0]); + assertEquals(data[1], tensor.getDataAsIntArray()[1]); + assertEquals(data[2], tensor.getDataAsIntArray()[2]); + assertEquals(data[3], tensor.getDataAsIntArray()[3]); + + IntBuffer intBuffer = Tensor.allocateIntBuffer(4); + intBuffer.put(data); + tensor = Tensor.fromBlob(intBuffer, shape); + assertEquals(tensor.dtype(), DType.INT32); + assertEquals(shape[0], tensor.shape()[0]); + assertEquals(shape[1], tensor.shape()[1]); + assertEquals(shape[2], tensor.shape()[2]); + assertEquals(4, tensor.numel()); + assertEquals(data[0], tensor.getDataAsIntArray()[0]); + assertEquals(data[1], tensor.getDataAsIntArray()[1]); + assertEquals(data[2], tensor.getDataAsIntArray()[2]); + assertEquals(data[3], tensor.getDataAsIntArray()[3]); + } + + @Test + public void testDoubleTensor() { + double data[] = {Double.MIN_VALUE, 0.0d, 0.1d, Double.MAX_VALUE}; + long shape[] = {1, 4}; + Tensor tensor = Tensor.fromBlob(data, shape); + assertEquals(tensor.dtype(), DType.DOUBLE); + assertEquals(shape[0], tensor.shape()[0]); + assertEquals(shape[1], tensor.shape()[1]); + assertEquals(4, tensor.numel()); + assertEquals(data[0], tensor.getDataAsDoubleArray()[0], 1e-5); + assertEquals(data[1], tensor.getDataAsDoubleArray()[1], 1e-5); + assertEquals(data[2], tensor.getDataAsDoubleArray()[2], 1e-5); + assertEquals(data[3], tensor.getDataAsDoubleArray()[3], 1e-5); + + DoubleBuffer doubleBuffer = Tensor.allocateDoubleBuffer(4); + doubleBuffer.put(data); + tensor = Tensor.fromBlob(doubleBuffer, shape); + assertEquals(tensor.dtype(), DType.DOUBLE); + assertEquals(shape[0], tensor.shape()[0]); + assertEquals(shape[1], tensor.shape()[1]); + assertEquals(4, tensor.numel()); + assertEquals(data[0], tensor.getDataAsDoubleArray()[0], 1e-5); + assertEquals(data[1], tensor.getDataAsDoubleArray()[1], 1e-5); + assertEquals(data[2], tensor.getDataAsDoubleArray()[2], 1e-5); + assertEquals(data[3], tensor.getDataAsDoubleArray()[3], 1e-5); + } + + @Test + public void testLongTensor() { + long data[] = {Long.MIN_VALUE, 0L, 1L, Long.MAX_VALUE}; + long shape[] = {4, 1}; + Tensor tensor = Tensor.fromBlob(data, shape); + assertEquals(tensor.dtype(), DType.INT64); + assertEquals(shape[0], tensor.shape()[0]); + assertEquals(shape[1], tensor.shape()[1]); + assertEquals(4, tensor.numel()); + assertEquals(data[0], tensor.getDataAsLongArray()[0]); + assertEquals(data[1], tensor.getDataAsLongArray()[1]); + assertEquals(data[2], tensor.getDataAsLongArray()[2]); + assertEquals(data[3], tensor.getDataAsLongArray()[3]); + + LongBuffer longBuffer = Tensor.allocateLongBuffer(4); + longBuffer.put(data); + tensor = Tensor.fromBlob(longBuffer, shape); + assertEquals(tensor.dtype(), DType.INT64); + assertEquals(shape[0], tensor.shape()[0]); + assertEquals(shape[1], tensor.shape()[1]); + assertEquals(4, tensor.numel()); + assertEquals(data[0], tensor.getDataAsLongArray()[0]); + assertEquals(data[1], tensor.getDataAsLongArray()[1]); + assertEquals(data[2], tensor.getDataAsLongArray()[2]); + assertEquals(data[3], tensor.getDataAsLongArray()[3]); + } + + @Test + public void testSignedByteTensor() { + byte data[] = {Byte.MIN_VALUE, (byte) 0, (byte) 1, Byte.MAX_VALUE}; + long shape[] = {1, 1, 4}; + Tensor tensor = Tensor.fromBlob(data, shape); + assertEquals(tensor.dtype(), DType.INT8); + assertEquals(shape[0], tensor.shape()[0]); + assertEquals(shape[1], tensor.shape()[1]); + assertEquals(shape[2], tensor.shape()[2]); + assertEquals(4, tensor.numel()); + assertEquals(data[0], tensor.getDataAsByteArray()[0]); + assertEquals(data[1], tensor.getDataAsByteArray()[1]); + assertEquals(data[2], tensor.getDataAsByteArray()[2]); + assertEquals(data[3], tensor.getDataAsByteArray()[3]); + + ByteBuffer byteBuffer = Tensor.allocateByteBuffer(4); + byteBuffer.put(data); + tensor = Tensor.fromBlob(byteBuffer, shape); + assertEquals(tensor.dtype(), DType.INT8); + assertEquals(shape[0], tensor.shape()[0]); + assertEquals(shape[1], tensor.shape()[1]); + assertEquals(shape[2], tensor.shape()[2]); + assertEquals(4, tensor.numel()); + assertEquals(data[0], tensor.getDataAsByteArray()[0]); + assertEquals(data[1], tensor.getDataAsByteArray()[1]); + assertEquals(data[2], tensor.getDataAsByteArray()[2]); + assertEquals(data[3], tensor.getDataAsByteArray()[3]); + } + + @Test + public void testUnsignedByteTensor() { + byte data[] = {(byte) 0, (byte) 1, (byte) 2, (byte) 255}; + long shape[] = {4, 1, 1}; + Tensor tensor = Tensor.fromBlobUnsigned(data, shape); + assertEquals(tensor.dtype(), DType.UINT8); + assertEquals(shape[0], tensor.shape()[0]); + assertEquals(shape[1], tensor.shape()[1]); + assertEquals(shape[2], tensor.shape()[2]); + assertEquals(4, tensor.numel()); + assertEquals(data[0], tensor.getDataAsUnsignedByteArray()[0]); + assertEquals(data[1], tensor.getDataAsUnsignedByteArray()[1]); + assertEquals(data[2], tensor.getDataAsUnsignedByteArray()[2]); + assertEquals(data[3], tensor.getDataAsUnsignedByteArray()[3]); + + ByteBuffer byteBuffer = Tensor.allocateByteBuffer(4); + byteBuffer.put(data); + tensor = Tensor.fromBlobUnsigned(byteBuffer, shape); + assertEquals(tensor.dtype(), DType.UINT8); + assertEquals(shape[0], tensor.shape()[0]); + assertEquals(shape[1], tensor.shape()[1]); + assertEquals(shape[2], tensor.shape()[2]); + assertEquals(4, tensor.numel()); + assertEquals(data[0], tensor.getDataAsUnsignedByteArray()[0]); + assertEquals(data[1], tensor.getDataAsUnsignedByteArray()[1]); + assertEquals(data[2], tensor.getDataAsUnsignedByteArray()[2]); + assertEquals(data[3], tensor.getDataAsUnsignedByteArray()[3]); + } + + @Test + public void testIllegalDataTypeException() { + float data[] = {Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE}; + long shape[] = {2, 2}; + Tensor tensor = Tensor.fromBlob(data, shape); + assertEquals(tensor.dtype(), DType.FLOAT); + + try { + tensor.getDataAsByteArray(); + fail("Should have thrown an exception"); + } catch (IllegalStateException e) { + // expected + } + try { + tensor.getDataAsUnsignedByteArray(); + fail("Should have thrown an exception"); + } catch (IllegalStateException e) { + // expected + } + try { + tensor.getDataAsIntArray(); + fail("Should have thrown an exception"); + } catch (IllegalStateException e) { + // expected + } + try { + tensor.getDataAsDoubleArray(); + fail("Should have thrown an exception"); + } catch (IllegalStateException e) { + // expected + } + try { + tensor.getDataAsLongArray(); + fail("Should have thrown an exception"); + } catch (IllegalStateException e) { + // expected + } + } + + @Test + public void testIllegalArguments() { + float data[] = {Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE}; + long shapeWithNegativeValues[] = {-1, 2}; + long mismatchShape[] = {1, 2}; + + try { + Tensor tensor = Tensor.fromBlob((float[]) null, mismatchShape); + fail("Should have thrown an exception"); + } catch (IllegalArgumentException e) { + // expected + } + try { + Tensor tensor = Tensor.fromBlob(data, null); + fail("Should have thrown an exception"); + } catch (IllegalArgumentException e) { + // expected + } + try { + Tensor tensor = Tensor.fromBlob(data, shapeWithNegativeValues); + fail("Should have thrown an exception"); + } catch (IllegalArgumentException e) { + // expected + } + try { + Tensor tensor = Tensor.fromBlob(data, mismatchShape); + fail("Should have thrown an exception"); + } catch (IllegalArgumentException e) { + // expected + } + } +} diff --git a/extension/data_loader/file_data_loader.cpp b/extension/data_loader/file_data_loader.cpp index 0324751bfa..1d097cfd98 100644 --- a/extension/data_loader/file_data_loader.cpp +++ b/extension/data_loader/file_data_loader.cpp @@ -43,8 +43,6 @@ namespace extension { namespace { -static constexpr char kFdFilesystemPrefix[] = "fd:///"; - /** * Returns true if the value is an integer power of 2. */ @@ -76,36 +74,25 @@ FileDataLoader::~FileDataLoader() { ::close(fd_); } -Result getFDFromUri(const char* file_descriptor_uri) { - // check if the uri starts with the prefix "fd://" +Result FileDataLoader::from( + const char* file_name, + size_t alignment) { ET_CHECK_OR_RETURN_ERROR( - strncmp( - file_descriptor_uri, - kFdFilesystemPrefix, - strlen(kFdFilesystemPrefix)) == 0, + is_power_of_2(alignment), InvalidArgument, - "File descriptor uri (%s) does not start with %s", - file_descriptor_uri, - kFdFilesystemPrefix); - - // strip "fd:///" from the uri - int fd_len = strlen(file_descriptor_uri) - strlen(kFdFilesystemPrefix); - char fd_without_prefix[fd_len + 1]; - memcpy( - fd_without_prefix, - &file_descriptor_uri[strlen(kFdFilesystemPrefix)], - fd_len); - fd_without_prefix[fd_len] = '\0'; + "Alignment %zu is not a power of 2", + alignment); - // check if remaining fd string is a valid integer - int fd = ::atoi(fd_without_prefix); - return fd; -} + // Use open() instead of fopen() to avoid the layer of buffering that + // fopen() does. We will be reading large portions of the file in one shot, + // so buffering does not help. + int fd = ::open(file_name, O_RDONLY); + if (fd < 0) { + ET_LOG( + Error, "Failed to open %s: %s (%d)", file_name, strerror(errno), errno); + return Error::AccessFailed; + } -Result FileDataLoader::fromFileDescriptor( - const char* file_name, - const int fd, - size_t alignment) { // Cache the file size. struct stat st; int err = ::fstat(fd, &st); @@ -132,47 +119,6 @@ Result FileDataLoader::fromFileDescriptor( return FileDataLoader(fd, file_size, alignment, file_name_copy); } -Result FileDataLoader::fromFileDescriptorUri( - const char* file_descriptor_uri, - size_t alignment) { - ET_CHECK_OR_RETURN_ERROR( - is_power_of_2(alignment), - InvalidArgument, - "Alignment %zu is not a power of 2", - alignment); - - auto parsed_fd = getFDFromUri(file_descriptor_uri); - if (!parsed_fd.ok()) { - return parsed_fd.error(); - } - - int fd = parsed_fd.get(); - - return fromFileDescriptor(file_descriptor_uri, fd, alignment); -} - -Result FileDataLoader::from( - const char* file_name, - size_t alignment) { - ET_CHECK_OR_RETURN_ERROR( - is_power_of_2(alignment), - InvalidArgument, - "Alignment %zu is not a power of 2", - alignment); - - // Use open() instead of fopen() to avoid the layer of buffering that - // fopen() does. We will be reading large portions of the file in one shot, - // so buffering does not help. - int fd = ::open(file_name, O_RDONLY); - if (fd < 0) { - ET_LOG( - Error, "Failed to open %s: %s (%d)", file_name, strerror(errno), errno); - return Error::AccessFailed; - } - - return fromFileDescriptor(file_name, fd, alignment); -} - namespace { /** * FreeableBuffer::FreeFn-compatible callback. diff --git a/extension/data_loader/file_data_loader.h b/extension/data_loader/file_data_loader.h index 959684137b..7cf2a92c4a 100644 --- a/extension/data_loader/file_data_loader.h +++ b/extension/data_loader/file_data_loader.h @@ -26,27 +26,6 @@ namespace extension { */ class FileDataLoader final : public executorch::runtime::DataLoader { public: - /** - * Creates a new FileDataLoader that wraps the named file descriptor, and the - * ownership of the file descriptor is passed. This helper is used when ET is - * running in a process that does not have access to the filesystem, and the - * caller is able to open the file and pass the file descriptor. - * - * @param[in] file_descriptor_uri File descriptor with the prefix "fd:///", - * followed by the file descriptor number. - * @param[in] alignment Alignment in bytes of pointers returned by this - * instance. Must be a power of two. - * - * @returns A new FileDataLoader on success. - * @retval Error::InvalidArgument `alignment` is not a power of two. - * @retval Error::AccessFailed `file_name` could not be opened, or its size - * could not be found. - * @retval Error::MemoryAllocationFailed Internal memory allocation failure. - */ - static executorch::runtime::Result fromFileDescriptorUri( - const char* file_descriptor_uri, - size_t alignment = alignof(std::max_align_t)); - /** * Creates a new FileDataLoader that wraps the named file. * @@ -100,11 +79,6 @@ class FileDataLoader final : public executorch::runtime::DataLoader { void* buffer) const override; private: - static executorch::runtime::Result fromFileDescriptor( - const char* file_name, - const int fd, - size_t alignment = alignof(std::max_align_t)); - FileDataLoader( int fd, size_t file_size, diff --git a/extension/data_loader/file_descriptor_data_loader.cpp b/extension/data_loader/file_descriptor_data_loader.cpp new file mode 100644 index 0000000000..48e81fd706 --- /dev/null +++ b/extension/data_loader/file_descriptor_data_loader.cpp @@ -0,0 +1,292 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include + +using executorch::runtime::Error; +using executorch::runtime::FreeableBuffer; +using executorch::runtime::Result; + +namespace executorch { +namespace extension { + +namespace { + +static constexpr char kFdFilesystemPrefix[] = "fd:///"; + +/** + * Returns true if the value is an integer power of 2. + */ +static bool is_power_of_2(size_t value) { + return value > 0 && (value & ~(value - 1)) == value; +} + +/** + * Returns the next alignment for a given pointer. + */ +static uint8_t* align_pointer(void* ptr, size_t alignment) { + intptr_t addr = reinterpret_cast(ptr); + if ((addr & (alignment - 1)) == 0) { + // Already aligned. + return reinterpret_cast(ptr); + } + // Bump forward. + addr = (addr | (alignment - 1)) + 1; + return reinterpret_cast(addr); +} +} // namespace + +FileDescriptorDataLoader::~FileDescriptorDataLoader() { + // file_descriptor_uri_ can be nullptr if this instance was moved from, but + // freeing a null pointer is safe. + std::free(const_cast(file_descriptor_uri_)); + // fd_ can be -1 if this instance was moved from, but closing a negative fd is + // safe (though it will return an error). + ::close(fd_); +} + +static Result getFDFromUri(const char* file_descriptor_uri) { + // check if the uri starts with the prefix "fd://" + ET_CHECK_OR_RETURN_ERROR( + strncmp( + file_descriptor_uri, + kFdFilesystemPrefix, + strlen(kFdFilesystemPrefix)) == 0, + InvalidArgument, + "File descriptor uri (%s) does not start with %s", + file_descriptor_uri, + kFdFilesystemPrefix); + + // strip "fd:///" from the uri + int fd_len = strlen(file_descriptor_uri) - strlen(kFdFilesystemPrefix); + char fd_without_prefix[fd_len + 1]; + memcpy( + fd_without_prefix, + &file_descriptor_uri[strlen(kFdFilesystemPrefix)], + fd_len); + fd_without_prefix[fd_len] = '\0'; + + // check if remaining fd string is a valid integer + int fd = ::atoi(fd_without_prefix); + return fd; +} + +Result +FileDescriptorDataLoader::fromFileDescriptorUri( + const char* file_descriptor_uri, + size_t alignment) { + ET_CHECK_OR_RETURN_ERROR( + is_power_of_2(alignment), + InvalidArgument, + "Alignment %zu is not a power of 2", + alignment); + + auto parsed_fd = getFDFromUri(file_descriptor_uri); + if (!parsed_fd.ok()) { + return parsed_fd.error(); + } + + int fd = parsed_fd.get(); + + // Cache the file size. + struct stat st; + int err = ::fstat(fd, &st); + if (err < 0) { + ET_LOG( + Error, + "Could not get length of %s: %s (%d)", + file_descriptor_uri, + ::strerror(errno), + errno); + ::close(fd); + return Error::AccessFailed; + } + size_t file_size = st.st_size; + + // Copy the filename so we can print better debug messages if reads fail. + const char* file_descriptor_uri_copy = ::strdup(file_descriptor_uri); + if (file_descriptor_uri_copy == nullptr) { + ET_LOG(Error, "strdup(%s) failed", file_descriptor_uri); + ::close(fd); + return Error::MemoryAllocationFailed; + } + + return FileDescriptorDataLoader( + fd, file_size, alignment, file_descriptor_uri_copy); +} + +namespace { +/** + * FreeableBuffer::FreeFn-compatible callback. + * + * `context` is actually a ptrdiff_t value (not a pointer) that contains the + * offset in bytes between `data` and the actual pointer to free. + */ +void FreeSegment(void* context, void* data, ET_UNUSED size_t size) { + ptrdiff_t offset = reinterpret_cast(context); + ET_DCHECK_MSG(offset >= 0, "Unexpected offset %ld", (long int)offset); + std::free(static_cast(data) - offset); +} +} // namespace + +Result FileDescriptorDataLoader::load( + size_t offset, + size_t size, + ET_UNUSED const DataLoader::SegmentInfo& segment_info) const { + ET_CHECK_OR_RETURN_ERROR( + // Probably had its value moved to another instance. + fd_ >= 0, + InvalidState, + "Uninitialized"); + ET_CHECK_OR_RETURN_ERROR( + offset + size <= file_size_, + InvalidArgument, + "File %s: offset %zu + size %zu > file_size_ %zu", + file_descriptor_uri_, + offset, + size, + file_size_); + + // Don't bother allocating/freeing for empty segments. + if (size == 0) { + return FreeableBuffer(nullptr, 0, /*free_fn=*/nullptr); + } + + // Allocate memory for the FreeableBuffer. + size_t alloc_size = size; + if (alignment_ > alignof(std::max_align_t)) { + // malloc() will align to smaller values, but we must manually align to + // larger values. + alloc_size += alignment_; + } + void* buffer = std::malloc(alloc_size); + if (buffer == nullptr) { + ET_LOG( + Error, + "Reading from %s at offset %zu: malloc(%zd) failed", + file_descriptor_uri_, + offset, + size); + return Error::MemoryAllocationFailed; + } + + // Align. + void* aligned_buffer = align_pointer(buffer, alignment_); + + // Assert that the alignment didn't overflow the buffer. + ET_DCHECK_MSG( + reinterpret_cast(aligned_buffer) + size <= + reinterpret_cast(buffer) + alloc_size, + "aligned_buffer %p + size %zu > buffer %p + alloc_size %zu", + aligned_buffer, + size, + buffer, + alloc_size); + + auto err = load_into(offset, size, segment_info, aligned_buffer); + if (err != Error::Ok) { + // Free `buffer`, which is what malloc() gave us, not `aligned_buffer`. + std::free(buffer); + return err; + } + + // We can't naively free this pointer, since it may not be what malloc() gave + // us. Pass the offset to the real buffer as context. This is the number of + // bytes that need to be subtracted from the FreeableBuffer::data() pointer to + // find the actual pointer to free. + return FreeableBuffer( + aligned_buffer, + size, + FreeSegment, + /*free_fn_context=*/ + reinterpret_cast( + // Using signed types here because it will produce a signed ptrdiff_t + // value, though for us it will always be non-negative. + reinterpret_cast(aligned_buffer) - + reinterpret_cast(buffer))); +} + +Result FileDescriptorDataLoader::size() const { + ET_CHECK_OR_RETURN_ERROR( + // Probably had its value moved to another instance. + fd_ >= 0, + InvalidState, + "Uninitialized"); + return file_size_; +} + +ET_NODISCARD Error FileDescriptorDataLoader::load_into( + size_t offset, + size_t size, + ET_UNUSED const SegmentInfo& segment_info, + void* buffer) const { + ET_CHECK_OR_RETURN_ERROR( + // Probably had its value moved to another instance. + fd_ >= 0, + InvalidState, + "Uninitialized"); + ET_CHECK_OR_RETURN_ERROR( + offset + size <= file_size_, + InvalidArgument, + "File %s: offset %zu + size %zu > file_size_ %zu", + file_descriptor_uri_, + offset, + size, + file_size_); + ET_CHECK_OR_RETURN_ERROR( + buffer != nullptr, InvalidArgument, "Provided buffer cannot be null"); + + // Read the data into the aligned address. + size_t needed = size; + uint8_t* buf = reinterpret_cast(buffer); + + while (needed > 0) { + // Reads on macOS will fail with EINVAL if size > INT32_MAX. + const auto chunk_size = std::min( + needed, static_cast(std::numeric_limits::max())); + const auto nread = ::pread(fd_, buf, chunk_size, offset); + if (nread < 0 && errno == EINTR) { + // Interrupted by a signal; zero bytes read. + continue; + } + if (nread <= 0) { + // nread == 0 means EOF, which we shouldn't see if we were able to read + // the full amount. nread < 0 means an error occurred. + ET_LOG( + Error, + "Reading from %s: failed to read %zu bytes at offset %zu: %s", + file_descriptor_uri_, + size, + offset, + nread == 0 ? "EOF" : strerror(errno)); + return Error::AccessFailed; + } + needed -= nread; + buf += nread; + offset += nread; + } + return Error::Ok; +} + +} // namespace extension +} // namespace executorch diff --git a/extension/data_loader/file_descriptor_data_loader.h b/extension/data_loader/file_descriptor_data_loader.h new file mode 100644 index 0000000000..6f51f0f7a6 --- /dev/null +++ b/extension/data_loader/file_descriptor_data_loader.h @@ -0,0 +1,112 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include +#include +#include + +namespace executorch { +namespace extension { + +/** + * A DataLoader that loads segments from a file descriptor, allocating the + * memory with `malloc()`. This data loader is used when ET is running in a + * process that does not have access to the filesystem, and the caller is able + * to open the file and pass the file descriptor. + * + * Note that this will keep the file open for the duration of its lifetime, to + * avoid the overhead of opening it again for every load() call. + */ +class FileDescriptorDataLoader final : public executorch::runtime::DataLoader { + public: + /** + * Creates a new FileDescriptorDataLoader that wraps the named file + * descriptor, and the ownership of the file descriptor is passed. + * + * @param[in] file_descriptor_uri File descriptor with the prefix "fd:///", + * followed by the file descriptor number. + * @param[in] alignment Alignment in bytes of pointers returned by this + * instance. Must be a power of two. + * + * @returns A new FileDescriptorDataLoader on success. + * @retval Error::InvalidArgument `alignment` is not a power of two. + * @retval Error::AccessFailed `file_descriptor_uri` is incorrectly formatted, + * or its size could not be found. + * @retval Error::MemoryAllocationFailed Internal memory allocation failure. + */ + static executorch::runtime::Result + fromFileDescriptorUri( + const char* file_descriptor_uri, + size_t alignment = alignof(std::max_align_t)); + + // Movable to be compatible with Result. + FileDescriptorDataLoader(FileDescriptorDataLoader&& rhs) noexcept + : file_descriptor_uri_(rhs.file_descriptor_uri_), + file_size_(rhs.file_size_), + alignment_(rhs.alignment_), + fd_(rhs.fd_) { + const_cast(rhs.file_descriptor_uri_) = nullptr; + const_cast(rhs.file_size_) = 0; + const_cast(rhs.alignment_) = 0; + const_cast(rhs.fd_) = -1; + } + + ~FileDescriptorDataLoader() override; + + ET_NODISCARD + executorch::runtime::Result load( + size_t offset, + size_t size, + const DataLoader::SegmentInfo& segment_info) const override; + + ET_NODISCARD executorch::runtime::Result size() const override; + + ET_NODISCARD executorch::runtime::Error load_into( + size_t offset, + size_t size, + ET_UNUSED const SegmentInfo& segment_info, + void* buffer) const override; + + private: + FileDescriptorDataLoader( + int fd, + size_t file_size, + size_t alignment, + const char* file_descriptor_uri) + : file_descriptor_uri_(file_descriptor_uri), + file_size_(file_size), + alignment_(alignment), + fd_(fd) {} + + // Not safely copyable. + FileDescriptorDataLoader(const FileDescriptorDataLoader&) = delete; + FileDescriptorDataLoader& operator=(const FileDescriptorDataLoader&) = delete; + FileDescriptorDataLoader& operator=(FileDescriptorDataLoader&&) = delete; + + const char* const file_descriptor_uri_; // Owned by the instance. + const size_t file_size_; + const size_t alignment_; + const int fd_; // Owned by the instance. +}; + +} // namespace extension +} // namespace executorch + +namespace torch { +namespace executor { +namespace util { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::extension::FileDescriptorDataLoader; +} // namespace util +} // namespace executor +} // namespace torch diff --git a/extension/data_loader/targets.bzl b/extension/data_loader/targets.bzl index 4886df03a7..fcc7cba541 100644 --- a/extension/data_loader/targets.bzl +++ b/extension/data_loader/targets.bzl @@ -52,6 +52,21 @@ def define_common_targets(): ], ) + runtime.cxx_library( + name = "file_descriptor_data_loader", + srcs = ["file_descriptor_data_loader.cpp"], + exported_headers = ["file_descriptor_data_loader.h"], + visibility = [ + "//executorch/test/...", + "//executorch/runtime/executor/test/...", + "//executorch/extension/data_loader/test/...", + "@EXECUTORCH_CLIENTS", + ], + exported_deps = [ + "//executorch/runtime/core:core", + ], + ) + runtime.cxx_library( name = "mmap_data_loader", srcs = ["mmap_data_loader.cpp"], diff --git a/extension/data_loader/test/file_data_loader_test.cpp b/extension/data_loader/test/file_data_loader_test.cpp index b8921aebb5..1d4f4c1619 100644 --- a/extension/data_loader/test/file_data_loader_test.cpp +++ b/extension/data_loader/test/file_data_loader_test.cpp @@ -40,103 +40,6 @@ class FileDataLoaderTest : public ::testing::TestWithParam { } }; -TEST_P(FileDataLoaderTest, InBoundsFileDescriptorLoadsSucceed) { - // Write some heterogeneous data to a file. - uint8_t data[256]; - for (int i = 0; i < sizeof(data); ++i) { - data[i] = i; - } - TempFile tf(data, sizeof(data)); - - int fd = ::open(tf.path().c_str(), O_RDONLY); - - // Wrap it in a loader. - Result fdl = FileDataLoader::fromFileDescriptorUri( - ("fd:///" + std::to_string(fd)).c_str(), alignment()); - ASSERT_EQ(fdl.error(), Error::Ok); - - // size() should succeed and reflect the total size. - Result size = fdl->size(); - ASSERT_EQ(size.error(), Error::Ok); - EXPECT_EQ(*size, sizeof(data)); - - // Load the first bytes of the data. - { - Result fb = fdl->load( - /*offset=*/0, - /*size=*/8, - DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); - ASSERT_EQ(fb.error(), Error::Ok); - EXPECT_ALIGNED(fb->data(), alignment()); - EXPECT_EQ(fb->size(), 8); - EXPECT_EQ( - 0, - std::memcmp( - fb->data(), - "\x00\x01\x02\x03" - "\x04\x05\x06\x07", - fb->size())); - - // Freeing should release the buffer and clear out the segment. - fb->Free(); - EXPECT_EQ(fb->size(), 0); - EXPECT_EQ(fb->data(), nullptr); - - // Safe to call multiple times. - fb->Free(); - } - - // Load the last few bytes of the data, a different size than the first time. - { - Result fb = fdl->load( - /*offset=*/sizeof(data) - 3, - /*size=*/3, - DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); - ASSERT_EQ(fb.error(), Error::Ok); - EXPECT_ALIGNED(fb->data(), alignment()); - EXPECT_EQ(fb->size(), 3); - EXPECT_EQ(0, std::memcmp(fb->data(), "\xfd\xfe\xff", fb->size())); - } - - // Loading all of the data succeeds. - { - Result fb = fdl->load( - /*offset=*/0, - /*size=*/sizeof(data), - DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); - ASSERT_EQ(fb.error(), Error::Ok); - EXPECT_ALIGNED(fb->data(), alignment()); - EXPECT_EQ(fb->size(), sizeof(data)); - EXPECT_EQ(0, std::memcmp(fb->data(), data, fb->size())); - } - - // Loading zero-sized data succeeds, even at the end of the data. - { - Result fb = fdl->load( - /*offset=*/sizeof(data), - /*size=*/0, - DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); - ASSERT_EQ(fb.error(), Error::Ok); - EXPECT_EQ(fb->size(), 0); - } -} - -TEST_P(FileDataLoaderTest, FileDescriptorLoadPrefixFail) { - // Write some heterogeneous data to a file. - uint8_t data[256]; - for (int i = 0; i < sizeof(data); ++i) { - data[i] = i; - } - TempFile tf(data, sizeof(data)); - - int fd = ::open(tf.path().c_str(), O_RDONLY); - - // Wrap it in a loader. - Result fdl = FileDataLoader::fromFileDescriptorUri( - std::to_string(fd).c_str(), alignment()); - ASSERT_EQ(fdl.error(), Error::InvalidArgument); -} - TEST_P(FileDataLoaderTest, InBoundsLoadsSucceed) { // Write some heterogeneous data to a file. uint8_t data[256]; diff --git a/extension/data_loader/test/file_descriptor_data_loader_test.cpp b/extension/data_loader/test/file_descriptor_data_loader_test.cpp new file mode 100644 index 0000000000..0258611cbd --- /dev/null +++ b/extension/data_loader/test/file_descriptor_data_loader_test.cpp @@ -0,0 +1,359 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include + +#include +#include +#include +#include + +using namespace ::testing; +using executorch::extension::FileDescriptorDataLoader; +using executorch::extension::testing::TempFile; +using executorch::runtime::DataLoader; +using executorch::runtime::Error; +using executorch::runtime::FreeableBuffer; +using executorch::runtime::Result; + +class FileDescriptorDataLoaderTest : public ::testing::TestWithParam { + protected: + void SetUp() override { + // Since these tests cause ET_LOG to be called, the PAL must be initialized + // first. + executorch::runtime::runtime_init(); + } + + // The alignment in bytes that tests should use. The values are set by the + // list in the INSTANTIATE_TEST_SUITE_P call below. + size_t alignment() const { + return GetParam(); + } +}; + +TEST_P(FileDescriptorDataLoaderTest, InBoundsFileDescriptorLoadsSucceed) { + // Write some heterogeneous data to a file. + uint8_t data[256]; + for (int i = 0; i < sizeof(data); ++i) { + data[i] = i; + } + TempFile tf(data, sizeof(data)); + + int fd = ::open(tf.path().c_str(), O_RDONLY); + + // Wrap it in a loader. + Result fdl = + FileDescriptorDataLoader::fromFileDescriptorUri( + ("fd:///" + std::to_string(fd)).c_str(), alignment()); + ASSERT_EQ(fdl.error(), Error::Ok); + + // size() should succeed and reflect the total size. + Result size = fdl->size(); + ASSERT_EQ(size.error(), Error::Ok); + EXPECT_EQ(*size, sizeof(data)); + + // Load the first bytes of the data. + { + Result fb = fdl->load( + /*offset=*/0, + /*size=*/8, + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); + ASSERT_EQ(fb.error(), Error::Ok); + EXPECT_ALIGNED(fb->data(), alignment()); + EXPECT_EQ(fb->size(), 8); + EXPECT_EQ( + 0, + std::memcmp( + fb->data(), + "\x00\x01\x02\x03" + "\x04\x05\x06\x07", + fb->size())); + + // Freeing should release the buffer and clear out the segment. + fb->Free(); + EXPECT_EQ(fb->size(), 0); + EXPECT_EQ(fb->data(), nullptr); + + // Safe to call multiple times. + fb->Free(); + } + + // Load the last few bytes of the data, a different size than the first time. + { + Result fb = fdl->load( + /*offset=*/sizeof(data) - 3, + /*size=*/3, + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); + ASSERT_EQ(fb.error(), Error::Ok); + EXPECT_ALIGNED(fb->data(), alignment()); + EXPECT_EQ(fb->size(), 3); + EXPECT_EQ(0, std::memcmp(fb->data(), "\xfd\xfe\xff", fb->size())); + } + + // Loading all of the data succeeds. + { + Result fb = fdl->load( + /*offset=*/0, + /*size=*/sizeof(data), + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); + ASSERT_EQ(fb.error(), Error::Ok); + EXPECT_ALIGNED(fb->data(), alignment()); + EXPECT_EQ(fb->size(), sizeof(data)); + EXPECT_EQ(0, std::memcmp(fb->data(), data, fb->size())); + } + + // Loading zero-sized data succeeds, even at the end of the data. + { + Result fb = fdl->load( + /*offset=*/sizeof(data), + /*size=*/0, + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); + ASSERT_EQ(fb.error(), Error::Ok); + EXPECT_EQ(fb->size(), 0); + } +} + +TEST_P(FileDescriptorDataLoaderTest, FileDescriptorLoadPrefixFail) { + // Write some heterogeneous data to a file. + uint8_t data[256]; + for (int i = 0; i < sizeof(data); ++i) { + data[i] = i; + } + TempFile tf(data, sizeof(data)); + + int fd = ::open(tf.path().c_str(), O_RDONLY); + + // Wrap it in a loader. + Result fdl = + FileDescriptorDataLoader::fromFileDescriptorUri( + std::to_string(fd).c_str(), alignment()); + ASSERT_EQ(fdl.error(), Error::InvalidArgument); +} + +TEST_P(FileDescriptorDataLoaderTest, InBoundsLoadsSucceed) { + // Write some heterogeneous data to a file. + uint8_t data[256]; + for (int i = 0; i < sizeof(data); ++i) { + data[i] = i; + } + TempFile tf(data, sizeof(data)); + + int fd = ::open(tf.path().c_str(), O_RDONLY); + + // Wrap it in a loader. + Result fdl = + FileDescriptorDataLoader::fromFileDescriptorUri( + ("fd:///" + std::to_string(fd)).c_str(), alignment()); + ASSERT_EQ(fdl.error(), Error::Ok); + + // size() should succeed and reflect the total size. + Result size = fdl->size(); + ASSERT_EQ(size.error(), Error::Ok); + EXPECT_EQ(*size, sizeof(data)); + + // Load the first bytes of the data. + { + Result fb = fdl->load( + /*offset=*/0, + /*size=*/8, + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); + ASSERT_EQ(fb.error(), Error::Ok); + EXPECT_ALIGNED(fb->data(), alignment()); + EXPECT_EQ(fb->size(), 8); + EXPECT_EQ( + 0, + std::memcmp( + fb->data(), + "\x00\x01\x02\x03" + "\x04\x05\x06\x07", + fb->size())); + + // Freeing should release the buffer and clear out the segment. + fb->Free(); + EXPECT_EQ(fb->size(), 0); + EXPECT_EQ(fb->data(), nullptr); + + // Safe to call multiple times. + fb->Free(); + } + + // Load the last few bytes of the data, a different size than the first time. + { + Result fb = fdl->load( + /*offset=*/sizeof(data) - 3, + /*size=*/3, + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); + ASSERT_EQ(fb.error(), Error::Ok); + EXPECT_ALIGNED(fb->data(), alignment()); + EXPECT_EQ(fb->size(), 3); + EXPECT_EQ(0, std::memcmp(fb->data(), "\xfd\xfe\xff", fb->size())); + } + + // Loading all of the data succeeds. + { + Result fb = fdl->load( + /*offset=*/0, + /*size=*/sizeof(data), + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); + ASSERT_EQ(fb.error(), Error::Ok); + EXPECT_ALIGNED(fb->data(), alignment()); + EXPECT_EQ(fb->size(), sizeof(data)); + EXPECT_EQ(0, std::memcmp(fb->data(), data, fb->size())); + } + + // Loading zero-sized data succeeds, even at the end of the data. + { + Result fb = fdl->load( + /*offset=*/sizeof(data), + /*size=*/0, + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); + ASSERT_EQ(fb.error(), Error::Ok); + EXPECT_EQ(fb->size(), 0); + } +} + +TEST_P(FileDescriptorDataLoaderTest, OutOfBoundsLoadFails) { + // Create a temp file; contents don't matter. + uint8_t data[256] = {}; + TempFile tf(data, sizeof(data)); + + int fd = ::open(tf.path().c_str(), O_RDONLY); + + // Wrap it in a loader. + Result fdl = + FileDescriptorDataLoader::fromFileDescriptorUri( + ("fd:///" + std::to_string(fd)).c_str(), alignment()); + ASSERT_EQ(fdl.error(), Error::Ok); + + // Loading beyond the end of the data should fail. + { + Result fb = fdl->load( + /*offset=*/0, + /*size=*/sizeof(data) + 1, + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); + EXPECT_NE(fb.error(), Error::Ok); + } + + // Loading zero bytes still fails if it's past the end of the data. + { + Result fb = fdl->load( + /*offset=*/sizeof(data) + 1, + /*size=*/0, + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); + EXPECT_NE(fb.error(), Error::Ok); + } +} + +TEST_P(FileDescriptorDataLoaderTest, BadAlignmentFails) { + // Create a temp file; contents don't matter. + uint8_t data[256] = {}; + TempFile tf(data, sizeof(data)); + + // Creating a loader with default alignment works fine. + { + int fd = ::open(tf.path().c_str(), O_RDONLY); + + // Wrap it in a loader. + Result fdl = + FileDescriptorDataLoader::fromFileDescriptorUri( + ("fd:///" + std::to_string(fd)).c_str(), alignment()); + ASSERT_EQ(fdl.error(), Error::Ok); + } + + // Bad alignments fail. + const std::vector bad_alignments = {0, 3, 5, 17}; + for (size_t bad_alignment : bad_alignments) { + int fd = ::open(tf.path().c_str(), O_RDONLY); + + // Wrap it in a loader. + Result fdl = + FileDescriptorDataLoader::fromFileDescriptorUri( + ("fd:///" + std::to_string(fd)).c_str(), bad_alignment); + ASSERT_EQ(fdl.error(), Error::InvalidArgument); + } +} + +// Tests that the move ctor works. +TEST_P(FileDescriptorDataLoaderTest, MoveCtor) { + // Create a loader. + std::string contents = "FILE_CONTENTS"; + TempFile tf(contents); + int fd = ::open(tf.path().c_str(), O_RDONLY); + + // Wrap it in a loader. + Result fdl = + FileDescriptorDataLoader::fromFileDescriptorUri( + ("fd:///" + std::to_string(fd)).c_str(), alignment()); + ASSERT_EQ(fdl.error(), Error::Ok); + EXPECT_EQ(fdl->size().get(), contents.size()); + + // Move it into another instance. + FileDescriptorDataLoader fdl2(std::move(*fdl)); + + // Old loader should now be invalid. + EXPECT_EQ( + fdl->load( + 0, + 0, + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)) + .error(), + Error::InvalidState); + EXPECT_EQ(fdl->size().error(), Error::InvalidState); + + // New loader should point to the file. + EXPECT_EQ(fdl2.size().get(), contents.size()); + Result fb = fdl2.load( + /*offset=*/0, + contents.size(), + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); + ASSERT_EQ(fb.error(), Error::Ok); + EXPECT_ALIGNED(fb->data(), alignment()); + ASSERT_EQ(fb->size(), contents.size()); + EXPECT_EQ(0, std::memcmp(fb->data(), contents.data(), fb->size())); +} + +// Test that the deprecated From method (capital 'F') still works. +TEST_P(FileDescriptorDataLoaderTest, DEPRECATEDFrom) { + // Write some heterogeneous data to a file. + uint8_t data[256]; + for (int i = 0; i < sizeof(data); ++i) { + data[i] = i; + } + TempFile tf(data, sizeof(data)); + + int fd = ::open(tf.path().c_str(), O_RDONLY); + + // Wrap it in a loader. + Result fdl = + FileDescriptorDataLoader::fromFileDescriptorUri( + ("fd:///" + std::to_string(fd)).c_str(), alignment()); + ASSERT_EQ(fdl.error(), Error::Ok); + + // size() should succeed and reflect the total size. + Result size = fdl->size(); + ASSERT_EQ(size.error(), Error::Ok); + EXPECT_EQ(*size, sizeof(data)); +} + +// Run all FileDescriptorDataLoaderTests multiple times, varying the return +// value of `GetParam()` based on the `testing::Values` list. The tests will +// interpret the value as "alignment". +INSTANTIATE_TEST_SUITE_P( + VariedSegments, + FileDescriptorDataLoaderTest, + testing::Values( + 1, + 4, + alignof(std::max_align_t), + 2 * alignof(std::max_align_t), + 128, + 1024)); diff --git a/extension/data_loader/test/targets.bzl b/extension/data_loader/test/targets.bzl index 9c83d6d56b..d424413c1b 100644 --- a/extension/data_loader/test/targets.bzl +++ b/extension/data_loader/test/targets.bzl @@ -38,6 +38,17 @@ def define_common_targets(): ], ) + runtime.cxx_test( + name = "file_descriptor_data_loader_test", + srcs = [ + "file_descriptor_data_loader_test.cpp", + ], + deps = [ + "//executorch/extension/testing_util:temp_file", + "//executorch/extension/data_loader:file_descriptor_data_loader", + ], + ) + runtime.cxx_test( name = "mmap_data_loader_test", srcs = [ diff --git a/extension/llm/custom_ops/targets.bzl b/extension/llm/custom_ops/targets.bzl index 6b9f9cb959..781225afed 100644 --- a/extension/llm/custom_ops/targets.bzl +++ b/extension/llm/custom_ops/targets.bzl @@ -1,10 +1,14 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +load( + "@fbsource//xplat/executorch/kernels/optimized:lib_defs.bzl", + "get_vec_preprocessor_flags", + "get_vec_deps", +) load( "@fbsource//xplat/executorch/kernels/portable:op_registration_util.bzl", "get_compiler_optimization_flags", ) - def define_common_targets(): """Defines targets that should be shared between fbcode and xplat. @@ -26,6 +30,7 @@ def define_common_targets(): "op_sdpa.h", "op_update_quantized_cache.h", ], + preprocessor_flags = get_vec_preprocessor_flags(), exported_deps = [ "//executorch/runtime/kernel:kernel_includes", "//executorch/kernels/portable/cpu:scalar_utils", @@ -38,7 +43,7 @@ def define_common_targets(): deps = [ "//executorch/kernels/portable/cpu/util:reduce_util", "//executorch/extension/llm/custom_ops/spinquant:fast_hadamard_transform", - ], + ] + get_vec_deps(), compiler_flags = ["-Wno-missing-prototypes", "-Wno-global-constructors"] + get_compiler_optimization_flags(), visibility = [ "//executorch/...", diff --git a/extension/llm/export/partitioner_lib.py b/extension/llm/export/partitioner_lib.py index d966de9a25..6f4b95e3d0 100644 --- a/extension/llm/export/partitioner_lib.py +++ b/extension/llm/export/partitioner_lib.py @@ -32,7 +32,7 @@ def get_xnnpack_partitioner(dynamic_quant_only_partitioner: bool = True): def get_vulkan_partitioner( - dtype_override: Optional[str] = None, quantization_mode: Optional[str] = None + dtype_override: Optional[str] = None, enable_dynamic_shape: bool = False ): assert ( dtype_override == "fp32" or dtype_override is None @@ -41,7 +41,7 @@ def get_vulkan_partitioner( VulkanPartitioner, ) - return VulkanPartitioner({"require_dynamic_shapes": True}) + return VulkanPartitioner({"require_dynamic_shapes": enable_dynamic_shape}) def get_mps_partitioner(use_kv_cache: bool = False): diff --git a/extension/llm/modules/README.md b/extension/llm/modules/README.md new file mode 100644 index 0000000000..e6e1a20cec --- /dev/null +++ b/extension/llm/modules/README.md @@ -0,0 +1,17 @@ +## Export-friendly Modules + +Modules in this directory: +* Extend `torch.nn.Module`. +* Are guaranteed to work out of the box with `torch.export.export()`. +* Should work out of the box with `torch.aot_compile()`. +* Should be able to workt with ExecuTorch. + +All modules should be covered by unit tests to make sure they are: +1. Give the output as the reference eager model in PyTorch or TorrchTune +2. Export-friendly + +Additionally, we aim to make these modules: +3. AOTI-friendly +4. ExecuTorch-friendly + +These modules are subject to change (may upstream to TorchTune) so proceed with caution. diff --git a/extension/llm/modules/__init__.py b/extension/llm/modules/__init__.py new file mode 100644 index 0000000000..38245bf935 --- /dev/null +++ b/extension/llm/modules/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from ._position_embeddings import ( + replace_tile_positional_embedding, + TilePositionalEmbedding, +) + +__all__ = [ + "TilePositionalEmbedding", + "replace_tile_positional_embedding", +] diff --git a/extension/llm/modules/_position_embeddings.py b/extension/llm/modules/_position_embeddings.py new file mode 100644 index 0000000000..0c6a4f6ed9 --- /dev/null +++ b/extension/llm/modules/_position_embeddings.py @@ -0,0 +1,243 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# An torch.export() friendly version of torchtune's positional embeddings. +# Added torch._check() to make sure guards on symints are enforced. +# See https://github.com/pytorch/torchtune/blob/main/torchtune/models/clip/_position_embeddings.py + +import logging +from typing import Any, Dict, Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) + + +class TilePositionalEmbedding(nn.Module): + """ + Positional embedding for tiles, different for every tile, same for every token within a tile. + + Notice that tile is different from patch (token). For details, please check the documentation of + :class:`torchtune.modules.vision_transformer.VisionTransformer`. + + Args: + max_num_tiles (int): The maximum number of tiles an image can be divided into. + embed_dim (int): The dimensionality of each tile embedding. + """ + + def __init__( + self, + max_num_tiles: int, + embed_dim: int, + ): + super().__init__() + self.max_num_tiles = max_num_tiles + self.embed_dim = embed_dim + + scale = embed_dim**-0.5 + self.embedding = nn.Parameter( + scale * torch.randn(max_num_tiles, max_num_tiles, 1, embed_dim) + ) + self.gate = nn.Parameter(torch.zeros(1)) + + # Register load hook to interpolate positional embeddings + self._register_load_state_dict_pre_hook(self._load_state_dict_hook) + + # TODO: Switch to public method after 2.5 is stable + @torch.no_grad() + def _load_state_dict_hook( + self, + state_dict: Dict[str, Any], + prefix: str, + *args: Tuple[Any], + **kwargs: Dict[str, Any], + ): + """ + Interpolates positional embeddings to accomodate different number of tiles, + in case the model was instantiated with different + settings than the one you are loading the state dict from. + + For more info, check self._dynamic_resize function. + + Args: + state_dict (Dict[str, Any]): The state dict to load. + prefix (str): The prefix of the state dict. + *args (Tuple[Any]): Additional positional arguments. + **kwargs (Dict[str, Any]): Additional keyword arguments. + + Raises: + ValueError: if the shape of the loaded embedding is not compatible with the current embedding. + ValueError: if max_num_tiles_x, max_num_tiles_y are not equal. + ValueError: if after interpolation, the shape of the loaded embedding is not compatible with the current embedding. + """ + + embedding = state_dict.get(prefix + "embedding") + + if embedding is not None: + + # ckpt pos emb + ( + tgt_max_num_tiles_x, + tgt_max_num_tiles_y, + tgt_num_tokens, + tgt_emb, + ) = self.embedding.shape + + # instantiated pos emb + ( + inpt_max_num_tiles_x, + inpt_max_num_tiles_y, + inpt_num_tokens, + inpt_emb, + ) = state_dict[prefix + "embedding"].shape + + # sanity check + if inpt_num_tokens != tgt_num_tokens or inpt_emb != tgt_emb: + raise ValueError( + "Expected embedding shape to be (..., num_tokens, tgt_emb) to match" + f" but found shapes {self.embedding.shape} and {state_dict[prefix + 'embedding'].shape}" + ) + + if inpt_max_num_tiles_x != inpt_max_num_tiles_y: + raise ValueError( + "Expected max_num_tiles_x, max_num_tiles_y to be equal but found, but found" + f"(max_num_tiles_x, max_num_tiles_y, 1, embed_dim) = {self.embedding.shape}" + ) + + # resize ckpt to match instantiated shape + embedding_new = self._resize_position_embedding( + embedding, tgt_max_num_tiles=tgt_max_num_tiles_x + ) + + # update state dict + state_dict[prefix + "embedding"] = embedding_new + if embedding_new.shape != self.embedding.shape: + raise ValueError( + "Expected embedding shape and embedding_new.shape to match" + f" but found shapes {self.embedding.shape} and {embedding_new.shape}" + ) + + @staticmethod + def _resize_position_embedding( + embedding: torch.Tensor, tgt_max_num_tiles: int + ) -> torch.Tensor: + """ + Interpolates positional embeddings to accomodate a different max_num_tiles. These + are the only dimensions that changes during interpolation. + + Args: + embedding (torch.Tensor): torch.Tensor with shape (max_num_tiles, max_num_tiles, 1, embed_dim + tgt_max_num_tiles (int): The number of tiles to resize to. + + Returns: + torch.Tensor: The resized embedding. + + Example: + >>> import torch + >>> # create dummy embedding + >>> embedding = torch.arange(2*2*2*2).reshape(2, 2, 2, 2).float() + >>> resized_embed = _dynamic_resize(embedding, tgt_max_num_tiles=1) + >>> print(resized_embed.shape) + >>> torch.Size([1, 1, 2, 2]) + """ + # set max_num_tiles to the last dimension + embedding = embedding.permute(2, 3, 0, 1) + + embedding = F.interpolate( + embedding, + size=(tgt_max_num_tiles, tgt_max_num_tiles), + mode="bilinear", + align_corners=True, + ) + # permute to the original shape + embedding = embedding.permute(2, 3, 0, 1) + return embedding + + def forward(self, x: torch.Tensor, aspect_ratio: torch.Tensor) -> torch.Tensor: + """ + args: + x (torch.Tensor): torch.Tensor with shape (bsz * n_imgs, n_tiles, n_tokens, embed_dim). + aspect_ratio (torch.Tensor): torch.Tensor with shape (bsz * n_imgs, 2), + representing the aspect ratio of the image before tile-cropping, e.g. (2,1). + returns: + torch.Tensor: The input tensor with added positional embeddings. + """ + bsz_and_n_imgs, n_tiles, n_tokens, embed_dim = x.shape + torch._check(n_tiles <= self.max_num_tiles) + + for batch_idx, (n_tiles_h, n_tiles_w) in enumerate(aspect_ratio): + # When we batch images, all are padded to the same amount of tiles. + # The aspect_ratio lets us know the non padded tiles for each image. + # We only add positional encoding to those. + n_tiles_h = n_tiles_h.item() + n_tiles_w = n_tiles_w.item() + + n_non_padded_tiles = int(n_tiles_h * n_tiles_w) + + # We get only the positional encoding for non padded tiles, + # i.e. n_tiles_h, n_tiles_w. + torch._check_is_size(n_tiles_h) + torch._check_is_size(n_tiles_w) + torch._check(n_tiles_h >= 1) + torch._check(n_tiles_w >= 1) + torch._check(n_tiles_h <= self.max_num_tiles) + torch._check(n_tiles_w <= self.max_num_tiles) + # TODO: Remove this once pytorch/pytorch#120288 is fixed + padded_embedding = F.pad(self.embedding, (0, 0, 0, 0, 0, 1, 0, 1)) + pos_embed = padded_embedding[:n_tiles_h, :n_tiles_w, :, :] + + # We need to do a clone here in order to make this model export + # friendly as the reshape is collapsing dim 0 and dim 1 into a + # single dim. + pos_embed = pos_embed.clone() + pos_embed = pos_embed.reshape(n_non_padded_tiles, 1, self.embed_dim) + + x = F.pad(x, (0, 0, 0, 0, 0, 1, 0, 0)) + torch._check_is_size(n_non_padded_tiles) + torch._check(n_non_padded_tiles < x.size(1)) + x[batch_idx, :n_non_padded_tiles, :, :] += pos_embed * self.gate.tanh() + x = x[:, :n_tiles, :, :] + + return x + + +def replace_tile_positional_embedding(model: nn.Module) -> nn.Module: + """ + Replace the tile positional embedding from torchtune with an export-friendly one. + Recursively searches the submodules of the model and replaces the tile positional embedding if found. + Args: + model (nn.Module): The model to replace the tile positional embedding in. + + Returns: + nn.Module: The model after replacing the tile positional embedding. + + """ + from torchtune.models.clip._position_embeddings import ( + TilePositionalEmbedding as TuneTilePositionalEmbedding, + ) + + for name, module in model.named_children(): + if isinstance(module, TuneTilePositionalEmbedding): + logging.info( + f"Replacing tile positional embedding in {name} with export-friendly one." + ) + max_num_tiles, _, _, embed_dim = module.embedding.shape + mod = TilePositionalEmbedding( + max_num_tiles=max_num_tiles, + embed_dim=embed_dim, + ) + mod.load_state_dict(module.state_dict()) + setattr( + model, + name, + mod, + ) + else: + replace_tile_positional_embedding(module) + return model diff --git a/extension/llm/modules/mha.py b/extension/llm/modules/mha.py new file mode 100644 index 0000000000..0bfa4eb20c --- /dev/null +++ b/extension/llm/modules/mha.py @@ -0,0 +1,404 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Optional + +import torch +import torchtune.modules.attention as TorchTuneAttention +from torch import nn +from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention +from torchtune.modules.kv_cache import KVCache + +logger = logging.getLogger(__name__) + + +class MultiHeadAttention(nn.Module): + """ + NOTE: copied from Torchtune's mha.py. Should be mostly 1:1 except + that SDPA is factored out so that it can be swapped for more + efficient ExecuTorch-defined SDPA ops. + + Multi-headed attention layer with support for grouped query + attention (GQA) introduced in https://arxiv.org/abs/2305.13245v1. + + GQA is a version of multiheaded attention (MHA) which uses fewer + key/value heads than query heads by grouping n query heads for each + key and value head. Multi-Query Attention is an extreme + version where we have a single key and value head shared by all + query heads. + + Following is an example of MHA, GQA and MQA with num_heads = 4 + + (credit for the documentation: + `litgpt.Config `_). + + + :: + + ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ + │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │ + └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ + │ │ │ │ │ │ │ + ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ + │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │ + └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ + │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐ + ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ + │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ + └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ + ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶ + MHA GQA MQA + n_kv_heads =4 n_kv_heads=2 n_kv_heads=1 + + Args: + embed_dim (int): embedding dimension for the model + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + num_kv_heads (int): number of key and value heads. User should ensure + ``num_heads % num_kv_heads == 0``. For standard MHA set ``num_kv_heads == num_heads``, + for GQA ``num_kv_heads < num_heads``, and for MQA set ``num_kv_heads == 1``. + head_dim (int): dimension of each head, calculated by ``embed_dim // num_heads``. + q_proj (nn.Module): projection layer for query. + k_proj (nn.Module): projection layer for key. + v_proj (nn.Module): projection layer for value. + output_proj (nn.Module): projection layer for output. + pos_embeddings (Optional[nn.Module]): positional embeddings layer, e.g. RotaryPositionalEmbeddings. + q_norm (Optional[nn.Module]): normalization layer for query, e.g. RMSNorm. For decoding, this is applied + before updating from kv_cache. This means it will only support token wide normalization and not + batch or sequence wide normalization. + k_norm (Optional[nn.Module]): normalization layer for key, must be set if q_norm is. + kv_cache (Optional[KVCache]): KVCache object used to cache key and value + max_seq_len (int): maximum sequence length supported by the model. + This is needed to compute the RoPE Cache. Default: 4096. + is_causal (bool): sets the default mask to causal when no mask is provided + attn_dropout (float): dropout value passed onto the scaled_dot_product_attention function. + Default value is 0.0. + + Raises: + ValueError: If ``num_heads % num_kv_heads != 0`` + ValueError: If ``embed_dim % num_heads != 0`` + ValueError: If ``attn_dropout < 0`` or ``attn_dropout > 1`` + ValueError: if q_norm is defined without k_norm or vice versa + """ + + def __init__( + self, + *, + embed_dim: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + q_proj: nn.Module, + k_proj: nn.Module, + v_proj: nn.Module, + output_proj: nn.Module, + pos_embeddings: Optional[nn.Module] = None, + q_norm: Optional[nn.Module] = None, + k_norm: Optional[nn.Module] = None, + kv_cache: Optional[KVCache] = None, + max_seq_len: int = 4096, + is_causal: bool = True, + attn_dropout: float = 0.0, + ) -> None: + super().__init__() + if num_heads % num_kv_heads != 0: + raise ValueError( + f"num_heads ({num_heads}) must be divisible by " + f"num_kv_heads ({num_kv_heads})" + ) + + if embed_dim % num_heads != 0: + raise ValueError( + f"embed_dim ({embed_dim}) must be divisible by " + f"num_heads ({num_heads})" + ) + + if attn_dropout < 0 or attn_dropout > 1: + raise ValueError(f"attn_dropout ({embed_dim}) must be between 0.0 and 1.0") + + if bool(q_norm) ^ bool(k_norm): + raise ValueError("q and k norm must be set together") + + # Set attributes + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.embed_dim = embed_dim + self.attn_dropout = attn_dropout + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.is_causal = is_causal + + # Set layers + self.kv_cache = kv_cache + self.q_proj = q_proj + self.k_proj = k_proj + self.v_proj = v_proj + self.output_proj = output_proj + self.q_norm = q_norm + self.k_norm = k_norm + self.pos_embeddings = pos_embeddings + + # Use flex attention if supported and we are sample packing + self._attention_call = _sdpa_or_flex_attention() + self._sdpa = SDPA( + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + head_dim=self.head_dim, + q_per_kv=self.num_heads // self.num_kv_heads, + attn_dropout=self.attn_dropout if self.training else 0.0, + is_causal=self.is_causal, + attention_fn=self._attention_call, + kv_cache=self.kv_cache, + ) + + # this flag indicates whether to update the kv-cache during forward + # passes. when disabled, we can have the cache setup but still + # perform normal forward passes + self.cache_enabled = False + + def setup_cache( + self, batch_size: int, dtype: torch.dtype, max_seq_len: int + ) -> None: + """Setup key value caches for attention calculation. If called + after kv_cache is already setup, this will be skipped. + + Args: + batch_size (int): batch size for the caches. + dtype (torch.dtype): dtype for the caches. + max_seq_len (int): maximum sequence length model will be run with. + """ + # Don't overwrite user defined kv_cache from init + if self.kv_cache is not None: + logger.warning( + "Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping." + ) + else: + self.kv_cache = KVCache( + batch_size=batch_size, + max_seq_len=max_seq_len, + num_heads=self.num_heads, + head_dim=self.head_dim, + dtype=dtype, + ) + self._sdpa.kv_cache = self.kv_cache + self.cache_enabled = True + + def reset_cache(self): + """Reset the key value caches.""" + if self.kv_cache is None: + raise RuntimeError( + "Key value caches are not setup. Call ``setup_caches()`` first." + ) + self.kv_cache.reset() + + def forward( + self, + x: torch.Tensor, + y: Optional[torch.Tensor] = None, + *, + mask: Optional[_MaskType] = None, + input_pos: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + x (torch.Tensor): input tensor with shape [b x s_x x d] for the query + y (Optional[torch.Tensor]): second input tensor with shape [b x s_y x d], is the input + for k and v. For self attention, x=y. Optional only with kv_cache enabled. + mask (Optional[_MaskType]): Used to mask the scores after the query-key multiplication + and before the softmax. Either: + + A boolean tensor with shape ``[b x s x s]``, ``[b x s x self.encoder_max_cache_seq_len]``, + or ``[b x s x self.encoder_max_cache_seq_len]`` if using KV-cacheing with encoder/decoder layers. + A value of True in row ``i`` and column ``j`` means token ``i`` attends to token ``j``. A value of False means + token ``i`` does not attend to token ``j``. If no mask is specified, a causal mask + is used by default. + + A :class:`~torch.nn.attention.flex_attention.BlockMask` for document masking in a packed sequence + created via `create_block_mask `_. We use + :func:`~torch.nn.attention.flex_attention.flex_attention` when computing attention with block masks. + Default is None. + input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids + of each token. During training, this is used to indicate the positions + of each token relative to its sample when packed, shape [b x s]. + During inference, this indicates the position of the current token. + If none, assume the index of the token is its position id. Default is None. + + Raises: + ValueError: If no ``y`` input and ``kv_cache`` is not enabled. + + Returns: + torch.Tensor: output tensor with attention applied + + Notation used for tensor shapes: + - b: batch size + - s_x: sequence length for x + - s_y: sequence length for y + - n_h: num heads + - n_kv: num kv heads + - d: embed dim + - h_d: head dim + """ + # x has shape [b, s_x, d] + # y has shape [b, s_y, d] + b, s_x, _ = x.shape + s_y = y.shape[1] if y is not None else 0 + + # q has shape [b, s_x, num_heads * head_dim] + q = self.q_proj(x) + + # number of queries per key/value + q_per_kv = self.num_heads // self.num_kv_heads + q = q.view(b, s_x, self.num_kv_heads * q_per_kv, self.head_dim) + + # Apply positional embeddings + if self.pos_embeddings is not None: + q = self.pos_embeddings(q, input_pos=input_pos) + + # Normalize q + if self.q_norm is not None: + q = self.q_norm(q) + + if y is None: + if self.kv_cache is None: + raise ValueError( + "Must provide y input or use kv_cache to enable streaming decoding" + ) + k = self.kv_cache.k_cache + v = self.kv_cache.v_cache + else: + # Update k and v shape, positional embeddings, and normalization + + # k has shape [b, s_y, num_kv_heads * head_dim] + # v has shape [b, s_y, num_kv_heads * head_dim] + k = self.k_proj(y) + v = self.v_proj(y) + + # Apply positional embeddings + # k: [b, s_y, n_kv, h_d] + k = k.view(b, s_y, -1, self.head_dim) + v = v.view(b, s_y, -1, self.head_dim) + if self.pos_embeddings is not None: + k = self.pos_embeddings(k, input_pos=input_pos) + + # Normalize k + if self.k_norm is not None: + k = self.k_norm(k) + + # Update key-value cache + if self.kv_cache is not None and self.cache_enabled: + k, v = self.kv_cache.update(k, v) + + output = self._sdpa(q, k, v, b, s_x) + return self.output_proj(output) + + +class SDPA(nn.Module): + """ + TorchTune's SDPA which can be optimized and can be swapped + out for a more efficient implementations. + """ + + def __init__( + self, + num_kv_heads: int, + num_heads: int, + head_dim: int, + q_per_kv: int, + attn_dropout: float, + is_causal: bool, + attention_fn, + kv_cache, + ) -> None: + super().__init__() + self.num_kv_heads = num_kv_heads + self.num_heads = num_heads + self.head_dim = head_dim + self.q_per_kv = q_per_kv + self.attn_dropout = attn_dropout + self.is_causal = is_causal + self._attention_fn = attention_fn + self.kv_cache = kv_cache + + def forward( + self, + q: torch.Tensor, # [b, s, n_h, h_d] + k: torch.Tensor, # [b, s, n_kv, h_d] + v: torch.Tensor, # [b, s, n_kv, h_d] + bsz: int, + seq_len: int, + mask: torch.Tensor = None, + ) -> torch.Tensor: + # View + expand + reshape bring num_kv_heads to num_heads for k and v + # to match q. + + # k: [bsz, seq_len, n_kv, 1, h_d] + # v: [bsz, seq_len, n_kv, 1, h_d] + k = k.view(bsz, seq_len, self.num_kv_heads, 1, self.head_dim) + v = v.view(bsz, seq_len, self.num_kv_heads, 1, self.head_dim) + + # Expand the key and value tensors to have the same shape + # as the query tensor by copying values across the relevant dim + if self.num_heads != self.num_kv_heads: + k = k.expand(bsz, seq_len, self.num_kv_heads, self.q_per_kv, self.head_dim) + v = v.expand(bsz, seq_len, self.num_kv_heads, self.q_per_kv, self.head_dim) + + # [bsz, s, n_h, h_d] + k = k.reshape(bsz, seq_len, -1, self.head_dim) + v = v.reshape(bsz, seq_len, -1, self.head_dim) + + # [bsz, n_h, s, h_d] + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + output = self._attention_fn( + q, + k, + v, + mask=mask, + dropout_p=self.attn_dropout, + is_causal=self.kv_cache is None and mask is None and self.is_causal, + ) + # Reshape the output to be the same shape as the input + return output.transpose(1, 2).contiguous().view(bsz, seq_len, -1) + + +def _replace_mha_with_inference_mha(module: torch.nn.Module) -> None: + for name, child in module.named_children(): + if isinstance(child, TorchTuneAttention.MultiHeadAttention): + setattr( + module, + name, + MultiHeadAttention( + embed_dim=child.embed_dim, + num_heads=child.num_heads, + num_kv_heads=child.num_kv_heads, + head_dim=child.head_dim, + q_proj=child.q_proj, + k_proj=child.k_proj, + v_proj=child.v_proj, + output_proj=child.output_proj, + pos_embeddings=child.pos_embeddings, + q_norm=child.q_norm, + k_norm=child.k_norm, + kv_cache=child.kv_cache, + max_seq_len=child.max_seq_len, + is_causal=child.is_causal, + attn_dropout=child.attn_dropout, + ), + ) + else: + replace_mha_with_inference_mha(child) + + +def replace_mha_with_inference_mha(module: torch.nn.Module) -> torch.nn.Module: + """ + Replace TorchTune's MHA with an inference friendly version of MHA that + separates out the inference-related parts for further optimization. + """ + _replace_mha_with_inference_mha(module) + return module diff --git a/extension/llm/modules/test/__init__.py b/extension/llm/modules/test/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/extension/llm/modules/test/test_mha.py b/extension/llm/modules/test/test_mha.py new file mode 100644 index 0000000000..0dc7cba685 --- /dev/null +++ b/extension/llm/modules/test/test_mha.py @@ -0,0 +1,144 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.exir import EdgeCompileConfig, to_edge + +from executorch.extension.llm.modules.mha import ( + MultiHeadAttention as ETMultiHeadAttention, +) +from executorch.runtime import Runtime +from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE +from torchtune.modules.attention import MultiHeadAttention as TTMultiHeadAttention + + +torch.manual_seed(0) + + +class AttentionTest(unittest.TestCase): + def setUp(self): + super().setUp() + + # Constants + self.embed_dim = 2048 + self.num_heads = 32 + self.num_kv_heads = 8 + self.head_dim = 64 + self.max_seq_len = 128 + self.rope_base = 500_000 + self.scale_factor = 32 + + # Module dependency injections. + self.q_proj = torch.nn.Linear( + self.embed_dim, self.num_heads * self.head_dim, bias=False + ) + self.k_proj = torch.nn.Linear( + self.embed_dim, self.num_kv_heads * self.head_dim, bias=False + ) + self.v_proj = torch.nn.Linear( + self.embed_dim, self.num_kv_heads * self.head_dim, bias=False + ) + self.output_proj = torch.nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.pos_embeddings = Llama3ScaledRoPE( + dim=self.head_dim, + max_seq_len=self.max_seq_len, + base=self.rope_base, + scale_factor=self.scale_factor, + ) + + # Original TorchTune reference module to test accuracy against. + self.tt_mha = TTMultiHeadAttention( + embed_dim=self.embed_dim, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + q_proj=self.q_proj, + k_proj=self.k_proj, + v_proj=self.v_proj, + output_proj=self.output_proj, + pos_embeddings=self.pos_embeddings, + max_seq_len=self.max_seq_len, + ) + + # Source transformed module that we are testing. + self.et_mha = ETMultiHeadAttention( + embed_dim=self.embed_dim, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + q_proj=self.q_proj, + k_proj=self.k_proj, + v_proj=self.v_proj, + output_proj=self.output_proj, + pos_embeddings=self.pos_embeddings, + max_seq_len=self.max_seq_len, + ) + + # Common inputs. + seq_len = 10 + self.x = torch.randn(1, seq_len, self.embed_dim) + seq_len_dim = torch.export.Dim("seq_len", min=1, max=100) + self.dynamic_shapes = ( + {0: torch.export.Dim.STATIC, 1: seq_len_dim, 2: torch.export.Dim.STATIC}, + {0: torch.export.Dim.STATIC, 1: seq_len_dim, 2: torch.export.Dim.STATIC}, + ) + + def test_attention_eager(self): + et_res = self.et_mha(self.x, self.x) # Self attention. + tt_res = self.tt_mha(self.x, self.x) # Self attention. + + self.assertTrue(torch.allclose(et_res, tt_res)) + + # TODO: KV cache. + # self.et_mha.setup_cache(1, dtype=torch.float16, max_seq_len=20) + # self.tt_mha.setup_cache(1, dtype=torch.float16, max_seq_len=20) + + # et_res = self.et_mha(self.x, self.x) # Self attention. + # tt_res = self.tt_mha(self.x, self.x) # Self attention. + + # self.assertTrue(torch.allclose(et_res, tt_res)) + + def test_attention_export(self): + # Self attention. + et_mha_ep = torch.export.export( + self.et_mha, + (self.x, self.x), + kwargs=None, + dynamic_shapes=self.dynamic_shapes, + ) + et_res = et_mha_ep.module()(self.x, self.x) + tt_res = self.tt_mha(self.x, self.x) + self.assertTrue(torch.allclose(et_res, tt_res)) + + # TODO: KV cache. + + def test_attention_aoti(self): + # TODO. + pass + + def test_attention_executorch(self): + # Self attention. + et_mha_ep = torch.export.export( + self.et_mha, + (self.x, self.x), + kwargs=None, + dynamic_shapes=self.dynamic_shapes, + ) + et_program = to_edge( + et_mha_ep, + compile_config=EdgeCompileConfig(), + ).to_executorch() + runtime = Runtime.get() + program = runtime.load_program(et_program.buffer) + method = program.load_method("forward") + et_res = method.execute((self.x, self.x)) + tt_res = self.tt_mha(self.x, self.x) + + self.assertTrue(torch.allclose(et_res[0], tt_res, atol=1e-06)) + + # TODO: KV cache. diff --git a/extension/llm/modules/test/test_position_embeddings.py b/extension/llm/modules/test/test_position_embeddings.py new file mode 100644 index 0000000000..cf4e7e7f05 --- /dev/null +++ b/extension/llm/modules/test/test_position_embeddings.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import tempfile +import unittest + +import torch +from executorch.exir import EdgeCompileConfig, to_edge +from executorch.extension.llm.modules import ( + replace_tile_positional_embedding, + TilePositionalEmbedding, +) +from executorch.runtime import Runtime +from torch._inductor.package import load_package, package_aoti +from torchtune.models.clip import TilePositionalEmbedding as TuneTilePositionalEmbedding + + +class TilePositionalEmbeddingTest(unittest.TestCase): + def setUp(self): + super().setUp() + self.tpe = TilePositionalEmbedding(4, 1280) + self.ref_tpe = TuneTilePositionalEmbedding(4, 1280) + self.x = torch.randn(1, 4, 1600, 1280) + self.aspect_ratio = torch.tensor([[1, 1]]) + num_tiles_dim = torch.export.Dim("num_tiles", min=1, max=4) + num_tokens = torch.export.Dim("num_tokens", min=1, max=1600) + + self.dynamic_shape = { + 0: 1, # batch + 1: num_tiles_dim, # num tiles + 2: num_tokens, # num tokens + 3: 1280, # embedding dim + } + + def test_tile_positional_embedding_smoke(self): + y = self.tpe(self.x, self.aspect_ratio) + ref_y = self.ref_tpe(self.x, self.aspect_ratio) + + self.assertTrue(torch.allclose(y, ref_y)) + + def test_tile_positional_embedding_export(self): + + tpe_ep = torch.export.export( + self.tpe, + (self.x, self.aspect_ratio), + dynamic_shapes=( + self.dynamic_shape, + None, + ), # assuming aspect ratio is static + ) + + y = tpe_ep.module()(self.x, self.aspect_ratio) + ref_y = self.ref_tpe(self.x, self.aspect_ratio) + + self.assertTrue(torch.allclose(y, ref_y)) + + def test_tile_positional_embedding_aoti(self): + so = torch._export.aot_compile( + self.tpe, + args=(self.x, self.aspect_ratio), + options={"aot_inductor.package": True}, + dynamic_shapes=( + self.dynamic_shape, + None, + ), # assuming aspect ratio is static + ) + with tempfile.TemporaryDirectory() as tmpdir: + path = package_aoti(os.path.join(tmpdir, "tpe.pt2"), so) + tpe_aoti = load_package(path) + + y = tpe_aoti(self.x, self.aspect_ratio) + ref_y = self.ref_tpe(self.x, self.aspect_ratio) + + self.assertTrue(torch.allclose(y, ref_y)) + + def test_tile_positional_embedding_et(self): + tpe_ep = torch.export.export( + self.tpe, + (self.x, self.aspect_ratio), + dynamic_shapes=( + self.dynamic_shape, + None, + ), # assuming aspect ratio is static + ) + et_program = to_edge( + tpe_ep, + compile_config=EdgeCompileConfig( + _core_aten_ops_exception_list=[ + torch.ops.aten.sym_constrain_range_for_size.default, + torch.ops.aten._assert_scalar.default, + torch.ops.aten._local_scalar_dense.default, + ] + ), + ).to_executorch() + runtime = Runtime.get() + program = runtime.load_program(et_program.buffer) + method = program.load_method("forward") + y = method.execute((self.x, self.aspect_ratio)) + ref_y = self.ref_tpe(self.x, self.aspect_ratio) + + self.assertTrue(torch.allclose(y[0], ref_y)) + + def test_replace_tile_positional_embedding(self): + class Module(torch.nn.Module): + def __init__(self): + super().__init__() + self.tpe = TuneTilePositionalEmbedding(4, 1280) + + def forward(self, x, aspect_ratio): + return self.tpe(x, aspect_ratio) + + m = Module() + m = replace_tile_positional_embedding(m) + self.assertTrue(isinstance(m.tpe, TilePositionalEmbedding)) diff --git a/extension/llm/tokenizer/tokenizer.py b/extension/llm/tokenizer/tokenizer.py index ecd0231fb6..78377230b9 100644 --- a/extension/llm/tokenizer/tokenizer.py +++ b/extension/llm/tokenizer/tokenizer.py @@ -50,6 +50,10 @@ def decode(self, t: List[int]) -> str: # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`. return self.sp_model.decode(t) + def decode_token(self, t: int) -> str: + # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`. + return self.sp_model.decode(t) + def export(self, output_path: str, *, prepend_padding: bool = False) -> None: """ Export tokenizer.model to another serialization format. Here we did some lightweight diff --git a/kernels/optimized/lib_defs.bzl b/kernels/optimized/lib_defs.bzl index fb1c9a17f9..659c7afe09 100644 --- a/kernels/optimized/lib_defs.bzl +++ b/kernels/optimized/lib_defs.bzl @@ -15,16 +15,44 @@ load( # functions in order to declare the required compiler flags needed in order to # access CPU vector intrinsics. -def get_vec_android_preprocessor_flags(): - preprocessor_flags = [ - ( - "^android-arm64.*$", - [ +def get_vec_preprocessor_flags(): + if not runtime.is_oss: + # various ovr_configs are not available in oss + preprocessor_flags = select({ + "ovr_config//os:linux-x86_64": [ "-DET_BUILD_ARM_VEC256_WITH_SLEEF", - ], - ), - ] - return preprocessor_flags + ] if not runtime.is_oss else [], + "ovr_config//os:iphoneos-arm64": [ + "-DET_BUILD_ARM_VEC256_WITH_SLEEF", + ] if not runtime.is_oss else [], + "ovr_config//os:macos-arm64": [ + "-DET_BUILD_ARM_VEC256_WITH_SLEEF", + ] if not runtime.is_oss else [], + "ovr_config//os:android-arm64": [ + "-DET_BUILD_ARM_VEC256_WITH_SLEEF", + ] if not runtime.is_oss else [], + "DEFAULT": [], + }) + return preprocessor_flags + return [] + +def get_vec_deps(): + if not runtime.is_oss: + # various ovr_configs are not available in oss + deps = select({ + "ovr_config//os:iphoneos-arm64": [ + "fbsource//third-party/sleef:sleef_arm", + ] if not runtime.is_oss else [], + "ovr_config//os:macos-arm64": [ + "fbsource//third-party/sleef:sleef_arm", + ] if not runtime.is_oss else [], + "ovr_config//os:android-arm64": [ + "fbsource//third-party/sleef:sleef_arm", + ] if not runtime.is_oss else [], + "DEFAULT": [], + }) + return deps + return [] def get_vec_cxx_preprocessor_flags(): preprocessor_flags = [ diff --git a/kernels/optimized/op_registration_util.bzl b/kernels/optimized/op_registration_util.bzl index 6e74836bb7..6839454be2 100644 --- a/kernels/optimized/op_registration_util.bzl +++ b/kernels/optimized/op_registration_util.bzl @@ -2,7 +2,8 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") load("@fbsource//xplat/executorch/build:selects.bzl", "selects") load( "@fbsource//xplat/executorch/kernels/optimized:lib_defs.bzl", - "get_vec_android_preprocessor_flags", + "get_vec_preprocessor_flags", + "get_vec_deps", ) load( "@fbsource//xplat/executorch/kernels/portable:op_registration_util.bzl", @@ -94,8 +95,8 @@ def define_op_library(name, deps): compiler_flags = ["-Wno-missing-prototypes"] + get_compiler_optimization_flags(), deps = [ "//executorch/runtime/kernel:kernel_includes", - ] + augmented_deps, - fbandroid_platform_preprocessor_flags = get_vec_android_preprocessor_flags(), + ] + augmented_deps + get_vec_deps(), + preprocessor_flags = get_vec_preprocessor_flags(), # sleef needs to be added as a direct dependency of the operator target when building for Android, # or a linker error may occur. Not sure why this happens; it seems that fbandroid_platform_deps of # dependencies are not transitive diff --git a/kernels/optimized/test/targets.bzl b/kernels/optimized/test/targets.bzl index d2ee2880c6..e4740a9ad7 100644 --- a/kernels/optimized/test/targets.bzl +++ b/kernels/optimized/test/targets.bzl @@ -1,7 +1,7 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") load( "@fbsource//xplat/executorch/kernels/optimized:lib_defs.bzl", - "get_vec_android_preprocessor_flags", + "get_vec_preprocessor_flags", "get_vec_cxx_preprocessor_flags", ) load("@fbsource//xplat/executorch/kernels/test:util.bzl", "define_supported_features_lib") @@ -27,7 +27,7 @@ def _lib_test_bin(name, extra_deps = [], in_cpu = False): "//executorch/kernels/optimized{}:{}".format(cpu_path, lib_root), ] + extra_deps, cxx_platform_preprocessor_flags = get_vec_cxx_preprocessor_flags(), - fbandroid_platform_preprocessor_flags = get_vec_android_preprocessor_flags(), + preprocessor_flags = get_vec_preprocessor_flags(), ) def define_common_targets(): diff --git a/kernels/prim_ops/register_prim_ops.cpp b/kernels/prim_ops/register_prim_ops.cpp index 7872b0d173..38901bb840 100644 --- a/kernels/prim_ops/register_prim_ops.cpp +++ b/kernels/prim_ops/register_prim_ops.cpp @@ -12,6 +12,8 @@ #include #include +#include + using torch::executor::function::et_copy_index; namespace torch { @@ -301,6 +303,65 @@ static Kernel prim_ops[] = { } }), + // ceil.Scalar(Scalar a) -> Scalar + Kernel( + "executorch_prim::ceil.Scalar", + [](KernelRuntimeContext& context, EValue** stack) { + (void)context; + EValue& a = *stack[0]; + EValue& out = *stack[1]; + if (a.isDouble()) { + out = EValue(static_cast(ceil(a.toDouble()))); + } else { + ET_CHECK_MSG(false, "Unsupported DType %zu", (size_t)a.tag); + } + }), + + // round.Scalar(Scalar a) -> Scalar + Kernel( + "executorch_prim::round.Scalar", + [](KernelRuntimeContext& context, EValue** stack) { + (void)context; + EValue& a = *stack[0]; + EValue& out = *stack[1]; + if (a.isDouble()) { + // Round half to even to match Python round(). Need an explicit + // implementation as not all platforms support fenv rounding modes. + // See + // https://codeyarns.com/tech/2018-08-17-how-to-round-half-to-even.html + const auto val = a.toDouble(); + const auto r = round(val); + const auto d = r - val; + auto res = 0.0; + + if (std::abs(d) != 0.5) { + res = r; + } else if (fmod(r, 2.0) == 0.0) { + res = r; + } else { + res = val - d; + } + + out = EValue(static_cast(res)); + } else { + ET_CHECK_MSG(false, "Unsupported DType %zu", (size_t)a.tag); + } + }), + + // trunc.Scalar(Scalar a) -> Scalar + Kernel( + "executorch_prim::trunc.Scalar", + [](KernelRuntimeContext& context, EValue** stack) { + (void)context; + EValue& a = *stack[0]; + EValue& out = *stack[1]; + if (a.isDouble()) { + out = EValue(static_cast(trunc(a.toDouble()))); + } else { + ET_CHECK_MSG(false, "%zu", (size_t)a.tag); + } + }), + // executorch_prim::et_copy_index.tensor(tensor, tensor) -> tensor Kernel("executorch_prim::et_copy_index.tensor", &et_copy_index), // executorch_prim::et_view.default(Tensor, int[]) -> Tensor diff --git a/kernels/prim_ops/test/prim_ops_test.cpp b/kernels/prim_ops/test/prim_ops_test.cpp index 4b4b35a232..ab6bd28e6c 100644 --- a/kernels/prim_ops/test/prim_ops_test.cpp +++ b/kernels/prim_ops/test/prim_ops_test.cpp @@ -503,5 +503,66 @@ TEST_F(RegisterPrimOpsTest, TestETViewEmpty) { getOpsFn("executorch_prim::et_view.default")(context, bad_stack), ""); } +TEST_F(RegisterPrimOpsTest, TestCeil) { + std::array inputs = { + 0.0, 0.25, 0.5, 0.75, 1.0, 1.75, -0.5, -1.0, -1.5, 9.999999}; + std::array expected = {0, 1, 1, 1, 1, 2, 0, -1, -1, 10}; + + for (auto i = 0; i < inputs.size(); i++) { + EValue values[2]; + values[0] = EValue(inputs[i]); + values[1] = EValue(0.0); + + EValue* stack[2]; + for (size_t j = 0; j < 2; j++) { + stack[j] = &values[j]; + } + + getOpsFn("executorch_prim::ceil.Scalar")(context, stack); + EXPECT_EQ(stack[1]->toInt(), expected[i]); + } +} + +TEST_F(RegisterPrimOpsTest, TestRound) { + // Note that Python uses round-to-even for halfway values. + std::array inputs = { + 0.0, 0.25, 0.5, 0.75, 1.0, 1.5, -0.5, -1.0, -1.5, 9.999999}; + std::array expected = {0, 0, 0, 1, 1, 2, 0, -1, -2, 10}; + + for (auto i = 0; i < inputs.size(); i++) { + EValue values[2]; + values[0] = EValue(inputs[i]); + values[1] = EValue(0.0); + + EValue* stack[2]; + for (size_t j = 0; j < 2; j++) { + stack[j] = &values[j]; + } + + getOpsFn("executorch_prim::round.Scalar")(context, stack); + EXPECT_EQ(stack[1]->toInt(), expected[i]); + } +} + +TEST_F(RegisterPrimOpsTest, TestTrunc) { + std::array inputs = { + 0.0, 0.25, 0.5, 0.75, 1.0, 1.75, -0.5, -1.0, -1.5, 9.999999}; + std::array expected = {0, 0, 0, 0, 1, 1, 0, -1, -1, 9}; + + for (auto i = 0; i < inputs.size(); i++) { + EValue values[2]; + values[0] = EValue(inputs[i]); + values[1] = EValue(0.0); + + EValue* stack[2]; + for (size_t j = 0; j < 2; j++) { + stack[j] = &values[j]; + } + + getOpsFn("executorch_prim::trunc.Scalar")(context, stack); + EXPECT_EQ(stack[1]->toInt(), expected[i]); + } +} + } // namespace executor } // namespace torch diff --git a/kernels/quantized/cpu/embeddingxb.cpp b/kernels/quantized/cpu/embeddingxb.cpp index f8fdfe078c..5275f842df 100644 --- a/kernels/quantized/cpu/embeddingxb.cpp +++ b/kernels/quantized/cpu/embeddingxb.cpp @@ -65,7 +65,7 @@ static inline int32_t get_embedding_dim( void check_embedding_xbit_args( const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, const int64_t weight_quant_min, const int64_t weight_quant_max, const Tensor& indices, @@ -170,7 +170,7 @@ template void embedding_xbit_per_channel( const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, const Tensor& indices, Tensor& out, int weight_nbit) { @@ -260,7 +260,7 @@ Tensor& quantized_embedding_xbit_out( // non quant input and returns fp output const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, const int64_t weight_quant_min, const int64_t weight_quant_max, const Tensor& indices, @@ -299,7 +299,7 @@ Tensor& quantized_embedding_xbit_out( KernelRuntimeContext& context, const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, int64_t weight_quant_min, int64_t weight_quant_max, const Tensor& indices, @@ -325,7 +325,7 @@ Tensor& quantized_embedding_xbit_dtype_out( // non quant input and returns fp output const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, const int64_t weight_quant_min, const int64_t weight_quant_max, const Tensor& indices, @@ -368,7 +368,7 @@ Tensor& quantized_embedding_xbit_dtype_out( KernelRuntimeContext& context, const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, int64_t weight_quant_min, int64_t weight_quant_max, const Tensor& indices, diff --git a/kernels/quantized/cpu/embeddingxb.h b/kernels/quantized/cpu/embeddingxb.h index ae1fccc6c2..d08c8ae745 100644 --- a/kernels/quantized/cpu/embeddingxb.h +++ b/kernels/quantized/cpu/embeddingxb.h @@ -24,7 +24,7 @@ Tensor& quantized_embedding_xbit_out( // non quant input and returns fp output const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, const int64_t weight_quant_min, const int64_t weight_quant_max, const Tensor& indices, @@ -35,7 +35,7 @@ Tensor& quantized_embedding_xbit_out( KernelRuntimeContext& context, const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, int64_t weight_quant_min, int64_t weight_quant_max, const Tensor& indices, @@ -47,7 +47,7 @@ Tensor& quantized_embedding_xbit_dtype_out( // non quant input and returns fp output const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, const int64_t weight_quant_min, const int64_t weight_quant_max, const Tensor& indices, @@ -59,7 +59,7 @@ Tensor& quantized_embedding_xbit_dtype_out( KernelRuntimeContext& context, const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, int64_t weight_quant_min, int64_t weight_quant_max, const Tensor& indices, diff --git a/kernels/quantized/cpu/op_dequantize.cpp b/kernels/quantized/cpu/op_dequantize.cpp index 9f8a365b9c..8d73d06694 100644 --- a/kernels/quantized/cpu/op_dequantize.cpp +++ b/kernels/quantized/cpu/op_dequantize.cpp @@ -186,7 +186,7 @@ float get_scale(const Tensor& scale, size_t channel_ix) { Tensor& dequantize_per_channel_out( const Tensor& input, const Tensor& scale, - const optional& opt_zero_points, + const exec_aten::optional& opt_zero_points, int64_t axis, int64_t quant_min, int64_t quant_max, @@ -261,7 +261,7 @@ Tensor& dequantize_per_channel_out( const auto* input_data_ptr = input.const_data_ptr(); \ ET_CHECK_MSG( \ axis == 0, "Axis must be 0 for a single dimensional tensors"); \ - const optional dim; \ + const exec_aten::optional dim; \ apply_over_dim( \ [input_data_ptr, out_data_ptr, zero_point_data, &scale]( \ size_t numel, size_t stride, size_t base_ix) { \ @@ -331,7 +331,7 @@ Tensor& dequantize_per_channel_out( KernelRuntimeContext& context, const Tensor& input, const Tensor& scale, - const optional& opt_zero_points, + const exec_aten::optional& opt_zero_points, int64_t axis, int64_t quant_min, int64_t quant_max, diff --git a/kernels/quantized/cpu/op_embedding.cpp b/kernels/quantized/cpu/op_embedding.cpp index e48e9a7eea..0ffe363f2a 100644 --- a/kernels/quantized/cpu/op_embedding.cpp +++ b/kernels/quantized/cpu/op_embedding.cpp @@ -27,7 +27,7 @@ namespace { void check_embedding_byte_args( const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, const int64_t weight_quant_min, const int64_t weight_quant_max, const Tensor& indices, @@ -129,7 +129,7 @@ template void embedding_byte_per_channel( const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, const Tensor& indices, Tensor& out) { // An embedding layer nn.Embedding(num_embeddings, embedding_dim) has a @@ -218,7 +218,7 @@ Tensor& quantized_embedding_byte_out( // non quant input and returns fp output const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, const int64_t weight_quant_min, const int64_t weight_quant_max, const Tensor& indices, @@ -253,7 +253,7 @@ Tensor& quantized_embedding_byte_out( KernelRuntimeContext& context, const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, int64_t weight_quant_min, int64_t weight_quant_max, const Tensor& indices, @@ -277,7 +277,7 @@ Tensor& quantized_embedding_byte_dtype_out( // non quant input and returns fp output const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, const int64_t weight_quant_min, const int64_t weight_quant_max, const Tensor& indices, @@ -316,7 +316,7 @@ Tensor& quantized_embedding_byte_dtype_out( KernelRuntimeContext& context, const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, int64_t weight_quant_min, int64_t weight_quant_max, const Tensor& indices, diff --git a/kernels/quantized/cpu/op_embedding2b.cpp b/kernels/quantized/cpu/op_embedding2b.cpp index 0fdd7b731f..a2d2f8eb39 100644 --- a/kernels/quantized/cpu/op_embedding2b.cpp +++ b/kernels/quantized/cpu/op_embedding2b.cpp @@ -37,7 +37,7 @@ Tensor& quantized_embedding_2bit_out( // non quant input and returns fp output const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, const int64_t weight_quant_min, const int64_t weight_quant_max, const Tensor& indices, @@ -57,7 +57,7 @@ Tensor& quantized_embedding_2bit_out( KernelRuntimeContext& context, const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, int64_t weight_quant_min, int64_t weight_quant_max, const Tensor& indices, @@ -77,7 +77,7 @@ Tensor& quantized_embedding_2bit_out( Tensor& quantized_embedding_2bit_dtype_out( const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, int64_t weight_quant_min, int64_t weight_quant_max, const Tensor& indices, @@ -99,7 +99,7 @@ Tensor& quantized_embedding_2bit_dtype_out( KernelRuntimeContext& context, const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, int64_t weight_quant_min, int64_t weight_quant_max, const Tensor& indices, diff --git a/kernels/quantized/cpu/op_embedding4b.cpp b/kernels/quantized/cpu/op_embedding4b.cpp index 8a99073cd0..d123b40b35 100644 --- a/kernels/quantized/cpu/op_embedding4b.cpp +++ b/kernels/quantized/cpu/op_embedding4b.cpp @@ -37,7 +37,7 @@ Tensor& quantized_embedding_4bit_out( // non quant input and returns fp output const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, const int64_t weight_quant_min, const int64_t weight_quant_max, const Tensor& indices, @@ -57,7 +57,7 @@ Tensor& quantized_embedding_4bit_out( KernelRuntimeContext& context, const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, int64_t weight_quant_min, int64_t weight_quant_max, const Tensor& indices, @@ -79,7 +79,7 @@ Tensor& quantized_embedding_4bit_dtype_out( // non quant input and returns fp output const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, const int64_t weight_quant_min, const int64_t weight_quant_max, const Tensor& indices, @@ -101,7 +101,7 @@ Tensor& quantized_embedding_4bit_dtype_out( KernelRuntimeContext& context, const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, int64_t weight_quant_min, int64_t weight_quant_max, const Tensor& indices, diff --git a/kernels/quantized/cpu/op_mixed_linear.cpp b/kernels/quantized/cpu/op_mixed_linear.cpp index d3552e1ca6..af3d10cedb 100644 --- a/kernels/quantized/cpu/op_mixed_linear.cpp +++ b/kernels/quantized/cpu/op_mixed_linear.cpp @@ -19,8 +19,8 @@ bool check_quantized_mixed_linear_args( const Tensor& in, const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, - const optional dtype, + const exec_aten::optional& opt_weight_zero_points, + const exec_aten::optional dtype, Tensor& out) { ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(in, 2)); ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(weight, 2)); @@ -64,8 +64,8 @@ Tensor& quantized_mixed_linear_out( const Tensor& in, const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, - const optional dtype, + const exec_aten::optional& opt_weight_zero_points, + const exec_aten::optional dtype, Tensor& out) { // TODO (gjcomer) Replace with ET_KERNEL_CHECK when context is available. ET_CHECK(check_quantized_mixed_linear_args( @@ -117,8 +117,8 @@ Tensor& quantized_mixed_linear_out( const Tensor& in, const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, - const optional dtype, + const exec_aten::optional& opt_weight_zero_points, + const exec_aten::optional dtype, Tensor& out) { // TODO(mcandales): Remove the need for this wrapper // TODO(mkg): add support for dtype diff --git a/kernels/quantized/cpu/op_mixed_mm.cpp b/kernels/quantized/cpu/op_mixed_mm.cpp index 895c7e0af3..18d8f1e70d 100644 --- a/kernels/quantized/cpu/op_mixed_mm.cpp +++ b/kernels/quantized/cpu/op_mixed_mm.cpp @@ -19,7 +19,7 @@ bool check_quantized_mixed_mm_args( const Tensor& in, const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, Tensor& out) { ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(in, 2)); ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(weight, 2)); @@ -55,7 +55,7 @@ Tensor& quantized_mixed_mm_out( const Tensor& in, const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, Tensor& out) { ET_CHECK(check_quantized_mixed_mm_args( in, weight, weight_scales, opt_weight_zero_points, out)); @@ -92,7 +92,7 @@ Tensor& quantized_mixed_mm_out( const Tensor& in, const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, Tensor& out) { // TODO(mcandales): Remove the need for this wrapper (void)ctx; diff --git a/pytest.ini b/pytest.ini index 3666c9c879..a5041504ae 100644 --- a/pytest.ini +++ b/pytest.ini @@ -38,6 +38,7 @@ addopts = # backends/xnnpack backends/xnnpack/test # extension/ + extension/llm/modules/test extension/pybindings/test # Runtime runtime diff --git a/runtime/core/portable_type/string_view.h b/runtime/core/portable_type/string_view.h index 977a0f542d..8e28fa022c 100644 --- a/runtime/core/portable_type/string_view.h +++ b/runtime/core/portable_type/string_view.h @@ -8,572 +8,13 @@ #pragma once -#include -#include +#include -#include - -// TODO(T154113473): Document this file namespace executorch { namespace runtime { namespace etensor { -namespace internal { - -/** - * Reimplementation of std::string_view for C++11. - * Mostly copy pasted from the c10 implementation but modified some to remove - * broader c10 dependencies - */ -template -class basic_string_view final { - public: - using value_type = CharT; - using pointer = CharT*; - using const_pointer = const CharT*; - using reference = CharT&; - using const_reference = const CharT&; - using const_iterator = const CharT*; - using iterator = const_iterator; - using size_type = std::size_t; - - static constexpr size_type npos = size_type(-1); - - constexpr basic_string_view() noexcept : begin_(nullptr), size_(0) {} - - explicit constexpr basic_string_view(const_pointer str, size_type count) - : begin_(str), size_(count) {} - - /* implicit */ constexpr basic_string_view(const_pointer str) - : basic_string_view(str, strlen_(str)) {} - - constexpr const_iterator begin() const noexcept { - return cbegin(); - } - - constexpr const_iterator cbegin() const noexcept { - return begin_; - } - - constexpr const_iterator end() const noexcept { - return cend(); - } - - constexpr const_iterator cend() const noexcept { - return begin_ + size_; - } - - friend constexpr const_iterator begin(basic_string_view sv) noexcept { - return sv.begin(); - } - - friend constexpr const_iterator end(basic_string_view sv) noexcept { - return sv.end(); - } - - constexpr const_reference operator[](size_type pos) const { - return at_(pos); - } - - constexpr const_reference at(size_type pos) const { - ET_CHECK_MSG( - pos >= size_, - "string_view::operator[] or string_view::at() out of range"); - return at_(pos); - } - - constexpr const_reference front() const { - return *begin_; - } - - constexpr const_reference back() const { - return *(begin_ + size_ - 1); - } - - constexpr const_pointer data() const noexcept { - return begin_; - } - - constexpr size_type size() const noexcept { - return size_; - } - - constexpr size_type length() const noexcept { - return size(); - } - - constexpr bool empty() const noexcept { - return size() == 0; - } - - void remove_prefix(size_type n) { - ET_CHECK_MSG(n > size(), "basic_string_view::remove_prefix: out of range."); - begin_ += n; - size_ -= n; - } - - void remove_suffix(size_type n) { - ET_CHECK_MSG(n > size(), "basic_string_view::remove_suffix: out of range."); - size_ -= n; - } - - void swap(basic_string_view& sv) noexcept { - auto tmp = *this; - *this = sv; - sv = tmp; - } - - size_type copy(pointer dest, size_type count, size_type pos = 0) const { - ET_CHECK_MSG(pos > size_, "basic_string_view::copy: out of range."); - size_type copy_length = min_(count, size_ - pos); - for (auto iter = begin() + pos, end = iter + copy_length; iter != end;) { - *(dest++) = *(iter++); - } - return copy_length; - } - - constexpr basic_string_view substr(size_type pos = 0, size_type count = npos) - const { - ET_CHECK_MSG( - pos > size_, "basic_string_view::substr parameter out of bounds."); - return substr_(pos, count); - } - - constexpr int compare(basic_string_view rhs) const noexcept { -#if __cpp_constexpr >= 201304 - // if we are in C++14, write it iteratively. This is faster. - for (size_t i = 0, end = min_(size(), rhs.size()); i < end; ++i) { - if (at_(i) < rhs.at_(i)) { - return -1; - } else if (at_(i) > rhs.at_(i)) { - return 1; - } - } - if (size() < rhs.size()) { - return -1; - } else if (size() > rhs.size()) { - return 1; - } - return 0; -#else - // if we are in C++11, we need to do it recursively because of constexpr - // restrictions. - return (size() == 0 && rhs.size() == 0) ? 0 - : (size() == 0) ? -1 - : (rhs.size() == 0) ? 1 - : (front() < rhs.front()) ? -1 - : (front() > rhs.front()) ? 1 - : substr_(1).compare(rhs.substr_(1)); -#endif - } - - constexpr int compare(size_type pos1, size_type count1, basic_string_view v) - const { - return substr(pos1, count1).compare(v); - } - - constexpr int compare( - size_type pos1, - size_type count1, - basic_string_view v, - size_type pos2, - size_type count2) const { - return substr(pos1, count1).compare(v.substr(pos2, count2)); - } - - constexpr int compare(const_pointer s) const { - return compare(basic_string_view(s)); - } - - constexpr int compare(size_type pos1, size_type count1, const_pointer s) - const { - return substr(pos1, count1).compare(basic_string_view(s)); - } - - constexpr int compare( - size_type pos1, - size_type count1, - const_pointer s, - size_type count2) const { - return substr(pos1, count1).compare(basic_string_view(s, count2)); - } - - friend constexpr bool operator==( - basic_string_view lhs, - basic_string_view rhs) noexcept { - return lhs.equals_(rhs); - } - - friend constexpr bool operator!=( - basic_string_view lhs, - basic_string_view rhs) noexcept { - return !(lhs == rhs); - } - - friend constexpr bool operator<( - basic_string_view lhs, - basic_string_view rhs) noexcept { - return lhs.compare(rhs) < 0; - } - - friend constexpr bool operator>=( - basic_string_view lhs, - basic_string_view rhs) noexcept { - return !(lhs < rhs); - } - - friend constexpr bool operator>( - basic_string_view lhs, - basic_string_view rhs) noexcept { - return rhs < lhs; - } - - friend constexpr bool operator<=( - basic_string_view lhs, - basic_string_view rhs) noexcept { - return !(lhs > rhs); - } - - constexpr bool starts_with(basic_string_view prefix) const noexcept { - return (prefix.size() > size()) ? false - : prefix.equals_(substr_(0, prefix.size())); - } - - constexpr bool starts_with(CharT prefix) const noexcept { - return !empty() && prefix == front(); - } - - constexpr bool starts_with(const_pointer prefix) const { - return starts_with(basic_string_view(prefix)); - } - - constexpr bool ends_with(basic_string_view suffix) const noexcept { - return (suffix.size() > size()) - ? false - : suffix.equals_(substr_(size() - suffix.size(), suffix.size())); - } - - constexpr bool ends_with(CharT suffix) const noexcept { - return !empty() && suffix == back(); - } - - constexpr bool ends_with(const_pointer suffix) const { - return ends_with(basic_string_view(suffix)); - } - - constexpr size_type find(basic_string_view v, size_type pos = 0) - const noexcept { -#if __cpp_constexpr >= 201304 - // if we are in C++14, write it iteratively. This is faster. - if (v.size() == 0) { - return pos <= size() ? pos : npos; - } - - if (pos + v.size() <= size()) { - for (size_type cur = pos, end = size() - v.size(); cur <= end; ++cur) { - if (v.at_(0) == at_(cur) && - v.substr_(1).equals_(substr_(cur + 1, v.size() - 1))) { - return cur; - } - } - } - return npos; -#else - // if we are in C++11, we need to do it recursively because of constexpr - // restrictions. - return (v.size() == 0) ? (pos <= size() ? pos : npos) - : (pos + v.size() > size()) ? npos - : (v.at_(0) == at_(pos) && - v.substr_(1).equals_(substr_(pos + 1, v.size() - 1))) - ? pos - : find(v, pos + 1); -#endif - } - - constexpr size_type find(CharT ch, size_type pos = 0) const noexcept { - return find_first_if_(pos, charIsEqual_{ch}); - } - - constexpr size_type find(const_pointer s, size_type pos, size_type count) - const { - return find(basic_string_view(s, count), pos); - } - - constexpr size_type find(const_pointer s, size_type pos = 0) const { - return find(basic_string_view(s), pos); - } - - constexpr size_type rfind(basic_string_view v, size_type pos = npos) - const noexcept { -#if __cpp_constexpr >= 201304 - // if we are in C++14, write it iteratively. This is faster. - if (v.size() == 0) { - return pos <= size() ? pos : size(); - } - - if (v.size() <= size()) { - pos = min_(size() - v.size(), pos); - do { - if (v.at_(0) == at_(pos) && - v.substr_(1).equals_(substr_(pos + 1, v.size() - 1))) { - return pos; - } - } while (pos-- > 0); - } - return npos; -#else - // if we are in C++11, we need to do it recursively because of constexpr - // restrictions. - return (v.size() == 0) ? (pos <= size() ? pos : size()) - : (v.size() > size()) ? npos - : (size() - v.size() < pos) ? rfind(v, size() - v.size()) - : (v.at_(0) == at_(pos) && - v.substr_(1).equals_(substr_(pos + 1, v.size() - 1))) - ? pos - : (pos == 0) ? npos - : rfind(v, pos - 1); -#endif - } - - constexpr size_type rfind(CharT ch, size_type pos = npos) const noexcept { - return find_last_if_(pos, charIsEqual_{ch}); - } - - constexpr size_type rfind(const_pointer s, size_type pos, size_type count) - const { - return rfind(basic_string_view(s, count), pos); - } - - constexpr size_type rfind(const_pointer s, size_type pos = npos) const { - return rfind(basic_string_view(s), pos); - } - - constexpr size_type find_first_of(basic_string_view v, size_type pos = 0) - const noexcept { - return find_first_if_(pos, stringViewContainsChar_{v}); - } - - constexpr size_type find_first_of(CharT ch, size_type pos = 0) - const noexcept { - return find_first_if_(pos, charIsEqual_{ch}); - } - - constexpr size_type - find_first_of(const_pointer s, size_type pos, size_type count) const { - return find_first_of(basic_string_view(s, count), pos); - } - - constexpr size_type find_first_of(const_pointer s, size_type pos = 0) const { - return find_first_of(basic_string_view(s), pos); - } - - constexpr size_type find_last_of(basic_string_view v, size_type pos = npos) - const noexcept { - return find_last_if_(pos, stringViewContainsChar_{v}); - } - - constexpr size_type find_last_of(CharT ch, size_type pos = npos) - const noexcept { - return find_last_if_(pos, charIsEqual_{ch}); - } - - constexpr size_type - find_last_of(const_pointer s, size_type pos, size_type count) const { - return find_last_of(basic_string_view(s, count), pos); - } - - constexpr size_type find_last_of(const_pointer s, size_type pos = npos) - const { - return find_last_of(basic_string_view(s), pos); - } - - constexpr size_type find_first_not_of(basic_string_view v, size_type pos = 0) - const noexcept { - return find_first_if_(pos, stringViewDoesNotContainChar_{v}); - } - - constexpr size_type find_first_not_of(CharT ch, size_type pos = 0) - const noexcept { - return find_first_if_(pos, charIsNotEqual_{ch}); - } - - constexpr size_type - find_first_not_of(const_pointer s, size_type pos, size_type count) const { - return find_first_not_of(basic_string_view(s, count), pos); - } - - constexpr size_type find_first_not_of(const_pointer s, size_type pos = 0) - const { - return find_first_not_of(basic_string_view(s), pos); - } - - constexpr size_type find_last_not_of( - basic_string_view v, - size_type pos = npos) const noexcept { - return find_last_if_(pos, stringViewDoesNotContainChar_{v}); - } - - constexpr size_type find_last_not_of(CharT ch, size_type pos = npos) - const noexcept { - return find_last_if_(pos, charIsNotEqual_{ch}); - } - - constexpr size_type - find_last_not_of(const_pointer s, size_type pos, size_type count) const { - return find_last_not_of(basic_string_view(s, count), pos); - } - - constexpr size_type find_last_not_of(const_pointer s, size_type pos = npos) - const { - return find_last_not_of(basic_string_view(s), pos); - } - - private: - static constexpr std::size_t min_(const std::size_t a, const std::size_t b) { - return (b < a) ? b : a; - } - - static constexpr size_type strlen_(const_pointer str) noexcept { -#if __cpp_constexpr >= 201304 - // if we are in C++14, write it iteratively. This is faster. - const_pointer current = str; - while (*current != '\0') { - ++current; - } - return current - str; -#else - // if we are in C++11, we need to do it recursively because of constexpr - // restrictions. - return (*str == '\0') ? 0 : 1 + strlen_(str + 1); -#endif - } - - constexpr const_reference at_(size_type pos) const noexcept { - return *(begin_ + pos); - } - - constexpr basic_string_view substr_(size_type pos = 0, size_type count = npos) - const { - return basic_string_view{begin_ + pos, min_(count, size() - pos)}; - } - - template - constexpr size_type find_first_if_(size_type pos, Condition&& condition) - const noexcept { -#if __cpp_constexpr >= 201304 - // if we are in C++14, write it iteratively. This is faster. - if (pos + 1 <= size()) { - for (size_type cur = pos; cur < size(); ++cur) { - if (condition(at_(cur))) { - return cur; - } - } - } - return npos; -#else - // if we are in C++11, we need to do it recursively because of constexpr - // restrictions. - return (pos + 1 > size()) ? npos - : condition(at_(pos)) - ? pos - : find_first_if_(pos + 1, std::forward(condition)); -#endif - } - - template - constexpr size_type find_last_if_(size_type pos, Condition&& condition) - const noexcept { -#if __cpp_constexpr >= 201304 - // if we are in C++14, write it iteratively. This is faster. - if (size() > 0) { - pos = min_(size() - 1, pos); - do { - if (condition(at_(pos))) { - return pos; - } - } while (pos-- > 0); - } - return npos; -#else - // if we are in C++11, we need to do it recursively because of constexpr - // restrictions. - return (size() == 0) ? npos - : (pos >= size()) - ? find_last_if_(size() - 1, std::forward(condition)) - : condition(at_(pos)) ? pos - : (pos == 0) - ? npos - : find_last_if_(pos - 1, std::forward(condition)); -#endif - } - - constexpr bool equals_(basic_string_view rhs) const { -#if __cpp_constexpr >= 201304 - // if we are in C++14, write it iteratively. This is faster than the - // recursive C++11 implementation below. - if (size() != rhs.size()) { - return false; - } - // memcmp would be faster than this loop, but memcmp isn't constexpr - for (typename basic_string_view::size_type pos = 0; pos < size(); - ++pos) { - if (at_(pos) != rhs.at_(pos)) { - return false; - } - } - return true; -#else - // if we are in C++11, we need to do it recursively because of constexpr - // restrictions. - return (size() != rhs.size()) ? false - : (size() == 0) ? true - : (front() != rhs.front()) ? false - : (substr_(1).equals_(rhs.substr_(1))); -#endif - } - - struct charIsEqual_ final { - CharT expected; - constexpr bool operator()(CharT actual) const noexcept { - return expected == actual; - } - }; - - struct charIsNotEqual_ final { - CharT expected; - constexpr bool operator()(CharT actual) const noexcept { - return expected != actual; - } - }; - - struct stringViewContainsChar_ final { - basic_string_view expected; - constexpr bool operator()(CharT ch) const noexcept { - return npos != expected.find(ch); - } - }; - - struct stringViewDoesNotContainChar_ final { - basic_string_view expected; - constexpr bool operator()(CharT ch) const noexcept { - return npos == expected.find(ch); - } - }; - - const_pointer begin_; - size_type size_; -}; - -template -inline void swap( - basic_string_view& lhs, - basic_string_view& rhs) noexcept { - lhs.swap(rhs); -} - -} // namespace internal - -using string_view = internal::basic_string_view; +using std::string_view; } // namespace etensor } // namespace runtime diff --git a/runtime/executor/test/backend_integration_test.cpp b/runtime/executor/test/backend_integration_test.cpp index 9180d77aa3..bf9dc0033f 100644 --- a/runtime/executor/test/backend_integration_test.cpp +++ b/runtime/executor/test/backend_integration_test.cpp @@ -55,7 +55,7 @@ class StubBackend final : public BackendInterface { using InitFn = std::function( FreeableBuffer*, ArrayRef, - MemoryAllocator*)>; + BackendInitContext&)>; using ExecuteFn = std::function; using DestroyFn = std::function; @@ -83,8 +83,7 @@ class StubBackend final : public BackendInterface { FreeableBuffer* processed, ArrayRef compile_specs) const override { if (init_fn_) { - return init_fn_.value()( - processed, compile_specs, context.get_runtime_allocator()); + return init_fn_.value()(processed, compile_specs, context); } // Return a benign value otherwise. return nullptr; @@ -351,7 +350,7 @@ TEST_P(BackendIntegrationTest, FreeingProcessedBufferSucceeds) { StubBackend::singleton().install_init( [&](FreeableBuffer* processed, ET_UNUSED ArrayRef compile_specs, - ET_UNUSED MemoryAllocator* runtime_allocator) + ET_UNUSED BackendInitContext& backend_init_context) -> Result { init_called = true; processed_data = processed->data(); @@ -395,7 +394,7 @@ TEST_P(BackendIntegrationTest, EndToEndTestWithProcessedAsHandle) { StubBackend::singleton().install_init( [&](FreeableBuffer* processed, ET_UNUSED ArrayRef compile_specs, - ET_UNUSED MemoryAllocator* runtime_allocator) + ET_UNUSED BackendInitContext& backend_init_context) -> Result { init_processed = processed; return processed; @@ -492,7 +491,7 @@ TEST_P(BackendIntegrationTest, SegmentInfoIsPassedIntoDataLoader) { StubBackend::singleton().install_init( [&](FreeableBuffer* processed, ET_UNUSED ArrayRef compile_specs, - ET_UNUSED MemoryAllocator* runtime_allocator) + ET_UNUSED BackendInitContext& backend_init_context) -> Result { processed_data = processed->data(); processed->Free(); @@ -606,7 +605,7 @@ TEST_P(DelegateDataAlignmentTest, ExpectedDataAlignment) { StubBackend::singleton().install_init( [&](FreeableBuffer* processed, ET_UNUSED ArrayRef compile_specs, - ET_UNUSED MemoryAllocator* runtime_allocator) + ET_UNUSED BackendInitContext& backend_init_context) -> Result { processed_data = processed->data(); return nullptr; diff --git a/shim/xplat/executorch/kernels/optimized/lib_defs.bzl b/shim/xplat/executorch/kernels/optimized/lib_defs.bzl index 79ce6b02b3..bd3284c42a 100644 --- a/shim/xplat/executorch/kernels/optimized/lib_defs.bzl +++ b/shim/xplat/executorch/kernels/optimized/lib_defs.bzl @@ -16,16 +16,46 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") # functions in order to declare the required compiler flags needed in order to # access CPU vector intrinsics. -def get_vec_android_preprocessor_flags(): - preprocessor_flags = [ - ( - "^android-arm64.*$", - [ +# This oopy from kernels/optimized/lib_defs.bzl is not necessary. +# This file really needs to be removed +def get_vec_preprocessor_flags(): + if not runtime.is_oss: + # various ovr_configs are not available in oss + preprocessor_flags = select({ + "ovr_config//os:iphoneos": [ "-DET_BUILD_ARM_VEC256_WITH_SLEEF", - ], - ), - ] - return preprocessor_flags + ] if not runtime.is_oss else [], + "ovr_config//os:macos-arm64": [ + "-DET_BUILD_ARM_VEC256_WITH_SLEEF", + ] if not runtime.is_oss else [], + "ovr_config//os:android-arm64": [ + "-DET_BUILD_ARM_VEC256_WITH_SLEEF", + ] if not runtime.is_oss else [], + "DEFAULT": [], + }) + return preprocessor_flags + return [] + +def get_vec_deps(): + if not runtime.is_oss: + # various ovr_configs are not available in oss + deps = select({ + "ovr_config//os:linux-x86_64": [ + "fbsource//third-party/sleef:sleef", + ] if not runtime.is_oss else [], + "ovr_config//os:iphoneos": [ + "fbsource//third-party/sleef:sleef_arm", + ] if not runtime.is_oss else [], + "ovr_config//os:macos-arm64": [ + "fbsource//third-party/sleef:sleef_arm", + ] if not runtime.is_oss else [], + "ovr_config//os:android-arm64": [ + "fbsource//third-party/sleef:sleef_arm", + ] if not runtime.is_oss else [], + "DEFAULT": [], + }) + return deps + return [] def get_vec_cxx_preprocessor_flags(): preprocessor_flags = [ diff --git a/shim/xplat/executorch/kernels/optimized/op_registration_util.bzl b/shim/xplat/executorch/kernels/optimized/op_registration_util.bzl index c9fe4ec912..37a68abaa0 100644 --- a/shim/xplat/executorch/kernels/optimized/op_registration_util.bzl +++ b/shim/xplat/executorch/kernels/optimized/op_registration_util.bzl @@ -9,7 +9,7 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") load("@fbsource//xplat/executorch/build:selects.bzl", "selects") load( "@fbsource//xplat/executorch/kernels/optimized:lib_defs.bzl", - "get_vec_android_preprocessor_flags", + "get_vec_preprocessor_flags", ) def op_target(name, deps = []): @@ -98,7 +98,7 @@ def define_op_library(name, deps): deps = [ "//executorch/runtime/kernel:kernel_includes", ] + augmented_deps, - fbandroid_platform_preprocessor_flags = get_vec_android_preprocessor_flags(), + preprocessor_flags = get_vec_preprocessor_flags(), # sleef needs to be added as a direct dependency of the operator target when building for Android, # or a linker error may occur. Not sure why this happens; it seems that fbandroid_platform_deps of # dependencies are not transitive diff --git a/third-party/ao b/third-party/ao new file mode 160000 index 0000000000..75d06933aa --- /dev/null +++ b/third-party/ao @@ -0,0 +1 @@ +Subproject commit 75d06933aace9d1ce803158e52910e4c9fc60981